Giter VIP home page Giter VIP logo

unist's Introduction

UniST

A pytorch implementation for the paper: UniST: A Prompt-Empowered Universal Model for Urban Spatio-Temporal Prediction.

Yuan Yuan, Jingtao Ding, Jie Feng, Depeng Jin, Yong Li

FIBLAB@Tsinghua University


The repo currently includes code implementations for the following tasks:

Short-term Prediction: We provide all scripts for the reproduction of short-term prediction results in this repo.

Long-term Prediction: We provide all scripts for the reproduction of long-term prediction results in this repo.

Few-shot Prediction: UniST can generalize well to scenarios with limited training data, making it to be data-efficient.

Zero-shot Prediction: UniST is demonstrated to generalize well on unseen spatio-temporal scenarios, making it a nice alternative as the fundamental backbone of the foundation spatio-temporal model.

πŸŽ‰ Updates

πŸ“’: News (2024.06) Introduction of our work in 量子位, ζ—Άη©ΊζŽ’η΄’δΉ‹ζ—… are available.

πŸ“’: News (2024.05) UniST has been accepted to KDD 2024.

Introduction

πŸ† By capturing the underlying commonalities across multiple spatio-temporal scenarios, UniST breaks the conventional practice that train separate models for different datasets, and has demonstrated superior performance and powerful generalization capability across diverse urban scenarios. UniST

Overall Architecture

🌟 The training of UniST consists of two stages: (i) large-scale spatio-temporal pre-training, and (ii) spatio-temporal knowledge-guided prompt tuning. OverallArchi

The pseudo-code of UniST is as simple as the following: Alg

βš– Foundation models for spatio-temporal prediction

Model Data Format Data Scalability Few-shot Zero-shot Computation Cost Memory Cost
PromptST [1] Grid βœ— βœ— βœ— Low Low
GPT-ST [2] Graph βœ— βœ— βœ— Low Low
STEP [3] Graph βœ— βœ— βœ— Low Low
ST-SSL [4] Graph βœ— βœ— βœ— Low Low
TrafficBERT [5] Grid/Graph βœ“ βœ— βœ— Low Low
TFM [6] Graph βœ— βœ— βœ— Low Low
UrbanGPT [7] Grid βœ“(a) βœ“(a) βœ“(a) High High
STG-LLM [8] Graph βœ— βœ— βœ— High High
UniST Grid/Graph βœ“ βœ“ βœ“ Low Low

(a). Still restricted in the same city.

[1] PromptST: Prompt-Enhanced Spatio-Temporal Multi-Attribute Prediction, CIKM 2023

[2] GPT-ST: Generative Pre-Training of Spatio-Temporal Graph Neural Networks, NIPS 2023

[3] Pre-training enhanced spatial-temporal graph neural network for multivariate time series forecasting, KDD 2022

[4] Spatio-Temporal Self-Supervised Learning for Traffic Flow Prediction, AAAI 2023

[5] TrafficBERT: Pre-trained model with large-scale data for long-range traffic flow forecasting, Expert Systems with Applications

[6] Building transportation foundation model via generative graph transformer, ITSC 2023

[7] UrbanGPT: Spatio-Temporal Large Language Models, KDD 2024

[8] How can large language models understand spatial-temporal data?, arXiv 2024

Data

We use multiple datasets to demonstrate the UniST, which span various cities and domains. To access the datasets, please refer to data readme.

βš™οΈ Installation

Environment

  • Tested OS: Linux
  • Python >= 3.9
  • torch == 2.0.0
  • Tensorboard

Dependencies:

  1. Install Pytorch with the correct CUDA version.
  2. Use the pip install -r requirements.txt command to install all of the Python modules and packages used in this project.

πŸƒ Model Training

Please first navigate to the src directory by using the cd command: cd src

Then please create a folder named experiments to record the training process: mkdir experiments

Stage-1: Pre-training

We provide the scripts under the folder ./scripts/pretrain.sh. You can train UniST with the Cellular dataset as the following examples:

python main.py --device_id 3 --machine machine  --dataset Crowd --task short --size middle  --mask_strategy_random 'batch' --lr 3e-4 --used_data 'single'  --prompt_ST 0

Once your model is trained, you will find the logs recording the training process in the ./logs/ directory. The folder will be named as the Pretrain_Dataset_<dataset>_task_<task>. In the ./experiments/Pretrain_Dataset_<dataset>_task_<task>/model_save/, you will find the trained model named model_best.pkl.

In our experiments, we leverage multiple datasets to enhance UniST. If you need to use multiple datasets, please use an asterisk (*) to separate the datasets, e.g., --dataset Crowd*Cellular*TaxiNYC*TaxiBike*TrafficSH.

Stage-2: Prompt-tuning

We provide the scripts under the folder ./scripts/prompt_tuning.sh. You can fine-tune UniST with the Cellular dataset as the following examples:

python main.py --device_id 2 --machine machine --task short --size middle   --prompt_ST 1  --pred_len 6 --his_len 6  --num_memory_spatial 512 --num_memory_temporal 512  --prompt_content 's_p_c'  --dataset Crowd    --lr 3e-4 --used_data 'single' --file_load_path  pretrained_model_path

There are some new parameters to specify:

  • his_len specifies the input sequence length.
  • pred_len specifies the prediction horizon.
  • file_load_path specifies the save path of the pre-trained model, the default is ./experiments/Dataset_<dataset>_task_<task>/model_save/model_best.pkl
  • num_memory_spatial and num_memory_temporal specify the number of embeddings in the memory pools.
  • prompt_ST specifies whether perform prompt-tuning: 0 for no prompt and 1 for prompt-tuning.
  • prompt_content specifies the type of prompt, which can be selected from ['s_p_c','s','c','p','s_c','s_p','p_c'].

Once your model is trained, you will find the logs recording the training process in the ./logs/ directory. The folder will be named as the Prompt_Dataset_<dataset>_His_<his_len>_Pred_<pred_len>. In the ./experiments/Prompt_Dataset_<dataset>_His_<his_len>_Pred_<pred_len>/model_save/, you will find the fine-tuned model named model_best.pkl.

The evaluation results of the testing set can be obtained from ./experiments/Prompt_Mode_finetuning_Dataset_<dataset>_His_<his_len>_Pred_<pred_len>/result.txt.

Model Weights

We provide downloads of model weights on xxx. Coming soon.

πŸ‘€ Citation

If you find this repo helpful, please cite our paper.

@article{yuan2024unist,
  title={UniST: A Prompt-Empowered Universal Model for Urban Spatio-Temporal Prediction},
  author={Yuan, Yuan and Ding, Jingtao and Feng, Jie and Jin, Depeng and Li, Yong},
  journal={arXiv preprint arXiv:2402.11838},
  year={2024}
}

πŸ™‡β€ Acknowledgement

We appreciate the following GitHub repos a lot for their valuable code and efforts.

πŸ“§ Contact

If you have any questions or want to use the code, feel free to contact:

unist's People

Contributors

davymorgan avatar yuanyuan98 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

unist's Issues

Question about four mask strategies.

Hi, thank you for your nice work and for sharing the code.

I would like to know about the mask strategies. Are these four strategies working together or individually (['random','causal','frame','tube'])? I saw that in your code "main.py", you set mask_strategy = 'random', which seems to use only one random strategy but not use others. Which one is the best mask strategy effect?

Look forward to your reply. Thanks in advance!

Model weights

Hello,

Thank you the code release. Are you planning to release the pre-trained model weights soon?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.