daceml.autodiff¶
Generating Backward Passes¶
-
add_backward_pass
(sdfg, state, outputs, inputs)[source]¶ Experimental: Add a backward pass to state using reverse-mode automatic differentiation.
inputs
,outputs
andgrads
can be provided either asAccessNode
nodes, or asstr
, in which case the graph will be searched for exactly one matchingAccessNode
with data matching thestr
.The SDFG should not contain any inplace operations. It may contain the following nodes:
Maps
AccessNodes
Reductions (Sum, Min, Max)
ONNXOps
NestedSDFGs containing a single SDFGState (subject to the same constraints). NestedSDFGs may contain multiple states as long as all other states are only used for zero initialization.
When differentiating an
ONNXOp
, the ONNXBackward registry will be checked for any matching backward pass implementations. If none are found, the ONNXForward registry will be checked for matching pure implementations. If one is found, symbolic differentiation of the pure implementation will be attempted. If this fails, or no pure forward implementation is found, the method will fail.- Parameters
sdfg (
SDFG
) – the parent SDFG ofstate
.state (
SDFGState
) – the state to add the backward pass to. This is also the state of the forward pass.outputs (
List
[Union
[AccessNode
,str
]]) – the forward pass outputs of the function to differentiate.inputs (
List
[Union
[AccessNode
,str
]]) – the inputs w.r.t. which the gradient will be returned.
-
make_backward_function
(model, apply_strict=False)[source]¶ Convert an ONNXModel to a PyTorch differentiable function. This method should not be used on its own. Instead use the
backward=True
parameter ofdaceml.pytorch.DaceModule
.- Parameters
model (
ONNXModel
) – the model to convert.apply_strict – whether to apply strict transformations before creating the backward pass.
- Return type
Tuple
[SDFG
,SDFG
,BackwardResult
,Dict
[str
,Data
]]- Returns
A 4-tuple of forward SDFG, backward SDFG, backward result, and input arrays for backward pass (as mapping of names to DaCe data descriptors).
Extending Autodiff¶
-
class
BackwardImplementation
[source]¶ ABC for ONNX op forward implementations.
This registry accepts two types of registrations. The register function expects an argument
node_type=TYPE
whereTYPE
is the type of node that this backward implementation supports. It can also take an argumentop=node_name
wherenode_name
is the string of the ONNX op it supports, e.g."Conv"
.It also expects a
name
argument that names the implementation.-
abstract static
backward
(forward_node, context, given_gradients, required_gradients)[source]¶ Add the reverse node for a node from the forward pass to the backward pass, and return it.
For each input connector with name
n
of the forward in required_grads, the returned backward node must add an output connector with namerequired_grads[n]
that will output the gradient for that input.If any input from the forward pass is required, simply add a connector with the same name as the connector on the forward node. The input will later be connected as required.
- Parameters
forward_node (
Node
) – the node for which the backward pass should be generated for.context (
BackwardContext
) – the context for this node (seeBackwardContext
).given_gradients (
List
[Optional
[str
]]) – The names of outputs of the node that gradients will be connected for.required_gradients (
List
[Optional
[str
]]) – The names of connectors that gradients should be generated for.
- Return type
Tuple
[Node
,BackwardResult
]- Returns
the reverse node and gradient names (see
BackwardResult
).
-
abstract static
-
class
BackwardContext
(forward_sdfg, forward_state, backward_sdfg, backward_state, backward_generator)[source]¶ Bases:
object
A tuple holding the graph context required to construct reverse nodes
- Parameters
forward_sdfg (dace.sdfg.sdfg.SDFG) –
forward_state (dace.sdfg.state.SDFGState) –
backward_sdfg (dace.sdfg.sdfg.SDFG) –
backward_state (dace.sdfg.state.SDFGState) –
backward_generator (daceml.autodiff.backward_pass_generator.BackwardPassGenerator) –
- Return type
None
-
backward_generator
: daceml.autodiff.backward_pass_generator.BackwardPassGenerator¶ the backward pass generator
-
backward_sdfg
: dace.sdfg.sdfg.SDFG¶ the backward SDFG
-
backward_state
: dace.sdfg.state.SDFGState¶ the backward SDFG state
-
forward_sdfg
: dace.sdfg.sdfg.SDFG¶ the forward SDFG
-
forward_state
: dace.sdfg.state.SDFGState¶ the forward SDFG state
-
class
BackwardResult
(required_grad_names, given_grad_names)[source]¶ Bases:
object
The return type of a differentiated node. It contains the names of the gradients the node calculates and requires.
- Parameters
required_grad_names (Dict[Optional[str], Optional[str]]) –
given_grad_names (Dict[Optional[str], Optional[str]]) –
- Return type
None
-
given_grad_names
: Dict[Optional[str], Optional[str]]¶ mapping from names of input connectors to the connector name of the gradient for that connector.
-
required_grad_names
: Dict[Optional[str], Optional[str]]¶ mapping from names of output connectors to the connector name of the gradient for that connector.