Using ONNX Library Nodes

This example demonstrates using ONNX library nodes.

The easiest way to use ONNX library nodes is using the dace python frontend

import dace
import daceml.onnx as donnx
import numpy as np


@dace.program
def conv_program(X_arr: dace.float32[5, 3, 10, 10], W_arr: dace.float32[16, 3,
                                                                        3, 3]):
    output = np.ndarray([5, 16, 4, 4], dtype=np.float32)
    donnx.ONNXConv(X=X_arr, W=W_arr, Y=output, strides=[2, 2])
    return output

The resulting SDFG contains an instance of the ONNXConv library node.

conv_program.to_sdfg()


We can now execute the program with some example inputs

X = np.random.rand(5, 3, 10, 10).astype(np.float32)
W = np.random.rand(16, 3, 3, 3).astype(np.float32)

result = conv_program(X_arr=X, W_arr=W)

Let’s check the correctness vs. PyTorch

import torch
import torch.nn.functional as F

torch_result = F.conv2d(torch.from_numpy(X), torch.from_numpy(W),
                        stride=2).numpy()

assert np.allclose(torch_result, result)
np.linalg.norm(torch_result - result)

Out:

2.3904955e-05

We can also use ONNX nodes using the SDFG Python API.

from daceml.onnx import ONNXConv

sdfg = dace.SDFG("conv_example")
sdfg.add_array("X_arr", (5, 3, 10, 10), dace.float32)
sdfg.add_array("W_arr", (16, 3, 3, 3), dace.float32)
sdfg.add_array("Z_arr", (5, 16, 4, 4), dace.float32)

state = sdfg.add_state()
access_X = state.add_access("X_arr")
access_W = state.add_access("W_arr")
access_Z = state.add_access("Z_arr")

conv = ONNXConv("MyConvNode", strides=[2, 2])

state.add_node(conv)
state.add_edge(access_X, None, conv, "X", sdfg.make_array_memlet("X_arr"))
state.add_edge(access_W, None, conv, "W", sdfg.make_array_memlet("W_arr"))
state.add_edge(conv, "Y", access_Z, None, sdfg.make_array_memlet("Z_arr"))

sdfg


The SDFG looks the same as the one above. Now let’s try running it

Z = np.zeros((5, 16, 4, 4)).astype(np.float32)
sdfg(X_arr=X, W_arr=W, Z_arr=Z)
assert np.allclose(torch_result, Z)
np.linalg.norm(torch_result - Z)

Out:

2.3904955e-05

Total running time of the script: ( 0 minutes 7.763 seconds)

Gallery generated by Sphinx-Gallery