From 390e925360b708cc1d70f7cbd4fe4fb9608a4ee4 Mon Sep 17 00:00:00 2001 From: "Brian S. Stephan" Date: Sun, 29 Jul 2012 09:39:07 -0500 Subject: [PATCH] Markov: rewrite backwards/forwards chainer this clarifies a bunch of sections and seems slightly faster target_word (which would be randomly selected from the input every time) is replaced with seed_words, a shuffled list from the input. this is to eliminate accidental reuse of the target word, which would result in chains like X X X X X X X X X X X X X because it'd keep targeting X the rest of this is mostly just debug cleanup, though to simplify the backwards code it only tries to find one target word --- modules/Markov.py | 118 +++++++++++++++++++++++++--------------------- 1 file changed, 63 insertions(+), 55 deletions(-) diff --git a/modules/Markov.py b/modules/Markov.py index 84d086e..5b5f42b 100644 --- a/modules/Markov.py +++ b/modules/Markov.py @@ -327,12 +327,15 @@ class Markov(Module): if (min_size > 20): raise Exception("min_size is too large: %d" % min_size) - words = [] - target_word = '' - # get a random word from the input - if line != '': - words = line.split() - target_word = words[random.randint(0, len(words)-1)] + seed_words = [] + # shuffle the words in the input + seed_words = line.split() + random.shuffle(seed_words) + self.log.debug("seed words: {0:s}".format(seed_words)) + + # hit to generate a new seed word immediately if possible + seed_word = None + hit_word = None context_id = self._get_context_id_for_target(target) @@ -341,72 +344,80 @@ class Markov(Module): # walk a chain, randomly, building the list of words while len(gen_words) < max_size + 2 and gen_words[-1] != self.stop: + # pick a word from the shuffled seed words, if we need a new one + if seed_word == hit_word: + if len(seed_words) > 0: + seed_word = seed_words.pop() + self.log.debug("picked new seed word: {0:s}".format(seed_word)) + else: + seed_word = None + self.log.debug("ran out of seed words") + # first, see if we have an empty response and a target word. - # we'll just pick a word and work backwards - if gen_words[-1] == self.start2 and target_word != '': + # if so, work backwards, otherwise forwards + if gen_words[-1] == self.start2 and seed_word is not None: + # work backwards + working_backwards = [] - key_hits = self._retrieve_k2_for_value(target_word, context_id) + key_hits = self._retrieve_k2_for_value(seed_word, context_id) if len(key_hits): - working_backwards.append(target_word) - self.log.debug("added '{0:s}' to working_backwards".format(target_word)) + found_word = seed_word + working_backwards.append(found_word) + self.log.debug("started working backwards with: {0:s}".format(found_word)) self.log.debug("working_backwards: {0:s}".format(" ".join(working_backwards))) - # generate new word - found_word = '' - target_word = words[random.randint(0, len(words)-1)] - # work backwards until we randomly bump into a start + + # now work backwards until we randomly bump into a start while True: key_hits = self._retrieve_k2_for_value(working_backwards[0], context_id) - if target_word in key_hits: - found_word = target_word - # generate new word - if len(filter(lambda a: a != target_word, words)) > 1 and False: - # if we have more than one target word, get a new one (otherwise give up) - target_word = random.choice(filter(lambda a: a != target_word, words)) - else: - target_word = '' - else: - found_word = random.choice(filter(lambda a: a != self.stop, key_hits)) - - if found_word == self.start2 or len(working_backwards) >= max_size + 2: - self.log.debug("done working backwards") + if self.start2 in key_hits: + self.log.debug("working backwards found start, finishing") gen_words = gen_words + working_backwards self.log.debug("gen_words: {0:s}".format(" ".join(gen_words))) break else: + found_word = random.choice(filter(lambda a: a != self.stop, key_hits)) working_backwards.insert(0, found_word) self.log.debug("added '{0:s}' to working_backwards".format(found_word)) self.log.debug("working_backwards: {0:s}".format(" ".join(working_backwards))) - key_hits = self._retrieve_chains_for_key(gen_words[-2], gen_words[-1], context_id) - # use the chain that includes the target word, if it is found - if target_word != '' and target_word in key_hits: - gen_words.append(target_word) - self.log.debug("added target word '{0:s}' to gen_words".format(target_word)) - self.log.debug("gen_words: {0:s}".format(" ".join(gen_words))) - # generate new word - target_word = words[random.randint(0, len(words)-1)] + hit_word = seed_word else: - gen_words.append(self._get_suitable_word_from_choices(key_hits, gen_words, min_size)) + # work forwards - # tack a new chain onto the list and resume if we're too short - if gen_words[-1] == self.stop and len(gen_words) < min_size + 2: - self.log.debug("starting a new chain on end of old one") + key_hits = self._retrieve_chains_for_key(gen_words[-2], gen_words[-1], context_id) - # chop off the end text, if it was the keyword indicating an end of chain - if gen_words[-1] == self.stop: - gen_words = gen_words[:-1] - - # new word 1 - key_hits = self._retrieve_chains_for_key(self.start1, self.start2, context_id) - gen_words.append(self._get_suitable_word_from_choices(key_hits, gen_words, min_size)) - - if gen_words[-1] == self.stop: - break + # use the chain that includes the target word, if it is found + if seed_word is not None and seed_word in key_hits: + hit_word = seed_word + self.log.debug("added seed word '{0:s}' to gen_words".format(hit_word)) else: - # new word 2 - key_hits = self._retrieve_chains_for_key(self.start2, gen_words[-1], context_id) + hit_word = self._get_suitable_word_from_choices(key_hits, gen_words, min_size) + self.log.debug("added random word '{0:s}' to gen_words".format(hit_word)) + + # from either method, append result + gen_words.append(hit_word) + self.log.debug("gen_words: {0:s}".format(" ".join(gen_words))) + + # tack a new chain onto the list and resume if we're too short + if gen_words[-1] == self.stop and len(gen_words) < min_size + 2: + self.log.debug("starting a new chain on end of old one") + + # chop off the end text, if it was the keyword indicating an end of chain + if gen_words[-1] == self.stop: + gen_words = gen_words[:-1] + + # new word 1 + key_hits = self._retrieve_chains_for_key(self.start1, self.start2, context_id) gen_words.append(self._get_suitable_word_from_choices(key_hits, gen_words, min_size)) + # the database is probably empty if we got a stop from this + if gen_words[-1] == self.stop: + break + else: + # new word 2 + key_hits = self._retrieve_chains_for_key(self.start2, gen_words[-1], context_id) + gen_words.append(self._get_suitable_word_from_choices(key_hits, gen_words, min_size)) + # chop off the seed data at the start gen_words = gen_words[2:] @@ -425,16 +436,13 @@ class Markov(Module): if len(gen_words) < min_size + 2 and len(filter(lambda a: a != self.stop, key_hits)) > 0: found_word = random.choice(filter(lambda a: a != self.stop, key_hits)) self.log.debug("added '{0:s}' to gen_words".format(found_word)) - self.log.debug("gen_words: {0:s}".format(" ".join(gen_words))) return found_word elif len(key_hits) <= 0: self.log.debug("no hits found, appending stop") - self.log.debug("gen_words: {0:s}".format(" ".join(gen_words))) return self.stop else: found_word = random.choice(key_hits) self.log.debug("added '{0:s}' to gen_words".format(found_word)) - self.log.debug("gen_words: {0:s}".format(" ".join(gen_words))) return found_word def _retrieve_chains_for_key(self, k1, k2, context_id):