Markov: add support for starting in the middle of a chain and working backwards

this only makes sense if we have a target word set, which we usually do.
start with the target word and go backwords, finding k2s that lead to it
(and that lead to that k2, and so on) until we get to the start-of-chain
value, when we know we're done working backwards. then resume the normal
appending logic

probably needs some work, probably a bit slow on huge databases. analysis
pending, but this appears to work
This commit is contained in:
Brian S. Stephan 2011-10-16 20:19:51 -05:00
parent ad93ea28ec
commit 42962bc48d
1 changed files with 44 additions and 0 deletions

View File

@ -418,6 +418,30 @@ class Markov(Module):
# walk a chain, randomly, building the list of words
while len(gen_words) < max_size + 2 and gen_words[-1] != self.stop:
# first, see if we have an empty response and a target word.
# we'll just pick a word and work backwards
if gen_words[-1] == self.start2 and target_word != '':
working_backwards = []
working_backwards.append(target_word)
# generate new word
found_word = ''
target_word = words[random.randint(0, len(words)-1)]
# work backwards until we randomly bump into a start
while True:
key_hits = self._retrieve_k2_for_value(working_backwards[0], context)
if target_word in key_hits:
found_word = target_word
# generate new word
target_word = words[random.randint(0, len(words)-1)]
else:
found_word = random.choice(filter(lambda a: a != self.stop, key_hits))
if found_word == self.start2:
gen_words = gen_words + working_backwards
break
else:
working_backwards.insert(0, found_word)
key_hits = self._retrieve_chains_for_key(gen_words[-2], gen_words[-1], context)
# use the chain that includes the target word, if it is found
if target_word != '' and target_word in key_hits:
@ -469,6 +493,26 @@ class Markov(Module):
print('sqlite error: ' + str(e))
raise
def _retrieve_k2_for_value(self, v, context):
"""Get the value(s) for a given key (a pair of strings)."""
values = []
try:
db = self.get_db()
query = 'SELECT k2 FROM markov_chain WHERE v = ? AND (context = ? OR context IS NULL)'
cursor = db.execute(query, (v,context))
results = cursor.fetchall()
for result in results:
values.append(result['k2'])
db.close()
return values
except sqlite3.Error as e:
db.close()
print('sqlite error: ' + str(e))
raise
def _get_chatter_targets(self):
"""Get all possible chatter targets."""