LCOV - code coverage report
Current view: top level - physics/clubb/src/CLUBB_core - penta_lu_solver.F90 (source / functions) Hit Total Coverage
Test: coverage.info Lines: 0 84 0.0 %
Date: 2024-12-17 17:57:11 Functions: 0 2 0.0 %

          Line data    Source code
       1             : module penta_lu_solvers
       2             : 
       3             :   ! Description:
       4             :   !   These routines solve lhs*soln=rhs using LU decomp. 
       5             :   !   
       6             :   !   LHS is stored in band diagonal form. 
       7             :   !     lhs = | lhs(0,1)  lhs(-1,1)  lhs(-2,1)     0           0          0          0
       8             :   !           | lhs(1,2)  lhs( 0,2)  lhs(-1,2)  lhs(-2,2)      0          0          0
       9             :   !           | lhs(2,3)  lhs( 1,3)  lhs( 0,3)  lhs(-1,3)  lhs(-2,3)      0          0
      10             :   !           |     0     lhs( 2,4)  lhs( 1,4)  lhs( 0,4)  lhs(-1,4)  lhs(-2,4)      0
      11             :   !           |     0         0      lhs( 2,5)  lhs( 1,5)  lhs( 0,5)  lhs(-1,5)  lhs(-2,5) ...
      12             :   !           |    ...                                                                   
      13             :   !
      14             :   !    U is stored in band diagonal form 
      15             :   !     U = |   1    upper_1(1)  upper_2(1)      0           0           0           0
      16             :   !         |   0        1       upper_1(2)  upper_2(2)      0           0           0
      17             :   !         |   0        0           1       upper_1(3)  upper_2(3)      0           0
      18             :   !         |   0        0           0           1       upper_1(4)  upper_2(4)      0
      19             :   !         |   0        0           0           0           1       upper_1(5)  upper_2(5)  ...
      20             :   !         |  ...    
      21             :   !
      22             :   !   L is also stored in band diagonal form, but the lowest most band is equivalent to the 
      23             :   !   lowermost band of LHS, thus we don't need to store it
      24             :   !     L = | l_diag(1)        0            0           0             0       0   
      25             :   !         | lower_1(2)    l_diag(2)       0           0             0       0    
      26             :   !         |   l_2(3)      lower_1(3)   l_diag(3)      0             0       0          
      27             :   !         |     0           l_2(4)     lower_1(4)  l_diag(4)        0       0     
      28             :   !         |     0             0        l_2(5)      lower_1(5)    l_diag(5)  0   ...
      29             :   !         |  ...   
      30             :   ! 
      31             :   !
      32             :   !   To perform the LU decomposition, we go element by element. 
      33             :   !   First we start by noting that we want lhs=LU, so the first step of calculating
      34             :   !   L*U, by multiplying the first row of L by the columns of U, gives us
      35             :   !
      36             :   !     l_diag(1)*1          = lhs( 0,1)  =>  l_diag(1)  = lhs( 0,1)
      37             :   !     l_diag(1)*upper_1(1) = lhs(-1,1)  =>  upper_1(1) = lhs(-1,1) / l_diag(1)
      38             :   !     l_diag(1)*upper_2(1) = lhs(-2,1)  =>  upper_2(1) = lhs(-2,1) / l_diag(1)
      39             :   !
      40             :   !   Multiplying the second row of L by U now we get
      41             :   !     
      42             :   !     lower_1(2)*1                               = lhs(1,2)   =>  lower_1(2)  = lhs(1,2)
      43             :   !     lower_1(2)*upper_1(1)+l_diag(2)*1          = lhs(0,2)   =>  l_diag(2)   = lhs(0,2) - lower_1(2)*upper_1(1)
      44             :   !     lower_1(2)*upper_2(1)+l_diag(2)*upper_1(2) = lhs(-1,2)  =>  upper_1(2)  = ( lhs(-1,2)-lower_1(2)*upper_2(1) )
      45             :   !                                                                               / l_diag(2)
      46             :   !     l_diag(2)*upper_2(2)                       = lhs(-2,2)  =>  upper_2(2)  = lhs(-2,2) / l_diag(2)
      47             :   !
      48             :   !   Now that we're passed the k=1 and k=2 steps, each following step uses all the bands,
      49             :   !   allowing us to write the general step
      50             :   !     
      51             :   !     l_2(k)*1                                    = lhs(2,k)   =>  l_2(k)     = lhs(2,k)
      52             :   !     l_2(k)*upper_1(k-2)+lower_1(k)*1            = lhs(1,k)   =>  lower_1(k) = lhs(1,k) - l_2(k)*upper_1(k-2)
      53             :   !     l_2(k)*upper_2(k-2)+lower_1(k)*upper_1(k-1) = lhs( 0,k)  =>  l_diag(k)  = lhs(0,k) - l_2(k)*upper_2(k-2)
      54             :   !                    +l_diag(k)*1                                               + lower_1(k)*upper_1(k-1) 
      55             :   !                    
      56             :   !     lower_1(k)*upper_2(k-1)
      57             :   !     + l_diag(k)*upper_1(k) = lhs(-1,k)  =>  upper_1(k) = ( lhs(-1,k) - lower_1(k)*upper_2(k-1) )
      58             :   !                                                          / l_diag(k)
      59             :   !     l_diag(k)*upper_2(k)   = lhs(-2,k)  =>  upper_2(k) = lhs(-2,k) / l_diag(k)
      60             :   !
      61             :   !
      62             :   !   This general step is done for k from 3 to ndim-2 (do k = 3, ndim-2), and the last two
      63             :   !   steps are tweaked similarly to the first two, where we disclude one then two bands
      64             :   !   since they become no longer relevant. Note from this general step that the l_2 band
      65             :   !   is always equivalent to second subdiagonal band of lhs, thus we do not need to 
      66             :   !   calculate or store l_2. Also note that we only ever need l_diag so that we can divide 
      67             :   !   by it, so instead we compute lower_diag_invrs to reduce divide operations.
      68             :   !
      69             :   !   After L and U are computed, normally we do forward substitution using L,
      70             :   !   then backward substitution using U to find the solution. This is replicated
      71             :   !   for every right hand side we want to solve for.
      72             :   !
      73             :   !
      74             :   ! References:
      75             :   !   none
      76             :   !------------------------------------------------------------------------
      77             : 
      78             : 
      79             :   use clubb_precision, only:  &
      80             :     core_rknd ! Variable(s)
      81             : 
      82             :   implicit none
      83             : 
      84             :   public :: penta_lu_solve
      85             : 
      86             :   private :: penta_lu_solve_single_rhs_multiple_lhs, penta_lu_solve_multiple_rhs_lhs
      87             : 
      88             :   interface penta_lu_solve
      89             :     module procedure penta_lu_solve_single_rhs_multiple_lhs
      90             :     module procedure penta_lu_solve_multiple_rhs_lhs
      91             :   end interface
      92             : 
      93             :   private ! Default scope
      94             : 
      95             :   contains
      96             : 
      97             :   !=============================================================================
      98           0 :   subroutine penta_lu_solve_single_rhs_multiple_lhs( ndim, ngrdcol, lhs, rhs, &
      99           0 :                                                      soln )
     100             :     ! Description:
     101             :     !   Written for single RHS and multiple LHS.
     102             :     !------------------------------------------------------------------------
     103             : 
     104             :     implicit none
     105             :  
     106             :     ! ----------------------- Input Variables -----------------------
     107             :     integer, intent(in) :: &
     108             :       ndim,   & ! Matrix size 
     109             :       ngrdcol   ! Number of grid columns
     110             : 
     111             :     real( kind = core_rknd ), intent(in), dimension(ngrdcol,ndim) ::  &
     112             :       rhs       ! 
     113             : 
     114             :     ! ----------------------- Input/Output Variables -----------------------
     115             :     real( kind = core_rknd ), intent(inout), dimension(-2:2,ngrdcol,ndim) :: &
     116             :       lhs   ! Matrices to solve, stored using band diagonal vectors
     117             :             ! -2 is the uppermost band, 2 is the lower most band, 0 is diagonal
     118             : 
     119             :     ! ----------------------- Output Variables -----------------------
     120             :     real( kind = core_rknd ), intent(out), dimension(ngrdcol,ndim) ::  &
     121             :       soln     ! Solution vector
     122             : 
     123             :     ! ----------------------- Local Variables -----------------------
     124             :     real( kind = core_rknd ), dimension(ngrdcol,ndim) ::  &
     125           0 :       upper_1,          & ! First U band 
     126           0 :       upper_2,          & ! Second U band 
     127           0 :       lower_diag_invrs, & ! Inverse of the diagonal of L
     128           0 :       lower_1,          & ! First L band
     129           0 :       lower_2             ! Second L band
     130             : 
     131             :     integer :: i, k, j    ! Loop variables
     132             : 
     133             :     ! ----------------------- Begin Code -----------------------
     134             :        
     135             :     !$acc data create( upper_1, upper_2, lower_1, lower_2, lower_diag_invrs ) &
     136             :     !$acc      copyin( rhs, lhs ) &
     137             :     !$acc      copyout( soln )
     138             : 
     139             :     !$acc parallel loop gang vector default(present)
     140           0 :     do i = 1, ngrdcol
     141           0 :       lower_diag_invrs(i,1) = 1.0_core_rknd / lhs(0,i,1)
     142           0 :       upper_1(i,1)          = lower_diag_invrs(i,1) * lhs(-1,i,1) 
     143           0 :       upper_2(i,1)          = lower_diag_invrs(i,1) * lhs(-2,i,1) 
     144             : 
     145           0 :       lower_1(i,2)          = lhs(1,i,2)
     146           0 :       lower_diag_invrs(i,2) = 1.0_core_rknd / ( lhs(0,i,2) - lower_1(i,2) * upper_1(i,1) )
     147           0 :       upper_1(i,2)          = lower_diag_invrs(i,2) * ( lhs(-1,i,2) - lower_1(i,2) * upper_2(i,1) )
     148           0 :       upper_2(i,2)          = lower_diag_invrs(i,2) * lhs(-2,i,2)
     149             :     end do
     150             :     !$acc end parallel loop
     151             : 
     152             :     !$acc parallel loop gang vector default(present)
     153           0 :     do i = 1, ngrdcol
     154           0 :       do k = 3, ndim-2
     155           0 :         lower_2(i,k) = lhs(2,i,k)
     156           0 :         lower_1(i,k) = lhs(1,i,k) - lower_2(i,k) * upper_1(i,k-2)
     157             : 
     158             :         lower_diag_invrs(i,k) = 1.0_core_rknd / ( lhs(0,i,k) - lower_2(i,k) * upper_2(i,k-2) &
     159           0 :                                                              - lower_1(i,k) * upper_1(i,k-1) )
     160             : 
     161           0 :         upper_1(i,k) = lower_diag_invrs(i,k) * ( lhs(-1,i,k) - lower_1(i,k) * upper_2(i,k-1) )
     162           0 :         upper_2(i,k) = lower_diag_invrs(i,k) * lhs(-2,i,k)
     163             :       end do
     164             :     end do
     165             :     !$acc end parallel loop
     166             : 
     167             :     !$acc parallel loop gang vector default(present)
     168           0 :     do i = 1, ngrdcol
     169           0 :       lower_2(i,ndim-1) = lhs(2,i,ndim-1)
     170           0 :       lower_1(i,ndim-1) = lhs(1,i,ndim-1) - lower_2(i,ndim-1) * upper_1(i,ndim-3)
     171             : 
     172             :       lower_diag_invrs(i,ndim-1) = 1.0_core_rknd  &
     173             :                                    / ( lhs(0,i,ndim-1) - lower_2(i,ndim-1) * upper_2(i,ndim-3) &
     174           0 :                                                        - lower_1(i,ndim-1) * upper_1(i,ndim-2) )
     175             : 
     176             :       upper_1(i,ndim-1)  = lower_diag_invrs(i,ndim-1) * ( lhs(-1,i,ndim-1) - lower_1(i,ndim-1) &
     177           0 :                                                                              * upper_2(i,ndim-2) )
     178             : 
     179           0 :       lower_2(i,ndim) = lhs(2,i,ndim)
     180           0 :       lower_1(i,ndim) = lhs(1,i,ndim) - lower_2(i,ndim) * upper_1(i,ndim-2)
     181             : 
     182             :       lower_diag_invrs(i,ndim) = 1.0_core_rknd  &
     183             :                                  / ( lhs(0,i,ndim-1) - lower_2(i,ndim) * upper_2(i,ndim-2) &
     184           0 :                                                      - lower_1(i,ndim) * upper_1(i,ndim-1) )
     185             :     end do
     186             :     !$acc end parallel loop
     187             :     
     188             :     !$acc parallel loop gang vector default(present)
     189           0 :     do i = 1, ngrdcol 
     190             : 
     191           0 :       soln(i,1)   = lower_diag_invrs(i,1) * rhs(i,1) 
     192             : 
     193           0 :       soln(i,2)   = lower_diag_invrs(i,2) * ( rhs(i,2) - lower_1(i,2) * soln(i,1) )
     194             : 
     195           0 :       do k = 3, ndim
     196           0 :         soln(i,k) = lower_diag_invrs(i,k) * ( rhs(i,k) - lower_2(i,k) * soln(i,k-2) &
     197           0 :                                                        - lower_1(i,k) * soln(i,k-1) )
     198             :       end do
     199             :     end do
     200             :     !$acc end parallel loop
     201             : 
     202             :     !$acc parallel loop gang vector default(present)
     203           0 :     do i = 1, ngrdcol 
     204           0 :       soln(i,ndim-1) = soln(i,ndim-1) - upper_1(i,ndim-1) * soln(i,ndim)
     205             : 
     206           0 :       do k = ndim-2, 1, -1
     207           0 :         soln(i,k) = soln(i,k) - upper_1(i,k) * soln(i,k+1) - upper_2(i,k) * soln(i,k+2)
     208             :       end do
     209             : 
     210             :     end do
     211             :     !$acc end parallel loop
     212             : 
     213             :     !$acc end data
     214             : 
     215           0 :   end subroutine penta_lu_solve_single_rhs_multiple_lhs
     216             : 
     217             :   
     218             :   !=============================================================================
     219           0 :   subroutine penta_lu_solve_multiple_rhs_lhs( ndim, nrhs, ngrdcol, lhs, rhs, &
     220           0 :                                               soln )
     221             :     ! Description:
     222             :     !   Written for multiple RHS and multiple LHS.
     223             :     !------------------------------------------------------------------------
     224             : 
     225             :     implicit none
     226             :  
     227             :     ! ----------------------- Input Variables -----------------------
     228             :     integer, intent(in) :: &
     229             :       ndim,   & ! Matrix size 
     230             :       nrhs,   & ! Number of right hand sides
     231             :       ngrdcol   ! Number of grid columns
     232             : 
     233             :     real( kind = core_rknd ), intent(in), dimension(ngrdcol,ndim,nrhs) ::  &
     234             :       rhs       ! 
     235             : 
     236             :     ! ----------------------- Input/Output Variables -----------------------
     237             :     real( kind = core_rknd ), intent(inout), dimension(-2:2,ngrdcol,ndim) :: &
     238             :       lhs   ! Matrices to solve, stored using band diagonal vectors
     239             :             ! -2 is the uppermost band, 2 is the lower most band, 0 is diagonal
     240             : 
     241             :     ! ----------------------- Output Variables -----------------------
     242             :     real( kind = core_rknd ), intent(out), dimension(ngrdcol,ndim,nrhs) ::  &
     243             :       soln     ! Solution vector
     244             : 
     245             :     ! ----------------------- Local Variables -----------------------
     246             :     real( kind = core_rknd ), dimension(ngrdcol,ndim) ::  &
     247           0 :       upper_1,          & ! First U band 
     248           0 :       upper_2,          & ! Second U band 
     249           0 :       lower_diag_invrs, & ! Inverse of the diagonal of L
     250           0 :       lower_1,          & ! First L band
     251           0 :       lower_2             ! Second L band
     252             : 
     253             :     integer :: i, k, j    ! Loop variables
     254             : 
     255             :     ! ----------------------- Begin Code -----------------------
     256             :        
     257             :     !$acc data create( upper_1, upper_2, lower_1, lower_2, lower_diag_invrs ) &
     258             :     !$acc      copyin( rhs, lhs ) &
     259             :     !$acc      copyout( soln )
     260             : 
     261             :     !$acc parallel loop gang vector default(present)
     262           0 :     do i = 1, ngrdcol
     263           0 :       lower_diag_invrs(i,1) = 1.0_core_rknd / lhs(0,i,1)
     264           0 :       upper_1(i,1)          = lower_diag_invrs(i,1) * lhs(-1,i,1) 
     265           0 :       upper_2(i,1)          = lower_diag_invrs(i,1) * lhs(-2,i,1) 
     266             : 
     267           0 :       lower_1(i,2)          = lhs(1,i,2)
     268           0 :       lower_diag_invrs(i,2) = 1.0_core_rknd / ( lhs(0,i,2) - lower_1(i,2) * upper_1(i,1) )
     269           0 :       upper_1(i,2)          = lower_diag_invrs(i,2) * ( lhs(-1,i,2) - lower_1(i,2) * upper_2(i,1) )
     270           0 :       upper_2(i,2)          = lower_diag_invrs(i,2) * lhs(-2,i,2)
     271             :     end do
     272             :     !$acc end parallel loop
     273             : 
     274             :     !$acc parallel loop gang vector default(present)
     275           0 :     do i = 1, ngrdcol
     276           0 :       do k = 3, ndim-2
     277           0 :         lower_2(i,k) = lhs(2,i,k)
     278           0 :         lower_1(i,k) = lhs(1,i,k) - lower_2(i,k) * upper_1(i,k-2)
     279             : 
     280             :         lower_diag_invrs(i,k) = 1.0_core_rknd / ( lhs(0,i,k) - lower_2(i,k) * upper_2(i,k-2) &
     281           0 :                                                              - lower_1(i,k) * upper_1(i,k-1) )
     282             : 
     283           0 :         upper_1(i,k) = lower_diag_invrs(i,k) * ( lhs(-1,i,k) - lower_1(i,k) * upper_2(i,k-1) )
     284           0 :         upper_2(i,k) = lower_diag_invrs(i,k) * lhs(-2,i,k)
     285             :       end do
     286             :     end do
     287             :     !$acc end parallel loop
     288             : 
     289             :     !$acc parallel loop gang vector default(present)
     290           0 :     do i = 1, ngrdcol
     291           0 :       lower_2(i,ndim-1) = lhs(2,i,ndim-1)
     292           0 :       lower_1(i,ndim-1) = lhs(1,i,ndim-1) - lower_2(i,ndim-1) * upper_1(i,ndim-3)
     293             : 
     294             :       lower_diag_invrs(i,ndim-1) = 1.0_core_rknd  &
     295             :                                    / ( lhs(0,i,ndim-1) - lower_2(i,ndim-1) * upper_2(i,ndim-3) &
     296           0 :                                                        - lower_1(i,ndim-1) * upper_1(i,ndim-2) )
     297             : 
     298             :       upper_1(i,ndim-1) = lower_diag_invrs(i,ndim-1) * ( lhs(-1,i,ndim-1) - lower_1(i,ndim-1) &
     299           0 :                                                                             * upper_2(i,ndim-2) )
     300             : 
     301           0 :       lower_2(i,ndim) = lhs(2,i,ndim)
     302           0 :       lower_1(i,ndim) = lhs(1,i,ndim) - lower_2(i,ndim) * upper_1(i,ndim-2)
     303             : 
     304             :       lower_diag_invrs(i,ndim) = 1.0_core_rknd  &
     305             :                                  / ( lhs(0,i,ndim-1) - lower_2(i,ndim) * upper_2(i,ndim-2) &
     306           0 :                                                      - lower_1(  i,ndim) * upper_1(i,ndim-1) )
     307             :     end do
     308             :     !$acc end parallel loop
     309             : 
     310             :     !$acc parallel loop gang vector collapse(2) default(present)
     311           0 :     do j = 1, nrhs
     312           0 :       do i = 1, ngrdcol 
     313             : 
     314           0 :         soln(i,1,j)   = lower_diag_invrs(i,1) * rhs(i,1,j) 
     315             : 
     316           0 :         soln(i,2,j)   = lower_diag_invrs(i,2) * ( rhs(i,2,j) - lower_1(i,2) * soln(i,1,j) )
     317             : 
     318           0 :         do k = 3, ndim
     319           0 :           soln(i,k,j) = lower_diag_invrs(i,k) * ( rhs(i,k,j) - lower_2(i,k) * soln(i,k-2,j) &
     320           0 :                                                              - lower_1(i,k) * soln(i,k-1,j) )
     321             :         end do
     322             :       end do
     323             :     end do
     324             :     !$acc end parallel loop
     325             : 
     326             :     !$acc parallel loop gang vector collapse(2) default(present)
     327           0 :     do j = 1, nrhs
     328           0 :       do i = 1, ngrdcol 
     329           0 :         soln(i,ndim-1,j) = soln(i,ndim-1,j) - upper_1(i,ndim-1) * soln(i,ndim,j)
     330             : 
     331           0 :         do k = ndim-2, 1, -1
     332           0 :           soln(i,k,j) = soln(i,k,j) - upper_1(i,k) * soln(i,k+1,j) - upper_2(i,k) * soln(i,k+2,j)
     333             :         end do
     334             : 
     335             :       end do
     336             :     end do
     337             :     !$acc end parallel loop
     338             : 
     339             :     !$acc end data
     340             : 
     341           0 :   end subroutine penta_lu_solve_multiple_rhs_lhs
     342             : 
     343             : end module penta_lu_solvers

Generated by: LCOV version 1.14