Commit 35f21401 authored by Florentie, Liesbeth's avatar Florentie, Liesbeth
Browse files

change in reading of observations: sites.rc file read already in...

change in reading of observations: sites.rc file read already in add_observations + cost function in optimizer calculated as sum instead of dot product
parent 2077b844
......@@ -61,6 +61,13 @@ class ctsfObsPackObservations(ObsPackObservations):
self.datalist = []
if not os.path.exists(dacycle.dasystem['obs.sites.rc']):
msg = 'Could not find the required sites.rc input file (%s) ' % dacycle.dasystem['obs.sites.rc']
logging.error(msg)
raise IOError, msg
else:
self.sites_file = dacycle.dasystem['obs.sites.rc']
def add_observations(self):
......@@ -85,10 +92,25 @@ class ctsfObsPackObservations(ObsPackObservations):
ncfilelist += [ncfile]
del line
# Step 2: Read valid sites from site_weights file
valid_sites = []
sites_weights = rc.read(self.sites_file)
for key, value in sites_weights.iteritems():
if 'co2_' in key:
sitename, sitecategory = key, value
sitename = sitename.strip()
sitecategory = sitecategory.split()[0].strip().lower()
if sitecategory != 'do-not-use':
valid_sites.append(sitename)
del key, value
# Step 3: Read observations from valid sites only
logging.debug("ObsPack dataset info read, proceeding with %d netcdf files" % len(ncfilelist))
for ncfile in ncfilelist:
if ncfile not in valid_sites: continue
infile = os.path.join(self.obspack_dir, 'data', 'nc', ncfile + '.nc')
ncf = io.ct_read(infile, 'read')
idates = ncf.get_variable('time_components')
......@@ -133,7 +155,7 @@ class ctsfObsPackObservations(ObsPackObservations):
def add_model_data_mismatch(self, filename):
def add_model_data_mismatch(self):
"""
Get the model-data mismatch values for this cycle.
(1) Open a sites_weights file
......@@ -142,12 +164,12 @@ class ctsfObsPackObservations(ObsPackObservations):
(4) Take care of double sites, etc
"""
if not os.path.exists(filename):
msg = 'Could not find the required sites.rc input file (%s) ' % filename
logging.error(msg)
raise IOError, msg
else:
self.sites_file = filename
# if not os.path.exists(filename):
# msg = 'Could not find the required sites.rc input file (%s) ' % filename
# logging.error(msg)
# raise IOError, msg
# else:
# self.sites_file = filename
sites_weights = rc.read(self.sites_file)
......@@ -178,6 +200,7 @@ class ctsfObsPackObservations(ObsPackObservations):
sitename, sitecategory = key, value
sitename = sitename.strip()
sitecategory = sitecategory.split()[0].strip().lower()
if sitecategory == 'do-not-use': continue
site_info[sitename] = site_categories[sitecategory]
if 'site.move' in key:
identifier, latmove, lonmove = value.split(';')
......@@ -208,7 +231,7 @@ class ctsfObsPackObservations(ObsPackObservations):
nr_obs_per_day = 1
else:
nr_obs_per_day = len([c.code for c in self.datalist if c.code == obs.code and c.xdate.day == obs.xdate.day and c.flag == 0])
logging.debug("Observation found (%s, %d), mdm category is: %0.2f, scaled with number of observations per day (%i), final mdm applied is: %0.2f." % (identifier, obs.id, site_info[identifier]['error'],nr_obs_per_day,site_info[identifier]['error']*sqrt(nr_obs_per_day)))
# logging.debug("Observation found (%s, %d), mdm category is: %0.2f, scaled with number of observations per day (%i), final mdm applied is: %0.2f." % (identifier, obs.id, site_info[identifier]['error'],nr_obs_per_day,site_info[identifier]['error']*sqrt(nr_obs_per_day)))
obs.mdm = site_info[identifier]['error'] * sqrt(nr_obs_per_day) * self.global_R_scaling * self.obs_scaling
obs.may_localize = site_info[identifier]['may_localize']
obs.may_reject = site_info[identifier]['may_reject']
......@@ -225,14 +248,14 @@ class ctsfObsPackObservations(ObsPackObservations):
obs.lat = obs.lat + movelat
obs.lon = obs.lon + movelon
logging.warning("Observation location for (%s, %d), is moved by %3.2f degrees latitude and %3.2f degrees longitude" % (identifier, obs.id, movelat, movelon))
# logging.warning("Observation location for (%s, %d), is moved by %3.2f degrees latitude and %3.2f degrees longitude" % (identifier, obs.id, movelat, movelon))
if site_incalt.has_key(identifier):
incalt = site_incalt[identifier]
obs.height = obs.height + incalt
logging.warning("Observation location for (%s, %d), is moved by %3.2f meters in altitude" % (identifier, obs.id, incalt))
# logging.warning("Observation location for (%s, %d), is moved by %3.2f meters in altitude" % (identifier, obs.id, incalt))
# Only keep obs in datalist that will be used in assimilation
self.datalist = obs_to_keep
......
......@@ -38,21 +38,22 @@ class ctsfOptimizer(CO2Optimizer):
def serial_minimum_least_squares(self):
""" Make minimum least squares solution by looping over obs"""
J_prior = np.dot(np.dot(np.transpose(self.obs - self.Hx),np.diag(1/self.R)),(self.obs - self.Hx))
res_2_prior = np.dot(np.transpose(self.obs - self.Hx),(self.obs - self.Hx))
# calculate prior value cost function
J_prior, res_2_prior = 0, 0
for n in range(self.nobs):
J_prior += (self.obs[n]-self.Hx[n])**2/self.R[n]
res_2_prior += (self.obs[n]-self.Hx[n])**2
del n
for n in range(self.nobs):
# Screen for flagged observations (for instance site not found, or no sample written from model)
if self.flags[n] != 0:
logging.debug('Skipping observation (%s,%i) because of flag value %d' % (self.sitecode[n], self.obs_ids[n], self.flags[n]))
continue
# Screen for outliers greater than 3x model-data mismatch, only apply if obs may be rejected
res = self.obs[n] - self.Hx[n]
if self.may_reject[n]:
threshold = self.rejection_threshold * np.sqrt(self.R[n])
if np.abs(res) > threshold:
......@@ -62,10 +63,9 @@ class ctsfOptimizer(CO2Optimizer):
logging.debug('Proceeding to assimilate observation %s, %i' % (self.sitecode[n], self.obs_ids[n]))
PHt = 1. / (self.nmembers - 1) * np.dot(self.X_prime, self.HX_prime[n, :])
self.HPHR[n] = 1. / (self.nmembers - 1) * (self.HX_prime[n, :] * self.HX_prime[n, :]).sum() + self.R[n]
self.KG[:] = PHt / self.HPHR[n]
PHt = 1. / (self.nmembers - 1) * np.dot(self.X_prime, self.HX_prime[n, :])
self.HPHR[n] = 1. / (self.nmembers - 1) * (self.HX_prime[n, :] * self.HX_prime[n, :]).sum() + self.R[n]
self.KG[:] = PHt / self.HPHR[n]
if self.may_localize[n]:
logging.debug('Trying to localize observation %s, %i' % (self.sitecode[n], self.obs_ids[n]))
......@@ -73,8 +73,7 @@ class ctsfOptimizer(CO2Optimizer):
else:
logging.debug('Not allowed to localize observation %s, %i' % (self.sitecode[n], self.obs_ids[n]))
alpha = np.double(1.0) / (np.double(1.0) + np.sqrt((self.R[n]) / self.HPHR[n]))
alpha = np.double(1.0) / (np.double(1.0) + np.sqrt((self.R[n]) / self.HPHR[n]))
self.x[:] = self.x + self.KG[:] * res
for r in range(self.nmembers):
......@@ -99,7 +98,12 @@ class ctsfOptimizer(CO2Optimizer):
del n
J_post = np.dot(np.dot(np.transpose(self.obs - self.Hx),np.diag(1/self.R)),(self.obs - self.Hx))
res_2_post = np.dot(np.transpose(self.obs - self.Hx),(self.obs - self.Hx))
# calculate posterior value cost function
J_post, res_2_post = 0, 0
for n in range(self.nobs):
J_post += (self.obs[n]-self.Hx[n])**2/self.R[n]
res_2_post += (self.obs[n]-self.Hx[n])**2
del n
logging.info('Observation part cost function: prior = %s, posterior = %s' % (J_prior, J_post))
logging.info('Squared residual: prior = %s, posterior = %s' % (res_2_prior, res_2_post))
......@@ -80,7 +80,7 @@ def check_setup(dacycle, platform, dasystem, samples, statevector, fluxmodel, ob
# Create observation vector for simulation interval
samples.setup(dacycle)
samples.add_observations()
samples.add_model_data_mismatch(dacycle.dasystem['obs.sites.rc'])
samples.add_model_data_mismatch()#dacycle.dasystem['obs.sites.rc'])
sampling_coords_file = os.path.join(dacycle['dir.input'], 'sample_coordinates_%s.nc' % dacycle['time.sample.stamp'])
samples.write_sample_coords(sampling_coords_file)
......@@ -222,7 +222,7 @@ def sample_step(dacycle, samples, statevector, fluxmodel, obsoperator, lag, adva
samples.add_observations()
# Add model-data mismatch to all samples, this *might* use output from the ensemble in the future??
samples.add_model_data_mismatch(dacycle.dasystem['obs.sites.rc'])
samples.add_model_data_mismatch()#dacycle.dasystem['obs.sites.rc'])
sampling_coords_file = os.path.join(dacycle['dir.input'], 'sample_coordinates_%s.nc' % dacycle['time.sample.stamp'])
samples.write_sample_coords(sampling_coords_file)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment