| |
|
| | from typing import Dict, Any |
| | from flow_modules.aiflows.OpenAIChatFlowModule import OpenAIChatAtomicFlow |
| | from flows.utils.general_helpers import encode_image,encode_from_buffer |
| | import cv2 |
| |
|
| |
|
| | class VisionAtomicFlow(OpenAIChatAtomicFlow): |
| | |
| | @staticmethod |
| | def get_image(image): |
| | extension_dict = { |
| | "jpg": "jpeg", |
| | "jpeg": "jpeg", |
| | "png": "png", |
| | "webp": "webp", |
| | "gif": "gif" |
| | } |
| | supported_image_types = ["local_path","url"] |
| | assert image.get("type",None) in supported_image_types, f"Must define a valid image type for every image \n your type: {image.get('type',None)} \n supported types{supported_image_types} " |
| | |
| | processed_image = None |
| | url = None |
| | if image["type"] == "local_path": |
| | processed_image = encode_image(image.get("image")) |
| | image_extension_type = image.get("image").split(".")[-1] |
| | url = f"data:image/{extension_dict[image_extension_type]};base64, {processed_image}" |
| | |
| | elif image["type"] == "url": |
| | processed_image = image |
| | url = image.get("image") |
| | |
| | return {"type": "image_url", "image_url": {"url": url}} |
| | |
| | @staticmethod |
| | def get_video(video): |
| | video_path = video["video_path"] |
| | resize = video.get("resize",768) |
| | frame_step_size = video.get("frame_step_size",10) |
| | start_frame = video.get("start_frame",0) |
| | end_frame = video.get("end_frame",None) |
| | base64Frames = [] |
| | video = cv2.VideoCapture(video_path) |
| | while video.isOpened(): |
| | success,frame = video.read() |
| | if not success: |
| | break |
| | _,buffer = cv2.imencode(".jpg",frame) |
| | base64Frames.append(encode_from_buffer(buffer)) |
| | video.release() |
| | return map(lambda x: {"image": x, "resize": resize},base64Frames[start_frame:end_frame:frame_step_size]) |
| |
|
| | @staticmethod |
| | def get_user_message(prompt_template, input_data: Dict[str, Any]): |
| | content = VisionAtomicFlow._get_message(prompt_template=prompt_template,input_data=input_data) |
| | media_data = input_data["data"] |
| | if "video" in media_data: |
| | content = [ content[0], *VisionAtomicFlow.get_video(media_data["video"])] |
| | if "images" in media_data: |
| | images = [VisionAtomicFlow.get_image(image) for image in media_data["images"]] |
| | content.extend(images) |
| | return content |
| | |
| | @staticmethod |
| | def _get_message(prompt_template, input_data: Dict[str, Any]): |
| | template_kwargs = {} |
| | for input_variable in prompt_template.input_variables: |
| | template_kwargs[input_variable] = input_data[input_variable] |
| | msg_content = prompt_template.format(**template_kwargs) |
| | return [{"type": "text", "text": msg_content}] |
| |
|
| | def _process_input(self, input_data: Dict[str, Any]): |
| | if self._is_conversation_initialized(): |
| | |
| | user_message_content = self.get_user_message(self.human_message_prompt_template, input_data) |
| |
|
| | else: |
| | |
| | self._initialize_conversation(input_data) |
| | if getattr(self, "init_human_message_prompt_template", None) is not None: |
| | |
| | user_message_content = self.get_user_message(self.init_human_message_prompt_template, input_data) |
| | else: |
| | user_message_content = self.get_user_message(self.human_message_prompt_template, input_data) |
| |
|
| | self._state_update_add_chat_message(role=self.flow_config["user_name"], |
| | content=user_message_content) |