# Multiple Shooting

Docs: https://docs.sciml.ai/DiffEqFlux/dev/examples/multiple_shooting/

In Multiple Shooting, the training data is split into overlapping intervals. The solver is then trained on individual intervals.

The optimization is achieved by `OptimizationPolyalgorithms.PolyOpt()`.

In [1]:
using Lux
using ComponentArrays
using DiffEqFlux
using Optimization
using OptimizationPolyalgorithms
using OrdinaryDiffEq
using DiffEqFlux: group_ranges
using Plots
using Random
rng = Random.default_rng()

TaskLocalRNG()

Define initial conditions and time steps

In [2]:
datasize = 30
u0 = Float32[2.0, 0.0]
tspan = (0.0f0, 5.0f0)
tsteps = range(tspan[begin], tspan[end], length = datasize)

0.0f0:0.1724138f0:5.0f0

True values

In [3]:
true_A = Float32[-0.1 2.0; -2.0 -0.1]

2×2 Matrix{Float32}:
 -0.1   2.0
 -2.0  -0.1

Generate data from the truth function.

In [4]:
function trueODEfunc!(du, u, p, t)
    du .= ((u.^3)'true_A)'
end

prob_trueode = ODEProblem(trueODEfunc!, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))

2×30 Matrix{Float32}:
 2.0  1.02407  -1.07771  -1.70875   …  0.317698  0.018593  -0.266894
 0.0  1.84867   1.72464   0.323419     0.958795  0.946583   0.930795

Define the Neural Network

In [5]:
nn = Lux.Chain(
    x -> x.^3,
    Lux.Dense(2, 16, tanh),
    Lux.Dense(16, 2)
)
p_init, st = Lux.setup(rng, nn)

((layer_1 = NamedTuple(), layer_2 = (weight = Float32[0.4253354 -0.5123715; 0.3929301 -0.17033812; … ; 0.25619218 0.26161644; -0.27258378 0.56273407], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (weight = Float32[-0.032363545 0.08579245 … 0.13629021 0.3513656; -0.3327108 -0.5358798 … 0.42824 -0.3265682], bias = Float32[0.0; 0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()))

Define the `NeuralODE` problem

In [6]:
neuralode = NeuralODE(nn, tspan, Tsit5(), saveat = tsteps)
prob_node = ODEProblem((u,p,t)->nn(u,p,st)[1], u0, tspan, ComponentArray(p_init))

[38;2;86;182;194mODEProblem[0m with uType [38;2;86;182;194mVector{Float32}[0m and tType [38;2;86;182;194mFloat32[0m. In-place: [38;2;86;182;194mfalse[0m
timespan: (0.0f0, 5.0f0)
u0: 2-element Vector{Float32}:
 2.0
 0.0

In [7]:
function plot_multiple_shoot(plt, preds, group_size)
	step = group_size-1
	ranges = group_ranges(datasize, group_size)

	for (i, rg) in enumerate(ranges)
		plot!(plt, tsteps[rg], preds[i][1,:], markershape=:circle, label="Group $(i)")
	end
end

plot_multiple_shoot (generic function with 1 method)

Animate training process by the callback function

In [8]:
anim = Animation()
callback = function (p, l, preds; doplot = true)
    # display(l)
    if doplot
        # plot the original data
        plt = scatter(tsteps, ode_data[1,:], label = "Data")
        # plot the different predictions for individual shoot
        plot_multiple_shoot(plt, preds, group_size)
        frame(anim)
        # display(plot(plt))
    end
    return false
end

#5 (generic function with 1 method)

Define parameters for Multiple Shooting

In [9]:
group_size = 3
continuity_term = 200

function loss_function(data, pred)
	return sum(abs2, data - pred)
end

ps = ComponentArray(p_init)
pd, pax = getdata(ps), getaxes(ps)

function loss_multiple_shooting(p)
    ps = ComponentArray(p, pax)
    return multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function, Tsit5(), group_size; continuity_term)
end

loss_multiple_shooting (generic function with 1 method)

Solve the problem

In [10]:
adtype = Optimization.AutoForwardDiff()
optf = Optimization.OptimizationFunction((x,p) -> loss_multiple_shooting(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pd)

[38;2;86;182;194mOptimizationProblem[0m. In-place: [38;2;86;182;194mtrue[0m
u0: 82-element Vector{Float32}:
  0.4253354
  0.3929301
 -0.3915948
  0.47407365
 -0.4691419
 -0.45513412
 -0.5348858
  0.058694467
  0.021952586
 -0.27349222
  0.0229849
  0.5718287
  0.37022027
  ⋮
  0.2037728
  0.1541715
 -0.0024424798
 -0.078487374
 -0.55349237
 -0.5158906
  0.13629021
  0.42824
  0.3513656
 -0.3265682
  0.0
  0.0

In [11]:
res_ms = Optimization.solve(optprob, PolyOpt(), callback = callback)

retcode: Success
u: 82-element Vector{Float32}:
  0.11198329
 -0.15426649
 -0.18752277
  0.13290909
 -0.26372623
  0.31958792
 -0.13469425
  0.12247226
  0.18524134
  0.26363507
  0.13074891
  0.14819647
  0.3990616
  ⋮
  1.1465963
  1.1108792
 -1.2939013
  1.5562027
 -2.4652734
 -0.2997964
 -0.71357197
  1.7802604
  0.08945161
 -1.570344
 -0.6255056
  0.20779555

Visualize the fitting processes

In [12]:
mp4(anim, fps=15)

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mSaved animation to /home/runner/work/jl-ude/jl-ude/docs/tmp.mp4
