vectorplasticity commited on
Commit
16525fb
·
verified ·
1 Parent(s): 39f0935

Add dataset utilities

Browse files
Files changed (1) hide show
  1. app/utils/dataset_utils.py +551 -0
app/utils/dataset_utils.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset Utilities - Helper functions for dataset operations
3
+ """
4
+
5
+ import logging
6
+ from typing import Dict, Any, List, Optional, Tuple
7
+ from datasets import load_dataset, Dataset, DatasetDict
8
+ from transformers import AutoTokenizer
9
+ import json
10
+ import os
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ # Dataset column mappings for common datasets
16
+ DATASET_COLUMN_MAPPINGS = {
17
+ "wikitext": {"text": "text"},
18
+ "squad": {"question": "question", "context": "context", "answers": "answers"},
19
+ "squad_v2": {"question": "question", "context": "context", "answers": "answers"},
20
+ "cnn_dailymail": {"article": "article", "highlights": "highlights"},
21
+ "xsum": {"document": "document", "summary": "summary"},
22
+ "samsum": {"dialogue": "dialogue", "summary": "summary"},
23
+ "billsum": {"text": "text", "summary": "summary"},
24
+ "aeslc": {"email_body": "email_body", "subject_line": "subject_line"},
25
+ "conll2003": {"tokens": "tokens", "ner_tags": "ner_tags"},
26
+ "wnut_17": {"tokens": "tokens", "ner_tags": "ner_tags"},
27
+ "imdb": {"text": "text", "label": "label"},
28
+ "yelp_polarity": {"text": "text", "label": "label"},
29
+ "yelp_review_full": {"text": "text", "label": "label"},
30
+ "sst2": {"sentence": "sentence", "label": "label"},
31
+ "cola": {"sentence": "sentence", "label": "label"},
32
+ "mnli": {"premise": "premise", "hypothesis": "hypothesis", "label": "label"},
33
+ "qnli": {"question": "question", "sentence": "sentence", "label": "label"},
34
+ "qqp": {"question1": "question1", "question2": "question2", "label": "label"},
35
+ "mrpc": {"sentence1": "sentence1", "sentence2": "sentence2", "label": "label"},
36
+ "stsb": {"sentence1": "sentence1", "sentence2": "sentence2", "label": "label"},
37
+ "glue": {},
38
+ "super_glue": {},
39
+ "trec": {"text": "text", "label": "label"},
40
+ "ag_news": {"text": "text", "label": "label"},
41
+ "dbpedia_14": {"content": "content", "label": "label"},
42
+ "20newsgroups": {"text": "text", "label": "label"},
43
+ }
44
+
45
+ # Task-specific dataset templates
46
+ TASK_DATASET_TEMPLATES = {
47
+ "causal-lm": {
48
+ "text_column": "text",
49
+ "format": "causal",
50
+ "examples": ["wikitext", "openwebtext", "the_pile", "c4", "oscar"],
51
+ },
52
+ "seq2seq": {
53
+ "input_column": None,
54
+ "target_column": None,
55
+ "format": "seq2seq",
56
+ "examples": ["cnn_dailymail", "xsum", "samsum", "billsum", "aeslc"],
57
+ },
58
+ "token-classification": {
59
+ "tokens_column": "tokens",
60
+ "labels_column": "ner_tags",
61
+ "format": "token",
62
+ "examples": ["conll2003", "wnut_17", "ontonotes5"],
63
+ },
64
+ "text-classification": {
65
+ "text_column": "text",
66
+ "label_column": "label",
67
+ "format": "classification",
68
+ "examples": ["imdb", "yelp_polarity", "sst2", "ag_news", "dbpedia_14"],
69
+ },
70
+ "question-answering": {
71
+ "context_column": "context",
72
+ "question_column": "question",
73
+ "answers_column": "answers",
74
+ "format": "qa",
75
+ "examples": ["squad", "squad_v2", "natural_questions", "hotpotqa"],
76
+ },
77
+ "reasoning": {
78
+ "input_column": "input",
79
+ "target_column": "target",
80
+ "format": "causal",
81
+ "examples": ["gsm8k", "strategyqa", "aqua"],
82
+ },
83
+ }
84
+
85
+
86
+ def get_dataset_info(dataset_name: str) -> Dict[str, Any]:
87
+ """Get information about a dataset from HuggingFace Hub."""
88
+ try:
89
+ from huggingface_hub import HfApi, dataset_info
90
+
91
+ api = HfApi()
92
+ info = api.dataset_info(dataset_name)
93
+
94
+ return {
95
+ "id": info.id,
96
+ "author": info.author,
97
+ "sha": info.sha,
98
+ "downloads": getattr(info, "downloads", 0),
99
+ "tags": info.tags or [],
100
+ "description": getattr(info, "description", ""),
101
+ "card_data": getattr(info, "card_data", {}),
102
+ "siblings": [s.rfilename for s in info.siblings] if info.siblings else [],
103
+ "size_bytes": sum(getattr(s, "size", 0) or 0 for s in info.siblings) if info.siblings else 0,
104
+ }
105
+ except Exception as e:
106
+ logger.error(f"Error getting dataset info for {dataset_name}: {e}")
107
+ return {"error": str(e)}
108
+
109
+
110
+ def load_and_validate_dataset(
111
+ dataset_name: str,
112
+ config: Optional[str] = None,
113
+ split: Optional[str] = None,
114
+ trust_remote_code: bool = False,
115
+ ) -> Tuple[Optional[DatasetDict], Optional[str]]:
116
+ """Load a dataset and validate it."""
117
+ try:
118
+ kwargs = {"trust_remote_code": trust_remote_code}
119
+ if config:
120
+ kwargs["name"] = config
121
+ if split:
122
+ kwargs["split"] = split
123
+
124
+ dataset = load_dataset(dataset_name, **kwargs)
125
+
126
+ # If single split returned, wrap in dict
127
+ if isinstance(dataset, Dataset):
128
+ dataset = DatasetDict({"train": dataset})
129
+
130
+ return dataset, None
131
+
132
+ except Exception as e:
133
+ logger.error(f"Error loading dataset {dataset_name}: {e}")
134
+ return None, str(e)
135
+
136
+
137
+ def get_dataset_schema(dataset: DatasetDict) -> Dict[str, Any]:
138
+ """Get the schema of a dataset."""
139
+ if not dataset:
140
+ return {}
141
+
142
+ # Get first available split
143
+ first_split = list(dataset.keys())[0]
144
+ ds = dataset[first_split]
145
+
146
+ schema = {
147
+ "splits": list(dataset.keys()),
148
+ "columns": {},
149
+ "num_rows": {},
150
+ "features": {},
151
+ }
152
+
153
+ for split_name, split_ds in dataset.items():
154
+ schema["num_rows"][split_name] = len(split_ds)
155
+
156
+ for col in ds.column_names:
157
+ col_info = {"name": col}
158
+ feature = ds.features.get(col)
159
+ if feature:
160
+ col_info["dtype"] = str(feature.dtype) if hasattr(feature, "dtype") else str(type(feature))
161
+ if hasattr(feature, "names"):
162
+ col_info["label_names"] = list(feature.names)
163
+ col_info["feature_type"] = type(feature).__name__
164
+ schema["columns"][col] = col_info
165
+ schema["features"][col] = str(feature) if feature else "unknown"
166
+
167
+ return schema
168
+
169
+
170
+ def detect_task_type(dataset_name: str, dataset: DatasetDict) -> str:
171
+ """Detect the likely task type for a dataset based on its columns."""
172
+ if not dataset:
173
+ return "unknown"
174
+
175
+ first_split = list(dataset.keys())[0]
176
+ columns = set(dataset[first_split].column_names)
177
+
178
+ # Check for specific patterns
179
+ if "tokens" in columns and "ner_tags" in columns:
180
+ return "token-classification"
181
+ if "question" in columns and "context" in columns:
182
+ return "question-answering"
183
+ if "article" in columns or "document" in columns:
184
+ return "seq2seq"
185
+ if "text" in columns and "label" in columns:
186
+ return "text-classification"
187
+ if "text" in columns and len(columns) <= 3:
188
+ return "causal-lm"
189
+ if "dialogue" in columns or "summary" in columns:
190
+ return "seq2seq"
191
+ if "input" in columns and "target" in columns:
192
+ return "causal-lm"
193
+
194
+ # Default
195
+ return "causal-lm"
196
+
197
+
198
+ def get_dataset_columns_for_task(
199
+ dataset: DatasetDict,
200
+ task_type: str
201
+ ) -> Dict[str, str]:
202
+ """Get the appropriate column mapping for a task."""
203
+ if not dataset:
204
+ return {}
205
+
206
+ first_split = list(dataset.keys())[0]
207
+ columns = set(dataset[first_split].column_names)
208
+
209
+ mapping = {}
210
+
211
+ if task_type == "causal-lm":
212
+ # Look for text column
213
+ for col in ["text", "content", "document", "article", "input"]:
214
+ if col in columns:
215
+ mapping["text_column"] = col
216
+ break
217
+ if not mapping and len(columns) == 1:
218
+ mapping["text_column"] = list(columns)[0]
219
+
220
+ elif task_type == "seq2seq":
221
+ for col in ["article", "document", "text", "input", "dialogue"]:
222
+ if col in columns:
223
+ mapping["input_column"] = col
224
+ break
225
+ for col in ["highlights", "summary", "target", "output", "subject_line"]:
226
+ if col in columns:
227
+ mapping["target_column"] = col
228
+ break
229
+
230
+ elif task_type == "token-classification":
231
+ for col in ["tokens", "words"]:
232
+ if col in columns:
233
+ mapping["tokens_column"] = col
234
+ break
235
+ for col in ["ner_tags", "labels", "tags"]:
236
+ if col in columns:
237
+ mapping["labels_column"] = col
238
+ break
239
+
240
+ elif task_type == "text-classification":
241
+ for col in ["text", "sentence", "content", "review"]:
242
+ if col in columns:
243
+ mapping["text_column"] = col
244
+ break
245
+ for col in ["label", "labels", "class", "category"]:
246
+ if col in columns:
247
+ mapping["label_column"] = col
248
+ break
249
+
250
+ elif task_type == "question-answering":
251
+ for col in ["context"]:
252
+ if col in columns:
253
+ mapping["context_column"] = col
254
+ for col in ["question"]:
255
+ if col in columns:
256
+ mapping["question_column"] = col
257
+ for col in ["answers", "answer"]:
258
+ if col in columns:
259
+ mapping["answers_column"] = col
260
+
261
+ return mapping
262
+
263
+
264
+ def prepare_dataset_for_training(
265
+ dataset: DatasetDict,
266
+ tokenizer: Any,
267
+ task_type: str,
268
+ column_mapping: Dict[str, str],
269
+ max_length: int = 512,
270
+ padding: str = "max_length",
271
+ truncation: bool = True,
272
+ ) -> Tuple[DatasetDict, Dict[str, Any]]:
273
+ """Prepare dataset for training by tokenizing."""
274
+
275
+ stats = {
276
+ "original_samples": {},
277
+ "processed_samples": {},
278
+ "avg_length": {},
279
+ "removed_samples": {},
280
+ }
281
+
282
+ def tokenize_function(examples, text_col=None, target_col=None):
283
+ """Tokenize function based on task type."""
284
+ if task_type == "causal-lm":
285
+ text_col = column_mapping.get("text_column", "text")
286
+ if text_col not in examples:
287
+ return examples
288
+
289
+ outputs = tokenizer(
290
+ examples[text_col],
291
+ padding=padding,
292
+ truncation=truncation,
293
+ max_length=max_length,
294
+ return_tensors=None,
295
+ )
296
+ outputs["labels"] = outputs["input_ids"].copy()
297
+ return outputs
298
+
299
+ elif task_type == "seq2seq":
300
+ input_col = column_mapping.get("input_column")
301
+ target_col = column_mapping.get("target_column")
302
+
303
+ if not input_col or not target_col:
304
+ raise ValueError(f"Missing columns for seq2seq: {column_mapping}")
305
+
306
+ model_inputs = tokenizer(
307
+ examples[input_col],
308
+ padding=padding,
309
+ truncation=truncation,
310
+ max_length=max_length,
311
+ )
312
+
313
+ with tokenizer.as_target_tokenizer():
314
+ labels = tokenizer(
315
+ examples[target_col],
316
+ padding=padding,
317
+ truncation=truncation,
318
+ max_length=max_length,
319
+ )
320
+
321
+ model_inputs["labels"] = labels["input_ids"]
322
+ return model_inputs
323
+
324
+ elif task_type == "token-classification":
325
+ tokens_col = column_mapping.get("tokens_column", "tokens")
326
+ labels_col = column_mapping.get("labels_column", "ner_tags")
327
+
328
+ if tokens_col not in examples or labels_col not in examples:
329
+ return examples
330
+
331
+ tokenized_inputs = tokenizer(
332
+ examples[tokens_col],
333
+ padding=padding,
334
+ truncation=truncation,
335
+ max_length=max_length,
336
+ is_split_into_words=True,
337
+ )
338
+
339
+ labels = []
340
+ for i, label in enumerate(examples[labels_col]):
341
+ word_ids = tokenized_inputs.word_ids(batch_index=i)
342
+ previous_word_idx = None
343
+ label_ids = []
344
+ for word_idx in word_ids:
345
+ if word_idx is None:
346
+ label_ids.append(-100)
347
+ elif word_idx != previous_word_idx:
348
+ label_ids.append(label[word_idx])
349
+ else:
350
+ label_ids.append(-100)
351
+ previous_word_idx = word_idx
352
+ labels.append(label_ids)
353
+
354
+
355
+ tokenized_inputs["labels"] = labels
356
+ return tokenized_inputs
357
+
358
+ elif task_type == "text-classification":
359
+ text_col = column_mapping.get("text_column", "text")
360
+ if text_col not in examples:
361
+ return examples
362
+
363
+ tokenized = tokenizer(
364
+ examples[text_col],
365
+ padding=padding,
366
+ truncation=truncation,
367
+ max_length=max_length,
368
+ )
369
+
370
+ # Add labels if present
371
+ label_col = column_mapping.get("label_column", "label")
372
+ if label_col in examples:
373
+ tokenized["labels"] = examples[label_col]
374
+
375
+ return tokenized
376
+
377
+ elif task_type == "question-answering":
378
+ context_col = column_mapping.get("context_column", "context")
379
+ question_col = column_mapping.get("question_column", "question")
380
+ answers_col = column_mapping.get("answers_column", "answers")
381
+
382
+ tokenized = tokenizer(
383
+ examples[question_col],
384
+ examples[context_col],
385
+ padding=padding,
386
+ truncation=truncation,
387
+ max_length=max_length,
388
+ )
389
+
390
+ # Process answers
391
+ if answers_col in examples:
392
+ # Simplified - full implementation would compute token positions
393
+ tokenized["labels"] = [[0, 0] for _ in examples[answers_col]]
394
+
395
+ return tokenized
396
+
397
+
398
+ return examples
399
+
400
+ # Tokenize each split
401
+ tokenized_datasets = DatasetDict()
402
+ for split_name, split_ds in dataset.items():
403
+ stats["original_samples"][split_name] = len(split_ds)
404
+
405
+ # Remove columns that aren't needed (keep label-related columns)
406
+ remove_columns = []
407
+ for col in split_ds.column_names:
408
+ if col not in ["labels", "label", "input_ids", "attention_mask"]:
409
+ if col not in column_mapping.values():
410
+ remove_columns.append(col)
411
+
412
+
413
+ tokenized = split_ds.map(
414
+ tokenize_function,
415
+ batched=True,
416
+ remove_columns=remove_columns,
417
+ desc=f"Tokenizing {split_name}",
418
+ )
419
+
420
+ tokenized_datasets[split_name] = tokenized
421
+ stats["processed_samples"][split_name] = len(tokenized)
422
+
423
+ return tokenized_datasets, stats
424
+
425
+
426
+ def split_dataset(
427
+ dataset: DatasetDict,
428
+ train_split: float = 0.9,
429
+ val_split: float = 0.1,
430
+ seed: int = 42,
431
+ ) -> DatasetDict:
432
+ """Split a dataset into train and validation sets."""
433
+ if "validation" in dataset:
434
+ return dataset
435
+
436
+ if "train" in dataset:
437
+ split_dataset = dataset["train"].train_test_split(
438
+ test_size=val_split,
439
+ seed=seed,
440
+ )
441
+ return DatasetDict({
442
+ "train": split_dataset["train"],
443
+ "validation": split_dataset["test"],
444
+ })
445
+
446
+ return dataset
447
+
448
+
449
+ def sample_dataset(
450
+ dataset: DatasetDict,
451
+ n_samples: int,
452
+ split: str = "train",
453
+ seed: int = 42,
454
+ ) -> DatasetDict:
455
+ """Sample a subset of the dataset for quick testing."""
456
+ if split not in dataset:
457
+ return dataset
458
+
459
+ sampled = dataset[split].shuffle(seed=seed).select(range(min(n_samples, len(dataset[split]))))
460
+
461
+ result = dict(dataset)
462
+ result[split] = sampled
463
+ return DatasetDict(result)
464
+
465
+
466
+ def get_label_list(dataset: DatasetDict, label_column: str = "label") -> List[str]:
467
+ """Get list of labels from dataset."""
468
+ if not dataset:
469
+ return []
470
+
471
+ for split_name, split_ds in dataset.items():
472
+ if label_column in split_ds.column_names:
473
+ features = split_ds.features.get(label_column)
474
+ if features and hasattr(features, "names"):
475
+ return list(features.names)
476
+ elif features and hasattr(features, "int2str"):
477
+ # Try to infer number of labels
478
+ unique_labels = set(split_ds[label_column])
479
+ return [str(i) for i in range(max(unique_labels) + 1)]
480
+
481
+ return []
482
+
483
+
484
+ def estimate_dataset_size(dataset: DatasetDict) -> Dict[str, Any]:
485
+ """Estimate dataset size in memory."""
486
+ if not dataset:
487
+ return {"total_samples": 0, "estimated_size_mb": 0}
488
+
489
+ total_samples = sum(len(split) for split in dataset.values())
490
+
491
+ # Rough estimation: ~1KB per sample for text
492
+ estimated_size_mb = total_samples * 0.001
493
+
494
+ return {
495
+ "total_samples": total_samples,
496
+ "estimated_size_mb": round(estimated_size_mb, 2),
497
+ "splits": {name: len(split) for name, split in dataset.items()},
498
+ }
499
+
500
+
501
+ def validate_dataset_for_task(
502
+ dataset: DatasetDict,
503
+ task_type: str,
504
+ column_mapping: Dict[str, str],
505
+ ) -> Tuple[bool, List[str]]:
506
+ """Validate that a dataset is suitable for a task."""
507
+ issues = []
508
+
509
+ if not dataset:
510
+ return False, ["Dataset is empty or could not be loaded"]
511
+
512
+ first_split = list(dataset.keys())[0]
513
+ columns = set(dataset[first_split].column_names)
514
+
515
+ if task_type == "causal-lm":
516
+ text_col = column_mapping.get("text_column")
517
+ if not text_col or text_col not in columns:
518
+ issues.append(f"Missing text column. Found: {columns}")
519
+
520
+ elif task_type == "seq2seq":
521
+ input_col = column_mapping.get("input_column")
522
+ target_col = column_mapping.get("target_column")
523
+ if not input_col or input_col not in columns:
524
+ issues.append(f"Missing input column. Found: {columns}")
525
+ if not target_col or target_col not in columns:
526
+ issues.append(f"Missing target column. Found: {columns}")
527
+
528
+ elif task_type == "token-classification":
529
+ tokens_col = column_mapping.get("tokens_column")
530
+ labels_col = column_mapping.get("labels_column")
531
+ if not tokens_col or tokens_col not in columns:
532
+ issues.append(f"Missing tokens column. Found: {columns}")
533
+ if not labels_col or labels_col not in columns:
534
+ issues.append(f"Missing labels column. Found: {columns}")
535
+
536
+ elif task_type == "text-classification":
537
+ text_col = column_mapping.get("text_column")
538
+ label_col = column_mapping.get("label_column")
539
+ if not text_col or text_col not in columns:
540
+ issues.append(f"Missing text column. Found: {columns}")
541
+ if not label_col or label_col not in columns:
542
+ issues.append(f"Missing label column. Found: {columns}")
543
+
544
+ elif task_type == "question-answering":
545
+ required = ["context_column", "question_column", "answers_column"]
546
+ for col_key in required:
547
+ col = column_mapping.get(col_key)
548
+ if not col or col not in columns:
549
+ issues.append(f"Missing {col_key}. Found: {columns}")
550
+
551
+ return len(issues) == 0, issues