Compare commits

...

39 Commits

Author SHA1 Message Date
615d8ab279 Refactoring of gather 2025-09-10 23:49:33 -07:00
9a887c4ed5 Refactoring of gather 2025-09-10 23:38:33 -07:00
63b140209b Command-line entry point:wq 2025-09-10 22:23:57 -07:00
d181ac73b1 Made prompt shorter 2025-09-10 22:08:16 -07:00
bddce23c76 Added to prompt help 2025-09-10 22:07:53 -07:00
b4758dd138 Some features for recommend, a browse feature 2025-09-06 14:10:53 -07:00
04332b73ee Fixed a bug in cat masking 2025-09-04 10:54:35 -07:00
103fffe0a4 Added another function to recommend 2025-09-04 10:43:31 -07:00
10519f9c1a Split training sets from gather 2025-09-03 19:40:55 -07:00
e419f698c9 Dataset metadata 2025-09-03 16:31:05 -07:00
6cd0415a26 README notes 2025-09-03 16:25:05 -07:00
b833d6d3c0 Added better error reporting to gather
And a bunch of options to control behavior
2025-09-03 16:07:41 -07:00
fee55a7d5a Twiddles 2025-09-03 14:56:41 -07:00
0594899bdd dataset writing in HF format 2025-09-03 14:52:10 -07:00
46a693bf93 tweaks 2025-09-03 14:23:09 -07:00
0a2aaa2a22 refining gather, adding TODOs 2025-09-03 14:13:12 -07:00
d855ba4c78 removed print message 2025-09-03 12:25:33 -07:00
336d6f013e removed print message 2025-09-03 12:24:31 -07:00
2a0b0367f8 Improved logging 2025-09-03 12:23:44 -07:00
184edcb7e4 Removed unnecessary parameter 2025-09-03 11:42:10 -07:00
d330f47462 Added some logging code 2025-09-03 11:36:38 -07:00
2fcdc24699 Made the warnings filter more specific 2025-09-03 11:16:51 -07:00
84eb046a38 Update README.md 2025-09-01 20:53:28 +00:00
1aa7260981 Update README.md 2025-09-01 20:53:07 +00:00
Jamie Hardt
33619d2ae2 Implemented '--skip-ucs' option for 'recommend' 2025-08-30 23:45:36 -07:00
Jamie Hardt
1a42d7d9f1 Refactored recommendation code into a new file 2025-08-30 23:29:45 -07:00
Jamie Hardt
010a93cf2b Displaying rename status better 2025-08-30 23:19:32 -07:00
Jamie Hardt
906656a8c9 Updated TODO 2025-08-30 22:33:02 -07:00
Jamie Hardt
ea8bd0e09b Rename implementaion 2025-08-30 22:16:58 -07:00
Jamie Hardt
7c591e9dbb Merge branch 'master' of https://git.squad51.us/jamie/ucsinfer 2025-08-30 20:36:47 -07:00
Jamie Hardt
3009d3831e Elborated rename function to recommend 2025-08-30 20:36:36 -07:00
Jamie Hardt
47829c5427 Added rename function to recommend 2025-08-30 20:14:09 -07:00
0cb2f25568 Update TODO.md 2025-08-28 00:34:04 +00:00
009cf98039 Update README.md
Updated logline
2025-08-28 00:28:00 +00:00
4570ed632a Update MODELS.md
Another run
2025-08-28 00:26:45 +00:00
62654464f3 Update pyproject.toml
Added description
2025-08-28 00:24:03 +00:00
7ef9d52135 Add TODO.md
Added a TODO list
2025-08-28 00:21:34 +00:00
Jamie Hardt
b50d8a6a06 Fixed readme 2025-08-27 15:51:26 -07:00
Jamie Hardt
a25c3c857f spelling 2025-08-27 15:48:18 -07:00
10 changed files with 506 additions and 212 deletions

View File

@@ -30,3 +30,19 @@
| UCS category count: | 752 | | | UCS category count: | 752 | |
| Total categories in sample: | 238 | 31.65% | | Total categories in sample: | 238 | 31.65% |
| Most missed category (VEHCar): | 75 | 3.21% | | Most missed category (VEHCar): | 75 | 3.21% |
## Evaluating model paraphrase-multilingual-mpnet-base-v2...
(FOLYProp and FOLYFeet have been omitted from the dataset.)
| | n | pct |
|---------------------------------|------|--------|
| Total records in sample: | 5800 | |
| Top Result: | 681 | 11.74% |
| Top 5 Result: | 1025 | 17.67% |
| Top 10 Result: | 802 | 13.83% |
| |
| UCS category count: | 752 | |
| Total categories in sample: | 298 | 39.63% |
| Most missed category (MAGElem): | 481 | 8.29% |

View File

@@ -1,6 +1,6 @@
# ucsinfer # ucsinfer
Universal Category System LLM toolkit. Tools for applying UCS categories to sounds using large-language models
## Install ## Install
@@ -37,8 +37,6 @@ python -m ucsinfer [command]
``` ```
Pass `--help` to see a summary of subcommands and options. Pass `--help` to see a summary of subcommands and options.
The subcommands available at this time are `gather` and `evaluate`.
## Functions ## Functions
* recommend * recommend
@@ -51,14 +49,17 @@ The subcommands available at this time are `gather` and `evaluate`.
* gather * gather
Scan files to capture existing text descriptions and UCS categories Scan files to capture existing text descriptions and UCS categories
and save as a dataset. This function is used to countruct datasets and save as a dataset.
that `evaluate` can use to test models and finetune can use to
refine them.
* ~finetune~ (planned) * ~finetune~ (planned)
Fine-tune an existing sentence embedding model with training data. Fine-tune an existing sentence embedding model with training data.
* evaluate * ~evaluate~ (FIXME phase)
Use datasets to evaluate the performance of a model and fine-tuning. Use datasets to evaluate the performance of a model and fine-tuning.
# Demos and Articles
* [Category Inference Experiments With UCSINFER](https://squad51.us/notebook/category_inference_experiments_ucsinfer/)
* [UCSINFER for Renaming Sounds](https://squad51.us/notebook/ucsinfer_to_rename_sounds/)

28
TODO.md Normal file
View File

@@ -0,0 +1,28 @@
## Recommend
- Use History when adding catids
## Gather
- Maybe more dataset configurations
## Fine-tune
- Implement BatchAllTripletLoss
## Evaluate
- Print more information about the dataset coverage of UCS
- Allow skipping model testing for this
- Print raw output
<!-- - Maybe load everything into a sqlite for slicker reporting -->
## Utility
- Clear caches

View File

@@ -1,7 +1,7 @@
[project] [project]
name = "ucsinfer" name = "ucsinfer"
version = "0.1.0" version = "0.1.0"
description = "" description = "Tools for applying UCS categories to sounds using large-language models"
authors = [ authors = [
{name = "Jamie Hardt",email = "jamiehardt@me.com"} {name = "Jamie Hardt",email = "jamiehardt@me.com"}
] ]
@@ -13,7 +13,8 @@ dependencies = [
"tqdm (>=4.67.1,<5.0.0)", "tqdm (>=4.67.1,<5.0.0)",
"platformdirs (>=4.3.8,<5.0.0)", "platformdirs (>=4.3.8,<5.0.0)",
"click (>=8.2.1,<9.0.0)", "click (>=8.2.1,<9.0.0)",
"tabulate (>=0.9.0,<0.10.0)" "tabulate (>=0.9.0,<0.10.0)",
"datasets (>=4.0.0,<5.0.0)"
] ]
@@ -25,3 +26,6 @@ build-backend = "poetry.core.masonry.api"
ipython = "^9.4.0" ipython = "^9.4.0"
jupyter = "^1.1.1" jupyter = "^1.1.1"
[tool.poetry.scripts]
ucsinfer = "ucsinfer.__main__:ucsinfer"

View File

@@ -1,242 +1,212 @@
import os import os
import sys # import csv
import csv import logging
from itertools import chain
from sentence_transformers import SentenceTransformer
import tqdm import tqdm
import click import click
from tabulate import tabulate, SEPARATING_LINE # from tabulate import tabulate, SEPARATING_LINE
from .inference import InferenceContext, load_ucs from .inference import InferenceContext, load_ucs
from .gather import (build_sentence_class_dataset, print_dataset_stats,
ucs_definitions_generator, scan_metadata, walk_path)
from .recommend import print_recommendation
from .util import ffmpeg_description, parse_ucs from .util import ffmpeg_description, parse_ucs
logger = logging.getLogger('ucsinfer')
def recommend_text(text: str, ctx: InferenceContext): logger.setLevel(logging.DEBUG)
return ctx.classify_text_ranked(text) stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.WARN)
def print_recommendation(path: str | None, text: str, ctx: InferenceContext): formatter = logging.Formatter(
recommendations = ctx.classify_text_ranked(text) '%(asctime)s. %(levelname)s %(name)s: %(message)s')
print("----------") stream_handler.setFormatter(formatter)
if path: logger.addHandler(stream_handler)
print(f"Path: {path}")
print(f"Text: {text or '<None>'}")
for i, r in enumerate(recommendations):
cat, subcat, _ = ctx.lookup_category(r)
print(f"- {i}: {r} ({cat}-{subcat})")
@click.group(epilog="For more information see " @click.group(epilog="For more information see "
"<https://git.squad51.us/jamie/ucsinfer>") "<https://git.squad51.us/jamie/ucsinfer>")
# @click.option('--verbose', flag_value='verbose', help='Verbose output') @click.option('--verbose', '-v', flag_value=True, help='Verbose output')
def ucsinfer(): @click.option('--model', type=str, metavar="<model-name>",
default="paraphrase-multilingual-mpnet-base-v2",
show_default=True,
help="Select the sentence_transformer model to use")
@click.option('--no-model-cache', flag_value=True,
help="Don't use local model cache")
@click.option('--complete-ucs', flag_value=True, default=False,
help="Use all UCS categories. By default, all 'FOLEY' and "
"'ARCHIVED' UCS categories are excluded from all functions.")
@click.pass_context
def ucsinfer(ctx, verbose, no_model_cache, model, complete_ucs):
""" """
Tools for applying UCS categories to sounds using large-language Models Tools for applying UCS categories to sounds using large-language Models
""" """
pass
if verbose:
stream_handler.setLevel(logging.DEBUG)
logger.info("Verbose logging is enabled")
else:
import warnings
warnings.filterwarnings(
action='ignore', module='torch', category=FutureWarning,
message=r"`encoder_attention_mask` is deprecated.*")
stream_handler.setLevel(logging.WARNING)
ctx.ensure_object(dict)
ctx.obj['model_cache'] = not no_model_cache
ctx.obj['model_name'] = model
ctx.obj['complete_ucs'] = complete_ucs
if no_model_cache:
logger.info("Model cache inhibited by config")
if complete_ucs:
logger.info("Using complete UCS catgeory list")
else:
logger.info("Non-descriptive UCS categories will be excluded. Turn "
"this option off by passing --complete-ucs.")
logger.info(f"Using model {model}")
@ucsinfer.command('recommend') @ucsinfer.command('recommend')
@click.option('--text', default=None, @click.option('--text', default=None,
help="Recommend a category for given text instead of reading " help="Recommend a category for given text instead of reading "
"from a file") "from a file")
@click.argument('paths', nargs=-1) @click.argument('paths', nargs=-1, metavar='<paths>')
@click.option('--model', type=str, metavar="<model-name>", @click.option('--interactive','-i', flag_value=True, default=False,
default="paraphrase-multilingual-mpnet-base-v2", help="After processing each path in <paths>, prompt for a "
show_default=True, "recommendation to accept, and then prepend the selection to "
help="Select the sentence_transformer model to use") "the file name.")
def recommend(text, paths, model): @click.option('-s', '--skip-ucs', flag_value=True, default=False,
help="Skip files that already have a UCS category in their "
"name.")
@click.pass_context
def recommend(ctx, text, paths, interactive, skip_ucs):
""" """
Infer a UCS category for a text description Infer a UCS category for a text description
"Description" text metadata is extracted from audio files given as PATHS, "Description" text metadata is extracted from audio files given as PATHS,
or text can be provided directly using the "--text" option. The selected or text can be provided directly using the "--text" option. The selected
model is then used to attempt to classify the given text according to model is then used to attempt to classify the given text according to
the synonyms and explanations definied for each UCS subcategory. A list the synonyms an explanations definied for each UCS subcategory. A list
of ranked subcategories is printed to the terminal for each PATH. of ranked subcategories is printed to the terminal for each PATH.
""" """
m = SentenceTransformer(model) logger.debug("RECOMMEND mode")
ctx = InferenceContext(m, model) inference_ctx = InferenceContext(ctx.obj['model_name'],
use_cached_model=ctx.obj['model_cache'],
use_full_ucs=ctx.obj['complete_ucs'])
if text is not None: if text is not None:
print_recommendation(None, text, ctx) print_recommendation(None, text, inference_ctx,
interactive_rename=False)
catlist = [x.catid for x in inference_ctx.catlist]
for path in paths: for path in paths:
text = ffmpeg_description(path) _, ext = os.path.splitext(path)
if text:
print_recommendation(path, text, ctx)
else: if ext not in (".wav", ".flac"):
filename = os.path.basename(path) continue
print_recommendation(path, filename, ctx)
basename = os.path.basename(path)
if skip_ucs and parse_ucs(basename, catlist):
continue
text = ffmpeg_description(path)
if not text:
text = os.path.basename(path)
while True:
retval = print_recommendation(path, text, inference_ctx, interactive)
if not retval:
break
if retval[0] is False:
return
elif retval[1] is not None:
text = retval[1]
continue
elif retval[2] is not None:
new_name = retval[2] + '_' + os.path.basename(path)
new_path = os.path.join(os.path.dirname(path), new_name)
print(f"Renaming {path} \n to {new_path}")
os.rename(path, new_path)
break
@ucsinfer.command('gather') @ucsinfer.command('gather')
@click.option('--outfile', type=click.File(mode='w', encoding='utf8'), @click.option('--out', default='dataset/', show_default=True)
default='dataset.csv', show_default=True) @click.option('--ucs-data', flag_value=True, help="Create a dataset based "
"on the UCS category explanations and synonymns (PATHS will "
"be ignored.)")
@click.argument('paths', nargs=-1) @click.argument('paths', nargs=-1)
def gather(paths, outfile): @click.pass_context
def gather(ctx, paths, out, ucs_data):
""" """
Scan files to build a training dataset at PATH Scan files to build a training dataset
The `gather` command walks the directory hierarchy for each path in PATHS The `gather` is used to build a training dataset for finetuning the
and looks for .wav and .flac files that are named according to the UCS selected model. Description sentences and UCS categories are collected from
file naming guidelines, with at least a CatID and FX Name, divided by an '.wav' and '.flac' files on-disk that have valid UCS filenames and assigned
underscore. CatIDs, and this information is recorded into a HuggingFace dataset.
For every file ucsinfer finds that meets this criteria, it creates a record Gather scans the filesystem in two passes: first, the directory tree is
in an output dataset CSV file. The dataset file has two columns: the first walked by os.walk and a list of filenames that meet the above name criteria
is the CatID indicated for the file, and the second is the embedded file is compiled. After this list is compiled, each file is scanned one-by-one
description for the file as returned by ffprobe. with ffprobe to obtain its "description" metadata; if this isn't present,
the parsed 'fxname' of te file becomes the description.
""" """
types = ['.wav', '.flac'] logger.debug("GATHER mode")
table = csv.writer(outfile)
print(f"Loading category list...") logger.debug(f"Loading category list...")
catid_list = [cat.catid for cat in load_ucs()] ucs = load_ucs(full_ucs=ctx.obj['complete_ucs'])
scan_list: list[tuple[str,str]] = []
catid_list = [cat.catid for cat in ucs]
if ucs_data:
logger.info('Creating dataset for UCS categories instead of from PATH')
paths = []
scan_list = []
for path in paths: for path in paths:
print(f"Scanning directory {path}...", file=sys.stdout) scan_list += walk_path(path, catid_list)
for dirpath, _, filenames in os.walk(path):
for filename in filenames:
root, ext = os.path.splitext(filename)
if ext in types and \
(ucs_components := parse_ucs(root, catid_list)) and \
not filename.startswith("._"):
scan_list.append((ucs_components.cat_id,
os.path.join(dirpath, filename)))
print(f"Found {len(scan_list)} files to process.") logger.info(f"Found {len(scan_list)} files to process.")
for pair in tqdm.tqdm(scan_list, unit='files', file=sys.stderr): logger.info("Building dataset...")
if desc := ffmpeg_description(pair[1]):
table.writerow([pair[0], desc]) dataset = build_sentence_class_dataset(
chain(scan_metadata(scan_list, catid_list),
ucs_definitions_generator(ucs)),
catid_list)
logger.info(f"Saving dataset to disk at {out}")
print_dataset_stats(dataset, catid_list)
dataset.save_to_disk(out)
@ucsinfer.command('finetune') @ucsinfer.command('finetune')
def finetune(): @click.pass_context
def finetune(ctx):
""" """
Fine-tune a model with training data Fine-tune a model with training data
""" """
pass logger.debug("FINETUNE mode")
@ucsinfer.command('evaluate') @ucsinfer.command('evaluate')
@click.option('--offset', type=int, default=0, metavar="<int>", @click.argument('dataset', default='dataset/')
help='Skip this many records in the dataset before processing') @click.pass_context
@click.option('--limit', type=int, default=-1, metavar="<int>", def evaluate(ctx, dataset, offset, limit):
help='Process this many records and then exit')
@click.option('--no-foley', 'no_foley', flag_value=True, default=False,
help="Ignore any data in the set with FOLYProp or FOLYFeet "
"category")
@click.option('--model', type=str, metavar="<model-name>",
default="paraphrase-multilingual-mpnet-base-v2",
show_default=True,
help="Select the sentence_transformer model to use")
@click.argument('dataset', type=click.File('r', encoding='utf8'),
default='dataset.csv')
def evaluate(dataset, offset, limit, model, no_foley):
""" """
Use datasets to evaluate model performance
The `evaluate` command reads the input DATASET file row by row and
performs a classifcation of the given description against the selected
model (either the default or using the --model option). The command then
checks if the model inferred the correct category as given by the dataset.
The model gives its top 10 possible categories for a given description,
and the results are tabulated according to (1) wether the top
classification was correct, (2) wether the correct classifcation was in the
top 5, or (3) wether it was in the top 10. The worst-performing category,
the one with the most misses, is also reported as well as the category
coverage, how many categories are present in the dataset.
NOTE: With experimentation it was found that foley items generally were
classified according to their subject and not wether or not they were
foley, and so these categories can be excluded with the --no-foley option.
""" """
m = SentenceTransformer(model) logger.debug("EVALUATE mode")
ctx = InferenceContext(m, model) logger.warning("Model evaluation is not currently implemented")
reader = csv.reader(dataset)
print(f"Evaluating model {model}...")
results = []
if offset > 0:
print(f"Skipping {offset} records...")
if limit > 0:
print(f"Will only evaluate {limit} records...")
progress_bar = tqdm.tqdm(total=limit)
for i, row in enumerate(reader):
if i < offset:
continue
if limit > 0 and i >= limit + offset:
break
cat_id, description = row
if no_foley and cat_id in ['FOLYProp', 'FOLYFeet']:
continue
guesses = ctx.classify_text_ranked(description, limit=10)
if cat_id == guesses[0]:
results.append({'catid': cat_id, 'result': "TOP"})
elif cat_id in guesses[0:5]:
results.append({'catid': cat_id, 'result': "TOP_5"})
elif cat_id in guesses:
results.append({'catid': cat_id, 'result': "TOP_10"})
else:
results.append({'catid': cat_id, 'result': "MISS"})
progress_bar.update(1)
total = len(results)
total_top = len([x for x in results if x['result'] == 'TOP'])
total_top_5 = len([x for x in results if x['result'] == 'TOP_5'])
total_top_10 = len([x for x in results if x['result'] == 'TOP_10'])
cats = set([x['catid'] for x in results])
total_cats = len(cats)
miss_counts = []
for cat in cats:
miss_counts.append(
(cat, len([x for x in results
if x['catid'] == cat and x['result'] == 'MISS'])))
miss_counts = sorted(miss_counts, key=lambda x: x[1])
print(f"## Results for Model {model} ##\n")
if no_foley:
print("(FOLYProp and FOLYFeet have been omitted from the dataset.)\n")
table = [
["Total records in sample:", f"{total}"],
["Top Result:", f"{total_top}",
f"{float(total_top)/float(total):.2%}"],
["Top 5 Result:", f"{total_top_5}",
f"{float(total_top_5)/float(total):.2%}"],
["Top 10 Result:", f"{total_top_10}",
f"{float(total_top_10)/float(total):.2%}"],
SEPARATING_LINE,
["UCS category count:", f"{len(ctx.catlist)}"],
["Total categories in sample:", f"{total_cats}",
f"{float(total_cats)/float(len(ctx.catlist)):.2%}"],
[f"Most missed category ({miss_counts[-1][0]}):",
f"{miss_counts[-1][1]}",
f"{float(miss_counts[-1][1])/float(total):.2%}"]
]
print(tabulate(table, headers=['', 'n', 'pct'], tablefmt='github'))
if __name__ == '__main__': if __name__ == '__main__':
os.environ['TOKENIZERS_PARALLELISM'] = 'false' os.environ['TOKENIZERS_PARALLELISM'] = 'false'
import warnings ucsinfer(obj={})
warnings.simplefilter(action='ignore', category=FutureWarning)
ucsinfer()

7
ucsinfer/evaluate.py Normal file
View File

@@ -0,0 +1,7 @@
# from sentence_transformers import SentenceTransformer
# from sentence_transformers.evaluation import BinaryClassificationEvaluator
# from datasets import load_dataset_from_disk, DatasetDict
#

118
ucsinfer/gather.py Normal file
View File

@@ -0,0 +1,118 @@
from .inference import Ucs
from .util import ffmpeg_description, parse_ucs
from subprocess import CalledProcessError
import os.path
from datasets import Dataset, Features, Value, ClassLabel, DatasetInfo
from datasets.dataset_dict import DatasetDict
from typing import Iterator, Generator
from tabulate import tabulate
import logging
import tqdm
def walk_path(path:str, catid_list) -> list[tuple[str,str]]:
types = ['.wav', '.flac']
logger = logging.getLogger('ucsinfer')
walker_p = tqdm.tqdm(total=None, unit='dir', desc="Walking filesystem...")
scan_list = []
for dirpath, _, filenames in os.walk(path):
logger.info(f"Walking directory {dirpath}")
for filename in filenames:
walker_p.update()
root, ext = os.path.splitext(filename)
if ext not in types or filename.startswith("._"):
continue
if (ucs_components := parse_ucs(root, catid_list)):
p = os.path.join(dirpath, filename)
logger.info(f"Adding path to scan list {p}")
scan_list.append((ucs_components.cat_id, p))
return scan_list
def scan_metadata(scan_list: list[tuple[str,str]], catid_list: list[str]):
logger = logging.getLogger('ucsinfer')
for pair in tqdm.tqdm(scan_list, unit='files'):
logger.info(f"Scanning file with ffprobe: {pair[1]}")
try:
desc = ffmpeg_description(pair[1])
except CalledProcessError as e:
logger.error(f"ffprobe returned error (){e.returncode}): " \
+ e.stderr)
continue
if desc:
yield desc, str(pair[0])
else:
comps = parse_ucs(os.path.basename(pair[1]), catid_list)
assert comps
yield comps.fx_name, str(pair[0])
def ucs_definitions_generator(ucs: list[Ucs]) \
-> Generator[tuple[str,str],None, None]:
for cat in ucs:
yield cat.explanations, cat.catid
yield ", ".join(cat.synonymns), cat.catid
def print_dataset_stats(dataset: DatasetDict, catlist: list[str]):
data_table = []
data_table.append([["Total records in combined dataset:", len(dataset)]])
data_table.append([["Total records in `train`:", len(dataset['train'])]])
tab = tabulate(data_table)
print(tab)
# https://www.sbert.net/docs/sentence_transformer/loss_overview.html
def build_sentence_class_dataset(
records: Iterator[tuple[str, str]],
catlist: list[str]) -> DatasetDict:
"""
Create a new dataset for `records` which contains (sentence, class) pairs.
The dataset is split into train and test slices.
:param records: a generator for records that generates pairs of
(sentence, catid)
:returns: A dataset with two columns: (sentence, hash(catid))
"""
labels = ClassLabel(names=catlist)
features = Features({'sentence': Value('string'),
'class': labels})
info = DatasetInfo(
description=f"(sentence, UCS CatID) pairs gathered by the "
"ucsinfer tool on {}", features= features)
items: list[dict] = []
for obj in records:
items += [{'sentence': obj[0], 'class': obj[1]}]
whole = Dataset.from_list(items, features=features, info=info)
split_set = whole.train_test_split(0.2)
test_eval_set = split_set['test'].train_test_split(0.5)
return DatasetDict({
'train': split_set['train'],
'test': test_eval_set['train'],
'eval': test_eval_set['test']
})
# def build_sentence_anchor_dataset() -> Dataset:
# """
# Create a new dataset for `records` which contains (sentence, anchor) pairs.
# """
# pass

View File

@@ -34,7 +34,7 @@ class Ucs(NamedTuple):
explanations=d['Explanations'], synonymns=d['Synonyms']) explanations=d['Explanations'], synonymns=d['Synonyms'])
def load_ucs() -> list[Ucs]: def load_ucs(full_ucs: bool = True) -> list[Ucs]:
FILE_ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) FILE_ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
cats = [] cats = []
ucs_defs = os.path.join(FILE_ROOT_DIR, 'ucs-community', 'json', ucs_defs = os.path.join(FILE_ROOT_DIR, 'ucs-community', 'json',
@@ -43,7 +43,13 @@ def load_ucs() -> list[Ucs]:
with open(ucs_defs, 'r') as f: with open(ucs_defs, 'r') as f:
cats = json.load(f) cats = json.load(f)
return [Ucs.from_dict(cat) for cat in cats] ucs = [Ucs.from_dict(cat) for cat in cats]
if full_ucs:
return ucs
else:
return [cat for cat in ucs if \
cat.category not in ['FOLEY', 'ARCHIVED']]
class InferenceContext: class InferenceContext:
@@ -54,39 +60,67 @@ class InferenceContext:
model: SentenceTransformer model: SentenceTransformer
model_name: str model_name: str
def __init__(self, model: SentenceTransformer, model_name: str): def __init__(self, model_name: str, use_cached_model: bool = True,
self.model = model use_full_ucs: bool = False):
self.model_name = model_name self.model_name = model_name
self.use_full_ucs = use_full_ucs
cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer",
'Squad 51')
model_cache_path = os.path.join(cache_dir,
f"{self.model_name}.cache")
if use_cached_model:
if os.path.exists(model_cache_path):
self.model = SentenceTransformer(model_cache_path)
else:
self.model = SentenceTransformer(model_name)
self.model.save(model_cache_path)
else:
self.model = SentenceTransformer(model_name)
@cached_property @cached_property
def catlist(self) -> list[Ucs]: def catlist(self) -> list[Ucs]:
return load_ucs() return load_ucs(full_ucs=self.use_full_ucs)
@cached_property @cached_property
def embeddings(self) -> list[dict]: def embeddings(self) -> list[dict]:
cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer", 'Squad 51') cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer",
embedding_cache = os.path.join(cache_dir, 'Squad 51')
f"{self.model_name}-ucs_embedding.cache") embedding_cache_path = os.path.join(
cache_dir,
f"{self.model_name}-ucs_embedding.cache")
embeddings = [] embeddings = []
if os.path.exists(embedding_cache): if os.path.exists(embedding_cache_path):
with open(embedding_cache, 'rb') as f: with open(embedding_cache_path, 'rb') as f:
embeddings = pickle.load(f) embeddings = pickle.load(f)
else: else:
print(f"Calculating embeddings for model {self.model_name}...") print(f"Calculating embeddings for model {self.model_name}...")
for cat_defn in self.catlist: # we need to calculate the embeddings for all cats, not just the
# ones we're loading for this run
full_catlist = load_ucs(full_ucs= True)
for cat_defn in full_catlist:
embeddings += [{ embeddings += [{
'CatID': cat_defn.catid, 'CatID': cat_defn.catid,
'Embedding': self._encode_category(cat_defn) 'Embedding': self._encode_category(cat_defn)
}] }]
os.makedirs(os.path.dirname(embedding_cache), exist_ok=True) os.makedirs(os.path.dirname(embedding_cache_path), exist_ok=True)
with open(embedding_cache, 'wb') as g: with open(embedding_cache_path, 'wb') as g:
pickle.dump(embeddings, g) pickle.dump(embeddings, g)
return embeddings whitelisted_cats = [cat.catid for cat in self.catlist]
return [e for e in embeddings if e['CatID'] in whitelisted_cats]
def _encode_category(self, cat: Ucs) -> np.ndarray: def _encode_category(self, cat: Ucs) -> np.ndarray:
sentence_components = [cat.explanations, sentence_components = [cat.explanations,

82
ucsinfer/recommend.py Normal file
View File

@@ -0,0 +1,82 @@
# recommend.py
from re import match
from .inference import InferenceContext
from tabulate import tabulate
def print_recommendation(path: str | None, text: str, ctx: InferenceContext,
interactive_rename: bool, recommend_limit=10):
"""
Print recommendations interactively.
:returns: If interactive_rename is false or path is None, returns None.
If interactive, returns a tuple of bool, str | None, str | None where:
- if retval[0] is False, the user has requested processing quit.
- if retval[1] is a str, this is the text the user has entered to
perform a new inference instead of the file metadata.
`print_recommendation` should be called again with this argument.
- if retval[2] is a str, this is the catid the user has selected.
"""
recs = ctx.classify_text_ranked(text, limit=recommend_limit)
print("----------")
if path:
print(f"Path: {path}")
print(f"Text: {text or '<None>'}")
for i, r in enumerate(recs):
cat, subcat, _ = ctx.lookup_category(r)
print(f"- {i}: {r} ({cat}-{subcat})")
if interactive_rename and path is not None:
response = input("(n#), t, c, b, ?, q > ")
if m := match(r'^([0-9]+)', response):
selection = int(m.group(1))
if 0 <= selection < len(recs):
return True, None, recs[selection]
else:
print(f"Invalid index {selection}")
return True, text, None
elif m := match(r'^t (.*)', response):
print("searching for new matches")
text = m.group(1)
return True, text, None
elif m := match(r'^c (.+)', response):
return True, None, m.group(1)
elif m := match(r'^b (.+)', response):
expt = []
for cat in ctx.catlist:
if cat.catid.startswith(m.group(1)):
expt.append([f"{cat.catid}: ({cat.category}-{cat.subcategory})",
cat.explanations])
print(tabulate(expt, maxcolwidths=80))
return True, text, None
elif response.startswith("?"):
print("""
Choices:
- Enter recommendation number to rename file,
- "t <text>" to search for new recommendations based on <text>
- "p" re-use the last selected cat-id
- "c <cat>" to type in a category by hand
- "b <cat>" browse category list for categories starting with <cat>
- "?" for this message
- "q" to quit
- or any other key to skip this file and continue to next file
""")
return True, text, None
elif response.startswith('q'):
return (False, None, None)
else:
print()
else:
return None

View File

@@ -1,5 +1,6 @@
import subprocess import subprocess
import json import json
import os
from typing import NamedTuple, Optional from typing import NamedTuple, Optional
from re import match from re import match
@@ -8,10 +9,7 @@ def ffmpeg_description(path: str) -> Optional[str]:
result = subprocess.run(['ffprobe', '-show_format', '-of', result = subprocess.run(['ffprobe', '-show_format', '-of',
'json', path], capture_output=True) 'json', path], capture_output=True)
try: result.check_returncode()
result.check_returncode()
except:
return None
stream = json.loads(result.stdout) stream = json.loads(result.stdout)
fmt = stream.get("format", None) fmt = stream.get("format", None)
@@ -59,6 +57,20 @@ class UcsNameComponents(NamedTuple):
return False return False
def normalize_ucs(basename: str, catid_list: list[str]):
"""
Take any filename and normalize it into the UCS system
"""
n, ext = os.path.splitext(basename)
r = parse_ucs(n, catid_list)
if r:
pass
else:
pass
return f"aaa.{ext}"
def build_ucs(components: UcsNameComponents, extension: str) -> str: def build_ucs(components: UcsNameComponents, extension: str) -> str:
""" """
Build a UCS filename Build a UCS filename
@@ -66,7 +78,29 @@ def build_ucs(components: UcsNameComponents, extension: str) -> str:
assert components.validate(), \ assert components.validate(), \
"UcsNameComponents contains invalid characters" "UcsNameComponents contains invalid characters"
return "" cat_segment = components.cat_id
if components.user_cat:
cat_segment += f"-{components.user_cat}"
name_segment = components.fx_name
if components.vendor_cat:
name_segment = f"{components.vendor_cat}-{components.fx_name}"
all_comps = [cat_segment, name_segment]
if components.creator:
all_comps += [components.creator]
if components.source:
all_comps += [components.source]
if components.user_data:
all_comps += [components.user_data]
root_name = "_".join(all_comps)
return root_name + '.' + extension
def parse_ucs(rootname: str, def parse_ucs(rootname: str,