Fixed a bug in cat masking

This commit is contained in:
2025-09-04 10:54:35 -07:00
parent 103fffe0a4
commit 04332b73ee
2 changed files with 10 additions and 3 deletions

View File

@@ -103,7 +103,12 @@ class InferenceContext:
else:
print(f"Calculating embeddings for model {self.model_name}...")
for cat_defn in self.catlist:
# we need to calculate the embeddings for all cats, not just the
# ones we're loading for this run
full_catlist = load_ucs(full_ucs= True)
for cat_defn in full_catlist:
embeddings += [{
'CatID': cat_defn.catid,
'Embedding': self._encode_category(cat_defn)
@@ -113,7 +118,9 @@ class InferenceContext:
with open(embedding_cache_path, 'wb') as g:
pickle.dump(embeddings, g)
return embeddings
whitelisted_cats = [cat.catid for cat in self.catlist]
return [e for e in embeddings if e['CatID'] in whitelisted_cats]
def _encode_category(self, cat: Ucs) -> np.ndarray:
sentence_components = [cat.explanations,

View File

@@ -1,7 +1,6 @@
# recommend.py
from re import match
from .inference import InferenceContext
def print_recommendation(path: str | None, text: str, ctx: InferenceContext,
@@ -23,6 +22,7 @@ def print_recommendation(path: str | None, text: str, ctx: InferenceContext,
print(f"Path: {path}")
print(f"Text: {text or '<None>'}")
for i, r in enumerate(recs):
cat, subcat, _ = ctx.lookup_category(r)
print(f"- {i}: {r} ({cat}-{subcat})")