Improved logging
This commit is contained in:
@@ -11,25 +11,39 @@ from .inference import InferenceContext, load_ucs
|
|||||||
from .recommend import print_recommendation
|
from .recommend import print_recommendation
|
||||||
from .util import ffmpeg_description, parse_ucs
|
from .util import ffmpeg_description, parse_ucs
|
||||||
|
|
||||||
|
logger = logging.getLogger('ucsinfer')
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
stream_handler = logging.StreamHandler()
|
||||||
|
stream_handler.setLevel(logging.WARN)
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
'%(asctime)s. %(levelname)s %(name)s: %(message)s')
|
||||||
|
stream_handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(stream_handler)
|
||||||
|
|
||||||
@click.group(epilog="For more information see "
|
@click.group(epilog="For more information see "
|
||||||
"<https://git.squad51.us/jamie/ucsinfer>")
|
"<https://git.squad51.us/jamie/ucsinfer>")
|
||||||
@click.option('--verbose', '-v', flag_value='verbose', help='Verbose output')
|
@click.option('--verbose', '-v', flag_value=True, help='Verbose output')
|
||||||
def ucsinfer(verbose):
|
@click.option('--no-model-cache', flag_value=True,
|
||||||
|
help="Don't use local model cache")
|
||||||
|
@click.pass_context
|
||||||
|
def ucsinfer(ctx, verbose, no_model_cache):
|
||||||
"""
|
"""
|
||||||
Tools for applying UCS categories to sounds using large-language Models
|
Tools for applying UCS categories to sounds using large-language Models
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
logging.basicConfig(format="%(levelname)s: %(message)s",
|
stream_handler.setLevel(logging.DEBUG)
|
||||||
level=logging.DEBUG)
|
logger.info("Verbose logging is enabled")
|
||||||
else:
|
else:
|
||||||
import warnings
|
import warnings
|
||||||
warnings.filterwarnings(
|
warnings.filterwarnings(
|
||||||
action='ignore', module='torch', category=FutureWarning,
|
action='ignore', module='torch', category=FutureWarning,
|
||||||
message=r"`encoder_attention_mask` is deprecated.*")
|
message=r"`encoder_attention_mask` is deprecated.*")
|
||||||
|
|
||||||
logging.basicConfig(format="%(levelname)s: %(message)s",
|
stream_handler.setLevel(logging.WARNING)
|
||||||
level=logging.WARN)
|
|
||||||
|
ctx.ensure_object(dict)
|
||||||
|
ctx.obj['model_cache'] = not no_model_cache
|
||||||
|
|
||||||
|
|
||||||
@ucsinfer.command('recommend')
|
@ucsinfer.command('recommend')
|
||||||
@@ -48,7 +62,8 @@ def ucsinfer(verbose):
|
|||||||
@click.option('-s', '--skip-ucs', flag_value=True, default=False,
|
@click.option('-s', '--skip-ucs', flag_value=True, default=False,
|
||||||
help="Skip files that already have a UCS category in their "
|
help="Skip files that already have a UCS category in their "
|
||||||
"name.")
|
"name.")
|
||||||
def recommend(text, paths, model, interactive, skip_ucs):
|
@click.pass_context
|
||||||
|
def recommend(ctx, text, paths, model, interactive, skip_ucs):
|
||||||
"""
|
"""
|
||||||
Infer a UCS category for a text description
|
Infer a UCS category for a text description
|
||||||
|
|
||||||
@@ -58,12 +73,15 @@ def recommend(text, paths, model, interactive, skip_ucs):
|
|||||||
the synonyms an explanations definied for each UCS subcategory. A list
|
the synonyms an explanations definied for each UCS subcategory. A list
|
||||||
of ranked subcategories is printed to the terminal for each PATH.
|
of ranked subcategories is printed to the terminal for each PATH.
|
||||||
"""
|
"""
|
||||||
ctx = InferenceContext(model)
|
logger.debug("RECOMMEND mode")
|
||||||
|
inference_ctx = InferenceContext(model,
|
||||||
|
use_cached_model=ctx.obj['model_cache'])
|
||||||
|
|
||||||
if text is not None:
|
if text is not None:
|
||||||
print_recommendation(None, text, ctx, interactive_rename=False)
|
print_recommendation(None, text, inference_ctx,
|
||||||
|
interactive_rename=False)
|
||||||
|
|
||||||
catlist = [x.catid for x in ctx.catlist]
|
catlist = [x.catid for x in inference_ctx.catlist]
|
||||||
|
|
||||||
for path in paths:
|
for path in paths:
|
||||||
basename = os.path.basename(path)
|
basename = os.path.basename(path)
|
||||||
@@ -75,7 +93,7 @@ def recommend(text, paths, model, interactive, skip_ucs):
|
|||||||
text = os.path.basename(path)
|
text = os.path.basename(path)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
retval = print_recommendation(path, text, ctx, interactive)
|
retval = print_recommendation(path, text, inference_ctx, interactive)
|
||||||
if not retval:
|
if not retval:
|
||||||
break
|
break
|
||||||
if retval[0] is False:
|
if retval[0] is False:
|
||||||
@@ -109,6 +127,8 @@ def gather(paths, outfile):
|
|||||||
is the CatID indicated for the file, and the second is the embedded file
|
is the CatID indicated for the file, and the second is the embedded file
|
||||||
description for the file as returned by ffprobe.
|
description for the file as returned by ffprobe.
|
||||||
"""
|
"""
|
||||||
|
logger.debug("GATHER mode")
|
||||||
|
|
||||||
types = ['.wav', '.flac']
|
types = ['.wav', '.flac']
|
||||||
table = csv.writer(outfile)
|
table = csv.writer(outfile)
|
||||||
print(f"Loading category list...")
|
print(f"Loading category list...")
|
||||||
@@ -116,7 +136,7 @@ def gather(paths, outfile):
|
|||||||
|
|
||||||
scan_list = []
|
scan_list = []
|
||||||
for path in paths:
|
for path in paths:
|
||||||
print(f"Scanning directory {path}...", file=sys.stdout)
|
logger.info(f"Scanning directory {path}...")
|
||||||
for dirpath, _, filenames in os.walk(path):
|
for dirpath, _, filenames in os.walk(path):
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
root, ext = os.path.splitext(filename)
|
root, ext = os.path.splitext(filename)
|
||||||
@@ -126,7 +146,7 @@ def gather(paths, outfile):
|
|||||||
scan_list.append((ucs_components.cat_id,
|
scan_list.append((ucs_components.cat_id,
|
||||||
os.path.join(dirpath, filename)))
|
os.path.join(dirpath, filename)))
|
||||||
|
|
||||||
print(f"Found {len(scan_list)} files to process.")
|
logger.info(f"Found {len(scan_list)} files to process.")
|
||||||
|
|
||||||
for pair in tqdm.tqdm(scan_list, unit='files', file=sys.stderr):
|
for pair in tqdm.tqdm(scan_list, unit='files', file=sys.stderr):
|
||||||
if desc := ffmpeg_description(pair[1]):
|
if desc := ffmpeg_description(pair[1]):
|
||||||
@@ -138,7 +158,8 @@ def finetune():
|
|||||||
"""
|
"""
|
||||||
Fine-tune a model with training data
|
Fine-tune a model with training data
|
||||||
"""
|
"""
|
||||||
pass
|
logger.debug("FINETUNE mode")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ucsinfer.command('evaluate')
|
@ucsinfer.command('evaluate')
|
||||||
@@ -155,7 +176,8 @@ def finetune():
|
|||||||
help="Select the sentence_transformer model to use")
|
help="Select the sentence_transformer model to use")
|
||||||
@click.argument('dataset', type=click.File('r', encoding='utf8'),
|
@click.argument('dataset', type=click.File('r', encoding='utf8'),
|
||||||
default='dataset.csv')
|
default='dataset.csv')
|
||||||
def evaluate(dataset, offset, limit, model, no_foley):
|
@click.pass_context
|
||||||
|
def evaluate(ctx, dataset, offset, limit, model, no_foley):
|
||||||
"""
|
"""
|
||||||
Use datasets to evaluate model performance
|
Use datasets to evaluate model performance
|
||||||
|
|
||||||
@@ -175,19 +197,24 @@ def evaluate(dataset, offset, limit, model, no_foley):
|
|||||||
classified according to their subject and not wether or not they were
|
classified according to their subject and not wether or not they were
|
||||||
foley, and so these categories can be excluded with the --no-foley option.
|
foley, and so these categories can be excluded with the --no-foley option.
|
||||||
"""
|
"""
|
||||||
ctx = InferenceContext(model)
|
logger.debug("EVALUATE mode")
|
||||||
|
inference_context = InferenceContext(model,
|
||||||
|
use_cached_model=
|
||||||
|
ctx.obj['model_cache'])
|
||||||
reader = csv.reader(dataset)
|
reader = csv.reader(dataset)
|
||||||
|
|
||||||
print(f"Evaluating model {model}...")
|
logger.info(f"Evaluating model {model}...")
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
if offset > 0:
|
if offset > 0:
|
||||||
print(f"Skipping {offset} records...")
|
logger.debug(f"Skipping {offset} records...")
|
||||||
|
|
||||||
if limit > 0:
|
if limit > 0:
|
||||||
print(f"Will only evaluate {limit} records...")
|
logger.debug(f"Will only evaluate {limit} records...")
|
||||||
|
|
||||||
progress_bar = tqdm.tqdm(total=limit)
|
progress_bar = tqdm.tqdm(total=limit,
|
||||||
|
desc="Processing dataset...",
|
||||||
|
unit="rec")
|
||||||
for i, row in enumerate(reader):
|
for i, row in enumerate(reader):
|
||||||
if i < offset:
|
if i < offset:
|
||||||
continue
|
continue
|
||||||
@@ -199,7 +226,7 @@ def evaluate(dataset, offset, limit, model, no_foley):
|
|||||||
if no_foley and cat_id in ['FOLYProp', 'FOLYFeet']:
|
if no_foley and cat_id in ['FOLYProp', 'FOLYFeet']:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
guesses = ctx.classify_text_ranked(description, limit=10)
|
guesses = inference_context.classify_text_ranked(description, limit=10)
|
||||||
if cat_id == guesses[0]:
|
if cat_id == guesses[0]:
|
||||||
results.append({'catid': cat_id, 'result': "TOP"})
|
results.append({'catid': cat_id, 'result': "TOP"})
|
||||||
elif cat_id in guesses[0:5]:
|
elif cat_id in guesses[0:5]:
|
||||||
@@ -241,9 +268,9 @@ def evaluate(dataset, offset, limit, model, no_foley):
|
|||||||
["Top 10 Result:", f"{total_top_10}",
|
["Top 10 Result:", f"{total_top_10}",
|
||||||
f"{float(total_top_10)/float(total):.2%}"],
|
f"{float(total_top_10)/float(total):.2%}"],
|
||||||
SEPARATING_LINE,
|
SEPARATING_LINE,
|
||||||
["UCS category count:", f"{len(ctx.catlist)}"],
|
["UCS category count:", f"{len(inference_context.catlist)}"],
|
||||||
["Total categories in sample:", f"{total_cats}",
|
["Total categories in sample:", f"{total_cats}",
|
||||||
f"{float(total_cats)/float(len(ctx.catlist)):.2%}"],
|
f"{float(total_cats)/float(len(inference_context.catlist)):.2%}"],
|
||||||
[f"Most missed category ({miss_counts[-1][0]}):",
|
[f"Most missed category ({miss_counts[-1][0]}):",
|
||||||
f"{miss_counts[-1][1]}",
|
f"{miss_counts[-1][1]}",
|
||||||
f"{float(miss_counts[-1][1])/float(total):.2%}"]
|
f"{float(miss_counts[-1][1])/float(total):.2%}"]
|
||||||
@@ -255,4 +282,4 @@ def evaluate(dataset, offset, limit, model, no_foley):
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
||||||
|
|
||||||
ucsinfer()
|
ucsinfer(obj={})
|
||||||
|
@@ -54,9 +54,24 @@ class InferenceContext:
|
|||||||
model: SentenceTransformer
|
model: SentenceTransformer
|
||||||
model_name: str
|
model_name: str
|
||||||
|
|
||||||
def __init__(self, model_name: str):
|
def __init__(self, model_name: str, use_cached_model: bool = True):
|
||||||
self.model = SentenceTransformer(model_name)
|
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer",
|
||||||
|
'Squad 51')
|
||||||
|
|
||||||
|
model_cache_path = os.path.join(cache_dir,
|
||||||
|
f"{self.model_name}.cache")
|
||||||
|
|
||||||
|
if use_cached_model:
|
||||||
|
if os.path.exists(model_cache_path):
|
||||||
|
self.model = SentenceTransformer(model_cache_path)
|
||||||
|
else:
|
||||||
|
self.model = SentenceTransformer(model_name)
|
||||||
|
self.model.save(model_cache_path)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.model = SentenceTransformer(model_name)
|
||||||
|
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def catlist(self) -> list[Ucs]:
|
def catlist(self) -> list[Ucs]:
|
||||||
@@ -66,14 +81,14 @@ class InferenceContext:
|
|||||||
def embeddings(self) -> list[dict]:
|
def embeddings(self) -> list[dict]:
|
||||||
cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer",
|
cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer",
|
||||||
'Squad 51')
|
'Squad 51')
|
||||||
embedding_cache = os.path.join(
|
embedding_cache_path = os.path.join(
|
||||||
cache_dir,
|
cache_dir,
|
||||||
f"{self.model_name}-ucs_embedding.cache")
|
f"{self.model_name}-ucs_embedding.cache")
|
||||||
|
|
||||||
embeddings = []
|
embeddings = []
|
||||||
|
|
||||||
if os.path.exists(embedding_cache):
|
if os.path.exists(embedding_cache_path):
|
||||||
with open(embedding_cache, 'rb') as f:
|
with open(embedding_cache_path, 'rb') as f:
|
||||||
embeddings = pickle.load(f)
|
embeddings = pickle.load(f)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@@ -85,8 +100,8 @@ class InferenceContext:
|
|||||||
'Embedding': self._encode_category(cat_defn)
|
'Embedding': self._encode_category(cat_defn)
|
||||||
}]
|
}]
|
||||||
|
|
||||||
os.makedirs(os.path.dirname(embedding_cache), exist_ok=True)
|
os.makedirs(os.path.dirname(embedding_cache_path), exist_ok=True)
|
||||||
with open(embedding_cache, 'wb') as g:
|
with open(embedding_cache_path, 'wb') as g:
|
||||||
pickle.dump(embeddings, g)
|
pickle.dump(embeddings, g)
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
Reference in New Issue
Block a user