From 103fffe0a4686b0fab84227ee847da2cfc56ec45 Mon Sep 17 00:00:00 2001 From: Jamie Hardt Date: Thu, 4 Sep 2025 10:43:31 -0700 Subject: [PATCH] Added another function to recommend --- ucsinfer/__main__.py | 121 +++++------------------------------------- ucsinfer/evaluate.py | 33 ++---------- ucsinfer/gather.py | 16 +++++- ucsinfer/recommend.py | 4 ++ 4 files changed, 34 insertions(+), 140 deletions(-) diff --git a/ucsinfer/__main__.py b/ucsinfer/__main__.py index c329a4e..b4cb286 100644 --- a/ucsinfer/__main__.py +++ b/ucsinfer/__main__.py @@ -2,13 +2,14 @@ import os # import csv import logging from subprocess import CalledProcessError +from itertools import chain import tqdm import click # from tabulate import tabulate, SEPARATING_LINE from .inference import InferenceContext, load_ucs -from .gather import build_sentence_class_dataset +from .gather import build_sentence_class_dataset, print_dataset_stats from .recommend import print_recommendation from .util import ffmpeg_description, parse_ucs @@ -198,19 +199,20 @@ def gather(ctx, paths, out, ucs_data): assert comps yield comps.fx_name, str(pair[0]) - def ucs_metadata(): for cat in ucs: yield cat.explanations, cat.catid yield ", ".join(cat.synonymns), cat.catid logger.info("Building dataset...") - if ucs_data: - dataset = build_sentence_class_dataset(ucs_metadata(), catid_list) - else: - dataset = build_sentence_class_dataset(scan_metadata(), catid_list) + + dataset = build_sentence_class_dataset(chain(scan_metadata(), + ucs_metadata()), + catid_list) + logger.info(f"Saving dataset to disk at {out}") + print_dataset_stats(dataset) dataset.save_to_disk(out) @@ -225,114 +227,15 @@ def finetune(ctx): @ucsinfer.command('evaluate') -@click.option('--offset', type=int, default=0, metavar="", - help='Skip this many records in the dataset before processing') -@click.option('--limit', type=int, default=-1, metavar="", - help='Process this many records and then exit') -@click.argument('dataset', type=click.File('r', encoding='utf8'), - default='dataset.csv') +@click.argument('dataset', default='dataset/') @click.pass_context -def evaluate(ctx, dataset, offset, limit, no_foley): +def evaluate(ctx, dataset, offset, limit): """ - Use datasets to evaluate model performance - - The `evaluate` command reads the input DATASET file row by row and - performs a classifcation of the given description against the selected - model (either the default or using the --model option). The command then - checks if the model inferred the correct category as given by the dataset. - - The model gives its top 10 possible categories for a given description, - and the results are tabulated according to (1) wether the top - classification was correct, (2) wether the correct classifcation was in the - top 5, or (3) wether it was in the top 10. The worst-performing category, - the one with the most misses, is also reported as well as the category - coverage, how many categories are present in the dataset. - - NOTE: With experimentation it was found that foley items generally were - classified according to their subject and not wether or not they were - foley, and so these categories can be excluded with the --no-foley option. + """ logger.debug("EVALUATE mode") logger.warning("Model evaluation is not currently implemented") - # inference_context = InferenceContext( - # ctx.obj['model_name'], use_cached_model=ctx.obj['model_cache'], - # use_full_ucs=ctx.obj['complete_ucs']) - # - # reader = csv.reader(dataset) - # - # results = [] - # - # if offset > 0: - # logger.debug(f"Skipping {offset} records...") - # - # if limit > 0: - # logger.debug(f"Will only evaluate {limit} records...") - # - # progress_bar = tqdm.tqdm(total=limit, - # desc="Processing dataset...", - # unit="rec") - # for i, row in enumerate(reader): - # if i < offset: - # continue - # - # if limit > 0 and i >= limit + offset: - # break - # - # cat_id, description = row - # if no_foley and cat_id in ['FOLYProp', 'FOLYFeet']: - # continue - # - # guesses = inference_context.classify_text_ranked(description, limit=10) - # if cat_id == guesses[0]: - # results.append({'catid': cat_id, 'result': "TOP"}) - # elif cat_id in guesses[0:5]: - # results.append({'catid': cat_id, 'result': "TOP_5"}) - # elif cat_id in guesses: - # results.append({'catid': cat_id, 'result': "TOP_10"}) - # else: - # results.append({'catid': cat_id, 'result': "MISS"}) - # - # progress_bar.update(1) - # - # total = len(results) - # total_top = len([x for x in results if x['result'] == 'TOP']) - # total_top_5 = len([x for x in results if x['result'] == 'TOP_5']) - # total_top_10 = len([x for x in results if x['result'] == 'TOP_10']) - # - # cats = set([x['catid'] for x in results]) - # total_cats = len(cats) - # - # miss_counts = [] - # for cat in cats: - # miss_counts.append( - # (cat, len([x for x in results - # if x['catid'] == cat and x['result'] == 'MISS']))) - # - # miss_counts = sorted(miss_counts, key=lambda x: x[1]) - # - # print(f"## Results for Model {model} ##\n") - # - # if no_foley: - # print("(FOLYProp and FOLYFeet have been omitted from the dataset.)\n") - # - # table = [ - # ["Total records in sample:", f"{total}"], - # ["Top Result:", f"{total_top}", - # f"{float(total_top)/float(total):.2%}"], - # ["Top 5 Result:", f"{total_top_5}", - # f"{float(total_top_5)/float(total):.2%}"], - # ["Top 10 Result:", f"{total_top_10}", - # f"{float(total_top_10)/float(total):.2%}"], - # SEPARATING_LINE, - # ["UCS category count:", f"{len(inference_context.catlist)}"], - # ["Total categories in sample:", f"{total_cats}", - # f"{float(total_cats)/float(len(inference_context.catlist)):.2%}"], - # [f"Most missed category ({miss_counts[-1][0]}):", - # f"{miss_counts[-1][1]}", - # f"{float(miss_counts[-1][1])/float(total):.2%}"] - # ] - # - # print(tabulate(table, headers=['', 'n', 'pct'], tablefmt='github')) + if __name__ == '__main__': diff --git a/ucsinfer/evaluate.py b/ucsinfer/evaluate.py index 70f05cf..a19383c 100644 --- a/ucsinfer/evaluate.py +++ b/ucsinfer/evaluate.py @@ -1,32 +1,7 @@ -from sentence_transformers import SentenceTransformer -from sentence_transformers.evaluation import BinaryClassificationEvaluator -from datasets import load_dataset_from_disk, DatasetDict +# from sentence_transformers import SentenceTransformer +# from sentence_transformers.evaluation import BinaryClassificationEvaluator +# from datasets import load_dataset_from_disk, DatasetDict +# - -def evaluate_model(model: SentenceTransformer, dataset): - - # eval_dataset = - - # Initialize the evaluator - binary_acc_evaluator = BinaryClassificationEvaluator( - sentences1=eval_dataset["sentence1"], - sentences2=eval_dataset["sentence2"], - labels=eval_dataset["label"], - name="quora_duplicates_dev", - ) - results = binary_acc_evaluator(model) - ''' - Binary Accuracy Evaluation of the model on the quora_duplicates_dev dataset: - Accuracy with Cosine-Similarity: 81.60 (Threshold: 0.8352) - F1 with Cosine-Similarity: 75.27 (Threshold: 0.7715) - Precision with Cosine-Similarity: 65.81 - Recall with Cosine-Similarity: 87.89 - Average Precision with Cosine-Similarity: 76.03 - Matthews Correlation with Cosine-Similarity: 62.48 - ''' - print(binary_acc_evaluator.primary_metric) - # => "quora_duplicates_dev_cosine_ap" - print(results[binary_acc_evaluator.primary_metric]) - # => 0.760277070888393 diff --git a/ucsinfer/gather.py b/ucsinfer/gather.py index f038983..9d20cd8 100644 --- a/ucsinfer/gather.py +++ b/ucsinfer/gather.py @@ -1,12 +1,24 @@ from datasets import Dataset, Features, Value, ClassLabel, DatasetInfo from datasets.dataset_dict import DatasetDict -from typing import Generator, Any +from typing import Iterator + +from tabulate import tabulate + +def print_dataset_stats(dataset: DatasetDict, catlist: list[str]): + + data_table = [] + data_table.append([["Total records in combined dataset:", len(dataset)]]) + data_table.append([["Total records in `train`:", len(dataset['train'])]]) + + tab = tabulate(data_table) + + print(tab) # https://www.sbert.net/docs/sentence_transformer/loss_overview.html def build_sentence_class_dataset( - records: Generator[tuple[str, str], Any, None], + records: Iterator[tuple[str, str]], catlist: list[str]) -> DatasetDict: """ Create a new dataset for `records` which contains (sentence, class) pairs. diff --git a/ucsinfer/recommend.py b/ucsinfer/recommend.py index 6c73c56..5b98f33 100644 --- a/ucsinfer/recommend.py +++ b/ucsinfer/recommend.py @@ -42,6 +42,9 @@ def print_recommendation(path: str | None, text: str, ctx: InferenceContext, print("searching for new matches") text = m.group(1) return True, text, None + + elif m := match(r'^c (.*)', response): + return True, None, m.group(1) elif response.startswith("?"): print(""" @@ -49,6 +52,7 @@ Choices: - Enter recommendation number to rename file, - "t [text]" to search for new recommendations based on [text] - "p" re-use the last selected cat-id +- "c [cat]" to type in a category by hand - "?" for this message - "q" to quit - or any other key to skip this file and continue to next file