azure home scripts: data gen, training, misc
Browse files- AI4RTGSdevDB.camera_masters.json +0 -0
- GSoC_2026_Proposal_Vivek_OLMo.docx +0 -0
- close_the_loop.py +703 -0
- generate_dataset.py +160 -0
- generate_dataset_flux2.py +132 -0
- gsoc_proposal_content.md +277 -0
- gsoc_proposal_final.md +200 -0
- h100_training.log +98 -0
- make_proposal_doc.py +444 -0
- merge_v3.py +88 -0
- pull_extra.py +27 -0
- telugu_voice_clone.py +113 -0
- train_h100_clean.log +0 -0
- train_h100_clean.py +50 -0
- train_h100_final.py +230 -0
- train_v3.log +0 -0
- train_v3.py +35 -0
- turboquant_case_study.md +72 -0
- vivek_complete_profile.md +367 -0
- yc_scrape.py +68 -0
AI4RTGSdevDB.camera_masters.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
GSoC_2026_Proposal_Vivek_OLMo.docx
ADDED
|
Binary file (43.7 kB). View file
|
|
|
close_the_loop.py
ADDED
|
@@ -0,0 +1,703 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
=============================================================================
|
| 4 |
+
CLOSING THE AI-BRAIN LOOP
|
| 5 |
+
Using TRIBE v2 to identify architectural gaps between AI models and the brain
|
| 6 |
+
=============================================================================
|
| 7 |
+
|
| 8 |
+
Methodology:
|
| 9 |
+
Phase 1 — Load TRIBE v2, run inference, capture per-layer AI features
|
| 10 |
+
Phase 2 — Layer-wise encoding analysis: which AI layers predict which brain regions
|
| 11 |
+
Phase 3 — Modality ablation: which encoder drives which brain area
|
| 12 |
+
Phase 4 — RSA: representational similarity between AI layers and brain ROIs
|
| 13 |
+
Phase 5 — Divergence mapping: where the brain does something AI can't capture
|
| 14 |
+
Phase 6 — Architectural implications: what's missing in current AI
|
| 15 |
+
|
| 16 |
+
Output: /home/azureuser/loop_results/
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import os
|
| 20 |
+
import sys
|
| 21 |
+
import logging
|
| 22 |
+
import warnings
|
| 23 |
+
import time
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
import torch
|
| 27 |
+
import matplotlib
|
| 28 |
+
matplotlib.use("Agg")
|
| 29 |
+
import matplotlib.pyplot as plt
|
| 30 |
+
from matplotlib.gridspec import GridSpec
|
| 31 |
+
import pandas as pd
|
| 32 |
+
from pathlib import Path
|
| 33 |
+
from scipy import stats
|
| 34 |
+
from scipy.spatial.distance import pdist, squareform
|
| 35 |
+
from einops import rearrange
|
| 36 |
+
|
| 37 |
+
warnings.filterwarnings("ignore")
|
| 38 |
+
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] %(message)s", datefmt="%H:%M:%S")
|
| 39 |
+
log = logging.getLogger(__name__)
|
| 40 |
+
|
| 41 |
+
OUT = Path("/home/azureuser/loop_results")
|
| 42 |
+
OUT.mkdir(exist_ok=True)
|
| 43 |
+
|
| 44 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 45 |
+
# PHASE 1: Load model, run inference, capture intermediate representations
|
| 46 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 47 |
+
|
| 48 |
+
log.info("PHASE 1: Loading TRIBE v2 and running inference")
|
| 49 |
+
|
| 50 |
+
from tribev2 import TribeModel
|
| 51 |
+
|
| 52 |
+
CACHE = "/home/azureuser/cache"
|
| 53 |
+
VIDEO = "/home/azureuser/test_stimulus.mp4"
|
| 54 |
+
|
| 55 |
+
model = TribeModel.from_pretrained("facebook/tribev2", cache_folder=CACHE)
|
| 56 |
+
fmri_model = model._model
|
| 57 |
+
device = fmri_model.device
|
| 58 |
+
log.info(f"Model on {device}, n_outputs={fmri_model.n_outputs}")
|
| 59 |
+
|
| 60 |
+
# Print model structure
|
| 61 |
+
log.info("Model structure:")
|
| 62 |
+
log.info(f" Projectors: {list(fmri_model.projectors.keys())}")
|
| 63 |
+
log.info(f" Hidden dim: {fmri_model.config.hidden}")
|
| 64 |
+
log.info(f" Layer aggregation: {fmri_model.config.layer_aggregation}")
|
| 65 |
+
log.info(f" Extractor aggregation: {fmri_model.config.extractor_aggregation}")
|
| 66 |
+
if hasattr(fmri_model, 'encoder'):
|
| 67 |
+
log.info(f" Encoder: {type(fmri_model.encoder).__name__}")
|
| 68 |
+
if hasattr(fmri_model, 'low_rank_head'):
|
| 69 |
+
log.info(f" Low-rank head: {fmri_model.low_rank_head}")
|
| 70 |
+
|
| 71 |
+
# Build events
|
| 72 |
+
log.info(f"Processing video: {VIDEO}")
|
| 73 |
+
events = model.get_events_dataframe(video_path=VIDEO)
|
| 74 |
+
log.info(f"Events: {len(events)} rows, types: {events.type.unique().tolist()}")
|
| 75 |
+
|
| 76 |
+
# Get data loader
|
| 77 |
+
loader = model.data.get_loaders(events=events, split_to_build="all")["all"]
|
| 78 |
+
|
| 79 |
+
# ── Collect per-layer features and brain predictions ──
|
| 80 |
+
# We intercept batch.data to get the raw per-layer features (B, L, D, T)
|
| 81 |
+
# and run the model to get brain predictions
|
| 82 |
+
|
| 83 |
+
all_features = {} # modality -> list of (L, D, T) arrays
|
| 84 |
+
all_projected = {} # modality -> list of (T, H) arrays
|
| 85 |
+
all_post_encoder = [] # list of (T, H) arrays
|
| 86 |
+
all_brain_preds = [] # list of (T, V) arrays — V = n_vertices
|
| 87 |
+
|
| 88 |
+
# Register hooks on projectors to capture projected features
|
| 89 |
+
proj_captures = {}
|
| 90 |
+
proj_hooks = []
|
| 91 |
+
|
| 92 |
+
def make_proj_hook(name):
|
| 93 |
+
def hook(mod, inp, out):
|
| 94 |
+
proj_captures[name] = out.detach().cpu().numpy()
|
| 95 |
+
return hook
|
| 96 |
+
|
| 97 |
+
for mod_name, proj in fmri_model.projectors.items():
|
| 98 |
+
proj_hooks.append(proj.register_forward_hook(make_proj_hook(mod_name)))
|
| 99 |
+
|
| 100 |
+
# Hook encoder output
|
| 101 |
+
encoder_capture = [None]
|
| 102 |
+
def enc_hook(mod, inp, out):
|
| 103 |
+
encoder_capture[0] = out.detach().cpu().numpy()
|
| 104 |
+
if hasattr(fmri_model, 'encoder') and fmri_model.encoder is not None:
|
| 105 |
+
proj_hooks.append(fmri_model.encoder.register_forward_hook(enc_hook))
|
| 106 |
+
|
| 107 |
+
log.info("Running inference with hooks...")
|
| 108 |
+
t0 = time.time()
|
| 109 |
+
|
| 110 |
+
with torch.inference_mode():
|
| 111 |
+
for batch_idx, batch in enumerate(loader):
|
| 112 |
+
batch = batch.to(device)
|
| 113 |
+
|
| 114 |
+
# Capture raw per-layer features before the model touches them
|
| 115 |
+
for mod_name in fmri_model.projectors.keys():
|
| 116 |
+
if mod_name in batch.data:
|
| 117 |
+
feat = batch.data[mod_name].detach().cpu().numpy() # (B, L, D, T) or (B, D, T)
|
| 118 |
+
if feat.ndim == 3:
|
| 119 |
+
feat = feat[:, np.newaxis, :, :] # ensure 4D
|
| 120 |
+
if mod_name not in all_features:
|
| 121 |
+
all_features[mod_name] = []
|
| 122 |
+
all_features[mod_name].append(feat)
|
| 123 |
+
|
| 124 |
+
# Forward pass (triggers hooks)
|
| 125 |
+
y_pred = fmri_model(batch).detach().cpu().numpy() # (B, V, T')
|
| 126 |
+
y_pred = rearrange(y_pred, 'b v t -> (b t) v')
|
| 127 |
+
all_brain_preds.append(y_pred)
|
| 128 |
+
|
| 129 |
+
# Save projected features
|
| 130 |
+
for mod_name in proj_captures:
|
| 131 |
+
if mod_name not in all_projected:
|
| 132 |
+
all_projected[mod_name] = []
|
| 133 |
+
all_projected[mod_name].append(proj_captures[mod_name])
|
| 134 |
+
proj_captures.clear()
|
| 135 |
+
|
| 136 |
+
# Save encoder output
|
| 137 |
+
if encoder_capture[0] is not None:
|
| 138 |
+
all_post_encoder.append(encoder_capture[0])
|
| 139 |
+
encoder_capture[0] = None
|
| 140 |
+
|
| 141 |
+
# Clean hooks
|
| 142 |
+
for h in proj_hooks:
|
| 143 |
+
h.remove()
|
| 144 |
+
|
| 145 |
+
elapsed = time.time() - t0
|
| 146 |
+
log.info(f"Inference done in {elapsed:.1f}s")
|
| 147 |
+
|
| 148 |
+
# Concatenate results
|
| 149 |
+
brain_preds = np.concatenate(all_brain_preds, axis=0) # (T_total, V)
|
| 150 |
+
log.info(f"Brain predictions: {brain_preds.shape}")
|
| 151 |
+
|
| 152 |
+
for mod in all_features:
|
| 153 |
+
all_features[mod] = np.concatenate(all_features[mod], axis=0) # (B_total, L, D, T)
|
| 154 |
+
log.info(f"Raw features [{mod}]: {all_features[mod].shape}")
|
| 155 |
+
|
| 156 |
+
for mod in all_projected:
|
| 157 |
+
all_projected[mod] = np.concatenate(all_projected[mod], axis=0) # (B_total, T, H)
|
| 158 |
+
log.info(f"Projected features [{mod}]: {all_projected[mod].shape}")
|
| 159 |
+
|
| 160 |
+
if all_post_encoder:
|
| 161 |
+
post_encoder = np.concatenate(all_post_encoder, axis=0) # (B_total, T, H)
|
| 162 |
+
log.info(f"Post-encoder features: {post_encoder.shape}")
|
| 163 |
+
|
| 164 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 165 |
+
# PHASE 2: Brain parcellation — map vertices to functional regions
|
| 166 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 167 |
+
|
| 168 |
+
log.info("PHASE 2: Loading brain parcellation (Destrieux atlas, fsaverage5)")
|
| 169 |
+
|
| 170 |
+
from nilearn import datasets
|
| 171 |
+
from nilearn.surface import load_surf_data
|
| 172 |
+
|
| 173 |
+
fsaverage5 = datasets.fetch_surf_fsaverage("fsaverage5")
|
| 174 |
+
|
| 175 |
+
# Load Destrieux parcellation labels
|
| 176 |
+
labels_lh = load_surf_data(fsaverage5["annot_left_destrieux"])
|
| 177 |
+
labels_rh = load_surf_data(fsaverage5["annot_right_destrieux"])
|
| 178 |
+
|
| 179 |
+
N_VERT = 10242 # vertices per hemisphere in fsaverage5
|
| 180 |
+
all_labels = np.concatenate([labels_lh, labels_rh]) # (20484,)
|
| 181 |
+
|
| 182 |
+
# Destrieux atlas label names
|
| 183 |
+
destrieux = datasets.fetch_atlas_destrieux_2009()
|
| 184 |
+
label_names_raw = destrieux["labels"]
|
| 185 |
+
label_names = {}
|
| 186 |
+
for i, name in enumerate(label_names_raw):
|
| 187 |
+
if isinstance(name, bytes):
|
| 188 |
+
name = name.decode("utf-8")
|
| 189 |
+
label_names[i] = name
|
| 190 |
+
|
| 191 |
+
# Build region info
|
| 192 |
+
regions = {}
|
| 193 |
+
for lid in np.unique(all_labels):
|
| 194 |
+
mask = all_labels == lid
|
| 195 |
+
n = mask.sum()
|
| 196 |
+
if n < 5:
|
| 197 |
+
continue
|
| 198 |
+
name = label_names.get(int(lid), f"region_{lid}")
|
| 199 |
+
if name == "Unknown" or name == "Medial_wall":
|
| 200 |
+
continue
|
| 201 |
+
regions[int(lid)] = {"name": name, "mask": mask, "n_vertices": int(n)}
|
| 202 |
+
|
| 203 |
+
log.info(f"Found {len(regions)} usable brain regions")
|
| 204 |
+
|
| 205 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 206 |
+
# PHASE 3: Layer-wise encoding analysis
|
| 207 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 208 |
+
|
| 209 |
+
log.info("PHASE 3: Layer-wise encoding analysis")
|
| 210 |
+
|
| 211 |
+
# For each modality and each cached layer, compute how well that layer's
|
| 212 |
+
# features correlate with predicted brain activity in each region.
|
| 213 |
+
#
|
| 214 |
+
# Method: For each region, average brain predictions across vertices to get
|
| 215 |
+
# a time series. For each AI layer, average features across the feature dim
|
| 216 |
+
# to get a time series. Compute Pearson correlation.
|
| 217 |
+
|
| 218 |
+
# First, build time-aligned representations
|
| 219 |
+
# brain_preds is (T_total, V) — each row is one TR's brain prediction
|
| 220 |
+
T_total = brain_preds.shape[0]
|
| 221 |
+
V = brain_preds.shape[1]
|
| 222 |
+
|
| 223 |
+
# Per-region mean brain activity over time
|
| 224 |
+
region_timeseries = {}
|
| 225 |
+
for lid, rinfo in regions.items():
|
| 226 |
+
region_timeseries[lid] = brain_preds[:, rinfo["mask"]].mean(axis=1) # (T_total,)
|
| 227 |
+
|
| 228 |
+
# Per-modality, per-layer feature time series
|
| 229 |
+
# all_features[mod] is (B, L, D, T_batch) — we need to flatten B and T_batch
|
| 230 |
+
layer_timeseries = {}
|
| 231 |
+
for mod, feats in all_features.items():
|
| 232 |
+
B, L, D, T_batch = feats.shape
|
| 233 |
+
# Reshape to (B*T_batch, L, D)
|
| 234 |
+
feats_flat = rearrange(feats, 'b l d t -> (b t) l d')
|
| 235 |
+
# Trim to match brain_preds length
|
| 236 |
+
min_t = min(feats_flat.shape[0], T_total)
|
| 237 |
+
feats_flat = feats_flat[:min_t]
|
| 238 |
+
|
| 239 |
+
layer_timeseries[mod] = {}
|
| 240 |
+
for l in range(L):
|
| 241 |
+
# Mean across feature dim to get single time series per layer
|
| 242 |
+
layer_timeseries[mod][l] = feats_flat[:, l, :].mean(axis=1) # (T,)
|
| 243 |
+
|
| 244 |
+
log.info("Computing layer-brain correlations...")
|
| 245 |
+
|
| 246 |
+
# Correlation matrix: (modality, layer) x (region)
|
| 247 |
+
layer_brain_corr = {}
|
| 248 |
+
for mod in layer_timeseries:
|
| 249 |
+
L = len(layer_timeseries[mod])
|
| 250 |
+
for l in range(L):
|
| 251 |
+
key = f"{mod}_L{l}"
|
| 252 |
+
layer_brain_corr[key] = {}
|
| 253 |
+
layer_ts = layer_timeseries[mod][l]
|
| 254 |
+
min_t = min(len(layer_ts), T_total)
|
| 255 |
+
for lid, rinfo in regions.items():
|
| 256 |
+
brain_ts = region_timeseries[lid][:min_t]
|
| 257 |
+
lt = layer_ts[:min_t]
|
| 258 |
+
if np.std(lt) < 1e-10 or np.std(brain_ts) < 1e-10:
|
| 259 |
+
r = 0.0
|
| 260 |
+
else:
|
| 261 |
+
r, _ = stats.pearsonr(lt, brain_ts)
|
| 262 |
+
layer_brain_corr[key][lid] = r
|
| 263 |
+
|
| 264 |
+
# Build matrix for visualization
|
| 265 |
+
all_layer_keys = sorted(layer_brain_corr.keys())
|
| 266 |
+
all_region_ids = sorted(regions.keys())
|
| 267 |
+
region_names_list = [regions[lid]["name"] for lid in all_region_ids]
|
| 268 |
+
|
| 269 |
+
corr_matrix = np.zeros((len(all_layer_keys), len(all_region_ids)))
|
| 270 |
+
for i, lk in enumerate(all_layer_keys):
|
| 271 |
+
for j, lid in enumerate(all_region_ids):
|
| 272 |
+
corr_matrix[i, j] = layer_brain_corr[lk].get(lid, 0)
|
| 273 |
+
|
| 274 |
+
log.info(f"Correlation matrix: {corr_matrix.shape} (AI layers x brain regions)")
|
| 275 |
+
|
| 276 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 277 |
+
# PHASE 4: Modality ablation — which encoder drives which brain area
|
| 278 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 279 |
+
|
| 280 |
+
log.info("PHASE 4: Modality ablation analysis")
|
| 281 |
+
|
| 282 |
+
modalities = list(fmri_model.projectors.keys())
|
| 283 |
+
log.info(f"Ablating modalities: {modalities}")
|
| 284 |
+
|
| 285 |
+
# Re-run inference with each modality zeroed out
|
| 286 |
+
ablation_preds = {"full": brain_preds}
|
| 287 |
+
|
| 288 |
+
with torch.inference_mode():
|
| 289 |
+
for mod_to_ablate in modalities:
|
| 290 |
+
log.info(f" Ablating: {mod_to_ablate}")
|
| 291 |
+
preds_list = []
|
| 292 |
+
loader = model.data.get_loaders(events=events, split_to_build="all")["all"]
|
| 293 |
+
for batch in loader:
|
| 294 |
+
batch = batch.to(device)
|
| 295 |
+
# Zero out target modality
|
| 296 |
+
if mod_to_ablate in batch.data:
|
| 297 |
+
original = batch.data[mod_to_ablate].clone()
|
| 298 |
+
batch.data[mod_to_ablate] = torch.zeros_like(original)
|
| 299 |
+
y = fmri_model(batch).detach().cpu().numpy()
|
| 300 |
+
y = rearrange(y, 'b v t -> (b t) v')
|
| 301 |
+
preds_list.append(y)
|
| 302 |
+
batch.data[mod_to_ablate] = original
|
| 303 |
+
else:
|
| 304 |
+
y = fmri_model(batch).detach().cpu().numpy()
|
| 305 |
+
y = rearrange(y, 'b v t -> (b t) v')
|
| 306 |
+
preds_list.append(y)
|
| 307 |
+
ablation_preds[mod_to_ablate] = np.concatenate(preds_list, axis=0)
|
| 308 |
+
log.info(f" shape: {ablation_preds[mod_to_ablate].shape}")
|
| 309 |
+
|
| 310 |
+
# Compute per-region modality importance
|
| 311 |
+
# Importance = mean absolute change when modality is removed
|
| 312 |
+
region_mod_importance = {}
|
| 313 |
+
for lid, rinfo in regions.items():
|
| 314 |
+
mask = rinfo["mask"]
|
| 315 |
+
full = ablation_preds["full"][:, mask]
|
| 316 |
+
imp = {}
|
| 317 |
+
for mod in modalities:
|
| 318 |
+
ablated = ablation_preds[mod][:, mask]
|
| 319 |
+
# Use both MSE change and correlation change
|
| 320 |
+
delta_mse = np.mean((full - ablated) ** 2)
|
| 321 |
+
imp[mod] = float(delta_mse)
|
| 322 |
+
total = sum(imp.values()) + 1e-12
|
| 323 |
+
imp_norm = {k: v / total for k, v in imp.items()}
|
| 324 |
+
region_mod_importance[lid] = imp_norm
|
| 325 |
+
|
| 326 |
+
log.info("Modality ablation done")
|
| 327 |
+
|
| 328 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 329 |
+
# PHASE 5: RSA — Representational Similarity Analysis
|
| 330 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 331 |
+
|
| 332 |
+
log.info("PHASE 5: Representational Similarity Analysis")
|
| 333 |
+
|
| 334 |
+
# Split predictions into temporal segments (2-second windows)
|
| 335 |
+
# and compute RDMs for AI features and brain regions
|
| 336 |
+
|
| 337 |
+
SEGMENT_SIZE = 2 # TRs per segment (adjustable)
|
| 338 |
+
n_segments = T_total // SEGMENT_SIZE
|
| 339 |
+
|
| 340 |
+
# Build segment-level representations
|
| 341 |
+
def build_segments(timeseries, n_segments, seg_size):
|
| 342 |
+
"""Average timeseries within segments."""
|
| 343 |
+
segs = []
|
| 344 |
+
for i in range(n_segments):
|
| 345 |
+
start = i * seg_size
|
| 346 |
+
end = start + seg_size
|
| 347 |
+
if end <= len(timeseries):
|
| 348 |
+
segs.append(timeseries[start:end].mean(axis=0) if timeseries.ndim > 1 else timeseries[start:end].mean())
|
| 349 |
+
return np.array(segs)
|
| 350 |
+
|
| 351 |
+
# Brain RDMs per region
|
| 352 |
+
log.info("Building brain RDMs per region...")
|
| 353 |
+
brain_rdms = {}
|
| 354 |
+
for lid, rinfo in regions.items():
|
| 355 |
+
region_data = brain_preds[:, rinfo["mask"]] # (T, n_verts)
|
| 356 |
+
seg_data = []
|
| 357 |
+
for i in range(n_segments):
|
| 358 |
+
s, e = i * SEGMENT_SIZE, (i + 1) * SEGMENT_SIZE
|
| 359 |
+
if e <= region_data.shape[0]:
|
| 360 |
+
seg_data.append(region_data[s:e].mean(axis=0))
|
| 361 |
+
if len(seg_data) < 3:
|
| 362 |
+
continue
|
| 363 |
+
seg_data = np.array(seg_data) # (n_segments, n_verts)
|
| 364 |
+
if seg_data.std() < 1e-10:
|
| 365 |
+
brain_rdms[lid] = np.zeros((len(seg_data), len(seg_data)))
|
| 366 |
+
else:
|
| 367 |
+
brain_rdms[lid] = squareform(pdist(seg_data, metric="correlation"))
|
| 368 |
+
|
| 369 |
+
# AI feature RDMs per modality per layer
|
| 370 |
+
log.info("Building AI feature RDMs per layer...")
|
| 371 |
+
ai_rdms = {}
|
| 372 |
+
for mod, feats in all_features.items():
|
| 373 |
+
B, L, D, T_batch = feats.shape
|
| 374 |
+
feats_flat = rearrange(feats, 'b l d t -> (b t) l d')
|
| 375 |
+
min_t = min(feats_flat.shape[0], T_total)
|
| 376 |
+
feats_flat = feats_flat[:min_t]
|
| 377 |
+
|
| 378 |
+
for l in range(L):
|
| 379 |
+
layer_data = feats_flat[:, l, :] # (T, D)
|
| 380 |
+
seg_data = []
|
| 381 |
+
for i in range(n_segments):
|
| 382 |
+
s, e = i * SEGMENT_SIZE, (i + 1) * SEGMENT_SIZE
|
| 383 |
+
if e <= layer_data.shape[0]:
|
| 384 |
+
seg_data.append(layer_data[s:e].mean(axis=0))
|
| 385 |
+
if len(seg_data) < 3:
|
| 386 |
+
continue
|
| 387 |
+
seg_data = np.array(seg_data)
|
| 388 |
+
key = f"{mod}_L{l}"
|
| 389 |
+
if seg_data.std() < 1e-10:
|
| 390 |
+
ai_rdms[key] = np.zeros((len(seg_data), len(seg_data)))
|
| 391 |
+
else:
|
| 392 |
+
ai_rdms[key] = squareform(pdist(seg_data, metric="correlation"))
|
| 393 |
+
|
| 394 |
+
# Compare AI RDMs to brain RDMs using Spearman correlation
|
| 395 |
+
log.info("Computing RSA (Spearman correlation between RDMs)...")
|
| 396 |
+
rsa_matrix = np.zeros((len(ai_rdms), len(brain_rdms)))
|
| 397 |
+
ai_rdm_keys = sorted(ai_rdms.keys())
|
| 398 |
+
brain_rdm_keys = sorted(brain_rdms.keys())
|
| 399 |
+
|
| 400 |
+
for i, ak in enumerate(ai_rdm_keys):
|
| 401 |
+
ai_vec = squareform(ai_rdms[ak]) # upper triangle
|
| 402 |
+
if len(ai_vec) == 0:
|
| 403 |
+
continue
|
| 404 |
+
for j, bk in enumerate(brain_rdm_keys):
|
| 405 |
+
brain_vec = squareform(brain_rdms[bk])
|
| 406 |
+
min_len = min(len(ai_vec), len(brain_vec))
|
| 407 |
+
if min_len < 3:
|
| 408 |
+
continue
|
| 409 |
+
rho, _ = stats.spearmanr(ai_vec[:min_len], brain_vec[:min_len])
|
| 410 |
+
rsa_matrix[i, j] = rho if not np.isnan(rho) else 0
|
| 411 |
+
|
| 412 |
+
log.info(f"RSA matrix: {rsa_matrix.shape}")
|
| 413 |
+
|
| 414 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 415 |
+
# PHASE 6: Divergence identification
|
| 416 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 417 |
+
|
| 418 |
+
log.info("PHASE 6: Identifying AI-brain divergences")
|
| 419 |
+
|
| 420 |
+
results = []
|
| 421 |
+
for idx_j, lid in enumerate(brain_rdm_keys):
|
| 422 |
+
rinfo = regions[lid]
|
| 423 |
+
imp = region_mod_importance.get(lid, {})
|
| 424 |
+
|
| 425 |
+
# Best RSA alignment across all AI layers
|
| 426 |
+
rsa_col = rsa_matrix[:, idx_j]
|
| 427 |
+
best_rsa = np.max(np.abs(rsa_col)) if len(rsa_col) > 0 else 0
|
| 428 |
+
best_rsa_layer = ai_rdm_keys[np.argmax(np.abs(rsa_col))] if len(rsa_col) > 0 else "none"
|
| 429 |
+
|
| 430 |
+
# Best encoding correlation across all AI layers
|
| 431 |
+
region_corrs = corr_matrix[:, all_region_ids.index(lid)] if lid in all_region_ids else np.zeros(1)
|
| 432 |
+
best_encoding = np.max(np.abs(region_corrs))
|
| 433 |
+
best_enc_layer = all_layer_keys[np.argmax(np.abs(region_corrs))] if len(region_corrs) > 0 else "none"
|
| 434 |
+
|
| 435 |
+
# Modality entropy
|
| 436 |
+
probs = np.array(list(imp.values())) + 1e-10
|
| 437 |
+
probs = probs / probs.sum()
|
| 438 |
+
entropy = -np.sum(probs * np.log2(probs))
|
| 439 |
+
|
| 440 |
+
# Temporal dynamics
|
| 441 |
+
temporal_var = float(brain_preds[:, rinfo["mask"]].mean(axis=1).var())
|
| 442 |
+
|
| 443 |
+
# Divergence score: high temporal dynamics but poor AI alignment
|
| 444 |
+
divergence = temporal_var * (1 - best_rsa) * entropy
|
| 445 |
+
|
| 446 |
+
results.append({
|
| 447 |
+
"region_id": lid,
|
| 448 |
+
"region": rinfo["name"],
|
| 449 |
+
"hemisphere": "LH" if lid < 100 else "RH", # approximate
|
| 450 |
+
"n_vertices": rinfo["n_vertices"],
|
| 451 |
+
"temporal_variance": temporal_var,
|
| 452 |
+
"best_rsa_alignment": best_rsa,
|
| 453 |
+
"best_rsa_layer": best_rsa_layer,
|
| 454 |
+
"best_encoding_corr": best_encoding,
|
| 455 |
+
"best_encoding_layer": best_enc_layer,
|
| 456 |
+
"modality_entropy": entropy,
|
| 457 |
+
"divergence_score": divergence,
|
| 458 |
+
**{f"imp_{k}": v for k, v in imp.items()},
|
| 459 |
+
})
|
| 460 |
+
|
| 461 |
+
df = pd.DataFrame(results)
|
| 462 |
+
df = df.sort_values("divergence_score", ascending=False)
|
| 463 |
+
df.to_csv(OUT / "full_analysis.csv", index=False)
|
| 464 |
+
|
| 465 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 466 |
+
# PHASE 7: Visualization and reporting
|
| 467 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 468 |
+
|
| 469 |
+
log.info("PHASE 7: Generating visualizations and report")
|
| 470 |
+
|
| 471 |
+
# ── Plot 1: Layer-brain correlation heatmap ──
|
| 472 |
+
fig, ax = plt.subplots(figsize=(20, max(8, len(all_layer_keys) * 0.3)))
|
| 473 |
+
im = ax.imshow(corr_matrix, aspect="auto", cmap="RdBu_r", vmin=-1, vmax=1)
|
| 474 |
+
ax.set_yticks(range(len(all_layer_keys)))
|
| 475 |
+
ax.set_yticklabels(all_layer_keys, fontsize=6)
|
| 476 |
+
ax.set_xticks(range(len(region_names_list)))
|
| 477 |
+
ax.set_xticklabels(region_names_list, fontsize=4, rotation=90)
|
| 478 |
+
ax.set_title("Layer-wise Encoding: AI Layer ↔ Brain Region Correlation", fontsize=14)
|
| 479 |
+
ax.set_ylabel("AI Encoder Layer")
|
| 480 |
+
ax.set_xlabel("Brain Region (Destrieux)")
|
| 481 |
+
plt.colorbar(im, ax=ax, label="Pearson r")
|
| 482 |
+
fig.tight_layout()
|
| 483 |
+
fig.savefig(OUT / "01_layer_brain_correlation.png", dpi=200)
|
| 484 |
+
plt.close(fig)
|
| 485 |
+
log.info("Saved 01_layer_brain_correlation.png")
|
| 486 |
+
|
| 487 |
+
# ── Plot 2: RSA heatmap ──
|
| 488 |
+
fig, ax = plt.subplots(figsize=(20, max(8, len(ai_rdm_keys) * 0.3)))
|
| 489 |
+
im = ax.imshow(rsa_matrix, aspect="auto", cmap="RdBu_r", vmin=-0.5, vmax=0.5)
|
| 490 |
+
ax.set_yticks(range(len(ai_rdm_keys)))
|
| 491 |
+
ax.set_yticklabels(ai_rdm_keys, fontsize=6)
|
| 492 |
+
brain_region_names_rsa = [regions[lid]["name"] for lid in brain_rdm_keys]
|
| 493 |
+
ax.set_xticks(range(len(brain_region_names_rsa)))
|
| 494 |
+
ax.set_xticklabels(brain_region_names_rsa, fontsize=4, rotation=90)
|
| 495 |
+
ax.set_title("RSA: AI Layer ↔ Brain Region Representational Similarity", fontsize=14)
|
| 496 |
+
ax.set_ylabel("AI Encoder Layer")
|
| 497 |
+
ax.set_xlabel("Brain Region (Destrieux)")
|
| 498 |
+
plt.colorbar(im, ax=ax, label="Spearman ρ")
|
| 499 |
+
fig.tight_layout()
|
| 500 |
+
fig.savefig(OUT / "02_rsa_heatmap.png", dpi=200)
|
| 501 |
+
plt.close(fig)
|
| 502 |
+
log.info("Saved 02_rsa_heatmap.png")
|
| 503 |
+
|
| 504 |
+
# ── Plot 3: Divergence scatter ──
|
| 505 |
+
fig, ax = plt.subplots(figsize=(12, 10))
|
| 506 |
+
sc = ax.scatter(
|
| 507 |
+
df["best_rsa_alignment"],
|
| 508 |
+
df["temporal_variance"],
|
| 509 |
+
s=df["n_vertices"] * 0.5,
|
| 510 |
+
c=df["divergence_score"],
|
| 511 |
+
cmap="YlOrRd",
|
| 512 |
+
alpha=0.7,
|
| 513 |
+
edgecolors="k",
|
| 514 |
+
linewidths=0.3,
|
| 515 |
+
)
|
| 516 |
+
# Label top divergent
|
| 517 |
+
for _, row in df.head(12).iterrows():
|
| 518 |
+
ax.annotate(
|
| 519 |
+
row["region"],
|
| 520 |
+
(row["best_rsa_alignment"], row["temporal_variance"]),
|
| 521 |
+
fontsize=6, alpha=0.9,
|
| 522 |
+
arrowprops=dict(arrowstyle="-", alpha=0.3),
|
| 523 |
+
textcoords="offset points", xytext=(5, 5),
|
| 524 |
+
)
|
| 525 |
+
ax.set_xlabel("Best RSA Alignment (max |Spearman ρ| across AI layers)", fontsize=11)
|
| 526 |
+
ax.set_ylabel("Temporal Variance (brain dynamics)", fontsize=11)
|
| 527 |
+
ax.set_title("AI-Brain Divergence Map\nBottom-right = active brain regions poorly captured by AI", fontsize=13)
|
| 528 |
+
plt.colorbar(sc, label="Divergence Score")
|
| 529 |
+
fig.tight_layout()
|
| 530 |
+
fig.savefig(OUT / "03_divergence_scatter.png", dpi=200)
|
| 531 |
+
plt.close(fig)
|
| 532 |
+
log.info("Saved 03_divergence_scatter.png")
|
| 533 |
+
|
| 534 |
+
# ── Plot 4: Modality importance per region (stacked bar) ──
|
| 535 |
+
imp_cols = [c for c in df.columns if c.startswith("imp_")]
|
| 536 |
+
if imp_cols:
|
| 537 |
+
top_30 = df.nlargest(30, "temporal_variance")
|
| 538 |
+
fig, ax = plt.subplots(figsize=(14, 8))
|
| 539 |
+
bottom = np.zeros(len(top_30))
|
| 540 |
+
colors = plt.cm.Set2(np.linspace(0, 1, len(imp_cols)))
|
| 541 |
+
for ci, col in enumerate(imp_cols):
|
| 542 |
+
vals = top_30[col].values
|
| 543 |
+
ax.barh(range(len(top_30)), vals, left=bottom, color=colors[ci],
|
| 544 |
+
label=col.replace("imp_", "").upper())
|
| 545 |
+
bottom += vals
|
| 546 |
+
ax.set_yticks(range(len(top_30)))
|
| 547 |
+
ax.set_yticklabels(top_30["region"].values, fontsize=7)
|
| 548 |
+
ax.set_xlabel("Relative Modality Importance")
|
| 549 |
+
ax.set_title("Modality Contribution per Brain Region (top 30 by dynamics)", fontsize=13)
|
| 550 |
+
ax.legend(loc="lower right")
|
| 551 |
+
fig.tight_layout()
|
| 552 |
+
fig.savefig(OUT / "04_modality_importance.png", dpi=200)
|
| 553 |
+
plt.close(fig)
|
| 554 |
+
log.info("Saved 04_modality_importance.png")
|
| 555 |
+
|
| 556 |
+
# ── Plot 5: Brain surface divergence map ──
|
| 557 |
+
try:
|
| 558 |
+
from nilearn.plotting import plot_surf_stat_map
|
| 559 |
+
|
| 560 |
+
vertex_divergence = np.zeros(N_VERT * 2)
|
| 561 |
+
for _, row in df.iterrows():
|
| 562 |
+
lid = row["region_id"]
|
| 563 |
+
if lid in regions:
|
| 564 |
+
vertex_divergence[regions[lid]["mask"]] = row["divergence_score"]
|
| 565 |
+
|
| 566 |
+
fig = plt.figure(figsize=(16, 12))
|
| 567 |
+
for idx, (hemi, view) in enumerate([
|
| 568 |
+
("left", "lateral"), ("left", "medial"),
|
| 569 |
+
("right", "lateral"), ("right", "medial")
|
| 570 |
+
]):
|
| 571 |
+
ax = fig.add_subplot(2, 2, idx + 1, projection="3d")
|
| 572 |
+
if hemi == "left":
|
| 573 |
+
data = vertex_divergence[:N_VERT]
|
| 574 |
+
else:
|
| 575 |
+
data = vertex_divergence[N_VERT:]
|
| 576 |
+
plot_surf_stat_map(
|
| 577 |
+
fsaverage5[f"pial_{hemi}"],
|
| 578 |
+
data,
|
| 579 |
+
hemi=hemi,
|
| 580 |
+
view=view,
|
| 581 |
+
cmap="YlOrRd",
|
| 582 |
+
title=f"Divergence ({hemi} {view})",
|
| 583 |
+
figure=fig,
|
| 584 |
+
axes=ax,
|
| 585 |
+
)
|
| 586 |
+
fig.suptitle("Brain Surface: AI-Brain Divergence Scores", fontsize=14, y=1.02)
|
| 587 |
+
fig.tight_layout()
|
| 588 |
+
fig.savefig(OUT / "05_brain_surface_divergence.png", dpi=200, bbox_inches="tight")
|
| 589 |
+
plt.close(fig)
|
| 590 |
+
log.info("Saved 05_brain_surface_divergence.png")
|
| 591 |
+
except Exception as e:
|
| 592 |
+
log.warning(f"Brain surface plot failed: {e}")
|
| 593 |
+
|
| 594 |
+
# ── Plot 6: Per-modality best layer alignment profile ──
|
| 595 |
+
for mod in all_features:
|
| 596 |
+
mod_keys = [k for k in all_layer_keys if k.startswith(f"{mod}_")]
|
| 597 |
+
if not mod_keys:
|
| 598 |
+
continue
|
| 599 |
+
mod_indices = [all_layer_keys.index(k) for k in mod_keys]
|
| 600 |
+
mod_corr = corr_matrix[mod_indices, :] # (n_layers_mod, n_regions)
|
| 601 |
+
|
| 602 |
+
fig, ax = plt.subplots(figsize=(14, 6))
|
| 603 |
+
im = ax.imshow(mod_corr, aspect="auto", cmap="RdBu_r", vmin=-1, vmax=1)
|
| 604 |
+
ax.set_yticks(range(len(mod_keys)))
|
| 605 |
+
ax.set_yticklabels([f"Layer {i}" for i in range(len(mod_keys))], fontsize=8)
|
| 606 |
+
ax.set_xticks(range(len(region_names_list)))
|
| 607 |
+
ax.set_xticklabels(region_names_list, fontsize=4, rotation=90)
|
| 608 |
+
ax.set_title(f"{mod.upper()} Encoder: Layer-wise Brain Alignment", fontsize=13)
|
| 609 |
+
ax.set_ylabel("Encoder Layer (early → late)")
|
| 610 |
+
ax.set_xlabel("Brain Region")
|
| 611 |
+
plt.colorbar(im, ax=ax, label="Pearson r")
|
| 612 |
+
fig.tight_layout()
|
| 613 |
+
fig.savefig(OUT / f"06_{mod}_layer_alignment.png", dpi=200)
|
| 614 |
+
plt.close(fig)
|
| 615 |
+
log.info(f"Saved 06_{mod}_layer_alignment.png")
|
| 616 |
+
|
| 617 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 618 |
+
# PHASE 8: Final report
|
| 619 |
+
# ═══════════════════════════════════════════════════════════════════════════════
|
| 620 |
+
|
| 621 |
+
report = []
|
| 622 |
+
report.append("=" * 100)
|
| 623 |
+
report.append("CLOSING THE AI-BRAIN LOOP: Analysis Report")
|
| 624 |
+
report.append(f"Generated: {pd.Timestamp.now()}")
|
| 625 |
+
report.append("=" * 100)
|
| 626 |
+
|
| 627 |
+
report.append("\n\n--- DATASET ---")
|
| 628 |
+
report.append(f"Stimulus: {VIDEO}")
|
| 629 |
+
report.append(f"Total time points: {T_total}")
|
| 630 |
+
report.append(f"Brain vertices: {V} (fsaverage5)")
|
| 631 |
+
report.append(f"Brain regions analyzed: {len(regions)} (Destrieux atlas)")
|
| 632 |
+
report.append(f"AI modalities: {modalities}")
|
| 633 |
+
for mod, feats in all_features.items():
|
| 634 |
+
report.append(f" {mod}: {feats.shape[1]} layers, {feats.shape[2]}-dim features")
|
| 635 |
+
|
| 636 |
+
report.append("\n\n--- TOP 15 DIVERGENCE REGIONS ---")
|
| 637 |
+
report.append("(Brain regions with high dynamics but poor AI alignment)")
|
| 638 |
+
report.append("")
|
| 639 |
+
cols = ["region", "temporal_variance", "best_rsa_alignment", "best_rsa_layer",
|
| 640 |
+
"modality_entropy", "divergence_score"]
|
| 641 |
+
cols_present = [c for c in cols if c in df.columns]
|
| 642 |
+
report.append(df[cols_present].head(15).to_string(index=False, float_format="%.4f"))
|
| 643 |
+
|
| 644 |
+
report.append("\n\n--- TOP 15 WELL-ALIGNED REGIONS ---")
|
| 645 |
+
report.append("(Brain regions where AI encoders match brain representations well)")
|
| 646 |
+
well_aligned = df.nlargest(15, "best_rsa_alignment")
|
| 647 |
+
report.append(well_aligned[cols_present].to_string(index=False, float_format="%.4f"))
|
| 648 |
+
|
| 649 |
+
report.append("\n\n--- MODALITY DOMINANCE ---")
|
| 650 |
+
for mod in modalities:
|
| 651 |
+
col = f"imp_{mod}"
|
| 652 |
+
if col not in df.columns:
|
| 653 |
+
continue
|
| 654 |
+
report.append(f"\n{mod.upper()}-dominated regions (top 5):")
|
| 655 |
+
top = df.nlargest(5, col)
|
| 656 |
+
for _, row in top.iterrows():
|
| 657 |
+
report.append(f" {row['region']:45s} importance={row[col]:.4f} rsa={row['best_rsa_alignment']:.4f}")
|
| 658 |
+
|
| 659 |
+
report.append("\n\n--- ARCHITECTURAL IMPLICATIONS ---")
|
| 660 |
+
report.append("""
|
| 661 |
+
Based on the divergence analysis, here are the gaps in current AI architectures
|
| 662 |
+
and proposed solutions:
|
| 663 |
+
|
| 664 |
+
1. HIGH-ENTROPY DIVERGENCE REGIONS (multiple modalities contribute equally,
|
| 665 |
+
but overall alignment is poor):
|
| 666 |
+
→ The brain performs CROSS-MODAL INTEGRATION that the concatenation-based
|
| 667 |
+
fusion in TRIBE v2 (and most multimodal AI) doesn't capture.
|
| 668 |
+
→ Proposed fix: EARLY FUSION with cross-attention between modality streams
|
| 669 |
+
at intermediate layers, not just late concatenation.
|
| 670 |
+
|
| 671 |
+
2. HIGH TEMPORAL VARIANCE + LOW ALIGNMENT:
|
| 672 |
+
→ The brain has strong TEMPORAL DYNAMICS (prediction, memory, feedback loops)
|
| 673 |
+
that feedforward AI encoders miss entirely.
|
| 674 |
+
→ Proposed fix: Add RECURRENT connections or PREDICTIVE CODING layers that
|
| 675 |
+
generate top-down predictions and propagate prediction errors.
|
| 676 |
+
|
| 677 |
+
3. REGIONS WHERE NO SINGLE LAYER ALIGNS WELL:
|
| 678 |
+
→ The brain's computation in these areas may involve representations that
|
| 679 |
+
DON'T EXIST in any layer of V-JEPA2, LLaMA, or Wav2Vec-BERT.
|
| 680 |
+
→ Proposed fix: Train a NEW encoder objective that explicitly optimizes for
|
| 681 |
+
brain alignment in these gap regions (brain-guided contrastive learning).
|
| 682 |
+
|
| 683 |
+
4. LAYER DEPTH PATTERNS:
|
| 684 |
+
→ If early AI layers align with sensory cortex and late layers align with
|
| 685 |
+
association cortex, this confirms the HIERARCHICAL CORRESPONDENCE between
|
| 686 |
+
DNNs and the cortical hierarchy.
|
| 687 |
+
→ Where this breaks (e.g., late layers don't align with prefrontal cortex),
|
| 688 |
+
it suggests the model lacks EXECUTIVE/ABSTRACT processing.
|
| 689 |
+
""")
|
| 690 |
+
|
| 691 |
+
report.append("\n--- FILES ---")
|
| 692 |
+
report.append(f"Full CSV: {OUT / 'full_analysis.csv'}")
|
| 693 |
+
report.append(f"Plots: {OUT / '*.png'}")
|
| 694 |
+
report.append("=" * 100)
|
| 695 |
+
|
| 696 |
+
report_text = "\n".join(report)
|
| 697 |
+
print(report_text)
|
| 698 |
+
|
| 699 |
+
with open(OUT / "report.txt", "w") as f:
|
| 700 |
+
f.write(report_text)
|
| 701 |
+
|
| 702 |
+
log.info(f"All results saved to {OUT}")
|
| 703 |
+
log.info("Done.")
|
generate_dataset.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Generate a diverse character dataset using Flux 1 + PuLID for LoRA training."""
|
| 3 |
+
|
| 4 |
+
import json
|
| 5 |
+
import urllib.request
|
| 6 |
+
import time
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
COMFYUI_URL = "http://127.0.0.1:80"
|
| 10 |
+
DATASET_DIR = "/home/azureuser/ai-toolkit/character_dataset"
|
| 11 |
+
TRIGGER_WORD = "ohwx"
|
| 12 |
+
|
| 13 |
+
# Diverse prompts covering different angles, scenes, lighting, outfits
|
| 14 |
+
PROMPTS = [
|
| 15 |
+
# Close-ups / headshots
|
| 16 |
+
(f"close up portrait photo of {TRIGGER_WORD} woman, natural lighting, soft smile, looking at camera, shallow depth of field", "closeup_front"),
|
| 17 |
+
(f"close up portrait of {TRIGGER_WORD} woman, side profile view, golden hour lighting, outdoors", "closeup_profile"),
|
| 18 |
+
(f"close up portrait of {TRIGGER_WORD} woman, three quarter view, studio lighting, neutral expression", "closeup_34"),
|
| 19 |
+
(f"headshot of {TRIGGER_WORD} woman, looking slightly up, soft natural light, gentle smile", "closeup_up"),
|
| 20 |
+
(f"close up of {TRIGGER_WORD} woman, looking down at something in her hands, natural indoor lighting", "closeup_down"),
|
| 21 |
+
|
| 22 |
+
# Half body shots
|
| 23 |
+
(f"half body photo of {TRIGGER_WORD} woman in a white blouse, sitting at a cafe table, warm ambient light, candid", "half_cafe"),
|
| 24 |
+
(f"half body photo of {TRIGGER_WORD} woman wearing a leather jacket, urban street background, overcast day", "half_street"),
|
| 25 |
+
(f"half body photo of {TRIGGER_WORD} woman in athletic wear, gym setting, bright lighting", "half_gym"),
|
| 26 |
+
(f"half body photo of {TRIGGER_WORD} woman in a sundress, garden background, dappled sunlight", "half_garden"),
|
| 27 |
+
(f"half body photo of {TRIGGER_WORD} woman in business attire, modern office, professional lighting", "half_office"),
|
| 28 |
+
|
| 29 |
+
# Full body shots
|
| 30 |
+
(f"full body photo of {TRIGGER_WORD} woman walking on the beach, sunset, casual summer outfit, warm tones", "full_beach"),
|
| 31 |
+
(f"full body photo of {TRIGGER_WORD} woman standing in a city street, winter coat, evening city lights", "full_city"),
|
| 32 |
+
(f"full body photo of {TRIGGER_WORD} woman hiking on a mountain trail, athletic outfit, golden hour", "full_hike"),
|
| 33 |
+
(f"full body photo of {TRIGGER_WORD} woman leaning against a wall, casual jeans and t-shirt, natural daylight", "full_casual"),
|
| 34 |
+
|
| 35 |
+
# Different lighting conditions
|
| 36 |
+
(f"portrait of {TRIGGER_WORD} woman, dramatic side lighting, dark moody background, artistic photo", "light_dramatic"),
|
| 37 |
+
(f"portrait of {TRIGGER_WORD} woman, bright overcast natural light, white background, clean look", "light_bright"),
|
| 38 |
+
(f"portrait of {TRIGGER_WORD} woman, warm golden hour backlight, hair glowing, outdoor", "light_golden"),
|
| 39 |
+
(f"portrait of {TRIGGER_WORD} woman, soft window light, sitting on a couch, cozy indoor setting", "light_window"),
|
| 40 |
+
|
| 41 |
+
# Different expressions
|
| 42 |
+
(f"photo of {TRIGGER_WORD} woman laughing genuinely, candid moment, natural setting", "expr_laugh"),
|
| 43 |
+
(f"photo of {TRIGGER_WORD} woman with a serious contemplative expression, looking into distance", "expr_serious"),
|
| 44 |
+
]
|
| 45 |
+
|
| 46 |
+
# Corresponding captions (without trigger word for variety, will be added by trainer)
|
| 47 |
+
CAPTIONS = {
|
| 48 |
+
"closeup_front": f"close up portrait photo of {TRIGGER_WORD} woman, natural lighting, soft smile, looking at camera",
|
| 49 |
+
"closeup_profile": f"close up portrait of {TRIGGER_WORD} woman, side profile view, golden hour lighting, outdoors",
|
| 50 |
+
"closeup_34": f"close up portrait of {TRIGGER_WORD} woman, three quarter view, studio lighting, neutral expression",
|
| 51 |
+
"closeup_up": f"headshot of {TRIGGER_WORD} woman, looking slightly up, soft natural light, gentle smile",
|
| 52 |
+
"closeup_down": f"close up of {TRIGGER_WORD} woman, looking down, natural indoor lighting",
|
| 53 |
+
"half_cafe": f"half body photo of {TRIGGER_WORD} woman in a white blouse, sitting at a cafe, warm ambient light",
|
| 54 |
+
"half_street": f"half body photo of {TRIGGER_WORD} woman wearing a leather jacket, urban street, overcast day",
|
| 55 |
+
"half_gym": f"half body photo of {TRIGGER_WORD} woman in athletic wear, gym setting, bright lighting",
|
| 56 |
+
"half_garden": f"half body photo of {TRIGGER_WORD} woman in a sundress, garden, dappled sunlight",
|
| 57 |
+
"half_office": f"half body photo of {TRIGGER_WORD} woman in business attire, modern office",
|
| 58 |
+
"full_beach": f"full body photo of {TRIGGER_WORD} woman walking on the beach, sunset, casual summer outfit",
|
| 59 |
+
"full_city": f"full body photo of {TRIGGER_WORD} woman standing in city street, winter coat, evening lights",
|
| 60 |
+
"full_hike": f"full body photo of {TRIGGER_WORD} woman hiking on mountain trail, athletic outfit, golden hour",
|
| 61 |
+
"full_casual": f"full body photo of {TRIGGER_WORD} woman leaning against wall, jeans and t-shirt, daylight",
|
| 62 |
+
"light_dramatic": f"portrait of {TRIGGER_WORD} woman, dramatic side lighting, dark moody background",
|
| 63 |
+
"light_bright": f"portrait of {TRIGGER_WORD} woman, bright overcast light, white background, clean",
|
| 64 |
+
"light_golden": f"portrait of {TRIGGER_WORD} woman, warm golden hour backlight, hair glowing, outdoor",
|
| 65 |
+
"light_window": f"portrait of {TRIGGER_WORD} woman, soft window light, sitting on couch, cozy indoor",
|
| 66 |
+
"expr_laugh": f"photo of {TRIGGER_WORD} woman laughing genuinely, candid moment, natural setting",
|
| 67 |
+
"expr_serious": f"photo of {TRIGGER_WORD} woman with serious contemplative expression, looking into distance",
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def queue_prompt(prompt_text, filename, seed):
|
| 72 |
+
workflow = {
|
| 73 |
+
"1": {"class_type": "UNETLoader", "inputs": {"unet_name": "flux1-dev.safetensors", "weight_dtype": "default"}},
|
| 74 |
+
"2": {"class_type": "DualCLIPLoader", "inputs": {"clip_name1": "t5xxl_fp16.safetensors", "clip_name2": "clip_l.safetensors", "type": "flux"}},
|
| 75 |
+
"3": {"class_type": "VAELoader", "inputs": {"vae_name": "ae.safetensors"}},
|
| 76 |
+
"4": {"class_type": "PulidFluxModelLoader", "inputs": {"pulid_file": "pulid_flux_v0.9.1.safetensors"}},
|
| 77 |
+
"5": {"class_type": "PulidFluxInsightFaceLoader", "inputs": {"provider": "CUDA"}},
|
| 78 |
+
"6": {"class_type": "PulidFluxEvaClipLoader", "inputs": {}},
|
| 79 |
+
"7": {"class_type": "LoadImage", "inputs": {"image": "reference_face.png"}},
|
| 80 |
+
"8": {"class_type": "ApplyPulidFlux", "inputs": {
|
| 81 |
+
"model": ["1", 0], "pulid_flux": ["4", 0], "eva_clip": ["6", 0],
|
| 82 |
+
"face_analysis": ["5", 0], "image": ["7", 0],
|
| 83 |
+
"weight": 0.85, "start_at": 0.0, "end_at": 1.0
|
| 84 |
+
}},
|
| 85 |
+
"9": {"class_type": "CLIPTextEncodeFlux", "inputs": {
|
| 86 |
+
"clip": ["2", 0],
|
| 87 |
+
"clip_l": prompt_text[:77],
|
| 88 |
+
"t5xxl": prompt_text,
|
| 89 |
+
"guidance": 3.5
|
| 90 |
+
}},
|
| 91 |
+
"10": {"class_type": "EmptySD3LatentImage", "inputs": {"width": 1024, "height": 1024, "batch_size": 1}},
|
| 92 |
+
"11": {"class_type": "KSampler", "inputs": {
|
| 93 |
+
"model": ["8", 0], "positive": ["9", 0], "negative": ["9", 0],
|
| 94 |
+
"latent_image": ["10", 0], "seed": seed,
|
| 95 |
+
"control_after_generate": "fixed", "steps": 20, "cfg": 1.0,
|
| 96 |
+
"sampler_name": "euler", "scheduler": "simple", "denoise": 1.0
|
| 97 |
+
}},
|
| 98 |
+
"12": {"class_type": "VAEDecode", "inputs": {"samples": ["11", 0], "vae": ["3", 0]}},
|
| 99 |
+
"13": {"class_type": "SaveImage", "inputs": {"images": ["12", 0], "filename_prefix": f"dataset_{filename}"}}
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
data = json.dumps({"prompt": workflow}).encode()
|
| 103 |
+
req = urllib.request.Request(f'{COMFYUI_URL}/prompt', data=data, headers={'Content-Type': 'application/json'})
|
| 104 |
+
resp = urllib.request.urlopen(req)
|
| 105 |
+
return json.loads(resp.read())['prompt_id']
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def wait_for_completion(prompt_id, timeout=600):
|
| 109 |
+
start = time.time()
|
| 110 |
+
while time.time() - start < timeout:
|
| 111 |
+
req = urllib.request.Request(f'{COMFYUI_URL}/history/{prompt_id}')
|
| 112 |
+
resp = urllib.request.urlopen(req)
|
| 113 |
+
history = json.loads(resp.read())
|
| 114 |
+
if prompt_id in history:
|
| 115 |
+
h = history[prompt_id]
|
| 116 |
+
status = h.get('status', {}).get('status_str', '')
|
| 117 |
+
if status == 'success':
|
| 118 |
+
for nid, out in h['outputs'].items():
|
| 119 |
+
if 'images' in out:
|
| 120 |
+
return out['images'][0]['filename']
|
| 121 |
+
elif status == 'error':
|
| 122 |
+
msgs = h.get('status', {}).get('messages', [])
|
| 123 |
+
for m in msgs:
|
| 124 |
+
if m[0] == 'execution_error':
|
| 125 |
+
print(f" ERROR: {m[1].get('exception_message', 'unknown')[:200]}")
|
| 126 |
+
return None
|
| 127 |
+
time.sleep(2)
|
| 128 |
+
return None
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def main():
|
| 132 |
+
print(f"Generating {len(PROMPTS)} character images for LoRA training...\n")
|
| 133 |
+
|
| 134 |
+
for i, (prompt, name) in enumerate(PROMPTS):
|
| 135 |
+
seed = 10000 + i * 1337
|
| 136 |
+
print(f"[{i+1}/{len(PROMPTS)}] {name} (seed={seed})")
|
| 137 |
+
prompt_id = queue_prompt(prompt, name, seed)
|
| 138 |
+
filename = wait_for_completion(prompt_id)
|
| 139 |
+
|
| 140 |
+
if filename:
|
| 141 |
+
# Copy to dataset folder
|
| 142 |
+
src = f"/home/azureuser/ComfyUI/output/{filename}"
|
| 143 |
+
dst = os.path.join(DATASET_DIR, f"{name}.png")
|
| 144 |
+
os.system(f"cp '{src}' '{dst}'")
|
| 145 |
+
|
| 146 |
+
# Write caption
|
| 147 |
+
caption = CAPTIONS[name]
|
| 148 |
+
with open(os.path.join(DATASET_DIR, f"{name}.txt"), 'w') as f:
|
| 149 |
+
f.write(caption)
|
| 150 |
+
|
| 151 |
+
print(f" -> saved {name}.png + caption")
|
| 152 |
+
else:
|
| 153 |
+
print(f" -> FAILED")
|
| 154 |
+
|
| 155 |
+
print(f"\nDone! Dataset at: {DATASET_DIR}")
|
| 156 |
+
print(f"Files: {len(os.listdir(DATASET_DIR))}")
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
if __name__ == "__main__":
|
| 160 |
+
main()
|
generate_dataset_flux2.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Generate diverse Indian woman character dataset using Flux 2 Dev for LoRA training."""
|
| 3 |
+
|
| 4 |
+
import json
|
| 5 |
+
import urllib.request
|
| 6 |
+
import time
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
COMFYUI_URL = "http://127.0.0.1:80"
|
| 10 |
+
DATASET_DIR = "/home/azureuser/ai-toolkit/character_dataset"
|
| 11 |
+
TRIGGER = "ohwx"
|
| 12 |
+
|
| 13 |
+
# Consistent identity description used in EVERY prompt
|
| 14 |
+
IDENTITY = f"a young 21 year old Indian woman called {TRIGGER}, long straight black hair, brown eyes, light brown skin, soft facial features, natural beauty"
|
| 15 |
+
|
| 16 |
+
# (prompt_suffix, filename, caption)
|
| 17 |
+
SHOTS = [
|
| 18 |
+
# === ANGLES / HEAD POSITIONS ===
|
| 19 |
+
("front facing portrait, looking directly at camera, neutral expression, white background, studio lighting, photorealistic", "angle_front", "front facing portrait, looking at camera, neutral expression, studio lighting"),
|
| 20 |
+
("side profile portrait, left side of face visible, looking left, clean background, studio lighting, photorealistic", "angle_left_profile", "side profile, left side visible, looking left, studio lighting"),
|
| 21 |
+
("side profile portrait, right side of face visible, looking right, clean background, studio lighting, photorealistic", "angle_right_profile", "side profile, right side visible, looking right, studio lighting"),
|
| 22 |
+
("three quarter view portrait, slightly turned to the left, soft lighting, clean background, photorealistic", "angle_34_left", "three quarter view, slightly turned left, soft lighting"),
|
| 23 |
+
("three quarter view portrait, slightly turned to the right, soft lighting, clean background, photorealistic", "angle_34_right", "three quarter view, slightly turned right, soft lighting"),
|
| 24 |
+
("looking up, chin slightly raised, soft lighting from above, clean background, photorealistic portrait", "angle_looking_up", "looking up, chin raised, soft overhead lighting"),
|
| 25 |
+
("looking down, reading a book, natural indoor lighting, soft focus background, candid photo", "angle_looking_down", "looking down reading a book, natural indoor lighting, candid"),
|
| 26 |
+
("back view, looking over shoulder at camera, long black hair visible, outdoor setting, natural light", "angle_back", "back view looking over shoulder, long black hair visible, outdoor natural light"),
|
| 27 |
+
("back of head view, long straight black hair falling down, standing outdoors, natural daylight", "angle_back_full", "back of head, long straight black hair, standing outdoors"),
|
| 28 |
+
("tilted head portrait, slight head tilt to the right, gentle smile, soft lighting, photorealistic", "angle_tilt", "tilted head portrait, slight tilt right, gentle smile"),
|
| 29 |
+
|
| 30 |
+
# === EMOTIONS ===
|
| 31 |
+
("smiling brightly, genuine happy smile showing teeth, warm expression, natural light, photorealistic portrait", "emo_happy", "smiling brightly, genuine happy smile, warm expression, natural light"),
|
| 32 |
+
("sad expression, looking down slightly, melancholic mood, soft moody lighting, photorealistic portrait", "emo_sad", "sad expression, looking down, melancholic, soft moody lighting"),
|
| 33 |
+
("laughing candidly, eyes slightly closed, natural joyful moment, outdoor sunlight, candid photo", "emo_laugh", "laughing candidly, joyful moment, outdoor sunlight, candid"),
|
| 34 |
+
("serious contemplative expression, deep in thought, looking into distance, dramatic lighting, photorealistic", "emo_serious", "serious contemplative expression, deep in thought, dramatic lighting"),
|
| 35 |
+
("surprised expression, eyes wide, mouth slightly open, bright lighting, photorealistic portrait", "emo_surprised", "surprised expression, eyes wide, bright lighting"),
|
| 36 |
+
("shy smile, looking slightly away from camera, subtle smile, soft warm lighting, photorealistic", "emo_shy", "shy smile, looking slightly away, subtle smile, soft warm lighting"),
|
| 37 |
+
("confident expression, strong gaze at camera, slight smirk, professional lighting, photorealistic", "emo_confident", "confident expression, strong gaze, slight smirk, professional lighting"),
|
| 38 |
+
("peaceful serene expression, eyes closed, meditating, soft natural light, photorealistic portrait", "emo_peaceful", "peaceful serene expression, eyes closed, meditating, soft light"),
|
| 39 |
+
|
| 40 |
+
# === HALF BODY / DIFFERENT OUTFITS ===
|
| 41 |
+
("half body photo, wearing traditional Indian salwar kameez, standing in a garden, natural sunlight, photorealistic", "outfit_salwar", "half body, wearing traditional Indian salwar kameez, garden, sunlight"),
|
| 42 |
+
("half body photo, wearing modern casual jeans and white t-shirt, urban street background, daylight, photorealistic", "outfit_casual", "half body, casual jeans and white t-shirt, urban street, daylight"),
|
| 43 |
+
("half body photo, wearing formal black blazer and white shirt, office setting, professional lighting, photorealistic", "outfit_formal", "half body, formal black blazer, office setting, professional lighting"),
|
| 44 |
+
("half body photo, wearing a red saree with gold border, elegant pose, warm indoor lighting, photorealistic", "outfit_saree", "half body, wearing red saree with gold border, elegant pose, warm indoor light"),
|
| 45 |
+
("half body photo, wearing athletic sportswear, gym background, bright fluorescent lighting, photorealistic", "outfit_sport", "half body, athletic sportswear, gym, bright lighting"),
|
| 46 |
+
("half body photo, wearing a cozy sweater, sitting by window, rainy day outside, soft natural light, photorealistic", "outfit_cozy", "half body, cozy sweater, sitting by window, rainy day, soft light"),
|
| 47 |
+
|
| 48 |
+
# === FULL BODY / DIFFERENT SCENES ===
|
| 49 |
+
("full body photo, walking on a beach at sunset, wearing summer dress, golden hour light, barefoot on sand, photorealistic", "full_beach", "full body, walking on beach at sunset, summer dress, golden hour"),
|
| 50 |
+
("full body photo, standing in a city street, wearing winter jacket, evening city lights, urban photography, photorealistic", "full_city", "full body, city street, winter jacket, evening city lights"),
|
| 51 |
+
("full body photo, sitting cross legged on grass in a park, casual outfit, dappled sunlight through trees, photorealistic", "full_park", "full body, sitting on grass in park, casual outfit, dappled sunlight"),
|
| 52 |
+
("full body photo, dancing in the rain, white dress, joyful expression, dramatic wet look, photorealistic", "full_rain", "full body, dancing in rain, white dress, joyful, dramatic"),
|
| 53 |
+
("full body photo, standing at a temple entrance, traditional Indian outfit, warm morning light, photorealistic", "full_temple", "full body, standing at temple entrance, traditional outfit, morning light"),
|
| 54 |
+
|
| 55 |
+
# === DIFFERENT LIGHTING ===
|
| 56 |
+
("portrait, dramatic chiaroscuro lighting, half face in shadow, artistic, high contrast, photorealistic", "light_dramatic", "portrait, dramatic chiaroscuro lighting, half face in shadow, high contrast"),
|
| 57 |
+
("portrait, soft golden hour backlight, hair glowing, warm tones, outdoor, lens flare, photorealistic", "light_golden", "portrait, golden hour backlight, hair glowing, warm tones, outdoor"),
|
| 58 |
+
("portrait, bright overcast daylight, flat even lighting, outdoor, clean natural look, photorealistic", "light_overcast", "portrait, bright overcast daylight, even lighting, outdoor, natural"),
|
| 59 |
+
("portrait, neon colored lights reflecting on face, nighttime, urban setting, cinematic, photorealistic", "light_neon", "portrait, neon lights on face, nighttime, urban, cinematic"),
|
| 60 |
+
("portrait, soft warm candlelight, intimate setting, warm orange tones, photorealistic", "light_candle", "portrait, soft candlelight, intimate setting, warm orange tones"),
|
| 61 |
+
]
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def queue_prompt(prompt_text, filename, seed):
|
| 65 |
+
workflow = {
|
| 66 |
+
"1": {"class_type": "UNETLoader", "inputs": {"unet_name": "flux2_dev_fp8mixed.safetensors", "weight_dtype": "default"}},
|
| 67 |
+
"2": {"class_type": "CLIPLoader", "inputs": {"clip_name": "mistral_3_small_flux2_bf16.safetensors", "type": "flux2"}},
|
| 68 |
+
"3": {"class_type": "VAELoader", "inputs": {"vae_name": "flux2-vae.safetensors"}},
|
| 69 |
+
"9": {"class_type": "CLIPTextEncode", "inputs": {"clip": ["2", 0], "text": prompt_text}},
|
| 70 |
+
"10": {"class_type": "EmptyFlux2LatentImage", "inputs": {"width": 1024, "height": 1024, "batch_size": 1}},
|
| 71 |
+
"11": {"class_type": "KSampler", "inputs": {
|
| 72 |
+
"model": ["1", 0], "positive": ["9", 0], "negative": ["9", 0],
|
| 73 |
+
"latent_image": ["10", 0], "seed": seed,
|
| 74 |
+
"control_after_generate": "fixed", "steps": 20, "cfg": 1.0,
|
| 75 |
+
"sampler_name": "euler", "scheduler": "simple", "denoise": 1.0
|
| 76 |
+
}},
|
| 77 |
+
"12": {"class_type": "VAEDecode", "inputs": {"samples": ["11", 0], "vae": ["3", 0]}},
|
| 78 |
+
"13": {"class_type": "SaveImage", "inputs": {"images": ["12", 0], "filename_prefix": f"ds_{filename}"}}
|
| 79 |
+
}
|
| 80 |
+
data = json.dumps({"prompt": workflow}).encode()
|
| 81 |
+
req = urllib.request.Request(f'{COMFYUI_URL}/prompt', data=data, headers={'Content-Type': 'application/json'})
|
| 82 |
+
resp = urllib.request.urlopen(req)
|
| 83 |
+
return json.loads(resp.read())['prompt_id']
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def wait_for_completion(prompt_id, timeout=600):
|
| 87 |
+
start = time.time()
|
| 88 |
+
while time.time() - start < timeout:
|
| 89 |
+
req = urllib.request.Request(f'{COMFYUI_URL}/history/{prompt_id}')
|
| 90 |
+
resp = urllib.request.urlopen(req)
|
| 91 |
+
history = json.loads(resp.read())
|
| 92 |
+
if prompt_id in history:
|
| 93 |
+
h = history[prompt_id]
|
| 94 |
+
s = h.get('status', {}).get('status_str', '')
|
| 95 |
+
if s == 'success':
|
| 96 |
+
for out in h['outputs'].values():
|
| 97 |
+
if 'images' in out:
|
| 98 |
+
return out['images'][0]['filename']
|
| 99 |
+
elif s == 'error':
|
| 100 |
+
return None
|
| 101 |
+
time.sleep(2)
|
| 102 |
+
return None
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def main():
|
| 106 |
+
print(f"Generating {len(SHOTS)} images with Flux 2 Dev...\n")
|
| 107 |
+
|
| 108 |
+
for i, (suffix, name, caption) in enumerate(SHOTS):
|
| 109 |
+
seed = 42424 + i * 997
|
| 110 |
+
full_prompt = f"{IDENTITY}, {suffix}"
|
| 111 |
+
full_caption = f"photo of {IDENTITY}, {caption}"
|
| 112 |
+
|
| 113 |
+
print(f"[{i+1}/{len(SHOTS)}] {name}")
|
| 114 |
+
prompt_id = queue_prompt(full_prompt, name, seed)
|
| 115 |
+
filename = wait_for_completion(prompt_id)
|
| 116 |
+
|
| 117 |
+
if filename:
|
| 118 |
+
src = f"/home/azureuser/ComfyUI/output/{filename}"
|
| 119 |
+
dst = os.path.join(DATASET_DIR, f"{name}.png")
|
| 120 |
+
os.system(f"cp '{src}' '{dst}'")
|
| 121 |
+
with open(os.path.join(DATASET_DIR, f"{name}.txt"), 'w') as f:
|
| 122 |
+
f.write(full_caption)
|
| 123 |
+
print(f" -> OK")
|
| 124 |
+
else:
|
| 125 |
+
print(f" -> FAILED")
|
| 126 |
+
|
| 127 |
+
total = len([f for f in os.listdir(DATASET_DIR) if f.endswith('.png')])
|
| 128 |
+
print(f"\nDone! {total} images in {DATASET_DIR}")
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
if __name__ == "__main__":
|
| 132 |
+
main()
|
gsoc_proposal_content.md
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# GSoC 2026 Proposal: LLM Support for 7B Models (OLMo) in DeepChem
|
| 2 |
+
# CONTENT REFERENCE — REWRITE IN YOUR OWN WORDS BEFORE SUBMITTING
|
| 3 |
+
|
| 4 |
+
---
|
| 5 |
+
|
| 6 |
+
## 1. INTRODUCTION
|
| 7 |
+
|
| 8 |
+
### What this project is about
|
| 9 |
+
DeepChem's HuggingFaceModel wrapper currently supports encoder-only models
|
| 10 |
+
(ChemBERTa, MoLFormer) through masked language modeling. There is no support
|
| 11 |
+
for decoder-only causal language models. This project adds OLMo-2
|
| 12 |
+
(Allen AI's open language model) to DeepChem, enabling:
|
| 13 |
+
- Continued pretraining on molecular data (SMILES)
|
| 14 |
+
- Fine-tuning for classification and regression
|
| 15 |
+
- Autoregressive molecular generation
|
| 16 |
+
|
| 17 |
+
### Why it matters
|
| 18 |
+
Encoder models (ChemBERTa) can classify and predict properties, but they
|
| 19 |
+
CANNOT generate molecules. A causal LM like OLMo opens up:
|
| 20 |
+
- De novo molecular generation (drug discovery)
|
| 21 |
+
- Text-molecule bridging (OLMo understands English AND can learn SMILES)
|
| 22 |
+
- In-context few-shot learning without fine-tuning
|
| 23 |
+
- Transfer learning from scientific literature
|
| 24 |
+
|
| 25 |
+
### Why OLMo specifically
|
| 26 |
+
- Fully open (weights, data, training code) — unlike LLaMA/GPT
|
| 27 |
+
- OLMo-2 is natively supported in HuggingFace transformers (no custom code)
|
| 28 |
+
- 1B and 7B variants available for different compute budgets
|
| 29 |
+
- Trained on Dolma corpus which includes scientific papers
|
| 30 |
+
|
| 31 |
+
### What I've already done (reference your PR and experiments)
|
| 32 |
+
- Found and fixed a transformers 5.x compatibility bug in Chemberta (PR #4913)
|
| 33 |
+
- Filed issue #4912 documenting broader transformers 5.x compat gap
|
| 34 |
+
- Built a working OLMo wrapper prototype (locally) with:
|
| 35 |
+
- Olmo2ForSequenceClassification (doesn't exist in transformers)
|
| 36 |
+
- Causal LM pretraining on SMILES
|
| 37 |
+
- All 8 unit tests passing
|
| 38 |
+
- Ran experiments on real MoleculeNet data:
|
| 39 |
+
- BBBP classification: ROC-AUC 0.67 (random init, tiny model)
|
| 40 |
+
- ESOL regression: R^2 = 0.37
|
| 41 |
+
- SMILES generation: 0% validity (expected — proves pretraining is essential)
|
| 42 |
+
|
| 43 |
+
---
|
| 44 |
+
|
| 45 |
+
## 2. RELEVANT EXPERIENCE & INTEREST
|
| 46 |
+
|
| 47 |
+
### Technical background
|
| 48 |
+
- Parameter Golf (OpenAI competition, March 2026): Trained language models
|
| 49 |
+
from scratch under 16MB constraint. Custom SentencePiece tokenizers,
|
| 50 |
+
GPTQ-lite quantization, flash attention, architecture design (11L 512d
|
| 51 |
+
transformer). This is directly relevant — I understand transformer
|
| 52 |
+
training at a low level.
|
| 53 |
+
- GSPO-DeepSeek-R1-Distill-Qwen-1.5B (15 GitHub stars): Fine-tuning and
|
| 54 |
+
distillation of large language models.
|
| 55 |
+
- wingman-AI (29 GitHub stars): Production AI assistant system.
|
| 56 |
+
- Open source contributions: PRs to HuggingFace transformers, Unsloth,
|
| 57 |
+
Anthropic SDK, OpenAI SDK, Karpathy's nanochat.
|
| 58 |
+
|
| 59 |
+
### Why I want to work on this
|
| 60 |
+
[WRITE THIS YOURSELF — what genuinely interests you about molecular ML?
|
| 61 |
+
Why DeepChem? Be specific and honest. Don't say "I'm passionate about
|
| 62 |
+
open source" — say what specific thing drew you to this project.]
|
| 63 |
+
|
| 64 |
+
### Links
|
| 65 |
+
- GitHub: https://github.com/vivekvar-dl
|
| 66 |
+
- PR #4913: https://github.com/deepchem/deepchem/pull/4913
|
| 67 |
+
- Issue #4912: https://github.com/deepchem/deepchem/issues/4912
|
| 68 |
+
|
| 69 |
+
---
|
| 70 |
+
|
| 71 |
+
## 3. WORK PLAN
|
| 72 |
+
|
| 73 |
+
### 3.1 Design
|
| 74 |
+
|
| 75 |
+
The implementation has four components:
|
| 76 |
+
|
| 77 |
+
**Component A: Base class changes to HuggingFaceModel**
|
| 78 |
+
- Add `causal_lm` task support (DataCollatorForLanguageModeling with mlm=False)
|
| 79 |
+
- Add `AutoModelForCausalLM` branch in load_from_pretrained()
|
| 80 |
+
- Add `generate()` method for autoregressive text generation
|
| 81 |
+
- Add causal LM batch preparation in _prepare_batch()
|
| 82 |
+
|
| 83 |
+
Note: PR #4907 by another contributor adds a similar generate() method.
|
| 84 |
+
My work is complementary — I'm adding a full model wrapper, not just
|
| 85 |
+
generation plumbing.
|
| 86 |
+
|
| 87 |
+
**Component B: Olmo2ForSequenceClassification**
|
| 88 |
+
This class DOES NOT EXIST in HuggingFace transformers. OLMo only has
|
| 89 |
+
OlmoForCausalLM — no classification head. I built one:
|
| 90 |
+
- Extends Olmo2PreTrainedModel
|
| 91 |
+
- Uses last-token pooling (last non-padded token's hidden state)
|
| 92 |
+
- Linear projection head for classification/regression
|
| 93 |
+
- Supports single-label, multi-label, and regression via problem_type config
|
| 94 |
+
- Computes CrossEntropyLoss / BCEWithLogitsLoss / MSELoss based on task
|
| 95 |
+
|
| 96 |
+
This follows the same pattern as LlamaForSequenceClassification.
|
| 97 |
+
|
| 98 |
+
**Component C: OLMo wrapper class**
|
| 99 |
+
```
|
| 100 |
+
OLMo(HuggingFaceModel)
|
| 101 |
+
__init__(task, model_name, n_tasks, config)
|
| 102 |
+
- task: causal_lm | regression | classification | mtr
|
| 103 |
+
- Loads tokenizer from HuggingFace Hub
|
| 104 |
+
- Sets pad_token = eos_token (decoder models don't have pad by default)
|
| 105 |
+
- Syncs vocab_size between config and tokenizer
|
| 106 |
+
- Creates appropriate model class based on task
|
| 107 |
+
|
| 108 |
+
_prepare_batch(batch)
|
| 109 |
+
- causal_lm: labels = input_ids (model shifts internally)
|
| 110 |
+
- regression/classification: labels from dataset, proper dtype casting
|
| 111 |
+
- Multi-task classification: float labels for BCEWithLogitsLoss
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
**Component D: Tokenization strategy**
|
| 115 |
+
Phase 1 (GSoC): Use OLMo's pretrained tokenizer as-is on SMILES.
|
| 116 |
+
- OLMo's 100K BPE vocab actually tokenizes SMILES more efficiently
|
| 117 |
+
than ChemBERTa's 600-token vocab (0.9x token ratio in my analysis)
|
| 118 |
+
- BUT it fragments chemical semantics: [C@@H] -> [C, @@, H, ]
|
| 119 |
+
- ChemBERTa learns chemistry-aware merges: (=O), ccccc, COc
|
| 120 |
+
|
| 121 |
+
Phase 2 (stretch): Extend tokenizer with SMILES-specific tokens.
|
| 122 |
+
- Add special tokens for stereochemistry: [C@@H], [C@H], [nH]
|
| 123 |
+
- Add aromatic ring tokens: c1ccccc1
|
| 124 |
+
- Retrain BPE on mixed English + SMILES corpus
|
| 125 |
+
|
| 126 |
+
### 3.2 Pseudocode
|
| 127 |
+
|
| 128 |
+
Olmo2ForSequenceClassification.forward():
|
| 129 |
+
```
|
| 130 |
+
hidden_states = self.model(input_ids, attention_mask)
|
| 131 |
+
# Pool: use last non-padded token
|
| 132 |
+
seq_lengths = (input_ids != pad_token_id).sum(-1) - 1
|
| 133 |
+
pooled = hidden_states[batch_range, seq_lengths]
|
| 134 |
+
logits = self.score(pooled) # Linear(hidden_size, num_labels)
|
| 135 |
+
if labels:
|
| 136 |
+
if regression: loss = MSELoss(logits, labels)
|
| 137 |
+
if single_class: loss = CrossEntropy(logits, labels)
|
| 138 |
+
if multi_label: loss = BCEWithLogits(logits, labels)
|
| 139 |
+
return {loss, logits}
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
OLMo._prepare_batch() for causal_lm:
|
| 143 |
+
```
|
| 144 |
+
tokens = tokenizer(smiles_list, padding=True)
|
| 145 |
+
input_ids = tokens.input_ids.to(device)
|
| 146 |
+
labels = input_ids.clone() # next-token prediction
|
| 147 |
+
return {input_ids, attention_mask, labels}
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
HuggingFaceModel.generate():
|
| 151 |
+
```
|
| 152 |
+
tokens = tokenizer(inputs, padding=True)
|
| 153 |
+
output_ids = model.generate(**tokens, max_new_tokens=N, **kwargs)
|
| 154 |
+
return tokenizer.batch_decode(output_ids)
|
| 155 |
+
```
|
| 156 |
+
|
| 157 |
+
### 3.3 Testing Plan
|
| 158 |
+
|
| 159 |
+
8 unit tests (all passing in my prototype):
|
| 160 |
+
|
| 161 |
+
| Test | What it validates |
|
| 162 |
+
|------|-------------------|
|
| 163 |
+
| test_olmo_causal_lm_pretraining | Causal LM trains, loss > 0 |
|
| 164 |
+
| test_olmo_regression_finetuning | Regression trains, predictions match shape, MAE computable |
|
| 165 |
+
| test_olmo_classification | Classification on binary labels, loss > 0 |
|
| 166 |
+
| test_olmo_multitask_regression | MTR with 2 tasks, predictions shape matches |
|
| 167 |
+
| test_olmo_save_and_restore | Checkpoint save/load, weights match exactly |
|
| 168 |
+
| test_olmo_load_from_pretrained | Pretrain causal LM -> load into regression model |
|
| 169 |
+
| test_olmo_generate | Single and batch generation returns strings |
|
| 170 |
+
| test_olmo_invalid_task | ValueError on bad task name |
|
| 171 |
+
|
| 172 |
+
All tests use a tiny config (64 hidden, 2 layers, 2 heads) — no model
|
| 173 |
+
download needed, runs in ~27 seconds on CPU.
|
| 174 |
+
|
| 175 |
+
Integration tests (to add during GSoC):
|
| 176 |
+
- MoleculeNet benchmarks: BBBP, ESOL, FreeSolv, Lipophilicity
|
| 177 |
+
- SMILES generation validity (RDKit validation)
|
| 178 |
+
- Continued pretraining convergence on ZINC/PubChem subsets
|
| 179 |
+
|
| 180 |
+
### 3.4 Sources of Risk
|
| 181 |
+
|
| 182 |
+
| Risk | Likelihood | Mitigation |
|
| 183 |
+
|------|-----------|------------|
|
| 184 |
+
| OLMo-7B requires ~14GB VRAM for inference | High | Use OLMo-1B for CI/demos. Test with tiny configs. Document GPU requirements. |
|
| 185 |
+
| SMILES generation validity low without extensive pretraining | High | This IS the core problem. Budget 3 weeks for pretraining experiments. Use ZINC-250K as training corpus. Target >50% validity. |
|
| 186 |
+
| Olmo2ForSequenceClassification not upstream | Medium | Our implementation follows HF patterns exactly. If HF adds it later, we swap to theirs. |
|
| 187 |
+
| Tokenizer fragments chemical semantics | Medium | Phase 1: works as-is (my experiments show learning happens). Phase 2: extend vocabulary. |
|
| 188 |
+
| transformers version compatibility | Low | Already found and fixed one issue (PR #4913). Use top-level imports throughout. |
|
| 189 |
+
|
| 190 |
+
### 3.5 Milestones & Timeline
|
| 191 |
+
|
| 192 |
+
Assuming Medium size (175 hours, ~12 weeks):
|
| 193 |
+
|
| 194 |
+
**Milestone 1: Core wrapper (Weeks 1-3)**
|
| 195 |
+
- PR: Base class changes to HuggingFaceModel (causal_lm task, generate())
|
| 196 |
+
- Coordinate with PR #4907 to avoid duplication
|
| 197 |
+
- PR: Olmo2ForSequenceClassification
|
| 198 |
+
- PR: OLMo wrapper class with all task modes
|
| 199 |
+
- PR: Unit tests (8 tests)
|
| 200 |
+
- Deliverable: `from deepchem.models import OLMo` works for all tasks
|
| 201 |
+
|
| 202 |
+
**Milestone 2: Continued pretraining (Weeks 4-6)**
|
| 203 |
+
- PR: Pretraining pipeline on molecular data (ZINC-250K)
|
| 204 |
+
- PR: Data loading utilities for SMILES corpora
|
| 205 |
+
- PR: Pretraining tutorial notebook
|
| 206 |
+
- Deliverable: Pretrained OLMo checkpoint on molecular data
|
| 207 |
+
|
| 208 |
+
**Milestone 3: Fine-tuning & benchmarks (Weeks 7-9)**
|
| 209 |
+
- PR: Classification tutorial (BBBP, Tox21)
|
| 210 |
+
- PR: Regression tutorial (ESOL, FreeSolv, Lipophilicity)
|
| 211 |
+
- PR: Benchmark results table vs ChemBERTa
|
| 212 |
+
- Deliverable: Published benchmark comparing OLMo vs ChemBERTa on MoleculeNet
|
| 213 |
+
|
| 214 |
+
**Milestone 4: Generation & polish (Weeks 10-12)**
|
| 215 |
+
- PR: SMILES generation tutorial with RDKit validity checking
|
| 216 |
+
- PR: Documentation (numpydoc, API reference, user guide)
|
| 217 |
+
- PR: Tokenizer extension experiments (stretch goal)
|
| 218 |
+
- Deliverable: Complete documentation and tutorials
|
| 219 |
+
|
| 220 |
+
Each milestone = 1 evaluation checkpoint. PRs are <50 lines where possible,
|
| 221 |
+
following DeepChem's contribution guidelines.
|
| 222 |
+
|
| 223 |
+
### 3.6 Pull Request Plan
|
| 224 |
+
|
| 225 |
+
I will follow DeepChem's guidelines: small PRs (<50 lines for initial ones),
|
| 226 |
+
with tests and numpydoc documentation. Expected ~8-12 PRs total:
|
| 227 |
+
|
| 228 |
+
1. HuggingFaceModel causal_lm support (~40 lines)
|
| 229 |
+
2. generate() method (~50 lines)
|
| 230 |
+
3. Olmo2ForSequenceClassification (~100 lines — larger, will discuss with mentor)
|
| 231 |
+
4. OLMo wrapper class (~80 lines)
|
| 232 |
+
5. Unit tests (~180 lines)
|
| 233 |
+
6. Pretraining pipeline
|
| 234 |
+
7. Data utilities
|
| 235 |
+
8. Tutorial notebooks (3-4 notebooks)
|
| 236 |
+
9. Documentation updates
|
| 237 |
+
10. Benchmark scripts
|
| 238 |
+
|
| 239 |
+
---
|
| 240 |
+
|
| 241 |
+
## 4. COMMUNITY ENGAGEMENT
|
| 242 |
+
|
| 243 |
+
- Already contributing: PR #4913 (bug fix), Issue #4912 (compat report)
|
| 244 |
+
- Will attend office hours MWF 9am PST
|
| 245 |
+
- Will join Discord for async discussion
|
| 246 |
+
- Will write weekly progress updates
|
| 247 |
+
- Happy to review other contributors' HuggingFace-related PRs
|
| 248 |
+
|
| 249 |
+
---
|
| 250 |
+
|
| 251 |
+
## 5. RESOURCES REQUIRED
|
| 252 |
+
|
| 253 |
+
- GPU: I have access to 1x H100 NVL 96GB (Azure) for development
|
| 254 |
+
- For CI: tiny model configs, no GPU needed
|
| 255 |
+
- For pretraining experiments: my H100 is sufficient for OLMo-1B
|
| 256 |
+
- OLMo-7B experiments: may need multi-GPU setup (discuss with mentor)
|
| 257 |
+
|
| 258 |
+
---
|
| 259 |
+
|
| 260 |
+
## 6. BIBLIOGRAPHY
|
| 261 |
+
|
| 262 |
+
1. Groeneveld et al. (2024). "OLMo: Accelerating the Science of Language Models." arXiv:2402.00838
|
| 263 |
+
2. Chithrananda et al. (2020). "ChemBERTa: Large-Scale Self-Supervised Pretraining for Molecular Property Prediction." arXiv:2010.09885
|
| 264 |
+
3. Ross et al. (2022). "Large-Scale Chemical Language Representations Capture Molecular Structure and Properties." Nature Machine Intelligence.
|
| 265 |
+
4. Weininger (1988). "SMILES, a chemical language and information system." J. Chem. Inf. Comput. Sci.
|
| 266 |
+
5. Wu et al. (2018). "MoleculeNet: A Benchmark for Molecular Machine Learning." Chemical Science.
|
| 267 |
+
|
| 268 |
+
---
|
| 269 |
+
|
| 270 |
+
## KEY NUMBERS FROM YOUR EXPERIMENTS (reference these in proposal)
|
| 271 |
+
|
| 272 |
+
- Tokenization: OLMo uses 0.9x tokens vs ChemBERTa on drug molecules
|
| 273 |
+
- BBBP classification: ROC-AUC 0.67 (random init, 12.9M param model, 200 samples, 3 epochs)
|
| 274 |
+
- ESOL regression: R^2 = 0.37, MAE = 1.27 (same conditions)
|
| 275 |
+
- SMILES generation: 0% validity from random init (proves pretraining is the core challenge)
|
| 276 |
+
- Test suite: 8/8 tests pass in 27 seconds on CPU
|
| 277 |
+
- Stereochemistry fragmentation: [C@@H] splits into 4 tokens in OLMo vs 7 in ChemBERTa
|
gsoc_proposal_final.md
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LLM Support for 7B Models (OLMo) in DeepChem
|
| 2 |
+
|
| 3 |
+
Vivek Varikuti
|
| 4 |
+
github.com/vivekvar-dl
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
## Introduction
|
| 8 |
+
|
| 9 |
+
DeepChems HuggingFaceModel wrapper works well for encoder models. ChemBERTa, MoLFormer, you can do MLM pretraining and finetune for classification or regression. But it has no support for decoder-only models. Like at all. You cant use any GPT style causal LM, and theres no way to do text generation.
|
| 10 |
+
|
| 11 |
+
I want to add OLMo-2 (Allen AIs language model) to DeepChem. The idea is to make it actually useful for molecular work, not just wrap an API. That means you should be able to pretrain on SMILES data, finetune for property prediction, and generate new molecules.
|
| 12 |
+
|
| 13 |
+
The reason I think this matters is that ChemBERTa can predict properties of molecules but it fundamentally cannot generate new ones. Its an encoder. OLMo is a causal LM so it can actually produce novel SMILES. For drug discovery thats a big deal. Also OLMo was trained on the Dolma corpus which has a bunch of scientific papers in it, so theres already some chemistry knowledge baked in before you even start finetuning.
|
| 14 |
+
|
| 15 |
+
Why OLMo and not some other model? Mainly because its actually fully open. Weights, data, code, everything is public. No weird license stuff like LLaMA. And OLMo-2 works natively in HuggingFace transformers without needing custom packages (the older OLMo-7B needs hf_olmo installed which is annoying). It also has a 1B version which is great for testing.
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
### What I already did
|
| 19 |
+
|
| 20 |
+
I didnt want to just write a proposal without touching the code so I cloned DeepChem and started building.
|
| 21 |
+
|
| 22 |
+
First thing that happened was ChemBERTa wouldnt even import. Turns out `transformers.models.roberta.tokenization_roberta_fast` got removed in transformers 5.x. Nobody had reported it. I fixed it (PR #4913) and filed a bigger issue about transformers 5.x compatibility (#4912) because theres more stuff broken beyond just the import.
|
| 23 |
+
|
| 24 |
+
After that I started building an OLMo wrapper. Ran into an interesting problem pretty quickly. HuggingFace has OlmoForCausalLM but theres no OlmoForSequenceClassification. It doesnt exist. So if you want to do regression or classification with OLMo you have to build the classification head yourself. I wrote one using last-token pooling, basically the same thing LlamaForSequenceClassification does internally.
|
| 25 |
+
|
| 26 |
+
Got everything working and ran some quick experiments on MoleculeNet:
|
| 27 |
+
|
| 28 |
+
BBBP (blood brain barrier classification): ROC-AUC 0.67. This was with a tiny model, random init, 200 training samples, 3 epochs. Not great but its above 0.5 so the architecture is clearly learning something.
|
| 29 |
+
|
| 30 |
+
ESOL (solubility regression): R squared 0.37. Same deal, tiny model from scratch.
|
| 31 |
+
|
| 32 |
+
SMILES generation: 0% valid molecules. Every single generated SMILES was broken.
|
| 33 |
+
|
| 34 |
+
That generation result is honestly the most useful thing I found. It tells you exactly where the hard problem is. The wrapper code works, the training loop works, all the plumbing is fine. But without real pretraining on a molecular corpus the model just outputs garbage. Thats what this GSoC project needs to solve.
|
| 35 |
+
|
| 36 |
+
I also compared how OLMo and ChemBERTa tokenize drug molecules. Tested on aspirin, caffeine, penicillin, paclitaxel etc. OLMo actually uses fewer tokens overall (100K vocab vs 600) but it breaks up chemistry in weird ways. Like [C@@H] is a single concept (stereocenter) but OLMo splits it into four tokens. ChemBERTas tokenizer learned chemical groupings like (=O) and ccccc that make more sense.
|
| 37 |
+
|
| 38 |
+
Wrote 8 unit tests, all pass in 27 sec on CPU.
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
## About Me
|
| 42 |
+
|
| 43 |
+
Im Vivek. I build AI stuff.
|
| 44 |
+
|
| 45 |
+
The thing thats most relevant here is probably the Parameter Golf competition by OpenAI (happening right now, March 2026). You have to train the best language model that fits in 16MB total. I built custom SentencePiece tokenizers, did GPTQ quantization, designed the transformer architecture from scratch (11 layers, 512 dim, 8 heads). Dealt with flash attention compatibility issues across different hardware. Point is I actually understand how transformers work at a low level, not just how to call .fit() on them.
|
| 46 |
+
|
| 47 |
+
Other stuff: GSPO-DeepSeek-R1-Distill-Qwen-1.5B (15 stars) where I did LLM distillation, wingman-AI (29 stars) which is a production AI assistant.
|
| 48 |
+
|
| 49 |
+
Ive submitted PRs to a bunch of repos. HuggingFace transformers, Unsloth (aarch64 support), Anthropic SDK (fixed a streaming bottleneck), OpenAI python SDK, Karpathys nanochat (NaN loss bug in SFT). Mix of bug fixes and features.
|
| 50 |
+
|
| 51 |
+
For DeepChem I found a real bug on day one (PR #4913), reported the broader compat issue (#4912), and built the OLMo prototype with passing tests.
|
| 52 |
+
|
| 53 |
+
I have GPU access through Azure cloud for development.
|
| 54 |
+
|
| 55 |
+
GitHub: https://github.com/vivekvar-dl
|
| 56 |
+
Bug fix: https://github.com/deepchem/deepchem/pull/4913
|
| 57 |
+
Issue: https://github.com/deepchem/deepchem/issues/4912
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
## Work Plan
|
| 61 |
+
|
| 62 |
+
### Design
|
| 63 |
+
|
| 64 |
+
Four things need to be built.
|
| 65 |
+
|
| 66 |
+
**1. Base class changes (HuggingFaceModel)**
|
| 67 |
+
|
| 68 |
+
Right now HuggingFaceModel doesnt know causal LMs exist. Need to add:
|
| 69 |
+
- "causal_lm" as a task type, using DataCollatorForLanguageModeling(mlm=False)
|
| 70 |
+
- AutoModelForCausalLM branch in load_from_pretrained()
|
| 71 |
+
- A generate() method wrapping HFs model.generate()
|
| 72 |
+
- _prepare_batch() handling for causal LM where labels = input_ids
|
| 73 |
+
|
| 74 |
+
Theres already PR #4907 from someone else that adds generation. My stuff is different, Im building a whole model wrapper not just generation support. But will coordinate so we dont duplicate work.
|
| 75 |
+
|
| 76 |
+
**2. Olmo2ForSequenceClassification**
|
| 77 |
+
|
| 78 |
+
This class doesnt exist in HuggingFace. Had to write it.
|
| 79 |
+
|
| 80 |
+
How it works: run input through Olmo2Model, take the last non-padded tokens hidden state, project through a linear layer. Loss depends on the problem type — MSELoss for regression, CrossEntropyLoss for classification, BCEWithLogitsLoss for multi label. About 100 lines total, same pattern as LlamaForSequenceClassification.
|
| 81 |
+
|
| 82 |
+
**3. OLMo wrapper class**
|
| 83 |
+
|
| 84 |
+
User facing class extending HuggingFaceModel. Same structure as ChemBERTa/MoLFormer.
|
| 85 |
+
|
| 86 |
+
```
|
| 87 |
+
OLMo(HuggingFaceModel)
|
| 88 |
+
__init__(task, model_name, n_tasks, config)
|
| 89 |
+
task: causal_lm | regression | classification | mtr
|
| 90 |
+
Loads tokenizer, sets pad_token = eos_token
|
| 91 |
+
Syncs vocab_size with tokenizer
|
| 92 |
+
Picks model class based on task
|
| 93 |
+
|
| 94 |
+
_prepare_batch(batch)
|
| 95 |
+
causal_lm: labels = input_ids clone
|
| 96 |
+
regression/classification: labels from dataset, right dtype
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
**4. Tokenization**
|
| 100 |
+
|
| 101 |
+
For now, use OLMos tokenizer as is on SMILES. My experiments show it works well enough to learn (0.67 AUC, 0.37 R2 from random init even). Not perfect with stereocenters but functional.
|
| 102 |
+
|
| 103 |
+
Stretch goal: extend the vocab with chemistry tokens like [C@@H], (=O), aromatic rings. Retrain BPE on English+SMILES mix.
|
| 104 |
+
|
| 105 |
+
### Pseudocode
|
| 106 |
+
|
| 107 |
+
Sequence classification forward:
|
| 108 |
+
```
|
| 109 |
+
hidden = base_model(input_ids, attention_mask)
|
| 110 |
+
seq_lengths = (input_ids != pad_id).sum(-1) - 1
|
| 111 |
+
pooled = hidden[range(batch_size), seq_lengths]
|
| 112 |
+
logits = linear_head(pooled)
|
| 113 |
+
loss = compute_loss(logits, labels, problem_type)
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
Causal LM batch:
|
| 117 |
+
```
|
| 118 |
+
tokens = tokenizer(smiles_list, padding=True)
|
| 119 |
+
inputs = {input_ids, attention_mask, labels: input_ids.clone()}
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
Generate:
|
| 123 |
+
```
|
| 124 |
+
encoded = tokenizer(prompts, padding=True)
|
| 125 |
+
output_ids = model.generate(**encoded, max_new_tokens=N)
|
| 126 |
+
return tokenizer.batch_decode(output_ids)
|
| 127 |
+
```
|
| 128 |
+
|
| 129 |
+
### Testing
|
| 130 |
+
|
| 131 |
+
8 tests written and passing:
|
| 132 |
+
|
| 133 |
+
1. Causal LM pretraining (loss > 0)
|
| 134 |
+
2. Regression finetuning (correct prediction shape)
|
| 135 |
+
3. Classification (binary labels)
|
| 136 |
+
4. Multitask regression (2 targets)
|
| 137 |
+
5. Save and restore checkpoint (weights match)
|
| 138 |
+
6. Load pretrained into regression model
|
| 139 |
+
7. Generation (single + batch)
|
| 140 |
+
8. Invalid task raises error
|
| 141 |
+
|
| 142 |
+
Tiny config, no downloads, 27 sec on CPU.
|
| 143 |
+
|
| 144 |
+
Will add during GSoC:
|
| 145 |
+
- MoleculeNet benchmarks (BBBP, ESOL, FreeSolv, Lipophilicity)
|
| 146 |
+
- Generation validity checking with RDKit
|
| 147 |
+
- Pretraining convergence on ZINC
|
| 148 |
+
|
| 149 |
+
### Risks
|
| 150 |
+
|
| 151 |
+
Generation quality is the big one. 0% validity from random init is expected obviously but getting to 50%+ valid SMILES needs real pretraining on a decent corpus. Im allocating 3 weeks for this using ZINC-250K.
|
| 152 |
+
|
| 153 |
+
OLMo-7B is ~14GB just for inference. CI uses tiny configs so no GPU needed there. OLMo-1B for demos. 7B for real benchmarks, might need multi GPU, will figure that out with mentor.
|
| 154 |
+
|
| 155 |
+
Olmo2ForSequenceClassification isnt upstream. If HF adds one later we swap ours out.
|
| 156 |
+
|
| 157 |
+
Transformers compat — already found one issue, using top level imports everywhere going forward.
|
| 158 |
+
|
| 159 |
+
### Timeline
|
| 160 |
+
|
| 161 |
+
12 weeks (Medium, 175 hours):
|
| 162 |
+
|
| 163 |
+
**Weeks 1-3**
|
| 164 |
+
Get the core wrapper merged. Small PRs, follow DeepChems contribution guidelines. Base class changes, Olmo2ForSequenceClassification, OLMo wrapper, tests. By week 3 you should be able to do `from deepchem.models import OLMo`.
|
| 165 |
+
|
| 166 |
+
**Weeks 4-6**
|
| 167 |
+
Pretraining pipeline. Load SMILES from ZINC-250K, causal LM training, checkpointing. Tutorial notebook for pretraining on custom molecular data.
|
| 168 |
+
|
| 169 |
+
**Weeks 7-9**
|
| 170 |
+
Finetune pretrained model on MoleculeNet. BBBP and Tox21 classification, ESOL and FreeSolv and Lipophilicity regression. Benchmark table vs ChemBERTa.
|
| 171 |
+
|
| 172 |
+
**Weeks 10-12**
|
| 173 |
+
Generation experiments with RDKit validity checking. Tutorial notebooks. Docs. If time allows, tokenizer extension experiments.
|
| 174 |
+
|
| 175 |
+
### PRs
|
| 176 |
+
|
| 177 |
+
Small PRs especially at first. Bigger ones (Olmo2ForSequenceClassification ~100 lines) I will discuss with mentor before submitting. Expecting 8-12 PRs across the summer.
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
## Community
|
| 181 |
+
|
| 182 |
+
Already in:
|
| 183 |
+
- PR #4913 (bug fix)
|
| 184 |
+
- Issue #4912 (compat report)
|
| 185 |
+
|
| 186 |
+
Going to do office hours MWF 9am PST, Discord for regular questions, weekly updates.
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
## Resources
|
| 190 |
+
|
| 191 |
+
GPU access through Azure cloud. Tiny configs for CI. OLMo-7B training setup to be discussed with mentor depending on whats needed.
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
## References
|
| 195 |
+
|
| 196 |
+
1. Groeneveld et al. (2024). OLMo: Accelerating the Science of Language Models. arXiv:2402.00838
|
| 197 |
+
2. Chithrananda et al. (2020). ChemBERTa: Large-Scale Self-Supervised Pretraining for Molecular Property Prediction. arXiv:2010.09885
|
| 198 |
+
3. Ross et al. (2022). Large-Scale Chemical Language Representations Capture Molecular Structure and Properties. Nature Machine Intelligence
|
| 199 |
+
4. Weininger (1988). SMILES, a chemical language and information system. J Chem Inf Comput Sci
|
| 200 |
+
5. Wu et al. (2018). MoleculeNet: A Benchmark for Molecular Machine Learning. Chemical Science
|
h100_training.log
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
WARNING ⚠️ user config directory '/home/azureuser/.config/Ultralytics' is not writable, using '/tmp/Ultralytics'. Set YOLO_CONFIG_DIR to override.
|
| 2 |
+
Creating new Ultralytics Settings v0.0.6 file ✅
|
| 3 |
+
View Ultralytics Settings with 'yolo settings' or at '/tmp/Ultralytics/settings.json'
|
| 4 |
+
Update Settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'. For help see https://docs.ultralytics.com/quickstart/#ultralytics-settings.
|
| 5 |
+
CUDA initialization: The NVIDIA driver on your system is too old (found version 12080). Please update your GPU driver by downloading and installing a new version from the URL: http://www.nvidia.com/Download/index.aspx Alternatively, go to: https://pytorch.org to install a PyTorch version that has been compiled with your version of the CUDA driver. (Triggered internally at /pytorch/c10/cuda/CUDAFunctions.cpp:119.)
|
| 6 |
+
======================================================================
|
| 7 |
+
FINAL TRAINING ON H100 - BALANCED DATASET
|
| 8 |
+
======================================================================
|
| 9 |
+
|
| 10 |
+
GPU Available: False
|
| 11 |
+
|
| 12 |
+
======================================================================
|
| 13 |
+
STEP 1: Downloading Datasets from Roboflow
|
| 14 |
+
======================================================================
|
| 15 |
+
|
| 16 |
+
Dataset 1: New helmet images (212)...
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
Dataset 2: No-helmet images (499)...
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
Dataset 3: With-helmet images (300)...
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
Dataset 4: Triple-riding (626)...
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
✅ All datasets downloaded!
|
| 45 |
+
|
| 46 |
+
======================================================================
|
| 47 |
+
STEP 2: Merging ALL Datasets
|
| 48 |
+
======================================================================
|
| 49 |
+
|
| 50 |
+
Unified classes (8): ['Helmet', 'Motorcycle', 'Rider', 'Triple Riding', 'helmet', 'more-than-2-person-on-2-wheeler', 'no helmet', 'with helmet']
|
| 51 |
+
|
| 52 |
+
Copying datasets...
|
| 53 |
+
helmet212: 206 images
|
| 54 |
+
nohelmet499: 496 images
|
| 55 |
+
withhelmet300: 242 images
|
| 56 |
+
triple626: 626 images
|
| 57 |
+
|
| 58 |
+
Final merged dataset:
|
| 59 |
+
train: 1335 images
|
| 60 |
+
valid: 186 images
|
| 61 |
+
test: 116 images
|
| 62 |
+
|
| 63 |
+
Config saved: /home/azureuser/final_merged_h100/data.yaml
|
| 64 |
+
|
| 65 |
+
======================================================================
|
| 66 |
+
STEP 3: TRAINING ON H100 (96GB VRAM!)
|
| 67 |
+
======================================================================
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
Training config:
|
| 71 |
+
Model: YOLO26m
|
| 72 |
+
Epochs: 150 (faster with H100)
|
| 73 |
+
Batch: -1 (auto - H100 can handle 64-128!)
|
| 74 |
+
Image size: 640
|
| 75 |
+
Classes: 8
|
| 76 |
+
|
| 77 |
+
Starting training...
|
| 78 |
+
Ultralytics 8.4.37 🚀 Python-3.12.3 torch-2.11.0+cu130
|
| 79 |
+
Traceback (most recent call last):
|
| 80 |
+
File "/home/azureuser/train_h100_final.py", line 182, in <module>
|
| 81 |
+
results = model.train(
|
| 82 |
+
^^^^^^^^^^^^
|
| 83 |
+
File "/home/azureuser/yolo_h100_env/lib/python3.12/site-packages/ultralytics/engine/model.py", line 781, in train
|
| 84 |
+
self.trainer = (trainer or self._smart_load("trainer"))(overrides=args, _callbacks=self.callbacks)
|
| 85 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 86 |
+
File "/home/azureuser/yolo_h100_env/lib/python3.12/site-packages/ultralytics/models/yolo/detect/train.py", line 63, in __init__
|
| 87 |
+
super().__init__(cfg, overrides, _callbacks)
|
| 88 |
+
File "/home/azureuser/yolo_h100_env/lib/python3.12/site-packages/ultralytics/engine/trainer.py", line 128, in __init__
|
| 89 |
+
self.device = select_device(self.args.device)
|
| 90 |
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
| 91 |
+
File "/home/azureuser/yolo_h100_env/lib/python3.12/site-packages/ultralytics/utils/torch_utils.py", line 230, in select_device
|
| 92 |
+
raise ValueError(
|
| 93 |
+
ValueError: Invalid CUDA 'device=0' requested. Use 'device=cpu' or pass valid CUDA device(s) if available, i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.
|
| 94 |
+
|
| 95 |
+
torch.cuda.is_available(): False
|
| 96 |
+
torch.cuda.device_count(): 1
|
| 97 |
+
os.environ['CUDA_VISIBLE_DEVICES']: None
|
| 98 |
+
|
make_proposal_doc.py
ADDED
|
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from docx import Document
|
| 2 |
+
from docx.shared import Pt, Inches
|
| 3 |
+
from docx.enum.text import WD_ALIGN_PARAGRAPH
|
| 4 |
+
|
| 5 |
+
doc = Document()
|
| 6 |
+
|
| 7 |
+
style = doc.styles['Normal']
|
| 8 |
+
font = style.font
|
| 9 |
+
font.name = 'Arial'
|
| 10 |
+
font.size = Pt(11)
|
| 11 |
+
|
| 12 |
+
# Title
|
| 13 |
+
title = doc.add_heading('LLM Support for 7B Models (OLMo) in DeepChem', level=1)
|
| 14 |
+
title.alignment = WD_ALIGN_PARAGRAPH.CENTER
|
| 15 |
+
|
| 16 |
+
p = doc.add_paragraph()
|
| 17 |
+
p.alignment = WD_ALIGN_PARAGRAPH.CENTER
|
| 18 |
+
run = p.add_run('V SaraVivek (Vivek Varikuti)')
|
| 19 |
+
run.bold = True
|
| 20 |
+
run.font.size = Pt(12)
|
| 21 |
+
p.add_run('\nvivekvarikuti22@gmail.com | github.com/vivekvar-dl | vivekvari.dev')
|
| 22 |
+
|
| 23 |
+
doc.add_paragraph('')
|
| 24 |
+
|
| 25 |
+
# ============ INTRODUCTION ============
|
| 26 |
+
doc.add_heading('Introduction', level=2)
|
| 27 |
+
|
| 28 |
+
doc.add_paragraph(
|
| 29 |
+
'DeepChems HuggingFaceModel wrapper does a solid job with encoder models. '
|
| 30 |
+
'ChemBERTa, MoLFormer, masked language modeling, classification, regression — all of that works. '
|
| 31 |
+
'But theres a pretty fundamental gap. It has zero support for decoder-only causal language models. '
|
| 32 |
+
'No GPT-style models, no text generation, nothing autoregressive.'
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
doc.add_paragraph(
|
| 36 |
+
'This project adds OLMo-2 (Allen AIs open language model) to DeepChem. '
|
| 37 |
+
'The goal isnt just to wrap another HuggingFace model though. '
|
| 38 |
+
'Its to make a causal LM genuinely useful for molecular science. '
|
| 39 |
+
'That means continued pretraining on SMILES strings, finetuning for property prediction tasks, '
|
| 40 |
+
'and most importantly — generating new molecules.'
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
doc.add_paragraph(
|
| 44 |
+
'Why does this matter for the community? ChemBERTa can tell you about a molecule but it cant make new ones. '
|
| 45 |
+
'Its an encoder, thats not what encoders do. A causal LM like OLMo can actually produce novel SMILES strings. '
|
| 46 |
+
'For drug discovery and molecular design thats a really big deal. '
|
| 47 |
+
'And because OLMo was pretrained on the Dolma corpus (which includes tons of scientific papers), '
|
| 48 |
+
'it already has some chemistry knowledge baked in before you even start finetuning on molecular data.'
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
doc.add_paragraph(
|
| 52 |
+
'I picked OLMo specifically because its fully open — weights, training data, code, all of it. '
|
| 53 |
+
'No license restrictions like LLaMA. OLMo-2 is natively supported in HuggingFace transformers so you '
|
| 54 |
+
'dont need custom packages. And it has 1B and 7B variants which is nice for development vs production.'
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
doc.add_heading('What I already built', level=3)
|
| 58 |
+
|
| 59 |
+
doc.add_paragraph(
|
| 60 |
+
'I cloned DeepChem and started building before writing this proposal. '
|
| 61 |
+
'First thing that happened — ChemBERTa wouldnt import. The module '
|
| 62 |
+
'transformers.models.roberta.tokenization_roberta_fast got removed in transformers 5.x '
|
| 63 |
+
'and nobody had caught it. I fixed it in PR #4913 and filed issue #4912 about the broader '
|
| 64 |
+
'transformers 5.x compatibility problems. Thats how I got into the codebase.'
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
doc.add_paragraph(
|
| 68 |
+
'Then I started on the OLMo wrapper. Hit an interesting problem right away — '
|
| 69 |
+
'HuggingFace doesnt have OlmoForSequenceClassification. The class just doesnt exist. '
|
| 70 |
+
'So you cant do regression or classification with OLMo out of the box. '
|
| 71 |
+
'I wrote one from scratch using last-token pooling (same approach as LlamaForSequenceClassification). '
|
| 72 |
+
'Also added causal_lm as a task type in HuggingFaceModel and built a generate() method.'
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
doc.add_paragraph(
|
| 76 |
+
'Ran experiments on MoleculeNet to see if the thing actually works:'
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
doc.add_paragraph(
|
| 80 |
+
'BBBP (blood brain barrier): ROC-AUC 0.67 — tiny random init model, 200 samples, 3 epochs. '
|
| 81 |
+
'Not amazing but clearly above random chance so the architecture learns.', style='List Bullet'
|
| 82 |
+
)
|
| 83 |
+
doc.add_paragraph(
|
| 84 |
+
'ESOL (solubility): R squared 0.37 — same conditions.', style='List Bullet'
|
| 85 |
+
)
|
| 86 |
+
doc.add_paragraph(
|
| 87 |
+
'SMILES generation: 0% valid molecules. Everything it generated was broken SMILES.', style='List Bullet'
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
doc.add_paragraph(
|
| 91 |
+
'That 0% generation result is actually the most important finding. '
|
| 92 |
+
'It tells you the wrapper works fine, training works fine, but without serious pretraining on '
|
| 93 |
+
'a molecular corpus the model just outputs nonsense. Thats the core problem this project needs to solve.'
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
doc.add_paragraph(
|
| 97 |
+
'I also compared tokenization — OLMo vs ChemBERTa on real drugs (aspirin, caffeine, penicillin, paclitaxel). '
|
| 98 |
+
'OLMo uses fewer tokens overall (100K vocab vs 600) but fragments chemical concepts. '
|
| 99 |
+
'[C@@H] which is one thing in chemistry (a stereocenter) gets split into 4 tokens. '
|
| 100 |
+
'ChemBERTa learned better groupings like (=O) for carbonyl. '
|
| 101 |
+
'Something to address in the stretch goals.'
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
doc.add_paragraph('8 unit tests written, all passing in 27 seconds on CPU.')
|
| 105 |
+
|
| 106 |
+
# ============ RELEVANT EXPERIENCE ============
|
| 107 |
+
doc.add_heading('Relevant Experience and Interest', level=2)
|
| 108 |
+
|
| 109 |
+
doc.add_paragraph(
|
| 110 |
+
'Im Vivek, I just finished my B.Tech in AI & ML from Usha Rama College of Engineering (2021-2025). '
|
| 111 |
+
'Ive been working as an AI Engineer for the past year building production systems.'
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
doc.add_heading('Work experience', level=3)
|
| 115 |
+
|
| 116 |
+
doc.add_paragraph(
|
| 117 |
+
'Right now Im working with the Andhra Pradesh Police on a government AI initiative (AI4AP). '
|
| 118 |
+
'I built an AI legal compliance system for POCSO cases — RAG pipeline over 1000+ legal documents, '
|
| 119 |
+
'FastAPI backend, vector embeddings for citation-backed responses. Its a real production system that '
|
| 120 |
+
'investigating officers actually use every day. This gave me solid experience with large scale NLP, '
|
| 121 |
+
'document processing, and building things that need to work reliably.'
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
doc.add_paragraph(
|
| 125 |
+
'Before that I did an ML internship at GGS Information Services where I worked on '
|
| 126 |
+
'3D model compression (custom GANs, STEP file processing) and optimized inference pipelines '
|
| 127 |
+
'using CUDA kernels and model quantization — 30% latency reduction, 50% memory decrease. '
|
| 128 |
+
'The CUDA and quantization experience is directly relevant to working with large language models.'
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
doc.add_heading('Relevant projects', level=3)
|
| 132 |
+
|
| 133 |
+
doc.add_paragraph(
|
| 134 |
+
'Parameter Golf (OpenAI, March 2026): Train the best LM that fits in 16MB. '
|
| 135 |
+
'I designed the transformer architecture from scratch — 11 layers, 512 dim, 8 heads. '
|
| 136 |
+
'Custom SentencePiece tokenizers, GPTQ quantization, flash attention. '
|
| 137 |
+
'This isnt tutorial-level stuff, I understand how transformers train from the ground up.'
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
doc.add_paragraph(
|
| 141 |
+
'GSPO-DeepSeek-R1-Distill-Qwen-1.5B (15 GitHub stars): Implemented the GSPO algorithm from '
|
| 142 |
+
'the Qwen team. Got 60% accuracy on ZebraLogic, 75.8% on math. '
|
| 143 |
+
'Beat PPO and GRPO baselines.'
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
doc.add_paragraph(
|
| 147 |
+
'Dial 112 AI: Built a production system for AP Police processing 1000+ emergency calls daily. '
|
| 148 |
+
'Speech-to-text, sentiment analysis, priority classification, geospatial dispatch optimization.'
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
doc.add_paragraph(
|
| 152 |
+
'Published a paper on pose-guided image generation (PPAG) — '
|
| 153 |
+
'progressive pose attention for identity-preserving synthesis, 92% identity preservation score.'
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
doc.add_heading('Open source', level=3)
|
| 157 |
+
|
| 158 |
+
doc.add_paragraph(
|
| 159 |
+
'PRs to HuggingFace transformers, Unsloth (aarch64 support), Anthropic SDK (streaming perf fix), '
|
| 160 |
+
'OpenAI SDK, Karpathys nanochat (NaN loss fix for SFT training).'
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
doc.add_paragraph(
|
| 164 |
+
'DeepChem contributions: PR #4913 (ChemBERTa import fix for transformers 5.x), '
|
| 165 |
+
'Issue #4912 (transformers compatibility report).'
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
doc.add_heading('Why this project', level=3)
|
| 169 |
+
|
| 170 |
+
doc.add_paragraph(
|
| 171 |
+
'[FILL THIS IN YOURSELF — what got you interested in molecular ML? '
|
| 172 |
+
'Why DeepChem specifically? Be honest here, mentors can tell when someone is just saying '
|
| 173 |
+
'what they think you want to hear. Talk about what actually excites you.]'
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
p = doc.add_paragraph()
|
| 177 |
+
p.add_run('GitHub: ').bold = True
|
| 178 |
+
p.add_run('https://github.com/vivekvar-dl')
|
| 179 |
+
p = doc.add_paragraph()
|
| 180 |
+
p.add_run('LinkedIn: ').bold = True
|
| 181 |
+
p.add_run('https://linkedin.com/in/vivekvar')
|
| 182 |
+
p = doc.add_paragraph()
|
| 183 |
+
p.add_run('Bug fix PR: ').bold = True
|
| 184 |
+
p.add_run('https://github.com/deepchem/deepchem/pull/4913')
|
| 185 |
+
p = doc.add_paragraph()
|
| 186 |
+
p.add_run('Compat issue: ').bold = True
|
| 187 |
+
p.add_run('https://github.com/deepchem/deepchem/issues/4912')
|
| 188 |
+
|
| 189 |
+
# ============ WORK PLAN ============
|
| 190 |
+
doc.add_heading('Work Plan', level=2)
|
| 191 |
+
|
| 192 |
+
doc.add_paragraph(
|
| 193 |
+
'The project breaks down into four components that build on each other. '
|
| 194 |
+
'First the base class needs to learn about causal LMs, then we build the OLMo specific stuff on top, '
|
| 195 |
+
'then pretraining and benchmarks, then generation and docs.'
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
# Design
|
| 199 |
+
doc.add_heading('Design and Pseudocode', level=3)
|
| 200 |
+
|
| 201 |
+
doc.add_paragraph().add_run('Component 1: HuggingFaceModel base class changes').bold = True
|
| 202 |
+
doc.add_paragraph(
|
| 203 |
+
'The wrapper currently doesnt know causal LMs exist. Need to add "causal_lm" as a task type '
|
| 204 |
+
'with DataCollatorForLanguageModeling(mlm=False), an AutoModelForCausalLM branch in '
|
| 205 |
+
'load_from_pretrained(), a generate() method wrapping HFs model.generate() API, and '
|
| 206 |
+
'_prepare_batch() handling where labels = input_ids for next token prediction.'
|
| 207 |
+
)
|
| 208 |
+
doc.add_paragraph(
|
| 209 |
+
'Note: PR #4907 from another contributor adds a generate() method too. '
|
| 210 |
+
'My work is different — building a complete model wrapper, not just generation. '
|
| 211 |
+
'Will coordinate to avoid overlap.'
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
doc.add_paragraph().add_run('Component 2: Olmo2ForSequenceClassification').bold = True
|
| 215 |
+
doc.add_paragraph(
|
| 216 |
+
'This class doesnt exist in HuggingFace. I already built it. '
|
| 217 |
+
'Takes Olmo2Model output, grabs the last non-padded tokens hidden state (last-token pooling), '
|
| 218 |
+
'projects through a linear layer. Loss computation depends on problem_type — '
|
| 219 |
+
'MSELoss for regression, CrossEntropyLoss for single label, BCEWithLogitsLoss for multi label. '
|
| 220 |
+
'About 100 lines, follows same pattern as LlamaForSequenceClassification.'
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
doc.add_paragraph('Forward pass pseudocode:')
|
| 224 |
+
doc.add_paragraph(
|
| 225 |
+
'hidden = base_model(input_ids, attention_mask)\n'
|
| 226 |
+
'seq_lengths = (input_ids != pad_id).sum(-1) - 1\n'
|
| 227 |
+
'pooled = hidden[range(batch_size), seq_lengths]\n'
|
| 228 |
+
'logits = linear_head(pooled)\n'
|
| 229 |
+
'loss = compute_loss(logits, labels, problem_type)',
|
| 230 |
+
style='No Spacing'
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
doc.add_paragraph().add_run('Component 3: OLMo wrapper class').bold = True
|
| 234 |
+
doc.add_paragraph(
|
| 235 |
+
'User-facing class extending HuggingFaceModel. Same pattern as ChemBERTa and MoLFormer.'
|
| 236 |
+
)
|
| 237 |
+
doc.add_paragraph(
|
| 238 |
+
'OLMo.__init__(task, model_name, n_tasks, config)\n'
|
| 239 |
+
' task: causal_lm | regression | classification | mtr\n'
|
| 240 |
+
' Loads tokenizer, sets pad_token = eos_token\n'
|
| 241 |
+
' Syncs vocab_size with tokenizer\n'
|
| 242 |
+
' Creates right model class based on task\n\n'
|
| 243 |
+
'OLMo._prepare_batch(batch)\n'
|
| 244 |
+
' causal_lm: labels = input_ids clone\n'
|
| 245 |
+
' regression: float labels\n'
|
| 246 |
+
' classification: long (single) or float (multi)',
|
| 247 |
+
style='No Spacing'
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
doc.add_paragraph().add_run('Component 4: Tokenization strategy').bold = True
|
| 251 |
+
doc.add_paragraph(
|
| 252 |
+
'Phase 1: Use OLMos tokenizer directly on SMILES. Works — '
|
| 253 |
+
'my experiments show 0.67 AUC and 0.37 R2 from random init which means learning is happening. '
|
| 254 |
+
'Not perfect with stereocenters but functional.'
|
| 255 |
+
)
|
| 256 |
+
doc.add_paragraph(
|
| 257 |
+
'Stretch: Extend vocab with chemistry tokens — [C@@H], [nH], (=O), aromatic ring patterns. '
|
| 258 |
+
'Retrain BPE on mixed English + SMILES corpus. Could help generation quality a lot.'
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
# Testing
|
| 262 |
+
doc.add_heading('Testing Plan', level=3)
|
| 263 |
+
|
| 264 |
+
doc.add_paragraph('I already have 8 unit tests written and passing:')
|
| 265 |
+
|
| 266 |
+
tests = [
|
| 267 |
+
('test_olmo_causal_lm_pretraining', 'Trains with causal LM objective, loss is positive'),
|
| 268 |
+
('test_olmo_regression_finetuning', 'Regression training, predictions have correct shape, MAE score works'),
|
| 269 |
+
('test_olmo_classification', 'Binary classification on random labels'),
|
| 270 |
+
('test_olmo_multitask_regression', '2-task regression, output shape matches'),
|
| 271 |
+
('test_olmo_save_and_restore', 'Save checkpoint, load into new model, verify all weights match'),
|
| 272 |
+
('test_olmo_load_from_pretrained', 'Pretrain causal LM then load weights into regression model'),
|
| 273 |
+
('test_olmo_generate', 'Single string and batch generation, returns valid strings'),
|
| 274 |
+
('test_olmo_invalid_task', 'Bad task name raises ValueError'),
|
| 275 |
+
]
|
| 276 |
+
for name, desc in tests:
|
| 277 |
+
doc.add_paragraph(f'{name} — {desc}', style='List Bullet')
|
| 278 |
+
|
| 279 |
+
doc.add_paragraph(
|
| 280 |
+
'All tests use a tiny config (64 hidden, 2 layers, 2 heads). No model downloads. '
|
| 281 |
+
'Runs in 27 seconds on CPU.'
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
doc.add_paragraph('Integration tests to add during GSoC:')
|
| 285 |
+
doc.add_paragraph('MoleculeNet benchmarks — BBBP, ESOL, FreeSolv, Lipophilicity', style='List Bullet')
|
| 286 |
+
doc.add_paragraph('SMILES generation validity checking with RDKit', style='List Bullet')
|
| 287 |
+
doc.add_paragraph('Pretraining convergence curves on ZINC subsets', style='List Bullet')
|
| 288 |
+
|
| 289 |
+
# Risks
|
| 290 |
+
doc.add_heading('Sources of Risk', level=3)
|
| 291 |
+
|
| 292 |
+
doc.add_paragraph().add_run('Generation quality:').bold = True
|
| 293 |
+
doc.add_paragraph(
|
| 294 |
+
'This is the biggest one. 0% validity from random init is expected but getting to '
|
| 295 |
+
'something useful like 50%+ valid SMILES needs real pretraining on a molecular corpus. '
|
| 296 |
+
'Im allocating 3 full weeks for pretraining experiments using ZINC-250K as training data. '
|
| 297 |
+
'If ZINC isnt enough, backup plan is PubChem subsets or combining multiple datasets.'
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
doc.add_paragraph().add_run('GPU memory for 7B model:').bold = True
|
| 301 |
+
doc.add_paragraph(
|
| 302 |
+
'OLMo-7B needs ~14GB VRAM just for inference. '
|
| 303 |
+
'CI tests use tiny configs so no GPU needed. Demos use OLMo-1B. '
|
| 304 |
+
'7B experiments may need multi-GPU, will figure out with mentor.'
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
doc.add_paragraph().add_run('Olmo2ForSequenceClassification not upstream:').bold = True
|
| 308 |
+
doc.add_paragraph(
|
| 309 |
+
'I wrote this class myself since HuggingFace doesnt have it. '
|
| 310 |
+
'If they add one later we just swap ours out. Follows their patterns exactly so low risk.'
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
doc.add_paragraph().add_run('Transformers version compatibility:').bold = True
|
| 314 |
+
doc.add_paragraph(
|
| 315 |
+
'Already found and fixed one issue (PR #4913). '
|
| 316 |
+
'Using top-level imports everywhere going forward. Will test against both 4.x and 5.x.'
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
# Milestones
|
| 320 |
+
doc.add_heading('Milestones', level=3)
|
| 321 |
+
|
| 322 |
+
doc.add_paragraph().add_run('Milestone 1 (end of week 3): Core wrapper working').bold = True
|
| 323 |
+
doc.add_paragraph(
|
| 324 |
+
'from deepchem.models import OLMo works. All task modes functional. '
|
| 325 |
+
'Tests passing. Base class changes merged.'
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
doc.add_paragraph().add_run('Milestone 2 (end of week 6): Pretraining pipeline done').bold = True
|
| 329 |
+
doc.add_paragraph(
|
| 330 |
+
'Can load SMILES data, pretrain OLMo with causal LM objective, save checkpoints. '
|
| 331 |
+
'Tutorial notebook showing how to pretrain on custom data. '
|
| 332 |
+
'First generation results with validity numbers.'
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
doc.add_paragraph().add_run('Milestone 3 (end of week 12): Everything shipped').bold = True
|
| 336 |
+
doc.add_paragraph(
|
| 337 |
+
'MoleculeNet benchmark results published. Generation tutorial with RDKit validation. '
|
| 338 |
+
'Full documentation. All PRs merged.'
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
# Timeline
|
| 342 |
+
doc.add_heading('Timeline', level=3)
|
| 343 |
+
|
| 344 |
+
doc.add_paragraph('12 weeks, Medium size (175 hours):')
|
| 345 |
+
|
| 346 |
+
weeks = [
|
| 347 |
+
('Week 1', 'Set up dev environment, submit base class PR (causal_lm task support, ~40 lines). Start on generate() method.'),
|
| 348 |
+
('Week 2', 'Submit generate() PR (~50 lines). Start Olmo2ForSequenceClassification. Discuss size with mentor since its ~100 lines.'),
|
| 349 |
+
('Week 3', 'Submit OLMo wrapper PR and unit test PR. Get through review. Target: all core PRs merged by end of week.'),
|
| 350 |
+
('Week 4', 'Start pretraining pipeline. Data loading for ZINC-250K. Figure out training hyperparameters on OLMo-1B.'),
|
| 351 |
+
('Week 5', 'Run pretraining experiments. Monitor convergence. First generation attempts — check validity with RDKit.'),
|
| 352 |
+
('Week 6', 'Finish pretraining PR and tutorial notebook. Submit for review. Save best checkpoint.'),
|
| 353 |
+
('Week 7', 'Start finetuning experiments. BBBP and Tox21 classification with pretrained model.'),
|
| 354 |
+
('Week 8', 'ESOL, FreeSolv, Lipophilicity regression. Build benchmark comparison table vs ChemBERTa.'),
|
| 355 |
+
('Week 9', 'Submit benchmark PR and finetuning tutorials. Respond to review feedback.'),
|
| 356 |
+
('Week 10', 'SMILES generation experiments. Validity rate analysis. Different sampling strategies (temperature, top-k, nucleus).'),
|
| 357 |
+
('Week 11', 'Generation tutorial notebook. Documentation — numpydoc for all classes/methods, API reference updates.'),
|
| 358 |
+
('Week 12', 'Final review rounds. Clean up any open PRs. Stretch: tokenizer extension experiments if time allows.'),
|
| 359 |
+
]
|
| 360 |
+
for week, desc in weeks:
|
| 361 |
+
p = doc.add_paragraph()
|
| 362 |
+
p.add_run(f'{week}: ').bold = True
|
| 363 |
+
p.add_run(desc)
|
| 364 |
+
|
| 365 |
+
# Pull Requests
|
| 366 |
+
doc.add_heading('Pull Requests', level=3)
|
| 367 |
+
|
| 368 |
+
doc.add_paragraph(
|
| 369 |
+
'Following DeepChems guidelines for new contributors — small PRs, especially at the start. '
|
| 370 |
+
'Heres the planned breakdown:'
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
prs = [
|
| 374 |
+
('PR 1 (Week 1)', 'HuggingFaceModel causal_lm support', '~40 lines'),
|
| 375 |
+
('PR 2 (Week 2)', 'generate() method', '~50 lines'),
|
| 376 |
+
('PR 3 (Week 2-3)', 'Olmo2ForSequenceClassification', '~100 lines (will discuss with mentor)'),
|
| 377 |
+
('PR 4 (Week 3)', 'OLMo wrapper class', '~80 lines'),
|
| 378 |
+
('PR 5 (Week 3)', 'Unit tests', '~180 lines'),
|
| 379 |
+
('PR 6 (Week 6)', 'Pretraining pipeline + data utils', 'TBD'),
|
| 380 |
+
('PR 7 (Week 6)', 'Pretraining tutorial notebook', 'notebook'),
|
| 381 |
+
('PR 8 (Week 9)', 'Finetuning tutorials (classification + regression)', 'notebooks'),
|
| 382 |
+
('PR 9 (Week 9)', 'Benchmark results', 'TBD'),
|
| 383 |
+
('PR 10 (Week 11)', 'Generation tutorial', 'notebook'),
|
| 384 |
+
('PR 11 (Week 12)', 'Documentation updates', 'TBD'),
|
| 385 |
+
]
|
| 386 |
+
for pr, desc, size in prs:
|
| 387 |
+
p = doc.add_paragraph()
|
| 388 |
+
p.add_run(f'{pr}: ').bold = True
|
| 389 |
+
p.add_run(f'{desc} ({size})')
|
| 390 |
+
|
| 391 |
+
doc.add_paragraph(
|
| 392 |
+
'Each PR goes through review at office hours. Bigger ones might need 2-3 rounds. '
|
| 393 |
+
'I have buffer built into the timeline for this — not assuming everything merges on first try.'
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
# ============ COMMUNITY ============
|
| 397 |
+
doc.add_heading('Community', level=2)
|
| 398 |
+
|
| 399 |
+
doc.add_paragraph('What Ive done so far:')
|
| 400 |
+
doc.add_paragraph('PR #4913 — fixed ChemBERTa import crash for transformers 5.x', style='List Bullet')
|
| 401 |
+
doc.add_paragraph('Issue #4912 — reported broader transformers 5.x compatibility problems', style='List Bullet')
|
| 402 |
+
doc.add_paragraph('Built and tested OLMo prototype locally against DeepChems codebase', style='List Bullet')
|
| 403 |
+
|
| 404 |
+
doc.add_paragraph(
|
| 405 |
+
'I can commit to attending at least 2 office hour sessions per week (MWF 9am PST). '
|
| 406 |
+
'Will also be active on Discord for async discussion and will do weekly progress updates.'
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
doc.add_paragraph(
|
| 410 |
+
'[NOTE: mention which mentors you have talked to once you connect on Discord. '
|
| 411 |
+
'Riya and Harindhar are listed as mentors for this project.]'
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
# ============ RESOURCES ============
|
| 415 |
+
doc.add_heading('Resources Required', level=2)
|
| 416 |
+
|
| 417 |
+
doc.add_paragraph(
|
| 418 |
+
'I have GPU access through Azure cloud which should handle OLMo-1B training and OLMo-7B inference. '
|
| 419 |
+
'For CI and unit tests everything runs on CPU with tiny model configs so no special compute needed there.'
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
doc.add_paragraph(
|
| 423 |
+
'For OLMo-7B full training we might need a multi-GPU setup. '
|
| 424 |
+
'Would be good to discuss with the mentor what compute DeepChem can provide '
|
| 425 |
+
'or if Colab Pro / cloud credits would work. '
|
| 426 |
+
'The pretraining experiments on OLMo-1B should be doable on my current setup.'
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
# ============ REFERENCES ============
|
| 430 |
+
doc.add_heading('References', level=2)
|
| 431 |
+
|
| 432 |
+
refs = [
|
| 433 |
+
'Groeneveld et al. (2024). OLMo: Accelerating the Science of Language Models. arXiv:2402.00838',
|
| 434 |
+
'Chithrananda et al. (2020). ChemBERTa: Large-Scale Self-Supervised Pretraining for Molecular Property Prediction. arXiv:2010.09885',
|
| 435 |
+
'Ross et al. (2022). Large-Scale Chemical Language Representations Capture Molecular Structure and Properties. Nature Machine Intelligence',
|
| 436 |
+
'Weininger (1988). SMILES, a chemical language and information system. J Chem Inf Comput Sci',
|
| 437 |
+
'Wu et al. (2018). MoleculeNet: A Benchmark for Molecular Machine Learning. Chemical Science',
|
| 438 |
+
]
|
| 439 |
+
for i, ref in enumerate(refs, 1):
|
| 440 |
+
doc.add_paragraph(f'{i}. {ref}')
|
| 441 |
+
|
| 442 |
+
# Save
|
| 443 |
+
doc.save('/home/azureuser/GSoC_2026_Proposal_Vivek_OLMo.docx')
|
| 444 |
+
print('Done: ~/GSoC_2026_Proposal_Vivek_OLMo.docx')
|
merge_v3.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# v3: use clean_merged_data (v2) as base, add Khadatkar + Learning to train
|
| 3 |
+
import os, shutil, glob
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
HOME = os.path.expanduser('~')
|
| 7 |
+
OUT = f'{HOME}/merged_v3'
|
| 8 |
+
BASE = f'{HOME}/clean_merged_data'
|
| 9 |
+
|
| 10 |
+
# Fresh output
|
| 11 |
+
if os.path.exists(OUT): shutil.rmtree(OUT)
|
| 12 |
+
for s in ('train','valid','test'):
|
| 13 |
+
os.makedirs(f'{OUT}/{s}/images', exist_ok=True)
|
| 14 |
+
os.makedirs(f'{OUT}/{s}/labels', exist_ok=True)
|
| 15 |
+
|
| 16 |
+
stats = {s: {c: 0 for c in (0,1,2)} for s in ('train','valid','test')}
|
| 17 |
+
imgcount = {s: 0 for s in ('train','valid','test')}
|
| 18 |
+
|
| 19 |
+
def copy_split(src_img_dir, src_lbl_dir, target_split, cmap, name_suffix):
|
| 20 |
+
n = 0
|
| 21 |
+
for lbl_path in glob.glob(f'{src_lbl_dir}/*.txt'):
|
| 22 |
+
stem = Path(lbl_path).stem
|
| 23 |
+
img_path = None
|
| 24 |
+
for ext in ('.jpg','.jpeg','.png','.JPG','.PNG'):
|
| 25 |
+
p = f'{src_img_dir}/{stem}{ext}'
|
| 26 |
+
if os.path.exists(p): img_path = p; break
|
| 27 |
+
if img_path is None: continue
|
| 28 |
+
lines = []
|
| 29 |
+
with open(lbl_path) as f:
|
| 30 |
+
for line in f:
|
| 31 |
+
parts = line.strip().split()
|
| 32 |
+
if not parts: continue
|
| 33 |
+
cid = int(parts[0])
|
| 34 |
+
if cid not in cmap: continue
|
| 35 |
+
lines.append(' '.join([str(cmap[cid])] + parts[1:]))
|
| 36 |
+
if not lines: continue
|
| 37 |
+
new_stem = f'{stem}{name_suffix}'
|
| 38 |
+
ext = Path(img_path).suffix
|
| 39 |
+
dst_img = f'{OUT}/{target_split}/images/{new_stem}{ext}'
|
| 40 |
+
dst_lbl = f'{OUT}/{target_split}/labels/{new_stem}.txt'
|
| 41 |
+
if not os.path.exists(dst_img):
|
| 42 |
+
try: os.link(img_path, dst_img)
|
| 43 |
+
except: shutil.copy(img_path, dst_img)
|
| 44 |
+
with open(dst_lbl, 'w') as f:
|
| 45 |
+
f.write('\n'.join(lines) + '\n')
|
| 46 |
+
for ln in lines:
|
| 47 |
+
stats[target_split][int(ln.split()[0])] += 1
|
| 48 |
+
imgcount[target_split] += 1
|
| 49 |
+
n += 1
|
| 50 |
+
return n
|
| 51 |
+
|
| 52 |
+
# 1) Copy clean_merged_data AS-IS (identity mapping for 0,1,2), no extra suffix
|
| 53 |
+
# Images already have _cctv_dataset / _helmet_dataset / _yolo_project suffixes
|
| 54 |
+
print('--- base v2 data ---')
|
| 55 |
+
for s in ('train','valid','test'):
|
| 56 |
+
n = copy_split(f'{BASE}/{s}/images', f'{BASE}/{s}/labels', s, {0:0,1:1,2:2}, '')
|
| 57 |
+
print(f' base -> {s}: {n}')
|
| 58 |
+
|
| 59 |
+
# 2) Add Khadatkar + Learning ONLY to train split
|
| 60 |
+
EXTRAS = [
|
| 61 |
+
('khadatkar', f'{HOME}/extra_khadatkar', {0:1, 1:0}), # 0=With Helmet->1, 1=Without Helmet->0, drop 2=licence
|
| 62 |
+
('learning', f'{HOME}/extra_learning', {0:1, 1:0}), # 0=With Helmet->1, 1=Without Helmet->0
|
| 63 |
+
]
|
| 64 |
+
print('--- extras -> train ---')
|
| 65 |
+
for name, root, cmap in EXTRAS:
|
| 66 |
+
for src_split in ('train','valid','test'):
|
| 67 |
+
img_dir = f'{root}/{src_split}/images'
|
| 68 |
+
lbl_dir = f'{root}/{src_split}/labels'
|
| 69 |
+
if not os.path.isdir(lbl_dir): continue
|
| 70 |
+
n = copy_split(img_dir, lbl_dir, 'train', cmap, f'_{name}_{src_split}')
|
| 71 |
+
print(f' {name} {src_split} -> train: {n}')
|
| 72 |
+
|
| 73 |
+
yaml = f'''path: {OUT}
|
| 74 |
+
train: train/images
|
| 75 |
+
val: valid/images
|
| 76 |
+
test: test/images
|
| 77 |
+
nc: 3
|
| 78 |
+
names:
|
| 79 |
+
0: no-helmet
|
| 80 |
+
1: with-helmet
|
| 81 |
+
2: triple-riding
|
| 82 |
+
'''
|
| 83 |
+
with open(f'{OUT}/data.yaml','w') as f: f.write(yaml)
|
| 84 |
+
|
| 85 |
+
print('\n=== V3 MERGE COMPLETE ===')
|
| 86 |
+
for s in ('train','valid','test'):
|
| 87 |
+
tot = sum(stats[s].values())
|
| 88 |
+
print(f' {s:6s} images={imgcount[s]:5d} | no-helmet={stats[s][0]:5d} with-helmet={stats[s][1]:5d} triple={stats[s][2]:4d} | instances={tot}')
|
pull_extra.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Pull 3 extra helmet/rider datasets from Roboflow Universe
|
| 3 |
+
from roboflow import Roboflow
|
| 4 |
+
|
| 5 |
+
# API key reused from previous script
|
| 6 |
+
rf = Roboflow(api_key='qeQs9chVa3kU0XnpTZsd')
|
| 7 |
+
|
| 8 |
+
targets = [
|
| 9 |
+
('gw-khadatkar-and-sv-wasule', 'helmet-and-no-helmet-rider-detection', '~/extra_khadatkar'),
|
| 10 |
+
('nckh-2023', 'helmet-detection-project', '~/extra_nckh'),
|
| 11 |
+
('learning-evidence', 'helmet-detection_yolov8', '~/extra_learning'),
|
| 12 |
+
]
|
| 13 |
+
|
| 14 |
+
for ws, proj, loc in targets:
|
| 15 |
+
print(f'\n=== {ws}/{proj} ===')
|
| 16 |
+
try:
|
| 17 |
+
p = rf.workspace(ws).project(proj)
|
| 18 |
+
# try latest version
|
| 19 |
+
vs = p.versions()
|
| 20 |
+
if not vs:
|
| 21 |
+
print(' NO VERSIONS'); continue
|
| 22 |
+
v = vs[0]
|
| 23 |
+
print(f' version={v.version}')
|
| 24 |
+
ds = v.download('yolov8', location=loc)
|
| 25 |
+
print(f' DOWNLOADED to {loc}')
|
| 26 |
+
except Exception as e:
|
| 27 |
+
print(f' FAIL: {e}')
|
telugu_voice_clone.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Telugu Voice Cloning with IndicF5
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
1. Place your reference audio as 'reference.wav' (10-15 seconds, clean Telugu speech)
|
| 6 |
+
2. Edit REF_TEXT with the exact Telugu transcript of your reference audio
|
| 7 |
+
3. Edit GEN_TEXT with the Telugu text you want to generate
|
| 8 |
+
4. Run: source ~/indicf5-env/bin/activate && python telugu_voice_clone.py
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import numpy as np
|
| 13 |
+
import soundfile as sf
|
| 14 |
+
import io
|
| 15 |
+
import time
|
| 16 |
+
from pydub import AudioSegment, silence
|
| 17 |
+
from huggingface_hub import hf_hub_download
|
| 18 |
+
from f5_tts.infer.utils_infer import (
|
| 19 |
+
infer_process,
|
| 20 |
+
load_model,
|
| 21 |
+
load_vocoder,
|
| 22 |
+
preprocess_ref_audio_text,
|
| 23 |
+
)
|
| 24 |
+
from f5_tts.model import DiT
|
| 25 |
+
|
| 26 |
+
# === CONFIGURE THESE ===
|
| 27 |
+
|
| 28 |
+
# Path to your reference voice recording (WAV, 10-15 seconds, Telugu)
|
| 29 |
+
REF_AUDIO = "reference.wav"
|
| 30 |
+
|
| 31 |
+
# Exact Telugu transcript of your reference audio
|
| 32 |
+
REF_TEXT = "ఇది నా గొంతు నమూనా, నేను తెలుగులో మాట్లాడుతున్నాను."
|
| 33 |
+
|
| 34 |
+
# Telugu text you want to generate in your cloned voice
|
| 35 |
+
GEN_TEXT = "నమస్కారం, మీరు ఎలా ఉన్నారు? నేను మీతో తెలుగులో మాట్లాడుతున్నాను."
|
| 36 |
+
|
| 37 |
+
# Output file
|
| 38 |
+
OUTPUT_FILE = "output_telugu.wav"
|
| 39 |
+
|
| 40 |
+
SPEED = 1.0
|
| 41 |
+
REMOVE_SILENCE = True
|
| 42 |
+
|
| 43 |
+
# === END CONFIG ===
|
| 44 |
+
|
| 45 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 46 |
+
print(f"Using device: {device}")
|
| 47 |
+
|
| 48 |
+
# Load vocoder
|
| 49 |
+
print("Loading vocoder...")
|
| 50 |
+
vocoder = load_vocoder(vocoder_name="vocos", is_local=False, device=device)
|
| 51 |
+
|
| 52 |
+
# Download vocab and load model
|
| 53 |
+
print("Downloading IndicF5 model...")
|
| 54 |
+
repo_id = "ai4bharat/IndicF5"
|
| 55 |
+
vocab_path = hf_hub_download(repo_id, filename="checkpoints/vocab.txt")
|
| 56 |
+
|
| 57 |
+
ema_model = load_model(
|
| 58 |
+
DiT,
|
| 59 |
+
dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4),
|
| 60 |
+
mel_spec_type="vocos",
|
| 61 |
+
vocab_file=vocab_path,
|
| 62 |
+
device=device,
|
| 63 |
+
)
|
| 64 |
+
print("Model loaded!")
|
| 65 |
+
|
| 66 |
+
# Preprocess reference audio
|
| 67 |
+
print(f"Reference audio: {REF_AUDIO}")
|
| 68 |
+
ref_audio, ref_text = preprocess_ref_audio_text(REF_AUDIO, REF_TEXT)
|
| 69 |
+
|
| 70 |
+
# Generate
|
| 71 |
+
print(f"Generating: {GEN_TEXT[:80]}...")
|
| 72 |
+
start = time.time()
|
| 73 |
+
audio, final_sample_rate, _ = infer_process(
|
| 74 |
+
ref_audio,
|
| 75 |
+
ref_text,
|
| 76 |
+
GEN_TEXT,
|
| 77 |
+
ema_model,
|
| 78 |
+
vocoder,
|
| 79 |
+
mel_spec_type="vocos",
|
| 80 |
+
speed=SPEED,
|
| 81 |
+
device=device,
|
| 82 |
+
)
|
| 83 |
+
print(f"Generated in {time.time() - start:.1f}s")
|
| 84 |
+
|
| 85 |
+
# Post-process: remove silence and normalize
|
| 86 |
+
buffer = io.BytesIO()
|
| 87 |
+
sf.write(buffer, audio, samplerate=24000, format="WAV")
|
| 88 |
+
buffer.seek(0)
|
| 89 |
+
audio_segment = AudioSegment.from_file(buffer, format="wav")
|
| 90 |
+
|
| 91 |
+
if REMOVE_SILENCE:
|
| 92 |
+
non_silent_segs = silence.split_on_silence(
|
| 93 |
+
audio_segment,
|
| 94 |
+
min_silence_len=1000,
|
| 95 |
+
silence_thresh=-50,
|
| 96 |
+
keep_silence=500,
|
| 97 |
+
seek_step=10,
|
| 98 |
+
)
|
| 99 |
+
if non_silent_segs:
|
| 100 |
+
audio_segment = sum(non_silent_segs, AudioSegment.silent(duration=0))
|
| 101 |
+
|
| 102 |
+
# Normalize loudness
|
| 103 |
+
target_dBFS = -20.0
|
| 104 |
+
change_in_dBFS = target_dBFS - audio_segment.dBFS
|
| 105 |
+
audio_segment = audio_segment.apply_gain(change_in_dBFS)
|
| 106 |
+
|
| 107 |
+
# Save
|
| 108 |
+
final_audio = np.array(audio_segment.get_array_of_samples())
|
| 109 |
+
if final_audio.dtype == np.int16:
|
| 110 |
+
final_audio = final_audio.astype(np.float32) / 32768.0
|
| 111 |
+
sf.write(OUTPUT_FILE, final_audio.astype(np.float32), samplerate=24000)
|
| 112 |
+
print(f"Saved to {OUTPUT_FILE}")
|
| 113 |
+
print("Done!")
|
train_h100_clean.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
train_h100_clean.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# 3-class clean training on H100 NVL
|
| 3 |
+
# Classes: 0 no-helmet | 1 with-helmet | 2 triple-riding
|
| 4 |
+
from ultralytics import YOLO
|
| 5 |
+
import torch, os
|
| 6 |
+
|
| 7 |
+
print('GPU:', torch.cuda.get_device_name(0), '|', torch.cuda.get_device_properties(0).total_memory/1e9, 'GB')
|
| 8 |
+
|
| 9 |
+
# Start from pretrained yolo26m (auto-downloads if missing)
|
| 10 |
+
model = YOLO('yolo26m.pt')
|
| 11 |
+
|
| 12 |
+
results = model.train(
|
| 13 |
+
data='/home/azureuser/clean_merged_data/data.yaml',
|
| 14 |
+
epochs=150,
|
| 15 |
+
imgsz=640,
|
| 16 |
+
batch=64, # H100 NVL has 95GB, can push batch high
|
| 17 |
+
device=0,
|
| 18 |
+
workers=8,
|
| 19 |
+
project='runs_clean',
|
| 20 |
+
name='h100_3class',
|
| 21 |
+
exist_ok=True,
|
| 22 |
+
amp=True,
|
| 23 |
+
cos_lr=True,
|
| 24 |
+
close_mosaic=15,
|
| 25 |
+
# augmentation — important for 10k image dataset
|
| 26 |
+
mosaic=1.0,
|
| 27 |
+
mixup=0.15,
|
| 28 |
+
copy_paste=0.3, # boost with-helmet via cross-image pasting
|
| 29 |
+
hsv_h=0.015, hsv_s=0.7, hsv_v=0.4,
|
| 30 |
+
degrees=5.0,
|
| 31 |
+
translate=0.1,
|
| 32 |
+
scale=0.5,
|
| 33 |
+
fliplr=0.5,
|
| 34 |
+
# loss
|
| 35 |
+
cls=1.0, # classification loss weight (bump if still confused)
|
| 36 |
+
box=7.5,
|
| 37 |
+
dfl=1.5,
|
| 38 |
+
# regularization
|
| 39 |
+
weight_decay=0.0005,
|
| 40 |
+
dropout=0.0,
|
| 41 |
+
# schedule
|
| 42 |
+
optimizer='auto',
|
| 43 |
+
lr0=0.01,
|
| 44 |
+
patience=40,
|
| 45 |
+
plots=True,
|
| 46 |
+
verbose=True,
|
| 47 |
+
)
|
| 48 |
+
print('TRAIN DONE — running val on test split')
|
| 49 |
+
m = YOLO('runs_clean/h100_3class/weights/best.pt')
|
| 50 |
+
m.val(data='/home/azureuser/clean_merged_data/data.yaml', split='test', plots=True, save_json=True)
|
train_h100_final.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Final Training on H100 - 96GB VRAM Beast!
|
| 4 |
+
Merges ALL datasets and trains with maximum performance
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from roboflow import Roboflow
|
| 8 |
+
from ultralytics import YOLO
|
| 9 |
+
import torch
|
| 10 |
+
import os
|
| 11 |
+
import shutil
|
| 12 |
+
import yaml
|
| 13 |
+
import glob
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
print("=" * 70)
|
| 17 |
+
print("FINAL TRAINING ON H100 - BALANCED DATASET")
|
| 18 |
+
print("=" * 70)
|
| 19 |
+
|
| 20 |
+
# Check GPU
|
| 21 |
+
print(f"\nGPU Available: {torch.cuda.is_available()}")
|
| 22 |
+
if torch.cuda.is_available():
|
| 23 |
+
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
| 24 |
+
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.0f} GB")
|
| 25 |
+
|
| 26 |
+
# Step 1: Download all datasets from Roboflow
|
| 27 |
+
print("\n" + "=" * 70)
|
| 28 |
+
print("STEP 1: Downloading Datasets from Roboflow")
|
| 29 |
+
print("=" * 70)
|
| 30 |
+
|
| 31 |
+
rf = Roboflow(api_key="cMpZOr1EizWFVrJ0Au4o")
|
| 32 |
+
|
| 33 |
+
# Dataset 1: New 212 helmet images
|
| 34 |
+
print("\nDataset 1: New helmet images (212)...")
|
| 35 |
+
project1 = rf.workspace("team11s-workspace-man05").project("helmet-detection-ihomd")
|
| 36 |
+
ds1 = project1.version(1).download("yolov8", location="~/helmet_212")
|
| 37 |
+
|
| 38 |
+
# Dataset 2: Old no-helmet (499) from first account
|
| 39 |
+
print("\nDataset 2: No-helmet images (499)...")
|
| 40 |
+
rf2 = Roboflow(api_key="qeQs9chVa3kU0XnpTZsd")
|
| 41 |
+
project2 = rf2.workspace("nyc-nleyq").project("indian-cctv-traffic-violations")
|
| 42 |
+
ds2 = project2.version(1).download("yolov8", location="~/no_helmet_499")
|
| 43 |
+
|
| 44 |
+
# Dataset 3: With-helmet (300) from second account
|
| 45 |
+
print("\nDataset 3: With-helmet images (300)...")
|
| 46 |
+
project3 = rf2.workspace("vivekvarikuti").project("withhelmet")
|
| 47 |
+
ds3 = project3.version(1).download("yolov8", location="~/with_helmet_300")
|
| 48 |
+
|
| 49 |
+
# Dataset 4: Triple-riding from original (626)
|
| 50 |
+
print("\nDataset 4: Triple-riding (626)...")
|
| 51 |
+
project4 = rf2.workspace("triple-ride-rsysj").project("triple-riding-detection-pniom")
|
| 52 |
+
ds4 = project4.version(1).download("yolov8", location="~/triple_riding_626")
|
| 53 |
+
|
| 54 |
+
print("\n✅ All datasets downloaded!")
|
| 55 |
+
|
| 56 |
+
# Step 2: Merge all datasets
|
| 57 |
+
print("\n" + "=" * 70)
|
| 58 |
+
print("STEP 2: Merging ALL Datasets")
|
| 59 |
+
print("=" * 70)
|
| 60 |
+
|
| 61 |
+
MERGED_DIR = os.path.expanduser("~/final_merged_h100")
|
| 62 |
+
|
| 63 |
+
for split in ['train', 'valid', 'test']:
|
| 64 |
+
os.makedirs(f"{MERGED_DIR}/{split}/images", exist_ok=True)
|
| 65 |
+
os.makedirs(f"{MERGED_DIR}/{split}/labels", exist_ok=True)
|
| 66 |
+
|
| 67 |
+
# Collect all classes
|
| 68 |
+
all_classes = set()
|
| 69 |
+
datasets = [
|
| 70 |
+
(ds1.location, 'helmet212'),
|
| 71 |
+
(ds2.location, 'nohelmet499'),
|
| 72 |
+
(ds3.location, 'withhelmet300'),
|
| 73 |
+
(ds4.location, 'triple626')
|
| 74 |
+
]
|
| 75 |
+
|
| 76 |
+
class_configs = {}
|
| 77 |
+
for ds_path, ds_name in datasets:
|
| 78 |
+
yaml_path = f"{ds_path}/data.yaml"
|
| 79 |
+
if os.path.exists(yaml_path):
|
| 80 |
+
with open(yaml_path, 'r') as f:
|
| 81 |
+
cfg = yaml.safe_load(f)
|
| 82 |
+
class_configs[ds_name] = cfg
|
| 83 |
+
if 'names' in cfg:
|
| 84 |
+
all_classes.update(cfg['names'])
|
| 85 |
+
|
| 86 |
+
unified_classes = sorted(list(all_classes))
|
| 87 |
+
print(f"\nUnified classes ({len(unified_classes)}): {unified_classes}")
|
| 88 |
+
|
| 89 |
+
# Create class mappings
|
| 90 |
+
class_maps = {}
|
| 91 |
+
for ds_name, cfg in class_configs.items():
|
| 92 |
+
class_maps[ds_name] = {}
|
| 93 |
+
if 'names' in cfg:
|
| 94 |
+
for i, cls in enumerate(cfg['names']):
|
| 95 |
+
class_maps[ds_name][i] = unified_classes.index(cls)
|
| 96 |
+
|
| 97 |
+
# Copy and merge datasets
|
| 98 |
+
def copy_with_remap(src_dir, prefix, class_mapping):
|
| 99 |
+
total = 0
|
| 100 |
+
for split in ['train', 'valid', 'test']:
|
| 101 |
+
src_img = f"{src_dir}/{split}/images"
|
| 102 |
+
src_lbl = f"{src_dir}/{split}/labels"
|
| 103 |
+
|
| 104 |
+
if not os.path.exists(src_img):
|
| 105 |
+
continue
|
| 106 |
+
|
| 107 |
+
imgs = glob.glob(f"{src_img}/*.jpg") + glob.glob(f"{src_img}/*.png")
|
| 108 |
+
|
| 109 |
+
for img_path in imgs:
|
| 110 |
+
img_name = os.path.basename(img_path)
|
| 111 |
+
lbl_name = Path(img_path).stem + '.txt'
|
| 112 |
+
lbl_path = f"{src_lbl}/{lbl_name}"
|
| 113 |
+
|
| 114 |
+
# Copy image with prefix
|
| 115 |
+
dst_img = f"{MERGED_DIR}/{split}/images/{prefix}_{img_name}"
|
| 116 |
+
shutil.copy2(img_path, dst_img)
|
| 117 |
+
|
| 118 |
+
# Remap and copy label
|
| 119 |
+
if os.path.exists(lbl_path):
|
| 120 |
+
with open(lbl_path, 'r') as f:
|
| 121 |
+
lines = f.readlines()
|
| 122 |
+
|
| 123 |
+
remapped = []
|
| 124 |
+
for line in lines:
|
| 125 |
+
parts = line.strip().split()
|
| 126 |
+
if len(parts) >= 5:
|
| 127 |
+
old_cls = int(parts[0])
|
| 128 |
+
new_cls = class_mapping.get(old_cls, old_cls)
|
| 129 |
+
remapped.append(f"{new_cls} {' '.join(parts[1:])}\n")
|
| 130 |
+
|
| 131 |
+
if remapped:
|
| 132 |
+
dst_lbl = f"{MERGED_DIR}/{split}/labels/{prefix}_{lbl_name}"
|
| 133 |
+
with open(dst_lbl, 'w') as f:
|
| 134 |
+
f.writelines(remapped)
|
| 135 |
+
total += 1
|
| 136 |
+
|
| 137 |
+
return total
|
| 138 |
+
|
| 139 |
+
print("\nCopying datasets...")
|
| 140 |
+
for (ds_path, ds_name), prefix in zip(datasets, ['h212', 'nh499', 'wh300', 'tr626']):
|
| 141 |
+
count = copy_with_remap(ds_path, prefix, class_maps.get(ds_name, {}))
|
| 142 |
+
print(f" {ds_name}: {count} images")
|
| 143 |
+
|
| 144 |
+
# Count final
|
| 145 |
+
print("\nFinal merged dataset:")
|
| 146 |
+
for split in ['train', 'valid', 'test']:
|
| 147 |
+
imgs = glob.glob(f"{MERGED_DIR}/{split}/images/*")
|
| 148 |
+
print(f" {split}: {len(imgs)} images")
|
| 149 |
+
|
| 150 |
+
# Create YAML
|
| 151 |
+
merged_yaml = {
|
| 152 |
+
'path': MERGED_DIR,
|
| 153 |
+
'train': 'train/images',
|
| 154 |
+
'val': 'valid/images',
|
| 155 |
+
'test': 'test/images',
|
| 156 |
+
'nc': len(unified_classes),
|
| 157 |
+
'names': unified_classes
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
yaml_path = f"{MERGED_DIR}/data.yaml"
|
| 161 |
+
with open(yaml_path, 'w') as f:
|
| 162 |
+
yaml.dump(merged_yaml, f, default_flow_style=False)
|
| 163 |
+
|
| 164 |
+
print(f"\nConfig saved: {yaml_path}")
|
| 165 |
+
|
| 166 |
+
# Step 3: Train on H100 with OPTIMIZED settings
|
| 167 |
+
print("\n" + "=" * 70)
|
| 168 |
+
print("STEP 3: TRAINING ON H100 (96GB VRAM!)")
|
| 169 |
+
print("=" * 70)
|
| 170 |
+
|
| 171 |
+
model = YOLO('yolo26m.pt')
|
| 172 |
+
|
| 173 |
+
print(f"\nTraining config:")
|
| 174 |
+
print(f" Model: YOLO26m")
|
| 175 |
+
print(f" Epochs: 150 (faster with H100)")
|
| 176 |
+
print(f" Batch: -1 (auto - H100 can handle 64-128!)")
|
| 177 |
+
print(f" Image size: 640")
|
| 178 |
+
print(f" Classes: {len(unified_classes)}")
|
| 179 |
+
|
| 180 |
+
print("\nStarting training...")
|
| 181 |
+
|
| 182 |
+
results = model.train(
|
| 183 |
+
data=yaml_path,
|
| 184 |
+
epochs=150, # Fewer epochs needed with large batch on H100
|
| 185 |
+
imgsz=640,
|
| 186 |
+
batch=-1, # Auto batch (H100 will use 64-128!)
|
| 187 |
+
cache='ram', # H100 has tons of RAM
|
| 188 |
+
device=0,
|
| 189 |
+
workers=8,
|
| 190 |
+
patience=30,
|
| 191 |
+
name='h100_final',
|
| 192 |
+
project='outputs',
|
| 193 |
+
|
| 194 |
+
# Augmentation
|
| 195 |
+
hsv_h=0.015,
|
| 196 |
+
hsv_s=0.7,
|
| 197 |
+
hsv_v=0.4,
|
| 198 |
+
degrees=10,
|
| 199 |
+
translate=0.1,
|
| 200 |
+
scale=0.5,
|
| 201 |
+
fliplr=0.5,
|
| 202 |
+
mosaic=1.0,
|
| 203 |
+
mixup=0.1,
|
| 204 |
+
|
| 205 |
+
lr0=0.01,
|
| 206 |
+
lrf=0.01,
|
| 207 |
+
amp=True,
|
| 208 |
+
val=True,
|
| 209 |
+
plots=True,
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
print("\n" + "=" * 70)
|
| 213 |
+
print("TRAINING COMPLETE!")
|
| 214 |
+
print("=" * 70)
|
| 215 |
+
|
| 216 |
+
# Validate
|
| 217 |
+
metrics = model.val()
|
| 218 |
+
print(f"\nFinal Metrics:")
|
| 219 |
+
print(f" mAP50: {metrics.box.map50:.4f} ({metrics.box.map50*100:.1f}%)")
|
| 220 |
+
print(f" mAP50-95: {metrics.box.map:.4f} ({metrics.box.map*100:.1f}%)")
|
| 221 |
+
print(f" Precision: {metrics.box.mp:.4f} ({metrics.box.mp*100:.1f}%)")
|
| 222 |
+
print(f" Recall: {metrics.box.mr:.4f} ({metrics.box.mr*100:.1f}%)")
|
| 223 |
+
|
| 224 |
+
# Export
|
| 225 |
+
print("\nExporting to ONNX...")
|
| 226 |
+
model.export(format='onnx', dynamic=True, simplify=True)
|
| 227 |
+
|
| 228 |
+
print("\n" + "=" * 70)
|
| 229 |
+
print("Model saved: outputs/h100_final/weights/best.pt")
|
| 230 |
+
print("=" * 70)
|
train_v3.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
train_v3.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# v3 max-throughput: yolo26m on 6673-img dataset, batch=128, cache=ram
|
| 3 |
+
from ultralytics import YOLO
|
| 4 |
+
import torch
|
| 5 |
+
print('GPU:', torch.cuda.get_device_name(0), '|', round(torch.cuda.get_device_properties(0).total_memory/1e9), 'GB')
|
| 6 |
+
|
| 7 |
+
model = YOLO('yolo26m.pt')
|
| 8 |
+
model.train(
|
| 9 |
+
data='/home/azureuser/merged_v3/data.yaml',
|
| 10 |
+
epochs=200,
|
| 11 |
+
imgsz=640,
|
| 12 |
+
batch=128, # 2x v2 — should hit ~70GB VRAM
|
| 13 |
+
device=0,
|
| 14 |
+
workers=16, # feed data faster
|
| 15 |
+
cache='ram', # dataset is ~1GB, fits easily
|
| 16 |
+
project='runs_v3',
|
| 17 |
+
name='h100_3class_v3',
|
| 18 |
+
exist_ok=True,
|
| 19 |
+
amp=True,
|
| 20 |
+
cos_lr=True,
|
| 21 |
+
close_mosaic=20,
|
| 22 |
+
mosaic=1.0, mixup=0.15, copy_paste=0.3,
|
| 23 |
+
hsv_h=0.015, hsv_s=0.7, hsv_v=0.4,
|
| 24 |
+
degrees=5.0, translate=0.1, scale=0.5, fliplr=0.5,
|
| 25 |
+
cls=1.0, box=7.5, dfl=1.5,
|
| 26 |
+
weight_decay=0.0005,
|
| 27 |
+
optimizer='auto',
|
| 28 |
+
patience=60,
|
| 29 |
+
plots=True, verbose=True,
|
| 30 |
+
)
|
| 31 |
+
print('TRAIN DONE — running val + test')
|
| 32 |
+
m = YOLO('runs_v3/h100_3class_v3/weights/best.pt')
|
| 33 |
+
print('--- VAL ---'); m.val(data='/home/azureuser/merged_v3/data.yaml', split='val')
|
| 34 |
+
print('--- TEST ---'); m.val(data='/home/azureuser/merged_v3/data.yaml', split='test')
|
| 35 |
+
print('--- TEST + TTA ---'); m.val(data='/home/azureuser/merged_v3/data.yaml', split='test', augment=True)
|
turboquant_case_study.md
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TurboQuant: 44-59% KV-Cache Reduction With Zero Quality Loss
|
| 2 |
+
|
| 3 |
+
## The Problem
|
| 4 |
+
|
| 5 |
+
LLM inference is memory-bound. As context length grows, the KV-cache eats your VRAM alive. At 8K tokens, a Gemma-2-9B model burns nearly 4 GB just on KV-cache. That's memory you can't use for batching more requests, which means fewer concurrent users per GPU, which means higher cost per query.
|
| 6 |
+
|
| 7 |
+
Every production team running LLMs hits this wall.
|
| 8 |
+
|
| 9 |
+
## What TurboQuant Does
|
| 10 |
+
|
| 11 |
+
TurboQuant applies mixed-precision quantization to the KV-cache during inference. It profiles each layer's activation norms, identifies outlier layers that need full precision, and quantizes the rest — cutting KV-cache memory by 44-59% while maintaining exact prefill fidelity.
|
| 12 |
+
|
| 13 |
+
No retraining. No fine-tuning. Drop-in replacement.
|
| 14 |
+
|
| 15 |
+
## Results
|
| 16 |
+
|
| 17 |
+
Benchmarked across 5 model families on NVIDIA H100 NVL (96 GB):
|
| 18 |
+
|
| 19 |
+
### Memory Savings at 8K Context
|
| 20 |
+
|
| 21 |
+
| Model | Default VRAM | TurboQuant VRAM | KV-Cache Saved | Prefill Fidelity |
|
| 22 |
+
|-------|-------------|----------------|----------------|-----------------|
|
| 23 |
+
| Gemma-2-9B | 9.98 GB | 7.71 GB | 2,323 MB (~59%) | Exact |
|
| 24 |
+
| Qwen2.5-32B | 23.16 GB | 21.41 GB | 1,791 MB (~47%) | Exact |
|
| 25 |
+
| Phi-4-14B | 12.28 GB | 10.92 GB | 1,392 MB (~44%) | Exact |
|
| 26 |
+
| LLaMA-3.1-8B | 7.71 GB | 6.84 GB | 890 MB (~44%) | Exact |
|
| 27 |
+
| Qwen2.5-7B | 7.08 GB | 6.71 GB | 380 MB (~44%) | Exact |
|
| 28 |
+
|
| 29 |
+
### Quality Verification
|
| 30 |
+
|
| 31 |
+
- **Prefill logit difference: 0.0 across all models** — the quantized KV-cache produces identical logits at the prefill stage
|
| 32 |
+
- **Same top-1 token prediction: 100%** — no drift in the most likely next token
|
| 33 |
+
- **Output coherence: 100%** — both default and TurboQuant outputs are fully coherent across all test prompts
|
| 34 |
+
- **Token match rate: 18-100%** on generation (expected — autoregressive sampling diverges naturally, but both outputs remain equally valid)
|
| 35 |
+
|
| 36 |
+
### Scaling With Context Length
|
| 37 |
+
|
| 38 |
+
Memory savings grow linearly with context. LLaMA-3.1-8B example:
|
| 39 |
+
|
| 40 |
+
| Context Length | Saved |
|
| 41 |
+
|---------------|-------|
|
| 42 |
+
| 1K tokens | 93 MB |
|
| 43 |
+
| 4K tokens | 417 MB |
|
| 44 |
+
| 8K tokens | 890 MB |
|
| 45 |
+
|
| 46 |
+
At 32K or 128K context (LLaMA-3.1 supports 128K), the savings become massive — potentially 3-14 GB on a single model.
|
| 47 |
+
|
| 48 |
+
### Outlier-Aware Design
|
| 49 |
+
|
| 50 |
+
Not all layers are equal. TurboQuant detects outlier layers with abnormal activation norms and keeps them at full precision:
|
| 51 |
+
|
| 52 |
+
- **Qwen2.5-7B**: layers 0 and 27 flagged as outliers (norms 273.84 and 239.91 vs median 16.86) — kept at BF16
|
| 53 |
+
- **All other models**: uniform norm distributions, all layers quantized
|
| 54 |
+
|
| 55 |
+
This is why quality stays intact — the layers that matter most keep their precision.
|
| 56 |
+
|
| 57 |
+
## What This Means For Production
|
| 58 |
+
|
| 59 |
+
If you're running LLMs in production:
|
| 60 |
+
|
| 61 |
+
- **2-3x more concurrent users** on the same GPU (freed VRAM = larger batch sizes)
|
| 62 |
+
- **Same quality** — your users won't notice any difference
|
| 63 |
+
- **No model changes** — works with any transformer architecture using standard KV-cache
|
| 64 |
+
- **Tested across**: Qwen2, LLaMA, Gemma2, Phi3 architectures
|
| 65 |
+
|
| 66 |
+
## About
|
| 67 |
+
|
| 68 |
+
Built by Vivek Varikuti. I optimize LLM inference for production workloads.
|
| 69 |
+
|
| 70 |
+
If your GPU bill is too high or your throughput is too low, I can help. Free 1-week proof-of-concept on your setup — you pay nothing if it doesn't beat your current numbers.
|
| 71 |
+
|
| 72 |
+
Reach me: domainluther1234@gmail.com | GitHub: vivekvar-dl
|
vivek_complete_profile.md
ADDED
|
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Vivek Varikuti — Complete Profile & Project Portfolio
|
| 2 |
+
|
| 3 |
+
## Who I Am
|
| 4 |
+
|
| 5 |
+
- 22 years old, AI Engineer & Startup Founder
|
| 6 |
+
- GitHub: vivekvar-dl
|
| 7 |
+
- Email: domainluther1234@gmail.com
|
| 8 |
+
- Strong Python/PyTorch/LLM skills, deep transformer training experience
|
| 9 |
+
- Hardware: 1x NVIDIA H100 NVL 96GB on Azure (NC40ads H100 v5)
|
| 10 |
+
- CUDA 12.8, PyTorch 2.7.0+cu128, flash-attn 2.8.3 (FA2)
|
| 11 |
+
- Transformers 5.4.0
|
| 12 |
+
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
## Working Style
|
| 16 |
+
|
| 17 |
+
- No AI fluff. No menus of options. Make the decision and execute.
|
| 18 |
+
- Write like a human — no perfect grammar, no emojis, no "leveraging" or "seamless"
|
| 19 |
+
- Any public text must read like a tired developer typed it at 2am
|
| 20 |
+
- No co-authored-by Claude in git commits — public contributions look fully human
|
| 21 |
+
- Verify before claiming. Test before shipping. Always run the actual code.
|
| 22 |
+
|
| 23 |
+
---
|
| 24 |
+
|
| 25 |
+
## Project 1: TurboQuant — KV Cache Compression
|
| 26 |
+
|
| 27 |
+
**What:** Implementing Google's TurboQuant paper (arXiv 2504.19874, Zandieh et al.) for KV cache compression during LLM inference.
|
| 28 |
+
|
| 29 |
+
**Why:** Compress KV cache ~4-7x on production LLMs to enable longer contexts and batching on H100 NVL (96GB).
|
| 30 |
+
|
| 31 |
+
**Location:** /home/azureuser/turboquant/
|
| 32 |
+
|
| 33 |
+
**Status:** Working prototype. Google hasn't released their code publicly — this is one of the first working implementations.
|
| 34 |
+
|
| 35 |
+
**Core Method:** Mixed-precision quantization of KV cache. Profile each layer's activation norms, identify outlier layers that need full precision, quantize the rest. No retraining, no fine-tuning — drop-in replacement.
|
| 36 |
+
|
| 37 |
+
**Key Discovery:** Layer 0 (and sometimes last layer) of Qwen models have anomalously large key norms (~16-50x median). These layers must be kept in BF16 (skip_layers). Auto-calibration function detects outlier layers.
|
| 38 |
+
|
| 39 |
+
### Benchmark Results (H100 NVL 96GB)
|
| 40 |
+
|
| 41 |
+
#### Model Architecture Summary
|
| 42 |
+
|
| 43 |
+
| Model | Architecture | KV Heads | head_dim | Outlier Layers | Prefill Fidelity |
|
| 44 |
+
|-------|-------------|----------|---------|----------------|-----------------|
|
| 45 |
+
| Qwen2.5-7B | 28L, qwen2 | 4 | 128 | layers 0, 27 | exact |
|
| 46 |
+
| Llama-3.1-8B | 32L, llama | 8 | 128 | none | exact |
|
| 47 |
+
| Gemma-2-9B | 42L, gemma2 | 8 | 256 | none | exact |
|
| 48 |
+
| Phi-4-14B | 40L, phi3 | 10 | 128 | none | exact |
|
| 49 |
+
| Qwen2.5-32B | 64L, qwen2 | 8 | 128 | none | exact |
|
| 50 |
+
| Llama-3.3-70B | 80L, llama | 8 | 128 | N/A (disk full) | N/A |
|
| 51 |
+
|
| 52 |
+
#### Memory Savings at 8K Context
|
| 53 |
+
|
| 54 |
+
| Model | Default VRAM | TurboQuant VRAM | Saved | KV Cache Reduction |
|
| 55 |
+
|-------|-------------|----------------|-------|-------------------|
|
| 56 |
+
| Gemma-2-9B | 9.98 GB | 7.71 GB | 2,323 MB | ~59% |
|
| 57 |
+
| Qwen2.5-32B | 23.16 GB | 21.41 GB | 1,791 MB | ~47% |
|
| 58 |
+
| Phi-4-14B | 12.28 GB | 10.92 GB | 1,392 MB | ~44% |
|
| 59 |
+
| LLaMA-3.1-8B | 7.71 GB | 6.84 GB | 890 MB | ~44% |
|
| 60 |
+
| Qwen2.5-7B | 7.08 GB | 6.71 GB | 380 MB | ~44% |
|
| 61 |
+
|
| 62 |
+
#### Memory Savings Scaling (LLaMA-3.1-8B)
|
| 63 |
+
|
| 64 |
+
| Context Length | Default VRAM | TurboQuant VRAM | Saved |
|
| 65 |
+
|---------------|-------------|----------------|-------|
|
| 66 |
+
| 1K tokens | 6.00 GB | 5.91 GB | 93 MB |
|
| 67 |
+
| 4K tokens | 6.67 GB | 6.27 GB | 417 MB |
|
| 68 |
+
| 8K tokens | 7.71 GB | 6.84 GB | 890 MB |
|
| 69 |
+
|
| 70 |
+
#### Full Memory Data Per Model
|
| 71 |
+
|
| 72 |
+
**Qwen2.5-7B (5.45 GB model)**
|
| 73 |
+
- Layer norms: median 16.86, max 273.84 (layer 0), ratio 16.24x
|
| 74 |
+
- Outlier layers: 0 (norm 273.84), 27 (norm 239.91)
|
| 75 |
+
- 1K: 5.76→5.73 GB (37 MB saved)
|
| 76 |
+
- 4K: 6.27→6.10 GB (176 MB saved)
|
| 77 |
+
- 8K: 7.08→6.71 GB (380 MB saved)
|
| 78 |
+
|
| 79 |
+
**LLaMA-3.1-8B (5.68 GB model)**
|
| 80 |
+
- Layer norms: median 17.90, max 21.05 (layer 7), ratio 1.18x
|
| 81 |
+
- No outlier layers
|
| 82 |
+
- 1K: 6.00→5.91 GB (93 MB saved, output match)
|
| 83 |
+
- 4K: 6.67→6.27 GB (417 MB saved, output match)
|
| 84 |
+
- 8K: 7.71→6.84 GB (890 MB saved, output match)
|
| 85 |
+
|
| 86 |
+
**Gemma-2-9B (6.08 GB model)**
|
| 87 |
+
- Layer norms: median 17.82, max 21.28 (layer 25), ratio 1.19x
|
| 88 |
+
- No outlier layers
|
| 89 |
+
- 1K: 6.62→6.38 GB (244 MB saved)
|
| 90 |
+
- 4K: 7.96→6.89 GB (1,096 MB saved)
|
| 91 |
+
- 8K: 9.98→7.71 GB (2,323 MB saved)
|
| 92 |
+
|
| 93 |
+
**Phi-4-14B (9.10 GB model)**
|
| 94 |
+
- Layer norms: median 19.21, max 26.46 (layer 0), ratio 1.38x
|
| 95 |
+
- No outlier layers
|
| 96 |
+
- 1K: 9.75→9.61 GB (146 MB saved)
|
| 97 |
+
- 4K: 10.72→10.09 GB (650 MB saved)
|
| 98 |
+
- 8K: 12.28→10.92 GB (1,392 MB saved)
|
| 99 |
+
|
| 100 |
+
**Qwen2.5-32B (19.31 GB model)**
|
| 101 |
+
- Layer norms: median 16.09, max 37.82 (layer 0), ratio 2.35x
|
| 102 |
+
- No outlier layers
|
| 103 |
+
- 1K: 19.97→19.79 GB (186 MB saved)
|
| 104 |
+
- 4K: 21.23→20.42 GB (833 MB saved)
|
| 105 |
+
- 8K: 23.16→21.41 GB (1,791 MB saved)
|
| 106 |
+
|
| 107 |
+
**LLaMA-3.3-70B** — failed with "No space left on device"
|
| 108 |
+
|
| 109 |
+
#### Quality Verification
|
| 110 |
+
|
| 111 |
+
All models tested with 3 prompts: "Explain quantum computing", "Write a Python prime checker", "What causes northern lights?"
|
| 112 |
+
|
| 113 |
+
- Prefill logit difference: 0.0 across ALL models
|
| 114 |
+
- Same top-1 token prediction: 100% across ALL models
|
| 115 |
+
- Output coherence: 100% — both default and TurboQuant outputs fully coherent
|
| 116 |
+
- Token match rate varies (18-100%) due to natural autoregressive sampling divergence — both outputs equally valid
|
| 117 |
+
|
| 118 |
+
**Detailed quality per model:**
|
| 119 |
+
|
| 120 |
+
Qwen2.5-7B: token match 39%, 3%, 54% — both coherent all 3 prompts
|
| 121 |
+
LLaMA-3.1-8B: token match 89.1%, 100%, 100% — 2/3 exact match
|
| 122 |
+
Phi-4-14B: token match 100%, 44%, 100% — 2/3 exact match
|
| 123 |
+
Gemma-2-9B: token match 100%, 100%, 18.8% — 2/3 exact match
|
| 124 |
+
Qwen2.5-32B: token match 71%, 25%, 53% — both coherent all 3 prompts
|
| 125 |
+
|
| 126 |
+
#### Infrastructure Notes
|
| 127 |
+
- Environment: torch 2.7.0+cu128, transformers 5.4.0, H100 NVL CUDA 12.8 (driver 570)
|
| 128 |
+
- PyTorch compiled for CUDA 13.0+ won't work — need cu128 wheel
|
| 129 |
+
- Core quantizer verified (MSE matches paper bounds)
|
| 130 |
+
- Cache integrates with HF Transformers v5.4.0 QuantizedLayer API
|
| 131 |
+
|
| 132 |
+
---
|
| 133 |
+
|
| 134 |
+
## Project 2: Parameter Golf Competition (OpenAI)
|
| 135 |
+
|
| 136 |
+
**What:** OpenAI competition — train the best language model within a 16MB artifact, 10 minutes on 8xH100.
|
| 137 |
+
|
| 138 |
+
**Metric:** Bits-per-byte (BPB) on FineWeb validation (62M tokens sp1024, 45.5M tokens sp4096)
|
| 139 |
+
|
| 140 |
+
**Timeline:** March 18 - April 30, 2026
|
| 141 |
+
|
| 142 |
+
**Current SOTA (merged):** 1.1194 BPP (PR #549, LeakyReLU^2 + TTT + Parallel Muon)
|
| 143 |
+
|
| 144 |
+
### Our Edge: sp4096 Vocabulary
|
| 145 |
+
|
| 146 |
+
- sp4096 tokens_per_byte: 0.3063 vs sp1024: 0.4149 → 26.2% fewer tokens
|
| 147 |
+
- Baseline A/B test (400 steps): sp4096 = 1.6208 BPB vs sp1024 = 1.7144 BPB → -5.5%
|
| 148 |
+
- #1 arch A/B test (400 steps, seed 42): sp4096+factored = 1.8693 BPB vs sp1024 = 2.0067 BPB → -6.8%
|
| 149 |
+
- Extrapolated SOTA: 1.1194 × 0.93 ≈ 1.04-1.06 BPB
|
| 150 |
+
|
| 151 |
+
### Architecture
|
| 152 |
+
|
| 153 |
+
- 11L, 512d, 8H, 4KV, 3x MLP, LeakyReLU(0.5)^2
|
| 154 |
+
- Factored embeddings: tok_emb(4096x256) + embed_up(256→512) + embed_down(512→256)
|
| 155 |
+
- All tricks from #1 submission: XSA, Partial RoPE, LN Scale, SmearGate, BigramHash, EMA, TTT, GPTQ-lite
|
| 156 |
+
|
| 157 |
+
### Key Files
|
| 158 |
+
|
| 159 |
+
- our_submission/train_gpt.py — modified #1 with sp4096 + factored embed + FA2 fallback
|
| 160 |
+
- our_submission/train_gpt_original.py — unmodified #1 with FA2 fallback
|
| 161 |
+
- train_sp4096.py — tokenizer training + data sharding script
|
| 162 |
+
- data/tokenizers/fineweb_4096_bpe.model — trained sp4096 tokenizer
|
| 163 |
+
- data/datasets/fineweb10B_sp4096/ — 80 train shards + 1 val shard
|
| 164 |
+
|
| 165 |
+
### N-gram Cache: CONFIRMED FAKE
|
| 166 |
+
|
| 167 |
+
- 256M bucket experiment: collision-free hash tables give 1.11 BPB (no improvement)
|
| 168 |
+
- All sub-1.0 BPB claims are measurement artifacts from hash collisions
|
| 169 |
+
- Valid Dirichlet smoothing gives at most ~0.002-0.005 genuine improvement
|
| 170 |
+
|
| 171 |
+
### Next Steps
|
| 172 |
+
|
| 173 |
+
1. Medium fidelity run (10min 1xH100)
|
| 174 |
+
2. Int5 MLP quantization (saves ~1.86MB for artifact budget headroom)
|
| 175 |
+
3. Get 8xH100 access for final submission (compute grant or RunPod)
|
| 176 |
+
4. Temperature scaling, document-isolated TTT for extra gains
|
| 177 |
+
|
| 178 |
+
### Hardware
|
| 179 |
+
|
| 180 |
+
- Dev: 1xH100 NVL (Azure NC40ads H100 v5), 96GB VRAM, CUDA 12.8, PyTorch 2.9.1+cu128
|
| 181 |
+
- flash-attn 2.8.3 (FA2, not FA3)
|
| 182 |
+
- Final submission needs 8xH100
|
| 183 |
+
|
| 184 |
+
---
|
| 185 |
+
|
| 186 |
+
## Project 3: GSoC 2026 — DeepChem OLMo Wrapper
|
| 187 |
+
|
| 188 |
+
**What:** Adding OLMo-2 7B LLM support to DeepChem for molecular property prediction and SMILES generation.
|
| 189 |
+
|
| 190 |
+
**Org:** DeepChem (standalone first time in GSoC 2026)
|
| 191 |
+
**Mentors:** Riya, Harindhar
|
| 192 |
+
**Deadline:** March 31, 2026 18:00 UTC (submitted)
|
| 193 |
+
|
| 194 |
+
### What Was Built
|
| 195 |
+
|
| 196 |
+
**PR #4913 (LIVE) — Bug Fix**
|
| 197 |
+
- Fixed ChemBERTa broken import for transformers 5.x
|
| 198 |
+
- `transformers.models.roberta.tokenization_roberta_fast` removed in 5.x
|
| 199 |
+
- 3 additions / 4 deletions
|
| 200 |
+
- https://github.com/deepchem/deepchem/pull/4913
|
| 201 |
+
|
| 202 |
+
**Issue #4912 (LIVE) — Compat Report**
|
| 203 |
+
- Broader transformers 5.x compatibility issues documented
|
| 204 |
+
- https://github.com/deepchem/deepchem/issues/4912
|
| 205 |
+
|
| 206 |
+
**OLMo Wrapper (LOCAL ONLY — not pushed)**
|
| 207 |
+
- Files at ~/olmo_draft/olmo.py and ~/olmo_draft/test_olmo.py
|
| 208 |
+
- Olmo2ForSequenceClassification — built from scratch (doesn't exist in HF)
|
| 209 |
+
- OLMo wrapper class extending HuggingFaceModel
|
| 210 |
+
- Added causal_lm task + generate() to base HuggingFaceModel
|
| 211 |
+
- 8/8 tests pass in 27 seconds on CPU
|
| 212 |
+
- Uses OLMo-2 (allenai/OLMo-2-1124-7B)
|
| 213 |
+
|
| 214 |
+
### Experiments Run
|
| 215 |
+
|
| 216 |
+
- BBBP classification: ROC-AUC 0.67 (random init, 12.9M params, 200 samples)
|
| 217 |
+
- ESOL regression: R² 0.37, MAE 1.27
|
| 218 |
+
- SMILES generation: 0% validity (proves pretraining is core challenge)
|
| 219 |
+
- Tokenization analysis: OLMo 0.9x tokens vs ChemBERTa, but fragments stereocenters
|
| 220 |
+
|
| 221 |
+
### Proposal
|
| 222 |
+
|
| 223 |
+
- ~/gsoc_proposal_final.md — human-written version
|
| 224 |
+
- ~/gsoc_proposal_content.md — raw technical reference
|
| 225 |
+
|
| 226 |
+
### Key Context
|
| 227 |
+
|
| 228 |
+
- PR #4907 by Aditya-ad48 also adds causal LM generation — complementary not competing
|
| 229 |
+
- DeepChem wants small PRs (<50 lines) for new contributors
|
| 230 |
+
- rbharath is the main reviewer/maintainer
|
| 231 |
+
- Office hours MWF 9am PST
|
| 232 |
+
- Discord: https://discord.gg/RYTrUY8Ssn
|
| 233 |
+
|
| 234 |
+
---
|
| 235 |
+
|
| 236 |
+
## Project 4: Genesis — Artificial Life Simulation
|
| 237 |
+
|
| 238 |
+
**What:** Virtual world where blank GRU neural net agents evolve survival behaviors from scratch — foraging, water-seeking, communication — on H100 GPU using JAX.
|
| 239 |
+
|
| 240 |
+
**Location:** /home/azureuser/genesis/ (venv at ~/genesis_env/)
|
| 241 |
+
|
| 242 |
+
### World Setup
|
| 243 |
+
|
| 244 |
+
- 512x512 grid with Perlin noise terrain
|
| 245 |
+
- Food regrowth, water sources, day/night cycles, seasons
|
| 246 |
+
- 1000 agents with GRU brains (~82K params each)
|
| 247 |
+
- Tournament selection + Gaussian mutation (self-adaptive sigma)
|
| 248 |
+
- Agents start with zero knowledge — must learn to survive
|
| 249 |
+
|
| 250 |
+
### Status (2026-04-01)
|
| 251 |
+
|
| 252 |
+
Phases 1-3 complete. 500K step run finished successfully:
|
| 253 |
+
- 86 generations evolved
|
| 254 |
+
- Agents sustain avg age 3,742 steps, energy 0.98, hydration 0.79
|
| 255 |
+
- Signal entropy dropping (4.28→3.58) — indicating early communication structure
|
| 256 |
+
- Simulation runs at ~1000 steps/s on H100 (JAX jit-compiled)
|
| 257 |
+
|
| 258 |
+
### Key Fix
|
| 259 |
+
|
| 260 |
+
food_growth_rate bumped from 0.005→0.02, food_eat_amount 0.05→0.03 to prevent ecological collapse at high generations.
|
| 261 |
+
|
| 262 |
+
### Architecture
|
| 263 |
+
|
| 264 |
+
- World: grid.py, resources.py, environment.py, physics.py, observations.py, spatial.py
|
| 265 |
+
- Agent: body.py (metabolism), brain.py (GRU + vmap batched), actions.py
|
| 266 |
+
- Evolution: fitness.py, selection.py, mutation.py (self-adaptive sigma), population.py
|
| 267 |
+
- Communication: signals.py (8-channel, spatial attenuation, top-4 reception)
|
| 268 |
+
- Analysis: emergence.py (signal entropy, magnitude, R², diversity, clustering)
|
| 269 |
+
- Visualization: renderer.py (dashboard, world map, zoom views)
|
| 270 |
+
|
| 271 |
+
### Run Data
|
| 272 |
+
|
| 273 |
+
~/genesis/runs/run_20260401_111309/ — metrics.csv (500 rows), emergence.csv (100 rows), 50 viz frames, 10 checkpoints + FINAL, config.json
|
| 274 |
+
|
| 275 |
+
### Next Steps
|
| 276 |
+
|
| 277 |
+
- Phase 4: TRIBE v2 integration — compare evolved GRU representations to human brain activity via RSA
|
| 278 |
+
- Phase 5: Scale to 5K+ agents, longer runs for 500+ generations
|
| 279 |
+
- Checkpoints at 50K intervals allow comparing brain representations across evolutionary time
|
| 280 |
+
|
| 281 |
+
---
|
| 282 |
+
|
| 283 |
+
## Project 5: TRIBE v2 — AI-Brain Loop
|
| 284 |
+
|
| 285 |
+
**What:** Closing the AI-brain comparison loop using Meta's TRIBE v2 — comparing AI encoder representations to predicted brain activity to find architectural gaps.
|
| 286 |
+
|
| 287 |
+
**Location:** /home/azureuser/tribev2 (venv at /home/azureuser/tribev2_env)
|
| 288 |
+
|
| 289 |
+
### What's Built
|
| 290 |
+
|
| 291 |
+
- Full analysis script: /home/azureuser/tribev2/close_the_loop_v2.py
|
| 292 |
+
- 8 phases: load model → extract per-layer features → brain parcellation → layer-wise encoding → modality ablation → RSA → divergence mapping → visualization
|
| 293 |
+
- Multimodal stimulus: /home/azureuser/multimodal_stimulus.mp4
|
| 294 |
+
- Results: /home/azureuser/loop_results_v2/
|
| 295 |
+
- Runs with video (V-JEPA2) + audio (Wav2Vec-BERT) + text (LLaMA 3.2-3B)
|
| 296 |
+
|
| 297 |
+
### Status
|
| 298 |
+
|
| 299 |
+
LLaMA 3.2 access granted. Full 3-modality analysis pipeline complete. Brain-guided ViT training attempted 5 times — all failed.
|
| 300 |
+
|
| 301 |
+
### Why Attempts Failed
|
| 302 |
+
|
| 303 |
+
- Never had real brain targets — routed ViT-Small features through TRIBE v2's projector (trained for V-JEPA2), producing random outputs
|
| 304 |
+
- Evaluated on wrong metric (classification accuracy instead of robustness)
|
| 305 |
+
- Literature shows brain-guided training helps ROBUSTNESS (+3-8%), not classification accuracy
|
| 306 |
+
|
| 307 |
+
### What Would Actually Work (from RESEARCH_BRIEF.md)
|
| 308 |
+
|
| 309 |
+
1. Pre-compute real brain targets using TRIBE v2's full pipeline
|
| 310 |
+
2. Train student with classification + per-vertex Pearson correlation brain loss
|
| 311 |
+
3. Evaluate on corruption/adversarial robustness, shape bias, brain-score — NOT accuracy
|
| 312 |
+
4. Or: use real fMRI data (Natural Scenes Dataset) instead of TRIBE v2 predictions
|
| 313 |
+
|
| 314 |
+
### Key Infrastructure
|
| 315 |
+
|
| 316 |
+
- Training scripts: /home/azureuser/brain_guided/train_*.py
|
| 317 |
+
- UCF-101 dataset: /home/azureuser/brain_guided/data/UCF-101 (13K videos)
|
| 318 |
+
- Results: /home/azureuser/brain_guided/results_final/
|
| 319 |
+
|
| 320 |
+
---
|
| 321 |
+
|
| 322 |
+
## Project 6: Instagram Cinema
|
| 323 |
+
|
| 324 |
+
**What:** AI-generated cinematic videos using LTX-2.3 22B on ComfyUI for Instagram growth.
|
| 325 |
+
|
| 326 |
+
**Setup:** LTX-2.3 22B dev model running on H100 via ComfyUI, exposed via cloudflared tunnel.
|
| 327 |
+
|
| 328 |
+
**Format:** Instagram Reels — 9:16 portrait, 544x960
|
| 329 |
+
|
| 330 |
+
**Goal:** Create viral-quality cinematic content for Instagram Reels.
|
| 331 |
+
|
| 332 |
+
---
|
| 333 |
+
|
| 334 |
+
## Money-Making Strategy (April 2026)
|
| 335 |
+
|
| 336 |
+
### Sellable Assets
|
| 337 |
+
|
| 338 |
+
1. **TurboQuant** — working implementation nobody else has publicly. Lead magnet for consulting.
|
| 339 |
+
2. **Parameter Golf** — competition result (if top placement) = massive credibility signal
|
| 340 |
+
3. **Fine-tuning expertise** — proven on H100, multiple model families
|
| 341 |
+
4. **Inference optimization consulting** — directly from TurboQuant benchmarks
|
| 342 |
+
|
| 343 |
+
### Immediate Plan
|
| 344 |
+
|
| 345 |
+
- Path to 10L: Freelancing/consulting — fine-tuning + inference optimization
|
| 346 |
+
- Path to 1Cr: Productized consulting at scale or AI startup
|
| 347 |
+
- Channel: X (Twitter) for distribution, direct DMs to founders for sales
|
| 348 |
+
|
| 349 |
+
### X (Twitter) Growth Strategy
|
| 350 |
+
|
| 351 |
+
- Account: 10 followers currently, Premium purchased (213.50/month with 50% off)
|
| 352 |
+
- Strategy: 70% replies (to bigger accounts), 30% original posts
|
| 353 |
+
- Target: 15 strategic replies/day to accounts with 100-5000 followers
|
| 354 |
+
- Post timing: 6:30 PM IST (9:00 AM EST) on Tue/Wed/Thu
|
| 355 |
+
- Pinned thread: TurboQuant benchmarks
|
| 356 |
+
- Goal: 500 followers in 4 weeks, first paid client in 2-4 weeks
|
| 357 |
+
|
| 358 |
+
### Cold Outreach Template
|
| 359 |
+
|
| 360 |
+
"I noticed you're using [X model]. I can cut your inference cost by 40%. Free 1-week proof. Interested?"
|
| 361 |
+
|
| 362 |
+
### Target Clients
|
| 363 |
+
|
| 364 |
+
- Indian startups using LLMs in production (inc42 AI list)
|
| 365 |
+
- US startups from YC directory (AI/ML category, S24/W25 batches)
|
| 366 |
+
- Anyone on Twitter complaining about GPU costs / inference scaling
|
| 367 |
+
- Companies with >$10K/month GPU spend
|
yc_scrape.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests, json, re, sys
|
| 2 |
+
from openpyxl import Workbook
|
| 3 |
+
|
| 4 |
+
APP = "45BWZJ1SGC"
|
| 5 |
+
|
| 6 |
+
def get_key():
|
| 7 |
+
r = requests.get("https://www.ycombinator.com/companies",
|
| 8 |
+
headers={"User-Agent": "Mozilla/5.0"}, timeout=30)
|
| 9 |
+
m = re.search(r'AlgoliaOpts\s*=\s*(\{[^}]*\})', r.text)
|
| 10 |
+
return json.loads(m.group(1))["key"]
|
| 11 |
+
|
| 12 |
+
KEY = get_key()
|
| 13 |
+
URL = f"https://{APP.lower()}-dsn.algolia.net/1/indexes/YCCompany_production/query"
|
| 14 |
+
HDR = {"X-Algolia-Application-Id": APP, "X-Algolia-API-Key": KEY,
|
| 15 |
+
"Content-Type": "application/json"}
|
| 16 |
+
|
| 17 |
+
BATCHES = ["Fall 2025", "Winter 2026", "Spring 2026", "Summer 2026"]
|
| 18 |
+
|
| 19 |
+
def fetch_batch(batch):
|
| 20 |
+
hits = []
|
| 21 |
+
page = 0
|
| 22 |
+
while True:
|
| 23 |
+
body = {"query": "", "facetFilters": [[f"batch:{batch}"]],
|
| 24 |
+
"hitsPerPage": 1000, "page": page}
|
| 25 |
+
r = requests.post(URL, headers=HDR, data=json.dumps(body), timeout=30)
|
| 26 |
+
d = r.json()
|
| 27 |
+
hits.extend(d.get("hits", []))
|
| 28 |
+
if page + 1 >= d.get("nbPages", 0):
|
| 29 |
+
break
|
| 30 |
+
page += 1
|
| 31 |
+
return hits
|
| 32 |
+
|
| 33 |
+
wb = Workbook()
|
| 34 |
+
ws = wb.active
|
| 35 |
+
ws.title = "YC Startups"
|
| 36 |
+
ws.append(["Name", "Batch", "Website", "One-liner", "Location",
|
| 37 |
+
"Industry", "Team Size", "Status", "Hiring", "Tags", "YC Page"])
|
| 38 |
+
|
| 39 |
+
totals = {}
|
| 40 |
+
all_hits = []
|
| 41 |
+
for b in BATCHES:
|
| 42 |
+
hits = fetch_batch(b)
|
| 43 |
+
totals[b] = len(hits)
|
| 44 |
+
print(f"{b}: {len(hits)}", flush=True)
|
| 45 |
+
all_hits.extend(hits)
|
| 46 |
+
|
| 47 |
+
for h in all_hits:
|
| 48 |
+
ws.append([
|
| 49 |
+
h.get("name", ""),
|
| 50 |
+
h.get("batch", ""),
|
| 51 |
+
h.get("website", ""),
|
| 52 |
+
h.get("one_liner", ""),
|
| 53 |
+
h.get("all_locations", ""),
|
| 54 |
+
h.get("industry", ""),
|
| 55 |
+
h.get("team_size", ""),
|
| 56 |
+
h.get("status", ""),
|
| 57 |
+
"Yes" if h.get("isHiring") else "No",
|
| 58 |
+
", ".join(h.get("tags", []) or []),
|
| 59 |
+
f"https://www.ycombinator.com/companies/{h.get('slug','')}",
|
| 60 |
+
])
|
| 61 |
+
|
| 62 |
+
for col, width in enumerate([22, 14, 40, 55, 28, 18, 11, 10, 8, 40, 50], start=1):
|
| 63 |
+
ws.column_dimensions[chr(64 + col)].width = width
|
| 64 |
+
|
| 65 |
+
out = "/home/azureuser/yc_companies.xlsx"
|
| 66 |
+
wb.save(out)
|
| 67 |
+
print(f"\nTOTAL: {sum(totals.values())} companies")
|
| 68 |
+
print(f"Saved: {out}")
|