Coverage for pygeodyn/corestates.py: 89%

272 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-12-22 13:43 +0000

1import numpy as np 

2import scipy as sp 

3from mpi4py import MPI 

4from functools import wraps 

5import pygeodyn.common as common 

6from pygeodyn.utilities import date_to_decimal 

7from pygeodyn.inout.reads import read_analysed_states_hdf5 

8 

9comm = MPI.COMM_WORLD 

10rank = comm.Get_rank() 

11 

12class CoreState: 

13 """ 

14 CoreState object. It stores the measure datas in a dict of ndarrays. 

15 """ 

16 def __init__(self, init_measures=None): 

17 """ 

18 Initiates the CoreState. Initial measures can be given in the init_measures arg: 

19 - Either with only the data (max_degree will be inferred from the data) 

20 Ex: CoreState({SV: np.zeros(224), ...}) 

21 - Or by giving also the max_degree as a second member of a 2-tuple/list (the first being the data): 

22 Ex: CoreState({SV: [np.zeros(224), 14], ...}) 

23 

24 :param init_measures: The measures to add in a dict with members can be np.ndarray or 2-tuple/lists. 

25 :type init_measures: dict or None 

26 """ 

27 self._measures = {} 

28 self._max_degrees = {} 

29 

30 if init_measures is not None and isinstance(init_measures, dict): 

31 for meas_id, measure_data in init_measures.items(): 

32 # If 2-tuple/list, assume data in first member and Lmax in second 

33 if len(measure_data) == 2 and (isinstance(measure_data, list) or isinstance(measure_data, tuple)): 

34 self.addMeasure(meas_id, measure_data[0], measure_data[1]) 

35 # Else add the data with Lmax inferred from the data 

36 else: 

37 self.addMeasure(meas_id, measure_data) 

38 

39 def create_CoreState(self): 

40 """ 

41 Create an instatiation of the class CoreState. 

42 Convenience function created for inheritance purposes. 

43 """ 

44 return CoreState() 

45 

46 def addMeasure(self, meas_id, meas_data, meas_max_degree=None): 

47 """ 

48 Adds a measure to the CoreState. 

49 

50 :param meas_id: name of the measure. Used as key of dict for internal storing. 

51 :type meas_id: str 

52 :param meas_data: data of the measure. 

53 :type meas_data: np.ndarray or list 

54 :param meas_max_degree: Max degree of the measure. If None (default), it will be inferred from the last dimension of the data. 

55 :type meas_max_degree: int or None 

56 """ 

57 # Try to infer the max degree from last dimension of data N=L(L+2) 

58 if meas_max_degree is None: 

59 if meas_id != 'Z': 

60 nb_coeffs = meas_data.shape[-1] 

61 # If a measure derived from U or S, the equation is N/2 = L(L+2) 

62 if ('U' == meas_id) or ('dUdt' == meas_id) or ('d2Udt2' == meas_id) or ('S' == meas_id): 

63 assert nb_coeffs % 2 == 0 

64 nb_coeffs = nb_coeffs//2 

65 computed_Lmax = np.sqrt(nb_coeffs+1) - 1 

66 if computed_Lmax != int(computed_Lmax): 

67 raise ValueError('Last dimension {} of the data given for {} does not lead to an integer max degree ! Got {} instead.' 

68 .format(meas_data.shape[-1], meas_id, computed_Lmax)) 

69 meas_max_degree = int(computed_Lmax) 

70 else: 

71 meas_max_degree = -1 

72 

73 assert isinstance(meas_max_degree, int) 

74 self._max_degrees[meas_id] = meas_max_degree 

75 self._measures[meas_id] = meas_data 

76 

77 

78 def __getitem__(self, item): 

79 # If the item is a measure name, return the measure (dict-like behaviour) 

80 if isinstance(item, str): 

81 return self.getMeasure(item) 

82 # Else, delegate the __getitem__ to the measures stored to return a sliced CoreState (ndarray-like behaviour) 

83 new_cs = self.create_CoreState() 

84 for id_qty, qty in self._measures.items(): 

85 new_cs.addMeasure(id_qty, qty.__getitem__(item), self._max_degrees[id_qty]) 

86 return new_cs 

87 

88 def __setitem__(self, key, value): 

89 for meas_id, meas_data in self._measures.items(): 

90 # If the value is a CoreState, set the items measure by measure 

91 if isinstance(value, CoreState): 

92 meas_data.__setitem__(key, value.getMeasure(meas_id)) 

93 # Else delegate the __setitem__ to the measures as it is 

94 else: 

95 meas_data.__setitem__(key, value) 

96 

97 def update_corestate(self, i_t, corestate, exclude = ["S"]): 

98 """ 

99 :param i_t: time_iteration 

100 :type i_t: integer 

101 :param corestate: corestate to update from 

102 :type corestate: Corestate 

103 :param exclude: list of measures not updated 

104 :type exclude: list 

105 """ 

106 for meas_id in self._measures.keys(): 

107 if meas_id not in exclude: 

108 self._measures[meas_id][:,i_t,:] = corestate._measures[meas_id] 

109 

110 def copy(self): 

111 """ 

112 :return: a CoreState with copied measures 

113 :rtype: CoreState 

114 """ 

115 new_cs = self.create_CoreState() 

116 for meas_id, meas_data in self._measures.items(): 

117 new_cs.addMeasure(meas_id, meas_data.copy(), self._max_degrees[meas_id]) 

118 return new_cs 

119 

120 def mean(self, *args, **kwargs): 

121 """ 

122 :return: a new CoreState with averaged measures 

123 :rtype: CoreState 

124 """ 

125 new_cs = self.create_CoreState() 

126 for meas_id, meas_data in self._measures.items(): 

127 new_cs.addMeasure(meas_id, meas_data.mean(*args, **kwargs), self._max_degrees[meas_id]) 

128 return new_cs 

129 

130 def std(self, **kwargs): 

131 """ 

132 :return: a new CoreState with the standard deviation of the measures 

133 :rtype: CoreState 

134 """ 

135 new_cs = self.create_CoreState() 

136 for meas_id, meas_data in self._measures.items(): 

137 new_cs.addMeasure(meas_id, meas_data.std(**kwargs), self._max_degrees[meas_id]) 

138 return new_cs 

139 

140 @property 

141 def measures(self): 

142 """ 

143 :return: A *sorted* list of the measures of the CoreState 

144 :rtype: list 

145 """ 

146 return sorted(self._measures.keys()) 

147 

148 def keys(self): 

149 """ 

150 Convenience function to use CoreState as a dict 

151 

152 :return: A *sorted* list of the measures of the CoreState 

153 :rtype: list 

154 """ 

155 return self.measures 

156 

157 def _getMaxDegree(self, measure_type): 

158 """ 

159 :param measure_type: type of the measure 

160 :type measure_type: str 

161 :return: the max degree of the asked type or 0 if the measure is not loaded 

162 :rtype: int 

163 """ 

164 # Returns 0 if not in _max_degrees 

165 if measure_type not in self._max_degrees: 

166 return 0 

167 else: 

168 return self._max_degrees[measure_type] 

169 

170 @property 

171 def Lb(self): 

172 return self._getMaxDegree('MF') 

173 

174 @property 

175 def Nb(self): 

176 return self.Lb*(self.Lb + 2) 

177 

178 @property 

179 def Lu(self): 

180 return self._getMaxDegree('U') 

181 

182 @property 

183 def Nu2(self): 

184 return 2*self.Lu*(self.Lu + 2) 

185 

186 @property 

187 def Lsv(self): 

188 return self._getMaxDegree('SV') 

189 

190 @property 

191 def Nsv(self): 

192 return self.Lsv*(self.Lsv + 2) 

193 

194 

195 # Getters/Setters for CoreState quantities 

196 def getMeasure(self, measure_type): 

197 """ 

198 Generic getter for measures. Raises a KeyError if the measure is not loaded. 

199 

200 :param measure_type: type of the measure to get 

201 :type measure_type: str 

202 :return: the measure data 

203 :rtype: np.ndarray 

204 """ 

205 if measure_type not in self._measures: 

206 raise KeyError('{} is not a valid measure to get from CoreState.'.format(measure_type)) 

207 return self._measures[measure_type] 

208 

209 def setMeasure(self, measure_type, value): 

210 """ 

211 Generic setter for measures. Raises a KeyError if the measure is not loaded. 

212 

213 :param measure_type: type of the measure to get 

214 :type measure_type: str 

215 :param value: the new measure data 

216 :type measure: np.ndarray 

217 """ 

218 if measure_type not in self._measures: 

219 raise KeyError('{} is not a valid measure to set from CoreState.'.format(measure_type)) 

220 self._measures[measure_type] = value 

221 

222 @property 

223 def B(self): 

224 """ 

225 Get the magnetic field (B) part of the CoreState. 

226 

227 :return: Magnetic field data 

228 :rtype: np.ndarray 

229 """ 

230 return self._measures['MF'] 

231 

232 @B.setter 

233 def B(self, mag_field): 

234 self._measures['MF'] = mag_field 

235 

236 @property 

237 def U(self): 

238 """ 

239 Get the core flow (U) part of the CoreState. 

240 

241 :return: Core flow data 

242 :rtype: np.ndarray 

243 """ 

244 return self._measures['U'] 

245 

246 @U.setter 

247 def U(self, core_flow): 

248 self._measures['U'] = core_flow 

249 

250 @property 

251 def dUdt(self): 

252 """ 

253 Get time derivative of core flow (dUdt) part of the CoreState. 

254 

255 :return: time derivative of core flow 

256 :rtype: np.ndarray 

257 """ 

258 return self._measures['dUdt'] 

259 

260 @dUdt.setter 

261 def dUdt(self, time_derivative_core_flow): 

262 self._measures['dUdt'] = time_derivative_core_flow 

263 

264 @property 

265 def d2Udt2(self): 

266 """ 

267 Get second time derivative of core flow (d2Udt2) part of the CoreState. 

268 

269 :return: second time derivative of core flow 

270 :rtype: np.ndarray 

271 """ 

272 return self._measures['d2Udt2'] 

273 

274 @d2Udt2.setter 

275 def d2Udt2(self, second_time_derivative_core_flow): 

276 self._measures['d2Udt2'] = second_time_derivative_core_flow 

277 

278 @property 

279 def S(self): 

280 """ 

281 Get the shear (S) part of the CoreState. 

282 

283 :return: Shear data 

284 :rtype: np.ndarray 

285 """ 

286 return self._measures['S'] 

287 

288 @S.setter 

289 def S(self, shear): 

290 self._measures['S'] = shear 

291 

292 @property 

293 def ER(self): 

294 """ 

295 Get the subgrid errors (ER) part of the CoreState. 

296 

297 :return: Subgrid errors data 

298 :rtype: np.ndarray 

299 """ 

300 return self._measures['ER'] 

301 

302 @ER.setter 

303 def ER(self, subgrid_error): 

304 self._measures['ER'] = subgrid_error 

305 

306 @property 

307 def dERdt(self): 

308 """ 

309 Get the derivative of subgrid errors (dERdt) part of the CoreState. 

310 

311 :return: derivative of subgrid errors 

312 :rtype: np.ndarray 

313 """ 

314 return self._measures['dERdt'] 

315 

316 @dERdt.setter 

317 def dERdt(self, derivative_subgrid_error): 

318 self._measures['dERdt'] = derivative_subgrid_error 

319 

320 @property 

321 def d2ERdt2(self): 

322 """ 

323 Get the second derivative of subgrid errors (d2ERdt2) part of the CoreState. 

324 

325 :return: second derivative of subgrid errors 

326 :rtype: np.ndarray 

327 """ 

328 return self._measures['d2ERdt2'] 

329 

330 @d2ERdt2.setter 

331 def d2ERdt2(self, second_derivative_subgrid_error): 

332 self._measures['d2ERdt2'] = second_derivative_subgrid_error 

333 

334 

335 @property 

336 def Z(self): 

337 """ 

338 Get the Augmented state (Z) part of the CoreState. 

339 

340 :return: Coreflow and subgrid errors data 

341 :rtype: np.ndarray 

342 """ 

343 return self._measures['Z'] 

344 

345 @Z.setter 

346 def Z(self, augmented_state): 

347 self._measures['Z'] = augmented_state 

348 

349 @property 

350 def SV(self): 

351 """ 

352 Get the Core secular variation (SV) part of the CoreState. 

353 

354 :return: Secular variation data 

355 :rtype: np.ndarray 

356 """ 

357 return self._measures['SV'] 

358 

359 @SV.setter 

360 def SV(self, secular_variation): 

361 self._measures['SV'] = secular_variation 

362 

363 

364 def initialise_from_noised_priors(self, algo, random_state=None): 

365 """ 

366 Initialise the core state Z MF U ER SV at t=0 with a random normal draw around the average priors. 

367  

368 :param algo: algo instance 

369 :type algo: Augkf.algo 

370 :param random_state: if not None, is used for the normal draw. sets the B and U part to the average prior. 

371 :type random_state: numpy.RandomState or None 

372 :return: nothing. Simply update self. 

373 """ 

374 

375 # THE CORESTATE HAS A NUMBER OF REALISATIONS THAT MATCHES THE NUMBER OF ATTRIBUTED MODELS 

376 

377 #Averages 

378 avg_b = algo.avg_prior['B'] 

379 AR_type = algo.config.AR_type 

380 Nz = algo.config.Nz 

381 dt_f = algo.config.dt_f 

382 

383 # Lower Cholesky matrices 

384 L_bb = sp.linalg.cholesky(algo.cov_prior['B,B'], lower=True) 

385 L_zz = sp.linalg.cholesky(algo.cov_prior['Z,Z'], lower=True) 

386 

387 if AR_type == "AR3": 

388 L_dzdz = sp.linalg.cholesky(algo.cov_prior['dZ,dZ'], lower=True) 

389 L_d2zd2z = sp.linalg.cholesky(algo.cov_prior['d2Z,d2Z'], lower=True) 

390 

391 # Set random draw 

392 if random_state is not None: 

393 normal_draw = random_state.normal 

394 else: 

395 normal_draw = np.random.normal 

396 

397 # Loop over attributed models 

398 for i_idx, i_real in enumerate(algo.attributed_models): 

399 # Set normal draw 

400 w_b = normal_draw(size=algo.config.Nb) 

401 w_z = normal_draw(size=Nz) 

402 

403 # Initialise B part of core state by normal distrib N(mean_b, sigma_b) 

404 self.B[i_idx, 0] = avg_b + np.matmul(L_bb, w_b) 

405 # Initialise Z part of core state by normal distrib N(0, sigma_z)  

406 self.Z[i_idx, 0] = np.matmul(L_zz, w_z) 

407 # Z to U ER 

408 self.U[i_idx, 0], self.ER[i_idx, 0] = algo.Z_to_U_ER(self.Z[i_idx, 0], 1) 

409 # SV = A(b) U + ER 

410 self.SV[i_idx, 0] = algo.analyser.compute_Ab(self[i_idx, 0]) @ self.U[i_idx, 0] + self.ER[i_idx, 0] 

411 

412 if AR_type == "AR3": 

413 # Compute Z_AR3 

414 Z_AR3 = np.zeros((algo.attributed_models.shape[0],3,Nz)) 

415 w_dz = normal_draw(size=(algo.attributed_models.shape[0],Nz)) 

416 w_d2z = normal_draw(size=(algo.attributed_models.shape[0],Nz)) 

417 # Set dZ d2Z normal distribution 

418 dZ = w_dz @ L_dzdz.T 

419 d2Z = w_d2z @ L_d2zd2z.T 

420 # Taylor series of order 2 

421 Z_AR3[:, 2] = self.Z[:, 0] 

422 Z_AR3[:, 1] = self.Z[:, 0] - dt_f * dZ + dt_f**2 / 2 * d2Z 

423 Z_AR3[:, 0] = self.Z[:, 0] - (2*dt_f) * dZ + (2*dt_f)**2 / 2 * d2Z 

424 

425 return Z_AR3 

426 else: 

427 return None 

428 

429 

430 def initialise_from_file(self, algo): 

431 """ 

432 Initialise the core state Z MF U ER SV at t=0 from the CoreState in a file at a given date. 

433  

434 :param algo: algo instance 

435 :type algo: Augkf.algo 

436 :param file_path: path of the hdf5 file containing the computed states to use for initialisation 

437 :type file_path: str 

438 :param date: date of the CoreState to use for the initialisation 

439 :type date: datetime64 

440 :return: nothing. Simply update self. 

441 """ 

442 

443 # THE CORESTATE HAS A NUMBER OF REALISATIONS THAT MATCHES THE NUMBER OF ATTRIBUTED MODELS 

444 

445 file_path = algo.config.init_file 

446 decimal_date = algo.config.init_date 

447 AR_type = algo.config.AR_type 

448 file_data = read_analysed_states_hdf5(file_path) 

449 Nz = algo.config.Nz 

450 

451 date_indexes = (np.where((file_data['times'] <= decimal_date + 0.1) & (file_data['times'] > decimal_date - 0.1)))[0] 

452 

453 if date_indexes.shape[0] == 0: 

454 raise(ValueError, "{} analysed state not found in {}, check the input date".format(decimal_date, file_path)) 

455 

456 # do some checks 

457 i_date = date_indexes[0] 

458 

459 for measure in ["MF", "SV", "U", "ER"]: 

460 # Check that measures match 

461 assert measure in file_data 

462 cs_nb_reals = algo.nb_realisations 

463 file_nb_reals = file_data[measure][:, i_date].shape[0] 

464 assert cs_nb_reals <= file_nb_reals 

465 cs_nb_coefs = self._measures[measure][:, 0].shape[1] 

466 file_nb_coefs = file_data[measure][:, i_date].shape[1] 

467 assert cs_nb_coefs <= file_nb_coefs 

468 

469 # If so, initialise with the file_data 

470 self._measures[measure][:, 0] = file_data[measure][algo.attributed_models, i_date, :cs_nb_coefs] 

471 

472 # build Z 

473 Z = np.zeros((algo.attributed_models.shape[0],Nz)) 

474 for i_idx, i_real in enumerate(algo.attributed_models): 

475 if algo.check_PCA(): 

476 # perform PCA 

477 U = algo.config.pcaU_operator.transform(file_data["U"][i_real, i_date]) 

478 else: 

479 U = file_data["U"][i_real, i_date] - algo.avg_prior["U"] 

480 ER = file_data["ER"][i_real, i_date] - algo.avg_prior["ER"] 

481 Z[i_idx] = np.concatenate((U, ER)) 

482 self._measures["Z"][:, 0] = Z 

483 

484 if AR_type == "AR3": 

485 dt_f = algo.config.dt_f 

486 

487 # If derivatives in file_data 

488 if 'dUdt' in file_data.keys() and 'dERdt' in file_data.keys() and 'd2Udt2' in file_data.keys() and 'd2ERdt2' in file_data.keys(): 

489 

490 Z_AR3 = np.zeros((algo.attributed_models.shape[0],3,Nz)) 

491 for i_idx, i_real in enumerate(algo.attributed_models): 

492 #compute dZ d2Z 

493 if algo.check_PCA(): 

494 # perform PCA 

495 dUdt = algo.config.pcaU_operator.transform_deriv(file_data["dUdt"][i_real, i_date]) 

496 d2Udt2 = algo.config.pcaU_operator.transform_deriv(file_data["d2Udt2"][i_real, i_date]) 

497 else: 

498 dUdt = file_data["dUdt"][i_real, i_date] 

499 d2Udt2 = file_data["d2Udt2"][i_real, i_date] 

500 dERdt = file_data["dERdt"][i_real, i_date] 

501 d2ERdt2 = file_data["d2ERdt2"][i_real, i_date] 

502 dZdt = np.concatenate((dUdt, dERdt)) 

503 d2Zdt2 = np.concatenate((d2Udt2, d2ERdt2)) 

504 # computes Z_AR3 

505 # taylor series of order 2 

506 Z_AR3[i_idx, 2] = Z[i_idx] 

507 Z_AR3[i_idx, 1] = Z_AR3[i_idx, 2] - dt_f * dZdt + dt_f**2 / 2 * d2Zdt2 

508 Z_AR3[i_idx, 0] = Z_AR3[i_idx, 2] - (2*dt_f) * dZdt + (2*dt_f)**2 / 2 * d2Zdt2 

509 

510 # If no derivatives in file_data NOT TESTED YET 

511 else: 

512 if len(date_indexes) < 3: 

513 raise ValueError('Initialisation from file requires at least 3 analysed states in {} but contains only {}'.format(file_path, len(date_indexes))) 

514 

515 # compute dZ and d2Z 

516 dt_a = file_data['times'][1] - file_data['times'][0] 

517 dZdt, d2Zdt2 = common.compute_derivative(file_data.Z[algo.attributed_models, i_date - 1 : i_date + 2, :cs_nb_coefs], dt_a) 

518 

519 # compute Z_AR3 

520 Z_AR3 = np.zeros((algo.attributed_models.shape[0],3,Nz)) 

521 

522 # taylor series of order 2 

523 Z_AR3[:, 2] = Z 

524 Z_AR3[:, 1] = Z - dt_f * dZdt + dt_f**2 / 2 * d2Zdt2 

525 Z_AR3[:, 0] = Z - (2*dt_f) * dZdt + (2*dt_f)**2 / 2 * d2Zdt2 

526 

527 return Z_AR3 

528 else: 

529 return None 

530 

531 

532def with_core_state_of_dimensions(ndim_desired, arg_pos=1): 

533 """ 

534 Decorator checking if methods have input core state of desired dimensions 

535 

536 :param ndim_desired: number of desired dimensions for the CoreState 

537 :type ndim_desired: int 

538 :param arg_pos: position of the CoreState arg in the decorated function 

539 :type arg_pos: int 

540 """ 

541 def checking_decorator(method): 

542 @wraps(method) 

543 def checked_method(*args, **kwargs): 

544 corestate = args[arg_pos] 

545 if not isinstance(corestate, CoreState): 

546 raise TypeError('Expected CoreState object in {}, got {} instead'.format(method.__name__, type(corestate))) 

547 if len(corestate.measures) == 0: 

548 raise ValueError('Expected {} dimensions for the core_state, got an empty core state instead'.format(ndim_desired)) 

549 # Checks that all measures have the desired dimensions 

550 for measure_id, measure in corestate._measures.items(): 

551 ndim_measure = measure.ndim 

552 if measure.ndim != ndim_desired: 

553 raise ValueError('Expected {} dimensions for the core_state, got {} instead for {}'.format(ndim_desired, ndim_measure, measure_id)) 

554 return method(*args, **kwargs) 

555 return checked_method 

556 

557 return checking_decorator 

558 

559 

560@with_core_state_of_dimensions(1, arg_pos=0) 

561def coef_print(core_state_1D, n_coef): 

562 """ 

563 Convenience function to print the coef of all core state quantities of a certain index. Note that the index should therefore be lower than the length of the smallest quantity. 

564 

565 :param core_state_1D: 1D Core state with all quantities 

566 :type core_state_1D: corestates.CoreState 

567 :param n_coef: index of the coef to print 

568 :type n_coef: int 

569 :returns: a string with the 'n_coef'-th coefficients of all measures of the input Corestate 

570 :rtype: str 

571 """ 

572 assert type(n_coef) is int and n_coef >= 0, 'Coef_print : {} is not a valid index'.format(n_coef) 

573 

574 to_print = '' 

575 # Checks that n_coef is not out of bounds 

576 if n_coef >= core_state_1D.Nb or (n_coef >= core_state_1D.Nu2 or n_coef >= core_state_1D.Nsv): 

577 return to_print 

578 

579 for measure_id in core_state_1D.measures: 

580 to_print += "{}: {} ".format(measure_id, core_state_1D.getMeasure(measure_id)[n_coef]) 

581 to_print = to_print[:-1]+"\n" 

582 return to_print