在 PyTorch 中,有几种常见的方法可以导入数据集:
使用torchvision.datasets 模块导入常见的计算机视觉数据集,例如 CIFAR10、MNIST 等。可以使用 torchvision.datasets.CIFAR10、torchvision.datasets.MNIST 等类来实例化数据集对象。示例代码:
import torchvision.datasets as datasets# 导入 CIFAR10 数据集train_dataset = datasets.CIFAR10(root='path/to/dataset', train=True, transform=None, download=True)# 导入 MNIST 数据集test_dataset = datasets.MNIST(root='path/to/dataset', train=False, transform=None, download=True)自定义数据集类,继承自 torch.utils.data.Dataset,并实现 __len__ 和 __getitem__ 方法。这允许您以自定义方式加载和处理数据集。示例代码:
from torch.utils.data import Datasetclass MyDataset(Dataset): def __init__(self, data, labels, transform=None): self.data = data self.labels = labels self.transform = transform def __len__(self): return len(self.data) def __getitem__(self, index): x = self.data[index] y = self.labels[index] if self.transform: x = self.transform(x) return x, y# 使用自定义数据集my_dataset = MyDataset(data, labels, transform=None)使用 torch.utils.data.DataLoader 类将数据集包装成可迭代的数据加载器。数据加载器可以用于批量加载数据、多线程加载数据等。示例代码:
from torch.utils.data import DataLoader# 创建数据加载器train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)以上是 PyTorch 中导入数据集的几种常见方法。具体的选择取决于数据集的类型和需求。

