""" RAG SaaS Platform - Streamlit UI """ import streamlit as st import requests import pandas as pd import time # Page Configuration - MUST be first Streamlit command st.set_page_config( page_title="RAG SaaS Platform", page_icon="š", layout="wide", initial_sidebar_state="expanded", ) # Minimal CSS - reduced for faster loading st.markdown(""" """, unsafe_allow_html=True) # Initialize session state if 'tenant_id' not in st.session_state: st.session_state.tenant_id = None if 'access_token' not in st.session_state: st.session_state.access_token = None if 'api_url' not in st.session_state: st.session_state.api_url = "http://localhost:7860" # Sidebar Configuration with st.sidebar: st.title("āļø Configuration") # API URL api_url = st.text_input( "Backend API URL", value=st.session_state.api_url, help="URL of the RAG SaaS backend" ) st.session_state.api_url = api_url st.divider() # Authentication Section st.subheader("š Authentication") if not st.session_state.tenant_id: # Login/Register tabs tab1, tab2 = st.tabs(["Login", "Register"]) with tab1: st.caption("Login with Tenant ID") login_tenant_id = st.text_input("Tenant ID", key="login_tenant_id") login_username = st.text_input("Username/Email", key="login_username") login_password = st.text_input("Password", type="password", key="login_password") if st.button("Login", key="login_btn"): try: response = requests.post( f"{api_url}/v1/auth/login", json={ "tenant_id": login_tenant_id, "username": login_username, "password": login_password } ) if response.status_code == 200: data = response.json() st.session_state.tenant_id = data['tenant_id'] st.session_state.access_token = data['access_token'] st.success(f"ā Logged in as {data['tenant_name']}") st.rerun() else: st.error(f"Login failed: {response.json().get('detail', 'Unknown error')}") except Exception as e: st.error(f"Connection error: {e}") with tab2: st.caption("Create New Tenant Account") reg_name = st.text_input("Organization Name", key="reg_name") reg_email = st.text_input("Admin Email", key="reg_email") reg_password = st.text_input("Admin Password", type="password", key="reg_password") if st.button("Create Tenant", key="register_btn"): try: response = requests.post( f"{api_url}/v1/auth/tenants", json={ "name": reg_name, "admin_email": reg_email, "admin_password": reg_password } ) if response.status_code == 200: data = response.json() st.success(f"ā Tenant created!") st.info(f"**Your Tenant ID:** `{data['tenant_id']}`\n\nā ļø Save this ID - you'll need it to login!") else: st.error(f"Registration failed: {response.json().get('detail', 'Unknown error')}") except Exception as e: st.error(f"Connection error: {e}") else: # Logged in state st.success(f"ā Logged in") st.caption(f"Tenant ID: `{st.session_state.tenant_id}`") if st.button("Logout"): st.session_state.tenant_id = None st.session_state.access_token = None st.rerun() # Main Content if not st.session_state.tenant_id: st.title("š RAG SaaS Platform") st.info("š Please login or create a tenant account to continue") # Show platform features st.subheader("Platform Features") col1, col2, col3 = st.columns(3) with col1: st.markdown("### š Multi-Source Retrieval") st.write("Unified search across documents, URLs, and YouTube videos") with col2: st.markdown("### š Tenant Metrics") st.write("Isolated metrics tracking for each tenant") with col3: st.markdown("### ā” Optimized Performance") st.write("Model caching with 70-90% latency reduction") else: # Logged in - show main interface st.title(f"š RAG SaaS Platform") # Main tabs tab1, tab2, tab3, tab4, tab5 = st.tabs(["š¬ Ask Questions", "š Manage Sources", "āļø Cloud Connectors", "š Metrics", "ā¹ļø About"]) # Tab 1: Ask Questions with tab1: st.header("Ask Questions") # Show active sources try: response = requests.get( f"{api_url}/v1/sources/active", headers={"X-Tenant-ID": st.session_state.tenant_id} ) if response.status_code == 200: data = response.json() active_sources = data.get('active_sources', []) source_types = data.get('source_types_present', []) if source_types: st.success(f"ā Active sources: {', '.join(source_types)}") # Show source badges badges_html = "" for source_type in source_types: badge_class = f"source-{source_type}" badges_html += f'{source_type.upper()}' st.markdown(badges_html, unsafe_allow_html=True) else: st.warning("ā ļø No active sources. Add sources in the 'Manage Sources' tab.") except Exception as e: st.error(f"Error fetching sources: {e}") st.divider() # Question input question = st.text_area("Your Question", placeholder="What would you like to know?", height=100) col1, col2 = st.columns([3, 1]) with col1: top_k = st.slider("Number of results", 5, 20, 10) with col2: ask_btn = st.button("š Ask", use_container_width=True, type="primary") if ask_btn and question: with st.spinner("š Searching across all sources..."): try: # Measure total request time request_start_time = time.time() response = requests.post( f"{api_url}/v1/ask/", headers={"X-Tenant-ID": st.session_state.tenant_id}, json={"query": question, "top_k": top_k} ) # Calculate total elapsed time total_elapsed_ms = (time.time() - request_start_time) * 1000 if response.status_code == 200: data = response.json() # Display answer st.markdown("### š” Answer") st.markdown(f"**{data['answer']}**") st.divider() # ===== ENHANCED LATENCY METRICS DISPLAY ===== st.markdown("### ā” Performance Metrics") # Get latency values from backend retrieval_time = data.get('retrieval_time_ms', 0) generation_time = data.get('generation_time_ms', 0) backend_total = retrieval_time + generation_time network_overhead = total_elapsed_ms - backend_total # Create 5 columns for detailed metrics col1, col2, col3, col4, col5 = st.columns(5) with col1: st.metric( "š Retrieval", f"{retrieval_time:.0f}ms", help="Time to search across all sources" ) with col2: st.metric( "š¤ Generation", f"{generation_time:.0f}ms", help="Time to generate the answer" ) with col3: st.metric( "āļø Backend Total", f"{backend_total:.0f}ms", help="Total backend processing time" ) with col4: st.metric( "š Network", f"{network_overhead:.0f}ms", help="Network latency (request + response)" ) with col5: # Color-code total time total_color = "š¢" if total_elapsed_ms < 500 else "š”" if total_elapsed_ms < 1000 else "š“" st.metric( f"{total_color} Total Time", f"{total_elapsed_ms:.0f}ms", help="End-to-end response time" ) # Visual latency breakdown st.markdown("#### Latency Breakdown") latency_data = { "Retrieval": retrieval_time, "Generation": generation_time, "Network": max(0, network_overhead) } # Create a simple bar chart representation total_for_percentage = sum(latency_data.values()) if total_for_percentage > 0: breakdown_html = '
' st.markdown(breakdown_html, unsafe_allow_html=True) # Show percentage breakdown breakdown_text = " | ".join([ f"{comp}: {(time_ms/total_for_percentage)*100:.1f}% ({time_ms:.0f}ms)" for comp, time_ms in latency_data.items() ]) st.caption(breakdown_text) st.divider() # Display confidence confidence = data.get('confidence', 0) confidence_color = "š¢" if confidence > 0.7 else "š”" if confidence > 0.4 else "š“" st.metric( f"{confidence_color} Answer Confidence", f"{confidence:.1%}", help="Model's confidence in the answer quality" ) # Display sources used if 'sources_used' in data and data['sources_used']: st.markdown("### š Sources Used") sources_html = "" for source_type in data['sources_used']: badge_class = f"source-{source_type}" sources_html += f'{source_type.upper()}' st.markdown(sources_html, unsafe_allow_html=True) st.caption(f"Retrieved from {len(data['sources_used'])} different source type(s)") # Display citations if data.get('citations'): st.markdown("### š Citations") for i, citation in enumerate(data['citations'], 1): with st.expander(f"Citation {i} - {citation.get('source_type', 'unknown').upper()}"): st.json(citation) else: st.error(f"ā Error: {response.json().get('detail', 'Unknown error')}") except Exception as e: st.error(f"ā Request failed: {e}") st.exception(e) # Tab 2: Manage Sources with tab2: st.header("Manage Sources") # Sub-tabs for different source types source_tab1, source_tab2, source_tab3, source_tab4 = st.tabs( ["š Documents", "š URLs", "š„ YouTube", "š All Sources"] ) with source_tab1: st.subheader("Upload Documents") st.info("Upload PDF, DOCX, or TXT files") uploaded_file = st.file_uploader("Choose a file", type=['pdf', 'docx', 'txt']) if uploaded_file and st.button("Upload & Ingest"): with st.spinner("Uploading and ingesting..."): try: files = {"file": uploaded_file} response = requests.post( f"{api_url}/v1/ingest/upload", headers={"X-Tenant-ID": st.session_state.tenant_id}, files=files ) if response.status_code == 200: st.success("ā Document uploaded and ingested successfully!") else: st.error(f"Upload failed: {response.json().get('detail', 'Unknown error')}") except Exception as e: st.error(f"Upload error: {e}") with source_tab2: st.subheader("Add Web URLs") url = st.text_input("Enter URL", placeholder="https://example.com/article") url_name = st.text_input("Source Name (optional)", placeholder="Company Blog") if st.button("Add URL") and url: with st.spinner("Adding URL..."): try: response = requests.post( f"{api_url}/v1/ingest/url", headers={"X-Tenant-ID": st.session_state.tenant_id}, json={"url": url, "name": url_name or url} ) if response.status_code == 200: st.success("ā URL added and ingestion started!") else: st.error(f"Failed: {response.json().get('detail', 'Unknown error')}") except Exception as e: st.error(f"Error: {e}") with source_tab3: st.subheader("Add YouTube Videos") st.info("ā Captions-first strategy with transcript upload fallback") youtube_url = st.text_input("YouTube URL", placeholder="https://youtube.com/watch?v=...") video_name = st.text_input("Video Name (optional)") transcript_file = st.file_uploader( "Upload Transcript (optional)", type=['vtt', 'srt', 'txt'], help="If video has no captions, upload a transcript file" ) owner_attestation = st.checkbox( "I attest that I have rights to use this content", help="Required: Confirm you own or have permission to use this video" ) if st.button("Add YouTube Video") and youtube_url and owner_attestation: with st.spinner("Adding YouTube video..."): try: files = {} if transcript_file: files['transcript_file'] = transcript_file data = { 'youtube_url': youtube_url, 'name': video_name or youtube_url, 'owner_attestation': str(owner_attestation).lower() } response = requests.post( f"{api_url}/v1/sources/youtube", headers={"X-Tenant-ID": st.session_state.tenant_id}, data=data, files=files if files else None ) if response.status_code == 200: result = response.json() st.success(f"ā YouTube source added! Strategy: {result.get('strategy', 'captions_api')}") else: st.error(f"Failed: {response.json().get('detail', 'Unknown error')}") except Exception as e: st.error(f"Error: {e}") with source_tab4: st.subheader("All Sources") try: response = requests.get( f"{api_url}/v1/sources/", headers={"X-Tenant-ID": st.session_state.tenant_id} ) if response.status_code == 200: sources = response.json() if sources: for source in sources: with st.expander(f"{source['name']} ({source['type']})"): st.write(f"**ID:** {source['id']}") st.write(f"**Type:** {source['type']}") st.write(f"**URI:** {source['uri']}") st.write(f"**Created:** {source.get('created_at', 'N/A')}") # Toggle source col1, col2 = st.columns(2) with col1: if st.button(f"Disable", key=f"disable_{source['id']}"): requests.put( f"{api_url}/v1/sources/{source['id']}/toggle?enabled=false", headers={"X-Tenant-ID": st.session_state.tenant_id} ) st.rerun() with col2: if st.button(f"Enable", key=f"enable_{source['id']}"): requests.put( f"{api_url}/v1/sources/{source['id']}/toggle?enabled=true", headers={"X-Tenant-ID": st.session_state.tenant_id} ) st.rerun() else: st.info("No sources yet. Add sources using the tabs above.") except Exception as e: st.error(f"Error loading sources: {e}") # Tab 3: Cloud Connectors with tab3: st.header("āļø Cloud Connectors") st.write("Connect to your enterprise cloud storage to sync documents directly.") # Connect New Account st.subheader("Connect New Account") col1, col2, col3 = st.columns(3) with col1: if st.button("š Connect Google Drive", use_container_width=True): try: res = requests.get(f"{api_url}/v1/connectors/gdrive/oauth/start?tenant_id={st.session_state.tenant_id}") if res.status_code == 200: st.markdown(f"[Click here to authorize Google Drive]({res.json()['auth_url']})") else: st.error("Failed to start GDrive OAuth") except Exception as e: st.error(f"Error: {e}") with col2: if st.button("š Connect OneDrive", use_container_width=True): try: res = requests.get(f"{api_url}/v1/connectors/onedrive/oauth/start?tenant_id={st.session_state.tenant_id}") if res.status_code == 200: st.markdown(f"[Click here to authorize OneDrive]({res.json()['auth_url']})") else: st.error("Failed to start OneDrive OAuth") except Exception as e: st.error(f"Error: {e}") with col3: if st.button("š Connect Dropbox", use_container_width=True): try: res = requests.get(f"{api_url}/v1/connectors/dropbox/oauth/start?tenant_id={st.session_state.tenant_id}") if res.status_code == 200: st.markdown(f"[Click here to authorize Dropbox]({res.json()['auth_url']})") else: st.error("Failed to start Dropbox OAuth") except Exception as e: st.error(f"Error: {e}") st.divider() # Manage Existing Connectors st.subheader("Your Connected Accounts") try: response = requests.get( f"{api_url}/v1/connectors/", params={"tenant_id": st.session_state.tenant_id} ) if response.status_code == 200: connectors = response.json() if connectors: for conn in connectors: with st.expander(f"{conn['provider'].upper()} - {conn['display_name']}"): col_a, col_b, col_c = st.columns([2, 1, 1]) with col_a: st.write(f"**Status:** {conn['sync_status'].upper()}") if conn['last_sync']: st.write(f"**Last Sync:** {conn['last_sync']}") with col_b: if st.button("š Trigger Sync", key=f"sync_{conn['id']}"): sync_res = requests.post(f"{api_url}/v1/connectors/{conn['id']}/sync") if sync_res.status_code == 200: st.success("Sync triggered!") else: st.error("Failed to trigger sync") with col_c: if st.button("š Manage Targets", key=f"targets_{conn['id']}"): st.session_state[f"show_targets_{conn['id']}"] = True # Target Selection Mock/Simplified if st.session_state.get(f"show_targets_{conn['id']}"): st.info("Loading roots...") roots_res = requests.get(f"{api_url}/v1/connectors/{conn['id']}/roots") if roots_res.status_code == 200: roots = roots_res.json() target_ids = st.multiselect( "Select Folders/Drives to sync", options=[r['id'] for r in roots], format_func=lambda x: next(r['name'] for r in roots if r['id'] == x), key=f"ms_{conn['id']}" ) if st.button("Save Targets", key=f"save_{conn['id']}"): selected_targets = [{"id": tid, "type": "folder"} for tid in target_ids] save_res = requests.post( f"{api_url}/v1/connectors/{conn['id']}/targets", params={"tenant_id": st.session_state.tenant_id}, json=selected_targets ) if save_res.status_code == 200: st.success("Targets updated!") st.session_state[f"show_targets_{conn['id']}"] = False st.rerun() else: st.info("No cloud accounts connected yet.") except Exception as e: st.error(f"Error fetching connectors: {e}") # Tab 4: Metrics with tab4: st.header("š Your Metrics") # Time range selector hours = st.selectbox("Time Range", [1, 6, 12, 24, 48, 168], index=3, format_func=lambda x: f"Last {x} hours") if st.button("Refresh Metrics"): try: response = requests.get( f"{api_url}/v1/tenants/{st.session_state.tenant_id}/metrics?hours={hours}", headers={"X-Tenant-ID": st.session_state.tenant_id} ) if response.status_code == 200: metrics = response.json() counters = metrics.get('counters', {}) # Display key metrics col1, col2, col3, col4 = st.columns(4) with col1: st.metric("Total Retrievals", counters.get('retrieval_total', 0)) with col2: st.metric("Total Generations", counters.get('generation_total', 0)) with col3: st.metric("Sources Created", counters.get('sources_created', 0)) with col4: st.metric("Ingestions", counters.get('ingestion_total', 0)) st.divider() # Show all counters st.subheader("All Metrics") if counters: df = pd.DataFrame(list(counters.items()), columns=['Metric', 'Value']) st.dataframe(df, use_container_width=True) else: st.info("No metrics data yet. Start using the platform to see metrics!") # Time series data if metrics.get('time_series'): st.subheader("Time Series Data") for metric_name, data_points in metrics['time_series'].items(): if data_points: st.write(f"**{metric_name}**") df = pd.DataFrame(data_points) st.line_chart(df.set_index('timestamp')['value']) else: st.error(f"Failed to load metrics: {response.json().get('detail', 'Unknown error')}") except Exception as e: st.error(f"Error: {e}") st.divider() # Cache statistics st.subheader("š Performance Metrics") try: response = requests.get(f"{api_url}/v1/metrics/cache-stats") if response.status_code == 200: cache_stats = response.json() col1, col2, col3 = st.columns(3) with col1: st.metric("Cache Hit Rate", f"{cache_stats.get('hit_rate', 0):.1%}") with col2: st.metric("Cache Size", f"{cache_stats.get('cache_size', 0)}/{cache_stats.get('cache_max_size', 0)}") with col3: st.metric("Model Load Time", f"{cache_stats.get('model_load_time_seconds', 0):.1f}s") except Exception as e: st.warning(f"Cache stats unavailable: {e}") # Tab 5: About with tab5: st.header("ā¹ļø About RAG SaaS Platform") st.markdown(""" ### Features ā **Multi-Source Retrieval** - Unified search across documents, URLs, and YouTube videos ā **Enterprise Cloud Connectors** - Native sync with Google Drive, OneDrive, and Dropbox ā **Tenant Isolation** - Complete data and metrics separation per tenant ā **Optimized Performance** - Model caching with 70-90% latency reduction ā **Compliant YouTube** - Captions-first with transcript upload fallback ā **Comprehensive Metrics** - Track all your usage and performance ### How It Works 1. **Add Sources** - Upload documents, add URLs, or YouTube videos 2. **Ask Questions** - The system retrieves from ALL your sources 3. **Get Answers** - Receive answers with citations from multiple sources 4. **Track Metrics** - Monitor your usage and performance ### Technical Details - **Backend:** FastAPI with SQLite/PostgreSQL - **Vector Store:** Local (upgradeable to Qdrant/Pgvector) - **Models:** Local CPU-optimized (Qwen 0.5B + MiniLM) - **Retrieval:** RRF fusion with diversity constraints """)