This commit is contained in:
2025-09-03 14:23:09 -07:00
parent 0a2aaa2a22
commit 46a693bf93
2 changed files with 22 additions and 22 deletions

10
TODO.md
View File

@@ -2,17 +2,15 @@
- Use History when adding catids - Use History when adding catids
## Gather ## Gather
- Maybe more dataset configurations - Maybe more dataset configurations
## Validate
A function for validating a dataset for finetuning
## Fine-tune ## Fine-tune
- Implement - Implement BatchAllTripletLoss
## Evaluate ## Evaluate
@@ -22,6 +20,8 @@ A function for validating a dataset for finetuning
- Print raw output - Print raw output
- Maybe load everything into a sqlite for slicker reporting - Maybe load everything into a sqlite for slicker reporting
## Utility ## Utility
- Dataset partitioning - Clear caches

View File

@@ -1,9 +1,7 @@
import os import os
import sys
import csv import csv
import logging import logging
from typing import Generator
import tqdm import tqdm
import click import click
@@ -26,10 +24,14 @@ logger.addHandler(stream_handler)
@click.group(epilog="For more information see " @click.group(epilog="For more information see "
"<https://git.squad51.us/jamie/ucsinfer>") "<https://git.squad51.us/jamie/ucsinfer>")
@click.option('--verbose', '-v', flag_value=True, help='Verbose output') @click.option('--verbose', '-v', flag_value=True, help='Verbose output')
@click.option('--model', type=str, metavar="<model-name>",
default="paraphrase-multilingual-mpnet-base-v2",
show_default=True,
help="Select the sentence_transformer model to use")
@click.option('--no-model-cache', flag_value=True, @click.option('--no-model-cache', flag_value=True,
help="Don't use local model cache") help="Don't use local model cache")
@click.pass_context @click.pass_context
def ucsinfer(ctx, verbose, no_model_cache): def ucsinfer(ctx, verbose, no_model_cache, model):
""" """
Tools for applying UCS categories to sounds using large-language Models Tools for applying UCS categories to sounds using large-language Models
""" """
@@ -47,6 +49,12 @@ def ucsinfer(ctx, verbose, no_model_cache):
ctx.ensure_object(dict) ctx.ensure_object(dict)
ctx.obj['model_cache'] = not no_model_cache ctx.obj['model_cache'] = not no_model_cache
ctx.obj['model_name'] = model
if no_model_cache:
logger.info("Model cache inhibited by config")
logger.info("Using model {model}")
@ucsinfer.command('recommend') @ucsinfer.command('recommend')
@@ -54,10 +62,6 @@ def ucsinfer(ctx, verbose, no_model_cache):
help="Recommend a category for given text instead of reading " help="Recommend a category for given text instead of reading "
"from a file") "from a file")
@click.argument('paths', nargs=-1, metavar='<paths>') @click.argument('paths', nargs=-1, metavar='<paths>')
@click.option('--model', type=str, metavar="<model-name>",
default="paraphrase-multilingual-mpnet-base-v2",
show_default=True,
help="Select the sentence_transformer model to use")
@click.option('--interactive','-i', flag_value=True, default=False, @click.option('--interactive','-i', flag_value=True, default=False,
help="After processing each path in <paths>, prompt for a " help="After processing each path in <paths>, prompt for a "
"recommendation to accept, and then prepend the selection to " "recommendation to accept, and then prepend the selection to "
@@ -66,7 +70,7 @@ def ucsinfer(ctx, verbose, no_model_cache):
help="Skip files that already have a UCS category in their " help="Skip files that already have a UCS category in their "
"name.") "name.")
@click.pass_context @click.pass_context
def recommend(ctx, text, paths, model, interactive, skip_ucs): def recommend(ctx, text, paths, interactive, skip_ucs):
""" """
Infer a UCS category for a text description Infer a UCS category for a text description
@@ -77,7 +81,7 @@ def recommend(ctx, text, paths, model, interactive, skip_ucs):
of ranked subcategories is printed to the terminal for each PATH. of ranked subcategories is printed to the terminal for each PATH.
""" """
logger.debug("RECOMMEND mode") logger.debug("RECOMMEND mode")
inference_ctx = InferenceContext(model, inference_ctx = InferenceContext(ctx.obj['model_name'],
use_cached_model=ctx.obj['model_cache']) use_cached_model=ctx.obj['model_cache'])
if text is not None: if text is not None:
@@ -167,7 +171,8 @@ def gather(paths, outfile):
@ucsinfer.command('finetune') @ucsinfer.command('finetune')
def finetune(): @click.pass_context
def finetune(ctx):
""" """
Fine-tune a model with training data Fine-tune a model with training data
""" """
@@ -183,14 +188,10 @@ def finetune():
@click.option('--no-foley', 'no_foley', flag_value=True, default=False, @click.option('--no-foley', 'no_foley', flag_value=True, default=False,
help="Ignore any data in the set with FOLYProp or FOLYFeet " help="Ignore any data in the set with FOLYProp or FOLYFeet "
"category") "category")
@click.option('--model', type=str, metavar="<model-name>",
default="paraphrase-multilingual-mpnet-base-v2",
show_default=True,
help="Select the sentence_transformer model to use")
@click.argument('dataset', type=click.File('r', encoding='utf8'), @click.argument('dataset', type=click.File('r', encoding='utf8'),
default='dataset.csv') default='dataset.csv')
@click.pass_context @click.pass_context
def evaluate(ctx, dataset, offset, limit, model, no_foley): def evaluate(ctx, dataset, offset, limit, no_foley):
""" """
Use datasets to evaluate model performance Use datasets to evaluate model performance
@@ -211,12 +212,11 @@ def evaluate(ctx, dataset, offset, limit, model, no_foley):
foley, and so these categories can be excluded with the --no-foley option. foley, and so these categories can be excluded with the --no-foley option.
""" """
logger.debug("EVALUATE mode") logger.debug("EVALUATE mode")
inference_context = InferenceContext(model, inference_context = InferenceContext(ctx.obj['model_name'],
use_cached_model= use_cached_model=
ctx.obj['model_cache']) ctx.obj['model_cache'])
reader = csv.reader(dataset) reader = csv.reader(dataset)
logger.info(f"Evaluating model {model}...")
results = [] results = []
if offset > 0: if offset > 0: