提升PyTorch DataLoader效率:避免重复实例化
在PyTorch深度学习训练中,高效的数据加载至关重要。 反复创建DataLoader实例会导致进程池的重复创建和销毁,严重影响训练速度。本文介绍如何复用DataLoader,避免这种低效的重复实例化操作。
问题:许多代码在每次迭代中都重新创建DataLoader:DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)。 这会造成性能瓶颈,因为DataLoader初始化需要创建进程池,频繁地创建和销毁进程池会消耗大量资源。
解决方案:将DataLoader的创建移至训练循环之外。 只需在训练开始前创建一次DataLoader实例,并在训练循环中重复使用它即可。 以下代码演示了改进后的方法:
import torch from torch.utils.data import DataLoader, Dataset from math import sqrt from typing import List, Tuple, Union from numpy import ndarray from PIL import Image from torchvision import transforms preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) class PreprocessImageDataset(Dataset): def __init__(self, images: Union[List[ndarray], Tuple[ndarray]]): self.images = images def __len__(self): return len(self.images) def __getitem__(self, idx): image = self.images[idx] image = Image.fromarray(image) preprocessed_image: torch.Tensor = preprocess(image) unsqueezed_image = preprocessed_image return unsqueezed_image if __name__=='__main__': data = list(range(10000000)) batch_size = 10 num_workers = 16 dataset = PreprocessImageDataset(data) dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) for epoch in range(5): print(f"Epoch {epoch + 1}:") for batch_data in dataloader: batch_data print("Batch data:", batch_data) print("Batch data type :", type(batch_data)) print("Batch data shape:", batch_data.shape)
通过将DataLoader的实例化放在循环外,并在多个epoch中复用同一个实例,我们避免了重复创建进程池,显著提高了数据加载效率,减少了系统开销,从而提升了训练性能。
以上就是PyTorch DataLoader 如何避免重复实例化以提升训练效率?的详细内容,更多请关注知识资源分享宝库其它相关文章!
发表评论:
◎欢迎参与讨论,请在这里发表您的看法、交流您的观点。