JAX implementation of deep learning time series modeling of river discharge. This code includes methods for:
- training (single and ensemble)
- hyperparameter search
- finetuning
- inference
- evaluation
which can be fully automated using configuration files and or arguments to the main interface run.py.
To train and run these models on your machine, you can use uv to create a virtual environment and install dependencies from pyproject.toml. Open a terminal and navigate to the project directory.
Create a new virtual environment (venv) with default name
uv venvActivate the environment (for mac/linux)
source .venv/bin/activateInstall the dependencies to your new venv in editable mode with dev packages
pip install -e .[dev]
JAX with CUDA support will be installed automatically from pyproject.toml but your system will still need the correct CUDA drivers.
You can check that JAX can locate your GPU using this command (while the virtual environment is activated):
python -c "import jax; print(jax.devices())"Which should show something like [CudaDevice(id=0) for a single GPU. If JAX can only find your a CPU, it will print [CpuDevice(id=0)].
If you want to use the code with jupyter notebooks, install the new environment as an ipython kernel:
python -m ipykernel install --user --name=swot-mlAll options for dataset creation, model hyperparameters, training progress, logging, etc. can be configured in a yaml file. These details are not exhaustively documented, but the example config files provide some example uses. The full listing of potential options is only documented in the the Config validation class.
You will need at a minimum, two types of data to train the model, with another two types depending on model configuration:
- time series data: NetCDF (.nc) files representing time-varying but regular interval (even if the observation is NA/missing) with a seperate file for each site. Each dataset must contain a 'date' index. This does not specify any particular frequency. However, even if you are modeling hourly features, these datetimes needs to be named 'date' unless you edit the hard-coded index in data/hydrodata.py.
- site lists: Text (.txt) files indicating which sites (basins, gauges, reaches, etc.) will be used for training and testing. This file contains a new site on each line with no other delimeters.
- attributes (optional): a comma-separated value (.csv) file with static attributes for each site. Only required if you specify static attributes in your configuration.
- graph network (optional): A networkx directed graph in json format. The only data used from this graph are the edge definitions, although the node (and thus edge end points), need to be identical to the sites. Must contain all sites. Only required for certain models, and should only be specified in the configuration file if needed by the model.
The recommended structure of the data directory is as follows, with some allowance for differences defined in the Config validator.
<data_dir>/
├── attributes/
│ └── attributes.csv
├── time_series/
│ ├── xxxx.nc
│ └── ... (one .nc file for each site)
├── metadata/
│ ├── site_lists/
│ │ ├── train.txt
│ │ └── test.txt
│ └── graph.json
Once the environment, configuration file, and data files are in place, simply call:
uv run python ./src/run.py train <path_to_configuration_file.yml>which will train a new model according to your config and then produce files with final predictions, error metrics, and some very basic error distribution figures.
Models are implemented in Equinox and have to be compatible with the sharp bits of jit'd JAX code.. Each model is a subclass of the [BaseModel] class which helps to standardize the models for integration with the training and inference code. The forward pass of each model assumes that the model has been called using jax.vmap on the batch dimension of the data (i.e. only runs on a single sequence).
To integrate a new model with this package:
- Write the intialization and forward pass logic as a subclass of the BaseModel.
- Add your model to the model.make() function
- Define the required arguments as a new class within config/model_args.py
- (Optionally) define arguments to be set based on the loaded dataset within the set_model_data_args() function
The HydroDataset and HydroDataLoader classes are implementations of PyTorch datasets and dataloaders. The HydroDataset class does several jobs to prepare the data:
- reads in list of training and testing basins
- reads in static basin attributes
- reads in dynamic timeseries data (compiled from per-basin netcdf files)
- normalizes the data based on the training subset
- encodes categorical or bitmask features (as defined in the yml)
- (optionally) caches compiled dynamic data based on the data configuration
- batches data depending on model config, i.e. time series data are shaped [batch_size, sequence_length, n_features] for lumped modeling or [batch_size, sequence_length, n_locations, n_features] for distributed modeling.
The HydroDataLoader is mostly a thin wrapper on the PyTorch DataLoader. The other side of 'mostly' is that the HydroDataLoader also sets the sharding and device settings of each batch to ensure JAX takes full advantage of GPU(s).
the Trainer class steps through the dataloader, making predictions, calculating loss, and updating the model. There are also methods for logging training progress to a file, monitoring gradients, and saving and loading model states.
This module contains methods for making predictions, calculating error statistics, and making plots. The prediction methods are similar to those used for training, excet there is no requirement of target data and it is significantly faster since there is no loss or gradient calculations. The outputs are also collected and saved. The error metrics are defined in metrics.py for both bulk sample statistics and per-basin statistics. The plots are, admittedly, poorly implemented as they rely on some hardcoded ranges and lists of metrics that were useful when written for a manuscript.