Cleaned up code and refactored to new file

This commit is contained in:
2025-10-14 10:31:34 -07:00
parent 5fc57cf7c8
commit 3ca921ad02
2 changed files with 61 additions and 56 deletions

View File

@@ -1,14 +1,12 @@
import os
import logging
from itertools import chain
import csv
from typing import Generator
import click
from .inference import InferenceContext, load_ucs
from .gather import (build_sentence_class_dataset, print_dataset_stats,
from .import_csv import csv_to_data
from .gather import (build_sentence_class_dataset, print_dataset_stats,
ucs_definitions_generator, scan_metadata, walk_path)
from .recommend import print_recommendation
from .util import ffmpeg_description, parse_ucs
@@ -18,18 +16,19 @@ logger.setLevel(logging.DEBUG)
stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.WARN)
formatter = logging.Formatter(
'%(asctime)s. %(levelname)s %(name)s: %(message)s')
'%(asctime)s. %(levelname)s %(name)s: %(message)s')
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
@click.group(epilog="For more information see "
"<https://git.squad51.us/jamie/ucsinfer>")
@click.option('--verbose', '-v', flag_value=True, help='Verbose output')
@click.option('--model', type=str, metavar="<model-name>",
@click.option('--model', type=str, metavar="<model-name>",
default="paraphrase-multilingual-mpnet-base-v2",
show_default=True,
show_default=True,
help="Select the sentence_transformer model to use")
@click.option('--no-model-cache', flag_value=True,
@click.option('--no-model-cache', flag_value=True,
help="Don't use local model cache")
@click.option('--complete-ucs', flag_value=True, default=False,
help="Use all UCS categories. By default, all 'FOLEY' and "
@@ -39,15 +38,15 @@ def ucsinfer(ctx, verbose, no_model_cache, model, complete_ucs):
"""
Tools for applying UCS categories to sounds using large-language Models
"""
if verbose:
stream_handler.setLevel(logging.DEBUG)
logger.info("Verbose logging is enabled")
else:
import warnings
warnings.filterwarnings(
action='ignore', module='torch', category=FutureWarning,
message=r"`encoder_attention_mask` is deprecated.*")
action='ignore', module='torch', category=FutureWarning,
message=r"`encoder_attention_mask` is deprecated.*")
stream_handler.setLevel(logging.WARNING)
@@ -77,11 +76,11 @@ def ucsinfer(ctx, verbose, no_model_cache, model, complete_ucs):
help="Recommend a category for given text instead of reading "
"from a file")
@click.argument('paths', nargs=-1, metavar='<paths>')
@click.option('--interactive','-i', flag_value=True, default=False,
@click.option('--interactive', '-i', flag_value=True, default=False,
help="After processing each path in <paths>, prompt for a "
"recommendation to accept, and then prepend the selection to "
"the file name.")
@click.option('-s', '--skip-ucs', flag_value=True, default=False,
@click.option('-s', '--skip-ucs', flag_value=True, default=False,
help="Skip files that already have a UCS category in their "
"name.")
@click.pass_context
@@ -96,39 +95,40 @@ def recommend(ctx, text, paths, interactive, skip_ucs):
of ranked subcategories is printed to the terminal for each PATH.
"""
logger.debug("RECOMMEND mode")
inference_ctx = InferenceContext(ctx.obj['model_name'],
inference_ctx = InferenceContext(ctx.obj['model_name'],
use_cached_model=ctx.obj['model_cache'],
use_full_ucs=ctx.obj['complete_ucs'])
if text is not None:
print_recommendation(None, text, inference_ctx,
print_recommendation(None, text, inference_ctx,
interactive_rename=False)
catlist = [x.catid for x in inference_ctx.catlist]
for path in paths:
_, ext = os.path.splitext(path)
_, ext = os.path.splitext(path)
if ext not in (".wav", ".flac"):
continue
basename = os.path.basename(path)
if skip_ucs and parse_ucs(basename, catlist):
continue
continue
text = ffmpeg_description(path)
if not text:
text = os.path.basename(path)
while True:
retval = print_recommendation(path, text, inference_ctx, interactive)
retval = print_recommendation(
path, text, inference_ctx, interactive)
if not retval:
break
if retval[0] is False:
return
elif retval[1] is not None:
text = retval[1]
continue
continue
elif retval[2] is not None:
new_name = retval[2] + '_' + os.path.basename(path)
new_path = os.path.join(os.path.dirname(path), new_name)
@@ -136,34 +136,12 @@ def recommend(ctx, text, paths, interactive, skip_ucs):
os.rename(path, new_path)
break
def csv_to_data(paths, description_key, filename_key, catid_list) -> Generator[tuple[str, str], None, None]:
"""
Accepts a list of paths and returns an iterator of (sentence, class)
tuples.
"""
for path in paths:
with open(path, 'r') as f:
records = csv.DictReader(f)
assert filename_key in records.fieldnames, \
(f"Filename key `{filename_key}` not present in file "
"{path}")
assert description_key in records.fieldnames, \
(f"Description key `{description_key}` not present in "
"file {path}")
for record in records:
ucs_comps = parse_ucs(record[filename_key], catid_list)
if ucs_comps:
yield (record[description_key], ucs_comps.cat_id)
@ucsinfer.command('csv')
@click.option('--filename-col', default="FileName",
@click.option('--filename-col', default="FileName",
help="Heading or index of the column containing filenames",
show_default=True)
@click.option('--description-col', default="TrackDescription",
@click.option('--description-col', default="TrackDescription",
help="Heading or index of the column containing descriptions",
show_default=True)
@click.option('--out', default='dataset/', show_default=True)
@@ -188,8 +166,8 @@ def import_csv(ctx, paths: list[str], out, filename_col, description_col):
logger.info("Building dataset from csv...")
dataset = build_sentence_class_dataset(
chain(csv_to_data(paths, description_col, filename_col, catid_list),
ucs_definitions_generator(ucs)),catid_list)
chain(csv_to_data(paths, description_col, filename_col, catid_list),
ucs_definitions_generator(ucs)), catid_list)
logger.info(f"Saving dataset to disk at {out}")
print_dataset_stats(dataset, catid_list)
@@ -203,12 +181,12 @@ def import_csv(ctx, paths: list[str], out, filename_col, description_col):
def gather(ctx, paths, out):
"""
Scan training data from audio files
`gather` is used to build a training dataset for finetuning the selected
model. Description sentences and UCS categories are collected from '.wav'
and '.flac' files on-disk that have valid UCS filenames and assigned
CatIDs, and this information is recorded into a HuggingFace dataset.
Gather scans the filesystem in two passes: first, the directory tree is
walked by os.walk and a list of filenames that meet the above name criteria
is compiled. After this list is compiled, each file is scanned one-by-one
@@ -220,7 +198,7 @@ def gather(ctx, paths, out):
logger.debug(f"Loading category list...")
ucs = load_ucs(full_ucs=ctx.obj['complete_ucs'])
scan_list: list[tuple[str,str]] = []
scan_list: list[tuple[str, str]] = []
catid_list = [cat.catid for cat in ucs]
for path in paths:
@@ -231,14 +209,15 @@ def gather(ctx, paths, out):
logger.info("Building dataset files...")
dataset = build_sentence_class_dataset(
chain(scan_metadata(scan_list, catid_list),
ucs_definitions_generator(ucs)),
catid_list)
chain(scan_metadata(scan_list, catid_list),
ucs_definitions_generator(ucs)),
catid_list)
logger.info(f"Saving dataset to disk at {out}")
print_dataset_stats(dataset, catid_list)
dataset.save_to_disk(out)
@ucsinfer.command('qualify')
def qualify():
"""
@@ -260,7 +239,6 @@ def finetune(ctx):
logger.debug("FINETUNE mode")
@ucsinfer.command('evaluate')
@click.argument('dataset', default='dataset/')
@click.pass_context
@@ -270,7 +248,6 @@ def evaluate(ctx, dataset, offset, limit):
"""
logger.debug("EVALUATE mode")
logger.warning("Model evaluation is not currently implemented")
if __name__ == '__main__':

28
ucsinfer/import_csv.py Normal file
View File

@@ -0,0 +1,28 @@
import csv
from typing import Generator
from .util import parse_ucs
def csv_to_data(paths, description_key, filename_key, catid_list) -> Generator[tuple[str, str], None, None]:
"""
Accepts a list of paths and returns an iterator of (sentence, class)
tuples.
"""
for path in paths:
with open(path, 'r') as f:
records = csv.DictReader(f)
assert filename_key in records.fieldnames, \
(f"Filename key `{filename_key}` not present in file "
"{path}")
assert description_key in records.fieldnames, \
(f"Description key `{description_key}` not present in "
"file {path}")
for record in records:
ucs_comps = parse_ucs(record[filename_key], catid_list)
if ucs_comps:
yield (record[description_key], ucs_comps.cat_id)