optimizer.py 21 KB
Newer Older
1
2
3
4
#!/usr/bin/env python
# optimizer.py

"""
Peters, Wouter's avatar
Peters, Wouter committed
5
6
.. module:: optimizer
.. moduleauthor:: Wouter Peters 
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66

Revision History:
File created on 28 Jul 2010.

"""

import os
import sys
import logging
import datetime

identifier = 'Optimizer baseclass'
version    = '0.0'

################### Begin Class Optimizer ###################

class Optimizer(object):
    """
        This creates an instance of an optimization object. It handles the minimum least squares optimization
        of the state vector given a set of sample objects. Two routines will be implemented: one where the optimization
        is sequential and one where it is the equivalent matrix solution. The choice can be made based on considerations of speed
        and efficiency.
    """

    def __init__(self):
        self.Identifier = self.getid()
        self.Version    = self.getversion()

        msg                 = 'Optimizer object initialized: %s'%self.Identifier ; logging.info(msg)

    def getid(self):
        return identifier

    def getversion(self):
        return version

    def Initialize(self, dims):

        self.nlag               = dims[0]
        self.nmembers           = dims[1]
        self.nparams            = dims[2]
        self.nobs               = dims[3]
        self.CreateMatrices()

        return None

    def CreateMatrices(self):
        """ Create Matrix space needed in optimization routine """
        import numpy as np

        # mean state  [X]
        self.x               = np.zeros( (self.nlag*self.nparams,), float)
        # deviations from mean state  [X']
        self.X_prime         = np.zeros( (self.nlag*self.nparams,self.nmembers,), float)
        # mean state, transported to observation space [ H(X) ]
        self.Hx              = np.zeros( (self.nobs,), float)
        # deviations from mean state, transported to observation space [ H(X') ]
        self.HX_prime        = np.zeros( (self.nobs,self.nmembers), float)
        # observations
        self.obs             =  np.zeros( (self.nobs,), float)
67
68
        # observation ids
        self.obs_ids         =  np.zeros( (self.nobs,), float)
69
70
        # covariance of observations
        self.R               =  np.zeros( (self.nobs,self.nobs,), float)
71
        # localization of obs
72
        self.may_localize       = np.zeros(self.nobs,bool)
73
        # rejection of obs
74
75
76
        self.may_reject         = np.zeros(self.nobs,bool)
        # flags of obs
        self.flags         = np.zeros(self.nobs,int)
77
78
79
80
81
        # species type
        self.species       = np.zeros(self.nobs,str)

        # species mask
        self.speciesmask    = {}
82

83
84
85
86
87
88
89
        # Total covariance of fluxes and obs in units of obs [H P H^t + R]
        self.HPHR            =  np.zeros( (self.nobs,self.nobs,), float)
        # Kalman Gain matrix
        self.KG              =  np.zeros( (self.nlag*self.nparams,self.nobs,), float)

    def StateToMatrix(self,StateVector):
        import numpy as np
90
91
92
93
        try:
            import matplotlib.pyplot as plt
        except:
            pass
94

95
96
97
        allobs=[]      # collect all obs for n=1,..,nlag
        allmdm=[]      # collect all mdm for n=1,..,nlag
        allsamples=[]  # collect all model samples for n=1,..,nlag
98
        allids=[]  # collect all model samples for n=1,..,nlag
99
100
101
        allreject=[]  # collect all model samples for n=1,..,nlag
        alllocalize=[]  # collect all model samples for n=1,..,nlag
        allflags=[]  # collect all model samples for n=1,..,nlag
102
        allspecies=[]  # collect all model samples for n=1,..,nlag
103
        allsimulated=None  # collect all members model samples for n=1,..,nlag
104

105
106
        for n in range(self.nlag):

107
            Samples                                             = StateVector.ObsToAssimmilate[n]
108
            members                                             = StateVector.EnsembleMembers[n]
109
            self.x[n*self.nparams:(n+1)*self.nparams]           = members[0].ParameterValues
110
            self.X_prime[n*self.nparams:(n+1)*self.nparams,:]   = np.transpose(np.array([m.ParameterValues for m in members]))
111

112
            if Samples != None:
Peters, Wouter's avatar
Peters, Wouter committed
113

114
                self.rejection_threshold                        = Samples.rejection_threshold
115

116
117
118
119
120
121
122
                allreject.extend(                                 Samples.Data.getvalues('may_reject') )
                alllocalize.extend(                               Samples.Data.getvalues('may_localize') )
                allflags.extend(                                  Samples.Data.getvalues('flag') )
                allspecies.extend(                                Samples.Data.getvalues('species') )
                allobs.extend(                                    Samples.Data.getvalues('obs') )
                allmdm.extend(                                    Samples.Data.getvalues('mdm')  )
                allids.extend(                                    Samples.Data.getvalues('id')  )
123

124
                simulatedensemble =                               Samples.Data.getvalues('simulated')
Peters, Wouter's avatar
Peters, Wouter committed
125

126
127
                if allsimulated == None :
                    allsimulated = np.array(simulatedensemble)
Peters, Wouter's avatar
Peters, Wouter committed
128
                else:
129
130
131
                    allsimulated = np.concatenate((allsimulated,np.array(simulatedensemble)),axis=0)


Peters, Wouter's avatar
Peters, Wouter committed
132

133
134


135
        self.obs[:]                                             = np.array(allobs)
136
        self.obs_ids[:]                                         = np.array(allids)
137
138
139

        self.HX_prime[:,:]                                      = np.array(allsimulated)
        self.Hx[:]                                              = self.HX_prime[:,0]
140

141
142
143
        self.may_reject[:]                                      = np.array(allreject)
        self.may_localize[:]                                    = np.array(alllocalize)
        self.flags[:]                                           = np.array(allflags)
144
        self.species[:]                                         = np.array(allspecies)
145
146


147
        self.X_prime                                            = self.X_prime - self.x[:,np.newaxis] # make into a deviation matrix
148
149
        self.HX_prime                                           = self.HX_prime - self.Hx[:,np.newaxis] # make a deviation matrix

150

151
        for i,mdm in enumerate(allmdm):
152
153

            self.R[i,i]                                             = mdm**2
154
155
156
157
158
159
160
161
162
163
164
165

        return None

    def MatrixToState(self,StateVector):
        import numpy as np

        for n in range(self.nlag):

            members                              = StateVector.EnsembleMembers[n]
            for m,mem in enumerate(members):
                members[m].ParameterValues[:]    = self.X_prime[n*self.nparams:(n+1)*self.nparams,m] + self.x[n*self.nparams:(n+1)*self.nparams]     

166
167
        StateVector.isOptimized = True
        msg = 'Returning optimized data to the StateVector, setting "StateVector.isOptimized = True" ' ; logging.debug(msg)
168
169
        return None

170
    def WriteDiagnostics(self,DaCycle, StateVector,type='prior'):
171
172
173
174
175
176
177
178
179
180
        """
            Open a NetCDF file and write diagnostic output from optimization process:

                - calculated residuals
                - model-data mismatches
                - HPH^T
                - prior ensemble of samples
                - posterior ensemble of samples
                - prior ensemble of fluxes
                - posterior ensemble of fluxes
181
182

            The type designation refers to the writing of prior or posterior data and is used in naming the variables"
183
        """
184
185
        import da.tools.io4 as io
        #import da.tools.io as io
186

187
        outdir          = DaCycle['dir.diagnostics']
188
        filename        = os.path.join(outdir,'optimizer.%s.nc'% DaCycle['time.start'].strftime('%Y%m%d') )
189
        DaCycle.OutputFileList += ( filename, )
190

191
192
193
        # Open or create file

        if type == 'prior':
194
            f           = io.CT_CDF(filename,method='create')
Peters, Wouter's avatar
Peters, Wouter committed
195
            msg         = 'Creating new diagnostics file for optimizer (%s)' % filename ; logging.debug(msg)
196
        elif type == 'optimized':
197
            f           = io.CT_CDF(filename,method='write')
Peters, Wouter's avatar
Peters, Wouter committed
198
            msg         = 'Opening existing diagnostics file for optimizer (%s)' % filename ; logging.debug(msg)
199
200

        # Add dimensions 
201
202
203
204
205

        dimparams       = f.AddParamsDim(self.nparams)
        dimmembers      = f.AddMembersDim(self.nmembers)
        dimlag          = f.AddLagDim(self.nlag, unlimited=False)
        dimobs          = f.AddObsDim(self.nobs)
206
        dimstate        = f.AddDim('nstate',self.nparams*self.nlag)
207

208
209
        # Add data, first the ones that are written both before and after the optimization

210
211
        data =  self.x

212
        savedict                = io.std_savedict.copy() 
213
214
215
216
217
218
219
220
221
222
        savedict['name']        = "statevectormean_%s" % type
        savedict['long_name']   = "full_statevector_mean_%s" % type
        savedict['units']       = "unitless"
        savedict['dims']        = dimstate
        savedict['values']      = data.tolist()
        savedict['comment']     = 'Full %s state vector mean '% type
        dummy                   = f.AddData(savedict)

        data =  self.X_prime

223
        savedict                = io.std_savedict.copy()
224
225
        savedict['name']        = "statevectordeviations_%s" % type
        savedict['long_name']   = "full_statevector_deviations_%s" % type
226
        savedict['units']       = "unitless"
227
228
229
230
231
232
233
        savedict['dims']        = dimstate+dimmembers
        savedict['values']      = data.tolist()
        savedict['comment']     = 'Full state vector %s deviations as resulting from the optimizer'% type
        dummy                   = f.AddData(savedict)

        data =  self.Hx

234
        savedict                = io.std_savedict.copy()
235
236
237
238
239
240
241
242
243
244
        savedict['name']        = "modelsamplesmean_%s"%type
        savedict['long_name']   = "modelsamplesforecastmean_%s" %type
        savedict['units']       = "mol mol-1"
        savedict['dims']        = dimobs
        savedict['values']      = data.tolist()
        savedict['comment']     = '%s mean mixing ratios based on %s state vector'% (type,type,)
        dummy                   = f.AddData(savedict)

        data =  self.HX_prime

245
        savedict                = io.std_savedict.copy()
246
247
248
249
        savedict['name']        = "modelsamplesdeviations_%s"% type
        savedict['long_name']   = "modelsamplesforecastdeviations_%s"%type
        savedict['units']       = "mol mol-1"
        savedict['dims']        = dimobs+dimmembers
250
        savedict['values']      = data.tolist()
251
        savedict['comment']     = '%s mixing ratio deviations based on %s state vector'% (type,type,)
252
253
        dummy                   = f.AddData(savedict)

254
255
256
257
258
259
        # Continue with prior only data

        if type == 'prior':

            data =  self.obs

260
            savedict                = io.std_savedict.copy()
261
262
263
264
265
266
267
268
            savedict['name']        = "observed"
            savedict['long_name']   = "observedvalues"
            savedict['units']       = "mol mol-1"
            savedict['dims']        = dimobs
            savedict['values']      = data.tolist()
            savedict['comment']     = 'Observations used in optimization'
            dummy                   = f.AddData(savedict)

269
270
271
272
273
274
275
276
277
278
279
280
            data =  self.obs_ids

            savedict                = io.std_savedict.copy()
            savedict['name']        = "obspack_num"
            savedict['dtype']       = "int64"
            savedict['long_name']   = "Unique_ObsPack_observation_number"
            savedict['units']       = ""
            savedict['dims']        = dimobs
            savedict['values']      = data.tolist()
            savedict['comment']     = 'Unique observation number across the entire ObsPack distribution'
            dummy                   = f.AddData(savedict)

281
282
283

            data =  self.R

284
            savedict                = io.std_savedict.copy()
285
286
287
288
289
290
291
292
293
294
295
296
297
298
            savedict['name']        = "modeldatamismatch"
            savedict['long_name']   = "modeldatamismatch"
            savedict['units']       = "[mol mol-1]^2"
            savedict['dims']        = dimobs+dimobs
            savedict['values']      = data.tolist()
            savedict['comment']     = 'Variance of mole fractions resulting from model-data mismatch'
            dummy                   = f.AddData(savedict)

        # Continue with posterior only data

        elif type == 'optimized':

            data =  self.HPHR

299
            savedict                = io.std_savedict.copy()
300
301
302
303
304
305
306
307
308
309
            savedict['name']        = "molefractionvariance"
            savedict['long_name']   = "molefractionvariance"
            savedict['units']       = "[mol mol-1]^2"
            savedict['dims']        = dimobs+dimobs
            savedict['values']      = data.tolist()
            savedict['comment']     = 'Variance of mole fractions resulting from prior state and model-data mismatch'
            dummy                   = f.AddData(savedict)

            data =  self.flags

310
            savedict                = io.std_savedict.copy()
311
312
            savedict['name']        = "flag"
            savedict['long_name']   = "flag_for_obs_model"
313
            savedict['units']       = "None"
314
315
            savedict['dims']        = dimobs
            savedict['values']      = data.tolist()
316
            savedict['comment']     = 'Flag (0/1/2/99) for observation value, 0 means okay, 1 means QC error, 2 means rejected, 99 means not sampled'
317
318
319
320
            dummy                   = f.AddData(savedict)

            data =  self.KG

321
            savedict                = io.std_savedict.copy()
322
323
324
325
326
327
328
329
            savedict['name']        = "kalmangainmatrix"
            savedict['long_name']   = "kalmangainmatrix"
            savedict['units']       = "unitless molefraction-1"
            savedict['dims']        = dimstate+dimobs
            savedict['values']      = data.tolist()
            savedict['comment']     = 'Kalman gain matrix of all obs and state vector elements'
            dummy                   = f.AddData(savedict)

330
331
        dummy                       = f.close()
        msg = 'Diagnostics file closed ' ; logging.debug(msg)
332
333
334

        return None

335
336
337
    def SerialMinimumLeastSquares(self):
        """ Make minimum least squares solution by looping over obs"""
        import numpy as np
338
        import numpy.linalg as la
339
340

        tvalue=1.97591
341

342
343
        for n in range(self.nobs):

344
345
346
            # Screen for flagged observations (for instance site not found, or no sample written from model)

            if self.flags[n] != 0:
347
                msg = 'Skipping observation %d because of flag value %d'%(n,self.flags[n]) ; logging.debug(msg)
348
349
                continue

350
351
352

            # Screen for outliers greather than 3x model-data mismatch, only apply if obs may be rejected

353
354
355
            res                         = self.obs[n]-self.Hx[n]

            if self.may_reject[n]:
356
357
358
359
                threshold =  self.rejection_threshold*np.sqrt(self.R[n,n])
                if np.abs(res) > threshold:
                    msg = 'Rejecting observation %d because residual (%f) exceeds threshold (%f)'%(n,res,threshold) ; logging.debug(msg)
                    self.flags[n] == 2
360
                    continue
361
362
363
364
365
366

            PHt                         = 1./(self.nmembers-1)*np.dot(self.X_prime,self.HX_prime[n,:])
            self.HPHR[n,n]              = 1./(self.nmembers-1)*(self.HX_prime[n,:]*self.HX_prime[n,:]).sum()+self.R[n,n]

            self.KG[:,n]                = PHt/self.HPHR[n,n]

367
            if self.may_localize[n]:
368
                dummy                   = self.Localize(n)
369
370
371
                msg = 'Localized observation %d'%(n,) ; logging.debug(msg)
            else:
                msg = 'Not allowed to Localize observation %d'%(n,) ; logging.debug(msg)
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489

            alpha                       = np.double(1.0)/(np.double(1.0)+np.sqrt( (self.R[n,n])/self.HPHR[n,n] ) )

            self.x[:]                   = self.x + self.KG[:,n]*res

            for r in range(self.nmembers):
                self.X_prime[:,r]       = self.X_prime[:,r]-alpha*self.KG[:,n]*(self.HX_prime[n,r])

#WP !!!! Very important to first do all obervations from n=1 through the end, and only then update 1,...,n. The current observation
#WP      should always be updated last because it features in the loop of the adjustments !!!!

            for m in range(n+1,self.nobs):
                res                     = self.obs[n]-self.Hx[n]
                fac                     = 1.0/(self.nmembers-1)*(self.HX_prime[n,:]*self.HX_prime[m,:]).sum()/self.HPHR[n,n]
                self.Hx[m]              = self.Hx[m] + fac * res
                self.HX_prime[m,:]      = self.HX_prime[m,:] - alpha*fac*self.HX_prime[n,:]

            for m in range(1,n+1):
                res                     = self.obs[n]-self.Hx[n]
                fac                     = 1.0/(self.nmembers-1)*(self.HX_prime[n,:]*self.HX_prime[m,:]).sum()/self.HPHR[n,n]
                self.Hx[m]              = self.Hx[m] + fac * res
                self.HX_prime[m,:]      = self.HX_prime[m,:] - alpha*fac*self.HX_prime[n,:]

            
    def BulkMinimumLeastSquares(self):
        """ Make minimum least squares solution by solving matrix equations"""
        import numpy as np
        import numpy.linalg as la

        # Create full solution, first calculate the mean of the posterior analysis

        HPH                 = np.dot(self.HX_prime,np.transpose(self.HX_prime))/(self.nmembers-1)   # HPH = 1/N * HX' * (HX')^T
        self.HPHR[:,:]      = HPH+self.R                                                            # HPHR = HPH + R
        HPb                 = np.dot(self.X_prime,np.transpose(self.HX_prime))/(self.nmembers-1)    # HP = 1/N X' * (HX')^T
        self.KG[:,:]        = np.dot(HPb,la.inv(self.HPHR))                                         # K = HP/(HPH+R)

        for n in range(self.nobs):
            dummy           = self.Localize(n)

        self.x[:]           = self.x + np.dot(self.KG,self.obs-self.Hx)                             # xa = xp + K (y-Hx)

        # And next make the updated ensemble deviations. Note that we calculate P by using the full equation (10) at once, and 
        # not in a serial update fashion as described in Whitaker and Hamill. 
        # For the current problem with limited N_obs this is easier, or at least more straightforward to do.

        I                   = np.identity(self.nlag*self.nparams)
        sHPHR               = la.cholesky(self.HPHR)                                  # square root of HPH+R
        part1               = np.dot(HPb,np.transpose(la.inv(sHPHR)))                 # HP(sqrt(HPH+R))^-1
        part2               = la.inv(sHPHR+np.sqrt(self.R))                           # (sqrt(HPH+R)+sqrt(R))^-1
        Kw                  = np.dot(part1,part2)                                     # K~
        self.X_prime[:,:]   = np.dot(I,self.X_prime)-np.dot(Kw,self.HX_prime)         # HX' = I - K~ * HX'

        P_opt               = np.dot(self.X_prime,np.transpose(self.X_prime))/(self.nmembers-1)

        # Now do the adjustments of the modeled mixing ratios using the linearized ensemble. These are not strictly needed but can be used
        # for diagnosis.

        part3               = np.dot(HPH,np.transpose(la.inv(sHPHR)))                           # HPH(sqrt(HPH+R))^-1
        Kw                  = np.dot(part3,part2)                                               # K~
        self.Hx[:]          = self.Hx + np.dot(np.dot(HPH,la.inv(self.HPHR)),self.obs-self.Hx)  # Hx  = Hx+ HPH/HPH+R (y-Hx)
        self.HX_prime[:,:]  = self.HX_prime-np.dot(Kw,self.HX_prime)                            # HX' = HX'- K~ * HX'

        msg = 'Minimum Least Squares solution was calculated, returning' ; logging.info(msg)

        return None

    def SetLocalization(self):
        """ determine which localization to use """

        self.localization = True
        self.localizetype = "None"
    
        msg       = "Current localization option is set to %s"%self.localizetype  ; logging.info(msg)

    def Localize(self,n):
        """ localize the Kalman Gain matrix """
        import numpy as np
    
        if not self.localization: return 

        return

################### End Class Optimizer ###################



if __name__ == "__main__":

    sys.path.append('../../')

    import os
    import sys
    from da.tools.general import StartLogger 
    from da.tools.initexit import CycleControl 
    from da.ct.statevector import CtStateVector, PrepareState
    from da.ct.obs import CtObservations
    import numpy as np
    import datetime
    import da.tools.rc as rc

    opts = ['-v']
    args = {'rc':'../../da.rc','logfile':'da_initexit.log','jobrcfilename':'test.rc'}

    StartLogger()
    DaCycle = CycleControl(opts,args)

    DaCycle.Initialize()
    print DaCycle

    StateVector = PrepareState(DaCycle)

    samples = CtObservations(DaCycle.DaSystem,datetime.datetime(2005,3,5))
    dummy = samples.AddObs()
    dummy = samples.AddSimulations('/Users/peters/tmp/test_da/output/20050305/samples.000.nc')


    nobs      = len(samples.Data)
    dims      = ( int(DaCycle['time.nlag']),
490
                  int(DaCycle['da.optimizer.nmembers']),
491
492
493
494
495
496
497
498
499
500
                  int(DaCycle.DaSystem['nparameters']),
                  nobs,  )

    opt = CtOptimizer(dims)

    opt.StateToMatrix(StateVector)

    opt.MinimumLeastSquares()

    opt.MatrixToState(StateVector)