PyTorch分布式数据加载学习 DistributedSampler

[源码解析] PyTorch 分布式(1) --- 数据加载之DistributedSampler - 罗西的思考 - 博客园

初始化

class DistributedSampler(Sampler[T_co]):
    def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None,
                 rank: Optional[int] = None, shuffle: bool = True,
                 seed: int = 0, drop_last: bool = False) -> None:
        # If the dataset length is evenly divisible by # of replicas, then there
        # is no need to drop any data, since the dataset will be split equally.
        if self.drop_last and len(self.dataset) % self.num_replicas != 0:
            # Split to nearest available length that is evenly divisible.
            # This is to ensure each rank receives the same amount of data when
            # using this Sampler.
            self.num_samples = math.ceil(
                (len(self.dataset) - self.num_replicas) / self.num_replicas
            )
        else:
            self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # 向上取整
        self.total_size = self.num_samples * self.num_replicas
        self.shuffle = shuffle
        self.seed = seed

如果不drop_last,那就

self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)

向上取整。

迭代

在迭代的时候,如果不能整除,那就把indices的前几个样本复制一遍:

    def __iter__(self) -> Iterator[T_co]:
        if self.shuffle:
            # deterministically shuffle based on epoch and seed
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            indices = torch.randperm(len(self.dataset), generator=g).tolist()
        else:
            indices = list(range(len(self.dataset)))

        if not self.drop_last:  # 一般进入这里,不会丢掉剩下的训练数据
            # add extra samples to make it evenly divisible
            padding_size = self.total_size - len(indices)
            if padding_size <= len(indices):
                indices += indices[:padding_size]   # 把indices的前几个复制一次
            else:
                indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
        else:
            # remove tail of data to make it evenly divisible.
            indices = indices[:self.total_size]
        assert len(indices) == self.total_size

        # subsample
        indices = indices[self.rank:self.total_size:self.num_replicas]
        assert len(indices) == self.num_samples

        return iter(indices)

最关键的还是这行:

indices = indices[self.rank:self.total_size:self.num_replicas]

规定了每个rank的取数据的索引,起始索引是rank,每间隔num_replicas取一个

shuffle数据集

每次epoch都会shuffle数据集,但是不同进程如何保持shuffle之后数据集一致性

DistributedSampler 使用当前的epoch作为随机数种子,在计算index之前就进行配置,从而保证不同进程都使用同样的随机数种子,这样shuffle出来的数据就能确保一致。

sampler = DistributedSampler(dataset) if is_distributed else None
loader = DataLoader(dataset, shuffle=(sampler is None), ...,
                            sampler=sampler)
	for epoch in range(start_epoch, n_epochs):
    	if is_distributed:
        	sampler.set_epoch(epoch) # 这设置epoch
        train(loader)

设置 random 种子的具体使用是在迭代函数之中:

    def __iter__(self) -> Iterator[T_co]:
        if self.shuffle:
            # deterministically shuffle based on epoch and seed
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch) # 这里设置随机种子
            indices = torch.randperm(len(self.dataset), generator=g).tolist()  # type: ignore[arg-type]
        else:
            indices = list(range(len(self.dataset)))  # type: ignore[arg-type]

在 PyTorch 中,torch.randperm(n) 函数用于生成一个从 0n-1 的随机排列的整数序列。这个函数是非常有用的,尤其是在需要随机打乱数据或索引时,比如在训练机器学习模型时打乱数据顺序,以确保模型训练的泛化能力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/887062.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Ubuntu24.04远程开机

近来在几台机器上鼓捣linux桌面&#xff0c;顺便研究一下远程唤醒主机。 本篇介绍Ubuntu系统的远程唤醒&#xff0c;Windows系统的唤醒可搜索相关资料。 依赖 有远程唤醒功能的路由器&#xff08;当前一般都带这个功能&#xff09;有线连接主机&#xff08;无线连接有兴趣朋友…

推荐:五种限流(Rate Limiting)算法

推荐&#xff1a;五种限流(Rate Limiting)算法&#xff0c;发现一个不错的讲这个算法的UP,地址是&#xff1a;05~五种限流(Rate Limiting)算法_哔哩哔哩_bilibili https://www.bilibili.com/video/BV11k4SerE74/ 全部用动画展示&#xff0c;十分生动&#xff0c;比如漏桶算法&…

芝法酱学习笔记(0.5)——使用jenkins做自动打包

前言 上节讲了SpringBoot上的打包。但这些过程都是手动的&#xff0c;在实际的开发测试时&#xff0c;自动化的打包部署&#xff0c;可以大大提升团队开发的效率 一、去官网下载 1.1 官网安装命令 对于如何安装的问题&#xff0c;我向来推荐官网 wget -O /usr/share/keyri…

针对考研的C语言学习(定制化快速掌握重点4)

typedef的使用 简化变量类型 逻辑结构 集合结构&#xff1a;无关系 线性结构&#xff1a;一对一 树形结构&#xff1a;一对多 图形结构&#xff1a;多对多 存储结构 顺序存储和链式存储&#xff08;考代码&#xff09; 顺序存储优点&#xff1a;1.可以实现随机存取。2.…

C语言 | Leetcode C语言题解之题451题根据字符出现频率排序

题目&#xff1a; 题解&#xff1a; #define HASH_FIND_CHAR(head, findint, out) HASH_FIND(hh, head, findint, sizeof(char), out) #define HASH_ADD_CHAR(head, intfield, add) HASH_ADD(hh, head, intfield, sizeof(char), add)struct HashTable {char key;int val;UT_ha…

今日凌晨,ChatGPT重磅更新!—— 我心目中的终极AGI界面

今日凌晨&#xff0c;ChatGPT重磅更新&#xff01;—— 我心目中的终极AGI界面 我心目中的终极 AGI 界面是一张空白画布&#xff08;canvas&#xff09;。 今日凌晨&#xff0c;OpenAI 发布 canvas&#xff0c;一个与 ChatGPT 合作写作和编程的新界面&#xff01; canvas&…

C语言复习概要(二)

本文目录 C语言中的数组与函数详解1. 引言2. 数组2.1. 什么是数组&#xff1f;语法&#xff1a;示例&#xff1a; 2.2. 数组的初始化示例 1&#xff1a;在声明时初始化示例 2&#xff1a;部分初始化示例 3&#xff1a;运行时赋值 2.3. 数组的访问与修改示例&#xff1a; 2.4. 多…

Docker启动 Redis提示:Can‘t initialize Background Jobg

问题说明: 在使用docker启动redis失败&#xff0c;但是查看容器日志&#xff0c;除了提示 Fatal:Cant initialize Background Jobg&#xff0c;没有其他错误信息。经过长时间查找资料及试错&#xff0c;现记录下可能的产生原因及解决方案&#xff0c;以便以后参考。 产生原因&…

【从零开始实现stm32无刷电机FOC】【实践】【7.1/7 硬件设计】

目录 stm32电路磁编码器电路电机驱动电路电流采样电路电机选择本文示例硬件说明 为了承载和验证本文的FOC代码工程&#xff0c;本节设计了一个简易的三相无刷电机 硬件套件&#xff0c;主控采用非常常用的stm32f103c8t6单片机&#xff0c;电机编码器采用MT6701&#xff0c;电机…

mysql怎么修改一个字段中的所有部分数据

UPDATE videos SET VideoCode replace(VideoCode,flv,mp4); update 表名 set 字段名 replace&#xff08;字段名&#xff0c;‘修改前’&#xff0c;‘修改后’&#xff09;&#xff1b;

【工欲善其事】巧用 Sublime Text 生成带格式的 HTML 片段

文章目录 【工欲善其事】巧用 Sublime Text 生成带格式的 HTML 片段1 问题由来2 操作流程步骤1&#xff1a;打开代码片段定制页步骤2&#xff1a;在新标签页输入定制 XML步骤3&#xff1a;保存定义内容步骤4&#xff1a;功能测试 3 拓展 【工欲善其事】巧用 Sublime Text 生成带…

Elasticsearch使用Easy-Es + RestHighLevelClient实现深度分页跳页

注意&#xff01;&#xff01;&#xff01;博主只在测试环境试了一下&#xff0c;没有发到生产环境跑。因为代码还没写完客户说不用弄了( •̩̩̩̩&#xff3f;•̩̩̩̩ ) 也好&#xff0c;少个功能少点BUG 使用from size的时候发现存在max_result_window10000的限制&…

如何使用工具删除 iPhone 上的图片背景

在 iPhone 上删除背景图像变得简单易行。感谢最近 iOS 更新中引入的新功能。如今&#xff0c;iOS 用户现在可以毫不费力地删除背景&#xff0c;而无需复杂的应用程序。在这篇文章中&#xff0c;您将学习如何使用各种方法去除 iPhone 上的背景。这可确保您可以选择最适合您偏好的…

自动驾驶核心技术:感知融合、规划决策、控制执行

1、前言 简单来说&#xff0c;实现自动驾驶需要解决三个核心问题&#xff1a;“我在哪?我要去哪?我该如何去?”能完整解决这三个问题就是真正的自动驾驶。 目前&#xff0c;自动驾驶汽车关键技术主要包括环境感知、精准定位、决策与规划、控制与执行、高精地图与车联网V2X以…

Linux下的IO模型

阻塞与非阻塞IO&#xff08;Input/Output&#xff09; 阻塞与非阻塞IO&#xff08;Input/Output&#xff09;是计算机操作系统中两种不同的文件或网络通信方式。它们的主要区别在于程序在等待IO操作完成时的行为。 阻塞IO&#xff08;Blocking IO&#xff09; 在阻塞IO模式下…

无IDEA不Java:快速掌握Java集成开发环境

IntelliJ IDEA是一种强大的Java集成开发环境&#xff0c;是Java开发人员的首选工具之一。本文将介绍IDEA的基本使用方法和常用功能&#xff0c;以帮助初学者快速上手。 安装和配置 首先&#xff0c;需要下载并安装IntelliJ IDEA。在安装完成后&#xff0c;需要配置JDK&#xff…

pygame--超级马里奥(万字详细版)

超级马里奥点我下载https://github.com/marblexu/PythonSuperMario 1.游戏介绍 小时候的经典游戏&#xff0c;代码参考了github上的项目Mario-Level-1&#xff0c;使用pygame来实现&#xff0c;从中学习到了横版过关游戏实现中的一些处理方法。原项目实现了超级玛丽的第一个小…

稀缺森林火险等级预测算法,基于xgboost方法的火险等级预测,共划分5级,依据当前地区月份,降水量,风力等参数进行预测,并提供15000字的报告

森林火险等级预测算法&#xff0c;基于xgboost方法的火险等级预测&#xff0c;共划分5级&#xff0c;依据当前地区月份&#xff0c;降水量&#xff0c;风力等参数进行预测&#xff0c;并提供15000字的报告 森林火险等级预测算法介绍 项目名称 基于XGBoost的森林火险等级预测算…

无环SLAM系统集成后端回环检测模块(loop):SC-A-LOAM以及FAST_LIO_SLAM

最近在研究SLAM目标检测相关知识&#xff0c;看到一篇论文&#xff0c;集成了SC-A-LOAM作为后端回环检测模块&#xff0c;在学习了论文相关内容后决定看一下代码知识&#xff0c;随后将其移植&#xff0c;学习过程中发现我找的论文已经集成了回环检测模块&#xff0c;但是我的另…

mybatis-plus使用总结

基本使用 mybatis-plus依赖 <!-- mybatis-plus开始 --><dependency><groupId>com.baomidou</groupId><artifactId>mybatis-plus-boot-starter</artifactId><version>3.5.7</version></dependency><dependency>&l…