test_optimizer.py 7.25 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
"""CarbonTracker Data Assimilation Shell (CTDAS) Copyright (C) 2017 Wouter Peters. 
Users are recommended to contact the developers (wouter.peters@wur.nl) to receive
updates of the code. See also: http://www.carbontracker.eu. 

This program is free software: you can redistribute it and/or modify it under the
terms of the GNU General Public License as published by the Free Software Foundation, 
version 3. This program is distributed in the hope that it will be useful, but 
WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS 
FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. 

You should have received a copy of the GNU General Public License along with this 
program. If not, see <http://www.gnu.org/licenses/>."""
13
14
15
16
17
18
19
20
21
22
23
#!/usr/bin/env python
# test_optimizer.py

"""
Author : peters 

Revision History:
File created on 04 Aug 2010.

"""

24

25
def serial_py_against_serial_fortran():
26
27
28
29
30
31
32
33
34
35
36
37
    """ Test the solution of the serial algorithm against the CT cy2 fortran generated one """

    # get data from the savestate.hdf file from the first cycle of CarbonTracker 2009 release

    print "WARNING: The optimization algorithm has changed from the CT2009 release because of a bug"
    print "WARNING: in the fortran code. Hence, the two solutions calculated are no longer the same."
    print "WARNING: To change the python algorithm so that it corresponds to the fortran, change the"
    print "WARNING: loop from m=n+1,nfcast to m=1,nfcast"

    savefile = '/data/CO2/peters/carbontracker/raw/ct09rc0i/20000108/savestate.hdf'
    print savefile

38
    f = Nio.open_file(savefile, 'r')
39
40
41
    obs = f.variables['co2_obs_fcast'].get_value()
    sel_obs = obs.shape[0]

42
43
44
    dims = (int(dacycle.da_settings['time.nlag']),
                  int(dacycle.da_settings['forecast.nmembers']),
                  int(dacycle.dasystem.da_settings['nparameters']),
45
                  sel_obs,)
46

47
    nlag, nmembers, nparams, nobs = dims
48
49
50
51

    optserial = CtOptimizer(dims)
    opt = optserial

52
    opt.set_localization('CT2007')
53
54
55
56
57
58
59
60
61
62

    obs = f.variables['co2_obs_fcast'].get_value()[0:nobs]
    opt.obs = obs
    sim = f.variables['co2_sim_fcast'].get_value()[0:nobs]
    opt.Hx = sim
    error = f.variables['error_sim_fcast'].get_value()[0:nobs]
    flags = f.variables['flag_sim_fcast'].get_value()[0:nobs]
    opt.flags = flags
    simana = f.variables['co2_sim_ana'].get_value()[0:nobs]

63
    for n in range(nobs): opt.R[n, n] = np.double(error[n] ** 2)
64

65
66
    xac = []
    adX = []
67
    for lag in range(nlag):
68
69
70
71
72
73
        xpc = f.variables['xpc_%02d' % (lag + 1)].get_value()
        opt.x[lag * nparams:(lag + 1) * nparams] = xpc
        X = f.variables['pdX_%02d' % (lag + 1)].get_value()
        opt.X_prime[lag * nparams:(lag + 1) * nparams, :] = np.transpose(X)
        HX = f.variables['dF'][:, 0:sel_obs]
        opt.HX_prime[:, :] = np.transpose(HX)
74
75
76

        # Also create arrays of the analysis of the fortran code for later comparison

77
78
        xac.extend (f.variables['xac_%02d' % (lag + 1)].get_value())
        adX.append (f.variables['adX_%02d' % (lag + 1)].get_value())
79

80
81
    xac = np.array(xac)
    X_prime = np.array(adX).swapaxes(1, 2).reshape((opt.nparams * opt.nlag, opt.nmembers))
82

83
    opt.serial_minimum_least_squares()
84
85

    print "Maximum differences and correlation of 2 state vectors:"
86
    print np.abs(xac - opt.x).max(), np.corrcoef(xac, opt.x)[0, 1]
87
88
       
    plt.figure(1)
89
90
    plt.plot(opt.x, label='SerialPy')
    plt.plot(xac, label='SerialFortran')
91
92
93
94
95
    plt.grid(True)
    plt.legend(loc=0)
    plt.title('Analysis of state vector')

    print "Maximum differences of 2 state vector deviations:"
96
    print np.abs(X_prime - opt.X_prime).max()
97
98

    plt.figure(2)
99
100
    plt.plot(opt.X_prime.flatten(), label='SerialPy')
    plt.plot(X_prime.flatten(), label='SerialFortran')
101
102
103
104
105
    plt.grid(True)
    plt.legend(loc=0)
    plt.title('Analysis of state vector deviations')

    print "Maximum differences and correlation of 2 simulated obs vectors:"
106
    print np.abs(simana - opt.Hx).max(), np.corrcoef(simana, opt.Hx)[0, 1]
107
108

    plt.figure(3)
109
110
    plt.plot(opt.Hx, label='SerialPy')
    plt.plot(simana, label='SerialFortran')
111
112
    plt.grid(True)
    plt.legend(loc=0)
113
    plt.title('Analysis of CO2 mole fractions')
114
115
116
117
    plt.show()

    f.close()

118
def serial_vs_bulk():
119
120
121
122
123
124
125
    """ A test of the two algorithms currently implemented: serial vs bulk solution """    

    # get data from the savestate.hdf file from the first cycle of CarbonTracker 2009 release

    savefile = '/data/CO2/peters/carbontracker/raw/ct09rc0i/20000108/savestate.hdf'
    print savefile

126
    f = Nio.open_file(savefile, 'r')
127
128
129
130
    obs = f.variables['co2_obs_fcast'].get_value()

    nobs = 77

131
132
133
    dims = (int(dacycle.da_settings['time.nlag']),
                  int(dacycle.da_settings['forecast.nmembers']),
                  int(dacycle.dasystem.da_settings['nparameters']),
134
                  nobs,)
135

136
    nlag, nmembers, nparams, nobs = dims
137

138
    optbulk = CtOptimizer(dims)
139
140
    optserial = CtOptimizer(dims)

141
    for o, opt in enumerate([optbulk, optserial]):
142

143
        opt.set_localization('CT2007')
144
145
146
147
148
149
150
151
152
153

        obs = f.variables['co2_obs_fcast'].get_value()[0:nobs]
        opt.obs = obs
        sim = f.variables['co2_sim_fcast'].get_value()[0:nobs]
        opt.Hx = sim
        error = f.variables['error_sim_fcast'].get_value()[0:nobs]
        flags = f.variables['flag_sim_fcast'].get_value()[0:nobs]
        opt.flags = flags

        for n in range(nobs): 
154
            opt.R[n, n] = np.double(error[n] ** 2)
155

156
        xac = []
157
        for lag in range(nlag):
158
159
160
161
162
163
            xpc = f.variables['xpc_%02d' % (lag + 1)].get_value()
            opt.x[lag * nparams:(lag + 1) * nparams] = xpc
            X = f.variables['pdX_%02d' % (lag + 1)].get_value()
            opt.X_prime[lag * nparams:(lag + 1) * nparams, :] = np.transpose(X)
            HX = f.variables['dF'][:, 0:nobs]
            opt.HX_prime[:, :] = np.transpose(HX)
164
165

        if o == 0:
166
            opt.bulk_minimum_least_squares()
167
168
169
170
171
172
            x1 = opt.x
            xp1 = opt.X_prime
            hx1 = opt.Hx
            hxp1 = opt.HX_prime
            hphr1 = opt.HPHR
            k1 = opt.KG
173
        if o == 1:
174
            opt.serial_minimum_least_squares()
175
176
177
178
179
180
            x2 = opt.x
            xp2 = opt.X_prime
            hx2 = opt.Hx
            hxp2 = opt.HX_prime
            hphr2 = opt.HPHR
            k2 = opt.KG
181
182
183
184
           
    plt.figure()

    print "Maximum differences and correlation of 2 state vectors:"
185
    print np.abs(x2 - x1).max(), np.corrcoef(x2, x1)[0, 1]
186
187
       
    plt.figure(1)
188
189
    plt.plot(x1, label='Serial')
    plt.plot(x2, label='Bulk')
190
191
192
193
194
    plt.grid(True)
    plt.legend(loc=0)
    plt.title('Analysis of state vector')

    print "Maximum differences of 2 state vector deviations:"
195
    print np.abs(xp2 - xp1).max()
196
197

    plt.figure(2)
198
199
    plt.plot(xp1.flatten(), label='Serial')
    plt.plot(xp2.flatten(), label='Bulk')
200
201
202
203
204
    plt.grid(True)
    plt.legend(loc=0)
    plt.title('Analysis of state vector deviations')

    print "Maximum differences and correlation of 2 simulated obs vectors:"
205
    print np.abs(hx2 - hx1).max(), np.corrcoef(hx2, hx1)[0, 1]
206
207

    plt.figure(3)
208
209
    plt.plot(hx1, label='Serial')
    plt.plot(hx2, label='Bulk')
210
    plt.title('Analysis of CO2 mole fractions')
211
212
213
214
215
216
217
218
219
220
    plt.grid(True)
    plt.legend(loc=0)

    plt.show()

    f.close()



if __name__ == "__main__":
221
    pass