diff --git a/ucsinfer/__main__.py b/ucsinfer/__main__.py index 1ea7215..1b5fab6 100644 --- a/ucsinfer/__main__.py +++ b/ucsinfer/__main__.py @@ -77,13 +77,17 @@ def finetune(): @ucsinfer.command('evaluate') -@click.option('--offset', type=int, default=0) -@click.option('--limit', type=int, default=-1) +@click.option('--offset', type=int, default=0, metavar="", + help='Skip this many records in the dataset before processing') +@click.option('--limit', type=int, default=-1, metavar="", + help='Process this many records and then exit') @click.option('--no-foley', 'no_foley', flag_value=True, default=False, help="Ignore any data in the set with FOLYProp or FOLYFeet " "category") -@click.option('--model', type=str, - default="paraphrase-multilingual-mpnet-base-v2") +@click.option('--model', type=str, metavar="", + default="paraphrase-multilingual-mpnet-base-v2", + show_default=True, + help="Select the sentence_transformer model to use") @click.argument('dataset', type=click.File('r', encoding='utf8'), default='dataset.csv') def evaluate(dataset, offset, limit, model, no_foley): @@ -110,8 +114,17 @@ def evaluate(dataset, offset, limit, model, no_foley): ctx = InferenceContext(m, model) reader = csv.reader(dataset) + print(f"Evaluating model {model}...") results = [] - for i, row in enumerate(tqdm.tqdm(reader)): + + if offset > 0: + print(f"Skipping {offset} records...") + + if limit > 0: + print(f"Will only evaluate {limit} records...") + + progress_bar = tqdm.tqdm(total=limit) + for i, row in enumerate(reader): if i < offset: continue @@ -131,6 +144,8 @@ def evaluate(dataset, offset, limit, model, no_foley): results.append({'catid': cat_id, 'result': "TOP_10"}) else: results.append({'catid': cat_id, 'result': "MISS"}) + + progress_bar.update(1) total = len(results) total_top = len([x for x in results if x['result'] == 'TOP'])