File size: 3,001 Bytes
89c6054
 
 
 
 
 
 
 
 
 
a176b28
89c6054
 
 
a176b28
89c6054
 
41f7e4f
a176b28
89c6054
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a176b28
89c6054
 
 
a176b28
89c6054
a176b28
89c6054
 
a176b28
89c6054
a176b28
 
 
89c6054
 
 
 
 
a176b28
 
89c6054
 
 
a176b28
89c6054
 
 
 
 
a176b28
89c6054
 
a176b28
89c6054
 
 
 
 
 
 
a176b28
 
9b173c6
a176b28
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse
from PIL import Image
import tempfile
import os
import base64
import cv2
import io
import re
from together import Together
import releaf_ai  # this should still contain your SYSTEM_PROMPT

app = FastAPI()

# Init Together client
API_KEY = "1495bcdf0c72ed1e15d0e3e31e4301bd665cb28f2291bcc388164ed745a7aa24"
client = Together(api_key=API_KEY)
MODEL_NAME = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"

SYSTEM_PROMPT = releaf_ai.SYSTEM_PROMPT

def encode_image_to_base64(image: Image.Image) -> str:
    buffered = io.BytesIO()
    image.save(buffered, format="JPEG")
    return base64.b64encode(buffered.getvalue()).decode("utf-8")

def extract_score(text: str):
    match = re.search(r"(?i)Score:\s*(\d+)", text)
    return int(match.group(1)) if match else None

def extract_activity(text: str):
    match = re.search(r"(?i)Detected Activity:\s*(.+?)\n", text)
    return match.group(1).strip() if match else "Unknown"

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    try:
        if file.content_type.startswith("image"):
            image = Image.open(io.BytesIO(await file.read())).convert("RGB")

        elif file.content_type.startswith("video"):
            temp_path = tempfile.NamedTemporaryFile(delete=False).name
            with open(temp_path, "wb") as f:
                f.write(await file.read())

            cap = cv2.VideoCapture(temp_path)
            total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            interval = max(total // 9, 1)

            frames = []
            for i in range(9):
                cap.set(cv2.CAP_PROP_POS_FRAMES, i * interval)
                ret, frame = cap.read()
                if ret:
                    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                    img = Image.fromarray(frame).resize((256, 256))
                    frames.append(img)
            cap.release()
            os.remove(temp_path)

            w, h = frames[0].size
            grid = Image.new("RGB", (3 * w, 3 * h))
            for idx, frame in enumerate(frames):
                grid.paste(frame, ((idx % 3) * w, (idx // 3) * h))
            image = grid

        else:
            raise HTTPException(status_code=400, detail="Unsupported file type")

        b64_img = encode_image_to_base64(image)
        messages = [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": [
                {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{b64_img}"}}
            ]}
        ]
        res = client.chat.completions.create(model=MODEL_NAME, messages=messages)
        reply = res.choices[0].message.content

        return JSONResponse({
            "points": extract_score(reply),
            "task": extract_activity(reply),
            "raw": reply
        })

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))