diff --git a/ucsinfer/__main__.py b/ucsinfer/__main__.py index 9d52d59..dbc8fdb 100644 --- a/ucsinfer/__main__.py +++ b/ucsinfer/__main__.py @@ -5,11 +5,15 @@ import sys import subprocess import cmd -from typing import Optional, Tuple, IO +from typing import Optional, IO -import numpy as np -import platformdirs -import tqdm +# from .inference import classify_text_ranked + +from .inference import InferenceContext + +# import numpy as np +# import platformdirs +# import tqdm from sentence_transformers import SentenceTransformer ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -19,48 +23,6 @@ os.environ['TOKENIZERS_PARALLELISM'] = 'false' import warnings warnings.simplefilter(action='ignore', category=FutureWarning) -def load_ucs_categories() -> list: - cats = [] - ucs_defs = os.path.join(ROOT_DIR, 'ucs-community', 'json', 'en.json') - - with open(ucs_defs, 'r') as f: - cats = json.load(f) - - return cats - -def encoode_category(cat_defn: dict, model: SentenceTransformer) -> np.ndarray: - sentence_components = [cat_defn['Explanations'], - cat_defn['Category'], - cat_defn['SubCategory'] - ] - sentence_components += cat_defn['Synonyms'] - sentence = ", ".join(sentence_components) - return model.encode(sentence, convert_to_numpy=True) - -def load_embeddings(ucs: list, model) -> list: - cache_dir = platformdirs.user_cache_dir('ucsinfer', 'Squad 51') - embedding_cache = os.path.join(cache_dir, f"ucs_embedding.cache") - embeddings = [] - - if os.path.exists(embedding_cache): - with open(embedding_cache, 'rb') as f: - embeddings = pickle.load(f) - - else: - print("Calculating embeddings...") - - for cat_defn in tqdm.tqdm(ucs): - embeddings += [{ - 'CatID': cat_defn['CatID'], - 'Embedding': encoode_category(cat_defn, model) - }] - - os.makedirs(os.path.dirname(embedding_cache), exist_ok=True) - with open(embedding_cache, 'wb') as g: - pickle.dump(embeddings, g) - - return embeddings - def description(path: str) -> Optional[str]: result = subprocess.run(['ffprobe', '-show_format', '-of', @@ -78,16 +40,7 @@ def description(path: str) -> Optional[str]: if tags: return tags.get("comment", None) - -def classify_text_ranked(text, embeddings_list, model, limit=5): - text_embedding = model.encode(text, convert_to_numpy=True) - embeddings = np.array([info['Embedding'] for info in embeddings_list]) - sim = model.similarity(text_embedding, embeddings)[0] - maxinds = np.argsort(sim)[-limit:] - return [embeddings_list[x]['CatID'] for x in reversed(maxinds)] - - -def recommend_category(path, embeddings, model) -> Tuple[str, list]: +def recommend_category(ctx, path) -> tuple[str, list]: """ Get a text description of the file at `path` and a list of UCS cat IDs """ @@ -95,39 +48,25 @@ def recommend_category(path, embeddings, model) -> Tuple[str, list]: if desc is None: desc = os.path.basename(path) - return desc, classify_text_ranked(desc, embeddings, model) + return desc, ctx.classify_text_ranked(desc) -def lookup_cat(catid: str, ucs: list) -> tuple[str,str, str]: - return next( ((x['Category'], x['SubCategory'], x['Explanations']) \ - for x in ucs if x['CatID'] == catid)) +from shutil import get_terminal_size class Commands(cmd.Cmd): + ctx: InferenceContext def __init__(self, completekey: str = "tab", stdin: IO[str] | None = None, stdout: IO[str] | None = None) -> None: super().__init__(completekey, stdin, stdout) self.file_list = [] - self.model: Optional[SentenceTransformer] = None - self.embeddings: Optional[list] = None self.catlist: list = [] self.rec_list = [] self.history = [] - - def default(self, line): - try: - rec = int(line) - if rec < len(self.rec_list): - print(f"Accept option {rec}") - ind = rec - self.history = [self.rec_list[ind]] + self.history[0:4] - self.onecmd("next") - else: - pass - - except ValueError: - super().default(line) + # def lookup_cat(self, catid: str) -> tuple[str,str, str]: + # return next( ((x['Category'], x['SubCategory'], x['Explanations']) \ + # for x in self.catlist if x['CatID'] == catid)) def preloop(self) -> None: self.file_cursor = 0 @@ -136,7 +75,18 @@ class Commands(cmd.Cmd): return super().preloop() def precmd(self, line: str): - return super().precmd(line) + try: + rec = int(line) + if rec < len(self.rec_list): + pass + else: + pass + + except ValueError: + pass + + finally: + return super().precmd(line) def postcmd(self, stop: bool, line: str) -> bool: if not stop: @@ -148,22 +98,27 @@ class Commands(cmd.Cmd): self.prompt = f"(ucsinfer:{self.file_cursor}/{len(self.file_list)}) " def setup_for_file(self): - file = self.file_list[self.file_cursor] - desc, recs = recommend_category(file, self.embeddings, self.model) - self.onecmd('file') - print(f" >> {desc}") - self.print_recommendations(recs) + if len(self.file_list) == 0: + print(" >> NO FILES!") + else: + file = self.file_list[self.file_cursor] + desc, recs = recommend_category(self.ctx, file) + self.onecmd('file') + print(f" >> {desc}") + self.print_recommendations(recs) def print_recommendations(self, top_recs): + self.rec_list = [] + + cols, _ = get_terminal_size((80,20)) def print_one_rec(index, rec): - cat, subcat, exp = lookup_cat(rec, self.catlist) - line = f" [{index:2}] : {rec} - {cat} / {subcat} - {exp}" - if len(line) > 75: - line = line[0:75] + "..." + cat, subcat, exp = self.ctx.lookup_category(rec) + line = f" [{index:2}] {rec} - {cat} / {subcat} - {exp}" + if len(line) > cols - 3: + line = line[0:cols - 3] + "..." print(line) - self.rec_list = [] print("Suggested from description:") for rec in top_recs: print_one_rec(len(self.rec_list), rec) @@ -175,6 +130,34 @@ class Commands(cmd.Cmd): print_one_rec(len(self.rec_list), rec) self.rec_list.append(rec) + def do_about(self, line: str): + 'Print information about recommendation NUMBER' + try: + picked = int(line) + if picked < len(self.rec_list): + cat, subcat, exp = \ + self.lookup_cat(self.rec_list[picked]) + + print(f" CatID: {self.rec_list[picked]}") + print(f" Category: {cat}") + print(f" SubCategory: {subcat}") + print(f" Explanation: {exp}") + + except ValueError: + print(f" *** Value \"{line}\" not recognized") + + + def do_use(self, line: str): + """Apply recomendation NUMBER to the current file and advance to the + next one""" + + try: + picked = int(line) + print(f" :: Using {self.rec_list[picked]}") + self.onecmd('next') + except ValueError: + print(" *** Value \"{line}\" not recognized") + def do_file(self, _): 'Print info about the current file' print("---") @@ -185,10 +168,10 @@ class Commands(cmd.Cmd): else: print( " > No file") - def do_lookup(self, args): - 'print a list of UCS categories similar to the argument' - self.rec_list = classify_text_ranked(args, self.embeddings, - self.model) + # def do_lookup(self, args): + # 'print a list of UCS categories similar to the argument' + # self.rec_list = classify_text_ranked(args, self.embeddings, + # self.model) def do_ls(self, _): 'Print list of all files in the buffer' @@ -216,17 +199,15 @@ class Commands(cmd.Cmd): def main(): - cats = load_ucs_categories() + # cats = load_ucs_categories() print(f"Loaded UCS categories.", file=sys.stderr) model = SentenceTransformer("paraphrase-multilingual-mpnet-base-v2") - embeddings = load_embeddings(cats, model) + # embeddings = load_embeddings(cats, model) print(f"Loaded embeddings...", file=sys.stderr) com = Commands() com.file_list = sys.argv[1:] - com.model = model - com.embeddings = embeddings - com.catlist = cats + com.ctx = InferenceContext(model=model) com.cmdloop() diff --git a/ucsinfer/inference.py b/ucsinfer/inference.py new file mode 100644 index 0000000..5750f84 --- /dev/null +++ b/ucsinfer/inference.py @@ -0,0 +1,113 @@ + +import os.path +import json +import pickle +from functools import cached_property + +from typing import NamedTuple + +import numpy as np +import platformdirs + +from sentence_transformers import SentenceTransformer + +def classify_text_ranked(text, embeddings_list, model, limit=5): + text_embedding = model.encode(text, convert_to_numpy=True) + embeddings = np.array([info['Embedding'] for info in embeddings_list]) + sim = model.similarity(text_embedding, embeddings)[0] + maxinds = np.argsort(sim)[-limit:] + return [embeddings_list[x]['CatID'] for x in reversed(maxinds)] + + +class Ucs(NamedTuple): + catid: str + category: str + subcategory: str + explanations: str + synonymns: list[str] + + @classmethod + def from_dict(cls, d: dict): + return Ucs(catid=d['CatID'], category=d['Category'], + subcategory=d['SubCategory'], + explanations=d['Explanations'], synonymns=d['Synonyms']) + +class InferenceContext: + """ + Maintains caches and resources for UCS category inference. + """ + + model: SentenceTransformer + + def __init__(self, model: SentenceTransformer): + self.model = model + + @cached_property + def catlist(self) -> list[Ucs]: + FILE_ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) + cats = [] + ucs_defs = os.path.join(FILE_ROOT_DIR, 'ucs-community', 'json', + 'en.json') + + with open(ucs_defs, 'r') as f: + cats = json.load(f) + + return [Ucs.from_dict(cat) for cat in cats] + + @cached_property + def embeddings(self) -> list[dict]: + cache_dir = platformdirs.user_cache_dir("ucsinfer", 'Squad 51') + embedding_cache = os.path.join(cache_dir, f"ucs_embedding.cache") + embeddings = [] + + if os.path.exists(embedding_cache): + with open(embedding_cache, 'rb') as f: + embeddings = pickle.load(f) + + else: + print("Calculating embeddings...") + + for cat_defn in self.catlist: + embeddings += [{ + 'CatID': cat_defn.catid, + 'Embedding': self._encode_category(cat_defn) + }] + + os.makedirs(os.path.dirname(embedding_cache), exist_ok=True) + with open(embedding_cache, 'wb') as g: + pickle.dump(embeddings, g) + + return embeddings + + def _encode_category(self, cat: Ucs) -> np.ndarray: + sentence_components = [cat.explanations, + cat.category, + cat.subcategory + ] + sentence_components += cat.synonymns + sentence = ", ".join(sentence_components) + return self.model.encode(sentence, convert_to_numpy=True) + + def classify_text_ranked(self, text: str, limit: int = 5) -> list[str]: + """ + Sort all UCS catids according to their similarity to `text` + """ + text_embedding = self.model.encode(text, convert_to_numpy=True) + embeddings = np.array([info['Embedding'] for info in self.embeddings]) + sim = self.model.similarity(text_embedding, embeddings)[0] + maxinds = np.argsort(sim)[-limit:] + return [self.embeddings[x]['CatID'] for x in reversed(maxinds)] + + def lookup_category(self, catid) -> tuple[str, str, str]: + """ + Get the category, subcategory and explanations phrase for a `catid` + + :raises: StopIterator if CatId is not on the schedule + """ + i = ( + (x.category, x.subcategory, x.explanations) \ + for x in self.catlist if x.catid == catid + ) + + return next(i) +