GodRad's picture
Upload 8 files
db59efd verified
"""
RAG SaaS Platform - Streamlit UI
"""
import streamlit as st
import requests
import pandas as pd
import time
# Page Configuration - MUST be first Streamlit command
st.set_page_config(
page_title="RAG SaaS Platform",
page_icon="πŸ”",
layout="wide",
initial_sidebar_state="expanded",
)
# Minimal CSS - reduced for faster loading
st.markdown("""
<style>
.stButton>button {
width: 100%; border-radius: 5px; height: 3em;
background-color: #1f6feb; color: white; font-weight: bold;
}
.source-badge {
display: inline-block; padding: 5px 10px; border-radius: 5px;
margin: 2px; font-size: 12px; font-weight: bold;
}
.source-file {background-color: #1f6feb; color: white;}
.source-web {background-color: #238636; color: white;}
.source-youtube {background-color: #da3633; color: white;}
.source-gdrive {background-color: #34a853; color: white;}
.source-onedrive {background-color: #0078d4; color: white;}
.source-dropbox {background-color: #0061ff; color: white;}
</style>
""", unsafe_allow_html=True)
# Initialize session state
if 'tenant_id' not in st.session_state:
st.session_state.tenant_id = None
if 'access_token' not in st.session_state:
st.session_state.access_token = None
if 'api_url' not in st.session_state:
st.session_state.api_url = "http://localhost:7860"
# Sidebar Configuration
with st.sidebar:
st.title("βš™οΈ Configuration")
# API URL
api_url = st.text_input(
"Backend API URL",
value=st.session_state.api_url,
help="URL of the RAG SaaS backend"
)
st.session_state.api_url = api_url
st.divider()
# Authentication Section
st.subheader("πŸ” Authentication")
if not st.session_state.tenant_id:
# Login/Register tabs
tab1, tab2 = st.tabs(["Login", "Register"])
with tab1:
st.caption("Login with Tenant ID")
login_tenant_id = st.text_input("Tenant ID", key="login_tenant_id")
login_username = st.text_input("Username/Email", key="login_username")
login_password = st.text_input("Password", type="password", key="login_password")
if st.button("Login", key="login_btn"):
try:
response = requests.post(
f"{api_url}/v1/auth/login",
json={
"tenant_id": login_tenant_id,
"username": login_username,
"password": login_password
}
)
if response.status_code == 200:
data = response.json()
st.session_state.tenant_id = data['tenant_id']
st.session_state.access_token = data['access_token']
st.success(f"βœ… Logged in as {data['tenant_name']}")
st.rerun()
else:
st.error(f"Login failed: {response.json().get('detail', 'Unknown error')}")
except Exception as e:
st.error(f"Connection error: {e}")
with tab2:
st.caption("Create New Tenant Account")
reg_name = st.text_input("Organization Name", key="reg_name")
reg_email = st.text_input("Admin Email", key="reg_email")
reg_password = st.text_input("Admin Password", type="password", key="reg_password")
if st.button("Create Tenant", key="register_btn"):
try:
response = requests.post(
f"{api_url}/v1/auth/tenants",
json={
"name": reg_name,
"admin_email": reg_email,
"admin_password": reg_password
}
)
if response.status_code == 200:
data = response.json()
st.success(f"βœ… Tenant created!")
st.info(f"**Your Tenant ID:** `{data['tenant_id']}`\n\n⚠️ Save this ID - you'll need it to login!")
else:
st.error(f"Registration failed: {response.json().get('detail', 'Unknown error')}")
except Exception as e:
st.error(f"Connection error: {e}")
else:
# Logged in state
st.success(f"βœ… Logged in")
st.caption(f"Tenant ID: `{st.session_state.tenant_id}`")
if st.button("Logout"):
st.session_state.tenant_id = None
st.session_state.access_token = None
st.rerun()
# Main Content
if not st.session_state.tenant_id:
st.title("πŸ” RAG SaaS Platform")
st.info("πŸ‘ˆ Please login or create a tenant account to continue")
# Show platform features
st.subheader("Platform Features")
col1, col2, col3 = st.columns(3)
with col1:
st.markdown("### πŸ“„ Multi-Source Retrieval")
st.write("Unified search across documents, URLs, and YouTube videos")
with col2:
st.markdown("### πŸ“Š Tenant Metrics")
st.write("Isolated metrics tracking for each tenant")
with col3:
st.markdown("### ⚑ Optimized Performance")
st.write("Model caching with 70-90% latency reduction")
else:
# Logged in - show main interface
st.title(f"πŸ” RAG SaaS Platform")
# Main tabs
tab1, tab2, tab3, tab4, tab5 = st.tabs(["πŸ’¬ Ask Questions", "πŸ“ Manage Sources", "☁️ Cloud Connectors", "πŸ“Š Metrics", "ℹ️ About"])
# Tab 1: Ask Questions
with tab1:
st.header("Ask Questions")
# Show active sources
try:
response = requests.get(
f"{api_url}/v1/sources/active",
headers={"X-Tenant-ID": st.session_state.tenant_id}
)
if response.status_code == 200:
data = response.json()
active_sources = data.get('active_sources', [])
source_types = data.get('source_types_present', [])
if source_types:
st.success(f"βœ… Active sources: {', '.join(source_types)}")
# Show source badges
badges_html = ""
for source_type in source_types:
badge_class = f"source-{source_type}"
badges_html += f'<span class="source-badge {badge_class}">{source_type.upper()}</span>'
st.markdown(badges_html, unsafe_allow_html=True)
else:
st.warning("⚠️ No active sources. Add sources in the 'Manage Sources' tab.")
except Exception as e:
st.error(f"Error fetching sources: {e}")
st.divider()
# Question input
question = st.text_area("Your Question", placeholder="What would you like to know?", height=100)
col1, col2 = st.columns([3, 1])
with col1:
top_k = st.slider("Number of results", 5, 20, 10)
with col2:
ask_btn = st.button("πŸ” Ask", use_container_width=True, type="primary")
if ask_btn and question:
with st.spinner("πŸ” Searching across all sources..."):
try:
# Measure total request time
request_start_time = time.time()
response = requests.post(
f"{api_url}/v1/ask/",
headers={"X-Tenant-ID": st.session_state.tenant_id},
json={"query": question, "top_k": top_k}
)
# Calculate total elapsed time
total_elapsed_ms = (time.time() - request_start_time) * 1000
if response.status_code == 200:
data = response.json()
# Display answer
st.markdown("### πŸ’‘ Answer")
st.markdown(f"**{data['answer']}**")
st.divider()
# ===== ENHANCED LATENCY METRICS DISPLAY =====
st.markdown("### ⚑ Performance Metrics")
# Get latency values from backend
retrieval_time = data.get('retrieval_time_ms', 0)
generation_time = data.get('generation_time_ms', 0)
backend_total = retrieval_time + generation_time
network_overhead = total_elapsed_ms - backend_total
# Create 5 columns for detailed metrics
col1, col2, col3, col4, col5 = st.columns(5)
with col1:
st.metric(
"πŸ” Retrieval",
f"{retrieval_time:.0f}ms",
help="Time to search across all sources"
)
with col2:
st.metric(
"πŸ€– Generation",
f"{generation_time:.0f}ms",
help="Time to generate the answer"
)
with col3:
st.metric(
"βš™οΈ Backend Total",
f"{backend_total:.0f}ms",
help="Total backend processing time"
)
with col4:
st.metric(
"🌐 Network",
f"{network_overhead:.0f}ms",
help="Network latency (request + response)"
)
with col5:
# Color-code total time
total_color = "🟒" if total_elapsed_ms < 500 else "🟑" if total_elapsed_ms < 1000 else "πŸ”΄"
st.metric(
f"{total_color} Total Time",
f"{total_elapsed_ms:.0f}ms",
help="End-to-end response time"
)
# Visual latency breakdown
st.markdown("#### Latency Breakdown")
latency_data = {
"Retrieval": retrieval_time,
"Generation": generation_time,
"Network": max(0, network_overhead)
}
# Create a simple bar chart representation
total_for_percentage = sum(latency_data.values())
if total_for_percentage > 0:
breakdown_html = '<div style="display: flex; width: 100%; height: 30px; border-radius: 5px; overflow: hidden;">'
colors = {
"Retrieval": "#1f6feb",
"Generation": "#238636",
"Network": "#da3633"
}
for component, time_ms in latency_data.items():
percentage = (time_ms / total_for_percentage) * 100
color = colors.get(component, "#666")
breakdown_html += f'<div style="background-color: {color}; width: {percentage}%; display: flex; align-items: center; justify-content: center; color: white; font-size: 11px; font-weight: bold;" title="{component}: {time_ms:.0f}ms ({percentage:.1f}%)">{component[:3]}</div>'
breakdown_html += '</div>'
st.markdown(breakdown_html, unsafe_allow_html=True)
# Show percentage breakdown
breakdown_text = " | ".join([
f"{comp}: {(time_ms/total_for_percentage)*100:.1f}% ({time_ms:.0f}ms)"
for comp, time_ms in latency_data.items()
])
st.caption(breakdown_text)
st.divider()
# Display confidence
confidence = data.get('confidence', 0)
confidence_color = "🟒" if confidence > 0.7 else "🟑" if confidence > 0.4 else "πŸ”΄"
st.metric(
f"{confidence_color} Answer Confidence",
f"{confidence:.1%}",
help="Model's confidence in the answer quality"
)
# Display sources used
if 'sources_used' in data and data['sources_used']:
st.markdown("### πŸ“š Sources Used")
sources_html = ""
for source_type in data['sources_used']:
badge_class = f"source-{source_type}"
sources_html += f'<span class="source-badge {badge_class}">{source_type.upper()}</span>'
st.markdown(sources_html, unsafe_allow_html=True)
st.caption(f"Retrieved from {len(data['sources_used'])} different source type(s)")
# Display citations
if data.get('citations'):
st.markdown("### πŸ“– Citations")
for i, citation in enumerate(data['citations'], 1):
with st.expander(f"Citation {i} - {citation.get('source_type', 'unknown').upper()}"):
st.json(citation)
else:
st.error(f"❌ Error: {response.json().get('detail', 'Unknown error')}")
except Exception as e:
st.error(f"❌ Request failed: {e}")
st.exception(e)
# Tab 2: Manage Sources
with tab2:
st.header("Manage Sources")
# Sub-tabs for different source types
source_tab1, source_tab2, source_tab3, source_tab4 = st.tabs(
["πŸ“„ Documents", "🌐 URLs", "πŸŽ₯ YouTube", "πŸ“‹ All Sources"]
)
with source_tab1:
st.subheader("Upload Documents")
st.info("Upload PDF, DOCX, or TXT files")
uploaded_file = st.file_uploader("Choose a file", type=['pdf', 'docx', 'txt'])
if uploaded_file and st.button("Upload & Ingest"):
with st.spinner("Uploading and ingesting..."):
try:
files = {"file": uploaded_file}
response = requests.post(
f"{api_url}/v1/ingest/upload",
headers={"X-Tenant-ID": st.session_state.tenant_id},
files=files
)
if response.status_code == 200:
st.success("βœ… Document uploaded and ingested successfully!")
else:
st.error(f"Upload failed: {response.json().get('detail', 'Unknown error')}")
except Exception as e:
st.error(f"Upload error: {e}")
with source_tab2:
st.subheader("Add Web URLs")
url = st.text_input("Enter URL", placeholder="https://example.com/article")
url_name = st.text_input("Source Name (optional)", placeholder="Company Blog")
if st.button("Add URL") and url:
with st.spinner("Adding URL..."):
try:
response = requests.post(
f"{api_url}/v1/ingest/url",
headers={"X-Tenant-ID": st.session_state.tenant_id},
json={"url": url, "name": url_name or url}
)
if response.status_code == 200:
st.success("βœ… URL added and ingestion started!")
else:
st.error(f"Failed: {response.json().get('detail', 'Unknown error')}")
except Exception as e:
st.error(f"Error: {e}")
with source_tab3:
st.subheader("Add YouTube Videos")
st.info("βœ… Captions-first strategy with transcript upload fallback")
youtube_url = st.text_input("YouTube URL", placeholder="https://youtube.com/watch?v=...")
video_name = st.text_input("Video Name (optional)")
transcript_file = st.file_uploader(
"Upload Transcript (optional)",
type=['vtt', 'srt', 'txt'],
help="If video has no captions, upload a transcript file"
)
owner_attestation = st.checkbox(
"I attest that I have rights to use this content",
help="Required: Confirm you own or have permission to use this video"
)
if st.button("Add YouTube Video") and youtube_url and owner_attestation:
with st.spinner("Adding YouTube video..."):
try:
files = {}
if transcript_file:
files['transcript_file'] = transcript_file
data = {
'youtube_url': youtube_url,
'name': video_name or youtube_url,
'owner_attestation': str(owner_attestation).lower()
}
response = requests.post(
f"{api_url}/v1/sources/youtube",
headers={"X-Tenant-ID": st.session_state.tenant_id},
data=data,
files=files if files else None
)
if response.status_code == 200:
result = response.json()
st.success(f"βœ… YouTube source added! Strategy: {result.get('strategy', 'captions_api')}")
else:
st.error(f"Failed: {response.json().get('detail', 'Unknown error')}")
except Exception as e:
st.error(f"Error: {e}")
with source_tab4:
st.subheader("All Sources")
try:
response = requests.get(
f"{api_url}/v1/sources/",
headers={"X-Tenant-ID": st.session_state.tenant_id}
)
if response.status_code == 200:
sources = response.json()
if sources:
for source in sources:
with st.expander(f"{source['name']} ({source['type']})"):
st.write(f"**ID:** {source['id']}")
st.write(f"**Type:** {source['type']}")
st.write(f"**URI:** {source['uri']}")
st.write(f"**Created:** {source.get('created_at', 'N/A')}")
# Toggle source
col1, col2 = st.columns(2)
with col1:
if st.button(f"Disable", key=f"disable_{source['id']}"):
requests.put(
f"{api_url}/v1/sources/{source['id']}/toggle?enabled=false",
headers={"X-Tenant-ID": st.session_state.tenant_id}
)
st.rerun()
with col2:
if st.button(f"Enable", key=f"enable_{source['id']}"):
requests.put(
f"{api_url}/v1/sources/{source['id']}/toggle?enabled=true",
headers={"X-Tenant-ID": st.session_state.tenant_id}
)
st.rerun()
else:
st.info("No sources yet. Add sources using the tabs above.")
except Exception as e:
st.error(f"Error loading sources: {e}")
# Tab 3: Cloud Connectors
with tab3:
st.header("☁️ Cloud Connectors")
st.write("Connect to your enterprise cloud storage to sync documents directly.")
# Connect New Account
st.subheader("Connect New Account")
col1, col2, col3 = st.columns(3)
with col1:
if st.button("πŸ”Œ Connect Google Drive", use_container_width=True):
try:
res = requests.get(f"{api_url}/v1/connectors/gdrive/oauth/start?tenant_id={st.session_state.tenant_id}")
if res.status_code == 200:
st.markdown(f"[Click here to authorize Google Drive]({res.json()['auth_url']})")
else:
st.error("Failed to start GDrive OAuth")
except Exception as e:
st.error(f"Error: {e}")
with col2:
if st.button("πŸ”Œ Connect OneDrive", use_container_width=True):
try:
res = requests.get(f"{api_url}/v1/connectors/onedrive/oauth/start?tenant_id={st.session_state.tenant_id}")
if res.status_code == 200:
st.markdown(f"[Click here to authorize OneDrive]({res.json()['auth_url']})")
else:
st.error("Failed to start OneDrive OAuth")
except Exception as e:
st.error(f"Error: {e}")
with col3:
if st.button("πŸ”Œ Connect Dropbox", use_container_width=True):
try:
res = requests.get(f"{api_url}/v1/connectors/dropbox/oauth/start?tenant_id={st.session_state.tenant_id}")
if res.status_code == 200:
st.markdown(f"[Click here to authorize Dropbox]({res.json()['auth_url']})")
else:
st.error("Failed to start Dropbox OAuth")
except Exception as e:
st.error(f"Error: {e}")
st.divider()
# Manage Existing Connectors
st.subheader("Your Connected Accounts")
try:
response = requests.get(
f"{api_url}/v1/connectors/",
params={"tenant_id": st.session_state.tenant_id}
)
if response.status_code == 200:
connectors = response.json()
if connectors:
for conn in connectors:
with st.expander(f"{conn['provider'].upper()} - {conn['display_name']}"):
col_a, col_b, col_c = st.columns([2, 1, 1])
with col_a:
st.write(f"**Status:** {conn['sync_status'].upper()}")
if conn['last_sync']:
st.write(f"**Last Sync:** {conn['last_sync']}")
with col_b:
if st.button("πŸ”„ Trigger Sync", key=f"sync_{conn['id']}"):
sync_res = requests.post(f"{api_url}/v1/connectors/{conn['id']}/sync")
if sync_res.status_code == 200:
st.success("Sync triggered!")
else:
st.error("Failed to trigger sync")
with col_c:
if st.button("πŸ“ Manage Targets", key=f"targets_{conn['id']}"):
st.session_state[f"show_targets_{conn['id']}"] = True
# Target Selection Mock/Simplified
if st.session_state.get(f"show_targets_{conn['id']}"):
st.info("Loading roots...")
roots_res = requests.get(f"{api_url}/v1/connectors/{conn['id']}/roots")
if roots_res.status_code == 200:
roots = roots_res.json()
target_ids = st.multiselect(
"Select Folders/Drives to sync",
options=[r['id'] for r in roots],
format_func=lambda x: next(r['name'] for r in roots if r['id'] == x),
key=f"ms_{conn['id']}"
)
if st.button("Save Targets", key=f"save_{conn['id']}"):
selected_targets = [{"id": tid, "type": "folder"} for tid in target_ids]
save_res = requests.post(
f"{api_url}/v1/connectors/{conn['id']}/targets",
params={"tenant_id": st.session_state.tenant_id},
json=selected_targets
)
if save_res.status_code == 200:
st.success("Targets updated!")
st.session_state[f"show_targets_{conn['id']}"] = False
st.rerun()
else:
st.info("No cloud accounts connected yet.")
except Exception as e:
st.error(f"Error fetching connectors: {e}")
# Tab 4: Metrics
with tab4:
st.header("πŸ“Š Your Metrics")
# Time range selector
hours = st.selectbox("Time Range", [1, 6, 12, 24, 48, 168], index=3, format_func=lambda x: f"Last {x} hours")
if st.button("Refresh Metrics"):
try:
response = requests.get(
f"{api_url}/v1/tenants/{st.session_state.tenant_id}/metrics?hours={hours}",
headers={"X-Tenant-ID": st.session_state.tenant_id}
)
if response.status_code == 200:
metrics = response.json()
counters = metrics.get('counters', {})
# Display key metrics
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Total Retrievals", counters.get('retrieval_total', 0))
with col2:
st.metric("Total Generations", counters.get('generation_total', 0))
with col3:
st.metric("Sources Created", counters.get('sources_created', 0))
with col4:
st.metric("Ingestions", counters.get('ingestion_total', 0))
st.divider()
# Show all counters
st.subheader("All Metrics")
if counters:
df = pd.DataFrame(list(counters.items()), columns=['Metric', 'Value'])
st.dataframe(df, use_container_width=True)
else:
st.info("No metrics data yet. Start using the platform to see metrics!")
# Time series data
if metrics.get('time_series'):
st.subheader("Time Series Data")
for metric_name, data_points in metrics['time_series'].items():
if data_points:
st.write(f"**{metric_name}**")
df = pd.DataFrame(data_points)
st.line_chart(df.set_index('timestamp')['value'])
else:
st.error(f"Failed to load metrics: {response.json().get('detail', 'Unknown error')}")
except Exception as e:
st.error(f"Error: {e}")
st.divider()
# Cache statistics
st.subheader("πŸš€ Performance Metrics")
try:
response = requests.get(f"{api_url}/v1/metrics/cache-stats")
if response.status_code == 200:
cache_stats = response.json()
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Cache Hit Rate", f"{cache_stats.get('hit_rate', 0):.1%}")
with col2:
st.metric("Cache Size", f"{cache_stats.get('cache_size', 0)}/{cache_stats.get('cache_max_size', 0)}")
with col3:
st.metric("Model Load Time", f"{cache_stats.get('model_load_time_seconds', 0):.1f}s")
except Exception as e:
st.warning(f"Cache stats unavailable: {e}")
# Tab 5: About
with tab5:
st.header("ℹ️ About RAG SaaS Platform")
st.markdown("""
### Features
βœ… **Multi-Source Retrieval** - Unified search across documents, URLs, and YouTube videos
βœ… **Enterprise Cloud Connectors** - Native sync with Google Drive, OneDrive, and Dropbox
βœ… **Tenant Isolation** - Complete data and metrics separation per tenant
βœ… **Optimized Performance** - Model caching with 70-90% latency reduction
βœ… **Compliant YouTube** - Captions-first with transcript upload fallback
βœ… **Comprehensive Metrics** - Track all your usage and performance
### How It Works
1. **Add Sources** - Upload documents, add URLs, or YouTube videos
2. **Ask Questions** - The system retrieves from ALL your sources
3. **Get Answers** - Receive answers with citations from multiple sources
4. **Track Metrics** - Monitor your usage and performance
### Technical Details
- **Backend:** FastAPI with SQLite/PostgreSQL
- **Vector Store:** Local (upgradeable to Qdrant/Pgvector)
- **Models:** Local CPU-optimized (Qwen 0.5B + MiniLM)
- **Retrieval:** RRF fusion with diversity constraints
""")