diff --git a/ucsinfer/__main__.py b/ucsinfer/__main__.py index 661b62a..72cde6e 100644 --- a/ucsinfer/__main__.py +++ b/ucsinfer/__main__.py @@ -1,11 +1,11 @@ import os -import csv +# import csv import logging - +from subprocess import CalledProcessError import tqdm import click -from tabulate import tabulate, SEPARATING_LINE +# from tabulate import tabulate, SEPARATING_LINE from .inference import InferenceContext, load_ucs from .gather import build_sentence_class_dataset @@ -30,8 +30,11 @@ logger.addHandler(stream_handler) help="Select the sentence_transformer model to use") @click.option('--no-model-cache', flag_value=True, help="Don't use local model cache") +@click.option('--complete-ucs', flag_value=True, default=False, + help="Use all UCS categories. By default, all 'FOLEY' and " + "'ARCHIVED' UCS categories are excluded from all functions.") @click.pass_context -def ucsinfer(ctx, verbose, no_model_cache, model): +def ucsinfer(ctx, verbose, no_model_cache, model, complete_ucs): """ Tools for applying UCS categories to sounds using large-language Models """ @@ -50,10 +53,17 @@ def ucsinfer(ctx, verbose, no_model_cache, model): ctx.ensure_object(dict) ctx.obj['model_cache'] = not no_model_cache ctx.obj['model_name'] = model + ctx.obj['complete_ucs'] = complete_ucs if no_model_cache: logger.info("Model cache inhibited by config") + if complete_ucs: + logger.info("Using complete UCS catgeory list") + else: + logger.info("Non-descriptive UCS categories will be excluded. Turn " + "this option off by passing --complete-ucs.") + logger.info(f"Using model {model}") @@ -82,7 +92,8 @@ def recommend(ctx, text, paths, interactive, skip_ucs): """ logger.debug("RECOMMEND mode") inference_ctx = InferenceContext(ctx.obj['model_name'], - use_cached_model=ctx.obj['model_cache']) + use_cached_model=ctx.obj['model_cache'], + use_full_ucs=ctx.obj['complete_ucs']) if text is not None: print_recommendation(None, text, inference_ctx, @@ -122,7 +133,8 @@ def recommend(ctx, text, paths, interactive, skip_ucs): "on the UCS category explanations and synonymns (PATHS will " "be ignored.)") @click.argument('paths', nargs=-1) -def gather(paths, out, ucs_data): +@click.pass_context +def gather(ctx, paths, out, ucs_data): """ Scan files to build a training dataset @@ -142,7 +154,7 @@ def gather(paths, out, ucs_data): types = ['.wav', '.flac'] logger.debug(f"Loading category list...") - ucs = load_ucs() + ucs = load_ucs(full_ucs=ctx.obj['complete_ucs']) scan_list = [] catid_list = [cat.catid for cat in ucs] @@ -152,32 +164,44 @@ def gather(paths, out, ucs_data): paths = [] for path in paths: - logger.info(f"Scanning directory {path}...") for dirpath, _, filenames in os.walk(path): + logger.info(f"Walking directory {dirpath}") for filename in filenames: root, ext = os.path.splitext(filename) - if ext in types and \ - (ucs_components := parse_ucs(root, catid_list)) and \ - not filename.startswith("._"): - scan_list.append((ucs_components.cat_id, - os.path.join(dirpath, filename))) + if ext not in types or filename.startswith("._"): + continue + + if (ucs_components := parse_ucs(root, catid_list)): + p = os.path.join(dirpath, filename) + logger.info(f"Adding path to scan list {p}") + scan_list.append((ucs_components.cat_id, p)) logger.info(f"Found {len(scan_list)} files to process.") def scan_metadata(): for pair in tqdm.tqdm(scan_list, unit='files'): - if desc := ffmpeg_description(pair[1]): + logger.info(f"Scanning file with ffprobe: {pair[1]}") + try: + desc = ffmpeg_description(pair[1]) + except CalledProcessError as e: + logger.error(f"ffprobe returned error {e.returncode}: " \ + + e.stderr) + continue + + if desc: yield desc, str(pair[0]) else: comps = parse_ucs(os.path.basename(pair[1]), catid_list) assert comps yield comps.fx_name, str(pair[0]) + def ucs_metadata(): for cat in ucs: yield cat.explanations, cat.catid yield ", ".join(cat.synonymns), cat.catid + logger.info("Building dataset...") if ucs_data: dataset = build_sentence_class_dataset(ucs_metadata(), catid_list) else: @@ -202,9 +226,6 @@ def finetune(ctx): help='Skip this many records in the dataset before processing') @click.option('--limit', type=int, default=-1, metavar="", help='Process this many records and then exit') -@click.option('--no-foley', 'no_foley', flag_value=True, default=False, - help="Ignore any data in the set with FOLYProp or FOLYFeet " - "category") @click.argument('dataset', type=click.File('r', encoding='utf8'), default='dataset.csv') @click.pass_context @@ -229,84 +250,86 @@ def evaluate(ctx, dataset, offset, limit, no_foley): foley, and so these categories can be excluded with the --no-foley option. """ logger.debug("EVALUATE mode") - inference_context = InferenceContext(ctx.obj['model_name'], - use_cached_model= - ctx.obj['model_cache']) - reader = csv.reader(dataset) - - results = [] - - if offset > 0: - logger.debug(f"Skipping {offset} records...") - - if limit > 0: - logger.debug(f"Will only evaluate {limit} records...") - - progress_bar = tqdm.tqdm(total=limit, - desc="Processing dataset...", - unit="rec") - for i, row in enumerate(reader): - if i < offset: - continue - - if limit > 0 and i >= limit + offset: - break - - cat_id, description = row - if no_foley and cat_id in ['FOLYProp', 'FOLYFeet']: - continue - - guesses = inference_context.classify_text_ranked(description, limit=10) - if cat_id == guesses[0]: - results.append({'catid': cat_id, 'result': "TOP"}) - elif cat_id in guesses[0:5]: - results.append({'catid': cat_id, 'result': "TOP_5"}) - elif cat_id in guesses: - results.append({'catid': cat_id, 'result': "TOP_10"}) - else: - results.append({'catid': cat_id, 'result': "MISS"}) - - progress_bar.update(1) - - total = len(results) - total_top = len([x for x in results if x['result'] == 'TOP']) - total_top_5 = len([x for x in results if x['result'] == 'TOP_5']) - total_top_10 = len([x for x in results if x['result'] == 'TOP_10']) - - cats = set([x['catid'] for x in results]) - total_cats = len(cats) - - miss_counts = [] - for cat in cats: - miss_counts.append( - (cat, len([x for x in results - if x['catid'] == cat and x['result'] == 'MISS']))) - - miss_counts = sorted(miss_counts, key=lambda x: x[1]) - - print(f"## Results for Model {model} ##\n") - - if no_foley: - print("(FOLYProp and FOLYFeet have been omitted from the dataset.)\n") - - table = [ - ["Total records in sample:", f"{total}"], - ["Top Result:", f"{total_top}", - f"{float(total_top)/float(total):.2%}"], - ["Top 5 Result:", f"{total_top_5}", - f"{float(total_top_5)/float(total):.2%}"], - ["Top 10 Result:", f"{total_top_10}", - f"{float(total_top_10)/float(total):.2%}"], - SEPARATING_LINE, - ["UCS category count:", f"{len(inference_context.catlist)}"], - ["Total categories in sample:", f"{total_cats}", - f"{float(total_cats)/float(len(inference_context.catlist)):.2%}"], - [f"Most missed category ({miss_counts[-1][0]}):", - f"{miss_counts[-1][1]}", - f"{float(miss_counts[-1][1])/float(total):.2%}"] - ] - - print(tabulate(table, headers=['', 'n', 'pct'], tablefmt='github')) + logger.warning("Model evaluation is not currently implemented") + # inference_context = InferenceContext( + # ctx.obj['model_name'], use_cached_model=ctx.obj['model_cache'], + # use_full_ucs=ctx.obj['complete_ucs']) + # + # reader = csv.reader(dataset) + # + # results = [] + # + # if offset > 0: + # logger.debug(f"Skipping {offset} records...") + # + # if limit > 0: + # logger.debug(f"Will only evaluate {limit} records...") + # + # progress_bar = tqdm.tqdm(total=limit, + # desc="Processing dataset...", + # unit="rec") + # for i, row in enumerate(reader): + # if i < offset: + # continue + # + # if limit > 0 and i >= limit + offset: + # break + # + # cat_id, description = row + # if no_foley and cat_id in ['FOLYProp', 'FOLYFeet']: + # continue + # + # guesses = inference_context.classify_text_ranked(description, limit=10) + # if cat_id == guesses[0]: + # results.append({'catid': cat_id, 'result': "TOP"}) + # elif cat_id in guesses[0:5]: + # results.append({'catid': cat_id, 'result': "TOP_5"}) + # elif cat_id in guesses: + # results.append({'catid': cat_id, 'result': "TOP_10"}) + # else: + # results.append({'catid': cat_id, 'result': "MISS"}) + # + # progress_bar.update(1) + # + # total = len(results) + # total_top = len([x for x in results if x['result'] == 'TOP']) + # total_top_5 = len([x for x in results if x['result'] == 'TOP_5']) + # total_top_10 = len([x for x in results if x['result'] == 'TOP_10']) + # + # cats = set([x['catid'] for x in results]) + # total_cats = len(cats) + # + # miss_counts = [] + # for cat in cats: + # miss_counts.append( + # (cat, len([x for x in results + # if x['catid'] == cat and x['result'] == 'MISS']))) + # + # miss_counts = sorted(miss_counts, key=lambda x: x[1]) + # + # print(f"## Results for Model {model} ##\n") + # + # if no_foley: + # print("(FOLYProp and FOLYFeet have been omitted from the dataset.)\n") + # + # table = [ + # ["Total records in sample:", f"{total}"], + # ["Top Result:", f"{total_top}", + # f"{float(total_top)/float(total):.2%}"], + # ["Top 5 Result:", f"{total_top_5}", + # f"{float(total_top_5)/float(total):.2%}"], + # ["Top 10 Result:", f"{total_top_10}", + # f"{float(total_top_10)/float(total):.2%}"], + # SEPARATING_LINE, + # ["UCS category count:", f"{len(inference_context.catlist)}"], + # ["Total categories in sample:", f"{total_cats}", + # f"{float(total_cats)/float(len(inference_context.catlist)):.2%}"], + # [f"Most missed category ({miss_counts[-1][0]}):", + # f"{miss_counts[-1][1]}", + # f"{float(miss_counts[-1][1])/float(total):.2%}"] + # ] + # + # print(tabulate(table, headers=['', 'n', 'pct'], tablefmt='github')) if __name__ == '__main__': diff --git a/ucsinfer/gather.py b/ucsinfer/gather.py new file mode 100644 index 0000000..e355ecb --- /dev/null +++ b/ucsinfer/gather.py @@ -0,0 +1,33 @@ +from datasets import Dataset, Features, Value, ClassLabel + +from typing import Generator, Any + +# https://www.sbert.net/docs/sentence_transformer/loss_overview.html + +def build_sentence_class_dataset( + records: Generator[tuple[str, str], Any, None], catlist: list[str]) -> Dataset: + """ + Create a new dataset for `records` which contains (sentence, class) pairs. + + :param records: a generator for records that generates pairs of + (sentence, catid) + :returns: A dataset with two columns: (sentence, hash(catid)) + """ + + labels = ClassLabel(names=catlist) + + items: list[dict] = [] + for obj in records: + items += [{'sentence': obj[0], 'class': obj[1]}] + + + return Dataset.from_list(items, features=Features({'sentence': Value('string'), + 'class': labels})) + + +# def build_sentence_anchor_dataset() -> Dataset: +# """ +# Create a new dataset for `records` which contains (sentence, anchor) pairs. +# """ +# pass + diff --git a/ucsinfer/inference.py b/ucsinfer/inference.py index 353b8a3..aab3f7d 100644 --- a/ucsinfer/inference.py +++ b/ucsinfer/inference.py @@ -34,7 +34,7 @@ class Ucs(NamedTuple): explanations=d['Explanations'], synonymns=d['Synonyms']) -def load_ucs() -> list[Ucs]: +def load_ucs(full_ucs: bool = True) -> list[Ucs]: FILE_ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) cats = [] ucs_defs = os.path.join(FILE_ROOT_DIR, 'ucs-community', 'json', @@ -43,7 +43,13 @@ def load_ucs() -> list[Ucs]: with open(ucs_defs, 'r') as f: cats = json.load(f) - return [Ucs.from_dict(cat) for cat in cats] + ucs = [Ucs.from_dict(cat) for cat in cats] + + if full_ucs: + return ucs + else: + return [cat for cat in ucs if \ + cat.category not in ['FOLEY', 'ARCHIVED']] class InferenceContext: @@ -54,8 +60,11 @@ class InferenceContext: model: SentenceTransformer model_name: str - def __init__(self, model_name: str, use_cached_model: bool = True): + def __init__(self, model_name: str, use_cached_model: bool = True, + use_full_ucs: bool = False): self.model_name = model_name + self.use_full_ucs = use_full_ucs + cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer", 'Squad 51') @@ -75,7 +84,7 @@ class InferenceContext: @cached_property def catlist(self) -> list[Ucs]: - return load_ucs() + return load_ucs(full_ucs=self.use_full_ucs) @cached_property def embeddings(self) -> list[dict]: diff --git a/ucsinfer/util.py b/ucsinfer/util.py index a2bd0f0..334b15a 100644 --- a/ucsinfer/util.py +++ b/ucsinfer/util.py @@ -9,10 +9,7 @@ def ffmpeg_description(path: str) -> Optional[str]: result = subprocess.run(['ffprobe', '-show_format', '-of', 'json', path], capture_output=True) - try: - result.check_returncode() - except: - return None + result.check_returncode() stream = json.loads(result.stdout) fmt = stream.get("format", None)