diff --git a/da/optimizers/optimizer_wrfchem.py b/da/optimizers/optimizer_wrfchem.py
new file mode 100644
index 0000000000000000000000000000000000000000..67af25302000f56c45226b006c055d890872d8ce
--- /dev/null
+++ b/da/optimizers/optimizer_wrfchem.py
@@ -0,0 +1,295 @@
+"""CarbonTracker Data Assimilation Shell (CTDAS) Copyright (C) 2017 Wouter Peters. 
+Users are recommended to contact the developers (wouter.peters@wur.nl) to receive
+updates of the code. See also: http://www.carbontracker.eu. 
+
+This program is free software: you can redistribute it and/or modify it under the
+terms of the GNU General Public License as published by the Free Software Foundation, 
+version 3. This program is distributed in the hope that it will be useful, but 
+WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS 
+FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. 
+
+You should have received a copy of the GNU General Public License along with this 
+program. If not, see <http://www.gnu.org/licenses/>."""
+#!/usr/bin/env python
+# optimizer.py
+
+"""
+Author : peters 
+
+Revision History:
+File created on 28 Jul 2010.
+
+"""
+
+import os
+import sys
+import logging
+import numpy as np
+import da.tools.io4 as io
+sys.path.append(os.getcwd())
+from da.optimizers.optimizer_baseclass import Optimizer
+
+
+identifier = 'Ensemble Square Root Filter'
+version = '0.0'
+
+################### Begin Class CO2Optimizer ###################
+
+class CO2Optimizer(Optimizer):
+    """
+        This creates an instance of a CarbonTracker optimization object. The base class it derives from is the optimizer object.
+        Additionally, this CO2Optimizer implements a special localization option following the CT2007 method.
+
+        All other methods are inherited from the base class Optimizer.
+    """
+
+    def create_matrices(self):
+        """ Create Matrix space needed in optimization routine """
+
+        # mean state  [X]
+        self.x = np.zeros((self.nlag * self.nparams,), float)
+        # deviations from mean state  [X']
+        self.X_prime = np.zeros((self.nlag * self.nparams, self.nmembers,), float)
+        # mean state, transported to observation space [ H(X) ]
+        self.Hx = np.zeros((self.nobs,), float)
+        # deviations from mean state, transported to observation space [ H(X') ]
+        self.HX_prime = np.zeros((self.nobs, self.nmembers), float)
+        # observations
+        self.obs = np.zeros((self.nobs,), float)
+        # observation ids
+        self.obs_ids = np.zeros((self.nobs,), float)
+        # covariance of observations
+        # Total covariance of fluxes and obs in units of obs [H P H^t + R]
+        if self.algorithm == 'Serial':
+            self.R = np.zeros((self.nobs,), float)
+            self.HPHR = np.zeros((self.nobs,), float)
+        else:
+            self.R = np.zeros((self.nobs, self.nobs,), float)
+            self.HPHR = np.zeros((self.nobs, self.nobs,), float)
+        # localization of obs
+        self.may_localize = np.zeros(self.nobs, bool)
+        self.loc_L = np.zeros(self.nobs, int)
+        # rejection of obs
+        self.may_reject = np.zeros(self.nobs, bool)
+        # flags of obs
+        self.flags = np.zeros(self.nobs, int)
+        # species type
+        self.species = np.zeros(self.nobs, str)
+        # species type
+        self.sitecode = np.zeros(self.nobs, str)
+        # rejection_threshold
+        self.rejection_threshold = np.zeros(self.nobs, float)
+        # lat and lon
+        self.latitude = np.zeros(self.nobs, float)
+        self.longitude = np.zeros(self.nobs, float)
+        self.date = np.zeros((self.nobs, 6), float)
+
+        # species mask
+        self.speciesmask = {}
+
+        # Kalman Gain matrix
+        #self.KG = np.zeros((self.nlag * self.nparams, self.nobs,), float)
+        self.KG = np.zeros((self.nlag * self.nparams,), float)
+
+
+
+    def state_to_matrix(self, statevector):
+        allsites     = []  # collect all obs for n=1,..,nlag
+        alllats      = []  # collect all latitudes for n=1,..,nlag
+        alllons      = []  # collect all longitudes for n=1,..,nlag
+        alldates     = []  # collect all date vectors for n=1,..,nlag
+        allobs       = []  # collect all obs for n=1,..,nlag
+        allmdm       = []  # collect all mdm for n=1,..,nlag
+        allids       = []  # collect all model samples for n=1,..,nlag
+        allreject    = []  # collect all model samples for n=1,..,nlag
+        alllocalize  = []  # collect all model samples for n=1,..,nlag
+        allloc_L     = []  # collect all model samples for n=1,..,nlag
+        allflags     = []  # collect all model samples for n=1,..,nlag
+        allspecies   = []  # collect all model samples for n=1,..,nlag
+        allsimulated = []  # collect all members model samples for n=1,..,nlag
+        allrej_thres = []  # collect all rejection_thresholds, will be the same for all samples of same source
+
+        for n in range(self.nlag):
+
+            samples = statevector.obs_to_assimilate[n]
+            members = statevector.ensemble_members[n]
+            self.x[n * self.nparams:(n + 1) * self.nparams]          = members[0].param_values
+            self.X_prime[n * self.nparams:(n + 1) * self.nparams, :] = np.transpose(np.array([m.param_values for m in members]))
+
+            # Add observation data for all sample objects
+            if samples != None:
+                if type(samples) != list: samples = [samples]
+                for m in range(len(samples)):
+                    sample = samples[m]
+                    logging.debug('Lag %i, sample %i: rejection_threshold = %i, nobs = %i' %(n, m, sample.rejection_threshold, sample.getlength()))
+
+                    allrej_thres.extend([sample.rejection_threshold] * sample.getlength())
+                    allreject.extend(sample.getvalues('may_reject'))
+                    alllocalize.extend(sample.getvalues('may_localize'))
+                    allloc_L.extend(sample.getvalues('loc_L'))
+                    allflags.extend(sample.getvalues('flag'))
+                    allspecies.extend(sample.getvalues('species'))
+                    allobs.extend(sample.getvalues('obs'))
+                    allsites.extend(sample.getvalues('code'))
+                    alllats.extend(sample.getvalues('lat'))
+                    alllons.extend(sample.getvalues('lon'))
+                    alldates.extend([[d.year, d.month, d.day, d.hour, d.minute, d.second] for d in sample.getvalues('xdate')])
+                    allmdm.extend(sample.getvalues('mdm'))
+                    allids.extend(sample.getvalues('id'))
+
+                    simulatedensemble = sample.getvalues('simulated')
+                    for s in range(simulatedensemble.shape[0]):
+                        allsimulated.append(simulatedensemble[s])
+
+        self.rejection_threshold[:] = np.array(allrej_thres)
+
+        self.obs[:]          = np.array(allobs)
+        self.obs_ids[:]      = np.array(allids)
+        self.HX_prime[:, :]  = np.array(allsimulated)
+        self.Hx[:]           = self.HX_prime[:, 0]
+
+        self.may_reject[:]   = np.array(allreject)
+        self.may_localize[:] = np.array(alllocalize)
+        self.loc_L[:]        = np.array(allloc_L)
+        self.flags[:]        = np.array(allflags)
+        self.species[:]      = np.array(allspecies)
+        self.sitecode        = allsites
+        self.latitude[:]     = np.array(alllats)
+        self.longitude[:]    = np.array(alllons)
+        self.date[:,:]       = np.array(alldates)
+
+        self.X_prime  = self.X_prime - self.x[:, np.newaxis] # make into a deviation matrix
+        self.HX_prime = self.HX_prime - self.Hx[:, np.newaxis] # make a deviation matrix
+
+        if self.algorithm == 'Serial':
+            for i, mdm in enumerate(allmdm):
+                self.R[i] = mdm ** 2
+        else:
+            for i, mdm in enumerate(allmdm):
+                self.R[i, i] = mdm ** 2
+
+        # For spatial localization (can't test for
+        # it here without altering the pipeline, so
+        # have to rely that the statevector has it
+        # when necessary)
+        if hasattr(statevector, "coords"):
+            self.sv_coords = statevector.coords
+            self.cdist = statevector.cdist
+            self.sv_n_emis_proc = statevector.n_emis_proc
+            self.sv_nbgparams = statevector.nbgparams
+
+    def set_localization(self, loctype='None'):
+        """ determine which localization to use """
+
+        def represents_number(s):
+            try:
+                float(s)
+                return True
+            except ValueError:
+                return False
+
+        if loctype == 'CT2007':
+            self.localization = True
+            self.localizetype = 'CT2007'
+            #T-test values for two-tailed student's T-test using 95% confidence interval for some options of nmembers
+            if self.nmembers == 50:
+                self.tvalue = 2.0086
+            elif self.nmembers == 100:
+                self.tvalue = 1.9840
+            elif self.nmembers == 150:
+                self.tvalue = 1.97591
+            elif self.nmembers == 200:
+                self.tvalue = 1.9719
+            else: self.tvalue = 0
+        elif loctype == 'None':
+            self.localization = False
+            self.localizetype = 'None'
+        elif represents_number(loctype):
+            self.localization = True
+            self.localizetype = 'spatial'
+            self.localization_length = float(loctype)
+            # Check that you could import from statevector
+            if not hasattr(self, "sv_coords"):
+                raise ValueError("Could not find state vector coordinates")
+            # import glob
+            # import re
+            # # Read all files with localization coefficients (all different length scale) and save to dict
+            # # loctype in this case contains template of filenames, so full path and ending at ..._*km.nc
+            # self.loc_coeff = {}
+            # files_all = sorted(glob.glob(loctype))
+            # for file in files_all:
+            #     L = re.findall(r"(\d+)km.nc", file)[0]
+            #     logging.debug('L = %s' %(str(L)))
+            #     ncf = io.ct_read(file, 'read')
+            #     self.loc_coeff[L] = ncf.get_variable('loc_coeff').filled(fill_value=0)
+            #     ncf.close()
+            #     logging.info('Added spatial localization coefficients for L=%skm from file %s' %(str(L),file))
+            # del file
+        else:
+            raise ValueError("Unknown loctype: " + str(loctype))
+        logging.info("Current localization option is set to %s" % self.localizetype)
+        if self.localization == True and self.localizetype == 'CT2007':
+            if self.tvalue == 0:
+                logging.error("Critical tvalue for localization not set for %i ensemble members"%(self.nmembers))
+                sys.exit(2)
+            else: logging.info("Used critical tvalue %0.05f is based on 95%% probability and %i ensemble members in a two-tailed student's T-test"%(self.tvalue,self.nmembers))
+
+    def localize(self, n):
+        """ localize the Kalman Gain matrix """
+        import numpy as np
+
+        if not self.localization: 
+            logging.debug('Not localized observation %i' % self.obs_ids[n])
+            return 
+        if self.localizetype == 'CT2007':
+            count_localized = 0
+            for r in range(self.nlag * self.nparams):
+                corr = np.corrcoef(self.HX_prime[n, :], self.X_prime[r, :].squeeze())[0, 1]
+                prob = corr / np.sqrt((1.000000001 - corr ** 2) / (self.nmembers - 2))
+                if abs(prob) < self.tvalue:
+                    self.KG[r] = 0.0
+                    count_localized = count_localized + 1
+            logging.debug('Localized observation %i, %i%% of values set to 0' % (self.obs_ids[n],count_localized*100/(self.nlag * self.nparams)))
+        elif self.localizetype == 'spatial':
+            distances = self.cdist(self.sv_coords, np.array([[self.latitude[n], self.longitude[n]]])).squeeze()
+            loc_coeff_reg = np.exp(-(distances/self.localization_length)**2)
+            loc_coeff_emis = np.repeat(loc_coeff_reg, self.sv_n_emis_proc)
+            loc_coeff = np.concatenate((loc_coeff_emis, np.repeat(1.0, self.sv_nbgparams)))
+            for l in range(self.nlag):
+                self.KG[l*self.nparams:(l+1)*self.nparams] = np.multiply(self.KG[l*self.nparams:(l+1)*self.nparams], loc_coeff)
+            logging.debug('Localized observation %i with localization length %s' %(self.obs_ids[n], self.localization_length))
+
+
+    def find_coord_index(self,lat,lon,nlat,nlon):
+        """
+        Find index of coord in 2D lat-lon array (nlat x nlon) based on coordinates.
+        Assumption: lat between -90 and 90, lon between -180 and 180
+        """
+        dlat = 180/nlat
+        dlon = 360/nlon
+
+        lat0 = -90 + dlat/2
+        lon0 = -180 + dlon/2
+
+        lat_index = (lat - lat0) / dlat
+        lon_index = (lon - lon0) / dlon
+
+        return lat_index.astype(int), lon_index.astype(int)
+
+
+
+
+    def set_algorithm(self, algorithm='Serial'):
+        """ determine which minimum least squares algorithm to use """
+
+        if algorithm == 'Serial':
+            self.algorithm = 'Serial'
+        else:
+            self.algorithm = 'Bulk'
+    
+        logging.info("Current minimum least squares algorithm is set to %s" % self.algorithm)
+
+################### End Class CO2Optimizer ###################
+
+if __name__ == "__main__":
+    pass
diff --git a/da/statevectors/statevector_wrfchem.py b/da/statevectors/statevector_wrfchem.py
index c4770c83c0e7dc7da873c941a51064507d6a0232..bc2f91dae8bb8c29d9446ef95dd6bff7e9215141 100644
--- a/da/statevectors/statevector_wrfchem.py
+++ b/da/statevectors/statevector_wrfchem.py
@@ -129,11 +129,29 @@ class WRFChemStateVector(StateVector):
         ncf = io.ct_read(mapfile, 'read')
         self.gridmap = ncf.get_variable('regions').astype(int)
 
-        # For computing spatial covariances, read latitude
-        # and longitude from regionsfile
-        if self.spatial_correlations:
-            self.lon = ncf.get_variable('longitude')
-            self.lat = ncf.get_variable('latitude')
+        # For computing spatial covariances or spatial localization,
+        # read latitude and longitude from regionsfile
+        def represents_number(s):
+            try:
+                float(s)
+                return True
+            except ValueError:
+                return False
+        is_spatial = represents_number
+
+        if self.spatial_correlations or \
+           is_spatial(dacycle["da.system.localization"]):
+            lons = ncf.get_variable('longitude')
+            lats = ncf.get_variable('latitude')
+            # Compute center point of each region
+            lats_reg = np.zeros((self.nparams_proc,))*np.nan
+            lons_reg = lats_reg.copy()
+            for nr in range(self.nparams_proc):
+                is_nr = self.gridmap==nr+1 # Regions are 1-based
+                lats_reg[nr] = lats[is_nr].mean()
+                lons_reg[nr] = lons[is_nr].mean()
+            self.coords = np.vstack([lats_reg, lons_reg]).T
+
         # >>> edit freum: Not using transcom regions mapping in regional setup with wrfchem
         #self.tcmap = ncf.get_variable('transcom_regions')
         # <<< edit freum
@@ -342,7 +360,7 @@ class WRFChemStateVector(StateVector):
         # Fill flux covariances
         if self.spatial_correlations:
             if not hasattr(self, "dm"):
-                self.dm = self.get_distance_matrix()
+                self.dm = self.cdist(self.coords, self.coords)
         for nproc in range(self.n_emis_proc):
             i0 = self.nparams_proc*nproc
             i1 = self.nparams_proc*(nproc+1)
@@ -372,23 +390,6 @@ class WRFChemStateVector(StateVector):
 
         return covariancematrix
 
-    def get_distance_matrix(self):
-        """Distance matrix of regions
-        Based on 'lat' and 'lon' in regionsfile"""
-
-        # Compute center point of each region
-        lats = np.zeros((self.nparams_proc,))*np.nan
-        lons = lats.copy()
-        for nr in range(self.nparams_proc):
-            is_nr = self.gridmap==nr+1 # Regions are 1-based
-            lats[nr] = self.lat[is_nr].mean()
-            lons[nr] = self.lon[is_nr].mean()
-        coords = np.vstack([lats, lons]).T
-        geodist=lambda p1, p2: distance.distance(p1, p2).km
-        dm = self.cdist(coords, coords)
-
-        return dm
-
     @staticmethod
     def cdist(c1, c2):
         """
@@ -426,7 +427,6 @@ class WRFChemStateVector(StateVector):
             h[h>1.0] = 1.0
             d = 2*EARTH_RADIUS*np.arcsin(np.sqrt(h))
             return d
-            
     
         # Ensure dimensions
         c1 = np.array(c1, ndmin=2)