Multiple Shooting#
Docs: https://docs.sciml.ai/DiffEqFlux/dev/examples/multiple_shooting/
In Multiple Shooting, the training data is split into overlapping intervals. The solver (OptimizationPolyalgorithms.PolyOpt()
) is trained on individual intervals. The results are stiched together.
This simple method assumes no noise in the data. A more robust version can be found at JuliaSimModelOptimizer.jl, which is a proprietary software.
using Lux
using ComponentArrays
using DiffEqFlux
using DiffEqFlux: group_ranges
using Optimization
using OptimizationPolyalgorithms
using OrdinaryDiffEq
using Plots
using Random
rng = Random.Xoshiro(0)
Random.Xoshiro(0xdb2fa90498613fdf, 0x48d73dc42d195740, 0x8c49bc52dc8a77ea, 0x1911b814c02405e8, 0x22a21880af5dc689)
Define initial conditions and time steps
datasize = 51
u0 = [2.0, 0.0]
tspan = (0.0, 5.0)
tsteps = range(tspan[begin], tspan[end], length = datasize)
0.0:0.1:5.0
True values
true_A = [-0.1 2.0; -2.0 -0.1]
2×2 Matrix{Float64}:
-0.1 2.0
-2.0 -0.1
Generate data from the true function: \(x^3 * A\)
function trueODEfunc!(du, u, p, t)
du .= ((u.^3)'true_A)'
end
prob_trueode = ODEProblem(trueODEfunc!, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))
2×51 Matrix{Float64}:
2.0 1.76453 0.666824 -0.580558 … 0.0306948 -0.138592 -0.302836
0.0 1.4286 1.86579 1.80634 0.950209 0.941635 0.930954
Define the Neural Network using Lux.jl Notice the network is smaller than the first example.
nn = Lux.Chain(
x -> x.^3,
Lux.Dense(2, 16, tanh),
Lux.Dense(16, 2)
)
p_init, st = Lux.setup(rng, nn) |> f64
ps = ComponentArray(p_init)
pd, pax = getdata(ps), getaxes(ps)
([-1.8019577264785767, -0.18273845314979553, 1.6776520013809204, 0.19449931383132935, 0.7557111978530884, 1.1159610748291016, -1.581186056137085, 1.7986798286437988, -0.36156967282295227, -1.9202053546905518 … 0.3558240234851837, -0.2906492352485657, 0.32653868198394775, 0.3687601387500763, -0.3538714349269867, 0.12959939241409302, 0.256054550409317, -0.20957911014556885, 0.10817152261734009, -0.20544955134391785], (Axis(layer_1 = 1:0, layer_2 = ViewAxis(1:48, Axis(weight = ViewAxis(1:32, ShapedAxis((16, 2))), bias = 33:48)), layer_3 = ViewAxis(49:82, Axis(weight = ViewAxis(1:32, ShapedAxis((2, 16))), bias = 33:34))),))
Define the NeuralODE
problem
neuralode = NeuralODE(nn, tspan, Tsit5(), saveat = tsteps)
prob_node = ODEProblem((u,p,t)->nn(u,p,st)[1], u0, tspan, ComponentArray(p_init))
ODEProblem with uType Vector{Float64} and tType Float64. In-place: false
timespan: (0.0, 5.0)
u0: 2-element Vector{Float64}:
2.0
0.0
Animate training process in the callback function
function plot_multiple_shoot(plt, preds, group_size)
ranges = group_ranges(datasize, group_size)
for (i, rg) in enumerate(ranges)
plot!(plt, tsteps[rg], preds[i][1,:], markershape=:circle, label="Group $(i)")
end
end
anim = Animation()
lossrecord=Float64[]
callback = function (state, l; doplot = true)
if doplot
plt = scatter(tsteps, ode_data[1,:], label = "Data")
plot_multiple_shoot(plt, preds, group_size)
frame(anim)
push!(lossrecord, l)
end
return false
end
#5 (generic function with 1 method)
Parameters for Multiple Shooting
group_size = 3
continuity_term = 200 ## Penalty for discontinuity
function loss_function(data, pred)
return sum(abs2, data .- pred)
end
l1, preds = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function, Tsit5(), group_size; continuity_term)
function loss_multiple_shooting(p)
ps = ComponentArray(p, pax)
loss, currpred = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function,
Tsit5(), group_size; continuity_term)
global preds = currpred
return loss
end
loss_multiple_shooting (generic function with 1 method)
Solve the problem using OptimizationPolyalgorithms.PolyOpt()
.
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p) -> loss_multiple_shooting(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pd)
res_ms = Optimization.solve(optprob, PolyOpt(), callback = callback)
println("Loss is ", loss_multiple_shooting(res_ms.u)[1])
Loss is 1.2367094321016447
Visualize the fitting processes
mp4(anim, fps=15)
[ Info: Saved animation to /home/runner/work/jl-ude/jl-ude/.cache/docs/tmp.mp4
This notebook was generated using Literate.jl.