daceml.pytorch

class DaceModule(module, dummy_inputs=None, cuda=None, training=False, backward=False, apply_strict=True, auto_optimize=True, debug_transients=False, sdfg_name=None)[source]

Bases: torch.nn.modules.module.Module

A wrapper that converts a PyTorch nn.Module to a PyTorch compatible data-centric nn.Module.

Parameters
  • module (Module) – the model to wrap.

  • dummy_inputs (Optional[Tuple[Tensor]]) – a tuple of tensors to use as input when tracing model.

  • cuda (Optional[bool]) – if True, the module will execute using CUDA. If None, it will be detected from the module.

  • training (bool) – whether to use train mode when tracing model.

  • backward – whether to enable the backward pass.

  • apply_strict (bool) – whether to apply strict transforms after conversion (this generally improves performance, but can be slow).

  • sdfg_name (Optional[str]) – the name to give to the sdfg (defaults to dace_model).

  • auto_optimize (bool) – whether to apply automatic optimizations.

  • debug_transients (bool) – if True, the module will have all transients as outputs.

Example
>>> from daceml.pytorch import DaceModule
>>> class MyModule(nn.Module):
...     def forward(self, x):
...        x = torch.log(x)
...        x = torch.sqrt(x)
...        return x
>>> module = MyModule()
>>> module(torch.ones(2))
tensor([0., 0.])
>>> dace_module = DaceModule(module)
>>> dace_module(torch.ones(2))
tensor([0., 0.])

Initializes internal Module state, shared by both nn.Module and ScriptModule.

append_post_autodiff_hook(name, func)[source]
Parameters
  • name (str) –

  • func (Callable[[dace.sdfg.sdfg.SDFG, dace.sdfg.sdfg.SDFG], None]) –

append_post_compile_hook(name, func)[source]
Parameters
  • name (str) –

  • func (Callable[[dace.codegen.compiled_sdfg.CompiledSDFG], None]) –

append_post_onnx_hook(name, func)[source]
Parameters
  • name (str) –

  • func (Callable[[DaceModule], None]) –

forward(*actual_inputs)[source]

Execute the forward pass using the traced module.

post_autodiff_hooks: OrderedDict[str, Callable[[dace.SDFG, dace.SDFG], None]]

hooks that are executed after the backpropagation sdfg has been created

post_compile_hooks: OrderedDict[str, Callable[[compiled_sdfg.CompiledSDFG], None]]

hooks that are executed after the sdfg is compiled

post_onnx_hooks: OrderedDict[str, Callable[[DaceModule], None]]

hooks that are executed after onnx graph is imported to an SDFG

prepend_post_autodiff_hook(name, func)[source]
Parameters
  • name (str) –

  • func (Callable[[dace.sdfg.sdfg.SDFG, dace.sdfg.sdfg.SDFG], None]) –

prepend_post_compile_hook(name, func)[source]
Parameters
  • name (str) –

  • func (Callable[[dace.codegen.compiled_sdfg.CompiledSDFG], None]) –

prepend_post_onnx_hook(name, func)[source]
Parameters
  • name (str) –

  • func (Callable[[DaceModule], None]) –

reset_sdfg()[source]

Clear the sdfg so that optimizations are reapplied.

training: bool
@dace_module(moduleclass, dummy_inputs=None, cuda=None, training=False, backward=False, apply_strict=True, auto_optimize=True, sdfg_name=None, debug_transients=False)[source]

Decorator to apply on a definition of a torch.nn.Module to convert it to a data-centric module upon construction.

Example
>>> from daceml.pytorch import dace_module
>>> @dace_module
... class MyDecoratedModule(nn.Module):
...     def forward(self, x):
...        x = torch.log(x)
...        x = torch.sqrt(x)
...        return x
>>> module = MyDecoratedModule()
>>> module(torch.ones(2))
tensor([0., 0.])
Parameters
  • moduleclass – the model to wrap.

  • dummy_inputs (Optional[Tuple[Tensor]]) – a tuple of tensors to use as input when tracing model.

  • cuda (Optional[bool]) – if True, the module will execute using CUDA. If None, it will be detected from the module.

  • training (bool) – whether to use train mode when tracing model.

  • backward – whether to enable the backward pass.

  • apply_strict (bool) – whether to apply strict transforms after conversion (this generally improves performance, but can be slow).

  • auto_optimize (bool) – whether to apply automatic optimizations.

  • sdfg_name (Optional[str]) – the name to give to the sdfg (defaults to dace_model).

  • debug_transients (bool) – if True, the module will have all transients as outputs.

Return type

Type[DaceModule]