Reworking

This commit is contained in:
Jamie Hardt
2025-08-09 20:58:53 -07:00
parent 6c424099fd
commit 53a91f6103
2 changed files with 277 additions and 114 deletions

View File

@@ -29,7 +29,10 @@ def load_ucs_categories() -> list:
return cats
def encoode_category(cat_defn: dict, model: SentenceTransformer) -> np.ndarray:
sentence_components = [cat_defn['Explanations'], cat_defn['Category'], cat_defn['SubCategory']]
sentence_components = [cat_defn['Explanations'],
cat_defn['Category'],
cat_defn['SubCategory']
]
sentence_components += cat_defn['Synonyms']
sentence = ", ".join(sentence_components)
return model.encode(sentence, convert_to_numpy=True)
@@ -60,8 +63,9 @@ def load_embeddings(ucs: list, model) -> list:
def description(path: str) -> Optional[str]:
result = subprocess.run(['ffprobe', '-show_format', '-of', 'json', path], capture_output=True)
# print(result)
result = subprocess.run(['ffprobe', '-show_format', '-of',
'json', path], capture_output=True)
try:
result.check_returncode()
except:
@@ -93,8 +97,9 @@ def recommend_category(path, embeddings, model) -> Tuple[str, list]:
return desc, classify_text_ranked(desc, embeddings, model)
def lookup_cat(catid: str, ucs: list) -> Optional[tuple[str,str]]:
return next( ((x['Category'], x['SubCategory']) for x in ucs if x['CatID'] == catid) , None)
def lookup_cat(catid: str, ucs: list) -> tuple[str,str]:
return next( ((x['Category'], x['SubCategory']) \
for x in ucs if x['CatID'] == catid))
class Commands(cmd.Cmd):
@@ -103,94 +108,84 @@ class Commands(cmd.Cmd):
stdout: IO[str] | None = None) -> None:
super().__init__(completekey, stdin, stdout)
self.file_list = []
self.model = None
self.embeddings = None
self.catlist = None
self._rec_list = []
self._file_cursor = 0
self.model: Optional[SentenceTransformer] = None
self.embeddings: Optional[list] = None
self.catlist: Optional[list] = None
self.model_rec_list = []
self.history = []
@property
def file_cursor(self):
return self._file_cursor
@file_cursor.setter
def file_cursor(self, val):
self._file_cursor = val
self.onecmd('file')
@property
def rec_list(self):
return self._rec_list
@rec_list.setter
def rec_list(self, value):
self._rec_list = value
if isinstance(self.rec_list, list) and self.catlist:
for i, cat_id in enumerate(self.rec_list):
cat, subcat = lookup_cat(cat_id, self.catlist)
print(f" [ {i+1} ]: {cat_id} ({cat} / {subcat})")
def default(self, line):
if len(self.rec_list) > 0:
try:
rec = int(line)
if rec < len(self.rec_list):
print(f"Accept option {rec}")
self.onecmd("next")
else:
pass
except ValueError:
super().default(line)
else:
def default(self, line):
try:
rec = int(line)
if rec < len(self.model_rec_list):
print(f"Accept option {rec}")
ind = rec - 1
self.history = [self.model_rec_list[ind]] + self.history[0:4]
self.onecmd("next")
else:
pass
except ValueError:
super().default(line)
def preloop(self) -> None:
self.file_cursor = 0
self.update_prompt()
return super().preloop()
def postcmd(self, stop: bool, line: str) -> bool:
self.update_prompt()
return super().postcmd(stop, line)
def update_prompt(self):
self.prompt = f"(ucsinfer:{self.file_cursor}/{len(self.file_list)}) "
def do_file(self, args):
def do_file(self, _):
'Print info about the current file'
if self.file_cursor < len(self.file_list):
self.update_prompt()
path = self.file_list[self.file_cursor]
f = os.path.basename(path)
print(f" > {f}")
else:
print( " > No file")
def do_rec(self, _):
if self.file_cursor < len(self.file_list):
self.update_prompt()
path = self.file_list[self.file_cursor]
desc, recs = recommend_category(path, self.embeddings, self.model)
print(f" >> {desc}")
self.rec_list = recs
self.model_rec_list = recs
else:
print(" > No file")
def do_addcontext(self, args):
'Add the argument to all file descriptions before searching for '
'similar. Enter a blank value to reset.'
pass
self.model_rec_list = []
def do_lookup(self, args):
'print a list of UCS categories similar to the argument'
self.rec_list = classify_text_ranked(args, self.embeddings, self.model)
self.model_rec_list = classify_text_ranked(args, self.embeddings,
self.model)
def do_ls(self, _):
'Print list of all files in the buffer'
for file in self.file_list[self.file_cursor:] + \
self.file_list[0:self.file_cursor]:
f = os.path.basename(file)
print(f" > {f}")
def do_next(self, _):
'go to next file'
self.file_cursor += 1
self.file_cursor = self.file_cursor % len(self.file_list)
def do_prev(self, _):
'go to previous file'
self.file_cursor -= 1
self.file_cursor = self.file_cursor % len(self.file_list)
def do_quit(self, _):
'exit'
def do_bye(self, _):
'exit the program'
print("Exiting...")
return True