Split training sets from gather
This commit is contained in:
@@ -1,13 +1,16 @@
|
||||
from datasets import Dataset, Features, Value, ClassLabel, DatasetInfo
|
||||
from datasets.dataset_dict import DatasetDict
|
||||
|
||||
from typing import Generator, Any
|
||||
|
||||
# https://www.sbert.net/docs/sentence_transformer/loss_overview.html
|
||||
|
||||
def build_sentence_class_dataset(
|
||||
records: Generator[tuple[str, str], Any, None], catlist: list[str]) -> Dataset:
|
||||
records: Generator[tuple[str, str], Any, None],
|
||||
catlist: list[str]) -> DatasetDict:
|
||||
"""
|
||||
Create a new dataset for `records` which contains (sentence, class) pairs.
|
||||
The dataset is split into train and test slices.
|
||||
|
||||
:param records: a generator for records that generates pairs of
|
||||
(sentence, catid)
|
||||
@@ -16,9 +19,12 @@ def build_sentence_class_dataset(
|
||||
|
||||
labels = ClassLabel(names=catlist)
|
||||
|
||||
features = Features({'sentence': Value('string'),
|
||||
'class': labels})
|
||||
|
||||
info = DatasetInfo(
|
||||
description=f"(sentence, UCS CatID) pairs gathered by the "
|
||||
"ucsinfer tool on {}")
|
||||
"ucsinfer tool on {}", features= features)
|
||||
|
||||
|
||||
items: list[dict] = []
|
||||
@@ -26,9 +32,16 @@ def build_sentence_class_dataset(
|
||||
items += [{'sentence': obj[0], 'class': obj[1]}]
|
||||
|
||||
|
||||
return Dataset.from_list(items, features=Features({'sentence': Value('string'),
|
||||
'class': labels}),
|
||||
info=info)
|
||||
whole = Dataset.from_list(items, features=features, info=info)
|
||||
|
||||
split_set = whole.train_test_split(0.2)
|
||||
test_eval_set = split_set['test'].train_test_split(0.5)
|
||||
|
||||
return DatasetDict({
|
||||
'train': split_set['train'],
|
||||
'test': test_eval_set['train'],
|
||||
'eval': test_eval_set['test']
|
||||
})
|
||||
|
||||
|
||||
# def build_sentence_anchor_dataset() -> Dataset:
|
||||
|
Reference in New Issue
Block a user