|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """PyTorch - Flax general utilities."""
|
|
|
| from pickle import UnpicklingError
|
|
|
| import jax
|
| import jax.numpy as jnp
|
| import numpy as np
|
| from flax.serialization import from_bytes
|
| from flax.traverse_util import flatten_dict
|
|
|
| from ..utils import logging
|
|
|
|
|
| logger = logging.get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def load_flax_checkpoint_in_pytorch_model(pt_model, model_file):
|
| try:
|
| with open(model_file, "rb") as flax_state_f:
|
| flax_state = from_bytes(None, flax_state_f.read())
|
| except UnpicklingError as e:
|
| try:
|
| with open(model_file) as f:
|
| if f.read().startswith("version"):
|
| raise OSError(
|
| "You seem to have cloned a repository without having git-lfs installed. Please"
|
| " install git-lfs and run `git lfs install` followed by `git lfs pull` in the"
|
| " folder you cloned."
|
| )
|
| else:
|
| raise ValueError from e
|
| except (UnicodeDecodeError, ValueError):
|
| raise EnvironmentError(f"Unable to convert {model_file} to Flax deserializable object. ")
|
|
|
| return load_flax_weights_in_pytorch_model(pt_model, flax_state)
|
|
|
|
|
| def load_flax_weights_in_pytorch_model(pt_model, flax_state):
|
| """Load flax checkpoints in a PyTorch model"""
|
|
|
| try:
|
| import torch
|
| except ImportError:
|
| logger.error(
|
| "Loading Flax weights in PyTorch requires both PyTorch and Flax to be installed. Please see"
|
| " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
|
| " instructions."
|
| )
|
| raise
|
|
|
|
|
| is_type_bf16 = flatten_dict(jax.tree_util.tree_map(lambda x: x.dtype == jnp.bfloat16, flax_state)).values()
|
| if any(is_type_bf16):
|
|
|
|
|
|
|
| logger.warning(
|
| "Found ``bfloat16`` weights in Flax model. Casting all ``bfloat16`` weights to ``float32`` "
|
| "before loading those in PyTorch model."
|
| )
|
| flax_state = jax.tree_util.tree_map(
|
| lambda params: params.astype(np.float32) if params.dtype == jnp.bfloat16 else params, flax_state
|
| )
|
|
|
| pt_model.base_model_prefix = ""
|
|
|
| flax_state_dict = flatten_dict(flax_state, sep=".")
|
| pt_model_dict = pt_model.state_dict()
|
|
|
|
|
| unexpected_keys = []
|
| missing_keys = set(pt_model_dict.keys())
|
|
|
| for flax_key_tuple, flax_tensor in flax_state_dict.items():
|
| flax_key_tuple_array = flax_key_tuple.split(".")
|
|
|
| if flax_key_tuple_array[-1] == "kernel" and flax_tensor.ndim == 4:
|
| flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
|
| flax_tensor = jnp.transpose(flax_tensor, (3, 2, 0, 1))
|
| elif flax_key_tuple_array[-1] == "kernel":
|
| flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
|
| flax_tensor = flax_tensor.T
|
| elif flax_key_tuple_array[-1] == "scale":
|
| flax_key_tuple_array = flax_key_tuple_array[:-1] + ["weight"]
|
|
|
| if "time_embedding" not in flax_key_tuple_array:
|
| for i, flax_key_tuple_string in enumerate(flax_key_tuple_array):
|
| flax_key_tuple_array[i] = (
|
| flax_key_tuple_string.replace("_0", ".0")
|
| .replace("_1", ".1")
|
| .replace("_2", ".2")
|
| .replace("_3", ".3")
|
| .replace("_4", ".4")
|
| .replace("_5", ".5")
|
| .replace("_6", ".6")
|
| .replace("_7", ".7")
|
| .replace("_8", ".8")
|
| .replace("_9", ".9")
|
| )
|
|
|
| flax_key = ".".join(flax_key_tuple_array)
|
|
|
| if flax_key in pt_model_dict:
|
| if flax_tensor.shape != pt_model_dict[flax_key].shape:
|
| raise ValueError(
|
| f"Flax checkpoint seems to be incorrect. Weight {flax_key_tuple} was expected "
|
| f"to be of shape {pt_model_dict[flax_key].shape}, but is {flax_tensor.shape}."
|
| )
|
| else:
|
|
|
| flax_tensor = np.asarray(flax_tensor) if not isinstance(flax_tensor, np.ndarray) else flax_tensor
|
| pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)
|
|
|
| missing_keys.remove(flax_key)
|
| else:
|
|
|
| unexpected_keys.append(flax_key)
|
|
|
| pt_model.load_state_dict(pt_model_dict)
|
|
|
|
|
| missing_keys = list(missing_keys)
|
|
|
| if len(unexpected_keys) > 0:
|
| logger.warning(
|
| "Some weights of the Flax model were not used when initializing the PyTorch model"
|
| f" {pt_model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are initializing"
|
| f" {pt_model.__class__.__name__} from a Flax model trained on another task or with another architecture"
|
| " (e.g. initializing a BertForSequenceClassification model from a FlaxBertForPreTraining model).\n- This"
|
| f" IS NOT expected if you are initializing {pt_model.__class__.__name__} from a Flax model that you expect"
|
| " to be exactly identical (e.g. initializing a BertForSequenceClassification model from a"
|
| " FlaxBertForSequenceClassification model)."
|
| )
|
| if len(missing_keys) > 0:
|
| logger.warning(
|
| f"Some weights of {pt_model.__class__.__name__} were not initialized from the Flax model and are newly"
|
| f" initialized: {missing_keys}\nYou should probably TRAIN this model on a down-stream task to be able to"
|
| " use it for predictions and inference."
|
| )
|
|
|
| return pt_model
|
|
|