Files
ucsinfer/ucsinfer/inference.py
2025-08-26 17:59:01 -07:00

122 lines
3.8 KiB
Python

import os.path
import json
import pickle
from functools import cached_property
from typing import NamedTuple
import numpy as np
import platformdirs
from sentence_transformers import SentenceTransformer
def classify_text_ranked(text, embeddings_list, model, limit=5):
text_embedding = model.encode(text, convert_to_numpy=True)
embeddings = np.array([info['Embedding'] for info in embeddings_list])
sim = model.similarity(text_embedding, embeddings)[0]
maxinds = np.argsort(sim)[-limit:]
return [embeddings_list[x]['CatID'] for x in reversed(maxinds)]
class Ucs(NamedTuple):
catid: str
category: str
subcategory: str
explanations: str
synonymns: list[str]
@classmethod
def from_dict(cls, d: dict):
return Ucs(catid=d['CatID'], category=d['Category'],
subcategory=d['SubCategory'],
explanations=d['Explanations'], synonymns=d['Synonyms'])
def load_ucs() -> list[Ucs]:
FILE_ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
cats = []
ucs_defs = os.path.join(FILE_ROOT_DIR, 'ucs-community', 'json',
'en.json')
with open(ucs_defs, 'r') as f:
cats = json.load(f)
return [Ucs.from_dict(cat) for cat in cats]
class InferenceContext:
"""
Maintains caches and resources for UCS category inference.
"""
model: SentenceTransformer
model_name: str
def __init__(self, model: SentenceTransformer, model_name: str):
self.model = model
self.model_name = model_name
@cached_property
def catlist(self) -> list[Ucs]:
return load_ucs()
@cached_property
def embeddings(self) -> list[dict]:
cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer", 'Squad 51')
embedding_cache = os.path.join(cache_dir,
f"{self.model_name}-ucs_embedding.cache")
embeddings = []
if os.path.exists(embedding_cache):
with open(embedding_cache, 'rb') as f:
embeddings = pickle.load(f)
else:
print(f"Calculating embeddings for model {self.model_name}...")
for cat_defn in self.catlist:
embeddings += [{
'CatID': cat_defn.catid,
'Embedding': self._encode_category(cat_defn)
}]
os.makedirs(os.path.dirname(embedding_cache), exist_ok=True)
with open(embedding_cache, 'wb') as g:
pickle.dump(embeddings, g)
return embeddings
def _encode_category(self, cat: Ucs) -> np.ndarray:
sentence_components = [cat.explanations,
cat.category,
cat.subcategory
]
sentence_components += cat.synonymns
sentence = ", ".join(sentence_components)
return self.model.encode(sentence, convert_to_numpy=True)
def classify_text_ranked(self, text: str, limit: int = 5) -> list[str]:
"""
Sort all UCS catids according to their similarity to `text`
"""
text_embedding = self.model.encode(text, convert_to_numpy=True)
embeddings = np.array([info['Embedding'] for info in self.embeddings])
sim = self.model.similarity(text_embedding, embeddings)[0]
maxinds = np.argsort(sim)[-limit:]
return [self.embeddings[x]['CatID'] for x in reversed(maxinds)]
def lookup_category(self, catid) -> tuple[str, str, str]:
"""
Get the category, subcategory and explanations phrase for a `catid`
:raises: StopIterator if CatId is not on the schedule
"""
i = (
(x.category, x.subcategory, x.explanations)
for x in self.catlist if x.catid == catid
)
return next(i)