Refactoring of gather
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import os
|
||||
# import csv
|
||||
import logging
|
||||
from subprocess import CalledProcessError
|
||||
from itertools import chain
|
||||
|
||||
import tqdm
|
||||
@@ -9,7 +8,8 @@ import click
|
||||
# from tabulate import tabulate, SEPARATING_LINE
|
||||
|
||||
from .inference import InferenceContext, load_ucs
|
||||
from .gather import build_sentence_class_dataset, print_dataset_stats
|
||||
from .gather import (build_sentence_class_dataset, print_dataset_stats,
|
||||
ucs_definitions_generator, scan_metadata)
|
||||
from .recommend import print_recommendation
|
||||
from .util import ffmpeg_description, parse_ucs
|
||||
|
||||
@@ -162,7 +162,7 @@ def gather(ctx, paths, out, ucs_data):
|
||||
logger.debug(f"Loading category list...")
|
||||
ucs = load_ucs(full_ucs=ctx.obj['complete_ucs'])
|
||||
|
||||
scan_list = []
|
||||
scan_list: list[tuple[str,str]] = []
|
||||
catid_list = [cat.catid for cat in ucs]
|
||||
|
||||
if ucs_data:
|
||||
@@ -187,37 +187,15 @@ def gather(ctx, paths, out, ucs_data):
|
||||
|
||||
logger.info(f"Found {len(scan_list)} files to process.")
|
||||
|
||||
def scan_metadata():
|
||||
for pair in tqdm.tqdm(scan_list, unit='files'):
|
||||
logger.info(f"Scanning file with ffprobe: {pair[1]}")
|
||||
try:
|
||||
desc = ffmpeg_description(pair[1])
|
||||
except CalledProcessError as e:
|
||||
logger.error(f"ffprobe returned error (){e.returncode}): " \
|
||||
+ e.stderr)
|
||||
continue
|
||||
|
||||
if desc:
|
||||
yield desc, str(pair[0])
|
||||
else:
|
||||
comps = parse_ucs(os.path.basename(pair[1]), catid_list)
|
||||
assert comps
|
||||
yield comps.fx_name, str(pair[0])
|
||||
|
||||
def ucs_metadata():
|
||||
for cat in ucs:
|
||||
yield cat.explanations, cat.catid
|
||||
yield ", ".join(cat.synonymns), cat.catid
|
||||
|
||||
logger.info("Building dataset...")
|
||||
|
||||
dataset = build_sentence_class_dataset(chain(scan_metadata(),
|
||||
ucs_metadata()),
|
||||
dataset = build_sentence_class_dataset(
|
||||
chain(scan_metadata(scan_list, catid_list),
|
||||
ucs_definitions_generator(ucs)),
|
||||
catid_list)
|
||||
|
||||
|
||||
logger.info(f"Saving dataset to disk at {out}")
|
||||
print_dataset_stats(dataset)
|
||||
print_dataset_stats(dataset, catid_list)
|
||||
dataset.save_to_disk(out)
|
||||
|
||||
|
||||
|
@@ -1,9 +1,42 @@
|
||||
from .inference import Ucs
|
||||
from .util import ffmpeg_description, parse_ucs
|
||||
|
||||
from subprocess import CalledProcessError
|
||||
import os.path
|
||||
|
||||
from datasets import Dataset, Features, Value, ClassLabel, DatasetInfo
|
||||
from datasets.dataset_dict import DatasetDict
|
||||
|
||||
from typing import Iterator
|
||||
from typing import Iterator, Generator
|
||||
|
||||
from tabulate import tabulate
|
||||
import logging
|
||||
import tqdm
|
||||
|
||||
def scan_metadata(scan_list: list[tuple[str,str]], catid_list: list[str]):
|
||||
logger = logging.getLogger('ucsinfer')
|
||||
for pair in tqdm.tqdm(scan_list, unit='files'):
|
||||
logger.info(f"Scanning file with ffprobe: {pair[1]}")
|
||||
try:
|
||||
desc = ffmpeg_description(pair[1])
|
||||
except CalledProcessError as e:
|
||||
logger.error(f"ffprobe returned error (){e.returncode}): " \
|
||||
+ e.stderr)
|
||||
continue
|
||||
|
||||
if desc:
|
||||
yield desc, str(pair[0])
|
||||
else:
|
||||
comps = parse_ucs(os.path.basename(pair[1]), catid_list)
|
||||
assert comps
|
||||
yield comps.fx_name, str(pair[0])
|
||||
|
||||
|
||||
def ucs_definitions_generator(ucs: list[Ucs]) \
|
||||
-> Generator[tuple[str,str],None, None]:
|
||||
for cat in ucs:
|
||||
yield cat.explanations, cat.catid
|
||||
yield ", ".join(cat.synonymns), cat.catid
|
||||
|
||||
def print_dataset_stats(dataset: DatasetDict, catlist: list[str]):
|
||||
|
||||
|
Reference in New Issue
Block a user