Compare commits
	
		
			2 Commits
		
	
	
		
			336d6f013e
			...
			0a2aaa2a22
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 0a2aaa2a22 | |||
| d855ba4c78 | 
							
								
								
									
										7
									
								
								TODO.md
									
									
									
									
									
								
							
							
						
						
									
										7
									
								
								TODO.md
									
									
									
									
									
								
							| @@ -4,12 +4,17 @@ | |||||||
|  |  | ||||||
| ## Gather | ## Gather | ||||||
|  |  | ||||||
| - Add "source" column for tracking provenance | - Maybe more dataset configurations | ||||||
|  |  | ||||||
|  | ## Validate  | ||||||
|  |  | ||||||
|  | A function for validating a dataset for finetuning | ||||||
|  |  | ||||||
| ## Fine-tune | ## Fine-tune | ||||||
|  |  | ||||||
| - Implement | - Implement | ||||||
|  |  | ||||||
|  |  | ||||||
| ## Evaluate | ## Evaluate | ||||||
|  |  | ||||||
| - Print more information about the dataset coverage of UCS | - Print more information about the dataset coverage of UCS | ||||||
|   | |||||||
| @@ -13,7 +13,8 @@ dependencies = [ | |||||||
|     "tqdm (>=4.67.1,<5.0.0)", |     "tqdm (>=4.67.1,<5.0.0)", | ||||||
|     "platformdirs (>=4.3.8,<5.0.0)", |     "platformdirs (>=4.3.8,<5.0.0)", | ||||||
|     "click (>=8.2.1,<9.0.0)", |     "click (>=8.2.1,<9.0.0)", | ||||||
|     "tabulate (>=0.9.0,<0.10.0)" |     "tabulate (>=0.9.0,<0.10.0)", | ||||||
|  |     "datasets (>=4.0.0,<5.0.0)" | ||||||
| ] | ] | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
| @@ -3,11 +3,14 @@ import sys | |||||||
| import csv | import csv | ||||||
| import logging | import logging | ||||||
|  |  | ||||||
|  | from typing import Generator | ||||||
|  |  | ||||||
| import tqdm | import tqdm | ||||||
| import click | import click | ||||||
| from tabulate import tabulate, SEPARATING_LINE | from tabulate import tabulate, SEPARATING_LINE | ||||||
|  |  | ||||||
| from .inference import InferenceContext, load_ucs | from .inference import InferenceContext, load_ucs | ||||||
|  | from .gather import build_sentence_class_dataset  | ||||||
| from .recommend import print_recommendation | from .recommend import print_recommendation | ||||||
| from .util import ffmpeg_description, parse_ucs | from .util import ffmpeg_description, parse_ucs | ||||||
|  |  | ||||||
| @@ -117,22 +120,24 @@ def gather(paths, outfile): | |||||||
|     """ |     """ | ||||||
|     Scan files to build a training dataset at PATH |     Scan files to build a training dataset at PATH | ||||||
|      |      | ||||||
|     The `gather` command walks the directory hierarchy for each path in PATHS  |     The `gather` is used to build a training dataset for finetuning the  | ||||||
|     and looks for .wav and .flac files that are named according to the UCS  |     selected model. Description sentences and UCS categories are collected from  | ||||||
|     file naming guidelines, with at least a CatID and FX Name, divided by an  |     '.wav' and '.flac' files on-disk that have valid UCS filenames and assigned  | ||||||
|     underscore. |     CatIDs, and this information is recorded into a HuggingFace dataset. | ||||||
|      |      | ||||||
|     For every file ucsinfer finds that meets this criteria, it creates a record |     Gather scans the filesystem in two passes: first, the directory tree is  | ||||||
|     in an output dataset CSV file. The dataset file has two columns: the first |     walked by os.walk and a list of filenames that meet the above name criteria  | ||||||
|     is the CatID indicated for the file, and the second is the embedded file  |     is compiled. After this list is compiled, each file is scanned one-by-one | ||||||
|     description for the file as returned by ffprobe. |     with ffprobe to obtain its "description" metadata; if this isn't present, | ||||||
|  |     the parsed 'fxname' of te file becomes the description. | ||||||
|     """ |     """ | ||||||
|     logger.debug("GATHER mode") |     logger.debug("GATHER mode") | ||||||
|  |  | ||||||
|     types = ['.wav', '.flac'] |     types = ['.wav', '.flac'] | ||||||
|     table = csv.writer(outfile) |  | ||||||
|     logger.debug(f"Loading category list...") |     logger.debug(f"Loading category list...") | ||||||
|     catid_list = [cat.catid for cat in load_ucs()] |     ucs = load_ucs() | ||||||
|  |     catid_list = [cat.catid for cat in ucs] | ||||||
|  |  | ||||||
|     scan_list = [] |     scan_list = [] | ||||||
|     for path in paths: |     for path in paths: | ||||||
| @@ -148,9 +153,17 @@ def gather(paths, outfile): | |||||||
|  |  | ||||||
|     logger.info(f"Found {len(scan_list)} files to process.") |     logger.info(f"Found {len(scan_list)} files to process.") | ||||||
|  |  | ||||||
|     for pair in tqdm.tqdm(scan_list, unit='files', file=sys.stderr): |     def scan_metadata(): | ||||||
|  |         for pair in tqdm.tqdm(scan_list, unit='files'): | ||||||
|             if desc := ffmpeg_description(pair[1]): |             if desc := ffmpeg_description(pair[1]): | ||||||
|             table.writerow([pair[0], desc]) |                 yield desc, str(pair[0]) | ||||||
|  |             else: | ||||||
|  |                 comps = parse_ucs(os.path.basename(pair[1]), catid_list) | ||||||
|  |                 assert comps | ||||||
|  |                 yield comps.fx_name, str(pair[0]) | ||||||
|  |  | ||||||
|  |     dataset = build_sentence_class_dataset(scan_metadata()) | ||||||
|  |     dataset.save_to_disk(outfile) | ||||||
|      |      | ||||||
|  |  | ||||||
| @ucsinfer.command('finetune') | @ucsinfer.command('finetune') | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user