Coverage for pygeodyn/augkf/forecaster.py: 100%
56 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-22 13:43 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-22 13:43 +0000
1import numpy as np
2from .. import common
3from ..generic.computer import GenericComputer
4from pygeodyn.corestates import with_core_state_of_dimensions
5from mpi4py import MPI
6from .. import corestates as cs
8comm = MPI.COMM_WORLD
9rank = comm.Get_rank()
11class AugkfForecasterAR1(GenericComputer):
12 """
13 Class that implements the forecasts using AugKF (Augmented state Kalman Filter) algorithm with DIFF treated as a contribution to ER.
14 """
15 def __init__(self, algo):
16 """
17 :param algo: Algorithm object
18 :type algo: Algo
19 """
20 super().__init__(algo)
22 for k in self.needed_covariance_matrices():
23 if k not in algo.cov_prior:
24 raise KeyError('{} matrix not found in covariance matrices while needed for Forecast !'.format(k))
26 # Bool to deactivate checks on AR processes
27 self.Cholesky_AR_check = False
29 def needed_covariance_matrices(self):
30 #needed_covariance_matrices for forecast step
31 return ['A', 'Chol']
33 @with_core_state_of_dimensions(1)
34 def forecast_step(self, input_core_state, Z_AR, seed, i_real, i_t):
35 """
36 Forecasts the input_core_state using AR processes for Z, computation of SV and Euler scheme for B.
38 :param input_core_state: core_state of a single realisation at a single date
39 :type input_core_state: corestates.CoreState
40 :param Z_AR: forecast state AR
41 :type Z_AR: np.array 1D if AR1 (Ncoef) or 2D if AR3 (3 x Ncoef)
42 :param seed: random seed
43 :type seed: int
44 :param i_real: model realisation index
45 :type i_real: int
46 :param i_t: time index
47 :type i_t: int
48 :return: CoreState containing the result from the forecast
49 :rtype: corestates.CoreState
50 """
52 # copy input core state
53 next_core_state = input_core_state.copy()
55 # set random state
56 N = self.algo.nb_realisations
57 rstate = np.random.default_rng(seed + i_real + N * i_t)
59 # Compute Z(t+1)
60 next_core_state.Z = self.forecast_Z(Z_AR, rstate=rstate)
61 next_core_state.U, next_core_state.ER = self.algo.Z_to_U_ER(next_core_state.Z, 1)
63 # Compute A(b)
64 Ab = self.compute_Ab(next_core_state)
66 # Compute SV(t+1) = A(b)U(t+1) + E(t+1)
67 next_core_state.SV = np.matmul(Ab, next_core_state.U) + next_core_state.ER
69 # Compute B(t+1) = B(t) + dt*SV(t+1) (Euler scheme as SV=dB/dt)
70 next_core_state.B = input_core_state.B + self.cfg.dt_f * next_core_state.SV
72 # Returns the updated core_state at t+1
73 return next_core_state
76 def forecast_Z(self, Z_AR1, rstate=None):
77 """
78 Forecast Z state with AR-1 process.
80 :param Z_AR1: AR-1 forecast state
81 :type Z_AR1: np.array (Ncoef)
82 :param rstate: Random state to use for the AR-1 process
83 :type rstate: np.random.RandomState
84 :return: forecasted Z state
85 :rtype: np.array(Ncoef)
86 """
87 # AR1 process for Z
88 return common.ar1_process(Z_AR1,
89 self.algo.cov_prior['A'],
90 self.algo.cov_prior['Chol'],
91 random_state=rstate)
94 def parallel_forecast_step(self, input_core_state, seed, i_t):
95 """
96 parallelize the AR1 forecast step
98 :param input_core_state: input_core_state at time t
99 :type input_core_states: corestates.Corestate
100 :param seed: random seed
101 :type seed: int
102 :param i_t: time index
103 :type i_t: int
104 :return: CoreState containing the result from the forecast
105 :rtype: corestates.CoreState
106 """
108 t = self.algo.config.t_forecasts[i_t]
110 # copy input core state
111 forecast_at_t = input_core_state.copy()
112 # set all measures to 0
113 forecast_at_t[:] = 0
115 # Each process computes its attributed models
116 for i_idx, i_model in enumerate(self.algo.attributed_models):
117 forecast_at_t[i_idx] = self.forecast_step(input_core_state[i_idx], input_core_state.Z[i_idx], seed, i_idx, i_t)
119 print('P'+str(rank), t, cs.coef_print(forecast_at_t[i_idx], 0))
121 return forecast_at_t
124class AugkfForecasterAR3(AugkfForecasterAR1):
126 def needed_covariance_matrices(self):
127 return ['A', 'B', 'C', 'Chol']
129 def forecast_Z(self, Z_AR3, rstate=None):
130 """
131 Forecast Z state with AR-3 process.
133 :param Z_AR3: AR-3 forecast state
134 :type Z_AR3: np.array (3 x Ncoef)
135 :param rstate: Random state to use for the AR-3 process
136 :type rstate: np.random.RandomState
137 :return: forecasted Z state
138 :rtype: np.array(Ncoef)
139 """
140 # AR3 process for Z
141 return common.ar3_process(Z_AR3,
142 self.algo.cov_prior['A'],
143 self.algo.cov_prior['B'],
144 self.algo.cov_prior['C'],
145 self.algo.cov_prior['Chol'],
146 random_state=rstate)
149 def update_Z_AR3(self, Z, Z_AR3):
150 """
151 Update forecast Z state for AR3
152 :param Z: forecasted Z
153 :type Z: np.array (nreal x Ncoef)
154 :param Z_AR3: AR-3 forecast state
155 :type Z_AR3: np.array (3 x Ncoef)
156 :return Z_AR3: AR-3 forecast state
157 :rtype Z_AR3: np.array (3 x Ncoef)
158 """
159 Z_AR3 = np.concatenate((Z_AR3, Z[:,np.newaxis,:]), axis=1)
160 assert Z_AR3[:,1:,:].shape[1] == 3
161 return Z_AR3[:,1:,:]
163 def parallel_forecast_step(self, input_core_state, Z_AR3, seed, i_t):
164 """
165 parallelize the AR3 forecast step
167 :param input_core_state: input_core_state at time t
168 :type input_core_states: corestates.Corestate
169 :param Z_AR3: AR-3 forecast state
170 :type Z_AR3: np.array (3 x Ncoef)
171 :param seed: random seed
172 :type seed: int
173 :param i_t: time index
174 :type i_t: int
175 :return: CoreState containing the result from the forecast
176 :rtype: corestates.CoreState
177 """
179 t = self.algo.config.t_forecasts[i_t]
181 # copy input core state
182 forecast_at_t = input_core_state.copy()
183 # set all measures to 0
184 forecast_at_t[:] = 0
186 # Each process computes its attributed models
187 for i_idx, i_model in enumerate(self.algo.attributed_models):
188 forecast_at_t[i_idx] = self.forecast_step(input_core_state[i_idx], Z_AR3[i_idx], seed, i_idx, i_t)
190 print('P'+str(rank), t, cs.coef_print(forecast_at_t[i_idx], 0))
192 # update Z_AR3
193 Z_AR3 = self.update_Z_AR3(forecast_at_t.Z, Z_AR3)
195 return forecast_at_t, Z_AR3