Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import joblib | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| # Complexity descriptions (robust mapping for codeparrot labels) | |
| DESCRIPTIONS = { | |
| "O(1)": ("O(1)", "β‘ Constant Time", "Executes in the same time regardless of input size. Very fast!"), | |
| "O(N)": ("O(N)", "π Linear Time", "Execution time grows linearly with input size."), | |
| "O(log N)": ("O(log N)", "π Logarithmic Time", "Very efficient! Common in binary search algorithms."), | |
| "O(N log N)": ("O(N log N)", "βοΈ Linearithmic Time", "Common in efficient sorting algorithms like merge sort."), | |
| "O(N^2)": ("O(NΒ²)", "π’ Quadratic Time", "Execution time grows quadratically. Common in nested loops."), | |
| "O(N^3)": ("O(NΒ³)", "π¦ Cubic Time", "Triple nested loops. Avoid for large inputs."), | |
| "O(2^N)": ("O(2βΏ)", "π Exponential Time", "NP-Hard complexity. Only feasible for very small inputs."), | |
| "O(NP)": ("O(NP)", "π NP-Complete", "Infeasible for large inputs without approximation."), | |
| "constant": ("O(1)", "β‘ Constant Time", "Executes in the same time regardless of input size. Very fast!"), | |
| "linear": ("O(N)", "π Linear Time", "Execution time grows linearly with input size."), | |
| "quadratic": ("O(NΒ²)", "π’ Quadratic Time", "Execution time grows quadratically. Common in nested loops."), | |
| } | |
| app = FastAPI(title="Code Complexity Predictor API") | |
| class PredictRequest(BaseModel): | |
| code: str | |
| # Global state | |
| model = None | |
| tokenizer = None | |
| le = None | |
| device = None | |
| def load_resources(): | |
| global model, tokenizer, le, device | |
| print("Loading resources...") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained("microsoft/graphcodebert-base") | |
| # Load label encoder | |
| if os.path.exists("label_encoder.pkl"): | |
| le = joblib.load("label_encoder.pkl") | |
| else: | |
| print("WARNING: label_encoder.pkl not found!") | |
| # Load model | |
| model = AutoModelForSequenceClassification.from_pretrained("microsoft/graphcodebert-base", num_labels=7) | |
| if os.path.exists("best_model.pt"): | |
| model.load_state_dict(torch.load("best_model.pt", map_location=device)) | |
| else: | |
| print("WARNING: best_model.pt not found!") | |
| model.to(device) | |
| model.eval() | |
| print("Resources loaded successfully!") | |
| def predict_complexity(request: PredictRequest): | |
| code = request.code | |
| if not code.strip(): | |
| raise HTTPException(status_code=400, detail="Code cannot be empty") | |
| try: | |
| inputs = tokenizer(code, truncation=True, max_length=512, padding='max_length', return_tensors='pt') | |
| input_ids = inputs['input_ids'].to(device) | |
| attention_mask = inputs['attention_mask'].to(device) | |
| with torch.no_grad(): | |
| outputs = model(input_ids=input_ids, attention_mask=attention_mask) | |
| pred = torch.argmax(outputs.logits, dim=1).item() | |
| label = le.inverse_transform([pred])[0] | |
| notation, title, description = DESCRIPTIONS.get(label, (label, label, "")) | |
| return { | |
| "notation": notation, | |
| "title": title, | |
| "description": description | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Mount frontend | |
| app.mount("/", StaticFiles(directory="frontend", html=True), name="frontend") | |