Nelder-Mead 算法

                     

贡献者: addis

预备知识 Matlab 的函数

   Nelder-Mead 算法是一种求多元函数局部最小值的算法,其优点是不需要函数可导并能较快收敛到局部最小值。Matlab 自带的 fminsearch 函数就是使用该算法。对 $N$ 元函数 $f( \boldsymbol{\mathbf{x}} )$(这里把函数自变量用 $N$ 维矢量来表示),该算法需要提供函数自变量空间中的一个初始点,$ \boldsymbol{\mathbf{x}} _1$,算法从该点出发寻找局部最小值,以下是具体步骤。

   我们先根据初始点另外生成 $N$ 个初始点 $ \boldsymbol{\mathbf{x}} _2\dots \boldsymbol{\mathbf{x}} _{N + 1}$,使 $ \boldsymbol{\mathbf{x}} _{1 + i}$ 在第 $i$ 个分量比 $ \boldsymbol{\mathbf{x}} _1$ 的大 5%,其他分量保持相同。如果 $ \boldsymbol{\mathbf{x}} _1$ 的第 $i$ 个分量为零,那么 $ \boldsymbol{\mathbf{x}} _{1 + i}$ 的第 $i$ 个分量设为 $0.00025$。得到 $N+1$ 个初始点后,开始按照以下步骤进行循环,直到满足特定的精度条件时退出循环。

  1. 先给 $ \boldsymbol{\mathbf{x}} _i$ 点按照 $f( \boldsymbol{\mathbf{x}} _i)$ 从小到大的顺序重新排序,使 $i$ 越大 $f( \boldsymbol{\mathbf{x}} _i)$ 越大。
  2. 计算前 $N$ 个点的平均位置为
    \begin{equation} \boldsymbol{\mathbf{m}} = \frac1N \sum_{i=1}^N \boldsymbol{\mathbf{x}} _i~. \end{equation}
  3. 计算 $ \boldsymbol{\mathbf{x}} _{N + 1}$ 关于点 $ \boldsymbol{\mathbf{m}} $ 的对称点
    \begin{equation} \boldsymbol{\mathbf{r}} = 2 \boldsymbol{\mathbf{m}} - \boldsymbol{\mathbf{x}} _{N + 1}~. \end{equation}
  4. 如果 $f( \boldsymbol{\mathbf{x}} _1) \leqslant f( \boldsymbol{\mathbf{r}} ) < f( \boldsymbol{\mathbf{x}} _N)$,令 $ \boldsymbol{\mathbf{x}} _{N+1} = \boldsymbol{\mathbf{r}} $,并进入下一个循环。
  5. 如果 $f( \boldsymbol{\mathbf{r}} ) < f( \boldsymbol{\mathbf{x}} _1)$,计算拓展点
    \begin{equation} \boldsymbol{\mathbf{s}} = \boldsymbol{\mathbf{m}} + 2( \boldsymbol{\mathbf{m}} - \boldsymbol{\mathbf{x}} _{N+1})~. \end{equation}
    如果 $f( \boldsymbol{\mathbf{s}} ) < f( \boldsymbol{\mathbf{r}} )$,令 $ \boldsymbol{\mathbf{x}} _{N+1} = \boldsymbol{\mathbf{s}} $ 并进入下一个循环,否则令 $ \boldsymbol{\mathbf{x}} _{N+1} = \boldsymbol{\mathbf{r}} $。并进入下一循环。
  6. 如果 $f( \boldsymbol{\mathbf{x}} _N) \leqslant f( \boldsymbol{\mathbf{r}} ) < f( \boldsymbol{\mathbf{x}} _{N+1})$,令
    \begin{equation} \boldsymbol{\mathbf{c}} _1 = \boldsymbol{\mathbf{m}} + ( \boldsymbol{\mathbf{r}} - \boldsymbol{\mathbf{m}} )/2~. \end{equation}
    如果 $f( \boldsymbol{\mathbf{c}} _1) < f( \boldsymbol{\mathbf{r}} )$,令 $ \boldsymbol{\mathbf{x}} _{N + 1} = \boldsymbol{\mathbf{c}} _1$ 并进入下一循环,否则执行最后一步。
  7. 如果 $f( \boldsymbol{\mathbf{x}} _{N+1}) \leqslant f( \boldsymbol{\mathbf{r}} )$ 令
    \begin{equation} \boldsymbol{\mathbf{c}} _2 = \boldsymbol{\mathbf{m}} + ( \boldsymbol{\mathbf{x}} _{N+1} - \boldsymbol{\mathbf{m}} )/2~. \end{equation}
    如果 $f( \boldsymbol{\mathbf{c}} _2) < f( \boldsymbol{\mathbf{x}} _{N+1})$,令 $ \boldsymbol{\mathbf{x}} _{N+1} = \boldsymbol{\mathbf{c}} _2$ 并进入下一循环,否则执行最后一步。
  8. \begin{equation} \boldsymbol{\mathbf{v}} _i = \boldsymbol{\mathbf{x}} _1 + ( \boldsymbol{\mathbf{x}} _i - \boldsymbol{\mathbf{x}} _1)/2 \qquad (i = 2\dots N+1)~. \end{equation}
    并用 $ \boldsymbol{\mathbf{v}} _i$ 赋值给 $ \boldsymbol{\mathbf{x}} _i$,进入下一循环。

   观察以上步骤可知,当局部最小值的位置在 $N+1$ 个围成的图形以外时,图形倾向于变大且加速向最小值移动。当最小值的位置在图形内部时,图形倾向于缩小。随着循环次数增加,这 $N+1$ 个点最终将向局部最小值聚拢。

   我们可以在每个循环的第一步之后计算 $ \boldsymbol{\mathbf{x}} _1$ 和 $ \boldsymbol{\mathbf{x}} _{N+1}$ 的距离来估算自变量的误差,如果该误差小于某个值,即可结束循环并使用 $ \boldsymbol{\mathbf{x}} _1$ 作为最终结果。作为另一种方法,我们也可以在每个循环的第一步之后计算 $f( \boldsymbol{\mathbf{x}} _{N+1}) - f( \boldsymbol{\mathbf{x}} _1)$ 来估算最小值的误差。

   以下是该算法的 Matlab 代码。

代码 1:NelderMead.m
% f 是函数句柄,只接受一个 N 维行矢量作为输入变量, 并返回一个函数值
% x0 是 N 维行矢量, xerr 是 xmin 各个元素的绝对误差
function [xmin, fmin] = NelderMead(f, x0, xerr)
N = numel(x0); % f 是 N 元函数
x = zeros(N+1,N); % 预赋值
y = zeros(1,N+1);
% 计算 N+1 个初始点
x(1,:) = x0;
for ii = 1:N
    x(ii+1,:) = x(1,:);
    if x(1,ii) == 0
        x(ii+1,ii) = 0.00025;
    else
        x(ii+1,ii) = 1.05 * x(1,ii);
    end
end
% 主循环
x_last = x*0;
mask = true(1, N+1); % 改变的顶点
while true
    if max(max(abs(x(2:end,:) - x(1,:)))) < xerr % 判断误差
        break;
    elseif all(x(:) == x_last(:))
        warning('NelderMead: abs err too small, machine precision reached');
        break;
    else
        x_last = x;
    end
    % 求值并排序
    for ii = find(mask)
        y(ii) = f(x(ii,:));
    end
    [y, order] = sort(y);
    x = x(order,:);
    m = mean(x(1:N,:)); % 平均位置
    r = 2*m - x(N+1,:); % 反射点
    f_r = f(r);
    mask(:) = false;
    mask(end) = true;
    if y(1) <= f_r && f_r < y(N) % 第 4 步
        x(N+1,:) = r; continue;
    elseif f_r < y(1) % 第 5 步
        s = m + 2*(m - x(N+1,:));
        if f(s) < f_r
            x(N+1,:) = s;
        else
            x(N+1,:) = r;
        end
        continue;
    elseif f_r < y(N+1) % 第 6 步
        c1 = m + (r - m)*0.5;
        if f(c1) < f_r
            x(N+1,:) = c1; continue;
        end
    else % 第 7 步
        c2 = m + (x(N+1,:) - m)*0.5;
        if f(c2) < y(N+1)
            x(N+1,:) = c2; continue;
        end
    end
    for jj = 2:N+1 % 第 8 步
        x(jj,:) = x(1,:) + (x(jj,:) - x(1,:))*0.5;
        mask(jj) = true;
    end
end
% 输出变量
xmin = x(1,:);
fmin = f(xmin);
end

   该程序中有几个需要注意的地方。这是为了避免少数情况下可能发生的死循环(例如 $f( \boldsymbol{\mathbf{x}} )$ 在某个区域中的值处处相等时)。第二,4-7 步中对 $f( \boldsymbol{\mathbf{r}} )$ 的判断有且仅有一个成立,所以我们可以用 if...elseif...else 结构来选择。最后,4-5 步的情况下程序必定会执行 continue 语句而跳过第 8 步,只有 6-7 步中的 if 判断失败程序才会执行第 8 步。


致读者: 小时百科一直以来坚持所有内容免费无广告,这导致我们处于严重的亏损状态。 长此以往很可能会最终导致我们不得不选择大量广告以及内容付费等。 因此,我们请求广大读者热心打赏 ,使网站得以健康发展。 如果看到这条信息的每位读者能慷慨打赏 20 元,我们一周就能脱离亏损, 并在接下来的一年里向所有读者继续免费提供优质内容。 但遗憾的是只有不到 1% 的读者愿意捐款, 他们的付出帮助了 99% 的读者免费获取知识, 我们在此表示感谢。

                     

友情链接: 超理论坛 | ©小时科技 保留一切权利