Compare commits
3 Commits
e419f698c9
...
04332b73ee
Author | SHA1 | Date | |
---|---|---|---|
04332b73ee | |||
103fffe0a4 | |||
10519f9c1a |
3
TODO.md
3
TODO.md
@@ -17,8 +17,9 @@
|
||||
|
||||
- Print more information about the dataset coverage of UCS
|
||||
- Allow skipping model testing for this
|
||||
|
||||
- Print raw output
|
||||
- Maybe load everything into a sqlite for slicker reporting
|
||||
<!-- - Maybe load everything into a sqlite for slicker reporting -->
|
||||
|
||||
|
||||
## Utility
|
||||
|
@@ -2,13 +2,14 @@ import os
|
||||
# import csv
|
||||
import logging
|
||||
from subprocess import CalledProcessError
|
||||
from itertools import chain
|
||||
|
||||
import tqdm
|
||||
import click
|
||||
# from tabulate import tabulate, SEPARATING_LINE
|
||||
|
||||
from .inference import InferenceContext, load_ucs
|
||||
from .gather import build_sentence_class_dataset
|
||||
from .gather import build_sentence_class_dataset, print_dataset_stats
|
||||
from .recommend import print_recommendation
|
||||
from .util import ffmpeg_description, parse_ucs
|
||||
|
||||
@@ -163,10 +164,12 @@ def gather(ctx, paths, out, ucs_data):
|
||||
logger.info('Creating dataset for UCS categories instead of from PATH')
|
||||
paths = []
|
||||
|
||||
walker_p = tqdm.tqdm(total=None, unit='dir', desc="Walking filesystem...")
|
||||
for path in paths:
|
||||
for dirpath, _, filenames in os.walk(path):
|
||||
logger.info(f"Walking directory {dirpath}")
|
||||
for filename in filenames:
|
||||
walker_p.update()
|
||||
root, ext = os.path.splitext(filename)
|
||||
if ext not in types or filename.startswith("._"):
|
||||
continue
|
||||
@@ -175,6 +178,7 @@ def gather(ctx, paths, out, ucs_data):
|
||||
p = os.path.join(dirpath, filename)
|
||||
logger.info(f"Adding path to scan list {p}")
|
||||
scan_list.append((ucs_components.cat_id, p))
|
||||
walker_p.close()
|
||||
|
||||
logger.info(f"Found {len(scan_list)} files to process.")
|
||||
|
||||
@@ -184,7 +188,7 @@ def gather(ctx, paths, out, ucs_data):
|
||||
try:
|
||||
desc = ffmpeg_description(pair[1])
|
||||
except CalledProcessError as e:
|
||||
logger.error(f"ffprobe returned error {e.returncode}: " \
|
||||
logger.error(f"ffprobe returned error (){e.returncode}): " \
|
||||
+ e.stderr)
|
||||
continue
|
||||
|
||||
@@ -195,19 +199,20 @@ def gather(ctx, paths, out, ucs_data):
|
||||
assert comps
|
||||
yield comps.fx_name, str(pair[0])
|
||||
|
||||
|
||||
def ucs_metadata():
|
||||
for cat in ucs:
|
||||
yield cat.explanations, cat.catid
|
||||
yield ", ".join(cat.synonymns), cat.catid
|
||||
|
||||
logger.info("Building dataset...")
|
||||
if ucs_data:
|
||||
dataset = build_sentence_class_dataset(ucs_metadata(), catid_list)
|
||||
else:
|
||||
dataset = build_sentence_class_dataset(scan_metadata(), catid_list)
|
||||
|
||||
dataset = build_sentence_class_dataset(chain(scan_metadata(),
|
||||
ucs_metadata()),
|
||||
catid_list)
|
||||
|
||||
|
||||
logger.info(f"Saving dataset to disk at {out}")
|
||||
print_dataset_stats(dataset)
|
||||
dataset.save_to_disk(out)
|
||||
|
||||
|
||||
@@ -222,114 +227,15 @@ def finetune(ctx):
|
||||
|
||||
|
||||
@ucsinfer.command('evaluate')
|
||||
@click.option('--offset', type=int, default=0, metavar="<int>",
|
||||
help='Skip this many records in the dataset before processing')
|
||||
@click.option('--limit', type=int, default=-1, metavar="<int>",
|
||||
help='Process this many records and then exit')
|
||||
@click.argument('dataset', type=click.File('r', encoding='utf8'),
|
||||
default='dataset.csv')
|
||||
@click.argument('dataset', default='dataset/')
|
||||
@click.pass_context
|
||||
def evaluate(ctx, dataset, offset, limit, no_foley):
|
||||
def evaluate(ctx, dataset, offset, limit):
|
||||
"""
|
||||
Use datasets to evaluate model performance
|
||||
|
||||
The `evaluate` command reads the input DATASET file row by row and
|
||||
performs a classifcation of the given description against the selected
|
||||
model (either the default or using the --model option). The command then
|
||||
checks if the model inferred the correct category as given by the dataset.
|
||||
|
||||
The model gives its top 10 possible categories for a given description,
|
||||
and the results are tabulated according to (1) wether the top
|
||||
classification was correct, (2) wether the correct classifcation was in the
|
||||
top 5, or (3) wether it was in the top 10. The worst-performing category,
|
||||
the one with the most misses, is also reported as well as the category
|
||||
coverage, how many categories are present in the dataset.
|
||||
|
||||
NOTE: With experimentation it was found that foley items generally were
|
||||
classified according to their subject and not wether or not they were
|
||||
foley, and so these categories can be excluded with the --no-foley option.
|
||||
|
||||
"""
|
||||
logger.debug("EVALUATE mode")
|
||||
logger.warning("Model evaluation is not currently implemented")
|
||||
# inference_context = InferenceContext(
|
||||
# ctx.obj['model_name'], use_cached_model=ctx.obj['model_cache'],
|
||||
# use_full_ucs=ctx.obj['complete_ucs'])
|
||||
#
|
||||
# reader = csv.reader(dataset)
|
||||
#
|
||||
# results = []
|
||||
#
|
||||
# if offset > 0:
|
||||
# logger.debug(f"Skipping {offset} records...")
|
||||
#
|
||||
# if limit > 0:
|
||||
# logger.debug(f"Will only evaluate {limit} records...")
|
||||
#
|
||||
# progress_bar = tqdm.tqdm(total=limit,
|
||||
# desc="Processing dataset...",
|
||||
# unit="rec")
|
||||
# for i, row in enumerate(reader):
|
||||
# if i < offset:
|
||||
# continue
|
||||
#
|
||||
# if limit > 0 and i >= limit + offset:
|
||||
# break
|
||||
#
|
||||
# cat_id, description = row
|
||||
# if no_foley and cat_id in ['FOLYProp', 'FOLYFeet']:
|
||||
# continue
|
||||
#
|
||||
# guesses = inference_context.classify_text_ranked(description, limit=10)
|
||||
# if cat_id == guesses[0]:
|
||||
# results.append({'catid': cat_id, 'result': "TOP"})
|
||||
# elif cat_id in guesses[0:5]:
|
||||
# results.append({'catid': cat_id, 'result': "TOP_5"})
|
||||
# elif cat_id in guesses:
|
||||
# results.append({'catid': cat_id, 'result': "TOP_10"})
|
||||
# else:
|
||||
# results.append({'catid': cat_id, 'result': "MISS"})
|
||||
#
|
||||
# progress_bar.update(1)
|
||||
#
|
||||
# total = len(results)
|
||||
# total_top = len([x for x in results if x['result'] == 'TOP'])
|
||||
# total_top_5 = len([x for x in results if x['result'] == 'TOP_5'])
|
||||
# total_top_10 = len([x for x in results if x['result'] == 'TOP_10'])
|
||||
#
|
||||
# cats = set([x['catid'] for x in results])
|
||||
# total_cats = len(cats)
|
||||
#
|
||||
# miss_counts = []
|
||||
# for cat in cats:
|
||||
# miss_counts.append(
|
||||
# (cat, len([x for x in results
|
||||
# if x['catid'] == cat and x['result'] == 'MISS'])))
|
||||
#
|
||||
# miss_counts = sorted(miss_counts, key=lambda x: x[1])
|
||||
#
|
||||
# print(f"## Results for Model {model} ##\n")
|
||||
#
|
||||
# if no_foley:
|
||||
# print("(FOLYProp and FOLYFeet have been omitted from the dataset.)\n")
|
||||
#
|
||||
# table = [
|
||||
# ["Total records in sample:", f"{total}"],
|
||||
# ["Top Result:", f"{total_top}",
|
||||
# f"{float(total_top)/float(total):.2%}"],
|
||||
# ["Top 5 Result:", f"{total_top_5}",
|
||||
# f"{float(total_top_5)/float(total):.2%}"],
|
||||
# ["Top 10 Result:", f"{total_top_10}",
|
||||
# f"{float(total_top_10)/float(total):.2%}"],
|
||||
# SEPARATING_LINE,
|
||||
# ["UCS category count:", f"{len(inference_context.catlist)}"],
|
||||
# ["Total categories in sample:", f"{total_cats}",
|
||||
# f"{float(total_cats)/float(len(inference_context.catlist)):.2%}"],
|
||||
# [f"Most missed category ({miss_counts[-1][0]}):",
|
||||
# f"{miss_counts[-1][1]}",
|
||||
# f"{float(miss_counts[-1][1])/float(total):.2%}"]
|
||||
# ]
|
||||
#
|
||||
# print(tabulate(table, headers=['', 'n', 'pct'], tablefmt='github'))
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
7
ucsinfer/evaluate.py
Normal file
7
ucsinfer/evaluate.py
Normal file
@@ -0,0 +1,7 @@
|
||||
|
||||
|
||||
# from sentence_transformers import SentenceTransformer
|
||||
# from sentence_transformers.evaluation import BinaryClassificationEvaluator
|
||||
# from datasets import load_dataset_from_disk, DatasetDict
|
||||
#
|
||||
|
@@ -1,13 +1,28 @@
|
||||
from datasets import Dataset, Features, Value, ClassLabel, DatasetInfo
|
||||
from datasets.dataset_dict import DatasetDict
|
||||
|
||||
from typing import Generator, Any
|
||||
from typing import Iterator
|
||||
|
||||
from tabulate import tabulate
|
||||
|
||||
def print_dataset_stats(dataset: DatasetDict, catlist: list[str]):
|
||||
|
||||
data_table = []
|
||||
data_table.append([["Total records in combined dataset:", len(dataset)]])
|
||||
data_table.append([["Total records in `train`:", len(dataset['train'])]])
|
||||
|
||||
tab = tabulate(data_table)
|
||||
|
||||
print(tab)
|
||||
|
||||
# https://www.sbert.net/docs/sentence_transformer/loss_overview.html
|
||||
|
||||
def build_sentence_class_dataset(
|
||||
records: Generator[tuple[str, str], Any, None], catlist: list[str]) -> Dataset:
|
||||
records: Iterator[tuple[str, str]],
|
||||
catlist: list[str]) -> DatasetDict:
|
||||
"""
|
||||
Create a new dataset for `records` which contains (sentence, class) pairs.
|
||||
The dataset is split into train and test slices.
|
||||
|
||||
:param records: a generator for records that generates pairs of
|
||||
(sentence, catid)
|
||||
@@ -16,9 +31,12 @@ def build_sentence_class_dataset(
|
||||
|
||||
labels = ClassLabel(names=catlist)
|
||||
|
||||
features = Features({'sentence': Value('string'),
|
||||
'class': labels})
|
||||
|
||||
info = DatasetInfo(
|
||||
description=f"(sentence, UCS CatID) pairs gathered by the "
|
||||
"ucsinfer tool on {}")
|
||||
"ucsinfer tool on {}", features= features)
|
||||
|
||||
|
||||
items: list[dict] = []
|
||||
@@ -26,9 +44,16 @@ def build_sentence_class_dataset(
|
||||
items += [{'sentence': obj[0], 'class': obj[1]}]
|
||||
|
||||
|
||||
return Dataset.from_list(items, features=Features({'sentence': Value('string'),
|
||||
'class': labels}),
|
||||
info=info)
|
||||
whole = Dataset.from_list(items, features=features, info=info)
|
||||
|
||||
split_set = whole.train_test_split(0.2)
|
||||
test_eval_set = split_set['test'].train_test_split(0.5)
|
||||
|
||||
return DatasetDict({
|
||||
'train': split_set['train'],
|
||||
'test': test_eval_set['train'],
|
||||
'eval': test_eval_set['test']
|
||||
})
|
||||
|
||||
|
||||
# def build_sentence_anchor_dataset() -> Dataset:
|
||||
|
@@ -103,7 +103,12 @@ class InferenceContext:
|
||||
else:
|
||||
print(f"Calculating embeddings for model {self.model_name}...")
|
||||
|
||||
for cat_defn in self.catlist:
|
||||
# we need to calculate the embeddings for all cats, not just the
|
||||
# ones we're loading for this run
|
||||
|
||||
full_catlist = load_ucs(full_ucs= True)
|
||||
|
||||
for cat_defn in full_catlist:
|
||||
embeddings += [{
|
||||
'CatID': cat_defn.catid,
|
||||
'Embedding': self._encode_category(cat_defn)
|
||||
@@ -113,7 +118,9 @@ class InferenceContext:
|
||||
with open(embedding_cache_path, 'wb') as g:
|
||||
pickle.dump(embeddings, g)
|
||||
|
||||
return embeddings
|
||||
whitelisted_cats = [cat.catid for cat in self.catlist]
|
||||
|
||||
return [e for e in embeddings if e['CatID'] in whitelisted_cats]
|
||||
|
||||
def _encode_category(self, cat: Ucs) -> np.ndarray:
|
||||
sentence_components = [cat.explanations,
|
||||
|
@@ -1,7 +1,6 @@
|
||||
# recommend.py
|
||||
|
||||
from re import match
|
||||
|
||||
from .inference import InferenceContext
|
||||
|
||||
def print_recommendation(path: str | None, text: str, ctx: InferenceContext,
|
||||
@@ -23,6 +22,7 @@ def print_recommendation(path: str | None, text: str, ctx: InferenceContext,
|
||||
print(f"Path: {path}")
|
||||
|
||||
print(f"Text: {text or '<None>'}")
|
||||
|
||||
for i, r in enumerate(recs):
|
||||
cat, subcat, _ = ctx.lookup_category(r)
|
||||
print(f"- {i}: {r} ({cat}-{subcat})")
|
||||
@@ -42,6 +42,9 @@ def print_recommendation(path: str | None, text: str, ctx: InferenceContext,
|
||||
print("searching for new matches")
|
||||
text = m.group(1)
|
||||
return True, text, None
|
||||
|
||||
elif m := match(r'^c (.*)', response):
|
||||
return True, None, m.group(1)
|
||||
|
||||
elif response.startswith("?"):
|
||||
print("""
|
||||
@@ -49,6 +52,7 @@ Choices:
|
||||
- Enter recommendation number to rename file,
|
||||
- "t [text]" to search for new recommendations based on [text]
|
||||
- "p" re-use the last selected cat-id
|
||||
- "c [cat]" to type in a category by hand
|
||||
- "?" for this message
|
||||
- "q" to quit
|
||||
- or any other key to skip this file and continue to next file
|
||||
|
Reference in New Issue
Block a user