File size: 9,022 Bytes
209e6b2 292c185 7d63ff8 c0428c6 209e6b2 c0428c6 2aa8945 c0428c6 2aa8945 c0428c6 7d63ff8 209e6b2 c0428c6 79d9770 8d70952 98f49ee c0428c6 209e6b2 2aa8945 c0428c6 2aa8945 c0428c6 7d63ff8 d89e6da da71e01 7d63ff8 aa2b395 7d63ff8 209e6b2 122c116 209e6b2 122c116 209e6b2 292c185 d8ff0e1 122c116 7d63ff8 209e6b2 a1850db 292c185 a1850db 7d63ff8 292c185 a1850db 292c185 7d63ff8 292c185 a1850db 209e6b2 0953001 209e6b2 292c185 209e6b2 292c185 7d63ff8 292c185 7d63ff8 292c185 7d63ff8 292c185 7d63ff8 292c185 122c116 292c185 c3c6baa 292c185 209e6b2 292c185 209e6b2 292c185 209e6b2 122c116 7d63ff8 2aa8945 209e6b2 122c116 c0428c6 292c185 122c116 c0428c6 209e6b2 292c185 f3c758e 209e6b2 f3c758e 209e6b2 122c116 afb7de3 4622e0f afb7de3 f3c758e 7d63ff8 292c185 122c116 7d63ff8 40af051 c0428c6 209e6b2 2aa8945 7d63ff8 209e6b2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 | import streamlit as st
import networkx as nx
from pyvis.network import Network
import pickle
import math
import random
import requests
import os
from huggingface_hub import hf_hub_download
# Dictionary to map brands to their respective HuggingFace model repo files
BRAND_GRAPHS = {
'drumeo': 'drumeo_graph.pkl',
'pianote': 'pianote_graph.pkl',
'singeo': 'singeo_graph.pkl',
'guitareo': 'guitareo_graph.pkl'
}
# HuggingFace Repository Info
#HF_REPO = "MusoraProductDepartment/popular-path-graphs"
AUTH_TOKEN = os.getenv('HF_TOKEN')
API_URL = "https://MusoraProductDepartment-PWGenerator.hf.space/rank_items/"
@st.cache_resource
def load_graph_from_hf(brand):
"""
Load the graph for the selected brand from HuggingFace Hub.
"""
try:
# Download the file from HuggingFace Hub
HF_REPO = f'MusoraProductDepartment/{brand}-graph'
cache_dir = '/tmp'
file_path = hf_hub_download(repo_id=HF_REPO, filename=BRAND_GRAPHS[brand], token=AUTH_TOKEN, cache_dir=cache_dir, repo_type='model')
# Load the graph
with open(file_path, 'rb') as f:
return pickle.load(f)
except Exception as e:
st.error(f"Error loading graph from HuggingFace: {e}")
return None
def filter_graph(graph, node_threshold=10, edge_threshold=5):
"""
Filters the graph to include only popular nodes and edges.
"""
popular_nodes = [
node for node in graph.nodes
if graph.degree(node) >= node_threshold
]
filtered_graph = graph.subgraph(popular_nodes).copy()
for u, v, data in list(filtered_graph.edges(data=True)):
if data.get("weight", 0) < edge_threshold:
filtered_graph.remove_edge(u, v)
return filtered_graph
def get_rankings_from_api(brand, user_id, content_ids):
"""
Call the rank_items API to fetch rankings for the given user and content IDs.
"""
try:
payload = {
"brand": brand.upper(),
"user_id": int(user_id),
"content_ids": [int(content_id) for content_id in content_ids]
}
headers = {
"Authorization": f"Bearer {AUTH_TOKEN}",
"accept": "application/json",
"Content-Type": "application/json"
}
response = requests.post(API_URL, json=payload, headers=headers)
response.raise_for_status()
rankings = response.json()
return rankings
except Exception as e:
st.error(f"Error calling rank_items API: {e}")
return {}
def rank_to_color(rank, max_rank):
"""
Map a rank to a grayscale color, where dark gray indicates high relevance (low rank),
and light gray indicates low relevance (high rank).
"""
if rank > max_rank: # Handle items without ranking
return "#E8E8E8" # Very light gray for unranked items
intensity = int(55 + (rank / max_rank) * 200) # Darker for lower ranks
return f"rgb({intensity}, {intensity}, {intensity})" # Grayscale
def dynamic_visualize_graph(graph, start_node, layers=3, top_k=5, show_titles=False, rankings=None):
net = Network(notebook=False, width="100%", height="600px", directed=True)
net.set_options("""
var options = {
"physics": {
"barnesHut": {
"gravitationalConstant": -15000,
"centralGravity": 0.8
}
}
}
""")
visited_nodes = set()
added_edges = set()
current_nodes = [str(start_node)]
max_rank = len(rankings) if rankings else 0
# Add the starting node, color it red, and include a tooltip
start_title = graph.nodes[str(start_node)].get('title', 'No title available')
start_in_degree = graph.in_degree(str(start_node))
start_out_degree = graph.out_degree(str(start_node))
start_node_size = (start_in_degree + start_out_degree) * 0.15
start_rank = rankings.index(str(start_node)) if rankings and str(start_node) in rankings else max_rank + 1
if rankings:
start_border_color = rank_to_color(start_rank, max_rank)
else:
start_border_color = 'darkblue'
label = str(start_node) if not show_titles else f"{str(start_node)}: {start_title[:15]}..."
net.add_node(
str(start_node),
label=label,
color={"background": "darkblue", "border": start_border_color},
title=f"{start_title}, In-degree: {start_in_degree}, Out-degree: {start_out_degree}, Rank: {start_rank}",
size=start_node_size,
borderWidth=3,
borderWidthSelected=6
)
visited_nodes.add(str(start_node))
for layer in range(layers):
next_nodes = []
for node in current_nodes:
neighbors = sorted(
[(str(neighbor), data['weight']) for neighbor, data in graph[node].items()],
key=lambda x: x[1],
reverse=True
)[:top_k]
for neighbor, weight in neighbors:
if neighbor not in visited_nodes:
neighbor_title = graph.nodes[neighbor].get('title', 'No title available')
neighbor_in_degree = graph.in_degree(neighbor)
neighbor_out_degree = graph.out_degree(neighbor)
neighbor_size = (neighbor_in_degree + neighbor_out_degree) * 0.15
neighbor_rank = rankings.index(neighbor) if rankings and neighbor in rankings else max_rank + 1
node_color = 'red' if neighbor_in_degree > neighbor_out_degree * 1.5 else \
'green' if neighbor_out_degree > neighbor_in_degree * 1.5 else 'lightblue'
if rankings:
neighbor_border_color = rank_to_color(neighbor_rank, max_rank)
else:
neighbor_border_color = node_color
label = str(neighbor) if not show_titles else f"{str(neighbor)}: {neighbor_title[:15]}..."
net.add_node(
neighbor,
label=label,
title=f"{neighbor_title}, In-degree: {neighbor_in_degree}, Out-degree: {neighbor_out_degree}, Rank: {neighbor_rank}",
size=neighbor_size,
color={"background": node_color, "border": neighbor_border_color},
borderWidth=3,
borderWidthSelected=6
)
edge = (node, neighbor)
if edge not in added_edges:
edge_width = math.log(weight + 1) * 8
net.add_edge(node, neighbor, label=f"w:{weight}", width=edge_width, color='lightgray')
added_edges.add(edge)
visited_nodes.add(neighbor)
next_nodes.append(neighbor)
current_nodes = next_nodes
html_content = net.generate_html()
st.components.v1.html(html_content, height=600, scrolling=False)
st.title("Popular Path Expansion + Personalization")
# Brand Selection
selected_brand = st.selectbox("Select a brand:", options=list(BRAND_GRAPHS.keys()))
if "selected_brand" not in st.session_state or st.session_state.selected_brand != selected_brand:
st.session_state.selected_brand = selected_brand
G = load_graph_from_hf(selected_brand)
# Sort nodes by popularity (in-degree + out-degree) and select from top 20
popular_nodes = sorted(G.nodes, key=lambda n: G.in_degree(n) + G.out_degree(n), reverse=True)
top_20_nodes = popular_nodes[:20] if len(popular_nodes) > 20 else popular_nodes
st.session_state.start_node = random.choice(top_20_nodes)
else:
G = load_graph_from_hf(selected_brand)
# Random Selection Button
if st.button("Random Selection"):
st.session_state.start_node = random.choice(list(G.nodes))
start_node = st.text_input(
"Enter the starting node ID:",
value=str(st.session_state.start_node)
)
try:
start_node = str(start_node)
except ValueError:
st.error("Please enter a valid numeric content ID.")
st.stop()
# Input: Student ID
student_id = st.text_input("Enter a student ID (optional):", value="")
# Toggle for showing content titles
show_titles = st.checkbox("Show content titles", value=False)
# Filter the graph
node_degree_threshold = 1
edge_weight_threshold = 1
G_filtered = filter_graph(G, node_threshold=node_degree_threshold, edge_threshold=edge_weight_threshold)
# Fetch rankings if student ID is provided
rankings = {}
if student_id:
content_ids = list(G_filtered.nodes)
rankings = get_rankings_from_api(selected_brand, int(student_id), content_ids)
if rankings:
rankings = rankings['ranked_content_ids']
layers = st.slider("Depth to explore:", 1, 6, value=3)
top_k = st.slider("Branching factor (per node):", 1, 6, value=3)
if st.button("Expand Graph"):
if start_node in G_filtered:
dynamic_visualize_graph(G_filtered, start_node, layers=layers, top_k=top_k, show_titles=show_titles, rankings=rankings)
else:
st.error("The starting node is not in the graph!")
|