Optimizing the Mish Operator

DaCeML allows users to optimize DNN modules at all levels of granularity, from operators to full models. In this example, we optimize the Mish operator 1, a relatively novel activation function that, among other uses, has been applied successfully in object detection. 2

Due to its novelty, it has, at the time of writing, not been implemented in PyTorch, ONNX or ONNX Runtime. We demonstrate how DaCeML can be used to optimize this operator.

1

Diganta Misra. Mish: A self regularized non-monotonic activation function. In 31st British Machine Vision Conference 2020, BMVC 2020, Virtual Event, UK, September 7-10, 2020. BMVA Press, 2020.

2

Alexey Bochkovskiy, Chien-Yao Wang, and Hong-Yuan Mark Liao. Yolov4: Optimal speed and accuracy of object detection. CoRR, abs/2004.10934, 2020.

We begin with code for the PyTorch Module, and import it into DaCeML by annotating it with the @dace_module decorator.

import torch
from torch import nn
from torch.nn import functional as F

from daceml.pytorch import dace_module


@dace_module(cuda=True, backward=True)
class DaCeMish(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        x = x * (torch.tanh(F.softplus(x)))
        return x

The module works immediately with DaCeML for the forward pass.

The first time we tested this, we found that the automatic differentiation failed due to an missing pure implementation for ONNXSoftplus. Fortunately, adding these implementations is easily done using the DaCe python frontend. The following code shows the pure implementation that was added.

@python_pure_op_implementation
def Softplus(X, Y):
    Y[:] = np.log(1 + np.exp(X))

Let’s test the operator and compare with a PyTorch version

class Mish(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        x = x * (torch.tanh(F.softplus(x)))
        return x


# create test inputs (size taken from YOLOv4)
with torch.no_grad():
    dace_input = torch.rand(8, 32, 224, 224).cuda()
    torch_input = torch.clone(dace_input)
    dace_dy = torch.rand(8, 32, 224, 224).cuda()
    torch_dy = torch.clone(dace_dy)

dace_input.requires_grad = True
torch_input.requires_grad = True

torch_mish = Mish().cuda()
dace_mish = DaCeMish()

dace_output = dace_mish(dace_input)
dace_output.backward(dace_dy)
torch_output = torch_mish(torch_input)
torch_output.backward(torch_dy)

assert torch.allclose(dace_output, torch_output)
assert torch.allclose(dace_input.grad, torch_input.grad)

Let’s profile this implementation

from daceml.testing.profiling import time_funcs, print_time_statistics


def run_dace():
    out = dace_mish(dace_input)
    out.backward(dace_dy)


def run_torch():
    out = torch_mish(torch_input)
    out.backward(torch_dy)


times = time_funcs([run_dace, run_torch],
                   func_names=["dace", "torch"],
                   warmups=5,
                   num_iters=100)
print_time_statistics(times, ["dace", "torch"])

Out:

| Name   |     Min |    Mean |   Median |   Stdev |     Max |
|--------|---------|---------|----------|---------|---------|
| dace   | 11.6968 | 12.6615 |  12.2345 |  1.7286 | 26.7186 |
| torch  |  4.7779 |  4.8643 |   4.8738 |  0.0264 |  4.8814 |

Inspection

# Let's inspect the forward pass SDFG first.
dace_mish.forward_sdfg


We can see that there is a lot of unnecessary data movement on the forward pass. Fusing the different maps would greatly improve runtime.

Now let’s look at the backward pass.

dace_mish.backward_sdfg


We also see another opportunity for optimization: The DaCeML autodiff engine is “forwarding” intermediate values to perform the differentiation. This means that the intermediate values have to be written out in the forward pass, and and read in the backward pass.

Optimization

To improve the runtime, we’ll apply three transformations.

Firstly, we’ll use SubgraphFusion to fuse all the maps into a single kernel. To tackle the issue of forwarding intermediate values in backprop, we’ll use the TaskletFusion transformation. By fusing the tasklets into a single tasklet before running automatic differentiation, the engine will differentiate the whole expression at once, eliminating the need to access the intermediate values. This is an easy way to tune recomputation vs. storage in automatic differentiation.

Finally, we’ll apply Vectorization to make our kernels operate on more than one element at once.

from daceml.transformation import TaskletFusion
from dace.transformation.dataflow import Vectorization, TrivialMapRangeElimination
from dace.transformation.subgraph import SubgraphFusion
from daceml.util import utils
from dace.library import change_default
from daceml import onnx as donnx

# reset the compiled sdfg
dace_mish.reset_sdfg()


# expand the onnx nodes, and apply automatic transformations like inlining
def expand_and_strict_transforms(module):
    # use the pure expansions of operators
    with change_default(donnx, "pure"):
        utils.auto_optimize(module.sdfg, cuda=True, apply_strict=True)


dace_mish.append_post_onnx_hook("auto_optimize", expand_and_strict_transforms)


# apply subgraph fusion
def fuse_sg(module):
    sdfg = module.sdfg
    sdfg.apply_transformations_repeated(TrivialMapRangeElimination)
    SubgraphFusion.apply_to(sdfg, *sdfg.node(0).nodes())


dace_mish.append_post_onnx_hook("subgraph_fusion", fuse_sg)

# apply tasklet fusion
dace_mish.append_post_onnx_hook("fuse_tasklets", lambda x:\
        x.dace_model.sdfg.apply_transformations_repeated(TaskletFusion))


# apply vectorization
def vectorize(fwd, bwd):
    fwd.apply_transformations(Vectorization)
    bwd.apply_transformations(Vectorization)


dace_mish.append_post_autodiff_hook("vectorize", vectorize)

Let’s check that the new SDFG is still correct.

dace_output = dace_mish(dace_input)
dace_output.backward(dace_dy)
torch_output = torch_mish(torch_input)
torch_output.backward(torch_dy)

assert torch.allclose(dace_output, torch_output)
assert torch.allclose(dace_input.grad, torch_input.grad)

Out:

/opt/actions-runner/_work/daceml/daceml/venv/lib/python3.7/site-packages/dace/sdfg/sdfg.py:1755: UserWarning: SDFG "DaCeMish" is already loaded by another object, recompiling under a different name.
  'recompiling under a different name.' % self.name)
/opt/actions-runner/_work/daceml/daceml/venv/lib/python3.7/site-packages/dace/sdfg/sdfg.py:1755: UserWarning: SDFG "DaCeMish_backward" is already loaded by another object, recompiling under a different name.
  'recompiling under a different name.' % self.name)

After running the module once, we can also inspect the compiled SDFG for the forward and backward pass.

dace_mish.forward_sdfg


dace_mish.backward_sdfg


Now we can profile the optimized module.

times = time_funcs([run_dace, run_torch],
                   func_names=["dace", "torch"],
                   warmups=5,
                   num_iters=100)
print_time_statistics(times, ["dace", "torch"])

Out:

| Name   |    Min |   Mean |   Median |   Stdev |    Max |
|--------|--------|--------|----------|---------|--------|
| dace   | 1.5453 | 1.5819 |   1.5854 |  0.0067 | 1.5911 |
| torch  | 4.2580 | 4.3152 |   4.3436 |  0.0398 | 4.3531 |

Let’s also try PyTorch JIT compilation.

import torch.jit

torch_jit = torch.jit.trace(Mish(), torch_input)


def run_torch_jit():
    out = torch_jit(torch_input)
    out.backward(torch_dy)


times = time_funcs([run_dace, run_torch, run_torch_jit],
                   func_names=["dace", "torch", "torch_jit"],
                   warmups=5,
                   num_iters=100)
print_time_statistics(times, ["dace", "torch", "torch_jit"])

Out:

| Name      |    Min |   Mean |   Median |   Stdev |    Max |
|-----------|--------|--------|----------|---------|--------|
| dace      | 1.5571 | 1.5885 |   1.5888 |  0.0035 | 1.5989 |
| torch     | 4.3418 | 4.3462 |   4.3459 |  0.0015 | 4.3517 |
| torch_jit | 3.8157 | 3.8183 |   3.8183 |  0.0010 | 3.8216 |

Total running time of the script: ( 1 minutes 44.541 seconds)

Gallery generated by Sphinx-Gallery