Removed unnecessary parameter

This commit is contained in:
2025-09-03 11:42:10 -07:00
parent d330f47462
commit 184edcb7e4
2 changed files with 10 additions and 10 deletions

View File

@@ -3,7 +3,6 @@ import sys
import csv
import logging
from sentence_transformers import SentenceTransformer
import tqdm
import click
from tabulate import tabulate, SEPARATING_LINE
@@ -59,8 +58,7 @@ def recommend(text, paths, model, interactive, skip_ucs):
the synonyms an explanations definied for each UCS subcategory. A list
of ranked subcategories is printed to the terminal for each PATH.
"""
m = SentenceTransformer(model)
ctx = InferenceContext(m, model)
ctx = InferenceContext(model)
if text is not None:
print_recommendation(None, text, ctx, interactive_rename=False)
@@ -177,8 +175,7 @@ def evaluate(dataset, offset, limit, model, no_foley):
classified according to their subject and not wether or not they were
foley, and so these categories can be excluded with the --no-foley option.
"""
m = SentenceTransformer(model)
ctx = InferenceContext(m, model)
ctx = InferenceContext(model)
reader = csv.reader(dataset)
print(f"Evaluating model {model}...")

View File

@@ -54,8 +54,8 @@ class InferenceContext:
model: SentenceTransformer
model_name: str
def __init__(self, model: SentenceTransformer, model_name: str):
self.model = model
def __init__(self, model_name: str):
self.model = SentenceTransformer(model_name)
self.model_name = model_name
@cached_property
@@ -64,9 +64,12 @@ class InferenceContext:
@cached_property
def embeddings(self) -> list[dict]:
cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer", 'Squad 51')
embedding_cache = os.path.join(cache_dir,
cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer",
'Squad 51')
embedding_cache = os.path.join(
cache_dir,
f"{self.model_name}-ucs_embedding.cache")
embeddings = []
if os.path.exists(embedding_cache):