File size: 3,643 Bytes
1a976d8
 
 
 
 
 
 
 
d054e60
1a976d8
d054e60
 
 
 
 
 
 
 
 
 
 
1a976d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b2967c
1a976d8
 
 
 
 
 
 
 
6b2967c
1a976d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
import torch
import joblib
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# Complexity descriptions (robust mapping for codeparrot labels)
DESCRIPTIONS = {
    "O(1)":       ("O(1)", "⚑ Constant Time", "Executes in the same time regardless of input size. Very fast!"),
    "O(N)":       ("O(N)", "πŸ“ˆ Linear Time", "Execution time grows linearly with input size."),
    "O(log N)":   ("O(log N)", "πŸ” Logarithmic Time", "Very efficient! Common in binary search algorithms."),
    "O(N log N)": ("O(N log N)", "βš™οΈ Linearithmic Time", "Common in efficient sorting algorithms like merge sort."),
    "O(N^2)":     ("O(N²)", "🐒 Quadratic Time", "Execution time grows quadratically. Common in nested loops."),
    "O(N^3)":     ("O(NΒ³)", "πŸ¦• Cubic Time", "Triple nested loops. Avoid for large inputs."),
    "O(2^N)":     ("O(2ⁿ)", "πŸ’€ Exponential Time", "NP-Hard complexity. Only feasible for very small inputs."),
    "O(NP)":      ("O(NP)", "πŸ’€ NP-Complete", "Infeasible for large inputs without approximation."),
    "constant":   ("O(1)", "⚑ Constant Time", "Executes in the same time regardless of input size. Very fast!"),
    "linear":     ("O(N)", "πŸ“ˆ Linear Time", "Execution time grows linearly with input size."),
    "quadratic":  ("O(N²)", "🐒 Quadratic Time", "Execution time grows quadratically. Common in nested loops."),
}

app = FastAPI(title="Code Complexity Predictor API")

class PredictRequest(BaseModel):
    code: str

# Global state
model = None
tokenizer = None
le = None
device = None

@app.on_event("startup")
def load_resources():
    global model, tokenizer, le, device
    print("Loading resources...")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained("microsoft/graphcodebert-base")
    
    # Load label encoder
    if os.path.exists("label_encoder.pkl"):
        le = joblib.load("label_encoder.pkl")
    else:
        print("WARNING: label_encoder.pkl not found!")
        
    # Load model
    model = AutoModelForSequenceClassification.from_pretrained("microsoft/graphcodebert-base", num_labels=7)
    if os.path.exists("best_model.pt"):
        model.load_state_dict(torch.load("best_model.pt", map_location=device))
    else:
        print("WARNING: best_model.pt not found!")
        
    model.to(device)
    model.eval()
    print("Resources loaded successfully!")

@app.post("/api/predict")
def predict_complexity(request: PredictRequest):
    code = request.code
    if not code.strip():
        raise HTTPException(status_code=400, detail="Code cannot be empty")
        
    try:
        inputs = tokenizer(code, truncation=True, max_length=512, padding='max_length', return_tensors='pt')
        input_ids = inputs['input_ids'].to(device)
        attention_mask = inputs['attention_mask'].to(device)

        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            pred = torch.argmax(outputs.logits, dim=1).item()

        label = le.inverse_transform([pred])[0]
        notation, title, description = DESCRIPTIONS.get(label, (label, label, ""))

        return {
            "notation": notation,
            "title": title,
            "description": description
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

# Mount frontend
app.mount("/", StaticFiles(directory="frontend", html=True), name="frontend")