Spaces:
Configuration error
Configuration error
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from PIL import Image | |
| import tempfile | |
| import os | |
| import base64 | |
| import cv2 | |
| import io | |
| import re | |
| from together import Together | |
| import releaf_ai # this should still contain your SYSTEM_PROMPT | |
| app = FastAPI() | |
| API_KEY = "your_api_key_here" | |
| client = Together(api_key=API_KEY) | |
| MODEL_NAME = "meta-llama/Llama-Vision-Free" | |
| SYSTEM_PROMPT = releaf_ai.SYSTEM_PROMPT | |
| def encode_image_to_base64(image: Image.Image) -> str: | |
| buffered = io.BytesIO() | |
| image.save(buffered, format="JPEG") | |
| return base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| def extract_score(text: str): | |
| match = re.search(r"(?i)Score:\s*(\d+)", text) | |
| return int(match.group(1)) if match else None | |
| def extract_activity(text: str): | |
| match = re.search(r"(?i)Detected Activity:\s*(.+?)\n", text) | |
| return match.group(1).strip() if match else "Unknown" | |
| async def predict(file: UploadFile = File(...)): | |
| try: | |
| if file.content_type.startswith("image"): | |
| image = Image.open(io.BytesIO(await file.read())).convert("RGB") | |
| elif file.content_type.startswith("video"): | |
| temp_path = tempfile.NamedTemporaryFile(delete=False).name | |
| with open(temp_path, "wb") as f: | |
| f.write(await file.read()) | |
| cap = cv2.VideoCapture(temp_path) | |
| total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| interval = max(total // 9, 1) | |
| frames = [] | |
| for i in range(9): | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, i * interval) | |
| ret, frame = cap.read() | |
| if ret: | |
| frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| img = Image.fromarray(frame).resize((256, 256)) | |
| frames.append(img) | |
| cap.release() | |
| os.remove(temp_path) | |
| w, h = frames[0].size | |
| grid = Image.new("RGB", (3 * w, 3 * h)) | |
| for idx, frame in enumerate(frames): | |
| grid.paste(frame, ((idx % 3) * w, (idx // 3) * h)) | |
| image = grid | |
| else: | |
| raise HTTPException(status_code=400, detail="Unsupported file type") | |
| b64_img = encode_image_to_base64(image) | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": [ | |
| {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64_img}"}} | |
| ]} | |
| ] | |
| res = client.chat.completions.create(model=MODEL_NAME, messages=messages) | |
| reply = res.choices[0].message.content | |
| return JSONResponse({ | |
| "points": extract_score(reply), | |
| "task": extract_activity(reply), | |
| "raw": reply | |
| }) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) |