UNet网络制作

UNet网络制作

代码参考UNet数据集制作及代码实现_哔哩哔哩_bilibili,根据该UP主的代码,加上我的个人整理和理解。(这个UP主的代码感觉很好,很规范

UNet网络由三部分组成:卷积块,下采样层,上采样层。

卷积块

UNet网络中卷积块进行了两次卷积。

class Conv_Block(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(Conv_Block, self).__init__()
        self.layer = nn.Sequential(
            # padding_mode = "reflect" 增强特征提取
            nn.Conv2d(in_channel,out_channel, 3, 1, 1, padding_mode="reflect", bias = False),
            nn.BatchNorm2d(out_channel), # 二维批归一化层,归一化卷积层的输出。用于加速训练和增强模型的泛化能力。
            nn.Dropout2d(0.3), # 二维随机失活层,以概率0.3随机抑制特征,用于防止过拟合。
            nn.LeakyReLU(), # 带有负斜率的修正线性单元激活函数,引入非线性变换。

            nn.Conv2d(out_channel, out_channel, 3, 1, 1, padding_mode="reflect", bias = False),
            nn.BatchNorm2d(out_channel),
            nn.Dropout(0.3),
            nn.LeakyReLU()
        )
    
    def forward(self, x):
        return self.layer(x) 

下采样层

UNet网络的下采样层中进行了一次卷积。下采样将图像大小减半,通道数不变,同时保留更多的重要特征。

class DownSample(nn.Module):
    def __init__(self, channel):
        super(DownSample, self).__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(channel, channel, 3, 2, 1, padding_mode="reflect", bias = False),
            nn.BatchNorm2d(channel),
            nn.LeakyReLU()
        )

    def forward(self, x):
        return self.layer(x)

上采样层

UNet网络的上采样层中进行了一次卷积操作和双线性插值上采样。卷积用于减低通道数,并将其与上一层的特征图进行拼接。用于恢复图像大小,同时提取更加精细的特征。

class UpSample(nn.Module):
    def __init__(self, channel):
        super(UpSample, self).__init__()
        self.layer = nn.Conv2d(channel, channel // 2, 1, 1)

    def forward(self, x, feature_map):
        up = F.interpolate(x, scale_factor=2, mode="nearest")
        out = self.layer(up)
        return torch.cat((out, feature_map), dim = 1)

网络模型

UNet网络模型由编码器和解码器两部分组成。编码器包含了四个 Conv_Block 和四个 DownSample 层,用于逐步提取图像的高级特征。解码器包含了四个 UpSample 和四个 Conv_Block 层,用于通过上采样和特征融合从编码器中恢复图像的细节。最后通过一个卷积层和 Sigmoid 激活函数得到二分类输出,用于分割图像。

class UNet(nn.Module):
    def __init__(self):
        super(UNet,self).__init__()
        self.c1 = Conv_Block(3, 64)
        self.d1 = DownSample(64)
        self.c2 = Conv_Block(64, 128)
        self.d2 = DownSample(128)
        self.c3 = Conv_Block(128, 256)
        self.d3 = DownSample(256)
        self.c4 = Conv_Block(256,512)
        self.d4 = DownSample(512)
        self.c5 = Conv_Block(512, 1024)

        self.u1 = UpSample(1024)
        self.c6 = Conv_Block(1024, 512)
        self.u2 = UpSample(512)
        self.c7 = Conv_Block(512, 256)
        self.u3 = UpSample(256)
        self.c8 = Conv_Block(256, 128)
        self.u4 = UpSample(128)
        self.c9 = Conv_Block(128, 64)
        self.out = nn.Conv2d(64,3,3,1,1)
        # 二分类
        self.Th = nn.Sigmoid()
    
    def forward(self, x):
        R1 = self.c1(x)
        R2 = self.c2(self.d1(R1))
        R3 = self.c3(self.d2(R2))
        R4 = self.c4(self.d3(R3))
        R5 = self.c5(self.d4(R4))

        O1 = self.c6(self.u1(R5, R4))
        O2 = self.c7(self.u2(O1, R3))
        O3 = self.c8(self.u3(O2, R2))
        O4 = self.c9(self.u4(O3, R1))

        return self.Th(self.out(O4))
 

测试

一致则正确。文章来源地址https://uudwc.com/A/k9y4X

if __name__ == "__main__":
    x = torch.randn(2,3,256,256)
    net=UNet()
    print(net(x).shape)

原文地址:https://blog.csdn.net/qq_63432403/article/details/133268539

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请联系站长进行投诉反馈,一经查实,立即删除!

h
上一篇 2023年09月25日 11:35
图层混合算法(一)
下一篇 2023年09月25日 11:36