Compare commits
	
		
			3 Commits
		
	
	
		
			354d1d4e40
			...
			4f7b2a73cb
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|   | 4f7b2a73cb | ||
|   | cc78291f1d | ||
|   | 7b3930ece0 | 
| @@ -2,6 +2,8 @@ import os | ||||
| import sys | ||||
| import csv | ||||
|  | ||||
| from typing import Optional | ||||
|  | ||||
| from sentence_transformers import SentenceTransformer | ||||
| import tqdm | ||||
| import click | ||||
| @@ -11,19 +13,60 @@ from .inference import InferenceContext, load_ucs | ||||
| from .util import ffmpeg_description, parse_ucs | ||||
|  | ||||
|  | ||||
| @click.group() | ||||
| def recommend_text(text: str, ctx: InferenceContext): | ||||
|     return None | ||||
|  | ||||
| @click.group(epilog="For more information see " | ||||
|              "<https://git.squad51.us/jamie/ucsinfer>") | ||||
| # @click.option('--verbose', flag_value='verbose', help='Verbose output') | ||||
| def ucsinfer(): | ||||
|     """ | ||||
|     Tools for applying UCS categories to sounds using large-language Models  | ||||
|  | ||||
|     """ | ||||
|     pass | ||||
|  | ||||
|  | ||||
| @ucsinfer.command('recommend') | ||||
| def recommend(): | ||||
| @click.option('--text', type=Optional[str], default=None, | ||||
|               help="Recommend a category for given text instead of reading " | ||||
|               "from a file") | ||||
| @click.argument('files', nargs=-1) | ||||
| @click.option('--model', type=str, metavar="<model-name>",  | ||||
|               default="paraphrase-multilingual-mpnet-base-v2", | ||||
|               show_default=True,  | ||||
|               help="Select the sentence_transformer model to use") | ||||
| def recommend(text, paths, model): | ||||
|     """ | ||||
|     Infer a UCS category for a text description | ||||
|  | ||||
|     """ | ||||
|     pass | ||||
|     m = SentenceTransformer(model) | ||||
|     ctx = InferenceContext(m, model) | ||||
|  | ||||
|     recommendations = [] | ||||
|     if text is not None: | ||||
|         recommendations.append({ | ||||
|             "text": text, | ||||
|             "recommendations": recommend_text(text, ctx) | ||||
|             }) | ||||
|      | ||||
|     for path in paths: | ||||
|         text = ffmpeg_description(path) | ||||
|         if text: | ||||
|             recommendations.append({ | ||||
|                 "path":path, | ||||
|                 "text":text, | ||||
|                 "recommendations":recommend_text(text, ctx) | ||||
|                 }) | ||||
|         else: | ||||
|             recommendations.append({ | ||||
|                 "path":path,  | ||||
|                 "text":None, | ||||
|                 "recommendations":None | ||||
|                 }) | ||||
|                      | ||||
|  | ||||
|  | ||||
|  | ||||
| @ucsinfer.command('gather') | ||||
| @@ -77,13 +120,17 @@ def finetune(): | ||||
|  | ||||
|  | ||||
| @ucsinfer.command('evaluate') | ||||
| @click.option('--offset', type=int, default=0) | ||||
| @click.option('--limit', type=int, default=-1) | ||||
| @click.option('--offset', type=int, default=0, metavar="<int>",  | ||||
|               help='Skip this many records in the dataset before processing') | ||||
| @click.option('--limit', type=int, default=-1, metavar="<int>", | ||||
|               help='Process this many records and then exit') | ||||
| @click.option('--no-foley', 'no_foley', flag_value=True, default=False,  | ||||
|               help="Ignore any data in the set with FOLYProp or FOLYFeet " | ||||
|               "category") | ||||
| @click.option('--model', type=str,  | ||||
|               default="paraphrase-multilingual-mpnet-base-v2") | ||||
| @click.option('--model', type=str, metavar="<model-name>",  | ||||
|               default="paraphrase-multilingual-mpnet-base-v2", | ||||
|               show_default=True,  | ||||
|               help="Select the sentence_transformer model to use") | ||||
| @click.argument('dataset', type=click.File('r', encoding='utf8'), | ||||
|                 default='dataset.csv') | ||||
| def evaluate(dataset, offset, limit, model, no_foley): | ||||
| @@ -110,8 +157,17 @@ def evaluate(dataset, offset, limit, model, no_foley): | ||||
|     ctx = InferenceContext(m, model) | ||||
|     reader = csv.reader(dataset) | ||||
|  | ||||
|     print(f"Evaluating model {model}...") | ||||
|     results = [] | ||||
|     for i, row in enumerate(tqdm.tqdm(reader)): | ||||
|      | ||||
|     if offset > 0: | ||||
|         print(f"Skipping {offset} records...") | ||||
|  | ||||
|     if limit > 0: | ||||
|         print(f"Will only evaluate {limit} records...") | ||||
|  | ||||
|     progress_bar = tqdm.tqdm(total=limit) | ||||
|     for i, row in enumerate(reader): | ||||
|         if i < offset: | ||||
|             continue | ||||
|  | ||||
| @@ -132,6 +188,8 @@ def evaluate(dataset, offset, limit, model, no_foley): | ||||
|         else: | ||||
|             results.append({'catid': cat_id, 'result': "MISS"}) | ||||
|          | ||||
|         progress_bar.update(1) | ||||
|  | ||||
|     total = len(results) | ||||
|     total_top = len([x for x in results if x['result'] == 'TOP']) | ||||
|     total_top_5 = len([x for x in results if x['result'] == 'TOP_5']) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user