Comments (6)
#23 contains a notebook with a good example.
I think putting it together with the README instructions looks like this:
import torch
from retro_pytorch import RETRO, TrainingWrapper
# instantiate RETRO, fit it into the TrainingWrapper with correct settings
retro = RETRO(
max_seq_len = 2048, # max sequence length
enc_dim = 896, # encoder model dimension
enc_depth = 3, # encoder depth
dec_dim = 768, # decoder model dimensions
dec_depth = 12, # decoder depth
dec_cross_attn_layers = (1, 3, 6, 9), # decoder cross attention layers (with causal chunk cross attention)
heads = 8, # attention heads
dim_head = 64, # dimension per head
dec_attn_dropout = 0.25, # decoder attention dropout
dec_ff_dropout = 0.25 # decoder feedforward dropout
).cuda()
wrapper = TrainingWrapper(
retro = retro, # path to retro instance
knn = 2, # knn (2 in paper was sufficient)
chunk_size = 64, # chunk size (64 in paper)
documents_path = './text_folder', # path to folder of text
glob = '**/*.txt', # text glob
chunks_memmap_path = './train.chunks.dat', # path to chunks
seqs_memmap_path = './train.seq.dat', # path to sequence data
doc_ids_memmap_path = './train.doc_ids.dat', # path to document ids per chunk (used for filtering neighbors belonging to same document)
max_chunks = 1_000_000, # maximum cap to chunks
max_seqs = 100_000, # maximum seqs
knn_extra_neighbors = 100, # num extra neighbors to fetch
max_index_memory_usage = '100m',
current_memory_available = '1G'
)
# get the dataloader and optimizer (AdamW with all the correct settings)
train_dl = iter(wrapper.get_dataloader(batch_size = 2, shuffle = True))
optim = wrapper.get_optimizer(lr = 3e-4, wd = 0.01)
# now do your training
# ex. one gradient step
seq, retrieved = map(lambda t: t.cuda(), next(train_dl))
# seq - (2, 2049) - 1 extra token since split by seq[:, :-1], seq[:, 1:]
# retrieved - (2, 32, 2, 128) - 128 since chunk + continuation, each 64 tokens
loss = retro(
seq,
retrieved,
return_loss = True
)
# one gradient step
loss.backward()
optim.step()
optim.zero_grad()
# do above for many steps, then ...
# encode prompt
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
prompt_str = "The movie Dune was released in"
prompt_ids = tokenizer(prompt_str)['input_ids'][1:-1]
prompt = torch.tensor([prompt_ids])
sampled = wrapper.generate(prompt, filter_thres = 0.9, temperature = 1.0)
# decode sample
decoded = tokenizer.decode(sampled.tolist()[0])
print(decoded)
The code in the notebook for training several times is probably needed for good results though.
from retro-pytorch.
#23 contains a notebook with a good example.
I think putting it together with the README instructions looks like this:
import torch from retro_pytorch import RETRO, TrainingWrapper # instantiate RETRO, fit it into the TrainingWrapper with correct settings retro = RETRO( max_seq_len = 2048, # max sequence length enc_dim = 896, # encoder model dimension enc_depth = 3, # encoder depth dec_dim = 768, # decoder model dimensions dec_depth = 12, # decoder depth dec_cross_attn_layers = (1, 3, 6, 9), # decoder cross attention layers (with causal chunk cross attention) heads = 8, # attention heads dim_head = 64, # dimension per head dec_attn_dropout = 0.25, # decoder attention dropout dec_ff_dropout = 0.25 # decoder feedforward dropout ).cuda() wrapper = TrainingWrapper( retro = retro, # path to retro instance knn = 2, # knn (2 in paper was sufficient) chunk_size = 64, # chunk size (64 in paper) documents_path = './text_folder', # path to folder of text glob = '**/*.txt', # text glob chunks_memmap_path = './train.chunks.dat', # path to chunks seqs_memmap_path = './train.seq.dat', # path to sequence data doc_ids_memmap_path = './train.doc_ids.dat', # path to document ids per chunk (used for filtering neighbors belonging to same document) max_chunks = 1_000_000, # maximum cap to chunks max_seqs = 100_000, # maximum seqs knn_extra_neighbors = 100, # num extra neighbors to fetch max_index_memory_usage = '100m', current_memory_available = '1G' ) # get the dataloader and optimizer (AdamW with all the correct settings) train_dl = iter(wrapper.get_dataloader(batch_size = 2, shuffle = True)) optim = wrapper.get_optimizer(lr = 3e-4, wd = 0.01) # now do your training # ex. one gradient step seq, retrieved = map(lambda t: t.cuda(), next(train_dl)) # seq - (2, 2049) - 1 extra token since split by seq[:, :-1], seq[:, 1:] # retrieved - (2, 32, 2, 128) - 128 since chunk + continuation, each 64 tokens loss = retro( seq, retrieved, return_loss = True ) # one gradient step loss.backward() optim.step() optim.zero_grad() # do above for many steps, then ... # encode prompt from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") prompt_str = "The movie Dune was released in" prompt_ids = tokenizer(prompt_str)['input_ids'][1:-1] prompt = torch.tensor([prompt_ids]) sampled = wrapper.generate(prompt, filter_thres = 0.9, temperature = 1.0) # decode sample decoded = tokenizer.decode(sampled.tolist()[0]) print(decoded)The code in the notebook for training several times is probably needed for good results though.
@filipesilva Can you please share notebook which you are referencing, its not accessible. or if you can share code for training multiple epochs, will be really very helpful. Thanks
from retro-pytorch.
@aakashgoel12 looks like the notebook that was in #23 is not there anymore. I don't have a copy of it, unfortunately. All the code I have is what I put in the comment.
from retro-pytorch.
@aakashgoel12 looks like the notebook that was in #23 is not there anymore. I don't have a copy of it, unfortunately. All the code I have is what I put in the comment.
Thanks @filipesilva. Can you please check if what I have written below is correct or need some modification. Thanks in advance.
num_epochs=3
train_dl = iter(wrapper.get_dataloader(batch_size = 4, shuffle = True))
for epoch in range(num_epochs):
counter=0
for batch in tqdm(train_dl):
seq, retrieved = map(lambda t: t.cuda(), batch)
loss = retro(
seq,
retrieved,
return_loss = True)
# one gradient step
loss.backward()
optim.step()
optim.zero_grad()
if counter%10==0:
print("Epoch:{}, BatchNo:{}, Loss:{}".format(epoch, counter, loss))
counter+=1
print("After epoch - {}, loss: {}".format(epoch,loss))
from retro-pytorch.
I really can't tell 😅 I only played around with this a couple of months ago and never really tried again.
from retro-pytorch.
Related Issues (20)
- Error Reconstructing FAISS Index HOT 18
- Unable to do non-retrieval sampling
- Extra layer encoder_output_to_decoder_dim cause issue with distributed training HOT 2
- TrainingWrapper does not support line breaks HOT 8
- RuntimeError: Error in void faiss::gpu::GpuIndexIVFPQ::verifySettings_() HOT 3
- Double [CLS] token in the first doc chunk HOT 1
- Retro-fitting a pretrained model HOT 7
- Clarification on Architecture
- Scann vs faiss HOT 6
- 'NoneType' object is not callable HOT 1
- Is there any pre-trained RETRO model released yet? HOT 4
- Huggingface model
- I am revising the model to solve QA task.. HOT 1
- Why are there so many position embeddings? HOT 5
- Causal mask in Chunked Cross Attention
- Error # could not open .tmp/.index/knn.index for reading: No such file or directory
- Question-Answer Dataset Format ?
- AttributeError: module 'faiss' has no attribute 'GpuParameterSpace' HOT 2
- Question: residual connect after `ChunkedCrossAttention`? HOT 5
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from retro-pytorch.