VIT对原始输入图像 作 切块处理PatchEmbed

import torch
import torch.nn as nn

# 对原始输入图像 作 切块处理
class PatchEmbed(nn.Module):
    """
    2D Image to Patch Embedding
    对2D图像作Patch Embedding操作
    """
    def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None):
        """
        此函数用于初始化相关参数
        :param img_size: 输入图像的大小
        :param patch_size: 一个patch的大小
        :param in_c: 输入图像的通道数
        :param embed_dim: 输出的每个token的维度
        :param norm_layer: 指定归一化方式,默认为None
        """
        
        super().__init__()
        img_size = (img_size, img_size) # 224 -> (224, 224)
        patch_size = (patch_size, patch_size) # 16 -> (16, 16)
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) # 计算原始图像被划分为(14, 14)个小块
        self.num_patches = self.grid_size[0] * self.grid_size[1] # 计算patch的个数为14*14=196个

        # 定义卷积层
        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
        # 定义归一化方式
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        """
        此函数用于前向传播
        :param x: 原始图像
        :return: 处理后的图像
        """
        B, C, H, W = x.shape
        
        # 检查图像高宽和预先设定是否一致,不一致则报错
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."

        # 对图像依次作卷积、展平和调换处理: [B, C, H, W] -> [B, C, HW] -> [B, HW, C]
        # flatten: [B, C, H, W] -> [B, C, HW]
        # transpose: [B, C, HW] -> [B, HW, C]
        # 第2步:通过2d卷积进行线性变换 proj
        # 第3步:拉平生成线性变量 flatten
        # 第4步:块的个数 与 每块的向量维度交换位置 transpose
        x = self.proj(x).flatten(2).transpose(1, 2)
        
        # 归一化处理
        x = self.norm(x)
        return x
    
if __name__ == "__main__":
    x = torch.rand([1, 3, 224, 224])
 
    model = PatchEmbed()
    y = model(x)
    print(y.shape)

文章来源地址https://uudwc.com/A/X3ddp

原文地址:https://blog.csdn.net/liuweizj12/article/details/132225053

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请联系站长进行投诉反馈,一经查实,立即删除!

h
上一篇 2023年08月11日 10:36
下一篇 2023年08月11日 10:36