File size: 1,286 Bytes
226675b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
from torch import nn
import sys
sys.path.append('rscd')
from utils.build import build_from_cfg

class myModel(nn.Module):
    def __init__(self, cfg):
        super(myModel, self).__init__()
        self.backbone = build_from_cfg(cfg.backbone)
        self.decoderhead = build_from_cfg(cfg.decoderhead)
    
    def forward(self, x1, x2, gtmask=None):
        backbone_outputs = self.backbone(x1, x2)
        if gtmask == None:
            x_list = self.decoderhead(backbone_outputs)
        else:
            x_list = self.decoderhead(backbone_outputs, gtmask)
        return x_list

"""
对于不满足该范式的模型可在backbone部分进行定义, 并在此处导入
"""

# model_config
def build_model(cfg):
    c = myModel(cfg)
    return c


if __name__ == "__main__":
    x1 = torch.randn(4, 3, 512, 512)
    x2 = torch.randn(4, 3, 512, 512)
    target = torch.randint(low=0,high=2,size=[4, 512, 512])
    file_path = r"E:\zjuse\2308CD\rschangedetection\configs\SARASNet.py"

    from utils.config import Config
    from rscd.losses import build_loss

    cfg = Config.fromfile(file_path)
    net = build_model(cfg.model_config)
    res = net(x1, x2)
    print(res.shape)
    loss = build_loss(cfg.loss_config)

    compute = loss(res,target)
    print(compute)