Split inference code out to inference.py

This commit is contained in:
Jamie Hardt
2025-08-10 22:14:02 -07:00
parent e294feb74d
commit 75f88d483b
2 changed files with 190 additions and 96 deletions

View File

@@ -5,11 +5,15 @@ import sys
import subprocess import subprocess
import cmd import cmd
from typing import Optional, Tuple, IO from typing import Optional, IO
import numpy as np # from .inference import classify_text_ranked
import platformdirs
import tqdm from .inference import InferenceContext
# import numpy as np
# import platformdirs
# import tqdm
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
@@ -19,48 +23,6 @@ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
import warnings import warnings
warnings.simplefilter(action='ignore', category=FutureWarning) warnings.simplefilter(action='ignore', category=FutureWarning)
def load_ucs_categories() -> list:
cats = []
ucs_defs = os.path.join(ROOT_DIR, 'ucs-community', 'json', 'en.json')
with open(ucs_defs, 'r') as f:
cats = json.load(f)
return cats
def encoode_category(cat_defn: dict, model: SentenceTransformer) -> np.ndarray:
sentence_components = [cat_defn['Explanations'],
cat_defn['Category'],
cat_defn['SubCategory']
]
sentence_components += cat_defn['Synonyms']
sentence = ", ".join(sentence_components)
return model.encode(sentence, convert_to_numpy=True)
def load_embeddings(ucs: list, model) -> list:
cache_dir = platformdirs.user_cache_dir('ucsinfer', 'Squad 51')
embedding_cache = os.path.join(cache_dir, f"ucs_embedding.cache")
embeddings = []
if os.path.exists(embedding_cache):
with open(embedding_cache, 'rb') as f:
embeddings = pickle.load(f)
else:
print("Calculating embeddings...")
for cat_defn in tqdm.tqdm(ucs):
embeddings += [{
'CatID': cat_defn['CatID'],
'Embedding': encoode_category(cat_defn, model)
}]
os.makedirs(os.path.dirname(embedding_cache), exist_ok=True)
with open(embedding_cache, 'wb') as g:
pickle.dump(embeddings, g)
return embeddings
def description(path: str) -> Optional[str]: def description(path: str) -> Optional[str]:
result = subprocess.run(['ffprobe', '-show_format', '-of', result = subprocess.run(['ffprobe', '-show_format', '-of',
@@ -78,16 +40,7 @@ def description(path: str) -> Optional[str]:
if tags: if tags:
return tags.get("comment", None) return tags.get("comment", None)
def recommend_category(ctx, path) -> tuple[str, list]:
def classify_text_ranked(text, embeddings_list, model, limit=5):
text_embedding = model.encode(text, convert_to_numpy=True)
embeddings = np.array([info['Embedding'] for info in embeddings_list])
sim = model.similarity(text_embedding, embeddings)[0]
maxinds = np.argsort(sim)[-limit:]
return [embeddings_list[x]['CatID'] for x in reversed(maxinds)]
def recommend_category(path, embeddings, model) -> Tuple[str, list]:
""" """
Get a text description of the file at `path` and a list of UCS cat IDs Get a text description of the file at `path` and a list of UCS cat IDs
""" """
@@ -95,39 +48,25 @@ def recommend_category(path, embeddings, model) -> Tuple[str, list]:
if desc is None: if desc is None:
desc = os.path.basename(path) desc = os.path.basename(path)
return desc, classify_text_ranked(desc, embeddings, model) return desc, ctx.classify_text_ranked(desc)
def lookup_cat(catid: str, ucs: list) -> tuple[str,str, str]:
return next( ((x['Category'], x['SubCategory'], x['Explanations']) \
for x in ucs if x['CatID'] == catid))
from shutil import get_terminal_size
class Commands(cmd.Cmd): class Commands(cmd.Cmd):
ctx: InferenceContext
def __init__(self, completekey: str = "tab", stdin: IO[str] | None = None, def __init__(self, completekey: str = "tab", stdin: IO[str] | None = None,
stdout: IO[str] | None = None) -> None: stdout: IO[str] | None = None) -> None:
super().__init__(completekey, stdin, stdout) super().__init__(completekey, stdin, stdout)
self.file_list = [] self.file_list = []
self.model: Optional[SentenceTransformer] = None
self.embeddings: Optional[list] = None
self.catlist: list = [] self.catlist: list = []
self.rec_list = [] self.rec_list = []
self.history = [] self.history = []
# def lookup_cat(self, catid: str) -> tuple[str,str, str]:
def default(self, line): # return next( ((x['Category'], x['SubCategory'], x['Explanations']) \
try: # for x in self.catlist if x['CatID'] == catid))
rec = int(line)
if rec < len(self.rec_list):
print(f"Accept option {rec}")
ind = rec
self.history = [self.rec_list[ind]] + self.history[0:4]
self.onecmd("next")
else:
pass
except ValueError:
super().default(line)
def preloop(self) -> None: def preloop(self) -> None:
self.file_cursor = 0 self.file_cursor = 0
@@ -136,7 +75,18 @@ class Commands(cmd.Cmd):
return super().preloop() return super().preloop()
def precmd(self, line: str): def precmd(self, line: str):
return super().precmd(line) try:
rec = int(line)
if rec < len(self.rec_list):
pass
else:
pass
except ValueError:
pass
finally:
return super().precmd(line)
def postcmd(self, stop: bool, line: str) -> bool: def postcmd(self, stop: bool, line: str) -> bool:
if not stop: if not stop:
@@ -148,22 +98,27 @@ class Commands(cmd.Cmd):
self.prompt = f"(ucsinfer:{self.file_cursor}/{len(self.file_list)}) " self.prompt = f"(ucsinfer:{self.file_cursor}/{len(self.file_list)}) "
def setup_for_file(self): def setup_for_file(self):
file = self.file_list[self.file_cursor] if len(self.file_list) == 0:
desc, recs = recommend_category(file, self.embeddings, self.model) print(" >> NO FILES!")
self.onecmd('file') else:
print(f" >> {desc}") file = self.file_list[self.file_cursor]
self.print_recommendations(recs) desc, recs = recommend_category(self.ctx, file)
self.onecmd('file')
print(f" >> {desc}")
self.print_recommendations(recs)
def print_recommendations(self, top_recs): def print_recommendations(self, top_recs):
self.rec_list = []
cols, _ = get_terminal_size((80,20))
def print_one_rec(index, rec): def print_one_rec(index, rec):
cat, subcat, exp = lookup_cat(rec, self.catlist) cat, subcat, exp = self.ctx.lookup_category(rec)
line = f" [{index:2}] : {rec} - {cat} / {subcat} - {exp}" line = f" [{index:2}] {rec} - {cat} / {subcat} - {exp}"
if len(line) > 75: if len(line) > cols - 3:
line = line[0:75] + "..." line = line[0:cols - 3] + "..."
print(line) print(line)
self.rec_list = []
print("Suggested from description:") print("Suggested from description:")
for rec in top_recs: for rec in top_recs:
print_one_rec(len(self.rec_list), rec) print_one_rec(len(self.rec_list), rec)
@@ -175,6 +130,34 @@ class Commands(cmd.Cmd):
print_one_rec(len(self.rec_list), rec) print_one_rec(len(self.rec_list), rec)
self.rec_list.append(rec) self.rec_list.append(rec)
def do_about(self, line: str):
'Print information about recommendation NUMBER'
try:
picked = int(line)
if picked < len(self.rec_list):
cat, subcat, exp = \
self.lookup_cat(self.rec_list[picked])
print(f" CatID: {self.rec_list[picked]}")
print(f" Category: {cat}")
print(f" SubCategory: {subcat}")
print(f" Explanation: {exp}")
except ValueError:
print(f" *** Value \"{line}\" not recognized")
def do_use(self, line: str):
"""Apply recomendation NUMBER to the current file and advance to the
next one"""
try:
picked = int(line)
print(f" :: Using {self.rec_list[picked]}")
self.onecmd('next')
except ValueError:
print(" *** Value \"{line}\" not recognized")
def do_file(self, _): def do_file(self, _):
'Print info about the current file' 'Print info about the current file'
print("---") print("---")
@@ -185,10 +168,10 @@ class Commands(cmd.Cmd):
else: else:
print( " > No file") print( " > No file")
def do_lookup(self, args): # def do_lookup(self, args):
'print a list of UCS categories similar to the argument' # 'print a list of UCS categories similar to the argument'
self.rec_list = classify_text_ranked(args, self.embeddings, # self.rec_list = classify_text_ranked(args, self.embeddings,
self.model) # self.model)
def do_ls(self, _): def do_ls(self, _):
'Print list of all files in the buffer' 'Print list of all files in the buffer'
@@ -216,17 +199,15 @@ class Commands(cmd.Cmd):
def main(): def main():
cats = load_ucs_categories() # cats = load_ucs_categories()
print(f"Loaded UCS categories.", file=sys.stderr) print(f"Loaded UCS categories.", file=sys.stderr)
model = SentenceTransformer("paraphrase-multilingual-mpnet-base-v2") model = SentenceTransformer("paraphrase-multilingual-mpnet-base-v2")
embeddings = load_embeddings(cats, model) # embeddings = load_embeddings(cats, model)
print(f"Loaded embeddings...", file=sys.stderr) print(f"Loaded embeddings...", file=sys.stderr)
com = Commands() com = Commands()
com.file_list = sys.argv[1:] com.file_list = sys.argv[1:]
com.model = model com.ctx = InferenceContext(model=model)
com.embeddings = embeddings
com.catlist = cats
com.cmdloop() com.cmdloop()

113
ucsinfer/inference.py Normal file
View File

@@ -0,0 +1,113 @@
import os.path
import json
import pickle
from functools import cached_property
from typing import NamedTuple
import numpy as np
import platformdirs
from sentence_transformers import SentenceTransformer
def classify_text_ranked(text, embeddings_list, model, limit=5):
text_embedding = model.encode(text, convert_to_numpy=True)
embeddings = np.array([info['Embedding'] for info in embeddings_list])
sim = model.similarity(text_embedding, embeddings)[0]
maxinds = np.argsort(sim)[-limit:]
return [embeddings_list[x]['CatID'] for x in reversed(maxinds)]
class Ucs(NamedTuple):
catid: str
category: str
subcategory: str
explanations: str
synonymns: list[str]
@classmethod
def from_dict(cls, d: dict):
return Ucs(catid=d['CatID'], category=d['Category'],
subcategory=d['SubCategory'],
explanations=d['Explanations'], synonymns=d['Synonyms'])
class InferenceContext:
"""
Maintains caches and resources for UCS category inference.
"""
model: SentenceTransformer
def __init__(self, model: SentenceTransformer):
self.model = model
@cached_property
def catlist(self) -> list[Ucs]:
FILE_ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
cats = []
ucs_defs = os.path.join(FILE_ROOT_DIR, 'ucs-community', 'json',
'en.json')
with open(ucs_defs, 'r') as f:
cats = json.load(f)
return [Ucs.from_dict(cat) for cat in cats]
@cached_property
def embeddings(self) -> list[dict]:
cache_dir = platformdirs.user_cache_dir("ucsinfer", 'Squad 51')
embedding_cache = os.path.join(cache_dir, f"ucs_embedding.cache")
embeddings = []
if os.path.exists(embedding_cache):
with open(embedding_cache, 'rb') as f:
embeddings = pickle.load(f)
else:
print("Calculating embeddings...")
for cat_defn in self.catlist:
embeddings += [{
'CatID': cat_defn.catid,
'Embedding': self._encode_category(cat_defn)
}]
os.makedirs(os.path.dirname(embedding_cache), exist_ok=True)
with open(embedding_cache, 'wb') as g:
pickle.dump(embeddings, g)
return embeddings
def _encode_category(self, cat: Ucs) -> np.ndarray:
sentence_components = [cat.explanations,
cat.category,
cat.subcategory
]
sentence_components += cat.synonymns
sentence = ", ".join(sentence_components)
return self.model.encode(sentence, convert_to_numpy=True)
def classify_text_ranked(self, text: str, limit: int = 5) -> list[str]:
"""
Sort all UCS catids according to their similarity to `text`
"""
text_embedding = self.model.encode(text, convert_to_numpy=True)
embeddings = np.array([info['Embedding'] for info in self.embeddings])
sim = self.model.similarity(text_embedding, embeddings)[0]
maxinds = np.argsort(sim)[-limit:]
return [self.embeddings[x]['CatID'] for x in reversed(maxinds)]
def lookup_category(self, catid) -> tuple[str, str, str]:
"""
Get the category, subcategory and explanations phrase for a `catid`
:raises: StopIterator if CatId is not on the schedule
"""
i = (
(x.category, x.subcategory, x.explanations) \
for x in self.catlist if x.catid == catid
)
return next(i)