| | import os |
| | import json |
| | import torch |
| | import datasets |
| | from torch.utils.data import DataLoader, Dataset |
| | from transformers import PreTrainedTokenizerFast |
| |
|
| | class CustomDataset(Dataset): |
| | def __init__(self, data, tokenizer, max_length=512): |
| | self.data = data |
| | self.tokenizer = tokenizer |
| | self.max_length = max_length |
| |
|
| | def __len__(self): |
| | return len(self.data) |
| |
|
| | def __getitem__(self, idx): |
| | text = self.data[idx]["text"] |
| | inputs = self.tokenizer( |
| | text, |
| | max_length=self.max_length, |
| | padding="max_length", |
| | truncation=True, |
| | return_tensors="pt" |
| | ) |
| | return { |
| | "input_ids": inputs["input_ids"].squeeze(0), |
| | "attention_mask": inputs["attention_mask"].squeeze(0) |
| | } |
| |
|
| | class DataLoaderHandler: |
| | def __init__(self, dataset_path, tokenizer_path, batch_size=8, max_length=512): |
| | self.dataset_path = dataset_path |
| | self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_path) |
| | self.batch_size = batch_size |
| | self.max_length = max_length |
| |
|
| | def load_dataset(self): |
| | if self.dataset_path.endswith(".json"): |
| | with open(self.dataset_path, "r", encoding="utf-8") as f: |
| | data = json.load(f) |
| | elif self.dataset_path.endswith(".jsonl"): |
| | data = [json.loads(line) for line in open(self.dataset_path, "r", encoding="utf-8")] |
| | else: |
| | raise ValueError("Unsupported dataset format. Use JSON or JSONL.") |
| | return data |
| |
|
| | def get_dataloader(self): |
| | data = self.load_dataset() |
| | dataset = CustomDataset(data, self.tokenizer, self.max_length) |
| | return DataLoader(dataset, batch_size=self.batch_size, shuffle=True) |
| |
|
| | if __name__ == "__main__": |
| | dataset_path = "data/dataset.jsonl" |
| | tokenizer_path = "tokenizer.json" |
| | batch_size = 16 |
| |
|
| | data_loader_handler = DataLoaderHandler(dataset_path, tokenizer_path, batch_size) |
| | dataloader = data_loader_handler.get_dataloader() |
| |
|
| | for batch in dataloader: |
| | print(batch["input_ids"].shape, batch["attention_mask"].shape) |
| | break |