Dataset类

用来创建自己的数据集,提供一种方式去获取数据及其label。

1.如何获取每一个数据及其label;2.告诉我们总共有多少数据

help:所有的数据集都需要继承该类,所有的子类都应该重写__getitem__方法(获取每一个数据及其label),选择性重写__len__类(返回数据集的大小)

(b站土堆蚂蚁和蜜蜂案例数据集下载:https://download.pytorch.org/tutorial/hymenoptera_data.zip)

创建自己的数据集:1.新建数据类继承Dataset类;2.重写方法;3.实例化使用(注意文件的路径修改为自己的路径)文章来源地址https://uudwc.com/A/OqLaz

from torch.utils.data import Dataset
from PIL import Image
import os

class Mydata(Dataset):

   def __init__(self, root_dir, label_dir):
       self.root_dir = root_dir
       self.label_dir = label_dir
       #路径相加
       self.path = os.path.join(self.root_dir, self.label_dir)
       #左侧为list类型,右侧函数是将该文件夹下的所有文件变成一个列表,保存的是文件名
       self.img_path = os.listdir(self.path)

   def __getitem__(self, idx):
       img_name = self.img_path[idx]
       #获取图片的路径
       img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
       #获取数据集中的图片
       img = Image.open(img_item_path)
       #获取图片对应的标签(此处标签为父目录的文件名)
       label = self.label_dir
       return img, label

   def __len__(self):
       return len(self.img_path)


root_dir = "dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
#实例化
ants_dataset = Mydata(root_dir, ants_label_dir)
bees_dataset = Mydata(root_dir, bees_label_dir)

train_dataset = ants_dataset + bees_dataset
#获取数据集中的图片和标签
img, label = train_dataset[0]
 

原文地址:https://www.cnblogs.com/yq-ydky/archive/2023/08/09/17616581.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请联系站长进行投诉反馈,一经查实,立即删除!

上一篇 2023年09月17日 23:14
GPT-4助力数据分析:提升效率与洞察力的未来关键技术
下一篇 2023年09月17日 23:18