Compare commits
	
		
			3 Commits
		
	
	
		
			f40a336893
			...
			1c5461a6a9
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
|   | 1c5461a6a9 | ||
|   | aea0bbb7be | ||
|   | c687e4614f | 
							
								
								
									
										18
									
								
								MODELS.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								MODELS.rst
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,18 @@ | |||||||
|  | Results for Model paraphrase-multilingual-mpnet-base-v2 | ||||||
|  | =================== | ||||||
|  |  | ||||||
|  | ================================  ====  ====== | ||||||
|  | ..                                   n  pct | ||||||
|  | ================================  ====  ====== | ||||||
|  | Total records in sample:          3445 | ||||||
|  | Top Result:                        469  13.61% | ||||||
|  | Top 5 Result:                      519  15.07% | ||||||
|  | Top 10 Result:                     513  14.89% | ||||||
|  | ================================  ====  ====== | ||||||
|  | UCS category count:                752 | ||||||
|  | Total categories in sample:        240  31.91% | ||||||
|  | Most missed category (FOLYProp):  1057  30.68% | ||||||
|  | ================================  ====  ====== | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -12,8 +12,8 @@ from .util import ffmpeg_description, parse_ucs | |||||||
|  |  | ||||||
|  |  | ||||||
| @click.group() | @click.group() | ||||||
| @click.option('--verbose', flag_value='verbose', help='Verbose output') | # @click.option('--verbose', flag_value='verbose', help='Verbose output') | ||||||
| def ucsinfer(verbose: bool): | def ucsinfer(): | ||||||
|     pass |     pass | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -68,14 +68,17 @@ def finetune(): | |||||||
| @ucsinfer.command('evaluate') | @ucsinfer.command('evaluate') | ||||||
| @click.option('--offset', type=int, default=0) | @click.option('--offset', type=int, default=0) | ||||||
| @click.option('--limit', type=int, default=-1) | @click.option('--limit', type=int, default=-1) | ||||||
|  | @click.option('--no-foley', type=bool, default=False) | ||||||
|  | @click.option('--model', type=str,  | ||||||
|  |               default="paraphrase-multilingual-mpnet-base-v2") | ||||||
| @click.argument('dataset', type=click.File('r', encoding='utf8'), | @click.argument('dataset', type=click.File('r', encoding='utf8'), | ||||||
|                 default='dataset.csv') |                 default='dataset.csv') | ||||||
| def evaluate(dataset, offset, limit): | def evaluate(dataset, offset, limit, model, no_foley): | ||||||
|     """ |     """ | ||||||
|     Use datasets to evauluate model performance  |     Use datasets to evauluate model performance  | ||||||
|     """ |     """ | ||||||
|     model = SentenceTransformer("paraphrase-multilingual-mpnet-base-v2") |     m = SentenceTransformer(model) | ||||||
|     ctx = InferenceContext(model) |     ctx = InferenceContext(m, model) | ||||||
|     reader = csv.reader(dataset) |     reader = csv.reader(dataset) | ||||||
|  |  | ||||||
|     results = [] |     results = [] | ||||||
| @@ -113,7 +116,8 @@ def evaluate(dataset, offset, limit): | |||||||
|  |  | ||||||
|     miss_counts = sorted(miss_counts, key=lambda x: x[1]) |     miss_counts = sorted(miss_counts, key=lambda x: x[1]) | ||||||
|  |  | ||||||
|     print(f" === RESULTS === ") |     print(f"Results for Model {model}") | ||||||
|  |     print("=====\n") | ||||||
|  |  | ||||||
|     table = [ |     table = [ | ||||||
|         ["Total records in sample:", f"{total}"], |         ["Total records in sample:", f"{total}"], | ||||||
| @@ -132,7 +136,7 @@ def evaluate(dataset, offset, limit): | |||||||
|          f"{float(miss_counts[-1][1])/float(total):.2%}"] |          f"{float(miss_counts[-1][1])/float(total):.2%}"] | ||||||
|     ] |     ] | ||||||
|  |  | ||||||
|     print(tabulate(table, headers=['', 'n', 'pct'])) |     print(tabulate(table, headers=['', 'n', 'pct'], tablefmt='rst')) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|   | |||||||
| @@ -52,9 +52,11 @@ class InferenceContext: | |||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     model: SentenceTransformer |     model: SentenceTransformer | ||||||
|  |     model_name: str | ||||||
|  |  | ||||||
|     def __init__(self, model: SentenceTransformer): |     def __init__(self, model: SentenceTransformer, model_name: str): | ||||||
|         self.model = model |         self.model = model | ||||||
|  |         self.model_name = model_name | ||||||
|  |  | ||||||
|     @cached_property |     @cached_property | ||||||
|     def catlist(self) -> list[Ucs]: |     def catlist(self) -> list[Ucs]: | ||||||
| @@ -62,8 +64,9 @@ class InferenceContext: | |||||||
|  |  | ||||||
|     @cached_property |     @cached_property | ||||||
|     def embeddings(self) -> list[dict]: |     def embeddings(self) -> list[dict]: | ||||||
|         cache_dir = platformdirs.user_cache_dir("ucsinfer", 'Squad 51') |         cache_dir = platformdirs.user_cache_dir("us.squad51.ucsinfer", 'Squad 51') | ||||||
|         embedding_cache = os.path.join(cache_dir, f"ucs_embedding.cache") |         embedding_cache = os.path.join(cache_dir,  | ||||||
|  |                                        f"{self.model_name}-ucs_embedding.cache") | ||||||
|         embeddings = [] |         embeddings = [] | ||||||
|  |  | ||||||
|         if os.path.exists(embedding_cache): |         if os.path.exists(embedding_cache): | ||||||
| @@ -71,7 +74,7 @@ class InferenceContext: | |||||||
|                 embeddings = pickle.load(f) |                 embeddings = pickle.load(f) | ||||||
|  |  | ||||||
|         else: |         else: | ||||||
|             print("Calculating embeddings...") |             print(f"Calculating embeddings for model {self.model_name}...") | ||||||
|  |  | ||||||
|             for cat_defn in self.catlist: |             for cat_defn in self.catlist: | ||||||
|                 embeddings += [{ |                 embeddings += [{ | ||||||
|   | |||||||
| @@ -4,9 +4,6 @@ from typing import NamedTuple, Optional | |||||||
| from re import match | from re import match | ||||||
|  |  | ||||||
|  |  | ||||||
| from .inference import Ucs |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def ffmpeg_description(path: str) -> Optional[str]: | def ffmpeg_description(path: str) -> Optional[str]: | ||||||
|     result = subprocess.run(['ffprobe', '-show_format', '-of', |     result = subprocess.run(['ffprobe', '-show_format', '-of', | ||||||
|                              'json', path], capture_output=True) |                              'json', path], capture_output=True) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user