WIP
This commit is contained in:
@@ -7,22 +7,10 @@ import cmd
|
||||
|
||||
from typing import Optional, IO
|
||||
|
||||
# from .inference import classify_text_ranked
|
||||
|
||||
from .inference import InferenceContext
|
||||
|
||||
# import numpy as np
|
||||
# import platformdirs
|
||||
# import tqdm
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
||||
|
||||
import warnings
|
||||
warnings.simplefilter(action='ignore', category=FutureWarning)
|
||||
|
||||
|
||||
def description(path: str) -> Optional[str]:
|
||||
result = subprocess.run(['ffprobe', '-show_format', '-of',
|
||||
@@ -40,7 +28,7 @@ def description(path: str) -> Optional[str]:
|
||||
if tags:
|
||||
return tags.get("comment", None)
|
||||
|
||||
def recommend_category(ctx, path) -> tuple[str, list]:
|
||||
def recommend_category(ctx: InferenceContext, path) -> tuple[str, list]:
|
||||
"""
|
||||
Get a text description of the file at `path` and a list of UCS cat IDs
|
||||
"""
|
||||
@@ -64,16 +52,22 @@ class Commands(cmd.Cmd):
|
||||
self.rec_list = []
|
||||
self.history = []
|
||||
|
||||
# def lookup_cat(self, catid: str) -> tuple[str,str, str]:
|
||||
# return next( ((x['Category'], x['SubCategory'], x['Explanations']) \
|
||||
# for x in self.catlist if x['CatID'] == catid))
|
||||
|
||||
def preloop(self) -> None:
|
||||
self.file_cursor = 0
|
||||
self.update_prompt()
|
||||
self.setup_for_file()
|
||||
return super().preloop()
|
||||
|
||||
def default(self, line: str):
|
||||
try:
|
||||
sel = int(line)
|
||||
self.onecmd(f"use {sel}")
|
||||
|
||||
except ValueError:
|
||||
return super().default(line)
|
||||
|
||||
return super().default(line)
|
||||
|
||||
def precmd(self, line: str):
|
||||
try:
|
||||
rec = int(line)
|
||||
@@ -136,7 +130,7 @@ class Commands(cmd.Cmd):
|
||||
picked = int(line)
|
||||
if picked < len(self.rec_list):
|
||||
cat, subcat, exp = \
|
||||
self.lookup_cat(self.rec_list[picked])
|
||||
self.ctx.lookup_category(self.rec_list[picked])
|
||||
|
||||
print(f" CatID: {self.rec_list[picked]}")
|
||||
print(f" Category: {cat}")
|
||||
@@ -154,7 +148,7 @@ class Commands(cmd.Cmd):
|
||||
try:
|
||||
picked = int(line)
|
||||
print(f" :: Using {self.rec_list[picked]}")
|
||||
self.onecmd('next')
|
||||
self.do_next("")
|
||||
except ValueError:
|
||||
print(" *** Value \"{line}\" not recognized")
|
||||
|
||||
@@ -168,11 +162,6 @@ class Commands(cmd.Cmd):
|
||||
else:
|
||||
print( " > No file")
|
||||
|
||||
# def do_lookup(self, args):
|
||||
# 'print a list of UCS categories similar to the argument'
|
||||
# self.rec_list = classify_text_ranked(args, self.embeddings,
|
||||
# self.model)
|
||||
|
||||
def do_ls(self, _):
|
||||
'Print list of all files in the buffer'
|
||||
for file in self.file_list[self.file_cursor:] + \
|
||||
@@ -199,11 +188,7 @@ class Commands(cmd.Cmd):
|
||||
|
||||
|
||||
def main():
|
||||
# cats = load_ucs_categories()
|
||||
print(f"Loaded UCS categories.", file=sys.stderr)
|
||||
model = SentenceTransformer("paraphrase-multilingual-mpnet-base-v2")
|
||||
# embeddings = load_embeddings(cats, model)
|
||||
print(f"Loaded embeddings...", file=sys.stderr)
|
||||
|
||||
com = Commands()
|
||||
com.file_list = sys.argv[1:]
|
||||
@@ -213,4 +198,9 @@ def main():
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
||||
|
||||
import warnings
|
||||
warnings.simplefilter(action='ignore', category=FutureWarning)
|
||||
|
||||
main()
|
||||
|
Reference in New Issue
Block a user