""" Markov - Chatterbot via Markov chains for IRC Copyright (C) 2010 Brian S. Stephan This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . """ from datetime import datetime import random import re import thread import time from dateutil.relativedelta import relativedelta import MySQLdb as mdb from extlib import irclib from Module import Module class Markov(Module): """Create a chatterbot very similar to a MegaHAL, but simpler and implemented in pure Python. Proof of concept code from Ape. Ape wrote: based on this: http://uswaretech.com/blog/2009/06/pseudo-random-text-markov-chains-python/ and this: http://code.activestate.com/recipes/194364-the-markov-chain-algorithm/ """ def __init__(self, irc, config): """Create the Markov chainer, and learn text from a file if available. """ # set up some keywords for use in the chains --- don't change these # once you've created a brain self.start1 = '__start1' self.start2 = '__start2' self.stop = '__stop' # set up regexes, for replying to specific stuff learnpattern = '^!markov\s+learn\s+(.*)$' replypattern = '^!markov\s+reply(\s+min=(\d+))?(\s+max=(\d+))?(\s+(.*)$|$)' self.learnre = re.compile(learnpattern) self.replyre = re.compile(replypattern) self.shut_up = False self.lines_seen = [] Module.__init__(self, irc, config) self.next_shut_up_check = 0 self.next_chatter_check = 0 self.connection = None thread.start_new_thread(self.thread_do, ()) irc.xmlrpc_register_function(self._generate_line, "markov_generate_line") def db_init(self): """Create the markov chain table.""" version = self.db_module_registered(self.__class__.__name__) if version == None: db = self.get_db() try: version = 1 cur = db.cursor(mdb.cursors.DictCursor) cur.execute(''' CREATE TABLE markov_chatter_target ( id SERIAL, target VARCHAR(256) NOT NULL, chance INTEGER NOT NULL DEFAULT 99999 ) ENGINE=InnoDB CHARACTER SET utf8 COLLATE utf8_bin ''') cur.execute(''' CREATE TABLE markov_context ( id SERIAL, context VARCHAR(256) NOT NULL ) ENGINE=InnoDB CHARACTER SET utf8 COLLATE utf8_bin ''') cur.execute(''' CREATE TABLE markov_target_to_context_map ( id SERIAL, target VARCHAR(256) NOT NULL, context_id BIGINT(20) UNSIGNED NOT NULL, FOREIGN KEY(context_id) REFERENCES markov_context(id) ) ENGINE=InnoDB CHARACTER SET utf8 COLLATE utf8_bin ''') cur.execute(''' CREATE TABLE markov_chain ( id SERIAL, k1 VARCHAR(128) NOT NULL, k2 VARCHAR(128) NOT NULL, v VARCHAR(128) NOT NULL, context_id BIGINT(20) UNSIGNED NOT NULL, FOREIGN KEY(context_id) REFERENCES markov_context(id) ) ENGINE=InnoDB CHARACTER SET utf8 COLLATE utf8_bin ''') cur.execute(''' CREATE INDEX markov_chain_keys_and_context_id_index ON markov_chain (k1, k2, context_id)''') cur.execute(''' CREATE INDEX markov_chain_value_and_context_id_index ON markov_chain (v, context_id)''') db.commit() self.db_register_module_version(self.__class__.__name__, version) except mdb.Error as e: db.rollback() self.log.error("database error trying to create tables") self.log.exception(e) raise finally: cur.close() def register_handlers(self): """Handle pubmsg/privmsg, to learn and/or reply to IRC events.""" self.irc.server.add_global_handler('pubmsg', self.on_pub_or_privmsg, self.priority()) self.irc.server.add_global_handler('privmsg', self.on_pub_or_privmsg, self.priority()) self.irc.server.add_global_handler('pubmsg', self.learn_from_irc_event) self.irc.server.add_global_handler('privmsg', self.learn_from_irc_event) def unregister_handlers(self): self.irc.server.remove_global_handler('pubmsg', self.on_pub_or_privmsg) self.irc.server.remove_global_handler('privmsg', self.on_pub_or_privmsg) self.irc.server.remove_global_handler('pubmsg', self.learn_from_irc_event) self.irc.server.remove_global_handler('privmsg', self.learn_from_irc_event) def learn_from_irc_event(self, connection, event): """Learn from IRC events.""" what = ''.join(event.arguments()[0]) my_nick = connection.get_nickname() what = re.sub('^' + my_nick + '[:,]\s+', '', what) target = event.target() nick = irclib.nm_to_n(event.source()) if not irclib.is_channel(target): target = nick self.lines_seen.append((nick, datetime.now())) self.connection = connection # don't learn from commands if self.learnre.search(what) or self.replyre.search(what): return self._learn_line(what, target, event) def do(self, connection, event, nick, userhost, what, admin_unlocked): """Handle commands and inputs.""" target = event.target() if self.learnre.search(what): return self.irc.reply(event, self.markov_learn(connection, event, nick, userhost, what, admin_unlocked)) elif self.replyre.search(what) and not self.shut_up: return self.irc.reply(event, self.markov_reply(connection, event, nick, userhost, what, admin_unlocked)) if not self.shut_up: # not a command, so see if i'm being mentioned if re.search(connection.get_nickname(), what, re.IGNORECASE) is not None: addressed_pattern = '^' + connection.get_nickname() + '[:,]\s+(.*)' addressed_re = re.compile(addressed_pattern) if addressed_re.match(what): # i was addressed directly, so respond, addressing # the speaker self.lines_seen.append(('.self.said.', datetime.now())) return self.irc.reply(event, '{0:s}: {1:s}'.format(nick, self._generate_line(target, line=addressed_re.match(what).group(1)))) else: # i wasn't addressed directly, so just respond self.lines_seen.append(('.self.said.', datetime.now())) return self.irc.reply(event, '{0:s}'.format(self._generate_line(target, line=what))) def markov_learn(self, connection, event, nick, userhost, what, admin_unlocked): """Learn one line, as provided to the command.""" target = event.target() match = self.learnre.search(what) if match: line = match.group(1) self._learn_line(line, target, event) # return what was learned, for weird chaining purposes return line def markov_reply(self, connection, event, nick, userhost, what, admin_unlocked): """Generate a reply to one line, without learning it.""" target = event.target() match = self.replyre.search(what) if match: min_size = 15 max_size = 100 if match.group(2): min_size = int(match.group(2)) if match.group(4): max_size = int(match.group(4)) if match.group(5) != '': line = match.group(6) self.lines_seen.append(('.self.said.', datetime.now())) return self._generate_line(target, line=line, min_size=min_size, max_size=max_size) else: self.lines_seen.append(('.self.said.', datetime.now())) return self._generate_line(target, min_size=min_size, max_size=max_size) def thread_do(self): """Do various things.""" while not self.is_shutdown: self._do_shut_up_checks() self._do_random_chatter_check() time.sleep(1) def _do_random_chatter_check(self): """Randomly say something to a channel.""" # don't immediately potentially chatter, let the bot # join channels first if self.next_chatter_check == 0: self.next_chatter_check = time.time() + 600 if self.next_chatter_check < time.time(): self.next_chatter_check = time.time() + 600 if self.connection is None: # i haven't seen any text yet... return targets = self._get_chatter_targets() for t in targets: if t['chance'] > 0: a = random.randint(1, t['chance']) if a == 1: self.sendmsg(self.connection, t['target'], self._generate_line(t['target'])) def _do_shut_up_checks(self): """Check to see if we've been talking too much, and shut up if so.""" if self.next_shut_up_check < time.time(): self.shut_up = False self.next_shut_up_check = time.time() + 30 last_30_sec_lines = [] for (nick, then) in self.lines_seen: rdelta = relativedelta(datetime.now(), then) if (rdelta.years == 0 and rdelta.months == 0 and rdelta.days == 0 and rdelta.hours == 0 and rdelta.minutes == 0 and rdelta.seconds <= 29): last_30_sec_lines.append((nick, then)) if len(last_30_sec_lines) >= 8: lines_i_said = len(filter(lambda (a, b): a == '.self.said.', last_30_sec_lines)) if lines_i_said >= 8: self.shut_up = True targets = self._get_chatter_targets() for t in targets: self.sendmsg(self.connection, t['target'], 'shutting up for 30 seconds due to last 30 seconds of activity') def _learn_line(self, line, target, event): """Create Markov chains from the provided line.""" # set up the head of the chain k1 = self.start1 k2 = self.start2 context_id = self._get_context_id_for_target(target) # don't learn recursion if not event._recursing: words = line.split() if len(words) == 0: return line db = self.get_db() try: cur = db.cursor(mdb.cursors.DictCursor) statement = 'INSERT INTO markov_chain (k1, k2, v, context_id) VALUES (%s, %s, %s, %s)' for word in words: cur.execute(statement, (k1, k2, word, context_id)) k1, k2 = k2, word cur.execute(statement, (k1, k2, self.stop, context_id)) db.commit() except mdb.Error as e: db.rollback() self.log.error("database error learning line") self.log.exception(e) raise finally: cur.close() def _generate_line(self, target, line='', min_size=15, max_size=100): """Create a line, optionally using some text in a seed as a point in the chain. Keyword arguments: target - the target to retrieve the context for (i.e. a channel or nick) line - the line to reply to, by picking a random word and seeding with it min_size - the minimum desired size in words. not guaranteed max_size - the maximum desired size in words. not guaranteed """ # if the limit is too low, there's nothing to do if (max_size <= 3): raise Exception("max_size is too small: %d" % max_size) # if the min is too large, abort if (min_size > 20): raise Exception("min_size is too large: %d" % min_size) seed_words = [] # shuffle the words in the input seed_words = line.split() random.shuffle(seed_words) self.log.debug("seed words: {0:s}".format(seed_words)) # hit to generate a new seed word immediately if possible seed_word = None hit_word = None context_id = self._get_context_id_for_target(target) # start with an empty chain, and work from there gen_words = [self.start1, self.start2] # walk a chain, randomly, building the list of words while len(gen_words) < max_size + 2 and gen_words[-1] != self.stop: # pick a word from the shuffled seed words, if we need a new one if seed_word == hit_word: if len(seed_words) > 0: seed_word = seed_words.pop() self.log.debug("picked new seed word: {0:s}".format(seed_word)) else: seed_word = None self.log.debug("ran out of seed words") # first, see if we have an empty response and a target word. # if so, work backwards, otherwise forwards if gen_words[-1] == self.start2 and seed_word is not None: # work backwards working_backwards = [] back_k2 = self._retrieve_random_k2_for_value(seed_word, context_id) if back_k2: found_word = seed_word if back_k2 == self.start2: self.log.debug("random further back was start2, swallowing") else: working_backwards.append(back_k2) working_backwards.append(found_word) self.log.debug("started working backwards with: {0:s}".format(found_word)) self.log.debug("working_backwards: {0:s}".format(" ".join(working_backwards))) # now work backwards until we randomly bump into a start # to steer the chainer away from spending too much time on # the weaker-context reverse chaining, we make max_size # a non-linear distribution, making it more likely that # some time is spent on better forward chains max_back = random.randint(1, max_size/2) + random.randint(1, max_size/2) self.log.debug("max_back: {0:d}".format(max_back)) while len(working_backwards) < max_back: back_k2 = self._retrieve_random_k2_for_value(working_backwards[0], context_id) if back_k2 == self.start2: self.log.debug("random further back was start2, finishing") break elif back_k2: working_backwards.insert(0, back_k2) self.log.debug("added '{0:s}' to working_backwards".format(back_k2)) self.log.debug("working_backwards: {0:s}".format(" ".join(working_backwards))) else: self.log.debug("nothing (at all!?) further back, finishing") break gen_words += working_backwards self.log.debug("gen_words: {0:s}".format(" ".join(gen_words))) else: # work forwards self.log.debug("looking forwards") prefer = seed_word if seed_word else '' self.log.debug("preferring: '{0:s}'".format(prefer)) forw_v = self._retrieve_random_v_for_k1_and_k2_with_pref(gen_words[-2], gen_words[-1], prefer, context_id) if forw_v: gen_words.append(forw_v) self.log.debug("added random word '{0:s}' to gen_words".format(forw_v)) self.log.debug("gen_words: {0:s}".format(" ".join(gen_words))) else: # append stop, let below code clean it up if necessary gen_words.append(self.stop) self.log.debug("nothing found, adding stop") self.log.debug("gen_words: {0:s}".format(" ".join(gen_words))) # tack a new chain onto the list and resume if we're too short if gen_words[-1] == self.stop and len(gen_words) < min_size + 3: self.log.debug("starting a new chain on end of old one") # chop off the end text, if it was the keyword indicating an end of chain if gen_words[-1] == self.stop: gen_words = gen_words[:-1] # monkey with the end word to make it more like an actual sentence end sentence_end = gen_words[-1] eos_punctuation = ['!', '?', ',', '.'] if sentence_end[-1] not in eos_punctuation: random.shuffle(eos_punctuation) gen_words[-1] = sentence_end + eos_punctuation.pop() self.log.debug("monkeyed with end of sentence, it's now: {0:s}".format(gen_words[-1])) new_chain_words = [] # new word 1 key_hits = self._retrieve_chains_for_key(self.start1, self.start2, context_id) new_chain_words.append(self._get_suitable_word_from_choices(key_hits, gen_words, min_size)) # the database is probably empty if we got a stop from this if new_chain_words[0] == self.stop: break # new word 2 key_hits = self._retrieve_chains_for_key(self.start2, new_chain_words[0], context_id) new_chain_words.append(self._get_suitable_word_from_choices(key_hits, gen_words, min_size)) if new_chain_words[1] != self.stop: # two valid words, try for a third and check for "foo:" # new word 3 (which we may need below) key_hits = self._retrieve_chains_for_key(new_chain_words[0], new_chain_words[1], context_id) new_chain_words.append(self._get_suitable_word_from_choices(key_hits, gen_words, min_size)) # if the first word is "foo:", start with the second addressing_suffixes = [':', ','] if new_chain_words[0][-1] in addressing_suffixes: gen_words += new_chain_words[1:] self.log.debug("appending following anti-address " \ "new_chain_words: {0:s}".format(new_chain_words[1:])) elif new_chain_words[2] == self.stop: gen_words += new_chain_words[0:1] self.log.debug("appending following anti-stop " \ "new_chain_words: {0:s}".format(new_chain_words[0:1])) else: gen_words += new_chain_words[0:] self.log.debug("appending following extended " \ "new_chain_words: {0:s}".format(new_chain_words[0:])) else: # well, we got one word out of this... let's go with it # and let the loop check if we need more self.log.debug("appending following short new_chain_words: {0:s}".format(new_chain_words)) gen_words += new_chain_words # no matter forwards or backwards, use the end of the sentence # as our current hit word hit_word = gen_words[-1] # chop off the seed data at the start gen_words = gen_words[2:] # chop off the end text, if it was the keyword indicating an end of chain if gen_words[-1] == self.stop: gen_words = gen_words[:-1] return ' '.join(gen_words) def _get_suitable_word_from_choices(self, key_hits, gen_words, min_size): """Given an existing set of words, and key hits, pick one.""" # first, if we're not yet at min_size, pick a non-stop word if it exists # else, if there were no results, append stop # otherwise, pick a random result if len(gen_words) < min_size + 2 and len(filter(lambda a: a != self.stop, key_hits)) > 0: found_word = random.choice(filter(lambda a: a != self.stop, key_hits)) return found_word elif len(key_hits) == 0: return self.stop else: found_word = random.choice(key_hits) return found_word def _retrieve_chains_for_key(self, k1, k2, context_id): """Get the value(s) for a given key (a pair of strings).""" values = [] db = self.get_db() try: query = '' if k1 == self.start1 and k2 == self.start2: # hack. get a quasi-random start from the database, in # a faster fashion than selecting all starts max_id = self._get_max_chain_id() rand_id = random.randint(1, max_id) query = ('SELECT v FROM markov_chain WHERE k1 = %s AND k2 = %s AND ' '(context_id = %s) AND id >= {0:d} LIMIT 1'.format(rand_id)) else: query = ('SELECT v FROM markov_chain WHERE k1 = %s AND k2 = %s AND ' '(context_id = %s)') cur = db.cursor(mdb.cursors.DictCursor) cur.execute(query, (k1, k2, context_id)) results = cur.fetchall() for result in results: values.append(result['v']) return values except mdb.Error as e: self.log.error("database error in _retrieve_chains_for_key") self.log.exception(e) raise finally: cur.close() def _retrieve_random_v_for_k1_and_k2_with_pref(self, k1, k2, prefer, context_id): """Get one v for a given k1,k2. Prefer that the result be prefer, if it's found. """ values = [] db = self.get_db() try: query = ''' SELECT v FROM markov_chain AS r1 JOIN ( SELECT (RAND() * (SELECT MAX(id) FROM markov_chain)) AS id ) AS r2 WHERE r1.id >= r2.id AND r1.k1 = %s AND r1.k2 = %s AND r1.context_id = %s ORDER BY r1.v = %s DESC, r1.id ASC LIMIT 1 ''' cur = db.cursor(mdb.cursors.DictCursor) cur.execute(query, (k1, k2, context_id, prefer)) result = cur.fetchone() if result: return result['v'] except mdb.Error as e: self.log.error("database error in _retrieve_random_v_for_k1_and_k2_with_pref") self.log.exception(e) raise finally: cur.close() def _retrieve_random_k2_for_value(self, v, context_id): """Get one k2 for a given value.""" values = [] db = self.get_db() try: query = ''' SELECT k2 FROM markov_chain AS r1 JOIN ( SELECT (RAND() * (SELECT MAX(id) FROM markov_chain)) AS id ) AS r2 WHERE r1.id >= r2.id AND r1.v = %s AND r1.context_id = %s ORDER BY r1.id ASC LIMIT 1 ''' cur = db.cursor(mdb.cursors.DictCursor) cur.execute(query, (v, context_id)) result = cur.fetchone() if result: return result['k2'] except mdb.Error as e: self.log.error("database error in _retrieve_random_k2_for_value") self.log.exception(e) raise finally: cur.close() def _get_chatter_targets(self): """Get all possible chatter targets.""" db = self.get_db() try: # need to create our own db object, since this is likely going to be in a new thread query = 'SELECT target, chance FROM markov_chatter_target' cur = db.cursor(mdb.cursors.DictCursor) cur.execute(query) results = cur.fetchall() return results except mdb.Error as e: self.log.error("database error in _get_chatter_targets") self.log.exception(e) raise finally: cur.close() def _get_one_chatter_target(self): """Select one random chatter target.""" targets = self._get_chatter_targets() if targets: return targets[random.randint(0, len(targets)-1)] def _get_max_chain_id(self): """Get the highest id in the chain table.""" db = self.get_db() try: query = ''' SELECT id FROM markov_chain ORDER BY id DESC LIMIT 1 ''' cur = db.cursor(mdb.cursors.DictCursor) cur.execute(query) result = cur.fetchone() if result: return result['id'] else: return None except mdb.Error as e: self.log.error("database error in _get_max_chain_id") self.log.exception(e) raise finally: cur.close() def _get_context_id_for_target(self, target): """Get the context ID for the desired/input target.""" db = self.get_db() try: query = ''' SELECT mc.id FROM markov_context mc INNER JOIN markov_target_to_context_map mt ON mt.context_id = mc.id WHERE mt.target = %s ''' cur = db.cursor(mdb.cursors.DictCursor) cur.execute(query, (target,)) result = cur.fetchone() db.close() if result: return result['id'] else: # auto-generate a context to keep things private self._add_context_for_target(target) return self._get_context_id_for_target(target) except mdb.Error as e: self.log.error("database error in _get_context_id_for_target") self.log.exception(e) raise finally: cur.close() def _add_context_for_target(self, target): """Create a new context for the desired/input target.""" db = self.get_db() try: statement = 'INSERT INTO markov_context (context) VALUES (%s)' cur = db.cursor(mdb.cursors.DictCursor) cur.execute(statement, (target,)) statement = ''' INSERT INTO markov_target_to_context_map (target, context_id) VALUES (%s, (SELECT id FROM markov_context WHERE context = %s)) ''' cur.execute(statement, (target, target)) db.commit() except mdb.Error as e: db.rollback() self.log.error("database error in _add_context_for_target") self.log.exception(e) raise finally: cur.close() try: query = ''' SELECT mc.id FROM markov_context mc INNER JOIN markov_target_to_context_map mt ON mt.context_id = mc.id WHERE mt.target = %s ''' cur = db.cursor(mdb.cursors.DictCursor) cur.execute(query, (target,)) result = cur.fetchone() if result: return result['id'] else: # auto-generate a context to keep things private self._add_context_for_target(target) return self._get_context_id_for_target(target) except mdb.Error as e: self.log.error("database error in _get_context_id_for_target") self.log.exception(e) raise finally: cur.close() # vi:tabstop=4:expandtab:autoindent # kate: indent-mode python;indent-width 4;replace-tabs on;