Reworking
This commit is contained in:
@@ -29,7 +29,10 @@ def load_ucs_categories() -> list:
|
||||
return cats
|
||||
|
||||
def encoode_category(cat_defn: dict, model: SentenceTransformer) -> np.ndarray:
|
||||
sentence_components = [cat_defn['Explanations'], cat_defn['Category'], cat_defn['SubCategory']]
|
||||
sentence_components = [cat_defn['Explanations'],
|
||||
cat_defn['Category'],
|
||||
cat_defn['SubCategory']
|
||||
]
|
||||
sentence_components += cat_defn['Synonyms']
|
||||
sentence = ", ".join(sentence_components)
|
||||
return model.encode(sentence, convert_to_numpy=True)
|
||||
@@ -60,8 +63,9 @@ def load_embeddings(ucs: list, model) -> list:
|
||||
|
||||
|
||||
def description(path: str) -> Optional[str]:
|
||||
result = subprocess.run(['ffprobe', '-show_format', '-of', 'json', path], capture_output=True)
|
||||
# print(result)
|
||||
result = subprocess.run(['ffprobe', '-show_format', '-of',
|
||||
'json', path], capture_output=True)
|
||||
|
||||
try:
|
||||
result.check_returncode()
|
||||
except:
|
||||
@@ -93,8 +97,9 @@ def recommend_category(path, embeddings, model) -> Tuple[str, list]:
|
||||
|
||||
return desc, classify_text_ranked(desc, embeddings, model)
|
||||
|
||||
def lookup_cat(catid: str, ucs: list) -> Optional[tuple[str,str]]:
|
||||
return next( ((x['Category'], x['SubCategory']) for x in ucs if x['CatID'] == catid) , None)
|
||||
def lookup_cat(catid: str, ucs: list) -> tuple[str,str]:
|
||||
return next( ((x['Category'], x['SubCategory']) \
|
||||
for x in ucs if x['CatID'] == catid))
|
||||
|
||||
|
||||
class Commands(cmd.Cmd):
|
||||
@@ -103,94 +108,84 @@ class Commands(cmd.Cmd):
|
||||
stdout: IO[str] | None = None) -> None:
|
||||
super().__init__(completekey, stdin, stdout)
|
||||
self.file_list = []
|
||||
self.model = None
|
||||
self.embeddings = None
|
||||
self.catlist = None
|
||||
self._rec_list = []
|
||||
self._file_cursor = 0
|
||||
self.model: Optional[SentenceTransformer] = None
|
||||
self.embeddings: Optional[list] = None
|
||||
self.catlist: Optional[list] = None
|
||||
self.model_rec_list = []
|
||||
self.history = []
|
||||
|
||||
@property
|
||||
def file_cursor(self):
|
||||
return self._file_cursor
|
||||
|
||||
@file_cursor.setter
|
||||
def file_cursor(self, val):
|
||||
self._file_cursor = val
|
||||
self.onecmd('file')
|
||||
|
||||
@property
|
||||
def rec_list(self):
|
||||
return self._rec_list
|
||||
|
||||
@rec_list.setter
|
||||
def rec_list(self, value):
|
||||
self._rec_list = value
|
||||
if isinstance(self.rec_list, list) and self.catlist:
|
||||
for i, cat_id in enumerate(self.rec_list):
|
||||
cat, subcat = lookup_cat(cat_id, self.catlist)
|
||||
print(f" [ {i+1} ]: {cat_id} ({cat} / {subcat})")
|
||||
|
||||
def default(self, line):
|
||||
if len(self.rec_list) > 0:
|
||||
try:
|
||||
rec = int(line)
|
||||
if rec < len(self.rec_list):
|
||||
print(f"Accept option {rec}")
|
||||
self.onecmd("next")
|
||||
else:
|
||||
pass
|
||||
|
||||
except ValueError:
|
||||
super().default(line)
|
||||
|
||||
else:
|
||||
def default(self, line):
|
||||
try:
|
||||
rec = int(line)
|
||||
if rec < len(self.model_rec_list):
|
||||
print(f"Accept option {rec}")
|
||||
ind = rec - 1
|
||||
self.history = [self.model_rec_list[ind]] + self.history[0:4]
|
||||
self.onecmd("next")
|
||||
else:
|
||||
pass
|
||||
|
||||
except ValueError:
|
||||
super().default(line)
|
||||
|
||||
|
||||
def preloop(self) -> None:
|
||||
self.file_cursor = 0
|
||||
self.update_prompt()
|
||||
return super().preloop()
|
||||
|
||||
def postcmd(self, stop: bool, line: str) -> bool:
|
||||
self.update_prompt()
|
||||
return super().postcmd(stop, line)
|
||||
|
||||
def update_prompt(self):
|
||||
self.prompt = f"(ucsinfer:{self.file_cursor}/{len(self.file_list)}) "
|
||||
|
||||
def do_file(self, args):
|
||||
|
||||
def do_file(self, _):
|
||||
'Print info about the current file'
|
||||
if self.file_cursor < len(self.file_list):
|
||||
self.update_prompt()
|
||||
path = self.file_list[self.file_cursor]
|
||||
f = os.path.basename(path)
|
||||
print(f" > {f}")
|
||||
else:
|
||||
print( " > No file")
|
||||
|
||||
def do_rec(self, _):
|
||||
if self.file_cursor < len(self.file_list):
|
||||
self.update_prompt()
|
||||
path = self.file_list[self.file_cursor]
|
||||
desc, recs = recommend_category(path, self.embeddings, self.model)
|
||||
print(f" >> {desc}")
|
||||
self.rec_list = recs
|
||||
self.model_rec_list = recs
|
||||
else:
|
||||
print(" > No file")
|
||||
|
||||
|
||||
|
||||
def do_addcontext(self, args):
|
||||
'Add the argument to all file descriptions before searching for '
|
||||
'similar. Enter a blank value to reset.'
|
||||
pass
|
||||
self.model_rec_list = []
|
||||
|
||||
def do_lookup(self, args):
|
||||
'print a list of UCS categories similar to the argument'
|
||||
self.rec_list = classify_text_ranked(args, self.embeddings, self.model)
|
||||
self.model_rec_list = classify_text_ranked(args, self.embeddings,
|
||||
self.model)
|
||||
|
||||
def do_ls(self, _):
|
||||
'Print list of all files in the buffer'
|
||||
for file in self.file_list[self.file_cursor:] + \
|
||||
self.file_list[0:self.file_cursor]:
|
||||
f = os.path.basename(file)
|
||||
print(f" > {f}")
|
||||
|
||||
def do_next(self, _):
|
||||
'go to next file'
|
||||
self.file_cursor += 1
|
||||
self.file_cursor = self.file_cursor % len(self.file_list)
|
||||
|
||||
def do_prev(self, _):
|
||||
'go to previous file'
|
||||
self.file_cursor -= 1
|
||||
self.file_cursor = self.file_cursor % len(self.file_list)
|
||||
|
||||
def do_quit(self, _):
|
||||
'exit'
|
||||
def do_bye(self, _):
|
||||
'exit the program'
|
||||
print("Exiting...")
|
||||
return True
|
||||
|
||||
|
Reference in New Issue
Block a user