refining gather, adding TODOs
This commit is contained in:
7
TODO.md
7
TODO.md
@@ -4,12 +4,17 @@
|
||||
|
||||
## Gather
|
||||
|
||||
- Add "source" column for tracking provenance
|
||||
- Maybe more dataset configurations
|
||||
|
||||
## Validate
|
||||
|
||||
A function for validating a dataset for finetuning
|
||||
|
||||
## Fine-tune
|
||||
|
||||
- Implement
|
||||
|
||||
|
||||
## Evaluate
|
||||
|
||||
- Print more information about the dataset coverage of UCS
|
||||
|
@@ -13,7 +13,8 @@ dependencies = [
|
||||
"tqdm (>=4.67.1,<5.0.0)",
|
||||
"platformdirs (>=4.3.8,<5.0.0)",
|
||||
"click (>=8.2.1,<9.0.0)",
|
||||
"tabulate (>=0.9.0,<0.10.0)"
|
||||
"tabulate (>=0.9.0,<0.10.0)",
|
||||
"datasets (>=4.0.0,<5.0.0)"
|
||||
]
|
||||
|
||||
|
||||
|
@@ -3,11 +3,14 @@ import sys
|
||||
import csv
|
||||
import logging
|
||||
|
||||
from typing import Generator
|
||||
|
||||
import tqdm
|
||||
import click
|
||||
from tabulate import tabulate, SEPARATING_LINE
|
||||
|
||||
from .inference import InferenceContext, load_ucs
|
||||
from .gather import build_sentence_class_dataset
|
||||
from .recommend import print_recommendation
|
||||
from .util import ffmpeg_description, parse_ucs
|
||||
|
||||
@@ -117,23 +120,24 @@ def gather(paths, outfile):
|
||||
"""
|
||||
Scan files to build a training dataset at PATH
|
||||
|
||||
The `gather` command walks the directory hierarchy for each path in PATHS
|
||||
and looks for .wav and .flac files that are named according to the UCS
|
||||
file naming guidelines, with at least a CatID and FX Name, divided by an
|
||||
underscore.
|
||||
The `gather` is used to build a training dataset for finetuning the
|
||||
selected model. Description sentences and UCS categories are collected from
|
||||
'.wav' and '.flac' files on-disk that have valid UCS filenames and assigned
|
||||
CatIDs, and this information is recorded into a HuggingFace dataset.
|
||||
|
||||
For every file ucsinfer finds that meets this criteria, it creates a record
|
||||
in an output dataset CSV file. The dataset file has two columns: the first
|
||||
is the CatID indicated for the file, and the second is the embedded file
|
||||
description for the file as returned by ffprobe.
|
||||
Gather scans the filesystem in two passes: first, the directory tree is
|
||||
walked by os.walk and a list of filenames that meet the above name criteria
|
||||
is compiled. After this list is compiled, each file is scanned one-by-one
|
||||
with ffprobe to obtain its "description" metadata; if this isn't present,
|
||||
the parsed 'fxname' of te file becomes the description.
|
||||
"""
|
||||
logger.debug("GATHER mode")
|
||||
|
||||
types = ['.wav', '.flac']
|
||||
table = csv.writer(outfile)
|
||||
|
||||
logger.debug(f"Loading category list...")
|
||||
catid_list = [cat.catid for cat in load_ucs()]
|
||||
ucs = load_ucs()
|
||||
catid_list = [cat.catid for cat in ucs]
|
||||
|
||||
scan_list = []
|
||||
for path in paths:
|
||||
@@ -149,9 +153,17 @@ def gather(paths, outfile):
|
||||
|
||||
logger.info(f"Found {len(scan_list)} files to process.")
|
||||
|
||||
for pair in tqdm.tqdm(scan_list, unit='files'):
|
||||
if desc := ffmpeg_description(pair[1]):
|
||||
table.writerow([pair[0], desc])
|
||||
def scan_metadata():
|
||||
for pair in tqdm.tqdm(scan_list, unit='files'):
|
||||
if desc := ffmpeg_description(pair[1]):
|
||||
yield desc, str(pair[0])
|
||||
else:
|
||||
comps = parse_ucs(os.path.basename(pair[1]), catid_list)
|
||||
assert comps
|
||||
yield comps.fx_name, str(pair[0])
|
||||
|
||||
dataset = build_sentence_class_dataset(scan_metadata())
|
||||
dataset.save_to_disk(outfile)
|
||||
|
||||
|
||||
@ucsinfer.command('finetune')
|
||||
|
Reference in New Issue
Block a user