| import random |
| from dataclasses import dataclass |
| from typing import Dict, Iterable, Iterator, List, Optional, Tuple |
|
|
| import torch |
| from torch.utils.data import IterableDataset |
| from datasets import load_dataset |
| from transformers import PreTrainedTokenizerBase |
| import yaml |
|
|
| @dataclass |
| class DataSource: |
| name: str |
| hf_path: str |
| hf_name: Optional[str] |
| split: str |
| text_field: str |
| weight: int = 1 |
| streaming: bool = True |
|
|
| def load_sources_from_yaml(path: str) -> List[DataSource]: |
| with open(path, "r", encoding="utf-8") as f: |
| cfg = yaml.safe_load(f) |
| srcs = [] |
| for s in cfg.get("sources", []): |
| srcs.append(DataSource( |
| name=s.get("name"), |
| hf_path=s.get("hf_path"), |
| hf_name=s.get("hf_name"), |
| split=s.get("split", "train"), |
| text_field=s.get("text_field", "text"), |
| weight=int(s.get("weight", 1)), |
| streaming=bool(s.get("streaming", True)), |
| )) |
| assert len(srcs) > 0, "No data sources configured" |
| return srcs |
|
|
| def build_streams(sources: List[DataSource]) -> List[Iterator[Dict]]: |
| iters = [] |
| for s in sources: |
| ds = load_dataset(s.hf_path, s.hf_name, split=s.split, streaming=s.streaming) |
| iters.append(iter(ds)) |
| return iters |
|
|
| def weighted_choice(weights: List[int]) -> int: |
| total = sum(weights) |
| r = random.randint(1, total) |
| acc = 0 |
| for i, w in enumerate(weights): |
| acc += w |
| if r <= acc: |
| return i |
| return len(weights) - 1 |
|
|
| class TokenChunkDataset(IterableDataset): |
| def __init__( |
| self, |
| tokenizer: PreTrainedTokenizerBase, |
| sources: List[DataSource], |
| seq_len: int, |
| eos_token_id: Optional[int] = None, |
| ): |
| super().__init__() |
| self.tok = tokenizer |
| self.sources = sources |
| self.seq_len = seq_len |
| self.eos_id = eos_token_id if eos_token_id is not None else getattr(tokenizer, "eos_token_id", None) |
| self.weights = [max(1, s.weight) for s in sources] |
|
|
| def _iter_texts(self) -> Iterator[str]: |
| iters = build_streams(self.sources) |
| while True: |
| i = weighted_choice(self.weights) |
| try: |
| row = next(iters[i]) |
| except StopIteration: |
| try: |
| ds = load_dataset( |
| self.sources[i].hf_path, |
| self.sources[i].hf_name, |
| split=self.sources[i].split, |
| streaming=self.sources[i].streaming |
| ) |
| iters[i] = iter(ds) |
| row = next(iters[i]) |
| except (StopIteration, Exception) as e: |
| print(f"Warning: Could not restart iterator for source {self.sources[i].name}: {e}") |
| continue |
| text = row.get(self.sources[i].text_field, None) |
| if isinstance(text, str) and len(text) > 0: |
| yield text |
|
|
| def _safe_encode(self, text: str) -> list: |
| try: |
| return self.tok.encode(text) |
| except Exception as e: |
| print(f"Encoding error for text: {text[:50]}... Error: {e}") |
| return [] |
|
|
| def _iter_token_ids(self) -> Iterator[int]: |
| for text in self._iter_texts(): |
| ids = self._safe_encode(text) |
| if self.eos_id is not None: |
| ids.append(self.eos_id) |
| for t in ids: |
| yield t |
|
|
| def __iter__(self): |
| buf: List[int] = [] |
| for tok_id in self._iter_token_ids(): |
| buf.append(tok_id) |
| while len(buf) >= self.seq_len + 1: |
| x = torch.tensor(buf[:self.seq_len], dtype=torch.long) |
| y = torch.tensor(buf[1:self.seq_len + 1], dtype=torch.long) |
| del buf[:self.seq_len] |
| yield x, y |
|
|
| def __len__(self): |
| |
| return 1000000 |
|
|
|
|