| from typing import List |
| from queue import Queue |
|
|
| import torch |
| from PIL import Image |
| from copy import deepcopy |
| import requests, os |
|
|
| IMAGE_TOKEN_INDEX=-200 |
| blacklist = ['<image>', '<s>', '</s>'] |
| max_num_images = 3 |
|
|
| def input_moderation(texts: list[list[str]]): |
| |
| for text_pair in texts: |
| |
| for b in blacklist: |
| text_pair[0] = text_pair[0].replace(b, '') |
| if text_pair[1] is not None: |
| text_pair[1] = text_pair[1].replace(b, '') |
| |
| return texts |
|
|
| def insert_image_placeholder(t, num_images, placeholder='<image>', sep='\n'): |
| for _ in range(num_images): |
| t = f"{placeholder}{sep}" + t |
| return t |
|
|
| def get_conv(texts): |
| ret = [] |
| |
| for conv in texts: |
| ret.append({'from': 'human', 'value': conv[0]}) |
| ret.append({'from': 'gpt', 'value': conv[1]}) |
|
|
| return ret |
|
|
| |
| def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): |
| prompt_chunks = [tokenizer(chunk, add_special_tokens=False).input_ids for chunk in prompt.split('<image>')] |
|
|
| def insert_separator(X, sep): |
| return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] |
|
|
| input_ids = [] |
| offset = 0 |
| if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: |
| offset = 1 |
| input_ids.append(prompt_chunks[0][0]) |
|
|
| for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): |
| input_ids.extend(x[offset:]) |
|
|
| if return_tensors is not None: |
| if return_tensors == 'pt': |
| return torch.tensor(input_ids, dtype=torch.long) |
| raise ValueError(f'Unsupported tensor type: {return_tensors}') |
| return input_ids |
| |
| def preprocess(tokenizer, data: list, return_tensors='pt'): |
| ''' |
| [ |
| { |
| 'from': 'human', |
| 'value': xxx, |
| }, |
| { |
| 'from': 'gpt', |
| 'value': xxx |
| } |
| ] |
| ''' |
| |
| if not isinstance(data, list): |
| raise ValueError('must be a list') |
|
|
| |
| return preprocess_allava(tokenizer, data, return_tensors=return_tensors) |
|
|
| |
|
|
| def preprocess_vicuna_v1(self, convs: list, return_tensors) -> list: |
| input_ids = None |
| for ind, conv in enumerate(convs): |
| if ind % 2 == 0: |
| h = conv['value'].strip() |
| h = f"USER: {h} " |
| cur_input_ids = self.tokenizer_image_token(prompt=h, return_tensors=return_tensors) |
| |
| if input_ids is None: |
| input_ids = cur_input_ids |
| else: |
| input_ids = torch.cat([input_ids, cur_input_ids]) |
|
|
| else: |
| g = conv['value'] |
| if g is not None: |
| cur_input_ids = self.tokenizer(f"ASSISTANT: {g}</s>", add_special_tokens= False, max_length=self.maxlen, truncation=True, return_tensors='pt').input_ids[0] |
| input_ids = torch.cat([input_ids, cur_input_ids]) |
| else: |
| cur_input_ids = self.tokenizer(f"ASSISTANT:", add_special_tokens= False, max_length=self.maxlen, truncation=True, return_tensors='pt').input_ids[0] |
| input_ids = torch.cat([input_ids, cur_input_ids]) |
|
|
|
|
| return input_ids |
|
|
| def preprocess_allava(tokenizer, convs: list, return_tensors) -> list: |
| input_ids = None |
|
|
|
|
| for ind, conv in enumerate(convs): |
| if ind % 2 == 0: |
| h = conv['value'].strip() |
| h = f"[INST] {h} [/INST] " |
| cur_input_ids = tokenizer_image_token(prompt=h, tokenizer=tokenizer, return_tensors=return_tensors) |
| |
| if input_ids is None: |
| input_ids = cur_input_ids |
| else: |
| input_ids = torch.cat([input_ids, cur_input_ids]) |
|
|
| else: |
| g = conv['value'] |
| if g is not None: |
| cur_input_ids = tokenizer(f"{g}{tokenizer.eos_token}", add_special_tokens= False, truncation=True, return_tensors='pt').input_ids[0] |
| input_ids = torch.cat([input_ids, cur_input_ids]) |
|
|
| return input_ids |
|
|
|
|
| |
| def get_image_tensors(processor, images, device): |
| list_image_tensors = [] |
| crop_size = processor.crop_size |
| for fp in images: |
| if fp is None: |
| list_image_tensors.append(torch.zeros(3, crop_size['height'], crop_size['width']).to(device)) |
| continue |
| elif isinstance(fp, str): |
| image = Image.open(fp).convert('RGB') |
| elif isinstance(fp, Image.Image): |
| image = fp |
| else: |
| raise TypeError(f'Unsupported type {type(fp)}') |
|
|
| |
| if True: |
| |
| def expand2square(pil_img, background_color): |
| width, height = pil_img.size |
| if pil_img.mode == 'L': |
| pil_img = pil_img.convert('RGB') |
|
|
| if width == height: |
| return pil_img |
| elif width > height: |
| result = Image.new(pil_img.mode, (width, width), background_color) |
| result.paste(pil_img, (0, (width - height) // 2)) |
| return result |
| else: |
| result = Image.new(pil_img.mode, (height, height), background_color) |
| result.paste(pil_img, ((height - width) // 2, 0)) |
| return result |
| |
| image = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) |
| image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] |
| else: |
| image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] |
| list_image_tensors.append(image.to(device)) |
| |
| return list_image_tensors |
|
|
|
|
|
|
|
|
| def build_allava_input(tokenizer, processor, texts, images, history=None, return_history=False, device='cuda'): |
| ''' |
| texts: [[]] |
| ''' |
|
|
| |
| |
| |
| if isinstance(texts, str): |
| texts = [[texts, None]] |
| else: |
| assert isinstance(texts, list) and isinstance(texts[0], list) , 'texts must be a list of list' |
| |
| if history is not None: |
| texts = history + texts |
|
|
| texts = input_moderation(texts) |
|
|
|
|
| |
| |
| |
| if isinstance(images, str) or isinstance(images, Image.Image): |
| images = [images] |
|
|
| valid_images = [] |
| if images is None: |
| images = [None] |
| |
| for img in images: |
| try: |
| if os.path.exists(img): |
| img = Image.open(img).convert('RGB') |
| else: |
| img = Image.open(requests.get(img, stream=True).raw) |
|
|
| valid_images.append(img) |
| except: |
| continue |
| |
| images = valid_images |
|
|
| if images == []: |
| images = [None] |
| |
|
|
| assert len(images) < max_num_images, f'Currently at most {max_num_images} images are supported' |
|
|
| |
| |
| |
|
|
| history = deepcopy(texts) |
|
|
| |
| image_place_holder_inserted = insert_image_placeholder(texts[0][0], len(images) if None not in images else 0) |
| texts[0][0] = image_place_holder_inserted |
|
|
| |
| conv = get_conv(texts) |
|
|
| |
| input_ids = preprocess(tokenizer, conv, return_tensors='pt').unsqueeze(0).to(device) |
|
|
| list_image_tensors = get_image_tensors(processor, images, device) |
| image_tensors = torch.stack(list_image_tensors) |
|
|
| try: |
| dtype = torch.bfloat16 |
| |
| torch.tensor(1, dtype=dtype).cuda() |
| except: |
| |
| dtype = torch.float16 |
|
|
| if return_history: |
| return input_ids, image_tensors, history |
| |
| return input_ids, image_tensors, None |
|
|
|
|
|
|
| class TextIterStreamer: |
| def __init__(self, tokenizer, skip_prompt=False, skip_special_tokens=False): |
| self.tokenizer = tokenizer |
| self.skip_prompt = skip_prompt |
| self.skip_special_tokens = skip_special_tokens |
| self.tokens = [] |
| self.text_queue = Queue() |
| self.next_tokens_are_prompt = True |
|
|
| def put(self, value): |
| if self.skip_prompt and self.next_tokens_are_prompt: |
| self.next_tokens_are_prompt = False |
| else: |
| if len(value.shape) > 1: |
| value = value[0] |
| self.tokens.extend(value.tolist()) |
| self.text_queue.put( |
| self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens)) |
|
|
| def end(self): |
| self.text_queue.put(None) |
|
|
| def __iter__(self): |
| return self |
|
|
| def __next__(self): |
| value = self.text_queue.get() |
| if value is None: |
| raise StopIteration() |
| else: |
| return value |
|
|