dr.botzo/modules/Markov.py
Brian S. Stephan 26bc8bec34 Markov: rebuild the tables, use the context stuff in a better fashion this time
the module will drop your old tables if you have them, so if there's data there,
be sure to back them up and figure out some migration strategy (probably annoying
and probably having to script it).

the big change is that each line is associated to a context now, and channels
are also associated to contexts. this should allow for a better partitioning
of multiple brains, and changing which channels point to which brain.

also caught in the wake is some additional logging verbosity, and a change to
no longer lower() everything learned.

the script to dump a file into the database has also been updated with the above
changes
2012-02-28 23:23:14 -06:00

540 lines
21 KiB
Python

"""
Markov - Chatterbot via Markov chains for IRC
Copyright (C) 2010 Brian S. Stephan
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
import cPickle
from datetime import datetime
import os
import random
import re
import sqlite3
import sys
import thread
import time
from dateutil.parser import *
from dateutil.relativedelta import *
from extlib import irclib
from Module import Module
class Markov(Module):
"""
Create a chatterbot very similar to a MegaHAL, but simpler and
implemented in pure Python. Proof of concept code from Ape.
Ape wrote: based on this:
http://uswaretech.com/blog/2009/06/pseudo-random-text-markov-chains-python/
and this:
http://code.activestate.com/recipes/194364-the-markov-chain-algorithm/
"""
def __init__(self, irc, config, server):
"""Create the Markov chainer, and learn text from a file if available."""
# set up some keywords for use in the chains --- don't change these
# once you've created a brain
self.start1 = '__start1'
self.start2 = '__start2'
self.stop = '__stop'
# set up regexes, for replying to specific stuff
trainpattern = '^!markov\s+train\s+(.*)$'
learnpattern = '^!markov\s+learn\s+(.*)$'
replypattern = '^!markov\s+reply(\s+min=(\d+))?(\s+max=(\d+))?(\s+(.*)$|$)'
self.trainre = re.compile(trainpattern)
self.learnre = re.compile(learnpattern)
self.replyre = re.compile(replypattern)
self.shut_up = False
self.lines_seen = []
Module.__init__(self, irc, config, server)
self.next_shut_up_check = 0
self.next_chatter_check = 0
self.connection = None
thread.start_new_thread(self.thread_do, ())
def db_init(self):
"""Create the markov chain table."""
version = self.db_module_registered(self.__class__.__name__)
if (version == None or version < 9):
db = self.get_db()
try:
version = 9
# recreating the tables, since i need to add some foreign key constraints
db.execute('''DROP INDEX IF EXISTS markov_chain_keys_and_context_index''')
db.execute('''DROP INDEX IF EXISTS markov_chain_keys_index''')
db.execute('''DROP INDEX IF EXISTS markov_chain_value_and_context_index''')
db.execute('''DROP TABLE IF EXISTS markov_chain''')
db.execute('''DROP TABLE IF EXISTS markov_target_to_context_map''')
db.execute('''DROP TABLE IF EXISTS markov_chatter_target''')
db.execute('''DROP TABLE IF EXISTS markov_context''')
db.execute('''
CREATE TABLE markov_chatter_target (
id INTEGER PRIMARY KEY AUTOINCREMENT,
target TEXT NOT NULL,
chance INTEGER NOT NULL DEFAULT 99999
)''')
db.execute('''
CREATE TABLE markov_context (
id INTEGER PRIMARY KEY AUTOINCREMENT,
context TEXT NOT NULL
)''')
db.execute('''
CREATE TABLE markov_target_to_context_map (
id INTEGER PRIMARY KEY AUTOINCREMENT,
target TEXT NOT NULL,
context_id INTEGER NOT NULL,
FOREIGN KEY(context_id) REFERENCES markov_context(id)
)''')
db.execute('''
CREATE TABLE markov_chain (
id INTEGER PRIMARY KEY AUTOINCREMENT,
k1 TEXT NOT NULL,
k2 TEXT NOT NULL,
v TEXT NOT NULL,
context_id INTEGER DEFAULT NULL,
FOREIGN KEY(context_id) REFERENCES markov_context(id)
)''')
db.execute('''
CREATE INDEX markov_chain_keys_and_context_id_index
ON markov_chain (k1, k2, context_id)''')
db.execute('''
CREATE INDEX markov_chain_value_and_context_id_index
ON markov_chain (v, context_id)''')
db.commit()
db.close()
self.db_register_module_version(self.__class__.__name__, version)
self._learn_line('','')
except sqlite3.Error as e:
db.rollback()
db.close()
print('sqlite error: ' + str(e))
raise
def register_handlers(self):
"""Handle pubmsg/privmsg, to learn and/or reply to IRC events."""
self.server.add_global_handler('pubmsg', self.on_pub_or_privmsg, self.priority())
self.server.add_global_handler('privmsg', self.on_pub_or_privmsg, self.priority())
self.server.add_global_handler('pubmsg', self.learn_from_irc_event)
self.server.add_global_handler('privmsg', self.learn_from_irc_event)
def unregister_handlers(self):
self.server.remove_global_handler('pubmsg', self.on_pub_or_privmsg)
self.server.remove_global_handler('privmsg', self.on_pub_or_privmsg)
self.server.remove_global_handler('pubmsg', self.learn_from_irc_event)
self.server.remove_global_handler('privmsg', self.learn_from_irc_event)
def learn_from_irc_event(self, connection, event):
"""Learn from IRC events."""
what = ''.join(event.arguments()[0])
my_nick = connection.get_nickname()
what = re.sub('^' + my_nick + '[:,]\s+', '', what)
target = event.target()
nick = irclib.nm_to_n(event.source())
self.lines_seen.append((nick, datetime.now()))
self.connection = connection
# don't learn from commands
if self.trainre.search(what) or self.learnre.search(what) or self.replyre.search(what):
return
self._learn_line(what, target)
def do(self, connection, event, nick, userhost, what, admin_unlocked):
"""Handle commands and inputs."""
target = event.target()
if self.trainre.search(what):
return self.reply(connection, event, self.markov_train(connection, event, nick,
userhost, what, admin_unlocked))
elif self.learnre.search(what):
return self.reply(connection, event, self.markov_learn(connection, event, nick,
userhost, what, admin_unlocked))
elif self.replyre.search(what) and not self.shut_up:
return self.reply(connection, event, self.markov_reply(connection, event, nick,
userhost, what, admin_unlocked))
if not self.shut_up:
# not a command, so see if i'm being mentioned
if re.search(connection.get_nickname(), what, re.IGNORECASE) is not None:
addressed_pattern = '^' + connection.get_nickname() + '[:,]\s+(.*)'
addressed_re = re.compile(addressed_pattern)
if addressed_re.match(what):
# i was addressed directly, so respond, addressing the speaker
self.lines_seen.append(('.self.said.', datetime.now()))
return self.reply(connection, event, '{0:s}: {1:s}'.format(nick,
self._generate_line(target, line=addressed_re.match(what).group(1))))
else:
# i wasn't addressed directly, so just respond
self.lines_seen.append(('.self.said.', datetime.now()))
return self.reply(connection, event, '{0:s}'.format(self._generate_line(target, line=what)))
def markov_train(self, connection, event, nick, userhost, what, admin_unlocked):
"""Learn lines from a file. Good for initializing a brain."""
match = self.trainre.search(what)
if match and admin_unlocked:
filename = match.group(1)
try:
for line in open(filename, 'r'):
self._learn_line(line)
return 'Learned from \'{0:s}\'.'.format(filename)
except IOError:
return 'No such file \'{0:s}\'.'.format(filename)
def markov_learn(self, connection, event, nick, userhost, what, admin_unlocked):
"""Learn one line, as provided to the command."""
target = event.target()
match = self.learnre.search(what)
if match:
line = match.group(1)
self._learn_line(line, target)
# return what was learned, for weird chaining purposes
return line
def markov_reply(self, connection, event, nick, userhost, what, admin_unlocked):
"""Generate a reply to one line, without learning it."""
target = event.target()
match = self.replyre.search(what)
if match:
min_size = 15
max_size = 100
if match.group(2):
min_size = int(match.group(2))
if match.group(4):
max_size = int(match.group(4))
if match.group(5) != '':
line = match.group(6)
self.lines_seen.append(('.self.said.', datetime.now()))
return self._generate_line(target, line=line, min_size=min_size, max_size=max_size)
else:
self.lines_seen.append(('.self.said.', datetime.now()))
return self._generate_line(target, min_size=min_size, max_size=max_size)
def thread_do(self):
"""Do various things."""
while not self.is_shutdown:
self._do_shut_up_checks()
self._do_random_chatter_check()
time.sleep(1)
def _do_random_chatter_check(self):
"""Randomly say something to a channel."""
# don't immediately potentially chatter, let the bot
# join channels first
if self.next_chatter_check == 0:
self.next_chatter_check = time.time() + 600
if self.next_chatter_check < time.time():
self.next_chatter_check = time.time() + 600
if self.connection is None:
# i haven't seen any text yet...
return
targets = self._get_chatter_targets()
for t in targets:
if t['chance'] > 0:
a = random.randint(1, t['chance'])
if a == 1:
self.sendmsg(self.connection, t['target'], self._generate_line(t['target']))
def _do_shut_up_checks(self):
"""Check to see if we've been talking too much, and shut up if so."""
if self.next_shut_up_check < time.time():
self.shut_up = False
self.next_shut_up_check = time.time() + 30
last_30_sec_lines = []
for (nick,then) in self.lines_seen:
rdelta = relativedelta(datetime.now(), then)
if (rdelta.years == 0 and rdelta.months == 0 and rdelta.days == 0 and
rdelta.hours == 0 and rdelta.minutes == 0 and rdelta.seconds <= 29):
last_30_sec_lines.append((nick,then))
if len(last_30_sec_lines) >= 8:
lines_i_said = len(filter(lambda (a,b): a == '.self.said.', last_30_sec_lines))
if lines_i_said >= 8:
self.shut_up = True
targets = self._get_chatter_targets()
for t in targets:
self.sendmsg(self.connection, t['target'],
'shutting up for 30 seconds due to last 30 seconds of activity')
def _learn_line(self, line, target):
"""Create Markov chains from the provided line."""
# set up the head of the chain
k1 = self.start1
k2 = self.start2
context_id = self._get_context_id_for_target(target)
# if there's no target, this is probably a sub-command. don't learn it
if target:
words = line.split()
if len(words) <= 0:
return line
try:
db = self.get_db()
cur = db.cursor()
statement = 'INSERT INTO markov_chain (k1, k2, v, context_id) VALUES (?, ?, ?, ?)'
for word in words:
cur.execute(statement, (k1.decode('utf-8', 'replace'),
k2.decode('utf-8', 'replace'), word.decode('utf-8', 'replace'), context_id))
k1, k2 = k2, word
cur.execute(statement, (k1.decode('utf-8', 'replace'),
k2.decode('utf-8', 'replace'), self.stop, context_id))
db.commit()
db.close()
except sqlite3.Error as e:
db.rollback()
db.close()
print("sqlite error in Markov._learn_line: " + str(e))
raise
def _generate_line(self, target, line='', min_size=15, max_size=100):
"""Reply to a line, using some text in the line as a point in the chain."""
# if the limit is too low, there's nothing to do
if (max_size <= 3):
raise Exception("max_size is too small: %d" % max_size)
# if the min is too large, abort
if (min_size > 20):
raise Exception("min_size is too large: %d" % min_size)
words = []
target_word = ''
# get a random word from the input
if line != '':
words = line.split()
target_word = words[random.randint(0, len(words)-1)]
context_id = self._get_context_id_for_target(target)
# start with an empty chain, and work from there
gen_words = [self.start1, self.start2]
# walk a chain, randomly, building the list of words
while len(gen_words) < max_size + 2 and gen_words[-1] != self.stop:
# first, see if we have an empty response and a target word.
# we'll just pick a word and work backwards
if gen_words[-1] == self.start2 and target_word != '':
working_backwards = []
key_hits = self._retrieve_k2_for_value(target_word, context_id)
if len(key_hits):
working_backwards.append(target_word)
# generate new word
found_word = ''
target_word = words[random.randint(0, len(words)-1)]
# work backwards until we randomly bump into a start
while True:
key_hits = self._retrieve_k2_for_value(working_backwards[0], context_id)
if target_word in key_hits:
found_word = target_word
# generate new word
if len(filter(lambda a: a != target_word, words)) > 1 and False:
# if we have more than one target word, get a new one (otherwise give up)
target_word = random.choice(filter(lambda a: a != target_word, words))
else:
target_word = ''
else:
found_word = random.choice(filter(lambda a: a != self.stop, key_hits))
if found_word == self.start2 or len(working_backwards) >= max_size + 2:
gen_words = gen_words + working_backwards
break
else:
working_backwards.insert(0, found_word)
key_hits = self._retrieve_chains_for_key(gen_words[-2], gen_words[-1], context_id)
# use the chain that includes the target word, if it is found
if target_word != '' and target_word in key_hits:
gen_words.append(target_word)
# generate new word
target_word = words[random.randint(0, len(words)-1)]
else:
if len(gen_words) < min_size and len(filter(lambda a: a != self.stop, key_hits)) > 0:
gen_words.append(random.choice(filter(lambda a: a != self.stop, key_hits)))
elif len(key_hits) <= 0:
gen_words.append(self.stop)
else:
gen_words.append(random.choice(key_hits))
# chop off the seed data at the start
gen_words = gen_words[2:]
# chop off the end text, if it was the keyword indicating an end of chain
if gen_words[-1] == self.stop:
gen_words = gen_words[:-1]
return ' '.join(gen_words).encode('utf-8', 'ignore')
def _retrieve_chains_for_key(self, k1, k2, context_id):
"""Get the value(s) for a given key (a pair of strings)."""
values = []
try:
db = self.get_db()
query = ''
if k1 == self.start1 and k2 == self.start2:
# hack. get a quasi-random start from the database, in
# a faster fashion than selecting all starts
max_id = self._get_max_chain_id()
rand_id = random.randint(1,max_id)
query = ('SELECT v FROM markov_chain WHERE k1 = ? AND k2 = ? AND '
'(context_id = ? OR context_id IS NULL) AND id >= {0:d} LIMIT 1'.format(rand_id))
else:
query = ('SELECT v FROM markov_chain WHERE k1 = ? AND k2 = ? AND '
'(context_id = ? OR context_id IS NULL)')
cursor = db.execute(query, (k1,k2,context_id))
results = cursor.fetchall()
for result in results:
values.append(result['v'])
db.close()
return values
except sqlite3.Error as e:
db.close()
print('sqlite error in Markov._retrieve_chains_for_key: ' + str(e))
raise
def _retrieve_k2_for_value(self, v, context_id):
"""Get the value(s) for a given key (a pair of strings)."""
values = []
try:
db = self.get_db()
query = 'SELECT k2 FROM markov_chain WHERE v = ? AND (context_id = ? OR context_id IS NULL)'
cursor = db.execute(query, (v,context_id))
results = cursor.fetchall()
for result in results:
values.append(result['k2'])
db.close()
return values
except sqlite3.Error as e:
db.close()
print('sqlite error in Markov._retrieve_k2_for_value: ' + str(e))
raise
def _get_chatter_targets(self):
"""Get all possible chatter targets."""
values = []
try:
# need to create our own db object, since this is likely going to be in a new thread
db = self.get_db()
query = 'SELECT target, chance FROM markov_chatter_target'
cursor = db.execute(query)
results = cursor.fetchall()
return results
except sqlite3.Error as e:
db.close()
print('sqlite error in Markov._get_chatter_targets: ' + str(e))
raise
def _get_one_chatter_target(self):
"""Select one random chatter target."""
targets = self._get_chatter_targets()
if targets:
return targets[random.randint(0, len(targets)-1)]
def _get_max_chain_id(self):
"""Get the highest id in the chain table."""
try:
db = self.get_db()
query = '''
SELECT id FROM markov_chain ORDER BY id DESC LIMIT 1
'''
cursor = db.execute(query)
result = cursor.fetchone()
db.close()
if result:
return result['id']
else:
return None
except sqlite3.Error as e:
db.close()
print('sqlite error in Markov._get_max_chain_id: ' + str(e))
raise
def _get_context_id_for_target(self, target):
"""Get the context ID for the desired/input target."""
try:
db = self.get_db()
query = '''
SELECT mc.id FROM markov_context mc
INNER JOIN markov_target_to_context_map mt
ON mt.context_id = mc.id
WHERE mt.target = ?
'''
cursor = db.execute(query, (target,))
result = cursor.fetchone()
db.close()
if result:
return result['id']
else:
return None
except sqlite3.Error as e:
db.close()
print('sqlite error in Markov._get_context_id_for_target: ' + str(e))
raise
# vi:tabstop=4:expandtab:autoindent
# kate: indent-mode python;indent-width 4;replace-tabs on;