! Created by Andrey Debolskiy on 29.11.2024.

module pbl_solver
    use parkinds, only: rf=>kind_rf, im=>kind_im
    use scm_state_data
    use pbl_turb_data
    implicit none

    public factorize_tridiag
    public solve_tridiag
    public fill_tridiag
    public solve_diffusion
    contains
    subroutine factorize_tridiag(ktvd, kl, aa, bb, cc, prgna, prgnz)
        implicit none
        integer, intent(in)              :: ktvd, kl
        real, intent(in),  dimension(kl) :: aa, bb, cc
        real, intent(out), dimension(kl) :: prgna, prgnz

        integer :: k

        prgna(ktvd) = cc(ktvd) / bb(ktvd)
        do k = ktvd+1, kl
            prgnz(k) = 1.0e0 / (bb(k) - aa(k) * prgna(k-1))
            prgna(k) = cc(k) * prgnz(k)
        end do
    end subroutine factorize_tridiag

    !> reduce tridiagonal system to bidiagonal after matrix factorization
    !!              - bb(k) y(k) + cc(k) y(k+1) = -f(k), k = ktvd
    !! aa(k) y(k-1) - bb(k) y(k) + cc(k) y(k+1) = -f(k), k = ktvd+1..kl-1
    !! aa(k) y(k-1) - bb(k) y(k)                = -f(k), k = kl
    !! assuming cc(kl) = 0.0
    !! reduced system is
    !! y(k) - prgna(k) y(k+1) = prgnb(k), k = ktvd..kl-1
    !! y(k)                   = prgnb(k), k = kl
    !! then solve via backward substitution
    subroutine solve_tridiag(ktvd, kl, aa, bb, cc, ff, prgna, prgnz, y)
        implicit none
        integer, intent(in)              :: ktvd, kl
        real, intent(in),  dimension(kl) :: aa, bb, cc, ff
        real, intent(in),  dimension(kl) :: prgna, prgnz
        real, intent(out), dimension(kl) :: y

        integer :: k
        real :: prgnb(kl)

        prgnb(ktvd) = ff(ktvd) / bb(ktvd)
        do k = ktvd+1, kl
            prgnb(k) = prgnz(k) * (ff(k) + aa(k) * prgnb(k-1))
        end do
        y(kl) = prgnb(kl)
        do k = kl-1, ktvd, -1
            y(k) = prgna(k) * y(k+1) + prgnb(k)
        end do
    end subroutine solve_tridiag

    subroutine fill_tridiag(aa, bb, cc, rho, kdiff, kbltop, grid, dt)
        use pbl_grid, only: pblgridDataType
        implicit none
        real, dimension(*), intent(in):: rho, kdiff
        real, intent(in):: dt
        integer, intent(in):: kbltop
        type(pblgridDataType),intent(in):: grid

        real, dimension(*), intent(out):: aa, bb, cc

        real:: dtdz
        integer:: k

        !nulify before top boundary
        aa(1:kbltop) = 0.0
        bb(1:kbltop) = 0.0
        cc(1:kbltop) = 0.0
        !top boundary condition: flux = 0
        k = kbltop
        dtdz = dt / (grid%dzc(k))
        aa(k) = 0
        cc(k) = (kdiff(k)/rho(k)) * dtdz / grid%dze(k)

        do k = kbltop + 1, grid%kmax -1
            dtdz = dt / (grid%dzc(k))
            aa(k) = (kdiff(k - 1)/rho(k)) * dtdz / grid%dze(k-1)
            cc(k) = (kdiff(k)/rho(k)) * dtdz / grid%dze(k)
            bb(k) = 1.0 + aa(k) + cc(k)
        end do

        !bottom boundary
        k = grid%kmax
        dtdz = dt / (grid%dzc(k))
        aa(k) = (kdiff(k-1)/rho(k)) * dtdz / grid%dze(k-1)
        bb(k) = 1.0 + aa(k)
        cc(k) = 0.0
    end subroutine fill_tridiag

    subroutine solve_diffusion(bl, bl_old, turb, fluid, grid)
        use scm_state_data, only : stateBLDataType
        use pbl_turb_data, only : turbBLDataType
        use phys_fluid, only: fluidParamsDataType
        use pbl_grid, only : pblgridDataType
        implicit none
        type(stateBLDataType), intent(out):: bl
        type(stateBLDataType), intent(in):: bl_old
        type(turbBLDataType), intent(in):: turb
        type(fluidParamsDataType), intent(in) :: fluid
        type(pblgridDataType), intent(in) :: grid

    end subroutine solve_diffusion
end module pbl_solver