/*******************************************************************************
* Copyright (C) 2023 Intel Corporation
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

//@HEADER
// ***************************************************
//
// HPCG: High Performance Conjugate Gradient Benchmark
//
// Contact:
// Michael A. Heroux ( maherou@sandia.gov)
// Jack Dongarra     (dongarra@eecs.utk.edu)
// Piotr Luszczek    (luszczek@eecs.utk.edu)
//
// ***************************************************
//@HEADER

#pragma once

#include "../Helpers.hpp"
#include "../EsimdHelpers.hpp"
#include "esb_unrolls.hpp"

// for naming kernels in cgh.parallel_for
// SpTRSV
template <int block_size, bool isFused>
class esb4_trsv_fwd_esimd_kernel;

template <int block_size, bool isFused>
class esb4_trsv_bwd_esimd_kernel;

// Comment out to disable unroll, large grf should be disabled
// if not using unrolled kernels
#define USE_TRSV_UNROLL_KERNELS

//
// esb forward solve accepting ESB4 arrays for start/stop
//
template <int block_size, bool isFused>
sycl::event sparse_esb4_trsv_fwd_esimd(sycl::queue &queue,
                                       const local_int_t nrows,
                                       const local_int_t nBlocks,
                                       const local_int_t *blockptr_st,
                                       const local_int_t *blockptr_en,
                                       const local_int_t *colind,
                                       const double *values,
                                       const double *diag,
                                       const local_int_t nColors,
                                       const local_int_t *xcolors_host,
                                       double *x,
                                       double *y,
                                       const std::vector<sycl::event> &dependencies)
{
    sycl::event last;

//    printf("trsv esb4 fwd:  nrows = %d, nBlocks = %d, nColors = %d\n", nrows, nBlocks, nColors);

    for (local_int_t color = 0; color < nColors; color++) {

        const local_int_t colorStart     = xcolors_host[color]; // by rows
        const local_int_t colorEnd       = xcolors_host[color + 1]; // by rows

        const local_int_t blockStart = floor_div(colorStart, block_size);
        const local_int_t blockEnd   = ceil_div(colorEnd, block_size);
        const local_int_t blockRange = blockEnd - blockStart;
        const local_int_t firstBlockFullyInColor = (blockStart*block_size) == colorStart;
        const local_int_t lastBlockFullyInColor  = (blockEnd*block_size) == colorEnd;

        const local_int_t nWG = 2;
        const local_int_t blockRangeRd = ceil_div(blockRange, nWG) * nWG;

       last = queue.submit([&](sycl::handler &cgh) {
           if (color == 0)
               cgh.depends_on(dependencies);
           else
               cgh.depends_on(last);

#define USE_LOWER
#include "trsv_esb4_esimd_kernel.hxx"

           auto kernel = [=](sycl::nd_item<1> item) SYCL_ESIMD_KERNEL {
               const local_int_t offset = item.get_global_id(0);
               const local_int_t block  = blockStart + offset;
               if (offset > blockRange || block >= nBlocks) return;

               const local_int_t vec_st = esimd_lsc_scalar_load<local_int_t, local_int_t, nc, nc>(blockptr_st, block);
               const local_int_t vec_en = esimd_lsc_scalar_load<local_int_t, local_int_t, nc, nc>(blockptr_en, block);

               // switch between block_load/store and gather/scatter on vecs
               const bool use_locmask = ((offset == 0) && (firstBlockFullyInColor == 1)) ||
                                        ((offset == blockRange-1) && (lastBlockFullyInColor == 1));

               if (use_locmask) {
                   // handle overlap of blocks/color
                   const local_int_t row_st = block * block_size;
                   esimd::simd<local_int_t, block_size> iota(0,1);
                   esimd::simd_mask<block_size> locmask( ( row_st + iota >= colorStart) && (row_st + iota < colorEnd) );

                   trsv_esb4_masked_esimd_kernel(block, locmask, vec_st, vec_en, colind, values, x, y, diag);
               }
               else {
                   trsv_esb4_esimd_kernel(block, vec_st, vec_en, colind, values, x, y, diag);
               }

           };
           cgh.parallel_for<class esb4_trsv_fwd_esimd_kernel<block_size, isFused>>(sycl::nd_range<1>(blockRangeRd, nWG), kernel);
       });

    } // for color < nColors

    return last;
}



//
// esb backward solve accepting ESB4 arrays for start/stop
//
template <int block_size, bool isFused>
sycl::event sparse_esb4_trsv_bwd_esimd(sycl::queue &queue,
                                       const local_int_t nrows,
                                       const local_int_t nBlocks,
                                       const local_int_t *blockptr_st,
                                       const local_int_t *blockptr_en,
                                       const local_int_t *colind,
                                       const double *values,
                                       const double *diag,
                                       const local_int_t nColors,
                                       const local_int_t *xcolors_host,
                                       double *x,
                                       double *y,
                                       const std::vector<sycl::event> &dependencies)
{
    sycl::event last;

//    printf("trsv esb4 bwd:  nrows = %d, nBlocks = %d, nColors = %d\n", nrows, nBlocks, nColors);

    for (local_int_t color = nColors - 1; color >= 0; color--) {

        const local_int_t colorStart     = xcolors_host[color]; // by rows
        const local_int_t colorEnd       = xcolors_host[color + 1]; // by rows

        const local_int_t blockStart = floor_div(colorStart, block_size);
        const local_int_t blockEnd   = ceil_div(colorEnd, block_size);
        const local_int_t blockRange = blockEnd - blockStart;
        const local_int_t firstBlockFullyInColor = (blockStart*block_size) == colorStart;
        const local_int_t lastBlockFullyInColor  = (blockEnd*block_size) == colorEnd;

        const local_int_t nWG = 2;
        const local_int_t blockRangeRd = ceil_div(blockRange, nWG) * nWG;

        last = queue.submit([&](sycl::handler &cgh) {
            if (color == nColors - 1)
                cgh.depends_on(dependencies);
            else
                cgh.depends_on(last);

#undef USE_LOWER
#include "trsv_esb4_esimd_kernel.hxx"

            auto kernel = [=](sycl::nd_item<1> item) SYCL_ESIMD_KERNEL {
                const local_int_t offset = item.get_global_id(0);
                const local_int_t block  = blockStart + offset;
                if (offset > blockRange || block >= nBlocks) return;

                const local_int_t vec_st = esimd_lsc_scalar_load<local_int_t, local_int_t, nc, nc>(blockptr_st, block);
                const local_int_t vec_en = esimd_lsc_scalar_load<local_int_t, local_int_t, nc, nc>(blockptr_en, block);

                // switch between block_load/store and gather/scatter on vecs
                const bool use_locmask = ((offset == 0) && (firstBlockFullyInColor == 1)) ||
                                         ((offset == blockRange-1) && (lastBlockFullyInColor == 1));

                if (use_locmask) {
                    // handle overlap of blocks/color
                    const local_int_t row_st = block * block_size;
                    esimd::simd<local_int_t, block_size> iota(0,1);
                    esimd::simd_mask<block_size> locmask( ( row_st + iota >= colorStart) && (row_st + iota < colorEnd) );

                    trsv_esb4_masked_esimd_kernel(block, locmask, vec_st, vec_en, colind, values, x, y, diag);
                }
                else {
                    trsv_esb4_esimd_kernel(block, vec_st, vec_en, colind, values, x, y, diag);
                }
            };
            cgh.parallel_for<class esb4_trsv_bwd_esimd_kernel<block_size, isFused>>(sycl::nd_range<1>(blockRangeRd, nWG), kernel);
        });

    } // for color >= 0

    return last;
}
