一文搞懂梯度下降算法(附MATLAB完整代码)

1.梯度下降算法

梯度下降是一种用于优化函数的迭代方法,常用于机器学习和深度学习模型的参数训练。这个算法的目标是找到一个使得损失函数(即需要最小化的函数)取值最小的参数值。

梯度下降法的基本思想是:在每一次迭代中,以当前位置为出发点,沿着函数值下降最快的方向(即负梯度方向)前进一步,然后在新的位置再次计算梯度,并沿新的梯度方向前进,如此反复迭代,直到找到函数的最小值位置。

具体的梯度下降算法可以描述如下:

初始化参数值(例如随机初始化)。

计算损失函数的梯度(即对每个参数求偏导数)。

更新参数:新的参数值 = 旧的参数值 - 学习率 * 梯度。这里的学习率(或称为步长)是一个超参数,用于控制每次更新的步长。

重复上述步骤2和3,直到满足某个停止准则(比如梯度趋近于0,或者达到预设的最大迭代次数)。

以下是用数学公式表示的参数更新步骤:

假设我们要最小化的函数为 f(θ),其中 θ 是我们要找的参数。我们希望找到一个 θ,使得 f(θ) 取得最小值。

在每一步迭代中,我们都会计算 f 在当前 θ 值处的梯度 ∇f(θ),然后按照以下方式更新 θ:

\theta_{new} = \theta_{old} - \alpha \triangledown f(\theta_{old})

其中,α 是学习率,控制我们每一步移动的距离。这个等式的含义是:新的 θ 值等于旧的 θ 值减去梯度乘以学习率。

这个更新规则的推导来自于泰勒级数的一阶近似:

f(\theta + \Delta\theta ) = f(\theta ) + \bigtriangledown f(\theta )* \Delta \theta

我们希望找到一个 ∆θ,使得 f(θ + ∆θ) 尽可能小。因此我们可以选择 ∆θ = -α * ∇f(θ),这样就可以确保 f(θ + ∆θ) < f(θ),即新的函数值比旧的函数值小,从而实现函数值的下降。

这就是梯度下降算法的基本推导过程。需要注意的是,这个推导过程假设了 f 是凸函数,或者至少在我们要找的最小值附近是凸的。对于非凸函数,梯度下降算法可能只能找到局部最小值,而非全局最小值。

2.MATLAB代码


clear all;clc;close all;


% 以下是一个使用梯度下降算法进行优化的基本 MATLAB 代码示例。在这个例子中,我们尝试找到函数 f(x) = x.^3/3 + 3*x.^2/2 - 5*x 的最小值。
% 在这个例子中,我们设定初始值为10,学习率为0.1,最大迭代次数为100,精度为0.00001。梯度函数对应于 f(x) 的导数,即 f’(x) = x.^2+3*x-5。
% 在每次迭代时,我们按照梯度下降的更新规则来更新 x 的值。如果 x 的值在两次迭代之间的变化小于设定的精度,那么我们就认为梯度下降已经收敛,并结束迭代。


% 定义初始值
x = 10; 

% 定义学习率
alpha = 0.1; 

% 定义迭代次数
n_iterations = 100; 

% 定义精度
precision = 0.0000001; 

% 定义梯度函数
fxfun= @(x) (x.^3/3 + 3*x.^2/2 - 5*x);% 定义待求极值函数
gradient = @(x) (x.^2+3*x-5);% 定义待求极值函数的导函数

% 梯度下降过程
xmat2=[x];
for i = 1:n_iterations
    x_old = x;
    
    % 更新 x
    x = x - alpha * gradient(x_old);
    
    xmat2=[xmat2,x];
    
    fprintf('Iteration: %d, x value: %f\n', i, x);
    
    % 检查收敛性
    if abs(x - x_old) < precision
        fprintf('Gradient descent has converged after %d iterations.', i);
        break;
    end
end

% 绘图
xmat=-10:0.01:10;
ymat=fxfun(xmat);


ymat2=fxfun(xmat2);


figure;
plot(xmat,ymat,'b-','linewidth',1);
hold on;
plot(xmat2,ymat2,'r*');
legend({'函数','梯度下降轨迹'},'fontname','宋体');
xlabel('x','fontname','宋体');
ylabel('y','fontname','宋体');
title('','fontname','宋体');

相关推荐

  1. OPC质量码

    2024-02-16 19:04:01       67 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-02-16 19:04:01       98 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-02-16 19:04:01       106 阅读
  3. 在Django里面运行非项目文件

    2024-02-16 19:04:01       87 阅读
  4. Python语言-面向对象

    2024-02-16 19:04:01       96 阅读

热门阅读

  1. 【图论经典题目讲解】洛谷 P2149 Elaxia的路线

    2024-02-16 19:04:01       60 阅读
  2. 应急响应实战笔记02日志分析篇(2)

    2024-02-16 19:04:01       49 阅读
  3. MySQL双写机制

    2024-02-16 19:04:01       61 阅读
  4. coredns 状态为running但not ready

    2024-02-16 19:04:01       53 阅读
  5. Acwing---869. 试除法求约数

    2024-02-16 19:04:01       44 阅读