LCOV - code coverage report
Current view: top level - physics/clubb/src/CLUBB_core - tridiag_lu_solver.F90 (source / functions) Hit Total Coverage
Test: coverage.info Lines: 0 60 0.0 %
Date: 2024-12-17 17:57:11 Functions: 0 3 0.0 %

          Line data    Source code
       1             : module tridiag_lu_solvers
       2             : 
       3             :   ! Description:
       4             :   !   These routines solve lhs*soln=rhs using LU decomp. 
       5             :   !   
       6             :   !   LHS is stored in band diagonal form. 
       7             :   !     lhs = | lhs(0,1)  lhs(-1,1)      0          0          0          0      0
       8             :   !           | lhs(1,2)  lhs( 0,2)  lhs(-1,2)      0          0          0      0
       9             :   !           |     0     lhs( 1,3)  lhs( 0,3)  lhs(-1,3)      0          0      0
      10             :   !           |     0         0      lhs( 1,4)  lhs( 0,4)  lhs(-1,4)      0      0
      11             :   !           |     0         0          0      lhs( 1,5)  lhs( 0,5)  lhs(-1,5)  0 ...
      12             :   !           |    ...                                                                   
      13             :   !
      14             :   !    U is stored in band diagonal form 
      15             :   !     U = |   1    upper(1)  0       0       0       0       0
      16             :   !         |   0      1     upper(2)  0       0       0       0
      17             :   !         |   0      0       1     upper(3)  0       0       0
      18             :   !         |   0      0       0       1     upper(4)  0       0
      19             :   !         |   0      0       0       0       1     upper(5)  0  ...
      20             :   !         |  ...   
      21             :   !
      22             :   !   L is also stored in band diagonal form, but the lowest most band is equivalent to the 
      23             :   !   lowermost band of LHS, thus we don't need to store it
      24             :   !     L = | l_diag(1)       0           0            0          0        0   
      25             :   !         |  lower(2)      l_diag(2)    0            0          0        0    
      26             :   !         |    0           lower(3)   l_diag(3)      0          0        0          
      27             :   !         |    0            0         lower(4)    l_diag(4)     0        0     
      28             :   !         |    0            0           0         lower(5)    l_diag(5)   0   ...
      29             :   !         |  ...   
      30             :   ! 
      31             :   !
      32             :   !   To perform the LU decomposition, we go element by element. 
      33             :   !   First we start by noting that we want lhs=LU, so the first step of calculating
      34             :   !   L*U, by multiplying the first row of L by the columns of U, gives us
      35             :   !
      36             :   !     l_diag(1)*1        = lhs( 0,1)  =>  l_diag(1)   = lhs( 0,1)
      37             :   !     l_diag(1)*upper(1) = lhs(-1,1)  =>  upper(1)    = lhs(-1,1) / l_diag(1)
      38             :   !
      39             :   !   Now that we're passed the k=1 step, each following step uses all the bands,
      40             :   !   allowing us to write the general step
      41             :   !     
      42             :   !     lower(k)*1 = lhs(1,k)                         =>  lower(k)    = lhs( 1,k)
      43             :   !     lower(k)*upper(k-1)+l_diag(k)*1  = lhs( 0,k)  =>  l_diag(k) = lhs( 0,k) - lower(k)*upper(k-1)
      44             :   !     l_diag(k)*upper(k)               = lhs(-1,k)  =>  upper(k)  = lhs(-1,k) / l_diag(k)
      45             :   !
      46             :   !   This general step is done for k from 2 to ndim-1 (do k = 2, ndim-), and the last 
      47             :   !   step is tweaked similarly to the first one, where we disclude the upper band
      48             :   !   since it becomes no longer relevant. Note from this general step that the lower band
      49             :   !   is always equivalent to first subdiagonal band of lhs, thus we do not need to 
      50             :   !   calculate or store lower. Also note that we only ever need l_diag so that we can divide 
      51             :   !   by it, so instead we compute lower_diag_invrs to reduce divide operations.
      52             :   !
      53             :   !   After L and U are computed, normally we do forward substitution using L,
      54             :   !   then backward substitution using U to find the solution. This is replicated
      55             :   !   for every right hand side we want to solve for.
      56             :   !
      57             :   !
      58             :   ! References:
      59             :   !   none
      60             :   !------------------------------------------------------------------------
      61             : 
      62             : 
      63             :   use clubb_precision, only:  &
      64             :     core_rknd ! Variable(s)
      65             : 
      66             :   implicit none
      67             : 
      68             :   public :: tridiag_lu_solve
      69             : 
      70             :   private :: tridiag_lu_solve_single_rhs_multiple_lhs, tridiag_lu_solve_multiple_rhs_lhs
      71             : 
      72             :   interface tridiag_lu_solve
      73             :     module procedure tridiag_lu_solve_single_rhs_lhs
      74             :     module procedure tridiag_lu_solve_single_rhs_multiple_lhs
      75             :     module procedure tridiag_lu_solve_multiple_rhs_lhs
      76             :   end interface
      77             : 
      78             :   private ! Default scope
      79             : 
      80             :   contains
      81             : 
      82             :   !=============================================================================
      83           0 :   subroutine tridiag_lu_solve_single_rhs_lhs( ndim, lhs, rhs, &
      84           0 :                                               soln )
      85             :     ! Description:
      86             :     !   Written for single RHS and single LHS.
      87             :     !------------------------------------------------------------------------
      88             : 
      89             :     implicit none
      90             :  
      91             :     ! ----------------------- Input Variables -----------------------
      92             :     integer, intent(in) :: &
      93             :       ndim  ! Matrix size 
      94             : 
      95             :     real( kind = core_rknd ), intent(in), dimension(ndim) ::  &
      96             :       rhs       ! 
      97             : 
      98             :     ! ----------------------- Input/Output Variables -----------------------
      99             :     real( kind = core_rknd ), intent(inout), dimension(-1:1,ndim) :: &
     100             :       lhs   ! Matrices to solve, stored using band diagonal vectors
     101             :             ! -2 is the uppermost band, 2 is the lower most band, 0 is diagonal
     102             : 
     103             :     ! ----------------------- Output Variables -----------------------
     104             :     real( kind = core_rknd ), intent(out), dimension(ndim) ::  &
     105             :       soln     ! Solution vector
     106             : 
     107             :     ! ----------------------- Local Variables -----------------------
     108             :     real( kind = core_rknd ), dimension(ndim) ::  &
     109           0 :       upper,         & ! First U band 
     110           0 :       lower_diag_invrs ! Inverse of the diagonal of L
     111             : 
     112             :     integer :: k    ! Loop variables
     113             : 
     114             :     ! ----------------------- Begin Code -----------------------
     115             :        
     116             :     !$acc data create( upper, lower_diag_invrs ) &
     117             :     !$acc      copyin( rhs, lhs ) &
     118             :     !$acc      copyout( soln )
     119             :     
     120             :     !$acc kernels
     121             :     
     122           0 :     lower_diag_invrs(1) = 1.0_core_rknd / lhs(0,1)
     123           0 :     upper(1)            = lower_diag_invrs(1) * lhs(-1,1) 
     124             : 
     125           0 :     do k = 2, ndim-1
     126           0 :       lower_diag_invrs(k) = 1.0_core_rknd / ( lhs(0,k) - lhs(1,k) * upper(k-1)  )
     127           0 :       upper(k)            = lower_diag_invrs(k) * lhs(-1,k) 
     128             :     end do
     129             : 
     130           0 :     lower_diag_invrs(ndim) = 1.0_core_rknd / ( lhs(0,ndim) - lhs(1,ndim) * upper(ndim-1)  )
     131             : 
     132           0 :     soln(1)   = lower_diag_invrs(1) * rhs(1) 
     133             : 
     134           0 :     do k = 2, ndim
     135           0 :       soln(k) = lower_diag_invrs(k) * ( rhs(k) - lhs(1,k) * soln(k-1) )
     136             :     end do
     137             : 
     138           0 :     do k = ndim-1, 1, -1
     139           0 :       soln(k) = soln(k) - upper(k) * soln(k+1)
     140             :     end do
     141             : 
     142             :     !$acc end kernels
     143             : 
     144             :     !$acc end data
     145             : 
     146           0 :   end subroutine tridiag_lu_solve_single_rhs_lhs
     147             : 
     148             :   !=============================================================================
     149           0 :   subroutine tridiag_lu_solve_single_rhs_multiple_lhs( ndim, ngrdcol, lhs, rhs, &
     150           0 :                                                        soln )
     151             :     ! Description:
     152             :     !   Written for single RHS and multiple LHS.
     153             :     !------------------------------------------------------------------------
     154             : 
     155             :     implicit none
     156             :  
     157             :     ! ----------------------- Input Variables -----------------------
     158             :     integer, intent(in) :: &
     159             :       ndim,   & ! Matrix size 
     160             :       ngrdcol   ! Number of grid columns
     161             : 
     162             :     real( kind = core_rknd ), intent(in), dimension(ngrdcol,ndim) ::  &
     163             :       rhs       ! 
     164             : 
     165             :     ! ----------------------- Input/Output Variables -----------------------
     166             :     real( kind = core_rknd ), intent(inout), dimension(-1:1,ngrdcol,ndim) :: &
     167             :       lhs   ! Matrices to solve, stored using band diagonal vectors
     168             :             ! -2 is the uppermost band, 2 is the lower most band, 0 is diagonal
     169             : 
     170             :     ! ----------------------- Output Variables -----------------------
     171             :     real( kind = core_rknd ), intent(out), dimension(ngrdcol,ndim) ::  &
     172             :       soln     ! Solution vector
     173             : 
     174             :     ! ----------------------- Local Variables -----------------------
     175             :     real( kind = core_rknd ), dimension(ngrdcol,ndim) ::  &
     176           0 :       upper,         & ! First U band 
     177           0 :       lower_diag_invrs ! Inverse of the diagonal of L
     178             : 
     179             :     integer :: i, k    ! Loop variables
     180             : 
     181             :     ! ----------------------- Begin Code -----------------------
     182             :        
     183             :     !$acc data create( upper, lower_diag_invrs ) &
     184             :     !$acc      copyin( rhs, lhs ) &
     185             :     !$acc      copyout( soln )
     186             :     
     187             :     !$acc parallel loop gang vector default(present)
     188           0 :     do i = 1, ngrdcol
     189           0 :       lower_diag_invrs(i,1) = 1.0_core_rknd / lhs(0,i,1)
     190           0 :       upper(i,1)            = lower_diag_invrs(i,1) * lhs(-1,i,1) 
     191             :     end do
     192             :     !$acc end parallel loop
     193             : 
     194             :     !$acc parallel loop gang vector default(present)
     195           0 :     do i = 1, ngrdcol
     196           0 :       do k = 2, ndim-1
     197           0 :         lower_diag_invrs(i,k) = 1.0_core_rknd / ( lhs(0,i,k) - lhs(1,i,k) * upper(i,k-1)  )
     198           0 :         upper(i,k)            = lower_diag_invrs(i,k) * lhs(-1,i,k) 
     199             :       end do
     200             :     end do
     201             :     !$acc end parallel loop
     202             : 
     203             :     !$acc parallel loop gang vector default(present)
     204           0 :     do i = 1, ngrdcol
     205           0 :       lower_diag_invrs(i,ndim) = 1.0_core_rknd / ( lhs(0,i,ndim) - lhs(1,i,ndim) * upper(i,ndim-1)  )
     206             :     end do
     207             :     !$acc end parallel loop
     208             : 
     209             :     !$acc parallel loop gang vector default(present)
     210           0 :     do i = 1, ngrdcol 
     211             : 
     212           0 :       soln(i,1)   = lower_diag_invrs(i,1) * rhs(i,1) 
     213             : 
     214           0 :       do k = 2, ndim
     215           0 :         soln(i,k) = lower_diag_invrs(i,k) * ( rhs(i,k) - lhs(1,i,k) * soln(i,k-1) )
     216             :       end do
     217             :     end do
     218             :     !$acc end parallel loop
     219             : 
     220             :     !$acc parallel loop gang vector default(present)
     221           0 :     do i = 1, ngrdcol 
     222           0 :       do k = ndim-1, 1, -1
     223           0 :         soln(i,k) = soln(i,k) - upper(i,k) * soln(i,k+1)
     224             :       end do
     225             : 
     226             :     end do
     227             :     !$acc end parallel loop
     228             : 
     229             :     !$acc end data
     230             : 
     231           0 :   end subroutine tridiag_lu_solve_single_rhs_multiple_lhs
     232             : 
     233             :   
     234             :   !=============================================================================
     235           0 :   subroutine tridiag_lu_solve_multiple_rhs_lhs( ndim, nrhs, ngrdcol, lhs, rhs, &
     236           0 :                                                 soln )
     237             :     ! Description:
     238             :     !   Written for multiple RHS and multiple LHS.
     239             :     !------------------------------------------------------------------------
     240             : 
     241             :     implicit none
     242             :  
     243             :     ! ----------------------- Input Variables -----------------------
     244             :     integer, intent(in) :: &
     245             :       ndim,   & ! Matrix size 
     246             :       nrhs,   & ! Number of right hand sides
     247             :       ngrdcol   ! Number of grid columns
     248             : 
     249             :     real( kind = core_rknd ), intent(in), dimension(ngrdcol,ndim,nrhs) ::  &
     250             :       rhs       ! 
     251             : 
     252             :     ! ----------------------- Input/Output Variables -----------------------
     253             :     real( kind = core_rknd ), intent(inout), dimension(-1:1,ngrdcol,ndim) :: &
     254             :       lhs   ! Matrices to solve, stored using band diagonal vectors
     255             :             ! -2 is the uppermost band, 2 is the lower most band, 0 is diagonal
     256             : 
     257             :     ! ----------------------- Output Variables -----------------------
     258             :     real( kind = core_rknd ), intent(out), dimension(ngrdcol,ndim,nrhs) ::  &
     259             :       soln     ! Solution vector
     260             : 
     261             :     ! ----------------------- Local Variables -----------------------
     262             :     real( kind = core_rknd ), dimension(ngrdcol,ndim) ::  &
     263           0 :       upper,            & ! First U band 
     264           0 :       lower_diag_invrs    ! Inverse of the diagonal of L
     265             : 
     266             :     integer :: i, k, j    ! Loop variables
     267             : 
     268             :     ! ----------------------- Begin Code -----------------------
     269             :        
     270             :     !$acc data create( upper, lower_diag_invrs ) &
     271             :     !$acc      copyin( rhs, lhs ) &
     272             :     !$acc      copyout( soln )
     273             : 
     274             :     !$acc parallel loop gang vector default(present)
     275           0 :     do i = 1, ngrdcol
     276           0 :       lower_diag_invrs(i,1) = 1.0_core_rknd / lhs(0,i,1)
     277           0 :       upper(i,1)            = lower_diag_invrs(i,1) * lhs(-1,i,1) 
     278             :     end do
     279             :     !$acc end parallel loop
     280             : 
     281             :     !$acc parallel loop gang vector default(present)
     282           0 :     do i = 1, ngrdcol
     283           0 :       do k = 2, ndim-1
     284           0 :         lower_diag_invrs(i,k) = 1.0_core_rknd / ( lhs(0,i,k) - lhs(1,i,k) * upper(i,k-1)  )
     285           0 :         upper(i,k)            = lower_diag_invrs(i,k) * lhs(-1,i,k) 
     286             :       end do
     287             :     end do
     288             :     !$acc end parallel loop
     289             : 
     290             :     !$acc parallel loop gang vector default(present)
     291           0 :     do i = 1, ngrdcol
     292           0 :       lower_diag_invrs(i,ndim) = 1.0_core_rknd / ( lhs(0,i,ndim) - lhs(1,i,ndim) * upper(i,ndim-1)  )
     293             :     end do
     294             :     !$acc end parallel loop
     295             : 
     296             :     !$acc parallel loop gang vector collapse(2) default(present)
     297           0 :     do j = 1, nrhs
     298           0 :       do i = 1, ngrdcol 
     299             : 
     300           0 :         soln(i,1,j)   = lower_diag_invrs(i,1) * rhs(i,1,j) 
     301             : 
     302           0 :         do k = 2, ndim
     303           0 :           soln(i,k,j) = lower_diag_invrs(i,k) * ( rhs(i,k,j) - lhs(1,i,k) * soln(i,k-1,j) )
     304             :         end do
     305             :       end do
     306             :     end do
     307             :     !$acc end parallel loop
     308             : 
     309             :     !$acc parallel loop gang vector collapse(2) default(present)
     310           0 :     do j = 1, nrhs
     311           0 :       do i = 1, ngrdcol 
     312           0 :         do k = ndim-1, 1, -1
     313           0 :           soln(i,k,j) = soln(i,k,j) - upper(i,k) * soln(i,k+1,j)
     314             :         end do
     315             :       end do
     316             :     end do
     317             :     !$acc end parallel loop
     318             : 
     319             :     !$acc end data
     320             : 
     321           0 :   end subroutine tridiag_lu_solve_multiple_rhs_lhs
     322             : 
     323             : end module tridiag_lu_solvers

Generated by: LCOV version 1.14