Fixed a bug in cat masking
This commit is contained in:
@@ -103,7 +103,12 @@ class InferenceContext:
|
|||||||
else:
|
else:
|
||||||
print(f"Calculating embeddings for model {self.model_name}...")
|
print(f"Calculating embeddings for model {self.model_name}...")
|
||||||
|
|
||||||
for cat_defn in self.catlist:
|
# we need to calculate the embeddings for all cats, not just the
|
||||||
|
# ones we're loading for this run
|
||||||
|
|
||||||
|
full_catlist = load_ucs(full_ucs= True)
|
||||||
|
|
||||||
|
for cat_defn in full_catlist:
|
||||||
embeddings += [{
|
embeddings += [{
|
||||||
'CatID': cat_defn.catid,
|
'CatID': cat_defn.catid,
|
||||||
'Embedding': self._encode_category(cat_defn)
|
'Embedding': self._encode_category(cat_defn)
|
||||||
@@ -113,7 +118,9 @@ class InferenceContext:
|
|||||||
with open(embedding_cache_path, 'wb') as g:
|
with open(embedding_cache_path, 'wb') as g:
|
||||||
pickle.dump(embeddings, g)
|
pickle.dump(embeddings, g)
|
||||||
|
|
||||||
return embeddings
|
whitelisted_cats = [cat.catid for cat in self.catlist]
|
||||||
|
|
||||||
|
return [e for e in embeddings if e['CatID'] in whitelisted_cats]
|
||||||
|
|
||||||
def _encode_category(self, cat: Ucs) -> np.ndarray:
|
def _encode_category(self, cat: Ucs) -> np.ndarray:
|
||||||
sentence_components = [cat.explanations,
|
sentence_components = [cat.explanations,
|
||||||
|
@@ -1,7 +1,6 @@
|
|||||||
# recommend.py
|
# recommend.py
|
||||||
|
|
||||||
from re import match
|
from re import match
|
||||||
|
|
||||||
from .inference import InferenceContext
|
from .inference import InferenceContext
|
||||||
|
|
||||||
def print_recommendation(path: str | None, text: str, ctx: InferenceContext,
|
def print_recommendation(path: str | None, text: str, ctx: InferenceContext,
|
||||||
@@ -23,6 +22,7 @@ def print_recommendation(path: str | None, text: str, ctx: InferenceContext,
|
|||||||
print(f"Path: {path}")
|
print(f"Path: {path}")
|
||||||
|
|
||||||
print(f"Text: {text or '<None>'}")
|
print(f"Text: {text or '<None>'}")
|
||||||
|
|
||||||
for i, r in enumerate(recs):
|
for i, r in enumerate(recs):
|
||||||
cat, subcat, _ = ctx.lookup_category(r)
|
cat, subcat, _ = ctx.lookup_category(r)
|
||||||
print(f"- {i}: {r} ({cat}-{subcat})")
|
print(f"- {i}: {r} ({cat}-{subcat})")
|
||||||
|
Reference in New Issue
Block a user