this should make the bot wait longer for table locks, assuming i read the docs right
517 lines
19 KiB
Python
517 lines
19 KiB
Python
"""
|
|
Markov - Chatterbot via Markov chains for IRC
|
|
Copyright (C) 2010 Brian S. Stephan
|
|
|
|
This program is free software: you can redistribute it and/or modify
|
|
it under the terms of the GNU General Public License as published by
|
|
the Free Software Foundation, either version 3 of the License, or
|
|
(at your option) any later version.
|
|
|
|
This program is distributed in the hope that it will be useful,
|
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
GNU General Public License for more details.
|
|
|
|
You should have received a copy of the GNU General Public License
|
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
"""
|
|
|
|
import cPickle
|
|
from datetime import datetime
|
|
import os
|
|
import random
|
|
import re
|
|
import sqlite3
|
|
import sys
|
|
import thread
|
|
import time
|
|
|
|
from dateutil.parser import *
|
|
from dateutil.relativedelta import *
|
|
from extlib import irclib
|
|
|
|
from Module import Module
|
|
|
|
class Markov(Module):
|
|
|
|
"""
|
|
Create a chatterbot very similar to a MegaHAL, but simpler and
|
|
implemented in pure Python. Proof of concept code from Ape.
|
|
|
|
Ape wrote: based on this:
|
|
http://uswaretech.com/blog/2009/06/pseudo-random-text-markov-chains-python/
|
|
and this:
|
|
http://code.activestate.com/recipes/194364-the-markov-chain-algorithm/
|
|
"""
|
|
|
|
def __init__(self, irc, config, server):
|
|
"""Create the Markov chainer, and learn text from a file if available."""
|
|
|
|
# set up some keywords for use in the chains --- don't change these
|
|
# once you've created a brain
|
|
self.start1 = '__start1'
|
|
self.start2 = '__start2'
|
|
self.stop = '__stop'
|
|
|
|
# set up regexes, for replying to specific stuff
|
|
trainpattern = '^!markov\s+train\s+(.*)$'
|
|
learnpattern = '^!markov\s+learn\s+(.*)$'
|
|
replypattern = '^!markov\s+reply(\s+min=(\d+))?(\s+max=(\d+))?(\s+(.*)$|$)'
|
|
|
|
self.trainre = re.compile(trainpattern)
|
|
self.learnre = re.compile(learnpattern)
|
|
self.replyre = re.compile(replypattern)
|
|
|
|
self.shut_up = False
|
|
self.lines_seen = []
|
|
|
|
Module.__init__(self, irc, config, server)
|
|
|
|
self.next_shut_up_check = 0
|
|
self.next_chatter_check = 0
|
|
self.connection = None
|
|
thread.start_new_thread(self.thread_do, ())
|
|
|
|
def db_init(self):
|
|
"""Create the markov chain table."""
|
|
|
|
version = self.db_module_registered(self.__class__.__name__)
|
|
if (version == None):
|
|
db = self.get_db()
|
|
try:
|
|
db.execute('''
|
|
CREATE TABLE markov_chain (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
k1 TEXT NOT NULL,
|
|
k2 TEXT NOT NULL,
|
|
v TEXT NOT NULL
|
|
)''')
|
|
sql = 'INSERT INTO drbotzo_modules VALUES (?,?)'
|
|
db.execute(sql, (self.__class__.__name__, 1))
|
|
db.commit()
|
|
db.close()
|
|
version = 1
|
|
|
|
self._learn_line('')
|
|
except sqlite3.Error as e:
|
|
db.rollback()
|
|
db.close()
|
|
print("sqlite error: " + str(e))
|
|
raise
|
|
if (version < 2):
|
|
db = self.get_db()
|
|
try:
|
|
db.execute('''
|
|
ALTER TABLE markov_chain
|
|
ADD COLUMN context TEXT DEFAULT NULL''')
|
|
db.execute('''
|
|
CREATE TABLE markov_context (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
context TEXT NOT NULL
|
|
)''')
|
|
db.execute('''
|
|
CREATE TABLE markov_target_to_context_map (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
target TEXT NOT NULL,
|
|
context_id INTEGER NOT NULL,
|
|
FOREIGN KEY(context_id) REFERENCES markov_context(id)
|
|
)''')
|
|
db.execute('UPDATE drbotzo_modules SET version = ? WHERE module = ?',
|
|
(2, self.__class__.__name__))
|
|
db.commit()
|
|
db.close()
|
|
version = 2
|
|
except sqlite3.Error as e:
|
|
db.rollback()
|
|
db.close()
|
|
print('sqlite error: ' + str(e))
|
|
raise
|
|
if (version < 3):
|
|
db = self.get_db()
|
|
try:
|
|
db.execute('''
|
|
CREATE INDEX markov_chain_keys_index
|
|
ON markov_chain (k1, k2)''')
|
|
db.execute('UPDATE drbotzo_modules SET version = ? WHERE module = ?',
|
|
(3, self.__class__.__name__))
|
|
db.commit()
|
|
db.close()
|
|
version = 3
|
|
except sqlite3.Error as e:
|
|
db.rollback()
|
|
db.close()
|
|
print('sqlite error: ' + str(e))
|
|
raise
|
|
if (version < 4):
|
|
db = self.get_db()
|
|
try:
|
|
db.execute('UPDATE drbotzo_modules SET version = ? WHERE module = ?',
|
|
(4, self.__class__.__name__))
|
|
db.commit()
|
|
db.close()
|
|
version = 4
|
|
except sqlite3.Error as e:
|
|
db.rollback()
|
|
db.close()
|
|
print('sqlite error: ' + str(e))
|
|
raise
|
|
if (version < 5):
|
|
db = self.get_db()
|
|
try:
|
|
version = 5
|
|
db.execute('''
|
|
CREATE TABLE markov_chatter_target (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
target TEXT NOT NULL
|
|
)''')
|
|
db.commit()
|
|
db.close()
|
|
self.db_register_module_version(self.__class__.__name__, version)
|
|
except sqlite3.Error as e:
|
|
db.rollback()
|
|
db.close()
|
|
print('sqlite error: ' + str(e))
|
|
raise
|
|
if (version < 6):
|
|
db = self.get_db()
|
|
try:
|
|
version = 6
|
|
db.execute('''
|
|
CREATE INDEX markov_chain_keys_and_context_index
|
|
ON markov_chain (k1, k2, context)''')
|
|
db.commit()
|
|
db.close()
|
|
self.db_register_module_version(self.__class__.__name__, version)
|
|
except sqlite3.Error as e:
|
|
db.rollback()
|
|
db.close()
|
|
print('sqlite error: ' + str(e))
|
|
raise
|
|
if (version < 7):
|
|
db = self.get_db()
|
|
try:
|
|
version = 7
|
|
db.execute('''
|
|
ALTER TABLE markov_chatter_target ADD COLUMN chance INTEGER NOT NULL DEFAULT 99999''')
|
|
db.commit()
|
|
db.close()
|
|
self.db_register_module_version(self.__class__.__name__, version)
|
|
except sqlite3.Error as e:
|
|
db.rollback()
|
|
db.close()
|
|
print('sqlite error: ' + str(e))
|
|
raise
|
|
|
|
def register_handlers(self):
|
|
"""Handle pubmsg/privmsg, to learn and/or reply to IRC events."""
|
|
|
|
self.server.add_global_handler('pubmsg', self.on_pub_or_privmsg, self.priority())
|
|
self.server.add_global_handler('privmsg', self.on_pub_or_privmsg, self.priority())
|
|
self.server.add_global_handler('pubmsg', self.learn_from_irc_event)
|
|
self.server.add_global_handler('privmsg', self.learn_from_irc_event)
|
|
|
|
def unregister_handlers(self):
|
|
self.server.remove_global_handler('pubmsg', self.on_pub_or_privmsg)
|
|
self.server.remove_global_handler('privmsg', self.on_pub_or_privmsg)
|
|
self.server.remove_global_handler('pubmsg', self.learn_from_irc_event)
|
|
self.server.remove_global_handler('privmsg', self.learn_from_irc_event)
|
|
|
|
def learn_from_irc_event(self, connection, event):
|
|
"""Learn from IRC events."""
|
|
|
|
what = ''.join(event.arguments()[0])
|
|
my_nick = connection.get_nickname()
|
|
what = re.sub('^' + my_nick + '[:,]\s+', '', what)
|
|
target = event.target()
|
|
nick = irclib.nm_to_n(event.source())
|
|
|
|
self.lines_seen.append((nick, datetime.now()))
|
|
self.connection = connection
|
|
|
|
# don't learn from commands
|
|
if self.trainre.search(what) or self.learnre.search(what) or self.replyre.search(what):
|
|
return
|
|
|
|
self._learn_line(what, target)
|
|
|
|
def do(self, connection, event, nick, userhost, what, admin_unlocked):
|
|
"""Handle commands and inputs."""
|
|
|
|
target = event.target()
|
|
|
|
if self.trainre.search(what):
|
|
return self.reply(connection, event, self.markov_train(connection, event, nick, userhost, what, admin_unlocked))
|
|
elif self.learnre.search(what):
|
|
return self.reply(connection, event, self.markov_learn(connection, event, nick, userhost, what, admin_unlocked))
|
|
elif self.replyre.search(what) and not self.shut_up:
|
|
return self.reply(connection, event, self.markov_reply(connection, event, nick, userhost, what, admin_unlocked))
|
|
|
|
if not self.shut_up:
|
|
# not a command, so see if i'm being mentioned
|
|
if re.search(connection.get_nickname(), what, re.IGNORECASE) is not None:
|
|
addressed_pattern = '^' + connection.get_nickname() + '[:,]\s+(.*)'
|
|
addressed_re = re.compile(addressed_pattern)
|
|
if addressed_re.match(what):
|
|
# i was addressed directly, so respond, addressing the speaker
|
|
self.lines_seen.append(('.self.said.', datetime.now()))
|
|
return self.reply(connection, event, '{0:s}: {1:s}'.format(nick, self._generate_line(target, line=addressed_re.match(what).group(1))))
|
|
else:
|
|
# i wasn't addressed directly, so just respond
|
|
self.lines_seen.append(('.self.said.', datetime.now()))
|
|
return self.reply(connection, event, '{0:s}'.format(self._generate_line(target, line=what)))
|
|
|
|
def markov_train(self, connection, event, nick, userhost, what, admin_unlocked):
|
|
"""Learn lines from a file. Good for initializing a brain."""
|
|
|
|
match = self.trainre.search(what)
|
|
if match and admin_unlocked:
|
|
filename = match.group(1)
|
|
|
|
try:
|
|
for line in open(filename, 'r'):
|
|
self._learn_line(line)
|
|
|
|
return 'Learned from \'{0:s}\'.'.format(filename)
|
|
except IOError:
|
|
return 'No such file \'{0:s}\'.'.format(filename)
|
|
|
|
def markov_learn(self, connection, event, nick, userhost, what, admin_unlocked):
|
|
"""Learn one line, as provided to the command."""
|
|
|
|
target = event.target()
|
|
match = self.learnre.search(what)
|
|
if match:
|
|
line = match.group(1)
|
|
self._learn_line(line, target)
|
|
|
|
# return what was learned, for weird chaining purposes
|
|
return line
|
|
|
|
def markov_reply(self, connection, event, nick, userhost, what, admin_unlocked):
|
|
"""Generate a reply to one line, without learning it."""
|
|
|
|
target = event.target()
|
|
match = self.replyre.search(what)
|
|
if match:
|
|
min_size = 15
|
|
max_size = 100
|
|
|
|
if match.group(2):
|
|
min_size = int(match.group(2))
|
|
if match.group(4):
|
|
max_size = int(match.group(4))
|
|
|
|
if match.group(5) != '':
|
|
line = match.group(6)
|
|
self.lines_seen.append(('.self.said.', datetime.now()))
|
|
return self._generate_line(target, line=line, min_size=min_size, max_size=max_size)
|
|
else:
|
|
self.lines_seen.append(('.self.said.', datetime.now()))
|
|
return self._generate_line(target, min_size=min_size, max_size=max_size)
|
|
|
|
def thread_do(self):
|
|
"""Do various things."""
|
|
|
|
while not self.is_shutdown:
|
|
self._do_shut_up_checks()
|
|
self._do_random_chatter_check()
|
|
time.sleep(1)
|
|
|
|
def _do_random_chatter_check(self):
|
|
"""Randomly say something to a channel."""
|
|
|
|
# don't immediately potentially chatter, let the bot
|
|
# join channels first
|
|
if self.next_chatter_check == 0:
|
|
self.next_chatter_check = time.time() + 600
|
|
|
|
if self.next_chatter_check < time.time():
|
|
self.next_chatter_check = time.time() + 600
|
|
|
|
if self.connection is None:
|
|
# i haven't seen any text yet...
|
|
return
|
|
|
|
targets = self._get_chatter_targets()
|
|
for t in targets:
|
|
a = random.randint(1, t['chance'])
|
|
if a == 1:
|
|
self.sendmsg(self.connection, t['target'], self._generate_line(t['target']))
|
|
|
|
def _do_shut_up_checks(self):
|
|
"""Check to see if we've been talking too much, and shut up if so."""
|
|
|
|
if self.next_shut_up_check < time.time():
|
|
self.shut_up = False
|
|
self.next_shut_up_check = time.time() + 30
|
|
|
|
last_30_sec_lines = []
|
|
|
|
for (nick,then) in self.lines_seen:
|
|
rdelta = relativedelta(datetime.now(), then)
|
|
if rdelta.years == 0 and rdelta.months == 0 and rdelta.days == 0 and rdelta.hours == 0 and rdelta.minutes == 0 and rdelta.seconds <= 29:
|
|
last_30_sec_lines.append((nick,then))
|
|
|
|
if len(last_30_sec_lines) >= 15:
|
|
lines_i_said = len(filter(lambda (a,b): a == '.self.said.', last_30_sec_lines))
|
|
if lines_i_said >= 8:
|
|
self.shut_up = True
|
|
targets = self._get_chatter_targets()
|
|
for t in targets:
|
|
self.sendmsg(self.connection, t['target'], 'shutting up for 30 seconds due to last 30 seconds of activity')
|
|
|
|
def _learn_line(self, line, target):
|
|
"""Create Markov chains from the provided line."""
|
|
|
|
# set up the head of the chain
|
|
k1 = self.start1
|
|
k2 = self.start2
|
|
|
|
context = target
|
|
|
|
# if there's no context, this is probably a sub-command. don't learn it
|
|
if context:
|
|
|
|
words = line.split()
|
|
if len(words) <= 0:
|
|
return line
|
|
|
|
try:
|
|
db = self.get_db()
|
|
cur = db.cursor()
|
|
statement = 'INSERT INTO markov_chain (k1, k2, v, context) VALUES (?, ?, ?, ?)'
|
|
for word in words:
|
|
cur.execute(statement, (k1.decode('utf-8', 'replace').lower(), k2.decode('utf-8', 'replace').lower(), word.decode('utf-8', 'replace').lower(), context))
|
|
k1, k2 = k2, word
|
|
cur.execute(statement, (k1.decode('utf-8', 'replace').lower(), k2.decode('utf-8', 'replace').lower(), self.stop, context))
|
|
|
|
db.commit()
|
|
db.close()
|
|
except sqlite3.Error as e:
|
|
db.rollback()
|
|
db.close()
|
|
print("sqlite error: " + str(e))
|
|
raise
|
|
|
|
def _generate_line(self, target, line='', min_size=15, max_size=100):
|
|
"""Reply to a line, using some text in the line as a point in the chain."""
|
|
|
|
# if the limit is too low, there's nothing to do
|
|
if (max_size <= 3):
|
|
raise Exception("max_size is too small: %d" % max_size)
|
|
|
|
# if the min is too large, abort
|
|
if (min_size > 20):
|
|
raise Exception("min_size is too large: %d" % min_size)
|
|
|
|
words = []
|
|
target_word = ''
|
|
# get a random word from the input
|
|
if line != '':
|
|
words = line.split()
|
|
target_word = words[random.randint(0, len(words)-1)]
|
|
|
|
context = target
|
|
|
|
# start with an empty chain, and work from there
|
|
gen_words = [self.start1, self.start2]
|
|
|
|
# walk a chain, randomly, building the list of words
|
|
while len(gen_words) < max_size + 2 and gen_words[-1] != self.stop:
|
|
key_hits = self._retrieve_chains_for_key(gen_words[-2], gen_words[-1], context)
|
|
# use the chain that includes the target word, if it is found
|
|
if target_word != '' and target_word in key_hits:
|
|
gen_words.append(target_word)
|
|
# generate new word
|
|
target_word = words[random.randint(0, len(words)-1)]
|
|
else:
|
|
if len(gen_words) < min_size and len(filter(lambda a: a != self.stop, key_hits)) > 0:
|
|
gen_words.append(random.choice(filter(lambda a: a != self.stop, key_hits)))
|
|
elif len(key_hits) <= 0:
|
|
gen_words.append(self.stop)
|
|
else:
|
|
gen_words.append(random.choice(key_hits))
|
|
|
|
# chop off the seed data at the start
|
|
gen_words = gen_words[2:]
|
|
|
|
# chop off the end text, if it was the keyword indicating an end of chain
|
|
if gen_words[-1] == self.stop:
|
|
gen_words = gen_words[:-1]
|
|
|
|
return ' '.join(gen_words).encode('utf-8', 'ignore')
|
|
|
|
def _retrieve_chains_for_key(self, k1, k2, context):
|
|
"""Get the value(s) for a given key (a pair of strings)."""
|
|
|
|
values = []
|
|
try:
|
|
db = self.get_db()
|
|
query = ''
|
|
if k1 == self.start1 and k2 == self.start2:
|
|
# hack. get a quasi-random start from the database, in
|
|
# a faster fashion than selecting all starts
|
|
max_id = self._get_max_chain_id()
|
|
rand_id = random.randint(1,max_id)
|
|
query = 'SELECT v FROM markov_chain WHERE k1 = ? AND k2 = ? AND (context = ? OR context IS NULL) AND id >= {0:d} LIMIT 1'.format(rand_id)
|
|
else:
|
|
query = 'SELECT v FROM markov_chain WHERE k1 = ? AND k2 = ? AND (context = ? OR context IS NULL)'
|
|
cursor = db.execute(query, (k1,k2,context))
|
|
results = cursor.fetchall()
|
|
|
|
for result in results:
|
|
values.append(result['v'])
|
|
|
|
db.close()
|
|
return values
|
|
except sqlite3.Error as e:
|
|
db.close()
|
|
print('sqlite error: ' + str(e))
|
|
raise
|
|
|
|
def _get_chatter_targets(self):
|
|
"""Get all possible chatter targets."""
|
|
|
|
values = []
|
|
try:
|
|
# need to create our own db object, since this is likely going to be in a new thread
|
|
db = self.get_db()
|
|
query = 'SELECT target, chance FROM markov_chatter_target'
|
|
cursor = db.execute(query)
|
|
results = cursor.fetchall()
|
|
return results
|
|
except sqlite3.Error as e:
|
|
db.close()
|
|
print('sqlite error: ' + str(e))
|
|
raise
|
|
|
|
def _get_one_chatter_target(self):
|
|
"""Select one random chatter target."""
|
|
|
|
targets = self._get_chatter_targets()
|
|
if targets:
|
|
return targets[random.randint(0, len(targets)-1)]
|
|
|
|
def _get_max_chain_id(self):
|
|
"""Get the highest id in the chain table."""
|
|
|
|
try:
|
|
db = self.get_db()
|
|
query = '''
|
|
SELECT id FROM markov_chain ORDER BY id DESC LIMIT 1
|
|
'''
|
|
cursor = db.execute(query)
|
|
result = cursor.fetchone()
|
|
db.close()
|
|
if result:
|
|
return result['id']
|
|
else:
|
|
return None
|
|
except sqlite3.Error as e:
|
|
db.close()
|
|
print('sqlite error: ' + str(e))
|
|
raise
|
|
|
|
# vi:tabstop=4:expandtab:autoindent
|
|
# kate: indent-mode python;indent-width 4;replace-tabs on;
|