|
|
|
|
@@ -5,7 +5,8 @@ from itertools import chain
|
|
|
|
|
import click
|
|
|
|
|
|
|
|
|
|
from .inference import InferenceContext, load_ucs
|
|
|
|
|
from .gather import (build_sentence_class_dataset, print_dataset_stats,
|
|
|
|
|
from .import_csv import csv_to_data
|
|
|
|
|
from .gather import (build_sentence_class_dataset, print_dataset_stats,
|
|
|
|
|
ucs_definitions_generator, scan_metadata, walk_path)
|
|
|
|
|
from .recommend import print_recommendation
|
|
|
|
|
from .util import ffmpeg_description, parse_ucs
|
|
|
|
|
@@ -15,18 +16,19 @@ logger.setLevel(logging.DEBUG)
|
|
|
|
|
stream_handler = logging.StreamHandler()
|
|
|
|
|
stream_handler.setLevel(logging.WARN)
|
|
|
|
|
formatter = logging.Formatter(
|
|
|
|
|
'%(asctime)s. %(levelname)s %(name)s: %(message)s')
|
|
|
|
|
'%(asctime)s. %(levelname)s %(name)s: %(message)s')
|
|
|
|
|
stream_handler.setFormatter(formatter)
|
|
|
|
|
logger.addHandler(stream_handler)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@click.group(epilog="For more information see "
|
|
|
|
|
"<https://git.squad51.us/jamie/ucsinfer>")
|
|
|
|
|
@click.option('--verbose', '-v', flag_value=True, help='Verbose output')
|
|
|
|
|
@click.option('--model', type=str, metavar="<model-name>",
|
|
|
|
|
@click.option('--model', type=str, metavar="<model-name>",
|
|
|
|
|
default="paraphrase-multilingual-mpnet-base-v2",
|
|
|
|
|
show_default=True,
|
|
|
|
|
show_default=True,
|
|
|
|
|
help="Select the sentence_transformer model to use")
|
|
|
|
|
@click.option('--no-model-cache', flag_value=True,
|
|
|
|
|
@click.option('--no-model-cache', flag_value=True,
|
|
|
|
|
help="Don't use local model cache")
|
|
|
|
|
@click.option('--complete-ucs', flag_value=True, default=False,
|
|
|
|
|
help="Use all UCS categories. By default, all 'FOLEY' and "
|
|
|
|
|
@@ -36,21 +38,21 @@ def ucsinfer(ctx, verbose, no_model_cache, model, complete_ucs):
|
|
|
|
|
"""
|
|
|
|
|
Tools for applying UCS categories to sounds using large-language Models
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if verbose:
|
|
|
|
|
stream_handler.setLevel(logging.DEBUG)
|
|
|
|
|
logger.info("Verbose logging is enabled")
|
|
|
|
|
else:
|
|
|
|
|
import warnings
|
|
|
|
|
warnings.filterwarnings(
|
|
|
|
|
action='ignore', module='torch', category=FutureWarning,
|
|
|
|
|
message=r"`encoder_attention_mask` is deprecated.*")
|
|
|
|
|
action='ignore', module='torch', category=FutureWarning,
|
|
|
|
|
message=r"`encoder_attention_mask` is deprecated.*")
|
|
|
|
|
|
|
|
|
|
stream_handler.setLevel(logging.WARNING)
|
|
|
|
|
|
|
|
|
|
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
|
|
|
|
logger.info("Setting TOKENIZERS_PARALLELISM environment variable to `false"
|
|
|
|
|
" explicitly")
|
|
|
|
|
logger.info("Setting TOKENIZERS_PARALLELISM environment variable to "
|
|
|
|
|
"`false` explicitly")
|
|
|
|
|
|
|
|
|
|
ctx.ensure_object(dict)
|
|
|
|
|
ctx.obj['model_cache'] = not no_model_cache
|
|
|
|
|
@@ -74,11 +76,11 @@ def ucsinfer(ctx, verbose, no_model_cache, model, complete_ucs):
|
|
|
|
|
help="Recommend a category for given text instead of reading "
|
|
|
|
|
"from a file")
|
|
|
|
|
@click.argument('paths', nargs=-1, metavar='<paths>')
|
|
|
|
|
@click.option('--interactive','-i', flag_value=True, default=False,
|
|
|
|
|
@click.option('--interactive', '-i', flag_value=True, default=False,
|
|
|
|
|
help="After processing each path in <paths>, prompt for a "
|
|
|
|
|
"recommendation to accept, and then prepend the selection to "
|
|
|
|
|
"the file name.")
|
|
|
|
|
@click.option('-s', '--skip-ucs', flag_value=True, default=False,
|
|
|
|
|
@click.option('-s', '--skip-ucs', flag_value=True, default=False,
|
|
|
|
|
help="Skip files that already have a UCS category in their "
|
|
|
|
|
"name.")
|
|
|
|
|
@click.pass_context
|
|
|
|
|
@@ -93,57 +95,59 @@ def recommend(ctx, text, paths, interactive, skip_ucs):
|
|
|
|
|
of ranked subcategories is printed to the terminal for each PATH.
|
|
|
|
|
"""
|
|
|
|
|
logger.debug("RECOMMEND mode")
|
|
|
|
|
inference_ctx = InferenceContext(ctx.obj['model_name'],
|
|
|
|
|
inference_ctx = InferenceContext(ctx.obj['model_name'],
|
|
|
|
|
use_cached_model=ctx.obj['model_cache'],
|
|
|
|
|
use_full_ucs=ctx.obj['complete_ucs'])
|
|
|
|
|
|
|
|
|
|
if text is not None:
|
|
|
|
|
print_recommendation(None, text, inference_ctx,
|
|
|
|
|
print_recommendation(None, text, inference_ctx,
|
|
|
|
|
interactive_rename=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
catlist = [x.catid for x in inference_ctx.catlist]
|
|
|
|
|
|
|
|
|
|
for path in paths:
|
|
|
|
|
_, ext = os.path.splitext(path)
|
|
|
|
|
|
|
|
|
|
_, ext = os.path.splitext(path)
|
|
|
|
|
|
|
|
|
|
if ext not in (".wav", ".flac"):
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
basename = os.path.basename(path)
|
|
|
|
|
if skip_ucs and parse_ucs(basename, catlist):
|
|
|
|
|
continue
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
text = ffmpeg_description(path)
|
|
|
|
|
if not text:
|
|
|
|
|
text = os.path.basename(path)
|
|
|
|
|
|
|
|
|
|
while True:
|
|
|
|
|
retval = print_recommendation(path, text, inference_ctx, interactive)
|
|
|
|
|
retval = print_recommendation(
|
|
|
|
|
path, text, inference_ctx, interactive)
|
|
|
|
|
if not retval:
|
|
|
|
|
break
|
|
|
|
|
if retval[0] is False:
|
|
|
|
|
return
|
|
|
|
|
elif retval[1] is not None:
|
|
|
|
|
text = retval[1]
|
|
|
|
|
continue
|
|
|
|
|
continue
|
|
|
|
|
elif retval[2] is not None:
|
|
|
|
|
new_name = retval[2] + '_' + os.path.basename(path)
|
|
|
|
|
new_path = os.path.join(os.path.dirname(path), new_name)
|
|
|
|
|
print(f"Renaming {path} \n to {new_path}")
|
|
|
|
|
os.rename(path, new_path)
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ucsinfer.command('csv')
|
|
|
|
|
@click.option('--filename-col', default="FileName",
|
|
|
|
|
@click.option('--filename-col', default="FileName",
|
|
|
|
|
help="Heading or index of the column containing filenames",
|
|
|
|
|
show_default=True)
|
|
|
|
|
@click.option('--description-col', default="TrackDescription",
|
|
|
|
|
@click.option('--description-col', default="TrackDescription",
|
|
|
|
|
help="Heading or index of the column containing descriptions",
|
|
|
|
|
show_default=True)
|
|
|
|
|
@click.option('--out', default='dataset/', show_default=True)
|
|
|
|
|
@click.argument('paths', nargs=-1)
|
|
|
|
|
@click.pass_context
|
|
|
|
|
def csv(ctx, paths, out, filename_col, description_col):
|
|
|
|
|
def import_csv(ctx, paths: list[str], out, filename_col, description_col):
|
|
|
|
|
"""
|
|
|
|
|
Scan training data from CSV files
|
|
|
|
|
|
|
|
|
|
@@ -152,24 +156,37 @@ def csv(ctx, paths, out, filename_col, description_col):
|
|
|
|
|
file system it builds a dataset from descriptions and UCS filenames in
|
|
|
|
|
columns of a CSV file.
|
|
|
|
|
"""
|
|
|
|
|
pass
|
|
|
|
|
logger.debug("CSV mode")
|
|
|
|
|
|
|
|
|
|
logger.debug(f"Loading category list...")
|
|
|
|
|
ucs = load_ucs(full_ucs=ctx.obj['complete_ucs'])
|
|
|
|
|
|
|
|
|
|
catid_list = [cat.catid for cat in ucs]
|
|
|
|
|
|
|
|
|
|
logger.info("Building dataset from csv...")
|
|
|
|
|
|
|
|
|
|
dataset = build_sentence_class_dataset(
|
|
|
|
|
chain(csv_to_data(paths, description_col, filename_col, catid_list),
|
|
|
|
|
ucs_definitions_generator(ucs)), catid_list)
|
|
|
|
|
|
|
|
|
|
logger.info(f"Saving dataset to disk at {out}")
|
|
|
|
|
print_dataset_stats(dataset, catid_list)
|
|
|
|
|
dataset.save_to_disk(out)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ucsinfer.command('gather')
|
|
|
|
|
@click.option('--out', default='dataset/', show_default=True)
|
|
|
|
|
# @click.option('--ucs-data', flag_value=True, help="Create a dataset based "
|
|
|
|
|
# "on the UCS category explanations and synonymns (PATHS will "
|
|
|
|
|
# "be ignored.)")
|
|
|
|
|
@click.argument('paths', nargs=-1)
|
|
|
|
|
@click.pass_context
|
|
|
|
|
def gather(ctx, paths, out, ucs_data):
|
|
|
|
|
def gather(ctx, paths, out):
|
|
|
|
|
"""
|
|
|
|
|
Scan training data from audio files
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
`gather` is used to build a training dataset for finetuning the selected
|
|
|
|
|
model. Description sentences and UCS categories are collected from '.wav'
|
|
|
|
|
and '.flac' files on-disk that have valid UCS filenames and assigned
|
|
|
|
|
CatIDs, and this information is recorded into a HuggingFace dataset.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Gather scans the filesystem in two passes: first, the directory tree is
|
|
|
|
|
walked by os.walk and a list of filenames that meet the above name criteria
|
|
|
|
|
is compiled. After this list is compiled, each file is scanned one-by-one
|
|
|
|
|
@@ -181,29 +198,26 @@ def gather(ctx, paths, out, ucs_data):
|
|
|
|
|
logger.debug(f"Loading category list...")
|
|
|
|
|
ucs = load_ucs(full_ucs=ctx.obj['complete_ucs'])
|
|
|
|
|
|
|
|
|
|
scan_list: list[tuple[str,str]] = []
|
|
|
|
|
scan_list: list[tuple[str, str]] = []
|
|
|
|
|
catid_list = [cat.catid for cat in ucs]
|
|
|
|
|
|
|
|
|
|
if ucs_data:
|
|
|
|
|
logger.info('Creating dataset for UCS categories instead of from PATH')
|
|
|
|
|
paths = []
|
|
|
|
|
|
|
|
|
|
for path in paths:
|
|
|
|
|
scan_list += walk_path(path, catid_list)
|
|
|
|
|
|
|
|
|
|
logger.info(f"Found {len(scan_list)} files to process.")
|
|
|
|
|
|
|
|
|
|
logger.info("Building dataset...")
|
|
|
|
|
logger.info("Building dataset files...")
|
|
|
|
|
|
|
|
|
|
dataset = build_sentence_class_dataset(
|
|
|
|
|
chain(scan_metadata(scan_list, catid_list),
|
|
|
|
|
ucs_definitions_generator(ucs)),
|
|
|
|
|
catid_list)
|
|
|
|
|
|
|
|
|
|
chain(scan_metadata(scan_list, catid_list),
|
|
|
|
|
ucs_definitions_generator(ucs)),
|
|
|
|
|
catid_list)
|
|
|
|
|
|
|
|
|
|
logger.info(f"Saving dataset to disk at {out}")
|
|
|
|
|
print_dataset_stats(dataset, catid_list)
|
|
|
|
|
dataset.save_to_disk(out)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ucsinfer.command('qualify')
|
|
|
|
|
def qualify():
|
|
|
|
|
"""
|
|
|
|
|
@@ -225,7 +239,6 @@ def finetune(ctx):
|
|
|
|
|
logger.debug("FINETUNE mode")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ucsinfer.command('evaluate')
|
|
|
|
|
@click.argument('dataset', default='dataset/')
|
|
|
|
|
@click.pass_context
|
|
|
|
|
@@ -235,7 +248,6 @@ def evaluate(ctx, dataset, offset, limit):
|
|
|
|
|
"""
|
|
|
|
|
logger.debug("EVALUATE mode")
|
|
|
|
|
logger.warning("Model evaluation is not currently implemented")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|