Markov: use sqlite backend for brain

this keeps us from having the entire markov chain in memory and
having to do the pickling and so on. in many ways, this is a good
thing.

in one way, this is a bad thing. each line on irc will create a
__start1,__start2 item in the database, which means starting a
chain will be an expensive process. (approx 3 seconds, from irc
logs of 600,000 K lines). following selects run much faster, but
the first one is dog slow. a later commit should hopefully fix this.
This commit is contained in:
Brian S. Stephan 2011-02-24 20:39:32 -06:00
parent 28694ed82f
commit 1712a7db53
1 changed files with 71 additions and 33 deletions

View File

@ -20,6 +20,7 @@ import cPickle
import os
import random
import re
import sqlite3
import sys
from extlib import irclib
@ -41,10 +42,6 @@ class Markov(Module):
def __init__(self, irc, config, server):
"""Create the Markov chainer, and learn text from a file if available."""
Module.__init__(self, irc, config, server)
self.brain_filename = 'dr.botzo.markov'
# set up some keywords for use in the chains --- don't change these
# once you've created a brain
self.start1 = '__start1'
@ -60,13 +57,33 @@ class Markov(Module):
self.learnre = re.compile(learnpattern)
self.replyre = re.compile(replypattern)
try:
brainfile = open(self.brain_filename, 'r')
self.brain = cPickle.load(brainfile)
brainfile.close()
except IOError:
self.brain = {}
self.brain.setdefault((self.start1, self.start2), []).append(self.stop)
Module.__init__(self, irc, config, server)
def db_init(self):
"""Create the markov chain table."""
version = self.db_module_registered(self.__class__.__name__)
if (version == None):
db = self.get_db()
try:
db.execute('''
CREATE TABLE markov_chain (
id INTEGER PRIMARY KEY AUTOINCREMENT,
k1 TEXT NOT NULL,
k2 TEXT NOT NULL,
v TEXT NOT NULL
)''')
db.execute('CREATE INDEX markov_chain_key_index ON markov_chain (k1, k2)')
sql = 'INSERT INTO drbotzo_modules VALUES (?,?)'
db.execute(sql, (self.__class__.__name__, 1))
db.commit()
version = 1
self._learn_line('')
except sqlite3.Error as e:
db.rollback()
print("sqlite error: " + str(e))
raise
def register_handlers(self):
"""Handle pubmsg/privmsg, to learn and/or reply to IRC events."""
@ -82,13 +99,6 @@ class Markov(Module):
self.server.remove_global_handler('pubmsg', self.learn_from_irc_event)
self.server.remove_global_handler('privmsg', self.learn_from_irc_event)
def save(self):
"""Pickle the brain upon save."""
brainfile = open(self.brain_filename, 'w')
cPickle.dump(self.brain, brainfile)
brainfile.close()
def learn_from_irc_event(self, connection, event):
"""Learn from IRC events."""
@ -170,16 +180,24 @@ class Markov(Module):
"""Create Markov chains from the provided line."""
# set up the head of the chain
w1 = self.start1
w2 = self.start2
k1 = self.start1
k2 = self.start2
# for each word pair, add the next word to the dictionary
for word in line.split():
self.brain.setdefault((w1, w2), []).append(word.lower())
w1, w2 = w2, word.lower()
try:
db = self.get_db()
cur = db.cursor()
statement = 'INSERT INTO markov_chain (k1, k2, v) VALUES (?, ?, ?)'
# cap the end of the chain
self.brain.setdefault((w1, w2), []).append(self.stop)
for word in line.split():
cur.execute(statement, (k1.decode('utf-8', 'replace').lower(), k2.decode('utf-8', 'replace').lower(), word.decode('utf-8', 'replace').lower()))
k1, k2 = k2, word
cur.execute(statement, (k1.decode('utf-8', 'replace').lower(), k2.decode('utf-8', 'replace').lower(), self.stop))
db.commit()
except sqlite3.Error as e:
db.rollback()
print("sqlite error: " + str(e))
raise
def _reply(self, min_size=15, max_size=100):
"""Generate a totally random string from the chains, of specified limit of words."""
@ -200,10 +218,11 @@ class Markov(Module):
# walk a chain, randomly, building the list of words
while len(gen_words) < max_size + 2 and gen_words[-1] != self.stop:
if len(gen_words) < min_size and len(filter(lambda a: a != self.stop, self.brain[(gen_words[-2], gen_words[-1])])) > 0:
key_hits = self._retrieve_chains_for_key(gen_words[-2], gen_words[-1])
if len(gen_words) < min_size and len(filter(lambda a: a != self.stop, key_hits)) > 0:
# we aren't at min size yet and we have at least one chain path
# that isn't (yet) the end. take one of those.
gen_words.append(random.choice(filter(lambda a: a != self.stop, self.brain[(gen_words[-2], gen_words[-1])])))
gen_words.append(random.choice(filter(lambda a: a != self.stop, key_hits)))
min_search_tries = 0
elif len(gen_words) < min_size and min_search_tries <= 10:
# we aren't at min size yet and the only path we currently have is
@ -215,7 +234,7 @@ class Markov(Module):
# either we have hit our min size requirement, or we haven't but
# we also exhausted min_search_tries. either way, just pick a word
# at random, knowing it may be the end of the chain
gen_words.append(random.choice(self.brain[(gen_words[-2], gen_words[-1])]))
gen_words.append(random.choice(key_hits))
min_search_tries = 0
# chop off the seed data at the start
@ -247,16 +266,17 @@ class Markov(Module):
# walk a chain, randomly, building the list of words
while len(gen_words) < max_size + 2 and gen_words[-1] != self.stop:
key_hits = self._retrieve_chains_for_key(gen_words[-2], gen_words[-1])
# use the chain that includes the target word, if it is found
if target_word in self.brain[(gen_words[-2], gen_words[-1])]:
if target_word in key_hits:
gen_words.append(target_word)
# generate new word
target_word = words[random.randint(0, len(words)-1)]
else:
if len(gen_words) < min_size and len(filter(lambda a: a != self.stop, self.brain[(gen_words[-2], gen_words[-1])])) > 0:
gen_words.append(random.choice(filter(lambda a: a != self.stop, self.brain[(gen_words[-2], gen_words[-1])])))
if len(gen_words) < min_size and len(filter(lambda a: a != self.stop, key_hits)) > 0:
gen_words.append(random.choice(filter(lambda a: a != self.stop, key_hits)))
else:
gen_words.append(random.choice(self.brain[(gen_words[-2], gen_words[-1])]))
gen_words.append(random.choice(key_hits))
# chop off the seed data at the start
gen_words = gen_words[2:]
@ -267,5 +287,23 @@ class Markov(Module):
return ' '.join(gen_words)
def _retrieve_chains_for_key(self, k1, k2):
"""Get the value(s) for a given key (a pair of strings)."""
values = []
try:
db = self.get_db()
query = 'SELECT v FROM markov_chain WHERE k1 = ? AND k2 = ?'
cursor = db.execute(query, (k1,k2))
results = cursor.fetchall()
for result in results:
values.append(result['v'])
return values
except sqlite3.Error as e:
print('sqlite error: ' + str(e))
raise
# vi:tabstop=4:expandtab:autoindent
# kate: indent-mode python;indent-width 4;replace-tabs on;