Solving ODEs with NeuralPDE.jl#
From https://docs.sciml.ai/NeuralPDE/stable/tutorials/ode/
using NeuralPDE
using Lux
using OptimizationOptimisers
using OrdinaryDiffEq
using LinearAlgebra
using Random
using Plots
rng = Random.Xoshiro(0)
Random.Xoshiro(0xdb2fa90498613fdf, 0x48d73dc42d195740, 0x8c49bc52dc8a77ea, 0x1911b814c02405e8, 0x22a21880af5dc689)
Solve ODEs#
The true function: \(u^{\prime} = cos(2 \pi t)\)
model(u, p, t) = cospi(2t)
model (generic function with 1 method)
Prepare data
tspan = (0.0, 1.0)
u0 = 0.0
prob = ODEProblem(model, u0, tspan)
ODEProblem with uType Float64 and tType Float64. In-place: false
timespan: (0.0, 1.0)
u0: 0.0
Construct a neural network to solve the problem.
chain = Lux.Chain(Lux.Dense(1, 5, σ), Lux.Dense(5, 1))
ps, st = Lux.setup(rng, chain) |> f64
((layer_1 = (weight = [-0.04929668828845024; -0.3266667425632477; … ; -1.4946011304855347; -1.0391809940338135;;], bias = [-0.458548903465271, -0.8280583620071411, -0.38509929180145264, 0.32322537899017334, -0.32623517513275146]), layer_2 = (weight = [0.5656673908233643 -0.605137288570404 … 0.3129439055919647 0.22128699719905853], bias = [-0.11007555574178696])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))
Solve the ODE as in DifferentialEquations.jl
, just change the solver algorithm to NeuralPDE.NNODE()
.
optimizer = OptimizationOptimisers.Adam(0.1)
alg = NeuralPDE.NNODE(chain, optimizer, init_params = ps)
sol = solve(prob, alg, maxiters=2000, saveat = 0.01)
┌ Warning: Mixed-Precision `matmul_cpu_fallback!` detected and Octavian.jl cannot be used for this set of inputs (C [Matrix{Float64}]: A [Base.ReshapedArray{Float32, 2, SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, Tuple{}}] x B [Matrix{Float64}]). Falling back to generic implementation. This may be slow.
└ @ LuxLib.Impl ~/.julia/packages/LuxLib/wiiF1/src/impl/matmul.jl:145
retcode: Success
Interpolation: Trained neural network interpolation
t: 0.0:0.01:1.0
u: 101-element Vector{Float64}:
0.0
0.009771322357526077
0.019375211198980107
0.028801092979314113
0.038037487503250804
0.04707196216939724
0.05589109111631943
0.06448042081942142
0.07282444386868113
0.08090658283417815
⋮
-0.08099146256434428
-0.07258245857715788
-0.06377591675137832
-0.05457955558280915
-0.04500102880385768
-0.035047922960984725
-0.024727756107609727
-0.014047977406816814
-0.003015967462660716
Comparing to the regular solver
sol2 = solve(prob, Tsit5(), saveat=sol.t)
retcode: Success
Interpolation: 1st order linear
t: 101-element Vector{Float64}:
0.0
0.01
0.02
0.03
0.04
0.05
0.06
0.07
0.08
0.09
⋮
0.92
0.93
0.94
0.95
0.96
0.97
0.98
0.99
1.0
u: 101-element Vector{Float64}:
0.0
0.009993421557959134
0.019947410479672478
0.029822662260300302
0.03958022071476466
0.04918159908446656
0.05858894195530296
0.0677649690474175
0.07667347363816583
0.08527940825421805
⋮
-0.07670552085637702
-0.06779339219172714
-0.05861083135542737
-0.049195336960505105
-0.039586192500697205
-0.029824466350448206
-0.019949141853139615
-0.009995163933671121
-1.7408033452440449e-6
plot(sol2, label = "Tsit5")
plot!(sol.t, sol.u, label = "NNODE")

Parameter estimation#
using NeuralPDE, OrdinaryDiffEq, Lux, Random, OptimizationOptimJL, LineSearches, Plots
rng = Random.Xoshiro(0)
Random.Xoshiro(0xdb2fa90498613fdf, 0x48d73dc42d195740, 0x8c49bc52dc8a77ea, 0x1911b814c02405e8, 0x22a21880af5dc689)
NNODE only supports out-of-place functions
function lv(u, p, t)
u₁, u₂ = u
α, β, γ, δ = p
du₁ = α * u₁ - β * u₁ * u₂
du₂ = δ * u₁ * u₂ - γ * u₂
[du₁, du₂]
end
lv (generic function with 1 method)
Generate data
tspan = (0.0, 5.0)
u0 = [5.0, 5.0]
true_p = [1.5, 1.0, 3.0, 1.0]
prob = ODEProblem(lv, u0, tspan, true_p)
sol_data = solve(prob, Tsit5(), saveat = 0.01)
t_ = sol_data.t
u_ = Array(sol_data)
2×501 Matrix{Float64}:
5.0 4.82567 4.65308 4.48283 4.31543 … 1.01959 1.03094 1.04248
5.0 5.09656 5.18597 5.26791 5.34212 0.397663 0.389887 0.382307
Define a neural network
n = 15
chain = Chain(Dense(1, n, σ), Dense(n, n, σ), Dense(n, n, σ), Dense(n, 2))
ps, st = Lux.setup(rng, chain) |> f64
((layer_1 = (weight = [-0.04929668828845024; -0.3266667425632477; … ; -1.3531280755996704; -0.2917589843273163;;], bias = [0.28568029403686523, -0.4209803342819214, -0.24613642692565918, -0.9429000616073608, -0.3618292808532715, 0.077278733253479, 0.9969245195388794, 0.7939795255661011, 0.45440757274627686, -0.4830443859100342, -0.6861011981964111, -0.3221019506454468, -0.5597391128540039, -0.15051674842834473, 0.9440881013870239]), layer_2 = (weight = [-0.08606008440256119 -0.2168799191713333 … -0.3507671356201172 0.07374405115842819; 0.24009405076503754 -0.2372819483280182 … 0.34944412112236023 -0.21207459270954132; … ; 0.3976286053657532 0.28444960713386536 … -0.32817620038986206 0.396392285823822; -0.07926429808139801 0.35875916481018066 … -0.03593128174543381 -0.28511112928390503], bias = [-0.065037302672863, 0.18384626507759094, 0.17181798815727234, -0.17310386896133423, 0.06428726017475128, 0.09600061178207397, -0.08703552931547165, 0.06890828162431717, -0.16194558143615723, -0.14649711549282074, -0.14649459719657898, -0.04401325806975365, -0.015492657199501991, 0.1046019047498703, 0.15015578269958496]), layer_3 = (weight = [-0.2995997369289398 0.14921274781227112 … -0.011808237060904503 -0.3409591019153595; 0.4351722002029419 0.1286778748035431 … -0.20781198143959045 -0.030425485223531723; … ; -0.02206072397530079 0.14348538219928741 … -0.05763476341962814 -0.2672235071659088; 0.2975636124610901 -0.06781639903783798 … 0.4012162387371063 0.12123444676399231], bias = [0.04135546088218689, -0.2398381233215332, 0.1595604568719864, 0.08355490118265152, -0.06149742379784584, -0.06998120248317719, -0.008059235289692879, -0.10936713218688965, -0.18340998888015747, 0.06297893822193146, 0.04081515222787857, -0.04258332401514053, 0.11171907186508179, -0.21218737959861755, 0.07965957373380661]), layer_4 = (weight = [0.3909371793270111 -0.23473049700260162 … 0.07385867089033127 0.31727129220962524; -0.04396385699510574 0.1817844659090042 … -0.26729491353034973 0.24492913484573364], bias = [0.04966225475072861, -0.04299044609069824])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))
Loss function
additional_loss(phi, θ) = sum(abs2, phi(t_, θ) .- u_) / size(u_, 2)
additional_loss (generic function with 1 method)
NNODE solver
opt = LBFGS(linesearch = BackTracking())
alg = NNODE(chain, opt, ps; strategy = WeightedIntervalTraining([0.7, 0.2, 0.1], 500), param_estim = true, additional_loss)
NeuralPDE.NNODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}, layer_4::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Optim.LBFGS{Nothing, LineSearches.InitialStatic{Float64}, LineSearches.BackTracking{Float64, Int64}, Optim.var"#20#22"}, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float64}, bias::Vector{Float64}}, layer_2::@NamedTuple{weight::Matrix{Float64}, bias::Vector{Float64}}, layer_3::@NamedTuple{weight::Matrix{Float64}, bias::Vector{Float64}}, layer_4::@NamedTuple{weight::Matrix{Float64}, bias::Vector{Float64}}}, Bool, NeuralPDE.WeightedIntervalTraining{Float64}, Bool, typeof(Main.var"##230".additional_loss), Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}}(Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}, layer_4::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(1 => 15, σ), layer_2 = Dense(15 => 15, σ), layer_3 = Dense(15 => 15, σ), layer_4 = Dense(15 => 2)), nothing), Optim.LBFGS{Nothing, LineSearches.InitialStatic{Float64}, LineSearches.BackTracking{Float64, Int64}, Optim.var"#20#22"}(10, LineSearches.InitialStatic{Float64}
alpha: Float64 1.0
scaled: Bool false
, LineSearches.BackTracking{Float64, Int64}
c_1: Float64 0.0001
ρ_hi: Float64 0.5
ρ_lo: Float64 0.1
iterations: Int64 1000
order: Int64 3
maxstep: Float64 Inf
cache: Nothing nothing
, nothing, Optim.var"#20#22"(), Optim.Flat(), true), (layer_1 = (weight = [-0.04929668828845024; -0.3266667425632477; … ; -1.3531280755996704; -0.2917589843273163;;], bias = [0.28568029403686523, -0.4209803342819214, -0.24613642692565918, -0.9429000616073608, -0.3618292808532715, 0.077278733253479, 0.9969245195388794, 0.7939795255661011, 0.45440757274627686, -0.4830443859100342, -0.6861011981964111, -0.3221019506454468, -0.5597391128540039, -0.15051674842834473, 0.9440881013870239]), layer_2 = (weight = [-0.08606008440256119 -0.2168799191713333 … -0.3507671356201172 0.07374405115842819; 0.24009405076503754 -0.2372819483280182 … 0.34944412112236023 -0.21207459270954132; … ; 0.3976286053657532 0.28444960713386536 … -0.32817620038986206 0.396392285823822; -0.07926429808139801 0.35875916481018066 … -0.03593128174543381 -0.28511112928390503], bias = [-0.065037302672863, 0.18384626507759094, 0.17181798815727234, -0.17310386896133423, 0.06428726017475128, 0.09600061178207397, -0.08703552931547165, 0.06890828162431717, -0.16194558143615723, -0.14649711549282074, -0.14649459719657898, -0.04401325806975365, -0.015492657199501991, 0.1046019047498703, 0.15015578269958496]), layer_3 = (weight = [-0.2995997369289398 0.14921274781227112 … -0.011808237060904503 -0.3409591019153595; 0.4351722002029419 0.1286778748035431 … -0.20781198143959045 -0.030425485223531723; … ; -0.02206072397530079 0.14348538219928741 … -0.05763476341962814 -0.2672235071659088; 0.2975636124610901 -0.06781639903783798 … 0.4012162387371063 0.12123444676399231], bias = [0.04135546088218689, -0.2398381233215332, 0.1595604568719864, 0.08355490118265152, -0.06149742379784584, -0.06998120248317719, -0.008059235289692879, -0.10936713218688965, -0.18340998888015747, 0.06297893822193146, 0.04081515222787857, -0.04258332401514053, 0.11171907186508179, -0.21218737959861755, 0.07965957373380661]), layer_4 = (weight = [0.3909371793270111 -0.23473049700260162 … 0.07385867089033127 0.31727129220962524; -0.04396385699510574 0.1817844659090042 … -0.26729491353034973 0.24492913484573364], bias = [0.04966225475072861, -0.04299044609069824])), false, true, NeuralPDE.WeightedIntervalTraining{Float64}([0.7, 0.2, 0.1], 500), true, Main.var"##230".additional_loss, Base.Pairs{Symbol, Union{}, Tuple{}, @NamedTuple{}}())
Solve the problem
Use verbose=true
to see the fitting process
sol = solve(prob, alg, verbose = false, abstol = 1e-8, maxiters = 5000, saveat = t_)
retcode: Success
Interpolation: Trained neural network interpolation
t: 501-element Vector{Float64}:
0.0
0.01
0.02
0.03
0.04
0.05
0.06
0.07
0.08
0.09
⋮
4.92
4.93
4.94
4.95
4.96
4.97
4.98
4.99
5.0
u: 501-element Vector{Vector{Float64}}:
[5.0, 5.0]
[4.823494070081761, 5.0952379540439265]
[4.649900975571517, 5.183749986601806]
[4.4794811247429385, 5.2651316137181885]
[4.312490368784844, 5.339042108250796]
[4.1491764389183245, 5.405211030683965]
[3.989775465102224, 5.46344221839134]
[3.8345086170683755, 5.513615223495945]
[3.6835789154909477, 5.5556843400091624]
[3.537168271934487, 5.589675487867298]
⋮
[0.95554599728922, 0.45058865358502054]
[0.9650871358232944, 0.441433417758482]
[0.9747165108833604, 0.4324718020267113]
[0.9844287389140209, 0.4236963931926958]
[0.9942184066131823, 0.41509994768065894]
[1.004080079860517, 0.4066753909863179]
[1.0140083124764518, 0.39841581694788086]
[1.0239976547851524, 0.39031448684521575]
[1.0340426619578147, 0.3823648283366765]
See the fitted parameters
println(sol.k.u.p)
[1.4996301170948254, 0.9996517238787163, 2.999157217484517, 0.9988930562743459]
Visualize the fit
plot(sol, labels = ["u1_pinn" "u2_pinn"])
plot!(sol_data, labels = ["u1_data" "u2_data"])

This notebook was generated using Literate.jl.