! Created by Andrey Debolskiy on 29.11.2024.

module pbl_solver
    use parkinds, only: rf=>kind_rf, im=>kind_im
    use scm_state_data
    use pbl_turb_data
    implicit none

    public factorize_tridiag
    public solve_tridiag
    public fill_tridiag
    public solve_diffusion
    public apply_subsidence
    contains
    subroutine factorize_tridiag(ktvd, kl, aa, bb, cc, prgna, prgnz)
        implicit none
        integer, intent(in)              :: ktvd, kl
        real, intent(in),  dimension(kl) :: aa, bb, cc
        real, intent(out), dimension(kl) :: prgna, prgnz

        integer :: k

        prgna(ktvd) = cc(ktvd) / bb(ktvd)
        do k = ktvd+1, kl
            prgnz(k) = 1.0e0 / (bb(k) - aa(k) * prgna(k-1))
            prgna(k) = cc(k) * prgnz(k)
        end do
    end subroutine factorize_tridiag

    !> reduce tridiagonal system to bidiagonal after matrix factorization
    !!              - bb(k) y(k) + cc(k) y(k+1) = -f(k), k = ktvd
    !! aa(k) y(k-1) - bb(k) y(k) + cc(k) y(k+1) = -f(k), k = ktvd+1..kl-1
    !! aa(k) y(k-1) - bb(k) y(k)                = -f(k), k = kl
    !! assuming cc(kl) = 0.0
    !! reduced system is
    !! y(k) - prgna(k) y(k+1) = prgnb(k), k = ktvd..kl-1
    !! y(k)                   = prgnb(k), k = kl
    !! then solve via backward substitution
    subroutine solve_tridiag(ktvd, kl, aa, bb, cc, ff, prgna, prgnz, y)
        implicit none
        integer, intent(in)              :: ktvd, kl
        real, intent(in),  dimension(kl) :: aa, bb, cc, ff
        real, intent(in),  dimension(kl) :: prgna, prgnz
        real, intent(inout), dimension(kl) :: y

        integer :: k
        real :: prgnb(kl)
        prgnb(ktvd) = ff(ktvd) / bb(ktvd)

        do k = ktvd+1, kl
            prgnb(k) = prgnz(k) * (ff(k) + aa(k) * prgnb(k-1))
        end do
        y(kl) = prgnb(kl)
        do k = kl-1, ktvd, -1
            y(k) = prgna(k) * y(k+1) + prgnb(k)
        end do
    end subroutine solve_tridiag

    subroutine fill_tridiag(aa, bb, cc, rho, kdiff, kbltop, cm2u, grid, dt)
        use pbl_grid, only: pblgridDataType
        implicit none
        real, dimension(*), intent(in):: rho, kdiff
        real, intent(in):: dt, cm2u
        integer, intent(in):: kbltop
        type(pblgridDataType),intent(in):: grid

        real, dimension(*), intent(inout):: aa, bb, cc

        real:: dtdz
        integer:: k

        !nulify before top boundary
        aa(1:grid%kmax) = 0.0
        bb(1:grid%kmax) = 0.0
        cc(1:grid%kmax) = 0.0
        !top boundary condition: flux = 0
        k = kbltop
        dtdz = dt / (grid%z_edge(k) - grid%z_edge(k-1) )
        aa(k) = 0
        cc(k) = (kdiff(k)/rho(k)) * dtdz / (grid%z_cell(k+1) - grid%z_cell(k))
        bb(k) = 1.0 + cc(k)
        do k = kbltop + 1, grid%kmax -1
            dtdz = dt / (grid%z_edge(k) - grid%z_edge(k-1) )
            aa(k) = (kdiff(k - 1)/rho(k)) * dtdz / (grid%z_cell(k) - grid%z_cell(k-1))
            cc(k) = (kdiff(k)/rho(k)) * dtdz /(grid%z_cell(k+1) - grid%z_cell(k))
            bb(k) = 1.0 + aa(k) + cc(k)

        end do
        !bottom boundary
        k = grid%kmax
        dtdz = dt / (grid%z_edge(k) - grid%z_edge(k-1) )
        aa(k) = (kdiff(k-1)/rho(k)) * dtdz / (grid%z_cell(k) - grid%z_cell(k-1))
        bb(k) = 1.0 + aa(k) - dtdz * cm2u / rho(k)
        cc(k) = 0.0
    end subroutine fill_tridiag

    subroutine solve_diffusion(bl, bl_old, turb, fluid, grid, dt)
        use scm_state_data, only : stateBLDataType
        use pbl_turb_data, only : turbBLDataType
        use phys_fluid, only: fluidParamsDataType
        use pbl_grid, only : pblgridDataType
        implicit none
        type(stateBLDataType), intent(inout):: bl
        type(stateBLDataType), intent(in):: bl_old
        type(turbBLDataType), intent(in):: turb
        type(fluidParamsDataType), intent(in) :: fluid
        type(pblgridDataType), intent(in) :: grid
        real, intent(in):: dt

        real, save, allocatable:: aa(:), bb(:), cc(:), ff(:)
        real, allocatable:: prgna(:), prgnz(:)
        integer k, integer, ktop, kmax
        kmax = grid%kmax
        if (.not.(allocated(aa))) then
            allocate(aa(grid%kmax), source=0.0)
        end if
        if (.not.(allocated(bb))) then
            allocate(bb(grid%kmax), source=0.0)
        end if
        if (.not.(allocated(cc))) then
            allocate(cc(grid%kmax), source=0.0)
        end if
        if (.not.(allocated(ff))) then
            allocate(ff(grid%kmax), source=0.0)
        end if
        if (.not.(allocated(prgna))) then
            allocate(prgna(grid%kmax), source=0.0)
        end if
        if (.not.(allocated(prgnz))) then
            allocate(prgnz(grid%kmax), source=0.0)
        end if
        ktop = bl%kpbl
        !this loop is generally not needed
        !do k = bl%kpbl-1,1,-1
        !    if(bl%vdcuv(k) > 0.e0) then
        !        ktop = k
        !    end if
        !enddo

        call fill_tridiag(aa, bb, cc, bl%rho, bl%vdctq, ktop, 0.0,  grid, dt)

        call factorize_tridiag(ktop, kmax, aa, bb, cc, prgna, prgnz)
        do k = ktop,kmax-1
            ff(k) = bl%theta(k)
        end do

        ff(kmax) = bl%theta(kmax)  &
                    + (dt/grid%dze(kmax)) * bl%surf%hs / (bl%rho(kmax))
        write(*,*) 'ff'!,  ff!, aa,bb,cc

        call solve_tridiag(ktop, kmax, aa, bb, cc, ff, prgna, prgnz, bl%theta)

        do k = ktop,kmax-1
            ff(k) = bl%qv(k)
        end do
        ff(kmax) = bl%qv(kmax) + (dt/grid%dze(kmax)) * bl%surf%es /bl%rho(kmax)
        call solve_tridiag(ktop, kmax, aa, bb, cc, ff, prgna, prgnz, bl%qv)
        !velocity

        call fill_tridiag(aa, bb, cc, bl%rho, bl%vdcuv, ktop, bl%surf%cm2u, grid, dt)

        call factorize_tridiag(ktop, kmax, aa, bb, cc, prgna, prgnz)

        do k = ktop,kmax-1
            ff(k) = bl%u(k)
        end do
        ff(kmax) = bl%u(kmax)

        call solve_tridiag(ktop, kmax, aa, bb, cc, ff, prgna, prgnz, bl%u)

        do k = ktop,kmax-1
            ff(k) = bl%v(k)
        end do
        ff(kmax) = bl%v(kmax)
        call solve_tridiag(ktop, kmax, aa, bb, cc, ff, prgna, prgnz, bl%v)

    end subroutine solve_diffusion

    subroutine apply_subsidence(bl, w_s, fluid, grid, dt)
        use scm_state_data, only : stateBLDataType
        use pbl_turb_data, only : turbBLDataType
        use phys_fluid, only: fluidParamsDataType
        use pbl_grid, only : pblgridDataType
        implicit none
        type(stateBLDataType), intent(inout):: bl
        type(fluidParamsDataType), intent(in) :: fluid
        type(pblgridDataType), intent(in) :: grid
        real, intent(in):: dt
        real, dimension(*), intent(in):: w_s
        real, save, allocatable :: subs_u(:), subs_v(:), &
                                    subs_theta(:), subs_qv(:)

        real:: dtdz, rhom, rhop
        integer k, kmax

        kmax = grid%kmax
        if (.not.(allocated(subs_u))) then
            allocate(subs_u(grid%kmax), source=0.0)
        end if
        if (.not.(allocated(subs_v))) then
            allocate(subs_v(grid%kmax), source=0.0)
        end if
        if (.not.(allocated(subs_theta))) then
            allocate(subs_theta(grid%kmax), source=0.0)
        end if
        if (.not.(allocated(subs_qv))) then
            allocate(subs_qv(grid%kmax), source=0.0)
        end if
        ! central differences
        do k = kmax-1, 2, -1
            rhop = 0.5*(bl%rho(k) + bl%rho(k+1))
            rhom = 0.5*(bl%rho(k) + bl%rho(k-1))
            subs_u(k) = (0.5/bl%rho(k))  &
                    * (rhop * w_s(k) * (bl%u(k)-bl%u(k+1)) &
                            /(grid%z_cell(k)-grid%z_cell(k+1)) &
                            + &
                        rhom * w_s(k-1) * (bl%u(k)-bl%u(k-1)) &
                            /(grid%z_cell(k)-grid%z_cell(k-1)))
            subs_v(k) = (0.5/bl%rho(k))  &
                    * (rhop * w_s(k) * (bl%v(k)-bl%v(k+1)) &
                            /(grid%z_cell(k)-grid%z_cell(k+1)) &
                            + &
                            rhom * w_s(k-1) * (bl%v(k)-bl%v(k-1)) &
                                    /(grid%z_cell(k)-grid%z_cell(k-1)))
            subs_theta(k) = (0.5/bl%rho(k))  &
                    * (rhop * w_s(k) * (bl%theta(k)-bl%theta(k+1)) &
                            /(grid%z_cell(k)-grid%z_cell(k+1)) &
                            + &
                            rhom * w_s(k-1) * (bl%theta(k)-bl%theta(k-1)) &
                                    /(grid%z_cell(k)-grid%z_cell(k-1)))
            subs_qv(k) = (0.5/bl%rho(k))  &
                    * (rhop * w_s(k) * (bl%qv(k)-bl%qv(k+1)) &
                            /(grid%z_cell(k)-grid%z_cell(k+1)) &
                            + &
                            rhom * w_s(k-1) * (bl%qv(k)-bl%qv(k-1)) &
                                    /(grid%z_cell(k)-grid%z_cell(k-1)))
        end do

        ! bottom boundary
        write(*,*) 'ws', grid%z_cell(kmax), grid%z_cell(kmax-1), w_s(kmax)
        k = kmax
        subs_u(k) = w_s(k-1) * (bl%u(k)-bl%u(k-1)) &
                    /(grid%z_cell(k)-grid%z_cell(k-1))
        subs_v(k) = w_s(k-1) * (bl%v(k)-bl%v(k-1)) &
                    /(grid%z_cell(k)-grid%z_cell(k-1))
        subs_theta(k) = w_s(k-1) * (bl%theta(k)-bl%theta(k-1)) &
                    /(grid%z_cell(k)-grid%z_cell(k-1))
        subs_qv(k) = w_s(k-1) * (bl%qv(k)-bl%qv(k-1)) &
                /(grid%z_cell(k)-grid%z_cell(k-1))
        ! top boundary
        k = 1
        subs_u(k) = w_s(k+1) * (bl%u(k)-bl%u(k+1)) &
                    /(grid%z_cell(k)-grid%z_cell(k+1))
        subs_v(k) = w_s(k+1) * (bl%v(k)-bl%v(k+1)) &
                    /(grid%z_cell(k)-grid%z_cell(k+1))
        subs_theta(k) = w_s(k+1) * (bl%theta(k)-bl%theta(k+1)) &
                    /(grid%z_cell(k)-grid%z_cell(k+1))
        subs_qv(k) = w_s(k+1) * (bl%qv(k)-bl%qv(k+1)) &
                    /(grid%z_cell(k)-grid%z_cell(k+1))

        !apply
        do k=kmax, 1, -1
            bl%u(k) = bl%u(k) - subs_u(k) * dt
            bl%v(k) = bl%v(k) - subs_v(k) * dt
            bl%theta(k) = bl%theta(k) - subs_theta(k) * dt
            bl%qv(k) = bl%qv(k) - subs_qv(k) * dt
        end do
    end subroutine apply_subsidence
end module pbl_solver