Added better error reporting to gather

And a bunch of options to control behavior
This commit is contained in:
2025-09-03 16:07:41 -07:00
parent fee55a7d5a
commit b833d6d3c0
4 changed files with 165 additions and 103 deletions

View File

@@ -1,11 +1,11 @@
import os import os
import csv # import csv
import logging import logging
from subprocess import CalledProcessError
import tqdm import tqdm
import click import click
from tabulate import tabulate, SEPARATING_LINE # from tabulate import tabulate, SEPARATING_LINE
from .inference import InferenceContext, load_ucs from .inference import InferenceContext, load_ucs
from .gather import build_sentence_class_dataset from .gather import build_sentence_class_dataset
@@ -30,8 +30,11 @@ logger.addHandler(stream_handler)
help="Select the sentence_transformer model to use") help="Select the sentence_transformer model to use")
@click.option('--no-model-cache', flag_value=True, @click.option('--no-model-cache', flag_value=True,
help="Don't use local model cache") help="Don't use local model cache")
@click.option('--complete-ucs', flag_value=True, default=False,
help="Use all UCS categories. By default, all 'FOLEY' and "
"'ARCHIVED' UCS categories are excluded from all functions.")
@click.pass_context @click.pass_context
def ucsinfer(ctx, verbose, no_model_cache, model): def ucsinfer(ctx, verbose, no_model_cache, model, complete_ucs):
""" """
Tools for applying UCS categories to sounds using large-language Models Tools for applying UCS categories to sounds using large-language Models
""" """
@@ -50,10 +53,17 @@ def ucsinfer(ctx, verbose, no_model_cache, model):
ctx.ensure_object(dict) ctx.ensure_object(dict)
ctx.obj['model_cache'] = not no_model_cache ctx.obj['model_cache'] = not no_model_cache
ctx.obj['model_name'] = model ctx.obj['model_name'] = model
ctx.obj['complete_ucs'] = complete_ucs
if no_model_cache: if no_model_cache:
logger.info("Model cache inhibited by config") logger.info("Model cache inhibited by config")
if complete_ucs:
logger.info("Using complete UCS catgeory list")
else:
logger.info("Non-descriptive UCS categories will be excluded. Turn "
"this option off by passing --complete-ucs.")
logger.info(f"Using model {model}") logger.info(f"Using model {model}")
@@ -82,7 +92,8 @@ def recommend(ctx, text, paths, interactive, skip_ucs):
""" """
logger.debug("RECOMMEND mode") logger.debug("RECOMMEND mode")
inference_ctx = InferenceContext(ctx.obj['model_name'], inference_ctx = InferenceContext(ctx.obj['model_name'],
use_cached_model=ctx.obj['model_cache']) use_cached_model=ctx.obj['model_cache'],
use_full_ucs=ctx.obj['complete_ucs'])
if text is not None: if text is not None:
print_recommendation(None, text, inference_ctx, print_recommendation(None, text, inference_ctx,
@@ -122,7 +133,8 @@ def recommend(ctx, text, paths, interactive, skip_ucs):
"on the UCS category explanations and synonymns (PATHS will " "on the UCS category explanations and synonymns (PATHS will "
"be ignored.)") "be ignored.)")
@click.argument('paths', nargs=-1) @click.argument('paths', nargs=-1)
def gather(paths, out, ucs_data): @click.pass_context
def gather(ctx, paths, out, ucs_data):
""" """
Scan files to build a training dataset Scan files to build a training dataset
@@ -142,7 +154,7 @@ def gather(paths, out, ucs_data):
types = ['.wav', '.flac'] types = ['.wav', '.flac']
logger.debug(f"Loading category list...") logger.debug(f"Loading category list...")
ucs = load_ucs() ucs = load_ucs(full_ucs=ctx.obj['complete_ucs'])
scan_list = [] scan_list = []
catid_list = [cat.catid for cat in ucs] catid_list = [cat.catid for cat in ucs]
@@ -152,32 +164,44 @@ def gather(paths, out, ucs_data):
paths = [] paths = []
for path in paths: for path in paths:
logger.info(f"Scanning directory {path}...")
for dirpath, _, filenames in os.walk(path): for dirpath, _, filenames in os.walk(path):
logger.info(f"Walking directory {dirpath}")
for filename in filenames: for filename in filenames:
root, ext = os.path.splitext(filename) root, ext = os.path.splitext(filename)
if ext in types and \ if ext not in types or filename.startswith("._"):
(ucs_components := parse_ucs(root, catid_list)) and \ continue
not filename.startswith("._"):
scan_list.append((ucs_components.cat_id, if (ucs_components := parse_ucs(root, catid_list)):
os.path.join(dirpath, filename))) p = os.path.join(dirpath, filename)
logger.info(f"Adding path to scan list {p}")
scan_list.append((ucs_components.cat_id, p))
logger.info(f"Found {len(scan_list)} files to process.") logger.info(f"Found {len(scan_list)} files to process.")
def scan_metadata(): def scan_metadata():
for pair in tqdm.tqdm(scan_list, unit='files'): for pair in tqdm.tqdm(scan_list, unit='files'):
if desc := ffmpeg_description(pair[1]): logger.info(f"Scanning file with ffprobe: {pair[1]}")
try:
desc = ffmpeg_description(pair[1])
except CalledProcessError as e:
logger.error(f"ffprobe returned error {e.returncode}: " \
+ e.stderr)
continue
if desc:
yield desc, str(pair[0]) yield desc, str(pair[0])
else: else:
comps = parse_ucs(os.path.basename(pair[1]), catid_list) comps = parse_ucs(os.path.basename(pair[1]), catid_list)
assert comps assert comps
yield comps.fx_name, str(pair[0]) yield comps.fx_name, str(pair[0])
def ucs_metadata(): def ucs_metadata():
for cat in ucs: for cat in ucs:
yield cat.explanations, cat.catid yield cat.explanations, cat.catid
yield ", ".join(cat.synonymns), cat.catid yield ", ".join(cat.synonymns), cat.catid
logger.info("Building dataset...")
if ucs_data: if ucs_data:
dataset = build_sentence_class_dataset(ucs_metadata(), catid_list) dataset = build_sentence_class_dataset(ucs_metadata(), catid_list)
else: else:
@@ -202,9 +226,6 @@ def finetune(ctx):
help='Skip this many records in the dataset before processing') help='Skip this many records in the dataset before processing')
@click.option('--limit', type=int, default=-1, metavar="<int>", @click.option('--limit', type=int, default=-1, metavar="<int>",
help='Process this many records and then exit') help='Process this many records and then exit')
@click.option('--no-foley', 'no_foley', flag_value=True, default=False,
help="Ignore any data in the set with FOLYProp or FOLYFeet "
"category")
@click.argument('dataset', type=click.File('r', encoding='utf8'), @click.argument('dataset', type=click.File('r', encoding='utf8'),
default='dataset.csv') default='dataset.csv')
@click.pass_context @click.pass_context
@@ -229,84 +250,86 @@ def evaluate(ctx, dataset, offset, limit, no_foley):
foley, and so these categories can be excluded with the --no-foley option. foley, and so these categories can be excluded with the --no-foley option.
""" """
logger.debug("EVALUATE mode") logger.debug("EVALUATE mode")
inference_context = InferenceContext(ctx.obj['model_name'], logger.warning("Model evaluation is not currently implemented")
use_cached_model= # inference_context = InferenceContext(
ctx.obj['model_cache']) # ctx.obj['model_name'], use_cached_model=ctx.obj['model_cache'],
reader = csv.reader(dataset) # use_full_ucs=ctx.obj['complete_ucs'])
#
results = [] # reader = csv.reader(dataset)
#
if offset > 0: # results = []
logger.debug(f"Skipping {offset} records...") #
# if offset > 0:
if limit > 0: # logger.debug(f"Skipping {offset} records...")
logger.debug(f"Will only evaluate {limit} records...") #
# if limit > 0:
progress_bar = tqdm.tqdm(total=limit, # logger.debug(f"Will only evaluate {limit} records...")
desc="Processing dataset...", #
unit="rec") # progress_bar = tqdm.tqdm(total=limit,
for i, row in enumerate(reader): # desc="Processing dataset...",
if i < offset: # unit="rec")
continue # for i, row in enumerate(reader):
# if i < offset:
if limit > 0 and i >= limit + offset: # continue
break #
# if limit > 0 and i >= limit + offset:
cat_id, description = row # break
if no_foley and cat_id in ['FOLYProp', 'FOLYFeet']: #
continue # cat_id, description = row
# if no_foley and cat_id in ['FOLYProp', 'FOLYFeet']:
guesses = inference_context.classify_text_ranked(description, limit=10) # continue
if cat_id == guesses[0]: #
results.append({'catid': cat_id, 'result': "TOP"}) # guesses = inference_context.classify_text_ranked(description, limit=10)
elif cat_id in guesses[0:5]: # if cat_id == guesses[0]:
results.append({'catid': cat_id, 'result': "TOP_5"}) # results.append({'catid': cat_id, 'result': "TOP"})
elif cat_id in guesses: # elif cat_id in guesses[0:5]:
results.append({'catid': cat_id, 'result': "TOP_10"}) # results.append({'catid': cat_id, 'result': "TOP_5"})
else: # elif cat_id in guesses:
results.append({'catid': cat_id, 'result': "MISS"}) # results.append({'catid': cat_id, 'result': "TOP_10"})
# else:
progress_bar.update(1) # results.append({'catid': cat_id, 'result': "MISS"})
#
total = len(results) # progress_bar.update(1)
total_top = len([x for x in results if x['result'] == 'TOP']) #
total_top_5 = len([x for x in results if x['result'] == 'TOP_5']) # total = len(results)
total_top_10 = len([x for x in results if x['result'] == 'TOP_10']) # total_top = len([x for x in results if x['result'] == 'TOP'])
# total_top_5 = len([x for x in results if x['result'] == 'TOP_5'])
cats = set([x['catid'] for x in results]) # total_top_10 = len([x for x in results if x['result'] == 'TOP_10'])
total_cats = len(cats) #
# cats = set([x['catid'] for x in results])
miss_counts = [] # total_cats = len(cats)
for cat in cats: #
miss_counts.append( # miss_counts = []
(cat, len([x for x in results # for cat in cats:
if x['catid'] == cat and x['result'] == 'MISS']))) # miss_counts.append(
# (cat, len([x for x in results
miss_counts = sorted(miss_counts, key=lambda x: x[1]) # if x['catid'] == cat and x['result'] == 'MISS'])))
#
print(f"## Results for Model {model} ##\n") # miss_counts = sorted(miss_counts, key=lambda x: x[1])
#
if no_foley: # print(f"## Results for Model {model} ##\n")
print("(FOLYProp and FOLYFeet have been omitted from the dataset.)\n") #
# if no_foley:
table = [ # print("(FOLYProp and FOLYFeet have been omitted from the dataset.)\n")
["Total records in sample:", f"{total}"], #
["Top Result:", f"{total_top}", # table = [
f"{float(total_top)/float(total):.2%}"], # ["Total records in sample:", f"{total}"],
["Top 5 Result:", f"{total_top_5}", # ["Top Result:", f"{total_top}",
f"{float(total_top_5)/float(total):.2%}"], # f"{float(total_top)/float(total):.2%}"],
["Top 10 Result:", f"{total_top_10}", # ["Top 5 Result:", f"{total_top_5}",
f"{float(total_top_10)/float(total):.2%}"], # f"{float(total_top_5)/float(total):.2%}"],
SEPARATING_LINE, # ["Top 10 Result:", f"{total_top_10}",
["UCS category count:", f"{len(inference_context.catlist)}"], # f"{float(total_top_10)/float(total):.2%}"],
["Total categories in sample:", f"{total_cats}", # SEPARATING_LINE,
f"{float(total_cats)/float(len(inference_context.catlist)):.2%}"], # ["UCS category count:", f"{len(inference_context.catlist)}"],
[f"Most missed category ({miss_counts[-1][0]}):", # ["Total categories in sample:", f"{total_cats}",
f"{miss_counts[-1][1]}", # f"{float(total_cats)/float(len(inference_context.catlist)):.2%}"],
f"{float(miss_counts[-1][1])/float(total):.2%}"] # [f"Most missed category ({miss_counts[-1][0]}):",
] # f"{miss_counts[-1][1]}",
# f"{float(miss_counts[-1][1])/float(total):.2%}"]
print(tabulate(table, headers=['', 'n', 'pct'], tablefmt='github')) # ]
#
# print(tabulate(table, headers=['', 'n', 'pct'], tablefmt='github'))
if __name__ == '__main__': if __name__ == '__main__':

33
ucsinfer/gather.py Normal file
View File

@@ -0,0 +1,33 @@
from datasets import Dataset, Features, Value, ClassLabel
from typing import Generator, Any
# https://www.sbert.net/docs/sentence_transformer/loss_overview.html
def build_sentence_class_dataset(
records: Generator[tuple[str, str], Any, None], catlist: list[str]) -> Dataset:
"""
Create a new dataset for `records` which contains (sentence, class) pairs.
:param records: a generator for records that generates pairs of
(sentence, catid)
:returns: A dataset with two columns: (sentence, hash(catid))
"""
labels = ClassLabel(names=catlist)
items: list[dict] = []
for obj in records:
items += [{'sentence': obj[0], 'class': obj[1]}]
return Dataset.from_list(items, features=Features({'sentence': Value('string'),
'class': labels}))
# def build_sentence_anchor_dataset() -> Dataset:
# """
# Create a new dataset for `records` which contains (sentence, anchor) pairs.
# """
# pass

View File

@@ -34,7 +34,7 @@ class Ucs(NamedTuple):
explanations=d['Explanations'], synonymns=d['Synonyms']) explanations=d['Explanations'], synonymns=d['Synonyms'])
def load_ucs() -> list[Ucs]: def load_ucs(full_ucs: bool = True) -> list[Ucs]:
FILE_ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) FILE_ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
cats = [] cats = []
ucs_defs = os.path.join(FILE_ROOT_DIR, 'ucs-community', 'json', ucs_defs = os.path.join(FILE_ROOT_DIR, 'ucs-community', 'json',
@@ -43,7 +43,13 @@ def load_ucs() -> list[Ucs]:
with open(ucs_defs, 'r') as f: with open(ucs_defs, 'r') as f:
cats = json.load(f) cats = json.load(f)
return [Ucs.from_dict(cat) for cat in cats] ucs = [Ucs.from_dict(cat) for cat in cats]
if full_ucs:
return ucs
else:
return [cat for cat in ucs if \
cat.category not in ['FOLEY', 'ARCHIVED']]
class InferenceContext: class InferenceContext:
@@ -54,8 +60,11 @@ class InferenceContext:
model: SentenceTransformer model: SentenceTransformer
model_name: str model_name: str
def __init__(self, model_name: str, use_cached_model: bool = True): def __init__(self, model_name: str, use_cached_model: bool = True,
use_full_ucs: bool = False):
self.model_name = model_name self.model_name = model_name
self.use_full_ucs = use_full_ucs
cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer", cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer",
'Squad 51') 'Squad 51')
@@ -75,7 +84,7 @@ class InferenceContext:
@cached_property @cached_property
def catlist(self) -> list[Ucs]: def catlist(self) -> list[Ucs]:
return load_ucs() return load_ucs(full_ucs=self.use_full_ucs)
@cached_property @cached_property
def embeddings(self) -> list[dict]: def embeddings(self) -> list[dict]:

View File

@@ -9,10 +9,7 @@ def ffmpeg_description(path: str) -> Optional[str]:
result = subprocess.run(['ffprobe', '-show_format', '-of', result = subprocess.run(['ffprobe', '-show_format', '-of',
'json', path], capture_output=True) 'json', path], capture_output=True)
try: result.check_returncode()
result.check_returncode()
except:
return None
stream = json.loads(result.stdout) stream = json.loads(result.stdout)
fmt = stream.get("format", None) fmt = stream.get("format", None)