This commit is contained in:
2025-09-03 14:56:41 -07:00
parent 0594899bdd
commit fee55a7d5a

View File

@@ -54,7 +54,7 @@ def ucsinfer(ctx, verbose, no_model_cache, model):
if no_model_cache: if no_model_cache:
logger.info("Model cache inhibited by config") logger.info("Model cache inhibited by config")
logger.info("Using model {model}") logger.info(f"Using model {model}")
@ucsinfer.command('recommend') @ucsinfer.command('recommend')
@@ -117,12 +117,12 @@ def recommend(ctx, text, paths, interactive, skip_ucs):
@ucsinfer.command('gather') @ucsinfer.command('gather')
@click.option('--outfile', default='dataset', show_default=True) @click.option('--out', default='dataset/', show_default=True)
@click.option('--ucs-data', flag_value=True, help="Create a dataset based " @click.option('--ucs-data', flag_value=True, help="Create a dataset based "
"on the UCS category explanations and synonymns (PATHS will " "on the UCS category explanations and synonymns (PATHS will "
"be ignored.)") "be ignored.)")
@click.argument('paths', nargs=-1) @click.argument('paths', nargs=-1)
def gather(paths, outfile, ucs_data): def gather(paths, out, ucs_data):
""" """
Scan files to build a training dataset Scan files to build a training dataset
@@ -148,6 +148,7 @@ def gather(paths, outfile, ucs_data):
catid_list = [cat.catid for cat in ucs] catid_list = [cat.catid for cat in ucs]
if ucs_data: if ucs_data:
logger.info('Creating dataset for UCS categories instead of from PATH')
paths = [] paths = []
for path in paths: for path in paths:
@@ -182,7 +183,8 @@ def gather(paths, outfile, ucs_data):
else: else:
dataset = build_sentence_class_dataset(scan_metadata(), catid_list) dataset = build_sentence_class_dataset(scan_metadata(), catid_list)
dataset.save_to_disk(outfile) logger.info(f"Saving dataset to disk at {out}")
dataset.save_to_disk(out)
@ucsinfer.command('finetune') @ucsinfer.command('finetune')