PyTorch IntegrationΒΆ
A PyTorch nn.Module
can be imported using the DaceModule
wrapper or dace_module
decorator.
import torch
import torch.nn.functional as F
from daceml.pytorch import DaceModule, dace_module
# Input and size definition
B, H, P, SM, SN = 2, 16, 64, 512, 512
N = P * H
Q, K, V = [torch.randn([SN, B, N]), torch.randn([SM, B, N]), torch.randn([SM, B, N])]
# DaCe module used as a wrapper
ptmodel = torch.nn.MultiheadAttention(N, H, bias=False)
dace_model = DaceModule(ptmodel)
outputs_wrapped = dace_model(Q, K, V)
# DaCe module used as a decorator
@dace_module
class Model(nn.Module):
def __init__(self, kernel_size):
super().__init__()
self.conv1 = nn.Conv2d(1, 4, kernel_size)
self.conv2 = nn.Conv2d(4, 4, kernel_size)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
dace_model = Model(3)
outputs_dec = dace_model(torch.rand(1, 1, 8, 8))