Solving ODEs with NeuralPDE.jl

Solving ODEs with NeuralPDE.jl#

From https://neuralpde.sciml.ai/dev/tutorials/ode/

For example, solving the ODE

\[ u^{\prime} = cos(2 \pi t) \]
using NeuralPDE
using Lux
using OptimizationOptimisers
using OrdinaryDiffEq
using LinearAlgebra
using Random
rng = Random.default_rng()
Precompiling OptimizationMTKExt
        Info Given OptimizationMTKExt was explicitly requested, output will be shown live 
WARNING: Method definition AutoModelingToolkit() in module ADTypes at deprecated.jl:103 overwritten in module OptimizationMTKExt at /srv/juliapkg/packages/OptimizationBase/QZlI6/ext/OptimizationMTKExt.jl:9.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
  ? OptimizationBase → OptimizationMTKExt
[ Info: Precompiling OptimizationMTKExt [ead85033-3460-5ce4-9d4b-429d76e53be9]
WARNING: Method definition AutoModelingToolkit() in module ADTypes at deprecated.jl:103 overwritten in module OptimizationMTKExt at /srv/juliapkg/packages/OptimizationBase/QZlI6/ext/OptimizationMTKExt.jl:9.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
[ Info: Skipping precompilation since __precompile__(false). Importing OptimizationMTKExt [ead85033-3460-5ce4-9d4b-429d76e53be9].
TaskLocalRNG()

True function.

model(u, p, t) = cospi(2t)

tspan = (0.0f0, 1.0f0)
u0 = 0.0f0
prob = ODEProblem(model, u0, tspan)
ODEProblem with uType Float32 and tType Float32. In-place: false
timespan: (0.0f0, 1.0f0)
u0: 0.0f0

Construct a neural network to solve the problem.

chain = Lux.Chain(Lux.Dense(1, 5, tanh), Lux.Dense(5, 1))
p, st = Lux.setup(rng, chain)
((layer_1 = (weight = Float32[0.64618576; -0.3263662; … ; -0.9749539; -0.62931883;;], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = Float32[0.44560993 -0.75582933 … -0.36735737 -0.7745962], bias = Float32[0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))

We solve the ODE as before, just change the solver algorithm to NeuralPDE.NNODE().

optimizer = OptimizationOptimisers.Adam(0.1)
alg = NeuralPDE.NNODE(chain, optimizer)
sol = solve(prob, alg, verbose=true, abstol=1f-6, maxiters=300)
Current loss is: 0.5548269, Iteration: 1
Current loss is: 0.7455408, Iteration: 2
Current loss is: 0.48976097, Iteration: 3
Current loss is: 0.3507415, Iteration: 4
Current loss is: 0.43227822, Iteration: 5
Current loss is: 0.46616384, Iteration: 6
Current loss is: 0.3956321, Iteration: 7
Current loss is: 0.3371801, Iteration: 8
Current loss is: 0.35673827, Iteration: 9
Current loss is: 0.40260854, Iteration: 10
Current loss is: 0.39367014, Iteration: 11
Current loss is: 0.34231517, Iteration: 12
Current loss is: 0.31019926, Iteration: 13
Current loss is: 0.31966949, Iteration: 14
Current loss is: 0.33830103, Iteration: 15
Current loss is: 0.32792866, Iteration: 16
Current loss is: 0.2937532, Iteration: 17
Current loss is: 0.26749623, Iteration: 18
Current loss is: 0.26368943, Iteration: 19
Current loss is: 0.26766008, Iteration: 20
Current loss is: 0.25806582, Iteration: 21
Current loss is: 0.23194012, Iteration: 22
Current loss is: 0.20450166, Iteration: 23
Current loss is: 0.18907368, Iteration: 24
Current loss is: 0.18224546, Iteration: 25
Current loss is: 0.16966513, Iteration: 26
Current loss is: 0.14668709, Iteration: 27
Current loss is: 0.123325184, Iteration: 28
Current loss is: 0.10863909, Iteration: 29
Current loss is: 0.09939949, Iteration: 30
Current loss is: 0.08688287, Iteration: 31
Current loss is: 0.06972115, Iteration: 32
Current loss is: 0.05395335, Iteration: 33
Current loss is: 0.04416932, Iteration: 34
Current loss is: 0.03855996, Iteration: 35
Current loss is: 0.032524493, Iteration: 36
Current loss is: 0.024603788, Iteration: 37
Current loss is: 0.01705001, Iteration: 38
Current loss is: 0.012042073, Iteration: 39
Current loss is: 0.009491978, Iteration: 40
Current loss is: 0.008039467, Iteration: 41
Current loss is: 0.0066085374, Iteration: 42
Current loss is: 0.0050047985, Iteration: 43
Current loss is: 0.0034779708, Iteration: 44
Current loss is: 0.0023277404, Iteration: 45
Current loss is: 0.0016562501, Iteration: 46
Current loss is: 0.0013784256, Iteration: 47
Current loss is: 0.0013530903, Iteration: 48
Current loss is: 0.0014290363, Iteration: 49
Current loss is: 0.0014935487, Iteration: 50
Current loss is: 0.0014768824, Iteration: 51
Current loss is: 0.0013601325, Iteration: 52
Current loss is: 0.0011657453, Iteration: 53
Current loss is: 0.0009413048, Iteration: 54
Current loss is: 0.00073037436, Iteration: 55
Current loss is: 0.00056864513, Iteration: 56
Current loss is: 0.0004729107, Iteration: 57
Current loss is: 0.0004508823, Iteration: 58
Current loss is: 0.00049674365, Iteration: 59
Current loss is: 0.00059442094, Iteration: 60
Current loss is: 0.0007176249, Iteration: 61
Current loss is: 0.0008336216, Iteration: 62
Current loss is: 0.00091011176, Iteration: 63
Current loss is: 0.00093333074, Iteration: 64
Current loss is: 0.0009029134, Iteration: 65
Current loss is: 0.0008337974, Iteration: 66
Current loss is: 0.0007494659, Iteration: 67
Current loss is: 0.00067453174, Iteration: 68
Current loss is: 0.00062169885, Iteration: 69
Current loss is: 0.00059665326, Iteration: 70
Current loss is: 0.00059738656, Iteration: 71
Current loss is: 0.0006135631, Iteration: 72
Current loss is: 0.0006327736, Iteration: 73
Current loss is: 0.00064574915, Iteration: 74
Current loss is: 0.00064568, Iteration: 75
Current loss is: 0.00062864047, Iteration: 76
Current loss is: 0.0005993108, Iteration: 77
Current loss is: 0.00056338264, Iteration: 78
Current loss is: 0.00052509166, Iteration: 79
Current loss is: 0.00048942294, Iteration: 80
Current loss is: 0.000461209, Iteration: 81
Current loss is: 0.00043965224, Iteration: 82
Current loss is: 0.0004251477, Iteration: 83
Current loss is: 0.00041549775, Iteration: 84
Current loss is: 0.00040918027, Iteration: 85
Current loss is: 0.00040263846, Iteration: 86
Current loss is: 0.00039483843, Iteration: 87
Current loss is: 0.00038434015, Iteration: 88
Current loss is: 0.0003716732, Iteration: 89
Current loss is: 0.00035741756, Iteration: 90
Current loss is: 0.00034277598, Iteration: 91
Current loss is: 0.0003291252, Iteration: 92
Current loss is: 0.0003166405, Iteration: 93
Current loss is: 0.00030629663, Iteration: 94
Current loss is: 0.00029757898, Iteration: 95
Current loss is: 0.00029115207, Iteration: 96
Current loss is: 0.00028613413, Iteration: 97
Current loss is: 0.0002815511, Iteration: 98
Current loss is: 0.00027751172, Iteration: 99
Current loss is: 0.00027305892, Iteration: 100
Current loss is: 0.00026873493, Iteration: 101
Current loss is: 0.00026452495, Iteration: 102
Current loss is: 0.00025947046, Iteration: 103
Current loss is: 0.00025477962, Iteration: 104
Current loss is: 0.00025033357, Iteration: 105
Current loss is: 0.00024617667, Iteration: 106
Current loss is: 0.00024191485, Iteration: 107
Current loss is: 0.00023848022, Iteration: 108
Current loss is: 0.00023536949, Iteration: 109
Current loss is: 0.0002328139, Iteration: 110
Current loss is: 0.00023004308, Iteration: 111
Current loss is: 0.00022781253, Iteration: 112
Current loss is: 0.0002257173, Iteration: 113
Current loss is: 0.00022347216, Iteration: 114
Current loss is: 0.00022094746, Iteration: 115
Current loss is: 0.00021876552, Iteration: 116
Current loss is: 0.00021658445, Iteration: 117
Current loss is: 0.0002141599, Iteration: 118
Current loss is: 0.0002120464, Iteration: 119
Current loss is: 0.00020997867, Iteration: 120
Current loss is: 0.00020793396, Iteration: 121
Current loss is: 0.00020580282, Iteration: 122
Current loss is: 0.00020430479, Iteration: 123
Current loss is: 0.00020263622, Iteration: 124
Current loss is: 0.00020106234, Iteration: 125
Current loss is: 0.000199241, Iteration: 126
Current loss is: 0.00019793847, Iteration: 127
Current loss is: 0.00019630408, Iteration: 128
Current loss is: 0.00019521179, Iteration: 129
Current loss is: 0.00019362652, Iteration: 130
Current loss is: 0.00019230177, Iteration: 131
Current loss is: 0.00019106589, Iteration: 132
Current loss is: 0.00018964313, Iteration: 133
Current loss is: 0.00018851267, Iteration: 134
Current loss is: 0.00018729459, Iteration: 135
Current loss is: 0.00018595847, Iteration: 136
Current loss is: 0.00018492801, Iteration: 137
Current loss is: 0.00018374987, Iteration: 138
Current loss is: 0.00018277089, Iteration: 139
Current loss is: 0.00018186847, Iteration: 140
Current loss is: 0.00018055688, Iteration: 141
Current loss is: 0.00017943047, Iteration: 142
Current loss is: 0.0001786052, Iteration: 143
Current loss is: 0.00017798324, Iteration: 144
Current loss is: 0.00017694832, Iteration: 145
Current loss is: 0.00017604724, Iteration: 146
Current loss is: 0.00017506893, Iteration: 147
Current loss is: 0.00017413223, Iteration: 148
Current loss is: 0.00017363978, Iteration: 149
Current loss is: 0.00017259156, Iteration: 150
Current loss is: 0.00017149663, Iteration: 151
Current loss is: 0.00017083797, Iteration: 152
Current loss is: 0.00017001717, Iteration: 153
Current loss is: 0.000169162, Iteration: 154
Current loss is: 0.00016842777, Iteration: 155
Current loss is: 0.00016790626, Iteration: 156
Current loss is: 0.00016738914, Iteration: 157
Current loss is: 0.00016625467, Iteration: 158
Current loss is: 0.00016534654, Iteration: 159
Current loss is: 0.00016473304, Iteration: 160
Current loss is: 0.00016400243, Iteration: 161
Current loss is: 0.00016337499, Iteration: 162
Current loss is: 0.00016270335, Iteration: 163
Current loss is: 0.00016195758, Iteration: 164
Current loss is: 0.0001612841, Iteration: 165
Current loss is: 0.00016053082, Iteration: 166
Current loss is: 0.00016005774, Iteration: 167
Current loss is: 0.00015896861, Iteration: 168
Current loss is: 0.00015832756, Iteration: 169
Current loss is: 0.0001577432, Iteration: 170
Current loss is: 0.00015718881, Iteration: 171
Current loss is: 0.00015644528, Iteration: 172
Current loss is: 0.00015581932, Iteration: 173
Current loss is: 0.00015509911, Iteration: 174
Current loss is: 0.00015456422, Iteration: 175
Current loss is: 0.00015373279, Iteration: 176
Current loss is: 0.00015331952, Iteration: 177
Current loss is: 0.00015231635, Iteration: 178
Current loss is: 0.00015163839, Iteration: 179
Current loss is: 0.00015115872, Iteration: 180
Current loss is: 0.00015055362, Iteration: 181
Current loss is: 0.00014969158, Iteration: 182
Current loss is: 0.00014915102, Iteration: 183
Current loss is: 0.00014859237, Iteration: 184
Current loss is: 0.00014802528, Iteration: 185
Current loss is: 0.0001474703, Iteration: 186
Current loss is: 0.0001468125, Iteration: 187
Current loss is: 0.0001459727, Iteration: 188
Current loss is: 0.00014545237, Iteration: 189
Current loss is: 0.00014484473, Iteration: 190
Current loss is: 0.00014420634, Iteration: 191
Current loss is: 0.00014343979, Iteration: 192
Current loss is: 0.00014306109, Iteration: 193
Current loss is: 0.00014227013, Iteration: 194
Current loss is: 0.0001419342, Iteration: 195
Current loss is: 0.00014110419, Iteration: 196
Current loss is: 0.00014061318, Iteration: 197
Current loss is: 0.00014016221, Iteration: 198
Current loss is: 0.00013945923, Iteration: 199
Current loss is: 0.00013878298, Iteration: 200
Current loss is: 0.00013851971, Iteration: 201
Current loss is: 0.00013760483, Iteration: 202
Current loss is: 0.00013727942, Iteration: 203
Current loss is: 0.00013665245, Iteration: 204
Current loss is: 0.00013588314, Iteration: 205
Current loss is: 0.00013543451, Iteration: 206
Current loss is: 0.00013464515, Iteration: 207
Current loss is: 0.0001342186, Iteration: 208
Current loss is: 0.00013353056, Iteration: 209
Current loss is: 0.00013322485, Iteration: 210
Current loss is: 0.00013273241, Iteration: 211
Current loss is: 0.00013196425, Iteration: 212
Current loss is: 0.00013130448, Iteration: 213
Current loss is: 0.00013103904, Iteration: 214
Current loss is: 0.00013023257, Iteration: 215
Current loss is: 0.00012984744, Iteration: 216
Current loss is: 0.00012912863, Iteration: 217
Current loss is: 0.00012874899, Iteration: 218
Current loss is: 0.00012812654, Iteration: 219
Current loss is: 0.00012748201, Iteration: 220
Current loss is: 0.00012710513, Iteration: 221
Current loss is: 0.00012650594, Iteration: 222
Current loss is: 0.00012582881, Iteration: 223
Current loss is: 0.00012537342, Iteration: 224
Current loss is: 0.00012477902, Iteration: 225
Current loss is: 0.00012449117, Iteration: 226
Current loss is: 0.00012405195, Iteration: 227
Current loss is: 0.0001234003, Iteration: 228
Current loss is: 0.00012277426, Iteration: 229
Current loss is: 0.00012229822, Iteration: 230
Current loss is: 0.000122029334, Iteration: 231
Current loss is: 0.00012142368, Iteration: 232
Current loss is: 0.00012083113, Iteration: 233
Current loss is: 0.00012043468, Iteration: 234
Current loss is: 0.00011990435, Iteration: 235
Current loss is: 0.00011939388, Iteration: 236
Current loss is: 0.00011889621, Iteration: 237
Current loss is: 0.00011805852, Iteration: 238
Current loss is: 0.00011783837, Iteration: 239
Current loss is: 0.0001173183, Iteration: 240
Current loss is: 0.00011670744, Iteration: 241
Current loss is: 0.00011651014, Iteration: 242
Current loss is: 0.00011579984, Iteration: 243
Current loss is: 0.00011544129, Iteration: 244
Current loss is: 0.00011491573, Iteration: 245
Current loss is: 0.00011434933, Iteration: 246
Current loss is: 0.000114202536, Iteration: 247
Current loss is: 0.00011353857, Iteration: 248
Current loss is: 0.000113032205, Iteration: 249
Current loss is: 0.00011274313, Iteration: 250
Current loss is: 0.00011184897, Iteration: 251
Current loss is: 0.00011132771, Iteration: 252
Current loss is: 0.00011110783, Iteration: 253
Current loss is: 0.00011067437, Iteration: 254
Current loss is: 0.0001103203, Iteration: 255
Current loss is: 0.00010983062, Iteration: 256
Current loss is: 0.00010922535, Iteration: 257
Current loss is: 0.000108857355, Iteration: 258
Current loss is: 0.000108366876, Iteration: 259
Current loss is: 0.00010801349, Iteration: 260
Current loss is: 0.00010750158, Iteration: 261
Current loss is: 0.000107095046, Iteration: 262
Current loss is: 0.00010651444, Iteration: 263
Current loss is: 0.00010623927, Iteration: 264
Current loss is: 0.00010585918, Iteration: 265
Current loss is: 0.00010549491, Iteration: 266
Current loss is: 0.000105141866, Iteration: 267
Current loss is: 0.00010462079, Iteration: 268
Current loss is: 0.00010417163, Iteration: 269
Current loss is: 0.00010370616, Iteration: 270
Current loss is: 0.00010331957, Iteration: 271
Current loss is: 0.0001030479, Iteration: 272
Current loss is: 0.000102311875, Iteration: 273
Current loss is: 0.00010202502, Iteration: 274
Current loss is: 0.000101538746, Iteration: 275
Current loss is: 0.00010137647, Iteration: 276
Current loss is: 0.00010086991, Iteration: 277
Current loss is: 0.000100360056, Iteration: 278
Current loss is: 9.989733e-5, Iteration: 279
Current loss is: 9.9305755e-5, Iteration: 280
Current loss is: 9.9265344e-5, Iteration: 281
Current loss is: 9.865979e-5, Iteration: 282
Current loss is: 9.843998e-5, Iteration: 283
Current loss is: 9.787499e-5, Iteration: 284
Current loss is: 9.7476324e-5, Iteration: 285
Current loss is: 9.714072e-5, Iteration: 286
Current loss is: 9.672758e-5, Iteration: 287
Current loss is: 9.6344185e-5, Iteration: 288
Current loss is: 9.600114e-5, Iteration: 289
Current loss is: 9.55788e-5, Iteration: 290
Current loss is: 9.521365e-5, Iteration: 291
Current loss is: 9.483525e-5, Iteration: 292
Current loss is: 9.448714e-5, Iteration: 293
Current loss is: 9.406591e-5, Iteration: 294
Current loss is: 9.360767e-5, Iteration: 295
Current loss is: 9.3312476e-5, Iteration: 296
Current loss is: 9.257099e-5, Iteration: 297
Current loss is: 9.255715e-5, Iteration: 298
Current loss is: 9.21054e-5, Iteration: 299
Current loss is: 9.169603e-5, Iteration: 300
Current loss is: 9.169603e-5, Iteration: 301
retcode: Success
Interpolation: Trained neural network interpolation
t: 0.0f0:0.01010101f0:1.0f0
u: 100-element Vector{Float32}:
  0.0
  0.008816629
  0.017586738
  0.026295807
  0.03492842
  0.043468114
  0.051897418
  0.060197696
  0.068349086
  0.07633047
  0.08411939
  0.09169207
  0.099023305
  ⋮
 -0.09626564
 -0.08953475
 -0.08246241
 -0.07505787
 -0.06733017
 -0.05928744
 -0.050938442
 -0.04229089
 -0.03335245
 -0.024130536
 -0.014632374
 -0.0048647225

Comparing to the regular solver

sol2 = solve(prob, Tsit5(), abstol=1f-6, saveat=sol.t)

using LinearAlgebra
norm(sol.u .- sol2.u, Inf)
0.012175649f0