Coverage for pygeodyn/corestates.py: 89%
272 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-22 13:43 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-22 13:43 +0000
1import numpy as np
2import scipy as sp
3from mpi4py import MPI
4from functools import wraps
5import pygeodyn.common as common
6from pygeodyn.utilities import date_to_decimal
7from pygeodyn.inout.reads import read_analysed_states_hdf5
9comm = MPI.COMM_WORLD
10rank = comm.Get_rank()
12class CoreState:
13 """
14 CoreState object. It stores the measure datas in a dict of ndarrays.
15 """
16 def __init__(self, init_measures=None):
17 """
18 Initiates the CoreState. Initial measures can be given in the init_measures arg:
19 - Either with only the data (max_degree will be inferred from the data)
20 Ex: CoreState({SV: np.zeros(224), ...})
21 - Or by giving also the max_degree as a second member of a 2-tuple/list (the first being the data):
22 Ex: CoreState({SV: [np.zeros(224), 14], ...})
24 :param init_measures: The measures to add in a dict with members can be np.ndarray or 2-tuple/lists.
25 :type init_measures: dict or None
26 """
27 self._measures = {}
28 self._max_degrees = {}
30 if init_measures is not None and isinstance(init_measures, dict):
31 for meas_id, measure_data in init_measures.items():
32 # If 2-tuple/list, assume data in first member and Lmax in second
33 if len(measure_data) == 2 and (isinstance(measure_data, list) or isinstance(measure_data, tuple)):
34 self.addMeasure(meas_id, measure_data[0], measure_data[1])
35 # Else add the data with Lmax inferred from the data
36 else:
37 self.addMeasure(meas_id, measure_data)
39 def create_CoreState(self):
40 """
41 Create an instatiation of the class CoreState.
42 Convenience function created for inheritance purposes.
43 """
44 return CoreState()
46 def addMeasure(self, meas_id, meas_data, meas_max_degree=None):
47 """
48 Adds a measure to the CoreState.
50 :param meas_id: name of the measure. Used as key of dict for internal storing.
51 :type meas_id: str
52 :param meas_data: data of the measure.
53 :type meas_data: np.ndarray or list
54 :param meas_max_degree: Max degree of the measure. If None (default), it will be inferred from the last dimension of the data.
55 :type meas_max_degree: int or None
56 """
57 # Try to infer the max degree from last dimension of data N=L(L+2)
58 if meas_max_degree is None:
59 if meas_id != 'Z':
60 nb_coeffs = meas_data.shape[-1]
61 # If a measure derived from U or S, the equation is N/2 = L(L+2)
62 if ('U' == meas_id) or ('dUdt' == meas_id) or ('d2Udt2' == meas_id) or ('S' == meas_id):
63 assert nb_coeffs % 2 == 0
64 nb_coeffs = nb_coeffs//2
65 computed_Lmax = np.sqrt(nb_coeffs+1) - 1
66 if computed_Lmax != int(computed_Lmax):
67 raise ValueError('Last dimension {} of the data given for {} does not lead to an integer max degree ! Got {} instead.'
68 .format(meas_data.shape[-1], meas_id, computed_Lmax))
69 meas_max_degree = int(computed_Lmax)
70 else:
71 meas_max_degree = -1
73 assert isinstance(meas_max_degree, int)
74 self._max_degrees[meas_id] = meas_max_degree
75 self._measures[meas_id] = meas_data
78 def __getitem__(self, item):
79 # If the item is a measure name, return the measure (dict-like behaviour)
80 if isinstance(item, str):
81 return self.getMeasure(item)
82 # Else, delegate the __getitem__ to the measures stored to return a sliced CoreState (ndarray-like behaviour)
83 new_cs = self.create_CoreState()
84 for id_qty, qty in self._measures.items():
85 new_cs.addMeasure(id_qty, qty.__getitem__(item), self._max_degrees[id_qty])
86 return new_cs
88 def __setitem__(self, key, value):
89 for meas_id, meas_data in self._measures.items():
90 # If the value is a CoreState, set the items measure by measure
91 if isinstance(value, CoreState):
92 meas_data.__setitem__(key, value.getMeasure(meas_id))
93 # Else delegate the __setitem__ to the measures as it is
94 else:
95 meas_data.__setitem__(key, value)
97 def update_corestate(self, i_t, corestate, exclude = ["S"]):
98 """
99 :param i_t: time_iteration
100 :type i_t: integer
101 :param corestate: corestate to update from
102 :type corestate: Corestate
103 :param exclude: list of measures not updated
104 :type exclude: list
105 """
106 for meas_id in self._measures.keys():
107 if meas_id not in exclude:
108 self._measures[meas_id][:,i_t,:] = corestate._measures[meas_id]
110 def copy(self):
111 """
112 :return: a CoreState with copied measures
113 :rtype: CoreState
114 """
115 new_cs = self.create_CoreState()
116 for meas_id, meas_data in self._measures.items():
117 new_cs.addMeasure(meas_id, meas_data.copy(), self._max_degrees[meas_id])
118 return new_cs
120 def mean(self, *args, **kwargs):
121 """
122 :return: a new CoreState with averaged measures
123 :rtype: CoreState
124 """
125 new_cs = self.create_CoreState()
126 for meas_id, meas_data in self._measures.items():
127 new_cs.addMeasure(meas_id, meas_data.mean(*args, **kwargs), self._max_degrees[meas_id])
128 return new_cs
130 def std(self, **kwargs):
131 """
132 :return: a new CoreState with the standard deviation of the measures
133 :rtype: CoreState
134 """
135 new_cs = self.create_CoreState()
136 for meas_id, meas_data in self._measures.items():
137 new_cs.addMeasure(meas_id, meas_data.std(**kwargs), self._max_degrees[meas_id])
138 return new_cs
140 @property
141 def measures(self):
142 """
143 :return: A *sorted* list of the measures of the CoreState
144 :rtype: list
145 """
146 return sorted(self._measures.keys())
148 def keys(self):
149 """
150 Convenience function to use CoreState as a dict
152 :return: A *sorted* list of the measures of the CoreState
153 :rtype: list
154 """
155 return self.measures
157 def _getMaxDegree(self, measure_type):
158 """
159 :param measure_type: type of the measure
160 :type measure_type: str
161 :return: the max degree of the asked type or 0 if the measure is not loaded
162 :rtype: int
163 """
164 # Returns 0 if not in _max_degrees
165 if measure_type not in self._max_degrees:
166 return 0
167 else:
168 return self._max_degrees[measure_type]
170 @property
171 def Lb(self):
172 return self._getMaxDegree('MF')
174 @property
175 def Nb(self):
176 return self.Lb*(self.Lb + 2)
178 @property
179 def Lu(self):
180 return self._getMaxDegree('U')
182 @property
183 def Nu2(self):
184 return 2*self.Lu*(self.Lu + 2)
186 @property
187 def Lsv(self):
188 return self._getMaxDegree('SV')
190 @property
191 def Nsv(self):
192 return self.Lsv*(self.Lsv + 2)
195 # Getters/Setters for CoreState quantities
196 def getMeasure(self, measure_type):
197 """
198 Generic getter for measures. Raises a KeyError if the measure is not loaded.
200 :param measure_type: type of the measure to get
201 :type measure_type: str
202 :return: the measure data
203 :rtype: np.ndarray
204 """
205 if measure_type not in self._measures:
206 raise KeyError('{} is not a valid measure to get from CoreState.'.format(measure_type))
207 return self._measures[measure_type]
209 def setMeasure(self, measure_type, value):
210 """
211 Generic setter for measures. Raises a KeyError if the measure is not loaded.
213 :param measure_type: type of the measure to get
214 :type measure_type: str
215 :param value: the new measure data
216 :type measure: np.ndarray
217 """
218 if measure_type not in self._measures:
219 raise KeyError('{} is not a valid measure to set from CoreState.'.format(measure_type))
220 self._measures[measure_type] = value
222 @property
223 def B(self):
224 """
225 Get the magnetic field (B) part of the CoreState.
227 :return: Magnetic field data
228 :rtype: np.ndarray
229 """
230 return self._measures['MF']
232 @B.setter
233 def B(self, mag_field):
234 self._measures['MF'] = mag_field
236 @property
237 def U(self):
238 """
239 Get the core flow (U) part of the CoreState.
241 :return: Core flow data
242 :rtype: np.ndarray
243 """
244 return self._measures['U']
246 @U.setter
247 def U(self, core_flow):
248 self._measures['U'] = core_flow
250 @property
251 def dUdt(self):
252 """
253 Get time derivative of core flow (dUdt) part of the CoreState.
255 :return: time derivative of core flow
256 :rtype: np.ndarray
257 """
258 return self._measures['dUdt']
260 @dUdt.setter
261 def dUdt(self, time_derivative_core_flow):
262 self._measures['dUdt'] = time_derivative_core_flow
264 @property
265 def d2Udt2(self):
266 """
267 Get second time derivative of core flow (d2Udt2) part of the CoreState.
269 :return: second time derivative of core flow
270 :rtype: np.ndarray
271 """
272 return self._measures['d2Udt2']
274 @d2Udt2.setter
275 def d2Udt2(self, second_time_derivative_core_flow):
276 self._measures['d2Udt2'] = second_time_derivative_core_flow
278 @property
279 def S(self):
280 """
281 Get the shear (S) part of the CoreState.
283 :return: Shear data
284 :rtype: np.ndarray
285 """
286 return self._measures['S']
288 @S.setter
289 def S(self, shear):
290 self._measures['S'] = shear
292 @property
293 def ER(self):
294 """
295 Get the subgrid errors (ER) part of the CoreState.
297 :return: Subgrid errors data
298 :rtype: np.ndarray
299 """
300 return self._measures['ER']
302 @ER.setter
303 def ER(self, subgrid_error):
304 self._measures['ER'] = subgrid_error
306 @property
307 def dERdt(self):
308 """
309 Get the derivative of subgrid errors (dERdt) part of the CoreState.
311 :return: derivative of subgrid errors
312 :rtype: np.ndarray
313 """
314 return self._measures['dERdt']
316 @dERdt.setter
317 def dERdt(self, derivative_subgrid_error):
318 self._measures['dERdt'] = derivative_subgrid_error
320 @property
321 def d2ERdt2(self):
322 """
323 Get the second derivative of subgrid errors (d2ERdt2) part of the CoreState.
325 :return: second derivative of subgrid errors
326 :rtype: np.ndarray
327 """
328 return self._measures['d2ERdt2']
330 @d2ERdt2.setter
331 def d2ERdt2(self, second_derivative_subgrid_error):
332 self._measures['d2ERdt2'] = second_derivative_subgrid_error
335 @property
336 def Z(self):
337 """
338 Get the Augmented state (Z) part of the CoreState.
340 :return: Coreflow and subgrid errors data
341 :rtype: np.ndarray
342 """
343 return self._measures['Z']
345 @Z.setter
346 def Z(self, augmented_state):
347 self._measures['Z'] = augmented_state
349 @property
350 def SV(self):
351 """
352 Get the Core secular variation (SV) part of the CoreState.
354 :return: Secular variation data
355 :rtype: np.ndarray
356 """
357 return self._measures['SV']
359 @SV.setter
360 def SV(self, secular_variation):
361 self._measures['SV'] = secular_variation
364 def initialise_from_noised_priors(self, algo, random_state=None):
365 """
366 Initialise the core state Z MF U ER SV at t=0 with a random normal draw around the average priors.
368 :param algo: algo instance
369 :type algo: Augkf.algo
370 :param random_state: if not None, is used for the normal draw. sets the B and U part to the average prior.
371 :type random_state: numpy.RandomState or None
372 :return: nothing. Simply update self.
373 """
375 # THE CORESTATE HAS A NUMBER OF REALISATIONS THAT MATCHES THE NUMBER OF ATTRIBUTED MODELS
377 #Averages
378 avg_b = algo.avg_prior['B']
379 AR_type = algo.config.AR_type
380 Nz = algo.config.Nz
381 dt_f = algo.config.dt_f
383 # Lower Cholesky matrices
384 L_bb = sp.linalg.cholesky(algo.cov_prior['B,B'], lower=True)
385 L_zz = sp.linalg.cholesky(algo.cov_prior['Z,Z'], lower=True)
387 if AR_type == "AR3":
388 L_dzdz = sp.linalg.cholesky(algo.cov_prior['dZ,dZ'], lower=True)
389 L_d2zd2z = sp.linalg.cholesky(algo.cov_prior['d2Z,d2Z'], lower=True)
391 # Set random draw
392 if random_state is not None:
393 normal_draw = random_state.normal
394 else:
395 normal_draw = np.random.normal
397 # Loop over attributed models
398 for i_idx, i_real in enumerate(algo.attributed_models):
399 # Set normal draw
400 w_b = normal_draw(size=algo.config.Nb)
401 w_z = normal_draw(size=Nz)
403 # Initialise B part of core state by normal distrib N(mean_b, sigma_b)
404 self.B[i_idx, 0] = avg_b + np.matmul(L_bb, w_b)
405 # Initialise Z part of core state by normal distrib N(0, sigma_z)
406 self.Z[i_idx, 0] = np.matmul(L_zz, w_z)
407 # Z to U ER
408 self.U[i_idx, 0], self.ER[i_idx, 0] = algo.Z_to_U_ER(self.Z[i_idx, 0], 1)
409 # SV = A(b) U + ER
410 self.SV[i_idx, 0] = algo.analyser.compute_Ab(self[i_idx, 0]) @ self.U[i_idx, 0] + self.ER[i_idx, 0]
412 if AR_type == "AR3":
413 # Compute Z_AR3
414 Z_AR3 = np.zeros((algo.attributed_models.shape[0],3,Nz))
415 w_dz = normal_draw(size=(algo.attributed_models.shape[0],Nz))
416 w_d2z = normal_draw(size=(algo.attributed_models.shape[0],Nz))
417 # Set dZ d2Z normal distribution
418 dZ = w_dz @ L_dzdz.T
419 d2Z = w_d2z @ L_d2zd2z.T
420 # Taylor series of order 2
421 Z_AR3[:, 2] = self.Z[:, 0]
422 Z_AR3[:, 1] = self.Z[:, 0] - dt_f * dZ + dt_f**2 / 2 * d2Z
423 Z_AR3[:, 0] = self.Z[:, 0] - (2*dt_f) * dZ + (2*dt_f)**2 / 2 * d2Z
425 return Z_AR3
426 else:
427 return None
430 def initialise_from_file(self, algo):
431 """
432 Initialise the core state Z MF U ER SV at t=0 from the CoreState in a file at a given date.
434 :param algo: algo instance
435 :type algo: Augkf.algo
436 :param file_path: path of the hdf5 file containing the computed states to use for initialisation
437 :type file_path: str
438 :param date: date of the CoreState to use for the initialisation
439 :type date: datetime64
440 :return: nothing. Simply update self.
441 """
443 # THE CORESTATE HAS A NUMBER OF REALISATIONS THAT MATCHES THE NUMBER OF ATTRIBUTED MODELS
445 file_path = algo.config.init_file
446 decimal_date = algo.config.init_date
447 AR_type = algo.config.AR_type
448 file_data = read_analysed_states_hdf5(file_path)
449 Nz = algo.config.Nz
451 date_indexes = (np.where((file_data['times'] <= decimal_date + 0.1) & (file_data['times'] > decimal_date - 0.1)))[0]
453 if date_indexes.shape[0] == 0:
454 raise(ValueError, "{} analysed state not found in {}, check the input date".format(decimal_date, file_path))
456 # do some checks
457 i_date = date_indexes[0]
459 for measure in ["MF", "SV", "U", "ER"]:
460 # Check that measures match
461 assert measure in file_data
462 cs_nb_reals = algo.nb_realisations
463 file_nb_reals = file_data[measure][:, i_date].shape[0]
464 assert cs_nb_reals <= file_nb_reals
465 cs_nb_coefs = self._measures[measure][:, 0].shape[1]
466 file_nb_coefs = file_data[measure][:, i_date].shape[1]
467 assert cs_nb_coefs <= file_nb_coefs
469 # If so, initialise with the file_data
470 self._measures[measure][:, 0] = file_data[measure][algo.attributed_models, i_date, :cs_nb_coefs]
472 # build Z
473 Z = np.zeros((algo.attributed_models.shape[0],Nz))
474 for i_idx, i_real in enumerate(algo.attributed_models):
475 if algo.check_PCA():
476 # perform PCA
477 U = algo.config.pcaU_operator.transform(file_data["U"][i_real, i_date])
478 else:
479 U = file_data["U"][i_real, i_date] - algo.avg_prior["U"]
480 ER = file_data["ER"][i_real, i_date] - algo.avg_prior["ER"]
481 Z[i_idx] = np.concatenate((U, ER))
482 self._measures["Z"][:, 0] = Z
484 if AR_type == "AR3":
485 dt_f = algo.config.dt_f
487 # If derivatives in file_data
488 if 'dUdt' in file_data.keys() and 'dERdt' in file_data.keys() and 'd2Udt2' in file_data.keys() and 'd2ERdt2' in file_data.keys():
490 Z_AR3 = np.zeros((algo.attributed_models.shape[0],3,Nz))
491 for i_idx, i_real in enumerate(algo.attributed_models):
492 #compute dZ d2Z
493 if algo.check_PCA():
494 # perform PCA
495 dUdt = algo.config.pcaU_operator.transform_deriv(file_data["dUdt"][i_real, i_date])
496 d2Udt2 = algo.config.pcaU_operator.transform_deriv(file_data["d2Udt2"][i_real, i_date])
497 else:
498 dUdt = file_data["dUdt"][i_real, i_date]
499 d2Udt2 = file_data["d2Udt2"][i_real, i_date]
500 dERdt = file_data["dERdt"][i_real, i_date]
501 d2ERdt2 = file_data["d2ERdt2"][i_real, i_date]
502 dZdt = np.concatenate((dUdt, dERdt))
503 d2Zdt2 = np.concatenate((d2Udt2, d2ERdt2))
504 # computes Z_AR3
505 # taylor series of order 2
506 Z_AR3[i_idx, 2] = Z[i_idx]
507 Z_AR3[i_idx, 1] = Z_AR3[i_idx, 2] - dt_f * dZdt + dt_f**2 / 2 * d2Zdt2
508 Z_AR3[i_idx, 0] = Z_AR3[i_idx, 2] - (2*dt_f) * dZdt + (2*dt_f)**2 / 2 * d2Zdt2
510 # If no derivatives in file_data NOT TESTED YET
511 else:
512 if len(date_indexes) < 3:
513 raise ValueError('Initialisation from file requires at least 3 analysed states in {} but contains only {}'.format(file_path, len(date_indexes)))
515 # compute dZ and d2Z
516 dt_a = file_data['times'][1] - file_data['times'][0]
517 dZdt, d2Zdt2 = common.compute_derivative(file_data.Z[algo.attributed_models, i_date - 1 : i_date + 2, :cs_nb_coefs], dt_a)
519 # compute Z_AR3
520 Z_AR3 = np.zeros((algo.attributed_models.shape[0],3,Nz))
522 # taylor series of order 2
523 Z_AR3[:, 2] = Z
524 Z_AR3[:, 1] = Z - dt_f * dZdt + dt_f**2 / 2 * d2Zdt2
525 Z_AR3[:, 0] = Z - (2*dt_f) * dZdt + (2*dt_f)**2 / 2 * d2Zdt2
527 return Z_AR3
528 else:
529 return None
532def with_core_state_of_dimensions(ndim_desired, arg_pos=1):
533 """
534 Decorator checking if methods have input core state of desired dimensions
536 :param ndim_desired: number of desired dimensions for the CoreState
537 :type ndim_desired: int
538 :param arg_pos: position of the CoreState arg in the decorated function
539 :type arg_pos: int
540 """
541 def checking_decorator(method):
542 @wraps(method)
543 def checked_method(*args, **kwargs):
544 corestate = args[arg_pos]
545 if not isinstance(corestate, CoreState):
546 raise TypeError('Expected CoreState object in {}, got {} instead'.format(method.__name__, type(corestate)))
547 if len(corestate.measures) == 0:
548 raise ValueError('Expected {} dimensions for the core_state, got an empty core state instead'.format(ndim_desired))
549 # Checks that all measures have the desired dimensions
550 for measure_id, measure in corestate._measures.items():
551 ndim_measure = measure.ndim
552 if measure.ndim != ndim_desired:
553 raise ValueError('Expected {} dimensions for the core_state, got {} instead for {}'.format(ndim_desired, ndim_measure, measure_id))
554 return method(*args, **kwargs)
555 return checked_method
557 return checking_decorator
560@with_core_state_of_dimensions(1, arg_pos=0)
561def coef_print(core_state_1D, n_coef):
562 """
563 Convenience function to print the coef of all core state quantities of a certain index. Note that the index should therefore be lower than the length of the smallest quantity.
565 :param core_state_1D: 1D Core state with all quantities
566 :type core_state_1D: corestates.CoreState
567 :param n_coef: index of the coef to print
568 :type n_coef: int
569 :returns: a string with the 'n_coef'-th coefficients of all measures of the input Corestate
570 :rtype: str
571 """
572 assert type(n_coef) is int and n_coef >= 0, 'Coef_print : {} is not a valid index'.format(n_coef)
574 to_print = ''
575 # Checks that n_coef is not out of bounds
576 if n_coef >= core_state_1D.Nb or (n_coef >= core_state_1D.Nu2 or n_coef >= core_state_1D.Nsv):
577 return to_print
579 for measure_id in core_state_1D.measures:
580 to_print += "{}: {} ".format(measure_id, core_state_1D.getMeasure(measure_id)[n_coef])
581 to_print = to_print[:-1]+"\n"
582 return to_print