Work on evaluator, and some model stats.
This commit is contained in:
18
MODELS.rst
Normal file
18
MODELS.rst
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
Results for Model paraphrase-multilingual-mpnet-base-v2
|
||||||
|
===================
|
||||||
|
|
||||||
|
================================ ==== ======
|
||||||
|
.. n pct
|
||||||
|
================================ ==== ======
|
||||||
|
Total records in sample: 3445
|
||||||
|
Top Result: 469 13.61%
|
||||||
|
Top 5 Result: 519 15.07%
|
||||||
|
Top 10 Result: 513 14.89%
|
||||||
|
================================ ==== ======
|
||||||
|
UCS category count: 752
|
||||||
|
Total categories in sample: 240 31.91%
|
||||||
|
Most missed category (FOLYProp): 1057 30.68%
|
||||||
|
================================ ==== ======
|
||||||
|
|
||||||
|
|
||||||
|
|
@@ -12,8 +12,8 @@ from .util import ffmpeg_description, parse_ucs
|
|||||||
|
|
||||||
|
|
||||||
@click.group()
|
@click.group()
|
||||||
@click.option('--verbose', flag_value='verbose', help='Verbose output')
|
# @click.option('--verbose', flag_value='verbose', help='Verbose output')
|
||||||
def ucsinfer(verbose: bool):
|
def ucsinfer():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@@ -68,14 +68,17 @@ def finetune():
|
|||||||
@ucsinfer.command('evaluate')
|
@ucsinfer.command('evaluate')
|
||||||
@click.option('--offset', type=int, default=0)
|
@click.option('--offset', type=int, default=0)
|
||||||
@click.option('--limit', type=int, default=-1)
|
@click.option('--limit', type=int, default=-1)
|
||||||
|
@click.option('--no-foley', type=bool, default=False)
|
||||||
|
@click.option('--model', type=str,
|
||||||
|
default="paraphrase-multilingual-mpnet-base-v2")
|
||||||
@click.argument('dataset', type=click.File('r', encoding='utf8'),
|
@click.argument('dataset', type=click.File('r', encoding='utf8'),
|
||||||
default='dataset.csv')
|
default='dataset.csv')
|
||||||
def evaluate(dataset, offset, limit):
|
def evaluate(dataset, offset, limit, model, no_foley):
|
||||||
"""
|
"""
|
||||||
Use datasets to evauluate model performance
|
Use datasets to evauluate model performance
|
||||||
"""
|
"""
|
||||||
model = SentenceTransformer("paraphrase-multilingual-mpnet-base-v2")
|
m = SentenceTransformer(model)
|
||||||
ctx = InferenceContext(model)
|
ctx = InferenceContext(m, model)
|
||||||
reader = csv.reader(dataset)
|
reader = csv.reader(dataset)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
@@ -113,7 +116,8 @@ def evaluate(dataset, offset, limit):
|
|||||||
|
|
||||||
miss_counts = sorted(miss_counts, key=lambda x: x[1])
|
miss_counts = sorted(miss_counts, key=lambda x: x[1])
|
||||||
|
|
||||||
print(f" === RESULTS === ")
|
print(f"Results for Model {model}")
|
||||||
|
print("=====\n")
|
||||||
|
|
||||||
table = [
|
table = [
|
||||||
["Total records in sample:", f"{total}"],
|
["Total records in sample:", f"{total}"],
|
||||||
@@ -132,7 +136,7 @@ def evaluate(dataset, offset, limit):
|
|||||||
f"{float(miss_counts[-1][1])/float(total):.2%}"]
|
f"{float(miss_counts[-1][1])/float(total):.2%}"]
|
||||||
]
|
]
|
||||||
|
|
||||||
print(tabulate(table, headers=['', 'n', 'pct']))
|
print(tabulate(table, headers=['', 'n', 'pct'], tablefmt='rst'))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@@ -52,9 +52,11 @@ class InferenceContext:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
model: SentenceTransformer
|
model: SentenceTransformer
|
||||||
|
model_name: str
|
||||||
|
|
||||||
def __init__(self, model: SentenceTransformer):
|
def __init__(self, model: SentenceTransformer, model_name: str):
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def catlist(self) -> list[Ucs]:
|
def catlist(self) -> list[Ucs]:
|
||||||
@@ -62,8 +64,9 @@ class InferenceContext:
|
|||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def embeddings(self) -> list[dict]:
|
def embeddings(self) -> list[dict]:
|
||||||
cache_dir = platformdirs.user_cache_dir("ucsinfer", 'Squad 51')
|
cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer", 'Squad 51')
|
||||||
embedding_cache = os.path.join(cache_dir, f"ucs_embedding.cache")
|
embedding_cache = os.path.join(cache_dir,
|
||||||
|
f"{self.model_name}-ucs_embedding.cache")
|
||||||
embeddings = []
|
embeddings = []
|
||||||
|
|
||||||
if os.path.exists(embedding_cache):
|
if os.path.exists(embedding_cache):
|
||||||
@@ -71,7 +74,7 @@ class InferenceContext:
|
|||||||
embeddings = pickle.load(f)
|
embeddings = pickle.load(f)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print("Calculating embeddings...")
|
print(f"Calculating embeddings for model {self.model_name}...")
|
||||||
|
|
||||||
for cat_defn in self.catlist:
|
for cat_defn in self.catlist:
|
||||||
embeddings += [{
|
embeddings += [{
|
||||||
|
Reference in New Issue
Block a user