Giter VIP home page Giter VIP logo

neural-lam's Introduction

Linting Automatic tests

Neural-LAM is a repository of graph-based neural weather prediction models for Limited Area Modeling (LAM). The code uses PyTorch and PyTorch Lightning. Graph Neural Networks are implemented using PyG and logging is set up through Weights & Biases.

The repository contains LAM versions of:

For more information see our paper: Graph-based Neural Weather Prediction for Limited Area Modeling. If you use Neural-LAM in your work, please cite:

@inproceedings{oskarsson2023graphbased,
    title={Graph-based Neural Weather Prediction for Limited Area Modeling},
    author={Oskarsson, Joel and Landelius, Tomas and Lindsten, Fredrik},
    booktitle={NeurIPS 2023 Workshop on Tackling Climate Change with Machine Learning},
    year={2023}
}

As the code in the repository is continuously evolving, the latest version might feature some small differences to what was used in the paper. See the branch ccai_paper_2023 for a revision of the code that reproduces the workshop paper.

We plan to continue updating this repository as we improve existing models and develop new ones. Collaborations around this implementation are very welcome. If you are working with Neural-LAM feel free to get in touch and/or submit pull requests to the repository.

Modularity

The Neural-LAM code is designed to modularize the different components involved in training and evaluating neural weather prediction models. Models, graphs and data are stored separately and it should be possible to swap out individual components. Still, some restrictions are inevitable:

  • The graph used has to be compatible with what the model expects. E.g. a hierarchical model requires a hierarchical graph.
  • The graph and data are specific to the limited area under consideration. This is of course true for the data, but also the graph should be created with the exact geometry of the area in mind.

A note on the limited area setting

Currently we are using these models on a limited area covering the Nordic region, the so called MEPS area (see paper). There are still some parts of the code that is quite specific for the MEPS area use case. This is in particular true for the mesh graph creation (create_mesh.py) and some of the constants set in a data_config.yaml file (path specified in train_model.py --data_config ). If there is interest to use Neural-LAM for other areas it is not a substantial undertaking to refactor the code to be fully area-agnostic. We would be happy to support such enhancements. See the issues #2, #3 and #4 for some initial ideas on how this could be done.

Using Neural-LAM

Below follows instructions on how to use Neural-LAM to train and evaluate models.

Installation

Follow the steps below to create the necessary python environment.

  1. Install GEOS for your system. For example with sudo apt-get install libgeos-dev. This is necessary for the Cartopy requirement.
  2. Use python 3.9.
  3. Install version 2.0.1 of PyTorch. Follow instructions on the PyTorch webpage for how to set this up with GPU support on your system.
  4. Install required packages specified in requirements.txt.
  5. Install PyTorch Geometric version 2.2.0. This can be done by running
TORCH="2.0.1"
CUDA="cu117"

pip install pyg-lib==0.2.0 torch-scatter==2.1.1 torch-sparse==0.6.17 torch-cluster==1.6.1\
    torch-geometric==2.3.1 -f https://pytorch-geometric.com/whl/torch-${TORCH}+${CUDA}.html

You will have to adjust the CUDA variable to match the CUDA version on your system or to run on CPU. See the installation webpage for more information.

Data

Datasets should be stored in a directory called data. See the repository format section for details on the directory structure.

The full MEPS dataset can be shared with other researchers on request, contact us for this. A tiny subset of the data (named meps_example) is available in example_data.zip, which can be downloaded from here. Download the file and unzip in the neural-lam directory. All graphs used in the paper are also available for download at the same link (but can as easily be re-generated using create_mesh.py). Note that this is far too little data to train any useful models, but all scripts can be ran with it. It should thus be useful to make sure that your python environment is set up correctly and that all the code can be ran without any issues.

Pre-processing

An overview of how the different scripts and files depend on each other is given in this figure:

In order to start training models at least three pre-processing scripts have to be ran:
  • create_mesh.py
  • create_grid_features.py
  • create_parameter_weights.py

Create graph

Run create_mesh.py with suitable options to generate the graph you want to use (see python create_mesh.py --help for a list of options). The graphs used for the different models in the paper can be created as:

  • GC-LAM: python create_mesh.py --graph multiscale
  • Hi-LAM: python create_mesh.py --graph hierarchical --hierarchical 1 (also works for Hi-LAM-Parallel)
  • L1-LAM: python create_mesh.py --graph 1level --levels 1

The graph-related files are stored in a directory called graphs.

Create remaining static features

To create the remaining static files run the scripts create_grid_features.py and create_parameter_weights.py.

Weights & Biases Integration

The project is fully integrated with Weights & Biases (W&B) for logging and visualization, but can just as easily be used without it. When W&B is used, training configuration, training/test statistics and plots are sent to the W&B servers and made available in an interactive web interface. If W&B is turned off, logging instead saves everything locally to a directory like wandb/dryrun.... The W&B project name is set to neural-lam, but this can be changed in the flags of train_model.py (using argsparse). See the W&B documentation for details.

If you would like to login and use W&B, run:

wandb login

If you would like to turn off W&B and just log things locally, run:

wandb off

Train Models

Models can be trained using train_model.py. Run python train_model.py --help for a full list of training options. A few of the key ones are outlined below:

  • --dataset: Which data to train on
  • --model: Which model to train
  • --graph: Which graph to use with the model
  • --processor_layers: Number of GNN layers to use in the processing part of the model
  • --ar_steps: Number of time steps to unroll for when making predictions and computing the loss

Checkpoints of trained models are stored in the saved_models directory. The implemented models are:

Graph-LAM

This is the basic graph-based LAM model. The encode-process-decode framework is used with a mesh graph in order to make one-step pedictions. This model class is used both for the L1-LAM and GC-LAM models from the paper, only with different graphs.

To train 1L-LAM use

python train_model.py --model graph_lam --graph 1level ...

To train GC-LAM use

python train_model.py --model graph_lam --graph multiscale ...

Hi-LAM

A version of Graph-LAM that uses a hierarchical mesh graph and performs sequential message passing through the hierarchy during processing.

To train Hi-LAM use

python train_model.py --model hi_lam --graph hierarchical ...

Hi-LAM-Parallel

A version of Hi-LAM where all message passing in the hierarchical mesh (up, down, inter-level) is ran in parallel. Not included in the paper as initial experiments showed worse results than Hi-LAM, but could be interesting to try in more settings.

To train Hi-LAM-Parallel use

python train_model.py --model hi_lam_parallel --graph hierarchical ...

Checkpoint files for our models trained on the MEPS data are available upon request.

Evaluate Models

Evaluation is also done using train_model.py, but using the --eval option. Use --eval val to evaluate the model on the validation set and --eval test to evaluate on test data. Most of the training options are also relevant for evaluation (not ar_steps, evaluation always unrolls full forecasts). Some options specifically important for evaluation are:

  • --load: Path to model checkpoint file (.ckpt) to load parameters from
  • --n_example_pred: Number of example predictions to plot during evaluation.

Note: While it is technically possible to use multiple GPUs for running evaluation, this is strongly discouraged. If using multiple devices the DistributedSampler will replicate some samples to make sure all devices have the same batch size, meaning that evaluation metrics will be unreliable. This issue stems from PyTorch Lightning. See for example this draft PR for more discussion and ongoing work to remedy this.

Repository Structure

Except for training and pre-processing scripts all the source code can be found in the neural_lam directory. Model classes, including abstract base classes, are located in neural_lam/models.

Format of data directory

It is possible to store multiple datasets in the data directory. Each dataset contains a set of files with static features and a set of samples. The samples are split into different sub-directories for training, validation and testing. The directory structure is shown with examples below. Script names within parenthesis denote the script used to generate the file.

data
├── dataset1
│   ├── samples                             - Directory with data samples
│   │   ├── train                           - Training data
│   │   │   ├── nwp_2022040100_mbr000.npy  - A time series sample
│   │   │   ├── nwp_2022040100_mbr001.npy
│   │   │   ├── ...
│   │   │   ├── nwp_2022043012_mbr001.npy
│   │   │   ├── nwp_toa_downwelling_shortwave_flux_2022040100.npy   - Solar flux forcing
│   │   │   ├── nwp_toa_downwelling_shortwave_flux_2022040112.npy
│   │   │   ├── ...
│   │   │   ├── nwp_toa_downwelling_shortwave_flux_2022043012.npy
│   │   │   ├── wtr_2022040100.npy          - Open water features for one sample
│   │   │   ├── wtr_2022040112.npy
│   │   │   ├── ...
│   │   │   └── wtr_202204012.npy
│   │   ├── val                             - Validation data
│   │   └── test                            - Test data
│   └── static                              - Directory with graph information and static features
│       ├── nwp_xy.npy                      - Coordinates of grid nodes (part of dataset)
│       ├── surface_geopotential.npy        - Geopotential at surface of grid nodes (part of dataset)
│       ├── border_mask.npy                 - Mask with True for grid nodes that are part of border (part of dataset)
│       ├── grid_features.pt                - Static features of grid nodes (create_grid_features.py)
│       ├── parameter_mean.pt               - Means of state parameters (create_parameter_weights.py)
│       ├── parameter_std.pt                - Std.-dev. of state parameters (create_parameter_weights.py)
│       ├── diff_mean.pt                    - Means of one-step differences (create_parameter_weights.py)
│       ├── diff_std.pt                     - Std.-dev. of one-step differences (create_parameter_weights.py)
│       ├── flux_stats.pt                   - Mean and std.-dev. of solar flux forcing (create_parameter_weights.py)
│       └── parameter_weights.npy           - Loss weights for different state parameters (create_parameter_weights.py)
├── dataset2
├── ...
└── datasetN

Format of graph directory

The graphs directory contains generated graph structures that can be used by different graph-based models. The structure is shown with examples below:

graphs
├── graph1                                  - Directory with a graph definition
│   ├── m2m_edge_index.pt                   - Edges in mesh graph (create_mesh.py)
│   ├── g2m_edge_index.pt                   - Edges from grid to mesh (create_mesh.py)
│   ├── m2g_edge_index.pt                   - Edges from mesh to grid (create_mesh.py)
│   ├── m2m_features.pt                     - Static features of mesh edges (create_mesh.py)
│   ├── g2m_features.pt                     - Static features of grid to mesh edges (create_mesh.py)
│   ├── m2g_features.pt                     - Static features of mesh to grid edges (create_mesh.py)
│   └── mesh_features.pt                    - Static features of mesh nodes (create_mesh.py)
├── graph2
├── ...
└── graphN

Mesh hierarchy format

To keep track of levels in the mesh graph, a list format is used for the files with mesh graph information. In particular, the files

│   ├── m2m_edge_index.pt                   - Edges in mesh graph (create_mesh.py)
│   ├── m2m_features.pt                     - Static features of mesh edges (create_mesh.py)
│   ├── mesh_features.pt                    - Static features of mesh nodes (create_mesh.py)

all contain lists of length L, for a hierarchical mesh graph with L layers. For non-hierarchical graphs L == 1 and these are all just singly-entry lists. Each entry in the list contains the corresponding edge set or features of that level. Note that the first level (index 0 in these lists) corresponds to the lowest level in the hierarchy.

In addition, hierarchical mesh graphs (L > 1) feature a few additional files with static data:

├── graph1
│   ├── ...
│   ├── mesh_down_edge_index.pt             - Downward edges in mesh graph (create_mesh.py)
│   ├── mesh_up_edge_index.pt               - Upward edges in mesh graph (create_mesh.py)
│   ├── mesh_down_features.pt               - Static features of downward mesh edges (create_mesh.py)
│   ├── mesh_up_features.pt                 - Static features of upward mesh edges (create_mesh.py)
│   ├── ...

These files have the same list format as the ones above, but each list has length L-1 (as these edges describe connections between levels). Entries 0 in these lists describe edges between the lowest levels 1 and 2.

Development and Contributing

Any push or Pull-Request to the main branch will trigger a selection of pre-commit hooks. These hooks will run a series of checks on the code, like formatting and linting. If any of these checks fail the push or PR will be rejected. To test whether your code passes these checks before pushing, run

pre-commit run --all-files

from the root directory of the repository.

Furthermore, all tests in the tests directory will be run upon pushing changes by a github action. Failure in any of the tests will also reject the push/PR.

Contact

If you are interested in machine learning models for LAM, have questions about our implementation or ideas for extending it, feel free to get in touch. You can open a github issue on this page, or (if more suitable) send an email to [email protected].

neural-lam's People

Contributors

joeloskarsson avatar leifdenby avatar sadamov avatar simonkamuk avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

neural-lam's Issues

Merge Graph-EFM model from `prob_model_lam` branch

This issue is to keep track of merging of the prob_model_lam branch into main. This branch contains the Graph-EFM model from Oskarsson et al. (2024) for LAM.

The main changes in this branch include:

I think it is best to do this merge (or at least the addition of the Graph-EFM model) after #49 is done, so the new model can fit into the new class hierarchy.

Creation of Static and Forcing Fields

Summary

For international collaboration a selection of state variables, as well as static and forcing features were selected (see here for reference: Sheet). While the variables are provided by the user in a zarr-archive (usually NWP-model output), the additional forcing and static features are often not available. I suggest to provide the user with the option to generate these fields using a new script create_forcings.py. The script works for state and boundary domains alike, simple visualizations were provided to make this issue easier to understand. Please feel free to improve on any part of the script and help writing PRs 🚀

New fields

  • datetime forcings,
  • land-sea masks,
  • boundary masks,
  • top-of-atmosphere (TOA) radiation forcings.

Questions / Issues:

  • Does such a script belong in the neuralLAM repo or in a pre-processing step?
  • Should it be one or multiple scripts?
  • I have a script version for COSMO and MEPS data - should become domain agnostic
  • See below for questions specific to each field
  • Currently the fields are stored in one .zarr archive. It probably makes sense to specify the zarr archive where each field should append to.

Function Details and Rationale

calculate_datetime_forcing(ds, args)

This function calculates the datetime forcing for the neural LAM model by generating sinusoidal representations of the hour of the day and the time of the year. These features help the model understand diurnal and seasonal variations. Previously these were part of the PyTorch Dataset, provided in __get_item__; it certainly makes more sense to calculate these in advance, should be more efficient and allows for cleaner Dataset code.

image

generate_land_sea_mask(xy_state, tempdir, projection, high_res_factor=10)

This function generates a land-sea mask for the neural LAM model. The shapefile for the land-sea mask comes from natural-earth at 50m resolution. Percentage of gridcell covered by land is calculated by rasterization of high-res shapefile. This generates static fields, for some regions this mask will vary in time as parts of the sea freeze.
image

create_boundary_mask(lonlat_state, lonlat_boundary, boundary_thickness, overlap=False)

This function creates a boundary mask for the neural LAM model. The boundary mask defines the edges of the model's domain and helps in handling boundary conditions, which are essential for limited area modeling. It allows for overlap (Flatbread) or non-overlapping boundaries (Donut). The borders of the interior state domain are identified using anemoi-dataset functionalitites for lat-lon matching of boundary and state. The thickness of the boundary is applied using binary_dilution, which is a terrible approach -> should be calculated in meters in projection of state.

image

generate_toa_radiation_forcing(ds, lonlat)

This function pre-computes all static features related to the grid nodes for top-of-atmosphere (TOA) radiation. This functionality is provided by the graphcast repository. It is however quite slow for large domains and again depends on lat lon. Maybe there are better options.

image

Usage

The example script is based on MEPS (since we all have that data) and ERA5, directly from the google weatherbench cloud storage. To run it, store the script on the first level of the NeuralLAM repo, with meps_example in the data-folder:

pip install gcsfs 
pip install rasterio
pip install git+https://github.com/google-deepmind/graphcast.git
python create_forcings.py

Use logging module instead of print statements

Hi Everyone,

What is your opinion on using the logging module instead of print statements for communicating with the user?
This might become important if the code were to be used in an operational setting down the line.

I guess for the training, some of the logging is done by wandb, but I guess it could be usefull for logging error messages when loading datasets etc.

Kind regards,
Michiel

Designing a `Datastore` class (for handling different storage backends for datasets)

I am in the process of reading through @sadamov's PR #54 on using zarr-based datasets in neural-lam and I am going to use this issue to write down some notes. Everyone is free to read-along, but this will only over time become a coherent piece of information, so it is probably best to wait until I comment directly on #54.

Uses of current neural_lam.config.Config attributes and methods outside the class itself:

$> grep -r 'config\.' *.py       
calculate_statistics.py:        default="neural_lam/data_config.yaml",
calculate_statistics.py:        help="Path to data config file (default: neural_lam/data_config.yaml)",
calculate_statistics.py:    data_config = config.Config.from_file(args.data_config)
calculate_statistics.py:    state_data = data_config.process_dataset("state", split="train")
calculate_statistics.py:    forcing_data = data_config.process_dataset(
create_boundary_mask.py:        default="neural_lam/data_config.yaml",
create_boundary_mask.py:        help="Path to data config file (default: neural_lam/data_config.yaml)",
create_boundary_mask.py:    data_config = config.Config.from_file(args.data_config)
create_boundary_mask.py:    mask = np.zeros(list(data_config.grid_shape_state.values.values()))
create_forcings.py:        "--data_config", type=str, default="neural_lam/data_config.yaml"
create_forcings.py:    data_config = config.Config.from_file(args.data_config)
create_forcings.py:    dataset = data_config.open_zarrs("state")
create_mesh.py:        default="neural_lam/data_config.yaml",
create_mesh.py:        help="Path to data config file (default: neural_lam/data_config.yaml)",
create_mesh.py:    data_config = config.Config.from_file(args.data_config)
create_mesh.py:    xy = data_config.get_xy("static")  # (2, N_y, N_x)
plot_graph.py:        default="neural_lam/data_config.yaml",
plot_graph.py:        help="Path to data config file (default: neural_lam/data_config.yaml)",
plot_graph.py:    data_config = config.Config.from_file(args.data_config)
plot_graph.py:    xy = data_config.get_xy("state")  # (2, N_y, N_x)
train_model.py:        default="neural_lam/data_config.yaml",
train_model.py:        help="Path to data config file (default: neural_lam/data_config.yaml)",
$> grep -r 'config\.' neural_lam 
neural_lam/models/ar_model.py:        self.data_config = config.Config.from_file(args.data_config)
neural_lam/models/ar_model.py:        static = self.data_config.process_dataset("static")
neural_lam/models/ar_model.py:        state_stats = self.data_config.load_normalization_stats(
neural_lam/models/ar_model.py:        self.grid_output_dim = self.data_config.num_data_vars("state")
neural_lam/models/ar_model.py:            self.grid_output_dim = 2 * self.data_config.num_data_vars("state")
neural_lam/models/ar_model.py:            self.grid_output_dim = self.data_config.num_data_vars("state")
neural_lam/models/ar_model.py:            + self.data_config.num_data_vars("forcing")
neural_lam/models/ar_model.py:            * self.data_config.forcing.window
neural_lam/models/ar_model.py:        boundary_mask = self.data_config.load_boundary_mask()
neural_lam/models/ar_model.py:        self.step_length = self.data_config.step_length
neural_lam/models/ar_model.py:                            self.data_config.vars_names("state"),
neural_lam/models/ar_model.py:                            self.data_config.vars_units("state"),
neural_lam/models/ar_model.py:                            self.data_config.vars_names("state"), var_figs
neural_lam/models/ar_model.py:                var = self.data_config.vars_names("state")[var_i]
neural_lam/config.py:        proj_params = proj_config.get("kwargs", {})
neural_lam/config.py:                "vars based on zarr config...\033[0m"
neural_lam/config.py:                    "vars based on zarr config.\033[0m"
neural_lam/config.py:                    "vars based on zarr config.\033[0m"
grep: neural_lam/__pycache__/config.cpython-310.pyc: binary file matches
neural_lam/vis.py:            data_config.vars_names("state"), data_config.vars_units("state")
neural_lam/vis.py:    extent = data_config.get_xy_extent("state")
neural_lam/vis.py:        list(data_config.grid_shape_state.values.values())
neural_lam/vis.py:        subplot_kw={"projection": data_config.coords_projection},
neural_lam/vis.py:            data.reshape(list(data_config.grid_shape_state.values.values()))
neural_lam/vis.py:    extent = data_config.get_xy_extent("state")
neural_lam/vis.py:        list(data_config.grid_shape_state.values.values())
neural_lam/vis.py:        subplot_kw={"projection": data_config.coords_projection},
neural_lam/vis.py:        error.reshape(list(data_config.grid_shape_state.values.values()))
neural_lam/weather_dataset.py:        data_config="neural_lam/data_config.yaml",
neural_lam/weather_dataset.py:        self.data_config = config.Config.from_file(data_config)
neural_lam/weather_dataset.py:        self.state = self.data_config.process_dataset("state", self.split)
neural_lam/weather_dataset.py:        self.forcing = self.data_config.process_dataset("forcing", self.split)
neural_lam/weather_dataset.py:            state_stats = self.data_config.load_normalization_stats(
neural_lam/weather_dataset.py:                forcing_stats = self.data_config.load_normalization_stats(

Add test for MEPS data example

Aim: Create a test setup where we ensure that npy-file based datasets can be read into neural-lam and the training items contain tensors of the right shape.

TODO

  • reduce number of variables, size of domain etc in Joel's MEPS data example so that the zip file is less than 500MB. Calling it meps_example_reduced
  • create test-data zip file and upload to EWC (credentials from @leifdenby)
  • implement test using pytorch to download and unpack testdata using pooch
  • Implement testing of:
    • initiation of neural_lam.weather_dataset.WeatherDataset from downloaded data
    • check shapes of returned parts of training item
    • create new graph in tests for reduced dataset
    • feed single batch through model and check shape of output
  • add github action to run tests during ci/cd

Links:

Replace constants.py with data + region specification from yaml-file

This supersedes #2 and #3.

Motivation

It is currently very hard to work with neural lam on different regions due to everything related to data and the forecast region being specified hard-coded in constants.py. It would be much better to specify this in a config file that you can then point to. Yaml seems like a suitable format for this.

Proposition

The main training/eval script takes a flag --spec that should be given a path to a yaml-file. This yaml file specifies all the data that goes into the model and information about the region the model should be working with.

Current options in constants.py that relate to what to plot should all be turned into flags.

The yaml file should be read in and turned into a single object that contains all useful information and can be passed around in the code (since this is needed almost everywhere). Having this as an object means that it can also compute things not directly in the yaml file, such as units of variables that can be retrieved from loaded xarray data.

Design

Here is an idea for how the yaml file could be laid out with examples:

Start of file:

# Data config for MEPS dataset
---

Some comments to keep track of what this specification is about. We don't enforce any content in there.

Forecast area data configuration

This describes the data configuration for the actual limited area that you are training on. Explicitly, the "inner region", not the "boundary". What is specified is what zarrs to load state, forcing and static grid features from and which variables in these to use for each.

forecast_area:
  zarrs: # List of zarrs containing fields related to state
    fields: # Names on this level are arbitrary
      path: data/meps_example/fields.zarr # Path to zarr
      dims: # Name of dimensions in zarr, to be used for indexing
        time: time
        level: level
        grid: grid_node # Either give "grid" (flattened) dimension or "x" and "y"
    forcing: 
      path: data/meps_example/forcing.zarr
      dims:
        time: time
        level: level
        x: x_coord 
        y: y_coord
  state: # Variables forecasted by the model
    surface: # Single-field variables
      -  2t
      - 10u
      - ...
    atmospheric: # Variables with vertical levels
      - z
      - u
      - v
      - ...
    levels: # Levels to use for atmospheric variables 
      - 700
      - 850
      - ...
    units:
      z: m2/s2
      2t: K
      ...
  forcing: # Forcing variables, dynamic inputs to the model
      surface:
        - land_sea_mask # Dynamic in MEPS
        - ...
      atmospheric:
        - ... # Nothing for MEPS, but we allow it
      levels:
        - ...
  static: # Static inputs
      surface:
        - topography 
        - ...
      atmospheric:
        - ... # Nothing for MEPS, but we allow it
      levels:
        - ...
  lat_lon_names: # Name of variables/coordinates in zarrs specifying latitude and longitude of grid cells
    lat: latitude
    lon: longitude
  • One or more zarrs can be specified. These are all opened as xarray Datasets, coordinates renamed to common names, then joined into one dataset (object in memory).
  • When specifying the names of dimensions in the zarr, the user needs to either specify the name of a flattened grid dimension or the name of x and y dimensions. If x and y then we will flatten ourselves in the data loading code. If only grid is specified, a grid_shape entry has to be given (see below).
  • Names of surface and atmospheric fields listed in the different sections should map exactly to variable names in the joined xarray Dataset.
  • State, forcing and static specifications have exact same structure. Static fields should just lack time-dimension.
  • Entries not specified is just assumed to not be used, or take default values (variable dimensions). Naturally some fields are required (state variables, at least one zarr).
  • For units the variable names correspond to both surface and atmospheric variables (we do not allow naming clashes between these). Format for units is up to the user. We could require that all variables need a specified unit, or we accept that some have unspecified units.
  • We use the latlon information to know where the forecasting area is positioned within the projection. Convert these to coordinates in the projection and take the min/max to get the extent of the forecasting area. These latlon:s are read from the xarray Dataset, from the given variables (can be coordinates, as these are also variables).

Boundary data configuration

The boundary data configuration follows exactly the same structure as the forecast_area, with two differences:

  1. No state entry is allowed, as we do not forecast the boundary nodes atm.
  2. There is a mask entry, specifying what grid cells of the boundary to include.
    The boundary has its own list of zarrs, to avoid variable name clashes with the forecast area zarrs. Note that we enforce no spatial structure of the boundary w.r.t. forecast area. The boundary nodes can be placed anywhere.
boundary:
  zarrs:
  ...
  mask: boundary_mask # Name of variable containing boolean mask, true for grid nodes to be used in boundary.

Grid shape

If the zarrs already contain flattened grid dimensions we need knowledge of the original 2d spatial shape in order to be able to plot data. For such cases this can be given by an optional grid_shape entry:

grid_shape:
   x: 238
   y: 268

Subset splits

The train/val/test split is defined based on timestamps:

splits:
  train:
    start: 2021-04-01T00
    end: 2022-05-31T23
  val:
    start: 2022-06-01T00
    end: 2022-06-30T23
  test:
    start: 2022-07-01T00
    end: 2023-03-31T23

Used by the dataset class to .sel the different subsets.

Forecast area projection

In order to be able to plot data in the forecasting area we need to know what projection the area is defined in. By plotting in this projection we end up with a flat rectangular area where the data sits. This should be specified as a reference to a cartopy.crs object.

projection: LambertConformal # Name of class in cartopy.crs
kwargs: # Parsed and used directly as kwargs to class above
    central_longitude: 15.0
    central_latitude: 63.3
    standard_parallels:
       - 63.3
       - 63.3

Normalization zarr

We also need information about statistics of variables, boundary and forcing for normalization (mean and std). Additionally we need the inverse variances used in the loss computation. As we compute and save this in a pre-processing script we can enforce a specific format, so lets put all of those also in its own zarr. Then we only need to specify a path here to that zarr to load it from.

normalization_zarr: data/meps_example/norm.zarr

Feature Request: Add Functionality to Apply Constraints to Predictions

I am proposing the addition of a new method to our model class, designed to apply constraints to predictions to ensure that the values fall within specified bounds. This functionality would be useful for maintaining the integrity of our model's predictions in scenarios where certain variables have predefined limits.

Proposed Method:

def apply_constraints(self, prediction):
    """
    Apply constraints to prediction to ensure values are within the
    specified bounds
    """
    for param, (min_val, max_val) in constants.PARAM_CONSTRAINTS.items():
        indices = self.variable_indices[param]
        for index in indices:
            # Apply clamping to ensure values are within the specified
            # bounds
            prediction[:, :, :, index] = torch.clamp(
                prediction[:, :, :, index],
                min=min_val,
                max=max_val if max_val is not None else float("inf"),
            )
    return prediction

Rationale:

Data Integrity: Ensuring that predictions adhere to real-world constraints is essential for the reliability of our model's outputs. This method would allow us to enforce these constraints directly on the prediction tensor.
Flexibility: By dynamically applying constraints based on the variable indices, we can maintain a high degree of flexibility in how we handle different variables with varying constraints.

The method could be added to the ARModel class, which is our primary model class.
The constants.PARAM_CONSTRAINTS dictionary, which maps variable names to their minimum and maximum values, should be used to determine the constraints for each variable.

PARAM_CONSTRAINTS = {
    "RELHUM": (0, 100),
    "CLCT": (0, 100),
    "TOT_PREC": (0, None),
}

This feature is closely related to #18

Merge global forecasting capabilities from `prob_model_global`

Global forecasting can be viewed as a special case of LAM, with no boundary conditions to consider. Therefore it is not a hard task to allow for global forecasting within the Neural-LAM framework. This has the benefit of being able to use any models in here also for global forecasting.

Global forecasting, as done in Oskarsson et al. (2024), is currently implemented on the branch prob_model_global. This issue is to keep track of merging of the prob_model_global branch into main.

The main changes in this branch include:

Multi-GPU training

I realized that multi-GPU training is currently broken. Luckily I believe this should be a simple fix, just making sure that logging + the storage of tensors in model classes conforms to the lightning setup properly.

Adapt graph generation script to general limited areas

The graph generation script create_mesh.py is currently written for the MEPS area. Without huge changes this should be possible to change to a generic script that can work for general (quadratic in their projection) areas. Some more arguments would probably have to be introduced.

It might be a good idea to turn the graph generation into a function, and let create_mesh.py only be a utility script for calling this function with some command line arguments.

The area definition object proposed in #2 could be useful as input to such a generic script.

Tracking of Input Channel Indices <-> Variable Name and Level

Description

Tracking of which feature-channel corresponds to which variable and vertical level of the input data is benefitial for plotting, verification and more...

  1. precompute_variable_indices() method to precompute the indices for each variable in the input tensor.
  2. selected_vars_units attribute to store the short names and units of the selected variables.

Implementation

Suggestion to add the following code to the ARModel class:

self.variable_indices = self.precompute_variable_indices()
self.selected_vars_units = list(
    zip(constants.PARAM_NAMES_SHORT, constants.PARAM_UNITS)
)

def precompute_variable_indices(self):
"""
Precompute indices for each variable in the input tensor
"""
variable_indices = {}
all_vars = []
index = 0
# Create a list of tuples for all variables, using level 0 for 2D
# variables
for var_name in constants.PARAM_NAMES_SHORT:
    if constants.IS_3D[var_name]:
        for level in constants.VERTICAL_LEVELS:
            all_vars.append((var_name, level))
    else:
        all_vars.append((var_name, 0))  # Use level 0 for 2D variables

# Sort the variables based on the tuples
sorted_vars = sorted(all_vars)

for var in sorted_vars:
    var_name, level = var
    if var_name not in variable_indices:
        variable_indices[var_name] = []
    variable_indices[var_name].append(index)
    index += 1

return variable_indices

Benefits

  • The precompute_variable_indices() method will allow efficient lookup of indices for each variable in the input tensor, avoiding the need to recompute them during runtime. This is very flexible as the user can define it in constants.py
  • The selected_vars_units attribute will provide easy access to the short names and units of the selected variables, which can be useful for plotting, logging, or other purposes.

Feature Request: Enhance Visualization Plots with Map Features

Description: To improve the visualization plots in the vis.py module, we should add the following features to the plots:

Country Borders:
Add country borders to the plots using cartopy.feature.BORDERS.
Set the linestyle to a solid line ("-") and the edge color to black.

Coastlines:
Include coastlines in the plots using cartopy.feature.COASTLINE.
Set the linestyle to a solid line ("-") and the edge color to black.

Gridlines:
Add gridlines to the plots using ax.gridlines().
Use the projection specified in constants.SELECTED_PROJ for the gridlines.
Disable drawing labels for the gridlines by setting draw_labels=False.
Set the linewidth of the gridlines to 0.5 and the alpha (transparency) to 0.5.

Implementation:

import cartopy.feature as cf
ax.add_feature(cf.BORDERS, linestyle="-", edgecolor="black")
ax.add_feature(cf.COASTLINE, linestyle="-", edgecolor="black")
ax.gridlines(
    crs=constants.SELECTED_PROJ,
    draw_labels=False,
    linewidth=0.5,
    alpha=0.5,
)

Benefits:
Enhancing the plots with country borders and coastlines will provide better geographical context and make the visualizations more informative.
Adding gridlines will improve the readability of the plots and help in understanding the spatial distribution of the data.
Customizing the linestyle, edge color, linewidth, and alpha allows for better visual aesthetics and clarity in the plots.

Notes: Since we are already depending on cartopy for the projections we might as well use some additional functionalities.

Handling checkpoint-breaking changes

Background

As we make more changes to the code there will be points where checkpoints from saved models can not be directly loaded in a newer version of neural-lam. This happens in particular if we start making changes to variable names of nn.Module attributes and the overall structure of the model classes. It would be good to have a policy of how we want to handle such breaking changes. This issue is for discussing this.

Proposals

I see three main options:

  1. Ignore this issue, and only guarantee that checkpoints trained in a specific version of neural-lam works with that version. If you upgrade you have to re-train models or do some "surgery" on your checkpoints files yourself.
  2. Make sure that we can load checkpoints from all previous versions. This is doable as long as the same neural network parameters are in there, just with different names. We have an example of this already, in the current ARModel:
    def on_load_checkpoint(self, checkpoint):
    """
    Perform any changes to state dict before loading checkpoint
    """
    loaded_state_dict = checkpoint["state_dict"]
    # Fix for loading older models after IneractionNet refactoring, where
    # the grid MLP was moved outside the encoder InteractionNet class
    if "g2m_gnn.grid_mlp.0.weight" in loaded_state_dict:
    replace_keys = list(
    filter(
    lambda key: key.startswith("g2m_gnn.grid_mlp"),
    loaded_state_dict.keys(),
    )
    )
    for old_key in replace_keys:
    new_key = old_key.replace(
    "g2m_gnn.grid_mlp", "encoding_grid_mlp"
    )
    loaded_state_dict[new_key] = loaded_state_dict[old_key]
    del loaded_state_dict[old_key]
  3. Create a separate script for converting checkpoint files from one version to another. The required logic for this is the same as in point 2, but here moved to a separate script that takes a checkpoint file as input and saves a new checkpoint file, now compatible with the new neural-lam version.

Considerations for point 2 and 3

  • This require that as soon as we make such a checkpoint-breaking change we also write the code for handling checkpoints before that change.
  • It would likely be useful to keep track of which version a certain checkpoint was created with, so we know if it needs to be converted. A simple way to do this could be to handle it similarly as Lightning, that stores the version of the package in the checkpoint file (e.g `ckpt["pytorch-lightning_version"] = "2.2.1").

My view

  • I see these kinds of changes to not happen that often, and maybe mostly will happen right now as we are refactoring some things. That could be a reason to just go for alternative 1, but also means that alternative 2/3 is less work. I think I am leaning towards alternative 3, as I would like to be able to use my existing trained checkpoints.
  • I prefer 3 over 2 as I think on_load_checkpoint would get unnecessarily complicated and I'd rather just do the conversion once and have a set of new checkpoint files. It is also easy to do both 2 and 3: if you try to load an old checkpoint you just convert it before loading.
  • While I think it makes sense to keep some track of which version a checkpoint was created with, I would like to avoid building any complex system around this. At the end of the day it is up to the user to make sure that they are using their checkpoint in a training/eval configuration that makes sense. With good tracking in e.g. W&B this is entirely doable. But it is nice to provide some tools to easily upgrade checkpoints if possible.

Tagging @leifdenby and @sadamov to get your input.

How could we add precipitation forecast in your neural-lam?

I've send email to you asking about removing precipitation from GraphCast origional codes,sir.
Thank you for your quick reply. In the email you explained the precipitation part is difficult due to ERA5
precipitation is not accurate enough. So, as in a specific region like China, we got high quelity precipitation
reanalysis data. Maybe, we could use these precipitation data instead of ERA5 precipitation to make the forecast better.

Now, how could we add the precipitation part in your neural-lam codes?
Thank you.

Structured documentation for Neural-LAM

As Neural-LAM grows and we make changes to it there is an increasing need for documentation. Already now the README is quite long, and some things are not that easy to find in it. I propose that we should set up a more robust and structured documentation solution. The idea of this issue is to start a discussion about how to do this.

When thinking about documentation for different open source python project, one that I really like is the pytorch geometric documentation. This uses Read the docs + Sphinx. In particular, you can create a documentation that contains both:

  1. Manually written tutorials, quickstart-guides, installation instructions, etc.
  2. A full API reference for all classes and functions in the codebase (descriptions + arguments + what is returned), generated directly from the docstrings in the python code.

I think having both of these is desirable, and think we could set up something like this for Neural-LAM.

I know that in weather-model-graphs you have these nice Jupyter books for documentation @leifdenby. Is there a way to do these things in those as well? In particular, is it possible to do the automatic reference generation from docstrings?

Generate a flow direction map

dem格网
图片2
Dear Mr. Joel Oskarsson, in the 'Graph Construction' section of your paper, you aggregate neighbor nodes in eight directions. A student would like to ask for your advice on how to selectively aggregate neighbor nodes. Could you also explain how to generate a flow direction map, for example, as shown below: Based on neighbors in the D8 or D4 directions, generate a flow direction map from the DEM image, and then aggregate the neighbor features of upstream river nodes. Article source: https://eartharxiv.org/repository/view/3018/, but this article does not share its code, which is very challenging for beginners.

Keeping a changelog and versioning

I would like to propose that we start keeping a changelog. This would mean that from now on changes are only added through pull-requests, and every pull-request should add an entry to CHANGELOG.md (typically this addition will be done as a final commit once a pull-request has passed tests and review).

My notes below are based on https://keepachangelog.com/en/1.0.0/, please read it. It is a very succinct.

Why keep a changelog?

As the code keeps evolving it will be important to be able to communicate with each other how the codebase is changing. Having a single point of reference, where changes are grouped by feature additions, bugfixes, breaking changes and maintenance will make it a lot easier (and a lot more fun!) to work with the codebase.

For reference here is xarray's changelog: https://github.com/pydata/xarray/blob/main/doc/whats-new.rst, here is one for a different project I've worked on: https://github.com/cloudsci/cloudmetrics/blob/master/CHANGELOG.md

Connected to this I would also like to introduce versioning, using semantic versioning. I would to tag the commit 2378ed7 @joeloskarsson as v0.1.0, and create a changelog relative to there. My reasoning for this is that I think that commit was the most recent commit when you shared this repository publicly. We can the start the changelog by including the commit that you made for recent loss-function additions (new feature) and the maintenance @sadamov is doing in #6 by setting up linting.

In that case the CHANGELOG would look something like:

# Changelog

## [Unreleased](https://github.com/joeloskarsson/neural-lam/HEAD)

[Full Changelog](https://github.com/joeloskarsson/neural-lam/compare/v0.1.0...HEAD)

*new features*

- additional loss-functions 
  [2378ed7](https://github.com/joeloskarsson/neural-lam/commit/2378ed7eddf8da5bfec6f57c41cadf310d191dee)](https://github.com/joeloskarsson/neural-lam/commit/c14b6b4323e6b56f1f18632b6ca8b0d65c3ce36a) By Joel Oskarsson
  (@joeloskarsson )

*maintenance*

- set up linting with pyflake8, black with pre-commit with github action to run in cicd [\#1](https://github.com/joeloskarsson/neural-lam/pull/6/) by Simon Adamov (@sadamov)

## [v0.1.0](https://github.com/joeloskarsson/neural-lam/HEAD)

First public release of Neural-LAM, including functionality to train hierarchical and ...

Let me know your thoughts. I would be happy to set this up if you are happy with it. But I just wanted to get the discussion going before the codebase starts changing a lot

Variable names and weights as part of static dataset files

Currently the variables in the dataset are listed in constants.py. This is bad if the code is to be used with other datasets.

Proposition

Create a file variables.json in data/my_dataset/static that describe all variables. This includes:

  • Weather state variables (e.g. u_65)
  • Forcing variables for the full grid
  • Batch-static forcing variables (static during one forecast, but changing throughout the dataset. i.e. open water currently)

All of these should be listed in order with names. For the weather state variables, their weighting (as in parameter_weights.npy currently) should also be listed with them. We can then remove the lines https://github.com/joeloskarsson/neural-lam/blob/89a4c63370201c9ea1a5f04d4cf1e5e75b7cc83e/create_parameter_weights.py#L26-L31 that generate this weighting file. It is better to let this be something that is set manually when preparing a dataset.

Such a variables.json file could then be loaded into a VariableDescription object and used in the models. The variable dimensions https://github.com/joeloskarsson/neural-lam/blob/89a4c63370201c9ea1a5f04d4cf1e5e75b7cc83e/neural_lam/models/ar_model.py#L22-L24 should then be read from this object rather than hard-coded in a model definition.

Idea for how to define neural-networks operating on graph

I've been thinking about whether we could construct the different neural-networks which update the graph. The ideas below are very tentative and I may also have fundamentally misunderstood parts of how this is supposed to work, so all feedback is very welcome.

To my mind there are two types of neural networks we used: 1) MLPs for creating embeddings of node/edge features into a common-sized latent-space across all embeddings and 2) for updating node embeddings using message-passing.

What I would like to achieve is code that:

  • is closer to the mathematical notation used in the neural-lam publication for the expressions that do the embedding and message-passing, i.e. shows which nodes/edges are being operated on
  • allows me to easily see in one place the complete set of neural networks used in a given architecture
  • is flexible by making easy to create new message passing operations

I think there are basically three steps to this:

  1. Define which embedding networks to create, the number of embedders will depend on nodes/edges share the same features or not
  2. Define the message passing operations, i.e. which nodes the message passing communicates between (giving each operation a unique identifier)
  3. Define the order of the message-passing order

Below is my code example that tries to encapsulate all the information I think is needed to later in the code actually instantiate the neural-network models that do the numerical operations.

import pytorch_lightning as pl


class NewGraphModel(pl.LightningModule):
    def __init__(self):
        n_mesh_node_features = 3
        n_grid_node_features = 18
        n_edge_features = 3
        n_hidden_features = 64

        # create node and edge feature embedding networks, all will project
        # into an embedding space of size n_hidden_features
        embedders = [
            dict(
                node=dict(component="mesh"),
                n_input_features=n_mesh_node_features,
            ),
            dict(
                node=dict(component="grid"),
                n_input_features=n_grid_node_features,
            ),
            # use the same edge embedding network for all edges, assuming
            # that all edges have the same set of features
            dict(edge=dict(), n_input_features=n_edge_features),
        ]

        # create message-passing networks that update node embeddings,
        # these all assume that the embedding vectors are of the same size
        message_passers = dict(
            g2m=dict(src=dict(component="grid"), dst=dict(component="mesh")),
            m2m_up_1to2=dict(
                src=dict(component="mesh", level=1),
                dst=dict(component="mesh", level=2),
            ),
            m2m_down_2to1=dict(
                src=dict(component="mesh", level=2),
                dst=dict(component="mesh", level=1),
            ),
            m2m_inlevel_1=dict(
                src=dict(component="mesh", level=1),
                dst=dict(component="mesh", level=1),
            ),
            m2m_inlevel_2=dict(
                src=dict(component="mesh", level=2),
                dst=dict(component="mesh", level=2),
            ),
            m2g=dict(src=dict(component="mesh"), dst=dict(component="grid")),
        )

        # define the order in which messages are passed
        # here we do the up/down twice before decoding back to the grid
        message_passing_order = [
            "g2m",
            # m2m pass 1
            "m2m_up_1to2",
            "m2m_inlevel_2",
            "m2m_down_2to1",
            "m2m_inlevel_1",
            # m2m pass 2
            "m2m_up_1to2",
            "m2m_inlevel_2",
            "m2m_down_2to1",
            "m2m_inlevel_1",
            "m2g",
        ]

A few notes on this to explain what is going on:

  • the three sections implement those three steps 1) embedders, 2) message passers, 3) message passing order
  • when defining the embedders or the message-passers then the arguments in the dict in effect define filters that select edges
  • I have made the message passing order explicit because this would create a single place to refer to for the message passing operations ordering. This feels a bit like nn.Sequential but is isn't of course because its not like the output from one step feeds into the next.

I hope this isn't total nonsense @joeloskarsson and @sadamov 😆 just trying to get the ball rolling

Grid Dimension Ordering

Hi, I ran into some issues using create_mesh.py for my dataset. I'm wondering if there might be a bug in the documentation for the code?

The code comments consistently say that a sample has dimensions (N_t', dim_x, dim_y, d_features').

full_sample = torch.tensor(
np.load(sample_path), dtype=torch.float32
) # (N_t', dim_x, dim_y, d_features')

Similarly, nwp_xy.npy, the coordinates of each grid point also has dimensions in the same order (2, N_x, N_y).

# -- Static grid node features --
grid_xy = torch.tensor(
np.load(os.path.join(static_dir_path, "nwp_xy.npy"))
) # (2, N_x, N_y)

Now, these lines in create_mesh.py imply that the first slice of nwp_xy.npy hold the x coordinates, and vice versa.

neural-lam/create_mesh.py

Lines 110 to 112 in 9d558d1

def mk_2d_graph(xy, nx, ny):
xm, xM = np.amin(xy[0][0, :]), np.amax(xy[0][0, :])
ym, yM = np.amin(xy[1][:, 0]), np.amax(xy[1][:, 0])

I now refer to nwp_xy.npy as grid_xy.
Therefore, given the dimension ordering (2, N_x, N_y), I would reason that grid_xy[0][0] is a length N_y vector holding the x coordinates of vertical slice of the dataset. I would expect grid_xy[0][0] to have the same x coordinate. However, I see that isn't the case.

grid_xy = np.load("data/meps_example/static/nwp_xy.npy")
print(grid_xy.shape)
print(grid_xy[0][0].shape)
print(grid_xy[0][0])

>>>
(2, 268, 238)
(238,)
array([-1.0601221e+06, -1.0501221e+06, -1.0401222e+06, -1.0301222e+06,
...

Trying to create a mesh with my own dataset with the ordering (2, N_x, N_y) gives me a buggy graph. However, when changing the ordering to (2, N_y, N_x), the graph seems fine.

Is this an issue with the code comments? Should the correct dimension ordering be N_y, N_x in all cases instead of N_x, N_y?

wandb not initialized in train_model.py

Hi,

I'm currently doing some first tests with the train-model.py script.
I'm quite new to wandb so I might have missed something, but it seems that
wandb.init('neural-lam') is never called, which leads to the following error:
wandb.errors.Error: You must call wandb.init() before wandb.define_metric() which traces back to /neural_lam/utils.py", line 203, in init_wandb_metrics.

Adding wandb.init('neural-lam') here:

    if trainer.global_rank == 0:
        wandb.init(project="neural-lam")
        utils.init_wandb_metrics() # Do after wandb.init

seems to work, but I guess the name of the project should be read from constants.py.

Kind regards,
Michiel

Regression tests for model outputs

Something that would be really nice to have is regression testing for model outputs. In short, whenever we refactor something in models we want them to still be able to load checkpoints (or well, see #48 ) and give exactly the same output when being fed with the same data.

One way to achieve this could be to

  1. Check out main branch
  2. Run some example data through the model and save the predictions (potentially also some internal representation tensors but likely unneccesary and hard to do in practice)
  3. Check out PR
  4. Run the same example data through the model and compare outputs to saved predictions.

I'm not too familiar with pytest and the github workflows to know all the details of how to do this. @SimonKamuk, @leifdenby do you think something like this is doable? Or are there any better ways to achieve this?

Pytorch Dataset Class that Reads From Zarr Archive

Summary
Since the weather community and especially ECMWF moved towards a single zarr archive that contains all the data in the state (domain), and one that contains all the data in the boundary, this project should follow the same approach. Zarr has many advantages like parallel computing with dask, lazy loading with xarray, efficient compression with different algorithms and chunking.

Specifics
There are three main data-processing steps happening in the current pipeline. This is a proposal how the work would be split between the three:

  • Data-Preprocessing
    • Usually some format like GRIB2 is converted into xarray->zarr. This step is out of scope
    • Pre-computation of forcings, static and grid features
    • Computation of normalization constants (stats) and inverse variances
    • Generating the boundary mask
  • Pytorch Dataset [on CPU]:
    • Reshaping of 3D variables into stacked 2D variables
    • Split data into train/val/test based on some indicator (e.g. time)
    • Generate the windowed indices for forcing and boundary
  • Pytorch Model [on GPU]
    • Normalization of the data

Interfaces

  • Data-Preprocessing
    • Input: out of scope
    • Output: one or multiple zarr files
  • Pytorch Dataset [on CPU]:
    • Input: one or multiple zarr files
    • Output: 5 pytorch tensors with the following dimensions:
      init_states: (2, N_grid, features_dim), 
      target_states: (n_lead_times, N_grid, features_dim), 
      forcing: (n_lead_times, N_grid, forcing_windowed_dim) # window_steps * n_forcing
      boundary: (n_lead_times, N_grid, boundary_windowed_dim) # windowed_steps * n_boundary
      batch_times: (2 + n_lead_times)[str]
  • Pytorch Model [on GPU]
    • Input: 5 pytorch tensors (batched with Pytorch DataLoader)
    • Output: out of scope

Implementation
One example for such a pytorch dataset and dataloader can be found here for inpisration: https://github.com/MeteoSwiss/neural-lam/blob/main/neural_lam/weather_dataset.py It needs however quite a bit of work:

  • Handle multiple zarrs properly
  • Use the YAML from #23
  • Remove temporal forcing calculation
  • Add boundary to dataflow
  • ...

Draw IO
dataset

Refactor model class hierarchy

Background

The different models that can be trained in Neural-LAM are currently all sub-classes of pytorch_lightning.LightningModule. In particular, much of the functionality sits in the first subclass, ARModel. The current hierarchy looks like this:
classes_current
(I am making these rather than some more fancy UML-diagrams since I think this should be enough for the level of detail we need to discuss here).

The problem

In the current hierarchy everything is a subclass of ARModel. This has a number of drawbacks:

  • All functionality ends up in one class. There is no clear division of responsibility, but rather we end up with one class that does everything from logging, device-handling, unrolling forecasts and the actual forward and backward pass through GNNs.
  • The subclasses do not utilize the standard torch forward calls, but rather must resorts to our own similar construction (e.g. predict_step)
  • This limits us to deterministic, auto-regressive models.
  • This is hard to extend upon for ensemble-models.

Proposed new hierarchy

I propose to split up the current class hierarchy into subclasses that have clear responsibilities. These should not just all inherit ARModel, but rather be members of each other as suitable. A first idea for this is shown below, including also potential future classes for new models (to show how this is more extendible):

classes_proposal

The core components are (I here count static features as part of forcing):

  • ForecasterModule: Takes over much of the responsibility of the old ARModel. Handles things not directly related to the nerual network components such as plotting, logging, moving batches to the right device. This inherits pytorch_lightning.LightningModule and have the different train/val/test steps. In each step (train/val/test), unpacks the batch of tensors and uses a Forecaster to produce a full forecast. Also responsible for computing the loss based in a produced forecast (could also be in Forecaster, not entirely sure about this).
  • Forecaster: A generic forecaster capable of mapping from a set of initial states, forcing and boundary forcing into a full forecast of the requested length.
    • ARForecaster: Subclass of Forecaster that uses an auto-regressive strategy to unroll a forecast. Makes use of a StepPredictor at each AR step.
  • StepPredictor: A model mapping from the two previous time steps + forcing + boundary forcing to a prediction of the next state. Corresponds to the $\hat{f}$ function in Oskarsson et al..
    • We find the existing graph models now as subclasses to StepPredictor.

In the figure above we can also see how new kinds of models could fit into this hierarchy:

  • Ensemble models using an auto-regressive strategy to sample the state at each time step.
  • Other types of single-step-predictors, e.g. using CNNs or Vision-Transformers.
  • Non-autoregressive strategys for creating a full forecast, in particular direct forecasting.

This is supposed to be a starting point for discussion and there will likely be things I have not thought about. Some parts of this will have to be hammered out when actually writing these classes, but I'd rather have the discussion whether this is a good direction to take things before starting to do too much work. Tagging @leifdenby and @sadamov for visibility.

Modifying `WeatherDataset.__getitem__` to return timestamps

I'm making good progress on #54 and in going through it I noticed that @sadamov you modified the return signature of WeatherDataset.__getitem__ to also return batch_times (which is looks like are np.datetime64 converted to strings). I can see the use of this for fx being able to plot the input and predictions from the model with timestamps. I think if we want to be able to make these plots with timestamps we can avoid returning the time here too. I'm not sure about using strings thought...

What are your thoughts on this @sadamov and @joeloskarsson?

Area definition as part of static dataset files

In order to make the code reusable with different limited area configurations, the exact definition of the model area should not be hard-coded in constants.py.

Proposition

Let a file area_def.json, stored in data/my_dataset/static describe the model area. This description should include:

  • The projection used for defining the area (useful for plotting).
  • The extent of the area in the coordinate system of the projection.
  • The grid shape (N x M).

These things are currently contained in constants.py. The json file could then be used to create an AreaDefinition object that could be passed to the model, plotting scripts etc.

Mesh Node Update: Paper vs. Code

Hi, I wanted to ask about a small discrepancy I may have found between the paper and the code.

image

In the paper, the G2M GNN updates the mesh embeddings using the current mesh embedding, and the edge embeddings of any G2M edges.

However, in the code, the grid node embeddings are also passed into the InteractionNet, which concatenates the grid and mesh embeddings.

Am I misunderstanding anything here? If the discrepancy is intentional, could someone please give a brief justification why the change was done?

Normalize Data on GPU

Motivation
Data normalization can be done on the fly on GPU for each batch. It's faster on GPU than CPU and cleans up the dataset init method.

Implementation
Could very nicely use https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#on-after-batch-transfer to normalize once data is on GPU. Makes sure that you never forget about it (all batches on GPU are normalized).

The stats could be provided by a yaml_object handler that can be accessed on the model's init

Feature Request: Generate GIFs from PNG Images for Each Variable and Level

Description: Currently, the code generates individual PNG images for each variable and level combination at different time steps. To enhance visualization and make it easier to observe changes over time, we propose generating GIF animations from these PNG images.

Proposed Changes:

Iterate over each unique combination of variable name and level. (using #18 )

for var_name, _ in self.selected_vars_units:
    var_indices = self.variable_indices[var_name]
    for lvl_i, _ in enumerate(var_indices):
        # Calculate var_vrange for each index
        lvl = constants.VERTICAL_LEVELS[lvl_i]

        # Get all the images for the current variable and index
        images = sorted(
            glob.glob(
                f"{plot_dir_path}/"
                f"{var_name}_prediction_lvl_{lvl:02}_t_*.png"
            )
        )
        # Generate the GIF
        with imageio.get_writer(
            f"{plot_dir_path}/{var_name}_prediction_lvl_{lvl:02}.gif",
            mode="I",
            fps=1,
        ) as writer:
            for filename in images:
                image = imageio.imread(filename)
                writer.append_data(image)

Benefits:

  • Improved visualization of changes over time for each variable and level.
  • Easier comparison of different variables and levels through animated GIFs.
  • Consolidated visualization format that combines multiple PNG images into a single GIF file.

Considerations:

  • Ensure that the necessary dependencies, such as the imageio library, are installed.
  • Verify that the PNG images are generated correctly and follow the expected naming convention.
  • Consider the storage requirements for the generated GIF files, as they may be larger than individual PNG images.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.