Note
Click here to download the full example code
Using ONNX Library Nodes¶
This example demonstrates using ONNX library nodes.
The easiest way to use ONNX library nodes is using the dace python frontend
import dace
import daceml.onnx as donnx
import numpy as np
@dace.program
def conv_program(X_arr: dace.float32[5, 3, 10, 10], W_arr: dace.float32[16, 3,
3, 3]):
output = np.ndarray([5, 16, 4, 4], dtype=np.float32)
donnx.ONNXConv(X=X_arr, W=W_arr, Y=output, strides=[2, 2])
return output
The resulting SDFG contains an instance of the ONNXConv
library node.
conv_program.to_sdfg()
We can now execute the program with some example inputs
X = np.random.rand(5, 3, 10, 10).astype(np.float32)
W = np.random.rand(16, 3, 3, 3).astype(np.float32)
result = conv_program(X_arr=X, W_arr=W)
Let’s check the correctness vs. PyTorch
import torch
import torch.nn.functional as F
torch_result = F.conv2d(torch.from_numpy(X), torch.from_numpy(W),
stride=2).numpy()
assert np.allclose(torch_result, result)
np.linalg.norm(torch_result - result)
Out:
2.3904955e-05
We can also use ONNX nodes using the SDFG Python API.
from daceml.onnx import ONNXConv
sdfg = dace.SDFG("conv_example")
sdfg.add_array("X_arr", (5, 3, 10, 10), dace.float32)
sdfg.add_array("W_arr", (16, 3, 3, 3), dace.float32)
sdfg.add_array("Z_arr", (5, 16, 4, 4), dace.float32)
state = sdfg.add_state()
access_X = state.add_access("X_arr")
access_W = state.add_access("W_arr")
access_Z = state.add_access("Z_arr")
conv = ONNXConv("MyConvNode", strides=[2, 2])
state.add_node(conv)
state.add_edge(access_X, None, conv, "X", sdfg.make_array_memlet("X_arr"))
state.add_edge(access_W, None, conv, "W", sdfg.make_array_memlet("W_arr"))
state.add_edge(conv, "Y", access_Z, None, sdfg.make_array_memlet("Z_arr"))
sdfg
The SDFG looks the same as the one above. Now let’s try running it
Z = np.zeros((5, 16, 4, 4)).astype(np.float32)
sdfg(X_arr=X, W_arr=W, Z_arr=Z)
assert np.allclose(torch_result, Z)
np.linalg.norm(torch_result - Z)
Out:
2.3904955e-05
Total running time of the script: ( 0 minutes 7.763 seconds)