Multiple Shooting

Multiple Shooting#

Docs: https://docs.sciml.ai/DiffEqFlux/dev/examples/multiple_shooting/

In Multiple Shooting, the training data is split into overlapping intervals. The solver is then trained on individual intervals.

The optimization is achieved by OptimizationPolyalgorithms.PolyOpt().

using Lux
using ComponentArrays
using DiffEqFlux
using Optimization
using OptimizationPolyalgorithms
using OrdinaryDiffEq
using DiffEqFlux: group_ranges
using Plots
using Random
rng = Random.default_rng()
[ Info: Precompiling IJuliaExt [2f4121a4-3b3a-5ce6-9c5e-1f2673ce168a]
TaskLocalRNG()

Define initial conditions and time steps

datasize = 30
u0 = Float32[2.0, 0.0]
tspan = (0.0f0, 5.0f0)
tsteps = range(tspan[begin], tspan[end], length = datasize)
0.0f0:0.1724138f0:5.0f0

True values

true_A = Float32[-0.1 2.0; -2.0 -0.1]
2×2 Matrix{Float32}:
 -0.1   2.0
 -2.0  -0.1

Generate data from the truth function.

function trueODEfunc!(du, u, p, t)
    du .= ((u.^3)'true_A)'
end

prob_trueode = ODEProblem(trueODEfunc!, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))
2×30 Matrix{Float32}:
 2.0  1.02407  -1.07771  -1.70875   …  0.317698  0.018593  -0.266894
 0.0  1.84867   1.72464   0.323419     0.958795  0.946583   0.930795

Define the Neural Network

nn = Lux.Chain(
    x -> x.^3,
    Lux.Dense(2, 16, tanh),
    Lux.Dense(16, 2)
)
p_init, st = Lux.setup(rng, nn)
((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-0.4577508 -0.3664843; -0.28135064 -0.37555605; … ; -0.016999144 -0.078374915; -0.21808596 0.10003073], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (weight = Float32[0.11594153 -0.2533249 … 0.10465174 0.36619797; -0.33947787 0.4672015 … -0.1425245 0.3905742], bias = Float32[0.0; 0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()))

Define the NeuralODE problem

neuralode = NeuralODE(nn, tspan, Tsit5(), saveat = tsteps)
prob_node = ODEProblem((u,p,t)->nn(u,p,st)[1], u0, tspan, ComponentArray(p_init))
ODEProblem with uType Vector{Float32} and tType Float32. In-place: false
timespan: (0.0f0, 5.0f0)
u0: 2-element Vector{Float32}:
 2.0
 0.0
function plot_multiple_shoot(plt, preds, group_size)
	step = group_size-1
	ranges = group_ranges(datasize, group_size)

	for (i, rg) in enumerate(ranges)
		plot!(plt, tsteps[rg], preds[i][1,:], markershape=:circle, label="Group $(i)")
	end
end
plot_multiple_shoot (generic function with 1 method)

Animate training process by the callback function

anim = Animation()
callback = function (p, l, preds; doplot = true)
    # display(l)
    if doplot
        # plot the original data
        plt = scatter(tsteps, ode_data[1,:], label = "Data")
        # plot the different predictions for individual shoot
        plot_multiple_shoot(plt, preds, group_size)
        frame(anim)
        # display(plot(plt))
    end
    return false
end
#5 (generic function with 1 method)

Define parameters for Multiple Shooting

group_size = 3
continuity_term = 200

function loss_function(data, pred)
	return sum(abs2, data - pred)
end

ps = ComponentArray(p_init)
pd, pax = getdata(ps), getaxes(ps)

function loss_multiple_shooting(p)
    ps = ComponentArray(p, pax)
    return multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function, Tsit5(), group_size; continuity_term)
end
loss_multiple_shooting (generic function with 1 method)

Solve the problem

adtype = Optimization.AutoForwardDiff()
optf = Optimization.OptimizationFunction((x,p) -> loss_multiple_shooting(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pd)
OptimizationProblem. In-place: true
u0: 82-element Vector{Float32}:
 -0.4577508
 -0.28135064
 -0.16281886
 -0.34549278
  0.029121796
 -0.17042898
  0.44449416
 -0.06095814
  0.41254535
 -0.086986296
 -0.45373997
 -0.43513548
 -0.29576224
  ⋮
  0.53160167
 -0.3236842
 -0.12832902
  0.23629871
 -0.496907
 -0.4534503
  0.10465174
 -0.1425245
  0.36619797
  0.3905742
  0.0
  0.0
res_ms = Optimization.solve(optprob, PolyOpt(), callback = callback)
retcode: Success
u: 82-element Vector{Float32}:
 -0.1261983
  0.1885271
 -0.15074702
  0.40682003
 -0.052960582
 -0.20547068
 -0.116966546
 -0.033597942
  0.16122933
 -0.2504757
 -0.39861146
 -0.14041175
 -0.49000475
  ⋮
  1.2710185
 -1.5773426
  1.5103456
 -0.570371
 -2.1141646
 -0.273884
  1.5661979
 -1.4449673
  1.5384969
  1.5453333
 -0.39173642
 -0.2685625

Visualize the fitting processes

mp4(anim, fps=15)
[ Info: Saved animation to /tmp/docs/ude/tmp.mp4