diff --git a/ucsinfer/__main__.py b/ucsinfer/__main__.py index 5ebf52c..1fee0dc 100644 --- a/ucsinfer/__main__.py +++ b/ucsinfer/__main__.py @@ -1,14 +1,12 @@ import os import logging from itertools import chain -import csv - -from typing import Generator import click from .inference import InferenceContext, load_ucs -from .gather import (build_sentence_class_dataset, print_dataset_stats, +from .import_csv import csv_to_data +from .gather import (build_sentence_class_dataset, print_dataset_stats, ucs_definitions_generator, scan_metadata, walk_path) from .recommend import print_recommendation from .util import ffmpeg_description, parse_ucs @@ -18,18 +16,19 @@ logger.setLevel(logging.DEBUG) stream_handler = logging.StreamHandler() stream_handler.setLevel(logging.WARN) formatter = logging.Formatter( - '%(asctime)s. %(levelname)s %(name)s: %(message)s') + '%(asctime)s. %(levelname)s %(name)s: %(message)s') stream_handler.setFormatter(formatter) logger.addHandler(stream_handler) + @click.group(epilog="For more information see " "") @click.option('--verbose', '-v', flag_value=True, help='Verbose output') -@click.option('--model', type=str, metavar="", +@click.option('--model', type=str, metavar="", default="paraphrase-multilingual-mpnet-base-v2", - show_default=True, + show_default=True, help="Select the sentence_transformer model to use") -@click.option('--no-model-cache', flag_value=True, +@click.option('--no-model-cache', flag_value=True, help="Don't use local model cache") @click.option('--complete-ucs', flag_value=True, default=False, help="Use all UCS categories. By default, all 'FOLEY' and " @@ -39,15 +38,15 @@ def ucsinfer(ctx, verbose, no_model_cache, model, complete_ucs): """ Tools for applying UCS categories to sounds using large-language Models """ - + if verbose: stream_handler.setLevel(logging.DEBUG) logger.info("Verbose logging is enabled") else: import warnings warnings.filterwarnings( - action='ignore', module='torch', category=FutureWarning, - message=r"`encoder_attention_mask` is deprecated.*") + action='ignore', module='torch', category=FutureWarning, + message=r"`encoder_attention_mask` is deprecated.*") stream_handler.setLevel(logging.WARNING) @@ -77,11 +76,11 @@ def ucsinfer(ctx, verbose, no_model_cache, model, complete_ucs): help="Recommend a category for given text instead of reading " "from a file") @click.argument('paths', nargs=-1, metavar='') -@click.option('--interactive','-i', flag_value=True, default=False, +@click.option('--interactive', '-i', flag_value=True, default=False, help="After processing each path in , prompt for a " "recommendation to accept, and then prepend the selection to " "the file name.") -@click.option('-s', '--skip-ucs', flag_value=True, default=False, +@click.option('-s', '--skip-ucs', flag_value=True, default=False, help="Skip files that already have a UCS category in their " "name.") @click.pass_context @@ -96,39 +95,40 @@ def recommend(ctx, text, paths, interactive, skip_ucs): of ranked subcategories is printed to the terminal for each PATH. """ logger.debug("RECOMMEND mode") - inference_ctx = InferenceContext(ctx.obj['model_name'], + inference_ctx = InferenceContext(ctx.obj['model_name'], use_cached_model=ctx.obj['model_cache'], use_full_ucs=ctx.obj['complete_ucs']) if text is not None: - print_recommendation(None, text, inference_ctx, + print_recommendation(None, text, inference_ctx, interactive_rename=False) - + catlist = [x.catid for x in inference_ctx.catlist] for path in paths: - _, ext = os.path.splitext(path) - + _, ext = os.path.splitext(path) + if ext not in (".wav", ".flac"): continue basename = os.path.basename(path) if skip_ucs and parse_ucs(basename, catlist): - continue + continue text = ffmpeg_description(path) if not text: text = os.path.basename(path) while True: - retval = print_recommendation(path, text, inference_ctx, interactive) + retval = print_recommendation( + path, text, inference_ctx, interactive) if not retval: break if retval[0] is False: return elif retval[1] is not None: text = retval[1] - continue + continue elif retval[2] is not None: new_name = retval[2] + '_' + os.path.basename(path) new_path = os.path.join(os.path.dirname(path), new_name) @@ -136,34 +136,12 @@ def recommend(ctx, text, paths, interactive, skip_ucs): os.rename(path, new_path) break -def csv_to_data(paths, description_key, filename_key, catid_list) -> Generator[tuple[str, str], None, None]: - """ - Accepts a list of paths and returns an iterator of (sentence, class) - tuples. - """ - for path in paths: - with open(path, 'r') as f: - records = csv.DictReader(f) - - assert filename_key in records.fieldnames, \ - (f"Filename key `{filename_key}` not present in file " - "{path}") - - assert description_key in records.fieldnames, \ - (f"Description key `{description_key}` not present in " - "file {path}") - - for record in records: - ucs_comps = parse_ucs(record[filename_key], catid_list) - if ucs_comps: - yield (record[description_key], ucs_comps.cat_id) - @ucsinfer.command('csv') -@click.option('--filename-col', default="FileName", +@click.option('--filename-col', default="FileName", help="Heading or index of the column containing filenames", show_default=True) -@click.option('--description-col', default="TrackDescription", +@click.option('--description-col', default="TrackDescription", help="Heading or index of the column containing descriptions", show_default=True) @click.option('--out', default='dataset/', show_default=True) @@ -188,8 +166,8 @@ def import_csv(ctx, paths: list[str], out, filename_col, description_col): logger.info("Building dataset from csv...") dataset = build_sentence_class_dataset( - chain(csv_to_data(paths, description_col, filename_col, catid_list), - ucs_definitions_generator(ucs)),catid_list) + chain(csv_to_data(paths, description_col, filename_col, catid_list), + ucs_definitions_generator(ucs)), catid_list) logger.info(f"Saving dataset to disk at {out}") print_dataset_stats(dataset, catid_list) @@ -203,12 +181,12 @@ def import_csv(ctx, paths: list[str], out, filename_col, description_col): def gather(ctx, paths, out): """ Scan training data from audio files - + `gather` is used to build a training dataset for finetuning the selected model. Description sentences and UCS categories are collected from '.wav' and '.flac' files on-disk that have valid UCS filenames and assigned CatIDs, and this information is recorded into a HuggingFace dataset. - + Gather scans the filesystem in two passes: first, the directory tree is walked by os.walk and a list of filenames that meet the above name criteria is compiled. After this list is compiled, each file is scanned one-by-one @@ -220,7 +198,7 @@ def gather(ctx, paths, out): logger.debug(f"Loading category list...") ucs = load_ucs(full_ucs=ctx.obj['complete_ucs']) - scan_list: list[tuple[str,str]] = [] + scan_list: list[tuple[str, str]] = [] catid_list = [cat.catid for cat in ucs] for path in paths: @@ -231,14 +209,15 @@ def gather(ctx, paths, out): logger.info("Building dataset files...") dataset = build_sentence_class_dataset( - chain(scan_metadata(scan_list, catid_list), - ucs_definitions_generator(ucs)), - catid_list) - + chain(scan_metadata(scan_list, catid_list), + ucs_definitions_generator(ucs)), + catid_list) + logger.info(f"Saving dataset to disk at {out}") print_dataset_stats(dataset, catid_list) dataset.save_to_disk(out) + @ucsinfer.command('qualify') def qualify(): """ @@ -260,7 +239,6 @@ def finetune(ctx): logger.debug("FINETUNE mode") - @ucsinfer.command('evaluate') @click.argument('dataset', default='dataset/') @click.pass_context @@ -270,7 +248,6 @@ def evaluate(ctx, dataset, offset, limit): """ logger.debug("EVALUATE mode") logger.warning("Model evaluation is not currently implemented") - if __name__ == '__main__': diff --git a/ucsinfer/import_csv.py b/ucsinfer/import_csv.py new file mode 100644 index 0000000..9a21318 --- /dev/null +++ b/ucsinfer/import_csv.py @@ -0,0 +1,28 @@ +import csv + +from typing import Generator + +from .util import parse_ucs + + +def csv_to_data(paths, description_key, filename_key, catid_list) -> Generator[tuple[str, str], None, None]: + """ + Accepts a list of paths and returns an iterator of (sentence, class) + tuples. + """ + for path in paths: + with open(path, 'r') as f: + records = csv.DictReader(f) + + assert filename_key in records.fieldnames, \ + (f"Filename key `{filename_key}` not present in file " + "{path}") + + assert description_key in records.fieldnames, \ + (f"Description key `{description_key}` not present in " + "file {path}") + + for record in records: + ucs_comps = parse_ucs(record[filename_key], catid_list) + if ucs_comps: + yield (record[description_key], ucs_comps.cat_id)