optimizer.py 21.8 KB
Newer Older
1
2
3
4
#!/usr/bin/env python
# optimizer.py

"""
Peters, Wouter's avatar
Peters, Wouter committed
5
6
.. module:: optimizer
.. moduleauthor:: Wouter Peters 
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66

Revision History:
File created on 28 Jul 2010.

"""

import os
import sys
import logging
import datetime

identifier = 'Optimizer baseclass'
version    = '0.0'

################### Begin Class Optimizer ###################

class Optimizer(object):
    """
        This creates an instance of an optimization object. It handles the minimum least squares optimization
        of the state vector given a set of sample objects. Two routines will be implemented: one where the optimization
        is sequential and one where it is the equivalent matrix solution. The choice can be made based on considerations of speed
        and efficiency.
    """

    def __init__(self):
        self.Identifier = self.getid()
        self.Version    = self.getversion()

        msg                 = 'Optimizer object initialized: %s'%self.Identifier ; logging.info(msg)

    def getid(self):
        return identifier

    def getversion(self):
        return version

    def Initialize(self, dims):

        self.nlag               = dims[0]
        self.nmembers           = dims[1]
        self.nparams            = dims[2]
        self.nobs               = dims[3]
        self.CreateMatrices()

        return None

    def CreateMatrices(self):
        """ Create Matrix space needed in optimization routine """
        import numpy as np

        # mean state  [X]
        self.x               = np.zeros( (self.nlag*self.nparams,), float)
        # deviations from mean state  [X']
        self.X_prime         = np.zeros( (self.nlag*self.nparams,self.nmembers,), float)
        # mean state, transported to observation space [ H(X) ]
        self.Hx              = np.zeros( (self.nobs,), float)
        # deviations from mean state, transported to observation space [ H(X') ]
        self.HX_prime        = np.zeros( (self.nobs,self.nmembers), float)
        # observations
        self.obs             =  np.zeros( (self.nobs,), float)
67
68
        # observation ids
        self.obs_ids         =  np.zeros( (self.nobs,), float)
69
70
        # covariance of observations
        self.R               =  np.zeros( (self.nobs,self.nobs,), float)
71
        # localization of obs
72
        self.may_localize       = np.zeros(self.nobs,bool)
73
        # rejection of obs
74
75
76
        self.may_reject         = np.zeros(self.nobs,bool)
        # flags of obs
        self.flags         = np.zeros(self.nobs,int)
77
78
        # species type
        self.species       = np.zeros(self.nobs,str)
79
80
        # species type
        self.sitecode      = np.zeros(self.nobs,str)
81
82
83

        # species mask
        self.speciesmask    = {}
84

85
86
87
88
89
90
91
        # Total covariance of fluxes and obs in units of obs [H P H^t + R]
        self.HPHR            =  np.zeros( (self.nobs,self.nobs,), float)
        # Kalman Gain matrix
        self.KG              =  np.zeros( (self.nlag*self.nparams,self.nobs,), float)

    def StateToMatrix(self,StateVector):
        import numpy as np
92
93
94
95
        try:
            import matplotlib.pyplot as plt
        except:
            pass
96

97
        allsites=[]      # collect all obs for n=1,..,nlag
98
99
100
        allobs=[]      # collect all obs for n=1,..,nlag
        allmdm=[]      # collect all mdm for n=1,..,nlag
        allsamples=[]  # collect all model samples for n=1,..,nlag
101
        allids=[]  # collect all model samples for n=1,..,nlag
102
103
104
        allreject=[]  # collect all model samples for n=1,..,nlag
        alllocalize=[]  # collect all model samples for n=1,..,nlag
        allflags=[]  # collect all model samples for n=1,..,nlag
105
        allspecies=[]  # collect all model samples for n=1,..,nlag
106
        allsimulated=None  # collect all members model samples for n=1,..,nlag
107

108
109
        for n in range(self.nlag):

110
            Samples                                             = StateVector.ObsToAssimmilate[n]
111
            members                                             = StateVector.EnsembleMembers[n]
112
            self.x[n*self.nparams:(n+1)*self.nparams]           = members[0].ParameterValues
113
            self.X_prime[n*self.nparams:(n+1)*self.nparams,:]   = np.transpose(np.array([m.ParameterValues for m in members]))
114

115
            if Samples != None:
Peters, Wouter's avatar
Peters, Wouter committed
116

117
                self.rejection_threshold                        = Samples.rejection_threshold
118

119
120
121
122
123
                allreject.extend(                                 Samples.Data.getvalues('may_reject') )
                alllocalize.extend(                               Samples.Data.getvalues('may_localize') )
                allflags.extend(                                  Samples.Data.getvalues('flag') )
                allspecies.extend(                                Samples.Data.getvalues('species') )
                allobs.extend(                                    Samples.Data.getvalues('obs') )
124
                allsites.extend(                                  Samples.Data.getvalues('code') )
125
126
                allmdm.extend(                                    Samples.Data.getvalues('mdm')  )
                allids.extend(                                    Samples.Data.getvalues('id')  )
127

128
                simulatedensemble =                               Samples.Data.getvalues('simulated')
Peters, Wouter's avatar
Peters, Wouter committed
129

130
131
                if allsimulated == None :
                    allsimulated = np.array(simulatedensemble)
Peters, Wouter's avatar
Peters, Wouter committed
132
                else:
133
134
135
                    allsimulated = np.concatenate((allsimulated,np.array(simulatedensemble)),axis=0)


Peters, Wouter's avatar
Peters, Wouter committed
136

137
        self.obs[:]                                             = np.array(allobs)
138
        self.obs_ids[:]                                         = np.array(allids)
139
140
        self.HX_prime[:,:]                                      = np.array(allsimulated)
        self.Hx[:]                                              = self.HX_prime[:,0]
141

142
143
144
        self.may_reject[:]                                      = np.array(allreject)
        self.may_localize[:]                                    = np.array(alllocalize)
        self.flags[:]                                           = np.array(allflags)
145
        self.species[:]                                         = np.array(allspecies)
146
        self.sitecode                                           = allsites
147

148
        self.X_prime                                            = self.X_prime - self.x[:,np.newaxis] # make into a deviation matrix
149
150
        self.HX_prime                                           = self.HX_prime - self.Hx[:,np.newaxis] # make a deviation matrix

151

152
        for i,mdm in enumerate(allmdm):
153
154

            self.R[i,i]                                             = mdm**2
155
156
157
158
159
160
161
162
163
164
165
166

        return None

    def MatrixToState(self,StateVector):
        import numpy as np

        for n in range(self.nlag):

            members                              = StateVector.EnsembleMembers[n]
            for m,mem in enumerate(members):
                members[m].ParameterValues[:]    = self.X_prime[n*self.nparams:(n+1)*self.nparams,m] + self.x[n*self.nparams:(n+1)*self.nparams]     

167
168
        StateVector.isOptimized = True
        msg = 'Returning optimized data to the StateVector, setting "StateVector.isOptimized = True" ' ; logging.debug(msg)
169
170
        return None

171
    def WriteDiagnostics(self,DaCycle, StateVector,type='prior'):
172
173
174
175
176
177
178
179
180
181
        """
            Open a NetCDF file and write diagnostic output from optimization process:

                - calculated residuals
                - model-data mismatches
                - HPH^T
                - prior ensemble of samples
                - posterior ensemble of samples
                - prior ensemble of fluxes
                - posterior ensemble of fluxes
182
183

            The type designation refers to the writing of prior or posterior data and is used in naming the variables"
184
        """
185
186
        import da.tools.io4 as io
        #import da.tools.io as io
187

188
        outdir          = DaCycle['dir.diagnostics']
189
        filename        = os.path.join(outdir,'optimizer.%s.nc'% DaCycle['time.start'].strftime('%Y%m%d') )
190
        DaCycle.OutputFileList += ( filename, )
191

192
193
194
        # Open or create file

        if type == 'prior':
195
            f           = io.CT_CDF(filename,method='create')
Peters, Wouter's avatar
Peters, Wouter committed
196
            msg         = 'Creating new diagnostics file for optimizer (%s)' % filename ; logging.debug(msg)
197
        elif type == 'optimized':
198
            f           = io.CT_CDF(filename,method='write')
Peters, Wouter's avatar
Peters, Wouter committed
199
            msg         = 'Opening existing diagnostics file for optimizer (%s)' % filename ; logging.debug(msg)
200
201

        # Add dimensions 
202
203
204
205
206

        dimparams       = f.AddParamsDim(self.nparams)
        dimmembers      = f.AddMembersDim(self.nmembers)
        dimlag          = f.AddLagDim(self.nlag, unlimited=False)
        dimobs          = f.AddObsDim(self.nobs)
207
        dimstate        = f.AddDim('nstate',self.nparams*self.nlag)
208
        dim200char      = f.AddDim('string_of200chars',200)
209

210
211
        # Add data, first the ones that are written both before and after the optimization

212
213
        data =  self.x

214
        savedict                = io.std_savedict.copy() 
215
216
217
218
219
220
221
222
223
224
        savedict['name']        = "statevectormean_%s" % type
        savedict['long_name']   = "full_statevector_mean_%s" % type
        savedict['units']       = "unitless"
        savedict['dims']        = dimstate
        savedict['values']      = data.tolist()
        savedict['comment']     = 'Full %s state vector mean '% type
        dummy                   = f.AddData(savedict)

        data =  self.X_prime

225
        savedict                = io.std_savedict.copy()
226
227
        savedict['name']        = "statevectordeviations_%s" % type
        savedict['long_name']   = "full_statevector_deviations_%s" % type
228
        savedict['units']       = "unitless"
229
230
231
232
233
234
235
        savedict['dims']        = dimstate+dimmembers
        savedict['values']      = data.tolist()
        savedict['comment']     = 'Full state vector %s deviations as resulting from the optimizer'% type
        dummy                   = f.AddData(savedict)

        data =  self.Hx

236
        savedict                = io.std_savedict.copy()
237
238
239
240
241
242
243
244
245
246
        savedict['name']        = "modelsamplesmean_%s"%type
        savedict['long_name']   = "modelsamplesforecastmean_%s" %type
        savedict['units']       = "mol mol-1"
        savedict['dims']        = dimobs
        savedict['values']      = data.tolist()
        savedict['comment']     = '%s mean mixing ratios based on %s state vector'% (type,type,)
        dummy                   = f.AddData(savedict)

        data =  self.HX_prime

247
        savedict                = io.std_savedict.copy()
248
249
250
251
        savedict['name']        = "modelsamplesdeviations_%s"% type
        savedict['long_name']   = "modelsamplesforecastdeviations_%s"%type
        savedict['units']       = "mol mol-1"
        savedict['dims']        = dimobs+dimmembers
252
        savedict['values']      = data.tolist()
253
        savedict['comment']     = '%s mixing ratio deviations based on %s state vector'% (type,type,)
254
255
        dummy                   = f.AddData(savedict)

256

257
258
259
260
        # Continue with prior only data

        if type == 'prior':

261
262
263
264
265
266
267
268
269
270
271
272
            data = self.sitecode

            savedict                = io.std_savedict.copy()
            savedict['name']        = "sitecode"
            savedict['long_name']   = "site code propagated from observation file"
            savedict['dtype']       = "char"
            savedict['dims']        = dimobs+dim200char
            savedict['values']      = data
            savedict['missing_value'] = '!'
            status                   = f.AddData(savedict)


273
274
            data =  self.obs

275
            savedict                = io.std_savedict.copy()
276
277
278
279
280
281
282
283
            savedict['name']        = "observed"
            savedict['long_name']   = "observedvalues"
            savedict['units']       = "mol mol-1"
            savedict['dims']        = dimobs
            savedict['values']      = data.tolist()
            savedict['comment']     = 'Observations used in optimization'
            dummy                   = f.AddData(savedict)

284
285
286
287
288
289
290
291
292
293
294
295
            data =  self.obs_ids

            savedict                = io.std_savedict.copy()
            savedict['name']        = "obspack_num"
            savedict['dtype']       = "int64"
            savedict['long_name']   = "Unique_ObsPack_observation_number"
            savedict['units']       = ""
            savedict['dims']        = dimobs
            savedict['values']      = data.tolist()
            savedict['comment']     = 'Unique observation number across the entire ObsPack distribution'
            dummy                   = f.AddData(savedict)

296
297
298

            data =  self.R

299
            savedict                = io.std_savedict.copy()
Peters, Wouter's avatar
Peters, Wouter committed
300
301
            savedict['name']        = "modeldatamismatchvariance"
            savedict['long_name']   = "modeldatamismatch variance"
302
303
304
            savedict['units']       = "[mol mol-1]^2"
            savedict['dims']        = dimobs+dimobs
            savedict['values']      = data.tolist()
Peters, Wouter's avatar
Peters, Wouter committed
305
            savedict['comment']     = 'Variance of mole fractions resulting from model-data mismatch'
306
307
308
309
310
311
312
313
            dummy                   = f.AddData(savedict)

        # Continue with posterior only data

        elif type == 'optimized':

            data =  self.HPHR

314
            savedict                = io.std_savedict.copy()
315
316
            savedict['name']        = "totalmolefractionvariance"
            savedict['long_name']   = "totalmolefractionvariance"
317
318
319
            savedict['units']       = "[mol mol-1]^2"
            savedict['dims']        = dimobs+dimobs
            savedict['values']      = data.tolist()
Peters, Wouter's avatar
Peters, Wouter committed
320
            savedict['comment']     = 'Variance of mole fractions resulting from prior state and model-data mismatch'
321
322
323
324
            dummy                   = f.AddData(savedict)

            data =  self.flags

325
            savedict                = io.std_savedict.copy()
326
327
            savedict['name']        = "flag"
            savedict['long_name']   = "flag_for_obs_model"
328
            savedict['units']       = "None"
329
330
            savedict['dims']        = dimobs
            savedict['values']      = data.tolist()
331
            savedict['comment']     = 'Flag (0/1/2/99) for observation value, 0 means okay, 1 means QC error, 2 means rejected, 99 means not sampled'
332
333
334
335
            dummy                   = f.AddData(savedict)

            data =  self.KG

336
            savedict                = io.std_savedict.copy()
337
338
339
340
341
342
343
344
            savedict['name']        = "kalmangainmatrix"
            savedict['long_name']   = "kalmangainmatrix"
            savedict['units']       = "unitless molefraction-1"
            savedict['dims']        = dimstate+dimobs
            savedict['values']      = data.tolist()
            savedict['comment']     = 'Kalman gain matrix of all obs and state vector elements'
            dummy                   = f.AddData(savedict)

345
346
        dummy                       = f.close()
        msg = 'Diagnostics file closed ' ; logging.debug(msg)
347
348
349

        return None

350
351
352
    def SerialMinimumLeastSquares(self):
        """ Make minimum least squares solution by looping over obs"""
        import numpy as np
353
        import numpy.linalg as la
354
355

        tvalue=1.97591
356

357
358
        for n in range(self.nobs):

359
360
361
            # Screen for flagged observations (for instance site not found, or no sample written from model)

            if self.flags[n] != 0:
362
                msg = 'Skipping observation %d because of flag value %d'%(n,self.flags[n]) ; logging.debug(msg)
363
364
                continue

365
366
367

            # Screen for outliers greather than 3x model-data mismatch, only apply if obs may be rejected

368
369
370
            res                         = self.obs[n]-self.Hx[n]

            if self.may_reject[n]:
371
372
373
                threshold =  self.rejection_threshold*np.sqrt(self.R[n,n])
                if np.abs(res) > threshold:
                    msg = 'Rejecting observation %d because residual (%f) exceeds threshold (%f)'%(n,res,threshold) ; logging.debug(msg)
374
                    self.flags[n] = 2
375
                    continue
376
377
378
379
380
381

            PHt                         = 1./(self.nmembers-1)*np.dot(self.X_prime,self.HX_prime[n,:])
            self.HPHR[n,n]              = 1./(self.nmembers-1)*(self.HX_prime[n,:]*self.HX_prime[n,:]).sum()+self.R[n,n]

            self.KG[:,n]                = PHt/self.HPHR[n,n]

382
            if self.may_localize[n]:
383
                dummy                   = self.Localize(n)
384
385
386
                msg = 'Localized observation %d'%(n,) ; logging.debug(msg)
            else:
                msg = 'Not allowed to Localize observation %d'%(n,) ; logging.debug(msg)
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504

            alpha                       = np.double(1.0)/(np.double(1.0)+np.sqrt( (self.R[n,n])/self.HPHR[n,n] ) )

            self.x[:]                   = self.x + self.KG[:,n]*res

            for r in range(self.nmembers):
                self.X_prime[:,r]       = self.X_prime[:,r]-alpha*self.KG[:,n]*(self.HX_prime[n,r])

#WP !!!! Very important to first do all obervations from n=1 through the end, and only then update 1,...,n. The current observation
#WP      should always be updated last because it features in the loop of the adjustments !!!!

            for m in range(n+1,self.nobs):
                res                     = self.obs[n]-self.Hx[n]
                fac                     = 1.0/(self.nmembers-1)*(self.HX_prime[n,:]*self.HX_prime[m,:]).sum()/self.HPHR[n,n]
                self.Hx[m]              = self.Hx[m] + fac * res
                self.HX_prime[m,:]      = self.HX_prime[m,:] - alpha*fac*self.HX_prime[n,:]

            for m in range(1,n+1):
                res                     = self.obs[n]-self.Hx[n]
                fac                     = 1.0/(self.nmembers-1)*(self.HX_prime[n,:]*self.HX_prime[m,:]).sum()/self.HPHR[n,n]
                self.Hx[m]              = self.Hx[m] + fac * res
                self.HX_prime[m,:]      = self.HX_prime[m,:] - alpha*fac*self.HX_prime[n,:]

            
    def BulkMinimumLeastSquares(self):
        """ Make minimum least squares solution by solving matrix equations"""
        import numpy as np
        import numpy.linalg as la

        # Create full solution, first calculate the mean of the posterior analysis

        HPH                 = np.dot(self.HX_prime,np.transpose(self.HX_prime))/(self.nmembers-1)   # HPH = 1/N * HX' * (HX')^T
        self.HPHR[:,:]      = HPH+self.R                                                            # HPHR = HPH + R
        HPb                 = np.dot(self.X_prime,np.transpose(self.HX_prime))/(self.nmembers-1)    # HP = 1/N X' * (HX')^T
        self.KG[:,:]        = np.dot(HPb,la.inv(self.HPHR))                                         # K = HP/(HPH+R)

        for n in range(self.nobs):
            dummy           = self.Localize(n)

        self.x[:]           = self.x + np.dot(self.KG,self.obs-self.Hx)                             # xa = xp + K (y-Hx)

        # And next make the updated ensemble deviations. Note that we calculate P by using the full equation (10) at once, and 
        # not in a serial update fashion as described in Whitaker and Hamill. 
        # For the current problem with limited N_obs this is easier, or at least more straightforward to do.

        I                   = np.identity(self.nlag*self.nparams)
        sHPHR               = la.cholesky(self.HPHR)                                  # square root of HPH+R
        part1               = np.dot(HPb,np.transpose(la.inv(sHPHR)))                 # HP(sqrt(HPH+R))^-1
        part2               = la.inv(sHPHR+np.sqrt(self.R))                           # (sqrt(HPH+R)+sqrt(R))^-1
        Kw                  = np.dot(part1,part2)                                     # K~
        self.X_prime[:,:]   = np.dot(I,self.X_prime)-np.dot(Kw,self.HX_prime)         # HX' = I - K~ * HX'

        P_opt               = np.dot(self.X_prime,np.transpose(self.X_prime))/(self.nmembers-1)

        # Now do the adjustments of the modeled mixing ratios using the linearized ensemble. These are not strictly needed but can be used
        # for diagnosis.

        part3               = np.dot(HPH,np.transpose(la.inv(sHPHR)))                           # HPH(sqrt(HPH+R))^-1
        Kw                  = np.dot(part3,part2)                                               # K~
        self.Hx[:]          = self.Hx + np.dot(np.dot(HPH,la.inv(self.HPHR)),self.obs-self.Hx)  # Hx  = Hx+ HPH/HPH+R (y-Hx)
        self.HX_prime[:,:]  = self.HX_prime-np.dot(Kw,self.HX_prime)                            # HX' = HX'- K~ * HX'

        msg = 'Minimum Least Squares solution was calculated, returning' ; logging.info(msg)

        return None

    def SetLocalization(self):
        """ determine which localization to use """

        self.localization = True
        self.localizetype = "None"
    
        msg       = "Current localization option is set to %s"%self.localizetype  ; logging.info(msg)

    def Localize(self,n):
        """ localize the Kalman Gain matrix """
        import numpy as np
    
        if not self.localization: return 

        return

################### End Class Optimizer ###################



if __name__ == "__main__":

    sys.path.append('../../')

    import os
    import sys
    from da.tools.general import StartLogger 
    from da.tools.initexit import CycleControl 
    from da.ct.statevector import CtStateVector, PrepareState
    from da.ct.obs import CtObservations
    import numpy as np
    import datetime
    import da.tools.rc as rc

    opts = ['-v']
    args = {'rc':'../../da.rc','logfile':'da_initexit.log','jobrcfilename':'test.rc'}

    StartLogger()
    DaCycle = CycleControl(opts,args)

    DaCycle.Initialize()
    print DaCycle

    StateVector = PrepareState(DaCycle)

    samples = CtObservations(DaCycle.DaSystem,datetime.datetime(2005,3,5))
    dummy = samples.AddObs()
    dummy = samples.AddSimulations('/Users/peters/tmp/test_da/output/20050305/samples.000.nc')


    nobs      = len(samples.Data)
    dims      = ( int(DaCycle['time.nlag']),
505
                  int(DaCycle['da.optimizer.nmembers']),
506
507
508
509
510
511
512
513
514
515
                  int(DaCycle.DaSystem['nparameters']),
                  nobs,  )

    opt = CtOptimizer(dims)

    opt.StateToMatrix(StateVector)

    opt.MinimumLeastSquares()

    opt.MatrixToState(StateVector)