import gradio as gr from datasets import load_dataset from sentence_transformers import SentenceTransformer import numpy as np from pygments import highlight from pygments.lexers import get_lexer_by_name, guess_lexer from pygments.formatters import HtmlFormatter import re import os import sys sys.path.append(os.path.join(os.path.dirname(__file__), '..')) from shared.components import create_method_panel, create_premium_hero # Load code-specific model embedder = SentenceTransformer('microsoft/codebert-base') # Global storage dataset_sample = None embeddings = None code_samples = [] def load_code_dataset(progress=gr.Progress()): """Load a sample of The Stack dataset.""" global dataset_sample, embeddings, code_samples progress(0, desc="Loading The Stack dataset...") try: # Load Python subset (smaller and more accessible) dataset_sample = load_dataset( "bigcode/the-stack-smol", data_dir="data/python", split="train", streaming=True ) # Take first 500 samples code_samples = [] progress(0.3, desc="Sampling code repositories...") for i, item in enumerate(dataset_sample): if i >= 500: break code = item.get('content', '') if len(code) < 50 or len(code) > 5000: # Filter very short/long continue code_samples.append({ 'code': code, 'language': 'python', 'size': len(code), 'max_stars_repo_name': item.get('max_stars_repo_name', 'unknown'), 'max_stars_count': item.get('max_stars_count', 0), 'license': item.get('max_stars_repo_licenses', ['unknown'])[0] if item.get('max_stars_repo_licenses') else 'unknown' }) if len(code_samples) >= 300: break progress(0.7, desc="Creating code embeddings...") code_texts = [c['code'][:512] for c in code_samples] # Use first 512 chars embeddings = embedder.encode(code_texts, show_progress_bar=False) progress(1.0, desc="Ready!") avg_stars = np.mean([c['max_stars_count'] for c in code_samples]) return f"✅ Loaded {len(code_samples)} code samples (avg stars: {avg_stars:.0f})" except Exception as e: return f"❌ Error: {str(e)}\nNote: Using fallback - dataset requires internet" def extract_function_name(code): """Extract main function/class name from code.""" # Look for function definitions func_match = re.search(r'def\s+(\w+)\s*\(', code) if func_match: return func_match.group(1) # Look for class definitions class_match = re.search(r'class\s+(\w+)\s*[:\(]', code) if class_match: return class_match.group(1) return "code snippet" def syntax_highlight_code(code, language='python'): """Apply syntax highlighting to code.""" try: lexer = get_lexer_by_name(language) formatter = HtmlFormatter(style='monokai', noclasses=True) highlighted = highlight(code, lexer, formatter) return highlighted except: return f"
{code}"
def search_code(query, language='python', min_stars=0, top_k=5):
"""Search for code samples."""
if embeddings is None or not code_samples:
return []
# Filter by language and stars
filtered_samples = [
(i, sample) for i, sample in enumerate(code_samples)
if sample['language'] == language and sample['max_stars_count'] >= min_stars
]
if not filtered_samples:
# Fallback: remove star filter
filtered_samples = [(i, sample) for i, sample in enumerate(code_samples)]
indices = [i for i, _ in filtered_samples]
filtered_embeddings = embeddings[indices]
# Search
query_embedding = embedder.encode([query])
similarities = np.dot(filtered_embeddings, query_embedding.T).flatten()
top_indices = np.argsort(similarities)[-top_k:][::-1]
# Map back to original samples
results = []
for idx in top_indices:
original_idx = indices[idx]
sample = code_samples[original_idx].copy()
sample['similarity'] = float(similarities[idx])
results.append(sample)
return results
def format_code_results(results, query):
"""Format code search results."""
if not results:
return "No code samples found. Try adjusting filters or query.
" html = f"Query: {query}
" html += f"Found: {len(results)} relevant code samples
" html += "" html += f"Repo: {result['max_stars_repo_name']} | " html += f"Stars: ⭐ {result['max_stars_count']} | " html += f"License: {result['license']} | " html += f"Relevance: {result['similarity']:.3f}" html += f"
" # Code code = result['code'][:1000] # Limit display length highlighted = syntax_highlight_code(code, result['language']) html += highlighted # Copy button (using JavaScript) escaped_code = result['code'].replace('`', '\\`').replace('$', '\\$') html += f""" """ html += "Please enter a search query
", "" if embeddings is None: return "Please load the dataset first
", "" progress(0, desc="Searching code...") results = search_code(query, language, min_stars, top_k=num_results) progress(0.7, desc="Formatting results...") formatted = format_code_results(results, query) progress(1.0, desc="Done!") # Stats stats = f""" ### 📊 Search Statistics - **Total samples**: {len(code_samples)} - **Results**: {len(results)} - **Language**: {language} - **Min stars**: {min_stars} - **Model**: CodeBERT (Microsoft) ### 🧠 How CodeBERT Works CodeBERT is trained on code and documentation: - Understands programming patterns - Maps code to natural language - Trained on GitHub repos - Supports multiple languages """ return formatted, stats # Gradio Interface with gr.Blocks(title="Code Search Engine", theme=gr.themes.Soft()) as demo: create_premium_hero( "Semantic Code Search Engine", "Search code with natural language using code embeddings, dataset sampling, and syntax-highlighted retrieval results.", "💻", badge="Code Intelligence", highlights=["CodeBERT", "The Stack sample", "Semantic retrieval"], ) create_method_panel({ "Technique": "Encode code snippets into vectors and rank them against natural-language queries.", "What it proves": "You can adapt embedding search beyond documents into developer tooling.", "HF capability": "Combines Hub datasets with transformer embeddings in an interactive Space.", }) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Step 1: Load Dataset") load_btn = gr.Button("Load Code Dataset", variant="primary") load_status = gr.Textbox(label="Status", interactive=False) gr.Markdown("### Step 2: Search Code") query_input = gr.Textbox( label="What are you looking for?", placeholder="e.g., binary search implementation", lines=2 ) language = gr.Dropdown( choices=['python'], value='python', label="Language (more coming soon)" ) min_stars = gr.Slider( minimum=0, maximum=1000, value=0, step=10, label="Minimum GitHub Stars" ) num_results = gr.Slider( minimum=3, maximum=10, value=5, step=1, label="Number of Results" ) search_btn = gr.Button("Search Code", variant="primary") gr.Markdown(""" ### 💡 Example Searches: - "binary search tree" - "web scraper with requests" - "recursive fibonacci" - "API client with authentication" - "data validation decorator" """) with gr.Column(scale=2): results_output = gr.HTML(label="Code Results") with gr.Accordion("📊 Statistics & Info", open=False): stats_output = gr.Markdown() gr.Markdown(""" ### 🎯 Why Semantic Code Search? **Traditional search** (GitHub, Google): - Keyword matching only - Must know exact function names - Hard to find by functionality **Semantic search** (this tool): - Search by what code does, not what it's called - "sort a list" finds quicksort, mergesort, etc. - Understands programming concepts ### 🔧 Features: - **Syntax highlighting** with Pygments - **Copy to clipboard** button - **Filter by stars** (code quality proxy) - **License information** (know before you use) - **CodeBERT embeddings** (code + NL understanding) Perfect for developers learning, debugging, or finding code examples! """) load_btn.click( load_code_dataset, outputs=[load_status] ) search_btn.click( perform_code_search, inputs=[query_input, language, min_stars, num_results], outputs=[results_output, stats_output] ) if __name__ == "__main__": demo.launch()