Compare commits

...

5 Commits

Author SHA1 Message Date
615d8ab279 Refactoring of gather 2025-09-10 23:49:33 -07:00
9a887c4ed5 Refactoring of gather 2025-09-10 23:38:33 -07:00
63b140209b Command-line entry point:wq 2025-09-10 22:23:57 -07:00
d181ac73b1 Made prompt shorter 2025-09-10 22:08:16 -07:00
bddce23c76 Added to prompt help 2025-09-10 22:07:53 -07:00
4 changed files with 68 additions and 48 deletions

View File

@@ -26,3 +26,6 @@ build-backend = "poetry.core.masonry.api"
ipython = "^9.4.0"
jupyter = "^1.1.1"
[tool.poetry.scripts]
ucsinfer = "ucsinfer.__main__:ucsinfer"

View File

@@ -1,7 +1,6 @@
import os
# import csv
import logging
from subprocess import CalledProcessError
from itertools import chain
import tqdm
@@ -9,7 +8,8 @@ import click
# from tabulate import tabulate, SEPARATING_LINE
from .inference import InferenceContext, load_ucs
from .gather import build_sentence_class_dataset, print_dataset_stats
from .gather import (build_sentence_class_dataset, print_dataset_stats,
ucs_definitions_generator, scan_metadata, walk_path)
from .recommend import print_recommendation
from .util import ffmpeg_description, parse_ucs
@@ -157,67 +157,30 @@ def gather(ctx, paths, out, ucs_data):
"""
logger.debug("GATHER mode")
types = ['.wav', '.flac']
logger.debug(f"Loading category list...")
ucs = load_ucs(full_ucs=ctx.obj['complete_ucs'])
scan_list = []
scan_list: list[tuple[str,str]] = []
catid_list = [cat.catid for cat in ucs]
if ucs_data:
logger.info('Creating dataset for UCS categories instead of from PATH')
paths = []
walker_p = tqdm.tqdm(total=None, unit='dir', desc="Walking filesystem...")
for path in paths:
for dirpath, _, filenames in os.walk(path):
logger.info(f"Walking directory {dirpath}")
for filename in filenames:
walker_p.update()
root, ext = os.path.splitext(filename)
if ext not in types or filename.startswith("._"):
continue
if (ucs_components := parse_ucs(root, catid_list)):
p = os.path.join(dirpath, filename)
logger.info(f"Adding path to scan list {p}")
scan_list.append((ucs_components.cat_id, p))
walker_p.close()
scan_list += walk_path(path, catid_list)
logger.info(f"Found {len(scan_list)} files to process.")
def scan_metadata():
for pair in tqdm.tqdm(scan_list, unit='files'):
logger.info(f"Scanning file with ffprobe: {pair[1]}")
try:
desc = ffmpeg_description(pair[1])
except CalledProcessError as e:
logger.error(f"ffprobe returned error (){e.returncode}): " \
+ e.stderr)
continue
if desc:
yield desc, str(pair[0])
else:
comps = parse_ucs(os.path.basename(pair[1]), catid_list)
assert comps
yield comps.fx_name, str(pair[0])
def ucs_metadata():
for cat in ucs:
yield cat.explanations, cat.catid
yield ", ".join(cat.synonymns), cat.catid
logger.info("Building dataset...")
dataset = build_sentence_class_dataset(chain(scan_metadata(),
ucs_metadata()),
catid_list)
dataset = build_sentence_class_dataset(
chain(scan_metadata(scan_list, catid_list),
ucs_definitions_generator(ucs)),
catid_list)
logger.info(f"Saving dataset to disk at {out}")
print_dataset_stats(dataset)
print_dataset_stats(dataset, catid_list)
dataset.save_to_disk(out)

View File

@@ -1,9 +1,63 @@
from .inference import Ucs
from .util import ffmpeg_description, parse_ucs
from subprocess import CalledProcessError
import os.path
from datasets import Dataset, Features, Value, ClassLabel, DatasetInfo
from datasets.dataset_dict import DatasetDict
from typing import Iterator
from typing import Iterator, Generator
from tabulate import tabulate
import logging
import tqdm
def walk_path(path:str, catid_list) -> list[tuple[str,str]]:
types = ['.wav', '.flac']
logger = logging.getLogger('ucsinfer')
walker_p = tqdm.tqdm(total=None, unit='dir', desc="Walking filesystem...")
scan_list = []
for dirpath, _, filenames in os.walk(path):
logger.info(f"Walking directory {dirpath}")
for filename in filenames:
walker_p.update()
root, ext = os.path.splitext(filename)
if ext not in types or filename.startswith("._"):
continue
if (ucs_components := parse_ucs(root, catid_list)):
p = os.path.join(dirpath, filename)
logger.info(f"Adding path to scan list {p}")
scan_list.append((ucs_components.cat_id, p))
return scan_list
def scan_metadata(scan_list: list[tuple[str,str]], catid_list: list[str]):
logger = logging.getLogger('ucsinfer')
for pair in tqdm.tqdm(scan_list, unit='files'):
logger.info(f"Scanning file with ffprobe: {pair[1]}")
try:
desc = ffmpeg_description(pair[1])
except CalledProcessError as e:
logger.error(f"ffprobe returned error (){e.returncode}): " \
+ e.stderr)
continue
if desc:
yield desc, str(pair[0])
else:
comps = parse_ucs(os.path.basename(pair[1]), catid_list)
assert comps
yield comps.fx_name, str(pair[0])
def ucs_definitions_generator(ucs: list[Ucs]) \
-> Generator[tuple[str,str],None, None]:
for cat in ucs:
yield cat.explanations, cat.catid
yield ", ".join(cat.synonymns), cat.catid
def print_dataset_stats(dataset: DatasetDict, catlist: list[str]):

View File

@@ -31,7 +31,7 @@ def print_recommendation(path: str | None, text: str, ctx: InferenceContext,
print(f"- {i}: {r} ({cat}-{subcat})")
if interactive_rename and path is not None:
response = input("(n#), t [text], c [cat], ?, q > ")
response = input("(n#), t, c, b, ?, q > ")
if m := match(r'^([0-9]+)', response):
selection = int(m.group(1))