Coverage for pygeodyn/augkf/analyser.py: 88%
329 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-22 13:43 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-22 13:43 +0000
1import numpy as np
2import os.path
3import pygeodyn
4import logging
5from scipy import linalg
6from mpi4py import MPI
7from .. import common
8from ..inout import observations as observations_module
9from ..generic.computer import GenericComputer
10from sklearn import covariance as skcov
11from .. import common
13comm = MPI.COMM_WORLD
14rank = comm.Get_rank()
16class AugkfAnalyserAR1(GenericComputer):
17 """
18 Class that handles the analyses of the Augmented State Kalman Filter algorithm with DIFF treated as a contribution to ER.
19 """
21 def __init__(self, algo):
22 """
23 Sets all internal members, observation operator and covariance matrices for Z=UE.
25 :param algo: Algorithm object
26 :type algo: Algo
27 """
28 super().__init__(algo)
30 # Date-based dicts for observations, stored in measure_observations dict
31 self.measure_observations = {}
33 nb_obs_mf, nb_obs_sv = self.extract_observations()
34 if nb_obs_mf == 0:
35 logging.error(
36 "No observation was extracted for MF! Analyses on b will be completely skipped"
37 )
38 elif nb_obs_sv == 0:
39 logging.error(
40 "No observation was extracted for SV! Analyses on u, e and d/dt b will be completely skipped"
41 )
42 else:
43 logging.info(
44 "Finished extracting the observations: MF ({}) and SV ({})!".format(
45 nb_obs_mf, nb_obs_sv
46 )
47 )
49 # Dict to store the misfits of the current analysis
50 self.current_misfits = {}
51 self.current_misfits = self.invalid_misfits(keys=["MF", "SV", "U", "ER", "Z"]) # initialisation
54 def invalid_misfits(self, keys):
55 possible_keys = ["MF", "SV", "U", "ER", "Z"]
56 err = 'invalid key passed in invalid_mifits(), keys: {}, possible keys: {}'
57 assert all([k in possible_keys for k in keys]), err.format(keys, possible_keys)
58 try:
59 for k in keys:
60 self.current_misfits[k] = np.nan
61 return self.current_misfits
62 except NameError:
63 return {k: np.nan for k in keys}
66 def extract_observations(self):
67 """
68 Extracts the observations for all obs types in the config. Updates the internal dictionaries observations_mf and observations_sv.
70 :return: The numbers of dates for MF and SV for eventual checking.
71 :rtype: int, int
72 """
74 # building function can be either, build_go_vo_observations, build_covobs_observations,
75 # build_covobs_observations or build_covobs_hdf5_observations
76 building_function = getattr(
77 observations_module,
78 "build_{}_observations".format(self.cfg.obs_type.lower()),
79 None,
80 )
82 if building_function is None:
83 logging.error(
84 "No loading function was defined for {}. The extraction of observations was skipped for this type.".format(
85 self.cfg.obs_type
86 )
87 )
88 logging.info("Reading {} data as observations...".format(self.cfg.obs_type))
90 for measure_type in ["SV", "MF"]:
91 self.measure_observations[measure_type] = building_function(
92 self.cfg, self.algo.nb_realisations, measure_type, self.algo.seed
93 )
94 if hasattr(self.cfg, "obs_mod_errors_dir"):
95 self.add_modelisation_errors(
96 os.path.join(
97 pygeodyn._package_directory, self.cfg.obs_mod_errors_dir
98 ),
99 measure_type,
100 )
101 return (
102 len(self.measure_observations["MF"].keys()),
103 len(self.measure_observations["SV"].keys()),
104 )
106 def add_modelisation_errors(self, obs_mod_errors_dir, measure_type):
107 """
108 Add another contribution to the observation errors. The errors have to be stored in two text files,
109 one for the MF and one for the SV, see for instance 'pygeodyn/data/observations/COVOBS-x2_maglat/mod_errors'
110 """
111 if self.cfg.obs_type == 'GO_VO':
112 logging.warning('mod_error is made for synthetic data, not for GVO observatories')
113 return None
115 try:
116 mod_errors = np.genfromtxt(
117 os.path.join(obs_mod_errors_dir, "{}.err".format(measure_type))
118 )
119 except IOError:
120 logging.error(
121 "No {0} modelisation error file ({0}.err) was found in {1} ! Modelisation errors will be set to 0 for {0}.".format(
122 measure_type, obs_mod_errors_dir
123 )
124 )
125 return
127 logging.debug(
128 "Adding modelisation errors stored in {1} for {0}.".format(
129 measure_type, obs_mod_errors_dir
130 )
131 )
132 for date, obs in self.measure_observations[measure_type].items():
133 No = len(obs.Rxx)
134 print('No, measure, mod_errors', No, date, measure_type, mod_errors.shape)
135 self.measure_observations[measure_type][date].Rxx = obs.Rxx + np.diag(
136 mod_errors[:No]
137 )
139 def sv_analysis(self):
140 """
141 Shortcut to check if sv_analysis is performed
142 """
143 # 1 means analysis on sv, 0 means no analysis on sv
144 return np.all(self.ana_sv)
146 def mf_analysis(self):
147 """
148 Shortcut to check if mf_analysis is performed
149 """
150 # 1 means analysis on mf, 0 means no analysis on mf
151 return np.all(self.ana_mf)
153 def check_if_analysis_data(self, i_analysis):
154 """
155 check if there is mf and/or sv observation at next analysis time (times if AR3)
156 and prepare observation data for either AR1 or AR3 analysis
158 :param i_analysis: analysis time iteration
159 :type i_analysis: int
160 :return: update self.ana_sv and self.ana_mf
161 """
163 AR_type = self.algo.config.AR_type
165 # set number of times (Nt) involved in analysis depending on AR_type
166 if AR_type == "AR3":
167 Nt = 3
168 else:
169 Nt = 1
171 # init boolean vectors
172 self.ana_sv = np.zeros((Nt,), dtype=bool)
173 self.ana_mf = np.zeros((Nt,), dtype=bool)
175 if not i_analysis < self.algo.config.nb_analyses:
176 return self.ana_mf, self.ana_sv
178 # set times vector depending on AR_type
179 if AR_type == "AR3":
180 times = [i_analysis,
181 i_analysis + 1,
182 i_analysis + 2]
183 else:
184 times = [i_analysis+1]
186 # loop over Nt
187 for i in range(Nt):
188 t = str(times[i])
189 # if observation found
190 if t in self.measure_observations["SV"].keys():
191 #update self.ana_sv
192 self.ana_sv[i] = True
193 if t in self.measure_observations["MF"].keys():
194 #update self.ana_sv
195 self.ana_mf[i] = True
196 #setup self.measure_observations
197 if AR_type == "AR3":
198 if self.sv_analysis():
199 self.sv_X = [self.measure_observations["SV"][str(times[0])].X,
200 self.measure_observations["SV"][str(times[1])].X,
201 self.measure_observations["SV"][str(times[2])].X]
202 self.sv_H = [self.measure_observations["SV"][str(times[0])].H,
203 self.measure_observations["SV"][str(times[1])].H,
204 self.measure_observations["SV"][str(times[2])].H]
205 self.sv_Rxx = [self.measure_observations["SV"][str(times[0])].Rxx,
206 self.measure_observations["SV"][str(times[1])].Rxx,
207 self.measure_observations["SV"][str(times[2])].Rxx]
208 if self.mf_analysis():
209 self.mf_X = [self.measure_observations["MF"][str(times[0])].X,
210 self.measure_observations["MF"][str(times[1])].X,
211 self.measure_observations["MF"][str(times[2])].X]
212 self.mf_H = [self.measure_observations["MF"][str(times[0])].H,
213 self.measure_observations["MF"][str(times[1])].H,
214 self.measure_observations["MF"][str(times[2])].H]
215 self.mf_Rxx = [self.measure_observations["MF"][str(times[0])].Rxx,
216 self.measure_observations["MF"][str(times[1])].Rxx,
217 self.measure_observations["MF"][str(times[2])].Rxx]
218 else:
219 if self.sv_analysis():
220 self.sv_X = self.measure_observations["SV"][str(times[0])].X
221 self.sv_H = self.measure_observations["SV"][str(times[0])].H
222 self.sv_Rxx = self.measure_observations["SV"][str(times[0])].Rxx
224 if self.mf_analysis():
225 self.mf_X = self.measure_observations["MF"][str(times[0])].X
226 self.mf_H = self.measure_observations["MF"][str(times[0])].H
227 self.mf_Rxx = self.measure_observations["MF"][str(times[0])].Rxx
229 # if no mf analysis
230 if not self.mf_analysis():
231 logging.critical(
232 "Skipping MF analysis"
233 )
234 self.current_misfits = self.invalid_misfits(keys=["MF"])
235 # if no sv analysis
236 if not self.sv_analysis():
237 logging.critical(
238 "Skipping SV analysis"
239 )
240 self.current_misfits = self.invalid_misfits(keys=["SV"])
243 def analysis_step(self, input_core_state):
244 """ Does the analysis at time t on the B and Z=[UE] part of the input_core_state.
245 Updates SV = A(B)U - ER in consequence.
247 :param input_core_state: Core state at time t
248 :type input_core_state: corestates.CoreState (dim: nb_realisations x Ncorestate)
249 :return: the analysed core state
250 :rtype: corestates.CoreState (dim: nb_realisations x Ncorestate)
251 """
253 # Check corestate is 2D
254 if input_core_state.B.ndim != 2:
255 raise ValueError("Corestate must be 2D but is ({}D)".format(input_core_state.B.ndim))
257 # Check that number of reals match
258 if input_core_state.B.shape[0] != self.algo.nb_realisations:
259 raise ValueError(
260 "Number of realisations in input_core_state ({}) is not equal to expected number ({})".format(
261 input_core_state.B.shape[0], self.algo.nb_realisations))
263 # copy core state
264 ana_core_state = input_core_state.copy()
266 # if no analysis
267 if not self.sv_analysis() and not self.mf_analysis(): # no analysis of mf and sv
268 return ana_core_state[self.algo.attributed_models]
270 if self.mf_analysis():
271 # perform MF analysis
272 ana_core_state.B = self.analyse_B(ana_core_state.B, self.mf_X, self.mf_H, self.mf_Rxx)
273 if self.sv_analysis():
274 # perform SV analysis
275 (ana_core_state.Z,
276 ana_core_state.U,
277 ana_core_state.ER,
278 ana_core_state.SV) = self.analyse_Z(ana_core_state, self.sv_X, self.sv_H, self.sv_Rxx)
279 return ana_core_state[self.algo.attributed_models]
281 def analyse_B(self, input_core_state, mf_X, mf_H, mf_Rxx):
282 """
283 Returns the analysed data for B by a BLUE given the observations.
285 :param input_core_state: NumPy array containing the coefficient data of B
286 :type input_core_state: np.array (dim: nb_realisations x Nb)
287 :param mf_X: Observation data to use for the BLUE
288 :type mf_X: Observation
289 :param mf_H: Observation matrix to use for the BLUE
290 :type mf_H: Observation
291 :param mf_Rxx: Observation error to use for the BLUE
292 :type mf_Rxx: Observation
293 :return: NumPy array containing the analysed coefficient data of B
294 :rtype: np.array (dim: nb_realisations x Nb)
295 """
296 # obs operator
297 Hb = mf_H
298 # obs error
299 Rbb = mf_Rxx
300 # compute Pbb from B state
301 Pbb_forecast = self.remove_small_correlations(input_core_state)
302 # Updates the B part of the core_state by the result of the Kalman filter for each model
303 logging.debug("Getting best linear unbiased estimate of B...")
304 analysed_B = np.zeros((self.algo.nb_realisations, self.algo.config.Nb))
306 if self.cfg.kalman_norm == 'l2': # for non least square norm, iteration are needed
307 Kbb = common.compute_Kalman_gain_matrix(
308 Pbb_forecast, Hb, Rbb
309 )
310 for i_idx, i_real in enumerate(self.algo.attributed_models):
311 analysed_B[i_real] = common.get_BLUE(
312 input_core_state[i_real],
313 mf_X[i_real],
314 Pbb_forecast,
315 Hb,
316 Rbb,
317 Kbb,
318 )
320 elif self.cfg.kalman_norm == 'huber':
321 # compute inverse of P_bb before loop on reals using its symmetry
322 P_eig_val, P_eig_vec = np.linalg.eigh(Pbb_forecast)
323 P_eig_val[P_eig_val < 1e-10] = 1e-10 # in case if matrix is not full rank, which should not happen
324 Pbb_inv = P_eig_vec @ np.diag(1 / P_eig_val) @ P_eig_vec.T
325 for i_idx, i_real in enumerate(self.algo.attributed_models):
326 analysed_B[i_real] = common.compute_Kalman_huber(input_core_state[i_real], mf_X[i_real],
327 Pbb_inv, Hb, Rbb)
328 else:
329 raise ValueError('Invalid value of param kalman_norm, should be equal to huber or l2, got {}'.format(self.cfg.kalman_norm))
331 comm.Allreduce(MPI.IN_PLACE, analysed_B, op=MPI.SUM)
333 # Compute the misfits for B (Y - HX)
334 HX_b = np.transpose(np.matmul(Hb, np.transpose(analysed_B)))
335 self.current_misfits["MF"] = common.compute_misfit(
336 mf_X, HX_b, linalg.inv(Rbb)
337 )
339 return analysed_B
342 def setup_Hz(self,Ab, Nu):
343 """
344 Compute the matrix [Ab | I_e] where Ab is the contains the Gaunt elasser integrals,
345 while I_e is the identity matrix of size Nsv
347 :param Nu: dimension of the flow
348 :type Nu: int
349 :return: matrix Nsv x (Nu + Ne)
350 :rtype: numpy array
351 """
352 assert Ab.shape[1] == Nu, "mismatch in dimension of Ab {} and flow {} ".format(Ab.shape[1], Nu)
353 Nsv = self.cfg.Nsv
355 Hz = np.zeros((Nsv, Nu + Nsv), order="F")
357 # Set the observation operator to A(B) for U
358 Hz[:Nsv, :Nu] = Ab
359 # and identity for E
360 Hz[:Nsv, Nu:Nu + Nsv] = np.identity(Nsv) * self.cfg.compute_e
361 return Hz
363 def remove_small_correlations(self, input_core_state, eps=1e-10):
364 """
365 Apply the graphical lasso to the correlation matrix. The correlation matrix is computed from
366 the covariance matrix, either Pzz or Pbb in practice, with C[i, j] = P[i, j] / (P[i, i] P[j, j]).
367 Warning: In some cases, some variance elements can be zero, for instance if the initialisation
368 parameter, core_state_init, is set to constant. Then the correlation matrix cannot be computed
369 and the Glasso is not applied.
371 If the glasso parameter, self.cfg.remove_spurious, is set to 0 (np.inf), then the resp. diagonal (empirical)
372 covariance matrix is returned.
373 Otherwise the glasso is applied on the correlation matrix.
375 :param input_core_state: Corestate which can either be Z or B,
376 at a given time for all realizations
377 :param eps: threshold that determines if a value should be considered as null. During the division to
378 get the correlation matrix, null values are replaced by eps
379 :type eps: float
380 """
382 # computation of the empirical Pzz_forecast
383 P_forecast = common.cov(input_core_state)
385 if self.cfg.remove_spurious == 0: # extreme case, keep the sample covariance matrix
386 return P_forecast
387 if self.cfg.remove_spurious == np.inf: # other extreme case, take only the diagonal of the covariance matrix
388 return np.diag(np.diag(P_forecast))
390 if np.any(np.diag(P_forecast) == 0):
391 # avoid division by zeros (exactly zeros, small numbers are left) by regularization
392 # with many zeros, might give a hard time to the graphical lasso algo, as result may not converge.
393 zeros = np.squeeze(np.argwhere(np.diag(P_forecast) == 0))
394 logging.warning('Some coefficients in the diagonal sample covariance matrix are very close to zero at {}'.format(zeros))
395 diag_mask = np.zeros_like(P_forecast, dtype=bool)
396 np.fill_diagonal(diag_mask, True)
397 P_forecast[np.logical_and(P_forecast==0, diag_mask)] = eps # only modifies zeros present in the diagonal
399 # if some values in the diagonal of the covariance matrix are zero, it will still give a 1 in the diagonal of the correlation matrix
400 # compute correlation matrix
401 diag_sq_inv = lambda M: np.diag(1 / np.sqrt(np.diag(M)))
402 C_forecast = diag_sq_inv(P_forecast) @ P_forecast @ diag_sq_inv(P_forecast)
404 # Compute the lasso approximation
405 C_lasso = skcov.graphical_lasso(C_forecast, self.cfg.remove_spurious, max_iter=100)[0]
407 # compute the P_lasso from P_forecast
408 diag_sq = lambda M: np.diag(np.sqrt(np.diag(M)))
410 return diag_sq(P_forecast) @ C_lasso @ diag_sq(P_forecast)
413 def analyse_Z(self, input_core_state, sv_X, sv_H, sv_Rxx):
414 """
415 Returns the analysed data for the augmented state Z = [U ER] and SV by a BLUE given the observations.
417 :param input_core_state: 2D CoreState containing the coefficient data
418 :type input_core_state: CoreState
419 :param sv_X: Observation data to use for the BLUE
420 :type sv_X: Observation
421 :param sv_H: Observation matrix to use for the BLUE
422 :type sv_H: Observation
423 :param sv_Rxx: Observation error to use for the BLUE
424 :type sv_Rxx: Observation
425 :return: 2D analysed Z U ER SV
426 :rtype: 2D arrays (Nreal x Ncoef)
428 """
430 # compute necessary matrices for Kalman filter
431 Pzz_forecast = self.remove_small_correlations(input_core_state.Z)
433 analysed_Z = np.zeros((self.algo.nb_realisations, self.algo.config.Nz))
434 analysed_SV = np.zeros((self.algo.nb_realisations, self.algo.config.Nsv))
435 analysed_ER = np.zeros((self.algo.nb_realisations, self.algo.config.Nsv))
436 analysed_U = np.zeros((self.algo.nb_realisations, self.algo.config.Nu2))
437 sv_X_real = np.zeros(sv_X.shape)
438 for i_idx, i_real in enumerate(self.algo.attributed_models):
440 Ab = self.compute_Ab(input_core_state[i_real])
442 # The complete H operator (No x Ncoefs of U) is:
443 # Hsv (No x (Ncoefs of core + Ncoefs SV)) * Hz ((Ncoefs of core + Ncoefs SV) x Ncoefs of U)
444 if self.algo.check_PCA():
445 # if PCA
446 complete_H = sv_H @ self.setup_Hz(Ab @ self.cfg.pcaU_operator.S_u, self.cfg.N_pca_u)
447 else:
448 # if no PCA
449 complete_H = sv_H @ self.setup_Hz(Ab, self.cfg.Nu2)
451 PzzHT = Pzz_forecast @ complete_H.T
452 HPzzHT = complete_H @ PzzHT
454 # Z is centered on 0 so we must remove the mean from the observation data
455 # Y = Ab (U+U0) + (ER + ER0) => Y - Ab U0 - ER0 = Ab U + ER
456 sv_X_real[i_real] = sv_X[i_real] - (sv_H @ Ab @ self.algo.avg_prior["U"] + sv_H @ self.algo.avg_prior["ER"])
458 analysed_Z[i_real] = common.compute_Kalman_huber_parameter_basis(input_core_state.Z[i_real],
459 sv_X_real[i_real],
460 HPzzHT,
461 PzzHT,
462 complete_H,
463 sv_Rxx)
464 analysed_U[i_real], analysed_ER[i_real] = self.algo.Z_to_U_ER(analysed_Z[i_real],1)
466 analysed_SV[i_real] = Ab @ analysed_U[i_real] + analysed_ER[i_real]
468 comm.Allreduce(MPI.IN_PLACE, analysed_Z, op=MPI.SUM)
469 comm.Allreduce(MPI.IN_PLACE, analysed_U, op=MPI.SUM)
470 comm.Allreduce(MPI.IN_PLACE, analysed_ER, op=MPI.SUM)
471 comm.Allreduce(MPI.IN_PLACE, analysed_SV, op=MPI.SUM)
472 comm.Allreduce(MPI.IN_PLACE, sv_X_real, op=MPI.SUM)
474 # Compute the misfits for SV (Y - HX)
475 HX_z = (complete_H @ np.transpose(analysed_Z)).T
476 self.current_misfits["SV"] = common.compute_misfit(
477 sv_X_real, HX_z, linalg.inv(sv_Rxx)
478 )
480 return analysed_Z, analysed_U, analysed_ER, analysed_SV
483class AugkfAnalyserAR3(AugkfAnalyserAR1):
485 def analysis_step(self, input_core_state, Z_AR3):
486 """ Does the analysis at time t on the B and Z=[UE] part of the input_core_state.
487 Updates SV = A(B)U - ER in consequence.
489 :param input_core_state: Core state at time t
490 :type input_core_state: corestates.CoreState (dim: nb_realisations x Ncorestate)
491 :param ZAR3: AR3 forecast state
492 :type ZAR3: np.array (dim: nb_realisations x 3 x Nz)
493 :return: the analysed core state and updated ZAR3
494 :rtype: corestates.CoreState (dim: nb_realisations x Ncorestate), np.array (dim: nb_realisations x 3 x Nz)
495 """
497 # Check corestate is 3D
498 if input_core_state.B.ndim != 3:
499 raise ValueError("Corestate must be 3D but is ({}D)".format(input_core_state.B.ndim))
501 # Check the dim of core_state
502 if input_core_state.B.shape[0] != self.algo.nb_realisations:
503 raise ValueError(
504 "Number of realisations in input_core_state ({}) is not equal to expected number ({})".format(
505 input_core_state.B.shape[0], self.algo.nb_realisations
506 )
507 )
509 ratio = self.algo.config.dt_a_f_ratio
511 # copy core state
512 ana_core_state = input_core_state[:,ratio].copy()
514 # if no analysis
515 if not self.sv_analysis() and not self.mf_analysis(): # no analysis of mf and sv
516 return ana_core_state[self.algo.attributed_models], Z_AR3
519 if self.mf_analysis():
520 # compute the full state B B' B'' at analysis time by fitting B forecast curves
521 B_full = self.compute_full_state(input_core_state.B)
523 # compute analysis on B B' B'' and return b_minus b b_plus
524 b_minus, b, b_plus = self.analyse_B(B_full, self.mf_X, self.mf_H, self.mf_Rxx)
525 # update ana B state
526 ana_core_state.B = b
527 else:
528 #else we simply pick b_minus b b_plus from forecast values
529 b_minus = input_core_state.B[:, 0]
530 b = input_core_state.B[:, ratio]
531 b_plus = input_core_state.B[:, 2*ratio]
533 if self.sv_analysis():
534 # compute the full state Z Z' Z'' at analysis time by fitting Z forecast curves
535 Z_full = self.compute_full_state(input_core_state.Z)
537 # perform the analysis on Z Z' Z''
538 (ana_Z,
539 ana_dZ,
540 ana_d2Z,
541 ana_core_state.U,
542 ana_core_state.ER,
543 ana_core_state.SV) = self.analyse_Z(Z_full, self.sv_X, self.sv_H, self.sv_Rxx, b_minus, b, b_plus, input_core_state[0,ratio])
545 # update ana Z state
546 ana_core_state.Z = ana_Z
547 # update ana dUdt and ana dERdt state
548 ana_core_state.dUdt, ana_core_state.dERdt = self.algo.dZ_to_dU_dER(ana_dZ)
549 # update ana d2Udt2 and ana d2ERdt2 state
550 ana_core_state.d2Udt2, ana_core_state.d2ERdt2 = self.algo.dZ_to_dU_dER(ana_d2Z)
552 # update Z_AR3
553 Z_AR3 = self.init_forecast(ana_Z, ana_dZ, ana_d2Z)
555 return ana_core_state[self.algo.attributed_models], Z_AR3[self.algo.attributed_models]
557 else:
558 return ana_core_state[self.algo.attributed_models], Z_AR3
561 def build_H_operator(self, X, row_diff = False):
562 """
563 Compute observation H operator
565 :param X: X matrix (2D or 3D)
566 :param row_diff: tells if X is the same everywhere (2D X) or different (3D X)
567 """
569 dt_a = self.algo.config.dt_a
570 if row_diff == False:
571 o = np.zeros(X.shape)
572 return np.concatenate((np.concatenate((X, -dt_a * X, dt_a**2/2*X),axis=1),
573 np.concatenate((X, o, o),axis=1),
574 np.concatenate((X, dt_a * X, dt_a**2/2*X),axis=1))
575 ,axis=0)
576 else:
577 o = np.zeros(X[1].shape)
578 return np.concatenate((np.concatenate((X[0], -dt_a * X[0], dt_a**2/2*X[0]),axis=1),
579 np.concatenate((X[1], o, o),axis=1),
580 np.concatenate((X[2], dt_a * X[2], dt_a**2/2*X[2]),axis=1)),axis=0)
583 def build_obs_operator(self, obs):
584 """
585 Compute observation data operator
587 :param obs: Observation data
588 """
590 return np.concatenate((obs[0],obs[1],obs[2]),axis=1)
593 def build_err_operator(self, error):
594 """
595 Compute error operator
597 :param error: Observation error
598 """
600 return linalg.block_diag(error[0], error[1], error[2])
603 def compute_full_state(self, X_state):
604 """
605 Compute Least Square fit of X_ta, dX_ta and d2X_ta
607 :param X_state: X_state
608 :type X_state: 3D np array (dim: n_reals x nb_forecasts x ncoefs)
609 :return: Xfull
610 :rtype: 2D np array (dim: n_reals x 3*ncoefs)
611 """
612 N_coef = X_state.shape[2]
613 ratio = self.algo.config.dt_a_f_ratio
614 dt_f = self.algo.config.dt_f
616 dt = np.arange(start = -ratio, stop = ratio + 1, step = 1) * dt_f
617 dt = dt.reshape(-1,1)
619 #build G operator
620 G = np.ones(dt.shape)
621 max_order = 2
622 for i in range(max_order):
623 i = i + 1
624 dG = np.power(dt,i)/np.math.factorial(i)
625 G = np.concatenate((G, dG), axis=1)
627 X_full = np.zeros((self.algo.nb_realisations, 3 * N_coef))
628 for i_idx, i_real in enumerate(self.algo.attributed_models):
629 Y = X_state[i_real]
630 # perform best fit L2 norm
631 S = np.linalg.inv(G.T @ G) @ G.T @ Y
632 X_full[i_real] = S[:3].flatten()
634 comm.Allreduce(MPI.IN_PLACE, X_full, op=MPI.SUM)
636 return X_full
639 def analyse_B(self, B_full, mf_X, mf_H, mf_Rxx):
640 """
641 Returns the analysed data for B by a BLUE given the observations.
643 :param input_B: NumPy array containing the coefficient data of B
644 :type input_B: np.array (dim: nb_realisations x Nb)
645 :param mf_X: Observation data to use for the BLUE
646 :type mf_X: Observation
647 :param mf_H: Observation matrix to use for the BLUE
648 :type mf_H: Observation
649 :param mf_Rxx: Observation error to use for the BLUE
650 :type mf_Rxx: Observation
651 :return: b_minus, b, b_plus
652 :rtype: np.array (dim: nb_realisations x Nb), np.array (dim: nb_realisations x Nb), np.array (dim: nb_realisations x Nb)
653 """
655 Nb = self.cfg.Nb
657 #obs operator
658 Hb = self.build_H_operator(mf_H, row_diff = True)
659 #obs data
660 Yb = self.build_obs_operator(mf_X)
661 #obs error
662 Rbb = self.build_err_operator(mf_Rxx)
663 # compute Pbb from B_full state (B B' B'')
664 Pbb = self.remove_small_correlations(B_full)
666 # Updates the B part of the core_state by the result of the Kalman filter for each model
667 logging.debug("Getting best linear unbiased estimate of B...")
668 analysed_B = np.zeros((self.algo.nb_realisations, 3*Nb))
670 if self.cfg.kalman_norm == 'l2': # for non least square norm, iteration are needed
671 Kbb = common.compute_Kalman_gain_matrix(
672 Pbb, Hb, Rbb
673 )
674 for i_idx, i_real in enumerate(self.algo.attributed_models):
675 analysed_B[i_real] = common.get_BLUE(
676 B_full[i_real],
677 Yb[i_real],
678 Pbb,
679 Hb,
680 Rbb,
681 K=Kbb,
682 )
684 elif self.cfg.kalman_norm == 'h875ber':
685 # compute inverse of P_bb before loop on reals using its symmetry
686 P_eig_val, P_eig_vec = np.linalg.eigh(Pbb)
687 P_eig_val[P_eig_val < 1e-10] = 1e-10 # in case if matrix is not full rank, which should not happen
688 Pbb_inv = P_eig_vec @ np.diag(1 / P_eig_val) @ P_eig_vec.T
689 for i_idx, i_real in enumerate(self.algo.attributed_models):
690 analysed_B[i_real] = common.compute_Kalman_huber(B_full[i_real], Yb[i_real],
691 Pbb_inv, Hb, Rbb)
692 else:
693 raise ValueError('Invalid value of param kalman_norm, should be equal to huber or l2, got {}'.format(self.cfg.kalman_norm))
695 # gather process
696 comm.Allreduce(MPI.IN_PLACE, analysed_B, op=MPI.SUM)
698 # Compute the misfits for B (Y - HX)
699 HX_b = np.transpose(np.matmul(Hb, np.transpose(analysed_B)))
700 self.current_misfits["MF"] = common.compute_misfit(
701 Yb, HX_b, linalg.inv(Rbb)
702 )
704 # compute b_minus b b_plus from B,B'B''
705 B_a = analysed_B @ self.build_H_operator(np.identity(Nb)).T
707 #extract b_minus b b_plus
708 b_minus = B_a[:,:Nb]
709 b = B_a[:,Nb:2*Nb]
710 b_plus = B_a[:,2*Nb:]
712 return b_minus, b, b_plus
715 def build_full_Hz(self,input_core_state, b_minus, b, b_plus, sv_H):
716 """
717 Build the full observation matrix H for Z
718 """
719 #loop over ta_minus ta ta_plus
720 for i in range(3):
721 if i == 0:
722 input_core_state.B = b_minus
723 elif i == 1:
724 input_core_state.B = b
725 elif i == 2:
726 input_core_state.B = b_plus
727 Ab = self.compute_Ab(input_core_state)
728 if self.algo.check_PCA():
729 # if PCA
730 h = sv_H[i] @ self.setup_Hz(Ab @ self.cfg.pcaU_operator.S_u, self.cfg.N_pca_u)
731 else:
732 # if no PCA
733 h = sv_H[i] @ self.setup_Hz(Ab, self.cfg.Nu2)
734 try:
735 H = H + [h]
736 Ab_out = Ab_out + [Ab]
737 except:
738 H = [h]
739 Ab_out = [Ab]
740 return H, Ab_out
743 def analyse_Z(self, Z_full, sv_X, sv_H, sv_Rxx, b_minus, b, b_plus, input_core_state):
744 """
745 Returns the analysed data for the augmented state Z and SV that maximizes the
746 likelihood (of the Gaussian or Huber distribution)
748 :param input_core_state: Z_full
749 :type input_core_state: np.array (dim: nb_realisations x 3*Nz)
750 :param sv_X: Observation data to use for the BLUE
751 :type sv_X: Observation
752 :param sv_H: Observation matrix to use for the BLUE
753 :type sv_H: Observation
754 :param sv_Rxx: Observation error to use for the BLUE
755 :type sv_Rxx: Observation
756 :param b_minus: b at ta-
757 :type b_minus: np.array (dim: nb_realisations x Nb)
758 :param b: b at ta
759 :type b: np.array (dim: nb_realisations x Nb)
760 :param b_plus: b at ta+
761 :type b_plus: np.array (dim: nb_realisations x Nb)
762 :return: ana_Z, ana_dZ, ana_d2Z, ana_U, ana_ER, ana_SV
763 :rtype: np.array (dim: nb_realisations x Ncoef)
764 """
766 Nz = self.cfg.Nz
768 # build obs error
769 Rzz = self.build_err_operator(sv_Rxx)
771 # compute necessary matrices for Kalman filter
772 PZZ = self.remove_small_correlations(Z_full)
774 analysed_Z = np.zeros((self.algo.nb_realisations, 3*Nz))
775 analysed_SV = np.zeros((self.algo.nb_realisations, self.algo.config.Nsv))
776 analysed_ER = np.zeros((self.algo.nb_realisations, self.algo.config.Nsv))
777 analysed_U = np.zeros((self.algo.nb_realisations, self.algo.config.Nu2))
778 sv_X_real = [np.zeros(sv_X[0].shape), np.zeros(sv_X[1].shape), np.zeros(sv_X[2].shape)]
779 Yz = np.zeros((self.algo.nb_realisations, sv_X[0].shape[1] + sv_X[1].shape[1] + sv_X[2].shape[1]))
780 for i_idx, i_real in enumerate(self.algo.attributed_models):
782 # build obs operator complete
783 H, Ab = self.build_full_Hz(input_core_state, b_minus[i_real], b[i_real], b_plus[i_real], sv_H)
784 Hz = self.build_H_operator(H, row_diff = True)
785 # Z is centered on 0 so we must remove the mean from the observation data
786 # Y = Ab (U+U0) + (ER + ER0) => Y - Ab U0 - ER0 = Ab U + ER
787 for i in range(3):
788 sv_X_real[i][i_real] = sv_X[i][i_real] - (sv_H[i] @ Ab[i] @ self.algo.avg_prior["U"] + sv_H[i] @ self.algo.avg_prior["ER"])
790 # build obs data
791 Yz[i_real] = np.concatenate((sv_X_real[0][i_real],sv_X_real[1][i_real],sv_X_real[2][i_real]),axis=0)
793 PzzHT = PZZ @ Hz.T
794 HPzzHT = Hz @ PzzHT
796 analysed_Z[i_real] = common.compute_Kalman_huber_parameter_basis(Z_full[i_real],
797 Yz[i_real],
798 HPzzHT,
799 PzzHT,
800 Hz,
801 Rzz)
803 analysed_U[i_real], analysed_ER[i_real] = self.algo.Z_to_U_ER(analysed_Z[i_real,:Nz],1)
805 analysed_SV[i_real] = Ab[1] @ analysed_U[i_real] + analysed_ER[i_real]
807 # gather process
808 comm.Allreduce(MPI.IN_PLACE, analysed_U, op=MPI.SUM)
809 comm.Allreduce(MPI.IN_PLACE, analysed_ER, op=MPI.SUM)
810 comm.Allreduce(MPI.IN_PLACE, analysed_SV, op=MPI.SUM)
811 comm.Allreduce(MPI.IN_PLACE, analysed_Z, op=MPI.SUM)
812 comm.Allreduce(MPI.IN_PLACE, Yz, op=MPI.SUM)
814 # Compute the misfits for SV (Y - HX)
815 HX_z = (Hz @ np.transpose(analysed_Z)).T
816 self.current_misfits["SV"] = common.compute_misfit(
817 Yz, HX_z, linalg.inv(Rzz)
818 )
820 # extract ana_Z ana_dZ ana_d2Z
821 ana_Z = analysed_Z[:,:Nz]
822 ana_dZ = analysed_Z[:,Nz:2*Nz]
823 ana_d2Z = analysed_Z[:,2*Nz:]
825 return ana_Z, ana_dZ, ana_d2Z, analysed_U, analysed_ER, analysed_SV
827 def init_forecast(self, Z, dZ, d2Z):
828 """
829 init ZAR3 state for forecast
831 :param Z: Z state
832 :type Z: 2D np array (dim: n_reals x nz)
833 :param dZ: dZ state
834 :type dZ: 2D np array (dim: n_reals x nz)
835 :param d2Z: d2Z state
836 :type d2Z: 2D np array (dim: n_reals x nz)
837 :return: init_Z
838 :rtype: 3D np array (dim: n_reals x 3 x nz)
839 """
840 init_Z = np.zeros((self.algo.nb_realisations, 3, self.algo.config.Nz))
841 dt_f = self.algo.config.dt_f
843 init_Z[:, 2] = Z
844 init_Z[:, 1] = Z - dt_f * dZ + dt_f**2 / 2 * d2Z
845 init_Z[:, 0] = Z - (2*dt_f) * dZ + (2*dt_f)**2 / 2 * d2Z
847 return init_Z