First Neural ODE example#
A neural ODE is an ODE where a neural network defines its derivative function. \(\dot{u} = NN(u)\)
From: https://docs.sciml.ai/DiffEqFlux/stable/examples/neural_ode/
using Lux, DiffEqFlux, DifferentialEquations, ComponentArrays
using Optimization, OptimizationOptimJL, OptimizationOptimisers
using Random, Plots
Plots.default(fmt=:png)
rng = Random.default_rng()
TaskLocalRNG()
True solution
function trueODEfunc(du, u, p, t)
true_A = [-0.1 2.0; -2.0 -0.1]
du .= ((u.^3)'true_A)'
end
trueODEfunc (generic function with 1 method)
The data used for training
u0 = Float32[2.0; 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[begin], tspan[end], length = datasize)
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))
2×30 Matrix{Float32}:
2.0 1.9465 1.74178 1.23837 0.577127 … 1.40688 1.37023 1.29214
0.0 0.798832 1.46473 1.80877 1.86465 0.451377 0.728699 0.972102
Make a NeuralODE
problem with a neural network defined by Lux.jl
.
dudt2 = Lux.Chain(
x -> x.^3,
Lux.Dense(2, 50, tanh),
Lux.Dense(50, 2)
)
p, st = Lux.setup(rng, dudt2)
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps)
NeuralODE(
model = Chain(
layer_1 = WrappedFunction(static(:direct_call), var"#1#2"()),
layer_2 = Dense(2 => 50, tanh_fast), # 150 parameters
layer_3 = Dense(50 => 2), # 102 parameters
),
) # Total: 252 parameters,
# plus 0 states.
Define output, loss, and callback functions.
function predict_neuralode(p)
Array(prob_neuralode(u0, p, st)[1])
end
function loss_neuralode(p)
pred = predict_neuralode(p)
loss = sum(abs2, ode_data .- pred)
return loss, pred
end
loss_neuralode (generic function with 1 method)
Do not generate plots by default. Users could change doplot=true to see the figures in the callback fuction.
callback = function (p, l, pred; doplot = false)
println(l)
# plot current prediction against data
if doplot
plt = scatter(tsteps, ode_data[1,:], label = "data")
scatter!(plt, tsteps, pred[1,:], label = "prediction")
plot(plt)
end
return false
end
#3 (generic function with 1 method)
Try the callback function on the first iteration.
pinit = ComponentArray(p)
callback(pinit, loss_neuralode(pinit)...; doplot=true)
167.19664
false
Use Optimization.jl to solve the problem.
Zygote
for automatic differentiation (AD)loss_neuralode
as the function to be optimizedMake an
OptimizationProblem
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pinit)
OptimizationProblem. In-place: true
u0: ComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[0.0662788 -0.32771367; -0.33288264 0.2743391; … ; 0.2683026 -0.21706231; -0.091520146 0.24157134], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (weight = Float32[0.1532453 -0.116649 … 0.08572113 0.08022476; -0.22498849 -0.16709 … 0.050608143 0.12155518], bias = Float32[0.0; 0.0;;]))
Solve the OptimizationProblem
.
result_neuralode = Optimization.solve(
optprob,
OptimizationOptimisers.ADAM(0.05),
callback = callback,
maxiters = 300
)
167.19664
120.31914
108.403946
107.46847
105.21311
100.592
96.75333
95.06156
94.56816
94.15621
93.28329
91.80549
89.79514
87.41507
84.811775
82.02457
78.87979
74.708885
68.96615
63.56195
57.726936
50.350815
43.649918
38.29633
35.52835
37.3186
36.656704
29.027027
32.500465
34.886585
25.533499
29.097551
28.418858
20.533684
25.243275
19.16565
16.592937
17.684364
16.928545
14.648718
12.550248
13.505121
11.347705
8.949317
9.846885
8.97159
6.836227
6.975878
7.744865
6.0111246
5.281647
5.767023
5.235621
4.2431607
4.3920913
4.5372124
3.690721
3.5166442
3.7018
3.353897
3.2213786
3.461114
3.148039
3.0739288
3.2057383
2.9637
2.7577891
2.7795935
2.6118565
2.4678175
2.5276616
2.3566623
2.2758777
2.3271666
2.1483185
2.099087
2.1436574
1.9536127
1.9597641
1.9065138
1.7870454
1.8259807
1.7385446
1.7003086
1.6982859
1.6125047
1.634399
1.5904665
1.5435815
1.5764419
1.5038526
1.4906355
1.4613497
1.426169
1.4192605
1.3734976
1.3661429
1.3415122
1.3147414
1.3047757
1.2754083
1.2674935
1.2397871
1.2310768
1.2003013
1.1902609
1.1798377
1.1652381
1.1435666
1.1210132
1.1055727
1.0891864
1.0774424
1.0650972
1.0511982
1.0359436
1.0212802
1.0087494
0.9953654
0.98375773
0.9719267
0.9589668
0.9463374
0.9306667
0.91587317
0.9038944
0.89262575
0.8787781
0.8661327
0.8604386
0.8492976
0.84059197
0.8223171
0.82367826
0.8042032
0.8005909
0.7836791
0.7976124
0.7701272
0.77212745
0.758391
0.7438481
0.74305236
0.7207943
0.72236717
0.7059372
0.6984963
0.6954198
0.6778692
0.6752462
0.6639262
0.6553555
0.6494032
0.6386608
0.63134205
0.6241068
0.6150467
0.6110784
0.60470045
0.59407806
0.5873354
0.58192706
0.5736661
0.5669396
0.56126475
0.5539975
0.5463702
0.5407676
0.5337791
0.52746207
0.52124625
0.514028
0.50932306
0.5027462
0.4967661
0.48984045
0.48452985
0.47828454
0.47207585
0.46709275
0.46134683
0.45644197
0.44976348
0.44624883
0.4402119
0.4340558
0.42862633
0.4235788
0.41733482
0.4143118
0.40865052
0.40317938
0.3979276
0.39296716
0.38751215
0.3839442
0.39341205
0.41484922
0.46488193
0.52884984
0.6560018
0.65819484
0.6265168
0.4334909
0.348474
0.38716778
0.44867402
0.45235953
0.36642152
0.33085352
0.37611735
0.4168111
0.40535602
0.33783332
0.3090127
0.3307818
0.36547258
0.3866476
0.35495284
0.31450218
0.28897926
0.2874467
0.29782757
0.30801246
0.3158679
0.30882594
0.29747495
0.27989316
0.26561207
0.25657293
0.25809216
0.26534176
0.27745903
0.3020192
0.33713692
0.41761467
0.4944152
0.662216
0.63720495
0.60500246
0.33288863
0.23581052
0.3453623
0.4090682
0.3532079
0.23394367
0.26301467
0.353802
0.31930432
0.24198447
0.21395075
0.24781902
0.26383287
0.22576715
0.2028881
0.21761017
0.24375473
0.24763174
0.22536644
0.2155435
0.20329408
0.19133282
0.18749475
0.19132333
0.19779207
0.19545625
0.18790458
0.1780947
0.17350098
0.17474052
0.17898133
0.1779093
0.16959493
0.16213535
0.1660638
0.17686051
0.19306593
0.20414105
0.23039857
0.26192695
0.3471881
0.42528796
0.6083786
0.56232923
0.5059952
0.23989151
0.14987257
0.24800329
0.2980779
0.22881198
0.14531864
0.19605987
0.26088452
0.18836077
0.13768837
0.18005958
0.20056307
0.16222958
0.13290618
0.15851179
0.13290618
retcode: Default
u: ComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[-0.35651985 0.26343077; -0.10204345 0.11413392; … ; 0.06733824 -0.09214533; -0.25298274 0.015427607], bias = Float32[0.13203104; -0.343761; … ; 0.6513015; -0.46673098;;]), layer_3 = (weight = Float32[-0.6064762 -0.065846846 … 0.13850327 0.42713645; -0.63914734 -0.5337808 … 0.54750055 -0.11272484], bias = Float32[-0.4069189; -0.19255003;;]))
Use another optimization algorithm Optim.BFGS()
and start from where the ADAM()
algorithm stopped.
optprob2 = remake(optprob, u0 = result_neuralode.u)
result_neuralode2 = Optimization.solve(
optprob2,
Optim.BFGS(initial_stepnorm=0.01),
callback=callback,
allow_f_increases = false
)
0.13290618
0.13218533
0.13190852
0.12189339
0.11270843
0.0872243
0.06982846
0.05787575
0.05056137
0.042543683
0.037341867
0.033438224
0.02993875
0.02592087
0.025537813
0.020921824
0.018936818
0.0175642
0.017458977
0.015713342
0.014957287
0.013680975
0.012938932
0.012617341
0.012014252
0.010198819
0.009953882
0.008814211
0.008545592
0.0073171053
0.0069292653
0.006661932
0.00621038
0.006055832
0.006055832
retcode: Success
u: ComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[-0.14090522 0.23728213; -0.106181554 0.12472089; … ; 0.09093609 -0.20473632; -0.2093878 -0.026541697], bias = Float32[0.236187; -0.43054527; … ; 0.72830635; -0.4551798;;]), layer_3 = (weight = Float32[-0.61769265 -0.14475703 … 0.26489532 0.43320486; -0.6611805 -0.7263577 … 0.8093641 -0.2037413], bias = Float32[-0.3718002; 0.20293607;;]))
Plot the solution to see if it matches the provided data.
callback(result_neuralode2.u, loss_neuralode(result_neuralode2.u)...; doplot=true)
0.006055832
false
Animated solving process#
Let’s reset the problem and visualize the training process.
rng = Random.default_rng()
u0 = Float32[2.0; 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[begin], tspan[end], length = datasize)
0.0f0:0.05172414f0:1.5f0
Setup truth values for validation
true_A = Float32[-0.1 2.0; -2.0 -0.1]
function trueODEfunc!(du, u, p, t)
du .= ((u.^3)'true_A)'
end
trueODEfunc! (generic function with 1 method)
prob_trueode = ODEProblem(trueODEfunc!, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))
2×30 Matrix{Float32}:
2.0 1.9465 1.74178 1.23837 0.577126 … 1.40688 1.37023 1.29215
0.0 0.798832 1.46473 1.80877 1.86465 0.451358 0.728681 0.972087
nodeFunc = Lux.Chain(
x -> x.^3,
Lux.Dense(2, 50, tanh),
Lux.Dense(50, 2)
)
p, st = Lux.setup(rng, nodeFunc)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-0.0017697228 0.075023025; -0.03329629 0.13265085; … ; -0.09379381 -0.27755153; -0.17466469 0.3351102], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (weight = Float32[-0.31938246 -0.012167533 … -0.05699885 0.16791472; 0.1971828 -0.011796937 … 0.20698278 -0.12465175], bias = Float32[0.0; 0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()))
Parameters in the neural network:
p
(layer_1 = NamedTuple(), layer_2 = (weight = Float32[-0.0017697228 0.075023025; -0.03329629 0.13265085; … ; -0.09379381 -0.27755153; -0.17466469 0.3351102], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (weight = Float32[-0.31938246 -0.012167533 … -0.05699885 0.16791472; 0.1971828 -0.011796937 … 0.20698278 -0.12465175], bias = Float32[0.0; 0.0;;]))
Use NeuroODE()
to construct the problem
prob_node = NeuralODE(nodeFunc, tspan, Tsit5(), saveat = tsteps)
NeuralODE(
model = Chain(
layer_1 = WrappedFunction(static(:direct_call), var"#8#9"()),
layer_2 = Dense(2 => 50, tanh_fast), # 150 parameters
layer_3 = Dense(50 => 2), # 102 parameters
),
) # Total: 252 parameters,
# plus 0 states.
Predicted values.
function predict_neuralode(p)
Array(prob_node(u0, p, st)[1])
end
predict_neuralode (generic function with 1 method)
The loss function.
function loss_neuralode(p)
pred = predict_neuralode(p)
loss = sum(abs2, ode_data .- pred)
return loss, pred
end
loss_neuralode (generic function with 1 method)
Callback function to observe training process
anim = Animation()
callback = function (p, l, pred; doplot = true)
if doplot
plt = scatter(tsteps, ode_data[1,:], label = "data")
scatter!(plt, tsteps, pred[1,:], label = "prediction")
frame(anim)
end
return false
end
#10 (generic function with 1 method)
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentArray(p))
OptimizationProblem. In-place: true
u0: ComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[-0.0017697228 0.075023025; -0.03329629 0.13265085; … ; -0.09379381 -0.27755153; -0.17466469 0.3351102], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (weight = Float32[-0.31938246 -0.012167533 … -0.05699885 0.16791472; 0.1971828 -0.011796937 … 0.20698278 -0.12465175], bias = Float32[0.0; 0.0;;]))
Solve the problem using the ADAM optimizer
result_neuralode = Optimization.solve(
optprob,
OptimizationOptimisers.ADAM(0.05),
callback = callback,
maxiters = 300
)
retcode: Default
u: ComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[0.60198385 0.4751431; 0.91905576 -0.44151247; … ; 0.24059728 -0.6173106; -1.0068108 0.010679541], bias = Float32[-0.053559765; 0.11342683; … ; -0.1158054; -0.27371103;;]), layer_3 = (weight = Float32[0.24937321 0.79998964 … 0.4535362 -0.03008947; 0.7118856 0.7973566 … 0.26106736 -0.58793104], bias = Float32[-0.59332967; -0.11879284;;]))
And then solve the problem using the LBFGS optimizer
optprob2 = remake(optprob, u0 = result_neuralode.u)
result_neuralode2 = Optimization.solve(
optprob2,
Optim.LBFGS(),
callback = callback,
allow_f_increases = false
)
retcode: Success
u: ComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[0.55711615 0.32786527; 0.8842164 -0.41041452; … ; 0.1308622 -0.6185396; -1.0092094 -0.02800535], bias = Float32[-0.056156855; 0.09885336; … ; -0.11396651; -0.21055597;;]), layer_3 = (weight = Float32[0.2577177 0.8306918 … 0.5498645 -0.042159688; 0.6751993 0.73580015 … 0.3830905 -0.5592567], bias = Float32[-0.49717152; 0.147061;;]))
Visualize fitting process
mp4(anim, fps=15)
[ Info: Saved animation to /home/runner/work/jl-ude/jl-ude/docs/tmp.mp4