Markov: rewrite backwards/forwards chainer
this clarifies a bunch of sections and seems slightly faster target_word (which would be randomly selected from the input every time) is replaced with seed_words, a shuffled list from the input. this is to eliminate accidental reuse of the target word, which would result in chains like X X X X X X X X X X X X X because it'd keep targeting X the rest of this is mostly just debug cleanup, though to simplify the backwards code it only tries to find one target word
This commit is contained in:
parent
9ca37c3990
commit
390e925360
|
@ -327,12 +327,15 @@ class Markov(Module):
|
|||
if (min_size > 20):
|
||||
raise Exception("min_size is too large: %d" % min_size)
|
||||
|
||||
words = []
|
||||
target_word = ''
|
||||
# get a random word from the input
|
||||
if line != '':
|
||||
words = line.split()
|
||||
target_word = words[random.randint(0, len(words)-1)]
|
||||
seed_words = []
|
||||
# shuffle the words in the input
|
||||
seed_words = line.split()
|
||||
random.shuffle(seed_words)
|
||||
self.log.debug("seed words: {0:s}".format(seed_words))
|
||||
|
||||
# hit to generate a new seed word immediately if possible
|
||||
seed_word = None
|
||||
hit_word = None
|
||||
|
||||
context_id = self._get_context_id_for_target(target)
|
||||
|
||||
|
@ -341,72 +344,80 @@ class Markov(Module):
|
|||
|
||||
# walk a chain, randomly, building the list of words
|
||||
while len(gen_words) < max_size + 2 and gen_words[-1] != self.stop:
|
||||
# pick a word from the shuffled seed words, if we need a new one
|
||||
if seed_word == hit_word:
|
||||
if len(seed_words) > 0:
|
||||
seed_word = seed_words.pop()
|
||||
self.log.debug("picked new seed word: {0:s}".format(seed_word))
|
||||
else:
|
||||
seed_word = None
|
||||
self.log.debug("ran out of seed words")
|
||||
|
||||
# first, see if we have an empty response and a target word.
|
||||
# we'll just pick a word and work backwards
|
||||
if gen_words[-1] == self.start2 and target_word != '':
|
||||
# if so, work backwards, otherwise forwards
|
||||
if gen_words[-1] == self.start2 and seed_word is not None:
|
||||
# work backwards
|
||||
|
||||
working_backwards = []
|
||||
key_hits = self._retrieve_k2_for_value(target_word, context_id)
|
||||
key_hits = self._retrieve_k2_for_value(seed_word, context_id)
|
||||
if len(key_hits):
|
||||
working_backwards.append(target_word)
|
||||
self.log.debug("added '{0:s}' to working_backwards".format(target_word))
|
||||
found_word = seed_word
|
||||
working_backwards.append(found_word)
|
||||
self.log.debug("started working backwards with: {0:s}".format(found_word))
|
||||
self.log.debug("working_backwards: {0:s}".format(" ".join(working_backwards)))
|
||||
# generate new word
|
||||
found_word = ''
|
||||
target_word = words[random.randint(0, len(words)-1)]
|
||||
# work backwards until we randomly bump into a start
|
||||
|
||||
# now work backwards until we randomly bump into a start
|
||||
while True:
|
||||
key_hits = self._retrieve_k2_for_value(working_backwards[0], context_id)
|
||||
if target_word in key_hits:
|
||||
found_word = target_word
|
||||
# generate new word
|
||||
if len(filter(lambda a: a != target_word, words)) > 1 and False:
|
||||
# if we have more than one target word, get a new one (otherwise give up)
|
||||
target_word = random.choice(filter(lambda a: a != target_word, words))
|
||||
else:
|
||||
target_word = ''
|
||||
else:
|
||||
found_word = random.choice(filter(lambda a: a != self.stop, key_hits))
|
||||
|
||||
if found_word == self.start2 or len(working_backwards) >= max_size + 2:
|
||||
self.log.debug("done working backwards")
|
||||
if self.start2 in key_hits:
|
||||
self.log.debug("working backwards found start, finishing")
|
||||
gen_words = gen_words + working_backwards
|
||||
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
|
||||
break
|
||||
else:
|
||||
found_word = random.choice(filter(lambda a: a != self.stop, key_hits))
|
||||
working_backwards.insert(0, found_word)
|
||||
self.log.debug("added '{0:s}' to working_backwards".format(found_word))
|
||||
self.log.debug("working_backwards: {0:s}".format(" ".join(working_backwards)))
|
||||
|
||||
key_hits = self._retrieve_chains_for_key(gen_words[-2], gen_words[-1], context_id)
|
||||
# use the chain that includes the target word, if it is found
|
||||
if target_word != '' and target_word in key_hits:
|
||||
gen_words.append(target_word)
|
||||
self.log.debug("added target word '{0:s}' to gen_words".format(target_word))
|
||||
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
|
||||
# generate new word
|
||||
target_word = words[random.randint(0, len(words)-1)]
|
||||
hit_word = seed_word
|
||||
else:
|
||||
gen_words.append(self._get_suitable_word_from_choices(key_hits, gen_words, min_size))
|
||||
# work forwards
|
||||
|
||||
# tack a new chain onto the list and resume if we're too short
|
||||
if gen_words[-1] == self.stop and len(gen_words) < min_size + 2:
|
||||
self.log.debug("starting a new chain on end of old one")
|
||||
key_hits = self._retrieve_chains_for_key(gen_words[-2], gen_words[-1], context_id)
|
||||
|
||||
# chop off the end text, if it was the keyword indicating an end of chain
|
||||
if gen_words[-1] == self.stop:
|
||||
gen_words = gen_words[:-1]
|
||||
|
||||
# new word 1
|
||||
key_hits = self._retrieve_chains_for_key(self.start1, self.start2, context_id)
|
||||
gen_words.append(self._get_suitable_word_from_choices(key_hits, gen_words, min_size))
|
||||
|
||||
if gen_words[-1] == self.stop:
|
||||
break
|
||||
# use the chain that includes the target word, if it is found
|
||||
if seed_word is not None and seed_word in key_hits:
|
||||
hit_word = seed_word
|
||||
self.log.debug("added seed word '{0:s}' to gen_words".format(hit_word))
|
||||
else:
|
||||
# new word 2
|
||||
key_hits = self._retrieve_chains_for_key(self.start2, gen_words[-1], context_id)
|
||||
hit_word = self._get_suitable_word_from_choices(key_hits, gen_words, min_size)
|
||||
self.log.debug("added random word '{0:s}' to gen_words".format(hit_word))
|
||||
|
||||
# from either method, append result
|
||||
gen_words.append(hit_word)
|
||||
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
|
||||
|
||||
# tack a new chain onto the list and resume if we're too short
|
||||
if gen_words[-1] == self.stop and len(gen_words) < min_size + 2:
|
||||
self.log.debug("starting a new chain on end of old one")
|
||||
|
||||
# chop off the end text, if it was the keyword indicating an end of chain
|
||||
if gen_words[-1] == self.stop:
|
||||
gen_words = gen_words[:-1]
|
||||
|
||||
# new word 1
|
||||
key_hits = self._retrieve_chains_for_key(self.start1, self.start2, context_id)
|
||||
gen_words.append(self._get_suitable_word_from_choices(key_hits, gen_words, min_size))
|
||||
|
||||
# the database is probably empty if we got a stop from this
|
||||
if gen_words[-1] == self.stop:
|
||||
break
|
||||
else:
|
||||
# new word 2
|
||||
key_hits = self._retrieve_chains_for_key(self.start2, gen_words[-1], context_id)
|
||||
gen_words.append(self._get_suitable_word_from_choices(key_hits, gen_words, min_size))
|
||||
|
||||
# chop off the seed data at the start
|
||||
gen_words = gen_words[2:]
|
||||
|
||||
|
@ -425,16 +436,13 @@ class Markov(Module):
|
|||
if len(gen_words) < min_size + 2 and len(filter(lambda a: a != self.stop, key_hits)) > 0:
|
||||
found_word = random.choice(filter(lambda a: a != self.stop, key_hits))
|
||||
self.log.debug("added '{0:s}' to gen_words".format(found_word))
|
||||
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
|
||||
return found_word
|
||||
elif len(key_hits) <= 0:
|
||||
self.log.debug("no hits found, appending stop")
|
||||
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
|
||||
return self.stop
|
||||
else:
|
||||
found_word = random.choice(key_hits)
|
||||
self.log.debug("added '{0:s}' to gen_words".format(found_word))
|
||||
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
|
||||
return found_word
|
||||
|
||||
def _retrieve_chains_for_key(self, k1, k2, context_id):
|
||||
|
|
Loading…
Reference in New Issue