Compare commits
3 Commits
f40a336893
...
1c5461a6a9
Author | SHA1 | Date | |
---|---|---|---|
![]() |
1c5461a6a9 | ||
![]() |
aea0bbb7be | ||
![]() |
c687e4614f |
18
MODELS.rst
Normal file
18
MODELS.rst
Normal file
@@ -0,0 +1,18 @@
|
||||
Results for Model paraphrase-multilingual-mpnet-base-v2
|
||||
===================
|
||||
|
||||
================================ ==== ======
|
||||
.. n pct
|
||||
================================ ==== ======
|
||||
Total records in sample: 3445
|
||||
Top Result: 469 13.61%
|
||||
Top 5 Result: 519 15.07%
|
||||
Top 10 Result: 513 14.89%
|
||||
================================ ==== ======
|
||||
UCS category count: 752
|
||||
Total categories in sample: 240 31.91%
|
||||
Most missed category (FOLYProp): 1057 30.68%
|
||||
================================ ==== ======
|
||||
|
||||
|
||||
|
@@ -12,8 +12,8 @@ from .util import ffmpeg_description, parse_ucs
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.option('--verbose', flag_value='verbose', help='Verbose output')
|
||||
def ucsinfer(verbose: bool):
|
||||
# @click.option('--verbose', flag_value='verbose', help='Verbose output')
|
||||
def ucsinfer():
|
||||
pass
|
||||
|
||||
|
||||
@@ -68,14 +68,17 @@ def finetune():
|
||||
@ucsinfer.command('evaluate')
|
||||
@click.option('--offset', type=int, default=0)
|
||||
@click.option('--limit', type=int, default=-1)
|
||||
@click.option('--no-foley', type=bool, default=False)
|
||||
@click.option('--model', type=str,
|
||||
default="paraphrase-multilingual-mpnet-base-v2")
|
||||
@click.argument('dataset', type=click.File('r', encoding='utf8'),
|
||||
default='dataset.csv')
|
||||
def evaluate(dataset, offset, limit):
|
||||
def evaluate(dataset, offset, limit, model, no_foley):
|
||||
"""
|
||||
Use datasets to evauluate model performance
|
||||
"""
|
||||
model = SentenceTransformer("paraphrase-multilingual-mpnet-base-v2")
|
||||
ctx = InferenceContext(model)
|
||||
m = SentenceTransformer(model)
|
||||
ctx = InferenceContext(m, model)
|
||||
reader = csv.reader(dataset)
|
||||
|
||||
results = []
|
||||
@@ -113,7 +116,8 @@ def evaluate(dataset, offset, limit):
|
||||
|
||||
miss_counts = sorted(miss_counts, key=lambda x: x[1])
|
||||
|
||||
print(f" === RESULTS === ")
|
||||
print(f"Results for Model {model}")
|
||||
print("=====\n")
|
||||
|
||||
table = [
|
||||
["Total records in sample:", f"{total}"],
|
||||
@@ -132,7 +136,7 @@ def evaluate(dataset, offset, limit):
|
||||
f"{float(miss_counts[-1][1])/float(total):.2%}"]
|
||||
]
|
||||
|
||||
print(tabulate(table, headers=['', 'n', 'pct']))
|
||||
print(tabulate(table, headers=['', 'n', 'pct'], tablefmt='rst'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@@ -52,9 +52,11 @@ class InferenceContext:
|
||||
"""
|
||||
|
||||
model: SentenceTransformer
|
||||
model_name: str
|
||||
|
||||
def __init__(self, model: SentenceTransformer):
|
||||
def __init__(self, model: SentenceTransformer, model_name: str):
|
||||
self.model = model
|
||||
self.model_name = model_name
|
||||
|
||||
@cached_property
|
||||
def catlist(self) -> list[Ucs]:
|
||||
@@ -62,8 +64,9 @@ class InferenceContext:
|
||||
|
||||
@cached_property
|
||||
def embeddings(self) -> list[dict]:
|
||||
cache_dir = platformdirs.user_cache_dir("ucsinfer", 'Squad 51')
|
||||
embedding_cache = os.path.join(cache_dir, f"ucs_embedding.cache")
|
||||
cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer", 'Squad 51')
|
||||
embedding_cache = os.path.join(cache_dir,
|
||||
f"{self.model_name}-ucs_embedding.cache")
|
||||
embeddings = []
|
||||
|
||||
if os.path.exists(embedding_cache):
|
||||
@@ -71,7 +74,7 @@ class InferenceContext:
|
||||
embeddings = pickle.load(f)
|
||||
|
||||
else:
|
||||
print("Calculating embeddings...")
|
||||
print(f"Calculating embeddings for model {self.model_name}...")
|
||||
|
||||
for cat_defn in self.catlist:
|
||||
embeddings += [{
|
||||
|
@@ -4,9 +4,6 @@ from typing import NamedTuple, Optional
|
||||
from re import match
|
||||
|
||||
|
||||
from .inference import Ucs
|
||||
|
||||
|
||||
def ffmpeg_description(path: str) -> Optional[str]:
|
||||
result = subprocess.run(['ffprobe', '-show_format', '-of',
|
||||
'json', path], capture_output=True)
|
||||
|
Reference in New Issue
Block a user