From 5ea64d089f9d33be7e3225a435e2b32678f50158 Mon Sep 17 00:00:00 2001 From: Jamie Hardt Date: Tue, 26 Aug 2025 17:14:56 -0700 Subject: [PATCH] Autopep --- ucsinfer/__main__.py | 54 +++++++++++++++++++++---------------------- ucsinfer/inference.py | 34 ++++++++++++++------------- ucsinfer/util.py | 23 +++++++++--------- 3 files changed, 57 insertions(+), 54 deletions(-) diff --git a/ucsinfer/__main__.py b/ucsinfer/__main__.py index b14e6e7..dcdd70a 100644 --- a/ucsinfer/__main__.py +++ b/ucsinfer/__main__.py @@ -22,7 +22,7 @@ def recommend(): """ Infer a UCS category for a text description """ - pass + pass @ucsinfer.command('gather') @@ -36,7 +36,7 @@ def gather(paths, outfile): types = ['.wav', '.flac'] table = csv.writer(outfile) print(f"Loading category list...") - catid_list = [cat.catid for cat in load_ucs()] + catid_list = [cat.catid for cat in load_ucs()] scan_list = [] for path in paths: @@ -47,12 +47,12 @@ def gather(paths, outfile): if ext in types and \ (ucs_components := parse_ucs(root, catid_list)) and \ not filename.startswith("._"): - scan_list.append((ucs_components.cat_id, + scan_list.append((ucs_components.cat_id, os.path.join(dirpath, filename))) print(f"Found {len(scan_list)} files to process.") - for pair in tqdm.tqdm(scan_list, unit='files',file=sys.stderr): + for pair in tqdm.tqdm(scan_list, unit='files', file=sys.stderr): if desc := ffmpeg_description(pair[1]): table.writerow([pair[0], desc]) @@ -62,13 +62,13 @@ def finetune(): """ Fine-tune a model with training data """ - pass + pass @ucsinfer.command('evaluate') @click.option('--offset', type=int, default=0) @click.option('--limit', type=int, default=-1) -@click.argument('dataset', type=click.File('r', encoding='utf8'), +@click.argument('dataset', type=click.File('r', encoding='utf8'), default='dataset.csv') def evaluate(dataset, offset, limit): """ @@ -82,7 +82,7 @@ def evaluate(dataset, offset, limit): for i, row in enumerate(tqdm.tqdm(reader)): if i < offset: continue - + if limit > 0 and i >= limit + offset: break @@ -107,33 +107,33 @@ def evaluate(dataset, offset, limit): miss_counts = [] for cat in cats: - miss_counts.append((cat, len([x for x in results \ - if x['catid'] == cat and x['result'] == 'MISS']))) + miss_counts.append((cat, len([x for x in results + if x['catid'] == cat and x['result'] == 'MISS']))) miss_counts = sorted(miss_counts, key=lambda x: x[1]) print(f" === RESULTS === ") - + table = [ - ["Total records in sample:", f"{total}"], - ["Top Result:", f"{total_top}", - f"{float(total_top)/float(total):.2%}"], - ["Top 5 Result:", f"{total_top_5}", - f"{float(total_top_5)/float(total):.2%}"], - ["Top 10 Result:", f"{total_top_10}", - f"{float(total_top_10)/float(total):.2%}"], - SEPARATING_LINE, - ["UCS category count:", f"{len(ctx.catlist)}"], - ["Total categories in sample:", f"{total_cats}", - f"{float(total_cats)/float(len(ctx.catlist)):.2%}"], - [f"Most missed category ({miss_counts[-1][0]}):", - f"{miss_counts[-1][1]}", - f"{float(miss_counts[-1][1])/float(total):.2%}"] - ] + ["Total records in sample:", f"{total}"], + ["Top Result:", f"{total_top}", + f"{float(total_top)/float(total):.2%}"], + ["Top 5 Result:", f"{total_top_5}", + f"{float(total_top_5)/float(total):.2%}"], + ["Top 10 Result:", f"{total_top_10}", + f"{float(total_top_10)/float(total):.2%}"], + SEPARATING_LINE, + ["UCS category count:", f"{len(ctx.catlist)}"], + ["Total categories in sample:", f"{total_cats}", + f"{float(total_cats)/float(len(ctx.catlist)):.2%}"], + [f"Most missed category ({miss_counts[-1][0]}):", + f"{miss_counts[-1][1]}", + f"{float(miss_counts[-1][1])/float(total):.2%}"] + ] + + print(tabulate(table, headers=['', 'n', 'pct'])) - print(tabulate(table, headers=['','n','pct'])) - if __name__ == '__main__': os.environ['TOKENIZERS_PARALLELISM'] = 'false' diff --git a/ucsinfer/inference.py b/ucsinfer/inference.py index a951411..c7cee72 100644 --- a/ucsinfer/inference.py +++ b/ucsinfer/inference.py @@ -11,6 +11,7 @@ import platformdirs from sentence_transformers import SentenceTransformer + def classify_text_ranked(text, embeddings_list, model, limit=5): text_embedding = model.encode(text, convert_to_numpy=True) embeddings = np.array([info['Embedding'] for info in embeddings_list]) @@ -23,15 +24,16 @@ class Ucs(NamedTuple): catid: str category: str subcategory: str - explanations: str + explanations: str synonymns: list[str] @classmethod def from_dict(cls, d: dict): - return Ucs(catid=d['CatID'], category=d['Category'], - subcategory=d['SubCategory'], + return Ucs(catid=d['CatID'], category=d['Category'], + subcategory=d['SubCategory'], explanations=d['Explanations'], synonymns=d['Synonyms']) + def load_ucs() -> list[Ucs]: FILE_ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) cats = [] @@ -43,6 +45,7 @@ def load_ucs() -> list[Ucs]: return [Ucs.from_dict(cat) for cat in cats] + class InferenceContext: """ Maintains caches and resources for UCS category inference. @@ -72,9 +75,9 @@ class InferenceContext: for cat_defn in self.catlist: embeddings += [{ - 'CatID': cat_defn.catid, - 'Embedding': self._encode_category(cat_defn) - }] + 'CatID': cat_defn.catid, + 'Embedding': self._encode_category(cat_defn) + }] os.makedirs(os.path.dirname(embedding_cache), exist_ok=True) with open(embedding_cache, 'wb') as g: @@ -83,10 +86,10 @@ class InferenceContext: return embeddings def _encode_category(self, cat: Ucs) -> np.ndarray: - sentence_components = [cat.explanations, - cat.category, - cat.subcategory - ] + sentence_components = [cat.explanations, + cat.category, + cat.subcategory + ] sentence_components += cat.synonymns sentence = ", ".join(sentence_components) return self.model.encode(sentence, convert_to_numpy=True) @@ -104,13 +107,12 @@ class InferenceContext: def lookup_category(self, catid) -> tuple[str, str, str]: """ Get the category, subcategory and explanations phrase for a `catid` - + :raises: StopIterator if CatId is not on the schedule """ i = ( - (x.category, x.subcategory, x.explanations) \ - for x in self.catlist if x.catid == catid - ) - - return next(i) + (x.category, x.subcategory, x.explanations) + for x in self.catlist if x.catid == catid + ) + return next(i) diff --git a/ucsinfer/util.py b/ucsinfer/util.py index fed9e55..08a43ac 100644 --- a/ucsinfer/util.py +++ b/ucsinfer/util.py @@ -1,20 +1,21 @@ import subprocess import json -from typing import NamedTuple, Optional +from typing import NamedTuple, Optional from re import match -from .inference import Ucs +from .inference import Ucs + def ffmpeg_description(path: str) -> Optional[str]: - result = subprocess.run(['ffprobe', '-show_format', '-of', + result = subprocess.run(['ffprobe', '-show_format', '-of', 'json', path], capture_output=True) try: result.check_returncode() except: return None - + stream = json.loads(result.stdout) fmt = stream.get("format", None) if fmt: @@ -28,10 +29,10 @@ class UcsNameComponents(NamedTuple): Components of a UCS filename """ cat_id: str - user_cat: str | None + user_cat: str | None vendor_cat: str | None fx_name: str - creator: str | None + creator: str | None source: str | None user_data: str | None @@ -43,7 +44,7 @@ class UcsNameComponents(NamedTuple): return False if self.user_cat and not match(r"[^\-_]+", self.user_cat): - return False + return False if self.vendor_cat and not match(r"[^\-_]+", self.vendor_cat): return False @@ -52,7 +53,7 @@ class UcsNameComponents(NamedTuple): return False if self.creator and not match(r"[^_]+", self.creator): - return False + return False if self.source and not match(r"[^_]+", self.source): return False @@ -73,7 +74,7 @@ def build_ucs(components: UcsNameComponents, extension: str) -> str: def parse_ucs(rootname: str, catid_list: list[str]) -> Optional[UcsNameComponents]: """ Parse the UCS components from a file name root. - + :param rootname: filename root, the basename of the file without extension :param catid_list: a list of all UCS CatIDs :returns: the components, or `None` if the filename is not in UCS format @@ -82,8 +83,8 @@ def parse_ucs(rootname: str, catid_list: list[str]) -> Optional[UcsNameComponent regexp1 = r"^(?P[A-z]+)(-(?P[^_]+))?_((?P[^-]+)-)?(?P[^_]+)" regexp2 = r"(_(?P[^_]+)(_(?P[^_]+)(_(?P[^.]+))?)?)?" - - regexp = regexp1 + regexp2 + + regexp = regexp1 + regexp2 matches = match(regexp, rootname)