65 lines
1.9 KiB
Python
65 lines
1.9 KiB
Python
from datasets import Dataset, Features, Value, ClassLabel, DatasetInfo
|
|
from datasets.dataset_dict import DatasetDict
|
|
|
|
from typing import Iterator
|
|
|
|
from tabulate import tabulate
|
|
|
|
def print_dataset_stats(dataset: DatasetDict, catlist: list[str]):
|
|
|
|
data_table = []
|
|
data_table.append([["Total records in combined dataset:", len(dataset)]])
|
|
data_table.append([["Total records in `train`:", len(dataset['train'])]])
|
|
|
|
tab = tabulate(data_table)
|
|
|
|
print(tab)
|
|
|
|
# https://www.sbert.net/docs/sentence_transformer/loss_overview.html
|
|
|
|
def build_sentence_class_dataset(
|
|
records: Iterator[tuple[str, str]],
|
|
catlist: list[str]) -> DatasetDict:
|
|
"""
|
|
Create a new dataset for `records` which contains (sentence, class) pairs.
|
|
The dataset is split into train and test slices.
|
|
|
|
:param records: a generator for records that generates pairs of
|
|
(sentence, catid)
|
|
:returns: A dataset with two columns: (sentence, hash(catid))
|
|
"""
|
|
|
|
labels = ClassLabel(names=catlist)
|
|
|
|
features = Features({'sentence': Value('string'),
|
|
'class': labels})
|
|
|
|
info = DatasetInfo(
|
|
description=f"(sentence, UCS CatID) pairs gathered by the "
|
|
"ucsinfer tool on {}", features= features)
|
|
|
|
|
|
items: list[dict] = []
|
|
for obj in records:
|
|
items += [{'sentence': obj[0], 'class': obj[1]}]
|
|
|
|
|
|
whole = Dataset.from_list(items, features=features, info=info)
|
|
|
|
split_set = whole.train_test_split(0.2)
|
|
test_eval_set = split_set['test'].train_test_split(0.5)
|
|
|
|
return DatasetDict({
|
|
'train': split_set['train'],
|
|
'test': test_eval_set['train'],
|
|
'eval': test_eval_set['test']
|
|
})
|
|
|
|
|
|
# def build_sentence_anchor_dataset() -> Dataset:
|
|
# """
|
|
# Create a new dataset for `records` which contains (sentence, anchor) pairs.
|
|
# """
|
|
# pass
|
|
|