`timescale 1ns/1ps
`default_nettype none

module freq_filter #(
    parameter DATA_BITS = 12,       // 入出力データのBit幅
    parameter COEF_TOTAL_BITS = 8,  // 係数データの合計(整数部+小数部)Bit幅
    parameter COEF_FRAC_BITS = 4,   // 係数データの小数部Bit幅
    parameter FFT_N = 2048          // FFTのPoint数
) (
    input wire clk,      // クロック入力
    input wire nrst,     // リセット入力 (負極性)

    // 入力信号
    input wire din_valid,
    input wire signed [DATA_BITS-1:0] din_re_data,
    input wire signed [DATA_BITS-1:0] din_im_data,

    // 出力信号
    output logic dout_valid,
    output logic signed [DATA_BITS-1:0] dout_re_data,
    output logic signed [DATA_BITS-1:0] dout_im_data
);
    
    //===== Stage 0: 係数ROM 読み出し ==========
    logic [$clog2(FFT_N)-1:0] fft_idx;         // FFT結果の要素番号
    logic [$clog2(FFT_N)-1:0] coef_rom_idx;    // 係数ROMの読出しIndex番号
    logic [COEF_TOTAL_BITS-1:0] coef_st0;      // 係数値

    // 係数ROMの読出しIndexの更新
    always_ff@(posedge clk) begin
        if (~nrst) begin
            fft_idx <= 'b0;
        end else begin
            if (din_valid) begin
                // 入力データを受信したらIndexを進める
                if (fft_idx == (FFT_N - 1))
                    fft_idx <= 'b0;
                else 
                    fft_idx <= fft_idx + 'b1; 
            end
        end
    end

    always_comb begin
        coef_rom_idx = (fft_idx <= FFT_N/2)? fft_idx : (FFT_N - fft_idx);
    end

    // 係数ROMモジュール
    //   - "index"を更新してから1サイクル後に"coef"が出力される
    freq_filter_coef_rom # (
        .COEF_TOTAL_BITS(COEF_TOTAL_BITS), 
        .N(FFT_N/2)
    ) coef_rom (
        .clk(clk),
        .index(coef_rom_idx),
        .coef(coef_st0)
    );

    
    // 入力データのパイプラインレジスタ
    logic din_valid_st0;
    logic signed [DATA_BITS-1:0] din_re_data_st0;
    logic signed [DATA_BITS-1:0] din_im_data_st0;

    always_ff@(posedge clk) begin
        if (~nrst) begin
            din_valid_st0 <= 'b0;
            din_re_data_st0 <= 'b0;
            din_im_data_st0 <= 'b0;
        end else begin
            din_valid_st0 <= din_valid;
            din_re_data_st0 <= din_re_data;
            din_im_data_st0 <= din_im_data;
        end
    end

    //==== Stage 1: 係数乗算 =================
    logic mod_data_valid_st1;
    logic signed [COEF_TOTAL_BITS+DATA_BITS-1:0] mod_re_data_st1;
    logic signed [COEF_TOTAL_BITS+DATA_BITS-1:0] mod_im_data_st1;

    always_ff@(posedge clk) begin
        if(~nrst) begin
            mod_data_valid_st1  <= 'b0;
            mod_re_data_st1 <= 'b0;
            mod_im_data_st1 <= 'b0;
        end else begin
            mod_data_valid_st1 <= din_valid_st0;
            mod_re_data_st1 <= din_re_data_st0 * $signed({1'b0, coef_st0});
            mod_im_data_st1 <= din_im_data_st0 * $signed({1'b0, coef_st0});
        end
    end 
   
    
    //==== Stage 2: 桁補正 + 飽和 ========
    logic dout_valid_reg;
    logic signed [DATA_BITS-1:0] dout_re_data_reg;
    logic signed [DATA_BITS-1:0] dout_im_data_reg;

    localparam signed OUT_DATA_RANGE_MIN = -(1<<(DATA_BITS-1));
    localparam signed OUT_DATA_RANGE_MAX = (1<<(DATA_BITS-1)) - 1;

    function automatic signed [COEF_TOTAL_BITS-COEF_FRAC_BITS+DATA_BITS-1:0] saturation;
        input signed [COEF_TOTAL_BITS-COEF_FRAC_BITS+DATA_BITS-1:0] data;
        begin
            if (data < OUT_DATA_RANGE_MIN)
                saturation = OUT_DATA_RANGE_MIN;
            else if (data > OUT_DATA_RANGE_MAX)
                saturation = OUT_DATA_RANGE_MAX;
            else 
                saturation = data;
        end
    endfunction

    always_ff@(posedge clk) begin
        logic signed [COEF_TOTAL_BITS-COEF_FRAC_BITS+DATA_BITS-1:0] re_data_scaled, im_data_scaled;
        logic signed [COEF_TOTAL_BITS-COEF_FRAC_BITS+DATA_BITS-1:0] re_data_saturated, im_data_saturated;

        if(~nrst) begin
            dout_valid_reg <= 'b0;
        end else begin
            // 桁補正: 係数値の小数部のビット数分を右シフト (符号付なので算術シフト)
            re_data_scaled = mod_re_data_st1 >>> COEF_FRAC_BITS;
            im_data_scaled = mod_im_data_st1 >>> COEF_FRAC_BITS;
            
            // 飽和: 出力ビット数で表現できる範囲を超えている場合は、正側の最大値or負側の最小値に上書きする
            re_data_saturated = saturation(re_data_scaled);
            im_data_saturated = saturation(im_data_scaled);

            // 出力
            dout_valid_reg <= mod_data_valid_st1;
            dout_re_data_reg <= re_data_saturated;
            dout_im_data_reg <= im_data_saturated;
        end 
    end

    //==== 出力 ===============================
    always_comb begin
        dout_valid <= dout_valid_reg;
        dout_re_data <= dout_re_data_reg;
        dout_im_data <= dout_im_data_reg;
    end 
  
endmodule

`default_nettype wire