Compare commits

..

3 Commits

Author SHA1 Message Date
Jamie Hardt
4f7b2a73cb Merge branch 'master' of https://git.squad51.us/jamie/ucsinfer 2025-08-27 15:05:26 -07:00
Jamie Hardt
cc78291f1d Implementation 2025-08-27 15:03:06 -07:00
Jamie Hardt
7b3930ece0 Added more helpful prompts and help 2025-08-27 14:04:46 -07:00

View File

@@ -2,6 +2,8 @@ import os
import sys
import csv
from typing import Optional
from sentence_transformers import SentenceTransformer
import tqdm
import click
@@ -11,19 +13,60 @@ from .inference import InferenceContext, load_ucs
from .util import ffmpeg_description, parse_ucs
@click.group()
def recommend_text(text: str, ctx: InferenceContext):
return None
@click.group(epilog="For more information see "
"<https://git.squad51.us/jamie/ucsinfer>")
# @click.option('--verbose', flag_value='verbose', help='Verbose output')
def ucsinfer():
"""
Tools for applying UCS categories to sounds using large-language Models
"""
pass
@ucsinfer.command('recommend')
def recommend():
@click.option('--text', type=Optional[str], default=None,
help="Recommend a category for given text instead of reading "
"from a file")
@click.argument('files', nargs=-1)
@click.option('--model', type=str, metavar="<model-name>",
default="paraphrase-multilingual-mpnet-base-v2",
show_default=True,
help="Select the sentence_transformer model to use")
def recommend(text, paths, model):
"""
Infer a UCS category for a text description
"""
pass
m = SentenceTransformer(model)
ctx = InferenceContext(m, model)
recommendations = []
if text is not None:
recommendations.append({
"text": text,
"recommendations": recommend_text(text, ctx)
})
for path in paths:
text = ffmpeg_description(path)
if text:
recommendations.append({
"path":path,
"text":text,
"recommendations":recommend_text(text, ctx)
})
else:
recommendations.append({
"path":path,
"text":None,
"recommendations":None
})
@ucsinfer.command('gather')
@@ -77,13 +120,17 @@ def finetune():
@ucsinfer.command('evaluate')
@click.option('--offset', type=int, default=0)
@click.option('--limit', type=int, default=-1)
@click.option('--offset', type=int, default=0, metavar="<int>",
help='Skip this many records in the dataset before processing')
@click.option('--limit', type=int, default=-1, metavar="<int>",
help='Process this many records and then exit')
@click.option('--no-foley', 'no_foley', flag_value=True, default=False,
help="Ignore any data in the set with FOLYProp or FOLYFeet "
"category")
@click.option('--model', type=str,
default="paraphrase-multilingual-mpnet-base-v2")
@click.option('--model', type=str, metavar="<model-name>",
default="paraphrase-multilingual-mpnet-base-v2",
show_default=True,
help="Select the sentence_transformer model to use")
@click.argument('dataset', type=click.File('r', encoding='utf8'),
default='dataset.csv')
def evaluate(dataset, offset, limit, model, no_foley):
@@ -110,8 +157,17 @@ def evaluate(dataset, offset, limit, model, no_foley):
ctx = InferenceContext(m, model)
reader = csv.reader(dataset)
print(f"Evaluating model {model}...")
results = []
for i, row in enumerate(tqdm.tqdm(reader)):
if offset > 0:
print(f"Skipping {offset} records...")
if limit > 0:
print(f"Will only evaluate {limit} records...")
progress_bar = tqdm.tqdm(total=limit)
for i, row in enumerate(reader):
if i < offset:
continue
@@ -132,6 +188,8 @@ def evaluate(dataset, offset, limit, model, no_foley):
else:
results.append({'catid': cat_id, 'result': "MISS"})
progress_bar.update(1)
total = len(results)
total_top = len([x for x in results if x['result'] == 'TOP'])
total_top_5 = len([x for x in results if x['result'] == 'TOP_5'])