PyTorch深度学习中的自动求导一文读懂
在PyTorch这个深度学习框架中,自动求导是它的一个重要特点。我们在用PyTorch训练模型的过程中,自动求导能帮我们显著提升效率,它可以自动计算模型的梯度和损失函数从而缩短训练的时间。本篇将为各位同学介绍一下自动求导这个创新而又高效的机制。
01 什么是自动求导
假如我们需要解决一个数学问题,计算一个复杂的函数 f(x, y, z) 的值,首先想知道这个函数对于每个变量 x, y, z 的变化有多敏感,也就是求它的偏导数 ∂f/∂x, ∂f/∂y, ∂f/∂z。
PyTorch 的自动求导就是帮我们自动完成这个求偏导数的过程。它主要依赖两个核心概念:
1. 计算图 (Computational Graph):当用户用 PyTorch 进行一系列运算时,PyTorch 会默默地记录下这些运算的顺序和关系,构建一个有向无环图。图中的节点表示张量(数据)和运算,带箭头的边表示数据流动方向。例如,如果计算 `y=3x^2 + 4x + 2`,计算图会记录下 `x`, 是如何通过乘法和加法运算得到 `y` 的。
2. 反向传播 (Backpropagation):构建好计算图后,如果你想求某个标量(通常是损失函数)对各个变量的导数,PyTorch 会从这个标量出发,沿着计算图反向传播,根据链式法则自动计算出每个节点的梯度(导数)。
那我们如何更简单地理解这两个概念呢?举个例子,就像一个人要学习骑自行车,而人的大脑就是PyTorch的自动求导系统。当人开始骑车时,大脑会不断地调整平衡,让身体保持直立。这个过程中,大脑实际上是在做很多复杂的计算,比如从身体的角度、速度、加速度等等,但是这些计算对人脑来说是自动的,不需要主动去想每一步该怎么做。
在PyTorch中,自动求导也是这样工作的。当用户定义了一个计算图(就像人在骑车时的一系列动作),PyTorch会帮你记录下这个图中的每一个操作(比如加法、乘法)。这些操作就像是骑车时的每一个小步骤。
当完成了一次前向传播(也就是从起点骑到终点),PyTorch会帮自动计算出损失(比如你偏离目标的距离)。然后,如果想要改进(比如骑得更直),PyTorch会使用自动求导功能来计算出每一步操作对最终损失的影响(这就像是大脑计算出需要如何调整身体动作来保持平衡)。
这个过程叫做反向传播。PyTorch会从损失开始,逆向遍历整个计算图,计算出每个操作对损失的“贡献”(也就是梯度)。这样,用户就可以知道如何调整参数(比如骑车姿势),来减少损失了(更接近目标)。
PyTorch的自动求导功能就像是人大脑在骑车时的自动平衡系统,它能帮人记录下每一个操作,然后在需要的时候,计算出如何调整这些操作来达到更好的结果。这样,用户就不需要自己去手动计算每一步的梯度,PyTorch都自动搞定了。
02 举例
让我们通过一个简单的例子来进一步理解PyTorch中的自动求导过程。假设我们有一个简单的数学问题:计算函数 f(x)=x2f(x)=x2 在 x=3x=3 时的导数。在数学上,我们知道 f(x)f(x) 的导数是 f′(x)=2xf′(x)=2x,所以 f′(3)=6f′(3)=6。现在,我们用PyTorch来自动求导计算这个导数。
首先,我们需要导入PyTorch库,并创建一个变量 xx,告诉PyTorch这个变量需要被跟踪(即需要计算梯度):
import torch
# 创建一个需要梯度的变量x,并初始化为3
x = torch.tensor(3.0, requires_grad=True)
接下来,我们定义函数 f(x)=x2f(x)=x2 并计算其值:
# 定义函数f(x) = x^2
f_x = x ** 2
现在,我们计算 f(x)f(x) 的值:
# 计算f(x)的值
print(f_x) # 输出: tensor(9., grad_fn=<PowBackward0>)
注意到输出中有一个 `grad_fn=<PowBackward0>`,这表示PyTorch已经记录了这个操作,并且知道如何对这个操作进行求导。
接下来,我们计算损失(在这里,损失就是函数的值,因为我们只是想要计算导数):
# 计算损失,这里我们直接使用f(x)作为损失
loss = f_x
然后,我们告诉PyTorch我们需要计算损失相对于 xx 的梯度:
# 告诉PyTorch我们需要计算梯度
loss.backward()
最后,我们可以获取 xx 的梯度:
# 获取x的梯度
print(x.grad) # 输出: tensor(6.)
这个梯度就是 f(x)f(x) 在 x=3x=3 时的导数,正如我们之前手动计算的那样。
通过这个例子,我们可以看到PyTorch是如何自动记录操作并计算梯度的。这个过程对于复杂的神经网络来说也是一样的,PyTorch会记录下所有的操作,并在调用 `backward()` 方法时计算出每个参数的梯度。这些梯度随后可以被用来更新网络的权重,以最小化损失函数。
03 总结
PyTorch 的自动求导机制通过构建计算图和反向传播,实现了自动计算梯度的功能。这使得用户可以专注于模型的设计和训练,而无需手动计算复杂的导数,极大地提高了开发效率。它就像一个聪明的“导数计算器”,帮助我们更好地理解和优化模型。
来源: