Work on evaluator, and some model stats.

This commit is contained in:
Jamie Hardt
2025-08-26 17:59:01 -07:00
parent c687e4614f
commit aea0bbb7be
3 changed files with 36 additions and 11 deletions

18
MODELS.rst Normal file
View File

@@ -0,0 +1,18 @@
Results for Model paraphrase-multilingual-mpnet-base-v2
===================
================================ ==== ======
.. n pct
================================ ==== ======
Total records in sample: 3445
Top Result: 469 13.61%
Top 5 Result: 519 15.07%
Top 10 Result: 513 14.89%
================================ ==== ======
UCS category count: 752
Total categories in sample: 240 31.91%
Most missed category (FOLYProp): 1057 30.68%
================================ ==== ======

View File

@@ -12,8 +12,8 @@ from .util import ffmpeg_description, parse_ucs
@click.group()
@click.option('--verbose', flag_value='verbose', help='Verbose output')
def ucsinfer(verbose: bool):
# @click.option('--verbose', flag_value='verbose', help='Verbose output')
def ucsinfer():
pass
@@ -68,14 +68,17 @@ def finetune():
@ucsinfer.command('evaluate')
@click.option('--offset', type=int, default=0)
@click.option('--limit', type=int, default=-1)
@click.option('--no-foley', type=bool, default=False)
@click.option('--model', type=str,
default="paraphrase-multilingual-mpnet-base-v2")
@click.argument('dataset', type=click.File('r', encoding='utf8'),
default='dataset.csv')
def evaluate(dataset, offset, limit):
def evaluate(dataset, offset, limit, model, no_foley):
"""
Use datasets to evauluate model performance
"""
model = SentenceTransformer("paraphrase-multilingual-mpnet-base-v2")
ctx = InferenceContext(model)
m = SentenceTransformer(model)
ctx = InferenceContext(m, model)
reader = csv.reader(dataset)
results = []
@@ -113,7 +116,8 @@ def evaluate(dataset, offset, limit):
miss_counts = sorted(miss_counts, key=lambda x: x[1])
print(f" === RESULTS === ")
print(f"Results for Model {model}")
print("=====\n")
table = [
["Total records in sample:", f"{total}"],
@@ -132,7 +136,7 @@ def evaluate(dataset, offset, limit):
f"{float(miss_counts[-1][1])/float(total):.2%}"]
]
print(tabulate(table, headers=['', 'n', 'pct']))
print(tabulate(table, headers=['', 'n', 'pct'], tablefmt='rst'))
if __name__ == '__main__':

View File

@@ -52,9 +52,11 @@ class InferenceContext:
"""
model: SentenceTransformer
model_name: str
def __init__(self, model: SentenceTransformer):
def __init__(self, model: SentenceTransformer, model_name: str):
self.model = model
self.model_name = model_name
@cached_property
def catlist(self) -> list[Ucs]:
@@ -62,8 +64,9 @@ class InferenceContext:
@cached_property
def embeddings(self) -> list[dict]:
cache_dir = platformdirs.user_cache_dir("ucsinfer", 'Squad 51')
embedding_cache = os.path.join(cache_dir, f"ucs_embedding.cache")
cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer", 'Squad 51')
embedding_cache = os.path.join(cache_dir,
f"{self.model_name}-ucs_embedding.cache")
embeddings = []
if os.path.exists(embedding_cache):
@@ -71,7 +74,7 @@ class InferenceContext:
embeddings = pickle.load(f)
else:
print("Calculating embeddings...")
print(f"Calculating embeddings for model {self.model_name}...")
for cat_defn in self.catlist:
embeddings += [{