test_optimizer.py 7.25 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
"""CarbonTracker Data Assimilation Shell (CTDAS) Copyright (C) 2017 Wouter Peters. 
Users are recommended to contact the developers (wouter.peters@wur.nl) to receive
updates of the code. See also: http://www.carbontracker.eu. 

This program is free software: you can redistribute it and/or modify it under the
terms of the GNU General Public License as published by the Free Software Foundation, 
version 3. This program is distributed in the hope that it will be useful, but 
WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS 
FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. 

You should have received a copy of the GNU General Public License along with this 
program. If not, see <http://www.gnu.org/licenses/>."""
13
14
15
16
17
18
19
20
21
22
23
#!/usr/bin/env python
# test_optimizer.py

"""
Author : peters 

Revision History:
File created on 04 Aug 2010.

"""

24
25


26
def serial_py_against_serial_fortran():
27
28
29
30
31
32
33
34
35
36
37
38
    """ Test the solution of the serial algorithm against the CT cy2 fortran generated one """

    # get data from the savestate.hdf file from the first cycle of CarbonTracker 2009 release

    print "WARNING: The optimization algorithm has changed from the CT2009 release because of a bug"
    print "WARNING: in the fortran code. Hence, the two solutions calculated are no longer the same."
    print "WARNING: To change the python algorithm so that it corresponds to the fortran, change the"
    print "WARNING: loop from m=n+1,nfcast to m=1,nfcast"

    savefile = '/data/CO2/peters/carbontracker/raw/ct09rc0i/20000108/savestate.hdf'
    print savefile

39
    f = Nio.open_file(savefile, 'r')
40
41
42
    obs = f.variables['co2_obs_fcast'].get_value()
    sel_obs = obs.shape[0]

43
44
45
    dims = (int(dacycle.da_settings['time.nlag']),
                  int(dacycle.da_settings['forecast.nmembers']),
                  int(dacycle.dasystem.da_settings['nparameters']),
46
                  sel_obs,)
47

48
    nlag, nmembers, nparams, nobs = dims
49
50
51
52

    optserial = CtOptimizer(dims)
    opt = optserial

53
    opt.set_localization('CT2007')
54
55
56
57
58
59
60
61
62
63

    obs = f.variables['co2_obs_fcast'].get_value()[0:nobs]
    opt.obs = obs
    sim = f.variables['co2_sim_fcast'].get_value()[0:nobs]
    opt.Hx = sim
    error = f.variables['error_sim_fcast'].get_value()[0:nobs]
    flags = f.variables['flag_sim_fcast'].get_value()[0:nobs]
    opt.flags = flags
    simana = f.variables['co2_sim_ana'].get_value()[0:nobs]

64
    for n in range(nobs): opt.R[n, n] = np.double(error[n] ** 2)
65

66
67
    xac = []
    adX = []
68
    for lag in range(nlag):
69
70
71
72
73
74
        xpc = f.variables['xpc_%02d' % (lag + 1)].get_value()
        opt.x[lag * nparams:(lag + 1) * nparams] = xpc
        X = f.variables['pdX_%02d' % (lag + 1)].get_value()
        opt.X_prime[lag * nparams:(lag + 1) * nparams, :] = np.transpose(X)
        HX = f.variables['dF'][:, 0:sel_obs]
        opt.HX_prime[:, :] = np.transpose(HX)
75
76
77

        # Also create arrays of the analysis of the fortran code for later comparison

78
79
        xac.extend (f.variables['xac_%02d' % (lag + 1)].get_value())
        adX.append (f.variables['adX_%02d' % (lag + 1)].get_value())
80

81
82
    xac = np.array(xac)
    X_prime = np.array(adX).swapaxes(1, 2).reshape((opt.nparams * opt.nlag, opt.nmembers))
83

84
    opt.serial_minimum_least_squares()
85
86

    print "Maximum differences and correlation of 2 state vectors:"
87
    print np.abs(xac - opt.x).max(), np.corrcoef(xac, opt.x)[0, 1]
88
89
       
    plt.figure(1)
90
91
    plt.plot(opt.x, label='SerialPy')
    plt.plot(xac, label='SerialFortran')
92
93
94
95
96
    plt.grid(True)
    plt.legend(loc=0)
    plt.title('Analysis of state vector')

    print "Maximum differences of 2 state vector deviations:"
97
    print np.abs(X_prime - opt.X_prime).max()
98
99

    plt.figure(2)
100
101
    plt.plot(opt.X_prime.flatten(), label='SerialPy')
    plt.plot(X_prime.flatten(), label='SerialFortran')
102
103
104
105
106
    plt.grid(True)
    plt.legend(loc=0)
    plt.title('Analysis of state vector deviations')

    print "Maximum differences and correlation of 2 simulated obs vectors:"
107
    print np.abs(simana - opt.Hx).max(), np.corrcoef(simana, opt.Hx)[0, 1]
108
109

    plt.figure(3)
110
111
    plt.plot(opt.Hx, label='SerialPy')
    plt.plot(simana, label='SerialFortran')
112
113
    plt.grid(True)
    plt.legend(loc=0)
114
    plt.title('Analysis of CO2 mole fractions')
115
116
117
118
    plt.show()

    f.close()

119
def serial_vs_bulk():
120
121
122
123
124
125
126
    """ A test of the two algorithms currently implemented: serial vs bulk solution """    

    # get data from the savestate.hdf file from the first cycle of CarbonTracker 2009 release

    savefile = '/data/CO2/peters/carbontracker/raw/ct09rc0i/20000108/savestate.hdf'
    print savefile

127
    f = Nio.open_file(savefile, 'r')
128
129
130
131
    obs = f.variables['co2_obs_fcast'].get_value()

    nobs = 77

132
133
134
    dims = (int(dacycle.da_settings['time.nlag']),
                  int(dacycle.da_settings['forecast.nmembers']),
                  int(dacycle.dasystem.da_settings['nparameters']),
135
                  nobs,)
136

137
    nlag, nmembers, nparams, nobs = dims
138

139
    optbulk = CtOptimizer(dims)
140
141
    optserial = CtOptimizer(dims)

142
    for o, opt in enumerate([optbulk, optserial]):
143

144
        opt.set_localization('CT2007')
145
146
147
148
149
150
151
152
153
154

        obs = f.variables['co2_obs_fcast'].get_value()[0:nobs]
        opt.obs = obs
        sim = f.variables['co2_sim_fcast'].get_value()[0:nobs]
        opt.Hx = sim
        error = f.variables['error_sim_fcast'].get_value()[0:nobs]
        flags = f.variables['flag_sim_fcast'].get_value()[0:nobs]
        opt.flags = flags

        for n in range(nobs): 
155
            opt.R[n, n] = np.double(error[n] ** 2)
156

157
        xac = []
158
        for lag in range(nlag):
159
160
161
162
163
164
            xpc = f.variables['xpc_%02d' % (lag + 1)].get_value()
            opt.x[lag * nparams:(lag + 1) * nparams] = xpc
            X = f.variables['pdX_%02d' % (lag + 1)].get_value()
            opt.X_prime[lag * nparams:(lag + 1) * nparams, :] = np.transpose(X)
            HX = f.variables['dF'][:, 0:nobs]
            opt.HX_prime[:, :] = np.transpose(HX)
165
166

        if o == 0:
167
            opt.bulk_minimum_least_squares()
168
169
170
171
172
173
            x1 = opt.x
            xp1 = opt.X_prime
            hx1 = opt.Hx
            hxp1 = opt.HX_prime
            hphr1 = opt.HPHR
            k1 = opt.KG
174
        if o == 1:
175
            opt.serial_minimum_least_squares()
176
177
178
179
180
181
            x2 = opt.x
            xp2 = opt.X_prime
            hx2 = opt.Hx
            hxp2 = opt.HX_prime
            hphr2 = opt.HPHR
            k2 = opt.KG
182
183
184
185
           
    plt.figure()

    print "Maximum differences and correlation of 2 state vectors:"
186
    print np.abs(x2 - x1).max(), np.corrcoef(x2, x1)[0, 1]
187
188
       
    plt.figure(1)
189
190
    plt.plot(x1, label='Serial')
    plt.plot(x2, label='Bulk')
191
192
193
194
195
    plt.grid(True)
    plt.legend(loc=0)
    plt.title('Analysis of state vector')

    print "Maximum differences of 2 state vector deviations:"
196
    print np.abs(xp2 - xp1).max()
197
198

    plt.figure(2)
199
200
    plt.plot(xp1.flatten(), label='Serial')
    plt.plot(xp2.flatten(), label='Bulk')
201
202
203
204
205
    plt.grid(True)
    plt.legend(loc=0)
    plt.title('Analysis of state vector deviations')

    print "Maximum differences and correlation of 2 simulated obs vectors:"
206
    print np.abs(hx2 - hx1).max(), np.corrcoef(hx2, hx1)[0, 1]
207
208

    plt.figure(3)
209
210
    plt.plot(hx1, label='Serial')
    plt.plot(hx2, label='Bulk')
211
    plt.title('Analysis of CO2 mole fractions')
212
213
214
215
216
217
218
219
220
221
    plt.grid(True)
    plt.legend(loc=0)

    plt.show()

    f.close()



if __name__ == "__main__":
222
    pass