diff --git a/modules/Markov.py b/modules/Markov.py index a034851..f8293d2 100644 --- a/modules/Markov.py +++ b/modules/Markov.py @@ -418,6 +418,30 @@ class Markov(Module): # walk a chain, randomly, building the list of words while len(gen_words) < max_size + 2 and gen_words[-1] != self.stop: + # first, see if we have an empty response and a target word. + # we'll just pick a word and work backwards + if gen_words[-1] == self.start2 and target_word != '': + working_backwards = [] + working_backwards.append(target_word) + # generate new word + found_word = '' + target_word = words[random.randint(0, len(words)-1)] + # work backwards until we randomly bump into a start + while True: + key_hits = self._retrieve_k2_for_value(working_backwards[0], context) + if target_word in key_hits: + found_word = target_word + # generate new word + target_word = words[random.randint(0, len(words)-1)] + else: + found_word = random.choice(filter(lambda a: a != self.stop, key_hits)) + + if found_word == self.start2: + gen_words = gen_words + working_backwards + break + else: + working_backwards.insert(0, found_word) + key_hits = self._retrieve_chains_for_key(gen_words[-2], gen_words[-1], context) # use the chain that includes the target word, if it is found if target_word != '' and target_word in key_hits: @@ -469,6 +493,26 @@ class Markov(Module): print('sqlite error: ' + str(e)) raise + def _retrieve_k2_for_value(self, v, context): + """Get the value(s) for a given key (a pair of strings).""" + + values = [] + try: + db = self.get_db() + query = 'SELECT k2 FROM markov_chain WHERE v = ? AND (context = ? OR context IS NULL)' + cursor = db.execute(query, (v,context)) + results = cursor.fetchall() + + for result in results: + values.append(result['k2']) + + db.close() + return values + except sqlite3.Error as e: + db.close() + print('sqlite error: ' + str(e)) + raise + def _get_chatter_targets(self): """Get all possible chatter targets."""