Compare commits

...

3 Commits

Author SHA1 Message Date
04332b73ee Fixed a bug in cat masking 2025-09-04 10:54:35 -07:00
103fffe0a4 Added another function to recommend 2025-09-04 10:43:31 -07:00
10519f9c1a Split training sets from gather 2025-09-03 19:40:55 -07:00
6 changed files with 70 additions and 120 deletions

View File

@@ -17,8 +17,9 @@
- Print more information about the dataset coverage of UCS - Print more information about the dataset coverage of UCS
- Allow skipping model testing for this - Allow skipping model testing for this
- Print raw output - Print raw output
- Maybe load everything into a sqlite for slicker reporting <!-- - Maybe load everything into a sqlite for slicker reporting -->
## Utility ## Utility

View File

@@ -2,13 +2,14 @@ import os
# import csv # import csv
import logging import logging
from subprocess import CalledProcessError from subprocess import CalledProcessError
from itertools import chain
import tqdm import tqdm
import click import click
# from tabulate import tabulate, SEPARATING_LINE # from tabulate import tabulate, SEPARATING_LINE
from .inference import InferenceContext, load_ucs from .inference import InferenceContext, load_ucs
from .gather import build_sentence_class_dataset from .gather import build_sentence_class_dataset, print_dataset_stats
from .recommend import print_recommendation from .recommend import print_recommendation
from .util import ffmpeg_description, parse_ucs from .util import ffmpeg_description, parse_ucs
@@ -163,10 +164,12 @@ def gather(ctx, paths, out, ucs_data):
logger.info('Creating dataset for UCS categories instead of from PATH') logger.info('Creating dataset for UCS categories instead of from PATH')
paths = [] paths = []
walker_p = tqdm.tqdm(total=None, unit='dir', desc="Walking filesystem...")
for path in paths: for path in paths:
for dirpath, _, filenames in os.walk(path): for dirpath, _, filenames in os.walk(path):
logger.info(f"Walking directory {dirpath}") logger.info(f"Walking directory {dirpath}")
for filename in filenames: for filename in filenames:
walker_p.update()
root, ext = os.path.splitext(filename) root, ext = os.path.splitext(filename)
if ext not in types or filename.startswith("._"): if ext not in types or filename.startswith("._"):
continue continue
@@ -175,6 +178,7 @@ def gather(ctx, paths, out, ucs_data):
p = os.path.join(dirpath, filename) p = os.path.join(dirpath, filename)
logger.info(f"Adding path to scan list {p}") logger.info(f"Adding path to scan list {p}")
scan_list.append((ucs_components.cat_id, p)) scan_list.append((ucs_components.cat_id, p))
walker_p.close()
logger.info(f"Found {len(scan_list)} files to process.") logger.info(f"Found {len(scan_list)} files to process.")
@@ -184,7 +188,7 @@ def gather(ctx, paths, out, ucs_data):
try: try:
desc = ffmpeg_description(pair[1]) desc = ffmpeg_description(pair[1])
except CalledProcessError as e: except CalledProcessError as e:
logger.error(f"ffprobe returned error {e.returncode}: " \ logger.error(f"ffprobe returned error (){e.returncode}): " \
+ e.stderr) + e.stderr)
continue continue
@@ -195,19 +199,20 @@ def gather(ctx, paths, out, ucs_data):
assert comps assert comps
yield comps.fx_name, str(pair[0]) yield comps.fx_name, str(pair[0])
def ucs_metadata(): def ucs_metadata():
for cat in ucs: for cat in ucs:
yield cat.explanations, cat.catid yield cat.explanations, cat.catid
yield ", ".join(cat.synonymns), cat.catid yield ", ".join(cat.synonymns), cat.catid
logger.info("Building dataset...") logger.info("Building dataset...")
if ucs_data:
dataset = build_sentence_class_dataset(ucs_metadata(), catid_list) dataset = build_sentence_class_dataset(chain(scan_metadata(),
else: ucs_metadata()),
dataset = build_sentence_class_dataset(scan_metadata(), catid_list) catid_list)
logger.info(f"Saving dataset to disk at {out}") logger.info(f"Saving dataset to disk at {out}")
print_dataset_stats(dataset)
dataset.save_to_disk(out) dataset.save_to_disk(out)
@@ -222,114 +227,15 @@ def finetune(ctx):
@ucsinfer.command('evaluate') @ucsinfer.command('evaluate')
@click.option('--offset', type=int, default=0, metavar="<int>", @click.argument('dataset', default='dataset/')
help='Skip this many records in the dataset before processing')
@click.option('--limit', type=int, default=-1, metavar="<int>",
help='Process this many records and then exit')
@click.argument('dataset', type=click.File('r', encoding='utf8'),
default='dataset.csv')
@click.pass_context @click.pass_context
def evaluate(ctx, dataset, offset, limit, no_foley): def evaluate(ctx, dataset, offset, limit):
""" """
Use datasets to evaluate model performance
The `evaluate` command reads the input DATASET file row by row and
performs a classifcation of the given description against the selected
model (either the default or using the --model option). The command then
checks if the model inferred the correct category as given by the dataset.
The model gives its top 10 possible categories for a given description,
and the results are tabulated according to (1) wether the top
classification was correct, (2) wether the correct classifcation was in the
top 5, or (3) wether it was in the top 10. The worst-performing category,
the one with the most misses, is also reported as well as the category
coverage, how many categories are present in the dataset.
NOTE: With experimentation it was found that foley items generally were
classified according to their subject and not wether or not they were
foley, and so these categories can be excluded with the --no-foley option.
""" """
logger.debug("EVALUATE mode") logger.debug("EVALUATE mode")
logger.warning("Model evaluation is not currently implemented") logger.warning("Model evaluation is not currently implemented")
# inference_context = InferenceContext(
# ctx.obj['model_name'], use_cached_model=ctx.obj['model_cache'],
# use_full_ucs=ctx.obj['complete_ucs'])
#
# reader = csv.reader(dataset)
#
# results = []
#
# if offset > 0:
# logger.debug(f"Skipping {offset} records...")
#
# if limit > 0:
# logger.debug(f"Will only evaluate {limit} records...")
#
# progress_bar = tqdm.tqdm(total=limit,
# desc="Processing dataset...",
# unit="rec")
# for i, row in enumerate(reader):
# if i < offset:
# continue
#
# if limit > 0 and i >= limit + offset:
# break
#
# cat_id, description = row
# if no_foley and cat_id in ['FOLYProp', 'FOLYFeet']:
# continue
#
# guesses = inference_context.classify_text_ranked(description, limit=10)
# if cat_id == guesses[0]:
# results.append({'catid': cat_id, 'result': "TOP"})
# elif cat_id in guesses[0:5]:
# results.append({'catid': cat_id, 'result': "TOP_5"})
# elif cat_id in guesses:
# results.append({'catid': cat_id, 'result': "TOP_10"})
# else:
# results.append({'catid': cat_id, 'result': "MISS"})
#
# progress_bar.update(1)
#
# total = len(results)
# total_top = len([x for x in results if x['result'] == 'TOP'])
# total_top_5 = len([x for x in results if x['result'] == 'TOP_5'])
# total_top_10 = len([x for x in results if x['result'] == 'TOP_10'])
#
# cats = set([x['catid'] for x in results])
# total_cats = len(cats)
#
# miss_counts = []
# for cat in cats:
# miss_counts.append(
# (cat, len([x for x in results
# if x['catid'] == cat and x['result'] == 'MISS'])))
#
# miss_counts = sorted(miss_counts, key=lambda x: x[1])
#
# print(f"## Results for Model {model} ##\n")
#
# if no_foley:
# print("(FOLYProp and FOLYFeet have been omitted from the dataset.)\n")
#
# table = [
# ["Total records in sample:", f"{total}"],
# ["Top Result:", f"{total_top}",
# f"{float(total_top)/float(total):.2%}"],
# ["Top 5 Result:", f"{total_top_5}",
# f"{float(total_top_5)/float(total):.2%}"],
# ["Top 10 Result:", f"{total_top_10}",
# f"{float(total_top_10)/float(total):.2%}"],
# SEPARATING_LINE,
# ["UCS category count:", f"{len(inference_context.catlist)}"],
# ["Total categories in sample:", f"{total_cats}",
# f"{float(total_cats)/float(len(inference_context.catlist)):.2%}"],
# [f"Most missed category ({miss_counts[-1][0]}):",
# f"{miss_counts[-1][1]}",
# f"{float(miss_counts[-1][1])/float(total):.2%}"]
# ]
#
# print(tabulate(table, headers=['', 'n', 'pct'], tablefmt='github'))
if __name__ == '__main__': if __name__ == '__main__':

7
ucsinfer/evaluate.py Normal file
View File

@@ -0,0 +1,7 @@
# from sentence_transformers import SentenceTransformer
# from sentence_transformers.evaluation import BinaryClassificationEvaluator
# from datasets import load_dataset_from_disk, DatasetDict
#

View File

@@ -1,13 +1,28 @@
from datasets import Dataset, Features, Value, ClassLabel, DatasetInfo from datasets import Dataset, Features, Value, ClassLabel, DatasetInfo
from datasets.dataset_dict import DatasetDict
from typing import Generator, Any from typing import Iterator
from tabulate import tabulate
def print_dataset_stats(dataset: DatasetDict, catlist: list[str]):
data_table = []
data_table.append([["Total records in combined dataset:", len(dataset)]])
data_table.append([["Total records in `train`:", len(dataset['train'])]])
tab = tabulate(data_table)
print(tab)
# https://www.sbert.net/docs/sentence_transformer/loss_overview.html # https://www.sbert.net/docs/sentence_transformer/loss_overview.html
def build_sentence_class_dataset( def build_sentence_class_dataset(
records: Generator[tuple[str, str], Any, None], catlist: list[str]) -> Dataset: records: Iterator[tuple[str, str]],
catlist: list[str]) -> DatasetDict:
""" """
Create a new dataset for `records` which contains (sentence, class) pairs. Create a new dataset for `records` which contains (sentence, class) pairs.
The dataset is split into train and test slices.
:param records: a generator for records that generates pairs of :param records: a generator for records that generates pairs of
(sentence, catid) (sentence, catid)
@@ -16,9 +31,12 @@ def build_sentence_class_dataset(
labels = ClassLabel(names=catlist) labels = ClassLabel(names=catlist)
features = Features({'sentence': Value('string'),
'class': labels})
info = DatasetInfo( info = DatasetInfo(
description=f"(sentence, UCS CatID) pairs gathered by the " description=f"(sentence, UCS CatID) pairs gathered by the "
"ucsinfer tool on {}") "ucsinfer tool on {}", features= features)
items: list[dict] = [] items: list[dict] = []
@@ -26,9 +44,16 @@ def build_sentence_class_dataset(
items += [{'sentence': obj[0], 'class': obj[1]}] items += [{'sentence': obj[0], 'class': obj[1]}]
return Dataset.from_list(items, features=Features({'sentence': Value('string'), whole = Dataset.from_list(items, features=features, info=info)
'class': labels}),
info=info) split_set = whole.train_test_split(0.2)
test_eval_set = split_set['test'].train_test_split(0.5)
return DatasetDict({
'train': split_set['train'],
'test': test_eval_set['train'],
'eval': test_eval_set['test']
})
# def build_sentence_anchor_dataset() -> Dataset: # def build_sentence_anchor_dataset() -> Dataset:

View File

@@ -103,7 +103,12 @@ class InferenceContext:
else: else:
print(f"Calculating embeddings for model {self.model_name}...") print(f"Calculating embeddings for model {self.model_name}...")
for cat_defn in self.catlist: # we need to calculate the embeddings for all cats, not just the
# ones we're loading for this run
full_catlist = load_ucs(full_ucs= True)
for cat_defn in full_catlist:
embeddings += [{ embeddings += [{
'CatID': cat_defn.catid, 'CatID': cat_defn.catid,
'Embedding': self._encode_category(cat_defn) 'Embedding': self._encode_category(cat_defn)
@@ -113,7 +118,9 @@ class InferenceContext:
with open(embedding_cache_path, 'wb') as g: with open(embedding_cache_path, 'wb') as g:
pickle.dump(embeddings, g) pickle.dump(embeddings, g)
return embeddings whitelisted_cats = [cat.catid for cat in self.catlist]
return [e for e in embeddings if e['CatID'] in whitelisted_cats]
def _encode_category(self, cat: Ucs) -> np.ndarray: def _encode_category(self, cat: Ucs) -> np.ndarray:
sentence_components = [cat.explanations, sentence_components = [cat.explanations,

View File

@@ -1,7 +1,6 @@
# recommend.py # recommend.py
from re import match from re import match
from .inference import InferenceContext from .inference import InferenceContext
def print_recommendation(path: str | None, text: str, ctx: InferenceContext, def print_recommendation(path: str | None, text: str, ctx: InferenceContext,
@@ -23,6 +22,7 @@ def print_recommendation(path: str | None, text: str, ctx: InferenceContext,
print(f"Path: {path}") print(f"Path: {path}")
print(f"Text: {text or '<None>'}") print(f"Text: {text or '<None>'}")
for i, r in enumerate(recs): for i, r in enumerate(recs):
cat, subcat, _ = ctx.lookup_category(r) cat, subcat, _ = ctx.lookup_category(r)
print(f"- {i}: {r} ({cat}-{subcat})") print(f"- {i}: {r} ({cat}-{subcat})")
@@ -43,12 +43,16 @@ def print_recommendation(path: str | None, text: str, ctx: InferenceContext,
text = m.group(1) text = m.group(1)
return True, text, None return True, text, None
elif m := match(r'^c (.*)', response):
return True, None, m.group(1)
elif response.startswith("?"): elif response.startswith("?"):
print(""" print("""
Choices: Choices:
- Enter recommendation number to rename file, - Enter recommendation number to rename file,
- "t [text]" to search for new recommendations based on [text] - "t [text]" to search for new recommendations based on [text]
- "p" re-use the last selected cat-id - "p" re-use the last selected cat-id
- "c [cat]" to type in a category by hand
- "?" for this message - "?" for this message
- "q" to quit - "q" to quit
- or any other key to skip this file and continue to next file - or any other key to skip this file and continue to next file