FBAGSTM's picture
STM32 AI Experimentation Hub
747451d
# /*---------------------------------------------------------------------------------------------
# * Copyright (c) 2022-2023 STMicroelectronics.
# * All rights reserved.
# *
# * This software is licensed under terms that can be found in the LICENSE file in
# * the root directory of this software component.
# * If no LICENSE file comes with this software, it is provided AS-IS.
# *--------------------------------------------------------------------------------------------*/
import tensorflow as tf
def grayscale_not_supported(images, function_name=None):
"""
Docstring for grayscale_not_supported. Reports an error message if grayscale images are detected
in the input images batch.
Args:
images (tf.Tensor): batch of images
function_name (str): name of the augmentation function
"""
message = f"\nFunction `{function_name}`: grayscale images are not supported."
tf.debugging.assert_equal(tf.shape(images)[-1], tf.constant(3), message)
def check_dataaug_argument(arg, arg_name, function_name=None, data_type=None, tuples=1):
"""
This function is a utility that checks the data types of arguments
used in data augmentation functions. Argument values may be integers,
floats or tuples of length 2.
Inputs:
arg:
The argument to check.
arg_name:
A string, the name of the argument.
function_name:
A string, the name of the function calling check_dataaug_argument().
data_type:
Specifies the expected data type for the argument. It may
be int, float, (int, float) or (float, int).
tuples:
An integer, specifies how tuples should be handled:
0: tuples are not accepted, scalars only.
1: both tuples and scalars are accepted.
2. scalars are not accepted, tuples only.
By default, both scalers and tuples are accepted.
"""
def _check_data_type(arg):
# Check that the data type of the argument is as expected.
arg_type = type(arg)
if data_type == int and arg_type != int:
raise ValueError("\nArgument `{}` of function `{}`: expecting integer values. "
"Received {}".format(arg_name, function_name, arg))
if data_type == float and arg_type != float:
raise ValueError("\nArgument `{}` of function `{}`: expecting float values. "
"Received {}".format(arg_name, function_name, arg))
if data_type == (int, float) or data_type == (float, int):
if arg_type != int and arg_type != float:
raise ValueError("\nArgument `{}` of function `{}`: expecting float or integer values. "
"Received {}".format(arg_name, function_name, arg))
if tuples not in {0, 1, 2}:
# The function is used incorrectly.
raise ValueError("\nArgument `{}` of function `{}`: expecting an integer "
"value in {0, 1, 2}".format(arg_name, function_name))
if arg is None:
# The argument is not set.
raise ValueError("\nFunction `{}`: the argument `{}` is not set. "
"Received None.".format(function_name, arg_name))
if not isinstance(arg, (int, float, tuple, list)):
raise ValueError("\nArgument `{}` of function `{}`: invalid data type. "
"Received {}".format(arg_name, function_name, arg))
if tuples == 0:
# Tuples are not accepted, scalars only.
if isinstance(arg, (tuple, list)):
raise ValueError("\nArgument `{}` of function `{}`: tuples are not supported. "
"Received {}".format(arg_name, function_name, arg))
if tuples == 2:
# Only tuples are accepted, no scalars.
if not isinstance(arg, (tuple, list)):
raise ValueError("\nArgument `{}` of function `{}`: expecting a tuple of length 2. "
"Received {}".format(arg_name, function_name, arg))
if isinstance(arg, (tuple, list)):
if len(arg) != 2:
raise ValueError("\nArgument `{}` of function `{}`: the tuple should have 2 elements. "
"Received {}".format(arg_name, function_name, arg))
if arg[1] <= arg[0]:
raise ValueError("\nArgument `{}` of function `{}`: the tuple right value should be greater "
"than the left value. Received{}".format(arg_name, function_name, arg))
_check_data_type(arg[0])
_check_data_type(arg[1])
else:
_check_data_type(arg)
def remap_pixel_values_range(images, input_range, output_range, dtype=tf.float32):
"""
This function remaps the pixel values of input images to a different range
of values using a linear transformation.
For example, it can be used to remap input pixels with values in the [0, 255]
interval with uint8 data type to output pixels in the interval [0.0, 1.0]
with float data type.
Args:
images (tf.Tensor): batch of input images
input_range (tuple): pixels range at input
output_range (tuple): pixels range at output
dtype: type to cast output images
Returns:
(tf.Tensor): batch of remapped and casted images
"""
if input_range != output_range:
s0, s1 = input_range
t0, t1 = output_range
images = tf.cast(images, tf.float32)
images = ((t1 - t0) * images + t0*s1 - t1*s0) / (s1 - s0)
images = tf.clip_by_value(images, t0, t1)
return tf.cast(images, dtype)
def apply_change_rate(images, images_augmented, change_rate=1.0):
"""
This function outputs a mix of augmented images and original
images. The argument `change_rate` is a float in the interval
[0.0, 1.0] representing the number of changed images versus
the total number of input images average ratio. For example,
if `change_rate` is set to 0.25, 25% of the input images will
get changed on average (75% won't get changed). If it is set
to 0.0, no images are changed. If it is set to 1.0, all the
images are changed.
Args:
images (tf.Tensor): input batch of images
images_augmented (tf.Tensor): output augmented images
change_rate (float): probability to apply augmentation
Returns:
(tf.Tensor): stacking of augmented and non augmented images
"""
if change_rate == 1.0:
return images_augmented
if change_rate < 0. or change_rate > 1.:
raise ValueError("The value of `change_rate` must be in the interval [0, 1]. ",
"Received {}".format(change_rate))
images_shape = tf.shape(images)
batch_size = images_shape[0]
width = images_shape[1]
height = images_shape[2]
channels = images_shape[3]
probs = tf.random.uniform([batch_size], minval=0, maxval=1, dtype=tf.float32)
change = tf.where(probs < change_rate, True, False)
# Create a mask to apply to images
mask = tf.repeat(change, width * height * channels)
mask = tf.reshape(mask, [batch_size, width, height, channels])
mask_not = tf.math.logical_not(mask)
mask = tf.cast(mask, images.dtype)
mask_not = tf.cast(mask_not, images.dtype)
return mask_not * images + mask * images_augmented