This commit is contained in:
Jamie Hardt
2025-08-26 17:14:56 -07:00
parent 3d67623d77
commit 5ea64d089f
3 changed files with 57 additions and 54 deletions

View File

@@ -52,7 +52,7 @@ def gather(paths, outfile):
print(f"Found {len(scan_list)} files to process.") print(f"Found {len(scan_list)} files to process.")
for pair in tqdm.tqdm(scan_list, unit='files',file=sys.stderr): for pair in tqdm.tqdm(scan_list, unit='files', file=sys.stderr):
if desc := ffmpeg_description(pair[1]): if desc := ffmpeg_description(pair[1]):
table.writerow([pair[0], desc]) table.writerow([pair[0], desc])
@@ -107,7 +107,7 @@ def evaluate(dataset, offset, limit):
miss_counts = [] miss_counts = []
for cat in cats: for cat in cats:
miss_counts.append((cat, len([x for x in results \ miss_counts.append((cat, len([x for x in results
if x['catid'] == cat and x['result'] == 'MISS']))) if x['catid'] == cat and x['result'] == 'MISS'])))
miss_counts = sorted(miss_counts, key=lambda x: x[1]) miss_counts = sorted(miss_counts, key=lambda x: x[1])
@@ -131,7 +131,7 @@ def evaluate(dataset, offset, limit):
f"{float(miss_counts[-1][1])/float(total):.2%}"] f"{float(miss_counts[-1][1])/float(total):.2%}"]
] ]
print(tabulate(table, headers=['','n','pct'])) print(tabulate(table, headers=['', 'n', 'pct']))
if __name__ == '__main__': if __name__ == '__main__':

View File

@@ -11,6 +11,7 @@ import platformdirs
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
def classify_text_ranked(text, embeddings_list, model, limit=5): def classify_text_ranked(text, embeddings_list, model, limit=5):
text_embedding = model.encode(text, convert_to_numpy=True) text_embedding = model.encode(text, convert_to_numpy=True)
embeddings = np.array([info['Embedding'] for info in embeddings_list]) embeddings = np.array([info['Embedding'] for info in embeddings_list])
@@ -32,6 +33,7 @@ class Ucs(NamedTuple):
subcategory=d['SubCategory'], subcategory=d['SubCategory'],
explanations=d['Explanations'], synonymns=d['Synonyms']) explanations=d['Explanations'], synonymns=d['Synonyms'])
def load_ucs() -> list[Ucs]: def load_ucs() -> list[Ucs]:
FILE_ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) FILE_ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
cats = [] cats = []
@@ -43,6 +45,7 @@ def load_ucs() -> list[Ucs]:
return [Ucs.from_dict(cat) for cat in cats] return [Ucs.from_dict(cat) for cat in cats]
class InferenceContext: class InferenceContext:
""" """
Maintains caches and resources for UCS category inference. Maintains caches and resources for UCS category inference.
@@ -108,9 +111,8 @@ class InferenceContext:
:raises: StopIterator if CatId is not on the schedule :raises: StopIterator if CatId is not on the schedule
""" """
i = ( i = (
(x.category, x.subcategory, x.explanations) \ (x.category, x.subcategory, x.explanations)
for x in self.catlist if x.catid == catid for x in self.catlist if x.catid == catid
) )
return next(i) return next(i)

View File

@@ -6,6 +6,7 @@ from re import match
from .inference import Ucs from .inference import Ucs
def ffmpeg_description(path: str) -> Optional[str]: def ffmpeg_description(path: str) -> Optional[str]:
result = subprocess.run(['ffprobe', '-show_format', '-of', result = subprocess.run(['ffprobe', '-show_format', '-of',
'json', path], capture_output=True) 'json', path], capture_output=True)