Restructured
Hav removed suggestions for now, have added dataset gathering and evaluation.
This commit is contained in:
28
README.md
28
README.md
@@ -1,3 +1,31 @@
|
|||||||
# ucsinfer
|
# ucsinfer
|
||||||
|
|
||||||
Universal Category System inference.
|
Universal Category System inference.
|
||||||
|
|
||||||
|
## Running
|
||||||
|
|
||||||
|
```sh
|
||||||
|
python -m ucsinfer [command]
|
||||||
|
```
|
||||||
|
Pass `--help` to see a summary of subcommands and options.
|
||||||
|
|
||||||
|
The subcommands available at this time are `gather` and `evaluate`.
|
||||||
|
|
||||||
|
## Functions
|
||||||
|
|
||||||
|
* recommend
|
||||||
|
|
||||||
|
Infer a UCS category for a text description.
|
||||||
|
|
||||||
|
* gather
|
||||||
|
|
||||||
|
Scan files to capture existing text descriptions and UCS categories and save
|
||||||
|
as a dataset.
|
||||||
|
|
||||||
|
* finetune
|
||||||
|
|
||||||
|
Fine-tune an existing sentence embedding model with training data.
|
||||||
|
|
||||||
|
* evaluate
|
||||||
|
|
||||||
|
Use datasets to evauluate the performance of fine-tuning.
|
||||||
|
@@ -11,7 +11,9 @@ dependencies = [
|
|||||||
"sentence-transformers (>=5.0.0,<6.0.0)",
|
"sentence-transformers (>=5.0.0,<6.0.0)",
|
||||||
"numpy (>=2.3.2,<3.0.0)",
|
"numpy (>=2.3.2,<3.0.0)",
|
||||||
"tqdm (>=4.67.1,<5.0.0)",
|
"tqdm (>=4.67.1,<5.0.0)",
|
||||||
"platformdirs (>=4.3.8,<5.0.0)"
|
"platformdirs (>=4.3.8,<5.0.0)",
|
||||||
|
"click (>=8.2.1,<9.0.0)",
|
||||||
|
"tabulate (>=0.9.0,<0.10.0)"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@@ -1,199 +1,137 @@
|
|||||||
import os
|
import os
|
||||||
import json
|
|
||||||
import sys
|
import sys
|
||||||
import subprocess
|
import csv
|
||||||
import cmd
|
|
||||||
|
|
||||||
from typing import Optional, IO
|
|
||||||
|
|
||||||
from .inference import InferenceContext
|
|
||||||
|
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
|
import tqdm
|
||||||
|
import click
|
||||||
|
from tabulate import tabulate, SEPARATING_LINE
|
||||||
|
|
||||||
|
from .inference import InferenceContext, load_ucs
|
||||||
|
from .util import ffmpeg_description, parse_ucs
|
||||||
|
|
||||||
|
|
||||||
def description(path: str) -> Optional[str]:
|
@click.group()
|
||||||
result = subprocess.run(['ffprobe', '-show_format', '-of',
|
@click.option('--verbose', flag_value='verbose', help='Verbose output')
|
||||||
'json', path], capture_output=True)
|
def ucsinfer(verbose: bool):
|
||||||
|
pass
|
||||||
|
|
||||||
try:
|
|
||||||
result.check_returncode()
|
|
||||||
except:
|
|
||||||
return None
|
|
||||||
|
|
||||||
stream = json.loads(result.stdout)
|
@ucsinfer.command('recommend')
|
||||||
fmt = stream.get("format", None)
|
def recommend():
|
||||||
if fmt:
|
|
||||||
tags = fmt.get("tags", None)
|
|
||||||
if tags:
|
|
||||||
return tags.get("comment", None)
|
|
||||||
|
|
||||||
def recommend_category(ctx: InferenceContext, path) -> tuple[str, list]:
|
|
||||||
"""
|
"""
|
||||||
Get a text description of the file at `path` and a list of UCS cat IDs
|
Infer a UCS category for a text description
|
||||||
"""
|
"""
|
||||||
desc = description(path)
|
pass
|
||||||
if desc is None:
|
|
||||||
desc = os.path.basename(path)
|
|
||||||
|
|
||||||
return desc, ctx.classify_text_ranked(desc)
|
|
||||||
|
|
||||||
|
|
||||||
from shutil import get_terminal_size
|
@ucsinfer.command('gather')
|
||||||
|
@click.option('--outfile', type=click.File(mode='w', encoding='utf8'),
|
||||||
|
default='dataset.csv', show_default=True)
|
||||||
|
@click.argument('paths', nargs=-1)
|
||||||
|
def gather(paths, outfile):
|
||||||
|
"""
|
||||||
|
Scan files to build a training dataset at PATH
|
||||||
|
"""
|
||||||
|
types = ['.wav', '.flac']
|
||||||
|
table = csv.writer(outfile)
|
||||||
|
print(f"Loading category list...")
|
||||||
|
catid_list = [cat.catid for cat in load_ucs()]
|
||||||
|
|
||||||
class Commands(cmd.Cmd):
|
scan_list = []
|
||||||
ctx: InferenceContext
|
for path in paths:
|
||||||
|
print(f"Scanning directory {path}...", file=sys.stdout)
|
||||||
|
for dirpath, _, filenames in os.walk(path):
|
||||||
|
for filename in filenames:
|
||||||
|
root, ext = os.path.splitext(filename)
|
||||||
|
if ext in types and \
|
||||||
|
(ucs_components := parse_ucs(root, catid_list)) and \
|
||||||
|
not filename.startswith("._"):
|
||||||
|
scan_list.append((ucs_components.cat_id,
|
||||||
|
os.path.join(dirpath, filename)))
|
||||||
|
|
||||||
def __init__(self, completekey: str = "tab", stdin: IO[str] | None = None,
|
print(f"Found {len(scan_list)} files to process.")
|
||||||
stdout: IO[str] | None = None) -> None:
|
|
||||||
super().__init__(completekey, stdin, stdout)
|
|
||||||
self.file_list = []
|
|
||||||
self.catlist: list = []
|
|
||||||
self.rec_list = []
|
|
||||||
self.history = []
|
|
||||||
|
|
||||||
def preloop(self) -> None:
|
for pair in tqdm.tqdm(scan_list, unit='files',file=sys.stderr):
|
||||||
self.file_cursor = 0
|
if desc := ffmpeg_description(pair[1]):
|
||||||
self.update_prompt()
|
table.writerow([pair[0], desc])
|
||||||
self.setup_for_file()
|
|
||||||
return super().preloop()
|
|
||||||
|
|
||||||
def default(self, line: str):
|
|
||||||
try:
|
|
||||||
sel = int(line)
|
|
||||||
self.onecmd(f"use {sel}")
|
|
||||||
|
|
||||||
except ValueError:
|
|
||||||
return super().default(line)
|
|
||||||
|
|
||||||
return super().default(line)
|
|
||||||
|
|
||||||
def precmd(self, line: str):
|
|
||||||
try:
|
|
||||||
rec = int(line)
|
|
||||||
if rec < len(self.rec_list):
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
finally:
|
|
||||||
return super().precmd(line)
|
|
||||||
|
|
||||||
def postcmd(self, stop: bool, line: str) -> bool:
|
|
||||||
if not stop:
|
|
||||||
self.update_prompt()
|
|
||||||
self.setup_for_file()
|
|
||||||
return super().postcmd(stop, line)
|
|
||||||
|
|
||||||
def update_prompt(self):
|
|
||||||
self.prompt = f"(ucsinfer:{self.file_cursor}/{len(self.file_list)}) "
|
|
||||||
|
|
||||||
def setup_for_file(self):
|
|
||||||
if len(self.file_list) == 0:
|
|
||||||
print(" >> NO FILES!")
|
|
||||||
else:
|
|
||||||
file = self.file_list[self.file_cursor]
|
|
||||||
desc, recs = recommend_category(self.ctx, file)
|
|
||||||
self.onecmd('file')
|
|
||||||
print(f" >> {desc}")
|
|
||||||
self.print_recommendations(recs)
|
|
||||||
|
|
||||||
def print_recommendations(self, top_recs):
|
|
||||||
self.rec_list = []
|
|
||||||
|
|
||||||
cols, _ = get_terminal_size((80,20))
|
|
||||||
|
|
||||||
def print_one_rec(index, rec):
|
|
||||||
cat, subcat, exp = self.ctx.lookup_category(rec)
|
|
||||||
line = f" [{index:2}] {rec} - {cat} / {subcat} - {exp}"
|
|
||||||
if len(line) > cols - 3:
|
|
||||||
line = line[0:cols - 3] + "..."
|
|
||||||
print(line)
|
|
||||||
|
|
||||||
print("Suggested from description:")
|
|
||||||
for rec in top_recs:
|
|
||||||
print_one_rec(len(self.rec_list), rec)
|
|
||||||
self.rec_list.append(rec)
|
|
||||||
|
|
||||||
if len(self.history) > 0:
|
|
||||||
print("History:")
|
|
||||||
for rec in self.history:
|
|
||||||
print_one_rec(len(self.rec_list), rec)
|
|
||||||
self.rec_list.append(rec)
|
|
||||||
|
|
||||||
def do_about(self, line: str):
|
|
||||||
'Print information about recommendation NUMBER'
|
|
||||||
try:
|
|
||||||
picked = int(line)
|
|
||||||
if picked < len(self.rec_list):
|
|
||||||
cat, subcat, exp = \
|
|
||||||
self.ctx.lookup_category(self.rec_list[picked])
|
|
||||||
|
|
||||||
print(f" CatID: {self.rec_list[picked]}")
|
|
||||||
print(f" Category: {cat}")
|
|
||||||
print(f" SubCategory: {subcat}")
|
|
||||||
print(f" Explanation: {exp}")
|
|
||||||
|
|
||||||
except ValueError:
|
|
||||||
print(f" *** Value \"{line}\" not recognized")
|
|
||||||
|
|
||||||
|
|
||||||
def do_use(self, line: str):
|
@ucsinfer.command('finetune')
|
||||||
"""Apply recomendation NUMBER to the current file and advance to the
|
def finetune():
|
||||||
next one"""
|
"""
|
||||||
|
Fine-tune a model with training data
|
||||||
try:
|
"""
|
||||||
picked = int(line)
|
pass
|
||||||
print(f" :: Using {self.rec_list[picked]}")
|
|
||||||
self.do_next("")
|
|
||||||
except ValueError:
|
|
||||||
print(" *** Value \"{line}\" not recognized")
|
|
||||||
|
|
||||||
def do_file(self, _):
|
|
||||||
'Print info about the current file'
|
|
||||||
print("---")
|
|
||||||
if self.file_cursor < len(self.file_list):
|
|
||||||
path = self.file_list[self.file_cursor]
|
|
||||||
f = os.path.basename(path)
|
|
||||||
print(f" > {f}")
|
|
||||||
else:
|
|
||||||
print( " > No file")
|
|
||||||
|
|
||||||
def do_ls(self, _):
|
|
||||||
'Print list of all files in the buffer'
|
|
||||||
for file in self.file_list[self.file_cursor:] + \
|
|
||||||
self.file_list[0:self.file_cursor]:
|
|
||||||
f = os.path.basename(file)
|
|
||||||
print(f" > {f}")
|
|
||||||
|
|
||||||
def do_next(self, _):
|
|
||||||
'go to next file'
|
|
||||||
self.file_cursor += 1
|
|
||||||
self.file_cursor = self.file_cursor % len(self.file_list)
|
|
||||||
self.setup_for_file()
|
|
||||||
|
|
||||||
def do_prev(self, _):
|
|
||||||
'go to previous file'
|
|
||||||
self.file_cursor -= 1
|
|
||||||
self.file_cursor = self.file_cursor % len(self.file_list)
|
|
||||||
self.setup_for_file()
|
|
||||||
|
|
||||||
def do_bye(self, _):
|
|
||||||
'exit the program'
|
|
||||||
print("Exiting...")
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
@ucsinfer.command('evaluate')
|
||||||
|
@click.option('--offset', type=int, default=0)
|
||||||
|
@click.option('--limit', type=int, default=-1)
|
||||||
|
@click.argument('dataset', type=click.File('r', encoding='utf8'),
|
||||||
|
default='dataset.csv')
|
||||||
|
def evaluate(dataset, offset, limit):
|
||||||
|
"""
|
||||||
|
Use datasets to evauluate model performance
|
||||||
|
"""
|
||||||
model = SentenceTransformer("paraphrase-multilingual-mpnet-base-v2")
|
model = SentenceTransformer("paraphrase-multilingual-mpnet-base-v2")
|
||||||
|
ctx = InferenceContext(model)
|
||||||
|
reader = csv.reader(dataset)
|
||||||
|
|
||||||
com = Commands()
|
results = []
|
||||||
com.file_list = sys.argv[1:]
|
for i, row in enumerate(tqdm.tqdm(reader)):
|
||||||
com.ctx = InferenceContext(model=model)
|
if i < offset:
|
||||||
|
continue
|
||||||
|
|
||||||
com.cmdloop()
|
if limit > 0 and i >= limit + offset:
|
||||||
|
break
|
||||||
|
|
||||||
|
cat_id, description = row
|
||||||
|
guesses = ctx.classify_text_ranked(description, limit=10)
|
||||||
|
if cat_id == guesses[0]:
|
||||||
|
results.append({'catid': cat_id, 'result': "TOP"})
|
||||||
|
elif cat_id in guesses[0:5]:
|
||||||
|
results.append({'catid': cat_id, 'result': "TOP_5"})
|
||||||
|
elif cat_id in guesses:
|
||||||
|
results.append({'catid': cat_id, 'result': "TOP_10"})
|
||||||
|
else:
|
||||||
|
results.append({'catid': cat_id, 'result': "MISS"})
|
||||||
|
|
||||||
|
total = len(results)
|
||||||
|
total_top = len([x for x in results if x['result'] == 'TOP'])
|
||||||
|
total_top_5 = len([x for x in results if x['result'] == 'TOP_5'])
|
||||||
|
total_top_10 = len([x for x in results if x['result'] == 'TOP_10'])
|
||||||
|
|
||||||
|
cats = set([x['catid'] for x in results])
|
||||||
|
total_cats = len(cats)
|
||||||
|
|
||||||
|
miss_counts = []
|
||||||
|
for cat in cats:
|
||||||
|
miss_counts.append((cat, len([x for x in results \
|
||||||
|
if x['catid'] == cat and x['result'] == 'MISS'])))
|
||||||
|
|
||||||
|
miss_counts = sorted(miss_counts, key=lambda x: x[1])
|
||||||
|
|
||||||
|
print(f" === RESULTS === ")
|
||||||
|
|
||||||
|
table = [
|
||||||
|
["Total records in sample:", f"{total}"],
|
||||||
|
["Top Result:", f"{total_top}",
|
||||||
|
f"{float(total_top)/float(total):.2%}"],
|
||||||
|
["Top 5 Result:", f"{total_top_5}",
|
||||||
|
f"{float(total_top_5)/float(total):.2%}"],
|
||||||
|
["Top 10 Result:", f"{total_top_10}",
|
||||||
|
f"{float(total_top_10)/float(total):.2%}"],
|
||||||
|
SEPARATING_LINE,
|
||||||
|
["UCS category count:", f"{len(ctx.catlist)}"],
|
||||||
|
["Total categories in sample:", f"{total_cats}",
|
||||||
|
f"{float(total_cats)/float(len(ctx.catlist)):.2%}"],
|
||||||
|
[f"Most missed category ({miss_counts[-1][0]}):",
|
||||||
|
f"{miss_counts[-1][1]}",
|
||||||
|
f"{float(miss_counts[-1][1])/float(total):.2%}"]
|
||||||
|
]
|
||||||
|
|
||||||
|
print(tabulate(table, headers=['','n','pct']))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@@ -202,4 +140,4 @@ if __name__ == '__main__':
|
|||||||
import warnings
|
import warnings
|
||||||
warnings.simplefilter(action='ignore', category=FutureWarning)
|
warnings.simplefilter(action='ignore', category=FutureWarning)
|
||||||
|
|
||||||
main()
|
ucsinfer()
|
||||||
|
@@ -32,6 +32,17 @@ class Ucs(NamedTuple):
|
|||||||
subcategory=d['SubCategory'],
|
subcategory=d['SubCategory'],
|
||||||
explanations=d['Explanations'], synonymns=d['Synonyms'])
|
explanations=d['Explanations'], synonymns=d['Synonyms'])
|
||||||
|
|
||||||
|
def load_ucs() -> list[Ucs]:
|
||||||
|
FILE_ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
cats = []
|
||||||
|
ucs_defs = os.path.join(FILE_ROOT_DIR, 'ucs-community', 'json',
|
||||||
|
'en.json')
|
||||||
|
|
||||||
|
with open(ucs_defs, 'r') as f:
|
||||||
|
cats = json.load(f)
|
||||||
|
|
||||||
|
return [Ucs.from_dict(cat) for cat in cats]
|
||||||
|
|
||||||
class InferenceContext:
|
class InferenceContext:
|
||||||
"""
|
"""
|
||||||
Maintains caches and resources for UCS category inference.
|
Maintains caches and resources for UCS category inference.
|
||||||
@@ -44,15 +55,7 @@ class InferenceContext:
|
|||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def catlist(self) -> list[Ucs]:
|
def catlist(self) -> list[Ucs]:
|
||||||
FILE_ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
return load_ucs()
|
||||||
cats = []
|
|
||||||
ucs_defs = os.path.join(FILE_ROOT_DIR, 'ucs-community', 'json',
|
|
||||||
'en.json')
|
|
||||||
|
|
||||||
with open(ucs_defs, 'r') as f:
|
|
||||||
cats = json.load(f)
|
|
||||||
|
|
||||||
return [Ucs.from_dict(cat) for cat in cats]
|
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def embeddings(self) -> list[dict]:
|
def embeddings(self) -> list[dict]:
|
||||||
|
58
ucsinfer/util.py
Normal file
58
ucsinfer/util.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
import subprocess
|
||||||
|
import json
|
||||||
|
from typing import NamedTuple, Optional
|
||||||
|
from re import match
|
||||||
|
|
||||||
|
|
||||||
|
from .inference import Ucs
|
||||||
|
|
||||||
|
def ffmpeg_description(path: str) -> Optional[str]:
|
||||||
|
result = subprocess.run(['ffprobe', '-show_format', '-of',
|
||||||
|
'json', path], capture_output=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result.check_returncode()
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
stream = json.loads(result.stdout)
|
||||||
|
fmt = stream.get("format", None)
|
||||||
|
if fmt:
|
||||||
|
tags = fmt.get("tags", None)
|
||||||
|
if tags:
|
||||||
|
return tags.get("comment", None)
|
||||||
|
|
||||||
|
|
||||||
|
class UcsNameComponents(NamedTuple):
|
||||||
|
cat_id: str
|
||||||
|
user_cat: str | None
|
||||||
|
vendor_cat: str | None
|
||||||
|
fx_name: str
|
||||||
|
creator: str | None
|
||||||
|
source: str | None
|
||||||
|
user_data: str | None
|
||||||
|
|
||||||
|
|
||||||
|
def parse_ucs(basename: str, catid_list: list[str]) -> Optional[UcsNameComponents]:
|
||||||
|
|
||||||
|
regexp1 = r"^(?P<CatID>[A-z]+)(-(?P<UserCat>[^_]+))?_((?P<VendorCat>[^-]+)-)?(?P<FXName>[^_]+)"
|
||||||
|
|
||||||
|
regexp2 = r"(_(?P<CreatorID>[^_]+)(_(?P<SourceID>[^_]+)(_(?P<UserData>[^.]+))?)?)?"
|
||||||
|
|
||||||
|
regexp = regexp1 + regexp2
|
||||||
|
|
||||||
|
matches = match(regexp, basename)
|
||||||
|
|
||||||
|
if matches is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if matches.group('CatID') not in catid_list:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return UcsNameComponents(cat_id=matches.group('CatID'),
|
||||||
|
user_cat=matches.group('UserCat'),
|
||||||
|
vendor_cat=matches.group('VendorCat'),
|
||||||
|
fx_name=matches.group('FXName'),
|
||||||
|
creator=matches.group('CreatorID'),
|
||||||
|
source=matches.group('SourceID'),
|
||||||
|
user_data=matches.group('UserData'))
|
Reference in New Issue
Block a user