EnYa32 commited on
Commit
40c63dd
·
verified ·
1 Parent(s): 7bcf463

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +158 -37
src/streamlit_app.py CHANGED
@@ -1,40 +1,161 @@
1
- import altair as alt
2
- import numpy as np
3
  import pandas as pd
 
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
1
  import pandas as pd
2
+ import numpy as np
3
  import streamlit as st
4
+ import joblib
5
+ from pathlib import Path
6
+
7
+ # -------------------------
8
+ # Page config
9
+ # -------------------------
10
+ st.set_page_config(
11
+ page_title='Clustering Predictor (KMeans / GMM)',
12
+ page_icon='🧩',
13
+ layout='centered'
14
+ )
15
+
16
+ st.title('🧩 Clustering Predictor (KMeans / GMM)')
17
+ st.write(
18
+ 'Upload a CSV file and get cluster predictions using your saved preprocessing (Scaler + PCA) and model.'
19
+ )
20
+
21
+ # -------------------------
22
+ # Paths (HF friendly)
23
+ # -------------------------
24
+ BASE_DIR = Path(__file__).resolve().parent
25
+
26
+ FEATURES_PATH = BASE_DIR / 'feature_names.pkl'
27
+ SCALER_PATH = BASE_DIR / 'scaler.pkl'
28
+ PCA_PATH = BASE_DIR / 'pca.pkl'
29
+
30
+ KMEANS_PATH = BASE_DIR / 'kmeans_model_k9.pkl'
31
+ GMM_PATH = BASE_DIR / 'gmm_model_k9.pkl'
32
+
33
+ # -------------------------
34
+ # Load assets
35
+ # -------------------------
36
+ @st.cache_resource
37
+ def load_assets():
38
+ missing = []
39
+ for p in [FEATURES_PATH, SCALER_PATH, PCA_PATH]:
40
+ if not p.exists():
41
+ missing.append(p.name)
42
+
43
+ if missing:
44
+ raise FileNotFoundError(
45
+ f'Missing required files in repo root: {missing}. '
46
+ 'Please upload them next to app.py.'
47
+ )
48
+
49
+ feature_names = joblib.load(FEATURES_PATH)
50
+ scaler = joblib.load(SCALER_PATH)
51
+ pca = joblib.load(PCA_PATH)
52
+
53
+ models = {}
54
+ if KMEANS_PATH.exists():
55
+ models['KMeans (k=9)'] = joblib.load(KMEANS_PATH)
56
+ if GMM_PATH.exists():
57
+ models['GMM (k=9)'] = joblib.load(GMM_PATH)
58
+
59
+ if not models:
60
+ raise FileNotFoundError(
61
+ "No model files found. Upload 'kmeans_model_k9.pkl' and/or 'gmm_model_k9.pkl' next to app.py."
62
+ )
63
+
64
+ return feature_names, scaler, pca, models
65
+
66
+ try:
67
+ feature_names, scaler, pca, models = load_assets()
68
+ except Exception as e:
69
+ st.error(str(e))
70
+ st.stop()
71
+
72
+ # -------------------------
73
+ # Model selector
74
+ # -------------------------
75
+ model_name = st.selectbox('Select model', list(models.keys()))
76
+ model = models[model_name]
77
+
78
+ st.caption('Expected input columns:')
79
+ with st.expander('Show feature columns'):
80
+ st.write(feature_names)
81
+
82
+ # -------------------------
83
+ # Upload CSV
84
+ # -------------------------
85
+ uploaded = st.file_uploader('Upload CSV', type=['csv'])
86
+
87
+ def preprocess_df(df_in: pd.DataFrame) -> tuple[pd.DataFrame, pd.Series | None]:
88
+ # Accept either 'id' or 'Id' columns (optional)
89
+ id_col = None
90
+ if 'id' in df_in.columns:
91
+ id_col = 'id'
92
+ elif 'Id' in df_in.columns:
93
+ id_col = 'Id'
94
+
95
+ ids = df_in[id_col].copy() if id_col else None
96
+
97
+ X = df_in.drop(columns=[id_col], errors='ignore').copy()
98
+
99
+ # Validate columns
100
+ missing_cols = [c for c in feature_names if c not in X.columns]
101
+ extra_cols = [c for c in X.columns if c not in feature_names]
102
+
103
+ if missing_cols:
104
+ raise ValueError(f'Missing required columns: {missing_cols}')
105
+
106
+ # Keep only expected columns + correct order
107
+ X = X[feature_names]
108
+
109
+ # Convert to numeric if possible (safety)
110
+ for c in X.columns:
111
+ X[c] = pd.to_numeric(X[c], errors='coerce')
112
+
113
+ if X.isna().any().any():
114
+ # You can choose a strategy; here we fail fast so user fixes the input.
115
+ bad_cols = X.columns[X.isna().any()].tolist()
116
+ raise ValueError(
117
+ f'Found NaNs after converting to numeric. Check these columns: {bad_cols}. '
118
+ 'Make sure your CSV has valid numeric values.'
119
+ )
120
+
121
+ return X, ids
122
+
123
+ def predict_clusters(X: pd.DataFrame) -> np.ndarray:
124
+ X_scaled = scaler.transform(X)
125
+ X_pca = pca.transform(X_scaled)
126
+
127
+ # KMeans and GMM both have predict()
128
+ preds = model.predict(X_pca)
129
+ return preds
130
+
131
+ if uploaded is not None:
132
+ try:
133
+ df_up = pd.read_csv(uploaded)
134
+ X_up, ids = preprocess_df(df_up)
135
+ preds = predict_clusters(X_up)
136
+
137
+ out = pd.DataFrame({'Predicted': preds})
138
+ if ids is not None:
139
+ out.insert(0, 'Id', ids)
140
+
141
+ st.success('✅ Predictions created successfully.')
142
+ st.dataframe(out.head(30), use_container_width=True)
143
+
144
+ st.download_button(
145
+ 'Download predictions as CSV',
146
+ data=out.to_csv(index=False).encode('utf-8'),
147
+ file_name='predictions.csv',
148
+ mime='text/csv'
149
+ )
150
+
151
+ # Quick info
152
+ st.subheader('Cluster distribution')
153
+ dist = pd.Series(preds).value_counts().sort_index()
154
+ st.write(dist)
155
+
156
+ except Exception as e:
157
+ st.error(str(e))
158
+ st.stop()
159
 
160
+ else:
161
+ st.info('Upload a CSV file to generate cluster predictions.')