| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Generic utilities |
| """ |
|
|
| from collections import OrderedDict |
| from dataclasses import fields |
| from typing import Any, Tuple |
|
|
| import numpy as np |
|
|
| from .import_utils import is_paddle_available |
|
|
|
|
| def is_tensor(x): |
| """ |
| Tests if `x` is a `paddle.Tensor` or `np.ndarray`. |
| """ |
| if is_paddle_available(): |
| import paddle |
|
|
| return paddle.is_tensor(x) |
|
|
| return isinstance(x, np.ndarray) |
|
|
|
|
| class BaseOutput(OrderedDict): |
| """ |
| Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a |
| tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular |
| python dictionary. |
| |
| <Tip warning={true}> |
| |
| You can't unpack a `BaseOutput` directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple |
| before. |
| |
| </Tip> |
| """ |
|
|
| def __post_init__(self): |
| class_fields = fields(self) |
|
|
| |
| if not len(class_fields): |
| raise ValueError(f"{self.__class__.__name__} has no fields.") |
|
|
| first_field = getattr(self, class_fields[0].name) |
| other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:]) |
|
|
| if other_fields_are_none and isinstance(first_field, dict): |
| for key, value in first_field.items(): |
| self[key] = value |
| else: |
| for field in class_fields: |
| v = getattr(self, field.name) |
| if v is not None: |
| self[field.name] = v |
|
|
| def __delitem__(self, *args, **kwargs): |
| raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") |
|
|
| def setdefault(self, *args, **kwargs): |
| raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") |
|
|
| def pop(self, *args, **kwargs): |
| raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") |
|
|
| def update(self, *args, **kwargs): |
| raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") |
|
|
| def __getitem__(self, k): |
| if isinstance(k, str): |
| inner_dict = {k: v for (k, v) in self.items()} |
| return inner_dict[k] |
| else: |
| return self.to_tuple()[k] |
|
|
| def __setattr__(self, name, value): |
| if name in self.keys() and value is not None: |
| |
| super().__setitem__(name, value) |
| super().__setattr__(name, value) |
|
|
| def __setitem__(self, key, value): |
| |
| super().__setitem__(key, value) |
| |
| super().__setattr__(key, value) |
|
|
| def to_tuple(self) -> Tuple[Any]: |
| """ |
| Convert self to a tuple containing all the attributes/keys that are not `None`. |
| """ |
| |
| |
| |
| tuples = () |
| for field in fields(self): |
| if getattr(self, field.name, None) is None: |
| continue |
| tuples = tuples + (getattr(self, field.name),) |
|
|
| return tuples |
|
|