Files
ucsinfer/ucsinfer/__main__.py

314 lines
12 KiB
Python

import os
import csv
import logging
import tqdm
import click
from tabulate import tabulate, SEPARATING_LINE
from .inference import InferenceContext, load_ucs
from .gather import build_sentence_class_dataset
from .recommend import print_recommendation
from .util import ffmpeg_description, parse_ucs
logger = logging.getLogger('ucsinfer')
logger.setLevel(logging.DEBUG)
stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.WARN)
formatter = logging.Formatter(
'%(asctime)s. %(levelname)s %(name)s: %(message)s')
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
@click.group(epilog="For more information see "
"<https://git.squad51.us/jamie/ucsinfer>")
@click.option('--verbose', '-v', flag_value=True, help='Verbose output')
@click.option('--model', type=str, metavar="<model-name>",
default="paraphrase-multilingual-mpnet-base-v2",
show_default=True,
help="Select the sentence_transformer model to use")
@click.option('--no-model-cache', flag_value=True,
help="Don't use local model cache")
@click.pass_context
def ucsinfer(ctx, verbose, no_model_cache, model):
"""
Tools for applying UCS categories to sounds using large-language Models
"""
if verbose:
stream_handler.setLevel(logging.DEBUG)
logger.info("Verbose logging is enabled")
else:
import warnings
warnings.filterwarnings(
action='ignore', module='torch', category=FutureWarning,
message=r"`encoder_attention_mask` is deprecated.*")
stream_handler.setLevel(logging.WARNING)
ctx.ensure_object(dict)
ctx.obj['model_cache'] = not no_model_cache
ctx.obj['model_name'] = model
if no_model_cache:
logger.info("Model cache inhibited by config")
logger.info("Using model {model}")
@ucsinfer.command('recommend')
@click.option('--text', default=None,
help="Recommend a category for given text instead of reading "
"from a file")
@click.argument('paths', nargs=-1, metavar='<paths>')
@click.option('--interactive','-i', flag_value=True, default=False,
help="After processing each path in <paths>, prompt for a "
"recommendation to accept, and then prepend the selection to "
"the file name.")
@click.option('-s', '--skip-ucs', flag_value=True, default=False,
help="Skip files that already have a UCS category in their "
"name.")
@click.pass_context
def recommend(ctx, text, paths, interactive, skip_ucs):
"""
Infer a UCS category for a text description
"Description" text metadata is extracted from audio files given as PATHS,
or text can be provided directly using the "--text" option. The selected
model is then used to attempt to classify the given text according to
the synonyms an explanations definied for each UCS subcategory. A list
of ranked subcategories is printed to the terminal for each PATH.
"""
logger.debug("RECOMMEND mode")
inference_ctx = InferenceContext(ctx.obj['model_name'],
use_cached_model=ctx.obj['model_cache'])
if text is not None:
print_recommendation(None, text, inference_ctx,
interactive_rename=False)
catlist = [x.catid for x in inference_ctx.catlist]
for path in paths:
basename = os.path.basename(path)
if skip_ucs and parse_ucs(basename, catlist):
continue
text = ffmpeg_description(path)
if not text:
text = os.path.basename(path)
while True:
retval = print_recommendation(path, text, inference_ctx, interactive)
if not retval:
break
if retval[0] is False:
return
elif retval[1] is not None:
text = retval[1]
continue
elif retval[2] is not None:
new_name = retval[2] + '_' + os.path.basename(path)
new_path = os.path.join(os.path.dirname(path), new_name)
print(f"Renaming {path} \n to {new_path}")
os.rename(path, new_path)
break
@ucsinfer.command('gather')
@click.option('--outfile', default='dataset', show_default=True)
@click.option('--ucs-data', flag_value=True, help="Create a dataset based "
"on the UCS category explanations and synonymns (PATHS will "
"be ignored.)")
@click.argument('paths', nargs=-1)
def gather(paths, outfile, ucs_data):
"""
Scan files to build a training dataset
The `gather` is used to build a training dataset for finetuning the
selected model. Description sentences and UCS categories are collected from
'.wav' and '.flac' files on-disk that have valid UCS filenames and assigned
CatIDs, and this information is recorded into a HuggingFace dataset.
Gather scans the filesystem in two passes: first, the directory tree is
walked by os.walk and a list of filenames that meet the above name criteria
is compiled. After this list is compiled, each file is scanned one-by-one
with ffprobe to obtain its "description" metadata; if this isn't present,
the parsed 'fxname' of te file becomes the description.
"""
logger.debug("GATHER mode")
types = ['.wav', '.flac']
logger.debug(f"Loading category list...")
ucs = load_ucs()
scan_list = []
catid_list = [cat.catid for cat in ucs]
if ucs_data:
paths = []
for path in paths:
logger.info(f"Scanning directory {path}...")
for dirpath, _, filenames in os.walk(path):
for filename in filenames:
root, ext = os.path.splitext(filename)
if ext in types and \
(ucs_components := parse_ucs(root, catid_list)) and \
not filename.startswith("._"):
scan_list.append((ucs_components.cat_id,
os.path.join(dirpath, filename)))
logger.info(f"Found {len(scan_list)} files to process.")
def scan_metadata():
for pair in tqdm.tqdm(scan_list, unit='files'):
if desc := ffmpeg_description(pair[1]):
yield desc, str(pair[0])
else:
comps = parse_ucs(os.path.basename(pair[1]), catid_list)
assert comps
yield comps.fx_name, str(pair[0])
def ucs_metadata():
for cat in ucs:
yield cat.explanations, cat.catid
yield ", ".join(cat.synonymns), cat.catid
if ucs_data:
dataset = build_sentence_class_dataset(ucs_metadata(), catid_list)
else:
dataset = build_sentence_class_dataset(scan_metadata(), catid_list)
dataset.save_to_disk(outfile)
@ucsinfer.command('finetune')
@click.pass_context
def finetune(ctx):
"""
Fine-tune a model with training data
"""
logger.debug("FINETUNE mode")
@ucsinfer.command('evaluate')
@click.option('--offset', type=int, default=0, metavar="<int>",
help='Skip this many records in the dataset before processing')
@click.option('--limit', type=int, default=-1, metavar="<int>",
help='Process this many records and then exit')
@click.option('--no-foley', 'no_foley', flag_value=True, default=False,
help="Ignore any data in the set with FOLYProp or FOLYFeet "
"category")
@click.argument('dataset', type=click.File('r', encoding='utf8'),
default='dataset.csv')
@click.pass_context
def evaluate(ctx, dataset, offset, limit, no_foley):
"""
Use datasets to evaluate model performance
The `evaluate` command reads the input DATASET file row by row and
performs a classifcation of the given description against the selected
model (either the default or using the --model option). The command then
checks if the model inferred the correct category as given by the dataset.
The model gives its top 10 possible categories for a given description,
and the results are tabulated according to (1) wether the top
classification was correct, (2) wether the correct classifcation was in the
top 5, or (3) wether it was in the top 10. The worst-performing category,
the one with the most misses, is also reported as well as the category
coverage, how many categories are present in the dataset.
NOTE: With experimentation it was found that foley items generally were
classified according to their subject and not wether or not they were
foley, and so these categories can be excluded with the --no-foley option.
"""
logger.debug("EVALUATE mode")
inference_context = InferenceContext(ctx.obj['model_name'],
use_cached_model=
ctx.obj['model_cache'])
reader = csv.reader(dataset)
results = []
if offset > 0:
logger.debug(f"Skipping {offset} records...")
if limit > 0:
logger.debug(f"Will only evaluate {limit} records...")
progress_bar = tqdm.tqdm(total=limit,
desc="Processing dataset...",
unit="rec")
for i, row in enumerate(reader):
if i < offset:
continue
if limit > 0 and i >= limit + offset:
break
cat_id, description = row
if no_foley and cat_id in ['FOLYProp', 'FOLYFeet']:
continue
guesses = inference_context.classify_text_ranked(description, limit=10)
if cat_id == guesses[0]:
results.append({'catid': cat_id, 'result': "TOP"})
elif cat_id in guesses[0:5]:
results.append({'catid': cat_id, 'result': "TOP_5"})
elif cat_id in guesses:
results.append({'catid': cat_id, 'result': "TOP_10"})
else:
results.append({'catid': cat_id, 'result': "MISS"})
progress_bar.update(1)
total = len(results)
total_top = len([x for x in results if x['result'] == 'TOP'])
total_top_5 = len([x for x in results if x['result'] == 'TOP_5'])
total_top_10 = len([x for x in results if x['result'] == 'TOP_10'])
cats = set([x['catid'] for x in results])
total_cats = len(cats)
miss_counts = []
for cat in cats:
miss_counts.append(
(cat, len([x for x in results
if x['catid'] == cat and x['result'] == 'MISS'])))
miss_counts = sorted(miss_counts, key=lambda x: x[1])
print(f"## Results for Model {model} ##\n")
if no_foley:
print("(FOLYProp and FOLYFeet have been omitted from the dataset.)\n")
table = [
["Total records in sample:", f"{total}"],
["Top Result:", f"{total_top}",
f"{float(total_top)/float(total):.2%}"],
["Top 5 Result:", f"{total_top_5}",
f"{float(total_top_5)/float(total):.2%}"],
["Top 10 Result:", f"{total_top_10}",
f"{float(total_top_10)/float(total):.2%}"],
SEPARATING_LINE,
["UCS category count:", f"{len(inference_context.catlist)}"],
["Total categories in sample:", f"{total_cats}",
f"{float(total_cats)/float(len(inference_context.catlist)):.2%}"],
[f"Most missed category ({miss_counts[-1][0]}):",
f"{miss_counts[-1][1]}",
f"{float(miss_counts[-1][1])/float(total):.2%}"]
]
print(tabulate(table, headers=['', 'n', 'pct'], tablefmt='github'))
if __name__ == '__main__':
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
ucsinfer(obj={})