Automatic Differentiation¶
Warning
The symbolic automatic differentiation feature still experimental.
DaCeML takes a different approach to automatic differentiation than most deep learning frameworks. Instead of hand-writing backward passes for all differentiable operators, DaceML has a symbolic reverse-mode differentation engine.
Using Autodiff¶
There are two main ways to generate backward passes in DaCeML.
DaceModule
This class includes a
backward
parameter. IfTrue
, the autodiff engine will be used to add a backward pass to the PyTorch module, and the resulting module can be seamlessly used with other PyTorch code. For example:import torch.nn.functional as F from daceml.pytorch import dace_module @dace_module(backward=True) class Net(nn.Module): def __init__(self): super().__init__() self.fc1 = nn.Linear(784, 120) self.fc2 = nn.Linear(120, 32) self.fc3 = nn.Linear(32, 10) self.ls = nn.LogSoftmax(dim=-1) def forward(self, x): x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) x = self.ls(x) return x x = torch.randn(8, 784) y = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.long) model = Net() criterion = nn.NLLLoss() prediction = model(x) loss = criterion(prediction, y) print(f"gradients before: {model.model.fc3.weight.grad}") # gradients can flow through model! loss.backward() print(f"gradients after: {model.model.fc3.weight.grad}")
The autodiff engine can also be run on plain SDFGs. Here, the output
S
of the dace function/sdfg is differentiated w.r.t toX
andY
.from daceml.autodiff import add_backward_pass @dace.program def dace_gemm( X: dace.float32[5, 4], Y: dace.float32[4, 3], Z: dace.float32[5, 3], S: dace.float32[1], ): Z[:] = X @ Y @dace.map(_[0:5, 0:3]) def summap(i, j): s >> S(1, lambda x, y: x + y)[0] z << Z[i, j] s = z sdfg = dace_gemm.to_sdfg() add_backward_pass(sdfg=sdfg, state=sdfg.nodes()[0], inputs=["X", "Y"], outputs=["S"])
Architecture¶
At its core, the automatic differentiation engine attempts to lift the SymPy scalar differentiation engine to tensor programs. The SDFG IR is especially suitable for this for two reasons:
In most SDFGs, computation (i.e. Tasklets) operates on scalars, which can often be differentiated symbolically by SymPy.
The SDFG IR precisely specifies which Tasklets read and write to which memory locations. This information makes it simple to correctly sum the gradient contribution from each tasklet.
At a high level, it operates as follows:
Find the
AccessNode
for each input and output of theSDFGState
. Use these to determine the subgraph to differentiate.Traverse the subgraph in reverse topological order. For each node:
Call a function that reverses the node. To reverse the node, the engine checks the
BackwardImplementation
repository for a registered & applicable backward implementation for that node. If no such function exists and the node is aLibraryNode
, attempt to differentiate the pure expanded version of the node. Otherwise, call the relevant function onBackwardGenerator
. Main subtleties here are clarified in Extending the Engine. Note that this includes a recursive call forNestedSDFG
nodes (forwarding intermediate values is a source of complexity here).Connect required inputs. This includes gradients of outputs of the node, as well as the values of inputs of the node (which potentially need to be routed through reversed maps, or through the hierarchy of
NestedSDFG
s).
Extending the Engine¶
When attempting to differentiate a LibraryNode
, the engine will recursively expand the node until it is in a form
that the engine can differentiate. Usually, this means that the engine will expand the node down to the “pure”
implementation consisting of simple tasklets and maps.
However, it is sometimes desirable to “exit” this expansion process at a stage earlier than the lowest level.
For instance, consider differentiating the ONNXMatMul
library node. Since no
backward implementation exists for this node, it will be expanded to its pure version, an
ONNXEinsum
. Fully expanding this node into its pure form would result in a mapped
tasklet, which we could differentiate. However, we would like to use BLAS nodes on the forward and backward pass where
possible. To achieve this, a custom backward implementation is registered for
ONNXEinsum
, which returns a NestedSDFG
containing other einsums. Since we avoid
lowering to the lowest level, we are able to preserve information, and can later potentially expand both the forward and
backward pass einsums to more efficient BLAS calls.
Another example is ONNXSoftmax
: a typical implementation includes a maximum
operation for numerical stablility. Differentiating this implementation results in several argmax calls, which is not
desirable.
In situations like these, it makes sense to provide a custom backward pass implementation.
These implementations are registered using BackwardImplementation
. This requires
implementation of backward()
. Examples of this are
daceml.autodiff.implementations.onnx_ops.DefaultEinsumBackward
and
daceml.autodiff.implementations.onnx_ops.DefaultSoftmaxBackward
.