c语言编程笔录

首页 >   > 笔记大全

笔记大全

一文告诉你为选择pytorch

更新时间:2023-09-08

什么是PyTorch

PyTorch 是由 Facebook 发布的深度学习框架,是 Torch 的 Python 版本。PyTorch 的使用和 Python 语言具有非常高的相似度,因此对于 Python 开发者而言非常友好。PyTorch 不仅支持 Tensor 张量和动态计算图,还可以更好的实现多种神经网络算法,这让它成为了许多初学者和专业人士的首选。

import torch

# 创建张量
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])

# 张量加法
z = x + y
print(z)    # tensor([5, 7, 9])

动态图与静态图的比较

PyTorch 采用动态图,TensorFlow 采用静态图,二者的主要区别在于:
1. 编程方式:PyTorch 使用动态定义计算图,需要进行反向传播时动态执行,而 TensorFlow 先定义计算图,再根据计算图生成静态计算图,最后执行。
2. 运行方式:PyTorch 每次迭代计算时会重新构建计算图,相对比较灵活,能够快速调试代码,而 TensorFlow 先构建计算图再执行,可以优化运行效率。

import torch

# 动态图
x = torch.randn(3, 5)
y = torch.randn(5, 4)
z = torch.matmul(x, y)
z.backward(torch.randn(3, 4))

# 静态图
import tensorflow as tf

x = tf.placeholder(tf.float32, shape=[None, 5])
y = tf.placeholder(tf.float32, shape=[5, 4])
z = tf.matmul(x, y)
loss = tf.reduce_mean(tf.square(z - 2))
train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)

PyTorch 的优点

PyTorch 具有如下优点:
1. 简单易用:PyToch 的代码简洁易懂,特别适合初学者使用。
2. 动态计算图:PyTorch 采用动态图的方式进行计算,使用和调试都相对方便。
3. 广泛的应用领域:PyTorch 在自然语言处理、计算机视觉、语音识别和强化学习等领域都有广泛的应用场景。
4. 强大的社区支持:由 Facebook 支持,拥有庞大的社区和优秀的生态系统,持续迭代更新。

import torch

# 自定义神经网络模型
class Net(torch.nn.Module):
    def __init__(self, n_feature, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden = torch.nn.Linear(n_feature, n_hidden)
        self.predict = torch.nn.Linear(n_hidden, n_output)
        
    def forward(self, x):
        x = torch.relu(self.hidden(x))
        x = self.predict(x)
        return x

# 实例化模型
net = Net(2, 10, 2)
print(net)

PyTorch 的应用场景

PyTorch 的应用场景非常丰富,特别在以下领域有较为广泛的应用:
1. 自然语言处理:PyTorch 在机器翻译、文本分类、命名实体识别等领域表现突出;
2. 计算机视觉:PyTorch 在图像分类、目标检测、图像分割等方面有很好的表现;
3. 语音识别:PyTorch 在声学模型、语言模型等领域有很好的应用;
4. 强化学习:PyTorch 也广泛用于构建和训练强化学习模型。

import torch
import torchvision
import torchvision.transforms as transforms

# 加载 CIFAR-10 数据集
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

# 测试数据
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

# 类别
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')