Split inference code out to inference.py
This commit is contained in:
@@ -5,11 +5,15 @@ import sys
|
|||||||
import subprocess
|
import subprocess
|
||||||
import cmd
|
import cmd
|
||||||
|
|
||||||
from typing import Optional, Tuple, IO
|
from typing import Optional, IO
|
||||||
|
|
||||||
import numpy as np
|
# from .inference import classify_text_ranked
|
||||||
import platformdirs
|
|
||||||
import tqdm
|
from .inference import InferenceContext
|
||||||
|
|
||||||
|
# import numpy as np
|
||||||
|
# import platformdirs
|
||||||
|
# import tqdm
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
@@ -19,48 +23,6 @@ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
|||||||
import warnings
|
import warnings
|
||||||
warnings.simplefilter(action='ignore', category=FutureWarning)
|
warnings.simplefilter(action='ignore', category=FutureWarning)
|
||||||
|
|
||||||
def load_ucs_categories() -> list:
|
|
||||||
cats = []
|
|
||||||
ucs_defs = os.path.join(ROOT_DIR, 'ucs-community', 'json', 'en.json')
|
|
||||||
|
|
||||||
with open(ucs_defs, 'r') as f:
|
|
||||||
cats = json.load(f)
|
|
||||||
|
|
||||||
return cats
|
|
||||||
|
|
||||||
def encoode_category(cat_defn: dict, model: SentenceTransformer) -> np.ndarray:
|
|
||||||
sentence_components = [cat_defn['Explanations'],
|
|
||||||
cat_defn['Category'],
|
|
||||||
cat_defn['SubCategory']
|
|
||||||
]
|
|
||||||
sentence_components += cat_defn['Synonyms']
|
|
||||||
sentence = ", ".join(sentence_components)
|
|
||||||
return model.encode(sentence, convert_to_numpy=True)
|
|
||||||
|
|
||||||
def load_embeddings(ucs: list, model) -> list:
|
|
||||||
cache_dir = platformdirs.user_cache_dir('ucsinfer', 'Squad 51')
|
|
||||||
embedding_cache = os.path.join(cache_dir, f"ucs_embedding.cache")
|
|
||||||
embeddings = []
|
|
||||||
|
|
||||||
if os.path.exists(embedding_cache):
|
|
||||||
with open(embedding_cache, 'rb') as f:
|
|
||||||
embeddings = pickle.load(f)
|
|
||||||
|
|
||||||
else:
|
|
||||||
print("Calculating embeddings...")
|
|
||||||
|
|
||||||
for cat_defn in tqdm.tqdm(ucs):
|
|
||||||
embeddings += [{
|
|
||||||
'CatID': cat_defn['CatID'],
|
|
||||||
'Embedding': encoode_category(cat_defn, model)
|
|
||||||
}]
|
|
||||||
|
|
||||||
os.makedirs(os.path.dirname(embedding_cache), exist_ok=True)
|
|
||||||
with open(embedding_cache, 'wb') as g:
|
|
||||||
pickle.dump(embeddings, g)
|
|
||||||
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
|
|
||||||
def description(path: str) -> Optional[str]:
|
def description(path: str) -> Optional[str]:
|
||||||
result = subprocess.run(['ffprobe', '-show_format', '-of',
|
result = subprocess.run(['ffprobe', '-show_format', '-of',
|
||||||
@@ -78,16 +40,7 @@ def description(path: str) -> Optional[str]:
|
|||||||
if tags:
|
if tags:
|
||||||
return tags.get("comment", None)
|
return tags.get("comment", None)
|
||||||
|
|
||||||
|
def recommend_category(ctx, path) -> tuple[str, list]:
|
||||||
def classify_text_ranked(text, embeddings_list, model, limit=5):
|
|
||||||
text_embedding = model.encode(text, convert_to_numpy=True)
|
|
||||||
embeddings = np.array([info['Embedding'] for info in embeddings_list])
|
|
||||||
sim = model.similarity(text_embedding, embeddings)[0]
|
|
||||||
maxinds = np.argsort(sim)[-limit:]
|
|
||||||
return [embeddings_list[x]['CatID'] for x in reversed(maxinds)]
|
|
||||||
|
|
||||||
|
|
||||||
def recommend_category(path, embeddings, model) -> Tuple[str, list]:
|
|
||||||
"""
|
"""
|
||||||
Get a text description of the file at `path` and a list of UCS cat IDs
|
Get a text description of the file at `path` and a list of UCS cat IDs
|
||||||
"""
|
"""
|
||||||
@@ -95,39 +48,25 @@ def recommend_category(path, embeddings, model) -> Tuple[str, list]:
|
|||||||
if desc is None:
|
if desc is None:
|
||||||
desc = os.path.basename(path)
|
desc = os.path.basename(path)
|
||||||
|
|
||||||
return desc, classify_text_ranked(desc, embeddings, model)
|
return desc, ctx.classify_text_ranked(desc)
|
||||||
|
|
||||||
def lookup_cat(catid: str, ucs: list) -> tuple[str,str, str]:
|
|
||||||
return next( ((x['Category'], x['SubCategory'], x['Explanations']) \
|
|
||||||
for x in ucs if x['CatID'] == catid))
|
|
||||||
|
|
||||||
|
from shutil import get_terminal_size
|
||||||
|
|
||||||
class Commands(cmd.Cmd):
|
class Commands(cmd.Cmd):
|
||||||
|
ctx: InferenceContext
|
||||||
|
|
||||||
def __init__(self, completekey: str = "tab", stdin: IO[str] | None = None,
|
def __init__(self, completekey: str = "tab", stdin: IO[str] | None = None,
|
||||||
stdout: IO[str] | None = None) -> None:
|
stdout: IO[str] | None = None) -> None:
|
||||||
super().__init__(completekey, stdin, stdout)
|
super().__init__(completekey, stdin, stdout)
|
||||||
self.file_list = []
|
self.file_list = []
|
||||||
self.model: Optional[SentenceTransformer] = None
|
|
||||||
self.embeddings: Optional[list] = None
|
|
||||||
self.catlist: list = []
|
self.catlist: list = []
|
||||||
self.rec_list = []
|
self.rec_list = []
|
||||||
self.history = []
|
self.history = []
|
||||||
|
|
||||||
|
# def lookup_cat(self, catid: str) -> tuple[str,str, str]:
|
||||||
def default(self, line):
|
# return next( ((x['Category'], x['SubCategory'], x['Explanations']) \
|
||||||
try:
|
# for x in self.catlist if x['CatID'] == catid))
|
||||||
rec = int(line)
|
|
||||||
if rec < len(self.rec_list):
|
|
||||||
print(f"Accept option {rec}")
|
|
||||||
ind = rec
|
|
||||||
self.history = [self.rec_list[ind]] + self.history[0:4]
|
|
||||||
self.onecmd("next")
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
|
|
||||||
except ValueError:
|
|
||||||
super().default(line)
|
|
||||||
|
|
||||||
def preloop(self) -> None:
|
def preloop(self) -> None:
|
||||||
self.file_cursor = 0
|
self.file_cursor = 0
|
||||||
@@ -136,7 +75,18 @@ class Commands(cmd.Cmd):
|
|||||||
return super().preloop()
|
return super().preloop()
|
||||||
|
|
||||||
def precmd(self, line: str):
|
def precmd(self, line: str):
|
||||||
return super().precmd(line)
|
try:
|
||||||
|
rec = int(line)
|
||||||
|
if rec < len(self.rec_list):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
finally:
|
||||||
|
return super().precmd(line)
|
||||||
|
|
||||||
def postcmd(self, stop: bool, line: str) -> bool:
|
def postcmd(self, stop: bool, line: str) -> bool:
|
||||||
if not stop:
|
if not stop:
|
||||||
@@ -148,22 +98,27 @@ class Commands(cmd.Cmd):
|
|||||||
self.prompt = f"(ucsinfer:{self.file_cursor}/{len(self.file_list)}) "
|
self.prompt = f"(ucsinfer:{self.file_cursor}/{len(self.file_list)}) "
|
||||||
|
|
||||||
def setup_for_file(self):
|
def setup_for_file(self):
|
||||||
file = self.file_list[self.file_cursor]
|
if len(self.file_list) == 0:
|
||||||
desc, recs = recommend_category(file, self.embeddings, self.model)
|
print(" >> NO FILES!")
|
||||||
self.onecmd('file')
|
else:
|
||||||
print(f" >> {desc}")
|
file = self.file_list[self.file_cursor]
|
||||||
self.print_recommendations(recs)
|
desc, recs = recommend_category(self.ctx, file)
|
||||||
|
self.onecmd('file')
|
||||||
|
print(f" >> {desc}")
|
||||||
|
self.print_recommendations(recs)
|
||||||
|
|
||||||
def print_recommendations(self, top_recs):
|
def print_recommendations(self, top_recs):
|
||||||
|
self.rec_list = []
|
||||||
|
|
||||||
|
cols, _ = get_terminal_size((80,20))
|
||||||
|
|
||||||
def print_one_rec(index, rec):
|
def print_one_rec(index, rec):
|
||||||
cat, subcat, exp = lookup_cat(rec, self.catlist)
|
cat, subcat, exp = self.ctx.lookup_category(rec)
|
||||||
line = f" [{index:2}] : {rec} - {cat} / {subcat} - {exp}"
|
line = f" [{index:2}] {rec} - {cat} / {subcat} - {exp}"
|
||||||
if len(line) > 75:
|
if len(line) > cols - 3:
|
||||||
line = line[0:75] + "..."
|
line = line[0:cols - 3] + "..."
|
||||||
print(line)
|
print(line)
|
||||||
|
|
||||||
self.rec_list = []
|
|
||||||
print("Suggested from description:")
|
print("Suggested from description:")
|
||||||
for rec in top_recs:
|
for rec in top_recs:
|
||||||
print_one_rec(len(self.rec_list), rec)
|
print_one_rec(len(self.rec_list), rec)
|
||||||
@@ -175,6 +130,34 @@ class Commands(cmd.Cmd):
|
|||||||
print_one_rec(len(self.rec_list), rec)
|
print_one_rec(len(self.rec_list), rec)
|
||||||
self.rec_list.append(rec)
|
self.rec_list.append(rec)
|
||||||
|
|
||||||
|
def do_about(self, line: str):
|
||||||
|
'Print information about recommendation NUMBER'
|
||||||
|
try:
|
||||||
|
picked = int(line)
|
||||||
|
if picked < len(self.rec_list):
|
||||||
|
cat, subcat, exp = \
|
||||||
|
self.lookup_cat(self.rec_list[picked])
|
||||||
|
|
||||||
|
print(f" CatID: {self.rec_list[picked]}")
|
||||||
|
print(f" Category: {cat}")
|
||||||
|
print(f" SubCategory: {subcat}")
|
||||||
|
print(f" Explanation: {exp}")
|
||||||
|
|
||||||
|
except ValueError:
|
||||||
|
print(f" *** Value \"{line}\" not recognized")
|
||||||
|
|
||||||
|
|
||||||
|
def do_use(self, line: str):
|
||||||
|
"""Apply recomendation NUMBER to the current file and advance to the
|
||||||
|
next one"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
picked = int(line)
|
||||||
|
print(f" :: Using {self.rec_list[picked]}")
|
||||||
|
self.onecmd('next')
|
||||||
|
except ValueError:
|
||||||
|
print(" *** Value \"{line}\" not recognized")
|
||||||
|
|
||||||
def do_file(self, _):
|
def do_file(self, _):
|
||||||
'Print info about the current file'
|
'Print info about the current file'
|
||||||
print("---")
|
print("---")
|
||||||
@@ -185,10 +168,10 @@ class Commands(cmd.Cmd):
|
|||||||
else:
|
else:
|
||||||
print( " > No file")
|
print( " > No file")
|
||||||
|
|
||||||
def do_lookup(self, args):
|
# def do_lookup(self, args):
|
||||||
'print a list of UCS categories similar to the argument'
|
# 'print a list of UCS categories similar to the argument'
|
||||||
self.rec_list = classify_text_ranked(args, self.embeddings,
|
# self.rec_list = classify_text_ranked(args, self.embeddings,
|
||||||
self.model)
|
# self.model)
|
||||||
|
|
||||||
def do_ls(self, _):
|
def do_ls(self, _):
|
||||||
'Print list of all files in the buffer'
|
'Print list of all files in the buffer'
|
||||||
@@ -216,17 +199,15 @@ class Commands(cmd.Cmd):
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
cats = load_ucs_categories()
|
# cats = load_ucs_categories()
|
||||||
print(f"Loaded UCS categories.", file=sys.stderr)
|
print(f"Loaded UCS categories.", file=sys.stderr)
|
||||||
model = SentenceTransformer("paraphrase-multilingual-mpnet-base-v2")
|
model = SentenceTransformer("paraphrase-multilingual-mpnet-base-v2")
|
||||||
embeddings = load_embeddings(cats, model)
|
# embeddings = load_embeddings(cats, model)
|
||||||
print(f"Loaded embeddings...", file=sys.stderr)
|
print(f"Loaded embeddings...", file=sys.stderr)
|
||||||
|
|
||||||
com = Commands()
|
com = Commands()
|
||||||
com.file_list = sys.argv[1:]
|
com.file_list = sys.argv[1:]
|
||||||
com.model = model
|
com.ctx = InferenceContext(model=model)
|
||||||
com.embeddings = embeddings
|
|
||||||
com.catlist = cats
|
|
||||||
|
|
||||||
com.cmdloop()
|
com.cmdloop()
|
||||||
|
|
||||||
|
113
ucsinfer/inference.py
Normal file
113
ucsinfer/inference.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
|
||||||
|
import os.path
|
||||||
|
import json
|
||||||
|
import pickle
|
||||||
|
from functools import cached_property
|
||||||
|
|
||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import platformdirs
|
||||||
|
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
def classify_text_ranked(text, embeddings_list, model, limit=5):
|
||||||
|
text_embedding = model.encode(text, convert_to_numpy=True)
|
||||||
|
embeddings = np.array([info['Embedding'] for info in embeddings_list])
|
||||||
|
sim = model.similarity(text_embedding, embeddings)[0]
|
||||||
|
maxinds = np.argsort(sim)[-limit:]
|
||||||
|
return [embeddings_list[x]['CatID'] for x in reversed(maxinds)]
|
||||||
|
|
||||||
|
|
||||||
|
class Ucs(NamedTuple):
|
||||||
|
catid: str
|
||||||
|
category: str
|
||||||
|
subcategory: str
|
||||||
|
explanations: str
|
||||||
|
synonymns: list[str]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, d: dict):
|
||||||
|
return Ucs(catid=d['CatID'], category=d['Category'],
|
||||||
|
subcategory=d['SubCategory'],
|
||||||
|
explanations=d['Explanations'], synonymns=d['Synonyms'])
|
||||||
|
|
||||||
|
class InferenceContext:
|
||||||
|
"""
|
||||||
|
Maintains caches and resources for UCS category inference.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model: SentenceTransformer
|
||||||
|
|
||||||
|
def __init__(self, model: SentenceTransformer):
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def catlist(self) -> list[Ucs]:
|
||||||
|
FILE_ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
cats = []
|
||||||
|
ucs_defs = os.path.join(FILE_ROOT_DIR, 'ucs-community', 'json',
|
||||||
|
'en.json')
|
||||||
|
|
||||||
|
with open(ucs_defs, 'r') as f:
|
||||||
|
cats = json.load(f)
|
||||||
|
|
||||||
|
return [Ucs.from_dict(cat) for cat in cats]
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def embeddings(self) -> list[dict]:
|
||||||
|
cache_dir = platformdirs.user_cache_dir("ucsinfer", 'Squad 51')
|
||||||
|
embedding_cache = os.path.join(cache_dir, f"ucs_embedding.cache")
|
||||||
|
embeddings = []
|
||||||
|
|
||||||
|
if os.path.exists(embedding_cache):
|
||||||
|
with open(embedding_cache, 'rb') as f:
|
||||||
|
embeddings = pickle.load(f)
|
||||||
|
|
||||||
|
else:
|
||||||
|
print("Calculating embeddings...")
|
||||||
|
|
||||||
|
for cat_defn in self.catlist:
|
||||||
|
embeddings += [{
|
||||||
|
'CatID': cat_defn.catid,
|
||||||
|
'Embedding': self._encode_category(cat_defn)
|
||||||
|
}]
|
||||||
|
|
||||||
|
os.makedirs(os.path.dirname(embedding_cache), exist_ok=True)
|
||||||
|
with open(embedding_cache, 'wb') as g:
|
||||||
|
pickle.dump(embeddings, g)
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
def _encode_category(self, cat: Ucs) -> np.ndarray:
|
||||||
|
sentence_components = [cat.explanations,
|
||||||
|
cat.category,
|
||||||
|
cat.subcategory
|
||||||
|
]
|
||||||
|
sentence_components += cat.synonymns
|
||||||
|
sentence = ", ".join(sentence_components)
|
||||||
|
return self.model.encode(sentence, convert_to_numpy=True)
|
||||||
|
|
||||||
|
def classify_text_ranked(self, text: str, limit: int = 5) -> list[str]:
|
||||||
|
"""
|
||||||
|
Sort all UCS catids according to their similarity to `text`
|
||||||
|
"""
|
||||||
|
text_embedding = self.model.encode(text, convert_to_numpy=True)
|
||||||
|
embeddings = np.array([info['Embedding'] for info in self.embeddings])
|
||||||
|
sim = self.model.similarity(text_embedding, embeddings)[0]
|
||||||
|
maxinds = np.argsort(sim)[-limit:]
|
||||||
|
return [self.embeddings[x]['CatID'] for x in reversed(maxinds)]
|
||||||
|
|
||||||
|
def lookup_category(self, catid) -> tuple[str, str, str]:
|
||||||
|
"""
|
||||||
|
Get the category, subcategory and explanations phrase for a `catid`
|
||||||
|
|
||||||
|
:raises: StopIterator if CatId is not on the schedule
|
||||||
|
"""
|
||||||
|
i = (
|
||||||
|
(x.category, x.subcategory, x.explanations) \
|
||||||
|
for x in self.catlist if x.catid == catid
|
||||||
|
)
|
||||||
|
|
||||||
|
return next(i)
|
||||||
|
|
Reference in New Issue
Block a user