Spaces:
Running
Running
File size: 11,970 Bytes
c374021 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 | """
experiments/data_prep_analysis.py
===================================
Compares caption quality and model performance BEFORE vs AFTER applying
data preparation quality filters to the COCO dataset.
Filters applied in the "after" condition:
1. Minimum word count: caption must have β₯ 5 words
2. Maximum word count: caption must have β€ 25 words
3. Short/Long/Mixed caption strategy switching
Usage:
python -m experiments.data_prep_analysis --model blip
Expected insight:
- Raw COCO captions include many very short (1-3 word) and very long (30+
word) references that add noise to training and evaluation.
- Filtering to 5-25 words focuses training on informative mid-length
captions and typically improves CIDEr by 3-8% on the eval set.
- Mixed strategy (randomly choosing from long, short, or medium captions)
improves robustness but individual CIDEr may be slightly lower than a
targeted strategy.
"""
import argparse
import random
import torch
from tqdm.auto import tqdm
from datasets import load_dataset
import aiohttp
from torch.utils.data import DataLoader
from pycocoevalcap.cider.cider import Cider
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Caption Filtering Functions
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def filter_low_quality_captions(captions: list, min_words: int = 5,
max_words: int = 25) -> list:
"""
Filter a list of captions to only include those within the word count range.
Args:
captions : list of caption strings
min_words : minimum word count (inclusive)
max_words : maximum word count (inclusive)
Returns:
filtered : list of captions meeting the criteria (may be empty)
"""
return [
c for c in captions
if min_words <= len(c.split()) <= max_words
]
def pick_caption_raw(example: dict) -> str:
"""Pick any random caption from the example (no filtering)."""
return random.choice(example["captions"])
def pick_caption_filtered(example: dict, min_words: int = 5,
max_words: int = 25) -> str:
"""Pick a filtered caption; fallback to raw random if none pass filter."""
filtered = filter_low_quality_captions(
example["captions"], min_words, max_words
)
pool = filtered if filtered else example["captions"]
return random.choice(pool)
def pick_caption_short(example: dict, max_words: int = 9) -> str:
"""Pick a short caption (β€ max_words); fallback to raw if none qualify."""
short = [c for c in example["captions"] if len(c.split()) <= max_words]
return random.choice(short) if short else random.choice(example["captions"])
def pick_caption_long(example: dict, min_words: int = 12) -> str:
"""Pick a long caption (β₯ min_words); fallback to raw if none qualify."""
long = [c for c in example["captions"] if len(c.split()) >= min_words]
return random.choice(long) if long else random.choice(example["captions"])
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Caption Distribution Analysis
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def analyze_caption_distribution(ds, n_samples: int = 500) -> dict:
"""
Compute word-count distribution statistics for a HF dataset split.
Returns dict with mean, median, p10, p90, pct_short, pct_long.
"""
import numpy as np
lengths = []
for ex in ds.select(range(min(n_samples, len(ds)))):
for cap in ex["captions"]:
lengths.append(len(cap.split()))
lengths = sorted(lengths)
n = len(lengths)
return {
"count": n,
"mean": sum(lengths) / n,
"min": lengths[0],
"max": lengths[-1],
"p10": lengths[int(n * 0.10)],
"p50": lengths[int(n * 0.50)],
"p90": lengths[int(n * 0.90)],
"pct_short": sum(1 for l in lengths if l < 5) / n * 100,
"pct_long": sum(1 for l in lengths if l > 25) / n * 100,
}
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Eval Helper
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _eval_blip_cider(model, processor, dataloader, device, eval_batches=15):
"""Quick BLIP inference CIDEr eval over a dataloader."""
from models.blip_tuner import generate_with_mask
model.eval()
gts, res = {}, {}
with torch.no_grad():
for i, batch in enumerate(tqdm(dataloader, desc="Evaluating", leave=False)):
if i >= eval_batches:
break
pixel_values = batch["pixel_values"].to(device)
mask = torch.ones(pixel_values.shape[0], 197,
dtype=torch.long, device=device)
decoded = generate_with_mask(
model, processor, device=device,
pixel_values=pixel_values, encoder_attention_mask=mask,
max_new_tokens=32, num_beams=4,
)
preds = decoded # generate_with_mask returns decoded strings
gts_batch = processor.batch_decode(
batch["labels"], skip_special_tokens=True
)
for j, (p, g) in enumerate(zip(preds, gts_batch)):
k = str(i * len(preds) + j)
res[k] = [p]
gts[k] = [g]
if not gts:
return 0.0
scorer = Cider()
score, _ = scorer.compute_score(gts, res)
return score
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Main Analysis Runner
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def run_data_prep_analysis(model, processor, dataset_id, device, cfg,
eval_batches=15):
"""
Evaluate CIDEr under three caption selection strategies:
1. Raw β any random caption (no filtering)
2. Short β captions β€ 9 words
3. Long β captions β₯ 12 words
4. Filtered (Mixed) β captions 5-25 words
Prints a before/after comparison table and key insights.
"""
print("\nπ Data Preparation Analysis")
print("=" * 60)
ds = load_dataset(
dataset_id,
storage_options={"client_kwargs": {
"timeout": aiohttp.ClientTimeout(total=3600)
}},
)
val_split = "validation" if "validation" in ds else "train"
val_hf = ds[val_split].shuffle(seed=43).select(range(min(200, len(ds[val_split]))))
print("\nπ Caption Word-Count Distribution (val set sample):")
stats = analyze_caption_distribution(val_hf)
print(f" Count : {stats['count']}")
print(f" Mean : {stats['mean']:.1f} words")
print(f" Range : {stats['min']} β {stats['max']} words")
print(f" P10/P50/P90: {stats['p10']} / {stats['p50']} / {stats['p90']}")
print(f" % Short (<5 words) : {stats['pct_short']:.1f}%")
print(f" % Long (>25 words): {stats['pct_long']:.1f}%")
strategies = {
"raw": pick_caption_raw,
"short": pick_caption_short,
"long": pick_caption_long,
"filtered": pick_caption_filtered,
}
results = {}
for strat_name, pick_fn in strategies.items():
print(f"\n Running strategy: '{strat_name}'...")
def _collate(examples, _pick=pick_fn):
images = [ex["image"].convert("RGB") for ex in examples]
captions = [_pick(ex) for ex in examples]
enc = processor(
images=images, text=captions,
padding="max_length", truncation=True,
max_length=cfg.max_target_len, return_tensors="pt",
)
enc["labels"] = enc["input_ids"].clone()
return enc
val_loader = DataLoader(
val_hf, batch_size=cfg.batch_size, shuffle=False,
num_workers=0, collate_fn=_collate,
)
score = _eval_blip_cider(model, processor, val_loader, device, eval_batches)
results[strat_name] = score
print(f" β
CIDEr [{strat_name}]: {score:.4f}")
# ββ Summary Table βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
print("\n" + "=" * 60)
print(" Data Preparation β CIDEr Comparison")
print("=" * 60)
print(f" {'Strategy':<20} {'CIDEr':>8} {'Ξ Raw':>10} Notes")
print(" " + "-" * 56)
raw_score = results.get("raw", 0.0)
notes = {
"raw": "Baseline β no filtering",
"short": "Short captions β€ 9 words",
"long": "Long captions β₯ 12 words",
"filtered": "Quality filter 5-25 words β recommended",
}
for strat, score in results.items():
delta = score - raw_score
sign = "+" if delta >= 0 else ""
print(f" {strat:<20} {score:>8.4f} {sign}{delta:>9.4f} {notes[strat]}")
print("=" * 60)
print("\nπ‘ Key Insight:")
best = max(results, key=results.get)
if best == "raw":
print(" Raw captions perform comparably β dataset is already clean.")
else:
gain = results[best] - raw_score
print(f" '{best}' strategy improves CIDEr by {gain:+.4f} over raw captions.")
print(" Recommendation: use 'filtered' strategy (5-25 words) for")
print(" reproducible, balanced training across all models.\n")
return results
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# CLI
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def main():
parser = argparse.ArgumentParser(description="Data preparation analysis")
parser.add_argument("--eval_batches", type=int, default=15)
args = parser.parse_args()
import sys, os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config import CFG
from models.blip_tuner import get_blip_model
device = torch.device(
"mps" if torch.backends.mps.is_available() else
"cuda" if torch.cuda.is_available() else "cpu"
)
cfg = CFG.load_for_model("blip")
model, processor = get_blip_model(cfg, device)
run_data_prep_analysis(
model, processor, cfg.dataset_id, device, cfg,
eval_batches=args.eval_batches,
)
if __name__ == "__main__":
main()
|