| import torch |
| from PIL import Image |
|
|
| from transformers import Qwen2VLProcessor, AutoProcessor |
| from transformers.models.qwen2_vl.processing_qwen2_vl import Qwen2VLProcessorKwargs |
|
|
|
|
| class QwenVLProcessor: |
| ROLE = ('user', 'assistant') |
|
|
| def __init__(self, max_length=512, pretrained_model_name_or_path=None): |
| self.processor = AutoProcessor.from_pretrained( |
| pretrained_model_name_or_path) |
| self.max_length = max_length |
|
|
| def __getattr__(self, name: str): |
| try: |
| return super().__getattr__(name) |
| except AttributeError: |
| return getattr(self.processor, name) |
|
|
| def build_prompt(self, query, answer, round=0, system=None): |
| messages = [{"role": self.ROLE[0], "content": query}] |
| if round == 0 and system: |
| messages.insert(0, {"role": "system", "content": system}) |
|
|
| if answer is None: |
| query = self.tokenizer.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True) |
| else: |
| messages.append({"role": self.ROLE[1], "content": answer}) |
| prompt = self.tokenizer.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=False) |
| query, answer = prompt.split("<|im_start|>assistant\n") |
| query += "<|im_start|>assistant\n" |
|
|
| return query, answer |
|
|
| def __call__(self, data_dict, **kwargs): |
| conversations = data_dict["conversations"] |
| images = data_dict.get("images", None) |
| videos = data_dict.get("videos", None) |
|
|
| images = data_dict.get("image", None) |
| if images is not None: |
| images = [Image.open(images).convert('RGB')] |
|
|
| output_kwargs = self._merge_kwargs( |
| Qwen2VLProcessorKwargs, |
| tokenizer_init_kwargs=self.tokenizer.init_kwargs, |
| **kwargs, |
| ) |
| if images is not None: |
| image_inputs = self.image_processor( |
| images=images, videos=None, **output_kwargs["images_kwargs"]) |
| image_grid_thw = image_inputs["image_grid_thw"] |
| else: |
| image_inputs = {} |
| image_grid_thw = None |
|
|
| new_conversation = [] |
| index = 0 |
| for msg in conversations: |
| if msg['from'] == 'human': |
| if image_grid_thw is not None: |
| merge_length = self.image_processor.merge_size**2 |
| text = msg['value'] |
| while "<image>" in text: |
| text = text.replace( |
| "<image>", "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length), 1) |
| index += 1 |
| text = text.replace("<|placeholder|>", "<|image_pad|>") |
| msg['value'] = text |
| new_conversation.append(msg) |
|
|
| input_ids, labels = [], [] |
| for i in range(0, len(new_conversation), 2): |
| query = new_conversation[i]['value'] |
| answer = new_conversation[i+1]['value'] if i + \ |
| 1 < len(new_conversation) else None |
| query, answer = self.build_prompt(query, answer, round=i // 2) |
|
|
| input_ids_ = self.tokenizer( |
| query, add_special_tokens=True, return_attention_mask=False)['input_ids'] |
| labels_ = [-100] * len(input_ids_) |
| if answer is not None: |
| output_ids_ = self.tokenizer(answer, add_special_tokens=True, |
| return_attention_mask=False)['input_ids'] |
| labels_ += output_ids_ |
| input_ids_ += output_ids_ |
| input_ids += input_ids_ |
| labels += labels_ |
|
|
| return { |
| "input_ids": input_ids, |
| "labels": labels, |
| 'pixel_values': image_inputs.get('pixel_values', None), |
| } |
|
|