Markov: start new chains if the existing one is too short

This commit is contained in:
Brian S. Stephan 2012-07-28 13:55:54 -05:00
parent ced165cff4
commit a6f4827a41
1 changed files with 35 additions and 14 deletions

View File

@ -386,20 +386,23 @@ class Markov(Module):
# generate new word
target_word = words[random.randint(0, len(words)-1)]
else:
if len(gen_words) < min_size and len(filter(lambda a: a != self.stop, key_hits)) > 0:
found_word = random.choice(filter(lambda a: a != self.stop, key_hits))
gen_words.append(found_word)
self.log.debug("added '{0:s}' to gen_words".format(found_word))
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
elif len(key_hits) <= 0:
self.log.debug("no hits found, appending stop")
gen_words.append(self.stop)
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
else:
found_word = random.choice(key_hits)
gen_words.append(found_word)
self.log.debug("added '{0:s}' to gen_words".format(found_word))
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
gen_words.append(self._get_suitable_word_from_choices(key_hits, gen_words, min_size))
# tack a new chain onto the list and resume if we're too short
if gen_words[-1] == self.stop and len(gen_words) < min_size + 2:
self.log.debug("starting a new chain on end of old one")
# chop off the end text, if it was the keyword indicating an end of chain
if gen_words[-1] == self.stop:
gen_words = gen_words[:-1]
# new word 1
key_hits = self._retrieve_chains_for_key(self.start1, self.start2, context_id)
gen_words.append(self._get_suitable_word_from_choices(key_hits, gen_words, min_size))
# new word 2
key_hits = self._retrieve_chains_for_key(self.start2, found_word, context_id)
gen_words.append(self._get_suitable_word_from_choices(key_hits, gen_words, min_size))
# chop off the seed data at the start
gen_words = gen_words[2:]
@ -410,6 +413,24 @@ class Markov(Module):
return ' '.join(gen_words)
def _get_suitable_word_from_choices(self, key_hits, gen_words, min_size):
"""Given an existing set of words, and key hits, pick one."""
if len(gen_words) < min_size + 2 and len(filter(lambda a: a != self.stop, key_hits)) > 0:
found_word = random.choice(filter(lambda a: a != self.stop, key_hits))
self.log.debug("added '{0:s}' to gen_words".format(found_word))
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
return found_word
elif len(key_hits) <= 0:
self.log.debug("no hits found, appending stop")
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
return self.stop
else:
found_word = random.choice(key_hits)
self.log.debug("added '{0:s}' to gen_words".format(found_word))
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
return found_word
def _retrieve_chains_for_key(self, k1, k2, context_id):
"""Get the value(s) for a given key (a pair of strings)."""