Split training sets from gather

This commit is contained in:
2025-09-03 19:40:55 -07:00
parent e419f698c9
commit 10519f9c1a
4 changed files with 56 additions and 7 deletions

View File

@@ -17,8 +17,9 @@
- Print more information about the dataset coverage of UCS - Print more information about the dataset coverage of UCS
- Allow skipping model testing for this - Allow skipping model testing for this
- Print raw output - Print raw output
- Maybe load everything into a sqlite for slicker reporting <!-- - Maybe load everything into a sqlite for slicker reporting -->
## Utility ## Utility

View File

@@ -163,10 +163,12 @@ def gather(ctx, paths, out, ucs_data):
logger.info('Creating dataset for UCS categories instead of from PATH') logger.info('Creating dataset for UCS categories instead of from PATH')
paths = [] paths = []
walker_p = tqdm.tqdm(total=None, unit='dir', desc="Walking filesystem...")
for path in paths: for path in paths:
for dirpath, _, filenames in os.walk(path): for dirpath, _, filenames in os.walk(path):
logger.info(f"Walking directory {dirpath}") logger.info(f"Walking directory {dirpath}")
for filename in filenames: for filename in filenames:
walker_p.update()
root, ext = os.path.splitext(filename) root, ext = os.path.splitext(filename)
if ext not in types or filename.startswith("._"): if ext not in types or filename.startswith("._"):
continue continue
@@ -175,6 +177,7 @@ def gather(ctx, paths, out, ucs_data):
p = os.path.join(dirpath, filename) p = os.path.join(dirpath, filename)
logger.info(f"Adding path to scan list {p}") logger.info(f"Adding path to scan list {p}")
scan_list.append((ucs_components.cat_id, p)) scan_list.append((ucs_components.cat_id, p))
walker_p.close()
logger.info(f"Found {len(scan_list)} files to process.") logger.info(f"Found {len(scan_list)} files to process.")
@@ -184,7 +187,7 @@ def gather(ctx, paths, out, ucs_data):
try: try:
desc = ffmpeg_description(pair[1]) desc = ffmpeg_description(pair[1])
except CalledProcessError as e: except CalledProcessError as e:
logger.error(f"ffprobe returned error {e.returncode}: " \ logger.error(f"ffprobe returned error (){e.returncode}): " \
+ e.stderr) + e.stderr)
continue continue

32
ucsinfer/evaluate.py Normal file
View File

@@ -0,0 +1,32 @@
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import BinaryClassificationEvaluator
from datasets import load_dataset_from_disk, DatasetDict
def evaluate_model(model: SentenceTransformer, dataset):
# eval_dataset =
# Initialize the evaluator
binary_acc_evaluator = BinaryClassificationEvaluator(
sentences1=eval_dataset["sentence1"],
sentences2=eval_dataset["sentence2"],
labels=eval_dataset["label"],
name="quora_duplicates_dev",
)
results = binary_acc_evaluator(model)
'''
Binary Accuracy Evaluation of the model on the quora_duplicates_dev dataset:
Accuracy with Cosine-Similarity: 81.60 (Threshold: 0.8352)
F1 with Cosine-Similarity: 75.27 (Threshold: 0.7715)
Precision with Cosine-Similarity: 65.81
Recall with Cosine-Similarity: 87.89
Average Precision with Cosine-Similarity: 76.03
Matthews Correlation with Cosine-Similarity: 62.48
'''
print(binary_acc_evaluator.primary_metric)
# => "quora_duplicates_dev_cosine_ap"
print(results[binary_acc_evaluator.primary_metric])
# => 0.760277070888393

View File

@@ -1,13 +1,16 @@
from datasets import Dataset, Features, Value, ClassLabel, DatasetInfo from datasets import Dataset, Features, Value, ClassLabel, DatasetInfo
from datasets.dataset_dict import DatasetDict
from typing import Generator, Any from typing import Generator, Any
# https://www.sbert.net/docs/sentence_transformer/loss_overview.html # https://www.sbert.net/docs/sentence_transformer/loss_overview.html
def build_sentence_class_dataset( def build_sentence_class_dataset(
records: Generator[tuple[str, str], Any, None], catlist: list[str]) -> Dataset: records: Generator[tuple[str, str], Any, None],
catlist: list[str]) -> DatasetDict:
""" """
Create a new dataset for `records` which contains (sentence, class) pairs. Create a new dataset for `records` which contains (sentence, class) pairs.
The dataset is split into train and test slices.
:param records: a generator for records that generates pairs of :param records: a generator for records that generates pairs of
(sentence, catid) (sentence, catid)
@@ -16,9 +19,12 @@ def build_sentence_class_dataset(
labels = ClassLabel(names=catlist) labels = ClassLabel(names=catlist)
features = Features({'sentence': Value('string'),
'class': labels})
info = DatasetInfo( info = DatasetInfo(
description=f"(sentence, UCS CatID) pairs gathered by the " description=f"(sentence, UCS CatID) pairs gathered by the "
"ucsinfer tool on {}") "ucsinfer tool on {}", features= features)
items: list[dict] = [] items: list[dict] = []
@@ -26,9 +32,16 @@ def build_sentence_class_dataset(
items += [{'sentence': obj[0], 'class': obj[1]}] items += [{'sentence': obj[0], 'class': obj[1]}]
return Dataset.from_list(items, features=Features({'sentence': Value('string'), whole = Dataset.from_list(items, features=features, info=info)
'class': labels}),
info=info) split_set = whole.train_test_split(0.2)
test_eval_set = split_set['test'].train_test_split(0.5)
return DatasetDict({
'train': split_set['train'],
'test': test_eval_set['train'],
'eval': test_eval_set['test']
})
# def build_sentence_anchor_dataset() -> Dataset: # def build_sentence_anchor_dataset() -> Dataset: