72

An Introduction to the Nelder-Mead Algorithm

  • 280 Views
  • 5 Comments
  • 72 Votes

One of the most used optimization algorithms of today is the Nelder-Mead algorithm. It has become a core muscle in many programming languages' minimization techniques, including being the default for both Matlab and Scipy's fmin function.

In [1]:
# Imports
    		from __future__ import division
    		import numpy as np
    		import matplotlib
    		import matplotlib.pyplot as plt
    		import seaborn as sb
    		import scipy.optimize as opt
    		from mpl_toolkits.mplot3d import Axes3D, proj3d
    		from matplotlib.colors import LogNorm
    		
    		
    		# Magics
    		%matplotlib inline
    		

Optimization: An Introduction to the Nelder-Mead Algorithm

Author: Chase Coleman

One of the most used optimization algorithms of today is the Nelder-Mead algorithm. It has become a core muscle in many programming languages' minimization techniques, including being the default for both Matlab and Scipy's fmin function. One of its key benefits is that it requires no information about first or second derivatives. The Nelder-Mead algorithm searches for the minimum value of an objective function map f:RnR by applying simple operations to a simplex of n+1 points in Rn. The algorithm is simple and a basic understanding of it can provide valuable intuition for when it is (and more importantly when it isn't) an appropriate minimization technique.

Key Components of the Algorithm

The algorithm relies on 4 main operations on a simplex of points. Before presenting the main algorithm, we will discuss these operations to simplify the process later. The four operations are:

  • Reflection
  • Expansion
  • Contraction
  • Shrink

We will discuss these operations in the context of a concrete example. Consider the following simplex of points in 2-D. Let Δ be a simplex that consists of the points x1=(0,0);x2=(2,3);x3=(4,0). We can then compute the center of mass x¯=13i=13xi=(2,1). We graph the points of our simplex and their center of mass below.

In [2]:
alpha, beta, gamma, delta = 1., 2., .5, .75
    		x1, x2, x3 = np.array([0., 0.]), np.array([2., 3.]), np.array([4., 0.])
    		xbar = np.array([2, 1])
    		Delta = np.vstack([x1, x2, x3])
    		offsets = [(-15, 10), (5, 5), (5, 5)]
    		
    		# init_simplex = Polygon(Delta, closed=True)
    		fig, ax = plt.subplots(1, figsize=(10, 8))
    		ax.set_xlim((-2, 7))
    		ax.set_ylim((-5, 5))
    		
    		
    		for ind, point in enumerate(Delta):
    		    curr_x = r"$x_{}$"
    		    curr_offset = offsets[ind]
    		    ax.scatter(point[0], point[1], color="k")
    		    ax.annotate(curr_x.format(ind+1), xy=point, xytext=curr_offset, 
    		                textcoords='offset points', size=18)
    		    
    		ax.scatter(xbar[0], xbar[1], color="k")
    		ax.annotate(r"$\bar{x}$", xy=xbar, xytext=xbar, 
    		            textcoords='offset points', size=16)
    		
    		plt.show(fig)
    		

To facilitate a simple description of these operations, we sweep several formalities under the rug. First, the point that has these operations applied to it is chosen within the algorithm, but we simply perform all of our operations on x2. Second, and perhaps more importantly, the center of mass used within the algorithm is not the center of mass of all n+1 points (it is the center of mass of n points where we exclude the point for which the function ahieves the highest value). Also for aesthetic purposes, we use some nonstandard parameter values in the graphing of these operations (α=1, β=2, γ=.5, and δ=.75).

Reflection

The reflection operation creates a new point defined by reflecting a point across the center of mass of the simplex by xr:=x¯+α(x¯xi), where xi is the point we are reflecting. Thus in our example, xr=x¯+α(x¯x2)=(2,1)+α((2,1)(2,3))=(2,1). We add the reflected point below.

In [3]:
xr = xbar + alpha*(xbar - x2)
    		
    		ax.scatter(xr[0], xr[1], color="b")
    		ax.annotate(r"$x_r$", xy=xr, xytext=xr,
    		            textcoords='offset points', size=16)
    		
    		fig
    		
Out[3]:

Expansion

The expansion operation creates a new point by expanding the reflected point further away from x¯. It is created by xe:=x¯+β(xrx¯). We can see in our example that xe=x¯+β(xrx¯)=(2,1)+β((2,1)(2,1))=(2,3). We add the expanded point below.

In [4]:
xe = xbar + beta*(xr - xbar)
    		
    		ax.scatter(xe[0], xe[1], color="g")
    		ax.annotate(r"$x_e$", xy=xe, xytext=xe,
    		            textcoords='offset points', size=16)
    		
    		fig
    		
Out[4]:

Contraction

There are two types of contractions: outside and inside. The operation of contraction is the opposite of the operation expanding in the sense that instead of expanded the reflected point out further, it draws it closer towards the center of mass.

The outside contraction creates a new point by contracting towards the center of mass from the reflected point and is defined by xoc=x¯+γ(xrx¯).

The inside contaction creates a new point by contracting towards the center of mass from the point xi that we reflected on and is defined by xic=x¯+γ(xrx¯).

In [5]:
xoc = xbar + gamma*(xr - xbar)
    		xic = xbar - gamma*(xr - xbar)
    		
    		ax.scatter(xoc[0], xoc[1], color="r")
    		ax.scatter(xic[0], xic[1], color="r")
    		ax.annotate(r"$x_{oc}$", xy=xoc, xytext=xoc,
    		            textcoords='offset points', size=16)
    		ax.annotate(r"$x_{ic}$", xy=xic, xytext=xic,
    		            textcoords='offset points', size=16)
    		
    		fig
    		
Out[5]:

Shrink

The shrink operation takes all but one of the points and draws them closer to that one point. In the algorithm, we won't be shrinking the points towards the point that we perform the reflection/expansion/contraction on, so we will use x1 as the point towards which the points "shrink." For every point except x1, we create a new point xis=x1+δ(xix1).

In [6]:
xs2 = x1 + delta*(x2 - x1)
    		xs3 = x1 + delta*(x3 - x1)
    		
    		ax.scatter(xs2[0], xs2[1], color="DarkOrange")
    		ax.scatter(xs3[0], xs3[1], color="DarkOrange")
    		ax.annotate(r"$x_{2}^s$", xy=xs2, xytext=xs2,
    		            textcoords='offset points', size=16)
    		ax.annotate(r"$x_{3}^s$", xy=xs3, xytext=xs3,
    		            textcoords='offset points', size=16)
    		
    		fig
    		
Out[6]:

The Nelder-Mead Algorithm

Now that we understand what each of the 4 operations within the Nelder-Mead algorithm do, we can discuss the actual algorithm. As you read through the algorithm take note that the main idea is very simple: Order the points, Create some new points, Replace the point with the largest function value, Repeat.

  1. There are two ways to obtain the initial simplex. The first way is to simply pass in a simplex Δ as the guess. The second is to pass in a single point, x0, and create the simplex based around that point (we do this by using x0 as one point of the simplex and by using xi:=x0+eiε for the n other points).

  2. Once we have the simplex, we evaluate the function at each of the points in the simplex and sort the points such that x1x2xn+1 and find x¯:=1ni=1nxi (Notice as we mentioned earlier, we are only taking the center point of the n points with smallest function evaluations).

  3. If |f(x¯)f(x1)|(or another convergence metric of your choosing) then return x1 as the minimum value, otherwise proceed.

  4. Create a reflected point xr.

    4.1 If f(x1)f(xr)<f(xn) then replace xn+1 with xr

    4.2 Return to step 2.

  5. Else if f(xr)<f(x1) then create the expanded point xe.

    5.1 If f(xe)<f(xr) then replace xn+1 with xe

    5.2 Else if f(xr)<f(xe) then replace xn+1 with xr.

    5.3 Return to step 2.

  6. Else if f(xn)<f(xr)<f(xn+1) then create the outside contraction point xoc.

    6.1 If f(xoc)<f(xr) then replace xn+1 with xoc

    6.2 Else, shrink the points towards x1

    6.3 Return to step 2

  7. Else if f(xn+1)<f(xr) then create the inside contraction point xic.

    7.1 If f(xic)<f(xr) then replace xn+1 with xic

    7.2 Else, shrink the points towards x1

    7.3 Return to step 2

That is the entire algorithm. As previously stated, you can see that it simply applies our main operations repeatedly until we converge. I have written a simple implementation of the algorithm below.

In [7]:
"""
    		Author: Chase Coleman
    		Date: August 13, 2014
    		
    		This is a simple implementation of the Nelder-Mead algorithm
    		
    		"""
    		
    		
    		def nelder_mead(f, x0, method="ANMS", tol=1e-8, maxit=1e4, iter_returns=None):
    		    """
    		    This is a naive python implementation of the nelder-mead algorithm.
    		
    		    Parameters
    		    ----------
    		
    		    f : callable
    		        Function to minimize
    		    x0 : scalar(float) or array_like(float, ndim=1)
    		        The initial guess for minimizing
    		    method : string or tuple(floats)
    		        If a string, should specify ANMS or NMS then will use specific
    		        parameter values, but also can pass in a tuple of parameters in
    		        order (alpha, beta, gamma, delta), which are the reflection,
    		        expansion, contraction, and contraction parameters
    		    tol : scalar(float)
    		        The tolerance level to achieve convergence
    		    maxit : scalar(int)
    		        The maximimum number of iterations allowed
    		
    		
    		    References :
    		
    		    Nelder, J. A. and R. Mead, "A Simplex Method for Function
    		    Minimization." 1965. Vol 7(4). Computer Journal
    		
    		    F. Gao, L. Han, "Implementing the Nelder-Mead simplex algorithm with
    		    adaptive parameters", Comput. Optim. Appl.,
    		
    		    http://www.brnt.eu/phd/node10.html#SECTION00622200000000000000
    		
    		
    		    TODO:
    		      * Check to see whether we can use an array instead of a list of
    		      tuples
    		      * Write some tests
    		    """
    		    #-----------------------------------------------------------------#
    		    # Set some parameter values
    		    #-----------------------------------------------------------------#
    		    init_guess = x0
    		    fx0 = f(x0)
    		    dist = 10.
    		    curr_it = 0
    		
    		    # Get the number of dimensions we are optimizing
    		    n = np.size(x0)
    		
    		    # Will use the Adaptive Nelder-Mead Simplex paramters by default
    		    if method is "ANMS":
    		        alpha = 1.
    		        beta = 1. + (2./n)
    		        gamma = .75 - 1./(2.*n)
    		        delta = 1. - (1./n)
    		    # Otherwise can use standard parameters
    		    elif method is "NMS":
    		        alpha = 1.
    		        beta = 2.
    		        gamma = .5
    		        delta = .5
    		    elif type(method) is tuple:
    		        alpha, beta, gamma, delta = method
    		
    		
    		    #-----------------------------------------------------------------#
    		    # Create the simplex points and do the initial sort
    		    #-----------------------------------------------------------------#
    		    simplex_points = np.empty((n+1, n))
    		
    		    pt_fval = [(x0, fx0)]
    		
    		    simplex_points[0, :] = x0
    		
    		    for ind, elem in enumerate(x0):
    		
    		        if np.abs(elem) < 1e-14:
    		            curr_tau = 0.00025
    		        else:
    		            curr_tau = 0.05
    		
    		        curr_point = np.squeeze(np.eye(1, M=n, k=ind)*curr_tau + x0)
    		
    		        simplex_points[ind, :] = curr_point
    		        pt_fval.append((curr_point, f(curr_point)))
    		        
    		    if iter_returns is not None:
    		        ret_points = []
    		    else:
    		        ret_points = None
    		
    		
    		    #-----------------------------------------------------------------#
    		    # The Core of The Nelder-Mead Algorithm
    		    #-----------------------------------------------------------------#
    		    while dist>tol and curr_it<maxit:
    		
    		        # 1: Sort and find new center point (excluding worst point)
    		        pt_fval = sorted(pt_fval, key=lambda v: v[1])
    		        xbar = x0*0
    		
    		        for i in range(n):
    		            xbar = xbar + (pt_fval[i][0])/(n)
    		            
    		        if iter_returns is not None and curr_it in iter_returns:
    		            ret_points.append(pt_fval)
    		
    		        # Define useful variables
    		        x1, f1 = pt_fval[0]
    		        xn, fn = pt_fval[n-1]
    		        xnp1, fnp1 = pt_fval[n]
    		
    		
    		        # 2: Reflect
    		        xr = xbar + alpha*(xbar - pt_fval[-1][0])
    		        fr = f(xr)
    		
    		        if f1 <= fr < fn:
    		            # Replace the n+1 point
    		            xnp1, fnp1 = (xr, fr)
    		            pt_fval[n] = (xnp1, fnp1)
    		
    		        elif fr < f1:
    		            # 3: expand
    		            xe = xbar + beta*(xr - xbar)
    		            fe = f(xe)
    		
    		            if fe < fr:
    		                xnp1, fnp1 = (xe, fe)
    		                pt_fval[n] = (xnp1, fnp1)
    		            else:
    		                xnp1, fnp1 = (xr, fr)
    		                pt_fval[n] = (xnp1, fnp1)
    		
    		        elif fn <= fr <= fnp1:
    		            # 4: outside contraction
    		            xoc = xbar + gamma*(xr - xbar)
    		            foc = f(xoc)
    		
    		            if foc <= fr:
    		                xnp1, fnp1 = (xoc, foc)
    		                pt_fval[n] = (xnp1, fnp1)
    		            else:
    		                # 6: Shrink
    		                for i in range(1, n+1):
    		                    curr_pt, curr_f = pt_fval[i]
    		                    # Shrink the points
    		                    new_pt = x1 + delta*(curr_pt - x1)
    		                    new_f = f(new_pt)
    		                    # Replace
    		                    pt_fval[i] = new_pt, new_f
    		
    		        elif fr >= fnp1:
    		            # 5: inside contraction
    		            xic = xbar - gamma*(xr - xbar)
    		            fic = f(xic)
    		
    		            if fic <= fr:
    		                xnp1, fnp1 = (xic, fic)
    		                pt_fval[n] = (xnp1, fnp1)
    		            else:
    		                # 6: Shrink
    		                for i in range(1, n+1):
    		                    curr_pt, curr_f = pt_fval[i]
    		                    # Shrink the points
    		                    new_pt = x1 + delta*(curr_pt - x1)
    		                    new_f = f(new_pt)
    		                    # Replace
    		                    pt_fval[i] = new_pt, new_f
    		
    		        # Compute the distance and increase iteration counter
    		        dist = abs(fn - f1)
    		        curr_it = curr_it + 1
    		
    		    if curr_it == maxit:
    		        raise ValueError("Max iterations; Convergence failed.")
    		        
    		    if ret_points:
    		        return x1, f1, curr_it, ret_points
    		    else:
    		        return x1, f1, curr_it
    		

Example: Rosenbrock Function

One of the key tests for an optimization algorithm is Rosenbrock's "banana function" which is f(x,y):=(ax)2+b(yx2)2 which has a minimum at (a,a2). It is a tricky function because of nonconvexities and there are many points that are close to being a minimium. I graph the function below:

In [8]:
# Define Rosenbrock Function
    		def rosenbrock(x, a=1, b=100):
    		    """
    		    The minimum value of rosenbrock function is
    		    (a, a**2)
    		    """
    		    y = x[1]
    		    x = x[0]
    		    return (a - x)**2 + b*(y - x**2)**2
    		
In [ ]:
 
    		
In [9]:
x = np.linspace(-2.5, 2.5, 500)
    		y = np.linspace(-2.5, 2.5, 500)
    		
    		X, Y = np.meshgrid(x, y)
    		
    		Z = rosenbrock([X, Y])
    		
    		fig = plt.figure(figsize=(14, 8))
    		ax1 = fig.add_subplot(121)
    		ax2 = fig.add_subplot(122, projection="3d")
    		
    		fig.suptitle("Rosenbrock Function", size=24)
    		
    		# Color mesh
    		ax1.set_axis_bgcolor("white")
    		ax1.pcolormesh(X, Y, Z, cmap=matplotlib.cm.viridis,
    		               norm=LogNorm())
    		ax1.scatter(1, 1, color="k")
    		ax1.annotate('Global Min', xy=(1, 1), xytext=(-0.5, 1.25),
    		            arrowprops=dict(facecolor='black', shrink=0.05))
    		
    		# Surface plot
    		ax2.set_axis_bgcolor("white")
    		ax2.plot_surface(X, Y, Z, norm = LogNorm(), cmap=matplotlib.cm.viridis,
    		                 linewidth=0)
    		ax2.view_init(azim=65, elev=25)
    		ax2.scatter(1., 1., 0., color="k")
    		xa, ya, _ = proj3d.proj_transform(1,1,0, ax2.get_proj())
    		ax2.annotate("Global Min", xy = (xa, ya), xytext = (-20, 30),
    		    textcoords = 'offset points', ha = 'right', va = 'bottom',
    		    arrowprops=dict(facecolor='black', shrink=0.05))
    		
    		plt.tight_layout()
    		plt.show()
    		

Performance of the Nelder Mead

Now that we have seen the objective function, I will try and use our algorithm to find the minimum of this function. To show the progress, I will plot some of the steps below --In particular, I will plot iterations 0, 1, 2, 3, 4, 5, 10, 20, 50, 75, 90, and 95 as set by iterstosee.

In [10]:
iterstosee = [0, 1, 2, 3, 4, 5, 10, 15, 30, 45, 75, 90]
    		x, fx, its, ret_tris = nelder_mead(rosenbrock, x0=np.array([-1.5, -1.]), tol=1e-12, iter_returns=iterstosee)
    		
    		fig, axs = plt.subplots(nrows=6, ncols=2, figsize=(16, 24))
    		axs = axs.flatten()
    		
    		# Color mesh
    		for i, curr_ax in enumerate(axs):
    		    curr_simplex = np.vstack([ret_tris[i][0][0], ret_tris[i][1][0], ret_tris[i][2][0]])
    		    curr_ax.pcolormesh(X, Y, Z, cmap=matplotlib.cm.viridis,
    		               norm=LogNorm())
    		    curr_ax.set_title("This is simplex for iteration %i" %iterstosee[i])
    		    curr_ax.scatter(curr_simplex[:, 0], curr_simplex[:, 1])
    		
    		plt.tight_layout()
    		plt.show()
    		

The minimum value of the Rosenbrock banana function is f(1,1)=0. I print the x and f(x) value below. It seems that our minimization algorithm worked!

In [11]:
printst = "The minimization algorithm converged to x={} with a function value of f(x)={}"
    		print("\n" + printst.format(x, fx))
    		
The minimization algorithm converged to x=[ 1.00000141  1.00000287] with a function value of f(x)=2.286020259542178e-12
    		
In [ ]: