"""
Markov - Chatterbot via Markov chains for IRC
Copyright (C) 2010  Brian S. Stephan

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program.  If not, see <http://www.gnu.org/licenses/>.
"""

import cPickle
from datetime import datetime
import os
import random
import re
import sqlite3
import sys
import thread
import time

from dateutil.parser import *
from dateutil.relativedelta import *
from extlib import irclib

from Module import Module

class Markov(Module):

    """
    Create a chatterbot very similar to a MegaHAL, but simpler and
    implemented in pure Python. Proof of concept code from Ape.

    Ape wrote: based on this:
    http://uswaretech.com/blog/2009/06/pseudo-random-text-markov-chains-python/
    and this:
    http://code.activestate.com/recipes/194364-the-markov-chain-algorithm/
    """

    def __init__(self, irc, config, server):
        """Create the Markov chainer, and learn text from a file if available."""

        # set up some keywords for use in the chains --- don't change these
        # once you've created a brain
        self.start1 = '__start1'
        self.start2 = '__start2'
        self.stop = '__stop'

        # set up regexes, for replying to specific stuff
        trainpattern = '^!markov\s+train\s+(.*)$'
        learnpattern = '^!markov\s+learn\s+(.*)$'
        replypattern = '^!markov\s+reply(\s+min=(\d+))?(\s+max=(\d+))?(\s+(.*)$|$)'

        self.trainre = re.compile(trainpattern)
        self.learnre = re.compile(learnpattern)
        self.replyre = re.compile(replypattern)

        self.shut_up = False
        self.lines_seen = []

        Module.__init__(self, irc, config, server)

        self.next_shut_up_check = 0
        self.next_chatter_check = 0
        self.connection = None
        thread.start_new_thread(self.thread_do, ())

    def db_init(self):
        """Create the markov chain table."""

        version = self.db_module_registered(self.__class__.__name__)
        if (version == None or version < 9):
            db = self.get_db()
            try:
                version = 9

                # recreating the tables, since i need to add some foreign key constraints
                db.execute('''DROP INDEX IF EXISTS markov_chain_keys_and_context_index''')
                db.execute('''DROP INDEX IF EXISTS markov_chain_keys_index''')
                db.execute('''DROP INDEX IF EXISTS markov_chain_value_and_context_index''')
                db.execute('''DROP TABLE IF EXISTS markov_chain''')
                db.execute('''DROP TABLE IF EXISTS markov_target_to_context_map''')
                db.execute('''DROP TABLE IF EXISTS markov_chatter_target''')
                db.execute('''DROP TABLE IF EXISTS markov_context''')

                db.execute('''
                    CREATE TABLE markov_chatter_target (
                        id INTEGER PRIMARY KEY AUTOINCREMENT,
                        target TEXT NOT NULL,
                        chance INTEGER NOT NULL DEFAULT 99999
                    )''')

                db.execute('''
                    CREATE TABLE markov_context (
                        id INTEGER PRIMARY KEY AUTOINCREMENT,
                        context TEXT NOT NULL
                    )''')

                db.execute('''
                    CREATE TABLE markov_target_to_context_map (
                        id INTEGER PRIMARY KEY AUTOINCREMENT,
                        target TEXT NOT NULL,
                        context_id INTEGER NOT NULL,
                        FOREIGN KEY(context_id) REFERENCES markov_context(id)
                    )''')

                db.execute('''
                    CREATE TABLE markov_chain (
                        id INTEGER PRIMARY KEY AUTOINCREMENT,
                        k1 TEXT NOT NULL,
                        k2 TEXT NOT NULL,
                        v TEXT NOT NULL,
                        context_id INTEGER DEFAULT NULL,
                        FOREIGN KEY(context_id) REFERENCES markov_context(id)
                    )''')

                db.execute('''
                    CREATE INDEX markov_chain_keys_and_context_id_index
                        ON markov_chain (k1, k2, context_id)''')

                db.execute('''
                    CREATE INDEX markov_chain_value_and_context_id_index
                        ON markov_chain (v, context_id)''')

                db.commit()
                db.close()
                self.db_register_module_version(self.__class__.__name__, version)

                self._learn_line('','')
            except sqlite3.Error as e:
                db.rollback()
                db.close()
                print('sqlite error: ' + str(e))
                raise

    def register_handlers(self):
        """Handle pubmsg/privmsg, to learn and/or reply to IRC events."""

        self.server.add_global_handler('pubmsg', self.on_pub_or_privmsg, self.priority())
        self.server.add_global_handler('privmsg', self.on_pub_or_privmsg, self.priority())
        self.server.add_global_handler('pubmsg', self.learn_from_irc_event)
        self.server.add_global_handler('privmsg', self.learn_from_irc_event)

    def unregister_handlers(self):
        self.server.remove_global_handler('pubmsg', self.on_pub_or_privmsg)
        self.server.remove_global_handler('privmsg', self.on_pub_or_privmsg)
        self.server.remove_global_handler('pubmsg', self.learn_from_irc_event)
        self.server.remove_global_handler('privmsg', self.learn_from_irc_event)

    def learn_from_irc_event(self, connection, event):
        """Learn from IRC events."""

        what = ''.join(event.arguments()[0])
        my_nick = connection.get_nickname()
        what = re.sub('^' + my_nick + '[:,]\s+', '', what)
        target = event.target()
        nick = irclib.nm_to_n(event.source())

        self.lines_seen.append((nick, datetime.now()))
        self.connection = connection

        # don't learn from commands
        if self.trainre.search(what) or self.learnre.search(what) or self.replyre.search(what):
            return

        self._learn_line(what, target)

    def do(self, connection, event, nick, userhost, what, admin_unlocked):
        """Handle commands and inputs."""

        target = event.target()

        if self.trainre.search(what):
            return self.reply(connection, event, self.markov_train(connection, event, nick,
              userhost, what, admin_unlocked))
        elif self.learnre.search(what):
            return self.reply(connection, event, self.markov_learn(connection, event, nick,
              userhost, what, admin_unlocked))
        elif self.replyre.search(what) and not self.shut_up:
            return self.reply(connection, event, self.markov_reply(connection, event, nick,
              userhost, what, admin_unlocked))

        if not self.shut_up:
            # not a command, so see if i'm being mentioned
            if re.search(connection.get_nickname(), what, re.IGNORECASE) is not None:
                addressed_pattern = '^' + connection.get_nickname() + '[:,]\s+(.*)'
                addressed_re = re.compile(addressed_pattern)
                if addressed_re.match(what):
                    # i was addressed directly, so respond, addressing the speaker
                    self.lines_seen.append(('.self.said.', datetime.now()))
                    return self.reply(connection, event, '{0:s}: {1:s}'.format(nick,
                      self._generate_line(target, line=addressed_re.match(what).group(1))))
                else:
                    # i wasn't addressed directly, so just respond
                    self.lines_seen.append(('.self.said.', datetime.now()))
                    return self.reply(connection, event, '{0:s}'.format(self._generate_line(target, line=what)))

    def markov_train(self, connection, event, nick, userhost, what, admin_unlocked):
        """Learn lines from a file. Good for initializing a brain."""

        match = self.trainre.search(what)
        if match and admin_unlocked:
            filename = match.group(1)

            try:
                for line in open(filename, 'r'):
                    self._learn_line(line)

                return 'Learned from \'{0:s}\'.'.format(filename)
            except IOError:
                return 'No such file \'{0:s}\'.'.format(filename)

    def markov_learn(self, connection, event, nick, userhost, what, admin_unlocked):
        """Learn one line, as provided to the command."""

        target = event.target()
        match = self.learnre.search(what)
        if match:
            line = match.group(1)
            self._learn_line(line, target)

            # return what was learned, for weird chaining purposes
            return line

    def markov_reply(self, connection, event, nick, userhost, what, admin_unlocked):
        """Generate a reply to one line, without learning it."""

        target = event.target()
        match = self.replyre.search(what)
        if match:
            min_size = 15
            max_size = 100

            if match.group(2):
                min_size = int(match.group(2))
            if match.group(4):
                max_size = int(match.group(4))

            if match.group(5) != '':
                line = match.group(6)
                self.lines_seen.append(('.self.said.', datetime.now()))
                return self._generate_line(target, line=line, min_size=min_size, max_size=max_size)
            else:
                self.lines_seen.append(('.self.said.', datetime.now()))
                return self._generate_line(target, min_size=min_size, max_size=max_size)

    def thread_do(self):
        """Do various things."""

        while not self.is_shutdown:
            self._do_shut_up_checks()
            self._do_random_chatter_check()
            time.sleep(1)

    def _do_random_chatter_check(self):
        """Randomly say something to a channel."""

        # don't immediately potentially chatter, let the bot
        # join channels first
        if self.next_chatter_check == 0:
            self.next_chatter_check = time.time() + 600

        if self.next_chatter_check < time.time():
            self.next_chatter_check = time.time() + 600

            if self.connection is None:
                # i haven't seen any text yet...
                return

            targets = self._get_chatter_targets()
            for t in targets:
                if t['chance'] > 0:
                    a = random.randint(1, t['chance'])
                    if a == 1:
                        self.sendmsg(self.connection, t['target'], self._generate_line(t['target']))

    def _do_shut_up_checks(self):
        """Check to see if we've been talking too much, and shut up if so."""

        if self.next_shut_up_check < time.time():
            self.shut_up = False
            self.next_shut_up_check = time.time() + 30

            last_30_sec_lines = []

            for (nick,then) in self.lines_seen:
                rdelta = relativedelta(datetime.now(), then)
                if (rdelta.years == 0 and rdelta.months == 0 and rdelta.days == 0 and
                  rdelta.hours == 0 and rdelta.minutes == 0 and rdelta.seconds <= 29):
                    last_30_sec_lines.append((nick,then))

            if len(last_30_sec_lines) >= 8:
                lines_i_said = len(filter(lambda (a,b): a == '.self.said.', last_30_sec_lines))
                if lines_i_said >= 8:
                    self.shut_up = True
                    targets = self._get_chatter_targets()
                    for t in targets:
                        self.sendmsg(self.connection, t['target'],
                          'shutting up for 30 seconds due to last 30 seconds of activity')

    def _learn_line(self, line, target):
        """Create Markov chains from the provided line."""

        # set up the head of the chain
        k1 = self.start1
        k2 = self.start2

        context_id = self._get_context_id_for_target(target)

        # if there's no target, this is probably a sub-command. don't learn it
        if target:

            words = line.split()
            if len(words) <= 0:
                return line

            try:
                db = self.get_db()
                cur = db.cursor()
                statement = 'INSERT INTO markov_chain (k1, k2, v, context_id) VALUES (?, ?, ?, ?)'
                for word in words:
                    cur.execute(statement, (k1.decode('utf-8', 'replace'),
                      k2.decode('utf-8', 'replace'), word.decode('utf-8', 'replace'), context_id))
                    k1, k2 = k2, word
                cur.execute(statement, (k1.decode('utf-8', 'replace'),
                  k2.decode('utf-8', 'replace'), self.stop, context_id))

                db.commit()
                db.close()
            except sqlite3.Error as e:
                db.rollback()
                db.close()
                print("sqlite error in Markov._learn_line: " + str(e))
                raise

    def _generate_line(self, target, line='', min_size=15, max_size=100):
        """Reply to a line, using some text in the line as a point in the chain."""

        # if the limit is too low, there's nothing to do
        if (max_size <= 3):
            raise Exception("max_size is too small: %d" % max_size)

        # if the min is too large, abort
        if (min_size > 20):
            raise Exception("min_size is too large: %d" % min_size)

        words = []
        target_word = ''
        # get a random word from the input
        if line != '':
            words = line.split()
            target_word = words[random.randint(0, len(words)-1)]

        context_id = self._get_context_id_for_target(target)

        # start with an empty chain, and work from there
        gen_words = [self.start1, self.start2]

        # walk a chain, randomly, building the list of words
        while len(gen_words) < max_size + 2 and gen_words[-1] != self.stop:
            # first, see if we have an empty response and a target word.
            # we'll just pick a word and work backwards
            if gen_words[-1] == self.start2 and target_word != '':
                working_backwards = []
                key_hits = self._retrieve_k2_for_value(target_word, context_id)
                if len(key_hits):
                    working_backwards.append(target_word)
                    # generate new word
                    found_word = ''
                    target_word = words[random.randint(0, len(words)-1)]
                    # work backwards until we randomly bump into a start
                    while True:
                        key_hits = self._retrieve_k2_for_value(working_backwards[0], context_id)
                        if target_word in key_hits:
                            found_word = target_word
                            # generate new word
                            if len(filter(lambda a: a != target_word, words)) > 1 and False:
                                # if we have more than one target word, get a new one (otherwise give up)
                                target_word = random.choice(filter(lambda a: a != target_word, words))
                            else:
                                target_word = ''
                        else:
                            found_word = random.choice(filter(lambda a: a != self.stop, key_hits))

                        if found_word == self.start2 or len(working_backwards) >= max_size + 2:
                            gen_words = gen_words + working_backwards
                            break
                        else:
                            working_backwards.insert(0, found_word)

            key_hits = self._retrieve_chains_for_key(gen_words[-2], gen_words[-1], context_id)
            # use the chain that includes the target word, if it is found
            if target_word != '' and target_word in key_hits:
                gen_words.append(target_word)
                # generate new word
                target_word = words[random.randint(0, len(words)-1)]
            else:
                if len(gen_words) < min_size and len(filter(lambda a: a != self.stop, key_hits)) > 0:
                    gen_words.append(random.choice(filter(lambda a: a != self.stop, key_hits)))
                elif len(key_hits) <= 0:
                    gen_words.append(self.stop)
                else:
                    gen_words.append(random.choice(key_hits))

        # chop off the seed data at the start
        gen_words = gen_words[2:]

        # chop off the end text, if it was the keyword indicating an end of chain
        if gen_words[-1] == self.stop:
            gen_words = gen_words[:-1]

        return ' '.join(gen_words).encode('utf-8', 'ignore')

    def _retrieve_chains_for_key(self, k1, k2, context_id):
        """Get the value(s) for a given key (a pair of strings)."""

        values = []
        try:
            db = self.get_db()
            query = ''
            if k1 == self.start1 and k2 == self.start2:
                # hack. get a quasi-random start from the database, in
                # a faster fashion than selecting all starts
                max_id = self._get_max_chain_id()
                rand_id = random.randint(1,max_id)
                query = ('SELECT v FROM markov_chain WHERE k1 = ? AND k2 = ? AND '
                  '(context_id = ? OR context_id IS NULL) AND id >= {0:d} LIMIT 1'.format(rand_id))
            else:
                query = ('SELECT v FROM markov_chain WHERE k1 = ? AND k2 = ? AND '
                  '(context_id = ? OR context_id IS NULL)')
            cursor = db.execute(query, (k1,k2,context_id))
            results = cursor.fetchall()

            for result in results:
                values.append(result['v'])

            db.close()
            return values
        except sqlite3.Error as e:
            db.close()
            print('sqlite error in Markov._retrieve_chains_for_key: ' + str(e))
            raise

    def _retrieve_k2_for_value(self, v, context_id):
        """Get the value(s) for a given key (a pair of strings)."""

        values = []
        try:
            db = self.get_db()
            query = 'SELECT k2 FROM markov_chain WHERE v = ? AND (context_id = ? OR context_id IS NULL)'
            cursor = db.execute(query, (v,context_id))
            results = cursor.fetchall()

            for result in results:
                values.append(result['k2'])

            db.close()
            return values
        except sqlite3.Error as e:
            db.close()
            print('sqlite error in Markov._retrieve_k2_for_value: ' + str(e))
            raise

    def _get_chatter_targets(self):
        """Get all possible chatter targets."""

        values = []
        try:
            # need to create our own db object, since this is likely going to be in a new thread
            db = self.get_db()
            query = 'SELECT target, chance FROM markov_chatter_target'
            cursor = db.execute(query)
            results = cursor.fetchall()
            return results
        except sqlite3.Error as e:
            db.close()
            print('sqlite error in Markov._get_chatter_targets: ' + str(e))
            raise

    def _get_one_chatter_target(self):
        """Select one random chatter target."""

        targets = self._get_chatter_targets()
        if targets:
            return targets[random.randint(0, len(targets)-1)]

    def _get_max_chain_id(self):
        """Get the highest id in the chain table."""

        try:
            db = self.get_db()
            query = '''
                SELECT id FROM markov_chain ORDER BY id DESC LIMIT 1
                '''
            cursor = db.execute(query)
            result = cursor.fetchone()
            db.close()
            if result:
                return result['id']
            else:
                return None
        except sqlite3.Error as e:
            db.close()
            print('sqlite error in Markov._get_max_chain_id: ' + str(e))
            raise

    def _get_context_id_for_target(self, target):

        """Get the context ID for the desired/input target."""

        try:
            db = self.get_db()
            query = '''
                SELECT mc.id FROM markov_context mc
                  INNER JOIN markov_target_to_context_map mt
                    ON mt.context_id = mc.id
                    WHERE mt.target = ?
                '''
            cursor = db.execute(query, (target,))
            result = cursor.fetchone()
            db.close()
            if result:
                return result['id']
            else:
                return None
        except sqlite3.Error as e:
            db.close()
            print('sqlite error in Markov._get_context_id_for_target: ' + str(e))
            raise

# vi:tabstop=4:expandtab:autoindent
# kate: indent-mode python;indent-width 4;replace-tabs on;