Compare commits
	
		
			2 Commits
		
	
	
		
			3d67623d77
			...
			fb56ca1dd4
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|   | fb56ca1dd4 | ||
|   | 5ea64d089f | 
| @@ -52,7 +52,7 @@ def gather(paths, outfile): | |||||||
|  |  | ||||||
|     print(f"Found {len(scan_list)} files to process.") |     print(f"Found {len(scan_list)} files to process.") | ||||||
|  |  | ||||||
|     for pair in tqdm.tqdm(scan_list, unit='files',file=sys.stderr): |     for pair in tqdm.tqdm(scan_list, unit='files', file=sys.stderr): | ||||||
|         if desc := ffmpeg_description(pair[1]): |         if desc := ffmpeg_description(pair[1]): | ||||||
|             table.writerow([pair[0], desc]) |             table.writerow([pair[0], desc]) | ||||||
|  |  | ||||||
| @@ -107,31 +107,32 @@ def evaluate(dataset, offset, limit): | |||||||
|  |  | ||||||
|     miss_counts = [] |     miss_counts = [] | ||||||
|     for cat in cats: |     for cat in cats: | ||||||
|         miss_counts.append((cat, len([x for x in results \ |         miss_counts.append( | ||||||
|                 if x['catid'] == cat and x['result'] == 'MISS']))) |             (cat, len([x for x in results | ||||||
|  |                        if x['catid'] == cat and x['result'] == 'MISS']))) | ||||||
|  |  | ||||||
|     miss_counts = sorted(miss_counts, key=lambda x: x[1]) |     miss_counts = sorted(miss_counts, key=lambda x: x[1]) | ||||||
|  |  | ||||||
|     print(f" === RESULTS === ") |     print(f" === RESULTS === ") | ||||||
|  |  | ||||||
|     table = [ |     table = [ | ||||||
|                 ["Total records in sample:", f"{total}"], |         ["Total records in sample:", f"{total}"], | ||||||
|                 ["Top Result:", f"{total_top}",  |         ["Top Result:", f"{total_top}", | ||||||
|                  f"{float(total_top)/float(total):.2%}"], |          f"{float(total_top)/float(total):.2%}"], | ||||||
|                 ["Top 5 Result:", f"{total_top_5}",  |         ["Top 5 Result:", f"{total_top_5}", | ||||||
|                     f"{float(total_top_5)/float(total):.2%}"], |          f"{float(total_top_5)/float(total):.2%}"], | ||||||
|                 ["Top 10 Result:", f"{total_top_10}",  |         ["Top 10 Result:", f"{total_top_10}", | ||||||
|                     f"{float(total_top_10)/float(total):.2%}"], |          f"{float(total_top_10)/float(total):.2%}"], | ||||||
|                 SEPARATING_LINE, |         SEPARATING_LINE, | ||||||
|                 ["UCS category count:", f"{len(ctx.catlist)}"], |         ["UCS category count:", f"{len(ctx.catlist)}"], | ||||||
|                 ["Total categories in sample:", f"{total_cats}",  |         ["Total categories in sample:", f"{total_cats}", | ||||||
|                     f"{float(total_cats)/float(len(ctx.catlist)):.2%}"], |          f"{float(total_cats)/float(len(ctx.catlist)):.2%}"], | ||||||
|                 [f"Most missed category ({miss_counts[-1][0]}):", |         [f"Most missed category ({miss_counts[-1][0]}):", | ||||||
|                  f"{miss_counts[-1][1]}",  |          f"{miss_counts[-1][1]}", | ||||||
|                  f"{float(miss_counts[-1][1])/float(total):.2%}"] |          f"{float(miss_counts[-1][1])/float(total):.2%}"] | ||||||
|             ] |     ] | ||||||
|  |  | ||||||
|     print(tabulate(table, headers=['','n','pct'])) |     print(tabulate(table, headers=['', 'n', 'pct'])) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|   | |||||||
| @@ -11,6 +11,7 @@ import platformdirs | |||||||
|  |  | ||||||
| from sentence_transformers import SentenceTransformer | from sentence_transformers import SentenceTransformer | ||||||
|  |  | ||||||
|  |  | ||||||
| def classify_text_ranked(text, embeddings_list, model, limit=5): | def classify_text_ranked(text, embeddings_list, model, limit=5): | ||||||
|     text_embedding = model.encode(text, convert_to_numpy=True) |     text_embedding = model.encode(text, convert_to_numpy=True) | ||||||
|     embeddings = np.array([info['Embedding'] for info in embeddings_list]) |     embeddings = np.array([info['Embedding'] for info in embeddings_list]) | ||||||
| @@ -32,6 +33,7 @@ class Ucs(NamedTuple): | |||||||
|                    subcategory=d['SubCategory'], |                    subcategory=d['SubCategory'], | ||||||
|                    explanations=d['Explanations'], synonymns=d['Synonyms']) |                    explanations=d['Explanations'], synonymns=d['Synonyms']) | ||||||
|  |  | ||||||
|  |  | ||||||
| def load_ucs() -> list[Ucs]: | def load_ucs() -> list[Ucs]: | ||||||
|     FILE_ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) |     FILE_ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) | ||||||
|     cats = [] |     cats = [] | ||||||
| @@ -43,6 +45,7 @@ def load_ucs() -> list[Ucs]: | |||||||
|  |  | ||||||
|     return [Ucs.from_dict(cat) for cat in cats] |     return [Ucs.from_dict(cat) for cat in cats] | ||||||
|  |  | ||||||
|  |  | ||||||
| class InferenceContext: | class InferenceContext: | ||||||
|     """ |     """ | ||||||
|     Maintains caches and resources for UCS category inference. |     Maintains caches and resources for UCS category inference. | ||||||
| @@ -72,9 +75,9 @@ class InferenceContext: | |||||||
|  |  | ||||||
|             for cat_defn in self.catlist: |             for cat_defn in self.catlist: | ||||||
|                 embeddings += [{ |                 embeddings += [{ | ||||||
|                         'CatID': cat_defn.catid, |                     'CatID': cat_defn.catid, | ||||||
|                         'Embedding': self._encode_category(cat_defn) |                     'Embedding': self._encode_category(cat_defn) | ||||||
|                                }] |                 }] | ||||||
|  |  | ||||||
|             os.makedirs(os.path.dirname(embedding_cache), exist_ok=True) |             os.makedirs(os.path.dirname(embedding_cache), exist_ok=True) | ||||||
|             with open(embedding_cache, 'wb') as g: |             with open(embedding_cache, 'wb') as g: | ||||||
| @@ -84,9 +87,9 @@ class InferenceContext: | |||||||
|  |  | ||||||
|     def _encode_category(self, cat: Ucs) -> np.ndarray: |     def _encode_category(self, cat: Ucs) -> np.ndarray: | ||||||
|         sentence_components = [cat.explanations, |         sentence_components = [cat.explanations, | ||||||
|                            cat.category,  |                                cat.category, | ||||||
|                            cat.subcategory |                                cat.subcategory | ||||||
|                            ] |                                ] | ||||||
|         sentence_components += cat.synonymns |         sentence_components += cat.synonymns | ||||||
|         sentence = ", ".join(sentence_components) |         sentence = ", ".join(sentence_components) | ||||||
|         return self.model.encode(sentence, convert_to_numpy=True) |         return self.model.encode(sentence, convert_to_numpy=True) | ||||||
| @@ -108,9 +111,8 @@ class InferenceContext: | |||||||
|         :raises: StopIterator if CatId is not on the schedule  |         :raises: StopIterator if CatId is not on the schedule  | ||||||
|         """ |         """ | ||||||
|         i = ( |         i = ( | ||||||
|                 (x.category, x.subcategory, x.explanations) \ |             (x.category, x.subcategory, x.explanations) | ||||||
|                         for x in self.catlist if x.catid == catid |             for x in self.catlist if x.catid == catid | ||||||
|                         ) |         ) | ||||||
|  |  | ||||||
|         return next(i) |         return next(i) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -6,6 +6,7 @@ from re import match | |||||||
|  |  | ||||||
| from .inference import Ucs | from .inference import Ucs | ||||||
|  |  | ||||||
|  |  | ||||||
| def ffmpeg_description(path: str) -> Optional[str]: | def ffmpeg_description(path: str) -> Optional[str]: | ||||||
|     result = subprocess.run(['ffprobe', '-show_format', '-of', |     result = subprocess.run(['ffprobe', '-show_format', '-of', | ||||||
|                              'json', path], capture_output=True) |                              'json', path], capture_output=True) | ||||||
| @@ -65,12 +66,14 @@ def build_ucs(components: UcsNameComponents, extension: str) -> str: | |||||||
|     """ |     """ | ||||||
|     Build a UCS filename |     Build a UCS filename | ||||||
|     """ |     """ | ||||||
|     assert components.validate(), "UcsNameComponents contains invalid characters" |     assert components.validate(), \ | ||||||
|  |             "UcsNameComponents contains invalid characters" | ||||||
|  |  | ||||||
|     return "" |     return "" | ||||||
|  |  | ||||||
|  |  | ||||||
| def parse_ucs(rootname: str, catid_list: list[str]) -> Optional[UcsNameComponents]: | def parse_ucs(rootname: str,  | ||||||
|  |               catid_list: list[str]) -> Optional[UcsNameComponents]: | ||||||
|     """ |     """ | ||||||
|     Parse the UCS components from a file name root. |     Parse the UCS components from a file name root. | ||||||
|  |  | ||||||
| @@ -79,11 +82,12 @@ def parse_ucs(rootname: str, catid_list: list[str]) -> Optional[UcsNameComponent | |||||||
|     :returns: the components, or `None` if the filename is not in UCS format |     :returns: the components, or `None` if the filename is not in UCS format | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     regexp1 = r"^(?P<CatID>[A-z]+)(-(?P<UserCat>[^_]+))?_((?P<VendorCat>[^-]+)-)?(?P<FXName>[^_]+)" |     regexp1 = r"^(?P<CatID>[A-z]+)(-(?P<UserCat>[^_]+))?_" | ||||||
|  |     regexp2 = r"((?P<VendorCat>[^-]+)-)?(?P<FXName>[^_]+)" | ||||||
|  |     regexp3 = r"(_(?P<CreatorID>[^_]+)(_(?P<SourceID>[^_]+)" | ||||||
|  |     regexp4 = r"(_(?P<UserData>[^.]+))?)?)?" | ||||||
|  |  | ||||||
|     regexp2 = r"(_(?P<CreatorID>[^_]+)(_(?P<SourceID>[^_]+)(_(?P<UserData>[^.]+))?)?)?" |     regexp = regexp1 + regexp2 + regexp3 + regexp4 | ||||||
|      |  | ||||||
|     regexp = regexp1 + regexp2  |  | ||||||
|  |  | ||||||
|     matches = match(regexp, rootname) |     matches = match(regexp, rootname) | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user