First Neural ODE example#
A neural ODE is an ODE where a neural network defines its derivative function. \(\dot{u} = NN(u)\)
From: https://docs.sciml.ai/DiffEqFlux/stable/examples/neural_ode/
using Lux, DiffEqFlux, DifferentialEquations, ComponentArrays
using Optimization, OptimizationOptimJL, OptimizationOptimisers
using Random, Plots
rng = Random.default_rng()
[ Info: Precompiling IJuliaExt [2f4121a4-3b3a-5ce6-9c5e-1f2673ce168a]
TaskLocalRNG()
True solution
function trueODEfunc(du, u, p, t)
true_A = [-0.1 2.0; -2.0 -0.1]
du .= ((u.^3)'true_A)'
end
trueODEfunc (generic function with 1 method)
The data used for training
u0 = Float32[2.0; 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[begin], tspan[end], length = datasize)
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))
2×30 Matrix{Float32}:
2.0 1.9465 1.74178 1.23837 0.577127 … 1.40688 1.37023 1.29214
0.0 0.798832 1.46473 1.80877 1.86465 0.451377 0.728699 0.972102
Make a NeuralODE
problem with a neural network defined by Lux.jl
.
dudt2 = Lux.Chain(
x -> x.^3,
Lux.Dense(2, 50, tanh),
Lux.Dense(50, 2)
)
p, st = Lux.setup(rng, dudt2)
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps)
NeuralODE(
model = Chain(
layer_1 = WrappedFunction(#1),
layer_2 = Dense(2 => 50, tanh_fast), # 150 parameters
layer_3 = Dense(50 => 2), # 102 parameters
),
) # Total: 252 parameters,
# plus 0 states.
Define output, loss, and callback functions.
function predict_neuralode(p)
Array(prob_neuralode(u0, p, st)[1])
end
function loss_neuralode(p)
pred = predict_neuralode(p)
loss = sum(abs2, ode_data .- pred)
return loss, pred
end
loss_neuralode (generic function with 1 method)
Do not generate plots by default. Users could change doplot=true to see the figures in the callback fuction.
callback = function (p, l, pred; doplot = false)
println(l)
# plot current prediction against data
if doplot
plt = scatter(tsteps, ode_data[1,:], label = "data")
scatter!(plt, tsteps, pred[1,:], label = "prediction")
plot(plt)
end
return false
end
#3 (generic function with 1 method)
Try the callback function on the first iteration.
pinit = ComponentArray(p)
callback(pinit, loss_neuralode(pinit)...; doplot=true)
118.93152
false
Use Optimization.jl to solve the problem.
Zygote
for automatic differentiation (AD)loss_neuralode
as the function to be optimizedMake an
OptimizationProblem
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pinit)
OptimizationProblem. In-place: true
u0: ComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[-0.13173558 -0.26862946; -0.21219468 -0.29113472; … ; 0.23871331 0.32163706; 0.2673218 -0.17830189], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (weight = Float32[0.039003026 0.28549364 … -0.24841379 0.102247044; 0.23711038 -0.06729417 … 0.070725866 0.0989487], bias = Float32[0.0; 0.0;;]))
Solve the OptimizationProblem
.
result_neuralode = Optimization.solve(
optprob,
OptimizationOptimisers.ADAM(0.05),
callback = callback,
maxiters = 300
)
118.93152
107.588326
106.10397
104.19737
97.24725
92.15314
90.36655
90.212326
89.86787
87.88174
84.30611
79.794754
74.50199
67.97798
60.834373
49.852478
43.669773
43.41168
49.646446
52.95736
48.760284
40.29796
32.5933
36.968784
43.266697
37.661457
28.308996
26.99663
27.117113
18.224869
96.5595
19.507513
25.880632
28.43808
31.63586
34.80781
37.45008
39.565304
41.24104
42.545956
43.48191
43.99707
44.060825
43.736874
43.136986
42.329556
41.334167
40.15575
38.84882
37.48766
36.146152
34.874454
33.717518
32.696476
31.7733
30.857851
29.862919
28.808237
27.685902
26.28656
24.606424
22.672087
20.167015
17.113039
13.772392
10.939747
9.860489
13.007938
14.8824415
10.949396
9.325867
8.410077
8.960457
9.799699
8.659706
9.288935
8.292451
8.015104
6.900593
6.666691
5.8868365
5.243168
5.1806083
5.4701586
5.4545035
4.8181467
4.3685756
3.9891233
3.7556217
3.8554635
3.8198745
3.5525782
3.191701
2.9124403
2.6297073
2.5695117
2.544964
2.3150454
2.1894524
2.0074775
1.8385435
1.8549367
1.8052906
1.7279646
1.7551826
1.6504538
1.5874145
1.6074078
1.5670396
1.5367639
1.5248336
1.4728632
1.4370981
1.3992316
1.3706714
1.3412848
1.305206
1.2548811
1.2194672
1.1932594
1.1558425
1.1368406
1.1092584
1.0908762
1.0662757
1.0483215
1.0350976
1.01947
1.0049571
0.9824165
0.97018325
0.9649516
0.95416635
0.941921
0.925175
0.9117411
0.9005054
0.8941272
0.87876767
0.871321
0.8557811
0.845482
0.83529013
0.8299453
0.81286347
0.8054415
0.8019792
0.78728247
0.78193
0.77108365
0.76428175
0.75724345
0.75075823
0.7411033
0.7392033
0.7296305
0.7196565
0.71166456
0.7058496
0.6999039
0.6956167
0.6877581
0.68062913
0.67444295
0.6687909
0.666411
0.6535747
0.6511683
0.6435649
0.63974446
0.63729036
0.63559884
0.6321424
0.6239543
0.61493677
0.6183981
0.6096776
0.6073146
0.5993138
0.59358007
0.5949444
0.5872536
0.58704114
0.5821489
0.57660115
0.57140446
0.56809527
0.5658814
0.5591117
0.55308956
0.5483318
0.5481536
0.5453068
0.54355454
0.5411277
0.5394622
0.5287483
0.53205395
0.5267897
0.52042353
0.51914907
0.5121167
0.5132708
0.5056295
0.5021682
0.5043211
0.50082844
0.49712506
0.49049583
0.488807
0.4836612
0.48589966
0.47914726
0.47941613
0.4763267
0.4696518
0.47023976
0.46876737
0.46203262
0.45890146
0.45814407
0.45125443
0.45364237
0.45154297
0.44755158
0.4429472
0.44174865
0.43937814
0.43662018
0.43433928
0.42819756
0.4290631
0.4241396
0.42134872
0.41700017
0.41936338
0.41664764
0.41054168
0.4094148
0.40712437
0.40857816
0.40158877
0.4031048
0.40195826
0.3939447
0.39499843
0.3902107
0.39367834
0.38610053
0.38710114
0.390421
0.37968183
0.38185933
0.381964
0.36958292
0.37819535
0.3724618
0.36656803
0.3674983
0.36255923
0.35879308
0.35838723
0.35515106
0.35121173
0.35158592
0.34738743
0.3486128
0.34715125
0.340986
0.3471296
0.34075132
0.33429274
0.34008977
0.335349
0.3268921
0.33041143
0.3265921
0.3224478
0.3244342
0.32013506
0.31521055
0.3153871
0.31209573
0.31137058
0.31365484
0.30831978
0.30435055
0.30782023
0.30088735
0.30174026
0.3167247
0.30817467
0.29165807
0.2964253
0.29488084
0.28661618
0.286528
0.28760532
0.2833498
0.27903017
0.27903017
retcode: Default
u: ComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[0.2608399 -0.7103775; -0.16067556 -1.1602527; … ; 0.15277196 1.2721673; 0.21492538 0.2944236], bias = Float32[0.36572495; 0.44362512; … ; -0.44807923; -0.53865325;;]), layer_3 = (weight = Float32[0.6236109 0.5328911 … -0.50796896 0.35644993; 0.051700193 -0.3421964 … 0.34809616 0.63404644], bias = Float32[-0.21957095; 0.017598046;;]))
Use another optimization algorithm Optim.BFGS()
and start from where the ADAM()
algorithm stopped.
optprob2 = remake(optprob, u0 = result_neuralode.u)
result_neuralode2 = Optimization.solve(
optprob2,
Optim.BFGS(initial_stepnorm=0.01),
callback=callback,
allow_f_increases = false
)
0.27903017
0.27872407
0.27863026
0.25121066
0.2420656
0.21821216
0.19049883
0.15103383
0.12721382
0.07557778
0.06568295
0.061216295
0.060143154
0.050319873
0.041621387
0.038229752
0.033189513
0.031393625
0.028698342
0.026024546
0.023069603
0.022244453
0.019386347
0.015208791
0.013807998
0.0113597335
0.010351848
0.010138731
0.008896519
0.007957236
0.007272527
0.0069238544
0.0061710677
0.0061710677
0.0058267526
0.0051714745
0.0044511706
0.004431646
0.003977054
0.003935433
0.0035506398
0.0033168232
0.00292864
0.002692305
0.002690089
0.0026872384
0.0026772963
0.0026385426
0.0026384164
0.0026384164
0.0026384164
retcode: Success
u: ComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[0.08741521 -0.7041453; -0.19451857 -0.99229836; … ; 0.19416788 1.120761; 0.31404698 0.026229795], bias = Float32[0.3758842; 0.6552975; … ; -0.6631577; -1.0059776;;]), layer_3 = (weight = Float32[0.665346 0.35397014 … -0.31324288 0.48092267; 0.0966155 -0.45088494 … 0.44247928 1.1319721], bias = Float32[-0.1843225; 0.8225856;;]))
Plot the solution to see if it matches the provided data.
callback(result_neuralode2.u, loss_neuralode(result_neuralode2.u)...; doplot=true)
0.0026384164
false
Animated solving process#
Let’s reset the problem and visualize the training process.
rng = Random.default_rng()
u0 = Float32[2.0; 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[begin], tspan[end], length = datasize)
0.0f0:0.05172414f0:1.5f0
Setup truth values for validation
true_A = Float32[-0.1 2.0; -2.0 -0.1]
function trueODEfunc!(du, u, p, t)
du .= ((u.^3)'true_A)'
end
trueODEfunc! (generic function with 1 method)
prob_trueode = ODEProblem(trueODEfunc!, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))
2×30 Matrix{Float32}:
2.0 1.9465 1.74178 1.23837 0.577126 … 1.40688 1.37023 1.29215
0.0 0.798832 1.46473 1.80877 1.86465 0.451358 0.728681 0.972087
nodeFunc = Lux.Chain(
x -> x.^3,
Lux.Dense(2, 50, tanh),
Lux.Dense(50, 2)
)
p, st = Lux.setup(rng, nodeFunc)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-0.2273688 -0.20968895; 0.054255627 0.20530364; … ; 0.15356942 0.24286065; -0.11287057 0.30330524], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (weight = Float32[-0.18407616 0.17299011 … -0.22141114 -0.07978176; 0.16965075 0.03569014 … 0.15526423 -0.3377969], bias = Float32[0.0; 0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()))
Parameters in the neural network:
p
(layer_1 = NamedTuple(), layer_2 = (weight = Float32[-0.2273688 -0.20968895; 0.054255627 0.20530364; … ; 0.15356942 0.24286065; -0.11287057 0.30330524], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (weight = Float32[-0.18407616 0.17299011 … -0.22141114 -0.07978176; 0.16965075 0.03569014 … 0.15526423 -0.3377969], bias = Float32[0.0; 0.0;;]))
Use NeuroODE()
to construct the problem
prob_node = NeuralODE(nodeFunc, tspan, Tsit5(), saveat = tsteps)
NeuralODE(
model = Chain(
layer_1 = WrappedFunction(#8),
layer_2 = Dense(2 => 50, tanh_fast), # 150 parameters
layer_3 = Dense(50 => 2), # 102 parameters
),
) # Total: 252 parameters,
# plus 0 states.
Predicted values.
function predict_neuralode(p)
Array(prob_node(u0, p, st)[1])
end
predict_neuralode (generic function with 1 method)
The loss function.
function loss_neuralode(p)
pred = predict_neuralode(p)
loss = sum(abs2, ode_data .- pred)
return loss, pred
end
loss_neuralode (generic function with 1 method)
Callback function to observe training process
anim = Animation()
callback = function (p, l, pred; doplot = true)
if doplot
plt = scatter(tsteps, ode_data[1,:], label = "data")
scatter!(plt, tsteps, pred[1,:], label = "prediction")
frame(anim)
end
return false
end
#10 (generic function with 1 method)
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentArray(p))
OptimizationProblem. In-place: true
u0: ComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[-0.2273688 -0.20968895; 0.054255627 0.20530364; … ; 0.15356942 0.24286065; -0.11287057 0.30330524], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (weight = Float32[-0.18407616 0.17299011 … -0.22141114 -0.07978176; 0.16965075 0.03569014 … 0.15526423 -0.3377969], bias = Float32[0.0; 0.0;;]))
Solve the problem using the ADAM optimizer
result_neuralode = Optimization.solve(
optprob,
OptimizationOptimisers.ADAM(0.05),
callback = callback,
maxiters = 300
)
retcode: Default
u: ComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[1.1471213 0.074072815; -0.33308297 0.38328815; … ; 0.10337194 1.6461912; 0.116916105 0.70201075], bias = Float32[-0.027155323; -0.05846334; … ; -0.2410494; 0.23793979;;]), layer_3 = (weight = Float32[0.39916682 -0.42204866 … -0.39229235 0.077795036; 0.6139497 -0.46670642 … 0.24933042 0.09802464], bias = Float32[-0.4860557; -0.10765326;;]))
And then solve the problem using the LBFGS optimizer
optprob2 = remake(optprob, u0 = result_neuralode.u)
result_neuralode2 = Optimization.solve(
optprob2,
Optim.LBFGS(),
callback = callback,
allow_f_increases = false
)
retcode: Success
u: ComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[1.1471213 0.07407282; -0.33308297 0.38328815; … ; 0.10337194 1.6461912; 0.116916105 0.70201075], bias = Float32[-0.027155323; -0.05846334; … ; -0.2410494; 0.23793979;;]), layer_3 = (weight = Float32[0.39916682 -0.42204866 … -0.39229235 0.07779504; 0.6139497 -0.46670642 … 0.24933042 0.09802464], bias = Float32[-0.4860557; -0.10765325;;]))
Visualize fitting process
mp4(anim, fps=15)
[ Info: Saved animation to /tmp/docs/ude/tmp.mp4