Comments (5)
For multi label classification one is supposed to use sigmoid
over softmax
, because softmax makes sure the output of all values add up to zero. Therefore, you can get it to predict only one value. Whereas sigmoid there is no such restriction.
from pytorch-sentiment-analysis.
@greed2411
Thanks for the prompt answer. So in that case, does model should look like following?
class CNN(nn.Module):
def __init__(self, vocab_size, embedding_dim, n_filters, filter_sizes, output_dim, dropout):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.conv_0 = nn.Conv2d(in_channels=1, out_channels=n_filters, kernel_size=(filter_sizes[0],embedding_dim))
self.conv_1 = nn.Conv2d(in_channels=1, out_channels=n_filters, kernel_size=(filter_sizes[1],embedding_dim))
self.conv_2 = nn.Conv2d(in_channels=1, out_channels=n_filters, kernel_size=(filter_sizes[2],embedding_dim))
self.fc = nn.Linear(len(filter_sizes)*n_filters, output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
#x = [sent len, batch size]
x = x.permute(1, 0)
#x = [batch size, sent len]
embedded = self.embedding(x)
#embedded = [batch size, sent len, emb dim]
embedded = embedded.unsqueeze(1)
#embedded = [batch size, 1, sent len, emb dim]
conved_0 = F.relu(self.conv_0(embedded).squeeze(3))
conved_1 = F.relu(self.conv_1(embedded).squeeze(3))
conved_2 = F.relu(self.conv_2(embedded).squeeze(3))
#conv_n = [batch size, n_filters, sent len - filter_sizes[n]]
pooled_0 = F.max_pool1d(conved_0, conved_0.shape[2]).squeeze(2)
pooled_1 = F.max_pool1d(conved_1, conved_1.shape[2]).squeeze(2)
pooled_2 = F.max_pool1d(conved_2, conved_2.shape[2]).squeeze(2)
#pooled_n = [batch size, n_filters]
cat = self.dropout(torch.cat((pooled_0, pooled_1, pooled_2), dim=1))
#cat = [batch size, n_filters * len(filter_sizes)]
return F.sigmoid(self.fc(cat))
from pytorch-sentiment-analysis.
Yes that's how it should be.
On another case, can't think of a reason why @bentrevett used Conv2d
instead of Conv1d
for text data.
from pytorch-sentiment-analysis.
@enod, I believe that implementation is correct.
@greed2411, I'm not sure why I used it either. I believe I found it found it easier to get my head around it when thinking about it as 2d.
I'm currently in the process of updating these tutorials to TorchText 0.3 (as it has better integration with PyTorch 0.4) and will change to a Conv1d
.
from pytorch-sentiment-analysis.
@greed2411 @bentrevett
Thanks for the feedback. Using sigmoid didn't work.
It turns out BCEWithLogitsLoss has sigmoid layer in it already. So if I'm not wrong, there is no need to add sigmoid layer in the model.
from pytorch-sentiment-analysis.
Related Issues (20)
- The train_data built from my own dataset after following the Appendix A looks wrong HOT 1
- migrating to the new API HOT 4
- .squeeze(1) HOT 5
- for word embedding in RNN model HOT 2
- Representation of similar words HOT 1
- Using a target size (torch.Size([64, 1])) that is different to the input size (torch.Size([304800, 1])) is deprecated. Please ensure they have the same size HOT 7
- train_test_split in LSTM HOT 2
- pad sequence in my dataset
- 6 - Transformers for Sentiment Analysis HOT 1
- Question in fasttext HOT 1
- ModuleNotFoundError
- How can I predict on one example?
- User Interface
- Got error" 'lengths' argument should be a 1D CPU int64 tensor, but got 1D cuda:0 Long tensor " HOT 10
- where is the trained model parameters ? HOT 1
- TypeError HOT 2
- Multi-class Sentiment Analysis: How to use custom dataset?
- how did you build torchtext from source HOT 4
- how does pytorch pad sentences ? HOT 3
- Pad Sequence error HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch-sentiment-analysis.