diff --git a/TODO.md b/TODO.md index e1cd9bf..ec7377f 100644 --- a/TODO.md +++ b/TODO.md @@ -4,12 +4,17 @@ ## Gather -- Add "source" column for tracking provenance +- Maybe more dataset configurations + +## Validate + +A function for validating a dataset for finetuning ## Fine-tune - Implement + ## Evaluate - Print more information about the dataset coverage of UCS diff --git a/pyproject.toml b/pyproject.toml index c4649df..a2a4608 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,8 @@ dependencies = [ "tqdm (>=4.67.1,<5.0.0)", "platformdirs (>=4.3.8,<5.0.0)", "click (>=8.2.1,<9.0.0)", - "tabulate (>=0.9.0,<0.10.0)" + "tabulate (>=0.9.0,<0.10.0)", + "datasets (>=4.0.0,<5.0.0)" ] diff --git a/ucsinfer/__main__.py b/ucsinfer/__main__.py index d41da68..3e8de5a 100644 --- a/ucsinfer/__main__.py +++ b/ucsinfer/__main__.py @@ -3,11 +3,14 @@ import sys import csv import logging +from typing import Generator + import tqdm import click from tabulate import tabulate, SEPARATING_LINE from .inference import InferenceContext, load_ucs +from .gather import build_sentence_class_dataset from .recommend import print_recommendation from .util import ffmpeg_description, parse_ucs @@ -117,23 +120,24 @@ def gather(paths, outfile): """ Scan files to build a training dataset at PATH - The `gather` command walks the directory hierarchy for each path in PATHS - and looks for .wav and .flac files that are named according to the UCS - file naming guidelines, with at least a CatID and FX Name, divided by an - underscore. - - For every file ucsinfer finds that meets this criteria, it creates a record - in an output dataset CSV file. The dataset file has two columns: the first - is the CatID indicated for the file, and the second is the embedded file - description for the file as returned by ffprobe. + The `gather` is used to build a training dataset for finetuning the + selected model. Description sentences and UCS categories are collected from + '.wav' and '.flac' files on-disk that have valid UCS filenames and assigned + CatIDs, and this information is recorded into a HuggingFace dataset. + + Gather scans the filesystem in two passes: first, the directory tree is + walked by os.walk and a list of filenames that meet the above name criteria + is compiled. After this list is compiled, each file is scanned one-by-one + with ffprobe to obtain its "description" metadata; if this isn't present, + the parsed 'fxname' of te file becomes the description. """ logger.debug("GATHER mode") types = ['.wav', '.flac'] - table = csv.writer(outfile) logger.debug(f"Loading category list...") - catid_list = [cat.catid for cat in load_ucs()] + ucs = load_ucs() + catid_list = [cat.catid for cat in ucs] scan_list = [] for path in paths: @@ -149,10 +153,18 @@ def gather(paths, outfile): logger.info(f"Found {len(scan_list)} files to process.") - for pair in tqdm.tqdm(scan_list, unit='files'): - if desc := ffmpeg_description(pair[1]): - table.writerow([pair[0], desc]) + def scan_metadata(): + for pair in tqdm.tqdm(scan_list, unit='files'): + if desc := ffmpeg_description(pair[1]): + yield desc, str(pair[0]) + else: + comps = parse_ucs(os.path.basename(pair[1]), catid_list) + assert comps + yield comps.fx_name, str(pair[0]) + dataset = build_sentence_class_dataset(scan_metadata()) + dataset.save_to_disk(outfile) + @ucsinfer.command('finetune') def finetune():