dr.botzo/modules/Markov.py
Brian S. Stephan 3e76f75bba Module: remove reply(), use DrBotIRC's
obviously this means all of the modules changed to accomodate. this is
one of many steps to reduce the number of times we pass connections and
servers and other such info around, when it's mostly unnecessary because
modules have a reference to DrBotIRC
2012-12-19 20:51:35 -06:00

659 lines
27 KiB
Python

"""
Markov - Chatterbot via Markov chains for IRC
Copyright (C) 2010 Brian S. Stephan
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
from datetime import datetime
import random
import re
import thread
import time
from dateutil.relativedelta import *
import MySQLdb as mdb
from extlib import irclib
from Module import Module
class Markov(Module):
"""
Create a chatterbot very similar to a MegaHAL, but simpler and
implemented in pure Python. Proof of concept code from Ape.
Ape wrote: based on this:
http://uswaretech.com/blog/2009/06/pseudo-random-text-markov-chains-python/
and this:
http://code.activestate.com/recipes/194364-the-markov-chain-algorithm/
"""
def __init__(self, irc, config, server):
"""Create the Markov chainer, and learn text from a file if available."""
# set up some keywords for use in the chains --- don't change these
# once you've created a brain
self.start1 = '__start1'
self.start2 = '__start2'
self.stop = '__stop'
# set up regexes, for replying to specific stuff
learnpattern = '^!markov\s+learn\s+(.*)$'
replypattern = '^!markov\s+reply(\s+min=(\d+))?(\s+max=(\d+))?(\s+(.*)$|$)'
self.learnre = re.compile(learnpattern)
self.replyre = re.compile(replypattern)
self.shut_up = False
self.lines_seen = []
Module.__init__(self, irc, config, server)
self.next_shut_up_check = 0
self.next_chatter_check = 0
self.connection = None
thread.start_new_thread(self.thread_do, ())
irc.xmlrpc_register_function(self._generate_line, "markov_generate_line")
def db_init(self):
"""Create the markov chain table."""
version = self.db_module_registered(self.__class__.__name__)
if version == None:
db = self.get_db()
try:
version = 1
cur = db.cursor(mdb.cursors.DictCursor)
cur.execute('''
CREATE TABLE markov_chatter_target (
id SERIAL,
target VARCHAR(256) NOT NULL,
chance INTEGER NOT NULL DEFAULT 99999
) ENGINE=InnoDB CHARACTER SET utf8 COLLATE utf8_bin
''')
cur.execute('''
CREATE TABLE markov_context (
id SERIAL,
context VARCHAR(256) NOT NULL
) ENGINE=InnoDB CHARACTER SET utf8 COLLATE utf8_bin
''')
cur.execute('''
CREATE TABLE markov_target_to_context_map (
id SERIAL,
target VARCHAR(256) NOT NULL,
context_id BIGINT(20) UNSIGNED NOT NULL,
FOREIGN KEY(context_id) REFERENCES markov_context(id)
) ENGINE=InnoDB CHARACTER SET utf8 COLLATE utf8_bin
''')
cur.execute('''
CREATE TABLE markov_chain (
id SERIAL,
k1 VARCHAR(128) NOT NULL,
k2 VARCHAR(128) NOT NULL,
v VARCHAR(128) NOT NULL,
context_id BIGINT(20) UNSIGNED NOT NULL,
FOREIGN KEY(context_id) REFERENCES markov_context(id)
) ENGINE=InnoDB CHARACTER SET utf8 COLLATE utf8_bin
''')
cur.execute('''
CREATE INDEX markov_chain_keys_and_context_id_index
ON markov_chain (k1, k2, context_id)''')
cur.execute('''
CREATE INDEX markov_chain_value_and_context_id_index
ON markov_chain (v, context_id)''')
db.commit()
self.db_register_module_version(self.__class__.__name__, version)
except mdb.Error as e:
db.rollback()
self.log.error("database error trying to create tables")
self.log.exception(e)
raise
finally: cur.close()
def register_handlers(self):
"""Handle pubmsg/privmsg, to learn and/or reply to IRC events."""
self.server.add_global_handler('pubmsg', self.on_pub_or_privmsg, self.priority())
self.server.add_global_handler('privmsg', self.on_pub_or_privmsg, self.priority())
self.server.add_global_handler('pubmsg', self.learn_from_irc_event)
self.server.add_global_handler('privmsg', self.learn_from_irc_event)
def unregister_handlers(self):
self.server.remove_global_handler('pubmsg', self.on_pub_or_privmsg)
self.server.remove_global_handler('privmsg', self.on_pub_or_privmsg)
self.server.remove_global_handler('pubmsg', self.learn_from_irc_event)
self.server.remove_global_handler('privmsg', self.learn_from_irc_event)
def learn_from_irc_event(self, connection, event):
"""Learn from IRC events."""
what = ''.join(event.arguments()[0])
my_nick = connection.get_nickname()
what = re.sub('^' + my_nick + '[:,]\s+', '', what)
target = event.target()
nick = irclib.nm_to_n(event.source())
if not irclib.is_channel(target):
target = nick
self.lines_seen.append((nick, datetime.now()))
self.connection = connection
# don't learn from commands
if self.learnre.search(what) or self.replyre.search(what):
return
self._learn_line(what, target, event)
def do(self, connection, event, nick, userhost, what, admin_unlocked):
"""Handle commands and inputs."""
target = event.target()
if self.learnre.search(what):
return self.irc.reply(event, self.markov_learn(connection, event, nick,
userhost, what, admin_unlocked))
elif self.replyre.search(what) and not self.shut_up:
return self.irc.reply(event, self.markov_reply(connection, event, nick,
userhost, what, admin_unlocked))
if not self.shut_up:
# not a command, so see if i'm being mentioned
if re.search(connection.get_nickname(), what, re.IGNORECASE) is not None:
addressed_pattern = '^' + connection.get_nickname() + '[:,]\s+(.*)'
addressed_re = re.compile(addressed_pattern)
if addressed_re.match(what):
# i was addressed directly, so respond, addressing the speaker
self.lines_seen.append(('.self.said.', datetime.now()))
return self.irc.reply(event, '{0:s}: {1:s}'.format(nick,
self._generate_line(target, line=addressed_re.match(what).group(1))))
else:
# i wasn't addressed directly, so just respond
self.lines_seen.append(('.self.said.', datetime.now()))
return self.irc.reply(event, '{0:s}'.format(self._generate_line(target, line=what)))
def markov_learn(self, connection, event, nick, userhost, what, admin_unlocked):
"""Learn one line, as provided to the command."""
target = event.target()
match = self.learnre.search(what)
if match:
line = match.group(1)
self._learn_line(line, target, event)
# return what was learned, for weird chaining purposes
return line
def markov_reply(self, connection, event, nick, userhost, what, admin_unlocked):
"""Generate a reply to one line, without learning it."""
target = event.target()
match = self.replyre.search(what)
if match:
min_size = 15
max_size = 100
if match.group(2):
min_size = int(match.group(2))
if match.group(4):
max_size = int(match.group(4))
if match.group(5) != '':
line = match.group(6)
self.lines_seen.append(('.self.said.', datetime.now()))
return self._generate_line(target, line=line, min_size=min_size, max_size=max_size)
else:
self.lines_seen.append(('.self.said.', datetime.now()))
return self._generate_line(target, min_size=min_size, max_size=max_size)
def thread_do(self):
"""Do various things."""
while not self.is_shutdown:
self._do_shut_up_checks()
self._do_random_chatter_check()
time.sleep(1)
def _do_random_chatter_check(self):
"""Randomly say something to a channel."""
# don't immediately potentially chatter, let the bot
# join channels first
if self.next_chatter_check == 0:
self.next_chatter_check = time.time() + 600
if self.next_chatter_check < time.time():
self.next_chatter_check = time.time() + 600
if self.connection is None:
# i haven't seen any text yet...
return
targets = self._get_chatter_targets()
for t in targets:
if t['chance'] > 0:
a = random.randint(1, t['chance'])
if a == 1:
self.sendmsg(self.connection, t['target'], self._generate_line(t['target']))
def _do_shut_up_checks(self):
"""Check to see if we've been talking too much, and shut up if so."""
if self.next_shut_up_check < time.time():
self.shut_up = False
self.next_shut_up_check = time.time() + 30
last_30_sec_lines = []
for (nick,then) in self.lines_seen:
rdelta = relativedelta(datetime.now(), then)
if (rdelta.years == 0 and rdelta.months == 0 and rdelta.days == 0 and
rdelta.hours == 0 and rdelta.minutes == 0 and rdelta.seconds <= 29):
last_30_sec_lines.append((nick,then))
if len(last_30_sec_lines) >= 8:
lines_i_said = len(filter(lambda (a,b): a == '.self.said.', last_30_sec_lines))
if lines_i_said >= 8:
self.shut_up = True
targets = self._get_chatter_targets()
for t in targets:
self.sendmsg(self.connection, t['target'],
'shutting up for 30 seconds due to last 30 seconds of activity')
def _learn_line(self, line, target, event):
"""Create Markov chains from the provided line."""
# set up the head of the chain
k1 = self.start1
k2 = self.start2
context_id = self._get_context_id_for_target(target)
# don't learn recursion
if not event._recursing:
words = line.split()
if len(words) == 0:
return line
db = self.get_db()
try:
cur = db.cursor(mdb.cursors.DictCursor)
statement = 'INSERT INTO markov_chain (k1, k2, v, context_id) VALUES (%s, %s, %s, %s)'
for word in words:
cur.execute(statement, (k1, k2, word, context_id))
k1, k2 = k2, word
cur.execute(statement, (k1, k2, self.stop, context_id))
db.commit()
except mdb.Error as e:
db.rollback()
self.log.error("database error learning line")
self.log.exception(e)
raise
finally: cur.close()
def _generate_line(self, target, line='', min_size=15, max_size=100):
"""
Create a line, optionally using some text in a seed as a point in the chain.
Keyword arguments:
target - the target to retrieve the context for (i.e. a channel or nick)
line - the line to reply to, by picking a random word and seeding with it
min_size - the minimum desired size in words. not guaranteed
max_size - the maximum desired size in words. not guaranteed
"""
# if the limit is too low, there's nothing to do
if (max_size <= 3):
raise Exception("max_size is too small: %d" % max_size)
# if the min is too large, abort
if (min_size > 20):
raise Exception("min_size is too large: %d" % min_size)
seed_words = []
# shuffle the words in the input
seed_words = line.split()
random.shuffle(seed_words)
self.log.debug("seed words: {0:s}".format(seed_words))
# hit to generate a new seed word immediately if possible
seed_word = None
hit_word = None
context_id = self._get_context_id_for_target(target)
# start with an empty chain, and work from there
gen_words = [self.start1, self.start2]
# walk a chain, randomly, building the list of words
while len(gen_words) < max_size + 2 and gen_words[-1] != self.stop:
# pick a word from the shuffled seed words, if we need a new one
if seed_word == hit_word:
if len(seed_words) > 0:
seed_word = seed_words.pop()
self.log.debug("picked new seed word: {0:s}".format(seed_word))
else:
seed_word = None
self.log.debug("ran out of seed words")
# first, see if we have an empty response and a target word.
# if so, work backwards, otherwise forwards
if gen_words[-1] == self.start2 and seed_word is not None:
# work backwards
working_backwards = []
key_hits = self._retrieve_k2_for_value(seed_word, context_id)
if len(key_hits):
found_word = seed_word
working_backwards.append(found_word)
self.log.debug("started working backwards with: {0:s}".format(found_word))
self.log.debug("working_backwards: {0:s}".format(" ".join(working_backwards)))
# now work backwards until we randomly bump into a start
while True:
key_hits = self._retrieve_k2_for_value(working_backwards[0], context_id)
if self.start2 in key_hits and len(working_backwards) > 2:
self.log.debug("working backwards forced start, finishing")
gen_words += working_backwards
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
break
elif len(key_hits) == 0:
self.log.debug("no key_hits, finishing")
gen_words += working_backwards
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
break
elif len(filter(lambda a: a!= self.start2, key_hits)) == 0:
self.log.debug("only start2 in key_hits, finishing")
gen_words += working_backwards
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
break
else:
found_word = random.choice(filter(lambda a: a != self.start2, key_hits))
working_backwards.insert(0, found_word)
self.log.debug("added '{0:s}' to working_backwards".format(found_word))
self.log.debug("working_backwards: {0:s}".format(" ".join(working_backwards)))
hit_word = seed_word
else:
# work forwards
key_hits = self._retrieve_chains_for_key(gen_words[-2], gen_words[-1], context_id)
# use the chain that includes the target word, if it is found
if seed_word is not None and seed_word in key_hits:
hit_word = seed_word
self.log.debug("added seed word '{0:s}' to gen_words".format(hit_word))
else:
hit_word = self._get_suitable_word_from_choices(key_hits, gen_words, min_size)
self.log.debug("added random word '{0:s}' to gen_words".format(hit_word))
# from either method, append result
gen_words.append(hit_word)
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
# tack a new chain onto the list and resume if we're too short
if gen_words[-1] == self.stop and len(gen_words) < min_size + 3:
self.log.debug("starting a new chain on end of old one")
# chop off the end text, if it was the keyword indicating an end of chain
if gen_words[-1] == self.stop:
gen_words = gen_words[:-1]
# monkey with the end word to make it more like an actual sentence end
sentence_end = gen_words[-1]
eos_punctuation = ['!', '?', ',', '.']
if sentence_end[-1] not in eos_punctuation:
random.shuffle(eos_punctuation)
gen_words[-1] = sentence_end + eos_punctuation.pop()
self.log.debug("monkeyed with end of sentence, it's now: {0:s}".format(gen_words[-1]))
new_chain_words = []
# new word 1
key_hits = self._retrieve_chains_for_key(self.start1, self.start2, context_id)
new_chain_words.append(self._get_suitable_word_from_choices(key_hits, gen_words, min_size))
# the database is probably empty if we got a stop from this
if new_chain_words[0] == self.stop:
break
# new word 2
key_hits = self._retrieve_chains_for_key(self.start2, new_chain_words[0], context_id)
new_chain_words.append(self._get_suitable_word_from_choices(key_hits, gen_words, min_size))
if new_chain_words[1] != self.stop:
# two valid words, try for a third and check for "foo:"
# new word 3 (which we may need below)
key_hits = self._retrieve_chains_for_key(new_chain_words[0], new_chain_words[1], context_id)
new_chain_words.append(self._get_suitable_word_from_choices(key_hits, gen_words, min_size))
# if the first word is "foo:", start with the second
addressing_suffixes = [':', ',']
if new_chain_words[0][-1] in addressing_suffixes:
gen_words += new_chain_words[1:]
self.log.debug("appending following anti-address " \
"new_chain_words: {0:s}".format(new_chain_words[1:]))
elif new_chain_words[2] == self.stop:
gen_words += new_chain_words[0:1]
self.log.debug("appending following anti-stop " \
"new_chain_words: {0:s}".format(new_chain_words[0:1]))
else:
gen_words += new_chain_words[0:]
self.log.debug("appending following extended " \
"new_chain_words: {0:s}".format(new_chain_words[0:]))
else:
# well, we got one word out of this... let's go with it
# and let the loop check if we need more
self.log.debug("appending following short new_chain_words: {0:s}".format(new_chain_words))
gen_words += new_chain_words
# chop off the seed data at the start
gen_words = gen_words[2:]
# chop off the end text, if it was the keyword indicating an end of chain
if gen_words[-1] == self.stop:
gen_words = gen_words[:-1]
return ' '.join(gen_words)
def _get_suitable_word_from_choices(self, key_hits, gen_words, min_size):
"""Given an existing set of words, and key hits, pick one."""
# first, if we're not yet at min_size, pick a non-stop word if it exists
# else, if there were no results, append stop
# otherwise, pick a random result
if len(gen_words) < min_size + 2 and len(filter(lambda a: a != self.stop, key_hits)) > 0:
found_word = random.choice(filter(lambda a: a != self.stop, key_hits))
return found_word
elif len(key_hits) == 0:
return self.stop
else:
found_word = random.choice(key_hits)
return found_word
def _retrieve_chains_for_key(self, k1, k2, context_id):
"""Get the value(s) for a given key (a pair of strings)."""
values = []
db = self.get_db()
try:
query = ''
if k1 == self.start1 and k2 == self.start2:
# hack. get a quasi-random start from the database, in
# a faster fashion than selecting all starts
max_id = self._get_max_chain_id()
rand_id = random.randint(1,max_id)
query = ('SELECT v FROM markov_chain WHERE k1 = %s AND k2 = %s AND '
'(context_id = %s) AND id >= {0:d} LIMIT 1'.format(rand_id))
else:
query = ('SELECT v FROM markov_chain WHERE k1 = %s AND k2 = %s AND '
'(context_id = %s)')
cur = db.cursor(mdb.cursors.DictCursor)
cur.execute(query, (k1, k2, context_id))
results = cur.fetchall()
for result in results:
values.append(result['v'])
return values
except mdb.Error as e:
self.log.error("database error in _retrieve_chains_for_key")
self.log.exception(e)
raise
finally: cur.close()
def _retrieve_k2_for_value(self, v, context_id):
"""Get the value(s) for a given key (a pair of strings)."""
values = []
db = self.get_db()
try:
query = 'SELECT k2 FROM markov_chain WHERE v = %s AND context_id = %s'
cur = db.cursor(mdb.cursors.DictCursor)
cur.execute(query, (v, context_id))
results = cur.fetchall()
for result in results:
values.append(result['k2'])
return values
except mdb.Error as e:
self.log.error("database error in _retrieve_k2_for_value")
self.log.exception(e)
raise
finally: cur.close()
def _get_chatter_targets(self):
"""Get all possible chatter targets."""
db = self.get_db()
try:
# need to create our own db object, since this is likely going to be in a new thread
query = 'SELECT target, chance FROM markov_chatter_target'
cur = db.cursor(mdb.cursors.DictCursor)
cur.execute(query)
results = cur.fetchall()
return results
except mdb.Error as e:
self.log.error("database error in _get_chatter_targets")
self.log.exception(e)
raise
finally: cur.close()
def _get_one_chatter_target(self):
"""Select one random chatter target."""
targets = self._get_chatter_targets()
if targets:
return targets[random.randint(0, len(targets)-1)]
def _get_max_chain_id(self):
"""Get the highest id in the chain table."""
db = self.get_db()
try:
query = '''
SELECT id FROM markov_chain ORDER BY id DESC LIMIT 1
'''
cur = db.cursor(mdb.cursors.DictCursor)
cur.execute(query)
result = cur.fetchone()
if result:
return result['id']
else:
return None
except mdb.Error as e:
self.log.error("database error in _get_max_chain_id")
self.log.exception(e)
raise
finally: cur.close()
def _get_context_id_for_target(self, target):
"""Get the context ID for the desired/input target."""
db = self.get_db()
try:
query = '''
SELECT mc.id FROM markov_context mc
INNER JOIN markov_target_to_context_map mt
ON mt.context_id = mc.id
WHERE mt.target = %s
'''
cur = db.cursor(mdb.cursors.DictCursor)
cur.execute(query, (target,))
result = cur.fetchone()
db.close()
if result:
return result['id']
else:
# auto-generate a context to keep things private
self._add_context_for_target(target)
return self._get_context_id_for_target(target)
except mdb.Error as e:
self.log.error("database error in _get_context_id_for_target")
self.log.exception(e)
raise
finally: cur.close()
def _add_context_for_target(self, target):
"""Create a new context for the desired/input target."""
db = self.get_db()
try:
statement = 'INSERT INTO markov_context (context) VALUES (%s)'
cur = db.cursor(mdb.cursors.DictCursor)
cur.execute(statement, (target,))
statement = '''
INSERT INTO markov_target_to_context_map (target, context_id)
VALUES (%s, (SELECT id FROM markov_context WHERE context = %s))
'''
cur.execute(statement, (target,target))
db.commit()
except mdb.Error as e:
db.rollback()
self.log.error("database error in _add_context_for_target")
self.log.exception(e)
raise
finally: cur.close()
try:
query = '''
SELECT mc.id FROM markov_context mc
INNER JOIN markov_target_to_context_map mt
ON mt.context_id = mc.id
WHERE mt.target = %s
'''
cur = db.cursor(mdb.cursors.DictCursor)
cur.execute(query, (target,))
result = cur.fetchone()
if result:
return result['id']
else:
# auto-generate a context to keep things private
self._add_context_for_target(target)
return self._get_context_id_for_target(target)
except mdb.Error as e:
self.log.error("database error in _get_context_id_for_target")
self.log.exception(e)
raise
finally: cur.close()
# vi:tabstop=4:expandtab:autoindent
# kate: indent-mode python;indent-width 4;replace-tabs on;