Multiple Shooting#
Docs: https://docs.sciml.ai/DiffEqFlux/dev/examples/multiple_shooting/
In Multiple Shooting, the training data is split into overlapping intervals. The solver is then trained on individual intervals.
The optimization is achieved by OptimizationPolyalgorithms.PolyOpt()
.
using Lux
using ComponentArrays
using DiffEqFlux
using Optimization
using OptimizationPolyalgorithms
using OrdinaryDiffEq
using DiffEqFlux: group_ranges
using Plots
using Random
rng = Random.default_rng()
TaskLocalRNG()
Define initial conditions and time steps
datasize = 30
u0 = Float32[2.0, 0.0]
tspan = (0.0f0, 5.0f0)
tsteps = range(tspan[begin], tspan[end], length = datasize)
0.0f0:0.1724138f0:5.0f0
True values
true_A = Float32[-0.1 2.0; -2.0 -0.1]
2×2 Matrix{Float32}:
-0.1 2.0
-2.0 -0.1
Generate data from the truth function.
function trueODEfunc!(du, u, p, t)
du .= ((u.^3)'true_A)'
end
prob_trueode = ODEProblem(trueODEfunc!, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))
2×30 Matrix{Float32}:
2.0 1.02407 -1.07771 -1.70875 … 0.317698 0.018593 -0.266894
0.0 1.84867 1.72464 0.323419 0.958795 0.946583 0.930795
Define the Neural Network
nn = Lux.Chain(
x -> x.^3,
Lux.Dense(2, 16, tanh),
Lux.Dense(16, 2)
)
p_init, st = Lux.setup(rng, nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-0.44528675 0.43915123; 0.21514416 0.16796206; … ; 0.010940022 -0.22445899; -0.13767284 -0.17843497], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (weight = Float32[0.4005401 0.4104821 … -0.4990906 0.34355137; -0.39049003 -0.5041938 … -0.06933709 0.3783746], bias = Float32[0.0; 0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()))
Define the NeuralODE
problem
neuralode = NeuralODE(nn, tspan, Tsit5(), saveat = tsteps)
prob_node = ODEProblem((u,p,t)->nn(u,p,st)[1], u0, tspan, ComponentArray(p_init))
ODEProblem with uType Vector{Float32} and tType Float32. In-place: false
timespan: (0.0f0, 5.0f0)
u0: 2-element Vector{Float32}:
2.0
0.0
function plot_multiple_shoot(plt, preds, group_size)
step = group_size-1
ranges = group_ranges(datasize, group_size)
for (i, rg) in enumerate(ranges)
plot!(plt, tsteps[rg], preds[i][1,:], markershape=:circle, label="Group $(i)")
end
end
plot_multiple_shoot (generic function with 1 method)
Animate training process by the callback function
anim = Animation()
callback = function (p, l, preds; doplot = true)
# display(l)
if doplot
# plot the original data
plt = scatter(tsteps, ode_data[1,:], label = "Data")
# plot the different predictions for individual shoot
plot_multiple_shoot(plt, preds, group_size)
frame(anim)
# display(plot(plt))
end
return false
end
#5 (generic function with 1 method)
Define parameters for Multiple Shooting
group_size = 3
continuity_term = 200
function loss_function(data, pred)
return sum(abs2, data - pred)
end
ps = ComponentArray(p_init)
pd, pax = getdata(ps), getaxes(ps)
function loss_multiple_shooting(p)
ps = ComponentArray(p, pax)
return multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function, Tsit5(), group_size; continuity_term)
end
loss_multiple_shooting (generic function with 1 method)
Solve the problem
adtype = Optimization.AutoForwardDiff()
optf = Optimization.OptimizationFunction((x,p) -> loss_multiple_shooting(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pd)
OptimizationProblem. In-place: true
u0: 82-element Vector{Float32}:
-0.44528675
0.21514416
0.07682035
-0.20898351
0.051997196
-0.5316986
0.1930279
0.26410097
0.23659101
-0.060885802
0.5507948
-0.0071215825
-0.23172091
⋮
0.10886517
-0.08881182
0.002469184
-0.09091162
0.21303624
-0.17846352
-0.4990906
-0.06933709
0.34355137
0.3783746
0.0
0.0
res_ms = Optimization.solve(optprob, PolyOpt(), callback = callback)
retcode: Success
u: 82-element Vector{Float32}:
-0.1541229
-0.20206842
-0.14872874
-0.0017683604
0.12575443
-0.2708575
0.12641384
-0.081900075
0.20303464
0.2473409
0.021652604
-0.115582466
-0.28945136
⋮
1.3281618
-1.809715
1.4031996
-1.5424356
1.2134823
-1.4007381
-1.6079698
1.5099764
1.8800535
0.5374997
-0.7142393
-0.047968287
Visualize the fitting processes
mp4(anim, fps=15)
[ Info: Saved animation to /home/runner/work/jl-ude/jl-ude/docs/tmp.mp4