codeSentry / codesentry-backend /tools /code_parser.py
YashashviAlva's picture
Initial commit for HF Spaces deploy
7b4f5dd
"""
Code ingestion: parse from raw string, GitHub URL, or base64 zip.
Extracts file contents and builds a flat list of (path, content) tuples.
"""
from __future__ import annotations
import ast
import base64
import io
import os
import re
import zipfile
from pathlib import Path
from typing import List, Optional, Tuple
# ──────────────────────────────────────────────
# Types
# ──────────────────────────────────────────────
FileEntry = Tuple[str, str] # (relative_path, content)
SUPPORTED_EXTENSIONS = {".py", ".js", ".ts", ".go", ".java", ".rb", ".php", ".sh", ".yaml", ".yml", ".toml", ".json"}
MAX_FILE_SIZE_BYTES = 2 * 1024 * 1024 # 2 MB per file
MAX_TOTAL_FILES = 500
# ──────────────────────────────────────────────
# Raw code string
# ──────────────────────────────────────────────
def parse_code_string(code: str, filename: str = "input.py") -> List[FileEntry]:
"""Wrap a raw code string as a single-file entry."""
return [(filename, code)]
# ──────────────────────────────────────────────
# Base64-encoded zip
# ──────────────────────────────────────────────
def parse_zip_base64(b64_content: str) -> List[FileEntry]:
"""Decode a base64 zip and extract all supported source files."""
try:
raw = base64.b64decode(b64_content)
except Exception as exc:
raise ValueError(f"Invalid base64 zip content: {exc}") from exc
entries: List[FileEntry] = []
with zipfile.ZipFile(io.BytesIO(raw)) as zf:
names = [n for n in zf.namelist() if not n.endswith("/")]
for name in names[:MAX_TOTAL_FILES]:
ext = Path(name).suffix.lower()
if ext not in SUPPORTED_EXTENSIONS:
continue
info = zf.getinfo(name)
if info.file_size > MAX_FILE_SIZE_BYTES:
continue
try:
content = zf.read(name).decode("utf-8", errors="replace")
entries.append((name, content))
except Exception:
continue
return entries
# ──────────────────────────────────────────────
# Local directory (for cloned repos)
# ──────────────────────────────────────────────
def parse_directory(directory: str) -> List[FileEntry]:
"""Walk a local directory and collect all supported source files."""
root = Path(directory)
entries: List[FileEntry] = []
# Directories to skip
skip_dirs = {
".git", "__pycache__", "node_modules", ".venv", "venv",
"env", ".env", "dist", "build", ".mypy_cache", ".pytest_cache",
}
for path in root.rglob("*"):
if any(part in skip_dirs for part in path.parts):
continue
if not path.is_file():
continue
if path.suffix.lower() not in SUPPORTED_EXTENSIONS:
continue
if path.stat().st_size > MAX_FILE_SIZE_BYTES:
continue
try:
content = path.read_text(encoding="utf-8", errors="replace")
rel = str(path.relative_to(root))
entries.append((rel, content))
except Exception:
continue
if len(entries) >= MAX_TOTAL_FILES:
break
return entries
# ──────────────────────────────────────────────
# AST helpers (Python only)
# ──────────────────────────────────────────────
def extract_python_ast(code: str) -> Optional[ast.AST]:
"""Parse Python source and return the AST; returns None on parse failure."""
try:
return ast.parse(code)
except SyntaxError:
return None
def get_function_names(tree: ast.AST) -> List[str]:
"""Return all function/method names defined in an AST."""
return [
node.name
for node in ast.walk(tree)
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
]
def get_imports(tree: ast.AST) -> List[str]:
"""Return all imported module names."""
modules: List[str] = []
for node in ast.walk(tree):
if isinstance(node, ast.Import):
modules.extend(alias.name for alias in node.names)
elif isinstance(node, ast.ImportFrom):
if node.module:
modules.append(node.module)
return modules
def get_line_content(code: str, line_number: int) -> str:
"""Return the content of a specific 1-indexed line."""
lines = code.splitlines()
if 1 <= line_number <= len(lines):
return lines[line_number - 1]
return ""
def get_snippet(code: str, line_number: int, context: int = 3) -> str:
"""Return a snippet of code around a given line number (1-indexed)."""
lines = code.splitlines()
start = max(0, line_number - 1 - context)
end = min(len(lines), line_number + context)
snippet_lines = []
for i, line in enumerate(lines[start:end], start=start + 1):
prefix = ">>>" if i == line_number else " "
snippet_lines.append(f"{prefix} {i:4d} | {line}")
return "\n".join(snippet_lines)
# ──────────────────────────────────────────────
# Regex-based pattern search across files
# ──────────────────────────────────────────────
def find_pattern_in_code(
code: str,
pattern: str,
file_path: str = "unknown",
) -> List[dict]:
"""
Search for a regex pattern in code.
Returns a list of {line_number, line_content, snippet} dicts.
"""
results = []
try:
compiled = re.compile(pattern, re.MULTILINE | re.DOTALL)
except re.error:
return results
for match in compiled.finditer(code):
line_number = code[: match.start()].count("\n") + 1
results.append(
{
"file_path": file_path,
"line_number": line_number,
"line_content": get_line_content(code, line_number),
"snippet": get_snippet(code, line_number),
}
)
return results
def count_tokens_estimate(text: str) -> int:
"""Rough token count estimate (1 token β‰ˆ 4 chars)."""
return max(1, len(text) // 4)
def build_context_block(files: List[FileEntry], max_tokens: int = 3000) -> str:
"""
Concatenate files into a single context block for the LLM.
Respects an approximate token budget.
"""
blocks: List[str] = []
used_tokens = 0
for path, content in files:
header = f"\n\n# === FILE: {path} ===\n"
chunk = header + content
chunk_tokens = count_tokens_estimate(chunk)
if used_tokens + chunk_tokens > max_tokens:
break
blocks.append(chunk)
used_tokens += chunk_tokens
return "".join(blocks)