This commit is contained in:
Jamie Hardt
2025-08-10 23:12:45 -07:00
parent 75f88d483b
commit c4b3324a3b

View File

@@ -7,22 +7,10 @@ import cmd
from typing import Optional, IO from typing import Optional, IO
# from .inference import classify_text_ranked
from .inference import InferenceContext from .inference import InferenceContext
# import numpy as np
# import platformdirs
# import tqdm
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
def description(path: str) -> Optional[str]: def description(path: str) -> Optional[str]:
result = subprocess.run(['ffprobe', '-show_format', '-of', result = subprocess.run(['ffprobe', '-show_format', '-of',
@@ -40,7 +28,7 @@ def description(path: str) -> Optional[str]:
if tags: if tags:
return tags.get("comment", None) return tags.get("comment", None)
def recommend_category(ctx, path) -> tuple[str, list]: def recommend_category(ctx: InferenceContext, path) -> tuple[str, list]:
""" """
Get a text description of the file at `path` and a list of UCS cat IDs Get a text description of the file at `path` and a list of UCS cat IDs
""" """
@@ -64,16 +52,22 @@ class Commands(cmd.Cmd):
self.rec_list = [] self.rec_list = []
self.history = [] self.history = []
# def lookup_cat(self, catid: str) -> tuple[str,str, str]:
# return next( ((x['Category'], x['SubCategory'], x['Explanations']) \
# for x in self.catlist if x['CatID'] == catid))
def preloop(self) -> None: def preloop(self) -> None:
self.file_cursor = 0 self.file_cursor = 0
self.update_prompt() self.update_prompt()
self.setup_for_file() self.setup_for_file()
return super().preloop() return super().preloop()
def default(self, line: str):
try:
sel = int(line)
self.onecmd(f"use {sel}")
except ValueError:
return super().default(line)
return super().default(line)
def precmd(self, line: str): def precmd(self, line: str):
try: try:
rec = int(line) rec = int(line)
@@ -136,7 +130,7 @@ class Commands(cmd.Cmd):
picked = int(line) picked = int(line)
if picked < len(self.rec_list): if picked < len(self.rec_list):
cat, subcat, exp = \ cat, subcat, exp = \
self.lookup_cat(self.rec_list[picked]) self.ctx.lookup_category(self.rec_list[picked])
print(f" CatID: {self.rec_list[picked]}") print(f" CatID: {self.rec_list[picked]}")
print(f" Category: {cat}") print(f" Category: {cat}")
@@ -154,7 +148,7 @@ class Commands(cmd.Cmd):
try: try:
picked = int(line) picked = int(line)
print(f" :: Using {self.rec_list[picked]}") print(f" :: Using {self.rec_list[picked]}")
self.onecmd('next') self.do_next("")
except ValueError: except ValueError:
print(" *** Value \"{line}\" not recognized") print(" *** Value \"{line}\" not recognized")
@@ -168,11 +162,6 @@ class Commands(cmd.Cmd):
else: else:
print( " > No file") print( " > No file")
# def do_lookup(self, args):
# 'print a list of UCS categories similar to the argument'
# self.rec_list = classify_text_ranked(args, self.embeddings,
# self.model)
def do_ls(self, _): def do_ls(self, _):
'Print list of all files in the buffer' 'Print list of all files in the buffer'
for file in self.file_list[self.file_cursor:] + \ for file in self.file_list[self.file_cursor:] + \
@@ -199,11 +188,7 @@ class Commands(cmd.Cmd):
def main(): def main():
# cats = load_ucs_categories()
print(f"Loaded UCS categories.", file=sys.stderr)
model = SentenceTransformer("paraphrase-multilingual-mpnet-base-v2") model = SentenceTransformer("paraphrase-multilingual-mpnet-base-v2")
# embeddings = load_embeddings(cats, model)
print(f"Loaded embeddings...", file=sys.stderr)
com = Commands() com = Commands()
com.file_list = sys.argv[1:] com.file_list = sys.argv[1:]
@@ -213,4 +198,9 @@ def main():
if __name__ == '__main__': if __name__ == '__main__':
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
main() main()