| <!DOCTYPE html> |
| <html lang="en"> |
| <head> |
| <meta charset="UTF-8"> |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> |
| <title>PyTorch Training Debugger — Live Dashboard</title> |
| <script src="https://cdn.plot.ly/plotly-2.27.0.min.js"></script> |
| <style> |
| * { margin: 0; padding: 0; box-sizing: border-box; } |
| body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; background: #0d1117; color: #c9d1d9; } |
| .header { background: #161b22; padding: 16px 24px; border-bottom: 1px solid #30363d; display: flex; align-items: center; gap: 16px; } |
| .header h1 { font-size: 20px; font-weight: 600; } |
| .header .status { padding: 4px 12px; border-radius: 12px; font-size: 13px; font-weight: 500; } |
| .status.connected { background: #238636; color: #fff; } |
| .status.disconnected { background: #da3633; color: #fff; } |
| .grid { display: grid; grid-template-columns: 1fr 1fr; grid-template-rows: 1fr 1fr; gap: 12px; padding: 12px; height: calc(100vh - 60px); } |
| .panel { background: #161b22; border: 1px solid #30363d; border-radius: 8px; overflow: hidden; display: flex; flex-direction: column; } |
| .panel-title { padding: 10px 16px; font-size: 14px; font-weight: 600; color: #58a6ff; border-bottom: 1px solid #30363d; background: #0d1117; } |
| .panel-body { flex: 1; padding: 8px; position: relative; min-height: 0; } |
| .panel-body > div:first-child { width: 100%; height: 100%; } |
| .placeholder { display: flex; align-items: center; justify-content: center; height: 100%; color: #484f58; font-style: italic; } |
| #controls { display: flex; gap: 8px; align-items: center; } |
| #controls select, #controls button { background: #21262d; color: #c9d1d9; border: 1px solid #30363d; padding: 6px 12px; border-radius: 6px; cursor: pointer; font-size: 13px; } |
| #controls button:hover { background: #30363d; } |
| #controls button.primary { background: #238636; border-color: #238636; color: #fff; } |
| #summary { padding: 16px; font-size: 13px; line-height: 1.8; overflow-y: auto; } |
| #summary .row { display: flex; justify-content: space-between; border-bottom: 1px solid #21262d; padding: 4px 0; } |
| #summary .label { color: #8b949e; } |
| #summary .value { font-weight: 600; } |
| #summary .score { font-size: 24px; color: #58a6ff; text-align: center; margin: 12px 0; } |
| .actions-list { display: flex; flex-wrap: wrap; gap: 4px; margin-top: 8px; } |
| .action-tag { padding: 2px 8px; border-radius: 4px; font-size: 11px; font-weight: 500; } |
| .action-tag.investigate { background: #1f6feb33; color: #58a6ff; } |
| .action-tag.fix { background: #23863633; color: #3fb950; } |
| .action-tag.terminal { background: #da363333; color: #f85149; } |
| .action-tag.wrong { background: #da363366; color: #f85149; } |
| </style> |
| </head> |
| <body> |
| <div class="header"> |
| <h1>PyTorch Training Debugger</h1> |
| <div id="connStatus" class="status disconnected">Disconnected</div> |
| <div id="controls"> |
| <select id="taskSelect"> |
| <option value="task_001">Task 1 — Exploding Gradients (Easy)</option> |
| <option value="task_002">Task 2 — Vanishing Gradients (Easy)</option> |
| <option value="task_003">Task 3 — Data Leakage (Medium)</option> |
| <option value="task_004">Task 4 — Overfitting (Medium)</option> |
| <option value="task_005">Task 5 — BatchNorm Eval (Hard)</option> |
| <option value="task_006">Task 6 — Code Bug (Hard)</option> |
| <option value="task_007">Task 7 — Scheduler Misconfigured (Med-Hard)</option> |
| </select> |
| <button class="primary" onclick="runBaseline()">Run Baseline</button> |
| </div> |
| </div> |
| <div class="grid"> |
| <div class="panel"> |
| <div class="panel-title">Training Metrics</div> |
| <div class="panel-body"><div id="metricsChart"><div class="placeholder">Run baseline to see metrics</div></div></div> |
| </div> |
| <div class="panel"> |
| <div class="panel-title">Gradient & Weight Heatmap</div> |
| <div class="panel-body"><div id="gradientChart"><div class="placeholder">Not yet inspected</div></div></div> |
| </div> |
| <div class="panel"> |
| <div class="panel-title">Action Timeline & Rewards</div> |
| <div class="panel-body"><div id="timelineChart"><div class="placeholder">No actions yet</div></div></div> |
| </div> |
| <div class="panel"> |
| <div class="panel-title">Episode Summary</div> |
| <div class="panel-body" id="summary"> |
| <div class="placeholder">Waiting for episode</div> |
| </div> |
| </div> |
| </div> |
|
|
| <script> |
| const host = window.location.host; |
| const wsProto = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; |
| let ws = null; |
| let actions = []; |
| let rewards = []; |
| let cumRewards = []; |
| let obs = null; |
| |
| function setStatus(connected) { |
| const el = document.getElementById('connStatus'); |
| el.textContent = connected ? 'Connected' : 'Disconnected'; |
| el.className = 'status ' + (connected ? 'connected' : 'disconnected'); |
| } |
| |
| function connect() { |
| ws = new WebSocket(`${wsProto}//${host}/ws`); |
| ws.onopen = () => setStatus(true); |
| ws.onclose = () => { setStatus(false); setTimeout(connect, 2000); }; |
| ws.onerror = () => ws.close(); |
| ws.onmessage = (ev) => { |
| const msg = JSON.parse(ev.data); |
| if (msg.type === 'observation' && msg.data) { |
| |
| const wrapper = msg.data; |
| const obsData = wrapper.observation || wrapper; |
| obsData.reward = wrapper.reward; |
| obsData.done = wrapper.done; |
| handleObservation(obsData); |
| } |
| }; |
| } |
| |
| function handleObservation(data) { |
| obs = data; |
| if (data.reward !== null && data.reward !== undefined) { |
| rewards.push(data.reward); |
| const prev = cumRewards.length > 0 ? cumRewards[cumRewards.length - 1] : 0; |
| cumRewards.push(prev + data.reward); |
| } |
| if (data.episode_state && data.episode_state.actions_taken) { |
| actions = data.episode_state.actions_taken; |
| } |
| updateMetrics(data); |
| updateGradients(data); |
| updateTimeline(); |
| updateSummary(data); |
| } |
| |
| function updateMetrics(d) { |
| const traces = []; |
| if (d.training_loss_history && d.training_loss_history.length > 0) { |
| const valid = d.training_loss_history.filter(v => isFinite(v)); |
| traces.push({ y: valid, name: 'Train Loss', line: { color: '#f85149' } }); |
| } |
| if (d.val_loss_history && d.val_loss_history.length > 0) { |
| const valid = d.val_loss_history.filter(v => isFinite(v)); |
| traces.push({ y: valid, name: 'Val Loss', line: { color: '#f0883e', dash: 'dash' } }); |
| } |
| if (d.val_accuracy_history && d.val_accuracy_history.length > 0) { |
| traces.push({ y: d.val_accuracy_history, name: 'Val Accuracy', yaxis: 'y2', line: { color: '#3fb950' } }); |
| } |
| if (traces.length === 0) return; |
| Plotly.newPlot('metricsChart', traces, { |
| paper_bgcolor: 'transparent', plot_bgcolor: 'transparent', |
| font: { color: '#c9d1d9', size: 11 }, |
| margin: { t: 10, b: 30, l: 50, r: 50 }, |
| xaxis: { title: 'Epoch', gridcolor: '#21262d' }, |
| yaxis: { title: 'Loss', gridcolor: '#21262d' }, |
| yaxis2: { title: 'Accuracy', overlaying: 'y', side: 'right', range: [0, 1], gridcolor: '#21262d' }, |
| legend: { x: 0, y: 1.15, orientation: 'h' }, |
| showlegend: true, |
| }, { responsive: true }); |
| } |
| |
| function updateGradients(d) { |
| if (!d.gradient_stats || d.gradient_stats.length === 0) return; |
| const layers = d.gradient_stats.map(g => g.layer_name); |
| const norms = d.gradient_stats.map(g => g.mean_norm); |
| const colors = d.gradient_stats.map(g => g.is_exploding ? '#f85149' : g.is_vanishing ? '#1f6feb' : '#3fb950'); |
| Plotly.newPlot('gradientChart', [{ |
| x: layers, y: norms, type: 'bar', |
| marker: { color: colors }, |
| text: d.gradient_stats.map(g => g.is_exploding ? 'EXPLODING' : g.is_vanishing ? 'VANISHING' : 'Normal'), |
| textposition: 'auto', |
| }], { |
| paper_bgcolor: 'transparent', plot_bgcolor: 'transparent', |
| font: { color: '#c9d1d9', size: 11 }, |
| margin: { t: 10, b: 30, l: 50, r: 20 }, |
| yaxis: { title: 'Mean Grad Norm', gridcolor: '#21262d', type: 'log' }, |
| xaxis: { gridcolor: '#21262d' }, |
| }, { responsive: true }); |
| } |
| |
| function updateTimeline() { |
| if (actions.length === 0) return; |
| const colors = actions.map(a => { |
| if (a.startsWith('inspect')) return '#1f6feb'; |
| if (a.startsWith('fix') || a === 'modify_config' || a === 'patch_data_loader' || a === 'add_callback' || a === 'replace_optimizer') return '#238636'; |
| if (a.startsWith('mark_diagnosed')) return '#da3633'; |
| if (a === 'restart_run') return '#f0883e'; |
| return '#484f58'; |
| }); |
| Plotly.newPlot('timelineChart', [ |
| { x: actions.map((_, i) => i + 1), y: rewards, type: 'bar', name: 'Step Reward', marker: { color: rewards.map(r => r >= 0 ? '#3fb950' : '#f85149') } }, |
| { x: actions.map((_, i) => i + 1), y: cumRewards, type: 'scatter', name: 'Cumulative', line: { color: '#58a6ff', width: 2 } } |
| ], { |
| paper_bgcolor: 'transparent', plot_bgcolor: 'transparent', |
| font: { color: '#c9d1d9', size: 11 }, |
| margin: { t: 10, b: 30, l: 50, r: 20 }, |
| xaxis: { title: 'Step', gridcolor: '#21262d', tickvals: actions.map((_, i) => i + 1), ticktext: actions.map(a => a.split(':')[0].replace('inspect_', 'i_').replace('mark_diagnosed', 'diag')) }, |
| yaxis: { title: 'Reward', gridcolor: '#21262d' }, |
| legend: { x: 0, y: 1.15, orientation: 'h' }, |
| }, { responsive: true }); |
| } |
| |
| function updateSummary(d) { |
| const s = d.episode_state || {}; |
| const avail = d.available_actions || []; |
| let html = ''; |
| if (d.done) { |
| html += `<div class="score">Episode Complete</div>`; |
| } |
| html += '<div class="row"><span class="label">Task</span><span class="value">' + (d.run_id || '-') + '</span></div>'; |
| html += '<div class="row"><span class="label">Steps</span><span class="value">' + (s.step_count || 0) + '</span></div>'; |
| html += '<div class="row"><span class="label">Gradients Inspected</span><span class="value">' + (s.gradients_inspected ? 'Yes' : 'No') + '</span></div>'; |
| html += '<div class="row"><span class="label">Gradients Normal</span><span class="value">' + (s.gradients_were_normal ? 'Yes' : '-') + '</span></div>'; |
| html += '<div class="row"><span class="label">Data Inspected</span><span class="value">' + (s.data_inspected ? 'Yes' : 'No') + '</span></div>'; |
| html += '<div class="row"><span class="label">Model Modes Inspected</span><span class="value">' + (s.model_modes_inspected ? 'Yes' : 'No') + '</span></div>'; |
| html += '<div class="row"><span class="label">Code Inspected</span><span class="value">' + (s.code_inspected ? 'Yes' : 'No') + '</span></div>'; |
| html += '<div class="row"><span class="label">Fix Applied</span><span class="value">' + (s.fix_action_taken ? 'Yes' : 'No') + '</span></div>'; |
| html += '<div class="row"><span class="label">Restarted</span><span class="value">' + (s.restart_after_fix ? 'Yes' : 'No') + '</span></div>'; |
| html += '<div class="row"><span class="label">Diagnosed</span><span class="value">' + (s.diagnosis_submitted ? 'Yes' : 'No') + '</span></div>'; |
| if (d.code_snippet) { |
| html += '<div style="margin-top:12px"><span class="label">Code:</span><pre style="background:#0d1117;padding:8px;border-radius:4px;font-size:11px;overflow:auto;max-height:120px;margin-top:4px">' + d.code_snippet.code.replace(/</g,'<') + '</pre></div>'; |
| } |
| html += '<div style="margin-top:8px"><span class="label">Available Actions:</span></div>'; |
| html += '<div class="actions-list">'; |
| avail.forEach(a => { |
| let cls = 'investigate'; |
| if (a.startsWith('fix') || a === 'modify_config' || a === 'patch_data_loader' || a === 'add_callback' || a === 'replace_optimizer') cls = 'fix'; |
| if (a === 'mark_diagnosed' || a === 'restart_run') cls = 'terminal'; |
| html += `<span class="action-tag ${cls}">${a}</span>`; |
| }); |
| html += '</div>'; |
| document.getElementById('summary').innerHTML = html; |
| } |
| |
| function sendStep(action) { |
| return new Promise(resolve => { |
| const handler = (ev) => { |
| const msg = JSON.parse(ev.data); |
| if (msg.type === 'observation') { |
| ws.removeEventListener('message', handler); |
| resolve(msg); |
| } |
| }; |
| ws.addEventListener('message', handler); |
| ws.send(JSON.stringify({ type: 'step', data: action })); |
| }); |
| } |
| |
| function sendReset(taskId) { |
| return new Promise(resolve => { |
| const handler = (ev) => { |
| const msg = JSON.parse(ev.data); |
| if (msg.type === 'observation') { |
| ws.removeEventListener('message', handler); |
| resolve(msg); |
| } |
| }; |
| ws.addEventListener('message', handler); |
| ws.send(JSON.stringify({ type: 'reset', data: { task_id: taskId, seed: 42 } })); |
| }); |
| } |
| |
| async function runBaseline() { |
| const taskId = document.getElementById('taskSelect').value; |
| actions = []; rewards = []; cumRewards = []; |
| if (!ws || ws.readyState !== WebSocket.OPEN) return; |
| |
| const delay = (ms) => new Promise(r => setTimeout(r, ms)); |
| |
| |
| await sendReset(taskId); |
| await delay(300); |
| |
| |
| await sendStep({ action_type: 'inspect_gradients' }); |
| await delay(300); |
| |
| const gs = obs && obs.gradient_stats ? obs.gradient_stats : []; |
| const anyExploding = gs.some(g => g.is_exploding); |
| const anyVanishing = gs.some(g => g.is_vanishing); |
| |
| if (anyExploding) { |
| await sendStep({ action_type: 'modify_config', target: 'learning_rate', value: 0.001 }); |
| await delay(300); |
| await sendStep({ action_type: 'restart_run' }); |
| await delay(300); |
| await sendStep({ action_type: 'mark_diagnosed', diagnosis: 'lr_too_high' }); |
| return; |
| } |
| |
| if (anyVanishing) { |
| await sendStep({ action_type: 'modify_config', target: 'learning_rate', value: 0.01 }); |
| await delay(300); |
| await sendStep({ action_type: 'restart_run' }); |
| await delay(300); |
| await sendStep({ action_type: 'mark_diagnosed', diagnosis: 'vanishing_gradients' }); |
| return; |
| } |
| |
| |
| await sendStep({ action_type: 'inspect_data_batch' }); |
| await delay(300); |
| |
| const dbs = obs && obs.data_batch_stats ? obs.data_batch_stats : {}; |
| if (dbs.class_overlap_score && dbs.class_overlap_score > 0.5) { |
| await sendStep({ action_type: 'patch_data_loader' }); |
| await delay(300); |
| await sendStep({ action_type: 'restart_run' }); |
| await delay(300); |
| await sendStep({ action_type: 'mark_diagnosed', diagnosis: 'data_leakage' }); |
| return; |
| } |
| |
| |
| const tl = obs && obs.training_loss_history ? obs.training_loss_history : []; |
| const vl = obs && obs.val_loss_history ? obs.val_loss_history : []; |
| const lastTrainLoss = tl.length > 0 ? tl[tl.length - 1] : 999; |
| const lastValLoss = vl.length > 0 ? vl[vl.length - 1] : 0; |
| const earlyValLoss = vl.length > 5 ? vl[5] : lastValLoss; |
| const isOverfitting = lastTrainLoss < 0.1 && lastValLoss > earlyValLoss; |
| |
| if (isOverfitting) { |
| await sendStep({ action_type: 'modify_config', target: 'weight_decay', value: 0.01 }); |
| await delay(300); |
| await sendStep({ action_type: 'restart_run' }); |
| await delay(300); |
| await sendStep({ action_type: 'mark_diagnosed', diagnosis: 'overfitting' }); |
| return; |
| } |
| |
| |
| await sendStep({ action_type: 'inspect_model_modes' }); |
| await delay(300); |
| |
| const modes = obs && obs.model_mode_info ? obs.model_mode_info : {}; |
| const anyEval = Object.values(modes).some(m => m === 'eval'); |
| if (anyEval) { |
| await sendStep({ action_type: 'fix_model_mode' }); |
| await delay(300); |
| await sendStep({ action_type: 'restart_run' }); |
| await delay(300); |
| await sendStep({ action_type: 'mark_diagnosed', diagnosis: 'batchnorm_eval_mode' }); |
| return; |
| } |
| |
| |
| await sendStep({ action_type: 'inspect_code' }); |
| await delay(300); |
| |
| if (obs && obs.code_snippet && obs.code_snippet.code) { |
| const code = obs.code_snippet.code; |
| const lines = code.split('\n'); |
| let fixLine = null, fixReplacement = null; |
| for (let i = 0; i < lines.length; i++) { |
| const ln = lines[i].trim(); |
| if (ln.includes('model.eval()')) { fixLine = i + 1; fixReplacement = lines[i].replace('model.eval()', 'model.train()'); break; } |
| if (ln.includes('.detach()') && ln.includes('criterion')) { fixLine = i + 1; fixReplacement = lines[i].replace('.detach()', ''); break; } |
| if (ln.includes('inplace=True')) { fixLine = i + 1; fixReplacement = lines[i].replace('inplace=True', ''); break; } |
| } |
| if (fixLine) { |
| await sendStep({ action_type: 'fix_code', line: fixLine, replacement: fixReplacement }); |
| await delay(300); |
| } else { |
| |
| for (let i = 0; i < lines.length; i++) { |
| if (lines[i].trim().includes('optimizer.step()')) { |
| fixLine = i + 1; |
| fixReplacement = ' optimizer.zero_grad()\n' + lines[i]; |
| break; |
| } |
| } |
| if (fixLine) { |
| await sendStep({ action_type: 'fix_code', line: fixLine, replacement: fixReplacement }); |
| await delay(300); |
| } |
| } |
| await sendStep({ action_type: 'restart_run' }); |
| await delay(300); |
| await sendStep({ action_type: 'mark_diagnosed', diagnosis: 'code_bug' }); |
| return; |
| } |
| |
| |
| const va = obs && obs.val_accuracy_history ? obs.val_accuracy_history : []; |
| const midAcc = va.length > 10 ? va[9] : 0; |
| const endAcc = va.length > 0 ? va[va.length - 1] : 0; |
| const stagnated = midAcc > 0.3 && (endAcc - midAcc) < 0.05; |
| |
| if (stagnated) { |
| await sendStep({ action_type: 'modify_config', target: 'learning_rate', value: 0.005 }); |
| await delay(300); |
| await sendStep({ action_type: 'restart_run' }); |
| await delay(300); |
| await sendStep({ action_type: 'mark_diagnosed', diagnosis: 'scheduler_misconfigured' }); |
| return; |
| } |
| |
| |
| await sendStep({ action_type: 'modify_config', target: 'weight_decay', value: 0.01 }); |
| await delay(300); |
| await sendStep({ action_type: 'restart_run' }); |
| await delay(300); |
| await sendStep({ action_type: 'mark_diagnosed', diagnosis: 'overfitting' }); |
| } |
| |
| connect(); |
| </script> |
| </body> |
| </html> |
|
|