diff --git a/TODO.md b/TODO.md index ec7377f..6707ce9 100644 --- a/TODO.md +++ b/TODO.md @@ -2,17 +2,15 @@ - Use History when adding catids + ## Gather - Maybe more dataset configurations -## Validate - -A function for validating a dataset for finetuning ## Fine-tune -- Implement +- Implement BatchAllTripletLoss ## Evaluate @@ -22,6 +20,8 @@ A function for validating a dataset for finetuning - Print raw output - Maybe load everything into a sqlite for slicker reporting + ## Utility -- Dataset partitioning +- Clear caches + diff --git a/ucsinfer/__main__.py b/ucsinfer/__main__.py index 3e8de5a..b669c4c 100644 --- a/ucsinfer/__main__.py +++ b/ucsinfer/__main__.py @@ -1,9 +1,7 @@ import os -import sys import csv import logging -from typing import Generator import tqdm import click @@ -26,10 +24,14 @@ logger.addHandler(stream_handler) @click.group(epilog="For more information see " "") @click.option('--verbose', '-v', flag_value=True, help='Verbose output') +@click.option('--model', type=str, metavar="", + default="paraphrase-multilingual-mpnet-base-v2", + show_default=True, + help="Select the sentence_transformer model to use") @click.option('--no-model-cache', flag_value=True, help="Don't use local model cache") @click.pass_context -def ucsinfer(ctx, verbose, no_model_cache): +def ucsinfer(ctx, verbose, no_model_cache, model): """ Tools for applying UCS categories to sounds using large-language Models """ @@ -47,6 +49,12 @@ def ucsinfer(ctx, verbose, no_model_cache): ctx.ensure_object(dict) ctx.obj['model_cache'] = not no_model_cache + ctx.obj['model_name'] = model + + if no_model_cache: + logger.info("Model cache inhibited by config") + + logger.info("Using model {model}") @ucsinfer.command('recommend') @@ -54,10 +62,6 @@ def ucsinfer(ctx, verbose, no_model_cache): help="Recommend a category for given text instead of reading " "from a file") @click.argument('paths', nargs=-1, metavar='') -@click.option('--model', type=str, metavar="", - default="paraphrase-multilingual-mpnet-base-v2", - show_default=True, - help="Select the sentence_transformer model to use") @click.option('--interactive','-i', flag_value=True, default=False, help="After processing each path in , prompt for a " "recommendation to accept, and then prepend the selection to " @@ -66,7 +70,7 @@ def ucsinfer(ctx, verbose, no_model_cache): help="Skip files that already have a UCS category in their " "name.") @click.pass_context -def recommend(ctx, text, paths, model, interactive, skip_ucs): +def recommend(ctx, text, paths, interactive, skip_ucs): """ Infer a UCS category for a text description @@ -77,7 +81,7 @@ def recommend(ctx, text, paths, model, interactive, skip_ucs): of ranked subcategories is printed to the terminal for each PATH. """ logger.debug("RECOMMEND mode") - inference_ctx = InferenceContext(model, + inference_ctx = InferenceContext(ctx.obj['model_name'], use_cached_model=ctx.obj['model_cache']) if text is not None: @@ -167,7 +171,8 @@ def gather(paths, outfile): @ucsinfer.command('finetune') -def finetune(): +@click.pass_context +def finetune(ctx): """ Fine-tune a model with training data """ @@ -183,14 +188,10 @@ def finetune(): @click.option('--no-foley', 'no_foley', flag_value=True, default=False, help="Ignore any data in the set with FOLYProp or FOLYFeet " "category") -@click.option('--model', type=str, metavar="", - default="paraphrase-multilingual-mpnet-base-v2", - show_default=True, - help="Select the sentence_transformer model to use") @click.argument('dataset', type=click.File('r', encoding='utf8'), default='dataset.csv') @click.pass_context -def evaluate(ctx, dataset, offset, limit, model, no_foley): +def evaluate(ctx, dataset, offset, limit, no_foley): """ Use datasets to evaluate model performance @@ -211,12 +212,11 @@ def evaluate(ctx, dataset, offset, limit, model, no_foley): foley, and so these categories can be excluded with the --no-foley option. """ logger.debug("EVALUATE mode") - inference_context = InferenceContext(model, + inference_context = InferenceContext(ctx.obj['model_name'], use_cached_model= ctx.obj['model_cache']) reader = csv.reader(dataset) - logger.info(f"Evaluating model {model}...") results = [] if offset > 0: