diff --git a/modules/Markov.py b/modules/Markov.py index 6bb29ef..b2f9125 100644 --- a/modules/Markov.py +++ b/modules/Markov.py @@ -413,20 +413,21 @@ class Markov(Module): hit_word = seed_word else: # work forwards - - key_hits = self._retrieve_chains_for_key(gen_words[-2], gen_words[-1], context_id) - - # use the chain that includes the target word, if it is found - if seed_word is not None and seed_word in key_hits: - hit_word = seed_word - self.log.debug("added seed word '{0:s}' to gen_words".format(hit_word)) + self.log.debug("looking forwards") + prefer = seed_word if seed_word else '' + self.log.debug("preferring: '{0:s}'".format(prefer)) + forw_v = self._retrieve_random_v_for_k1_and_k2_with_pref(gen_words[-2], + gen_words[-1], + prefer, context_id) + if forw_v: + gen_words.append(forw_v) + self.log.debug("added random word '{0:s}' to gen_words".format(forw_v)) + self.log.debug("gen_words: {0:s}".format(" ".join(gen_words))) else: - hit_word = self._get_suitable_word_from_choices(key_hits, gen_words, min_size) - self.log.debug("added random word '{0:s}' to gen_words".format(hit_word)) - - # from either method, append result - gen_words.append(hit_word) - self.log.debug("gen_words: {0:s}".format(" ".join(gen_words))) + # append stop, let below code clean it up if necessary + gen_words.append(self.stop) + self.log.debug("nothing found, adding stop") + self.log.debug("gen_words: {0:s}".format(" ".join(gen_words))) # tack a new chain onto the list and resume if we're too short if gen_words[-1] == self.stop and len(gen_words) < min_size + 3: @@ -536,6 +537,39 @@ class Markov(Module): raise finally: cur.close() + def _retrieve_random_v_for_k1_and_k2_with_pref(self, k1, k2, prefer, context_id): + """Get one v for a given k1,k2. + + Prefer that the result be prefer, if it's found. + + """ + + values = [] + db = self.get_db() + try: + query = ''' + SELECT v FROM markov_chain AS r1 + JOIN ( + SELECT (RAND() * (SELECT MAX(id) FROM markov_chain)) AS id + ) AS r2 + WHERE r1.id >= r2.id + AND r1.k1 = %s + AND r1.k2 = %s + AND r1.context_id = %s + ORDER BY r1.v = %s DESC, r1.id ASC + LIMIT 1 + ''' + cur = db.cursor(mdb.cursors.DictCursor) + cur.execute(query, (k1, k2, context_id, prefer)) + result = cur.fetchone() + if result: + return result['v'] + except mdb.Error as e: + self.log.error("database error in _retrieve_random_v_for_k1_and_k2_with_pref") + self.log.exception(e) + raise + finally: cur.close() + def _retrieve_random_k2_for_value(self, v, context_id): """Get one k2 for a given value."""