Compare commits

...

17 Commits

Author SHA1 Message Date
615d8ab279 Refactoring of gather 2025-09-10 23:49:33 -07:00
9a887c4ed5 Refactoring of gather 2025-09-10 23:38:33 -07:00
63b140209b Command-line entry point:wq 2025-09-10 22:23:57 -07:00
d181ac73b1 Made prompt shorter 2025-09-10 22:08:16 -07:00
bddce23c76 Added to prompt help 2025-09-10 22:07:53 -07:00
b4758dd138 Some features for recommend, a browse feature 2025-09-06 14:10:53 -07:00
04332b73ee Fixed a bug in cat masking 2025-09-04 10:54:35 -07:00
103fffe0a4 Added another function to recommend 2025-09-04 10:43:31 -07:00
10519f9c1a Split training sets from gather 2025-09-03 19:40:55 -07:00
e419f698c9 Dataset metadata 2025-09-03 16:31:05 -07:00
6cd0415a26 README notes 2025-09-03 16:25:05 -07:00
b833d6d3c0 Added better error reporting to gather
And a bunch of options to control behavior
2025-09-03 16:07:41 -07:00
fee55a7d5a Twiddles 2025-09-03 14:56:41 -07:00
0594899bdd dataset writing in HF format 2025-09-03 14:52:10 -07:00
46a693bf93 tweaks 2025-09-03 14:23:09 -07:00
0a2aaa2a22 refining gather, adding TODOs 2025-09-03 14:13:12 -07:00
d855ba4c78 removed print message 2025-09-03 12:25:33 -07:00
9 changed files with 266 additions and 173 deletions

View File

@@ -49,15 +49,13 @@ Pass `--help` to see a summary of subcommands and options.
* gather
Scan files to capture existing text descriptions and UCS categories
and save as a dataset. This function is used to construct datasets
that `evaluate` can use to test models and finetune can use to
refine them.
and save as a dataset.
* ~finetune~ (planned)
Fine-tune an existing sentence embedding model with training data.
* evaluate
* ~evaluate~ (FIXME phase)
Use datasets to evaluate the performance of a model and fine-tuning.

14
TODO.md
View File

@@ -2,21 +2,27 @@
- Use History when adding catids
## Gather
- Add "source" column for tracking provenance
- Maybe more dataset configurations
## Fine-tune
- Implement
- Implement BatchAllTripletLoss
## Evaluate
- Print more information about the dataset coverage of UCS
- Allow skipping model testing for this
- Print raw output
- Maybe load everything into a sqlite for slicker reporting
<!-- - Maybe load everything into a sqlite for slicker reporting -->
## Utility
- Dataset partitioning
- Clear caches

View File

@@ -13,7 +13,8 @@ dependencies = [
"tqdm (>=4.67.1,<5.0.0)",
"platformdirs (>=4.3.8,<5.0.0)",
"click (>=8.2.1,<9.0.0)",
"tabulate (>=0.9.0,<0.10.0)"
"tabulate (>=0.9.0,<0.10.0)",
"datasets (>=4.0.0,<5.0.0)"
]
@@ -25,3 +26,6 @@ build-backend = "poetry.core.masonry.api"
ipython = "^9.4.0"
jupyter = "^1.1.1"
[tool.poetry.scripts]
ucsinfer = "ucsinfer.__main__:ucsinfer"

View File

@@ -1,13 +1,15 @@
import os
import sys
import csv
# import csv
import logging
from itertools import chain
import tqdm
import click
from tabulate import tabulate, SEPARATING_LINE
# from tabulate import tabulate, SEPARATING_LINE
from .inference import InferenceContext, load_ucs
from .gather import (build_sentence_class_dataset, print_dataset_stats,
ucs_definitions_generator, scan_metadata, walk_path)
from .recommend import print_recommendation
from .util import ffmpeg_description, parse_ucs
@@ -23,10 +25,17 @@ logger.addHandler(stream_handler)
@click.group(epilog="For more information see "
"<https://git.squad51.us/jamie/ucsinfer>")
@click.option('--verbose', '-v', flag_value=True, help='Verbose output')
@click.option('--model', type=str, metavar="<model-name>",
default="paraphrase-multilingual-mpnet-base-v2",
show_default=True,
help="Select the sentence_transformer model to use")
@click.option('--no-model-cache', flag_value=True,
help="Don't use local model cache")
@click.option('--complete-ucs', flag_value=True, default=False,
help="Use all UCS categories. By default, all 'FOLEY' and "
"'ARCHIVED' UCS categories are excluded from all functions.")
@click.pass_context
def ucsinfer(ctx, verbose, no_model_cache):
def ucsinfer(ctx, verbose, no_model_cache, model, complete_ucs):
"""
Tools for applying UCS categories to sounds using large-language Models
"""
@@ -44,6 +53,19 @@ def ucsinfer(ctx, verbose, no_model_cache):
ctx.ensure_object(dict)
ctx.obj['model_cache'] = not no_model_cache
ctx.obj['model_name'] = model
ctx.obj['complete_ucs'] = complete_ucs
if no_model_cache:
logger.info("Model cache inhibited by config")
if complete_ucs:
logger.info("Using complete UCS catgeory list")
else:
logger.info("Non-descriptive UCS categories will be excluded. Turn "
"this option off by passing --complete-ucs.")
logger.info(f"Using model {model}")
@ucsinfer.command('recommend')
@@ -51,10 +73,6 @@ def ucsinfer(ctx, verbose, no_model_cache):
help="Recommend a category for given text instead of reading "
"from a file")
@click.argument('paths', nargs=-1, metavar='<paths>')
@click.option('--model', type=str, metavar="<model-name>",
default="paraphrase-multilingual-mpnet-base-v2",
show_default=True,
help="Select the sentence_transformer model to use")
@click.option('--interactive','-i', flag_value=True, default=False,
help="After processing each path in <paths>, prompt for a "
"recommendation to accept, and then prepend the selection to "
@@ -63,7 +81,7 @@ def ucsinfer(ctx, verbose, no_model_cache):
help="Skip files that already have a UCS category in their "
"name.")
@click.pass_context
def recommend(ctx, text, paths, model, interactive, skip_ucs):
def recommend(ctx, text, paths, interactive, skip_ucs):
"""
Infer a UCS category for a text description
@@ -74,8 +92,9 @@ def recommend(ctx, text, paths, model, interactive, skip_ucs):
of ranked subcategories is printed to the terminal for each PATH.
"""
logger.debug("RECOMMEND mode")
inference_ctx = InferenceContext(model,
use_cached_model=ctx.obj['model_cache'])
inference_ctx = InferenceContext(ctx.obj['model_name'],
use_cached_model=ctx.obj['model_cache'],
use_full_ucs=ctx.obj['complete_ucs'])
if text is not None:
print_recommendation(None, text, inference_ctx,
@@ -84,6 +103,11 @@ def recommend(ctx, text, paths, model, interactive, skip_ucs):
catlist = [x.catid for x in inference_ctx.catlist]
for path in paths:
_, ext = os.path.splitext(path)
if ext not in (".wav", ".flac"):
continue
basename = os.path.basename(path)
if skip_ucs and parse_ucs(basename, catlist):
continue
@@ -110,51 +134,59 @@ def recommend(ctx, text, paths, model, interactive, skip_ucs):
@ucsinfer.command('gather')
@click.option('--outfile', type=click.File(mode='w', encoding='utf8'),
default='dataset.csv', show_default=True)
@click.option('--out', default='dataset/', show_default=True)
@click.option('--ucs-data', flag_value=True, help="Create a dataset based "
"on the UCS category explanations and synonymns (PATHS will "
"be ignored.)")
@click.argument('paths', nargs=-1)
def gather(paths, outfile):
@click.pass_context
def gather(ctx, paths, out, ucs_data):
"""
Scan files to build a training dataset at PATH
Scan files to build a training dataset
The `gather` command walks the directory hierarchy for each path in PATHS
and looks for .wav and .flac files that are named according to the UCS
file naming guidelines, with at least a CatID and FX Name, divided by an
underscore.
The `gather` is used to build a training dataset for finetuning the
selected model. Description sentences and UCS categories are collected from
'.wav' and '.flac' files on-disk that have valid UCS filenames and assigned
CatIDs, and this information is recorded into a HuggingFace dataset.
For every file ucsinfer finds that meets this criteria, it creates a record
in an output dataset CSV file. The dataset file has two columns: the first
is the CatID indicated for the file, and the second is the embedded file
description for the file as returned by ffprobe.
Gather scans the filesystem in two passes: first, the directory tree is
walked by os.walk and a list of filenames that meet the above name criteria
is compiled. After this list is compiled, each file is scanned one-by-one
with ffprobe to obtain its "description" metadata; if this isn't present,
the parsed 'fxname' of te file becomes the description.
"""
logger.debug("GATHER mode")
types = ['.wav', '.flac']
table = csv.writer(outfile)
logger.debug(f"Loading category list...")
catid_list = [cat.catid for cat in load_ucs()]
ucs = load_ucs(full_ucs=ctx.obj['complete_ucs'])
scan_list: list[tuple[str,str]] = []
catid_list = [cat.catid for cat in ucs]
if ucs_data:
logger.info('Creating dataset for UCS categories instead of from PATH')
paths = []
scan_list = []
for path in paths:
logger.info(f"Scanning directory {path}...")
for dirpath, _, filenames in os.walk(path):
for filename in filenames:
root, ext = os.path.splitext(filename)
if ext in types and \
(ucs_components := parse_ucs(root, catid_list)) and \
not filename.startswith("._"):
scan_list.append((ucs_components.cat_id,
os.path.join(dirpath, filename)))
scan_list += walk_path(path, catid_list)
logger.info(f"Found {len(scan_list)} files to process.")
for pair in tqdm.tqdm(scan_list, unit='files', file=sys.stderr):
if desc := ffmpeg_description(pair[1]):
table.writerow([pair[0], desc])
logger.info("Building dataset...")
dataset = build_sentence_class_dataset(
chain(scan_metadata(scan_list, catid_list),
ucs_definitions_generator(ucs)),
catid_list)
logger.info(f"Saving dataset to disk at {out}")
print_dataset_stats(dataset, catid_list)
dataset.save_to_disk(out)
@ucsinfer.command('finetune')
def finetune():
@click.pass_context
def finetune(ctx):
"""
Fine-tune a model with training data
"""
@@ -163,120 +195,15 @@ def finetune():
@ucsinfer.command('evaluate')
@click.option('--offset', type=int, default=0, metavar="<int>",
help='Skip this many records in the dataset before processing')
@click.option('--limit', type=int, default=-1, metavar="<int>",
help='Process this many records and then exit')
@click.option('--no-foley', 'no_foley', flag_value=True, default=False,
help="Ignore any data in the set with FOLYProp or FOLYFeet "
"category")
@click.option('--model', type=str, metavar="<model-name>",
default="paraphrase-multilingual-mpnet-base-v2",
show_default=True,
help="Select the sentence_transformer model to use")
@click.argument('dataset', type=click.File('r', encoding='utf8'),
default='dataset.csv')
@click.argument('dataset', default='dataset/')
@click.pass_context
def evaluate(ctx, dataset, offset, limit, model, no_foley):
def evaluate(ctx, dataset, offset, limit):
"""
Use datasets to evaluate model performance
The `evaluate` command reads the input DATASET file row by row and
performs a classifcation of the given description against the selected
model (either the default or using the --model option). The command then
checks if the model inferred the correct category as given by the dataset.
The model gives its top 10 possible categories for a given description,
and the results are tabulated according to (1) wether the top
classification was correct, (2) wether the correct classifcation was in the
top 5, or (3) wether it was in the top 10. The worst-performing category,
the one with the most misses, is also reported as well as the category
coverage, how many categories are present in the dataset.
NOTE: With experimentation it was found that foley items generally were
classified according to their subject and not wether or not they were
foley, and so these categories can be excluded with the --no-foley option.
"""
logger.debug("EVALUATE mode")
inference_context = InferenceContext(model,
use_cached_model=
ctx.obj['model_cache'])
reader = csv.reader(dataset)
logger.warning("Model evaluation is not currently implemented")
logger.info(f"Evaluating model {model}...")
results = []
if offset > 0:
logger.debug(f"Skipping {offset} records...")
if limit > 0:
logger.debug(f"Will only evaluate {limit} records...")
progress_bar = tqdm.tqdm(total=limit,
desc="Processing dataset...",
unit="rec")
for i, row in enumerate(reader):
if i < offset:
continue
if limit > 0 and i >= limit + offset:
break
cat_id, description = row
if no_foley and cat_id in ['FOLYProp', 'FOLYFeet']:
continue
guesses = inference_context.classify_text_ranked(description, limit=10)
if cat_id == guesses[0]:
results.append({'catid': cat_id, 'result': "TOP"})
elif cat_id in guesses[0:5]:
results.append({'catid': cat_id, 'result': "TOP_5"})
elif cat_id in guesses:
results.append({'catid': cat_id, 'result': "TOP_10"})
else:
results.append({'catid': cat_id, 'result': "MISS"})
progress_bar.update(1)
total = len(results)
total_top = len([x for x in results if x['result'] == 'TOP'])
total_top_5 = len([x for x in results if x['result'] == 'TOP_5'])
total_top_10 = len([x for x in results if x['result'] == 'TOP_10'])
cats = set([x['catid'] for x in results])
total_cats = len(cats)
miss_counts = []
for cat in cats:
miss_counts.append(
(cat, len([x for x in results
if x['catid'] == cat and x['result'] == 'MISS'])))
miss_counts = sorted(miss_counts, key=lambda x: x[1])
print(f"## Results for Model {model} ##\n")
if no_foley:
print("(FOLYProp and FOLYFeet have been omitted from the dataset.)\n")
table = [
["Total records in sample:", f"{total}"],
["Top Result:", f"{total_top}",
f"{float(total_top)/float(total):.2%}"],
["Top 5 Result:", f"{total_top_5}",
f"{float(total_top_5)/float(total):.2%}"],
["Top 10 Result:", f"{total_top_10}",
f"{float(total_top_10)/float(total):.2%}"],
SEPARATING_LINE,
["UCS category count:", f"{len(inference_context.catlist)}"],
["Total categories in sample:", f"{total_cats}",
f"{float(total_cats)/float(len(inference_context.catlist)):.2%}"],
[f"Most missed category ({miss_counts[-1][0]}):",
f"{miss_counts[-1][1]}",
f"{float(miss_counts[-1][1])/float(total):.2%}"]
]
print(tabulate(table, headers=['', 'n', 'pct'], tablefmt='github'))
if __name__ == '__main__':

7
ucsinfer/evaluate.py Normal file
View File

@@ -0,0 +1,7 @@
# from sentence_transformers import SentenceTransformer
# from sentence_transformers.evaluation import BinaryClassificationEvaluator
# from datasets import load_dataset_from_disk, DatasetDict
#

118
ucsinfer/gather.py Normal file
View File

@@ -0,0 +1,118 @@
from .inference import Ucs
from .util import ffmpeg_description, parse_ucs
from subprocess import CalledProcessError
import os.path
from datasets import Dataset, Features, Value, ClassLabel, DatasetInfo
from datasets.dataset_dict import DatasetDict
from typing import Iterator, Generator
from tabulate import tabulate
import logging
import tqdm
def walk_path(path:str, catid_list) -> list[tuple[str,str]]:
types = ['.wav', '.flac']
logger = logging.getLogger('ucsinfer')
walker_p = tqdm.tqdm(total=None, unit='dir', desc="Walking filesystem...")
scan_list = []
for dirpath, _, filenames in os.walk(path):
logger.info(f"Walking directory {dirpath}")
for filename in filenames:
walker_p.update()
root, ext = os.path.splitext(filename)
if ext not in types or filename.startswith("._"):
continue
if (ucs_components := parse_ucs(root, catid_list)):
p = os.path.join(dirpath, filename)
logger.info(f"Adding path to scan list {p}")
scan_list.append((ucs_components.cat_id, p))
return scan_list
def scan_metadata(scan_list: list[tuple[str,str]], catid_list: list[str]):
logger = logging.getLogger('ucsinfer')
for pair in tqdm.tqdm(scan_list, unit='files'):
logger.info(f"Scanning file with ffprobe: {pair[1]}")
try:
desc = ffmpeg_description(pair[1])
except CalledProcessError as e:
logger.error(f"ffprobe returned error (){e.returncode}): " \
+ e.stderr)
continue
if desc:
yield desc, str(pair[0])
else:
comps = parse_ucs(os.path.basename(pair[1]), catid_list)
assert comps
yield comps.fx_name, str(pair[0])
def ucs_definitions_generator(ucs: list[Ucs]) \
-> Generator[tuple[str,str],None, None]:
for cat in ucs:
yield cat.explanations, cat.catid
yield ", ".join(cat.synonymns), cat.catid
def print_dataset_stats(dataset: DatasetDict, catlist: list[str]):
data_table = []
data_table.append([["Total records in combined dataset:", len(dataset)]])
data_table.append([["Total records in `train`:", len(dataset['train'])]])
tab = tabulate(data_table)
print(tab)
# https://www.sbert.net/docs/sentence_transformer/loss_overview.html
def build_sentence_class_dataset(
records: Iterator[tuple[str, str]],
catlist: list[str]) -> DatasetDict:
"""
Create a new dataset for `records` which contains (sentence, class) pairs.
The dataset is split into train and test slices.
:param records: a generator for records that generates pairs of
(sentence, catid)
:returns: A dataset with two columns: (sentence, hash(catid))
"""
labels = ClassLabel(names=catlist)
features = Features({'sentence': Value('string'),
'class': labels})
info = DatasetInfo(
description=f"(sentence, UCS CatID) pairs gathered by the "
"ucsinfer tool on {}", features= features)
items: list[dict] = []
for obj in records:
items += [{'sentence': obj[0], 'class': obj[1]}]
whole = Dataset.from_list(items, features=features, info=info)
split_set = whole.train_test_split(0.2)
test_eval_set = split_set['test'].train_test_split(0.5)
return DatasetDict({
'train': split_set['train'],
'test': test_eval_set['train'],
'eval': test_eval_set['test']
})
# def build_sentence_anchor_dataset() -> Dataset:
# """
# Create a new dataset for `records` which contains (sentence, anchor) pairs.
# """
# pass

View File

@@ -34,7 +34,7 @@ class Ucs(NamedTuple):
explanations=d['Explanations'], synonymns=d['Synonyms'])
def load_ucs() -> list[Ucs]:
def load_ucs(full_ucs: bool = True) -> list[Ucs]:
FILE_ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
cats = []
ucs_defs = os.path.join(FILE_ROOT_DIR, 'ucs-community', 'json',
@@ -43,7 +43,13 @@ def load_ucs() -> list[Ucs]:
with open(ucs_defs, 'r') as f:
cats = json.load(f)
return [Ucs.from_dict(cat) for cat in cats]
ucs = [Ucs.from_dict(cat) for cat in cats]
if full_ucs:
return ucs
else:
return [cat for cat in ucs if \
cat.category not in ['FOLEY', 'ARCHIVED']]
class InferenceContext:
@@ -54,8 +60,11 @@ class InferenceContext:
model: SentenceTransformer
model_name: str
def __init__(self, model_name: str, use_cached_model: bool = True):
def __init__(self, model_name: str, use_cached_model: bool = True,
use_full_ucs: bool = False):
self.model_name = model_name
self.use_full_ucs = use_full_ucs
cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer",
'Squad 51')
@@ -75,7 +84,7 @@ class InferenceContext:
@cached_property
def catlist(self) -> list[Ucs]:
return load_ucs()
return load_ucs(full_ucs=self.use_full_ucs)
@cached_property
def embeddings(self) -> list[dict]:
@@ -94,7 +103,12 @@ class InferenceContext:
else:
print(f"Calculating embeddings for model {self.model_name}...")
for cat_defn in self.catlist:
# we need to calculate the embeddings for all cats, not just the
# ones we're loading for this run
full_catlist = load_ucs(full_ucs= True)
for cat_defn in full_catlist:
embeddings += [{
'CatID': cat_defn.catid,
'Embedding': self._encode_category(cat_defn)
@@ -104,7 +118,9 @@ class InferenceContext:
with open(embedding_cache_path, 'wb') as g:
pickle.dump(embeddings, g)
return embeddings
whitelisted_cats = [cat.catid for cat in self.catlist]
return [e for e in embeddings if e['CatID'] in whitelisted_cats]
def _encode_category(self, cat: Ucs) -> np.ndarray:
sentence_components = [cat.explanations,

View File

@@ -4,8 +4,10 @@ from re import match
from .inference import InferenceContext
from tabulate import tabulate
def print_recommendation(path: str | None, text: str, ctx: InferenceContext,
interactive_rename: bool):
interactive_rename: bool, recommend_limit=10):
"""
Print recommendations interactively.
@@ -17,18 +19,19 @@ def print_recommendation(path: str | None, text: str, ctx: InferenceContext,
`print_recommendation` should be called again with this argument.
- if retval[2] is a str, this is the catid the user has selected.
"""
recs = ctx.classify_text_ranked(text)
recs = ctx.classify_text_ranked(text, limit=recommend_limit)
print("----------")
if path:
print(f"Path: {path}")
print(f"Text: {text or '<None>'}")
for i, r in enumerate(recs):
cat, subcat, _ = ctx.lookup_category(r)
print(f"- {i}: {r} ({cat}-{subcat})")
if interactive_rename and path is not None:
response = input("#, t [text], ?, q > ")
response = input("(n#), t, c, b, ?, q > ")
if m := match(r'^([0-9]+)', response):
selection = int(m.group(1))
@@ -43,12 +46,27 @@ def print_recommendation(path: str | None, text: str, ctx: InferenceContext,
text = m.group(1)
return True, text, None
elif m := match(r'^c (.+)', response):
return True, None, m.group(1)
elif m := match(r'^b (.+)', response):
expt = []
for cat in ctx.catlist:
if cat.catid.startswith(m.group(1)):
expt.append([f"{cat.catid}: ({cat.category}-{cat.subcategory})",
cat.explanations])
print(tabulate(expt, maxcolwidths=80))
return True, text, None
elif response.startswith("?"):
print("""
Choices:
- Enter recommendation number to rename file,
- "t [text]" to search for new recommendations based on [text]
- "t <text>" to search for new recommendations based on <text>
- "p" re-use the last selected cat-id
- "c <cat>" to type in a category by hand
- "b <cat>" browse category list for categories starting with <cat>
- "?" for this message
- "q" to quit
- or any other key to skip this file and continue to next file
@@ -56,6 +74,8 @@ Choices:
return True, text, None
elif response.startswith('q'):
return (False, None, None)
else:
print()
else:
return None

View File

@@ -9,10 +9,7 @@ def ffmpeg_description(path: str) -> Optional[str]:
result = subprocess.run(['ffprobe', '-show_format', '-of',
'json', path], capture_output=True)
try:
result.check_returncode()
except:
return None
stream = json.loads(result.stdout)
fmt = stream.get("format", None)