贡献者: addis
Nelder-Mead 算法是一种求多元函数局部最小值的算法,其优点是不需要函数可导并能较快收敛到局部最小值。Matlab 自带的 fminsearch
函数就是使用该算法。对 $N$ 元函数 $f( \boldsymbol{\mathbf{x}} )$(这里把函数自变量用 $N$ 维矢量来表示),该算法需要提供函数自变量空间中的一个初始点,$ \boldsymbol{\mathbf{x}} _1$,算法从该点出发寻找局部最小值,以下是具体步骤。
我们先根据初始点另外生成 $N$ 个初始点 $ \boldsymbol{\mathbf{x}} _2\dots \boldsymbol{\mathbf{x}} _{N + 1}$,使 $ \boldsymbol{\mathbf{x}} _{1 + i}$ 在第 $i$ 个分量比 $ \boldsymbol{\mathbf{x}} _1$ 的大 5%,其他分量保持相同。如果 $ \boldsymbol{\mathbf{x}} _1$ 的第 $i$ 个分量为零,那么 $ \boldsymbol{\mathbf{x}} _{1 + i}$ 的第 $i$ 个分量设为 $0.00025$。得到 $N+1$ 个初始点后,开始按照以下步骤进行循环,直到满足特定的精度条件时退出循环。
观察以上步骤可知,当局部最小值的位置在 $N+1$ 个围成的图形以外时,图形倾向于变大且加速向最小值移动。当最小值的位置在图形内部时,图形倾向于缩小。随着循环次数增加,这 $N+1$ 个点最终将向局部最小值聚拢。
我们可以在每个循环的第一步之后计算 $ \boldsymbol{\mathbf{x}} _1$ 和 $ \boldsymbol{\mathbf{x}} _{N+1}$ 的距离来估算自变量的误差,如果该误差小于某个值,即可结束循环并使用 $ \boldsymbol{\mathbf{x}} _1$ 作为最终结果。作为另一种方法,我们也可以在每个循环的第一步之后计算 $f( \boldsymbol{\mathbf{x}} _{N+1}) - f( \boldsymbol{\mathbf{x}} _1)$ 来估算最小值的误差。
以下是该算法的 Matlab 代码。
% f 是函数句柄,只接受一个 N 维行矢量作为输入变量, 并返回一个函数值
% x0 是 N 维行矢量, xerr 是 xmin 各个元素的绝对误差
function [xmin, fmin] = NelderMead(f, x0, xerr)
N = numel(x0); % f 是 N 元函数
x = zeros(N+1,N); % 预赋值
y = zeros(1,N+1);
% 计算 N+1 个初始点
x(1,:) = x0;
for ii = 1:N
x(ii+1,:) = x(1,:);
if x(1,ii) == 0
x(ii+1,ii) = 0.00025;
else
x(ii+1,ii) = 1.05 * x(1,ii);
end
end
% 主循环
x_last = x*0;
mask = true(1, N+1); % 改变的顶点
while true
if max(max(abs(x(2:end,:) - x(1,:)))) < xerr % 判断误差
break;
elseif all(x(:) == x_last(:))
warning('NelderMead: abs err too small, machine precision reached');
break;
else
x_last = x;
end
% 求值并排序
for ii = find(mask)
y(ii) = f(x(ii,:));
end
[y, order] = sort(y);
x = x(order,:);
m = mean(x(1:N,:)); % 平均位置
r = 2*m - x(N+1,:); % 反射点
f_r = f(r);
mask(:) = false;
mask(end) = true;
if y(1) <= f_r && f_r < y(N) % 第 4 步
x(N+1,:) = r; continue;
elseif f_r < y(1) % 第 5 步
s = m + 2*(m - x(N+1,:));
if f(s) < f_r
x(N+1,:) = s;
else
x(N+1,:) = r;
end
continue;
elseif f_r < y(N+1) % 第 6 步
c1 = m + (r - m)*0.5;
if f(c1) < f_r
x(N+1,:) = c1; continue;
end
else % 第 7 步
c2 = m + (x(N+1,:) - m)*0.5;
if f(c2) < y(N+1)
x(N+1,:) = c2; continue;
end
end
for jj = 2:N+1 % 第 8 步
x(jj,:) = x(1,:) + (x(jj,:) - x(1,:))*0.5;
mask(jj) = true;
end
end
% 输出变量
xmin = x(1,:);
fmin = f(xmin);
end
该程序中有几个需要注意的地方。这是为了避免少数情况下可能发生的死循环(例如 $f( \boldsymbol{\mathbf{x}} )$ 在某个区域中的值处处相等时)。第二,4-7 步中对 $f( \boldsymbol{\mathbf{r}} )$ 的判断有且仅有一个成立,所以我们可以用 if...elseif...else
结构来选择。最后,4-5 步的情况下程序必定会执行 continue
语句而跳过第 8 步,只有 6-7 步中的 if
判断失败程序才会执行第 8 步。