Multiple Shooting#
Docs: https://docs.sciml.ai/DiffEqFlux/dev/examples/multiple_shooting/
In Multiple Shooting, the training data is split into overlapping intervals. The solver is then trained on individual intervals.
The optimization is achieved by OptimizationPolyalgorithms.PolyOpt()
.
using Lux
using ComponentArrays
using DiffEqFlux
using Optimization
using OptimizationPolyalgorithms
using OrdinaryDiffEq
using DiffEqFlux: group_ranges
using Plots
using Random
rng = Random.default_rng()
[ Info: Precompiling IJuliaExt [2f4121a4-3b3a-5ce6-9c5e-1f2673ce168a]
TaskLocalRNG()
Define initial conditions and time steps
datasize = 30
u0 = Float32[2.0, 0.0]
tspan = (0.0f0, 5.0f0)
tsteps = range(tspan[begin], tspan[end], length = datasize)
0.0f0:0.1724138f0:5.0f0
True values
true_A = Float32[-0.1 2.0; -2.0 -0.1]
2×2 Matrix{Float32}:
-0.1 2.0
-2.0 -0.1
Generate data from the truth function.
function trueODEfunc!(du, u, p, t)
du .= ((u.^3)'true_A)'
end
prob_trueode = ODEProblem(trueODEfunc!, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))
2×30 Matrix{Float32}:
2.0 1.02407 -1.07771 -1.70875 … 0.317698 0.018593 -0.266894
0.0 1.84867 1.72464 0.323419 0.958795 0.946583 0.930795
Define the Neural Network
nn = Lux.Chain(
x -> x.^3,
Lux.Dense(2, 16, tanh),
Lux.Dense(16, 2)
)
p_init, st = Lux.setup(rng, nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-0.4577508 -0.3664843; -0.28135064 -0.37555605; … ; -0.016999144 -0.078374915; -0.21808596 0.10003073], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (weight = Float32[0.11594153 -0.2533249 … 0.10465174 0.36619797; -0.33947787 0.4672015 … -0.1425245 0.3905742], bias = Float32[0.0; 0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()))
Define the NeuralODE
problem
neuralode = NeuralODE(nn, tspan, Tsit5(), saveat = tsteps)
prob_node = ODEProblem((u,p,t)->nn(u,p,st)[1], u0, tspan, ComponentArray(p_init))
ODEProblem with uType Vector{Float32} and tType Float32. In-place: false
timespan: (0.0f0, 5.0f0)
u0: 2-element Vector{Float32}:
2.0
0.0
function plot_multiple_shoot(plt, preds, group_size)
step = group_size-1
ranges = group_ranges(datasize, group_size)
for (i, rg) in enumerate(ranges)
plot!(plt, tsteps[rg], preds[i][1,:], markershape=:circle, label="Group $(i)")
end
end
plot_multiple_shoot (generic function with 1 method)
Animate training process by the callback function
anim = Animation()
callback = function (p, l, preds; doplot = true)
# display(l)
if doplot
# plot the original data
plt = scatter(tsteps, ode_data[1,:], label = "Data")
# plot the different predictions for individual shoot
plot_multiple_shoot(plt, preds, group_size)
frame(anim)
# display(plot(plt))
end
return false
end
#5 (generic function with 1 method)
Define parameters for Multiple Shooting
group_size = 3
continuity_term = 200
function loss_function(data, pred)
return sum(abs2, data - pred)
end
ps = ComponentArray(p_init)
pd, pax = getdata(ps), getaxes(ps)
function loss_multiple_shooting(p)
ps = ComponentArray(p, pax)
return multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function, Tsit5(), group_size; continuity_term)
end
loss_multiple_shooting (generic function with 1 method)
Solve the problem
adtype = Optimization.AutoForwardDiff()
optf = Optimization.OptimizationFunction((x,p) -> loss_multiple_shooting(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pd)
OptimizationProblem. In-place: true
u0: 82-element Vector{Float32}:
-0.4577508
-0.28135064
-0.16281886
-0.34549278
0.029121796
-0.17042898
0.44449416
-0.06095814
0.41254535
-0.086986296
-0.45373997
-0.43513548
-0.29576224
⋮
0.53160167
-0.3236842
-0.12832902
0.23629871
-0.496907
-0.4534503
0.10465174
-0.1425245
0.36619797
0.3905742
0.0
0.0
res_ms = Optimization.solve(optprob, PolyOpt(), callback = callback)
retcode: Success
u: 82-element Vector{Float32}:
-0.1261983
0.1885271
-0.15074702
0.40682003
-0.052960582
-0.20547068
-0.116966546
-0.033597942
0.16122933
-0.2504757
-0.39861146
-0.14041175
-0.49000475
⋮
1.2710185
-1.5773426
1.5103456
-0.570371
-2.1141646
-0.273884
1.5661979
-1.4449673
1.5384969
1.5453333
-0.39173642
-0.2685625
Visualize the fitting processes
mp4(anim, fps=15)
[ Info: Saved animation to /tmp/docs/ude/tmp.mp4