Compare commits
3 Commits
354d1d4e40
...
4f7b2a73cb
Author | SHA1 | Date | |
---|---|---|---|
![]() |
4f7b2a73cb | ||
![]() |
cc78291f1d | ||
![]() |
7b3930ece0 |
@@ -2,6 +2,8 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import csv
|
import csv
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
import tqdm
|
import tqdm
|
||||||
import click
|
import click
|
||||||
@@ -11,19 +13,60 @@ from .inference import InferenceContext, load_ucs
|
|||||||
from .util import ffmpeg_description, parse_ucs
|
from .util import ffmpeg_description, parse_ucs
|
||||||
|
|
||||||
|
|
||||||
@click.group()
|
def recommend_text(text: str, ctx: InferenceContext):
|
||||||
|
return None
|
||||||
|
|
||||||
|
@click.group(epilog="For more information see "
|
||||||
|
"<https://git.squad51.us/jamie/ucsinfer>")
|
||||||
# @click.option('--verbose', flag_value='verbose', help='Verbose output')
|
# @click.option('--verbose', flag_value='verbose', help='Verbose output')
|
||||||
def ucsinfer():
|
def ucsinfer():
|
||||||
|
"""
|
||||||
|
Tools for applying UCS categories to sounds using large-language Models
|
||||||
|
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ucsinfer.command('recommend')
|
@ucsinfer.command('recommend')
|
||||||
def recommend():
|
@click.option('--text', type=Optional[str], default=None,
|
||||||
|
help="Recommend a category for given text instead of reading "
|
||||||
|
"from a file")
|
||||||
|
@click.argument('files', nargs=-1)
|
||||||
|
@click.option('--model', type=str, metavar="<model-name>",
|
||||||
|
default="paraphrase-multilingual-mpnet-base-v2",
|
||||||
|
show_default=True,
|
||||||
|
help="Select the sentence_transformer model to use")
|
||||||
|
def recommend(text, paths, model):
|
||||||
"""
|
"""
|
||||||
Infer a UCS category for a text description
|
Infer a UCS category for a text description
|
||||||
|
|
||||||
"""
|
"""
|
||||||
pass
|
m = SentenceTransformer(model)
|
||||||
|
ctx = InferenceContext(m, model)
|
||||||
|
|
||||||
|
recommendations = []
|
||||||
|
if text is not None:
|
||||||
|
recommendations.append({
|
||||||
|
"text": text,
|
||||||
|
"recommendations": recommend_text(text, ctx)
|
||||||
|
})
|
||||||
|
|
||||||
|
for path in paths:
|
||||||
|
text = ffmpeg_description(path)
|
||||||
|
if text:
|
||||||
|
recommendations.append({
|
||||||
|
"path":path,
|
||||||
|
"text":text,
|
||||||
|
"recommendations":recommend_text(text, ctx)
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
recommendations.append({
|
||||||
|
"path":path,
|
||||||
|
"text":None,
|
||||||
|
"recommendations":None
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ucsinfer.command('gather')
|
@ucsinfer.command('gather')
|
||||||
@@ -77,13 +120,17 @@ def finetune():
|
|||||||
|
|
||||||
|
|
||||||
@ucsinfer.command('evaluate')
|
@ucsinfer.command('evaluate')
|
||||||
@click.option('--offset', type=int, default=0)
|
@click.option('--offset', type=int, default=0, metavar="<int>",
|
||||||
@click.option('--limit', type=int, default=-1)
|
help='Skip this many records in the dataset before processing')
|
||||||
|
@click.option('--limit', type=int, default=-1, metavar="<int>",
|
||||||
|
help='Process this many records and then exit')
|
||||||
@click.option('--no-foley', 'no_foley', flag_value=True, default=False,
|
@click.option('--no-foley', 'no_foley', flag_value=True, default=False,
|
||||||
help="Ignore any data in the set with FOLYProp or FOLYFeet "
|
help="Ignore any data in the set with FOLYProp or FOLYFeet "
|
||||||
"category")
|
"category")
|
||||||
@click.option('--model', type=str,
|
@click.option('--model', type=str, metavar="<model-name>",
|
||||||
default="paraphrase-multilingual-mpnet-base-v2")
|
default="paraphrase-multilingual-mpnet-base-v2",
|
||||||
|
show_default=True,
|
||||||
|
help="Select the sentence_transformer model to use")
|
||||||
@click.argument('dataset', type=click.File('r', encoding='utf8'),
|
@click.argument('dataset', type=click.File('r', encoding='utf8'),
|
||||||
default='dataset.csv')
|
default='dataset.csv')
|
||||||
def evaluate(dataset, offset, limit, model, no_foley):
|
def evaluate(dataset, offset, limit, model, no_foley):
|
||||||
@@ -110,8 +157,17 @@ def evaluate(dataset, offset, limit, model, no_foley):
|
|||||||
ctx = InferenceContext(m, model)
|
ctx = InferenceContext(m, model)
|
||||||
reader = csv.reader(dataset)
|
reader = csv.reader(dataset)
|
||||||
|
|
||||||
|
print(f"Evaluating model {model}...")
|
||||||
results = []
|
results = []
|
||||||
for i, row in enumerate(tqdm.tqdm(reader)):
|
|
||||||
|
if offset > 0:
|
||||||
|
print(f"Skipping {offset} records...")
|
||||||
|
|
||||||
|
if limit > 0:
|
||||||
|
print(f"Will only evaluate {limit} records...")
|
||||||
|
|
||||||
|
progress_bar = tqdm.tqdm(total=limit)
|
||||||
|
for i, row in enumerate(reader):
|
||||||
if i < offset:
|
if i < offset:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -132,6 +188,8 @@ def evaluate(dataset, offset, limit, model, no_foley):
|
|||||||
else:
|
else:
|
||||||
results.append({'catid': cat_id, 'result': "MISS"})
|
results.append({'catid': cat_id, 'result': "MISS"})
|
||||||
|
|
||||||
|
progress_bar.update(1)
|
||||||
|
|
||||||
total = len(results)
|
total = len(results)
|
||||||
total_top = len([x for x in results if x['result'] == 'TOP'])
|
total_top = len([x for x in results if x['result'] == 'TOP'])
|
||||||
total_top_5 = len([x for x in results if x['result'] == 'TOP_5'])
|
total_top_5 = len([x for x in results if x['result'] == 'TOP_5'])
|
||||||
|
Reference in New Issue
Block a user