File size: 1,409 Bytes
747451d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# /*---------------------------------------------------------------------------------------------
#  * Copyright (c) 2025 STMicroelectronics.
#  * All rights reserved.
#  *
#  * This software is licensed under terms that can be found in the LICENSE file in
#  * the root directory of this software component.
#  * If no LICENSE file comes with this software, it is provided AS-IS.
#  *--------------------------------------------------------------------------------------------*/

def prepare_kwargs_for_model(cfg):
    in_chans = 3
    #if getattr(cfg, 'input_size', None) is not None:
    in_chans = tuple(cfg.model.input_shape)[0] # TODO ST: HWC Torch: CHW

    model_kwargs = {
        'in_chans': in_chans,
        'drop_rate': getattr(getattr(cfg, 'data_augmentation', None), 'drop', None),
        'drop_path_rate': getattr(getattr(cfg, 'data_augmentation', None), 'drop_path', None),
        'drop_block_rate': getattr(getattr(cfg, 'data_augmentation', None), 'drop_block', None),

        'global_pool': getattr(getattr(cfg, 'model', None), 'gp', None),

        'bn_momentum': getattr(getattr(cfg, 'training', None), 'bn_momentum', None),
        'bn_eps': getattr(getattr(cfg, 'training', None), 'bn_eps', None),
        'scriptable': False,
        'checkpoint_path': getattr(cfg.model, "model_path", None),
        **(getattr(cfg, 'model_kwargs', {}) or {}),
    }
    return model_kwargs