Files
ucsinfer/ucsinfer/inference.py

156 lines
5.0 KiB
Python

import os.path
import json
import pickle
from functools import cached_property
from typing import NamedTuple
import numpy as np
import platformdirs
from sentence_transformers import SentenceTransformer
def classify_text_ranked(text, embeddings_list, model, limit=5):
text_embedding = model.encode(text, convert_to_numpy=True)
embeddings = np.array([info['Embedding'] for info in embeddings_list])
sim = model.similarity(text_embedding, embeddings)[0]
maxinds = np.argsort(sim)[-limit:]
return [embeddings_list[x]['CatID'] for x in reversed(maxinds)]
class Ucs(NamedTuple):
catid: str
category: str
subcategory: str
explanations: str
synonymns: list[str]
@classmethod
def from_dict(cls, d: dict):
return Ucs(catid=d['CatID'], category=d['Category'],
subcategory=d['SubCategory'],
explanations=d['Explanations'], synonymns=d['Synonyms'])
def load_ucs(full_ucs: bool = True) -> list[Ucs]:
FILE_ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
cats = []
ucs_defs = os.path.join(FILE_ROOT_DIR, 'ucs-community', 'json',
'en.json')
with open(ucs_defs, 'r') as f:
cats = json.load(f)
ucs = [Ucs.from_dict(cat) for cat in cats]
if full_ucs:
return ucs
else:
return [cat for cat in ucs if \
cat.category not in ['FOLEY', 'ARCHIVED']]
class InferenceContext:
"""
Maintains caches and resources for UCS category inference.
"""
model: SentenceTransformer
model_name: str
def __init__(self, model_name: str, use_cached_model: bool = True,
use_full_ucs: bool = False):
self.model_name = model_name
self.use_full_ucs = use_full_ucs
cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer",
'Squad 51')
model_cache_path = os.path.join(cache_dir,
f"{self.model_name}.cache")
if use_cached_model:
if os.path.exists(model_cache_path):
self.model = SentenceTransformer(model_cache_path)
else:
self.model = SentenceTransformer(model_name)
self.model.save(model_cache_path)
else:
self.model = SentenceTransformer(model_name)
@cached_property
def catlist(self) -> list[Ucs]:
return load_ucs(full_ucs=self.use_full_ucs)
@cached_property
def embeddings(self) -> list[dict]:
cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer",
'Squad 51')
embedding_cache_path = os.path.join(
cache_dir,
f"{self.model_name}-ucs_embedding.cache")
embeddings = []
if os.path.exists(embedding_cache_path):
with open(embedding_cache_path, 'rb') as f:
embeddings = pickle.load(f)
else:
print(f"Calculating embeddings for model {self.model_name}...")
# we need to calculate the embeddings for all cats, not just the
# ones we're loading for this run
full_catlist = load_ucs(full_ucs= True)
for cat_defn in full_catlist:
embeddings += [{
'CatID': cat_defn.catid,
'Embedding': self._encode_category(cat_defn)
}]
os.makedirs(os.path.dirname(embedding_cache_path), exist_ok=True)
with open(embedding_cache_path, 'wb') as g:
pickle.dump(embeddings, g)
whitelisted_cats = [cat.catid for cat in self.catlist]
return [e for e in embeddings if e['CatID'] in whitelisted_cats]
def _encode_category(self, cat: Ucs) -> np.ndarray:
sentence_components = [cat.explanations,
cat.category,
cat.subcategory
]
sentence_components += cat.synonymns
sentence = ", ".join(sentence_components)
return self.model.encode(sentence, convert_to_numpy=True)
def classify_text_ranked(self, text: str, limit: int = 5) -> list[str]:
"""
Sort all UCS catids according to their similarity to `text`
"""
text_embedding = self.model.encode(text, convert_to_numpy=True)
embeddings = np.array([info['Embedding'] for info in self.embeddings])
sim = self.model.similarity(text_embedding, embeddings)[0]
maxinds = np.argsort(sim)[-limit:]
return [self.embeddings[x]['CatID'] for x in reversed(maxinds)]
def lookup_category(self, catid) -> tuple[str, str, str]:
"""
Get the category, subcategory and explanations phrase for a `catid`
:raises: StopIterator if CatId is not on the schedule
"""
i = (
(x.category, x.subcategory, x.explanations)
for x in self.catlist if x.catid == catid
)
return next(i)