/*******************************************************************************
* Copyright (C) 2020 Intel Corporation
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
*
*  Content:
*       This example demonstrates use of oneAPI Math Kernel Library (oneMKL)
*       SPARSE BLAS APIs with sycl::buffers to solve a system of linear equations (Ax=b)
*       by preconditioned CG method with symmetric Gauss-Seidel preconditioner:
*
*       Solve Ax = b:
*
*       x_0 initial guess
*       r_0 = b - A*x_0
*       solve M*z_0 = r_0 for z_0
*       p_1 = z_0
*       k = 0
*       while (||z_k|| / ||z_0|| > relTol and k < maxIter )
*           Ap_{k+1} = A*p_{k+1}
*           alpha_{k+1} = (r_k, z_k) / (p_{k+1}, Ap_{k+1})
*
*           x_{k+1} = x_k + alpha_{k+1} * p_{k+1}
*           r_{k+1} = r_k - alpha_{k+1} * Ap_{k+1}
*
*           solve M*z_{k+1} = r_{k+1} for z_{k+1}
*           if (||z_{k+1}|| < absTol) break with convergence
*
*           k=k+1
*           beta_k = (r_k, z_k) / (r_{k-1}, z_{k-1})
*           p_{k+1} = z_k + beta_k * p_k
*       end
*
*       where A = L+D+L^T; M = (D+L)*D^{-1}*(D+L^t).
*
*       Note that
*
*         x is solution
*         r is residual
*         z is preconditioned residual
*         p is search direction
*
*       and in this example, we are using norm of z and z_0 for stopping criteria.
*
*
*       The supported floating point data types for matrix data in this example
*       are:
*           float
*           double
*
*       This example uses a matrix in CSR format.
*
*/

// stl includes
#include <algorithm>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <iomanip>
#include <iterator>
#include <limits>
#include <list>
#include <vector>

#include "mkl.h"
#include "oneapi/mkl.hpp"
#include <sycl/sycl.hpp>

// local includes
#include "common_for_examples.hpp"
#include "./include/common_for_sparse_examples.hpp"

template <typename dataType, typename intType>
class extractDiagonalClass;

template <typename dataType, typename intType>
class modifyDiagonalClass;

template <typename dataType, typename intType>
class diagonalMVClass;

//
// extract diagonal from matrix
//
template <typename dataType, typename intType>
static void extract_diagonal(sycl::queue q,
                             const intType n,
                             sycl::buffer<intType, 1> &ia_buffer,
                             sycl::buffer<intType, 1> &ja_buffer,
                             sycl::buffer<dataType, 1> &a_buffer,
                             sycl::buffer<dataType, 1> &d_buffer)
{
    q.submit([&](sycl::handler &cgh) {
        auto ia = (ia_buffer).template get_access<sycl::access::mode::read>(cgh);
        auto ja = (ja_buffer).template get_access<sycl::access::mode::read>(cgh);
        auto a  = (a_buffer).template get_access<sycl::access::mode::read>(cgh);
        auto d  = (d_buffer).template get_access<sycl::access::mode::write>(cgh);

        auto kernel = [=](sycl::item<1> item) {
            const int row = item.get_id(0);
            for (intType i = ia[row]; i < ia[row + 1]; i++) {
                if (ja[i] == row) {
                    dataType diagVal = a[i];
                    d[row] = diagVal;
                    break;
                }
            }
        };
        cgh.parallel_for<class extractDiagonalClass<dataType, intType>>(
                sycl::range<1>(n), kernel);
    });
}


//
// Modify diagonal value in matrix
//
// D <- new_diagVal * I
//
template <typename dataType, typename intType>
static void modify_diagonal(sycl::queue q,
                            const dataType new_diagVal,
                            const intType n,
                            sycl::buffer<intType, 1> &ia_buffer,
                            sycl::buffer<intType, 1> &ja_buffer,
                            sycl::buffer<dataType, 1> &a_buffer, // to be modified
                            sycl::buffer<dataType, 1> &d_buffer) // to be modified
{
    assert(new_diagVal != dataType(0.0) );

    q.submit([&](sycl::handler &cgh) {
        auto ia = (ia_buffer).template get_access<sycl::access::mode::read>(cgh);
        auto ja = (ja_buffer).template get_access<sycl::access::mode::read>(cgh);
        auto a  = (a_buffer).template get_access<sycl::access::mode::write>(cgh);
        auto d  = (d_buffer).template get_access<sycl::access::mode::write>(cgh);

        auto kernel = [=](sycl::item<1> item) {
            const int row = item.get_id(0);
            for (intType i = ia[row]; i < ia[row + 1]; i++) {
                if (ja[i] == row) {
                    a[i] = new_diagVal;
                    d[row] = new_diagVal;
                    break;
                }
            }
        };
        cgh.parallel_for<class modifyDiagonalClass<dataType, intType>>(
                sycl::range<1>(n), kernel);
    });
}



//
// Scale by diagonal
//
// t = D * t
//
template <typename dataType, typename intType>
static void diagonal_mv(sycl::queue q,
                        const intType n,
                        sycl::buffer<dataType, 1> &d_buffer,
                        sycl::buffer<dataType, 1> &t_buffer)
{
    q.submit([&](sycl::handler &cgh) {
        auto d = (d_buffer).template get_access<sycl::access::mode::write>(cgh);
        auto t = (t_buffer).template get_access<sycl::access::mode::read_write>(cgh);
        auto kernel = [=](sycl::item<1> item) {
            const int row = item.get_id(0);
            t[row] *= d[row];
        };
        cgh.parallel_for<class diagonalMVClass<dataType, intType>>(
                sycl::range<1>(n), kernel);
    });
}


//
// Gauss-Seidel Preconditioner
//
// solve M z = r   where M = (L+D)*inv(D)*(D+U)
//
// t = inv(D+L) * r;   // forward triangular solve
// t = D*t             // diagonal mv
// z = inv(D+U) * t    // backward triangular solve
//
template <typename dataType, typename intType>
static void precon_gauss_seidel(sycl::queue q,
                                const intType n,
                                oneapi::mkl::sparse::matrix_handle_t csrA,
                                sycl::buffer<dataType, 1> &d_buffer,
                                sycl::buffer<dataType, 1> &r_buffer,
                                sycl::buffer<dataType, 1> &t_buffer, // temporary workspace
                                sycl::buffer<dataType, 1> &z_buffer) // output
{
    oneapi::mkl::sparse::trsv(q, oneapi::mkl::uplo::lower, oneapi::mkl::transpose::nontrans,
            oneapi::mkl::diag::nonunit, dataType(1.0) /* alpha */, csrA, r_buffer, t_buffer);
    diagonal_mv<dataType, intType>(q, n, d_buffer, t_buffer);
    oneapi::mkl::sparse::trsv(q, oneapi::mkl::uplo::upper, oneapi::mkl::transpose::nontrans,
            oneapi::mkl::diag::nonunit, dataType(1.0) /* alpha */, csrA, t_buffer, z_buffer);
}



template <typename dataType, typename intType>
int run_sparse_blas_example(sycl::queue &q)
{
    bool good = true;

    // handle for sparse matrix
    oneapi::mkl::sparse::matrix_handle_t csrA = nullptr;

    try {
        // Matrix data size
        intType size  = 4;
        intType n = size * size * size; // A is n x n

        // PCG settings
        const intType maxIter = 500;
        const dataType relTol = 1.0e-5;
        const dataType absTol = 1.0e-3;

        // Input matrix in CSR format
        std::vector<intType, mkl_allocator<intType, 64>> ia;
        std::vector<intType, mkl_allocator<intType, 64>> ja;
        std::vector<dataType, mkl_allocator<dataType, 64>> a;

        ia.resize(n + 1);
        ja.resize(27 * n);
        a.resize(27 * n);

        generate_sparse_matrix<dataType, intType>(size, ia, ja, a);

        const intType a_nnz = ia[n]; // 0 based indexing

        // Vectors x and y
        std::vector<dataType, mkl_allocator<dataType, 64>> x;
        std::vector<dataType, mkl_allocator<dataType, 64>> b;
        x.resize(n);
        b.resize(n);

        // Init right hand side and vector x
        for (int i = 0; i < n; i++) {
            b[i] = set_fp_value(dataType(1.0), dataType(0.0));
            x[i] = set_fp_value(dataType(0.0), dataType(0.0));
        }

        //
        // Execute CG solver on Ax = b
        //

        std::cout << "\n\t\tSparse PCG parameters:\n";

        std::cout << "\t\t\tA size: (" << n << ", " << n << ") with nnz = " << a_nnz << " elements stored" << std::endl;
        std::cout << "\t\t\tPreconditioner = Symmetric Gauss-Seidel" << std::endl;
        std::cout << "\t\t\tmax iterations = " << maxIter << std::endl;
        std::cout << "\t\t\trelative tolerance = " << relTol << std::endl;
        std::cout << "\t\t\tabsolute tolerance = " << absTol << std::endl;

        sycl::buffer<intType, 1> ia_buffer(ia.data(), ia.data() + n + 1);
        sycl::buffer<intType, 1> ja_buffer(ja.data(), ja.data() + a_nnz);
        sycl::buffer<dataType, 1> a_buffer(a.data(), a.data() + a_nnz);
        sycl::buffer<dataType, 1> x_buffer(x.data(), x.data() + n);
        sycl::buffer<dataType, 1> b_buffer(b.data(), b.data() + n);
        sycl::buffer<dataType, 1> r_buffer((sycl::range<1>(n)));
        sycl::buffer<dataType, 1> z_buffer((sycl::range<1>(n)));
        sycl::buffer<dataType, 1> p_buffer((sycl::range<1>(n)));
        sycl::buffer<dataType, 1> t_buffer((sycl::range<1>(n)));
        sycl::buffer<dataType, 1> d_buffer((sycl::range<1>(n)));
        sycl::buffer<dataType, 1> temp_buffer((sycl::range<1>(1)));

        extract_diagonal(q, n, ia_buffer, ja_buffer, a_buffer, d_buffer);

        modify_diagonal(q, dataType(52.0), n, ia_buffer, ja_buffer, a_buffer, d_buffer);

        oneapi::mkl::sparse::init_matrix_handle(&csrA);

        oneapi::mkl::sparse::set_csr_data(q, csrA, n, n, oneapi::mkl::index_base::zero,
                                          ia_buffer, ja_buffer, a_buffer);

        // properties set on csrA which aide in internal optimizations (especially for trsv)
        oneapi::mkl::sparse::set_matrix_property(csrA, oneapi::mkl::sparse::property::symmetric);
        oneapi::mkl::sparse::set_matrix_property(csrA, oneapi::mkl::sparse::property::sorted);

        oneapi::mkl::sparse::optimize_trsv(q, oneapi::mkl::uplo::lower,
                                           oneapi::mkl::transpose::nontrans,
                                           oneapi::mkl::diag::nonunit, csrA);
        oneapi::mkl::sparse::optimize_trsv(q, oneapi::mkl::uplo::upper,
                                           oneapi::mkl::transpose::nontrans,
                                           oneapi::mkl::diag::nonunit, csrA);
        oneapi::mkl::sparse::optimize_gemv(q, oneapi::mkl::transpose::nontrans, csrA);


        // initial residual equal to r_0 = b - A * x_0
        oneapi::mkl::sparse::gemv(q, oneapi::mkl::transpose::nontrans, 1.0, csrA,
                x_buffer, 0.0, r_buffer); // r = A * x_0
        oneapi::mkl::blas::axpby(q, n, 1.0, b_buffer, 1, -1.0, r_buffer, 1); // r_0 = b - r (= b-A*x_0)

        // Calculation z_0 = M^{-1}*r_0
        precon_gauss_seidel(q, n, csrA, d_buffer, r_buffer, t_buffer, z_buffer);

        // p_0 = z_0
        oneapi::mkl::blas::copy(q, n, z_buffer, 1, p_buffer, 1);

        // Calculate initial norm of preconditioned residual, zTz_0 = (z_0, z_0)
        dataType zTz_0 = set_fp_value(dataType(0.0), dataType(0.0));
        // temp_buffer = ||z_0||^2 = (z_0, z_0)
        oneapi::mkl::blas::nrm2(q, n, z_buffer, 1, temp_buffer);
        {
            auto temp_accessor = temp_buffer.get_host_access(sycl::read_only);
            zTz_0 = temp_accessor[0];
        }
        dataType zTz = zTz_0;

        // Start of main PCG algorithm
        std::int32_t k = 0;
        dataType alpha, beta, rTz;

        // rTz = dot(r_0, z_0)
        oneapi::mkl::blas::dot(q, n, r_buffer, 1, z_buffer, 1, temp_buffer);
        {
            auto temp_accessor = temp_buffer.get_host_access(sycl::read_only);
            rTz = temp_accessor[0];
        }

        while ( std::sqrt(zTz / zTz_0) > relTol && k < maxIter) {
            // Calculate t_k = A*p_k
            oneapi::mkl::sparse::gemv(q, oneapi::mkl::transpose::nontrans, 1.0, csrA,
                                      p_buffer, 0.0, t_buffer);

            // Calculate alpha_k  = (r_k, z_k) / (p_k, Ap_k)
            //
            // temp_buffer = (p_k, Ap_k)
            oneapi::mkl::blas::dot(q, n, p_buffer, 1, t_buffer, 1, temp_buffer);
            {
                auto temp_accessor = temp_buffer.get_host_access(sycl::read_only);
                alpha = rTz / temp_accessor[0];
            }

            // Calculate x_{k+1} = x_k + alpha_k*p_k
            oneapi::mkl::blas::axpy(q, n, alpha, p_buffer, 1, x_buffer, 1);
            // Calculate r_{k+1} = r_k - alpha_k*A*p_k (note that t = A*p_k right now so it can be reused here)
            oneapi::mkl::blas::axpy(q, n, -alpha, t_buffer, 1, r_buffer, 1);

            // Calculate z_{k+1} = M^{-1}r_{k+1}
            precon_gauss_seidel(q, n, csrA, d_buffer, r_buffer, t_buffer, z_buffer);

            // Calculate current norm of correction
            //
            // temp_buffer = ||z_{k+1}||^2
            oneapi::mkl::blas::nrm2(q, n, z_buffer, 1, temp_buffer);
            {
                auto temp_accessor = temp_buffer.get_host_access(sycl::read_only);
                zTz = temp_accessor[0];
            }

            k++; // increment k counter
            std::cout << "\t\t\t\trelative norm of residual on " << std::setw(4) << k  // output in 1 base indexing
                      << " iteration: " << std::sqrt(zTz / zTz_0)
                      << std::endl;
            if (std::sqrt(zTz) <= absTol) {
                std::cout << "\t\t\t\tabsolute norm of residual on " << std::setw(4) << k // output in 1-based indexing
                    << " iteration: " <<  std::sqrt(zTz) << std::endl;
                break;
            }

            // Calculate beta_{k+1} = (r_{k+1}, z_{k+1}) / (r_k, z_k)
            //
            // temp_buffer = (r_{k+1}, z_{k+1})
            oneapi::mkl::blas::dot(q, n, r_buffer, 1, z_buffer, 1, temp_buffer);
            {
                auto temp_accessor = temp_buffer.get_host_access(sycl::read_only);
                beta = temp_accessor[0] / rTz;
                // rTz = (r_{k+1}, z_{k+1})
                rTz = temp_accessor[0];
            }

            // Calculate p_{k+1} = z_{k+1} + beta_{k+1} * p_k
            oneapi::mkl::blas::axpby(q, n, 1.0, z_buffer, 1, beta, p_buffer, 1);

        } // while sqrt(zTz / zTz_0) > relTol && k < maxIter

        //
        // Determine if we converged or not based on relative and absolute Errors
        //
        if (std::sqrt(zTz) < absTol) {
            std::cout << "" << std::endl;
            std::cout << "\t\tPreconditioned CG process has successfully converged in absolute error in " << k << " steps with" << std::endl;

            good = true;
        }
        else if (k <= maxIter && std::sqrt(zTz / zTz_0) <= relTol) {
            std::cout << "" << std::endl;
            std::cout << "\t\tPreconditioned CG process has successfully converged in relative error in " << k << " steps with" << std::endl;

            good = true;
        } else {
            std::cout << "" << std::endl;
            std::cout << "\t\tPreconditioned CG process has not converged after " << k << " steps with" << std::endl;

            good = false;
        }
        std::cout << "\t\t relative error ||z||/||z_0|| = " << std::sqrt(zTz / zTz_0) << (std::sqrt(zTz / zTz_0) < relTol ? " < " : " > ") << relTol << std::endl;
        std::cout << "\t\t absolute error ||z||         = " << std::sqrt(zTz) << (std::sqrt(zTz) < absTol ? " < " : " > ") << absTol << std::endl;
        std::cout << "" << std::endl;

        oneapi::mkl::sparse::release_matrix_handle(q, &csrA);

        q.wait_and_throw();
    }
    catch (sycl::exception const &e) {
        std::cout << "\t\tCaught synchronous SYCL exception:\n" << e.what() << std::endl;
        good = false;
    }
    catch (std::exception const &e) {
        std::cout << "\t\tCaught std exception:\n" << e.what() << std::endl;
        good = false;
    }

    q.wait();

    // backup cleaning of matrix handle and others for if exceptions happened
    oneapi::mkl::sparse::release_matrix_handle(q, &csrA);

    q.wait();

    return good ? 0 : 1;
}

//
// Description of example setup, apis used and supported floating point type
// precisions
//
void print_example_banner()
{

    std::cout << "" << std::endl;
    std::cout << "###############################################################"
                 "#########"
              << std::endl;
    std::cout << "# Sparse Preconditioned CG Example with sycl buffers: " << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# A * x = b" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# where A is a sparse matrix in CSR format, x and b are "
                 "dense vectors"
              << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Supported floating point type precisions:" << std::endl;
    std::cout << "#   float" << std::endl;
    std::cout << "#   double" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "###############################################################"
                 "#########"
              << std::endl;
    std::cout << std::endl;
}

//
// Main entry point for example.
//
// Dispatches to appropriate device types as set at build time with flag:
// -DSYCL_DEVICES_cpu -- only runs SYCL CPU implementation
// -DSYCL_DEVICES_gpu -- only runs SYCL GPU implementation
// -DSYCL_DEVICES_all (default) -- runs on all: cpu and gpu devices
//
//  For each device selected and each supported data type,
//  run_sparse_blas_example is run with all supported data types,
//  if any fail, we move on to the next device.
//

int main(int argc, char **argv)
{

    print_example_banner();

    std::list<my_sycl_device_types> list_of_devices;
    set_list_of_devices(list_of_devices);

    int status = 0;
    for (auto it = list_of_devices.begin(); it != list_of_devices.end(); ++it) {
        try {
            sycl::device my_dev;
            bool my_dev_is_found = false;
            get_sycl_device(my_dev, my_dev_is_found, *it);

            if (my_dev_is_found) {
                std::cout << "Running tests on " << sycl_device_names[*it] << ".\n";

                // Catch asynchronous exceptions
                auto exception_handler = [](sycl::exception_list exceptions) {
                    for (std::exception_ptr const &e : exceptions) {
                        try {
                            std::rethrow_exception(e);
                        }
                        catch (sycl::exception const &e) {
                            std::cout << "Caught asynchronous SYCL exception: \n"
                                << e.what() << std::endl;
                        }
                    }
                };

                sycl::queue q(my_dev, exception_handler);

                std::cout << "\tRunning with single precision real data type:" << std::endl;
                status |= run_sparse_blas_example<float, std::int32_t>(q);

                if (my_dev.get_info<sycl::info::device::double_fp_config>().size() != 0) {
                    std::cout << "\tRunning with double precision real data type:" << std::endl;
                    status |= run_sparse_blas_example<double, std::int32_t>(q);
                }

            }
            else {
#ifdef FAIL_ON_MISSING_DEVICES
                std::cout << "No " << sycl_device_names[*it]
                    << " devices found; Fail on missing devices "
                    "is enabled.\n";
                return 1;
#else
                std::cout << "No " << sycl_device_names[*it] << " devices found; skipping "
                    << sycl_device_names[*it] << " tests.\n";
#endif
            }
        }
        catch (sycl::exception const &e) {
            std::cout << "\t\tCaught SYCL exception at driver level: \n" << e.what() << std::endl;
            continue; // stop with device, but move on to other devices
        }
        catch (std::exception const &e) {
            std::cout << "\t\tCaught std exception at driver level: \n" << e.what() << std::endl;
            continue; // stop with device, but move on to other devices
        }


    } // for loop over devices

    mkl_free_buffers();
    return status;
}

