Removed unnecessary parameter
This commit is contained in:
@@ -3,7 +3,6 @@ import sys
|
|||||||
import csv
|
import csv
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
import tqdm
|
import tqdm
|
||||||
import click
|
import click
|
||||||
from tabulate import tabulate, SEPARATING_LINE
|
from tabulate import tabulate, SEPARATING_LINE
|
||||||
@@ -59,8 +58,7 @@ def recommend(text, paths, model, interactive, skip_ucs):
|
|||||||
the synonyms an explanations definied for each UCS subcategory. A list
|
the synonyms an explanations definied for each UCS subcategory. A list
|
||||||
of ranked subcategories is printed to the terminal for each PATH.
|
of ranked subcategories is printed to the terminal for each PATH.
|
||||||
"""
|
"""
|
||||||
m = SentenceTransformer(model)
|
ctx = InferenceContext(model)
|
||||||
ctx = InferenceContext(m, model)
|
|
||||||
|
|
||||||
if text is not None:
|
if text is not None:
|
||||||
print_recommendation(None, text, ctx, interactive_rename=False)
|
print_recommendation(None, text, ctx, interactive_rename=False)
|
||||||
@@ -177,8 +175,7 @@ def evaluate(dataset, offset, limit, model, no_foley):
|
|||||||
classified according to their subject and not wether or not they were
|
classified according to their subject and not wether or not they were
|
||||||
foley, and so these categories can be excluded with the --no-foley option.
|
foley, and so these categories can be excluded with the --no-foley option.
|
||||||
"""
|
"""
|
||||||
m = SentenceTransformer(model)
|
ctx = InferenceContext(model)
|
||||||
ctx = InferenceContext(m, model)
|
|
||||||
reader = csv.reader(dataset)
|
reader = csv.reader(dataset)
|
||||||
|
|
||||||
print(f"Evaluating model {model}...")
|
print(f"Evaluating model {model}...")
|
||||||
|
@@ -54,8 +54,8 @@ class InferenceContext:
|
|||||||
model: SentenceTransformer
|
model: SentenceTransformer
|
||||||
model_name: str
|
model_name: str
|
||||||
|
|
||||||
def __init__(self, model: SentenceTransformer, model_name: str):
|
def __init__(self, model_name: str):
|
||||||
self.model = model
|
self.model = SentenceTransformer(model_name)
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
@@ -64,9 +64,12 @@ class InferenceContext:
|
|||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def embeddings(self) -> list[dict]:
|
def embeddings(self) -> list[dict]:
|
||||||
cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer", 'Squad 51')
|
cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer",
|
||||||
embedding_cache = os.path.join(cache_dir,
|
'Squad 51')
|
||||||
|
embedding_cache = os.path.join(
|
||||||
|
cache_dir,
|
||||||
f"{self.model_name}-ucs_embedding.cache")
|
f"{self.model_name}-ucs_embedding.cache")
|
||||||
|
|
||||||
embeddings = []
|
embeddings = []
|
||||||
|
|
||||||
if os.path.exists(embedding_cache):
|
if os.path.exists(embedding_cache):
|
||||||
|
Reference in New Issue
Block a user