Cleaned up code and refactored to new file

This commit is contained in:
2025-10-14 10:31:34 -07:00
parent 5fc57cf7c8
commit 3ca921ad02
2 changed files with 61 additions and 56 deletions

View File

@@ -1,14 +1,12 @@
import os import os
import logging import logging
from itertools import chain from itertools import chain
import csv
from typing import Generator
import click import click
from .inference import InferenceContext, load_ucs from .inference import InferenceContext, load_ucs
from .gather import (build_sentence_class_dataset, print_dataset_stats, from .import_csv import csv_to_data
from .gather import (build_sentence_class_dataset, print_dataset_stats,
ucs_definitions_generator, scan_metadata, walk_path) ucs_definitions_generator, scan_metadata, walk_path)
from .recommend import print_recommendation from .recommend import print_recommendation
from .util import ffmpeg_description, parse_ucs from .util import ffmpeg_description, parse_ucs
@@ -18,18 +16,19 @@ logger.setLevel(logging.DEBUG)
stream_handler = logging.StreamHandler() stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.WARN) stream_handler.setLevel(logging.WARN)
formatter = logging.Formatter( formatter = logging.Formatter(
'%(asctime)s. %(levelname)s %(name)s: %(message)s') '%(asctime)s. %(levelname)s %(name)s: %(message)s')
stream_handler.setFormatter(formatter) stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler) logger.addHandler(stream_handler)
@click.group(epilog="For more information see " @click.group(epilog="For more information see "
"<https://git.squad51.us/jamie/ucsinfer>") "<https://git.squad51.us/jamie/ucsinfer>")
@click.option('--verbose', '-v', flag_value=True, help='Verbose output') @click.option('--verbose', '-v', flag_value=True, help='Verbose output')
@click.option('--model', type=str, metavar="<model-name>", @click.option('--model', type=str, metavar="<model-name>",
default="paraphrase-multilingual-mpnet-base-v2", default="paraphrase-multilingual-mpnet-base-v2",
show_default=True, show_default=True,
help="Select the sentence_transformer model to use") help="Select the sentence_transformer model to use")
@click.option('--no-model-cache', flag_value=True, @click.option('--no-model-cache', flag_value=True,
help="Don't use local model cache") help="Don't use local model cache")
@click.option('--complete-ucs', flag_value=True, default=False, @click.option('--complete-ucs', flag_value=True, default=False,
help="Use all UCS categories. By default, all 'FOLEY' and " help="Use all UCS categories. By default, all 'FOLEY' and "
@@ -39,15 +38,15 @@ def ucsinfer(ctx, verbose, no_model_cache, model, complete_ucs):
""" """
Tools for applying UCS categories to sounds using large-language Models Tools for applying UCS categories to sounds using large-language Models
""" """
if verbose: if verbose:
stream_handler.setLevel(logging.DEBUG) stream_handler.setLevel(logging.DEBUG)
logger.info("Verbose logging is enabled") logger.info("Verbose logging is enabled")
else: else:
import warnings import warnings
warnings.filterwarnings( warnings.filterwarnings(
action='ignore', module='torch', category=FutureWarning, action='ignore', module='torch', category=FutureWarning,
message=r"`encoder_attention_mask` is deprecated.*") message=r"`encoder_attention_mask` is deprecated.*")
stream_handler.setLevel(logging.WARNING) stream_handler.setLevel(logging.WARNING)
@@ -77,11 +76,11 @@ def ucsinfer(ctx, verbose, no_model_cache, model, complete_ucs):
help="Recommend a category for given text instead of reading " help="Recommend a category for given text instead of reading "
"from a file") "from a file")
@click.argument('paths', nargs=-1, metavar='<paths>') @click.argument('paths', nargs=-1, metavar='<paths>')
@click.option('--interactive','-i', flag_value=True, default=False, @click.option('--interactive', '-i', flag_value=True, default=False,
help="After processing each path in <paths>, prompt for a " help="After processing each path in <paths>, prompt for a "
"recommendation to accept, and then prepend the selection to " "recommendation to accept, and then prepend the selection to "
"the file name.") "the file name.")
@click.option('-s', '--skip-ucs', flag_value=True, default=False, @click.option('-s', '--skip-ucs', flag_value=True, default=False,
help="Skip files that already have a UCS category in their " help="Skip files that already have a UCS category in their "
"name.") "name.")
@click.pass_context @click.pass_context
@@ -96,39 +95,40 @@ def recommend(ctx, text, paths, interactive, skip_ucs):
of ranked subcategories is printed to the terminal for each PATH. of ranked subcategories is printed to the terminal for each PATH.
""" """
logger.debug("RECOMMEND mode") logger.debug("RECOMMEND mode")
inference_ctx = InferenceContext(ctx.obj['model_name'], inference_ctx = InferenceContext(ctx.obj['model_name'],
use_cached_model=ctx.obj['model_cache'], use_cached_model=ctx.obj['model_cache'],
use_full_ucs=ctx.obj['complete_ucs']) use_full_ucs=ctx.obj['complete_ucs'])
if text is not None: if text is not None:
print_recommendation(None, text, inference_ctx, print_recommendation(None, text, inference_ctx,
interactive_rename=False) interactive_rename=False)
catlist = [x.catid for x in inference_ctx.catlist] catlist = [x.catid for x in inference_ctx.catlist]
for path in paths: for path in paths:
_, ext = os.path.splitext(path) _, ext = os.path.splitext(path)
if ext not in (".wav", ".flac"): if ext not in (".wav", ".flac"):
continue continue
basename = os.path.basename(path) basename = os.path.basename(path)
if skip_ucs and parse_ucs(basename, catlist): if skip_ucs and parse_ucs(basename, catlist):
continue continue
text = ffmpeg_description(path) text = ffmpeg_description(path)
if not text: if not text:
text = os.path.basename(path) text = os.path.basename(path)
while True: while True:
retval = print_recommendation(path, text, inference_ctx, interactive) retval = print_recommendation(
path, text, inference_ctx, interactive)
if not retval: if not retval:
break break
if retval[0] is False: if retval[0] is False:
return return
elif retval[1] is not None: elif retval[1] is not None:
text = retval[1] text = retval[1]
continue continue
elif retval[2] is not None: elif retval[2] is not None:
new_name = retval[2] + '_' + os.path.basename(path) new_name = retval[2] + '_' + os.path.basename(path)
new_path = os.path.join(os.path.dirname(path), new_name) new_path = os.path.join(os.path.dirname(path), new_name)
@@ -136,34 +136,12 @@ def recommend(ctx, text, paths, interactive, skip_ucs):
os.rename(path, new_path) os.rename(path, new_path)
break break
def csv_to_data(paths, description_key, filename_key, catid_list) -> Generator[tuple[str, str], None, None]:
"""
Accepts a list of paths and returns an iterator of (sentence, class)
tuples.
"""
for path in paths:
with open(path, 'r') as f:
records = csv.DictReader(f)
assert filename_key in records.fieldnames, \
(f"Filename key `{filename_key}` not present in file "
"{path}")
assert description_key in records.fieldnames, \
(f"Description key `{description_key}` not present in "
"file {path}")
for record in records:
ucs_comps = parse_ucs(record[filename_key], catid_list)
if ucs_comps:
yield (record[description_key], ucs_comps.cat_id)
@ucsinfer.command('csv') @ucsinfer.command('csv')
@click.option('--filename-col', default="FileName", @click.option('--filename-col', default="FileName",
help="Heading or index of the column containing filenames", help="Heading or index of the column containing filenames",
show_default=True) show_default=True)
@click.option('--description-col', default="TrackDescription", @click.option('--description-col', default="TrackDescription",
help="Heading or index of the column containing descriptions", help="Heading or index of the column containing descriptions",
show_default=True) show_default=True)
@click.option('--out', default='dataset/', show_default=True) @click.option('--out', default='dataset/', show_default=True)
@@ -188,8 +166,8 @@ def import_csv(ctx, paths: list[str], out, filename_col, description_col):
logger.info("Building dataset from csv...") logger.info("Building dataset from csv...")
dataset = build_sentence_class_dataset( dataset = build_sentence_class_dataset(
chain(csv_to_data(paths, description_col, filename_col, catid_list), chain(csv_to_data(paths, description_col, filename_col, catid_list),
ucs_definitions_generator(ucs)),catid_list) ucs_definitions_generator(ucs)), catid_list)
logger.info(f"Saving dataset to disk at {out}") logger.info(f"Saving dataset to disk at {out}")
print_dataset_stats(dataset, catid_list) print_dataset_stats(dataset, catid_list)
@@ -203,12 +181,12 @@ def import_csv(ctx, paths: list[str], out, filename_col, description_col):
def gather(ctx, paths, out): def gather(ctx, paths, out):
""" """
Scan training data from audio files Scan training data from audio files
`gather` is used to build a training dataset for finetuning the selected `gather` is used to build a training dataset for finetuning the selected
model. Description sentences and UCS categories are collected from '.wav' model. Description sentences and UCS categories are collected from '.wav'
and '.flac' files on-disk that have valid UCS filenames and assigned and '.flac' files on-disk that have valid UCS filenames and assigned
CatIDs, and this information is recorded into a HuggingFace dataset. CatIDs, and this information is recorded into a HuggingFace dataset.
Gather scans the filesystem in two passes: first, the directory tree is Gather scans the filesystem in two passes: first, the directory tree is
walked by os.walk and a list of filenames that meet the above name criteria walked by os.walk and a list of filenames that meet the above name criteria
is compiled. After this list is compiled, each file is scanned one-by-one is compiled. After this list is compiled, each file is scanned one-by-one
@@ -220,7 +198,7 @@ def gather(ctx, paths, out):
logger.debug(f"Loading category list...") logger.debug(f"Loading category list...")
ucs = load_ucs(full_ucs=ctx.obj['complete_ucs']) ucs = load_ucs(full_ucs=ctx.obj['complete_ucs'])
scan_list: list[tuple[str,str]] = [] scan_list: list[tuple[str, str]] = []
catid_list = [cat.catid for cat in ucs] catid_list = [cat.catid for cat in ucs]
for path in paths: for path in paths:
@@ -231,14 +209,15 @@ def gather(ctx, paths, out):
logger.info("Building dataset files...") logger.info("Building dataset files...")
dataset = build_sentence_class_dataset( dataset = build_sentence_class_dataset(
chain(scan_metadata(scan_list, catid_list), chain(scan_metadata(scan_list, catid_list),
ucs_definitions_generator(ucs)), ucs_definitions_generator(ucs)),
catid_list) catid_list)
logger.info(f"Saving dataset to disk at {out}") logger.info(f"Saving dataset to disk at {out}")
print_dataset_stats(dataset, catid_list) print_dataset_stats(dataset, catid_list)
dataset.save_to_disk(out) dataset.save_to_disk(out)
@ucsinfer.command('qualify') @ucsinfer.command('qualify')
def qualify(): def qualify():
""" """
@@ -260,7 +239,6 @@ def finetune(ctx):
logger.debug("FINETUNE mode") logger.debug("FINETUNE mode")
@ucsinfer.command('evaluate') @ucsinfer.command('evaluate')
@click.argument('dataset', default='dataset/') @click.argument('dataset', default='dataset/')
@click.pass_context @click.pass_context
@@ -270,7 +248,6 @@ def evaluate(ctx, dataset, offset, limit):
""" """
logger.debug("EVALUATE mode") logger.debug("EVALUATE mode")
logger.warning("Model evaluation is not currently implemented") logger.warning("Model evaluation is not currently implemented")
if __name__ == '__main__': if __name__ == '__main__':

28
ucsinfer/import_csv.py Normal file
View File

@@ -0,0 +1,28 @@
import csv
from typing import Generator
from .util import parse_ucs
def csv_to_data(paths, description_key, filename_key, catid_list) -> Generator[tuple[str, str], None, None]:
"""
Accepts a list of paths and returns an iterator of (sentence, class)
tuples.
"""
for path in paths:
with open(path, 'r') as f:
records = csv.DictReader(f)
assert filename_key in records.fieldnames, \
(f"Filename key `{filename_key}` not present in file "
"{path}")
assert description_key in records.fieldnames, \
(f"Description key `{description_key}` not present in "
"file {path}")
for record in records:
ucs_comps = parse_ucs(record[filename_key], catid_list)
if ucs_comps:
yield (record[description_key], ucs_comps.cat_id)