diff --git a/ucsinfer/inference.py b/ucsinfer/inference.py index aab3f7d..be6739e 100644 --- a/ucsinfer/inference.py +++ b/ucsinfer/inference.py @@ -103,7 +103,12 @@ class InferenceContext: else: print(f"Calculating embeddings for model {self.model_name}...") - for cat_defn in self.catlist: + # we need to calculate the embeddings for all cats, not just the + # ones we're loading for this run + + full_catlist = load_ucs(full_ucs= True) + + for cat_defn in full_catlist: embeddings += [{ 'CatID': cat_defn.catid, 'Embedding': self._encode_category(cat_defn) @@ -113,7 +118,9 @@ class InferenceContext: with open(embedding_cache_path, 'wb') as g: pickle.dump(embeddings, g) - return embeddings + whitelisted_cats = [cat.catid for cat in self.catlist] + + return [e for e in embeddings if e['CatID'] in whitelisted_cats] def _encode_category(self, cat: Ucs) -> np.ndarray: sentence_components = [cat.explanations, diff --git a/ucsinfer/recommend.py b/ucsinfer/recommend.py index 5b98f33..6c92abd 100644 --- a/ucsinfer/recommend.py +++ b/ucsinfer/recommend.py @@ -1,7 +1,6 @@ # recommend.py from re import match - from .inference import InferenceContext def print_recommendation(path: str | None, text: str, ctx: InferenceContext, @@ -23,6 +22,7 @@ def print_recommendation(path: str | None, text: str, ctx: InferenceContext, print(f"Path: {path}") print(f"Text: {text or ''}") + for i, r in enumerate(recs): cat, subcat, _ = ctx.lookup_category(r) print(f"- {i}: {r} ({cat}-{subcat})")