From 26bc8bec347383a369be857896b24e37aa043cae Mon Sep 17 00:00:00 2001 From: "Brian S. Stephan" Date: Tue, 28 Feb 2012 23:23:14 -0600 Subject: [PATCH] Markov: rebuild the tables, use the context stuff in a better fashion this time the module will drop your old tables if you have them, so if there's data there, be sure to back them up and figure out some migration strategy (probably annoying and probably having to script it). the big change is that each line is associated to a context now, and channels are also associated to contexts. this should allow for a better partitioning of multiple brains, and changing which channels point to which brain. also caught in the wake is some additional logging verbosity, and a change to no longer lower() everything learned. the script to dump a file into the database has also been updated with the above changes --- modules/Markov.py | 235 +++++++++-------------- scripts/import-file-into-markov_chain.py | 12 +- 2 files changed, 102 insertions(+), 145 deletions(-) diff --git a/modules/Markov.py b/modules/Markov.py index 3e1fd2c..51287ff 100644 --- a/modules/Markov.py +++ b/modules/Markov.py @@ -76,39 +76,33 @@ class Markov(Module): """Create the markov chain table.""" version = self.db_module_registered(self.__class__.__name__) - if (version == None): + if (version == None or version < 9): db = self.get_db() try: - db.execute(''' - CREATE TABLE markov_chain ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - k1 TEXT NOT NULL, - k2 TEXT NOT NULL, - v TEXT NOT NULL - )''') - sql = 'INSERT INTO drbotzo_modules VALUES (?,?)' - db.execute(sql, (self.__class__.__name__, 1)) - db.commit() - db.close() - version = 1 + version = 9 + + # recreating the tables, since i need to add some foreign key constraints + db.execute('''DROP INDEX IF EXISTS markov_chain_keys_and_context_index''') + db.execute('''DROP INDEX IF EXISTS markov_chain_keys_index''') + db.execute('''DROP INDEX IF EXISTS markov_chain_value_and_context_index''') + db.execute('''DROP TABLE IF EXISTS markov_chain''') + db.execute('''DROP TABLE IF EXISTS markov_target_to_context_map''') + db.execute('''DROP TABLE IF EXISTS markov_chatter_target''') + db.execute('''DROP TABLE IF EXISTS markov_context''') - self._learn_line('') - except sqlite3.Error as e: - db.rollback() - db.close() - print("sqlite error: " + str(e)) - raise - if (version < 2): - db = self.get_db() - try: db.execute(''' - ALTER TABLE markov_chain - ADD COLUMN context TEXT DEFAULT NULL''') + CREATE TABLE markov_chatter_target ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + target TEXT NOT NULL, + chance INTEGER NOT NULL DEFAULT 99999 + )''') + db.execute(''' CREATE TABLE markov_context ( id INTEGER PRIMARY KEY AUTOINCREMENT, context TEXT NOT NULL )''') + db.execute(''' CREATE TABLE markov_target_to_context_map ( id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -116,101 +110,30 @@ class Markov(Module): context_id INTEGER NOT NULL, FOREIGN KEY(context_id) REFERENCES markov_context(id) )''') - db.execute('UPDATE drbotzo_modules SET version = ? WHERE module = ?', - (2, self.__class__.__name__)) - db.commit() - db.close() - version = 2 - except sqlite3.Error as e: - db.rollback() - db.close() - print('sqlite error: ' + str(e)) - raise - if (version < 3): - db = self.get_db() - try: + db.execute(''' - CREATE INDEX markov_chain_keys_index - ON markov_chain (k1, k2)''') - db.execute('UPDATE drbotzo_modules SET version = ? WHERE module = ?', - (3, self.__class__.__name__)) - db.commit() - db.close() - version = 3 - except sqlite3.Error as e: - db.rollback() - db.close() - print('sqlite error: ' + str(e)) - raise - if (version < 4): - db = self.get_db() - try: - db.execute('UPDATE drbotzo_modules SET version = ? WHERE module = ?', - (4, self.__class__.__name__)) - db.commit() - db.close() - version = 4 - except sqlite3.Error as e: - db.rollback() - db.close() - print('sqlite error: ' + str(e)) - raise - if (version < 5): - db = self.get_db() - try: - version = 5 - db.execute(''' - CREATE TABLE markov_chatter_target ( + CREATE TABLE markov_chain ( id INTEGER PRIMARY KEY AUTOINCREMENT, - target TEXT NOT NULL + k1 TEXT NOT NULL, + k2 TEXT NOT NULL, + v TEXT NOT NULL, + context_id INTEGER DEFAULT NULL, + FOREIGN KEY(context_id) REFERENCES markov_context(id) )''') - db.commit() - db.close() - self.db_register_module_version(self.__class__.__name__, version) - except sqlite3.Error as e: - db.rollback() - db.close() - print('sqlite error: ' + str(e)) - raise - if (version < 6): - db = self.get_db() - try: - version = 6 + db.execute(''' - CREATE INDEX markov_chain_keys_and_context_index - ON markov_chain (k1, k2, context)''') - db.commit() - db.close() - self.db_register_module_version(self.__class__.__name__, version) - except sqlite3.Error as e: - db.rollback() - db.close() - print('sqlite error: ' + str(e)) - raise - if (version < 7): - db = self.get_db() - try: - version = 7 + CREATE INDEX markov_chain_keys_and_context_id_index + ON markov_chain (k1, k2, context_id)''') + db.execute(''' - ALTER TABLE markov_chatter_target ADD COLUMN chance INTEGER NOT NULL DEFAULT 99999''') - db.commit() - db.close() - self.db_register_module_version(self.__class__.__name__, version) - except sqlite3.Error as e: - db.rollback() - db.close() - print('sqlite error: ' + str(e)) - raise - if (version < 8): - db = self.get_db() - try: - version = 8 - db.execute(''' - CREATE INDEX markov_chain_value_and_context_index - ON markov_chain (v, context)''') + CREATE INDEX markov_chain_value_and_context_id_index + ON markov_chain (v, context_id)''') + db.commit() db.close() self.db_register_module_version(self.__class__.__name__, version) + + self._learn_line('','') except sqlite3.Error as e: db.rollback() db.close() @@ -255,11 +178,14 @@ class Markov(Module): target = event.target() if self.trainre.search(what): - return self.reply(connection, event, self.markov_train(connection, event, nick, userhost, what, admin_unlocked)) + return self.reply(connection, event, self.markov_train(connection, event, nick, + userhost, what, admin_unlocked)) elif self.learnre.search(what): - return self.reply(connection, event, self.markov_learn(connection, event, nick, userhost, what, admin_unlocked)) + return self.reply(connection, event, self.markov_learn(connection, event, nick, + userhost, what, admin_unlocked)) elif self.replyre.search(what) and not self.shut_up: - return self.reply(connection, event, self.markov_reply(connection, event, nick, userhost, what, admin_unlocked)) + return self.reply(connection, event, self.markov_reply(connection, event, nick, + userhost, what, admin_unlocked)) if not self.shut_up: # not a command, so see if i'm being mentioned @@ -269,7 +195,8 @@ class Markov(Module): if addressed_re.match(what): # i was addressed directly, so respond, addressing the speaker self.lines_seen.append(('.self.said.', datetime.now())) - return self.reply(connection, event, '{0:s}: {1:s}'.format(nick, self._generate_line(target, line=addressed_re.match(what).group(1)))) + return self.reply(connection, event, '{0:s}: {1:s}'.format(nick, + self._generate_line(target, line=addressed_re.match(what).group(1)))) else: # i wasn't addressed directly, so just respond self.lines_seen.append(('.self.said.', datetime.now())) @@ -365,7 +292,8 @@ class Markov(Module): for (nick,then) in self.lines_seen: rdelta = relativedelta(datetime.now(), then) - if rdelta.years == 0 and rdelta.months == 0 and rdelta.days == 0 and rdelta.hours == 0 and rdelta.minutes == 0 and rdelta.seconds <= 29: + if (rdelta.years == 0 and rdelta.months == 0 and rdelta.days == 0 and + rdelta.hours == 0 and rdelta.minutes == 0 and rdelta.seconds <= 29): last_30_sec_lines.append((nick,then)) if len(last_30_sec_lines) >= 8: @@ -374,7 +302,8 @@ class Markov(Module): self.shut_up = True targets = self._get_chatter_targets() for t in targets: - self.sendmsg(self.connection, t['target'], 'shutting up for 30 seconds due to last 30 seconds of activity') + self.sendmsg(self.connection, t['target'], + 'shutting up for 30 seconds due to last 30 seconds of activity') def _learn_line(self, line, target): """Create Markov chains from the provided line.""" @@ -383,10 +312,10 @@ class Markov(Module): k1 = self.start1 k2 = self.start2 - context = target + context_id = self._get_context_id_for_target(target) - # if there's no context, this is probably a sub-command. don't learn it - if context: + # if there's no target, this is probably a sub-command. don't learn it + if target: words = line.split() if len(words) <= 0: @@ -395,18 +324,20 @@ class Markov(Module): try: db = self.get_db() cur = db.cursor() - statement = 'INSERT INTO markov_chain (k1, k2, v, context) VALUES (?, ?, ?, ?)' + statement = 'INSERT INTO markov_chain (k1, k2, v, context_id) VALUES (?, ?, ?, ?)' for word in words: - cur.execute(statement, (k1.decode('utf-8', 'replace').lower(), k2.decode('utf-8', 'replace').lower(), word.decode('utf-8', 'replace').lower(), context)) + cur.execute(statement, (k1.decode('utf-8', 'replace'), + k2.decode('utf-8', 'replace'), word.decode('utf-8', 'replace'), context_id)) k1, k2 = k2, word - cur.execute(statement, (k1.decode('utf-8', 'replace').lower(), k2.decode('utf-8', 'replace').lower(), self.stop, context)) + cur.execute(statement, (k1.decode('utf-8', 'replace'), + k2.decode('utf-8', 'replace'), self.stop, context_id)) db.commit() db.close() except sqlite3.Error as e: db.rollback() db.close() - print("sqlite error: " + str(e)) + print("sqlite error in Markov._learn_line: " + str(e)) raise def _generate_line(self, target, line='', min_size=15, max_size=100): @@ -427,7 +358,7 @@ class Markov(Module): words = line.split() target_word = words[random.randint(0, len(words)-1)] - context = target + context_id = self._get_context_id_for_target(target) # start with an empty chain, and work from there gen_words = [self.start1, self.start2] @@ -438,7 +369,7 @@ class Markov(Module): # we'll just pick a word and work backwards if gen_words[-1] == self.start2 and target_word != '': working_backwards = [] - key_hits = self._retrieve_k2_for_value(target_word, context) + key_hits = self._retrieve_k2_for_value(target_word, context_id) if len(key_hits): working_backwards.append(target_word) # generate new word @@ -446,7 +377,7 @@ class Markov(Module): target_word = words[random.randint(0, len(words)-1)] # work backwards until we randomly bump into a start while True: - key_hits = self._retrieve_k2_for_value(working_backwards[0], context) + key_hits = self._retrieve_k2_for_value(working_backwards[0], context_id) if target_word in key_hits: found_word = target_word # generate new word @@ -464,7 +395,7 @@ class Markov(Module): else: working_backwards.insert(0, found_word) - key_hits = self._retrieve_chains_for_key(gen_words[-2], gen_words[-1], context) + key_hits = self._retrieve_chains_for_key(gen_words[-2], gen_words[-1], context_id) # use the chain that includes the target word, if it is found if target_word != '' and target_word in key_hits: gen_words.append(target_word) @@ -487,7 +418,7 @@ class Markov(Module): return ' '.join(gen_words).encode('utf-8', 'ignore') - def _retrieve_chains_for_key(self, k1, k2, context): + def _retrieve_chains_for_key(self, k1, k2, context_id): """Get the value(s) for a given key (a pair of strings).""" values = [] @@ -499,10 +430,12 @@ class Markov(Module): # a faster fashion than selecting all starts max_id = self._get_max_chain_id() rand_id = random.randint(1,max_id) - query = 'SELECT v FROM markov_chain WHERE k1 = ? AND k2 = ? AND (context = ? OR context IS NULL) AND id >= {0:d} LIMIT 1'.format(rand_id) + query = ('SELECT v FROM markov_chain WHERE k1 = ? AND k2 = ? AND ' + '(context_id = ? OR context_id IS NULL) AND id >= {0:d} LIMIT 1'.format(rand_id)) else: - query = 'SELECT v FROM markov_chain WHERE k1 = ? AND k2 = ? AND (context = ? OR context IS NULL)' - cursor = db.execute(query, (k1,k2,context)) + query = ('SELECT v FROM markov_chain WHERE k1 = ? AND k2 = ? AND ' + '(context_id = ? OR context_id IS NULL)') + cursor = db.execute(query, (k1,k2,context_id)) results = cursor.fetchall() for result in results: @@ -512,17 +445,17 @@ class Markov(Module): return values except sqlite3.Error as e: db.close() - print('sqlite error: ' + str(e)) + print('sqlite error in Markov._retrieve_chains_for_key: ' + str(e)) raise - def _retrieve_k2_for_value(self, v, context): + def _retrieve_k2_for_value(self, v, context_id): """Get the value(s) for a given key (a pair of strings).""" values = [] try: db = self.get_db() - query = 'SELECT k2 FROM markov_chain WHERE v = ? AND (context = ? OR context IS NULL)' - cursor = db.execute(query, (v,context)) + query = 'SELECT k2 FROM markov_chain WHERE v = ? AND (context_id = ? OR context_id IS NULL)' + cursor = db.execute(query, (v,context_id)) results = cursor.fetchall() for result in results: @@ -532,7 +465,7 @@ class Markov(Module): return values except sqlite3.Error as e: db.close() - print('sqlite error: ' + str(e)) + print('sqlite error in Markov._retrieve_k2_for_value: ' + str(e)) raise def _get_chatter_targets(self): @@ -548,7 +481,7 @@ class Markov(Module): return results except sqlite3.Error as e: db.close() - print('sqlite error: ' + str(e)) + print('sqlite error in Markov._get_chatter_targets: ' + str(e)) raise def _get_one_chatter_target(self): @@ -575,7 +508,31 @@ class Markov(Module): return None except sqlite3.Error as e: db.close() - print('sqlite error: ' + str(e)) + print('sqlite error in Markov._get_max_chain_id: ' + str(e)) + raise + + def _get_context_id_for_target(self, target): + + """Get the context ID for the desired/input target.""" + + try: + db = self.get_db() + query = ''' + SELECT mc.id FROM markov_context mc + INNER JOIN markov_target_to_context_map mt + ON mt.context_id = mc.id + WHERE mt.target = ? + ''' + cursor = db.execute(query, (target,)) + result = cursor.fetchone() + db.close() + if result: + return result['id'] + else: + return None + except sqlite3.Error as e: + db.close() + print('sqlite error in Markov._get_context_id_for_target: ' + str(e)) raise # vi:tabstop=4:expandtab:autoindent diff --git a/scripts/import-file-into-markov_chain.py b/scripts/import-file-into-markov_chain.py index 8f415c9..e55e048 100644 --- a/scripts/import-file-into-markov_chain.py +++ b/scripts/import-file-into-markov_chain.py @@ -21,16 +21,16 @@ import os import sqlite3 import sys -parser = argparse.ArgumentParser(description='Import lines into the specified context.') -parser.add_argument('context', metavar='CONTEXT', type=str, nargs=1) +parser = argparse.ArgumentParser(description='Import lines into the specified context_id.') +parser.add_argument('context_id', metavar='CONTEXT', type=int, nargs=1) args = parser.parse_args() -print(args.context[0]) +print(args.context_id[0]) db = sqlite3.connect('dr.botzo.data') db.row_factory = sqlite3.Row cur = db.cursor() -statement = 'INSERT INTO markov_chain (k1, k2, v, context) VALUES (?, ?, ?, ?)' +statement = 'INSERT INTO markov_chain (k1, k2, v, context_id) VALUES (?, ?, ?, ?)' for line in sys.stdin: # set up the head of the chain w1 = '__start1' @@ -39,7 +39,7 @@ for line in sys.stdin: # for each word pair, add the next word to the dictionary for word in line.split(): try: - cur.execute(statement, (w1.decode('utf-8', 'replace').lower(), w2.decode('utf-8', 'replace').lower(), word.decode('utf-8', 'replace').lower(), args.context[0])) + cur.execute(statement, (w1.decode('utf-8', 'replace'), w2.decode('utf-8', 'replace'), word.decode('utf-8', 'replace'), args.context_id[0])) except sqlite3.Error as e: db.rollback() print("sqlite error: " + str(e)) @@ -48,7 +48,7 @@ for line in sys.stdin: w1, w2 = w2, word try: - cur.execute(statement, (w1.decode('utf-8', 'replace').lower(), w2.decode('utf-8', 'replace').lower(), '__stop', args.context[0])) + cur.execute(statement, (w1.decode('utf-8', 'replace'), w2.decode('utf-8', 'replace'), '__stop', args.context_id[0])) db.commit() except sqlite3.Error as e: db.rollback()