From 74c03cff88b744c6b47e8289cdeff4ed9f85f68d Mon Sep 17 00:00:00 2001 From: "Brian S. Stephan" Date: Wed, 15 Jun 2011 20:40:24 -0500 Subject: [PATCH] update markov chain import script for always using a context, specified on command line also read stdin rather than a file for lines --- scripts/import-file-into-markov_chain.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/scripts/import-file-into-markov_chain.py b/scripts/import-file-into-markov_chain.py index 9657222..8f415c9 100644 --- a/scripts/import-file-into-markov_chain.py +++ b/scripts/import-file-into-markov_chain.py @@ -16,17 +16,22 @@ You should have received a copy of the GNU General Public License along with this program. If not, see . """ -import fileinput +import argparse import os import sqlite3 import sys +parser = argparse.ArgumentParser(description='Import lines into the specified context.') +parser.add_argument('context', metavar='CONTEXT', type=str, nargs=1) +args = parser.parse_args() +print(args.context[0]) + db = sqlite3.connect('dr.botzo.data') db.row_factory = sqlite3.Row cur = db.cursor() -statement = 'INSERT INTO markov_chain (k1, k2, v) VALUES (?, ?, ?)' -for line in fileinput.input(): +statement = 'INSERT INTO markov_chain (k1, k2, v, context) VALUES (?, ?, ?, ?)' +for line in sys.stdin: # set up the head of the chain w1 = '__start1' w2 = '__start2' @@ -34,7 +39,7 @@ for line in fileinput.input(): # for each word pair, add the next word to the dictionary for word in line.split(): try: - cur.execute(statement, (w1.decode('utf-8', 'replace').lower(), w2.decode('utf-8', 'replace').lower(), word.decode('utf-8', 'replace').lower())) + cur.execute(statement, (w1.decode('utf-8', 'replace').lower(), w2.decode('utf-8', 'replace').lower(), word.decode('utf-8', 'replace').lower(), args.context[0])) except sqlite3.Error as e: db.rollback() print("sqlite error: " + str(e)) @@ -43,7 +48,7 @@ for line in fileinput.input(): w1, w2 = w2, word try: - cur.execute(statement, (w1.decode('utf-8', 'replace').lower(), w2.decode('utf-8', 'replace').lower(), '__stop')) + cur.execute(statement, (w1.decode('utf-8', 'replace').lower(), w2.decode('utf-8', 'replace').lower(), '__stop', args.context[0])) db.commit() except sqlite3.Error as e: db.rollback()