Coverage for pygeodyn/run.py: 77%
190 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-22 13:43 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-22 13:43 +0000
1import os
3# to improve performance in parrallel
4os.environ["MKL_NUM_THREADS"] = "1"
5os.environ["NUMEXPR_NUM_THREADS"] = "1"
6os.environ["OMP_NUM_THREADS"] = "1"
8import time
9import h5py
10import logging
11import subprocess
12import numpy as np
13from mpi4py import MPI
14from .inout import writes
15from . import corestates as cs
16from .augkf.algo import create_augkf_algo
17from .inout.config import ComputationConfig
18from .shear.shear_algo import ShearInversion
20def send_receive(corestate, comm, rank, handler, attributed_models):
21 """
22 Gather corestate to rank 0 (and broadcast to all ranks if do_bcast=True)
23 without comm.Gather that had issues with large matrices
25 :param corestate: Corestate to be gathered
26 :type corestate: Cs.Corestate
27 :param comm: MPI communicator
28 :type comm: MPI.comm
29 :param rank: process rank
30 :type rank: int
31 :param handler: hdf5 handler
32 :param attributed_models: hdf5 handler
33 :type attributed_models: 1D numpy array
34 """
35 if rank == 0:
36 handler.update_all(corestate, attributed_models)
38 for i in range(1,comm.Get_size()):
39 if rank == i:
40 comm.send(corestate, dest=0, tag=0)
41 comm.send(attributed_models, dest=0, tag=1)
42 if rank == 0:
43 corestate_recv = comm.recv(source=i, tag=0)
44 attributed_models_recv = comm.recv(source=i, tag=1)
45 handler.update_all(corestate_recv, attributed_models_recv)
46 comm.Barrier()
48def gather_states(corestate, attributed_models, comm, rank, do_bcast=True):
49 """
50 Gather corestate to rank 0 (and broadcast to all ranks if do_bcast=True)
52 :param corestate: Corestate to be gathered
53 :type corestate: Cs.Corestate
54 :param attributed_models: models handled by rank process
55 :type attributed_models: 1D numpy array
56 :param comm: MPI communicator
57 :type comm: MPI.comm
58 :param rank: process rank
59 :type rank: int
60 :param do_bcast: Controls whether the gathered corestate is broadcasted to all process
61 :type do_bcast: boolean
62 """
63 #synchronyze all processes
64 comm.Barrier()
65 #gather from all cores to rank 0
66 attributed_models_gather = comm.gather(attributed_models, root=0)
67 #in rank 0
68 if rank == 0:
69 # concatenate the contribution from all cores
70 attributed_models_gather_concat = np.array(attributed_models_gather[0])
71 for i in range(1,len(attributed_models_gather)):
72 attributed_models_gather_concat = np.concatenate((attributed_models_gather_concat,attributed_models_gather[i]),axis=0)
73 #get sorting indices
74 attributed_models_gather_sorted = np.argsort(attributed_models_gather_concat)
75 #for each measure in corestate
76 for meas in corestate._measures:
77 #synchronyze all processes
78 comm.Barrier()
79 #gather from all cores to rank 0
80 X_gather = comm.gather(corestate._measures[meas], root=0)
81 #in rank 0
82 if rank == 0:
83 X_gather_concat = np.array(X_gather[0])
84 for i in range(1,len(X_gather)):
85 X_gather_concat = np.concatenate((X_gather_concat,X_gather[i]),axis=0)
86 #order X_gather_concat using the sorting indices
87 X_gather_concat = X_gather_concat[attributed_models_gather_sorted]
88 else:
89 X_gather_concat = None
90 if do_bcast:
91 #synchronyze all processes
92 comm.Barrier()
93 #broadcast state from rank 0 to all ranks
94 corestate._measures[meas] = comm.bcast(X_gather_concat, root=0)
95 else:
96 #synchronyze all processes
97 comm.Barrier()
98 corestate._measures[meas] = X_gather_concat
99 return corestate
101def choose_algorithm(algo_name, config_file, nb_models, global_seed, attributed_models, do_shear):
102 if algo_name == 'augkf':
103 logging.info('Using augmented state Kalman filter algorithm')
104 return create_augkf_algo(ComputationConfig(do_shear, config_file), nb_models, global_seed, attributed_models)
105 else:
106 raise ValueError('Algorithm name {} was not recognised. Aborting computation.'.format(algo_name))
108def algorithm(output_path, computation_name, config_file, nb_models, do_shear, seed=None, log_file=None, logging_level=None, algo_name='augkf'):
109 """
110 Runs the chosen algorithm : takes care of running the forecasts, the analysis, of logging and of saving the data
112 :param output_path: path where the data should be saved
113 :type output_path: str
114 :param computation_name: will create a folder of this name to store the output files
115 :type computation_name: str
116 :param config_file: path to the configuration file
117 :type config_file: str
118 :param nb_models: number of realisations/models to consider
119 :type nb_models: int
120 :param do_shear: control parameter of the shear computation
121 :type do_shear: int
122 :param seed: seed to use for theCs.Corestateel (https://docs.python.org/3.6/library/logging.html?highlight=logging#levels)
123 :param algo_name: name of the algorithm. Supported algorithms is 'augkf'
124 :type algo_name: str
125 :return: CoreStates containing all the results of the computation, the forecasts and analysis.
126 :rtype: CoreState, CoreState, CoreState
127 """
128 # Init MPI variables
129 if not MPI.Is_initialized():
130 MPI.Init()
131 comm = MPI.COMM_WORLD
132 nb_proc = comm.Get_size()
133 rank = comm.Get_rank()
135 def first_process():
136 return rank == 0
138 def compute_shear():
139 return do_shear == 1
141 # set output folder
142 folder_name = os.path.join(output_path, computation_name)
143 os.makedirs(folder_name, exist_ok=True)
145 # Set log
146 if logging_level is None:
147 logging_level = logging.INFO
148 log_format = '%(asctime)s - Process {} - %(levelname)s - %(message)s'.format(rank)
149 if log_file is None:
150 logging.basicConfig(format=log_format, level=logging_level)
151 print("Logs will be displayed in the console.")
152 else:
153 logging.basicConfig(format=log_format, level=logging_level,
154 filename=os.path.join(folder_name, log_file))
155 print("Logs will be saved in {}".format(os.path.join(folder_name, log_file)))
157 # Start time
158 begin_time = time.time()
160 # INITIALISATION OF ALGO
162 # Set seed
163 process_seeds = None
164 if first_process():
165 if seed is None:
166 seed = np.random.randint(0, 50000)
167 logging.info("Global seed = {}".format(seed))
168 # Get sha1 and human readable info about git commit and branch
169 current_dir = os.path.dirname(os.path.realpath(__file__))
170 git_branch = (subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"], cwd=current_dir).strip()).decode("utf-8")
171 git_commit_mess = subprocess.check_output(["git", "log", "-1", "--pretty=%B"], cwd=current_dir).strip().decode("utf-8")
172 git_description = subprocess.check_output(["git", "describe", "--always"], cwd=current_dir).strip().decode("utf-8")
173 logging.info("Git branch: " + git_branch)
174 logging.info("Git last commit message: \"" + git_commit_mess + "\" and description: " + git_description)
175 # Generate a seed for each MPI processes from the global seed
176 # If the same seed was given to each process, then the same random numbers would appear
177 # on each process
178 process_seeds = np.random.RandomState(seed).randint(0, 50000, size=nb_proc)
180 # Broadcast the process seed for each MPI process
181 seed = comm.bcast(seed, root=0)
182 process_seeds = comm.bcast(process_seeds, root=0)
183 pseed = process_seeds[rank]
184 logging.info("Process {} seed = {}".format(rank, pseed))
185 # Initialise the random state using the process seed
186 process_rstate = np.random.RandomState(pseed)
188 # Build the attribution of realisations if several processes in parallel
189 attributed_models = np.array(list(range(rank, nb_models, nb_proc)))
190 logging.debug('Process {} will process realisations: {}'.format(rank, list(attributed_models)))
192 # Set algo
193 algo = choose_algorithm(algo_name, config_file, nb_models, pseed, attributed_models, do_shear)
195 # Initialization of states is done in each process
196 # Each process has its own attributed models so that there is less transfer of arrays
197 computed_states, forecast_states, analysed_states, misfits, Z_AR3 = algo.init_corestates(random_state=process_rstate)
199 if first_process():
200 logging.info("Forecast will be performed at following times: {}"
201 .format(algo.config.t_forecasts))
203 logging.info("Analyses will be performed at following times: {}"
204 .format(algo.config.t_analyses))
206 # LAUNCH THE ALGORITHM
207 # INIT
208 i_t = 0
209 i_analysis = 0
210 ratio = algo.config.dt_a_f_ratio
211 idx_max = algo.config.t_forecasts.shape[0] - 1
213 # Loop over all time indices
214 while i_t < idx_max:
216 # PREPARE
217 # Check if obs data is available and if next analysis on mf and/or sv is performed
218 algo.analyser.check_if_analysis_data(i_analysis)
220 # Adapt forecast range R to eventual analysis and AR type
221 if algo.analyser.sv_analysis() and algo.config.AR_type == "AR3":
222 R = 2*ratio
223 else:
224 R = ratio
225 if i_t + R > idx_max:
226 R = idx_max - i_t
228 # FORECAST
229 for i in range(R):
230 # Increment i_t
231 i_t += 1
232 t = algo.config.t_forecasts[i_t]
234 # Compute forecast
235 if algo.config.AR_type == "AR3":
236 forecast_states[:, i_t], Z_AR3 = algo.forecaster.parallel_forecast_step(computed_states[:,i_t-1], Z_AR3, pseed, i_t)
237 else:
238 forecast_states[:, i_t] = algo.forecaster.parallel_forecast_step(computed_states[:,i_t-1], pseed, i_t)
240 # Update the computed core_state array with the forecast result
241 computed_states[:, i_t] = forecast_states[:, i_t]
243 # ANALYSIS
244 if algo.analyser.sv_analysis() and algo.config.AR_type == "AR3":
245 # Set i_t back from t_a+ to t_a
246 i_t = i_t - ratio
247 t = algo.config.t_forecasts[i_t]
249 # If at least mf or sv analysis
250 if algo.analyser.mf_analysis() or algo.analyser.sv_analysis():
251 logging.info("Starting analysis #{} at time {}...".format(i_analysis+1, t))
253 # Compute analysis
254 if algo.config.AR_type == "AR3":
255 # Gather computed_states to get all realisations for analysis
256 gather_computed_states = gather_states(computed_states[:, i_t-ratio:i_t+ratio+1], attributed_models, comm, rank)
257 analysed_states[:, i_analysis], Z_AR3 = algo.analyser.analysis_step(gather_computed_states, Z_AR3)
258 else:
259 # Gather computed_states to get all realisations for analysis
260 gather_computed_states = gather_states(computed_states[:, i_t], attributed_models, comm, rank)
261 analysed_states[:, i_analysis] = algo.analyser.analysis_step(gather_computed_states)
263 if first_process():
264 print('A', t, cs.coef_print(analysed_states[0, i_analysis], 0))
266 # Update the computed core_state array with the analysis result (overwrites forecast result)
267 computed_states[:, i_t] = analysed_states[:, i_analysis]
269 # Update the misfits
270 for measure in misfits.measures:
271 misfits[measure][:, i_analysis] = algo.get_current_misfits(measure)
273 logging.info("Analysis #{} finished !".format(i_analysis+1))
275 if i_analysis < algo.config.nb_analyses:
276 if abs(t - algo.config.t_analyses[i_analysis]) < algo.config.dt_f/2: # if analysis time
277 # Update the analysed core_state array with the computed core_state
278 analysed_states[:, i_analysis] = computed_states[:, i_t]
279 # Increment i_analysis
280 i_analysis += 1
282 # END OF ALGO TIME LOOP
284 # START SHEAR COMPUTATION
285 if compute_shear():
287 logging.info("Initializing shear inversion")
289 # Init the Class ShearInversion
290 shear = ShearInversion(algo.config)
292 # Precompute operators
293 shear.precompute_operators()
295 # Build prior
296 (
297 Pee_y1_tilda_with_glasso,
298 Pee_y2_tilda_with_glasso,
299 Pss_inv_with_glasso
300 ) = shear.build_prior()
302 # - Sign Toroidal convention for shear
303 U_state = np.copy(analysed_states.U)
304 U_state[:, :, :algo.config.Nu] = -U_state[:, :, :algo.config.Nu]
306 # Loop over analyses
307 for i_t, t in enumerate(algo.config.out_analysis):
309 logging.info("Starting shear computation #{} at time {}...".format(i_t, t))
311 store_shear = np.zeros((analysed_states.S[:, 0,:].shape))
313 # Compute shear and store every realisation
314 for i_idx, i_real in enumerate(attributed_models):
315 if i_t == 0:
316 DU = (U_state[i_idx, i_t+1] - U_state[i_idx, i_t]) / (algo.config.dt_a)
317 elif i_t == algo.config.out_analysis.shape[0] - 1:
318 DU = (U_state[i_idx, i_t] - U_state[i_idx, i_t-1]) / (algo.config.dt_a)
319 else :
320 DU = (U_state[i_idx, i_t+1] - U_state[i_idx, i_t-1]) / (2*algo.config.dt_a)
322 U = U_state[i_idx, i_t,:]
323 MF = analysed_states.B[i_idx, i_t]
324 SV = analysed_states.SV[i_idx, i_t]
325 store_shear[i_idx,:] = shear.compute_shear(U, DU, MF, SV, Pee_y1_tilda_with_glasso, Pee_y2_tilda_with_glasso, Pss_inv_with_glasso)
327 analysed_states.S[:, i_t] = store_shear
329 if first_process():
330 print('A', t, cs.coef_print(analysed_states[0, i_t], 0))
332 # END SHEAR COMPUTATION
334 if first_process():
335 # Create FileHandlers for saving
336 output_file = h5py.File(os.path.join(folder_name, computation_name)+'.hdf5', mode='w')
337 algo.config.save_hdf5(output_file)
338 # Save data if not done beforehand
339 if algo.config.out_computed:
340 computed_handler = writes.Hdf5GroupHandler(
341 algo.config.t_forecasts, nb_models, computed_states, output_file, "computed", algo.config.out_format,
342 exclude=["Z","S","dUdt","dERdt","d2Udt2","d2ERdt2"])
343 if algo.config.out_forecast:
344 forecast_handler = writes.Hdf5GroupHandler(
345 algo.config.t_forecasts, nb_models, forecast_states, output_file, "forecast", algo.config.out_format,
346 exclude=["Z","S","dUdt","dERdt","d2Udt2","d2ERdt2"])
347 if algo.config.out_analysis:
348 analysis_handler = writes.Hdf5GroupHandler(
349 algo.config.t_analyses, nb_models, analysed_states, output_file, "analysed", algo.config.out_format,
350 exclude=["Z"])
351 if algo.config.out_misfits:
352 misfits_handler = writes.Hdf5GroupHandler(
353 algo.config.t_analyses, 1, misfits, output_file, "misfits", algo.config.out_format)
354 misfits_handler.update_all(misfits, [0])
355 else:
356 computed_handler, forecast_handler, analysis_handler= None, None, None
358 if algo.config.out_computed:
359 send_receive(computed_states, comm, rank, computed_handler, attributed_models)
360 if algo.config.out_forecast:
361 send_receive(forecast_states, comm, rank, forecast_handler, attributed_models)
362 if algo.config.out_analysis:
363 send_receive(analysed_states, comm, rank, analysis_handler, attributed_models)
365 # Final saving tasks (closing files and writing readme)
366 if first_process():
367 # End time
368 end_time = time.time()
369 elapsed_time = end_time - begin_time
370 logging.info("Elapsed time : {:.2f}".format(elapsed_time))
372 # Close file handlers
373 output_file.close()
375 logging.info("Data are saved in {}".format(os.path.join(folder_name, computation_name)+'.hdf5'))
377 git_info = (git_branch, git_commit_mess, git_description)
379 # Write Readme file
380 writes.saveReadme(folder_name, algo.config, algo.name, nb_models, git_info, elapsed_time, nb_proc, seed)
382 return computed_states, forecast_states, analysed_states
383 else:
384 return None, None, None