From 68b2b387dd426cef1bcd929e01e9872c16c0f35c Mon Sep 17 00:00:00 2001 From: Dongze Sun Date: Tue, 29 Nov 2022 15:43:39 -0800 Subject: [PATCH] Reduce memory usage when applying supertranslations by a factor of 6 --- scri/asymptotic_bondi_data/transformations.py | 151 ++++++++++++++---- 1 file changed, 117 insertions(+), 34 deletions(-) diff --git a/scri/asymptotic_bondi_data/transformations.py b/scri/asymptotic_bondi_data/transformations.py index d45b4e48..27948987 100644 --- a/scri/asymptotic_bondi_data/transformations.py +++ b/scri/asymptotic_bondi_data/transformations.py @@ -336,11 +336,24 @@ def transform(self, **kwargs): # σ(u, θ', ϕ') exp(2iλ) σ = sf.Grid(self.sigma.evaluate(distorted_grid_rotors), spin_weight=2) + # Determine the new time slices. The set timeprime is chosen so that on each slice of constant + # u'_i, the average value of u=(u'/k)+α is precisely =u'γ+<α>=u_i. But then, we have to + # narrow that set down, so that every grid point on all the u'_i' slices correspond to data in + # the range of input data. + timeprime = (u - sf.constant_from_ell_0_mode(supertranslation[0]).real) / γ + timeprime_of_initialtime_directionprime = k * (u[0] - α) + timeprime_of_finaltime_directionprime = k * (u[-1] - α) + earliest_complete_timeprime = np.max(timeprime_of_initialtime_directionprime.view(np.ndarray)) + latest_complete_timeprime = np.min(timeprime_of_finaltime_directionprime.view(np.ndarray)) + timeprime = timeprime[(timeprime >= earliest_complete_timeprime) & (timeprime <= latest_complete_timeprime)] + abdprime = type(self)(timeprime, output_ell_max) + ### The following calculations are done using in-place Horner form. I suspect this will be the ### most efficient form of this calculation, within reason. Note that the factors of exp(isλ) ### were computed automatically by evaluating in terms of quaternions. # - fprime_of_timenaught_directionprime = np.empty((6, self.n_times, n_theta, n_phi), dtype=complex) + fprime_of_timenaught_directionprime = np.empty((1, self.n_times, n_theta, n_phi), dtype=complex) + # ψ0'(u, θ', ϕ') fprime_temp = ψ4.copy() fprime_temp *= ðuprime_over_k @@ -352,7 +365,27 @@ def transform(self, **kwargs): fprime_temp *= ðuprime_over_k fprime_temp += ψ0 fprime_temp *= one_over_k_cubed + # This will store the values of f'(u', θ', ϕ') for the various functions `f` fprime_of_timenaught_directionprime[0] = fprime_temp + fprime_of_timeprime_directionprime = np.zeros((1,timeprime.size, n_theta, n_phi), dtype=complex) + # Interpolate the various transformed function values on the transformed grid from the original + # time coordinate to the new set of time coordinates, independently for each direction. + for i in range(n_theta): + for j in range(n_phi): + k_i_j = k[0, i, j] + α_i_j = α[0, i, j] + # u'(u, θ', ϕ') + timeprime_of_timenaught_directionprime_i_j = k_i_j * (u - α_i_j) + # f'(u', θ', ϕ') + fprime_of_timeprime_directionprime[:, :, i, j] = CubicSpline( + timeprime_of_timenaught_directionprime_i_j, fprime_of_timenaught_directionprime[:, :, i, j], axis=1 + )(timeprime) + # Transform back from the distorted grid to the SWSH mode weights as measured in that + # grid. I'll abuse notation slightly here by indicating those "distorted" mode weights with + # primes, so that f'(u')_{ℓ', m'} = ∫ f'(u', θ', ϕ') sȲ_{ℓ', m'}(θ', ϕ') sin(θ') dθ' dϕ' + # ψ0'(u')_{ℓ', m'} + abdprime.psi0 = spinsfast.map2salm(fprime_of_timeprime_directionprime[0], 2, output_ell_max) + # ψ1'(u, θ', ϕ') fprime_temp = -ψ4 fprime_temp *= ðuprime_over_k @@ -362,7 +395,26 @@ def transform(self, **kwargs): fprime_temp *= ðuprime_over_k fprime_temp += ψ1 fprime_temp *= one_over_k_cubed - fprime_of_timenaught_directionprime[1] = fprime_temp + # This will store the values of f'(u', θ', ϕ') for the various functions `f` + fprime_of_timenaught_directionprime[0] = fprime_temp + # Interpolate the various transformed function values on the transformed grid from the original + # time coordinate to the new set of time coordinates, independently for each direction. + for i in range(n_theta): + for j in range(n_phi): + k_i_j = k[0, i, j] + α_i_j = α[0, i, j] + # u'(u, θ', ϕ') + timeprime_of_timenaught_directionprime_i_j = k_i_j * (u - α_i_j) + # f'(u', θ', ϕ') + fprime_of_timeprime_directionprime[:, :, i, j] = CubicSpline( + timeprime_of_timenaught_directionprime_i_j, fprime_of_timenaught_directionprime[:, :, i, j], axis=1 + )(timeprime) + # Transform back from the distorted grid to the SWSH mode weights as measured in that + # grid. I'll abuse notation slightly here by indicating those "distorted" mode weights with + # primes, so that f'(u')_{ℓ', m'} = ∫ f'(u', θ', ϕ') sȲ_{ℓ', m'}(θ', ϕ') sin(θ') dθ' dϕ' + # ψ1'(u')_{ℓ', m'} + abdprime.psi1 = spinsfast.map2salm(fprime_of_timeprime_directionprime[0], 1, output_ell_max) + # ψ2'(u, θ', ϕ') fprime_temp = ψ4.copy() fprime_temp *= ðuprime_over_k @@ -370,37 +422,80 @@ def transform(self, **kwargs): fprime_temp *= ðuprime_over_k fprime_temp += ψ2 fprime_temp *= one_over_k_cubed - fprime_of_timenaught_directionprime[2] = fprime_temp + # This will store the values of f'(u', θ', ϕ') for the various functions `f` + fprime_of_timenaught_directionprime[0] = fprime_temp + # Interpolate the various transformed function values on the transformed grid from the original + # time coordinate to the new set of time coordinates, independently for each direction. + for i in range(n_theta): + for j in range(n_phi): + k_i_j = k[0, i, j] + α_i_j = α[0, i, j] + # u'(u, θ', ϕ') + timeprime_of_timenaught_directionprime_i_j = k_i_j * (u - α_i_j) + # f'(u', θ', ϕ') + fprime_of_timeprime_directionprime[:, :, i, j] = CubicSpline( + timeprime_of_timenaught_directionprime_i_j, fprime_of_timenaught_directionprime[:, :, i, j], axis=1 + )(timeprime) + # Transform back from the distorted grid to the SWSH mode weights as measured in that + # grid. I'll abuse notation slightly here by indicating those "distorted" mode weights with + # primes, so that f'(u')_{ℓ', m'} = ∫ f'(u', θ', ϕ') sȲ_{ℓ', m'}(θ', ϕ') sin(θ') dθ' dϕ' + # ψ2'(u')_{ℓ', m'} + abdprime.psi2 = spinsfast.map2salm(fprime_of_timeprime_directionprime[0], 0, output_ell_max) + # ψ3'(u, θ', ϕ') fprime_temp = -ψ4 fprime_temp *= ðuprime_over_k fprime_temp += ψ3 fprime_temp *= one_over_k_cubed - fprime_of_timenaught_directionprime[3] = fprime_temp + # This will store the values of f'(u', θ', ϕ') for the various functions `f` + fprime_of_timenaught_directionprime[0] = fprime_temp + # Interpolate the various transformed function values on the transformed grid from the original + # time coordinate to the new set of time coordinates, independently for each direction. + for i in range(n_theta): + for j in range(n_phi): + k_i_j = k[0, i, j] + α_i_j = α[0, i, j] + # u'(u, θ', ϕ') + timeprime_of_timenaught_directionprime_i_j = k_i_j * (u - α_i_j) + # f'(u', θ', ϕ') + fprime_of_timeprime_directionprime[:, :, i, j] = CubicSpline( + timeprime_of_timenaught_directionprime_i_j, fprime_of_timenaught_directionprime[:, :, i, j], axis=1 + )(timeprime) + # Transform back from the distorted grid to the SWSH mode weights as measured in that + # grid. I'll abuse notation slightly here by indicating those "distorted" mode weights with + # primes, so that f'(u')_{ℓ', m'} = ∫ f'(u', θ', ϕ') sȲ_{ℓ', m'}(θ', ϕ') sin(θ') dθ' dϕ' + # ψ3'(u')_{ℓ', m'} + abdprime.psi3 = spinsfast.map2salm(fprime_of_timeprime_directionprime[0], -1, output_ell_max) + # ψ4'(u, θ', ϕ') fprime_temp = ψ4.copy() fprime_temp *= one_over_k_cubed - fprime_of_timenaught_directionprime[4] = fprime_temp + # This will store the values of f'(u', θ', ϕ') for the various functions `f` + fprime_of_timenaught_directionprime[0] = fprime_temp + # Interpolate the various transformed function values on the transformed grid from the original + # time coordinate to the new set of time coordinates, independently for each direction. + for i in range(n_theta): + for j in range(n_phi): + k_i_j = k[0, i, j] + α_i_j = α[0, i, j] + # u'(u, θ', ϕ') + timeprime_of_timenaught_directionprime_i_j = k_i_j * (u - α_i_j) + # f'(u', θ', ϕ') + fprime_of_timeprime_directionprime[:, :, i, j] = CubicSpline( + timeprime_of_timenaught_directionprime_i_j, fprime_of_timenaught_directionprime[:, :, i, j], axis=1 + )(timeprime) + # Transform back from the distorted grid to the SWSH mode weights as measured in that + # grid. I'll abuse notation slightly here by indicating those "distorted" mode weights with + # primes, so that f'(u')_{ℓ', m'} = ∫ f'(u', θ', ϕ') sȲ_{ℓ', m'}(θ', ϕ') sin(θ') dθ' dϕ' + # ψ4'(u')_{ℓ', m'} + abdprime.psi4 = spinsfast.map2salm(fprime_of_timeprime_directionprime[0], -2, output_ell_max) + # σ'(u, θ', ϕ') fprime_temp = σ.copy() fprime_temp -= ððα fprime_temp *= one_over_k - fprime_of_timenaught_directionprime[5] = fprime_temp - - # Determine the new time slices. The set timeprime is chosen so that on each slice of constant - # u'_i, the average value of u=(u'/k)+α is precisely =u'γ+<α>=u_i. But then, we have to - # narrow that set down, so that every grid point on all the u'_i' slices correspond to data in - # the range of input data. - timeprime = (u - sf.constant_from_ell_0_mode(supertranslation[0]).real) / γ - timeprime_of_initialtime_directionprime = k * (u[0] - α) - timeprime_of_finaltime_directionprime = k * (u[-1] - α) - earliest_complete_timeprime = np.max(timeprime_of_initialtime_directionprime.view(np.ndarray)) - latest_complete_timeprime = np.min(timeprime_of_finaltime_directionprime.view(np.ndarray)) - timeprime = timeprime[(timeprime >= earliest_complete_timeprime) & (timeprime <= latest_complete_timeprime)] - # This will store the values of f'(u', θ', ϕ') for the various functions `f` - fprime_of_timeprime_directionprime = np.zeros((6, timeprime.size, n_theta, n_phi), dtype=complex) - + fprime_of_timenaught_directionprime[0] = fprime_temp # Interpolate the various transformed function values on the transformed grid from the original # time coordinate to the new set of time coordinates, independently for each direction. for i in range(n_theta): @@ -413,22 +508,10 @@ def transform(self, **kwargs): fprime_of_timeprime_directionprime[:, :, i, j] = CubicSpline( timeprime_of_timenaught_directionprime_i_j, fprime_of_timenaught_directionprime[:, :, i, j], axis=1 )(timeprime) - - # Finally, transform back from the distorted grid to the SWSH mode weights as measured in that + # Transform back from the distorted grid to the SWSH mode weights as measured in that # grid. I'll abuse notation slightly here by indicating those "distorted" mode weights with # primes, so that f'(u')_{ℓ', m'} = ∫ f'(u', θ', ϕ') sȲ_{ℓ', m'}(θ', ϕ') sin(θ') dθ' dϕ' - abdprime = type(self)(timeprime, output_ell_max) - # ψ0'(u')_{ℓ', m'} - abdprime.psi0 = spinsfast.map2salm(fprime_of_timeprime_directionprime[0], 2, output_ell_max) - # ψ1'(u')_{ℓ', m'} - abdprime.psi1 = spinsfast.map2salm(fprime_of_timeprime_directionprime[1], 1, output_ell_max) - # ψ2'(u')_{ℓ', m'} - abdprime.psi2 = spinsfast.map2salm(fprime_of_timeprime_directionprime[2], 0, output_ell_max) - # ψ3'(u')_{ℓ', m'} - abdprime.psi3 = spinsfast.map2salm(fprime_of_timeprime_directionprime[3], -1, output_ell_max) - # ψ4'(u')_{ℓ', m'} - abdprime.psi4 = spinsfast.map2salm(fprime_of_timeprime_directionprime[4], -2, output_ell_max) # σ'(u')_{ℓ', m'} - abdprime.sigma = spinsfast.map2salm(fprime_of_timeprime_directionprime[5], 2, output_ell_max) + abdprime.sigma = spinsfast.map2salm(fprime_of_timeprime_directionprime[0], 2, output_ell_max) return abdprime