Markov: rebuild the tables, use the context stuff in a better fashion this time

the module will drop your old tables if you have them, so if there's data there,
be sure to back them up and figure out some migration strategy (probably annoying
and probably having to script it).

the big change is that each line is associated to a context now, and channels
are also associated to contexts. this should allow for a better partitioning
of multiple brains, and changing which channels point to which brain.

also caught in the wake is some additional logging verbosity, and a change to
no longer lower() everything learned.

the script to dump a file into the database has also been updated with the above
changes
This commit is contained in:
Brian S. Stephan 2012-02-28 23:23:14 -06:00
parent 79ddce0bcb
commit 26bc8bec34
2 changed files with 102 additions and 145 deletions

View File

@ -76,39 +76,33 @@ class Markov(Module):
"""Create the markov chain table."""
version = self.db_module_registered(self.__class__.__name__)
if (version == None):
if (version == None or version < 9):
db = self.get_db()
try:
db.execute('''
CREATE TABLE markov_chain (
id INTEGER PRIMARY KEY AUTOINCREMENT,
k1 TEXT NOT NULL,
k2 TEXT NOT NULL,
v TEXT NOT NULL
)''')
sql = 'INSERT INTO drbotzo_modules VALUES (?,?)'
db.execute(sql, (self.__class__.__name__, 1))
db.commit()
db.close()
version = 1
version = 9
# recreating the tables, since i need to add some foreign key constraints
db.execute('''DROP INDEX IF EXISTS markov_chain_keys_and_context_index''')
db.execute('''DROP INDEX IF EXISTS markov_chain_keys_index''')
db.execute('''DROP INDEX IF EXISTS markov_chain_value_and_context_index''')
db.execute('''DROP TABLE IF EXISTS markov_chain''')
db.execute('''DROP TABLE IF EXISTS markov_target_to_context_map''')
db.execute('''DROP TABLE IF EXISTS markov_chatter_target''')
db.execute('''DROP TABLE IF EXISTS markov_context''')
self._learn_line('')
except sqlite3.Error as e:
db.rollback()
db.close()
print("sqlite error: " + str(e))
raise
if (version < 2):
db = self.get_db()
try:
db.execute('''
ALTER TABLE markov_chain
ADD COLUMN context TEXT DEFAULT NULL''')
CREATE TABLE markov_chatter_target (
id INTEGER PRIMARY KEY AUTOINCREMENT,
target TEXT NOT NULL,
chance INTEGER NOT NULL DEFAULT 99999
)''')
db.execute('''
CREATE TABLE markov_context (
id INTEGER PRIMARY KEY AUTOINCREMENT,
context TEXT NOT NULL
)''')
db.execute('''
CREATE TABLE markov_target_to_context_map (
id INTEGER PRIMARY KEY AUTOINCREMENT,
@ -116,101 +110,30 @@ class Markov(Module):
context_id INTEGER NOT NULL,
FOREIGN KEY(context_id) REFERENCES markov_context(id)
)''')
db.execute('UPDATE drbotzo_modules SET version = ? WHERE module = ?',
(2, self.__class__.__name__))
db.commit()
db.close()
version = 2
except sqlite3.Error as e:
db.rollback()
db.close()
print('sqlite error: ' + str(e))
raise
if (version < 3):
db = self.get_db()
try:
db.execute('''
CREATE INDEX markov_chain_keys_index
ON markov_chain (k1, k2)''')
db.execute('UPDATE drbotzo_modules SET version = ? WHERE module = ?',
(3, self.__class__.__name__))
db.commit()
db.close()
version = 3
except sqlite3.Error as e:
db.rollback()
db.close()
print('sqlite error: ' + str(e))
raise
if (version < 4):
db = self.get_db()
try:
db.execute('UPDATE drbotzo_modules SET version = ? WHERE module = ?',
(4, self.__class__.__name__))
db.commit()
db.close()
version = 4
except sqlite3.Error as e:
db.rollback()
db.close()
print('sqlite error: ' + str(e))
raise
if (version < 5):
db = self.get_db()
try:
version = 5
db.execute('''
CREATE TABLE markov_chatter_target (
CREATE TABLE markov_chain (
id INTEGER PRIMARY KEY AUTOINCREMENT,
target TEXT NOT NULL
k1 TEXT NOT NULL,
k2 TEXT NOT NULL,
v TEXT NOT NULL,
context_id INTEGER DEFAULT NULL,
FOREIGN KEY(context_id) REFERENCES markov_context(id)
)''')
db.commit()
db.close()
self.db_register_module_version(self.__class__.__name__, version)
except sqlite3.Error as e:
db.rollback()
db.close()
print('sqlite error: ' + str(e))
raise
if (version < 6):
db = self.get_db()
try:
version = 6
db.execute('''
CREATE INDEX markov_chain_keys_and_context_index
ON markov_chain (k1, k2, context)''')
db.commit()
db.close()
self.db_register_module_version(self.__class__.__name__, version)
except sqlite3.Error as e:
db.rollback()
db.close()
print('sqlite error: ' + str(e))
raise
if (version < 7):
db = self.get_db()
try:
version = 7
CREATE INDEX markov_chain_keys_and_context_id_index
ON markov_chain (k1, k2, context_id)''')
db.execute('''
ALTER TABLE markov_chatter_target ADD COLUMN chance INTEGER NOT NULL DEFAULT 99999''')
db.commit()
db.close()
self.db_register_module_version(self.__class__.__name__, version)
except sqlite3.Error as e:
db.rollback()
db.close()
print('sqlite error: ' + str(e))
raise
if (version < 8):
db = self.get_db()
try:
version = 8
db.execute('''
CREATE INDEX markov_chain_value_and_context_index
ON markov_chain (v, context)''')
CREATE INDEX markov_chain_value_and_context_id_index
ON markov_chain (v, context_id)''')
db.commit()
db.close()
self.db_register_module_version(self.__class__.__name__, version)
self._learn_line('','')
except sqlite3.Error as e:
db.rollback()
db.close()
@ -255,11 +178,14 @@ class Markov(Module):
target = event.target()
if self.trainre.search(what):
return self.reply(connection, event, self.markov_train(connection, event, nick, userhost, what, admin_unlocked))
return self.reply(connection, event, self.markov_train(connection, event, nick,
userhost, what, admin_unlocked))
elif self.learnre.search(what):
return self.reply(connection, event, self.markov_learn(connection, event, nick, userhost, what, admin_unlocked))
return self.reply(connection, event, self.markov_learn(connection, event, nick,
userhost, what, admin_unlocked))
elif self.replyre.search(what) and not self.shut_up:
return self.reply(connection, event, self.markov_reply(connection, event, nick, userhost, what, admin_unlocked))
return self.reply(connection, event, self.markov_reply(connection, event, nick,
userhost, what, admin_unlocked))
if not self.shut_up:
# not a command, so see if i'm being mentioned
@ -269,7 +195,8 @@ class Markov(Module):
if addressed_re.match(what):
# i was addressed directly, so respond, addressing the speaker
self.lines_seen.append(('.self.said.', datetime.now()))
return self.reply(connection, event, '{0:s}: {1:s}'.format(nick, self._generate_line(target, line=addressed_re.match(what).group(1))))
return self.reply(connection, event, '{0:s}: {1:s}'.format(nick,
self._generate_line(target, line=addressed_re.match(what).group(1))))
else:
# i wasn't addressed directly, so just respond
self.lines_seen.append(('.self.said.', datetime.now()))
@ -365,7 +292,8 @@ class Markov(Module):
for (nick,then) in self.lines_seen:
rdelta = relativedelta(datetime.now(), then)
if rdelta.years == 0 and rdelta.months == 0 and rdelta.days == 0 and rdelta.hours == 0 and rdelta.minutes == 0 and rdelta.seconds <= 29:
if (rdelta.years == 0 and rdelta.months == 0 and rdelta.days == 0 and
rdelta.hours == 0 and rdelta.minutes == 0 and rdelta.seconds <= 29):
last_30_sec_lines.append((nick,then))
if len(last_30_sec_lines) >= 8:
@ -374,7 +302,8 @@ class Markov(Module):
self.shut_up = True
targets = self._get_chatter_targets()
for t in targets:
self.sendmsg(self.connection, t['target'], 'shutting up for 30 seconds due to last 30 seconds of activity')
self.sendmsg(self.connection, t['target'],
'shutting up for 30 seconds due to last 30 seconds of activity')
def _learn_line(self, line, target):
"""Create Markov chains from the provided line."""
@ -383,10 +312,10 @@ class Markov(Module):
k1 = self.start1
k2 = self.start2
context = target
context_id = self._get_context_id_for_target(target)
# if there's no context, this is probably a sub-command. don't learn it
if context:
# if there's no target, this is probably a sub-command. don't learn it
if target:
words = line.split()
if len(words) <= 0:
@ -395,18 +324,20 @@ class Markov(Module):
try:
db = self.get_db()
cur = db.cursor()
statement = 'INSERT INTO markov_chain (k1, k2, v, context) VALUES (?, ?, ?, ?)'
statement = 'INSERT INTO markov_chain (k1, k2, v, context_id) VALUES (?, ?, ?, ?)'
for word in words:
cur.execute(statement, (k1.decode('utf-8', 'replace').lower(), k2.decode('utf-8', 'replace').lower(), word.decode('utf-8', 'replace').lower(), context))
cur.execute(statement, (k1.decode('utf-8', 'replace'),
k2.decode('utf-8', 'replace'), word.decode('utf-8', 'replace'), context_id))
k1, k2 = k2, word
cur.execute(statement, (k1.decode('utf-8', 'replace').lower(), k2.decode('utf-8', 'replace').lower(), self.stop, context))
cur.execute(statement, (k1.decode('utf-8', 'replace'),
k2.decode('utf-8', 'replace'), self.stop, context_id))
db.commit()
db.close()
except sqlite3.Error as e:
db.rollback()
db.close()
print("sqlite error: " + str(e))
print("sqlite error in Markov._learn_line: " + str(e))
raise
def _generate_line(self, target, line='', min_size=15, max_size=100):
@ -427,7 +358,7 @@ class Markov(Module):
words = line.split()
target_word = words[random.randint(0, len(words)-1)]
context = target
context_id = self._get_context_id_for_target(target)
# start with an empty chain, and work from there
gen_words = [self.start1, self.start2]
@ -438,7 +369,7 @@ class Markov(Module):
# we'll just pick a word and work backwards
if gen_words[-1] == self.start2 and target_word != '':
working_backwards = []
key_hits = self._retrieve_k2_for_value(target_word, context)
key_hits = self._retrieve_k2_for_value(target_word, context_id)
if len(key_hits):
working_backwards.append(target_word)
# generate new word
@ -446,7 +377,7 @@ class Markov(Module):
target_word = words[random.randint(0, len(words)-1)]
# work backwards until we randomly bump into a start
while True:
key_hits = self._retrieve_k2_for_value(working_backwards[0], context)
key_hits = self._retrieve_k2_for_value(working_backwards[0], context_id)
if target_word in key_hits:
found_word = target_word
# generate new word
@ -464,7 +395,7 @@ class Markov(Module):
else:
working_backwards.insert(0, found_word)
key_hits = self._retrieve_chains_for_key(gen_words[-2], gen_words[-1], context)
key_hits = self._retrieve_chains_for_key(gen_words[-2], gen_words[-1], context_id)
# use the chain that includes the target word, if it is found
if target_word != '' and target_word in key_hits:
gen_words.append(target_word)
@ -487,7 +418,7 @@ class Markov(Module):
return ' '.join(gen_words).encode('utf-8', 'ignore')
def _retrieve_chains_for_key(self, k1, k2, context):
def _retrieve_chains_for_key(self, k1, k2, context_id):
"""Get the value(s) for a given key (a pair of strings)."""
values = []
@ -499,10 +430,12 @@ class Markov(Module):
# a faster fashion than selecting all starts
max_id = self._get_max_chain_id()
rand_id = random.randint(1,max_id)
query = 'SELECT v FROM markov_chain WHERE k1 = ? AND k2 = ? AND (context = ? OR context IS NULL) AND id >= {0:d} LIMIT 1'.format(rand_id)
query = ('SELECT v FROM markov_chain WHERE k1 = ? AND k2 = ? AND '
'(context_id = ? OR context_id IS NULL) AND id >= {0:d} LIMIT 1'.format(rand_id))
else:
query = 'SELECT v FROM markov_chain WHERE k1 = ? AND k2 = ? AND (context = ? OR context IS NULL)'
cursor = db.execute(query, (k1,k2,context))
query = ('SELECT v FROM markov_chain WHERE k1 = ? AND k2 = ? AND '
'(context_id = ? OR context_id IS NULL)')
cursor = db.execute(query, (k1,k2,context_id))
results = cursor.fetchall()
for result in results:
@ -512,17 +445,17 @@ class Markov(Module):
return values
except sqlite3.Error as e:
db.close()
print('sqlite error: ' + str(e))
print('sqlite error in Markov._retrieve_chains_for_key: ' + str(e))
raise
def _retrieve_k2_for_value(self, v, context):
def _retrieve_k2_for_value(self, v, context_id):
"""Get the value(s) for a given key (a pair of strings)."""
values = []
try:
db = self.get_db()
query = 'SELECT k2 FROM markov_chain WHERE v = ? AND (context = ? OR context IS NULL)'
cursor = db.execute(query, (v,context))
query = 'SELECT k2 FROM markov_chain WHERE v = ? AND (context_id = ? OR context_id IS NULL)'
cursor = db.execute(query, (v,context_id))
results = cursor.fetchall()
for result in results:
@ -532,7 +465,7 @@ class Markov(Module):
return values
except sqlite3.Error as e:
db.close()
print('sqlite error: ' + str(e))
print('sqlite error in Markov._retrieve_k2_for_value: ' + str(e))
raise
def _get_chatter_targets(self):
@ -548,7 +481,7 @@ class Markov(Module):
return results
except sqlite3.Error as e:
db.close()
print('sqlite error: ' + str(e))
print('sqlite error in Markov._get_chatter_targets: ' + str(e))
raise
def _get_one_chatter_target(self):
@ -575,7 +508,31 @@ class Markov(Module):
return None
except sqlite3.Error as e:
db.close()
print('sqlite error: ' + str(e))
print('sqlite error in Markov._get_max_chain_id: ' + str(e))
raise
def _get_context_id_for_target(self, target):
"""Get the context ID for the desired/input target."""
try:
db = self.get_db()
query = '''
SELECT mc.id FROM markov_context mc
INNER JOIN markov_target_to_context_map mt
ON mt.context_id = mc.id
WHERE mt.target = ?
'''
cursor = db.execute(query, (target,))
result = cursor.fetchone()
db.close()
if result:
return result['id']
else:
return None
except sqlite3.Error as e:
db.close()
print('sqlite error in Markov._get_context_id_for_target: ' + str(e))
raise
# vi:tabstop=4:expandtab:autoindent

View File

@ -21,16 +21,16 @@ import os
import sqlite3
import sys
parser = argparse.ArgumentParser(description='Import lines into the specified context.')
parser.add_argument('context', metavar='CONTEXT', type=str, nargs=1)
parser = argparse.ArgumentParser(description='Import lines into the specified context_id.')
parser.add_argument('context_id', metavar='CONTEXT', type=int, nargs=1)
args = parser.parse_args()
print(args.context[0])
print(args.context_id[0])
db = sqlite3.connect('dr.botzo.data')
db.row_factory = sqlite3.Row
cur = db.cursor()
statement = 'INSERT INTO markov_chain (k1, k2, v, context) VALUES (?, ?, ?, ?)'
statement = 'INSERT INTO markov_chain (k1, k2, v, context_id) VALUES (?, ?, ?, ?)'
for line in sys.stdin:
# set up the head of the chain
w1 = '__start1'
@ -39,7 +39,7 @@ for line in sys.stdin:
# for each word pair, add the next word to the dictionary
for word in line.split():
try:
cur.execute(statement, (w1.decode('utf-8', 'replace').lower(), w2.decode('utf-8', 'replace').lower(), word.decode('utf-8', 'replace').lower(), args.context[0]))
cur.execute(statement, (w1.decode('utf-8', 'replace'), w2.decode('utf-8', 'replace'), word.decode('utf-8', 'replace'), args.context_id[0]))
except sqlite3.Error as e:
db.rollback()
print("sqlite error: " + str(e))
@ -48,7 +48,7 @@ for line in sys.stdin:
w1, w2 = w2, word
try:
cur.execute(statement, (w1.decode('utf-8', 'replace').lower(), w2.decode('utf-8', 'replace').lower(), '__stop', args.context[0]))
cur.execute(statement, (w1.decode('utf-8', 'replace'), w2.decode('utf-8', 'replace'), '__stop', args.context_id[0]))
db.commit()
except sqlite3.Error as e:
db.rollback()