Coverage for pygeodyn/augkf/analyser.py: 88%

329 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-12-22 13:43 +0000

1import numpy as np 

2import os.path 

3import pygeodyn 

4import logging 

5from scipy import linalg 

6from mpi4py import MPI 

7from .. import common 

8from ..inout import observations as observations_module 

9from ..generic.computer import GenericComputer 

10from sklearn import covariance as skcov 

11from .. import common 

12 

13comm = MPI.COMM_WORLD 

14rank = comm.Get_rank() 

15 

16class AugkfAnalyserAR1(GenericComputer): 

17 """ 

18 Class that handles the analyses of the Augmented State Kalman Filter algorithm with DIFF treated as a contribution to ER. 

19 """ 

20 

21 def __init__(self, algo): 

22 """ 

23 Sets all internal members, observation operator and covariance matrices for Z=UE. 

24 

25 :param algo: Algorithm object 

26 :type algo: Algo 

27 """ 

28 super().__init__(algo) 

29 

30 # Date-based dicts for observations, stored in measure_observations dict 

31 self.measure_observations = {} 

32 

33 nb_obs_mf, nb_obs_sv = self.extract_observations() 

34 if nb_obs_mf == 0: 

35 logging.error( 

36 "No observation was extracted for MF! Analyses on b will be completely skipped" 

37 ) 

38 elif nb_obs_sv == 0: 

39 logging.error( 

40 "No observation was extracted for SV! Analyses on u, e and d/dt b will be completely skipped" 

41 ) 

42 else: 

43 logging.info( 

44 "Finished extracting the observations: MF ({}) and SV ({})!".format( 

45 nb_obs_mf, nb_obs_sv 

46 ) 

47 ) 

48 

49 # Dict to store the misfits of the current analysis 

50 self.current_misfits = {} 

51 self.current_misfits = self.invalid_misfits(keys=["MF", "SV", "U", "ER", "Z"]) # initialisation 

52 

53 

54 def invalid_misfits(self, keys): 

55 possible_keys = ["MF", "SV", "U", "ER", "Z"] 

56 err = 'invalid key passed in invalid_mifits(), keys: {}, possible keys: {}' 

57 assert all([k in possible_keys for k in keys]), err.format(keys, possible_keys) 

58 try: 

59 for k in keys: 

60 self.current_misfits[k] = np.nan 

61 return self.current_misfits 

62 except NameError: 

63 return {k: np.nan for k in keys} 

64 

65 

66 def extract_observations(self): 

67 """ 

68 Extracts the observations for all obs types in the config. Updates the internal dictionaries observations_mf and observations_sv. 

69 

70 :return: The numbers of dates for MF and SV for eventual checking. 

71 :rtype: int, int 

72 """ 

73 

74 # building function can be either, build_go_vo_observations, build_covobs_observations, 

75 # build_covobs_observations or build_covobs_hdf5_observations 

76 building_function = getattr( 

77 observations_module, 

78 "build_{}_observations".format(self.cfg.obs_type.lower()), 

79 None, 

80 ) 

81 

82 if building_function is None: 

83 logging.error( 

84 "No loading function was defined for {}. The extraction of observations was skipped for this type.".format( 

85 self.cfg.obs_type 

86 ) 

87 ) 

88 logging.info("Reading {} data as observations...".format(self.cfg.obs_type)) 

89 

90 for measure_type in ["SV", "MF"]: 

91 self.measure_observations[measure_type] = building_function( 

92 self.cfg, self.algo.nb_realisations, measure_type, self.algo.seed 

93 ) 

94 if hasattr(self.cfg, "obs_mod_errors_dir"): 

95 self.add_modelisation_errors( 

96 os.path.join( 

97 pygeodyn._package_directory, self.cfg.obs_mod_errors_dir 

98 ), 

99 measure_type, 

100 ) 

101 return ( 

102 len(self.measure_observations["MF"].keys()), 

103 len(self.measure_observations["SV"].keys()), 

104 ) 

105 

106 def add_modelisation_errors(self, obs_mod_errors_dir, measure_type): 

107 """ 

108 Add another contribution to the observation errors. The errors have to be stored in two text files, 

109 one for the MF and one for the SV, see for instance 'pygeodyn/data/observations/COVOBS-x2_maglat/mod_errors' 

110 """ 

111 if self.cfg.obs_type == 'GO_VO': 

112 logging.warning('mod_error is made for synthetic data, not for GVO observatories') 

113 return None 

114 

115 try: 

116 mod_errors = np.genfromtxt( 

117 os.path.join(obs_mod_errors_dir, "{}.err".format(measure_type)) 

118 ) 

119 except IOError: 

120 logging.error( 

121 "No {0} modelisation error file ({0}.err) was found in {1} ! Modelisation errors will be set to 0 for {0}.".format( 

122 measure_type, obs_mod_errors_dir 

123 ) 

124 ) 

125 return 

126 

127 logging.debug( 

128 "Adding modelisation errors stored in {1} for {0}.".format( 

129 measure_type, obs_mod_errors_dir 

130 ) 

131 ) 

132 for date, obs in self.measure_observations[measure_type].items(): 

133 No = len(obs.Rxx) 

134 print('No, measure, mod_errors', No, date, measure_type, mod_errors.shape) 

135 self.measure_observations[measure_type][date].Rxx = obs.Rxx + np.diag( 

136 mod_errors[:No] 

137 ) 

138 

139 def sv_analysis(self): 

140 """ 

141 Shortcut to check if sv_analysis is performed 

142 """ 

143 # 1 means analysis on sv, 0 means no analysis on sv 

144 return np.all(self.ana_sv) 

145 

146 def mf_analysis(self): 

147 """ 

148 Shortcut to check if mf_analysis is performed 

149 """ 

150 # 1 means analysis on mf, 0 means no analysis on mf 

151 return np.all(self.ana_mf) 

152 

153 def check_if_analysis_data(self, i_analysis): 

154 """ 

155 check if there is mf and/or sv observation at next analysis time (times if AR3) 

156 and prepare observation data for either AR1 or AR3 analysis 

157 

158 :param i_analysis: analysis time iteration 

159 :type i_analysis: int 

160 :return: update self.ana_sv and self.ana_mf 

161 """ 

162 

163 AR_type = self.algo.config.AR_type 

164 

165 # set number of times (Nt) involved in analysis depending on AR_type 

166 if AR_type == "AR3": 

167 Nt = 3 

168 else: 

169 Nt = 1 

170 

171 # init boolean vectors  

172 self.ana_sv = np.zeros((Nt,), dtype=bool) 

173 self.ana_mf = np.zeros((Nt,), dtype=bool) 

174 

175 if not i_analysis < self.algo.config.nb_analyses: 

176 return self.ana_mf, self.ana_sv 

177 

178 # set times vector depending on AR_type 

179 if AR_type == "AR3": 

180 times = [i_analysis, 

181 i_analysis + 1, 

182 i_analysis + 2] 

183 else: 

184 times = [i_analysis+1] 

185 

186 # loop over Nt 

187 for i in range(Nt): 

188 t = str(times[i]) 

189 # if observation found 

190 if t in self.measure_observations["SV"].keys(): 

191 #update self.ana_sv 

192 self.ana_sv[i] = True 

193 if t in self.measure_observations["MF"].keys(): 

194 #update self.ana_sv 

195 self.ana_mf[i] = True 

196 #setup self.measure_observations 

197 if AR_type == "AR3": 

198 if self.sv_analysis(): 

199 self.sv_X = [self.measure_observations["SV"][str(times[0])].X, 

200 self.measure_observations["SV"][str(times[1])].X, 

201 self.measure_observations["SV"][str(times[2])].X] 

202 self.sv_H = [self.measure_observations["SV"][str(times[0])].H, 

203 self.measure_observations["SV"][str(times[1])].H, 

204 self.measure_observations["SV"][str(times[2])].H] 

205 self.sv_Rxx = [self.measure_observations["SV"][str(times[0])].Rxx, 

206 self.measure_observations["SV"][str(times[1])].Rxx, 

207 self.measure_observations["SV"][str(times[2])].Rxx] 

208 if self.mf_analysis(): 

209 self.mf_X = [self.measure_observations["MF"][str(times[0])].X, 

210 self.measure_observations["MF"][str(times[1])].X, 

211 self.measure_observations["MF"][str(times[2])].X] 

212 self.mf_H = [self.measure_observations["MF"][str(times[0])].H, 

213 self.measure_observations["MF"][str(times[1])].H, 

214 self.measure_observations["MF"][str(times[2])].H] 

215 self.mf_Rxx = [self.measure_observations["MF"][str(times[0])].Rxx, 

216 self.measure_observations["MF"][str(times[1])].Rxx, 

217 self.measure_observations["MF"][str(times[2])].Rxx] 

218 else: 

219 if self.sv_analysis(): 

220 self.sv_X = self.measure_observations["SV"][str(times[0])].X 

221 self.sv_H = self.measure_observations["SV"][str(times[0])].H 

222 self.sv_Rxx = self.measure_observations["SV"][str(times[0])].Rxx 

223 

224 if self.mf_analysis(): 

225 self.mf_X = self.measure_observations["MF"][str(times[0])].X 

226 self.mf_H = self.measure_observations["MF"][str(times[0])].H 

227 self.mf_Rxx = self.measure_observations["MF"][str(times[0])].Rxx 

228 

229 # if no mf analysis 

230 if not self.mf_analysis(): 

231 logging.critical( 

232 "Skipping MF analysis" 

233 ) 

234 self.current_misfits = self.invalid_misfits(keys=["MF"]) 

235 # if no sv analysis 

236 if not self.sv_analysis(): 

237 logging.critical( 

238 "Skipping SV analysis" 

239 ) 

240 self.current_misfits = self.invalid_misfits(keys=["SV"]) 

241 

242 

243 def analysis_step(self, input_core_state): 

244 """ Does the analysis at time t on the B and Z=[UE] part of the input_core_state. 

245 Updates SV = A(B)U - ER in consequence. 

246 

247 :param input_core_state: Core state at time t 

248 :type input_core_state: corestates.CoreState (dim: nb_realisations x Ncorestate) 

249 :return: the analysed core state 

250 :rtype: corestates.CoreState (dim: nb_realisations x Ncorestate) 

251 """ 

252 

253 # Check corestate is 2D 

254 if input_core_state.B.ndim != 2: 

255 raise ValueError("Corestate must be 2D but is ({}D)".format(input_core_state.B.ndim)) 

256 

257 # Check that number of reals match 

258 if input_core_state.B.shape[0] != self.algo.nb_realisations: 

259 raise ValueError( 

260 "Number of realisations in input_core_state ({}) is not equal to expected number ({})".format( 

261 input_core_state.B.shape[0], self.algo.nb_realisations)) 

262 

263 # copy core state 

264 ana_core_state = input_core_state.copy() 

265 

266 # if no analysis 

267 if not self.sv_analysis() and not self.mf_analysis(): # no analysis of mf and sv 

268 return ana_core_state[self.algo.attributed_models] 

269 

270 if self.mf_analysis(): 

271 # perform MF analysis 

272 ana_core_state.B = self.analyse_B(ana_core_state.B, self.mf_X, self.mf_H, self.mf_Rxx) 

273 if self.sv_analysis(): 

274 # perform SV analysis 

275 (ana_core_state.Z, 

276 ana_core_state.U, 

277 ana_core_state.ER, 

278 ana_core_state.SV) = self.analyse_Z(ana_core_state, self.sv_X, self.sv_H, self.sv_Rxx) 

279 return ana_core_state[self.algo.attributed_models] 

280 

281 def analyse_B(self, input_core_state, mf_X, mf_H, mf_Rxx): 

282 """ 

283 Returns the analysed data for B by a BLUE given the observations. 

284 

285 :param input_core_state: NumPy array containing the coefficient data of B 

286 :type input_core_state: np.array (dim: nb_realisations x Nb) 

287 :param mf_X: Observation data to use for the BLUE 

288 :type mf_X: Observation 

289 :param mf_H: Observation matrix to use for the BLUE 

290 :type mf_H: Observation 

291 :param mf_Rxx: Observation error to use for the BLUE 

292 :type mf_Rxx: Observation 

293 :return: NumPy array containing the analysed coefficient data of B 

294 :rtype: np.array (dim: nb_realisations x Nb) 

295 """ 

296 # obs operator 

297 Hb = mf_H 

298 # obs error 

299 Rbb = mf_Rxx 

300 # compute Pbb from B state 

301 Pbb_forecast = self.remove_small_correlations(input_core_state) 

302 # Updates the B part of the core_state by the result of the Kalman filter for each model 

303 logging.debug("Getting best linear unbiased estimate of B...") 

304 analysed_B = np.zeros((self.algo.nb_realisations, self.algo.config.Nb)) 

305 

306 if self.cfg.kalman_norm == 'l2': # for non least square norm, iteration are needed 

307 Kbb = common.compute_Kalman_gain_matrix( 

308 Pbb_forecast, Hb, Rbb 

309 ) 

310 for i_idx, i_real in enumerate(self.algo.attributed_models): 

311 analysed_B[i_real] = common.get_BLUE( 

312 input_core_state[i_real], 

313 mf_X[i_real], 

314 Pbb_forecast, 

315 Hb, 

316 Rbb, 

317 Kbb, 

318 ) 

319 

320 elif self.cfg.kalman_norm == 'huber': 

321 # compute inverse of P_bb before loop on reals using its symmetry 

322 P_eig_val, P_eig_vec = np.linalg.eigh(Pbb_forecast) 

323 P_eig_val[P_eig_val < 1e-10] = 1e-10 # in case if matrix is not full rank, which should not happen 

324 Pbb_inv = P_eig_vec @ np.diag(1 / P_eig_val) @ P_eig_vec.T 

325 for i_idx, i_real in enumerate(self.algo.attributed_models): 

326 analysed_B[i_real] = common.compute_Kalman_huber(input_core_state[i_real], mf_X[i_real], 

327 Pbb_inv, Hb, Rbb) 

328 else: 

329 raise ValueError('Invalid value of param kalman_norm, should be equal to huber or l2, got {}'.format(self.cfg.kalman_norm)) 

330 

331 comm.Allreduce(MPI.IN_PLACE, analysed_B, op=MPI.SUM) 

332 

333 # Compute the misfits for B (Y - HX) 

334 HX_b = np.transpose(np.matmul(Hb, np.transpose(analysed_B))) 

335 self.current_misfits["MF"] = common.compute_misfit( 

336 mf_X, HX_b, linalg.inv(Rbb) 

337 ) 

338 

339 return analysed_B 

340 

341 

342 def setup_Hz(self,Ab, Nu): 

343 """ 

344 Compute the matrix [Ab | I_e] where Ab is the contains the Gaunt elasser integrals, 

345 while I_e is the identity matrix of size Nsv 

346  

347 :param Nu: dimension of the flow 

348 :type Nu: int 

349 :return: matrix Nsv x (Nu + Ne) 

350 :rtype: numpy array 

351 """ 

352 assert Ab.shape[1] == Nu, "mismatch in dimension of Ab {} and flow {} ".format(Ab.shape[1], Nu) 

353 Nsv = self.cfg.Nsv 

354 

355 Hz = np.zeros((Nsv, Nu + Nsv), order="F") 

356 

357 # Set the observation operator to A(B) for U 

358 Hz[:Nsv, :Nu] = Ab 

359 # and identity for E 

360 Hz[:Nsv, Nu:Nu + Nsv] = np.identity(Nsv) * self.cfg.compute_e 

361 return Hz 

362 

363 def remove_small_correlations(self, input_core_state, eps=1e-10): 

364 """ 

365 Apply the graphical lasso to the correlation matrix. The correlation matrix is computed from 

366 the covariance matrix, either Pzz or Pbb in practice, with C[i, j] = P[i, j] / (P[i, i] P[j, j]). 

367 Warning: In some cases, some variance elements can be zero, for instance if the initialisation 

368 parameter, core_state_init, is set to constant. Then the correlation matrix cannot be computed 

369 and the Glasso is not applied. 

370 

371 If the glasso parameter, self.cfg.remove_spurious, is set to 0 (np.inf), then the resp. diagonal (empirical) 

372 covariance matrix is returned. 

373 Otherwise the glasso is applied on the correlation matrix. 

374 

375 :param input_core_state: Corestate which can either be Z or B, 

376 at a given time for all realizations 

377 :param eps: threshold that determines if a value should be considered as null. During the division to 

378 get the correlation matrix, null values are replaced by eps 

379 :type eps: float 

380 """ 

381 

382 # computation of the empirical Pzz_forecast 

383 P_forecast = common.cov(input_core_state) 

384 

385 if self.cfg.remove_spurious == 0: # extreme case, keep the sample covariance matrix 

386 return P_forecast 

387 if self.cfg.remove_spurious == np.inf: # other extreme case, take only the diagonal of the covariance matrix 

388 return np.diag(np.diag(P_forecast)) 

389 

390 if np.any(np.diag(P_forecast) == 0): 

391 # avoid division by zeros (exactly zeros, small numbers are left) by regularization 

392 # with many zeros, might give a hard time to the graphical lasso algo, as result may not converge. 

393 zeros = np.squeeze(np.argwhere(np.diag(P_forecast) == 0)) 

394 logging.warning('Some coefficients in the diagonal sample covariance matrix are very close to zero at {}'.format(zeros)) 

395 diag_mask = np.zeros_like(P_forecast, dtype=bool) 

396 np.fill_diagonal(diag_mask, True) 

397 P_forecast[np.logical_and(P_forecast==0, diag_mask)] = eps # only modifies zeros present in the diagonal 

398 

399 # if some values in the diagonal of the covariance matrix are zero, it will still give a 1 in the diagonal of the correlation matrix 

400 # compute correlation matrix 

401 diag_sq_inv = lambda M: np.diag(1 / np.sqrt(np.diag(M))) 

402 C_forecast = diag_sq_inv(P_forecast) @ P_forecast @ diag_sq_inv(P_forecast) 

403 

404 # Compute the lasso approximation 

405 C_lasso = skcov.graphical_lasso(C_forecast, self.cfg.remove_spurious, max_iter=100)[0] 

406 

407 # compute the P_lasso from P_forecast 

408 diag_sq = lambda M: np.diag(np.sqrt(np.diag(M))) 

409 

410 return diag_sq(P_forecast) @ C_lasso @ diag_sq(P_forecast) 

411 

412 

413 def analyse_Z(self, input_core_state, sv_X, sv_H, sv_Rxx): 

414 """ 

415 Returns the analysed data for the augmented state Z = [U ER] and SV by a BLUE given the observations. 

416 

417 :param input_core_state: 2D CoreState containing the coefficient data 

418 :type input_core_state: CoreState 

419 :param sv_X: Observation data to use for the BLUE 

420 :type sv_X: Observation 

421 :param sv_H: Observation matrix to use for the BLUE 

422 :type sv_H: Observation 

423 :param sv_Rxx: Observation error to use for the BLUE 

424 :type sv_Rxx: Observation 

425 :return: 2D analysed Z U ER SV 

426 :rtype: 2D arrays (Nreal x Ncoef) 

427 

428 """ 

429 

430 # compute necessary matrices for Kalman filter 

431 Pzz_forecast = self.remove_small_correlations(input_core_state.Z) 

432 

433 analysed_Z = np.zeros((self.algo.nb_realisations, self.algo.config.Nz)) 

434 analysed_SV = np.zeros((self.algo.nb_realisations, self.algo.config.Nsv)) 

435 analysed_ER = np.zeros((self.algo.nb_realisations, self.algo.config.Nsv)) 

436 analysed_U = np.zeros((self.algo.nb_realisations, self.algo.config.Nu2)) 

437 sv_X_real = np.zeros(sv_X.shape) 

438 for i_idx, i_real in enumerate(self.algo.attributed_models): 

439 

440 Ab = self.compute_Ab(input_core_state[i_real]) 

441 

442 # The complete H operator (No x Ncoefs of U) is: 

443 # Hsv (No x (Ncoefs of core + Ncoefs SV)) * Hz ((Ncoefs of core + Ncoefs SV) x Ncoefs of U) 

444 if self.algo.check_PCA(): 

445 # if PCA 

446 complete_H = sv_H @ self.setup_Hz(Ab @ self.cfg.pcaU_operator.S_u, self.cfg.N_pca_u) 

447 else: 

448 # if no PCA 

449 complete_H = sv_H @ self.setup_Hz(Ab, self.cfg.Nu2) 

450 

451 PzzHT = Pzz_forecast @ complete_H.T 

452 HPzzHT = complete_H @ PzzHT 

453 

454 # Z is centered on 0 so we must remove the mean from the observation data 

455 # Y = Ab (U+U0) + (ER + ER0) => Y - Ab U0 - ER0 = Ab U + ER 

456 sv_X_real[i_real] = sv_X[i_real] - (sv_H @ Ab @ self.algo.avg_prior["U"] + sv_H @ self.algo.avg_prior["ER"]) 

457 

458 analysed_Z[i_real] = common.compute_Kalman_huber_parameter_basis(input_core_state.Z[i_real], 

459 sv_X_real[i_real], 

460 HPzzHT, 

461 PzzHT, 

462 complete_H, 

463 sv_Rxx) 

464 analysed_U[i_real], analysed_ER[i_real] = self.algo.Z_to_U_ER(analysed_Z[i_real],1) 

465 

466 analysed_SV[i_real] = Ab @ analysed_U[i_real] + analysed_ER[i_real] 

467 

468 comm.Allreduce(MPI.IN_PLACE, analysed_Z, op=MPI.SUM) 

469 comm.Allreduce(MPI.IN_PLACE, analysed_U, op=MPI.SUM) 

470 comm.Allreduce(MPI.IN_PLACE, analysed_ER, op=MPI.SUM) 

471 comm.Allreduce(MPI.IN_PLACE, analysed_SV, op=MPI.SUM) 

472 comm.Allreduce(MPI.IN_PLACE, sv_X_real, op=MPI.SUM) 

473 

474 # Compute the misfits for SV (Y - HX) 

475 HX_z = (complete_H @ np.transpose(analysed_Z)).T 

476 self.current_misfits["SV"] = common.compute_misfit( 

477 sv_X_real, HX_z, linalg.inv(sv_Rxx) 

478 ) 

479 

480 return analysed_Z, analysed_U, analysed_ER, analysed_SV 

481 

482 

483class AugkfAnalyserAR3(AugkfAnalyserAR1): 

484 

485 def analysis_step(self, input_core_state, Z_AR3): 

486 """ Does the analysis at time t on the B and Z=[UE] part of the input_core_state. 

487 Updates SV = A(B)U - ER in consequence. 

488 

489 :param input_core_state: Core state at time t 

490 :type input_core_state: corestates.CoreState (dim: nb_realisations x Ncorestate) 

491 :param ZAR3: AR3 forecast state 

492 :type ZAR3: np.array (dim: nb_realisations x 3 x Nz) 

493 :return: the analysed core state and updated ZAR3 

494 :rtype: corestates.CoreState (dim: nb_realisations x Ncorestate), np.array (dim: nb_realisations x 3 x Nz) 

495 """ 

496 

497 # Check corestate is 3D 

498 if input_core_state.B.ndim != 3: 

499 raise ValueError("Corestate must be 3D but is ({}D)".format(input_core_state.B.ndim)) 

500 

501 # Check the dim of core_state 

502 if input_core_state.B.shape[0] != self.algo.nb_realisations: 

503 raise ValueError( 

504 "Number of realisations in input_core_state ({}) is not equal to expected number ({})".format( 

505 input_core_state.B.shape[0], self.algo.nb_realisations 

506 ) 

507 ) 

508 

509 ratio = self.algo.config.dt_a_f_ratio 

510 

511 # copy core state 

512 ana_core_state = input_core_state[:,ratio].copy() 

513 

514 # if no analysis 

515 if not self.sv_analysis() and not self.mf_analysis(): # no analysis of mf and sv 

516 return ana_core_state[self.algo.attributed_models], Z_AR3 

517 

518 

519 if self.mf_analysis(): 

520 # compute the full state B B' B'' at analysis time by fitting B forecast curves 

521 B_full = self.compute_full_state(input_core_state.B) 

522 

523 # compute analysis on B B' B'' and return b_minus b b_plus 

524 b_minus, b, b_plus = self.analyse_B(B_full, self.mf_X, self.mf_H, self.mf_Rxx) 

525 # update ana B state 

526 ana_core_state.B = b 

527 else: 

528 #else we simply pick b_minus b b_plus from forecast values 

529 b_minus = input_core_state.B[:, 0] 

530 b = input_core_state.B[:, ratio] 

531 b_plus = input_core_state.B[:, 2*ratio] 

532 

533 if self.sv_analysis(): 

534 # compute the full state Z Z' Z'' at analysis time by fitting Z forecast curves 

535 Z_full = self.compute_full_state(input_core_state.Z) 

536 

537 # perform the analysis on Z Z' Z'' 

538 (ana_Z, 

539 ana_dZ, 

540 ana_d2Z, 

541 ana_core_state.U, 

542 ana_core_state.ER, 

543 ana_core_state.SV) = self.analyse_Z(Z_full, self.sv_X, self.sv_H, self.sv_Rxx, b_minus, b, b_plus, input_core_state[0,ratio]) 

544 

545 # update ana Z state 

546 ana_core_state.Z = ana_Z 

547 # update ana dUdt and ana dERdt state 

548 ana_core_state.dUdt, ana_core_state.dERdt = self.algo.dZ_to_dU_dER(ana_dZ) 

549 # update ana d2Udt2 and ana d2ERdt2 state 

550 ana_core_state.d2Udt2, ana_core_state.d2ERdt2 = self.algo.dZ_to_dU_dER(ana_d2Z) 

551 

552 # update Z_AR3 

553 Z_AR3 = self.init_forecast(ana_Z, ana_dZ, ana_d2Z) 

554 

555 return ana_core_state[self.algo.attributed_models], Z_AR3[self.algo.attributed_models] 

556 

557 else: 

558 return ana_core_state[self.algo.attributed_models], Z_AR3 

559 

560 

561 def build_H_operator(self, X, row_diff = False): 

562 """ 

563 Compute observation H operator 

564 

565 :param X: X matrix (2D or 3D) 

566 :param row_diff: tells if X is the same everywhere (2D X) or different (3D X) 

567 """ 

568 

569 dt_a = self.algo.config.dt_a 

570 if row_diff == False: 

571 o = np.zeros(X.shape) 

572 return np.concatenate((np.concatenate((X, -dt_a * X, dt_a**2/2*X),axis=1), 

573 np.concatenate((X, o, o),axis=1), 

574 np.concatenate((X, dt_a * X, dt_a**2/2*X),axis=1)) 

575 ,axis=0) 

576 else: 

577 o = np.zeros(X[1].shape) 

578 return np.concatenate((np.concatenate((X[0], -dt_a * X[0], dt_a**2/2*X[0]),axis=1), 

579 np.concatenate((X[1], o, o),axis=1), 

580 np.concatenate((X[2], dt_a * X[2], dt_a**2/2*X[2]),axis=1)),axis=0) 

581 

582 

583 def build_obs_operator(self, obs): 

584 """ 

585 Compute observation data operator 

586 

587 :param obs: Observation data 

588 """ 

589 

590 return np.concatenate((obs[0],obs[1],obs[2]),axis=1) 

591 

592 

593 def build_err_operator(self, error): 

594 """ 

595 Compute error operator 

596 

597 :param error: Observation error 

598 """ 

599 

600 return linalg.block_diag(error[0], error[1], error[2]) 

601 

602 

603 def compute_full_state(self, X_state): 

604 """ 

605 Compute Least Square fit of X_ta, dX_ta and d2X_ta 

606 

607 :param X_state: X_state 

608 :type X_state: 3D np array (dim: n_reals x nb_forecasts x ncoefs) 

609 :return: Xfull 

610 :rtype: 2D np array (dim: n_reals x 3*ncoefs) 

611 """ 

612 N_coef = X_state.shape[2] 

613 ratio = self.algo.config.dt_a_f_ratio 

614 dt_f = self.algo.config.dt_f 

615 

616 dt = np.arange(start = -ratio, stop = ratio + 1, step = 1) * dt_f 

617 dt = dt.reshape(-1,1) 

618 

619 #build G operator 

620 G = np.ones(dt.shape) 

621 max_order = 2 

622 for i in range(max_order): 

623 i = i + 1 

624 dG = np.power(dt,i)/np.math.factorial(i) 

625 G = np.concatenate((G, dG), axis=1) 

626 

627 X_full = np.zeros((self.algo.nb_realisations, 3 * N_coef)) 

628 for i_idx, i_real in enumerate(self.algo.attributed_models): 

629 Y = X_state[i_real] 

630 # perform best fit L2 norm 

631 S = np.linalg.inv(G.T @ G) @ G.T @ Y 

632 X_full[i_real] = S[:3].flatten() 

633 

634 comm.Allreduce(MPI.IN_PLACE, X_full, op=MPI.SUM) 

635 

636 return X_full 

637 

638 

639 def analyse_B(self, B_full, mf_X, mf_H, mf_Rxx): 

640 """ 

641 Returns the analysed data for B by a BLUE given the observations. 

642 

643 :param input_B: NumPy array containing the coefficient data of B 

644 :type input_B: np.array (dim: nb_realisations x Nb) 

645 :param mf_X: Observation data to use for the BLUE 

646 :type mf_X: Observation 

647 :param mf_H: Observation matrix to use for the BLUE 

648 :type mf_H: Observation 

649 :param mf_Rxx: Observation error to use for the BLUE 

650 :type mf_Rxx: Observation 

651 :return: b_minus, b, b_plus 

652 :rtype: np.array (dim: nb_realisations x Nb), np.array (dim: nb_realisations x Nb), np.array (dim: nb_realisations x Nb) 

653 """ 

654 

655 Nb = self.cfg.Nb 

656 

657 #obs operator 

658 Hb = self.build_H_operator(mf_H, row_diff = True) 

659 #obs data 

660 Yb = self.build_obs_operator(mf_X) 

661 #obs error 

662 Rbb = self.build_err_operator(mf_Rxx) 

663 # compute Pbb from B_full state (B B' B'') 

664 Pbb = self.remove_small_correlations(B_full) 

665 

666 # Updates the B part of the core_state by the result of the Kalman filter for each model 

667 logging.debug("Getting best linear unbiased estimate of B...") 

668 analysed_B = np.zeros((self.algo.nb_realisations, 3*Nb)) 

669 

670 if self.cfg.kalman_norm == 'l2': # for non least square norm, iteration are needed 

671 Kbb = common.compute_Kalman_gain_matrix( 

672 Pbb, Hb, Rbb 

673 ) 

674 for i_idx, i_real in enumerate(self.algo.attributed_models): 

675 analysed_B[i_real] = common.get_BLUE( 

676 B_full[i_real], 

677 Yb[i_real], 

678 Pbb, 

679 Hb, 

680 Rbb, 

681 K=Kbb, 

682 ) 

683 

684 elif self.cfg.kalman_norm == 'h875ber': 

685 # compute inverse of P_bb before loop on reals using its symmetry 

686 P_eig_val, P_eig_vec = np.linalg.eigh(Pbb) 

687 P_eig_val[P_eig_val < 1e-10] = 1e-10 # in case if matrix is not full rank, which should not happen 

688 Pbb_inv = P_eig_vec @ np.diag(1 / P_eig_val) @ P_eig_vec.T 

689 for i_idx, i_real in enumerate(self.algo.attributed_models): 

690 analysed_B[i_real] = common.compute_Kalman_huber(B_full[i_real], Yb[i_real], 

691 Pbb_inv, Hb, Rbb) 

692 else: 

693 raise ValueError('Invalid value of param kalman_norm, should be equal to huber or l2, got {}'.format(self.cfg.kalman_norm)) 

694 

695 # gather process 

696 comm.Allreduce(MPI.IN_PLACE, analysed_B, op=MPI.SUM) 

697 

698 # Compute the misfits for B (Y - HX) 

699 HX_b = np.transpose(np.matmul(Hb, np.transpose(analysed_B))) 

700 self.current_misfits["MF"] = common.compute_misfit( 

701 Yb, HX_b, linalg.inv(Rbb) 

702 ) 

703 

704 # compute b_minus b b_plus from B,B'B'' 

705 B_a = analysed_B @ self.build_H_operator(np.identity(Nb)).T 

706 

707 #extract b_minus b b_plus 

708 b_minus = B_a[:,:Nb] 

709 b = B_a[:,Nb:2*Nb] 

710 b_plus = B_a[:,2*Nb:] 

711 

712 return b_minus, b, b_plus 

713 

714 

715 def build_full_Hz(self,input_core_state, b_minus, b, b_plus, sv_H): 

716 """ 

717 Build the full observation matrix H for Z 

718 """ 

719 #loop over ta_minus ta ta_plus 

720 for i in range(3): 

721 if i == 0: 

722 input_core_state.B = b_minus 

723 elif i == 1: 

724 input_core_state.B = b 

725 elif i == 2: 

726 input_core_state.B = b_plus 

727 Ab = self.compute_Ab(input_core_state) 

728 if self.algo.check_PCA(): 

729 # if PCA 

730 h = sv_H[i] @ self.setup_Hz(Ab @ self.cfg.pcaU_operator.S_u, self.cfg.N_pca_u) 

731 else: 

732 # if no PCA 

733 h = sv_H[i] @ self.setup_Hz(Ab, self.cfg.Nu2) 

734 try: 

735 H = H + [h] 

736 Ab_out = Ab_out + [Ab] 

737 except: 

738 H = [h] 

739 Ab_out = [Ab] 

740 return H, Ab_out 

741 

742 

743 def analyse_Z(self, Z_full, sv_X, sv_H, sv_Rxx, b_minus, b, b_plus, input_core_state): 

744 """ 

745 Returns the analysed data for the augmented state Z and SV that maximizes the 

746 likelihood (of the Gaussian or Huber distribution) 

747 

748 :param input_core_state: Z_full  

749 :type input_core_state: np.array (dim: nb_realisations x 3*Nz) 

750 :param sv_X: Observation data to use for the BLUE 

751 :type sv_X: Observation 

752 :param sv_H: Observation matrix to use for the BLUE 

753 :type sv_H: Observation 

754 :param sv_Rxx: Observation error to use for the BLUE 

755 :type sv_Rxx: Observation 

756 :param b_minus: b at ta- 

757 :type b_minus: np.array (dim: nb_realisations x Nb) 

758 :param b: b at ta 

759 :type b: np.array (dim: nb_realisations x Nb) 

760 :param b_plus: b at ta+ 

761 :type b_plus: np.array (dim: nb_realisations x Nb) 

762 :return: ana_Z, ana_dZ, ana_d2Z, ana_U, ana_ER, ana_SV 

763 :rtype: np.array (dim: nb_realisations x Ncoef) 

764 """ 

765 

766 Nz = self.cfg.Nz 

767 

768 # build obs error 

769 Rzz = self.build_err_operator(sv_Rxx) 

770 

771 # compute necessary matrices for Kalman filter 

772 PZZ = self.remove_small_correlations(Z_full) 

773 

774 analysed_Z = np.zeros((self.algo.nb_realisations, 3*Nz)) 

775 analysed_SV = np.zeros((self.algo.nb_realisations, self.algo.config.Nsv)) 

776 analysed_ER = np.zeros((self.algo.nb_realisations, self.algo.config.Nsv)) 

777 analysed_U = np.zeros((self.algo.nb_realisations, self.algo.config.Nu2)) 

778 sv_X_real = [np.zeros(sv_X[0].shape), np.zeros(sv_X[1].shape), np.zeros(sv_X[2].shape)] 

779 Yz = np.zeros((self.algo.nb_realisations, sv_X[0].shape[1] + sv_X[1].shape[1] + sv_X[2].shape[1])) 

780 for i_idx, i_real in enumerate(self.algo.attributed_models): 

781 

782 # build obs operator complete 

783 H, Ab = self.build_full_Hz(input_core_state, b_minus[i_real], b[i_real], b_plus[i_real], sv_H) 

784 Hz = self.build_H_operator(H, row_diff = True) 

785 # Z is centered on 0 so we must remove the mean from the observation data 

786 # Y = Ab (U+U0) + (ER + ER0) => Y - Ab U0 - ER0 = Ab U + ER 

787 for i in range(3): 

788 sv_X_real[i][i_real] = sv_X[i][i_real] - (sv_H[i] @ Ab[i] @ self.algo.avg_prior["U"] + sv_H[i] @ self.algo.avg_prior["ER"]) 

789 

790 # build obs data 

791 Yz[i_real] = np.concatenate((sv_X_real[0][i_real],sv_X_real[1][i_real],sv_X_real[2][i_real]),axis=0) 

792 

793 PzzHT = PZZ @ Hz.T 

794 HPzzHT = Hz @ PzzHT 

795 

796 analysed_Z[i_real] = common.compute_Kalman_huber_parameter_basis(Z_full[i_real], 

797 Yz[i_real], 

798 HPzzHT, 

799 PzzHT, 

800 Hz, 

801 Rzz) 

802 

803 analysed_U[i_real], analysed_ER[i_real] = self.algo.Z_to_U_ER(analysed_Z[i_real,:Nz],1) 

804 

805 analysed_SV[i_real] = Ab[1] @ analysed_U[i_real] + analysed_ER[i_real] 

806 

807 # gather process 

808 comm.Allreduce(MPI.IN_PLACE, analysed_U, op=MPI.SUM) 

809 comm.Allreduce(MPI.IN_PLACE, analysed_ER, op=MPI.SUM) 

810 comm.Allreduce(MPI.IN_PLACE, analysed_SV, op=MPI.SUM) 

811 comm.Allreduce(MPI.IN_PLACE, analysed_Z, op=MPI.SUM) 

812 comm.Allreduce(MPI.IN_PLACE, Yz, op=MPI.SUM) 

813 

814 # Compute the misfits for SV (Y - HX) 

815 HX_z = (Hz @ np.transpose(analysed_Z)).T 

816 self.current_misfits["SV"] = common.compute_misfit( 

817 Yz, HX_z, linalg.inv(Rzz) 

818 ) 

819 

820 # extract ana_Z ana_dZ ana_d2Z 

821 ana_Z = analysed_Z[:,:Nz] 

822 ana_dZ = analysed_Z[:,Nz:2*Nz] 

823 ana_d2Z = analysed_Z[:,2*Nz:] 

824 

825 return ana_Z, ana_dZ, ana_d2Z, analysed_U, analysed_ER, analysed_SV 

826 

827 def init_forecast(self, Z, dZ, d2Z): 

828 """ 

829 init ZAR3 state for forecast 

830 

831 :param Z: Z state 

832 :type Z: 2D np array (dim: n_reals x nz) 

833 :param dZ: dZ state 

834 :type dZ: 2D np array (dim: n_reals x nz) 

835 :param d2Z: d2Z state 

836 :type d2Z: 2D np array (dim: n_reals x nz) 

837 :return: init_Z 

838 :rtype: 3D np array (dim: n_reals x 3 x nz) 

839 """ 

840 init_Z = np.zeros((self.algo.nb_realisations, 3, self.algo.config.Nz)) 

841 dt_f = self.algo.config.dt_f 

842 

843 init_Z[:, 2] = Z 

844 init_Z[:, 1] = Z - dt_f * dZ + dt_f**2 / 2 * d2Z 

845 init_Z[:, 0] = Z - (2*dt_f) * dZ + (2*dt_f)**2 / 2 * d2Z 

846 

847 return init_Z