Pytorch中TensorDataset与DataLoader怎么联合使用
更新时间:2023-10-06一、TensorDataset 与 DataLoader 简介
TensorDataset 是一个 Pytorch 中的数据集类,它是将输入数据 Tensor 和标签 Tensor 组成一个新的 Tensor, 并将其组合成一个可以迭代的数据集。而 DataLoader 则是 Pytorch 提供的一个用于数据加载和批量构建的工具类,它可以自动将数据切分成一小批一小批进行处理。
import torch from torch.utils.data import TensorDataset, DataLoader # 构造 TensorDataset inputs = torch.tensor([[1, 2], [3, 4], [5, 6]]) labels = torch.tensor([0, 1, 2]) dataset = TensorDataset(inputs, labels) # 构造 DataLoader dataloader = DataLoader(dataset)
二、TensorDataset 与 DataLoader 的传参方式
在 TensorDataset 与 DataLoader 的传参方式中,除了传入张量数据和标签数据之外,还可以传入一些其他参数,比如 batch_size、shuffle 等。其中 batch_size 表示每个 batch 中的样本数量,shuffle 表示在加载前是否将样本随机打乱。
# 构造 DataLoader dataloader = DataLoader(dataset, batch_size=2, shuffle=False)
三、TensorDataset 与 DataLoader 的迭代方式
TensorDataset 与 DataLoader 的迭代方式可以使用 for 循环进行数据迭代,也可以使用 iter 和 next 进行数据迭代。
# 使用 for 进行迭代 for data in dataloader: inputs, labels = data print(inputs, labels) # 使用 iter 和 next 进行迭代 dataiter = iter(dataloader) inputs, labels = next(dataiter) print(inputs, labels)
四、TensorDataset 与 DataLoader 联合使用示例
下面是 Pytorch 中使用 TensorDataset 和 DataLoader 联合进行数据加载、模型训练和测试的示例代码。
import torch from torch.utils.data import TensorDataset, DataLoader import torch.nn as nn # 构造数据集 X_train = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.float32) y_train = torch.tensor([0, 1, 2], dtype=torch.int64) dataset = TensorDataset(X_train, y_train) # 构造 DataLoader dataloader = DataLoader(dataset, batch_size=2, shuffle=False) # 构造模型 model = nn.Linear(2, 3) # 定义损失函数和优化器 criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=0.01) # 模型训练 for epoch in range(10): for data in dataloader: inputs, labels = data optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() # 模型测试 X_test = torch.tensor([[7, 8], [9, 10]], dtype=torch.float32) y_test = torch.tensor([2, 2], dtype=torch.int64) test_dataset = TensorDataset(X_test, y_test) test_dataloader = DataLoader(test_dataset, batch_size=2, shuffle=False) with torch.no_grad(): for data in test_dataloader: inputs, labels = data outputs = model(inputs) _, predicted = torch.max(outputs.data, 1) print(predicted)
总结
本文介绍了 Pytorch 中 TensorDataset 与 DataLoader 的使用方法,包括简介、传参方式、迭代方式和联合使用示例。在实践中,可以通过构造具有 TensorDataset 形式的数据集,并使用 DataLoader 对其进行批量加载,来进行模型的训练和测试。