diff --git a/README.md b/README.md index 307c1ab..b3ca746 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,31 @@ # ucsinfer -Universal Category System inference. \ No newline at end of file +Universal Category System inference. + +## Running + +```sh +python -m ucsinfer [command] +``` +Pass `--help` to see a summary of subcommands and options. + +The subcommands available at this time are `gather` and `evaluate`. + +## Functions + +* recommend + + Infer a UCS category for a text description. + +* gather + + Scan files to capture existing text descriptions and UCS categories and save + as a dataset. + +* finetune + + Fine-tune an existing sentence embedding model with training data. + +* evaluate + + Use datasets to evauluate the performance of fine-tuning. diff --git a/pyproject.toml b/pyproject.toml index 0d99163..1ceacea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,9 @@ dependencies = [ "sentence-transformers (>=5.0.0,<6.0.0)", "numpy (>=2.3.2,<3.0.0)", "tqdm (>=4.67.1,<5.0.0)", - "platformdirs (>=4.3.8,<5.0.0)" + "platformdirs (>=4.3.8,<5.0.0)", + "click (>=8.2.1,<9.0.0)", + "tabulate (>=0.9.0,<0.10.0)" ] diff --git a/ucsinfer/__main__.py b/ucsinfer/__main__.py index 61816dc..b14e6e7 100644 --- a/ucsinfer/__main__.py +++ b/ucsinfer/__main__.py @@ -1,205 +1,143 @@ import os -import json import sys -import subprocess -import cmd - -from typing import Optional, IO - -from .inference import InferenceContext +import csv from sentence_transformers import SentenceTransformer +import tqdm +import click +from tabulate import tabulate, SEPARATING_LINE + +from .inference import InferenceContext, load_ucs +from .util import ffmpeg_description, parse_ucs -def description(path: str) -> Optional[str]: - result = subprocess.run(['ffprobe', '-show_format', '-of', - 'json', path], capture_output=True) +@click.group() +@click.option('--verbose', flag_value='verbose', help='Verbose output') +def ucsinfer(verbose: bool): + pass - try: - result.check_returncode() - except: - return None - - stream = json.loads(result.stdout) - fmt = stream.get("format", None) - if fmt: - tags = fmt.get("tags", None) - if tags: - return tags.get("comment", None) -def recommend_category(ctx: InferenceContext, path) -> tuple[str, list]: +@ucsinfer.command('recommend') +def recommend(): """ - Get a text description of the file at `path` and a list of UCS cat IDs + Infer a UCS category for a text description """ - desc = description(path) - if desc is None: - desc = os.path.basename(path) - - return desc, ctx.classify_text_ranked(desc) + pass -from shutil import get_terminal_size +@ucsinfer.command('gather') +@click.option('--outfile', type=click.File(mode='w', encoding='utf8'), + default='dataset.csv', show_default=True) +@click.argument('paths', nargs=-1) +def gather(paths, outfile): + """ + Scan files to build a training dataset at PATH + """ + types = ['.wav', '.flac'] + table = csv.writer(outfile) + print(f"Loading category list...") + catid_list = [cat.catid for cat in load_ucs()] -class Commands(cmd.Cmd): - ctx: InferenceContext + scan_list = [] + for path in paths: + print(f"Scanning directory {path}...", file=sys.stdout) + for dirpath, _, filenames in os.walk(path): + for filename in filenames: + root, ext = os.path.splitext(filename) + if ext in types and \ + (ucs_components := parse_ucs(root, catid_list)) and \ + not filename.startswith("._"): + scan_list.append((ucs_components.cat_id, + os.path.join(dirpath, filename))) - def __init__(self, completekey: str = "tab", stdin: IO[str] | None = None, - stdout: IO[str] | None = None) -> None: - super().__init__(completekey, stdin, stdout) - self.file_list = [] - self.catlist: list = [] - self.rec_list = [] - self.history = [] + print(f"Found {len(scan_list)} files to process.") - def preloop(self) -> None: - self.file_cursor = 0 - self.update_prompt() - self.setup_for_file() - return super().preloop() - - def default(self, line: str): - try: - sel = int(line) - self.onecmd(f"use {sel}") - - except ValueError: - return super().default(line) - - return super().default(line) - - def precmd(self, line: str): - try: - rec = int(line) - if rec < len(self.rec_list): - pass - else: - pass - - except ValueError: - pass - - finally: - return super().precmd(line) - - def postcmd(self, stop: bool, line: str) -> bool: - if not stop: - self.update_prompt() - self.setup_for_file() - return super().postcmd(stop, line) - - def update_prompt(self): - self.prompt = f"(ucsinfer:{self.file_cursor}/{len(self.file_list)}) " - - def setup_for_file(self): - if len(self.file_list) == 0: - print(" >> NO FILES!") - else: - file = self.file_list[self.file_cursor] - desc, recs = recommend_category(self.ctx, file) - self.onecmd('file') - print(f" >> {desc}") - self.print_recommendations(recs) - - def print_recommendations(self, top_recs): - self.rec_list = [] - - cols, _ = get_terminal_size((80,20)) - - def print_one_rec(index, rec): - cat, subcat, exp = self.ctx.lookup_category(rec) - line = f" [{index:2}] {rec} - {cat} / {subcat} - {exp}" - if len(line) > cols - 3: - line = line[0:cols - 3] + "..." - print(line) - - print("Suggested from description:") - for rec in top_recs: - print_one_rec(len(self.rec_list), rec) - self.rec_list.append(rec) - - if len(self.history) > 0: - print("History:") - for rec in self.history: - print_one_rec(len(self.rec_list), rec) - self.rec_list.append(rec) - - def do_about(self, line: str): - 'Print information about recommendation NUMBER' - try: - picked = int(line) - if picked < len(self.rec_list): - cat, subcat, exp = \ - self.ctx.lookup_category(self.rec_list[picked]) - - print(f" CatID: {self.rec_list[picked]}") - print(f" Category: {cat}") - print(f" SubCategory: {subcat}") - print(f" Explanation: {exp}") - - except ValueError: - print(f" *** Value \"{line}\" not recognized") + for pair in tqdm.tqdm(scan_list, unit='files',file=sys.stderr): + if desc := ffmpeg_description(pair[1]): + table.writerow([pair[0], desc]) - def do_use(self, line: str): - """Apply recomendation NUMBER to the current file and advance to the - next one""" - - try: - picked = int(line) - print(f" :: Using {self.rec_list[picked]}") - self.do_next("") - except ValueError: - print(" *** Value \"{line}\" not recognized") - - def do_file(self, _): - 'Print info about the current file' - print("---") - if self.file_cursor < len(self.file_list): - path = self.file_list[self.file_cursor] - f = os.path.basename(path) - print(f" > {f}") - else: - print( " > No file") - - def do_ls(self, _): - 'Print list of all files in the buffer' - for file in self.file_list[self.file_cursor:] + \ - self.file_list[0:self.file_cursor]: - f = os.path.basename(file) - print(f" > {f}") - - def do_next(self, _): - 'go to next file' - self.file_cursor += 1 - self.file_cursor = self.file_cursor % len(self.file_list) - self.setup_for_file() - - def do_prev(self, _): - 'go to previous file' - self.file_cursor -= 1 - self.file_cursor = self.file_cursor % len(self.file_list) - self.setup_for_file() - - def do_bye(self, _): - 'exit the program' - print("Exiting...") - return True +@ucsinfer.command('finetune') +def finetune(): + """ + Fine-tune a model with training data + """ + pass -def main(): +@ucsinfer.command('evaluate') +@click.option('--offset', type=int, default=0) +@click.option('--limit', type=int, default=-1) +@click.argument('dataset', type=click.File('r', encoding='utf8'), + default='dataset.csv') +def evaluate(dataset, offset, limit): + """ + Use datasets to evauluate model performance + """ model = SentenceTransformer("paraphrase-multilingual-mpnet-base-v2") + ctx = InferenceContext(model) + reader = csv.reader(dataset) - com = Commands() - com.file_list = sys.argv[1:] - com.ctx = InferenceContext(model=model) + results = [] + for i, row in enumerate(tqdm.tqdm(reader)): + if i < offset: + continue + + if limit > 0 and i >= limit + offset: + break - com.cmdloop() + cat_id, description = row + guesses = ctx.classify_text_ranked(description, limit=10) + if cat_id == guesses[0]: + results.append({'catid': cat_id, 'result': "TOP"}) + elif cat_id in guesses[0:5]: + results.append({'catid': cat_id, 'result': "TOP_5"}) + elif cat_id in guesses: + results.append({'catid': cat_id, 'result': "TOP_10"}) + else: + results.append({'catid': cat_id, 'result': "MISS"}) + total = len(results) + total_top = len([x for x in results if x['result'] == 'TOP']) + total_top_5 = len([x for x in results if x['result'] == 'TOP_5']) + total_top_10 = len([x for x in results if x['result'] == 'TOP_10']) + cats = set([x['catid'] for x in results]) + total_cats = len(cats) + + miss_counts = [] + for cat in cats: + miss_counts.append((cat, len([x for x in results \ + if x['catid'] == cat and x['result'] == 'MISS']))) + + miss_counts = sorted(miss_counts, key=lambda x: x[1]) + + print(f" === RESULTS === ") + + table = [ + ["Total records in sample:", f"{total}"], + ["Top Result:", f"{total_top}", + f"{float(total_top)/float(total):.2%}"], + ["Top 5 Result:", f"{total_top_5}", + f"{float(total_top_5)/float(total):.2%}"], + ["Top 10 Result:", f"{total_top_10}", + f"{float(total_top_10)/float(total):.2%}"], + SEPARATING_LINE, + ["UCS category count:", f"{len(ctx.catlist)}"], + ["Total categories in sample:", f"{total_cats}", + f"{float(total_cats)/float(len(ctx.catlist)):.2%}"], + [f"Most missed category ({miss_counts[-1][0]}):", + f"{miss_counts[-1][1]}", + f"{float(miss_counts[-1][1])/float(total):.2%}"] + ] + + print(tabulate(table, headers=['','n','pct'])) + + if __name__ == '__main__': os.environ['TOKENIZERS_PARALLELISM'] = 'false' import warnings warnings.simplefilter(action='ignore', category=FutureWarning) - main() + ucsinfer() diff --git a/ucsinfer/inference.py b/ucsinfer/inference.py index 5750f84..a951411 100644 --- a/ucsinfer/inference.py +++ b/ucsinfer/inference.py @@ -32,6 +32,17 @@ class Ucs(NamedTuple): subcategory=d['SubCategory'], explanations=d['Explanations'], synonymns=d['Synonyms']) +def load_ucs() -> list[Ucs]: + FILE_ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) + cats = [] + ucs_defs = os.path.join(FILE_ROOT_DIR, 'ucs-community', 'json', + 'en.json') + + with open(ucs_defs, 'r') as f: + cats = json.load(f) + + return [Ucs.from_dict(cat) for cat in cats] + class InferenceContext: """ Maintains caches and resources for UCS category inference. @@ -44,15 +55,7 @@ class InferenceContext: @cached_property def catlist(self) -> list[Ucs]: - FILE_ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) - cats = [] - ucs_defs = os.path.join(FILE_ROOT_DIR, 'ucs-community', 'json', - 'en.json') - - with open(ucs_defs, 'r') as f: - cats = json.load(f) - - return [Ucs.from_dict(cat) for cat in cats] + return load_ucs() @cached_property def embeddings(self) -> list[dict]: diff --git a/ucsinfer/util.py b/ucsinfer/util.py new file mode 100644 index 0000000..8a3d0ed --- /dev/null +++ b/ucsinfer/util.py @@ -0,0 +1,58 @@ +import subprocess +import json +from typing import NamedTuple, Optional +from re import match + + +from .inference import Ucs + +def ffmpeg_description(path: str) -> Optional[str]: + result = subprocess.run(['ffprobe', '-show_format', '-of', + 'json', path], capture_output=True) + + try: + result.check_returncode() + except: + return None + + stream = json.loads(result.stdout) + fmt = stream.get("format", None) + if fmt: + tags = fmt.get("tags", None) + if tags: + return tags.get("comment", None) + + +class UcsNameComponents(NamedTuple): + cat_id: str + user_cat: str | None + vendor_cat: str | None + fx_name: str + creator: str | None + source: str | None + user_data: str | None + + +def parse_ucs(basename: str, catid_list: list[str]) -> Optional[UcsNameComponents]: + + regexp1 = r"^(?P[A-z]+)(-(?P[^_]+))?_((?P[^-]+)-)?(?P[^_]+)" + + regexp2 = r"(_(?P[^_]+)(_(?P[^_]+)(_(?P[^.]+))?)?)?" + + regexp = regexp1 + regexp2 + + matches = match(regexp, basename) + + if matches is None: + return None + + if matches.group('CatID') not in catid_list: + return None + + return UcsNameComponents(cat_id=matches.group('CatID'), + user_cat=matches.group('UserCat'), + vendor_cat=matches.group('VendorCat'), + fx_name=matches.group('FXName'), + creator=matches.group('CreatorID'), + source=matches.group('SourceID'), + user_data=matches.group('UserData'))