! Created by Andrey Debolskiy on 29.11.2024.

module pbl_solver
    use parkinds, only: rf=>kind_rf, im=>kind_im
    use scm_state_data
    use pbl_turb_data
    implicit none

    public factorize_tridiag
    public solve_tridiag
    public fill_tridiag
    public solve_diffusion
    contains
    subroutine factorize_tridiag(ktvd, kl, aa, bb, cc, prgna, prgnz)
        implicit none
        integer, intent(in)              :: ktvd, kl
        real, intent(in),  dimension(kl) :: aa, bb, cc
        real, intent(out), dimension(kl) :: prgna, prgnz

        integer :: k

        prgna(ktvd) = cc(ktvd) / bb(ktvd)
        do k = ktvd+1, kl
            prgnz(k) = 1.0e0 / (bb(k) - aa(k) * prgna(k-1))
            prgna(k) = cc(k) * prgnz(k)
        end do
    end subroutine factorize_tridiag

    !> reduce tridiagonal system to bidiagonal after matrix factorization
    !!              - bb(k) y(k) + cc(k) y(k+1) = -f(k), k = ktvd
    !! aa(k) y(k-1) - bb(k) y(k) + cc(k) y(k+1) = -f(k), k = ktvd+1..kl-1
    !! aa(k) y(k-1) - bb(k) y(k)                = -f(k), k = kl
    !! assuming cc(kl) = 0.0
    !! reduced system is
    !! y(k) - prgna(k) y(k+1) = prgnb(k), k = ktvd..kl-1
    !! y(k)                   = prgnb(k), k = kl
    !! then solve via backward substitution
    subroutine solve_tridiag(ktvd, kl, aa, bb, cc, ff, prgna, prgnz, y)
        implicit none
        integer, intent(in)              :: ktvd, kl
        real, intent(in),  dimension(kl) :: aa, bb, cc, ff
        real, intent(in),  dimension(kl) :: prgna, prgnz
        real, intent(inout), dimension(kl) :: y

        integer :: k
        real :: prgnb(kl)
        write(*,*) 'here_diff3.1', bb(ktvd)
        prgnb(ktvd) = ff(ktvd) / bb(ktvd)
        write(*,*) 'here_diff3.1'
        do k = ktvd+1, kl
            write(*,*) k, ktvd, prgnz(k),  ff(k)
            prgnb(k) = prgnz(k) * (ff(k) + aa(k) * prgnb(k-1))
        end do
        write(*,*) 'here_diff3'
        y(kl) = prgnb(kl)
        do k = kl-1, ktvd, -1
            y(k) = prgna(k) * y(k+1) + prgnb(k)
        end do
    end subroutine solve_tridiag

    subroutine fill_tridiag(aa, bb, cc, rho, kdiff, kbltop, cm2u, grid, dt)
        use pbl_grid, only: pblgridDataType
        implicit none
        real, dimension(*), intent(in):: rho, kdiff
        real, intent(in):: dt, cm2u
        integer, intent(in):: kbltop
        type(pblgridDataType),intent(in):: grid

        real, dimension(*), intent(inout):: aa, bb, cc

        real:: dtdz
        integer:: k

        !nulify before top boundary
        aa(1:kbltop-1) = 0.0
        bb(1:kbltop-1) = 0.0
        cc(1:kbltop-1) = 0.0
        !top boundary condition: flux = 0
        write(*,*) 'here_diff1.25', kbltop
        k = kbltop
        dtdz = dt / (grid%dzc(k))
        aa(k) = 0
        cc(k) = (kdiff(k)/rho(k)) * dtdz / grid%dze(k)
        write(*,*) 'here_diff1.5', kdiff(k), kbltop
        write(*,*) 'KTVDM', k, aa(k), bb(k), cc(k), rho(k)
        do k = kbltop + 1, grid%kmax -1
            dtdz = dt / (grid%dzc(k))
            aa(k) = (kdiff(k - 1)/rho(k)) * dtdz / grid%dze(k-1)
            cc(k) = (kdiff(k)/rho(k)) * dtdz / grid%dze(k)
            bb(k) = 1.0 + aa(k) + cc(k)
            write(*,*) 'fill', k, aa(k), bb(k), cc(k), kdiff(k)
        end do
        write(*,*) 'here_diff1.75'
        !bottom boundary
        k = grid%kmax
        dtdz = dt / (grid%dzc(k))
        aa(k) = (kdiff(k-1)/rho(k)) * dtdz / grid%dze(k-1)
        bb(k) = 1.0 + aa(k) + dtdz * cm2u / rho(k)
        cc(k) = 0.0
    end subroutine fill_tridiag

    subroutine solve_diffusion(bl, bl_old, turb, fluid, grid, dt)
        use scm_state_data, only : stateBLDataType
        use pbl_turb_data, only : turbBLDataType
        use phys_fluid, only: fluidParamsDataType
        use pbl_grid, only : pblgridDataType
        implicit none
        type(stateBLDataType), intent(inout):: bl
        type(stateBLDataType), intent(in):: bl_old
        type(turbBLDataType), intent(in):: turb
        type(fluidParamsDataType), intent(in) :: fluid
        type(pblgridDataType), intent(in) :: grid
        real, intent(in):: dt

        real, allocatable:: aa(:), bb(:), cc(:), ff(:)
        real, allocatable:: prgna(:), prgnz(:)
        integer k, integer, ktop, kmax
        kmax = grid%kmax
        write(*,*) 'here_diff'
        if (.not.(allocated(aa))) then
            allocate(aa(grid%kmax), source=0.0)
        end if
        if (.not.(allocated(bb))) then
            allocate(bb(grid%kmax), source=0.0)
        end if
        if (.not.(allocated(cc))) then
            allocate(cc(grid%kmax), source=0.0)
        end if
        if (.not.(allocated(ff))) then
            allocate(ff(grid%kmax), source=0.0)
        end if
        if (.not.(allocated(prgna))) then
            allocate(prgna(grid%kmax), source=0.0)
        end if
        if (.not.(allocated(prgnz))) then
            allocate(prgnz(grid%kmax), source=0.0)
        end if
        write(*,*) 'here_diff1', bl%kpbl
        ktop = bl%kpbl
        do k = bl%kpbl-1,1,-1
            if(bl%vdcuv(k) > 0.e0) then
                ktop = k
            end if
        enddo
        ! fill for temperature and specific humidity
        call fill_tridiag(aa, bb, cc, bl%rho, bl%vdctq, ktop, 0.0,  grid, dt)
        write(*,*) 'here_diff2'
        call factorize_tridiag(ktop, kmax, aa, bb, cc, prgna, prgnz)
        do k = ktop,kmax-1
            ff(k) = bl%theta(k)
            write(*,*) '2', k, aa(k), bb(k), cc(k), ff(k), grid%dze(k)
        end do
        ff(kmax) = bl%theta(kmax) + (dt/kmax) * bl%surf%hs /bl%rho(kmax)
        write(*,*) 'here_diff3'
        call solve_tridiag(ktop, kmax, aa, bb, cc, ff, prgna, prgnz, bl%theta)
        write(*,*) 'here_diff4'
        do k = ktop,kmax-1
            ff(k) = bl%qv(k)
        end do
        ff(kmax) = bl%qv(kmax) + (dt/kmax) * bl%surf%es /bl%rho(kmax)
        call solve_tridiag(ktop, kmax, aa, bb, cc, ff, prgna, prgnz, bl%qv)
        write(*,*) 'here_diff4'
        !velocity
        call fill_tridiag(aa, bb, cc, bl%rho, bl%vdcuv, ktop, bl%surf%cm2u, grid, dt)
        call factorize_tridiag(ktop, kmax, aa, bb, cc, prgna, prgnz)
        do k = ktop,kmax-1
            ff(k) = bl%u(k)
        end do
        ff(kmax) = bl%u(kmax)
        call solve_tridiag(ktop, kmax, aa, bb, cc, ff, prgna, prgnz, bl%u)
        do k = ktop,kmax-1
            ff(k) = bl%v(k)
        end do
        ff(kmax) = bl%v(kmax)
        call solve_tridiag(ktop, kmax, aa, bb, cc, ff, prgna, prgnz, bl%v)
    end subroutine solve_diffusion
end module pbl_solver