Commit 8a23838e authored by Florentie, Liesbeth's avatar Florentie, Liesbeth
Browse files

update serial least squares optimizer: array operations instead of loop for update of samples

parent d1905891
......@@ -14,11 +14,11 @@ program. If not, see <http://www.gnu.org/licenses/>."""
# obs.py
"""
Author : peters
Authors: peters, liesbeth
Revision History:
File created on 28 Jul 2010.
File updated on 16 April 2019.
"""
import os
import sys
......@@ -31,7 +31,7 @@ sys.path.append(os.getcwd())
sys.path.append('../../')
identifier = 'CarbonTracker CO2 mole fractions'
version = '0.0'
version = '2.0'
from da.carbondioxide.obspack_globalviewplus2 import ObsPackObservations, MoleFractionSample
import da.tools.io4 as io
......@@ -57,7 +57,10 @@ class ctsfObsPackObservations(ObsPackObservations):
self.obspack_dir = op_dir
self.obspack_id = op_id
self.obs_scaling = float(dacycle.dasystem['obs.scaling.factor'])
try:
self.obs_scaling = float(dacycle.dasystem['obs.scaling.factor'])
except KeyError:
self.obs_scaling = 1.0
self.datalist = []
......@@ -155,7 +158,7 @@ class ctsfObsPackObservations(ObsPackObservations):
def add_model_data_mismatch(self):
def add_model_data_mismatch(self, filename=None, advance=False):
"""
Get the model-data mismatch values for this cycle.
(1) Open a sites_weights file
......@@ -164,13 +167,6 @@ class ctsfObsPackObservations(ObsPackObservations):
(4) Take care of double sites, etc
"""
# if not os.path.exists(filename):
# msg = 'Could not find the required sites.rc input file (%s) ' % filename
# logging.error(msg)
# raise IOError, msg
# else:
# self.sites_file = filename
sites_weights = rc.read(self.sites_file)
self.rejection_threshold = int(sites_weights['obs.rejection.threshold'])
......@@ -223,7 +219,11 @@ class ctsfObsPackObservations(ObsPackObservations):
for obs in self.datalist:
identifier = obs.code
species, site, method, lab, datasetnr = identifier.split('_')
if identifier.endswith('_MERGED'):
species, site, method, lab, datasetnr, merged = identifier.split('_')
elif identifier.endswith( '_NRT'):
species, site, method, lab, datasetnr, nrt = identifier.split('_')
else: species, site, method, lab, datasetnr = identifier.split('_')
if site_info.has_key(identifier):
if site_info[identifier]['category'] != 'do-not-use' and obs.flag != 99:
......@@ -243,27 +243,29 @@ class ctsfObsPackObservations(ObsPackObservations):
logging.warning("Observation NOT found (%s, %d), please check sites.rc file (%s) !!!" % (identifier, obs.id, self.sites_file))
if site_move.has_key(identifier):
movelat, movelon = site_move[identifier]
obs.lat = obs.lat + movelat
obs.lon = obs.lon + movelon
logging.warning("Observation location for (%s, %d), is moved by %3.2f degrees latitude and %3.2f degrees longitude" % (identifier, obs.id, movelat, movelon))
if site_incalt.has_key(identifier):
incalt = site_incalt[identifier]
obs.height = obs.height + incalt
logging.warning("Observation location for (%s, %d), is moved by %3.2f meters in altitude" % (identifier, obs.id, incalt))
if advance == False:
if obs.flag > 80:
self.datalist.remove(obs)
logging.debug('Dropped observation from datalist, as it is not to be assimilated')
# Only keep obs in datalist that will be used in assimilation
self.datalist = obs_to_keep
logging.info('Removed unused observations, observation list now holds %s values' % len(self.datalist))
# Add site_info dictionary to the Observations object for future use
self.site_info = site_info
self.site_move = site_move
self.site_info = site_info
self.site_move = site_move
self.site_incalt = site_incalt
logging.debug("Added Model Data Mismatch to all samples ")
......
......@@ -14,22 +14,24 @@ program. If not, see <http://www.gnu.org/licenses/>."""
# optimizer.py
"""
Author : peters
Author : peters, liesbeth
Revision History:
File created on 28 Jul 2010.
File updated on 16 April 2019
"""
import os
import sys
import logging
import time
import numpy as np
sys.path.append(os.getcwd())
from da.carbondioxide.optimizer import CO2Optimizer
identifier = 'Ensemble Square Root Filter'
version = '0.0'
version = '2.0'
################### Begin Class CO2Optimizer ###################
......@@ -44,95 +46,17 @@ def update_hx_for_obs_n(m, n, obsn, Hx, nmembers, HX_prime, HPHR, alpha):
class ctsfOptimizer(CO2Optimizer):
def setup(self, dims, ntasks):
def setup(self, dims):
self.nlag = dims[0]
self.nmembers = dims[1]
self.nparams = dims[2]
self.nobs = dims[3]
self.ntasks = ntasks
self.create_matrices()
def serial_minimum_least_squares_serial(self):
""" Make minimum least squares solution by looping over obs"""
# calculate prior value cost function
J_prior, res_2_prior = 0, 0
for n in range(self.nobs):
J_prior += (self.obs[n]-self.Hx[n])**2/self.R[n]
res_2_prior += (self.obs[n]-self.Hx[n])**2
del n
for n in range(self.nobs):
# Screen for flagged observations (for instance site not found, or no sample written from model)
if self.flags[n] != 0:
logging.debug('Skipping observation (%s,%i) because of flag value %d' % (self.sitecode[n], self.obs_ids[n], self.flags[n]))
continue
# Screen for outliers greater than 3x model-data mismatch, only apply if obs may be rejected
res = self.obs[n] - self.Hx[n]
if self.may_reject[n]:
threshold = self.rejection_threshold * np.sqrt(self.R[n])
if np.abs(res) > threshold:
logging.debug('Rejecting observation (%s,%i) because residual (%f) exceeds threshold (%f)' % (self.sitecode[n], self.obs_ids[n], res, threshold))
self.flags[n] = 2
continue
logging.debug('Proceeding to assimilate observation %s, %i' % (self.sitecode[n], self.obs_ids[n]))
PHt = 1. / (self.nmembers - 1) * np.dot(self.X_prime, self.HX_prime[n, :])
self.HPHR[n] = 1. / (self.nmembers - 1) * (self.HX_prime[n, :] * self.HX_prime[n, :]).sum() + self.R[n]
self.KG[:] = PHt / self.HPHR[n]
if self.may_localize[n]:
logging.debug('Trying to localize observation %s, %i' % (self.sitecode[n], self.obs_ids[n]))
self.localize(n)
else:
logging.debug('Not allowed to localize observation %s, %i' % (self.sitecode[n], self.obs_ids[n]))
alpha = np.double(1.0) / (np.double(1.0) + np.sqrt((self.R[n]) / self.HPHR[n]))
self.x[:] = self.x + self.KG[:] * res
for r in range(self.nmembers):
self.X_prime[:, r] = self.X_prime[:, r] - alpha * self.KG[:] * (self.HX_prime[n, r])
del r
#WP !!!! Very important to first do all obervations from n=1 through the end, and only then update 1,...,n. The current observation
#WP should always be updated last because it features in the loop of the adjustments !!!!
for m in range(n + 1, self.nobs):
res = self.obs[n] - self.Hx[n]
fac = 1.0 / (self.nmembers - 1) * (self.HX_prime[n, :] * self.HX_prime[m, :]).sum() / self.HPHR[n]
self.Hx[m] = self.Hx[m] + fac * res
self.HX_prime[m, :] = self.HX_prime[m, :] - alpha * fac * self.HX_prime[n, :]
for m in range(n + 1):
res = self.obs[n] - self.Hx[n]
fac = 1.0 / (self.nmembers - 1) * (self.HX_prime[n, :] * self.HX_prime[m, :]).sum() / self.HPHR[n]
self.Hx[m] = self.Hx[m] + fac * res
self.HX_prime[m, :] = self.HX_prime[m, :] - alpha * fac * self.HX_prime[n, :]
del m
del n
# calculate posterior value cost function
J_post, res_2_post = 0, 0
for n in range(self.nobs):
J_post += (self.obs[n]-self.Hx[n])**2/self.R[n]
res_2_post += (self.obs[n]-self.Hx[n])**2
del n
logging.info('Observation part cost function: prior = %s, posterior = %s' % (J_prior, J_post))
logging.info('Squared residual: prior = %s, posterior = %s' % (res_2_prior, res_2_post))
def serial_minimum_least_squares(self):
""" Make minimum least squares solution by looping over obs"""
from multiprocessing import Pool
from functools import partial
# calculate prior value cost function
J_prior, res_2_prior = 0, 0
for n in range(self.nobs):
......@@ -140,9 +64,7 @@ class ctsfOptimizer(CO2Optimizer):
res_2_prior += (self.obs[n]-self.Hx[n])**2
del n
# create pools for parallel updating of Hx and HX_prime
pool = Pool(processes=self.ntasks)
tic = time.clock()
for n in range(self.nobs):
# Screen for flagged observations (for instance site not found, or no sample written from model)
......@@ -180,21 +102,15 @@ class ctsfOptimizer(CO2Optimizer):
del r
# update samples to account for update of statevector based on observation n
#WP !!!! Very important to first do all obervations from n=1 through the end, and only then update 1,...,n. The current observation
#WP should always be updated last because it features in the loop of the adjustments !!!
if n+1 < self.nobs:
void = np.array(pool.map(partial(update_hx_for_obs_n, n=n, obsn=self.obs[n], Hx=self.Hx, nmembers=self.nmembers, HX_prime=self.HX_prime, HPHR=self.HPHR[n], alpha=alpha),range(n+1,self.nobs)))
self.Hx[n+1:self.nobs] = void[:,1]
self.HX_prime[n+1:self.nobs,:] = [np.array(hxm) for hxm in void[:,0]]
if n > 0:
void = np.array(pool.map(partial(update_hx_for_obs_n, n=n, obsn=self.obs[n], Hx=self.Hx, nmembers=self.nmembers, HX_prime=self.HX_prime, HPHR=self.HPHR[n], alpha=alpha),range(n)))
self.Hx[:n] = void[:,1]
self.HX_prime[:n,:] = [np.array(hxm) for hxm in void[:,0]]
self.HX_prime[n,:], self.Hx[n] = update_hx_for_obs_n(n,n=n, obsn=self.obs[n], Hx=self.Hx, nmembers=self.nmembers, HX_prime=self.HX_prime, HPHR=self.HPHR[n], alpha=alpha)
del n, void
pool.close()
HXprime_n = self.HX_prime[n,:].copy()
res = self.obs[n] - self.Hx[n]
fac = 1.0 / (self.nmembers - 1) * np.sum(HXprime_n[np.newaxis,:] * self.HX_prime, axis=1) / self.HPHR[n]
self.Hx = self.Hx + fac*res
self.HX_prime = self.HX_prime - alpha* fac[:,np.newaxis]*HXprime_n
del n, HXprime_n
toc = time.clock()
logging.debug('Minimum least squares update finished in %s seconds' %(toc-tic))
# calculate posterior value cost function
J_post, res_2_post = 0, 0
......
......@@ -270,7 +270,6 @@ def invert(dacycle, statevector, optimizer):
int(dacycle['da.optimizer.nmembers']),
statevector.nparams,
statevector.nobs)
ntasks = int(dacycle['da.resources.ntasks'])
if not dacycle.dasystem.has_key('opt.algorithm'):
logging.info("There was no minimum least squares algorithm specified in the DA System rc file (key : opt.algorithm)")
......@@ -283,7 +282,7 @@ def invert(dacycle, statevector, optimizer):
logging.info("Using the bulk minimum least squares algorithm to solve ENKF equations")
optimizer.set_algorithm('Bulk')
optimizer.setup(dims, ntasks)
optimizer.setup(dims)
optimizer.state_to_matrix(statevector)
diagnostics_file = os.path.join(dacycle['dir.output'], 'optimizer.%s.nc' % dacycle['time.start'].strftime('%Y%m%d'))
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment