Pranesh64 commited on
Commit
217598f
·
verified ·
1 Parent(s): f99a0e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -489
app.py CHANGED
@@ -3,13 +3,6 @@
3
  # - Standard TensorFlow (Keras) based
4
  # - Gradio 5 compatible (no theme=)
5
  # - CPU-friendly (disables GPU usage)
6
- # - Features:
7
- # * Load model (.h5) or use example models
8
- # * Graph visualization (nodes = layers, edges = inbound connections)
9
- # * Click node -> inspect layer attributes, shapes, params
10
- # * View weights (kernels as images + histogram)
11
- # * Activation maps for conv layers (image input)
12
- # * Simple vs Advanced explanatory text
13
  # ==========================================================
14
 
15
  import io
@@ -17,7 +10,7 @@ import os
17
  import math
18
  import traceback
19
  import warnings
20
- from typing import Any, Dict, List, Optional, Tuple
21
 
22
  import gradio as gr
23
  import numpy as np
@@ -27,11 +20,10 @@ import networkx as nx
27
 
28
  warnings.filterwarnings("ignore")
29
 
30
- # Try importing tensorflow (standard). If import fails, we show a friendly error.
31
  try:
32
  import tensorflow as tf
33
  from tensorflow import keras
34
- # force CPU to avoid GPU surprises in Spaces
35
  try:
36
  tf.config.set_visible_devices([], "GPU")
37
  except Exception:
@@ -41,533 +33,181 @@ except Exception as e:
41
  TF_AVAILABLE = False
42
  TF_IMPORT_ERROR = str(e)
43
 
44
-
45
  # -------------------- Helpers --------------------
46
 
47
  def safe_load_keras_model(fileobj: Optional[io.BytesIO], chosen: str):
48
- """
49
- If fileobj provided (uploaded .h5), load that model.
50
- Else create a built-in small example model depending on 'chosen'.
51
- """
52
  if not TF_AVAILABLE:
53
- raise RuntimeError("TensorFlow not available. Add 'tensorflow' to requirements.txt")
54
 
55
  if fileobj:
56
- # load uploaded .h5 bytes
57
  fileobj.seek(0)
58
  tmp_path = "/tmp/uploaded_model.h5"
59
  with open(tmp_path, "wb") as f:
60
  f.write(fileobj.read())
61
  model = keras.models.load_model(tmp_path)
62
  return model, "uploaded .h5 model"
63
- else:
64
- # built-in models: "small_cnn" or "toy_resnet"
65
- if chosen == "small_cnn":
66
- model = keras.Sequential(
67
- [
68
- keras.layers.InputLayer(input_shape=(64, 64, 3)),
69
- keras.layers.Conv2D(16, 3, activation="relu", padding="same"),
70
- keras.layers.MaxPool2D(),
71
- keras.layers.Conv2D(32, 3, activation="relu", padding="same"),
72
- keras.layers.MaxPool2D(),
73
- keras.layers.Conv2D(64, 3, activation="relu", padding="same"),
74
- keras.layers.GlobalAveragePooling2D(),
75
- keras.layers.Dense(64, activation="relu"),
76
- keras.layers.Dense(10, activation="softmax"),
77
- ]
78
- )
79
- # build model
80
- model.build(input_shape=(None, 64, 64, 3))
81
- return model, "Small CNN (example)"
82
- elif chosen == "toy_resnet":
83
- inputs = keras.Input(shape=(64, 64, 3))
84
- x = keras.layers.Conv2D(32, 3, strides=1, padding="same", activation="relu")(inputs)
85
- for _ in range(2):
86
- y = keras.layers.Conv2D(32, 3, padding="same", activation="relu")(x)
87
- y = keras.layers.Conv2D(32, 3, padding="same")(y)
88
- x = keras.layers.add([x, y])
89
- x = keras.layers.ReLU()(x)
90
- x = keras.layers.GlobalAveragePooling2D()(x)
91
- outputs = keras.layers.Dense(5, activation="softmax")(x)
92
- model = keras.Model(inputs, outputs)
93
- model.build(input_shape=(None, 64, 64, 3))
94
- return model, "Toy ResNet-like (example)"
95
- else:
96
- # fallback to small_cnn
97
- return safe_load_keras_model(None, "small_cnn")
98
-
99
-
100
- def model_summary_str(model: keras.Model) -> str:
101
- """Return model.summary() as a string."""
102
- if not TF_AVAILABLE:
103
- return "TensorFlow not available."
104
  stream = io.StringIO()
105
  model.summary(print_fn=lambda s: stream.write(s + "\n"))
106
  return stream.getvalue()
107
 
108
-
109
  # -------------------- Graph builder --------------------
110
 
111
- def build_layer_graph(model: keras.Model):
112
- """
113
- Build a directed graph (networkx) of layers. Node attributes include:
114
- - name, class_name, inbound_layers, outbound_layers, input_shape, output_shape, params
115
- """
116
  G = nx.DiGraph()
117
- # Keras keeps layers in model.layers
118
- layers = model.layers
119
- # build simple mapping from layer.name -> layer
120
- name2layer = {layer.name: layer for layer in layers}
121
-
122
- # gather inbound/outbound info from layer._inbound_nodes
123
- for layer in layers:
124
- node_attr = {}
125
- node_attr["name"] = layer.name
126
- node_attr["class_name"] = layer.__class__.__name__
127
- try:
128
- node_attr["input_shape"] = layer.input_shape
129
- except Exception:
130
- node_attr["input_shape"] = None
131
- try:
132
- node_attr["output_shape"] = layer.output_shape
133
- except Exception:
134
- node_attr["output_shape"] = None
135
- try:
136
- node_attr["params"] = layer.count_params()
137
- except Exception:
138
- node_attr["params"] = None
139
-
140
- # inbound layer names (may be empty for InputLayer)
141
  inbound = []
142
- try:
143
- for node in getattr(layer, "_inbound_nodes", []) or []:
144
- for inbound_layer in getattr(node, "inbound_layers", []) or []:
145
- if hasattr(inbound_layer, "name"):
146
- inbound.append(inbound_layer.name)
147
- except Exception:
148
- inbound = []
149
-
150
- node_attr["inbound_layers"] = inbound
151
- G.add_node(layer.name, **node_attr)
152
-
153
- # add edges based on inbound lists
154
- for node in G.nodes(data=True):
155
- src = node[0]
156
- inbound = node[1].get("inbound_layers", [])
157
- for src_in in inbound:
158
- if not G.has_node(src_in):
159
- # sometimes inbound is a tensor name; ignore
160
- continue
161
- G.add_edge(src_in, src)
162
  return G
163
 
164
 
165
- def nx_to_plotly_fig(G: nx.DiGraph, highlight_node: Optional[str] = None):
166
- """
167
- Convert networkx graph into a Plotly network figure for interactive selection.
168
- Node hover shows class_name and params. Node click returns node name via customdata.
169
- """
170
- pos = nx.spring_layout(G, seed=42, k=0.5)
171
- node_x = []
172
- node_y = []
173
- texts = []
174
- customdata = []
175
- sizes = []
176
- for n, d in G.nodes(data=True):
177
- x, y = pos[n]
178
- node_x.append(x)
179
- node_y.append(y)
180
- cname = d.get("class_name", "")
181
- params = d.get("params", 0)
182
- texts.append(f"{n} ({cname})\nparams: {params}")
183
- customdata.append(n)
184
- sizes.append(20 if n != highlight_node else 36)
185
-
186
- edge_x = []
187
- edge_y = []
188
  for u, v in G.edges():
189
  x0, y0 = pos[u]
190
  x1, y1 = pos[v]
191
  edge_x += [x0, x1, None]
192
  edge_y += [y0, y1, None]
193
 
194
- edge_trace = go.Scatter(
195
- x=edge_x,
196
- y=edge_y,
197
- line=dict(width=1, color="#888"),
198
- hoverinfo="none",
199
- mode="lines",
200
- )
201
-
202
- node_trace = go.Scatter(
203
- x=node_x,
204
- y=node_y,
205
- mode="markers+text",
206
- text=[n for n in G.nodes()],
207
- textposition="top center",
208
- marker=dict(size=sizes, color="#1f78b4"),
209
- hoverinfo="text",
210
- hovertext=texts,
211
- customdata=customdata,
212
- )
213
 
214
- fig = go.Figure(data=[edge_trace, node_trace])
215
- fig.update_layout(
216
- showlegend=False,
217
- hovermode="closest",
218
- margin=dict(b=20, l=5, r=5, t=40),
219
- height=600,
220
- clickmode="event+select",
221
- )
222
  fig.update_xaxes(visible=False)
223
  fig.update_yaxes(visible=False)
224
  return fig
225
 
226
-
227
- # -------------------- Inspect layer details --------------------
228
-
229
- def get_layer_info(model: keras.Model, layer_name: str) -> Dict[str, Any]:
230
- """Return layer info: class, input/output shapes, params, config"""
231
- if not TF_AVAILABLE:
232
- return {"error": "TensorFlow not installed."}
233
- try:
234
- layer = model.get_layer(layer_name)
235
- except Exception as e:
236
- return {"error": f"Layer not found: {e}"}
237
- info = {
238
- "name": layer.name,
239
- "class_name": layer.__class__.__name__,
240
- "input_shape": getattr(layer, "input_shape", None),
241
- "output_shape": getattr(layer, "output_shape", None),
242
- "params": layer.count_params() if hasattr(layer, "count_params") else None,
243
- "trainable": getattr(layer, "trainable", None),
244
- "config": getattr(layer, "get_config", lambda: {})(),
245
- }
246
- return info
247
-
248
-
249
- def visualize_weights(layer):
250
- """
251
- For Conv2D kernels, show first few filters as small images.
252
- For Dense layers show weight histogram.
253
- Returns: PIL image (visual collage) and histogram data (bins/counts)
254
- """
255
- try:
256
- weights = layer.get_weights()
257
- except Exception:
258
- return None, None
259
-
260
- if len(weights) == 0:
261
- return None, None
262
-
263
- # Conv2D: kernel shape (kh, kw, in_ch, out_ch)
264
- w = weights[0]
265
- if w.ndim == 4:
266
- kh, kw, ic, oc = w.shape
267
- # visualize up to 8 filters (channels)
268
- nshow = min(8, oc)
269
- tile_w = kw
270
- tile_h = kh
271
- pad = 2
272
- # normalize each filter to 0..255
273
- imgs = []
274
- for i in range(nshow):
275
- filt = w[:, :, :, i]
276
- # collapse input channels by averaging
277
- img_arr = filt.mean(axis=2)
278
- mn, mx = img_arr.min(), img_arr.max()
279
- if mx - mn > 1e-6:
280
- img_norm = (img_arr - mn) / (mx - mn)
281
- else:
282
- img_norm = np.zeros_like(img_arr)
283
- img8 = (img_norm * 255).astype("uint8")
284
- imgs.append(Image.fromarray(img8).resize((tile_w * 8, tile_h * 8)))
285
- # stitch horizontally
286
- total_w = sum(im.width for im in imgs) + pad * (len(imgs) - 1)
287
- hmax = max(im.height for im in imgs)
288
- coll = Image.new("L", (total_w, hmax), color=0)
289
- x = 0
290
- for im in imgs:
291
- coll.paste(im, (x, 0))
292
- x += im.width + pad
293
- return coll.convert("RGB"), np.histogram(w.flatten(), bins=50)
294
- else:
295
- # Dense or other: histogram
296
- hist = np.histogram(w.flatten(), bins=80)
297
- return None, hist
298
-
299
-
300
- # -------------------- Activation extraction --------------------
301
-
302
- def build_activation_model(model: keras.Model, layer_names: List[str]):
303
- """
304
- Create a model that returns outputs of specified layers.
305
- """
306
- if not TF_AVAILABLE:
307
- raise RuntimeError("TensorFlow not available")
308
- outputs = [model.get_layer(name).output for name in layer_names]
309
- act_model = keras.Model(inputs=model.inputs, outputs=outputs)
310
- return act_model
311
-
312
-
313
- def compute_activations(act_model: keras.Model, pil_img: Image.Image):
314
- """
315
- Resize image to model input (if possible) and return activations as np arrays.
316
- For conv layers, they will be (H, W, C) arrays.
317
- """
318
- # determine required size from model input
319
- try:
320
- input_shape = act_model.input_shape
321
- except Exception:
322
- input_shape = None
323
- if input_shape and len(input_shape) == 4:
324
- ih, iw = input_shape[1], input_shape[2]
325
- else:
326
- ih, iw = 224, 224
327
- img = pil_img.convert("RGB").resize((iw, ih))
328
- arr = np.array(img).astype("float32") / 255.0
329
- arr = np.expand_dims(arr, axis=0)
330
- with np.errstate(all="ignore"):
331
- outs = act_model.predict(arr)
332
- # ensure list
333
- if not isinstance(outs, list):
334
- outs = [outs]
335
- outs_np = [o.squeeze() for o in outs]
336
- return outs_np
337
-
338
-
339
- # -------------------- GRADIO UI callbacks --------------------
340
-
341
- def load_model_callback(model_file, example_choice):
342
- if not TF_AVAILABLE:
343
- return {
344
- "error": True,
345
- "message": "TensorFlow not installed in the environment. Add 'tensorflow' to requirements.txt and redeploy."
346
- }
347
- try:
348
- model, tag = safe_load_keras_model(model_file, example_choice)
349
- summary = model_summary_str(model)
350
- G = build_layer_graph(model)
351
- fig = nx_to_plotly_fig(G)
352
- # basic stats
353
- total_params = model.count_params()
354
- return {
355
- "error": False,
356
- "model": model,
357
- "graph_fig": fig,
358
- "summary": summary,
359
- "tag": tag,
360
- "total_params": total_params,
361
- "nx_graph": G
362
- }
363
- except Exception as e:
364
- return {"error": True, "message": f"Failed to load model: {e}\n{traceback.format_exc()}"}
365
-
366
 
367
  def node_inspect_callback(state, node_name):
368
- """
369
- state contains 'model' and 'nx_graph'
370
- """
371
  if not state:
372
  return "No model loaded.", None, None
373
- model = state.get("model")
374
- nx_graph = state.get("nx_graph")
375
- if node_name is None:
376
- return "Click a node in the graph to inspect layer details.", None, None
377
- try:
378
- info = get_layer_info(model, node_name)
379
- # weights visualization
380
- layer = model.get_layer(node_name)
381
- weights_img, hist = visualize_weights(layer)
382
- # create a small HTML summary
383
- html = f"**Layer:** {info['name']} ({info['class_name']}) \n"
384
- html += f"- input shape: `{info['input_shape']}` \n"
385
- html += f"- output shape: `{info['output_shape']}` \n"
386
- html += f"- params: `{info['params']}` \n"
387
- html += f"- trainable: `{info['trainable']}` \n"
388
- return html, weights_img, hist
389
- except Exception as e:
390
- return f"Error inspecting node: {e}", None, None
391
-
392
-
393
- def activation_callback(state, uploaded_image, selected_layers_text):
394
- """
395
- Compute activations for selected layers (comma separated)
396
- Return list of PIL preview images (for convs show channel grid; for dense show vector plot)
397
- """
398
- if not state or "model" not in state:
399
- return None, "No model loaded."
400
- try:
401
- model = state["model"]
402
- # parse layers
403
- layer_names = [s.strip() for s in selected_layers_text.split(",") if s.strip()]
404
- # validate layers
405
- valid = []
406
- for name in layer_names:
407
- try:
408
- _ = model.get_layer(name)
409
- valid.append(name)
410
- except Exception:
411
- pass
412
- if len(valid) == 0:
413
- return None, "No valid layer names found. Use exact layer names from the graph or summary."
414
-
415
- act_model = build_activation_model(model, valid)
416
- activations = compute_activations(act_model, uploaded_image)
417
- # build previews: for each activation, create a montage
418
- previews = []
419
- for act in activations:
420
- if act.ndim == 3:
421
- # H,W,C -> show first up to 12 channels in a grid
422
- C = act.shape[2]
423
- nshow = min(12, C)
424
- # normalize each channel
425
- imgs = []
426
- for i in range(nshow):
427
- ch = act[:, :, i]
428
- mn, mx = ch.min(), ch.max()
429
- if mx - mn > 1e-6:
430
- chn = (ch - mn) / (mx - mn)
431
- else:
432
- chn = np.zeros_like(ch)
433
- im = Image.fromarray((chn * 255).astype("uint8")).resize((128, 128))
434
- imgs.append(im.convert("RGB"))
435
- # make grid 3x4
436
- cols = 4
437
- rows = math.ceil(len(imgs) / cols)
438
- w = cols * 128
439
- h = rows * 128
440
- collage = Image.new("RGB", (w, h), color=(0, 0, 0))
441
- x = y = 0
442
- for idx, im in enumerate(imgs):
443
- collage.paste(im, (x * 128, y * 128))
444
- x += 1
445
- if x >= cols:
446
- x = 0
447
- y += 1
448
- previews.append(collage)
449
- else:
450
- # vector -> show as small bar chart image
451
- vec = np.array(act).flatten()
452
- # scale to 0..255
453
- if vec.size > 0:
454
- mn, mx = vec.min(), vec.max()
455
- if mx - mn > 0:
456
- v = (vec - mn) / (mx - mn)
457
- else:
458
- v = np.zeros_like(vec)
459
- else:
460
- v = vec
461
- # make a simple plot image as grayscale
462
- arr = (v.reshape(1, -1) * 255).astype("uint8")
463
- im = Image.fromarray(arr).resize((512, 128)).convert("RGB")
464
- previews.append(im)
465
- return previews, "OK"
466
- except Exception as e:
467
- return None, f"Activation error: {e}\n{traceback.format_exc()}"
468
-
469
-
470
- # -------------------- Build UI (Gradio 5 compatible) --------------------
471
 
472
  with gr.Blocks() as demo:
473
- gr.Markdown("# 🔎 TensorFlow Computation Graph Visualizer (Advanced)\n"
474
- "Load a Keras `.h5` model or pick an example. Click nodes to inspect layers, view weights and activations.")
475
 
476
  with gr.Row():
477
  with gr.Column(scale=1):
478
- model_file = gr.File(label="Upload Keras model (.h5)", file_types=[".h5"])
479
- example_choice = gr.Dropdown(["small_cnn", "toy_resnet"], value="small_cnn", label="Or pick an example model")
480
  load_btn = gr.Button("Load model")
481
- summary_box = gr.Textbox(label="Model Summary", lines=12)
482
- total_params_box = gr.Textbox(label="Total Parameters", lines=1)
483
- error_box = gr.Markdown()
484
 
485
  with gr.Column(scale=2):
486
- graph_plot = gr.Plot(label="Computation Graph (click a node to inspect)")
487
- node_info = gr.Markdown("Click a node to inspect its details here.")
488
- weights_img = gr.Image(label="Weights preview (conv filters or hist)")
489
- weights_hist = gr.Plot(label="Weights histogram")
 
490
 
491
- gr.Markdown("### Activations (upload an image to see intermediate maps)")
492
- with gr.Row():
493
- with gr.Column(scale=1):
494
- act_img = gr.Image(label="Upload image for activations", type="pil")
495
- layer_names_txt = gr.Textbox(label="Layer names (comma separated) e.g. conv2d,conv2d_1", value="")
496
- act_btn = gr.Button("Compute activations")
497
- act_msg = gr.Markdown()
498
- with gr.Column(scale=2):
499
- act_preview = gr.Gallery(
500
- label="Activation previews",
501
- elem_id="act_gallery",
502
- columns=2,
503
- height="auto"
504
- )
505
- # state store for model object & nx graph
506
  state = gr.State()
507
 
508
- # load model button behavior
509
- def on_load(model_file_obj, example_choice_val):
510
- if not TF_AVAILABLE:
511
- return None, None, "", "", gr.update(visible=True, value=f"TensorFlow import failed: {TF_IMPORT_ERROR}")
512
- try:
513
- res = load_model_callback(model_file_obj, example_choice_val)
514
- if res.get("error"):
515
- return None, None, "", "", gr.update(visible=True, value=res.get("message"))
516
- model = res["model"]
517
- fig = res["graph_fig"]
518
- summary = res["summary"]
519
- total = res["total_params"]
520
- G = res["nx_graph"]
521
- st = {"model": model, "nx_graph": G}
522
- return st, fig, summary, str(total), gr.update(visible=False, value="")
523
- except Exception as e:
524
- return None, None, "", "", gr.update(visible=True, value=f"Load error: {e}\n{traceback.format_exc()}")
525
-
526
- load_btn.click(on_load, inputs=[model_file, example_choice], outputs=[state, graph_plot, summary_box, total_params_box, error_box])
527
-
528
- # when user clicks a node on the plotly graph, gradio returns event with clicked point customdata -> node name
529
- # Use gr.Plot's 'plotly_events' to capture clicks
530
- def on_node_click(evt, st):
531
- # evt is list of click events from plotly; we take first if exists
532
- if not st:
533
- return "No model loaded.", None, None
534
- try:
535
- if not evt:
536
- return "Click a node to inspect it.", None, None
537
- # evt is a list of dicts, get 'customdata'
538
- node_name = evt[0].get("customdata") or evt[0].get("pointIndex")
539
- html, wimg, hist = node_inspect_callback(st, node_name)
540
- hist_fig = None
541
- if hist is not None:
542
- # hist is a tuple (counts, bins)
543
- hist_counts, hist_bins = hist
544
- hist_fig = go.Figure(data=go.Bar(x=hist_bins[:-1].tolist(), y=hist_counts.tolist()))
545
- hist_fig.update_layout(title="Weight histogram", height=240)
546
- return html, wimg, hist_fig
547
- except Exception as e:
548
- return f"Node click error: {e}", None, None
549
-
550
- graph_plot.plotly_events(on_node_click, inputs=[gr.Plot("plotly_events"), state], outputs=[node_info, weights_img, weights_hist])
551
-
552
- # activation compute
553
- def on_compute_activations(st, uploaded_image, layer_names_txt_val):
554
- previews, msg = activation_callback(st, uploaded_image, layer_names_txt_val)
555
- if previews is None:
556
- return None, msg
557
- # convert previews to displayable list
558
- return previews, "Activations computed."
559
-
560
- act_btn.click(on_compute_activations, inputs=[state, act_img, layer_names_txt], outputs=[act_preview, act_msg])
561
-
562
- # friendly note for non-technical users
563
- with gr.Accordion("Simple explanation (for non-technical viewers)", open=False):
564
- gr.Markdown("""
565
- **Simple explanation**
566
-
567
- - Each rectangle (node) is a layer that transforms the data.
568
- - Edges show how data flows from one layer to the next.
569
- - Click any node to see what that layer does (shapes, number of parameters).
570
- - Upload an image and pick a layer to see the 'activation map' — where the network 'looks' for features.
571
- """)
572
 
573
  demo.launch()
 
3
  # - Standard TensorFlow (Keras) based
4
  # - Gradio 5 compatible (no theme=)
5
  # - CPU-friendly (disables GPU usage)
 
 
 
 
 
 
 
6
  # ==========================================================
7
 
8
  import io
 
10
  import math
11
  import traceback
12
  import warnings
13
+ from typing import Any, Dict, List, Optional
14
 
15
  import gradio as gr
16
  import numpy as np
 
20
 
21
  warnings.filterwarnings("ignore")
22
 
23
+ # Try importing tensorflow
24
  try:
25
  import tensorflow as tf
26
  from tensorflow import keras
 
27
  try:
28
  tf.config.set_visible_devices([], "GPU")
29
  except Exception:
 
33
  TF_AVAILABLE = False
34
  TF_IMPORT_ERROR = str(e)
35
 
 
36
  # -------------------- Helpers --------------------
37
 
38
  def safe_load_keras_model(fileobj: Optional[io.BytesIO], chosen: str):
 
 
 
 
39
  if not TF_AVAILABLE:
40
+ raise RuntimeError("TensorFlow not available")
41
 
42
  if fileobj:
 
43
  fileobj.seek(0)
44
  tmp_path = "/tmp/uploaded_model.h5"
45
  with open(tmp_path, "wb") as f:
46
  f.write(fileobj.read())
47
  model = keras.models.load_model(tmp_path)
48
  return model, "uploaded .h5 model"
49
+
50
+ if chosen == "small_cnn":
51
+ model = keras.Sequential([
52
+ keras.layers.InputLayer(input_shape=(64, 64, 3)),
53
+ keras.layers.Conv2D(16, 3, activation="relu", padding="same"),
54
+ keras.layers.MaxPool2D(),
55
+ keras.layers.Conv2D(32, 3, activation="relu", padding="same"),
56
+ keras.layers.MaxPool2D(),
57
+ keras.layers.Conv2D(64, 3, activation="relu", padding="same"),
58
+ keras.layers.GlobalAveragePooling2D(),
59
+ keras.layers.Dense(64, activation="relu"),
60
+ keras.layers.Dense(10, activation="softmax"),
61
+ ])
62
+ model.build((None, 64, 64, 3))
63
+ return model, "Small CNN (example)"
64
+
65
+ if chosen == "toy_resnet":
66
+ inputs = keras.Input(shape=(64, 64, 3))
67
+ x = keras.layers.Conv2D(32, 3, padding="same", activation="relu")(inputs)
68
+ for _ in range(2):
69
+ y = keras.layers.Conv2D(32, 3, padding="same", activation="relu")(x)
70
+ y = keras.layers.Conv2D(32, 3, padding="same")(y)
71
+ x = keras.layers.add([x, y])
72
+ x = keras.layers.ReLU()(x)
73
+ x = keras.layers.GlobalAveragePooling2D()(x)
74
+ outputs = keras.layers.Dense(5, activation="softmax")(x)
75
+ model = keras.Model(inputs, outputs)
76
+ model.build((None, 64, 64, 3))
77
+ return model, "Toy ResNet-like (example)"
78
+
79
+ return safe_load_keras_model(None, "small_cnn")
80
+
81
+
82
+ def model_summary_str(model):
 
 
 
 
 
 
 
83
  stream = io.StringIO()
84
  model.summary(print_fn=lambda s: stream.write(s + "\n"))
85
  return stream.getvalue()
86
 
 
87
  # -------------------- Graph builder --------------------
88
 
89
+ def build_layer_graph(model):
 
 
 
 
90
  G = nx.DiGraph()
91
+ for layer in model.layers:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  inbound = []
93
+ for node in getattr(layer, "_inbound_nodes", []) or []:
94
+ for l in getattr(node, "inbound_layers", []) or []:
95
+ inbound.append(l.name)
96
+ G.add_node(
97
+ layer.name,
98
+ class_name=layer.__class__.__name__,
99
+ input_shape=getattr(layer, "input_shape", None),
100
+ output_shape=getattr(layer, "output_shape", None),
101
+ params=layer.count_params(),
102
+ inbound_layers=inbound,
103
+ )
104
+ for n, d in G.nodes(data=True):
105
+ for src in d["inbound_layers"]:
106
+ if src in G:
107
+ G.add_edge(src, n)
 
 
 
 
 
108
  return G
109
 
110
 
111
+ def nx_to_plotly_fig(G):
112
+ pos = nx.spring_layout(G, seed=42)
113
+ edge_x, edge_y = [], []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  for u, v in G.edges():
115
  x0, y0 = pos[u]
116
  x1, y1 = pos[v]
117
  edge_x += [x0, x1, None]
118
  edge_y += [y0, y1, None]
119
 
120
+ node_x, node_y, labels = [], [], []
121
+ for n in G.nodes():
122
+ x, y = pos[n]
123
+ node_x.append(x)
124
+ node_y.append(y)
125
+ labels.append(n)
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
+ fig = go.Figure()
128
+ fig.add_trace(go.Scatter(x=edge_x, y=edge_y, mode="lines"))
129
+ fig.add_trace(go.Scatter(x=node_x, y=node_y, mode="markers+text", text=labels))
130
+ fig.update_layout(height=600, showlegend=False)
 
 
 
 
131
  fig.update_xaxes(visible=False)
132
  fig.update_yaxes(visible=False)
133
  return fig
134
 
135
+ # -------------------- Inspect --------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
  def node_inspect_callback(state, node_name):
 
 
 
138
  if not state:
139
  return "No model loaded.", None, None
140
+ model = state["model"]
141
+ layer = model.get_layer(node_name)
142
+ weights = layer.get_weights()
143
+ hist_fig = None
144
+ img = None
145
+
146
+ if weights:
147
+ w = weights[0]
148
+ hist = np.histogram(w.flatten(), bins=50)
149
+ hist_fig = go.Figure(go.Bar(x=hist[1][:-1], y=hist[0]))
150
+
151
+ if w.ndim == 4:
152
+ ch = w[:, :, :, 0].mean(axis=2)
153
+ ch = (ch - ch.min()) / (ch.ptp() + 1e-6)
154
+ img = Image.fromarray((ch * 255).astype("uint8")).resize((256, 256))
155
+
156
+ txt = (
157
+ f"**Layer:** {layer.name}\n\n"
158
+ f"- Type: `{layer.__class__.__name__}`\n"
159
+ f"- Input: `{layer.input_shape}`\n"
160
+ f"- Output: `{layer.output_shape}`\n"
161
+ f"- Params: `{layer.count_params()}`"
162
+ )
163
+ return txt, img, hist_fig
164
+
165
+ # -------------------- UI --------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
  with gr.Blocks() as demo:
168
+ gr.Markdown("# 🔎 TensorFlow Computation Graph Visualizer")
 
169
 
170
  with gr.Row():
171
  with gr.Column(scale=1):
172
+ model_file = gr.File(label="Upload .h5")
173
+ example = gr.Dropdown(["small_cnn", "toy_resnet"], value="small_cnn")
174
  load_btn = gr.Button("Load model")
175
+ summary = gr.Textbox(lines=12)
176
+ params = gr.Textbox()
177
+ error = gr.Markdown()
178
 
179
  with gr.Column(scale=2):
180
+ graph_plot = gr.Plot()
181
+ layer_select = gr.Dropdown(label="Select layer to inspect")
182
+ node_info = gr.Markdown()
183
+ weights_img = gr.Image()
184
+ weights_hist = gr.Plot()
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  state = gr.State()
187
 
188
+ def on_load(file, ex):
189
+ model, _ = safe_load_keras_model(file, ex)
190
+ G = build_layer_graph(model)
191
+ fig = nx_to_plotly_fig(G)
192
+ return (
193
+ {"model": model, "graph": G},
194
+ fig,
195
+ model_summary_str(model),
196
+ str(model.count_params()),
197
+ "",
198
+ list(G.nodes())
199
+ )
200
+
201
+ load_btn.click(
202
+ on_load,
203
+ inputs=[model_file, example],
204
+ outputs=[state, graph_plot, summary, params, error, layer_select]
205
+ )
206
+
207
+ layer_select.change(
208
+ node_inspect_callback,
209
+ inputs=[state, layer_select],
210
+ outputs=[node_info, weights_img, weights_hist]
211
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
  demo.launch()