Compare commits
2 Commits
3d67623d77
...
fb56ca1dd4
Author | SHA1 | Date | |
---|---|---|---|
![]() |
fb56ca1dd4 | ||
![]() |
5ea64d089f |
@@ -22,7 +22,7 @@ def recommend():
|
|||||||
"""
|
"""
|
||||||
Infer a UCS category for a text description
|
Infer a UCS category for a text description
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ucsinfer.command('gather')
|
@ucsinfer.command('gather')
|
||||||
@@ -36,7 +36,7 @@ def gather(paths, outfile):
|
|||||||
types = ['.wav', '.flac']
|
types = ['.wav', '.flac']
|
||||||
table = csv.writer(outfile)
|
table = csv.writer(outfile)
|
||||||
print(f"Loading category list...")
|
print(f"Loading category list...")
|
||||||
catid_list = [cat.catid for cat in load_ucs()]
|
catid_list = [cat.catid for cat in load_ucs()]
|
||||||
|
|
||||||
scan_list = []
|
scan_list = []
|
||||||
for path in paths:
|
for path in paths:
|
||||||
@@ -47,12 +47,12 @@ def gather(paths, outfile):
|
|||||||
if ext in types and \
|
if ext in types and \
|
||||||
(ucs_components := parse_ucs(root, catid_list)) and \
|
(ucs_components := parse_ucs(root, catid_list)) and \
|
||||||
not filename.startswith("._"):
|
not filename.startswith("._"):
|
||||||
scan_list.append((ucs_components.cat_id,
|
scan_list.append((ucs_components.cat_id,
|
||||||
os.path.join(dirpath, filename)))
|
os.path.join(dirpath, filename)))
|
||||||
|
|
||||||
print(f"Found {len(scan_list)} files to process.")
|
print(f"Found {len(scan_list)} files to process.")
|
||||||
|
|
||||||
for pair in tqdm.tqdm(scan_list, unit='files',file=sys.stderr):
|
for pair in tqdm.tqdm(scan_list, unit='files', file=sys.stderr):
|
||||||
if desc := ffmpeg_description(pair[1]):
|
if desc := ffmpeg_description(pair[1]):
|
||||||
table.writerow([pair[0], desc])
|
table.writerow([pair[0], desc])
|
||||||
|
|
||||||
@@ -62,13 +62,13 @@ def finetune():
|
|||||||
"""
|
"""
|
||||||
Fine-tune a model with training data
|
Fine-tune a model with training data
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ucsinfer.command('evaluate')
|
@ucsinfer.command('evaluate')
|
||||||
@click.option('--offset', type=int, default=0)
|
@click.option('--offset', type=int, default=0)
|
||||||
@click.option('--limit', type=int, default=-1)
|
@click.option('--limit', type=int, default=-1)
|
||||||
@click.argument('dataset', type=click.File('r', encoding='utf8'),
|
@click.argument('dataset', type=click.File('r', encoding='utf8'),
|
||||||
default='dataset.csv')
|
default='dataset.csv')
|
||||||
def evaluate(dataset, offset, limit):
|
def evaluate(dataset, offset, limit):
|
||||||
"""
|
"""
|
||||||
@@ -82,7 +82,7 @@ def evaluate(dataset, offset, limit):
|
|||||||
for i, row in enumerate(tqdm.tqdm(reader)):
|
for i, row in enumerate(tqdm.tqdm(reader)):
|
||||||
if i < offset:
|
if i < offset:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if limit > 0 and i >= limit + offset:
|
if limit > 0 and i >= limit + offset:
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -107,33 +107,34 @@ def evaluate(dataset, offset, limit):
|
|||||||
|
|
||||||
miss_counts = []
|
miss_counts = []
|
||||||
for cat in cats:
|
for cat in cats:
|
||||||
miss_counts.append((cat, len([x for x in results \
|
miss_counts.append(
|
||||||
if x['catid'] == cat and x['result'] == 'MISS'])))
|
(cat, len([x for x in results
|
||||||
|
if x['catid'] == cat and x['result'] == 'MISS'])))
|
||||||
|
|
||||||
miss_counts = sorted(miss_counts, key=lambda x: x[1])
|
miss_counts = sorted(miss_counts, key=lambda x: x[1])
|
||||||
|
|
||||||
print(f" === RESULTS === ")
|
print(f" === RESULTS === ")
|
||||||
|
|
||||||
table = [
|
table = [
|
||||||
["Total records in sample:", f"{total}"],
|
["Total records in sample:", f"{total}"],
|
||||||
["Top Result:", f"{total_top}",
|
["Top Result:", f"{total_top}",
|
||||||
f"{float(total_top)/float(total):.2%}"],
|
f"{float(total_top)/float(total):.2%}"],
|
||||||
["Top 5 Result:", f"{total_top_5}",
|
["Top 5 Result:", f"{total_top_5}",
|
||||||
f"{float(total_top_5)/float(total):.2%}"],
|
f"{float(total_top_5)/float(total):.2%}"],
|
||||||
["Top 10 Result:", f"{total_top_10}",
|
["Top 10 Result:", f"{total_top_10}",
|
||||||
f"{float(total_top_10)/float(total):.2%}"],
|
f"{float(total_top_10)/float(total):.2%}"],
|
||||||
SEPARATING_LINE,
|
SEPARATING_LINE,
|
||||||
["UCS category count:", f"{len(ctx.catlist)}"],
|
["UCS category count:", f"{len(ctx.catlist)}"],
|
||||||
["Total categories in sample:", f"{total_cats}",
|
["Total categories in sample:", f"{total_cats}",
|
||||||
f"{float(total_cats)/float(len(ctx.catlist)):.2%}"],
|
f"{float(total_cats)/float(len(ctx.catlist)):.2%}"],
|
||||||
[f"Most missed category ({miss_counts[-1][0]}):",
|
[f"Most missed category ({miss_counts[-1][0]}):",
|
||||||
f"{miss_counts[-1][1]}",
|
f"{miss_counts[-1][1]}",
|
||||||
f"{float(miss_counts[-1][1])/float(total):.2%}"]
|
f"{float(miss_counts[-1][1])/float(total):.2%}"]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
print(tabulate(table, headers=['', 'n', 'pct']))
|
||||||
|
|
||||||
print(tabulate(table, headers=['','n','pct']))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
||||||
|
|
||||||
|
@@ -11,6 +11,7 @@ import platformdirs
|
|||||||
|
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
|
||||||
def classify_text_ranked(text, embeddings_list, model, limit=5):
|
def classify_text_ranked(text, embeddings_list, model, limit=5):
|
||||||
text_embedding = model.encode(text, convert_to_numpy=True)
|
text_embedding = model.encode(text, convert_to_numpy=True)
|
||||||
embeddings = np.array([info['Embedding'] for info in embeddings_list])
|
embeddings = np.array([info['Embedding'] for info in embeddings_list])
|
||||||
@@ -23,15 +24,16 @@ class Ucs(NamedTuple):
|
|||||||
catid: str
|
catid: str
|
||||||
category: str
|
category: str
|
||||||
subcategory: str
|
subcategory: str
|
||||||
explanations: str
|
explanations: str
|
||||||
synonymns: list[str]
|
synonymns: list[str]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, d: dict):
|
def from_dict(cls, d: dict):
|
||||||
return Ucs(catid=d['CatID'], category=d['Category'],
|
return Ucs(catid=d['CatID'], category=d['Category'],
|
||||||
subcategory=d['SubCategory'],
|
subcategory=d['SubCategory'],
|
||||||
explanations=d['Explanations'], synonymns=d['Synonyms'])
|
explanations=d['Explanations'], synonymns=d['Synonyms'])
|
||||||
|
|
||||||
|
|
||||||
def load_ucs() -> list[Ucs]:
|
def load_ucs() -> list[Ucs]:
|
||||||
FILE_ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
FILE_ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
cats = []
|
cats = []
|
||||||
@@ -43,6 +45,7 @@ def load_ucs() -> list[Ucs]:
|
|||||||
|
|
||||||
return [Ucs.from_dict(cat) for cat in cats]
|
return [Ucs.from_dict(cat) for cat in cats]
|
||||||
|
|
||||||
|
|
||||||
class InferenceContext:
|
class InferenceContext:
|
||||||
"""
|
"""
|
||||||
Maintains caches and resources for UCS category inference.
|
Maintains caches and resources for UCS category inference.
|
||||||
@@ -72,9 +75,9 @@ class InferenceContext:
|
|||||||
|
|
||||||
for cat_defn in self.catlist:
|
for cat_defn in self.catlist:
|
||||||
embeddings += [{
|
embeddings += [{
|
||||||
'CatID': cat_defn.catid,
|
'CatID': cat_defn.catid,
|
||||||
'Embedding': self._encode_category(cat_defn)
|
'Embedding': self._encode_category(cat_defn)
|
||||||
}]
|
}]
|
||||||
|
|
||||||
os.makedirs(os.path.dirname(embedding_cache), exist_ok=True)
|
os.makedirs(os.path.dirname(embedding_cache), exist_ok=True)
|
||||||
with open(embedding_cache, 'wb') as g:
|
with open(embedding_cache, 'wb') as g:
|
||||||
@@ -83,10 +86,10 @@ class InferenceContext:
|
|||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
def _encode_category(self, cat: Ucs) -> np.ndarray:
|
def _encode_category(self, cat: Ucs) -> np.ndarray:
|
||||||
sentence_components = [cat.explanations,
|
sentence_components = [cat.explanations,
|
||||||
cat.category,
|
cat.category,
|
||||||
cat.subcategory
|
cat.subcategory
|
||||||
]
|
]
|
||||||
sentence_components += cat.synonymns
|
sentence_components += cat.synonymns
|
||||||
sentence = ", ".join(sentence_components)
|
sentence = ", ".join(sentence_components)
|
||||||
return self.model.encode(sentence, convert_to_numpy=True)
|
return self.model.encode(sentence, convert_to_numpy=True)
|
||||||
@@ -104,13 +107,12 @@ class InferenceContext:
|
|||||||
def lookup_category(self, catid) -> tuple[str, str, str]:
|
def lookup_category(self, catid) -> tuple[str, str, str]:
|
||||||
"""
|
"""
|
||||||
Get the category, subcategory and explanations phrase for a `catid`
|
Get the category, subcategory and explanations phrase for a `catid`
|
||||||
|
|
||||||
:raises: StopIterator if CatId is not on the schedule
|
:raises: StopIterator if CatId is not on the schedule
|
||||||
"""
|
"""
|
||||||
i = (
|
i = (
|
||||||
(x.category, x.subcategory, x.explanations) \
|
(x.category, x.subcategory, x.explanations)
|
||||||
for x in self.catlist if x.catid == catid
|
for x in self.catlist if x.catid == catid
|
||||||
)
|
)
|
||||||
|
|
||||||
return next(i)
|
|
||||||
|
|
||||||
|
return next(i)
|
||||||
|
@@ -1,20 +1,21 @@
|
|||||||
import subprocess
|
import subprocess
|
||||||
import json
|
import json
|
||||||
from typing import NamedTuple, Optional
|
from typing import NamedTuple, Optional
|
||||||
from re import match
|
from re import match
|
||||||
|
|
||||||
|
|
||||||
from .inference import Ucs
|
from .inference import Ucs
|
||||||
|
|
||||||
|
|
||||||
def ffmpeg_description(path: str) -> Optional[str]:
|
def ffmpeg_description(path: str) -> Optional[str]:
|
||||||
result = subprocess.run(['ffprobe', '-show_format', '-of',
|
result = subprocess.run(['ffprobe', '-show_format', '-of',
|
||||||
'json', path], capture_output=True)
|
'json', path], capture_output=True)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result.check_returncode()
|
result.check_returncode()
|
||||||
except:
|
except:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
stream = json.loads(result.stdout)
|
stream = json.loads(result.stdout)
|
||||||
fmt = stream.get("format", None)
|
fmt = stream.get("format", None)
|
||||||
if fmt:
|
if fmt:
|
||||||
@@ -28,10 +29,10 @@ class UcsNameComponents(NamedTuple):
|
|||||||
Components of a UCS filename
|
Components of a UCS filename
|
||||||
"""
|
"""
|
||||||
cat_id: str
|
cat_id: str
|
||||||
user_cat: str | None
|
user_cat: str | None
|
||||||
vendor_cat: str | None
|
vendor_cat: str | None
|
||||||
fx_name: str
|
fx_name: str
|
||||||
creator: str | None
|
creator: str | None
|
||||||
source: str | None
|
source: str | None
|
||||||
user_data: str | None
|
user_data: str | None
|
||||||
|
|
||||||
@@ -43,7 +44,7 @@ class UcsNameComponents(NamedTuple):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
if self.user_cat and not match(r"[^\-_]+", self.user_cat):
|
if self.user_cat and not match(r"[^\-_]+", self.user_cat):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if self.vendor_cat and not match(r"[^\-_]+", self.vendor_cat):
|
if self.vendor_cat and not match(r"[^\-_]+", self.vendor_cat):
|
||||||
return False
|
return False
|
||||||
@@ -52,7 +53,7 @@ class UcsNameComponents(NamedTuple):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
if self.creator and not match(r"[^_]+", self.creator):
|
if self.creator and not match(r"[^_]+", self.creator):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if self.source and not match(r"[^_]+", self.source):
|
if self.source and not match(r"[^_]+", self.source):
|
||||||
return False
|
return False
|
||||||
@@ -65,25 +66,28 @@ def build_ucs(components: UcsNameComponents, extension: str) -> str:
|
|||||||
"""
|
"""
|
||||||
Build a UCS filename
|
Build a UCS filename
|
||||||
"""
|
"""
|
||||||
assert components.validate(), "UcsNameComponents contains invalid characters"
|
assert components.validate(), \
|
||||||
|
"UcsNameComponents contains invalid characters"
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def parse_ucs(rootname: str, catid_list: list[str]) -> Optional[UcsNameComponents]:
|
def parse_ucs(rootname: str,
|
||||||
|
catid_list: list[str]) -> Optional[UcsNameComponents]:
|
||||||
"""
|
"""
|
||||||
Parse the UCS components from a file name root.
|
Parse the UCS components from a file name root.
|
||||||
|
|
||||||
:param rootname: filename root, the basename of the file without extension
|
:param rootname: filename root, the basename of the file without extension
|
||||||
:param catid_list: a list of all UCS CatIDs
|
:param catid_list: a list of all UCS CatIDs
|
||||||
:returns: the components, or `None` if the filename is not in UCS format
|
:returns: the components, or `None` if the filename is not in UCS format
|
||||||
"""
|
"""
|
||||||
|
|
||||||
regexp1 = r"^(?P<CatID>[A-z]+)(-(?P<UserCat>[^_]+))?_((?P<VendorCat>[^-]+)-)?(?P<FXName>[^_]+)"
|
regexp1 = r"^(?P<CatID>[A-z]+)(-(?P<UserCat>[^_]+))?_"
|
||||||
|
regexp2 = r"((?P<VendorCat>[^-]+)-)?(?P<FXName>[^_]+)"
|
||||||
|
regexp3 = r"(_(?P<CreatorID>[^_]+)(_(?P<SourceID>[^_]+)"
|
||||||
|
regexp4 = r"(_(?P<UserData>[^.]+))?)?)?"
|
||||||
|
|
||||||
regexp2 = r"(_(?P<CreatorID>[^_]+)(_(?P<SourceID>[^_]+)(_(?P<UserData>[^.]+))?)?)?"
|
regexp = regexp1 + regexp2 + regexp3 + regexp4
|
||||||
|
|
||||||
regexp = regexp1 + regexp2
|
|
||||||
|
|
||||||
matches = match(regexp, rootname)
|
matches = match(regexp, rootname)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user