Markov: index on (v, context) and other enhancements for the last commit

reduce some infinite loop possibilities, and add an index with the old <= id trick
to speed up the searching for backwards chains
This commit is contained in:
Brian S. Stephan 2011-10-16 21:13:27 -05:00
parent 42962bc48d
commit cda1d43606
1 changed files with 43 additions and 20 deletions

View File

@ -201,6 +201,21 @@ class Markov(Module):
db.close()
print('sqlite error: ' + str(e))
raise
if (version < 8):
db = self.get_db()
try:
version = 8
db.execute('''
CREATE INDEX markov_chain_value_and_context_index
ON markov_chain (v, context)''')
db.commit()
db.close()
self.db_register_module_version(self.__class__.__name__, version)
except sqlite3.Error as e:
db.rollback()
db.close()
print('sqlite error: ' + str(e))
raise
def register_handlers(self):
"""Handle pubmsg/privmsg, to learn and/or reply to IRC events."""
@ -422,25 +437,31 @@ class Markov(Module):
# we'll just pick a word and work backwards
if gen_words[-1] == self.start2 and target_word != '':
working_backwards = []
working_backwards.append(target_word)
# generate new word
found_word = ''
target_word = words[random.randint(0, len(words)-1)]
# work backwards until we randomly bump into a start
while True:
key_hits = self._retrieve_k2_for_value(working_backwards[0], context)
if target_word in key_hits:
found_word = target_word
# generate new word
target_word = words[random.randint(0, len(words)-1)]
else:
found_word = random.choice(filter(lambda a: a != self.stop, key_hits))
key_hits = self._retrieve_k2_for_value(target_word, context)
if len(key_hits):
working_backwards.append(target_word)
# generate new word
found_word = ''
target_word = words[random.randint(0, len(words)-1)]
# work backwards until we randomly bump into a start
while True:
key_hits = self._retrieve_k2_for_value(working_backwards[0], context)
if target_word in key_hits:
found_word = target_word
# generate new word
if len(filter(lambda a: a != target_word, words)) > 1 and False:
# if we have more than one target word, get a new one (otherwise give up)
target_word = random.choice(filter(lambda a: a != target_word, words))
else:
target_word = ''
else:
found_word = random.choice(filter(lambda a: a != self.stop, key_hits))
if found_word == self.start2:
gen_words = gen_words + working_backwards
break
else:
working_backwards.insert(0, found_word)
if found_word == self.start2 or len(working_backwards) >= max_size + 2:
gen_words = gen_words + working_backwards
break
else:
working_backwards.insert(0, found_word)
key_hits = self._retrieve_chains_for_key(gen_words[-2], gen_words[-1], context)
# use the chain that includes the target word, if it is found
@ -499,8 +520,10 @@ class Markov(Module):
values = []
try:
db = self.get_db()
query = 'SELECT k2 FROM markov_chain WHERE v = ? AND (context = ? OR context IS NULL)'
cursor = db.execute(query, (v,context))
max_id = self._get_max_chain_id()
rand_id = random.randint(1,max_id)
query = 'SELECT k2 FROM markov_chain WHERE v = ? AND (context = ? OR context IS NULL) AND id >= {0:d} UNION SELECT k2 FROM markov_chain WHERE v = ? AND (context = ? OR context IS NULL) AND id < {1:d} LIMIT 1'.format(rand_id, rand_id)
cursor = db.execute(query, (v,context,v,context))
results = cursor.fetchall()
for result in results: