{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Optimizing the Mish Operator\n\nDaCeML allows users to optimize DNN modules at all levels of granularity, from operators to full models. In this\nexample, we optimize the Mish operator [1]_, a relatively novel activation function that,\namong other uses, has been applied successfully in object detection. [2]_\n\nDue to its novelty, it has, at the time of writing, not been implemented in PyTorch, ONNX or ONNX Runtime. We\ndemonstrate how DaCeML can be used to optimize this operator.\n\n.. [1] Diganta Misra. Mish: A self regularized non-monotonic activation function. In 31st British Machine Vision\n   Conference 2020, BMVC 2020, Virtual Event, UK, September 7-10, 2020. BMVA Press, 2020.\n.. [2] Alexey Bochkovskiy, Chien-Yao Wang, and Hong-Yuan Mark Liao. Yolov4: Optimal speed and accuracy of object\n   detection. CoRR, abs/2004.10934, 2020.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We begin with code for the PyTorch Module, and import it into DaCeML by annotating it with the ``@dace_module``\ndecorator.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom daceml.pytorch import dace_module\n\n\n@dace_module(cuda=True, backward=True)\nclass DaCeMish(nn.Module):\n    def __init__(self):\n        super().__init__()\n\n    def forward(self, x):\n        x = x * (torch.tanh(F.softplus(x)))\n        return x"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The module works immediately with DaCeML for the forward pass.\n\nThe first time we tested this, we found that the automatic differentiation failed due to an missing\npure implementation for :class:`~daceml.onnx.nodes.onnx_op.ONNXSoftplus`. Fortunately, adding these implementations\nis easily done using the DaCe python frontend. The following code shows the pure implementation that was added.\n\n.. code-block:: python\n\n   @python_pure_op_implementation\n   def Softplus(X, Y):\n       Y[:] = np.log(1 + np.exp(X))\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Let's test the operator and compare with a PyTorch version\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class Mish(nn.Module):\n    def __init__(self):\n        super().__init__()\n\n    def forward(self, x):\n        x = x * (torch.tanh(F.softplus(x)))\n        return x\n\n\n# create test inputs (size taken from YOLOv4)\nwith torch.no_grad():\n    dace_input = torch.rand(8, 32, 224, 224).cuda()\n    torch_input = torch.clone(dace_input)\n    dace_dy = torch.rand(8, 32, 224, 224).cuda()\n    torch_dy = torch.clone(dace_dy)\n\ndace_input.requires_grad = True\ntorch_input.requires_grad = True\n\ntorch_mish = Mish().cuda()\ndace_mish = DaCeMish()\n\ndace_output = dace_mish(dace_input)\ndace_output.backward(dace_dy)\ntorch_output = torch_mish(torch_input)\ntorch_output.backward(torch_dy)\n\nassert torch.allclose(dace_output, torch_output)\nassert torch.allclose(dace_input.grad, torch_input.grad)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Let's profile this implementation\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from daceml.testing.profiling import time_funcs, print_time_statistics\n\n\ndef run_dace():\n    out = dace_mish(dace_input)\n    out.backward(dace_dy)\n\n\ndef run_torch():\n    out = torch_mish(torch_input)\n    out.backward(torch_dy)\n\n\ntimes = time_funcs([run_dace, run_torch],\n                   func_names=[\"dace\", \"torch\"],\n                   warmups=5,\n                   num_iters=100)\nprint_time_statistics(times, [\"dace\", \"torch\"])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Inspection\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Let's inspect the forward pass SDFG first.\ndace_mish.forward_sdfg"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We can see that there is a lot of unnecessary data movement on the forward pass. Fusing the different maps would\ngreatly improve runtime.\n\nNow let's look at the backward pass.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "dace_mish.backward_sdfg"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We also see another opportunity for optimization: The DaCeML autodiff engine is \"forwarding\" intermediate values to\nperform the differentiation. This means that the intermediate values have to be written out in the forward pass, and\nand read in the backward pass.\n\n## Optimization\n\nTo improve the runtime, we'll apply three transformations.\n\nFirstly, we'll use ``SubgraphFusion`` to fuse all the maps into a single kernel.\nTo tackle the issue of forwarding intermediate values in backprop, we'll use the :class:`~daceml.transformation.TaskletFusion` transformation. By fusing the\ntasklets into a single tasklet before running automatic differentiation, the engine will differentiate the whole\nexpression at once, eliminating the need to access the intermediate values. This is an easy way to tune recomputation\nvs. storage in automatic differentiation.\n\nFinally, we'll apply ``Vectorization`` to make our kernels operate on more than one element at once.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from daceml.transformation import TaskletFusion\nfrom dace.transformation.dataflow import Vectorization, TrivialMapRangeElimination\nfrom dace.transformation.subgraph import SubgraphFusion\nfrom daceml.util import utils\nfrom dace.library import change_default\nfrom daceml import onnx as donnx\n\n# reset the compiled sdfg\ndace_mish.reset_sdfg()\n\n\n# expand the onnx nodes, and apply automatic transformations like inlining\ndef expand_and_strict_transforms(module):\n    # use the pure expansions of operators\n    with change_default(donnx, \"pure\"):\n        utils.auto_optimize(module.sdfg, cuda=True, apply_strict=True)\n\n\ndace_mish.append_post_onnx_hook(\"auto_optimize\", expand_and_strict_transforms)\n\n\n# apply subgraph fusion\ndef fuse_sg(module):\n    sdfg = module.sdfg\n    sdfg.apply_transformations_repeated(TrivialMapRangeElimination)\n    SubgraphFusion.apply_to(sdfg, *sdfg.node(0).nodes())\n\n\ndace_mish.append_post_onnx_hook(\"subgraph_fusion\", fuse_sg)\n\n# apply tasklet fusion\ndace_mish.append_post_onnx_hook(\"fuse_tasklets\", lambda x:\\\n        x.dace_model.sdfg.apply_transformations_repeated(TaskletFusion))\n\n\n# apply vectorization\ndef vectorize(fwd, bwd):\n    fwd.apply_transformations(Vectorization)\n    bwd.apply_transformations(Vectorization)\n\n\ndace_mish.append_post_autodiff_hook(\"vectorize\", vectorize)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Let's check that the new SDFG is still correct.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "dace_output = dace_mish(dace_input)\ndace_output.backward(dace_dy)\ntorch_output = torch_mish(torch_input)\ntorch_output.backward(torch_dy)\n\nassert torch.allclose(dace_output, torch_output)\nassert torch.allclose(dace_input.grad, torch_input.grad)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "After running the module once, we can also inspect the compiled SDFG for the forward and backward pass.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "dace_mish.forward_sdfg"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "dace_mish.backward_sdfg"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Now we can profile the optimized module.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "times = time_funcs([run_dace, run_torch],\n                   func_names=[\"dace\", \"torch\"],\n                   warmups=5,\n                   num_iters=100)\nprint_time_statistics(times, [\"dace\", \"torch\"])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Let's also try PyTorch JIT compilation.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import torch.jit\n\ntorch_jit = torch.jit.trace(Mish(), torch_input)\n\n\ndef run_torch_jit():\n    out = torch_jit(torch_input)\n    out.backward(torch_dy)\n\n\ntimes = time_funcs([run_dace, run_torch, run_torch_jit],\n                   func_names=[\"dace\", \"torch\", \"torch_jit\"],\n                   warmups=5,\n                   num_iters=100)\nprint_time_statistics(times, [\"dace\", \"torch\", \"torch_jit\"])"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}