PyTorch - Dataset 迭代数据接口 __getitem__ 异常处理

欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/133378772

Dataset

在模型训练的过程中,加载数据部分,极其容易出现异常,以及不可控的因素,需要通过异常捕获的方式,及时处理,常用方式就是使用 collate_fn,除此之外,还可以直接跳过错误样本,运行下一个样本进行补充。

PyTorch Dataset 类是一个抽象类,用于表示一个数据集,可以将数据和标签封装成一个可迭代的对象。要使用 Dataset 类,我们需要继承它,并实现两个方法:

  • __getitem__(self, index):根据给定的索引,返回数据集中的一个样本和对应的标签。
  • __len__(self):返回数据集中的样本数量。

即:

  1. 将数据获取封装成单独函数。
  2. 使用 while True 持续监控,如果运行正确,即 break 跳过。
  3. 如果运行失败,则打印日志,选择下一个样本运行,即 idx += 1
  4. 注意,索引不要溢出。

源码如下:文章来源地址https://uudwc.com/A/Ev0wB

    def __getitem__(self, idx):
        # TODO: 解决数据异常问题,KeyError,尽量保持数据干净
        while True:
            try:
                feats = self.getitem_wrapper(idx)
                break
            except Exception as e:
                name = self.idx_to_chain_id(idx)
                logger.error(f"err sample: {name} !!!")
                idx += 1
                idx = idx % len(self._chain_ids)  # 避免溢出
        return feats

原文地址:https://blog.csdn.net/u012515223/article/details/133378772

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处: 如若内容造成侵权/违法违规/事实不符,请联系站长进行投诉反馈,一经查实,立即删除!

上一篇 2023年10月24日 03:04
【考研数学】高等数学第七模块 —— 曲线积分与曲面积分 | 3. 对面积的曲面积分(第一类曲面积分)
下一篇 2023年10月24日 05:04