Committing module

This commit is contained in:
Jamie Hardt
2025-08-03 15:03:54 -07:00
commit 6c424099fd
9 changed files with 787 additions and 0 deletions

215
ucsinfer/__main__.py Normal file
View File

@@ -0,0 +1,215 @@
import os
import json
import pickle
import sys
import subprocess
import cmd
from typing import Optional, Tuple, IO
import numpy as np
import platformdirs
import tqdm
from sentence_transformers import SentenceTransformer
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
def load_ucs_categories() -> list:
cats = []
ucs_defs = os.path.join(ROOT_DIR, 'ucs-community', 'json', 'en.json')
with open(ucs_defs, 'r') as f:
cats = json.load(f)
return cats
def encoode_category(cat_defn: dict, model: SentenceTransformer) -> np.ndarray:
sentence_components = [cat_defn['Explanations'], cat_defn['Category'], cat_defn['SubCategory']]
sentence_components += cat_defn['Synonyms']
sentence = ", ".join(sentence_components)
return model.encode(sentence, convert_to_numpy=True)
def load_embeddings(ucs: list, model) -> list:
cache_dir = platformdirs.user_cache_dir('ucsinfer', 'Squad 51')
embedding_cache = os.path.join(cache_dir, f"ucs_embedding.cache")
embeddings = []
if os.path.exists(embedding_cache):
with open(embedding_cache, 'rb') as f:
embeddings = pickle.load(f)
else:
print("Calculating embeddings...")
for cat_defn in tqdm.tqdm(ucs):
embeddings += [{
'CatID': cat_defn['CatID'],
'Embedding': encoode_category(cat_defn, model)
}]
os.makedirs(os.path.dirname(embedding_cache), exist_ok=True)
with open(embedding_cache, 'wb') as g:
pickle.dump(embeddings, g)
return embeddings
def description(path: str) -> Optional[str]:
result = subprocess.run(['ffprobe', '-show_format', '-of', 'json', path], capture_output=True)
# print(result)
try:
result.check_returncode()
except:
return None
stream = json.loads(result.stdout)
fmt = stream.get("format", None)
if fmt:
tags = fmt.get("tags", None)
if tags:
return tags.get("comment", None)
def classify_text_ranked(text, embeddings_list, model, limit=5):
text_embedding = model.encode(text, convert_to_numpy=True)
embeddings = np.array([info['Embedding'] for info in embeddings_list])
sim = model.similarity(text_embedding, embeddings)[0]
maxinds = np.argsort(sim)[-limit:]
return [embeddings_list[x]['CatID'] for x in reversed(maxinds)]
def recommend_category(path, embeddings, model) -> Tuple[str, list]:
"""
Get a text description of the file at `path` and a list of UCS cat IDs
"""
desc = description(path)
if desc is None:
desc = os.path.basename(path)
return desc, classify_text_ranked(desc, embeddings, model)
def lookup_cat(catid: str, ucs: list) -> Optional[tuple[str,str]]:
return next( ((x['Category'], x['SubCategory']) for x in ucs if x['CatID'] == catid) , None)
class Commands(cmd.Cmd):
def __init__(self, completekey: str = "tab", stdin: IO[str] | None = None,
stdout: IO[str] | None = None) -> None:
super().__init__(completekey, stdin, stdout)
self.file_list = []
self.model = None
self.embeddings = None
self.catlist = None
self._rec_list = []
self._file_cursor = 0
@property
def file_cursor(self):
return self._file_cursor
@file_cursor.setter
def file_cursor(self, val):
self._file_cursor = val
self.onecmd('file')
@property
def rec_list(self):
return self._rec_list
@rec_list.setter
def rec_list(self, value):
self._rec_list = value
if isinstance(self.rec_list, list) and self.catlist:
for i, cat_id in enumerate(self.rec_list):
cat, subcat = lookup_cat(cat_id, self.catlist)
print(f" [ {i+1} ]: {cat_id} ({cat} / {subcat})")
def default(self, line):
if len(self.rec_list) > 0:
try:
rec = int(line)
if rec < len(self.rec_list):
print(f"Accept option {rec}")
self.onecmd("next")
else:
pass
except ValueError:
super().default(line)
else:
super().default(line)
def preloop(self) -> None:
self.file_cursor = 0
self.update_prompt()
return super().preloop()
def postcmd(self, stop: bool, line: str) -> bool:
return super().postcmd(stop, line)
def update_prompt(self):
self.prompt = f"(ucsinfer:{self.file_cursor}/{len(self.file_list)}) "
def do_file(self, args):
'Print info about the current file'
if self.file_cursor < len(self.file_list):
self.update_prompt()
path = self.file_list[self.file_cursor]
f = os.path.basename(path)
print(f" > {f}")
desc, recs = recommend_category(path, self.embeddings, self.model)
print(f" >> {desc}")
self.rec_list = recs
else:
print(" > No file")
def do_addcontext(self, args):
'Add the argument to all file descriptions before searching for '
'similar. Enter a blank value to reset.'
pass
def do_lookup(self, args):
'print a list of UCS categories similar to the argument'
self.rec_list = classify_text_ranked(args, self.embeddings, self.model)
def do_next(self, _):
'go to next file'
self.file_cursor += 1
def do_prev(self, _):
'go to previous file'
self.file_cursor -= 1
def do_quit(self, _):
'exit'
print("Exiting...")
return True
def main():
cats = load_ucs_categories()
print(f"Loaded UCS categories.", file=sys.stderr)
model = SentenceTransformer("paraphrase-multilingual-mpnet-base-v2")
embeddings = load_embeddings(cats, model)
print(f"Loaded embeddings...", file=sys.stderr)
com = Commands()
com.file_list = sys.argv[1:]
com.model = model
com.embeddings = embeddings
com.catlist = cats
com.cmdloop()
if __name__ == '__main__':
main()