314 lines
12 KiB
Python
314 lines
12 KiB
Python
import os
|
|
import csv
|
|
import logging
|
|
|
|
|
|
import tqdm
|
|
import click
|
|
from tabulate import tabulate, SEPARATING_LINE
|
|
|
|
from .inference import InferenceContext, load_ucs
|
|
from .gather import build_sentence_class_dataset
|
|
from .recommend import print_recommendation
|
|
from .util import ffmpeg_description, parse_ucs
|
|
|
|
logger = logging.getLogger('ucsinfer')
|
|
logger.setLevel(logging.DEBUG)
|
|
stream_handler = logging.StreamHandler()
|
|
stream_handler.setLevel(logging.WARN)
|
|
formatter = logging.Formatter(
|
|
'%(asctime)s. %(levelname)s %(name)s: %(message)s')
|
|
stream_handler.setFormatter(formatter)
|
|
logger.addHandler(stream_handler)
|
|
|
|
@click.group(epilog="For more information see "
|
|
"<https://git.squad51.us/jamie/ucsinfer>")
|
|
@click.option('--verbose', '-v', flag_value=True, help='Verbose output')
|
|
@click.option('--model', type=str, metavar="<model-name>",
|
|
default="paraphrase-multilingual-mpnet-base-v2",
|
|
show_default=True,
|
|
help="Select the sentence_transformer model to use")
|
|
@click.option('--no-model-cache', flag_value=True,
|
|
help="Don't use local model cache")
|
|
@click.pass_context
|
|
def ucsinfer(ctx, verbose, no_model_cache, model):
|
|
"""
|
|
Tools for applying UCS categories to sounds using large-language Models
|
|
"""
|
|
|
|
if verbose:
|
|
stream_handler.setLevel(logging.DEBUG)
|
|
logger.info("Verbose logging is enabled")
|
|
else:
|
|
import warnings
|
|
warnings.filterwarnings(
|
|
action='ignore', module='torch', category=FutureWarning,
|
|
message=r"`encoder_attention_mask` is deprecated.*")
|
|
|
|
stream_handler.setLevel(logging.WARNING)
|
|
|
|
ctx.ensure_object(dict)
|
|
ctx.obj['model_cache'] = not no_model_cache
|
|
ctx.obj['model_name'] = model
|
|
|
|
if no_model_cache:
|
|
logger.info("Model cache inhibited by config")
|
|
|
|
logger.info("Using model {model}")
|
|
|
|
|
|
@ucsinfer.command('recommend')
|
|
@click.option('--text', default=None,
|
|
help="Recommend a category for given text instead of reading "
|
|
"from a file")
|
|
@click.argument('paths', nargs=-1, metavar='<paths>')
|
|
@click.option('--interactive','-i', flag_value=True, default=False,
|
|
help="After processing each path in <paths>, prompt for a "
|
|
"recommendation to accept, and then prepend the selection to "
|
|
"the file name.")
|
|
@click.option('-s', '--skip-ucs', flag_value=True, default=False,
|
|
help="Skip files that already have a UCS category in their "
|
|
"name.")
|
|
@click.pass_context
|
|
def recommend(ctx, text, paths, interactive, skip_ucs):
|
|
"""
|
|
Infer a UCS category for a text description
|
|
|
|
"Description" text metadata is extracted from audio files given as PATHS,
|
|
or text can be provided directly using the "--text" option. The selected
|
|
model is then used to attempt to classify the given text according to
|
|
the synonyms an explanations definied for each UCS subcategory. A list
|
|
of ranked subcategories is printed to the terminal for each PATH.
|
|
"""
|
|
logger.debug("RECOMMEND mode")
|
|
inference_ctx = InferenceContext(ctx.obj['model_name'],
|
|
use_cached_model=ctx.obj['model_cache'])
|
|
|
|
if text is not None:
|
|
print_recommendation(None, text, inference_ctx,
|
|
interactive_rename=False)
|
|
|
|
catlist = [x.catid for x in inference_ctx.catlist]
|
|
|
|
for path in paths:
|
|
basename = os.path.basename(path)
|
|
if skip_ucs and parse_ucs(basename, catlist):
|
|
continue
|
|
|
|
text = ffmpeg_description(path)
|
|
if not text:
|
|
text = os.path.basename(path)
|
|
|
|
while True:
|
|
retval = print_recommendation(path, text, inference_ctx, interactive)
|
|
if not retval:
|
|
break
|
|
if retval[0] is False:
|
|
return
|
|
elif retval[1] is not None:
|
|
text = retval[1]
|
|
continue
|
|
elif retval[2] is not None:
|
|
new_name = retval[2] + '_' + os.path.basename(path)
|
|
new_path = os.path.join(os.path.dirname(path), new_name)
|
|
print(f"Renaming {path} \n to {new_path}")
|
|
os.rename(path, new_path)
|
|
break
|
|
|
|
|
|
@ucsinfer.command('gather')
|
|
@click.option('--outfile', default='dataset', show_default=True)
|
|
@click.option('--ucs-data', flag_value=True, help="Create a dataset based "
|
|
"on the UCS category explanations and synonymns (PATHS will "
|
|
"be ignored.)")
|
|
@click.argument('paths', nargs=-1)
|
|
def gather(paths, outfile, ucs_data):
|
|
"""
|
|
Scan files to build a training dataset
|
|
|
|
The `gather` is used to build a training dataset for finetuning the
|
|
selected model. Description sentences and UCS categories are collected from
|
|
'.wav' and '.flac' files on-disk that have valid UCS filenames and assigned
|
|
CatIDs, and this information is recorded into a HuggingFace dataset.
|
|
|
|
Gather scans the filesystem in two passes: first, the directory tree is
|
|
walked by os.walk and a list of filenames that meet the above name criteria
|
|
is compiled. After this list is compiled, each file is scanned one-by-one
|
|
with ffprobe to obtain its "description" metadata; if this isn't present,
|
|
the parsed 'fxname' of te file becomes the description.
|
|
"""
|
|
logger.debug("GATHER mode")
|
|
|
|
types = ['.wav', '.flac']
|
|
|
|
logger.debug(f"Loading category list...")
|
|
ucs = load_ucs()
|
|
|
|
scan_list = []
|
|
catid_list = [cat.catid for cat in ucs]
|
|
|
|
if ucs_data:
|
|
paths = []
|
|
|
|
for path in paths:
|
|
logger.info(f"Scanning directory {path}...")
|
|
for dirpath, _, filenames in os.walk(path):
|
|
for filename in filenames:
|
|
root, ext = os.path.splitext(filename)
|
|
if ext in types and \
|
|
(ucs_components := parse_ucs(root, catid_list)) and \
|
|
not filename.startswith("._"):
|
|
scan_list.append((ucs_components.cat_id,
|
|
os.path.join(dirpath, filename)))
|
|
|
|
logger.info(f"Found {len(scan_list)} files to process.")
|
|
|
|
def scan_metadata():
|
|
for pair in tqdm.tqdm(scan_list, unit='files'):
|
|
if desc := ffmpeg_description(pair[1]):
|
|
yield desc, str(pair[0])
|
|
else:
|
|
comps = parse_ucs(os.path.basename(pair[1]), catid_list)
|
|
assert comps
|
|
yield comps.fx_name, str(pair[0])
|
|
|
|
def ucs_metadata():
|
|
for cat in ucs:
|
|
yield cat.explanations, cat.catid
|
|
yield ", ".join(cat.synonymns), cat.catid
|
|
|
|
if ucs_data:
|
|
dataset = build_sentence_class_dataset(ucs_metadata(), catid_list)
|
|
else:
|
|
dataset = build_sentence_class_dataset(scan_metadata(), catid_list)
|
|
|
|
dataset.save_to_disk(outfile)
|
|
|
|
|
|
@ucsinfer.command('finetune')
|
|
@click.pass_context
|
|
def finetune(ctx):
|
|
"""
|
|
Fine-tune a model with training data
|
|
"""
|
|
logger.debug("FINETUNE mode")
|
|
|
|
|
|
|
|
@ucsinfer.command('evaluate')
|
|
@click.option('--offset', type=int, default=0, metavar="<int>",
|
|
help='Skip this many records in the dataset before processing')
|
|
@click.option('--limit', type=int, default=-1, metavar="<int>",
|
|
help='Process this many records and then exit')
|
|
@click.option('--no-foley', 'no_foley', flag_value=True, default=False,
|
|
help="Ignore any data in the set with FOLYProp or FOLYFeet "
|
|
"category")
|
|
@click.argument('dataset', type=click.File('r', encoding='utf8'),
|
|
default='dataset.csv')
|
|
@click.pass_context
|
|
def evaluate(ctx, dataset, offset, limit, no_foley):
|
|
"""
|
|
Use datasets to evaluate model performance
|
|
|
|
The `evaluate` command reads the input DATASET file row by row and
|
|
performs a classifcation of the given description against the selected
|
|
model (either the default or using the --model option). The command then
|
|
checks if the model inferred the correct category as given by the dataset.
|
|
|
|
The model gives its top 10 possible categories for a given description,
|
|
and the results are tabulated according to (1) wether the top
|
|
classification was correct, (2) wether the correct classifcation was in the
|
|
top 5, or (3) wether it was in the top 10. The worst-performing category,
|
|
the one with the most misses, is also reported as well as the category
|
|
coverage, how many categories are present in the dataset.
|
|
|
|
NOTE: With experimentation it was found that foley items generally were
|
|
classified according to their subject and not wether or not they were
|
|
foley, and so these categories can be excluded with the --no-foley option.
|
|
"""
|
|
logger.debug("EVALUATE mode")
|
|
inference_context = InferenceContext(ctx.obj['model_name'],
|
|
use_cached_model=
|
|
ctx.obj['model_cache'])
|
|
reader = csv.reader(dataset)
|
|
|
|
results = []
|
|
|
|
if offset > 0:
|
|
logger.debug(f"Skipping {offset} records...")
|
|
|
|
if limit > 0:
|
|
logger.debug(f"Will only evaluate {limit} records...")
|
|
|
|
progress_bar = tqdm.tqdm(total=limit,
|
|
desc="Processing dataset...",
|
|
unit="rec")
|
|
for i, row in enumerate(reader):
|
|
if i < offset:
|
|
continue
|
|
|
|
if limit > 0 and i >= limit + offset:
|
|
break
|
|
|
|
cat_id, description = row
|
|
if no_foley and cat_id in ['FOLYProp', 'FOLYFeet']:
|
|
continue
|
|
|
|
guesses = inference_context.classify_text_ranked(description, limit=10)
|
|
if cat_id == guesses[0]:
|
|
results.append({'catid': cat_id, 'result': "TOP"})
|
|
elif cat_id in guesses[0:5]:
|
|
results.append({'catid': cat_id, 'result': "TOP_5"})
|
|
elif cat_id in guesses:
|
|
results.append({'catid': cat_id, 'result': "TOP_10"})
|
|
else:
|
|
results.append({'catid': cat_id, 'result': "MISS"})
|
|
|
|
progress_bar.update(1)
|
|
|
|
total = len(results)
|
|
total_top = len([x for x in results if x['result'] == 'TOP'])
|
|
total_top_5 = len([x for x in results if x['result'] == 'TOP_5'])
|
|
total_top_10 = len([x for x in results if x['result'] == 'TOP_10'])
|
|
|
|
cats = set([x['catid'] for x in results])
|
|
total_cats = len(cats)
|
|
|
|
miss_counts = []
|
|
for cat in cats:
|
|
miss_counts.append(
|
|
(cat, len([x for x in results
|
|
if x['catid'] == cat and x['result'] == 'MISS'])))
|
|
|
|
miss_counts = sorted(miss_counts, key=lambda x: x[1])
|
|
|
|
print(f"## Results for Model {model} ##\n")
|
|
|
|
if no_foley:
|
|
print("(FOLYProp and FOLYFeet have been omitted from the dataset.)\n")
|
|
|
|
table = [
|
|
["Total records in sample:", f"{total}"],
|
|
["Top Result:", f"{total_top}",
|
|
f"{float(total_top)/float(total):.2%}"],
|
|
["Top 5 Result:", f"{total_top_5}",
|
|
f"{float(total_top_5)/float(total):.2%}"],
|
|
["Top 10 Result:", f"{total_top_10}",
|
|
f"{float(total_top_10)/float(total):.2%}"],
|
|
SEPARATING_LINE,
|
|
["UCS category count:", f"{len(inference_context.catlist)}"],
|
|
["Total categories in sample:", f"{total_cats}",
|
|
f"{float(total_cats)/float(len(inference_context.catlist)):.2%}"],
|
|
[f"Most missed category ({miss_counts[-1][0]}):",
|
|
f"{miss_counts[-1][1]}",
|
|
f"{float(miss_counts[-1][1])/float(total):.2%}"]
|
|
]
|
|
|
|
print(tabulate(table, headers=['', 'n', 'pct'], tablefmt='github'))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
|
|
|
ucsinfer(obj={})
|