156 lines
5.0 KiB
Python
156 lines
5.0 KiB
Python
|
|
import os.path
|
|
import json
|
|
import pickle
|
|
from functools import cached_property
|
|
|
|
from typing import NamedTuple
|
|
|
|
import numpy as np
|
|
import platformdirs
|
|
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
|
|
def classify_text_ranked(text, embeddings_list, model, limit=5):
|
|
text_embedding = model.encode(text, convert_to_numpy=True)
|
|
embeddings = np.array([info['Embedding'] for info in embeddings_list])
|
|
sim = model.similarity(text_embedding, embeddings)[0]
|
|
maxinds = np.argsort(sim)[-limit:]
|
|
return [embeddings_list[x]['CatID'] for x in reversed(maxinds)]
|
|
|
|
|
|
class Ucs(NamedTuple):
|
|
catid: str
|
|
category: str
|
|
subcategory: str
|
|
explanations: str
|
|
synonymns: list[str]
|
|
|
|
@classmethod
|
|
def from_dict(cls, d: dict):
|
|
return Ucs(catid=d['CatID'], category=d['Category'],
|
|
subcategory=d['SubCategory'],
|
|
explanations=d['Explanations'], synonymns=d['Synonyms'])
|
|
|
|
|
|
def load_ucs(full_ucs: bool = True) -> list[Ucs]:
|
|
FILE_ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
cats = []
|
|
ucs_defs = os.path.join(FILE_ROOT_DIR, 'ucs-community', 'json',
|
|
'en.json')
|
|
|
|
with open(ucs_defs, 'r') as f:
|
|
cats = json.load(f)
|
|
|
|
ucs = [Ucs.from_dict(cat) for cat in cats]
|
|
|
|
if full_ucs:
|
|
return ucs
|
|
else:
|
|
return [cat for cat in ucs if \
|
|
cat.category not in ['FOLEY', 'ARCHIVED']]
|
|
|
|
|
|
class InferenceContext:
|
|
"""
|
|
Maintains caches and resources for UCS category inference.
|
|
"""
|
|
|
|
model: SentenceTransformer
|
|
model_name: str
|
|
|
|
def __init__(self, model_name: str, use_cached_model: bool = True,
|
|
use_full_ucs: bool = False):
|
|
self.model_name = model_name
|
|
self.use_full_ucs = use_full_ucs
|
|
|
|
cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer",
|
|
'Squad 51')
|
|
|
|
model_cache_path = os.path.join(cache_dir,
|
|
f"{self.model_name}.cache")
|
|
|
|
if use_cached_model:
|
|
if os.path.exists(model_cache_path):
|
|
self.model = SentenceTransformer(model_cache_path)
|
|
else:
|
|
self.model = SentenceTransformer(model_name)
|
|
self.model.save(model_cache_path)
|
|
|
|
else:
|
|
self.model = SentenceTransformer(model_name)
|
|
|
|
|
|
@cached_property
|
|
def catlist(self) -> list[Ucs]:
|
|
return load_ucs(full_ucs=self.use_full_ucs)
|
|
|
|
@cached_property
|
|
def embeddings(self) -> list[dict]:
|
|
cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer",
|
|
'Squad 51')
|
|
embedding_cache_path = os.path.join(
|
|
cache_dir,
|
|
f"{self.model_name}-ucs_embedding.cache")
|
|
|
|
embeddings = []
|
|
|
|
if os.path.exists(embedding_cache_path):
|
|
with open(embedding_cache_path, 'rb') as f:
|
|
embeddings = pickle.load(f)
|
|
|
|
else:
|
|
print(f"Calculating embeddings for model {self.model_name}...")
|
|
|
|
# we need to calculate the embeddings for all cats, not just the
|
|
# ones we're loading for this run
|
|
|
|
full_catlist = load_ucs(full_ucs= True)
|
|
|
|
for cat_defn in full_catlist:
|
|
embeddings += [{
|
|
'CatID': cat_defn.catid,
|
|
'Embedding': self._encode_category(cat_defn)
|
|
}]
|
|
|
|
os.makedirs(os.path.dirname(embedding_cache_path), exist_ok=True)
|
|
with open(embedding_cache_path, 'wb') as g:
|
|
pickle.dump(embeddings, g)
|
|
|
|
whitelisted_cats = [cat.catid for cat in self.catlist]
|
|
|
|
return [e for e in embeddings if e['CatID'] in whitelisted_cats]
|
|
|
|
def _encode_category(self, cat: Ucs) -> np.ndarray:
|
|
sentence_components = [cat.explanations,
|
|
cat.category,
|
|
cat.subcategory
|
|
]
|
|
sentence_components += cat.synonymns
|
|
sentence = ", ".join(sentence_components)
|
|
return self.model.encode(sentence, convert_to_numpy=True)
|
|
|
|
def classify_text_ranked(self, text: str, limit: int = 5) -> list[str]:
|
|
"""
|
|
Sort all UCS catids according to their similarity to `text`
|
|
"""
|
|
text_embedding = self.model.encode(text, convert_to_numpy=True)
|
|
embeddings = np.array([info['Embedding'] for info in self.embeddings])
|
|
sim = self.model.similarity(text_embedding, embeddings)[0]
|
|
maxinds = np.argsort(sim)[-limit:]
|
|
return [self.embeddings[x]['CatID'] for x in reversed(maxinds)]
|
|
|
|
def lookup_category(self, catid) -> tuple[str, str, str]:
|
|
"""
|
|
Get the category, subcategory and explanations phrase for a `catid`
|
|
|
|
:raises: StopIterator if CatId is not on the schedule
|
|
"""
|
|
i = (
|
|
(x.category, x.subcategory, x.explanations)
|
|
for x in self.catlist if x.catid == catid
|
|
)
|
|
|
|
return next(i)
|