dr.botzo/modules/Markov.py

341 lines
13 KiB
Python
Raw Normal View History

"""
Markov - Chatterbot via Markov chains for IRC
Copyright (C) 2010 Brian S. Stephan
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
import cPickle
import os
import random
import re
import sqlite3
import sys
from extlib import irclib
from Module import Module
class Markov(Module):
"""
Create a chatterbot very similar to a MegaHAL, but simpler and
implemented in pure Python. Proof of concept code from Ape.
2011-01-20 14:15:10 -06:00
Ape wrote: based on this:
http://uswaretech.com/blog/2009/06/pseudo-random-text-markov-chains-python/
and this:
http://code.activestate.com/recipes/194364-the-markov-chain-algorithm/
"""
2011-01-20 14:15:10 -06:00
def __init__(self, irc, config, server):
"""Create the Markov chainer, and learn text from a file if available."""
# set up some keywords for use in the chains --- don't change these
# once you've created a brain
self.start1 = '__start1'
self.start2 = '__start2'
self.stop = '__stop'
# set up regexes, for replying to specific stuff
trainpattern = '^!markov\s+train\s+(.*)$'
learnpattern = '^!markov\s+learn\s+(.*)$'
replypattern = '^!markov\s+reply(\s+min=(\d+))?(\s+max=(\d+))?(\s+(.*)$|$)'
self.trainre = re.compile(trainpattern)
self.learnre = re.compile(learnpattern)
self.replyre = re.compile(replypattern)
Module.__init__(self, irc, config, server)
# load the existing chain starts from the database
self.starts = self._get_chain_beginnings()
def db_init(self):
"""Create the markov chain table."""
version = self.db_module_registered(self.__class__.__name__)
if (version == None):
db = self.get_db()
try:
db.execute('''
CREATE TABLE markov_chain (
id INTEGER PRIMARY KEY AUTOINCREMENT,
k1 TEXT NOT NULL,
k2 TEXT NOT NULL,
v TEXT NOT NULL
)''')
db.execute('CREATE INDEX markov_chain_key_index ON markov_chain (k1, k2)')
sql = 'INSERT INTO drbotzo_modules VALUES (?,?)'
db.execute(sql, (self.__class__.__name__, 1))
db.commit()
version = 1
self._learn_line('')
except sqlite3.Error as e:
db.rollback()
print("sqlite error: " + str(e))
raise
def register_handlers(self):
"""Handle pubmsg/privmsg, to learn and/or reply to IRC events."""
self.server.add_global_handler('pubmsg', self.on_pub_or_privmsg, self.priority())
self.server.add_global_handler('privmsg', self.on_pub_or_privmsg, self.priority())
self.server.add_global_handler('pubmsg', self.learn_from_irc_event)
self.server.add_global_handler('privmsg', self.learn_from_irc_event)
def unregister_handlers(self):
self.server.remove_global_handler('pubmsg', self.on_pub_or_privmsg)
self.server.remove_global_handler('privmsg', self.on_pub_or_privmsg)
self.server.remove_global_handler('pubmsg', self.learn_from_irc_event)
self.server.remove_global_handler('privmsg', self.learn_from_irc_event)
def learn_from_irc_event(self, connection, event):
"""Learn from IRC events."""
what = ''.join(event.arguments()[0])
my_nick = connection.get_nickname()
what = re.sub('^' + my_nick + '[:,]\s+', '', what)
# don't learn from commands
if self.trainre.search(what) or self.learnre.search(what) or self.replyre.search(what):
return
self._learn_line(what)
def do(self, connection, event, nick, userhost, what, admin_unlocked):
"""Handle commands and inputs."""
if self.trainre.search(what):
return self.reply(connection, event, self.markov_train(connection, event, nick, userhost, what, admin_unlocked))
elif self.learnre.search(what):
return self.reply(connection, event, self.markov_learn(connection, event, nick, userhost, what, admin_unlocked))
elif self.replyre.search(what):
return self.reply(connection, event, self.markov_reply(connection, event, nick, userhost, what, admin_unlocked))
# not a command, so see if i'm being mentioned
if re.search(connection.get_nickname(), what, re.IGNORECASE) is not None:
addressed_pattern = '^' + connection.get_nickname() + '[:,]\s+(.*)'
addressed_re = re.compile(addressed_pattern)
if addressed_re.match(what):
# i was addressed directly, so respond, addressing the speaker
return self.reply(connection, event, '{0:s}: {1:s}'.format(nick, self._reply_to_line(addressed_re.match(what).group(1))))
else:
# i wasn't addressed directly, so just respond
return self.reply(connection, event, '{0:s}'.format(self._reply_to_line(what)))
def markov_train(self, connection, event, nick, userhost, what, admin_unlocked):
"""Learn lines from a file. Good for initializing a brain."""
match = self.trainre.search(what)
if match and admin_unlocked:
filename = match.group(1)
try:
for line in open(filename, 'r'):
self._learn_line(line)
return 'Learned from \'{0:s}\'.'.format(filename)
except IOError:
return 'No such file \'{0:s}\'.'.format(filename)
def markov_learn(self, connection, event, nick, userhost, what, admin_unlocked):
"""Learn one line, as provided to the command."""
match = self.learnre.search(what)
if match:
line = match.group(1)
self._learn_line(line)
# return what was learned, for weird chaining purposes
return line
def markov_reply(self, connection, event, nick, userhost, what, admin_unlocked):
"""Generate a reply to one line, without learning it."""
match = self.replyre.search(what)
if match:
min_size = 15
max_size = 100
if match.group(2):
min_size = int(match.group(2))
if match.group(4):
max_size = int(match.group(4))
if match.group(5) != '':
line = match.group(6)
return self._reply_to_line(line, min_size=min_size, max_size=max_size)
else:
return self._reply(min_size=min_size, max_size=max_size)
def _learn_line(self, line):
"""Create Markov chains from the provided line."""
# set up the head of the chain
k1 = self.start1
k2 = self.start2
words = line.split()
if len(words) <= 0:
return line
self.starts.append(words[0])
try:
db = self.get_db()
cur = db.cursor()
statement = 'INSERT INTO markov_chain (k1, k2, v) VALUES (?, ?, ?)'
for word in words:
cur.execute(statement, (k1.decode('utf-8', 'replace').lower(), k2.decode('utf-8', 'replace').lower(), word.decode('utf-8', 'replace').lower()))
k1, k2 = k2, word
cur.execute(statement, (k1.decode('utf-8', 'replace').lower(), k2.decode('utf-8', 'replace').lower(), self.stop))
db.commit()
except sqlite3.Error as e:
db.rollback()
print("sqlite error: " + str(e))
raise
def _reply(self, min_size=15, max_size=100):
"""Generate a totally random string from the chains, of specified limit of words."""
# if the limit is too low, there's nothing to do
if (max_size <= 3):
raise Exception("max_size is too small: %d" % max_size)
# if the min is too large, abort
if (min_size > 20):
raise Exception("min_size is too large: %d" % min_size)
# start with an empty chain, and work from there
gen_words = [self.start1, self.start2, random.choice(self.starts)]
# set up the number of times we've tried to hit the specified minimum
min_search_tries = 0
# walk a chain, randomly, building the list of words
while len(gen_words) < max_size + 2 and gen_words[-1] != self.stop:
key_hits = self._retrieve_chains_for_key(gen_words[-2], gen_words[-1])
if len(gen_words) < min_size and len(filter(lambda a: a != self.stop, key_hits)) > 0:
# we aren't at min size yet and we have at least one chain path
# that isn't (yet) the end. take one of those.
gen_words.append(random.choice(filter(lambda a: a != self.stop, key_hits)))
min_search_tries = 0
elif len(gen_words) < min_size and min_search_tries <= 10:
# we aren't at min size yet and the only path we currently have is
# a end, but we haven't retried much yet, so chop off our current
# chain and try again.
gen_words = gen_words[0:len(gen_words)-2]
min_search_tries = min_search_tries + 1
else:
# either we have hit our min size requirement, or we haven't but
# we also exhausted min_search_tries. either way, just pick a word
# at random, knowing it may be the end of the chain
gen_words.append(random.choice(key_hits))
min_search_tries = 0
# chop off the seed data at the start
gen_words = gen_words[2:]
# chop off the end text, if it was the keyword indicating an end of chain
if gen_words[-1] == self.stop:
gen_words = gen_words[:-1]
2011-02-25 20:59:57 -06:00
return ' '.join(gen_words).encode('utf-8', 'ignore')
def _reply_to_line(self, line, min_size=15, max_size=100):
"""Reply to a line, using some text in the line as a point in the chain."""
# if the limit is too low, there's nothing to do
if (max_size <= 3):
raise Exception("max_size is too small: %d" % max_size)
# if the min is too large, abort
if (min_size > 20):
raise Exception("min_size is too large: %d" % min_size)
# get a random word from the input
words = line.split()
target_word = words[random.randint(0, len(words)-1)]
# start with an empty chain, and work from there
gen_words = [self.start1, self.start2, random.choice(self.starts)]
# walk a chain, randomly, building the list of words
while len(gen_words) < max_size + 2 and gen_words[-1] != self.stop:
key_hits = self._retrieve_chains_for_key(gen_words[-2], gen_words[-1])
# use the chain that includes the target word, if it is found
if target_word in key_hits:
gen_words.append(target_word)
# generate new word
target_word = words[random.randint(0, len(words)-1)]
else:
if len(gen_words) < min_size and len(filter(lambda a: a != self.stop, key_hits)) > 0:
gen_words.append(random.choice(filter(lambda a: a != self.stop, key_hits)))
elif len(key_hits) <= 0:
gen_words.append(self.stop)
else:
gen_words.append(random.choice(key_hits))
# chop off the seed data at the start
gen_words = gen_words[2:]
# chop off the end text, if it was the keyword indicating an end of chain
if gen_words[-1] == self.stop:
gen_words = gen_words[:-1]
2011-02-25 20:59:57 -06:00
return ' '.join(gen_words).encode('utf-8', 'ignore')
def _retrieve_chains_for_key(self, k1, k2):
"""Get the value(s) for a given key (a pair of strings)."""
values = []
try:
db = self.get_db()
query = 'SELECT v FROM markov_chain WHERE k1 = ? AND k2 = ?'
cursor = db.execute(query, (k1,k2))
results = cursor.fetchall()
for result in results:
values.append(result['v'])
return values
except sqlite3.Error as e:
print('sqlite error: ' + str(e))
raise
def _get_chain_beginnings(self):
"""Get all of the first (real) words in the brain."""
values = []
try:
db = self.get_db()
query = 'SELECT v FROM markov_chain WHERE k1 = "__start1" AND k2 = "__start2"'
cursor = db.execute(query)
results = cursor.fetchall()
for result in results:
values.append(result['v'])
return values
except sqlite3.Error as e:
print('sqlite error: ' + str(e))
raise
# vi:tabstop=4:expandtab:autoindent
# kate: indent-mode python;indent-width 4;replace-tabs on;