Added another function to recommend

This commit is contained in:
2025-09-04 10:43:31 -07:00
parent 10519f9c1a
commit 103fffe0a4
4 changed files with 34 additions and 140 deletions

View File

@@ -2,13 +2,14 @@ import os
# import csv # import csv
import logging import logging
from subprocess import CalledProcessError from subprocess import CalledProcessError
from itertools import chain
import tqdm import tqdm
import click import click
# from tabulate import tabulate, SEPARATING_LINE # from tabulate import tabulate, SEPARATING_LINE
from .inference import InferenceContext, load_ucs from .inference import InferenceContext, load_ucs
from .gather import build_sentence_class_dataset from .gather import build_sentence_class_dataset, print_dataset_stats
from .recommend import print_recommendation from .recommend import print_recommendation
from .util import ffmpeg_description, parse_ucs from .util import ffmpeg_description, parse_ucs
@@ -198,19 +199,20 @@ def gather(ctx, paths, out, ucs_data):
assert comps assert comps
yield comps.fx_name, str(pair[0]) yield comps.fx_name, str(pair[0])
def ucs_metadata(): def ucs_metadata():
for cat in ucs: for cat in ucs:
yield cat.explanations, cat.catid yield cat.explanations, cat.catid
yield ", ".join(cat.synonymns), cat.catid yield ", ".join(cat.synonymns), cat.catid
logger.info("Building dataset...") logger.info("Building dataset...")
if ucs_data:
dataset = build_sentence_class_dataset(ucs_metadata(), catid_list) dataset = build_sentence_class_dataset(chain(scan_metadata(),
else: ucs_metadata()),
dataset = build_sentence_class_dataset(scan_metadata(), catid_list) catid_list)
logger.info(f"Saving dataset to disk at {out}") logger.info(f"Saving dataset to disk at {out}")
print_dataset_stats(dataset)
dataset.save_to_disk(out) dataset.save_to_disk(out)
@@ -225,114 +227,15 @@ def finetune(ctx):
@ucsinfer.command('evaluate') @ucsinfer.command('evaluate')
@click.option('--offset', type=int, default=0, metavar="<int>", @click.argument('dataset', default='dataset/')
help='Skip this many records in the dataset before processing')
@click.option('--limit', type=int, default=-1, metavar="<int>",
help='Process this many records and then exit')
@click.argument('dataset', type=click.File('r', encoding='utf8'),
default='dataset.csv')
@click.pass_context @click.pass_context
def evaluate(ctx, dataset, offset, limit, no_foley): def evaluate(ctx, dataset, offset, limit):
""" """
Use datasets to evaluate model performance
The `evaluate` command reads the input DATASET file row by row and
performs a classifcation of the given description against the selected
model (either the default or using the --model option). The command then
checks if the model inferred the correct category as given by the dataset.
The model gives its top 10 possible categories for a given description,
and the results are tabulated according to (1) wether the top
classification was correct, (2) wether the correct classifcation was in the
top 5, or (3) wether it was in the top 10. The worst-performing category,
the one with the most misses, is also reported as well as the category
coverage, how many categories are present in the dataset.
NOTE: With experimentation it was found that foley items generally were
classified according to their subject and not wether or not they were
foley, and so these categories can be excluded with the --no-foley option.
""" """
logger.debug("EVALUATE mode") logger.debug("EVALUATE mode")
logger.warning("Model evaluation is not currently implemented") logger.warning("Model evaluation is not currently implemented")
# inference_context = InferenceContext(
# ctx.obj['model_name'], use_cached_model=ctx.obj['model_cache'],
# use_full_ucs=ctx.obj['complete_ucs'])
#
# reader = csv.reader(dataset)
#
# results = []
#
# if offset > 0:
# logger.debug(f"Skipping {offset} records...")
#
# if limit > 0:
# logger.debug(f"Will only evaluate {limit} records...")
#
# progress_bar = tqdm.tqdm(total=limit,
# desc="Processing dataset...",
# unit="rec")
# for i, row in enumerate(reader):
# if i < offset:
# continue
#
# if limit > 0 and i >= limit + offset:
# break
#
# cat_id, description = row
# if no_foley and cat_id in ['FOLYProp', 'FOLYFeet']:
# continue
#
# guesses = inference_context.classify_text_ranked(description, limit=10)
# if cat_id == guesses[0]:
# results.append({'catid': cat_id, 'result': "TOP"})
# elif cat_id in guesses[0:5]:
# results.append({'catid': cat_id, 'result': "TOP_5"})
# elif cat_id in guesses:
# results.append({'catid': cat_id, 'result': "TOP_10"})
# else:
# results.append({'catid': cat_id, 'result': "MISS"})
#
# progress_bar.update(1)
#
# total = len(results)
# total_top = len([x for x in results if x['result'] == 'TOP'])
# total_top_5 = len([x for x in results if x['result'] == 'TOP_5'])
# total_top_10 = len([x for x in results if x['result'] == 'TOP_10'])
#
# cats = set([x['catid'] for x in results])
# total_cats = len(cats)
#
# miss_counts = []
# for cat in cats:
# miss_counts.append(
# (cat, len([x for x in results
# if x['catid'] == cat and x['result'] == 'MISS'])))
#
# miss_counts = sorted(miss_counts, key=lambda x: x[1])
#
# print(f"## Results for Model {model} ##\n")
#
# if no_foley:
# print("(FOLYProp and FOLYFeet have been omitted from the dataset.)\n")
#
# table = [
# ["Total records in sample:", f"{total}"],
# ["Top Result:", f"{total_top}",
# f"{float(total_top)/float(total):.2%}"],
# ["Top 5 Result:", f"{total_top_5}",
# f"{float(total_top_5)/float(total):.2%}"],
# ["Top 10 Result:", f"{total_top_10}",
# f"{float(total_top_10)/float(total):.2%}"],
# SEPARATING_LINE,
# ["UCS category count:", f"{len(inference_context.catlist)}"],
# ["Total categories in sample:", f"{total_cats}",
# f"{float(total_cats)/float(len(inference_context.catlist)):.2%}"],
# [f"Most missed category ({miss_counts[-1][0]}):",
# f"{miss_counts[-1][1]}",
# f"{float(miss_counts[-1][1])/float(total):.2%}"]
# ]
#
# print(tabulate(table, headers=['', 'n', 'pct'], tablefmt='github'))
if __name__ == '__main__': if __name__ == '__main__':

View File

@@ -1,32 +1,7 @@
from sentence_transformers import SentenceTransformer # from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import BinaryClassificationEvaluator # from sentence_transformers.evaluation import BinaryClassificationEvaluator
from datasets import load_dataset_from_disk, DatasetDict # from datasets import load_dataset_from_disk, DatasetDict
#
def evaluate_model(model: SentenceTransformer, dataset):
# eval_dataset =
# Initialize the evaluator
binary_acc_evaluator = BinaryClassificationEvaluator(
sentences1=eval_dataset["sentence1"],
sentences2=eval_dataset["sentence2"],
labels=eval_dataset["label"],
name="quora_duplicates_dev",
)
results = binary_acc_evaluator(model)
'''
Binary Accuracy Evaluation of the model on the quora_duplicates_dev dataset:
Accuracy with Cosine-Similarity: 81.60 (Threshold: 0.8352)
F1 with Cosine-Similarity: 75.27 (Threshold: 0.7715)
Precision with Cosine-Similarity: 65.81
Recall with Cosine-Similarity: 87.89
Average Precision with Cosine-Similarity: 76.03
Matthews Correlation with Cosine-Similarity: 62.48
'''
print(binary_acc_evaluator.primary_metric)
# => "quora_duplicates_dev_cosine_ap"
print(results[binary_acc_evaluator.primary_metric])
# => 0.760277070888393

View File

@@ -1,12 +1,24 @@
from datasets import Dataset, Features, Value, ClassLabel, DatasetInfo from datasets import Dataset, Features, Value, ClassLabel, DatasetInfo
from datasets.dataset_dict import DatasetDict from datasets.dataset_dict import DatasetDict
from typing import Generator, Any from typing import Iterator
from tabulate import tabulate
def print_dataset_stats(dataset: DatasetDict, catlist: list[str]):
data_table = []
data_table.append([["Total records in combined dataset:", len(dataset)]])
data_table.append([["Total records in `train`:", len(dataset['train'])]])
tab = tabulate(data_table)
print(tab)
# https://www.sbert.net/docs/sentence_transformer/loss_overview.html # https://www.sbert.net/docs/sentence_transformer/loss_overview.html
def build_sentence_class_dataset( def build_sentence_class_dataset(
records: Generator[tuple[str, str], Any, None], records: Iterator[tuple[str, str]],
catlist: list[str]) -> DatasetDict: catlist: list[str]) -> DatasetDict:
""" """
Create a new dataset for `records` which contains (sentence, class) pairs. Create a new dataset for `records` which contains (sentence, class) pairs.

View File

@@ -43,12 +43,16 @@ def print_recommendation(path: str | None, text: str, ctx: InferenceContext,
text = m.group(1) text = m.group(1)
return True, text, None return True, text, None
elif m := match(r'^c (.*)', response):
return True, None, m.group(1)
elif response.startswith("?"): elif response.startswith("?"):
print(""" print("""
Choices: Choices:
- Enter recommendation number to rename file, - Enter recommendation number to rename file,
- "t [text]" to search for new recommendations based on [text] - "t [text]" to search for new recommendations based on [text]
- "p" re-use the last selected cat-id - "p" re-use the last selected cat-id
- "c [cat]" to type in a category by hand
- "?" for this message - "?" for this message
- "q" to quit - "q" to quit
- or any other key to skip this file and continue to next file - or any other key to skip this file and continue to next file