From 184edcb7e43b5b44caf710fb894167cda9599c8e Mon Sep 17 00:00:00 2001 From: Jamie Hardt Date: Wed, 3 Sep 2025 11:42:10 -0700 Subject: [PATCH] Removed unnecessary parameter --- ucsinfer/__main__.py | 7 ++----- ucsinfer/inference.py | 13 ++++++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/ucsinfer/__main__.py b/ucsinfer/__main__.py index 63b0959..2fc1e06 100644 --- a/ucsinfer/__main__.py +++ b/ucsinfer/__main__.py @@ -3,7 +3,6 @@ import sys import csv import logging -from sentence_transformers import SentenceTransformer import tqdm import click from tabulate import tabulate, SEPARATING_LINE @@ -59,8 +58,7 @@ def recommend(text, paths, model, interactive, skip_ucs): the synonyms an explanations definied for each UCS subcategory. A list of ranked subcategories is printed to the terminal for each PATH. """ - m = SentenceTransformer(model) - ctx = InferenceContext(m, model) + ctx = InferenceContext(model) if text is not None: print_recommendation(None, text, ctx, interactive_rename=False) @@ -177,8 +175,7 @@ def evaluate(dataset, offset, limit, model, no_foley): classified according to their subject and not wether or not they were foley, and so these categories can be excluded with the --no-foley option. """ - m = SentenceTransformer(model) - ctx = InferenceContext(m, model) + ctx = InferenceContext(model) reader = csv.reader(dataset) print(f"Evaluating model {model}...") diff --git a/ucsinfer/inference.py b/ucsinfer/inference.py index 630c681..b5fc4a5 100644 --- a/ucsinfer/inference.py +++ b/ucsinfer/inference.py @@ -54,8 +54,8 @@ class InferenceContext: model: SentenceTransformer model_name: str - def __init__(self, model: SentenceTransformer, model_name: str): - self.model = model + def __init__(self, model_name: str): + self.model = SentenceTransformer(model_name) self.model_name = model_name @cached_property @@ -64,9 +64,12 @@ class InferenceContext: @cached_property def embeddings(self) -> list[dict]: - cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer", 'Squad 51') - embedding_cache = os.path.join(cache_dir, - f"{self.model_name}-ucs_embedding.cache") + cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer", + 'Squad 51') + embedding_cache = os.path.join( + cache_dir, + f"{self.model_name}-ucs_embedding.cache") + embeddings = [] if os.path.exists(embedding_cache):