Multiple Shooting

Multiple Shooting#

Docs: https://docs.sciml.ai/DiffEqFlux/dev/examples/multiple_shooting/

In Multiple Shooting, the training data is split into overlapping intervals. The solver (OptimizationPolyalgorithms.PolyOpt()) is trained on individual intervals. The results are stiched together.

This simple method assumes no noise in the data. A more robust version can be found at JuliaSimModelOptimizer.jl, which is a proprietary software.

using Lux
using ComponentArrays
using DiffEqFlux
using DiffEqFlux: group_ranges
using Optimization
using OptimizationPolyalgorithms
using OrdinaryDiffEq
using Plots
using Random
rng = Random.Xoshiro(0)
Random.Xoshiro(0xdb2fa90498613fdf, 0x48d73dc42d195740, 0x8c49bc52dc8a77ea, 0x1911b814c02405e8, 0x22a21880af5dc689)

Define initial conditions and time steps

datasize = 51
u0 = [2.0, 0.0]
tspan = (0.0, 5.0)
tsteps = range(tspan[begin], tspan[end], length = datasize)
0.0:0.1:5.0

True values

true_A = [-0.1 2.0; -2.0 -0.1]
2×2 Matrix{Float64}:
 -0.1   2.0
 -2.0  -0.1

Generate data from the true function: \(x^3 * A\)

function trueODEfunc!(du, u, p, t)
    du .= ((u.^3)'true_A)'
end
prob_trueode = ODEProblem(trueODEfunc!, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))
2×51 Matrix{Float64}:
 2.0  1.76453  0.666824  -0.580558  …  0.0306948  -0.138592  -0.302836
 0.0  1.4286   1.86579    1.80634      0.950209    0.941635   0.930954

Define the Neural Network using Lux.jl Notice the network is smaller than the first example.

nn = Lux.Chain(
    x -> x.^3,
    Lux.Dense(2, 16, tanh),
    Lux.Dense(16, 2)
)
p_init, st = Lux.setup(rng, nn) |> f64
ps = ComponentArray(p_init)
pd, pax = getdata(ps), getaxes(ps)
([-1.8019577264785767, -0.18273845314979553, 1.6776520013809204, 0.19449931383132935, 0.7557111978530884, 1.1159610748291016, -1.581186056137085, 1.7986798286437988, -0.36156967282295227, -1.9202053546905518  …  0.3558240234851837, -0.2906492352485657, 0.32653868198394775, 0.3687601387500763, -0.3538714349269867, 0.12959939241409302, 0.256054550409317, -0.20957911014556885, 0.10817152261734009, -0.20544955134391785], (Axis(layer_1 = ViewAxis(1:0, Shaped1DAxis((0,))), layer_2 = ViewAxis(1:48, Axis(weight = ViewAxis(1:32, ShapedAxis((16, 2))), bias = ViewAxis(33:48, Shaped1DAxis((16,))))), layer_3 = ViewAxis(49:82, Axis(weight = ViewAxis(1:32, ShapedAxis((2, 16))), bias = ViewAxis(33:34, Shaped1DAxis((2,)))))),))

Define the NeuralODE problem

neuralode = NeuralODE(nn, tspan, Tsit5(), saveat = tsteps)
prob_node = ODEProblem((u,p,t)->nn(u,p,st)[1], u0, tspan, ComponentArray(p_init))
ODEProblem with uType Vector{Float64} and tType Float64. In-place: false
Non-trivial mass matrix: false
timespan: (0.0, 5.0)
u0: 2-element Vector{Float64}:
 2.0
 0.0

Animate training process in the callback function

function plot_multiple_shoot(plt, preds, group_size)
	ranges = group_ranges(datasize, group_size)
	for (i, rg) in enumerate(ranges)
		plot!(plt, tsteps[rg], preds[i][1,:], markershape=:circle, label="Group $(i)")
	end
end

anim = Animation()
lossrecord=Float64[]
callback = function (state, l; doplot = true)
    if doplot
        plt = scatter(tsteps, ode_data[1,:], label = "Data")
        plot_multiple_shoot(plt, preds, group_size)
        frame(anim)
        push!(lossrecord, l)
    end
    return false
end
#9 (generic function with 1 method)

Parameters for Multiple Shooting

group_size = 3
continuity_term = 200  ## Penalty for discontinuity

function loss_function(data, pred)
    return sum(abs2, data .- pred)
end

l1, preds = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function, Tsit5(), group_size; continuity_term)

function loss_multiple_shooting(p)
    ps = ComponentArray(p, pax)

    loss, currpred = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function,
        Tsit5(), group_size; continuity_term)
    global preds = currpred
    return loss
end
loss_multiple_shooting (generic function with 1 method)

Solve the problem using OptimizationPolyalgorithms.PolyOpt().

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p) -> loss_multiple_shooting(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pd)
res_ms = Optimization.solve(optprob, PolyOpt(), callback = callback)

println("Loss is ", loss_multiple_shooting(res_ms.u)[1])
Loss is 1.6006485439536657

Visualize the fitting processes

mp4(anim, fps=15)
[ Info: Saved animation to /tmp/jl_9u1MHADyp4.mp4

Runtime environment#

using Pkg
Pkg.status()
Status `~/work/jl-pde/jl-pde/Project.toml`
  [b0b7db55] ComponentArrays v0.15.30
⌃ [aae7a2af] DiffEqFlux v4.4.0
  [5b8099bc] DomainSets v0.7.16
  [7da242da] Enzyme v0.13.108
  [f6369f11] ForwardDiff v1.3.0
  [d3d80556] LineSearches v7.5.1
⌃ [b2108857] Lux v1.21.1
  [94925ecb] MethodOfLines v0.11.9
  [961ee093] ModelingToolkit v10.31.0
  [315f7962] NeuralPDE v5.20.0
  [8913a72c] NonlinearSolve v4.12.0
⌅ [7f7a1694] Optimization v4.8.0
⌃ [36348300] OptimizationOptimJL v0.4.5
⌃ [42dfb2eb] OptimizationOptimisers v0.3.11
⌃ [500b13db] OptimizationPolyalgorithms v0.3.1
  [1dea7af3] OrdinaryDiffEq v6.103.0
  [91a5bcdd] Plots v1.41.2
  [ce78b400] SimpleUnPack v1.1.0
  [37e2e46d] LinearAlgebra v1.12.0
  [9a3f8284] Random v1.11.0
Info Packages marked with ⌃ and ⌅ have new versions available. Those with ⌃ may be upgradable, but those with ⌅ are restricted by compatibility constraints from upgrading. To see why use `status --outdated`
using InteractiveUtils
InteractiveUtils.versioninfo()
Julia Version 1.12.2
Commit ca9b6662be4 (2025-11-20 16:25 UTC)
Build Info:
  Official https://julialang.org release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-18.1.7 (ORCJIT, znver3)
  GC: Built with stock GC
Threads: 4 default, 1 interactive, 4 GC (on 4 virtual cores)
Environment:
  JULIA_CPU_TARGET = generic;icelake-server,clone_all;znver3,clone_all
  JULIA_CONDAPKG_OFFLINE = true
  JULIA_CONDAPKG_BACKEND = Null
  JULIA_CI = true
  LD_LIBRARY_PATH = /opt/hostedtoolcache/Python/3.13.9/x64/lib
  JULIA_NUM_THREADS = auto

This notebook was generated using Literate.jl.