import networkx as nx import json import matplotlib.pyplot as plt class Node: def __init__(self, name: str, value=None, parent=None, children: list = []): self.name = name self.children = set(children) self.parent = parent self.value = value def __repr__(self): return self.name def __str__(self): return self.name def __eq__(self, other): return self.name == other.name def __hash__(self) -> int: return hash(self.name) # make serializable for json def __getstate__(self): return self.__dict__ def __dict__(self): # return a dict of the node's attributes return { "name": self.name, "children": self.children, "parent": self.parent, "value": self.value, } def to_json(self): """ Returns a JSON string representation of the node. """ return json.dumps(self.__dict__) def add_child(self, child): self.children.add(child) def has_children(self): return len(self.children) > 0 def set_parent(self, new_parent): self.parent = new_parent def set_value(self, new_value): self.value = new_value def read_json(fname: str) -> dict: assert fname.endswith(".json"), "File must be a json file" with open(fname, "r") as f: data = json.load(f) return dict(data) def build_tree_from_dict(data: dict, connect_children: bool = True): # every dict key is a node's name # dict value is a dict with keys "value", "parent", "children" # "value" is the node's value # "parent" is the node's parent's name # "children" is a list of the node's children's names # create a networkx graph G = nx.Graph() nodes_dict = dict() # build the nodes for name, info in data.items(): value = info["value"] parent = info["parent"] children: list = info["children"] nodes_dict[name] = Node( name=name, parent=parent, children=children, value=value ) G.add_node(nodes_dict[name], value=value) # build the edges for _, node in nodes_dict.items(): for child in node.children: G.add_edge(node, nodes_dict[child]) # connect children to each other if connect_children is True if connect_children: for child2 in node.children: if child != child2: G.add_edge(nodes_dict[child], nodes_dict[child2]) return G, nodes_dict def build_tree_from_file(fname: str): data = read_json(fname) return build_tree_from_dict(data) # calculate the number of edges between two nodes def num_edges_between_nodes(G, node1, node2): return len(nx.shortest_path(G, node1, node2)) - 1 def explore_bfs(G: nx.Graph, source: Node, nodes_dict: dict[str, Node]): # start from a source node and explore the graph in a breadth-first manner # prioritize nodes with non-empty values # explore the graph and return a list of nodes in the order they were explored explored_nodes = [] queue = [source] while queue: node = queue.pop(0) explored_nodes.append(node) for child in node.children: if nodes_dict[child].value: queue.insert(0, nodes_dict[child]) else: queue.append(nodes_dict[child]) return explored_nodes def from_list(node_list: list[Node], directional=True): # create a tree from a list of nodes # and label the edges from the first node to the last node from 1 to n if directional: G = nx.DiGraph() else: G = nx.Graph() G.add_nodes_from(node_list) for i in range(len(node_list) - 1): G.add_edge(node_list[i], node_list[i + 1], label=i + 1) return G def visualize_graph( graph: nx.Graph, layout_graph: nx.Graph, title="BFS Tree", fig_size=(30, 20), title_fontsize=20, edge_width=1, font_size=9, node_size=500, node_shape="o", prog="dot", ): graphviz_args = "-Goverlap=false -Gsplines=true -Gsep=0.1 -Gnodesep=0.1 -Gmaxiter=1000 -Gepsilon=0.0001 -Gstart=0" _, ax = plt.subplots(figsize=fig_size) ax.set_title(title, fontsize=title_fontsize) # also draw edge labels nx.draw( graph, ax=ax, with_labels=True, # color every node lightblue except the root which is colored red node_color=(["lightgreen"] + ["lightblue"] * (len(graph.nodes) - 2) + ["red"]) if len(graph.nodes) > 2 else ["lightgreen", "red"] if len(graph.nodes) == 2 else ["lightgreen"], edge_color="gray", width=edge_width, font_size=font_size, # node size to be proportional to the node's value node_size=node_size, # shape set to rectangle node_shape=node_shape, pos=nx.nx_agraph.graphviz_layout( layout_graph, prog=prog, root="root", args=graphviz_args ), ) nx.draw_networkx_edge_labels( graph, pos=nx.nx_agraph.graphviz_layout( layout_graph, prog=prog, root="root", args=graphviz_args ), edge_labels=nx.get_edge_attributes(graph, "label"), font_size=font_size, ) plt.show() def get_graph( graph: nx.Graph, layout_graph: nx.Graph, title="BFS Tree", fig_size=(30, 20), title_fontsize=20, edge_width=1, font_size=9, node_size=500, node_shape="o", prog="dot", ): graphviz_args = "-Goverlap=false -Gsplines=true -Gsep=0.1 -Gnodesep=0.1 -Gmaxiter=1000 -Gepsilon=0.0001 -Gstart=0" fig, ax = plt.subplots(figsize=fig_size) ax.set_title(title, fontsize=title_fontsize) nx.draw( graph, ax=ax, with_labels=True, # color every node lightblue except the root which is colored red node_color=(["lightgreen"] + ["lightblue"] * (len(graph.nodes) - 2) + ["red"]) if len(graph.nodes) > 2 else ["lightgreen", "red"] if len(graph.nodes) == 2 else ["lightgreen"], edge_color="gray", width=edge_width, font_size=font_size, # node size to be proportional to the node's value node_size=node_size, # shape set to rectangle node_shape=node_shape, pos=nx.nx_agraph.graphviz_layout( layout_graph, prog=prog, root="root", args=graphviz_args ), ) nx.draw_networkx_edge_labels( graph, pos=nx.nx_agraph.graphviz_layout( layout_graph, prog=prog, root="root", args=graphviz_args ), edge_labels=nx.get_edge_attributes(graph, "label"), font_size=font_size, ) return fig, ax