Multiple Shooting#
Docs: https://docs.sciml.ai/DiffEqFlux/dev/examples/multiple_shooting/
In Multiple Shooting, the training data is split into overlapping intervals. The solver (OptimizationPolyalgorithms.PolyOpt()) is trained on individual intervals. The results are stiched together.
This simple method assumes no noise in the data. A more robust version can be found at JuliaSimModelOptimizer.jl, which is a proprietary software.
using Lux
using ComponentArrays
using DiffEqFlux
using DiffEqFlux: group_ranges
using Optimization
using OptimizationPolyalgorithms
using OrdinaryDiffEq
using Plots
using Random
rng = Random.Xoshiro(0)
Random.Xoshiro(0xdb2fa90498613fdf, 0x48d73dc42d195740, 0x8c49bc52dc8a77ea, 0x1911b814c02405e8, 0x22a21880af5dc689)
Define initial conditions and time steps
datasize = 51
u0 = [2.0, 0.0]
tspan = (0.0, 5.0)
tsteps = range(tspan[begin], tspan[end], length = datasize)
0.0:0.1:5.0
True values
true_A = [-0.1 2.0; -2.0 -0.1]
2×2 Matrix{Float64}:
-0.1 2.0
-2.0 -0.1
Generate data from the true function: \(x^3 * A\)
function trueODEfunc!(du, u, p, t)
du .= ((u.^3)'true_A)'
end
prob_trueode = ODEProblem(trueODEfunc!, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))
2×51 Matrix{Float64}:
2.0 1.76453 0.666824 -0.580558 … 0.0306948 -0.138592 -0.302836
0.0 1.4286 1.86579 1.80634 0.950209 0.941635 0.930954
Define the Neural Network using Lux.jl Notice the network is smaller than the first example.
nn = Lux.Chain(
x -> x.^3,
Lux.Dense(2, 16, tanh),
Lux.Dense(16, 2)
)
p_init, st = Lux.setup(rng, nn) |> f64
ps = ComponentArray(p_init)
pd, pax = getdata(ps), getaxes(ps)
([-1.8019577264785767, -0.18273845314979553, 1.6776520013809204, 0.19449931383132935, 0.7557111978530884, 1.1159610748291016, -1.581186056137085, 1.7986798286437988, -0.36156967282295227, -1.9202053546905518 … 0.3558240234851837, -0.2906492352485657, 0.32653868198394775, 0.3687601387500763, -0.3538714349269867, 0.12959939241409302, 0.256054550409317, -0.20957911014556885, 0.10817152261734009, -0.20544955134391785], (Axis(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:48, Axis(weight = ViewAxis(1:32, ShapedAxis((16, 2))), bias = ViewAxis(33:48, Shaped1DAxis((16,))))), layer_3 = ViewAxis(49:82, Axis(weight = ViewAxis(1:32, ShapedAxis((2, 16))), bias = ViewAxis(33:34, Shaped1DAxis((2,)))))),))
Define the NeuralODE problem
neuralode = NeuralODE(nn, tspan, Tsit5(), saveat = tsteps)
prob_node = ODEProblem((u,p,t)->nn(u,p,st)[1], u0, tspan, ComponentArray(p_init))
ODEProblem with uType Vector{Float64} and tType Float64. In-place: false
Non-trivial mass matrix: false
timespan: (0.0, 5.0)
u0: 2-element Vector{Float64}:
2.0
0.0
Animate training process in the callback function
function plot_multiple_shoot(plt, preds, group_size)
ranges = group_ranges(datasize, group_size)
for (i, rg) in enumerate(ranges)
plot!(plt, tsteps[rg], preds[i][1,:], markershape=:circle, label="Group $(i)")
end
end
anim = Animation()
lossrecord=Float64[]
callback = function (state, l; doplot = true)
if doplot
plt = scatter(tsteps, ode_data[1,:], label = "Data")
plot_multiple_shoot(plt, preds, group_size)
frame(anim)
push!(lossrecord, l)
end
return false
end
#9 (generic function with 1 method)
Parameters for Multiple Shooting
group_size = 3
continuity_term = 200 ## Penalty for discontinuity
function loss_function(data, pred)
return sum(abs2, data .- pred)
end
l1, preds = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function, Tsit5(), group_size; continuity_term)
function loss_multiple_shooting(p)
ps = ComponentArray(p, pax)
loss, currpred = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function,
Tsit5(), group_size; continuity_term)
global preds = currpred
return loss
end
loss_multiple_shooting (generic function with 1 method)
Solve the problem using OptimizationPolyalgorithms.PolyOpt().
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p) -> loss_multiple_shooting(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pd)
res_ms = Optimization.solve(optprob, PolyOpt(), callback = callback)
println("Loss is ", loss_multiple_shooting(res_ms.u)[1])
Loss is 1.6362929930673293
Visualize the fitting processes
mp4(anim, fps=15)
[ Info: Saved animation to /tmp/jl_9QBcN8RgFO.mp4
This notebook was generated using Literate.jl.