Committing module
This commit is contained in:
215
ucsinfer/__main__.py
Normal file
215
ucsinfer/__main__.py
Normal file
@@ -0,0 +1,215 @@
|
||||
import os
|
||||
import json
|
||||
import pickle
|
||||
import sys
|
||||
import subprocess
|
||||
import cmd
|
||||
|
||||
from typing import Optional, Tuple, IO
|
||||
|
||||
import numpy as np
|
||||
import platformdirs
|
||||
import tqdm
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
||||
|
||||
import warnings
|
||||
warnings.simplefilter(action='ignore', category=FutureWarning)
|
||||
|
||||
def load_ucs_categories() -> list:
|
||||
cats = []
|
||||
ucs_defs = os.path.join(ROOT_DIR, 'ucs-community', 'json', 'en.json')
|
||||
|
||||
with open(ucs_defs, 'r') as f:
|
||||
cats = json.load(f)
|
||||
|
||||
return cats
|
||||
|
||||
def encoode_category(cat_defn: dict, model: SentenceTransformer) -> np.ndarray:
|
||||
sentence_components = [cat_defn['Explanations'], cat_defn['Category'], cat_defn['SubCategory']]
|
||||
sentence_components += cat_defn['Synonyms']
|
||||
sentence = ", ".join(sentence_components)
|
||||
return model.encode(sentence, convert_to_numpy=True)
|
||||
|
||||
def load_embeddings(ucs: list, model) -> list:
|
||||
cache_dir = platformdirs.user_cache_dir('ucsinfer', 'Squad 51')
|
||||
embedding_cache = os.path.join(cache_dir, f"ucs_embedding.cache")
|
||||
embeddings = []
|
||||
|
||||
if os.path.exists(embedding_cache):
|
||||
with open(embedding_cache, 'rb') as f:
|
||||
embeddings = pickle.load(f)
|
||||
|
||||
else:
|
||||
print("Calculating embeddings...")
|
||||
|
||||
for cat_defn in tqdm.tqdm(ucs):
|
||||
embeddings += [{
|
||||
'CatID': cat_defn['CatID'],
|
||||
'Embedding': encoode_category(cat_defn, model)
|
||||
}]
|
||||
|
||||
os.makedirs(os.path.dirname(embedding_cache), exist_ok=True)
|
||||
with open(embedding_cache, 'wb') as g:
|
||||
pickle.dump(embeddings, g)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
def description(path: str) -> Optional[str]:
|
||||
result = subprocess.run(['ffprobe', '-show_format', '-of', 'json', path], capture_output=True)
|
||||
# print(result)
|
||||
try:
|
||||
result.check_returncode()
|
||||
except:
|
||||
return None
|
||||
|
||||
stream = json.loads(result.stdout)
|
||||
fmt = stream.get("format", None)
|
||||
if fmt:
|
||||
tags = fmt.get("tags", None)
|
||||
if tags:
|
||||
return tags.get("comment", None)
|
||||
|
||||
|
||||
def classify_text_ranked(text, embeddings_list, model, limit=5):
|
||||
text_embedding = model.encode(text, convert_to_numpy=True)
|
||||
embeddings = np.array([info['Embedding'] for info in embeddings_list])
|
||||
sim = model.similarity(text_embedding, embeddings)[0]
|
||||
maxinds = np.argsort(sim)[-limit:]
|
||||
return [embeddings_list[x]['CatID'] for x in reversed(maxinds)]
|
||||
|
||||
|
||||
def recommend_category(path, embeddings, model) -> Tuple[str, list]:
|
||||
"""
|
||||
Get a text description of the file at `path` and a list of UCS cat IDs
|
||||
"""
|
||||
desc = description(path)
|
||||
if desc is None:
|
||||
desc = os.path.basename(path)
|
||||
|
||||
return desc, classify_text_ranked(desc, embeddings, model)
|
||||
|
||||
def lookup_cat(catid: str, ucs: list) -> Optional[tuple[str,str]]:
|
||||
return next( ((x['Category'], x['SubCategory']) for x in ucs if x['CatID'] == catid) , None)
|
||||
|
||||
|
||||
class Commands(cmd.Cmd):
|
||||
|
||||
def __init__(self, completekey: str = "tab", stdin: IO[str] | None = None,
|
||||
stdout: IO[str] | None = None) -> None:
|
||||
super().__init__(completekey, stdin, stdout)
|
||||
self.file_list = []
|
||||
self.model = None
|
||||
self.embeddings = None
|
||||
self.catlist = None
|
||||
self._rec_list = []
|
||||
self._file_cursor = 0
|
||||
|
||||
@property
|
||||
def file_cursor(self):
|
||||
return self._file_cursor
|
||||
|
||||
@file_cursor.setter
|
||||
def file_cursor(self, val):
|
||||
self._file_cursor = val
|
||||
self.onecmd('file')
|
||||
|
||||
@property
|
||||
def rec_list(self):
|
||||
return self._rec_list
|
||||
|
||||
@rec_list.setter
|
||||
def rec_list(self, value):
|
||||
self._rec_list = value
|
||||
if isinstance(self.rec_list, list) and self.catlist:
|
||||
for i, cat_id in enumerate(self.rec_list):
|
||||
cat, subcat = lookup_cat(cat_id, self.catlist)
|
||||
print(f" [ {i+1} ]: {cat_id} ({cat} / {subcat})")
|
||||
|
||||
def default(self, line):
|
||||
if len(self.rec_list) > 0:
|
||||
try:
|
||||
rec = int(line)
|
||||
if rec < len(self.rec_list):
|
||||
print(f"Accept option {rec}")
|
||||
self.onecmd("next")
|
||||
else:
|
||||
pass
|
||||
|
||||
except ValueError:
|
||||
super().default(line)
|
||||
|
||||
else:
|
||||
super().default(line)
|
||||
|
||||
def preloop(self) -> None:
|
||||
self.file_cursor = 0
|
||||
self.update_prompt()
|
||||
return super().preloop()
|
||||
|
||||
def postcmd(self, stop: bool, line: str) -> bool:
|
||||
return super().postcmd(stop, line)
|
||||
|
||||
def update_prompt(self):
|
||||
self.prompt = f"(ucsinfer:{self.file_cursor}/{len(self.file_list)}) "
|
||||
|
||||
def do_file(self, args):
|
||||
'Print info about the current file'
|
||||
if self.file_cursor < len(self.file_list):
|
||||
self.update_prompt()
|
||||
path = self.file_list[self.file_cursor]
|
||||
f = os.path.basename(path)
|
||||
print(f" > {f}")
|
||||
desc, recs = recommend_category(path, self.embeddings, self.model)
|
||||
print(f" >> {desc}")
|
||||
self.rec_list = recs
|
||||
else:
|
||||
print(" > No file")
|
||||
|
||||
|
||||
|
||||
def do_addcontext(self, args):
|
||||
'Add the argument to all file descriptions before searching for '
|
||||
'similar. Enter a blank value to reset.'
|
||||
pass
|
||||
|
||||
def do_lookup(self, args):
|
||||
'print a list of UCS categories similar to the argument'
|
||||
self.rec_list = classify_text_ranked(args, self.embeddings, self.model)
|
||||
|
||||
def do_next(self, _):
|
||||
'go to next file'
|
||||
self.file_cursor += 1
|
||||
|
||||
def do_prev(self, _):
|
||||
'go to previous file'
|
||||
self.file_cursor -= 1
|
||||
|
||||
def do_quit(self, _):
|
||||
'exit'
|
||||
print("Exiting...")
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
cats = load_ucs_categories()
|
||||
print(f"Loaded UCS categories.", file=sys.stderr)
|
||||
model = SentenceTransformer("paraphrase-multilingual-mpnet-base-v2")
|
||||
embeddings = load_embeddings(cats, model)
|
||||
print(f"Loaded embeddings...", file=sys.stderr)
|
||||
|
||||
com = Commands()
|
||||
com.file_list = sys.argv[1:]
|
||||
com.model = model
|
||||
com.embeddings = embeddings
|
||||
com.catlist = cats
|
||||
|
||||
com.cmdloop()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Reference in New Issue
Block a user