Giter VIP home page Giter VIP logo

hats's Introduction

HATS

overview_model This repository contains source codes of HATS, A Hierarchical Graph Attention Network for Stock Movement Prediction. As we conducted experiments on two different tasks, node classification and graph classification, we provide two different version of codes for each tasks. Please refer to our paper HATS: A Hierarchical Graph Attention Network for Stock Movement Prediction for further details.

Requirements

Numpy 1.15.1
Tensorflow 1.11.0

Dataset

Price-realted data and corporate relation data is used for HATS. We gathered both data for S&P 500 listed companies from 2013/02/08 to 2019/06/17 (1174 trading days in total). Price data are gathered from Yahoo Finance and corporate relation data are collected based on the information on Wikidata. Both datasets can be downloaded with the command below.

Usage

Download Data

bash download.sh

Excute model with makefile

You need to pass some arguments.
test_phase : phase that you want to test
save_dir : name of saving directories
data_type (only in graph_classification) : choose among ['S5CONS', 'S5ENRS', 'S5UTIL', 'S5FINL', 'S5INFT']

e.g.
make test_phase=1 save_dir=save data_type='S5CONS'

hats's People

Contributors

raehyuns avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

hats's Issues

Error when running code

Hello, I downloaded your code and data, but an error is reported when running main.py, the error message is TypeError: 'NoneType' object is not iterable, I didn't understand how to modify it, can you guide me?

TypeError: 'NoneType' object is not iterable

Run Command : (in /hats/graph_classification/src direcory)

python main.py --data_type='S5CONS' --label_proportion=5

Error :

2022-08-18 13:46:56.025164: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2022-08-18 13:46:56.025202: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
08/18/2022 01:46:57 PM: Namespace(GNN_model='HATS', adj_train_phase='adj_sum', assignment_layer=[10, 1], batch_size=32, data_type='S5CONS', dev_size=50, dropout=0.3, early_stop_type='acc', eval_step=10, feat_att=False, feature_list=None, grad_max_norm=2.0, inference_model='CNN', label_proportion=[5], lookback=50, lr=5e-05, market_data_dir='./data/price/', max_relations=5, max_to_keep=10, min_train_period=300, mode='train', model_dir='./model', model_id='epoch200.pt', model_type='graph-HATS', momentum=0.9, n_epochs=200, n_iter_per_epoch=50, node_features=64, num_layer=1, num_relations=0, optimizer='Adam', pred_threshold=0.6, preprocess=False, pretrained_dir=None, price_model='LSTM', print_step=10, rel_att=False, relation_data_dir='./data/relation/', save_dir='./out', save_log=False, stack_layer=1, test_phase=None, test_size=100, train_on_stock=False, train_proportion=3.0, use_bias=True, use_rel_list=[2, 3, 6, 11, 13, 17, 18, 28, 30, 32, 37, 38, 39, 40, 41, 48, 57, 84, 85, 86], weight_decay=1e-05) 
Traceback (most recent call last):
  File "main.py", line 78, in <module>
    main()
  File "main.py", line 29, in main
    dataset = StockDataset(config)
  File "/home/hemang/Downloads/notebook_scripts/hats/graph_classification/src/dataset.py", line 20, in __init__
    self.scaling_feats = [idx for idx, name in enumerate(self.feature_list) if name != 'return']
TypeError: 'NoneType' object is not iterable

https://hjlabs.in/

correspondence between relation type & index?

Hi,

In the data there are 85 relations, 14 of which are marked as not used. However in the paper, totally 72 relations are presented, which is a bit confusing.

Now I want to use the several most important relations and extract them from the adjacent matrix, can you give any suggestions?

NaN loss for graph classification

Hi, thanks for the code. When I ran the graph classification code by make test_phase=1 save_dir=save data_type='S5CONS', I found that the loss is always NaN both for training and validation, and the performance figures didn't change. Do you have any idea why this might be happening? Thank you!

Performance lower than reported

Hi, thank you for your great work and code!

I ran your training code of node classification by make test_phase=1 save_dir=save, but have found that the performance on the validation set is relatively low at accuracy : 0.4146, hit ratio : 0.5090, macro f1 : 0.2762, expected return : -0.0087, which seems to be much lower than the figures reported in the paper.

Is there any configurations that should be modified to reproduce your results? Thank you very much!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.