Compare commits
2 Commits
3d67623d77
...
fb56ca1dd4
Author | SHA1 | Date | |
---|---|---|---|
![]() |
fb56ca1dd4 | ||
![]() |
5ea64d089f |
@@ -52,7 +52,7 @@ def gather(paths, outfile):
|
||||
|
||||
print(f"Found {len(scan_list)} files to process.")
|
||||
|
||||
for pair in tqdm.tqdm(scan_list, unit='files',file=sys.stderr):
|
||||
for pair in tqdm.tqdm(scan_list, unit='files', file=sys.stderr):
|
||||
if desc := ffmpeg_description(pair[1]):
|
||||
table.writerow([pair[0], desc])
|
||||
|
||||
@@ -107,31 +107,32 @@ def evaluate(dataset, offset, limit):
|
||||
|
||||
miss_counts = []
|
||||
for cat in cats:
|
||||
miss_counts.append((cat, len([x for x in results \
|
||||
if x['catid'] == cat and x['result'] == 'MISS'])))
|
||||
miss_counts.append(
|
||||
(cat, len([x for x in results
|
||||
if x['catid'] == cat and x['result'] == 'MISS'])))
|
||||
|
||||
miss_counts = sorted(miss_counts, key=lambda x: x[1])
|
||||
|
||||
print(f" === RESULTS === ")
|
||||
|
||||
table = [
|
||||
["Total records in sample:", f"{total}"],
|
||||
["Top Result:", f"{total_top}",
|
||||
f"{float(total_top)/float(total):.2%}"],
|
||||
["Top 5 Result:", f"{total_top_5}",
|
||||
f"{float(total_top_5)/float(total):.2%}"],
|
||||
["Top 10 Result:", f"{total_top_10}",
|
||||
f"{float(total_top_10)/float(total):.2%}"],
|
||||
SEPARATING_LINE,
|
||||
["UCS category count:", f"{len(ctx.catlist)}"],
|
||||
["Total categories in sample:", f"{total_cats}",
|
||||
f"{float(total_cats)/float(len(ctx.catlist)):.2%}"],
|
||||
[f"Most missed category ({miss_counts[-1][0]}):",
|
||||
f"{miss_counts[-1][1]}",
|
||||
f"{float(miss_counts[-1][1])/float(total):.2%}"]
|
||||
]
|
||||
["Total records in sample:", f"{total}"],
|
||||
["Top Result:", f"{total_top}",
|
||||
f"{float(total_top)/float(total):.2%}"],
|
||||
["Top 5 Result:", f"{total_top_5}",
|
||||
f"{float(total_top_5)/float(total):.2%}"],
|
||||
["Top 10 Result:", f"{total_top_10}",
|
||||
f"{float(total_top_10)/float(total):.2%}"],
|
||||
SEPARATING_LINE,
|
||||
["UCS category count:", f"{len(ctx.catlist)}"],
|
||||
["Total categories in sample:", f"{total_cats}",
|
||||
f"{float(total_cats)/float(len(ctx.catlist)):.2%}"],
|
||||
[f"Most missed category ({miss_counts[-1][0]}):",
|
||||
f"{miss_counts[-1][1]}",
|
||||
f"{float(miss_counts[-1][1])/float(total):.2%}"]
|
||||
]
|
||||
|
||||
print(tabulate(table, headers=['','n','pct']))
|
||||
print(tabulate(table, headers=['', 'n', 'pct']))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@@ -11,6 +11,7 @@ import platformdirs
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
|
||||
def classify_text_ranked(text, embeddings_list, model, limit=5):
|
||||
text_embedding = model.encode(text, convert_to_numpy=True)
|
||||
embeddings = np.array([info['Embedding'] for info in embeddings_list])
|
||||
@@ -32,6 +33,7 @@ class Ucs(NamedTuple):
|
||||
subcategory=d['SubCategory'],
|
||||
explanations=d['Explanations'], synonymns=d['Synonyms'])
|
||||
|
||||
|
||||
def load_ucs() -> list[Ucs]:
|
||||
FILE_ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
cats = []
|
||||
@@ -43,6 +45,7 @@ def load_ucs() -> list[Ucs]:
|
||||
|
||||
return [Ucs.from_dict(cat) for cat in cats]
|
||||
|
||||
|
||||
class InferenceContext:
|
||||
"""
|
||||
Maintains caches and resources for UCS category inference.
|
||||
@@ -72,9 +75,9 @@ class InferenceContext:
|
||||
|
||||
for cat_defn in self.catlist:
|
||||
embeddings += [{
|
||||
'CatID': cat_defn.catid,
|
||||
'Embedding': self._encode_category(cat_defn)
|
||||
}]
|
||||
'CatID': cat_defn.catid,
|
||||
'Embedding': self._encode_category(cat_defn)
|
||||
}]
|
||||
|
||||
os.makedirs(os.path.dirname(embedding_cache), exist_ok=True)
|
||||
with open(embedding_cache, 'wb') as g:
|
||||
@@ -84,9 +87,9 @@ class InferenceContext:
|
||||
|
||||
def _encode_category(self, cat: Ucs) -> np.ndarray:
|
||||
sentence_components = [cat.explanations,
|
||||
cat.category,
|
||||
cat.subcategory
|
||||
]
|
||||
cat.category,
|
||||
cat.subcategory
|
||||
]
|
||||
sentence_components += cat.synonymns
|
||||
sentence = ", ".join(sentence_components)
|
||||
return self.model.encode(sentence, convert_to_numpy=True)
|
||||
@@ -108,9 +111,8 @@ class InferenceContext:
|
||||
:raises: StopIterator if CatId is not on the schedule
|
||||
"""
|
||||
i = (
|
||||
(x.category, x.subcategory, x.explanations) \
|
||||
for x in self.catlist if x.catid == catid
|
||||
)
|
||||
(x.category, x.subcategory, x.explanations)
|
||||
for x in self.catlist if x.catid == catid
|
||||
)
|
||||
|
||||
return next(i)
|
||||
|
||||
|
@@ -6,6 +6,7 @@ from re import match
|
||||
|
||||
from .inference import Ucs
|
||||
|
||||
|
||||
def ffmpeg_description(path: str) -> Optional[str]:
|
||||
result = subprocess.run(['ffprobe', '-show_format', '-of',
|
||||
'json', path], capture_output=True)
|
||||
@@ -65,12 +66,14 @@ def build_ucs(components: UcsNameComponents, extension: str) -> str:
|
||||
"""
|
||||
Build a UCS filename
|
||||
"""
|
||||
assert components.validate(), "UcsNameComponents contains invalid characters"
|
||||
assert components.validate(), \
|
||||
"UcsNameComponents contains invalid characters"
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def parse_ucs(rootname: str, catid_list: list[str]) -> Optional[UcsNameComponents]:
|
||||
def parse_ucs(rootname: str,
|
||||
catid_list: list[str]) -> Optional[UcsNameComponents]:
|
||||
"""
|
||||
Parse the UCS components from a file name root.
|
||||
|
||||
@@ -79,11 +82,12 @@ def parse_ucs(rootname: str, catid_list: list[str]) -> Optional[UcsNameComponent
|
||||
:returns: the components, or `None` if the filename is not in UCS format
|
||||
"""
|
||||
|
||||
regexp1 = r"^(?P<CatID>[A-z]+)(-(?P<UserCat>[^_]+))?_((?P<VendorCat>[^-]+)-)?(?P<FXName>[^_]+)"
|
||||
regexp1 = r"^(?P<CatID>[A-z]+)(-(?P<UserCat>[^_]+))?_"
|
||||
regexp2 = r"((?P<VendorCat>[^-]+)-)?(?P<FXName>[^_]+)"
|
||||
regexp3 = r"(_(?P<CreatorID>[^_]+)(_(?P<SourceID>[^_]+)"
|
||||
regexp4 = r"(_(?P<UserData>[^.]+))?)?)?"
|
||||
|
||||
regexp2 = r"(_(?P<CreatorID>[^_]+)(_(?P<SourceID>[^_]+)(_(?P<UserData>[^.]+))?)?)?"
|
||||
|
||||
regexp = regexp1 + regexp2
|
||||
regexp = regexp1 + regexp2 + regexp3 + regexp4
|
||||
|
||||
matches = match(regexp, rootname)
|
||||
|
||||
|
Reference in New Issue
Block a user