Markov: massive rewrite of the chainer
a bunch of logic is moved around, some queries are improved, max_size does what it's actually supposed to do. all in all this is a much clearer chainer, even if the actual results are more or less the same. it's probably a bit faster in most cases but slower in situations when all the seed words have been consumed and it needs to do __start1,__start2 chains (since there's so many of them, it's rather slow). otherwise, it tries to use seed words in sentences, combining multiple sentences when possible. there's a lot more in the periphery, but that's the general idea
This commit is contained in:
parent
5d90c98fb2
commit
5314dadc07
|
@ -222,7 +222,7 @@ class Markov(Module):
|
|||
match = self.replyre.search(what)
|
||||
if match:
|
||||
min_size = 15
|
||||
max_size = 100
|
||||
max_size = 30
|
||||
|
||||
if match.group(2):
|
||||
min_size = int(match.group(2))
|
||||
|
@ -323,7 +323,7 @@ class Markov(Module):
|
|||
raise
|
||||
finally: cur.close()
|
||||
|
||||
def _generate_line(self, target, line='', min_size=15, max_size=100):
|
||||
def _generate_line(self, target, line='', min_size=15, max_size=30):
|
||||
"""Create a line, optionally using some text in a seed as a point in
|
||||
the chain.
|
||||
|
||||
|
@ -358,22 +358,51 @@ class Markov(Module):
|
|||
# start with an empty chain, and work from there
|
||||
gen_words = [self.start1, self.start2]
|
||||
|
||||
# walk a chain, randomly, building the list of words
|
||||
while len(gen_words) < max_size + 2 and gen_words[-1] != self.stop:
|
||||
# build a response by creating multiple sentences
|
||||
while len(gen_words) < max_size + 2:
|
||||
# if we're past the min and on a stop, we can end
|
||||
if len(gen_words) > min_size + 2:
|
||||
if gen_words[-1] == self.stop:
|
||||
break
|
||||
|
||||
# pick a word from the shuffled seed words, if we need a new one
|
||||
if seed_word == hit_word:
|
||||
if len(seed_words) > 0:
|
||||
seed_word = seed_words.pop()
|
||||
self.log.debug("picked new seed word: {0:s}".format(seed_word))
|
||||
self.log.debug("picked new seed word: "
|
||||
"{0:s}".format(seed_word))
|
||||
else:
|
||||
seed_word = None
|
||||
self.log.debug("ran out of seed words")
|
||||
|
||||
# first, see if we have an empty response and a target word.
|
||||
# if so, work backwards, otherwise forwards
|
||||
if gen_words[-1] == self.start2 and seed_word is not None:
|
||||
# work backwards
|
||||
# if we have a stop, the word before it might need to be
|
||||
# made to look like a sentence end
|
||||
if gen_words[-1] == self.stop:
|
||||
# chop off the stop, temporarily
|
||||
gen_words = gen_words[:-1]
|
||||
|
||||
# we should have a real word, make it look like a
|
||||
# sentence end
|
||||
sentence_end = gen_words[-1]
|
||||
eos_punctuation = ['!', '?', ',', '.']
|
||||
if sentence_end[-1] not in eos_punctuation:
|
||||
random.shuffle(eos_punctuation)
|
||||
gen_words[-1] = sentence_end + eos_punctuation.pop()
|
||||
self.log.debug("monkeyed with end of sentence, it's "
|
||||
"now: {0:s}".format(gen_words[-1]))
|
||||
|
||||
# put the stop back on
|
||||
gen_words.append(self.stop)
|
||||
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
|
||||
|
||||
# first, see if we should start a new sentence. if so,
|
||||
# work backwards
|
||||
if gen_words[-1] in (self.start2, self.stop) and seed_word is not None:
|
||||
# drop a stop, since we're starting another sentence
|
||||
if gen_words[-1] == self.stop:
|
||||
gen_words = gen_words[:-1]
|
||||
|
||||
# work backwards from seed_word
|
||||
working_backwards = []
|
||||
back_k2 = self._retrieve_random_k2_for_value(seed_word, context_id)
|
||||
if back_k2:
|
||||
|
@ -391,7 +420,8 @@ class Markov(Module):
|
|||
# the weaker-context reverse chaining, we make max_size
|
||||
# a non-linear distribution, making it more likely that
|
||||
# some time is spent on better forward chains
|
||||
max_back = random.randint(1, max_size/2) + random.randint(1, max_size/2)
|
||||
max_back = min(random.randint(1, max_size/2) + random.randint(1, max_size/2),
|
||||
max_size/4)
|
||||
self.log.debug("max_back: {0:d}".format(max_back))
|
||||
while len(working_backwards) < max_back:
|
||||
back_k2 = self._retrieve_random_k2_for_value(working_backwards[0], context_id)
|
||||
|
@ -408,80 +438,59 @@ class Markov(Module):
|
|||
|
||||
gen_words += working_backwards
|
||||
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
|
||||
hit_word = gen_words[-1]
|
||||
else:
|
||||
# work forwards
|
||||
self.log.debug("looking forwards")
|
||||
prefer = seed_word if seed_word else ''
|
||||
self.log.debug("preferring: '{0:s}'".format(prefer))
|
||||
forw_v = self._retrieve_random_v_for_k1_and_k2_with_pref(gen_words[-2],
|
||||
gen_words[-1],
|
||||
prefer, context_id)
|
||||
if forw_v:
|
||||
gen_words.append(forw_v)
|
||||
self.log.debug("added random word '{0:s}' to gen_words".format(forw_v))
|
||||
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
|
||||
else:
|
||||
# append stop, let below code clean it up if necessary
|
||||
gen_words.append(self.stop)
|
||||
self.log.debug("nothing found, adding stop")
|
||||
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
|
||||
|
||||
# tack a new chain onto the list and resume if we're too short
|
||||
if gen_words[-1] == self.stop and len(gen_words) < min_size + 3:
|
||||
self.log.debug("starting a new chain on end of old one")
|
||||
|
||||
# chop off the end text, if it was the keyword indicating an end of chain
|
||||
# we are working forward, with either:
|
||||
# * a pair of words (normal path, filling out a sentence)
|
||||
# * start1, start2 (completely new chain, no seed words)
|
||||
# * stop (new sentence in existing chain, no seed words)
|
||||
self.log.debug("working forwards")
|
||||
forw_v = None
|
||||
if gen_words[-1] in (self.start2, self.stop):
|
||||
# case 2 or 3 above, need to work forward on a beginning
|
||||
# of a sentence (this is slow)
|
||||
if gen_words[-1] == self.stop:
|
||||
# remove the stop if it's there
|
||||
gen_words = gen_words[:-1]
|
||||
|
||||
# monkey with the end word to make it more like an actual sentence end
|
||||
sentence_end = gen_words[-1]
|
||||
eos_punctuation = ['!', '?', ',', '.']
|
||||
if sentence_end[-1] not in eos_punctuation:
|
||||
random.shuffle(eos_punctuation)
|
||||
gen_words[-1] = sentence_end + eos_punctuation.pop()
|
||||
self.log.debug("monkeyed with end of sentence, it's now: {0:s}".format(gen_words[-1]))
|
||||
new_sentence = self._create_chain_with_k1_k2(self.start1,
|
||||
self.start2,
|
||||
3,
|
||||
context_id)
|
||||
|
||||
new_chain_words = []
|
||||
# new word 1
|
||||
key_hits = self._retrieve_chains_for_key(self.start1, self.start2, context_id)
|
||||
new_chain_words.append(self._get_suitable_word_from_choices(key_hits, gen_words, min_size))
|
||||
# the database is probably empty if we got a stop from this
|
||||
if new_chain_words[0] == self.stop:
|
||||
break
|
||||
# new word 2
|
||||
key_hits = self._retrieve_chains_for_key(self.start2, new_chain_words[0], context_id)
|
||||
new_chain_words.append(self._get_suitable_word_from_choices(key_hits, gen_words, min_size))
|
||||
if new_chain_words[1] != self.stop:
|
||||
# two valid words, try for a third and check for "foo:"
|
||||
|
||||
# new word 3 (which we may need below)
|
||||
key_hits = self._retrieve_chains_for_key(new_chain_words[0], new_chain_words[1], context_id)
|
||||
new_chain_words.append(self._get_suitable_word_from_choices(key_hits, gen_words, min_size))
|
||||
|
||||
# if the first word is "foo:", start with the second
|
||||
addressing_suffixes = [':', ',']
|
||||
if new_chain_words[0][-1] in addressing_suffixes:
|
||||
gen_words += new_chain_words[1:]
|
||||
self.log.debug("appending following anti-address " \
|
||||
"new_chain_words: {0:s}".format(new_chain_words[1:]))
|
||||
elif new_chain_words[2] == self.stop:
|
||||
gen_words += new_chain_words[0:1]
|
||||
self.log.debug("appending following anti-stop " \
|
||||
"new_chain_words: {0:s}".format(new_chain_words[0:1]))
|
||||
else:
|
||||
gen_words += new_chain_words[0:]
|
||||
self.log.debug("appending following extended " \
|
||||
"new_chain_words: {0:s}".format(new_chain_words[0:]))
|
||||
if len(new_sentence) > 0:
|
||||
self.log.debug("started new sentence "
|
||||
"'{0:s}'".format(" ".join(new_sentence)))
|
||||
gen_words += new_sentence
|
||||
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
|
||||
else:
|
||||
# well, we got one word out of this... let's go with it
|
||||
# and let the loop check if we need more
|
||||
self.log.debug("appending following short new_chain_words: {0:s}".format(new_chain_words))
|
||||
gen_words += new_chain_words
|
||||
# this is a problem. we started a sentence on
|
||||
# start1,start2, and still didn't find anything. to
|
||||
# avoid endlessly looping we need to abort here
|
||||
break
|
||||
else:
|
||||
if seed_word:
|
||||
self.log.debug("preferring: '{0:s}'".format(seed_word))
|
||||
forw_v = self._retrieve_random_v_for_k1_and_k2_with_pref(gen_words[-2],
|
||||
gen_words[-1],
|
||||
seed_word,
|
||||
context_id)
|
||||
else:
|
||||
forw_v = self._retrieve_random_v_for_k1_and_k2(gen_words[-2],
|
||||
gen_words[-1],
|
||||
context_id)
|
||||
|
||||
# no matter forwards or backwards, use the end of the sentence
|
||||
# as our current hit word
|
||||
hit_word = gen_words[-1]
|
||||
if forw_v:
|
||||
gen_words.append(forw_v)
|
||||
self.log.debug("added random word '{0:s}' to gen_words".format(forw_v))
|
||||
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
|
||||
hit_word = gen_words[-1]
|
||||
else:
|
||||
# append stop. this is an end to a sentence (since
|
||||
# we had non-start words to begin with)
|
||||
gen_words.append(self.stop)
|
||||
self.log.debug("nothing found, added stop")
|
||||
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
|
||||
|
||||
# chop off the seed data at the start
|
||||
gen_words = gen_words[2:]
|
||||
|
@ -492,48 +501,32 @@ class Markov(Module):
|
|||
|
||||
return ' '.join(gen_words)
|
||||
|
||||
def _get_suitable_word_from_choices(self, key_hits, gen_words, min_size):
|
||||
"""Given an existing set of words, and key hits, pick one."""
|
||||
|
||||
# first, if we're not yet at min_size, pick a non-stop word if it exists
|
||||
# else, if there were no results, append stop
|
||||
# otherwise, pick a random result
|
||||
if len(gen_words) < min_size + 2 and len(filter(lambda a: a != self.stop, key_hits)) > 0:
|
||||
found_word = random.choice(filter(lambda a: a != self.stop, key_hits))
|
||||
return found_word
|
||||
elif len(key_hits) == 0:
|
||||
return self.stop
|
||||
else:
|
||||
found_word = random.choice(key_hits)
|
||||
return found_word
|
||||
|
||||
def _retrieve_chains_for_key(self, k1, k2, context_id):
|
||||
"""Get the value(s) for a given key (a pair of strings)."""
|
||||
def _retrieve_random_v_for_k1_and_k2(self, k1, k2, context_id):
|
||||
"""Get one v for a given k1,k2."""
|
||||
|
||||
self.log.debug("searching with '{0:s}','{1:s}'".format(k1, k2))
|
||||
values = []
|
||||
db = self.get_db()
|
||||
try:
|
||||
query = ''
|
||||
if k1 == self.start1 and k2 == self.start2:
|
||||
# hack. get a quasi-random start from the database, in
|
||||
# a faster fashion than selecting all starts
|
||||
max_id = self._get_max_chain_id()
|
||||
rand_id = random.randint(1, max_id)
|
||||
query = ('SELECT v FROM markov_chain WHERE k1 = %s AND k2 = %s AND '
|
||||
'(context_id = %s) AND id >= {0:d} LIMIT 1'.format(rand_id))
|
||||
else:
|
||||
query = ('SELECT v FROM markov_chain WHERE k1 = %s AND k2 = %s AND '
|
||||
'(context_id = %s)')
|
||||
query = '''
|
||||
SELECT v FROM markov_chain AS r1
|
||||
JOIN (
|
||||
SELECT (RAND() * (SELECT MAX(id) FROM markov_chain)) AS id
|
||||
) AS r2
|
||||
WHERE r1.k1 = %s
|
||||
AND r1.k2 = %s
|
||||
AND r1.context_id = %s
|
||||
ORDER BY r1.id >= r2.id DESC, r1.id ASC
|
||||
LIMIT 1
|
||||
'''
|
||||
cur = db.cursor(mdb.cursors.DictCursor)
|
||||
cur.execute(query, (k1, k2, context_id))
|
||||
results = cur.fetchall()
|
||||
|
||||
for result in results:
|
||||
values.append(result['v'])
|
||||
|
||||
return values
|
||||
result = cur.fetchone()
|
||||
if result:
|
||||
self.log.debug("found '{0:s}'".format(result['v']))
|
||||
return result['v']
|
||||
except mdb.Error as e:
|
||||
self.log.error("database error in _retrieve_chains_for_key")
|
||||
self.log.error("database error in _retrieve_random_v_for_k1_and_k2")
|
||||
self.log.exception(e)
|
||||
raise
|
||||
finally: cur.close()
|
||||
|
@ -545,6 +538,8 @@ class Markov(Module):
|
|||
|
||||
"""
|
||||
|
||||
self.log.debug("searching with '{0:s}','{1:s}', prefer "
|
||||
"'{2:s}'".format(k1, k2, prefer))
|
||||
values = []
|
||||
db = self.get_db()
|
||||
try:
|
||||
|
@ -553,17 +548,17 @@ class Markov(Module):
|
|||
JOIN (
|
||||
SELECT (RAND() * (SELECT MAX(id) FROM markov_chain)) AS id
|
||||
) AS r2
|
||||
WHERE r1.id >= r2.id
|
||||
AND r1.k1 = %s
|
||||
WHERE r1.k1 = %s
|
||||
AND r1.k2 = %s
|
||||
AND r1.context_id = %s
|
||||
ORDER BY r1.v = %s DESC, r1.id ASC
|
||||
ORDER BY r1.id >= r2.id DESC, r1.v = %s DESC, r1.id ASC
|
||||
LIMIT 1
|
||||
'''
|
||||
cur = db.cursor(mdb.cursors.DictCursor)
|
||||
cur.execute(query, (k1, k2, context_id, prefer))
|
||||
result = cur.fetchone()
|
||||
if result:
|
||||
self.log.debug("found '{0:s}'".format(result['v']))
|
||||
return result['v']
|
||||
except mdb.Error as e:
|
||||
self.log.error("database error in _retrieve_random_v_for_k1_and_k2_with_pref")
|
||||
|
@ -582,10 +577,9 @@ class Markov(Module):
|
|||
JOIN (
|
||||
SELECT (RAND() * (SELECT MAX(id) FROM markov_chain)) AS id
|
||||
) AS r2
|
||||
WHERE r1.id >= r2.id
|
||||
AND r1.v = %s
|
||||
WHERE r1.v = %s
|
||||
AND r1.context_id = %s
|
||||
ORDER BY r1.id ASC
|
||||
ORDER BY r1.id >= r2.id DESC, r1.id ASC
|
||||
LIMIT 1
|
||||
'''
|
||||
cur = db.cursor(mdb.cursors.DictCursor)
|
||||
|
@ -599,6 +593,25 @@ class Markov(Module):
|
|||
raise
|
||||
finally: cur.close()
|
||||
|
||||
def _create_chain_with_k1_k2(self, k1, k2, length, context_id):
|
||||
"""Create a chain of the given length, using k1,k2.
|
||||
|
||||
k1,k2 does not appear in the resulting chain.
|
||||
|
||||
"""
|
||||
|
||||
chain = [k1, k2]
|
||||
self.log.debug("creating chain for {0:s},{1:s}".format(k1, k2))
|
||||
|
||||
for _ in range(length):
|
||||
v = self._retrieve_random_v_for_k1_and_k2(chain[-2],
|
||||
chain[-1],
|
||||
context_id)
|
||||
if v:
|
||||
chain.append(v)
|
||||
|
||||
return chain[2:]
|
||||
|
||||
def _get_chatter_targets(self):
|
||||
"""Get all possible chatter targets."""
|
||||
|
||||
|
@ -616,34 +629,6 @@ class Markov(Module):
|
|||
raise
|
||||
finally: cur.close()
|
||||
|
||||
def _get_one_chatter_target(self):
|
||||
"""Select one random chatter target."""
|
||||
|
||||
targets = self._get_chatter_targets()
|
||||
if targets:
|
||||
return targets[random.randint(0, len(targets)-1)]
|
||||
|
||||
def _get_max_chain_id(self):
|
||||
"""Get the highest id in the chain table."""
|
||||
|
||||
db = self.get_db()
|
||||
try:
|
||||
query = '''
|
||||
SELECT id FROM markov_chain ORDER BY id DESC LIMIT 1
|
||||
'''
|
||||
cur = db.cursor(mdb.cursors.DictCursor)
|
||||
cur.execute(query)
|
||||
result = cur.fetchone()
|
||||
if result:
|
||||
return result['id']
|
||||
else:
|
||||
return None
|
||||
except mdb.Error as e:
|
||||
self.log.error("database error in _get_max_chain_id")
|
||||
self.log.exception(e)
|
||||
raise
|
||||
finally: cur.close()
|
||||
|
||||
def _get_context_id_for_target(self, target):
|
||||
"""Get the context ID for the desired/input target."""
|
||||
|
||||
|
|
Loading…
Reference in New Issue