Compare commits

..

1 Commits

Author SHA1 Message Date
3ca921ad02 Cleaned up code and refactored to new file 2025-10-14 10:31:34 -07:00
2 changed files with 61 additions and 56 deletions

View File

@@ -1,13 +1,11 @@
import os import os
import logging import logging
from itertools import chain from itertools import chain
import csv
from typing import Generator
import click import click
from .inference import InferenceContext, load_ucs from .inference import InferenceContext, load_ucs
from .import_csv import csv_to_data
from .gather import (build_sentence_class_dataset, print_dataset_stats, from .gather import (build_sentence_class_dataset, print_dataset_stats,
ucs_definitions_generator, scan_metadata, walk_path) ucs_definitions_generator, scan_metadata, walk_path)
from .recommend import print_recommendation from .recommend import print_recommendation
@@ -22,6 +20,7 @@ formatter = logging.Formatter(
stream_handler.setFormatter(formatter) stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler) logger.addHandler(stream_handler)
@click.group(epilog="For more information see " @click.group(epilog="For more information see "
"<https://git.squad51.us/jamie/ucsinfer>") "<https://git.squad51.us/jamie/ucsinfer>")
@click.option('--verbose', '-v', flag_value=True, help='Verbose output') @click.option('--verbose', '-v', flag_value=True, help='Verbose output')
@@ -77,7 +76,7 @@ def ucsinfer(ctx, verbose, no_model_cache, model, complete_ucs):
help="Recommend a category for given text instead of reading " help="Recommend a category for given text instead of reading "
"from a file") "from a file")
@click.argument('paths', nargs=-1, metavar='<paths>') @click.argument('paths', nargs=-1, metavar='<paths>')
@click.option('--interactive','-i', flag_value=True, default=False, @click.option('--interactive', '-i', flag_value=True, default=False,
help="After processing each path in <paths>, prompt for a " help="After processing each path in <paths>, prompt for a "
"recommendation to accept, and then prepend the selection to " "recommendation to accept, and then prepend the selection to "
"the file name.") "the file name.")
@@ -121,7 +120,8 @@ def recommend(ctx, text, paths, interactive, skip_ucs):
text = os.path.basename(path) text = os.path.basename(path)
while True: while True:
retval = print_recommendation(path, text, inference_ctx, interactive) retval = print_recommendation(
path, text, inference_ctx, interactive)
if not retval: if not retval:
break break
if retval[0] is False: if retval[0] is False:
@@ -136,28 +136,6 @@ def recommend(ctx, text, paths, interactive, skip_ucs):
os.rename(path, new_path) os.rename(path, new_path)
break break
def csv_to_data(paths, description_key, filename_key, catid_list) -> Generator[tuple[str, str], None, None]:
"""
Accepts a list of paths and returns an iterator of (sentence, class)
tuples.
"""
for path in paths:
with open(path, 'r') as f:
records = csv.DictReader(f)
assert filename_key in records.fieldnames, \
(f"Filename key `{filename_key}` not present in file "
"{path}")
assert description_key in records.fieldnames, \
(f"Description key `{description_key}` not present in "
"file {path}")
for record in records:
ucs_comps = parse_ucs(record[filename_key], catid_list)
if ucs_comps:
yield (record[description_key], ucs_comps.cat_id)
@ucsinfer.command('csv') @ucsinfer.command('csv')
@click.option('--filename-col', default="FileName", @click.option('--filename-col', default="FileName",
@@ -189,7 +167,7 @@ def import_csv(ctx, paths: list[str], out, filename_col, description_col):
dataset = build_sentence_class_dataset( dataset = build_sentence_class_dataset(
chain(csv_to_data(paths, description_col, filename_col, catid_list), chain(csv_to_data(paths, description_col, filename_col, catid_list),
ucs_definitions_generator(ucs)),catid_list) ucs_definitions_generator(ucs)), catid_list)
logger.info(f"Saving dataset to disk at {out}") logger.info(f"Saving dataset to disk at {out}")
print_dataset_stats(dataset, catid_list) print_dataset_stats(dataset, catid_list)
@@ -220,7 +198,7 @@ def gather(ctx, paths, out):
logger.debug(f"Loading category list...") logger.debug(f"Loading category list...")
ucs = load_ucs(full_ucs=ctx.obj['complete_ucs']) ucs = load_ucs(full_ucs=ctx.obj['complete_ucs'])
scan_list: list[tuple[str,str]] = [] scan_list: list[tuple[str, str]] = []
catid_list = [cat.catid for cat in ucs] catid_list = [cat.catid for cat in ucs]
for path in paths: for path in paths:
@@ -239,6 +217,7 @@ def gather(ctx, paths, out):
print_dataset_stats(dataset, catid_list) print_dataset_stats(dataset, catid_list)
dataset.save_to_disk(out) dataset.save_to_disk(out)
@ucsinfer.command('qualify') @ucsinfer.command('qualify')
def qualify(): def qualify():
""" """
@@ -260,7 +239,6 @@ def finetune(ctx):
logger.debug("FINETUNE mode") logger.debug("FINETUNE mode")
@ucsinfer.command('evaluate') @ucsinfer.command('evaluate')
@click.argument('dataset', default='dataset/') @click.argument('dataset', default='dataset/')
@click.pass_context @click.pass_context
@@ -272,6 +250,5 @@ def evaluate(ctx, dataset, offset, limit):
logger.warning("Model evaluation is not currently implemented") logger.warning("Model evaluation is not currently implemented")
if __name__ == '__main__': if __name__ == '__main__':
ucsinfer(obj={}) ucsinfer(obj={})

28
ucsinfer/import_csv.py Normal file
View File

@@ -0,0 +1,28 @@
import csv
from typing import Generator
from .util import parse_ucs
def csv_to_data(paths, description_key, filename_key, catid_list) -> Generator[tuple[str, str], None, None]:
"""
Accepts a list of paths and returns an iterator of (sentence, class)
tuples.
"""
for path in paths:
with open(path, 'r') as f:
records = csv.DictReader(f)
assert filename_key in records.fieldnames, \
(f"Filename key `{filename_key}` not present in file "
"{path}")
assert description_key in records.fieldnames, \
(f"Description key `{description_key}` not present in "
"file {path}")
for record in records:
ucs_comps = parse_ucs(record[filename_key], catid_list)
if ucs_comps:
yield (record[description_key], ucs_comps.cat_id)