Compare commits
17 Commits
336d6f013e
...
master
Author | SHA1 | Date | |
---|---|---|---|
615d8ab279 | |||
9a887c4ed5 | |||
63b140209b | |||
d181ac73b1 | |||
bddce23c76 | |||
b4758dd138 | |||
04332b73ee | |||
103fffe0a4 | |||
10519f9c1a | |||
e419f698c9 | |||
6cd0415a26 | |||
b833d6d3c0 | |||
fee55a7d5a | |||
0594899bdd | |||
46a693bf93 | |||
0a2aaa2a22 | |||
d855ba4c78 |
@@ -49,15 +49,13 @@ Pass `--help` to see a summary of subcommands and options.
|
|||||||
* gather
|
* gather
|
||||||
|
|
||||||
Scan files to capture existing text descriptions and UCS categories
|
Scan files to capture existing text descriptions and UCS categories
|
||||||
and save as a dataset. This function is used to construct datasets
|
and save as a dataset.
|
||||||
that `evaluate` can use to test models and finetune can use to
|
|
||||||
refine them.
|
|
||||||
|
|
||||||
* ~finetune~ (planned)
|
* ~finetune~ (planned)
|
||||||
|
|
||||||
Fine-tune an existing sentence embedding model with training data.
|
Fine-tune an existing sentence embedding model with training data.
|
||||||
|
|
||||||
* evaluate
|
* ~evaluate~ (FIXME phase)
|
||||||
|
|
||||||
Use datasets to evaluate the performance of a model and fine-tuning.
|
Use datasets to evaluate the performance of a model and fine-tuning.
|
||||||
|
|
||||||
|
14
TODO.md
14
TODO.md
@@ -2,21 +2,27 @@
|
|||||||
|
|
||||||
- Use History when adding catids
|
- Use History when adding catids
|
||||||
|
|
||||||
|
|
||||||
## Gather
|
## Gather
|
||||||
|
|
||||||
- Add "source" column for tracking provenance
|
- Maybe more dataset configurations
|
||||||
|
|
||||||
|
|
||||||
## Fine-tune
|
## Fine-tune
|
||||||
|
|
||||||
- Implement
|
- Implement BatchAllTripletLoss
|
||||||
|
|
||||||
|
|
||||||
## Evaluate
|
## Evaluate
|
||||||
|
|
||||||
- Print more information about the dataset coverage of UCS
|
- Print more information about the dataset coverage of UCS
|
||||||
- Allow skipping model testing for this
|
- Allow skipping model testing for this
|
||||||
|
|
||||||
- Print raw output
|
- Print raw output
|
||||||
- Maybe load everything into a sqlite for slicker reporting
|
<!-- - Maybe load everything into a sqlite for slicker reporting -->
|
||||||
|
|
||||||
|
|
||||||
## Utility
|
## Utility
|
||||||
|
|
||||||
- Dataset partitioning
|
- Clear caches
|
||||||
|
|
||||||
|
@@ -13,7 +13,8 @@ dependencies = [
|
|||||||
"tqdm (>=4.67.1,<5.0.0)",
|
"tqdm (>=4.67.1,<5.0.0)",
|
||||||
"platformdirs (>=4.3.8,<5.0.0)",
|
"platformdirs (>=4.3.8,<5.0.0)",
|
||||||
"click (>=8.2.1,<9.0.0)",
|
"click (>=8.2.1,<9.0.0)",
|
||||||
"tabulate (>=0.9.0,<0.10.0)"
|
"tabulate (>=0.9.0,<0.10.0)",
|
||||||
|
"datasets (>=4.0.0,<5.0.0)"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -25,3 +26,6 @@ build-backend = "poetry.core.masonry.api"
|
|||||||
ipython = "^9.4.0"
|
ipython = "^9.4.0"
|
||||||
jupyter = "^1.1.1"
|
jupyter = "^1.1.1"
|
||||||
|
|
||||||
|
[tool.poetry.scripts]
|
||||||
|
ucsinfer = "ucsinfer.__main__:ucsinfer"
|
||||||
|
|
||||||
|
@@ -1,13 +1,15 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
# import csv
|
||||||
import csv
|
|
||||||
import logging
|
import logging
|
||||||
|
from itertools import chain
|
||||||
|
|
||||||
import tqdm
|
import tqdm
|
||||||
import click
|
import click
|
||||||
from tabulate import tabulate, SEPARATING_LINE
|
# from tabulate import tabulate, SEPARATING_LINE
|
||||||
|
|
||||||
from .inference import InferenceContext, load_ucs
|
from .inference import InferenceContext, load_ucs
|
||||||
|
from .gather import (build_sentence_class_dataset, print_dataset_stats,
|
||||||
|
ucs_definitions_generator, scan_metadata, walk_path)
|
||||||
from .recommend import print_recommendation
|
from .recommend import print_recommendation
|
||||||
from .util import ffmpeg_description, parse_ucs
|
from .util import ffmpeg_description, parse_ucs
|
||||||
|
|
||||||
@@ -23,10 +25,17 @@ logger.addHandler(stream_handler)
|
|||||||
@click.group(epilog="For more information see "
|
@click.group(epilog="For more information see "
|
||||||
"<https://git.squad51.us/jamie/ucsinfer>")
|
"<https://git.squad51.us/jamie/ucsinfer>")
|
||||||
@click.option('--verbose', '-v', flag_value=True, help='Verbose output')
|
@click.option('--verbose', '-v', flag_value=True, help='Verbose output')
|
||||||
|
@click.option('--model', type=str, metavar="<model-name>",
|
||||||
|
default="paraphrase-multilingual-mpnet-base-v2",
|
||||||
|
show_default=True,
|
||||||
|
help="Select the sentence_transformer model to use")
|
||||||
@click.option('--no-model-cache', flag_value=True,
|
@click.option('--no-model-cache', flag_value=True,
|
||||||
help="Don't use local model cache")
|
help="Don't use local model cache")
|
||||||
|
@click.option('--complete-ucs', flag_value=True, default=False,
|
||||||
|
help="Use all UCS categories. By default, all 'FOLEY' and "
|
||||||
|
"'ARCHIVED' UCS categories are excluded from all functions.")
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
def ucsinfer(ctx, verbose, no_model_cache):
|
def ucsinfer(ctx, verbose, no_model_cache, model, complete_ucs):
|
||||||
"""
|
"""
|
||||||
Tools for applying UCS categories to sounds using large-language Models
|
Tools for applying UCS categories to sounds using large-language Models
|
||||||
"""
|
"""
|
||||||
@@ -44,6 +53,19 @@ def ucsinfer(ctx, verbose, no_model_cache):
|
|||||||
|
|
||||||
ctx.ensure_object(dict)
|
ctx.ensure_object(dict)
|
||||||
ctx.obj['model_cache'] = not no_model_cache
|
ctx.obj['model_cache'] = not no_model_cache
|
||||||
|
ctx.obj['model_name'] = model
|
||||||
|
ctx.obj['complete_ucs'] = complete_ucs
|
||||||
|
|
||||||
|
if no_model_cache:
|
||||||
|
logger.info("Model cache inhibited by config")
|
||||||
|
|
||||||
|
if complete_ucs:
|
||||||
|
logger.info("Using complete UCS catgeory list")
|
||||||
|
else:
|
||||||
|
logger.info("Non-descriptive UCS categories will be excluded. Turn "
|
||||||
|
"this option off by passing --complete-ucs.")
|
||||||
|
|
||||||
|
logger.info(f"Using model {model}")
|
||||||
|
|
||||||
|
|
||||||
@ucsinfer.command('recommend')
|
@ucsinfer.command('recommend')
|
||||||
@@ -51,10 +73,6 @@ def ucsinfer(ctx, verbose, no_model_cache):
|
|||||||
help="Recommend a category for given text instead of reading "
|
help="Recommend a category for given text instead of reading "
|
||||||
"from a file")
|
"from a file")
|
||||||
@click.argument('paths', nargs=-1, metavar='<paths>')
|
@click.argument('paths', nargs=-1, metavar='<paths>')
|
||||||
@click.option('--model', type=str, metavar="<model-name>",
|
|
||||||
default="paraphrase-multilingual-mpnet-base-v2",
|
|
||||||
show_default=True,
|
|
||||||
help="Select the sentence_transformer model to use")
|
|
||||||
@click.option('--interactive','-i', flag_value=True, default=False,
|
@click.option('--interactive','-i', flag_value=True, default=False,
|
||||||
help="After processing each path in <paths>, prompt for a "
|
help="After processing each path in <paths>, prompt for a "
|
||||||
"recommendation to accept, and then prepend the selection to "
|
"recommendation to accept, and then prepend the selection to "
|
||||||
@@ -63,7 +81,7 @@ def ucsinfer(ctx, verbose, no_model_cache):
|
|||||||
help="Skip files that already have a UCS category in their "
|
help="Skip files that already have a UCS category in their "
|
||||||
"name.")
|
"name.")
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
def recommend(ctx, text, paths, model, interactive, skip_ucs):
|
def recommend(ctx, text, paths, interactive, skip_ucs):
|
||||||
"""
|
"""
|
||||||
Infer a UCS category for a text description
|
Infer a UCS category for a text description
|
||||||
|
|
||||||
@@ -74,8 +92,9 @@ def recommend(ctx, text, paths, model, interactive, skip_ucs):
|
|||||||
of ranked subcategories is printed to the terminal for each PATH.
|
of ranked subcategories is printed to the terminal for each PATH.
|
||||||
"""
|
"""
|
||||||
logger.debug("RECOMMEND mode")
|
logger.debug("RECOMMEND mode")
|
||||||
inference_ctx = InferenceContext(model,
|
inference_ctx = InferenceContext(ctx.obj['model_name'],
|
||||||
use_cached_model=ctx.obj['model_cache'])
|
use_cached_model=ctx.obj['model_cache'],
|
||||||
|
use_full_ucs=ctx.obj['complete_ucs'])
|
||||||
|
|
||||||
if text is not None:
|
if text is not None:
|
||||||
print_recommendation(None, text, inference_ctx,
|
print_recommendation(None, text, inference_ctx,
|
||||||
@@ -84,6 +103,11 @@ def recommend(ctx, text, paths, model, interactive, skip_ucs):
|
|||||||
catlist = [x.catid for x in inference_ctx.catlist]
|
catlist = [x.catid for x in inference_ctx.catlist]
|
||||||
|
|
||||||
for path in paths:
|
for path in paths:
|
||||||
|
_, ext = os.path.splitext(path)
|
||||||
|
|
||||||
|
if ext not in (".wav", ".flac"):
|
||||||
|
continue
|
||||||
|
|
||||||
basename = os.path.basename(path)
|
basename = os.path.basename(path)
|
||||||
if skip_ucs and parse_ucs(basename, catlist):
|
if skip_ucs and parse_ucs(basename, catlist):
|
||||||
continue
|
continue
|
||||||
@@ -110,51 +134,59 @@ def recommend(ctx, text, paths, model, interactive, skip_ucs):
|
|||||||
|
|
||||||
|
|
||||||
@ucsinfer.command('gather')
|
@ucsinfer.command('gather')
|
||||||
@click.option('--outfile', type=click.File(mode='w', encoding='utf8'),
|
@click.option('--out', default='dataset/', show_default=True)
|
||||||
default='dataset.csv', show_default=True)
|
@click.option('--ucs-data', flag_value=True, help="Create a dataset based "
|
||||||
|
"on the UCS category explanations and synonymns (PATHS will "
|
||||||
|
"be ignored.)")
|
||||||
@click.argument('paths', nargs=-1)
|
@click.argument('paths', nargs=-1)
|
||||||
def gather(paths, outfile):
|
@click.pass_context
|
||||||
|
def gather(ctx, paths, out, ucs_data):
|
||||||
"""
|
"""
|
||||||
Scan files to build a training dataset at PATH
|
Scan files to build a training dataset
|
||||||
|
|
||||||
The `gather` command walks the directory hierarchy for each path in PATHS
|
The `gather` is used to build a training dataset for finetuning the
|
||||||
and looks for .wav and .flac files that are named according to the UCS
|
selected model. Description sentences and UCS categories are collected from
|
||||||
file naming guidelines, with at least a CatID and FX Name, divided by an
|
'.wav' and '.flac' files on-disk that have valid UCS filenames and assigned
|
||||||
underscore.
|
CatIDs, and this information is recorded into a HuggingFace dataset.
|
||||||
|
|
||||||
For every file ucsinfer finds that meets this criteria, it creates a record
|
Gather scans the filesystem in two passes: first, the directory tree is
|
||||||
in an output dataset CSV file. The dataset file has two columns: the first
|
walked by os.walk and a list of filenames that meet the above name criteria
|
||||||
is the CatID indicated for the file, and the second is the embedded file
|
is compiled. After this list is compiled, each file is scanned one-by-one
|
||||||
description for the file as returned by ffprobe.
|
with ffprobe to obtain its "description" metadata; if this isn't present,
|
||||||
|
the parsed 'fxname' of te file becomes the description.
|
||||||
"""
|
"""
|
||||||
logger.debug("GATHER mode")
|
logger.debug("GATHER mode")
|
||||||
|
|
||||||
types = ['.wav', '.flac']
|
|
||||||
table = csv.writer(outfile)
|
|
||||||
logger.debug(f"Loading category list...")
|
logger.debug(f"Loading category list...")
|
||||||
catid_list = [cat.catid for cat in load_ucs()]
|
ucs = load_ucs(full_ucs=ctx.obj['complete_ucs'])
|
||||||
|
|
||||||
|
scan_list: list[tuple[str,str]] = []
|
||||||
|
catid_list = [cat.catid for cat in ucs]
|
||||||
|
|
||||||
|
if ucs_data:
|
||||||
|
logger.info('Creating dataset for UCS categories instead of from PATH')
|
||||||
|
paths = []
|
||||||
|
|
||||||
scan_list = []
|
|
||||||
for path in paths:
|
for path in paths:
|
||||||
logger.info(f"Scanning directory {path}...")
|
scan_list += walk_path(path, catid_list)
|
||||||
for dirpath, _, filenames in os.walk(path):
|
|
||||||
for filename in filenames:
|
|
||||||
root, ext = os.path.splitext(filename)
|
|
||||||
if ext in types and \
|
|
||||||
(ucs_components := parse_ucs(root, catid_list)) and \
|
|
||||||
not filename.startswith("._"):
|
|
||||||
scan_list.append((ucs_components.cat_id,
|
|
||||||
os.path.join(dirpath, filename)))
|
|
||||||
|
|
||||||
logger.info(f"Found {len(scan_list)} files to process.")
|
logger.info(f"Found {len(scan_list)} files to process.")
|
||||||
|
|
||||||
for pair in tqdm.tqdm(scan_list, unit='files', file=sys.stderr):
|
logger.info("Building dataset...")
|
||||||
if desc := ffmpeg_description(pair[1]):
|
|
||||||
table.writerow([pair[0], desc])
|
dataset = build_sentence_class_dataset(
|
||||||
|
chain(scan_metadata(scan_list, catid_list),
|
||||||
|
ucs_definitions_generator(ucs)),
|
||||||
|
catid_list)
|
||||||
|
|
||||||
|
logger.info(f"Saving dataset to disk at {out}")
|
||||||
|
print_dataset_stats(dataset, catid_list)
|
||||||
|
dataset.save_to_disk(out)
|
||||||
|
|
||||||
|
|
||||||
@ucsinfer.command('finetune')
|
@ucsinfer.command('finetune')
|
||||||
def finetune():
|
@click.pass_context
|
||||||
|
def finetune(ctx):
|
||||||
"""
|
"""
|
||||||
Fine-tune a model with training data
|
Fine-tune a model with training data
|
||||||
"""
|
"""
|
||||||
@@ -163,120 +195,15 @@ def finetune():
|
|||||||
|
|
||||||
|
|
||||||
@ucsinfer.command('evaluate')
|
@ucsinfer.command('evaluate')
|
||||||
@click.option('--offset', type=int, default=0, metavar="<int>",
|
@click.argument('dataset', default='dataset/')
|
||||||
help='Skip this many records in the dataset before processing')
|
|
||||||
@click.option('--limit', type=int, default=-1, metavar="<int>",
|
|
||||||
help='Process this many records and then exit')
|
|
||||||
@click.option('--no-foley', 'no_foley', flag_value=True, default=False,
|
|
||||||
help="Ignore any data in the set with FOLYProp or FOLYFeet "
|
|
||||||
"category")
|
|
||||||
@click.option('--model', type=str, metavar="<model-name>",
|
|
||||||
default="paraphrase-multilingual-mpnet-base-v2",
|
|
||||||
show_default=True,
|
|
||||||
help="Select the sentence_transformer model to use")
|
|
||||||
@click.argument('dataset', type=click.File('r', encoding='utf8'),
|
|
||||||
default='dataset.csv')
|
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
def evaluate(ctx, dataset, offset, limit, model, no_foley):
|
def evaluate(ctx, dataset, offset, limit):
|
||||||
"""
|
"""
|
||||||
Use datasets to evaluate model performance
|
|
||||||
|
|
||||||
The `evaluate` command reads the input DATASET file row by row and
|
|
||||||
performs a classifcation of the given description against the selected
|
|
||||||
model (either the default or using the --model option). The command then
|
|
||||||
checks if the model inferred the correct category as given by the dataset.
|
|
||||||
|
|
||||||
The model gives its top 10 possible categories for a given description,
|
|
||||||
and the results are tabulated according to (1) wether the top
|
|
||||||
classification was correct, (2) wether the correct classifcation was in the
|
|
||||||
top 5, or (3) wether it was in the top 10. The worst-performing category,
|
|
||||||
the one with the most misses, is also reported as well as the category
|
|
||||||
coverage, how many categories are present in the dataset.
|
|
||||||
|
|
||||||
NOTE: With experimentation it was found that foley items generally were
|
|
||||||
classified according to their subject and not wether or not they were
|
|
||||||
foley, and so these categories can be excluded with the --no-foley option.
|
|
||||||
"""
|
"""
|
||||||
logger.debug("EVALUATE mode")
|
logger.debug("EVALUATE mode")
|
||||||
inference_context = InferenceContext(model,
|
logger.warning("Model evaluation is not currently implemented")
|
||||||
use_cached_model=
|
|
||||||
ctx.obj['model_cache'])
|
|
||||||
reader = csv.reader(dataset)
|
|
||||||
|
|
||||||
logger.info(f"Evaluating model {model}...")
|
|
||||||
results = []
|
|
||||||
|
|
||||||
if offset > 0:
|
|
||||||
logger.debug(f"Skipping {offset} records...")
|
|
||||||
|
|
||||||
if limit > 0:
|
|
||||||
logger.debug(f"Will only evaluate {limit} records...")
|
|
||||||
|
|
||||||
progress_bar = tqdm.tqdm(total=limit,
|
|
||||||
desc="Processing dataset...",
|
|
||||||
unit="rec")
|
|
||||||
for i, row in enumerate(reader):
|
|
||||||
if i < offset:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if limit > 0 and i >= limit + offset:
|
|
||||||
break
|
|
||||||
|
|
||||||
cat_id, description = row
|
|
||||||
if no_foley and cat_id in ['FOLYProp', 'FOLYFeet']:
|
|
||||||
continue
|
|
||||||
|
|
||||||
guesses = inference_context.classify_text_ranked(description, limit=10)
|
|
||||||
if cat_id == guesses[0]:
|
|
||||||
results.append({'catid': cat_id, 'result': "TOP"})
|
|
||||||
elif cat_id in guesses[0:5]:
|
|
||||||
results.append({'catid': cat_id, 'result': "TOP_5"})
|
|
||||||
elif cat_id in guesses:
|
|
||||||
results.append({'catid': cat_id, 'result': "TOP_10"})
|
|
||||||
else:
|
|
||||||
results.append({'catid': cat_id, 'result': "MISS"})
|
|
||||||
|
|
||||||
progress_bar.update(1)
|
|
||||||
|
|
||||||
total = len(results)
|
|
||||||
total_top = len([x for x in results if x['result'] == 'TOP'])
|
|
||||||
total_top_5 = len([x for x in results if x['result'] == 'TOP_5'])
|
|
||||||
total_top_10 = len([x for x in results if x['result'] == 'TOP_10'])
|
|
||||||
|
|
||||||
cats = set([x['catid'] for x in results])
|
|
||||||
total_cats = len(cats)
|
|
||||||
|
|
||||||
miss_counts = []
|
|
||||||
for cat in cats:
|
|
||||||
miss_counts.append(
|
|
||||||
(cat, len([x for x in results
|
|
||||||
if x['catid'] == cat and x['result'] == 'MISS'])))
|
|
||||||
|
|
||||||
miss_counts = sorted(miss_counts, key=lambda x: x[1])
|
|
||||||
|
|
||||||
print(f"## Results for Model {model} ##\n")
|
|
||||||
|
|
||||||
if no_foley:
|
|
||||||
print("(FOLYProp and FOLYFeet have been omitted from the dataset.)\n")
|
|
||||||
|
|
||||||
table = [
|
|
||||||
["Total records in sample:", f"{total}"],
|
|
||||||
["Top Result:", f"{total_top}",
|
|
||||||
f"{float(total_top)/float(total):.2%}"],
|
|
||||||
["Top 5 Result:", f"{total_top_5}",
|
|
||||||
f"{float(total_top_5)/float(total):.2%}"],
|
|
||||||
["Top 10 Result:", f"{total_top_10}",
|
|
||||||
f"{float(total_top_10)/float(total):.2%}"],
|
|
||||||
SEPARATING_LINE,
|
|
||||||
["UCS category count:", f"{len(inference_context.catlist)}"],
|
|
||||||
["Total categories in sample:", f"{total_cats}",
|
|
||||||
f"{float(total_cats)/float(len(inference_context.catlist)):.2%}"],
|
|
||||||
[f"Most missed category ({miss_counts[-1][0]}):",
|
|
||||||
f"{miss_counts[-1][1]}",
|
|
||||||
f"{float(miss_counts[-1][1])/float(total):.2%}"]
|
|
||||||
]
|
|
||||||
|
|
||||||
print(tabulate(table, headers=['', 'n', 'pct'], tablefmt='github'))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
7
ucsinfer/evaluate.py
Normal file
7
ucsinfer/evaluate.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
|
||||||
|
|
||||||
|
# from sentence_transformers import SentenceTransformer
|
||||||
|
# from sentence_transformers.evaluation import BinaryClassificationEvaluator
|
||||||
|
# from datasets import load_dataset_from_disk, DatasetDict
|
||||||
|
#
|
||||||
|
|
118
ucsinfer/gather.py
Normal file
118
ucsinfer/gather.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
from .inference import Ucs
|
||||||
|
from .util import ffmpeg_description, parse_ucs
|
||||||
|
|
||||||
|
from subprocess import CalledProcessError
|
||||||
|
import os.path
|
||||||
|
|
||||||
|
from datasets import Dataset, Features, Value, ClassLabel, DatasetInfo
|
||||||
|
from datasets.dataset_dict import DatasetDict
|
||||||
|
|
||||||
|
from typing import Iterator, Generator
|
||||||
|
|
||||||
|
from tabulate import tabulate
|
||||||
|
import logging
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
def walk_path(path:str, catid_list) -> list[tuple[str,str]]:
|
||||||
|
types = ['.wav', '.flac']
|
||||||
|
logger = logging.getLogger('ucsinfer')
|
||||||
|
walker_p = tqdm.tqdm(total=None, unit='dir', desc="Walking filesystem...")
|
||||||
|
scan_list = []
|
||||||
|
|
||||||
|
for dirpath, _, filenames in os.walk(path):
|
||||||
|
logger.info(f"Walking directory {dirpath}")
|
||||||
|
for filename in filenames:
|
||||||
|
walker_p.update()
|
||||||
|
root, ext = os.path.splitext(filename)
|
||||||
|
if ext not in types or filename.startswith("._"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if (ucs_components := parse_ucs(root, catid_list)):
|
||||||
|
p = os.path.join(dirpath, filename)
|
||||||
|
logger.info(f"Adding path to scan list {p}")
|
||||||
|
scan_list.append((ucs_components.cat_id, p))
|
||||||
|
|
||||||
|
return scan_list
|
||||||
|
|
||||||
|
def scan_metadata(scan_list: list[tuple[str,str]], catid_list: list[str]):
|
||||||
|
logger = logging.getLogger('ucsinfer')
|
||||||
|
for pair in tqdm.tqdm(scan_list, unit='files'):
|
||||||
|
logger.info(f"Scanning file with ffprobe: {pair[1]}")
|
||||||
|
try:
|
||||||
|
desc = ffmpeg_description(pair[1])
|
||||||
|
except CalledProcessError as e:
|
||||||
|
logger.error(f"ffprobe returned error (){e.returncode}): " \
|
||||||
|
+ e.stderr)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if desc:
|
||||||
|
yield desc, str(pair[0])
|
||||||
|
else:
|
||||||
|
comps = parse_ucs(os.path.basename(pair[1]), catid_list)
|
||||||
|
assert comps
|
||||||
|
yield comps.fx_name, str(pair[0])
|
||||||
|
|
||||||
|
|
||||||
|
def ucs_definitions_generator(ucs: list[Ucs]) \
|
||||||
|
-> Generator[tuple[str,str],None, None]:
|
||||||
|
for cat in ucs:
|
||||||
|
yield cat.explanations, cat.catid
|
||||||
|
yield ", ".join(cat.synonymns), cat.catid
|
||||||
|
|
||||||
|
def print_dataset_stats(dataset: DatasetDict, catlist: list[str]):
|
||||||
|
|
||||||
|
data_table = []
|
||||||
|
data_table.append([["Total records in combined dataset:", len(dataset)]])
|
||||||
|
data_table.append([["Total records in `train`:", len(dataset['train'])]])
|
||||||
|
|
||||||
|
tab = tabulate(data_table)
|
||||||
|
|
||||||
|
print(tab)
|
||||||
|
|
||||||
|
# https://www.sbert.net/docs/sentence_transformer/loss_overview.html
|
||||||
|
|
||||||
|
def build_sentence_class_dataset(
|
||||||
|
records: Iterator[tuple[str, str]],
|
||||||
|
catlist: list[str]) -> DatasetDict:
|
||||||
|
"""
|
||||||
|
Create a new dataset for `records` which contains (sentence, class) pairs.
|
||||||
|
The dataset is split into train and test slices.
|
||||||
|
|
||||||
|
:param records: a generator for records that generates pairs of
|
||||||
|
(sentence, catid)
|
||||||
|
:returns: A dataset with two columns: (sentence, hash(catid))
|
||||||
|
"""
|
||||||
|
|
||||||
|
labels = ClassLabel(names=catlist)
|
||||||
|
|
||||||
|
features = Features({'sentence': Value('string'),
|
||||||
|
'class': labels})
|
||||||
|
|
||||||
|
info = DatasetInfo(
|
||||||
|
description=f"(sentence, UCS CatID) pairs gathered by the "
|
||||||
|
"ucsinfer tool on {}", features= features)
|
||||||
|
|
||||||
|
|
||||||
|
items: list[dict] = []
|
||||||
|
for obj in records:
|
||||||
|
items += [{'sentence': obj[0], 'class': obj[1]}]
|
||||||
|
|
||||||
|
|
||||||
|
whole = Dataset.from_list(items, features=features, info=info)
|
||||||
|
|
||||||
|
split_set = whole.train_test_split(0.2)
|
||||||
|
test_eval_set = split_set['test'].train_test_split(0.5)
|
||||||
|
|
||||||
|
return DatasetDict({
|
||||||
|
'train': split_set['train'],
|
||||||
|
'test': test_eval_set['train'],
|
||||||
|
'eval': test_eval_set['test']
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
# def build_sentence_anchor_dataset() -> Dataset:
|
||||||
|
# """
|
||||||
|
# Create a new dataset for `records` which contains (sentence, anchor) pairs.
|
||||||
|
# """
|
||||||
|
# pass
|
||||||
|
|
@@ -34,7 +34,7 @@ class Ucs(NamedTuple):
|
|||||||
explanations=d['Explanations'], synonymns=d['Synonyms'])
|
explanations=d['Explanations'], synonymns=d['Synonyms'])
|
||||||
|
|
||||||
|
|
||||||
def load_ucs() -> list[Ucs]:
|
def load_ucs(full_ucs: bool = True) -> list[Ucs]:
|
||||||
FILE_ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
FILE_ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
cats = []
|
cats = []
|
||||||
ucs_defs = os.path.join(FILE_ROOT_DIR, 'ucs-community', 'json',
|
ucs_defs = os.path.join(FILE_ROOT_DIR, 'ucs-community', 'json',
|
||||||
@@ -43,7 +43,13 @@ def load_ucs() -> list[Ucs]:
|
|||||||
with open(ucs_defs, 'r') as f:
|
with open(ucs_defs, 'r') as f:
|
||||||
cats = json.load(f)
|
cats = json.load(f)
|
||||||
|
|
||||||
return [Ucs.from_dict(cat) for cat in cats]
|
ucs = [Ucs.from_dict(cat) for cat in cats]
|
||||||
|
|
||||||
|
if full_ucs:
|
||||||
|
return ucs
|
||||||
|
else:
|
||||||
|
return [cat for cat in ucs if \
|
||||||
|
cat.category not in ['FOLEY', 'ARCHIVED']]
|
||||||
|
|
||||||
|
|
||||||
class InferenceContext:
|
class InferenceContext:
|
||||||
@@ -54,8 +60,11 @@ class InferenceContext:
|
|||||||
model: SentenceTransformer
|
model: SentenceTransformer
|
||||||
model_name: str
|
model_name: str
|
||||||
|
|
||||||
def __init__(self, model_name: str, use_cached_model: bool = True):
|
def __init__(self, model_name: str, use_cached_model: bool = True,
|
||||||
|
use_full_ucs: bool = False):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
self.use_full_ucs = use_full_ucs
|
||||||
|
|
||||||
cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer",
|
cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer",
|
||||||
'Squad 51')
|
'Squad 51')
|
||||||
|
|
||||||
@@ -75,7 +84,7 @@ class InferenceContext:
|
|||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def catlist(self) -> list[Ucs]:
|
def catlist(self) -> list[Ucs]:
|
||||||
return load_ucs()
|
return load_ucs(full_ucs=self.use_full_ucs)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def embeddings(self) -> list[dict]:
|
def embeddings(self) -> list[dict]:
|
||||||
@@ -94,7 +103,12 @@ class InferenceContext:
|
|||||||
else:
|
else:
|
||||||
print(f"Calculating embeddings for model {self.model_name}...")
|
print(f"Calculating embeddings for model {self.model_name}...")
|
||||||
|
|
||||||
for cat_defn in self.catlist:
|
# we need to calculate the embeddings for all cats, not just the
|
||||||
|
# ones we're loading for this run
|
||||||
|
|
||||||
|
full_catlist = load_ucs(full_ucs= True)
|
||||||
|
|
||||||
|
for cat_defn in full_catlist:
|
||||||
embeddings += [{
|
embeddings += [{
|
||||||
'CatID': cat_defn.catid,
|
'CatID': cat_defn.catid,
|
||||||
'Embedding': self._encode_category(cat_defn)
|
'Embedding': self._encode_category(cat_defn)
|
||||||
@@ -104,7 +118,9 @@ class InferenceContext:
|
|||||||
with open(embedding_cache_path, 'wb') as g:
|
with open(embedding_cache_path, 'wb') as g:
|
||||||
pickle.dump(embeddings, g)
|
pickle.dump(embeddings, g)
|
||||||
|
|
||||||
return embeddings
|
whitelisted_cats = [cat.catid for cat in self.catlist]
|
||||||
|
|
||||||
|
return [e for e in embeddings if e['CatID'] in whitelisted_cats]
|
||||||
|
|
||||||
def _encode_category(self, cat: Ucs) -> np.ndarray:
|
def _encode_category(self, cat: Ucs) -> np.ndarray:
|
||||||
sentence_components = [cat.explanations,
|
sentence_components = [cat.explanations,
|
||||||
|
@@ -4,8 +4,10 @@ from re import match
|
|||||||
|
|
||||||
from .inference import InferenceContext
|
from .inference import InferenceContext
|
||||||
|
|
||||||
|
from tabulate import tabulate
|
||||||
|
|
||||||
def print_recommendation(path: str | None, text: str, ctx: InferenceContext,
|
def print_recommendation(path: str | None, text: str, ctx: InferenceContext,
|
||||||
interactive_rename: bool):
|
interactive_rename: bool, recommend_limit=10):
|
||||||
"""
|
"""
|
||||||
Print recommendations interactively.
|
Print recommendations interactively.
|
||||||
|
|
||||||
@@ -17,18 +19,19 @@ def print_recommendation(path: str | None, text: str, ctx: InferenceContext,
|
|||||||
`print_recommendation` should be called again with this argument.
|
`print_recommendation` should be called again with this argument.
|
||||||
- if retval[2] is a str, this is the catid the user has selected.
|
- if retval[2] is a str, this is the catid the user has selected.
|
||||||
"""
|
"""
|
||||||
recs = ctx.classify_text_ranked(text)
|
recs = ctx.classify_text_ranked(text, limit=recommend_limit)
|
||||||
print("----------")
|
print("----------")
|
||||||
if path:
|
if path:
|
||||||
print(f"Path: {path}")
|
print(f"Path: {path}")
|
||||||
|
|
||||||
print(f"Text: {text or '<None>'}")
|
print(f"Text: {text or '<None>'}")
|
||||||
|
|
||||||
for i, r in enumerate(recs):
|
for i, r in enumerate(recs):
|
||||||
cat, subcat, _ = ctx.lookup_category(r)
|
cat, subcat, _ = ctx.lookup_category(r)
|
||||||
print(f"- {i}: {r} ({cat}-{subcat})")
|
print(f"- {i}: {r} ({cat}-{subcat})")
|
||||||
|
|
||||||
if interactive_rename and path is not None:
|
if interactive_rename and path is not None:
|
||||||
response = input("#, t [text], ?, q > ")
|
response = input("(n#), t, c, b, ?, q > ")
|
||||||
|
|
||||||
if m := match(r'^([0-9]+)', response):
|
if m := match(r'^([0-9]+)', response):
|
||||||
selection = int(m.group(1))
|
selection = int(m.group(1))
|
||||||
@@ -43,12 +46,27 @@ def print_recommendation(path: str | None, text: str, ctx: InferenceContext,
|
|||||||
text = m.group(1)
|
text = m.group(1)
|
||||||
return True, text, None
|
return True, text, None
|
||||||
|
|
||||||
|
elif m := match(r'^c (.+)', response):
|
||||||
|
return True, None, m.group(1)
|
||||||
|
|
||||||
|
elif m := match(r'^b (.+)', response):
|
||||||
|
expt = []
|
||||||
|
for cat in ctx.catlist:
|
||||||
|
if cat.catid.startswith(m.group(1)):
|
||||||
|
expt.append([f"{cat.catid}: ({cat.category}-{cat.subcategory})",
|
||||||
|
cat.explanations])
|
||||||
|
|
||||||
|
print(tabulate(expt, maxcolwidths=80))
|
||||||
|
return True, text, None
|
||||||
|
|
||||||
elif response.startswith("?"):
|
elif response.startswith("?"):
|
||||||
print("""
|
print("""
|
||||||
Choices:
|
Choices:
|
||||||
- Enter recommendation number to rename file,
|
- Enter recommendation number to rename file,
|
||||||
- "t [text]" to search for new recommendations based on [text]
|
- "t <text>" to search for new recommendations based on <text>
|
||||||
- "p" re-use the last selected cat-id
|
- "p" re-use the last selected cat-id
|
||||||
|
- "c <cat>" to type in a category by hand
|
||||||
|
- "b <cat>" browse category list for categories starting with <cat>
|
||||||
- "?" for this message
|
- "?" for this message
|
||||||
- "q" to quit
|
- "q" to quit
|
||||||
- or any other key to skip this file and continue to next file
|
- or any other key to skip this file and continue to next file
|
||||||
@@ -56,6 +74,8 @@ Choices:
|
|||||||
return True, text, None
|
return True, text, None
|
||||||
elif response.startswith('q'):
|
elif response.startswith('q'):
|
||||||
return (False, None, None)
|
return (False, None, None)
|
||||||
|
else:
|
||||||
|
print()
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@@ -9,10 +9,7 @@ def ffmpeg_description(path: str) -> Optional[str]:
|
|||||||
result = subprocess.run(['ffprobe', '-show_format', '-of',
|
result = subprocess.run(['ffprobe', '-show_format', '-of',
|
||||||
'json', path], capture_output=True)
|
'json', path], capture_output=True)
|
||||||
|
|
||||||
try:
|
|
||||||
result.check_returncode()
|
result.check_returncode()
|
||||||
except:
|
|
||||||
return None
|
|
||||||
|
|
||||||
stream = json.loads(result.stdout)
|
stream = json.loads(result.stdout)
|
||||||
fmt = stream.get("format", None)
|
fmt = stream.get("format", None)
|
||||||
|
Reference in New Issue
Block a user