Markov: _retrieve_random_k2_for_value
rather than getting all k2s for a value from the database, then walking the list and picking one at random, pick one for a value at random via a query this simplifies the code, and is (usually) faster than the old way, which has been removed. it would be even faster if it weren't for that context_id stuff, but so it goes
This commit is contained in:
parent
b5be0501de
commit
5a55227cf9
@ -375,9 +375,10 @@ class Markov(Module):
|
||||
# work backwards
|
||||
|
||||
working_backwards = []
|
||||
key_hits = self._retrieve_k2_for_value(seed_word, context_id)
|
||||
if len(key_hits):
|
||||
back_k2 = self._retrieve_random_k2_for_value(seed_word, context_id)
|
||||
if back_k2:
|
||||
found_word = seed_word
|
||||
working_backwards.append(back_k2)
|
||||
working_backwards.append(found_word)
|
||||
self.log.debug("started working backwards with: {0:s}".format(found_word))
|
||||
self.log.debug("working_backwards: {0:s}".format(" ".join(working_backwards)))
|
||||
@ -389,28 +390,22 @@ class Markov(Module):
|
||||
# some time is spent on better forward chains
|
||||
max_back = random.randint(1, max_size/2) + random.randint(1, max_size/2)
|
||||
self.log.debug("max_back: {0:d}".format(max_back))
|
||||
while True:
|
||||
key_hits = self._retrieve_k2_for_value(working_backwards[0], context_id)
|
||||
if self.start2 in key_hits and len(working_backwards) > max_back:
|
||||
self.log.debug("max_back exceeded, cleanly finishing")
|
||||
while len(working_backwards) < max_back:
|
||||
back_k2 = self._retrieve_random_k2_for_value(working_backwards[0], context_id)
|
||||
if back_k2 == self.start2:
|
||||
self.log.debug("random further back was start2, finishing")
|
||||
gen_words += working_backwards
|
||||
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
|
||||
break
|
||||
elif len(key_hits) == 0:
|
||||
self.log.debug("no key_hits, finishing")
|
||||
gen_words += working_backwards
|
||||
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
|
||||
break
|
||||
elif len(filter(lambda a: a!= self.start2, key_hits)) == 0:
|
||||
self.log.debug("only start2 in key_hits, finishing")
|
||||
gen_words += working_backwards
|
||||
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
|
||||
break
|
||||
else:
|
||||
found_word = random.choice(filter(lambda a: a != self.start2, key_hits))
|
||||
working_backwards.insert(0, found_word)
|
||||
self.log.debug("added '{0:s}' to working_backwards".format(found_word))
|
||||
elif back_k2:
|
||||
working_backwards.insert(0, back_k2)
|
||||
self.log.debug("added '{0:s}' to working_backwards".format(back_k2))
|
||||
self.log.debug("working_backwards: {0:s}".format(" ".join(working_backwards)))
|
||||
else:
|
||||
self.log.debug("nothing (at all!?) further back, finishing")
|
||||
gen_words += working_backwards
|
||||
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
|
||||
break
|
||||
|
||||
hit_word = seed_word
|
||||
else:
|
||||
@ -538,23 +533,30 @@ class Markov(Module):
|
||||
raise
|
||||
finally: cur.close()
|
||||
|
||||
def _retrieve_k2_for_value(self, v, context_id):
|
||||
"""Get the value(s) for a given key (a pair of strings)."""
|
||||
def _retrieve_random_k2_for_value(self, v, context_id):
|
||||
"""Get one k2 for a given value."""
|
||||
|
||||
values = []
|
||||
db = self.get_db()
|
||||
try:
|
||||
query = 'SELECT k2 FROM markov_chain WHERE v = %s AND context_id = %s'
|
||||
query = '''
|
||||
SELECT k2 FROM markov_chain AS r1
|
||||
JOIN (
|
||||
SELECT (RAND() * (SELECT MAX(id) FROM markov_chain)) AS id
|
||||
) AS r2
|
||||
WHERE r1.id >= r2.id
|
||||
AND r1.v = %s
|
||||
AND r1.context_id = %s
|
||||
ORDER BY r1.id ASC
|
||||
LIMIT 1
|
||||
'''
|
||||
cur = db.cursor(mdb.cursors.DictCursor)
|
||||
cur.execute(query, (v, context_id))
|
||||
results = cur.fetchall()
|
||||
|
||||
for result in results:
|
||||
values.append(result['k2'])
|
||||
|
||||
return values
|
||||
result = cur.fetchone()
|
||||
if result:
|
||||
return result['k2']
|
||||
except mdb.Error as e:
|
||||
self.log.error("database error in _retrieve_k2_for_value")
|
||||
self.log.error("database error in _retrieve_random_k2_for_value")
|
||||
self.log.exception(e)
|
||||
raise
|
||||
finally: cur.close()
|
||||
|
Loading…
Reference in New Issue
Block a user