Markov: rebuild the tables, use the context stuff in a better fashion this time

the module will drop your old tables if you have them, so if there's data there,
be sure to back them up and figure out some migration strategy (probably annoying
and probably having to script it).

the big change is that each line is associated to a context now, and channels
are also associated to contexts. this should allow for a better partitioning
of multiple brains, and changing which channels point to which brain.

also caught in the wake is some additional logging verbosity, and a change to
no longer lower() everything learned.

the script to dump a file into the database has also been updated with the above
changes
This commit is contained in:
Brian S. Stephan 2012-02-28 23:23:14 -06:00
parent 79ddce0bcb
commit 26bc8bec34
2 changed files with 102 additions and 145 deletions

View File

@ -76,39 +76,33 @@ class Markov(Module):
"""Create the markov chain table.""" """Create the markov chain table."""
version = self.db_module_registered(self.__class__.__name__) version = self.db_module_registered(self.__class__.__name__)
if (version == None): if (version == None or version < 9):
db = self.get_db() db = self.get_db()
try: try:
db.execute(''' version = 9
CREATE TABLE markov_chain (
id INTEGER PRIMARY KEY AUTOINCREMENT, # recreating the tables, since i need to add some foreign key constraints
k1 TEXT NOT NULL, db.execute('''DROP INDEX IF EXISTS markov_chain_keys_and_context_index''')
k2 TEXT NOT NULL, db.execute('''DROP INDEX IF EXISTS markov_chain_keys_index''')
v TEXT NOT NULL db.execute('''DROP INDEX IF EXISTS markov_chain_value_and_context_index''')
)''') db.execute('''DROP TABLE IF EXISTS markov_chain''')
sql = 'INSERT INTO drbotzo_modules VALUES (?,?)' db.execute('''DROP TABLE IF EXISTS markov_target_to_context_map''')
db.execute(sql, (self.__class__.__name__, 1)) db.execute('''DROP TABLE IF EXISTS markov_chatter_target''')
db.commit() db.execute('''DROP TABLE IF EXISTS markov_context''')
db.close()
version = 1
self._learn_line('')
except sqlite3.Error as e:
db.rollback()
db.close()
print("sqlite error: " + str(e))
raise
if (version < 2):
db = self.get_db()
try:
db.execute(''' db.execute('''
ALTER TABLE markov_chain CREATE TABLE markov_chatter_target (
ADD COLUMN context TEXT DEFAULT NULL''') id INTEGER PRIMARY KEY AUTOINCREMENT,
target TEXT NOT NULL,
chance INTEGER NOT NULL DEFAULT 99999
)''')
db.execute(''' db.execute('''
CREATE TABLE markov_context ( CREATE TABLE markov_context (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
context TEXT NOT NULL context TEXT NOT NULL
)''') )''')
db.execute(''' db.execute('''
CREATE TABLE markov_target_to_context_map ( CREATE TABLE markov_target_to_context_map (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
@ -116,101 +110,30 @@ class Markov(Module):
context_id INTEGER NOT NULL, context_id INTEGER NOT NULL,
FOREIGN KEY(context_id) REFERENCES markov_context(id) FOREIGN KEY(context_id) REFERENCES markov_context(id)
)''') )''')
db.execute('UPDATE drbotzo_modules SET version = ? WHERE module = ?',
(2, self.__class__.__name__))
db.commit()
db.close()
version = 2
except sqlite3.Error as e:
db.rollback()
db.close()
print('sqlite error: ' + str(e))
raise
if (version < 3):
db = self.get_db()
try:
db.execute(''' db.execute('''
CREATE INDEX markov_chain_keys_index CREATE TABLE markov_chain (
ON markov_chain (k1, k2)''')
db.execute('UPDATE drbotzo_modules SET version = ? WHERE module = ?',
(3, self.__class__.__name__))
db.commit()
db.close()
version = 3
except sqlite3.Error as e:
db.rollback()
db.close()
print('sqlite error: ' + str(e))
raise
if (version < 4):
db = self.get_db()
try:
db.execute('UPDATE drbotzo_modules SET version = ? WHERE module = ?',
(4, self.__class__.__name__))
db.commit()
db.close()
version = 4
except sqlite3.Error as e:
db.rollback()
db.close()
print('sqlite error: ' + str(e))
raise
if (version < 5):
db = self.get_db()
try:
version = 5
db.execute('''
CREATE TABLE markov_chatter_target (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
target TEXT NOT NULL k1 TEXT NOT NULL,
k2 TEXT NOT NULL,
v TEXT NOT NULL,
context_id INTEGER DEFAULT NULL,
FOREIGN KEY(context_id) REFERENCES markov_context(id)
)''') )''')
db.commit()
db.close()
self.db_register_module_version(self.__class__.__name__, version)
except sqlite3.Error as e:
db.rollback()
db.close()
print('sqlite error: ' + str(e))
raise
if (version < 6):
db = self.get_db()
try:
version = 6
db.execute(''' db.execute('''
CREATE INDEX markov_chain_keys_and_context_index CREATE INDEX markov_chain_keys_and_context_id_index
ON markov_chain (k1, k2, context)''') ON markov_chain (k1, k2, context_id)''')
db.commit()
db.close()
self.db_register_module_version(self.__class__.__name__, version)
except sqlite3.Error as e:
db.rollback()
db.close()
print('sqlite error: ' + str(e))
raise
if (version < 7):
db = self.get_db()
try:
version = 7
db.execute(''' db.execute('''
ALTER TABLE markov_chatter_target ADD COLUMN chance INTEGER NOT NULL DEFAULT 99999''') CREATE INDEX markov_chain_value_and_context_id_index
db.commit() ON markov_chain (v, context_id)''')
db.close()
self.db_register_module_version(self.__class__.__name__, version)
except sqlite3.Error as e:
db.rollback()
db.close()
print('sqlite error: ' + str(e))
raise
if (version < 8):
db = self.get_db()
try:
version = 8
db.execute('''
CREATE INDEX markov_chain_value_and_context_index
ON markov_chain (v, context)''')
db.commit() db.commit()
db.close() db.close()
self.db_register_module_version(self.__class__.__name__, version) self.db_register_module_version(self.__class__.__name__, version)
self._learn_line('','')
except sqlite3.Error as e: except sqlite3.Error as e:
db.rollback() db.rollback()
db.close() db.close()
@ -255,11 +178,14 @@ class Markov(Module):
target = event.target() target = event.target()
if self.trainre.search(what): if self.trainre.search(what):
return self.reply(connection, event, self.markov_train(connection, event, nick, userhost, what, admin_unlocked)) return self.reply(connection, event, self.markov_train(connection, event, nick,
userhost, what, admin_unlocked))
elif self.learnre.search(what): elif self.learnre.search(what):
return self.reply(connection, event, self.markov_learn(connection, event, nick, userhost, what, admin_unlocked)) return self.reply(connection, event, self.markov_learn(connection, event, nick,
userhost, what, admin_unlocked))
elif self.replyre.search(what) and not self.shut_up: elif self.replyre.search(what) and not self.shut_up:
return self.reply(connection, event, self.markov_reply(connection, event, nick, userhost, what, admin_unlocked)) return self.reply(connection, event, self.markov_reply(connection, event, nick,
userhost, what, admin_unlocked))
if not self.shut_up: if not self.shut_up:
# not a command, so see if i'm being mentioned # not a command, so see if i'm being mentioned
@ -269,7 +195,8 @@ class Markov(Module):
if addressed_re.match(what): if addressed_re.match(what):
# i was addressed directly, so respond, addressing the speaker # i was addressed directly, so respond, addressing the speaker
self.lines_seen.append(('.self.said.', datetime.now())) self.lines_seen.append(('.self.said.', datetime.now()))
return self.reply(connection, event, '{0:s}: {1:s}'.format(nick, self._generate_line(target, line=addressed_re.match(what).group(1)))) return self.reply(connection, event, '{0:s}: {1:s}'.format(nick,
self._generate_line(target, line=addressed_re.match(what).group(1))))
else: else:
# i wasn't addressed directly, so just respond # i wasn't addressed directly, so just respond
self.lines_seen.append(('.self.said.', datetime.now())) self.lines_seen.append(('.self.said.', datetime.now()))
@ -365,7 +292,8 @@ class Markov(Module):
for (nick,then) in self.lines_seen: for (nick,then) in self.lines_seen:
rdelta = relativedelta(datetime.now(), then) rdelta = relativedelta(datetime.now(), then)
if rdelta.years == 0 and rdelta.months == 0 and rdelta.days == 0 and rdelta.hours == 0 and rdelta.minutes == 0 and rdelta.seconds <= 29: if (rdelta.years == 0 and rdelta.months == 0 and rdelta.days == 0 and
rdelta.hours == 0 and rdelta.minutes == 0 and rdelta.seconds <= 29):
last_30_sec_lines.append((nick,then)) last_30_sec_lines.append((nick,then))
if len(last_30_sec_lines) >= 8: if len(last_30_sec_lines) >= 8:
@ -374,7 +302,8 @@ class Markov(Module):
self.shut_up = True self.shut_up = True
targets = self._get_chatter_targets() targets = self._get_chatter_targets()
for t in targets: for t in targets:
self.sendmsg(self.connection, t['target'], 'shutting up for 30 seconds due to last 30 seconds of activity') self.sendmsg(self.connection, t['target'],
'shutting up for 30 seconds due to last 30 seconds of activity')
def _learn_line(self, line, target): def _learn_line(self, line, target):
"""Create Markov chains from the provided line.""" """Create Markov chains from the provided line."""
@ -383,10 +312,10 @@ class Markov(Module):
k1 = self.start1 k1 = self.start1
k2 = self.start2 k2 = self.start2
context = target context_id = self._get_context_id_for_target(target)
# if there's no context, this is probably a sub-command. don't learn it # if there's no target, this is probably a sub-command. don't learn it
if context: if target:
words = line.split() words = line.split()
if len(words) <= 0: if len(words) <= 0:
@ -395,18 +324,20 @@ class Markov(Module):
try: try:
db = self.get_db() db = self.get_db()
cur = db.cursor() cur = db.cursor()
statement = 'INSERT INTO markov_chain (k1, k2, v, context) VALUES (?, ?, ?, ?)' statement = 'INSERT INTO markov_chain (k1, k2, v, context_id) VALUES (?, ?, ?, ?)'
for word in words: for word in words:
cur.execute(statement, (k1.decode('utf-8', 'replace').lower(), k2.decode('utf-8', 'replace').lower(), word.decode('utf-8', 'replace').lower(), context)) cur.execute(statement, (k1.decode('utf-8', 'replace'),
k2.decode('utf-8', 'replace'), word.decode('utf-8', 'replace'), context_id))
k1, k2 = k2, word k1, k2 = k2, word
cur.execute(statement, (k1.decode('utf-8', 'replace').lower(), k2.decode('utf-8', 'replace').lower(), self.stop, context)) cur.execute(statement, (k1.decode('utf-8', 'replace'),
k2.decode('utf-8', 'replace'), self.stop, context_id))
db.commit() db.commit()
db.close() db.close()
except sqlite3.Error as e: except sqlite3.Error as e:
db.rollback() db.rollback()
db.close() db.close()
print("sqlite error: " + str(e)) print("sqlite error in Markov._learn_line: " + str(e))
raise raise
def _generate_line(self, target, line='', min_size=15, max_size=100): def _generate_line(self, target, line='', min_size=15, max_size=100):
@ -427,7 +358,7 @@ class Markov(Module):
words = line.split() words = line.split()
target_word = words[random.randint(0, len(words)-1)] target_word = words[random.randint(0, len(words)-1)]
context = target context_id = self._get_context_id_for_target(target)
# start with an empty chain, and work from there # start with an empty chain, and work from there
gen_words = [self.start1, self.start2] gen_words = [self.start1, self.start2]
@ -438,7 +369,7 @@ class Markov(Module):
# we'll just pick a word and work backwards # we'll just pick a word and work backwards
if gen_words[-1] == self.start2 and target_word != '': if gen_words[-1] == self.start2 and target_word != '':
working_backwards = [] working_backwards = []
key_hits = self._retrieve_k2_for_value(target_word, context) key_hits = self._retrieve_k2_for_value(target_word, context_id)
if len(key_hits): if len(key_hits):
working_backwards.append(target_word) working_backwards.append(target_word)
# generate new word # generate new word
@ -446,7 +377,7 @@ class Markov(Module):
target_word = words[random.randint(0, len(words)-1)] target_word = words[random.randint(0, len(words)-1)]
# work backwards until we randomly bump into a start # work backwards until we randomly bump into a start
while True: while True:
key_hits = self._retrieve_k2_for_value(working_backwards[0], context) key_hits = self._retrieve_k2_for_value(working_backwards[0], context_id)
if target_word in key_hits: if target_word in key_hits:
found_word = target_word found_word = target_word
# generate new word # generate new word
@ -464,7 +395,7 @@ class Markov(Module):
else: else:
working_backwards.insert(0, found_word) working_backwards.insert(0, found_word)
key_hits = self._retrieve_chains_for_key(gen_words[-2], gen_words[-1], context) key_hits = self._retrieve_chains_for_key(gen_words[-2], gen_words[-1], context_id)
# use the chain that includes the target word, if it is found # use the chain that includes the target word, if it is found
if target_word != '' and target_word in key_hits: if target_word != '' and target_word in key_hits:
gen_words.append(target_word) gen_words.append(target_word)
@ -487,7 +418,7 @@ class Markov(Module):
return ' '.join(gen_words).encode('utf-8', 'ignore') return ' '.join(gen_words).encode('utf-8', 'ignore')
def _retrieve_chains_for_key(self, k1, k2, context): def _retrieve_chains_for_key(self, k1, k2, context_id):
"""Get the value(s) for a given key (a pair of strings).""" """Get the value(s) for a given key (a pair of strings)."""
values = [] values = []
@ -499,10 +430,12 @@ class Markov(Module):
# a faster fashion than selecting all starts # a faster fashion than selecting all starts
max_id = self._get_max_chain_id() max_id = self._get_max_chain_id()
rand_id = random.randint(1,max_id) rand_id = random.randint(1,max_id)
query = 'SELECT v FROM markov_chain WHERE k1 = ? AND k2 = ? AND (context = ? OR context IS NULL) AND id >= {0:d} LIMIT 1'.format(rand_id) query = ('SELECT v FROM markov_chain WHERE k1 = ? AND k2 = ? AND '
'(context_id = ? OR context_id IS NULL) AND id >= {0:d} LIMIT 1'.format(rand_id))
else: else:
query = 'SELECT v FROM markov_chain WHERE k1 = ? AND k2 = ? AND (context = ? OR context IS NULL)' query = ('SELECT v FROM markov_chain WHERE k1 = ? AND k2 = ? AND '
cursor = db.execute(query, (k1,k2,context)) '(context_id = ? OR context_id IS NULL)')
cursor = db.execute(query, (k1,k2,context_id))
results = cursor.fetchall() results = cursor.fetchall()
for result in results: for result in results:
@ -512,17 +445,17 @@ class Markov(Module):
return values return values
except sqlite3.Error as e: except sqlite3.Error as e:
db.close() db.close()
print('sqlite error: ' + str(e)) print('sqlite error in Markov._retrieve_chains_for_key: ' + str(e))
raise raise
def _retrieve_k2_for_value(self, v, context): def _retrieve_k2_for_value(self, v, context_id):
"""Get the value(s) for a given key (a pair of strings).""" """Get the value(s) for a given key (a pair of strings)."""
values = [] values = []
try: try:
db = self.get_db() db = self.get_db()
query = 'SELECT k2 FROM markov_chain WHERE v = ? AND (context = ? OR context IS NULL)' query = 'SELECT k2 FROM markov_chain WHERE v = ? AND (context_id = ? OR context_id IS NULL)'
cursor = db.execute(query, (v,context)) cursor = db.execute(query, (v,context_id))
results = cursor.fetchall() results = cursor.fetchall()
for result in results: for result in results:
@ -532,7 +465,7 @@ class Markov(Module):
return values return values
except sqlite3.Error as e: except sqlite3.Error as e:
db.close() db.close()
print('sqlite error: ' + str(e)) print('sqlite error in Markov._retrieve_k2_for_value: ' + str(e))
raise raise
def _get_chatter_targets(self): def _get_chatter_targets(self):
@ -548,7 +481,7 @@ class Markov(Module):
return results return results
except sqlite3.Error as e: except sqlite3.Error as e:
db.close() db.close()
print('sqlite error: ' + str(e)) print('sqlite error in Markov._get_chatter_targets: ' + str(e))
raise raise
def _get_one_chatter_target(self): def _get_one_chatter_target(self):
@ -575,7 +508,31 @@ class Markov(Module):
return None return None
except sqlite3.Error as e: except sqlite3.Error as e:
db.close() db.close()
print('sqlite error: ' + str(e)) print('sqlite error in Markov._get_max_chain_id: ' + str(e))
raise
def _get_context_id_for_target(self, target):
"""Get the context ID for the desired/input target."""
try:
db = self.get_db()
query = '''
SELECT mc.id FROM markov_context mc
INNER JOIN markov_target_to_context_map mt
ON mt.context_id = mc.id
WHERE mt.target = ?
'''
cursor = db.execute(query, (target,))
result = cursor.fetchone()
db.close()
if result:
return result['id']
else:
return None
except sqlite3.Error as e:
db.close()
print('sqlite error in Markov._get_context_id_for_target: ' + str(e))
raise raise
# vi:tabstop=4:expandtab:autoindent # vi:tabstop=4:expandtab:autoindent

View File

@ -21,16 +21,16 @@ import os
import sqlite3 import sqlite3
import sys import sys
parser = argparse.ArgumentParser(description='Import lines into the specified context.') parser = argparse.ArgumentParser(description='Import lines into the specified context_id.')
parser.add_argument('context', metavar='CONTEXT', type=str, nargs=1) parser.add_argument('context_id', metavar='CONTEXT', type=int, nargs=1)
args = parser.parse_args() args = parser.parse_args()
print(args.context[0]) print(args.context_id[0])
db = sqlite3.connect('dr.botzo.data') db = sqlite3.connect('dr.botzo.data')
db.row_factory = sqlite3.Row db.row_factory = sqlite3.Row
cur = db.cursor() cur = db.cursor()
statement = 'INSERT INTO markov_chain (k1, k2, v, context) VALUES (?, ?, ?, ?)' statement = 'INSERT INTO markov_chain (k1, k2, v, context_id) VALUES (?, ?, ?, ?)'
for line in sys.stdin: for line in sys.stdin:
# set up the head of the chain # set up the head of the chain
w1 = '__start1' w1 = '__start1'
@ -39,7 +39,7 @@ for line in sys.stdin:
# for each word pair, add the next word to the dictionary # for each word pair, add the next word to the dictionary
for word in line.split(): for word in line.split():
try: try:
cur.execute(statement, (w1.decode('utf-8', 'replace').lower(), w2.decode('utf-8', 'replace').lower(), word.decode('utf-8', 'replace').lower(), args.context[0])) cur.execute(statement, (w1.decode('utf-8', 'replace'), w2.decode('utf-8', 'replace'), word.decode('utf-8', 'replace'), args.context_id[0]))
except sqlite3.Error as e: except sqlite3.Error as e:
db.rollback() db.rollback()
print("sqlite error: " + str(e)) print("sqlite error: " + str(e))
@ -48,7 +48,7 @@ for line in sys.stdin:
w1, w2 = w2, word w1, w2 = w2, word
try: try:
cur.execute(statement, (w1.decode('utf-8', 'replace').lower(), w2.decode('utf-8', 'replace').lower(), '__stop', args.context[0])) cur.execute(statement, (w1.decode('utf-8', 'replace'), w2.decode('utf-8', 'replace'), '__stop', args.context_id[0]))
db.commit() db.commit()
except sqlite3.Error as e: except sqlite3.Error as e:
db.rollback() db.rollback()