| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from PIL import Image |
| | from io import BytesIO |
| | import base64 |
| | import numpy as np |
| | import os, math, cv2, re |
| |
|
| | import torch |
| | from transformers import StoppingCriteria |
| | from apollo.constants import * |
| |
|
| | import tempfile |
| | from io import BytesIO |
| | from decord import VideoReader, cpu |
| |
|
| |
|
| |
|
| | def read_video_cv2(video_path, all_indices): |
| | vidcap = cv2.VideoCapture(video_path) |
| | frames_dict = {} |
| | max_index = max(all_indices) |
| | count = 0 |
| | success = True |
| | while success and count <= max_index: |
| | success, frame = vidcap.read() |
| | if success and count in all_indices: |
| | img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| | im_pil = Image.fromarray(img) |
| | frames_dict[count] = im_pil |
| | count += 1 |
| | |
| | images = [frames_dict[idx] for idx in all_indices if idx in frames_dict] |
| | return np.stack([np.array(img) for img in images]) |
| |
|
| | def read_video_decord(video_file, all_indices): |
| | vr = VideoReader(video_file, num_threads=1, ctx=cpu(0)) |
| | return vr.get_batch(all_indices).asnumpy() |
| |
|
| |
|
| | def read_video_decord_eval(video_file, all_indices): |
| | vr = VideoReader(video_file) |
| | return vr.get_batch(all_indices).asnumpy() |
| |
|
| | def load_frames_from_video(video_file, all_indices, video_decode_backend="decord", eval_=False): |
| | video_ending = os.path.splitext(video_file)[1] |
| | if video_ending in ['.gif', '.webm'] or video_decode_backend=="opencv": |
| | buffer = read_video_cv2(video_file, all_indices) |
| | else: |
| | |
| | if eval_: |
| | buffer = read_video_decord_eval(video_file, all_indices) |
| | else: |
| | buffer = read_video_decord(video_file, all_indices) |
| | return buffer |
| |
|
| | def pad_to_center_square(frames, mean_values): |
| | """ |
| | Pad the given frame or frames numpy array to square dimensions using the mean values as the padding color. |
| | Handles both single frames (H, W, C) and batches of frames (N, H, W, C). |
| | |
| | Args: |
| | frames (np.array): The input frame array of shape (H, W, C) or (N, H, W, C). |
| | mean_values (tuple): Mean values for each channel, typically derived from dataset normalization parameters. |
| | |
| | Returns: |
| | np.array: The padded frame array with square dimensions. |
| | """ |
| | if frames.ndim == 3: |
| | frames = frames[np.newaxis, :] |
| | elif frames.ndim != 4: |
| | raise ValueError("Input array must be either of shape (H, W, C) or (N, H, W, C)") |
| |
|
| | N, height, width, channels = frames.shape |
| | size = max(width, height) |
| | background_color = np.array(mean_values, dtype=frames.dtype) |
| | |
| | |
| | padded_frames = np.full((N, size, size, channels), background_color, dtype=frames.dtype) |
| |
|
| | |
| | top, left = (size - height) // 2, (size - width) // 2 |
| |
|
| | |
| | padded_frames[:, top:top + height, left:left + width, :] = frames |
| | return padded_frames |
| |
|
| |
|
| | def expand2square(pil_img, background_color): |
| | width, height = pil_img.size |
| | if width == height: |
| | return pil_img |
| | elif width > height: |
| | result = Image.new(pil_img.mode, (width, width), background_color) |
| | result.paste(pil_img, (0, (width - height) // 2)) |
| | |
| | return result |
| | else: |
| | result = Image.new(pil_img.mode, (height, height), background_color) |
| | result.paste(pil_img, ((height - width) // 2, 0)) |
| | |
| | return result |
| |
|
| |
|
| | def calculate_sample_indices(clip_duration, frames_per_clip, total_frames, original_fps, video_duration, clip_sampling_ratio=1): |
| | sample_video_fps = frames_per_clip / clip_duration |
| | num_clips = math.ceil((video_duration / clip_duration) * clip_sampling_ratio) |
| | frame_step = original_fps / sample_video_fps |
| | partition_len = total_frames // num_clips |
| | all_indices, clip_indices, timestamps = [], [], [] |
| | if frame_step > 0.5: |
| | frame_step = max(1, int(original_fps / sample_video_fps)) |
| | clip_len = int(frames_per_clip * frame_step) |
| | sample_len = min(clip_len, total_frames) |
| | clip_step = (total_frames - clip_len) // max(1, (num_clips - 1)) if total_frames > clip_len else 0 |
| | for i in range(num_clips): |
| | if partition_len > clip_len: |
| | start_idx = (partition_len - clip_len) // 2 |
| | end_idx = start_idx + clip_len |
| | indices = np.arange(start_idx, end_idx, frame_step) |
| | indices = np.clip(indices, 0, partition_len-1).astype(np.int64) |
| | indices = indices+ i * partition_len |
| |
|
| | else: |
| | |
| | indices = np.arange(0, sample_len, frame_step) |
| | if len(indices) < frames_per_clip: |
| | padding = np.full(frames_per_clip - len(indices), sample_len) |
| | indices = np.concatenate((indices, padding)) |
| | |
| | indices = np.clip(indices, 0, sample_len-1).astype(np.int64) |
| | indices = indices + i * clip_step |
| |
|
| | clip_indices.append(indices) |
| | all_indices.extend(list(indices)) |
| |
|
| | |
| | start_time = (indices[0] / original_fps) |
| | end_time = (indices[-1] / original_fps) |
| | timestamps.append((start_time, end_time)) |
| |
|
| | else: |
| | |
| | |
| | |
| | num_sample = int(np.ceil(1 / frame_step)) |
| | |
| | |
| | clip_len = int(frames_per_clip * frame_step) |
| | |
| | |
| | indices = np.repeat(np.arange(clip_len), num_sample) |
| |
|
| | |
| | clip_len = min(clip_len, len(indices)) |
| | clip_step = (total_frames - clip_len) // max(1, (num_clips - 1)) if total_frames > clip_len else 0 |
| | |
| | sample_len = min(clip_len, total_frames) |
| | if len(indices) < frames_per_clip: |
| | padding = np.full(frames_per_clip - len(indices), sample_len) |
| | indices = np.concatenate((indices, padding)) |
| | |
| | |
| | for i in range(num_clips): |
| | current_clip_indices = np.clip(indices, 0, sample_len-1).astype(np.int64) |
| | current_clip_indices = current_clip_indices + i * clip_step |
| |
|
| | |
| | clip_indices.append(current_clip_indices) |
| | all_indices.extend(current_clip_indices) |
| |
|
| | |
| | start_time = (current_clip_indices[0] / original_fps) |
| | end_time = (current_clip_indices[-1] / original_fps) |
| | timestamps.append((start_time, end_time)) |
| |
|
| | return clip_indices, all_indices, timestamps |
| |
|
| | def calculate_sample_indices_uniform(frames_per_clip, total_frames, uniform_frame_count, original_fps): |
| |
|
| | |
| | if total_frames >= N: |
| | |
| | indices = np.linspace(0, total_frames - 1, N, dtype=int) |
| | else: |
| | |
| | repeats = math.ceil(N / total_frames) |
| | base_indices = np.arange(total_frames) |
| | indices = np.tile(base_indices, repeats)[:N] |
| |
|
| | |
| | clip_indices = [ |
| | indices[i * frames_per_clip: (i + 1) * frames_per_clip] |
| | for i in range(num_clips) |
| | ] |
| |
|
| | |
| | timestamps = [] |
| | for clip in clip_indices: |
| | start_time = clip[0] / original_fps |
| | end_time = clip[-1] / original_fps |
| | timestamps.append((start_time, end_time)) |
| |
|
| | all_indices = indices.tolist() |
| | return clip_indices, all_indices, timestamps |
| |
|
| |
|
| | def get_video_details(fname): |
| | """ Load video content using Decord """ |
| | assert os.path.exists(fname), f'video path not found {fname}' |
| | _fsize = os.path.getsize(fname) |
| | assert _fsize >= 1 * 1024, f"video too short {fname}" |
| | vr = VideoReader(fname, num_threads=-1, ctx=cpu(0)) |
| | |
| | total_frames = len(vr) |
| | original_fps = vr.get_avg_fps() |
| | video_duration = total_frames / original_fps |
| | return total_frames, original_fps, video_duration |
| |
|
| |
|
| | def get_video_details_cv2(fname): |
| | """ |
| | Load video content using OpenCV (cv2) and retrieve video details. |
| | |
| | Args: |
| | fname (str): Path to the video file. |
| | |
| | Returns: |
| | tuple: A tuple containing: |
| | - total_frames (int): Total number of frames in the video. |
| | - original_fps (float): Frames per second of the video. |
| | - video_duration (float): Duration of the video in seconds. |
| | |
| | Raises: |
| | AssertionError: If the file does not exist or is too short. |
| | ValueError: If the video cannot be opened or FPS is zero. |
| | """ |
| | |
| | assert os.path.exists(fname), f'Video path not found: {fname}' |
| | |
| | |
| | _fsize = os.path.getsize(fname) |
| | assert _fsize >= 1 * 1024, f"Video too short: {fname}" |
| | |
| | |
| | cap = cv2.VideoCapture(fname) |
| | if not cap.isOpened(): |
| | raise ValueError(f"Failed to open video file: {fname}") |
| | |
| | |
| | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| | |
| | |
| | original_fps = cap.get(cv2.CAP_PROP_FPS) |
| | if original_fps == 0: |
| | cap.release() |
| | raise ValueError(f"Failed to get FPS for video file: {fname}") |
| | |
| | |
| | video_duration = total_frames / original_fps |
| | |
| | |
| | cap.release() |
| | |
| | return total_frames, original_fps, video_duration |
| |
|
| |
|
| | |
| | def split_into_clips(video, frames_per_clip): |
| | """ Split video into a list of clips """ |
| | fpc = frames_per_clip |
| | nc = len(video) // frames_per_clip |
| | return [video[i*fpc:(i+1)*fpc] for i in range(nc)] |
| |
|
| | def process_image(vision_processors, frames_per_clip, image): |
| | mm_data = [] |
| | for vision_processor in vision_processors: |
| | tmp = expand2square(image, tuple(int(x * 255) for x in vision_processor.image_mean)) |
| | tmp = np.expand_dims(np.asarray(tmp), 0) |
| | tmp = vision_processor.preprocess(tmp, return_tensors='pt')['pixel_values'][0].unsqueeze(0) |
| | if len(tmp.shape)==4: |
| | |
| | tmp = tmp.unsqueeze(1) |
| | tmp = tmp.repeat_interleave(frames_per_clip, dim=1) |
| | else: |
| | |
| | if tmp.shape[1]==1: |
| | tmp = tmp.repeat_interleave(frames_per_clip, dim=1) |
| | else: |
| | tmp = tmp.repeat_interleave(frames_per_clip, dim=2) |
| | |
| | mm_data.append(tmp) |
| | return mm_data |
| |
|
| | def process_video(vision_processors, frames_per_clip, buffer): |
| | mm_data=[] |
| | for vision_processor in vision_processors: |
| | centered_buffer = pad_to_center_square(buffer, tuple(int(x * 255) for x in vision_processor.image_mean)) |
| | processed_clips = [] |
| | for clip in split_into_clips(centered_buffer, frames_per_clip): |
| | clip = vision_processor.preprocess(clip, return_tensors='pt')['pixel_values'] |
| | if type(clip) is list: |
| | assert len(clip)==1, "LazyVideoDataset: error, vision processor returned clip that is list of len>1 ." |
| | clip = clip[0] |
| | processed_clips.append(clip) |
| | mm_data.append(torch.stack(processed_clips)) |
| | return mm_data |
| |
|
| | def load_video(video_file, vision_processors, clip_duration, frames_per_clip, clip_sampling_ratio=1, video_decode_backend='decord', eval_=False): |
| | total_frames, original_fps, video_duration = get_video_details(video_file) |
| | _, all_indices, timestamps = calculate_sample_indices(clip_duration, frames_per_clip, total_frames, original_fps, video_duration, clip_sampling_ratio=clip_sampling_ratio) |
| | buffer = load_frames_from_video(video_file, all_indices, video_decode_backend, eval_) |
| | mm_data = process_video(vision_processors, frames_per_clip, buffer) |
| | return mm_data, timestamps |
| |
|
| | def load_video_uniform(video_file, vision_processors, clip_duration, frames_per_clip, clip_sampling_ratio=1, video_decode_backend='decord', eval_=False, uniform_sampling=8): |
| | total_frames, original_fps, video_duration = get_video_details(video_file) |
| | all_indices = np.linspace(0, total_frames-1, uniform_sampling, dtype=int) |
| | print('using uniform frame sampled, sampled: ', len(all_indices), ' frames') |
| | buffer = load_frames_from_video(video_file, all_indices, video_decode_backend, eval_) |
| | mm_data = process_video(vision_processors, frames_per_clip, buffer) |
| | return mm_data, [] |
| |
|
| |
|
| |
|
| | class ApolloMMLoader: |
| | def __init__(self, vision_processors, clip_duration, frames_per_clip, num_repeat_token, device, model_max_length = 32768, clip_sampling_ratio=1, video_decode_backend="decord"): |
| | self.vision_processors=vision_processors |
| | self.clip_duration=clip_duration |
| | self.device=device |
| | self.frames_per_clip=frames_per_clip |
| | self.num_repeat_token = num_repeat_token |
| | self.clip_sampling_ratio=clip_sampling_ratio |
| | self.model_max_length=model_max_length |
| | self.video_decode_backend=video_decode_backend |
| | self.vidprompt = lambda num_clips, video_duration : f"You are provided the following series of {num2words(num_clips)}, {self.clip_duration} second clips from a {datetime.timedelta(seconds=video_duration)} [H:MM:SS] video.\n" |
| | |
| | def load_video(self, video_file): |
| | total_frames, original_fps, video_duration = get_video_details(video_file) |
| | clip_sampling_ratio = min(1, (self.model_max_length * self.clip_sampling_ratio) / (video_duration * self.num_repeat_token / self.clip_duration)) |
| | |
| | _, all_indices, timestamps = calculate_sample_indices(self.clip_duration, self.frames_per_clip, total_frames, original_fps, video_duration, clip_sampling_ratio=clip_sampling_ratio) |
| | video, timestamps = load_video(video_file, self.vision_processors, self.clip_duration, self.frames_per_clip, clip_sampling_ratio=clip_sampling_ratio, eval_=True) |
| | |
| | num_clips = len(video[0]) |
| | num_tokens = num_clips * self.num_repeat_token |
| | video = [v.to(device=self.device, dtype=torch.bfloat16) for v in video] |
| | replace_string = self.vidprompt(num_clips, video_duration) |
| |
|
| | temporal_prompt = [f"{round(clip[0], 1)}-{round(clip[1], 1)} seconds: {X_TOKEN['video'] * self.num_repeat_token}" for clip in timestamps] |
| | temporal_prompt = ',\n'.join(temporal_prompt) |
| | replace_string = replace_string + temporal_prompt |
| | |
| | return video, replace_string |
| | |
| | def load_image(self, image_file): |
| | print('implement image loading') |
| | return None |
| |
|
| |
|
| | def get_frame_from_vcap(vidcap, num_frames=10, fps=None, frame_count=None): |
| | import cv2 |
| |
|
| | if fps == None or frame_count == None: |
| | |
| | fps = vidcap.get(cv2.CAP_PROP_FPS) |
| | frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| | if fps == 0 or frame_count == 0: |
| | print("Video file not found. return empty images.") |
| | return [ |
| | Image.new("RGB", (720, 720)), |
| | ] * num_frames |
| | |
| | duration = frame_count / fps |
| | frame_interval = frame_count // num_frames |
| | if frame_interval == 0 and frame_count <= 1: |
| | print("frame_interval is equal to 0. return empty image.") |
| | return [ |
| | Image.new("RGB", (720, 720)), |
| | ] * num_frames |
| | |
| |
|
| | images = [] |
| | count = 0 |
| | success = True |
| | frame_indices = np.linspace(0, frame_count - 2, num_frames, dtype=int) |
| |
|
| | while success: |
| | |
| | if frame_count >= num_frames: |
| | success, frame = vidcap.read() |
| | if count in frame_indices: |
| | img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| | im_pil = Image.fromarray(img) |
| | images.append(im_pil) |
| | if len(images) >= num_frames: |
| | return images |
| | count += 1 |
| | else: |
| | |
| | success, frame = vidcap.read() |
| | if success: |
| | img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| | im_pil = Image.fromarray(img) |
| | images.append(im_pil) |
| | count += 1 |
| | elif count >= 1: |
| | width, height = images[-1].size |
| | images = [Image.new("RGB", (width, height))] * (num_frames - len(images)) + images |
| | print("padding frames:", (num_frames - len(images))) |
| | return images |
| | else: |
| | break |
| | raise ValueError("Did not find enough frames in the video. return empty image.") |
| |
|
| |
|
| | def opencv_extract_frames(vpath_or_bytesio, frames=6, fps=None, frame_count=None): |
| | """ |
| | Extract frames from a video using OpenCV. |
| | |
| | Args: |
| | vpath_or_bytesio (str or BytesIO): Path to the video file or BytesIO object containing the video. |
| | frames (int): Number of frames to extract from the video. |
| | |
| | Returns: |
| | list: List of PIL Images extracted from the video. |
| | |
| | Raises: |
| | NotImplementedError: If the type of `vpath_or_bytesio` is not supported. |
| | """ |
| | import cv2 |
| |
|
| | if isinstance(vpath_or_bytesio, str): |
| | vidcap = cv2.VideoCapture(vpath_or_bytesio) |
| | return get_frame_from_vcap(vidcap, frames, fps=fps, frame_count=frame_count) |
| | elif isinstance(vpath_or_bytesio, (BytesIO,)): |
| | |
| | with tempfile.NamedTemporaryFile(delete=True, suffix=".mp4") as temp_video: |
| | temp_video.write(vpath_or_bytesio.read()) |
| | temp_video_name = temp_video.name |
| | vidcap = cv2.VideoCapture(temp_video_name) |
| | return get_frame_from_vcap(vidcap, frames, fps=fps, frame_count=frame_count) |
| | else: |
| | raise NotImplementedError(type(vpath_or_bytesio)) |
| |
|
| |
|
| | def load_image_from_base64(image): |
| | return Image.open(BytesIO(base64.b64decode(image))) |
| |
|
| |
|
| | def expand2square(pil_img, background_color): |
| | """ |
| | Expand the given PIL image to a square shape by adding padding. |
| | |
| | Parameters: |
| | - pil_img: The PIL image to be expanded. |
| | - background_color: The color of the padding to be added. |
| | |
| | Returns: |
| | - The expanded PIL image. |
| | |
| | If the image is already square, it is returned as is. |
| | If the image is wider than it is tall, padding is added to the top and bottom. |
| | If the image is taller than it is wide, padding is added to the left and right. |
| | """ |
| | width, height = pil_img.size |
| | if pil_img.mode == 'L': |
| | background_color = background_color[0] |
| | if width == height: |
| | return pil_img |
| | elif width > height: |
| | result = Image.new(pil_img.mode, (width, width), background_color) |
| | result.paste(pil_img, (0, (width - height) // 2)) |
| | return result |
| | else: |
| | result = Image.new(pil_img.mode, (height, height), background_color) |
| | result.paste(pil_img, ((height - width) // 2, 0)) |
| | return result |
| |
|
| |
|
| |
|
| | def process_images(images, image_processor, model_cfg): |
| |
|
| | model_cfg.image_processor = image_processor |
| | new_images = [process_image(image, model_cfg, None) for image in images] |
| |
|
| | if all(x.shape == new_images[0].shape for x in new_images): |
| | new_images = torch.stack(new_images, dim=0) |
| | return new_images |
| |
|
| |
|
| |
|
| |
|
| | def tokenizer_mm_token(prompt, tokenizer, return_tensors=None): |
| | tokens_regex = re.compile('|'.join(re.escape(token) for token in X_TOKEN.values())) |
| | input_ids, last_pos, start_id = [], 0, 0 |
| | for match in tokens_regex.finditer(prompt): |
| | if match.start() > last_pos: |
| | input_ids.extend(tokenizer(prompt[last_pos:match.start()]).input_ids) |
| | elif match.start() == 0: |
| | input_ids = tokenizer('').input_ids |
| | start_id = 1 |
| | input_ids.append(X_TOKEN_INDEX) |
| | last_pos = match.end() |
| | if last_pos < len(prompt): |
| | input_ids.extend(tokenizer(prompt[last_pos:]).input_ids[start_id:]) |
| | return torch.tensor(input_ids, dtype=torch.long) if return_tensors == 'pt' else input_ids |
| |
|
| |
|
| | def is_gemma_tokenizer(tokenizer): |
| | return "gemma" in tokenizer.__class__.__name__.lower() |
| |
|
| |
|
| | def get_model_name_from_path(model_path): |
| | model_path = model_path.strip("/") |
| | model_paths = model_path.split("/") |
| | if model_paths[-1].startswith("checkpoint-"): |
| | return model_paths[-2] + "_" + model_paths[-1] |
| | else: |
| | return model_paths[-1] |
| |
|
| |
|
| | class KeywordsStoppingCriteria(StoppingCriteria): |
| | def __init__(self, keywords, tokenizer, input_ids): |
| | self.keywords = keywords |
| | self.keyword_ids = [] |
| | self.max_keyword_len = 0 |
| | for keyword in keywords: |
| | cur_keyword_ids = tokenizer(keyword).input_ids |
| | if ( |
| | len(cur_keyword_ids) > 1 |
| | and cur_keyword_ids[0] == tokenizer.bos_token_id |
| | ): |
| | cur_keyword_ids = cur_keyword_ids[1:] |
| | if len(cur_keyword_ids) > self.max_keyword_len: |
| | self.max_keyword_len = len(cur_keyword_ids) |
| | self.keyword_ids.append(torch.tensor(cur_keyword_ids)) |
| | self.tokenizer = tokenizer |
| | self.start_len = input_ids.shape[1] |
| |
|
| | def call_for_batch( |
| | self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs |
| | ) -> bool: |
| | offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) |
| | self.keyword_ids = [ |
| | keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids |
| | ] |
| | for keyword_id in self.keyword_ids: |
| | if (output_ids[0, -keyword_id.shape[0] :] == keyword_id).all(): |
| | return True |
| | outputs = self.tokenizer.batch_decode( |
| | output_ids[:, -offset:], skip_special_tokens=True |
| | )[0] |
| | for keyword in self.keywords: |
| | if keyword in outputs: |
| | return True |
| | return False |
| |
|
| | def __call__( |
| | self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs |
| | ) -> bool: |
| | outputs = [] |
| | for i in range(output_ids.shape[0]): |
| | outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores)) |
| | return all(outputs) |
| |
|