# ========================================================== # TensorFlow Computation Graph Visualizer (Advanced) # - Standard TensorFlow (Keras) based # - Gradio 5 compatible (no theme=) # - CPU-friendly (disables GPU usage) # ========================================================== import io import os import math import traceback import warnings from typing import Any, Dict, List, Optional import gradio as gr import numpy as np from PIL import Image import plotly.graph_objects as go import networkx as nx warnings.filterwarnings("ignore") # Try importing tensorflow try: import tensorflow as tf from tensorflow import keras try: tf.config.set_visible_devices([], "GPU") except Exception: pass TF_AVAILABLE = True except Exception as e: TF_AVAILABLE = False TF_IMPORT_ERROR = str(e) # -------------------- Helpers -------------------- def safe_load_keras_model(fileobj: Optional[io.BytesIO], chosen: str): if not TF_AVAILABLE: raise RuntimeError("TensorFlow not available") if fileobj: fileobj.seek(0) tmp_path = "/tmp/uploaded_model.h5" with open(tmp_path, "wb") as f: f.write(fileobj.read()) model = keras.models.load_model(tmp_path) return model, "uploaded .h5 model" if chosen == "small_cnn": model = keras.Sequential([ keras.layers.InputLayer(input_shape=(64, 64, 3)), keras.layers.Conv2D(16, 3, activation="relu", padding="same"), keras.layers.MaxPool2D(), keras.layers.Conv2D(32, 3, activation="relu", padding="same"), keras.layers.MaxPool2D(), keras.layers.Conv2D(64, 3, activation="relu", padding="same"), keras.layers.GlobalAveragePooling2D(), keras.layers.Dense(64, activation="relu"), keras.layers.Dense(10, activation="softmax"), ]) model.build((None, 64, 64, 3)) return model, "Small CNN (example)" if chosen == "toy_resnet": inputs = keras.Input(shape=(64, 64, 3)) x = keras.layers.Conv2D(32, 3, padding="same", activation="relu")(inputs) for _ in range(2): y = keras.layers.Conv2D(32, 3, padding="same", activation="relu")(x) y = keras.layers.Conv2D(32, 3, padding="same")(y) x = keras.layers.add([x, y]) x = keras.layers.ReLU()(x) x = keras.layers.GlobalAveragePooling2D()(x) outputs = keras.layers.Dense(5, activation="softmax")(x) model = keras.Model(inputs, outputs) model.build((None, 64, 64, 3)) return model, "Toy ResNet-like (example)" return safe_load_keras_model(None, "small_cnn") def model_summary_str(model): stream = io.StringIO() model.summary(print_fn=lambda s: stream.write(s + "\n")) return stream.getvalue() # -------------------- Graph builder -------------------- def build_layer_graph(model): G = nx.DiGraph() for layer in model.layers: inbound = [] for node in getattr(layer, "_inbound_nodes", []) or []: for l in getattr(node, "inbound_layers", []) or []: inbound.append(l.name) G.add_node( layer.name, class_name=layer.__class__.__name__, input_shape=getattr(layer, "input_shape", None), output_shape=getattr(layer, "output_shape", None), params=layer.count_params(), inbound_layers=inbound, ) for n, d in G.nodes(data=True): for src in d["inbound_layers"]: if src in G: G.add_edge(src, n) return G def nx_to_plotly_fig(G): pos = nx.spring_layout(G, seed=42) edge_x, edge_y = [], [] for u, v in G.edges(): x0, y0 = pos[u] x1, y1 = pos[v] edge_x += [x0, x1, None] edge_y += [y0, y1, None] node_x, node_y, labels = [], [], [] for n in G.nodes(): x, y = pos[n] node_x.append(x) node_y.append(y) labels.append(n) fig = go.Figure() fig.add_trace(go.Scatter(x=edge_x, y=edge_y, mode="lines")) fig.add_trace(go.Scatter(x=node_x, y=node_y, mode="markers+text", text=labels)) fig.update_layout(height=600, showlegend=False) fig.update_xaxes(visible=False) fig.update_yaxes(visible=False) return fig # -------------------- Inspect -------------------- def node_inspect_callback(state, node_name): if not state: return "No model loaded.", None, None model = state["model"] layer = model.get_layer(node_name) weights = layer.get_weights() hist_fig = None img = None if weights: w = weights[0] hist = np.histogram(w.flatten(), bins=50) hist_fig = go.Figure(go.Bar(x=hist[1][:-1], y=hist[0])) if w.ndim == 4: ch = w[:, :, :, 0].mean(axis=2) ch = (ch - ch.min()) / (ch.ptp() + 1e-6) img = Image.fromarray((ch * 255).astype("uint8")).resize((256, 256)) txt = ( f"**Layer:** {layer.name}\n\n" f"- Type: `{layer.__class__.__name__}`\n" f"- Input: `{layer.input_shape}`\n" f"- Output: `{layer.output_shape}`\n" f"- Params: `{layer.count_params()}`" ) return txt, img, hist_fig # -------------------- UI -------------------- with gr.Blocks() as demo: gr.Markdown("# 🔎 TensorFlow Computation Graph Visualizer") with gr.Row(): with gr.Column(scale=1): model_file = gr.File(label="Upload .h5") example = gr.Dropdown(["small_cnn", "toy_resnet"], value="small_cnn") load_btn = gr.Button("Load model") summary = gr.Textbox(lines=12) params = gr.Textbox() error = gr.Markdown() with gr.Column(scale=2): graph_plot = gr.Plot() layer_select = gr.Dropdown(label="Select layer to inspect") node_info = gr.Markdown() weights_img = gr.Image() weights_hist = gr.Plot() state = gr.State() def on_load(file, ex): model, _ = safe_load_keras_model(file, ex) G = build_layer_graph(model) fig = nx_to_plotly_fig(G) return ( {"model": model, "graph": G}, fig, model_summary_str(model), str(model.count_params()), "", list(G.nodes()) ) load_btn.click( on_load, inputs=[model_file, example], outputs=[state, graph_plot, summary, params, error, layer_select] ) layer_select.change( node_inspect_callback, inputs=[state, layer_select], outputs=[node_info, weights_img, weights_hist] ) demo.launch()