From a0588869f3ea0bcd94ed08a312b8372e51a74dfd Mon Sep 17 00:00:00 2001 From: "Brian S. Stephan" Date: Tue, 14 Jun 2011 22:10:57 -0500 Subject: [PATCH] Markov: add selecting by context, in order to segregate chains by channel adding chains by context has existed for a while, this should allow for querying for chains with null context or the current context. lightly tested --- modules/Markov.py | 40 ++++++++++++++++++++++++++++++---------- 1 file changed, 30 insertions(+), 10 deletions(-) diff --git a/modules/Markov.py b/modules/Markov.py index 2e8b4be..b0ae8e0 100644 --- a/modules/Markov.py +++ b/modules/Markov.py @@ -159,6 +159,19 @@ class Markov(Module): db.rollback() print('sqlite error: ' + str(e)) raise + if (version < 6): + db = self.get_db() + try: + version = 6 + db.execute(''' + CREATE INDEX markov_chain_keys_and_context_index + ON markov_chain (k1, k2, context)''') + db.commit() + self.db_register_module_version(self.__class__.__name__, version) + except sqlite3.Error as e: + db.rollback() + print('sqlite error: ' + str(e)) + raise def register_handlers(self): """Handle pubmsg/privmsg, to learn and/or reply to IRC events.""" @@ -195,6 +208,8 @@ class Markov(Module): def do(self, connection, event, nick, userhost, what, admin_unlocked): """Handle commands and inputs.""" + target = event.target() + if self.trainre.search(what): return self.reply(connection, event, self.markov_train(connection, event, nick, userhost, what, admin_unlocked)) elif self.learnre.search(what): @@ -210,11 +225,11 @@ class Markov(Module): if addressed_re.match(what): # i was addressed directly, so respond, addressing the speaker self.lines_seen.append(('.self.said.', datetime.now())) - return self.reply(connection, event, '{0:s}: {1:s}'.format(nick, self._generate_line(line=addressed_re.match(what).group(1)))) + return self.reply(connection, event, '{0:s}: {1:s}'.format(nick, self._generate_line(line=addressed_re.match(what).group(1), target=target))) else: # i wasn't addressed directly, so just respond self.lines_seen.append(('.self.said.', datetime.now())) - return self.reply(connection, event, '{0:s}'.format(self._generate_line(line=what))) + return self.reply(connection, event, '{0:s}'.format(self._generate_line(line=what, target=target))) def markov_train(self, connection, event, nick, userhost, what, admin_unlocked): """Learn lines from a file. Good for initializing a brain.""" @@ -246,6 +261,7 @@ class Markov(Module): def markov_reply(self, connection, event, nick, userhost, what, admin_unlocked): """Generate a reply to one line, without learning it.""" + target = event.target() match = self.replyre.search(what) if match: min_size = 15 @@ -259,10 +275,10 @@ class Markov(Module): if match.group(5) != '': line = match.group(6) self.lines_seen.append(('.self.said.', datetime.now())) - return self._generate_line(line=line, min_size=min_size, max_size=max_size) + return self._generate_line(line=line, min_size=min_size, max_size=max_size, target=target) else: self.lines_seen.append(('.self.said.', datetime.now())) - return self._generate_line(min_size=min_size, max_size=max_size) + return self._generate_line(min_size=min_size, max_size=max_size, target=target) def timer_do(self): """Do various things.""" @@ -329,7 +345,7 @@ class Markov(Module): print("sqlite error: " + str(e)) raise - def _generate_line(self, line='', min_size=15, max_size=100): + def _generate_line(self, line='', min_size=15, max_size=100, target=None): """Reply to a line, using some text in the line as a point in the chain.""" # if the limit is too low, there's nothing to do @@ -347,12 +363,16 @@ class Markov(Module): words = line.split() target_word = words[random.randint(0, len(words)-1)] + context = 0 + if target: + context = self._get_context_for_target(target) + # start with an empty chain, and work from there gen_words = [self.start1, self.start2] # walk a chain, randomly, building the list of words while len(gen_words) < max_size + 2 and gen_words[-1] != self.stop: - key_hits = self._retrieve_chains_for_key(gen_words[-2], gen_words[-1]) + key_hits = self._retrieve_chains_for_key(gen_words[-2], gen_words[-1], context) # use the chain that includes the target word, if it is found if target_word != '' and target_word in key_hits: gen_words.append(target_word) @@ -375,7 +395,7 @@ class Markov(Module): return ' '.join(gen_words).encode('utf-8', 'ignore') - def _retrieve_chains_for_key(self, k1, k2): + def _retrieve_chains_for_key(self, k1, k2, context): """Get the value(s) for a given key (a pair of strings).""" values = [] @@ -387,10 +407,10 @@ class Markov(Module): # a faster fashion than selecting all starts max_id = self._get_max_chain_id() rand_id = random.randint(1,max_id) - query = 'SELECT v FROM markov_chain WHERE k1 = ? AND k2 = ? AND id >= {0:d} LIMIT 1'.format(rand_id) + query = 'SELECT v FROM markov_chain WHERE k1 = ? AND k2 = ? AND (context = ? OR context IS NULL) AND id >= {0:d} LIMIT 1'.format(rand_id) else: - query = 'SELECT v FROM markov_chain WHERE k1 = ? AND k2 = ?' - cursor = db.execute(query, (k1,k2)) + query = 'SELECT v FROM markov_chain WHERE k1 = ? AND k2 = ? AND (context = ? OR context IS NULL)' + cursor = db.execute(query, (k1,k2,context)) results = cursor.fetchall() for result in results: