Compare commits

...

5 Commits

Author SHA1 Message Date
3ca921ad02 Cleaned up code and refactored to new file 2025-10-14 10:31:34 -07:00
5fc57cf7c8 CSV import implementation 2025-10-14 10:26:57 -07:00
2fa5e4575d twiddle 2025-10-14 09:09:48 -07:00
e5698fec7b Plumbing for CSV import from online logs 2025-09-26 21:08:41 -07:00
c75365b856 TODO 2025-09-26 20:52:08 -07:00
4 changed files with 103 additions and 47 deletions

View File

@@ -18,9 +18,8 @@
- Use (anchor, positive) pairs to train a new model - Use (anchor, positive) pairs to train a new model
- Use (sentence) + class labels to train a new model - Use (sentence) + class labels to train a new model
- Implement BatchAllTripletLoss - Implement BatchAllTripletLoss
- Implement a two-phase training regime - Train with anchored definitions and/or...
1. Train with anchored definitions then... - Train with class labels
2. Train with class labels
## Evaluate ## Evaluate

View File

@@ -1,13 +1,11 @@
import os import os
# import csv
import logging import logging
from itertools import chain from itertools import chain
import tqdm
import click import click
# from tabulate import tabulate, SEPARATING_LINE
from .inference import InferenceContext, load_ucs from .inference import InferenceContext, load_ucs
from .import_csv import csv_to_data
from .gather import (build_sentence_class_dataset, print_dataset_stats, from .gather import (build_sentence_class_dataset, print_dataset_stats,
ucs_definitions_generator, scan_metadata, walk_path) ucs_definitions_generator, scan_metadata, walk_path)
from .recommend import print_recommendation from .recommend import print_recommendation
@@ -22,6 +20,7 @@ formatter = logging.Formatter(
stream_handler.setFormatter(formatter) stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler) logger.addHandler(stream_handler)
@click.group(epilog="For more information see " @click.group(epilog="For more information see "
"<https://git.squad51.us/jamie/ucsinfer>") "<https://git.squad51.us/jamie/ucsinfer>")
@click.option('--verbose', '-v', flag_value=True, help='Verbose output') @click.option('--verbose', '-v', flag_value=True, help='Verbose output')
@@ -52,8 +51,8 @@ def ucsinfer(ctx, verbose, no_model_cache, model, complete_ucs):
stream_handler.setLevel(logging.WARNING) stream_handler.setLevel(logging.WARNING)
os.environ['TOKENIZERS_PARALLELISM'] = 'false' os.environ['TOKENIZERS_PARALLELISM'] = 'false'
logger.info("Setting TOKENIZERS_PARALLELISM environment variable to `false" logger.info("Setting TOKENIZERS_PARALLELISM environment variable to "
" explicitly") "`false` explicitly")
ctx.ensure_object(dict) ctx.ensure_object(dict)
ctx.obj['model_cache'] = not no_model_cache ctx.obj['model_cache'] = not no_model_cache
@@ -121,7 +120,8 @@ def recommend(ctx, text, paths, interactive, skip_ucs):
text = os.path.basename(path) text = os.path.basename(path)
while True: while True:
retval = print_recommendation(path, text, inference_ctx, interactive) retval = print_recommendation(
path, text, inference_ctx, interactive)
if not retval: if not retval:
break break
if retval[0] is False: if retval[0] is False:
@@ -137,16 +137,50 @@ def recommend(ctx, text, paths, interactive, skip_ucs):
break break
@ucsinfer.command('gather') @ucsinfer.command('csv')
@click.option('--filename-col', default="FileName",
help="Heading or index of the column containing filenames",
show_default=True)
@click.option('--description-col', default="TrackDescription",
help="Heading or index of the column containing descriptions",
show_default=True)
@click.option('--out', default='dataset/', show_default=True) @click.option('--out', default='dataset/', show_default=True)
# @click.option('--ucs-data', flag_value=True, help="Create a dataset based "
# "on the UCS category explanations and synonymns (PATHS will "
# "be ignored.)")
@click.argument('paths', nargs=-1) @click.argument('paths', nargs=-1)
@click.pass_context @click.pass_context
def gather(ctx, paths, out, ucs_data): def import_csv(ctx, paths: list[str], out, filename_col, description_col):
""" """
Scan files to build a training dataset Scan training data from CSV files
`csv` is used to build a training dataset for finetuning the selected
model, as like the `gather` command, except instead of scanning the
file system it builds a dataset from descriptions and UCS filenames in
columns of a CSV file.
"""
logger.debug("CSV mode")
logger.debug(f"Loading category list...")
ucs = load_ucs(full_ucs=ctx.obj['complete_ucs'])
catid_list = [cat.catid for cat in ucs]
logger.info("Building dataset from csv...")
dataset = build_sentence_class_dataset(
chain(csv_to_data(paths, description_col, filename_col, catid_list),
ucs_definitions_generator(ucs)), catid_list)
logger.info(f"Saving dataset to disk at {out}")
print_dataset_stats(dataset, catid_list)
dataset.save_to_disk(out)
@ucsinfer.command('gather')
@click.option('--out', default='dataset/', show_default=True)
@click.argument('paths', nargs=-1)
@click.pass_context
def gather(ctx, paths, out):
"""
Scan training data from audio files
`gather` is used to build a training dataset for finetuning the selected `gather` is used to build a training dataset for finetuning the selected
model. Description sentences and UCS categories are collected from '.wav' model. Description sentences and UCS categories are collected from '.wav'
@@ -167,16 +201,12 @@ def gather(ctx, paths, out, ucs_data):
scan_list: list[tuple[str, str]] = [] scan_list: list[tuple[str, str]] = []
catid_list = [cat.catid for cat in ucs] catid_list = [cat.catid for cat in ucs]
if ucs_data:
logger.info('Creating dataset for UCS categories instead of from PATH')
paths = []
for path in paths: for path in paths:
scan_list += walk_path(path, catid_list) scan_list += walk_path(path, catid_list)
logger.info(f"Found {len(scan_list)} files to process.") logger.info(f"Found {len(scan_list)} files to process.")
logger.info("Building dataset...") logger.info("Building dataset files...")
dataset = build_sentence_class_dataset( dataset = build_sentence_class_dataset(
chain(scan_metadata(scan_list, catid_list), chain(scan_metadata(scan_list, catid_list),
@@ -187,6 +217,7 @@ def gather(ctx, paths, out, ucs_data):
print_dataset_stats(dataset, catid_list) print_dataset_stats(dataset, catid_list)
dataset.save_to_disk(out) dataset.save_to_disk(out)
@ucsinfer.command('qualify') @ucsinfer.command('qualify')
def qualify(): def qualify():
""" """
@@ -208,7 +239,6 @@ def finetune(ctx):
logger.debug("FINETUNE mode") logger.debug("FINETUNE mode")
@ucsinfer.command('evaluate') @ucsinfer.command('evaluate')
@click.argument('dataset', default='dataset/') @click.argument('dataset', default='dataset/')
@click.pass_context @click.pass_context
@@ -220,6 +250,5 @@ def evaluate(ctx, dataset, offset, limit):
logger.warning("Model evaluation is not currently implemented") logger.warning("Model evaluation is not currently implemented")
if __name__ == '__main__': if __name__ == '__main__':
ucsinfer(obj={}) ucsinfer(obj={})

28
ucsinfer/import_csv.py Normal file
View File

@@ -0,0 +1,28 @@
import csv
from typing import Generator
from .util import parse_ucs
def csv_to_data(paths, description_key, filename_key, catid_list) -> Generator[tuple[str, str], None, None]:
"""
Accepts a list of paths and returns an iterator of (sentence, class)
tuples.
"""
for path in paths:
with open(path, 'r') as f:
records = csv.DictReader(f)
assert filename_key in records.fieldnames, \
(f"Filename key `{filename_key}` not present in file "
"{path}")
assert description_key in records.fieldnames, \
(f"Description key `{description_key}` not present in "
"file {path}")
for record in records:
ucs_comps = parse_ucs(record[filename_key], catid_list)
if ucs_comps:
yield (record[description_key], ucs_comps.cat_id)