File size: 2,910 Bytes
8e8354c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse
from PIL import Image
import tempfile
import os
import base64
import cv2
import io
import re
from together import Together
import releaf_ai  # this should still contain your SYSTEM_PROMPT

app = FastAPI()

API_KEY = "your_api_key_here"
client = Together(api_key=API_KEY)
MODEL_NAME = "meta-llama/Llama-Vision-Free"

SYSTEM_PROMPT = releaf_ai.SYSTEM_PROMPT

def encode_image_to_base64(image: Image.Image) -> str:
    buffered = io.BytesIO()
    image.save(buffered, format="JPEG")
    return base64.b64encode(buffered.getvalue()).decode("utf-8")

def extract_score(text: str):
    match = re.search(r"(?i)Score:\s*(\d+)", text)
    return int(match.group(1)) if match else None

def extract_activity(text: str):
    match = re.search(r"(?i)Detected Activity:\s*(.+?)\n", text)
    return match.group(1).strip() if match else "Unknown"

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    try:
        if file.content_type.startswith("image"):
            image = Image.open(io.BytesIO(await file.read())).convert("RGB")

        elif file.content_type.startswith("video"):
            temp_path = tempfile.NamedTemporaryFile(delete=False).name
            with open(temp_path, "wb") as f:
                f.write(await file.read())

            cap = cv2.VideoCapture(temp_path)
            total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            interval = max(total // 9, 1)

            frames = []
            for i in range(9):
                cap.set(cv2.CAP_PROP_POS_FRAMES, i * interval)
                ret, frame = cap.read()
                if ret:
                    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                    img = Image.fromarray(frame).resize((256, 256))
                    frames.append(img)
            cap.release()
            os.remove(temp_path)

            w, h = frames[0].size
            grid = Image.new("RGB", (3 * w, 3 * h))
            for idx, frame in enumerate(frames):
                grid.paste(frame, ((idx % 3) * w, (idx // 3) * h))
            image = grid

        else:
            raise HTTPException(status_code=400, detail="Unsupported file type")

        b64_img = encode_image_to_base64(image)
        messages = [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": [
                {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64_img}"}}
            ]}
        ]
        res = client.chat.completions.create(model=MODEL_NAME, messages=messages)
        reply = res.choices[0].message.content

        return JSONResponse({
            "points": extract_score(reply),
            "task": extract_activity(reply),
            "raw": reply
        })

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))