719 lines
29 KiB
Python
719 lines
29 KiB
Python
"""
|
|
Markov - Chatterbot via Markov chains for IRC
|
|
Copyright (C) 2010 Brian S. Stephan
|
|
|
|
This program is free software: you can redistribute it and/or modify
|
|
it under the terms of the GNU General Public License as published by
|
|
the Free Software Foundation, either version 3 of the License, or
|
|
(at your option) any later version.
|
|
|
|
This program is distributed in the hope that it will be useful,
|
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
GNU General Public License for more details.
|
|
|
|
You should have received a copy of the GNU General Public License
|
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
|
|
"""
|
|
|
|
from datetime import datetime
|
|
import random
|
|
import re
|
|
import thread
|
|
import time
|
|
|
|
from dateutil.relativedelta import relativedelta
|
|
import MySQLdb as mdb
|
|
|
|
from extlib import irclib
|
|
|
|
from Module import Module
|
|
|
|
class Markov(Module):
|
|
|
|
"""Create a chatterbot very similar to a MegaHAL, but simpler and
|
|
implemented in pure Python. Proof of concept code from Ape.
|
|
|
|
Ape wrote: based on this:
|
|
http://uswaretech.com/blog/2009/06/pseudo-random-text-markov-chains-python/
|
|
and this:
|
|
http://code.activestate.com/recipes/194364-the-markov-chain-algorithm/
|
|
|
|
"""
|
|
|
|
def __init__(self, irc, config):
|
|
"""Create the Markov chainer, and learn text from a file if
|
|
available.
|
|
|
|
"""
|
|
|
|
# set up some keywords for use in the chains --- don't change these
|
|
# once you've created a brain
|
|
self.start1 = '__start1'
|
|
self.start2 = '__start2'
|
|
self.stop = '__stop'
|
|
|
|
# set up regexes, for replying to specific stuff
|
|
learnpattern = '^!markov\s+learn\s+(.*)$'
|
|
replypattern = '^!markov\s+reply(\s+min=(\d+))?(\s+max=(\d+))?(\s+(.*)$|$)'
|
|
|
|
self.learnre = re.compile(learnpattern)
|
|
self.replyre = re.compile(replypattern)
|
|
|
|
self.shut_up = False
|
|
self.lines_seen = []
|
|
|
|
Module.__init__(self, irc, config)
|
|
|
|
self.next_shut_up_check = 0
|
|
self.next_chatter_check = 0
|
|
thread.start_new_thread(self.thread_do, ())
|
|
|
|
irc.xmlrpc_register_function(self._generate_line,
|
|
"markov_generate_line")
|
|
|
|
def db_init(self):
|
|
"""Create the markov chain table."""
|
|
|
|
version = self.db_module_registered(self.__class__.__name__)
|
|
if version == None:
|
|
db = self.get_db()
|
|
try:
|
|
version = 1
|
|
cur = db.cursor(mdb.cursors.DictCursor)
|
|
cur.execute('''
|
|
CREATE TABLE markov_chatter_target (
|
|
id SERIAL,
|
|
target VARCHAR(256) NOT NULL,
|
|
chance INTEGER NOT NULL DEFAULT 99999
|
|
) ENGINE=InnoDB CHARACTER SET utf8 COLLATE utf8_bin
|
|
''')
|
|
cur.execute('''
|
|
CREATE TABLE markov_context (
|
|
id SERIAL,
|
|
context VARCHAR(256) NOT NULL
|
|
) ENGINE=InnoDB CHARACTER SET utf8 COLLATE utf8_bin
|
|
''')
|
|
cur.execute('''
|
|
CREATE TABLE markov_target_to_context_map (
|
|
id SERIAL,
|
|
target VARCHAR(256) NOT NULL,
|
|
context_id BIGINT(20) UNSIGNED NOT NULL,
|
|
FOREIGN KEY(context_id) REFERENCES markov_context(id)
|
|
) ENGINE=InnoDB CHARACTER SET utf8 COLLATE utf8_bin
|
|
''')
|
|
cur.execute('''
|
|
CREATE TABLE markov_chain (
|
|
id SERIAL,
|
|
k1 VARCHAR(128) NOT NULL,
|
|
k2 VARCHAR(128) NOT NULL,
|
|
v VARCHAR(128) NOT NULL,
|
|
context_id BIGINT(20) UNSIGNED NOT NULL,
|
|
FOREIGN KEY(context_id) REFERENCES markov_context(id)
|
|
) ENGINE=InnoDB CHARACTER SET utf8 COLLATE utf8_bin
|
|
''')
|
|
cur.execute('''
|
|
CREATE INDEX markov_chain_keys_and_context_id_index
|
|
ON markov_chain (k1, k2, context_id)''')
|
|
|
|
cur.execute('''
|
|
CREATE INDEX markov_chain_value_and_context_id_index
|
|
ON markov_chain (v, context_id)''')
|
|
|
|
db.commit()
|
|
self.db_register_module_version(self.__class__.__name__,
|
|
version)
|
|
except mdb.Error as e:
|
|
db.rollback()
|
|
self.log.error("database error trying to create tables")
|
|
self.log.exception(e)
|
|
raise
|
|
finally: cur.close()
|
|
|
|
def register_handlers(self):
|
|
"""Handle pubmsg/privmsg, to learn and/or reply to IRC events."""
|
|
|
|
self.irc.server.add_global_handler('pubmsg', self.on_pub_or_privmsg,
|
|
self.priority())
|
|
self.irc.server.add_global_handler('privmsg', self.on_pub_or_privmsg,
|
|
self.priority())
|
|
self.irc.server.add_global_handler('pubmsg',
|
|
self.learn_from_irc_event)
|
|
self.irc.server.add_global_handler('privmsg',
|
|
self.learn_from_irc_event)
|
|
|
|
def unregister_handlers(self):
|
|
self.irc.server.remove_global_handler('pubmsg',
|
|
self.on_pub_or_privmsg)
|
|
self.irc.server.remove_global_handler('privmsg',
|
|
self.on_pub_or_privmsg)
|
|
self.irc.server.remove_global_handler('pubmsg',
|
|
self.learn_from_irc_event)
|
|
self.irc.server.remove_global_handler('privmsg',
|
|
self.learn_from_irc_event)
|
|
|
|
def learn_from_irc_event(self, connection, event):
|
|
"""Learn from IRC events."""
|
|
|
|
what = ''.join(event.arguments()[0])
|
|
my_nick = connection.get_nickname()
|
|
what = re.sub('^' + my_nick + '[:,]\s+', '', what)
|
|
target = event.target()
|
|
nick = irclib.nm_to_n(event.source())
|
|
|
|
if not irclib.is_channel(target):
|
|
target = nick
|
|
|
|
self.lines_seen.append((nick, datetime.now()))
|
|
|
|
# don't learn from commands
|
|
if self.learnre.search(what) or self.replyre.search(what):
|
|
return
|
|
|
|
self._learn_line(what, target, event)
|
|
|
|
def do(self, connection, event, nick, userhost, what, admin_unlocked):
|
|
"""Handle commands and inputs."""
|
|
|
|
target = event.target()
|
|
|
|
if self.learnre.search(what):
|
|
return self.irc.reply(event, self.markov_learn(event,
|
|
nick, userhost, what, admin_unlocked))
|
|
elif self.replyre.search(what) and not self.shut_up:
|
|
return self.irc.reply(event, self.markov_reply(event,
|
|
nick, userhost, what, admin_unlocked))
|
|
|
|
if not self.shut_up:
|
|
# not a command, so see if i'm being mentioned
|
|
if re.search(connection.get_nickname(), what, re.IGNORECASE) is not None:
|
|
addressed_pattern = '^' + connection.get_nickname() + '[:,]\s+(.*)'
|
|
addressed_re = re.compile(addressed_pattern)
|
|
if addressed_re.match(what):
|
|
# i was addressed directly, so respond, addressing
|
|
# the speaker
|
|
self.lines_seen.append(('.self.said.', datetime.now()))
|
|
return self.irc.reply(event, '{0:s}: {1:s}'.format(nick,
|
|
self._generate_line(target, line=addressed_re.match(what).group(1))))
|
|
else:
|
|
# i wasn't addressed directly, so just respond
|
|
self.lines_seen.append(('.self.said.', datetime.now()))
|
|
return self.irc.reply(event, '{0:s}'.format(self._generate_line(target, line=what)))
|
|
|
|
def markov_learn(self, event, nick, userhost, what, admin_unlocked):
|
|
"""Learn one line, as provided to the command."""
|
|
|
|
target = event.target()
|
|
|
|
if not irclib.is_channel(target):
|
|
target = nick
|
|
|
|
match = self.learnre.search(what)
|
|
if match:
|
|
line = match.group(1)
|
|
self._learn_line(line, target, event)
|
|
|
|
# return what was learned, for weird chaining purposes
|
|
return line
|
|
|
|
def markov_reply(self, event, nick, userhost, what, admin_unlocked):
|
|
"""Generate a reply to one line, without learning it."""
|
|
|
|
target = event.target()
|
|
|
|
if not irclib.is_channel(target):
|
|
target = nick
|
|
|
|
match = self.replyre.search(what)
|
|
if match:
|
|
min_size = 15
|
|
max_size = 30
|
|
|
|
if match.group(2):
|
|
min_size = int(match.group(2))
|
|
if match.group(4):
|
|
max_size = int(match.group(4))
|
|
|
|
if match.group(5) != '':
|
|
line = match.group(6)
|
|
self.lines_seen.append(('.self.said.', datetime.now()))
|
|
return self._generate_line(target, line=line, min_size=min_size, max_size=max_size)
|
|
else:
|
|
self.lines_seen.append(('.self.said.', datetime.now()))
|
|
return self._generate_line(target, min_size=min_size, max_size=max_size)
|
|
|
|
def thread_do(self):
|
|
"""Do various things."""
|
|
|
|
while not self.is_shutdown:
|
|
self._do_shut_up_checks()
|
|
self._do_random_chatter_check()
|
|
time.sleep(1)
|
|
|
|
def _do_random_chatter_check(self):
|
|
"""Randomly say something to a channel."""
|
|
|
|
# don't immediately potentially chatter, let the bot
|
|
# join channels first
|
|
if self.next_chatter_check == 0:
|
|
self.next_chatter_check = time.time() + 600
|
|
|
|
if self.next_chatter_check < time.time():
|
|
self.next_chatter_check = time.time() + 600
|
|
|
|
targets = self._get_chatter_targets()
|
|
for t in targets:
|
|
if t['chance'] > 0:
|
|
a = random.randint(1, t['chance'])
|
|
if a == 1:
|
|
self.sendmsg(t['target'], self._generate_line(t['target']))
|
|
|
|
def _do_shut_up_checks(self):
|
|
"""Check to see if we've been talking too much, and shut up if so."""
|
|
|
|
if self.next_shut_up_check < time.time():
|
|
self.shut_up = False
|
|
self.next_shut_up_check = time.time() + 30
|
|
|
|
last_30_sec_lines = []
|
|
|
|
for (nick, then) in self.lines_seen:
|
|
rdelta = relativedelta(datetime.now(), then)
|
|
if (rdelta.years == 0 and rdelta.months == 0 and rdelta.days == 0 and
|
|
rdelta.hours == 0 and rdelta.minutes == 0 and rdelta.seconds <= 29):
|
|
last_30_sec_lines.append((nick, then))
|
|
|
|
if len(last_30_sec_lines) >= 8:
|
|
lines_i_said = len(filter(lambda (a, b): a == '.self.said.', last_30_sec_lines))
|
|
if lines_i_said >= 8:
|
|
self.shut_up = True
|
|
targets = self._get_chatter_targets()
|
|
for t in targets:
|
|
self.sendmsg(t['target'],
|
|
'shutting up for 30 seconds due to last 30 seconds of activity')
|
|
|
|
def _learn_line(self, line, target, event):
|
|
"""Create Markov chains from the provided line."""
|
|
|
|
# set up the head of the chain
|
|
k1 = self.start1
|
|
k2 = self.start2
|
|
|
|
context_id = self._get_context_id_for_target(target)
|
|
|
|
# don't learn recursion
|
|
if not event._recursing:
|
|
words = line.split()
|
|
if len(words) == 0:
|
|
return line
|
|
|
|
db = self.get_db()
|
|
try:
|
|
cur = db.cursor(mdb.cursors.DictCursor)
|
|
statement = 'INSERT INTO markov_chain (k1, k2, v, context_id) VALUES (%s, %s, %s, %s)'
|
|
for word in words:
|
|
cur.execute(statement, (k1, k2, word, context_id))
|
|
k1, k2 = k2, word
|
|
cur.execute(statement, (k1, k2, self.stop, context_id))
|
|
|
|
db.commit()
|
|
except mdb.Error as e:
|
|
db.rollback()
|
|
self.log.error("database error learning line")
|
|
self.log.exception(e)
|
|
raise
|
|
finally: cur.close()
|
|
|
|
def _generate_line(self, target, line='', min_size=15, max_size=30):
|
|
"""Create a line, optionally using some text in a seed as a point in
|
|
the chain.
|
|
|
|
Keyword arguments:
|
|
target - the target to retrieve the context for (i.e. a channel or nick)
|
|
line - the line to reply to, by picking a random word and seeding with it
|
|
min_size - the minimum desired size in words. not guaranteed
|
|
max_size - the maximum desired size in words. not guaranteed
|
|
|
|
"""
|
|
|
|
# if the limit is too low, there's nothing to do
|
|
if (max_size <= 3):
|
|
raise Exception("max_size is too small: %d" % max_size)
|
|
|
|
# if the min is too large, abort
|
|
if (min_size > 20):
|
|
raise Exception("min_size is too large: %d" % min_size)
|
|
|
|
seed_words = []
|
|
# shuffle the words in the input
|
|
seed_words = line.split()
|
|
random.shuffle(seed_words)
|
|
self.log.debug("seed words: {0:s}".format(seed_words))
|
|
|
|
# hit to generate a new seed word immediately if possible
|
|
seed_word = None
|
|
hit_word = None
|
|
|
|
context_id = self._get_context_id_for_target(target)
|
|
|
|
# start with an empty chain, and work from there
|
|
gen_words = [self.start1, self.start2]
|
|
|
|
# build a response by creating multiple sentences
|
|
while len(gen_words) < max_size + 2:
|
|
# if we're past the min and on a stop, we can end
|
|
if len(gen_words) > min_size + 2:
|
|
if gen_words[-1] == self.stop:
|
|
break
|
|
|
|
# pick a word from the shuffled seed words, if we need a new one
|
|
if seed_word == hit_word:
|
|
if len(seed_words) > 0:
|
|
seed_word = seed_words.pop()
|
|
self.log.debug("picked new seed word: "
|
|
"{0:s}".format(seed_word))
|
|
else:
|
|
seed_word = None
|
|
self.log.debug("ran out of seed words")
|
|
|
|
# if we have a stop, the word before it might need to be
|
|
# made to look like a sentence end
|
|
if gen_words[-1] == self.stop:
|
|
# chop off the stop, temporarily
|
|
gen_words = gen_words[:-1]
|
|
|
|
# we should have a real word, make it look like a
|
|
# sentence end
|
|
sentence_end = gen_words[-1]
|
|
eos_punctuation = ['!', '?', ',', '.']
|
|
if sentence_end[-1] not in eos_punctuation:
|
|
random.shuffle(eos_punctuation)
|
|
gen_words[-1] = sentence_end + eos_punctuation.pop()
|
|
self.log.debug("monkeyed with end of sentence, it's "
|
|
"now: {0:s}".format(gen_words[-1]))
|
|
|
|
# put the stop back on
|
|
gen_words.append(self.stop)
|
|
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
|
|
|
|
# first, see if we should start a new sentence. if so,
|
|
# work backwards
|
|
if gen_words[-1] in (self.start2, self.stop) and seed_word is not None and 0 == 1:
|
|
# drop a stop, since we're starting another sentence
|
|
if gen_words[-1] == self.stop:
|
|
gen_words = gen_words[:-1]
|
|
|
|
# work backwards from seed_word
|
|
working_backwards = []
|
|
back_k2 = self._retrieve_random_k2_for_value(seed_word, context_id)
|
|
if back_k2:
|
|
found_word = seed_word
|
|
if back_k2 == self.start2:
|
|
self.log.debug("random further back was start2, swallowing")
|
|
else:
|
|
working_backwards.append(back_k2)
|
|
working_backwards.append(found_word)
|
|
self.log.debug("started working backwards with: {0:s}".format(found_word))
|
|
self.log.debug("working_backwards: {0:s}".format(" ".join(working_backwards)))
|
|
|
|
# now work backwards until we randomly bump into a start
|
|
# to steer the chainer away from spending too much time on
|
|
# the weaker-context reverse chaining, we make max_size
|
|
# a non-linear distribution, making it more likely that
|
|
# some time is spent on better forward chains
|
|
max_back = min(random.randint(1, max_size/2) + random.randint(1, max_size/2),
|
|
max_size/4)
|
|
self.log.debug("max_back: {0:d}".format(max_back))
|
|
while len(working_backwards) < max_back:
|
|
back_k2 = self._retrieve_random_k2_for_value(working_backwards[0], context_id)
|
|
if back_k2 == self.start2:
|
|
self.log.debug("random further back was start2, finishing")
|
|
break
|
|
elif back_k2:
|
|
working_backwards.insert(0, back_k2)
|
|
self.log.debug("added '{0:s}' to working_backwards".format(back_k2))
|
|
self.log.debug("working_backwards: {0:s}".format(" ".join(working_backwards)))
|
|
else:
|
|
self.log.debug("nothing (at all!?) further back, finishing")
|
|
break
|
|
|
|
gen_words += working_backwards
|
|
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
|
|
hit_word = gen_words[-1]
|
|
else:
|
|
# we are working forward, with either:
|
|
# * a pair of words (normal path, filling out a sentence)
|
|
# * start1, start2 (completely new chain, no seed words)
|
|
# * stop (new sentence in existing chain, no seed words)
|
|
self.log.debug("working forwards")
|
|
forw_v = None
|
|
if gen_words[-1] in (self.start2, self.stop):
|
|
# case 2 or 3 above, need to work forward on a beginning
|
|
# of a sentence (this is slow)
|
|
if gen_words[-1] == self.stop:
|
|
# remove the stop if it's there
|
|
gen_words = gen_words[:-1]
|
|
|
|
new_sentence = self._create_chain_with_k1_k2(self.start1,
|
|
self.start2,
|
|
3, context_id,
|
|
avoid_address=True)
|
|
|
|
if len(new_sentence) > 0:
|
|
self.log.debug("started new sentence "
|
|
"'{0:s}'".format(" ".join(new_sentence)))
|
|
gen_words += new_sentence
|
|
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
|
|
else:
|
|
# this is a problem. we started a sentence on
|
|
# start1,start2, and still didn't find anything. to
|
|
# avoid endlessly looping we need to abort here
|
|
break
|
|
else:
|
|
if seed_word:
|
|
self.log.debug("preferring: '{0:s}'".format(seed_word))
|
|
forw_v = self._retrieve_random_v_for_k1_and_k2_with_pref(gen_words[-2],
|
|
gen_words[-1],
|
|
seed_word,
|
|
context_id)
|
|
else:
|
|
forw_v = self._retrieve_random_v_for_k1_and_k2(gen_words[-2],
|
|
gen_words[-1],
|
|
context_id)
|
|
|
|
if forw_v:
|
|
gen_words.append(forw_v)
|
|
self.log.debug("added random word '{0:s}' to gen_words".format(forw_v))
|
|
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
|
|
hit_word = gen_words[-1]
|
|
else:
|
|
# append stop. this is an end to a sentence (since
|
|
# we had non-start words to begin with)
|
|
gen_words.append(self.stop)
|
|
self.log.debug("nothing found, added stop")
|
|
self.log.debug("gen_words: {0:s}".format(" ".join(gen_words)))
|
|
|
|
# chop off the seed data at the start
|
|
gen_words = gen_words[2:]
|
|
|
|
if len(gen_words):
|
|
# chop off the end text, if it was the keyword indicating an end of chain
|
|
if gen_words[-1] == self.stop:
|
|
gen_words = gen_words[:-1]
|
|
else:
|
|
self.log.warning("after all this we have an empty list of words. "
|
|
"there probably isn't any data for this context")
|
|
|
|
return ' '.join(gen_words)
|
|
|
|
def _retrieve_random_v_for_k1_and_k2(self, k1, k2, context_id):
|
|
"""Get one v for a given k1,k2."""
|
|
|
|
self.log.debug("searching with '{0:s}','{1:s}'".format(k1, k2))
|
|
values = []
|
|
db = self.get_db()
|
|
try:
|
|
query = '''
|
|
SELECT v FROM markov_chain AS r1
|
|
JOIN (
|
|
SELECT (RAND() * (SELECT MAX(id) FROM markov_chain)) AS id
|
|
) AS r2
|
|
WHERE r1.k1 = %s
|
|
AND r1.k2 = %s
|
|
AND r1.context_id = %s
|
|
ORDER BY r1.id >= r2.id DESC, r1.id ASC
|
|
LIMIT 1
|
|
'''
|
|
cur = db.cursor(mdb.cursors.DictCursor)
|
|
cur.execute(query, (k1, k2, context_id))
|
|
result = cur.fetchone()
|
|
if result:
|
|
self.log.debug("found '{0:s}'".format(result['v']))
|
|
return result['v']
|
|
except mdb.Error as e:
|
|
self.log.error("database error in _retrieve_random_v_for_k1_and_k2")
|
|
self.log.exception(e)
|
|
raise
|
|
finally: cur.close()
|
|
|
|
def _retrieve_random_v_for_k1_and_k2_with_pref(self, k1, k2, prefer, context_id):
|
|
"""Get one v for a given k1,k2.
|
|
|
|
Prefer that the result be prefer, if it's found.
|
|
|
|
"""
|
|
|
|
self.log.debug("searching with '{0:s}','{1:s}', prefer "
|
|
"'{2:s}'".format(k1, k2, prefer))
|
|
values = []
|
|
db = self.get_db()
|
|
try:
|
|
query = '''
|
|
SELECT v FROM markov_chain AS r1
|
|
JOIN (
|
|
SELECT (RAND() * (SELECT MAX(id) FROM markov_chain)) AS id
|
|
) AS r2
|
|
WHERE r1.k1 = %s
|
|
AND r1.k2 = %s
|
|
AND r1.context_id = %s
|
|
ORDER BY r1.id >= r2.id DESC, r1.v = %s DESC, r1.id ASC
|
|
LIMIT 1
|
|
'''
|
|
cur = db.cursor(mdb.cursors.DictCursor)
|
|
cur.execute(query, (k1, k2, context_id, prefer))
|
|
result = cur.fetchone()
|
|
if result:
|
|
self.log.debug("found '{0:s}'".format(result['v']))
|
|
return result['v']
|
|
except mdb.Error as e:
|
|
self.log.error("database error in _retrieve_random_v_for_k1_and_k2_with_pref")
|
|
self.log.exception(e)
|
|
raise
|
|
finally: cur.close()
|
|
|
|
def _retrieve_random_k2_for_value(self, v, context_id):
|
|
"""Get one k2 for a given value."""
|
|
|
|
values = []
|
|
db = self.get_db()
|
|
try:
|
|
query = '''
|
|
SELECT k2 FROM markov_chain AS r1
|
|
JOIN (
|
|
SELECT (RAND() * (SELECT MAX(id) FROM markov_chain)) AS id
|
|
) AS r2
|
|
WHERE r1.v = %s
|
|
AND r1.context_id = %s
|
|
ORDER BY r1.id >= r2.id DESC, r1.id ASC
|
|
LIMIT 1
|
|
'''
|
|
cur = db.cursor(mdb.cursors.DictCursor)
|
|
cur.execute(query, (v, context_id))
|
|
result = cur.fetchone()
|
|
if result:
|
|
return result['k2']
|
|
except mdb.Error as e:
|
|
self.log.error("database error in _retrieve_random_k2_for_value")
|
|
self.log.exception(e)
|
|
raise
|
|
finally: cur.close()
|
|
|
|
def _create_chain_with_k1_k2(self, k1, k2, length, context_id,
|
|
avoid_address=False):
|
|
"""Create a chain of the given length, using k1,k2.
|
|
|
|
k1,k2 does not appear in the resulting chain.
|
|
|
|
"""
|
|
|
|
chain = [k1, k2]
|
|
self.log.debug("creating chain for {0:s},{1:s}".format(k1, k2))
|
|
|
|
for _ in range(length):
|
|
v = self._retrieve_random_v_for_k1_and_k2(chain[-2],
|
|
chain[-1],
|
|
context_id)
|
|
if v:
|
|
chain.append(v)
|
|
|
|
# check for addresses (the "whoever:" in
|
|
# __start1 __start2 whoever: some words)
|
|
addressing_suffixes = [':', ',']
|
|
if len(chain) > 2 and chain[2][-1] in addressing_suffixes and avoid_address:
|
|
return chain[3:]
|
|
elif len(chain) > 2:
|
|
return chain[2:]
|
|
else:
|
|
return []
|
|
|
|
def _get_chatter_targets(self):
|
|
"""Get all possible chatter targets."""
|
|
|
|
db = self.get_db()
|
|
try:
|
|
# need to create our own db object, since this is likely going to be in a new thread
|
|
query = 'SELECT target, chance FROM markov_chatter_target'
|
|
cur = db.cursor(mdb.cursors.DictCursor)
|
|
cur.execute(query)
|
|
results = cur.fetchall()
|
|
return results
|
|
except mdb.Error as e:
|
|
self.log.error("database error in _get_chatter_targets")
|
|
self.log.exception(e)
|
|
raise
|
|
finally: cur.close()
|
|
|
|
def _get_context_id_for_target(self, target):
|
|
"""Get the context ID for the desired/input target."""
|
|
|
|
db = self.get_db()
|
|
try:
|
|
query = '''
|
|
SELECT mc.id FROM markov_context mc
|
|
INNER JOIN markov_target_to_context_map mt
|
|
ON mt.context_id = mc.id
|
|
WHERE mt.target = %s
|
|
'''
|
|
cur = db.cursor(mdb.cursors.DictCursor)
|
|
cur.execute(query, (target,))
|
|
result = cur.fetchone()
|
|
db.close()
|
|
if result:
|
|
return result['id']
|
|
else:
|
|
# auto-generate a context to keep things private
|
|
self._add_context_for_target(target)
|
|
return self._get_context_id_for_target(target)
|
|
except mdb.Error as e:
|
|
self.log.error("database error in _get_context_id_for_target")
|
|
self.log.exception(e)
|
|
raise
|
|
finally: cur.close()
|
|
|
|
def _add_context_for_target(self, target):
|
|
"""Create a new context for the desired/input target."""
|
|
|
|
db = self.get_db()
|
|
try:
|
|
statement = 'INSERT INTO markov_context (context) VALUES (%s)'
|
|
cur = db.cursor(mdb.cursors.DictCursor)
|
|
cur.execute(statement, (target,))
|
|
statement = '''
|
|
INSERT INTO markov_target_to_context_map (target, context_id)
|
|
VALUES (%s, (SELECT id FROM markov_context WHERE context = %s))
|
|
'''
|
|
cur.execute(statement, (target, target))
|
|
db.commit()
|
|
except mdb.Error as e:
|
|
db.rollback()
|
|
self.log.error("database error in _add_context_for_target")
|
|
self.log.exception(e)
|
|
raise
|
|
finally: cur.close()
|
|
|
|
try:
|
|
query = '''
|
|
SELECT mc.id FROM markov_context mc
|
|
INNER JOIN markov_target_to_context_map mt
|
|
ON mt.context_id = mc.id
|
|
WHERE mt.target = %s
|
|
'''
|
|
cur = db.cursor(mdb.cursors.DictCursor)
|
|
cur.execute(query, (target,))
|
|
result = cur.fetchone()
|
|
if result:
|
|
return result['id']
|
|
else:
|
|
# auto-generate a context to keep things private
|
|
self._add_context_for_target(target)
|
|
return self._get_context_id_for_target(target)
|
|
except mdb.Error as e:
|
|
self.log.error("database error in _get_context_id_for_target")
|
|
self.log.exception(e)
|
|
raise
|
|
finally: cur.close()
|
|
|
|
# vi:tabstop=4:expandtab:autoindent
|
|
# kate: indent-mode python;indent-width 4;replace-tabs on;
|