Giter VIP home page Giter VIP logo

oneforall's Introduction

Code for One for All: Towards Training One Graph Model for All Classification Tasks

Paper: https://arxiv.org/abs/2310.00149

Authors: Hao Liu, Jiarui Feng, Lecheng Kong, Ningyue Liang, Dacheng Tao, Yixin Chen, Muhan Zhang

OFA Pipeline

OFA is a general Graph Classification Framework that can solves a wide range of graph classification tasks with a single model and a single set of parameters. The tasks are cross-domain (e.g. citation network, molecular graph,...) and cross-tasks (e.g. few-shot, zero-shot, graph-level, node-leve,...)

OFA use natural languages to describe all graphs, and use a LLM to embed all description in the same embedding space, which enable cross-domain training using a single model.

OFA propose a prompting paradiagm that all task information are converted to prompt graph. So subsequence model is able to read tasks information and predict relavent target accordingly, without having to adjust model parameters and architecture. Hence, a single model can be cross-task.

OFA curated a list of graph datasets from a different sources and domains and describe nodes/edges in the graphs with a systematical decription protocol. We thank previous works including, OGB, GIMLET, MoleculeNet, GraphLLM, and villmow for providing wonderful raw graph/text data that make our work possible.

🔥News

Update 04/14

  • Multi-GPU training is employed, it uses all visible GPU to train the model.
  • Fixed bug in #11, few-shot should be correct.
  • Update ArXiv split. So baseline and OFA both use the same split.
  • Fix prompt edge connection to align with the paper. If you cloned our repo earlier, please first update and reproduce our results.

Old

OneForAll underwent a major revision, where we cleaned up the code and fixed several reported bugs. The major updates are:

  • Use yaml configs to specify tasks, see Configuration Section for details.
  • Updated graph prompting logic, where users can design their own prompting more freely.
  • Use only one Few-shot dataset for few-shot prompting of different levels of tasks.

If you previously used our repository, please pull and delete the old generated feature/text files and regenerate. Sorry for the inconvenience.

Requirements

To install requirement for the project using conda:

conda env create -f environment.yml

E2E experiments

For joint end-to-end experiments on all collected dataset, run

python run_cdm.py --override e2e_all_config.yaml

All arguments can be changed by space separated values such as

python run_cdm.py --override e2e_all_config.yaml num_layers 7 batch_size 512 dropout 0.15 JK none

Users can modify the task_names variable in ./e2e_all_config.yaml to control which datasets are included during training. The length of task_names, d_multiple, and d_min_ratio should be the same. They can also be specified in command line arguments by comma separated values.

e.g.

python run_cdm.py task_names cora_link,arxiv d_multiple 1,1 d_min_ratio 1,1

OFA-ind can be specified by

python run_cdm.py task_names cora_link d_multiple 1 d_min_ratio 1

Low resource experiments

To run the few-shot and zero-shot experiments

python run_cdm.py --override lr_all_config.yaml

Configuration explained

We define configurations for each task, each task configurations contains several datasets configurations.

Task configurations are stored in ./configs/task_config.yaml. A task usually consists several splits of datasets (not necessarily same datasets). For example, a regular end-to-end Cora node classification task will have the train split of the Cora dataset as the train dataset, the valid split of the Cora dataset as one of the valid dataset, and likewise for the test split. You can also have more validation/test by specifying the train split of the Cora as one of the validation/test datasets. Specifically, a task configuration looks like

arxiv:
  eval_pool_mode: mean
  dataset: arxiv             # dataset name
  eval_set_constructs:
    - stage: train           # a task should have one and only one train stage dataset
      split_name: train
    - stage: valid
      split_name: valid
      dataset: cora          # replace the default dataset for zero-shot tasks
    - stage: valid
      split_name: valid
    - stage: test
      split_name: test
    - stage: test
      split_name: train      # test the train split

Dataset configurations are stored in ./configs/task_config.yaml. A dataset configuration defines how a dataset is constructed. Specifically,

arxiv:
  task_level: e2e_node
  preprocess: null                       # name of the preprocess function defined in task_constructor.py
  construct: ConstructNodeCls            # name of the dataset construction function defined in task_constructor.py
  args: # additional arguments to construct function
    walk_length: null
    single_prompt_edge: True
  eval_metric: acc                       # evaluation metric
  eval_func: classification_func         # evaluation function that process model output and batch to input to evaluator
  eval_mode: max                         # evaluation mode (min/max)
  dataset_name: arxiv                    # name of the OFAPygDataset
  dataset_splitter: ArxivSplitter        # splitting function defined in task_constructor.py
  process_label_func: process_pth_label  # name of process label function that transform original label to the binary labels
  num_classes: 40 

Add your own datasets

If you are implementing a dataset like Cora/pubmed/Arxiv, we recommend adding a directory of your data $customized_data $ under data/single_graph/$customized_data$ and implement gen_data.py under the directory, you can use data/Cora/gen_data.py as an example.

After the data is constructed, you need to register you dataset name in here , and implement a splitter like here. If you are doing zero-shot/few-shot tasks, you can constructor zero-shot/few-shot split here too.

Lastly, register a config entry in configs/data_config.yaml. For example, for end-to-end node classification

$data_name$:
  <<: *E2E-node
  dataset_name: $data_name$
  dataset_splitter: $splitter$
  process_label_func: ... # usually processs_pth_label should work
  num_classes: $number of classes$

process_label_func converts the target label to binary label, and transform class embedding if the task is zero-shot/few-shot, where the number of class node is not fixed. A list of avalailable process_label_func is here. It takes in all classes embedding and the correct label. The output is a tuple : (label, class_node_embedding, binary/one-hot label).

If you want more flexibility, then adding customized datasets requires implementation of a customized subclass of OFAPygDataset .A template is here:

class CustomizedOFADataset(OFAPygDataset):
    def gen_data(self):
        """
        Returns a tuple of the following format
        (data, text, extra) 
        data: a list of Pyg Data, if you only have a one large graph, you should still wrap it with the list.
        text: a list of list of texts. e.g. [node_text, edge_text, label_text] this is will be converted to pooled vector representation.
        extra: any extra data (e.g. split information) you want to save.
        """

    def add_text_emb(self, data_list, text_emb):
        """
        This function assigns generated embedding to member variables of the graph

        data_list: data list returned in self.gen_data.
        text_emb: list of torch text tensor corresponding to the returned text in self.gen_data. text_emb[0] = llm_encode(text[0])

        
        """
        data_list[0].node_text_feat = ...     # corresponding node features
        data_list[0].edge_text_feat = ...      # corresponding edge features
        data_list[0].class_node_text_feat = ...      # class node features
        data_list[0].prompt_edge_text_feat = ...     # edge features used in prompt node
        data_list[0].noi_node_text_feat = ...       # noi node features, refer to the paper for the definition
        return self.collate(data_list)

    def get_idx_split(self):
        """
        Return the split information required to split the dataset, this optional, you can further split the dataset in task_constructor.py
        
        """

    def get_task_map(self):
        """
        Because a dataset can have multiple different tasks that requires different prompt/class text embedding. This function returns a task map that maps a task name to the desired text embedding. Specifically, a task map is of the following format.

        prompt_text_map = {task_name1: {"noi_node_text_feat": ["noi_node_text_feat", [$Index in data[0].noi_node_text_feat$]],
                                    "class_node_text_feat": ["class_node_text_feat",
                                                             [$Index in data[0].class_node_text_feat$]],
                                    "prompt_edge_text_feat": ["prompt_edge_text_feat", [$Index in data[0].prompt_edge_text_feat$]]},
                       task_name2: similar to task_name 1}
        Please refer to examples in data/ for details.
        """
        return self.side_data[-1]

    def get_edge_list(self, mode="e2e"):
        """
        Defines how to construct prompt graph
        f2n: noi nodes to noi prompt node
        n2f: noi prompt node to noi nodes
        n2c: noi prompt node to class nodes
        c2n: class nodes to noi prompt node
        For different task/mode you might want to use different prompt graph construction, you can do so by returning a dictionary. For example
        {"f2n":[1,0], "n2c":[2,0]} means you only want f2n and n2c edges, f2n edges have edge type 1, and its text embedding feature is data[0].prompt_edge_text_feat[0]
        """
        if mode == "e2e_link":
            return {"f2n": [1, 0], "n2f": [3, 0], "n2c": [2, 0], "c2n": [4, 0]}
        elif mode == "lr_link":
            return {"f2n": [1, 0], "n2f": [3, 0]}

oneforall's People

Contributors

haoliu-cola avatar jiaruifeng avatar lechengkong avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

oneforall's Issues

Split of Arxiv

Hi! Thanks for sharing the elegant codebase. I have a problem regarding the dataset split of the ogbn-arxiv. I notice that the split is given by the ArxivSplitter, which is a 10-fold split, and 80% of the data are adopted as the training set, which is different from the original split. However, it seems the performance of GCN in Table 3 is taken from OGB's leaderboard. I wonder whether the split will affect the overall performance of OFA.

Could you please provice the run command about lama2-7b and llama2-13b?

          > Hi @noah-yxk , thank you for your interest in our work!

I ran your command and it gives me ~ 0.66 test accuracy, can you also share your replication results? Thanks.

It appears to me that the resulting model is not well-trained. The following command

python run_cdm.py task_names cora_node num_epochs 20 d_multiple 10.0 d_min_ratio 10.0 lr 0.0001 JK none batch_size 64

gives a test accuracy of 0.745. d_multiple 10.0 means that we sample 10*len(dataset) data points for training in one epoch. If you still can't get similar results, please let us know. I think what happened in the command you provided is that we have a learning_rate scheduler that decreases learning rate for every 15 epochs, but this causes performance decrease for cora-node dataset. In the command I provided, the learning rate only decreases once at epoch 15.

For the individual experiment, especially for small datasets like cora-node, extra care is needed to make sure the model is well-trained and not overfitted, we have learning rate and batch_size set at 0.0001 and 64 to reduce the likelihood of overfitting. We will hopefully add hyperparameter settings for individual experiments in our next revision.

Cheers

Hi, @LechengKong , when I use the llama2-7b or llama2-13b as the llm encoder, the acc is around 0.68. Could you please provice the run command about lama2-7b and llama2-13b?

Originally posted by @1957787636 in #6 (comment)

Replicating the results of Table12

Hi! Thanks for your brilliant work! I'm pretty interested in the results of Table 12 in the appendix. I wonder if there are any configs for us to replicate the results of Table 12? Thank you.

A question about the data.

Is the 'Data(raw_text=[2708], y=[2708], label_names=[7], edge_index=[2, 10858], train_masks=[10], val_masks=[10], test_masks=[10], x=[2708, 384], raw_texts=[2708], category_names=[2708])' output in the .pt files from /data the raw data? It seems like the 'x' has undergone some processing. Is it normalized? Thanks!

problem on environment.yml

When I run the command "conda env create -f environment.yml", there is a problem: Could not find a version that satisfies the requirement en-core-web-sm==3.5.0.
微信图片_20240329110119

problem of replicating the outcomes of 'ofa-ind-st'

Hello,I tried to reproduce the results of ofa, but did not achieve the effect in the article.
And this is my setting
python run_cdm.py task_names cora_node num_epochs 100 num_layers 6 dropout 0.15
Could you advise me on replicating the outcomes of 'ofa-ind-st' as mentioned in the paper?
What specific command line parameters should I employ to run 'cdm.py'?

For the node type

Hi! Thanks for your brilliant work! There are some parts in the code that I am confused about, I hope you can help me answer them!
What is the logic node I found in the data processing of the ARXIV dataset? This does not seem to appear in the paper. At the same time, why is the Noi node text not seen in FB15K237_fs? Hope you can get the answer, thank you! !

Encountering KeyError and NaN label issues with end-to-end experiments

Hi lecheng,

Thanks for the open source code, I run into an error when I use end-to-end experiments.

First is the KeyError: 'prompt_node_edge_feat' on dataset coralink and coranode.
image

Second is the nan in labels on chempcba chemhiv chemblpre datasets.
image
image

Could you please help me understand what might be causing these issues and how to resolve them?

How to replicate the results in Table 4?

Thanks for your outstanding work. I'm very interesting in the experiment about few-shot and zero-shot learning. I have a few questions that I would like to seek your guidance on:

  1. What does the parameter d_min_ratio signify?
  2. How should I adjust the parameters to reproduce the results in Table 4? I attempted to directly execute python run_cdm.py --override lr_all_config.yaml, however, it was challenging to achieve the reported outcomes. For instance, I only obtained a testing performance of 38.60 on ogbn-arxiv 5-way 5-shot.

I would be immensely grateful if I could receive your response.

molecular graph data problem?

Hello, how can I use the HIV molecular graph, first download the data, and then perform data read operations. Download is not allowed due to network problems. Thanks.

About zero shot task

I think this is a fantastic work! And it's been incredibly interesting for me!
However, I have a question that I'd like to trouble you with. I'm curious about the transferability of the model, for instance, training the model on the training set from arXiv in a supervised learning setup, and then testing it on the test sets from PubMed and Cora. Can the existing code achieve this? Specifically, how should I modify the configuration file? Looking forward to your reply! Wishing you smooth work and a happy life!

How can I test other settings on few-shot ability of OFA?

Hi! First of all, thanks for your amazing work!
In Section E.2 you provide some experiment results spanning more ways and shots on ogbn-arxiv and FB15K237 datasets, such as 3/5-ways on ogbn-arxiv and 10/20-ways on FB15K237. I want to test other ways on these datasets, like 10/20/30/40-ways, how should I do it?

Can I just modify the config/data_config.yaml and config/task_config.yaml like this :

config/data_config.yaml

FB15K237_fs_403:
  <<: *FB15K237_fs
  args:
    walk_length: null
    single_prompt_edge: True
    n_way: 40
    k_shot: 3
    base_construct: ConstructKG
    no_class_node: True
    remove_edge: True
  num_classes: 40

config/task_config.yaml

FB15K237_fs: &FB15K237_fs
  <<: *LR-link
  dataset: FB15K237_fs
  eval_set_constructs:
    - stage: train
      split_name: train
      dataset: FB15K237_fs

...
    - stage: valid
      split_name: valid
      dataset: FB15K237_fs_403
    - stage: valid
      split_name: test
      dataset: FB15K237_fs_403

if not, what should I do to test other few-shot settings?

Inquiry about Cora data

Hi, thanks for your excellent work, and it inspires a lot! Here I have a question about the Cora dataset in your work that the cora.pt is a newly-organized cora TAG developed by you own, the order of nodes in it is not consistent with the official dataset provided by PyG, may I ask if you still keep the correspondence between each node and paper id for cora.pt? Thx!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.