import logging
import random

from django.db.models import Sum

from markov.models import MarkovContext, MarkovState, MarkovTarget


log = logging.getLogger('markov.lib')


def generate_line(context, topics=None, min_words=15, max_words=30, max_sentences=3):
    """String multiple sentences together into a coherent sentence."""

    tries = 0
    sentences = 0
    line = []
    while tries < 5:
        line += generate_longish_sentence(context, topics=topics, max_words=max_words)
        sentences += 1
        if sentences >= max_sentences:
            return line
        if len(line) >= min_words:
            return line
        else:
            if line[-1][-1] not in [',', '.', '!']:
                line[-1] += random.choice([',', '.', '!'])

        tries += 1

    # if we got here, we need to give up
    return line


def generate_longish_sentence(context, topics=None, min_words=4, max_words=30):
    """Generate a Markov chain, but throw away the short ones unless we get desperate."""

    tries = 0
    while tries < 5:
        sent = generate_sentence(context, topics=topics, max_words=max_words)
        if len(sent) >= min_words:
            return sent

        tries += 1

    # if we got here, we need to just give up
    return generate_sentence(context)


def generate_sentence(context, topics=None, max_words=30):
    """Generate a Markov chain."""

    words = []
    # if we have topics, try to work from it and work backwards
    if topics:
        topic_word = random.choice(topics)
        topics.remove(topic_word)
        log.debug(u"looking for topic '{0:s}'".format(topic_word))
        new_states = MarkovState.objects.filter(context=context, v=topic_word)

        if len(new_states) > 0:
            log.debug(u"found '{0:s}', starting backwards".format(topic_word))
            words.insert(0, topic_word)
            while len(words) <= max_words and words[0] != MarkovState._start2:
                log.debug(u"looking backwards for '{0:s}'".format(words[0]))
                new_states = MarkovState.objects.filter(context=context, v=words[0])
                words.insert(0, get_word_out_of_states(new_states, backwards=True))

    # if we didn't get topic stuff, we need to start (forwards) here
    if len(words) == 0:
        words = [MarkovState._start1, MarkovState._start2]

    i = len(words)
    while len(words) <= max_words and words[-1] != MarkovState._stop:
        log.debug(u"looking for '{0:s}','{1:s}'".format(words[i-2], words[i-1]))
        new_states = MarkovState.objects.filter(context=context, k1=words[i-2], k2=words[i-1])
        log.debug(u"states retrieved")
        words.append(get_word_out_of_states(new_states))
        i += 1

    words = [word for word in words if word not in
             (MarkovState._start1, MarkovState._start2, MarkovState._stop)]

    return words


def get_or_create_target_context(target_name):
    """Return the context for a provided nick/channel, creating missing ones."""

    # find the stuff, or create it
    try:
        target = MarkovTarget.objects.get(name=target_name)
        return target.context
    except MarkovTarget.DoesNotExist:
        # we need to create a context and a target, and we have to make the context first
        # make a context --- lacking a good idea, just create one with this target name until configured otherwise
        context, c = MarkovContext.objects.get_or_create(name=target_name)
        target, c = MarkovTarget.objects.get_or_create(name=target_name, context=context)

        return target.context
    except MarkovContext.DoesNotExist:
        # make a context --- lacking a good idea, just create one with this target name until configured otherwise
        context, c = MarkovContext.objects.get_or_create(name=target_name)
        target.context = context
        target.save()

        return target.context


def get_word_out_of_states(states, backwards=False):
    """Pick one random word out of the given states."""

    new_word = ''
    running = 0
    count_sum = states.aggregate(Sum('count'))['count__sum']
    hit = random.randint(0, count_sum)

    log.debug(u"sum: {0:d} hit: {1:d}".format(count_sum, hit))

    states_itr = states.iterator()
    for state in states_itr:
        running += state.count
        if running >= hit:
            if backwards:
                new_word = state.k2
            else:
                new_word = state.v

            break

    log.debug(u"found '{0:s}'".format(new_word))
    return new_word


def learn_line(line, context):
    """Create a bunch of MarkovStates for a given line of text."""

    log.debug(u"learning %s...", line[:40])

    words = line.split()
    words = [MarkovState._start1, MarkovState._start2] + words + [MarkovState._stop]

    for word in words:
        if len(word) > MarkovState._meta.get_field('k1').max_length:
            return

    for i, word in enumerate(words):
        log.debug(u"'{0:s}','{1:s}' -> '{2:s}'".format(words[i], words[i+1], words[i+2]))
        state, created = MarkovState.objects.get_or_create(context=context,
                                                           k1=words[i],
                                                           k2=words[i+1],
                                                           v=words[i+2])
        state.count += 1
        state.save()

        if i > len(words) - 4:
            break