refining gather, adding TODOs

This commit is contained in:
2025-09-03 14:13:12 -07:00
parent d855ba4c78
commit 0a2aaa2a22
3 changed files with 34 additions and 16 deletions

View File

@@ -4,12 +4,17 @@
## Gather ## Gather
- Add "source" column for tracking provenance - Maybe more dataset configurations
## Validate
A function for validating a dataset for finetuning
## Fine-tune ## Fine-tune
- Implement - Implement
## Evaluate ## Evaluate
- Print more information about the dataset coverage of UCS - Print more information about the dataset coverage of UCS

View File

@@ -13,7 +13,8 @@ dependencies = [
"tqdm (>=4.67.1,<5.0.0)", "tqdm (>=4.67.1,<5.0.0)",
"platformdirs (>=4.3.8,<5.0.0)", "platformdirs (>=4.3.8,<5.0.0)",
"click (>=8.2.1,<9.0.0)", "click (>=8.2.1,<9.0.0)",
"tabulate (>=0.9.0,<0.10.0)" "tabulate (>=0.9.0,<0.10.0)",
"datasets (>=4.0.0,<5.0.0)"
] ]

View File

@@ -3,11 +3,14 @@ import sys
import csv import csv
import logging import logging
from typing import Generator
import tqdm import tqdm
import click import click
from tabulate import tabulate, SEPARATING_LINE from tabulate import tabulate, SEPARATING_LINE
from .inference import InferenceContext, load_ucs from .inference import InferenceContext, load_ucs
from .gather import build_sentence_class_dataset
from .recommend import print_recommendation from .recommend import print_recommendation
from .util import ffmpeg_description, parse_ucs from .util import ffmpeg_description, parse_ucs
@@ -117,23 +120,24 @@ def gather(paths, outfile):
""" """
Scan files to build a training dataset at PATH Scan files to build a training dataset at PATH
The `gather` command walks the directory hierarchy for each path in PATHS The `gather` is used to build a training dataset for finetuning the
and looks for .wav and .flac files that are named according to the UCS selected model. Description sentences and UCS categories are collected from
file naming guidelines, with at least a CatID and FX Name, divided by an '.wav' and '.flac' files on-disk that have valid UCS filenames and assigned
underscore. CatIDs, and this information is recorded into a HuggingFace dataset.
For every file ucsinfer finds that meets this criteria, it creates a record Gather scans the filesystem in two passes: first, the directory tree is
in an output dataset CSV file. The dataset file has two columns: the first walked by os.walk and a list of filenames that meet the above name criteria
is the CatID indicated for the file, and the second is the embedded file is compiled. After this list is compiled, each file is scanned one-by-one
description for the file as returned by ffprobe. with ffprobe to obtain its "description" metadata; if this isn't present,
the parsed 'fxname' of te file becomes the description.
""" """
logger.debug("GATHER mode") logger.debug("GATHER mode")
types = ['.wav', '.flac'] types = ['.wav', '.flac']
table = csv.writer(outfile)
logger.debug(f"Loading category list...") logger.debug(f"Loading category list...")
catid_list = [cat.catid for cat in load_ucs()] ucs = load_ucs()
catid_list = [cat.catid for cat in ucs]
scan_list = [] scan_list = []
for path in paths: for path in paths:
@@ -149,9 +153,17 @@ def gather(paths, outfile):
logger.info(f"Found {len(scan_list)} files to process.") logger.info(f"Found {len(scan_list)} files to process.")
def scan_metadata():
for pair in tqdm.tqdm(scan_list, unit='files'): for pair in tqdm.tqdm(scan_list, unit='files'):
if desc := ffmpeg_description(pair[1]): if desc := ffmpeg_description(pair[1]):
table.writerow([pair[0], desc]) yield desc, str(pair[0])
else:
comps = parse_ucs(os.path.basename(pair[1]), catid_list)
assert comps
yield comps.fx_name, str(pair[0])
dataset = build_sentence_class_dataset(scan_metadata())
dataset.save_to_disk(outfile)
@ucsinfer.command('finetune') @ucsinfer.command('finetune')