Giter VIP home page Giter VIP logo

graphcast's Introduction

GraphCast: Learning skillful medium-range global weather forecasting

This package contains example code to run and train GraphCast. It also provides three pretrained models:

  1. GraphCast, the high-resolution model used in the GraphCast paper (0.25 degree resolution, 37 pressure levels), trained on ERA5 data from 1979 to 2017,

  2. GraphCast_small, a smaller, low-resolution version of GraphCast (1 degree resolution, 13 pressure levels, and a smaller mesh), trained on ERA5 data from 1979 to 2015, useful to run a model with lower memory and compute constraints,

  3. GraphCast_operational, a high-resolution model (0.25 degree resolution, 13 pressure levels) pre-trained on ERA5 data from 1979 to 2017 and fine-tuned on HRES data from 2016 to 2021. This model can be initialized from HRES data (does not require precipitation inputs).

The model weights, normalization statistics, and example inputs are available on Google Cloud Bucket.

Full model training requires downloading the ERA5 dataset, available from ECMWF. This can best be accessed as Zarr from Weatherbench2's ERA5 data (see the 6h downsampled versions).

Overview of files

The best starting point is to open graphcast_demo.ipynb in Colaboratory, which gives an example of loading data, generating random weights or load a pre-trained snapshot, generating predictions, computing the loss and computing gradients. The one-step implementation of GraphCast architecture, is provided in graphcast.py.

Brief description of library files:

  • autoregressive.py: Wrapper used to run (and train) the one-step GraphCast to produce a sequence of predictions by auto-regressively feeding the outputs back as inputs at each step, in JAX a differentiable way.
  • casting.py: Wrapper used around GraphCast to make it work using BFloat16 precision.
  • checkpoint.py: Utils to serialize and deserialize trees.
  • data_utils.py: Utils for data preprocessing.
  • deep_typed_graph_net.py: General purpose deep graph neural network (GNN) that operates on TypedGraph's where both inputs and outputs are flat vectors of features for each of the nodes and edges. graphcast.py uses three of these for the Grid2Mesh GNN, the Multi-mesh GNN and the Mesh2Grid GNN, respectively.
  • graphcast.py: The main GraphCast model architecture for one-step of predictions.
  • grid_mesh_connectivity.py: Tools for converting between regular grids on a sphere and triangular meshes.
  • icosahedral_mesh.py: Definition of an icosahedral multi-mesh.
  • losses.py: Loss computations, including latitude-weighting.
  • model_utils.py: Utilities to produce flat node and edge vector features from input grid data, and to manipulate the node output vectors back into a multilevel grid data.
  • normalization.py: Wrapper for the one-step GraphCast used to normalize inputs according to historical values, and targets according to historical time differences.
  • predictor_base.py: Defines the interface of the predictor, which GraphCast and all of the wrappers implement.
  • rollout.py: Similar to autoregressive.py but used only at inference time using a python loop to produce longer, but non-differentiable trajectories.
  • solar_radiation.py: Computes Top-Of-the-Atmosphere (TOA) incident solar radiation compatible with ERA5. This is used as a forcing variable and thus needs to be computed for target lead times in an operational setting.
  • typed_graph.py: Definition of TypedGraph's.
  • typed_graph_net.py: Implementation of simple graph neural network building blocks defined over TypedGraph's that can be combined to build deeper models.
  • xarray_jax.py: A wrapper to let JAX work with xarrays.
  • xarray_tree.py: An implementation of tree.map_structure that works with xarrays.

Dependencies.

Chex, Dask, Haiku, JAX, JAXline, Jraph, Numpy, Pandas, Python, SciPy, Tree, Trimesh and XArray.

License and attribution

The Colab notebook and the associated code are licensed under the Apache License, Version 2.0. You may obtain a copy of the License at: https://www.apache.org/licenses/LICENSE-2.0.

The model weights are made available for use under the terms of the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0). You may obtain a copy of the License at: https://creativecommons.org/licenses/by-nc-sa/4.0/.

The weights were trained on ECMWF's ERA5 and HRES data. The colab includes a few examples of ERA5 and HRES data that can be used as inputs to the models. ECMWF data product are subject to the following terms:

  1. Copyright statement: Copyright "© 2023 European Centre for Medium-Range Weather Forecasts (ECMWF)".
  2. Source www.ecmwf.int
  3. Licence Statement: ECMWF data is published under a Creative Commons Attribution 4.0 International (CC BY 4.0). https://creativecommons.org/licenses/by/4.0/
  4. Disclaimer: ECMWF does not accept any liability whatsoever for any error or omission in the data, their availability, or for any loss or damage arising from their use.

Disclaimer

This is not an officially supported Google product.

Copyright 2023 DeepMind Technologies Limited.

Citation

If you use this work, consider citing our paper:

@article{lam2022graphcast,
      title={GraphCast: Learning skillful medium-range global weather forecasting},
      author={Remi Lam and Alvaro Sanchez-Gonzalez and Matthew Willson and Peter Wirnsberger and Meire Fortunato and Alexander Pritzel and Suman Ravuri and Timo Ewalds and Ferran Alet and Zach Eaton-Rosen and Weihua Hu and Alexander Merose and Stephan Hoyer and George Holland and Jacklynn Stott and Oriol Vinyals and Shakir Mohamed and Peter Battaglia},
      year={2022},
      eprint={2212.12794},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

graphcast's People

Contributors

alvarosg avatar geraedts-google avatar lewington-pitsos avatar mjwillson avatar shoyer avatar tewalds avatar voctav avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

graphcast's Issues

Haiku needs all `hk.Module` must be initialized inside an `hk.transform`

Hi. Ia am trying to execute the graphcast model in a conda enviornment built with the same packages version of a working execution at google collab but whenever I try to build the model at construct_wrapped_graphcast function returns this error:

Traceback (most recent call last):
  File "/home/eloy.anguiano/repos/graphcast/0.get_model.py", line 76, in <module>
    model = construct_wrapped_graphcast(model_config, task_config)
  File "/home/eloy.anguiano/repos/graphcast/0.get_model.py", line 58, in construct_wrapped_graphcast
    predictor = graphcast.GraphCast(model_config, task_config)
  File "/home/eloy.anguiano/repos/graphcast/graphcast/graphcast.py", line 261, in __init__
    self._grid2mesh_gnn = deep_typed_graph_net.DeepTypedGraphNet(
  File "/home/eloy.anguiano/miniconda3/envs/graphcast_iic/lib/python3.10/site-packages/haiku/_src/module.py", line 139, in __call__
    init(module, *args, **kwargs)
  File "/home/eloy.anguiano/miniconda3/envs/graphcast_iic/lib/python3.10/site-packages/haiku/_src/module.py", line 433, in wrapped
    raise ValueError(
ValueError: All `hk.Module`s must be initialized inside an `hk.transform`.

I checked that both dm-haiku versions (collaboratory and local) are 0.0.11. Is there any dockerfile to build a working environment or something like that? It is very difficult to run the same collab env at local.

How to reproduce:

from google.cloud import storage
from graphcast import autoregressive
from graphcast import casting
from graphcast import checkpoint
from graphcast import graphcast
from graphcast import normalization
import xarray


MODEL_VERSION = 'GraphCast_operational - ERA5-HRES 1979-2021 - resolution 0.25 - pressure levels 13 - mesh 2to6 - precipitation output only.npz'

# @title Authenticate with Google Cloud Storage
gcs_client = storage.Client.create_anonymous_client()
gcs_bucket = gcs_client.get_bucket("dm_graphcast")

with gcs_bucket.blob(f"params/{MODEL_VERSION}").open("rb") as f:
    ckpt = checkpoint.load(f, graphcast.CheckPoint)
params = ckpt.params
state = {}

model_config = ckpt.model_config
task_config = ckpt.task_config
print("Model description:\n", ckpt.description, "\n")
print("Model license:\n", ckpt.license, "\n")



with gcs_bucket.blob("stats/diffs_stddev_by_level.nc").open("rb") as f:
  diffs_stddev_by_level = xarray.load_dataset(f).compute()
with gcs_bucket.blob("stats/mean_by_level.nc").open("rb") as f:
  mean_by_level = xarray.load_dataset(f).compute()
with gcs_bucket.blob("stats/stddev_by_level.nc").open("rb") as f:
  stddev_by_level = xarray.load_dataset(f).compute()

def construct_wrapped_graphcast(
    model_config: graphcast.ModelConfig,
    task_config: graphcast.TaskConfig):
    """Constructs and wraps the GraphCast Predictor."""
    # Deeper one-step predictor.
    predictor = graphcast.GraphCast(model_config, task_config)

    # Modify inputs/outputs to `graphcast.GraphCast` to handle conversion to
    # from/to float32 to/from BFloat16.
    predictor = casting.Bfloat16Cast(predictor)

    # Modify inputs/outputs to `casting.Bfloat16Cast` so the casting to/from
    # BFloat16 happens after applying normalization to the inputs/targets.
    predictor = normalization.InputsAndResiduals(
        predictor,
        diffs_stddev_by_level=diffs_stddev_by_level,
        mean_by_level=mean_by_level,
        stddev_by_level=stddev_by_level)

    # Wraps everything so the one-step model can produce trajectories.
    predictor = autoregressive.Predictor(predictor, gradient_checkpointing=True)
    return predictor

model = construct_wrapped_graphcast(model_config, task_config)
print("Done")

Ask for your help: Crop my personal area of interest, but no results.

Dear graphcast experts, I uploaded my own research area: shp file into your demo code and then cropped the example_batch data. The final output includes only Targets Predictions and no results for predictions and Diff. I looked at the specific Predictions and the latitude and longitude ranges are correct, the geographical ranges are correct, but the variables will all have values of nan. I beg your answer. Thank you very much.
Here are my cut code:

SHP = geopandas. Read_file ('/content/drive/MyDrive/graphcast - the main/SHP/roi. SHP ')
example_batch.rio.set_spatial_dims(x_dim="lon", y_dim="lat", inplace=True)
example_batch.rio.write_crs("WGS1984", inplace=True)
example_batch = example_batch.rio.clip(shp.geometry.apply(mapping),shp.crs)

1701334975102

Are forcing variables repeated?

In graphcast.py are Tasks defined. For example, the task used for 1 degree resolution and 13 pressure levels have these varaibles:

input_variables=(
        '2m_temperature', 
        'mean_sea_level_pressure', 
        '10m_v_component_of_wind', 
        '10m_u_component_of_wind', 
        'total_precipitation_6hr', 
        'temperature', 
        'geopotential', 
        'u_component_of_wind', 
        'v_component_of_wind'
        'vertical_velocity'
        'specific_humidity'
        'toa_incident_solar_radiation'
        'year_progress_sin'
        'year_progress_cos'
        'day_progress_sin'
        'day_progress_cos'
        'geopotential_at_surface'
        'land_sea_mask'
    )
    forcing_variables=(
        'toa_incident_solar_radiation'
        'year_progress_sin'
        'year_progress_cos'
        'day_progress_sin'
        'day_progress_cos'
    )

The forcing variables are included inside the input ones. However, inside graphcast the function that transform xarrays to numpy is _inputs_to_grid_node_features:

def _inputs_to_grid_node_features(
        self,
        inputs: xarray.Dataset,
        forcings: xarray.Dataset,
    ) -> chex.Array:
        """xarrays -> [num_grid_nodes, batch, num_channels]."""

        # xarray `Dataset` (batch, time, lat, lon, level, multiple vars)
        # to xarray `DataArray` (batch, lat, lon, channels)
        stacked_inputs = model_utils.dataset_to_stacked(inputs)
        stacked_forcings = model_utils.dataset_to_stacked(forcings)
        stacked_inputs = xarray.concat(
            [stacked_inputs, stacked_forcings], dim="channels"
        )

        # xarray `DataArray` (batch, lat, lon, channels)
        # to single numpy array with shape [lat_lon_node, batch, channels]
        grid_xarray_lat_lon_leading = model_utils.lat_lon_to_leading_axes(
            stacked_inputs
        )
        return xarray_jax.unwrap(grid_xarray_lat_lon_leading.data).reshape(
            (-1,) + grid_xarray_lat_lon_leading.data.shape[2:]
        )

The question is: why are input and forcing varaibles being concatenated, if forcing are already included in input?

GraphCast's forecasts beyond 10 days

Hi authors of GraphCast,

Congrats on the amazing work and thank you for open-sourcing the model. Do you have the forecasts of GraphCast over 10 days of lead time by any chance? We're working on a project and hope to compare our method with GraphCast at longer lead times.

JAX error while training

Hello, I wonder if anyone encountered the same question. When I try the graphcast_demo.ipynb, every step is okay until training, and I see this:
image
image
I use the example data given in the google cloud bucket,and change nothing about code, anyone know how to solve it? My jax version is 0.4.20 with cuda 11.8

weights license - use of graphcast

I've read through the license (https://creativecommons.org/licenses/by-nc-sa/4.0/) and all of the terms are clear except the NonCommercial use. As an employee with the US Government (National Weather Service), we're not going to directly use graphcast weights for commercial purposes. However, the output data and images we create are planned to be on public servers. Once they are Government public servers, then they can be used by corporations to make a profit (or... commercial purposes). Is this OK, since we're not directly using the model weights for commercial purposes?

Division of grid points

Grid points with 1-degree resolution are recommended to start at 89.5 degrees, because any longitude of 90N and 90S is the same point, i.e., the South Pole and the North Pole.

Model Predict

Hello, can this model predict the weather in N20~N40, E110-E130?I see that the model can only train data with latitudes starting from -90 and ending in 90. How can I train data with latitudes between 20 and 40? When training data from latitude 20 to 40, errors are reported in calculating losses!

Constructing datasets from CDS products

My understanding is that each input data file is constructed from a couple of CDS data products for 'atmospheric' and 'single-level' variables. For example, the two queries below pull the required data (for a specific date), as I understand it:

import cdsapi

c = cdsapi.Client()

c.retrieve(
    'reanalysis-era5-pressure-levels',
    {
        'product_type': 'reanalysis',
        'format': 'netcdf',
        'variable': [
            'geopotential', 'specific_humidity', 'temperature',
            'u_component_of_wind', 'v_component_of_wind', 'vertical_velocity',
        ],
        'year': '2023',
        'month': '11',
        'day': '01',
        'time': [
            '00:00', '06:00',
        ],
        'pressure_level': [
            '1', '2', '3',
            '5', '7', '10',
            '20', '30', '50',
            '70', '100', '125',
            '150', '175', '200',
            '225', '250', '300',
            '350', '400', '450',
            '500', '550', '600',
            '650', '700', '750',
            '775', '800', '825',
            '850', '875', '900',
            '925', '950', '975',
            '1000',
        ],
    },
    'atmospheric.nc')

c.retrieve(
    'reanalysis-era5-single-levels',
    {
        'product_type': 'reanalysis',
        'variable': [
            '10m_u_component_of_wind', '10m_v_component_of_wind', '2m_temperature',
            'geopotential', 'land_sea_mask', 'mean_sea_level_pressure',
            'toa_incident_solar_radiation', 'total_precipitation',
        ],
        'year': '2023',
        'month': '11',
        'day': '01',
        'time': [
            '00:00', '06:00',
        ],
        'format': 'netcdf',
    },
    'single-level.nc')

Presumably, these files are then spliced together to form an input file. While I'm able to combine the files easily enough, there are subtleties that I'm clearly messing up, like constructing the datetime and batch coordinates, which seem to be mandatory (e.g., ValueError: 'datetime' must be in data coordinates.)

Do you plan to publish a script or instructions for constructing input files from these data products in the format expected by the model? It would be a fantastic help for applying the model!

Many thanks in advance,
Dan

Develop an hourly model

Thanks for and congratulations on the great work here!

I'm probably not the only person excited to use GraphCast but limited by the 6 hour resolution. I work in the energy sector. Forecasting the shape of electricity prices across the day matters a lot to battery revenue, and is highly dependent on weather. Accuracy at ~hourly resolution probably determines when generic weather apps can switch to new models, too.

So, the feature request is to develop a model with hourly rather than 6-hourly resolution. I'm sure you're already considering this but thought I'd submit a formal issue so those of us interested can follow along.

How to save the model

I have successfully run graphcast_demo.ipynb locally, but how do I save the model parameters after the final training using the sample dataset?

I try using code:

# @title Autoregressive rollout (keep the loop in JAX)
print("Inputs:  ", train_inputs.dims.mapping)
print("Targets: ", train_targets.dims.mapping)
print("Forcings:", train_forcings.dims.mapping)

predictions = run_forward_jitted(
    rng=jax.random.PRNGKey(0),
    inputs=train_inputs,
    targets_template=train_targets * np.nan,
    forcings=train_forcings)
predictions

with open(f"dm_graphcast/params/test.npz", "wb") as f:
    checkpoint.dump(f, graphcast.CheckPoint)

But it doesn't work.

Python 3.9 error

Though code works fine with python 3.10, python 3.9 throws error.

Traceback (most recent call last):
File "/xxxxxxxgraphcast/graphcast_example.py", line 44, in
ckpt=checkpoint.load(f, graphcast.CheckPoint)
File "/xxxxxxxgraphcast/graphcast/checkpoint.py", line 54, in load
return _convert_types(typ, _unflatten(np.load(source)))
File "/xxxxxxxgraphcast/graphcast/checkpoint.py", line 117, in _convert_types
if isinstance(f.type, (types.UnionType, type(Optional[int]))):
AttributeError: module 'types' has no attribute 'UnionType'

OS: Ubuntu 20.04.6
Python: 3.9.17
jax: 0.4.14

Predicting Forecast for 10 Days , 5 Days

Hi,
I am very new to GraphCast. I wanted to know how can I have the forecast of 10 days and 5 days. I did run the notebook successfully (graphcast_demo.ipynb), but I did not catch up with the changes required for forecasting upto different (user-defined) timesteps.
Right now, I am confused about how can i use any pre-trained model of GraphCast and predict 10days, 5 days forecast?
I would really appreciate if someone explains the necessary steps for it clearly so I can understand them easily.
Thank you very much!

Jax cannot access GPU when generating predictions

When generating predictions, this error comes up:

WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Using an A100 environment on colab with the stock notebook except with the modifications mentioned here.

evaluation on selected region

Hi. Thanks for the great work!
One quick question is how one can evaluate the graphcast on the local regions e.g., North America, by adding some lat/lon points and measuring the error on that region?
@tewalds

About loss weights

Great job! I have a question: How are the weights of different variables in the loss function determined? In other words, how did you obtain the current loss function weights?

About Memory Use

Using the cofiguration resolution -- 0.25°, mesh_size -- 6, latent_size -- 512, and gnn_msg_steps --16, an error message is displayed when dataset_source-era5_date-2022-01-01_res-0.25_levels-13_steps-04.nc data is used for training.

2023-09-15 15:37:29.263873: W external/xla/xla/service/hlo_rematerialization.cc:2202] Can't reduce memory use below 23.81GiB (25567444992 bytes) by rematerialization; only reduced to 86.14GiB (92489619062 bytes), down from 95.71GiB (102770769072 bytes) originally 2023-09-15 15:37:48.396554: W external/tsl/tsl/framework/bfc_allocator.cc:485] Allocator (GPU_0_bfc) ran out of memory trying to allocate 85.87GiB (rounded to 92198188544)requested by op 2023-09-15 15:37:48.398429: W external/tsl/tsl/framework/bfc_allocator.cc:497] ******______________________________________________________________________________________________ 2023-09-15 15:37:48.404245: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2593] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 92198188384 bytes.

Even if I reduce the mesh_size to 5, 38.26 GB of memory is needed in Tesla V100, so I wonder what kind of hardware device can run this configuration for training or what other way to reduce memory?

Data required for prediction

In order to get the predictions for 6 hours later, we need to pass the current weather values and the values 6 hours earlier. However what about the geospatial values, is data for all coordinates required as input or a subset of coordinates can be passed as input, as per the use case?

Possible standalone distribution of `xarray_jax`?

Hi! I'm a big fan of the boilerplate code used to wrap xarray into a JAX-compatible entity -- I think this could have wider usage were it more well-known, especially for this kind of deep learning + weather project.

Would you consider distributing this code as a stand-alone helper module? I'm happy to volunteer refactoring this code into a small library, since I'm probably going to use it anyway -- let me know if you'd like me to take the initiative on that (and where it should live).

Sharing models on the Hugging Face Hub

Hi there!

At the moment, models are downloaded through a public GCP Bucket. It would be useful to have the models on the Hugging Face Hub (maybe under the https://hf.co/google organization?). It has some benefits as:

  • More visibility + discoverability of the existing models
  • Programmatic model access (either via wget or huggingface_hub if you want some of our caching mechanisms)
  • Download stats

Here is a guide in case you're interested, and I'm happy to guide you in the process.

Omar
🤗

[GraphCast Operational Model] Issue with Negative Precipitation Data in GraphCast Operational Model Output

Background:
I've been using the GraphCast Operational model to generate 6-hour cumulative precipitation forecast data and have attempted to visualize it with NASA's Panoply software. Despite the model's ability to predict precipitation without using rainfall as an input, as described by the authors in a Science article, I couldn't find a parameter explicitly labeled as precipitation amount in Panoply. Instead, I came across a parameter named "Mixed intervals Accumulation," which raises my suspicion that Panoply might not be compatible with displaying this type of data.

Questions:
By programming my way through the GRIB files, I managed to access the data marked as "Total precipitation." However, I've noticed that some of the data includes negative values, which seems counterintuitive for precipitation metrics. As I am not a professional meteorologist, I would like to understand the following:

Is it normal to encounter negative values in precipitation data, or could this indicate an error or some other issue?
If negative values are normal, what do they signify?
Should I apply any special treatment to these negative values when analyzing precipitation data?

Additional Information:
The code snippet I used is as follows:

import pygrib  # Library for reading GRIB format data
import datetime  # Library for handling dates and times
import numpy as np  # Library for mathematical operations
import matplotlib.pyplot as plt  # Library for plotting
import cartopy.crs as ccrs  # Library for map projections
import cartopy.feature as cfeature  # Library for map features
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER  # Formatting for map gridline labels

# Open the GRIB2 file
file_path = 'graphcast.grib'  # Replace with your GRIB2 file path
grbs = pygrib.open(file_path)

# Set the target date and time
target_date = datetime.datetime(2023, 12, 30, 18, 0)  # Example date and time
# Find data matching the specific variable and date
for grb in grbs:
    if grb.name == "Total precipitation" and grb.validDate == target_date:
        data = grb.values  # Read the data
        lats, lons = grb.latlons()  # Get the latitudes and longitudes corresponding to the data
        break
grbs.close()

# Print out statistical information about the data
print(f'Minimum value: {data.min()}')  # Minimum value
print(f'Maximum value: {data.max()}')  # Maximum value
print(f'Mean value: {data.mean()}')  # Mean value
print(f'Median value: {np.median(data)}')  # Median value
print(f'Standard deviation: {data.std()}')  # Standard deviation


# Create the map
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())
# Create color levels for the contour plot
levels = np.linspace(-0.0002, 0.12, 10)  # Create levels from a bit below the minimum to above the maximum value

# Plot the contour map using the color levels
precipitation = ax.contourf(lons, lats, data, levels=levels, transform=ccrs.PlateCarree(), cmap='viridis')

# Add map features like coastlines and borders
ax.coastlines()
ax.add_feature(cfeature.BORDERS, linestyle=':')

# Add gridlines and labels for longitude and latitude
gl = ax.gridlines(crs=ccrs.PlateCarree(), draw_labels=True)
gl.top_labels = False  # Disable top labels
gl.right_labels = False  # Disable right labels
gl.xformatter = LONGITUDE_FORMATTER  # Format for longitude
gl.yformatter = LATITUDE_FORMATTER  # Format for latitude

# Add a colorbar to explain the color encoding of precipitation levels
plt.colorbar(precipitation, ax=ax, orientation='horizontal', pad=0.05, aspect=50, label='Precipitation (m/6hr)', ticks=levels)
plt.title(f'6-Hour Accumulated Precipitation Forecast (Ending on {target_date.strftime("%Y-%m-%d %H:%M")})')

plt.show()  # Display the map
# python3.10 requirements.txt
Cartopy==0.22.0
certifi==2023.11.17
contourpy==1.2.0
cycler==0.12.1
fonttools==4.47.0
kiwisolver==1.4.5
matplotlib==3.8.2
mplcursors==0.5.2
numpy==1.26.2
packaging==23.2
pandas==2.1.4
Pillow==10.1.0
pygrib==2.1.5
pyparsing==3.1.1
pyproj==3.6.1
pyshp==2.3.1
python-dateutil==2.8.2
pytz==2023.3.post1
shapely==2.0.2
six==1.16.0
tzdata==2023.3
xarray==2023.12.0

A Python Code visualization screenshot of the precipitation data:
image
A Panoply visualization screenshot of the precipitation data:
image
image

My data file can be downloaded from the following Google Drive link:

I am grateful for any explanation and advice.
Attachments:
Download link for Panoply Software: https://www.giss.nasa.gov/tools/panoply/download/
Download link for the grib(6.5GB) data file :
https://drive.google.com/file/d/1JrsCXZcRBXgEQg-Xu0rd7EsvR_XLARPU/view?usp=drive_link

How to solve TracerArrayConversionError (about xarray_jax)?

When i run the shell code Loss computation (autoregressive loss over multiple steps) and Gradient computation (backprop through time)locally, i encountered the problem.

What is the cause of the problem?
How should this problem be solved locally?
Is there something wrong with my environment setup? Why is there no problem running in google colab?
Oh, I have so many questions...

Below is the complete error log:

{
	"name": "TracerArrayConversionError",
	"message": "The numpy.ndarray conversion method __array__() was called on traced array with shape bfloat16[1].
The error occurred while tracing the function apply_fn at c:\\Users\\90566\\miniconda3\\envs\\graphcast\\lib\\site-packages\\haiku\\_src\\transform.py:440 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:f32[1,2,181,360] = sub b c
    from line d:\\code\\graphcast-0.1\\graphcast\\xarray_jax.py:353 (wrapped_func)

  operation a:f32[1,2,181,360] = div b c
    from line d:\\code\\graphcast-0.1\\graphcast\\xarray_jax.py:353 (wrapped_func)

  operation a:f32[1,2,181,360] = sub b c
    from line d:\\code\\graphcast-0.1\\graphcast\\xarray_jax.py:353 (wrapped_func)

  operation a:f32[1,2,181,360] = div b c
    from line d:\\code\\graphcast-0.1\\graphcast\\xarray_jax.py:353 (wrapped_func)

  operation a:f32[1,2,181,360] = sub b c
    from line d:\\code\\graphcast-0.1\\graphcast\\xarray_jax.py:353 (wrapped_func)

(Additional originating lines are not shown.)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError",
	"stack": "---------------------------------------------------------------------------
TracerArrayConversionError                Traceback (most recent call last)
Cell In[33], line 3
      1 # @title Loss computation (autoregressive loss over multiple steps)
      2 # 
----> 3 loss, diagnostics = loss_fn_jitted(
      4     rng=jax.random.PRNGKey(0),
      5     inputs=train_inputs,
      6     targets=train_targets,
      7     forcings=train_forcings)
      8 print(\"Loss:\", float(loss))

Cell In[29], line 68, in drop_state.<locals>.<lambda>(**kw)
     67 def drop_state(fn):
---> 68   return lambda **kw: fn(**kw)[0]

    [... skipping hidden 12 frame]

File c:\\Users\\90566\\miniconda3\\envs\\graphcast\\lib\\site-packages\\haiku\\_src\\transform.py:456, in transform_with_state.<locals>.apply_fn(params, state, rng, *args, **kwargs)
    454 with base.new_context(params=params, state=state, rng=rng) as ctx:
    455   try:
--> 456     out = f(*args, **kwargs)
    457   except jax.errors.UnexpectedTracerError as e:
    458     raise jax.errors.UnexpectedTracerError(unexpected_tracer_hint) from e

Cell In[29], line 37, in loss_fn(model_config, task_config, inputs, targets, forcings)
     34 @hk.transform_with_state
     35 def loss_fn(model_config, task_config, inputs, targets, forcings):
     36   predictor = construct_wrapped_graphcast(model_config, task_config)
---> 37   loss, diagnostics = predictor.loss(inputs, targets, forcings)
     38   return xarray_tree.map_structure(
     39       lambda x: xarray_jax.unwrap_data(x.mean(), require_jax=True),
     40       (loss, diagnostics))

File d:\\code\\graphcast-0.1\\graphcast\\autoregressive.py:236, in Predictor.loss(self, inputs, targets, forcings, **kwargs)
    230 \"\"\"The mean of the per-timestep losses of the underlying predictor.\"\"\"
    231 if targets.sizes['time'] == 1:
    232   # If there is only a single target timestep then we don't need any
    233   # autoregressive feedback and can delegate the loss directly to the
    234   # underlying single-step predictor. This means the underlying predictor
    235   # doesn't need to implement .loss_and_predictions.
--> 236   return self._predictor.loss(inputs, targets, forcings, **kwargs)
    238 constant_inputs = self._get_and_validate_constant_inputs(
    239     inputs, targets, forcings)
    240 self._validate_targets_and_forcings(targets, forcings)

File d:\\code\\graphcast-0.1\\graphcast\
ormalization.py:174, in InputsAndResiduals.loss(self, inputs, targets, forcings, **kwargs)
    170 norm_forcings = normalize(forcings, self._scales, self._locations)
    171 norm_target_residuals = xarray_tree.map_structure(
    172     lambda t: self._subtract_input_and_normalize_target(inputs, t),
    173     targets)
--> 174 return self._predictor.loss(
    175     norm_inputs, norm_target_residuals, forcings=norm_forcings, **kwargs)

File d:\\code\\graphcast-0.1\\graphcast\\casting.py:77, in Bfloat16Cast.loss(self, inputs, targets, forcings, **kwargs)
     74   return self._predictor.loss(inputs, targets, forcings, **kwargs)
     76 with bfloat16_variable_view():
---> 77   loss, scalars = self._predictor.loss(
     78       *_all_inputs_to_bfloat16(inputs, targets, forcings), **kwargs)
     80 if loss.dtype != jnp.bfloat16:
     81   raise ValueError(f'Expected bfloat16 loss, got {loss.dtype}')

File d:\\code\\graphcast-0.1\\graphcast\\graphcast.py:424, in GraphCast.loss(self, inputs, targets, forcings)
    418 def loss(  # pytype: disable=signature-mismatch  # jax-ndarray
    419     self,
    420     inputs: xarray.Dataset,
    421     targets: xarray.Dataset,
    422     forcings: xarray.Dataset,
    423     ) -> predictor_base.LossAndDiagnostics:
--> 424   loss, _ = self.loss_and_predictions(inputs, targets, forcings)
    425   return loss

File d:\\code\\graphcast-0.1\\graphcast\\graphcast.py:400, in GraphCast.loss_and_predictions(self, inputs, targets, forcings)
    397 predictions = self(
    398     inputs, targets_template=targets, forcings=forcings, is_training=True)
    399 # Compute loss.
--> 400 loss = losses.weighted_mse_per_level(
    401     predictions, targets,
    402     per_variable_weights={
    403         # Any variables not specified here are weighted as 1.0.
    404         # A single-level variable, but an important headline variable
    405         # and also one which we have struggled to get good performance
    406         # on at short lead times, so leaving it weighted at 1.0, equal
    407         # to the multi-level variables:
    408         \"2m_temperature\": 1.0,
    409         # New single-level variables, which we don't weight too highly
    410         # to avoid hurting performance on other variables.
    411         \"10m_u_component_of_wind\": 0.1,
    412         \"10m_v_component_of_wind\": 0.1,
    413         \"mean_sea_level_pressure\": 0.1,
    414         \"total_precipitation_6hr\": 0.1,
    415     })
    416 return loss, predictions

File d:\\code\\graphcast-0.1\\graphcast\\losses.py:69, in weighted_mse_per_level(predictions, targets, per_variable_weights)
     66     loss *= normalized_level_weights(target).astype(loss.dtype)
     67   return _mean_preserving_batch(loss)
---> 69 losses = xarray_tree.map_structure(loss, predictions, targets)
     70 return sum_per_variable_losses(losses, per_variable_weights)

File d:\\code\\graphcast-0.1\\graphcast\\xarray_tree.py:56, in map_structure(func, *structures)
     54 first = structures[0]
     55 if isinstance(first, xarray.Dataset):
---> 56   data = {k: func(*[s[k] for s in structures]) for k in first.keys()}
     57   if all(isinstance(a, (type(None), xarray.DataArray))
     58          for a in data.values()):
     59     data_arrays = [v.rename(k) for k, v in data.items() if v is not None]

File d:\\code\\graphcast-0.1\\graphcast\\xarray_tree.py:56, in <dictcomp>(.0)
     54 first = structures[0]
     55 if isinstance(first, xarray.Dataset):
---> 56   data = {k: func(*[s[k] for s in structures]) for k in first.keys()}
     57   if all(isinstance(a, (type(None), xarray.DataArray))
     58          for a in data.values()):
     59     data_arrays = [v.rename(k) for k, v in data.items() if v is not None]

File d:\\code\\graphcast-0.1\\graphcast\\losses.py:67, in weighted_mse_per_level.<locals>.loss(prediction, target)
     65 if 'level' in target.dims:
     66   loss *= normalized_level_weights(target).astype(loss.dtype)
---> 67 return _mean_preserving_batch(loss)

File d:\\code\\graphcast-0.1\\graphcast\\losses.py:74, in _mean_preserving_batch(x)
     73 def _mean_preserving_batch(x: xarray.DataArray) -> xarray.DataArray:
---> 74   return x.mean([d for d in x.dims if d != 'batch'], skipna=False)

File c:\\Users\\90566\\miniconda3\\envs\\graphcast\\lib\\site-packages\\xarray\\core\\_aggregations.py:1663, in DataArrayAggregations.mean(self, dim, skipna, keep_attrs, **kwargs)
   1588 def mean(
   1589     self,
   1590     dim: Dims = None,
   (...)
   1594     **kwargs: Any,
   1595 ) -> Self:
   1596     \"\"\"
   1597     Reduce this DataArray's data by applying ``mean`` along some dimension(s).
   1598 
   (...)
   1661     array(nan)
   1662     \"\"\"
-> 1663     return self.reduce(
   1664         duck_array_ops.mean,
   1665         dim=dim,
   1666         skipna=skipna,
   1667         keep_attrs=keep_attrs,
   1668         **kwargs,
   1669     )

File c:\\Users\\90566\\miniconda3\\envs\\graphcast\\lib\\site-packages\\xarray\\core\\dataarray.py:3760, in DataArray.reduce(self, func, dim, axis, keep_attrs, keepdims, **kwargs)
   3716 def reduce(
   3717     self,
   3718     func: Callable[..., Any],
   (...)
   3724     **kwargs: Any,
   3725 ) -> Self:
   3726     \"\"\"Reduce this array by applying `func` along some dimension(s).
   3727 
   3728     Parameters
   (...)
   3757         summarized data and the indicated dimension(s) removed.
   3758     \"\"\"
-> 3760     var = self.variable.reduce(func, dim, axis, keep_attrs, keepdims, **kwargs)
   3761     return self._replace_maybe_drop_dims(var)

File c:\\Users\\90566\\miniconda3\\envs\\graphcast\\lib\\site-packages\\xarray\\core\\variable.py:1756, in Variable.reduce(self, func, dim, axis, keep_attrs, keepdims, **kwargs)
   1749 keep_attrs_ = (
   1750     _get_keep_attrs(default=False) if keep_attrs is None else keep_attrs
   1751 )
   1753 # Noe that the call order for Variable.mean is
   1754 #    Variable.mean -> NamedArray.mean -> Variable.reduce
   1755 #    -> NamedArray.reduce
-> 1756 result = super().reduce(
   1757     func=func, dim=dim, axis=axis, keepdims=keepdims, **kwargs
   1758 )
   1760 # return Variable always to support IndexVariable
   1761 return Variable(
   1762     result.dims, result._data, attrs=result._attrs if keep_attrs_ else None
   1763 )

File c:\\Users\\90566\\miniconda3\\envs\\graphcast\\lib\\site-packages\\xarray\
amedarray\\core.py:789, in NamedArray.reduce(self, func, dim, axis, keepdims, **kwargs)
    784         dims = tuple(
    785             adim for n, adim in enumerate(self.dims) if n not in removed_axes
    786         )
    788 # Return NamedArray to handle IndexVariable when data is nD
--> 789 return from_array(dims, data, attrs=self._attrs)

File c:\\Users\\90566\\miniconda3\\envs\\graphcast\\lib\\site-packages\\xarray\
amedarray\\core.py:203, in from_array(dims, data, attrs)
    200     return NamedArray(dims, to_0d_object_array(data), attrs)
    202 # validate whether the data is valid data types.
--> 203 return NamedArray(dims, np.asarray(data), attrs)

File d:\\code\\graphcast-0.1\\graphcast\\xarray_jax.py:468, in JaxArrayWrapper.__array__(self, dtype, context)
    467 def __array__(self, dtype=None, context=None):
--> 468   return np.asarray(self.jax_array, dtype=dtype)

File c:\\Users\\90566\\miniconda3\\envs\\graphcast\\lib\\site-packages\\jax\\_src\\core.py:668, in Tracer.__array__(self, *args, **kw)
    667 def __array__(self, *args, **kw):
--> 668   raise TracerArrayConversionError(self)

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape bfloat16[1].
The error occurred while tracing the function apply_fn at c:\\Users\\90566\\miniconda3\\envs\\graphcast\\lib\\site-packages\\haiku\\_src\\transform.py:440 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:f32[1,2,181,360] = sub b c
    from line d:\\code\\graphcast-0.1\\graphcast\\xarray_jax.py:353 (wrapped_func)

  operation a:f32[1,2,181,360] = div b c
    from line d:\\code\\graphcast-0.1\\graphcast\\xarray_jax.py:353 (wrapped_func)

  operation a:f32[1,2,181,360] = sub b c
    from line d:\\code\\graphcast-0.1\\graphcast\\xarray_jax.py:353 (wrapped_func)

  operation a:f32[1,2,181,360] = div b c
    from line d:\\code\\graphcast-0.1\\graphcast\\xarray_jax.py:353 (wrapped_func)

  operation a:f32[1,2,181,360] = sub b c
    from line d:\\code\\graphcast-0.1\\graphcast\\xarray_jax.py:353 (wrapped_func)

(Additional originating lines are not shown.)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError"
}

jax version error

Imports section fails:

RuntimeError Traceback (most recent call last)
in <cell line: 12>()
10 import cartopy.crs as ccrs
11 from google.cloud import storage
---> 12 from graphcast import autoregressive
13 from graphcast import casting
14 from graphcast import checkpoint

7 frames
/usr/local/lib/python3.10/dist-packages/jax/_src/lib/init.py in check_jaxlib_version(jax_version, jaxlib_version, minimum_jaxlib_version)
61 msg = (f'jaxlib is version {jaxlib_version}, but this version '
62 f'of jax requires version >= {minimum_jaxlib_version}.')
---> 63 raise RuntimeError(msg)
64
65 if _jaxlib_version > _jax_version:

RuntimeError: jaxlib is version 0.3.25, but this version of jax requires version >= 0.4.14.

Dataset for 10 day forecast

Hi there,
I'm trying to run GraphCast for the 10 day forecast like it was done in the paper. Unless I'm mistaken the example data that was provided in the Google Cloud Bucket only datasets for up to 3 days. Where can I find the dataset that was used for the 10 day run?

Obtaining successive forecasts based on previous predictions

Hi all,

I am wondering how to use the predictions from GraphCast as inputs to produce forecasts for the next time steps. I’m also a beginner in GraphCast and some of the tools and data this code uses.

I’ve run the graphcast_demo.ipynb code with data from ERA5 (using the function download_era5_data from #22) and downloaded data from one day (let’s say 2022-12-31, with time values ['2022-12-31T00:00:00.000000000', '2022-12-31T06:00:00.000000000', '2022-12-31T12:00:00.000000000', '2022-12-31T18:00:00.000000000'] ) to use that one as 'example_batch' instead of the example data. In that way I can obtain some predictions. Now I’d want to use those predictions to produce successive forecasts for the next days (up to 2023-1-10, so I can obtain the next 10 days forecast).

The 'predictions' xarray.Dataset is somewhat different from the input data. The 'time' coordinate is now a timedelta variable and the datetime coordinate is missing in the predictions xarray. Also, the following variables are missing 'toa_incident_solar_radiation', 'geopotential_at_surface' and 'land_sea_mask'. What should be the process to complete and reformat the 'predictions' so it can be used again as input to produce the next forecast values? Or maybe what I’m trying to do should be done in other way?

Thanks in advance!

Jax Error only when TPU-enabled runtime selected

I'm getting the following error, but only when I have a TPU in my runtime.

It works fine without a TPU or with a GPU hardware accelerator

RuntimeError Traceback (most recent call last)
in <cell line: 12>()
10 import cartopy.crs as ccrs
11 from google.cloud import storage
---> 12 from graphcast import autoregressive
13 from graphcast import casting
14 from graphcast import checkpoint

7 frames
/usr/local/lib/python3.10/dist-packages/jax/_src/lib/init.py in check_jaxlib_version(jax_version, jaxlib_version, minimum_jaxlib_version)
61 msg = (f'jaxlib is version {jaxlib_version}, but this version '
62 f'of jax requires version >= {minimum_jaxlib_version}.')
---> 63 raise RuntimeError(msg)
64
65 if _jaxlib_version > _jax_version:

RuntimeError: jaxlib is version 0.3.25, but this version of jax requires version >= 0.4.19.

ERA5-HRES Model Humidity Data Displaying as Zero — Seeking Insights and Clarification on Model Naming

Issue Summary:
While analyzing meteorological data with the GraphCast_operational - ERA5-HRES 1979-2021 model, I encountered an issue with the humidity data displaying as zero. The model has a resolution of 0.25, covers 13 pressure levels, and utilizes a 2to6 mesh grid system, focusing its output on major surface variables such as temperature, humidity, wind speed, wind direction, and mean sea level pressure.

Detailed Description:
Global historical meteorological data was obtained through the cds api, and after inference using the model, it generated a 6.5GB grib file. When examining these data with NASA's Panoply software, all variables except for humidity displayed normally. Puzzlingly, the humidity values were uniformly shown as zero, which is significantly different from what was expected.

As a non-professional individual deeply interested in the analysis of meteorological data, I am concerned that there might be a misunderstanding in my interpretation of the data or an unknown issue encountered during the data processing.

Therefore, I urgently require assistance from professionals to explain why the humidity data is abnormally displaying as zero, and whether this suggests an error in the model's output or the data conversion process.

Additionally, I am curious about the model naming "mesh2 to 6" and "precipitation output only.npz". Does this imply that the model is primarily used for predicting precipitation? I wish to learn more about the background of the model's naming and design focus.

Attachments:

Screenshot of humidity data observed in Panoply software
Download link for Panoply Software: https://www.giss.nasa.gov/tools/panoply/download/
Download link for the grib(6.5GB) data file :
https://drive.google.com/file/d/1yVgTWT1DJNewRepib4qkOLbBz0HGIARY/view?usp=drive_link
I would greatly appreciate it if professionals or those knowledgeable in this area could provide answers and assistance. Thank you!

image

image
image

How to calculate toa_incident_solar_radiation?

I look at the sample code. If you want to predict future weather, you also need to calculate the variable toa_incident_solar_radiation. But how to calculate this variable? I use pysolar, but I can’t get an approximate value.

Regarding training

Hi,

I am trying to train GraphCast on some set of data. Since the main training loop is absent in the repo, I am following the e.ipynb file to create one. The demo file only computes loss for one iteration over a small number of forecasting steps.

How do you train it if you have a large number of steps, like two months of data, as there might be batching of these steps involved? However, I cannot find the function or batching function over the long trajectories using 'data_utils.extract_inputs_targets_forcings' to backpropagate the gradient.

Regards,
Yogesh

GPU / TPU memory requirements for training

In the paper it states that you're using the TPU v4 chips, which have 32 GB memory accessible per TPU core I believe. When trying to train the high res version (on nvidia GPU currently) I seem to use > 48 GB of VRAM for a batch size of 1 per GPU.

I'm using code in the example notebook with minimal modification.

Are there any other things that need to be done in order to get the memory useage down below 32GB for the high res model? I note that gradient checkpointing and bfloat16 are already setup in the example notebook in this function if I'm understanding it correctly?

def construct_wrapped_graphcast(model_config: graphcast.ModelConfig, task_config: graphcast.TaskConfig):
        # Deeper one-step predictor.
        predictor = graphcast.GraphCast(model_config, task_config)

        # Modify inputs/outputs to `graphcast.GraphCast` to handle conversion to
        # from/to float32 to/from BFloat16.
        predictor = casting.Bfloat16Cast(predictor)

        # Modify inputs/outputs to `casting.Bfloat16Cast` so the casting to/from
        # BFloat16 happens after applying normalization to the inputs/targets.
        predictor = normalization.InputsAndResiduals(
            predictor,
            diffs_stddev_by_level=diffs_stddev_by_level,
            mean_by_level=mean_by_level,
            stddev_by_level=stddev_by_level)

        # Wraps everything so the one-step model can produce trajectories.
        predictor = autoregressive.Predictor(predictor, gradient_checkpointing=True)
        return predictor

Do the model inputs (xarray datasets) need to also be cast to a different precision instead of float32?

I'm unsure how to get the memory usage down further so that it could be run on the TPU v4 or something like an 40 GB A100.

colab install error

When trying to install packages the google-colab package install throws the following error:

dist.fetch_build_eggs(dist.setup_requires)

error in pandas setup command: 'install_requires' must be a string or list of strings containing valid project/version requirement specifiers; Expected end or semicolon (after version specifier)
pytz >= 2011k

I want access to all my emails at once cause I'm getting a lawyer Google

I want access to my emails at once for legal issues it's my intellectual property Google you can either give me access to all my emails or I'll have my lawyer subpoena them and sue you for data breach and yes I can prove it all by my sex offender registration that Illinois state Police has and Wisconsin doc has of my emails numbers addresses etc. I want them now you can contact me 2245754798 and my business inf Northpoint Transportation Inc USDOT 1769940 [email protected] [email protected] [email protected] [email protected] etc 2626129917 2623512728

Create a grib file of the predictions

Hi everybody,

I am trying to create a grib file with a subset (10 meter wind and pressure at sea level) of the predictions but do not succeed.
Can you help me creating a grib file?

Thanks,

Henri

Can the pre-trained model work with CPU devices?

I want to use CPU platform with 40 GB memory to run the graphcast, but the program arise the message:
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory allocating 124867611608 bytes.

I was wondering if the pretrain model can work with CPU devices?
I used the graphcast_small model.

Data downloaded from ECMWF is different from the sample data

I have downloaded ERA5 data following instructions on ECMWF. To check the accuracy of data, I downloaded data on 2022-01-01, just the same as the example data provided in Google Cloud Bucket, but actually they were slightly different. Would I ask which API should be used to download these data, or could you publish a script or instructions for constructing input data?

Thanks a lot !

Regarding Normalization data

Hi,

I have a question regarding the computation of these normalization variables in the demo.ipynb file,

diffs_stddev_by_level = xarray.load_dataset(f).compute()
mean_by_level = xarray.load_dataset(f).compute()
stddev_by_level = xarray.load_dataset(f).compute()

I assume that mean and stddev is global and temporal average mean and std for the grid. Can you let me know how you calculated diffs_stddev_by_level, as I didn't get it from the paper also, I would be grateful if you have the code to compute this in repo. or can you point me to it?

Regards,
Yogesh

Inference on multiple TPU cores / GPUs

I have the low level resolution model running locally in inference on a GPU (RTX 4090) and call also run the high resolution (37 pressure levels) for a couple of timesteps before running out of memory.
Does anyone have any advice on parallelising across multiple GPUs or using a TPU v3-8 instance in GCP and utilising all TPU cores?
I see there is the xarray_jax.pmap function which I assume can be used for this, but I'm not sure how to use it properly.

Error while forecasting using processed GFS data in Autoregressive rollout code

I'm attempting to use processed GFS (Global Forecast System) data for forecasting. The processed data's coordinates, dimensions, and variables are nearly identical to the example data provided. However, when running the Autoregressive rollout code snippet, specifically this segment:

# Autoregressive rollout code snippet
assert model_config.resolution in (0, 360. / eval_inputs.sizes["lon"]), (
  "Model resolution doesn't match the data resolution. You likely want to "
  "re-filter the dataset list, and download the correct data.")

print("Inputs:  ", eval_inputs.dims.mapping)
print("Targets: ", eval_targets.dims.mapping)
print("Forcings:", eval_forcings.dims.mapping)

predictions = rollout.chunked_prediction(
    run_forward_jitted,
    rng=jax.random.PRNGKey(0),
    inputs=eval_inputs,
    targets_template=eval_targets * np.nan,
    forcings=eval_forcings)
predictions

I encounter an error message,The error message is as follows.

File ~/soft/anaconda3/lib/python3.10/site-packages/graphcast/rollout.py:69, in chunked_prediction(predictor_fn, rng, inputs, targets_template, forcings, num_steps_per_chunk, verbose)
     67 #   print(inputs)
     68   chunks_list = []
---> 69   for prediction_chunk in chunked_prediction_generator(
     70       predictor_fn=predictor_fn,
     71       rng=rng,
     72       inputs=inputs,
     73       targets_template=targets_template,
     74       forcings=forcings,
     75       num_steps_per_chunk=num_steps_per_chunk,
     76       verbose=verbose):
     77     chunks_list.append(jax.device_get(prediction_chunk))
     78   return xarray.concat(chunks_list, dim="time")

File ~/soft/anaconda3/lib/python3.10/site-packages/graphcast/rollout.py:165, in chunked_prediction_generator(predictor_fn, rng, inputs, targets_template, forcings, num_steps_per_chunk, verbose)
    163 # Make predictions for the chunk.
    164 rng, this_rng = jax.random.split(rng)
--> 165 predictions = predictor_fn(
    166     rng=this_rng,
    167     inputs=current_inputs,
    168     targets_template=current_targets_template,
    169     forcings=current_forcings)
    171 next_frame = xarray.merge([predictions, current_forcings])
    173 current_inputs = _get_next_inputs(current_inputs, next_frame)

/mnt/d/kt/project/yl/google-graphcast/code/test_graphcast.ipynb 单元格 15 line 6
     64 def drop_state(fn):
---> 65   return lambda **kw: fn(**kw)[0]

    [... skipping hidden 12 frame]

File ~/soft/anaconda3/lib/python3.10/site-packages/haiku/_src/transform.py:456, in transform_with_state.<locals>.apply_fn(params, state, rng, *args, **kwargs)
    454 with base.new_context(params=params, state=state, rng=rng) as ctx:
    455   try:
--> 456     out = f(*args, **kwargs)
    457   except jax.errors.UnexpectedTracerError as e:
    458     raise jax.errors.UnexpectedTracerError(unexpected_tracer_hint) from e

/mnt/d/kt/project/yl/google-graphcast/code/test_graphcast.ipynb 单元格 15 line 3
     27 @hk.transform_with_state
     28 def run_forward(model_config, task_config, inputs, targets_template, forcings):
     29   predictor = construct_wrapped_graphcast(model_config, task_config)
---> 30   return predictor(inputs, targets_template=targets_template, forcings=forcings)

File ~/soft/anaconda3/lib/python3.10/site-packages/graphcast/autoregressive.py:212, in Predictor.__call__(self, inputs, targets_template, forcings, **kwargs)
    209     one_step_prediction = hk.remat(one_step_prediction)
    211 # Loop (without unroll) with hk states in cell (jax.lax.scan won't do).
--> 212 _, flat_preds = hk.scan(one_step_prediction, inputs, scan_variables)
    214 # The result of scan will have an extra leading axis on all arrays,
    215 # corresponding to the target times in this case. We need to be prepared for
    216 # it when unflattening the arrays back into a Dataset:
    217 scan_result_template = (
    218     target_template.squeeze('time', drop=True)
    219     .expand_dims(time=targets_template.coords['time'], axis=0))

File ~/soft/anaconda3/lib/python3.10/site-packages/haiku/_src/stateful.py:643, in scan(f, init, xs, length, reverse, unroll)
    637 # We know that we don't need to thread params in and out, since for init we
    638 # have already created them (given that above we unroll one step of the scan)
    639 # and for apply we know they are immutable. As such we only need to thread the
    640 # state and rng in and out.
    642 init = (init, internal_state(params=False))
--> 643 (carry, state), ys = jax.lax.scan(
    644     stateful_fun, init, xs, length, reverse, unroll=unroll)
    645 update_internal_state(state)
    647 if running_init_fn:

    [... skipping hidden 9 frame]

File ~/soft/anaconda3/lib/python3.10/site-packages/haiku/_src/stateful.py:626, in scan.<locals>.stateful_fun(carry, x)
    623 with temporary_internal_state(state):
    624   with base.assert_no_new_parameters(), \
    625        base.push_jax_trace_level():
--> 626     carry, out = f(carry, x)
    627   reserve_up_to_full_rng_block()
    628   carry = (carry, internal_state(params=False))

File ~/soft/anaconda3/lib/python3.10/site-packages/graphcast/autoregressive.py:183, in Predictor.__call__.<locals>.one_step_prediction(inputs, scan_variables)
    181 # Add constant inputs:
    182 all_inputs = xarray.merge([constant_inputs, inputs])
--> 183 predictions: xarray.Dataset = self._predictor(
    184     all_inputs, target_template,
    185     forcings=forcings,
    186     **kwargs)
    188 next_frame = xarray.merge([predictions, forcings])
    189 next_inputs = self._update_inputs(inputs, next_frame)

File ~/soft/anaconda3/lib/python3.10/site-packages/graphcast/normalization.py:156, in InputsAndResiduals.__call__(self, inputs, targets_template, forcings, **kwargs)
    154 norm_inputs = normalize(inputs, self._scales, self._locations)
    155 norm_forcings = normalize(forcings, self._scales, self._locations)
--> 156 norm_predictions = self._predictor(
    157     norm_inputs, targets_template, forcings=norm_forcings, **kwargs)
    158 return xarray_tree.map_structure(
    159     lambda pred: self._unnormalize_prediction_and_add_input(inputs, pred),
    160     norm_predictions)

File ~/soft/anaconda3/lib/python3.10/site-packages/graphcast/casting.py:56, in Bfloat16Cast.__call__(self, inputs, targets_template, forcings, **kwargs)
     52   return self._predictor(inputs, targets_template, forcings, **kwargs)
     54 with bfloat16_variable_view():
     55   predictions = self._predictor(
---> 56       *_all_inputs_to_bfloat16(inputs, targets_template, forcings),
     57       **kwargs,)
     59 predictions_dtype = infer_floating_dtype(predictions)
     60 if predictions_dtype != jnp.bfloat16:

File ~/soft/anaconda3/lib/python3.10/site-packages/graphcast/casting.py:179, in _all_inputs_to_bfloat16(inputs, targets, forcings)
    164 def _all_inputs_to_bfloat16(
    165     inputs: xarray.Dataset,
    166     targets: xarray.Dataset,
   (...)
    177 #   data_vars = {key: value.values if hasattr(value, 'values') else value for key, value in inputs.items()}
    178 #   dataset = xr.Dataset(data_vars)
--> 179     return (inputs.astype(jnp.bfloat16),
    180             jax.tree_map(lambda x: x.astype(jnp.bfloat16), targets),
    181             forcings.astype(jnp.bfloat16))
AttributeError: 'dict' object has no attribute 'astype'

However, running the code with the example data works correctly. The specific error message encountered when forecasting with my processed GFS data. I would appreciate any assistance in resolving this issue

Different loss values for seemingly same forecast

I performed "1 Eval Step" forecast with Graphcast small using dataset_source-era5_date-2022-01-01_res-1.0_levels-13_steps-01.nc, steps-12.nc, and steps-40.nc.

The loss is different for each case even though we are only performing a 6 hr forecast in each case. Why might this be? As I understand it, the prediction should be the same in all these cases and the target should also be the same (ERA5 Reanalysis 6hr into the future).

Loss Values:
0.9296875 for 6hr forecast 1-step data 01 Jan 2022
0.69140625 for 6hr forecast 12-step data 01 Jan 2022
0.66015625 for 6hr forecast 40-step data 01 Jan 2022

Accessing ECMWF Data

For some reason, right when your team started collaborating with the ECMWF, they changed their licensing structure on July 1st, 2023 . This meant that my order placed in April was canceled with no notification. How can we access the data and go about doing research? Please reference this issue,
ecmwf/ecmwf-opendata#30

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.