Removed unnecessary parameter
This commit is contained in:
@@ -3,7 +3,6 @@ import sys
|
||||
import csv
|
||||
import logging
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import tqdm
|
||||
import click
|
||||
from tabulate import tabulate, SEPARATING_LINE
|
||||
@@ -59,8 +58,7 @@ def recommend(text, paths, model, interactive, skip_ucs):
|
||||
the synonyms an explanations definied for each UCS subcategory. A list
|
||||
of ranked subcategories is printed to the terminal for each PATH.
|
||||
"""
|
||||
m = SentenceTransformer(model)
|
||||
ctx = InferenceContext(m, model)
|
||||
ctx = InferenceContext(model)
|
||||
|
||||
if text is not None:
|
||||
print_recommendation(None, text, ctx, interactive_rename=False)
|
||||
@@ -177,8 +175,7 @@ def evaluate(dataset, offset, limit, model, no_foley):
|
||||
classified according to their subject and not wether or not they were
|
||||
foley, and so these categories can be excluded with the --no-foley option.
|
||||
"""
|
||||
m = SentenceTransformer(model)
|
||||
ctx = InferenceContext(m, model)
|
||||
ctx = InferenceContext(model)
|
||||
reader = csv.reader(dataset)
|
||||
|
||||
print(f"Evaluating model {model}...")
|
||||
|
@@ -54,8 +54,8 @@ class InferenceContext:
|
||||
model: SentenceTransformer
|
||||
model_name: str
|
||||
|
||||
def __init__(self, model: SentenceTransformer, model_name: str):
|
||||
self.model = model
|
||||
def __init__(self, model_name: str):
|
||||
self.model = SentenceTransformer(model_name)
|
||||
self.model_name = model_name
|
||||
|
||||
@cached_property
|
||||
@@ -64,9 +64,12 @@ class InferenceContext:
|
||||
|
||||
@cached_property
|
||||
def embeddings(self) -> list[dict]:
|
||||
cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer", 'Squad 51')
|
||||
embedding_cache = os.path.join(cache_dir,
|
||||
f"{self.model_name}-ucs_embedding.cache")
|
||||
cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer",
|
||||
'Squad 51')
|
||||
embedding_cache = os.path.join(
|
||||
cache_dir,
|
||||
f"{self.model_name}-ucs_embedding.cache")
|
||||
|
||||
embeddings = []
|
||||
|
||||
if os.path.exists(embedding_cache):
|
||||
|
Reference in New Issue
Block a user