Cleaned up code and refactored to new file

This commit is contained in:
2025-10-14 10:31:34 -07:00
parent 5fc57cf7c8
commit 3ca921ad02
2 changed files with 61 additions and 56 deletions

View File

@@ -1,13 +1,11 @@
import os import os
import logging import logging
from itertools import chain from itertools import chain
import csv
from typing import Generator
import click import click
from .inference import InferenceContext, load_ucs from .inference import InferenceContext, load_ucs
from .import_csv import csv_to_data
from .gather import (build_sentence_class_dataset, print_dataset_stats, from .gather import (build_sentence_class_dataset, print_dataset_stats,
ucs_definitions_generator, scan_metadata, walk_path) ucs_definitions_generator, scan_metadata, walk_path)
from .recommend import print_recommendation from .recommend import print_recommendation
@@ -22,6 +20,7 @@ formatter = logging.Formatter(
stream_handler.setFormatter(formatter) stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler) logger.addHandler(stream_handler)
@click.group(epilog="For more information see " @click.group(epilog="For more information see "
"<https://git.squad51.us/jamie/ucsinfer>") "<https://git.squad51.us/jamie/ucsinfer>")
@click.option('--verbose', '-v', flag_value=True, help='Verbose output') @click.option('--verbose', '-v', flag_value=True, help='Verbose output')
@@ -77,7 +76,7 @@ def ucsinfer(ctx, verbose, no_model_cache, model, complete_ucs):
help="Recommend a category for given text instead of reading " help="Recommend a category for given text instead of reading "
"from a file") "from a file")
@click.argument('paths', nargs=-1, metavar='<paths>') @click.argument('paths', nargs=-1, metavar='<paths>')
@click.option('--interactive','-i', flag_value=True, default=False, @click.option('--interactive', '-i', flag_value=True, default=False,
help="After processing each path in <paths>, prompt for a " help="After processing each path in <paths>, prompt for a "
"recommendation to accept, and then prepend the selection to " "recommendation to accept, and then prepend the selection to "
"the file name.") "the file name.")
@@ -121,7 +120,8 @@ def recommend(ctx, text, paths, interactive, skip_ucs):
text = os.path.basename(path) text = os.path.basename(path)
while True: while True:
retval = print_recommendation(path, text, inference_ctx, interactive) retval = print_recommendation(
path, text, inference_ctx, interactive)
if not retval: if not retval:
break break
if retval[0] is False: if retval[0] is False:
@@ -136,28 +136,6 @@ def recommend(ctx, text, paths, interactive, skip_ucs):
os.rename(path, new_path) os.rename(path, new_path)
break break
def csv_to_data(paths, description_key, filename_key, catid_list) -> Generator[tuple[str, str], None, None]:
"""
Accepts a list of paths and returns an iterator of (sentence, class)
tuples.
"""
for path in paths:
with open(path, 'r') as f:
records = csv.DictReader(f)
assert filename_key in records.fieldnames, \
(f"Filename key `{filename_key}` not present in file "
"{path}")
assert description_key in records.fieldnames, \
(f"Description key `{description_key}` not present in "
"file {path}")
for record in records:
ucs_comps = parse_ucs(record[filename_key], catid_list)
if ucs_comps:
yield (record[description_key], ucs_comps.cat_id)
@ucsinfer.command('csv') @ucsinfer.command('csv')
@click.option('--filename-col', default="FileName", @click.option('--filename-col', default="FileName",
@@ -189,7 +167,7 @@ def import_csv(ctx, paths: list[str], out, filename_col, description_col):
dataset = build_sentence_class_dataset( dataset = build_sentence_class_dataset(
chain(csv_to_data(paths, description_col, filename_col, catid_list), chain(csv_to_data(paths, description_col, filename_col, catid_list),
ucs_definitions_generator(ucs)),catid_list) ucs_definitions_generator(ucs)), catid_list)
logger.info(f"Saving dataset to disk at {out}") logger.info(f"Saving dataset to disk at {out}")
print_dataset_stats(dataset, catid_list) print_dataset_stats(dataset, catid_list)
@@ -220,7 +198,7 @@ def gather(ctx, paths, out):
logger.debug(f"Loading category list...") logger.debug(f"Loading category list...")
ucs = load_ucs(full_ucs=ctx.obj['complete_ucs']) ucs = load_ucs(full_ucs=ctx.obj['complete_ucs'])
scan_list: list[tuple[str,str]] = [] scan_list: list[tuple[str, str]] = []
catid_list = [cat.catid for cat in ucs] catid_list = [cat.catid for cat in ucs]
for path in paths: for path in paths:
@@ -239,6 +217,7 @@ def gather(ctx, paths, out):
print_dataset_stats(dataset, catid_list) print_dataset_stats(dataset, catid_list)
dataset.save_to_disk(out) dataset.save_to_disk(out)
@ucsinfer.command('qualify') @ucsinfer.command('qualify')
def qualify(): def qualify():
""" """
@@ -260,7 +239,6 @@ def finetune(ctx):
logger.debug("FINETUNE mode") logger.debug("FINETUNE mode")
@ucsinfer.command('evaluate') @ucsinfer.command('evaluate')
@click.argument('dataset', default='dataset/') @click.argument('dataset', default='dataset/')
@click.pass_context @click.pass_context
@@ -272,6 +250,5 @@ def evaluate(ctx, dataset, offset, limit):
logger.warning("Model evaluation is not currently implemented") logger.warning("Model evaluation is not currently implemented")
if __name__ == '__main__': if __name__ == '__main__':
ucsinfer(obj={}) ucsinfer(obj={})

28
ucsinfer/import_csv.py Normal file
View File

@@ -0,0 +1,28 @@
import csv
from typing import Generator
from .util import parse_ucs
def csv_to_data(paths, description_key, filename_key, catid_list) -> Generator[tuple[str, str], None, None]:
"""
Accepts a list of paths and returns an iterator of (sentence, class)
tuples.
"""
for path in paths:
with open(path, 'r') as f:
records = csv.DictReader(f)
assert filename_key in records.fieldnames, \
(f"Filename key `{filename_key}` not present in file "
"{path}")
assert description_key in records.fieldnames, \
(f"Description key `{description_key}` not present in "
"file {path}")
for record in records:
ucs_comps = parse_ucs(record[filename_key], catid_list)
if ucs_comps:
yield (record[description_key], ucs_comps.cat_id)