WIP
This commit is contained in:
@@ -7,22 +7,10 @@ import cmd
|
|||||||
|
|
||||||
from typing import Optional, IO
|
from typing import Optional, IO
|
||||||
|
|
||||||
# from .inference import classify_text_ranked
|
|
||||||
|
|
||||||
from .inference import InferenceContext
|
from .inference import InferenceContext
|
||||||
|
|
||||||
# import numpy as np
|
|
||||||
# import platformdirs
|
|
||||||
# import tqdm
|
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
|
|
||||||
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
|
||||||
|
|
||||||
import warnings
|
|
||||||
warnings.simplefilter(action='ignore', category=FutureWarning)
|
|
||||||
|
|
||||||
|
|
||||||
def description(path: str) -> Optional[str]:
|
def description(path: str) -> Optional[str]:
|
||||||
result = subprocess.run(['ffprobe', '-show_format', '-of',
|
result = subprocess.run(['ffprobe', '-show_format', '-of',
|
||||||
@@ -40,7 +28,7 @@ def description(path: str) -> Optional[str]:
|
|||||||
if tags:
|
if tags:
|
||||||
return tags.get("comment", None)
|
return tags.get("comment", None)
|
||||||
|
|
||||||
def recommend_category(ctx, path) -> tuple[str, list]:
|
def recommend_category(ctx: InferenceContext, path) -> tuple[str, list]:
|
||||||
"""
|
"""
|
||||||
Get a text description of the file at `path` and a list of UCS cat IDs
|
Get a text description of the file at `path` and a list of UCS cat IDs
|
||||||
"""
|
"""
|
||||||
@@ -64,16 +52,22 @@ class Commands(cmd.Cmd):
|
|||||||
self.rec_list = []
|
self.rec_list = []
|
||||||
self.history = []
|
self.history = []
|
||||||
|
|
||||||
# def lookup_cat(self, catid: str) -> tuple[str,str, str]:
|
|
||||||
# return next( ((x['Category'], x['SubCategory'], x['Explanations']) \
|
|
||||||
# for x in self.catlist if x['CatID'] == catid))
|
|
||||||
|
|
||||||
def preloop(self) -> None:
|
def preloop(self) -> None:
|
||||||
self.file_cursor = 0
|
self.file_cursor = 0
|
||||||
self.update_prompt()
|
self.update_prompt()
|
||||||
self.setup_for_file()
|
self.setup_for_file()
|
||||||
return super().preloop()
|
return super().preloop()
|
||||||
|
|
||||||
|
def default(self, line: str):
|
||||||
|
try:
|
||||||
|
sel = int(line)
|
||||||
|
self.onecmd(f"use {sel}")
|
||||||
|
|
||||||
|
except ValueError:
|
||||||
|
return super().default(line)
|
||||||
|
|
||||||
|
return super().default(line)
|
||||||
|
|
||||||
def precmd(self, line: str):
|
def precmd(self, line: str):
|
||||||
try:
|
try:
|
||||||
rec = int(line)
|
rec = int(line)
|
||||||
@@ -136,7 +130,7 @@ class Commands(cmd.Cmd):
|
|||||||
picked = int(line)
|
picked = int(line)
|
||||||
if picked < len(self.rec_list):
|
if picked < len(self.rec_list):
|
||||||
cat, subcat, exp = \
|
cat, subcat, exp = \
|
||||||
self.lookup_cat(self.rec_list[picked])
|
self.ctx.lookup_category(self.rec_list[picked])
|
||||||
|
|
||||||
print(f" CatID: {self.rec_list[picked]}")
|
print(f" CatID: {self.rec_list[picked]}")
|
||||||
print(f" Category: {cat}")
|
print(f" Category: {cat}")
|
||||||
@@ -154,7 +148,7 @@ class Commands(cmd.Cmd):
|
|||||||
try:
|
try:
|
||||||
picked = int(line)
|
picked = int(line)
|
||||||
print(f" :: Using {self.rec_list[picked]}")
|
print(f" :: Using {self.rec_list[picked]}")
|
||||||
self.onecmd('next')
|
self.do_next("")
|
||||||
except ValueError:
|
except ValueError:
|
||||||
print(" *** Value \"{line}\" not recognized")
|
print(" *** Value \"{line}\" not recognized")
|
||||||
|
|
||||||
@@ -168,11 +162,6 @@ class Commands(cmd.Cmd):
|
|||||||
else:
|
else:
|
||||||
print( " > No file")
|
print( " > No file")
|
||||||
|
|
||||||
# def do_lookup(self, args):
|
|
||||||
# 'print a list of UCS categories similar to the argument'
|
|
||||||
# self.rec_list = classify_text_ranked(args, self.embeddings,
|
|
||||||
# self.model)
|
|
||||||
|
|
||||||
def do_ls(self, _):
|
def do_ls(self, _):
|
||||||
'Print list of all files in the buffer'
|
'Print list of all files in the buffer'
|
||||||
for file in self.file_list[self.file_cursor:] + \
|
for file in self.file_list[self.file_cursor:] + \
|
||||||
@@ -199,11 +188,7 @@ class Commands(cmd.Cmd):
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# cats = load_ucs_categories()
|
|
||||||
print(f"Loaded UCS categories.", file=sys.stderr)
|
|
||||||
model = SentenceTransformer("paraphrase-multilingual-mpnet-base-v2")
|
model = SentenceTransformer("paraphrase-multilingual-mpnet-base-v2")
|
||||||
# embeddings = load_embeddings(cats, model)
|
|
||||||
print(f"Loaded embeddings...", file=sys.stderr)
|
|
||||||
|
|
||||||
com = Commands()
|
com = Commands()
|
||||||
com.file_list = sys.argv[1:]
|
com.file_list = sys.argv[1:]
|
||||||
@@ -213,4 +198,9 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
warnings.simplefilter(action='ignore', category=FutureWarning)
|
||||||
|
|
||||||
main()
|
main()
|
||||||
|
Reference in New Issue
Block a user