File size: 1,409 Bytes
747451d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 | # /*---------------------------------------------------------------------------------------------
# * Copyright (c) 2025 STMicroelectronics.
# * All rights reserved.
# *
# * This software is licensed under terms that can be found in the LICENSE file in
# * the root directory of this software component.
# * If no LICENSE file comes with this software, it is provided AS-IS.
# *--------------------------------------------------------------------------------------------*/
def prepare_kwargs_for_model(cfg):
in_chans = 3
#if getattr(cfg, 'input_size', None) is not None:
in_chans = tuple(cfg.model.input_shape)[0] # TODO ST: HWC Torch: CHW
model_kwargs = {
'in_chans': in_chans,
'drop_rate': getattr(getattr(cfg, 'data_augmentation', None), 'drop', None),
'drop_path_rate': getattr(getattr(cfg, 'data_augmentation', None), 'drop_path', None),
'drop_block_rate': getattr(getattr(cfg, 'data_augmentation', None), 'drop_block', None),
'global_pool': getattr(getattr(cfg, 'model', None), 'gp', None),
'bn_momentum': getattr(getattr(cfg, 'training', None), 'bn_momentum', None),
'bn_eps': getattr(getattr(cfg, 'training', None), 'bn_eps', None),
'scriptable': False,
'checkpoint_path': getattr(cfg.model, "model_path", None),
**(getattr(cfg, 'model_kwargs', {}) or {}),
}
return model_kwargs
|