Line data Source code
1 : module penta_lu_solvers
2 :
3 : ! Description:
4 : ! These routines solve lhs*soln=rhs using LU decomp.
5 : !
6 : ! LHS is stored in band diagonal form.
7 : ! lhs = | lhs(0,1) lhs(-1,1) lhs(-2,1) 0 0 0 0
8 : ! | lhs(1,2) lhs( 0,2) lhs(-1,2) lhs(-2,2) 0 0 0
9 : ! | lhs(2,3) lhs( 1,3) lhs( 0,3) lhs(-1,3) lhs(-2,3) 0 0
10 : ! | 0 lhs( 2,4) lhs( 1,4) lhs( 0,4) lhs(-1,4) lhs(-2,4) 0
11 : ! | 0 0 lhs( 2,5) lhs( 1,5) lhs( 0,5) lhs(-1,5) lhs(-2,5) ...
12 : ! | ...
13 : !
14 : ! U is stored in band diagonal form
15 : ! U = | 1 upper_1(1) upper_2(1) 0 0 0 0
16 : ! | 0 1 upper_1(2) upper_2(2) 0 0 0
17 : ! | 0 0 1 upper_1(3) upper_2(3) 0 0
18 : ! | 0 0 0 1 upper_1(4) upper_2(4) 0
19 : ! | 0 0 0 0 1 upper_1(5) upper_2(5) ...
20 : ! | ...
21 : !
22 : ! L is also stored in band diagonal form, but the lowest most band is equivalent to the
23 : ! lowermost band of LHS, thus we don't need to store it
24 : ! L = | l_diag(1) 0 0 0 0 0
25 : ! | lower_1(2) l_diag(2) 0 0 0 0
26 : ! | l_2(3) lower_1(3) l_diag(3) 0 0 0
27 : ! | 0 l_2(4) lower_1(4) l_diag(4) 0 0
28 : ! | 0 0 l_2(5) lower_1(5) l_diag(5) 0 ...
29 : ! | ...
30 : !
31 : !
32 : ! To perform the LU decomposition, we go element by element.
33 : ! First we start by noting that we want lhs=LU, so the first step of calculating
34 : ! L*U, by multiplying the first row of L by the columns of U, gives us
35 : !
36 : ! l_diag(1)*1 = lhs( 0,1) => l_diag(1) = lhs( 0,1)
37 : ! l_diag(1)*upper_1(1) = lhs(-1,1) => upper_1(1) = lhs(-1,1) / l_diag(1)
38 : ! l_diag(1)*upper_2(1) = lhs(-2,1) => upper_2(1) = lhs(-2,1) / l_diag(1)
39 : !
40 : ! Multiplying the second row of L by U now we get
41 : !
42 : ! lower_1(2)*1 = lhs(1,2) => lower_1(2) = lhs(1,2)
43 : ! lower_1(2)*upper_1(1)+l_diag(2)*1 = lhs(0,2) => l_diag(2) = lhs(0,2) - lower_1(2)*upper_1(1)
44 : ! lower_1(2)*upper_2(1)+l_diag(2)*upper_1(2) = lhs(-1,2) => upper_1(2) = ( lhs(-1,2)-lower_1(2)*upper_2(1) )
45 : ! / l_diag(2)
46 : ! l_diag(2)*upper_2(2) = lhs(-2,2) => upper_2(2) = lhs(-2,2) / l_diag(2)
47 : !
48 : ! Now that we're passed the k=1 and k=2 steps, each following step uses all the bands,
49 : ! allowing us to write the general step
50 : !
51 : ! l_2(k)*1 = lhs(2,k) => l_2(k) = lhs(2,k)
52 : ! l_2(k)*upper_1(k-2)+lower_1(k)*1 = lhs(1,k) => lower_1(k) = lhs(1,k) - l_2(k)*upper_1(k-2)
53 : ! l_2(k)*upper_2(k-2)+lower_1(k)*upper_1(k-1) = lhs( 0,k) => l_diag(k) = lhs(0,k) - l_2(k)*upper_2(k-2)
54 : ! +l_diag(k)*1 + lower_1(k)*upper_1(k-1)
55 : !
56 : ! lower_1(k)*upper_2(k-1)
57 : ! + l_diag(k)*upper_1(k) = lhs(-1,k) => upper_1(k) = ( lhs(-1,k) - lower_1(k)*upper_2(k-1) )
58 : ! / l_diag(k)
59 : ! l_diag(k)*upper_2(k) = lhs(-2,k) => upper_2(k) = lhs(-2,k) / l_diag(k)
60 : !
61 : !
62 : ! This general step is done for k from 3 to ndim-2 (do k = 3, ndim-2), and the last two
63 : ! steps are tweaked similarly to the first two, where we disclude one then two bands
64 : ! since they become no longer relevant. Note from this general step that the l_2 band
65 : ! is always equivalent to second subdiagonal band of lhs, thus we do not need to
66 : ! calculate or store l_2. Also note that we only ever need l_diag so that we can divide
67 : ! by it, so instead we compute lower_diag_invrs to reduce divide operations.
68 : !
69 : ! After L and U are computed, normally we do forward substitution using L,
70 : ! then backward substitution using U to find the solution. This is replicated
71 : ! for every right hand side we want to solve for.
72 : !
73 : !
74 : ! References:
75 : ! none
76 : !------------------------------------------------------------------------
77 :
78 :
79 : use clubb_precision, only: &
80 : core_rknd ! Variable(s)
81 :
82 : implicit none
83 :
84 : public :: penta_lu_solve
85 :
86 : private :: penta_lu_solve_single_rhs_multiple_lhs, penta_lu_solve_multiple_rhs_lhs
87 :
88 : interface penta_lu_solve
89 : module procedure penta_lu_solve_single_rhs_multiple_lhs
90 : module procedure penta_lu_solve_multiple_rhs_lhs
91 : end interface
92 :
93 : private ! Default scope
94 :
95 : contains
96 :
97 : !=============================================================================
98 0 : subroutine penta_lu_solve_single_rhs_multiple_lhs( ndim, ngrdcol, lhs, rhs, &
99 0 : soln )
100 : ! Description:
101 : ! Written for single RHS and multiple LHS.
102 : !------------------------------------------------------------------------
103 :
104 : implicit none
105 :
106 : ! ----------------------- Input Variables -----------------------
107 : integer, intent(in) :: &
108 : ndim, & ! Matrix size
109 : ngrdcol ! Number of grid columns
110 :
111 : real( kind = core_rknd ), intent(in), dimension(ngrdcol,ndim) :: &
112 : rhs !
113 :
114 : ! ----------------------- Input/Output Variables -----------------------
115 : real( kind = core_rknd ), intent(inout), dimension(-2:2,ngrdcol,ndim) :: &
116 : lhs ! Matrices to solve, stored using band diagonal vectors
117 : ! -2 is the uppermost band, 2 is the lower most band, 0 is diagonal
118 :
119 : ! ----------------------- Output Variables -----------------------
120 : real( kind = core_rknd ), intent(out), dimension(ngrdcol,ndim) :: &
121 : soln ! Solution vector
122 :
123 : ! ----------------------- Local Variables -----------------------
124 : real( kind = core_rknd ), dimension(ngrdcol,ndim) :: &
125 0 : upper_1, & ! First U band
126 0 : upper_2, & ! Second U band
127 0 : lower_diag_invrs, & ! Inverse of the diagonal of L
128 0 : lower_1, & ! First L band
129 0 : lower_2 ! Second L band
130 :
131 : integer :: i, k, j ! Loop variables
132 :
133 : ! ----------------------- Begin Code -----------------------
134 :
135 : !$acc data create( upper_1, upper_2, lower_1, lower_2, lower_diag_invrs ) &
136 : !$acc copyin( rhs, lhs ) &
137 : !$acc copyout( soln )
138 :
139 : !$acc parallel loop gang vector default(present)
140 0 : do i = 1, ngrdcol
141 0 : lower_diag_invrs(i,1) = 1.0_core_rknd / lhs(0,i,1)
142 0 : upper_1(i,1) = lower_diag_invrs(i,1) * lhs(-1,i,1)
143 0 : upper_2(i,1) = lower_diag_invrs(i,1) * lhs(-2,i,1)
144 :
145 0 : lower_1(i,2) = lhs(1,i,2)
146 0 : lower_diag_invrs(i,2) = 1.0_core_rknd / ( lhs(0,i,2) - lower_1(i,2) * upper_1(i,1) )
147 0 : upper_1(i,2) = lower_diag_invrs(i,2) * ( lhs(-1,i,2) - lower_1(i,2) * upper_2(i,1) )
148 0 : upper_2(i,2) = lower_diag_invrs(i,2) * lhs(-2,i,2)
149 : end do
150 : !$acc end parallel loop
151 :
152 : !$acc parallel loop gang vector default(present)
153 0 : do i = 1, ngrdcol
154 0 : do k = 3, ndim-2
155 0 : lower_2(i,k) = lhs(2,i,k)
156 0 : lower_1(i,k) = lhs(1,i,k) - lower_2(i,k) * upper_1(i,k-2)
157 :
158 : lower_diag_invrs(i,k) = 1.0_core_rknd / ( lhs(0,i,k) - lower_2(i,k) * upper_2(i,k-2) &
159 0 : - lower_1(i,k) * upper_1(i,k-1) )
160 :
161 0 : upper_1(i,k) = lower_diag_invrs(i,k) * ( lhs(-1,i,k) - lower_1(i,k) * upper_2(i,k-1) )
162 0 : upper_2(i,k) = lower_diag_invrs(i,k) * lhs(-2,i,k)
163 : end do
164 : end do
165 : !$acc end parallel loop
166 :
167 : !$acc parallel loop gang vector default(present)
168 0 : do i = 1, ngrdcol
169 0 : lower_2(i,ndim-1) = lhs(2,i,ndim-1)
170 0 : lower_1(i,ndim-1) = lhs(1,i,ndim-1) - lower_2(i,ndim-1) * upper_1(i,ndim-3)
171 :
172 : lower_diag_invrs(i,ndim-1) = 1.0_core_rknd &
173 : / ( lhs(0,i,ndim-1) - lower_2(i,ndim-1) * upper_2(i,ndim-3) &
174 0 : - lower_1(i,ndim-1) * upper_1(i,ndim-2) )
175 :
176 : upper_1(i,ndim-1) = lower_diag_invrs(i,ndim-1) * ( lhs(-1,i,ndim-1) - lower_1(i,ndim-1) &
177 0 : * upper_2(i,ndim-2) )
178 :
179 0 : lower_2(i,ndim) = lhs(2,i,ndim)
180 0 : lower_1(i,ndim) = lhs(1,i,ndim) - lower_2(i,ndim) * upper_1(i,ndim-2)
181 :
182 : lower_diag_invrs(i,ndim) = 1.0_core_rknd &
183 : / ( lhs(0,i,ndim-1) - lower_2(i,ndim) * upper_2(i,ndim-2) &
184 0 : - lower_1(i,ndim) * upper_1(i,ndim-1) )
185 : end do
186 : !$acc end parallel loop
187 :
188 : !$acc parallel loop gang vector default(present)
189 0 : do i = 1, ngrdcol
190 :
191 0 : soln(i,1) = lower_diag_invrs(i,1) * rhs(i,1)
192 :
193 0 : soln(i,2) = lower_diag_invrs(i,2) * ( rhs(i,2) - lower_1(i,2) * soln(i,1) )
194 :
195 0 : do k = 3, ndim
196 0 : soln(i,k) = lower_diag_invrs(i,k) * ( rhs(i,k) - lower_2(i,k) * soln(i,k-2) &
197 0 : - lower_1(i,k) * soln(i,k-1) )
198 : end do
199 : end do
200 : !$acc end parallel loop
201 :
202 : !$acc parallel loop gang vector default(present)
203 0 : do i = 1, ngrdcol
204 0 : soln(i,ndim-1) = soln(i,ndim-1) - upper_1(i,ndim-1) * soln(i,ndim)
205 :
206 0 : do k = ndim-2, 1, -1
207 0 : soln(i,k) = soln(i,k) - upper_1(i,k) * soln(i,k+1) - upper_2(i,k) * soln(i,k+2)
208 : end do
209 :
210 : end do
211 : !$acc end parallel loop
212 :
213 : !$acc end data
214 :
215 0 : end subroutine penta_lu_solve_single_rhs_multiple_lhs
216 :
217 :
218 : !=============================================================================
219 0 : subroutine penta_lu_solve_multiple_rhs_lhs( ndim, nrhs, ngrdcol, lhs, rhs, &
220 0 : soln )
221 : ! Description:
222 : ! Written for multiple RHS and multiple LHS.
223 : !------------------------------------------------------------------------
224 :
225 : implicit none
226 :
227 : ! ----------------------- Input Variables -----------------------
228 : integer, intent(in) :: &
229 : ndim, & ! Matrix size
230 : nrhs, & ! Number of right hand sides
231 : ngrdcol ! Number of grid columns
232 :
233 : real( kind = core_rknd ), intent(in), dimension(ngrdcol,ndim,nrhs) :: &
234 : rhs !
235 :
236 : ! ----------------------- Input/Output Variables -----------------------
237 : real( kind = core_rknd ), intent(inout), dimension(-2:2,ngrdcol,ndim) :: &
238 : lhs ! Matrices to solve, stored using band diagonal vectors
239 : ! -2 is the uppermost band, 2 is the lower most band, 0 is diagonal
240 :
241 : ! ----------------------- Output Variables -----------------------
242 : real( kind = core_rknd ), intent(out), dimension(ngrdcol,ndim,nrhs) :: &
243 : soln ! Solution vector
244 :
245 : ! ----------------------- Local Variables -----------------------
246 : real( kind = core_rknd ), dimension(ngrdcol,ndim) :: &
247 0 : upper_1, & ! First U band
248 0 : upper_2, & ! Second U band
249 0 : lower_diag_invrs, & ! Inverse of the diagonal of L
250 0 : lower_1, & ! First L band
251 0 : lower_2 ! Second L band
252 :
253 : integer :: i, k, j ! Loop variables
254 :
255 : ! ----------------------- Begin Code -----------------------
256 :
257 : !$acc data create( upper_1, upper_2, lower_1, lower_2, lower_diag_invrs ) &
258 : !$acc copyin( rhs, lhs ) &
259 : !$acc copyout( soln )
260 :
261 : !$acc parallel loop gang vector default(present)
262 0 : do i = 1, ngrdcol
263 0 : lower_diag_invrs(i,1) = 1.0_core_rknd / lhs(0,i,1)
264 0 : upper_1(i,1) = lower_diag_invrs(i,1) * lhs(-1,i,1)
265 0 : upper_2(i,1) = lower_diag_invrs(i,1) * lhs(-2,i,1)
266 :
267 0 : lower_1(i,2) = lhs(1,i,2)
268 0 : lower_diag_invrs(i,2) = 1.0_core_rknd / ( lhs(0,i,2) - lower_1(i,2) * upper_1(i,1) )
269 0 : upper_1(i,2) = lower_diag_invrs(i,2) * ( lhs(-1,i,2) - lower_1(i,2) * upper_2(i,1) )
270 0 : upper_2(i,2) = lower_diag_invrs(i,2) * lhs(-2,i,2)
271 : end do
272 : !$acc end parallel loop
273 :
274 : !$acc parallel loop gang vector default(present)
275 0 : do i = 1, ngrdcol
276 0 : do k = 3, ndim-2
277 0 : lower_2(i,k) = lhs(2,i,k)
278 0 : lower_1(i,k) = lhs(1,i,k) - lower_2(i,k) * upper_1(i,k-2)
279 :
280 : lower_diag_invrs(i,k) = 1.0_core_rknd / ( lhs(0,i,k) - lower_2(i,k) * upper_2(i,k-2) &
281 0 : - lower_1(i,k) * upper_1(i,k-1) )
282 :
283 0 : upper_1(i,k) = lower_diag_invrs(i,k) * ( lhs(-1,i,k) - lower_1(i,k) * upper_2(i,k-1) )
284 0 : upper_2(i,k) = lower_diag_invrs(i,k) * lhs(-2,i,k)
285 : end do
286 : end do
287 : !$acc end parallel loop
288 :
289 : !$acc parallel loop gang vector default(present)
290 0 : do i = 1, ngrdcol
291 0 : lower_2(i,ndim-1) = lhs(2,i,ndim-1)
292 0 : lower_1(i,ndim-1) = lhs(1,i,ndim-1) - lower_2(i,ndim-1) * upper_1(i,ndim-3)
293 :
294 : lower_diag_invrs(i,ndim-1) = 1.0_core_rknd &
295 : / ( lhs(0,i,ndim-1) - lower_2(i,ndim-1) * upper_2(i,ndim-3) &
296 0 : - lower_1(i,ndim-1) * upper_1(i,ndim-2) )
297 :
298 : upper_1(i,ndim-1) = lower_diag_invrs(i,ndim-1) * ( lhs(-1,i,ndim-1) - lower_1(i,ndim-1) &
299 0 : * upper_2(i,ndim-2) )
300 :
301 0 : lower_2(i,ndim) = lhs(2,i,ndim)
302 0 : lower_1(i,ndim) = lhs(1,i,ndim) - lower_2(i,ndim) * upper_1(i,ndim-2)
303 :
304 : lower_diag_invrs(i,ndim) = 1.0_core_rknd &
305 : / ( lhs(0,i,ndim-1) - lower_2(i,ndim) * upper_2(i,ndim-2) &
306 0 : - lower_1( i,ndim) * upper_1(i,ndim-1) )
307 : end do
308 : !$acc end parallel loop
309 :
310 : !$acc parallel loop gang vector collapse(2) default(present)
311 0 : do j = 1, nrhs
312 0 : do i = 1, ngrdcol
313 :
314 0 : soln(i,1,j) = lower_diag_invrs(i,1) * rhs(i,1,j)
315 :
316 0 : soln(i,2,j) = lower_diag_invrs(i,2) * ( rhs(i,2,j) - lower_1(i,2) * soln(i,1,j) )
317 :
318 0 : do k = 3, ndim
319 0 : soln(i,k,j) = lower_diag_invrs(i,k) * ( rhs(i,k,j) - lower_2(i,k) * soln(i,k-2,j) &
320 0 : - lower_1(i,k) * soln(i,k-1,j) )
321 : end do
322 : end do
323 : end do
324 : !$acc end parallel loop
325 :
326 : !$acc parallel loop gang vector collapse(2) default(present)
327 0 : do j = 1, nrhs
328 0 : do i = 1, ngrdcol
329 0 : soln(i,ndim-1,j) = soln(i,ndim-1,j) - upper_1(i,ndim-1) * soln(i,ndim,j)
330 :
331 0 : do k = ndim-2, 1, -1
332 0 : soln(i,k,j) = soln(i,k,j) - upper_1(i,k) * soln(i,k+1,j) - upper_2(i,k) * soln(i,k+2,j)
333 : end do
334 :
335 : end do
336 : end do
337 : !$acc end parallel loop
338 :
339 : !$acc end data
340 :
341 0 : end subroutine penta_lu_solve_multiple_rhs_lhs
342 :
343 : end module penta_lu_solvers
|