diff --git a/TODO.md b/TODO.md index 6707ce9..cc0775f 100644 --- a/TODO.md +++ b/TODO.md @@ -17,8 +17,9 @@ - Print more information about the dataset coverage of UCS - Allow skipping model testing for this + - Print raw output -- Maybe load everything into a sqlite for slicker reporting + ## Utility diff --git a/ucsinfer/__main__.py b/ucsinfer/__main__.py index 72cde6e..c329a4e 100644 --- a/ucsinfer/__main__.py +++ b/ucsinfer/__main__.py @@ -163,10 +163,12 @@ def gather(ctx, paths, out, ucs_data): logger.info('Creating dataset for UCS categories instead of from PATH') paths = [] + walker_p = tqdm.tqdm(total=None, unit='dir', desc="Walking filesystem...") for path in paths: for dirpath, _, filenames in os.walk(path): logger.info(f"Walking directory {dirpath}") for filename in filenames: + walker_p.update() root, ext = os.path.splitext(filename) if ext not in types or filename.startswith("._"): continue @@ -175,6 +177,7 @@ def gather(ctx, paths, out, ucs_data): p = os.path.join(dirpath, filename) logger.info(f"Adding path to scan list {p}") scan_list.append((ucs_components.cat_id, p)) + walker_p.close() logger.info(f"Found {len(scan_list)} files to process.") @@ -184,7 +187,7 @@ def gather(ctx, paths, out, ucs_data): try: desc = ffmpeg_description(pair[1]) except CalledProcessError as e: - logger.error(f"ffprobe returned error {e.returncode}: " \ + logger.error(f"ffprobe returned error (){e.returncode}): " \ + e.stderr) continue diff --git a/ucsinfer/evaluate.py b/ucsinfer/evaluate.py new file mode 100644 index 0000000..70f05cf --- /dev/null +++ b/ucsinfer/evaluate.py @@ -0,0 +1,32 @@ + + +from sentence_transformers import SentenceTransformer +from sentence_transformers.evaluation import BinaryClassificationEvaluator +from datasets import load_dataset_from_disk, DatasetDict + + +def evaluate_model(model: SentenceTransformer, dataset): + + # eval_dataset = + + # Initialize the evaluator + binary_acc_evaluator = BinaryClassificationEvaluator( + sentences1=eval_dataset["sentence1"], + sentences2=eval_dataset["sentence2"], + labels=eval_dataset["label"], + name="quora_duplicates_dev", + ) + results = binary_acc_evaluator(model) + ''' + Binary Accuracy Evaluation of the model on the quora_duplicates_dev dataset: + Accuracy with Cosine-Similarity: 81.60 (Threshold: 0.8352) + F1 with Cosine-Similarity: 75.27 (Threshold: 0.7715) + Precision with Cosine-Similarity: 65.81 + Recall with Cosine-Similarity: 87.89 + Average Precision with Cosine-Similarity: 76.03 + Matthews Correlation with Cosine-Similarity: 62.48 + ''' + print(binary_acc_evaluator.primary_metric) + # => "quora_duplicates_dev_cosine_ap" + print(results[binary_acc_evaluator.primary_metric]) + # => 0.760277070888393 diff --git a/ucsinfer/gather.py b/ucsinfer/gather.py index 1700f52..f038983 100644 --- a/ucsinfer/gather.py +++ b/ucsinfer/gather.py @@ -1,13 +1,16 @@ from datasets import Dataset, Features, Value, ClassLabel, DatasetInfo +from datasets.dataset_dict import DatasetDict from typing import Generator, Any # https://www.sbert.net/docs/sentence_transformer/loss_overview.html def build_sentence_class_dataset( - records: Generator[tuple[str, str], Any, None], catlist: list[str]) -> Dataset: + records: Generator[tuple[str, str], Any, None], + catlist: list[str]) -> DatasetDict: """ Create a new dataset for `records` which contains (sentence, class) pairs. + The dataset is split into train and test slices. :param records: a generator for records that generates pairs of (sentence, catid) @@ -16,9 +19,12 @@ def build_sentence_class_dataset( labels = ClassLabel(names=catlist) + features = Features({'sentence': Value('string'), + 'class': labels}) + info = DatasetInfo( description=f"(sentence, UCS CatID) pairs gathered by the " - "ucsinfer tool on {}") + "ucsinfer tool on {}", features= features) items: list[dict] = [] @@ -26,9 +32,16 @@ def build_sentence_class_dataset( items += [{'sentence': obj[0], 'class': obj[1]}] - return Dataset.from_list(items, features=Features({'sentence': Value('string'), - 'class': labels}), - info=info) + whole = Dataset.from_list(items, features=features, info=info) + + split_set = whole.train_test_split(0.2) + test_eval_set = split_set['test'].train_test_split(0.5) + + return DatasetDict({ + 'train': split_set['train'], + 'test': test_eval_set['train'], + 'eval': test_eval_set['test'] + }) # def build_sentence_anchor_dataset() -> Dataset: