From 615d8ab27963786d20a65f705cb7b188b2f15cec Mon Sep 17 00:00:00 2001 From: Jamie Hardt Date: Wed, 10 Sep 2025 23:49:33 -0700 Subject: [PATCH] Refactoring of gather --- ucsinfer/__main__.py | 19 ++----------------- ucsinfer/gather.py | 21 +++++++++++++++++++++ 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/ucsinfer/__main__.py b/ucsinfer/__main__.py index 22974da..986577c 100644 --- a/ucsinfer/__main__.py +++ b/ucsinfer/__main__.py @@ -9,7 +9,7 @@ import click from .inference import InferenceContext, load_ucs from .gather import (build_sentence_class_dataset, print_dataset_stats, - ucs_definitions_generator, scan_metadata) + ucs_definitions_generator, scan_metadata, walk_path) from .recommend import print_recommendation from .util import ffmpeg_description, parse_ucs @@ -157,8 +157,6 @@ def gather(ctx, paths, out, ucs_data): """ logger.debug("GATHER mode") - types = ['.wav', '.flac'] - logger.debug(f"Loading category list...") ucs = load_ucs(full_ucs=ctx.obj['complete_ucs']) @@ -169,22 +167,9 @@ def gather(ctx, paths, out, ucs_data): logger.info('Creating dataset for UCS categories instead of from PATH') paths = [] - walker_p = tqdm.tqdm(total=None, unit='dir', desc="Walking filesystem...") for path in paths: - for dirpath, _, filenames in os.walk(path): - logger.info(f"Walking directory {dirpath}") - for filename in filenames: - walker_p.update() - root, ext = os.path.splitext(filename) - if ext not in types or filename.startswith("._"): - continue + scan_list += walk_path(path, catid_list) - if (ucs_components := parse_ucs(root, catid_list)): - p = os.path.join(dirpath, filename) - logger.info(f"Adding path to scan list {p}") - scan_list.append((ucs_components.cat_id, p)) - walker_p.close() - logger.info(f"Found {len(scan_list)} files to process.") logger.info("Building dataset...") diff --git a/ucsinfer/gather.py b/ucsinfer/gather.py index b0ab51e..3c87d98 100644 --- a/ucsinfer/gather.py +++ b/ucsinfer/gather.py @@ -13,6 +13,27 @@ from tabulate import tabulate import logging import tqdm +def walk_path(path:str, catid_list) -> list[tuple[str,str]]: + types = ['.wav', '.flac'] + logger = logging.getLogger('ucsinfer') + walker_p = tqdm.tqdm(total=None, unit='dir', desc="Walking filesystem...") + scan_list = [] + + for dirpath, _, filenames in os.walk(path): + logger.info(f"Walking directory {dirpath}") + for filename in filenames: + walker_p.update() + root, ext = os.path.splitext(filename) + if ext not in types or filename.startswith("._"): + continue + + if (ucs_components := parse_ucs(root, catid_list)): + p = os.path.join(dirpath, filename) + logger.info(f"Adding path to scan list {p}") + scan_list.append((ucs_components.cat_id, p)) + + return scan_list + def scan_metadata(scan_list: list[tuple[str,str]], catid_list: list[str]): logger = logging.getLogger('ucsinfer') for pair in tqdm.tqdm(scan_list, unit='files'):