| import torch |
| from transformers import AutoProcessor, LlavaForConditionalGeneration |
| from PIL import Image |
| import base64 |
| import io |
|
|
| class EndpointHandler(): |
| def __init__(self, model_path=""): |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| self.processor = AutoProcessor.from_pretrained(model_path) |
| self.model = LlavaForConditionalGeneration.from_pretrained( |
| model_path, |
| torch_dtype=torch.bfloat16, |
| device_map="auto" if torch.cuda.is_available() else None |
| ) |
| self.model.eval() |
|
|
| def __call__(self, data): |
| inputs = data.get("inputs", {}) |
| prompt = inputs.get("prompt", "Generate a caption for this image.") |
| images_b64 = inputs.get("images") |
|
|
| |
| if isinstance(images_b64, str): |
| images_b64 = [images_b64] |
| if not images_b64: |
| return {"error": "No images provided in the payload."} |
|
|
| try: |
| images = [ |
| Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB") |
| for img_b64 in images_b64 |
| ] |
| except Exception as e: |
| return {"error": f"Failed to decode image: {str(e)}"} |
|
|
| |
| conversation = [ |
| {"role": "system", "content": "You are a helpful image captioner."}, |
| {"role": "user", "content": prompt} |
| ] |
|
|
| convo_string = self.processor.apply_chat_template( |
| conversation, |
| tokenize=False, |
| add_generation_prompt=True |
| ) |
| if not isinstance(convo_string, str): |
| return {"error": "Failed to create conversation string."} |
|
|
| |
| model_inputs = self.processor( |
| text=[convo_string], |
| images=images, |
| return_tensors="pt" |
| ) |
| model_inputs = {k: v.to(self.device) for k, v in model_inputs.items()} |
| if "pixel_values" in model_inputs: |
| model_inputs["pixel_values"] = model_inputs["pixel_values"].to(torch.bfloat16) |
|
|
| |
| generate_ids = self.model.generate( |
| **model_inputs, |
| max_new_tokens=300, |
| do_sample=True, |
| temperature=0.6, |
| top_p=0.9 |
| ) |
|
|
| |
| generate_ids = generate_ids[:, model_inputs["input_ids"].shape[1]:] |
| captions = [ |
| self.processor.tokenizer.decode( |
| ids, |
| skip_special_tokens=True, |
| clean_up_tokenization_spaces=False |
| ).strip() |
| for ids in generate_ids |
| ] |
|
|
| return {"captions": captions if len(captions) > 1 else captions[0]} |
|
|