Markov: use sqlite backend for brain
this keeps us from having the entire markov chain in memory and having to do the pickling and so on. in many ways, this is a good thing. in one way, this is a bad thing. each line on irc will create a __start1,__start2 item in the database, which means starting a chain will be an expensive process. (approx 3 seconds, from irc logs of 600,000 K lines). following selects run much faster, but the first one is dog slow. a later commit should hopefully fix this.
This commit is contained in:
parent
28694ed82f
commit
1712a7db53
@ -20,6 +20,7 @@ import cPickle
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import sqlite3
|
||||
import sys
|
||||
|
||||
from extlib import irclib
|
||||
@ -41,10 +42,6 @@ class Markov(Module):
|
||||
def __init__(self, irc, config, server):
|
||||
"""Create the Markov chainer, and learn text from a file if available."""
|
||||
|
||||
Module.__init__(self, irc, config, server)
|
||||
|
||||
self.brain_filename = 'dr.botzo.markov'
|
||||
|
||||
# set up some keywords for use in the chains --- don't change these
|
||||
# once you've created a brain
|
||||
self.start1 = '__start1'
|
||||
@ -60,13 +57,33 @@ class Markov(Module):
|
||||
self.learnre = re.compile(learnpattern)
|
||||
self.replyre = re.compile(replypattern)
|
||||
|
||||
try:
|
||||
brainfile = open(self.brain_filename, 'r')
|
||||
self.brain = cPickle.load(brainfile)
|
||||
brainfile.close()
|
||||
except IOError:
|
||||
self.brain = {}
|
||||
self.brain.setdefault((self.start1, self.start2), []).append(self.stop)
|
||||
Module.__init__(self, irc, config, server)
|
||||
|
||||
def db_init(self):
|
||||
"""Create the markov chain table."""
|
||||
|
||||
version = self.db_module_registered(self.__class__.__name__)
|
||||
if (version == None):
|
||||
db = self.get_db()
|
||||
try:
|
||||
db.execute('''
|
||||
CREATE TABLE markov_chain (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
k1 TEXT NOT NULL,
|
||||
k2 TEXT NOT NULL,
|
||||
v TEXT NOT NULL
|
||||
)''')
|
||||
db.execute('CREATE INDEX markov_chain_key_index ON markov_chain (k1, k2)')
|
||||
sql = 'INSERT INTO drbotzo_modules VALUES (?,?)'
|
||||
db.execute(sql, (self.__class__.__name__, 1))
|
||||
db.commit()
|
||||
version = 1
|
||||
|
||||
self._learn_line('')
|
||||
except sqlite3.Error as e:
|
||||
db.rollback()
|
||||
print("sqlite error: " + str(e))
|
||||
raise
|
||||
|
||||
def register_handlers(self):
|
||||
"""Handle pubmsg/privmsg, to learn and/or reply to IRC events."""
|
||||
@ -82,13 +99,6 @@ class Markov(Module):
|
||||
self.server.remove_global_handler('pubmsg', self.learn_from_irc_event)
|
||||
self.server.remove_global_handler('privmsg', self.learn_from_irc_event)
|
||||
|
||||
def save(self):
|
||||
"""Pickle the brain upon save."""
|
||||
|
||||
brainfile = open(self.brain_filename, 'w')
|
||||
cPickle.dump(self.brain, brainfile)
|
||||
brainfile.close()
|
||||
|
||||
def learn_from_irc_event(self, connection, event):
|
||||
"""Learn from IRC events."""
|
||||
|
||||
@ -170,16 +180,24 @@ class Markov(Module):
|
||||
"""Create Markov chains from the provided line."""
|
||||
|
||||
# set up the head of the chain
|
||||
w1 = self.start1
|
||||
w2 = self.start2
|
||||
k1 = self.start1
|
||||
k2 = self.start2
|
||||
|
||||
# for each word pair, add the next word to the dictionary
|
||||
for word in line.split():
|
||||
self.brain.setdefault((w1, w2), []).append(word.lower())
|
||||
w1, w2 = w2, word.lower()
|
||||
try:
|
||||
db = self.get_db()
|
||||
cur = db.cursor()
|
||||
statement = 'INSERT INTO markov_chain (k1, k2, v) VALUES (?, ?, ?)'
|
||||
|
||||
# cap the end of the chain
|
||||
self.brain.setdefault((w1, w2), []).append(self.stop)
|
||||
for word in line.split():
|
||||
cur.execute(statement, (k1.decode('utf-8', 'replace').lower(), k2.decode('utf-8', 'replace').lower(), word.decode('utf-8', 'replace').lower()))
|
||||
k1, k2 = k2, word
|
||||
cur.execute(statement, (k1.decode('utf-8', 'replace').lower(), k2.decode('utf-8', 'replace').lower(), self.stop))
|
||||
|
||||
db.commit()
|
||||
except sqlite3.Error as e:
|
||||
db.rollback()
|
||||
print("sqlite error: " + str(e))
|
||||
raise
|
||||
|
||||
def _reply(self, min_size=15, max_size=100):
|
||||
"""Generate a totally random string from the chains, of specified limit of words."""
|
||||
@ -200,10 +218,11 @@ class Markov(Module):
|
||||
|
||||
# walk a chain, randomly, building the list of words
|
||||
while len(gen_words) < max_size + 2 and gen_words[-1] != self.stop:
|
||||
if len(gen_words) < min_size and len(filter(lambda a: a != self.stop, self.brain[(gen_words[-2], gen_words[-1])])) > 0:
|
||||
key_hits = self._retrieve_chains_for_key(gen_words[-2], gen_words[-1])
|
||||
if len(gen_words) < min_size and len(filter(lambda a: a != self.stop, key_hits)) > 0:
|
||||
# we aren't at min size yet and we have at least one chain path
|
||||
# that isn't (yet) the end. take one of those.
|
||||
gen_words.append(random.choice(filter(lambda a: a != self.stop, self.brain[(gen_words[-2], gen_words[-1])])))
|
||||
gen_words.append(random.choice(filter(lambda a: a != self.stop, key_hits)))
|
||||
min_search_tries = 0
|
||||
elif len(gen_words) < min_size and min_search_tries <= 10:
|
||||
# we aren't at min size yet and the only path we currently have is
|
||||
@ -215,7 +234,7 @@ class Markov(Module):
|
||||
# either we have hit our min size requirement, or we haven't but
|
||||
# we also exhausted min_search_tries. either way, just pick a word
|
||||
# at random, knowing it may be the end of the chain
|
||||
gen_words.append(random.choice(self.brain[(gen_words[-2], gen_words[-1])]))
|
||||
gen_words.append(random.choice(key_hits))
|
||||
min_search_tries = 0
|
||||
|
||||
# chop off the seed data at the start
|
||||
@ -247,16 +266,17 @@ class Markov(Module):
|
||||
|
||||
# walk a chain, randomly, building the list of words
|
||||
while len(gen_words) < max_size + 2 and gen_words[-1] != self.stop:
|
||||
key_hits = self._retrieve_chains_for_key(gen_words[-2], gen_words[-1])
|
||||
# use the chain that includes the target word, if it is found
|
||||
if target_word in self.brain[(gen_words[-2], gen_words[-1])]:
|
||||
if target_word in key_hits:
|
||||
gen_words.append(target_word)
|
||||
# generate new word
|
||||
target_word = words[random.randint(0, len(words)-1)]
|
||||
else:
|
||||
if len(gen_words) < min_size and len(filter(lambda a: a != self.stop, self.brain[(gen_words[-2], gen_words[-1])])) > 0:
|
||||
gen_words.append(random.choice(filter(lambda a: a != self.stop, self.brain[(gen_words[-2], gen_words[-1])])))
|
||||
if len(gen_words) < min_size and len(filter(lambda a: a != self.stop, key_hits)) > 0:
|
||||
gen_words.append(random.choice(filter(lambda a: a != self.stop, key_hits)))
|
||||
else:
|
||||
gen_words.append(random.choice(self.brain[(gen_words[-2], gen_words[-1])]))
|
||||
gen_words.append(random.choice(key_hits))
|
||||
|
||||
# chop off the seed data at the start
|
||||
gen_words = gen_words[2:]
|
||||
@ -267,5 +287,23 @@ class Markov(Module):
|
||||
|
||||
return ' '.join(gen_words)
|
||||
|
||||
def _retrieve_chains_for_key(self, k1, k2):
|
||||
"""Get the value(s) for a given key (a pair of strings)."""
|
||||
|
||||
values = []
|
||||
try:
|
||||
db = self.get_db()
|
||||
query = 'SELECT v FROM markov_chain WHERE k1 = ? AND k2 = ?'
|
||||
cursor = db.execute(query, (k1,k2))
|
||||
results = cursor.fetchall()
|
||||
|
||||
for result in results:
|
||||
values.append(result['v'])
|
||||
|
||||
return values
|
||||
except sqlite3.Error as e:
|
||||
print('sqlite error: ' + str(e))
|
||||
raise
|
||||
|
||||
# vi:tabstop=4:expandtab:autoindent
|
||||
# kate: indent-mode python;indent-width 4;replace-tabs on;
|
||||
|
Loading…
Reference in New Issue
Block a user