From 9a887c4ed59e5aec2ac85dba1dea96126cb0f800 Mon Sep 17 00:00:00 2001 From: Jamie Hardt Date: Wed, 10 Sep 2025 23:38:33 -0700 Subject: [PATCH] Refactoring of gather --- ucsinfer/__main__.py | 40 +++++++++------------------------------- ucsinfer/gather.py | 35 ++++++++++++++++++++++++++++++++++- 2 files changed, 43 insertions(+), 32 deletions(-) diff --git a/ucsinfer/__main__.py b/ucsinfer/__main__.py index 4ca2550..22974da 100644 --- a/ucsinfer/__main__.py +++ b/ucsinfer/__main__.py @@ -1,7 +1,6 @@ import os # import csv import logging -from subprocess import CalledProcessError from itertools import chain import tqdm @@ -9,7 +8,8 @@ import click # from tabulate import tabulate, SEPARATING_LINE from .inference import InferenceContext, load_ucs -from .gather import build_sentence_class_dataset, print_dataset_stats +from .gather import (build_sentence_class_dataset, print_dataset_stats, + ucs_definitions_generator, scan_metadata) from .recommend import print_recommendation from .util import ffmpeg_description, parse_ucs @@ -162,7 +162,7 @@ def gather(ctx, paths, out, ucs_data): logger.debug(f"Loading category list...") ucs = load_ucs(full_ucs=ctx.obj['complete_ucs']) - scan_list = [] + scan_list: list[tuple[str,str]] = [] catid_list = [cat.catid for cat in ucs] if ucs_data: @@ -184,40 +184,18 @@ def gather(ctx, paths, out, ucs_data): logger.info(f"Adding path to scan list {p}") scan_list.append((ucs_components.cat_id, p)) walker_p.close() - + logger.info(f"Found {len(scan_list)} files to process.") - def scan_metadata(): - for pair in tqdm.tqdm(scan_list, unit='files'): - logger.info(f"Scanning file with ffprobe: {pair[1]}") - try: - desc = ffmpeg_description(pair[1]) - except CalledProcessError as e: - logger.error(f"ffprobe returned error (){e.returncode}): " \ - + e.stderr) - continue - - if desc: - yield desc, str(pair[0]) - else: - comps = parse_ucs(os.path.basename(pair[1]), catid_list) - assert comps - yield comps.fx_name, str(pair[0]) - - def ucs_metadata(): - for cat in ucs: - yield cat.explanations, cat.catid - yield ", ".join(cat.synonymns), cat.catid - logger.info("Building dataset...") - dataset = build_sentence_class_dataset(chain(scan_metadata(), - ucs_metadata()), - catid_list) + dataset = build_sentence_class_dataset( + chain(scan_metadata(scan_list, catid_list), + ucs_definitions_generator(ucs)), + catid_list) - logger.info(f"Saving dataset to disk at {out}") - print_dataset_stats(dataset) + print_dataset_stats(dataset, catid_list) dataset.save_to_disk(out) diff --git a/ucsinfer/gather.py b/ucsinfer/gather.py index 9d20cd8..b0ab51e 100644 --- a/ucsinfer/gather.py +++ b/ucsinfer/gather.py @@ -1,9 +1,42 @@ +from .inference import Ucs +from .util import ffmpeg_description, parse_ucs + +from subprocess import CalledProcessError +import os.path + from datasets import Dataset, Features, Value, ClassLabel, DatasetInfo from datasets.dataset_dict import DatasetDict -from typing import Iterator +from typing import Iterator, Generator from tabulate import tabulate +import logging +import tqdm + +def scan_metadata(scan_list: list[tuple[str,str]], catid_list: list[str]): + logger = logging.getLogger('ucsinfer') + for pair in tqdm.tqdm(scan_list, unit='files'): + logger.info(f"Scanning file with ffprobe: {pair[1]}") + try: + desc = ffmpeg_description(pair[1]) + except CalledProcessError as e: + logger.error(f"ffprobe returned error (){e.returncode}): " \ + + e.stderr) + continue + + if desc: + yield desc, str(pair[0]) + else: + comps = parse_ucs(os.path.basename(pair[1]), catid_list) + assert comps + yield comps.fx_name, str(pair[0]) + + +def ucs_definitions_generator(ucs: list[Ucs]) \ + -> Generator[tuple[str,str],None, None]: + for cat in ucs: + yield cat.explanations, cat.catid + yield ", ".join(cat.synonymns), cat.catid def print_dataset_stats(dataset: DatasetDict, catlist: list[str]):