Compare commits
2 Commits
336d6f013e
...
0a2aaa2a22
Author | SHA1 | Date | |
---|---|---|---|
0a2aaa2a22 | |||
d855ba4c78 |
7
TODO.md
7
TODO.md
@@ -4,12 +4,17 @@
|
|||||||
|
|
||||||
## Gather
|
## Gather
|
||||||
|
|
||||||
- Add "source" column for tracking provenance
|
- Maybe more dataset configurations
|
||||||
|
|
||||||
|
## Validate
|
||||||
|
|
||||||
|
A function for validating a dataset for finetuning
|
||||||
|
|
||||||
## Fine-tune
|
## Fine-tune
|
||||||
|
|
||||||
- Implement
|
- Implement
|
||||||
|
|
||||||
|
|
||||||
## Evaluate
|
## Evaluate
|
||||||
|
|
||||||
- Print more information about the dataset coverage of UCS
|
- Print more information about the dataset coverage of UCS
|
||||||
|
@@ -13,7 +13,8 @@ dependencies = [
|
|||||||
"tqdm (>=4.67.1,<5.0.0)",
|
"tqdm (>=4.67.1,<5.0.0)",
|
||||||
"platformdirs (>=4.3.8,<5.0.0)",
|
"platformdirs (>=4.3.8,<5.0.0)",
|
||||||
"click (>=8.2.1,<9.0.0)",
|
"click (>=8.2.1,<9.0.0)",
|
||||||
"tabulate (>=0.9.0,<0.10.0)"
|
"tabulate (>=0.9.0,<0.10.0)",
|
||||||
|
"datasets (>=4.0.0,<5.0.0)"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@@ -3,11 +3,14 @@ import sys
|
|||||||
import csv
|
import csv
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from typing import Generator
|
||||||
|
|
||||||
import tqdm
|
import tqdm
|
||||||
import click
|
import click
|
||||||
from tabulate import tabulate, SEPARATING_LINE
|
from tabulate import tabulate, SEPARATING_LINE
|
||||||
|
|
||||||
from .inference import InferenceContext, load_ucs
|
from .inference import InferenceContext, load_ucs
|
||||||
|
from .gather import build_sentence_class_dataset
|
||||||
from .recommend import print_recommendation
|
from .recommend import print_recommendation
|
||||||
from .util import ffmpeg_description, parse_ucs
|
from .util import ffmpeg_description, parse_ucs
|
||||||
|
|
||||||
@@ -117,22 +120,24 @@ def gather(paths, outfile):
|
|||||||
"""
|
"""
|
||||||
Scan files to build a training dataset at PATH
|
Scan files to build a training dataset at PATH
|
||||||
|
|
||||||
The `gather` command walks the directory hierarchy for each path in PATHS
|
The `gather` is used to build a training dataset for finetuning the
|
||||||
and looks for .wav and .flac files that are named according to the UCS
|
selected model. Description sentences and UCS categories are collected from
|
||||||
file naming guidelines, with at least a CatID and FX Name, divided by an
|
'.wav' and '.flac' files on-disk that have valid UCS filenames and assigned
|
||||||
underscore.
|
CatIDs, and this information is recorded into a HuggingFace dataset.
|
||||||
|
|
||||||
For every file ucsinfer finds that meets this criteria, it creates a record
|
Gather scans the filesystem in two passes: first, the directory tree is
|
||||||
in an output dataset CSV file. The dataset file has two columns: the first
|
walked by os.walk and a list of filenames that meet the above name criteria
|
||||||
is the CatID indicated for the file, and the second is the embedded file
|
is compiled. After this list is compiled, each file is scanned one-by-one
|
||||||
description for the file as returned by ffprobe.
|
with ffprobe to obtain its "description" metadata; if this isn't present,
|
||||||
|
the parsed 'fxname' of te file becomes the description.
|
||||||
"""
|
"""
|
||||||
logger.debug("GATHER mode")
|
logger.debug("GATHER mode")
|
||||||
|
|
||||||
types = ['.wav', '.flac']
|
types = ['.wav', '.flac']
|
||||||
table = csv.writer(outfile)
|
|
||||||
logger.debug(f"Loading category list...")
|
logger.debug(f"Loading category list...")
|
||||||
catid_list = [cat.catid for cat in load_ucs()]
|
ucs = load_ucs()
|
||||||
|
catid_list = [cat.catid for cat in ucs]
|
||||||
|
|
||||||
scan_list = []
|
scan_list = []
|
||||||
for path in paths:
|
for path in paths:
|
||||||
@@ -148,10 +153,18 @@ def gather(paths, outfile):
|
|||||||
|
|
||||||
logger.info(f"Found {len(scan_list)} files to process.")
|
logger.info(f"Found {len(scan_list)} files to process.")
|
||||||
|
|
||||||
for pair in tqdm.tqdm(scan_list, unit='files', file=sys.stderr):
|
def scan_metadata():
|
||||||
if desc := ffmpeg_description(pair[1]):
|
for pair in tqdm.tqdm(scan_list, unit='files'):
|
||||||
table.writerow([pair[0], desc])
|
if desc := ffmpeg_description(pair[1]):
|
||||||
|
yield desc, str(pair[0])
|
||||||
|
else:
|
||||||
|
comps = parse_ucs(os.path.basename(pair[1]), catid_list)
|
||||||
|
assert comps
|
||||||
|
yield comps.fx_name, str(pair[0])
|
||||||
|
|
||||||
|
dataset = build_sentence_class_dataset(scan_metadata())
|
||||||
|
dataset.save_to_disk(outfile)
|
||||||
|
|
||||||
|
|
||||||
@ucsinfer.command('finetune')
|
@ucsinfer.command('finetune')
|
||||||
def finetune():
|
def finetune():
|
||||||
|
Reference in New Issue
Block a user