Coverage for pygeodyn/augkf/forecaster.py: 100%

56 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-12-22 13:43 +0000

1import numpy as np 

2from .. import common 

3from ..generic.computer import GenericComputer 

4from pygeodyn.corestates import with_core_state_of_dimensions 

5from mpi4py import MPI 

6from .. import corestates as cs 

7 

8comm = MPI.COMM_WORLD 

9rank = comm.Get_rank() 

10 

11class AugkfForecasterAR1(GenericComputer): 

12 """ 

13 Class that implements the forecasts using AugKF (Augmented state Kalman Filter) algorithm with DIFF treated as a contribution to ER. 

14 """ 

15 def __init__(self, algo): 

16 """ 

17 :param algo: Algorithm object 

18 :type algo: Algo 

19 """ 

20 super().__init__(algo) 

21 

22 for k in self.needed_covariance_matrices(): 

23 if k not in algo.cov_prior: 

24 raise KeyError('{} matrix not found in covariance matrices while needed for Forecast !'.format(k)) 

25 

26 # Bool to deactivate checks on AR processes 

27 self.Cholesky_AR_check = False 

28 

29 def needed_covariance_matrices(self): 

30 #needed_covariance_matrices for forecast step 

31 return ['A', 'Chol'] 

32 

33 @with_core_state_of_dimensions(1) 

34 def forecast_step(self, input_core_state, Z_AR, seed, i_real, i_t): 

35 """ 

36 Forecasts the input_core_state using AR processes for Z, computation of SV and Euler scheme for B. 

37 

38 :param input_core_state: core_state of a single realisation at a single date  

39 :type input_core_state: corestates.CoreState 

40 :param Z_AR: forecast state AR 

41 :type Z_AR: np.array 1D if AR1 (Ncoef) or 2D if AR3 (3 x Ncoef) 

42 :param seed: random seed 

43 :type seed: int 

44 :param i_real: model realisation index 

45 :type i_real: int 

46 :param i_t: time index 

47 :type i_t: int 

48 :return: CoreState containing the result from the forecast 

49 :rtype: corestates.CoreState 

50 """ 

51 

52 # copy input core state 

53 next_core_state = input_core_state.copy() 

54 

55 # set random state 

56 N = self.algo.nb_realisations 

57 rstate = np.random.default_rng(seed + i_real + N * i_t) 

58 

59 # Compute Z(t+1) 

60 next_core_state.Z = self.forecast_Z(Z_AR, rstate=rstate) 

61 next_core_state.U, next_core_state.ER = self.algo.Z_to_U_ER(next_core_state.Z, 1) 

62 

63 # Compute A(b) 

64 Ab = self.compute_Ab(next_core_state) 

65 

66 # Compute SV(t+1) = A(b)U(t+1) + E(t+1) 

67 next_core_state.SV = np.matmul(Ab, next_core_state.U) + next_core_state.ER 

68 

69 # Compute B(t+1) = B(t) + dt*SV(t+1) (Euler scheme as SV=dB/dt) 

70 next_core_state.B = input_core_state.B + self.cfg.dt_f * next_core_state.SV 

71 

72 # Returns the updated core_state at t+1 

73 return next_core_state 

74 

75 

76 def forecast_Z(self, Z_AR1, rstate=None): 

77 """ 

78 Forecast Z state with AR-1 process. 

79 

80 :param Z_AR1: AR-1 forecast state 

81 :type Z_AR1: np.array (Ncoef) 

82 :param rstate: Random state to use for the AR-1 process 

83 :type rstate: np.random.RandomState 

84 :return: forecasted Z state 

85 :rtype: np.array(Ncoef) 

86 """ 

87 # AR1 process for Z 

88 return common.ar1_process(Z_AR1, 

89 self.algo.cov_prior['A'], 

90 self.algo.cov_prior['Chol'], 

91 random_state=rstate) 

92 

93 

94 def parallel_forecast_step(self, input_core_state, seed, i_t): 

95 """ 

96 parallelize the AR1 forecast step 

97 

98 :param input_core_state: input_core_state at time t 

99 :type input_core_states: corestates.Corestate 

100 :param seed: random seed 

101 :type seed: int 

102 :param i_t: time index 

103 :type i_t: int 

104 :return: CoreState containing the result from the forecast 

105 :rtype: corestates.CoreState 

106 """ 

107 

108 t = self.algo.config.t_forecasts[i_t] 

109 

110 # copy input core state 

111 forecast_at_t = input_core_state.copy() 

112 # set all measures to 0 

113 forecast_at_t[:] = 0 

114 

115 # Each process computes its attributed models 

116 for i_idx, i_model in enumerate(self.algo.attributed_models): 

117 forecast_at_t[i_idx] = self.forecast_step(input_core_state[i_idx], input_core_state.Z[i_idx], seed, i_idx, i_t) 

118 

119 print('P'+str(rank), t, cs.coef_print(forecast_at_t[i_idx], 0)) 

120 

121 return forecast_at_t 

122 

123 

124class AugkfForecasterAR3(AugkfForecasterAR1): 

125 

126 def needed_covariance_matrices(self): 

127 return ['A', 'B', 'C', 'Chol'] 

128 

129 def forecast_Z(self, Z_AR3, rstate=None): 

130 """ 

131 Forecast Z state with AR-3 process. 

132 

133 :param Z_AR3: AR-3 forecast state 

134 :type Z_AR3: np.array (3 x Ncoef) 

135 :param rstate: Random state to use for the AR-3 process 

136 :type rstate: np.random.RandomState 

137 :return: forecasted Z state 

138 :rtype: np.array(Ncoef) 

139 """ 

140 # AR3 process for Z 

141 return common.ar3_process(Z_AR3, 

142 self.algo.cov_prior['A'], 

143 self.algo.cov_prior['B'], 

144 self.algo.cov_prior['C'], 

145 self.algo.cov_prior['Chol'], 

146 random_state=rstate) 

147 

148 

149 def update_Z_AR3(self, Z, Z_AR3): 

150 """ 

151 Update forecast Z state for AR3 

152 :param Z: forecasted Z 

153 :type Z: np.array (nreal x Ncoef) 

154 :param Z_AR3: AR-3 forecast state 

155 :type Z_AR3: np.array (3 x Ncoef) 

156 :return Z_AR3: AR-3 forecast state 

157 :rtype Z_AR3: np.array (3 x Ncoef) 

158 """ 

159 Z_AR3 = np.concatenate((Z_AR3, Z[:,np.newaxis,:]), axis=1) 

160 assert Z_AR3[:,1:,:].shape[1] == 3 

161 return Z_AR3[:,1:,:] 

162 

163 def parallel_forecast_step(self, input_core_state, Z_AR3, seed, i_t): 

164 """ 

165 parallelize the AR3 forecast step 

166 

167 :param input_core_state: input_core_state at time t 

168 :type input_core_states: corestates.Corestate 

169 :param Z_AR3: AR-3 forecast state 

170 :type Z_AR3: np.array (3 x Ncoef) 

171 :param seed: random seed 

172 :type seed: int 

173 :param i_t: time index 

174 :type i_t: int 

175 :return: CoreState containing the result from the forecast 

176 :rtype: corestates.CoreState 

177 """ 

178 

179 t = self.algo.config.t_forecasts[i_t] 

180 

181 # copy input core state 

182 forecast_at_t = input_core_state.copy() 

183 # set all measures to 0 

184 forecast_at_t[:] = 0 

185 

186 # Each process computes its attributed models 

187 for i_idx, i_model in enumerate(self.algo.attributed_models): 

188 forecast_at_t[i_idx] = self.forecast_step(input_core_state[i_idx], Z_AR3[i_idx], seed, i_idx, i_t) 

189 

190 print('P'+str(rank), t, cs.coef_print(forecast_at_t[i_idx], 0)) 

191 

192 # update Z_AR3  

193 Z_AR3 = self.update_Z_AR3(forecast_at_t.Z, Z_AR3) 

194 

195 return forecast_at_t, Z_AR3