Giter VIP home page Giter VIP logo

Comments (13)

songshuhan avatar songshuhan commented on July 3, 2024

就非常的奇怪,而且这个mask的值我发现很敏感,如果不用0,用一些其他的值,发现非常的不work,而且只在decoder是gat的时候看似有效,其他的一律没有作用

from graphmae.

songshuhan avatar songshuhan commented on July 3, 2024

希望作者可以给一点解释,感谢

from graphmae.

THINK2TRY avatar THINK2TRY commented on July 3, 2024

@songshuhan 非常感谢关注GraphMAE!

  1. 关于loss function用 SCE而非 MSE,在GraphMAE的论文中进行了解释。cosine loss 可以视为一种归一化之后的 MSE,在表示学习中具有优势。MSE本身可能会受到 feature中极值的干扰。
  2. MASK的值可以用0或者高斯初始化。在 cora 等数据集上,由于 input feature 绝大部分维度的值都是0,所以用 0来初始化 MASK 是一个好的选择。我们在后续工作的实验中也显示,在OGB数据集这种特征为连续分布的数据上,0和高斯初始化没有显著的差距。
  3. Decoder使用GAT是一个较好的选择,使用不同GNN作为 decoder时需要进行超参数的调整。

希望有帮助!

from graphmae.

songshuhan avatar songshuhan commented on July 3, 2024

感谢您的回复!另外想问下有mini batch的训练代码么,因为受制于实验室显存的限制,Reddit不利用mini batch的方式是不能跑的,会cuda out of memery 作者可以提供一下不?

from graphmae.

zzzsh2000 avatar zzzsh2000 commented on July 3, 2024

你好,我运行以下命令,无法复现论文的结果,能帮我解释下为什么吗?非常感谢
ogb=1.3.6
pytorch=1.12.0
dgl-cuda11.3=0.9.1
python main_transductive.py --dataset cora --encoder gat --decoder gat --seed 0 --device 0

--- TestAcc: 0.5650, early-stopping-TestAcc: 0.5650, Best ValAcc: 0.5840 in epoch 29 ---

final_acc: 0.5650±0.0000
early-stopping_acc: 0.5650±0.0000

from graphmae.

THINK2TRY avatar THINK2TRY commented on July 3, 2024

感谢您的回复!另外想问下有mini batch的训练代码么,因为受制于实验室显存的限制,Reddit不利用mini batch的方式是不能跑的,会cuda out of memery 作者可以提供一下不?

@songshuhan 我们在这个repo中提供了mini-batch的实现,包括GraphSAINT 和 localclustering,可以参考。如果要使用 GraphMAE,可以对模型部分进行调整即可。

from graphmae.

songshuhan avatar songshuhan commented on July 3, 2024

@THINK2TRY 我其实自己来利用dgl的MultiLayerFullNeighborSampler和NodeDataLoader实现了一个minibatch,但是跑起来太太太慢了,感觉运行一个epoch的reddit估计就要半个小时,500个epoch感觉要很久很久的样子。。。我看您提供的repo是给每个节点生成了一个子图存放在文件中做预处理,预处理之后有会更快一点么?但是貌似在dataloader的时候也会dgl.subgraph然后dgl.batch,想咨询下您运行一个epoch的reddit大概用多久呢?

另外可以加您的联系方式不?关注你们的工作,感觉很有意思,想多和您交流!我的wx是ShuHann1997

from graphmae.

songshuhan avatar songshuhan commented on July 3, 2024

不对,我采用的这个方式,感觉运行一个epoch不止半个小时。。。估计得俩小时。。

logging.info("start mini batch training..")

sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.dataloading.NodeDataLoader(
          graph,train_index, sampler,
          batch_size=100,
          shuffle=True,
          drop_last=False,
          num_workers=1)

total_epoch = tqdm(range(max_epoch))
for epoch in total_epoch:
    epoch_iter = tqdm(dataloader)
    loss_list = []

    for input_nodes, output_nodes, _ in epoch_iter:
    
        model.train()
        subgraph = dgl.node_subgraph(graph, input_nodes).to(device)
        loss, loss_dict = model(subgraph, subgraph.ndata["feat"])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_list.append(loss.item())

        if scheduler is not None:
            scheduler.step()

    train_loss = np.mean(loss_list)
    total_epoch.set_description(f"# Epoch {epoch} | train_loss: {train_loss:.4f}")
    if logger is not None:
        loss_dict["lr"] = get_current_lr(optimizer)
        logger.note(loss_dict, step=epoch)

from graphmae.

THINK2TRY avatar THINK2TRY commented on July 3, 2024

@songshuhan 建议使用 GraphSAINT 等基于子图采样的方法训练。mini-batch training 一般只需要几十个 epoch 就可以完成。如果后续有问题,可以通过 GitHub issue 或者 邮件 沟通都可以的。

from graphmae.

songshuhan avatar songshuhan commented on July 3, 2024

@THINK2TRY 好的,感谢,另外我看到您用的reddit是处理过的嘛,看您论文里的reddit是11,606,919条边,但是我在dgl里加载的貌似有一亿条边,可以问下这个数据集可以从哪里得到么?

from graphmae.

THINK2TRY avatar THINK2TRY commented on July 3, 2024

@songshuhan 使用 DGL 提供的 Reddit 数据集,在 code 里面有实现可以参考。

from graphmae.

songshuhan avatar songshuhan commented on July 3, 2024

code里貌似没有提供reddidt,我的DGL版本里提供的边数是114615892,是论文里数量的十倍了。然后我参考您的建议发现了dgl里面提供的SAINTSample,速度确实很快,想问下如果能得到论文里reddit的结果您是用了多少个子图做一次epoch呢,每个子图的budget大概是多少呢?

num_iters = 1000
sampler = SAINTSampler(
mode='node', # Can be 'node', 'edge' or 'walk'
budget=2000,
prefetch_ndata=['feat', 'label'] # optionally, specify data to prefetch
)
dataloader = DataLoader(graph, torch.arange(num_iters), sampler, num_workers=1)

from graphmae.

THINK2TRY avatar THINK2TRY commented on July 3, 2024

Reddit 可以参考这里。关于采样的超参数可以在能计算资源和时间能接受的范围内进行设置,

from graphmae.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.