Commit d1905891 authored by Florentie, Liesbeth's avatar Florentie, Liesbeth
Browse files

parallel updating of samples

parent 35f21401
......@@ -205,12 +205,10 @@ class ctsfTM5ObsOperator(TM5ObservationOperator):
logging.debug('Resetting TM5 to perform restart')
# If neither one is true, simply take the istart value from the tm5.rc file that was read
self.modify_rc(new_items)
self.write_rc(self.rc_filename)
# Define the name of the file that will contain the modeled output of each observation
# Define the name of the file that will contain the modeled output of each observation
self.simulated_file = os.path.join(self.outputdir, 'flask_output.%s.nc' % self.dacycle['time.sample.stamp'])
......
......@@ -32,6 +32,7 @@ class FluxObsOperator(ObservationOperator):
self.ID = identifier # the identifier gives the model name
self.version = version # the model version used
self.output_filelist = []
logging.info('Observation Operator initialized: %s (%s)' % (self.ID, self.version))
......@@ -40,8 +41,8 @@ class FluxObsOperator(ObservationOperator):
def setup(self,dacycle):
""" Perform all steps necessary to start the observation operator through a simple Run() call """
self.dacycle = dacycle
self.refdate = to_datetime(dacycle.dasystem['time.reference'])
self.dacycle = dacycle
self.refdate = to_datetime(dacycle.dasystem['time.reference'])
......@@ -51,7 +52,7 @@ class FluxObsOperator(ObservationOperator):
self.prepare_run(postprocessing)
self.validate_input() # perform checks
# calculate and write fluxmap for ensemble
self.get_flux_samples(fluxmodel)
self.get_flux_samples(fluxmodel, postprocessing=postprocessing)
......@@ -68,7 +69,7 @@ class FluxObsOperator(ObservationOperator):
for em in range(int(self.dacycle['da.optimizer.nmembers'])):
# get flux on grid timerange, only read anomalies first time
nee = fluxmodel.calc_flux(self.timevec, member=em, updateano=updateano, write=False)
nee = fluxmodel.calc_flux(self.timevec, member=em, updateano=updateano, write=False, postprocessing=postprocessing)
updateano = False
# extract flux samples for member
......@@ -178,7 +179,7 @@ class FluxObsOperator(ObservationOperator):
# Define the name of the file that will contain the modeled output of each observation
if postprocessing:
self.simulated_file = os.path.join(self.outputdir, 'samples_simulated.%s.nc' % self.dacycle['time.sample.stamp'])
self.simulated_file = os.path.join(self.dacycle['dir.output'], 'samples_simulated.%s.nc' % self.dacycle['time.sample.stamp'])
else:
self.simulated_file = self.dacycle['dir.da_run'] + ('/input/samples_simulated.%s.nc' % self.dacycle['time.sample.stamp'])
self.forecast_nmembers = int(self.dacycle['da.optimizer.nmembers'])
......
......@@ -222,7 +222,7 @@ class NEEobservations(Observations):
def add_model_data_mismatch(self, filename): #filename = dacycle.dasystem['obs.sites.rc']
def add_model_data_mismatch(self):#, filename): #filename = dacycle.dasystem['obs.sites.rc']
"""
Get the model-data mismatch values for this cycle.
......
......@@ -231,7 +231,7 @@ class ctsfObsPackObservations(ObsPackObservations):
nr_obs_per_day = 1
else:
nr_obs_per_day = len([c.code for c in self.datalist if c.code == obs.code and c.xdate.day == obs.xdate.day and c.flag == 0])
# logging.debug("Observation found (%s, %d), mdm category is: %0.2f, scaled with number of observations per day (%i), final mdm applied is: %0.2f." % (identifier, obs.id, site_info[identifier]['error'],nr_obs_per_day,site_info[identifier]['error']*sqrt(nr_obs_per_day)))
logging.debug("Observation found (%s, %d), mdm category is: %0.2f, scaled with number of observations per day (%i), final mdm applied is: %0.2f." % (identifier, obs.id, site_info[identifier]['error'],nr_obs_per_day,site_info[identifier]['error']*sqrt(nr_obs_per_day)))
obs.mdm = site_info[identifier]['error'] * sqrt(nr_obs_per_day) * self.global_R_scaling * self.obs_scaling
obs.may_localize = site_info[identifier]['may_localize']
obs.may_reject = site_info[identifier]['may_reject']
......@@ -248,14 +248,14 @@ class ctsfObsPackObservations(ObsPackObservations):
obs.lat = obs.lat + movelat
obs.lon = obs.lon + movelon
# logging.warning("Observation location for (%s, %d), is moved by %3.2f degrees latitude and %3.2f degrees longitude" % (identifier, obs.id, movelat, movelon))
logging.warning("Observation location for (%s, %d), is moved by %3.2f degrees latitude and %3.2f degrees longitude" % (identifier, obs.id, movelat, movelon))
if site_incalt.has_key(identifier):
incalt = site_incalt[identifier]
obs.height = obs.height + incalt
# logging.warning("Observation location for (%s, %d), is moved by %3.2f meters in altitude" % (identifier, obs.id, incalt))
logging.warning("Observation location for (%s, %d), is moved by %3.2f meters in altitude" % (identifier, obs.id, incalt))
# Only keep obs in datalist that will be used in assimilation
self.datalist = obs_to_keep
......
......@@ -33,9 +33,26 @@ version = '0.0'
################### Begin Class CO2Optimizer ###################
def update_hx_for_obs_n(m, n, obsn, Hx, nmembers, HX_prime, HPHR, alpha):
res = obsn - Hx[n]
fac = 1.0 / (nmembers - 1) * (HX_prime[n,:] * HX_prime[m, :]).sum() / HPHR
Hx_m = Hx[m] + fac * res
HX_prime_m = HX_prime[m, :] - alpha * fac * HX_prime[n, :]
return HX_prime_m, Hx_m
class ctsfOptimizer(CO2Optimizer):
def serial_minimum_least_squares(self):
def setup(self, dims, ntasks):
self.nlag = dims[0]
self.nmembers = dims[1]
self.nparams = dims[2]
self.nobs = dims[3]
self.ntasks = ntasks
self.create_matrices()
def serial_minimum_least_squares_serial(self):
""" Make minimum least squares solution by looping over obs"""
# calculate prior value cost function
......@@ -107,3 +124,84 @@ class ctsfOptimizer(CO2Optimizer):
logging.info('Observation part cost function: prior = %s, posterior = %s' % (J_prior, J_post))
logging.info('Squared residual: prior = %s, posterior = %s' % (res_2_prior, res_2_post))
def serial_minimum_least_squares(self):
""" Make minimum least squares solution by looping over obs"""
from multiprocessing import Pool
from functools import partial
# calculate prior value cost function
J_prior, res_2_prior = 0, 0
for n in range(self.nobs):
J_prior += (self.obs[n]-self.Hx[n])**2/self.R[n]
res_2_prior += (self.obs[n]-self.Hx[n])**2
del n
# create pools for parallel updating of Hx and HX_prime
pool = Pool(processes=self.ntasks)
for n in range(self.nobs):
# Screen for flagged observations (for instance site not found, or no sample written from model)
if self.flags[n] != 0:
logging.debug('Skipping observation (%s,%i) because of flag value %d' % (self.sitecode[n], self.obs_ids[n], self.flags[n]))
continue
# Screen for outliers greater than 3x model-data mismatch, only apply if obs may be rejected
res = self.obs[n] - self.Hx[n]
if self.may_reject[n]:
threshold = self.rejection_threshold * np.sqrt(self.R[n])
if np.abs(res) > threshold:
logging.debug('Rejecting observation (%s,%i) because residual (%f) exceeds threshold (%f)' % (self.sitecode[n], self.obs_ids[n], res, threshold))
self.flags[n] = 2
continue
logging.debug('Proceeding to assimilate observation %s, %i' % (self.sitecode[n], self.obs_ids[n]))
# LF: might be necessary to rewrite PHt calculation to loop with summation if statevector becomes too big
PHt = 1. / (self.nmembers - 1) * np.dot(self.X_prime, self.HX_prime[n, :])
self.HPHR[n] = 1. / (self.nmembers - 1) * (self.HX_prime[n, :] * self.HX_prime[n, :]).sum() + self.R[n]
self.KG[:] = PHt / self.HPHR[n]
if self.may_localize[n]:
logging.debug('Trying to localize observation %s, %i' % (self.sitecode[n], self.obs_ids[n]))
self.localize(n)
else:
logging.debug('Not allowed to localize observation %s, %i' % (self.sitecode[n], self.obs_ids[n]))
alpha = np.double(1.0) / (np.double(1.0) + np.sqrt((self.R[n]) / self.HPHR[n]))
self.x[:] = self.x + self.KG[:] * res
for r in range(self.nmembers):
self.X_prime[:, r] = self.X_prime[:, r] - alpha * self.KG[:] * (self.HX_prime[n, r])
del r
# update samples to account for update of statevector based on observation n
#WP !!!! Very important to first do all obervations from n=1 through the end, and only then update 1,...,n. The current observation
#WP should always be updated last because it features in the loop of the adjustments !!!
if n+1 < self.nobs:
void = np.array(pool.map(partial(update_hx_for_obs_n, n=n, obsn=self.obs[n], Hx=self.Hx, nmembers=self.nmembers, HX_prime=self.HX_prime, HPHR=self.HPHR[n], alpha=alpha),range(n+1,self.nobs)))
self.Hx[n+1:self.nobs] = void[:,1]
self.HX_prime[n+1:self.nobs,:] = [np.array(hxm) for hxm in void[:,0]]
if n > 0:
void = np.array(pool.map(partial(update_hx_for_obs_n, n=n, obsn=self.obs[n], Hx=self.Hx, nmembers=self.nmembers, HX_prime=self.HX_prime, HPHR=self.HPHR[n], alpha=alpha),range(n)))
self.Hx[:n] = void[:,1]
self.HX_prime[:n,:] = [np.array(hxm) for hxm in void[:,0]]
self.HX_prime[n,:], self.Hx[n] = update_hx_for_obs_n(n,n=n, obsn=self.obs[n], Hx=self.Hx, nmembers=self.nmembers, HX_prime=self.HX_prime, HPHR=self.HPHR[n], alpha=alpha)
del n, void
pool.close()
# calculate posterior value cost function
J_post, res_2_post = 0, 0
for n in range(self.nobs):
J_post += (self.obs[n]-self.Hx[n])**2/self.R[n]
res_2_post += (self.obs[n]-self.Hx[n])**2
del n
logging.info('Observation part cost function: prior = %s, posterior = %s' % (J_prior, J_post))
logging.info('Squared residual: prior = %s, posterior = %s' % (res_2_prior, res_2_post))
......@@ -266,10 +266,11 @@ def sample_step(dacycle, samples, statevector, fluxmodel, obsoperator, lag, adva
def invert(dacycle, statevector, optimizer):
""" Perform the inverse calculation """
logging.info(header + "starting invert" + footer)
dims = (int(dacycle['time.nlag']),
dims = (int(dacycle['time.nlag']),
int(dacycle['da.optimizer.nmembers']),
statevector.nparams,
statevector.nobs)
ntasks = int(dacycle['da.resources.ntasks'])
if not dacycle.dasystem.has_key('opt.algorithm'):
logging.info("There was no minimum least squares algorithm specified in the DA System rc file (key : opt.algorithm)")
......@@ -282,7 +283,7 @@ def invert(dacycle, statevector, optimizer):
logging.info("Using the bulk minimum least squares algorithm to solve ENKF equations")
optimizer.set_algorithm('Bulk')
optimizer.setup(dims)
optimizer.setup(dims, ntasks)
optimizer.state_to_matrix(statevector)
diagnostics_file = os.path.join(dacycle['dir.output'], 'optimizer.%s.nc' % dacycle['time.start'].strftime('%Y%m%d'))
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment