Split inference code out to inference.py
This commit is contained in:
@@ -5,11 +5,15 @@ import sys
|
||||
import subprocess
|
||||
import cmd
|
||||
|
||||
from typing import Optional, Tuple, IO
|
||||
from typing import Optional, IO
|
||||
|
||||
import numpy as np
|
||||
import platformdirs
|
||||
import tqdm
|
||||
# from .inference import classify_text_ranked
|
||||
|
||||
from .inference import InferenceContext
|
||||
|
||||
# import numpy as np
|
||||
# import platformdirs
|
||||
# import tqdm
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -19,48 +23,6 @@ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
||||
import warnings
|
||||
warnings.simplefilter(action='ignore', category=FutureWarning)
|
||||
|
||||
def load_ucs_categories() -> list:
|
||||
cats = []
|
||||
ucs_defs = os.path.join(ROOT_DIR, 'ucs-community', 'json', 'en.json')
|
||||
|
||||
with open(ucs_defs, 'r') as f:
|
||||
cats = json.load(f)
|
||||
|
||||
return cats
|
||||
|
||||
def encoode_category(cat_defn: dict, model: SentenceTransformer) -> np.ndarray:
|
||||
sentence_components = [cat_defn['Explanations'],
|
||||
cat_defn['Category'],
|
||||
cat_defn['SubCategory']
|
||||
]
|
||||
sentence_components += cat_defn['Synonyms']
|
||||
sentence = ", ".join(sentence_components)
|
||||
return model.encode(sentence, convert_to_numpy=True)
|
||||
|
||||
def load_embeddings(ucs: list, model) -> list:
|
||||
cache_dir = platformdirs.user_cache_dir('ucsinfer', 'Squad 51')
|
||||
embedding_cache = os.path.join(cache_dir, f"ucs_embedding.cache")
|
||||
embeddings = []
|
||||
|
||||
if os.path.exists(embedding_cache):
|
||||
with open(embedding_cache, 'rb') as f:
|
||||
embeddings = pickle.load(f)
|
||||
|
||||
else:
|
||||
print("Calculating embeddings...")
|
||||
|
||||
for cat_defn in tqdm.tqdm(ucs):
|
||||
embeddings += [{
|
||||
'CatID': cat_defn['CatID'],
|
||||
'Embedding': encoode_category(cat_defn, model)
|
||||
}]
|
||||
|
||||
os.makedirs(os.path.dirname(embedding_cache), exist_ok=True)
|
||||
with open(embedding_cache, 'wb') as g:
|
||||
pickle.dump(embeddings, g)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
def description(path: str) -> Optional[str]:
|
||||
result = subprocess.run(['ffprobe', '-show_format', '-of',
|
||||
@@ -78,16 +40,7 @@ def description(path: str) -> Optional[str]:
|
||||
if tags:
|
||||
return tags.get("comment", None)
|
||||
|
||||
|
||||
def classify_text_ranked(text, embeddings_list, model, limit=5):
|
||||
text_embedding = model.encode(text, convert_to_numpy=True)
|
||||
embeddings = np.array([info['Embedding'] for info in embeddings_list])
|
||||
sim = model.similarity(text_embedding, embeddings)[0]
|
||||
maxinds = np.argsort(sim)[-limit:]
|
||||
return [embeddings_list[x]['CatID'] for x in reversed(maxinds)]
|
||||
|
||||
|
||||
def recommend_category(path, embeddings, model) -> Tuple[str, list]:
|
||||
def recommend_category(ctx, path) -> tuple[str, list]:
|
||||
"""
|
||||
Get a text description of the file at `path` and a list of UCS cat IDs
|
||||
"""
|
||||
@@ -95,39 +48,25 @@ def recommend_category(path, embeddings, model) -> Tuple[str, list]:
|
||||
if desc is None:
|
||||
desc = os.path.basename(path)
|
||||
|
||||
return desc, classify_text_ranked(desc, embeddings, model)
|
||||
return desc, ctx.classify_text_ranked(desc)
|
||||
|
||||
def lookup_cat(catid: str, ucs: list) -> tuple[str,str, str]:
|
||||
return next( ((x['Category'], x['SubCategory'], x['Explanations']) \
|
||||
for x in ucs if x['CatID'] == catid))
|
||||
|
||||
from shutil import get_terminal_size
|
||||
|
||||
class Commands(cmd.Cmd):
|
||||
ctx: InferenceContext
|
||||
|
||||
def __init__(self, completekey: str = "tab", stdin: IO[str] | None = None,
|
||||
stdout: IO[str] | None = None) -> None:
|
||||
super().__init__(completekey, stdin, stdout)
|
||||
self.file_list = []
|
||||
self.model: Optional[SentenceTransformer] = None
|
||||
self.embeddings: Optional[list] = None
|
||||
self.catlist: list = []
|
||||
self.rec_list = []
|
||||
self.history = []
|
||||
|
||||
|
||||
def default(self, line):
|
||||
try:
|
||||
rec = int(line)
|
||||
if rec < len(self.rec_list):
|
||||
print(f"Accept option {rec}")
|
||||
ind = rec
|
||||
self.history = [self.rec_list[ind]] + self.history[0:4]
|
||||
self.onecmd("next")
|
||||
else:
|
||||
pass
|
||||
|
||||
except ValueError:
|
||||
super().default(line)
|
||||
# def lookup_cat(self, catid: str) -> tuple[str,str, str]:
|
||||
# return next( ((x['Category'], x['SubCategory'], x['Explanations']) \
|
||||
# for x in self.catlist if x['CatID'] == catid))
|
||||
|
||||
def preloop(self) -> None:
|
||||
self.file_cursor = 0
|
||||
@@ -136,6 +75,17 @@ class Commands(cmd.Cmd):
|
||||
return super().preloop()
|
||||
|
||||
def precmd(self, line: str):
|
||||
try:
|
||||
rec = int(line)
|
||||
if rec < len(self.rec_list):
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
finally:
|
||||
return super().precmd(line)
|
||||
|
||||
def postcmd(self, stop: bool, line: str) -> bool:
|
||||
@@ -148,22 +98,27 @@ class Commands(cmd.Cmd):
|
||||
self.prompt = f"(ucsinfer:{self.file_cursor}/{len(self.file_list)}) "
|
||||
|
||||
def setup_for_file(self):
|
||||
if len(self.file_list) == 0:
|
||||
print(" >> NO FILES!")
|
||||
else:
|
||||
file = self.file_list[self.file_cursor]
|
||||
desc, recs = recommend_category(file, self.embeddings, self.model)
|
||||
desc, recs = recommend_category(self.ctx, file)
|
||||
self.onecmd('file')
|
||||
print(f" >> {desc}")
|
||||
self.print_recommendations(recs)
|
||||
|
||||
def print_recommendations(self, top_recs):
|
||||
self.rec_list = []
|
||||
|
||||
cols, _ = get_terminal_size((80,20))
|
||||
|
||||
def print_one_rec(index, rec):
|
||||
cat, subcat, exp = lookup_cat(rec, self.catlist)
|
||||
line = f" [{index:2}] : {rec} - {cat} / {subcat} - {exp}"
|
||||
if len(line) > 75:
|
||||
line = line[0:75] + "..."
|
||||
cat, subcat, exp = self.ctx.lookup_category(rec)
|
||||
line = f" [{index:2}] {rec} - {cat} / {subcat} - {exp}"
|
||||
if len(line) > cols - 3:
|
||||
line = line[0:cols - 3] + "..."
|
||||
print(line)
|
||||
|
||||
self.rec_list = []
|
||||
print("Suggested from description:")
|
||||
for rec in top_recs:
|
||||
print_one_rec(len(self.rec_list), rec)
|
||||
@@ -175,6 +130,34 @@ class Commands(cmd.Cmd):
|
||||
print_one_rec(len(self.rec_list), rec)
|
||||
self.rec_list.append(rec)
|
||||
|
||||
def do_about(self, line: str):
|
||||
'Print information about recommendation NUMBER'
|
||||
try:
|
||||
picked = int(line)
|
||||
if picked < len(self.rec_list):
|
||||
cat, subcat, exp = \
|
||||
self.lookup_cat(self.rec_list[picked])
|
||||
|
||||
print(f" CatID: {self.rec_list[picked]}")
|
||||
print(f" Category: {cat}")
|
||||
print(f" SubCategory: {subcat}")
|
||||
print(f" Explanation: {exp}")
|
||||
|
||||
except ValueError:
|
||||
print(f" *** Value \"{line}\" not recognized")
|
||||
|
||||
|
||||
def do_use(self, line: str):
|
||||
"""Apply recomendation NUMBER to the current file and advance to the
|
||||
next one"""
|
||||
|
||||
try:
|
||||
picked = int(line)
|
||||
print(f" :: Using {self.rec_list[picked]}")
|
||||
self.onecmd('next')
|
||||
except ValueError:
|
||||
print(" *** Value \"{line}\" not recognized")
|
||||
|
||||
def do_file(self, _):
|
||||
'Print info about the current file'
|
||||
print("---")
|
||||
@@ -185,10 +168,10 @@ class Commands(cmd.Cmd):
|
||||
else:
|
||||
print( " > No file")
|
||||
|
||||
def do_lookup(self, args):
|
||||
'print a list of UCS categories similar to the argument'
|
||||
self.rec_list = classify_text_ranked(args, self.embeddings,
|
||||
self.model)
|
||||
# def do_lookup(self, args):
|
||||
# 'print a list of UCS categories similar to the argument'
|
||||
# self.rec_list = classify_text_ranked(args, self.embeddings,
|
||||
# self.model)
|
||||
|
||||
def do_ls(self, _):
|
||||
'Print list of all files in the buffer'
|
||||
@@ -216,17 +199,15 @@ class Commands(cmd.Cmd):
|
||||
|
||||
|
||||
def main():
|
||||
cats = load_ucs_categories()
|
||||
# cats = load_ucs_categories()
|
||||
print(f"Loaded UCS categories.", file=sys.stderr)
|
||||
model = SentenceTransformer("paraphrase-multilingual-mpnet-base-v2")
|
||||
embeddings = load_embeddings(cats, model)
|
||||
# embeddings = load_embeddings(cats, model)
|
||||
print(f"Loaded embeddings...", file=sys.stderr)
|
||||
|
||||
com = Commands()
|
||||
com.file_list = sys.argv[1:]
|
||||
com.model = model
|
||||
com.embeddings = embeddings
|
||||
com.catlist = cats
|
||||
com.ctx = InferenceContext(model=model)
|
||||
|
||||
com.cmdloop()
|
||||
|
||||
|
113
ucsinfer/inference.py
Normal file
113
ucsinfer/inference.py
Normal file
@@ -0,0 +1,113 @@
|
||||
|
||||
import os.path
|
||||
import json
|
||||
import pickle
|
||||
from functools import cached_property
|
||||
|
||||
from typing import NamedTuple
|
||||
|
||||
import numpy as np
|
||||
import platformdirs
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
def classify_text_ranked(text, embeddings_list, model, limit=5):
|
||||
text_embedding = model.encode(text, convert_to_numpy=True)
|
||||
embeddings = np.array([info['Embedding'] for info in embeddings_list])
|
||||
sim = model.similarity(text_embedding, embeddings)[0]
|
||||
maxinds = np.argsort(sim)[-limit:]
|
||||
return [embeddings_list[x]['CatID'] for x in reversed(maxinds)]
|
||||
|
||||
|
||||
class Ucs(NamedTuple):
|
||||
catid: str
|
||||
category: str
|
||||
subcategory: str
|
||||
explanations: str
|
||||
synonymns: list[str]
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict):
|
||||
return Ucs(catid=d['CatID'], category=d['Category'],
|
||||
subcategory=d['SubCategory'],
|
||||
explanations=d['Explanations'], synonymns=d['Synonyms'])
|
||||
|
||||
class InferenceContext:
|
||||
"""
|
||||
Maintains caches and resources for UCS category inference.
|
||||
"""
|
||||
|
||||
model: SentenceTransformer
|
||||
|
||||
def __init__(self, model: SentenceTransformer):
|
||||
self.model = model
|
||||
|
||||
@cached_property
|
||||
def catlist(self) -> list[Ucs]:
|
||||
FILE_ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
cats = []
|
||||
ucs_defs = os.path.join(FILE_ROOT_DIR, 'ucs-community', 'json',
|
||||
'en.json')
|
||||
|
||||
with open(ucs_defs, 'r') as f:
|
||||
cats = json.load(f)
|
||||
|
||||
return [Ucs.from_dict(cat) for cat in cats]
|
||||
|
||||
@cached_property
|
||||
def embeddings(self) -> list[dict]:
|
||||
cache_dir = platformdirs.user_cache_dir("ucsinfer", 'Squad 51')
|
||||
embedding_cache = os.path.join(cache_dir, f"ucs_embedding.cache")
|
||||
embeddings = []
|
||||
|
||||
if os.path.exists(embedding_cache):
|
||||
with open(embedding_cache, 'rb') as f:
|
||||
embeddings = pickle.load(f)
|
||||
|
||||
else:
|
||||
print("Calculating embeddings...")
|
||||
|
||||
for cat_defn in self.catlist:
|
||||
embeddings += [{
|
||||
'CatID': cat_defn.catid,
|
||||
'Embedding': self._encode_category(cat_defn)
|
||||
}]
|
||||
|
||||
os.makedirs(os.path.dirname(embedding_cache), exist_ok=True)
|
||||
with open(embedding_cache, 'wb') as g:
|
||||
pickle.dump(embeddings, g)
|
||||
|
||||
return embeddings
|
||||
|
||||
def _encode_category(self, cat: Ucs) -> np.ndarray:
|
||||
sentence_components = [cat.explanations,
|
||||
cat.category,
|
||||
cat.subcategory
|
||||
]
|
||||
sentence_components += cat.synonymns
|
||||
sentence = ", ".join(sentence_components)
|
||||
return self.model.encode(sentence, convert_to_numpy=True)
|
||||
|
||||
def classify_text_ranked(self, text: str, limit: int = 5) -> list[str]:
|
||||
"""
|
||||
Sort all UCS catids according to their similarity to `text`
|
||||
"""
|
||||
text_embedding = self.model.encode(text, convert_to_numpy=True)
|
||||
embeddings = np.array([info['Embedding'] for info in self.embeddings])
|
||||
sim = self.model.similarity(text_embedding, embeddings)[0]
|
||||
maxinds = np.argsort(sim)[-limit:]
|
||||
return [self.embeddings[x]['CatID'] for x in reversed(maxinds)]
|
||||
|
||||
def lookup_category(self, catid) -> tuple[str, str, str]:
|
||||
"""
|
||||
Get the category, subcategory and explanations phrase for a `catid`
|
||||
|
||||
:raises: StopIterator if CatId is not on the schedule
|
||||
"""
|
||||
i = (
|
||||
(x.category, x.subcategory, x.explanations) \
|
||||
for x in self.catlist if x.catid == catid
|
||||
)
|
||||
|
||||
return next(i)
|
||||
|
Reference in New Issue
Block a user