From 04332b73eebc6fe8a4973a6313dc86b0ab700c91 Mon Sep 17 00:00:00 2001 From: Jamie Hardt Date: Thu, 4 Sep 2025 10:54:35 -0700 Subject: [PATCH] Fixed a bug in cat masking --- ucsinfer/inference.py | 11 +++++++++-- ucsinfer/recommend.py | 2 +- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/ucsinfer/inference.py b/ucsinfer/inference.py index aab3f7d..be6739e 100644 --- a/ucsinfer/inference.py +++ b/ucsinfer/inference.py @@ -103,7 +103,12 @@ class InferenceContext: else: print(f"Calculating embeddings for model {self.model_name}...") - for cat_defn in self.catlist: + # we need to calculate the embeddings for all cats, not just the + # ones we're loading for this run + + full_catlist = load_ucs(full_ucs= True) + + for cat_defn in full_catlist: embeddings += [{ 'CatID': cat_defn.catid, 'Embedding': self._encode_category(cat_defn) @@ -113,7 +118,9 @@ class InferenceContext: with open(embedding_cache_path, 'wb') as g: pickle.dump(embeddings, g) - return embeddings + whitelisted_cats = [cat.catid for cat in self.catlist] + + return [e for e in embeddings if e['CatID'] in whitelisted_cats] def _encode_category(self, cat: Ucs) -> np.ndarray: sentence_components = [cat.explanations, diff --git a/ucsinfer/recommend.py b/ucsinfer/recommend.py index 5b98f33..6c92abd 100644 --- a/ucsinfer/recommend.py +++ b/ucsinfer/recommend.py @@ -1,7 +1,6 @@ # recommend.py from re import match - from .inference import InferenceContext def print_recommendation(path: str | None, text: str, ctx: InferenceContext, @@ -23,6 +22,7 @@ def print_recommendation(path: str | None, text: str, ctx: InferenceContext, print(f"Path: {path}") print(f"Text: {text or ''}") + for i, r in enumerate(recs): cat, subcat, _ = ctx.lookup_category(r) print(f"- {i}: {r} ({cat}-{subcat})")