diff --git a/ucsinfer/__main__.py b/ucsinfer/__main__.py index 2fc1e06..ed20fc9 100644 --- a/ucsinfer/__main__.py +++ b/ucsinfer/__main__.py @@ -11,25 +11,39 @@ from .inference import InferenceContext, load_ucs from .recommend import print_recommendation from .util import ffmpeg_description, parse_ucs +logger = logging.getLogger('ucsinfer') +logger.setLevel(logging.DEBUG) +stream_handler = logging.StreamHandler() +stream_handler.setLevel(logging.WARN) +formatter = logging.Formatter( + '%(asctime)s. %(levelname)s %(name)s: %(message)s') +stream_handler.setFormatter(formatter) +logger.addHandler(stream_handler) + @click.group(epilog="For more information see " "") -@click.option('--verbose', '-v', flag_value='verbose', help='Verbose output') -def ucsinfer(verbose): +@click.option('--verbose', '-v', flag_value=True, help='Verbose output') +@click.option('--no-model-cache', flag_value=True, + help="Don't use local model cache") +@click.pass_context +def ucsinfer(ctx, verbose, no_model_cache): """ Tools for applying UCS categories to sounds using large-language Models """ if verbose: - logging.basicConfig(format="%(levelname)s: %(message)s", - level=logging.DEBUG) + stream_handler.setLevel(logging.DEBUG) + logger.info("Verbose logging is enabled") else: import warnings warnings.filterwarnings( action='ignore', module='torch', category=FutureWarning, message=r"`encoder_attention_mask` is deprecated.*") - logging.basicConfig(format="%(levelname)s: %(message)s", - level=logging.WARN) + stream_handler.setLevel(logging.WARNING) + + ctx.ensure_object(dict) + ctx.obj['model_cache'] = not no_model_cache @ucsinfer.command('recommend') @@ -48,7 +62,8 @@ def ucsinfer(verbose): @click.option('-s', '--skip-ucs', flag_value=True, default=False, help="Skip files that already have a UCS category in their " "name.") -def recommend(text, paths, model, interactive, skip_ucs): +@click.pass_context +def recommend(ctx, text, paths, model, interactive, skip_ucs): """ Infer a UCS category for a text description @@ -58,12 +73,15 @@ def recommend(text, paths, model, interactive, skip_ucs): the synonyms an explanations definied for each UCS subcategory. A list of ranked subcategories is printed to the terminal for each PATH. """ - ctx = InferenceContext(model) + logger.debug("RECOMMEND mode") + inference_ctx = InferenceContext(model, + use_cached_model=ctx.obj['model_cache']) if text is not None: - print_recommendation(None, text, ctx, interactive_rename=False) + print_recommendation(None, text, inference_ctx, + interactive_rename=False) - catlist = [x.catid for x in ctx.catlist] + catlist = [x.catid for x in inference_ctx.catlist] for path in paths: basename = os.path.basename(path) @@ -75,7 +93,7 @@ def recommend(text, paths, model, interactive, skip_ucs): text = os.path.basename(path) while True: - retval = print_recommendation(path, text, ctx, interactive) + retval = print_recommendation(path, text, inference_ctx, interactive) if not retval: break if retval[0] is False: @@ -109,6 +127,8 @@ def gather(paths, outfile): is the CatID indicated for the file, and the second is the embedded file description for the file as returned by ffprobe. """ + logger.debug("GATHER mode") + types = ['.wav', '.flac'] table = csv.writer(outfile) print(f"Loading category list...") @@ -116,7 +136,7 @@ def gather(paths, outfile): scan_list = [] for path in paths: - print(f"Scanning directory {path}...", file=sys.stdout) + logger.info(f"Scanning directory {path}...") for dirpath, _, filenames in os.walk(path): for filename in filenames: root, ext = os.path.splitext(filename) @@ -126,7 +146,7 @@ def gather(paths, outfile): scan_list.append((ucs_components.cat_id, os.path.join(dirpath, filename))) - print(f"Found {len(scan_list)} files to process.") + logger.info(f"Found {len(scan_list)} files to process.") for pair in tqdm.tqdm(scan_list, unit='files', file=sys.stderr): if desc := ffmpeg_description(pair[1]): @@ -138,7 +158,8 @@ def finetune(): """ Fine-tune a model with training data """ - pass + logger.debug("FINETUNE mode") + @ucsinfer.command('evaluate') @@ -155,7 +176,8 @@ def finetune(): help="Select the sentence_transformer model to use") @click.argument('dataset', type=click.File('r', encoding='utf8'), default='dataset.csv') -def evaluate(dataset, offset, limit, model, no_foley): +@click.pass_context +def evaluate(ctx, dataset, offset, limit, model, no_foley): """ Use datasets to evaluate model performance @@ -175,19 +197,24 @@ def evaluate(dataset, offset, limit, model, no_foley): classified according to their subject and not wether or not they were foley, and so these categories can be excluded with the --no-foley option. """ - ctx = InferenceContext(model) + logger.debug("EVALUATE mode") + inference_context = InferenceContext(model, + use_cached_model= + ctx.obj['model_cache']) reader = csv.reader(dataset) - print(f"Evaluating model {model}...") + logger.info(f"Evaluating model {model}...") results = [] if offset > 0: - print(f"Skipping {offset} records...") + logger.debug(f"Skipping {offset} records...") if limit > 0: - print(f"Will only evaluate {limit} records...") + logger.debug(f"Will only evaluate {limit} records...") - progress_bar = tqdm.tqdm(total=limit) + progress_bar = tqdm.tqdm(total=limit, + desc="Processing dataset...", + unit="rec") for i, row in enumerate(reader): if i < offset: continue @@ -199,7 +226,7 @@ def evaluate(dataset, offset, limit, model, no_foley): if no_foley and cat_id in ['FOLYProp', 'FOLYFeet']: continue - guesses = ctx.classify_text_ranked(description, limit=10) + guesses = inference_context.classify_text_ranked(description, limit=10) if cat_id == guesses[0]: results.append({'catid': cat_id, 'result': "TOP"}) elif cat_id in guesses[0:5]: @@ -241,9 +268,9 @@ def evaluate(dataset, offset, limit, model, no_foley): ["Top 10 Result:", f"{total_top_10}", f"{float(total_top_10)/float(total):.2%}"], SEPARATING_LINE, - ["UCS category count:", f"{len(ctx.catlist)}"], + ["UCS category count:", f"{len(inference_context.catlist)}"], ["Total categories in sample:", f"{total_cats}", - f"{float(total_cats)/float(len(ctx.catlist)):.2%}"], + f"{float(total_cats)/float(len(inference_context.catlist)):.2%}"], [f"Most missed category ({miss_counts[-1][0]}):", f"{miss_counts[-1][1]}", f"{float(miss_counts[-1][1])/float(total):.2%}"] @@ -255,4 +282,4 @@ def evaluate(dataset, offset, limit, model, no_foley): if __name__ == '__main__': os.environ['TOKENIZERS_PARALLELISM'] = 'false' - ucsinfer() + ucsinfer(obj={}) diff --git a/ucsinfer/inference.py b/ucsinfer/inference.py index b5fc4a5..353b8a3 100644 --- a/ucsinfer/inference.py +++ b/ucsinfer/inference.py @@ -54,9 +54,24 @@ class InferenceContext: model: SentenceTransformer model_name: str - def __init__(self, model_name: str): - self.model = SentenceTransformer(model_name) + def __init__(self, model_name: str, use_cached_model: bool = True): self.model_name = model_name + cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer", + 'Squad 51') + + model_cache_path = os.path.join(cache_dir, + f"{self.model_name}.cache") + + if use_cached_model: + if os.path.exists(model_cache_path): + self.model = SentenceTransformer(model_cache_path) + else: + self.model = SentenceTransformer(model_name) + self.model.save(model_cache_path) + + else: + self.model = SentenceTransformer(model_name) + @cached_property def catlist(self) -> list[Ucs]: @@ -66,14 +81,14 @@ class InferenceContext: def embeddings(self) -> list[dict]: cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer", 'Squad 51') - embedding_cache = os.path.join( + embedding_cache_path = os.path.join( cache_dir, f"{self.model_name}-ucs_embedding.cache") embeddings = [] - if os.path.exists(embedding_cache): - with open(embedding_cache, 'rb') as f: + if os.path.exists(embedding_cache_path): + with open(embedding_cache_path, 'rb') as f: embeddings = pickle.load(f) else: @@ -85,8 +100,8 @@ class InferenceContext: 'Embedding': self._encode_category(cat_defn) }] - os.makedirs(os.path.dirname(embedding_cache), exist_ok=True) - with open(embedding_cache, 'wb') as g: + os.makedirs(os.path.dirname(embedding_cache_path), exist_ok=True) + with open(embedding_cache_path, 'wb') as g: pickle.dump(embeddings, g) return embeddings