tweaks
This commit is contained in:
10
TODO.md
10
TODO.md
@@ -2,17 +2,15 @@
|
|||||||
|
|
||||||
- Use History when adding catids
|
- Use History when adding catids
|
||||||
|
|
||||||
|
|
||||||
## Gather
|
## Gather
|
||||||
|
|
||||||
- Maybe more dataset configurations
|
- Maybe more dataset configurations
|
||||||
|
|
||||||
## Validate
|
|
||||||
|
|
||||||
A function for validating a dataset for finetuning
|
|
||||||
|
|
||||||
## Fine-tune
|
## Fine-tune
|
||||||
|
|
||||||
- Implement
|
- Implement BatchAllTripletLoss
|
||||||
|
|
||||||
|
|
||||||
## Evaluate
|
## Evaluate
|
||||||
@@ -22,6 +20,8 @@ A function for validating a dataset for finetuning
|
|||||||
- Print raw output
|
- Print raw output
|
||||||
- Maybe load everything into a sqlite for slicker reporting
|
- Maybe load everything into a sqlite for slicker reporting
|
||||||
|
|
||||||
|
|
||||||
## Utility
|
## Utility
|
||||||
|
|
||||||
- Dataset partitioning
|
- Clear caches
|
||||||
|
|
||||||
|
@@ -1,9 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import csv
|
import csv
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from typing import Generator
|
|
||||||
|
|
||||||
import tqdm
|
import tqdm
|
||||||
import click
|
import click
|
||||||
@@ -26,10 +24,14 @@ logger.addHandler(stream_handler)
|
|||||||
@click.group(epilog="For more information see "
|
@click.group(epilog="For more information see "
|
||||||
"<https://git.squad51.us/jamie/ucsinfer>")
|
"<https://git.squad51.us/jamie/ucsinfer>")
|
||||||
@click.option('--verbose', '-v', flag_value=True, help='Verbose output')
|
@click.option('--verbose', '-v', flag_value=True, help='Verbose output')
|
||||||
|
@click.option('--model', type=str, metavar="<model-name>",
|
||||||
|
default="paraphrase-multilingual-mpnet-base-v2",
|
||||||
|
show_default=True,
|
||||||
|
help="Select the sentence_transformer model to use")
|
||||||
@click.option('--no-model-cache', flag_value=True,
|
@click.option('--no-model-cache', flag_value=True,
|
||||||
help="Don't use local model cache")
|
help="Don't use local model cache")
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
def ucsinfer(ctx, verbose, no_model_cache):
|
def ucsinfer(ctx, verbose, no_model_cache, model):
|
||||||
"""
|
"""
|
||||||
Tools for applying UCS categories to sounds using large-language Models
|
Tools for applying UCS categories to sounds using large-language Models
|
||||||
"""
|
"""
|
||||||
@@ -47,6 +49,12 @@ def ucsinfer(ctx, verbose, no_model_cache):
|
|||||||
|
|
||||||
ctx.ensure_object(dict)
|
ctx.ensure_object(dict)
|
||||||
ctx.obj['model_cache'] = not no_model_cache
|
ctx.obj['model_cache'] = not no_model_cache
|
||||||
|
ctx.obj['model_name'] = model
|
||||||
|
|
||||||
|
if no_model_cache:
|
||||||
|
logger.info("Model cache inhibited by config")
|
||||||
|
|
||||||
|
logger.info("Using model {model}")
|
||||||
|
|
||||||
|
|
||||||
@ucsinfer.command('recommend')
|
@ucsinfer.command('recommend')
|
||||||
@@ -54,10 +62,6 @@ def ucsinfer(ctx, verbose, no_model_cache):
|
|||||||
help="Recommend a category for given text instead of reading "
|
help="Recommend a category for given text instead of reading "
|
||||||
"from a file")
|
"from a file")
|
||||||
@click.argument('paths', nargs=-1, metavar='<paths>')
|
@click.argument('paths', nargs=-1, metavar='<paths>')
|
||||||
@click.option('--model', type=str, metavar="<model-name>",
|
|
||||||
default="paraphrase-multilingual-mpnet-base-v2",
|
|
||||||
show_default=True,
|
|
||||||
help="Select the sentence_transformer model to use")
|
|
||||||
@click.option('--interactive','-i', flag_value=True, default=False,
|
@click.option('--interactive','-i', flag_value=True, default=False,
|
||||||
help="After processing each path in <paths>, prompt for a "
|
help="After processing each path in <paths>, prompt for a "
|
||||||
"recommendation to accept, and then prepend the selection to "
|
"recommendation to accept, and then prepend the selection to "
|
||||||
@@ -66,7 +70,7 @@ def ucsinfer(ctx, verbose, no_model_cache):
|
|||||||
help="Skip files that already have a UCS category in their "
|
help="Skip files that already have a UCS category in their "
|
||||||
"name.")
|
"name.")
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
def recommend(ctx, text, paths, model, interactive, skip_ucs):
|
def recommend(ctx, text, paths, interactive, skip_ucs):
|
||||||
"""
|
"""
|
||||||
Infer a UCS category for a text description
|
Infer a UCS category for a text description
|
||||||
|
|
||||||
@@ -77,7 +81,7 @@ def recommend(ctx, text, paths, model, interactive, skip_ucs):
|
|||||||
of ranked subcategories is printed to the terminal for each PATH.
|
of ranked subcategories is printed to the terminal for each PATH.
|
||||||
"""
|
"""
|
||||||
logger.debug("RECOMMEND mode")
|
logger.debug("RECOMMEND mode")
|
||||||
inference_ctx = InferenceContext(model,
|
inference_ctx = InferenceContext(ctx.obj['model_name'],
|
||||||
use_cached_model=ctx.obj['model_cache'])
|
use_cached_model=ctx.obj['model_cache'])
|
||||||
|
|
||||||
if text is not None:
|
if text is not None:
|
||||||
@@ -167,7 +171,8 @@ def gather(paths, outfile):
|
|||||||
|
|
||||||
|
|
||||||
@ucsinfer.command('finetune')
|
@ucsinfer.command('finetune')
|
||||||
def finetune():
|
@click.pass_context
|
||||||
|
def finetune(ctx):
|
||||||
"""
|
"""
|
||||||
Fine-tune a model with training data
|
Fine-tune a model with training data
|
||||||
"""
|
"""
|
||||||
@@ -183,14 +188,10 @@ def finetune():
|
|||||||
@click.option('--no-foley', 'no_foley', flag_value=True, default=False,
|
@click.option('--no-foley', 'no_foley', flag_value=True, default=False,
|
||||||
help="Ignore any data in the set with FOLYProp or FOLYFeet "
|
help="Ignore any data in the set with FOLYProp or FOLYFeet "
|
||||||
"category")
|
"category")
|
||||||
@click.option('--model', type=str, metavar="<model-name>",
|
|
||||||
default="paraphrase-multilingual-mpnet-base-v2",
|
|
||||||
show_default=True,
|
|
||||||
help="Select the sentence_transformer model to use")
|
|
||||||
@click.argument('dataset', type=click.File('r', encoding='utf8'),
|
@click.argument('dataset', type=click.File('r', encoding='utf8'),
|
||||||
default='dataset.csv')
|
default='dataset.csv')
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
def evaluate(ctx, dataset, offset, limit, model, no_foley):
|
def evaluate(ctx, dataset, offset, limit, no_foley):
|
||||||
"""
|
"""
|
||||||
Use datasets to evaluate model performance
|
Use datasets to evaluate model performance
|
||||||
|
|
||||||
@@ -211,12 +212,11 @@ def evaluate(ctx, dataset, offset, limit, model, no_foley):
|
|||||||
foley, and so these categories can be excluded with the --no-foley option.
|
foley, and so these categories can be excluded with the --no-foley option.
|
||||||
"""
|
"""
|
||||||
logger.debug("EVALUATE mode")
|
logger.debug("EVALUATE mode")
|
||||||
inference_context = InferenceContext(model,
|
inference_context = InferenceContext(ctx.obj['model_name'],
|
||||||
use_cached_model=
|
use_cached_model=
|
||||||
ctx.obj['model_cache'])
|
ctx.obj['model_cache'])
|
||||||
reader = csv.reader(dataset)
|
reader = csv.reader(dataset)
|
||||||
|
|
||||||
logger.info(f"Evaluating model {model}...")
|
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
if offset > 0:
|
if offset > 0:
|
||||||
|
Reference in New Issue
Block a user