Added better error reporting to gather
And a bunch of options to control behavior
This commit is contained in:
@@ -1,11 +1,11 @@
|
|||||||
import os
|
import os
|
||||||
import csv
|
# import csv
|
||||||
import logging
|
import logging
|
||||||
|
from subprocess import CalledProcessError
|
||||||
|
|
||||||
import tqdm
|
import tqdm
|
||||||
import click
|
import click
|
||||||
from tabulate import tabulate, SEPARATING_LINE
|
# from tabulate import tabulate, SEPARATING_LINE
|
||||||
|
|
||||||
from .inference import InferenceContext, load_ucs
|
from .inference import InferenceContext, load_ucs
|
||||||
from .gather import build_sentence_class_dataset
|
from .gather import build_sentence_class_dataset
|
||||||
@@ -30,8 +30,11 @@ logger.addHandler(stream_handler)
|
|||||||
help="Select the sentence_transformer model to use")
|
help="Select the sentence_transformer model to use")
|
||||||
@click.option('--no-model-cache', flag_value=True,
|
@click.option('--no-model-cache', flag_value=True,
|
||||||
help="Don't use local model cache")
|
help="Don't use local model cache")
|
||||||
|
@click.option('--complete-ucs', flag_value=True, default=False,
|
||||||
|
help="Use all UCS categories. By default, all 'FOLEY' and "
|
||||||
|
"'ARCHIVED' UCS categories are excluded from all functions.")
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
def ucsinfer(ctx, verbose, no_model_cache, model):
|
def ucsinfer(ctx, verbose, no_model_cache, model, complete_ucs):
|
||||||
"""
|
"""
|
||||||
Tools for applying UCS categories to sounds using large-language Models
|
Tools for applying UCS categories to sounds using large-language Models
|
||||||
"""
|
"""
|
||||||
@@ -50,10 +53,17 @@ def ucsinfer(ctx, verbose, no_model_cache, model):
|
|||||||
ctx.ensure_object(dict)
|
ctx.ensure_object(dict)
|
||||||
ctx.obj['model_cache'] = not no_model_cache
|
ctx.obj['model_cache'] = not no_model_cache
|
||||||
ctx.obj['model_name'] = model
|
ctx.obj['model_name'] = model
|
||||||
|
ctx.obj['complete_ucs'] = complete_ucs
|
||||||
|
|
||||||
if no_model_cache:
|
if no_model_cache:
|
||||||
logger.info("Model cache inhibited by config")
|
logger.info("Model cache inhibited by config")
|
||||||
|
|
||||||
|
if complete_ucs:
|
||||||
|
logger.info("Using complete UCS catgeory list")
|
||||||
|
else:
|
||||||
|
logger.info("Non-descriptive UCS categories will be excluded. Turn "
|
||||||
|
"this option off by passing --complete-ucs.")
|
||||||
|
|
||||||
logger.info(f"Using model {model}")
|
logger.info(f"Using model {model}")
|
||||||
|
|
||||||
|
|
||||||
@@ -82,7 +92,8 @@ def recommend(ctx, text, paths, interactive, skip_ucs):
|
|||||||
"""
|
"""
|
||||||
logger.debug("RECOMMEND mode")
|
logger.debug("RECOMMEND mode")
|
||||||
inference_ctx = InferenceContext(ctx.obj['model_name'],
|
inference_ctx = InferenceContext(ctx.obj['model_name'],
|
||||||
use_cached_model=ctx.obj['model_cache'])
|
use_cached_model=ctx.obj['model_cache'],
|
||||||
|
use_full_ucs=ctx.obj['complete_ucs'])
|
||||||
|
|
||||||
if text is not None:
|
if text is not None:
|
||||||
print_recommendation(None, text, inference_ctx,
|
print_recommendation(None, text, inference_ctx,
|
||||||
@@ -122,7 +133,8 @@ def recommend(ctx, text, paths, interactive, skip_ucs):
|
|||||||
"on the UCS category explanations and synonymns (PATHS will "
|
"on the UCS category explanations and synonymns (PATHS will "
|
||||||
"be ignored.)")
|
"be ignored.)")
|
||||||
@click.argument('paths', nargs=-1)
|
@click.argument('paths', nargs=-1)
|
||||||
def gather(paths, out, ucs_data):
|
@click.pass_context
|
||||||
|
def gather(ctx, paths, out, ucs_data):
|
||||||
"""
|
"""
|
||||||
Scan files to build a training dataset
|
Scan files to build a training dataset
|
||||||
|
|
||||||
@@ -142,7 +154,7 @@ def gather(paths, out, ucs_data):
|
|||||||
types = ['.wav', '.flac']
|
types = ['.wav', '.flac']
|
||||||
|
|
||||||
logger.debug(f"Loading category list...")
|
logger.debug(f"Loading category list...")
|
||||||
ucs = load_ucs()
|
ucs = load_ucs(full_ucs=ctx.obj['complete_ucs'])
|
||||||
|
|
||||||
scan_list = []
|
scan_list = []
|
||||||
catid_list = [cat.catid for cat in ucs]
|
catid_list = [cat.catid for cat in ucs]
|
||||||
@@ -152,32 +164,44 @@ def gather(paths, out, ucs_data):
|
|||||||
paths = []
|
paths = []
|
||||||
|
|
||||||
for path in paths:
|
for path in paths:
|
||||||
logger.info(f"Scanning directory {path}...")
|
|
||||||
for dirpath, _, filenames in os.walk(path):
|
for dirpath, _, filenames in os.walk(path):
|
||||||
|
logger.info(f"Walking directory {dirpath}")
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
root, ext = os.path.splitext(filename)
|
root, ext = os.path.splitext(filename)
|
||||||
if ext in types and \
|
if ext not in types or filename.startswith("._"):
|
||||||
(ucs_components := parse_ucs(root, catid_list)) and \
|
continue
|
||||||
not filename.startswith("._"):
|
|
||||||
scan_list.append((ucs_components.cat_id,
|
if (ucs_components := parse_ucs(root, catid_list)):
|
||||||
os.path.join(dirpath, filename)))
|
p = os.path.join(dirpath, filename)
|
||||||
|
logger.info(f"Adding path to scan list {p}")
|
||||||
|
scan_list.append((ucs_components.cat_id, p))
|
||||||
|
|
||||||
logger.info(f"Found {len(scan_list)} files to process.")
|
logger.info(f"Found {len(scan_list)} files to process.")
|
||||||
|
|
||||||
def scan_metadata():
|
def scan_metadata():
|
||||||
for pair in tqdm.tqdm(scan_list, unit='files'):
|
for pair in tqdm.tqdm(scan_list, unit='files'):
|
||||||
if desc := ffmpeg_description(pair[1]):
|
logger.info(f"Scanning file with ffprobe: {pair[1]}")
|
||||||
|
try:
|
||||||
|
desc = ffmpeg_description(pair[1])
|
||||||
|
except CalledProcessError as e:
|
||||||
|
logger.error(f"ffprobe returned error {e.returncode}: " \
|
||||||
|
+ e.stderr)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if desc:
|
||||||
yield desc, str(pair[0])
|
yield desc, str(pair[0])
|
||||||
else:
|
else:
|
||||||
comps = parse_ucs(os.path.basename(pair[1]), catid_list)
|
comps = parse_ucs(os.path.basename(pair[1]), catid_list)
|
||||||
assert comps
|
assert comps
|
||||||
yield comps.fx_name, str(pair[0])
|
yield comps.fx_name, str(pair[0])
|
||||||
|
|
||||||
|
|
||||||
def ucs_metadata():
|
def ucs_metadata():
|
||||||
for cat in ucs:
|
for cat in ucs:
|
||||||
yield cat.explanations, cat.catid
|
yield cat.explanations, cat.catid
|
||||||
yield ", ".join(cat.synonymns), cat.catid
|
yield ", ".join(cat.synonymns), cat.catid
|
||||||
|
|
||||||
|
logger.info("Building dataset...")
|
||||||
if ucs_data:
|
if ucs_data:
|
||||||
dataset = build_sentence_class_dataset(ucs_metadata(), catid_list)
|
dataset = build_sentence_class_dataset(ucs_metadata(), catid_list)
|
||||||
else:
|
else:
|
||||||
@@ -202,9 +226,6 @@ def finetune(ctx):
|
|||||||
help='Skip this many records in the dataset before processing')
|
help='Skip this many records in the dataset before processing')
|
||||||
@click.option('--limit', type=int, default=-1, metavar="<int>",
|
@click.option('--limit', type=int, default=-1, metavar="<int>",
|
||||||
help='Process this many records and then exit')
|
help='Process this many records and then exit')
|
||||||
@click.option('--no-foley', 'no_foley', flag_value=True, default=False,
|
|
||||||
help="Ignore any data in the set with FOLYProp or FOLYFeet "
|
|
||||||
"category")
|
|
||||||
@click.argument('dataset', type=click.File('r', encoding='utf8'),
|
@click.argument('dataset', type=click.File('r', encoding='utf8'),
|
||||||
default='dataset.csv')
|
default='dataset.csv')
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
@@ -229,84 +250,86 @@ def evaluate(ctx, dataset, offset, limit, no_foley):
|
|||||||
foley, and so these categories can be excluded with the --no-foley option.
|
foley, and so these categories can be excluded with the --no-foley option.
|
||||||
"""
|
"""
|
||||||
logger.debug("EVALUATE mode")
|
logger.debug("EVALUATE mode")
|
||||||
inference_context = InferenceContext(ctx.obj['model_name'],
|
logger.warning("Model evaluation is not currently implemented")
|
||||||
use_cached_model=
|
# inference_context = InferenceContext(
|
||||||
ctx.obj['model_cache'])
|
# ctx.obj['model_name'], use_cached_model=ctx.obj['model_cache'],
|
||||||
reader = csv.reader(dataset)
|
# use_full_ucs=ctx.obj['complete_ucs'])
|
||||||
|
#
|
||||||
results = []
|
# reader = csv.reader(dataset)
|
||||||
|
#
|
||||||
if offset > 0:
|
# results = []
|
||||||
logger.debug(f"Skipping {offset} records...")
|
#
|
||||||
|
# if offset > 0:
|
||||||
if limit > 0:
|
# logger.debug(f"Skipping {offset} records...")
|
||||||
logger.debug(f"Will only evaluate {limit} records...")
|
#
|
||||||
|
# if limit > 0:
|
||||||
progress_bar = tqdm.tqdm(total=limit,
|
# logger.debug(f"Will only evaluate {limit} records...")
|
||||||
desc="Processing dataset...",
|
#
|
||||||
unit="rec")
|
# progress_bar = tqdm.tqdm(total=limit,
|
||||||
for i, row in enumerate(reader):
|
# desc="Processing dataset...",
|
||||||
if i < offset:
|
# unit="rec")
|
||||||
continue
|
# for i, row in enumerate(reader):
|
||||||
|
# if i < offset:
|
||||||
if limit > 0 and i >= limit + offset:
|
# continue
|
||||||
break
|
#
|
||||||
|
# if limit > 0 and i >= limit + offset:
|
||||||
cat_id, description = row
|
# break
|
||||||
if no_foley and cat_id in ['FOLYProp', 'FOLYFeet']:
|
#
|
||||||
continue
|
# cat_id, description = row
|
||||||
|
# if no_foley and cat_id in ['FOLYProp', 'FOLYFeet']:
|
||||||
guesses = inference_context.classify_text_ranked(description, limit=10)
|
# continue
|
||||||
if cat_id == guesses[0]:
|
#
|
||||||
results.append({'catid': cat_id, 'result': "TOP"})
|
# guesses = inference_context.classify_text_ranked(description, limit=10)
|
||||||
elif cat_id in guesses[0:5]:
|
# if cat_id == guesses[0]:
|
||||||
results.append({'catid': cat_id, 'result': "TOP_5"})
|
# results.append({'catid': cat_id, 'result': "TOP"})
|
||||||
elif cat_id in guesses:
|
# elif cat_id in guesses[0:5]:
|
||||||
results.append({'catid': cat_id, 'result': "TOP_10"})
|
# results.append({'catid': cat_id, 'result': "TOP_5"})
|
||||||
else:
|
# elif cat_id in guesses:
|
||||||
results.append({'catid': cat_id, 'result': "MISS"})
|
# results.append({'catid': cat_id, 'result': "TOP_10"})
|
||||||
|
# else:
|
||||||
progress_bar.update(1)
|
# results.append({'catid': cat_id, 'result': "MISS"})
|
||||||
|
#
|
||||||
total = len(results)
|
# progress_bar.update(1)
|
||||||
total_top = len([x for x in results if x['result'] == 'TOP'])
|
#
|
||||||
total_top_5 = len([x for x in results if x['result'] == 'TOP_5'])
|
# total = len(results)
|
||||||
total_top_10 = len([x for x in results if x['result'] == 'TOP_10'])
|
# total_top = len([x for x in results if x['result'] == 'TOP'])
|
||||||
|
# total_top_5 = len([x for x in results if x['result'] == 'TOP_5'])
|
||||||
cats = set([x['catid'] for x in results])
|
# total_top_10 = len([x for x in results if x['result'] == 'TOP_10'])
|
||||||
total_cats = len(cats)
|
#
|
||||||
|
# cats = set([x['catid'] for x in results])
|
||||||
miss_counts = []
|
# total_cats = len(cats)
|
||||||
for cat in cats:
|
#
|
||||||
miss_counts.append(
|
# miss_counts = []
|
||||||
(cat, len([x for x in results
|
# for cat in cats:
|
||||||
if x['catid'] == cat and x['result'] == 'MISS'])))
|
# miss_counts.append(
|
||||||
|
# (cat, len([x for x in results
|
||||||
miss_counts = sorted(miss_counts, key=lambda x: x[1])
|
# if x['catid'] == cat and x['result'] == 'MISS'])))
|
||||||
|
#
|
||||||
print(f"## Results for Model {model} ##\n")
|
# miss_counts = sorted(miss_counts, key=lambda x: x[1])
|
||||||
|
#
|
||||||
if no_foley:
|
# print(f"## Results for Model {model} ##\n")
|
||||||
print("(FOLYProp and FOLYFeet have been omitted from the dataset.)\n")
|
#
|
||||||
|
# if no_foley:
|
||||||
table = [
|
# print("(FOLYProp and FOLYFeet have been omitted from the dataset.)\n")
|
||||||
["Total records in sample:", f"{total}"],
|
#
|
||||||
["Top Result:", f"{total_top}",
|
# table = [
|
||||||
f"{float(total_top)/float(total):.2%}"],
|
# ["Total records in sample:", f"{total}"],
|
||||||
["Top 5 Result:", f"{total_top_5}",
|
# ["Top Result:", f"{total_top}",
|
||||||
f"{float(total_top_5)/float(total):.2%}"],
|
# f"{float(total_top)/float(total):.2%}"],
|
||||||
["Top 10 Result:", f"{total_top_10}",
|
# ["Top 5 Result:", f"{total_top_5}",
|
||||||
f"{float(total_top_10)/float(total):.2%}"],
|
# f"{float(total_top_5)/float(total):.2%}"],
|
||||||
SEPARATING_LINE,
|
# ["Top 10 Result:", f"{total_top_10}",
|
||||||
["UCS category count:", f"{len(inference_context.catlist)}"],
|
# f"{float(total_top_10)/float(total):.2%}"],
|
||||||
["Total categories in sample:", f"{total_cats}",
|
# SEPARATING_LINE,
|
||||||
f"{float(total_cats)/float(len(inference_context.catlist)):.2%}"],
|
# ["UCS category count:", f"{len(inference_context.catlist)}"],
|
||||||
[f"Most missed category ({miss_counts[-1][0]}):",
|
# ["Total categories in sample:", f"{total_cats}",
|
||||||
f"{miss_counts[-1][1]}",
|
# f"{float(total_cats)/float(len(inference_context.catlist)):.2%}"],
|
||||||
f"{float(miss_counts[-1][1])/float(total):.2%}"]
|
# [f"Most missed category ({miss_counts[-1][0]}):",
|
||||||
]
|
# f"{miss_counts[-1][1]}",
|
||||||
|
# f"{float(miss_counts[-1][1])/float(total):.2%}"]
|
||||||
print(tabulate(table, headers=['', 'n', 'pct'], tablefmt='github'))
|
# ]
|
||||||
|
#
|
||||||
|
# print(tabulate(table, headers=['', 'n', 'pct'], tablefmt='github'))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
33
ucsinfer/gather.py
Normal file
33
ucsinfer/gather.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
from datasets import Dataset, Features, Value, ClassLabel
|
||||||
|
|
||||||
|
from typing import Generator, Any
|
||||||
|
|
||||||
|
# https://www.sbert.net/docs/sentence_transformer/loss_overview.html
|
||||||
|
|
||||||
|
def build_sentence_class_dataset(
|
||||||
|
records: Generator[tuple[str, str], Any, None], catlist: list[str]) -> Dataset:
|
||||||
|
"""
|
||||||
|
Create a new dataset for `records` which contains (sentence, class) pairs.
|
||||||
|
|
||||||
|
:param records: a generator for records that generates pairs of
|
||||||
|
(sentence, catid)
|
||||||
|
:returns: A dataset with two columns: (sentence, hash(catid))
|
||||||
|
"""
|
||||||
|
|
||||||
|
labels = ClassLabel(names=catlist)
|
||||||
|
|
||||||
|
items: list[dict] = []
|
||||||
|
for obj in records:
|
||||||
|
items += [{'sentence': obj[0], 'class': obj[1]}]
|
||||||
|
|
||||||
|
|
||||||
|
return Dataset.from_list(items, features=Features({'sentence': Value('string'),
|
||||||
|
'class': labels}))
|
||||||
|
|
||||||
|
|
||||||
|
# def build_sentence_anchor_dataset() -> Dataset:
|
||||||
|
# """
|
||||||
|
# Create a new dataset for `records` which contains (sentence, anchor) pairs.
|
||||||
|
# """
|
||||||
|
# pass
|
||||||
|
|
@@ -34,7 +34,7 @@ class Ucs(NamedTuple):
|
|||||||
explanations=d['Explanations'], synonymns=d['Synonyms'])
|
explanations=d['Explanations'], synonymns=d['Synonyms'])
|
||||||
|
|
||||||
|
|
||||||
def load_ucs() -> list[Ucs]:
|
def load_ucs(full_ucs: bool = True) -> list[Ucs]:
|
||||||
FILE_ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
FILE_ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
cats = []
|
cats = []
|
||||||
ucs_defs = os.path.join(FILE_ROOT_DIR, 'ucs-community', 'json',
|
ucs_defs = os.path.join(FILE_ROOT_DIR, 'ucs-community', 'json',
|
||||||
@@ -43,7 +43,13 @@ def load_ucs() -> list[Ucs]:
|
|||||||
with open(ucs_defs, 'r') as f:
|
with open(ucs_defs, 'r') as f:
|
||||||
cats = json.load(f)
|
cats = json.load(f)
|
||||||
|
|
||||||
return [Ucs.from_dict(cat) for cat in cats]
|
ucs = [Ucs.from_dict(cat) for cat in cats]
|
||||||
|
|
||||||
|
if full_ucs:
|
||||||
|
return ucs
|
||||||
|
else:
|
||||||
|
return [cat for cat in ucs if \
|
||||||
|
cat.category not in ['FOLEY', 'ARCHIVED']]
|
||||||
|
|
||||||
|
|
||||||
class InferenceContext:
|
class InferenceContext:
|
||||||
@@ -54,8 +60,11 @@ class InferenceContext:
|
|||||||
model: SentenceTransformer
|
model: SentenceTransformer
|
||||||
model_name: str
|
model_name: str
|
||||||
|
|
||||||
def __init__(self, model_name: str, use_cached_model: bool = True):
|
def __init__(self, model_name: str, use_cached_model: bool = True,
|
||||||
|
use_full_ucs: bool = False):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
self.use_full_ucs = use_full_ucs
|
||||||
|
|
||||||
cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer",
|
cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer",
|
||||||
'Squad 51')
|
'Squad 51')
|
||||||
|
|
||||||
@@ -75,7 +84,7 @@ class InferenceContext:
|
|||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def catlist(self) -> list[Ucs]:
|
def catlist(self) -> list[Ucs]:
|
||||||
return load_ucs()
|
return load_ucs(full_ucs=self.use_full_ucs)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def embeddings(self) -> list[dict]:
|
def embeddings(self) -> list[dict]:
|
||||||
|
@@ -9,10 +9,7 @@ def ffmpeg_description(path: str) -> Optional[str]:
|
|||||||
result = subprocess.run(['ffprobe', '-show_format', '-of',
|
result = subprocess.run(['ffprobe', '-show_format', '-of',
|
||||||
'json', path], capture_output=True)
|
'json', path], capture_output=True)
|
||||||
|
|
||||||
try:
|
result.check_returncode()
|
||||||
result.check_returncode()
|
|
||||||
except:
|
|
||||||
return None
|
|
||||||
|
|
||||||
stream = json.loads(result.stdout)
|
stream = json.loads(result.stdout)
|
||||||
fmt = stream.get("format", None)
|
fmt = stream.get("format", None)
|
||||||
|
Reference in New Issue
Block a user