Markov: rewrite backwards/forwards chainer

this clarifies a bunch of sections and seems slightly faster

target_word (which would be randomly selected from the input every
time) is replaced with seed_words, a shuffled list from the input.
this is to eliminate accidental reuse of the target word, which
would result in chains like X X X X X X X X X X X X X because
it'd keep targeting X

the rest of this is mostly just debug cleanup, though to simplify
the backwards code it only tries to find one target word
This commit is contained in:
Brian S. Stephan 2012-07-29 09:39:07 -05:00
parent 9ca37c3990
commit 390e925360
1 changed files with 63 additions and 55 deletions

View File

@ -327,12 +327,15 @@ class Markov(Module):
if (min_size > 20):
raise Exception("min_size is too large: %d" % min_size)
words = []
target_word = ''
# get a random word from the input
if line != '':
words = line.split()
target_word = words[random.randint(0, len(words)-1)]
seed_words = []
# shuffle the words in the input
seed_words = line.split()
random.shuffle(seed_words)
self.log.debug("seed words: {0:s}".format(seed_words))
# hit to generate a new seed word immediately if possible
seed_word = None
hit_word = None
context_id = self._get_context_id_for_target(target)
@ -341,72 +344,80 @@ class Markov(Module):
# walk a chain, randomly, building the list of words
while len(gen_words) < max_size + 2 and gen_words[-1] != self.stop:
# pick a word from the shuffled seed words, if we need a new one
if seed_word == hit_word:
if len(seed_words) > 0:
seed_word = seed_words.pop()
self.log.debug("picked new seed word: {0:s}".format(seed_word))
else:
seed_word = None
self.log.debug("ran out of seed words")
# first, see if we have an empty response and a target word.
# we'll just pick a word and work backwards
if gen_words[-1] == self.start2 and target_word != '':
# if so, work backwards, otherwise forwards
if gen_words[-1] == self.start2 and seed_word is not None:
# work backwards
working_backwards = []
key_hits = self._retrieve_k2_for_value(target_word, context_id)
key_hits = self._retrieve_k2_for_value(seed_word, context_id)
if len(key_hits):
working_backwards.append(target_word)
self.log.debug("added '{0:s}' to working_backwards".format(target_word))
found_word = seed_word
working_backwards.append(found_word)
self.log.debug("started working backwards with: {0:s}".format(found_word))
self.log.debug("working_backwards: {0:s}".format(" ".join(working_backwards)))
# generate new word
found_word = ''
target_word = words[random.randint(0, len(words)-1)]
# work backwards until we randomly bump into a start
# now work backwards until we randomly bump into a start
while True:
key_hits = self._retrieve_k2_for_value(working_backwards[0], context_id)
if target_word in key_hits:
found_word = target_word
# generate new word
if len(filter(lambda a: a != target_word, words)) > 1 and False:
# if we have more than one target word, get a new one (otherwise give up)
target_word = random.choice(filter(lambda a: a != target_word, words))
else:
target_word = ''
else:
found_word = random.choice(filter(lambda a: a != self.stop, key_hits))
if found_word == self.start2 or len(working_backwards) >= max_size + 2:
self.log.debug("done working backwards")
if self.start2 in key_hits:
self.log.debug("working backwards found start, finishing")
gen_words = gen_words + working_backwards
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
break
else:
found_word = random.choice(filter(lambda a: a != self.stop, key_hits))
working_backwards.insert(0, found_word)
self.log.debug("added '{0:s}' to working_backwards".format(found_word))
self.log.debug("working_backwards: {0:s}".format(" ".join(working_backwards)))
key_hits = self._retrieve_chains_for_key(gen_words[-2], gen_words[-1], context_id)
# use the chain that includes the target word, if it is found
if target_word != '' and target_word in key_hits:
gen_words.append(target_word)
self.log.debug("added target word '{0:s}' to gen_words".format(target_word))
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
# generate new word
target_word = words[random.randint(0, len(words)-1)]
hit_word = seed_word
else:
gen_words.append(self._get_suitable_word_from_choices(key_hits, gen_words, min_size))
# work forwards
# tack a new chain onto the list and resume if we're too short
if gen_words[-1] == self.stop and len(gen_words) < min_size + 2:
self.log.debug("starting a new chain on end of old one")
key_hits = self._retrieve_chains_for_key(gen_words[-2], gen_words[-1], context_id)
# chop off the end text, if it was the keyword indicating an end of chain
if gen_words[-1] == self.stop:
gen_words = gen_words[:-1]
# new word 1
key_hits = self._retrieve_chains_for_key(self.start1, self.start2, context_id)
gen_words.append(self._get_suitable_word_from_choices(key_hits, gen_words, min_size))
if gen_words[-1] == self.stop:
break
# use the chain that includes the target word, if it is found
if seed_word is not None and seed_word in key_hits:
hit_word = seed_word
self.log.debug("added seed word '{0:s}' to gen_words".format(hit_word))
else:
# new word 2
key_hits = self._retrieve_chains_for_key(self.start2, gen_words[-1], context_id)
hit_word = self._get_suitable_word_from_choices(key_hits, gen_words, min_size)
self.log.debug("added random word '{0:s}' to gen_words".format(hit_word))
# from either method, append result
gen_words.append(hit_word)
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
# tack a new chain onto the list and resume if we're too short
if gen_words[-1] == self.stop and len(gen_words) < min_size + 2:
self.log.debug("starting a new chain on end of old one")
# chop off the end text, if it was the keyword indicating an end of chain
if gen_words[-1] == self.stop:
gen_words = gen_words[:-1]
# new word 1
key_hits = self._retrieve_chains_for_key(self.start1, self.start2, context_id)
gen_words.append(self._get_suitable_word_from_choices(key_hits, gen_words, min_size))
# the database is probably empty if we got a stop from this
if gen_words[-1] == self.stop:
break
else:
# new word 2
key_hits = self._retrieve_chains_for_key(self.start2, gen_words[-1], context_id)
gen_words.append(self._get_suitable_word_from_choices(key_hits, gen_words, min_size))
# chop off the seed data at the start
gen_words = gen_words[2:]
@ -425,16 +436,13 @@ class Markov(Module):
if len(gen_words) < min_size + 2 and len(filter(lambda a: a != self.stop, key_hits)) > 0:
found_word = random.choice(filter(lambda a: a != self.stop, key_hits))
self.log.debug("added '{0:s}' to gen_words".format(found_word))
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
return found_word
elif len(key_hits) <= 0:
self.log.debug("no hits found, appending stop")
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
return self.stop
else:
found_word = random.choice(key_hits)
self.log.debug("added '{0:s}' to gen_words".format(found_word))
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
return found_word
def _retrieve_chains_for_key(self, k1, k2, context_id):