Removed unnecessary parameter

This commit is contained in:
2025-09-03 11:42:10 -07:00
parent d330f47462
commit 184edcb7e4
2 changed files with 10 additions and 10 deletions

View File

@@ -3,7 +3,6 @@ import sys
import csv import csv
import logging import logging
from sentence_transformers import SentenceTransformer
import tqdm import tqdm
import click import click
from tabulate import tabulate, SEPARATING_LINE from tabulate import tabulate, SEPARATING_LINE
@@ -59,8 +58,7 @@ def recommend(text, paths, model, interactive, skip_ucs):
the synonyms an explanations definied for each UCS subcategory. A list the synonyms an explanations definied for each UCS subcategory. A list
of ranked subcategories is printed to the terminal for each PATH. of ranked subcategories is printed to the terminal for each PATH.
""" """
m = SentenceTransformer(model) ctx = InferenceContext(model)
ctx = InferenceContext(m, model)
if text is not None: if text is not None:
print_recommendation(None, text, ctx, interactive_rename=False) print_recommendation(None, text, ctx, interactive_rename=False)
@@ -177,8 +175,7 @@ def evaluate(dataset, offset, limit, model, no_foley):
classified according to their subject and not wether or not they were classified according to their subject and not wether or not they were
foley, and so these categories can be excluded with the --no-foley option. foley, and so these categories can be excluded with the --no-foley option.
""" """
m = SentenceTransformer(model) ctx = InferenceContext(model)
ctx = InferenceContext(m, model)
reader = csv.reader(dataset) reader = csv.reader(dataset)
print(f"Evaluating model {model}...") print(f"Evaluating model {model}...")

View File

@@ -54,8 +54,8 @@ class InferenceContext:
model: SentenceTransformer model: SentenceTransformer
model_name: str model_name: str
def __init__(self, model: SentenceTransformer, model_name: str): def __init__(self, model_name: str):
self.model = model self.model = SentenceTransformer(model_name)
self.model_name = model_name self.model_name = model_name
@cached_property @cached_property
@@ -64,9 +64,12 @@ class InferenceContext:
@cached_property @cached_property
def embeddings(self) -> list[dict]: def embeddings(self) -> list[dict]:
cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer", 'Squad 51') cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer",
embedding_cache = os.path.join(cache_dir, 'Squad 51')
f"{self.model_name}-ucs_embedding.cache") embedding_cache = os.path.join(
cache_dir,
f"{self.model_name}-ucs_embedding.cache")
embeddings = [] embeddings = []
if os.path.exists(embedding_cache): if os.path.exists(embedding_cache):