Cleaned up code and refactored to new file

This commit is contained in:
2025-10-14 10:31:34 -07:00
parent 5fc57cf7c8
commit 3ca921ad02
2 changed files with 61 additions and 56 deletions

View File

@@ -1,13 +1,11 @@
import os import os
import logging import logging
from itertools import chain from itertools import chain
import csv
from typing import Generator
import click import click
from .inference import InferenceContext, load_ucs from .inference import InferenceContext, load_ucs
from .import_csv import csv_to_data
from .gather import (build_sentence_class_dataset, print_dataset_stats, from .gather import (build_sentence_class_dataset, print_dataset_stats,
ucs_definitions_generator, scan_metadata, walk_path) ucs_definitions_generator, scan_metadata, walk_path)
from .recommend import print_recommendation from .recommend import print_recommendation
@@ -22,6 +20,7 @@ formatter = logging.Formatter(
stream_handler.setFormatter(formatter) stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler) logger.addHandler(stream_handler)
@click.group(epilog="For more information see " @click.group(epilog="For more information see "
"<https://git.squad51.us/jamie/ucsinfer>") "<https://git.squad51.us/jamie/ucsinfer>")
@click.option('--verbose', '-v', flag_value=True, help='Verbose output') @click.option('--verbose', '-v', flag_value=True, help='Verbose output')
@@ -121,7 +120,8 @@ def recommend(ctx, text, paths, interactive, skip_ucs):
text = os.path.basename(path) text = os.path.basename(path)
while True: while True:
retval = print_recommendation(path, text, inference_ctx, interactive) retval = print_recommendation(
path, text, inference_ctx, interactive)
if not retval: if not retval:
break break
if retval[0] is False: if retval[0] is False:
@@ -136,28 +136,6 @@ def recommend(ctx, text, paths, interactive, skip_ucs):
os.rename(path, new_path) os.rename(path, new_path)
break break
def csv_to_data(paths, description_key, filename_key, catid_list) -> Generator[tuple[str, str], None, None]:
"""
Accepts a list of paths and returns an iterator of (sentence, class)
tuples.
"""
for path in paths:
with open(path, 'r') as f:
records = csv.DictReader(f)
assert filename_key in records.fieldnames, \
(f"Filename key `{filename_key}` not present in file "
"{path}")
assert description_key in records.fieldnames, \
(f"Description key `{description_key}` not present in "
"file {path}")
for record in records:
ucs_comps = parse_ucs(record[filename_key], catid_list)
if ucs_comps:
yield (record[description_key], ucs_comps.cat_id)
@ucsinfer.command('csv') @ucsinfer.command('csv')
@click.option('--filename-col', default="FileName", @click.option('--filename-col', default="FileName",
@@ -239,6 +217,7 @@ def gather(ctx, paths, out):
print_dataset_stats(dataset, catid_list) print_dataset_stats(dataset, catid_list)
dataset.save_to_disk(out) dataset.save_to_disk(out)
@ucsinfer.command('qualify') @ucsinfer.command('qualify')
def qualify(): def qualify():
""" """
@@ -260,7 +239,6 @@ def finetune(ctx):
logger.debug("FINETUNE mode") logger.debug("FINETUNE mode")
@ucsinfer.command('evaluate') @ucsinfer.command('evaluate')
@click.argument('dataset', default='dataset/') @click.argument('dataset', default='dataset/')
@click.pass_context @click.pass_context
@@ -272,6 +250,5 @@ def evaluate(ctx, dataset, offset, limit):
logger.warning("Model evaluation is not currently implemented") logger.warning("Model evaluation is not currently implemented")
if __name__ == '__main__': if __name__ == '__main__':
ucsinfer(obj={}) ucsinfer(obj={})

28
ucsinfer/import_csv.py Normal file
View File

@@ -0,0 +1,28 @@
import csv
from typing import Generator
from .util import parse_ucs
def csv_to_data(paths, description_key, filename_key, catid_list) -> Generator[tuple[str, str], None, None]:
"""
Accepts a list of paths and returns an iterator of (sentence, class)
tuples.
"""
for path in paths:
with open(path, 'r') as f:
records = csv.DictReader(f)
assert filename_key in records.fieldnames, \
(f"Filename key `{filename_key}` not present in file "
"{path}")
assert description_key in records.fieldnames, \
(f"Description key `{description_key}` not present in "
"file {path}")
for record in records:
ucs_comps = parse_ucs(record[filename_key], catid_list)
if ucs_comps:
yield (record[description_key], ucs_comps.cat_id)