First Neural ODE example

First Neural ODE example#

A neural ODE is an ODE where a neural network defines its derivative function. \(\dot{u} = NN(u)\)

From: https://docs.sciml.ai/DiffEqFlux/stable/examples/neural_ode/

using Lux, DiffEqFlux, OrdinaryDiffEq, ComponentArrays
using Optimization, OptimizationOptimJL, OptimizationOptimisers
using Random, Plots
rng = Random.Xoshiro(0)
Random.Xoshiro(0xdb2fa90498613fdf, 0x48d73dc42d195740, 0x8c49bc52dc8a77ea, 0x1911b814c02405e8, 0x22a21880af5dc689)

True solution: \(u^3\) and multiplied by a matrix

function trueODEfunc(du, u, p, t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u.^3)'true_A)'
end
trueODEfunc (generic function with 1 method)

Generate data from the true function

u0 = [2.0; 0.0]
datasize = 31
tspan = (0.0, 1.5)
tsteps = range(tspan[begin], tspan[end], length = datasize)
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))
2×31 Matrix{Float64}:
 2.0  1.94946   1.76453  1.29973  0.666824  …  1.40811  1.36939   1.28907
 0.0  0.773427  1.4286   1.79062  1.86579      0.48727  0.755306  0.98924

Define a NeuralODE problem with a neural network from Lux.jl.

dudt2 = Lux.Chain(
    x -> x.^3,
    Lux.Dense(2, 50, tanh),
    Lux.Dense(50, 2)
)

p, st = Lux.setup(rng, dudt2) |> f64
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps)
NeuralODE(
    model = Chain(
        layer_1 = WrappedFunction(#2),
        layer_2 = Dense(2 => 50, tanh),           # 150 parameters
        layer_3 = Dense(50 => 2),                 # 102 parameters
    ),
)         # Total: 252 parameters,
          #        plus 0 states.

Predicted output

predict_neuralode(p) = Array(prob_neuralode(u0, p, st)[1])
predict_neuralode (generic function with 1 method)

Loss function Optimization.jl v4 only accept a scalar output

function loss_neuralode(p)
    pred = predict_neuralode(p)
    l2loss = sum(abs2, ode_data .- pred)
    return l2loss
end
loss_neuralode (generic function with 1 method)

Callback function

anim = Animation()
lossrecord=Float64[]
callback = function (state, l; doplot = true)
    if doplot
        pred = predict_neuralode(state.u)
        plt = scatter(tsteps, ode_data[1,:], label = "data")
        scatter!(plt, tsteps, pred[1,:], label = "prediction")
        frame(anim)
        push!(lossrecord, l)
    else
        println(l)
    end
    return false
end
#6 (generic function with 1 method)

Try the callback function to see if it works.

pinit = ComponentArray(p)
callback((; u = pinit), loss_neuralode(pinit); doplot=false)
120.03290214006641
false

Use SciML/Optimization.jl to solve the problem and FluxML/Zygote.jl for automatic differentiation (AD).

adtype = Optimization.AutoZygote()
ADTypes.AutoZygote()

Define a function to optimize with AD.

optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)
SciMLBase.OptimizationFunction{true, ADTypes.AutoZygote, Main.var"##277".var"#9#10", Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED_NO_TIME), Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}(Main.var"##277".var"#9#10"(), ADTypes.AutoZygote(), nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, SciMLBase.DEFAULT_OBSERVED_NO_TIME, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing, nothing)

Define an OptimizationProblem

optprob = Optimization.OptimizationProblem(optf, pinit)
OptimizationProblem. In-place: true
u0: ComponentVector{Float64}(layer_1 = Float64[], layer_2 = (weight = [-1.8019577264785767 1.509717345237732; -0.18273845314979553 -0.46764108538627625; … ; 0.37099915742874146 -0.27108314633369446; -0.34856587648391724 -0.6062840819358826], bias = [-0.5224840044975281, -0.6805992722511292, -0.21060703694820404, 0.5093754529953003, 0.336392879486084, 0.22010256350040436, -0.12450861930847168, 0.38843590021133423, 0.5799375176429749, 0.3984285593032837  …  0.10401319712400436, 0.009969078004360199, -0.460673987865448, 0.210310161113739, 0.5280858278274536, 0.7054404020309448, 0.0009628869011066854, 0.4056747257709503, 0.30830612778663635, 0.17590543627738953]), layer_3 = (weight = [0.22905384004116058 -0.23547108471393585 … 0.0332123264670372 0.13550478219985962; 0.22466984391212463 -0.148941770195961 … -0.1966829150915146 0.10960526019334793], bias = [-0.02694704197347164, -0.03700210154056549]))

Solve the OptimizationProblem using the ADAM optimizer first to get a rough estimate.

result_neuralode = Optimization.solve(
    optprob,
    OptimizationOptimisers.Adam(0.05),
    callback = callback,
    maxiters = 300
)

println("Loss is: ", loss_neuralode(result_neuralode.u))
Loss is: 0.10195836527600548

Use another optimizer (BFGS) to refine the solution.

optprob2 = remake(optprob; u0 = result_neuralode.u)

result_neuralode2 = Optimization.solve(
    optprob2,
    Optim.BFGS(; initial_stepnorm = 0.01),
    callback = callback,
    allow_f_increases = false
)

println("Loss is: ", loss_neuralode(result_neuralode2.u))
Loss is: 0.011016032931972099

Visualize the fitting process

mp4(anim, fps=15)
[ Info: Saved animation to /tmp/jl_B2t2vQ5cRC.mp4
lossrecord
plot(lossrecord[1:300], xlabel="Iters", ylabel="Loss", lab="Adam", yscale=:log10)
plot!(300:length(lossrecord), lossrecord[300:end], lab="BFGS")

Runtime environment#

using Pkg
Pkg.status()
Status `~/work/jl-pde/jl-pde/Project.toml`
  [b0b7db55] ComponentArrays v0.15.30
⌃ [aae7a2af] DiffEqFlux v4.4.0
  [5b8099bc] DomainSets v0.7.16
  [7da242da] Enzyme v0.13.108
  [f6369f11] ForwardDiff v1.3.0
  [d3d80556] LineSearches v7.5.1
⌃ [b2108857] Lux v1.21.1
  [94925ecb] MethodOfLines v0.11.9
  [961ee093] ModelingToolkit v10.31.0
  [315f7962] NeuralPDE v5.20.0
  [8913a72c] NonlinearSolve v4.12.0
⌅ [7f7a1694] Optimization v4.8.0
⌃ [36348300] OptimizationOptimJL v0.4.5
⌃ [42dfb2eb] OptimizationOptimisers v0.3.11
⌃ [500b13db] OptimizationPolyalgorithms v0.3.1
  [1dea7af3] OrdinaryDiffEq v6.103.0
  [91a5bcdd] Plots v1.41.2
  [ce78b400] SimpleUnPack v1.1.0
  [37e2e46d] LinearAlgebra v1.12.0
  [9a3f8284] Random v1.11.0
Info Packages marked with ⌃ and ⌅ have new versions available. Those with ⌃ may be upgradable, but those with ⌅ are restricted by compatibility constraints from upgrading. To see why use `status --outdated`
using InteractiveUtils
InteractiveUtils.versioninfo()
Julia Version 1.12.2
Commit ca9b6662be4 (2025-11-20 16:25 UTC)
Build Info:
  Official https://julialang.org release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-18.1.7 (ORCJIT, znver3)
  GC: Built with stock GC
Threads: 4 default, 1 interactive, 4 GC (on 4 virtual cores)
Environment:
  JULIA_CPU_TARGET = generic;icelake-server,clone_all;znver3,clone_all
  JULIA_CONDAPKG_OFFLINE = true
  JULIA_CONDAPKG_BACKEND = Null
  JULIA_CI = true
  LD_LIBRARY_PATH = /opt/hostedtoolcache/Python/3.13.9/x64/lib
  JULIA_NUM_THREADS = auto

This notebook was generated using Literate.jl.