This commit is contained in:
Jamie Hardt
2025-08-10 23:12:45 -07:00
parent 75f88d483b
commit c4b3324a3b

View File

@@ -7,22 +7,10 @@ import cmd
from typing import Optional, IO
# from .inference import classify_text_ranked
from .inference import InferenceContext
# import numpy as np
# import platformdirs
# import tqdm
from sentence_transformers import SentenceTransformer
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
def description(path: str) -> Optional[str]:
result = subprocess.run(['ffprobe', '-show_format', '-of',
@@ -40,7 +28,7 @@ def description(path: str) -> Optional[str]:
if tags:
return tags.get("comment", None)
def recommend_category(ctx, path) -> tuple[str, list]:
def recommend_category(ctx: InferenceContext, path) -> tuple[str, list]:
"""
Get a text description of the file at `path` and a list of UCS cat IDs
"""
@@ -64,16 +52,22 @@ class Commands(cmd.Cmd):
self.rec_list = []
self.history = []
# def lookup_cat(self, catid: str) -> tuple[str,str, str]:
# return next( ((x['Category'], x['SubCategory'], x['Explanations']) \
# for x in self.catlist if x['CatID'] == catid))
def preloop(self) -> None:
self.file_cursor = 0
self.update_prompt()
self.setup_for_file()
return super().preloop()
def default(self, line: str):
try:
sel = int(line)
self.onecmd(f"use {sel}")
except ValueError:
return super().default(line)
return super().default(line)
def precmd(self, line: str):
try:
rec = int(line)
@@ -136,7 +130,7 @@ class Commands(cmd.Cmd):
picked = int(line)
if picked < len(self.rec_list):
cat, subcat, exp = \
self.lookup_cat(self.rec_list[picked])
self.ctx.lookup_category(self.rec_list[picked])
print(f" CatID: {self.rec_list[picked]}")
print(f" Category: {cat}")
@@ -154,7 +148,7 @@ class Commands(cmd.Cmd):
try:
picked = int(line)
print(f" :: Using {self.rec_list[picked]}")
self.onecmd('next')
self.do_next("")
except ValueError:
print(" *** Value \"{line}\" not recognized")
@@ -168,11 +162,6 @@ class Commands(cmd.Cmd):
else:
print( " > No file")
# def do_lookup(self, args):
# 'print a list of UCS categories similar to the argument'
# self.rec_list = classify_text_ranked(args, self.embeddings,
# self.model)
def do_ls(self, _):
'Print list of all files in the buffer'
for file in self.file_list[self.file_cursor:] + \
@@ -199,11 +188,7 @@ class Commands(cmd.Cmd):
def main():
# cats = load_ucs_categories()
print(f"Loaded UCS categories.", file=sys.stderr)
model = SentenceTransformer("paraphrase-multilingual-mpnet-base-v2")
# embeddings = load_embeddings(cats, model)
print(f"Loaded embeddings...", file=sys.stderr)
com = Commands()
com.file_list = sys.argv[1:]
@@ -213,4 +198,9 @@ def main():
if __name__ == '__main__':
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
main()