用来创建自己的数据集,提供一种方式去获取数据及其label。
1.如何获取每一个数据及其label;2.告诉我们总共有多少数据
help:所有的数据集都需要继承该类,所有的子类都应该重写__getitem__
方法(获取每一个数据及其label),选择性重写__len__
类(返回数据集的大小)
(b站土堆蚂蚁和蜜蜂案例数据集下载:https://download.pytorch.org/tutorial/hymenoptera_data.zip)文章来源:https://uudwc.com/A/OqLaz
创建自己的数据集:1.新建数据类继承Dataset类;2.重写方法;3.实例化使用(注意文件的路径修改为自己的路径)文章来源地址https://uudwc.com/A/OqLaz
from torch.utils.data import Dataset
from PIL import Image
import os
class Mydata(Dataset):
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
#路径相加
self.path = os.path.join(self.root_dir, self.label_dir)
#左侧为list类型,右侧函数是将该文件夹下的所有文件变成一个列表,保存的是文件名
self.img_path = os.listdir(self.path)
def __getitem__(self, idx):
img_name = self.img_path[idx]
#获取图片的路径
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
#获取数据集中的图片
img = Image.open(img_item_path)
#获取图片对应的标签(此处标签为父目录的文件名)
label = self.label_dir
return img, label
def __len__(self):
return len(self.img_path)
root_dir = "dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
#实例化
ants_dataset = Mydata(root_dir, ants_label_dir)
bees_dataset = Mydata(root_dir, bees_label_dir)
train_dataset = ants_dataset + bees_dataset
#获取数据集中的图片和标签
img, label = train_dataset[0]