import os import torch from flask import Flask, request, jsonify, render_template, Response from flask_cors import CORS from werkzeug.utils import secure_filename from ultralytics import YOLO from dotenv import load_dotenv import time import json import traceback # Import the processing logic from processing import process_images # Load environment variables from .env file load_dotenv() app = Flask(__name__) # Enable CORS for all routes CORS(app) # --- Configuration --- UPLOAD_FOLDER = 'static/uploads' MODELS_FOLDER = 'models' ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'} # --- Load model names from .env file --- PARTS_MODEL_NAME = os.getenv('PARTS_MODEL_NAME', 'best_parts_EP336.pt') DAMAGE_MODEL_NAME = os.getenv('DAMAGE_MODEL_NAME', 'best_new_EP382.pt') # --- Model Paths --- PARTS_MODEL_PATH = os.path.join(MODELS_FOLDER, PARTS_MODEL_NAME) DAMAGE_MODEL_PATH = os.path.join(MODELS_FOLDER, DAMAGE_MODEL_NAME) app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) os.makedirs(MODELS_FOLDER, exist_ok=True) os.makedirs('templates', exist_ok=True) # --- Determine Device --- device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # --- Load YOLO Models --- parts_model, damage_model = None, None # Load Parts Model try: if not os.path.exists(PARTS_MODEL_PATH): print(f"Warning: Parts model file not found at {PARTS_MODEL_PATH}") else: parts_model = YOLO(PARTS_MODEL_PATH) parts_model.to(device) print(f"Successfully loaded parts model '{PARTS_MODEL_NAME}' on {device}.") except Exception as e: print(f"Error loading Parts Model ({PARTS_MODEL_NAME}): {e}") # Load Damage Model try: if not os.path.exists(DAMAGE_MODEL_PATH): print(f"Warning: Damage model file not found at {DAMAGE_MODEL_PATH}") else: damage_model = YOLO(DAMAGE_MODEL_PATH) damage_model.to(device) print(f"Successfully loaded damage model '{DAMAGE_MODEL_NAME}' on {device}.") except Exception as e: print(f"Error loading Damage Model ({DAMAGE_MODEL_NAME}): {e}") def allowed_file(filename): """Checks if a file's extension is in the ALLOWED_EXTENSIONS set.""" return '.' in filename and \ filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS @app.route('/') def home(): """Serve the main HTML page.""" return render_template('index.html') @app.route('/predict', methods=['POST']) def predict(): """ Endpoint to receive one or more images, process them immediately, and return the prediction results. """ # 1. --- Get Session Key and Validate --- # Session key can be used for logging or grouping, but doesn't control logic. session_key = request.form.get('session_key') if not session_key: return jsonify({"error": "No session_key provided in the payload"}), 400 # 2. --- File Validation --- if 'file' not in request.files: return jsonify({"error": "No file part in the request"}), 400 files = request.files.getlist('file') if not files or all(f.filename == '' for f in files): return jsonify({"error": "No selected files"}), 400 # 3. --- Save Files and Prepare for Processing --- saved_filepaths = [] for file in files: if file and allowed_file(file.filename): # Create a unique filename to prevent overwrites unique_filename = f"{session_key}_{int(time.time()*1000)}_{secure_filename(file.filename)}" filepath = os.path.join(app.config['UPLOAD_FOLDER'], unique_filename) file.save(filepath) saved_filepaths.append(filepath) else: print(f"Skipped invalid file: {file.filename}") if not saved_filepaths: return jsonify({"error": "No valid files were uploaded. Allowed types: png, jpg, jpeg"}), 400 # 4. --- Run Prediction --- try: print(f"Processing {len(saved_filepaths)} file(s) for session '{session_key}'...") # This function processes the images and returns the prediction results. results = process_images(parts_model, damage_model, saved_filepaths) print(f"Processing complete for session '{session_key}'.") # Return the results as a JSON response return Response(json.dumps(results), mimetype='application/json') except Exception as e: print(f"An error occurred during processing for session {session_key}: {e}") traceback.print_exc() return jsonify({"error": f"An error occurred during processing: {str(e)}"}), 500 finally: # 5. --- Clean up the saved files --- for filepath in saved_filepaths: try: if os.path.exists(filepath): os.remove(filepath) except Exception as e: print(f"Error cleaning up file {filepath}: {e}") if __name__ == '__main__': app.run(host='0.0.0.0', port=7860, debug=True)