dr.botzo/modules/Markov.py

659 lines
27 KiB
Python
Raw Normal View History

"""
Markov - Chatterbot via Markov chains for IRC
Copyright (C) 2010 Brian S. Stephan
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
from datetime import datetime
import random
import re
import thread
import time
2012-07-27 20:38:45 -05:00
from dateutil.relativedelta import *
import MySQLdb as mdb
from extlib import irclib
from Module import Module
class Markov(Module):
"""
Create a chatterbot very similar to a MegaHAL, but simpler and
implemented in pure Python. Proof of concept code from Ape.
2011-01-20 14:15:10 -06:00
Ape wrote: based on this:
http://uswaretech.com/blog/2009/06/pseudo-random-text-markov-chains-python/
and this:
http://code.activestate.com/recipes/194364-the-markov-chain-algorithm/
"""
2011-01-20 14:15:10 -06:00
def __init__(self, irc, config):
"""Create the Markov chainer, and learn text from a file if available."""
# set up some keywords for use in the chains --- don't change these
# once you've created a brain
self.start1 = '__start1'
self.start2 = '__start2'
self.stop = '__stop'
# set up regexes, for replying to specific stuff
learnpattern = '^!markov\s+learn\s+(.*)$'
replypattern = '^!markov\s+reply(\s+min=(\d+))?(\s+max=(\d+))?(\s+(.*)$|$)'
self.learnre = re.compile(learnpattern)
self.replyre = re.compile(replypattern)
self.shut_up = False
self.lines_seen = []
Module.__init__(self, irc, config)
self.next_shut_up_check = 0
self.next_chatter_check = 0
self.connection = None
thread.start_new_thread(self.thread_do, ())
irc.xmlrpc_register_function(self._generate_line, "markov_generate_line")
def db_init(self):
"""Create the markov chain table."""
version = self.db_module_registered(self.__class__.__name__)
if version == None:
db = self.get_db()
try:
version = 1
cur = db.cursor(mdb.cursors.DictCursor)
cur.execute('''
CREATE TABLE markov_chatter_target (
id SERIAL,
target VARCHAR(256) NOT NULL,
chance INTEGER NOT NULL DEFAULT 99999
) ENGINE=InnoDB CHARACTER SET utf8 COLLATE utf8_bin
''')
cur.execute('''
CREATE TABLE markov_context (
id SERIAL,
context VARCHAR(256) NOT NULL
) ENGINE=InnoDB CHARACTER SET utf8 COLLATE utf8_bin
''')
cur.execute('''
CREATE TABLE markov_target_to_context_map (
id SERIAL,
target VARCHAR(256) NOT NULL,
context_id BIGINT(20) UNSIGNED NOT NULL,
FOREIGN KEY(context_id) REFERENCES markov_context(id)
) ENGINE=InnoDB CHARACTER SET utf8 COLLATE utf8_bin
''')
cur.execute('''
CREATE TABLE markov_chain (
id SERIAL,
k1 VARCHAR(128) NOT NULL,
k2 VARCHAR(128) NOT NULL,
v VARCHAR(128) NOT NULL,
context_id BIGINT(20) UNSIGNED NOT NULL,
FOREIGN KEY(context_id) REFERENCES markov_context(id)
) ENGINE=InnoDB CHARACTER SET utf8 COLLATE utf8_bin
''')
cur.execute('''
CREATE INDEX markov_chain_keys_and_context_id_index
ON markov_chain (k1, k2, context_id)''')
cur.execute('''
CREATE INDEX markov_chain_value_and_context_id_index
ON markov_chain (v, context_id)''')
db.commit()
self.db_register_module_version(self.__class__.__name__, version)
except mdb.Error as e:
db.rollback()
self.log.error("database error trying to create tables")
self.log.exception(e)
raise
finally: cur.close()
def register_handlers(self):
"""Handle pubmsg/privmsg, to learn and/or reply to IRC events."""
self.irc.server.add_global_handler('pubmsg', self.on_pub_or_privmsg, self.priority())
self.irc.server.add_global_handler('privmsg', self.on_pub_or_privmsg, self.priority())
self.irc.server.add_global_handler('pubmsg', self.learn_from_irc_event)
self.irc.server.add_global_handler('privmsg', self.learn_from_irc_event)
def unregister_handlers(self):
self.irc.server.remove_global_handler('pubmsg', self.on_pub_or_privmsg)
self.irc.server.remove_global_handler('privmsg', self.on_pub_or_privmsg)
self.irc.server.remove_global_handler('pubmsg', self.learn_from_irc_event)
self.irc.server.remove_global_handler('privmsg', self.learn_from_irc_event)
def learn_from_irc_event(self, connection, event):
"""Learn from IRC events."""
what = ''.join(event.arguments()[0])
my_nick = connection.get_nickname()
what = re.sub('^' + my_nick + '[:,]\s+', '', what)
target = event.target()
nick = irclib.nm_to_n(event.source())
if not irclib.is_channel(target):
target = nick
self.lines_seen.append((nick, datetime.now()))
self.connection = connection
# don't learn from commands
if self.learnre.search(what) or self.replyre.search(what):
return
self._learn_line(what, target, event)
def do(self, connection, event, nick, userhost, what, admin_unlocked):
"""Handle commands and inputs."""
target = event.target()
if self.learnre.search(what):
return self.irc.reply(event, self.markov_learn(connection, event, nick,
userhost, what, admin_unlocked))
elif self.replyre.search(what) and not self.shut_up:
return self.irc.reply(event, self.markov_reply(connection, event, nick,
userhost, what, admin_unlocked))
if not self.shut_up:
# not a command, so see if i'm being mentioned
if re.search(connection.get_nickname(), what, re.IGNORECASE) is not None:
addressed_pattern = '^' + connection.get_nickname() + '[:,]\s+(.*)'
addressed_re = re.compile(addressed_pattern)
if addressed_re.match(what):
# i was addressed directly, so respond, addressing the speaker
self.lines_seen.append(('.self.said.', datetime.now()))
return self.irc.reply(event, '{0:s}: {1:s}'.format(nick,
self._generate_line(target, line=addressed_re.match(what).group(1))))
else:
# i wasn't addressed directly, so just respond
self.lines_seen.append(('.self.said.', datetime.now()))
return self.irc.reply(event, '{0:s}'.format(self._generate_line(target, line=what)))
def markov_learn(self, connection, event, nick, userhost, what, admin_unlocked):
"""Learn one line, as provided to the command."""
target = event.target()
match = self.learnre.search(what)
if match:
line = match.group(1)
self._learn_line(line, target, event)
# return what was learned, for weird chaining purposes
return line
def markov_reply(self, connection, event, nick, userhost, what, admin_unlocked):
"""Generate a reply to one line, without learning it."""
target = event.target()
match = self.replyre.search(what)
if match:
min_size = 15
max_size = 100
if match.group(2):
min_size = int(match.group(2))
if match.group(4):
max_size = int(match.group(4))
if match.group(5) != '':
line = match.group(6)
self.lines_seen.append(('.self.said.', datetime.now()))
return self._generate_line(target, line=line, min_size=min_size, max_size=max_size)
else:
self.lines_seen.append(('.self.said.', datetime.now()))
return self._generate_line(target, min_size=min_size, max_size=max_size)
def thread_do(self):
"""Do various things."""
while not self.is_shutdown:
self._do_shut_up_checks()
self._do_random_chatter_check()
time.sleep(1)
def _do_random_chatter_check(self):
"""Randomly say something to a channel."""
# don't immediately potentially chatter, let the bot
# join channels first
if self.next_chatter_check == 0:
self.next_chatter_check = time.time() + 600
if self.next_chatter_check < time.time():
self.next_chatter_check = time.time() + 600
if self.connection is None:
# i haven't seen any text yet...
return
targets = self._get_chatter_targets()
for t in targets:
if t['chance'] > 0:
a = random.randint(1, t['chance'])
if a == 1:
self.sendmsg(self.connection, t['target'], self._generate_line(t['target']))
def _do_shut_up_checks(self):
"""Check to see if we've been talking too much, and shut up if so."""
if self.next_shut_up_check < time.time():
self.shut_up = False
self.next_shut_up_check = time.time() + 30
last_30_sec_lines = []
for (nick,then) in self.lines_seen:
rdelta = relativedelta(datetime.now(), then)
if (rdelta.years == 0 and rdelta.months == 0 and rdelta.days == 0 and
rdelta.hours == 0 and rdelta.minutes == 0 and rdelta.seconds <= 29):
last_30_sec_lines.append((nick,then))
if len(last_30_sec_lines) >= 8:
lines_i_said = len(filter(lambda (a,b): a == '.self.said.', last_30_sec_lines))
if lines_i_said >= 8:
self.shut_up = True
targets = self._get_chatter_targets()
for t in targets:
self.sendmsg(self.connection, t['target'],
'shutting up for 30 seconds due to last 30 seconds of activity')
def _learn_line(self, line, target, event):
"""Create Markov chains from the provided line."""
# set up the head of the chain
k1 = self.start1
k2 = self.start2
context_id = self._get_context_id_for_target(target)
# don't learn recursion
if not event._recursing:
words = line.split()
2012-07-29 17:46:14 -05:00
if len(words) == 0:
return line
db = self.get_db()
try:
cur = db.cursor(mdb.cursors.DictCursor)
statement = 'INSERT INTO markov_chain (k1, k2, v, context_id) VALUES (%s, %s, %s, %s)'
for word in words:
cur.execute(statement, (k1, k2, word, context_id))
k1, k2 = k2, word
cur.execute(statement, (k1, k2, self.stop, context_id))
db.commit()
except mdb.Error as e:
db.rollback()
self.log.error("database error learning line")
self.log.exception(e)
raise
finally: cur.close()
def _generate_line(self, target, line='', min_size=15, max_size=100):
"""
Create a line, optionally using some text in a seed as a point in the chain.
Keyword arguments:
target - the target to retrieve the context for (i.e. a channel or nick)
line - the line to reply to, by picking a random word and seeding with it
min_size - the minimum desired size in words. not guaranteed
max_size - the maximum desired size in words. not guaranteed
"""
# if the limit is too low, there's nothing to do
if (max_size <= 3):
raise Exception("max_size is too small: %d" % max_size)
# if the min is too large, abort
if (min_size > 20):
raise Exception("min_size is too large: %d" % min_size)
seed_words = []
# shuffle the words in the input
seed_words = line.split()
random.shuffle(seed_words)
self.log.debug("seed words: {0:s}".format(seed_words))
# hit to generate a new seed word immediately if possible
seed_word = None
hit_word = None
context_id = self._get_context_id_for_target(target)
# start with an empty chain, and work from there
2011-04-23 16:27:07 -05:00
gen_words = [self.start1, self.start2]
# walk a chain, randomly, building the list of words
while len(gen_words) < max_size + 2 and gen_words[-1] != self.stop:
# pick a word from the shuffled seed words, if we need a new one
if seed_word == hit_word:
if len(seed_words) > 0:
seed_word = seed_words.pop()
self.log.debug("picked new seed word: {0:s}".format(seed_word))
else:
seed_word = None
self.log.debug("ran out of seed words")
# first, see if we have an empty response and a target word.
# if so, work backwards, otherwise forwards
if gen_words[-1] == self.start2 and seed_word is not None:
# work backwards
working_backwards = []
key_hits = self._retrieve_k2_for_value(seed_word, context_id)
if len(key_hits):
found_word = seed_word
working_backwards.append(found_word)
self.log.debug("started working backwards with: {0:s}".format(found_word))
2012-07-28 13:32:58 -05:00
self.log.debug("working_backwards: {0:s}".format(" ".join(working_backwards)))
# now work backwards until we randomly bump into a start
while True:
key_hits = self._retrieve_k2_for_value(working_backwards[0], context_id)
if self.start2 in key_hits and len(working_backwards) > 2:
self.log.debug("working backwards forced start, finishing")
gen_words += working_backwards
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
break
elif len(key_hits) == 0:
self.log.debug("no key_hits, finishing")
gen_words += working_backwards
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
break
elif len(filter(lambda a: a!= self.start2, key_hits)) == 0:
self.log.debug("only start2 in key_hits, finishing")
gen_words += working_backwards
2012-07-28 13:32:58 -05:00
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
break
else:
found_word = random.choice(filter(lambda a: a != self.start2, key_hits))
working_backwards.insert(0, found_word)
2012-07-28 13:32:58 -05:00
self.log.debug("added '{0:s}' to working_backwards".format(found_word))
self.log.debug("working_backwards: {0:s}".format(" ".join(working_backwards)))
hit_word = seed_word
else:
# work forwards
key_hits = self._retrieve_chains_for_key(gen_words[-2], gen_words[-1], context_id)
# use the chain that includes the target word, if it is found
if seed_word is not None and seed_word in key_hits:
hit_word = seed_word
self.log.debug("added seed word '{0:s}' to gen_words".format(hit_word))
else:
hit_word = self._get_suitable_word_from_choices(key_hits, gen_words, min_size)
self.log.debug("added random word '{0:s}' to gen_words".format(hit_word))
# from either method, append result
gen_words.append(hit_word)
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
# tack a new chain onto the list and resume if we're too short
if gen_words[-1] == self.stop and len(gen_words) < min_size + 3:
self.log.debug("starting a new chain on end of old one")
# chop off the end text, if it was the keyword indicating an end of chain
if gen_words[-1] == self.stop:
gen_words = gen_words[:-1]
# monkey with the end word to make it more like an actual sentence end
sentence_end = gen_words[-1]
eos_punctuation = ['!', '?', ',', '.']
if sentence_end[-1] not in eos_punctuation:
random.shuffle(eos_punctuation)
gen_words[-1] = sentence_end + eos_punctuation.pop()
2012-07-29 15:44:43 -05:00
self.log.debug("monkeyed with end of sentence, it's now: {0:s}".format(gen_words[-1]))
new_chain_words = []
# new word 1
key_hits = self._retrieve_chains_for_key(self.start1, self.start2, context_id)
new_chain_words.append(self._get_suitable_word_from_choices(key_hits, gen_words, min_size))
# the database is probably empty if we got a stop from this
2012-09-17 16:23:42 -05:00
if new_chain_words[0] == self.stop:
break
# new word 2
key_hits = self._retrieve_chains_for_key(self.start2, new_chain_words[0], context_id)
new_chain_words.append(self._get_suitable_word_from_choices(key_hits, gen_words, min_size))
2012-09-17 16:23:42 -05:00
if new_chain_words[1] != self.stop:
# two valid words, try for a third and check for "foo:"
# new word 3 (which we may need below)
key_hits = self._retrieve_chains_for_key(new_chain_words[0], new_chain_words[1], context_id)
new_chain_words.append(self._get_suitable_word_from_choices(key_hits, gen_words, min_size))
# if the first word is "foo:", start with the second
addressing_suffixes = [':', ',']
if new_chain_words[0][-1] in addressing_suffixes:
gen_words += new_chain_words[1:]
self.log.debug("appending following anti-address " \
"new_chain_words: {0:s}".format(new_chain_words[1:]))
2012-09-17 16:23:42 -05:00
elif new_chain_words[2] == self.stop:
gen_words += new_chain_words[0:1]
self.log.debug("appending following anti-stop " \
"new_chain_words: {0:s}".format(new_chain_words[0:1]))
else:
gen_words += new_chain_words[0:]
self.log.debug("appending following extended " \
"new_chain_words: {0:s}".format(new_chain_words[0:]))
else:
# well, we got one word out of this... let's go with it
# and let the loop check if we need more
self.log.debug("appending following short new_chain_words: {0:s}".format(new_chain_words))
gen_words += new_chain_words
# chop off the seed data at the start
gen_words = gen_words[2:]
# chop off the end text, if it was the keyword indicating an end of chain
if gen_words[-1] == self.stop:
gen_words = gen_words[:-1]
2012-07-15 01:11:21 -05:00
return ' '.join(gen_words)
def _get_suitable_word_from_choices(self, key_hits, gen_words, min_size):
"""Given an existing set of words, and key hits, pick one."""
# first, if we're not yet at min_size, pick a non-stop word if it exists
# else, if there were no results, append stop
# otherwise, pick a random result
if len(gen_words) < min_size + 2 and len(filter(lambda a: a != self.stop, key_hits)) > 0:
found_word = random.choice(filter(lambda a: a != self.stop, key_hits))
return found_word
2012-07-29 17:46:14 -05:00
elif len(key_hits) == 0:
return self.stop
else:
found_word = random.choice(key_hits)
return found_word
def _retrieve_chains_for_key(self, k1, k2, context_id):
"""Get the value(s) for a given key (a pair of strings)."""
values = []
db = self.get_db()
try:
query = ''
if k1 == self.start1 and k2 == self.start2:
# hack. get a quasi-random start from the database, in
# a faster fashion than selecting all starts
max_id = self._get_max_chain_id()
rand_id = random.randint(1,max_id)
query = ('SELECT v FROM markov_chain WHERE k1 = %s AND k2 = %s AND '
'(context_id = %s) AND id >= {0:d} LIMIT 1'.format(rand_id))
else:
query = ('SELECT v FROM markov_chain WHERE k1 = %s AND k2 = %s AND '
'(context_id = %s)')
cur = db.cursor(mdb.cursors.DictCursor)
cur.execute(query, (k1, k2, context_id))
results = cur.fetchall()
for result in results:
values.append(result['v'])
return values
except mdb.Error as e:
self.log.error("database error in _retrieve_chains_for_key")
self.log.exception(e)
raise
finally: cur.close()
def _retrieve_k2_for_value(self, v, context_id):
"""Get the value(s) for a given key (a pair of strings)."""
values = []
db = self.get_db()
try:
query = 'SELECT k2 FROM markov_chain WHERE v = %s AND context_id = %s'
cur = db.cursor(mdb.cursors.DictCursor)
cur.execute(query, (v, context_id))
results = cur.fetchall()
for result in results:
values.append(result['k2'])
return values
except mdb.Error as e:
self.log.error("database error in _retrieve_k2_for_value")
self.log.exception(e)
raise
finally: cur.close()
def _get_chatter_targets(self):
"""Get all possible chatter targets."""
db = self.get_db()
try:
# need to create our own db object, since this is likely going to be in a new thread
query = 'SELECT target, chance FROM markov_chatter_target'
cur = db.cursor(mdb.cursors.DictCursor)
cur.execute(query)
results = cur.fetchall()
return results
except mdb.Error as e:
self.log.error("database error in _get_chatter_targets")
self.log.exception(e)
raise
finally: cur.close()
def _get_one_chatter_target(self):
"""Select one random chatter target."""
targets = self._get_chatter_targets()
if targets:
return targets[random.randint(0, len(targets)-1)]
def _get_max_chain_id(self):
"""Get the highest id in the chain table."""
db = self.get_db()
try:
query = '''
SELECT id FROM markov_chain ORDER BY id DESC LIMIT 1
'''
cur = db.cursor(mdb.cursors.DictCursor)
cur.execute(query)
result = cur.fetchone()
if result:
return result['id']
else:
return None
except mdb.Error as e:
self.log.error("database error in _get_max_chain_id")
self.log.exception(e)
raise
finally: cur.close()
def _get_context_id_for_target(self, target):
"""Get the context ID for the desired/input target."""
db = self.get_db()
try:
query = '''
SELECT mc.id FROM markov_context mc
INNER JOIN markov_target_to_context_map mt
ON mt.context_id = mc.id
WHERE mt.target = %s
'''
cur = db.cursor(mdb.cursors.DictCursor)
cur.execute(query, (target,))
result = cur.fetchone()
db.close()
if result:
return result['id']
else:
# auto-generate a context to keep things private
self._add_context_for_target(target)
return self._get_context_id_for_target(target)
except mdb.Error as e:
self.log.error("database error in _get_context_id_for_target")
self.log.exception(e)
raise
finally: cur.close()
def _add_context_for_target(self, target):
"""Create a new context for the desired/input target."""
db = self.get_db()
try:
statement = 'INSERT INTO markov_context (context) VALUES (%s)'
cur = db.cursor(mdb.cursors.DictCursor)
cur.execute(statement, (target,))
statement = '''
INSERT INTO markov_target_to_context_map (target, context_id)
VALUES (%s, (SELECT id FROM markov_context WHERE context = %s))
'''
cur.execute(statement, (target,target))
db.commit()
except mdb.Error as e:
db.rollback()
self.log.error("database error in _add_context_for_target")
self.log.exception(e)
raise
finally: cur.close()
try:
query = '''
SELECT mc.id FROM markov_context mc
INNER JOIN markov_target_to_context_map mt
ON mt.context_id = mc.id
WHERE mt.target = %s
'''
cur = db.cursor(mdb.cursors.DictCursor)
cur.execute(query, (target,))
result = cur.fetchone()
if result:
return result['id']
else:
# auto-generate a context to keep things private
self._add_context_for_target(target)
return self._get_context_id_for_target(target)
except mdb.Error as e:
self.log.error("database error in _get_context_id_for_target")
self.log.exception(e)
raise
finally: cur.close()
# vi:tabstop=4:expandtab:autoindent
# kate: indent-mode python;indent-width 4;replace-tabs on;