Dua Rajper commited on
Commit
560caa5
·
verified ·
1 Parent(s): 7c2b1bc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +289 -0
app.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import google.generativeai as genai
4
+ from dotenv import load_dotenv
5
+ import time
6
+ from typing import Any, List, Optional
7
+ import numpy as np
8
+ from sklearn.metrics.pairwise import cosine_similarity
9
+ import tensorflow as tf
10
+ from tensorflow.keras.models import Sequential
11
+ from tensorflow.keras.layers import Dense, Input
12
+ from tensorflow.keras.utils import to_categorical
13
+ from tensorflow.keras.optimizers import Adam
14
+
15
+ # Load environment variables
16
+ load_dotenv()
17
+ GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
18
+
19
+ # Configure Generative AI model
20
+ if GOOGLE_API_KEY:
21
+ genai.configure(api_key=GOOGLE_API_KEY)
22
+ else:
23
+ st.error(
24
+ "Google AI Studio API key not found. Please add it to your .env file. "
25
+ "You can obtain an API key from https://makersuite.google.com/."
26
+ )
27
+ st.stop()
28
+
29
+ st.title("Embeddings and Vector Search Demo")
30
+ st.subheader("Explore Embeddings and Vector Databases")
31
+
32
+ # Sidebar for explanations
33
+ with st.sidebar:
34
+ st.header("Embeddings and Vector Search")
35
+ st.markdown(
36
+ """
37
+ This app demonstrates how embeddings and vector databases can be used for various tasks.
38
+ """
39
+ )
40
+ st.subheader("Key Concepts:")
41
+ st.markdown(
42
+ """
43
+ - **Embeddings**: Numerical representations of text, capturing semantic meaning.
44
+ - **Vector Databases**: Databases optimized for storing and querying vectors (simulated here).
45
+ - **Retrieval Augmented Generation (RAG)**: Combining retrieval with LLM generation.
46
+ - **Cosine Similarity**: A measure of similarity between two vectors.
47
+ - **Neural Networks**: Using embeddings as input for classification.
48
+ """
49
+ )
50
+ st.subheader("Whitepaper Insights")
51
+ st.markdown(
52
+ """
53
+ - Efficient similarity search using vector indexes (e.g., ANN).
54
+ - Handling large datasets and scalability considerations.
55
+ - Applications of embeddings: search, recommendation, classification, etc.
56
+ """
57
+ )
58
+
59
+ # --- Helper Functions ---
60
+ def code_block(text: str, language: str = "text") -> None:
61
+ """Displays text as a formatted code block in Streamlit."""
62
+ st.markdown(f"```{language}\n{text}\n```", unsafe_allow_html=True)
63
+
64
+ def display_response(response: Any) -> None:
65
+ """Displays the model's response."""
66
+ if response and hasattr(response, "text"):
67
+ st.subheader("Generated Response:")
68
+ st.markdown(response.text)
69
+ else:
70
+ st.error("Failed to generate a response.")
71
+
72
+ def generate_embeddings(texts: List[str], model_name: str = "models/embedding-001") -> Optional[List[List[float]]]:
73
+ """Generates embeddings for a list of texts using a specified model.
74
+ Args:
75
+ texts: List of text strings.
76
+ model_name: Name of the embedding model.
77
+ Returns:
78
+ List of embeddings (list of floats) or None on error.
79
+ """
80
+ try:
81
+ # Use the embedding model directly
82
+ embeddings = []
83
+ for text in texts:
84
+ result = genai.embed_content(
85
+ model=model_name,
86
+ content=text,
87
+ task_type="retrieval_document" # or "retrieval_query" for queries
88
+ )
89
+ embeddings.append(result['embedding'])
90
+ return embeddings
91
+ except Exception as e:
92
+ st.error(f"Error generating embeddings with model '{model_name}': {e}")
93
+ return None
94
+
95
+ def generate_with_retry(prompt: str, model_name: str, generation_config: genai.types.GenerationConfig, max_retries: int = 3, delay: int = 5) -> Any:
96
+ """Generates content with retry logic and error handling."""
97
+ for i in range(max_retries):
98
+ try:
99
+ model = genai.GenerativeModel(model_name)
100
+ response = model.generate_content(prompt, generation_config=generation_config)
101
+ return response
102
+ except Exception as e:
103
+ error_message = str(e)
104
+ st.warning(f"Error during generation (attempt {i + 1}/{max_retries}): {error_message}")
105
+ if "404" in error_message and "not found" in error_message:
106
+ st.error(
107
+ f"Model '{model_name}' is not available or not supported. Please select a different model."
108
+ )
109
+ return None
110
+ elif i < max_retries - 1:
111
+ st.info(f"Retrying in {delay} seconds...")
112
+ time.sleep(delay)
113
+ else:
114
+ st.error(f"Failed to generate content after {max_retries} attempts. Please check your prompt and model.")
115
+ return None
116
+ return None
117
+
118
+ def calculate_similarity(embedding1: List[float], embedding2: List[float]) -> float:
119
+ """Calculates the cosine similarity between two embeddings."""
120
+ return cosine_similarity(np.array(embedding1).reshape(1, -1), np.array(embedding2).reshape(1, -1))[0][0]
121
+
122
+ def create_and_train_model(
123
+ embeddings: List[List[float]],
124
+ labels: List[int],
125
+ num_classes: int,
126
+ epochs: int,
127
+ batch_size: int,
128
+ learning_rate: float,
129
+ optimizer_str: str
130
+ ) -> tf.keras.Model:
131
+ """Creates and trains a neural network for classification."""
132
+ model = Sequential([
133
+ Input(shape=(len(embeddings[0]),), # Fixed the double comma here
134
+ Dense(64, activation='relu'),
135
+ Dense(32, activation='relu'),
136
+ Dense(num_classes, activation='softmax')
137
+ ])
138
+
139
+ if optimizer_str.lower() == 'adam':
140
+ optimizer = Adam(learning_rate=learning_rate)
141
+ elif optimizer_str.lower() == 'sgd':
142
+ optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
143
+ elif optimizer_str.lower() == 'rmsprop':
144
+ optimizer = tf.keras.optimizers.RMSprop(learning_rate=learning_rate)
145
+ else:
146
+ optimizer = Adam(learning_rate=learning_rate)
147
+
148
+ model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
149
+ encoded_labels = to_categorical(labels, num_classes=num_classes)
150
+ model.fit(np.array(embeddings), encoded_labels, epochs=epochs, batch_size=batch_size, verbose=0)
151
+ return model
152
+ # --- RAG Question Answering ---
153
+ st.header("RAG Question Answering")
154
+ rag_model_name = st.selectbox("Select model for RAG:", ["gemini-pro"], index=0)
155
+ rag_embedding_model = st.selectbox("Select embedding model for RAG:", ["models/embedding-001"], index=0)
156
+ rag_context = st.text_area(
157
+ "Enter your context documents:",
158
+ "Relevant information to answer the question. Separate documents with newlines.",
159
+ height=150,
160
+ )
161
+ rag_question = st.text_area("Ask a question about the context:", "What is the main topic?", height=70)
162
+ rag_max_context_length = st.number_input("Maximum Context Length", min_value=100, max_value=2000, value=500, step=100)
163
+
164
+ if st.button("Answer with RAG"):
165
+ if not rag_context or not rag_question:
166
+ st.warning("Please provide both context and a question.")
167
+ else:
168
+ with st.spinner("Generating answer..."):
169
+ try:
170
+ # 1. Generate embeddings for the context
171
+ context_embeddings = generate_embeddings(rag_context.split('\n'), rag_embedding_model)
172
+ if not context_embeddings:
173
+ st.stop()
174
+
175
+ # 2. Generate embedding for the question
176
+ question_embedding = generate_embeddings([rag_question], rag_embedding_model)
177
+ if not question_embedding:
178
+ st.stop()
179
+
180
+ # 3. Calculate similarity scores
181
+ similarities = cosine_similarity(np.array(question_embedding).reshape(1, -1), np.array(context_embeddings))[0]
182
+
183
+ # 4. Find the most relevant document(s)
184
+ most_relevant_index = np.argmax(similarities)
185
+ relevant_context = rag_context.split('\n')[most_relevant_index]
186
+ if len(relevant_context) > rag_max_context_length:
187
+ relevant_context = relevant_context[:rag_max_context_length]
188
+
189
+ # 5. Construct the prompt
190
+ rag_prompt = f"Use the following context to answer the question: '{rag_question}'.\nContext: {relevant_context}"
191
+
192
+ # 6. Generate the answer
193
+ response = generate_with_retry(rag_prompt, rag_model_name, generation_config=genai.types.GenerationConfig())
194
+ if response:
195
+ display_response(response)
196
+ except Exception as e:
197
+ st.error(f"An error occurred: {e}")
198
+
199
+ # --- Text Similarity ---
200
+ st.header("Text Similarity")
201
+ similarity_embedding_model = st.selectbox("Select embedding model for similarity:", ["models/embedding-001"], index=0)
202
+ text1 = st.text_area("Enter text 1:", "This is the first sentence.", height=70)
203
+ text2 = st.text_area("Enter text 2:", "This is a similar sentence.", height=70)
204
+
205
+ if st.button("Calculate Similarity"):
206
+ if not text1 or not text2:
207
+ st.warning("Please provide both texts.")
208
+ else:
209
+ with st.spinner("Calculating similarity..."):
210
+ try:
211
+ embeddings = generate_embeddings([text1, text2], similarity_embedding_model)
212
+ if not embeddings:
213
+ st.stop()
214
+ similarity = calculate_similarity(embeddings[0], embeddings[1])
215
+ st.subheader("Cosine Similarity:")
216
+ st.write(similarity)
217
+ except Exception as e:
218
+ st.error(f"An error occurred: {e}")
219
+
220
+ # --- Neural Classification ---
221
+ st.header("Neural Classification with Embeddings")
222
+ classification_embedding_model = st.selectbox("Select embedding model for classification:", ["models/embedding-001"], index=0)
223
+ classification_data = st.text_area(
224
+ "Enter your training data (text, label pairs), separated by newlines. Example: text1,0\\ntext2,1",
225
+ "text1,0\ntext2,1\ntext3,0\ntext4,1",
226
+ height=150,
227
+ )
228
+ classification_prompt = st.text_area("Enter text to classify:", "This is a test text.", height=70)
229
+ num_epochs = st.number_input("Number of Epochs", min_value=1, max_value=200, value=10, step=1)
230
+ batch_size = st.number_input("Batch Size", min_value=1, max_value=128, value=32, step=1)
231
+ learning_rate = st.number_input("Learning Rate", min_value=0.0001, max_value=0.1, value=0.0001, step=0.0001, format="%.4f")
232
+ optimizer_str = st.selectbox("Optimizer", ['adam', 'sgd', 'rmsprop'], index=0)
233
+
234
+ def process_classification_data(data: str) -> Optional[tuple[List[str], List[int]]]:
235
+ """Processes the classification data string into lists of texts and labels."""
236
+ data_pairs = [line.split(',') for line in data.split('\n') if ',' in line]
237
+ if not data_pairs:
238
+ st.error("No valid data pairs found. Please ensure each line contains 'text,label'.")
239
+ return None
240
+ texts = []
241
+ labels = []
242
+ for i, pair in enumerate(data_pairs):
243
+ if len(pair) != 2:
244
+ st.error(f"Invalid data format in line {i + 1}: '{','.join(pair)}'. Expected 'text,label'.")
245
+ return None
246
+ text = pair[0].strip()
247
+ label_str = pair[1].strip()
248
+ try:
249
+ label = int(label_str)
250
+ texts.append(text)
251
+ labels.append(label)
252
+ except ValueError:
253
+ st.error(f"Invalid label value in line {i + 1}: '{label_str}'. Label must be an integer.")
254
+ return None
255
+ return texts, labels
256
+
257
+ if st.button("Classify"):
258
+ if not classification_data or not classification_prompt:
259
+ st.warning("Please provide training data and text to classify.")
260
+ else:
261
+ with st.spinner("Classifying..."):
262
+ try:
263
+ processed_data = process_classification_data(classification_data)
264
+ if not processed_data:
265
+ st.stop()
266
+ train_texts, train_labels = processed_data
267
+ num_classes = len(set(train_labels))
268
+
269
+ train_embeddings = generate_embeddings(train_texts, classification_embedding_model)
270
+ if not train_embeddings:
271
+ st.stop()
272
+
273
+ model = create_and_train_model(
274
+ train_embeddings, train_labels, num_classes, num_epochs, batch_size, learning_rate, optimizer_str
275
+ )
276
+
277
+ predict_embedding = generate_embeddings([classification_prompt], classification_embedding_model)
278
+ if not predict_embedding:
279
+ st.stop()
280
+
281
+ prediction = model.predict(np.array([predict_embedding]), verbose=0)
282
+ predicted_class = np.argmax(prediction[0])
283
+ st.subheader("Predicted Class:")
284
+ st.write(predicted_class)
285
+ st.subheader("Prediction Probabilities:")
286
+ st.write(prediction)
287
+
288
+ except Exception as e:
289
+ st.error(f"An error occurred: {e}")