Added another function to recommend
This commit is contained in:
@@ -2,13 +2,14 @@ import os
|
|||||||
# import csv
|
# import csv
|
||||||
import logging
|
import logging
|
||||||
from subprocess import CalledProcessError
|
from subprocess import CalledProcessError
|
||||||
|
from itertools import chain
|
||||||
|
|
||||||
import tqdm
|
import tqdm
|
||||||
import click
|
import click
|
||||||
# from tabulate import tabulate, SEPARATING_LINE
|
# from tabulate import tabulate, SEPARATING_LINE
|
||||||
|
|
||||||
from .inference import InferenceContext, load_ucs
|
from .inference import InferenceContext, load_ucs
|
||||||
from .gather import build_sentence_class_dataset
|
from .gather import build_sentence_class_dataset, print_dataset_stats
|
||||||
from .recommend import print_recommendation
|
from .recommend import print_recommendation
|
||||||
from .util import ffmpeg_description, parse_ucs
|
from .util import ffmpeg_description, parse_ucs
|
||||||
|
|
||||||
@@ -198,19 +199,20 @@ def gather(ctx, paths, out, ucs_data):
|
|||||||
assert comps
|
assert comps
|
||||||
yield comps.fx_name, str(pair[0])
|
yield comps.fx_name, str(pair[0])
|
||||||
|
|
||||||
|
|
||||||
def ucs_metadata():
|
def ucs_metadata():
|
||||||
for cat in ucs:
|
for cat in ucs:
|
||||||
yield cat.explanations, cat.catid
|
yield cat.explanations, cat.catid
|
||||||
yield ", ".join(cat.synonymns), cat.catid
|
yield ", ".join(cat.synonymns), cat.catid
|
||||||
|
|
||||||
logger.info("Building dataset...")
|
logger.info("Building dataset...")
|
||||||
if ucs_data:
|
|
||||||
dataset = build_sentence_class_dataset(ucs_metadata(), catid_list)
|
dataset = build_sentence_class_dataset(chain(scan_metadata(),
|
||||||
else:
|
ucs_metadata()),
|
||||||
dataset = build_sentence_class_dataset(scan_metadata(), catid_list)
|
catid_list)
|
||||||
|
|
||||||
|
|
||||||
logger.info(f"Saving dataset to disk at {out}")
|
logger.info(f"Saving dataset to disk at {out}")
|
||||||
|
print_dataset_stats(dataset)
|
||||||
dataset.save_to_disk(out)
|
dataset.save_to_disk(out)
|
||||||
|
|
||||||
|
|
||||||
@@ -225,114 +227,15 @@ def finetune(ctx):
|
|||||||
|
|
||||||
|
|
||||||
@ucsinfer.command('evaluate')
|
@ucsinfer.command('evaluate')
|
||||||
@click.option('--offset', type=int, default=0, metavar="<int>",
|
@click.argument('dataset', default='dataset/')
|
||||||
help='Skip this many records in the dataset before processing')
|
|
||||||
@click.option('--limit', type=int, default=-1, metavar="<int>",
|
|
||||||
help='Process this many records and then exit')
|
|
||||||
@click.argument('dataset', type=click.File('r', encoding='utf8'),
|
|
||||||
default='dataset.csv')
|
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
def evaluate(ctx, dataset, offset, limit, no_foley):
|
def evaluate(ctx, dataset, offset, limit):
|
||||||
"""
|
"""
|
||||||
Use datasets to evaluate model performance
|
|
||||||
|
|
||||||
The `evaluate` command reads the input DATASET file row by row and
|
|
||||||
performs a classifcation of the given description against the selected
|
|
||||||
model (either the default or using the --model option). The command then
|
|
||||||
checks if the model inferred the correct category as given by the dataset.
|
|
||||||
|
|
||||||
The model gives its top 10 possible categories for a given description,
|
|
||||||
and the results are tabulated according to (1) wether the top
|
|
||||||
classification was correct, (2) wether the correct classifcation was in the
|
|
||||||
top 5, or (3) wether it was in the top 10. The worst-performing category,
|
|
||||||
the one with the most misses, is also reported as well as the category
|
|
||||||
coverage, how many categories are present in the dataset.
|
|
||||||
|
|
||||||
NOTE: With experimentation it was found that foley items generally were
|
|
||||||
classified according to their subject and not wether or not they were
|
|
||||||
foley, and so these categories can be excluded with the --no-foley option.
|
|
||||||
"""
|
"""
|
||||||
logger.debug("EVALUATE mode")
|
logger.debug("EVALUATE mode")
|
||||||
logger.warning("Model evaluation is not currently implemented")
|
logger.warning("Model evaluation is not currently implemented")
|
||||||
# inference_context = InferenceContext(
|
|
||||||
# ctx.obj['model_name'], use_cached_model=ctx.obj['model_cache'],
|
|
||||||
# use_full_ucs=ctx.obj['complete_ucs'])
|
|
||||||
#
|
|
||||||
# reader = csv.reader(dataset)
|
|
||||||
#
|
|
||||||
# results = []
|
|
||||||
#
|
|
||||||
# if offset > 0:
|
|
||||||
# logger.debug(f"Skipping {offset} records...")
|
|
||||||
#
|
|
||||||
# if limit > 0:
|
|
||||||
# logger.debug(f"Will only evaluate {limit} records...")
|
|
||||||
#
|
|
||||||
# progress_bar = tqdm.tqdm(total=limit,
|
|
||||||
# desc="Processing dataset...",
|
|
||||||
# unit="rec")
|
|
||||||
# for i, row in enumerate(reader):
|
|
||||||
# if i < offset:
|
|
||||||
# continue
|
|
||||||
#
|
|
||||||
# if limit > 0 and i >= limit + offset:
|
|
||||||
# break
|
|
||||||
#
|
|
||||||
# cat_id, description = row
|
|
||||||
# if no_foley and cat_id in ['FOLYProp', 'FOLYFeet']:
|
|
||||||
# continue
|
|
||||||
#
|
|
||||||
# guesses = inference_context.classify_text_ranked(description, limit=10)
|
|
||||||
# if cat_id == guesses[0]:
|
|
||||||
# results.append({'catid': cat_id, 'result': "TOP"})
|
|
||||||
# elif cat_id in guesses[0:5]:
|
|
||||||
# results.append({'catid': cat_id, 'result': "TOP_5"})
|
|
||||||
# elif cat_id in guesses:
|
|
||||||
# results.append({'catid': cat_id, 'result': "TOP_10"})
|
|
||||||
# else:
|
|
||||||
# results.append({'catid': cat_id, 'result': "MISS"})
|
|
||||||
#
|
|
||||||
# progress_bar.update(1)
|
|
||||||
#
|
|
||||||
# total = len(results)
|
|
||||||
# total_top = len([x for x in results if x['result'] == 'TOP'])
|
|
||||||
# total_top_5 = len([x for x in results if x['result'] == 'TOP_5'])
|
|
||||||
# total_top_10 = len([x for x in results if x['result'] == 'TOP_10'])
|
|
||||||
#
|
|
||||||
# cats = set([x['catid'] for x in results])
|
|
||||||
# total_cats = len(cats)
|
|
||||||
#
|
|
||||||
# miss_counts = []
|
|
||||||
# for cat in cats:
|
|
||||||
# miss_counts.append(
|
|
||||||
# (cat, len([x for x in results
|
|
||||||
# if x['catid'] == cat and x['result'] == 'MISS'])))
|
|
||||||
#
|
|
||||||
# miss_counts = sorted(miss_counts, key=lambda x: x[1])
|
|
||||||
#
|
|
||||||
# print(f"## Results for Model {model} ##\n")
|
|
||||||
#
|
|
||||||
# if no_foley:
|
|
||||||
# print("(FOLYProp and FOLYFeet have been omitted from the dataset.)\n")
|
|
||||||
#
|
|
||||||
# table = [
|
|
||||||
# ["Total records in sample:", f"{total}"],
|
|
||||||
# ["Top Result:", f"{total_top}",
|
|
||||||
# f"{float(total_top)/float(total):.2%}"],
|
|
||||||
# ["Top 5 Result:", f"{total_top_5}",
|
|
||||||
# f"{float(total_top_5)/float(total):.2%}"],
|
|
||||||
# ["Top 10 Result:", f"{total_top_10}",
|
|
||||||
# f"{float(total_top_10)/float(total):.2%}"],
|
|
||||||
# SEPARATING_LINE,
|
|
||||||
# ["UCS category count:", f"{len(inference_context.catlist)}"],
|
|
||||||
# ["Total categories in sample:", f"{total_cats}",
|
|
||||||
# f"{float(total_cats)/float(len(inference_context.catlist)):.2%}"],
|
|
||||||
# [f"Most missed category ({miss_counts[-1][0]}):",
|
|
||||||
# f"{miss_counts[-1][1]}",
|
|
||||||
# f"{float(miss_counts[-1][1])/float(total):.2%}"]
|
|
||||||
# ]
|
|
||||||
#
|
|
||||||
# print(tabulate(table, headers=['', 'n', 'pct'], tablefmt='github'))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@@ -1,32 +1,7 @@
|
|||||||
|
|
||||||
|
|
||||||
from sentence_transformers import SentenceTransformer
|
# from sentence_transformers import SentenceTransformer
|
||||||
from sentence_transformers.evaluation import BinaryClassificationEvaluator
|
# from sentence_transformers.evaluation import BinaryClassificationEvaluator
|
||||||
from datasets import load_dataset_from_disk, DatasetDict
|
# from datasets import load_dataset_from_disk, DatasetDict
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
def evaluate_model(model: SentenceTransformer, dataset):
|
|
||||||
|
|
||||||
# eval_dataset =
|
|
||||||
|
|
||||||
# Initialize the evaluator
|
|
||||||
binary_acc_evaluator = BinaryClassificationEvaluator(
|
|
||||||
sentences1=eval_dataset["sentence1"],
|
|
||||||
sentences2=eval_dataset["sentence2"],
|
|
||||||
labels=eval_dataset["label"],
|
|
||||||
name="quora_duplicates_dev",
|
|
||||||
)
|
|
||||||
results = binary_acc_evaluator(model)
|
|
||||||
'''
|
|
||||||
Binary Accuracy Evaluation of the model on the quora_duplicates_dev dataset:
|
|
||||||
Accuracy with Cosine-Similarity: 81.60 (Threshold: 0.8352)
|
|
||||||
F1 with Cosine-Similarity: 75.27 (Threshold: 0.7715)
|
|
||||||
Precision with Cosine-Similarity: 65.81
|
|
||||||
Recall with Cosine-Similarity: 87.89
|
|
||||||
Average Precision with Cosine-Similarity: 76.03
|
|
||||||
Matthews Correlation with Cosine-Similarity: 62.48
|
|
||||||
'''
|
|
||||||
print(binary_acc_evaluator.primary_metric)
|
|
||||||
# => "quora_duplicates_dev_cosine_ap"
|
|
||||||
print(results[binary_acc_evaluator.primary_metric])
|
|
||||||
# => 0.760277070888393
|
|
||||||
|
@@ -1,12 +1,24 @@
|
|||||||
from datasets import Dataset, Features, Value, ClassLabel, DatasetInfo
|
from datasets import Dataset, Features, Value, ClassLabel, DatasetInfo
|
||||||
from datasets.dataset_dict import DatasetDict
|
from datasets.dataset_dict import DatasetDict
|
||||||
|
|
||||||
from typing import Generator, Any
|
from typing import Iterator
|
||||||
|
|
||||||
|
from tabulate import tabulate
|
||||||
|
|
||||||
|
def print_dataset_stats(dataset: DatasetDict, catlist: list[str]):
|
||||||
|
|
||||||
|
data_table = []
|
||||||
|
data_table.append([["Total records in combined dataset:", len(dataset)]])
|
||||||
|
data_table.append([["Total records in `train`:", len(dataset['train'])]])
|
||||||
|
|
||||||
|
tab = tabulate(data_table)
|
||||||
|
|
||||||
|
print(tab)
|
||||||
|
|
||||||
# https://www.sbert.net/docs/sentence_transformer/loss_overview.html
|
# https://www.sbert.net/docs/sentence_transformer/loss_overview.html
|
||||||
|
|
||||||
def build_sentence_class_dataset(
|
def build_sentence_class_dataset(
|
||||||
records: Generator[tuple[str, str], Any, None],
|
records: Iterator[tuple[str, str]],
|
||||||
catlist: list[str]) -> DatasetDict:
|
catlist: list[str]) -> DatasetDict:
|
||||||
"""
|
"""
|
||||||
Create a new dataset for `records` which contains (sentence, class) pairs.
|
Create a new dataset for `records` which contains (sentence, class) pairs.
|
||||||
|
@@ -43,12 +43,16 @@ def print_recommendation(path: str | None, text: str, ctx: InferenceContext,
|
|||||||
text = m.group(1)
|
text = m.group(1)
|
||||||
return True, text, None
|
return True, text, None
|
||||||
|
|
||||||
|
elif m := match(r'^c (.*)', response):
|
||||||
|
return True, None, m.group(1)
|
||||||
|
|
||||||
elif response.startswith("?"):
|
elif response.startswith("?"):
|
||||||
print("""
|
print("""
|
||||||
Choices:
|
Choices:
|
||||||
- Enter recommendation number to rename file,
|
- Enter recommendation number to rename file,
|
||||||
- "t [text]" to search for new recommendations based on [text]
|
- "t [text]" to search for new recommendations based on [text]
|
||||||
- "p" re-use the last selected cat-id
|
- "p" re-use the last selected cat-id
|
||||||
|
- "c [cat]" to type in a category by hand
|
||||||
- "?" for this message
|
- "?" for this message
|
||||||
- "q" to quit
|
- "q" to quit
|
||||||
- or any other key to skip this file and continue to next file
|
- or any other key to skip this file and continue to next file
|
||||||
|
Reference in New Issue
Block a user