| | from .resnet import * |
| | import logging |
| | logger = logging.getLogger('base') |
| |
|
| | def create_CD_model(opt): |
| | |
| | from models.STNR import STNR as stnr |
| |
|
| | if opt['model']['name'] == 'STNR': |
| | cd_model = stnr(spatial_dims=opt['model']['spatial_dims'], in_channels=opt['model']['in_channels'], init_filters=opt['model']['init_filters'], out_channels=opt['model']['n_classes'], |
| | mode=opt['model']['mode'], conv_mode=opt['model']['conv_mode'], up_mode=opt['model']['up_mode'], up_conv_mode=opt['model']['up_conv_mode'], norm=opt['model']['norm'], |
| | blocks_down=opt['model']['blocks_down'], blocks_up=opt['model']['blocks_up'], resdiual=opt['model']['resdiual'], diff_abs=opt['model']['diff_abs'], stage=opt['model']['stage'], |
| | mamba_act=opt['model']['mamba_act'], local_query_model=opt['model']['local_query_model']) |
| |
|