First Neural ODE example

First Neural ODE example#

A neural ODE is an ODE where a neural network defines its derivative function. \(\dot{u} = NN(u)\)

From: https://docs.sciml.ai/DiffEqFlux/stable/examples/neural_ode/

using Lux, DiffEqFlux, DifferentialEquations, ComponentArrays
using Optimization, OptimizationOptimJL, OptimizationOptimisers
using Random, Plots
Plots.default(fmt=:png)

rng = Random.default_rng()
TaskLocalRNG()

True solution

function trueODEfunc(du, u, p, t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u.^3)'true_A)'
end
trueODEfunc (generic function with 1 method)

The data used for training

u0 = Float32[2.0; 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[begin], tspan[end], length = datasize)
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))
2×30 Matrix{Float32}:
 2.0  1.9465    1.74178  1.23837  0.577127  …  1.40688   1.37023   1.29214
 0.0  0.798832  1.46473  1.80877  1.86465      0.451377  0.728699  0.972102

Make a NeuralODE problem with a neural network defined by Lux.jl.

dudt2 = Lux.Chain(
    x -> x.^3,
    Lux.Dense(2, 50, tanh),
    Lux.Dense(50, 2)
)

p, st = Lux.setup(rng, dudt2)
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps)
NeuralODE(
    model = Chain(
        layer_1 = WrappedFunction(static(:direct_call), var"#1#2"()),
        layer_2 = Dense(2 => 50, tanh_fast),  # 150 parameters
        layer_3 = Dense(50 => 2),       # 102 parameters
    ),
)         # Total: 252 parameters,
          #        plus 0 states.

Define output, loss, and callback functions.

function predict_neuralode(p)
    Array(prob_neuralode(u0, p, st)[1])
  end

function loss_neuralode(p)
    pred = predict_neuralode(p)
    loss = sum(abs2, ode_data .- pred)
    return loss, pred
end
loss_neuralode (generic function with 1 method)

Do not generate plots by default. Users could change doplot=true to see the figures in the callback fuction.

callback = function (p, l, pred; doplot = false)
    println(l)
    # plot current prediction against data
    if doplot
      plt = scatter(tsteps, ode_data[1,:], label = "data")
      scatter!(plt, tsteps, pred[1,:], label = "prediction")
      plot(plt)
    end
    return false
end
#3 (generic function with 1 method)

Try the callback function on the first iteration.

pinit = ComponentArray(p)
callback(pinit, loss_neuralode(pinit)...; doplot=true)
167.19664
false

Use Optimization.jl to solve the problem.

  • Zygote for automatic differentiation (AD)

  • loss_neuralode as the function to be optimized

  • Make an OptimizationProblem

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pinit)
OptimizationProblem. In-place: true
u0: ComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[0.0662788 -0.32771367; -0.33288264 0.2743391; … ; 0.2683026 -0.21706231; -0.091520146 0.24157134], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (weight = Float32[0.1532453 -0.116649 … 0.08572113 0.08022476; -0.22498849 -0.16709 … 0.050608143 0.12155518], bias = Float32[0.0; 0.0;;]))

Solve the OptimizationProblem.

result_neuralode = Optimization.solve(
    optprob,
    OptimizationOptimisers.ADAM(0.05),
    callback = callback,
    maxiters = 300
)
167.19664
120.31914
108.403946
107.46847
105.21311
100.592
96.75333
95.06156
94.56816
94.15621
93.28329
91.80549
89.79514
87.41507
84.811775
82.02457
78.87979
74.708885
68.96615
63.56195
57.726936
50.350815
43.649918
38.29633
35.52835
37.3186
36.656704
29.027027
32.500465
34.886585
25.533499
29.097551
28.418858
20.533684
25.243275
19.16565
16.592937
17.684364
16.928545
14.648718
12.550248
13.505121
11.347705
8.949317
9.846885
8.97159
6.836227
6.975878
7.744865
6.0111246
5.281647
5.767023
5.235621
4.2431607
4.3920913
4.5372124
3.690721
3.5166442
3.7018
3.353897
3.2213786
3.461114
3.148039
3.0739288
3.2057383
2.9637
2.7577891
2.7795935
2.6118565
2.4678175
2.5276616
2.3566623
2.2758777
2.3271666
2.1483185
2.099087
2.1436574
1.9536127
1.9597641
1.9065138
1.7870454
1.8259807
1.7385446
1.7003086
1.6982859
1.6125047
1.634399
1.5904665
1.5435815
1.5764419
1.5038526
1.4906355
1.4613497
1.426169
1.4192605
1.3734976
1.3661429
1.3415122
1.3147414
1.3047757
1.2754083
1.2674935
1.2397871
1.2310768
1.2003013
1.1902609
1.1798377
1.1652381
1.1435666
1.1210132
1.1055727
1.0891864
1.0774424
1.0650972
1.0511982
1.0359436
1.0212802
1.0087494
0.9953654
0.98375773
0.9719267
0.9589668
0.9463374
0.9306667
0.91587317
0.9038944
0.89262575
0.8787781
0.8661327
0.8604386
0.8492976
0.84059197
0.8223171
0.82367826
0.8042032
0.8005909
0.7836791
0.7976124
0.7701272
0.77212745
0.758391
0.7438481
0.74305236
0.7207943
0.72236717
0.7059372
0.6984963
0.6954198
0.6778692
0.6752462
0.6639262
0.6553555
0.6494032
0.6386608
0.63134205
0.6241068
0.6150467
0.6110784
0.60470045
0.59407806
0.5873354
0.58192706
0.5736661
0.5669396
0.56126475
0.5539975
0.5463702
0.5407676
0.5337791
0.52746207
0.52124625
0.514028
0.50932306
0.5027462
0.4967661
0.48984045
0.48452985
0.47828454
0.47207585
0.46709275
0.46134683
0.45644197
0.44976348
0.44624883
0.4402119
0.4340558
0.42862633
0.4235788
0.41733482
0.4143118
0.40865052
0.40317938
0.3979276
0.39296716
0.38751215
0.3839442
0.39341205
0.41484922
0.46488193
0.52884984
0.6560018
0.65819484
0.6265168
0.4334909
0.348474
0.38716778
0.44867402
0.45235953
0.36642152
0.33085352
0.37611735
0.4168111
0.40535602
0.33783332
0.3090127
0.3307818
0.36547258
0.3866476
0.35495284
0.31450218
0.28897926
0.2874467
0.29782757
0.30801246
0.3158679
0.30882594
0.29747495
0.27989316
0.26561207
0.25657293
0.25809216
0.26534176
0.27745903
0.3020192
0.33713692
0.41761467
0.4944152
0.662216
0.63720495
0.60500246
0.33288863
0.23581052
0.3453623
0.4090682
0.3532079
0.23394367
0.26301467
0.353802
0.31930432
0.24198447
0.21395075
0.24781902
0.26383287
0.22576715
0.2028881
0.21761017
0.24375473
0.24763174
0.22536644
0.2155435
0.20329408
0.19133282
0.18749475
0.19132333
0.19779207
0.19545625
0.18790458
0.1780947
0.17350098
0.17474052
0.17898133
0.1779093
0.16959493
0.16213535
0.1660638
0.17686051
0.19306593
0.20414105
0.23039857
0.26192695
0.3471881
0.42528796
0.6083786
0.56232923
0.5059952
0.23989151
0.14987257
0.24800329
0.2980779
0.22881198
0.14531864
0.19605987
0.26088452
0.18836077
0.13768837
0.18005958
0.20056307
0.16222958
0.13290618
0.15851179
0.13290618
retcode: Default
u: ComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[-0.35651985 0.26343077; -0.10204345 0.11413392; … ; 0.06733824 -0.09214533; -0.25298274 0.015427607], bias = Float32[0.13203104; -0.343761; … ; 0.6513015; -0.46673098;;]), layer_3 = (weight = Float32[-0.6064762 -0.065846846 … 0.13850327 0.42713645; -0.63914734 -0.5337808 … 0.54750055 -0.11272484], bias = Float32[-0.4069189; -0.19255003;;]))

Use another optimization algorithm Optim.BFGS() and start from where the ADAM() algorithm stopped.

optprob2 = remake(optprob, u0 = result_neuralode.u)

result_neuralode2 = Optimization.solve(
    optprob2,
    Optim.BFGS(initial_stepnorm=0.01),
    callback=callback,
    allow_f_increases = false
)
0.13290618
0.13218533
0.13190852
0.12189339
0.11270843
0.0872243
0.06982846
0.05787575
0.05056137
0.042543683
0.037341867
0.033438224
0.02993875
0.02592087
0.025537813
0.020921824
0.018936818
0.0175642
0.017458977
0.015713342
0.014957287
0.013680975
0.012938932
0.012617341
0.012014252
0.010198819
0.009953882
0.008814211
0.008545592
0.0073171053
0.0069292653
0.006661932
0.00621038
0.006055832
0.006055832
retcode: Success
u: ComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[-0.14090522 0.23728213; -0.106181554 0.12472089; … ; 0.09093609 -0.20473632; -0.2093878 -0.026541697], bias = Float32[0.236187; -0.43054527; … ; 0.72830635; -0.4551798;;]), layer_3 = (weight = Float32[-0.61769265 -0.14475703 … 0.26489532 0.43320486; -0.6611805 -0.7263577 … 0.8093641 -0.2037413], bias = Float32[-0.3718002; 0.20293607;;]))

Plot the solution to see if it matches the provided data.

callback(result_neuralode2.u, loss_neuralode(result_neuralode2.u)...; doplot=true)
0.006055832
false

Animated solving process#

Let’s reset the problem and visualize the training process.

rng = Random.default_rng()
u0 = Float32[2.0; 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[begin], tspan[end], length = datasize)
0.0f0:0.05172414f0:1.5f0

Setup truth values for validation

true_A = Float32[-0.1 2.0; -2.0 -0.1]

function trueODEfunc!(du, u, p, t)
    du .= ((u.^3)'true_A)'
end
trueODEfunc! (generic function with 1 method)
prob_trueode = ODEProblem(trueODEfunc!, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))
2×30 Matrix{Float32}:
 2.0  1.9465    1.74178  1.23837  0.577126  …  1.40688   1.37023   1.29215
 0.0  0.798832  1.46473  1.80877  1.86465      0.451358  0.728681  0.972087
nodeFunc = Lux.Chain(
    x -> x.^3,
    Lux.Dense(2, 50, tanh),
    Lux.Dense(50, 2)
)

p, st = Lux.setup(rng, nodeFunc)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-0.0017697228 0.075023025; -0.03329629 0.13265085; … ; -0.09379381 -0.27755153; -0.17466469 0.3351102], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (weight = Float32[-0.31938246 -0.012167533 … -0.05699885 0.16791472; 0.1971828 -0.011796937 … 0.20698278 -0.12465175], bias = Float32[0.0; 0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()))

Parameters in the neural network:

p
(layer_1 = NamedTuple(), layer_2 = (weight = Float32[-0.0017697228 0.075023025; -0.03329629 0.13265085; … ; -0.09379381 -0.27755153; -0.17466469 0.3351102], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (weight = Float32[-0.31938246 -0.012167533 … -0.05699885 0.16791472; 0.1971828 -0.011796937 … 0.20698278 -0.12465175], bias = Float32[0.0; 0.0;;]))

Use NeuroODE() to construct the problem

prob_node = NeuralODE(nodeFunc, tspan, Tsit5(), saveat = tsteps)
NeuralODE(
    model = Chain(
        layer_1 = WrappedFunction(static(:direct_call), var"#8#9"()),
        layer_2 = Dense(2 => 50, tanh_fast),  # 150 parameters
        layer_3 = Dense(50 => 2),       # 102 parameters
    ),
)         # Total: 252 parameters,
          #        plus 0 states.

Predicted values.

function predict_neuralode(p)
    Array(prob_node(u0, p, st)[1])
end
predict_neuralode (generic function with 1 method)

The loss function.

function loss_neuralode(p)
    pred = predict_neuralode(p)
    loss = sum(abs2, ode_data .- pred)
    return loss, pred
end
loss_neuralode (generic function with 1 method)

Callback function to observe training process

anim = Animation()
callback = function (p, l, pred; doplot = true)
    if doplot
        plt = scatter(tsteps, ode_data[1,:], label = "data")
        scatter!(plt, tsteps, pred[1,:], label = "prediction")
        frame(anim)
    end
    return false
end
#10 (generic function with 1 method)
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentArray(p))
OptimizationProblem. In-place: true
u0: ComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[-0.0017697228 0.075023025; -0.03329629 0.13265085; … ; -0.09379381 -0.27755153; -0.17466469 0.3351102], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (weight = Float32[-0.31938246 -0.012167533 … -0.05699885 0.16791472; 0.1971828 -0.011796937 … 0.20698278 -0.12465175], bias = Float32[0.0; 0.0;;]))

Solve the problem using the ADAM optimizer

result_neuralode = Optimization.solve(
    optprob,
    OptimizationOptimisers.ADAM(0.05),
    callback = callback,
    maxiters = 300
)
retcode: Default
u: ComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[0.60198385 0.4751431; 0.91905576 -0.44151247; … ; 0.24059728 -0.6173106; -1.0068108 0.010679541], bias = Float32[-0.053559765; 0.11342683; … ; -0.1158054; -0.27371103;;]), layer_3 = (weight = Float32[0.24937321 0.79998964 … 0.4535362 -0.03008947; 0.7118856 0.7973566 … 0.26106736 -0.58793104], bias = Float32[-0.59332967; -0.11879284;;]))

And then solve the problem using the LBFGS optimizer

optprob2 = remake(optprob, u0 = result_neuralode.u)

result_neuralode2 = Optimization.solve(
    optprob2,
    Optim.LBFGS(),
    callback = callback,
    allow_f_increases = false
)
retcode: Success
u: ComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[0.55711615 0.32786527; 0.8842164 -0.41041452; … ; 0.1308622 -0.6185396; -1.0092094 -0.02800535], bias = Float32[-0.056156855; 0.09885336; … ; -0.11396651; -0.21055597;;]), layer_3 = (weight = Float32[0.2577177 0.8306918 … 0.5498645 -0.042159688; 0.6751993 0.73580015 … 0.3830905 -0.5592567], bias = Float32[-0.49717152; 0.147061;;]))

Visualize fitting process

mp4(anim, fps=15)
[ Info: Saved animation to /home/runner/work/jl-ude/jl-ude/docs/tmp.mp4