From c15c499869f6dfbaa8e31fa89baa7c1af8b616a1 Mon Sep 17 00:00:00 2001 From: Jamie Hardt Date: Sat, 13 Sep 2025 23:03:29 -0700 Subject: [PATCH] Rewriting structure, added to TODO --- TODO.md | 10 ++++++++++ ucsinfer/__main__.py | 25 ++++++++++++++++++------- 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/TODO.md b/TODO.md index cc0775f..bf83d05 100644 --- a/TODO.md +++ b/TODO.md @@ -8,9 +8,19 @@ - Maybe more dataset configurations +## Qualify + +- Print stats for a dataset + ## Fine-tune +- https://www.sbert.net/docs/sentence_transformer/loss_overview.html#loss-table + - Use (anchor, positive) pairs to train a new model + - Use (sentence) + class labels to train a new model - Implement BatchAllTripletLoss +- Implement a two-phase training regime + - 1. Train with anchored definitions then... + - 2. Train with class labels ## Evaluate diff --git a/ucsinfer/__main__.py b/ucsinfer/__main__.py index ccad36a..54fdd2d 100644 --- a/ucsinfer/__main__.py +++ b/ucsinfer/__main__.py @@ -139,18 +139,18 @@ def recommend(ctx, text, paths, interactive, skip_ucs): @ucsinfer.command('gather') @click.option('--out', default='dataset/', show_default=True) -@click.option('--ucs-data', flag_value=True, help="Create a dataset based " - "on the UCS category explanations and synonymns (PATHS will " - "be ignored.)") +# @click.option('--ucs-data', flag_value=True, help="Create a dataset based " +# "on the UCS category explanations and synonymns (PATHS will " +# "be ignored.)") @click.argument('paths', nargs=-1) @click.pass_context def gather(ctx, paths, out, ucs_data): """ Scan files to build a training dataset - The `gather` is used to build a training dataset for finetuning the - selected model. Description sentences and UCS categories are collected from - '.wav' and '.flac' files on-disk that have valid UCS filenames and assigned + `gather` is used to build a training dataset for finetuning the selected + model. Description sentences and UCS categories are collected from '.wav' + and '.flac' files on-disk that have valid UCS filenames and assigned CatIDs, and this information is recorded into a HuggingFace dataset. Gather scans the filesystem in two passes: first, the directory tree is @@ -186,7 +186,18 @@ def gather(ctx, paths, out, ucs_data): logger.info(f"Saving dataset to disk at {out}") print_dataset_stats(dataset, catid_list) dataset.save_to_disk(out) - + +@ucsinfer.command('qualify') +def qualify(): + """ + Check and prepare a dataset for finetuning + + `quality` reads a dataset and will output statistics on its coverage of the + UCS, and will add the UCS canoncial definitions to the dataset for every + extant category. + """ + pass + @ucsinfer.command('finetune') @click.pass_context