Compare commits

...

9 Commits

Author SHA1 Message Date
615d8ab279 Refactoring of gather 2025-09-10 23:49:33 -07:00
9a887c4ed5 Refactoring of gather 2025-09-10 23:38:33 -07:00
63b140209b Command-line entry point:wq 2025-09-10 22:23:57 -07:00
d181ac73b1 Made prompt shorter 2025-09-10 22:08:16 -07:00
bddce23c76 Added to prompt help 2025-09-10 22:07:53 -07:00
b4758dd138 Some features for recommend, a browse feature 2025-09-06 14:10:53 -07:00
04332b73ee Fixed a bug in cat masking 2025-09-04 10:54:35 -07:00
103fffe0a4 Added another function to recommend 2025-09-04 10:43:31 -07:00
10519f9c1a Split training sets from gather 2025-09-03 19:40:55 -07:00
7 changed files with 152 additions and 161 deletions

View File

@@ -17,8 +17,9 @@
- Print more information about the dataset coverage of UCS
- Allow skipping model testing for this
- Print raw output
- Maybe load everything into a sqlite for slicker reporting
<!-- - Maybe load everything into a sqlite for slicker reporting -->
## Utility

View File

@@ -26,3 +26,6 @@ build-backend = "poetry.core.masonry.api"
ipython = "^9.4.0"
jupyter = "^1.1.1"
[tool.poetry.scripts]
ucsinfer = "ucsinfer.__main__:ucsinfer"

View File

@@ -1,14 +1,15 @@
import os
# import csv
import logging
from subprocess import CalledProcessError
from itertools import chain
import tqdm
import click
# from tabulate import tabulate, SEPARATING_LINE
from .inference import InferenceContext, load_ucs
from .gather import build_sentence_class_dataset
from .gather import (build_sentence_class_dataset, print_dataset_stats,
ucs_definitions_generator, scan_metadata, walk_path)
from .recommend import print_recommendation
from .util import ffmpeg_description, parse_ucs
@@ -102,6 +103,11 @@ def recommend(ctx, text, paths, interactive, skip_ucs):
catlist = [x.catid for x in inference_ctx.catlist]
for path in paths:
_, ext = os.path.splitext(path)
if ext not in (".wav", ".flac"):
continue
basename = os.path.basename(path)
if skip_ucs and parse_ucs(basename, catlist):
continue
@@ -151,12 +157,10 @@ def gather(ctx, paths, out, ucs_data):
"""
logger.debug("GATHER mode")
types = ['.wav', '.flac']
logger.debug(f"Loading category list...")
ucs = load_ucs(full_ucs=ctx.obj['complete_ucs'])
scan_list = []
scan_list: list[tuple[str,str]] = []
catid_list = [cat.catid for cat in ucs]
if ucs_data:
@@ -164,50 +168,19 @@ def gather(ctx, paths, out, ucs_data):
paths = []
for path in paths:
for dirpath, _, filenames in os.walk(path):
logger.info(f"Walking directory {dirpath}")
for filename in filenames:
root, ext = os.path.splitext(filename)
if ext not in types or filename.startswith("._"):
continue
if (ucs_components := parse_ucs(root, catid_list)):
p = os.path.join(dirpath, filename)
logger.info(f"Adding path to scan list {p}")
scan_list.append((ucs_components.cat_id, p))
scan_list += walk_path(path, catid_list)
logger.info(f"Found {len(scan_list)} files to process.")
def scan_metadata():
for pair in tqdm.tqdm(scan_list, unit='files'):
logger.info(f"Scanning file with ffprobe: {pair[1]}")
try:
desc = ffmpeg_description(pair[1])
except CalledProcessError as e:
logger.error(f"ffprobe returned error {e.returncode}: " \
+ e.stderr)
continue
if desc:
yield desc, str(pair[0])
else:
comps = parse_ucs(os.path.basename(pair[1]), catid_list)
assert comps
yield comps.fx_name, str(pair[0])
def ucs_metadata():
for cat in ucs:
yield cat.explanations, cat.catid
yield ", ".join(cat.synonymns), cat.catid
logger.info("Building dataset...")
if ucs_data:
dataset = build_sentence_class_dataset(ucs_metadata(), catid_list)
else:
dataset = build_sentence_class_dataset(scan_metadata(), catid_list)
dataset = build_sentence_class_dataset(
chain(scan_metadata(scan_list, catid_list),
ucs_definitions_generator(ucs)),
catid_list)
logger.info(f"Saving dataset to disk at {out}")
print_dataset_stats(dataset, catid_list)
dataset.save_to_disk(out)
@@ -222,114 +195,15 @@ def finetune(ctx):
@ucsinfer.command('evaluate')
@click.option('--offset', type=int, default=0, metavar="<int>",
help='Skip this many records in the dataset before processing')
@click.option('--limit', type=int, default=-1, metavar="<int>",
help='Process this many records and then exit')
@click.argument('dataset', type=click.File('r', encoding='utf8'),
default='dataset.csv')
@click.argument('dataset', default='dataset/')
@click.pass_context
def evaluate(ctx, dataset, offset, limit, no_foley):
def evaluate(ctx, dataset, offset, limit):
"""
Use datasets to evaluate model performance
The `evaluate` command reads the input DATASET file row by row and
performs a classifcation of the given description against the selected
model (either the default or using the --model option). The command then
checks if the model inferred the correct category as given by the dataset.
The model gives its top 10 possible categories for a given description,
and the results are tabulated according to (1) wether the top
classification was correct, (2) wether the correct classifcation was in the
top 5, or (3) wether it was in the top 10. The worst-performing category,
the one with the most misses, is also reported as well as the category
coverage, how many categories are present in the dataset.
NOTE: With experimentation it was found that foley items generally were
classified according to their subject and not wether or not they were
foley, and so these categories can be excluded with the --no-foley option.
"""
logger.debug("EVALUATE mode")
logger.warning("Model evaluation is not currently implemented")
# inference_context = InferenceContext(
# ctx.obj['model_name'], use_cached_model=ctx.obj['model_cache'],
# use_full_ucs=ctx.obj['complete_ucs'])
#
# reader = csv.reader(dataset)
#
# results = []
#
# if offset > 0:
# logger.debug(f"Skipping {offset} records...")
#
# if limit > 0:
# logger.debug(f"Will only evaluate {limit} records...")
#
# progress_bar = tqdm.tqdm(total=limit,
# desc="Processing dataset...",
# unit="rec")
# for i, row in enumerate(reader):
# if i < offset:
# continue
#
# if limit > 0 and i >= limit + offset:
# break
#
# cat_id, description = row
# if no_foley and cat_id in ['FOLYProp', 'FOLYFeet']:
# continue
#
# guesses = inference_context.classify_text_ranked(description, limit=10)
# if cat_id == guesses[0]:
# results.append({'catid': cat_id, 'result': "TOP"})
# elif cat_id in guesses[0:5]:
# results.append({'catid': cat_id, 'result': "TOP_5"})
# elif cat_id in guesses:
# results.append({'catid': cat_id, 'result': "TOP_10"})
# else:
# results.append({'catid': cat_id, 'result': "MISS"})
#
# progress_bar.update(1)
#
# total = len(results)
# total_top = len([x for x in results if x['result'] == 'TOP'])
# total_top_5 = len([x for x in results if x['result'] == 'TOP_5'])
# total_top_10 = len([x for x in results if x['result'] == 'TOP_10'])
#
# cats = set([x['catid'] for x in results])
# total_cats = len(cats)
#
# miss_counts = []
# for cat in cats:
# miss_counts.append(
# (cat, len([x for x in results
# if x['catid'] == cat and x['result'] == 'MISS'])))
#
# miss_counts = sorted(miss_counts, key=lambda x: x[1])
#
# print(f"## Results for Model {model} ##\n")
#
# if no_foley:
# print("(FOLYProp and FOLYFeet have been omitted from the dataset.)\n")
#
# table = [
# ["Total records in sample:", f"{total}"],
# ["Top Result:", f"{total_top}",
# f"{float(total_top)/float(total):.2%}"],
# ["Top 5 Result:", f"{total_top_5}",
# f"{float(total_top_5)/float(total):.2%}"],
# ["Top 10 Result:", f"{total_top_10}",
# f"{float(total_top_10)/float(total):.2%}"],
# SEPARATING_LINE,
# ["UCS category count:", f"{len(inference_context.catlist)}"],
# ["Total categories in sample:", f"{total_cats}",
# f"{float(total_cats)/float(len(inference_context.catlist)):.2%}"],
# [f"Most missed category ({miss_counts[-1][0]}):",
# f"{miss_counts[-1][1]}",
# f"{float(miss_counts[-1][1])/float(total):.2%}"]
# ]
#
# print(tabulate(table, headers=['', 'n', 'pct'], tablefmt='github'))
if __name__ == '__main__':

7
ucsinfer/evaluate.py Normal file
View File

@@ -0,0 +1,7 @@
# from sentence_transformers import SentenceTransformer
# from sentence_transformers.evaluation import BinaryClassificationEvaluator
# from datasets import load_dataset_from_disk, DatasetDict
#

View File

@@ -1,13 +1,82 @@
from datasets import Dataset, Features, Value, ClassLabel, DatasetInfo
from .inference import Ucs
from .util import ffmpeg_description, parse_ucs
from typing import Generator, Any
from subprocess import CalledProcessError
import os.path
from datasets import Dataset, Features, Value, ClassLabel, DatasetInfo
from datasets.dataset_dict import DatasetDict
from typing import Iterator, Generator
from tabulate import tabulate
import logging
import tqdm
def walk_path(path:str, catid_list) -> list[tuple[str,str]]:
types = ['.wav', '.flac']
logger = logging.getLogger('ucsinfer')
walker_p = tqdm.tqdm(total=None, unit='dir', desc="Walking filesystem...")
scan_list = []
for dirpath, _, filenames in os.walk(path):
logger.info(f"Walking directory {dirpath}")
for filename in filenames:
walker_p.update()
root, ext = os.path.splitext(filename)
if ext not in types or filename.startswith("._"):
continue
if (ucs_components := parse_ucs(root, catid_list)):
p = os.path.join(dirpath, filename)
logger.info(f"Adding path to scan list {p}")
scan_list.append((ucs_components.cat_id, p))
return scan_list
def scan_metadata(scan_list: list[tuple[str,str]], catid_list: list[str]):
logger = logging.getLogger('ucsinfer')
for pair in tqdm.tqdm(scan_list, unit='files'):
logger.info(f"Scanning file with ffprobe: {pair[1]}")
try:
desc = ffmpeg_description(pair[1])
except CalledProcessError as e:
logger.error(f"ffprobe returned error (){e.returncode}): " \
+ e.stderr)
continue
if desc:
yield desc, str(pair[0])
else:
comps = parse_ucs(os.path.basename(pair[1]), catid_list)
assert comps
yield comps.fx_name, str(pair[0])
def ucs_definitions_generator(ucs: list[Ucs]) \
-> Generator[tuple[str,str],None, None]:
for cat in ucs:
yield cat.explanations, cat.catid
yield ", ".join(cat.synonymns), cat.catid
def print_dataset_stats(dataset: DatasetDict, catlist: list[str]):
data_table = []
data_table.append([["Total records in combined dataset:", len(dataset)]])
data_table.append([["Total records in `train`:", len(dataset['train'])]])
tab = tabulate(data_table)
print(tab)
# https://www.sbert.net/docs/sentence_transformer/loss_overview.html
def build_sentence_class_dataset(
records: Generator[tuple[str, str], Any, None], catlist: list[str]) -> Dataset:
records: Iterator[tuple[str, str]],
catlist: list[str]) -> DatasetDict:
"""
Create a new dataset for `records` which contains (sentence, class) pairs.
The dataset is split into train and test slices.
:param records: a generator for records that generates pairs of
(sentence, catid)
@@ -16,9 +85,12 @@ def build_sentence_class_dataset(
labels = ClassLabel(names=catlist)
features = Features({'sentence': Value('string'),
'class': labels})
info = DatasetInfo(
description=f"(sentence, UCS CatID) pairs gathered by the "
"ucsinfer tool on {}")
"ucsinfer tool on {}", features= features)
items: list[dict] = []
@@ -26,9 +98,16 @@ def build_sentence_class_dataset(
items += [{'sentence': obj[0], 'class': obj[1]}]
return Dataset.from_list(items, features=Features({'sentence': Value('string'),
'class': labels}),
info=info)
whole = Dataset.from_list(items, features=features, info=info)
split_set = whole.train_test_split(0.2)
test_eval_set = split_set['test'].train_test_split(0.5)
return DatasetDict({
'train': split_set['train'],
'test': test_eval_set['train'],
'eval': test_eval_set['test']
})
# def build_sentence_anchor_dataset() -> Dataset:

View File

@@ -103,7 +103,12 @@ class InferenceContext:
else:
print(f"Calculating embeddings for model {self.model_name}...")
for cat_defn in self.catlist:
# we need to calculate the embeddings for all cats, not just the
# ones we're loading for this run
full_catlist = load_ucs(full_ucs= True)
for cat_defn in full_catlist:
embeddings += [{
'CatID': cat_defn.catid,
'Embedding': self._encode_category(cat_defn)
@@ -113,7 +118,9 @@ class InferenceContext:
with open(embedding_cache_path, 'wb') as g:
pickle.dump(embeddings, g)
return embeddings
whitelisted_cats = [cat.catid for cat in self.catlist]
return [e for e in embeddings if e['CatID'] in whitelisted_cats]
def _encode_category(self, cat: Ucs) -> np.ndarray:
sentence_components = [cat.explanations,

View File

@@ -4,8 +4,10 @@ from re import match
from .inference import InferenceContext
from tabulate import tabulate
def print_recommendation(path: str | None, text: str, ctx: InferenceContext,
interactive_rename: bool):
interactive_rename: bool, recommend_limit=10):
"""
Print recommendations interactively.
@@ -17,18 +19,19 @@ def print_recommendation(path: str | None, text: str, ctx: InferenceContext,
`print_recommendation` should be called again with this argument.
- if retval[2] is a str, this is the catid the user has selected.
"""
recs = ctx.classify_text_ranked(text)
recs = ctx.classify_text_ranked(text, limit=recommend_limit)
print("----------")
if path:
print(f"Path: {path}")
print(f"Text: {text or '<None>'}")
for i, r in enumerate(recs):
cat, subcat, _ = ctx.lookup_category(r)
print(f"- {i}: {r} ({cat}-{subcat})")
if interactive_rename and path is not None:
response = input("#, t [text], ?, q > ")
response = input("(n#), t, c, b, ?, q > ")
if m := match(r'^([0-9]+)', response):
selection = int(m.group(1))
@@ -43,12 +46,27 @@ def print_recommendation(path: str | None, text: str, ctx: InferenceContext,
text = m.group(1)
return True, text, None
elif m := match(r'^c (.+)', response):
return True, None, m.group(1)
elif m := match(r'^b (.+)', response):
expt = []
for cat in ctx.catlist:
if cat.catid.startswith(m.group(1)):
expt.append([f"{cat.catid}: ({cat.category}-{cat.subcategory})",
cat.explanations])
print(tabulate(expt, maxcolwidths=80))
return True, text, None
elif response.startswith("?"):
print("""
Choices:
- Enter recommendation number to rename file,
- "t [text]" to search for new recommendations based on [text]
- "t <text>" to search for new recommendations based on <text>
- "p" re-use the last selected cat-id
- "c <cat>" to type in a category by hand
- "b <cat>" browse category list for categories starting with <cat>
- "?" for this message
- "q" to quit
- or any other key to skip this file and continue to next file
@@ -56,6 +74,8 @@ Choices:
return True, text, None
elif response.startswith('q'):
return (False, None, None)
else:
print()
else:
return None