"""Roll dice when asked, intended for RPGs."""

# this breaks yacc, but ply might be happy in py3
#from __future__ import unicode_literals

import math
import re
import random

from irc.client import NickMask

import ply.lex as lex
import ply.yacc as yacc

from ircbot.lib import Plugin

class Dice(Plugin):

    """Roll simple or complex dice strings."""

    def start(self):
        """Set up the handlers."""

        self.roller = DiceRoller()

        self.connection.reactor.add_global_regex_handler(['pubmsg', 'privmsg'], r'^!roll\s+(.*)$',
                                                         self.handle_roll, -20)
        self.connection.reactor.add_global_regex_handler(['pubmsg', 'privmsg'], r'^!ctech\s+(.*)$',
                                                         self.handle_ctech, -20)

        super(Dice, self).start()

    def stop(self):
        """Tear down handlers."""

        self.connection.reactor.remove_global_regex_handler(['pubmsg', 'privmsg'], self.handle_roll)
        self.connection.reactor.remove_global_regex_handler(['pubmsg', 'privmsg'], self.handle_ctech)

        super(Dice, self).stop()

    def handle_roll(self, connection, event, match):
        """Handle the !roll command which covers most common dice stuff."""

        nick = NickMask(event.source).nick
        dicestr = match.group(1)

        reply = "{0:s}: {1:s}".format(nick, self.roller.do_roll(dicestr))
        return self.bot.reply(event, re.sub(r'(\d+)(.*?\s+)(\(.*?\))', r'\1\214\3', reply))

    def handle_ctech(self, connection, event, match):
        """Handle cthulhutech dice rolls."""

        nick = NickMask(event.source).nick
        rollitrs = re.split(';\s*', match.group(1))
        reply = ""
        for count, roll in enumerate(rollitrs):
            pattern = '^(\d+)d(?:(\+|\-)(\d+))?(?:\s+(.*))?'
            regex = re.compile(pattern)
            matches = regex.search(roll)
            if matches is not None:
                dice = int(matches.group(1))
                modifier = 0

                if matches.group(2) is not None and matches.group(3) is not None:
                    if str(matches.group(2)) == '-':
                        modifier = -1 * int(matches.group(3))
                    else:
                        modifier = int(matches.group(3))

                result = roll + ': '

                rolls = []
                for d in range(dice):
                    rolls.append(random.randint(1, 10))
                rolls.sort()
                rolls.reverse()

                # highest single die method
                method1 = rolls[0]

                # highest set method
                method2 = 0
                rolling_sum = 0
                for i, r in enumerate(rolls):
                    # if next roll is same as current, sum and continue, else see if sum is best so far
                    if i+1 < len(rolls) and rolls[i+1] == r:
                        if rolling_sum == 0:
                            rolling_sum = r
                        rolling_sum += r
                    else:
                        if rolling_sum > method2:
                            method2 = rolling_sum
                        rolling_sum = 0
                # check for set in progress (e.g. lots of 1s)
                if rolling_sum > method2:
                    method2 = rolling_sum

                # straight method
                method3 = 0
                rolling_sum = 0
                count = 0
                for i, r in enumerate(rolls):
                    # if next roll is one less as current, sum and continue, else check len and see if sum is best so far
                    if i+1 < len(rolls) and rolls[i+1] == r-1:
                        if rolling_sum == 0:
                            rolling_sum = r
                            count += 1
                        rolling_sum += r-1
                        count += 1
                    else:
                        if count >= 3 and rolling_sum > method3:
                            method3 = rolling_sum
                        rolling_sum = 0
                # check for straight in progress (e.g. straight ending in 1)
                if count >= 3 and rolling_sum > method3:
                    method3 = rolling_sum

                # get best roll
                best = max([method1, method2, method3])

                # check for critical failure
                botch = False
                ones = 0
                for r in rolls:
                    if r == 1:
                        ones += 1
                if ones >= math.ceil(float(len(rolls))/2):
                    botch = True

                if botch:
                    result += 'BOTCH'
                else:
                    result += str(best + modifier)
                rollres = ''
                for i,r in enumerate(rolls):
                    rollres += str(r)
                    if i is not len(rolls)-1:
                        rollres += ','
                result += ' [' + rollres
                if modifier != 0:
                    if modifier > 0:
                        result += ' +' + str(modifier)
                    else:
                        result += ' -' + str(modifier * -1)
                result += ']'

                reply += result
                if count is not len(rollitrs)-1:
                    reply += "; "
        if reply is not "":
            msg = "{0:s}: {1:s}".format(nick, reply)
            return self.bot.reply(event, msg)


class DiceRoller(object):

    tokens = ['NUMBER', 'TEXT', 'ROLLSEP']
    literals = ['#', '/', '+', '-', 'd']

    t_TEXT = r'\s+[^;]+'
    t_ROLLSEP = r';\s*'

    def build(self):
        lex.lex(module=self)
        yacc.yacc(module=self)

    def t_NUMBER(self, t):
        r'\d+'
        t.value = int(t.value)
        return t

    def t_error(self, t):
        t.lexer.skip(1)

    precedence = (
        ('left', 'ROLLSEP'),
        ('left', '+', '-'),
        ('right', 'd'),
        ('left', '#'),
        ('left', '/')
    )

    output = ""

    def roll_dice(self, keep, dice, size):
        """Takes the parsed dice string for a single roll (eg 3/4d20) and performs
        the actual roll. Returns a string representing the result.
        """

        a = range(dice)
        for i in range(dice):
            a[i] = random.randint(1, size)
        if keep != dice:
            b = sorted(a, reverse=True)
            b = b[0:keep]
        else:
            b = a
        total = sum(b)
        outstr = "[" + ",".join(str(i) for i in a) + "]"

        return (total, outstr)

    def process_roll(self, trials, mods, comment):
        """Processes rolls coming from the parser.

        This generates the inputs for the roll_dice() command, and returns
        the full string representing the whole current dice string (the part
        up to a semicolon or end of line).
        """

        output = ""
        repeat = 1
        if trials != None:
            repeat = trials
        for i in range(repeat):
            mode = 1
            total = 0
            curr_str = ""
            if i > 0:
                output += ", "
            for m in mods:
                keep = 0
                dice = 1
                res = 0
                # if m is a tuple, then it is a die roll
                # m[0] = (keep, num dice)
                # m[1] = num faces on the die
                if type(m) == tuple:
                    if m[0] != None:
                        if m[0][0] != None:
                            keep = m[0][0]
                        dice = m[0][1]
                    size = m[1]
                    if keep > dice or keep == 0:
                        keep = dice
                    if size < 1:
                        output = "# of sides for die is incorrect: %d" % size
                        return output
                    if dice < 1:
                        output = "# of dice is incorrect: %d" % dice
                        return output
                    res = self.roll_dice(keep, dice, size)
                    curr_str += "%d%s" % (res[0], res[1])
                    res = res[0]
                elif m == "+":
                    mode = 1
                    curr_str += "+"
                elif m == "-":
                    mode = -1
                    curr_str += "-"
                else:
                    res = m
                    curr_str += str(m)
                total += mode * res
            if repeat == 1:
                if comment != None:
                    output = "%d %s (%s)" % (total, comment.strip(), curr_str)
                else:
                    output = "%d (%s)" % (total, curr_str)
            else:
                output += "%d (%s)" % (total, curr_str)
                if i == repeat - 1:
                    if comment != None:
                        output += " (%s)" % (comment.strip())
        return output

    def p_roll_r(self, p):
        # Chain rolls together.

        # General idea I had when creating this grammar: A roll string is a chain
        # of modifiers, which may be repeated for a certain number of trials. It can
        # have a comment that describes the roll
        # Multiple roll strings can be chained with semicolon

        'roll : roll ROLLSEP roll'
        global output
        p[0] = p[1] + "; " + p[3]
        output = p[0]

    def p_roll(self, p):
        # Parse a basic roll string.

        'roll : trial modifier comment'
        global output
        mods = []
        if type(p[2]) == list:
            mods = p[2]
        else:
            mods = [p[2]]
        p[0] = self.process_roll(p[1], mods, p[3])
        output = p[0]

    def p_roll_no_trials(self, p):
        # Parse a roll string without trials.

        'roll : modifier comment'
        global output
        mods = []
        if type(p[1]) == list:
            mods = p[1]
        else:
            mods = [p[1]]
        p[0] = self.process_roll(None, mods, p[2])
        output = p[0]

    def p_comment(self, p):
        # Parse a comment.

        '''comment : TEXT
                   |'''
        if len(p) == 2:
            p[0] = p[1]
        else:
            p[0] = None

    def p_modifier(self, p):
        # Parse a modifier on a roll string.

        '''modifier : modifier "+" modifier
                    | modifier "-" modifier'''
        # Use append to prevent nested lists (makes dealing with this easier)
        if type(p[1]) == list:
            p[1].append(p[2])
            p[1].append(p[3])
            p[0] = p[1]
        elif type(p[3]) == list:
            p[3].insert(0, p[2])
            p[3].insert(0, p[1])
            p[0] = p[3]
        else:
            p[0] = [p[1], p[2], p[3]]

    def p_die(self, p):
        # Return the left side before the "d", and the number of faces.

        'modifier : left NUMBER'
        p[0] = (p[1], p[2])

    def p_die_num(self, p):
        'modifier : NUMBER'
        p[0] = p[1]

    def p_left(self, p):
        # Parse the number of dice we are rolling, and how many we are keeping.
        'left : keep dice'
        if p[1] == None:
            p[0] = [None, p[2]]
        else:
            p[0] = [p[1], p[2]]

    def p_left_all(self, p):
        'left : dice'
        p[0] = [None, p[1]]

    def p_left_e(self, p):
        'left :'
        p[0] = None

    def p_total(self, p):
        'trial : NUMBER "#"'
        if len(p) > 1:
            p[0] = p[1]
        else:
            p[0] = None

    def p_keep(self, p):
        'keep : NUMBER "/"'
        if p[1] != None:
            p[0] = p[1]
        else:
            p[0] = None

    def p_dice(self, p):
        'dice : NUMBER "d"'
        p[0] = p[1]

    def p_dice_one(self, p):
        'dice : "d"'
        p[0] = 1

    def p_error(self, p):
        # Provide the user with something (albeit not much) when the roll can't be parsed.
        global output
        output = "Unable to parse roll"

    def get_result(self):
        global output
        return output

    def do_roll(self, dicestr):
        """
        Roll some dice and get the result (with broken out rolls).

        Keyword arguments:
        dicestr - format:
        N#X/YdS+M label
            N#: do the following roll N times (optional)
            X/: take the top X rolls of the Y times rolled (optional)
            Y : roll the die specified Y times (optional, defaults to 1)
            dS: roll a S-sided die
            +M: add M to the result (-M for subtraction) (optional)
        """
        self.build()
        yacc.parse(dicestr)
        return self.get_result()


plugin = Dice