Fixed a bug in cat masking
This commit is contained in:
@@ -103,7 +103,12 @@ class InferenceContext:
|
||||
else:
|
||||
print(f"Calculating embeddings for model {self.model_name}...")
|
||||
|
||||
for cat_defn in self.catlist:
|
||||
# we need to calculate the embeddings for all cats, not just the
|
||||
# ones we're loading for this run
|
||||
|
||||
full_catlist = load_ucs(full_ucs= True)
|
||||
|
||||
for cat_defn in full_catlist:
|
||||
embeddings += [{
|
||||
'CatID': cat_defn.catid,
|
||||
'Embedding': self._encode_category(cat_defn)
|
||||
@@ -113,7 +118,9 @@ class InferenceContext:
|
||||
with open(embedding_cache_path, 'wb') as g:
|
||||
pickle.dump(embeddings, g)
|
||||
|
||||
return embeddings
|
||||
whitelisted_cats = [cat.catid for cat in self.catlist]
|
||||
|
||||
return [e for e in embeddings if e['CatID'] in whitelisted_cats]
|
||||
|
||||
def _encode_category(self, cat: Ucs) -> np.ndarray:
|
||||
sentence_components = [cat.explanations,
|
||||
|
@@ -1,7 +1,6 @@
|
||||
# recommend.py
|
||||
|
||||
from re import match
|
||||
|
||||
from .inference import InferenceContext
|
||||
|
||||
def print_recommendation(path: str | None, text: str, ctx: InferenceContext,
|
||||
@@ -23,6 +22,7 @@ def print_recommendation(path: str | None, text: str, ctx: InferenceContext,
|
||||
print(f"Path: {path}")
|
||||
|
||||
print(f"Text: {text or '<None>'}")
|
||||
|
||||
for i, r in enumerate(recs):
|
||||
cat, subcat, _ = ctx.lookup_category(r)
|
||||
print(f"- {i}: {r} ({cat}-{subcat})")
|
||||
|
Reference in New Issue
Block a user