File size: 9,022 Bytes
209e6b2
 
 
 
 
292c185
7d63ff8
 
c0428c6
209e6b2
c0428c6
2aa8945
c0428c6
 
 
 
2aa8945
 
c0428c6
 
7d63ff8
 
 
 
209e6b2
c0428c6
 
 
 
 
 
79d9770
8d70952
98f49ee
c0428c6
 
 
 
 
 
 
209e6b2
2aa8945
c0428c6
 
 
2aa8945
 
 
 
 
 
 
 
 
 
 
 
 
c0428c6
7d63ff8
 
 
 
 
 
 
d89e6da
da71e01
7d63ff8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa2b395
7d63ff8
 
 
 
209e6b2
 
 
 
 
122c116
 
 
209e6b2
122c116
 
209e6b2
 
292c185
d8ff0e1
122c116
7d63ff8
 
209e6b2
a1850db
 
 
292c185
a1850db
7d63ff8
 
 
 
 
292c185
a1850db
292c185
7d63ff8
 
 
 
 
292c185
a1850db
209e6b2
 
 
 
 
0953001
209e6b2
 
292c185
209e6b2
 
 
292c185
 
 
 
7d63ff8
 
292c185
7d63ff8
 
 
 
 
 
292c185
 
 
 
7d63ff8
292c185
7d63ff8
 
 
292c185
 
122c116
292c185
c3c6baa
292c185
209e6b2
292c185
209e6b2
292c185
209e6b2
 
 
 
122c116
7d63ff8
2aa8945
 
 
209e6b2
122c116
 
c0428c6
292c185
 
 
 
 
122c116
c0428c6
209e6b2
292c185
 
 
 
f3c758e
209e6b2
f3c758e
209e6b2
122c116
afb7de3
4622e0f
afb7de3
 
 
f3c758e
 
7d63ff8
 
 
292c185
 
 
 
 
 
122c116
 
7d63ff8
 
40af051
 
 
 
 
c0428c6
209e6b2
 
 
 
2aa8945
7d63ff8
209e6b2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
import streamlit as st
import networkx as nx
from pyvis.network import Network
import pickle
import math
import random
import requests
import os
from huggingface_hub import hf_hub_download

# Dictionary to map brands to their respective HuggingFace model repo files
BRAND_GRAPHS = {
    'drumeo': 'drumeo_graph.pkl',
    'pianote': 'pianote_graph.pkl',
    'singeo': 'singeo_graph.pkl',
    'guitareo': 'guitareo_graph.pkl'
}

# HuggingFace Repository Info
#HF_REPO = "MusoraProductDepartment/popular-path-graphs"
AUTH_TOKEN = os.getenv('HF_TOKEN')
API_URL = "https://MusoraProductDepartment-PWGenerator.hf.space/rank_items/"


@st.cache_resource
def load_graph_from_hf(brand):
    """
    Load the graph for the selected brand from HuggingFace Hub.
    """
    try:
        # Download the file from HuggingFace Hub
        HF_REPO = f'MusoraProductDepartment/{brand}-graph'
        cache_dir = '/tmp'
        file_path = hf_hub_download(repo_id=HF_REPO, filename=BRAND_GRAPHS[brand], token=AUTH_TOKEN, cache_dir=cache_dir, repo_type='model')
        # Load the graph
        with open(file_path, 'rb') as f:
            return pickle.load(f)
    except Exception as e:
        st.error(f"Error loading graph from HuggingFace: {e}")
        return None


def filter_graph(graph, node_threshold=10, edge_threshold=5):
    """
    Filters the graph to include only popular nodes and edges.
    """
    popular_nodes = [
        node for node in graph.nodes
        if graph.degree(node) >= node_threshold
    ]

    filtered_graph = graph.subgraph(popular_nodes).copy()

    for u, v, data in list(filtered_graph.edges(data=True)):
        if data.get("weight", 0) < edge_threshold:
            filtered_graph.remove_edge(u, v)

    return filtered_graph


def get_rankings_from_api(brand, user_id, content_ids):
    """
    Call the rank_items API to fetch rankings for the given user and content IDs.
    """
    try:
        payload = {
            "brand": brand.upper(),
            "user_id": int(user_id),
            "content_ids": [int(content_id) for content_id in content_ids]
        }
        headers = {
            "Authorization": f"Bearer {AUTH_TOKEN}",
            "accept": "application/json",
            "Content-Type": "application/json"
        }
        response = requests.post(API_URL, json=payload, headers=headers)
        response.raise_for_status()
        rankings = response.json()
        return rankings
    except Exception as e:
        st.error(f"Error calling rank_items API: {e}")
        return {}


def rank_to_color(rank, max_rank):
    """
    Map a rank to a grayscale color, where dark gray indicates high relevance (low rank),
    and light gray indicates low relevance (high rank).
    """
    if rank > max_rank:  # Handle items without ranking
        return "#E8E8E8"  # Very light gray for unranked items
    intensity = int(55 + (rank / max_rank) * 200)  # Darker for lower ranks
    return f"rgb({intensity}, {intensity}, {intensity})"  # Grayscale


def dynamic_visualize_graph(graph, start_node, layers=3, top_k=5, show_titles=False, rankings=None):
    net = Network(notebook=False, width="100%", height="600px", directed=True)
    net.set_options("""
    var options = {
      "physics": {
        "barnesHut": {
          "gravitationalConstant": -15000,
          "centralGravity": 0.8
        }
      }
    }
    """)

    visited_nodes = set()
    added_edges = set()
    current_nodes = [str(start_node)]

    max_rank = len(rankings) if rankings else 0

    # Add the starting node, color it red, and include a tooltip
    start_title = graph.nodes[str(start_node)].get('title', 'No title available')
    start_in_degree = graph.in_degree(str(start_node))
    start_out_degree = graph.out_degree(str(start_node))
    start_node_size = (start_in_degree + start_out_degree) * 0.15
    start_rank = rankings.index(str(start_node)) if rankings and str(start_node) in rankings else max_rank + 1
    if rankings:
        start_border_color = rank_to_color(start_rank, max_rank)
    else:
        start_border_color = 'darkblue'
    label = str(start_node) if not show_titles else f"{str(start_node)}: {start_title[:15]}..."
    net.add_node(
        str(start_node),
        label=label,
        color={"background": "darkblue", "border": start_border_color},
        title=f"{start_title}, In-degree: {start_in_degree}, Out-degree: {start_out_degree}, Rank: {start_rank}",
        size=start_node_size,
        borderWidth=3,
        borderWidthSelected=6
    )
    visited_nodes.add(str(start_node))

    for layer in range(layers):
        next_nodes = []
        for node in current_nodes:
            neighbors = sorted(
                [(str(neighbor), data['weight']) for neighbor, data in graph[node].items()],
                key=lambda x: x[1],
                reverse=True
            )[:top_k]

            for neighbor, weight in neighbors:
                if neighbor not in visited_nodes:
                    neighbor_title = graph.nodes[neighbor].get('title', 'No title available')
                    neighbor_in_degree = graph.in_degree(neighbor)
                    neighbor_out_degree = graph.out_degree(neighbor)
                    neighbor_size = (neighbor_in_degree + neighbor_out_degree) * 0.15
                    neighbor_rank = rankings.index(neighbor) if rankings and neighbor in rankings else max_rank + 1

                    node_color = 'red' if neighbor_in_degree > neighbor_out_degree * 1.5 else \
                        'green' if neighbor_out_degree > neighbor_in_degree * 1.5 else 'lightblue'
                    if rankings:
                        neighbor_border_color = rank_to_color(neighbor_rank, max_rank)
                    else:
                        neighbor_border_color = node_color

                    label = str(neighbor) if not show_titles else f"{str(neighbor)}: {neighbor_title[:15]}..."
                    net.add_node(
                        neighbor,
                        label=label,
                        title=f"{neighbor_title}, In-degree: {neighbor_in_degree}, Out-degree: {neighbor_out_degree}, Rank: {neighbor_rank}",
                        size=neighbor_size,
                        color={"background": node_color, "border": neighbor_border_color},
                        borderWidth=3,
                        borderWidthSelected=6
                    )
                edge = (node, neighbor)
                if edge not in added_edges:
                    edge_width = math.log(weight + 1) * 8
                    net.add_edge(node, neighbor, label=f"w:{weight}", width=edge_width, color='lightgray')
                    added_edges.add(edge)
                visited_nodes.add(neighbor)
                next_nodes.append(neighbor)

        current_nodes = next_nodes

    html_content = net.generate_html()
    st.components.v1.html(html_content, height=600, scrolling=False)


st.title("Popular Path Expansion + Personalization")

# Brand Selection
selected_brand = st.selectbox("Select a brand:", options=list(BRAND_GRAPHS.keys()))

if "selected_brand" not in st.session_state or st.session_state.selected_brand != selected_brand:
    st.session_state.selected_brand = selected_brand
    G = load_graph_from_hf(selected_brand)

    # Sort nodes by popularity (in-degree + out-degree) and select from top 20
    popular_nodes = sorted(G.nodes, key=lambda n: G.in_degree(n) + G.out_degree(n), reverse=True)
    top_20_nodes = popular_nodes[:20] if len(popular_nodes) > 20 else popular_nodes
    st.session_state.start_node = random.choice(top_20_nodes)
else:
    G = load_graph_from_hf(selected_brand)

# Random Selection Button
if st.button("Random Selection"):
    st.session_state.start_node = random.choice(list(G.nodes))

start_node = st.text_input(
    "Enter the starting node ID:",
    value=str(st.session_state.start_node)
)

try:
    start_node = str(start_node)
except ValueError:
    st.error("Please enter a valid numeric content ID.")
    st.stop()


# Input: Student ID
student_id = st.text_input("Enter a student ID (optional):", value="")

# Toggle for showing content titles
show_titles = st.checkbox("Show content titles", value=False)

# Filter the graph
node_degree_threshold = 1
edge_weight_threshold = 1
G_filtered = filter_graph(G, node_threshold=node_degree_threshold, edge_threshold=edge_weight_threshold)

# Fetch rankings if student ID is provided
rankings = {}
if student_id:  
    content_ids = list(G_filtered.nodes)
    rankings = get_rankings_from_api(selected_brand, int(student_id), content_ids)
    if rankings:
        rankings = rankings['ranked_content_ids']

layers = st.slider("Depth to explore:", 1, 6, value=3)
top_k = st.slider("Branching factor (per node):", 1, 6, value=3)

if st.button("Expand Graph"):
    if start_node in G_filtered:
        dynamic_visualize_graph(G_filtered, start_node, layers=layers, top_k=top_k, show_titles=show_titles, rankings=rankings)
    else:
        st.error("The starting node is not in the graph!")