this puts additional pressure on the sentence generator, retrying many times to get something that's long but not too long. only testing on a small context so far, so this is certainly not yet ready to go live, but the results are pretty good so far
204 lines
7.1 KiB
Python
204 lines
7.1 KiB
Python
import logging
|
|
import random
|
|
|
|
from django.db.models import Sum
|
|
|
|
from markov.models import MarkovContext, MarkovState, MarkovTarget
|
|
|
|
|
|
log = logging.getLogger('markov.lib')
|
|
|
|
|
|
def generate_line(context, topics=None, min_words=15, max_words=30, max_sentences=3, max_tries=5):
|
|
"""String multiple sentences together into a coherent sentence."""
|
|
|
|
tries = 0
|
|
sentences = 0
|
|
line = []
|
|
min_words_per_sentence = min_words / max_sentences
|
|
while tries < max_tries:
|
|
line += generate_longish_sentence(context, topics=topics, min_words=min_words_per_sentence,
|
|
max_words=max_words, max_tries=max_tries)
|
|
sentences += 1
|
|
if sentences >= max_sentences:
|
|
return line
|
|
if len(line) >= min_words:
|
|
return line
|
|
else:
|
|
if line[-1][-1] not in [',', '.', '!']:
|
|
line[-1] += random.choice([',', '.', '!'])
|
|
|
|
tries += 1
|
|
|
|
# if we got here, we need to give up
|
|
return line
|
|
|
|
|
|
def generate_longish_sentence(context, topics=None, min_words=15, max_words=30, max_tries=100):
|
|
"""Generate a Markov chain, but throw away the short ones unless we get desperate."""
|
|
|
|
sent = ""
|
|
tries = 0
|
|
while tries < max_tries:
|
|
sent = generate_sentence(context, topics=topics, min_words=min_words, max_words=max_words)
|
|
if len(sent) >= min_words:
|
|
return sent
|
|
|
|
tries += 1
|
|
|
|
# if we got here, we need to just give up
|
|
return sent
|
|
|
|
|
|
def generate_sentence(context, topics=None, min_words=15, max_words=30):
|
|
"""Generate a Markov chain."""
|
|
|
|
words = []
|
|
# if we have topics, try to work from it and work backwards
|
|
if topics:
|
|
topic_word = random.choice(topics)
|
|
topics.remove(topic_word)
|
|
log.debug("looking for topic '{0:s}'".format(topic_word))
|
|
new_states = MarkovState.objects.filter(context=context, v=topic_word)
|
|
|
|
if len(new_states) > 0:
|
|
log.debug("found '{0:s}', starting backwards".format(topic_word))
|
|
words.insert(0, topic_word)
|
|
while len(words) <= max_words and words[0] != MarkovState._start2:
|
|
log.debug("looking backwards for '{0:s}'".format(words[0]))
|
|
new_states = MarkovState.objects.filter(context=context, v=words[0])
|
|
# if we find a start, use it
|
|
if MarkovState._start2 in new_states:
|
|
log.debug("found a start2 in the results, intentionally picking it")
|
|
words.insert(0, MarkovState._start2)
|
|
else:
|
|
words.insert(0, get_word_out_of_states(new_states, backwards=True))
|
|
log.debug("picked %s", words[0])
|
|
|
|
# if what we found is too long, abandon it, sadly
|
|
if len(words) > max_words:
|
|
log.debug("%s is too long, i'm going to give up on it", words)
|
|
words.clear()
|
|
|
|
# if we didn't get topic stuff, we need to start (forwards) here, otherwise we use
|
|
# what we already put together (obviously)
|
|
if len(words) == 0:
|
|
words = [MarkovState._start1, MarkovState._start2]
|
|
|
|
i = len(words)
|
|
while words[-1] != MarkovState._stop:
|
|
log.debug("looking for '{0:s}','{1:s}'".format(words[i-2], words[i-1]))
|
|
new_states = MarkovState.objects.filter(context=context, k1=words[i-2], k2=words[i-1])
|
|
log.debug("states retrieved")
|
|
|
|
# try to find states that are in our targets
|
|
if topics and len(topics):
|
|
target_hits = list(set(words).intersection(set(topics)))
|
|
else:
|
|
target_hits = []
|
|
|
|
# if we're over min_words, and got a stop naturally, use it
|
|
if len(words) > min_words and MarkovState._stop in new_states:
|
|
log.debug("found a stop in the results, intentionally picking it")
|
|
words.append(MarkovState._stop)
|
|
elif len(target_hits) > 0:
|
|
target_hit = random.choice(target_hits)
|
|
log.debug("found a topic hit %s, using it", target_hit)
|
|
topics.remove(target_hit)
|
|
words.append(target_hit)
|
|
else:
|
|
words.append(get_word_out_of_states(new_states))
|
|
log.debug("picked %s", words[-1])
|
|
i += 1
|
|
|
|
words = [word for word in words if word not in
|
|
(MarkovState._start1, MarkovState._start2, MarkovState._stop)]
|
|
|
|
# if what we found is too long, abandon it, sadly
|
|
if len(words) > max_words:
|
|
log.debug("%s is too long, i'm going to give up on it", words)
|
|
words.clear()
|
|
|
|
return words
|
|
|
|
|
|
def get_or_create_target_context(target_name):
|
|
"""Return the context for a provided nick/channel, creating missing ones."""
|
|
|
|
target_name = target_name.lower()
|
|
|
|
# find the stuff, or create it
|
|
try:
|
|
target = MarkovTarget.objects.get(name=target_name)
|
|
except MarkovTarget.DoesNotExist:
|
|
# we need to create a context and a target, and we have to make the context first
|
|
# make a context --- lacking a good idea, just create one with this target name until configured otherwise
|
|
context, c = MarkovContext.objects.get_or_create(name=target_name)
|
|
target, c = MarkovTarget.objects.get_or_create(name=target_name, context=context)
|
|
|
|
return target.context
|
|
|
|
try:
|
|
return target.context
|
|
except MarkovContext.DoesNotExist:
|
|
# make a context --- lacking a good idea, just create one with this target name until configured otherwise
|
|
context, c = MarkovContext.objects.get_or_create(name=target_name)
|
|
target.context = context
|
|
target.save()
|
|
|
|
return target.context
|
|
|
|
|
|
def get_word_out_of_states(states, backwards=False):
|
|
"""Pick one random word out of the given states."""
|
|
|
|
# work around possible broken data, where a k1,k2 should have a value but doesn't
|
|
if len(states) == 0:
|
|
states = MarkovState.objects.filter(v=MarkovState._stop)
|
|
|
|
new_word = ''
|
|
running = 0
|
|
count_sum = states.aggregate(Sum('count'))['count__sum']
|
|
hit = random.randint(0, count_sum)
|
|
|
|
log.debug("sum: {0:d} hit: {1:d}".format(count_sum, hit))
|
|
|
|
states_itr = states.iterator()
|
|
for state in states_itr:
|
|
running += state.count
|
|
if running >= hit:
|
|
if backwards:
|
|
new_word = state.k2
|
|
else:
|
|
new_word = state.v
|
|
|
|
break
|
|
|
|
log.debug("found '{0:s}'".format(new_word))
|
|
return new_word
|
|
|
|
|
|
def learn_line(line, context):
|
|
"""Create a bunch of MarkovStates for a given line of text."""
|
|
|
|
log.debug("learning %s...", line[:40])
|
|
|
|
words = line.split()
|
|
words = [MarkovState._start1, MarkovState._start2] + words + [MarkovState._stop]
|
|
|
|
for word in words:
|
|
if len(word) > MarkovState._meta.get_field('k1').max_length:
|
|
return
|
|
|
|
for i, word in enumerate(words):
|
|
log.debug("'{0:s}','{1:s}' -> '{2:s}'".format(words[i], words[i+1], words[i+2]))
|
|
state, created = MarkovState.objects.get_or_create(context=context,
|
|
k1=words[i],
|
|
k2=words[i+1],
|
|
v=words[i+2])
|
|
state.count += 1
|
|
state.save()
|
|
|
|
if i > len(words) - 4:
|
|
break
|