import streamlit as st import networkx as nx from pyvis.network import Network import pickle import math import random import requests import os from huggingface_hub import hf_hub_download # Dictionary to map brands to their respective HuggingFace model repo files BRAND_GRAPHS = { 'drumeo': 'drumeo_graph.pkl', 'pianote': 'pianote_graph.pkl', 'singeo': 'singeo_graph.pkl', 'guitareo': 'guitareo_graph.pkl' } # HuggingFace Repository Info #HF_REPO = "MusoraProductDepartment/popular-path-graphs" AUTH_TOKEN = os.getenv('HF_TOKEN') API_URL = "https://MusoraProductDepartment-PWGenerator.hf.space/rank_items/" @st.cache_resource def load_graph_from_hf(brand): """ Load the graph for the selected brand from HuggingFace Hub. """ try: # Download the file from HuggingFace Hub HF_REPO = f'MusoraProductDepartment/{brand}-graph' cache_dir = '/tmp' file_path = hf_hub_download(repo_id=HF_REPO, filename=BRAND_GRAPHS[brand], token=AUTH_TOKEN, cache_dir=cache_dir, repo_type='model') # Load the graph with open(file_path, 'rb') as f: return pickle.load(f) except Exception as e: st.error(f"Error loading graph from HuggingFace: {e}") return None def filter_graph(graph, node_threshold=10, edge_threshold=5): """ Filters the graph to include only popular nodes and edges. """ popular_nodes = [ node for node in graph.nodes if graph.degree(node) >= node_threshold ] filtered_graph = graph.subgraph(popular_nodes).copy() for u, v, data in list(filtered_graph.edges(data=True)): if data.get("weight", 0) < edge_threshold: filtered_graph.remove_edge(u, v) return filtered_graph def get_rankings_from_api(brand, user_id, content_ids): """ Call the rank_items API to fetch rankings for the given user and content IDs. """ try: payload = { "brand": brand.upper(), "user_id": int(user_id), "content_ids": [int(content_id) for content_id in content_ids] } headers = { "Authorization": f"Bearer {AUTH_TOKEN}", "accept": "application/json", "Content-Type": "application/json" } response = requests.post(API_URL, json=payload, headers=headers) response.raise_for_status() rankings = response.json() return rankings except Exception as e: st.error(f"Error calling rank_items API: {e}") return {} def rank_to_color(rank, max_rank): """ Map a rank to a grayscale color, where dark gray indicates high relevance (low rank), and light gray indicates low relevance (high rank). """ if rank > max_rank: # Handle items without ranking return "#E8E8E8" # Very light gray for unranked items intensity = int(55 + (rank / max_rank) * 200) # Darker for lower ranks return f"rgb({intensity}, {intensity}, {intensity})" # Grayscale def dynamic_visualize_graph(graph, start_node, layers=3, top_k=5, show_titles=False, rankings=None): net = Network(notebook=False, width="100%", height="600px", directed=True) net.set_options(""" var options = { "physics": { "barnesHut": { "gravitationalConstant": -15000, "centralGravity": 0.8 } } } """) visited_nodes = set() added_edges = set() current_nodes = [str(start_node)] max_rank = len(rankings) if rankings else 0 # Add the starting node, color it red, and include a tooltip start_title = graph.nodes[str(start_node)].get('title', 'No title available') start_in_degree = graph.in_degree(str(start_node)) start_out_degree = graph.out_degree(str(start_node)) start_node_size = (start_in_degree + start_out_degree) * 0.15 start_rank = rankings.index(str(start_node)) if rankings and str(start_node) in rankings else max_rank + 1 if rankings: start_border_color = rank_to_color(start_rank, max_rank) else: start_border_color = 'darkblue' label = str(start_node) if not show_titles else f"{str(start_node)}: {start_title[:15]}..." net.add_node( str(start_node), label=label, color={"background": "darkblue", "border": start_border_color}, title=f"{start_title}, In-degree: {start_in_degree}, Out-degree: {start_out_degree}, Rank: {start_rank}", size=start_node_size, borderWidth=3, borderWidthSelected=6 ) visited_nodes.add(str(start_node)) for layer in range(layers): next_nodes = [] for node in current_nodes: neighbors = sorted( [(str(neighbor), data['weight']) for neighbor, data in graph[node].items()], key=lambda x: x[1], reverse=True )[:top_k] for neighbor, weight in neighbors: if neighbor not in visited_nodes: neighbor_title = graph.nodes[neighbor].get('title', 'No title available') neighbor_in_degree = graph.in_degree(neighbor) neighbor_out_degree = graph.out_degree(neighbor) neighbor_size = (neighbor_in_degree + neighbor_out_degree) * 0.15 neighbor_rank = rankings.index(neighbor) if rankings and neighbor in rankings else max_rank + 1 node_color = 'red' if neighbor_in_degree > neighbor_out_degree * 1.5 else \ 'green' if neighbor_out_degree > neighbor_in_degree * 1.5 else 'lightblue' if rankings: neighbor_border_color = rank_to_color(neighbor_rank, max_rank) else: neighbor_border_color = node_color label = str(neighbor) if not show_titles else f"{str(neighbor)}: {neighbor_title[:15]}..." net.add_node( neighbor, label=label, title=f"{neighbor_title}, In-degree: {neighbor_in_degree}, Out-degree: {neighbor_out_degree}, Rank: {neighbor_rank}", size=neighbor_size, color={"background": node_color, "border": neighbor_border_color}, borderWidth=3, borderWidthSelected=6 ) edge = (node, neighbor) if edge not in added_edges: edge_width = math.log(weight + 1) * 8 net.add_edge(node, neighbor, label=f"w:{weight}", width=edge_width, color='lightgray') added_edges.add(edge) visited_nodes.add(neighbor) next_nodes.append(neighbor) current_nodes = next_nodes html_content = net.generate_html() st.components.v1.html(html_content, height=600, scrolling=False) st.title("Popular Path Expansion + Personalization") # Brand Selection selected_brand = st.selectbox("Select a brand:", options=list(BRAND_GRAPHS.keys())) if "selected_brand" not in st.session_state or st.session_state.selected_brand != selected_brand: st.session_state.selected_brand = selected_brand G = load_graph_from_hf(selected_brand) # Sort nodes by popularity (in-degree + out-degree) and select from top 20 popular_nodes = sorted(G.nodes, key=lambda n: G.in_degree(n) + G.out_degree(n), reverse=True) top_20_nodes = popular_nodes[:20] if len(popular_nodes) > 20 else popular_nodes st.session_state.start_node = random.choice(top_20_nodes) else: G = load_graph_from_hf(selected_brand) # Random Selection Button if st.button("Random Selection"): st.session_state.start_node = random.choice(list(G.nodes)) start_node = st.text_input( "Enter the starting node ID:", value=str(st.session_state.start_node) ) try: start_node = str(start_node) except ValueError: st.error("Please enter a valid numeric content ID.") st.stop() # Input: Student ID student_id = st.text_input("Enter a student ID (optional):", value="") # Toggle for showing content titles show_titles = st.checkbox("Show content titles", value=False) # Filter the graph node_degree_threshold = 1 edge_weight_threshold = 1 G_filtered = filter_graph(G, node_threshold=node_degree_threshold, edge_threshold=edge_weight_threshold) # Fetch rankings if student ID is provided rankings = {} if student_id: content_ids = list(G_filtered.nodes) rankings = get_rankings_from_api(selected_brand, int(student_id), content_ids) if rankings: rankings = rankings['ranked_content_ids'] layers = st.slider("Depth to explore:", 1, 6, value=3) top_k = st.slider("Branching factor (per node):", 1, 6, value=3) if st.button("Expand Graph"): if start_node in G_filtered: dynamic_visualize_graph(G_filtered, start_node, layers=layers, top_k=top_k, show_titles=show_titles, rankings=rankings) else: st.error("The starting node is not in the graph!")