"""Provide methods for manipulating markov chain processing."""
import logging
import random

from django.db.models import Sum

from markov.models import MarkovState

log = logging.getLogger(__name__)


def generate_line(context, topics=None, min_words=15, max_words=30, sentence_bias=2, max_tries=5):
    """Combine multiple sentences together into a coherent sentence."""
    tries = 0
    line = []
    min_words_per_sentence = min_words / sentence_bias
    while tries < max_tries:
        line += generate_longish_sentence(context, topics=topics, min_words=min_words_per_sentence,
                                          max_words=max_words, max_tries=max_tries)
        if len(line) >= min_words:
            return line
        else:
            if len(line) > 0:
                if line[-1][-1] not in [',', '.', '!', '?', ':']:
                    line[-1] += random.SystemRandom().choice(['?', '.', '!'])

        tries += 1

    # if we got here, we need to give up
    return line


def generate_longish_sentence(context, topics=None, min_words=15, max_words=30, max_tries=100):
    """Generate a Markov chain, but throw away the short ones unless we get desperate."""
    sent = ""
    tries = 0
    while tries < max_tries:
        sent = generate_sentence(context, topics=topics, min_words=min_words, max_words=max_words)
        if len(sent) >= min_words:
            log.debug("found a longish sentence, %s", sent)
            return sent
        else:
            log.debug("%s isn't long enough, going to try again", sent)

        tries += 1

    # if we got here, we need to just give up
    return sent


def generate_sentence(context, topics=None, min_words=15, max_words=30):
    """Generate a Markov chain."""
    words = []
    # if we have topics, try to work from it and work backwards
    if topics:
        topic_word = random.SystemRandom().choice(topics)
        topics.remove(topic_word)
        log.debug("looking for topic '%s'", topic_word)
        new_states = MarkovState.objects.filter(context=context, v=topic_word)

        if len(new_states) > 0:
            log.debug("found '%s', starting backwards", topic_word)
            words.insert(0, topic_word)
            while len(words) <= max_words and words[0] != MarkovState._start2:
                log.debug("looking backwards for '%s'", words[0])
                new_states = MarkovState.objects.filter(context=context, v=words[0])
                # if we find a start, use it
                if MarkovState._start2 in new_states:
                    log.debug("found a start2 in the results, intentionally picking it")
                    words.insert(0, MarkovState._start2)
                else:
                    words.insert(0, get_word_out_of_states(new_states, backwards=True))
                    log.debug("picked %s", words[0])

    # if what we found is too long, abandon it, sadly
    if len(words) > max_words:
        log.debug("%s is too long, i'm going to give up on it", words)
        words.clear()

    # if we didn't get topic stuff, we need to start (forwards) here, otherwise we use
    # what we already put together (obviously)
    if len(words) == 0:
        words = [MarkovState._start1, MarkovState._start2]

    i = len(words)
    while words[-1] != MarkovState._stop:
        log.debug("looking for '%s','%s'", words[i-2], words[i-1])
        new_states = MarkovState.objects.filter(context=context, k1=words[i-2], k2=words[i-1])
        log.debug("states retrieved")

        # try to find states that are in our targets
        if topics and len(topics):
            target_hits = list(set(words).intersection(set(topics)))
        else:
            target_hits = []

        if len(words) > min_words and MarkovState._stop in new_states:
            # if we're over min_words, and got a stop naturally, use it
            log.debug("found a stop in the results, intentionally picking it")
            words.append(MarkovState._stop)
        elif len(target_hits) > 0:
            # if there's a target word in the states, pick it
            target_hit = random.SystemRandom().choice(target_hits)
            log.debug("found a topic hit %s, using it", target_hit)
            topics.remove(target_hit)
            words.append(target_hit)
        elif len(words) <= min_words:
            # if we still need more words, intentionally avoid stop
            words.append(get_word_out_of_states(new_states.exclude(v=MarkovState._stop)))
            log.debug("picked (stop avoidance) %s", words[-1])
        else:
            words.append(get_word_out_of_states(new_states))
            log.debug("picked %s", words[-1])
        i += 1

    words = [word for word in words if word not in
             (MarkovState._start1, MarkovState._start2, MarkovState._stop)]

    # if what we found is too long, abandon it, sadly
    if len(words) > max_words:
        log.debug("%s is too long, i'm going to give up on it", words)
        words.clear()

    return words


def get_word_out_of_states(states, backwards=False):
    """Pick one random word out of the given states."""
    # work around possible broken data, where a k1,k2 should have a value but doesn't
    if len(states) == 0:
        states = MarkovState.objects.filter(v=MarkovState._stop)

    new_word = ''
    running = 0
    count_sum = states.aggregate(Sum('count'))['count__sum']
    if not count_sum:
        # this being None probably means there's no data for this context
        raise ValueError("no markov states to generate from")

    hit = random.SystemRandom().randint(0, count_sum)

    log.debug("sum: %s hit: %s", count_sum, hit)

    states_itr = states.iterator()
    for state in states_itr:
        running += state.count
        if running >= hit:
            if backwards:
                new_word = state.k2
            else:
                new_word = state.v

            break

    log.debug("found '%s'", new_word)
    return new_word


def learn_line(line, context):
    """Create a bunch of MarkovStates for a given line of text."""
    log.debug("learning %s...", line[:40])

    words = line.split()
    words = [MarkovState._start1, MarkovState._start2] + words + [MarkovState._stop]

    for word in words:
        if len(word) > MarkovState._meta.get_field('k1').max_length:
            return

    for i, word in enumerate(words):
        log.debug("'%s','%s' -> '%s'", words[i], words[i+1], words[i+2])
        state, created = MarkovState.objects.get_or_create(context=context,
                                                           k1=words[i],
                                                           k2=words[i+1],
                                                           v=words[i+2])
        state.count += 1
        state.save()

        if i > len(words) - 4:
            break