Nelder-Mead 算法

                     

贡献者: addis

预备知识 Matlab 的函数

   Nelder-Mead 算法是一种求多元函数局部最小值的算法,其优点是不需要函数可导并能较快收敛到局部最小值。Matlab 自带的 fminsearch 函数就是使用该算法。对 $N$ 元函数 $f( \boldsymbol{\mathbf{x}} )$(这里把函数自变量用 $N$ 维矢量来表示),该算法需要提供函数自变量空间中的一个初始点,$ \boldsymbol{\mathbf{x}} _1$,算法从该点出发寻找局部最小值,以下是具体步骤。

   我们先根据初始点另外生成 $N$ 个初始点 $ \boldsymbol{\mathbf{x}} _2\dots \boldsymbol{\mathbf{x}} _{N + 1}$,使 $ \boldsymbol{\mathbf{x}} _{1 + i}$ 在第 $i$ 个分量比 $ \boldsymbol{\mathbf{x}} _1$ 的大 5%,其他分量保持相同。如果 $ \boldsymbol{\mathbf{x}} _1$ 的第 $i$ 个分量为零,那么 $ \boldsymbol{\mathbf{x}} _{1 + i}$ 的第 $i$ 个分量设为 $0.00025$。得到 $N+1$ 个初始点后,开始按照以下步骤进行循环,直到满足特定的精度条件时退出循环。

  1. 先给 $ \boldsymbol{\mathbf{x}} _i$ 点按照 $f( \boldsymbol{\mathbf{x}} _i)$ 从小到大的顺序重新排序,使 $i$ 越大 $f( \boldsymbol{\mathbf{x}} _i)$ 越大。
  2. 计算前 $N$ 个点的平均位置为
    \begin{equation} \boldsymbol{\mathbf{m}} = \frac1N \sum_{i=1}^N \boldsymbol{\mathbf{x}} _i~. \end{equation}
  3. 计算 $ \boldsymbol{\mathbf{x}} _{N + 1}$ 关于点 $ \boldsymbol{\mathbf{m}} $ 的对称点
    \begin{equation} \boldsymbol{\mathbf{r}} = 2 \boldsymbol{\mathbf{m}} - \boldsymbol{\mathbf{x}} _{N + 1}~. \end{equation}
  4. 如果 $f( \boldsymbol{\mathbf{x}} _1) \leqslant f( \boldsymbol{\mathbf{r}} ) < f( \boldsymbol{\mathbf{x}} _N)$,令 $ \boldsymbol{\mathbf{x}} _{N+1} = \boldsymbol{\mathbf{r}} $,并进入下一个循环。
  5. 如果 $f( \boldsymbol{\mathbf{r}} ) < f( \boldsymbol{\mathbf{x}} _1)$,计算拓展点
    \begin{equation} \boldsymbol{\mathbf{s}} = \boldsymbol{\mathbf{m}} + 2( \boldsymbol{\mathbf{m}} - \boldsymbol{\mathbf{x}} _{N+1})~. \end{equation}
    如果 $f( \boldsymbol{\mathbf{s}} ) < f( \boldsymbol{\mathbf{r}} )$,令 $ \boldsymbol{\mathbf{x}} _{N+1} = \boldsymbol{\mathbf{s}} $ 并进入下一个循环,否则令 $ \boldsymbol{\mathbf{x}} _{N+1} = \boldsymbol{\mathbf{r}} $。并进入下一循环。
  6. 如果 $f( \boldsymbol{\mathbf{x}} _N) \leqslant f( \boldsymbol{\mathbf{r}} ) < f( \boldsymbol{\mathbf{x}} _{N+1})$,令
    \begin{equation} \boldsymbol{\mathbf{c}} _1 = \boldsymbol{\mathbf{m}} + ( \boldsymbol{\mathbf{r}} - \boldsymbol{\mathbf{m}} )/2~. \end{equation}
    如果 $f( \boldsymbol{\mathbf{c}} _1) < f( \boldsymbol{\mathbf{r}} )$,令 $ \boldsymbol{\mathbf{x}} _{N + 1} = \boldsymbol{\mathbf{c}} _1$ 并进入下一循环,否则执行最后一步。
  7. 如果 $f( \boldsymbol{\mathbf{x}} _{N+1}) \leqslant f( \boldsymbol{\mathbf{r}} )$ 令
    \begin{equation} \boldsymbol{\mathbf{c}} _2 = \boldsymbol{\mathbf{m}} + ( \boldsymbol{\mathbf{x}} _{N+1} - \boldsymbol{\mathbf{m}} )/2~. \end{equation}
    如果 $f( \boldsymbol{\mathbf{c}} _2) < f( \boldsymbol{\mathbf{x}} _{N+1})$,令 $ \boldsymbol{\mathbf{x}} _{N+1} = \boldsymbol{\mathbf{c}} _2$ 并进入下一循环,否则执行最后一步。
  8. \begin{equation} \boldsymbol{\mathbf{v}} _i = \boldsymbol{\mathbf{x}} _1 + ( \boldsymbol{\mathbf{x}} _i - \boldsymbol{\mathbf{x}} _1)/2 \qquad (i = 2\dots N+1)~. \end{equation}
    并用 $ \boldsymbol{\mathbf{v}} _i$ 赋值给 $ \boldsymbol{\mathbf{x}} _i$,进入下一循环。

   观察以上步骤可知,当局部最小值的位置在 $N+1$ 个围成的图形以外时,图形倾向于变大且加速向最小值移动。当最小值的位置在图形内部时,图形倾向于缩小。随着循环次数增加,这 $N+1$ 个点最终将向局部最小值聚拢。

   我们可以在每个循环的第一步之后计算 $ \boldsymbol{\mathbf{x}} _1$ 和 $ \boldsymbol{\mathbf{x}} _{N+1}$ 的距离来估算自变量的误差,如果该误差小于某个值,即可结束循环并使用 $ \boldsymbol{\mathbf{x}} _1$ 作为最终结果。作为另一种方法,我们也可以在每个循环的第一步之后计算 $f( \boldsymbol{\mathbf{x}} _{N+1}) - f( \boldsymbol{\mathbf{x}} _1)$ 来估算最小值的误差。

   以下是该算法的 Matlab 代码。

代码 1:NelderMead.m
% f 是函数句柄,只接受一个 N 维行矢量作为输入变量, 并返回一个函数值
% x0 是 N 维行矢量, xerr 是 xmin 各个元素的绝对误差
function [xmin, fmin] = NelderMead(f, x0, xerr)
N = numel(x0); % f 是 N 元函数
x = zeros(N+1,N); % 预赋值
y = zeros(1,N+1);
% 计算 N+1 个初始点
x(1,:) = x0;
for ii = 1:N
    x(ii+1,:) = x(1,:);
    if x(1,ii) == 0
        x(ii+1,ii) = 0.00025;
    else
        x(ii+1,ii) = 1.05 * x(1,ii);
    end
end
% 主循环
x_last = x*0;
mask = true(1, N+1); % 改变的顶点
while true
    if max(max(abs(x(2:end,:) - x(1,:)))) < xerr % 判断误差
        break;
    elseif all(x(:) == x_last(:))
        warning('NelderMead: abs err too small, machine precision reached');
        break;
    else
        x_last = x;
    end
    % 求值并排序
    for ii = find(mask)
        y(ii) = f(x(ii,:));
    end
    [y, order] = sort(y);
    x = x(order,:);
    m = mean(x(1:N,:)); % 平均位置
    r = 2*m - x(N+1,:); % 反射点
    f_r = f(r);
    mask(:) = false;
    mask(end) = true;
    if y(1) <= f_r && f_r < y(N) % 第 4 步
        x(N+1,:) = r; continue;
    elseif f_r < y(1) % 第 5 步
        s = m + 2*(m - x(N+1,:));
        if f(s) < f_r
            x(N+1,:) = s;
        else
            x(N+1,:) = r;
        end
        continue;
    elseif f_r < y(N+1) % 第 6 步
        c1 = m + (r - m)*0.5;
        if f(c1) < f_r
            x(N+1,:) = c1; continue;
        end
    else % 第 7 步
        c2 = m + (x(N+1,:) - m)*0.5;
        if f(c2) < y(N+1)
            x(N+1,:) = c2; continue;
        end
    end
    for jj = 2:N+1 % 第 8 步
        x(jj,:) = x(1,:) + (x(jj,:) - x(1,:))*0.5;
        mask(jj) = true;
    end
end
% 输出变量
xmin = x(1,:);
fmin = f(xmin);
end

   该程序中有几个需要注意的地方。这是为了避免少数情况下可能发生的死循环(例如 $f( \boldsymbol{\mathbf{x}} )$ 在某个区域中的值处处相等时)。第二,4-7 步中对 $f( \boldsymbol{\mathbf{r}} )$ 的判断有且仅有一个成立,所以我们可以用 if...elseif...else 结构来选择。最后,4-5 步的情况下程序必定会执行 continue 语句而跳过第 8 步,只有 6-7 步中的 if 判断失败程序才会执行第 8 步。

                     

© 小时科技 保留一切权利