import torch
import torch.nn as nn
# 对原始输入图像 作 切块处理
class PatchEmbed(nn.Module):
"""
2D Image to Patch Embedding
对2D图像作Patch Embedding操作
"""
def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None):
"""
此函数用于初始化相关参数
:param img_size: 输入图像的大小
:param patch_size: 一个patch的大小
:param in_c: 输入图像的通道数
:param embed_dim: 输出的每个token的维度
:param norm_layer: 指定归一化方式,默认为None
"""
super().__init__()
img_size = (img_size, img_size) # 224 -> (224, 224)
patch_size = (patch_size, patch_size) # 16 -> (16, 16)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) # 计算原始图像被划分为(14, 14)个小块
self.num_patches = self.grid_size[0] * self.grid_size[1] # 计算patch的个数为14*14=196个
# 定义卷积层
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
# 定义归一化方式
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
"""
此函数用于前向传播
:param x: 原始图像
:return: 处理后的图像
"""
B, C, H, W = x.shape
# 检查图像高宽和预先设定是否一致,不一致则报错
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
# 对图像依次作卷积、展平和调换处理: [B, C, H, W] -> [B, C, HW] -> [B, HW, C]
# flatten: [B, C, H, W] -> [B, C, HW]
# transpose: [B, C, HW] -> [B, HW, C]
# 第2步:通过2d卷积进行线性变换 proj
# 第3步:拉平生成线性变量 flatten
# 第4步:块的个数 与 每块的向量维度交换位置 transpose
x = self.proj(x).flatten(2).transpose(1, 2)
# 归一化处理
x = self.norm(x)
return x
if __name__ == "__main__":
x = torch.rand([1, 3, 224, 224])
model = PatchEmbed()
y = model(x)
print(y.shape)
文章来源地址https://uudwc.com/A/X3ddp
文章来源:https://uudwc.com/A/X3ddp