Datasets
Custom dataset for transformers Trainer
class WhatsappClassificationDataset(Dataset):
"""Dataset for sequence classification task.
Tokenizer is applied to the input text with truncation and padding.
Example:
>>> from whatsapp_dataset import WhatsappClassificationDataset
>>> dataset = WhatsappClassificationDataset("data/whatsapp/classification/cased/train.json")
>>> dataset[0]
{'input_ids': tensor([10814, 39869, 11, ..., 50256, 50256, 50256]),
'attention_mask': tensor([1, 1, 1, ..., 0, 0, 0]),
'labels': tensor(1)}
"""
def __init__(self, data_path):
self.tokenizer = GPT2TokenizerFast.from_pretrained(MODEL_NAME)
# note: adding new pad token will change the vocab size
# to keep it simple just reuse an existing special token
# https://github.com/huggingface/transformers/issues/6263
self.tokenizer.pad_token = self.tokenizer.eos_token
texts, self.labels = self._parse_classification_data(data_path)
self.encodings = self.tokenizer(texts, truncation=True, padding=True)
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
item['labels'] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.labels)
@staticmethod
def _parse_classification_data(data_path):
texts = []
labels = []
with open(data_path) as f:
for line in f.readlines():
x = json.loads(line)
texts.append(x["sentence"])
labels.append(x["label"])
return texts, labels
Note: Since we added tokenizer.pad_token, we must also remember to update the model.config
# added padding token to gpt2 tokenizer => need to update model.config
# https://github.com/huggingface/transformers/issues/6263
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)
model.config.pad_token_id = model.config.eos_token_id
The output is of the form:
dataset = WhatsappClassificationDataset(data_file)
dataset[i]
{
'input_ids': tensor([10814, 39869, 11, ..., 50256, 50256, 50256]),
'attention_mask': tensor([1, 1, 1, ..., 0, 0, 0]),
'labels': tensor(1)
}
Slicing works as you would expect too:
dataset = WhatsppPromptsDataset(data_file)
dataset[0:3]
{
'input_ids': tensor([[ 50, 14715, 3208, ..., 50256, 50256, 50256],
[ 50, 14715, 3208, ..., 50256, 50256, 50256],
[ 50, 14715, 3208, ..., 50256, 50256, 50256]]),
'attention_mask': tensor([[1, 1, 1, ..., 0, 0, 0],
[1, 1, 1, ..., 0, 0, 0],
[1, 1, 1, ..., 0, 0, 0]]),
'labels': tensor([[ -100, -100, -100, ..., 50256, 50256, 50256],
[ -100, -100, -100, ..., 50256, 50256, 50256],
[ -100, -100, -100, ..., 50256, 50256, 50256]])
}
Consume this dataset as follows:
data_files = {
"train": os.path.join(args.datapath, "train.json"),
"val": os.path.join(args.datapath, "val.json"),
}
train_dataset = WhatsappClassificationDataset(data_files["train"])
val_dataset = WhatsappClassificationDataset(data_files["val"])
...
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
)
Create custom Huggingface dataset from json
Note: I don't recommend this approach. In my experience it can be error prone and hard to debug. Prefer native PyTorch datasets.
Build a json file and create huggingface dataset following:
https://huggingface.co/docs/datasets/master/loading_datasets.html
Create json files where each line corresponds to json object like this:
{"sentence": "Alright, thank you", "label": 0}
{"sentence": "Driving lesson today?", "label": 1}
...
Example
import json
with open(raw_path) as f_raw, open(train_path, 'w+') as f_train:
for line in f_raw.readlines():
# process raw text
sentence, label = process_line(line)
# write json output
row = {}
row["sentence"] = sentence
row["label"] = label
f_train.write(json.dumps(row)+"\n")
Use this to create a dataset.
from datasets import load_dataset
data_files = {
"train": os.path.join(args.datapath, "train.json"),
"val": os.path.join(args.datapath, "val.json"),
}
dataset = load_dataset('json', data_files=data_files)
print(dataset)
Giving
DatasetDict({
train: Dataset({
features: ['sentence', 'label'],
num_rows: 8101
})
val: Dataset({
features: ['sentence', 'label'],
num_rows: 923
})