206 lines
5.6 KiB
Python
206 lines
5.6 KiB
Python
import os
|
|
import json
|
|
import sys
|
|
import subprocess
|
|
import cmd
|
|
|
|
from typing import Optional, IO
|
|
|
|
from .inference import InferenceContext
|
|
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
|
|
def description(path: str) -> Optional[str]:
|
|
result = subprocess.run(['ffprobe', '-show_format', '-of',
|
|
'json', path], capture_output=True)
|
|
|
|
try:
|
|
result.check_returncode()
|
|
except:
|
|
return None
|
|
|
|
stream = json.loads(result.stdout)
|
|
fmt = stream.get("format", None)
|
|
if fmt:
|
|
tags = fmt.get("tags", None)
|
|
if tags:
|
|
return tags.get("comment", None)
|
|
|
|
def recommend_category(ctx: InferenceContext, path) -> tuple[str, list]:
|
|
"""
|
|
Get a text description of the file at `path` and a list of UCS cat IDs
|
|
"""
|
|
desc = description(path)
|
|
if desc is None:
|
|
desc = os.path.basename(path)
|
|
|
|
return desc, ctx.classify_text_ranked(desc)
|
|
|
|
|
|
from shutil import get_terminal_size
|
|
|
|
class Commands(cmd.Cmd):
|
|
ctx: InferenceContext
|
|
|
|
def __init__(self, completekey: str = "tab", stdin: IO[str] | None = None,
|
|
stdout: IO[str] | None = None) -> None:
|
|
super().__init__(completekey, stdin, stdout)
|
|
self.file_list = []
|
|
self.catlist: list = []
|
|
self.rec_list = []
|
|
self.history = []
|
|
|
|
def preloop(self) -> None:
|
|
self.file_cursor = 0
|
|
self.update_prompt()
|
|
self.setup_for_file()
|
|
return super().preloop()
|
|
|
|
def default(self, line: str):
|
|
try:
|
|
sel = int(line)
|
|
self.onecmd(f"use {sel}")
|
|
|
|
except ValueError:
|
|
return super().default(line)
|
|
|
|
return super().default(line)
|
|
|
|
def precmd(self, line: str):
|
|
try:
|
|
rec = int(line)
|
|
if rec < len(self.rec_list):
|
|
pass
|
|
else:
|
|
pass
|
|
|
|
except ValueError:
|
|
pass
|
|
|
|
finally:
|
|
return super().precmd(line)
|
|
|
|
def postcmd(self, stop: bool, line: str) -> bool:
|
|
if not stop:
|
|
self.update_prompt()
|
|
self.setup_for_file()
|
|
return super().postcmd(stop, line)
|
|
|
|
def update_prompt(self):
|
|
self.prompt = f"(ucsinfer:{self.file_cursor}/{len(self.file_list)}) "
|
|
|
|
def setup_for_file(self):
|
|
if len(self.file_list) == 0:
|
|
print(" >> NO FILES!")
|
|
else:
|
|
file = self.file_list[self.file_cursor]
|
|
desc, recs = recommend_category(self.ctx, file)
|
|
self.onecmd('file')
|
|
print(f" >> {desc}")
|
|
self.print_recommendations(recs)
|
|
|
|
def print_recommendations(self, top_recs):
|
|
self.rec_list = []
|
|
|
|
cols, _ = get_terminal_size((80,20))
|
|
|
|
def print_one_rec(index, rec):
|
|
cat, subcat, exp = self.ctx.lookup_category(rec)
|
|
line = f" [{index:2}] {rec} - {cat} / {subcat} - {exp}"
|
|
if len(line) > cols - 3:
|
|
line = line[0:cols - 3] + "..."
|
|
print(line)
|
|
|
|
print("Suggested from description:")
|
|
for rec in top_recs:
|
|
print_one_rec(len(self.rec_list), rec)
|
|
self.rec_list.append(rec)
|
|
|
|
if len(self.history) > 0:
|
|
print("History:")
|
|
for rec in self.history:
|
|
print_one_rec(len(self.rec_list), rec)
|
|
self.rec_list.append(rec)
|
|
|
|
def do_about(self, line: str):
|
|
'Print information about recommendation NUMBER'
|
|
try:
|
|
picked = int(line)
|
|
if picked < len(self.rec_list):
|
|
cat, subcat, exp = \
|
|
self.ctx.lookup_category(self.rec_list[picked])
|
|
|
|
print(f" CatID: {self.rec_list[picked]}")
|
|
print(f" Category: {cat}")
|
|
print(f" SubCategory: {subcat}")
|
|
print(f" Explanation: {exp}")
|
|
|
|
except ValueError:
|
|
print(f" *** Value \"{line}\" not recognized")
|
|
|
|
|
|
def do_use(self, line: str):
|
|
"""Apply recomendation NUMBER to the current file and advance to the
|
|
next one"""
|
|
|
|
try:
|
|
picked = int(line)
|
|
print(f" :: Using {self.rec_list[picked]}")
|
|
self.do_next("")
|
|
except ValueError:
|
|
print(" *** Value \"{line}\" not recognized")
|
|
|
|
def do_file(self, _):
|
|
'Print info about the current file'
|
|
print("---")
|
|
if self.file_cursor < len(self.file_list):
|
|
path = self.file_list[self.file_cursor]
|
|
f = os.path.basename(path)
|
|
print(f" > {f}")
|
|
else:
|
|
print( " > No file")
|
|
|
|
def do_ls(self, _):
|
|
'Print list of all files in the buffer'
|
|
for file in self.file_list[self.file_cursor:] + \
|
|
self.file_list[0:self.file_cursor]:
|
|
f = os.path.basename(file)
|
|
print(f" > {f}")
|
|
|
|
def do_next(self, _):
|
|
'go to next file'
|
|
self.file_cursor += 1
|
|
self.file_cursor = self.file_cursor % len(self.file_list)
|
|
self.setup_for_file()
|
|
|
|
def do_prev(self, _):
|
|
'go to previous file'
|
|
self.file_cursor -= 1
|
|
self.file_cursor = self.file_cursor % len(self.file_list)
|
|
self.setup_for_file()
|
|
|
|
def do_bye(self, _):
|
|
'exit the program'
|
|
print("Exiting...")
|
|
return True
|
|
|
|
|
|
def main():
|
|
model = SentenceTransformer("paraphrase-multilingual-mpnet-base-v2")
|
|
|
|
com = Commands()
|
|
com.file_list = sys.argv[1:]
|
|
com.ctx = InferenceContext(model=model)
|
|
|
|
com.cmdloop()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
|
|
|
import warnings
|
|
warnings.simplefilter(action='ignore', category=FutureWarning)
|
|
|
|
main()
|