Line data Source code
1 : module tridiag_lu_solvers
2 :
3 : ! Description:
4 : ! These routines solve lhs*soln=rhs using LU decomp.
5 : !
6 : ! LHS is stored in band diagonal form.
7 : ! lhs = | lhs(0,1) lhs(-1,1) 0 0 0 0 0
8 : ! | lhs(1,2) lhs( 0,2) lhs(-1,2) 0 0 0 0
9 : ! | 0 lhs( 1,3) lhs( 0,3) lhs(-1,3) 0 0 0
10 : ! | 0 0 lhs( 1,4) lhs( 0,4) lhs(-1,4) 0 0
11 : ! | 0 0 0 lhs( 1,5) lhs( 0,5) lhs(-1,5) 0 ...
12 : ! | ...
13 : !
14 : ! U is stored in band diagonal form
15 : ! U = | 1 upper(1) 0 0 0 0 0
16 : ! | 0 1 upper(2) 0 0 0 0
17 : ! | 0 0 1 upper(3) 0 0 0
18 : ! | 0 0 0 1 upper(4) 0 0
19 : ! | 0 0 0 0 1 upper(5) 0 ...
20 : ! | ...
21 : !
22 : ! L is also stored in band diagonal form, but the lowest most band is equivalent to the
23 : ! lowermost band of LHS, thus we don't need to store it
24 : ! L = | l_diag(1) 0 0 0 0 0
25 : ! | lower(2) l_diag(2) 0 0 0 0
26 : ! | 0 lower(3) l_diag(3) 0 0 0
27 : ! | 0 0 lower(4) l_diag(4) 0 0
28 : ! | 0 0 0 lower(5) l_diag(5) 0 ...
29 : ! | ...
30 : !
31 : !
32 : ! To perform the LU decomposition, we go element by element.
33 : ! First we start by noting that we want lhs=LU, so the first step of calculating
34 : ! L*U, by multiplying the first row of L by the columns of U, gives us
35 : !
36 : ! l_diag(1)*1 = lhs( 0,1) => l_diag(1) = lhs( 0,1)
37 : ! l_diag(1)*upper(1) = lhs(-1,1) => upper(1) = lhs(-1,1) / l_diag(1)
38 : !
39 : ! Now that we're passed the k=1 step, each following step uses all the bands,
40 : ! allowing us to write the general step
41 : !
42 : ! lower(k)*1 = lhs(1,k) => lower(k) = lhs( 1,k)
43 : ! lower(k)*upper(k-1)+l_diag(k)*1 = lhs( 0,k) => l_diag(k) = lhs( 0,k) - lower(k)*upper(k-1)
44 : ! l_diag(k)*upper(k) = lhs(-1,k) => upper(k) = lhs(-1,k) / l_diag(k)
45 : !
46 : ! This general step is done for k from 2 to ndim-1 (do k = 2, ndim-), and the last
47 : ! step is tweaked similarly to the first one, where we disclude the upper band
48 : ! since it becomes no longer relevant. Note from this general step that the lower band
49 : ! is always equivalent to first subdiagonal band of lhs, thus we do not need to
50 : ! calculate or store lower. Also note that we only ever need l_diag so that we can divide
51 : ! by it, so instead we compute lower_diag_invrs to reduce divide operations.
52 : !
53 : ! After L and U are computed, normally we do forward substitution using L,
54 : ! then backward substitution using U to find the solution. This is replicated
55 : ! for every right hand side we want to solve for.
56 : !
57 : !
58 : ! References:
59 : ! none
60 : !------------------------------------------------------------------------
61 :
62 :
63 : use clubb_precision, only: &
64 : core_rknd ! Variable(s)
65 :
66 : implicit none
67 :
68 : public :: tridiag_lu_solve
69 :
70 : private :: tridiag_lu_solve_single_rhs_multiple_lhs, tridiag_lu_solve_multiple_rhs_lhs
71 :
72 : interface tridiag_lu_solve
73 : module procedure tridiag_lu_solve_single_rhs_lhs
74 : module procedure tridiag_lu_solve_single_rhs_multiple_lhs
75 : module procedure tridiag_lu_solve_multiple_rhs_lhs
76 : end interface
77 :
78 : private ! Default scope
79 :
80 : contains
81 :
82 : !=============================================================================
83 0 : subroutine tridiag_lu_solve_single_rhs_lhs( ndim, lhs, rhs, &
84 0 : soln )
85 : ! Description:
86 : ! Written for single RHS and single LHS.
87 : !------------------------------------------------------------------------
88 :
89 : implicit none
90 :
91 : ! ----------------------- Input Variables -----------------------
92 : integer, intent(in) :: &
93 : ndim ! Matrix size
94 :
95 : real( kind = core_rknd ), intent(in), dimension(ndim) :: &
96 : rhs !
97 :
98 : ! ----------------------- Input/Output Variables -----------------------
99 : real( kind = core_rknd ), intent(inout), dimension(-1:1,ndim) :: &
100 : lhs ! Matrices to solve, stored using band diagonal vectors
101 : ! -2 is the uppermost band, 2 is the lower most band, 0 is diagonal
102 :
103 : ! ----------------------- Output Variables -----------------------
104 : real( kind = core_rknd ), intent(out), dimension(ndim) :: &
105 : soln ! Solution vector
106 :
107 : ! ----------------------- Local Variables -----------------------
108 : real( kind = core_rknd ), dimension(ndim) :: &
109 0 : upper, & ! First U band
110 0 : lower_diag_invrs ! Inverse of the diagonal of L
111 :
112 : integer :: k ! Loop variables
113 :
114 : ! ----------------------- Begin Code -----------------------
115 :
116 : !$acc data create( upper, lower_diag_invrs ) &
117 : !$acc copyin( rhs, lhs ) &
118 : !$acc copyout( soln )
119 :
120 : !$acc kernels
121 :
122 0 : lower_diag_invrs(1) = 1.0_core_rknd / lhs(0,1)
123 0 : upper(1) = lower_diag_invrs(1) * lhs(-1,1)
124 :
125 0 : do k = 2, ndim-1
126 0 : lower_diag_invrs(k) = 1.0_core_rknd / ( lhs(0,k) - lhs(1,k) * upper(k-1) )
127 0 : upper(k) = lower_diag_invrs(k) * lhs(-1,k)
128 : end do
129 :
130 0 : lower_diag_invrs(ndim) = 1.0_core_rknd / ( lhs(0,ndim) - lhs(1,ndim) * upper(ndim-1) )
131 :
132 0 : soln(1) = lower_diag_invrs(1) * rhs(1)
133 :
134 0 : do k = 2, ndim
135 0 : soln(k) = lower_diag_invrs(k) * ( rhs(k) - lhs(1,k) * soln(k-1) )
136 : end do
137 :
138 0 : do k = ndim-1, 1, -1
139 0 : soln(k) = soln(k) - upper(k) * soln(k+1)
140 : end do
141 :
142 : !$acc end kernels
143 :
144 : !$acc end data
145 :
146 0 : end subroutine tridiag_lu_solve_single_rhs_lhs
147 :
148 : !=============================================================================
149 0 : subroutine tridiag_lu_solve_single_rhs_multiple_lhs( ndim, ngrdcol, lhs, rhs, &
150 0 : soln )
151 : ! Description:
152 : ! Written for single RHS and multiple LHS.
153 : !------------------------------------------------------------------------
154 :
155 : implicit none
156 :
157 : ! ----------------------- Input Variables -----------------------
158 : integer, intent(in) :: &
159 : ndim, & ! Matrix size
160 : ngrdcol ! Number of grid columns
161 :
162 : real( kind = core_rknd ), intent(in), dimension(ngrdcol,ndim) :: &
163 : rhs !
164 :
165 : ! ----------------------- Input/Output Variables -----------------------
166 : real( kind = core_rknd ), intent(inout), dimension(-1:1,ngrdcol,ndim) :: &
167 : lhs ! Matrices to solve, stored using band diagonal vectors
168 : ! -2 is the uppermost band, 2 is the lower most band, 0 is diagonal
169 :
170 : ! ----------------------- Output Variables -----------------------
171 : real( kind = core_rknd ), intent(out), dimension(ngrdcol,ndim) :: &
172 : soln ! Solution vector
173 :
174 : ! ----------------------- Local Variables -----------------------
175 : real( kind = core_rknd ), dimension(ngrdcol,ndim) :: &
176 0 : upper, & ! First U band
177 0 : lower_diag_invrs ! Inverse of the diagonal of L
178 :
179 : integer :: i, k ! Loop variables
180 :
181 : ! ----------------------- Begin Code -----------------------
182 :
183 : !$acc data create( upper, lower_diag_invrs ) &
184 : !$acc copyin( rhs, lhs ) &
185 : !$acc copyout( soln )
186 :
187 : !$acc parallel loop gang vector default(present)
188 0 : do i = 1, ngrdcol
189 0 : lower_diag_invrs(i,1) = 1.0_core_rknd / lhs(0,i,1)
190 0 : upper(i,1) = lower_diag_invrs(i,1) * lhs(-1,i,1)
191 : end do
192 : !$acc end parallel loop
193 :
194 : !$acc parallel loop gang vector default(present)
195 0 : do i = 1, ngrdcol
196 0 : do k = 2, ndim-1
197 0 : lower_diag_invrs(i,k) = 1.0_core_rknd / ( lhs(0,i,k) - lhs(1,i,k) * upper(i,k-1) )
198 0 : upper(i,k) = lower_diag_invrs(i,k) * lhs(-1,i,k)
199 : end do
200 : end do
201 : !$acc end parallel loop
202 :
203 : !$acc parallel loop gang vector default(present)
204 0 : do i = 1, ngrdcol
205 0 : lower_diag_invrs(i,ndim) = 1.0_core_rknd / ( lhs(0,i,ndim) - lhs(1,i,ndim) * upper(i,ndim-1) )
206 : end do
207 : !$acc end parallel loop
208 :
209 : !$acc parallel loop gang vector default(present)
210 0 : do i = 1, ngrdcol
211 :
212 0 : soln(i,1) = lower_diag_invrs(i,1) * rhs(i,1)
213 :
214 0 : do k = 2, ndim
215 0 : soln(i,k) = lower_diag_invrs(i,k) * ( rhs(i,k) - lhs(1,i,k) * soln(i,k-1) )
216 : end do
217 : end do
218 : !$acc end parallel loop
219 :
220 : !$acc parallel loop gang vector default(present)
221 0 : do i = 1, ngrdcol
222 0 : do k = ndim-1, 1, -1
223 0 : soln(i,k) = soln(i,k) - upper(i,k) * soln(i,k+1)
224 : end do
225 :
226 : end do
227 : !$acc end parallel loop
228 :
229 : !$acc end data
230 :
231 0 : end subroutine tridiag_lu_solve_single_rhs_multiple_lhs
232 :
233 :
234 : !=============================================================================
235 0 : subroutine tridiag_lu_solve_multiple_rhs_lhs( ndim, nrhs, ngrdcol, lhs, rhs, &
236 0 : soln )
237 : ! Description:
238 : ! Written for multiple RHS and multiple LHS.
239 : !------------------------------------------------------------------------
240 :
241 : implicit none
242 :
243 : ! ----------------------- Input Variables -----------------------
244 : integer, intent(in) :: &
245 : ndim, & ! Matrix size
246 : nrhs, & ! Number of right hand sides
247 : ngrdcol ! Number of grid columns
248 :
249 : real( kind = core_rknd ), intent(in), dimension(ngrdcol,ndim,nrhs) :: &
250 : rhs !
251 :
252 : ! ----------------------- Input/Output Variables -----------------------
253 : real( kind = core_rknd ), intent(inout), dimension(-1:1,ngrdcol,ndim) :: &
254 : lhs ! Matrices to solve, stored using band diagonal vectors
255 : ! -2 is the uppermost band, 2 is the lower most band, 0 is diagonal
256 :
257 : ! ----------------------- Output Variables -----------------------
258 : real( kind = core_rknd ), intent(out), dimension(ngrdcol,ndim,nrhs) :: &
259 : soln ! Solution vector
260 :
261 : ! ----------------------- Local Variables -----------------------
262 : real( kind = core_rknd ), dimension(ngrdcol,ndim) :: &
263 0 : upper, & ! First U band
264 0 : lower_diag_invrs ! Inverse of the diagonal of L
265 :
266 : integer :: i, k, j ! Loop variables
267 :
268 : ! ----------------------- Begin Code -----------------------
269 :
270 : !$acc data create( upper, lower_diag_invrs ) &
271 : !$acc copyin( rhs, lhs ) &
272 : !$acc copyout( soln )
273 :
274 : !$acc parallel loop gang vector default(present)
275 0 : do i = 1, ngrdcol
276 0 : lower_diag_invrs(i,1) = 1.0_core_rknd / lhs(0,i,1)
277 0 : upper(i,1) = lower_diag_invrs(i,1) * lhs(-1,i,1)
278 : end do
279 : !$acc end parallel loop
280 :
281 : !$acc parallel loop gang vector default(present)
282 0 : do i = 1, ngrdcol
283 0 : do k = 2, ndim-1
284 0 : lower_diag_invrs(i,k) = 1.0_core_rknd / ( lhs(0,i,k) - lhs(1,i,k) * upper(i,k-1) )
285 0 : upper(i,k) = lower_diag_invrs(i,k) * lhs(-1,i,k)
286 : end do
287 : end do
288 : !$acc end parallel loop
289 :
290 : !$acc parallel loop gang vector default(present)
291 0 : do i = 1, ngrdcol
292 0 : lower_diag_invrs(i,ndim) = 1.0_core_rknd / ( lhs(0,i,ndim) - lhs(1,i,ndim) * upper(i,ndim-1) )
293 : end do
294 : !$acc end parallel loop
295 :
296 : !$acc parallel loop gang vector collapse(2) default(present)
297 0 : do j = 1, nrhs
298 0 : do i = 1, ngrdcol
299 :
300 0 : soln(i,1,j) = lower_diag_invrs(i,1) * rhs(i,1,j)
301 :
302 0 : do k = 2, ndim
303 0 : soln(i,k,j) = lower_diag_invrs(i,k) * ( rhs(i,k,j) - lhs(1,i,k) * soln(i,k-1,j) )
304 : end do
305 : end do
306 : end do
307 : !$acc end parallel loop
308 :
309 : !$acc parallel loop gang vector collapse(2) default(present)
310 0 : do j = 1, nrhs
311 0 : do i = 1, ngrdcol
312 0 : do k = ndim-1, 1, -1
313 0 : soln(i,k,j) = soln(i,k,j) - upper(i,k) * soln(i,k+1,j)
314 : end do
315 : end do
316 : end do
317 : !$acc end parallel loop
318 :
319 : !$acc end data
320 :
321 0 : end subroutine tridiag_lu_solve_multiple_rhs_lhs
322 :
323 : end module tridiag_lu_solvers
|