daceml.autodiff

Generating Backward Passes

add_backward_pass(sdfg, state, outputs, inputs)[source]

Experimental: Add a backward pass to state using reverse-mode automatic differentiation.

inputs, outputs and grads can be provided either as AccessNode nodes, or as str, in which case the graph will be searched for exactly one matching AccessNode with data matching the str.

The SDFG should not contain any inplace operations. It may contain the following nodes:

  • Maps

  • AccessNodes

  • Reductions (Sum, Min, Max)

  • ONNXOps

  • NestedSDFGs containing a single SDFGState (subject to the same constraints). NestedSDFGs may contain multiple states as long as all other states are only used for zero initialization.

When differentiating an ONNXOp, the ONNXBackward registry will be checked for any matching backward pass implementations. If none are found, the ONNXForward registry will be checked for matching pure implementations. If one is found, symbolic differentiation of the pure implementation will be attempted. If this fails, or no pure forward implementation is found, the method will fail.

Parameters
  • sdfg (SDFG) – the parent SDFG of state.

  • state (SDFGState) – the state to add the backward pass to. This is also the state of the forward pass.

  • outputs (List[Union[AccessNode, str]]) – the forward pass outputs of the function to differentiate.

  • inputs (List[Union[AccessNode, str]]) – the inputs w.r.t. which the gradient will be returned.

make_backward_function(model, apply_strict=False)[source]

Convert an ONNXModel to a PyTorch differentiable function. This method should not be used on its own. Instead use the backward=True parameter of daceml.pytorch.DaceModule.

Parameters
  • model (ONNXModel) – the model to convert.

  • apply_strict – whether to apply strict transformations before creating the backward pass.

Return type

Tuple[SDFG, SDFG, BackwardResult, Dict[str, Data]]

Returns

A 4-tuple of forward SDFG, backward SDFG, backward result, and input arrays for backward pass (as mapping of names to DaCe data descriptors).

Extending Autodiff

class BackwardImplementation[source]

ABC for ONNX op forward implementations.

This registry accepts two types of registrations. The register function expects an argument node_type=TYPE where TYPE is the type of node that this backward implementation supports. It can also take an argument op=node_name where node_name is the string of the ONNX op it supports, e.g. "Conv".

It also expects a name argument that names the implementation.

abstract static backward(forward_node, context, given_gradients, required_gradients)[source]

Add the reverse node for a node from the forward pass to the backward pass, and return it.

For each input connector with name n of the forward in required_grads, the returned backward node must add an output connector with name required_grads[n] that will output the gradient for that input.

If any input from the forward pass is required, simply add a connector with the same name as the connector on the forward node. The input will later be connected as required.

Parameters
  • forward_node (Node) – the node for which the backward pass should be generated for.

  • context (BackwardContext) – the context for this node (see BackwardContext).

  • given_gradients (List[Optional[str]]) – The names of outputs of the node that gradients will be connected for.

  • required_gradients (List[Optional[str]]) – The names of connectors that gradients should be generated for.

Return type

Tuple[Node, BackwardResult]

Returns

the reverse node and gradient names (see BackwardResult).

static backward_can_be_applied(node, state, sdfg)[source]

Return whether this expansion can be applied.

Parameters
  • node (Node) – the candidate node.

  • state (SDFGState) – the candidate state.

  • sdfg (SDFG) – the candidate sdfg.

Return type

bool

class BackwardContext(forward_sdfg, forward_state, backward_sdfg, backward_state, backward_generator)[source]

Bases: object

A tuple holding the graph context required to construct reverse nodes

Parameters
  • forward_sdfg (dace.sdfg.sdfg.SDFG) –

  • forward_state (dace.sdfg.state.SDFGState) –

  • backward_sdfg (dace.sdfg.sdfg.SDFG) –

  • backward_state (dace.sdfg.state.SDFGState) –

  • backward_generator (daceml.autodiff.backward_pass_generator.BackwardPassGenerator) –

Return type

None

backward_generator: daceml.autodiff.backward_pass_generator.BackwardPassGenerator

the backward pass generator

backward_sdfg: dace.sdfg.sdfg.SDFG

the backward SDFG

backward_state: dace.sdfg.state.SDFGState

the backward SDFG state

forward_sdfg: dace.sdfg.sdfg.SDFG

the forward SDFG

forward_state: dace.sdfg.state.SDFGState

the forward SDFG state

class BackwardResult(required_grad_names, given_grad_names)[source]

Bases: object

The return type of a differentiated node. It contains the names of the gradients the node calculates and requires.

Parameters
  • required_grad_names (Dict[Optional[str], Optional[str]]) –

  • given_grad_names (Dict[Optional[str], Optional[str]]) –

Return type

None

given_grad_names: Dict[Optional[str], Optional[str]]

mapping from names of input connectors to the connector name of the gradient for that connector.

required_grad_names: Dict[Optional[str], Optional[str]]

mapping from names of output connectors to the connector name of the gradient for that connector.