Giter VIP home page Giter VIP logo

gordicaleksa / pytorch-original-transformer Goto Github PK

View Code? Open in Web Editor NEW
939.0 32.0 156.0 971 KB

My implementation of the original transformer model (Vaswani et al.). I've additionally included the playground.py file for visualizing otherwise seemingly hard concepts. Currently included IWSLT pretrained models.

Home Page: https://youtube.com/c/TheAIEpiphany

License: MIT License

Python 46.88% Jupyter Notebook 53.12%
transformer transformers pytorch-transformer pytorch-transformers attention attention-mechanism attention-is-all-you-need pytorch python jupyter

pytorch-original-transformer's Issues

A environment problem

Thanks for your work.
I met some problem about conda environment creating. The error as follow. Please tell me how to solve thie problem.

$ conda env create

Channels:
 - defaults
 - pytorch
Platform: linux-64
Collecting package metadata (repodata.json): done
Solving environment: failed
Channels:
 - defaults
 - pytorch
Platform: linux-64
Collecting package metadata (repodata.json): done
Solving environment: failed

LibMambaUnsatisfiableError: Encountered problems while solving:
  - package pytorch-1.5.0-cpu_py37hd91cbb3_0 requires python >=3.7,<3.8.0a0, but none of the providers can be installed

Could not solve for environment specs
The following packages are incompatible
├─ python 3.8.3  is requested and can be installed;
└─ pytorch 1.5.0  is not installable because it requires
   └─ python >=3.7,<3.8.0a0 , which conflicts with any installable versions previously reported.

sharing weight matrix between the two embedding layers and the pre-softmax linear transformation

Hi, thanks for your repo: helps a lot!
In the paper weight matrix is shared between the two embedding layers and the pre-softmax linear transformation.
"In our model, we share the same weight matrix between the two embedding layers and the pre-softmax
linear transformation, similar to [30]. " (Page 5, Chapter 3.4 Embeddings and Softmax)
Would it be correct to modify in transformer_model.py the following rows to something like this:
rows 32-33 -> self.src_embedding = self.trg_embedding = Embedding(src_vocab_size, model_dimension)
row 50 -> self.decoder_generator = DecoderGenerator(self.src_embedding.embeddings_table.weight)
row 221 -> def init(self, shared_embedding_weights):
row 224 -> self.linear = nn.Linear(shared_embedding_weights.size()[1], shared_embedding_weights.size()[0], bias=False)
del self.linear.weight
self.shared_embedding_weights = shared_embedding_weights
row 232 -> self.linear.weight = self.shared_embedding_weights
row 233 -> return self.log_softmax(self.linear(trg_representations_batch) * math.sqrt(self.shared_embedding_weights.size()[1]))

torchtext.data import not working in the latest versions of pytorch.

Data manipulation related imports

from torchtext.data import Dataset, BucketIterator, Field, Example
from torchtext.data.utils import interleave_keys
from torchtext import datasets
from torchtext.data import Example

imports under this are not working as the torchtext.data functions such as Dataset, BucketIterator, Field, Example are removed in the latest version of pytorch.

Would it be possible to migrate the pytorch-original-transformer code to the new veresion of pytorch-original-transformer?

Frequency in the positional encodings

What does the frequency represent in positional encoding ?
Why do we need to multiply it with the positional values?

frequencies = torch.pow(10000., -torch.arange(0, model_dimension, 2, dtype=torch.float) / model_dimension)

Issue regarding "9.1 Download pretrained transformers automatically"

While running def translate_a_single_sentence(translation_config): I have encountered an error in which the file en-de.tgz is not recognized as a gzip file. How could I do?
Below, it is reported the snippet of the error :
` downloading en-de.tgz
C:\Users..\pytorch-original-transformer\data\iwslt\en-de.tgz: 97.4kB [00:00, 1.60MB/s]

BadGzipFile Traceback (most recent call last)
~\anaconda3\envs\pytorch-transformer\lib\tarfile.py in gzopen(cls, name, mode, fileobj, compresslevel, **kwargs)
1669 try:
-> 1670 t = cls.taropen(name, mode, fileobj, **kwargs)
1671 except OSError:

~\anaconda3\envs\pytorch-transformer\lib\tarfile.py in taropen(cls, name, mode, fileobj, **kwargs)
1646 raise ValueError("mode must be 'r', 'a', 'w' or 'x'")
-> 1647 return cls(name, mode, fileobj, **kwargs)
1648

~\anaconda3\envs\pytorch-transformer\lib\tarfile.py in init(self, name, mode, fileobj, format, tarinfo, dereference, ignore_zeros, encoding, errors, pax_headers, debug, errorlevel, copybufsize)
1509 self.firstmember = None
-> 1510 self.firstmember = self.next()
1511

~\anaconda3\envs\pytorch-transformer\lib\tarfile.py in next(self)
2310 try:
-> 2311 tarinfo = self.tarinfo.fromtarfile(self)
2312 except EOFHeaderError as e:

~\anaconda3\envs\pytorch-transformer\lib\tarfile.py in fromtarfile(cls, tarfile)
1101 """
-> 1102 buf = tarfile.fileobj.read(BLOCKSIZE)
1103 obj = cls.frombuf(buf, tarfile.encoding, tarfile.errors)

~\anaconda3\envs\pytorch-transformer\lib\gzip.py in read(self, size)
291 raise OSError(errno.EBADF, "read() on write-only GzipFile object")
--> 292 return self._buffer.read(size)
293

~\anaconda3\envs\pytorch-transformer\lib_compression.py in readinto(self, b)
67 with memoryview(b) as view, view.cast("B") as byte_view:
---> 68 data = self.read(len(byte_view))
69 byte_view[:len(data)] = data

~\anaconda3\envs\pytorch-transformer\lib\gzip.py in read(self, size)
478 self._init_read()
--> 479 if not self._read_gzip_header():
480 self._size = self._pos

~\anaconda3\envs\pytorch-transformer\lib\gzip.py in _read_gzip_header(self)
426 if magic != b'\037\213':
--> 427 raise BadGzipFile('Not a gzipped file (%r)' % magic)
428

BadGzipFile: Not a gzipped file (b'<!')

During handling of the above exception, another exception occurred:

ReadError Traceback (most recent call last)
in
85
86 # Translate the given source sentence
---> 87 translate_a_single_sentence(translation_config)

in translate_a_single_sentence(translation_config)
5 print(2)
6 # Step 1: Prepare the field processor (tokenizer, numericalizer)
----> 7 _, _, src_field_processor, trg_field_processor = get_datasets_and_vocabs(
8 translation_config['dataset_path'],
9 translation_config['language_direction'],

in get_datasets_and_vocabs(dataset_path, language_direction, use_iwslt, use_caching_mechanism)
41 dataset_split_fn = datasets.IWSLT.splits if use_iwslt else datasets.WMT14.splits
42
---> 43 train_dataset, val_dataset, test_dataset = dataset_split_fn(
44 exts=(src_ext, trg_ext),
45 fields=fields,

~\anaconda3\envs\pytorch-transformer\lib\site-packages\torchtext\datasets\translation.py in splits(cls, exts, fields, root, train, validation, test, **kwargs)
142 cls.urls = [cls.base_url.format(exts[0][1:], exts[1][1:], cls.dirname)]
143 check = os.path.join(root, cls.name, cls.dirname)
--> 144 path = cls.download(root, check=check)
145
146 train = '.'.join([train, cls.dirname])

~\anaconda3\envs\pytorch-transformer\lib\site-packages\torchtext\data\dataset.py in download(cls, root, check)
189 # tarfile cannot handle bare .gz files
190 elif ext == '.tgz' or ext == '.gz' and ext_inner == '.tar':
--> 191 with tarfile.open(zpath, 'r:gz') as tar:
192 dirs = [member for member in tar.getmembers()]
193 tar.extractall(path=path, members=dirs)

~\anaconda3\envs\pytorch-transformer\lib\tarfile.py in open(cls, name, mode, fileobj, bufsize, **kwargs)
1615 else:
1616 raise CompressionError("unknown compression type %r" % comptype)
-> 1617 return func(name, filemode, fileobj, **kwargs)
1618
1619 elif "|" in mode:

~\anaconda3\envs\pytorch-transformer\lib\tarfile.py in gzopen(cls, name, mode, fileobj, compresslevel, **kwargs)
1672 fileobj.close()
1673 if mode == 'r':
-> 1674 raise ReadError("not a gzip file")
1675 raise
1676 except:

ReadError: not a gzip file`

Thank you very much!

Error when running "python training_script.py --batch_size 100 --dataset_name IWSLT --language_direction G2E

Not sure what is going on here but the best that I can tell is that there is a gzip file that seems to be missing.

Thank You
Tom

Traceback (most recent call last):
File "/home/tom/anaconda3/envs/pytorch-transformer/lib/python3.8/tarfile.py", line 1670, in gzopen
t = cls.taropen(name, mode, fileobj, **kwargs)
File "/home/tom/anaconda3/envs/pytorch-transformer/lib/python3.8/tarfile.py", line 1647, in taropen
return cls(name, mode, fileobj, **kwargs)
File "/home/tom/anaconda3/envs/pytorch-transformer/lib/python3.8/tarfile.py", line 1510, in init
self.firstmember = self.next()
File "/home/tom/anaconda3/envs/pytorch-transformer/lib/python3.8/tarfile.py", line 2311, in next
tarinfo = self.tarinfo.fromtarfile(self)
File "/home/tom/anaconda3/envs/pytorch-transformer/lib/python3.8/tarfile.py", line 1102, in fromtarfile
buf = tarfile.fileobj.read(BLOCKSIZE)
File "/home/tom/anaconda3/envs/pytorch-transformer/lib/python3.8/gzip.py", line 292, in read
return self._buffer.read(size)
File "/home/tom/anaconda3/envs/pytorch-transformer/lib/python3.8/_compression.py", line 68, in readinto
data = self.read(len(byte_view))
File "/home/tom/anaconda3/envs/pytorch-transformer/lib/python3.8/gzip.py", line 479, in read
if not self._read_gzip_header():
File "/home/tom/anaconda3/envs/pytorch-transformer/lib/python3.8/gzip.py", line 427, in _read_gzip_header
raise BadGzipFile('Not a gzipped file (%r)' % magic)
gzip.BadGzipFile: Not a gzipped file (b'<!')

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "training_script.py", line 192, in
train_transformer(training_config)
File "training_script.py", line 103, in train_transformer
train_token_ids_loader, val_token_ids_loader, src_field_processor, trg_field_processor = get_data_loaders(
File "/home/tom/Downloads/pytorch-original-transformer/utils/data_utils.py", line 223, in get_data_loaders
train_dataset, val_dataset, src_field_processor, trg_field_processor = get_datasets_and_vocabs(dataset_path, language_direction, dataset_name == DatasetType.IWSLT.name)
File "/home/tom/Downloads/pytorch-original-transformer/utils/data_utils.py", line 151, in get_datasets_and_vocabs
train_dataset, val_dataset, test_dataset = dataset_split_fn(
File "/home/tom/anaconda3/envs/pytorch-transformer/lib/python3.8/site-packages/torchtext/datasets/translation.py", line 144, in splits
path = cls.download(root, check=check)
File "/home/tom/anaconda3/envs/pytorch-transformer/lib/python3.8/site-packages/torchtext/data/dataset.py", line 191, in download
with tarfile.open(zpath, 'r:gz') as tar:
File "/home/tom/anaconda3/envs/pytorch-transformer/lib/python3.8/tarfile.py", line 1617, in open
return func(name, filemode, fileobj, **kwargs)
File "/home/tom/anaconda3/envs/pytorch-transformer/lib/python3.8/tarfile.py", line 1674, in gzopen
raise ReadError("not a gzip file")
tarfile.ReadError: not a gzip file

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.