Solving ODEs with Physics-Informed Neural Networks: https://
using NeuralPDE
using Lux
using OptimizationOptimisers
using OrdinaryDiffEq
using LinearAlgebra
using Random
using Plots
rng = Random.default_rng()
Random.seed!(rng, 42)Random.TaskLocalRNG()Solve ODEs¶
The true function:
model(u, p, t) = cospi(2t)model (generic function with 1 method)Prepare data
tspan = (0.0, 1.0)
u0 = 0.0
prob = ODEProblem(model, u0, tspan)ODEProblem with uType Float64 and tType Float64. In-place: false
Non-trivial mass matrix: false
timespan: (0.0, 1.0)
u0: 0.0Construct a neural network to solve the problem.
chain = Lux.Chain(Lux.Dense(1, 5, σ), Lux.Dense(5, 1))
ps, st = Lux.setup(rng, chain) |> Lux.f64((layer_1 = (weight = [1.0205187797546387; 0.4480646252632141; … ; -0.17203108966350555; 1.4574639797210693;;], bias = [-0.37966299057006836, 0.4062596559524536, -0.1814206838607788, 0.34669220447540283, -0.9501688480377197]), layer_2 = (weight = [-0.4852765202522278 0.1757996529340744 … 0.26079171895980835 -0.3895817995071411], bias = [-0.20251299440860748])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))Solve the ODE with NeuralPDE.NNODE().
optimizer = OptimizationOptimisers.Adam(0.1)
alg = NeuralPDE.NNODE(chain, optimizer, init_params = ps)
@time sol = solve(prob, alg, maxiters=2000, saveat = 0.01)retcode: Success
Interpolation: Trained neural network interpolation
t: 0.0:0.01:1.0
u: 101-element Vector{Float64}:
0.0
0.009151421211917572
0.01825691104070158
0.027299171299466487
0.036259632563473705
0.04511843088055512
0.053854394660181805
0.06244504363196118
0.07086660188482889
0.07909402709230419
⋮
-0.07731899338316414
-0.06912985447088438
-0.060584330881115894
-0.05169195568108881
-0.04246213655204361
-0.03290414929420665
-0.023027132883315104
-0.012840085860246172
-0.0023518638577649416Comparing to the regular solver
sol2 = solve(prob, Tsit5(), saveat=sol.t)
plot(sol2, label = "Tsit5")
plot!(sol.t, sol.u, label = "NNODE")
Parameter estimation¶
using NeuralPDE
using OrdinaryDiffEq
using Lux
using Random
using OptimizationOptimJL
using LineSearches
using Plots
rng = Random.default_rng()
Random.seed!(rng, 0)Random.TaskLocalRNG()NNODE only supports out-of-place functions f(u, p ,t)
function lv(u, p, t)
u₁, u₂ = u
α, β, γ, δ = p
du₁ = α * u₁ - β * u₁ * u₂
du₂ = δ * u₁ * u₂ - γ * u₂
[du₁, du₂]
endlv (generic function with 1 method)Generate data
tspan = (0.0, 5.0)
u0 = [5.0, 5.0]
true_p = [1.5, 1.0, 3.0, 1.0]
prob = ODEProblem(lv, u0, tspan, true_p)
sol_data = solve(prob, Tsit5(), saveat = 0.01)
t_ = sol_data.t
u_ = Array(sol_data)2×501 Matrix{Float64}:
5.0 4.82567 4.65308 4.48283 4.31543 … 1.01959 1.03094 1.04248
5.0 5.09656 5.18597 5.26791 5.34212 0.397663 0.389887 0.382307Define a neural network
n = 15
chain = Chain(Dense(1, n, σ), Dense(n, n, σ), Dense(n, n, σ), Dense(n, 2))
ps, st = Lux.setup(rng, chain) |> Lux.f64((layer_1 = (weight = [-0.04929668828845024; -0.3266667425632477; … ; -1.3531280755996704; -0.2917589843273163;;], bias = [0.28568029403686523, -0.4209803342819214, -0.24613642692565918, -0.9429000616073608, -0.3618292808532715, 0.077278733253479, 0.9969245195388794, 0.7939795255661011, 0.45440757274627686, -0.4830443859100342, -0.6861011981964111, -0.3221019506454468, -0.5597391128540039, -0.15051674842834473, 0.9440881013870239]), layer_2 = (weight = [-0.08606009185314178 -0.2168799340724945 … -0.3507671356201172 0.07374405860900879; 0.24009406566619873 -0.23728196322917938 … 0.3494441509246826 -0.21207460761070251; … ; 0.3976286053657532 0.28444960713386536 … -0.32817623019218445 0.3963923156261444; -0.07926430553197861 0.35875919461250305 … -0.035931285470724106 -0.2851111590862274], bias = [-0.065037302672863, 0.18384626507759094, 0.17181798815727234, -0.17310386896133423, 0.06428726017475128, 0.09600061178207397, -0.08703552931547165, 0.06890828162431717, -0.16194558143615723, -0.14649711549282074, -0.14649459719657898, -0.04401325806975365, -0.015492657199501991, 0.1046019047498703, 0.15015578269958496]), layer_3 = (weight = [-0.2995997667312622 0.1492127627134323 … -0.011808237992227077 -0.3409591317176819; 0.4351722300052643 0.1286778748035431 … -0.20781199634075165 -0.030425487086176872; … ; -0.02206072397530079 0.1434853971004486 … -0.05763476714491844 -0.2672235369682312; 0.2975636124610901 -0.06781639903783798 … 0.4012162387371063 0.1212344542145729], bias = [0.04135546088218689, -0.2398381233215332, 0.1595604568719864, 0.08355490118265152, -0.06149742379784584, -0.06998120248317719, -0.008059235289692879, -0.10936713218688965, -0.18340998888015747, 0.06297893822193146, 0.04081515222787857, -0.04258332401514053, 0.11171907186508179, -0.21218737959861755, 0.07965957373380661]), layer_4 = (weight = [0.3909372091293335 -0.23473051190376282 … 0.07385867834091187 0.31727132201194763; -0.04396386072039604 0.1817844808101654 … -0.26729491353034973 0.24492914974689484], bias = [0.04966225475072861, -0.04299044609069824])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple(), layer_4 = NamedTuple()))Loss function
additional_loss(phi, θ) = sum(abs2, phi(t_, θ) .- u_) / size(u_, 2)additional_loss (generic function with 1 method)NNODE solver
opt = LBFGS(linesearch = BackTracking())
alg = NNODE(chain, opt, ps; strategy = WeightedIntervalTraining([0.7, 0.2, 0.1], 500), param_estim = true, additional_loss)NeuralPDE.NNODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}, layer_4::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, Optim.LBFGS{Nothing, LineSearches.InitialStatic{Float64}, LineSearches.BackTracking{Float64, Int64}, Returns{Nothing}}, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float64}, bias::Vector{Float64}}, layer_2::@NamedTuple{weight::Matrix{Float64}, bias::Vector{Float64}}, layer_3::@NamedTuple{weight::Matrix{Float64}, bias::Vector{Float64}}, layer_4::@NamedTuple{weight::Matrix{Float64}, bias::Vector{Float64}}}, Bool, NeuralPDE.WeightedIntervalTraining{Float64}, Bool, typeof(Main.var"##277".additional_loss), Vector{Any}, Base.Pairs{Symbol, Union{}, Nothing, @NamedTuple{}}}(Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(NNlib.σ), Int64, Int64, Nothing, Nothing, Static.True}, layer_4::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(1 => 15, σ), layer_2 = Dense(15 => 15, σ), layer_3 = Dense(15 => 15, σ), layer_4 = Dense(15 => 2)), nothing), Optim.LBFGS{Nothing, LineSearches.InitialStatic{Float64}, LineSearches.BackTracking{Float64, Int64}, Returns{Nothing}}(10, LineSearches.InitialStatic{Float64}(1.0, false), LineSearches.BackTracking{Float64, Int64}(0.0001, 0.5, 0.1, 1000, 3, Inf, nothing), nothing, Returns{Nothing}(nothing), Optim.Flat(), true), (layer_1 = (weight = [-0.04929668828845024; -0.3266667425632477; … ; -1.3531280755996704; -0.2917589843273163;;], bias = [0.28568029403686523, -0.4209803342819214, -0.24613642692565918, -0.9429000616073608, -0.3618292808532715, 0.077278733253479, 0.9969245195388794, 0.7939795255661011, 0.45440757274627686, -0.4830443859100342, -0.6861011981964111, -0.3221019506454468, -0.5597391128540039, -0.15051674842834473, 0.9440881013870239]), layer_2 = (weight = [-0.08606009185314178 -0.2168799340724945 … -0.3507671356201172 0.07374405860900879; 0.24009406566619873 -0.23728196322917938 … 0.3494441509246826 -0.21207460761070251; … ; 0.3976286053657532 0.28444960713386536 … -0.32817623019218445 0.3963923156261444; -0.07926430553197861 0.35875919461250305 … -0.035931285470724106 -0.2851111590862274], bias = [-0.065037302672863, 0.18384626507759094, 0.17181798815727234, -0.17310386896133423, 0.06428726017475128, 0.09600061178207397, -0.08703552931547165, 0.06890828162431717, -0.16194558143615723, -0.14649711549282074, -0.14649459719657898, -0.04401325806975365, -0.015492657199501991, 0.1046019047498703, 0.15015578269958496]), layer_3 = (weight = [-0.2995997667312622 0.1492127627134323 … -0.011808237992227077 -0.3409591317176819; 0.4351722300052643 0.1286778748035431 … -0.20781199634075165 -0.030425487086176872; … ; -0.02206072397530079 0.1434853971004486 … -0.05763476714491844 -0.2672235369682312; 0.2975636124610901 -0.06781639903783798 … 0.4012162387371063 0.1212344542145729], bias = [0.04135546088218689, -0.2398381233215332, 0.1595604568719864, 0.08355490118265152, -0.06149742379784584, -0.06998120248317719, -0.008059235289692879, -0.10936713218688965, -0.18340998888015747, 0.06297893822193146, 0.04081515222787857, -0.04258332401514053, 0.11171907186508179, -0.21218737959861755, 0.07965957373380661]), layer_4 = (weight = [0.3909372091293335 -0.23473051190376282 … 0.07385867834091187 0.31727132201194763; -0.04396386072039604 0.1817844808101654 … -0.26729491353034973 0.24492914974689484], bias = [0.04966225475072861, -0.04299044609069824])), false, true, NeuralPDE.WeightedIntervalTraining{Float64}([0.7, 0.2, 0.1], 500), true, Main.var"##277".additional_loss, Any[], false, Base.Pairs{Symbol, Union{}, Nothing, @NamedTuple{}}())Solve the problem
verbose=true for the fitting process
@time sol = solve(prob, alg, verbose = true, abstol = 1e-8, maxiters = 5000, saveat = t_)retcode: Success
Interpolation: Trained neural network interpolation
t: 501-element Vector{Float64}:
0.0
0.01
0.02
0.03
0.04
0.05
0.06
0.07
0.08
0.09
⋮
4.92
4.93
4.94
4.95
4.96
4.97
4.98
4.99
5.0
u: 501-element Vector{Vector{Float64}}:
[5.0, 5.0]
[4.825514567298723, 5.0952013299678]
[4.653708188543102, 5.1835075718585415]
[4.4848434013700365, 5.264596948339136]
[4.319183810417474, 5.338196866582742]
[4.156989738913435, 5.40409050642591]
[3.9985137725719384, 5.462121251731317]
[3.8439963061895135, 5.512194920322765]
[3.693661211689003, 5.554279907009037]
[3.5477117550223385, 5.588405480876789]
⋮
[0.7069350525005849, 0.4132966912819649]
[0.70300831020357, 0.3996656742747229]
[0.6989673525132138, 0.3861682770149617]
[0.6948142871037986, 0.3728002282406928]
[0.6905512294107803, 0.35955738751303734]
[0.686180298972408, 0.34643574149927225]
[0.6817036160285923, 0.33343140034542085]
[0.6771232983641813, 0.3205405941374728]
[0.6724414583848493, 0.30775966944970623]See the fitted parameters
println(sol.k.u.p)[1.4871093675125466, 0.9939011081101585, 2.9935648111590876, 0.994850461329788]
Visualize the fit
plot(sol, labels = ["u1_pinn" "u2_pinn"])
plot!(sol_data, labels = ["u1_data" "u2_data"])
Bayesian inference for PINNs¶
https://
using NeuralPDE
using AdvancedHMC
using MCMCChains
using LogDensityProblems
using Lux
using Plots
using OrdinaryDiffEq
using Distributions
using RandomNNODE only supports out-of-place functions f(u, p ,t)
function lotka_volterra(u, p, t)
# Model parameters.
α, β, γ, δ = p
# Current state.
x, y = u
# Evaluate differential equations.
dx = (α - β * y) * x ## prey
dy = (δ * x - γ) * y ## predator
return [dx, dy]
endlotka_volterra (generic function with 1 method)u0 = [1.0, 1.0]
p = [1.5, 1.0, 3.0, 1.0]
tspan = (0.0, 4.0)
prob = ODEProblem(lotka_volterra, u0, tspan, p)
dt = 0.01
solution = solve(prob, Tsit5(); saveat = dt)retcode: Success
Interpolation: 1st order linear
t: 401-element Vector{Float64}:
0.0
0.01
0.02
0.03
0.04
0.05
0.06
0.07
0.08
0.09
⋮
3.92
3.93
3.94
3.95
3.96
3.97
3.98
3.99
4.0
u: 401-element Vector{Vector{Float64}}:
[1.0, 1.0]
[1.0051122697054304, 0.9802235489841001]
[1.0104482482084793, 0.9608884029133249]
[1.0160067852516195, 0.9419859539931972]
[1.0217868581271055, 0.9235077034160883]
[1.0277875716769742, 0.9054452613612181]
[1.0340081582930438, 0.887790346994655]
[1.040447977916915, 0.870534788469315]
[1.0471065141055134, 0.8536705240234226]
[1.0539833005834183, 0.837189627503347]
⋮
[1.7188602790655703, 0.35603128520071003]
[1.7386754231380195, 0.35153232204585927]
[1.7587969344457686, 0.3471593575148184]
[1.779227949981121, 0.34291022060855736]
[1.7999716537968578, 0.3387828301940905]
[1.821031277006269, 0.3347751950044705]
[1.8424100977831226, 0.3308854136387941]
[1.8641114413617017, 0.3271116745621956]
[1.886138680036769, 0.32345225610585315]Dataset creation for parameter estimation (plus 30% noise)
time = solution.t
u = hcat(solution.u...)
x = u[1, :] + (u[1, :]) .* (0.3 .* randn(length(u[1, :])))
y = u[2, :] + (u[2, :]) .* (0.3 .* randn(length(u[2, :])))
dataset = [x, y, time]
# Plotting the data which will be used
plot(time, x, label = "noisy x")
plot!(time, y, label = "noisy y")
plot!(solution, labels = ["x" "y"])
Define a PINN neural network. The input is time, and the output is the state of the system (x and y).
chain = Chain(Dense(1, 6, tanh), Dense(6, 6, tanh), Dense(6, 2))Chain(
layer_1 = Dense(1 => 6, tanh), # 12 parameters
layer_2 = Dense(6 => 6, tanh), # 42 parameters
layer_3 = Dense(6 => 2), # 14 parameters
) # Total: 68 parameters,
# plus 0 states.Use BNNODE for Bayesian inference. The parameters of the model are estimated with the dataset, and the uncertainty of the estimation is quantified with the posterior distribution.
alg = BNNODE(chain;
dataset = dataset,
draw_samples = 1000,
l2std = [0.1, 0.1],
phystd = [0.1, 0.1],
priorsNNw = (0.0, 3.0),
param = [
Normal(1, 2),
Normal(2, 2),
Normal(2, 2),
Normal(0, 2)],
progress = false)NeuralPDE.BNNODE{Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}, UnionAll, Nothing, Vector{Distributions.Normal{Float64}}, NeuralPDEBPINNExt.var"#27#28", Vector{Vector{Float64}}, @NamedTuple{n_leapfrog::Int64}, Nothing, @NamedTuple{Adaptor::UnionAll, Metric::UnionAll, targetacceptancerate::Float64}, @NamedTuple{Integrator::UnionAll}}(Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_3::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(1 => 6, tanh), layer_2 = Dense(6 => 6, tanh), layer_3 = Dense(6 => 2)), nothing), AdvancedHMC.HMC, nothing, 1000, (0.0, 3.0), Distributions.Normal{Float64}[Distributions.Normal{Float64}(μ=1.0, σ=2.0), Distributions.Normal{Float64}(μ=2.0, σ=2.0), Distributions.Normal{Float64}(μ=2.0, σ=2.0), Distributions.Normal{Float64}(μ=0.0, σ=2.0)], [0.1, 0.1], [0.1, 0.1], NeuralPDEBPINNExt.var"#27#28"(), [[1.2851868809809888, 0.5668624861813534, 0.8692827494388045, 1.2703239000381619, 1.0229898563484339, 1.0796127868752652, 1.1968103708358069, 0.9346133351093979, 0.8645468892557517, 1.3002304052818947 … 1.8937571685134968, 2.3234438970949496, 1.5517589163868453, 1.3835260874285784, 1.4423538083935412, 1.6155095842154217, 1.279029171368105, 2.6512401501432126, 2.2737619927092663, 1.8244650767845423], [0.55994853647493, 1.1253839806193784, 0.6505735514781006, 0.6572951817244799, 0.7125375754086148, 0.757175701983382, 1.1911208727563234, 0.9825034758755204, 0.6321579226703791, 1.025142029452657 … 0.5328221048670227, 0.2966955254330875, 0.23712570605530775, 0.3193652785414433, 0.437225868604619, 0.38881462990240134, 0.39010767249893336, 0.17212185495970153, 0.35576426416301443, 0.40059798592471374], [0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09 … 3.91, 3.92, 3.93, 3.94, 3.95, 3.96, 3.97, 3.98, 3.99, 4.0]], 0.05, (n_leapfrog = 30,), 1, nothing, (Adaptor = AdvancedHMC.Adaptation.StanHMCAdaptor, Metric = AdvancedHMC.DiagEuclideanMetric, targetacceptancerate = 0.8), (Integrator = AdvancedHMC.Leapfrog,), 333, false, false, false, false)Solve the problem
@time sol_pestim = solve(prob, alg; saveat = dt)
sol_pestim.estimated_de_params380.976346 seconds (306.32 M allocations: 1.105 TiB, 25.04% gc time, 4.27% compilation time)
4-element Vector{MonteCarloMeasurements.Particles{Float64, 334}}:
0.368 ± 0.12
0.129 ± 0.037
0.292 ± 0.17
0.224 ± 0.065plot(time, sol_pestim.ensemblesol[1], label = "estimated x")
plot!(time, sol_pestim.ensemblesol[2], label = "estimated y")
plot!(solution, labels = ["true x" "true y"])
This notebook was generated using Literate.jl.