Spaces:
Runtime error
Runtime error
| import networkx as nx | |
| import json | |
| import matplotlib.pyplot as plt | |
| class Node: | |
| def __init__(self, name: str, value=None, parent=None, children: list = []): | |
| self.name = name | |
| self.children = set(children) | |
| self.parent = parent | |
| self.value = value | |
| def __repr__(self): | |
| return self.name | |
| def __str__(self): | |
| return self.name | |
| def __eq__(self, other): | |
| return self.name == other.name | |
| def __hash__(self) -> int: | |
| return hash(self.name) | |
| # make serializable for json | |
| def __getstate__(self): | |
| return self.__dict__ | |
| def __dict__(self): | |
| # return a dict of the node's attributes | |
| return { | |
| "name": self.name, | |
| "children": self.children, | |
| "parent": self.parent, | |
| "value": self.value, | |
| } | |
| def to_json(self): | |
| """ | |
| Returns a JSON string representation of the node. | |
| """ | |
| return json.dumps(self.__dict__) | |
| def add_child(self, child): | |
| self.children.add(child) | |
| def has_children(self): | |
| return len(self.children) > 0 | |
| def set_parent(self, new_parent): | |
| self.parent = new_parent | |
| def set_value(self, new_value): | |
| self.value = new_value | |
| def read_json(fname: str) -> dict: | |
| assert fname.endswith(".json"), "File must be a json file" | |
| with open(fname, "r") as f: | |
| data = json.load(f) | |
| return dict(data) | |
| def build_tree_from_dict(data: dict, connect_children: bool = True): | |
| # every dict key is a node's name | |
| # dict value is a dict with keys "value", "parent", "children" | |
| # "value" is the node's value | |
| # "parent" is the node's parent's name | |
| # "children" is a list of the node's children's names | |
| # create a networkx graph | |
| G = nx.Graph() | |
| nodes_dict = dict() | |
| # build the nodes | |
| for name, info in data.items(): | |
| value = info["value"] | |
| parent = info["parent"] | |
| children: list = info["children"] | |
| nodes_dict[name] = Node( | |
| name=name, parent=parent, children=children, value=value | |
| ) | |
| G.add_node(nodes_dict[name], value=value) | |
| # build the edges | |
| for _, node in nodes_dict.items(): | |
| for child in node.children: | |
| G.add_edge(node, nodes_dict[child]) | |
| # connect children to each other if connect_children is True | |
| if connect_children: | |
| for child2 in node.children: | |
| if child != child2: | |
| G.add_edge(nodes_dict[child], nodes_dict[child2]) | |
| return G, nodes_dict | |
| def build_tree_from_file(fname: str): | |
| data = read_json(fname) | |
| return build_tree_from_dict(data) | |
| # calculate the number of edges between two nodes | |
| def num_edges_between_nodes(G, node1, node2): | |
| return len(nx.shortest_path(G, node1, node2)) - 1 | |
| def explore_bfs(G: nx.Graph, source: Node, nodes_dict: dict[str, Node]): | |
| # start from a source node and explore the graph in a breadth-first manner | |
| # prioritize nodes with non-empty values | |
| # explore the graph and return a list of nodes in the order they were explored | |
| explored_nodes = [] | |
| queue = [source] | |
| while queue: | |
| node = queue.pop(0) | |
| explored_nodes.append(node) | |
| for child in node.children: | |
| if nodes_dict[child].value: | |
| queue.insert(0, nodes_dict[child]) | |
| else: | |
| queue.append(nodes_dict[child]) | |
| return explored_nodes | |
| def from_list(node_list: list[Node], directional=True): | |
| # create a tree from a list of nodes | |
| # and label the edges from the first node to the last node from 1 to n | |
| if directional: | |
| G = nx.DiGraph() | |
| else: | |
| G = nx.Graph() | |
| G.add_nodes_from(node_list) | |
| for i in range(len(node_list) - 1): | |
| G.add_edge(node_list[i], node_list[i + 1], label=i + 1) | |
| return G | |
| def visualize_graph( | |
| graph: nx.Graph, | |
| layout_graph: nx.Graph, | |
| title="BFS Tree", | |
| fig_size=(30, 20), | |
| title_fontsize=20, | |
| edge_width=1, | |
| font_size=9, | |
| node_size=500, | |
| node_shape="o", | |
| prog="dot", | |
| ): | |
| graphviz_args = "-Goverlap=false -Gsplines=true -Gsep=0.1 -Gnodesep=0.1 -Gmaxiter=1000 -Gepsilon=0.0001 -Gstart=0" | |
| _, ax = plt.subplots(figsize=fig_size) | |
| ax.set_title(title, fontsize=title_fontsize) | |
| # also draw edge labels | |
| nx.draw( | |
| graph, | |
| ax=ax, | |
| with_labels=True, | |
| # color every node lightblue except the root which is colored red | |
| node_color=(["lightgreen"] + ["lightblue"] * (len(graph.nodes) - 2) + ["red"]) | |
| if len(graph.nodes) > 2 | |
| else ["lightgreen", "red"] | |
| if len(graph.nodes) == 2 | |
| else ["lightgreen"], | |
| edge_color="gray", | |
| width=edge_width, | |
| font_size=font_size, | |
| # node size to be proportional to the node's value | |
| node_size=node_size, | |
| # shape set to rectangle | |
| node_shape=node_shape, | |
| pos=nx.nx_agraph.graphviz_layout( | |
| layout_graph, prog=prog, root="root", args=graphviz_args | |
| ), | |
| ) | |
| nx.draw_networkx_edge_labels( | |
| graph, | |
| pos=nx.nx_agraph.graphviz_layout( | |
| layout_graph, prog=prog, root="root", args=graphviz_args | |
| ), | |
| edge_labels=nx.get_edge_attributes(graph, "label"), | |
| font_size=font_size, | |
| ) | |
| plt.show() | |
| def get_graph( | |
| graph: nx.Graph, | |
| layout_graph: nx.Graph, | |
| title="BFS Tree", | |
| fig_size=(30, 20), | |
| title_fontsize=20, | |
| edge_width=1, | |
| font_size=9, | |
| node_size=500, | |
| node_shape="o", | |
| prog="dot", | |
| ): | |
| graphviz_args = "-Goverlap=false -Gsplines=true -Gsep=0.1 -Gnodesep=0.1 -Gmaxiter=1000 -Gepsilon=0.0001 -Gstart=0" | |
| fig, ax = plt.subplots(figsize=fig_size) | |
| ax.set_title(title, fontsize=title_fontsize) | |
| nx.draw( | |
| graph, | |
| ax=ax, | |
| with_labels=True, | |
| # color every node lightblue except the root which is colored red | |
| node_color=(["lightgreen"] + ["lightblue"] * (len(graph.nodes) - 2) + ["red"]) | |
| if len(graph.nodes) > 2 | |
| else ["lightgreen", "red"] | |
| if len(graph.nodes) == 2 | |
| else ["lightgreen"], | |
| edge_color="gray", | |
| width=edge_width, | |
| font_size=font_size, | |
| # node size to be proportional to the node's value | |
| node_size=node_size, | |
| # shape set to rectangle | |
| node_shape=node_shape, | |
| pos=nx.nx_agraph.graphviz_layout( | |
| layout_graph, prog=prog, root="root", args=graphviz_args | |
| ), | |
| ) | |
| nx.draw_networkx_edge_labels( | |
| graph, | |
| pos=nx.nx_agraph.graphviz_layout( | |
| layout_graph, prog=prog, root="root", args=graphviz_args | |
| ), | |
| edge_labels=nx.get_edge_attributes(graph, "label"), | |
| font_size=font_size, | |
| ) | |
| return fig, ax | |