Comments (4)
training: 0%| | 0/2000000 [00:00<?, ?it/s]
Traceback (most recent call last):
File "train_qa_webtext2.py", line 164, in
accelerator.backward(loss / GRADIENT_ACCUMULATE_EVERY)
File "/home/wac/anaconda3/envs/py38/lib/python3.8/site-packages/accelerate/accelerator.py", line 1683, in backward
loss.backward(**kwargs)
File "/home/wac/anaconda3/envs/py38/lib/python3.8/site-packages/torch/_tensor.py", line 487, in backward
torch.autograd.backward(
File "/home/wac/anaconda3/envs/py38/lib/python3.8/site-packages/torch/autograd/init.py", line 200, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: CUDA error: invalid argument
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with TORCH_USE_CUDA_DSA
to enable device-side assertions.
from palm-rlhf-pytorch.
model = PaLM( num_tokens=256, #512 1024 dim=2048, #dim_head*heads depth=24, dim_head = 256, #always 256 heads = 8, flash_attn=True ).to(device)
Which type of GPU are you using? Are you using PyTorch 2.0? Flash Attention requires an A100. Also, I do not believe Flash Attention supports dim_head
larger than 128.
FlashAttention currently supports:
1. Turing, Ampere, Ada, or Hopper GPUs (e.g., A100).
2. fp16 and bf16
3. Head dimensions that are multiples of 8, up to 128 (e.g., 8, 16, 24, ..., 128). Head dim > 64 backward requires A100 or H100.
from palm-rlhf-pytorch.
use a6000,can‘t open FA?
how to setup dim_head larger than 256?
from palm-rlhf-pytorch.
runs well with 256 dim_head while i comments FA
from palm-rlhf-pytorch.
Related Issues (20)
- Value function
- Can not train the model using PyTorch version 2? HOT 1
- train your reward model issue HOT 1
- KL divergence loss HOT 1
- mask raised error HOT 2
- Confusion about KL divergence calculation for human feedback policies HOT 13
- Reason for using pooled critic embedding instead of the last embedding for value head HOT 3
- Calculating the kl loss seems has a mistake. HOT 1
- Column and Row Parallel Linear for Apex Tensor Parallel HOT 1
- norm.gamma not used during backprop HOT 2
- speed up with flash attn in A6000? HOT 2
- memory-efficient attention is default opened? if i dont use flash attn HOT 3
- Model Name HOT 3
- I looked at the llama source code and there is an intermedie layer
- Flash Attention 2
- Possible incorrect creation of Rotary Embeddinigs HOT 1
- Should critic's input be prompt only?
- How to use lora?
- Is there any documentation to train this on my own data ?
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from palm-rlhf-pytorch.