diff --git a/ucsinfer/__main__.py b/ucsinfer/__main__.py index dbc8fdb..f224c30 100644 --- a/ucsinfer/__main__.py +++ b/ucsinfer/__main__.py @@ -7,22 +7,10 @@ import cmd from typing import Optional, IO -# from .inference import classify_text_ranked - from .inference import InferenceContext -# import numpy as np -# import platformdirs -# import tqdm from sentence_transformers import SentenceTransformer -ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) - -os.environ['TOKENIZERS_PARALLELISM'] = 'false' - -import warnings -warnings.simplefilter(action='ignore', category=FutureWarning) - def description(path: str) -> Optional[str]: result = subprocess.run(['ffprobe', '-show_format', '-of', @@ -40,7 +28,7 @@ def description(path: str) -> Optional[str]: if tags: return tags.get("comment", None) -def recommend_category(ctx, path) -> tuple[str, list]: +def recommend_category(ctx: InferenceContext, path) -> tuple[str, list]: """ Get a text description of the file at `path` and a list of UCS cat IDs """ @@ -64,16 +52,22 @@ class Commands(cmd.Cmd): self.rec_list = [] self.history = [] - # def lookup_cat(self, catid: str) -> tuple[str,str, str]: - # return next( ((x['Category'], x['SubCategory'], x['Explanations']) \ - # for x in self.catlist if x['CatID'] == catid)) - def preloop(self) -> None: self.file_cursor = 0 self.update_prompt() self.setup_for_file() return super().preloop() + def default(self, line: str): + try: + sel = int(line) + self.onecmd(f"use {sel}") + + except ValueError: + return super().default(line) + + return super().default(line) + def precmd(self, line: str): try: rec = int(line) @@ -136,7 +130,7 @@ class Commands(cmd.Cmd): picked = int(line) if picked < len(self.rec_list): cat, subcat, exp = \ - self.lookup_cat(self.rec_list[picked]) + self.ctx.lookup_category(self.rec_list[picked]) print(f" CatID: {self.rec_list[picked]}") print(f" Category: {cat}") @@ -154,7 +148,7 @@ class Commands(cmd.Cmd): try: picked = int(line) print(f" :: Using {self.rec_list[picked]}") - self.onecmd('next') + self.do_next("") except ValueError: print(" *** Value \"{line}\" not recognized") @@ -168,11 +162,6 @@ class Commands(cmd.Cmd): else: print( " > No file") - # def do_lookup(self, args): - # 'print a list of UCS categories similar to the argument' - # self.rec_list = classify_text_ranked(args, self.embeddings, - # self.model) - def do_ls(self, _): 'Print list of all files in the buffer' for file in self.file_list[self.file_cursor:] + \ @@ -199,11 +188,7 @@ class Commands(cmd.Cmd): def main(): - # cats = load_ucs_categories() - print(f"Loaded UCS categories.", file=sys.stderr) model = SentenceTransformer("paraphrase-multilingual-mpnet-base-v2") - # embeddings = load_embeddings(cats, model) - print(f"Loaded embeddings...", file=sys.stderr) com = Commands() com.file_list = sys.argv[1:] @@ -213,4 +198,9 @@ def main(): if __name__ == '__main__': + os.environ['TOKENIZERS_PARALLELISM'] = 'false' + + import warnings + warnings.simplefilter(action='ignore', category=FutureWarning) + main()