Improved logging

This commit is contained in:
2025-09-03 12:23:44 -07:00
parent 184edcb7e4
commit 2a0b0367f8
2 changed files with 73 additions and 31 deletions

View File

@@ -11,25 +11,39 @@ from .inference import InferenceContext, load_ucs
from .recommend import print_recommendation from .recommend import print_recommendation
from .util import ffmpeg_description, parse_ucs from .util import ffmpeg_description, parse_ucs
logger = logging.getLogger('ucsinfer')
logger.setLevel(logging.DEBUG)
stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.WARN)
formatter = logging.Formatter(
'%(asctime)s. %(levelname)s %(name)s: %(message)s')
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
@click.group(epilog="For more information see " @click.group(epilog="For more information see "
"<https://git.squad51.us/jamie/ucsinfer>") "<https://git.squad51.us/jamie/ucsinfer>")
@click.option('--verbose', '-v', flag_value='verbose', help='Verbose output') @click.option('--verbose', '-v', flag_value=True, help='Verbose output')
def ucsinfer(verbose): @click.option('--no-model-cache', flag_value=True,
help="Don't use local model cache")
@click.pass_context
def ucsinfer(ctx, verbose, no_model_cache):
""" """
Tools for applying UCS categories to sounds using large-language Models Tools for applying UCS categories to sounds using large-language Models
""" """
if verbose: if verbose:
logging.basicConfig(format="%(levelname)s: %(message)s", stream_handler.setLevel(logging.DEBUG)
level=logging.DEBUG) logger.info("Verbose logging is enabled")
else: else:
import warnings import warnings
warnings.filterwarnings( warnings.filterwarnings(
action='ignore', module='torch', category=FutureWarning, action='ignore', module='torch', category=FutureWarning,
message=r"`encoder_attention_mask` is deprecated.*") message=r"`encoder_attention_mask` is deprecated.*")
logging.basicConfig(format="%(levelname)s: %(message)s", stream_handler.setLevel(logging.WARNING)
level=logging.WARN)
ctx.ensure_object(dict)
ctx.obj['model_cache'] = not no_model_cache
@ucsinfer.command('recommend') @ucsinfer.command('recommend')
@@ -48,7 +62,8 @@ def ucsinfer(verbose):
@click.option('-s', '--skip-ucs', flag_value=True, default=False, @click.option('-s', '--skip-ucs', flag_value=True, default=False,
help="Skip files that already have a UCS category in their " help="Skip files that already have a UCS category in their "
"name.") "name.")
def recommend(text, paths, model, interactive, skip_ucs): @click.pass_context
def recommend(ctx, text, paths, model, interactive, skip_ucs):
""" """
Infer a UCS category for a text description Infer a UCS category for a text description
@@ -58,12 +73,15 @@ def recommend(text, paths, model, interactive, skip_ucs):
the synonyms an explanations definied for each UCS subcategory. A list the synonyms an explanations definied for each UCS subcategory. A list
of ranked subcategories is printed to the terminal for each PATH. of ranked subcategories is printed to the terminal for each PATH.
""" """
ctx = InferenceContext(model) logger.debug("RECOMMEND mode")
inference_ctx = InferenceContext(model,
use_cached_model=ctx.obj['model_cache'])
if text is not None: if text is not None:
print_recommendation(None, text, ctx, interactive_rename=False) print_recommendation(None, text, inference_ctx,
interactive_rename=False)
catlist = [x.catid for x in ctx.catlist] catlist = [x.catid for x in inference_ctx.catlist]
for path in paths: for path in paths:
basename = os.path.basename(path) basename = os.path.basename(path)
@@ -75,7 +93,7 @@ def recommend(text, paths, model, interactive, skip_ucs):
text = os.path.basename(path) text = os.path.basename(path)
while True: while True:
retval = print_recommendation(path, text, ctx, interactive) retval = print_recommendation(path, text, inference_ctx, interactive)
if not retval: if not retval:
break break
if retval[0] is False: if retval[0] is False:
@@ -109,6 +127,8 @@ def gather(paths, outfile):
is the CatID indicated for the file, and the second is the embedded file is the CatID indicated for the file, and the second is the embedded file
description for the file as returned by ffprobe. description for the file as returned by ffprobe.
""" """
logger.debug("GATHER mode")
types = ['.wav', '.flac'] types = ['.wav', '.flac']
table = csv.writer(outfile) table = csv.writer(outfile)
print(f"Loading category list...") print(f"Loading category list...")
@@ -116,7 +136,7 @@ def gather(paths, outfile):
scan_list = [] scan_list = []
for path in paths: for path in paths:
print(f"Scanning directory {path}...", file=sys.stdout) logger.info(f"Scanning directory {path}...")
for dirpath, _, filenames in os.walk(path): for dirpath, _, filenames in os.walk(path):
for filename in filenames: for filename in filenames:
root, ext = os.path.splitext(filename) root, ext = os.path.splitext(filename)
@@ -126,7 +146,7 @@ def gather(paths, outfile):
scan_list.append((ucs_components.cat_id, scan_list.append((ucs_components.cat_id,
os.path.join(dirpath, filename))) os.path.join(dirpath, filename)))
print(f"Found {len(scan_list)} files to process.") logger.info(f"Found {len(scan_list)} files to process.")
for pair in tqdm.tqdm(scan_list, unit='files', file=sys.stderr): for pair in tqdm.tqdm(scan_list, unit='files', file=sys.stderr):
if desc := ffmpeg_description(pair[1]): if desc := ffmpeg_description(pair[1]):
@@ -138,7 +158,8 @@ def finetune():
""" """
Fine-tune a model with training data Fine-tune a model with training data
""" """
pass logger.debug("FINETUNE mode")
@ucsinfer.command('evaluate') @ucsinfer.command('evaluate')
@@ -155,7 +176,8 @@ def finetune():
help="Select the sentence_transformer model to use") help="Select the sentence_transformer model to use")
@click.argument('dataset', type=click.File('r', encoding='utf8'), @click.argument('dataset', type=click.File('r', encoding='utf8'),
default='dataset.csv') default='dataset.csv')
def evaluate(dataset, offset, limit, model, no_foley): @click.pass_context
def evaluate(ctx, dataset, offset, limit, model, no_foley):
""" """
Use datasets to evaluate model performance Use datasets to evaluate model performance
@@ -175,19 +197,24 @@ def evaluate(dataset, offset, limit, model, no_foley):
classified according to their subject and not wether or not they were classified according to their subject and not wether or not they were
foley, and so these categories can be excluded with the --no-foley option. foley, and so these categories can be excluded with the --no-foley option.
""" """
ctx = InferenceContext(model) logger.debug("EVALUATE mode")
inference_context = InferenceContext(model,
use_cached_model=
ctx.obj['model_cache'])
reader = csv.reader(dataset) reader = csv.reader(dataset)
print(f"Evaluating model {model}...") logger.info(f"Evaluating model {model}...")
results = [] results = []
if offset > 0: if offset > 0:
print(f"Skipping {offset} records...") logger.debug(f"Skipping {offset} records...")
if limit > 0: if limit > 0:
print(f"Will only evaluate {limit} records...") logger.debug(f"Will only evaluate {limit} records...")
progress_bar = tqdm.tqdm(total=limit) progress_bar = tqdm.tqdm(total=limit,
desc="Processing dataset...",
unit="rec")
for i, row in enumerate(reader): for i, row in enumerate(reader):
if i < offset: if i < offset:
continue continue
@@ -199,7 +226,7 @@ def evaluate(dataset, offset, limit, model, no_foley):
if no_foley and cat_id in ['FOLYProp', 'FOLYFeet']: if no_foley and cat_id in ['FOLYProp', 'FOLYFeet']:
continue continue
guesses = ctx.classify_text_ranked(description, limit=10) guesses = inference_context.classify_text_ranked(description, limit=10)
if cat_id == guesses[0]: if cat_id == guesses[0]:
results.append({'catid': cat_id, 'result': "TOP"}) results.append({'catid': cat_id, 'result': "TOP"})
elif cat_id in guesses[0:5]: elif cat_id in guesses[0:5]:
@@ -241,9 +268,9 @@ def evaluate(dataset, offset, limit, model, no_foley):
["Top 10 Result:", f"{total_top_10}", ["Top 10 Result:", f"{total_top_10}",
f"{float(total_top_10)/float(total):.2%}"], f"{float(total_top_10)/float(total):.2%}"],
SEPARATING_LINE, SEPARATING_LINE,
["UCS category count:", f"{len(ctx.catlist)}"], ["UCS category count:", f"{len(inference_context.catlist)}"],
["Total categories in sample:", f"{total_cats}", ["Total categories in sample:", f"{total_cats}",
f"{float(total_cats)/float(len(ctx.catlist)):.2%}"], f"{float(total_cats)/float(len(inference_context.catlist)):.2%}"],
[f"Most missed category ({miss_counts[-1][0]}):", [f"Most missed category ({miss_counts[-1][0]}):",
f"{miss_counts[-1][1]}", f"{miss_counts[-1][1]}",
f"{float(miss_counts[-1][1])/float(total):.2%}"] f"{float(miss_counts[-1][1])/float(total):.2%}"]
@@ -255,4 +282,4 @@ def evaluate(dataset, offset, limit, model, no_foley):
if __name__ == '__main__': if __name__ == '__main__':
os.environ['TOKENIZERS_PARALLELISM'] = 'false' os.environ['TOKENIZERS_PARALLELISM'] = 'false'
ucsinfer() ucsinfer(obj={})

View File

@@ -54,9 +54,24 @@ class InferenceContext:
model: SentenceTransformer model: SentenceTransformer
model_name: str model_name: str
def __init__(self, model_name: str): def __init__(self, model_name: str, use_cached_model: bool = True):
self.model = SentenceTransformer(model_name)
self.model_name = model_name self.model_name = model_name
cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer",
'Squad 51')
model_cache_path = os.path.join(cache_dir,
f"{self.model_name}.cache")
if use_cached_model:
if os.path.exists(model_cache_path):
self.model = SentenceTransformer(model_cache_path)
else:
self.model = SentenceTransformer(model_name)
self.model.save(model_cache_path)
else:
self.model = SentenceTransformer(model_name)
@cached_property @cached_property
def catlist(self) -> list[Ucs]: def catlist(self) -> list[Ucs]:
@@ -66,14 +81,14 @@ class InferenceContext:
def embeddings(self) -> list[dict]: def embeddings(self) -> list[dict]:
cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer", cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer",
'Squad 51') 'Squad 51')
embedding_cache = os.path.join( embedding_cache_path = os.path.join(
cache_dir, cache_dir,
f"{self.model_name}-ucs_embedding.cache") f"{self.model_name}-ucs_embedding.cache")
embeddings = [] embeddings = []
if os.path.exists(embedding_cache): if os.path.exists(embedding_cache_path):
with open(embedding_cache, 'rb') as f: with open(embedding_cache_path, 'rb') as f:
embeddings = pickle.load(f) embeddings = pickle.load(f)
else: else:
@@ -85,8 +100,8 @@ class InferenceContext:
'Embedding': self._encode_category(cat_defn) 'Embedding': self._encode_category(cat_defn)
}] }]
os.makedirs(os.path.dirname(embedding_cache), exist_ok=True) os.makedirs(os.path.dirname(embedding_cache_path), exist_ok=True)
with open(embedding_cache, 'wb') as g: with open(embedding_cache_path, 'wb') as g:
pickle.dump(embeddings, g) pickle.dump(embeddings, g)
return embeddings return embeddings