Compare commits

...

2 Commits

Author SHA1 Message Date
Jamie Hardt
fb56ca1dd4 Formatting 2025-08-26 17:21:03 -07:00
Jamie Hardt
5ea64d089f Autopep 2025-08-26 17:14:56 -07:00
3 changed files with 65 additions and 58 deletions

View File

@@ -22,7 +22,7 @@ def recommend():
""" """
Infer a UCS category for a text description Infer a UCS category for a text description
""" """
pass pass
@ucsinfer.command('gather') @ucsinfer.command('gather')
@@ -36,7 +36,7 @@ def gather(paths, outfile):
types = ['.wav', '.flac'] types = ['.wav', '.flac']
table = csv.writer(outfile) table = csv.writer(outfile)
print(f"Loading category list...") print(f"Loading category list...")
catid_list = [cat.catid for cat in load_ucs()] catid_list = [cat.catid for cat in load_ucs()]
scan_list = [] scan_list = []
for path in paths: for path in paths:
@@ -47,12 +47,12 @@ def gather(paths, outfile):
if ext in types and \ if ext in types and \
(ucs_components := parse_ucs(root, catid_list)) and \ (ucs_components := parse_ucs(root, catid_list)) and \
not filename.startswith("._"): not filename.startswith("._"):
scan_list.append((ucs_components.cat_id, scan_list.append((ucs_components.cat_id,
os.path.join(dirpath, filename))) os.path.join(dirpath, filename)))
print(f"Found {len(scan_list)} files to process.") print(f"Found {len(scan_list)} files to process.")
for pair in tqdm.tqdm(scan_list, unit='files',file=sys.stderr): for pair in tqdm.tqdm(scan_list, unit='files', file=sys.stderr):
if desc := ffmpeg_description(pair[1]): if desc := ffmpeg_description(pair[1]):
table.writerow([pair[0], desc]) table.writerow([pair[0], desc])
@@ -62,13 +62,13 @@ def finetune():
""" """
Fine-tune a model with training data Fine-tune a model with training data
""" """
pass pass
@ucsinfer.command('evaluate') @ucsinfer.command('evaluate')
@click.option('--offset', type=int, default=0) @click.option('--offset', type=int, default=0)
@click.option('--limit', type=int, default=-1) @click.option('--limit', type=int, default=-1)
@click.argument('dataset', type=click.File('r', encoding='utf8'), @click.argument('dataset', type=click.File('r', encoding='utf8'),
default='dataset.csv') default='dataset.csv')
def evaluate(dataset, offset, limit): def evaluate(dataset, offset, limit):
""" """
@@ -82,7 +82,7 @@ def evaluate(dataset, offset, limit):
for i, row in enumerate(tqdm.tqdm(reader)): for i, row in enumerate(tqdm.tqdm(reader)):
if i < offset: if i < offset:
continue continue
if limit > 0 and i >= limit + offset: if limit > 0 and i >= limit + offset:
break break
@@ -107,33 +107,34 @@ def evaluate(dataset, offset, limit):
miss_counts = [] miss_counts = []
for cat in cats: for cat in cats:
miss_counts.append((cat, len([x for x in results \ miss_counts.append(
if x['catid'] == cat and x['result'] == 'MISS']))) (cat, len([x for x in results
if x['catid'] == cat and x['result'] == 'MISS'])))
miss_counts = sorted(miss_counts, key=lambda x: x[1]) miss_counts = sorted(miss_counts, key=lambda x: x[1])
print(f" === RESULTS === ") print(f" === RESULTS === ")
table = [ table = [
["Total records in sample:", f"{total}"], ["Total records in sample:", f"{total}"],
["Top Result:", f"{total_top}", ["Top Result:", f"{total_top}",
f"{float(total_top)/float(total):.2%}"], f"{float(total_top)/float(total):.2%}"],
["Top 5 Result:", f"{total_top_5}", ["Top 5 Result:", f"{total_top_5}",
f"{float(total_top_5)/float(total):.2%}"], f"{float(total_top_5)/float(total):.2%}"],
["Top 10 Result:", f"{total_top_10}", ["Top 10 Result:", f"{total_top_10}",
f"{float(total_top_10)/float(total):.2%}"], f"{float(total_top_10)/float(total):.2%}"],
SEPARATING_LINE, SEPARATING_LINE,
["UCS category count:", f"{len(ctx.catlist)}"], ["UCS category count:", f"{len(ctx.catlist)}"],
["Total categories in sample:", f"{total_cats}", ["Total categories in sample:", f"{total_cats}",
f"{float(total_cats)/float(len(ctx.catlist)):.2%}"], f"{float(total_cats)/float(len(ctx.catlist)):.2%}"],
[f"Most missed category ({miss_counts[-1][0]}):", [f"Most missed category ({miss_counts[-1][0]}):",
f"{miss_counts[-1][1]}", f"{miss_counts[-1][1]}",
f"{float(miss_counts[-1][1])/float(total):.2%}"] f"{float(miss_counts[-1][1])/float(total):.2%}"]
] ]
print(tabulate(table, headers=['', 'n', 'pct']))
print(tabulate(table, headers=['','n','pct']))
if __name__ == '__main__': if __name__ == '__main__':
os.environ['TOKENIZERS_PARALLELISM'] = 'false' os.environ['TOKENIZERS_PARALLELISM'] = 'false'

View File

@@ -11,6 +11,7 @@ import platformdirs
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
def classify_text_ranked(text, embeddings_list, model, limit=5): def classify_text_ranked(text, embeddings_list, model, limit=5):
text_embedding = model.encode(text, convert_to_numpy=True) text_embedding = model.encode(text, convert_to_numpy=True)
embeddings = np.array([info['Embedding'] for info in embeddings_list]) embeddings = np.array([info['Embedding'] for info in embeddings_list])
@@ -23,15 +24,16 @@ class Ucs(NamedTuple):
catid: str catid: str
category: str category: str
subcategory: str subcategory: str
explanations: str explanations: str
synonymns: list[str] synonymns: list[str]
@classmethod @classmethod
def from_dict(cls, d: dict): def from_dict(cls, d: dict):
return Ucs(catid=d['CatID'], category=d['Category'], return Ucs(catid=d['CatID'], category=d['Category'],
subcategory=d['SubCategory'], subcategory=d['SubCategory'],
explanations=d['Explanations'], synonymns=d['Synonyms']) explanations=d['Explanations'], synonymns=d['Synonyms'])
def load_ucs() -> list[Ucs]: def load_ucs() -> list[Ucs]:
FILE_ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) FILE_ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
cats = [] cats = []
@@ -43,6 +45,7 @@ def load_ucs() -> list[Ucs]:
return [Ucs.from_dict(cat) for cat in cats] return [Ucs.from_dict(cat) for cat in cats]
class InferenceContext: class InferenceContext:
""" """
Maintains caches and resources for UCS category inference. Maintains caches and resources for UCS category inference.
@@ -72,9 +75,9 @@ class InferenceContext:
for cat_defn in self.catlist: for cat_defn in self.catlist:
embeddings += [{ embeddings += [{
'CatID': cat_defn.catid, 'CatID': cat_defn.catid,
'Embedding': self._encode_category(cat_defn) 'Embedding': self._encode_category(cat_defn)
}] }]
os.makedirs(os.path.dirname(embedding_cache), exist_ok=True) os.makedirs(os.path.dirname(embedding_cache), exist_ok=True)
with open(embedding_cache, 'wb') as g: with open(embedding_cache, 'wb') as g:
@@ -83,10 +86,10 @@ class InferenceContext:
return embeddings return embeddings
def _encode_category(self, cat: Ucs) -> np.ndarray: def _encode_category(self, cat: Ucs) -> np.ndarray:
sentence_components = [cat.explanations, sentence_components = [cat.explanations,
cat.category, cat.category,
cat.subcategory cat.subcategory
] ]
sentence_components += cat.synonymns sentence_components += cat.synonymns
sentence = ", ".join(sentence_components) sentence = ", ".join(sentence_components)
return self.model.encode(sentence, convert_to_numpy=True) return self.model.encode(sentence, convert_to_numpy=True)
@@ -104,13 +107,12 @@ class InferenceContext:
def lookup_category(self, catid) -> tuple[str, str, str]: def lookup_category(self, catid) -> tuple[str, str, str]:
""" """
Get the category, subcategory and explanations phrase for a `catid` Get the category, subcategory and explanations phrase for a `catid`
:raises: StopIterator if CatId is not on the schedule :raises: StopIterator if CatId is not on the schedule
""" """
i = ( i = (
(x.category, x.subcategory, x.explanations) \ (x.category, x.subcategory, x.explanations)
for x in self.catlist if x.catid == catid for x in self.catlist if x.catid == catid
) )
return next(i)
return next(i)

View File

@@ -1,20 +1,21 @@
import subprocess import subprocess
import json import json
from typing import NamedTuple, Optional from typing import NamedTuple, Optional
from re import match from re import match
from .inference import Ucs from .inference import Ucs
def ffmpeg_description(path: str) -> Optional[str]: def ffmpeg_description(path: str) -> Optional[str]:
result = subprocess.run(['ffprobe', '-show_format', '-of', result = subprocess.run(['ffprobe', '-show_format', '-of',
'json', path], capture_output=True) 'json', path], capture_output=True)
try: try:
result.check_returncode() result.check_returncode()
except: except:
return None return None
stream = json.loads(result.stdout) stream = json.loads(result.stdout)
fmt = stream.get("format", None) fmt = stream.get("format", None)
if fmt: if fmt:
@@ -28,10 +29,10 @@ class UcsNameComponents(NamedTuple):
Components of a UCS filename Components of a UCS filename
""" """
cat_id: str cat_id: str
user_cat: str | None user_cat: str | None
vendor_cat: str | None vendor_cat: str | None
fx_name: str fx_name: str
creator: str | None creator: str | None
source: str | None source: str | None
user_data: str | None user_data: str | None
@@ -43,7 +44,7 @@ class UcsNameComponents(NamedTuple):
return False return False
if self.user_cat and not match(r"[^\-_]+", self.user_cat): if self.user_cat and not match(r"[^\-_]+", self.user_cat):
return False return False
if self.vendor_cat and not match(r"[^\-_]+", self.vendor_cat): if self.vendor_cat and not match(r"[^\-_]+", self.vendor_cat):
return False return False
@@ -52,7 +53,7 @@ class UcsNameComponents(NamedTuple):
return False return False
if self.creator and not match(r"[^_]+", self.creator): if self.creator and not match(r"[^_]+", self.creator):
return False return False
if self.source and not match(r"[^_]+", self.source): if self.source and not match(r"[^_]+", self.source):
return False return False
@@ -65,25 +66,28 @@ def build_ucs(components: UcsNameComponents, extension: str) -> str:
""" """
Build a UCS filename Build a UCS filename
""" """
assert components.validate(), "UcsNameComponents contains invalid characters" assert components.validate(), \
"UcsNameComponents contains invalid characters"
return "" return ""
def parse_ucs(rootname: str, catid_list: list[str]) -> Optional[UcsNameComponents]: def parse_ucs(rootname: str,
catid_list: list[str]) -> Optional[UcsNameComponents]:
""" """
Parse the UCS components from a file name root. Parse the UCS components from a file name root.
:param rootname: filename root, the basename of the file without extension :param rootname: filename root, the basename of the file without extension
:param catid_list: a list of all UCS CatIDs :param catid_list: a list of all UCS CatIDs
:returns: the components, or `None` if the filename is not in UCS format :returns: the components, or `None` if the filename is not in UCS format
""" """
regexp1 = r"^(?P<CatID>[A-z]+)(-(?P<UserCat>[^_]+))?_((?P<VendorCat>[^-]+)-)?(?P<FXName>[^_]+)" regexp1 = r"^(?P<CatID>[A-z]+)(-(?P<UserCat>[^_]+))?_"
regexp2 = r"((?P<VendorCat>[^-]+)-)?(?P<FXName>[^_]+)"
regexp3 = r"(_(?P<CreatorID>[^_]+)(_(?P<SourceID>[^_]+)"
regexp4 = r"(_(?P<UserData>[^.]+))?)?)?"
regexp2 = r"(_(?P<CreatorID>[^_]+)(_(?P<SourceID>[^_]+)(_(?P<UserData>[^.]+))?)?)?" regexp = regexp1 + regexp2 + regexp3 + regexp4
regexp = regexp1 + regexp2
matches = match(regexp, rootname) matches = match(regexp, rootname)