diff --git a/modules/Markov.py b/modules/Markov.py index e5f5d57..98bbcc7 100644 --- a/modules/Markov.py +++ b/modules/Markov.py @@ -222,7 +222,7 @@ class Markov(Module): match = self.replyre.search(what) if match: min_size = 15 - max_size = 100 + max_size = 30 if match.group(2): min_size = int(match.group(2)) @@ -323,7 +323,7 @@ class Markov(Module): raise finally: cur.close() - def _generate_line(self, target, line='', min_size=15, max_size=100): + def _generate_line(self, target, line='', min_size=15, max_size=30): """Create a line, optionally using some text in a seed as a point in the chain. @@ -358,22 +358,51 @@ class Markov(Module): # start with an empty chain, and work from there gen_words = [self.start1, self.start2] - # walk a chain, randomly, building the list of words - while len(gen_words) < max_size + 2 and gen_words[-1] != self.stop: + # build a response by creating multiple sentences + while len(gen_words) < max_size + 2: + # if we're past the min and on a stop, we can end + if len(gen_words) > min_size + 2: + if gen_words[-1] == self.stop: + break + # pick a word from the shuffled seed words, if we need a new one if seed_word == hit_word: if len(seed_words) > 0: seed_word = seed_words.pop() - self.log.debug("picked new seed word: {0:s}".format(seed_word)) + self.log.debug("picked new seed word: " + "{0:s}".format(seed_word)) else: seed_word = None self.log.debug("ran out of seed words") - # first, see if we have an empty response and a target word. - # if so, work backwards, otherwise forwards - if gen_words[-1] == self.start2 and seed_word is not None: - # work backwards + # if we have a stop, the word before it might need to be + # made to look like a sentence end + if gen_words[-1] == self.stop: + # chop off the stop, temporarily + gen_words = gen_words[:-1] + # we should have a real word, make it look like a + # sentence end + sentence_end = gen_words[-1] + eos_punctuation = ['!', '?', ',', '.'] + if sentence_end[-1] not in eos_punctuation: + random.shuffle(eos_punctuation) + gen_words[-1] = sentence_end + eos_punctuation.pop() + self.log.debug("monkeyed with end of sentence, it's " + "now: {0:s}".format(gen_words[-1])) + + # put the stop back on + gen_words.append(self.stop) + self.log.debug("gen_words: {0:s}".format(" ".join(gen_words))) + + # first, see if we should start a new sentence. if so, + # work backwards + if gen_words[-1] in (self.start2, self.stop) and seed_word is not None: + # drop a stop, since we're starting another sentence + if gen_words[-1] == self.stop: + gen_words = gen_words[:-1] + + # work backwards from seed_word working_backwards = [] back_k2 = self._retrieve_random_k2_for_value(seed_word, context_id) if back_k2: @@ -391,7 +420,8 @@ class Markov(Module): # the weaker-context reverse chaining, we make max_size # a non-linear distribution, making it more likely that # some time is spent on better forward chains - max_back = random.randint(1, max_size/2) + random.randint(1, max_size/2) + max_back = min(random.randint(1, max_size/2) + random.randint(1, max_size/2), + max_size/4) self.log.debug("max_back: {0:d}".format(max_back)) while len(working_backwards) < max_back: back_k2 = self._retrieve_random_k2_for_value(working_backwards[0], context_id) @@ -408,80 +438,59 @@ class Markov(Module): gen_words += working_backwards self.log.debug("gen_words: {0:s}".format(" ".join(gen_words))) + hit_word = gen_words[-1] else: - # work forwards - self.log.debug("looking forwards") - prefer = seed_word if seed_word else '' - self.log.debug("preferring: '{0:s}'".format(prefer)) - forw_v = self._retrieve_random_v_for_k1_and_k2_with_pref(gen_words[-2], - gen_words[-1], - prefer, context_id) - if forw_v: - gen_words.append(forw_v) - self.log.debug("added random word '{0:s}' to gen_words".format(forw_v)) - self.log.debug("gen_words: {0:s}".format(" ".join(gen_words))) - else: - # append stop, let below code clean it up if necessary - gen_words.append(self.stop) - self.log.debug("nothing found, adding stop") - self.log.debug("gen_words: {0:s}".format(" ".join(gen_words))) - - # tack a new chain onto the list and resume if we're too short - if gen_words[-1] == self.stop and len(gen_words) < min_size + 3: - self.log.debug("starting a new chain on end of old one") - - # chop off the end text, if it was the keyword indicating an end of chain + # we are working forward, with either: + # * a pair of words (normal path, filling out a sentence) + # * start1, start2 (completely new chain, no seed words) + # * stop (new sentence in existing chain, no seed words) + self.log.debug("working forwards") + forw_v = None + if gen_words[-1] in (self.start2, self.stop): + # case 2 or 3 above, need to work forward on a beginning + # of a sentence (this is slow) if gen_words[-1] == self.stop: + # remove the stop if it's there gen_words = gen_words[:-1] - # monkey with the end word to make it more like an actual sentence end - sentence_end = gen_words[-1] - eos_punctuation = ['!', '?', ',', '.'] - if sentence_end[-1] not in eos_punctuation: - random.shuffle(eos_punctuation) - gen_words[-1] = sentence_end + eos_punctuation.pop() - self.log.debug("monkeyed with end of sentence, it's now: {0:s}".format(gen_words[-1])) + new_sentence = self._create_chain_with_k1_k2(self.start1, + self.start2, + 3, + context_id) - new_chain_words = [] - # new word 1 - key_hits = self._retrieve_chains_for_key(self.start1, self.start2, context_id) - new_chain_words.append(self._get_suitable_word_from_choices(key_hits, gen_words, min_size)) - # the database is probably empty if we got a stop from this - if new_chain_words[0] == self.stop: - break - # new word 2 - key_hits = self._retrieve_chains_for_key(self.start2, new_chain_words[0], context_id) - new_chain_words.append(self._get_suitable_word_from_choices(key_hits, gen_words, min_size)) - if new_chain_words[1] != self.stop: - # two valid words, try for a third and check for "foo:" - - # new word 3 (which we may need below) - key_hits = self._retrieve_chains_for_key(new_chain_words[0], new_chain_words[1], context_id) - new_chain_words.append(self._get_suitable_word_from_choices(key_hits, gen_words, min_size)) - - # if the first word is "foo:", start with the second - addressing_suffixes = [':', ','] - if new_chain_words[0][-1] in addressing_suffixes: - gen_words += new_chain_words[1:] - self.log.debug("appending following anti-address " \ - "new_chain_words: {0:s}".format(new_chain_words[1:])) - elif new_chain_words[2] == self.stop: - gen_words += new_chain_words[0:1] - self.log.debug("appending following anti-stop " \ - "new_chain_words: {0:s}".format(new_chain_words[0:1])) - else: - gen_words += new_chain_words[0:] - self.log.debug("appending following extended " \ - "new_chain_words: {0:s}".format(new_chain_words[0:])) + if len(new_sentence) > 0: + self.log.debug("started new sentence " + "'{0:s}'".format(" ".join(new_sentence))) + gen_words += new_sentence + self.log.debug("gen_words: {0:s}".format(" ".join(gen_words))) else: - # well, we got one word out of this... let's go with it - # and let the loop check if we need more - self.log.debug("appending following short new_chain_words: {0:s}".format(new_chain_words)) - gen_words += new_chain_words + # this is a problem. we started a sentence on + # start1,start2, and still didn't find anything. to + # avoid endlessly looping we need to abort here + break + else: + if seed_word: + self.log.debug("preferring: '{0:s}'".format(seed_word)) + forw_v = self._retrieve_random_v_for_k1_and_k2_with_pref(gen_words[-2], + gen_words[-1], + seed_word, + context_id) + else: + forw_v = self._retrieve_random_v_for_k1_and_k2(gen_words[-2], + gen_words[-1], + context_id) - # no matter forwards or backwards, use the end of the sentence - # as our current hit word - hit_word = gen_words[-1] + if forw_v: + gen_words.append(forw_v) + self.log.debug("added random word '{0:s}' to gen_words".format(forw_v)) + self.log.debug("gen_words: {0:s}".format(" ".join(gen_words))) + hit_word = gen_words[-1] + else: + # append stop. this is an end to a sentence (since + # we had non-start words to begin with) + gen_words.append(self.stop) + self.log.debug("nothing found, added stop") + self.log.debug("gen_words: {0:s}".format(" ".join(gen_words))) # chop off the seed data at the start gen_words = gen_words[2:] @@ -492,48 +501,32 @@ class Markov(Module): return ' '.join(gen_words) - def _get_suitable_word_from_choices(self, key_hits, gen_words, min_size): - """Given an existing set of words, and key hits, pick one.""" - - # first, if we're not yet at min_size, pick a non-stop word if it exists - # else, if there were no results, append stop - # otherwise, pick a random result - if len(gen_words) < min_size + 2 and len(filter(lambda a: a != self.stop, key_hits)) > 0: - found_word = random.choice(filter(lambda a: a != self.stop, key_hits)) - return found_word - elif len(key_hits) == 0: - return self.stop - else: - found_word = random.choice(key_hits) - return found_word - - def _retrieve_chains_for_key(self, k1, k2, context_id): - """Get the value(s) for a given key (a pair of strings).""" + def _retrieve_random_v_for_k1_and_k2(self, k1, k2, context_id): + """Get one v for a given k1,k2.""" + self.log.debug("searching with '{0:s}','{1:s}'".format(k1, k2)) values = [] db = self.get_db() try: - query = '' - if k1 == self.start1 and k2 == self.start2: - # hack. get a quasi-random start from the database, in - # a faster fashion than selecting all starts - max_id = self._get_max_chain_id() - rand_id = random.randint(1, max_id) - query = ('SELECT v FROM markov_chain WHERE k1 = %s AND k2 = %s AND ' - '(context_id = %s) AND id >= {0:d} LIMIT 1'.format(rand_id)) - else: - query = ('SELECT v FROM markov_chain WHERE k1 = %s AND k2 = %s AND ' - '(context_id = %s)') + query = ''' + SELECT v FROM markov_chain AS r1 + JOIN ( + SELECT (RAND() * (SELECT MAX(id) FROM markov_chain)) AS id + ) AS r2 + WHERE r1.k1 = %s + AND r1.k2 = %s + AND r1.context_id = %s + ORDER BY r1.id >= r2.id DESC, r1.id ASC + LIMIT 1 + ''' cur = db.cursor(mdb.cursors.DictCursor) cur.execute(query, (k1, k2, context_id)) - results = cur.fetchall() - - for result in results: - values.append(result['v']) - - return values + result = cur.fetchone() + if result: + self.log.debug("found '{0:s}'".format(result['v'])) + return result['v'] except mdb.Error as e: - self.log.error("database error in _retrieve_chains_for_key") + self.log.error("database error in _retrieve_random_v_for_k1_and_k2") self.log.exception(e) raise finally: cur.close() @@ -545,6 +538,8 @@ class Markov(Module): """ + self.log.debug("searching with '{0:s}','{1:s}', prefer " + "'{2:s}'".format(k1, k2, prefer)) values = [] db = self.get_db() try: @@ -553,17 +548,17 @@ class Markov(Module): JOIN ( SELECT (RAND() * (SELECT MAX(id) FROM markov_chain)) AS id ) AS r2 - WHERE r1.id >= r2.id - AND r1.k1 = %s + WHERE r1.k1 = %s AND r1.k2 = %s AND r1.context_id = %s - ORDER BY r1.v = %s DESC, r1.id ASC + ORDER BY r1.id >= r2.id DESC, r1.v = %s DESC, r1.id ASC LIMIT 1 ''' cur = db.cursor(mdb.cursors.DictCursor) cur.execute(query, (k1, k2, context_id, prefer)) result = cur.fetchone() if result: + self.log.debug("found '{0:s}'".format(result['v'])) return result['v'] except mdb.Error as e: self.log.error("database error in _retrieve_random_v_for_k1_and_k2_with_pref") @@ -582,10 +577,9 @@ class Markov(Module): JOIN ( SELECT (RAND() * (SELECT MAX(id) FROM markov_chain)) AS id ) AS r2 - WHERE r1.id >= r2.id - AND r1.v = %s + WHERE r1.v = %s AND r1.context_id = %s - ORDER BY r1.id ASC + ORDER BY r1.id >= r2.id DESC, r1.id ASC LIMIT 1 ''' cur = db.cursor(mdb.cursors.DictCursor) @@ -599,6 +593,25 @@ class Markov(Module): raise finally: cur.close() + def _create_chain_with_k1_k2(self, k1, k2, length, context_id): + """Create a chain of the given length, using k1,k2. + + k1,k2 does not appear in the resulting chain. + + """ + + chain = [k1, k2] + self.log.debug("creating chain for {0:s},{1:s}".format(k1, k2)) + + for _ in range(length): + v = self._retrieve_random_v_for_k1_and_k2(chain[-2], + chain[-1], + context_id) + if v: + chain.append(v) + + return chain[2:] + def _get_chatter_targets(self): """Get all possible chatter targets.""" @@ -616,34 +629,6 @@ class Markov(Module): raise finally: cur.close() - def _get_one_chatter_target(self): - """Select one random chatter target.""" - - targets = self._get_chatter_targets() - if targets: - return targets[random.randint(0, len(targets)-1)] - - def _get_max_chain_id(self): - """Get the highest id in the chain table.""" - - db = self.get_db() - try: - query = ''' - SELECT id FROM markov_chain ORDER BY id DESC LIMIT 1 - ''' - cur = db.cursor(mdb.cursors.DictCursor) - cur.execute(query) - result = cur.fetchone() - if result: - return result['id'] - else: - return None - except mdb.Error as e: - self.log.error("database error in _get_max_chain_id") - self.log.exception(e) - raise - finally: cur.close() - def _get_context_id_for_target(self, target): """Get the context ID for the desired/input target."""