Compare commits

...

5 Commits

Author SHA1 Message Date
336d6f013e removed print message 2025-09-03 12:24:31 -07:00
2a0b0367f8 Improved logging 2025-09-03 12:23:44 -07:00
184edcb7e4 Removed unnecessary parameter 2025-09-03 11:42:10 -07:00
d330f47462 Added some logging code 2025-09-03 11:36:38 -07:00
2fcdc24699 Made the warnings filter more specific 2025-09-03 11:16:51 -07:00
2 changed files with 89 additions and 38 deletions

View File

@@ -1,8 +1,8 @@
import os
import sys
import csv
import logging
from sentence_transformers import SentenceTransformer
import tqdm
import click
from tabulate import tabulate, SEPARATING_LINE
@@ -11,14 +11,39 @@ from .inference import InferenceContext, load_ucs
from .recommend import print_recommendation
from .util import ffmpeg_description, parse_ucs
logger = logging.getLogger('ucsinfer')
logger.setLevel(logging.DEBUG)
stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.WARN)
formatter = logging.Formatter(
'%(asctime)s. %(levelname)s %(name)s: %(message)s')
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
@click.group(epilog="For more information see "
"<https://git.squad51.us/jamie/ucsinfer>")
# @click.option('--verbose', flag_value='verbose', help='Verbose output')
def ucsinfer():
@click.option('--verbose', '-v', flag_value=True, help='Verbose output')
@click.option('--no-model-cache', flag_value=True,
help="Don't use local model cache")
@click.pass_context
def ucsinfer(ctx, verbose, no_model_cache):
"""
Tools for applying UCS categories to sounds using large-language Models
"""
pass
if verbose:
stream_handler.setLevel(logging.DEBUG)
logger.info("Verbose logging is enabled")
else:
import warnings
warnings.filterwarnings(
action='ignore', module='torch', category=FutureWarning,
message=r"`encoder_attention_mask` is deprecated.*")
stream_handler.setLevel(logging.WARNING)
ctx.ensure_object(dict)
ctx.obj['model_cache'] = not no_model_cache
@ucsinfer.command('recommend')
@@ -37,23 +62,26 @@ def ucsinfer():
@click.option('-s', '--skip-ucs', flag_value=True, default=False,
help="Skip files that already have a UCS category in their "
"name.")
def recommend(text, paths, model, interactive, skip_ucs):
@click.pass_context
def recommend(ctx, text, paths, model, interactive, skip_ucs):
"""
Infer a UCS category for a text description
"Description" text metadata is extracted from audio files given as PATHS,
or text can be provided directly using the "--text" option. The selected
model is then used to attempt to classify the given text according to
the synonyms and explanations definied for each UCS subcategory. A list
the synonyms an explanations definied for each UCS subcategory. A list
of ranked subcategories is printed to the terminal for each PATH.
"""
m = SentenceTransformer(model)
ctx = InferenceContext(m, model)
logger.debug("RECOMMEND mode")
inference_ctx = InferenceContext(model,
use_cached_model=ctx.obj['model_cache'])
if text is not None:
print_recommendation(None, text, ctx, interactive_rename=False)
print_recommendation(None, text, inference_ctx,
interactive_rename=False)
catlist = [x.catid for x in ctx.catlist]
catlist = [x.catid for x in inference_ctx.catlist]
for path in paths:
basename = os.path.basename(path)
@@ -65,7 +93,7 @@ def recommend(text, paths, model, interactive, skip_ucs):
text = os.path.basename(path)
while True:
retval = print_recommendation(path, text, ctx, interactive)
retval = print_recommendation(path, text, inference_ctx, interactive)
if not retval:
break
if retval[0] is False:
@@ -99,14 +127,16 @@ def gather(paths, outfile):
is the CatID indicated for the file, and the second is the embedded file
description for the file as returned by ffprobe.
"""
logger.debug("GATHER mode")
types = ['.wav', '.flac']
table = csv.writer(outfile)
print(f"Loading category list...")
logger.debug(f"Loading category list...")
catid_list = [cat.catid for cat in load_ucs()]
scan_list = []
for path in paths:
print(f"Scanning directory {path}...", file=sys.stdout)
logger.info(f"Scanning directory {path}...")
for dirpath, _, filenames in os.walk(path):
for filename in filenames:
root, ext = os.path.splitext(filename)
@@ -116,7 +146,7 @@ def gather(paths, outfile):
scan_list.append((ucs_components.cat_id,
os.path.join(dirpath, filename)))
print(f"Found {len(scan_list)} files to process.")
logger.info(f"Found {len(scan_list)} files to process.")
for pair in tqdm.tqdm(scan_list, unit='files', file=sys.stderr):
if desc := ffmpeg_description(pair[1]):
@@ -128,7 +158,8 @@ def finetune():
"""
Fine-tune a model with training data
"""
pass
logger.debug("FINETUNE mode")
@ucsinfer.command('evaluate')
@@ -145,7 +176,8 @@ def finetune():
help="Select the sentence_transformer model to use")
@click.argument('dataset', type=click.File('r', encoding='utf8'),
default='dataset.csv')
def evaluate(dataset, offset, limit, model, no_foley):
@click.pass_context
def evaluate(ctx, dataset, offset, limit, model, no_foley):
"""
Use datasets to evaluate model performance
@@ -165,20 +197,24 @@ def evaluate(dataset, offset, limit, model, no_foley):
classified according to their subject and not wether or not they were
foley, and so these categories can be excluded with the --no-foley option.
"""
m = SentenceTransformer(model)
ctx = InferenceContext(m, model)
logger.debug("EVALUATE mode")
inference_context = InferenceContext(model,
use_cached_model=
ctx.obj['model_cache'])
reader = csv.reader(dataset)
print(f"Evaluating model {model}...")
logger.info(f"Evaluating model {model}...")
results = []
if offset > 0:
print(f"Skipping {offset} records...")
logger.debug(f"Skipping {offset} records...")
if limit > 0:
print(f"Will only evaluate {limit} records...")
logger.debug(f"Will only evaluate {limit} records...")
progress_bar = tqdm.tqdm(total=limit)
progress_bar = tqdm.tqdm(total=limit,
desc="Processing dataset...",
unit="rec")
for i, row in enumerate(reader):
if i < offset:
continue
@@ -190,7 +226,7 @@ def evaluate(dataset, offset, limit, model, no_foley):
if no_foley and cat_id in ['FOLYProp', 'FOLYFeet']:
continue
guesses = ctx.classify_text_ranked(description, limit=10)
guesses = inference_context.classify_text_ranked(description, limit=10)
if cat_id == guesses[0]:
results.append({'catid': cat_id, 'result': "TOP"})
elif cat_id in guesses[0:5]:
@@ -232,9 +268,9 @@ def evaluate(dataset, offset, limit, model, no_foley):
["Top 10 Result:", f"{total_top_10}",
f"{float(total_top_10)/float(total):.2%}"],
SEPARATING_LINE,
["UCS category count:", f"{len(ctx.catlist)}"],
["UCS category count:", f"{len(inference_context.catlist)}"],
["Total categories in sample:", f"{total_cats}",
f"{float(total_cats)/float(len(ctx.catlist)):.2%}"],
f"{float(total_cats)/float(len(inference_context.catlist)):.2%}"],
[f"Most missed category ({miss_counts[-1][0]}):",
f"{miss_counts[-1][1]}",
f"{float(miss_counts[-1][1])/float(total):.2%}"]
@@ -246,7 +282,4 @@ def evaluate(dataset, offset, limit, model, no_foley):
if __name__ == '__main__':
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
ucsinfer()
ucsinfer(obj={})

View File

@@ -54,9 +54,24 @@ class InferenceContext:
model: SentenceTransformer
model_name: str
def __init__(self, model: SentenceTransformer, model_name: str):
self.model = model
def __init__(self, model_name: str, use_cached_model: bool = True):
self.model_name = model_name
cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer",
'Squad 51')
model_cache_path = os.path.join(cache_dir,
f"{self.model_name}.cache")
if use_cached_model:
if os.path.exists(model_cache_path):
self.model = SentenceTransformer(model_cache_path)
else:
self.model = SentenceTransformer(model_name)
self.model.save(model_cache_path)
else:
self.model = SentenceTransformer(model_name)
@cached_property
def catlist(self) -> list[Ucs]:
@@ -64,13 +79,16 @@ class InferenceContext:
@cached_property
def embeddings(self) -> list[dict]:
cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer", 'Squad 51')
embedding_cache = os.path.join(cache_dir,
cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer",
'Squad 51')
embedding_cache_path = os.path.join(
cache_dir,
f"{self.model_name}-ucs_embedding.cache")
embeddings = []
if os.path.exists(embedding_cache):
with open(embedding_cache, 'rb') as f:
if os.path.exists(embedding_cache_path):
with open(embedding_cache_path, 'rb') as f:
embeddings = pickle.load(f)
else:
@@ -82,8 +100,8 @@ class InferenceContext:
'Embedding': self._encode_category(cat_defn)
}]
os.makedirs(os.path.dirname(embedding_cache), exist_ok=True)
with open(embedding_cache, 'wb') as g:
os.makedirs(os.path.dirname(embedding_cache_path), exist_ok=True)
with open(embedding_cache_path, 'wb') as g:
pickle.dump(embeddings, g)
return embeddings