File size: 7,567 Bytes
747451d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 | # /*---------------------------------------------------------------------------------------------
# * Copyright (c) 2022-2023 STMicroelectronics.
# * All rights reserved.
# *
# * This software is licensed under terms that can be found in the LICENSE file in
# * the root directory of this software component.
# * If no LICENSE file comes with this software, it is provided AS-IS.
# *--------------------------------------------------------------------------------------------*/
import tensorflow as tf
def grayscale_not_supported(images, function_name=None):
"""
Docstring for grayscale_not_supported. Reports an error message if grayscale images are detected
in the input images batch.
Args:
images (tf.Tensor): batch of images
function_name (str): name of the augmentation function
"""
message = f"\nFunction `{function_name}`: grayscale images are not supported."
tf.debugging.assert_equal(tf.shape(images)[-1], tf.constant(3), message)
def check_dataaug_argument(arg, arg_name, function_name=None, data_type=None, tuples=1):
"""
This function is a utility that checks the data types of arguments
used in data augmentation functions. Argument values may be integers,
floats or tuples of length 2.
Inputs:
arg:
The argument to check.
arg_name:
A string, the name of the argument.
function_name:
A string, the name of the function calling check_dataaug_argument().
data_type:
Specifies the expected data type for the argument. It may
be int, float, (int, float) or (float, int).
tuples:
An integer, specifies how tuples should be handled:
0: tuples are not accepted, scalars only.
1: both tuples and scalars are accepted.
2. scalars are not accepted, tuples only.
By default, both scalers and tuples are accepted.
"""
def _check_data_type(arg):
# Check that the data type of the argument is as expected.
arg_type = type(arg)
if data_type == int and arg_type != int:
raise ValueError("\nArgument `{}` of function `{}`: expecting integer values. "
"Received {}".format(arg_name, function_name, arg))
if data_type == float and arg_type != float:
raise ValueError("\nArgument `{}` of function `{}`: expecting float values. "
"Received {}".format(arg_name, function_name, arg))
if data_type == (int, float) or data_type == (float, int):
if arg_type != int and arg_type != float:
raise ValueError("\nArgument `{}` of function `{}`: expecting float or integer values. "
"Received {}".format(arg_name, function_name, arg))
if tuples not in {0, 1, 2}:
# The function is used incorrectly.
raise ValueError("\nArgument `{}` of function `{}`: expecting an integer "
"value in {0, 1, 2}".format(arg_name, function_name))
if arg is None:
# The argument is not set.
raise ValueError("\nFunction `{}`: the argument `{}` is not set. "
"Received None.".format(function_name, arg_name))
if not isinstance(arg, (int, float, tuple, list)):
raise ValueError("\nArgument `{}` of function `{}`: invalid data type. "
"Received {}".format(arg_name, function_name, arg))
if tuples == 0:
# Tuples are not accepted, scalars only.
if isinstance(arg, (tuple, list)):
raise ValueError("\nArgument `{}` of function `{}`: tuples are not supported. "
"Received {}".format(arg_name, function_name, arg))
if tuples == 2:
# Only tuples are accepted, no scalars.
if not isinstance(arg, (tuple, list)):
raise ValueError("\nArgument `{}` of function `{}`: expecting a tuple of length 2. "
"Received {}".format(arg_name, function_name, arg))
if isinstance(arg, (tuple, list)):
if len(arg) != 2:
raise ValueError("\nArgument `{}` of function `{}`: the tuple should have 2 elements. "
"Received {}".format(arg_name, function_name, arg))
if arg[1] <= arg[0]:
raise ValueError("\nArgument `{}` of function `{}`: the tuple right value should be greater "
"than the left value. Received{}".format(arg_name, function_name, arg))
_check_data_type(arg[0])
_check_data_type(arg[1])
else:
_check_data_type(arg)
def remap_pixel_values_range(images, input_range, output_range, dtype=tf.float32):
"""
This function remaps the pixel values of input images to a different range
of values using a linear transformation.
For example, it can be used to remap input pixels with values in the [0, 255]
interval with uint8 data type to output pixels in the interval [0.0, 1.0]
with float data type.
Args:
images (tf.Tensor): batch of input images
input_range (tuple): pixels range at input
output_range (tuple): pixels range at output
dtype: type to cast output images
Returns:
(tf.Tensor): batch of remapped and casted images
"""
if input_range != output_range:
s0, s1 = input_range
t0, t1 = output_range
images = tf.cast(images, tf.float32)
images = ((t1 - t0) * images + t0*s1 - t1*s0) / (s1 - s0)
images = tf.clip_by_value(images, t0, t1)
return tf.cast(images, dtype)
def apply_change_rate(images, images_augmented, change_rate=1.0):
"""
This function outputs a mix of augmented images and original
images. The argument `change_rate` is a float in the interval
[0.0, 1.0] representing the number of changed images versus
the total number of input images average ratio. For example,
if `change_rate` is set to 0.25, 25% of the input images will
get changed on average (75% won't get changed). If it is set
to 0.0, no images are changed. If it is set to 1.0, all the
images are changed.
Args:
images (tf.Tensor): input batch of images
images_augmented (tf.Tensor): output augmented images
change_rate (float): probability to apply augmentation
Returns:
(tf.Tensor): stacking of augmented and non augmented images
"""
if change_rate == 1.0:
return images_augmented
if change_rate < 0. or change_rate > 1.:
raise ValueError("The value of `change_rate` must be in the interval [0, 1]. ",
"Received {}".format(change_rate))
images_shape = tf.shape(images)
batch_size = images_shape[0]
width = images_shape[1]
height = images_shape[2]
channels = images_shape[3]
probs = tf.random.uniform([batch_size], minval=0, maxval=1, dtype=tf.float32)
change = tf.where(probs < change_rate, True, False)
# Create a mask to apply to images
mask = tf.repeat(change, width * height * channels)
mask = tf.reshape(mask, [batch_size, width, height, channels])
mask_not = tf.math.logical_not(mask)
mask = tf.cast(mask, images.dtype)
mask_not = tf.cast(mask_not, images.dtype)
return mask_not * images + mask * images_augmented
|