Split training sets from gather
This commit is contained in:
3
TODO.md
3
TODO.md
@@ -17,8 +17,9 @@
|
|||||||
|
|
||||||
- Print more information about the dataset coverage of UCS
|
- Print more information about the dataset coverage of UCS
|
||||||
- Allow skipping model testing for this
|
- Allow skipping model testing for this
|
||||||
|
|
||||||
- Print raw output
|
- Print raw output
|
||||||
- Maybe load everything into a sqlite for slicker reporting
|
<!-- - Maybe load everything into a sqlite for slicker reporting -->
|
||||||
|
|
||||||
|
|
||||||
## Utility
|
## Utility
|
||||||
|
@@ -163,10 +163,12 @@ def gather(ctx, paths, out, ucs_data):
|
|||||||
logger.info('Creating dataset for UCS categories instead of from PATH')
|
logger.info('Creating dataset for UCS categories instead of from PATH')
|
||||||
paths = []
|
paths = []
|
||||||
|
|
||||||
|
walker_p = tqdm.tqdm(total=None, unit='dir', desc="Walking filesystem...")
|
||||||
for path in paths:
|
for path in paths:
|
||||||
for dirpath, _, filenames in os.walk(path):
|
for dirpath, _, filenames in os.walk(path):
|
||||||
logger.info(f"Walking directory {dirpath}")
|
logger.info(f"Walking directory {dirpath}")
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
|
walker_p.update()
|
||||||
root, ext = os.path.splitext(filename)
|
root, ext = os.path.splitext(filename)
|
||||||
if ext not in types or filename.startswith("._"):
|
if ext not in types or filename.startswith("._"):
|
||||||
continue
|
continue
|
||||||
@@ -175,6 +177,7 @@ def gather(ctx, paths, out, ucs_data):
|
|||||||
p = os.path.join(dirpath, filename)
|
p = os.path.join(dirpath, filename)
|
||||||
logger.info(f"Adding path to scan list {p}")
|
logger.info(f"Adding path to scan list {p}")
|
||||||
scan_list.append((ucs_components.cat_id, p))
|
scan_list.append((ucs_components.cat_id, p))
|
||||||
|
walker_p.close()
|
||||||
|
|
||||||
logger.info(f"Found {len(scan_list)} files to process.")
|
logger.info(f"Found {len(scan_list)} files to process.")
|
||||||
|
|
||||||
@@ -184,7 +187,7 @@ def gather(ctx, paths, out, ucs_data):
|
|||||||
try:
|
try:
|
||||||
desc = ffmpeg_description(pair[1])
|
desc = ffmpeg_description(pair[1])
|
||||||
except CalledProcessError as e:
|
except CalledProcessError as e:
|
||||||
logger.error(f"ffprobe returned error {e.returncode}: " \
|
logger.error(f"ffprobe returned error (){e.returncode}): " \
|
||||||
+ e.stderr)
|
+ e.stderr)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
32
ucsinfer/evaluate.py
Normal file
32
ucsinfer/evaluate.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
|
||||||
|
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from sentence_transformers.evaluation import BinaryClassificationEvaluator
|
||||||
|
from datasets import load_dataset_from_disk, DatasetDict
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_model(model: SentenceTransformer, dataset):
|
||||||
|
|
||||||
|
# eval_dataset =
|
||||||
|
|
||||||
|
# Initialize the evaluator
|
||||||
|
binary_acc_evaluator = BinaryClassificationEvaluator(
|
||||||
|
sentences1=eval_dataset["sentence1"],
|
||||||
|
sentences2=eval_dataset["sentence2"],
|
||||||
|
labels=eval_dataset["label"],
|
||||||
|
name="quora_duplicates_dev",
|
||||||
|
)
|
||||||
|
results = binary_acc_evaluator(model)
|
||||||
|
'''
|
||||||
|
Binary Accuracy Evaluation of the model on the quora_duplicates_dev dataset:
|
||||||
|
Accuracy with Cosine-Similarity: 81.60 (Threshold: 0.8352)
|
||||||
|
F1 with Cosine-Similarity: 75.27 (Threshold: 0.7715)
|
||||||
|
Precision with Cosine-Similarity: 65.81
|
||||||
|
Recall with Cosine-Similarity: 87.89
|
||||||
|
Average Precision with Cosine-Similarity: 76.03
|
||||||
|
Matthews Correlation with Cosine-Similarity: 62.48
|
||||||
|
'''
|
||||||
|
print(binary_acc_evaluator.primary_metric)
|
||||||
|
# => "quora_duplicates_dev_cosine_ap"
|
||||||
|
print(results[binary_acc_evaluator.primary_metric])
|
||||||
|
# => 0.760277070888393
|
@@ -1,13 +1,16 @@
|
|||||||
from datasets import Dataset, Features, Value, ClassLabel, DatasetInfo
|
from datasets import Dataset, Features, Value, ClassLabel, DatasetInfo
|
||||||
|
from datasets.dataset_dict import DatasetDict
|
||||||
|
|
||||||
from typing import Generator, Any
|
from typing import Generator, Any
|
||||||
|
|
||||||
# https://www.sbert.net/docs/sentence_transformer/loss_overview.html
|
# https://www.sbert.net/docs/sentence_transformer/loss_overview.html
|
||||||
|
|
||||||
def build_sentence_class_dataset(
|
def build_sentence_class_dataset(
|
||||||
records: Generator[tuple[str, str], Any, None], catlist: list[str]) -> Dataset:
|
records: Generator[tuple[str, str], Any, None],
|
||||||
|
catlist: list[str]) -> DatasetDict:
|
||||||
"""
|
"""
|
||||||
Create a new dataset for `records` which contains (sentence, class) pairs.
|
Create a new dataset for `records` which contains (sentence, class) pairs.
|
||||||
|
The dataset is split into train and test slices.
|
||||||
|
|
||||||
:param records: a generator for records that generates pairs of
|
:param records: a generator for records that generates pairs of
|
||||||
(sentence, catid)
|
(sentence, catid)
|
||||||
@@ -16,9 +19,12 @@ def build_sentence_class_dataset(
|
|||||||
|
|
||||||
labels = ClassLabel(names=catlist)
|
labels = ClassLabel(names=catlist)
|
||||||
|
|
||||||
|
features = Features({'sentence': Value('string'),
|
||||||
|
'class': labels})
|
||||||
|
|
||||||
info = DatasetInfo(
|
info = DatasetInfo(
|
||||||
description=f"(sentence, UCS CatID) pairs gathered by the "
|
description=f"(sentence, UCS CatID) pairs gathered by the "
|
||||||
"ucsinfer tool on {}")
|
"ucsinfer tool on {}", features= features)
|
||||||
|
|
||||||
|
|
||||||
items: list[dict] = []
|
items: list[dict] = []
|
||||||
@@ -26,9 +32,16 @@ def build_sentence_class_dataset(
|
|||||||
items += [{'sentence': obj[0], 'class': obj[1]}]
|
items += [{'sentence': obj[0], 'class': obj[1]}]
|
||||||
|
|
||||||
|
|
||||||
return Dataset.from_list(items, features=Features({'sentence': Value('string'),
|
whole = Dataset.from_list(items, features=features, info=info)
|
||||||
'class': labels}),
|
|
||||||
info=info)
|
split_set = whole.train_test_split(0.2)
|
||||||
|
test_eval_set = split_set['test'].train_test_split(0.5)
|
||||||
|
|
||||||
|
return DatasetDict({
|
||||||
|
'train': split_set['train'],
|
||||||
|
'test': test_eval_set['train'],
|
||||||
|
'eval': test_eval_set['test']
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
# def build_sentence_anchor_dataset() -> Dataset:
|
# def build_sentence_anchor_dataset() -> Dataset:
|
||||||
|
Reference in New Issue
Block a user