diff --git a/modules/Markov.py b/modules/Markov.py index 3e1fd2c..51287ff 100644 --- a/modules/Markov.py +++ b/modules/Markov.py @@ -76,39 +76,33 @@ class Markov(Module): """Create the markov chain table.""" version = self.db_module_registered(self.__class__.__name__) - if (version == None): + if (version == None or version < 9): db = self.get_db() try: - db.execute(''' - CREATE TABLE markov_chain ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - k1 TEXT NOT NULL, - k2 TEXT NOT NULL, - v TEXT NOT NULL - )''') - sql = 'INSERT INTO drbotzo_modules VALUES (?,?)' - db.execute(sql, (self.__class__.__name__, 1)) - db.commit() - db.close() - version = 1 + version = 9 + + # recreating the tables, since i need to add some foreign key constraints + db.execute('''DROP INDEX IF EXISTS markov_chain_keys_and_context_index''') + db.execute('''DROP INDEX IF EXISTS markov_chain_keys_index''') + db.execute('''DROP INDEX IF EXISTS markov_chain_value_and_context_index''') + db.execute('''DROP TABLE IF EXISTS markov_chain''') + db.execute('''DROP TABLE IF EXISTS markov_target_to_context_map''') + db.execute('''DROP TABLE IF EXISTS markov_chatter_target''') + db.execute('''DROP TABLE IF EXISTS markov_context''') - self._learn_line('') - except sqlite3.Error as e: - db.rollback() - db.close() - print("sqlite error: " + str(e)) - raise - if (version < 2): - db = self.get_db() - try: db.execute(''' - ALTER TABLE markov_chain - ADD COLUMN context TEXT DEFAULT NULL''') + CREATE TABLE markov_chatter_target ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + target TEXT NOT NULL, + chance INTEGER NOT NULL DEFAULT 99999 + )''') + db.execute(''' CREATE TABLE markov_context ( id INTEGER PRIMARY KEY AUTOINCREMENT, context TEXT NOT NULL )''') + db.execute(''' CREATE TABLE markov_target_to_context_map ( id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -116,101 +110,30 @@ class Markov(Module): context_id INTEGER NOT NULL, FOREIGN KEY(context_id) REFERENCES markov_context(id) )''') - db.execute('UPDATE drbotzo_modules SET version = ? WHERE module = ?', - (2, self.__class__.__name__)) - db.commit() - db.close() - version = 2 - except sqlite3.Error as e: - db.rollback() - db.close() - print('sqlite error: ' + str(e)) - raise - if (version < 3): - db = self.get_db() - try: + db.execute(''' - CREATE INDEX markov_chain_keys_index - ON markov_chain (k1, k2)''') - db.execute('UPDATE drbotzo_modules SET version = ? WHERE module = ?', - (3, self.__class__.__name__)) - db.commit() - db.close() - version = 3 - except sqlite3.Error as e: - db.rollback() - db.close() - print('sqlite error: ' + str(e)) - raise - if (version < 4): - db = self.get_db() - try: - db.execute('UPDATE drbotzo_modules SET version = ? WHERE module = ?', - (4, self.__class__.__name__)) - db.commit() - db.close() - version = 4 - except sqlite3.Error as e: - db.rollback() - db.close() - print('sqlite error: ' + str(e)) - raise - if (version < 5): - db = self.get_db() - try: - version = 5 - db.execute(''' - CREATE TABLE markov_chatter_target ( + CREATE TABLE markov_chain ( id INTEGER PRIMARY KEY AUTOINCREMENT, - target TEXT NOT NULL + k1 TEXT NOT NULL, + k2 TEXT NOT NULL, + v TEXT NOT NULL, + context_id INTEGER DEFAULT NULL, + FOREIGN KEY(context_id) REFERENCES markov_context(id) )''') - db.commit() - db.close() - self.db_register_module_version(self.__class__.__name__, version) - except sqlite3.Error as e: - db.rollback() - db.close() - print('sqlite error: ' + str(e)) - raise - if (version < 6): - db = self.get_db() - try: - version = 6 + db.execute(''' - CREATE INDEX markov_chain_keys_and_context_index - ON markov_chain (k1, k2, context)''') - db.commit() - db.close() - self.db_register_module_version(self.__class__.__name__, version) - except sqlite3.Error as e: - db.rollback() - db.close() - print('sqlite error: ' + str(e)) - raise - if (version < 7): - db = self.get_db() - try: - version = 7 + CREATE INDEX markov_chain_keys_and_context_id_index + ON markov_chain (k1, k2, context_id)''') + db.execute(''' - ALTER TABLE markov_chatter_target ADD COLUMN chance INTEGER NOT NULL DEFAULT 99999''') - db.commit() - db.close() - self.db_register_module_version(self.__class__.__name__, version) - except sqlite3.Error as e: - db.rollback() - db.close() - print('sqlite error: ' + str(e)) - raise - if (version < 8): - db = self.get_db() - try: - version = 8 - db.execute(''' - CREATE INDEX markov_chain_value_and_context_index - ON markov_chain (v, context)''') + CREATE INDEX markov_chain_value_and_context_id_index + ON markov_chain (v, context_id)''') + db.commit() db.close() self.db_register_module_version(self.__class__.__name__, version) + + self._learn_line('','') except sqlite3.Error as e: db.rollback() db.close() @@ -255,11 +178,14 @@ class Markov(Module): target = event.target() if self.trainre.search(what): - return self.reply(connection, event, self.markov_train(connection, event, nick, userhost, what, admin_unlocked)) + return self.reply(connection, event, self.markov_train(connection, event, nick, + userhost, what, admin_unlocked)) elif self.learnre.search(what): - return self.reply(connection, event, self.markov_learn(connection, event, nick, userhost, what, admin_unlocked)) + return self.reply(connection, event, self.markov_learn(connection, event, nick, + userhost, what, admin_unlocked)) elif self.replyre.search(what) and not self.shut_up: - return self.reply(connection, event, self.markov_reply(connection, event, nick, userhost, what, admin_unlocked)) + return self.reply(connection, event, self.markov_reply(connection, event, nick, + userhost, what, admin_unlocked)) if not self.shut_up: # not a command, so see if i'm being mentioned @@ -269,7 +195,8 @@ class Markov(Module): if addressed_re.match(what): # i was addressed directly, so respond, addressing the speaker self.lines_seen.append(('.self.said.', datetime.now())) - return self.reply(connection, event, '{0:s}: {1:s}'.format(nick, self._generate_line(target, line=addressed_re.match(what).group(1)))) + return self.reply(connection, event, '{0:s}: {1:s}'.format(nick, + self._generate_line(target, line=addressed_re.match(what).group(1)))) else: # i wasn't addressed directly, so just respond self.lines_seen.append(('.self.said.', datetime.now())) @@ -365,7 +292,8 @@ class Markov(Module): for (nick,then) in self.lines_seen: rdelta = relativedelta(datetime.now(), then) - if rdelta.years == 0 and rdelta.months == 0 and rdelta.days == 0 and rdelta.hours == 0 and rdelta.minutes == 0 and rdelta.seconds <= 29: + if (rdelta.years == 0 and rdelta.months == 0 and rdelta.days == 0 and + rdelta.hours == 0 and rdelta.minutes == 0 and rdelta.seconds <= 29): last_30_sec_lines.append((nick,then)) if len(last_30_sec_lines) >= 8: @@ -374,7 +302,8 @@ class Markov(Module): self.shut_up = True targets = self._get_chatter_targets() for t in targets: - self.sendmsg(self.connection, t['target'], 'shutting up for 30 seconds due to last 30 seconds of activity') + self.sendmsg(self.connection, t['target'], + 'shutting up for 30 seconds due to last 30 seconds of activity') def _learn_line(self, line, target): """Create Markov chains from the provided line.""" @@ -383,10 +312,10 @@ class Markov(Module): k1 = self.start1 k2 = self.start2 - context = target + context_id = self._get_context_id_for_target(target) - # if there's no context, this is probably a sub-command. don't learn it - if context: + # if there's no target, this is probably a sub-command. don't learn it + if target: words = line.split() if len(words) <= 0: @@ -395,18 +324,20 @@ class Markov(Module): try: db = self.get_db() cur = db.cursor() - statement = 'INSERT INTO markov_chain (k1, k2, v, context) VALUES (?, ?, ?, ?)' + statement = 'INSERT INTO markov_chain (k1, k2, v, context_id) VALUES (?, ?, ?, ?)' for word in words: - cur.execute(statement, (k1.decode('utf-8', 'replace').lower(), k2.decode('utf-8', 'replace').lower(), word.decode('utf-8', 'replace').lower(), context)) + cur.execute(statement, (k1.decode('utf-8', 'replace'), + k2.decode('utf-8', 'replace'), word.decode('utf-8', 'replace'), context_id)) k1, k2 = k2, word - cur.execute(statement, (k1.decode('utf-8', 'replace').lower(), k2.decode('utf-8', 'replace').lower(), self.stop, context)) + cur.execute(statement, (k1.decode('utf-8', 'replace'), + k2.decode('utf-8', 'replace'), self.stop, context_id)) db.commit() db.close() except sqlite3.Error as e: db.rollback() db.close() - print("sqlite error: " + str(e)) + print("sqlite error in Markov._learn_line: " + str(e)) raise def _generate_line(self, target, line='', min_size=15, max_size=100): @@ -427,7 +358,7 @@ class Markov(Module): words = line.split() target_word = words[random.randint(0, len(words)-1)] - context = target + context_id = self._get_context_id_for_target(target) # start with an empty chain, and work from there gen_words = [self.start1, self.start2] @@ -438,7 +369,7 @@ class Markov(Module): # we'll just pick a word and work backwards if gen_words[-1] == self.start2 and target_word != '': working_backwards = [] - key_hits = self._retrieve_k2_for_value(target_word, context) + key_hits = self._retrieve_k2_for_value(target_word, context_id) if len(key_hits): working_backwards.append(target_word) # generate new word @@ -446,7 +377,7 @@ class Markov(Module): target_word = words[random.randint(0, len(words)-1)] # work backwards until we randomly bump into a start while True: - key_hits = self._retrieve_k2_for_value(working_backwards[0], context) + key_hits = self._retrieve_k2_for_value(working_backwards[0], context_id) if target_word in key_hits: found_word = target_word # generate new word @@ -464,7 +395,7 @@ class Markov(Module): else: working_backwards.insert(0, found_word) - key_hits = self._retrieve_chains_for_key(gen_words[-2], gen_words[-1], context) + key_hits = self._retrieve_chains_for_key(gen_words[-2], gen_words[-1], context_id) # use the chain that includes the target word, if it is found if target_word != '' and target_word in key_hits: gen_words.append(target_word) @@ -487,7 +418,7 @@ class Markov(Module): return ' '.join(gen_words).encode('utf-8', 'ignore') - def _retrieve_chains_for_key(self, k1, k2, context): + def _retrieve_chains_for_key(self, k1, k2, context_id): """Get the value(s) for a given key (a pair of strings).""" values = [] @@ -499,10 +430,12 @@ class Markov(Module): # a faster fashion than selecting all starts max_id = self._get_max_chain_id() rand_id = random.randint(1,max_id) - query = 'SELECT v FROM markov_chain WHERE k1 = ? AND k2 = ? AND (context = ? OR context IS NULL) AND id >= {0:d} LIMIT 1'.format(rand_id) + query = ('SELECT v FROM markov_chain WHERE k1 = ? AND k2 = ? AND ' + '(context_id = ? OR context_id IS NULL) AND id >= {0:d} LIMIT 1'.format(rand_id)) else: - query = 'SELECT v FROM markov_chain WHERE k1 = ? AND k2 = ? AND (context = ? OR context IS NULL)' - cursor = db.execute(query, (k1,k2,context)) + query = ('SELECT v FROM markov_chain WHERE k1 = ? AND k2 = ? AND ' + '(context_id = ? OR context_id IS NULL)') + cursor = db.execute(query, (k1,k2,context_id)) results = cursor.fetchall() for result in results: @@ -512,17 +445,17 @@ class Markov(Module): return values except sqlite3.Error as e: db.close() - print('sqlite error: ' + str(e)) + print('sqlite error in Markov._retrieve_chains_for_key: ' + str(e)) raise - def _retrieve_k2_for_value(self, v, context): + def _retrieve_k2_for_value(self, v, context_id): """Get the value(s) for a given key (a pair of strings).""" values = [] try: db = self.get_db() - query = 'SELECT k2 FROM markov_chain WHERE v = ? AND (context = ? OR context IS NULL)' - cursor = db.execute(query, (v,context)) + query = 'SELECT k2 FROM markov_chain WHERE v = ? AND (context_id = ? OR context_id IS NULL)' + cursor = db.execute(query, (v,context_id)) results = cursor.fetchall() for result in results: @@ -532,7 +465,7 @@ class Markov(Module): return values except sqlite3.Error as e: db.close() - print('sqlite error: ' + str(e)) + print('sqlite error in Markov._retrieve_k2_for_value: ' + str(e)) raise def _get_chatter_targets(self): @@ -548,7 +481,7 @@ class Markov(Module): return results except sqlite3.Error as e: db.close() - print('sqlite error: ' + str(e)) + print('sqlite error in Markov._get_chatter_targets: ' + str(e)) raise def _get_one_chatter_target(self): @@ -575,7 +508,31 @@ class Markov(Module): return None except sqlite3.Error as e: db.close() - print('sqlite error: ' + str(e)) + print('sqlite error in Markov._get_max_chain_id: ' + str(e)) + raise + + def _get_context_id_for_target(self, target): + + """Get the context ID for the desired/input target.""" + + try: + db = self.get_db() + query = ''' + SELECT mc.id FROM markov_context mc + INNER JOIN markov_target_to_context_map mt + ON mt.context_id = mc.id + WHERE mt.target = ? + ''' + cursor = db.execute(query, (target,)) + result = cursor.fetchone() + db.close() + if result: + return result['id'] + else: + return None + except sqlite3.Error as e: + db.close() + print('sqlite error in Markov._get_context_id_for_target: ' + str(e)) raise # vi:tabstop=4:expandtab:autoindent diff --git a/scripts/import-file-into-markov_chain.py b/scripts/import-file-into-markov_chain.py index 8f415c9..e55e048 100644 --- a/scripts/import-file-into-markov_chain.py +++ b/scripts/import-file-into-markov_chain.py @@ -21,16 +21,16 @@ import os import sqlite3 import sys -parser = argparse.ArgumentParser(description='Import lines into the specified context.') -parser.add_argument('context', metavar='CONTEXT', type=str, nargs=1) +parser = argparse.ArgumentParser(description='Import lines into the specified context_id.') +parser.add_argument('context_id', metavar='CONTEXT', type=int, nargs=1) args = parser.parse_args() -print(args.context[0]) +print(args.context_id[0]) db = sqlite3.connect('dr.botzo.data') db.row_factory = sqlite3.Row cur = db.cursor() -statement = 'INSERT INTO markov_chain (k1, k2, v, context) VALUES (?, ?, ?, ?)' +statement = 'INSERT INTO markov_chain (k1, k2, v, context_id) VALUES (?, ?, ?, ?)' for line in sys.stdin: # set up the head of the chain w1 = '__start1' @@ -39,7 +39,7 @@ for line in sys.stdin: # for each word pair, add the next word to the dictionary for word in line.split(): try: - cur.execute(statement, (w1.decode('utf-8', 'replace').lower(), w2.decode('utf-8', 'replace').lower(), word.decode('utf-8', 'replace').lower(), args.context[0])) + cur.execute(statement, (w1.decode('utf-8', 'replace'), w2.decode('utf-8', 'replace'), word.decode('utf-8', 'replace'), args.context_id[0])) except sqlite3.Error as e: db.rollback() print("sqlite error: " + str(e)) @@ -48,7 +48,7 @@ for line in sys.stdin: w1, w2 = w2, word try: - cur.execute(statement, (w1.decode('utf-8', 'replace').lower(), w2.decode('utf-8', 'replace').lower(), '__stop', args.context[0])) + cur.execute(statement, (w1.decode('utf-8', 'replace'), w2.decode('utf-8', 'replace'), '__stop', args.context_id[0])) db.commit() except sqlite3.Error as e: db.rollback()