| import tensorflow as tf |
| from tensorflow.keras import layers, models |
| from .config import Config |
|
|
| class ModelFactory: |
| """Factory class to reconstruct model architectures for weight loading.""" |
| |
| @staticmethod |
| def build_resnet50_classifier(): |
| """Reconstructs ResNet50 Classifier.""" |
| |
| base_model = tf.keras.applications.ResNet50( |
| weights='imagenet', |
| include_top=False, |
| input_shape=Config.INPUT_SHAPE |
| ) |
| x = base_model.output |
| x = layers.GlobalAveragePooling2D()(x) |
| x = layers.Dense(512, activation='relu')(x) |
| x = layers.Dropout(0.5)(x) |
| output = layers.Dense(len(Config.CLASSES), activation='softmax')(x) |
| return models.Model(inputs=base_model.input, outputs=output) |
|
|
| @staticmethod |
| def build_cia_net(): |
| """Reconstructs CIA-Net Segmentation Model.""" |
| |
| def IAM_Module(nuc, con, filters): |
| concat = layers.Concatenate()([nuc, con]) |
| smooth = layers.Conv2D(filters, 3, padding='same')(concat) |
| nuc_refine = layers.Conv2D(filters, 3, padding='same', activation='relu')(smooth) |
| con_refine = layers.Conv2D(filters, 3, padding='same', activation='relu')(smooth) |
| return nuc_refine, con_refine |
|
|
| inputs = layers.Input(shape=(None, None, 3)) |
| |
| |
| base = tf.keras.applications.DenseNet121( |
| include_top=False, |
| weights='imagenet', |
| input_tensor=inputs |
| ) |
| |
| |
| enc1 = base.get_layer('conv1_relu').output |
| enc2 = base.get_layer('conv2_block6_concat').output |
| enc3 = base.get_layer('conv3_block12_concat').output |
| enc4 = base.get_layer('conv4_block24_concat').output |
| bottleneck = base.get_layer('relu').output |
|
|
| |
| x = layers.Conv2D(256, 3, padding='same', activation='relu')(bottleneck) |
| x = layers.UpSampling2D()(x) |
| enc4_lat = layers.Conv2D(256, 1, padding='same')(enc4) |
| |
| m4 = layers.Add()([x, enc4_lat]) |
| nuc4, con4 = IAM_Module(m4, m4, 256) |
| |
| |
| nuc_up3 = layers.Conv2D(128, 1, padding='same')(layers.UpSampling2D()(nuc4)) |
| con_up3 = layers.Conv2D(128, 1, padding='same')(layers.UpSampling2D()(con4)) |
| enc3_lat = layers.Conv2D(128, 1, padding='same')(enc3) |
| |
| nuc_m3 = layers.Add()([nuc_up3, enc3_lat]) |
| con_m3 = layers.Add()([con_up3, enc3_lat]) |
| nuc3, con3 = IAM_Module(nuc_m3, con_m3, 128) |
| |
| |
| nuc_up2 = layers.Conv2D(64, 1, padding='same')(layers.UpSampling2D()(nuc3)) |
| con_up2 = layers.Conv2D(64, 1, padding='same')(layers.UpSampling2D()(con3)) |
| enc2_lat = layers.Conv2D(64, 1, padding='same')(enc2) |
| |
| nuc_m2 = layers.Add()([nuc_up2, enc2_lat]) |
| con_m2 = layers.Add()([con_up2, enc2_lat]) |
| nuc2, con2 = IAM_Module(nuc_m2, con_m2, 64) |
| |
| |
| nuc_up1 = layers.Conv2D(32, 1, padding='same')(layers.UpSampling2D()(nuc2)) |
| con_up1 = layers.Conv2D(32, 1, padding='same')(layers.UpSampling2D()(con2)) |
| enc1_lat = layers.Conv2D(32, 1, padding='same')(enc1) |
| |
| nuc_m1 = layers.Add()([nuc_up1, enc1_lat]) |
| con_m1 = layers.Add()([con_up1, enc1_lat]) |
| nuc1, con1 = IAM_Module(nuc_m1, con_m1, 32) |
| |
| |
| final_nuc = layers.UpSampling2D()(nuc1) |
| final_con = layers.UpSampling2D()(con1) |
| |
| out_nuc = layers.Conv2D(1, 1, activation='sigmoid', name='nuclei_output')(final_nuc) |
| out_con = layers.Conv2D(1, 1, activation='sigmoid', name='contour_output')(final_con) |
| |
| return models.Model(inputs=inputs, outputs=[out_nuc, out_con]) |