Markov: add support for starting in the middle of a chain and working backwards
this only makes sense if we have a target word set, which we usually do. start with the target word and go backwords, finding k2s that lead to it (and that lead to that k2, and so on) until we get to the start-of-chain value, when we know we're done working backwards. then resume the normal appending logic probably needs some work, probably a bit slow on huge databases. analysis pending, but this appears to work
This commit is contained in:
parent
ad93ea28ec
commit
42962bc48d
@ -418,6 +418,30 @@ class Markov(Module):
|
||||
|
||||
# walk a chain, randomly, building the list of words
|
||||
while len(gen_words) < max_size + 2 and gen_words[-1] != self.stop:
|
||||
# first, see if we have an empty response and a target word.
|
||||
# we'll just pick a word and work backwards
|
||||
if gen_words[-1] == self.start2 and target_word != '':
|
||||
working_backwards = []
|
||||
working_backwards.append(target_word)
|
||||
# generate new word
|
||||
found_word = ''
|
||||
target_word = words[random.randint(0, len(words)-1)]
|
||||
# work backwards until we randomly bump into a start
|
||||
while True:
|
||||
key_hits = self._retrieve_k2_for_value(working_backwards[0], context)
|
||||
if target_word in key_hits:
|
||||
found_word = target_word
|
||||
# generate new word
|
||||
target_word = words[random.randint(0, len(words)-1)]
|
||||
else:
|
||||
found_word = random.choice(filter(lambda a: a != self.stop, key_hits))
|
||||
|
||||
if found_word == self.start2:
|
||||
gen_words = gen_words + working_backwards
|
||||
break
|
||||
else:
|
||||
working_backwards.insert(0, found_word)
|
||||
|
||||
key_hits = self._retrieve_chains_for_key(gen_words[-2], gen_words[-1], context)
|
||||
# use the chain that includes the target word, if it is found
|
||||
if target_word != '' and target_word in key_hits:
|
||||
@ -469,6 +493,26 @@ class Markov(Module):
|
||||
print('sqlite error: ' + str(e))
|
||||
raise
|
||||
|
||||
def _retrieve_k2_for_value(self, v, context):
|
||||
"""Get the value(s) for a given key (a pair of strings)."""
|
||||
|
||||
values = []
|
||||
try:
|
||||
db = self.get_db()
|
||||
query = 'SELECT k2 FROM markov_chain WHERE v = ? AND (context = ? OR context IS NULL)'
|
||||
cursor = db.execute(query, (v,context))
|
||||
results = cursor.fetchall()
|
||||
|
||||
for result in results:
|
||||
values.append(result['k2'])
|
||||
|
||||
db.close()
|
||||
return values
|
||||
except sqlite3.Error as e:
|
||||
db.close()
|
||||
print('sqlite error: ' + str(e))
|
||||
raise
|
||||
|
||||
def _get_chatter_targets(self):
|
||||
"""Get all possible chatter targets."""
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user