| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Any, Dict, List, Optional, Sequence |
|
|
| import torch |
| import transformers |
|
|
| from .conversation import default_conversation, SeparatorStyle |
| from .mm_utils import tokenizer_image_token |
| from .constants import IGNORE_INDEX, SENTINEL_TOKEN |
|
|
| |
| |
| |
| |
| |
|
|
| DUMMY_CONVERSATION = [ |
| {"from": "human", "value": "question"}, |
| {"from": "gpt", "value": "answer"}, |
| ] * 10 |
|
|
|
|
| def tokenize_conversation_legacy( |
| messages: Sequence[Dict[str, str]], |
| tokenizer: transformers.PreTrainedTokenizer, |
| add_generation_prompt: bool = False, |
| overrides: Optional[Dict[str, str]] = None, |
| no_system_prompt: bool = False, |
| ) -> torch.Tensor: |
| conv = default_conversation.copy() |
| roles = {"human": conv.roles[0], "gpt": conv.roles[1]} |
|
|
| if no_system_prompt: |
| conv.system = "" |
|
|
| |
| if messages[0]["from"] != "human": |
| messages = messages[1:] |
|
|
| |
| if add_generation_prompt: |
| messages.append({"from": "gpt", "value": None}) |
|
|
| conv.messages = [] |
| for turn, message in enumerate(messages): |
| role = roles[message["from"]] |
| assert role == conv.roles[turn % 2] |
| if overrides is not None and message["from"] in overrides: |
| conv.append_message(role, overrides[message["from"]]) |
| else: |
| conv.append_message(role, message["value"]) |
|
|
| return tokenizer_image_token(conv.get_prompt(), tokenizer, return_tensors="pt") |
|
|
|
|
| def tokenize_conversation( |
| messages: Sequence[Dict[str, str]], |
| tokenizer: transformers.PreTrainedTokenizer, |
| add_generation_prompt: bool = False, |
| overrides: Optional[Dict[str, str]] = None, |
| no_system_prompt: bool = False, |
| ) -> torch.Tensor: |
| |
| for message in messages: |
| message["value"] = message["value"].strip() |
|
|
| if default_conversation.sep_style != SeparatorStyle.AUTO: |
| return tokenize_conversation_legacy( |
| messages, |
| tokenizer, |
| add_generation_prompt=add_generation_prompt, |
| overrides=overrides, |
| no_system_prompt=no_system_prompt, |
| ) |
|
|
| conversation = [] |
| for m in messages: |
| message = {} |
| if m["from"] == "human": |
| message["role"] = "user" |
| elif m["from"] == "gpt": |
| message["role"] = "assistant" |
| else: |
| raise ValueError(f"Unexpected sender '{m['from']}' in conversation entry.") |
|
|
| message["content"] = m["value"] |
| if overrides is not None and m["from"] in overrides: |
| message["content"] = overrides[m["from"]] |
| conversation.append(message) |
|
|
| if no_system_prompt: |
| conversation = [{"role": "system", "content": ""}] + conversation |
|
|
| text = tokenizer.apply_chat_template( |
| conversation, |
| add_generation_prompt=add_generation_prompt, |
| tokenize=False, |
| ) |
| return tokenizer_image_token(text, tokenizer, return_tensors="pt") |
|
|
|
|
| def _maybe_add_sentinel_token(tokenizer: transformers.PreTrainedTokenizer) -> None: |
| if not hasattr(tokenizer, "sentinel_token"): |
| tokenizer.add_tokens([SENTINEL_TOKEN], special_tokens=True) |
| tokenizer.sentinel_token = SENTINEL_TOKEN |
| tokenizer.sentinel_token_id = tokenizer.convert_tokens_to_ids(SENTINEL_TOKEN) |
|
|
|
|
| def preprocess_conversation( |
| conversation: Sequence[Dict[str, str]], |
| tokenizer: transformers.PreTrainedTokenizer, |
| no_system_prompt: bool = False, |
| retried: bool = False, |
| ) -> Dict[str, Any]: |
| inputs = tokenize_conversation(conversation, tokenizer, no_system_prompt=no_system_prompt) |
| labels = torch.ones_like(inputs) * IGNORE_INDEX |
|
|
| |
| _maybe_add_sentinel_token(tokenizer) |
| template = tokenize_conversation( |
| conversation, tokenizer, overrides={"gpt": SENTINEL_TOKEN}, no_system_prompt=no_system_prompt |
| ) |
|
|
| |
| mask = torch.ones_like(template, dtype=torch.bool) |
| for k in range(template.size(0) - 1): |
| if template[k] == tokenizer.sentinel_token_id: |
| mask[k : k + 2] = False |
| |
| if k > 0 and retried: |
| mask[k - 1] = False |
| template = template[mask] |
|
|
| |
| |
| p = 0 |
| for k in range(inputs.size(0)): |
| if p < template.size(0) and inputs[k] == template[p]: |
| p += 1 |
| else: |
| labels[k] = inputs[k] |
|
|
| |
| if p < template.size(0): |
| if not retried: |
| return preprocess_conversation( |
| conversation, |
| tokenizer, |
| no_system_prompt=no_system_prompt, |
| retried=True, |
| ) |
| print(f"Failed to process the conversation: '{conversation}'. All tokens will be masked in the label.") |
| labels[:] = IGNORE_INDEX |
|
|
| return {"input_ids": inputs, "labels": labels} |
|
|
|
|
| def infer_stop_tokens(tokenizer: transformers.PreTrainedTokenizer) -> List[str]: |
| _maybe_add_sentinel_token(tokenizer) |
| template = tokenize_conversation(DUMMY_CONVERSATION, tokenizer, overrides={"gpt": SENTINEL_TOKEN}) |
|
|
| stop_tokens = {tokenizer.eos_token} |
| for k in range(template.size(0) - 1): |
| if template[k] == tokenizer.sentinel_token_id: |
| stop_token = tokenizer.decode(template[k + 1]) |
| stop_tokens.add(stop_token) |
| return list(stop_tokens) |
|
|