TensorIR Transformation

In the previous post, we’ve explored how to write primitive functions in TensorIR. Here, we will see how to transform TensorIR into other (potentially more performant) variants. The content is drived from the mlc course taught by Tianqi Chen.

Batched BMM ReLu

A batched matrix multiplication followed by a ReLu operation can be expressed using numpy as:

def lnumpy_mm_relu_v2(A: np.ndarray, B: np.ndarray, C: np.ndarray):
    Y = np.empty((16, 128, 128), dtype="float32")
    for n in range(16):
        for i in range(128):
            for j in range(128):
                for k in range(128):
                    if k == 0:
                        Y[n, i, j] = 0
                    Y[n, i, j] = Y[n, i, j] + A[n, i, k] * B[n, k, j]
    for n in range(16):
        for i in range(128):
            for j in range(128):
                C[n, i, j] = max(Y[n, i, j], 0)

Translating the numpy code into TensorIR we get:

@tvm.script.ir_module
class MyBmmRule:
  @T.prim_func
  def bmm_relu(A: T.Buffer[(16, 128, 128), "float32"],
               W: T.Buffer[(16, 128, 128), "float32"],
               Y: T.Buffer[(16, 128, 128), "float32"]):
    T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
    # we must to allocate the buffer here!
    Y_ = T.alloc_buffer([16, 128, 128], dtype="float32")
    for n, i, j, k in T.grid(16, 128, 128, 128):
      with T.block("M"):
        vn = T.axis.spatial(16, n)
        vi = T.axis.spatial(128, i)
        vj = T.axis.spatial(128, j)
        vk = T.axis.reduce(128, k)
        with T.init():
          Y_[vn, vi, vj] = T.float32(0)
        Y_[vn, vi, vj] += A[vn, vi, vk] * W[vn, vk, vj]
    for n, i, j in T.grid(16, 128, 128):
      with T.block("R"):
        vn = T.axis.spatial(16, n)
        vi = T.axis.spatial(128, i)
        vj = T.axis.spatial(128, j)
        Y[vn, vi, vj] = T.max(Y_[vn, vi, vj], T.float32(0))

Our ultimate goal is to transform the TensorIR above to the following form:

@tvm.script.ir_module
class TargetModule:
    @T.prim_func
    def bmm_relu(A: T.Buffer[(16, 128, 128), "float32"], B: T.Buffer[(16, 128, 128), "float32"], C: T.Buffer[(16, 128, 128), "float32"]) -> None:
        T.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
        Y = T.alloc_buffer([16, 128, 128], dtype="float32")
        for i0 in T.parallel(16):
            for i1, i2_0 in T.grid(128, 16):
                for ax0_init in T.vectorized(8):
                    with T.block("M_init"):
                        n, i = T.axis.remap("SS", [i0, i1])
                        j = T.axis.spatial(128, i2_0 * 8 + ax0_init)
                        Y[n, i, j] = T.float32(0)
                for ax1_0 in T.serial(32):
                    for ax1_1 in T.unroll(4):
                        for ax0 in T.serial(8):
                            with T.block("M_update"):
                                n, i = T.axis.remap("SS", [i0, i1])
                                j = T.axis.spatial(128, i2_0 * 8 + ax0)
                                k = T.axis.reduce(128, ax1_0 * 4 + ax1_1)
                                Y[n, i, j] = Y[n, i, j] + A[n, i, k] * B[n, k, j]
                for i2_1 in T.vectorized(8):
                    with T.block("R"):
                        n, i = T.axis.remap("SS", [i0, i1])
                        j = T.axis.spatial(128, i2_0 * 8 + i2_1)
                        C[n, i, j] = T.max(Y[n, i, j], T.float32(0))

Before we perform the transformation, let’s understand what the transformed TensorIR is doing by looking at several loops here.

First, taking a look at

for i1, i2_0 in T.grid(128, 16):
    for ax0_init in T.vectorized(8):
        with T.block("M_init"):
            n, i = T.axis.remap("SS", [i0, i1])
            j = T.axis.spatial(128, i2_0 * 8 + ax0_init)
            Y[n, i, j] = T.float32(0)

The code block is initializing the Y matrix to be 0. But it does so by initializing every 8 consecutive elements in each row of Y using a vectorized operation (which might be faster).

The next loop is bit tricky:

for ax1_0 in T.serial(32):
    for ax1_1 in T.unroll(4):
        for ax0 in T.serial(8):
            with T.block("M_update"):
                n, i = T.axis.remap("SS", [i0, i1])
                j = T.axis.spatial(128, i2_0 * 8 + ax0)
                k = T.axis.reduce(128, ax1_0 * 4 + ax1_1)
                Y[n, i, j] = Y[n, i, j] + A[n, i, k] * B[n, k, j]

This loop is actually performing the matrix multiplication of A and B. We mutiply a row in A with a column in B and sum up the result into a number.

Here, i is mapped to i1, which means we access A one row at a time.i k = T.axis.reduce(128, ax1_0 * 4 + ax1_1) means we access one row in matrix A and one column in matrix B sequentially duing mutiplying, while applying unrolling in hope for better access efficency (\(128 = 32\times 4))). j = T.axis.spatial(128, i2_0 * 8 + ax0) really just means accessing each column sequentially, nothing special.

Perform Transformation

To perform tranformation on any TensorIP, it’s very important to follow the steps listed below:

  1. Get block
  2. Get loops
  3. Organize loops by split, reorder, compute_at/reverse_compute_at
  4. Decompose reduction
  5. vectorize/unroll/parallel

Applying step 1, 2, and 3, we first get the block from the original TensorIR:

sch = tvm.tir.Schedule(MyBmmRule)
# Step 1. Get blocks
block_M = sch.get_block("M", func_name="bmm_relu")

# Step 2. Get loops
n, i, j, k = sch.get_loops(block_M)

# Step 3. Organize loops
k0, k1 = sch.split(k, factors=[32, 4])
j0, j1 = sch.split(j, factors=[16, 8])

The reason we split k and j in such a way is: we already mentioned k dimension is accessed sequentially but with unrolling (4) applied; when matrix Y is initialized, a vectorized operation (applied on 8 elements) is applied to dimension j, or every 8 elements in one row(TVM is row-major, therefore might be faster).

But the next question is: how do we reorder the spitted loop? I spent a lot of time trying to figure that out. Turns out the simplest way is to write out the implementation in numpy and proceed from there. Remember, we’ve already splitted k and j, which are used during matrix multiplication, so our new matrix multipliation in numy would be:

for j0 in range(16):
    for k0 in range(32):
        for k1 in range(4):
            for j1 in range(8):
                Y[i, 8*j0+j1] += A[i, 4*k0 + k1] * B[4*k0+k1, 8*j0+j1]

Because we move the the next column in B after traversing the previous column, we will put j1 at the innermost loop. Therefore, the transformation for TensorIR would be:

sch.reorder(j0, k0, k1, j1)

We can print out the transformed TensorIR with print(sch.mod.script()):

@tvm.script.ir_module
class Module:
    @tir.prim_func
    def bmm_relu(A: tir.Buffer[(16, 128, 128), "float32"], B: tir.Buffer[(16, 128, 128), "float32"], C: tir.Buffer[(16, 128, 128), "float32"]) -> None:
        tir.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
        Y = tir.alloc_buffer([16, 128, 128], dtype="float32")
        for n in tir.parallel(16):
            for i, j_0, k_0, k_1, j_1 in tir.grid(128, 16, 32, 4, 8):
                with tir.block("M"):
                    vn, vi = tir.axis.remap("SS", [n, i])
                    vj = tir.axis.spatial(128, j_0 * 8 + j_1)
                    vk = tir.axis.reduce(128, k_0 * 4 + k_1)
                    tir.reads(A[vn, vi, vk], B[vn, vk, vj])
                    tir.writes(Y[vn, vi, vj])
                    with tir.init():
                        Y[vn, vi, vj] = tir.float32(0)
                    Y[vn, vi, vj] = Y[vn, vi, vj] + A[vn, vi, vk] * B[vn, vk, vj]
        for n, i, j in tir.grid(16, 128, 128):
            with tir.block("R"):
                vn, vi, vj = tir.axis.remap("SSS", [n, i, j])
                tir.reads(Y[vn, vi, vj])
                tir.writes(C[vn, vi, vj])
                C[vn, vi, vj] = tir.max(Y[vn, vi, vj], tir.float32(0))

Now, we just need to move the ReLu operation (for n, i, j in tir.grid(16, 128, 128):) into the loop above:

block_M = sch.get_block("M", func_name="bmm_relu")
sch.reverse_compute_at(block_M, j0)

Step 4 involves seperating initialization and matrix multiplication, therefore we use M_init = sch.decompose_reduction(block_M, k0), which results in:

@tvm.script.ir_module
class Module:
    @tir.prim_func
    def bmm_relu(A: tir.Buffer[(16, 128, 128), "float32"], B: tir.Buffer[(16, 128, 128), "float32"], C: tir.Buffer[(16, 128, 128), "float32"]) -> None:
        # function attr dict
        tir.func_attr({"global_symbol": "bmm_relu", "tir.noalias": True})
        # body
        # with tir.block("root")
        Y = tir.alloc_buffer([16, 128, 128], dtype="float32")
        for n in tir.parallel(16):
            for i, j_0 in tir.grid(128, 16):
                for j_1_init in tir.serial(8):
                    with tir.block("M_init"):
                        vn, vi = tir.axis.remap("SS", [n, i])
                        vj = tir.axis.spatial(128, j_0 * 8 + j_1_init)
                        tir.reads()
                        tir.writes(Y[vn, vi, vj])
                        Y[vn, vi, vj] = tir.float32(0)
                for k_0, k_1, j_1 in tir.grid(32, 4, 8):
                    with tir.block("M_update"):
                        vn, vi = tir.axis.remap("SS", [n, i])
                        vj = tir.axis.spatial(128, j_0 * 8 + j_1)
                        vk = tir.axis.reduce(128, k_0 * 4 + k_1)
                        tir.reads(Y[vn, vi, vj], A[vn, vi, vk], B[vn, vk, vj])
                        tir.writes(Y[vn, vi, vj])
                        Y[vn, vi, vj] = Y[vn, vi, vj] + A[vn, vi, vk] * B[vn, vk, vj]
                for ax0 in tir.serial(8):
                    with tir.block("R"):
                        vn, vi = tir.axis.remap("SS", [n, i])
                        vj = tir.axis.spatial(128, j_0 * 8 + ax0)
                        tir.reads(Y[vn, vi, vj])
                        tir.writes(C[vn, vi, vj])
                        C[vn, vi, vj] = tir.max(Y[vn, vi, vj], tir.float32(0))

The final step is easy, just apply vectorize/parallel/unroll onto corresponding loop:

n, i, j_0, j_1_init = sch.get_loops(M_init)
sch.vectorize(j_1_init)

n, i, j_0, i2_1 = sch.get_loops(block_R)
sch.vectorize(i2_1)

block_M_update = sch.get_block("M_update", func_name="bmm_relu")
n, i, j_0, k_0, k_1, j_1 = sch.get_loops(block_M_update)

Print out the final TensorIR to find out its final form ( ͡❛ ͜ʖ ͡❛).