Skip to content

Add wave equation example (PDE via method of lines)#279

Open
gpartin wants to merge 2 commits intortqichen:masterfrom
gpartin:examples/wave-equation
Open

Add wave equation example (PDE via method of lines)#279
gpartin wants to merge 2 commits intortqichen:masterfrom
gpartin:examples/wave-equation

Conversation

@gpartin
Copy link
Copy Markdown

@gpartin gpartin commented Mar 12, 2026

Summary

Adds a new example demonstrating how to solve the 1D wave equation using torchdiffeq. This fills a gap in the examples directory — currently all examples are ODE systems, but torchdiffeq is equally useful for solving PDEs via the method of lines: discretize in space, integrate in time with odeint.

What the example does

The wave equation u_tt = c^2 * u_xx is converted to a first-order ODE system:

  • du/dt = v
  • dv/dt = c^2 * Laplacian(u)

where the spatial Laplacian uses second-order finite differences with periodic boundary conditions.

Two modes of operation

Mode Command Description
Forward solve python wave_equation.py --viz Solve the wave equation and visualize space-time heatmap, snapshots, and energy conservation
Neural ODE training python wave_equation.py --train --viz Train a neural network to learn wave dynamics from data

Results

  • Forward solve: 128-point grid, dopri5 adaptive solver, GPU — completes in ~3s with 0.0000% energy drift
  • Training: Neural ODE learns short-horizon wave dynamics from batched trajectory data
  • Supports --adjoint for O(1)-memory backpropagation
  • Supports --method (dopri5, adams, rk4), --gpu, --n_grid, --c (wave speed)

Files changed

  • examples/wave_equation.py — New example (305 lines)
  • examples/README.md — Added Wave Equation section

Notes

  • Follows the same code style and argument pattern as ode_demo.py
  • Uses nn.Module for both physics and neural dynamics (compatible with odeint_adjoint)
  • The discrete energy (kinetic + gradient, consistent with the finite-difference Laplacian) is conserved to machine precision by the adaptive solver

Copilot AI review requested due to automatic review settings March 12, 2026 02:38
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new torchdiffeq example demonstrating solving a 1D PDE (wave equation) via the method of lines (spatial finite differences + odeint time integration), plus documentation in the examples README.

Changes:

  • Added examples/wave_equation.py forward-solve + visualization + optional Neural ODE training loop.
  • Documented how to run the new example in examples/README.md.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
examples/wave_equation.py New wave-equation method-of-lines example with optional Neural ODE training and visualization utilities.
examples/README.md Added a section describing how to run/visualize/train the new wave equation example.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Comment thread examples/wave_equation.py Outdated
"""Two Gaussian pulses traveling in opposite directions."""
L = 2.0 * np.pi
u0 = torch.exp(-40.0 * (x - L / 3.0) ** 2) + 0.5 * torch.exp(-40.0 * (x - 2.0 * L / 3.0) ** 2)
v0 = torch.zeros(n_grid)
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

initial_condition creates u0 from x (float64 in generate_data) but creates v0 with torch.zeros(n_grid) (default float32). torch.cat([u0, v0]) requires matching dtypes and will raise an error. Create v0 with the same dtype/device as u0 (e.g., using torch.zeros_like(u0) or specifying dtype=u0.dtype and device=u0.device).

Suggested change
v0 = torch.zeros(n_grid)
v0 = torch.zeros(n_grid, dtype=u0.dtype, device=u0.device)

Copilot uses AI. Check for mistakes.
Match v0 dtype and device to u0 to avoid errors when u0 is
float64 or on GPU. Addresses Copilot review suggestion.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants