Use multiprocessing

This commit is contained in:
2022-01-31 20:54:06 -05:00
parent bcc86e18f9
commit 6c1055c49d

View File

@@ -2,8 +2,11 @@ import copy
import logging
import random
from dataclasses import dataclass, field
from enum import Enum, auto
from itertools import repeat
from multiprocessing import Pool
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
from pathlib import Path
import string
@@ -41,13 +44,12 @@ class DictionaryProvider:
def __init__(self, limit: Optional[int]=None):
self.limit = limit
self.dictionary = self._get_dictionary(self._word_filter(WORD_LENGTH))
if limit:
self.dictionary = random.sample(self.dictionary, limit)
def provide(self, ) -> List[Word]:
words = self._get_dictionary(self._word_filter(WORD_LENGTH))
if self.limit:
return words[:self.limit]
else:
return words
def provide(self) -> List[Word]:
return self.dictionary.copy()
def _get_dictionary(self, filter_fun: Callable[str, bool] = None) -> List[Word]:
l = []
@@ -159,8 +161,20 @@ class WordleMoveCalculator:
word for word in self.dictionary if self.word_playable.is_word_playable(state, word)
]
move_rankings = [] # type: List[Tuple[Word, int]]
for guess in words:
with Pool(4) as p:
move_rankings = p.starmap(self.calc_guess, zip(repeat(state, len(words)), words))
move_rankings = sorted(move_rankings, key=lambda x: x[1], reverse=True)
if limit:
move_rankings = move_rankings[0:limit]
return move_rankings
def calc_guess(self, state: WordleState, guess: Word):
# TODO: Don't recompute me, already determined in `calculate`
words = [
word for word in self.dictionary if self.word_playable.is_word_playable(state, word)
]
word_total = 0
for answer in words:
temp_state = WordleState(
@@ -173,19 +187,15 @@ class WordleMoveCalculator:
word_total += self.state_evaluator.evaluate(temp_state)
word_score = word_total / len(words)
move_rankings.append((guess, word_score))
return (guess, word_score)
move_rankings = sorted(move_rankings, key=lambda x: x[1], reverse=True)
if limit:
move_rankings = move_rankings[0:limit]
return move_rankings
def main():
logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)
dictionary_provider = DictionaryProvider(limit=100)
dictionary_provider = DictionaryProvider(limit=200)
word_playable = WordleWordPlayable()
state_evaluator = WordleRemainingWordsStateEvaluator(dictionary_provider, word_playable)
move_processor = WordleMoveProcessor()