Compare commits

...

53 Commits

Author SHA1 Message Date
615d8ab279 Refactoring of gather 2025-09-10 23:49:33 -07:00
9a887c4ed5 Refactoring of gather 2025-09-10 23:38:33 -07:00
63b140209b Command-line entry point:wq 2025-09-10 22:23:57 -07:00
d181ac73b1 Made prompt shorter 2025-09-10 22:08:16 -07:00
bddce23c76 Added to prompt help 2025-09-10 22:07:53 -07:00
b4758dd138 Some features for recommend, a browse feature 2025-09-06 14:10:53 -07:00
04332b73ee Fixed a bug in cat masking 2025-09-04 10:54:35 -07:00
103fffe0a4 Added another function to recommend 2025-09-04 10:43:31 -07:00
10519f9c1a Split training sets from gather 2025-09-03 19:40:55 -07:00
e419f698c9 Dataset metadata 2025-09-03 16:31:05 -07:00
6cd0415a26 README notes 2025-09-03 16:25:05 -07:00
b833d6d3c0 Added better error reporting to gather
And a bunch of options to control behavior
2025-09-03 16:07:41 -07:00
fee55a7d5a Twiddles 2025-09-03 14:56:41 -07:00
0594899bdd dataset writing in HF format 2025-09-03 14:52:10 -07:00
46a693bf93 tweaks 2025-09-03 14:23:09 -07:00
0a2aaa2a22 refining gather, adding TODOs 2025-09-03 14:13:12 -07:00
d855ba4c78 removed print message 2025-09-03 12:25:33 -07:00
336d6f013e removed print message 2025-09-03 12:24:31 -07:00
2a0b0367f8 Improved logging 2025-09-03 12:23:44 -07:00
184edcb7e4 Removed unnecessary parameter 2025-09-03 11:42:10 -07:00
d330f47462 Added some logging code 2025-09-03 11:36:38 -07:00
2fcdc24699 Made the warnings filter more specific 2025-09-03 11:16:51 -07:00
84eb046a38 Update README.md 2025-09-01 20:53:28 +00:00
1aa7260981 Update README.md 2025-09-01 20:53:07 +00:00
Jamie Hardt
33619d2ae2 Implemented '--skip-ucs' option for 'recommend' 2025-08-30 23:45:36 -07:00
Jamie Hardt
1a42d7d9f1 Refactored recommendation code into a new file 2025-08-30 23:29:45 -07:00
Jamie Hardt
010a93cf2b Displaying rename status better 2025-08-30 23:19:32 -07:00
Jamie Hardt
906656a8c9 Updated TODO 2025-08-30 22:33:02 -07:00
Jamie Hardt
ea8bd0e09b Rename implementaion 2025-08-30 22:16:58 -07:00
Jamie Hardt
7c591e9dbb Merge branch 'master' of https://git.squad51.us/jamie/ucsinfer 2025-08-30 20:36:47 -07:00
Jamie Hardt
3009d3831e Elborated rename function to recommend 2025-08-30 20:36:36 -07:00
Jamie Hardt
47829c5427 Added rename function to recommend 2025-08-30 20:14:09 -07:00
0cb2f25568 Update TODO.md 2025-08-28 00:34:04 +00:00
009cf98039 Update README.md
Updated logline
2025-08-28 00:28:00 +00:00
4570ed632a Update MODELS.md
Another run
2025-08-28 00:26:45 +00:00
62654464f3 Update pyproject.toml
Added description
2025-08-28 00:24:03 +00:00
7ef9d52135 Add TODO.md
Added a TODO list
2025-08-28 00:21:34 +00:00
Jamie Hardt
b50d8a6a06 Fixed readme 2025-08-27 15:51:26 -07:00
Jamie Hardt
a25c3c857f spelling 2025-08-27 15:48:18 -07:00
Jamie Hardt
4a5397f153 Refinement 2025-08-27 15:47:38 -07:00
Jamie Hardt
65bdf3256e Added to readme 2025-08-27 15:43:19 -07:00
Jamie Hardt
efd9a08212 Refine print out of recs 2025-08-27 15:41:05 -07:00
Jamie Hardt
739d27fe71 Implementation of recommend 2025-08-27 15:37:53 -07:00
Jamie Hardt
de7ad3d65a Implemented recommend 2025-08-27 15:27:30 -07:00
Jamie Hardt
4f7b2a73cb Merge branch 'master' of https://git.squad51.us/jamie/ucsinfer 2025-08-27 15:05:26 -07:00
354d1d4e40 Update README.md 2025-08-27 22:05:17 +00:00
Jamie Hardt
cc78291f1d Implementation 2025-08-27 15:03:06 -07:00
Jamie Hardt
7b3930ece0 Added more helpful prompts and help 2025-08-27 14:04:46 -07:00
Jamie Hardt
2d233bd1f2 Eliminated redundant usage examples 2025-08-27 13:52:02 -07:00
Jamie Hardt
3846809918 Added some more help text for function 2025-08-27 13:50:54 -07:00
Jamie Hardt
d13f5acbab Enhanced readme 2025-08-27 13:34:46 -07:00
Jamie Hardt
b0a3fcc748 Merge branch 'master' of https://git.squad51.us/jamie/ucsinfer 2025-08-27 13:22:30 -07:00
Jamie Hardt
2b738f561b Cleaned-up some formatting 2025-08-26 18:09:30 -07:00
10 changed files with 534 additions and 146 deletions

View File

@@ -30,3 +30,19 @@
| UCS category count: | 752 | |
| Total categories in sample: | 238 | 31.65% |
| Most missed category (VEHCar): | 75 | 3.21% |
## Evaluating model paraphrase-multilingual-mpnet-base-v2...
(FOLYProp and FOLYFeet have been omitted from the dataset.)
| | n | pct |
|---------------------------------|------|--------|
| Total records in sample: | 5800 | |
| Top Result: | 681 | 11.74% |
| Top 5 Result: | 1025 | 17.67% |
| Top 10 Result: | 802 | 13.83% |
| |
| UCS category count: | 752 | |
| Total categories in sample: | 298 | 39.63% |
| Most missed category (MAGElem): | 481 | 8.29% |

View File

@@ -1,6 +1,6 @@
# ucsinfer
Universal Category System LLM toolkit.
Tools for applying UCS categories to sounds using large-language models
## Install
@@ -9,9 +9,13 @@ packaged on PyPi. You should clone the project to your local machine and
do an [editable install](https://pip.pypa.io/en/stable/topics/local-project-installs/#editable-installs)
in a [virtual environment](https://docs.python.org/3/library/venv.html).
Note: You will also need ffmpeg and ffprobe in order to interrogate audio
files for their metadata.
```sh
$ brew install ffmpeg
$ git clone https://git.squad51.us/jamie/ucsinfer.git
$ git submodule sync
$ git submodule update --init
$ python -m venv .venv
$ source .venv/bin/activate # or whatever command is approprate for your shell
$ pip install -e .
@@ -33,25 +37,29 @@ python -m ucsinfer [command]
```
Pass `--help` to see a summary of subcommands and options.
The subcommands available at this time are `gather` and `evaluate`.
## Functions
* ~recommend~ (in-progress)
* recommend
Infer a UCS category for a text description.
Infer a UCS category for a text description. Text metadata is extracted from
audio files and the language model can recommend a corresponding list of
appropriate categories, ranked by their alignment with the category
definition.
* gather
Scan files to capture existing text descriptions and UCS categories
and save as a dataset. This function is used to countruct datasets
that `evaluate` can use to test models and finetune can use to
refine them.
and save as a dataset.
* ~finetune~ (planned)
Fine-tune an existing sentence embedding model with training data.
* evaluate
* ~evaluate~ (FIXME phase)
Use datasets to evaluate the performance of a model and fine-tuning.
# Demos and Articles
* [Category Inference Experiments With UCSINFER](https://squad51.us/notebook/category_inference_experiments_ucsinfer/)
* [UCSINFER for Renaming Sounds](https://squad51.us/notebook/ucsinfer_to_rename_sounds/)

28
TODO.md Normal file
View File

@@ -0,0 +1,28 @@
## Recommend
- Use History when adding catids
## Gather
- Maybe more dataset configurations
## Fine-tune
- Implement BatchAllTripletLoss
## Evaluate
- Print more information about the dataset coverage of UCS
- Allow skipping model testing for this
- Print raw output
<!-- - Maybe load everything into a sqlite for slicker reporting -->
## Utility
- Clear caches

View File

@@ -1,7 +1,7 @@
[project]
name = "ucsinfer"
version = "0.1.0"
description = ""
description = "Tools for applying UCS categories to sounds using large-language models"
authors = [
{name = "Jamie Hardt",email = "jamiehardt@me.com"}
]
@@ -13,7 +13,8 @@ dependencies = [
"tqdm (>=4.67.1,<5.0.0)",
"platformdirs (>=4.3.8,<5.0.0)",
"click (>=8.2.1,<9.0.0)",
"tabulate (>=0.9.0,<0.10.0)"
"tabulate (>=0.9.0,<0.10.0)",
"datasets (>=4.0.0,<5.0.0)"
]
@@ -25,3 +26,6 @@ build-backend = "poetry.core.masonry.api"
ipython = "^9.4.0"
jupyter = "^1.1.1"
[tool.poetry.scripts]
ucsinfer = "ucsinfer.__main__:ucsinfer"

View File

@@ -1,155 +1,212 @@
import os
import sys
import csv
# import csv
import logging
from itertools import chain
from sentence_transformers import SentenceTransformer
import tqdm
import click
from tabulate import tabulate, SEPARATING_LINE
# from tabulate import tabulate, SEPARATING_LINE
from .inference import InferenceContext, load_ucs
from .gather import (build_sentence_class_dataset, print_dataset_stats,
ucs_definitions_generator, scan_metadata, walk_path)
from .recommend import print_recommendation
from .util import ffmpeg_description, parse_ucs
logger = logging.getLogger('ucsinfer')
logger.setLevel(logging.DEBUG)
stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.WARN)
formatter = logging.Formatter(
'%(asctime)s. %(levelname)s %(name)s: %(message)s')
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
@click.group()
# @click.option('--verbose', flag_value='verbose', help='Verbose output')
def ucsinfer():
pass
@click.group(epilog="For more information see "
"<https://git.squad51.us/jamie/ucsinfer>")
@click.option('--verbose', '-v', flag_value=True, help='Verbose output')
@click.option('--model', type=str, metavar="<model-name>",
default="paraphrase-multilingual-mpnet-base-v2",
show_default=True,
help="Select the sentence_transformer model to use")
@click.option('--no-model-cache', flag_value=True,
help="Don't use local model cache")
@click.option('--complete-ucs', flag_value=True, default=False,
help="Use all UCS categories. By default, all 'FOLEY' and "
"'ARCHIVED' UCS categories are excluded from all functions.")
@click.pass_context
def ucsinfer(ctx, verbose, no_model_cache, model, complete_ucs):
"""
Tools for applying UCS categories to sounds using large-language Models
"""
if verbose:
stream_handler.setLevel(logging.DEBUG)
logger.info("Verbose logging is enabled")
else:
import warnings
warnings.filterwarnings(
action='ignore', module='torch', category=FutureWarning,
message=r"`encoder_attention_mask` is deprecated.*")
stream_handler.setLevel(logging.WARNING)
ctx.ensure_object(dict)
ctx.obj['model_cache'] = not no_model_cache
ctx.obj['model_name'] = model
ctx.obj['complete_ucs'] = complete_ucs
if no_model_cache:
logger.info("Model cache inhibited by config")
if complete_ucs:
logger.info("Using complete UCS catgeory list")
else:
logger.info("Non-descriptive UCS categories will be excluded. Turn "
"this option off by passing --complete-ucs.")
logger.info(f"Using model {model}")
@ucsinfer.command('recommend')
def recommend():
@click.option('--text', default=None,
help="Recommend a category for given text instead of reading "
"from a file")
@click.argument('paths', nargs=-1, metavar='<paths>')
@click.option('--interactive','-i', flag_value=True, default=False,
help="After processing each path in <paths>, prompt for a "
"recommendation to accept, and then prepend the selection to "
"the file name.")
@click.option('-s', '--skip-ucs', flag_value=True, default=False,
help="Skip files that already have a UCS category in their "
"name.")
@click.pass_context
def recommend(ctx, text, paths, interactive, skip_ucs):
"""
Infer a UCS category for a text description
"""
pass
"Description" text metadata is extracted from audio files given as PATHS,
or text can be provided directly using the "--text" option. The selected
model is then used to attempt to classify the given text according to
the synonyms an explanations definied for each UCS subcategory. A list
of ranked subcategories is printed to the terminal for each PATH.
"""
logger.debug("RECOMMEND mode")
inference_ctx = InferenceContext(ctx.obj['model_name'],
use_cached_model=ctx.obj['model_cache'],
use_full_ucs=ctx.obj['complete_ucs'])
if text is not None:
print_recommendation(None, text, inference_ctx,
interactive_rename=False)
catlist = [x.catid for x in inference_ctx.catlist]
for path in paths:
_, ext = os.path.splitext(path)
if ext not in (".wav", ".flac"):
continue
basename = os.path.basename(path)
if skip_ucs and parse_ucs(basename, catlist):
continue
text = ffmpeg_description(path)
if not text:
text = os.path.basename(path)
while True:
retval = print_recommendation(path, text, inference_ctx, interactive)
if not retval:
break
if retval[0] is False:
return
elif retval[1] is not None:
text = retval[1]
continue
elif retval[2] is not None:
new_name = retval[2] + '_' + os.path.basename(path)
new_path = os.path.join(os.path.dirname(path), new_name)
print(f"Renaming {path} \n to {new_path}")
os.rename(path, new_path)
break
@ucsinfer.command('gather')
@click.option('--outfile', type=click.File(mode='w', encoding='utf8'),
default='dataset.csv', show_default=True)
@click.option('--out', default='dataset/', show_default=True)
@click.option('--ucs-data', flag_value=True, help="Create a dataset based "
"on the UCS category explanations and synonymns (PATHS will "
"be ignored.)")
@click.argument('paths', nargs=-1)
def gather(paths, outfile):
@click.pass_context
def gather(ctx, paths, out, ucs_data):
"""
Scan files to build a training dataset at PATH
Scan files to build a training dataset
The `gather` is used to build a training dataset for finetuning the
selected model. Description sentences and UCS categories are collected from
'.wav' and '.flac' files on-disk that have valid UCS filenames and assigned
CatIDs, and this information is recorded into a HuggingFace dataset.
Gather scans the filesystem in two passes: first, the directory tree is
walked by os.walk and a list of filenames that meet the above name criteria
is compiled. After this list is compiled, each file is scanned one-by-one
with ffprobe to obtain its "description" metadata; if this isn't present,
the parsed 'fxname' of te file becomes the description.
"""
types = ['.wav', '.flac']
table = csv.writer(outfile)
print(f"Loading category list...")
catid_list = [cat.catid for cat in load_ucs()]
logger.debug("GATHER mode")
logger.debug(f"Loading category list...")
ucs = load_ucs(full_ucs=ctx.obj['complete_ucs'])
scan_list: list[tuple[str,str]] = []
catid_list = [cat.catid for cat in ucs]
if ucs_data:
logger.info('Creating dataset for UCS categories instead of from PATH')
paths = []
scan_list = []
for path in paths:
print(f"Scanning directory {path}...", file=sys.stdout)
for dirpath, _, filenames in os.walk(path):
for filename in filenames:
root, ext = os.path.splitext(filename)
if ext in types and \
(ucs_components := parse_ucs(root, catid_list)) and \
not filename.startswith("._"):
scan_list.append((ucs_components.cat_id,
os.path.join(dirpath, filename)))
scan_list += walk_path(path, catid_list)
print(f"Found {len(scan_list)} files to process.")
logger.info(f"Found {len(scan_list)} files to process.")
for pair in tqdm.tqdm(scan_list, unit='files', file=sys.stderr):
if desc := ffmpeg_description(pair[1]):
table.writerow([pair[0], desc])
logger.info("Building dataset...")
dataset = build_sentence_class_dataset(
chain(scan_metadata(scan_list, catid_list),
ucs_definitions_generator(ucs)),
catid_list)
logger.info(f"Saving dataset to disk at {out}")
print_dataset_stats(dataset, catid_list)
dataset.save_to_disk(out)
@ucsinfer.command('finetune')
def finetune():
@click.pass_context
def finetune(ctx):
"""
Fine-tune a model with training data
"""
pass
logger.debug("FINETUNE mode")
@ucsinfer.command('evaluate')
@click.option('--offset', type=int, default=0)
@click.option('--limit', type=int, default=-1)
@click.option('--no-foley', 'no_foley', flag_value=True, default=False,
help="Ignore any data in the set with FOLYProp or FOLYFeet "
"category")
@click.option('--model', type=str,
default="paraphrase-multilingual-mpnet-base-v2")
@click.argument('dataset', type=click.File('r', encoding='utf8'),
default='dataset.csv')
def evaluate(dataset, offset, limit, model, no_foley):
@click.argument('dataset', default='dataset/')
@click.pass_context
def evaluate(ctx, dataset, offset, limit):
"""
Use datasets to evaluate model performance
"""
m = SentenceTransformer(model)
ctx = InferenceContext(m, model)
reader = csv.reader(dataset)
results = []
for i, row in enumerate(tqdm.tqdm(reader)):
if i < offset:
continue
if limit > 0 and i >= limit + offset:
break
cat_id, description = row
if no_foley and cat_id in ['FOLYProp', 'FOLYFeet']:
continue
guesses = ctx.classify_text_ranked(description, limit=10)
if cat_id == guesses[0]:
results.append({'catid': cat_id, 'result': "TOP"})
elif cat_id in guesses[0:5]:
results.append({'catid': cat_id, 'result': "TOP_5"})
elif cat_id in guesses:
results.append({'catid': cat_id, 'result': "TOP_10"})
else:
results.append({'catid': cat_id, 'result': "MISS"})
total = len(results)
total_top = len([x for x in results if x['result'] == 'TOP'])
total_top_5 = len([x for x in results if x['result'] == 'TOP_5'])
total_top_10 = len([x for x in results if x['result'] == 'TOP_10'])
cats = set([x['catid'] for x in results])
total_cats = len(cats)
miss_counts = []
for cat in cats:
miss_counts.append(
(cat, len([x for x in results
if x['catid'] == cat and x['result'] == 'MISS'])))
miss_counts = sorted(miss_counts, key=lambda x: x[1])
print(f"## Results for Model {model} ##")
if no_foley:
print("(FOLYProp and FOLYFeet have been omitted from the dataset.)\n")
table = [
["Total records in sample:", f"{total}"],
["Top Result:", f"{total_top}",
f"{float(total_top)/float(total):.2%}"],
["Top 5 Result:", f"{total_top_5}",
f"{float(total_top_5)/float(total):.2%}"],
["Top 10 Result:", f"{total_top_10}",
f"{float(total_top_10)/float(total):.2%}"],
SEPARATING_LINE,
["UCS category count:", f"{len(ctx.catlist)}"],
["Total categories in sample:", f"{total_cats}",
f"{float(total_cats)/float(len(ctx.catlist)):.2%}"],
[f"Most missed category ({miss_counts[-1][0]}):",
f"{miss_counts[-1][1]}",
f"{float(miss_counts[-1][1])/float(total):.2%}"]
]
print(tabulate(table, headers=['', 'n', 'pct'], tablefmt='github'))
logger.debug("EVALUATE mode")
logger.warning("Model evaluation is not currently implemented")
if __name__ == '__main__':
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
ucsinfer()
ucsinfer(obj={})

7
ucsinfer/evaluate.py Normal file
View File

@@ -0,0 +1,7 @@
# from sentence_transformers import SentenceTransformer
# from sentence_transformers.evaluation import BinaryClassificationEvaluator
# from datasets import load_dataset_from_disk, DatasetDict
#

118
ucsinfer/gather.py Normal file
View File

@@ -0,0 +1,118 @@
from .inference import Ucs
from .util import ffmpeg_description, parse_ucs
from subprocess import CalledProcessError
import os.path
from datasets import Dataset, Features, Value, ClassLabel, DatasetInfo
from datasets.dataset_dict import DatasetDict
from typing import Iterator, Generator
from tabulate import tabulate
import logging
import tqdm
def walk_path(path:str, catid_list) -> list[tuple[str,str]]:
types = ['.wav', '.flac']
logger = logging.getLogger('ucsinfer')
walker_p = tqdm.tqdm(total=None, unit='dir', desc="Walking filesystem...")
scan_list = []
for dirpath, _, filenames in os.walk(path):
logger.info(f"Walking directory {dirpath}")
for filename in filenames:
walker_p.update()
root, ext = os.path.splitext(filename)
if ext not in types or filename.startswith("._"):
continue
if (ucs_components := parse_ucs(root, catid_list)):
p = os.path.join(dirpath, filename)
logger.info(f"Adding path to scan list {p}")
scan_list.append((ucs_components.cat_id, p))
return scan_list
def scan_metadata(scan_list: list[tuple[str,str]], catid_list: list[str]):
logger = logging.getLogger('ucsinfer')
for pair in tqdm.tqdm(scan_list, unit='files'):
logger.info(f"Scanning file with ffprobe: {pair[1]}")
try:
desc = ffmpeg_description(pair[1])
except CalledProcessError as e:
logger.error(f"ffprobe returned error (){e.returncode}): " \
+ e.stderr)
continue
if desc:
yield desc, str(pair[0])
else:
comps = parse_ucs(os.path.basename(pair[1]), catid_list)
assert comps
yield comps.fx_name, str(pair[0])
def ucs_definitions_generator(ucs: list[Ucs]) \
-> Generator[tuple[str,str],None, None]:
for cat in ucs:
yield cat.explanations, cat.catid
yield ", ".join(cat.synonymns), cat.catid
def print_dataset_stats(dataset: DatasetDict, catlist: list[str]):
data_table = []
data_table.append([["Total records in combined dataset:", len(dataset)]])
data_table.append([["Total records in `train`:", len(dataset['train'])]])
tab = tabulate(data_table)
print(tab)
# https://www.sbert.net/docs/sentence_transformer/loss_overview.html
def build_sentence_class_dataset(
records: Iterator[tuple[str, str]],
catlist: list[str]) -> DatasetDict:
"""
Create a new dataset for `records` which contains (sentence, class) pairs.
The dataset is split into train and test slices.
:param records: a generator for records that generates pairs of
(sentence, catid)
:returns: A dataset with two columns: (sentence, hash(catid))
"""
labels = ClassLabel(names=catlist)
features = Features({'sentence': Value('string'),
'class': labels})
info = DatasetInfo(
description=f"(sentence, UCS CatID) pairs gathered by the "
"ucsinfer tool on {}", features= features)
items: list[dict] = []
for obj in records:
items += [{'sentence': obj[0], 'class': obj[1]}]
whole = Dataset.from_list(items, features=features, info=info)
split_set = whole.train_test_split(0.2)
test_eval_set = split_set['test'].train_test_split(0.5)
return DatasetDict({
'train': split_set['train'],
'test': test_eval_set['train'],
'eval': test_eval_set['test']
})
# def build_sentence_anchor_dataset() -> Dataset:
# """
# Create a new dataset for `records` which contains (sentence, anchor) pairs.
# """
# pass

View File

@@ -34,7 +34,7 @@ class Ucs(NamedTuple):
explanations=d['Explanations'], synonymns=d['Synonyms'])
def load_ucs() -> list[Ucs]:
def load_ucs(full_ucs: bool = True) -> list[Ucs]:
FILE_ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
cats = []
ucs_defs = os.path.join(FILE_ROOT_DIR, 'ucs-community', 'json',
@@ -43,7 +43,13 @@ def load_ucs() -> list[Ucs]:
with open(ucs_defs, 'r') as f:
cats = json.load(f)
return [Ucs.from_dict(cat) for cat in cats]
ucs = [Ucs.from_dict(cat) for cat in cats]
if full_ucs:
return ucs
else:
return [cat for cat in ucs if \
cat.category not in ['FOLEY', 'ARCHIVED']]
class InferenceContext:
@@ -54,39 +60,67 @@ class InferenceContext:
model: SentenceTransformer
model_name: str
def __init__(self, model: SentenceTransformer, model_name: str):
self.model = model
def __init__(self, model_name: str, use_cached_model: bool = True,
use_full_ucs: bool = False):
self.model_name = model_name
self.use_full_ucs = use_full_ucs
cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer",
'Squad 51')
model_cache_path = os.path.join(cache_dir,
f"{self.model_name}.cache")
if use_cached_model:
if os.path.exists(model_cache_path):
self.model = SentenceTransformer(model_cache_path)
else:
self.model = SentenceTransformer(model_name)
self.model.save(model_cache_path)
else:
self.model = SentenceTransformer(model_name)
@cached_property
def catlist(self) -> list[Ucs]:
return load_ucs()
return load_ucs(full_ucs=self.use_full_ucs)
@cached_property
def embeddings(self) -> list[dict]:
cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer", 'Squad 51')
embedding_cache = os.path.join(cache_dir,
f"{self.model_name}-ucs_embedding.cache")
cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer",
'Squad 51')
embedding_cache_path = os.path.join(
cache_dir,
f"{self.model_name}-ucs_embedding.cache")
embeddings = []
if os.path.exists(embedding_cache):
with open(embedding_cache, 'rb') as f:
if os.path.exists(embedding_cache_path):
with open(embedding_cache_path, 'rb') as f:
embeddings = pickle.load(f)
else:
print(f"Calculating embeddings for model {self.model_name}...")
for cat_defn in self.catlist:
# we need to calculate the embeddings for all cats, not just the
# ones we're loading for this run
full_catlist = load_ucs(full_ucs= True)
for cat_defn in full_catlist:
embeddings += [{
'CatID': cat_defn.catid,
'Embedding': self._encode_category(cat_defn)
}]
os.makedirs(os.path.dirname(embedding_cache), exist_ok=True)
with open(embedding_cache, 'wb') as g:
os.makedirs(os.path.dirname(embedding_cache_path), exist_ok=True)
with open(embedding_cache_path, 'wb') as g:
pickle.dump(embeddings, g)
return embeddings
whitelisted_cats = [cat.catid for cat in self.catlist]
return [e for e in embeddings if e['CatID'] in whitelisted_cats]
def _encode_category(self, cat: Ucs) -> np.ndarray:
sentence_components = [cat.explanations,

82
ucsinfer/recommend.py Normal file
View File

@@ -0,0 +1,82 @@
# recommend.py
from re import match
from .inference import InferenceContext
from tabulate import tabulate
def print_recommendation(path: str | None, text: str, ctx: InferenceContext,
interactive_rename: bool, recommend_limit=10):
"""
Print recommendations interactively.
:returns: If interactive_rename is false or path is None, returns None.
If interactive, returns a tuple of bool, str | None, str | None where:
- if retval[0] is False, the user has requested processing quit.
- if retval[1] is a str, this is the text the user has entered to
perform a new inference instead of the file metadata.
`print_recommendation` should be called again with this argument.
- if retval[2] is a str, this is the catid the user has selected.
"""
recs = ctx.classify_text_ranked(text, limit=recommend_limit)
print("----------")
if path:
print(f"Path: {path}")
print(f"Text: {text or '<None>'}")
for i, r in enumerate(recs):
cat, subcat, _ = ctx.lookup_category(r)
print(f"- {i}: {r} ({cat}-{subcat})")
if interactive_rename and path is not None:
response = input("(n#), t, c, b, ?, q > ")
if m := match(r'^([0-9]+)', response):
selection = int(m.group(1))
if 0 <= selection < len(recs):
return True, None, recs[selection]
else:
print(f"Invalid index {selection}")
return True, text, None
elif m := match(r'^t (.*)', response):
print("searching for new matches")
text = m.group(1)
return True, text, None
elif m := match(r'^c (.+)', response):
return True, None, m.group(1)
elif m := match(r'^b (.+)', response):
expt = []
for cat in ctx.catlist:
if cat.catid.startswith(m.group(1)):
expt.append([f"{cat.catid}: ({cat.category}-{cat.subcategory})",
cat.explanations])
print(tabulate(expt, maxcolwidths=80))
return True, text, None
elif response.startswith("?"):
print("""
Choices:
- Enter recommendation number to rename file,
- "t <text>" to search for new recommendations based on <text>
- "p" re-use the last selected cat-id
- "c <cat>" to type in a category by hand
- "b <cat>" browse category list for categories starting with <cat>
- "?" for this message
- "q" to quit
- or any other key to skip this file and continue to next file
""")
return True, text, None
elif response.startswith('q'):
return (False, None, None)
else:
print()
else:
return None

View File

@@ -1,5 +1,6 @@
import subprocess
import json
import os
from typing import NamedTuple, Optional
from re import match
@@ -8,10 +9,7 @@ def ffmpeg_description(path: str) -> Optional[str]:
result = subprocess.run(['ffprobe', '-show_format', '-of',
'json', path], capture_output=True)
try:
result.check_returncode()
except:
return None
result.check_returncode()
stream = json.loads(result.stdout)
fmt = stream.get("format", None)
@@ -59,6 +57,20 @@ class UcsNameComponents(NamedTuple):
return False
def normalize_ucs(basename: str, catid_list: list[str]):
"""
Take any filename and normalize it into the UCS system
"""
n, ext = os.path.splitext(basename)
r = parse_ucs(n, catid_list)
if r:
pass
else:
pass
return f"aaa.{ext}"
def build_ucs(components: UcsNameComponents, extension: str) -> str:
"""
Build a UCS filename
@@ -66,7 +78,29 @@ def build_ucs(components: UcsNameComponents, extension: str) -> str:
assert components.validate(), \
"UcsNameComponents contains invalid characters"
return ""
cat_segment = components.cat_id
if components.user_cat:
cat_segment += f"-{components.user_cat}"
name_segment = components.fx_name
if components.vendor_cat:
name_segment = f"{components.vendor_cat}-{components.fx_name}"
all_comps = [cat_segment, name_segment]
if components.creator:
all_comps += [components.creator]
if components.source:
all_comps += [components.source]
if components.user_data:
all_comps += [components.user_data]
root_name = "_".join(all_comps)
return root_name + '.' + extension
def parse_ucs(rootname: str,