From 4826f26722315a5e82527818f0045c33f14d7dcc Mon Sep 17 00:00:00 2001 From: d-kor Date: Wed, 8 Apr 2026 13:54:44 +0000 Subject: [PATCH 01/17] Initial MVAU-Tiling commit into dev. - Added MM infrastructure for tiling. - Fetch weight fixes. --- finn-rtllib/dwc/hdl/axis_dwc.sv | 90 ++++ finn-rtllib/fetch_weights/fetch_weights.sv | 318 ++++++++++++ .../fetch_weights/fetch_weights_wrapper.v | 219 ++++++++ .../fetch_weights/local_weight_buffer.sv | 305 ++++++++++++ finn-rtllib/mvu_tiled/acc_stage.sv | 172 +++++++ finn-rtllib/mvu_tiled/add_tree.sv | 145 ++++++ finn-rtllib/mvu_tiled/cu_mvau_tiled.sv | 295 +++++++++++ finn-rtllib/mvu_tiled/mvu_tiled_axi.sv | 279 +++++++++++ finn-rtllib/mvu_tiled/mvu_tiled_axi_wrapper.v | 97 ++++ finn-rtllib/mvu_tiled/reorder_out.sv | 341 +++++++++++++ finn-rtllib/mvu_tiled/replay_buff_tile.sv | 392 +++++++++++++++ finn-rtllib/mvu_tiled/weights_buff_tile.sv | 260 ++++++++++ finn-rtllib/ram/ram_p_c.sv | 7 +- finn_xsi/finn_xsi/sim_engine.py | 2 +- src/finn/core/rtlsim_exec.py | 18 + src/finn/custom_op/fpgadataflow/hwcustomop.py | 14 +- .../fpgadataflow/matrixvectoractivation.py | 468 ++++++++++-------- .../rtl/matrixvectoractivation_rtl.py | 135 +++-- .../fpgadataflow/convert_to_hw_layers.py | 4 +- .../fpgadataflow/create_stitched_ip.py | 116 ++--- .../fpgadataflow/set_fifo_depths.py | 96 +++- .../test_fpgadataflow_finnloop.py | 3 + tests/fpgadataflow/test_fpgadataflow_mvau.py | 103 +++- 23 files changed, 3529 insertions(+), 350 deletions(-) create mode 100644 finn-rtllib/dwc/hdl/axis_dwc.sv create mode 100644 finn-rtllib/fetch_weights/fetch_weights.sv create mode 100644 finn-rtllib/fetch_weights/fetch_weights_wrapper.v create mode 100644 finn-rtllib/fetch_weights/local_weight_buffer.sv create mode 100644 finn-rtllib/mvu_tiled/acc_stage.sv create mode 100644 finn-rtllib/mvu_tiled/add_tree.sv create mode 100644 finn-rtllib/mvu_tiled/cu_mvau_tiled.sv create mode 100644 finn-rtllib/mvu_tiled/mvu_tiled_axi.sv create mode 100644 finn-rtllib/mvu_tiled/mvu_tiled_axi_wrapper.v create mode 100644 finn-rtllib/mvu_tiled/reorder_out.sv create mode 100644 finn-rtllib/mvu_tiled/replay_buff_tile.sv create mode 100644 finn-rtllib/mvu_tiled/weights_buff_tile.sv diff --git a/finn-rtllib/dwc/hdl/axis_dwc.sv b/finn-rtllib/dwc/hdl/axis_dwc.sv new file mode 100644 index 0000000000..a482ebd5c9 --- /dev/null +++ b/finn-rtllib/dwc/hdl/axis_dwc.sv @@ -0,0 +1,90 @@ +// Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +// +// This file is subject to the Xilinx Design License Agreement located +// in the LICENSE.md file in the root directory of this repository. +// +// This file contains confidential and proprietary information of Xilinx, Inc. +// and is protected under U.S. and international copyright and other +// intellectual property laws. +// +// DISCLAIMER +// This disclaimer is not a license and does not grant any rights to the materials +// distributed herewith. Except as otherwise provided in a valid license issued to +// you by Xilinx, and to the maximum extent permitted by applicable law: (1) THESE +// MATERIALS ARE MADE AVAILABLE "AS IS" AND WITH ALL FAULTS, AND XILINX HEREBY +// DISCLAIMS ALL WARRANTIES AND CONDITIONS, EXPRESS, IMPLIED, OR STATUTORY, +// INCLUDING BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY, NONINFRINGEMENT, OR +// FITNESS FOR ANY PARTICULAR PURPOSE; and (2) Xilinx shall not be liable (whether +// in contract or tort, including negligence, or under any other theory of +// liability) for any loss or damage of any kind or nature related to, arising +// under or in connection with these materials, including for any direct, or any +// indirect, special, incidental, or consequential loss or damage (including loss +// of data, profits, goodwill, or any type of loss or damage suffered as a result +// of any action brought by a third party) even if such damage or loss was +// reasonably foreseeable or Xilinx had been advised of the possibility of the +// same. +// +// CRITICAL APPLICATIONS +// Xilinx products are not designed or intended to be fail-safe, or for use in +// any application requiring failsafe performance, such as life-support or safety +// devices or systems, Class III medical devices, nuclear facilities, applications +// related to the deployment of airbags, or any other applications that could lead +// to death, personal injury, or severe property or environmental damage +// (individually and collectively, "Critical Applications"). Customer assumes the +// sole risk and liability of any use of Xilinx products in Critical Applications, +// subject only to applicable laws and regulations governing limitations on product +// liability. +// +// THIS COPYRIGHT NOTICE AND DISCLAIMER MUST BE RETAINED AS PART OF THIS FILE AT ALL TIMES. + +module axis_dwc #( + parameter integer DEPTH = 512, + parameter integer S_DATA_BITS = 32, + parameter integer M_DATA_BITS = 8 +) ( + input logic aclk, + input logic aresetn, + + input logic s_axis_tvalid, + output logic s_axis_tready, + input logic [S_DATA_BITS-1:0] s_axis_tdata, + input logic [S_DATA_BITS/8-1:0] s_axis_tkeep, + input logic s_axis_tlast, + + output logic m_axis_tvalid, + input logic m_axis_tready, + output logic [M_DATA_BITS-1:0] m_axis_tdata, + output logic [M_DATA_BITS/8-1:0] m_axis_tkeep, + output logic m_axis_tlast +); + +axis_fifo_adapter #( + .DEPTH(DEPTH), + .S_DATA_WIDTH(S_DATA_BITS), + .M_DATA_WIDTH(M_DATA_BITS) +) inst_fifo_adapter ( + .clk (aclk), + .rst (~aresetn), + + .s_axis_tdata (s_axis_tdata), + .s_axis_tkeep (s_axis_tkeep), + .s_axis_tvalid (s_axis_tvalid), + .s_axis_tready (s_axis_tready), + .s_axis_tlast (s_axis_tlast), + .s_axis_tid ('0), + .s_axis_tdest ('0), + .s_axis_tuser ('0), + + .pause_req('0), + + .m_axis_tdata (m_axis_tdata), + .m_axis_tkeep (m_axis_tkeep), + .m_axis_tvalid (m_axis_tvalid), + .m_axis_tready (m_axis_tready), + .m_axis_tlast (m_axis_tlast), + .m_axis_tid (), + .m_axis_tdest (), + .m_axis_tuser () +); + +endmodule diff --git a/finn-rtllib/fetch_weights/fetch_weights.sv b/finn-rtllib/fetch_weights/fetch_weights.sv new file mode 100644 index 0000000000..e5b740c164 --- /dev/null +++ b/finn-rtllib/fetch_weights/fetch_weights.sv @@ -0,0 +1,318 @@ +/****************************************************************************** + * Copyright (C) 2024, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + *****************************************************************************/ + +module fetch_weights #( + int unsigned PE, + int unsigned SIMD, + int unsigned TH = 1, + int unsigned MH, + int unsigned MW, + int unsigned N_REPS, + int unsigned WEIGHT_WIDTH = 8, + + int unsigned ADDR_BITS = 64, + int unsigned DATA_BITS = 256, + int unsigned LEN_BITS = 32, + int unsigned IDX_BITS = 16, + + int unsigned N_LAYERS, + + int unsigned EN_MLO = 1, + + int unsigned QDEPTH = 8, + int unsigned EN_OREG = 1, + int unsigned N_DCPL_STGS = 1, + int unsigned DBG = 0, + + // Safely deducible parameters + int unsigned IWSIMD = (TH > 1) ? ((PE*SIMD)/TH) : SIMD, + int unsigned OWSIMD = (PE * SIMD) / TH, + int unsigned DS_BITS_BA = (IWSIMD*WEIGHT_WIDTH+7)/8 * 8, + int unsigned WS_BITS_BA = (OWSIMD*WEIGHT_WIDTH+7)/8 * 8, + logic[ADDR_BITS-1:0] LAYER_OFFS = ((MH*MW*WEIGHT_WIDTH+7)/8) & ~7 // 8-byte aligned +) ( + input wire aclk, + input wire aresetn, + + output logic m_done, + + // AXI + output logic[ADDR_BITS-1:0] m_axi_ddr_araddr, + output logic[1:0] m_axi_ddr_arburst, + output logic[3:0] m_axi_ddr_arcache, + output logic[1:0] m_axi_ddr_arid, + output logic[7:0] m_axi_ddr_arlen, + output logic[0:0] m_axi_ddr_arlock, + output logic[2:0] m_axi_ddr_arprot, + output logic[2:0] m_axi_ddr_arsize, + input logic m_axi_ddr_arready, + output logic m_axi_ddr_arvalid, + output logic[ADDR_BITS-1:0] m_axi_ddr_awaddr, + output logic[1:0] m_axi_ddr_awburst, + output logic[3:0] m_axi_ddr_awcache, + output logic[1:0] m_axi_ddr_awid, + output logic[7:0] m_axi_ddr_awlen, + output logic[0:0] m_axi_ddr_awlock, + output logic[2:0] m_axi_ddr_awprot, + output logic[2:0] m_axi_ddr_awsize, + input logic m_axi_ddr_awready, + output logic m_axi_ddr_awvalid, + input logic[DATA_BITS-1:0] m_axi_ddr_rdata, + input logic[1:0] m_axi_ddr_rid, + input logic m_axi_ddr_rlast, + input logic[1:0] m_axi_ddr_rresp, + output logic m_axi_ddr_rready, + input logic m_axi_ddr_rvalid, + output logic[DATA_BITS-1:0] m_axi_ddr_wdata, + output logic m_axi_ddr_wlast, + output logic[DATA_BITS/8-1:0] m_axi_ddr_wstrb, + input logic m_axi_ddr_wready, + output logic m_axi_ddr_wvalid, + input logic[1:0] m_axi_ddr_bid, + input logic[1:0] m_axi_ddr_bresp, + output logic m_axi_ddr_bready, + input logic m_axi_ddr_bvalid, + + // Index + input logic s_idx_tvalid, + output logic s_idx_tready, + input logic[IDX_BITS-1:0] s_idx_tdata, + + // DMA stream out (to external width converter) + output logic axis_dma_tvalid, + input logic axis_dma_tready, + output logic[DATA_BITS-1:0] axis_dma_tdata, + output logic[DATA_BITS/8-1:0] axis_dma_tkeep, + output logic axis_dma_tlast, + + // DWC stream in (from external width converter) + input logic axis_dwc_tvalid, + output logic axis_dwc_tready, + input logic[DS_BITS_BA-1:0] axis_dwc_tdata, + input logic[(DS_BITS_BA)/8-1:0] axis_dwc_tkeep, + input logic axis_dwc_tlast, + + // Stream + // TODO: Should we reg this? Would be quite wide ... + output logic m_axis_tvalid, + input logic m_axis_tready, + output logic[WS_BITS_BA-1:0] m_axis_tdata +); + +// Offsets +logic [N_LAYERS-1:0][ADDR_BITS-1:0] l_offsets; +for(genvar i = 0; i < N_LAYERS; i++) begin + assign l_offsets[i] = (i * LAYER_OFFS); +end + +// +// Indexes and DMA +// + +logic dma_tvalid; +logic dma_tready; +logic [ADDR_BITS-1:0] dma_addr; +logic [LEN_BITS-1:0] dma_len; + +if(TH > 1) begin + + // Consts + localparam integer REPS_BITS = (N_REPS == 1) ? 1 : $clog2(N_REPS); + + // Reps + typedef enum logic[0:0] {ST_IDLE, ST_DMA} state_t; + state_t state_C = ST_IDLE, state_N; + + logic [REPS_BITS-1:0] cnt_dma_C = '0, cnt_dma_N; + logic [IDX_BITS-1:0] idx_C = '0, idx_N; + + logic q_idx_out_tvalid, q_idx_out_tready; + logic [IDX_BITS-1:0] q_idx_out_tdata; + + // Idx queue + Q_srl #( + .depth(QDEPTH), + .width(IDX_BITS) + ) inst_queue_in ( + .clock(aclk), .reset(!aresetn), + .count(), .maxcount(), + .i_d(s_idx_tdata), .i_v(s_idx_tvalid), .i_r(s_idx_tready), + .o_d(q_idx_out_tdata), .o_v(q_idx_out_tvalid), .o_r(q_idx_out_tready) + ); + + assign dma_addr = l_offsets[idx_C]; + assign dma_len = ((MH*MW*WEIGHT_WIDTH+7)/8) & ~7; + + always_ff @( posedge aclk ) begin: REG + if(~aresetn) begin + state_C <= ST_IDLE; + + cnt_dma_C <= '0; + idx_C <= 'X; + end else begin + state_C <= state_N; + + cnt_dma_C <= cnt_dma_N; + idx_C <= idx_N; + end + end + + always_comb begin: NSL + state_N = state_C; + + case (state_C) + ST_IDLE: + state_N = q_idx_out_tvalid ? ST_DMA : ST_IDLE; + + ST_DMA: + state_N = (cnt_dma_C == N_REPS-1) && dma_tready ? ST_IDLE : ST_DMA; + + endcase + end + + always_comb begin: DP + cnt_dma_N = cnt_dma_C; + idx_N = idx_C; + + q_idx_out_tready = 1'b0; + dma_tvalid = 1'b0; + + case (state_C) + ST_IDLE: begin + q_idx_out_tready = 1'b1; + cnt_dma_N = 0; + if(q_idx_out_tvalid) begin + idx_N = q_idx_out_tdata; + end + end + + ST_DMA: begin + dma_tvalid = 1'b1; + if(dma_tready) begin + cnt_dma_N = cnt_dma_C + 1; + end + end + + endcase + end + +end else begin + + // Idx queue + logic [IDX_BITS-1:0] q_idx_out_tdata; + + Q_srl #( + .depth(QDEPTH), + .width(IDX_BITS) + ) inst_idx_queue ( + .clock(aclk), .reset(!aresetn), + .count(), .maxcount(), + .i_d(s_idx_tdata), .i_v(s_idx_tvalid), .i_r(s_idx_tready), + .o_d(q_idx_out_tdata), .o_v(dma_tvalid), .o_r(dma_tready) + ); + + assign dma_addr = l_offsets[q_idx_out_tdata]; + assign dma_len = ((MH*MW*WEIGHT_WIDTH+7)/8) & ~7; + +end + +cdma_u_rd #( + .DATA_BITS(DATA_BITS), + .ADDR_BITS(ADDR_BITS), + .LEN_BITS(LEN_BITS) +) inst_dma ( + .aclk(aclk), .aresetn(aresetn), + + .rd_valid(dma_tvalid), .rd_ready(dma_tready), + .rd_paddr(dma_addr), .rd_len(dma_len), + .rd_done(m_done), + + .m_axi_ddr_arvalid(m_axi_ddr_arvalid), + .m_axi_ddr_arready(m_axi_ddr_arready), + .m_axi_ddr_araddr(m_axi_ddr_araddr), + .m_axi_ddr_arid(m_axi_ddr_arid), + .m_axi_ddr_arlen(m_axi_ddr_arlen), + .m_axi_ddr_arsize(m_axi_ddr_arsize), + .m_axi_ddr_arburst(m_axi_ddr_arburst), + .m_axi_ddr_arlock(m_axi_ddr_arlock), + .m_axi_ddr_arcache(m_axi_ddr_arcache), + .m_axi_ddr_arprot(m_axi_ddr_arprot), + .m_axi_ddr_rvalid(m_axi_ddr_rvalid), + .m_axi_ddr_rready(m_axi_ddr_rready), + .m_axi_ddr_rdata(m_axi_ddr_rdata), + .m_axi_ddr_rlast(m_axi_ddr_rlast), + .m_axi_ddr_rid(m_axi_ddr_rid), + .m_axi_ddr_rresp(m_axi_ddr_rresp), + + .m_axis_ddr_tvalid(axis_dma_tvalid), + .m_axis_ddr_tready(axis_dma_tready), + .m_axis_ddr_tdata(axis_dma_tdata), + .m_axis_ddr_tkeep(axis_dma_tkeep), + .m_axis_ddr_tlast(axis_dma_tlast) +); + +// Local weight buffer +// Only for non-tiled nodes +logic axis_lwb_tvalid; +logic axis_lwb_tready; +logic[WS_BITS_BA-1:0] axis_lwb_tdata; + +if(TH == 1) begin + local_weight_buffer #( + .PE(PE), .SIMD(SIMD), .MH(MH), .MW(MW), .N_REPS(N_REPS), .WEIGHT_WIDTH(WEIGHT_WIDTH), .DBG(DBG) + ) inst_weight_buff ( + .clk(aclk), .rst(~aresetn), + .ivld(axis_dwc_tvalid), .irdy(axis_dwc_tready), .idat(axis_dwc_tdata), + .ovld(axis_lwb_tvalid), .ordy(axis_lwb_tready), .odat(axis_lwb_tdata) + ); +end else begin + assign axis_lwb_tvalid = axis_dwc_tvalid; + assign axis_dwc_tready = axis_lwb_tready; + assign axis_lwb_tdata = axis_dwc_tdata; +end + +// Reg slice +if(EN_OREG) begin + skid #( + .DATA_WIDTH(WS_BITS_BA), .FEED_STAGES(N_DCPL_STGS) + ) inst_oreg ( + .clk(aclk), .rst(!aresetn), + .ivld(axis_lwb_tvalid), .irdy(axis_lwb_tready), .idat(axis_lwb_tdata), + .ovld(m_axis_tvalid), .ordy(m_axis_tready), .odat(m_axis_tdata) + ); +end else begin + assign m_axis_tvalid = axis_lwb_tvalid; + assign axis_lwb_tready = m_axis_tready; + assign m_axis_tdata = axis_lwb_tdata; +end + +endmodule diff --git a/finn-rtllib/fetch_weights/fetch_weights_wrapper.v b/finn-rtllib/fetch_weights/fetch_weights_wrapper.v new file mode 100644 index 0000000000..06fa66031d --- /dev/null +++ b/finn-rtllib/fetch_weights/fetch_weights_wrapper.v @@ -0,0 +1,219 @@ +/****************************************************************************** + * Copyright (C) 2024, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * @brief Verilog AXI-lite wrapper for MVU & VVU. + *****************************************************************************/ + +`define $EN_MLO$ + +module $MODULE_NAME_AXI_WRAPPER$ #( + parameter MW = $MW$, + parameter MH = $MH$, + parameter PE = $PE$, + parameter SIMD = $SIMD$, + parameter TH = $TH$, + parameter N_REPS = $N_REPS$, + parameter WEIGHT_WIDTH = $WEIGHT_WIDTH$, + parameter N_LAYERS = $N_LAYERS$, + + parameter ADDR_BITS = 64, + parameter DATA_BITS = 256, + parameter LEN_BITS = 32, + parameter IDX_BITS = 16, + + // Safely deducible parameters + parameter IWSIMD = (TH > 1) ? ((PE*SIMD)/TH) : SIMD, + parameter WSIMD = (PE * SIMD) / TH, + parameter DS_BITS_BA = (IWSIMD*WEIGHT_WIDTH+7)/8 * 8, + parameter WS_BITS_BA = (WSIMD*WEIGHT_WIDTH+7)/8 * 8 +)( + // Global Control + (* X_INTERFACE_PARAMETER = "ASSOCIATED_BUSIF axi_mm:in_idx0_V:out0_V, ASSOCIATED_RESET ap_rst_n" *) + (* X_INTERFACE_INFO = "xilinx.com:signal:clock:1.0 ap_clk CLK" *) + input ap_clk, + (* X_INTERFACE_PARAMETER = "POLARITY ACTIVE_LOW" *) + input ap_rst_n, + + // Completion + output wire out_done, + + // AXI + (* X_INTERFACE_INFO = "xilinx.com:interface:aximm:1.0 axi_mm" *) + output wire[ADDR_BITS-1:0] axi_mm_araddr, + output wire[1:0] axi_mm_arburst, + output wire[3:0] axi_mm_arcache, + output wire[1:0] axi_mm_arid, + output wire[7:0] axi_mm_arlen, + output wire[0:0] axi_mm_arlock, + output wire[2:0] axi_mm_arprot, + output wire[2:0] axi_mm_arsize, + input wire axi_mm_arready, + output wire axi_mm_arvalid, + output wire[ADDR_BITS-1:0] axi_mm_awaddr, + output wire[1:0] axi_mm_awburst, + output wire[3:0] axi_mm_awcache, + output wire[1:0] axi_mm_awid, + output wire[7:0] axi_mm_awlen, + output wire[0:0] axi_mm_awlock, + output wire[2:0] axi_mm_awprot, + output wire[2:0] axi_mm_awsize, + input wire axi_mm_awready, + output wire axi_mm_awvalid, + input wire[DATA_BITS-1:0] axi_mm_rdata, + input wire[1:0] axi_mm_rid, + input wire axi_mm_rlast, + input wire[1:0] axi_mm_rresp, + output wire axi_mm_rready, + input wire axi_mm_rvalid, + output wire[DATA_BITS-1:0] axi_mm_wdata, + output wire axi_mm_wlast, + output wire[DATA_BITS/8-1:0] axi_mm_wstrb, + input wire axi_mm_wready, + output wire axi_mm_wvalid, + input wire[1:0] axi_mm_bid, + input wire[1:0] axi_mm_bresp, + output wire axi_mm_bready, + input wire axi_mm_bvalid, + +`ifdef EN_MLO + // Index + input wire in_idx0_V_tvalid, + output wire in_idx0_V_tready, + input wire[IDX_BITS-1:0] in_idx0_V_tdata, +`endif + + // Stream + output wire out0_V_tvalid, + input wire out0_V_tready, + output wire[WS_BITS_BA-1:0] out0_V_tdata +); + +`ifndef EN_MLO + wire in_idx0_V_tvalid; + wire in_idx0_V_tready; + wire [IDX_BITS-1:0] in_idx0_V_tdata; + + assign in_idx0_V_tvalid = 1'b1; + assign in_idx0_V_tdata = 0; +`endif + +// DMA <-> DWC internal wires +wire axis_dma_tvalid; +wire axis_dma_tready; +wire [DATA_BITS-1:0] axis_dma_tdata; +wire [DATA_BITS/8-1:0] axis_dma_tkeep; +wire axis_dma_tlast; + +wire axis_dwc_tvalid; +wire axis_dwc_tready; +wire [DS_BITS_BA-1:0] axis_dwc_tdata; +wire [(DS_BITS_BA)/8-1:0] axis_dwc_tkeep; +wire axis_dwc_tlast; + +// Width converter +$DWC_MODULE_NAME$ inst_dwc ( + .aclk(ap_clk), .aresetn(ap_rst_n), + .s_axis_tvalid(axis_dma_tvalid), .s_axis_tready(axis_dma_tready), .s_axis_tdata(axis_dma_tdata), .s_axis_tkeep(axis_dma_tkeep), .s_axis_tlast(axis_dma_tlast), + .m_axis_tvalid(axis_dwc_tvalid), .m_axis_tready(axis_dwc_tready), .m_axis_tdata(axis_dwc_tdata), .m_axis_tkeep(axis_dwc_tkeep), .m_axis_tlast(axis_dwc_tlast) +); + +fetch_weights #( + .PE(PE), .SIMD(SIMD), .TH(TH), + .MH(MH), .MW(MW), .N_REPS(N_REPS), + .WEIGHT_WIDTH(WEIGHT_WIDTH), + .ADDR_BITS(ADDR_BITS), .DATA_BITS(DATA_BITS), .LEN_BITS(LEN_BITS), .IDX_BITS(IDX_BITS), + .N_LAYERS(N_LAYERS), +`ifdef EN_MLO + .EN_MLO(1) +`else + .EN_MLO(0) +`endif +) inst ( + .aclk (ap_clk), + .aresetn (ap_rst_n), + + .m_axi_ddr_araddr (axi_mm_araddr), + .m_axi_ddr_arburst (axi_mm_arburst), + .m_axi_ddr_arcache (axi_mm_arcache), + .m_axi_ddr_arid (axi_mm_arid), + .m_axi_ddr_arlen (axi_mm_arlen), + .m_axi_ddr_arlock (axi_mm_arlock), + .m_axi_ddr_arprot (axi_mm_arprot), + .m_axi_ddr_arsize (axi_mm_arsize), + .m_axi_ddr_arready (axi_mm_arready), + .m_axi_ddr_arvalid (axi_mm_arvalid), + .m_axi_ddr_awaddr (axi_mm_awaddr), + .m_axi_ddr_awburst (axi_mm_awburst), + .m_axi_ddr_awcache (axi_mm_awcache), + .m_axi_ddr_awid (axi_mm_awid), + .m_axi_ddr_awlen (axi_mm_awlen), + .m_axi_ddr_awlock (axi_mm_awlock), + .m_axi_ddr_awprot (axi_mm_awprot), + .m_axi_ddr_awsize (axi_mm_awsize), + .m_axi_ddr_awready (axi_mm_awready), + .m_axi_ddr_awvalid (axi_mm_awvalid), + .m_axi_ddr_rdata (axi_mm_rdata), + .m_axi_ddr_rid (axi_mm_rid), + .m_axi_ddr_rlast (axi_mm_rlast), + .m_axi_ddr_rresp (axi_mm_rresp), + .m_axi_ddr_rready (axi_mm_rready), + .m_axi_ddr_rvalid (axi_mm_rvalid), + .m_axi_ddr_wdata (axi_mm_wdata), + .m_axi_ddr_wlast (axi_mm_wlast), + .m_axi_ddr_wstrb (axi_mm_wstrb), + .m_axi_ddr_wready (axi_mm_wready), + .m_axi_ddr_wvalid (axi_mm_wvalid), + .m_axi_ddr_bid (axi_mm_bid), + .m_axi_ddr_bresp (axi_mm_bresp), + .m_axi_ddr_bready (axi_mm_bready), + .m_axi_ddr_bvalid (axi_mm_bvalid), + + .s_idx_tvalid (in_idx0_V_tvalid), + .s_idx_tready (in_idx0_V_tready), + .s_idx_tdata (in_idx0_V_tdata), + + .axis_dma_tvalid (axis_dma_tvalid), + .axis_dma_tready (axis_dma_tready), + .axis_dma_tdata (axis_dma_tdata), + .axis_dma_tkeep (axis_dma_tkeep), + .axis_dma_tlast (axis_dma_tlast), + + .axis_dwc_tvalid (axis_dwc_tvalid), + .axis_dwc_tready (axis_dwc_tready), + .axis_dwc_tdata (axis_dwc_tdata), + .axis_dwc_tkeep (axis_dwc_tkeep), + .axis_dwc_tlast (axis_dwc_tlast), + + .m_axis_tvalid (out0_V_tvalid), + .m_axis_tready (out0_V_tready), + .m_axis_tdata (out0_V_tdata) +); + +endmodule // $MODULE_NAME_AXI_WRAPPER$ diff --git a/finn-rtllib/fetch_weights/local_weight_buffer.sv b/finn-rtllib/fetch_weights/local_weight_buffer.sv new file mode 100644 index 0000000000..aec4a8ab0d --- /dev/null +++ b/finn-rtllib/fetch_weights/local_weight_buffer.sv @@ -0,0 +1,305 @@ +/****************************************************************************** + * Copyright (C) 2024, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + *****************************************************************************/ + +module local_weight_buffer #( + int unsigned PE, + int unsigned SIMD, + int unsigned WEIGHT_WIDTH = 8, + int unsigned MH, + int unsigned MW, + int unsigned N_REPS, + int unsigned DBG = 0 +) ( + input logic clk, + input logic rst, + + input logic ivld, + output logic irdy, + input logic [SIMD-1:0][WEIGHT_WIDTH-1:0] idat, + + output logic ovld, + input logic ordy, + output logic [PE-1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] odat +); + +// ---------------------------------------------------------------------------- +// Consts and types +// ---------------------------------------------------------------------------- + +localparam int unsigned SF = MW/SIMD; +localparam int unsigned NF = MH/PE; +localparam int unsigned N_TLS = SF * NF; + +localparam int unsigned SIMD_BITS = (SIMD == 1) ? 1 : $clog2(SIMD); +localparam int unsigned PE_BITS = (PE == 1) ? 1 : $clog2(PE); +localparam int unsigned WGT_ADDR_BITS = $clog2(NF * SF); +localparam int unsigned RAM_BITS = (SIMD*WEIGHT_WIDTH + 7)/8 * 8; +localparam int unsigned WGT_EN_BITS = RAM_BITS / 8; +localparam int unsigned N_TLS_BITS = $clog2(N_TLS); +localparam int unsigned N_REPS_BITS = $clog2(N_REPS); + +typedef enum logic[1:0] {ST_WR_0, ST_WR_0_WAIT, ST_WR_1, ST_WR_1_WAIT} state_wr_t; +typedef enum logic {ST_RD_0, ST_RD_1} state_rd_t; + +// ---------------------------------------------------------------------------- +// Writer +// ---------------------------------------------------------------------------- + +// -- Regs +state_wr_t state_wr_C = ST_WR_0, state_wr_N; +state_rd_t state_rd_C = ST_RD_0, state_rd_N; + +logic[N_TLS_BITS-1:0] wr_pntr_C = '0, wr_pntr_N; +logic[PE_BITS-1:0] curr_pe_C = '0, curr_pe_N; + +// -- Signals +logic [1:0][PE-1:0][WGT_EN_BITS-1:0] a_we; // Bank enables +logic [1:0][WGT_ADDR_BITS-1:0] a_addr; +logic [1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] a_data_in; + +// -- REG +always_ff @( posedge clk ) begin : REG_PROC_WR + if(rst) begin + state_wr_C <= ST_WR_0; + + wr_pntr_C <= '0; + curr_pe_C <= '0; + end + else begin + state_wr_C <= state_wr_N; + + wr_pntr_C <= wr_pntr_N; + curr_pe_C <= curr_pe_N; + end +end + +// -- NSL +always_comb begin : NSL_PROC_WR + state_wr_N = state_wr_C; + + case (state_wr_C) + ST_WR_0: + if((curr_pe_C == PE - 1) && (wr_pntr_C == N_TLS - 1) && ivld) begin + state_wr_N = (state_rd_C == ST_RD_0) ? ST_WR_1 : ST_WR_0_WAIT; + end + + ST_WR_0_WAIT: + state_wr_N = (state_rd_C == ST_RD_0) ? ST_WR_1 : ST_WR_0_WAIT; + + ST_WR_1: + if((curr_pe_C == PE - 1) && (wr_pntr_C == N_TLS - 1) && ivld) begin + state_wr_N = (state_rd_C == ST_RD_1) ? ST_WR_0 : ST_WR_1_WAIT; + end + + ST_WR_1_WAIT: + state_wr_N = (state_rd_C == ST_RD_1) ? ST_WR_0 : ST_WR_1_WAIT; + + endcase +end + +// -- DP +always_comb begin : DP_PROC_WR + wr_pntr_N = wr_pntr_C; + curr_pe_N = curr_pe_C; + + // Input + irdy = 1'b0; + + // Buffers + a_we = '0; + for(int i = 0; i < 2; i++) begin + a_addr[i] = wr_pntr_C; + a_data_in[i] = idat; + end + + // Write and count + case (state_wr_C) + ST_WR_0, ST_WR_1: begin + irdy = 1'b1; + + if(ivld) begin + for(int i = 0; i < PE; i++) begin + if(curr_pe_C == i) begin + a_we[state_wr_C == ST_WR_1][i] = '1; + end + end + + curr_pe_N = (curr_pe_C == PE-1) ? 0 : curr_pe_C + 1; + wr_pntr_N = (curr_pe_C == PE-1) ? ((wr_pntr_C == N_TLS-1) ? 0 : wr_pntr_C + 1) : wr_pntr_C; + end + end + endcase + +end + +// ---------------------------------------------------------------------------- +// Reader +// ---------------------------------------------------------------------------- + +// -- Regs +logic [N_TLS_BITS-1:0] rd_pntr_C = '0, rd_pntr_N; +logic [N_REPS_BITS-1:0] reps_C = '0, reps_N; + +//logic [15:0] rd_pntr_C = '0, rd_pntr_N; +//logic [15:0] reps_C = '0, reps_N; + +logic [1:0] vld_s0_C = '0, vld_s0_N; +logic [1:0] vld_s1_C = '0, vld_s1_N; + +logic vld_C = '0, vld_N; +logic [PE-1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] odat_C = '0, odat_N; + +// -- Signals +logic [1:0][WGT_ADDR_BITS-1:0] b_addr; +logic [1:0][PE-1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] odat_ram; + +// -- REG +always_ff @( posedge clk ) begin : REG_PROC_RD + if(rst) begin + state_rd_C <= ST_RD_0; + + rd_pntr_C <= '0; + reps_C <= '0; + + vld_s0_C <= '0; + vld_s1_C <= '0; + vld_C <= '0; + odat_C <= 'X; + end + else begin + state_rd_C <= state_rd_N; + + rd_pntr_C <= rd_pntr_N; + reps_C <= reps_N; + + vld_s0_C <= vld_s0_N; + vld_s1_C <= vld_s1_N; + vld_C <= vld_N; + odat_C <= odat_N; + end +end + +// -- NSL +always_comb begin : NSL_PROC_RD + state_rd_N = state_rd_C; + + case (state_rd_C) + ST_RD_0: + if(ordy && ((state_wr_C == ST_WR_0) ? (wr_pntr_C > rd_pntr_C) : 1'b1)) begin + if((rd_pntr_C == N_TLS-1) && (reps_C == N_REPS-1)) begin + state_rd_N = ST_RD_1; + end + end + + ST_RD_1: + if(ordy && ((state_wr_C == ST_WR_1) ? (wr_pntr_C > rd_pntr_C) : 1'b1)) begin + if((rd_pntr_C == N_TLS-1) && (reps_C == N_REPS-1)) begin + state_rd_N = ST_RD_0; + end + end + endcase +end + +// -- DP +always_comb begin : DP_PROC_RD + rd_pntr_N = rd_pntr_C; + reps_N = reps_C; + + for(int i = 0; i < 2; i++) begin + vld_s0_N[i] = ordy ? 1'b0 : vld_s0_C[i]; + vld_s1_N[i] = ordy ? vld_s0_C[i] : vld_s1_C[i]; + end + + vld_N = ordy ? |vld_s1_C : vld_C; + odat_N = ordy ? (vld_s1_C[0] ? odat_ram[0] : odat_ram[1]) : odat_C; + + for(int i = 0; i < 2; i++) begin + b_addr[i] = rd_pntr_C; + end + + case(state_rd_C) + ST_RD_0: begin + if(ordy) begin + if((state_wr_C == ST_WR_0) ? (wr_pntr_C > rd_pntr_C) : 1'b1) begin + + vld_s0_N[0] = 1'b1; + + rd_pntr_N = (rd_pntr_C == N_TLS-1) ? 0 : rd_pntr_C + 1; + reps_N = (rd_pntr_C == N_TLS-1) ? ((reps_C == N_REPS-1) ? 0 : reps_C + 1) : reps_C; + end + end + end + + ST_RD_1: begin + if(ordy) begin + if((state_wr_C == ST_WR_1) ? (wr_pntr_C > rd_pntr_C) : 1'b1) begin + + vld_s0_N[1] = 1'b1; + + rd_pntr_N = (rd_pntr_C == N_TLS-1) ? 0 : rd_pntr_C + 1; + reps_N = (rd_pntr_C == N_TLS-1) ? ((reps_C == N_REPS-1) ? 0 : reps_C + 1) : reps_C; + end + end + end + + endcase + +end + +assign ovld = vld_C; +assign odat = odat_C; + +// ---------------------------------------------------------------------------- +// Weights +// ---------------------------------------------------------------------------- + +for(genvar i = 0; i < 2; i++) begin + for(genvar j = 0; j < PE; j++) begin + ram_p_c #( + .ADDR_BITS(WGT_ADDR_BITS), + .DATA_BITS(RAM_BITS), + .RAM_STYLE("block") + ) inst_ram_tp_c ( + .clk(clk), + .a_en(1'b1), + .a_we(a_we[i][j]), + .a_addr(a_addr[i]), + .b_en(ordy), + .b_addr(b_addr[i]), + .a_data_in(a_data_in[i]), + .a_data_out(), + .b_data_out(odat_ram[i][j]) + ); + end +end + +endmodule diff --git a/finn-rtllib/mvu_tiled/acc_stage.sv b/finn-rtllib/mvu_tiled/acc_stage.sv new file mode 100644 index 0000000000..87214f0c1c --- /dev/null +++ b/finn-rtllib/mvu_tiled/acc_stage.sv @@ -0,0 +1,172 @@ +/****************************************************************************** + * Copyright (C) 2024, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + *****************************************************************************/ + + +module acc_stage #( + int unsigned CHAINLEN, + int unsigned PE, + int unsigned ACCU_WIDTH, + int unsigned TH, + int unsigned TH_MAX = 2*TH +) ( + input logic clk, + input logic rst, + input logic en, + + input logic [PE-1:0][CHAINLEN-1:0][ACCU_WIDTH-1:0] idat, + input logic ival, + input logic ilast, + + output logic [PE-1:0][ACCU_WIDTH-1:0] odat, + output logic oval +); + +// +// Adder tree +// + +localparam integer TREE_HEIGHT = $clog2(CHAINLEN); +localparam integer ADD_LAT = TREE_HEIGHT + 1; + +logic [PE-1:0][ACCU_WIDTH-1:0] dat_acc; +logic [PE-1:0][ACCU_WIDTH-1:0] dat_int; + +for(genvar i = 0; i < PE; i++) begin + add_tree #( + .CHAINLEN(CHAINLEN), + .ACCU_WIDTH(ACCU_WIDTH), + .TREE_HEIGHT(TREE_HEIGHT) + ) inst_add_stage ( + .clk(clk), + .rst(rst), + .en(en), + + .idat(idat[i]), + .iacc(dat_acc[i]), + .odat(dat_int[i]) + ); +end + +// REG +logic [ADD_LAT:0] val; +logic [ADD_LAT:0] last; + +assign val[0] = ival; +assign last[0] = ilast; + +always_ff @(posedge clk) begin + if(rst) begin + for(int i = 1; i <= ADD_LAT; i++) begin + val[i] <= 1'b0; + last[i] <= 'X; + end + end + else begin + if(en) begin + for(int i = 1; i <= ADD_LAT; i++) begin + val[i] <= val[i-1]; + last[i] <= last[i-1]; + end + end + end +end + +logic val_int; +logic last_int; +logic inc_acc; + +assign val_int = val[ADD_LAT]; +assign last_int = last[ADD_LAT]; +assign inc_acc = val[ADD_LAT-1]; + +// +// Accumulation +// + +localparam integer TH_BITS = $clog2(TH); + +logic [TH_BITS-1:0] cnt_prep = 0; +logic prep = 1'b1; + +logic fifo_in_tvalid, fifo_in_tready; +logic fifo_out_tvalid, fifo_out_tready; +logic [PE*ACCU_WIDTH-1:0] fifo_in_tdata, fifo_out_tdata; + +Q_srl #( + .depth(TH_MAX), + .width(PE*ACCU_WIDTH) +) inst_acc ( + .clock(clk), + .reset(rst), + .i_d(fifo_in_tdata), + .i_v(fifo_in_tvalid), + .i_r(fifo_in_tready), + .o_d(fifo_out_tdata), + .o_v(fifo_out_tvalid), + .o_r(fifo_out_tready), + .count(), + .maxcount() +); + +always_ff @(posedge clk) begin + if(rst) begin + cnt_prep <= 0; + prep <= 1'b1; + + odat <= 'X; + oval <= 1'b0; + end + else begin + if(cnt_prep == TH-1) begin + prep <= 1'b0; + cnt_prep <= 0; + end + else begin + cnt_prep <= cnt_prep + 1; + end + + if(en) begin + odat <= dat_int; + oval <= val_int && last_int; + end + end +end + +always_comb begin + fifo_in_tvalid = prep ? 1'b1 : (en ? val_int : 1'b0); + fifo_in_tdata = prep ? 0 : (last_int ? 0 : dat_int); +end + +assign dat_acc = fifo_out_tdata; +assign fifo_out_tready = en & inc_acc; + +endmodule : acc_stage \ No newline at end of file diff --git a/finn-rtllib/mvu_tiled/add_tree.sv b/finn-rtllib/mvu_tiled/add_tree.sv new file mode 100644 index 0000000000..32861e3519 --- /dev/null +++ b/finn-rtllib/mvu_tiled/add_tree.sv @@ -0,0 +1,145 @@ +/****************************************************************************** + * Copyright (C) 2024, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + *****************************************************************************/ + +module add_tree #( + parameter CHAINLEN, + parameter ACCU_WIDTH, + parameter TREE_HEIGHT +) ( + input logic clk, + input logic rst, + input logic en, + + input logic [CHAINLEN-1:0][ACCU_WIDTH-1:0] idat, + + input logic [ACCU_WIDTH-1:0] iacc, + + output logic [ACCU_WIDTH-1:0] odat +); + +//-------------------- Adder tree + +function automatic int level_len (int lvl); + return (CHAINLEN + (1 << lvl) - 1) >> lvl; // ceil(CHAINLEN / 2^lvl) +endfunction + +logic signed [ACCU_WIDTH-1:0] add_sf; + +if(CHAINLEN == 1) begin + assign add_sf = idat[0]; +end +else begin + logic signed [TREE_HEIGHT:0][CHAINLEN-1:0][ACCU_WIDTH-1:0] add_s; + + for(genvar i = 0; i < CHAINLEN; i++) begin + assign add_s[0][i] = signed'(idat[i]); + end + + /* + always_ff @(posedge clk) begin + if(rst) begin + for(int i = 1; i <= TREE_HEIGHT; i++) begin + add_s[i] <= '0; + end + end + else begin + if(en) begin + for(int i = 0; i < TREE_HEIGHT; i++) begin + for(int j = 0; j < (CHAINLEN/2 + (2**i-1))/(2**i); j++) begin + add_s[i+1][j] <= $signed(add_s[i][2*j+0]) + $signed(add_s[i][2*j+1]); + end + end + end + end + end + */ + + always_ff @(posedge clk) begin + if (rst) begin + // Clear all levels (safe for unused slots too) + for (int i = 1; i <= TREE_HEIGHT; i++) begin + for (int j = 0; j < CHAINLEN; j++) begin + add_s[i][j] <= '0; + end + end + end else if (en) begin + // For each level i, produce next level i+1 + for (int i = 0; i < TREE_HEIGHT; i++) begin + int src_len = level_len(i); // live elems at level i + int dst_len = level_len(i + 1); // live elems at level i+1 (ceil(src_len/2)) + + // Compute valid outputs only (0..dst_len-1). Leave rest zero. + for (int j = 0; j < dst_len; j++) begin + int a_idx = 2*j; + int b_idx = 2*j + 1; + + // Cases: + // - both indices in range -> sum + // - only a_idx in range -> pass-through + // - neither in range -> zero (shouldn't happen for j 256 (with a loop-carried dependency) cannot be handled out-of-the-box with PyVerilator + +//-------------------- Shift register for last and valid signals --------------------\\ + localparam int unsigned DSP_PIPELINE_STAGES = 1; + logic L [0:1+DSP_PIPELINE_STAGES] = '{default: 0}; + logic V [0:1+DSP_PIPELINE_STAGES] = '{default: 0}; + + always_ff @(posedge clk) begin + if(rst) begin + L <= '{default: 0}; + V <= '{default: 0}; + end + else if(en) begin + L[1+DSP_PIPELINE_STAGES] <= ilast; + L[0:DSP_PIPELINE_STAGES] <= L[1:1+DSP_PIPELINE_STAGES]; + + V[1+DSP_PIPELINE_STAGES] <= ivld; + V[0:DSP_PIPELINE_STAGES] <= V[1:1+DSP_PIPELINE_STAGES]; + end + end + + logic last; + logic vld; + assign last = L[0]; + assign vld = V[0]; + +//-------------------- Buffer for input activations --------------------\\ + localparam int unsigned PAD_BITS_ACT = 9 - ACTIVATION_WIDTH; + for (genvar i=0; i 8) begin + $error("Weight width of %0d-bits exceeds maximum of 8-bits", WEIGHT_WIDTH); + $finish; + end + if (ACTIVATION_WIDTH > 8) begin + $error("Activation width of %0d-bits exceeds maximum of 8-bits", ACTIVATION_WIDTH); + $finish; + end + end + + uwire rst = !ap_rst_n; + + //- Replay to Accommodate Neuron Fold ----------------------------------- + typedef logic [SIMD-1:0][ACTIVATION_WIDTH-1:0] mvu_flatin_t; + uwire mvu_flatin_t amvau; + uwire alast; + uwire avld; + uwire ardy; + + replay_buff_tile #(.XC(MW/SIMD), .YC(TH), .W($bits(mvu_flatin_t)), .N_REPS(MH/PE), .IO_TILED(IN_TILED)) activation_replay ( + .clk(ap_clk), .rst(rst), + .ivld(s_axis_input_tvalid), .irdy(s_axis_input_tready), .idat(mvu_flatin_t'(s_axis_input_tdata)), + .ovld(avld), .ordy(ardy), .odat(amvau), .olast(alast) + ); + + //- Unflatten weights --------------------------------------------------- + typedef logic [PE-1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] mvu_w_t; + uwire mvu_w_t wdat; + uwire wvld; + uwire wrdy; + + weights_buff_tile #( + .WEIGHT_WIDTH(WEIGHT_WIDTH), + .SIMD(SIMD), .PE(PE), + .TH(TH), .WSIMD(WSIMD), + .N_DCPL_STAGES(N_DCPL_STAGES) + ) inst_weights_buff_tile ( + .clk(ap_clk), .rst(rst), + .ivld(s_axis_weights_tvalid), .irdy(s_axis_weights_tready), .idat(s_axis_weights_tdata), + .ovld(wvld), .ordy(wrdy), .odat(wdat) + ); + + //- Flow Control Bracket around Compute Core ---------------------------- + uwire en; + uwire istb = avld && wvld; + assign ardy = en && wvld; + assign wrdy = en && avld; + + //- Conditionally Pumped DSP Compute ------------------------------------ + typedef logic [PE-1:0][ACCU_WIDTH-1:0] dsp_p_t; + uwire ovld; + uwire dsp_p_t odat; + if(1) begin : blkDsp + localparam int unsigned EFFECTIVE_SIMD = SIMD_UNEVEN && PUMPED_COMPUTE ? SIMD+1 : SIMD; + localparam int unsigned DSP_SIMD = EFFECTIVE_SIMD/(PUMPED_COMPUTE+1); + typedef logic [PE -1:0][DSP_SIMD-1:0][WEIGHT_WIDTH -1:0] dsp_w_t; + typedef logic [DSP_SIMD-1:0][ACTIVATION_WIDTH-1:0] dsp_a_t; + + uwire dsp_last; + uwire dsp_zero; + uwire dsp_w_t dsp_w; + uwire dsp_a_t dsp_a; + + uwire dsp_vld; + uwire dsp_p_t dsp_p; + + // TODO: No double-pumping in the initial implementation + assign dsp_en = en; + + assign dsp_last = alast && istb; + assign dsp_zero = !istb; + assign dsp_w = wdat; + assign dsp_a = amvau; + + assign ovld = dsp_vld; + assign odat = dsp_p; + + // + // Compute Unit + // + + case(COMPUTE_CORE) + "mvu_vvu_8sx9_dsp58": begin : core + cu_mvau_tiled #( + .PE(PE), .SIMD(SIMD), + .TH(TH), + .WEIGHT_WIDTH(WEIGHT_WIDTH), .ACTIVATION_WIDTH(ACTIVATION_WIDTH), .ACCU_WIDTH(ACCU_WIDTH), + .SIGNED_ACTIVATIONS(SIGNED_ACTIVATIONS) + ) inst_cu_mvau_tiled ( + .clk(ap_clk), .rst(rst), .en(dsp_en), + .ivld(istb), .ilast(dsp_last), .w(dsp_w), .a(dsp_a), + .ovld(dsp_vld), .p(dsp_p) + ); + end + default: initial begin + $error("Unrecognized COMPUTE_CORE '%s'", COMPUTE_CORE); + $finish; + end + endcase + + end : blkDsp + + //-------------------- Output register slice --------------------\\ + // Make `en`computation independent from external inputs. + // Drive all outputs from registers. + + logic m_axis_int_tvalid; + logic m_axis_int_tready; + logic [OUTPUT_STREAM_WIDTH_BA-1:0] m_axis_int_tdata; + + struct packed { + logic rdy; + logic [PE-1:0][ACCU_WIDTH-1:0] dat; + } A = '{ rdy: 1, default: 'x }; // side-step register used when encountering backpressure + struct packed { + logic vld; + logic [PE-1:0][ACCU_WIDTH-1:0] dat; + } B = '{ vld: 0, default: 'x }; // ultimate output register + + assign en = A.rdy; + uwire b_load = !B.vld || m_axis_int_tready; + + always_ff @(posedge ap_clk) begin + if(rst) begin + A <= '{ rdy: 1, default: 'x }; + B <= '{ vld: 0, default: 'x }; + end + else begin + if(A.rdy) A.dat <= odat; + A.rdy <= (A.rdy && !ovld) || b_load; + + if(b_load) begin + B <= '{ + vld: ovld || !A.rdy, + dat: A.rdy? odat : A.dat + }; + end + end + end + assign m_axis_int_tvalid = B.vld; + // Why would we need a sign extension here potentially creating a higher signal load into the next FIFO? + // These extra bits should never be used. Why not 'x them out? + assign m_axis_int_tdata = { {(OUTPUT_STREAM_WIDTH_BA-OUTPUT_STREAM_WIDTH){B.dat[PE-1][ACCU_WIDTH-1]}}, B.dat}; + + //-------------------- Output reordering --------------------\\ + + if(OUT_TILED == 0) begin + reorder_out #(.W(OUTPUT_STREAM_WIDTH_BA), .XC(MH/PE), .YC(TH)) inst_reorder_out ( + .clk(ap_clk), .rst(rst), + .ivld(m_axis_int_tvalid), .irdy(m_axis_int_tready), .idat(m_axis_int_tdata), + .ovld(m_axis_output_tvalid), .ordy(m_axis_output_tready), .odat(m_axis_output_tdata) + ); + end else begin + assign m_axis_output_tvalid = m_axis_int_tvalid; + assign m_axis_int_tready = m_axis_output_tready; + assign m_axis_output_tdata = m_axis_int_tdata; + end + +endmodule : mvu_tiled_axi diff --git a/finn-rtllib/mvu_tiled/mvu_tiled_axi_wrapper.v b/finn-rtllib/mvu_tiled/mvu_tiled_axi_wrapper.v new file mode 100644 index 0000000000..fe3a73bc8b --- /dev/null +++ b/finn-rtllib/mvu_tiled/mvu_tiled_axi_wrapper.v @@ -0,0 +1,97 @@ +/****************************************************************************** + * Copyright (C) 2024, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + *****************************************************************************/ + +module $MODULE_NAME_AXI_WRAPPER$ #( + parameter PE = $PE$, + parameter SIMD = $SIMD$, + parameter ACTIVATION_WIDTH = $ACTIVATION_WIDTH$, + parameter WEIGHT_WIDTH = $WEIGHT_WIDTH$, + parameter ACCU_WIDTH = $ACCU_WIDTH$, + parameter MW = $MW$, + parameter MH = $MH$, + parameter TH = $TH$, + parameter NARROW_WEIGHTS = $NARROW_WEIGHTS$, + parameter SIGNED_ACTIVATIONS = $SIGNED_ACTIVATIONS$, + parameter PUMPED_COMPUTE = $PUMPED_COMPUTE$, + + // Safely deducible parameters + parameter WSIMD = (PE * SIMD) / TH, + parameter WEIGHT_STREAM_WIDTH_BA = (WSIMD * WEIGHT_WIDTH + 7)/8 * 8, + parameter INPUT_STREAM_WIDTH_BA = (SIMD * ACTIVATION_WIDTH + 7) / 8 * 8, + parameter OUTPUT_STREAM_WIDTH_BA = (PE * ACCU_WIDTH + 7)/8 * 8 +)( + // Global Control + (* X_INTERFACE_PARAMETER = "ASSOCIATED_BUSIF in1_V:in0_V:out0_V, ASSOCIATED_RESET ap_rst_n" *) + (* X_INTERFACE_INFO = "xilinx.com:signal:clock:1.0 ap_clk CLK" *) + input ap_clk, + (* X_INTERFACE_PARAMETER = "ASSOCIATED_RESET ap_rst_n" *) + (* X_INTERFACE_INFO = "xilinx.com:signal:clock:1.0 ap_clk2x CLK" *) + input ap_clk2x, + (* X_INTERFACE_PARAMETER = "POLARITY ACTIVE_LOW" *) + input ap_rst_n, + + // Weight Stream + input [WEIGHT_STREAM_WIDTH_BA-1:0] in1_V_TDATA, + input in1_V_TVALID, + output in1_V_TREADY, + // Input Stream + input [INPUT_STREAM_WIDTH_BA-1:0] in0_V_TDATA, + input in0_V_TVALID, + output in0_V_TREADY, + // Output Stream + output [OUTPUT_STREAM_WIDTH_BA-1:0] out0_V_TDATA, + output out0_V_TVALID, + input out0_V_TREADY +); + +mvu_tiled_axi #( + .PE(PE), .SIMD(SIMD), + .ACTIVATION_WIDTH(ACTIVATION_WIDTH), .WEIGHT_WIDTH(WEIGHT_WIDTH), .ACCU_WIDTH(ACCU_WIDTH), + .MW(MW), .MH(MH), .TH(TH), + .NARROW_WEIGHTS(NARROW_WEIGHTS), .SIGNED_ACTIVATIONS(SIGNED_ACTIVATIONS), .PUMPED_COMPUTE(PUMPED_COMPUTE), + .FORCE_BEHAVIORAL(0) + ) inst ( + .ap_clk(ap_clk), + .ap_clk2x(ap_clk2x), + .ap_rst_n(ap_rst_n), + .s_axis_weights_tdata(in1_V_TDATA), + .s_axis_weights_tvalid(in1_V_TVALID), + .s_axis_weights_tready(in1_V_TREADY), + .s_axis_input_tdata(in0_V_TDATA), + .s_axis_input_tvalid(in0_V_TVALID), + .s_axis_input_tready(in0_V_TREADY), + .m_axis_output_tdata(out0_V_TDATA), + .m_axis_output_tvalid(out0_V_TVALID), + .m_axis_output_tready(out0_V_TREADY) +); + +endmodule // $MODULE_NAME_AXI_WRAPPER$ diff --git a/finn-rtllib/mvu_tiled/reorder_out.sv b/finn-rtllib/mvu_tiled/reorder_out.sv new file mode 100644 index 0000000000..19421937e4 --- /dev/null +++ b/finn-rtllib/mvu_tiled/reorder_out.sv @@ -0,0 +1,341 @@ +/****************************************************************************** + * Copyright (C) 2024, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + *****************************************************************************/ + +module reorder_out #( + int unsigned W, + int unsigned XC, + int unsigned YC +)( + input logic clk, + input logic rst, + + input logic ivld, + output logic irdy, + input logic [W-1:0] idat, + + output logic ovld, + input logic ordy, + output logic [W-1:0] odat +); + +// ---------------------------------------------------------------------------- +// Consts and types +// ---------------------------------------------------------------------------- + +localparam int unsigned RAM_BITS = (W + 7)/8 * 8; +localparam int unsigned WGT_EN_BITS = RAM_BITS / 8; +localparam int unsigned XYC = XC * YC; +localparam int unsigned XCNT_BITS = (XC == 1) ? 1 : $clog2(XC); +localparam int unsigned YCNT_BITS = (YC == 1) ? 1 : $clog2(YC); +localparam int unsigned XYCNT_BITS = (XYC == 1) ? 1 : $clog2(XYC); + +typedef enum logic[1:0] {ST_WR_0, ST_WR_0_WAIT, ST_WR_1, ST_WR_1_WAIT} state_wr_t; +typedef enum logic {ST_RD_0, ST_RD_1} state_rd_t; + +// ---------------------------------------------------------------------------- +// Writer +// ---------------------------------------------------------------------------- + +// -- Regs +state_wr_t state_wr_C = ST_WR_0, state_wr_N; +state_rd_t state_rd_C = ST_RD_0, state_rd_N; + +logic [XCNT_BITS-1:0] curr_wrX_C = '0, curr_wrX_N; +logic [YCNT_BITS-1:0] curr_wrY_C = '0, curr_wrY_N; + +// -- Ram +logic [1:0][WGT_EN_BITS-1:0] a_we; // Bank enables +logic [1:0][XYCNT_BITS-1:0] a_addr; +logic [1:0][W-1:0] a_data_in; + +// -- Offsets +logic [XC-1:0][XYCNT_BITS-1:0] x_offsets; +for(genvar i = 0; i < XC; i++) begin + assign x_offsets[i] = i*YC; +end + +// -- IPC +logic done; + +// -- REG +always_ff @( posedge clk ) begin : REG_PROC_WR + if(rst) begin + state_wr_C <= ST_WR_0; + + curr_wrX_C <= 0; + curr_wrY_C <= 0; + end + else begin + state_wr_C <= state_wr_N; + + curr_wrX_C <= curr_wrX_N; + curr_wrY_C <= curr_wrY_N; + end +end + +// -- NSL +always_comb begin : NSL_PROC_WR + state_wr_N = state_wr_C; + + case (state_wr_C) + ST_WR_0: + if ((curr_wrY_C == YC - 1) && (curr_wrX_C == XC - 1) && ivld) begin + state_wr_N = (done || (state_rd_C == ST_RD_0)) ? ST_WR_1 : ST_WR_0_WAIT; + end + + ST_WR_0_WAIT: + state_wr_N = (done || (state_rd_C == ST_RD_0)) ? ST_WR_1 : ST_WR_0_WAIT; + + ST_WR_1: + if ((curr_wrY_C == YC - 1) && (curr_wrX_C == XC - 1) && ivld) begin + state_wr_N = (done || (state_rd_C == ST_RD_1)) ? ST_WR_0 : ST_WR_1_WAIT; + end + + ST_WR_1_WAIT: + state_wr_N = (done || (state_rd_C == ST_RD_1)) ? ST_WR_0 : ST_WR_1_WAIT; + + endcase +end + +// -- DP +always_comb begin : DP_PROC_WR + curr_wrX_N = curr_wrX_C; + curr_wrY_N = curr_wrY_C; + + // Input + irdy = 1'b0; + + // Buffer control + a_we = '0; + for(int i = 0; i < 2; i++) begin + a_addr[i] = x_offsets[curr_wrX_C] + curr_wrY_C; + a_data_in[i] = idat; + end + + // Write and count + case (state_wr_C) + ST_WR_0, ST_WR_1: begin + irdy = 1'b1; + + if(ivld) begin + if(state_wr_C == ST_WR_0) a_we[0] = '1; else a_we[1] = '1; + + curr_wrY_N = (curr_wrY_C == YC-1) ? 0 : curr_wrY_C + 1; + curr_wrX_N = (curr_wrY_C == YC-1) ? ((curr_wrX_C == XC-1) ? 0 : curr_wrX_C + 1) : curr_wrX_C; + end + end + endcase + +end + + +// ---------------------------------------------------------------------------- +// Reader +// ---------------------------------------------------------------------------- + +// -- Regs +logic [XCNT_BITS-1:0] curr_rdX_C = '0, curr_rdX_N; +logic [YCNT_BITS-1:0] curr_rdY_C = '0, curr_rdY_N; + +// -- Ram +logic [1:0] vld_s0_C = '0, vld_s0_N; +logic [1:0] vld_s1_C = '0, vld_s1_N; +logic vld_C = '0, vld_N; +logic [W-1:0] odat_C = '0, odat_N; + +logic [1:0][XYCNT_BITS-1:0] b_addr; +logic [1:0][W-1:0] odat_ram; + +// -- Cond +logic cond_go; + +// -- Oreg +logic [W-1:0] odat_int; +logic ovld_int; +logic ordy_int; + +// -- REG +always_ff @( posedge clk ) begin : REG_PROC_RD + if(rst) begin + state_rd_C <= ST_RD_0; + + curr_rdX_C <= 0; + curr_rdY_C <= 0; + + vld_s0_C <= 0; + vld_s1_C <= 0; + vld_C <= 0; + odat_C <= 0; + end + else begin + state_rd_C <= state_rd_N; + + curr_rdX_C <= curr_rdX_N; + curr_rdY_C <= curr_rdY_N; + + vld_s0_C <= vld_s0_N; + vld_s1_C <= vld_s1_N; + vld_C <= vld_N; + odat_C <= odat_N; + end +end + +// -- NSL +always_comb begin : NSL_PROC_RD + state_rd_N = state_rd_C; + + case (state_rd_C) + ST_RD_0: + if(ordy_int && ((state_wr_C == ST_WR_0) ? cond_go : 1'b1)) begin + if((curr_rdX_C == XC-1) && (curr_rdY_C == YC-1)) begin + state_rd_N = ST_RD_1; + end + end + + ST_RD_1: + if(ordy_int && ((state_wr_C == ST_WR_1) ? cond_go : 1'b1)) begin + if((curr_rdX_C == XC-1) && (curr_rdY_C == YC-1)) begin + state_rd_N = ST_RD_0; + end + end + + endcase +end + +// -- DP cond +always_comb begin + cond_go = 1'b0; + + if(curr_wrX_C > curr_rdX_C) begin + cond_go = 1'b1; + end + else if(curr_wrX_C == curr_rdX_C) begin + if(curr_wrY_C > curr_rdY_C) begin + cond_go = 1'b1; + end + end +end + +// -- DP +always_comb begin : DP_PROC_RD + curr_rdX_N = curr_rdX_C; + curr_rdY_N = curr_rdY_C; + + for(int i = 0; i < 2; i++) begin + vld_s0_N[i] = ordy_int ? 1'b0 : vld_s0_C[i]; + vld_s1_N[i] = ordy_int ? vld_s0_C[i] : vld_s1_C[i]; + end + + vld_N = ordy_int ? |vld_s1_C : vld_C; + odat_N = ordy_int ? (vld_s1_C[0] ? odat_ram[0] : odat_ram[1]) : odat_C; + + for(int i = 0; i < 2; i++) begin + b_addr[i] = x_offsets[curr_rdX_C] + curr_rdY_C; + end + + done = 1'b0; + + case(state_rd_C) + ST_RD_0: begin + if(ordy_int) begin + if((state_wr_C == ST_WR_0) ? cond_go : 1'b1) begin + vld_s0_N[0] = 1'b1; + + curr_rdX_N = (curr_rdX_C == XC-1) ? 0 : curr_rdX_C + 1; + curr_rdY_N = (curr_rdX_C == XC-1) ? ((curr_rdY_C == YC-1) ? 0 : curr_rdY_C + 1) : curr_rdY_C; + done = ((curr_rdY_C == YC-1) && (curr_rdX_C == XC-1)); + end + end + end + + ST_RD_1: begin + if(ordy_int) begin + if((state_wr_C == ST_WR_1) ? cond_go : 1'b1) begin + vld_s0_N[1] = 1'b1; + + curr_rdX_N = (curr_rdX_C == XC-1) ? 0 : curr_rdX_C + 1; + curr_rdY_N = (curr_rdX_C == XC-1) ? ((curr_rdY_C == YC-1) ? 0 : curr_rdY_C + 1) : curr_rdY_C; + done = ((curr_rdY_C == YC-1) && (curr_rdX_C == XC-1)); + end + end + end + + endcase + +end + +assign ovld_int = vld_C; +assign odat_int = odat_C; + +// ---------------------------------------------------------------------------- +// Matrix +// ---------------------------------------------------------------------------- + +for(genvar i = 0; i < 2; i++) begin + ram_p_c #( + .ADDR_BITS(XYCNT_BITS), + .DATA_BITS(RAM_BITS), + .RAM_STYLE("distributed") + ) inst_ram_tp_c ( + .clk(clk), + .a_en(1'b1), + .a_we(a_we[i]), + .a_addr(a_addr[i]), + .b_en(ordy_int), + .b_addr(b_addr[i]), + .a_data_in(a_data_in[i]), + .a_data_out(), + .b_data_out(odat_ram[i]) + ); +end + +// ---------------------------------------------------------------------------- +// Output +// ---------------------------------------------------------------------------- + +Q_srl #( + .depth(2), .width(W) +) inst_out_fifo ( + .clock(clk), + .reset(rst), + .count(), + .maxcount(), + .i_d(odat_int), + .i_v(ovld_int), + .i_r(ordy_int), + .o_d(odat), + .o_v(ovld), + .o_r(ordy) +); + + +endmodule \ No newline at end of file diff --git a/finn-rtllib/mvu_tiled/replay_buff_tile.sv b/finn-rtllib/mvu_tiled/replay_buff_tile.sv new file mode 100644 index 0000000000..467e17a3af --- /dev/null +++ b/finn-rtllib/mvu_tiled/replay_buff_tile.sv @@ -0,0 +1,392 @@ +/****************************************************************************** + * Copyright (C) 2024, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + *****************************************************************************/ + +module replay_buff_tile #( + int unsigned W, + int unsigned XC, + int unsigned YC, + int unsigned N_REPS, + int unsigned IO_TILED = 0 +)( + input logic clk, + input logic rst, + + input logic ivld, + output logic irdy, + input logic [W-1:0] idat, + + output logic ovld, + input logic ordy, + output logic [W-1:0] odat, + output logic olast +); + +// ---------------------------------------------------------------------------- +// Consts and types +// ---------------------------------------------------------------------------- + +localparam int unsigned RAM_BITS = (W + 7)/8 * 8; +localparam int unsigned WGT_EN_BITS = RAM_BITS / 8; +localparam int unsigned XYC = XC * YC; +localparam int unsigned XCNT_BITS = (XC == 1) ? 1 : $clog2(XC); +localparam int unsigned YCNT_BITS = (YC == 1) ? 1 : $clog2(YC); +localparam int unsigned XYCNT_BITS = (XYC == 1) ? 1 : $clog2(XYC); +localparam int unsigned REPS_BITS = (N_REPS == 1) ? 1 : $clog2(N_REPS); + +typedef enum logic[1:0] {ST_WR_0, ST_WR_0_WAIT, ST_WR_1, ST_WR_1_WAIT} state_wr_t; +typedef enum logic {ST_RD_0, ST_RD_1} state_rd_t; + +// ---------------------------------------------------------------------------- +// Ireg +// ---------------------------------------------------------------------------- +logic [W-1:0] idat_int; +logic ivld_int; +logic irdy_int; + +skid #(.DATA_WIDTH(W), .FEED_STAGES(1)) isnt_input_reg ( + .clk(clk), .rst(rst), + .ivld(ivld), .irdy(irdy), .idat(idat), + .ovld(ivld_int), .ordy(irdy_int), .odat(idat_int) +); + +// ---------------------------------------------------------------------------- +// Writer +// ---------------------------------------------------------------------------- + +// -- Regs +state_wr_t state_wr_C = ST_WR_0, state_wr_N; +state_rd_t state_rd_C = ST_RD_0, state_rd_N; + +logic [XCNT_BITS-1:0] curr_wrX_C = '0, curr_wrX_N; +logic [YCNT_BITS-1:0] curr_wrY_C = '0, curr_wrY_N; + +// -- Ram +logic [1:0][WGT_EN_BITS-1:0] a_we; // Bank enables +logic [1:0][XYCNT_BITS-1:0] a_addr; +logic [1:0][W-1:0] a_data_in; + +// -- Offsets +logic [XC-1:0][XYCNT_BITS-1:0] x_offsets; +for(genvar i = 0; i < XC; i++) begin + assign x_offsets[i] = i*YC; +end + +// -- IPC +logic done; + +// -- REG +always_ff @( posedge clk ) begin : REG_PROC_WR + if(rst) begin + state_wr_C <= ST_WR_0; + + curr_wrX_C <= 0; + curr_wrY_C <= 0; + end + else begin + state_wr_C <= state_wr_N; + + curr_wrX_C <= curr_wrX_N; + curr_wrY_C <= curr_wrY_N; + end +end + +// -- NSL +always_comb begin : NSL_PROC_WR + state_wr_N = state_wr_C; + + case (state_wr_C) + ST_WR_0: + if ((curr_wrY_C == YC - 1) && (curr_wrX_C == XC - 1) && ivld_int) begin + state_wr_N = (done || (state_rd_C == ST_RD_0)) ? ST_WR_1 : ST_WR_0_WAIT; + end + + ST_WR_0_WAIT: + state_wr_N = (done || (state_rd_C == ST_RD_0)) ? ST_WR_1 : ST_WR_0_WAIT; + + ST_WR_1: + if ((curr_wrY_C == YC - 1) && (curr_wrX_C == XC - 1) && ivld_int) begin + state_wr_N = (done || (state_rd_C == ST_RD_1)) ? ST_WR_0 : ST_WR_1_WAIT; + end + + ST_WR_1_WAIT: + state_wr_N = (done || (state_rd_C == ST_RD_1)) ? ST_WR_0 : ST_WR_1_WAIT; + + endcase +end + +// -- DP +always_comb begin : DP_PROC_WR + curr_wrX_N = curr_wrX_C; + curr_wrY_N = curr_wrY_C; + + // Input + irdy_int = 1'b0; + + // Buffer control + a_we = '0; + for(int i = 0; i < 2; i++) begin + a_addr[i] = x_offsets[curr_wrX_C] + curr_wrY_C; + a_data_in[i] = idat_int; + end + + // Write and count + case (state_wr_C) + ST_WR_0, ST_WR_1: begin + irdy_int = 1'b1; + + if(ivld_int) begin + if(state_wr_C == ST_WR_0) a_we[0] = '1; else a_we[1] = '1; + + curr_wrX_N = (curr_wrX_C == XC-1) ? 0 : curr_wrX_C + 1; + curr_wrY_N = (curr_wrX_C == XC-1) ? ((curr_wrY_C == YC-1) ? 0 : curr_wrY_C + 1) : curr_wrY_C; + end + end + endcase + +end + +// ---------------------------------------------------------------------------- +// Reader +// ---------------------------------------------------------------------------- + +// -- Regs +logic [XCNT_BITS-1:0] curr_rdX_C = '0, curr_rdX_N; +logic [YCNT_BITS-1:0] curr_rdY_C = '0, curr_rdY_N; +logic [REPS_BITS-1:0] curr_reps_C = '0, curr_reps_N; + +// -- Ram +logic [1:0] vld_s0_C = '0, vld_s0_N; +logic [1:0] vld_s1_C = '0, vld_s1_N; +logic vld_C = '0, vld_N; +logic last_s0_C = '0, last_s0_N; +logic last_s1_C = '0, last_s1_N; +logic last_C = '0, last_N; +logic [W-1:0] odat_C = '0, odat_N; + +logic [1:0][XYCNT_BITS-1:0] b_addr; +logic [1:0][W-1:0] odat_ram; + +// -- Cond +logic cond_go; + +// -- Oreg +logic [W-1:0] odat_int; +logic ovld_int; +logic ordy_int; +logic olast_int; + +// -- REG +always_ff @( posedge clk ) begin : REG_PROC_RD + if(rst) begin + state_rd_C <= ST_RD_0; + + curr_rdX_C <= 0; + curr_rdY_C <= 0; + curr_reps_C <= 0; + + vld_s0_C <= 0; + vld_s1_C <= 0; + vld_C <= 0; + odat_C <= 0; + last_s0_C <= 0; + last_s1_C <= 0; + last_C <= 0; + end + else begin + state_rd_C <= state_rd_N; + + curr_rdX_C <= curr_rdX_N; + curr_rdY_C <= curr_rdY_N; + curr_reps_C <= curr_reps_N; + + vld_s0_C <= vld_s0_N; + vld_s1_C <= vld_s1_N; + vld_C <= vld_N; + odat_C <= odat_N; + last_s0_C <= last_s0_N; + last_s1_C <= last_s1_N; + last_C <= last_N; + end +end + +// -- NSL +always_comb begin : NSL_PROC_RD + state_rd_N = state_rd_C; + + case (state_rd_C) + ST_RD_0: + if(ordy_int && ((state_wr_C == ST_WR_0) ? cond_go : 1'b1)) begin + if((curr_rdX_C == XC-1) && (curr_rdY_C == YC-1) && (curr_reps_C == N_REPS-1)) begin + state_rd_N = ST_RD_1; + end + end + + ST_RD_1: + if(ordy_int && ((state_wr_C == ST_WR_1) ? cond_go : 1'b1)) begin + if((curr_rdX_C == XC-1) && (curr_rdY_C == YC-1) && (curr_reps_C == N_REPS-1)) begin + state_rd_N = ST_RD_0; + end + end + + endcase +end + +// -- DP cond +always_comb begin + cond_go = 1'b0; + + if(IO_TILED) begin + if(curr_wrX_C > curr_rdX_C) begin + cond_go = 1'b1; + end + else if(curr_wrX_C == curr_rdX_C) begin + if(curr_wrY_C > curr_rdY_C) begin + cond_go = 1'b1; + end + end + end else begin + if(curr_wrY_C > curr_rdY_C) begin + cond_go = 1'b1; + end + else if(curr_wrY_C == curr_rdY_C) begin + if(curr_wrX_C > curr_rdX_C) begin + cond_go = 1'b1; + end + end + end +end + +// -- DP +always_comb begin : DP_PROC_RD + curr_rdX_N = curr_rdX_C; + curr_rdY_N = curr_rdY_C; + curr_reps_N = curr_reps_C; + + for(int i = 0; i < 2; i++) begin + vld_s0_N[i] = ordy_int ? 1'b0 : vld_s0_C[i]; + vld_s1_N[i] = ordy_int ? vld_s0_C[i] : vld_s1_C[i]; + end + + vld_N = ordy_int ? |vld_s1_C : vld_C; + odat_N = ordy_int ? (vld_s1_C[0] ? odat_ram[0] : odat_ram[1]) : odat_C; + + last_s0_N = ordy_int ? 1'b0 : last_s0_C; + last_s1_N = ordy_int ? last_s0_C : last_s1_C; + last_N = ordy_int ? last_s1_C : last_C; + + for(int i = 0; i < 2; i++) begin + b_addr[i] = x_offsets[curr_rdX_C] + curr_rdY_C; + end + + done = 1'b0; + + case(state_rd_C) + ST_RD_0: begin + if(ordy_int) begin + if((state_wr_C == ST_WR_0) ? cond_go : 1'b1) begin + vld_s0_N[0] = 1'b1; + + last_s0_N = (curr_rdX_C == XC-1); + + curr_rdY_N = (curr_rdY_C == YC-1) ? 0 : curr_rdY_C + 1; + curr_rdX_N = (curr_rdY_C == YC-1) ? ((curr_rdX_C == XC-1) ? 0 : curr_rdX_C + 1) : curr_rdX_C; + curr_reps_N = ((curr_rdY_C == YC-1) && (curr_rdX_C == XC-1)) ? ((curr_reps_C == N_REPS-1) ? 0 : curr_reps_C + 1) : curr_reps_C; + done = ((curr_rdY_C == YC-1) && (curr_rdX_C == XC-1) && (curr_reps_C == N_REPS-1)); + end + end + end + + ST_RD_1: begin + if(ordy_int) begin + if((state_wr_C == ST_WR_1) ? cond_go : 1'b1) begin + vld_s0_N[1] = 1'b1; + + last_s0_N = (curr_rdX_C == XC-1); + + curr_rdY_N = (curr_rdY_C == YC-1) ? 0 : curr_rdY_C + 1; + curr_rdX_N = (curr_rdY_C == YC-1) ? ((curr_rdX_C == XC-1) ? 0 : curr_rdX_C + 1) : curr_rdX_C; + curr_reps_N = ((curr_rdY_C == YC-1) && (curr_rdX_C == XC-1)) ? ((curr_reps_C == N_REPS-1) ? 0 : curr_reps_C + 1) : curr_reps_C; + done = ((curr_rdY_C == YC-1) && (curr_rdX_C == XC-1) && (curr_reps_C == N_REPS-1)); + end + end + end + + endcase + +end + +assign ovld_int = vld_C; +assign odat_int = odat_C; +assign olast_int = last_C; + +// ---------------------------------------------------------------------------- +// BRAM +// ---------------------------------------------------------------------------- + +for(genvar i = 0; i < 2; i++) begin + ram_p_c #( + .ADDR_BITS(XYCNT_BITS), + .DATA_BITS(RAM_BITS), + .RAM_STYLE("distributed") + ) inst_ram_tp_c ( + .clk(clk), + .a_en(1'b1), + .a_we(a_we[i]), + .a_addr(a_addr[i]), + .b_en(ordy_int), + .b_addr(b_addr[i]), + .a_data_in(a_data_in[i]), + .a_data_out(), + .b_data_out(odat_ram[i]) + ); +end + +// ---------------------------------------------------------------------------- +// Output +// ---------------------------------------------------------------------------- + +Q_srl #( + .depth(2), .width(1+W) +) inst_out_fifo ( + .clock(clk), + .reset(rst), + .count(), + .maxcount(), + .i_d({olast_int, odat_int}), + .i_v(ovld_int), + .i_r(ordy_int), + .o_d({olast, odat}), + .o_v(ovld), + .o_r(ordy) +); + +endmodule diff --git a/finn-rtllib/mvu_tiled/weights_buff_tile.sv b/finn-rtllib/mvu_tiled/weights_buff_tile.sv new file mode 100644 index 0000000000..1426411c24 --- /dev/null +++ b/finn-rtllib/mvu_tiled/weights_buff_tile.sv @@ -0,0 +1,260 @@ +/****************************************************************************** + * Copyright (C) 2024, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + *****************************************************************************/ + +module weights_buff_tile #( + int unsigned WEIGHT_WIDTH = 8, + int unsigned SIMD, + int unsigned PE, + int unsigned TH, + int unsigned WSIMD, + int unsigned NW = (PE*SIMD)/WSIMD, + int unsigned N_DCPL_STAGES +)( + input logic clk, + input logic rst, + + input logic ivld, + output logic irdy, + input logic [WSIMD-1:0][WEIGHT_WIDTH-1:0] idat, + + output logic ovld, + input logic ordy, + output logic [PE-1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] odat +); + + +//-------------------- Parameter sanity checks -------------------------------- + +initial begin + if ((PE*SIMD) % WSIMD != 0) begin + $error("Weight stream width not set properly (WSIMD: %0d, PE %0d, SIMD %0d).", WSIMD, PE, SIMD); + $finish; + end +end + +// ---------------------------------------------------------------------------- +// Consts and types +// ---------------------------------------------------------------------------- + +localparam integer NW_BITS = (NW == 1) ? 1 : $clog2(NW); +localparam integer TH_BITS = (TH == 1) ? 1 : $clog2(TH); + +typedef enum logic[1:0] {ST_WR_0, ST_WR_0_WAIT, ST_WR_1, ST_WR_1_WAIT} state_wr_t; +typedef enum logic {ST_RD_0, ST_RD_1} state_rd_t; + +// ---------------------------------------------------------------------------- +// Slice +// ---------------------------------------------------------------------------- + +logic ivld_int; +logic irdy_int; +logic [WSIMD-1:0][WEIGHT_WIDTH-1:0] idat_int; + +// Ireg +skid #(.DATA_WIDTH(WSIMD*WEIGHT_WIDTH), .FEED_STAGES(1)) inst_ireg ( + .clk(clk), .rst(rst), + .ivld(ivld), .irdy(irdy), .idat(idat), + .ovld(ivld_int), .ordy(irdy_int), .odat(idat_int) +); + +// ---------------------------------------------------------------------------- +// Writer +// ---------------------------------------------------------------------------- + +// -- Regs +state_wr_t state_wr_C = ST_WR_0, state_wr_N; +state_rd_t state_rd_C = ST_RD_0, state_rd_N; + +logic [NW_BITS-1:0] curr_C = '0, curr_N; + +logic done; + +logic ovld_int; +logic ordy_int; +logic [PE-1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] odat_int; + +// -- Mem +logic [1:0][NW-1:0][WSIMD*WEIGHT_WIDTH-1:0] mem_C, mem_N; + +// -- REG +always_ff @( posedge clk ) begin : REG_PROC_WR + if(rst) begin + state_wr_C <= ST_WR_0; + + curr_C <= '0; + mem_C <= '0; + end + else begin + state_wr_C <= state_wr_N; + + curr_C <= curr_N; + mem_C <= mem_N; + end +end + +// -- NSL +always_comb begin : NSL_PROC_WR + state_wr_N = state_wr_C; + + case (state_wr_C) + ST_WR_0: + if ((curr_C == NW - 1) && ivld_int) begin + state_wr_N = (done || (state_rd_C == ST_RD_0)) ? ST_WR_1 : ST_WR_0_WAIT; + end + + ST_WR_0_WAIT: + state_wr_N = (done || (state_rd_C == ST_RD_0)) ? ST_WR_1 : ST_WR_0_WAIT; + + ST_WR_1: + if ((curr_C == NW - 1) && ivld_int) begin + state_wr_N = (done || (state_rd_C == ST_RD_1)) ? ST_WR_0 : ST_WR_1_WAIT; + end + + ST_WR_1_WAIT: + state_wr_N = (done || (state_rd_C == ST_RD_1)) ? ST_WR_0 : ST_WR_1_WAIT; + + endcase +end + +// -- DP +always_comb begin : DP_PROC_WR + curr_N = curr_C; + mem_N = mem_C; + + // Input + irdy_int = 1'b0; + + // Write and count + case (state_wr_C) + ST_WR_0, ST_WR_1: begin + irdy_int = 1'b1; + + if(ivld_int) begin + if(state_wr_C == ST_WR_0) begin + mem_N[0] = (mem_C[0] >> WSIMD*WEIGHT_WIDTH); + mem_N[0][NW-1] = idat_int; + end + else begin + mem_N[1] = (mem_C[1] >> WSIMD*WEIGHT_WIDTH); + mem_N[1][NW-1] = idat_int; + end + + curr_N = (curr_C == NW-1) ? 0 : curr_C + 1; + end + end + endcase +end + +// ---------------------------------------------------------------------------- +// Reader +// ---------------------------------------------------------------------------- + +// -- Regs +logic [TH_BITS-1:0] cons_r_C = '0, cons_r_N; + +// -- REG +always_ff @( posedge clk ) begin : REG_PROC_RD + if(rst) begin + state_rd_C <= ST_RD_0; + + cons_r_C <= 0; + end + else begin + state_rd_C <= state_rd_N; + + cons_r_C <= cons_r_N; + end +end + +// -- NSL +always_comb begin : NSL_PROC_RD + state_rd_N = state_rd_C; + + case (state_rd_C) + ST_RD_0: + if(ordy_int && (state_wr_C != ST_WR_0)) begin + if(cons_r_C == TH-1) begin + state_rd_N = ST_RD_1; + end + end + + ST_RD_1: + if(ordy_int && (state_wr_C != ST_WR_1)) begin + if(cons_r_C == TH-1) begin + state_rd_N = ST_RD_0; + end + end + + endcase +end + +// -- DP +always_comb begin : DP_PROC_RD + cons_r_N = cons_r_C; + + done = 1'b0; + + ovld_int = 1'b0; + odat_int = 0; + + case (state_rd_C) + ST_RD_0: begin + if(ordy_int && (state_wr_C != ST_WR_0)) begin + ovld_int = 1'b1; + odat_int = mem_C[0]; + + done = (cons_r_C == TH-1); + cons_r_N = (cons_r_C == TH-1) ? 0 : cons_r_C + 1; + end + end + + ST_RD_1: begin + if(ordy_int && (state_wr_C != ST_WR_1)) begin + ovld_int = 1'b1; + odat_int = mem_C[1]; + + done = (cons_r_C == TH-1); + cons_r_N = (cons_r_C == TH-1) ? 0 : cons_r_C + 1; + end + end + + endcase +end + +// Oreg +skid #(.DATA_WIDTH(PE*SIMD*WEIGHT_WIDTH), .FEED_STAGES(N_DCPL_STAGES)) inst_oreg ( + .clk(clk), .rst(rst), + .ivld(ovld_int), .irdy(ordy_int), .idat(odat_int), + .ovld(ovld), .ordy(ordy), .odat(odat) +); + +endmodule \ No newline at end of file diff --git a/finn-rtllib/ram/ram_p_c.sv b/finn-rtllib/ram/ram_p_c.sv index 553121f2f8..db47fafe18 100644 --- a/finn-rtllib/ram/ram_p_c.sv +++ b/finn-rtllib/ram/ram_p_c.sv @@ -50,11 +50,8 @@ module ram_p_c #( (* ram_style = RAM_STYLE *) logic [DATA_BITS-1:0] ram[DEPTH]; - reg [DATA_BITS-1:0] a_data_reg = 0; - reg [DATA_BITS-1:0] b_data_reg = 0; - - reg [DATA_BITS-1:0] a_data_q = 0; - reg [DATA_BITS-1:0] b_data_q = 0; + logic [DATA_BITS-1:0] a_data_reg = 0; + logic [DATA_BITS-1:0] b_data_reg = 0; always_ff @(posedge clk) begin if(a_en) begin diff --git a/finn_xsi/finn_xsi/sim_engine.py b/finn_xsi/finn_xsi/sim_engine.py index 0d17e581af..7288650624 100644 --- a/finn_xsi/finn_xsi/sim_engine.py +++ b/finn_xsi/finn_xsi/sim_engine.py @@ -109,7 +109,7 @@ def run(self, cycles=None): # Execute Cycle self.ticks += 1 - print(f"Cycle {self.ticks}") + # print(f"Cycle {self.ticks}") strong = False for task in self.tasks: # Tasks read signals and derive updates to schedule for after the clock cycle diff --git a/src/finn/core/rtlsim_exec.py b/src/finn/core/rtlsim_exec.py index 3993b7cd62..dad9359819 100644 --- a/src/finn/core/rtlsim_exec.py +++ b/src/finn/core/rtlsim_exec.py @@ -358,6 +358,24 @@ def rtlsim_exec_finnxsi(model, execution_context, pre_hook=None, post_hook=None) # reset and call rtlsim, including any pre/post hooks finnxsi.reset_rtlsim(sim) + + # automatically load AXI-MM weight images for external_mem nodes + aximm_weights_json = model.get_metadata_prop("vivado_stitch_aximm_weights") + if aximm_weights_json is not None: + import json + + aximm_weights = json.loads(aximm_weights_json) + for aximm_name, npy_path in aximm_weights.items(): + weight_npy = np.load(npy_path) + # Pack weight values (int8 etc.) into a flat byte array for AXI-MM + # Each element is one weight; pack them LSB-first per line + weight_bytes = [] + for line in weight_npy.reshape(-1, weight_npy.shape[-1]): + for val in line: + weight_bytes.append(int(val) & 0xFF) + weight_data = np.array(weight_bytes, dtype=np.uint8) + sim.aximm_ro_image(aximm_name, 0, weight_data.flatten()) + if pre_hook is not None: pre_hook(sim) n_cycles = finnxsi.rtlsim_multi_io( diff --git a/src/finn/custom_op/fpgadataflow/hwcustomop.py b/src/finn/custom_op/fpgadataflow/hwcustomop.py index 5b4b6e7757..efd252207b 100644 --- a/src/finn/custom_op/fpgadataflow/hwcustomop.py +++ b/src/finn/custom_op/fpgadataflow/hwcustomop.py @@ -344,12 +344,14 @@ def generate_hdl_memstream(self, fpgapart, pumped_memory=0): else: pass - def generate_hdl_fetch_weights(self, fpgapart): + def generate_hdl_fetch_weights(self): """Helper function to generate verilog code for fetch_weights component. Currently utilized by MVAU.""" ops = ["MVAU_hls", "MVAU_rtl"] if self.onnx_node.op_type in ops or self.onnx_node.op_type.startswith("Elementwise"): - template_path = os.environ["FINN_ROOT"] + "/finn-rtllib/mlo/fetch_weights_wrapper.v" + template_path = ( + os.environ["FINN_ROOT"] + "/finn-rtllib/fetch_weights/fetch_weights_wrapper.v" + ) mname = self.onnx_node.name wdt = self.get_input_datatype(1) if self.onnx_node.op_type in ops: @@ -358,6 +360,8 @@ def generate_hdl_fetch_weights(self, fpgapart): pe = self.get_nodeattr("PE") simd = self.get_nodeattr("SIMD") n_reps = np.prod(self.get_nodeattr("numInputVectors")) + theight = self.get_nodeattr("TH") + en_mlo = "EN_MLO" if self.get_nodeattr("mlo_max_iter") else "NO_MLO" else: # Eltwise layers only have one parallelism parameter mw = 1 @@ -366,10 +370,13 @@ def generate_hdl_fetch_weights(self, fpgapart): simd = 1 # TODO use broadcast rhs shape here n_reps = np.prod(self.get_nodeattr("rhs_shape")[:-1]) + theight = 1 + en_mlo = "EN_MLO" if self.get_nodeattr("mlo_max_iter") else "NO_MLO" layer_offs = mw * mh # upper bound on how many layers can be supported, set to 64 for now n_max_layers = 64 code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + code_gen_dict = { "$MODULE_NAME_AXI_WRAPPER$": [mname + "_fetch_weights_wrapper"], "$MW$": [str(mw)], @@ -380,6 +387,9 @@ def generate_hdl_fetch_weights(self, fpgapart): "$WEIGHT_WIDTH$": [str(wdt.bitwidth())], "$LAYER_OFFS$": [str(layer_offs)], "$N_LAYERS$": [str(n_max_layers)], + "$TH$": [str(theight)], + "$EN_MLO$": [en_mlo], + "$DWC_MODULE_NAME$": [mname + "_dwc"], } # apply code generation to template with open(template_path, "r") as f: diff --git a/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py b/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py index b27f82ec93..1534cfb09e 100644 --- a/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py +++ b/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py @@ -40,6 +40,7 @@ ) from finn.custom_op.fpgadataflow.hwcustomop import HWCustomOp +from finn.util.basic import is_versal from finn.util.data_packing import numpy_to_hls_code, pack_innermost_dim_as_hex_string # ONNX i/o tensor shape assumptions for MatrixVectorActivation: @@ -88,7 +89,7 @@ def get_nodeattr_types(self): "s", False, "internal_decoupled", - {"internal_embedded", "internal_decoupled", "external"}, + {"internal_embedded", "internal_decoupled", "external", "external_mem", "dynamic"}, ), # FPGA resource type for memories in internal_decoupled mode # auto -- let Vivado decide @@ -123,8 +124,15 @@ def get_nodeattr_types(self): # weight data from the weight FIFOs. "runtime_writeable_weights": ("i", False, 0, {0, 1}), "pumpedMemory": ("i", False, 0, {0, 1}), - # dynamic input - "dynamic_input": ("i", False, 0, {0, 1}), + # Matrix unit activation type + "mua_type": ( + "s", + False, + "vector", + {"mv", "mvt", "mm"}, + ), + # tiling + "TH": ("i", False, 1), } my_attrs.update(super().get_nodeattr_types()) return my_attrs @@ -190,6 +198,8 @@ def verify_node(self): except Exception: info_messages.append("""The required MatrixVectorActivation attributes do not exist.""") + # TODO: Verify matrix unit type + # verify the number of inputs depending on noActivation value # check noActivation value to determine the number of inputs no_act = self.get_nodeattr("noActivation") @@ -253,26 +263,24 @@ def get_output_datatype(self, ind=0): """Returns FINN DataType of output.""" return DataType[self.get_nodeattr("outputDataType")] - def get_instream_width(self, ind=0): + def get_instream_width(self, ind=0): # TODO: Hacky, need to clean these calls ... if ind == 0: i_bits = self.get_input_datatype(0).bitwidth() width = i_bits * self.get_nodeattr("SIMD") elif ind == 1: - if self.get_nodeattr("dynamic_input"): - width = ( - self.get_folded_input_shape(ind)[-1] * self.get_input_datatype(ind).bitwidth() - ) - elif ( - self.get_nodeattr("mem_mode") == "internal_decoupled" - or self.get_nodeattr("mem_mode") == "external" - or self.get_nodeattr("mlo_max_iter") - ): - pe = self.get_nodeattr("PE") - simd = self.get_nodeattr("SIMD") - wp = self.get_input_datatype(1).bitwidth() - width = pe * simd * wp - else: - width = 0 + pe = self.get_nodeattr("PE") + simd = self.get_nodeattr("SIMD") + wp = self.get_input_datatype(1).bitwidth() + mem_mode = self.get_nodeattr("mem_mode") + theight = self.get_nodeattr("TH") + + match mem_mode: + case "dynamic": + width = pe * wp + case "external" | "external_mem" | "internal_decoupled": + width = ((pe * simd) * wp) // theight + case _: + width = 0 elif ind == 2: # check if integrated thresholding and return 0 # because threshold values are always embedded @@ -297,22 +305,26 @@ def get_folded_input_shape(self, ind=0): mh = self.get_nodeattr("MH") simd = self.get_nodeattr("SIMD") pe = self.get_nodeattr("PE") + mem_mode = self.get_nodeattr("mem_mode") sf = mw // simd nf = mh // pe vecs = list(self.get_nodeattr("numInputVectors")) + n_vecs = int(np.prod(vecs)) + theight = self.get_nodeattr("TH") if ind == 0: # calculate shape of input 0 folded_input_shape = tuple(vecs + [sf, simd]) elif ind == 1: - if self.get_nodeattr("dynamic_input"): - # calculate shape of input 1 (weights dynamic) - folded_input_shape = tuple(vecs[:2] + [mw] + [nf, pe]) - elif self.get_nodeattr("mem_mode") == "external" or self.get_nodeattr("mlo_max_iter"): - # calculate shape of input 1 (weights) - folded_input_shape = tuple(vecs + [sf * nf, simd * pe]) - else: - raise Exception("Undefined input shape for requested input") + match mem_mode: + case "dynamic": + folded_input_shape = (1, mw, nf, pe) + case "external" | "external_mem" | "internal_decoupled": + folded_input_shape = (n_vecs, sf * nf, (simd * pe) // theight) + case _: + raise Exception("Undefined input shape for requested input") + else: + raise Exception("Undefined input shape for requested input") return folded_input_shape @@ -377,7 +389,7 @@ def uram_estimation(self): (mmode == "internal_decoupled" and mstyle != "ultra") or (mmode == "internal_embedded" and self.calc_wmem() <= 128) or (mmode == "external") - or self.get_nodeattr("mlo_max_iter") + or (mmode == "external_mem") ): return 0 width_multiplier = math.ceil(mem_width / 72) @@ -407,7 +419,7 @@ def bram_estimation(self): (mmode == "internal_decoupled" and mstyle in ["distributed", "ultra"]) or (mmode == "internal_embedded" and self.calc_wmem() <= 128) or (mmode == "external") - or self.get_nodeattr("mlo_max_iter") + or (mmode == "external_mem") ): return 0 # assuming SDP mode RAMB18s (see UG573 Table 1-10) @@ -478,13 +490,12 @@ def minimize_accumulator_width(self, model): idt = self.get_input_datatype(0) - # if runtime-writeable weights or mem_mode=external, then the values of the weights can - # change and we need to use the worst-case values from the datatypes + # if runtime-writeable weights, mem_mode=external, or weights are absent (MLO), + # then we need to use the worst-case values from the datatypes if ( self.get_nodeattr("runtime_writeable_weights") - or self.get_nodeattr("mem_mode") == "external" - or self.get_nodeattr("mlo_max_iter") - or self.get_nodeattr("dynamic_input") + or self.get_nodeattr("mem_mode") in ["external", "external_mem", "dynamic"] + or weights is None ): mw = self.get_nodeattr("MW") mh = self.get_nodeattr("MH") @@ -559,11 +570,11 @@ def minimize_weight_bit_width(self, model): """Minimize the bit width based on the values of the weights""" if not ( self.get_nodeattr("runtime_writeable_weights") - or self.get_nodeattr("mem_mode") == "external" - or self.get_nodeattr("mlo_max_iter") - or self.get_nodeattr("dynamic_input") + or self.get_nodeattr("mem_mode") in ["external", "external_mem", "dynamic"] ): weights = model.get_initializer(self.onnx_node.input[1]) + if weights is None: + return DataType[self.get_nodeattr("weightDataType")] w_min = weights.min() w_max = weights.max() if w_min < 0: @@ -724,13 +735,17 @@ def make_weight_file(self, weights, weight_file_mode, weight_file_name): # flipped weight_tensor_pe_flipped = weight_tensor_pe_flipped.reshape(1, -1, pe * simd) weight_tensor_pe_flipped = weight_tensor_pe_flipped.copy() + # tiling + tinner = (pe * simd) // self.get_nodeattr("TH") + weight_tensor_simd_flipped = weight_tensor_simd_flipped.reshape(1, -1, tinner) + weight_tensor_pe_flipped = weight_tensor_pe_flipped.reshape(1, -1, tinner) if weight_file_mode == "decoupled_npy": # save weight stream into npy for cppsim np.save(weight_file_name, weight_tensor_simd_flipped) elif weight_file_mode == "decoupled_verilog_dat": # convert weight values into hexstring weight_width = self.get_instream_width(1) - if self.get_nodeattr("dynamic_input"): + if self.get_nodeattr("mem_mode") == "dynamic": weight_width = weight_width * simd # pad to nearest 4 bits to get hex strings weight_width_padded = roundup_to_integer_multiple(weight_width, 4) @@ -762,7 +777,7 @@ def make_weight_file(self, weights, weight_file_mode, weight_file_name): # memstream axi-lite interface will map each mem line to # one or multiple 32-bit words weight_width = self.get_instream_width(1) - if self.get_nodeattr("dynamic_input"): + if self.get_nodeattr("mem_mode") == "dynamic": weight_width = weight_width * simd words_per_memwidth = 2 ** math.ceil(math.log2(weight_width / 32)) if words_per_memwidth < 1: @@ -793,25 +808,22 @@ def generate_params(self, model, path): # weights, if not external weights = model.get_initializer(self.onnx_node.input[1]) if weights is not None: - if mem_mode == "internal_embedded": - # save hlslib-compatible weights in params.h - weight_filename = "{}/params.h".format(code_gen_dir) - self.make_weight_file(weights, "hls_header", weight_filename) - elif mem_mode == "internal_decoupled" or mem_mode == "external": - weight_filename_sim = "{}/input_1.npy".format(code_gen_dir) - # save internal_decoupled weights for cppsim - self.make_weight_file(weights, "decoupled_npy", weight_filename_sim) - if mem_mode == "internal_decoupled": + match mem_mode: + case "internal_embedded": + # save hlslib-compatible weights in params.h + weight_filename = "{}/params.h".format(code_gen_dir) + self.make_weight_file(weights, "hls_header", weight_filename) + case "internal_decoupled" | "external" | "external_mem": + weight_filename_sim = "{}/input_1.npy".format(code_gen_dir) + # save internal_decoupled weights for cppsim + self.make_weight_file(weights, "decoupled_npy", weight_filename_sim) + # if mem_mode == "internal_decoupled": # also save weights as Verilog .dat file # This file will be ignored when synthesizing UltraScale memory. weight_filename_rtl = "{}/memblock.dat".format(code_gen_dir) self.make_weight_file(weights, "decoupled_verilog_dat", weight_filename_rtl) else: - if not ( - mem_mode == "external" - or self.get_nodeattr("mlo_max_iter") - or self.get_nodeattr("dynamic_input") - ): + if mem_mode not in ["external", "dynamic", "external_mem"]: raise Exception( """Invalid setting found, weight values not initialized, but neither "external" case nor MLO.""" @@ -926,33 +938,48 @@ def get_verilog_top_module_intf_names(self): if pumped_compute or self.get_nodeattr("pumpedMemory"): intf_names["clk2x"] = ["ap_clk2x"] - if self.get_nodeattr("mlo_max_iter"): - intf_names["aximm"].append(("axi_mm", 64)) - intf_names["s_axis"].append(("in_idx0_V", 32)) - else: - dynamic_input = self.get_nodeattr("dynamic_input") - mem_mode = self.get_nodeattr("mem_mode") - if dynamic_input: - weight_width = self.get_instream_width(1) - weight_width = weight_width * self.get_nodeattr("SIMD") - intf_names["s_axis"].append(("in1_V", roundup_to_integer_multiple(weight_width, 8))) - else: - if mem_mode == "external": - intf_names["s_axis"].append(("in1_V", self.get_instream_width_padded(1))) - elif mem_mode == "internal_decoupled": - # only expose axilite interface if attribute is set - runtime_writeable = self.get_nodeattr("runtime_writeable_weights") - if runtime_writeable: - intf_names["axilite"] = ["s_axilite"] + match self.get_nodeattr("mem_mode"): + case "external_mem": + intf_names["aximm"].append(("axi_mm", 64)) + if self.get_nodeattr("mlo_max_iter") > 0: + intf_names["s_axis"].append(("in_idx0_V", 32)) + case "dynamic" | "external": + intf_names["s_axis"].append(("in1_V", self.get_instream_width_padded(1))) + case "internal_decoupled": + # only expose axilite interface if attribute is set + if self.get_nodeattr("runtime_writeable_weights"): + intf_names["axilite"] = ["s_axilite"] + return intf_names + def generate_hdl(self, fpgapart): + mem_mode = self.get_nodeattr("mem_mode") + + match mem_mode: + case "dynamic": + self.generate_hdl_dynload() + case "external_mem": + self.generate_hdl_fetch_weights() + case "internal_decoupled": + if self.get_nodeattr("ram_style") == "ultra" and not is_versal(fpgapart): + assert ( + self.get_nodeattr("runtime_writeable_weights") == 1 + ), """Layer with URAM weights must have runtime_writeable_weights=1 + if Ultrascale device is targeted.""" + self.generate_hdl_memstream( + fpgapart, pumped_memory=self.get_nodeattr("pumpedMemory") + ) + def code_generation_ipi(self): source_target = "./ip/verilog/rtl_ops/%s" % self.onnx_node.name cmd = ["file mkdir %s" % source_target] - dyn_input = self.get_nodeattr("dynamic_input") - mem_mode = self.get_nodeattr("mem_mode") + + # # check if additional components are needed - if mem_mode == "internal_decoupled" or self.get_nodeattr("mlo_max_iter") or dyn_input: + mem_mode = self.get_nodeattr("mem_mode") + if mem_mode in ["internal_decoupled", "dynamic", "external_mem"]: + # + # Base runtime_writeable = self.get_nodeattr("runtime_writeable_weights") node_name = self.onnx_node.name # create a hierarchy for this layer, with the same port names @@ -984,133 +1011,177 @@ def code_generation_ipi(self): "-vlnv xilinx.com:interface:axis_rtl:1.0 /%s/%s" % (node_name, din_name) ) - if self.get_nodeattr("mlo_max_iter"): - cmd.append( - "create_bd_intf_pin -mode Slave " - "-vlnv xilinx.com:interface:axis_rtl:1.0 /%s/%s" % (node_name, "in_idx0_V") - ) - cmd.append( - "create_bd_intf_pin -mode Master " - "-vlnv xilinx.com:interface:aximm_rtl:1.0 /%s/%s" % (node_name, "axi_mm") - ) - + # # Instantiate either the HLS or RTL IP depending on operator self.instantiate_ip(cmd) code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") - if dyn_input: - # additional dynamic input - win_name = self.get_verilog_top_module_intf_names()["s_axis"][1][0] - cmd.append( - "create_bd_intf_pin -mode Slave " - "-vlnv xilinx.com:interface:axis_rtl:1.0 /%s/%s" % (node_name, win_name) - ) - # dynamic loader - ram_rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/ram/") - dyn_rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/dynload/hdl/") - file_suffix = "_dynamic_load_wrapper.v" - # automatically find memstream verilog component in code generation directory - for fname in os.listdir(code_gen_dir): - if fname.endswith(file_suffix): - strm_tmpl = fname - strm_tmpl_name = strm_tmpl[:-2] - sourcefiles = [ - os.path.join(code_gen_dir, strm_tmpl), - ram_rtllib_dir + "ram_p_c.sv", - dyn_rtllib_dir + "dynamic_load.sv", - ] - for f in sourcefiles: - cmd += ["add_files -copy_to %s -norecurse %s" % (source_target, f)] - strm_inst = node_name + "_wdynld" - strm_out_name = "m_axis_0" - elif self.get_nodeattr("mlo_max_iter"): - # instantiate a fetch weights component and connect it to the IP - mlo_rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/mlo/") - reg_rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/skid/") - ram_rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/ram/") - dwc_rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/dwc/hdl/") - dma_rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/cdma/") - file_suffix = "_fetch_weights_wrapper.v" - # automatically find memstream verilog component in code generation directory - for fname in os.listdir(code_gen_dir): - if fname.endswith(file_suffix): - strm_tmpl = fname - strm_tmpl_name = strm_tmpl[:-2] - sourcefiles = [ - os.path.join(code_gen_dir, strm_tmpl), - reg_rtllib_dir + "skid.sv", - ram_rtllib_dir + "ram_p_c.sv", - dwc_rtllib_dir + "axis_adapter.v", - dwc_rtllib_dir + "axis_fifo_adapter.sv", - dwc_rtllib_dir + "axis_fifo.v", - mlo_rtllib_dir + "fetch_weights.sv", - mlo_rtllib_dir + "local_weight_buffer.sv", - ] - # add files from cdma dir - for file in os.listdir(dma_rtllib_dir): - if file.endswith(".sv") or file.endswith(".svh"): - sourcefiles.append(os.path.join(dma_rtllib_dir, file)) - for file in os.listdir(dma_rtllib_dir + "cdma_a/"): - if file.endswith(".sv") or file.endswith(".svh"): - sourcefiles.append(os.path.join(dma_rtllib_dir + "cdma_a/", file)) - for file in os.listdir(dma_rtllib_dir + "cdma_u/"): - if file.endswith(".sv") or file.endswith(".svh"): - sourcefiles.append(os.path.join(dma_rtllib_dir + "cdma_u/", file)) - for file in os.listdir(dma_rtllib_dir + "cdma_x/"): - if file.endswith(".sv") or file.endswith(".svh"): - sourcefiles.append(os.path.join(dma_rtllib_dir + "cdma_x/", file)) - - for f in sourcefiles: - cmd += ["add_files -copy_to %s -norecurse %s" % (source_target, f)] - strm_inst = node_name + "_fetch_weights" - strm_out_name = "out0_V" - # update intf dict to remove weights input and replace with index/tap input - self.get_verilog_top_module_intf_names()["s_axis"] - - elif mem_mode == "internal_decoupled": - # instantiate a streamer and connect it to the IP - axi_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/axi/hdl/") - ms_rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/memstream/hdl/") - file_suffix = "_memstream_wrapper.v" - # automatically find memstream verilog component in code generation directory - for fname in os.listdir(code_gen_dir): - if fname.endswith(file_suffix): - strm_tmpl = fname - strm_tmpl_name = strm_tmpl[:-2] - sourcefiles = [ - os.path.join(code_gen_dir, strm_tmpl), - axi_dir + "axilite.sv", - ms_rtllib_dir + "memstream_axi.sv", - ms_rtllib_dir + "memstream.sv", - ] - for f in sourcefiles: - cmd += ["add_files -copy_to %s -norecurse %s" % (source_target, f)] - strm_inst = node_name + "_wstrm" - strm_out_name = "m_axis_0" + + match mem_mode: + # + # Dynamic loader instantiation + case "dynamic": + # additional dynamic input + win_name = self.get_verilog_top_module_intf_names()["s_axis"][1][0] + cmd.append( + "create_bd_intf_pin -mode Slave " + "-vlnv xilinx.com:interface:axis_rtl:1.0 /%s/%s" % (node_name, win_name) + ) + + # dynamic loader + ram_rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/ram/") + dyn_rtllib_dir = os.path.join( + os.environ["FINN_ROOT"], "finn-rtllib/dynload/hdl/" + ) + file_suffix = "_dynamic_load_wrapper.v" + # automatically find memstream verilog component in code generation directory + for fname in os.listdir(code_gen_dir): + if fname.endswith(file_suffix): + strm_tmpl = fname + strm_tmpl_name = strm_tmpl[:-2] + sourcefiles = [ + os.path.join(code_gen_dir, strm_tmpl), + ram_rtllib_dir + "ram_p_c.sv", + dyn_rtllib_dir + "dynamic_load.sv", + ] + for f in sourcefiles: + cmd += ["add_files -copy_to %s -norecurse %s" % (source_target, f)] + strm_inst = node_name + "_wdynld" + strm_out_name = "m_axis_0" + + # + # Fetch weights instantiation (MLO or TODO: tiling) + case "external_mem": + # additional inputs + cmd.append( + "create_bd_intf_pin -mode Master " + "-vlnv xilinx.com:interface:aximm_rtl:1.0 /%s/%s" % (node_name, "axi_mm") + ) + if self.get_nodeattr("mlo_max_iter") > 0: + cmd.append( + "create_bd_intf_pin -mode Slave " + "-vlnv xilinx.com:interface:axis_rtl:1.0 /%s/%s" + % (node_name, "in_idx0_V") + ) + + # instantiate a fetch weights component and connect it to the IP + ram_rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/ram/") + reg_rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/skid/") + que_rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/fifo/hdl/") + fwg_rtllib_dir = os.path.join( + os.environ["FINN_ROOT"], "finn-rtllib/fetch_weights/" + ) + dma_rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/cdma/") + file_suffix = "_fetch_weights_wrapper.v" + # automatically find memstream verilog component in code generation directory + for fname in os.listdir(code_gen_dir): + if fname.endswith(file_suffix): + strm_tmpl = fname + strm_tmpl_name = strm_tmpl[:-2] + sourcefiles = [ + os.path.join(code_gen_dir, strm_tmpl), + ram_rtllib_dir + "ram_p_c.sv", + reg_rtllib_dir + "skid.sv", + que_rtllib_dir + "Q_srl.v", + fwg_rtllib_dir + "fetch_weights.sv", + fwg_rtllib_dir + "local_weight_buffer.sv", + ] + # Create Vivado axis_dwidth_converter IP + theight = self.get_nodeattr("TH") + wdt = self.get_input_datatype(1) + iwsimd = ( + (self.get_nodeattr("PE") * self.get_nodeattr("SIMD")) // theight + if theight > 1 + else self.get_nodeattr("SIMD") + ) + ds_bits_ba = ((iwsimd * wdt.bitwidth() + 7) // 8) * 8 + dwc_ip_name = node_name + "_dwc" + s_bytes = 256 // 8 + m_bytes = ds_bits_ba // 8 + cmd += [ + "create_ip -name axis_dwidth_converter -vendor xilinx.com " + "-library ip -version 1.1 -module_name %s" % dwc_ip_name, + "set_property -dict [list " + "CONFIG.S_TDATA_NUM_BYTES {%d} " + "CONFIG.M_TDATA_NUM_BYTES {%d} " + "CONFIG.HAS_TLAST {1} " + "CONFIG.HAS_TKEEP {1} " + "] [get_ips %s]" % (s_bytes, m_bytes, dwc_ip_name), + "generate_target all [get_ips %s]" % dwc_ip_name, + ] + + # add files from cdma dir + for file in os.listdir(dma_rtllib_dir): + if file.endswith(".sv") or file.endswith(".svh"): + sourcefiles.append(os.path.join(dma_rtllib_dir, file)) + for file in os.listdir(dma_rtllib_dir + "cdma_a/"): + if file.endswith(".sv") or file.endswith(".svh"): + sourcefiles.append(os.path.join(dma_rtllib_dir + "cdma_a", file)) + for file in os.listdir(dma_rtllib_dir + "cdma_u/"): + if file.endswith(".sv") or file.endswith(".svh"): + sourcefiles.append(os.path.join(dma_rtllib_dir + "cdma_u/", file)) + for file in os.listdir(dma_rtllib_dir + "cdma_x/"): + if file.endswith(".sv") or file.endswith(".svh"): + sourcefiles.append(os.path.join(dma_rtllib_dir + "cdma_x/", file)) + for f in sourcefiles: + cmd += ["add_files -copy_to %s -norecurse %s" % (source_target, f)] + strm_inst = node_name + "_fetch_weights" + strm_out_name = "out0_V" + # update intf dict to remove weights input and replace with index/tap input + self.get_verilog_top_module_intf_names()["s_axis"] + + # + # Memstream instantiation + case "internal_decoupled": + # instantiate a streamer and connect it to the IP + axi_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/axi/hdl/") + ms_rtllib_dir = os.path.join( + os.environ["FINN_ROOT"], "finn-rtllib/memstream/hdl/" + ) + file_suffix = "_memstream_wrapper.v" + # automatically find memstream verilog component in code generation directory + for fname in os.listdir(code_gen_dir): + if fname.endswith(file_suffix): + strm_tmpl = fname + strm_tmpl_name = strm_tmpl[:-2] + sourcefiles = [ + os.path.join(code_gen_dir, strm_tmpl), + axi_dir + "axilite.sv", + ms_rtllib_dir + "memstream_axi.sv", + ms_rtllib_dir + "memstream.sv", + ] + for f in sourcefiles: + cmd += ["add_files -copy_to %s -norecurse %s" % (source_target, f)] + strm_inst = node_name + "_wstrm" + strm_out_name = "m_axis_0" cmd.append( "create_bd_cell -type hier -reference %s /%s/%s" % (strm_tmpl_name, node_name, strm_inst) ) - if self.get_nodeattr("mlo_max_iter"): - cmd.append( - "connect_bd_intf_net [get_bd_intf_pins %s/%s] " - "[get_bd_intf_pins %s/%s/%s]" - % (node_name, "in_idx0_V", node_name, strm_inst, "in_idx0_V") - ) + # + # Connect + match mem_mode: + case "dynamic": + cmd.append( + "connect_bd_intf_net [get_bd_intf_pins %s/%s] " + "[get_bd_intf_pins %s/%s/s_axis_0]" + % (node_name, win_name, node_name, strm_inst) + ) - cmd.append( - "connect_bd_intf_net [get_bd_intf_pins %s/%s] " - "[get_bd_intf_pins %s/%s/%s]" - % (node_name, "axi_mm", node_name, strm_inst, "axi_mm") - ) + case "external_mem": + cmd.append( + "connect_bd_intf_net [get_bd_intf_pins %s/%s] " + "[get_bd_intf_pins %s/%s/%s]" + % (node_name, "axi_mm", node_name, strm_inst, "axi_mm") + ) + if self.get_nodeattr("mlo_max_iter") > 0: + cmd.append( + "connect_bd_intf_net [get_bd_intf_pins %s/%s] " + "[get_bd_intf_pins %s/%s/%s]" + % (node_name, "in_idx0_V", node_name, strm_inst, "in_idx0_V") + ) - if dyn_input: - cmd.append( - "connect_bd_intf_net [get_bd_intf_pins %s/%s] " - "[get_bd_intf_pins %s/%s/s_axis_0]" - % (node_name, win_name, node_name, strm_inst) - ) cmd.append( "connect_bd_intf_net [get_bd_intf_pins %s/%s/%s] " "[get_bd_intf_pins %s/%s/in1_V]" @@ -1124,11 +1195,10 @@ def code_generation_ipi(self): "connect_bd_net [get_bd_pins %s/%s] [get_bd_pins %s/%s/ap_clk]" % (node_name, clk_name, node_name, strm_inst) ) + # if using 2x pumped memory, connect the memstreamer's 2x clk input # to the 2x clock port. otherwise connect it to the regular clock port. - if mem_mode == "internal_decoupled" and not ( - self.get_nodeattr("mlo_max_iter") or dyn_input - ): + if mem_mode == "internal_decoupled": if self.get_nodeattr("pumpedMemory"): cmd.append( "connect_bd_net [get_bd_pins %s/%s] [get_bd_pins %s/%s/ap_clk2x]" @@ -1139,8 +1209,8 @@ def code_generation_ipi(self): "connect_bd_net [get_bd_pins %s/%s] [get_bd_pins %s/%s/ap_clk2x]" % (node_name, clk_name, node_name, strm_inst) ) - # runtime writeable weights - if runtime_writeable: + # runtime writeable weights (skip for MLO nodes) + if runtime_writeable and not self.get_nodeattr("mlo_max_iter"): axilite_name = self.get_verilog_top_module_intf_names()["axilite"][0] cmd.append( "create_bd_intf_pin -mode Slave " @@ -1154,6 +1224,7 @@ def code_generation_ipi(self): ) # TODO calculate and pass in segment size here cmd.append("assign_bd_address") + cmd.append( "connect_bd_net [get_bd_pins %s/%s] [get_bd_pins %s/%s/%s]" % (node_name, rst_name, node_name, node_name, rst_name) @@ -1175,11 +1246,12 @@ def code_generation_ipi(self): # save bd cmd.append("save_bd_design") - elif (mem_mode == "internal_embedded" or mem_mode == "external") and not self.get_nodeattr( - "mlo_max_iter" - ): + + elif mem_mode in ["internal_embedded", "external"]: # base class impl sufficient for internal_embedded/external modes self.instantiate_ip(cmd) + else: raise Exception("Unrecognized mem_mode for MatrixVectorActivation") + return cmd diff --git a/src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py b/src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py index 9cd6fc2a9d..f3aa269789 100644 --- a/src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py +++ b/src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py @@ -31,7 +31,7 @@ from finn.custom_op.fpgadataflow.matrixvectoractivation import MVAU from finn.custom_op.fpgadataflow.rtlbackend import RTLBackend -from finn.util.basic import get_dsp_block, is_versal +from finn.util.basic import get_dsp_block from finn.util.data_packing import npy_to_rtlsim_input, rtlsim_output_to_npy # ONNX i/o tensor shape assumptions for MatrixVectorActivation_rtl: @@ -58,7 +58,6 @@ def get_nodeattr_types(self): def execute_node(self, context, graph): mode = self.get_nodeattr("exec_mode") - dynamic_input = self.get_nodeattr("dynamic_input") mem_mode = self.get_nodeattr("mem_mode") node = self.onnx_node @@ -91,7 +90,7 @@ def execute_node(self, context, graph): ) if in_ind == 1: - if dynamic_input or self.get_nodeattr("mlo_max_iter"): + if mem_mode in ["dynamic", "external"]: reshaped_input = context[inputs].reshape(-1, context[inputs].shape[-1]) self.make_weight_file( reshaped_input, "decoupled_npy", "{}/input_1.npy".format(code_gen_dir) @@ -101,18 +100,15 @@ def execute_node(self, context, graph): nbits = self.get_instream_width() inp = npy_to_rtlsim_input("{}/input_0.npy".format(code_gen_dir), export_idt, nbits) super().reset_rtlsim(sim) - if ( - dynamic_input - or mem_mode in ["external", "internal_decoupled"] - or self.get_nodeattr("mlo_max_iter") - ): + if mem_mode in ["external", "dynamic", "internal_decoupled", "external_mem"]: wnbits = self.get_instream_width(1) - if dynamic_input: + if mem_mode == "dynamic": wnbits = wnbits * self.get_nodeattr("SIMD") export_wdt = self.get_input_datatype(1) wei = npy_to_rtlsim_input("{}/input_1.npy".format(code_gen_dir), export_wdt, wnbits) num_w_reps = np.prod(self.get_nodeattr("numInputVectors")) + num_w_reps = num_w_reps // self.get_nodeattr("TH") io_dict = { "inputs": {"in0": inp, "in1": wei * num_w_reps}, @@ -164,23 +160,43 @@ def instantiate_ip(self, cmd): # instantiate the RTL IP node_name = self.onnx_node.name code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") - rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/mvu/") - sourcefiles = [ - "mvu_pkg.sv", - "mvu_vvu_axi.sv", - "replay_buffer.sv", - "mvu.sv", - "mvu_vvu_8sx9_dsp58.sv", - "add_multi.sv", - ] + + theight = self.get_nodeattr("TH") + if theight > 1: + rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/mvu_tiled/") + sourcefiles = [ + "../fifo/hdl/Q_srl.v", + "../skid/skid.sv", + "../ram/ram_p_c.sv", + "mvu_tiled_axi.sv", + "cu_mvau_tiled.sv", + "acc_stage.sv", + "add_tree.sv", + "reorder_out.sv", + "replay_buff_tile.sv", + "weights_buff_tile.sv", + ] + else: + rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/mvu/") + sourcefiles = [ + "mvu_pkg.sv", + "mvu_vvu_axi.sv", + "replay_buffer.sv", + "mvu.sv", + "mvu_vvu_8sx9_dsp58.sv", + "add_multi.sv", + ] sourcefiles = [ os.path.join(code_gen_dir, self.get_nodeattr("gen_top_module") + "_wrapper.v") ] + [rtllib_dir + _ for _ in sourcefiles] for f in sourcefiles: cmd.append("add_files -norecurse %s" % (f)) - mem_mode = self.get_nodeattr("mem_mode") - if mem_mode == "internal_decoupled" or self.get_nodeattr("mlo_max_iter"): + if self.get_nodeattr("mem_mode") in [ + "internal_decoupled", + "dynamic", + "external_mem", + ] or self.get_nodeattr("mlo_max_iter"): cmd.append( "create_bd_cell -type hier -reference %s /%s/%s" % ( @@ -280,9 +296,9 @@ def generate_hdl(self, model, fpgapart, clk): wdt = self.get_input_datatype(1) narrow_weights = ( 0 - if np.min(weights) == wdt.min() - or self.get_nodeattr("dynamic_input") - or (self.get_nodeattr("mlo_max_iter") > 1) + if weights is None + or np.min(weights) == wdt.min() + or self.get_nodeattr("mem_mode") in ["dynamic", "external_mem"] else 1 ) code_gen_dict["$NARROW_WEIGHTS$"] = str(narrow_weights) @@ -305,28 +321,20 @@ def generate_hdl(self, model, fpgapart, clk): ) as f: f.write(template_wrapper) - dynamic_input = self.get_nodeattr("dynamic_input") - mem_mode = self.get_nodeattr("mem_mode") + super().generate_hdl(fpgapart) - if dynamic_input: - self.generate_hdl_dynload() - elif mem_mode == "internal_decoupled" and not self.get_nodeattr("mlo_max_iter"): - if self.get_nodeattr("ram_style") == "ultra" and not is_versal(fpgapart): - runtime_writeable = self.get_nodeattr("runtime_writeable_weights") - assert ( - runtime_writeable == 1 - ), """Layer with URAM weights must have runtime_writeable_weights=1 - if Ultrascale device is targeted.""" - self.generate_hdl_memstream(fpgapart, pumped_memory=self.get_nodeattr("pumpedMemory")) - elif self.get_nodeattr("mlo_max_iter"): - self.generate_hdl_fetch_weights(fpgapart) # set ipgen_path and ip_path so that HLS-Synth transformation # and stich_ip transformation do not complain self.set_nodeattr("ipgen_path", code_gen_dir) self.set_nodeattr("ip_path", code_gen_dir) def prepare_codegen_default(self, fpgapart, clk): - template_path = os.environ["FINN_ROOT"] + "/finn-rtllib/mvu/mvu_vvu_axi_wrapper.v" + if self.get_nodeattr("TH") > 1: + template_path = ( + os.environ["FINN_ROOT"] + "/finn-rtllib/mvu_tiled/mvu_tiled_axi_wrapper.v" + ) + else: + template_path = os.environ["FINN_ROOT"] + "/finn-rtllib/mvu/mvu_vvu_axi_wrapper.v" # check if settings are valid pumped_compute = self.get_nodeattr("pumpedCompute") @@ -335,6 +343,7 @@ def prepare_codegen_default(self, fpgapart, clk): raise Exception( "Clock pumping an input of SIMD=1 is not meaningful. Please increase SIMD." ) + dsp_block = get_dsp_block(fpgapart) code_gen_dict = {} code_gen_dict["$IS_MVU$"] = [str(1)] @@ -344,6 +353,7 @@ def prepare_codegen_default(self, fpgapart, clk): code_gen_dict["$MH$"] = [str(self.get_nodeattr("MH"))] code_gen_dict["$PE$"] = [str(self.get_nodeattr("PE"))] code_gen_dict["$SIMD$"] = [str(simd)] + code_gen_dict["$TH$"] = [str(self.get_nodeattr("TH"))] code_gen_dict["$ACTIVATION_WIDTH$"] = [str(self.get_input_datatype(0).bitwidth())] code_gen_dict["$WEIGHT_WIDTH$"] = [str(self.get_input_datatype(1).bitwidth())] code_gen_dict["$ACCU_WIDTH$"] = [str(self.get_output_datatype().bitwidth())] @@ -357,26 +367,49 @@ def prepare_codegen_default(self, fpgapart, clk): def get_rtl_file_list(self, abspath=False): if abspath: code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + "/" - rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/mvu/") + if self.get_nodeattr("TH") > 1: + rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/mvu_tiled/") + else: + rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/mvu/") else: code_gen_dir = "" rtllib_dir = "" - verilog_files = [ - "mvu_pkg.sv", - "mvu_vvu_axi.sv", - "replay_buffer.sv", - "mvu.sv", - "mvu_vvu_8sx9_dsp58.sv", - "add_multi.sv", - ] - verilog_files = [ - os.path.join(code_gen_dir, self.get_nodeattr("gen_top_module") + "_wrapper.v") - ] + [rtllib_dir + _ for _ in verilog_files] + if self.get_nodeattr("TH") > 1: + verilog_files = [ + "../fifo/hdl/Q_srl.v", + "../skid/skid.sv", + "../ram/ram_p_c.sv", + "acc_stage.sv", + "add_tree.sv", + "replay_buff_tile.sv", + "weights_buff_tile.sv", + "reorder_out.sv", + "cu_mvau_tiled.sv", + "mvu_tiled_axi.sv", + ] + verilog_files = [ + os.path.join(code_gen_dir, self.get_nodeattr("gen_top_module") + "_wrapper.v") + ] + [rtllib_dir + _ for _ in verilog_files] + else: + verilog_files = [ + "mvu_pkg.sv", + "mvu_vvu_axi.sv", + "replay_buffer.sv", + "mvu.sv", + "mvu_vvu_8sx9_dsp58.sv", + "add_multi.sv", + ] + verilog_files = [ + os.path.join(code_gen_dir, self.get_nodeattr("gen_top_module") + "_wrapper.v") + ] + [rtllib_dir + _ for _ in verilog_files] return verilog_files def get_verilog_paths(self): verilog_paths = super().get_verilog_paths() - verilog_paths.append(os.environ["FINN_ROOT"] + "/finn-rtllib/mvu") + if self.get_nodeattr("TH") > 1: + verilog_paths.append(os.environ["FINN_ROOT"] + "/finn-rtllib/mvu_tiled") + else: + verilog_paths.append(os.environ["FINN_ROOT"] + "/finn-rtllib/mvu") return verilog_paths diff --git a/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py b/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py index 9bc150982f..4a8e4ab18b 100644 --- a/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py +++ b/src/finn/transformation/fpgadataflow/convert_to_hw_layers.py @@ -1572,7 +1572,7 @@ def apply(self, model): noActivation=0, numInputVectors=list(mm_in_shape[:-1]), name="MVAU_" + n.name, - dynamic_input=W is None, + mem_mode="dynamic" if W is None else "internal_decoupled", inFIFODepths=[2, 2] if W is None else [2], ) graph.node.insert(node_ind, new_node) @@ -1604,7 +1604,7 @@ def apply(self, model): noActivation=1, numInputVectors=list(mm_in_shape[:-1]), name="MVAU_" + n.name, - dynamic_input=W is None, + mem_mode="dynamic" if W is None else "internal_decoupled", inFIFODepths=[2, 2] if W is None else [2], ) graph.node.insert(node_ind, new_node) diff --git a/src/finn/transformation/fpgadataflow/create_stitched_ip.py b/src/finn/transformation/fpgadataflow/create_stitched_ip.py index a2a3fde233..8544e1e6b5 100644 --- a/src/finn/transformation/fpgadataflow/create_stitched_ip.py +++ b/src/finn/transformation/fpgadataflow/create_stitched_ip.py @@ -95,6 +95,7 @@ def __init__(self, fpgapart, clk_ns, ip_name="finn_design", vitis=False, signatu self.vitis = vitis self.signature = signature self.has_aximm = False + self.aximm_weight_files = {} self.aximm_idx = 0 self.has_m_axis = False self.m_axis_idx = 0 @@ -186,94 +187,51 @@ def connect_axi(self, node, model): ) self.intf_names["axilite"].append(ext_if_name) - if not node_inst.get_nodeattr("mlo_max_iter"): - if node.op_type == "FINNLoop": - for mm_intf_name in aximm_intf_name: - self.connect_cmds.append( - "make_bd_intf_pins_external [get_bd_intf_pins %s/%s]" - % (inst_name, mm_intf_name[0]) - ) - self.connect_cmds.append( - "set_property name %s [get_bd_intf_ports %s_0]" - % (mm_intf_name[0], mm_intf_name[0]) - ) - self.connect_cmds.append("assign_bd_address") - - if mm_intf_name[0] == "m_axi_hbm": - seg_name = "%s/%s/SEG_%s_Reg" % ( - inst_name, - mm_intf_name[0], - mm_intf_name[0], - ) - else: - seg_name = "%s/%s/SEG_%s_Reg" % ( - inst_name, - mm_intf_name[0], - mm_intf_name[0], - ) - self.connect_cmds.append( - "set_property offset 0 [get_bd_addr_segs {%s}]" % (seg_name) - ) - # TODO should propagate this information from the node instead of 256M - self.connect_cmds.append( - "set_property range 256M [get_bd_addr_segs {%s}]" % (seg_name) - ) - self.intf_names["aximm"].append((mm_intf_name[0], mm_intf_name[1])) - self.has_aximm = True - self.aximm_idx += 1 - - elif len(aximm_intf_name) != 0: - self.connect_cmds.append( - "make_bd_intf_pins_external [get_bd_intf_pins %s/%s]" - % (inst_name, aximm_intf_name[0][0]) - ) - ext_if_name = "m_axi_gmem%d" % (self.aximm_idx) - self.connect_cmds.append( - "set_property name %s [get_bd_intf_ports m_axi_gmem_0]" % ext_if_name - ) - self.connect_cmds.append("assign_bd_address") - seg_name = "%s/Data_m_axi_gmem/SEG_%s_Reg" % (inst_name, ext_if_name) - self.connect_cmds.append( - "set_property offset 0 [get_bd_addr_segs {%s}]" % (seg_name) - ) - # TODO should propagate this information from the node instead of 4G - self.connect_cmds.append( - "set_property range 4G [get_bd_addr_segs {%s}]" % (seg_name) - ) - self.intf_names["aximm"].append((ext_if_name, aximm_intf_name[0][1])) - self.has_aximm = True - self.aximm_idx += 1 - else: + is_mlo = node_inst.get_nodeattr("mlo_max_iter") + if is_mlo: self.is_mlo = True - for mm_intf_name in aximm_intf_name: + + for mm_intf_name in aximm_intf_name: + self.connect_cmds.append( + "make_bd_intf_pins_external [get_bd_intf_pins %s/%s]" % (inst_name, mm_intf_name[0]) + ) + + # Determine external interface name and address segment path + if node.op_type == "FINNLoop": + ext_if_name = mm_intf_name[0] self.connect_cmds.append( - "make_bd_intf_pins_external [get_bd_intf_pins %s/%s]" - % (inst_name, mm_intf_name[0]) + "set_property name %s [get_bd_intf_ports %s_0]" % (ext_if_name, ext_if_name) ) - # ext_if_name = "m_axi_gmem%d" % (self.aximm_idx) - # ext_if_name = f"m_axi_{inst_name}" - idx = inputs.index(node.input[1]) - ext_if_name = f"m_axi_MVAU_id_{idx}" + seg_name = "%s/%s/SEG_%s_Reg" % (inst_name, ext_if_name, ext_if_name) + else: + # Derive a unique name from graph input index or instance name + if node.input[1] in inputs: + idx = inputs.index(node.input[1]) + ext_if_name = f"m_axi_MVAU_id_{idx}" + else: + ext_if_name = f"m_axi_{inst_name}_{self.aximm_idx}" self.connect_cmds.append( "set_property name %s [get_bd_intf_ports axi_mm_0]" % (ext_if_name) ) - self.connect_cmds.append("assign_bd_address") - seg_name = "%s/%s_fetch_weights/axi_mm/SEG_%s_Reg" % ( inst_name, inst_name, ext_if_name, ) - self.connect_cmds.append( - "set_property offset 0 [get_bd_addr_segs {%s}]" % (seg_name) - ) - # TODO should propagate this information from the node instead of 256M - self.connect_cmds.append( - "set_property range 256M [get_bd_addr_segs {%s}]" % (seg_name) - ) - self.intf_names["aximm"].append((ext_if_name, mm_intf_name[1])) - self.has_aximm = True - self.aximm_idx += 1 + + self.connect_cmds.append("assign_bd_address") + self.connect_cmds.append("set_property offset 0 [get_bd_addr_segs {%s}]" % (seg_name)) + # TODO should propagate this information from the node instead of 256M + self.connect_cmds.append("set_property range 256M [get_bd_addr_segs {%s}]" % (seg_name)) + self.intf_names["aximm"].append((ext_if_name, mm_intf_name[1])) + # Track weight data files for AXI-MM simulation + if not node.op_type == "FINNLoop": + code_gen_dir = node_inst.get_nodeattr("code_gen_dir_ipgen") + npy_path = os.path.join(code_gen_dir, "input_1.npy") + if os.path.isfile(npy_path): + self.aximm_weight_files[ext_if_name] = npy_path + self.has_aximm = True + self.aximm_idx += 1 def connect_m_axis_external(self, node, idx=None): inst_name = node.name @@ -536,6 +494,10 @@ def apply(self, model): block_vlnv = "%s:%s:%s:1.0" % (block_vendor, block_library, block_name) model.set_metadata_prop("vivado_stitch_vlnv", block_vlnv) model.set_metadata_prop("vivado_stitch_ifnames", json.dumps(self.intf_names)) + if self.aximm_weight_files: + model.set_metadata_prop( + "vivado_stitch_aximm_weights", json.dumps(self.aximm_weight_files) + ) tcl.append( ( "ipx::package_project -root_dir %s/ip -vendor %s " diff --git a/src/finn/transformation/fpgadataflow/set_fifo_depths.py b/src/finn/transformation/fpgadataflow/set_fifo_depths.py index ba41def8d8..27a4202cb3 100644 --- a/src/finn/transformation/fpgadataflow/set_fifo_depths.py +++ b/src/finn/transformation/fpgadataflow/set_fifo_depths.py @@ -311,7 +311,7 @@ def apply(self, model): "ElementwiseAdd_hls", "ElementwiseMul_hls", ] - modified_mlo_nodes = [] + modified_mlo_nodes = {} for node in model.graph.node: # verify assumptions assert is_hls_node(node) or is_rtl_node(node), "Found non-fpgadataflow node: " + str( @@ -351,14 +351,47 @@ def apply(self, model): "Changed mem_mode from external to internal_decoupled for " + node.onnx_node.name ) - # do necessary temporary settings for mlo nodes + # do necessary temporary settings for external_mem nodes if node.onnx_node.op_type in mlo_optypes: mlo_max_iter = node.get_nodeattr("mlo_max_iter") - if mlo_max_iter: - modified_mlo_nodes.append(node.onnx_node.name) + has_mem_mode = "mem_mode" in node.get_nodeattr_types() + mmode = node.get_nodeattr("mem_mode") if has_mem_mode else None + if mlo_max_iter or mmode == "external_mem": + node_mlo_info = { + "orig_mem_mode": mmode, + "orig_mlo_max_iter": mlo_max_iter, + "saved_initializer": None, + } node.set_nodeattr("mlo_max_iter", 0) if node.onnx_node.op_type.startswith("MVAU"): node.set_nodeattr("mem_mode", "external") + # If the weight tensor has an initializer and is not + # already a graph input, we must promote it so that + # InsertFIFO / CreateStitchedIP treat it as a streaming + # input during FIFO sizing simulation. + param_input = node.onnx_node.input[1] + input_names_set = {inp.name for inp in model.graph.input} + if param_input not in input_names_set: + # Save and remove initializer + node_mlo_info["saved_initializer"] = model.get_initializer(param_input) + model.del_initializer(param_input) + # Move value_info to graph.input + param_vi = model.get_tensor_valueinfo(param_input) + if param_vi is not None and param_vi in model.graph.value_info: + model.graph.value_info.remove(param_vi) + else: + param_shape = model.get_tensor_shape(param_input) + param_vi = helper.make_tensor_value_info( + param_input, TensorProto.FLOAT, param_shape + ) + model.graph.input.append(param_vi) + # Ensure inFIFODepths covers the weight stream (index 1) + # since it was computed while mem_mode was external_mem + ifd = node.get_nodeattr("inFIFODepths") + if len(ifd) <= 1: + w_size = np.prod(node.get_folded_input_shape(1)[:-1]) + ifd.append(int(w_size) if w_size > 1 else 2) + node.set_nodeattr("inFIFODepths", ifd) elif ( node.onnx_node.op_type == "Thresholding_rtl" or node.onnx_node.op_type.startswith("Elementwise") @@ -375,7 +408,7 @@ def apply(self, model): dummy_threshs = np.sort(dummy_threshs, axis=1) model.set_initializer(param_input, dummy_threshs) self.ind_map[node.onnx_node.name] = ind - self.mlo_max_iter = mlo_max_iter + modified_mlo_nodes[node.onnx_node.name] = node_mlo_info reset_implementation(node) # insert stream infrastructure (DWC/FIFO) model = model.transform(InsertDWC()) @@ -431,6 +464,7 @@ def apply(self, model): # Apply depths back into the model; # also set in/outFIFODepths to zero for non-FIFO # nodes, preventing further FIFO insertion + weight_fifos_to_remove = [] for node in model.graph.node: # set FIFO depth, reset FIFO implementation, # and set implementation/ram styles @@ -465,20 +499,60 @@ def apply(self, model): node_inst.set_nodeattr("mem_mode", "external") reset_implementation(node_inst) modified_extw_nodes.remove(node.name) - # do the same resetting for mlo nodes + # do the same resetting for mlo / external_mem nodes if node.op_type in mlo_optypes: if node.name in modified_mlo_nodes and node.op_type.startswith("MVAU"): node_inst = getCustomOp(node) - node_inst.set_nodeattr("mlo_max_iter", self.mlo_max_iter) - node_inst.set_nodeattr("mem_mode", "internal_decoupled") + node_mlo_info = modified_mlo_nodes[node.name] + node_inst.set_nodeattr("mlo_max_iter", node_mlo_info["orig_mlo_max_iter"]) + node_inst.set_nodeattr("mem_mode", node_mlo_info["orig_mem_mode"]) + # Remove the weight-stream FIFO that was inserted during + # FIFO sizing (input index 1) and restore the original + # weight tensor connection. + if node_mlo_info["saved_initializer"] is not None: + # node.input[1] now points to FIFO output; find the FIFO + # param_input = node.input[1] + weight_fifo_out = node.input[1] + weight_fifo = model.find_producer(weight_fifo_out) + if weight_fifo is not None and weight_fifo.op_type.startswith( + "StreamingFIFO" + ): + # The original weight tensor is the FIFO's input + orig_weight_name = weight_fifo.input[0] + # Reconnect MVAU directly to original weight tensor + node.input[1] = orig_weight_name + # Defer removal of the FIFO node until after iteration + weight_fifos_to_remove.append((weight_fifo, weight_fifo_out)) + else: + orig_weight_name = weight_fifo_out + # Restore initializer and demote weight from graph input + model.set_initializer( + orig_weight_name, node_mlo_info["saved_initializer"] + ) + for gi in list(model.graph.input): + if gi.name == orig_weight_name: + model.graph.input.remove(gi) + model.graph.value_info.append(gi) + break reset_implementation(node_inst) - modified_mlo_nodes.remove(node.name) + del modified_mlo_nodes[node.name] + + # Remove weight-stream FIFOs that were deferred during the loop + for weight_fifo, weight_fifo_out in weight_fifos_to_remove: + model.graph.node.remove(weight_fifo) + for vi in list(model.graph.value_info): + if vi.name == weight_fifo_out: + model.graph.value_info.remove(vi) + break + if weight_fifo.name in fifos: + del fifos[weight_fifo.name] sorted_ind_map = dict(sorted(self.ind_map.items(), key=lambda item: item[1])) for k, v in sorted_ind_map.items(): node = model.get_node_from_name(k) node_inst = getCustomOp(node) - node_inst.set_nodeattr("mlo_max_iter", self.mlo_max_iter) + node_mlo_info = modified_mlo_nodes[node.name] + node_inst.set_nodeattr("mlo_max_iter", node_mlo_info["orig_mlo_max_iter"]) # remove initializer again param_input = node.input[1] param_input_vi = model.get_tensor_valueinfo(param_input) @@ -488,7 +562,7 @@ def apply(self, model): if node.op_type.startswith("Elementwise"): node_inst.set_nodeattr("rhs_style", "input") reset_implementation(node_inst) - modified_mlo_nodes.remove(node.name) + del modified_mlo_nodes[node.name] assert ( len(modified_extw_nodes) == 0 and len(fifos.keys()) == 0 diff --git a/tests/fpgadataflow/test_fpgadataflow_finnloop.py b/tests/fpgadataflow/test_fpgadataflow_finnloop.py index 15182a181f..7d6ab494bd 100644 --- a/tests/fpgadataflow/test_fpgadataflow_finnloop.py +++ b/tests/fpgadataflow/test_fpgadataflow_finnloop.py @@ -178,6 +178,7 @@ def make_loop_modelwrapper( "ActVal": 0, "binaryXnorMode": 0, "noActivation": 1, + "mem_mode": "external_mem", }, ), create_node( @@ -211,6 +212,7 @@ def make_loop_modelwrapper( "ActVal": 0, "binaryXnorMode": 0, "noActivation": 1, + "mem_mode": "external_mem", }, ), create_node( @@ -244,6 +246,7 @@ def make_loop_modelwrapper( "ActVal": 0, "binaryXnorMode": 0, "noActivation": 1, + "mem_mode": "external_mem", }, ), create_node( diff --git a/tests/fpgadataflow/test_fpgadataflow_mvau.py b/tests/fpgadataflow/test_fpgadataflow_mvau.py index 780efd170f..c828b4717e 100644 --- a/tests/fpgadataflow/test_fpgadataflow_mvau.py +++ b/tests/fpgadataflow/test_fpgadataflow_mvau.py @@ -856,14 +856,111 @@ def test_fpgadataflow_rtl_mvau( ).all(), "Output of ONNX model not matching output of stitched-IP RTL model!" +@pytest.mark.parametrize("mh", [18]) +@pytest.mark.parametrize("mw", [36]) +@pytest.mark.parametrize("pe", [6]) +@pytest.mark.parametrize("simd", [3]) +@pytest.mark.parametrize("th", [3]) +@pytest.mark.parametrize("idt_wdt", [[DataType["UINT8"], DataType["INT8"]]]) +@pytest.mark.parametrize("clk_ns", [4]) +@pytest.mark.fpgadataflow +@pytest.mark.slow +@pytest.mark.vivado +def test_fpgadataflow_rtl_tiled_mvau(mh, mw, pe, simd, th, idt_wdt, clk_ns): + # Tiled MVAU only supported on Versal (DSP58) + part = "xcvc1902-vsva2197-2MP-e-S" + + if (pe * simd) % th != 0: + pytest.skip("(PE * SIMD) must be divisible by TH") + + if mw % simd != 0: + pytest.skip("MW must be divisible by SIMD") + + if mh % pe != 0: + pytest.skip("MH must be divisible by PE") + + idt, wdt = idt_wdt + # Create test input vector (produced by SWG) + ofm_shape = (3, 3) + ofm_h, ofm_w = ofm_shape + ifm = helper.make_tensor_value_info("ifm", TensorProto.FLOAT, [1, ofm_h, ofm_w, mw]) + ofm = helper.make_tensor_value_info("ofm", TensorProto.FLOAT, (1, ofm_h, ofm_w, mh)) + W = gen_finn_dt_tensor(wdt, (mw, mh)) + model = make_single_matmul_modelwrapper(ifm, ofm, idt, wdt, W) + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(GiveReadableTensorNames()) + + # Create MatMul & obtain golden reference output + A = gen_finn_dt_tensor( + model.get_tensor_datatype("global_in"), model.get_tensor_shape("global_in") + ) + input_dict = prepare_inputs(A, idt, wdt, inp_name="global_in") + + # Execute ONNX model + output_matmul = oxe.execute_onnx(model, input_dict)["global_out"] + + # Create MVAU + model = model.transform(to_hw.InferQuantizedMatrixVectorActivation()) + model = model.transform(GiveUniqueNodeNames()) + + # Apply convert-to-rtl step + model = model.transform(SpecializeLayers(part)) + model = model.transform(GiveUniqueNodeNames()) + + assert model.graph.node[0].op_type == "MVAU_rtl" + # Apply folding with TH for tiled implementation + folding_config = { + "Defaults": {}, + "MVAU_rtl_0": { + "PE": pe, + "SIMD": simd, + "TH": th, + "resType": "dsp", + "mem_mode": "external_mem", + }, + } + model = model.transform(ApplyConfig(folding_config)) + model = model.transform(MinimizeWeightBitWidth()) + model = model.transform(MinimizeAccumulatorWidth()) + # make sure the changed datatypes are propagated through the network + model = model.transform(InferDataTypes()) + + # Run CPPsim + model = model.transform(SetExecMode("cppsim")) + model = model.transform(PrepareCppSim()) + model = model.transform(CompileCppSim()) + output_mvau_hls = oxe.execute_onnx(model, input_dict)["global_out"] + assert ( + output_matmul == output_mvau_hls + ).all(), "Output of ONNX model not matching output of node-by-node CPPsim!" + + # Run node-by-node RTLsim + model = model.transform(SetExecMode("rtlsim")) + model = model.transform(PrepareIP(part, clk_ns)) + model = model.transform(HLSSynthIP()) + model = model.transform(PrepareRTLSim()) + output_mvau_rtl = oxe.execute_onnx(model, input_dict)["global_out"] + assert ( + output_matmul == output_mvau_rtl + ).all(), "Output of ONNX model not matching output of node-by-node RTLsim!" + + # Run stitched-ip RTLsim + model = model.transform(InsertAndSetFIFODepths(part, clk_ns)) + model = model.transform(PrepareIP(part, clk_ns)) + model = model.transform(HLSSynthIP()) + model = model.transform(CreateStitchedIP(part, clk_ns)) + output_mvau_rtl_stitch = oxe.execute_onnx(model, input_dict)["global_out"] + assert ( + output_matmul == output_mvau_rtl_stitch + ).all(), "Output of ONNX model not matching output of tiled stitched-IP RTL model!" + + @pytest.mark.parametrize("mh", [32]) @pytest.mark.parametrize("mw", [16]) @pytest.mark.parametrize("n_vectors", [32]) @pytest.mark.parametrize("pe", [1, 16, 32]) @pytest.mark.parametrize("simd", [1, 8, 16]) -@pytest.mark.parametrize( - "idt_wdt", [[DataType["INT8"], DataType["INT8"]], [DataType["INT4"], DataType["INT4"]]] -) +@pytest.mark.parametrize("idt_wdt", [[DataType["INT4"], DataType["INT4"]]]) @pytest.mark.parametrize( "part", ["xcvc1902-vsva2197-2MP-e-S", "xcku3p-ffva676-1-e", "xc7z020clg400-1"] ) From 64d52f77c86ce425babab844c379cb810295f14d Mon Sep 17 00:00:00 2001 From: dkorolij Date: Tue, 14 Apr 2026 13:41:52 +0100 Subject: [PATCH 02/17] MLO GEMM fixes - MVAU tiled passing tests. - MMAU 1D/2D added. --- finn-rtllib/fetch_weights/fetch_weights.sv | 2 +- .../fetch_weights/fetch_weights_wrapper.v | 5 +- finn-rtllib/mlo/fetch_weights.sv | 226 --------- finn-rtllib/mlo/fetch_weights_wrapper.v | 164 ------- .../mlo/infrastructure/intermediate_frames.sv | 46 +- finn-rtllib/mlo/local_weight_buffer.sv | 305 ------------- finn-rtllib/mmu/1d/collect_out_1d.sv | 111 +++++ finn-rtllib/mmu/1d/cu_mmau_1d.sv | 403 ++++++++++++++++ finn-rtllib/mmu/1d/sched_weights_1d.sv | 141 ++++++ finn-rtllib/mmu/1d/sft_reg.sv | 25 + finn-rtllib/mmu/2d/collect_out_2d.sv | 120 +++++ finn-rtllib/mmu/2d/cu_mmau_2d.sv | 432 ++++++++++++++++++ finn-rtllib/mmu/2d/sched_weights_2d.sv | 165 +++++++ finn-rtllib/mmu/en_global.sv | 99 ++++ finn-rtllib/mmu/mmu_axi.sv | 252 ++++++++++ finn-rtllib/mmu/mmu_axi_wrapper.v | 100 ++++ finn-rtllib/mmu/q_writer.sv | 124 +++++ finn-rtllib/mmu/reorder_out.sv | 341 ++++++++++++++ finn-rtllib/mmu/replay_buff_mmau.sv | 397 ++++++++++++++++ finn-rtllib/mmu/sched_activations.sv | 199 ++++++++ src/finn/custom_op/fpgadataflow/hwcustomop.py | 27 +- .../fpgadataflow/matrixvectoractivation.py | 56 ++- .../custom_op/fpgadataflow/rtl/finn_loop.py | 50 ++ .../rtl/matrixvectoractivation_rtl.py | 97 +++- src/finn/custom_op/fpgadataflow/templates.py | 3 - .../fpgadataflow/loop_rolling.py | 12 +- src/finn/util/mlo_sim.py | 35 +- .../test_fpgadataflow_finnloop.py | 294 +++++++++++- tests/fpgadataflow/test_fpgadataflow_mvau.py | 104 ++++- 29 files changed, 3552 insertions(+), 783 deletions(-) delete mode 100644 finn-rtllib/mlo/fetch_weights.sv delete mode 100644 finn-rtllib/mlo/fetch_weights_wrapper.v delete mode 100644 finn-rtllib/mlo/local_weight_buffer.sv create mode 100644 finn-rtllib/mmu/1d/collect_out_1d.sv create mode 100644 finn-rtllib/mmu/1d/cu_mmau_1d.sv create mode 100644 finn-rtllib/mmu/1d/sched_weights_1d.sv create mode 100644 finn-rtllib/mmu/1d/sft_reg.sv create mode 100644 finn-rtllib/mmu/2d/collect_out_2d.sv create mode 100644 finn-rtllib/mmu/2d/cu_mmau_2d.sv create mode 100644 finn-rtllib/mmu/2d/sched_weights_2d.sv create mode 100644 finn-rtllib/mmu/en_global.sv create mode 100644 finn-rtllib/mmu/mmu_axi.sv create mode 100644 finn-rtllib/mmu/mmu_axi_wrapper.v create mode 100644 finn-rtllib/mmu/q_writer.sv create mode 100644 finn-rtllib/mmu/reorder_out.sv create mode 100644 finn-rtllib/mmu/replay_buff_mmau.sv create mode 100644 finn-rtllib/mmu/sched_activations.sv diff --git a/finn-rtllib/fetch_weights/fetch_weights.sv b/finn-rtllib/fetch_weights/fetch_weights.sv index e5b740c164..4573fe53bf 100644 --- a/finn-rtllib/fetch_weights/fetch_weights.sv +++ b/finn-rtllib/fetch_weights/fetch_weights.sv @@ -58,7 +58,7 @@ module fetch_weights #( int unsigned OWSIMD = (PE * SIMD) / TH, int unsigned DS_BITS_BA = (IWSIMD*WEIGHT_WIDTH+7)/8 * 8, int unsigned WS_BITS_BA = (OWSIMD*WEIGHT_WIDTH+7)/8 * 8, - logic[ADDR_BITS-1:0] LAYER_OFFS = ((MH*MW*WEIGHT_WIDTH+7)/8) & ~7 // 8-byte aligned + logic[ADDR_BITS-1:0] LAYER_OFFS = ((MH*MW*WEIGHT_WIDTH+7)/8 + (DATA_BITS/8-1)) & ~(DATA_BITS/8-1) // AXI bus-width aligned ) ( input wire aclk, input wire aresetn, diff --git a/finn-rtllib/fetch_weights/fetch_weights_wrapper.v b/finn-rtllib/fetch_weights/fetch_weights_wrapper.v index 06fa66031d..92827edc4e 100644 --- a/finn-rtllib/fetch_weights/fetch_weights_wrapper.v +++ b/finn-rtllib/fetch_weights/fetch_weights_wrapper.v @@ -49,8 +49,8 @@ module $MODULE_NAME_AXI_WRAPPER$ #( parameter IDX_BITS = 16, // Safely deducible parameters - parameter IWSIMD = (TH > 1) ? ((PE*SIMD)/TH) : SIMD, - parameter WSIMD = (PE * SIMD) / TH, + parameter IWSIMD = $IWSIMD$, + parameter WSIMD = $WSIMD$, parameter DS_BITS_BA = (IWSIMD*WEIGHT_WIDTH+7)/8 * 8, parameter WS_BITS_BA = (WSIMD*WEIGHT_WIDTH+7)/8 * 8 )( @@ -148,6 +148,7 @@ fetch_weights #( .PE(PE), .SIMD(SIMD), .TH(TH), .MH(MH), .MW(MW), .N_REPS(N_REPS), .WEIGHT_WIDTH(WEIGHT_WIDTH), + .IWSIMD(IWSIMD), .OWSIMD(WSIMD), .ADDR_BITS(ADDR_BITS), .DATA_BITS(DATA_BITS), .LEN_BITS(LEN_BITS), .IDX_BITS(IDX_BITS), .N_LAYERS(N_LAYERS), `ifdef EN_MLO diff --git a/finn-rtllib/mlo/fetch_weights.sv b/finn-rtllib/mlo/fetch_weights.sv deleted file mode 100644 index fda40e45d8..0000000000 --- a/finn-rtllib/mlo/fetch_weights.sv +++ /dev/null @@ -1,226 +0,0 @@ -/****************************************************************************** - * Copyright (C) 2024, Advanced Micro Devices, Inc. - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, - * this list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, - * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR - * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR - * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, - * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR - * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF - * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - *****************************************************************************/ - -module fetch_weights #( - int unsigned PE, - int unsigned SIMD, - int unsigned MH, - int unsigned MW, - int unsigned N_REPS, - int unsigned WEIGHT_WIDTH = 8, - - int unsigned ADDR_BITS = 64, - int unsigned DATA_BITS = 256, - int unsigned LEN_BITS = 32, - int unsigned IDX_BITS = 16, - - int unsigned N_LAYERS, - - int unsigned QDEPTH = 8, - int unsigned EN_OREG = 1, - int unsigned N_DCPL_STGS = 1, - int unsigned DBG = 0, - - // Safely deducible parameters - int unsigned DS_BITS_BA = (SIMD*WEIGHT_WIDTH+7)/8 * 8, - int unsigned WS_BITS_BA = (PE*SIMD*WEIGHT_WIDTH+7)/8 * 8, - logic[ADDR_BITS-1:0] LAYER_OFFS = ((MH*MW*WEIGHT_WIDTH+7)/8) & ~7 // 8-byte aligned -) ( - input logic aclk, - input logic aresetn, - - output logic m_done, - - // AXI - output logic[ADDR_BITS-1:0] m_axi_ddr_araddr, - output logic[1:0] m_axi_ddr_arburst, - output logic[3:0] m_axi_ddr_arcache, - output logic[1:0] m_axi_ddr_arid, - output logic[7:0] m_axi_ddr_arlen, - output logic[0:0] m_axi_ddr_arlock, - output logic[2:0] m_axi_ddr_arprot, - output logic[2:0] m_axi_ddr_arsize, - input logic m_axi_ddr_arready, - output logic m_axi_ddr_arvalid, - output logic[ADDR_BITS-1:0] m_axi_ddr_awaddr, - output logic[1:0] m_axi_ddr_awburst, - output logic[3:0] m_axi_ddr_awcache, - output logic[1:0] m_axi_ddr_awid, - output logic[7:0] m_axi_ddr_awlen, - output logic[0:0] m_axi_ddr_awlock, - output logic[2:0] m_axi_ddr_awprot, - output logic[2:0] m_axi_ddr_awsize, - input logic m_axi_ddr_awready, - output logic m_axi_ddr_awvalid, - input logic[DATA_BITS-1:0] m_axi_ddr_rdata, - input logic[1:0] m_axi_ddr_rid, - input logic m_axi_ddr_rlast, - input logic[1:0] m_axi_ddr_rresp, - output logic m_axi_ddr_rready, - input logic m_axi_ddr_rvalid, - output logic[DATA_BITS-1:0] m_axi_ddr_wdata, - output logic m_axi_ddr_wlast, - output logic[DATA_BITS/8-1:0] m_axi_ddr_wstrb, - input logic m_axi_ddr_wready, - output logic m_axi_ddr_wvalid, - input logic[1:0] m_axi_ddr_bid, - input logic[1:0] m_axi_ddr_bresp, - output logic m_axi_ddr_bready, - input logic m_axi_ddr_bvalid, - - // Index - input logic s_idx_tvalid, - output logic s_idx_tready, - input logic[IDX_BITS-1:0] s_idx_tdata, - - // Stream - // TODO: Should we reg this? Would be quite wide ... - output logic m_axis_tvalid, - input logic m_axis_tready, - output logic[WS_BITS_BA-1:0] m_axis_tdata -); - -localparam int unsigned WMAT_SIZE = ((MH*MW*WEIGHT_WIDTH+7)/8) & ~7; - -// Offsets -logic [N_LAYERS-1:0][ADDR_BITS-1:0] l_offsets; -for(genvar i = 0; i < N_LAYERS; i++) begin - assign l_offsets[i] = (i * LAYER_OFFS); -end - -logic q_idx_out_tvalid, q_idx_out_tready; -logic [IDX_BITS-1:0] q_idx_out_tdata; -logic [ADDR_BITS-1:0] q_dma_addr; -logic [LEN_BITS-1:0] q_dma_len; - -// Queues -Q_srl #( - .depth(QDEPTH), - .width(IDX_BITS) -) inst_queue_in ( - .clock(aclk), .reset(!aresetn), - .count(), .maxcount(), - .i_d(s_idx_tdata), .i_v(s_idx_tvalid), .i_r(s_idx_tready), - .o_d(q_idx_out_tdata), .o_v(q_idx_out_tvalid), .o_r(q_idx_out_tready) -); - -assign q_dma_addr = l_offsets[q_idx_out_tdata]; -assign q_dma_len = WMAT_SIZE; - -// DMA -logic axis_dma_tvalid; -logic axis_dma_tready; -logic[DATA_BITS-1:0] axis_dma_tdata; -logic[DATA_BITS/8-1:0] axis_dma_tkeep; -logic axis_dma_tlast; - -cdma_u_rd #( - .DATA_BITS(DATA_BITS), - .ADDR_BITS(ADDR_BITS), - .LEN_BITS(LEN_BITS) -) inst_dma ( - .aclk(aclk), .aresetn(aresetn), - - .rd_valid(q_idx_out_tvalid), .rd_ready(q_idx_out_tready), - .rd_paddr(q_dma_addr), .rd_len(q_dma_len), - .rd_done(m_done), - - .m_axi_ddr_arvalid(m_axi_ddr_arvalid), - .m_axi_ddr_arready(m_axi_ddr_arready), - .m_axi_ddr_araddr(m_axi_ddr_araddr), - .m_axi_ddr_arid(m_axi_ddr_arid), - .m_axi_ddr_arlen(m_axi_ddr_arlen), - .m_axi_ddr_arsize(m_axi_ddr_arsize), - .m_axi_ddr_arburst(m_axi_ddr_arburst), - .m_axi_ddr_arlock(m_axi_ddr_arlock), - .m_axi_ddr_arcache(m_axi_ddr_arcache), - .m_axi_ddr_arprot(m_axi_ddr_arprot), - .m_axi_ddr_rvalid(m_axi_ddr_rvalid), - .m_axi_ddr_rready(m_axi_ddr_rready), - .m_axi_ddr_rdata(m_axi_ddr_rdata), - .m_axi_ddr_rlast(m_axi_ddr_rlast), - .m_axi_ddr_rid(m_axi_ddr_rid), - .m_axi_ddr_rresp(m_axi_ddr_rresp), - - .m_axis_ddr_tvalid(axis_dma_tvalid), - .m_axis_ddr_tready(axis_dma_tready), - .m_axis_ddr_tdata(axis_dma_tdata), - .m_axis_ddr_tkeep(axis_dma_tkeep), - .m_axis_ddr_tlast(axis_dma_tlast) -); - -// Width conversion -logic axis_dwc_tvalid; -logic axis_dwc_tready; -logic[DS_BITS_BA-1:0] axis_dwc_tdata; -logic[(DS_BITS_BA)/8-1:0] axis_dwc_tkeep; -logic axis_dwc_tlast; - -axis_fifo_adapter #( - .S_DATA_WIDTH(DATA_BITS), .M_DATA_WIDTH(DS_BITS_BA) -) inst_dwc ( - .clk(aclk), .rst(~aresetn), - .pause_req('0), .s_axis_tid('0), .s_axis_tdest('0), .s_axis_tuser('0), - .s_axis_tvalid(axis_dma_tvalid), .s_axis_tready(axis_dma_tready), .s_axis_tdata(axis_dma_tdata), .s_axis_tkeep(axis_dma_tkeep), .s_axis_tlast(axis_dma_tlast), - .pause_ack(), .m_axis_tid(), .m_axis_tdest(), .m_axis_tuser(), - .m_axis_tvalid(axis_dwc_tvalid), .m_axis_tready(axis_dwc_tready), .m_axis_tdata(axis_dwc_tdata), .m_axis_tkeep(axis_dwc_tkeep), .m_axis_tlast(axis_dwc_tlast) -); - -// Double buffer -logic axis_lwb_tvalid; -logic axis_lwb_tready; -logic[WS_BITS_BA-1:0] axis_lwb_tdata; - -local_weight_buffer #( - .PE(PE), .SIMD(SIMD), .MH(MH), .MW(MW), .N_REPS(N_REPS), .WEIGHT_WIDTH(WEIGHT_WIDTH), .DBG(DBG) -) inst_weight_buff ( - .clk(aclk), .rst(~aresetn), - .ivld(axis_dwc_tvalid), .irdy(axis_dwc_tready), .idat(axis_dwc_tdata), - .ovld(axis_lwb_tvalid), .ordy(axis_lwb_tready), .odat(axis_lwb_tdata) -); - -// Reg slice -if(EN_OREG) begin - skid #( - .DATA_WIDTH(WS_BITS_BA), .FEED_STAGES(N_DCPL_STGS) - ) inst_oreg ( - .clk(aclk), .rst(~aresetn), - .ivld(axis_lwb_tvalid), .irdy(axis_lwb_tready), .idat(axis_lwb_tdata), - .ovld(m_axis_tvalid), .ordy(m_axis_tready), .odat(m_axis_tdata) - ); -end else begin - assign m_axis_tvalid = axis_lwb_tvalid; - assign axis_lwb_tready = m_axis_tready; - assign m_axis_tdata = axis_lwb_tdata; -end - -endmodule diff --git a/finn-rtllib/mlo/fetch_weights_wrapper.v b/finn-rtllib/mlo/fetch_weights_wrapper.v deleted file mode 100644 index dc92478b6c..0000000000 --- a/finn-rtllib/mlo/fetch_weights_wrapper.v +++ /dev/null @@ -1,164 +0,0 @@ -/****************************************************************************** - * Copyright (C) 2024, Advanced Micro Devices, Inc. - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, - * this list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, - * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR - * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR - * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, - * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR - * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF - * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - *****************************************************************************/ - -module $MODULE_NAME_AXI_WRAPPER$ #( - parameter MW = $MW$, - parameter MH = $MH$, - parameter PE = $PE$, - parameter SIMD = $SIMD$, - parameter N_REPS = $N_REPS$, - parameter WEIGHT_WIDTH = $WEIGHT_WIDTH$, - parameter N_LAYERS = $N_LAYERS$, - - parameter ADDR_BITS = 64, - parameter DATA_BITS = 256, - parameter LEN_BITS = 32, - parameter IDX_BITS = 16, - - // Safely deducible parameters - parameter WS_BITS_BA = (PE*SIMD*WEIGHT_WIDTH+7)/8 * 8 -)( - // Global Control - (* X_INTERFACE_PARAMETER = "ASSOCIATED_BUSIF axi_mm:in_idx0_V:out0_V, ASSOCIATED_RESET ap_rst_n" *) - (* X_INTERFACE_INFO = "xilinx.com:signal:clock:1.0 ap_clk CLK" *) - input ap_clk, - (* X_INTERFACE_PARAMETER = "POLARITY ACTIVE_LOW" *) - input ap_rst_n, - - // Completion - output wire out_done, - - // AXI - (* X_INTERFACE_INFO = "xilinx.com:interface:aximm:1.0 axi_mm" *) - output wire[ADDR_BITS-1:0] axi_mm_araddr, - output wire[1:0] axi_mm_arburst, - output wire[3:0] axi_mm_arcache, - output wire[1:0] axi_mm_arid, - output wire[7:0] axi_mm_arlen, - output wire[0:0] axi_mm_arlock, - output wire[2:0] axi_mm_arprot, - output wire[2:0] axi_mm_arsize, - input wire axi_mm_arready, - output wire axi_mm_arvalid, - output wire[ADDR_BITS-1:0] axi_mm_awaddr, - output wire[1:0] axi_mm_awburst, - output wire[3:0] axi_mm_awcache, - output wire[1:0] axi_mm_awid, - output wire[7:0] axi_mm_awlen, - output wire[0:0] axi_mm_awlock, - output wire[2:0] axi_mm_awprot, - output wire[2:0] axi_mm_awsize, - input wire axi_mm_awready, - output wire axi_mm_awvalid, - input wire[DATA_BITS-1:0] axi_mm_rdata, - input wire[1:0] axi_mm_rid, - input wire axi_mm_rlast, - input wire[1:0] axi_mm_rresp, - output wire axi_mm_rready, - input wire axi_mm_rvalid, - output wire[DATA_BITS-1:0] axi_mm_wdata, - output wire axi_mm_wlast, - output wire[DATA_BITS/8-1:0] axi_mm_wstrb, - input wire axi_mm_wready, - output wire axi_mm_wvalid, - input wire[1:0] axi_mm_bid, - input wire[1:0] axi_mm_bresp, - output wire axi_mm_bready, - input wire axi_mm_bvalid, - - // Index - input wire in_idx0_V_tvalid, - output wire in_idx0_V_tready, - input wire[IDX_BITS-1:0] in_idx0_V_tdata, - - // Stream - output wire out0_V_tvalid, - input wire out0_V_tready, - output wire[WS_BITS_BA-1:0] out0_V_tdata -); - - -fetch_weights #( - .PE(PE), .SIMD(SIMD), .MH(MH), .MW(MW), .N_REPS(N_REPS), - .WEIGHT_WIDTH(WEIGHT_WIDTH), - .ADDR_BITS(ADDR_BITS), .DATA_BITS(DATA_BITS), .LEN_BITS(LEN_BITS), .IDX_BITS(IDX_BITS), - .N_LAYERS(N_LAYERS) -) inst ( - .aclk (ap_clk), - .aresetn (ap_rst_n), - - .m_axi_ddr_araddr (axi_mm_araddr), - .m_axi_ddr_arburst (axi_mm_arburst), - .m_axi_ddr_arcache (axi_mm_arcache), - .m_axi_ddr_arid (axi_mm_arid), - .m_axi_ddr_arlen (axi_mm_arlen), - .m_axi_ddr_arlock (axi_mm_arlock), - .m_axi_ddr_arprot (axi_mm_arprot), - .m_axi_ddr_arsize (axi_mm_arsize), - .m_axi_ddr_arready (axi_mm_arready), - .m_axi_ddr_arvalid (axi_mm_arvalid), - .m_axi_ddr_awaddr (axi_mm_awaddr), - .m_axi_ddr_awburst (axi_mm_awburst), - .m_axi_ddr_awcache (axi_mm_awcache), - .m_axi_ddr_awid (axi_mm_awid), - .m_axi_ddr_awlen (axi_mm_awlen), - .m_axi_ddr_awlock (axi_mm_awlock), - .m_axi_ddr_awprot (axi_mm_awprot), - .m_axi_ddr_awsize (axi_mm_awsize), - .m_axi_ddr_awready (axi_mm_awready), - .m_axi_ddr_awvalid (axi_mm_awvalid), - .m_axi_ddr_rdata (axi_mm_rdata), - .m_axi_ddr_rid (axi_mm_rid), - .m_axi_ddr_rlast (axi_mm_rlast), - .m_axi_ddr_rresp (axi_mm_rresp), - .m_axi_ddr_rready (axi_mm_rready), - .m_axi_ddr_rvalid (axi_mm_rvalid), - .m_axi_ddr_wdata (axi_mm_wdata), - .m_axi_ddr_wlast (axi_mm_wlast), - .m_axi_ddr_wstrb (axi_mm_wstrb), - .m_axi_ddr_wready (axi_mm_wready), - .m_axi_ddr_wvalid (axi_mm_wvalid), - .m_axi_ddr_bid (axi_mm_bid), - .m_axi_ddr_bresp (axi_mm_bresp), - .m_axi_ddr_bready (axi_mm_bready), - .m_axi_ddr_bvalid (axi_mm_bvalid), - - .s_idx_tvalid (in_idx0_V_tvalid), - .s_idx_tready (in_idx0_V_tready), - .s_idx_tdata (in_idx0_V_tdata), - - .m_axis_tvalid (out0_V_tvalid), - .m_axis_tready (out0_V_tready), - .m_axis_tdata (out0_V_tdata) -); - -endmodule // $MODULE_NAME_AXI_WRAPPER$ diff --git a/finn-rtllib/mlo/infrastructure/intermediate_frames.sv b/finn-rtllib/mlo/infrastructure/intermediate_frames.sv index c36cb5cd30..c924007955 100644 --- a/finn-rtllib/mlo/infrastructure/intermediate_frames.sv +++ b/finn-rtllib/mlo/infrastructure/intermediate_frames.sv @@ -399,7 +399,7 @@ logic s_axis_int_tvalid, s_axis_int_tready; logic [OLEN_BITS-1:0] s_axis_int_tdata; logic m_axis_int_tvalid, m_axis_int_tready; -logic [OLEN_BITS-1:0] m_axis_int_tdata; +logic [ILEN_BITS-1:0] m_axis_int_tdata; logic [FM_BEATS_IN_BITS-1:0] cnt_dwc_C = '0; always_ff @(posedge aclk) begin @@ -410,42 +410,18 @@ end logic last_dwc_in; assign last_dwc_in = (cnt_dwc_C == FM_BEATS_IN-1); -axis_fifo_adapter #(.S_DATA_WIDTH(OLEN_BITS), .M_DATA_WIDTH(DATA_BITS)) inst_dwc_wr ( - .clk(aclk), - .rst(~aresetn), - - .pause_req('0), .s_axis_tid('0), .s_axis_tdest('0), .s_axis_tuser('0), - .s_axis_tvalid(s_axis_int_tvalid), - .s_axis_tready(s_axis_int_tready), - .s_axis_tdata (s_axis_int_tdata), - .s_axis_tkeep ('1), - .s_axis_tlast (last_dwc_in), - - .pause_ack(), .m_axis_tid(), .m_axis_tdest(), .m_axis_tuser(), - .m_axis_tvalid(axis_dma_wr_tvalid), - .m_axis_tready(axis_dma_wr_tready), - .m_axis_tdata (axis_dma_wr_tdata), - .m_axis_tkeep (axis_dma_wr_tkeep), - .m_axis_tlast (axis_dma_wr_tlast) +// DWC write: OLEN_BITS -> DATA_BITS (body output -> DMA) +if_dwc_sink inst_dwc_wr ( + .aclk(aclk), .aresetn(aresetn), + .s_axis_tvalid(s_axis_int_tvalid), .s_axis_tready(s_axis_int_tready), .s_axis_tdata(s_axis_int_tdata), .s_axis_tkeep({(OLEN_BITS/8){1'b1}}), .s_axis_tlast(last_dwc_in), + .m_axis_tvalid(axis_dma_wr_tvalid), .m_axis_tready(axis_dma_wr_tready), .m_axis_tdata(axis_dma_wr_tdata), .m_axis_tkeep(axis_dma_wr_tkeep), .m_axis_tlast(axis_dma_wr_tlast) ); -axis_fifo_adapter #(.S_DATA_WIDTH(DATA_BITS), .M_DATA_WIDTH(ILEN_BITS)) inst_dwc_rd ( - .clk(aclk), - .rst(~aresetn), - - .pause_req('0), .s_axis_tid('0), .s_axis_tdest('0), .s_axis_tuser('0), - .s_axis_tvalid(axis_dma_rd_tvalid), - .s_axis_tready(axis_dma_rd_tready), - .s_axis_tdata (axis_dma_rd_tdata), - .s_axis_tkeep (axis_dma_rd_tkeep), - .s_axis_tlast (axis_dma_rd_tlast), - - .pause_ack(), .m_axis_tid(), .m_axis_tdest(), .m_axis_tuser(), - .m_axis_tvalid(m_axis_int_tvalid), - .m_axis_tready(m_axis_int_tready), - .m_axis_tdata (m_axis_int_tdata), - .m_axis_tkeep (), - .m_axis_tlast () +// DWC read: DATA_BITS -> ILEN_BITS (DMA -> body input) +if_dwc_source inst_dwc_rd ( + .aclk(aclk), .aresetn(aresetn), + .s_axis_tvalid(axis_dma_rd_tvalid), .s_axis_tready(axis_dma_rd_tready), .s_axis_tdata(axis_dma_rd_tdata), .s_axis_tkeep(axis_dma_rd_tkeep), .s_axis_tlast(axis_dma_rd_tlast), + .m_axis_tvalid(m_axis_int_tvalid), .m_axis_tready(m_axis_int_tready), .m_axis_tdata(m_axis_int_tdata), .m_axis_tkeep(), .m_axis_tlast() ); // REG diff --git a/finn-rtllib/mlo/local_weight_buffer.sv b/finn-rtllib/mlo/local_weight_buffer.sv deleted file mode 100644 index cdc2a9eca3..0000000000 --- a/finn-rtllib/mlo/local_weight_buffer.sv +++ /dev/null @@ -1,305 +0,0 @@ -/****************************************************************************** - * Copyright (C) 2024, Advanced Micro Devices, Inc. - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, - * this list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, - * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR - * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR - * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, - * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR - * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF - * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - *****************************************************************************/ - -module local_weight_buffer #( - int unsigned PE, - int unsigned SIMD, - int unsigned WEIGHT_WIDTH, - int unsigned MH, - int unsigned MW, - int unsigned N_REPS, - int unsigned DBG = 0 -) ( - input logic clk, - input logic rst, - - input logic ivld, - output logic irdy, - input logic [SIMD-1:0][WEIGHT_WIDTH-1:0] idat, - - output logic ovld, - input logic ordy, - output logic [PE-1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] odat -); - -// ---------------------------------------------------------------------------- -// Consts and types -// ---------------------------------------------------------------------------- - -localparam int unsigned SF = MW/SIMD; -localparam int unsigned NF = MH/PE; -localparam int unsigned N_TLS = SF * NF; - -localparam int unsigned SIMD_BITS = (SIMD == 1) ? 1 : $clog2(SIMD); -localparam int unsigned PE_BITS = (PE == 1) ? 1 : $clog2(PE); -localparam int unsigned WGT_ADDR_BITS = $clog2(NF * SF); -localparam int unsigned RAM_BITS = (SIMD*WEIGHT_WIDTH + 7)/8 * 8; -localparam int unsigned WGT_EN_BITS = RAM_BITS / 8; -localparam int unsigned N_TLS_BITS = $clog2(N_TLS); -localparam int unsigned N_REPS_BITS = $clog2(N_REPS); - -typedef enum logic[1:0] {ST_WR_0, ST_WR_0_WAIT, ST_WR_1, ST_WR_1_WAIT} state_wr_t; -typedef enum logic {ST_RD_0, ST_RD_1} state_rd_t; - -// ---------------------------------------------------------------------------- -// Writer -// ---------------------------------------------------------------------------- - -// -- Regs -state_wr_t state_wr_C = ST_WR_0, state_wr_N; -state_rd_t state_rd_C = ST_RD_0, state_rd_N; - -logic[N_TLS_BITS-1:0] wr_pntr_C = '0, wr_pntr_N; -logic[PE_BITS-1:0] curr_pe_C = '0, curr_pe_N; - -// -- Signals -logic [1:0][PE-1:0][WGT_EN_BITS-1:0] a_we; // Bank enables -logic [1:0][WGT_ADDR_BITS-1:0] a_addr; -logic [1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] a_data_in; - -// -- REG -always_ff @( posedge clk ) begin : REG_PROC_WR - if(rst) begin - state_wr_C <= ST_WR_0; - - wr_pntr_C <= '0; - curr_pe_C <= '0; - end - else begin - state_wr_C <= state_wr_N; - - wr_pntr_C <= wr_pntr_N; - curr_pe_C <= curr_pe_N; - end -end - -// -- NSL -always_comb begin : NSL_PROC_WR - state_wr_N = state_wr_C; - - case (state_wr_C) - ST_WR_0: - if((curr_pe_C == PE - 1) && (wr_pntr_C == N_TLS - 1) && ivld) begin - state_wr_N = (state_rd_C == ST_RD_0) ? ST_WR_1 : ST_WR_0_WAIT; - end - - ST_WR_0_WAIT: - state_wr_N = (state_rd_C == ST_RD_0) ? ST_WR_1 : ST_WR_0_WAIT; - - ST_WR_1: - if((curr_pe_C == PE - 1) && (wr_pntr_C == N_TLS - 1) && ivld) begin - state_wr_N = (state_rd_C == ST_RD_1) ? ST_WR_0 : ST_WR_1_WAIT; - end - - ST_WR_1_WAIT: - state_wr_N = (state_rd_C == ST_RD_1) ? ST_WR_0 : ST_WR_1_WAIT; - - endcase -end - -// -- DP -always_comb begin : DP_PROC_WR - wr_pntr_N = wr_pntr_C; - curr_pe_N = curr_pe_C; - - // Input - irdy = 1'b0; - - // Buffers - a_we = '0; - for(int i = 0; i < 2; i++) begin - a_addr[i] = wr_pntr_C; - a_data_in[i] = idat; - end - - // Write and count - case (state_wr_C) - ST_WR_0, ST_WR_1: begin - irdy = 1'b1; - - if(ivld) begin - for(int i = 0; i < PE; i++) begin - if(curr_pe_C == i) begin - a_we[state_wr_C == ST_WR_1][i] = '1; - end - end - - curr_pe_N = (curr_pe_C == PE-1) ? 0 : curr_pe_C + 1; - wr_pntr_N = (curr_pe_C == PE-1) ? ((wr_pntr_C == N_TLS-1) ? 0 : wr_pntr_C + 1) : wr_pntr_C; - end - end - endcase - -end - -// ---------------------------------------------------------------------------- -// Reader -// ---------------------------------------------------------------------------- - -// -- Regs -logic [N_TLS_BITS-1:0] rd_pntr_C = '0, rd_pntr_N; -logic [N_REPS_BITS-1:0] reps_C = '0, reps_N; - -//logic [15:0] rd_pntr_C = '0, rd_pntr_N; -//logic [15:0] reps_C = '0, reps_N; - -logic [1:0] vld_s0_C = '0, vld_s0_N; -logic [1:0] vld_s1_C = '0, vld_s1_N; - -logic vld_C = '0, vld_N; -logic [PE-1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] odat_C = '0, odat_N; - -// -- Signals -logic [1:0][WGT_ADDR_BITS-1:0] b_addr; -logic [1:0][PE-1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] odat_ram; - -// -- REG -always_ff @( posedge clk ) begin : REG_PROC_RD - if(rst) begin - state_rd_C <= ST_RD_0; - - rd_pntr_C <= '0; - reps_C <= '0; - - vld_s0_C <= '0; - vld_s1_C <= '0; - vld_C <= '0; - odat_C <= 'X; - end - else begin - state_rd_C <= state_rd_N; - - rd_pntr_C <= rd_pntr_N; - reps_C <= reps_N; - - vld_s0_C <= vld_s0_N; - vld_s1_C <= vld_s1_N; - vld_C <= vld_N; - odat_C <= odat_N; - end -end - -// -- NSL -always_comb begin : NSL_PROC_RD - state_rd_N = state_rd_C; - - case (state_rd_C) - ST_RD_0: - if(ordy && ((state_wr_C == ST_WR_0) ? (wr_pntr_C > rd_pntr_C) : 1'b1)) begin - if((rd_pntr_C == N_TLS-1) && (reps_C == N_REPS-1)) begin - state_rd_N = ST_RD_1; - end - end - - ST_RD_1: - if(ordy && ((state_wr_C == ST_WR_1) ? (wr_pntr_C > rd_pntr_C) : 1'b1)) begin - if((rd_pntr_C == N_TLS-1) && (reps_C == N_REPS-1)) begin - state_rd_N = ST_RD_0; - end - end - endcase -end - -// -- DP -always_comb begin : DP_PROC_RD - rd_pntr_N = rd_pntr_C; - reps_N = reps_C; - - for(int i = 0; i < 2; i++) begin - vld_s0_N[i] = ordy ? 1'b0 : vld_s0_C[i]; - vld_s1_N[i] = ordy ? vld_s0_C[i] : vld_s1_C[i]; - end - - vld_N = ordy ? |vld_s1_C : vld_C; - odat_N = ordy ? (vld_s1_C[0] ? odat_ram[0] : odat_ram[1]) : odat_C; - - for(int i = 0; i < 2; i++) begin - b_addr[i] = rd_pntr_C; - end - - case(state_rd_C) - ST_RD_0: begin - if(ordy) begin - if((state_wr_C == ST_WR_0) ? (wr_pntr_C > rd_pntr_C) : 1'b1) begin - - vld_s0_N[0] = 1'b1; - - rd_pntr_N = (rd_pntr_C == N_TLS-1) ? 0 : rd_pntr_C + 1; - reps_N = (rd_pntr_C == N_TLS-1) ? ((reps_C == N_REPS-1) ? 0 : reps_C + 1) : reps_C; - end - end - end - - ST_RD_1: begin - if(ordy) begin - if((state_wr_C == ST_WR_1) ? (wr_pntr_C > rd_pntr_C) : 1'b1) begin - - vld_s0_N[1] = 1'b1; - - rd_pntr_N = (rd_pntr_C == N_TLS-1) ? 0 : rd_pntr_C + 1; - reps_N = (rd_pntr_C == N_TLS-1) ? ((reps_C == N_REPS-1) ? 0 : reps_C + 1) : reps_C; - end - end - end - - endcase - -end - -assign ovld = vld_C; -assign odat = odat_C; - -// ---------------------------------------------------------------------------- -// Weights -// ---------------------------------------------------------------------------- - -for(genvar i = 0; i < 2; i++) begin - for(genvar j = 0; j < PE; j++) begin - ram_p_c #( - .ADDR_BITS(WGT_ADDR_BITS), - .DATA_BITS(RAM_BITS), - .RAM_STYLE("block") - ) inst_ram_tp_c ( - .clk(clk), - .a_en(1'b1), - .a_we(a_we[i][j]), - .a_addr(a_addr[i]), - .b_en(ordy), - .b_addr(b_addr[i]), - .a_data_in(a_data_in[i]), - .a_data_out(), - .b_data_out(odat_ram[i][j]) - ); - end -end - -endmodule diff --git a/finn-rtllib/mmu/1d/collect_out_1d.sv b/finn-rtllib/mmu/1d/collect_out_1d.sv new file mode 100644 index 0000000000..1451d9a837 --- /dev/null +++ b/finn-rtllib/mmu/1d/collect_out_1d.sv @@ -0,0 +1,111 @@ +/****************************************************************************** + * Copyright (C) 2024, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + *****************************************************************************/ + +module collect_out_1d #( + int unsigned PE, + int unsigned ACCU_WIDTH, + + int unsigned QDEPTH = 2 * PE, + int unsigned QCNT_BITS = $clog2(QDEPTH), + int unsigned Q_MAX = PE, + + int unsigned N_DCPL_STAGES = 2 +)( + // Global Control + input logic clk, + input logic rst, + + output logic en, + + // Input Stream + input logic [PE-1:0][ACCU_WIDTH-1:0] p_tdata, + input logic p_tvalid, + + // Output Stream + output logic [PE-1:0][ACCU_WIDTH-1:0] m_axis_tdata, + output logic m_axis_tvalid, + input logic m_axis_tready +); + +// Queueing +// --------------------------------------------------------------------- + logic q_in_tready; + logic q_out_tready, q_out_tvalid; + logic [PE-1:0][ACCU_WIDTH-1:0] q_out_tdata; + logic [QCNT_BITS-1:0] q_count; + logic en_int; + + for(genvar i = 0; i < PE; i++) begin + if(i == 0) begin + Q_srl #( + .depth(QDEPTH), + .width(ACCU_WIDTH) + ) inst_queue ( + .clock(clk), .reset(rst), + .count(q_count), .maxcount(), + .i_v(p_tvalid), .i_r(q_in_tready), .i_d(p_tdata[i]), + .o_v(q_out_tvalid), .o_r(q_out_tready), .o_d(q_out_tdata[i]) + ); + end else begin + Q_srl #( + .depth(QDEPTH), + .width(ACCU_WIDTH) + ) inst_queue ( + .clock(clk), .reset(rst), + .count(), .maxcount(), + .i_v(p_tvalid), .i_r(), .i_d(p_tdata[i]), + .o_v(), .o_r(q_out_tready), .o_d(q_out_tdata[i]) + ); + end + end + + // Global enable + assign en_int = !(q_count > Q_MAX); + + always_ff @( posedge clk ) begin + if(rst) begin + en <= 1'b0; + end + else begin + en <= en_int; + end + end + +// Output +// --------------------------------------------------------------------- + skid #(.DATA_WIDTH(PE*ACCU_WIDTH), .FEED_STAGES(N_DCPL_STAGES)) inst_oreg ( + .clk(clk), .rst(rst), + .idat(q_out_tdata), .ivld(q_out_tvalid), .irdy(q_out_tready), + .odat(m_axis_tdata), .ovld(m_axis_tvalid), .ordy(m_axis_tready) + ); + +endmodule \ No newline at end of file diff --git a/finn-rtllib/mmu/1d/cu_mmau_1d.sv b/finn-rtllib/mmu/1d/cu_mmau_1d.sv new file mode 100644 index 0000000000..601c4bb2db --- /dev/null +++ b/finn-rtllib/mmu/1d/cu_mmau_1d.sv @@ -0,0 +1,403 @@ +/****************************************************************************** + * Copyright (C) 2025, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * @brief Compute unit (DSP grid) - MMAU + * @author Dario Korolija + *****************************************************************************/ + +module cu_mmau_1d #( + int unsigned PE, + int unsigned CLEN, + int unsigned CU_SIMD, + + int unsigned ACTIVATION_WIDTH, + int unsigned WEIGHT_WIDTH, + int unsigned ACCU_WIDTH, + + bit SIGNED_ACTIVATIONS = 1, + int unsigned FORCE_BEHAVIOURAL = 0 + ) ( + // Global Control + input logic clk, + input logic rst, + + // Enable + output logic en, + + // Input + input logic ivld, + input logic [CLEN-1:0] ilast, + input logic [CLEN-1:0][CU_SIMD-1:0][ACTIVATION_WIDTH-1:0] a, + input logic [PE-1:0][CU_SIMD-1:0][WEIGHT_WIDTH-1:0] w, + + // Ouput + output logic m_axis_tvalid, + input logic m_axis_tready, + output logic [PE-1:0][ACCU_WIDTH-1:0] m_axis_tdata + ); + +// Startup Recovery Watchdog +// The DSP slice needs 100ns of recovery time after initial startup before +// being able to ingest input properly. This watchdog discovers violating +// stimuli during simulation and produces a corresponding warning. +//------------------------------------------------------------------------------------ + if(1) begin : blkRecoveryWatch + logic Dirty = 1; + initial begin + #100ns; + Dirty <= 0; + end + + always_ff @(posedge clk) begin + assert(!Dirty || rst) else begin + $warning("%m: Feeding input during DSP startup recovery. Expect functional errors."); + end + end + end : blkRecoveryWatch + +// Shifts - activations and weights +//------------------------------------------------------------------------------------ + localparam int unsigned PAD_BITS_ACT = 9 - ACTIVATION_WIDTH; + localparam int unsigned PAD_BITS_WEIGHT = 8 - WEIGHT_WIDTH; + + logic [CLEN:0][PE-1:0][CU_SIMD*WEIGHT_WIDTH-1:0] Wc; + logic [CLEN-1:0][PE-1:0][23:0] Wc_int; + + for(genvar i = 0; i < PE; i++) begin + assign Wc[0][i] = w[i]; + + for (genvar k = 0; k < CU_SIMD; k++) begin + assign Wc_int[0][i][8*k +: 8] = + PAD_BITS_WEIGHT == 0 ? Wc[0][i][WEIGHT_WIDTH*k+:WEIGHT_WIDTH] : { {PAD_BITS_WEIGHT{Wc[0][i][k*WEIGHT_WIDTH+WEIGHT_WIDTH-1]}}, Wc[0][i][k*WEIGHT_WIDTH+:WEIGHT_WIDTH] }; + end + end + + /* + always_ff @(posedge clk) begin + if(rst) begin + for(int i = 1; i < CLEN; i++) begin + for(int j = 0; j < PE; j++) begin + Wc[i][j] <= 'X; + end + end + end + for(int i = 1; i < CLEN; i++) begin + for(int j = 0; j < PE; j++) begin + if(ivld) begin + Wc[i][j] <= Wc[i-1][j]; + end + end + end + end + */ + +// Shifts - per DSP +//------------------------------------------------------------------------------------ + localparam int unsigned DSP_PIPELINE_STAGES = 3; + logic [CLEN-1:0][DSP_PIPELINE_STAGES:0] Lc; + + for(genvar i = 0; i < CLEN; i++) begin + assign Lc[i][0] = ilast[i]; + end + + always_ff @(posedge clk) begin + if(rst) begin + for(int i = 0; i < CLEN; i++) begin + for(int k = 1; k <= DSP_PIPELINE_STAGES; k++) begin + Lc[i][k] <= 'X; + end + end + end + else begin + for(int i = 0; i < CLEN; i++) begin + for(int k = 1; k <= DSP_PIPELINE_STAGES; k++) begin + if(ivld) begin + Lc[i][k] <= Lc[i][k-1]; + end + end + end + end + end + +// Instantiate PE x CLEN DSPs +//------------------------------------------------------------------------------------ + logic [CLEN-1:0][PE-1:0][ACCU_WIDTH-1:0] pout; + + /* if(FORCE_BEHAVIOURAL == 1) begin + logic [CLEN-1:0][CU_SIMD*ACTIVATION_WIDTH-1:0] Ac_int; + logic [CLEN-1:0][PE-1:0][CU_SIMD*WEIGHT_WIDTH-1:0] Wc_int; + logic [CLEN-1:0][PE-1:0][CU_SIMD-1:0][ACCU_WIDTH-1:0] Mc_int_part; + logic [CLEN-1:0][PE-1:0][ACCU_WIDTH-1:0] Mc_int_sum; + logic [CLEN-1:0][PE-1:0][ACCU_WIDTH-1:0] Mc_int; + + + for (genvar i = 0; i < CLEN; i++) begin + always_ff @(posedge clk) begin + if(rst) begin + Ac_int[i] <= 'X; + end else begin + if(ivld) begin + Ac_int[i] <= a[i]; + end + end + end + + for (genvar j = 0; j < PE; j++) begin + always_comb begin + Mc_int_sum[i][j] = 0; + + for(int k = 0; k < CU_SIMD; k++) begin + Mc_int_part[i][j][k] = $signed(Ac_int[i][k*ACTIVATION_WIDTH+:ACTIVATION_WIDTH]) * $signed(Wc_int[i][j][k*WEIGHT_WIDTH+:WEIGHT_WIDTH]); + Mc_int_sum[i][j] = $signed(Mc_int_sum[i][j]) + $signed(Mc_int_part[i][j][k]); + end + end + + always_ff @(posedge clk) begin + if(rst) begin + Wc_int[i][j] <= '0; + Mc_int[i][j] <= '0; + pout[i][j] <= '0; + end else begin + if(ivld) begin + Wc_int[i][j] <= Wc[i][j]; + Mc_int[i][j] <= $signed(Mc_int_sum[i][j]); + pout[i][j] <= Lc[i][DSP_PIPELINE_STAGES] ? $signed(Mc_int[i][j]) : $signed(Mc_int[i][j]) + $signed(pout[i][j]); + end + end + end + end + end + end else begin */ + localparam int INTERNAL_REGS = 1; // 1 : 0 + localparam bit PREG = 1; + localparam int CC_LEN = CLEN / 4; + + logic [CLEN-1:0][26:0] Ac_int; + //logic [CLEN-1:0][PE-1:0][23:0] Wc_int; + logic [CLEN-1:0][PE-1:0][23:0] tmp_cc; + + for (genvar i = 0; i < CLEN; i++) begin + for (genvar k = 0; k < CU_SIMD; k++) begin + assign Ac_int[i][9*k +: 9] = + SIGNED_ACTIVATIONS ? PAD_BITS_ACT == 0 ? a[i][k] : { {PAD_BITS_ACT{a[i][k][ACTIVATION_WIDTH-1]}}, a[i][k] } + : PAD_BITS_ACT == 0 ? a[i][k] : { {PAD_BITS_ACT{1'b0}}, a[i][k] } ; + end + + + for (genvar j = 0; j < PE; j++) begin + /* for (genvar k = 0; k < CU_SIMD; k++) begin + assign Wc_int[i][j][8*k +: 8] = + PAD_BITS_WEIGHT == 0 ? Wc[i][j][WEIGHT_WIDTH*k+:WEIGHT_WIDTH] : { {PAD_BITS_WEIGHT{Wc[i][j][k*WEIGHT_WIDTH+WEIGHT_WIDTH-1]}}, Wc[i][j][k*WEIGHT_WIDTH+:WEIGHT_WIDTH] }; + end */ + + + DSP58 #( + // Feature Control Attributes: Data Path Selection + .AMULTSEL("A"), // Selects A input to multiplier (A, AD) + .A_INPUT("DIRECT"), // Selects A input source, "DIRECT" (A port) or "CASCADE" (ACIN port) + .BMULTSEL("B"), // Selects B input to multiplier (AD, B) + .B_INPUT((i % CC_LEN == 0) ? "DIRECT" : "CASCADE"), // Selects B input source, "DIRECT" (B port) or "CASCADE" (BCIN port) + .DSP_MODE("INT8"), // Configures DSP to a particular mode of operation. Set to INT24 for + // legacy mode. + .PREADDINSEL("A"), // Selects input to pre-adder (A, B) + .RND(58'h000000000000000), // Rounding Constant + .USE_MULT("MULTIPLY"), // Select multiplier usage (DYNAMIC, MULTIPLY, NONE) + .USE_SIMD("ONE58"), // SIMD selection (FOUR12, ONE58, TWO24) + .USE_WIDEXOR("FALSE"), // Use the Wide XOR function (FALSE, TRUE) + .XORSIMD("XOR24_34_58_116"), // Mode of operation for the Wide XOR (XOR12_22, XOR24_34_58_116) + // Pattern Detector Attributes: Pattern Detection Configuration + .AUTORESET_PATDET("NO_RESET"), // NO_RESET, RESET_MATCH, RESET_NOT_MATCH + .AUTORESET_PRIORITY("RESET"), // Priority of AUTORESET vs. CEP (CEP, RESET). + .MASK(58'h0ffffffffffffff), // 58-bit mask value for pattern detect (1=ignore) + .PATTERN(58'h000000000000000), // 58-bit pattern match for pattern detect + .SEL_MASK("MASK"), // C, MASK, ROUNDING_MODE1, ROUNDING_MODE2 + .SEL_PATTERN("PATTERN"), // Select pattern value (C, PATTERN) + .USE_PATTERN_DETECT("NO_PATDET"), // Enable pattern detect (NO_PATDET, PATDET) + // Programmable Inversion Attributes: Specifies built-in programmable inversion on specific pins + .IS_ALUMODE_INVERTED(4'b0000), // Optional inversion for ALUMODE + .IS_CARRYIN_INVERTED(1'b0), // Optional inversion for CARRYIN + .IS_CLK_INVERTED(1'b0), // Optional inversion for CLK + .IS_INMODE_INVERTED(5'b00000), // Optional inversion for INMODE + .IS_NEGATE_INVERTED(3'b000), // Optional inversion for NEGATE + .IS_OPMODE_INVERTED({2'b00, // W: LAST ? 0 : P + 3'b000, // Z: 0 + 2'b01, // Y : M + 2'b01 // X: M + }), // Optional inversion for OPMODE + .IS_RSTALLCARRYIN_INVERTED(1'b0), // Optional inversion for RSTALLCARRYIN + .IS_RSTALUMODE_INVERTED(1'b0), // Optional inversion for RSTALUMODE + .IS_RSTA_INVERTED(1'b0), // Optional inversion for RSTA + .IS_RSTB_INVERTED(1'b0), // Optional inversion for RSTB + .IS_RSTCTRL_INVERTED(1'b0), // Optional inversion for STCONJUGATE_A + .IS_RSTC_INVERTED(1'b0), // Optional inversion for RSTC + .IS_RSTD_INVERTED(1'b0), // Optional inversion for RSTD + .IS_RSTINMODE_INVERTED(1'b0), // Optional inversion for RSTINMODE + .IS_RSTM_INVERTED(1'b0), // Optional inversion for RSTM + .IS_RSTP_INVERTED(1'b0), // Optional inversion for RSTP + // Register Control Attributes: Pipeline Register Configuration + .ACASCREG(INTERNAL_REGS), // Number of pipeline stages between A/ACIN and ACOUT (0-2) + .ADREG(0), // Pipeline stages for pre-adder (0-1) + .ALUMODEREG(0), // Pipeline stages for ALUMODE (0-1) + .AREG(INTERNAL_REGS), // Pipeline stages for A (0-2) + .BCASCREG(INTERNAL_REGS), // Number of pipeline stages between B/BCIN and BCOUT (0-2) + .BREG(INTERNAL_REGS), // Pipeline stages for B (0-2) + .CARRYINREG(0), // Pipeline stages for CARRYIN (0-1) + .CARRYINSELREG(0), // Pipeline stages for CARRYINSEL (0-1) + .CREG(0), // Pipeline stages for C (0-1) + .DREG(0), // Pipeline stages for D (0-1) + .INMODEREG(1), // Pipeline stages for INMODE (0-1) + .MREG(1), // Multiplier pipeline stages (0-1) + .OPMODEREG(1), // Pipeline stages for OPMODE (0-1) + .PREG(PREG), // Number of pipeline stages for P (0-1) + .RESET_MODE("SYNC") // Selection of synchronous or asynchronous reset. (ASYNC, SYNC). + ) + DSP58_inst ( + // Cascade outputs: Cascade Ports + .ACOUT(), // 34-bit output: A port cascade + .BCOUT((i % CC_LEN == CC_LEN-1) ? tmp_cc[i+1][j] : Wc_int[i+1][j]), // 24-bit output: B cascade + .CARRYCASCOUT(), // 1-bit output: Cascade carry + .MULTSIGNOUT(), // 1-bit output: Multiplier sign cascade + .PCOUT() , // 58-bit output: Cascade output + // Control outputs: Control Inputs/Status Bits + .OVERFLOW(), // 1-bit output: Overflow in add/acc + .PATTERNBDETECT(), // 1-bit output: Pattern bar detect + .PATTERNDETECT(), // 1-bit output: Pattern detect + .UNDERFLOW(), // 1-bit output: Underflow in add/acc + // Data outputs: Data Ports + .CARRYOUT(), // 4-bit output: Carry + .P(pout[i][j]), // 58-bit output: Primary data + .XOROUT(), // 8-bit output: XOR data + // Cascade inputs: Cascade Ports + .ACIN('x), // 34-bit input: A cascade data + .BCIN((i % CC_LEN == 0) ? 'x : Wc_int[i][j]), // 24-bit input: B cascade + .CARRYCASCIN('x), // 1-bit input: Cascade carry + .MULTSIGNIN('x), // 1-bit input: Multiplier sign cascade + .PCIN('x), // 58-bit input: P cascade + // Control inputs: Control Inputs/Status Bits + .ALUMODE(4'h0), // 4-bit input: ALU control + .CARRYINSEL('0), // 3-bit input: Carry select + .CLK(clk), // 1-bit input: Clock + .INMODE({5'b10001}), // 5-bit input: INMODE control + .NEGATE('0), // 3-bit input: Negates the input of the multiplier + .OPMODE({ + Lc[i][DSP_PIPELINE_STAGES-1] ? 2'b00 : 2'b01, + 7'b000_0000 + }), // 9-bit input: Operation mode + // Data inputs: Data Ports + .A({ 7'b0, Ac_int[i] }), // 34-bit input: A data + .B((i % CC_LEN == 0) ? Wc_int[i][j] : 'x), // 24-bit input: B data + .C('x), // 58-bit input: C data + .CARRYIN('0), // 1-bit input: Carry-in + .D('x), // 27-bit input: D data + // Reset/Clock Enable inputs: Reset/Clock Enable Inputs + .ASYNC_RST('0), // 1-bit input: Asynchronous reset for all registers. + .CEA1(ivld), // 1-bit input: Clock enable for 1st stage AREG + .CEA2('0), // 1-bit input: Clock enable for 2nd stage AREG + .CEAD('0), // 1-bit input: Clock enable for ADREG + .CEALUMODE('0), // 1-bit input: Clock enable for ALUMODE + .CEB1(ivld), // 1-bit input: Clock enable for 1st stage BREG + .CEB2('0), // 1-bit input: Clock enable for 2nd stage BREG + .CEC('0), // 1-bit input: Clock enable for CREG + .CECARRYIN('0), // 1-bit input: Clock enable for CARRYINREG + .CECTRL(ivld), // 1-bit input: Clock enable for OPMODEREG and CARRYINSELREG + .CED('0), // 1-bit input: Clock enable for DREG + .CEINMODE('1), // 1-bit input: Clock enable for INMODEREG + .CEM(ivld), // 1-bit input: Clock enable for MREG + .CEP(ivld), // 1-bit input: Clock enable for PREG + .RSTA(rst), // 1-bit input: Reset for AREG + .RSTALLCARRYIN('0), // 1-bit input: Reset for CARRYINREG + .RSTALUMODE('0), // 1-bit input: Reset for ALUMODEREG + .RSTB(rst), // 1-bit input: Reset for BREG + .RSTC('0), // 1-bit input: Reset for CREG + .RSTCTRL(rst), // 1-bit input: Reset for OPMODEREG and CARRYINSELREG + .RSTD('0), // 1-bit input: Reset for DREG and ADREG + .RSTINMODE(rst), // 1-bit input: Reset for INMODE register + .RSTM(rst), // 1-bit input: Reset for MREG + .RSTP(rst) // 1-bit input: Reset for PREG + ); + + + if(i % CC_LEN == CC_LEN-1) begin + sft_reg #( + .N(CC_LEN) + ) inst_sft_reg ( + .clk(clk), + .ivld(ivld), + .din(Wc_int[i-(CC_LEN-1)][j]), + .dout(Wc_int[i+1][j]) + ); + end + + end + end + // end + +// Collect +//------------------------------------------------------------------------------------ + logic [CLEN-1:0][PE-1:0][ACCU_WIDTH-1:0] Pc; + logic [CLEN-1:0] Pc_vld; + + always_ff @(posedge clk) begin + if(rst) begin + for(int i = 0; i < CLEN; i++) begin + Pc[i] <= '0; + end + end else begin + for(int i = 0; i < CLEN; i++) begin + if(ivld) begin + if(i == CLEN-1) begin + Pc[i] <= pout[i]; + Pc_vld[i] <= Lc[i][DSP_PIPELINE_STAGES]; + end else begin + Pc[i] <= Lc[i][DSP_PIPELINE_STAGES] ? pout[i] : Pc[i+1]; + Pc_vld[i] <= Lc[i][DSP_PIPELINE_STAGES] ? 1'b1 : Pc_vld[i+1]; + end + end + end + end + end + + logic ovld; + logic [PE-1:0][ACCU_WIDTH-1:0] p; + + assign ovld = Pc_vld[0]; + assign p = Pc[0]; + + collect_out_1d #( + .PE(PE), + .ACCU_WIDTH(ACCU_WIDTH) + ) inst_collect_out ( + .clk(clk), .rst(rst), + .en(en), + .p_tdata(p), .p_tvalid(ovld), + .m_axis_tdata(m_axis_tdata), .m_axis_tvalid(m_axis_tvalid), .m_axis_tready(m_axis_tready) + ); + +endmodule \ No newline at end of file diff --git a/finn-rtllib/mmu/1d/sched_weights_1d.sv b/finn-rtllib/mmu/1d/sched_weights_1d.sv new file mode 100644 index 0000000000..5046f63d84 --- /dev/null +++ b/finn-rtllib/mmu/1d/sched_weights_1d.sv @@ -0,0 +1,141 @@ +/****************************************************************************** + * Copyright (C) 2024, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + *****************************************************************************/ + +module sched_weights_1d #( + int unsigned CU_SIMD, + int unsigned PE, + int unsigned WEIGHT_WIDTH, + + int unsigned N_BEATS_OP, + int unsigned N_BEATS_EP, + + int unsigned N_DCPL_STAGES = 2 +)( + // Global Control + input logic clk, + input logic rst, + + // Input Stream + input logic [PE-1:0][CU_SIMD-1:0][WEIGHT_WIDTH-1:0] s_axis_tdata, + input logic s_axis_tvalid, + output logic s_axis_tready, + + // Output Stream + output logic [PE-1:0][CU_SIMD-1:0][WEIGHT_WIDTH-1:0] m_axis_tdata, + output logic m_axis_tvalid, + input logic m_axis_tready +); + +// Params +// --------------------------------------------------------------------- + localparam integer CNT_EPLG_BITS = (N_BEATS_OP > N_BEATS_EP) ? + (N_BEATS_OP == 1) ? 1 : $clog2(N_BEATS_OP) : + (N_BEATS_EP == 1) ? 1 : $clog2(N_BEATS_EP); + +// Queueing +// --------------------------------------------------------------------- + logic s_out_tready, s_out_tvalid; + logic [PE-1:0][CU_SIMD-1:0][WEIGHT_WIDTH-1:0] s_out_tdata; + + skid #(.DATA_WIDTH(PE*CU_SIMD*WEIGHT_WIDTH), .FEED_STAGES(N_DCPL_STAGES)) inst_ireg ( + .clk(clk), .rst(rst), + .idat(s_axis_tdata), .ivld(s_axis_tvalid), .irdy(s_axis_tready), + .odat(s_out_tdata), .ovld(s_out_tvalid), .ordy(s_out_tready) + ); + +// Shifting +// --------------------------------------------------------------------- + logic valid_C = '0, valid_N; + logic eplg_C = '0, eplg_N; + logic [CNT_EPLG_BITS-1:0] cnt_eplg_C = '0, cnt_eplg_N; + + logic ovld, ordy; + logic [PE-1:0][CU_SIMD-1:0][WEIGHT_WIDTH-1:0] odat; + + // REG + always_ff @(posedge clk) begin + if(rst) begin + valid_C <= '1; + eplg_C <= 1'b0; + cnt_eplg_C <= '0; + end else begin + valid_C <= valid_N; + eplg_C <= eplg_N; + cnt_eplg_C <= cnt_eplg_N; + end + end + + // DP + always_comb begin + valid_N = valid_C; + eplg_N = eplg_C; + cnt_eplg_N = cnt_eplg_C; + + // Read + if (ovld && ordy) begin + // Shift ctrl + if(eplg_C) begin + if(cnt_eplg_C == N_BEATS_EP-1) begin + eplg_N = 1'b0; + cnt_eplg_N = 0; + valid_N = 1'b1; + end else begin + cnt_eplg_N = cnt_eplg_C + 1; + valid_N = 1'b0; + end + end else begin + if(cnt_eplg_C == N_BEATS_OP-1) begin + eplg_N = 1'b1; + cnt_eplg_N = 0; + valid_N = 1'b0; + end else begin + cnt_eplg_N = cnt_eplg_C + 1; + valid_N = 1'b1; + end + end + end + end + + // Output valid + assign ovld = !((s_out_tvalid && valid_C) != valid_C); + assign s_out_tready = (ovld && ordy) && valid_C; + assign odat = valid_C ? s_out_tdata : '0; + +// Oreg +// --------------------------------------------------------------------- + skid #(.DATA_WIDTH(PE*CU_SIMD*WEIGHT_WIDTH), .FEED_STAGES(N_DCPL_STAGES)) inst_oreg ( + .clk(clk), .rst(rst), + .idat(odat), .ivld(ovld), .irdy(ordy), + .odat(m_axis_tdata), .ovld(m_axis_tvalid), .ordy(m_axis_tready) + ); + +endmodule \ No newline at end of file diff --git a/finn-rtllib/mmu/1d/sft_reg.sv b/finn-rtllib/mmu/1d/sft_reg.sv new file mode 100644 index 0000000000..7fa8f5eeef --- /dev/null +++ b/finn-rtllib/mmu/1d/sft_reg.sv @@ -0,0 +1,25 @@ +module sft_reg #( + int N = 4, + int DATA_BITS = 24 +)( + input logic clk, + input logic ivld, + input logic [DATA_BITS-1:0] din, + output logic [DATA_BITS-1:0] dout +); + + // A 2D array representing the shift stages + logic [N-1:0][DATA_BITS-1:0] shift_pipe; + + always_ff @(posedge clk) begin + if (ivld) begin + // Shift the bits in + shift_pipe <= {shift_pipe[N-2:0], din}; + end + end + + // The tool sees this lack of reset and constant index + // and maps it to an SRL16 automatically. + assign dout = shift_pipe[N-1]; + +endmodule \ No newline at end of file diff --git a/finn-rtllib/mmu/2d/collect_out_2d.sv b/finn-rtllib/mmu/2d/collect_out_2d.sv new file mode 100644 index 0000000000..1f0b60b319 --- /dev/null +++ b/finn-rtllib/mmu/2d/collect_out_2d.sv @@ -0,0 +1,120 @@ +/****************************************************************************** + * Copyright (C) 2024, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + *****************************************************************************/ + +module collect_out_2d #( + int unsigned PE, + int unsigned ACCU_WIDTH, + + int unsigned QDEPTH = 2 * PE, + int unsigned QCNT_BITS = $clog2(QDEPTH), + int unsigned Q_MAX = PE, + + int unsigned N_DCPL_STAGES = 2 +)( + // Global Control + input logic clk, + input logic rst, + + output logic en, + + // Input Stream + input logic [PE-1:0][ACCU_WIDTH-1:0] p_tdata, + input logic [PE-1:0] p_tvalid, + + // Output Stream + output logic [PE-1:0][ACCU_WIDTH-1:0] m_axis_tdata, + output logic m_axis_tvalid, + input logic m_axis_tready +); + +// Queueing +// --------------------------------------------------------------------- + logic [PE-1:0] q_in_tready, q_in_tvalid; + logic [PE-1:0][ACCU_WIDTH-1:0] q_in_tdata; + logic [PE-1:0] q_out_tready, q_out_tvalid; + logic [PE-1:0][ACCU_WIDTH-1:0] q_out_tdata; + logic [PE-1:0][QCNT_BITS-1:0] q_count; + logic en_int; + + assign q_in_tvalid = p_tvalid; + assign q_in_tdata = p_tdata; + + for(genvar i = 0; i < PE; i++) begin + Q_srl #( + .depth(QDEPTH), + .width(ACCU_WIDTH) + ) inst_queue ( + .clock(clk), .reset(rst), + .count(q_count[i]), .maxcount(), + .i_v(q_in_tvalid[i]), .i_r(q_in_tready[i]), .i_d(q_in_tdata[i]), + .o_v(q_out_tvalid[i]), .o_r(q_out_tready[i]), .o_d(q_out_tdata[i]) + ); + end + + // Global enable + always_comb begin + en_int = 1'b1; + + for(int i = 0; i < PE; i++) begin + if(q_count[i] > Q_MAX) + en_int = 1'b0; + end + end + + always_ff @( posedge clk ) begin + if(rst) begin + en <= 1'b0; + end + else begin + en <= en_int; + end + end + +// Output +// --------------------------------------------------------------------- + logic ovld; + logic ordy; + logic [PE-1:0][ACCU_WIDTH-1:0] odat; + + assign odat = q_out_tdata; + assign ovld = &q_out_tvalid; + for(genvar i = 0; i < PE; i++) begin + assign q_out_tready[i] = ovld && ordy; + end + + skid #(.DATA_WIDTH(PE*ACCU_WIDTH), .FEED_STAGES(N_DCPL_STAGES)) inst_oreg ( + .clk(clk), .rst(rst), + .idat(odat), .ivld(ovld), .irdy(ordy), + .odat(m_axis_tdata), .ovld(m_axis_tvalid), .ordy(m_axis_tready) + ); + +endmodule \ No newline at end of file diff --git a/finn-rtllib/mmu/2d/cu_mmau_2d.sv b/finn-rtllib/mmu/2d/cu_mmau_2d.sv new file mode 100644 index 0000000000..e4f246872c --- /dev/null +++ b/finn-rtllib/mmu/2d/cu_mmau_2d.sv @@ -0,0 +1,432 @@ +/****************************************************************************** + * Copyright (C) 2025, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * @brief Compute unit (DSP grid) - MMAU + * @author Dario Korolija + *****************************************************************************/ + +module cu_mmau_2d #( + int unsigned PE, + int unsigned CLEN, + int unsigned CU_SIMD, + + int unsigned ACTIVATION_WIDTH, + int unsigned WEIGHT_WIDTH, + int unsigned ACCU_WIDTH, + + bit SIGNED_ACTIVATIONS = 1, + int unsigned FORCE_BEHAVIOURAL = 0 + ) ( + // Global Control + input logic clk, + input logic rst, + + // Enable + output logic en, + + // Input + input logic ivld, + input logic [CLEN-1:0] ilast, + input logic [CLEN-1:0][CU_SIMD-1:0][ACTIVATION_WIDTH-1:0] a, + input logic [PE-1:0][CU_SIMD-1:0][WEIGHT_WIDTH-1:0] w, + + // Ouput + output logic m_axis_tvalid, + input logic m_axis_tready, + output logic [PE-1:0][ACCU_WIDTH-1:0] m_axis_tdata + ); + + +// Startup Recovery Watchdog +// The DSP slice needs 100ns of recovery time after initial startup before +// being able to ingest input properly. This watchdog discovers violating +// stimuli during simulation and produces a corresponding warning. +//------------------------------------------------------------------------------------ + if(1) begin : blkRecoveryWatch + logic Dirty = 1; + initial begin + #100ns; + Dirty <= 0; + end + + always_ff @(posedge clk) begin + assert(!Dirty || rst) else begin + $warning("%m: Feeding input during DSP startup recovery. Expect functional errors."); + end + end + end : blkRecoveryWatch + +// Shifts - activations and weights +//------------------------------------------------------------------------------------ + localparam int unsigned PAD_BITS_ACT = 9 - ACTIVATION_WIDTH; + localparam int unsigned PAD_BITS_WEIGHT = 8 - WEIGHT_WIDTH; + + logic [CLEN-1:0][PE-1:0][CU_SIMD*ACTIVATION_WIDTH-1:0] Ac; + logic [CLEN-1:0][PE-1:0] Ac_last; + logic [CLEN-1:0][PE-1:0][CU_SIMD*WEIGHT_WIDTH-1:0] Wc; + + for(genvar i = 0; i < CLEN; i++) begin + assign Ac[i][0] = a[i]; + assign Ac_last[i][0] = ilast[i]; + end + + for(genvar i = 0; i < PE; i++) begin + assign Wc[0][i] = w[i]; + end + + always_ff @(posedge clk) begin + if(rst) begin + for(int i = 0; i < CLEN; i++) begin + for(int j = 1; j < PE; j++) begin + Ac[i][j] <= '0; + Ac_last[i][j] <= 1'b0; + end + end + + for(int i = 1; i < CLEN; i++) begin + for(int j = 0; j < PE; j++) begin + Wc[i][j] <= '0; + end + end + end + for(int i = 0; i < CLEN; i++) begin + for(int j = 1; j < PE; j++) begin + if(ivld) begin + Ac[i][j] <= Ac[i][j-1]; + Ac_last[i][j] <= Ac_last[i][j-1]; + end + end + end + + for(int i = 1; i < CLEN; i++) begin + for(int j = 0; j < PE; j++) begin + if(ivld) begin + Wc[i][j] <= Wc[i-1][j]; + end + end + end + end + +// Shifts - per DSP +//------------------------------------------------------------------------------------ + localparam int unsigned DSP_PIPELINE_STAGES = 3; + logic [CLEN-1:0][PE-1:0][DSP_PIPELINE_STAGES:0] Lc; + + for(genvar i = 0; i < CLEN; i++) begin + for(genvar j = 0; j < PE; j++) begin + assign Lc[i][j][0] = Ac_last[i][j]; + end + end + + always_ff @(posedge clk) begin + if(rst) begin + for(int i = 0; i < CLEN; i++) begin + for(int j = 0; j < PE; j++) begin + for(int k = 1; k <= DSP_PIPELINE_STAGES; k++) begin + Lc[i][j][k] <= 1'b0; + end + end + end + end + else begin + for(int i = 0; i < CLEN; i++) begin + for(int j = 0; j < PE; j++) begin + for(int k = 1; k <= DSP_PIPELINE_STAGES; k++) begin + if(ivld) begin + Lc[i][j][k] <= Lc[i][j][k-1]; + end + end + end + end + end + end + +// Instantiate PE x CLEN DSPs +//------------------------------------------------------------------------------------ + logic [CLEN-1:0][PE-1:0][ACCU_WIDTH-1:0] pout; + + if(FORCE_BEHAVIOURAL == 1) begin + logic [CLEN-1:0][PE-1:0][CU_SIMD*ACTIVATION_WIDTH-1:0] Ac_int; + logic [CLEN-1:0][PE-1:0][CU_SIMD*WEIGHT_WIDTH-1:0] Wc_int; + logic [CLEN-1:0][PE-1:0][CU_SIMD-1:0][ACCU_WIDTH-1:0] Mc_int_part; + logic [CLEN-1:0][PE-1:0][ACCU_WIDTH-1:0] Mc_int_sum; + logic [CLEN-1:0][PE-1:0][ACCU_WIDTH-1:0] Mc_int; + + + for (genvar i = 0; i < CLEN; i++) begin + for (genvar j = 0; j < PE; j++) begin + always_comb begin + Mc_int_sum[i][j] = 0; + + for(int k = 0; k < CU_SIMD; k++) begin + Mc_int_part[i][j][k] = $signed(Ac_int[i][j][k*ACTIVATION_WIDTH+:ACTIVATION_WIDTH]) * $signed(Wc_int[i][j][k*WEIGHT_WIDTH+:WEIGHT_WIDTH]); + Mc_int_sum[i][j] = $signed(Mc_int_sum[i][j]) + $signed(Mc_int_part[i][j][k]); + end + end + + always_ff @(posedge clk) begin + if(rst) begin + Ac_int[i][j] <= '0; + Wc_int[i][j] <= '0; + Mc_int[i][j] <= '0; + pout[i][j] <= '0; + end else begin + if(ivld) begin + Ac_int[i][j] <= Ac[i][j]; + Wc_int[i][j] <= Wc[i][j]; + Mc_int[i][j] <= $signed(Mc_int_sum[i][j]); + pout[i][j] <= Lc[i][j][DSP_PIPELINE_STAGES] ? $signed(Mc_int[i][j]) : $signed(Mc_int[i][j]) + $signed(pout[i][j]); + end + end + end + end + end + end else begin + localparam int INTERNAL_REGS = 1; // 1 : 0 + localparam bit PREG = 1; + + logic [CLEN-1:0][PE-1:0][26:0] Ac_int; + logic [CLEN-1:0][PE-1:0][23:0] Wc_int; + + for (genvar i = 0; i < CLEN; i++) begin + for (genvar j = 0; j < PE; j++) begin + + for (genvar k = 0; k < CU_SIMD; k++) begin + assign Ac_int[i][j][9*k +: 9] = + SIGNED_ACTIVATIONS ? PAD_BITS_ACT == 0 ? Ac[i][j][ACTIVATION_WIDTH*k+:ACTIVATION_WIDTH] : { {PAD_BITS_ACT{Ac[i][j][k*ACTIVATION_WIDTH+ACTIVATION_WIDTH-1]}}, Ac[i][j][k*ACTIVATION_WIDTH+:ACTIVATION_WIDTH] } + : PAD_BITS_ACT == 0 ? Ac[i][j][ACTIVATION_WIDTH*k+:ACTIVATION_WIDTH] : { {PAD_BITS_ACT{1'b0}}, Ac[i][j][k*ACTIVATION_WIDTH+:ACTIVATION_WIDTH] } ; + assign Wc_int[i][j][8*k +: 8] = + PAD_BITS_WEIGHT == 0 ? Wc[i][j][WEIGHT_WIDTH*k+:WEIGHT_WIDTH] : { {PAD_BITS_WEIGHT{Wc[i][j][k*WEIGHT_WIDTH+WEIGHT_WIDTH-1]}}, Wc[i][j][k*WEIGHT_WIDTH+:WEIGHT_WIDTH] }; + end + + + DSP58 #( + // Feature Control Attributes: Data Path Selection + .AMULTSEL("A"), // Selects A input to multiplier (A, AD) + .A_INPUT("DIRECT"), // Selects A input source, "DIRECT" (A port) or "CASCADE" (ACIN port) + .BMULTSEL("B"), // Selects B input to multiplier (AD, B) + .B_INPUT("DIRECT"), // Selects B input source, "DIRECT" (B port) or "CASCADE" (BCIN port) + .DSP_MODE("INT8"), // Configures DSP to a particular mode of operation. Set to INT24 for + // legacy mode. + .PREADDINSEL("A"), // Selects input to pre-adder (A, B) + .RND(58'h000000000000000), // Rounding Constant + .USE_MULT("MULTIPLY"), // Select multiplier usage (DYNAMIC, MULTIPLY, NONE) + .USE_SIMD("ONE58"), // SIMD selection (FOUR12, ONE58, TWO24) + .USE_WIDEXOR("FALSE"), // Use the Wide XOR function (FALSE, TRUE) + .XORSIMD("XOR24_34_58_116"), // Mode of operation for the Wide XOR (XOR12_22, XOR24_34_58_116) + // Pattern Detector Attributes: Pattern Detection Configuration + .AUTORESET_PATDET("NO_RESET"), // NO_RESET, RESET_MATCH, RESET_NOT_MATCH + .AUTORESET_PRIORITY("RESET"), // Priority of AUTORESET vs. CEP (CEP, RESET). + .MASK(58'h0ffffffffffffff), // 58-bit mask value for pattern detect (1=ignore) + .PATTERN(58'h000000000000000), // 58-bit pattern match for pattern detect + .SEL_MASK("MASK"), // C, MASK, ROUNDING_MODE1, ROUNDING_MODE2 + .SEL_PATTERN("PATTERN"), // Select pattern value (C, PATTERN) + .USE_PATTERN_DETECT("NO_PATDET"), // Enable pattern detect (NO_PATDET, PATDET) + // Programmable Inversion Attributes: Specifies built-in programmable inversion on specific pins + .IS_ALUMODE_INVERTED(4'b0000), // Optional inversion for ALUMODE + .IS_CARRYIN_INVERTED(1'b0), // Optional inversion for CARRYIN + .IS_CLK_INVERTED(1'b0), // Optional inversion for CLK + .IS_INMODE_INVERTED(5'b00000), // Optional inversion for INMODE + .IS_NEGATE_INVERTED(3'b000), // Optional inversion for NEGATE + .IS_OPMODE_INVERTED({2'b00, // W: LAST ? 0 : P + 3'b000, // Z: 0 + 2'b01, // Y : M + 2'b01 // X: M + }), // Optional inversion for OPMODE + .IS_RSTALLCARRYIN_INVERTED(1'b0), // Optional inversion for RSTALLCARRYIN + .IS_RSTALUMODE_INVERTED(1'b0), // Optional inversion for RSTALUMODE + .IS_RSTA_INVERTED(1'b0), // Optional inversion for RSTA + .IS_RSTB_INVERTED(1'b0), // Optional inversion for RSTB + .IS_RSTCTRL_INVERTED(1'b0), // Optional inversion for STCONJUGATE_A + .IS_RSTC_INVERTED(1'b0), // Optional inversion for RSTC + .IS_RSTD_INVERTED(1'b0), // Optional inversion for RSTD + .IS_RSTINMODE_INVERTED(1'b0), // Optional inversion for RSTINMODE + .IS_RSTM_INVERTED(1'b0), // Optional inversion for RSTM + .IS_RSTP_INVERTED(1'b0), // Optional inversion for RSTP + // Register Control Attributes: Pipeline Register Configuration + .ACASCREG(INTERNAL_REGS), // Number of pipeline stages between A/ACIN and ACOUT (0-2) + .ADREG(0), // Pipeline stages for pre-adder (0-1) + .ALUMODEREG(0), // Pipeline stages for ALUMODE (0-1) + .AREG(INTERNAL_REGS), // Pipeline stages for A (0-2) + .BCASCREG(INTERNAL_REGS), // Number of pipeline stages between B/BCIN and BCOUT (0-2) + .BREG(INTERNAL_REGS), // Pipeline stages for B (0-2) + .CARRYINREG(0), // Pipeline stages for CARRYIN (0-1) + .CARRYINSELREG(0), // Pipeline stages for CARRYINSEL (0-1) + .CREG(0), // Pipeline stages for C (0-1) + .DREG(0), // Pipeline stages for D (0-1) + .INMODEREG(1), // Pipeline stages for INMODE (0-1) + .MREG(1), // Multiplier pipeline stages (0-1) + .OPMODEREG(1), // Pipeline stages for OPMODE (0-1) + .PREG(PREG), // Number of pipeline stages for P (0-1) + .RESET_MODE("SYNC") // Selection of synchronous or asynchronous reset. (ASYNC, SYNC). + ) + DSP58_inst ( + // Cascade outputs: Cascade Ports + .ACOUT(), // 34-bit output: A port cascade + .BCOUT(), // 24-bit output: B cascade + .CARRYCASCOUT(), // 1-bit output: Cascade carry + .MULTSIGNOUT(), // 1-bit output: Multiplier sign cascade + .PCOUT(), // 58-bit output: Cascade output + // Control outputs: Control Inputs/Status Bits + .OVERFLOW(), // 1-bit output: Overflow in add/acc + .PATTERNBDETECT(), // 1-bit output: Pattern bar detect + .PATTERNDETECT(), // 1-bit output: Pattern detect + .UNDERFLOW(), // 1-bit output: Underflow in add/acc + // Data outputs: Data Ports + .CARRYOUT(), // 4-bit output: Carry + .P(pout[i][j]), // 58-bit output: Primary data + .XOROUT(), // 8-bit output: XOR data + // Cascade inputs: Cascade Ports + .ACIN('x), // 34-bit input: A cascade data + .BCIN('x), // 24-bit input: B cascade + .CARRYCASCIN('x), // 1-bit input: Cascade carry + .MULTSIGNIN('x), // 1-bit input: Multiplier sign cascade + .PCIN('0), // 58-bit input: P cascade + // Control inputs: Control Inputs/Status Bits + .ALUMODE(4'h0), // 4-bit input: ALU control + .CARRYINSEL('0), // 3-bit input: Carry select + .CLK(clk), // 1-bit input: Clock + .INMODE({5'b10001}), // 5-bit input: INMODE control + .NEGATE('0), // 3-bit input: Negates the input of the multiplier + .OPMODE({ + Lc[i][j][DSP_PIPELINE_STAGES-1] ? 2'b00 : 2'b01, + 7'b000_0000 + }), // 9-bit input: Operation mode + // Data inputs: Data Ports + .A({ 7'b0, Ac_int[i][j] }), // 34-bit input: A data + .B(Wc_int[i][j]), // 24-bit input: B data + .C('x), // 58-bit input: C data + .CARRYIN('0), // 1-bit input: Carry-in + .D('x), // 27-bit input: D data + // Reset/Clock Enable inputs: Reset/Clock Enable Inputs + .ASYNC_RST('0), // 1-bit input: Asynchronous reset for all registers. + .CEA1(ivld), // 1-bit input: Clock enable for 1st stage AREG + .CEA2('0), // 1-bit input: Clock enable for 2nd stage AREG + .CEAD('0), // 1-bit input: Clock enable for ADREG + .CEALUMODE('0), // 1-bit input: Clock enable for ALUMODE + .CEB1(ivld), // 1-bit input: Clock enable for 1st stage BREG + .CEB2('0), // 1-bit input: Clock enable for 2nd stage BREG + .CEC('0), // 1-bit input: Clock enable for CREG + .CECARRYIN('0), // 1-bit input: Clock enable for CARRYINREG + .CECTRL(ivld), // 1-bit input: Clock enable for OPMODEREG and CARRYINSELREG + .CED('0), // 1-bit input: Clock enable for DREG + .CEINMODE('1), // 1-bit input: Clock enable for INMODEREG + .CEM(ivld), // 1-bit input: Clock enable for MREG + .CEP(ivld), // 1-bit input: Clock enable for PREG + .RSTA(rst), // 1-bit input: Reset for AREG + .RSTALLCARRYIN('0), // 1-bit input: Reset for CARRYINREG + .RSTALUMODE('0), // 1-bit input: Reset for ALUMODEREG + .RSTB(rst), // 1-bit input: Reset for BREG + .RSTC('0), // 1-bit input: Reset for CREG + .RSTCTRL(rst), // 1-bit input: Reset for OPMODEREG and CARRYINSELREG + .RSTD('0), // 1-bit input: Reset for DREG and ADREG + .RSTINMODE(rst), // 1-bit input: Reset for INMODE register + .RSTM(rst), // 1-bit input: Reset for MREG + .RSTP(rst) // 1-bit input: Reset for PREG + ); + + end + end + end + +// Collect +//------------------------------------------------------------------------------------ + logic [CLEN-1:0][PE-1:0][ACCU_WIDTH-1:0] Pc_C = '0, Pc_N; + logic [CLEN-1:0][PE-1:0] Pc_vld_C = '0, Pc_vld_N; + + always_ff @(posedge clk) begin + if(rst) begin + Pc_C <= 'X; + Pc_vld_C <= '0; + end else begin + if(ivld) begin + Pc_C <= Pc_N; + Pc_vld_C <= Pc_vld_N; + end + end + end + + for(genvar i = 0; i < CLEN; i++) begin + for(genvar j = 0; j < PE; j++) begin + if(i == CLEN-1) begin + assign Pc_N[i][j] = pout[i][j]; + assign Pc_vld_N[i][j] = Lc[i][j][DSP_PIPELINE_STAGES]; + end else begin + assign Pc_N[i][j] = Lc[i][j][DSP_PIPELINE_STAGES] ? pout[i][j] : Pc_C[i+1][j]; + assign Pc_vld_N[i][j] = Lc[i][j][DSP_PIPELINE_STAGES] ? 1'b1 : Pc_vld_C[i+1][j]; + end + end + end + /* + always_ff @(posedge clk) begin + if(rst) begin + for(int i = 0; i < CLEN; i++) begin + for(int j = 0; j < PE; j++) begin + Pc[i][j] <= '0; + Pc_vld[i][j] <= '0; + end + end + end else begin + for(int i = 0; i < CLEN; i++) begin + for(int j = 0; j < PE; j++) begin + if(ivld) begin + if(i == CLEN-1) begin + Pc[i][j] <= pout[i][j]; + Pc_vld[i][j] <= Lc[i][j][DSP_PIPELINE_STAGES]; + end else begin + Pc[i][j] <= Lc[i][j][DSP_PIPELINE_STAGES] ? pout[i][j] : Pc[i+1][j]; + Pc_vld[i][j] <= Lc[i][j][DSP_PIPELINE_STAGES] ? 1'b1 : Pc_vld[i+1][j]; + end + end + end + end + end + end + */ + + logic [PE-1:0] ovld; + logic [PE-1:0][ACCU_WIDTH-1:0] p; + + for(genvar i = 0; i < PE; i++) begin + assign ovld[i] = Pc_vld_C[0][i]; + assign p[i] = Pc_C[0][i]; + end + + collect_out_2d #( + .PE(PE), + .ACCU_WIDTH(ACCU_WIDTH) + ) inst_collect_out ( + .clk(clk), .rst(rst), + .en(en), + .p_tdata(p), .p_tvalid(ovld), + .m_axis_tdata(m_axis_tdata), .m_axis_tvalid(m_axis_tvalid), .m_axis_tready(m_axis_tready) + ); + +endmodule \ No newline at end of file diff --git a/finn-rtllib/mmu/2d/sched_weights_2d.sv b/finn-rtllib/mmu/2d/sched_weights_2d.sv new file mode 100644 index 0000000000..a61c5bbaa6 --- /dev/null +++ b/finn-rtllib/mmu/2d/sched_weights_2d.sv @@ -0,0 +1,165 @@ +/****************************************************************************** + * Copyright (C) 2024, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + *****************************************************************************/ + +module sched_weights_2d #( + int unsigned CU_SIMD, + int unsigned PE, + int unsigned WEIGHT_WIDTH, + + int unsigned N_BEATS_OP, + int unsigned N_BEATS_EP, + + int unsigned N_DCPL_STAGES = 2 +)( + // Global Control + input logic clk, + input logic rst, + + // Input Stream + input logic [PE-1:0][CU_SIMD-1:0][WEIGHT_WIDTH-1:0] s_axis_tdata, + input logic s_axis_tvalid, + output logic s_axis_tready, + + // Output Stream + output logic [PE-1:0][CU_SIMD-1:0][WEIGHT_WIDTH-1:0] m_axis_tdata, + output logic m_axis_tvalid, + input logic m_axis_tready +); + +// Params +// --------------------------------------------------------------------- + localparam integer CNT_EPLG_BITS = (N_BEATS_OP > N_BEATS_EP) ? + (N_BEATS_OP == 1) ? 1 : $clog2(N_BEATS_OP) : + (N_BEATS_EP == 1) ? 1 : $clog2(N_BEATS_EP); + +// Queueing +// --------------------------------------------------------------------- + logic [PE-1:0] q_in_tready, q_in_tvalid; + logic [PE-1:0][CU_SIMD-1:0][WEIGHT_WIDTH-1:0] q_in_tdata; + logic [PE-1:0] q_out_tready, q_out_tvalid; + logic [PE-1:0][CU_SIMD-1:0][WEIGHT_WIDTH-1:0] q_out_tdata; + + assign s_axis_tready = &q_in_tready; + + for(genvar i = 0; i < PE; i++) begin + assign q_in_tdata[i] = s_axis_tdata[i]; + assign q_in_tvalid[i] = s_axis_tvalid && s_axis_tready; + + Q_srl #( + .depth(2*PE), .width(CU_SIMD*WEIGHT_WIDTH) + ) inst_queue ( + .clock(clk), .reset(rst), + .i_v(q_in_tvalid[i]), .i_r(q_in_tready[i]), .i_d(q_in_tdata[i]), + .o_v(q_out_tvalid[i]), .o_r(q_out_tready[i]), .o_d(q_out_tdata[i]) + ); + end + +// Shifting +// --------------------------------------------------------------------- + logic [PE-1:0] valid_C = '0, valid_N; + logic eplg_C = '0, eplg_N; + logic [CNT_EPLG_BITS-1:0] cnt_eplg_C = '0, cnt_eplg_N; + + logic ovld, ordy; + logic [PE-1:0][CU_SIMD-1:0][WEIGHT_WIDTH-1:0] odat; + + // REG + always_ff @(posedge clk) begin + if(rst) begin + valid_C[0] <= 1'b1; + valid_C[PE-1:1] <= '0; + eplg_C <= 1'b0; + cnt_eplg_C <= '0; + end else begin + valid_C <= valid_N; + eplg_C <= eplg_N; + cnt_eplg_C <= cnt_eplg_N; + end + end + + // DP + always_comb begin + valid_N = valid_C; + eplg_N = eplg_C; + cnt_eplg_N = cnt_eplg_C; + + // Read + if (ovld && ordy) begin + // Shift ctrl + valid_N[PE-1:1] = valid_C[PE-2:0]; + if(eplg_C) begin + if(cnt_eplg_C == N_BEATS_EP-1) begin + eplg_N = 1'b0; + cnt_eplg_N = 0; + valid_N[0] = 1'b1; + end else begin + cnt_eplg_N = cnt_eplg_C + 1; + valid_N[0] = 1'b0; + end + end else begin + if(cnt_eplg_C == N_BEATS_OP-1) begin + eplg_N = 1'b1; + cnt_eplg_N = 0; + valid_N[0] = 1'b0; + end else begin + cnt_eplg_N = cnt_eplg_C + 1; + valid_N[0] = 1'b1; + end + end + end + end + + // Output valid + always_comb begin + ovld = 1'b1; + + for(int i = 0; i < PE; i++) begin + if((valid_C[i] & q_out_tvalid[i]) != valid_C[i]) begin + ovld = 1'b0; + end + end + end + + for(genvar i = 0; i < PE; i++) begin + assign q_out_tready[i] = (ovld && ordy) && valid_C[i]; + assign odat[i] = valid_C[i] ? q_out_tdata[i] : '0; + end + +// Oreg +// --------------------------------------------------------------------- + skid #(.DATA_WIDTH(PE*CU_SIMD*WEIGHT_WIDTH), .FEED_STAGES(N_DCPL_STAGES)) inst_oreg ( + .clk(clk), .rst(rst), + .idat(odat), .ivld(ovld), .irdy(ordy), + .odat(m_axis_tdata), .ovld(m_axis_tvalid), .ordy(m_axis_tready) + ); + +endmodule \ No newline at end of file diff --git a/finn-rtllib/mmu/en_global.sv b/finn-rtllib/mmu/en_global.sv new file mode 100644 index 0000000000..fe9525112e --- /dev/null +++ b/finn-rtllib/mmu/en_global.sv @@ -0,0 +1,99 @@ +/****************************************************************************** + * Copyright (C) 2024, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + *****************************************************************************/ + +module en_global #( + int unsigned PE, + int unsigned CLEN, + int unsigned CU_SIMD = 3, + + int unsigned WEIGHT_WIDTH, + int unsigned ACTIVATION_WIDTH, + + int unsigned N_DCPL_STAGES = 2 +)( + // Global Control + input logic clk, + input logic rst, + input logic en, + + // Activation Stream + input logic s_act_tvalid, + output logic s_act_tready, + input logic [CLEN-1:0][CU_SIMD-1:0][ACTIVATION_WIDTH-1:0] s_act_tdata, + input logic [CLEN-1:0] s_act_tlast, + + output logic [CLEN-1:0][CU_SIMD-1:0][ACTIVATION_WIDTH-1:0] m_act_tdata, + output logic [CLEN-1:0] m_act_tlast, + + // Weight Stream + input logic [PE-1:0][CU_SIMD-1:0][WEIGHT_WIDTH-1:0] s_wgt_tdata, + input logic s_wgt_tvalid, + output logic s_wgt_tready, + + output logic [PE-1:0][CU_SIMD-1:0][WEIGHT_WIDTH-1:0] m_wgt_tdata, + + output logic m_tvalid +); + +// Global enable +// --------------------------------------------------------------------- +logic [CLEN-1:0][CU_SIMD-1:0][ACTIVATION_WIDTH-1:0] act_tdata; +logic [CLEN-1:0] act_tlast; +logic [PE-1:0][CU_SIMD-1:0][WEIGHT_WIDTH-1:0] wgt_tdata; +logic ovld; +logic ordy; + +assign ovld = en && s_act_tvalid && s_wgt_tvalid; +assign s_act_tready = en && s_wgt_tvalid; +assign s_wgt_tready = en && s_act_tvalid; + + +assign act_tdata = ovld ? s_act_tdata : '0; +assign act_tlast = ovld ? s_act_tlast : '0; +assign wgt_tdata = ovld ? s_wgt_tdata : '0; + + +// Output +// --------------------------------------------------------------------- +skid #(.DATA_WIDTH(PE*CU_SIMD*WEIGHT_WIDTH), .FEED_STAGES(N_DCPL_STAGES)) inst_oreg_weights ( + .clk(clk), .rst(rst), + .idat(wgt_tdata), .ivld(ovld), .irdy(), + .odat(m_wgt_tdata), .ovld(), .ordy(1'b1) +); + +skid #(.DATA_WIDTH(CLEN*CU_SIMD*ACTIVATION_WIDTH+CLEN), .FEED_STAGES(N_DCPL_STAGES)) inst_oreg_activations ( + .clk(clk), .rst(rst), + .idat({act_tlast, act_tdata}), .ivld(ovld), .irdy(ordy), + .odat({m_act_tlast, m_act_tdata}), .ovld(m_tvalid), .ordy(1'b1) +); + +endmodule \ No newline at end of file diff --git a/finn-rtllib/mmu/mmu_axi.sv b/finn-rtllib/mmu/mmu_axi.sv new file mode 100644 index 0000000000..92ec9213a8 --- /dev/null +++ b/finn-rtllib/mmu/mmu_axi.sv @@ -0,0 +1,252 @@ +/****************************************************************************** + * Copyright (C) 2024, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + *****************************************************************************/ + +module mmu_axi #( + string GEMM_TYPE = "mmau", + int unsigned PE, + int unsigned SIMD, + int unsigned CU_SIMD = 3, + + int unsigned MW, + int unsigned MH, + int unsigned N_VECTORS, + + int unsigned WEIGHT_WIDTH, + int unsigned ACTIVATION_WIDTH, + int unsigned ACCU_WIDTH, + + int unsigned IN_TILED = 0, + int unsigned OUT_TILED = 0, + + int unsigned DSP_STAGES = 3, + bit SIGNED_ACTIVATIONS = 1, + bit PUMPED_COMPUTE = 0, // Not used + bit FORCE_BEHAVIOURAL = 0, + + int unsigned N_DCPL_STAGES = 2, + + // Safely deducible parameters + localparam int unsigned CLEN = (SIMD + CU_SIMD-1)/ CU_SIMD, + localparam int unsigned WSIMD = PE * CU_SIMD, + localparam int unsigned ASIMD = CLEN * CU_SIMD, + + localparam int unsigned WEIGHT_STREAM_WIDTH = WSIMD * WEIGHT_WIDTH, + localparam int unsigned WEIGHT_STREAM_WIDTH_BA = (WEIGHT_STREAM_WIDTH + 7)/8 * 8, + localparam int unsigned INPUT_STREAM_WIDTH = SIMD * ACTIVATION_WIDTH, + localparam int unsigned INPUT_STREAM_WIDTH_BA = (INPUT_STREAM_WIDTH + 7)/8 * 8, + localparam int unsigned OUTPUT_STREAM_WIDTH = PE * ACCU_WIDTH, + localparam int unsigned OUTPUT_STREAM_WIDTH_BA = (OUTPUT_STREAM_WIDTH + 7)/8 * 8 +)( + // Global Control + input logic ap_clk, + input logic ap_clk2x, + input logic ap_rst_n, + + // Weight Stream + input logic [WEIGHT_STREAM_WIDTH_BA-1:0] s_axis_weights_tdata, + input logic s_axis_weights_tvalid, + output logic s_axis_weights_tready, + + // Input Stream + input logic [INPUT_STREAM_WIDTH_BA-1:0] s_axis_input_tdata, + input logic s_axis_input_tvalid, + output logic s_axis_input_tready, + + // Output Stream + output logic [OUTPUT_STREAM_WIDTH_BA-1:0] m_axis_output_tdata, + output logic m_axis_output_tvalid, + input logic m_axis_output_tready +); + +// Checks and params +// --------------------------------------------------------------------- + initial begin + if (SIMD != CLEN * CU_SIMD) begin + $error("%m: SIMD (%0d) should be a multiple of CU_SIMD and CLEN. (TODO: Needs testing)", SIMD); + $finish; + end + if (MW % SIMD != 0) begin + $error("%m: MW (%0d) is not a multiple of SIMD (%0d).", MW, SIMD); + $finish; + end + if (MH % PE != 0) begin + $error("%m: MH (%0d) is not a multiple of PE (%0d).", MH, PE); + $finish; + end + if (WEIGHT_WIDTH > 8) begin + $error("Weight width of %0d-bits exceeds maximum of 8-bits", WEIGHT_WIDTH); + $finish; + end + if (ACTIVATION_WIDTH > 8) begin + $error("Activation width of %0d-bits exceeds maximum of 8-bits", ACTIVATION_WIDTH); + $finish; + end + end + + localparam int unsigned SF = MW / SIMD; + localparam int unsigned NF = MH / PE; + localparam int unsigned N_TRS_OP = SF * NF * N_VECTORS; + localparam int unsigned N_TRS_EP = (GEMM_TYPE == "mmau_1d") ? CLEN-1 + CLEN-1 + DSP_STAGES + 2 : + CLEN-1 + CLEN-1 + DSP_STAGES + PE; + +// Input replay +// --------------------------------------------------------------------- + logic [SIMD-1:0][ACTIVATION_WIDTH-1:0] adat_s0; + logic [ASIMD-1:0][ACTIVATION_WIDTH-1:0] adat_s0_wd; + logic alast_s0; + logic avld_s0, ardy_s0; + + logic [SIMD-1:0][ACTIVATION_WIDTH-1:0] act_s0_tdata; + logic [ASIMD-1:0][ACTIVATION_WIDTH-1:0] act_s0_tdata_mod; + logic act_s0_tlast; + logic act_s0_tvalid, act_s0_tready; + + replay_buff_mmau #(.XC(SF), .YC(CLEN), .W(SIMD*ACTIVATION_WIDTH), .N_REPS(NF), .IO_TILED(IN_TILED)) activation_replay ( + .clk(ap_clk), .rst(~ap_rst_n), + .ivld(s_axis_input_tvalid), .irdy(s_axis_input_tready), .idat(s_axis_input_tdata), + .ovld(act_s0_tvalid), .ordy(act_s0_tready), .odat(act_s0_tdata), .olast(act_s0_tlast) + ); + + if (ASIMD > SIMD) + assign act_s0_tdata_mod[ASIMD-1:SIMD] = '0; + assign act_s0_tdata_mod[SIMD-1:0] = act_s0_tdata[SIMD-1:0]; + +// Activation scheduling +// --------------------------------------------------------------------- + logic [ASIMD-1:0][ACTIVATION_WIDTH-1:0] act_s1_tdata; + logic [CLEN-1:0] act_s1_tlast; + logic act_s1_tvalid, act_s1_tready; + + sched_activations #( + .CU_SIMD(CU_SIMD), .CLEN(CLEN), + .ACTIVATION_WIDTH(ACTIVATION_WIDTH), + .N_BEATS_OP(N_TRS_OP), .N_BEATS_EP(N_TRS_EP) + ) inst_sched_act ( + .clk(ap_clk), .rst(~ap_rst_n), + .s_axis_tdata(act_s0_tdata_mod), .s_axis_tvalid(act_s0_tvalid), .s_axis_tready(act_s0_tready), .s_axis_tlast(act_s0_tlast), + .m_axis_tdata(act_s1_tdata), .m_axis_tvalid(act_s1_tvalid), .m_axis_tready(act_s1_tready), .m_axis_tlast(act_s1_tlast) + ); + +// Weight scheduling +// --------------------------------------------------------------------- + logic [WSIMD-1:0][WEIGHT_WIDTH-1:0] wgt_s1_tdata; + logic wgt_s1_tvalid, wgt_s1_tready; + +if(GEMM_TYPE == "mmau_1d") begin + sched_weights_1d #( + .CU_SIMD(CU_SIMD), .PE(PE), + .WEIGHT_WIDTH(WEIGHT_WIDTH), + .N_BEATS_OP(N_TRS_OP), .N_BEATS_EP(N_TRS_EP) + ) inst_sched_wgt ( + .clk(ap_clk), .rst(~ap_rst_n), + .s_axis_tdata(s_axis_weights_tdata), .s_axis_tvalid(s_axis_weights_tvalid), .s_axis_tready(s_axis_weights_tready), + .m_axis_tdata(wgt_s1_tdata), .m_axis_tvalid(wgt_s1_tvalid), .m_axis_tready(wgt_s1_tready) + ); +end else begin + sched_weights_2d #( + .CU_SIMD(CU_SIMD), .PE(PE), + .WEIGHT_WIDTH(WEIGHT_WIDTH), + .N_BEATS_OP(N_TRS_OP), .N_BEATS_EP(N_TRS_EP) + ) inst_sched_wgt ( + .clk(ap_clk), .rst(~ap_rst_n), + .s_axis_tdata(s_axis_weights_tdata), .s_axis_tvalid(s_axis_weights_tvalid), .s_axis_tready(s_axis_weights_tready), + .m_axis_tdata(wgt_s1_tdata), .m_axis_tvalid(wgt_s1_tvalid), .m_axis_tready(wgt_s1_tready) + ); +end + + +// Global enable +// --------------------------------------------------------------------- + logic en; + logic [ASIMD-1:0][ACTIVATION_WIDTH-1:0] act_s2_tdata; + logic [CLEN-1:0] act_s2_tlast; + logic [WSIMD-1:0][WEIGHT_WIDTH-1:0] wgt_s2_tdata; + logic s2_tvalid; + + en_global #( + .PE(PE), .CLEN(CLEN), .CU_SIMD(CU_SIMD), + .WEIGHT_WIDTH(WEIGHT_WIDTH), .ACTIVATION_WIDTH(ACTIVATION_WIDTH) + ) inst_en_global ( + .clk(ap_clk), .rst(~ap_rst_n), + .en(en), + .s_act_tvalid(act_s1_tvalid), .s_act_tready(act_s1_tready), .s_act_tdata(act_s1_tdata), .s_act_tlast(act_s1_tlast), + .m_act_tdata(act_s2_tdata), .m_act_tlast(act_s2_tlast), + .s_wgt_tvalid(wgt_s1_tvalid), .s_wgt_tready(wgt_s1_tready), .s_wgt_tdata(wgt_s1_tdata), + .m_wgt_tdata(wgt_s2_tdata), + .m_tvalid(s2_tvalid) + ); + +// CU +// --------------------------------------------------------------------- + logic p_tvalid, p_tready; + logic [PE-1:0][ACCU_WIDTH-1:0] p_tdata; + +if(GEMM_TYPE == "mmau_1d") begin + cu_mmau_1d #( + .PE(PE), .CLEN(CLEN), .CU_SIMD(CU_SIMD), + .ACTIVATION_WIDTH(ACTIVATION_WIDTH), .WEIGHT_WIDTH(WEIGHT_WIDTH), .ACCU_WIDTH(ACCU_WIDTH), + .FORCE_BEHAVIOURAL(FORCE_BEHAVIOURAL) + ) inst_cu_mmau ( + .clk(ap_clk), .rst(~ap_rst_n), + .en(en), + .ivld(s2_tvalid), .a(act_s2_tdata), .ilast(act_s2_tlast), .w(wgt_s2_tdata), + .m_axis_tvalid(p_tvalid), .m_axis_tready(p_tready), .m_axis_tdata(p_tdata) + ); +end else begin + cu_mmau_2d #( + .PE(PE), .CLEN(CLEN), .CU_SIMD(CU_SIMD), + .ACTIVATION_WIDTH(ACTIVATION_WIDTH), .WEIGHT_WIDTH(WEIGHT_WIDTH), .ACCU_WIDTH(ACCU_WIDTH), + .FORCE_BEHAVIOURAL(FORCE_BEHAVIOURAL) + ) inst_cu_mmau ( + .clk(ap_clk), .rst(~ap_rst_n), + .en(en), + .ivld(s2_tvalid), .a(act_s2_tdata), .ilast(act_s2_tlast), .w(wgt_s2_tdata), + .m_axis_tvalid(p_tvalid), .m_axis_tready(p_tready), .m_axis_tdata(p_tdata) + ); +end + + +// Reorder +// --------------------------------------------------------------------- + if(OUT_TILED == 0) begin + reorder_out #(.W(OUTPUT_STREAM_WIDTH_BA), .XC(NF), .YC(CLEN)) inst_reorder_out ( + .clk(ap_clk), .rst(~ap_rst_n), + .ivld(p_tvalid), .irdy(p_tready), .idat(p_tdata), + .ovld(m_axis_output_tvalid), .ordy(m_axis_output_tready), .odat(m_axis_output_tdata) + ); + end else begin + assign m_axis_output_tvalid = p_tvalid; + assign p_tready = m_axis_output_tready; + assign m_axis_output_tdata = p_tdata; + end + +endmodule \ No newline at end of file diff --git a/finn-rtllib/mmu/mmu_axi_wrapper.v b/finn-rtllib/mmu/mmu_axi_wrapper.v new file mode 100644 index 0000000000..70c75193ba --- /dev/null +++ b/finn-rtllib/mmu/mmu_axi_wrapper.v @@ -0,0 +1,100 @@ +/****************************************************************************** + * Copyright (C) 2024, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + *****************************************************************************/ + +module $MODULE_NAME_AXI_WRAPPER$ #( + parameter PE = $PE$, + parameter SIMD = $SIMD$, + parameter CU_SIMD = 3, + parameter ACTIVATION_WIDTH = $ACTIVATION_WIDTH$, + parameter WEIGHT_WIDTH = $WEIGHT_WIDTH$, + parameter ACCU_WIDTH = $ACCU_WIDTH$, + parameter MW = $MW$, + parameter MH = $MH$, + parameter N_VECTORS = $N_VECTORS$, + parameter SIGNED_ACTIVATIONS = $SIGNED_ACTIVATIONS$, + parameter PUMPED_COMPUTE = $PUMPED_COMPUTE$, + + // Safely deducible parameters + parameter WSIMD = PE * CU_SIMD, + parameter WEIGHT_STREAM_WIDTH_BA = (WSIMD * WEIGHT_WIDTH + 7)/8 * 8, + parameter INPUT_STREAM_WIDTH_BA = (SIMD * ACTIVATION_WIDTH + 7) / 8 * 8, + parameter OUTPUT_STREAM_WIDTH_BA = (PE * ACCU_WIDTH + 7)/8 * 8 +)( + // Global Control + (* X_INTERFACE_PARAMETER = "ASSOCIATED_BUSIF in1_V:in0_V:out0_V, ASSOCIATED_RESET ap_rst_n" *) + (* X_INTERFACE_INFO = "xilinx.com:signal:clock:1.0 ap_clk CLK" *) + input ap_clk, + (* X_INTERFACE_PARAMETER = "ASSOCIATED_RESET ap_rst_n" *) + (* X_INTERFACE_INFO = "xilinx.com:signal:clock:1.0 ap_clk2x CLK" *) + input ap_clk2x, + (* X_INTERFACE_PARAMETER = "POLARITY ACTIVE_LOW" *) + input ap_rst_n, + + // Weight Stream + input [WEIGHT_STREAM_WIDTH_BA-1:0] in1_V_TDATA, + input in1_V_TVALID, + output in1_V_TREADY, + // Input Stream + input [INPUT_STREAM_WIDTH_BA-1:0] in0_V_TDATA, + input in0_V_TVALID, + output in0_V_TREADY, + // Output Stream + output [OUTPUT_STREAM_WIDTH_BA-1:0] out0_V_TDATA, + output out0_V_TVALID, + input out0_V_TREADY +); + +// NOTE: MW and MH are swapped -- FINN convention (MW=input features, MH=output features) +// is opposite to the MMU RTL convention. +mmu_axi #( + .GEMM_TYPE("$GEMM_TYPE$"), + .PE(PE), .SIMD(SIMD), .CU_SIMD(CU_SIMD), + .ACTIVATION_WIDTH(ACTIVATION_WIDTH), .WEIGHT_WIDTH(WEIGHT_WIDTH), .ACCU_WIDTH(ACCU_WIDTH), + .MW(MW), .MH(MH), .N_VECTORS(N_VECTORS), + .SIGNED_ACTIVATIONS(SIGNED_ACTIVATIONS), .PUMPED_COMPUTE(PUMPED_COMPUTE), + .FORCE_BEHAVIOURAL(0) +) inst ( + .ap_clk(ap_clk), + .ap_clk2x(ap_clk2x), + .ap_rst_n(ap_rst_n), + .s_axis_weights_tdata(in1_V_TDATA), + .s_axis_weights_tvalid(in1_V_TVALID), + .s_axis_weights_tready(in1_V_TREADY), + .s_axis_input_tdata(in0_V_TDATA), + .s_axis_input_tvalid(in0_V_TVALID), + .s_axis_input_tready(in0_V_TREADY), + .m_axis_output_tdata(out0_V_TDATA), + .m_axis_output_tvalid(out0_V_TVALID), + .m_axis_output_tready(out0_V_TREADY) +); + +endmodule // $MODULE_NAME_AXI_WRAPPER$ diff --git a/finn-rtllib/mmu/q_writer.sv b/finn-rtllib/mmu/q_writer.sv new file mode 100644 index 0000000000..8d974a7f43 --- /dev/null +++ b/finn-rtllib/mmu/q_writer.sv @@ -0,0 +1,124 @@ +/****************************************************************************** + * Copyright (C) 2024, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + *****************************************************************************/ + +module q_writer #( + int unsigned CU_SIMD, + int unsigned CLEN, + int unsigned ACTIVATION_WIDTH +)( + // Global Control + input logic clk, + input logic rst, + + // Input Stream + input logic [CLEN-1:0][CU_SIMD-1:0][ACTIVATION_WIDTH-1:0] s_axis_tdata, + input logic s_axis_tlast, + input logic s_axis_tvalid, + output logic s_axis_tready, + + // Output Stream + output logic [CU_SIMD-1:0][ACTIVATION_WIDTH-1:0] m_axis_tdata, + output logic m_axis_tlast, + output logic m_axis_tvalid, + input logic m_axis_tready +); + +// Params +// --------------------------------------------------------------------- + localparam integer CLEN_BITS = (CLEN == 1) ? 1 : $clog2(CLEN); + +// Skid +// --------------------------------------------------------------------- + logic axis_s0_tvalid, axis_s0_tready; + logic axis_s0_tlast; + logic [CLEN-1:0][CU_SIMD-1:0][ACTIVATION_WIDTH-1:0] axis_s0_tdata; + + skid #(.DATA_WIDTH(CLEN*CU_SIMD*ACTIVATION_WIDTH + 1), .FEED_STAGES(1)) inst_reg ( + .clk(clk), .rst(rst), + .idat({s_axis_tlast, s_axis_tdata}), .ivld(s_axis_tvalid), .irdy(s_axis_tready), + .odat({axis_s0_tlast, axis_s0_tdata}), .ovld(axis_s0_tvalid), .ordy(axis_s0_tready) + ); + +// PtoS +// --------------------------------------------------------------------- + logic [CLEN_BITS-1:0] wr_ptr_C = '0, wr_ptr_N; + + logic axis_s1_tvalid, axis_s1_tready; + logic axis_s1_tlast; + logic [CU_SIMD-1:0][ACTIVATION_WIDTH-1:0] axis_s1_tdata; + + // REG + always_ff @(posedge clk) begin + if(rst) begin + wr_ptr_C <= 0; + end else begin + wr_ptr_C <= wr_ptr_N; + end + end + + // DP + always_comb begin + wr_ptr_N = wr_ptr_C; + + axis_s0_tready = 1'b0; + axis_s1_tvalid = 1'b0; + axis_s1_tdata = axis_s0_tdata[wr_ptr_C]; + axis_s1_tlast = 1'b0; + + if(axis_s0_tvalid) begin + axis_s1_tvalid = 1'b1; + + if(axis_s1_tready) begin + if(wr_ptr_C == CLEN-1) begin + wr_ptr_N = 0; + axis_s0_tready = 1'b1; + axis_s1_tlast = axis_s0_tlast; + end else begin + wr_ptr_N = wr_ptr_C + 1; + end + end + end + end + +// Queue +// --------------------------------------------------------------------- + Q_srl #( + .depth(CLEN), + .width(CU_SIMD*ACTIVATION_WIDTH+1) + ) inst_queue ( + .clock(clk), .reset(rst), + .count(), .maxcount(), + .i_v(axis_s1_tvalid), .i_r(axis_s1_tready), .i_d({axis_s1_tlast, axis_s1_tdata}), + .o_v(m_axis_tvalid), .o_r(m_axis_tready), .o_d({m_axis_tlast, m_axis_tdata}) + ); + +endmodule \ No newline at end of file diff --git a/finn-rtllib/mmu/reorder_out.sv b/finn-rtllib/mmu/reorder_out.sv new file mode 100644 index 0000000000..19421937e4 --- /dev/null +++ b/finn-rtllib/mmu/reorder_out.sv @@ -0,0 +1,341 @@ +/****************************************************************************** + * Copyright (C) 2024, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + *****************************************************************************/ + +module reorder_out #( + int unsigned W, + int unsigned XC, + int unsigned YC +)( + input logic clk, + input logic rst, + + input logic ivld, + output logic irdy, + input logic [W-1:0] idat, + + output logic ovld, + input logic ordy, + output logic [W-1:0] odat +); + +// ---------------------------------------------------------------------------- +// Consts and types +// ---------------------------------------------------------------------------- + +localparam int unsigned RAM_BITS = (W + 7)/8 * 8; +localparam int unsigned WGT_EN_BITS = RAM_BITS / 8; +localparam int unsigned XYC = XC * YC; +localparam int unsigned XCNT_BITS = (XC == 1) ? 1 : $clog2(XC); +localparam int unsigned YCNT_BITS = (YC == 1) ? 1 : $clog2(YC); +localparam int unsigned XYCNT_BITS = (XYC == 1) ? 1 : $clog2(XYC); + +typedef enum logic[1:0] {ST_WR_0, ST_WR_0_WAIT, ST_WR_1, ST_WR_1_WAIT} state_wr_t; +typedef enum logic {ST_RD_0, ST_RD_1} state_rd_t; + +// ---------------------------------------------------------------------------- +// Writer +// ---------------------------------------------------------------------------- + +// -- Regs +state_wr_t state_wr_C = ST_WR_0, state_wr_N; +state_rd_t state_rd_C = ST_RD_0, state_rd_N; + +logic [XCNT_BITS-1:0] curr_wrX_C = '0, curr_wrX_N; +logic [YCNT_BITS-1:0] curr_wrY_C = '0, curr_wrY_N; + +// -- Ram +logic [1:0][WGT_EN_BITS-1:0] a_we; // Bank enables +logic [1:0][XYCNT_BITS-1:0] a_addr; +logic [1:0][W-1:0] a_data_in; + +// -- Offsets +logic [XC-1:0][XYCNT_BITS-1:0] x_offsets; +for(genvar i = 0; i < XC; i++) begin + assign x_offsets[i] = i*YC; +end + +// -- IPC +logic done; + +// -- REG +always_ff @( posedge clk ) begin : REG_PROC_WR + if(rst) begin + state_wr_C <= ST_WR_0; + + curr_wrX_C <= 0; + curr_wrY_C <= 0; + end + else begin + state_wr_C <= state_wr_N; + + curr_wrX_C <= curr_wrX_N; + curr_wrY_C <= curr_wrY_N; + end +end + +// -- NSL +always_comb begin : NSL_PROC_WR + state_wr_N = state_wr_C; + + case (state_wr_C) + ST_WR_0: + if ((curr_wrY_C == YC - 1) && (curr_wrX_C == XC - 1) && ivld) begin + state_wr_N = (done || (state_rd_C == ST_RD_0)) ? ST_WR_1 : ST_WR_0_WAIT; + end + + ST_WR_0_WAIT: + state_wr_N = (done || (state_rd_C == ST_RD_0)) ? ST_WR_1 : ST_WR_0_WAIT; + + ST_WR_1: + if ((curr_wrY_C == YC - 1) && (curr_wrX_C == XC - 1) && ivld) begin + state_wr_N = (done || (state_rd_C == ST_RD_1)) ? ST_WR_0 : ST_WR_1_WAIT; + end + + ST_WR_1_WAIT: + state_wr_N = (done || (state_rd_C == ST_RD_1)) ? ST_WR_0 : ST_WR_1_WAIT; + + endcase +end + +// -- DP +always_comb begin : DP_PROC_WR + curr_wrX_N = curr_wrX_C; + curr_wrY_N = curr_wrY_C; + + // Input + irdy = 1'b0; + + // Buffer control + a_we = '0; + for(int i = 0; i < 2; i++) begin + a_addr[i] = x_offsets[curr_wrX_C] + curr_wrY_C; + a_data_in[i] = idat; + end + + // Write and count + case (state_wr_C) + ST_WR_0, ST_WR_1: begin + irdy = 1'b1; + + if(ivld) begin + if(state_wr_C == ST_WR_0) a_we[0] = '1; else a_we[1] = '1; + + curr_wrY_N = (curr_wrY_C == YC-1) ? 0 : curr_wrY_C + 1; + curr_wrX_N = (curr_wrY_C == YC-1) ? ((curr_wrX_C == XC-1) ? 0 : curr_wrX_C + 1) : curr_wrX_C; + end + end + endcase + +end + + +// ---------------------------------------------------------------------------- +// Reader +// ---------------------------------------------------------------------------- + +// -- Regs +logic [XCNT_BITS-1:0] curr_rdX_C = '0, curr_rdX_N; +logic [YCNT_BITS-1:0] curr_rdY_C = '0, curr_rdY_N; + +// -- Ram +logic [1:0] vld_s0_C = '0, vld_s0_N; +logic [1:0] vld_s1_C = '0, vld_s1_N; +logic vld_C = '0, vld_N; +logic [W-1:0] odat_C = '0, odat_N; + +logic [1:0][XYCNT_BITS-1:0] b_addr; +logic [1:0][W-1:0] odat_ram; + +// -- Cond +logic cond_go; + +// -- Oreg +logic [W-1:0] odat_int; +logic ovld_int; +logic ordy_int; + +// -- REG +always_ff @( posedge clk ) begin : REG_PROC_RD + if(rst) begin + state_rd_C <= ST_RD_0; + + curr_rdX_C <= 0; + curr_rdY_C <= 0; + + vld_s0_C <= 0; + vld_s1_C <= 0; + vld_C <= 0; + odat_C <= 0; + end + else begin + state_rd_C <= state_rd_N; + + curr_rdX_C <= curr_rdX_N; + curr_rdY_C <= curr_rdY_N; + + vld_s0_C <= vld_s0_N; + vld_s1_C <= vld_s1_N; + vld_C <= vld_N; + odat_C <= odat_N; + end +end + +// -- NSL +always_comb begin : NSL_PROC_RD + state_rd_N = state_rd_C; + + case (state_rd_C) + ST_RD_0: + if(ordy_int && ((state_wr_C == ST_WR_0) ? cond_go : 1'b1)) begin + if((curr_rdX_C == XC-1) && (curr_rdY_C == YC-1)) begin + state_rd_N = ST_RD_1; + end + end + + ST_RD_1: + if(ordy_int && ((state_wr_C == ST_WR_1) ? cond_go : 1'b1)) begin + if((curr_rdX_C == XC-1) && (curr_rdY_C == YC-1)) begin + state_rd_N = ST_RD_0; + end + end + + endcase +end + +// -- DP cond +always_comb begin + cond_go = 1'b0; + + if(curr_wrX_C > curr_rdX_C) begin + cond_go = 1'b1; + end + else if(curr_wrX_C == curr_rdX_C) begin + if(curr_wrY_C > curr_rdY_C) begin + cond_go = 1'b1; + end + end +end + +// -- DP +always_comb begin : DP_PROC_RD + curr_rdX_N = curr_rdX_C; + curr_rdY_N = curr_rdY_C; + + for(int i = 0; i < 2; i++) begin + vld_s0_N[i] = ordy_int ? 1'b0 : vld_s0_C[i]; + vld_s1_N[i] = ordy_int ? vld_s0_C[i] : vld_s1_C[i]; + end + + vld_N = ordy_int ? |vld_s1_C : vld_C; + odat_N = ordy_int ? (vld_s1_C[0] ? odat_ram[0] : odat_ram[1]) : odat_C; + + for(int i = 0; i < 2; i++) begin + b_addr[i] = x_offsets[curr_rdX_C] + curr_rdY_C; + end + + done = 1'b0; + + case(state_rd_C) + ST_RD_0: begin + if(ordy_int) begin + if((state_wr_C == ST_WR_0) ? cond_go : 1'b1) begin + vld_s0_N[0] = 1'b1; + + curr_rdX_N = (curr_rdX_C == XC-1) ? 0 : curr_rdX_C + 1; + curr_rdY_N = (curr_rdX_C == XC-1) ? ((curr_rdY_C == YC-1) ? 0 : curr_rdY_C + 1) : curr_rdY_C; + done = ((curr_rdY_C == YC-1) && (curr_rdX_C == XC-1)); + end + end + end + + ST_RD_1: begin + if(ordy_int) begin + if((state_wr_C == ST_WR_1) ? cond_go : 1'b1) begin + vld_s0_N[1] = 1'b1; + + curr_rdX_N = (curr_rdX_C == XC-1) ? 0 : curr_rdX_C + 1; + curr_rdY_N = (curr_rdX_C == XC-1) ? ((curr_rdY_C == YC-1) ? 0 : curr_rdY_C + 1) : curr_rdY_C; + done = ((curr_rdY_C == YC-1) && (curr_rdX_C == XC-1)); + end + end + end + + endcase + +end + +assign ovld_int = vld_C; +assign odat_int = odat_C; + +// ---------------------------------------------------------------------------- +// Matrix +// ---------------------------------------------------------------------------- + +for(genvar i = 0; i < 2; i++) begin + ram_p_c #( + .ADDR_BITS(XYCNT_BITS), + .DATA_BITS(RAM_BITS), + .RAM_STYLE("distributed") + ) inst_ram_tp_c ( + .clk(clk), + .a_en(1'b1), + .a_we(a_we[i]), + .a_addr(a_addr[i]), + .b_en(ordy_int), + .b_addr(b_addr[i]), + .a_data_in(a_data_in[i]), + .a_data_out(), + .b_data_out(odat_ram[i]) + ); +end + +// ---------------------------------------------------------------------------- +// Output +// ---------------------------------------------------------------------------- + +Q_srl #( + .depth(2), .width(W) +) inst_out_fifo ( + .clock(clk), + .reset(rst), + .count(), + .maxcount(), + .i_d(odat_int), + .i_v(ovld_int), + .i_r(ordy_int), + .o_d(odat), + .o_v(ovld), + .o_r(ordy) +); + + +endmodule \ No newline at end of file diff --git a/finn-rtllib/mmu/replay_buff_mmau.sv b/finn-rtllib/mmu/replay_buff_mmau.sv new file mode 100644 index 0000000000..ab99bdad5c --- /dev/null +++ b/finn-rtllib/mmu/replay_buff_mmau.sv @@ -0,0 +1,397 @@ +/****************************************************************************** + * Copyright (C) 2024, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + *****************************************************************************/ + +module replay_buff_mmau #( + int unsigned W, + int unsigned XC, + int unsigned YC, + int unsigned N_REPS, + int unsigned IO_TILED = 0 +)( + input logic clk, + input logic rst, + + input logic ivld, + output logic irdy, + input logic [W-1:0] idat, + + output logic ovld, + input logic ordy, + output logic [W-1:0] odat, + output logic olast +); + +// ---------------------------------------------------------------------------- +// Consts and types +// ---------------------------------------------------------------------------- + +localparam int unsigned RAM_BITS = (W + 7)/8 * 8; +localparam int unsigned WGT_EN_BITS = RAM_BITS / 8; +localparam int unsigned XYC = XC * YC; +localparam int unsigned XCNT_BITS = (XC == 1) ? 1 : $clog2(XC); +localparam int unsigned YCNT_BITS = (YC == 1) ? 1 : $clog2(YC); +localparam int unsigned XYCNT_BITS = (XYC == 1) ? 1 : $clog2(XYC); +localparam int unsigned REPS_BITS = (N_REPS == 1) ? 1 : $clog2(N_REPS); + +typedef enum logic[1:0] {ST_WR_0, ST_WR_0_WAIT, ST_WR_1, ST_WR_1_WAIT} state_wr_t; +typedef enum logic {ST_RD_0, ST_RD_1} state_rd_t; + +// ---------------------------------------------------------------------------- +// Ireg +// ---------------------------------------------------------------------------- +logic [W-1:0] idat_int; +logic ivld_int; +logic irdy_int; + +skid #(.DATA_WIDTH(W), .FEED_STAGES(1)) isnt_input_reg ( + .clk(clk), .rst(rst), + .ivld(ivld), .irdy(irdy), .idat(idat), + .ovld(ivld_int), .ordy(irdy_int), .odat(idat_int) +); + +// ---------------------------------------------------------------------------- +// Writer +// ---------------------------------------------------------------------------- + +// -- Regs +state_wr_t state_wr_C = ST_WR_0, state_wr_N; +state_rd_t state_rd_C = ST_RD_0, state_rd_N; + +logic [XCNT_BITS-1:0] curr_wrX_C = '0, curr_wrX_N; +logic [YCNT_BITS-1:0] curr_wrY_C = '0, curr_wrY_N; + +// -- Ram +logic [1:0][WGT_EN_BITS-1:0] a_we; // Bank enables +logic [1:0][XYCNT_BITS-1:0] a_addr; +logic [1:0][W-1:0] a_data_in; + +// -- Offsets +logic [XC-1:0][XYCNT_BITS-1:0] x_offsets; +for(genvar i = 0; i < XC; i++) begin + assign x_offsets[i] = i*YC; +end + +// -- IPC +logic done; + +// -- REG +always_ff @( posedge clk ) begin : REG_PROC_WR + if(rst) begin + state_wr_C <= ST_WR_0; + + curr_wrX_C <= 0; + curr_wrY_C <= 0; + end + else begin + state_wr_C <= state_wr_N; + + curr_wrX_C <= curr_wrX_N; + curr_wrY_C <= curr_wrY_N; + end +end + +// -- NSL +always_comb begin : NSL_PROC_WR + state_wr_N = state_wr_C; + + case (state_wr_C) + ST_WR_0: + if ((curr_wrY_C == YC - 1) && (curr_wrX_C == XC - 1) && ivld_int) begin + state_wr_N = (done || (state_rd_C == ST_RD_0)) ? ST_WR_1 : ST_WR_0_WAIT; + end + + ST_WR_0_WAIT: + state_wr_N = (done || (state_rd_C == ST_RD_0)) ? ST_WR_1 : ST_WR_0_WAIT; + + ST_WR_1: + if ((curr_wrY_C == YC - 1) && (curr_wrX_C == XC - 1) && ivld_int) begin + state_wr_N = (done || (state_rd_C == ST_RD_1)) ? ST_WR_0 : ST_WR_1_WAIT; + end + + ST_WR_1_WAIT: + state_wr_N = (done || (state_rd_C == ST_RD_1)) ? ST_WR_0 : ST_WR_1_WAIT; + + endcase +end + +// -- DP +always_comb begin : DP_PROC_WR + curr_wrX_N = curr_wrX_C; + curr_wrY_N = curr_wrY_C; + + // Input + irdy_int = 1'b0; + + // Buffer control + a_we = '0; + for(int i = 0; i < 2; i++) begin + a_addr[i] = x_offsets[curr_wrX_C] + curr_wrY_C; + a_data_in[i] = idat_int; + end + + // Write and count + case (state_wr_C) + ST_WR_0, ST_WR_1: begin + irdy_int = 1'b1; + + if(ivld_int) begin + if(state_wr_C == ST_WR_0) a_we[0] = '1; else a_we[1] = '1; + + if(IO_TILED == 1) begin + curr_wrY_N = (curr_wrY_C == YC-1) ? 0 : curr_wrY_C + 1; + curr_wrX_N = (curr_wrY_C == YC-1) ? ((curr_wrX_C == XC-1) ? 0 : curr_wrX_C + 1) : curr_wrX_C; + end else begin + curr_wrX_N = (curr_wrX_C == XC-1) ? 0 : curr_wrX_C + 1; + curr_wrY_N = (curr_wrX_C == XC-1) ? ((curr_wrY_C == YC-1) ? 0 : curr_wrY_C + 1) : curr_wrY_C; + end + end + end + endcase + +end + +// ---------------------------------------------------------------------------- +// Reader +// ---------------------------------------------------------------------------- + +// -- Regs +logic [XCNT_BITS-1:0] curr_rdX_C = '0, curr_rdX_N; +logic [YCNT_BITS-1:0] curr_rdY_C = '0, curr_rdY_N; +logic [REPS_BITS-1:0] curr_reps_C = '0, curr_reps_N; + +// -- Ram +logic [1:0] vld_s0_C = '0, vld_s0_N; +logic [1:0] vld_s1_C = '0, vld_s1_N; +logic vld_C = '0, vld_N; +logic last_s0_C = '0, last_s0_N; +logic last_s1_C = '0, last_s1_N; +logic last_C = '0, last_N; +logic [W-1:0] odat_C = '0, odat_N; + +logic [1:0][XYCNT_BITS-1:0] b_addr; +logic [1:0][W-1:0] odat_ram; + +// -- Cond +logic cond_go; + +// -- Oreg +logic [W-1:0] odat_int; +logic ovld_int; +logic ordy_int; +logic olast_int; + +// -- REG +always_ff @( posedge clk ) begin : REG_PROC_RD + if(rst) begin + state_rd_C <= ST_RD_0; + + curr_rdX_C <= 0; + curr_rdY_C <= 0; + curr_reps_C <= 0; + + vld_s0_C <= 0; + vld_s1_C <= 0; + vld_C <= 0; + odat_C <= 0; + last_s0_C <= 0; + last_s1_C <= 0; + last_C <= 0; + end + else begin + state_rd_C <= state_rd_N; + + curr_rdX_C <= curr_rdX_N; + curr_rdY_C <= curr_rdY_N; + curr_reps_C <= curr_reps_N; + + vld_s0_C <= vld_s0_N; + vld_s1_C <= vld_s1_N; + vld_C <= vld_N; + odat_C <= odat_N; + last_s0_C <= last_s0_N; + last_s1_C <= last_s1_N; + last_C <= last_N; + end +end + +// -- NSL +always_comb begin : NSL_PROC_RD + state_rd_N = state_rd_C; + + case (state_rd_C) + ST_RD_0: + if(ordy_int && ((state_wr_C == ST_WR_0) ? cond_go : 1'b1)) begin + if((curr_rdX_C == XC-1) && (curr_rdY_C == YC-1) && (curr_reps_C == N_REPS-1)) begin + state_rd_N = ST_RD_1; + end + end + + ST_RD_1: + if(ordy_int && ((state_wr_C == ST_WR_1) ? cond_go : 1'b1)) begin + if((curr_rdX_C == XC-1) && (curr_rdY_C == YC-1) && (curr_reps_C == N_REPS-1)) begin + state_rd_N = ST_RD_0; + end + end + + endcase +end + +// -- DP cond +always_comb begin + cond_go = 1'b0; + + if(IO_TILED) begin + if(curr_wrX_C > curr_rdX_C) begin + cond_go = 1'b1; + end + else if(curr_wrX_C == curr_rdX_C) begin + if(curr_wrY_C > curr_rdY_C) begin + cond_go = 1'b1; + end + end + end else begin + if(curr_wrY_C > curr_rdY_C) begin + cond_go = 1'b1; + end + else if(curr_wrY_C == curr_rdY_C) begin + if(curr_wrX_C > curr_rdX_C) begin + cond_go = 1'b1; + end + end + end +end + +// -- DP +always_comb begin : DP_PROC_RD + curr_rdX_N = curr_rdX_C; + curr_rdY_N = curr_rdY_C; + curr_reps_N = curr_reps_C; + + for(int i = 0; i < 2; i++) begin + vld_s0_N[i] = ordy_int ? 1'b0 : vld_s0_C[i]; + vld_s1_N[i] = ordy_int ? vld_s0_C[i] : vld_s1_C[i]; + end + + vld_N = ordy_int ? |vld_s1_C : vld_C; + odat_N = ordy_int ? (vld_s1_C[0] ? odat_ram[0] : odat_ram[1]) : odat_C; + + last_s0_N = ordy_int ? 1'b0 : last_s0_C; + last_s1_N = ordy_int ? last_s0_C : last_s1_C; + last_N = ordy_int ? last_s1_C : last_C; + + for(int i = 0; i < 2; i++) begin + b_addr[i] = x_offsets[curr_rdX_C] + curr_rdY_C; + end + + done = 1'b0; + + case(state_rd_C) + ST_RD_0: begin + if(ordy_int) begin + if((state_wr_C == ST_WR_0) ? cond_go : 1'b1) begin + vld_s0_N[0] = 1'b1; + + last_s0_N = (curr_rdX_C == XC-1); + + curr_rdY_N = (curr_rdY_C == YC-1) ? 0 : curr_rdY_C + 1; + curr_rdX_N = (curr_rdY_C == YC-1) ? ((curr_rdX_C == XC-1) ? 0 : curr_rdX_C + 1) : curr_rdX_C; + curr_reps_N = ((curr_rdY_C == YC-1) && (curr_rdX_C == XC-1)) ? ((curr_reps_C == N_REPS-1) ? 0 : curr_reps_C + 1) : curr_reps_C; + done = ((curr_rdY_C == YC-1) && (curr_rdX_C == XC-1) && (curr_reps_C == N_REPS-1)); + end + end + end + + ST_RD_1: begin + if(ordy_int) begin + if((state_wr_C == ST_WR_1) ? cond_go : 1'b1) begin + vld_s0_N[1] = 1'b1; + + last_s0_N = (curr_rdX_C == XC-1); + + curr_rdY_N = (curr_rdY_C == YC-1) ? 0 : curr_rdY_C + 1; + curr_rdX_N = (curr_rdY_C == YC-1) ? ((curr_rdX_C == XC-1) ? 0 : curr_rdX_C + 1) : curr_rdX_C; + curr_reps_N = ((curr_rdY_C == YC-1) && (curr_rdX_C == XC-1)) ? ((curr_reps_C == N_REPS-1) ? 0 : curr_reps_C + 1) : curr_reps_C; + done = ((curr_rdY_C == YC-1) && (curr_rdX_C == XC-1) && (curr_reps_C == N_REPS-1)); + end + end + end + + endcase + +end + +assign ovld_int = vld_C; +assign odat_int = odat_C; +assign olast_int = last_C; + +// ---------------------------------------------------------------------------- +// BRAM +// ---------------------------------------------------------------------------- + +for(genvar i = 0; i < 2; i++) begin + ram_p_c #( + .ADDR_BITS(XYCNT_BITS), + .DATA_BITS(RAM_BITS), + .RAM_STYLE("distributed") + ) inst_ram_tp_c ( + .clk(clk), + .a_en(1'b1), + .a_we(a_we[i]), + .a_addr(a_addr[i]), + .b_en(ordy_int), + .b_addr(b_addr[i]), + .a_data_in(a_data_in[i]), + .a_data_out(), + .b_data_out(odat_ram[i]) + ); +end + +// ---------------------------------------------------------------------------- +// Output +// ---------------------------------------------------------------------------- + +Q_srl #( + .depth(2), .width(1+W) +) inst_out_fifo ( + .clock(clk), + .reset(rst), + .count(), + .maxcount(), + .i_d({olast_int, odat_int}), + .i_v(ovld_int), + .i_r(ordy_int), + .o_d({olast, odat}), + .o_v(ovld), + .o_r(ordy) +); + +endmodule \ No newline at end of file diff --git a/finn-rtllib/mmu/sched_activations.sv b/finn-rtllib/mmu/sched_activations.sv new file mode 100644 index 0000000000..bbb989ff4e --- /dev/null +++ b/finn-rtllib/mmu/sched_activations.sv @@ -0,0 +1,199 @@ +/****************************************************************************** + * Copyright (C) 2024, Advanced Micro Devices, Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR + * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF + * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + *****************************************************************************/ + +module sched_activations #( + int unsigned CU_SIMD, + int unsigned CLEN, + int unsigned ACTIVATION_WIDTH, + + int unsigned N_BEATS_OP, + int unsigned N_BEATS_EP, + + int unsigned N_DCPL_STAGES = 2 +)( + // Global Control + input logic clk, + input logic rst, + + // Input Stream + input logic [CLEN-1:0][CU_SIMD-1:0][ACTIVATION_WIDTH-1:0] s_axis_tdata, + input logic s_axis_tlast, + input logic s_axis_tvalid, + output logic s_axis_tready, + + // Output Stream + output logic [CLEN-1:0][CU_SIMD-1:0][ACTIVATION_WIDTH-1:0] m_axis_tdata, + output logic [CLEN-1:0] m_axis_tlast, + output logic m_axis_tvalid, + input logic m_axis_tready +); + +// Params +// --------------------------------------------------------------------- + localparam integer CLEN_BITS = (CLEN == 1) ? 1 : $clog2(CLEN); + localparam integer CNT_EPLG_BITS = (N_BEATS_OP > N_BEATS_EP) ? + (N_BEATS_OP == 1) ? 1 : $clog2(N_BEATS_OP) : + (N_BEATS_EP == 1) ? 1 : $clog2(N_BEATS_EP); + +// Shifting +// --------------------------------------------------------------------- + logic [CLEN_BITS-1:0] wr_ptr_C = '0, wr_ptr_N; + + logic [CLEN-1:0] valid_C = '0, valid_N; + logic eplg_C = '0, eplg_N; + logic [CNT_EPLG_BITS-1:0] cnt_eplg_C = '0, cnt_eplg_N; + + logic [CLEN-1:0] q_in_tvalid, q_in_tready; + logic [CLEN-1:0] q_in_tlast; + logic [CLEN-1:0][CLEN-1:0][CU_SIMD-1:0][ACTIVATION_WIDTH-1:0] q_in_tdata; + + logic [CLEN-1:0] q_out_tvalid, q_out_tready; + logic [CLEN-1:0] q_out_tlast; + logic [CLEN-1:0][CU_SIMD-1:0][ACTIVATION_WIDTH-1:0] q_out_tdata; + + logic ovld, ordy; + logic [CLEN-1:0][CU_SIMD-1:0][ACTIVATION_WIDTH-1:0] odat; + logic [CLEN-1:0] olast; + + + for(genvar i = 0; i < CLEN; i++) begin + q_writer #( + .CLEN(CLEN), .CU_SIMD(CU_SIMD), .ACTIVATION_WIDTH(ACTIVATION_WIDTH) + ) inst_queue_writer ( + .clk(clk), .rst(rst), + .s_axis_tvalid(q_in_tvalid[i]), .s_axis_tready(q_in_tready[i]), .s_axis_tdata(q_in_tdata[i]), .s_axis_tlast(q_in_tlast[i]), + .m_axis_tvalid(q_out_tvalid[i]), .m_axis_tready(q_out_tready[i]), .m_axis_tdata(q_out_tdata[i]), .m_axis_tlast(q_out_tlast[i]) + ); + end + + // REG + always_ff @(posedge clk) begin + if(rst) begin + wr_ptr_C <= '0; + valid_C[0] <= 1'b1; + valid_C[CLEN-1:1] <= '0; + eplg_C <= '0; + cnt_eplg_C <= '0; + end else begin + wr_ptr_C <= wr_ptr_N; + valid_C <= valid_N; + eplg_C <= eplg_N; + cnt_eplg_C <= cnt_eplg_N; + end + end + + // DP + always_comb begin + // Read + valid_N = valid_C; + eplg_N = eplg_C; + cnt_eplg_N = cnt_eplg_C; + + q_out_tready = '0; + + // Read + if (ovld && ordy) begin + // Read from queue + for(int i = 0; i < CLEN; i++) begin + q_out_tready[i] = valid_C[i]; + end + + // Shift ctrl + valid_N[CLEN-1:1] = valid_C[CLEN-2:0]; + if(eplg_C) begin + if(cnt_eplg_C == N_BEATS_EP-1) begin + eplg_N = 1'b0; + cnt_eplg_N = 0; + valid_N[0] = 1'b1; + end else begin + cnt_eplg_N = cnt_eplg_C + 1; + valid_N[0] = 1'b0; + end + end else begin + if(cnt_eplg_C == N_BEATS_OP-1) begin + eplg_N = 1'b1; + cnt_eplg_N = 0; + valid_N[0] = 1'b0; + end else begin + cnt_eplg_N = cnt_eplg_C + 1; + valid_N[0] = 1'b1; + end + end + end + + // Write + wr_ptr_N = wr_ptr_C; + + s_axis_tready = 1'b0; + q_in_tvalid = '0; + for(int i = 0; i < CLEN; i++) begin + q_in_tdata[i] = s_axis_tdata; + q_in_tlast[i] = s_axis_tlast; + end + + if(s_axis_tvalid) begin + q_in_tvalid[wr_ptr_C] = 1'b1; + + if(q_in_tready[wr_ptr_C]) begin + s_axis_tready = 1'b1; + wr_ptr_N = (wr_ptr_C == CLEN-1) ? 0 : wr_ptr_C + 1; + end + end + + end + + // Output valid + always_comb begin + ovld = 1'b1; + + for(int i = 0; i < CLEN; i++) begin + if((valid_C[i] & q_out_tvalid[i]) != valid_C[i]) begin + ovld = 1'b0; + end + end + end + + for(genvar i = 0; i < CLEN; i++) begin + assign odat[i] = valid_C[i] ? q_out_tdata[i] : '0; + assign olast[i] = valid_C[i] ? q_out_tlast[i] : '0; + end + +// Oreg +// --------------------------------------------------------------------- + skid #(.DATA_WIDTH(CLEN*(CU_SIMD*ACTIVATION_WIDTH + 1)), .FEED_STAGES(N_DCPL_STAGES)) inst_oreg ( + .clk(clk), .rst(rst), + .idat({olast, odat}), .ivld(ovld), .irdy(ordy), + .odat({m_axis_tlast, m_axis_tdata}), .ovld(m_axis_tvalid), .ordy(m_axis_tready) + ); + + +endmodule \ No newline at end of file diff --git a/src/finn/custom_op/fpgadataflow/hwcustomop.py b/src/finn/custom_op/fpgadataflow/hwcustomop.py index efd252207b..6fb7abb709 100644 --- a/src/finn/custom_op/fpgadataflow/hwcustomop.py +++ b/src/finn/custom_op/fpgadataflow/hwcustomop.py @@ -359,8 +359,14 @@ def generate_hdl_fetch_weights(self): mh = self.get_nodeattr("MH") pe = self.get_nodeattr("PE") simd = self.get_nodeattr("SIMD") - n_reps = np.prod(self.get_nodeattr("numInputVectors")) theight = self.get_nodeattr("TH") + gemm_type = self.get_nodeattr("gemm_type") + if gemm_type in ("mmau_1d", "mmau_2d"): + cu_simd = self.get_nodeattr("CU_SIMD") + clen = (simd + cu_simd - 1) // cu_simd + n_reps = np.prod(self.get_nodeattr("numInputVectors")) // clen + else: + n_reps = np.prod(self.get_nodeattr("numInputVectors")) // theight en_mlo = "EN_MLO" if self.get_nodeattr("mlo_max_iter") else "NO_MLO" else: # Eltwise layers only have one parallelism parameter @@ -377,6 +383,23 @@ def generate_hdl_fetch_weights(self): n_max_layers = 64 code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + # Compute IWSIMD and WSIMD for the fetch_weights wrapper + if self.onnx_node.op_type in ops: + gemm_type = self.get_nodeattr("gemm_type") + if gemm_type in ("mmau_1d", "mmau_2d"): + cu_simd = self.get_nodeattr("CU_SIMD") + iwsimd = pe * cu_simd + wsimd = pe * cu_simd + elif theight > 1: + iwsimd = (pe * simd) // theight + wsimd = (pe * simd) // theight + else: + iwsimd = simd + wsimd = (pe * simd) // theight + else: + iwsimd = simd + wsimd = (pe * simd) // theight + code_gen_dict = { "$MODULE_NAME_AXI_WRAPPER$": [mname + "_fetch_weights_wrapper"], "$MW$": [str(mw)], @@ -388,6 +411,8 @@ def generate_hdl_fetch_weights(self): "$LAYER_OFFS$": [str(layer_offs)], "$N_LAYERS$": [str(n_max_layers)], "$TH$": [str(theight)], + "$IWSIMD$": [str(iwsimd)], + "$WSIMD$": [str(wsimd)], "$EN_MLO$": [en_mlo], "$DWC_MODULE_NAME$": [mname + "_dwc"], } diff --git a/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py b/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py index 1534cfb09e..f3bb6d402a 100644 --- a/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py +++ b/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py @@ -124,15 +124,16 @@ def get_nodeattr_types(self): # weight data from the weight FIFOs. "runtime_writeable_weights": ("i", False, 0, {0, 1}), "pumpedMemory": ("i", False, 0, {0, 1}), - # Matrix unit activation type - "mua_type": ( + # tiling + "TH": ("i", False, 1), + # MMU parameters + "gemm_type": ( "s", False, - "vector", - {"mv", "mvt", "mm"}, + "mvau", + {"mvau", "mvau_tiled", "mmau_1d", "mmau_2d"}, ), - # tiling - "TH": ("i", False, 1), + "CU_SIMD": ("i", False, 3), } my_attrs.update(super().get_nodeattr_types()) return my_attrs @@ -273,12 +274,17 @@ def get_instream_width(self, ind=0): # TODO: Hacky, need to clean these calls . wp = self.get_input_datatype(1).bitwidth() mem_mode = self.get_nodeattr("mem_mode") theight = self.get_nodeattr("TH") + gemm_type = self.get_nodeattr("gemm_type") match mem_mode: case "dynamic": width = pe * wp case "external" | "external_mem" | "internal_decoupled": - width = ((pe * simd) * wp) // theight + if gemm_type in ("mmau_1d", "mmau_2d"): + cu_simd = self.get_nodeattr("CU_SIMD") + width = pe * cu_simd * wp + else: + width = ((pe * simd) * wp) // theight case _: width = 0 elif ind == 2: @@ -311,6 +317,7 @@ def get_folded_input_shape(self, ind=0): vecs = list(self.get_nodeattr("numInputVectors")) n_vecs = int(np.prod(vecs)) theight = self.get_nodeattr("TH") + gemm_type = self.get_nodeattr("gemm_type") if ind == 0: # calculate shape of input 0 @@ -320,7 +327,11 @@ def get_folded_input_shape(self, ind=0): case "dynamic": folded_input_shape = (1, mw, nf, pe) case "external" | "external_mem" | "internal_decoupled": - folded_input_shape = (n_vecs, sf * nf, (simd * pe) // theight) + if gemm_type in ("mmau_1d", "mmau_2d"): + cu_simd = self.get_nodeattr("CU_SIMD") + folded_input_shape = (n_vecs, sf * nf, pe * cu_simd) + else: + folded_input_shape = (n_vecs, sf * nf, (simd * pe) // theight) case _: raise Exception("Undefined input shape for requested input") else: @@ -467,12 +478,21 @@ def uram_efficiency_estimation(self): def get_exp_cycles(self): pe = self.get_nodeattr("PE") simd = self.get_nodeattr("SIMD") + th = self.get_nodeattr("TH") num_inp_vec = self.get_nodeattr("numInputVectors") mh = self.get_nodeattr("MH") mw = self.get_nodeattr("MW") # since mmv != 1 is not supported yet, we set mmv for now to 1 mmv = 1 - exp_cycles = (mh / pe) * (mw / simd) * np.prod(num_inp_vec) / mmv + # Tiling/systolic reduces throughput + gemm_type = self.get_nodeattr("gemm_type") + if gemm_type in ("mmau_1d", "mmau_2d"): + cu_simd = self.get_nodeattr("CU_SIMD") + clen = (simd + cu_simd - 1) // cu_simd + exp_cycles = (mh / pe) * (mw / simd) * np.prod(num_inp_vec) * clen / mmv + else: + # TH>1 (tiling) reduces throughput by factor TH (tinner = PE*SIMD/TH) + exp_cycles = (mh / pe) * (mw / simd) * np.prod(num_inp_vec) * th / mmv return int(exp_cycles) def minimize_accumulator_width(self, model): @@ -736,7 +756,11 @@ def make_weight_file(self, weights, weight_file_mode, weight_file_name): weight_tensor_pe_flipped = weight_tensor_pe_flipped.reshape(1, -1, pe * simd) weight_tensor_pe_flipped = weight_tensor_pe_flipped.copy() # tiling - tinner = (pe * simd) // self.get_nodeattr("TH") + gemm_type = self.get_nodeattr("gemm_type") + if gemm_type in ("mmau_1d", "mmau_2d"): + tinner = pe * self.get_nodeattr("CU_SIMD") + else: + tinner = (pe * simd) // self.get_nodeattr("TH") weight_tensor_simd_flipped = weight_tensor_simd_flipped.reshape(1, -1, tinner) weight_tensor_pe_flipped = weight_tensor_pe_flipped.reshape(1, -1, tinner) if weight_file_mode == "decoupled_npy": @@ -1087,12 +1111,14 @@ def code_generation_ipi(self): ] # Create Vivado axis_dwidth_converter IP theight = self.get_nodeattr("TH") + gemm_type = self.get_nodeattr("gemm_type") wdt = self.get_input_datatype(1) - iwsimd = ( - (self.get_nodeattr("PE") * self.get_nodeattr("SIMD")) // theight - if theight > 1 - else self.get_nodeattr("SIMD") - ) + if gemm_type in ("mmau_1d", "mmau_2d"): + iwsimd = self.get_nodeattr("PE") * self.get_nodeattr("CU_SIMD") + elif theight > 1: + iwsimd = (self.get_nodeattr("PE") * self.get_nodeattr("SIMD")) // theight + else: + iwsimd = self.get_nodeattr("SIMD") ds_bits_ba = ((iwsimd * wdt.bitwidth() + 7) // 8) * 8 dwc_ip_name = node_name + "_dwc" s_bytes = 256 // 8 diff --git a/src/finn/custom_op/fpgadataflow/rtl/finn_loop.py b/src/finn/custom_op/fpgadataflow/rtl/finn_loop.py index c0c381a042..7e49aef1a6 100644 --- a/src/finn/custom_op/fpgadataflow/rtl/finn_loop.py +++ b/src/finn/custom_op/fpgadataflow/rtl/finn_loop.py @@ -404,6 +404,11 @@ def generate_params(self, model, path): ): # rename so it doesn't get overwritten shutil.move(param_file, new_param_file) + # also rename simd-flipped npy for external_mem MVAU nodes + npy_file = "{}/input_1.npy".format(path) + if os.path.isfile(npy_file): + new_npy_file = "{}/{}_input1_{}.npy".format(path, param_node.op_type, iter) + shutil.move(npy_file, new_npy_file) elif param_node.op_type.startswith("Thresholding"): # get all generated Thresholding dat files pe = inst.get_nodeattr("PE") @@ -444,6 +449,17 @@ def generate_params(self, model, path): for line in infile: outfile.write(line) os.remove(memblock_file) + # concatenate all .npy files together (simd-flipped for AXI-MM sim) + npy_parts = [] + for iter in range(iteration): + npy_file = "{}/{}_input1_{}.npy".format(path, param_node.op_type, iter) + if os.path.isfile(npy_file): + npy_parts.append(np.load(npy_file)) + os.remove(npy_file) + if npy_parts: + combined_npy = np.concatenate(npy_parts, axis=1) + npy_out = "{}/input1_{}_id_{}.npy".format(path, param_node.op_type, i + 1) + np.save(npy_out, combined_npy) # Replace the path for the dat files in the ipgen files if Eltwise # Adapted from transformations.fpgadataflow.replace_verilog_relpaths if param_node.op_type.startswith("Elementwise"): @@ -563,6 +579,40 @@ def ipgen_singlenode_code(self, fpgapart=None): vivado_stitch_proj_dir = self.get_nodeattr("code_gen_dir_ipgen") cmd = [] + + # Create Vivado axis_dwidth_converter IPs for intermediate_frames DWCs + olen_bits = self.get_outstream_width(0) + ilen_bits = self.get_instream_width(0) + data_bits = 256 + # DWC write path: body output width -> DMA width (256) + dwc_sink_s_bytes = (olen_bits + 7) // 8 + dwc_sink_m_bytes = data_bits // 8 + cmd += [ + "create_ip -name axis_dwidth_converter -vendor xilinx.com " + "-library ip -version 1.1 -module_name if_dwc_sink", + "set_property -dict [list " + "CONFIG.S_TDATA_NUM_BYTES {%d} " + "CONFIG.M_TDATA_NUM_BYTES {%d} " + "CONFIG.HAS_TLAST {1} " + "CONFIG.HAS_TKEEP {1} " + "] [get_ips if_dwc_sink]" % (dwc_sink_s_bytes, dwc_sink_m_bytes), + "generate_target all [get_ips if_dwc_sink]", + ] + # DWC read path: DMA width (256) -> body input width + dwc_source_s_bytes = data_bits // 8 + dwc_source_m_bytes = (ilen_bits + 7) // 8 + cmd += [ + "create_ip -name axis_dwidth_converter -vendor xilinx.com " + "-library ip -version 1.1 -module_name if_dwc_source", + "set_property -dict [list " + "CONFIG.S_TDATA_NUM_BYTES {%d} " + "CONFIG.M_TDATA_NUM_BYTES {%d} " + "CONFIG.HAS_TLAST {1} " + "CONFIG.HAS_TKEEP {1} " + "] [get_ips if_dwc_source]" % (dwc_source_s_bytes, dwc_source_m_bytes), + "generate_target all [get_ips if_dwc_source]", + ] + # add all the generated IP dirs to ip_repo_paths ip_dirs = ["list"] # add RTL streamer IP diff --git a/src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py b/src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py index f3aa269789..1b5d95913a 100644 --- a/src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py +++ b/src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py @@ -108,7 +108,14 @@ def execute_node(self, context, graph): wei = npy_to_rtlsim_input("{}/input_1.npy".format(code_gen_dir), export_wdt, wnbits) num_w_reps = np.prod(self.get_nodeattr("numInputVectors")) - num_w_reps = num_w_reps // self.get_nodeattr("TH") + gemm_type = self.get_nodeattr("gemm_type") + if gemm_type in ("mmau_1d", "mmau_2d"): + simd = self.get_nodeattr("SIMD") + cu_simd = self.get_nodeattr("CU_SIMD") + clen = (simd + cu_simd - 1) // cu_simd + num_w_reps = num_w_reps // clen + else: + num_w_reps = num_w_reps // self.get_nodeattr("TH") io_dict = { "inputs": {"in0": inp, "in1": wei * num_w_reps}, @@ -161,8 +168,36 @@ def instantiate_ip(self, cmd): node_name = self.onnx_node.name code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + gemm_type = self.get_nodeattr("gemm_type") theight = self.get_nodeattr("TH") - if theight > 1: + + if gemm_type in ("mmau_1d", "mmau_2d"): + rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/mmu/") + sourcefiles = [ + "../fifo/hdl/Q_srl.v", + "../skid/skid.sv", + "../ram/ram_p_c.sv", + "mmu_axi.sv", + "replay_buff_mmau.sv", + "sched_activations.sv", + "en_global.sv", + "q_writer.sv", + "reorder_out.sv", + ] + if gemm_type == "mmau_1d": + sourcefiles += [ + "1d/cu_mmau_1d.sv", + "1d/sched_weights_1d.sv", + "1d/collect_out_1d.sv", + "1d/sft_reg.sv", + ] + else: + sourcefiles += [ + "2d/cu_mmau_2d.sv", + "2d/sched_weights_2d.sv", + "2d/collect_out_2d.sv", + ] + elif theight > 1: rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/mvu_tiled/") sourcefiles = [ "../fifo/hdl/Q_srl.v", @@ -329,7 +364,12 @@ def generate_hdl(self, model, fpgapart, clk): self.set_nodeattr("ip_path", code_gen_dir) def prepare_codegen_default(self, fpgapart, clk): - if self.get_nodeattr("TH") > 1: + gemm_type = self.get_nodeattr("gemm_type") + if gemm_type in ("mmau_1d", "mmau_2d"): + template_path = ( + os.environ["FINN_ROOT"] + "/finn-rtllib/mmu/mmu_axi_wrapper.v" + ) + elif self.get_nodeattr("TH") > 1: template_path = ( os.environ["FINN_ROOT"] + "/finn-rtllib/mvu_tiled/mvu_tiled_axi_wrapper.v" ) @@ -362,12 +402,22 @@ def prepare_codegen_default(self, fpgapart, clk): ) code_gen_dict["$SEGMENTLEN$"] = [str(self._resolve_segment_len(clk))] + # MMU-specific template variables + if gemm_type in ("mmau_1d", "mmau_2d"): + code_gen_dict["$GEMM_TYPE$"] = [gemm_type] + code_gen_dict["$CU_SIMD$"] = [str(self.get_nodeattr("CU_SIMD"))] + n_vectors = int(np.prod(self.get_nodeattr("numInputVectors"))) + code_gen_dict["$N_VECTORS$"] = [str(n_vectors)] + return template_path, code_gen_dict def get_rtl_file_list(self, abspath=False): + gemm_type = self.get_nodeattr("gemm_type") if abspath: code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + "/" - if self.get_nodeattr("TH") > 1: + if gemm_type in ("mmau_1d", "mmau_2d"): + rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/mmu/") + elif self.get_nodeattr("TH") > 1: rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/mvu_tiled/") else: rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/mvu/") @@ -375,7 +425,35 @@ def get_rtl_file_list(self, abspath=False): code_gen_dir = "" rtllib_dir = "" - if self.get_nodeattr("TH") > 1: + if gemm_type in ("mmau_1d", "mmau_2d"): + verilog_files = [ + "../fifo/hdl/Q_srl.v", + "../skid/skid.sv", + "../ram/ram_p_c.sv", + "mmu_axi.sv", + "replay_buff_mmau.sv", + "sched_activations.sv", + "en_global.sv", + "q_writer.sv", + "reorder_out.sv", + ] + if gemm_type == "mmau_1d": + verilog_files += [ + "1d/cu_mmau_1d.sv", + "1d/sched_weights_1d.sv", + "1d/collect_out_1d.sv", + "1d/sft_reg.sv", + ] + else: + verilog_files += [ + "2d/cu_mmau_2d.sv", + "2d/sched_weights_2d.sv", + "2d/collect_out_2d.sv", + ] + verilog_files = [ + os.path.join(code_gen_dir, self.get_nodeattr("gen_top_module") + "_wrapper.v") + ] + [rtllib_dir + _ for _ in verilog_files] + elif self.get_nodeattr("TH") > 1: verilog_files = [ "../fifo/hdl/Q_srl.v", "../skid/skid.sv", @@ -407,8 +485,15 @@ def get_rtl_file_list(self, abspath=False): return verilog_files def get_verilog_paths(self): + gemm_type = self.get_nodeattr("gemm_type") verilog_paths = super().get_verilog_paths() - if self.get_nodeattr("TH") > 1: + if gemm_type in ("mmau_1d", "mmau_2d"): + verilog_paths.append(os.environ["FINN_ROOT"] + "/finn-rtllib/mmu") + if gemm_type == "mmau_1d": + verilog_paths.append(os.environ["FINN_ROOT"] + "/finn-rtllib/mmu/1d") + else: + verilog_paths.append(os.environ["FINN_ROOT"] + "/finn-rtllib/mmu/2d") + elif self.get_nodeattr("TH") > 1: verilog_paths.append(os.environ["FINN_ROOT"] + "/finn-rtllib/mvu_tiled") else: verilog_paths.append(os.environ["FINN_ROOT"] + "/finn-rtllib/mvu") diff --git a/src/finn/custom_op/fpgadataflow/templates.py b/src/finn/custom_op/fpgadataflow/templates.py index 6ce0cac42c..18b3d708a1 100644 --- a/src/finn/custom_op/fpgadataflow/templates.py +++ b/src/finn/custom_op/fpgadataflow/templates.py @@ -343,9 +343,6 @@ add_files -norecurse "$::env(FINN_ROOT)/finn-rtllib/cdma/cdma_x/cdma_x.sv" add_files -norecurse "$::env(FINN_ROOT)/finn-rtllib/cdma/cdma_x/cdma_x_rd.sv" add_files -norecurse "$::env(FINN_ROOT)/finn-rtllib/cdma/cdma_x/cdma_x_wr.sv" -add_files -norecurse "$::env(FINN_ROOT)/finn-rtllib/dwc/hdl/axis_adapter.v" -add_files -norecurse "$::env(FINN_ROOT)/finn-rtllib/dwc/hdl/axis_fifo.v" -add_files -norecurse "$::env(FINN_ROOT)/finn-rtllib/dwc/hdl/axis_fifo_adapter.sv" add_files -norecurse "$::env(FINN_ROOT)/finn-rtllib/skid/skid.sv" add_files -norecurse "$::env(FINN_ROOT)/finn-rtllib/ram/ram_p_c.sv" add_files -norecurse "$::env(FINN_ROOT)/finn-rtllib/mlo/infrastructure/intermediate_frames.sv" diff --git a/src/finn/transformation/fpgadataflow/loop_rolling.py b/src/finn/transformation/fpgadataflow/loop_rolling.py index 62287ccc2c..4782a74aed 100644 --- a/src/finn/transformation/fpgadataflow/loop_rolling.py +++ b/src/finn/transformation/fpgadataflow/loop_rolling.py @@ -254,12 +254,12 @@ def apply(self, model: ModelWrapper) -> Tuple[ModelWrapper, bool]: print("error: could not find metadata for node") exit(1) - node.metadata_props["pkg.torch.onnx.name_scopes"] = mnode.metadata_props[ - "pkg.torch.onnx.name_scopes" - ] - node.metadata_props["pkg.torch.onnx.class_hierarchy"] = mnode.metadata_props[ - "pkg.torch.onnx.class_hierarchy" - ] + node.metadata_props["pkg.torch.onnx.name_scopes"] = mnode.metadata_props.get( + "pkg.torch.onnx.name_scopes", "" + ) + node.metadata_props["pkg.torch.onnx.class_hierarchy"] = mnode.metadata_props.get( + "pkg.torch.onnx.class_hierarchy", "" + ) assert P.add_node(node) graph.sort() diff --git a/src/finn/util/mlo_sim.py b/src/finn/util/mlo_sim.py index 9f906767f3..65f7c35d97 100644 --- a/src/finn/util/mlo_sim.py +++ b/src/finn/util/mlo_sim.py @@ -31,6 +31,7 @@ # aximm simulation tasks for handling the aximm interfaces. import numpy as np +import os from qonnx.core.modelwrapper import ModelWrapper from qonnx.custom_op.registry import getCustomOp from typing import Callable @@ -81,10 +82,36 @@ def mlo_prehook_func_factory(node) -> Callable[[SimEngine], None]: if downstream.op_type.startswith("MVAU"): mvau_hbm_weights[idx] = {} mvau_hbm_weights[idx]["name"] = lb_inp.name - datfile = ( - f"{finnloop_op.get_nodeattr('code_gen_dir_ipgen')}/memblock_MVAU_rtl_id_{idx}.dat" - ) - mvau_hbm_weights[idx]["value"] = dat_file_to_numpy_array(datfile) + code_gen_dir = finnloop_op.get_nodeattr('code_gen_dir_ipgen') + npy_file = f"{code_gen_dir}/input1_MVAU_rtl_id_{idx}.npy" + datfile = f"{code_gen_dir}/memblock_MVAU_rtl_id_{idx}.dat" + mvau_op = getCustomOp(downstream) + mh = mvau_op.get_nodeattr("MH") + mw = mvau_op.get_nodeattr("MW") + wdt_width = mvau_op.get_input_datatype(1).bitwidth() + # Must match RTL LAYER_OFFS: align to AXI bus width (256 bits = 32 bytes) + axi_bytes = 32 + raw_layer_bytes = (mh * mw * wdt_width + 7) // 8 + layer_offs = (raw_layer_bytes + axi_bytes - 1) & ~(axi_bytes - 1) + if os.path.isfile(npy_file): + weight_npy = np.load(npy_file) + # Pack npy values into byte array for AXI-MM + # Memory byte order matches npy_to_rtlsim_input packing + tinner = weight_npy.shape[-1] + words_per_iter = raw_layer_bytes // tinner + flat = weight_npy.reshape(-1, tinner) + n_iters = len(flat) // words_per_iter + weight_bytes = [] + for it in range(n_iters): + for row in flat[it * words_per_iter:(it + 1) * words_per_iter]: + for val in row: + weight_bytes.append(int(val) & 0xFF) + # Pad to layer_offs boundary + pad = layer_offs - raw_layer_bytes + weight_bytes.extend([0] * pad) + mvau_hbm_weights[idx]["value"] = np.array(weight_bytes, dtype=np.uint8) + else: + mvau_hbm_weights[idx]["value"] = dat_file_to_numpy_array(datfile) mvau_hbm_weights[idx]["extern_idx"] = extern_idx mvau_hbm_weights[idx]["extern_name"] = f"m_axi_MVAU_id_{idx}" extern_idx += 1 diff --git a/tests/fpgadataflow/test_fpgadataflow_finnloop.py b/tests/fpgadataflow/test_fpgadataflow_finnloop.py index 7d6ab494bd..e48b7afd76 100644 --- a/tests/fpgadataflow/test_fpgadataflow_finnloop.py +++ b/tests/fpgadataflow/test_fpgadataflow_finnloop.py @@ -100,6 +100,10 @@ def make_loop_modelwrapper( rhs_shape=[1], eltw_param_dtype="INT8", name_suffix="", + mvau_pe=2, + mvau_simd=2, + mvau_th=1, + helper_pe=2, ): elemwise_output_dtype = ( DataType["FLOAT32"] if eltw_param_dtype == "FLOAT32" else DataType["INT32"] @@ -155,7 +159,7 @@ def make_loop_modelwrapper( { "NumChannels": mh, "NumOutputStreams": 2, - "PE": 2, + "PE": helper_pe, "inputDataType": dtype.name, "outFIFODepths": [2, 2], "cpp_interface": "hls_vector", @@ -170,8 +174,9 @@ def make_loop_modelwrapper( { "MW": mw, "MH": mh, - "SIMD": 2, - "PE": 2, + "SIMD": mvau_simd, + "PE": mvau_pe, + "TH": mvau_th, "inputDataType": "INT8", "weightDataType": "INT8", "outputDataType": "INT32", @@ -188,7 +193,7 @@ def make_loop_modelwrapper( f"Thresholding_rtl_0{name_suffix}", { "NumChannels": mh, - "PE": 2, + "PE": helper_pe, "inputDataType": "INT32", "weightDataType": "INT33", "outputDataType": dtype.name, @@ -204,8 +209,9 @@ def make_loop_modelwrapper( { "MW": mw, "MH": mh, - "SIMD": 2, - "PE": 2, + "SIMD": mvau_simd, + "PE": mvau_pe, + "TH": mvau_th, "inputDataType": "INT8", "weightDataType": "INT8", "outputDataType": "INT32", @@ -222,7 +228,7 @@ def make_loop_modelwrapper( f"Thresholding_rtl_1{name_suffix}", { "NumChannels": mh, - "PE": 2, + "PE": helper_pe, "inputDataType": "INT32", "weightDataType": "INT33", "outputDataType": dtype.name, @@ -238,8 +244,9 @@ def make_loop_modelwrapper( { "MW": mw, "MH": mh, - "SIMD": 2, - "PE": 2, + "SIMD": mvau_simd, + "PE": mvau_pe, + "TH": mvau_th, "inputDataType": "INT8", "weightDataType": "INT8", "outputDataType": "INT32", @@ -256,7 +263,7 @@ def make_loop_modelwrapper( f"Thresholding_rtl_2{name_suffix}", { "NumChannels": mh, - "PE": 2, + "PE": helper_pe, "inputDataType": "INT32", "weightDataType": "INT33", "outputDataType": dtype.name, @@ -269,7 +276,7 @@ def make_loop_modelwrapper( [f"mt2_out{name_suffix}", f"mt1_out{name_suffix}"], [f"ofm{name_suffix}"], f"AddStreams_hls_0{name_suffix}", - {"NumChannels": mh, "PE": 2, "inputDataTypes": [dtype.name, dtype.name]}, + {"NumChannels": mh, "PE": helper_pe, "inputDataTypes": [dtype.name, dtype.name]}, ), create_node( elemwise_optype, @@ -292,7 +299,7 @@ def make_loop_modelwrapper( f"Thresholding_rtl4{name_suffix}", { "NumChannels": mh, - "PE": 2, + "PE": helper_pe, "numSteps": dtype.get_num_possible_values() - 1, "inputDataType": thresholding_input_dtype.name, "weightDataType": thresholding_input_dtype.name, @@ -362,8 +369,94 @@ def make_loop_modelwrapper( return loop_body_model +def make_single_mvau_loop_body( + mw, + mh, + dtype=DataType["INT8"], + name_suffix="", + mvau_pe=2, + mvau_simd=2, + mvau_th=1, + helper_pe=2, +): + """Create a minimal loop body with just MVAU_rtl -> Thresholding_rtl.""" + + W0 = gen_finn_dt_tensor(dtype, (mw, mh)) + T0 = np.sort( + generate_random_threshold_values(dtype, 1, dtype.get_num_possible_values() - 1), axis=1 + ) + + nodes = [ + create_node( + "MVAU_rtl", + [f"ifm{name_suffix}", f"weights0{name_suffix}"], + [f"mm0_out{name_suffix}"], + f"MVAU_rtl_0{name_suffix}", + { + "MW": mw, + "MH": mh, + "SIMD": mvau_simd, + "PE": mvau_pe, + "TH": mvau_th, + "inputDataType": "INT8", + "weightDataType": "INT8", + "outputDataType": "INT32", + "ActVal": 0, + "binaryXnorMode": 0, + "noActivation": 1, + "mem_mode": "external_mem", + }, + ), + create_node( + "Thresholding_rtl", + [f"mm0_out{name_suffix}", f"thresh0{name_suffix}"], + [f"ofm{name_suffix}"], + f"Thresholding_rtl_0{name_suffix}", + { + "NumChannels": mh, + "PE": helper_pe, + "inputDataType": "INT32", + "weightDataType": "INT33", + "outputDataType": dtype.name, + "ActVal": int(dtype.min()), + "numSteps": dtype.get_num_possible_values() - 1, + }, + ), + ] + + loop_body = helper.make_graph( + nodes=nodes, + name=f"single_mvau_graph{name_suffix}", + inputs=[ + create_tensor_info(f"ifm{name_suffix}", [1, 3, 3, mw]), + create_threshold(f"thresh0{name_suffix}", (1, dtype.get_num_possible_values() - 1)), + ], + outputs=[create_tensor_info(f"ofm{name_suffix}", (1, 3, 3, mh))], + value_info=[ + create_tensor_info(f"mm0_out{name_suffix}", [1, 3, 3, mh]), + ], + ) + + loop_body_model = qonnx_make_model(loop_body, producer_name=f"single-mvau-body{name_suffix}") + loop_body_model = ModelWrapper(loop_body_model) + + loop_body_model.set_initializer(f"weights0{name_suffix}", W0) + loop_body_model.set_initializer(f"thresh0{name_suffix}", T0) + + for tensor in [ + f"weights0{name_suffix}", + f"thresh0{name_suffix}", + f"ifm{name_suffix}", + f"ofm{name_suffix}", + ]: + loop_body_model.set_tensor_datatype(tensor, dtype) + + return loop_body_model + + def create_chained_loop_bodies( - mw, mh, num_copies, elemwise_optype="ElementwiseMul_hls", rhs_shape=[1], eltw_param_dtype="INT8" + mw, mh, num_copies, elemwise_optype="ElementwiseMul_hls", rhs_shape=[1], eltw_param_dtype="INT8", + mvau_pe=2, mvau_simd=2, mvau_th=1, helper_pe=2, ): loop_body_models = [] @@ -378,6 +471,10 @@ def create_chained_loop_bodies( rhs_shape=rhs_shape, eltw_param_dtype=eltw_param_dtype, name_suffix=name_suffix, + mvau_pe=mvau_pe, + mvau_simd=mvau_simd, + mvau_th=mvau_th, + helper_pe=helper_pe, ) loop_body_models.append(loop_body_model) @@ -550,6 +647,177 @@ def test_finnloop_end2end_mlo( ), f"Check vivado.log in {tmp_output_dir}/stitched_ip" +# iteration count, number of models chained together +@pytest.mark.parametrize("iteration", [3]) +# elementwise operation +@pytest.mark.parametrize("elemwise_optype", ["ElementwiseMul_hls"]) +# elementwise shape +@pytest.mark.parametrize("rhs_shape", [[1]]) +# eltwise param dtype +@pytest.mark.parametrize("eltw_param_dtype", ["INT8"]) +# tail node +@pytest.mark.parametrize("tail_node", [False]) +@pytest.mark.fpgadataflow +@pytest.mark.vivado +@pytest.mark.slow +def test_finnloop_end2end_mlo_tiled( + iteration, elemwise_optype, rhs_shape, eltw_param_dtype, tail_node +): + """End-to-end MLO test with tiled MVAUs (TH>1).""" + dim = 12 + mvau_pe = 6 + mvau_simd = 3 + mvau_th = 3 + helper_pe = 6 + + # Check vivado version + vivado_path = os.environ.get("XILINX_VIVADO") + match = re.search(r"\b(20\d{2})\.(1|2)\b", vivado_path) + year, minor = int(match.group(1)), int(match.group(2)) + if (year, minor) < (2024, 2): + pytest.skip("""At least Vivado version 2024.2 needed for MLO.""") + loop_body_models = create_chained_loop_bodies( + dim, dim, iteration, elemwise_optype, rhs_shape, eltw_param_dtype, + mvau_pe=mvau_pe, mvau_simd=mvau_simd, mvau_th=mvau_th, helper_pe=helper_pe, + ) + model = loop_body_models[0] + for m in loop_body_models[1:]: + model = model.transform(MergeONNXModels(m)) + + if tail_node: + tail_outp = create_tensor_info("tail_outp", [1, 3, 3, dim]) + tr_node = create_node( + "ElementwiseAdd_hls", + [model.graph.output[0].name, "tail_add"], + ["tail_outp"], + "Add_tail", + { + "lhs_shape": [1, 3, 3, dim], + "rhs_shape": [1], + "out_shape": [1, 3, 3, dim], + "lhs_dtype": "INT8", + "rhs_dtype": "INT8", + "out_dtype": "INT9", + }, + ) + model.graph.node.insert(len(model.graph.node), tr_node) + model.graph.value_info.append(model.graph.output[0]) + model.graph.output.pop(0) + model.graph.output.append(tail_outp) + AddtailParam = gen_finn_dt_tensor(DataType["INT8"], [1]) + model.set_initializer("tail_add", AddtailParam) + model.set_tensor_datatype("tail_add", DataType["INT8"]) + + # cleanup + model = model.transform(RemoveUnusedTensors()) + model = model.transform(InferShapes()) + model = model.transform(InferDataTypes()) + + # Generate reference by first copying the model and running cppsim + model_ref = model.transform(PrepareCppSim()) + model_ref = model_ref.transform(CompileCppSim()) + model_ref = model_ref.transform(SetExecMode("cppsim")) + + # generate reference io pair + x = gen_finn_dt_tensor(DataType["INT8"], (1, 3, 3, dim)) + io_dict = {model_ref.graph.input[0].name: x} + y_dict = oxe.execute_onnx(model_ref, io_dict) + y_ref = y_dict[model_ref.graph.output[0].name] + + tmp_output_dir = make_build_dir("build_mlo_tiled") + + np.save(tmp_output_dir + "/input.npy", x) + np.save(tmp_output_dir + "/expected_output.npy", y_ref) + + model.save(tmp_output_dir + "/mlo_model.onnx") + + # steps - skip step_target_fps_parallelization since PE/SIMD/TH already set + steps = [ + "step_create_dataflow_partition", + "step_loop_rolling", + "step_apply_folding_config", + "step_minimize_bit_width", + "step_generate_estimate_reports", + "step_hw_codegen", + "step_hw_ipgen", + "step_set_fifo_depths", + "step_create_stitched_ip", + ] + + cfg = build_cfg.DataflowBuildConfig( + output_dir=tmp_output_dir, + steps=steps, + target_fps=1000, + synth_clk_period_ns=10.0, + board="V80", + rtlsim_batch_size=100, + standalone_thresholds=True, + mlo=True, + loop_body_hierarchy=[["", "layers.0"]], + loop_body_range=(model.graph.node[0], model.graph.node[9]), + verify_steps=verif_steps, + verify_input_npy=tmp_output_dir + "/input.npy", + verify_expected_output_npy=tmp_output_dir + "/expected_output.npy", + verify_save_rtlsim_waveforms=True, + generate_outputs=[ + build_cfg.DataflowOutputType.ESTIMATE_REPORTS, + build_cfg.DataflowOutputType.STITCHED_IP, + ], + ) + build.build_dataflow_cfg(tmp_output_dir + "/mlo_model.onnx", cfg) + + # Dump weight files for hardware debug + import glob + + built_model = ModelWrapper(tmp_output_dir + "/mlo_model.onnx") + for node in built_model.graph.node: + if node.op_type == "FINNLoop": + fl_op = getCustomOp(node) + code_gen_dir = fl_op.get_nodeattr("code_gen_dir_ipgen") + if code_gen_dir and os.path.isdir(code_gen_dir): + for f in sorted(glob.glob(code_gen_dir + "/input1_*.npy")): + arr = np.load(f) + base = os.path.basename(f).replace(".npy", "") + txt_path = tmp_output_dir + f"/weights_{base}.txt" + with open(txt_path, "w") as tf: + tf.write(f"# {f}\n# shape: {arr.shape}, dtype: {arr.dtype}\n") + tf.write("# row | decimal_values | hex_bytes | bus_word\n\n") + flat = arr.reshape(-1, arr.shape[-1]) + for i, row in enumerate(flat): + dec = " ".join(f"{int(v):5d}" for v in row) + hx = " ".join(f"{int(v) & 0xFF:02x}" for v in row) + bus = 0 + for j, v in enumerate(row): + bus |= (int(v) & 0xFF) << (j * 8) + tf.write( + f"[{i:3d}] {dec} | {hx} | 0x{bus:0{arr.shape[-1]*2}x}\n" + ) + print(f"DEBUG: weight dump -> {txt_path}") + for f in sorted(glob.glob(code_gen_dir + "/memblock_*.dat")): + base = os.path.basename(f).replace(".dat", "") + txt_path = tmp_output_dir + f"/weights_{base}.txt" + with open(txt_path, "w") as tf: + tf.write(f"# {f}\n") + with open(f) as df: + for i, line in enumerate(df): + tf.write(f"[{i:3d}] {line.strip()}\n") + print(f"DEBUG: dat dump -> {txt_path}") + + # check if expected files are there + assert os.path.isfile(tmp_output_dir + "/loop-body-template.onnx") + assert os.path.isfile(tmp_output_dir + "/stitched_ip/ip/component.xml") + + verif_dir = tmp_output_dir + "/verification_output" + assert os.path.isfile( + verif_dir + "/verify_folded_hls_cppsim_0_SUCCESS.npy" + ), f"Check npy files in {verif_dir}" + assert os.path.isfile( + verif_dir + "/verify_node_by_node_rtlsim_0_SUCCESS.npy" + ), f"Check npy files in {verif_dir}" + assert os.path.isfile( + verif_dir + "/verify_stitched_ip_rtlsim_0_SUCCESS.npy" + ), f"Check npy files in {verif_dir}" + # Debug test for manual loop transformation steps below # This test is intentionally not marked for CI # Use test_finnloop_end2end_mlo instead diff --git a/tests/fpgadataflow/test_fpgadataflow_mvau.py b/tests/fpgadataflow/test_fpgadataflow_mvau.py index c828b4717e..f1016f092c 100644 --- a/tests/fpgadataflow/test_fpgadataflow_mvau.py +++ b/tests/fpgadataflow/test_fpgadataflow_mvau.py @@ -856,8 +856,8 @@ def test_fpgadataflow_rtl_mvau( ).all(), "Output of ONNX model not matching output of stitched-IP RTL model!" -@pytest.mark.parametrize("mh", [18]) -@pytest.mark.parametrize("mw", [36]) +@pytest.mark.parametrize("mh", [12]) +@pytest.mark.parametrize("mw", [12]) @pytest.mark.parametrize("pe", [6]) @pytest.mark.parametrize("simd", [3]) @pytest.mark.parametrize("th", [3]) @@ -955,6 +955,106 @@ def test_fpgadataflow_rtl_tiled_mvau(mh, mw, pe, simd, th, idt_wdt, clk_ns): ).all(), "Output of ONNX model not matching output of tiled stitched-IP RTL model!" +@pytest.mark.parametrize("mh", [12]) +@pytest.mark.parametrize("mw", [12]) +@pytest.mark.parametrize("pe", [6]) +@pytest.mark.parametrize("simd", [6]) +@pytest.mark.parametrize("cu_simd", [3]) +@pytest.mark.parametrize("idt_wdt", [[DataType["UINT8"], DataType["INT8"]]]) +@pytest.mark.parametrize("clk_ns", [4]) +@pytest.mark.fpgadataflow +@pytest.mark.slow +@pytest.mark.vivado +def test_fpgadataflow_rtl_mmau_2d(mh, mw, pe, simd, cu_simd, idt_wdt, clk_ns): + # MMU only supported on Versal (DSP58) + part = "xcvc1902-vsva2197-2MP-e-S" + + if simd % cu_simd != 0: + pytest.skip("SIMD must be divisible by CU_SIMD") + + if mw % simd != 0: + pytest.skip("MW must be divisible by SIMD") + + if mh % pe != 0: + pytest.skip("MH must be divisible by PE") + + idt, wdt = idt_wdt + # Create test input vector (produced by SWG) + ofm_shape = (3, 3) + ofm_h, ofm_w = ofm_shape + ifm = helper.make_tensor_value_info("ifm", TensorProto.FLOAT, [1, ofm_h, ofm_w, mw]) + ofm = helper.make_tensor_value_info("ofm", TensorProto.FLOAT, (1, ofm_h, ofm_w, mh)) + W = gen_finn_dt_tensor(wdt, (mw, mh)) + model = make_single_matmul_modelwrapper(ifm, ofm, idt, wdt, W) + model = model.transform(GiveUniqueNodeNames()) + model = model.transform(GiveReadableTensorNames()) + + # Create MatMul & obtain golden reference output + A = gen_finn_dt_tensor( + model.get_tensor_datatype("global_in"), model.get_tensor_shape("global_in") + ) + input_dict = prepare_inputs(A, idt, wdt, inp_name="global_in") + + # Execute ONNX model + output_matmul = oxe.execute_onnx(model, input_dict)["global_out"] + + # Create MVAU + model = model.transform(to_hw.InferQuantizedMatrixVectorActivation()) + model = model.transform(GiveUniqueNodeNames()) + + # Apply convert-to-rtl step + model = model.transform(SpecializeLayers(part)) + model = model.transform(GiveUniqueNodeNames()) + + assert model.graph.node[0].op_type == "MVAU_rtl" + # Apply folding with MMU systolic array + folding_config = { + "Defaults": {}, + "MVAU_rtl_0": { + "PE": pe, + "SIMD": simd, + "CU_SIMD": cu_simd, + "gemm_type": "mmau_2d", + "resType": "dsp", + "mem_mode": "external_mem", + }, + } + model = model.transform(ApplyConfig(folding_config)) + model = model.transform(MinimizeWeightBitWidth()) + model = model.transform(MinimizeAccumulatorWidth()) + # make sure the changed datatypes are propagated through the network + model = model.transform(InferDataTypes()) + + # Run CPPsim + model = model.transform(SetExecMode("cppsim")) + model = model.transform(PrepareCppSim()) + model = model.transform(CompileCppSim()) + output_mvau_hls = oxe.execute_onnx(model, input_dict)["global_out"] + assert ( + output_matmul == output_mvau_hls + ).all(), "Output of ONNX model not matching output of node-by-node CPPsim!" + + # Run node-by-node RTLsim + model = model.transform(SetExecMode("rtlsim")) + model = model.transform(PrepareIP(part, clk_ns)) + model = model.transform(HLSSynthIP()) + model = model.transform(PrepareRTLSim()) + output_mvau_rtl = oxe.execute_onnx(model, input_dict)["global_out"] + assert ( + output_matmul == output_mvau_rtl + ).all(), "Output of ONNX model not matching output of node-by-node RTLsim!" + + # Run stitched-ip RTLsim + model = model.transform(InsertAndSetFIFODepths(part, clk_ns)) + model = model.transform(PrepareIP(part, clk_ns)) + model = model.transform(HLSSynthIP()) + model = model.transform(CreateStitchedIP(part, clk_ns)) + output_mvau_rtl_stitch = oxe.execute_onnx(model, input_dict)["global_out"] + assert ( + output_matmul == output_mvau_rtl_stitch + ).all(), "Output of ONNX model not matching output of MMU stitched-IP RTL model!" + + @pytest.mark.parametrize("mh", [32]) @pytest.mark.parametrize("mw", [16]) @pytest.mark.parametrize("n_vectors", [32]) From 60e8eabbde2722a83d222b21534c023f8bcbe68e Mon Sep 17 00:00:00 2001 From: dkorolij Date: Tue, 14 Apr 2026 13:16:41 +0000 Subject: [PATCH 03/17] precommit run. --- finn-rtllib/fetch_weights/fetch_weights.sv | 8 ++--- finn-rtllib/mmu/1d/collect_out_1d.sv | 2 +- finn-rtllib/mmu/1d/cu_mmau_1d.sv | 20 ++++++------- finn-rtllib/mmu/1d/sched_weights_1d.sv | 2 +- finn-rtllib/mmu/1d/sft_reg.sv | 4 +-- finn-rtllib/mmu/2d/collect_out_2d.sv | 2 +- finn-rtllib/mmu/2d/cu_mmau_2d.sv | 20 ++++++------- finn-rtllib/mmu/2d/sched_weights_2d.sv | 2 +- finn-rtllib/mmu/en_global.sv | 2 +- finn-rtllib/mmu/mmu_axi.sv | 14 ++++----- finn-rtllib/mmu/q_writer.sv | 2 +- finn-rtllib/mmu/reorder_out.sv | 6 ++-- finn-rtllib/mmu/replay_buff_mmau.sv | 8 ++--- finn-rtllib/mmu/sched_activations.sv | 4 +-- finn-rtllib/mvu_tiled/acc_stage.sv | 14 ++++----- finn-rtllib/mvu_tiled/add_tree.sv | 4 +-- finn-rtllib/mvu_tiled/mvu_tiled_axi.sv | 10 +++---- finn-rtllib/mvu_tiled/mvu_tiled_axi_wrapper.v | 2 +- finn-rtllib/mvu_tiled/reorder_out.sv | 6 ++-- finn-rtllib/mvu_tiled/replay_buff_tile.sv | 6 ++-- finn-rtllib/mvu_tiled/weights_buff_tile.sv | 6 ++-- .../rtl/matrixvectoractivation_rtl.py | 4 +-- src/finn/util/mlo_sim.py | 4 +-- .../test_fpgadataflow_finnloop.py | 29 ++++++++++++++----- 24 files changed, 97 insertions(+), 84 deletions(-) diff --git a/finn-rtllib/fetch_weights/fetch_weights.sv b/finn-rtllib/fetch_weights/fetch_weights.sv index 4573fe53bf..73aff27039 100644 --- a/finn-rtllib/fetch_weights/fetch_weights.sv +++ b/finn-rtllib/fetch_weights/fetch_weights.sv @@ -29,7 +29,7 @@ * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * *****************************************************************************/ - + module fetch_weights #( int unsigned PE, int unsigned SIMD, @@ -195,7 +195,7 @@ if(TH > 1) begin ST_DMA: state_N = (cnt_dma_C == N_REPS-1) && dma_tready ? ST_IDLE : ST_DMA; - + endcase end @@ -213,7 +213,7 @@ if(TH > 1) begin if(q_idx_out_tvalid) begin idx_N = q_idx_out_tdata; end - end + end ST_DMA: begin dma_tvalid = 1'b1; @@ -221,7 +221,7 @@ if(TH > 1) begin cnt_dma_N = cnt_dma_C + 1; end end - + endcase end diff --git a/finn-rtllib/mmu/1d/collect_out_1d.sv b/finn-rtllib/mmu/1d/collect_out_1d.sv index 1451d9a837..ab874aac47 100644 --- a/finn-rtllib/mmu/1d/collect_out_1d.sv +++ b/finn-rtllib/mmu/1d/collect_out_1d.sv @@ -108,4 +108,4 @@ module collect_out_1d #( .odat(m_axis_tdata), .ovld(m_axis_tvalid), .ordy(m_axis_tready) ); -endmodule \ No newline at end of file +endmodule diff --git a/finn-rtllib/mmu/1d/cu_mmau_1d.sv b/finn-rtllib/mmu/1d/cu_mmau_1d.sv index 601c4bb2db..0f13083d09 100644 --- a/finn-rtllib/mmu/1d/cu_mmau_1d.sv +++ b/finn-rtllib/mmu/1d/cu_mmau_1d.sv @@ -36,7 +36,7 @@ module cu_mmau_1d #( int unsigned PE, int unsigned CLEN, int unsigned CU_SIMD, - + int unsigned ACTIVATION_WIDTH, int unsigned WEIGHT_WIDTH, int unsigned ACCU_WIDTH, @@ -62,7 +62,7 @@ module cu_mmau_1d #( input logic m_axis_tready, output logic [PE-1:0][ACCU_WIDTH-1:0] m_axis_tdata ); - + // Startup Recovery Watchdog // The DSP slice needs 100ns of recovery time after initial startup before // being able to ingest input properly. This watchdog discovers violating @@ -97,7 +97,7 @@ module cu_mmau_1d #( assign Wc_int[0][i][8*k +: 8] = PAD_BITS_WEIGHT == 0 ? Wc[0][i][WEIGHT_WIDTH*k+:WEIGHT_WIDTH] : { {PAD_BITS_WEIGHT{Wc[0][i][k*WEIGHT_WIDTH+WEIGHT_WIDTH-1]}}, Wc[0][i][k*WEIGHT_WIDTH+:WEIGHT_WIDTH] }; end - end + end /* always_ff @(posedge clk) begin @@ -169,7 +169,7 @@ module cu_mmau_1d #( end end - for (genvar j = 0; j < PE; j++) begin + for (genvar j = 0; j < PE; j++) begin always_comb begin Mc_int_sum[i][j] = 0; @@ -209,9 +209,9 @@ module cu_mmau_1d #( SIGNED_ACTIVATIONS ? PAD_BITS_ACT == 0 ? a[i][k] : { {PAD_BITS_ACT{a[i][k][ACTIVATION_WIDTH-1]}}, a[i][k] } : PAD_BITS_ACT == 0 ? a[i][k] : { {PAD_BITS_ACT{1'b0}}, a[i][k] } ; end - - for (genvar j = 0; j < PE; j++) begin + + for (genvar j = 0; j < PE; j++) begin /* for (genvar k = 0; k < CU_SIMD; k++) begin assign Wc_int[i][j][8*k +: 8] = PAD_BITS_WEIGHT == 0 ? Wc[i][j][WEIGHT_WIDTH*k+:WEIGHT_WIDTH] : { {PAD_BITS_WEIGHT{Wc[i][j][k*WEIGHT_WIDTH+WEIGHT_WIDTH-1]}}, Wc[i][j][k*WEIGHT_WIDTH+:WEIGHT_WIDTH] }; @@ -343,7 +343,7 @@ module cu_mmau_1d #( .RSTP(rst) // 1-bit input: Reset for PREG ); - + if(i % CC_LEN == CC_LEN-1) begin sft_reg #( .N(CC_LEN) @@ -378,7 +378,7 @@ module cu_mmau_1d #( end else begin Pc[i] <= Lc[i][DSP_PIPELINE_STAGES] ? pout[i] : Pc[i+1]; Pc_vld[i] <= Lc[i][DSP_PIPELINE_STAGES] ? 1'b1 : Pc_vld[i+1]; - end + end end end end @@ -391,7 +391,7 @@ module cu_mmau_1d #( assign p = Pc[0]; collect_out_1d #( - .PE(PE), + .PE(PE), .ACCU_WIDTH(ACCU_WIDTH) ) inst_collect_out ( .clk(clk), .rst(rst), @@ -400,4 +400,4 @@ module cu_mmau_1d #( .m_axis_tdata(m_axis_tdata), .m_axis_tvalid(m_axis_tvalid), .m_axis_tready(m_axis_tready) ); -endmodule \ No newline at end of file +endmodule diff --git a/finn-rtllib/mmu/1d/sched_weights_1d.sv b/finn-rtllib/mmu/1d/sched_weights_1d.sv index 5046f63d84..de2689e957 100644 --- a/finn-rtllib/mmu/1d/sched_weights_1d.sv +++ b/finn-rtllib/mmu/1d/sched_weights_1d.sv @@ -138,4 +138,4 @@ module sched_weights_1d #( .odat(m_axis_tdata), .ovld(m_axis_tvalid), .ordy(m_axis_tready) ); -endmodule \ No newline at end of file +endmodule diff --git a/finn-rtllib/mmu/1d/sft_reg.sv b/finn-rtllib/mmu/1d/sft_reg.sv index 7fa8f5eeef..65e674e53b 100644 --- a/finn-rtllib/mmu/1d/sft_reg.sv +++ b/finn-rtllib/mmu/1d/sft_reg.sv @@ -18,8 +18,8 @@ module sft_reg #( end end - // The tool sees this lack of reset and constant index + // The tool sees this lack of reset and constant index // and maps it to an SRL16 automatically. assign dout = shift_pipe[N-1]; -endmodule \ No newline at end of file +endmodule diff --git a/finn-rtllib/mmu/2d/collect_out_2d.sv b/finn-rtllib/mmu/2d/collect_out_2d.sv index 1f0b60b319..b1949e0620 100644 --- a/finn-rtllib/mmu/2d/collect_out_2d.sv +++ b/finn-rtllib/mmu/2d/collect_out_2d.sv @@ -117,4 +117,4 @@ module collect_out_2d #( .odat(m_axis_tdata), .ovld(m_axis_tvalid), .ordy(m_axis_tready) ); -endmodule \ No newline at end of file +endmodule diff --git a/finn-rtllib/mmu/2d/cu_mmau_2d.sv b/finn-rtllib/mmu/2d/cu_mmau_2d.sv index e4f246872c..40379ea16f 100644 --- a/finn-rtllib/mmu/2d/cu_mmau_2d.sv +++ b/finn-rtllib/mmu/2d/cu_mmau_2d.sv @@ -36,7 +36,7 @@ module cu_mmau_2d #( int unsigned PE, int unsigned CLEN, int unsigned CU_SIMD, - + int unsigned ACTIVATION_WIDTH, int unsigned WEIGHT_WIDTH, int unsigned ACCU_WIDTH, @@ -63,7 +63,7 @@ module cu_mmau_2d #( output logic [PE-1:0][ACCU_WIDTH-1:0] m_axis_tdata ); - + // Startup Recovery Watchdog // The DSP slice needs 100ns of recovery time after initial startup before // being able to ingest input properly. This watchdog discovers violating @@ -95,11 +95,11 @@ module cu_mmau_2d #( for(genvar i = 0; i < CLEN; i++) begin assign Ac[i][0] = a[i]; assign Ac_last[i][0] = ilast[i]; - end + end for(genvar i = 0; i < PE; i++) begin assign Wc[0][i] = w[i]; - end + end always_ff @(posedge clk) begin if(rst) begin @@ -115,7 +115,7 @@ module cu_mmau_2d #( Wc[i][j] <= '0; end end - end + end for(int i = 0; i < CLEN; i++) begin for(int j = 1; j < PE; j++) begin if(ivld) begin @@ -181,7 +181,7 @@ module cu_mmau_2d #( for (genvar i = 0; i < CLEN; i++) begin - for (genvar j = 0; j < PE; j++) begin + for (genvar j = 0; j < PE; j++) begin always_comb begin Mc_int_sum[i][j] = 0; @@ -216,7 +216,7 @@ module cu_mmau_2d #( logic [CLEN-1:0][PE-1:0][23:0] Wc_int; for (genvar i = 0; i < CLEN; i++) begin - for (genvar j = 0; j < PE; j++) begin + for (genvar j = 0; j < PE; j++) begin for (genvar k = 0; k < CU_SIMD; k++) begin assign Ac_int[i][j][9*k +: 9] = @@ -372,7 +372,7 @@ module cu_mmau_2d #( end end end - + for(genvar i = 0; i < CLEN; i++) begin for(genvar j = 0; j < PE; j++) begin if(i == CLEN-1) begin @@ -420,7 +420,7 @@ module cu_mmau_2d #( end collect_out_2d #( - .PE(PE), + .PE(PE), .ACCU_WIDTH(ACCU_WIDTH) ) inst_collect_out ( .clk(clk), .rst(rst), @@ -429,4 +429,4 @@ module cu_mmau_2d #( .m_axis_tdata(m_axis_tdata), .m_axis_tvalid(m_axis_tvalid), .m_axis_tready(m_axis_tready) ); -endmodule \ No newline at end of file +endmodule diff --git a/finn-rtllib/mmu/2d/sched_weights_2d.sv b/finn-rtllib/mmu/2d/sched_weights_2d.sv index a61c5bbaa6..93cfad09b5 100644 --- a/finn-rtllib/mmu/2d/sched_weights_2d.sv +++ b/finn-rtllib/mmu/2d/sched_weights_2d.sv @@ -162,4 +162,4 @@ module sched_weights_2d #( .odat(m_axis_tdata), .ovld(m_axis_tvalid), .ordy(m_axis_tready) ); -endmodule \ No newline at end of file +endmodule diff --git a/finn-rtllib/mmu/en_global.sv b/finn-rtllib/mmu/en_global.sv index fe9525112e..9547660d73 100644 --- a/finn-rtllib/mmu/en_global.sv +++ b/finn-rtllib/mmu/en_global.sv @@ -96,4 +96,4 @@ skid #(.DATA_WIDTH(CLEN*CU_SIMD*ACTIVATION_WIDTH+CLEN), .FEED_STAGES(N_DCPL_STAG .odat({m_act_tlast, m_act_tdata}), .ovld(m_tvalid), .ordy(1'b1) ); -endmodule \ No newline at end of file +endmodule diff --git a/finn-rtllib/mmu/mmu_axi.sv b/finn-rtllib/mmu/mmu_axi.sv index 92ec9213a8..7e0b3f14bb 100644 --- a/finn-rtllib/mmu/mmu_axi.sv +++ b/finn-rtllib/mmu/mmu_axi.sv @@ -53,12 +53,12 @@ module mmu_axi #( bit FORCE_BEHAVIOURAL = 0, int unsigned N_DCPL_STAGES = 2, - + // Safely deducible parameters localparam int unsigned CLEN = (SIMD + CU_SIMD-1)/ CU_SIMD, localparam int unsigned WSIMD = PE * CU_SIMD, localparam int unsigned ASIMD = CLEN * CU_SIMD, - + localparam int unsigned WEIGHT_STREAM_WIDTH = WSIMD * WEIGHT_WIDTH, localparam int unsigned WEIGHT_STREAM_WIDTH_BA = (WEIGHT_STREAM_WIDTH + 7)/8 * 8, localparam int unsigned INPUT_STREAM_WIDTH = SIMD * ACTIVATION_WIDTH, @@ -117,7 +117,7 @@ module mmu_axi #( localparam int unsigned N_TRS_OP = SF * NF * N_VECTORS; localparam int unsigned N_TRS_EP = (GEMM_TYPE == "mmau_1d") ? CLEN-1 + CLEN-1 + DSP_STAGES + 2 : CLEN-1 + CLEN-1 + DSP_STAGES + PE; - + // Input replay // --------------------------------------------------------------------- logic [SIMD-1:0][ACTIVATION_WIDTH-1:0] adat_s0; @@ -136,7 +136,7 @@ module mmu_axi #( .ovld(act_s0_tvalid), .ordy(act_s0_tready), .odat(act_s0_tdata), .olast(act_s0_tlast) ); - if (ASIMD > SIMD) + if (ASIMD > SIMD) assign act_s0_tdata_mod[ASIMD-1:SIMD] = '0; assign act_s0_tdata_mod[SIMD-1:0] = act_s0_tdata[SIMD-1:0]; @@ -200,7 +200,7 @@ end .en(en), .s_act_tvalid(act_s1_tvalid), .s_act_tready(act_s1_tready), .s_act_tdata(act_s1_tdata), .s_act_tlast(act_s1_tlast), .m_act_tdata(act_s2_tdata), .m_act_tlast(act_s2_tlast), - .s_wgt_tvalid(wgt_s1_tvalid), .s_wgt_tready(wgt_s1_tready), .s_wgt_tdata(wgt_s1_tdata), + .s_wgt_tvalid(wgt_s1_tvalid), .s_wgt_tready(wgt_s1_tready), .s_wgt_tdata(wgt_s1_tdata), .m_wgt_tdata(wgt_s2_tdata), .m_tvalid(s2_tvalid) ); @@ -233,7 +233,7 @@ end else begin .m_axis_tvalid(p_tvalid), .m_axis_tready(p_tready), .m_axis_tdata(p_tdata) ); end - + // Reorder // --------------------------------------------------------------------- @@ -249,4 +249,4 @@ end assign m_axis_output_tdata = p_tdata; end -endmodule \ No newline at end of file +endmodule diff --git a/finn-rtllib/mmu/q_writer.sv b/finn-rtllib/mmu/q_writer.sv index 8d974a7f43..2a796c0f8d 100644 --- a/finn-rtllib/mmu/q_writer.sv +++ b/finn-rtllib/mmu/q_writer.sv @@ -121,4 +121,4 @@ module q_writer #( .o_v(m_axis_tvalid), .o_r(m_axis_tready), .o_d({m_axis_tlast, m_axis_tdata}) ); -endmodule \ No newline at end of file +endmodule diff --git a/finn-rtllib/mmu/reorder_out.sv b/finn-rtllib/mmu/reorder_out.sv index 19421937e4..1b4b75183e 100644 --- a/finn-rtllib/mmu/reorder_out.sv +++ b/finn-rtllib/mmu/reorder_out.sv @@ -29,7 +29,7 @@ * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * *****************************************************************************/ - + module reorder_out #( int unsigned W, int unsigned XC, @@ -242,7 +242,7 @@ always_comb begin if(curr_wrY_C > curr_rdY_C) begin cond_go = 1'b1; end - end + end end // -- DP @@ -338,4 +338,4 @@ Q_srl #( ); -endmodule \ No newline at end of file +endmodule diff --git a/finn-rtllib/mmu/replay_buff_mmau.sv b/finn-rtllib/mmu/replay_buff_mmau.sv index ab99bdad5c..121471f163 100644 --- a/finn-rtllib/mmu/replay_buff_mmau.sv +++ b/finn-rtllib/mmu/replay_buff_mmau.sv @@ -73,7 +73,7 @@ logic ivld_int; logic irdy_int; skid #(.DATA_WIDTH(W), .FEED_STAGES(1)) isnt_input_reg ( - .clk(clk), .rst(rst), + .clk(clk), .rst(rst), .ivld(ivld), .irdy(irdy), .idat(idat), .ovld(ivld_int), .ordy(irdy_int), .odat(idat_int) ); @@ -277,7 +277,7 @@ always_comb begin if(curr_wrY_C > curr_rdY_C) begin cond_go = 1'b1; end - end + end end else begin if(curr_wrY_C > curr_rdY_C) begin cond_go = 1'b1; @@ -286,7 +286,7 @@ always_comb begin if(curr_wrX_C > curr_rdX_C) begin cond_go = 1'b1; end - end + end end end @@ -394,4 +394,4 @@ Q_srl #( .o_r(ordy) ); -endmodule \ No newline at end of file +endmodule diff --git a/finn-rtllib/mmu/sched_activations.sv b/finn-rtllib/mmu/sched_activations.sv index bbb989ff4e..29f832fcc1 100644 --- a/finn-rtllib/mmu/sched_activations.sv +++ b/finn-rtllib/mmu/sched_activations.sv @@ -124,7 +124,7 @@ module sched_activations #( if (ovld && ordy) begin // Read from queue for(int i = 0; i < CLEN; i++) begin - q_out_tready[i] = valid_C[i]; + q_out_tready[i] = valid_C[i]; end // Shift ctrl @@ -196,4 +196,4 @@ module sched_activations #( ); -endmodule \ No newline at end of file +endmodule diff --git a/finn-rtllib/mvu_tiled/acc_stage.sv b/finn-rtllib/mvu_tiled/acc_stage.sv index 87214f0c1c..afbb5e1d1b 100644 --- a/finn-rtllib/mvu_tiled/acc_stage.sv +++ b/finn-rtllib/mvu_tiled/acc_stage.sv @@ -50,9 +50,9 @@ module acc_stage #( output logic oval ); -// +// // Adder tree -// +// localparam integer TREE_HEIGHT = $clog2(CHAINLEN); localparam integer ADD_LAT = TREE_HEIGHT + 1; @@ -69,7 +69,7 @@ for(genvar i = 0; i < PE; i++) begin .clk(clk), .rst(rst), .en(en), - + .idat(idat[i]), .iacc(dat_acc[i]), .odat(dat_int[i]) @@ -108,9 +108,9 @@ assign val_int = val[ADD_LAT]; assign last_int = last[ADD_LAT]; assign inc_acc = val[ADD_LAT-1]; -// +// // Accumulation -// +// localparam integer TH_BITS = $clog2(TH); @@ -163,10 +163,10 @@ end always_comb begin fifo_in_tvalid = prep ? 1'b1 : (en ? val_int : 1'b0); - fifo_in_tdata = prep ? 0 : (last_int ? 0 : dat_int); + fifo_in_tdata = prep ? 0 : (last_int ? 0 : dat_int); end assign dat_acc = fifo_out_tdata; assign fifo_out_tready = en & inc_acc; -endmodule : acc_stage \ No newline at end of file +endmodule : acc_stage diff --git a/finn-rtllib/mvu_tiled/add_tree.sv b/finn-rtllib/mvu_tiled/add_tree.sv index 32861e3519..33eba29087 100644 --- a/finn-rtllib/mvu_tiled/add_tree.sv +++ b/finn-rtllib/mvu_tiled/add_tree.sv @@ -124,7 +124,7 @@ else begin end assign add_sf = add_s[TREE_HEIGHT][0]; -end +end logic signed [ACCU_WIDTH-1:0] odat_int = '0; @@ -142,4 +142,4 @@ end assign odat = odat_int; -endmodule : add_tree \ No newline at end of file +endmodule : add_tree diff --git a/finn-rtllib/mvu_tiled/mvu_tiled_axi.sv b/finn-rtllib/mvu_tiled/mvu_tiled_axi.sv index 8854fa7e22..28f0b9eafc 100644 --- a/finn-rtllib/mvu_tiled/mvu_tiled_axi.sv +++ b/finn-rtllib/mvu_tiled/mvu_tiled_axi.sv @@ -67,7 +67,7 @@ module mvu_tiled_axi #( parameter COMPUTE_CORE = "mvu_vvu_8sx9_dsp58", int unsigned N_DCPL_STAGES = 2, - + // Safely deducible parameters localparam int unsigned WSIMD = (PE * SIMD) / TH, localparam int unsigned WEIGHT_STREAM_WIDTH = WSIMD * WEIGHT_WIDTH, @@ -155,7 +155,7 @@ module mvu_tiled_axi #( .TH(TH), .WSIMD(WSIMD), .N_DCPL_STAGES(N_DCPL_STAGES) ) inst_weights_buff_tile ( - .clk(ap_clk), .rst(rst), + .clk(ap_clk), .rst(rst), .ivld(s_axis_weights_tvalid), .irdy(s_axis_weights_tready), .idat(s_axis_weights_tdata), .ovld(wvld), .ordy(wrdy), .odat(wdat) ); @@ -198,7 +198,7 @@ module mvu_tiled_axi #( // // Compute Unit // - + case(COMPUTE_CORE) "mvu_vvu_8sx9_dsp58": begin : core cu_mvau_tiled #( @@ -210,7 +210,7 @@ module mvu_tiled_axi #( .clk(ap_clk), .rst(rst), .en(dsp_en), .ivld(istb), .ilast(dsp_last), .w(dsp_w), .a(dsp_a), .ovld(dsp_vld), .p(dsp_p) - ); + ); end default: initial begin $error("Unrecognized COMPUTE_CORE '%s'", COMPUTE_CORE); @@ -263,7 +263,7 @@ module mvu_tiled_axi #( assign m_axis_int_tdata = { {(OUTPUT_STREAM_WIDTH_BA-OUTPUT_STREAM_WIDTH){B.dat[PE-1][ACCU_WIDTH-1]}}, B.dat}; //-------------------- Output reordering --------------------\\ - + if(OUT_TILED == 0) begin reorder_out #(.W(OUTPUT_STREAM_WIDTH_BA), .XC(MH/PE), .YC(TH)) inst_reorder_out ( .clk(ap_clk), .rst(rst), diff --git a/finn-rtllib/mvu_tiled/mvu_tiled_axi_wrapper.v b/finn-rtllib/mvu_tiled/mvu_tiled_axi_wrapper.v index fe3a73bc8b..36ce60c3b6 100644 --- a/finn-rtllib/mvu_tiled/mvu_tiled_axi_wrapper.v +++ b/finn-rtllib/mvu_tiled/mvu_tiled_axi_wrapper.v @@ -77,7 +77,7 @@ mvu_tiled_axi #( .PE(PE), .SIMD(SIMD), .ACTIVATION_WIDTH(ACTIVATION_WIDTH), .WEIGHT_WIDTH(WEIGHT_WIDTH), .ACCU_WIDTH(ACCU_WIDTH), .MW(MW), .MH(MH), .TH(TH), - .NARROW_WEIGHTS(NARROW_WEIGHTS), .SIGNED_ACTIVATIONS(SIGNED_ACTIVATIONS), .PUMPED_COMPUTE(PUMPED_COMPUTE), + .NARROW_WEIGHTS(NARROW_WEIGHTS), .SIGNED_ACTIVATIONS(SIGNED_ACTIVATIONS), .PUMPED_COMPUTE(PUMPED_COMPUTE), .FORCE_BEHAVIORAL(0) ) inst ( .ap_clk(ap_clk), diff --git a/finn-rtllib/mvu_tiled/reorder_out.sv b/finn-rtllib/mvu_tiled/reorder_out.sv index 19421937e4..1b4b75183e 100644 --- a/finn-rtllib/mvu_tiled/reorder_out.sv +++ b/finn-rtllib/mvu_tiled/reorder_out.sv @@ -29,7 +29,7 @@ * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * *****************************************************************************/ - + module reorder_out #( int unsigned W, int unsigned XC, @@ -242,7 +242,7 @@ always_comb begin if(curr_wrY_C > curr_rdY_C) begin cond_go = 1'b1; end - end + end end // -- DP @@ -338,4 +338,4 @@ Q_srl #( ); -endmodule \ No newline at end of file +endmodule diff --git a/finn-rtllib/mvu_tiled/replay_buff_tile.sv b/finn-rtllib/mvu_tiled/replay_buff_tile.sv index 467e17a3af..ec4618e58d 100644 --- a/finn-rtllib/mvu_tiled/replay_buff_tile.sv +++ b/finn-rtllib/mvu_tiled/replay_buff_tile.sv @@ -73,7 +73,7 @@ logic ivld_int; logic irdy_int; skid #(.DATA_WIDTH(W), .FEED_STAGES(1)) isnt_input_reg ( - .clk(clk), .rst(rst), + .clk(clk), .rst(rst), .ivld(ivld), .irdy(irdy), .idat(idat), .ovld(ivld_int), .ordy(irdy_int), .odat(idat_int) ); @@ -272,7 +272,7 @@ always_comb begin if(curr_wrY_C > curr_rdY_C) begin cond_go = 1'b1; end - end + end end else begin if(curr_wrY_C > curr_rdY_C) begin cond_go = 1'b1; @@ -281,7 +281,7 @@ always_comb begin if(curr_wrX_C > curr_rdX_C) begin cond_go = 1'b1; end - end + end end end diff --git a/finn-rtllib/mvu_tiled/weights_buff_tile.sv b/finn-rtllib/mvu_tiled/weights_buff_tile.sv index 1426411c24..258365f74e 100644 --- a/finn-rtllib/mvu_tiled/weights_buff_tile.sv +++ b/finn-rtllib/mvu_tiled/weights_buff_tile.sv @@ -81,7 +81,7 @@ logic [WSIMD-1:0][WEIGHT_WIDTH-1:0] idat_int; // Ireg skid #(.DATA_WIDTH(WSIMD*WEIGHT_WIDTH), .FEED_STAGES(1)) inst_ireg ( - .clk(clk), .rst(rst), + .clk(clk), .rst(rst), .ivld(ivld), .irdy(irdy), .idat(idat), .ovld(ivld_int), .ordy(irdy_int), .odat(idat_int) ); @@ -252,9 +252,9 @@ end // Oreg skid #(.DATA_WIDTH(PE*SIMD*WEIGHT_WIDTH), .FEED_STAGES(N_DCPL_STAGES)) inst_oreg ( - .clk(clk), .rst(rst), + .clk(clk), .rst(rst), .ivld(ovld_int), .irdy(ordy_int), .idat(odat_int), .ovld(ovld), .ordy(ordy), .odat(odat) ); -endmodule \ No newline at end of file +endmodule diff --git a/src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py b/src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py index 1b5d95913a..7bae191b83 100644 --- a/src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py +++ b/src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py @@ -366,9 +366,7 @@ def generate_hdl(self, model, fpgapart, clk): def prepare_codegen_default(self, fpgapart, clk): gemm_type = self.get_nodeattr("gemm_type") if gemm_type in ("mmau_1d", "mmau_2d"): - template_path = ( - os.environ["FINN_ROOT"] + "/finn-rtllib/mmu/mmu_axi_wrapper.v" - ) + template_path = os.environ["FINN_ROOT"] + "/finn-rtllib/mmu/mmu_axi_wrapper.v" elif self.get_nodeattr("TH") > 1: template_path = ( os.environ["FINN_ROOT"] + "/finn-rtllib/mvu_tiled/mvu_tiled_axi_wrapper.v" diff --git a/src/finn/util/mlo_sim.py b/src/finn/util/mlo_sim.py index 65f7c35d97..8a1f47abf0 100644 --- a/src/finn/util/mlo_sim.py +++ b/src/finn/util/mlo_sim.py @@ -82,7 +82,7 @@ def mlo_prehook_func_factory(node) -> Callable[[SimEngine], None]: if downstream.op_type.startswith("MVAU"): mvau_hbm_weights[idx] = {} mvau_hbm_weights[idx]["name"] = lb_inp.name - code_gen_dir = finnloop_op.get_nodeattr('code_gen_dir_ipgen') + code_gen_dir = finnloop_op.get_nodeattr("code_gen_dir_ipgen") npy_file = f"{code_gen_dir}/input1_MVAU_rtl_id_{idx}.npy" datfile = f"{code_gen_dir}/memblock_MVAU_rtl_id_{idx}.dat" mvau_op = getCustomOp(downstream) @@ -103,7 +103,7 @@ def mlo_prehook_func_factory(node) -> Callable[[SimEngine], None]: n_iters = len(flat) // words_per_iter weight_bytes = [] for it in range(n_iters): - for row in flat[it * words_per_iter:(it + 1) * words_per_iter]: + for row in flat[it * words_per_iter : (it + 1) * words_per_iter]: for val in row: weight_bytes.append(int(val) & 0xFF) # Pad to layer_offs boundary diff --git a/tests/fpgadataflow/test_fpgadataflow_finnloop.py b/tests/fpgadataflow/test_fpgadataflow_finnloop.py index e48b7afd76..1952a69c67 100644 --- a/tests/fpgadataflow/test_fpgadataflow_finnloop.py +++ b/tests/fpgadataflow/test_fpgadataflow_finnloop.py @@ -455,8 +455,16 @@ def make_single_mvau_loop_body( def create_chained_loop_bodies( - mw, mh, num_copies, elemwise_optype="ElementwiseMul_hls", rhs_shape=[1], eltw_param_dtype="INT8", - mvau_pe=2, mvau_simd=2, mvau_th=1, helper_pe=2, + mw, + mh, + num_copies, + elemwise_optype="ElementwiseMul_hls", + rhs_shape=[1], + eltw_param_dtype="INT8", + mvau_pe=2, + mvau_simd=2, + mvau_th=1, + helper_pe=2, ): loop_body_models = [] @@ -677,8 +685,16 @@ def test_finnloop_end2end_mlo_tiled( if (year, minor) < (2024, 2): pytest.skip("""At least Vivado version 2024.2 needed for MLO.""") loop_body_models = create_chained_loop_bodies( - dim, dim, iteration, elemwise_optype, rhs_shape, eltw_param_dtype, - mvau_pe=mvau_pe, mvau_simd=mvau_simd, mvau_th=mvau_th, helper_pe=helper_pe, + dim, + dim, + iteration, + elemwise_optype, + rhs_shape, + eltw_param_dtype, + mvau_pe=mvau_pe, + mvau_simd=mvau_simd, + mvau_th=mvau_th, + helper_pe=helper_pe, ) model = loop_body_models[0] for m in loop_body_models[1:]: @@ -789,9 +805,7 @@ def test_finnloop_end2end_mlo_tiled( bus = 0 for j, v in enumerate(row): bus |= (int(v) & 0xFF) << (j * 8) - tf.write( - f"[{i:3d}] {dec} | {hx} | 0x{bus:0{arr.shape[-1]*2}x}\n" - ) + tf.write(f"[{i:3d}] {dec} | {hx} | 0x{bus:0{arr.shape[-1]*2}x}\n") print(f"DEBUG: weight dump -> {txt_path}") for f in sorted(glob.glob(code_gen_dir + "/memblock_*.dat")): base = os.path.basename(f).replace(".dat", "") @@ -818,6 +832,7 @@ def test_finnloop_end2end_mlo_tiled( verif_dir + "/verify_stitched_ip_rtlsim_0_SUCCESS.npy" ), f"Check npy files in {verif_dir}" + # Debug test for manual loop transformation steps below # This test is intentionally not marked for CI # Use test_finnloop_end2end_mlo instead From 4d3eaf8989e87bf917e32bff634360352840fae2 Mon Sep 17 00:00:00 2001 From: dkorolij Date: Tue, 19 May 2026 10:35:50 +0000 Subject: [PATCH 04/17] removed MMAU from python code. Kept RTL under mvu_tiled for future integration. --- .../{ => mvu_tiled}/mmu/1d/collect_out_1d.sv | 0 .../{ => mvu_tiled}/mmu/1d/cu_mmau_1d.sv | 0 .../mmu/1d/sched_weights_1d.sv | 0 finn-rtllib/{ => mvu_tiled}/mmu/1d/sft_reg.sv | 0 .../{ => mvu_tiled}/mmu/2d/collect_out_2d.sv | 0 .../{ => mvu_tiled}/mmu/2d/cu_mmau_2d.sv | 0 .../mmu/2d/sched_weights_2d.sv | 0 finn-rtllib/{ => mvu_tiled}/mmu/en_global.sv | 0 finn-rtllib/{ => mvu_tiled}/mmu/mmu_axi.sv | 0 .../{ => mvu_tiled}/mmu/mmu_axi_wrapper.v | 0 finn-rtllib/{ => mvu_tiled}/mmu/q_writer.sv | 0 .../{ => mvu_tiled}/mmu/reorder_out.sv | 0 .../{ => mvu_tiled}/mmu/replay_buff_mmau.sv | 0 .../{ => mvu_tiled}/mmu/sched_activations.sv | 0 src/finn/custom_op/fpgadataflow/hwcustomop.py | 15 +-- .../fpgadataflow/matrixvectoractivation.py | 38 ++----- .../rtl/matrixvectoractivation_rtl.py | 94 ++-------------- tests/fpgadataflow/test_fpgadataflow_mvau.py | 101 ------------------ 18 files changed, 15 insertions(+), 233 deletions(-) rename finn-rtllib/{ => mvu_tiled}/mmu/1d/collect_out_1d.sv (100%) rename finn-rtllib/{ => mvu_tiled}/mmu/1d/cu_mmau_1d.sv (100%) rename finn-rtllib/{ => mvu_tiled}/mmu/1d/sched_weights_1d.sv (100%) rename finn-rtllib/{ => mvu_tiled}/mmu/1d/sft_reg.sv (100%) rename finn-rtllib/{ => mvu_tiled}/mmu/2d/collect_out_2d.sv (100%) rename finn-rtllib/{ => mvu_tiled}/mmu/2d/cu_mmau_2d.sv (100%) rename finn-rtllib/{ => mvu_tiled}/mmu/2d/sched_weights_2d.sv (100%) rename finn-rtllib/{ => mvu_tiled}/mmu/en_global.sv (100%) rename finn-rtllib/{ => mvu_tiled}/mmu/mmu_axi.sv (100%) rename finn-rtllib/{ => mvu_tiled}/mmu/mmu_axi_wrapper.v (100%) rename finn-rtllib/{ => mvu_tiled}/mmu/q_writer.sv (100%) rename finn-rtllib/{ => mvu_tiled}/mmu/reorder_out.sv (100%) rename finn-rtllib/{ => mvu_tiled}/mmu/replay_buff_mmau.sv (100%) rename finn-rtllib/{ => mvu_tiled}/mmu/sched_activations.sv (100%) diff --git a/finn-rtllib/mmu/1d/collect_out_1d.sv b/finn-rtllib/mvu_tiled/mmu/1d/collect_out_1d.sv similarity index 100% rename from finn-rtllib/mmu/1d/collect_out_1d.sv rename to finn-rtllib/mvu_tiled/mmu/1d/collect_out_1d.sv diff --git a/finn-rtllib/mmu/1d/cu_mmau_1d.sv b/finn-rtllib/mvu_tiled/mmu/1d/cu_mmau_1d.sv similarity index 100% rename from finn-rtllib/mmu/1d/cu_mmau_1d.sv rename to finn-rtllib/mvu_tiled/mmu/1d/cu_mmau_1d.sv diff --git a/finn-rtllib/mmu/1d/sched_weights_1d.sv b/finn-rtllib/mvu_tiled/mmu/1d/sched_weights_1d.sv similarity index 100% rename from finn-rtllib/mmu/1d/sched_weights_1d.sv rename to finn-rtllib/mvu_tiled/mmu/1d/sched_weights_1d.sv diff --git a/finn-rtllib/mmu/1d/sft_reg.sv b/finn-rtllib/mvu_tiled/mmu/1d/sft_reg.sv similarity index 100% rename from finn-rtllib/mmu/1d/sft_reg.sv rename to finn-rtllib/mvu_tiled/mmu/1d/sft_reg.sv diff --git a/finn-rtllib/mmu/2d/collect_out_2d.sv b/finn-rtllib/mvu_tiled/mmu/2d/collect_out_2d.sv similarity index 100% rename from finn-rtllib/mmu/2d/collect_out_2d.sv rename to finn-rtllib/mvu_tiled/mmu/2d/collect_out_2d.sv diff --git a/finn-rtllib/mmu/2d/cu_mmau_2d.sv b/finn-rtllib/mvu_tiled/mmu/2d/cu_mmau_2d.sv similarity index 100% rename from finn-rtllib/mmu/2d/cu_mmau_2d.sv rename to finn-rtllib/mvu_tiled/mmu/2d/cu_mmau_2d.sv diff --git a/finn-rtllib/mmu/2d/sched_weights_2d.sv b/finn-rtllib/mvu_tiled/mmu/2d/sched_weights_2d.sv similarity index 100% rename from finn-rtllib/mmu/2d/sched_weights_2d.sv rename to finn-rtllib/mvu_tiled/mmu/2d/sched_weights_2d.sv diff --git a/finn-rtllib/mmu/en_global.sv b/finn-rtllib/mvu_tiled/mmu/en_global.sv similarity index 100% rename from finn-rtllib/mmu/en_global.sv rename to finn-rtllib/mvu_tiled/mmu/en_global.sv diff --git a/finn-rtllib/mmu/mmu_axi.sv b/finn-rtllib/mvu_tiled/mmu/mmu_axi.sv similarity index 100% rename from finn-rtllib/mmu/mmu_axi.sv rename to finn-rtllib/mvu_tiled/mmu/mmu_axi.sv diff --git a/finn-rtllib/mmu/mmu_axi_wrapper.v b/finn-rtllib/mvu_tiled/mmu/mmu_axi_wrapper.v similarity index 100% rename from finn-rtllib/mmu/mmu_axi_wrapper.v rename to finn-rtllib/mvu_tiled/mmu/mmu_axi_wrapper.v diff --git a/finn-rtllib/mmu/q_writer.sv b/finn-rtllib/mvu_tiled/mmu/q_writer.sv similarity index 100% rename from finn-rtllib/mmu/q_writer.sv rename to finn-rtllib/mvu_tiled/mmu/q_writer.sv diff --git a/finn-rtllib/mmu/reorder_out.sv b/finn-rtllib/mvu_tiled/mmu/reorder_out.sv similarity index 100% rename from finn-rtllib/mmu/reorder_out.sv rename to finn-rtllib/mvu_tiled/mmu/reorder_out.sv diff --git a/finn-rtllib/mmu/replay_buff_mmau.sv b/finn-rtllib/mvu_tiled/mmu/replay_buff_mmau.sv similarity index 100% rename from finn-rtllib/mmu/replay_buff_mmau.sv rename to finn-rtllib/mvu_tiled/mmu/replay_buff_mmau.sv diff --git a/finn-rtllib/mmu/sched_activations.sv b/finn-rtllib/mvu_tiled/mmu/sched_activations.sv similarity index 100% rename from finn-rtllib/mmu/sched_activations.sv rename to finn-rtllib/mvu_tiled/mmu/sched_activations.sv diff --git a/src/finn/custom_op/fpgadataflow/hwcustomop.py b/src/finn/custom_op/fpgadataflow/hwcustomop.py index 6fb7abb709..8bd6dbb456 100644 --- a/src/finn/custom_op/fpgadataflow/hwcustomop.py +++ b/src/finn/custom_op/fpgadataflow/hwcustomop.py @@ -360,13 +360,7 @@ def generate_hdl_fetch_weights(self): pe = self.get_nodeattr("PE") simd = self.get_nodeattr("SIMD") theight = self.get_nodeattr("TH") - gemm_type = self.get_nodeattr("gemm_type") - if gemm_type in ("mmau_1d", "mmau_2d"): - cu_simd = self.get_nodeattr("CU_SIMD") - clen = (simd + cu_simd - 1) // cu_simd - n_reps = np.prod(self.get_nodeattr("numInputVectors")) // clen - else: - n_reps = np.prod(self.get_nodeattr("numInputVectors")) // theight + n_reps = np.prod(self.get_nodeattr("numInputVectors")) // theight en_mlo = "EN_MLO" if self.get_nodeattr("mlo_max_iter") else "NO_MLO" else: # Eltwise layers only have one parallelism parameter @@ -385,12 +379,7 @@ def generate_hdl_fetch_weights(self): # Compute IWSIMD and WSIMD for the fetch_weights wrapper if self.onnx_node.op_type in ops: - gemm_type = self.get_nodeattr("gemm_type") - if gemm_type in ("mmau_1d", "mmau_2d"): - cu_simd = self.get_nodeattr("CU_SIMD") - iwsimd = pe * cu_simd - wsimd = pe * cu_simd - elif theight > 1: + if theight > 1: iwsimd = (pe * simd) // theight wsimd = (pe * simd) // theight else: diff --git a/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py b/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py index f3bb6d402a..22a247c8db 100644 --- a/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py +++ b/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py @@ -131,9 +131,8 @@ def get_nodeattr_types(self): "s", False, "mvau", - {"mvau", "mvau_tiled", "mmau_1d", "mmau_2d"}, + {"mvau", "mvau_tiled"}, ), - "CU_SIMD": ("i", False, 3), } my_attrs.update(super().get_nodeattr_types()) return my_attrs @@ -274,17 +273,12 @@ def get_instream_width(self, ind=0): # TODO: Hacky, need to clean these calls . wp = self.get_input_datatype(1).bitwidth() mem_mode = self.get_nodeattr("mem_mode") theight = self.get_nodeattr("TH") - gemm_type = self.get_nodeattr("gemm_type") match mem_mode: case "dynamic": width = pe * wp case "external" | "external_mem" | "internal_decoupled": - if gemm_type in ("mmau_1d", "mmau_2d"): - cu_simd = self.get_nodeattr("CU_SIMD") - width = pe * cu_simd * wp - else: - width = ((pe * simd) * wp) // theight + width = ((pe * simd) * wp) // theight case _: width = 0 elif ind == 2: @@ -317,7 +311,6 @@ def get_folded_input_shape(self, ind=0): vecs = list(self.get_nodeattr("numInputVectors")) n_vecs = int(np.prod(vecs)) theight = self.get_nodeattr("TH") - gemm_type = self.get_nodeattr("gemm_type") if ind == 0: # calculate shape of input 0 @@ -327,11 +320,7 @@ def get_folded_input_shape(self, ind=0): case "dynamic": folded_input_shape = (1, mw, nf, pe) case "external" | "external_mem" | "internal_decoupled": - if gemm_type in ("mmau_1d", "mmau_2d"): - cu_simd = self.get_nodeattr("CU_SIMD") - folded_input_shape = (n_vecs, sf * nf, pe * cu_simd) - else: - folded_input_shape = (n_vecs, sf * nf, (simd * pe) // theight) + folded_input_shape = (n_vecs, sf * nf, (simd * pe) // theight) case _: raise Exception("Undefined input shape for requested input") else: @@ -485,14 +474,8 @@ def get_exp_cycles(self): # since mmv != 1 is not supported yet, we set mmv for now to 1 mmv = 1 # Tiling/systolic reduces throughput - gemm_type = self.get_nodeattr("gemm_type") - if gemm_type in ("mmau_1d", "mmau_2d"): - cu_simd = self.get_nodeattr("CU_SIMD") - clen = (simd + cu_simd - 1) // cu_simd - exp_cycles = (mh / pe) * (mw / simd) * np.prod(num_inp_vec) * clen / mmv - else: - # TH>1 (tiling) reduces throughput by factor TH (tinner = PE*SIMD/TH) - exp_cycles = (mh / pe) * (mw / simd) * np.prod(num_inp_vec) * th / mmv + # TH>1 (tiling) reduces throughput by factor TH (tinner = PE*SIMD/TH) + exp_cycles = (mh / pe) * (mw / simd) * np.prod(num_inp_vec) * th / mmv return int(exp_cycles) def minimize_accumulator_width(self, model): @@ -756,11 +739,7 @@ def make_weight_file(self, weights, weight_file_mode, weight_file_name): weight_tensor_pe_flipped = weight_tensor_pe_flipped.reshape(1, -1, pe * simd) weight_tensor_pe_flipped = weight_tensor_pe_flipped.copy() # tiling - gemm_type = self.get_nodeattr("gemm_type") - if gemm_type in ("mmau_1d", "mmau_2d"): - tinner = pe * self.get_nodeattr("CU_SIMD") - else: - tinner = (pe * simd) // self.get_nodeattr("TH") + tinner = (pe * simd) // self.get_nodeattr("TH") weight_tensor_simd_flipped = weight_tensor_simd_flipped.reshape(1, -1, tinner) weight_tensor_pe_flipped = weight_tensor_pe_flipped.reshape(1, -1, tinner) if weight_file_mode == "decoupled_npy": @@ -1111,11 +1090,8 @@ def code_generation_ipi(self): ] # Create Vivado axis_dwidth_converter IP theight = self.get_nodeattr("TH") - gemm_type = self.get_nodeattr("gemm_type") wdt = self.get_input_datatype(1) - if gemm_type in ("mmau_1d", "mmau_2d"): - iwsimd = self.get_nodeattr("PE") * self.get_nodeattr("CU_SIMD") - elif theight > 1: + if theight > 1: iwsimd = (self.get_nodeattr("PE") * self.get_nodeattr("SIMD")) // theight else: iwsimd = self.get_nodeattr("SIMD") diff --git a/src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py b/src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py index 7bae191b83..f744b0bca8 100644 --- a/src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py +++ b/src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py @@ -108,15 +108,7 @@ def execute_node(self, context, graph): wei = npy_to_rtlsim_input("{}/input_1.npy".format(code_gen_dir), export_wdt, wnbits) num_w_reps = np.prod(self.get_nodeattr("numInputVectors")) - gemm_type = self.get_nodeattr("gemm_type") - if gemm_type in ("mmau_1d", "mmau_2d"): - simd = self.get_nodeattr("SIMD") - cu_simd = self.get_nodeattr("CU_SIMD") - clen = (simd + cu_simd - 1) // cu_simd - num_w_reps = num_w_reps // clen - else: - num_w_reps = num_w_reps // self.get_nodeattr("TH") - + num_w_reps = num_w_reps // self.get_nodeattr("TH") io_dict = { "inputs": {"in0": inp, "in1": wei * num_w_reps}, "outputs": {"out0": []}, @@ -168,36 +160,9 @@ def instantiate_ip(self, cmd): node_name = self.onnx_node.name code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") - gemm_type = self.get_nodeattr("gemm_type") theight = self.get_nodeattr("TH") - if gemm_type in ("mmau_1d", "mmau_2d"): - rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/mmu/") - sourcefiles = [ - "../fifo/hdl/Q_srl.v", - "../skid/skid.sv", - "../ram/ram_p_c.sv", - "mmu_axi.sv", - "replay_buff_mmau.sv", - "sched_activations.sv", - "en_global.sv", - "q_writer.sv", - "reorder_out.sv", - ] - if gemm_type == "mmau_1d": - sourcefiles += [ - "1d/cu_mmau_1d.sv", - "1d/sched_weights_1d.sv", - "1d/collect_out_1d.sv", - "1d/sft_reg.sv", - ] - else: - sourcefiles += [ - "2d/cu_mmau_2d.sv", - "2d/sched_weights_2d.sv", - "2d/collect_out_2d.sv", - ] - elif theight > 1: + if theight > 1: rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/mvu_tiled/") sourcefiles = [ "../fifo/hdl/Q_srl.v", @@ -364,10 +329,7 @@ def generate_hdl(self, model, fpgapart, clk): self.set_nodeattr("ip_path", code_gen_dir) def prepare_codegen_default(self, fpgapart, clk): - gemm_type = self.get_nodeattr("gemm_type") - if gemm_type in ("mmau_1d", "mmau_2d"): - template_path = os.environ["FINN_ROOT"] + "/finn-rtllib/mmu/mmu_axi_wrapper.v" - elif self.get_nodeattr("TH") > 1: + if self.get_nodeattr("TH") > 1: template_path = ( os.environ["FINN_ROOT"] + "/finn-rtllib/mvu_tiled/mvu_tiled_axi_wrapper.v" ) @@ -400,22 +362,13 @@ def prepare_codegen_default(self, fpgapart, clk): ) code_gen_dict["$SEGMENTLEN$"] = [str(self._resolve_segment_len(clk))] - # MMU-specific template variables - if gemm_type in ("mmau_1d", "mmau_2d"): - code_gen_dict["$GEMM_TYPE$"] = [gemm_type] - code_gen_dict["$CU_SIMD$"] = [str(self.get_nodeattr("CU_SIMD"))] - n_vectors = int(np.prod(self.get_nodeattr("numInputVectors"))) - code_gen_dict["$N_VECTORS$"] = [str(n_vectors)] - return template_path, code_gen_dict def get_rtl_file_list(self, abspath=False): gemm_type = self.get_nodeattr("gemm_type") if abspath: code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + "/" - if gemm_type in ("mmau_1d", "mmau_2d"): - rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/mmu/") - elif self.get_nodeattr("TH") > 1: + if self.get_nodeattr("TH") > 1: rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/mvu_tiled/") else: rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/mvu/") @@ -423,35 +376,7 @@ def get_rtl_file_list(self, abspath=False): code_gen_dir = "" rtllib_dir = "" - if gemm_type in ("mmau_1d", "mmau_2d"): - verilog_files = [ - "../fifo/hdl/Q_srl.v", - "../skid/skid.sv", - "../ram/ram_p_c.sv", - "mmu_axi.sv", - "replay_buff_mmau.sv", - "sched_activations.sv", - "en_global.sv", - "q_writer.sv", - "reorder_out.sv", - ] - if gemm_type == "mmau_1d": - verilog_files += [ - "1d/cu_mmau_1d.sv", - "1d/sched_weights_1d.sv", - "1d/collect_out_1d.sv", - "1d/sft_reg.sv", - ] - else: - verilog_files += [ - "2d/cu_mmau_2d.sv", - "2d/sched_weights_2d.sv", - "2d/collect_out_2d.sv", - ] - verilog_files = [ - os.path.join(code_gen_dir, self.get_nodeattr("gen_top_module") + "_wrapper.v") - ] + [rtllib_dir + _ for _ in verilog_files] - elif self.get_nodeattr("TH") > 1: + if self.get_nodeattr("TH") > 1: verilog_files = [ "../fifo/hdl/Q_srl.v", "../skid/skid.sv", @@ -483,15 +408,8 @@ def get_rtl_file_list(self, abspath=False): return verilog_files def get_verilog_paths(self): - gemm_type = self.get_nodeattr("gemm_type") verilog_paths = super().get_verilog_paths() - if gemm_type in ("mmau_1d", "mmau_2d"): - verilog_paths.append(os.environ["FINN_ROOT"] + "/finn-rtllib/mmu") - if gemm_type == "mmau_1d": - verilog_paths.append(os.environ["FINN_ROOT"] + "/finn-rtllib/mmu/1d") - else: - verilog_paths.append(os.environ["FINN_ROOT"] + "/finn-rtllib/mmu/2d") - elif self.get_nodeattr("TH") > 1: + if self.get_nodeattr("TH") > 1: verilog_paths.append(os.environ["FINN_ROOT"] + "/finn-rtllib/mvu_tiled") else: verilog_paths.append(os.environ["FINN_ROOT"] + "/finn-rtllib/mvu") diff --git a/tests/fpgadataflow/test_fpgadataflow_mvau.py b/tests/fpgadataflow/test_fpgadataflow_mvau.py index f1016f092c..89d90fcb99 100644 --- a/tests/fpgadataflow/test_fpgadataflow_mvau.py +++ b/tests/fpgadataflow/test_fpgadataflow_mvau.py @@ -954,107 +954,6 @@ def test_fpgadataflow_rtl_tiled_mvau(mh, mw, pe, simd, th, idt_wdt, clk_ns): output_matmul == output_mvau_rtl_stitch ).all(), "Output of ONNX model not matching output of tiled stitched-IP RTL model!" - -@pytest.mark.parametrize("mh", [12]) -@pytest.mark.parametrize("mw", [12]) -@pytest.mark.parametrize("pe", [6]) -@pytest.mark.parametrize("simd", [6]) -@pytest.mark.parametrize("cu_simd", [3]) -@pytest.mark.parametrize("idt_wdt", [[DataType["UINT8"], DataType["INT8"]]]) -@pytest.mark.parametrize("clk_ns", [4]) -@pytest.mark.fpgadataflow -@pytest.mark.slow -@pytest.mark.vivado -def test_fpgadataflow_rtl_mmau_2d(mh, mw, pe, simd, cu_simd, idt_wdt, clk_ns): - # MMU only supported on Versal (DSP58) - part = "xcvc1902-vsva2197-2MP-e-S" - - if simd % cu_simd != 0: - pytest.skip("SIMD must be divisible by CU_SIMD") - - if mw % simd != 0: - pytest.skip("MW must be divisible by SIMD") - - if mh % pe != 0: - pytest.skip("MH must be divisible by PE") - - idt, wdt = idt_wdt - # Create test input vector (produced by SWG) - ofm_shape = (3, 3) - ofm_h, ofm_w = ofm_shape - ifm = helper.make_tensor_value_info("ifm", TensorProto.FLOAT, [1, ofm_h, ofm_w, mw]) - ofm = helper.make_tensor_value_info("ofm", TensorProto.FLOAT, (1, ofm_h, ofm_w, mh)) - W = gen_finn_dt_tensor(wdt, (mw, mh)) - model = make_single_matmul_modelwrapper(ifm, ofm, idt, wdt, W) - model = model.transform(GiveUniqueNodeNames()) - model = model.transform(GiveReadableTensorNames()) - - # Create MatMul & obtain golden reference output - A = gen_finn_dt_tensor( - model.get_tensor_datatype("global_in"), model.get_tensor_shape("global_in") - ) - input_dict = prepare_inputs(A, idt, wdt, inp_name="global_in") - - # Execute ONNX model - output_matmul = oxe.execute_onnx(model, input_dict)["global_out"] - - # Create MVAU - model = model.transform(to_hw.InferQuantizedMatrixVectorActivation()) - model = model.transform(GiveUniqueNodeNames()) - - # Apply convert-to-rtl step - model = model.transform(SpecializeLayers(part)) - model = model.transform(GiveUniqueNodeNames()) - - assert model.graph.node[0].op_type == "MVAU_rtl" - # Apply folding with MMU systolic array - folding_config = { - "Defaults": {}, - "MVAU_rtl_0": { - "PE": pe, - "SIMD": simd, - "CU_SIMD": cu_simd, - "gemm_type": "mmau_2d", - "resType": "dsp", - "mem_mode": "external_mem", - }, - } - model = model.transform(ApplyConfig(folding_config)) - model = model.transform(MinimizeWeightBitWidth()) - model = model.transform(MinimizeAccumulatorWidth()) - # make sure the changed datatypes are propagated through the network - model = model.transform(InferDataTypes()) - - # Run CPPsim - model = model.transform(SetExecMode("cppsim")) - model = model.transform(PrepareCppSim()) - model = model.transform(CompileCppSim()) - output_mvau_hls = oxe.execute_onnx(model, input_dict)["global_out"] - assert ( - output_matmul == output_mvau_hls - ).all(), "Output of ONNX model not matching output of node-by-node CPPsim!" - - # Run node-by-node RTLsim - model = model.transform(SetExecMode("rtlsim")) - model = model.transform(PrepareIP(part, clk_ns)) - model = model.transform(HLSSynthIP()) - model = model.transform(PrepareRTLSim()) - output_mvau_rtl = oxe.execute_onnx(model, input_dict)["global_out"] - assert ( - output_matmul == output_mvau_rtl - ).all(), "Output of ONNX model not matching output of node-by-node RTLsim!" - - # Run stitched-ip RTLsim - model = model.transform(InsertAndSetFIFODepths(part, clk_ns)) - model = model.transform(PrepareIP(part, clk_ns)) - model = model.transform(HLSSynthIP()) - model = model.transform(CreateStitchedIP(part, clk_ns)) - output_mvau_rtl_stitch = oxe.execute_onnx(model, input_dict)["global_out"] - assert ( - output_matmul == output_mvau_rtl_stitch - ).all(), "Output of ONNX model not matching output of MMU stitched-IP RTL model!" - - @pytest.mark.parametrize("mh", [32]) @pytest.mark.parametrize("mw", [16]) @pytest.mark.parametrize("n_vectors", [32]) From ad4cb78cc8b67494c856a8b134abd6c7f8f26377 Mon Sep 17 00:00:00 2001 From: dkorolij Date: Tue, 19 May 2026 10:47:09 +0000 Subject: [PATCH 05/17] pre-commit cleanup. --- src/finn/core/rtlsim_exec.py | 3 +-- .../custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py | 1 - tests/fpgadataflow/test_fpgadataflow_finnloop.py | 3 +-- tests/fpgadataflow/test_fpgadataflow_mvau.py | 1 + 4 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/finn/core/rtlsim_exec.py b/src/finn/core/rtlsim_exec.py index a7555fb183..db19e3f730 100644 --- a/src/finn/core/rtlsim_exec.py +++ b/src/finn/core/rtlsim_exec.py @@ -26,6 +26,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import json import numpy as np import os from qonnx.custom_op.registry import getCustomOp @@ -362,8 +363,6 @@ def rtlsim_exec_finnxsi(model, execution_context, pre_hook=None, post_hook=None) # automatically load AXI-MM weight images for external_mem nodes aximm_weights_json = model.get_metadata_prop("vivado_stitch_aximm_weights") if aximm_weights_json is not None: - import json - aximm_weights = json.loads(aximm_weights_json) for aximm_name, npy_path in aximm_weights.items(): weight_npy = np.load(npy_path) diff --git a/src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py b/src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py index f744b0bca8..3d1f0d764f 100644 --- a/src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py +++ b/src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py @@ -365,7 +365,6 @@ def prepare_codegen_default(self, fpgapart, clk): return template_path, code_gen_dict def get_rtl_file_list(self, abspath=False): - gemm_type = self.get_nodeattr("gemm_type") if abspath: code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") + "/" if self.get_nodeattr("TH") > 1: diff --git a/tests/fpgadataflow/test_fpgadataflow_finnloop.py b/tests/fpgadataflow/test_fpgadataflow_finnloop.py index 447b4ce0e3..ec822caa68 100644 --- a/tests/fpgadataflow/test_fpgadataflow_finnloop.py +++ b/tests/fpgadataflow/test_fpgadataflow_finnloop.py @@ -1,5 +1,6 @@ import pytest +import glob import numpy as np import os import re @@ -858,8 +859,6 @@ def test_finnloop_end2end_mlo_tiled( build.build_dataflow_cfg(tmp_output_dir + "/mlo_model.onnx", cfg) # Dump weight files for hardware debug - import glob - built_model = ModelWrapper(tmp_output_dir + "/mlo_model.onnx") for node in built_model.graph.node: if node.op_type == "FINNLoop": diff --git a/tests/fpgadataflow/test_fpgadataflow_mvau.py b/tests/fpgadataflow/test_fpgadataflow_mvau.py index 0afe15b3a7..378ad17b22 100644 --- a/tests/fpgadataflow/test_fpgadataflow_mvau.py +++ b/tests/fpgadataflow/test_fpgadataflow_mvau.py @@ -955,6 +955,7 @@ def test_fpgadataflow_rtl_tiled_mvau(mh, mw, pe, simd, th, idt_wdt, clk_ns): output_matmul == output_mvau_rtl_stitch ).all(), "Output of ONNX model not matching output of tiled stitched-IP RTL model!" + @pytest.mark.parametrize("mh", [32]) @pytest.mark.parametrize("mw", [16]) @pytest.mark.parametrize("n_vectors", [32]) From 4a756f91bfa0e5bd7ecd990dc6fc4d1be22e0539 Mon Sep 17 00:00:00 2001 From: dkorolij Date: Wed, 20 May 2026 12:07:53 +0100 Subject: [PATCH 06/17] hls mvau interface update. --- .../hls/matrixvectoractivation_hls.py | 68 ++++++------------- 1 file changed, 19 insertions(+), 49 deletions(-) diff --git a/src/finn/custom_op/fpgadataflow/hls/matrixvectoractivation_hls.py b/src/finn/custom_op/fpgadataflow/hls/matrixvectoractivation_hls.py index de73a54464..cacdf45134 100644 --- a/src/finn/custom_op/fpgadataflow/hls/matrixvectoractivation_hls.py +++ b/src/finn/custom_op/fpgadataflow/hls/matrixvectoractivation_hls.py @@ -139,11 +139,10 @@ def dsp_estimation(self, fpgapart): def code_generation_ipgen(self, model, fpgapart, clk): """Generates c++ code and tcl script for ip generation.""" super().code_generation_ipgen(model, fpgapart, clk) - dynamic_input = self.get_nodeattr("dynamic_input") mem_mode = self.get_nodeattr("mem_mode") - if dynamic_input: + if mem_mode == "dynamic": self.generate_hdl_dynload() - if mem_mode == "internal_decoupled" and not self.get_nodeattr("mlo_max_iter"): + if mem_mode == "internal_decoupled": if self.get_nodeattr("ram_style") == "ultra" and not is_versal(fpgapart): runtime_writeable = self.get_nodeattr("runtime_writeable_weights") assert ( @@ -151,7 +150,7 @@ def code_generation_ipgen(self, model, fpgapart, clk): ), """Layer with URAM weights must have runtime_writeable_weights=1 if Ultrascale device is targeted.""" self.generate_hdl_memstream(fpgapart, pumped_memory=self.get_nodeattr("pumpedMemory")) - elif self.get_nodeattr("mlo_max_iter"): + elif mem_mode == "external_mem": self.generate_hdl_fetch_weights(fpgapart) def get_template_param_values(self): @@ -236,11 +235,7 @@ def defines(self, var): numReps, ) ] - if ( - mem_mode == "internal_decoupled" - or mem_mode == "external" - or self.get_nodeattr("mlo_max_iter") - ): + if mem_mode in ["internal_decoupled", "external", "external_mem", "dynamic"]: wdt = self.get_input_datatype(1) self.code_gen_dict["$DEFINES$"].append("#define WP1 {}\n".format(wdt.bitwidth())) @@ -270,15 +265,11 @@ def read_npy_data(self): ) mem_mode = self.get_nodeattr("mem_mode") - if ( - mem_mode == "internal_decoupled" - or mem_mode == "external" - or self.get_nodeattr("mlo_max_iter") - ): + if mem_mode in ["internal_decoupled", "external", "external_mem", "dynamic"]: wdt = self.get_input_datatype(1) elem_bits = wdt.bitwidth() packed_bits = self.get_instream_width(1) - if self.get_nodeattr("dynamic_input"): + if mem_mode == "dynamic": packed_bits = packed_bits * self.get_nodeattr("SIMD") packed_hls_type = "ap_uint<%d>" % packed_bits elem_hls_type = wdt.get_hls_datatype_str() @@ -306,13 +297,9 @@ def strm_decl(self): 'hls::stream> out0_V ("out0_V");'.format(self.get_outstream_width()) ) - if ( - mem_mode == "internal_decoupled" - or mem_mode == "external" - or self.get_nodeattr("mlo_max_iter") - ): + if mem_mode in ["internal_decoupled", "external", "external_mem", "dynamic"]: iwidth = self.get_instream_width(1) - if self.get_nodeattr("dynamic_input"): + if mem_mode == "dynamic": iwidth = iwidth * self.get_nodeattr("SIMD") self.code_gen_dict["$STREAMDECLARATIONS$"].append( 'hls::stream> in1_V ("in1_V");'.format(iwidth) @@ -342,11 +329,7 @@ def docompute(self): map_to_hls_mult_style[self.get_nodeattr("resType")], ) ] - elif ( - mem_mode == "internal_decoupled" - or mem_mode == "external" - or self.get_nodeattr("mlo_max_iter") - ): + elif mem_mode in ["internal_decoupled", "external", "external_mem", "dynamic"]: wdt = self.get_input_datatype(1) if wdt == DataType["BIPOLAR"]: export_wdt = DataType["BINARY"] @@ -414,13 +397,9 @@ def blackboxfunction(self): self.get_outstream_width(), ) ] - elif ( - mem_mode == "internal_decoupled" - or mem_mode == "external" - or self.get_nodeattr("mlo_max_iter") - ): + elif mem_mode in ["internal_decoupled", "external", "external_mem", "dynamic"]: wwidth = self.get_instream_width(1) - if self.get_nodeattr("dynamic_input"): + if mem_mode == "dynamic": wwidth = wwidth * self.get_nodeattr("SIMD") self.code_gen_dict["$BLACKBOXFUNCTION$"] = [ """void {}( @@ -455,11 +434,7 @@ def pragmas(self): self.code_gen_dict["$PRAGMAS$"].append( ("#pragma HLS ARRAY_PARTITION variable=weights.m_weights " "complete dim=1") ) - elif ( - mem_mode == "internal_decoupled" - or mem_mode == "external" - or self.get_nodeattr("mlo_max_iter") - ): + elif mem_mode in ["internal_decoupled", "external", "external_mem", "dynamic"]: self.code_gen_dict["$PRAGMAS$"].append("#pragma HLS INTERFACE axis port=in1_V") else: @@ -500,7 +475,8 @@ def get_ap_int_max_w(self): # internal_decoupled mode weight stream weightstream = self.get_instream_width(1) simd = self.get_nodeattr("SIMD") - if self.get_nodeattr("dynamic_input"): + mem_mode = self.get_nodeattr("mem_mode") + if mem_mode == "dynamic": weightstream = weightstream * simd # single PE weight entry weight_bits = self.get_input_datatype(1).bitwidth() @@ -509,7 +485,6 @@ def get_ap_int_max_w(self): def execute_node(self, context, graph): mode = self.get_nodeattr("exec_mode") - dynamic_input = self.get_nodeattr("dynamic_input") mem_mode = self.get_nodeattr("mem_mode") node = self.onnx_node @@ -553,7 +528,7 @@ def execute_node(self, context, graph): ) if in_ind == 1: - if dynamic_input: + if mem_mode in ["dynamic", "external", "internal_decoupled", "external_mem"]: reshaped_input = context[inputs].reshape(-1, context[inputs].shape[-1]) self.make_weight_file( reshaped_input, "decoupled_npy", "{}/input_1.npy".format(code_gen_dir) @@ -578,13 +553,9 @@ def execute_node(self, context, graph): inp = npy_to_rtlsim_input("{}/input_0.npy".format(code_gen_dir), export_idt, nbits) self.reset_rtlsim(sim) - if ( - dynamic_input - or mem_mode in ["external", "internal_decoupled"] - or self.get_nodeattr("mlo_max_iter") - ): + if mem_mode in ["external", "internal_decoupled", "external_mem", "dynamic"]: wnbits = self.get_instream_width(1) - if self.get_nodeattr("dynamic_input"): + if mem_mode == "dynamic": wnbits = wnbits * self.get_nodeattr("SIMD") export_wdt = self.get_input_datatype(1) @@ -688,9 +659,8 @@ def instantiate_ip(self, cmd): # instantiate the HLS IP vlnv = self.get_nodeattr("ip_vlnv") node_name = self.onnx_node.name - if self.get_nodeattr("mem_mode") == "internal_decoupled" or self.get_nodeattr( - "mlo_max_iter" - ): + mem_mode = self.get_nodeattr("mem_mode") + if mem_mode in ["internal_decoupled", "external_mem", "dynamic"]: cmd.append("create_bd_cell -type ip -vlnv %s /%s/%s" % (vlnv, node_name, node_name)) else: cmd.append("create_bd_cell -type ip -vlnv %s %s" % (vlnv, node_name)) From c23b1a1ebe07001ade31bb8ad3a0c7b7fa6a8863 Mon Sep 17 00:00:00 2001 From: dkorolij Date: Wed, 20 May 2026 12:11:26 +0100 Subject: [PATCH 07/17] revert the folded_input_shape for dynamic matmuls. --- src/finn/custom_op/fpgadataflow/matrixvectoractivation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py b/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py index 82f59dd94b..444f2420dd 100644 --- a/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py +++ b/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py @@ -318,7 +318,7 @@ def get_folded_input_shape(self, ind=0): elif ind == 1: match mem_mode: case "dynamic": - folded_input_shape = (1, mw, nf, pe) + folded_input_shape = tuple(vecs[:2] + [mw] + [nf, pe]) case "external" | "external_mem" | "internal_decoupled": folded_input_shape = (n_vecs, sf * nf, (simd * pe) // theight) case _: From f7c9a2a10f4273d9262e97994c394cfaaf59bf21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20B=2E=20Preu=C3=9Fer?= Date: Mon, 1 Jun 2026 15:08:14 +0100 Subject: [PATCH 08/17] Refine replacing replay and transposes by input_gen and adding a testbench. --- finn-rtllib/mvu_tiled/acc_stage.sv | 35 +- finn-rtllib/mvu_tiled/cu_mvau_tiled.sv | 19 +- finn-rtllib/mvu_tiled/input_gen.sv | 265 ++++++++++++ finn-rtllib/mvu_tiled/mmu/mmu_axi.sv | 29 +- finn-rtllib/mvu_tiled/mmu/reorder_out.sv | 341 --------------- finn-rtllib/mvu_tiled/mmu/replay_buff_mmau.sv | 397 ------------------ finn-rtllib/mvu_tiled/mvu_tiled_axi.sv | 34 +- finn-rtllib/mvu_tiled/reorder_out.sv | 341 --------------- finn-rtllib/mvu_tiled/replay_buff_tile.sv | 392 ----------------- finn-rtllib/mvu_tiled/tb/mvu_tiled_axi_tb.sv | 341 +++++++++++++++ .../rtl/matrixvectoractivation_rtl.py | 8 +- 11 files changed, 680 insertions(+), 1522 deletions(-) create mode 100644 finn-rtllib/mvu_tiled/input_gen.sv delete mode 100644 finn-rtllib/mvu_tiled/mmu/reorder_out.sv delete mode 100644 finn-rtllib/mvu_tiled/mmu/replay_buff_mmau.sv delete mode 100644 finn-rtllib/mvu_tiled/reorder_out.sv delete mode 100644 finn-rtllib/mvu_tiled/replay_buff_tile.sv create mode 100644 finn-rtllib/mvu_tiled/tb/mvu_tiled_axi_tb.sv diff --git a/finn-rtllib/mvu_tiled/acc_stage.sv b/finn-rtllib/mvu_tiled/acc_stage.sv index afbb5e1d1b..d20607a906 100644 --- a/finn-rtllib/mvu_tiled/acc_stage.sv +++ b/finn-rtllib/mvu_tiled/acc_stage.sv @@ -112,11 +112,6 @@ assign inc_acc = val[ADD_LAT-1]; // Accumulation // -localparam integer TH_BITS = $clog2(TH); - -logic [TH_BITS-1:0] cnt_prep = 0; -logic prep = 1'b1; - logic fifo_in_tvalid, fifo_in_tready; logic fifo_out_tvalid, fifo_out_tready; logic [PE*ACCU_WIDTH-1:0] fifo_in_tdata, fifo_out_tdata; @@ -137,27 +132,21 @@ Q_srl #( .maxcount() ); +logic signed [$clog2(TH):0] cnt_prep = -TH; +uwire prep = cnt_prep[$left(cnt_prep)]; always_ff @(posedge clk) begin - if(rst) begin - cnt_prep <= 0; - prep <= 1'b1; + if(rst) cnt_prep <= -TH; + else cnt_prep <= cnt_prep + prep; +end - odat <= 'X; - oval <= 1'b0; +always_ff @(posedge clk) begin + if(rst) begin + odat <= 'x; + oval <= 0; end - else begin - if(cnt_prep == TH-1) begin - prep <= 1'b0; - cnt_prep <= 0; - end - else begin - cnt_prep <= cnt_prep + 1; - end - - if(en) begin - odat <= dat_int; - oval <= val_int && last_int; - end + else if(en) begin + odat <= dat_int; + oval <= val_int && last_int; end end diff --git a/finn-rtllib/mvu_tiled/cu_mvau_tiled.sv b/finn-rtllib/mvu_tiled/cu_mvau_tiled.sv index 43f697dff7..b716fe22b4 100644 --- a/finn-rtllib/mvu_tiled/cu_mvau_tiled.sv +++ b/finn-rtllib/mvu_tiled/cu_mvau_tiled.sv @@ -173,11 +173,12 @@ module cu_mvau_tiled #( .IS_CLK_INVERTED(1'b0), // Optional inversion for CLK .IS_INMODE_INVERTED(5'b00000), // Optional inversion for INMODE .IS_NEGATE_INVERTED(3'b000), // Optional inversion for NEGATE - .IS_OPMODE_INVERTED({ 2'b00 , // W: LAST ? (L[1] ? 0 : P) : 0 - 3'b000, // Z: FIRST ? 0 : PCIN - 2'b01, // Y: M - 2'b01 // X: M - }), // Optional inversion for OPMODE + .IS_OPMODE_INVERTED({ + 2'b00, // W: 0 (unused, accumulation is external) + 3'b000, // Z: 0 (unused) + 2'b01, // Y: M (multiply) + 2'b01 // X: M (multiply) + }), // Static OPMODE='0 inverted to select P = M (multiply-only) .IS_RSTALLCARRYIN_INVERTED(1'b0), // Optional inversion for RSTALLCARRYIN .IS_RSTALUMODE_INVERTED(1'b0), // Optional inversion for RSTALUMODE .IS_RSTA_INVERTED(1'b0), // Optional inversion for RSTA @@ -201,7 +202,7 @@ module cu_mvau_tiled #( .DREG(0), // Pipeline stages for D (0-1) .INMODEREG(1), // Pipeline stages for INMODE (0-1) .MREG(1), // Multiplier pipeline stages (0-1) - .OPMODEREG(1), // Pipeline stages for OPMODE (0-1) + .OPMODEREG(0), // No register needed: OPMODE is static .PREG(PREG), // Number of pipeline stages for P (0-1) .RESET_MODE("SYNC") // Selection of synchronous or asynchronous reset. (ASYNC, SYNC). ) @@ -238,7 +239,7 @@ module cu_mvau_tiled #( INTERNAL_REGS==2 ? 1'b0 : 1'b1 }), // 5-bit input: INMODE control .NEGATE('0), // 3-bit input: Negates the input of the multiplier - .OPMODE('0), // 9-bit input: Operation mode + .OPMODE('0), // 9-bit input: Static (inverted to X=Y=M, W=Z=0) // Data inputs: Data Ports .A({ 7'bx, a_in_i[j] }), // 34-bit input: A data .B(b_in_i[i][j]), // 24-bit input: B data @@ -255,7 +256,7 @@ module cu_mvau_tiled #( .CEB2(INTERNAL_REGS==2 ? en : '0), // 1-bit input: Clock enable for 2nd stage BREG .CEC('0), // 1-bit input: Clock enable for CREG .CECARRYIN('0), // 1-bit input: Clock enable for CARRYINREG - .CECTRL(en), // 1-bit input: Clock enable for OPMODEREG and CARRYINSELREG + .CECTRL('0), // 1-bit input: Clock enable for OPMODEREG and CARRYINSELREG .CED('0), // 1-bit input: Clock enable for DREG .CEINMODE(en), // 1-bit input: Clock enable for INMODEREG .CEM(en), // 1-bit input: Clock enable for MREG @@ -265,7 +266,7 @@ module cu_mvau_tiled #( .RSTALUMODE('0), // 1-bit input: Reset for ALUMODEREG .RSTB('0), // 1-bit input: Reset for BREG .RSTC('0), // 1-bit input: Reset for CREG - .RSTCTRL(rst), // 1-bit input: Reset for OPMODEREG and CARRYINSELREG + .RSTCTRL('0), // 1-bit input: Reset for OPMODEREG and CARRYINSELREG .RSTD('0), // 1-bit input: Reset for DREG and ADREG .RSTINMODE(rst), // 1-bit input: Reset for INMODE register .RSTM('0), // 1-bit input: Reset for MREG diff --git a/finn-rtllib/mvu_tiled/input_gen.sv b/finn-rtllib/mvu_tiled/input_gen.sv new file mode 100644 index 0000000000..16e2a443fd --- /dev/null +++ b/finn-rtllib/mvu_tiled/input_gen.sv @@ -0,0 +1,265 @@ +/**************************************************************************** + * Copyright Advanced Micro Devices, Inc. + * SPDX-License-Identifier: BSD-3-Clause + * + * @author Thomas B. Preußer + * @brief + * Generic sliding window / input generator driven by a perfect loop nest. + * + * A loop nest: + * + * for(i0 = 0; i0 < DIMS[0]; i0++) + * for(i1 = 0; i1 < DIMS[1]; i1++) + * ... + * for(in = 0; in < DIMS[D-1]; in++) + * emit(buf[COEFS[0]*i0 + COEFS[1]*i1 + ... + COEFS[D-1]*in]) + * + * is encoded by the array parameters DIMS and COEFS. The module reads + * a linear input stream into a circular buffer and replays elements + * according to the loop nest addressing, supporting arbitrary strides, + * dilations, and transpositions. + * + * FM_SIZE is the number of input elements per feature map (period of the + * input stream). The olst output exposes the level-completion cascade + * term[D-1:0] synchronous with each output beat. + ***************************************************************************/ + +module input_gen #( + int unsigned DATA_WIDTH, + int unsigned FM_SIZE, + int unsigned D, + int unsigned DIMS[D], + int unsigned COEFS[D] +)( + input logic clk, + input logic rst, + + // Input Stream + input logic [DATA_WIDTH-1:0] idat, + input logic ivld, + output logic irdy, + + // Output Stream + output logic [DATA_WIDTH-1:0] odat, + output logic ovld, + output logic [D-1:0] olst, + output logic [D-1:0] odone, + input logic ordy +); + + //=== Parameter Validation ============================================== + initial begin + if(D == 0) begin + $error("%m: D must be at least 1."); + $finish; + end + for(int unsigned i = 0; i < D; i++) begin + if(DIMS[i] == 0) begin + $error("%m: DIMS[%0d] must be positive.", i); + $finish; + end + end + end + + //=== Elaboration-Time Nest Computations ================================ + // Parent coefficient per level (W in the HLS Nest<> encoding): + // W[0] = FM_SIZE, W[i>0] = COEFS[i-1]. + typedef int unsigned w_arr_t[D+1]; + function automatic w_arr_t INIT_W(); + automatic w_arr_t a; + a[0] = FM_SIZE; + for(int unsigned i = 0; i < D; i++) a[i+1] = COEFS[i]; + return a; + endfunction : INIT_W + localparam w_arr_t W = INIT_W(); + + // Free-pointer responsibility flag per level. + // R_FLAG[i] is the R flag passed into level i from its parent. + typedef bit r_flag_arr_t[D+1]; + function automatic r_flag_arr_t INIT_R_FLAG(); + automatic r_flag_arr_t a; + a[0] = 1; + for(int unsigned i = 1; i <= D; i++) + a[i] = a[i-1] && (COEFS[i-1] > 0) + && (COEFS[i-1] * DIMS[i-1] <= W[i-1]); + return a; + endfunction : INIT_R_FLAG + localparam r_flag_arr_t R_FLAG = INIT_R_FLAG(); + + // Terminal read-pointer increment when level i completes. + // Index D covers the default innermost-advance case. + typedef int rp_inc_arr_t[D+1]; + function automatic rp_inc_arr_t INIT_RP_INC(); + automatic rp_inc_arr_t a; + automatic int unsigned rw = 0; // cumulative rp_rewind, built inside out + for(int i = D; i >= 0; i--) begin + if(i < int'(D)) rw = (DIMS[i]-1) * COEFS[i] + rw; + a[i] = int'(W[i]) - int'(rw); + end + return a; + endfunction : INIT_RP_INC + localparam rp_inc_arr_t TERMINAL_RP_INC = INIT_RP_INC(); + + // Negated terminal free-pointer increment when level i completes. + // Stored negated for direct use in the negated capacity counter. + // Index D covers the default innermost-advance case. + typedef int fp_inc_arr_t[D+1]; + function automatic fp_inc_arr_t INIT_FP_INC(); + automatic fp_inc_arr_t a; + automatic int unsigned fw = 0; // cumulative fp_rewind, built inside out + for(int i = D; i >= 0; i--) begin + if(i < int'(D)) fw = R_FLAG[i+1]? (DIMS[i]-1) * COEFS[i] + fw : 0; + a[i] = R_FLAG[i]? int'(fw) - int'(W[i]) : 0; + end + return a; + endfunction : INIT_FP_INC + localparam fp_inc_arr_t TERMINAL_FP_INC = INIT_FP_INC(); + + // Maximum buffer capacity requirement: the larger of the max backward + // read-pointer retraction and the max read-free pointer gap. + function automatic int unsigned INIT_MAX_OCCUPANCY(); + automatic int unsigned m = 0; + automatic int unsigned rw = 0; + automatic int unsigned fw = 0; + for(int unsigned i = 0; i < D; i++) begin + automatic int t = -TERMINAL_RP_INC[i]; + if(t > int'(m)) m = t; + end + for(int i = D-1; i >= 0; i--) begin + rw = (DIMS[i]-1) * COEFS[i] + rw; + fw = R_FLAG[i+1]? (DIMS[i]-1) * COEFS[i] + fw : 0; + if(rw - fw > m) m = rw - fw; + end + return m; + endfunction : INIT_MAX_OCCUPANCY + + //=== Buffer Sizing ===================================================== + localparam int unsigned WP_DELAY = 1; + localparam int unsigned MAX_OCCUPANCY = INIT_MAX_OCCUPANCY(); + localparam int unsigned ADDR_BITS = $clog2(MAX_OCCUPANCY + WP_DELAY + 2); + localparam int unsigned BUF_SIZE = 1 << ADDR_BITS; + + // Pointer type: one extra bit for signed wrap-around detection. + typedef logic signed [ADDR_BITS:0] ptr_t; + + // Pointer increment type: must accommodate the largest absolute increment. + function automatic int unsigned INIT_MAX_ABS_INC(); + automatic int unsigned m = 0; + for(int unsigned i = 0; i <= D; i++) begin + automatic int unsigned rp_abs = TERMINAL_RP_INC[i] < 0? -TERMINAL_RP_INC[i] : TERMINAL_RP_INC[i]; + automatic int unsigned fp_abs = TERMINAL_FP_INC[i] < 0? -TERMINAL_FP_INC[i] : TERMINAL_FP_INC[i]; + if(rp_abs > m) m = rp_abs; + if(fp_abs > m) m = fp_abs; + end + return m; + endfunction : INIT_MAX_ABS_INC + localparam int unsigned INC_BITS = 1 + $clog2(INIT_MAX_ABS_INC() + 1); + typedef logic signed [INC_BITS-1:0] inc_t; + + //=== Nest Counters ===================================================== + // done[i]: level i has exhausted its iterations (sign-bit of Cnt). + // term[i]: level i and all inner levels completed simultaneously. + uwire [D:0] done; + uwire [D:0] term; + assign done[D] = 1; + assign term[D] = 1; + + uwire advance; // forward-declared, defined in output section + + for(genvar i = 0; i < D; i++) begin : genCnt + uwire step = advance && term[i+1]; + + if(DIMS[i] == 1) begin : genTrivial + assign done[i] = 1; + end : genTrivial + else begin : genCounter + logic signed [$clog2(DIMS[i]-1):0] Cnt = DIMS[i]-2; // DIMS[i]-2, ..., 1, 0, -1 (done) + always_ff @(posedge clk) begin + if(rst) Cnt <= DIMS[i]-2; + else if(step) Cnt <= Cnt + (done[i]? $signed(DIMS[i])-1 : -1); + end + assign done[i] = Cnt[$left(Cnt)]; + end : genCounter + + assign term[i] = term[i+1] && done[i]; + end : genCnt + + //=== Pointer Increment Mux (Combinational) ============================= + inc_t rp_inc; + inc_t fp_inc; + always_comb begin + rp_inc = 0; + fp_inc = 0; + for(int i = D; i >= 0; i--) begin + if(term[i]) begin + rp_inc = TERMINAL_RP_INC[i]; + if(R_FLAG[i]) fp_inc = TERMINAL_FP_INC[i]; + end + end + end + + //=== Circular Buffer and Pointer Management ============================ + logic [DATA_WIDTH-1:0] Buf[BUF_SIZE]; + ptr_t Wp = 0; + ptr_t WpZ = 0; + ptr_t Rp = 0; + ptr_t Cap = -BUF_SIZE+1; // -BUF_SIZE+1, ..., -1, 0 (full) + + uwire has_data = $signed(Rp - WpZ) < 0; + + assign irdy = Cap[$left(Cap)]; + + // Buffer memory — one write port, one registered read port. + // Speculative pre-fetch: on advance, read from the next Rp so that + // BufRd is ready without a settling cycle. + logic [DATA_WIDTH-1:0] BufRd; + uwire ptr_t rd_ptr = Rp + (advance? ptr_t'(rp_inc) : ptr_t'(0)); + always_ff @(posedge clk) begin + if(irdy) Buf[Wp[ADDR_BITS-1:0]] <= idat; + BufRd <= Buf[rd_ptr[ADDR_BITS-1:0]]; + end + + always_ff @(posedge clk) begin + if(rst) begin + Wp <= 0; + WpZ <= 0; + Rp <= 0; + Cap <= -BUF_SIZE+1; + end + else begin + automatic logic istep = irdy && ivld; + WpZ <= Wp; + Wp <= Wp + istep; + Cap <= Cap + (advance? ptr_t'(fp_inc) : ptr_t'(0)) + istep; + if(advance) Rp <= Rp + ptr_t'(rp_inc); + end + end + + //=== Output Stage ====================================================== + logic OVld = 0; + logic [DATA_WIDTH-1:0] OBuf = 'x; + logic [D-1:0] OLst = 'x; + logic [D-1:0] ODone = 'x; + always_ff @(posedge clk) begin + if(rst) begin + OVld <= 0; + OBuf <= 'x; + OLst <= 'x; + ODone <= 'x; + end + else if(!OVld || ordy) begin + OVld <= has_data; + OBuf <= BufRd; + OLst <= term[D-1:0]; + ODone <= done[D-1:0]; + end + end + + assign advance = has_data && (!OVld || ordy); + + assign odat = OBuf; + assign ovld = OVld; + assign olst = OLst; + assign odone = ODone; + +endmodule : input_gen diff --git a/finn-rtllib/mvu_tiled/mmu/mmu_axi.sv b/finn-rtllib/mvu_tiled/mmu/mmu_axi.sv index 7e0b3f14bb..77d07d9db2 100644 --- a/finn-rtllib/mvu_tiled/mmu/mmu_axi.sv +++ b/finn-rtllib/mvu_tiled/mmu/mmu_axi.sv @@ -130,11 +130,20 @@ module mmu_axi #( logic act_s0_tlast; logic act_s0_tvalid, act_s0_tready; - replay_buff_mmau #(.XC(SF), .YC(CLEN), .W(SIMD*ACTIVATION_WIDTH), .N_REPS(NF), .IO_TILED(IN_TILED)) activation_replay ( + uwire [2:0] act_done; + input_gen #( + .DATA_WIDTH(SIMD*ACTIVATION_WIDTH), + .FM_SIZE(SF * CLEN), + .D(3), + .DIMS('{NF, SF, CLEN}), + .COEFS(IN_TILED ? '{0, CLEN, 1} : '{0, 1, SF}) + ) activation_replay ( .clk(ap_clk), .rst(~ap_rst_n), - .ivld(s_axis_input_tvalid), .irdy(s_axis_input_tready), .idat(s_axis_input_tdata), - .ovld(act_s0_tvalid), .ordy(act_s0_tready), .odat(act_s0_tdata), .olast(act_s0_tlast) + .idat(s_axis_input_tdata), + .ivld(s_axis_input_tvalid), .irdy(s_axis_input_tready), + .odat(act_s0_tdata), .ovld(act_s0_tvalid), .olst(), .odone(act_done), .ordy(act_s0_tready) ); + assign act_s0_tlast = act_done[1]; if (ASIMD > SIMD) assign act_s0_tdata_mod[ASIMD-1:SIMD] = '0; @@ -238,10 +247,18 @@ end // Reorder // --------------------------------------------------------------------- if(OUT_TILED == 0) begin - reorder_out #(.W(OUTPUT_STREAM_WIDTH_BA), .XC(NF), .YC(CLEN)) inst_reorder_out ( + input_gen #( + .DATA_WIDTH(OUTPUT_STREAM_WIDTH_BA), + .FM_SIZE(NF * CLEN), + .D(2), + .DIMS('{CLEN, NF}), + .COEFS('{1, CLEN}) + ) inst_reorder_out ( .clk(ap_clk), .rst(~ap_rst_n), - .ivld(p_tvalid), .irdy(p_tready), .idat(p_tdata), - .ovld(m_axis_output_tvalid), .ordy(m_axis_output_tready), .odat(m_axis_output_tdata) + .idat(p_tdata), + .ivld(p_tvalid), .irdy(p_tready), + .odat(m_axis_output_tdata), .ovld(m_axis_output_tvalid), + .olst(), .odone(), .ordy(m_axis_output_tready) ); end else begin assign m_axis_output_tvalid = p_tvalid; diff --git a/finn-rtllib/mvu_tiled/mmu/reorder_out.sv b/finn-rtllib/mvu_tiled/mmu/reorder_out.sv deleted file mode 100644 index 1b4b75183e..0000000000 --- a/finn-rtllib/mvu_tiled/mmu/reorder_out.sv +++ /dev/null @@ -1,341 +0,0 @@ -/****************************************************************************** - * Copyright (C) 2024, Advanced Micro Devices, Inc. - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, - * this list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, - * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR - * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR - * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, - * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR - * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF - * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - *****************************************************************************/ - -module reorder_out #( - int unsigned W, - int unsigned XC, - int unsigned YC -)( - input logic clk, - input logic rst, - - input logic ivld, - output logic irdy, - input logic [W-1:0] idat, - - output logic ovld, - input logic ordy, - output logic [W-1:0] odat -); - -// ---------------------------------------------------------------------------- -// Consts and types -// ---------------------------------------------------------------------------- - -localparam int unsigned RAM_BITS = (W + 7)/8 * 8; -localparam int unsigned WGT_EN_BITS = RAM_BITS / 8; -localparam int unsigned XYC = XC * YC; -localparam int unsigned XCNT_BITS = (XC == 1) ? 1 : $clog2(XC); -localparam int unsigned YCNT_BITS = (YC == 1) ? 1 : $clog2(YC); -localparam int unsigned XYCNT_BITS = (XYC == 1) ? 1 : $clog2(XYC); - -typedef enum logic[1:0] {ST_WR_0, ST_WR_0_WAIT, ST_WR_1, ST_WR_1_WAIT} state_wr_t; -typedef enum logic {ST_RD_0, ST_RD_1} state_rd_t; - -// ---------------------------------------------------------------------------- -// Writer -// ---------------------------------------------------------------------------- - -// -- Regs -state_wr_t state_wr_C = ST_WR_0, state_wr_N; -state_rd_t state_rd_C = ST_RD_0, state_rd_N; - -logic [XCNT_BITS-1:0] curr_wrX_C = '0, curr_wrX_N; -logic [YCNT_BITS-1:0] curr_wrY_C = '0, curr_wrY_N; - -// -- Ram -logic [1:0][WGT_EN_BITS-1:0] a_we; // Bank enables -logic [1:0][XYCNT_BITS-1:0] a_addr; -logic [1:0][W-1:0] a_data_in; - -// -- Offsets -logic [XC-1:0][XYCNT_BITS-1:0] x_offsets; -for(genvar i = 0; i < XC; i++) begin - assign x_offsets[i] = i*YC; -end - -// -- IPC -logic done; - -// -- REG -always_ff @( posedge clk ) begin : REG_PROC_WR - if(rst) begin - state_wr_C <= ST_WR_0; - - curr_wrX_C <= 0; - curr_wrY_C <= 0; - end - else begin - state_wr_C <= state_wr_N; - - curr_wrX_C <= curr_wrX_N; - curr_wrY_C <= curr_wrY_N; - end -end - -// -- NSL -always_comb begin : NSL_PROC_WR - state_wr_N = state_wr_C; - - case (state_wr_C) - ST_WR_0: - if ((curr_wrY_C == YC - 1) && (curr_wrX_C == XC - 1) && ivld) begin - state_wr_N = (done || (state_rd_C == ST_RD_0)) ? ST_WR_1 : ST_WR_0_WAIT; - end - - ST_WR_0_WAIT: - state_wr_N = (done || (state_rd_C == ST_RD_0)) ? ST_WR_1 : ST_WR_0_WAIT; - - ST_WR_1: - if ((curr_wrY_C == YC - 1) && (curr_wrX_C == XC - 1) && ivld) begin - state_wr_N = (done || (state_rd_C == ST_RD_1)) ? ST_WR_0 : ST_WR_1_WAIT; - end - - ST_WR_1_WAIT: - state_wr_N = (done || (state_rd_C == ST_RD_1)) ? ST_WR_0 : ST_WR_1_WAIT; - - endcase -end - -// -- DP -always_comb begin : DP_PROC_WR - curr_wrX_N = curr_wrX_C; - curr_wrY_N = curr_wrY_C; - - // Input - irdy = 1'b0; - - // Buffer control - a_we = '0; - for(int i = 0; i < 2; i++) begin - a_addr[i] = x_offsets[curr_wrX_C] + curr_wrY_C; - a_data_in[i] = idat; - end - - // Write and count - case (state_wr_C) - ST_WR_0, ST_WR_1: begin - irdy = 1'b1; - - if(ivld) begin - if(state_wr_C == ST_WR_0) a_we[0] = '1; else a_we[1] = '1; - - curr_wrY_N = (curr_wrY_C == YC-1) ? 0 : curr_wrY_C + 1; - curr_wrX_N = (curr_wrY_C == YC-1) ? ((curr_wrX_C == XC-1) ? 0 : curr_wrX_C + 1) : curr_wrX_C; - end - end - endcase - -end - - -// ---------------------------------------------------------------------------- -// Reader -// ---------------------------------------------------------------------------- - -// -- Regs -logic [XCNT_BITS-1:0] curr_rdX_C = '0, curr_rdX_N; -logic [YCNT_BITS-1:0] curr_rdY_C = '0, curr_rdY_N; - -// -- Ram -logic [1:0] vld_s0_C = '0, vld_s0_N; -logic [1:0] vld_s1_C = '0, vld_s1_N; -logic vld_C = '0, vld_N; -logic [W-1:0] odat_C = '0, odat_N; - -logic [1:0][XYCNT_BITS-1:0] b_addr; -logic [1:0][W-1:0] odat_ram; - -// -- Cond -logic cond_go; - -// -- Oreg -logic [W-1:0] odat_int; -logic ovld_int; -logic ordy_int; - -// -- REG -always_ff @( posedge clk ) begin : REG_PROC_RD - if(rst) begin - state_rd_C <= ST_RD_0; - - curr_rdX_C <= 0; - curr_rdY_C <= 0; - - vld_s0_C <= 0; - vld_s1_C <= 0; - vld_C <= 0; - odat_C <= 0; - end - else begin - state_rd_C <= state_rd_N; - - curr_rdX_C <= curr_rdX_N; - curr_rdY_C <= curr_rdY_N; - - vld_s0_C <= vld_s0_N; - vld_s1_C <= vld_s1_N; - vld_C <= vld_N; - odat_C <= odat_N; - end -end - -// -- NSL -always_comb begin : NSL_PROC_RD - state_rd_N = state_rd_C; - - case (state_rd_C) - ST_RD_0: - if(ordy_int && ((state_wr_C == ST_WR_0) ? cond_go : 1'b1)) begin - if((curr_rdX_C == XC-1) && (curr_rdY_C == YC-1)) begin - state_rd_N = ST_RD_1; - end - end - - ST_RD_1: - if(ordy_int && ((state_wr_C == ST_WR_1) ? cond_go : 1'b1)) begin - if((curr_rdX_C == XC-1) && (curr_rdY_C == YC-1)) begin - state_rd_N = ST_RD_0; - end - end - - endcase -end - -// -- DP cond -always_comb begin - cond_go = 1'b0; - - if(curr_wrX_C > curr_rdX_C) begin - cond_go = 1'b1; - end - else if(curr_wrX_C == curr_rdX_C) begin - if(curr_wrY_C > curr_rdY_C) begin - cond_go = 1'b1; - end - end -end - -// -- DP -always_comb begin : DP_PROC_RD - curr_rdX_N = curr_rdX_C; - curr_rdY_N = curr_rdY_C; - - for(int i = 0; i < 2; i++) begin - vld_s0_N[i] = ordy_int ? 1'b0 : vld_s0_C[i]; - vld_s1_N[i] = ordy_int ? vld_s0_C[i] : vld_s1_C[i]; - end - - vld_N = ordy_int ? |vld_s1_C : vld_C; - odat_N = ordy_int ? (vld_s1_C[0] ? odat_ram[0] : odat_ram[1]) : odat_C; - - for(int i = 0; i < 2; i++) begin - b_addr[i] = x_offsets[curr_rdX_C] + curr_rdY_C; - end - - done = 1'b0; - - case(state_rd_C) - ST_RD_0: begin - if(ordy_int) begin - if((state_wr_C == ST_WR_0) ? cond_go : 1'b1) begin - vld_s0_N[0] = 1'b1; - - curr_rdX_N = (curr_rdX_C == XC-1) ? 0 : curr_rdX_C + 1; - curr_rdY_N = (curr_rdX_C == XC-1) ? ((curr_rdY_C == YC-1) ? 0 : curr_rdY_C + 1) : curr_rdY_C; - done = ((curr_rdY_C == YC-1) && (curr_rdX_C == XC-1)); - end - end - end - - ST_RD_1: begin - if(ordy_int) begin - if((state_wr_C == ST_WR_1) ? cond_go : 1'b1) begin - vld_s0_N[1] = 1'b1; - - curr_rdX_N = (curr_rdX_C == XC-1) ? 0 : curr_rdX_C + 1; - curr_rdY_N = (curr_rdX_C == XC-1) ? ((curr_rdY_C == YC-1) ? 0 : curr_rdY_C + 1) : curr_rdY_C; - done = ((curr_rdY_C == YC-1) && (curr_rdX_C == XC-1)); - end - end - end - - endcase - -end - -assign ovld_int = vld_C; -assign odat_int = odat_C; - -// ---------------------------------------------------------------------------- -// Matrix -// ---------------------------------------------------------------------------- - -for(genvar i = 0; i < 2; i++) begin - ram_p_c #( - .ADDR_BITS(XYCNT_BITS), - .DATA_BITS(RAM_BITS), - .RAM_STYLE("distributed") - ) inst_ram_tp_c ( - .clk(clk), - .a_en(1'b1), - .a_we(a_we[i]), - .a_addr(a_addr[i]), - .b_en(ordy_int), - .b_addr(b_addr[i]), - .a_data_in(a_data_in[i]), - .a_data_out(), - .b_data_out(odat_ram[i]) - ); -end - -// ---------------------------------------------------------------------------- -// Output -// ---------------------------------------------------------------------------- - -Q_srl #( - .depth(2), .width(W) -) inst_out_fifo ( - .clock(clk), - .reset(rst), - .count(), - .maxcount(), - .i_d(odat_int), - .i_v(ovld_int), - .i_r(ordy_int), - .o_d(odat), - .o_v(ovld), - .o_r(ordy) -); - - -endmodule diff --git a/finn-rtllib/mvu_tiled/mmu/replay_buff_mmau.sv b/finn-rtllib/mvu_tiled/mmu/replay_buff_mmau.sv deleted file mode 100644 index 121471f163..0000000000 --- a/finn-rtllib/mvu_tiled/mmu/replay_buff_mmau.sv +++ /dev/null @@ -1,397 +0,0 @@ -/****************************************************************************** - * Copyright (C) 2024, Advanced Micro Devices, Inc. - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, - * this list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, - * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR - * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR - * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, - * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR - * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF - * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - *****************************************************************************/ - -module replay_buff_mmau #( - int unsigned W, - int unsigned XC, - int unsigned YC, - int unsigned N_REPS, - int unsigned IO_TILED = 0 -)( - input logic clk, - input logic rst, - - input logic ivld, - output logic irdy, - input logic [W-1:0] idat, - - output logic ovld, - input logic ordy, - output logic [W-1:0] odat, - output logic olast -); - -// ---------------------------------------------------------------------------- -// Consts and types -// ---------------------------------------------------------------------------- - -localparam int unsigned RAM_BITS = (W + 7)/8 * 8; -localparam int unsigned WGT_EN_BITS = RAM_BITS / 8; -localparam int unsigned XYC = XC * YC; -localparam int unsigned XCNT_BITS = (XC == 1) ? 1 : $clog2(XC); -localparam int unsigned YCNT_BITS = (YC == 1) ? 1 : $clog2(YC); -localparam int unsigned XYCNT_BITS = (XYC == 1) ? 1 : $clog2(XYC); -localparam int unsigned REPS_BITS = (N_REPS == 1) ? 1 : $clog2(N_REPS); - -typedef enum logic[1:0] {ST_WR_0, ST_WR_0_WAIT, ST_WR_1, ST_WR_1_WAIT} state_wr_t; -typedef enum logic {ST_RD_0, ST_RD_1} state_rd_t; - -// ---------------------------------------------------------------------------- -// Ireg -// ---------------------------------------------------------------------------- -logic [W-1:0] idat_int; -logic ivld_int; -logic irdy_int; - -skid #(.DATA_WIDTH(W), .FEED_STAGES(1)) isnt_input_reg ( - .clk(clk), .rst(rst), - .ivld(ivld), .irdy(irdy), .idat(idat), - .ovld(ivld_int), .ordy(irdy_int), .odat(idat_int) -); - -// ---------------------------------------------------------------------------- -// Writer -// ---------------------------------------------------------------------------- - -// -- Regs -state_wr_t state_wr_C = ST_WR_0, state_wr_N; -state_rd_t state_rd_C = ST_RD_0, state_rd_N; - -logic [XCNT_BITS-1:0] curr_wrX_C = '0, curr_wrX_N; -logic [YCNT_BITS-1:0] curr_wrY_C = '0, curr_wrY_N; - -// -- Ram -logic [1:0][WGT_EN_BITS-1:0] a_we; // Bank enables -logic [1:0][XYCNT_BITS-1:0] a_addr; -logic [1:0][W-1:0] a_data_in; - -// -- Offsets -logic [XC-1:0][XYCNT_BITS-1:0] x_offsets; -for(genvar i = 0; i < XC; i++) begin - assign x_offsets[i] = i*YC; -end - -// -- IPC -logic done; - -// -- REG -always_ff @( posedge clk ) begin : REG_PROC_WR - if(rst) begin - state_wr_C <= ST_WR_0; - - curr_wrX_C <= 0; - curr_wrY_C <= 0; - end - else begin - state_wr_C <= state_wr_N; - - curr_wrX_C <= curr_wrX_N; - curr_wrY_C <= curr_wrY_N; - end -end - -// -- NSL -always_comb begin : NSL_PROC_WR - state_wr_N = state_wr_C; - - case (state_wr_C) - ST_WR_0: - if ((curr_wrY_C == YC - 1) && (curr_wrX_C == XC - 1) && ivld_int) begin - state_wr_N = (done || (state_rd_C == ST_RD_0)) ? ST_WR_1 : ST_WR_0_WAIT; - end - - ST_WR_0_WAIT: - state_wr_N = (done || (state_rd_C == ST_RD_0)) ? ST_WR_1 : ST_WR_0_WAIT; - - ST_WR_1: - if ((curr_wrY_C == YC - 1) && (curr_wrX_C == XC - 1) && ivld_int) begin - state_wr_N = (done || (state_rd_C == ST_RD_1)) ? ST_WR_0 : ST_WR_1_WAIT; - end - - ST_WR_1_WAIT: - state_wr_N = (done || (state_rd_C == ST_RD_1)) ? ST_WR_0 : ST_WR_1_WAIT; - - endcase -end - -// -- DP -always_comb begin : DP_PROC_WR - curr_wrX_N = curr_wrX_C; - curr_wrY_N = curr_wrY_C; - - // Input - irdy_int = 1'b0; - - // Buffer control - a_we = '0; - for(int i = 0; i < 2; i++) begin - a_addr[i] = x_offsets[curr_wrX_C] + curr_wrY_C; - a_data_in[i] = idat_int; - end - - // Write and count - case (state_wr_C) - ST_WR_0, ST_WR_1: begin - irdy_int = 1'b1; - - if(ivld_int) begin - if(state_wr_C == ST_WR_0) a_we[0] = '1; else a_we[1] = '1; - - if(IO_TILED == 1) begin - curr_wrY_N = (curr_wrY_C == YC-1) ? 0 : curr_wrY_C + 1; - curr_wrX_N = (curr_wrY_C == YC-1) ? ((curr_wrX_C == XC-1) ? 0 : curr_wrX_C + 1) : curr_wrX_C; - end else begin - curr_wrX_N = (curr_wrX_C == XC-1) ? 0 : curr_wrX_C + 1; - curr_wrY_N = (curr_wrX_C == XC-1) ? ((curr_wrY_C == YC-1) ? 0 : curr_wrY_C + 1) : curr_wrY_C; - end - end - end - endcase - -end - -// ---------------------------------------------------------------------------- -// Reader -// ---------------------------------------------------------------------------- - -// -- Regs -logic [XCNT_BITS-1:0] curr_rdX_C = '0, curr_rdX_N; -logic [YCNT_BITS-1:0] curr_rdY_C = '0, curr_rdY_N; -logic [REPS_BITS-1:0] curr_reps_C = '0, curr_reps_N; - -// -- Ram -logic [1:0] vld_s0_C = '0, vld_s0_N; -logic [1:0] vld_s1_C = '0, vld_s1_N; -logic vld_C = '0, vld_N; -logic last_s0_C = '0, last_s0_N; -logic last_s1_C = '0, last_s1_N; -logic last_C = '0, last_N; -logic [W-1:0] odat_C = '0, odat_N; - -logic [1:0][XYCNT_BITS-1:0] b_addr; -logic [1:0][W-1:0] odat_ram; - -// -- Cond -logic cond_go; - -// -- Oreg -logic [W-1:0] odat_int; -logic ovld_int; -logic ordy_int; -logic olast_int; - -// -- REG -always_ff @( posedge clk ) begin : REG_PROC_RD - if(rst) begin - state_rd_C <= ST_RD_0; - - curr_rdX_C <= 0; - curr_rdY_C <= 0; - curr_reps_C <= 0; - - vld_s0_C <= 0; - vld_s1_C <= 0; - vld_C <= 0; - odat_C <= 0; - last_s0_C <= 0; - last_s1_C <= 0; - last_C <= 0; - end - else begin - state_rd_C <= state_rd_N; - - curr_rdX_C <= curr_rdX_N; - curr_rdY_C <= curr_rdY_N; - curr_reps_C <= curr_reps_N; - - vld_s0_C <= vld_s0_N; - vld_s1_C <= vld_s1_N; - vld_C <= vld_N; - odat_C <= odat_N; - last_s0_C <= last_s0_N; - last_s1_C <= last_s1_N; - last_C <= last_N; - end -end - -// -- NSL -always_comb begin : NSL_PROC_RD - state_rd_N = state_rd_C; - - case (state_rd_C) - ST_RD_0: - if(ordy_int && ((state_wr_C == ST_WR_0) ? cond_go : 1'b1)) begin - if((curr_rdX_C == XC-1) && (curr_rdY_C == YC-1) && (curr_reps_C == N_REPS-1)) begin - state_rd_N = ST_RD_1; - end - end - - ST_RD_1: - if(ordy_int && ((state_wr_C == ST_WR_1) ? cond_go : 1'b1)) begin - if((curr_rdX_C == XC-1) && (curr_rdY_C == YC-1) && (curr_reps_C == N_REPS-1)) begin - state_rd_N = ST_RD_0; - end - end - - endcase -end - -// -- DP cond -always_comb begin - cond_go = 1'b0; - - if(IO_TILED) begin - if(curr_wrX_C > curr_rdX_C) begin - cond_go = 1'b1; - end - else if(curr_wrX_C == curr_rdX_C) begin - if(curr_wrY_C > curr_rdY_C) begin - cond_go = 1'b1; - end - end - end else begin - if(curr_wrY_C > curr_rdY_C) begin - cond_go = 1'b1; - end - else if(curr_wrY_C == curr_rdY_C) begin - if(curr_wrX_C > curr_rdX_C) begin - cond_go = 1'b1; - end - end - end -end - -// -- DP -always_comb begin : DP_PROC_RD - curr_rdX_N = curr_rdX_C; - curr_rdY_N = curr_rdY_C; - curr_reps_N = curr_reps_C; - - for(int i = 0; i < 2; i++) begin - vld_s0_N[i] = ordy_int ? 1'b0 : vld_s0_C[i]; - vld_s1_N[i] = ordy_int ? vld_s0_C[i] : vld_s1_C[i]; - end - - vld_N = ordy_int ? |vld_s1_C : vld_C; - odat_N = ordy_int ? (vld_s1_C[0] ? odat_ram[0] : odat_ram[1]) : odat_C; - - last_s0_N = ordy_int ? 1'b0 : last_s0_C; - last_s1_N = ordy_int ? last_s0_C : last_s1_C; - last_N = ordy_int ? last_s1_C : last_C; - - for(int i = 0; i < 2; i++) begin - b_addr[i] = x_offsets[curr_rdX_C] + curr_rdY_C; - end - - done = 1'b0; - - case(state_rd_C) - ST_RD_0: begin - if(ordy_int) begin - if((state_wr_C == ST_WR_0) ? cond_go : 1'b1) begin - vld_s0_N[0] = 1'b1; - - last_s0_N = (curr_rdX_C == XC-1); - - curr_rdY_N = (curr_rdY_C == YC-1) ? 0 : curr_rdY_C + 1; - curr_rdX_N = (curr_rdY_C == YC-1) ? ((curr_rdX_C == XC-1) ? 0 : curr_rdX_C + 1) : curr_rdX_C; - curr_reps_N = ((curr_rdY_C == YC-1) && (curr_rdX_C == XC-1)) ? ((curr_reps_C == N_REPS-1) ? 0 : curr_reps_C + 1) : curr_reps_C; - done = ((curr_rdY_C == YC-1) && (curr_rdX_C == XC-1) && (curr_reps_C == N_REPS-1)); - end - end - end - - ST_RD_1: begin - if(ordy_int) begin - if((state_wr_C == ST_WR_1) ? cond_go : 1'b1) begin - vld_s0_N[1] = 1'b1; - - last_s0_N = (curr_rdX_C == XC-1); - - curr_rdY_N = (curr_rdY_C == YC-1) ? 0 : curr_rdY_C + 1; - curr_rdX_N = (curr_rdY_C == YC-1) ? ((curr_rdX_C == XC-1) ? 0 : curr_rdX_C + 1) : curr_rdX_C; - curr_reps_N = ((curr_rdY_C == YC-1) && (curr_rdX_C == XC-1)) ? ((curr_reps_C == N_REPS-1) ? 0 : curr_reps_C + 1) : curr_reps_C; - done = ((curr_rdY_C == YC-1) && (curr_rdX_C == XC-1) && (curr_reps_C == N_REPS-1)); - end - end - end - - endcase - -end - -assign ovld_int = vld_C; -assign odat_int = odat_C; -assign olast_int = last_C; - -// ---------------------------------------------------------------------------- -// BRAM -// ---------------------------------------------------------------------------- - -for(genvar i = 0; i < 2; i++) begin - ram_p_c #( - .ADDR_BITS(XYCNT_BITS), - .DATA_BITS(RAM_BITS), - .RAM_STYLE("distributed") - ) inst_ram_tp_c ( - .clk(clk), - .a_en(1'b1), - .a_we(a_we[i]), - .a_addr(a_addr[i]), - .b_en(ordy_int), - .b_addr(b_addr[i]), - .a_data_in(a_data_in[i]), - .a_data_out(), - .b_data_out(odat_ram[i]) - ); -end - -// ---------------------------------------------------------------------------- -// Output -// ---------------------------------------------------------------------------- - -Q_srl #( - .depth(2), .width(1+W) -) inst_out_fifo ( - .clock(clk), - .reset(rst), - .count(), - .maxcount(), - .i_d({olast_int, odat_int}), - .i_v(ovld_int), - .i_r(ordy_int), - .o_d({olast, odat}), - .o_v(ovld), - .o_r(ordy) -); - -endmodule diff --git a/finn-rtllib/mvu_tiled/mvu_tiled_axi.sv b/finn-rtllib/mvu_tiled/mvu_tiled_axi.sv index 28f0b9eafc..1040825daa 100644 --- a/finn-rtllib/mvu_tiled/mvu_tiled_axi.sv +++ b/finn-rtllib/mvu_tiled/mvu_tiled_axi.sv @@ -137,11 +137,23 @@ module mvu_tiled_axi #( uwire avld; uwire ardy; - replay_buff_tile #(.XC(MW/SIMD), .YC(TH), .W($bits(mvu_flatin_t)), .N_REPS(MH/PE), .IO_TILED(IN_TILED)) activation_replay ( + localparam int unsigned SF = MW / SIMD; + localparam int unsigned NF = MH / PE; + + uwire [2:0] act_done; + input_gen #( + .DATA_WIDTH($bits(mvu_flatin_t)), + .FM_SIZE(SF * TH), + .D(3), + .DIMS('{NF, SF, TH}), + .COEFS('{0, 1, SF}) + ) activation_replay ( .clk(ap_clk), .rst(rst), - .ivld(s_axis_input_tvalid), .irdy(s_axis_input_tready), .idat(mvu_flatin_t'(s_axis_input_tdata)), - .ovld(avld), .ordy(ardy), .odat(amvau), .olast(alast) + .idat(mvu_flatin_t'(s_axis_input_tdata)), + .ivld(s_axis_input_tvalid), .irdy(s_axis_input_tready), + .odat(amvau), .ovld(avld), .olst(), .odone(act_done), .ordy(ardy) ); + assign alast = act_done[1]; //- Unflatten weights --------------------------------------------------- typedef logic [PE-1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] mvu_w_t; @@ -185,7 +197,7 @@ module mvu_tiled_axi #( uwire dsp_p_t dsp_p; // TODO: No double-pumping in the initial implementation - assign dsp_en = en; + uwire dsp_en = en; assign dsp_last = alast && istb; assign dsp_zero = !istb; @@ -265,10 +277,18 @@ module mvu_tiled_axi #( //-------------------- Output reordering --------------------\\ if(OUT_TILED == 0) begin - reorder_out #(.W(OUTPUT_STREAM_WIDTH_BA), .XC(MH/PE), .YC(TH)) inst_reorder_out ( + input_gen #( + .DATA_WIDTH(OUTPUT_STREAM_WIDTH_BA), + .FM_SIZE(NF * TH), + .D(2), + .DIMS('{TH, NF}), + .COEFS('{1, TH}) + ) inst_reorder_out ( .clk(ap_clk), .rst(rst), - .ivld(m_axis_int_tvalid), .irdy(m_axis_int_tready), .idat(m_axis_int_tdata), - .ovld(m_axis_output_tvalid), .ordy(m_axis_output_tready), .odat(m_axis_output_tdata) + .idat(m_axis_int_tdata), + .ivld(m_axis_int_tvalid), .irdy(m_axis_int_tready), + .odat(m_axis_output_tdata), .ovld(m_axis_output_tvalid), + .olst(), .odone(), .ordy(m_axis_output_tready) ); end else begin assign m_axis_output_tvalid = m_axis_int_tvalid; diff --git a/finn-rtllib/mvu_tiled/reorder_out.sv b/finn-rtllib/mvu_tiled/reorder_out.sv deleted file mode 100644 index 1b4b75183e..0000000000 --- a/finn-rtllib/mvu_tiled/reorder_out.sv +++ /dev/null @@ -1,341 +0,0 @@ -/****************************************************************************** - * Copyright (C) 2024, Advanced Micro Devices, Inc. - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, - * this list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, - * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR - * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR - * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, - * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR - * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF - * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - *****************************************************************************/ - -module reorder_out #( - int unsigned W, - int unsigned XC, - int unsigned YC -)( - input logic clk, - input logic rst, - - input logic ivld, - output logic irdy, - input logic [W-1:0] idat, - - output logic ovld, - input logic ordy, - output logic [W-1:0] odat -); - -// ---------------------------------------------------------------------------- -// Consts and types -// ---------------------------------------------------------------------------- - -localparam int unsigned RAM_BITS = (W + 7)/8 * 8; -localparam int unsigned WGT_EN_BITS = RAM_BITS / 8; -localparam int unsigned XYC = XC * YC; -localparam int unsigned XCNT_BITS = (XC == 1) ? 1 : $clog2(XC); -localparam int unsigned YCNT_BITS = (YC == 1) ? 1 : $clog2(YC); -localparam int unsigned XYCNT_BITS = (XYC == 1) ? 1 : $clog2(XYC); - -typedef enum logic[1:0] {ST_WR_0, ST_WR_0_WAIT, ST_WR_1, ST_WR_1_WAIT} state_wr_t; -typedef enum logic {ST_RD_0, ST_RD_1} state_rd_t; - -// ---------------------------------------------------------------------------- -// Writer -// ---------------------------------------------------------------------------- - -// -- Regs -state_wr_t state_wr_C = ST_WR_0, state_wr_N; -state_rd_t state_rd_C = ST_RD_0, state_rd_N; - -logic [XCNT_BITS-1:0] curr_wrX_C = '0, curr_wrX_N; -logic [YCNT_BITS-1:0] curr_wrY_C = '0, curr_wrY_N; - -// -- Ram -logic [1:0][WGT_EN_BITS-1:0] a_we; // Bank enables -logic [1:0][XYCNT_BITS-1:0] a_addr; -logic [1:0][W-1:0] a_data_in; - -// -- Offsets -logic [XC-1:0][XYCNT_BITS-1:0] x_offsets; -for(genvar i = 0; i < XC; i++) begin - assign x_offsets[i] = i*YC; -end - -// -- IPC -logic done; - -// -- REG -always_ff @( posedge clk ) begin : REG_PROC_WR - if(rst) begin - state_wr_C <= ST_WR_0; - - curr_wrX_C <= 0; - curr_wrY_C <= 0; - end - else begin - state_wr_C <= state_wr_N; - - curr_wrX_C <= curr_wrX_N; - curr_wrY_C <= curr_wrY_N; - end -end - -// -- NSL -always_comb begin : NSL_PROC_WR - state_wr_N = state_wr_C; - - case (state_wr_C) - ST_WR_0: - if ((curr_wrY_C == YC - 1) && (curr_wrX_C == XC - 1) && ivld) begin - state_wr_N = (done || (state_rd_C == ST_RD_0)) ? ST_WR_1 : ST_WR_0_WAIT; - end - - ST_WR_0_WAIT: - state_wr_N = (done || (state_rd_C == ST_RD_0)) ? ST_WR_1 : ST_WR_0_WAIT; - - ST_WR_1: - if ((curr_wrY_C == YC - 1) && (curr_wrX_C == XC - 1) && ivld) begin - state_wr_N = (done || (state_rd_C == ST_RD_1)) ? ST_WR_0 : ST_WR_1_WAIT; - end - - ST_WR_1_WAIT: - state_wr_N = (done || (state_rd_C == ST_RD_1)) ? ST_WR_0 : ST_WR_1_WAIT; - - endcase -end - -// -- DP -always_comb begin : DP_PROC_WR - curr_wrX_N = curr_wrX_C; - curr_wrY_N = curr_wrY_C; - - // Input - irdy = 1'b0; - - // Buffer control - a_we = '0; - for(int i = 0; i < 2; i++) begin - a_addr[i] = x_offsets[curr_wrX_C] + curr_wrY_C; - a_data_in[i] = idat; - end - - // Write and count - case (state_wr_C) - ST_WR_0, ST_WR_1: begin - irdy = 1'b1; - - if(ivld) begin - if(state_wr_C == ST_WR_0) a_we[0] = '1; else a_we[1] = '1; - - curr_wrY_N = (curr_wrY_C == YC-1) ? 0 : curr_wrY_C + 1; - curr_wrX_N = (curr_wrY_C == YC-1) ? ((curr_wrX_C == XC-1) ? 0 : curr_wrX_C + 1) : curr_wrX_C; - end - end - endcase - -end - - -// ---------------------------------------------------------------------------- -// Reader -// ---------------------------------------------------------------------------- - -// -- Regs -logic [XCNT_BITS-1:0] curr_rdX_C = '0, curr_rdX_N; -logic [YCNT_BITS-1:0] curr_rdY_C = '0, curr_rdY_N; - -// -- Ram -logic [1:0] vld_s0_C = '0, vld_s0_N; -logic [1:0] vld_s1_C = '0, vld_s1_N; -logic vld_C = '0, vld_N; -logic [W-1:0] odat_C = '0, odat_N; - -logic [1:0][XYCNT_BITS-1:0] b_addr; -logic [1:0][W-1:0] odat_ram; - -// -- Cond -logic cond_go; - -// -- Oreg -logic [W-1:0] odat_int; -logic ovld_int; -logic ordy_int; - -// -- REG -always_ff @( posedge clk ) begin : REG_PROC_RD - if(rst) begin - state_rd_C <= ST_RD_0; - - curr_rdX_C <= 0; - curr_rdY_C <= 0; - - vld_s0_C <= 0; - vld_s1_C <= 0; - vld_C <= 0; - odat_C <= 0; - end - else begin - state_rd_C <= state_rd_N; - - curr_rdX_C <= curr_rdX_N; - curr_rdY_C <= curr_rdY_N; - - vld_s0_C <= vld_s0_N; - vld_s1_C <= vld_s1_N; - vld_C <= vld_N; - odat_C <= odat_N; - end -end - -// -- NSL -always_comb begin : NSL_PROC_RD - state_rd_N = state_rd_C; - - case (state_rd_C) - ST_RD_0: - if(ordy_int && ((state_wr_C == ST_WR_0) ? cond_go : 1'b1)) begin - if((curr_rdX_C == XC-1) && (curr_rdY_C == YC-1)) begin - state_rd_N = ST_RD_1; - end - end - - ST_RD_1: - if(ordy_int && ((state_wr_C == ST_WR_1) ? cond_go : 1'b1)) begin - if((curr_rdX_C == XC-1) && (curr_rdY_C == YC-1)) begin - state_rd_N = ST_RD_0; - end - end - - endcase -end - -// -- DP cond -always_comb begin - cond_go = 1'b0; - - if(curr_wrX_C > curr_rdX_C) begin - cond_go = 1'b1; - end - else if(curr_wrX_C == curr_rdX_C) begin - if(curr_wrY_C > curr_rdY_C) begin - cond_go = 1'b1; - end - end -end - -// -- DP -always_comb begin : DP_PROC_RD - curr_rdX_N = curr_rdX_C; - curr_rdY_N = curr_rdY_C; - - for(int i = 0; i < 2; i++) begin - vld_s0_N[i] = ordy_int ? 1'b0 : vld_s0_C[i]; - vld_s1_N[i] = ordy_int ? vld_s0_C[i] : vld_s1_C[i]; - end - - vld_N = ordy_int ? |vld_s1_C : vld_C; - odat_N = ordy_int ? (vld_s1_C[0] ? odat_ram[0] : odat_ram[1]) : odat_C; - - for(int i = 0; i < 2; i++) begin - b_addr[i] = x_offsets[curr_rdX_C] + curr_rdY_C; - end - - done = 1'b0; - - case(state_rd_C) - ST_RD_0: begin - if(ordy_int) begin - if((state_wr_C == ST_WR_0) ? cond_go : 1'b1) begin - vld_s0_N[0] = 1'b1; - - curr_rdX_N = (curr_rdX_C == XC-1) ? 0 : curr_rdX_C + 1; - curr_rdY_N = (curr_rdX_C == XC-1) ? ((curr_rdY_C == YC-1) ? 0 : curr_rdY_C + 1) : curr_rdY_C; - done = ((curr_rdY_C == YC-1) && (curr_rdX_C == XC-1)); - end - end - end - - ST_RD_1: begin - if(ordy_int) begin - if((state_wr_C == ST_WR_1) ? cond_go : 1'b1) begin - vld_s0_N[1] = 1'b1; - - curr_rdX_N = (curr_rdX_C == XC-1) ? 0 : curr_rdX_C + 1; - curr_rdY_N = (curr_rdX_C == XC-1) ? ((curr_rdY_C == YC-1) ? 0 : curr_rdY_C + 1) : curr_rdY_C; - done = ((curr_rdY_C == YC-1) && (curr_rdX_C == XC-1)); - end - end - end - - endcase - -end - -assign ovld_int = vld_C; -assign odat_int = odat_C; - -// ---------------------------------------------------------------------------- -// Matrix -// ---------------------------------------------------------------------------- - -for(genvar i = 0; i < 2; i++) begin - ram_p_c #( - .ADDR_BITS(XYCNT_BITS), - .DATA_BITS(RAM_BITS), - .RAM_STYLE("distributed") - ) inst_ram_tp_c ( - .clk(clk), - .a_en(1'b1), - .a_we(a_we[i]), - .a_addr(a_addr[i]), - .b_en(ordy_int), - .b_addr(b_addr[i]), - .a_data_in(a_data_in[i]), - .a_data_out(), - .b_data_out(odat_ram[i]) - ); -end - -// ---------------------------------------------------------------------------- -// Output -// ---------------------------------------------------------------------------- - -Q_srl #( - .depth(2), .width(W) -) inst_out_fifo ( - .clock(clk), - .reset(rst), - .count(), - .maxcount(), - .i_d(odat_int), - .i_v(ovld_int), - .i_r(ordy_int), - .o_d(odat), - .o_v(ovld), - .o_r(ordy) -); - - -endmodule diff --git a/finn-rtllib/mvu_tiled/replay_buff_tile.sv b/finn-rtllib/mvu_tiled/replay_buff_tile.sv deleted file mode 100644 index ec4618e58d..0000000000 --- a/finn-rtllib/mvu_tiled/replay_buff_tile.sv +++ /dev/null @@ -1,392 +0,0 @@ -/****************************************************************************** - * Copyright (C) 2024, Advanced Micro Devices, Inc. - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, - * this list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, - * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR - * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR - * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, - * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR - * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF - * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - *****************************************************************************/ - -module replay_buff_tile #( - int unsigned W, - int unsigned XC, - int unsigned YC, - int unsigned N_REPS, - int unsigned IO_TILED = 0 -)( - input logic clk, - input logic rst, - - input logic ivld, - output logic irdy, - input logic [W-1:0] idat, - - output logic ovld, - input logic ordy, - output logic [W-1:0] odat, - output logic olast -); - -// ---------------------------------------------------------------------------- -// Consts and types -// ---------------------------------------------------------------------------- - -localparam int unsigned RAM_BITS = (W + 7)/8 * 8; -localparam int unsigned WGT_EN_BITS = RAM_BITS / 8; -localparam int unsigned XYC = XC * YC; -localparam int unsigned XCNT_BITS = (XC == 1) ? 1 : $clog2(XC); -localparam int unsigned YCNT_BITS = (YC == 1) ? 1 : $clog2(YC); -localparam int unsigned XYCNT_BITS = (XYC == 1) ? 1 : $clog2(XYC); -localparam int unsigned REPS_BITS = (N_REPS == 1) ? 1 : $clog2(N_REPS); - -typedef enum logic[1:0] {ST_WR_0, ST_WR_0_WAIT, ST_WR_1, ST_WR_1_WAIT} state_wr_t; -typedef enum logic {ST_RD_0, ST_RD_1} state_rd_t; - -// ---------------------------------------------------------------------------- -// Ireg -// ---------------------------------------------------------------------------- -logic [W-1:0] idat_int; -logic ivld_int; -logic irdy_int; - -skid #(.DATA_WIDTH(W), .FEED_STAGES(1)) isnt_input_reg ( - .clk(clk), .rst(rst), - .ivld(ivld), .irdy(irdy), .idat(idat), - .ovld(ivld_int), .ordy(irdy_int), .odat(idat_int) -); - -// ---------------------------------------------------------------------------- -// Writer -// ---------------------------------------------------------------------------- - -// -- Regs -state_wr_t state_wr_C = ST_WR_0, state_wr_N; -state_rd_t state_rd_C = ST_RD_0, state_rd_N; - -logic [XCNT_BITS-1:0] curr_wrX_C = '0, curr_wrX_N; -logic [YCNT_BITS-1:0] curr_wrY_C = '0, curr_wrY_N; - -// -- Ram -logic [1:0][WGT_EN_BITS-1:0] a_we; // Bank enables -logic [1:0][XYCNT_BITS-1:0] a_addr; -logic [1:0][W-1:0] a_data_in; - -// -- Offsets -logic [XC-1:0][XYCNT_BITS-1:0] x_offsets; -for(genvar i = 0; i < XC; i++) begin - assign x_offsets[i] = i*YC; -end - -// -- IPC -logic done; - -// -- REG -always_ff @( posedge clk ) begin : REG_PROC_WR - if(rst) begin - state_wr_C <= ST_WR_0; - - curr_wrX_C <= 0; - curr_wrY_C <= 0; - end - else begin - state_wr_C <= state_wr_N; - - curr_wrX_C <= curr_wrX_N; - curr_wrY_C <= curr_wrY_N; - end -end - -// -- NSL -always_comb begin : NSL_PROC_WR - state_wr_N = state_wr_C; - - case (state_wr_C) - ST_WR_0: - if ((curr_wrY_C == YC - 1) && (curr_wrX_C == XC - 1) && ivld_int) begin - state_wr_N = (done || (state_rd_C == ST_RD_0)) ? ST_WR_1 : ST_WR_0_WAIT; - end - - ST_WR_0_WAIT: - state_wr_N = (done || (state_rd_C == ST_RD_0)) ? ST_WR_1 : ST_WR_0_WAIT; - - ST_WR_1: - if ((curr_wrY_C == YC - 1) && (curr_wrX_C == XC - 1) && ivld_int) begin - state_wr_N = (done || (state_rd_C == ST_RD_1)) ? ST_WR_0 : ST_WR_1_WAIT; - end - - ST_WR_1_WAIT: - state_wr_N = (done || (state_rd_C == ST_RD_1)) ? ST_WR_0 : ST_WR_1_WAIT; - - endcase -end - -// -- DP -always_comb begin : DP_PROC_WR - curr_wrX_N = curr_wrX_C; - curr_wrY_N = curr_wrY_C; - - // Input - irdy_int = 1'b0; - - // Buffer control - a_we = '0; - for(int i = 0; i < 2; i++) begin - a_addr[i] = x_offsets[curr_wrX_C] + curr_wrY_C; - a_data_in[i] = idat_int; - end - - // Write and count - case (state_wr_C) - ST_WR_0, ST_WR_1: begin - irdy_int = 1'b1; - - if(ivld_int) begin - if(state_wr_C == ST_WR_0) a_we[0] = '1; else a_we[1] = '1; - - curr_wrX_N = (curr_wrX_C == XC-1) ? 0 : curr_wrX_C + 1; - curr_wrY_N = (curr_wrX_C == XC-1) ? ((curr_wrY_C == YC-1) ? 0 : curr_wrY_C + 1) : curr_wrY_C; - end - end - endcase - -end - -// ---------------------------------------------------------------------------- -// Reader -// ---------------------------------------------------------------------------- - -// -- Regs -logic [XCNT_BITS-1:0] curr_rdX_C = '0, curr_rdX_N; -logic [YCNT_BITS-1:0] curr_rdY_C = '0, curr_rdY_N; -logic [REPS_BITS-1:0] curr_reps_C = '0, curr_reps_N; - -// -- Ram -logic [1:0] vld_s0_C = '0, vld_s0_N; -logic [1:0] vld_s1_C = '0, vld_s1_N; -logic vld_C = '0, vld_N; -logic last_s0_C = '0, last_s0_N; -logic last_s1_C = '0, last_s1_N; -logic last_C = '0, last_N; -logic [W-1:0] odat_C = '0, odat_N; - -logic [1:0][XYCNT_BITS-1:0] b_addr; -logic [1:0][W-1:0] odat_ram; - -// -- Cond -logic cond_go; - -// -- Oreg -logic [W-1:0] odat_int; -logic ovld_int; -logic ordy_int; -logic olast_int; - -// -- REG -always_ff @( posedge clk ) begin : REG_PROC_RD - if(rst) begin - state_rd_C <= ST_RD_0; - - curr_rdX_C <= 0; - curr_rdY_C <= 0; - curr_reps_C <= 0; - - vld_s0_C <= 0; - vld_s1_C <= 0; - vld_C <= 0; - odat_C <= 0; - last_s0_C <= 0; - last_s1_C <= 0; - last_C <= 0; - end - else begin - state_rd_C <= state_rd_N; - - curr_rdX_C <= curr_rdX_N; - curr_rdY_C <= curr_rdY_N; - curr_reps_C <= curr_reps_N; - - vld_s0_C <= vld_s0_N; - vld_s1_C <= vld_s1_N; - vld_C <= vld_N; - odat_C <= odat_N; - last_s0_C <= last_s0_N; - last_s1_C <= last_s1_N; - last_C <= last_N; - end -end - -// -- NSL -always_comb begin : NSL_PROC_RD - state_rd_N = state_rd_C; - - case (state_rd_C) - ST_RD_0: - if(ordy_int && ((state_wr_C == ST_WR_0) ? cond_go : 1'b1)) begin - if((curr_rdX_C == XC-1) && (curr_rdY_C == YC-1) && (curr_reps_C == N_REPS-1)) begin - state_rd_N = ST_RD_1; - end - end - - ST_RD_1: - if(ordy_int && ((state_wr_C == ST_WR_1) ? cond_go : 1'b1)) begin - if((curr_rdX_C == XC-1) && (curr_rdY_C == YC-1) && (curr_reps_C == N_REPS-1)) begin - state_rd_N = ST_RD_0; - end - end - - endcase -end - -// -- DP cond -always_comb begin - cond_go = 1'b0; - - if(IO_TILED) begin - if(curr_wrX_C > curr_rdX_C) begin - cond_go = 1'b1; - end - else if(curr_wrX_C == curr_rdX_C) begin - if(curr_wrY_C > curr_rdY_C) begin - cond_go = 1'b1; - end - end - end else begin - if(curr_wrY_C > curr_rdY_C) begin - cond_go = 1'b1; - end - else if(curr_wrY_C == curr_rdY_C) begin - if(curr_wrX_C > curr_rdX_C) begin - cond_go = 1'b1; - end - end - end -end - -// -- DP -always_comb begin : DP_PROC_RD - curr_rdX_N = curr_rdX_C; - curr_rdY_N = curr_rdY_C; - curr_reps_N = curr_reps_C; - - for(int i = 0; i < 2; i++) begin - vld_s0_N[i] = ordy_int ? 1'b0 : vld_s0_C[i]; - vld_s1_N[i] = ordy_int ? vld_s0_C[i] : vld_s1_C[i]; - end - - vld_N = ordy_int ? |vld_s1_C : vld_C; - odat_N = ordy_int ? (vld_s1_C[0] ? odat_ram[0] : odat_ram[1]) : odat_C; - - last_s0_N = ordy_int ? 1'b0 : last_s0_C; - last_s1_N = ordy_int ? last_s0_C : last_s1_C; - last_N = ordy_int ? last_s1_C : last_C; - - for(int i = 0; i < 2; i++) begin - b_addr[i] = x_offsets[curr_rdX_C] + curr_rdY_C; - end - - done = 1'b0; - - case(state_rd_C) - ST_RD_0: begin - if(ordy_int) begin - if((state_wr_C == ST_WR_0) ? cond_go : 1'b1) begin - vld_s0_N[0] = 1'b1; - - last_s0_N = (curr_rdX_C == XC-1); - - curr_rdY_N = (curr_rdY_C == YC-1) ? 0 : curr_rdY_C + 1; - curr_rdX_N = (curr_rdY_C == YC-1) ? ((curr_rdX_C == XC-1) ? 0 : curr_rdX_C + 1) : curr_rdX_C; - curr_reps_N = ((curr_rdY_C == YC-1) && (curr_rdX_C == XC-1)) ? ((curr_reps_C == N_REPS-1) ? 0 : curr_reps_C + 1) : curr_reps_C; - done = ((curr_rdY_C == YC-1) && (curr_rdX_C == XC-1) && (curr_reps_C == N_REPS-1)); - end - end - end - - ST_RD_1: begin - if(ordy_int) begin - if((state_wr_C == ST_WR_1) ? cond_go : 1'b1) begin - vld_s0_N[1] = 1'b1; - - last_s0_N = (curr_rdX_C == XC-1); - - curr_rdY_N = (curr_rdY_C == YC-1) ? 0 : curr_rdY_C + 1; - curr_rdX_N = (curr_rdY_C == YC-1) ? ((curr_rdX_C == XC-1) ? 0 : curr_rdX_C + 1) : curr_rdX_C; - curr_reps_N = ((curr_rdY_C == YC-1) && (curr_rdX_C == XC-1)) ? ((curr_reps_C == N_REPS-1) ? 0 : curr_reps_C + 1) : curr_reps_C; - done = ((curr_rdY_C == YC-1) && (curr_rdX_C == XC-1) && (curr_reps_C == N_REPS-1)); - end - end - end - - endcase - -end - -assign ovld_int = vld_C; -assign odat_int = odat_C; -assign olast_int = last_C; - -// ---------------------------------------------------------------------------- -// BRAM -// ---------------------------------------------------------------------------- - -for(genvar i = 0; i < 2; i++) begin - ram_p_c #( - .ADDR_BITS(XYCNT_BITS), - .DATA_BITS(RAM_BITS), - .RAM_STYLE("distributed") - ) inst_ram_tp_c ( - .clk(clk), - .a_en(1'b1), - .a_we(a_we[i]), - .a_addr(a_addr[i]), - .b_en(ordy_int), - .b_addr(b_addr[i]), - .a_data_in(a_data_in[i]), - .a_data_out(), - .b_data_out(odat_ram[i]) - ); -end - -// ---------------------------------------------------------------------------- -// Output -// ---------------------------------------------------------------------------- - -Q_srl #( - .depth(2), .width(1+W) -) inst_out_fifo ( - .clock(clk), - .reset(rst), - .count(), - .maxcount(), - .i_d({olast_int, odat_int}), - .i_v(ovld_int), - .i_r(ordy_int), - .o_d({olast, odat}), - .o_v(ovld), - .o_r(ordy) -); - -endmodule diff --git a/finn-rtllib/mvu_tiled/tb/mvu_tiled_axi_tb.sv b/finn-rtllib/mvu_tiled/tb/mvu_tiled_axi_tb.sv new file mode 100644 index 0000000000..11c973f67e --- /dev/null +++ b/finn-rtllib/mvu_tiled/tb/mvu_tiled_axi_tb.sv @@ -0,0 +1,341 @@ +/****************************************************************************** + * Copyright (C) 2024, Advanced Micro Devices, Inc. + * All rights reserved. + * + * SPDX-License-Identifier: BSD-3-Clause + * + * @brief Testbench for MVU-Tiled AXI wrapper module. + * @details + * Adapted from mvu/tb/mvu_axi_tb.sv for the tiled architecture. + * Exercises mvu_tiled_axi with multiple parameter configurations in parallel. + * + * Data flow under test: + * activations -> replay_buff_tile (replays MH/PE times) + * weights -> weights_buff_tile (collects NW=TH words, replays TH times) + * compute -> cu_mvau_tiled (DSP58 INT8, 3 MACs/DSP) + * accumulate -> acc_stage + add_tree (pipelined tree + circular FIFO) + * reorder -> reorder_out (transpose tiled -> sequential NF order) + * + * Weight feed order: + * For each neuron fold (h), for each SIMD fold (w): + * send TH chunks of WSIMD weights (PE*SIMD total per tile). + * + * Activation feed order: + * For each TH tile (y), for each SIMD fold (x): + * send one SIMD-wide activation word. + * Replay buffer handles repetition across neuron folds. + * + * Output order (after reorder_out): + * Sequential neuron folds: nf=0..MH/PE-1, one PE-wide word per fold, + * repeated for each input vector. + *****************************************************************************/ + +module mvu_tiled_axi_tb; + + // Test Configurations + localparam int unsigned ROUNDS = 7; + + typedef struct { + int unsigned mh; + int unsigned mw; + int unsigned pe; + int unsigned simd; + int unsigned th; + int unsigned weight_width; + int unsigned activation_width; + int unsigned accu_width; + bit signed_activations; + bit narrow_weights; + } cfg_t; + + // Constraints enforced by mvu_tiled_axi: + // - MW % SIMD == 0 + // - MH % PE == 0 + // - (PE * SIMD) % TH == 0 + // - WEIGHT_WIDTH <= 8 + // - ACTIVATION_WIDTH <= 8 (9 for signed -- uses full 9-bit A port) + // - TH >= 2 (TH=1 uses the non-tiled path) + // + // Test selection rationale: + // 0: Baseline -- balanced PE/SIMD, TH=2, signed activations, narrow weights + // 1: Larger TH (=3), odd SIMD (=3) -> CHAINLEN=1 (3 lanes in one DSP) + // 2: TH = PE*SIMD (maximum tiling, WSIMD=1) -- edge case: 1 weight/cycle + // 3: High PE (=6), low SIMD (=2) -> wide PE fanout, CHAINLEN=1 + // 4: PE=MH (no replay, single neuron fold) -- tests replay bypass + // 5: Large matrix, moderate tiling -- closer to real workload + // 6: SIMD=6 (CHAINLEN=2), TH=2 -- multi-DSP chain with tiling + // 7: Unsigned activations, small bitwidths -- corner case for sign extension + // 8: TH=6 (high tiling), PE=2, SIMD=3 -- stress accumulator depth + localparam int unsigned TEST_COUNT = 9; + // mh mw pe simd th ww aw accw sa nw + localparam cfg_t TESTS[TEST_COUNT] = '{ + '{ 12, 12, 6, 3, 2, 8, 8, 24, 1, 1 }, + '{ 12, 12, 6, 3, 3, 4, 4, 16, 1, 0 }, + '{ 12, 8, 2, 4, 8, 8, 8, 24, 1, 0 }, + '{ 12, 10, 6, 2, 3, 4, 8, 20, 0, 0 }, + '{ 4, 12, 4, 3, 2, 8, 4, 20, 1, 1 }, + '{ 24, 18, 6, 6, 3, 4, 4, 18, 1, 0 }, + '{ 16, 12, 4, 6, 2, 8, 8, 24, 0, 1 }, + '{ 8, 12, 4, 3, 2, 2, 2, 12, 0, 1 }, + '{ 6, 9, 2, 3, 6, 4, 4, 16, 1, 0 } + }; + + //----------------------------------------------------------------------- + // Global Control + logic clk = 0; + always #5ns clk = !clk; + logic clk2x = 0; + always #2.5ns clk2x = !clk2x; + + logic rst = 1; + initial begin + repeat(16) @(posedge clk); + rst <= 0; + // Allow 100ns DSP startup recovery before any input + #100ns; + end + + bit [TEST_COUNT-1:0] done = '0; + always_comb begin + if(&done) $finish; + end + + //----------------------------------------------------------------------- + // Parallel Test Instantiation + for(genvar t = 0; t < TEST_COUNT; t++) begin : genTests + localparam cfg_t CFG = TESTS[t]; + localparam int unsigned MH = CFG.mh; + localparam int unsigned MW = CFG.mw; + localparam int unsigned PE = CFG.pe; + localparam int unsigned SIMD = CFG.simd; + localparam int unsigned TH = CFG.th; + localparam int unsigned WEIGHT_WIDTH = CFG.weight_width; + localparam int unsigned ACTIVATION_WIDTH = CFG.activation_width; + localparam int unsigned ACCU_WIDTH = CFG.accu_width; + + // Derived + localparam int unsigned SF = MW / SIMD; // SIMD folds + localparam int unsigned NF = MH / PE; // neuron folds + localparam int unsigned WSIMD = (PE * SIMD) / TH; + + typedef logic signed [WEIGHT_WIDTH -1:0] weight_t; + typedef logic [ACTIVATION_WIDTH-1:0] activation_t; + typedef logic signed [ACCU_WIDTH -1:0] accu_t; + + // Stream widths (matching mvu_tiled_axi localparams) + localparam int unsigned WEIGHT_STREAM_WIDTH = WSIMD * WEIGHT_WIDTH; + localparam int unsigned WEIGHT_STREAM_WIDTH_BA = (WEIGHT_STREAM_WIDTH + 7)/8 * 8; + localparam int unsigned INPUT_STREAM_WIDTH = SIMD * ACTIVATION_WIDTH; + localparam int unsigned INPUT_STREAM_WIDTH_BA = (INPUT_STREAM_WIDTH + 7)/8 * 8; + localparam int unsigned OUTPUT_STREAM_WIDTH = PE * ACCU_WIDTH; + localparam int unsigned OUTPUT_STREAM_WIDTH_BA = (OUTPUT_STREAM_WIDTH + 7)/8 * 8; + + // DUT signals + logic [WEIGHT_STREAM_WIDTH_BA-1:0] wdat; + logic wvld; + uwire wrdy; + logic [INPUT_STREAM_WIDTH_BA-1:0] idat; + logic ivld; + uwire irdy; + uwire [OUTPUT_STREAM_WIDTH_BA-1:0] odat; + uwire ovld; + logic ordy; + + mvu_tiled_axi #( + .PE(PE), .SIMD(SIMD), + .WEIGHT_WIDTH(WEIGHT_WIDTH), + .ACTIVATION_WIDTH(ACTIVATION_WIDTH), + .ACCU_WIDTH(ACCU_WIDTH), + .MW(MW), .MH(MH), .TH(TH), + .SIGNED_ACTIVATIONS(CFG.signed_activations), + .NARROW_WEIGHTS(CFG.narrow_weights), + .PUMPED_COMPUTE(0), + .FORCE_BEHAVIORAL(0) + ) dut ( + .ap_clk(clk), + .ap_clk2x(clk2x), + .ap_rst_n(!rst), + .s_axis_weights_tdata(wdat), + .s_axis_weights_tvalid(wvld), + .s_axis_weights_tready(wrdy), + .s_axis_input_tdata(idat), + .s_axis_input_tvalid(ivld), + .s_axis_input_tready(irdy), + .m_axis_output_tdata(odat), + .m_axis_output_tvalid(ovld), + .m_axis_output_tready(ordy) + ); + + //--------------------------------------------------------------- + // Input Feed & Reference Generation + //--------------------------------------------------------------- + // TH input vectors are batched per round. The replay buffer + // stores TH*SF activation words (TH vectors, each SF folds) + // and the weight buffer replays each tile TH times internally. + // + // reorder_out output order: for each TH slot, all NF neuron + // folds in sequence. + //--------------------------------------------------------------- + accu_t [PE-1:0] Q[$]; + initial begin + wdat = 'x; wvld = 0; + idat = 'x; ivld = 0; + @(posedge clk iff !rst); + + // Wait for DSP startup recovery + repeat(20) @(posedge clk); + + repeat(ROUNDS) begin + // TH activation vectors per batch + automatic activation_t [TH-1:0][MW-1:0] ivecs; + automatic weight_t [MH-1:0][MW-1:0] iwgt; + automatic accu_t [TH-1:0][MH-1:0] ovecs; + + // Randomize all inputs + void'(std::randomize(ivecs, iwgt)); + + // Sanitize weights (narrow + overflow) using first vector + for(int unsigned h = 0; h < MH; h++) begin + automatic accu_t p = 0; + for(int unsigned w = 0; w < MW; w++) begin + automatic weight_t w0 = iwgt[h][w]; + automatic accu_t m0, p0; + + if(CFG.narrow_weights && (w0 == weight_t'(1 << (WEIGHT_WIDTH-1)))) w0++; + m0 = w0 * $signed({CFG.signed_activations && ivecs[0][w][ACTIVATION_WIDTH-1], ivecs[0][w]}); + p0 = p + m0; + if(((m0 < 0) == (p < 0)) && ((m0 < 0) != (p0 < 0))) w0 = 0; + else p = p0; + + iwgt[h][w] = w0; + end + end + + // Compute golden reference for each of TH vectors + for(int unsigned y = 0; y < TH; y++) begin + for(int unsigned h = 0; h < MH; h++) begin + automatic accu_t p = 0; + for(int unsigned w = 0; w < MW; w++) begin + p += $signed(iwgt[h][w]) * $signed({CFG.signed_activations && ivecs[y][w][ACTIVATION_WIDTH-1], ivecs[y][w]}); + end + ovecs[y][h] = p; + end + end + + // Enqueue expected outputs in reorder_out order: + // for each TH slot, all NF neuron folds + for(int unsigned y = 0; y < TH; y++) begin + for(int unsigned h = 0; h < MH; h += PE) begin + Q.push_back(ovecs[y][h+:PE]); + end + end + + // Feed activations and weights concurrently + fork + //-- Activation feed -- + // Replay buffer write order: X (SF) inner, Y (TH) outer. + // Feed TH vectors, each SF SIMD-wide words. + begin : blkActFeed + for(int unsigned y = 0; y < TH; y++) begin + for(int unsigned x = 0; x < SF; x++) begin + while($urandom()%19 == 0) @(posedge clk); + idat <= ivecs[y][x*SIMD +: SIMD]; + ivld <= 1; + @(posedge clk iff irdy); + idat <= 'x; + ivld <= 0; + end + end + end : blkActFeed + + //-- Weight feed -- + // One weight matrix, chunked: for each NF, for each SF, + // send TH chunks of WSIMD weights. + begin : blkWgtFeed + for(int unsigned h = 0; h < MH; h += PE) begin + for(int unsigned w = 0; w < MW; w += SIMD) begin + // Build full PE*SIMD weight tile + automatic weight_t [PE-1:0][SIMD-1:0] wtile; + for(int unsigned pe = 0; pe < PE; pe++) begin + for(int unsigned simd = 0; simd < SIMD; simd++) begin + wtile[pe][simd] = iwgt[h+pe][w+simd]; + end + end + + // Slice into TH chunks of WSIMD weights + for(int unsigned chunk = 0; chunk < TH; chunk++) begin + automatic logic [WEIGHT_STREAM_WIDTH_BA-1:0] wword = '0; + for(int unsigned k = 0; k < WSIMD; k++) begin + automatic int unsigned flat_idx = chunk * WSIMD + k; + automatic int unsigned pe_idx = flat_idx / SIMD; + automatic int unsigned simd_idx = flat_idx % SIMD; + wword[k*WEIGHT_WIDTH +: WEIGHT_WIDTH] = wtile[pe_idx][simd_idx]; + end + + while($urandom()%23 == 0) @(posedge clk); + wdat <= wword; + wvld <= 1; + @(posedge clk iff wrdy); + wdat <= 'x; + wvld <= 0; + end + end + end + end : blkWgtFeed + join + end + + repeat(256) @(posedge clk); + assert(Q.size == 0) else begin + $error("Test #%0d: Missing %0d outputs.", t, Q.size); + $stop; + end + done[t] = 1; + end + + //--------------------------------------------------------------- + // Output Checker + //--------------------------------------------------------------- + int unsigned Checks = 0; + initial begin + ordy = 0; + @(posedge clk iff !rst); + + forever begin + automatic accu_t [PE-1:0] exp; + automatic accu_t [PE-1:0] p; + + while(($urandom() % 59) == 0) @(posedge clk); + + // Drain one output + ordy <= 1; + @(posedge clk iff ovld); + ordy <= 0; + + p = odat; + assert(Q.size > 0) else begin + $error("Test #%0d: Spurious output: %0p.", t, p); + $stop; + end + + exp = Q.pop_front(); + assert(p === exp) else begin + $error("Test #%0d: Output mismatch %0p instead of %0p.", t, p, exp); + $stop; + end + + Checks <= Checks + 1; + end + end + + final begin + assert(Checks == ROUNDS * NF * TH) + $display("Test #%0d: OK -- %0d checks (MH=%0d MW=%0d PE=%0d SIMD=%0d TH=%0d).", + t, Checks, MH, MW, PE, SIMD, TH); + else + $error("Test #%0d: Unexpected check count: %0d instead of %0d.", t, Checks, ROUNDS * NF * TH); + end + + end : genTests + +endmodule : mvu_tiled_axi_tb diff --git a/src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py b/src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py index 3d1f0d764f..5e4662317e 100644 --- a/src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py +++ b/src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py @@ -167,13 +167,11 @@ def instantiate_ip(self, cmd): sourcefiles = [ "../fifo/hdl/Q_srl.v", "../skid/skid.sv", - "../ram/ram_p_c.sv", "mvu_tiled_axi.sv", "cu_mvau_tiled.sv", "acc_stage.sv", "add_tree.sv", - "reorder_out.sv", - "replay_buff_tile.sv", + "input_gen.sv", "weights_buff_tile.sv", ] else: @@ -379,12 +377,10 @@ def get_rtl_file_list(self, abspath=False): verilog_files = [ "../fifo/hdl/Q_srl.v", "../skid/skid.sv", - "../ram/ram_p_c.sv", "acc_stage.sv", "add_tree.sv", - "replay_buff_tile.sv", + "input_gen.sv", "weights_buff_tile.sv", - "reorder_out.sv", "cu_mvau_tiled.sv", "mvu_tiled_axi.sv", ] From 1c3a9cb232679b8de5c87594d4a9d637696b0a95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20B=2E=20Preu=C3=9Fer?= Date: Mon, 1 Jun 2026 15:13:53 +0100 Subject: [PATCH 09/17] Replace add_tree by add_multi. --- finn-rtllib/mvu_tiled/acc_stage.sv | 45 +++--- finn-rtllib/mvu_tiled/add_tree.sv | 145 ------------------ .../rtl/matrixvectoractivation_rtl.py | 6 +- 3 files changed, 31 insertions(+), 165 deletions(-) delete mode 100644 finn-rtllib/mvu_tiled/add_tree.sv diff --git a/finn-rtllib/mvu_tiled/acc_stage.sv b/finn-rtllib/mvu_tiled/acc_stage.sv index d20607a906..a44e3f13ab 100644 --- a/finn-rtllib/mvu_tiled/acc_stage.sv +++ b/finn-rtllib/mvu_tiled/acc_stage.sv @@ -51,30 +51,39 @@ module acc_stage #( ); // -// Adder tree +// Adder tree + accumulator add // -localparam integer TREE_HEIGHT = $clog2(CHAINLEN); -localparam integer ADD_LAT = TREE_HEIGHT + 1; +localparam integer TREE_DEPTH = $clog2(CHAINLEN); +localparam integer ADD_LAT = TREE_DEPTH + 1; logic [PE-1:0][ACCU_WIDTH-1:0] dat_acc; logic [PE-1:0][ACCU_WIDTH-1:0] dat_int; -for(genvar i = 0; i < PE; i++) begin - add_tree #( - .CHAINLEN(CHAINLEN), - .ACCU_WIDTH(ACCU_WIDTH), - .TREE_HEIGHT(TREE_HEIGHT) - ) inst_add_stage ( - .clk(clk), - .rst(rst), - .en(en), - - .idat(idat[i]), - .iacc(dat_acc[i]), - .odat(dat_int[i]) - ); -end +for(genvar i = 0; i < PE; i++) begin : genAdd + // Tree reduction of CHAINLEN DSP partial products + logic [ACCU_WIDTH-1:0] add_arg[CHAINLEN]; + for(genvar k = 0; k < CHAINLEN; k++) + assign add_arg[k] = idat[i][k]; + + localparam int unsigned SUM_WIDTH = $clog2(CHAINLEN) + ACCU_WIDTH; + uwire [SUM_WIDTH-1:0] tree_sum; + add_multi #( + .N(CHAINLEN), + .DEPTH(TREE_DEPTH), + .ARG_WIDTH(ACCU_WIDTH) + ) inst_add ( + .clk(clk), .rst(rst), .en(en), + .arg(add_arg), + .sum(tree_sum) + ); + + // Accumulator add (1 registered stage) + always_ff @(posedge clk) begin + if(rst) dat_int[i] <= 'x; + else if(en) dat_int[i] <= tree_sum[ACCU_WIDTH-1:0] + dat_acc[i]; + end +end : genAdd // REG logic [ADD_LAT:0] val; diff --git a/finn-rtllib/mvu_tiled/add_tree.sv b/finn-rtllib/mvu_tiled/add_tree.sv deleted file mode 100644 index 33eba29087..0000000000 --- a/finn-rtllib/mvu_tiled/add_tree.sv +++ /dev/null @@ -1,145 +0,0 @@ -/****************************************************************************** - * Copyright (C) 2024, Advanced Micro Devices, Inc. - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, - * this list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, - * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR - * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR - * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, - * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR - * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF - * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - *****************************************************************************/ - -module add_tree #( - parameter CHAINLEN, - parameter ACCU_WIDTH, - parameter TREE_HEIGHT -) ( - input logic clk, - input logic rst, - input logic en, - - input logic [CHAINLEN-1:0][ACCU_WIDTH-1:0] idat, - - input logic [ACCU_WIDTH-1:0] iacc, - - output logic [ACCU_WIDTH-1:0] odat -); - -//-------------------- Adder tree - -function automatic int level_len (int lvl); - return (CHAINLEN + (1 << lvl) - 1) >> lvl; // ceil(CHAINLEN / 2^lvl) -endfunction - -logic signed [ACCU_WIDTH-1:0] add_sf; - -if(CHAINLEN == 1) begin - assign add_sf = idat[0]; -end -else begin - logic signed [TREE_HEIGHT:0][CHAINLEN-1:0][ACCU_WIDTH-1:0] add_s; - - for(genvar i = 0; i < CHAINLEN; i++) begin - assign add_s[0][i] = signed'(idat[i]); - end - - /* - always_ff @(posedge clk) begin - if(rst) begin - for(int i = 1; i <= TREE_HEIGHT; i++) begin - add_s[i] <= '0; - end - end - else begin - if(en) begin - for(int i = 0; i < TREE_HEIGHT; i++) begin - for(int j = 0; j < (CHAINLEN/2 + (2**i-1))/(2**i); j++) begin - add_s[i+1][j] <= $signed(add_s[i][2*j+0]) + $signed(add_s[i][2*j+1]); - end - end - end - end - end - */ - - always_ff @(posedge clk) begin - if (rst) begin - // Clear all levels (safe for unused slots too) - for (int i = 1; i <= TREE_HEIGHT; i++) begin - for (int j = 0; j < CHAINLEN; j++) begin - add_s[i][j] <= '0; - end - end - end else if (en) begin - // For each level i, produce next level i+1 - for (int i = 0; i < TREE_HEIGHT; i++) begin - int src_len = level_len(i); // live elems at level i - int dst_len = level_len(i + 1); // live elems at level i+1 (ceil(src_len/2)) - - // Compute valid outputs only (0..dst_len-1). Leave rest zero. - for (int j = 0; j < dst_len; j++) begin - int a_idx = 2*j; - int b_idx = 2*j + 1; - - // Cases: - // - both indices in range -> sum - // - only a_idx in range -> pass-through - // - neither in range -> zero (shouldn't happen for j Date: Mon, 1 Jun 2026 15:32:11 +0100 Subject: [PATCH 10/17] Style fixes. --- finn-rtllib/mvu_tiled/acc_stage.sv | 244 +++++------ finn-rtllib/mvu_tiled/cu_mvau_tiled.sv | 437 +++++++++---------- finn-rtllib/mvu_tiled/mvu_tiled_axi.sv | 181 ++++---- finn-rtllib/mvu_tiled/tb/mvu_tiled_axi_tb.sv | 23 +- finn-rtllib/mvu_tiled/weights_buff_tile.sv | 407 ++++++++--------- 5 files changed, 606 insertions(+), 686 deletions(-) diff --git a/finn-rtllib/mvu_tiled/acc_stage.sv b/finn-rtllib/mvu_tiled/acc_stage.sv index a44e3f13ab..108ec9155e 100644 --- a/finn-rtllib/mvu_tiled/acc_stage.sv +++ b/finn-rtllib/mvu_tiled/acc_stage.sv @@ -23,7 +23,7 @@ * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. @@ -32,139 +32,117 @@ module acc_stage #( - int unsigned CHAINLEN, - int unsigned PE, - int unsigned ACCU_WIDTH, - int unsigned TH, - int unsigned TH_MAX = 2*TH -) ( - input logic clk, - input logic rst, - input logic en, - - input logic [PE-1:0][CHAINLEN-1:0][ACCU_WIDTH-1:0] idat, - input logic ival, - input logic ilast, - - output logic [PE-1:0][ACCU_WIDTH-1:0] odat, - output logic oval + int unsigned CHAINLEN, + int unsigned PE, + int unsigned ACCU_WIDTH, + int unsigned TH, + int unsigned TH_MAX = 2*TH +)( + input logic clk, + input logic rst, + input logic en, + + input logic [PE-1:0][CHAINLEN-1:0][ACCU_WIDTH-1:0] idat, + input logic ival, + input logic ilast, + + output logic [PE-1:0][ACCU_WIDTH-1:0] odat, + output logic oval ); -// -// Adder tree + accumulator add -// - -localparam integer TREE_DEPTH = $clog2(CHAINLEN); -localparam integer ADD_LAT = TREE_DEPTH + 1; - -logic [PE-1:0][ACCU_WIDTH-1:0] dat_acc; -logic [PE-1:0][ACCU_WIDTH-1:0] dat_int; - -for(genvar i = 0; i < PE; i++) begin : genAdd - // Tree reduction of CHAINLEN DSP partial products - logic [ACCU_WIDTH-1:0] add_arg[CHAINLEN]; - for(genvar k = 0; k < CHAINLEN; k++) - assign add_arg[k] = idat[i][k]; - - localparam int unsigned SUM_WIDTH = $clog2(CHAINLEN) + ACCU_WIDTH; - uwire [SUM_WIDTH-1:0] tree_sum; - add_multi #( - .N(CHAINLEN), - .DEPTH(TREE_DEPTH), - .ARG_WIDTH(ACCU_WIDTH) - ) inst_add ( - .clk(clk), .rst(rst), .en(en), - .arg(add_arg), - .sum(tree_sum) - ); - - // Accumulator add (1 registered stage) - always_ff @(posedge clk) begin - if(rst) dat_int[i] <= 'x; - else if(en) dat_int[i] <= tree_sum[ACCU_WIDTH-1:0] + dat_acc[i]; - end -end : genAdd - -// REG -logic [ADD_LAT:0] val; -logic [ADD_LAT:0] last; - -assign val[0] = ival; -assign last[0] = ilast; - -always_ff @(posedge clk) begin - if(rst) begin - for(int i = 1; i <= ADD_LAT; i++) begin - val[i] <= 1'b0; - last[i] <= 'X; - end - end - else begin - if(en) begin - for(int i = 1; i <= ADD_LAT; i++) begin - val[i] <= val[i-1]; - last[i] <= last[i-1]; - end - end - end -end - -logic val_int; -logic last_int; -logic inc_acc; - -assign val_int = val[ADD_LAT]; -assign last_int = last[ADD_LAT]; -assign inc_acc = val[ADD_LAT-1]; - -// -// Accumulation -// - -logic fifo_in_tvalid, fifo_in_tready; -logic fifo_out_tvalid, fifo_out_tready; -logic [PE*ACCU_WIDTH-1:0] fifo_in_tdata, fifo_out_tdata; - -Q_srl #( - .depth(TH_MAX), - .width(PE*ACCU_WIDTH) -) inst_acc ( - .clock(clk), - .reset(rst), - .i_d(fifo_in_tdata), - .i_v(fifo_in_tvalid), - .i_r(fifo_in_tready), - .o_d(fifo_out_tdata), - .o_v(fifo_out_tvalid), - .o_r(fifo_out_tready), - .count(), - .maxcount() -); - -logic signed [$clog2(TH):0] cnt_prep = -TH; -uwire prep = cnt_prep[$left(cnt_prep)]; -always_ff @(posedge clk) begin - if(rst) cnt_prep <= -TH; - else cnt_prep <= cnt_prep + prep; -end - -always_ff @(posedge clk) begin - if(rst) begin - odat <= 'x; - oval <= 0; - end - else if(en) begin - odat <= dat_int; - oval <= val_int && last_int; - end -end - -always_comb begin - fifo_in_tvalid = prep ? 1'b1 : (en ? val_int : 1'b0); - fifo_in_tdata = prep ? 0 : (last_int ? 0 : dat_int); -end - -assign dat_acc = fifo_out_tdata; -assign fifo_out_tready = en & inc_acc; + //=== Adder Tree + Accumulator Add ====================================== + localparam int unsigned TREE_DEPTH = $clog2(CHAINLEN); + localparam int unsigned ADD_LAT = TREE_DEPTH + 1; + + logic [PE-1:0][ACCU_WIDTH-1:0] Acc; + logic [PE-1:0][ACCU_WIDTH-1:0] DatInt; + + for(genvar i = 0; i < PE; i++) begin : genAdd + // Tree reduction of CHAINLEN DSP partial products + logic [ACCU_WIDTH-1:0] add_arg[CHAINLEN]; + for(genvar k = 0; k < CHAINLEN; k++) + assign add_arg[k] = idat[i][k]; + + localparam int unsigned SUM_WIDTH = $clog2(CHAINLEN) + ACCU_WIDTH; + uwire [SUM_WIDTH-1:0] tree_sum; + add_multi #( + .N(CHAINLEN), + .DEPTH(TREE_DEPTH), + .ARG_WIDTH(ACCU_WIDTH) + ) inst_add ( + .clk(clk), .rst(rst), .en(en), + .arg(add_arg), + .sum(tree_sum) + ); + + // Accumulator add (1 registered stage) + always_ff @(posedge clk) begin + if(rst) DatInt[i] <= 'x; + else if(en) DatInt[i] <= tree_sum[ACCU_WIDTH-1:0] + Acc[i]; + end + end : genAdd + + //=== Valid/Last Pipeline =============================================== + logic [ADD_LAT:0] Val; + logic [ADD_LAT:0] Last; + + assign Val[0] = ival; + assign Last[0] = ilast; + + always_ff @(posedge clk) begin + if(rst) begin + for(int i = 1; i <= ADD_LAT; i++) begin + Val [i] <= 0; + Last[i] <= 'x; + end + end + else if(en) begin + for(int i = 1; i <= ADD_LAT; i++) begin + Val [i] <= Val [i-1]; + Last[i] <= Last[i-1]; + end + end + end + + uwire val_out = Val[ADD_LAT]; + uwire last_out = Last[ADD_LAT]; + uwire inc_acc = Val[ADD_LAT-1]; + + //=== Prep Counter ====================================================== + logic signed [$clog2(TH):0] CntPrep = -TH; + uwire prep = CntPrep[$left(CntPrep)]; + always_ff @(posedge clk) begin + if(rst) CntPrep <= -TH; + else CntPrep <= CntPrep + prep; + end + + //=== Accumulation FIFO ================================================= + Q_srl #( + .depth(TH_MAX), + .width(PE*ACCU_WIDTH) + ) inst_acc ( + .clock(clk), + .reset(rst), + .i_d(prep? {PE*ACCU_WIDTH{1'b0}} : (last_out? {PE*ACCU_WIDTH{1'b0}} : DatInt)), + .i_v(prep? 1 : (en? val_out : 0)), + .i_r(), + .o_d(Acc), + .o_v(), + .o_r(en & inc_acc), + .count(), + .maxcount() + ); + + //=== Output Stage ====================================================== + always_ff @(posedge clk) begin + if(rst) begin + odat <= 'x; + oval <= 0; + end + else if(en) begin + odat <= DatInt; + oval <= val_out && last_out; + end + end endmodule : acc_stage diff --git a/finn-rtllib/mvu_tiled/cu_mvau_tiled.sv b/finn-rtllib/mvu_tiled/cu_mvau_tiled.sv index b716fe22b4..02412ad52f 100644 --- a/finn-rtllib/mvu_tiled/cu_mvau_tiled.sv +++ b/finn-rtllib/mvu_tiled/cu_mvau_tiled.sv @@ -23,7 +23,7 @@ * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. @@ -31,37 +31,33 @@ *****************************************************************************/ module cu_mvau_tiled #( - int unsigned PE, - int unsigned SIMD, - int unsigned TH, - int unsigned WEIGHT_WIDTH, - int unsigned ACTIVATION_WIDTH, - int unsigned ACCU_WIDTH, + int unsigned PE, + int unsigned SIMD, + int unsigned TH, + int unsigned WEIGHT_WIDTH, + int unsigned ACTIVATION_WIDTH, + int unsigned ACCU_WIDTH, - bit SIGNED_ACTIVATIONS = 1, - localparam int unsigned WEIGHT_ELEMENTS = PE*SIMD - ) ( - // Global Control - input logic clk, - input logic rst, - input logic en, + bit SIGNED_ACTIVATIONS = 1, + localparam int unsigned WEIGHT_ELEMENTS = PE*SIMD +)( + input logic clk, + input logic rst, + input logic en, - // Input - input logic ilast, - input logic ivld, - input logic [WEIGHT_ELEMENTS-1:0][WEIGHT_WIDTH-1:0] w, // weights - input logic [SIMD-1:0][ACTIVATION_WIDTH-1:0] a, // activations + input logic ilast, + input logic ivld, + input logic [WEIGHT_ELEMENTS-1:0][WEIGHT_WIDTH-1:0] w, + input logic [SIMD-1:0][ACTIVATION_WIDTH-1:0] a, - // Ouput - output logic ovld, - output logic [PE-1:0][ACCU_WIDTH-1:0] p - ); + output logic ovld, + output logic [PE-1:0][ACCU_WIDTH-1:0] p +); - //----------------------------------------------------------------------- - // Startup Recovery Watchdog - // The DSP slice needs 100ns of recovery time after initial startup before - // being able to ingest input properly. This watchdog discovers violating - // stimuli during simulation and produces a corresponding warning. + //=== Startup Recovery Watchdog ========================================= + // The DSP slice needs 100ns of recovery time after initial startup before + // being able to ingest input properly. This watchdog discovers violating + // stimuli during simulation and produces a corresponding warning. if(1) begin : blkRecoveryWatch logic Dirty = 1; initial begin @@ -76,221 +72,218 @@ module cu_mvau_tiled #( end end : blkRecoveryWatch -//-------------------- Declare global signals --------------------\\ - localparam int unsigned CHAINLEN = (SIMD+2)/3; - uwire [26:0] a_in_i [CHAINLEN]; - uwire [23:0] b_in_i [PE][CHAINLEN]; - uwire [PE-1:0][CHAINLEN-1:0][ACCU_WIDTH-1:0] pout; // Array with packed dimension > 256 (with a loop-carried dependency) cannot be handled out-of-the-box with PyVerilator + //=== Input Formatting ================================================== + localparam int unsigned CHAINLEN = (SIMD+2)/3; + uwire [26:0] a_in_i[CHAINLEN]; + uwire [23:0] b_in_i[PE][CHAINLEN]; + // Array with packed dimension > 256 cannot be handled out-of-the-box with PyVerilator + uwire [PE-1:0][CHAINLEN-1:0][ACCU_WIDTH-1:0] pout; -//-------------------- Shift register for last and valid signals --------------------\\ - localparam int unsigned DSP_PIPELINE_STAGES = 1; - logic L [0:1+DSP_PIPELINE_STAGES] = '{default: 0}; - logic V [0:1+DSP_PIPELINE_STAGES] = '{default: 0}; + //--- Valid/Last Pipeline ----------------------------------------------- + localparam int unsigned DSP_PIPELINE_STAGES = 1; + logic L[0:1+DSP_PIPELINE_STAGES] = '{default: 0}; + logic V[0:1+DSP_PIPELINE_STAGES] = '{default: 0}; always_ff @(posedge clk) begin if(rst) begin - L <= '{default: 0}; - V <= '{default: 0}; - end + L <= '{default: 0}; + V <= '{default: 0}; + end else if(en) begin L[1+DSP_PIPELINE_STAGES] <= ilast; L[0:DSP_PIPELINE_STAGES] <= L[1:1+DSP_PIPELINE_STAGES]; - V[1+DSP_PIPELINE_STAGES] <= ivld; + V[1+DSP_PIPELINE_STAGES] <= ivld; V[0:DSP_PIPELINE_STAGES] <= V[1:1+DSP_PIPELINE_STAGES]; end end - logic last; - logic vld; - assign last = L[0]; - assign vld = V[0]; - -//-------------------- Buffer for input activations --------------------\\ - localparam int unsigned PAD_BITS_ACT = 9 - ACTIVATION_WIDTH; - for (genvar i=0; i 8) begin + if(WEIGHT_WIDTH > 8) begin $error("Weight width of %0d-bits exceeds maximum of 8-bits", WEIGHT_WIDTH); $finish; end - if (ACTIVATION_WIDTH > 8) begin + if(ACTIVATION_WIDTH > 8) begin $error("Activation width of %0d-bits exceeds maximum of 8-bits", ACTIVATION_WIDTH); $finish; end @@ -130,12 +128,12 @@ module mvu_tiled_axi #( uwire rst = !ap_rst_n; - //- Replay to Accommodate Neuron Fold ----------------------------------- + //=== Activation Replay ================================================= typedef logic [SIMD-1:0][ACTIVATION_WIDTH-1:0] mvu_flatin_t; - uwire mvu_flatin_t amvau; - uwire alast; - uwire avld; - uwire ardy; + uwire mvu_flatin_t amvau; + uwire alast; + uwire avld; + uwire ardy; localparam int unsigned SF = MW / SIMD; localparam int unsigned NF = MH / PE; @@ -155,11 +153,11 @@ module mvu_tiled_axi #( ); assign alast = act_done[1]; - //- Unflatten weights --------------------------------------------------- + //=== Weight Buffering ================================================== typedef logic [PE-1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] mvu_w_t; - uwire mvu_w_t wdat; - uwire wvld; - uwire wrdy; + uwire mvu_w_t wdat; + uwire wvld; + uwire wrdy; weights_buff_tile #( .WEIGHT_WIDTH(WEIGHT_WIDTH), @@ -172,58 +170,52 @@ module mvu_tiled_axi #( .ovld(wvld), .ordy(wrdy), .odat(wdat) ); - //- Flow Control Bracket around Compute Core ---------------------------- - uwire en; - uwire istb = avld && wvld; - assign ardy = en && wvld; - assign wrdy = en && avld; + //=== Flow Control ====================================================== + uwire en; + uwire istb = avld && wvld; + assign ardy = en && wvld; + assign wrdy = en && avld; - //- Conditionally Pumped DSP Compute ------------------------------------ + //=== DSP Compute ======================================================= typedef logic [PE-1:0][ACCU_WIDTH-1:0] dsp_p_t; uwire ovld; - uwire dsp_p_t odat; + uwire dsp_p_t odat; if(1) begin : blkDsp - localparam int unsigned EFFECTIVE_SIMD = SIMD_UNEVEN && PUMPED_COMPUTE ? SIMD+1 : SIMD; - localparam int unsigned DSP_SIMD = EFFECTIVE_SIMD/(PUMPED_COMPUTE+1); + localparam int unsigned EFFECTIVE_SIMD = SIMD_UNEVEN && PUMPED_COMPUTE? SIMD+1 : SIMD; + localparam int unsigned DSP_SIMD = EFFECTIVE_SIMD / (PUMPED_COMPUTE+1); typedef logic [PE -1:0][DSP_SIMD-1:0][WEIGHT_WIDTH -1:0] dsp_w_t; typedef logic [DSP_SIMD-1:0][ACTIVATION_WIDTH-1:0] dsp_a_t; uwire dsp_last; - uwire dsp_zero; - uwire dsp_w_t dsp_w; - uwire dsp_a_t dsp_a; + uwire dsp_w_t dsp_w; + uwire dsp_a_t dsp_a; uwire dsp_vld; - uwire dsp_p_t dsp_p; + uwire dsp_p_t dsp_p; // TODO: No double-pumping in the initial implementation - uwire dsp_en = en; + uwire dsp_en = en; assign dsp_last = alast && istb; - assign dsp_zero = !istb; assign dsp_w = wdat; assign dsp_a = amvau; assign ovld = dsp_vld; assign odat = dsp_p; - // - // Compute Unit - // - - case(COMPUTE_CORE) - "mvu_vvu_8sx9_dsp58": begin : core - cu_mvau_tiled #( - .PE(PE), .SIMD(SIMD), - .TH(TH), - .WEIGHT_WIDTH(WEIGHT_WIDTH), .ACTIVATION_WIDTH(ACTIVATION_WIDTH), .ACCU_WIDTH(ACCU_WIDTH), - .SIGNED_ACTIVATIONS(SIGNED_ACTIVATIONS) - ) inst_cu_mvau_tiled ( - .clk(ap_clk), .rst(rst), .en(dsp_en), + case(COMPUTE_CORE) + "mvu_vvu_8sx9_dsp58": begin : core + cu_mvau_tiled #( + .PE(PE), .SIMD(SIMD), + .TH(TH), + .WEIGHT_WIDTH(WEIGHT_WIDTH), .ACTIVATION_WIDTH(ACTIVATION_WIDTH), .ACCU_WIDTH(ACCU_WIDTH), + .SIGNED_ACTIVATIONS(SIGNED_ACTIVATIONS) + ) inst_cu_mvau_tiled ( + .clk(ap_clk), .rst(rst), .en(dsp_en), .ivld(istb), .ilast(dsp_last), .w(dsp_w), .a(dsp_a), .ovld(dsp_vld), .p(dsp_p) - ); - end + ); + end default: initial begin $error("Unrecognized COMPUTE_CORE '%s'", COMPUTE_CORE); $finish; @@ -232,25 +224,25 @@ module mvu_tiled_axi #( end : blkDsp - //-------------------- Output register slice --------------------\\ - // Make `en`computation independent from external inputs. + //=== Output Register Slice ============================================= + // Make `en` computation independent from external inputs. // Drive all outputs from registers. - logic m_axis_int_tvalid; - logic m_axis_int_tready; - logic [OUTPUT_STREAM_WIDTH_BA-1:0] m_axis_int_tdata; + logic MIntVld; + uwire m_int_rdy; + logic [OUTPUT_STREAM_WIDTH_BA-1:0] MIntDat; struct packed { - logic rdy; - logic [PE-1:0][ACCU_WIDTH-1:0] dat; - } A = '{ rdy: 1, default: 'x }; // side-step register used when encountering backpressure + logic rdy; + logic [PE-1:0][ACCU_WIDTH-1:0] dat; + } A = '{ rdy: 1, default: 'x }; // side-step register used when encountering backpressure struct packed { - logic vld; - logic [PE-1:0][ACCU_WIDTH-1:0] dat; - } B = '{ vld: 0, default: 'x }; // ultimate output register + logic vld; + logic [PE-1:0][ACCU_WIDTH-1:0] dat; + } B = '{ vld: 0, default: 'x }; // ultimate output register assign en = A.rdy; - uwire b_load = !B.vld || m_axis_int_tready; + uwire b_load = !B.vld || m_int_rdy; always_ff @(posedge ap_clk) begin if(rst) begin @@ -269,14 +261,12 @@ module mvu_tiled_axi #( end end end - assign m_axis_int_tvalid = B.vld; - // Why would we need a sign extension here potentially creating a higher signal load into the next FIFO? - // These extra bits should never be used. Why not 'x them out? - assign m_axis_int_tdata = { {(OUTPUT_STREAM_WIDTH_BA-OUTPUT_STREAM_WIDTH){B.dat[PE-1][ACCU_WIDTH-1]}}, B.dat}; + assign MIntVld = B.vld; + assign MIntDat = { {(OUTPUT_STREAM_WIDTH_BA-OUTPUT_STREAM_WIDTH){B.dat[PE-1][ACCU_WIDTH-1]}}, B.dat }; - //-------------------- Output reordering --------------------\\ + //=== Output Reordering ================================================= - if(OUT_TILED == 0) begin + if(OUT_TILED == 0) begin : genReorder input_gen #( .DATA_WIDTH(OUTPUT_STREAM_WIDTH_BA), .FM_SIZE(NF * TH), @@ -285,15 +275,16 @@ module mvu_tiled_axi #( .COEFS('{1, TH}) ) inst_reorder_out ( .clk(ap_clk), .rst(rst), - .idat(m_axis_int_tdata), - .ivld(m_axis_int_tvalid), .irdy(m_axis_int_tready), + .idat(MIntDat), + .ivld(MIntVld), .irdy(m_int_rdy), .odat(m_axis_output_tdata), .ovld(m_axis_output_tvalid), .olst(), .odone(), .ordy(m_axis_output_tready) ); - end else begin - assign m_axis_output_tvalid = m_axis_int_tvalid; - assign m_axis_int_tready = m_axis_output_tready; - assign m_axis_output_tdata = m_axis_int_tdata; - end + end : genReorder + else begin : genPassthru + assign m_axis_output_tvalid = MIntVld; + assign m_int_rdy = m_axis_output_tready; + assign m_axis_output_tdata = MIntDat; + end : genPassthru endmodule : mvu_tiled_axi diff --git a/finn-rtllib/mvu_tiled/tb/mvu_tiled_axi_tb.sv b/finn-rtllib/mvu_tiled/tb/mvu_tiled_axi_tb.sv index 11c973f67e..a45fa7db50 100644 --- a/finn-rtllib/mvu_tiled/tb/mvu_tiled_axi_tb.sv +++ b/finn-rtllib/mvu_tiled/tb/mvu_tiled_axi_tb.sv @@ -10,11 +10,11 @@ * Exercises mvu_tiled_axi with multiple parameter configurations in parallel. * * Data flow under test: - * activations -> replay_buff_tile (replays MH/PE times) + * activations -> input_gen (replays MH/PE times) * weights -> weights_buff_tile (collects NW=TH words, replays TH times) * compute -> cu_mvau_tiled (DSP58 INT8, 3 MACs/DSP) - * accumulate -> acc_stage + add_tree (pipelined tree + circular FIFO) - * reorder -> reorder_out (transpose tiled -> sequential NF order) + * accumulate -> acc_stage (pipelined add_multi tree + circular FIFO) + * reorder -> input_gen (transpose tiled -> sequential NF order) * * Weight feed order: * For each neuron fold (h), for each SIMD fold (w): @@ -80,8 +80,7 @@ module mvu_tiled_axi_tb; '{ 6, 9, 2, 3, 6, 4, 4, 16, 1, 0 } }; - //----------------------------------------------------------------------- - // Global Control + //=== Global Control ==================================================== logic clk = 0; always #5ns clk = !clk; logic clk2x = 0; @@ -100,8 +99,7 @@ module mvu_tiled_axi_tb; if(&done) $finish; end - //----------------------------------------------------------------------- - // Parallel Test Instantiation + //=== Parallel Test Instantiation ======================================= for(genvar t = 0; t < TEST_COUNT; t++) begin : genTests localparam cfg_t CFG = TESTS[t]; localparam int unsigned MH = CFG.mh; @@ -166,16 +164,13 @@ module mvu_tiled_axi_tb; .m_axis_output_tready(ordy) ); - //--------------------------------------------------------------- - // Input Feed & Reference Generation - //--------------------------------------------------------------- + //=== Input Feed & Reference Generation ============================= // TH input vectors are batched per round. The replay buffer // stores TH*SF activation words (TH vectors, each SF folds) // and the weight buffer replays each tile TH times internally. // - // reorder_out output order: for each TH slot, all NF neuron + // Output reorder order: for each TH slot, all NF neuron // folds in sequence. - //--------------------------------------------------------------- accu_t [PE-1:0] Q[$]; initial begin wdat = 'x; wvld = 0; @@ -293,9 +288,7 @@ module mvu_tiled_axi_tb; done[t] = 1; end - //--------------------------------------------------------------- - // Output Checker - //--------------------------------------------------------------- + //=== Output Checker ================================================ int unsigned Checks = 0; initial begin ordy = 0; diff --git a/finn-rtllib/mvu_tiled/weights_buff_tile.sv b/finn-rtllib/mvu_tiled/weights_buff_tile.sv index 258365f74e..f959140cd3 100644 --- a/finn-rtllib/mvu_tiled/weights_buff_tile.sv +++ b/finn-rtllib/mvu_tiled/weights_buff_tile.sv @@ -23,7 +23,7 @@ * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. @@ -31,230 +31,195 @@ *****************************************************************************/ module weights_buff_tile #( - int unsigned WEIGHT_WIDTH = 8, - int unsigned SIMD, - int unsigned PE, - int unsigned TH, - int unsigned WSIMD, - int unsigned NW = (PE*SIMD)/WSIMD, - int unsigned N_DCPL_STAGES + int unsigned WEIGHT_WIDTH = 8, + int unsigned SIMD, + int unsigned PE, + int unsigned TH, + int unsigned WSIMD, + int unsigned NW = (PE*SIMD)/WSIMD, + int unsigned N_DCPL_STAGES )( input logic clk, input logic rst, - input logic ivld, - output logic irdy, - input logic [WSIMD-1:0][WEIGHT_WIDTH-1:0] idat, + input logic ivld, + output logic irdy, + input logic [WSIMD-1:0][WEIGHT_WIDTH-1:0] idat, - output logic ovld, - input logic ordy, - output logic [PE-1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] odat + output logic ovld, + input logic ordy, + output logic [PE-1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] odat ); - -//-------------------- Parameter sanity checks -------------------------------- - -initial begin - if ((PE*SIMD) % WSIMD != 0) begin - $error("Weight stream width not set properly (WSIMD: %0d, PE %0d, SIMD %0d).", WSIMD, PE, SIMD); - $finish; - end -end - -// ---------------------------------------------------------------------------- -// Consts and types -// ---------------------------------------------------------------------------- - -localparam integer NW_BITS = (NW == 1) ? 1 : $clog2(NW); -localparam integer TH_BITS = (TH == 1) ? 1 : $clog2(TH); - -typedef enum logic[1:0] {ST_WR_0, ST_WR_0_WAIT, ST_WR_1, ST_WR_1_WAIT} state_wr_t; -typedef enum logic {ST_RD_0, ST_RD_1} state_rd_t; - -// ---------------------------------------------------------------------------- -// Slice -// ---------------------------------------------------------------------------- - -logic ivld_int; -logic irdy_int; -logic [WSIMD-1:0][WEIGHT_WIDTH-1:0] idat_int; - -// Ireg -skid #(.DATA_WIDTH(WSIMD*WEIGHT_WIDTH), .FEED_STAGES(1)) inst_ireg ( - .clk(clk), .rst(rst), - .ivld(ivld), .irdy(irdy), .idat(idat), - .ovld(ivld_int), .ordy(irdy_int), .odat(idat_int) -); - -// ---------------------------------------------------------------------------- -// Writer -// ---------------------------------------------------------------------------- - -// -- Regs -state_wr_t state_wr_C = ST_WR_0, state_wr_N; -state_rd_t state_rd_C = ST_RD_0, state_rd_N; - -logic [NW_BITS-1:0] curr_C = '0, curr_N; - -logic done; - -logic ovld_int; -logic ordy_int; -logic [PE-1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] odat_int; - -// -- Mem -logic [1:0][NW-1:0][WSIMD*WEIGHT_WIDTH-1:0] mem_C, mem_N; - -// -- REG -always_ff @( posedge clk ) begin : REG_PROC_WR - if(rst) begin - state_wr_C <= ST_WR_0; - - curr_C <= '0; - mem_C <= '0; - end - else begin - state_wr_C <= state_wr_N; - - curr_C <= curr_N; - mem_C <= mem_N; - end -end - -// -- NSL -always_comb begin : NSL_PROC_WR - state_wr_N = state_wr_C; - - case (state_wr_C) - ST_WR_0: - if ((curr_C == NW - 1) && ivld_int) begin - state_wr_N = (done || (state_rd_C == ST_RD_0)) ? ST_WR_1 : ST_WR_0_WAIT; - end - - ST_WR_0_WAIT: - state_wr_N = (done || (state_rd_C == ST_RD_0)) ? ST_WR_1 : ST_WR_0_WAIT; - - ST_WR_1: - if ((curr_C == NW - 1) && ivld_int) begin - state_wr_N = (done || (state_rd_C == ST_RD_1)) ? ST_WR_0 : ST_WR_1_WAIT; - end - - ST_WR_1_WAIT: - state_wr_N = (done || (state_rd_C == ST_RD_1)) ? ST_WR_0 : ST_WR_1_WAIT; - - endcase -end - -// -- DP -always_comb begin : DP_PROC_WR - curr_N = curr_C; - mem_N = mem_C; - - // Input - irdy_int = 1'b0; - - // Write and count - case (state_wr_C) - ST_WR_0, ST_WR_1: begin - irdy_int = 1'b1; - - if(ivld_int) begin - if(state_wr_C == ST_WR_0) begin - mem_N[0] = (mem_C[0] >> WSIMD*WEIGHT_WIDTH); - mem_N[0][NW-1] = idat_int; - end - else begin - mem_N[1] = (mem_C[1] >> WSIMD*WEIGHT_WIDTH); - mem_N[1][NW-1] = idat_int; - end - - curr_N = (curr_C == NW-1) ? 0 : curr_C + 1; - end - end - endcase -end - -// ---------------------------------------------------------------------------- -// Reader -// ---------------------------------------------------------------------------- - -// -- Regs -logic [TH_BITS-1:0] cons_r_C = '0, cons_r_N; - -// -- REG -always_ff @( posedge clk ) begin : REG_PROC_RD - if(rst) begin - state_rd_C <= ST_RD_0; - - cons_r_C <= 0; - end - else begin - state_rd_C <= state_rd_N; - - cons_r_C <= cons_r_N; - end -end - -// -- NSL -always_comb begin : NSL_PROC_RD - state_rd_N = state_rd_C; - - case (state_rd_C) - ST_RD_0: - if(ordy_int && (state_wr_C != ST_WR_0)) begin - if(cons_r_C == TH-1) begin - state_rd_N = ST_RD_1; - end - end - - ST_RD_1: - if(ordy_int && (state_wr_C != ST_WR_1)) begin - if(cons_r_C == TH-1) begin - state_rd_N = ST_RD_0; - end - end - - endcase -end - -// -- DP -always_comb begin : DP_PROC_RD - cons_r_N = cons_r_C; - - done = 1'b0; - - ovld_int = 1'b0; - odat_int = 0; - - case (state_rd_C) - ST_RD_0: begin - if(ordy_int && (state_wr_C != ST_WR_0)) begin - ovld_int = 1'b1; - odat_int = mem_C[0]; - - done = (cons_r_C == TH-1); - cons_r_N = (cons_r_C == TH-1) ? 0 : cons_r_C + 1; - end - end - - ST_RD_1: begin - if(ordy_int && (state_wr_C != ST_WR_1)) begin - ovld_int = 1'b1; - odat_int = mem_C[1]; - - done = (cons_r_C == TH-1); - cons_r_N = (cons_r_C == TH-1) ? 0 : cons_r_C + 1; - end - end - - endcase -end - -// Oreg -skid #(.DATA_WIDTH(PE*SIMD*WEIGHT_WIDTH), .FEED_STAGES(N_DCPL_STAGES)) inst_oreg ( - .clk(clk), .rst(rst), - .ivld(ovld_int), .irdy(ordy_int), .idat(odat_int), - .ovld(ovld), .ordy(ordy), .odat(odat) -); - -endmodule + //=== Parameter Validation ============================================== + initial begin + if((PE*SIMD) % WSIMD != 0) begin + $error("Weight stream width not set properly (WSIMD: %0d, PE %0d, SIMD %0d).", WSIMD, PE, SIMD); + $finish; + end + end + + //=== Constants and Types =============================================== + localparam int unsigned NW_BITS = (NW == 1)? 1 : $clog2(NW); + localparam int unsigned TH_BITS = (TH == 1)? 1 : $clog2(TH); + + typedef enum logic [1:0] {ST_WR_0, ST_WR_0_WAIT, ST_WR_1, ST_WR_1_WAIT} state_wr_e; + typedef enum logic {ST_RD_0, ST_RD_1} state_rd_e; + + //=== Input Slice ======================================================= + uwire ivld_int; + logic irdy_int; + uwire [WSIMD-1:0][WEIGHT_WIDTH-1:0] idat_int; + + skid #(.DATA_WIDTH(WSIMD*WEIGHT_WIDTH), .FEED_STAGES(1)) inst_ireg ( + .clk(clk), .rst(rst), + .ivld(ivld), .irdy(irdy), .idat(idat), + .ovld(ivld_int), .ordy(irdy_int), .odat(idat_int) + ); + + //=== Writer ============================================================ + state_wr_e StateWr = ST_WR_0; + state_wr_e state_wr_n; + state_rd_e StateRd = ST_RD_0; + state_rd_e state_rd_n; + + logic [NW_BITS-1:0] Curr = '0; + logic [NW_BITS-1:0] curr_n; + + logic done; + + logic ovld_int; + logic ordy_int; + logic [PE-1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] odat_int; + + logic [1:0][NW-1:0][WSIMD*WEIGHT_WIDTH-1:0] Mem = '0; + logic [1:0][NW-1:0][WSIMD*WEIGHT_WIDTH-1:0] mem_n; + + //--- Writer Sequential ------------------------------------------------- + always_ff @(posedge clk) begin + if(rst) begin + StateWr <= ST_WR_0; + Curr <= '0; + Mem <= '0; + end + else begin + StateWr <= state_wr_n; + Curr <= curr_n; + Mem <= mem_n; + end + end + + //--- Writer Next State ------------------------------------------------- + always_comb begin + state_wr_n = StateWr; + + case(StateWr) + ST_WR_0: + if((Curr == NW - 1) && ivld_int) + state_wr_n = (done || (StateRd == ST_RD_0))? ST_WR_1 : ST_WR_0_WAIT; + + ST_WR_0_WAIT: + state_wr_n = (done || (StateRd == ST_RD_0))? ST_WR_1 : ST_WR_0_WAIT; + + ST_WR_1: + if((Curr == NW - 1) && ivld_int) + state_wr_n = (done || (StateRd == ST_RD_1))? ST_WR_0 : ST_WR_1_WAIT; + + ST_WR_1_WAIT: + state_wr_n = (done || (StateRd == ST_RD_1))? ST_WR_0 : ST_WR_1_WAIT; + endcase + end + + //--- Writer Datapath --------------------------------------------------- + always_comb begin + curr_n = Curr; + mem_n = Mem; + irdy_int = 0; + + case(StateWr) + ST_WR_0, ST_WR_1: begin + irdy_int = 1; + + if(ivld_int) begin + if(StateWr == ST_WR_0) begin + mem_n[0] = (Mem[0] >> WSIMD*WEIGHT_WIDTH); + mem_n[0][NW-1] = idat_int; + end + else begin + mem_n[1] = (Mem[1] >> WSIMD*WEIGHT_WIDTH); + mem_n[1][NW-1] = idat_int; + end + + curr_n = (Curr == NW-1)? 0 : Curr + 1; + end + end + endcase + end + + //=== Reader ============================================================ + logic [TH_BITS-1:0] ConsR = '0; + logic [TH_BITS-1:0] cons_r_n; + + //--- Reader Sequential ------------------------------------------------- + always_ff @(posedge clk) begin + if(rst) begin + StateRd <= ST_RD_0; + ConsR <= 0; + end + else begin + StateRd <= state_rd_n; + ConsR <= cons_r_n; + end + end + + //--- Reader Next State ------------------------------------------------- + always_comb begin + state_rd_n = StateRd; + + case(StateRd) + ST_RD_0: + if(ordy_int && (StateWr != ST_WR_0)) + if(ConsR == TH-1) + state_rd_n = ST_RD_1; + + ST_RD_1: + if(ordy_int && (StateWr != ST_WR_1)) + if(ConsR == TH-1) + state_rd_n = ST_RD_0; + endcase + end + + //--- Reader Datapath --------------------------------------------------- + always_comb begin + cons_r_n = ConsR; + done = 0; + ovld_int = 0; + odat_int = 0; + + case(StateRd) + ST_RD_0: + if(ordy_int && (StateWr != ST_WR_0)) begin + ovld_int = 1; + odat_int = Mem[0]; + done = (ConsR == TH-1); + cons_r_n = (ConsR == TH-1)? 0 : ConsR + 1; + end + + ST_RD_1: + if(ordy_int && (StateWr != ST_WR_1)) begin + ovld_int = 1; + odat_int = Mem[1]; + done = (ConsR == TH-1); + cons_r_n = (ConsR == TH-1)? 0 : ConsR + 1; + end + endcase + end + + //=== Output Slice ====================================================== + skid #(.DATA_WIDTH(PE*SIMD*WEIGHT_WIDTH), .FEED_STAGES(N_DCPL_STAGES)) inst_oreg ( + .clk(clk), .rst(rst), + .ivld(ovld_int), .irdy(ordy_int), .idat(odat_int), + .ovld(ovld), .ordy(ordy), .odat(odat) + ); + +endmodule : weights_buff_tile From 0dacdd8cb9868e4ddecb3bb0643c0c8fcbc0da63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20B=2E=20Preu=C3=9Fer?= Date: Tue, 2 Jun 2026 06:50:10 +0100 Subject: [PATCH 11/17] Review of fetch_weights. --- finn-rtllib/fetch_weights/fetch_weights.sv | 576 +++++++++--------- .../fetch_weights/fetch_weights_wrapper.v | 8 +- .../fetch_weights/local_weight_buffer.sv | 513 ++++++++-------- finn-rtllib/mvu_tiled/acc_stage.sv | 6 +- finn-rtllib/mvu_tiled/cu_mvau_tiled.sv | 7 +- src/finn/custom_op/fpgadataflow/hwcustomop.py | 2 - 6 files changed, 540 insertions(+), 572 deletions(-) diff --git a/finn-rtllib/fetch_weights/fetch_weights.sv b/finn-rtllib/fetch_weights/fetch_weights.sv index 73aff27039..bad9304c79 100644 --- a/finn-rtllib/fetch_weights/fetch_weights.sv +++ b/finn-rtllib/fetch_weights/fetch_weights.sv @@ -31,288 +31,298 @@ *****************************************************************************/ module fetch_weights #( - int unsigned PE, - int unsigned SIMD, - int unsigned TH = 1, - int unsigned MH, - int unsigned MW, - int unsigned N_REPS, - int unsigned WEIGHT_WIDTH = 8, - - int unsigned ADDR_BITS = 64, - int unsigned DATA_BITS = 256, - int unsigned LEN_BITS = 32, - int unsigned IDX_BITS = 16, - - int unsigned N_LAYERS, - - int unsigned EN_MLO = 1, - - int unsigned QDEPTH = 8, - int unsigned EN_OREG = 1, - int unsigned N_DCPL_STGS = 1, - int unsigned DBG = 0, - - // Safely deducible parameters - int unsigned IWSIMD = (TH > 1) ? ((PE*SIMD)/TH) : SIMD, - int unsigned OWSIMD = (PE * SIMD) / TH, - int unsigned DS_BITS_BA = (IWSIMD*WEIGHT_WIDTH+7)/8 * 8, - int unsigned WS_BITS_BA = (OWSIMD*WEIGHT_WIDTH+7)/8 * 8, - logic[ADDR_BITS-1:0] LAYER_OFFS = ((MH*MW*WEIGHT_WIDTH+7)/8 + (DATA_BITS/8-1)) & ~(DATA_BITS/8-1) // AXI bus-width aligned -) ( - input wire aclk, - input wire aresetn, - - output logic m_done, - - // AXI - output logic[ADDR_BITS-1:0] m_axi_ddr_araddr, - output logic[1:0] m_axi_ddr_arburst, - output logic[3:0] m_axi_ddr_arcache, - output logic[1:0] m_axi_ddr_arid, - output logic[7:0] m_axi_ddr_arlen, - output logic[0:0] m_axi_ddr_arlock, - output logic[2:0] m_axi_ddr_arprot, - output logic[2:0] m_axi_ddr_arsize, - input logic m_axi_ddr_arready, - output logic m_axi_ddr_arvalid, - output logic[ADDR_BITS-1:0] m_axi_ddr_awaddr, - output logic[1:0] m_axi_ddr_awburst, - output logic[3:0] m_axi_ddr_awcache, - output logic[1:0] m_axi_ddr_awid, - output logic[7:0] m_axi_ddr_awlen, - output logic[0:0] m_axi_ddr_awlock, - output logic[2:0] m_axi_ddr_awprot, - output logic[2:0] m_axi_ddr_awsize, - input logic m_axi_ddr_awready, - output logic m_axi_ddr_awvalid, - input logic[DATA_BITS-1:0] m_axi_ddr_rdata, - input logic[1:0] m_axi_ddr_rid, - input logic m_axi_ddr_rlast, - input logic[1:0] m_axi_ddr_rresp, - output logic m_axi_ddr_rready, - input logic m_axi_ddr_rvalid, - output logic[DATA_BITS-1:0] m_axi_ddr_wdata, - output logic m_axi_ddr_wlast, - output logic[DATA_BITS/8-1:0] m_axi_ddr_wstrb, - input logic m_axi_ddr_wready, - output logic m_axi_ddr_wvalid, - input logic[1:0] m_axi_ddr_bid, - input logic[1:0] m_axi_ddr_bresp, - output logic m_axi_ddr_bready, - input logic m_axi_ddr_bvalid, - - // Index - input logic s_idx_tvalid, - output logic s_idx_tready, - input logic[IDX_BITS-1:0] s_idx_tdata, - - // DMA stream out (to external width converter) - output logic axis_dma_tvalid, - input logic axis_dma_tready, - output logic[DATA_BITS-1:0] axis_dma_tdata, - output logic[DATA_BITS/8-1:0] axis_dma_tkeep, - output logic axis_dma_tlast, - - // DWC stream in (from external width converter) - input logic axis_dwc_tvalid, - output logic axis_dwc_tready, - input logic[DS_BITS_BA-1:0] axis_dwc_tdata, - input logic[(DS_BITS_BA)/8-1:0] axis_dwc_tkeep, - input logic axis_dwc_tlast, - - // Stream - // TODO: Should we reg this? Would be quite wide ... - output logic m_axis_tvalid, - input logic m_axis_tready, - output logic[WS_BITS_BA-1:0] m_axis_tdata + int unsigned PE, + int unsigned SIMD, + int unsigned TH = 1, + int unsigned MH, + int unsigned MW, + int unsigned N_REPS, + int unsigned WEIGHT_WIDTH = 8, + + int unsigned ADDR_BITS = 64, + int unsigned DATA_BITS = 256, + int unsigned LEN_BITS = 32, + int unsigned IDX_BITS = 16, + + int unsigned N_LAYERS, + + int unsigned QDEPTH = 8, + int unsigned EN_OREG = 1, + int unsigned N_DCPL_STGS = 1, + int unsigned DBG = 0, + + // Safely deducible parameters + int unsigned IWSIMD = (TH > 1)? ((PE*SIMD)/TH) : SIMD, + int unsigned OWSIMD = (PE * SIMD) / TH, + int unsigned DS_BITS_BA = (IWSIMD*WEIGHT_WIDTH+7)/8 * 8, + int unsigned WS_BITS_BA = (OWSIMD*WEIGHT_WIDTH+7)/8 * 8, + logic[ADDR_BITS-1:0] LAYER_OFFS = ((MH*MW*WEIGHT_WIDTH+7)/8 + (DATA_BITS/8-1)) & ~(DATA_BITS/8-1) // AXI bus-width aligned +)( + input logic aclk, + input logic aresetn, + + output logic m_done, + + // AXI + output logic[ADDR_BITS-1:0] m_axi_ddr_araddr, + output logic[1:0] m_axi_ddr_arburst, + output logic[3:0] m_axi_ddr_arcache, + output logic[1:0] m_axi_ddr_arid, + output logic[7:0] m_axi_ddr_arlen, + output logic[0:0] m_axi_ddr_arlock, + output logic[2:0] m_axi_ddr_arprot, + output logic[2:0] m_axi_ddr_arsize, + input logic m_axi_ddr_arready, + output logic m_axi_ddr_arvalid, + output logic[ADDR_BITS-1:0] m_axi_ddr_awaddr, + output logic[1:0] m_axi_ddr_awburst, + output logic[3:0] m_axi_ddr_awcache, + output logic[1:0] m_axi_ddr_awid, + output logic[7:0] m_axi_ddr_awlen, + output logic[0:0] m_axi_ddr_awlock, + output logic[2:0] m_axi_ddr_awprot, + output logic[2:0] m_axi_ddr_awsize, + input logic m_axi_ddr_awready, + output logic m_axi_ddr_awvalid, + input logic[DATA_BITS-1:0] m_axi_ddr_rdata, + input logic[1:0] m_axi_ddr_rid, + input logic m_axi_ddr_rlast, + input logic[1:0] m_axi_ddr_rresp, + output logic m_axi_ddr_rready, + input logic m_axi_ddr_rvalid, + output logic[DATA_BITS-1:0] m_axi_ddr_wdata, + output logic m_axi_ddr_wlast, + output logic[DATA_BITS/8-1:0] m_axi_ddr_wstrb, + input logic m_axi_ddr_wready, + output logic m_axi_ddr_wvalid, + input logic[1:0] m_axi_ddr_bid, + input logic[1:0] m_axi_ddr_bresp, + output logic m_axi_ddr_bready, + input logic m_axi_ddr_bvalid, + + // Index + input logic s_idx_tvalid, + output logic s_idx_tready, + input logic[IDX_BITS-1:0] s_idx_tdata, + + // DMA stream out (to external width converter) + output logic axis_dma_tvalid, + input logic axis_dma_tready, + output logic[DATA_BITS-1:0] axis_dma_tdata, + output logic[DATA_BITS/8-1:0] axis_dma_tkeep, + output logic axis_dma_tlast, + + // DWC stream in (from external width converter) + input logic axis_dwc_tvalid, + output logic axis_dwc_tready, + input logic[DS_BITS_BA-1:0] axis_dwc_tdata, + input logic[(DS_BITS_BA)/8-1:0] axis_dwc_tkeep, + input logic axis_dwc_tlast, + + // Stream + output logic m_axis_tvalid, + input logic m_axis_tready, + output logic[WS_BITS_BA-1:0] m_axis_tdata ); -// Offsets -logic [N_LAYERS-1:0][ADDR_BITS-1:0] l_offsets; -for(genvar i = 0; i < N_LAYERS; i++) begin - assign l_offsets[i] = (i * LAYER_OFFS); -end - -// -// Indexes and DMA -// - -logic dma_tvalid; -logic dma_tready; -logic [ADDR_BITS-1:0] dma_addr; -logic [LEN_BITS-1:0] dma_len; - -if(TH > 1) begin - - // Consts - localparam integer REPS_BITS = (N_REPS == 1) ? 1 : $clog2(N_REPS); - - // Reps - typedef enum logic[0:0] {ST_IDLE, ST_DMA} state_t; - state_t state_C = ST_IDLE, state_N; - - logic [REPS_BITS-1:0] cnt_dma_C = '0, cnt_dma_N; - logic [IDX_BITS-1:0] idx_C = '0, idx_N; - - logic q_idx_out_tvalid, q_idx_out_tready; - logic [IDX_BITS-1:0] q_idx_out_tdata; - - // Idx queue - Q_srl #( - .depth(QDEPTH), - .width(IDX_BITS) - ) inst_queue_in ( - .clock(aclk), .reset(!aresetn), - .count(), .maxcount(), - .i_d(s_idx_tdata), .i_v(s_idx_tvalid), .i_r(s_idx_tready), - .o_d(q_idx_out_tdata), .o_v(q_idx_out_tvalid), .o_r(q_idx_out_tready) - ); - - assign dma_addr = l_offsets[idx_C]; - assign dma_len = ((MH*MW*WEIGHT_WIDTH+7)/8) & ~7; - - always_ff @( posedge aclk ) begin: REG - if(~aresetn) begin - state_C <= ST_IDLE; - - cnt_dma_C <= '0; - idx_C <= 'X; - end else begin - state_C <= state_N; - - cnt_dma_C <= cnt_dma_N; - idx_C <= idx_N; - end - end - - always_comb begin: NSL - state_N = state_C; - - case (state_C) - ST_IDLE: - state_N = q_idx_out_tvalid ? ST_DMA : ST_IDLE; - - ST_DMA: - state_N = (cnt_dma_C == N_REPS-1) && dma_tready ? ST_IDLE : ST_DMA; - - endcase - end - - always_comb begin: DP - cnt_dma_N = cnt_dma_C; - idx_N = idx_C; - - q_idx_out_tready = 1'b0; - dma_tvalid = 1'b0; - - case (state_C) - ST_IDLE: begin - q_idx_out_tready = 1'b1; - cnt_dma_N = 0; - if(q_idx_out_tvalid) begin - idx_N = q_idx_out_tdata; - end - end - - ST_DMA: begin - dma_tvalid = 1'b1; - if(dma_tready) begin - cnt_dma_N = cnt_dma_C + 1; - end - end - - endcase - end - -end else begin - - // Idx queue - logic [IDX_BITS-1:0] q_idx_out_tdata; - - Q_srl #( - .depth(QDEPTH), - .width(IDX_BITS) - ) inst_idx_queue ( - .clock(aclk), .reset(!aresetn), - .count(), .maxcount(), - .i_d(s_idx_tdata), .i_v(s_idx_tvalid), .i_r(s_idx_tready), - .o_d(q_idx_out_tdata), .o_v(dma_tvalid), .o_r(dma_tready) - ); - - assign dma_addr = l_offsets[q_idx_out_tdata]; - assign dma_len = ((MH*MW*WEIGHT_WIDTH+7)/8) & ~7; - -end - -cdma_u_rd #( - .DATA_BITS(DATA_BITS), - .ADDR_BITS(ADDR_BITS), - .LEN_BITS(LEN_BITS) -) inst_dma ( - .aclk(aclk), .aresetn(aresetn), - - .rd_valid(dma_tvalid), .rd_ready(dma_tready), - .rd_paddr(dma_addr), .rd_len(dma_len), - .rd_done(m_done), - - .m_axi_ddr_arvalid(m_axi_ddr_arvalid), - .m_axi_ddr_arready(m_axi_ddr_arready), - .m_axi_ddr_araddr(m_axi_ddr_araddr), - .m_axi_ddr_arid(m_axi_ddr_arid), - .m_axi_ddr_arlen(m_axi_ddr_arlen), - .m_axi_ddr_arsize(m_axi_ddr_arsize), - .m_axi_ddr_arburst(m_axi_ddr_arburst), - .m_axi_ddr_arlock(m_axi_ddr_arlock), - .m_axi_ddr_arcache(m_axi_ddr_arcache), - .m_axi_ddr_arprot(m_axi_ddr_arprot), - .m_axi_ddr_rvalid(m_axi_ddr_rvalid), - .m_axi_ddr_rready(m_axi_ddr_rready), - .m_axi_ddr_rdata(m_axi_ddr_rdata), - .m_axi_ddr_rlast(m_axi_ddr_rlast), - .m_axi_ddr_rid(m_axi_ddr_rid), - .m_axi_ddr_rresp(m_axi_ddr_rresp), - - .m_axis_ddr_tvalid(axis_dma_tvalid), - .m_axis_ddr_tready(axis_dma_tready), - .m_axis_ddr_tdata(axis_dma_tdata), - .m_axis_ddr_tkeep(axis_dma_tkeep), - .m_axis_ddr_tlast(axis_dma_tlast) -); - -// Local weight buffer -// Only for non-tiled nodes -logic axis_lwb_tvalid; -logic axis_lwb_tready; -logic[WS_BITS_BA-1:0] axis_lwb_tdata; - -if(TH == 1) begin - local_weight_buffer #( - .PE(PE), .SIMD(SIMD), .MH(MH), .MW(MW), .N_REPS(N_REPS), .WEIGHT_WIDTH(WEIGHT_WIDTH), .DBG(DBG) - ) inst_weight_buff ( - .clk(aclk), .rst(~aresetn), - .ivld(axis_dwc_tvalid), .irdy(axis_dwc_tready), .idat(axis_dwc_tdata), - .ovld(axis_lwb_tvalid), .ordy(axis_lwb_tready), .odat(axis_lwb_tdata) - ); -end else begin - assign axis_lwb_tvalid = axis_dwc_tvalid; - assign axis_dwc_tready = axis_lwb_tready; - assign axis_lwb_tdata = axis_dwc_tdata; -end - -// Reg slice -if(EN_OREG) begin - skid #( - .DATA_WIDTH(WS_BITS_BA), .FEED_STAGES(N_DCPL_STGS) - ) inst_oreg ( - .clk(aclk), .rst(!aresetn), - .ivld(axis_lwb_tvalid), .irdy(axis_lwb_tready), .idat(axis_lwb_tdata), - .ovld(m_axis_tvalid), .ordy(m_axis_tready), .odat(m_axis_tdata) - ); -end else begin - assign m_axis_tvalid = axis_lwb_tvalid; - assign axis_lwb_tready = m_axis_tready; - assign m_axis_tdata = axis_lwb_tdata; -end - -endmodule + //=== Layer Offsets ===================================================== + logic [N_LAYERS-1:0][ADDR_BITS-1:0] l_offsets; + for(genvar i = 0; i < N_LAYERS; i++) begin : genOffs + assign l_offsets[i] = i * LAYER_OFFS; + end : genOffs + + //=== Index Handling & DMA Control ====================================== + logic dma_tvalid; + logic dma_tready; + logic [ADDR_BITS-1:0] dma_addr; + logic [ LEN_BITS-1:0] dma_len; + + if(TH > 1) begin : genTiled + + localparam int unsigned REPS_BITS = (N_REPS == 1)? 1 : $clog2(N_REPS); + + typedef enum logic [0:0] {ST_IDLE, ST_DMA} state_e; + + //--- Registers ----------------------------------------------------- + state_e State = ST_IDLE; + state_e state_n; + + logic [REPS_BITS-1:0] CntDma = '0; + logic [REPS_BITS-1:0] cnt_dma_n; + + logic [IDX_BITS-1:0] Idx = '0; + logic [IDX_BITS-1:0] idx_n; + + //--- Index Queue --------------------------------------------------- + uwire q_idx_vld; + logic q_idx_rdy; + uwire [IDX_BITS-1:0] q_idx_dat; + + Q_srl #(.depth(QDEPTH), .width(IDX_BITS)) inst_queue_in ( + .clock(aclk), .reset(!aresetn), + .count(), .maxcount(), + .i_d(s_idx_tdata), .i_v(s_idx_tvalid), .i_r(s_idx_tready), + .o_d(q_idx_dat), .o_v(q_idx_vld), .o_r(q_idx_rdy) + ); + + assign dma_addr = l_offsets[Idx]; + assign dma_len = ((MH*MW*WEIGHT_WIDTH+7)/8) & ~7; + + //--- Sequential ---------------------------------------------------- + always_ff @(posedge aclk) begin + if(~aresetn) begin + State <= ST_IDLE; + CntDma <= '0; + Idx <= 'x; + end + else begin + State <= state_n; + CntDma <= cnt_dma_n; + Idx <= idx_n; + end + end + + //--- Next State ---------------------------------------------------- + always_comb begin + state_n = State; + + case(State) + ST_IDLE: + state_n = q_idx_vld? ST_DMA : ST_IDLE; + + ST_DMA: + state_n = ((CntDma == N_REPS-1) && dma_tready)? ST_IDLE : ST_DMA; + endcase + end + + //--- Datapath ------------------------------------------------------ + always_comb begin + cnt_dma_n = CntDma; + idx_n = Idx; + + q_idx_rdy = 0; + dma_tvalid = 0; + + case(State) + ST_IDLE: begin + q_idx_rdy = 1; + cnt_dma_n = 0; + if(q_idx_vld) + idx_n = q_idx_dat; + end + + ST_DMA: begin + dma_tvalid = 1; + if(dma_tready) + cnt_dma_n = CntDma + 1; + end + endcase + end + + end : genTiled + else begin : genDirect + + uwire [IDX_BITS-1:0] q_idx_dat; + + Q_srl #(.depth(QDEPTH), .width(IDX_BITS)) inst_idx_queue ( + .clock(aclk), .reset(!aresetn), + .count(), .maxcount(), + .i_d(s_idx_tdata), .i_v(s_idx_tvalid), .i_r(s_idx_tready), + .o_d(q_idx_dat), .o_v(dma_tvalid), .o_r(dma_tready) + ); + + assign dma_addr = l_offsets[q_idx_dat]; + assign dma_len = ((MH*MW*WEIGHT_WIDTH+7)/8) & ~7; + + end : genDirect + + //=== Write Channel Tie-off (read-only DMA) ============================= + assign m_axi_ddr_awaddr = '0; + assign m_axi_ddr_awburst = '0; + assign m_axi_ddr_awcache = '0; + assign m_axi_ddr_awid = '0; + assign m_axi_ddr_awlen = '0; + assign m_axi_ddr_awlock = '0; + assign m_axi_ddr_awprot = '0; + assign m_axi_ddr_awsize = '0; + assign m_axi_ddr_awvalid = 0; + assign m_axi_ddr_wdata = '0; + assign m_axi_ddr_wlast = 0; + assign m_axi_ddr_wstrb = '0; + assign m_axi_ddr_wvalid = 0; + assign m_axi_ddr_bready = 0; + + //=== DMA Engine ======================================================== + cdma_u_rd #( + .DATA_BITS(DATA_BITS), + .ADDR_BITS(ADDR_BITS), + .LEN_BITS(LEN_BITS) + ) inst_dma ( + .aclk(aclk), .aresetn(aresetn), + + .rd_valid(dma_tvalid), .rd_ready(dma_tready), + .rd_paddr(dma_addr), .rd_len(dma_len), + .rd_done(m_done), + + .m_axi_ddr_arvalid(m_axi_ddr_arvalid), + .m_axi_ddr_arready(m_axi_ddr_arready), + .m_axi_ddr_araddr(m_axi_ddr_araddr), + .m_axi_ddr_arid(m_axi_ddr_arid), + .m_axi_ddr_arlen(m_axi_ddr_arlen), + .m_axi_ddr_arsize(m_axi_ddr_arsize), + .m_axi_ddr_arburst(m_axi_ddr_arburst), + .m_axi_ddr_arlock(m_axi_ddr_arlock), + .m_axi_ddr_arcache(m_axi_ddr_arcache), + .m_axi_ddr_arprot(m_axi_ddr_arprot), + .m_axi_ddr_rvalid(m_axi_ddr_rvalid), + .m_axi_ddr_rready(m_axi_ddr_rready), + .m_axi_ddr_rdata(m_axi_ddr_rdata), + .m_axi_ddr_rlast(m_axi_ddr_rlast), + .m_axi_ddr_rid(m_axi_ddr_rid), + .m_axi_ddr_rresp(m_axi_ddr_rresp), + + .m_axis_ddr_tvalid(axis_dma_tvalid), + .m_axis_ddr_tready(axis_dma_tready), + .m_axis_ddr_tdata(axis_dma_tdata), + .m_axis_ddr_tkeep(axis_dma_tkeep), + .m_axis_ddr_tlast(axis_dma_tlast) + ); + + //=== Local Weight Buffer =============================================== + logic axis_lwb_tvalid; + logic axis_lwb_tready; + logic [WS_BITS_BA-1:0] axis_lwb_tdata; + + if(TH == 1) begin : genLwb + local_weight_buffer #( + .PE(PE), .SIMD(SIMD), .MH(MH), .MW(MW), + .N_REPS(N_REPS), .WEIGHT_WIDTH(WEIGHT_WIDTH), .DBG(DBG) + ) inst_weight_buff ( + .clk(aclk), .rst(~aresetn), + .ivld(axis_dwc_tvalid), .irdy(axis_dwc_tready), .idat(axis_dwc_tdata), + .ovld(axis_lwb_tvalid), .ordy(axis_lwb_tready), .odat(axis_lwb_tdata) + ); + end : genLwb + else begin : genLwbPassthru + assign axis_lwb_tvalid = axis_dwc_tvalid; + assign axis_dwc_tready = axis_lwb_tready; + assign axis_lwb_tdata = axis_dwc_tdata; + end : genLwbPassthru + + //=== Output Register Slice ============================================= + if(EN_OREG) begin : genOreg + skid #( + .DATA_WIDTH(WS_BITS_BA), .FEED_STAGES(N_DCPL_STGS) + ) inst_oreg ( + .clk(aclk), .rst(!aresetn), + .ivld(axis_lwb_tvalid), .irdy(axis_lwb_tready), .idat(axis_lwb_tdata), + .ovld(m_axis_tvalid), .ordy(m_axis_tready), .odat(m_axis_tdata) + ); + end : genOreg + else begin : genOregPassthru + assign m_axis_tvalid = axis_lwb_tvalid; + assign axis_lwb_tready = m_axis_tready; + assign m_axis_tdata = axis_lwb_tdata; + end : genOregPassthru + +endmodule : fetch_weights diff --git a/finn-rtllib/fetch_weights/fetch_weights_wrapper.v b/finn-rtllib/fetch_weights/fetch_weights_wrapper.v index 92827edc4e..cf79afb7c3 100644 --- a/finn-rtllib/fetch_weights/fetch_weights_wrapper.v +++ b/finn-rtllib/fetch_weights/fetch_weights_wrapper.v @@ -150,15 +150,11 @@ fetch_weights #( .WEIGHT_WIDTH(WEIGHT_WIDTH), .IWSIMD(IWSIMD), .OWSIMD(WSIMD), .ADDR_BITS(ADDR_BITS), .DATA_BITS(DATA_BITS), .LEN_BITS(LEN_BITS), .IDX_BITS(IDX_BITS), - .N_LAYERS(N_LAYERS), -`ifdef EN_MLO - .EN_MLO(1) -`else - .EN_MLO(0) -`endif + .N_LAYERS(N_LAYERS) ) inst ( .aclk (ap_clk), .aresetn (ap_rst_n), + .m_done (out_done), .m_axi_ddr_araddr (axi_mm_araddr), .m_axi_ddr_arburst (axi_mm_arburst), diff --git a/finn-rtllib/fetch_weights/local_weight_buffer.sv b/finn-rtllib/fetch_weights/local_weight_buffer.sv index aec4a8ab0d..722c32f93d 100644 --- a/finn-rtllib/fetch_weights/local_weight_buffer.sv +++ b/finn-rtllib/fetch_weights/local_weight_buffer.sv @@ -31,275 +31,248 @@ *****************************************************************************/ module local_weight_buffer #( - int unsigned PE, - int unsigned SIMD, - int unsigned WEIGHT_WIDTH = 8, - int unsigned MH, - int unsigned MW, - int unsigned N_REPS, - int unsigned DBG = 0 -) ( - input logic clk, - input logic rst, - - input logic ivld, - output logic irdy, - input logic [SIMD-1:0][WEIGHT_WIDTH-1:0] idat, - - output logic ovld, - input logic ordy, - output logic [PE-1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] odat + int unsigned PE, + int unsigned SIMD, + int unsigned WEIGHT_WIDTH = 8, + int unsigned MH, + int unsigned MW, + int unsigned N_REPS, + int unsigned DBG = 0 +)( + input logic clk, + input logic rst, + + input logic ivld, + output logic irdy, + input logic [SIMD-1:0][WEIGHT_WIDTH-1:0] idat, + + output logic ovld, + input logic ordy, + output logic [PE-1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] odat ); -// ---------------------------------------------------------------------------- -// Consts and types -// ---------------------------------------------------------------------------- - -localparam int unsigned SF = MW/SIMD; -localparam int unsigned NF = MH/PE; -localparam int unsigned N_TLS = SF * NF; - -localparam int unsigned SIMD_BITS = (SIMD == 1) ? 1 : $clog2(SIMD); -localparam int unsigned PE_BITS = (PE == 1) ? 1 : $clog2(PE); -localparam int unsigned WGT_ADDR_BITS = $clog2(NF * SF); -localparam int unsigned RAM_BITS = (SIMD*WEIGHT_WIDTH + 7)/8 * 8; -localparam int unsigned WGT_EN_BITS = RAM_BITS / 8; -localparam int unsigned N_TLS_BITS = $clog2(N_TLS); -localparam int unsigned N_REPS_BITS = $clog2(N_REPS); - -typedef enum logic[1:0] {ST_WR_0, ST_WR_0_WAIT, ST_WR_1, ST_WR_1_WAIT} state_wr_t; -typedef enum logic {ST_RD_0, ST_RD_1} state_rd_t; - -// ---------------------------------------------------------------------------- -// Writer -// ---------------------------------------------------------------------------- - -// -- Regs -state_wr_t state_wr_C = ST_WR_0, state_wr_N; -state_rd_t state_rd_C = ST_RD_0, state_rd_N; - -logic[N_TLS_BITS-1:0] wr_pntr_C = '0, wr_pntr_N; -logic[PE_BITS-1:0] curr_pe_C = '0, curr_pe_N; - -// -- Signals -logic [1:0][PE-1:0][WGT_EN_BITS-1:0] a_we; // Bank enables -logic [1:0][WGT_ADDR_BITS-1:0] a_addr; -logic [1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] a_data_in; - -// -- REG -always_ff @( posedge clk ) begin : REG_PROC_WR - if(rst) begin - state_wr_C <= ST_WR_0; - - wr_pntr_C <= '0; - curr_pe_C <= '0; - end - else begin - state_wr_C <= state_wr_N; - - wr_pntr_C <= wr_pntr_N; - curr_pe_C <= curr_pe_N; - end -end - -// -- NSL -always_comb begin : NSL_PROC_WR - state_wr_N = state_wr_C; - - case (state_wr_C) - ST_WR_0: - if((curr_pe_C == PE - 1) && (wr_pntr_C == N_TLS - 1) && ivld) begin - state_wr_N = (state_rd_C == ST_RD_0) ? ST_WR_1 : ST_WR_0_WAIT; - end - - ST_WR_0_WAIT: - state_wr_N = (state_rd_C == ST_RD_0) ? ST_WR_1 : ST_WR_0_WAIT; - - ST_WR_1: - if((curr_pe_C == PE - 1) && (wr_pntr_C == N_TLS - 1) && ivld) begin - state_wr_N = (state_rd_C == ST_RD_1) ? ST_WR_0 : ST_WR_1_WAIT; - end - - ST_WR_1_WAIT: - state_wr_N = (state_rd_C == ST_RD_1) ? ST_WR_0 : ST_WR_1_WAIT; - - endcase -end - -// -- DP -always_comb begin : DP_PROC_WR - wr_pntr_N = wr_pntr_C; - curr_pe_N = curr_pe_C; - - // Input - irdy = 1'b0; - - // Buffers - a_we = '0; - for(int i = 0; i < 2; i++) begin - a_addr[i] = wr_pntr_C; - a_data_in[i] = idat; - end - - // Write and count - case (state_wr_C) - ST_WR_0, ST_WR_1: begin - irdy = 1'b1; - - if(ivld) begin - for(int i = 0; i < PE; i++) begin - if(curr_pe_C == i) begin - a_we[state_wr_C == ST_WR_1][i] = '1; - end - end - - curr_pe_N = (curr_pe_C == PE-1) ? 0 : curr_pe_C + 1; - wr_pntr_N = (curr_pe_C == PE-1) ? ((wr_pntr_C == N_TLS-1) ? 0 : wr_pntr_C + 1) : wr_pntr_C; - end - end - endcase - -end - -// ---------------------------------------------------------------------------- -// Reader -// ---------------------------------------------------------------------------- - -// -- Regs -logic [N_TLS_BITS-1:0] rd_pntr_C = '0, rd_pntr_N; -logic [N_REPS_BITS-1:0] reps_C = '0, reps_N; - -//logic [15:0] rd_pntr_C = '0, rd_pntr_N; -//logic [15:0] reps_C = '0, reps_N; - -logic [1:0] vld_s0_C = '0, vld_s0_N; -logic [1:0] vld_s1_C = '0, vld_s1_N; - -logic vld_C = '0, vld_N; -logic [PE-1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] odat_C = '0, odat_N; - -// -- Signals -logic [1:0][WGT_ADDR_BITS-1:0] b_addr; -logic [1:0][PE-1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] odat_ram; - -// -- REG -always_ff @( posedge clk ) begin : REG_PROC_RD - if(rst) begin - state_rd_C <= ST_RD_0; - - rd_pntr_C <= '0; - reps_C <= '0; - - vld_s0_C <= '0; - vld_s1_C <= '0; - vld_C <= '0; - odat_C <= 'X; - end - else begin - state_rd_C <= state_rd_N; - - rd_pntr_C <= rd_pntr_N; - reps_C <= reps_N; - - vld_s0_C <= vld_s0_N; - vld_s1_C <= vld_s1_N; - vld_C <= vld_N; - odat_C <= odat_N; - end -end - -// -- NSL -always_comb begin : NSL_PROC_RD - state_rd_N = state_rd_C; - - case (state_rd_C) - ST_RD_0: - if(ordy && ((state_wr_C == ST_WR_0) ? (wr_pntr_C > rd_pntr_C) : 1'b1)) begin - if((rd_pntr_C == N_TLS-1) && (reps_C == N_REPS-1)) begin - state_rd_N = ST_RD_1; - end - end - - ST_RD_1: - if(ordy && ((state_wr_C == ST_WR_1) ? (wr_pntr_C > rd_pntr_C) : 1'b1)) begin - if((rd_pntr_C == N_TLS-1) && (reps_C == N_REPS-1)) begin - state_rd_N = ST_RD_0; - end - end - endcase -end - -// -- DP -always_comb begin : DP_PROC_RD - rd_pntr_N = rd_pntr_C; - reps_N = reps_C; - - for(int i = 0; i < 2; i++) begin - vld_s0_N[i] = ordy ? 1'b0 : vld_s0_C[i]; - vld_s1_N[i] = ordy ? vld_s0_C[i] : vld_s1_C[i]; - end - - vld_N = ordy ? |vld_s1_C : vld_C; - odat_N = ordy ? (vld_s1_C[0] ? odat_ram[0] : odat_ram[1]) : odat_C; - - for(int i = 0; i < 2; i++) begin - b_addr[i] = rd_pntr_C; - end - - case(state_rd_C) - ST_RD_0: begin - if(ordy) begin - if((state_wr_C == ST_WR_0) ? (wr_pntr_C > rd_pntr_C) : 1'b1) begin - - vld_s0_N[0] = 1'b1; - - rd_pntr_N = (rd_pntr_C == N_TLS-1) ? 0 : rd_pntr_C + 1; - reps_N = (rd_pntr_C == N_TLS-1) ? ((reps_C == N_REPS-1) ? 0 : reps_C + 1) : reps_C; - end - end - end - - ST_RD_1: begin - if(ordy) begin - if((state_wr_C == ST_WR_1) ? (wr_pntr_C > rd_pntr_C) : 1'b1) begin - - vld_s0_N[1] = 1'b1; - - rd_pntr_N = (rd_pntr_C == N_TLS-1) ? 0 : rd_pntr_C + 1; - reps_N = (rd_pntr_C == N_TLS-1) ? ((reps_C == N_REPS-1) ? 0 : reps_C + 1) : reps_C; - end - end - end - - endcase - -end - -assign ovld = vld_C; -assign odat = odat_C; - -// ---------------------------------------------------------------------------- -// Weights -// ---------------------------------------------------------------------------- - -for(genvar i = 0; i < 2; i++) begin - for(genvar j = 0; j < PE; j++) begin - ram_p_c #( - .ADDR_BITS(WGT_ADDR_BITS), - .DATA_BITS(RAM_BITS), - .RAM_STYLE("block") - ) inst_ram_tp_c ( - .clk(clk), - .a_en(1'b1), - .a_we(a_we[i][j]), - .a_addr(a_addr[i]), - .b_en(ordy), - .b_addr(b_addr[i]), - .a_data_in(a_data_in[i]), - .a_data_out(), - .b_data_out(odat_ram[i][j]) - ); - end -end - -endmodule + //=== Constants and Types =============================================== + localparam int unsigned SF = MW / SIMD; + localparam int unsigned NF = MH / PE; + localparam int unsigned N_TLS = SF * NF; + localparam int unsigned PE_BITS = (PE == 1)? 1 : $clog2(PE); + localparam int unsigned WGT_ADDR_BITS = $clog2(NF * SF); + localparam int unsigned RAM_BITS = (SIMD*WEIGHT_WIDTH + 7)/8 * 8; + localparam int unsigned WGT_EN_BITS = RAM_BITS / 8; + localparam int unsigned N_TLS_BITS = $clog2(N_TLS); + localparam int unsigned N_REPS_BITS = $clog2(N_REPS); + + typedef enum logic [1:0] {ST_WR_0, ST_WR_0_WAIT, ST_WR_1, ST_WR_1_WAIT} state_wr_e; + typedef enum logic {ST_RD_0, ST_RD_1} state_rd_e; + + //=== Writer ============================================================ + + //--- Registers --------------------------------------------------------- + state_wr_e StateWr = ST_WR_0; + state_wr_e state_wr_n; + state_rd_e StateRd = ST_RD_0; + state_rd_e state_rd_n; + + logic [N_TLS_BITS-1:0] WrPntr = '0; + logic [N_TLS_BITS-1:0] wr_pntr_n; + + logic [PE_BITS-1:0] CurrPe = '0; + logic [PE_BITS-1:0] curr_pe_n; + + //--- Signals ----------------------------------------------------------- + logic [1:0][PE-1:0][WGT_EN_BITS-1:0] a_we; + logic [1:0][WGT_ADDR_BITS-1:0] a_addr; + logic [1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] a_data_in; + + //--- Sequential -------------------------------------------------------- + always_ff @(posedge clk) begin + if(rst) begin + StateWr <= ST_WR_0; + WrPntr <= '0; + CurrPe <= '0; + end + else begin + StateWr <= state_wr_n; + WrPntr <= wr_pntr_n; + CurrPe <= curr_pe_n; + end + end + + //--- Next State -------------------------------------------------------- + always_comb begin + state_wr_n = StateWr; + + case(StateWr) + ST_WR_0: + if((CurrPe == PE-1) && (WrPntr == N_TLS-1) && ivld) + state_wr_n = (StateRd == ST_RD_0)? ST_WR_1 : ST_WR_0_WAIT; + + ST_WR_0_WAIT: + state_wr_n = (StateRd == ST_RD_0)? ST_WR_1 : ST_WR_0_WAIT; + + ST_WR_1: + if((CurrPe == PE-1) && (WrPntr == N_TLS-1) && ivld) + state_wr_n = (StateRd == ST_RD_1)? ST_WR_0 : ST_WR_1_WAIT; + + ST_WR_1_WAIT: + state_wr_n = (StateRd == ST_RD_1)? ST_WR_0 : ST_WR_1_WAIT; + endcase + end + + //--- Datapath ---------------------------------------------------------- + always_comb begin + wr_pntr_n = WrPntr; + curr_pe_n = CurrPe; + + irdy = 0; + + a_we = '0; + for(int i = 0; i < 2; i++) begin + a_addr[i] = WrPntr; + a_data_in[i] = idat; + end + + case(StateWr) + ST_WR_0, ST_WR_1: begin + irdy = 1; + + if(ivld) begin + for(int i = 0; i < PE; i++) + if(CurrPe == i) + a_we[StateWr == ST_WR_1][i] = '1; + + curr_pe_n = (CurrPe == PE-1)? 0 : CurrPe + 1; + wr_pntr_n = (CurrPe == PE-1)? ((WrPntr == N_TLS-1)? 0 : WrPntr + 1) : WrPntr; + end + end + endcase + end + + //=== Reader ============================================================ + + //--- Registers --------------------------------------------------------- + logic [N_TLS_BITS-1:0] RdPntr = '0; + logic [N_TLS_BITS-1:0] rd_pntr_n; + + logic [N_REPS_BITS-1:0] Reps = '0; + logic [N_REPS_BITS-1:0] reps_n; + + logic [1:0] VldS0 = '0; + logic [1:0] vld_s0_n; + + logic [1:0] VldS1 = '0; + logic [1:0] vld_s1_n; + + logic Vld = 0; + logic vld_n; + + logic [PE-1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] Odat = '0; + logic [PE-1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] odat_n; + + //--- Signals ----------------------------------------------------------- + logic [1:0][WGT_ADDR_BITS-1:0] b_addr; + logic [1:0][PE-1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] odat_ram; + + //--- Sequential -------------------------------------------------------- + always_ff @(posedge clk) begin + if(rst) begin + StateRd <= ST_RD_0; + RdPntr <= '0; + Reps <= '0; + VldS0 <= '0; + VldS1 <= '0; + Vld <= 0; + Odat <= 'x; + end + else begin + StateRd <= state_rd_n; + RdPntr <= rd_pntr_n; + Reps <= reps_n; + VldS0 <= vld_s0_n; + VldS1 <= vld_s1_n; + Vld <= vld_n; + Odat <= odat_n; + end + end + + //--- Next State -------------------------------------------------------- + always_comb begin + state_rd_n = StateRd; + + case(StateRd) + ST_RD_0: + if(ordy && ((StateWr == ST_WR_0)? (WrPntr > RdPntr) : 1)) + if((RdPntr == N_TLS-1) && (Reps == N_REPS-1)) + state_rd_n = ST_RD_1; + + ST_RD_1: + if(ordy && ((StateWr == ST_WR_1)? (WrPntr > RdPntr) : 1)) + if((RdPntr == N_TLS-1) && (Reps == N_REPS-1)) + state_rd_n = ST_RD_0; + endcase + end + + //--- Datapath ---------------------------------------------------------- + always_comb begin + rd_pntr_n = RdPntr; + reps_n = Reps; + + for(int i = 0; i < 2; i++) begin + vld_s0_n[i] = ordy? 0 : VldS0[i]; + vld_s1_n[i] = ordy? VldS0[i] : VldS1[i]; + end + + vld_n = ordy? |VldS1 : Vld; + odat_n = ordy? (VldS1[0]? odat_ram[0] : odat_ram[1]) : Odat; + + for(int i = 0; i < 2; i++) + b_addr[i] = RdPntr; + + case(StateRd) + ST_RD_0: begin + if(ordy) begin + if((StateWr == ST_WR_0)? (WrPntr > RdPntr) : 1) begin + vld_s0_n[0] = 1; + rd_pntr_n = (RdPntr == N_TLS-1)? 0 : RdPntr + 1; + reps_n = (RdPntr == N_TLS-1)? ((Reps == N_REPS-1)? 0 : Reps + 1) : Reps; + end + end + end + + ST_RD_1: begin + if(ordy) begin + if((StateWr == ST_WR_1)? (WrPntr > RdPntr) : 1) begin + vld_s0_n[1] = 1; + rd_pntr_n = (RdPntr == N_TLS-1)? 0 : RdPntr + 1; + reps_n = (RdPntr == N_TLS-1)? ((Reps == N_REPS-1)? 0 : Reps + 1) : Reps; + end + end + end + endcase + end + + assign ovld = Vld; + assign odat = Odat; + + //=== Weight RAMs ======================================================= + for(genvar i = 0; i < 2; i++) begin : genBank + for(genvar j = 0; j < PE; j++) begin : genPe + ram_p_c #( + .ADDR_BITS(WGT_ADDR_BITS), + .DATA_BITS(RAM_BITS), + .RAM_STYLE("block") + ) inst_ram_tp_c ( + .clk(clk), + .a_en(1), + .a_we(a_we[i][j]), + .a_addr(a_addr[i]), + .b_en(ordy), + .b_addr(b_addr[i]), + .a_data_in(a_data_in[i]), + .a_data_out(), + .b_data_out(odat_ram[i][j]) + ); + end : genPe + end : genBank + +endmodule : local_weight_buffer diff --git a/finn-rtllib/mvu_tiled/acc_stage.sv b/finn-rtllib/mvu_tiled/acc_stage.sv index 108ec9155e..7ab3704492 100644 --- a/finn-rtllib/mvu_tiled/acc_stage.sv +++ b/finn-rtllib/mvu_tiled/acc_stage.sv @@ -65,11 +65,7 @@ module acc_stage #( localparam int unsigned SUM_WIDTH = $clog2(CHAINLEN) + ACCU_WIDTH; uwire [SUM_WIDTH-1:0] tree_sum; - add_multi #( - .N(CHAINLEN), - .DEPTH(TREE_DEPTH), - .ARG_WIDTH(ACCU_WIDTH) - ) inst_add ( + add_multi #(.N(CHAINLEN), .DEPTH(TREE_DEPTH), .ARG_WIDTH(ACCU_WIDTH)) inst_add ( .clk(clk), .rst(rst), .en(en), .arg(add_arg), .sum(tree_sum) diff --git a/finn-rtllib/mvu_tiled/cu_mvau_tiled.sv b/finn-rtllib/mvu_tiled/cu_mvau_tiled.sv index 02412ad52f..c42cd63643 100644 --- a/finn-rtllib/mvu_tiled/cu_mvau_tiled.sv +++ b/finn-rtllib/mvu_tiled/cu_mvau_tiled.sv @@ -270,12 +270,7 @@ module cu_mvau_tiled #( end : genPE //=== Accumulation ====================================================== - acc_stage #( - .CHAINLEN(CHAINLEN), - .PE(PE), - .ACCU_WIDTH(ACCU_WIDTH), - .TH(TH) - ) inst_acc_stage ( + acc_stage #(.CHAINLEN(CHAINLEN), .PE(PE), .ACCU_WIDTH(ACCU_WIDTH), .TH(TH)) inst_acc_stage ( .clk(clk), .rst(rst), .en(en), diff --git a/src/finn/custom_op/fpgadataflow/hwcustomop.py b/src/finn/custom_op/fpgadataflow/hwcustomop.py index 4a1d9ce1fa..1db619cea2 100644 --- a/src/finn/custom_op/fpgadataflow/hwcustomop.py +++ b/src/finn/custom_op/fpgadataflow/hwcustomop.py @@ -376,7 +376,6 @@ def generate_hdl_fetch_weights(self): n_reps = np.prod(self.get_nodeattr("rhs_shape")[:-1]) theight = 1 en_mlo = "EN_MLO" if self.get_nodeattr("mlo_max_iter") else "NO_MLO" - layer_offs = mw * mh # upper bound on how many layers can be supported, set to 64 for now n_max_layers = 64 code_gen_dir = self.get_nodeattr("code_gen_dir_ipgen") @@ -401,7 +400,6 @@ def generate_hdl_fetch_weights(self): "$SIMD$": [str(simd)], "$N_REPS$": [str(n_reps)], "$WEIGHT_WIDTH$": [str(wdt.bitwidth())], - "$LAYER_OFFS$": [str(layer_offs)], "$N_LAYERS$": [str(n_max_layers)], "$TH$": [str(theight)], "$IWSIMD$": [str(iwsimd)], From ad2a3dfdbce642e9966bf7cad26ec6443f302576 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20B=2E=20Preu=C3=9Fer?= Date: Tue, 2 Jun 2026 09:16:20 +0100 Subject: [PATCH 12/17] Reviewed load buffers. --- finn-rtllib/dynload/hdl/dynamic_load.sv | 61 +++++----------- .../fetch_weights/local_weight_buffer.sv | 35 ++++----- finn-rtllib/ram/ram_p_c.sv | 73 ------------------- 3 files changed, 35 insertions(+), 134 deletions(-) delete mode 100644 finn-rtllib/ram/ram_p_c.sv diff --git a/finn-rtllib/dynload/hdl/dynamic_load.sv b/finn-rtllib/dynload/hdl/dynamic_load.sv index 1b41b310a1..1918f4e7db 100644 --- a/finn-rtllib/dynload/hdl/dynamic_load.sv +++ b/finn-rtllib/dynload/hdl/dynamic_load.sv @@ -36,7 +36,8 @@ module dynamic_load #( int unsigned WEIGHT_WIDTH, int unsigned MH, int unsigned MW, - int unsigned N_REPS + int unsigned N_REPS, + parameter RAM_STYLE = "distributed" )( input logic ap_clk, input logic ap_rst_n, @@ -60,8 +61,6 @@ localparam int unsigned N_TLS = SF*NF; localparam int unsigned SIMD_BITS = (SIMD == 1) ? 1 : $clog2(SIMD); localparam int unsigned WGT_ADDR_BITS = (N_TLS == 1) ? 1 : $clog2(N_TLS); -localparam int unsigned RAM_BITS = (WEIGHT_WIDTH + 7)/8 * 8; -localparam int unsigned WGT_EN_BITS = RAM_BITS / 8; localparam int unsigned NF_BITS = (NF == 1) ? 1 : $clog2(NF); localparam int unsigned SF_BITS = (SF == 1) ? 1 : $clog2(SF); localparam int unsigned N_TLS_BITS = (N_TLS == 1) ? 1 : $clog2(N_TLS); @@ -85,9 +84,8 @@ logic[N_TLS_BITS-1:0] curr_sf_C = '0, curr_sf_N; logic[SIMD_BITS-1:0] curr_simd_C = '0, curr_simd_N; // -- Signals -logic [1:0][PE-1:0][SIMD-1:0][WGT_EN_BITS-1:0] a_we; // Bank enables +logic [1:0][SIMD-1:0] a_we; logic [1:0][WGT_ADDR_BITS-1:0] a_addr; -logic [1:0][PE-1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] a_data_in; // -- Offsets for(genvar i = 0; i < NF; i++) begin @@ -147,12 +145,8 @@ always_comb begin : DP_PROC_WR // Buffers a_we = '0; - for(int i = 0; i < 2; i++) begin + for(int i = 0; i < 2; i++) a_addr[i] = offsets[curr_nf_C] + curr_sf_C; - for(int j = 0; j < PE; j++) - for(int k = 0; k < SIMD; k++) - a_data_in[i][j][k] = idat[j]; - end // Write and count case (state_wr_C) @@ -160,16 +154,7 @@ always_comb begin : DP_PROC_WR irdy = 1'b1; if(ivld) begin - for(int i = 0; i < PE; i++) begin - for(int j = 0; j < SIMD; j++) begin - if(curr_simd_C == j) begin - if(state_wr_C == ST_WR_0) - a_we[0][i][j] = '1; - else - a_we[1][i][j] = '1; - end - end - end + a_we[state_wr_C == ST_WR_1][curr_simd_C] = 1; curr_nf_N = (curr_nf_C == NF-1) ? 0 : curr_nf_C + 1; curr_simd_N = (curr_nf_C == NF-1) ? ((curr_simd_C == SIMD-1) ? 0 : curr_simd_C + 1) : curr_simd_C; @@ -295,29 +280,23 @@ assign ovld = vld_C; assign odat = odat_C; // ---------------------------------------------------------------------------- -// Matrix +// Weight RAMs // ---------------------------------------------------------------------------- -for(genvar i = 0; i < 2; i++) begin - for(genvar j = 0; j < PE; j++) begin - for(genvar k = 0; k < SIMD; k++) begin - ram_p_c #( - .ADDR_BITS(WGT_ADDR_BITS), - .DATA_BITS(RAM_BITS), - .RAM_STYLE("distributed") - ) inst_ram_tp_c ( - .clk(ap_clk), - .a_en(1'b1), - .a_we(a_we[i][j][k]), - .a_addr(a_addr[i]), - .b_en(ordy), - .b_addr(b_addr[i]), - .a_data_in(a_data_in[i][j][k]), - .a_data_out(), - .b_data_out(odat_ram[i][j][k]) - ); +for(genvar i = 0; i < 2; i++) begin : genBank + for(genvar k = 0; k < SIMD; k++) begin : genSimd + (* RAM_STYLE = RAM_STYLE *) + logic [PE-1:0][WEIGHT_WIDTH-1:0] Ram[2**WGT_ADDR_BITS]; + logic [PE-1:0][WEIGHT_WIDTH-1:0] RdReg; + + always_ff @(posedge ap_clk) begin + if(a_we[i][k]) Ram[a_addr[i]] <= idat; + if(ordy) begin + RdReg <= Ram[b_addr[i]]; + foreach(RdReg[p]) odat_ram[i][p][k] <= RdReg[p]; + end end - end -end + end : genSimd +end : genBank endmodule : dynamic_load diff --git a/finn-rtllib/fetch_weights/local_weight_buffer.sv b/finn-rtllib/fetch_weights/local_weight_buffer.sv index 722c32f93d..71dbc14024 100644 --- a/finn-rtllib/fetch_weights/local_weight_buffer.sv +++ b/finn-rtllib/fetch_weights/local_weight_buffer.sv @@ -37,7 +37,8 @@ module local_weight_buffer #( int unsigned MH, int unsigned MW, int unsigned N_REPS, - int unsigned DBG = 0 + int unsigned DBG = 0, + parameter RAM_STYLE = "block" )( input logic clk, input logic rst, @@ -57,8 +58,6 @@ module local_weight_buffer #( localparam int unsigned N_TLS = SF * NF; localparam int unsigned PE_BITS = (PE == 1)? 1 : $clog2(PE); localparam int unsigned WGT_ADDR_BITS = $clog2(NF * SF); - localparam int unsigned RAM_BITS = (SIMD*WEIGHT_WIDTH + 7)/8 * 8; - localparam int unsigned WGT_EN_BITS = RAM_BITS / 8; localparam int unsigned N_TLS_BITS = $clog2(N_TLS); localparam int unsigned N_REPS_BITS = $clog2(N_REPS); @@ -80,7 +79,7 @@ module local_weight_buffer #( logic [PE_BITS-1:0] curr_pe_n; //--- Signals ----------------------------------------------------------- - logic [1:0][PE-1:0][WGT_EN_BITS-1:0] a_we; + logic [1:0][PE-1:0] a_we; logic [1:0][WGT_ADDR_BITS-1:0] a_addr; logic [1:0][SIMD-1:0][WEIGHT_WIDTH-1:0] a_data_in; @@ -139,7 +138,7 @@ module local_weight_buffer #( if(ivld) begin for(int i = 0; i < PE; i++) if(CurrPe == i) - a_we[StateWr == ST_WR_1][i] = '1; + a_we[StateWr == ST_WR_1][i] = 1; curr_pe_n = (CurrPe == PE-1)? 0 : CurrPe + 1; wr_pntr_n = (CurrPe == PE-1)? ((WrPntr == N_TLS-1)? 0 : WrPntr + 1) : WrPntr; @@ -257,21 +256,17 @@ module local_weight_buffer #( //=== Weight RAMs ======================================================= for(genvar i = 0; i < 2; i++) begin : genBank for(genvar j = 0; j < PE; j++) begin : genPe - ram_p_c #( - .ADDR_BITS(WGT_ADDR_BITS), - .DATA_BITS(RAM_BITS), - .RAM_STYLE("block") - ) inst_ram_tp_c ( - .clk(clk), - .a_en(1), - .a_we(a_we[i][j]), - .a_addr(a_addr[i]), - .b_en(ordy), - .b_addr(b_addr[i]), - .a_data_in(a_data_in[i]), - .a_data_out(), - .b_data_out(odat_ram[i][j]) - ); + (* RAM_STYLE = RAM_STYLE *) + logic [SIMD-1:0][WEIGHT_WIDTH-1:0] Ram[2**WGT_ADDR_BITS]; + logic [SIMD-1:0][WEIGHT_WIDTH-1:0] RdReg; + + always_ff @(posedge clk) begin + if(a_we[i][j]) Ram[a_addr[i]] <= a_data_in[i]; + if(ordy) begin + RdReg <= Ram[b_addr[i]]; + odat_ram[i][j] <= RdReg; + end + end end : genPe end : genBank diff --git a/finn-rtllib/ram/ram_p_c.sv b/finn-rtllib/ram/ram_p_c.sv deleted file mode 100644 index db47fafe18..0000000000 --- a/finn-rtllib/ram/ram_p_c.sv +++ /dev/null @@ -1,73 +0,0 @@ -/****************************************************************************** - * Copyright (C) 2024, Advanced Micro Devices, Inc. - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, - * this list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, - * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR - * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR - * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, - * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, - * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, - * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR - * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF - * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - *****************************************************************************/ - -module ram_p_c #( - int unsigned ADDR_BITS, - int unsigned DATA_BITS, - parameter RAM_STYLE = "block" -) ( - input logic clk, - input logic a_en, - input logic [(DATA_BITS/8)-1:0] a_we, - input logic [ADDR_BITS-1:0] a_addr, - input logic b_en, - input logic [ADDR_BITS-1:0] b_addr, - input logic [DATA_BITS-1:0] a_data_in, - output logic [DATA_BITS-1:0] a_data_out, - output logic [DATA_BITS-1:0] b_data_out -); - - localparam int unsigned DEPTH = 2**ADDR_BITS; - - (* ram_style = RAM_STYLE *) logic [DATA_BITS-1:0] ram[DEPTH]; - - logic [DATA_BITS-1:0] a_data_reg = 0; - logic [DATA_BITS-1:0] b_data_reg = 0; - - always_ff @(posedge clk) begin - if(a_en) begin - for (int i = 0; i < (DATA_BITS/8); i++) begin - if(a_we[i]) begin - ram[a_addr][(i*8)+:8] <= a_data_in[(i*8)+:8]; - end - end - a_data_reg <= ram[a_addr]; - a_data_out <= a_data_reg; - end - if(b_en) begin - b_data_reg <= ram[b_addr]; - b_data_out <= b_data_reg; - end - //end - end - -endmodule : ram_p_c From c11d51815bd8d8e6234f8e16af57347c32b1044b Mon Sep 17 00:00:00 2001 From: auphelia Date: Fri, 5 Jun 2026 08:31:43 +0100 Subject: [PATCH 13/17] [CustomOp] Remove other occurrences of obsolete rtl file --- src/finn/custom_op/fpgadataflow/matrixvectoractivation.py | 1 - src/finn/custom_op/fpgadataflow/templates.py | 1 - 2 files changed, 2 deletions(-) diff --git a/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py b/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py index 444f2420dd..d197992077 100644 --- a/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py +++ b/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py @@ -1014,7 +1014,6 @@ def code_generation_ipi(self): strm_tmpl_name = strm_tmpl[:-2] sourcefiles = [ os.path.join(code_gen_dir, strm_tmpl), - ram_rtllib_dir + "ram_p_c.sv", dyn_rtllib_dir + "dynamic_load.sv", ] for f in sourcefiles: diff --git a/src/finn/custom_op/fpgadataflow/templates.py b/src/finn/custom_op/fpgadataflow/templates.py index 18b3d708a1..5ec9159c2b 100644 --- a/src/finn/custom_op/fpgadataflow/templates.py +++ b/src/finn/custom_op/fpgadataflow/templates.py @@ -344,7 +344,6 @@ add_files -norecurse "$::env(FINN_ROOT)/finn-rtllib/cdma/cdma_x/cdma_x_rd.sv" add_files -norecurse "$::env(FINN_ROOT)/finn-rtllib/cdma/cdma_x/cdma_x_wr.sv" add_files -norecurse "$::env(FINN_ROOT)/finn-rtllib/skid/skid.sv" -add_files -norecurse "$::env(FINN_ROOT)/finn-rtllib/ram/ram_p_c.sv" add_files -norecurse "$::env(FINN_ROOT)/finn-rtllib/mlo/infrastructure/intermediate_frames.sv" add_files -norecurse "$::env(FINN_ROOT)/finn-rtllib/mlo/infrastructure/mux.sv" add_files -norecurse "$::env(FINN_ROOT)/finn-rtllib/mlo/infrastructure/demux.sv" From aea144e0b74e171ad1537d3dec44e7ec142d6bbc Mon Sep 17 00:00:00 2001 From: dkorolij Date: Tue, 9 Jun 2026 11:44:43 +0100 Subject: [PATCH 14/17] small fix, remove ram_p_c. --- src/finn/custom_op/fpgadataflow/matrixvectoractivation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py b/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py index d197992077..0d25732016 100644 --- a/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py +++ b/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py @@ -1052,7 +1052,6 @@ def code_generation_ipi(self): strm_tmpl_name = strm_tmpl[:-2] sourcefiles = [ os.path.join(code_gen_dir, strm_tmpl), - ram_rtllib_dir + "ram_p_c.sv", reg_rtllib_dir + "skid.sv", que_rtllib_dir + "Q_srl.v", fwg_rtllib_dir + "fetch_weights.sv", From 7a1007f4f7cb75e72a509eaa4d45728abc8c3039 Mon Sep 17 00:00:00 2001 From: Shane Fleming Date: Wed, 17 Jun 2026 09:02:02 +0100 Subject: [PATCH 15/17] [Tiled MVU] Adding an assertion to guard against configurations that caused issues in the bert build flow --- .../fpgadataflow/rtl/matrixvectoractivation_rtl.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py b/src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py index 69c8c50126..fb30d0861e 100644 --- a/src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py +++ b/src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py @@ -343,6 +343,20 @@ def prepare_codegen_default(self, fpgapart, clk): "Clock pumping an input of SIMD=1 is not meaningful. Please increase SIMD." ) + # Check to make sure that tile size divides the number of input vectors evenly + # otherwise we hit issues with output being silently dropped resulting in an + # eventual stall. + theight = self.get_nodeattr("TH") + if theight > 1: + num_inp_vec = int(np.prod(self.get_nodeattr("numInputVectors"))) + if num_inp_vec % theight != 0: + valid_th = [t for t in range(1, num_inp_vec + 1) if num_inp_vec % t == 0] + raise Exception( + "%s: TH=%d does not divide numInputVectors=%d. The tiled MVU " + "requires TH | numInputVectors; choose a divisor (valid: %s)." + % (self.onnx_node.name, theight, num_inp_vec, valid_th) + ) + dsp_block = get_dsp_block(fpgapart) code_gen_dict = {} code_gen_dict["$IS_MVU$"] = [str(1)] From 826ff8cdd90172d9e96878ceeb2a0edd61279ee3 Mon Sep 17 00:00:00 2001 From: Shane Fleming Date: Thu, 18 Jun 2026 10:19:35 +0100 Subject: [PATCH 16/17] [Tiled MVAU] cleaning up code comment (copilot suggestion) Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- .../custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py b/src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py index fb30d0861e..047c653ec2 100644 --- a/src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py +++ b/src/finn/custom_op/fpgadataflow/rtl/matrixvectoractivation_rtl.py @@ -343,9 +343,8 @@ def prepare_codegen_default(self, fpgapart, clk): "Clock pumping an input of SIMD=1 is not meaningful. Please increase SIMD." ) - # Check to make sure that tile size divides the number of input vectors evenly - # otherwise we hit issues with output being silently dropped resulting in an - # eventual stall. + # Check to make sure that tile size divides the number of input vectors evenly; + # otherwise the final partial tile can cause output to be dropped and eventually stall. theight = self.get_nodeattr("TH") if theight > 1: num_inp_vec = int(np.prod(self.get_nodeattr("numInputVectors"))) From 4bdb7f830058750d6c9576ec288abf76a854c819 Mon Sep 17 00:00:00 2001 From: auphelia Date: Mon, 22 Jun 2026 11:20:19 +0100 Subject: [PATCH 17/17] Linting --- src/finn/custom_op/fpgadataflow/matrixvectoractivation.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py b/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py index 0d25732016..370b6580a3 100644 --- a/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py +++ b/src/finn/custom_op/fpgadataflow/matrixvectoractivation.py @@ -1002,7 +1002,6 @@ def code_generation_ipi(self): ) # dynamic loader - ram_rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/ram/") dyn_rtllib_dir = os.path.join( os.environ["FINN_ROOT"], "finn-rtllib/dynload/hdl/" ) @@ -1037,7 +1036,6 @@ def code_generation_ipi(self): ) # instantiate a fetch weights component and connect it to the IP - ram_rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/ram/") reg_rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/skid/") que_rtllib_dir = os.path.join(os.environ["FINN_ROOT"], "finn-rtllib/fifo/hdl/") fwg_rtllib_dir = os.path.join(