Compare commits
	
		
			3 Commits
		
	
	
		
			354d1d4e40
			...
			4f7b2a73cb
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|   | 4f7b2a73cb | ||
|   | cc78291f1d | ||
|   | 7b3930ece0 | 
| @@ -2,6 +2,8 @@ import os | |||||||
| import sys | import sys | ||||||
| import csv | import csv | ||||||
|  |  | ||||||
|  | from typing import Optional | ||||||
|  |  | ||||||
| from sentence_transformers import SentenceTransformer | from sentence_transformers import SentenceTransformer | ||||||
| import tqdm | import tqdm | ||||||
| import click | import click | ||||||
| @@ -11,19 +13,60 @@ from .inference import InferenceContext, load_ucs | |||||||
| from .util import ffmpeg_description, parse_ucs | from .util import ffmpeg_description, parse_ucs | ||||||
|  |  | ||||||
|  |  | ||||||
| @click.group() | def recommend_text(text: str, ctx: InferenceContext): | ||||||
|  |     return None | ||||||
|  |  | ||||||
|  | @click.group(epilog="For more information see " | ||||||
|  |              "<https://git.squad51.us/jamie/ucsinfer>") | ||||||
| # @click.option('--verbose', flag_value='verbose', help='Verbose output') | # @click.option('--verbose', flag_value='verbose', help='Verbose output') | ||||||
| def ucsinfer(): | def ucsinfer(): | ||||||
|  |     """ | ||||||
|  |     Tools for applying UCS categories to sounds using large-language Models  | ||||||
|  |  | ||||||
|  |     """ | ||||||
|     pass |     pass | ||||||
|  |  | ||||||
|  |  | ||||||
| @ucsinfer.command('recommend') | @ucsinfer.command('recommend') | ||||||
| def recommend(): | @click.option('--text', type=Optional[str], default=None, | ||||||
|  |               help="Recommend a category for given text instead of reading " | ||||||
|  |               "from a file") | ||||||
|  | @click.argument('files', nargs=-1) | ||||||
|  | @click.option('--model', type=str, metavar="<model-name>",  | ||||||
|  |               default="paraphrase-multilingual-mpnet-base-v2", | ||||||
|  |               show_default=True,  | ||||||
|  |               help="Select the sentence_transformer model to use") | ||||||
|  | def recommend(text, paths, model): | ||||||
|     """ |     """ | ||||||
|     Infer a UCS category for a text description |     Infer a UCS category for a text description | ||||||
|  |  | ||||||
|     """ |     """ | ||||||
|     pass |     m = SentenceTransformer(model) | ||||||
|  |     ctx = InferenceContext(m, model) | ||||||
|  |  | ||||||
|  |     recommendations = [] | ||||||
|  |     if text is not None: | ||||||
|  |         recommendations.append({ | ||||||
|  |             "text": text, | ||||||
|  |             "recommendations": recommend_text(text, ctx) | ||||||
|  |             }) | ||||||
|  |      | ||||||
|  |     for path in paths: | ||||||
|  |         text = ffmpeg_description(path) | ||||||
|  |         if text: | ||||||
|  |             recommendations.append({ | ||||||
|  |                 "path":path, | ||||||
|  |                 "text":text, | ||||||
|  |                 "recommendations":recommend_text(text, ctx) | ||||||
|  |                 }) | ||||||
|  |         else: | ||||||
|  |             recommendations.append({ | ||||||
|  |                 "path":path,  | ||||||
|  |                 "text":None, | ||||||
|  |                 "recommendations":None | ||||||
|  |                 }) | ||||||
|  |                      | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @ucsinfer.command('gather') | @ucsinfer.command('gather') | ||||||
| @@ -77,13 +120,17 @@ def finetune(): | |||||||
|  |  | ||||||
|  |  | ||||||
| @ucsinfer.command('evaluate') | @ucsinfer.command('evaluate') | ||||||
| @click.option('--offset', type=int, default=0) | @click.option('--offset', type=int, default=0, metavar="<int>",  | ||||||
| @click.option('--limit', type=int, default=-1) |               help='Skip this many records in the dataset before processing') | ||||||
|  | @click.option('--limit', type=int, default=-1, metavar="<int>", | ||||||
|  |               help='Process this many records and then exit') | ||||||
| @click.option('--no-foley', 'no_foley', flag_value=True, default=False,  | @click.option('--no-foley', 'no_foley', flag_value=True, default=False,  | ||||||
|               help="Ignore any data in the set with FOLYProp or FOLYFeet " |               help="Ignore any data in the set with FOLYProp or FOLYFeet " | ||||||
|               "category") |               "category") | ||||||
| @click.option('--model', type=str,  | @click.option('--model', type=str, metavar="<model-name>",  | ||||||
|               default="paraphrase-multilingual-mpnet-base-v2") |               default="paraphrase-multilingual-mpnet-base-v2", | ||||||
|  |               show_default=True,  | ||||||
|  |               help="Select the sentence_transformer model to use") | ||||||
| @click.argument('dataset', type=click.File('r', encoding='utf8'), | @click.argument('dataset', type=click.File('r', encoding='utf8'), | ||||||
|                 default='dataset.csv') |                 default='dataset.csv') | ||||||
| def evaluate(dataset, offset, limit, model, no_foley): | def evaluate(dataset, offset, limit, model, no_foley): | ||||||
| @@ -110,8 +157,17 @@ def evaluate(dataset, offset, limit, model, no_foley): | |||||||
|     ctx = InferenceContext(m, model) |     ctx = InferenceContext(m, model) | ||||||
|     reader = csv.reader(dataset) |     reader = csv.reader(dataset) | ||||||
|  |  | ||||||
|  |     print(f"Evaluating model {model}...") | ||||||
|     results = [] |     results = [] | ||||||
|     for i, row in enumerate(tqdm.tqdm(reader)): |      | ||||||
|  |     if offset > 0: | ||||||
|  |         print(f"Skipping {offset} records...") | ||||||
|  |  | ||||||
|  |     if limit > 0: | ||||||
|  |         print(f"Will only evaluate {limit} records...") | ||||||
|  |  | ||||||
|  |     progress_bar = tqdm.tqdm(total=limit) | ||||||
|  |     for i, row in enumerate(reader): | ||||||
|         if i < offset: |         if i < offset: | ||||||
|             continue |             continue | ||||||
|  |  | ||||||
| @@ -131,6 +187,8 @@ def evaluate(dataset, offset, limit, model, no_foley): | |||||||
|             results.append({'catid': cat_id, 'result': "TOP_10"}) |             results.append({'catid': cat_id, 'result': "TOP_10"}) | ||||||
|         else: |         else: | ||||||
|             results.append({'catid': cat_id, 'result': "MISS"}) |             results.append({'catid': cat_id, 'result': "MISS"}) | ||||||
|  |          | ||||||
|  |         progress_bar.update(1) | ||||||
|  |  | ||||||
|     total = len(results) |     total = len(results) | ||||||
|     total_top = len([x for x in results if x['result'] == 'TOP']) |     total_top = len([x for x in results if x['result'] == 'TOP']) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user