Skip to content
118 changes: 72 additions & 46 deletions iron/operators/transpose/design.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from aie.iron.controlflow import range_


def shuffle_transpose(dev, M, N, num_columns, num_channels, m, n, s, func_prefix=""):
def shuffle_transpose(
dev, M, N, num_columns, num_channels, m, n, s, num_batches=1, func_prefix=""
):
num_elements = M * N
per_tile_elements = m * n
dtype = bfloat16
Expand All @@ -34,8 +36,9 @@ def shuffle_transpose(dev, M, N, num_columns, num_channels, m, n, s, func_prefix
if s == 8 and (m <= 16 or n <= 16):
raise ValueError(f"Kernel tile {s} needs AIE tile rows > 16 and columns > 16.")

# Define tensor types
tensor_ty = np.ndarray[(num_elements,), np.dtype[dtype]]
# Define tensor types. The runtime tensor spans all batches (contiguous matrices);
# per-tile work on the cores is identical regardless of batch count.
tensor_ty = np.ndarray[(num_batches * num_elements,), np.dtype[dtype]]
tile_ty = np.ndarray[(per_tile_elements,), np.dtype[dtype]]

fifodepth = 1 if per_tile_elements > 4096 else 2
Expand All @@ -47,13 +50,25 @@ def shuffle_transpose(dev, M, N, num_columns, num_channels, m, n, s, func_prefix
# and channels. Partially transposes the input
# data so that the kernel only needs to
# transpose s*s-sized sub-tiles.
# The L3 tensors hold num_batches contiguous (M,N) matrices stacked along the row
# dimension: in-dims (num_batches*M, N), out-dims (num_batches*N, M); at num_batches==1
# these are simply (M,N)/(N,M). Each (i,j) column/channel emits one TAP per batch, offset
# by batch*num_elements; the per-batch internal sizes/strides are the same for every batch
# because each matrix is contiguous and row-major.
in_dims = (num_batches * M, N)
out_dims = (num_batches * N, M)
taps_in_L3L2 = [
TensorAccessPattern(
(M, N),
(M // num_channels) * j * N + (N // num_columns) * i,
[M // num_channels // m, N // num_columns // n, m, n],
[m * N, n, N, 1],
)
[
TensorAccessPattern(
in_dims,
batch * num_elements
+ (M // num_channels) * j * N
+ (N // num_columns) * i,
[M // num_channels // m, N // num_columns // n, m, n],
[m * N, n, N, 1],
)
for batch in range(num_batches)
]
for i in range(num_columns)
for j in range(num_channels)
]
Expand All @@ -68,12 +83,17 @@ def shuffle_transpose(dev, M, N, num_columns, num_channels, m, n, s, func_prefix
for j in range(num_channels)
]
taps_out_L1L3 = [
TensorAccessPattern(
(N, M),
(N // num_columns) * i * M + (M // num_channels) * j,
[M // num_channels // m, N // num_columns // n, n, m],
[m, n * M, M, 1],
)
[
TensorAccessPattern(
out_dims,
batch * num_elements
+ (N // num_columns) * i * M
+ (M // num_channels) * j,
[M // num_channels // m, N // num_columns // n, n, m],
[m, n * M, M, 1],
)
for batch in range(num_batches)
]
for i in range(num_columns)
for j in range(num_channels)
]
Expand Down Expand Up @@ -106,14 +126,17 @@ def shuffle_transpose(dev, M, N, num_columns, num_channels, m, n, s, func_prefix

# Define a task that will run on a compute tile
def core_body(of_in1, of_out, transpose_kernel):
# Number of sub-matrix "tile" iterations
for _ in range_(N // n // num_columns):
for _ in range_(M // m // num_channels):
elem_in1 = of_in1.acquire(1)
elem_out = of_out.acquire(1)
transpose_kernel(elem_in1, elem_out)
of_out.release(1)
of_in1.release(1)
# Process num_batches contiguous matrices through the same FIFOs: num_batches x the per-matrix
# tile iterations. The kernel only ever sees s*s sub-tiles, so it is batch-agnostic.
for _ in range_(num_batches):
# Number of sub-matrix "tile" iterations
for _ in range_(N // n // num_columns):
for _ in range_(M // m // num_channels):
elem_in1 = of_in1.acquire(1)
elem_out = of_out.acquire(1)
transpose_kernel(elem_in1, elem_out)
of_out.release(1)
of_in1.release(1)

# Create a worker to run the task on a compute tile
my_workers = [
Expand All @@ -134,29 +157,32 @@ def core_body(of_in1, of_out, transpose_kernel):
with rt.sequence(tensor_ty, tensor_ty) as (A, C):
rt.start(*my_workers)

# Initialize a group for parallel drain tasks, with fill resources free'd when drains complete.
tg = rt.task_group()

# Fill the input objectFIFOs with data
for i in range(num_columns):
for j in range(num_channels):
rt.fill(
of_in1s_L3L2[i * num_channels + j].prod(),
A,
taps_in_L3L2[i * num_channels + j],
task_group=tg,
)
# Drain the output objectFIFOs with data
for i in range(num_columns):
for j in range(num_channels):
rt.drain(
of_outs[i * num_channels + j].cons(),
C,
taps_out_L1L3[i * num_channels + j],
wait=True, # wait for the transfer to complete and data to be available
task_group=tg,
)
rt.finish_task_group(tg)
# One task group per batch (each a parallel fill+drain over all columns/channels), so the
# num_batches contiguous matrices stream through the same FIFOs in sequence.
for batch in range(num_batches):
# Initialize a group for parallel drain tasks, with fill resources free'd when drains complete.
tg = rt.task_group()

# Fill the input objectFIFOs with data
for i in range(num_columns):
for j in range(num_channels):
rt.fill(
of_in1s_L3L2[i * num_channels + j].prod(),
A,
taps_in_L3L2[i * num_channels + j][batch],
task_group=tg,
)
# Drain the output objectFIFOs of data
for i in range(num_columns):
for j in range(num_channels):
rt.drain(
of_outs[i * num_channels + j].cons(),
C,
taps_out_L1L3[i * num_channels + j][batch],
wait=True, # wait for the transfer to complete and data to be available
task_group=tg,
)
rt.finish_task_group(tg)

# Place program components (assign them resources on the device) and generate an MLIR module
return Program(dev, rt).resolve_program(SequentialPlacer())
21 changes: 18 additions & 3 deletions iron/operators/transpose/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass, field
from typing import ClassVar, Dict

import aie.utils as aie_utils
from iron.common import (
Expand All @@ -16,7 +17,13 @@

@dataclass
class Transpose(MLIROperator):
"""AIE-accelerated transpose operator"""
"""AIE-accelerated transpose operator.

``num_batches`` > 1 performs that many independent (M,N)->(N,M) transposes on
contiguous matrices laid back-to-back in memory (results concatenated), mirroring
GEMV's batching — the per-batch tile work rides the same ObjectFifos, so B batched
transposes cost ONE dispatch instead of B unrolled ones.
"""

M: int
N: int
Expand All @@ -25,8 +32,14 @@ class Transpose(MLIROperator):
m: int
n: int
s: int
num_batches: int = 1
context: object = field(default=None, repr=False)

_name_aliases: ClassVar[Dict[str, str]] = {
**MLIROperator._name_aliases,
"num_batches": "batch",
}

def __post_init__(self):
if self.M % self.m != 0:
raise ValueError(f"Matrix rows ({self.M}) must be a multiple of {self.m}")
Expand Down Expand Up @@ -66,6 +79,7 @@ def get_mlir_artifact(self):
self.m,
self.n,
self.s,
self.num_batches,
),
),
)
Expand All @@ -90,7 +104,8 @@ def get_kernel_artifacts(self):
]

def get_arg_spec(self):
batch_dim = (self.num_batches,) if self.num_batches > 1 else ()
return [
AIERuntimeArgSpec("in", (self.M * self.N,)),
AIERuntimeArgSpec("out", (self.M * self.N,)),
AIERuntimeArgSpec("in", batch_dim + (self.M * self.N,)),
AIERuntimeArgSpec("out", batch_dim + (self.N * self.M,)),
]
17 changes: 14 additions & 3 deletions iron/operators/transpose/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,20 @@
from iron.common.test_utils import torch_dtype_map


def generate_golden_reference(rows: int, cols: int, dtype="bf16", seed=42):
def generate_golden_reference(
rows: int, cols: int, dtype="bf16", seed=42, num_batches=1
):
torch.manual_seed(seed)
val_range = 4
input_tensor = torch.rand(rows, cols, dtype=torch_dtype_map[dtype]) * val_range
output_tensor = torch.transpose(input_tensor, 0, 1)
# num_batches>1: B independent (rows,cols) matrices laid back-to-back; each is
# transposed independently and the results concatenated in the same order.
input_tensor = (
torch.rand(num_batches, rows, cols, dtype=torch_dtype_map[dtype]) * val_range
)
output_tensor = torch.stack(
[torch.transpose(input_tensor[b], 0, 1) for b in range(num_batches)]
)
# drop batch dimension if num_batches == 1
input_tensor = torch.squeeze(input_tensor, 0)
output_tensor = torch.squeeze(output_tensor, 0)
return {"input": input_tensor, "output": output_tensor}
26 changes: 23 additions & 3 deletions iron/operators/transpose/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,39 @@ def get_params():
m,
n,
s,
1,
marks=marks,
)
)

# num_batches>1: batch B independent same-shape transposes into one dispatch
# (regular shape, single column/channel). num_batches=2 runs in the default
# suite; the larger batch is extensive.
for nb in (2, 4):
params.append(
pytest.param(
2048,
64,
1,
1,
m,
n,
8,
nb,
marks=[] if nb == 2 else [pytest.mark.extensive],
)
)

return params


@pytest.mark.metrics(
Latency=r"Latency \(us\): (?P<value>[\d\.]+)",
Bandwidth=r"Effective Bandwidth: (?P<value>[\d\.e\+-]+) GB/s",
)
@pytest.mark.parametrize("M,N,aie_columns,channels,m,n,s", get_params())
def test_transpose(M, N, aie_columns, channels, m, n, s, aie_context):
golden_ref = generate_golden_reference(rows=M, cols=N)
@pytest.mark.parametrize("M,N,aie_columns,channels,m,n,s,num_batches", get_params())
def test_transpose(M, N, aie_columns, channels, m, n, s, num_batches, aie_context):
golden_ref = generate_golden_reference(rows=M, cols=N, num_batches=num_batches)

operator = Transpose(
M=M,
Expand All @@ -70,6 +89,7 @@ def test_transpose(M, N, aie_columns, channels, m, n, s, aie_context):
m=m,
n=n,
s=s,
num_batches=num_batches,
context=aie_context,
)

Expand Down