Coverage for pygeodyn/run.py: 77%

190 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-12-22 13:43 +0000

1import os 

2 

3# to improve performance in parrallel 

4os.environ["MKL_NUM_THREADS"] = "1" 

5os.environ["NUMEXPR_NUM_THREADS"] = "1" 

6os.environ["OMP_NUM_THREADS"] = "1" 

7 

8import time 

9import h5py 

10import logging 

11import subprocess 

12import numpy as np 

13from mpi4py import MPI 

14from .inout import writes 

15from . import corestates as cs 

16from .augkf.algo import create_augkf_algo 

17from .inout.config import ComputationConfig 

18from .shear.shear_algo import ShearInversion 

19 

20def send_receive(corestate, comm, rank, handler, attributed_models): 

21 """ 

22 Gather corestate to rank 0 (and broadcast to all ranks if do_bcast=True)  

23 without comm.Gather that had issues with large matrices 

24 

25 :param corestate: Corestate to be gathered 

26 :type corestate: Cs.Corestate 

27 :param comm: MPI communicator 

28 :type comm: MPI.comm 

29 :param rank: process rank 

30 :type rank: int 

31 :param handler: hdf5 handler 

32 :param attributed_models: hdf5 handler 

33 :type attributed_models: 1D numpy array 

34 """ 

35 if rank == 0: 

36 handler.update_all(corestate, attributed_models) 

37 

38 for i in range(1,comm.Get_size()): 

39 if rank == i: 

40 comm.send(corestate, dest=0, tag=0) 

41 comm.send(attributed_models, dest=0, tag=1) 

42 if rank == 0: 

43 corestate_recv = comm.recv(source=i, tag=0) 

44 attributed_models_recv = comm.recv(source=i, tag=1) 

45 handler.update_all(corestate_recv, attributed_models_recv) 

46 comm.Barrier() 

47 

48def gather_states(corestate, attributed_models, comm, rank, do_bcast=True): 

49 """ 

50 Gather corestate to rank 0 (and broadcast to all ranks if do_bcast=True) 

51 

52 :param corestate: Corestate to be gathered 

53 :type corestate: Cs.Corestate 

54 :param attributed_models: models handled by rank process  

55 :type attributed_models: 1D numpy array 

56 :param comm: MPI communicator 

57 :type comm: MPI.comm 

58 :param rank: process rank 

59 :type rank: int 

60 :param do_bcast: Controls whether the gathered corestate is broadcasted to all process 

61 :type do_bcast: boolean 

62 """ 

63 #synchronyze all processes 

64 comm.Barrier() 

65 #gather from all cores to rank 0 

66 attributed_models_gather = comm.gather(attributed_models, root=0) 

67 #in rank 0 

68 if rank == 0: 

69 # concatenate the contribution from all cores 

70 attributed_models_gather_concat = np.array(attributed_models_gather[0]) 

71 for i in range(1,len(attributed_models_gather)): 

72 attributed_models_gather_concat = np.concatenate((attributed_models_gather_concat,attributed_models_gather[i]),axis=0) 

73 #get sorting indices 

74 attributed_models_gather_sorted = np.argsort(attributed_models_gather_concat) 

75 #for each measure in corestate 

76 for meas in corestate._measures: 

77 #synchronyze all processes 

78 comm.Barrier() 

79 #gather from all cores to rank 0 

80 X_gather = comm.gather(corestate._measures[meas], root=0) 

81 #in rank 0 

82 if rank == 0: 

83 X_gather_concat = np.array(X_gather[0]) 

84 for i in range(1,len(X_gather)): 

85 X_gather_concat = np.concatenate((X_gather_concat,X_gather[i]),axis=0) 

86 #order X_gather_concat using the sorting indices 

87 X_gather_concat = X_gather_concat[attributed_models_gather_sorted] 

88 else: 

89 X_gather_concat = None 

90 if do_bcast: 

91 #synchronyze all processes 

92 comm.Barrier() 

93 #broadcast state from rank 0 to all ranks 

94 corestate._measures[meas] = comm.bcast(X_gather_concat, root=0) 

95 else: 

96 #synchronyze all processes 

97 comm.Barrier() 

98 corestate._measures[meas] = X_gather_concat 

99 return corestate 

100 

101def choose_algorithm(algo_name, config_file, nb_models, global_seed, attributed_models, do_shear): 

102 if algo_name == 'augkf': 

103 logging.info('Using augmented state Kalman filter algorithm') 

104 return create_augkf_algo(ComputationConfig(do_shear, config_file), nb_models, global_seed, attributed_models) 

105 else: 

106 raise ValueError('Algorithm name {} was not recognised. Aborting computation.'.format(algo_name)) 

107 

108def algorithm(output_path, computation_name, config_file, nb_models, do_shear, seed=None, log_file=None, logging_level=None, algo_name='augkf'): 

109 """ 

110 Runs the chosen algorithm : takes care of running the forecasts, the analysis, of logging and of saving the data 

111 

112 :param output_path: path where the data should be saved 

113 :type output_path: str 

114 :param computation_name: will create a folder of this name to store the output files 

115 :type computation_name: str 

116 :param config_file: path to the configuration file 

117 :type config_file: str 

118 :param nb_models: number of realisations/models to consider 

119 :type nb_models: int 

120 :param do_shear: control parameter of the shear computation 

121 :type do_shear: int 

122 :param seed: seed to use for theCs.Corestateel (https://docs.python.org/3.6/library/logging.html?highlight=logging#levels) 

123 :param algo_name: name of the algorithm. Supported algorithms is 'augkf' 

124 :type algo_name: str 

125 :return: CoreStates containing all the results of the computation, the forecasts and analysis. 

126 :rtype: CoreState, CoreState, CoreState 

127 """ 

128 # Init MPI variables 

129 if not MPI.Is_initialized(): 

130 MPI.Init() 

131 comm = MPI.COMM_WORLD 

132 nb_proc = comm.Get_size() 

133 rank = comm.Get_rank() 

134 

135 def first_process(): 

136 return rank == 0 

137 

138 def compute_shear(): 

139 return do_shear == 1 

140 

141 # set output folder 

142 folder_name = os.path.join(output_path, computation_name) 

143 os.makedirs(folder_name, exist_ok=True) 

144 

145 # Set log 

146 if logging_level is None: 

147 logging_level = logging.INFO 

148 log_format = '%(asctime)s - Process {} - %(levelname)s - %(message)s'.format(rank) 

149 if log_file is None: 

150 logging.basicConfig(format=log_format, level=logging_level) 

151 print("Logs will be displayed in the console.") 

152 else: 

153 logging.basicConfig(format=log_format, level=logging_level, 

154 filename=os.path.join(folder_name, log_file)) 

155 print("Logs will be saved in {}".format(os.path.join(folder_name, log_file))) 

156 

157 # Start time 

158 begin_time = time.time() 

159 

160 # INITIALISATION OF ALGO 

161 

162 # Set seed 

163 process_seeds = None 

164 if first_process(): 

165 if seed is None: 

166 seed = np.random.randint(0, 50000) 

167 logging.info("Global seed = {}".format(seed)) 

168 # Get sha1 and human readable info about git commit and branch 

169 current_dir = os.path.dirname(os.path.realpath(__file__)) 

170 git_branch = (subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"], cwd=current_dir).strip()).decode("utf-8") 

171 git_commit_mess = subprocess.check_output(["git", "log", "-1", "--pretty=%B"], cwd=current_dir).strip().decode("utf-8") 

172 git_description = subprocess.check_output(["git", "describe", "--always"], cwd=current_dir).strip().decode("utf-8") 

173 logging.info("Git branch: " + git_branch) 

174 logging.info("Git last commit message: \"" + git_commit_mess + "\" and description: " + git_description) 

175 # Generate a seed for each MPI processes from the global seed 

176 # If the same seed was given to each process, then the same random numbers would appear  

177 # on each process 

178 process_seeds = np.random.RandomState(seed).randint(0, 50000, size=nb_proc) 

179 

180 # Broadcast the process seed for each MPI process 

181 seed = comm.bcast(seed, root=0) 

182 process_seeds = comm.bcast(process_seeds, root=0) 

183 pseed = process_seeds[rank] 

184 logging.info("Process {} seed = {}".format(rank, pseed)) 

185 # Initialise the random state using the process seed 

186 process_rstate = np.random.RandomState(pseed) 

187 

188 # Build the attribution of realisations if several processes in parallel 

189 attributed_models = np.array(list(range(rank, nb_models, nb_proc))) 

190 logging.debug('Process {} will process realisations: {}'.format(rank, list(attributed_models))) 

191 

192 # Set algo 

193 algo = choose_algorithm(algo_name, config_file, nb_models, pseed, attributed_models, do_shear) 

194 

195 # Initialization of states is done in each process  

196 # Each process has its own attributed models so that there is less transfer of arrays 

197 computed_states, forecast_states, analysed_states, misfits, Z_AR3 = algo.init_corestates(random_state=process_rstate) 

198 

199 if first_process(): 

200 logging.info("Forecast will be performed at following times: {}" 

201 .format(algo.config.t_forecasts)) 

202 

203 logging.info("Analyses will be performed at following times: {}" 

204 .format(algo.config.t_analyses)) 

205 

206 # LAUNCH THE ALGORITHM 

207 # INIT 

208 i_t = 0 

209 i_analysis = 0 

210 ratio = algo.config.dt_a_f_ratio 

211 idx_max = algo.config.t_forecasts.shape[0] - 1 

212 

213 # Loop over all time indices 

214 while i_t < idx_max: 

215 

216 # PREPARE 

217 # Check if obs data is available and if next analysis on mf and/or sv is performed 

218 algo.analyser.check_if_analysis_data(i_analysis) 

219 

220 # Adapt forecast range R to eventual analysis and AR type 

221 if algo.analyser.sv_analysis() and algo.config.AR_type == "AR3": 

222 R = 2*ratio 

223 else: 

224 R = ratio 

225 if i_t + R > idx_max: 

226 R = idx_max - i_t 

227 

228 # FORECAST 

229 for i in range(R): 

230 # Increment i_t 

231 i_t += 1 

232 t = algo.config.t_forecasts[i_t] 

233 

234 # Compute forecast  

235 if algo.config.AR_type == "AR3": 

236 forecast_states[:, i_t], Z_AR3 = algo.forecaster.parallel_forecast_step(computed_states[:,i_t-1], Z_AR3, pseed, i_t) 

237 else: 

238 forecast_states[:, i_t] = algo.forecaster.parallel_forecast_step(computed_states[:,i_t-1], pseed, i_t) 

239 

240 # Update the computed core_state array with the forecast result 

241 computed_states[:, i_t] = forecast_states[:, i_t] 

242 

243 # ANALYSIS 

244 if algo.analyser.sv_analysis() and algo.config.AR_type == "AR3": 

245 # Set i_t back from t_a+ to t_a 

246 i_t = i_t - ratio 

247 t = algo.config.t_forecasts[i_t] 

248 

249 # If at least mf or sv analysis 

250 if algo.analyser.mf_analysis() or algo.analyser.sv_analysis(): 

251 logging.info("Starting analysis #{} at time {}...".format(i_analysis+1, t)) 

252 

253 # Compute analysis 

254 if algo.config.AR_type == "AR3": 

255 # Gather computed_states to get all realisations for analysis 

256 gather_computed_states = gather_states(computed_states[:, i_t-ratio:i_t+ratio+1], attributed_models, comm, rank) 

257 analysed_states[:, i_analysis], Z_AR3 = algo.analyser.analysis_step(gather_computed_states, Z_AR3) 

258 else: 

259 # Gather computed_states to get all realisations for analysis 

260 gather_computed_states = gather_states(computed_states[:, i_t], attributed_models, comm, rank) 

261 analysed_states[:, i_analysis] = algo.analyser.analysis_step(gather_computed_states) 

262 

263 if first_process(): 

264 print('A', t, cs.coef_print(analysed_states[0, i_analysis], 0)) 

265 

266 # Update the computed core_state array with the analysis result (overwrites forecast result) 

267 computed_states[:, i_t] = analysed_states[:, i_analysis] 

268 

269 # Update the misfits 

270 for measure in misfits.measures: 

271 misfits[measure][:, i_analysis] = algo.get_current_misfits(measure) 

272 

273 logging.info("Analysis #{} finished !".format(i_analysis+1)) 

274 

275 if i_analysis < algo.config.nb_analyses: 

276 if abs(t - algo.config.t_analyses[i_analysis]) < algo.config.dt_f/2: # if analysis time 

277 # Update the analysed core_state array with the computed core_state 

278 analysed_states[:, i_analysis] = computed_states[:, i_t] 

279 # Increment i_analysis 

280 i_analysis += 1 

281 

282 # END OF ALGO TIME LOOP 

283 

284 # START SHEAR COMPUTATION 

285 if compute_shear(): 

286 

287 logging.info("Initializing shear inversion") 

288 

289 # Init the Class ShearInversion 

290 shear = ShearInversion(algo.config) 

291 

292 # Precompute operators 

293 shear.precompute_operators() 

294 

295 # Build prior 

296 ( 

297 Pee_y1_tilda_with_glasso, 

298 Pee_y2_tilda_with_glasso, 

299 Pss_inv_with_glasso 

300 ) = shear.build_prior() 

301 

302 # - Sign Toroidal convention for shear 

303 U_state = np.copy(analysed_states.U) 

304 U_state[:, :, :algo.config.Nu] = -U_state[:, :, :algo.config.Nu] 

305 

306 # Loop over analyses 

307 for i_t, t in enumerate(algo.config.out_analysis): 

308 

309 logging.info("Starting shear computation #{} at time {}...".format(i_t, t)) 

310 

311 store_shear = np.zeros((analysed_states.S[:, 0,:].shape)) 

312 

313 # Compute shear and store every realisation 

314 for i_idx, i_real in enumerate(attributed_models): 

315 if i_t == 0: 

316 DU = (U_state[i_idx, i_t+1] - U_state[i_idx, i_t]) / (algo.config.dt_a) 

317 elif i_t == algo.config.out_analysis.shape[0] - 1: 

318 DU = (U_state[i_idx, i_t] - U_state[i_idx, i_t-1]) / (algo.config.dt_a) 

319 else : 

320 DU = (U_state[i_idx, i_t+1] - U_state[i_idx, i_t-1]) / (2*algo.config.dt_a) 

321 

322 U = U_state[i_idx, i_t,:] 

323 MF = analysed_states.B[i_idx, i_t] 

324 SV = analysed_states.SV[i_idx, i_t] 

325 store_shear[i_idx,:] = shear.compute_shear(U, DU, MF, SV, Pee_y1_tilda_with_glasso, Pee_y2_tilda_with_glasso, Pss_inv_with_glasso) 

326 

327 analysed_states.S[:, i_t] = store_shear 

328 

329 if first_process(): 

330 print('A', t, cs.coef_print(analysed_states[0, i_t], 0)) 

331 

332 # END SHEAR COMPUTATION 

333 

334 if first_process(): 

335 # Create FileHandlers for saving 

336 output_file = h5py.File(os.path.join(folder_name, computation_name)+'.hdf5', mode='w') 

337 algo.config.save_hdf5(output_file) 

338 # Save data if not done beforehand 

339 if algo.config.out_computed: 

340 computed_handler = writes.Hdf5GroupHandler( 

341 algo.config.t_forecasts, nb_models, computed_states, output_file, "computed", algo.config.out_format, 

342 exclude=["Z","S","dUdt","dERdt","d2Udt2","d2ERdt2"]) 

343 if algo.config.out_forecast: 

344 forecast_handler = writes.Hdf5GroupHandler( 

345 algo.config.t_forecasts, nb_models, forecast_states, output_file, "forecast", algo.config.out_format, 

346 exclude=["Z","S","dUdt","dERdt","d2Udt2","d2ERdt2"]) 

347 if algo.config.out_analysis: 

348 analysis_handler = writes.Hdf5GroupHandler( 

349 algo.config.t_analyses, nb_models, analysed_states, output_file, "analysed", algo.config.out_format, 

350 exclude=["Z"]) 

351 if algo.config.out_misfits: 

352 misfits_handler = writes.Hdf5GroupHandler( 

353 algo.config.t_analyses, 1, misfits, output_file, "misfits", algo.config.out_format) 

354 misfits_handler.update_all(misfits, [0]) 

355 else: 

356 computed_handler, forecast_handler, analysis_handler= None, None, None 

357 

358 if algo.config.out_computed: 

359 send_receive(computed_states, comm, rank, computed_handler, attributed_models) 

360 if algo.config.out_forecast: 

361 send_receive(forecast_states, comm, rank, forecast_handler, attributed_models) 

362 if algo.config.out_analysis: 

363 send_receive(analysed_states, comm, rank, analysis_handler, attributed_models) 

364 

365 # Final saving tasks (closing files and writing readme) 

366 if first_process(): 

367 # End time 

368 end_time = time.time() 

369 elapsed_time = end_time - begin_time 

370 logging.info("Elapsed time : {:.2f}".format(elapsed_time)) 

371 

372 # Close file handlers 

373 output_file.close() 

374 

375 logging.info("Data are saved in {}".format(os.path.join(folder_name, computation_name)+'.hdf5')) 

376 

377 git_info = (git_branch, git_commit_mess, git_description) 

378 

379 # Write Readme file 

380 writes.saveReadme(folder_name, algo.config, algo.name, nb_models, git_info, elapsed_time, nb_proc, seed) 

381 

382 return computed_states, forecast_states, analysed_states 

383 else: 

384 return None, None, None