First Neural ODE example

First Neural ODE example#

A neural ODE is an ODE where a neural network defines its derivative function. \(\dot{u} = NN(u)\)

From: https://docs.sciml.ai/DiffEqFlux/stable/examples/neural_ode/

using Lux, DiffEqFlux, DifferentialEquations, ComponentArrays
using Optimization, OptimizationOptimJL, OptimizationOptimisers
using Random, Plots

rng = Random.default_rng()
[ Info: Precompiling IJuliaExt [2f4121a4-3b3a-5ce6-9c5e-1f2673ce168a]
TaskLocalRNG()

True solution

function trueODEfunc(du, u, p, t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u.^3)'true_A)'
end
trueODEfunc (generic function with 1 method)

The data used for training

u0 = Float32[2.0; 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[begin], tspan[end], length = datasize)
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))
2×30 Matrix{Float32}:
 2.0  1.9465    1.74178  1.23837  0.577127  …  1.40688   1.37023   1.29214
 0.0  0.798832  1.46473  1.80877  1.86465      0.451377  0.728699  0.972102

Make a NeuralODE problem with a neural network defined by Lux.jl.

dudt2 = Lux.Chain(
    x -> x.^3,
    Lux.Dense(2, 50, tanh),
    Lux.Dense(50, 2)
)

p, st = Lux.setup(rng, dudt2)
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps)
NeuralODE(
    model = Chain(
        layer_1 = WrappedFunction(#1),
        layer_2 = Dense(2 => 50, tanh_fast),  # 150 parameters
        layer_3 = Dense(50 => 2),       # 102 parameters
    ),
)         # Total: 252 parameters,
          #        plus 0 states.

Define output, loss, and callback functions.

function predict_neuralode(p)
    Array(prob_neuralode(u0, p, st)[1])
  end

function loss_neuralode(p)
    pred = predict_neuralode(p)
    loss = sum(abs2, ode_data .- pred)
    return loss, pred
end
loss_neuralode (generic function with 1 method)

Do not generate plots by default. Users could change doplot=true to see the figures in the callback fuction.

callback = function (p, l, pred; doplot = false)
    println(l)
    # plot current prediction against data
    if doplot
      plt = scatter(tsteps, ode_data[1,:], label = "data")
      scatter!(plt, tsteps, pred[1,:], label = "prediction")
      plot(plt)
    end
    return false
end
#3 (generic function with 1 method)

Try the callback function on the first iteration.

pinit = ComponentArray(p)
callback(pinit, loss_neuralode(pinit)...; doplot=true)
118.93152
false

Use Optimization.jl to solve the problem.

  • Zygote for automatic differentiation (AD)

  • loss_neuralode as the function to be optimized

  • Make an OptimizationProblem

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pinit)
OptimizationProblem. In-place: true
u0: ComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[-0.13173558 -0.26862946; -0.21219468 -0.29113472; … ; 0.23871331 0.32163706; 0.2673218 -0.17830189], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (weight = Float32[0.039003026 0.28549364 … -0.24841379 0.102247044; 0.23711038 -0.06729417 … 0.070725866 0.0989487], bias = Float32[0.0; 0.0;;]))

Solve the OptimizationProblem.

result_neuralode = Optimization.solve(
    optprob,
    OptimizationOptimisers.ADAM(0.05),
    callback = callback,
    maxiters = 300
)
118.93152
107.588326
106.10397
104.19737
97.24725
92.15314
90.36655
90.212326
89.86787
87.88174
84.30611
79.794754
74.50199
67.97798
60.834373
49.852478
43.669773
43.41168
49.646446
52.95736
48.760284
40.29796
32.5933
36.968784
43.266697
37.661457
28.308996
26.99663
27.117113
18.224869
96.5595
19.507513
25.880632
28.43808
31.63586
34.80781
37.45008
39.565304
41.24104
42.545956
43.48191
43.99707
44.060825
43.736874
43.136986
42.329556
41.334167
40.15575
38.84882
37.48766
36.146152
34.874454
33.717518
32.696476
31.7733
30.857851
29.862919
28.808237
27.685902
26.28656
24.606424
22.672087
20.167015
17.113039
13.772392
10.939747
9.860489
13.007938
14.8824415
10.949396
9.325867
8.410077
8.960457
9.799699
8.659706
9.288935
8.292451
8.015104
6.900593
6.666691
5.8868365
5.243168
5.1806083
5.4701586
5.4545035
4.8181467
4.3685756
3.9891233
3.7556217
3.8554635
3.8198745
3.5525782
3.191701
2.9124403
2.6297073
2.5695117
2.544964
2.3150454
2.1894524
2.0074775
1.8385435
1.8549367
1.8052906
1.7279646
1.7551826
1.6504538
1.5874145
1.6074078
1.5670396
1.5367639
1.5248336
1.4728632
1.4370981
1.3992316
1.3706714
1.3412848
1.305206
1.2548811
1.2194672
1.1932594
1.1558425
1.1368406
1.1092584
1.0908762
1.0662757
1.0483215
1.0350976
1.01947
1.0049571
0.9824165
0.97018325
0.9649516
0.95416635
0.941921
0.925175
0.9117411
0.9005054
0.8941272
0.87876767
0.871321
0.8557811
0.845482
0.83529013
0.8299453
0.81286347
0.8054415
0.8019792
0.78728247
0.78193
0.77108365
0.76428175
0.75724345
0.75075823
0.7411033
0.7392033
0.7296305
0.7196565
0.71166456
0.7058496
0.6999039
0.6956167
0.6877581
0.68062913
0.67444295
0.6687909
0.666411
0.6535747
0.6511683
0.6435649
0.63974446
0.63729036
0.63559884
0.6321424
0.6239543
0.61493677
0.6183981
0.6096776
0.6073146
0.5993138
0.59358007
0.5949444
0.5872536
0.58704114
0.5821489
0.57660115
0.57140446
0.56809527
0.5658814
0.5591117
0.55308956
0.5483318
0.5481536
0.5453068
0.54355454
0.5411277
0.5394622
0.5287483
0.53205395
0.5267897
0.52042353
0.51914907
0.5121167
0.5132708
0.5056295
0.5021682
0.5043211
0.50082844
0.49712506
0.49049583
0.488807
0.4836612
0.48589966
0.47914726
0.47941613
0.4763267
0.4696518
0.47023976
0.46876737
0.46203262
0.45890146
0.45814407
0.45125443
0.45364237
0.45154297
0.44755158
0.4429472
0.44174865
0.43937814
0.43662018
0.43433928
0.42819756
0.4290631
0.4241396
0.42134872
0.41700017
0.41936338
0.41664764
0.41054168
0.4094148
0.40712437
0.40857816
0.40158877
0.4031048
0.40195826
0.3939447
0.39499843
0.3902107
0.39367834
0.38610053
0.38710114
0.390421
0.37968183
0.38185933
0.381964
0.36958292
0.37819535
0.3724618
0.36656803
0.3674983
0.36255923
0.35879308
0.35838723
0.35515106
0.35121173
0.35158592
0.34738743
0.3486128
0.34715125
0.340986
0.3471296
0.34075132
0.33429274
0.34008977
0.335349
0.3268921
0.33041143
0.3265921
0.3224478
0.3244342
0.32013506
0.31521055
0.3153871
0.31209573
0.31137058
0.31365484
0.30831978
0.30435055
0.30782023
0.30088735
0.30174026
0.3167247
0.30817467
0.29165807
0.2964253
0.29488084
0.28661618
0.286528
0.28760532
0.2833498
0.27903017
0.27903017
retcode: Default
u: ComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[0.2608399 -0.7103775; -0.16067556 -1.1602527; … ; 0.15277196 1.2721673; 0.21492538 0.2944236], bias = Float32[0.36572495; 0.44362512; … ; -0.44807923; -0.53865325;;]), layer_3 = (weight = Float32[0.6236109 0.5328911 … -0.50796896 0.35644993; 0.051700193 -0.3421964 … 0.34809616 0.63404644], bias = Float32[-0.21957095; 0.017598046;;]))

Use another optimization algorithm Optim.BFGS() and start from where the ADAM() algorithm stopped.

optprob2 = remake(optprob, u0 = result_neuralode.u)

result_neuralode2 = Optimization.solve(
    optprob2,
    Optim.BFGS(initial_stepnorm=0.01),
    callback=callback,
    allow_f_increases = false
)
0.27903017
0.27872407
0.27863026
0.25121066
0.2420656
0.21821216
0.19049883
0.15103383
0.12721382
0.07557778
0.06568295
0.061216295
0.060143154
0.050319873
0.041621387
0.038229752
0.033189513
0.031393625
0.028698342
0.026024546
0.023069603
0.022244453
0.019386347
0.015208791
0.013807998
0.0113597335
0.010351848
0.010138731
0.008896519
0.007957236
0.007272527
0.0069238544
0.0061710677
0.0061710677
0.0058267526
0.0051714745
0.0044511706
0.004431646
0.003977054
0.003935433
0.0035506398
0.0033168232
0.00292864
0.002692305
0.002690089
0.0026872384
0.0026772963
0.0026385426
0.0026384164
0.0026384164
0.0026384164
retcode: Success
u: ComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[0.08741521 -0.7041453; -0.19451857 -0.99229836; … ; 0.19416788 1.120761; 0.31404698 0.026229795], bias = Float32[0.3758842; 0.6552975; … ; -0.6631577; -1.0059776;;]), layer_3 = (weight = Float32[0.665346 0.35397014 … -0.31324288 0.48092267; 0.0966155 -0.45088494 … 0.44247928 1.1319721], bias = Float32[-0.1843225; 0.8225856;;]))

Plot the solution to see if it matches the provided data.

callback(result_neuralode2.u, loss_neuralode(result_neuralode2.u)...; doplot=true)
0.0026384164
false

Animated solving process#

Let’s reset the problem and visualize the training process.

rng = Random.default_rng()
u0 = Float32[2.0; 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[begin], tspan[end], length = datasize)
0.0f0:0.05172414f0:1.5f0

Setup truth values for validation

true_A = Float32[-0.1 2.0; -2.0 -0.1]

function trueODEfunc!(du, u, p, t)
    du .= ((u.^3)'true_A)'
end
trueODEfunc! (generic function with 1 method)
prob_trueode = ODEProblem(trueODEfunc!, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))
2×30 Matrix{Float32}:
 2.0  1.9465    1.74178  1.23837  0.577126  …  1.40688   1.37023   1.29215
 0.0  0.798832  1.46473  1.80877  1.86465      0.451358  0.728681  0.972087
nodeFunc = Lux.Chain(
    x -> x.^3,
    Lux.Dense(2, 50, tanh),
    Lux.Dense(50, 2)
)

p, st = Lux.setup(rng, nodeFunc)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-0.2273688 -0.20968895; 0.054255627 0.20530364; … ; 0.15356942 0.24286065; -0.11287057 0.30330524], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (weight = Float32[-0.18407616 0.17299011 … -0.22141114 -0.07978176; 0.16965075 0.03569014 … 0.15526423 -0.3377969], bias = Float32[0.0; 0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()))

Parameters in the neural network:

p
(layer_1 = NamedTuple(), layer_2 = (weight = Float32[-0.2273688 -0.20968895; 0.054255627 0.20530364; … ; 0.15356942 0.24286065; -0.11287057 0.30330524], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (weight = Float32[-0.18407616 0.17299011 … -0.22141114 -0.07978176; 0.16965075 0.03569014 … 0.15526423 -0.3377969], bias = Float32[0.0; 0.0;;]))

Use NeuroODE() to construct the problem

prob_node = NeuralODE(nodeFunc, tspan, Tsit5(), saveat = tsteps)
NeuralODE(
    model = Chain(
        layer_1 = WrappedFunction(#8),
        layer_2 = Dense(2 => 50, tanh_fast),  # 150 parameters
        layer_3 = Dense(50 => 2),       # 102 parameters
    ),
)         # Total: 252 parameters,
          #        plus 0 states.

Predicted values.

function predict_neuralode(p)
    Array(prob_node(u0, p, st)[1])
end
predict_neuralode (generic function with 1 method)

The loss function.

function loss_neuralode(p)
    pred = predict_neuralode(p)
    loss = sum(abs2, ode_data .- pred)
    return loss, pred
end
loss_neuralode (generic function with 1 method)

Callback function to observe training process

anim = Animation()
callback = function (p, l, pred; doplot = true)
    if doplot
        plt = scatter(tsteps, ode_data[1,:], label = "data")
        scatter!(plt, tsteps, pred[1,:], label = "prediction")
        frame(anim)
    end
    return false
end
#10 (generic function with 1 method)
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentArray(p))
OptimizationProblem. In-place: true
u0: ComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[-0.2273688 -0.20968895; 0.054255627 0.20530364; … ; 0.15356942 0.24286065; -0.11287057 0.30330524], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (weight = Float32[-0.18407616 0.17299011 … -0.22141114 -0.07978176; 0.16965075 0.03569014 … 0.15526423 -0.3377969], bias = Float32[0.0; 0.0;;]))

Solve the problem using the ADAM optimizer

result_neuralode = Optimization.solve(
    optprob,
    OptimizationOptimisers.ADAM(0.05),
    callback = callback,
    maxiters = 300
)
retcode: Default
u: ComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[1.1471213 0.074072815; -0.33308297 0.38328815; … ; 0.10337194 1.6461912; 0.116916105 0.70201075], bias = Float32[-0.027155323; -0.05846334; … ; -0.2410494; 0.23793979;;]), layer_3 = (weight = Float32[0.39916682 -0.42204866 … -0.39229235 0.077795036; 0.6139497 -0.46670642 … 0.24933042 0.09802464], bias = Float32[-0.4860557; -0.10765326;;]))

And then solve the problem using the LBFGS optimizer

optprob2 = remake(optprob, u0 = result_neuralode.u)

result_neuralode2 = Optimization.solve(
    optprob2,
    Optim.LBFGS(),
    callback = callback,
    allow_f_increases = false
)
retcode: Success
u: ComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[1.1471213 0.07407282; -0.33308297 0.38328815; … ; 0.10337194 1.6461912; 0.116916105 0.70201075], bias = Float32[-0.027155323; -0.05846334; … ; -0.2410494; 0.23793979;;]), layer_3 = (weight = Float32[0.39916682 -0.42204866 … -0.39229235 0.07779504; 0.6139497 -0.46670642 … 0.24933042 0.09802464], bias = Float32[-0.4860557; -0.10765325;;]))

Visualize fitting process

mp4(anim, fps=15)
[ Info: Saved animation to /tmp/docs/ude/tmp.mp4