#include "../include/matvec.cuh"

namespace icethermo_kernel
{
    template <typename NumType>
    __global__ void mul_vec(const NumType* vec1, const NumType* vec2, const int size, NumType* res);

    template <typename NumType>
    __global__ void mul_vec(NumType* vec1, const NumType num, const int size);
}

template <typename NumType>
__global__ void icethermo_kernel::mul_vec(const NumType* vec1, const NumType* vec2, const int size, NumType* res)
{
    const int index = blockIdx.x * blockDim.x + threadIdx.x;

    if(index < size)
        res[index] = vec1[index] * vec2[index];
}

template <typename NumType>
__global__ void icethermo_kernel::mul_vec(NumType* vec, const NumType num, const int size)
{
    const int index = blockIdx.x * blockDim.x + threadIdx.x;

    if(index < size)
        vec[index] *= num;
}

namespace icethermo_gpu
{
    template <typename NumType> void mul_vec(const NumType* vec1, const NumType* vec2, const int size, NumType* res)
    {
        const int BlockCount = int(ceil(float(size) / 1024.0));
        dim3 cuBlock = dim3(1024, 1, 1);
        dim3 cuGrid = dim3(BlockCount, 1, 1);

        icethermo_kernel::mul_vec<<<cuGrid, cuBlock>>>(vec1, vec2, size, res);
    }

    template <typename NumType> void mul_vec(NumType* vec, const NumType num, const int size)
    {
        const int BlockCount = int(ceil(float(size) / 1024.0));
        dim3 cuBlock = dim3(1024, 1, 1);
        dim3 cuGrid = dim3(BlockCount, 1, 1);

        icethermo_kernel::mul_vec<<<cuGrid, cuBlock>>>(vec, num, size);
    }


    // explicit instantaion
    template void mul_vec(const float* vec1, const float* vec2, const int size, float* res);
    template void mul_vec(const double* vec1, const double* vec2, const int size, double* res);

    template void mul_vec(float* vec1, const float num, const int size);
    template void mul_vec(double* vec1, const double num, const int size);
}

template __global__ void icethermo_kernel::mul_vec(const float* vec1, const float* vec2, const int size, float* res);
template __global__ void icethermo_kernel::mul_vec(const double* vec1, const double* vec2, const int size, double* res);

template __global__ void icethermo_kernel::mul_vec(float* vec, const float num, const int size);
template __global__ void icethermo_kernel::mul_vec(double* vec, const double num, const int size);