diff --git a/modules/Markov.py b/modules/Markov.py index 2c6e2bd..7721e66 100644 --- a/modules/Markov.py +++ b/modules/Markov.py @@ -386,20 +386,23 @@ class Markov(Module): # generate new word target_word = words[random.randint(0, len(words)-1)] else: - if len(gen_words) < min_size and len(filter(lambda a: a != self.stop, key_hits)) > 0: - found_word = random.choice(filter(lambda a: a != self.stop, key_hits)) - gen_words.append(found_word) - self.log.debug("added '{0:s}' to gen_words".format(found_word)) - self.log.debug("gen_words: {0:s}".format(" ".join(gen_words))) - elif len(key_hits) <= 0: - self.log.debug("no hits found, appending stop") - gen_words.append(self.stop) - self.log.debug("gen_words: {0:s}".format(" ".join(gen_words))) - else: - found_word = random.choice(key_hits) - gen_words.append(found_word) - self.log.debug("added '{0:s}' to gen_words".format(found_word)) - self.log.debug("gen_words: {0:s}".format(" ".join(gen_words))) + gen_words.append(self._get_suitable_word_from_choices(key_hits, gen_words, min_size)) + + # tack a new chain onto the list and resume if we're too short + if gen_words[-1] == self.stop and len(gen_words) < min_size + 2: + self.log.debug("starting a new chain on end of old one") + + # chop off the end text, if it was the keyword indicating an end of chain + if gen_words[-1] == self.stop: + gen_words = gen_words[:-1] + + # new word 1 + key_hits = self._retrieve_chains_for_key(self.start1, self.start2, context_id) + gen_words.append(self._get_suitable_word_from_choices(key_hits, gen_words, min_size)) + + # new word 2 + key_hits = self._retrieve_chains_for_key(self.start2, found_word, context_id) + gen_words.append(self._get_suitable_word_from_choices(key_hits, gen_words, min_size)) # chop off the seed data at the start gen_words = gen_words[2:] @@ -410,6 +413,24 @@ class Markov(Module): return ' '.join(gen_words) + def _get_suitable_word_from_choices(self, key_hits, gen_words, min_size): + """Given an existing set of words, and key hits, pick one.""" + + if len(gen_words) < min_size + 2 and len(filter(lambda a: a != self.stop, key_hits)) > 0: + found_word = random.choice(filter(lambda a: a != self.stop, key_hits)) + self.log.debug("added '{0:s}' to gen_words".format(found_word)) + self.log.debug("gen_words: {0:s}".format(" ".join(gen_words))) + return found_word + elif len(key_hits) <= 0: + self.log.debug("no hits found, appending stop") + self.log.debug("gen_words: {0:s}".format(" ".join(gen_words))) + return self.stop + else: + found_word = random.choice(key_hits) + self.log.debug("added '{0:s}' to gen_words".format(found_word)) + self.log.debug("gen_words: {0:s}".format(" ".join(gen_words))) + return found_word + def _retrieve_chains_for_key(self, k1, k2, context_id): """Get the value(s) for a given key (a pair of strings)."""