Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Initialize the RAG system by creating embeddings and FAISS index. | |
| """ | |
| import sys | |
| from pathlib import Path | |
| # Add project root to Python path | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| import numpy as np | |
| from config import DATA_DIR, MODELS_DIR, CHUNK_SIZE, CHUNK_OVERLAP, EMBEDDING_MODEL | |
| import sqlite3 | |
| import hashlib | |
| from typing import List, Tuple | |
| import os | |
| def chunk_text(text: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> List[str]: | |
| """Simple text chunking implementation.""" | |
| words = text.split() | |
| chunks = [] | |
| for i in range(0, len(words), chunk_size - overlap): | |
| chunk = " ".join(words[i:i + chunk_size]) | |
| chunks.append(chunk) | |
| if i + chunk_size >= len(words): | |
| break | |
| return chunks | |
| def initialize_rag(): | |
| """Initialize the RAG system with sample data.""" | |
| print("Initializing RAG system...") | |
| # Load embedding model | |
| print(f"Loading embedding model: {EMBEDDING_MODEL}") | |
| embedder = SentenceTransformer(EMBEDDING_MODEL) | |
| # Collect all documents | |
| documents = [] | |
| doc_ids = [] | |
| chunk_metadata = [] | |
| # First, check if we have documents | |
| md_files = list(DATA_DIR.glob("*.md")) | |
| txt_files = list(DATA_DIR.glob("*.txt")) | |
| if not md_files and not txt_files: | |
| print("No documents found. Running download_sample_data.py first...") | |
| # Try to create sample data | |
| from scripts.download_sample_data import download_sample_data | |
| download_sample_data() | |
| # Refresh file list | |
| md_files = list(DATA_DIR.glob("*.md")) | |
| txt_files = list(DATA_DIR.glob("*.txt")) | |
| print(f"Found {len(md_files)} .md files and {len(txt_files)} .txt files") | |
| for file_path in md_files: | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| chunks = chunk_text(content) | |
| documents.extend(chunks) | |
| doc_ids.extend([file_path.name] * len(chunks)) | |
| for j, chunk in enumerate(chunks): | |
| chunk_metadata.append({ | |
| 'doc_id': file_path.name, | |
| 'chunk_index': j, | |
| 'file_type': 'markdown' | |
| }) | |
| for file_path in txt_files: | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| chunks = chunk_text(content) | |
| documents.extend(chunks) | |
| doc_ids.extend([file_path.name] * len(chunks)) | |
| for j, chunk in enumerate(chunks): | |
| chunk_metadata.append({ | |
| 'doc_id': file_path.name, | |
| 'chunk_index': j, | |
| 'file_type': 'text' | |
| }) | |
| print(f"Found {len(documents)} chunks from {len(set(doc_ids))} documents") | |
| if not documents: | |
| print("ERROR: No documents found. Please add documents to the data/ directory first.") | |
| return | |
| # Create embeddings | |
| print("Creating embeddings...") | |
| embeddings = embedder.encode(documents, show_progress_bar=True, batch_size=32) | |
| # Create FAISS index | |
| print("Creating FAISS index...") | |
| dimension = embeddings.shape[1] | |
| index = faiss.IndexFlatL2(dimension) # L2 distance | |
| index.add(embeddings.astype(np.float32)) | |
| # Save FAISS index | |
| faiss_index_path = DATA_DIR / "faiss_index.bin" | |
| faiss.write_index(index, str(faiss_index_path)) | |
| print(f"Saved FAISS index to {faiss_index_path}") | |
| # Create document store (SQLite) | |
| print("Creating document store...") | |
| conn = sqlite3.connect(DATA_DIR / "docstore.db") | |
| cursor = conn.cursor() | |
| # Create tables | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS chunks ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| chunk_text TEXT NOT NULL, | |
| doc_id TEXT NOT NULL, | |
| chunk_hash TEXT UNIQUE NOT NULL, | |
| embedding_hash TEXT, | |
| chunk_index INTEGER, | |
| file_type TEXT, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
| ) | |
| """) | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS embedding_cache ( | |
| text_hash TEXT PRIMARY KEY, | |
| embedding BLOB NOT NULL, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
| access_count INTEGER DEFAULT 0 | |
| ) | |
| """) | |
| # Insert chunks | |
| inserted_count = 0 | |
| for i, (chunk, doc_id, metadata) in enumerate(zip(documents, doc_ids, chunk_metadata)): | |
| chunk_hash = hashlib.md5(chunk.encode()).hexdigest() | |
| try: | |
| cursor.execute( | |
| """INSERT INTO chunks | |
| (chunk_text, doc_id, chunk_hash, chunk_index, file_type) | |
| VALUES (?, ?, ?, ?, ?)""", | |
| (chunk, doc_id, chunk_hash, metadata['chunk_index'], metadata['file_type']) | |
| ) | |
| inserted_count += 1 | |
| except sqlite3.IntegrityError: | |
| # Skip duplicates | |
| pass | |
| conn.commit() | |
| # Create indexes for performance | |
| cursor.execute("CREATE INDEX IF NOT EXISTS idx_chunk_hash ON chunks(chunk_hash)") | |
| cursor.execute("CREATE INDEX IF NOT EXISTS idx_doc_id ON chunks(doc_id)") | |
| conn.commit() | |
| conn.close() | |
| print(f"Saved {inserted_count} chunks to document store") | |
| # Also create embedding_cache.db if it doesn't exist | |
| cache_path = DATA_DIR / "embedding_cache.db" | |
| if not cache_path.exists(): | |
| conn = sqlite3.connect(cache_path) | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS embedding_cache ( | |
| text_hash TEXT PRIMARY KEY, | |
| embedding BLOB NOT NULL, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
| access_count INTEGER DEFAULT 0 | |
| ) | |
| """) | |
| cursor.execute("CREATE INDEX IF NOT EXISTS idx_created_at ON embedding_cache(created_at)") | |
| conn.commit() | |
| conn.close() | |
| print(f"Created embedding cache at {cache_path}") | |
| print("\nRAG system initialized successfully!") | |
| print(f"FAISS index: {faiss_index_path}") | |
| print(f"Document store: {DATA_DIR / 'docstore.db'}") | |
| print(f"Embedding cache: {DATA_DIR / 'embedding_cache.db'}") | |
| print(f"Total chunks: {len(documents)}") | |
| print(f"Embedding dimension: {dimension}") | |
| print("\nYou can now start the API server with: python -m app.main") | |
| if __name__ == "__main__": | |
| initialize_rag() | |