c语言编程笔录

首页 >   > 笔记大全

笔记大全

Pytorch中TensorDataset与DataLoader怎么联合使用

更新时间:2023-10-06

一、TensorDataset 与 DataLoader 简介

TensorDataset 是一个 Pytorch 中的数据集类,它是将输入数据 Tensor 和标签 Tensor 组成一个新的 Tensor, 并将其组合成一个可以迭代的数据集。而 DataLoader 则是 Pytorch 提供的一个用于数据加载和批量构建的工具类,它可以自动将数据切分成一小批一小批进行处理。

import torch
from torch.utils.data import TensorDataset, DataLoader

# 构造 TensorDataset
inputs = torch.tensor([[1, 2], [3, 4], [5, 6]])
labels = torch.tensor([0, 1, 2])
dataset = TensorDataset(inputs, labels)

# 构造 DataLoader
dataloader = DataLoader(dataset)

二、TensorDataset 与 DataLoader 的传参方式

在 TensorDataset 与 DataLoader 的传参方式中,除了传入张量数据和标签数据之外,还可以传入一些其他参数,比如 batch_size、shuffle 等。其中 batch_size 表示每个 batch 中的样本数量,shuffle 表示在加载前是否将样本随机打乱。

# 构造 DataLoader
dataloader = DataLoader(dataset, batch_size=2, shuffle=False)

三、TensorDataset 与 DataLoader 的迭代方式

TensorDataset 与 DataLoader 的迭代方式可以使用 for 循环进行数据迭代,也可以使用 iter 和 next 进行数据迭代。

# 使用 for 进行迭代
for data in dataloader:
    inputs, labels = data
    print(inputs, labels)

# 使用 iter 和 next 进行迭代
dataiter = iter(dataloader)
inputs, labels = next(dataiter)
print(inputs, labels)

四、TensorDataset 与 DataLoader 联合使用示例

下面是 Pytorch 中使用 TensorDataset 和 DataLoader 联合进行数据加载、模型训练和测试的示例代码。

import torch
from torch.utils.data import TensorDataset, DataLoader
import torch.nn as nn

# 构造数据集
X_train = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.float32)
y_train = torch.tensor([0, 1, 2], dtype=torch.int64)
dataset = TensorDataset(X_train, y_train)

# 构造 DataLoader
dataloader = DataLoader(dataset, batch_size=2, shuffle=False)

# 构造模型
model = nn.Linear(2, 3)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# 模型训练
for epoch in range(10):
    for data in dataloader:
        inputs, labels = data
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

# 模型测试
X_test = torch.tensor([[7, 8], [9, 10]], dtype=torch.float32)
y_test = torch.tensor([2, 2], dtype=torch.int64)
test_dataset = TensorDataset(X_test, y_test)
test_dataloader = DataLoader(test_dataset, batch_size=2, shuffle=False)
with torch.no_grad():
    for data in test_dataloader:
        inputs, labels = data
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        print(predicted)

总结

本文介绍了 Pytorch 中 TensorDataset 与 DataLoader 的使用方法,包括简介、传参方式、迭代方式和联合使用示例。在实践中,可以通过构造具有 TensorDataset 形式的数据集,并使用 DataLoader 对其进行批量加载,来进行模型的训练和测试。