diff --git a/markov/lib.py b/markov/lib.py index 9e7243a..958541a 100644 --- a/markov/lib.py +++ b/markov/lib.py @@ -164,6 +164,10 @@ def get_word_out_of_states(states, backwards=False): new_word = '' running = 0 count_sum = states.aggregate(Sum('count'))['count__sum'] + if not count_sum: + # this being None probably means there's no data for this context + raise ValueError("no markov states to generate from") + hit = random.randint(0, count_sum) log.debug("sum: {0:d} hit: {1:d}".format(count_sum, hit)) diff --git a/markov/urls.py b/markov/urls.py index aeec473..1fcb1a9 100644 --- a/markov/urls.py +++ b/markov/urls.py @@ -1,11 +1,12 @@ """URL patterns for markov stuff.""" - -from django.conf.urls import url +from django.urls import path from django.views.generic import TemplateView -from markov.views import context_index +from markov.views import context_index, rpc_generate_line_for_context urlpatterns = [ - url(r'^$', TemplateView.as_view(template_name='index.html'), name='markov_index'), - url(r'^context/(?P\d+)/$', context_index, name='markov_context_index'), + path('', TemplateView.as_view(template_name='index.html'), name='markov_index'), + path('context//', context_index, name='markov_context_index'), + + path('rpc/context//generate/', rpc_generate_line_for_context, name='markov_rpc_generate_line'), ] diff --git a/markov/views.py b/markov/views.py index 2975866..5e3f87b 100644 --- a/markov/views.py +++ b/markov/views.py @@ -1,16 +1,19 @@ """Manipulate Markov data via the Django site.""" - +import json import logging import time from django.http import HttpResponse from django.shortcuts import get_object_or_404, render +from rest_framework.authentication import BasicAuthentication +from rest_framework.permissions import IsAuthenticated +from rest_framework.response import Response +from rest_framework.decorators import api_view, authentication_classes, permission_classes import markov.lib as markovlib from markov.models import MarkovContext - -log = logging.getLogger('markov.views') +logger = logging.getLogger(__name__) def index(request): @@ -28,3 +31,33 @@ def context_index(request, context_id): end_t = time.time() return render(request, 'markov/context.html', {'chain': chain, 'context': context, 'elapsed': end_t - start_t}) + + +@api_view(['POST']) +@authentication_classes((BasicAuthentication, )) +@permission_classes((IsAuthenticated, )) +def rpc_generate_line_for_context(request, context): + """Generate a line from a given context, with optional topics included.""" + if request.method != 'POST': + return Response({'detail': "Supported method: POST."}, status=405) + + topics = None + try: + if request.body: + markov_data = json.loads(request.body) + topics = markov_data.get('topics', []) + except (json.decoder.JSONDecodeError, KeyError): + return Response({'detail': "Request body, if provided, must be JSON with an optional 'topics' parameter."}, + status=400) + + context_id = markovlib.get_or_create_target_context(context) + try: + generated_words = markovlib.generate_line(context_id, topics) + except ValueError as vex: + return Response({'detail': f"Could not generate line: {vex}", 'context': context, 'topics': topics}, + status=400) + else: + return Response({ + 'context': context, 'topics': topics, + 'generated_line': ' '.join(generated_words), 'generated_words': generated_words + })