Markov: use sqlite backend for brain
this keeps us from having the entire markov chain in memory and having to do the pickling and so on. in many ways, this is a good thing. in one way, this is a bad thing. each line on irc will create a __start1,__start2 item in the database, which means starting a chain will be an expensive process. (approx 3 seconds, from irc logs of 600,000 K lines). following selects run much faster, but the first one is dog slow. a later commit should hopefully fix this.
This commit is contained in:
parent
28694ed82f
commit
1712a7db53
@ -20,6 +20,7 @@ import cPickle
|
|||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
|
import sqlite3
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from extlib import irclib
|
from extlib import irclib
|
||||||
@ -41,10 +42,6 @@ class Markov(Module):
|
|||||||
def __init__(self, irc, config, server):
|
def __init__(self, irc, config, server):
|
||||||
"""Create the Markov chainer, and learn text from a file if available."""
|
"""Create the Markov chainer, and learn text from a file if available."""
|
||||||
|
|
||||||
Module.__init__(self, irc, config, server)
|
|
||||||
|
|
||||||
self.brain_filename = 'dr.botzo.markov'
|
|
||||||
|
|
||||||
# set up some keywords for use in the chains --- don't change these
|
# set up some keywords for use in the chains --- don't change these
|
||||||
# once you've created a brain
|
# once you've created a brain
|
||||||
self.start1 = '__start1'
|
self.start1 = '__start1'
|
||||||
@ -60,13 +57,33 @@ class Markov(Module):
|
|||||||
self.learnre = re.compile(learnpattern)
|
self.learnre = re.compile(learnpattern)
|
||||||
self.replyre = re.compile(replypattern)
|
self.replyre = re.compile(replypattern)
|
||||||
|
|
||||||
try:
|
Module.__init__(self, irc, config, server)
|
||||||
brainfile = open(self.brain_filename, 'r')
|
|
||||||
self.brain = cPickle.load(brainfile)
|
def db_init(self):
|
||||||
brainfile.close()
|
"""Create the markov chain table."""
|
||||||
except IOError:
|
|
||||||
self.brain = {}
|
version = self.db_module_registered(self.__class__.__name__)
|
||||||
self.brain.setdefault((self.start1, self.start2), []).append(self.stop)
|
if (version == None):
|
||||||
|
db = self.get_db()
|
||||||
|
try:
|
||||||
|
db.execute('''
|
||||||
|
CREATE TABLE markov_chain (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
k1 TEXT NOT NULL,
|
||||||
|
k2 TEXT NOT NULL,
|
||||||
|
v TEXT NOT NULL
|
||||||
|
)''')
|
||||||
|
db.execute('CREATE INDEX markov_chain_key_index ON markov_chain (k1, k2)')
|
||||||
|
sql = 'INSERT INTO drbotzo_modules VALUES (?,?)'
|
||||||
|
db.execute(sql, (self.__class__.__name__, 1))
|
||||||
|
db.commit()
|
||||||
|
version = 1
|
||||||
|
|
||||||
|
self._learn_line('')
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
db.rollback()
|
||||||
|
print("sqlite error: " + str(e))
|
||||||
|
raise
|
||||||
|
|
||||||
def register_handlers(self):
|
def register_handlers(self):
|
||||||
"""Handle pubmsg/privmsg, to learn and/or reply to IRC events."""
|
"""Handle pubmsg/privmsg, to learn and/or reply to IRC events."""
|
||||||
@ -82,13 +99,6 @@ class Markov(Module):
|
|||||||
self.server.remove_global_handler('pubmsg', self.learn_from_irc_event)
|
self.server.remove_global_handler('pubmsg', self.learn_from_irc_event)
|
||||||
self.server.remove_global_handler('privmsg', self.learn_from_irc_event)
|
self.server.remove_global_handler('privmsg', self.learn_from_irc_event)
|
||||||
|
|
||||||
def save(self):
|
|
||||||
"""Pickle the brain upon save."""
|
|
||||||
|
|
||||||
brainfile = open(self.brain_filename, 'w')
|
|
||||||
cPickle.dump(self.brain, brainfile)
|
|
||||||
brainfile.close()
|
|
||||||
|
|
||||||
def learn_from_irc_event(self, connection, event):
|
def learn_from_irc_event(self, connection, event):
|
||||||
"""Learn from IRC events."""
|
"""Learn from IRC events."""
|
||||||
|
|
||||||
@ -170,16 +180,24 @@ class Markov(Module):
|
|||||||
"""Create Markov chains from the provided line."""
|
"""Create Markov chains from the provided line."""
|
||||||
|
|
||||||
# set up the head of the chain
|
# set up the head of the chain
|
||||||
w1 = self.start1
|
k1 = self.start1
|
||||||
w2 = self.start2
|
k2 = self.start2
|
||||||
|
|
||||||
# for each word pair, add the next word to the dictionary
|
try:
|
||||||
for word in line.split():
|
db = self.get_db()
|
||||||
self.brain.setdefault((w1, w2), []).append(word.lower())
|
cur = db.cursor()
|
||||||
w1, w2 = w2, word.lower()
|
statement = 'INSERT INTO markov_chain (k1, k2, v) VALUES (?, ?, ?)'
|
||||||
|
|
||||||
# cap the end of the chain
|
for word in line.split():
|
||||||
self.brain.setdefault((w1, w2), []).append(self.stop)
|
cur.execute(statement, (k1.decode('utf-8', 'replace').lower(), k2.decode('utf-8', 'replace').lower(), word.decode('utf-8', 'replace').lower()))
|
||||||
|
k1, k2 = k2, word
|
||||||
|
cur.execute(statement, (k1.decode('utf-8', 'replace').lower(), k2.decode('utf-8', 'replace').lower(), self.stop))
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
db.rollback()
|
||||||
|
print("sqlite error: " + str(e))
|
||||||
|
raise
|
||||||
|
|
||||||
def _reply(self, min_size=15, max_size=100):
|
def _reply(self, min_size=15, max_size=100):
|
||||||
"""Generate a totally random string from the chains, of specified limit of words."""
|
"""Generate a totally random string from the chains, of specified limit of words."""
|
||||||
@ -200,10 +218,11 @@ class Markov(Module):
|
|||||||
|
|
||||||
# walk a chain, randomly, building the list of words
|
# walk a chain, randomly, building the list of words
|
||||||
while len(gen_words) < max_size + 2 and gen_words[-1] != self.stop:
|
while len(gen_words) < max_size + 2 and gen_words[-1] != self.stop:
|
||||||
if len(gen_words) < min_size and len(filter(lambda a: a != self.stop, self.brain[(gen_words[-2], gen_words[-1])])) > 0:
|
key_hits = self._retrieve_chains_for_key(gen_words[-2], gen_words[-1])
|
||||||
|
if len(gen_words) < min_size and len(filter(lambda a: a != self.stop, key_hits)) > 0:
|
||||||
# we aren't at min size yet and we have at least one chain path
|
# we aren't at min size yet and we have at least one chain path
|
||||||
# that isn't (yet) the end. take one of those.
|
# that isn't (yet) the end. take one of those.
|
||||||
gen_words.append(random.choice(filter(lambda a: a != self.stop, self.brain[(gen_words[-2], gen_words[-1])])))
|
gen_words.append(random.choice(filter(lambda a: a != self.stop, key_hits)))
|
||||||
min_search_tries = 0
|
min_search_tries = 0
|
||||||
elif len(gen_words) < min_size and min_search_tries <= 10:
|
elif len(gen_words) < min_size and min_search_tries <= 10:
|
||||||
# we aren't at min size yet and the only path we currently have is
|
# we aren't at min size yet and the only path we currently have is
|
||||||
@ -215,7 +234,7 @@ class Markov(Module):
|
|||||||
# either we have hit our min size requirement, or we haven't but
|
# either we have hit our min size requirement, or we haven't but
|
||||||
# we also exhausted min_search_tries. either way, just pick a word
|
# we also exhausted min_search_tries. either way, just pick a word
|
||||||
# at random, knowing it may be the end of the chain
|
# at random, knowing it may be the end of the chain
|
||||||
gen_words.append(random.choice(self.brain[(gen_words[-2], gen_words[-1])]))
|
gen_words.append(random.choice(key_hits))
|
||||||
min_search_tries = 0
|
min_search_tries = 0
|
||||||
|
|
||||||
# chop off the seed data at the start
|
# chop off the seed data at the start
|
||||||
@ -247,16 +266,17 @@ class Markov(Module):
|
|||||||
|
|
||||||
# walk a chain, randomly, building the list of words
|
# walk a chain, randomly, building the list of words
|
||||||
while len(gen_words) < max_size + 2 and gen_words[-1] != self.stop:
|
while len(gen_words) < max_size + 2 and gen_words[-1] != self.stop:
|
||||||
|
key_hits = self._retrieve_chains_for_key(gen_words[-2], gen_words[-1])
|
||||||
# use the chain that includes the target word, if it is found
|
# use the chain that includes the target word, if it is found
|
||||||
if target_word in self.brain[(gen_words[-2], gen_words[-1])]:
|
if target_word in key_hits:
|
||||||
gen_words.append(target_word)
|
gen_words.append(target_word)
|
||||||
# generate new word
|
# generate new word
|
||||||
target_word = words[random.randint(0, len(words)-1)]
|
target_word = words[random.randint(0, len(words)-1)]
|
||||||
else:
|
else:
|
||||||
if len(gen_words) < min_size and len(filter(lambda a: a != self.stop, self.brain[(gen_words[-2], gen_words[-1])])) > 0:
|
if len(gen_words) < min_size and len(filter(lambda a: a != self.stop, key_hits)) > 0:
|
||||||
gen_words.append(random.choice(filter(lambda a: a != self.stop, self.brain[(gen_words[-2], gen_words[-1])])))
|
gen_words.append(random.choice(filter(lambda a: a != self.stop, key_hits)))
|
||||||
else:
|
else:
|
||||||
gen_words.append(random.choice(self.brain[(gen_words[-2], gen_words[-1])]))
|
gen_words.append(random.choice(key_hits))
|
||||||
|
|
||||||
# chop off the seed data at the start
|
# chop off the seed data at the start
|
||||||
gen_words = gen_words[2:]
|
gen_words = gen_words[2:]
|
||||||
@ -267,5 +287,23 @@ class Markov(Module):
|
|||||||
|
|
||||||
return ' '.join(gen_words)
|
return ' '.join(gen_words)
|
||||||
|
|
||||||
|
def _retrieve_chains_for_key(self, k1, k2):
|
||||||
|
"""Get the value(s) for a given key (a pair of strings)."""
|
||||||
|
|
||||||
|
values = []
|
||||||
|
try:
|
||||||
|
db = self.get_db()
|
||||||
|
query = 'SELECT v FROM markov_chain WHERE k1 = ? AND k2 = ?'
|
||||||
|
cursor = db.execute(query, (k1,k2))
|
||||||
|
results = cursor.fetchall()
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
values.append(result['v'])
|
||||||
|
|
||||||
|
return values
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
print('sqlite error: ' + str(e))
|
||||||
|
raise
|
||||||
|
|
||||||
# vi:tabstop=4:expandtab:autoindent
|
# vi:tabstop=4:expandtab:autoindent
|
||||||
# kate: indent-mode python;indent-width 4;replace-tabs on;
|
# kate: indent-mode python;indent-width 4;replace-tabs on;
|
||||||
|
Loading…
Reference in New Issue
Block a user