{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Adapted from: https://docs.scipy.org/doc/scipy/reference/tutorial/optimize.html#global-optimization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "from scipy import optimize\n",
    "from scipy.optimize import OptimizeResult"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def egg_carton(x):\n",
    "    return (-(x[1] + 47) * np.sin(np.sqrt(abs(x[0]/2 + (x[1]  + 47))))\n",
    "            -x[0] * np.sin(np.sqrt(abs(x[0] - (x[1]  + 47)))))    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_egg_carton(bounds):\n",
    "    x = np.arange(bounds[0][0], bounds[0][1])\n",
    "    y = np.arange(bounds[1][0], bounds[1][1])\n",
    "    \n",
    "    xgrid, ygrid = np.meshgrid(x, y)\n",
    "    xy = np.stack([xgrid, ygrid])\n",
    "\n",
    "    fig = plt.figure(figsize=[10, 10])\n",
    "    ax = fig.add_subplot(111, projection='3d')\n",
    "    ax.view_init(45, -45)\n",
    "    ax.plot_surface(xgrid, ygrid, egg_carton(xy), cmap='terrain')\n",
    "    ax.set_xlabel('x')\n",
    "    ax.set_ylabel('y')\n",
    "    ax.set_zlabel('egg_carton(x, y)')\n",
    "    \n",
    "    plt.show()\n",
    "    return xy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bounds = [(-512, 512), (-512, 512)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xy = plot_egg_carton(bounds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_heatmap(xy, bounds, results):\n",
    "    \n",
    "    def plot_point(result, marker='.', color=None, ms=1):\n",
    "        ax.plot(bounds[0][1] + result.x[0], bounds[1][1] + result.x[1], \n",
    "                marker=marker, color=color, ms=ms)\n",
    "    \n",
    "    fig = plt.figure(figsize=[10, 10])\n",
    "    ax  = fig.add_subplot(111)\n",
    "    im  = ax.imshow(egg_carton(xy), interpolation='bilinear', origin='lower', cmap='gray')\n",
    "    ax.set_xlabel('x')\n",
    "\n",
    "    plot_point(results['BH'], color='yellow', marker='o', ms=10)  # basin hopping\n",
    "    plot_point(results['DE'], color='cyan',   marker='o', ms=10)  # differential evolution\n",
    "    plot_point(results['DA'], color='white',  marker='o', ms=10)  # dual annealing\n",
    "\n",
    "    plot_point(results['shgo'],       color='red', marker='+', ms=10)\n",
    "    plot_point(results['shgo_sobol'], color='red', marker='x', ms=10)\n",
    "\n",
    "    # SHGO produces multiple minima, plot them all (with a smaller marker size)\n",
    "    for i in range(results['shgo_sobol'].xl.shape[0]):\n",
    "        res = OptimizeResult()\n",
    "        res.x = np.array([results['shgo_sobol'].xl[i, 0], results['shgo_sobol'].xl[i, 1]])\n",
    "        plot_point(res, color='red', marker='.', ms=8)\n",
    "\n",
    "    ax.set_xlim([-4, 514*2])\n",
    "    ax.set_ylim([-4, 514*2])\n",
    "\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = dict()\n",
    "\n",
    "results['DA'] = optimize.dual_annealing(egg_carton, bounds)\n",
    "results['DE'] = optimize.differential_evolution(egg_carton, bounds)\n",
    "results['BH'] = optimize.basinhopping(egg_carton, bounds)\n",
    "\n",
    "results['shgo']       = optimize.shgo(egg_carton, bounds)\n",
    "results['shgo_sobol'] = optimize.shgo(egg_carton, bounds, n=200, iters=5,\n",
    "                                      sampling_method='sobol')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_heatmap(xy, bounds, results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
