diff --git a/MODELS.rst b/MODELS.rst new file mode 100644 index 0000000..4dc1c4a --- /dev/null +++ b/MODELS.rst @@ -0,0 +1,18 @@ +Results for Model paraphrase-multilingual-mpnet-base-v2 +=================== + +================================ ==== ====== +.. n pct +================================ ==== ====== +Total records in sample: 3445 +Top Result: 469 13.61% +Top 5 Result: 519 15.07% +Top 10 Result: 513 14.89% +================================ ==== ====== +UCS category count: 752 +Total categories in sample: 240 31.91% +Most missed category (FOLYProp): 1057 30.68% +================================ ==== ====== + + + diff --git a/ucsinfer/__main__.py b/ucsinfer/__main__.py index a787d85..e0d2c59 100644 --- a/ucsinfer/__main__.py +++ b/ucsinfer/__main__.py @@ -12,8 +12,8 @@ from .util import ffmpeg_description, parse_ucs @click.group() -@click.option('--verbose', flag_value='verbose', help='Verbose output') -def ucsinfer(verbose: bool): +# @click.option('--verbose', flag_value='verbose', help='Verbose output') +def ucsinfer(): pass @@ -68,14 +68,17 @@ def finetune(): @ucsinfer.command('evaluate') @click.option('--offset', type=int, default=0) @click.option('--limit', type=int, default=-1) +@click.option('--no-foley', type=bool, default=False) +@click.option('--model', type=str, + default="paraphrase-multilingual-mpnet-base-v2") @click.argument('dataset', type=click.File('r', encoding='utf8'), default='dataset.csv') -def evaluate(dataset, offset, limit): +def evaluate(dataset, offset, limit, model, no_foley): """ Use datasets to evauluate model performance """ - model = SentenceTransformer("paraphrase-multilingual-mpnet-base-v2") - ctx = InferenceContext(model) + m = SentenceTransformer(model) + ctx = InferenceContext(m, model) reader = csv.reader(dataset) results = [] @@ -113,7 +116,8 @@ def evaluate(dataset, offset, limit): miss_counts = sorted(miss_counts, key=lambda x: x[1]) - print(f" === RESULTS === ") + print(f"Results for Model {model}") + print("=====\n") table = [ ["Total records in sample:", f"{total}"], @@ -132,7 +136,7 @@ def evaluate(dataset, offset, limit): f"{float(miss_counts[-1][1])/float(total):.2%}"] ] - print(tabulate(table, headers=['', 'n', 'pct'])) + print(tabulate(table, headers=['', 'n', 'pct'], tablefmt='rst')) if __name__ == '__main__': diff --git a/ucsinfer/inference.py b/ucsinfer/inference.py index c7cee72..630c681 100644 --- a/ucsinfer/inference.py +++ b/ucsinfer/inference.py @@ -52,9 +52,11 @@ class InferenceContext: """ model: SentenceTransformer + model_name: str - def __init__(self, model: SentenceTransformer): + def __init__(self, model: SentenceTransformer, model_name: str): self.model = model + self.model_name = model_name @cached_property def catlist(self) -> list[Ucs]: @@ -62,8 +64,9 @@ class InferenceContext: @cached_property def embeddings(self) -> list[dict]: - cache_dir = platformdirs.user_cache_dir("ucsinfer", 'Squad 51') - embedding_cache = os.path.join(cache_dir, f"ucs_embedding.cache") + cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer", 'Squad 51') + embedding_cache = os.path.join(cache_dir, + f"{self.model_name}-ucs_embedding.cache") embeddings = [] if os.path.exists(embedding_cache): @@ -71,7 +74,7 @@ class InferenceContext: embeddings = pickle.load(f) else: - print("Calculating embeddings...") + print(f"Calculating embeddings for model {self.model_name}...") for cat_defn in self.catlist: embeddings += [{