Comments (5)
I find where is wrong, edit
self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)).cuda()
to be
self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
but I don't know the difference between them.
from pytorch-center-loss.
Can you provide a code snippet for the definition of criterion_cent
, any you need to ensure that you are doing it like this
self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
from pytorch-center-loss.
SoftmaxLoss = torch.nn.CrossEntropyLoss()
centerLoss = center_losses.CenterLoss(classes=train_class, feature_dims=1024)
optimizer_centerloss = torch.optim.SGD(list(centerLoss.parameters()), lr=0.5)
class CenterLoss(nn.Module):
def __init__(self, classes, feature_dims, use_gpu=True):
super(CenterLoss, self).__init__()
self.classes = classes
self.feature_dims = feature_dims
self.use_gpu = use_gpu
if use_gpu:
centers = nn.Parameter(torch.randn(self.classes, self.feature_dims)).cuda()
else:
centers = nn.Parameter(torch.randn(self.classes, self.feature_dims))
self.centers = centers
def forward(self, x, labels):
# labels: [N_way, K_shot]
batch_size = x.shape[0] # x_shape: torch.Size([N_way*K_shot, 1024])
# dist_mat: torch.Size([batch_size, classes])
# print('x: ', x)
print('centers: ', self.centers)
dist_mat = torch.sum(torch.square(x), 1, keepdim=True).expand(batch_size, self.classes) + \
torch.sum(torch.square(self.centers), 1, keepdim=True).expand(self.classes, batch_size).t()
dist_mat = dist_mat - 2*torch.matmul(x, self.centers.t())
# dist_mat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.classes) + \
# torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.classes, batch_size).t()
# dist_mat.addmm_(x, self.centers.t(), beta=1, alpha=-2)
labels = torch.reshape(labels, (-1, 1)).expand(batch_size, self.classes)
classes_mat = torch.arange(self.classes).expand(batch_size, self.classes).long()
if self.use_gpu:
# print('use_gpu:', self.use_gpu)
classes_mat = classes_mat.cuda()
mask = labels.eq(classes_mat).float()
dist_mat = dist_mat*mask
center_loss = torch.sum(dist_mat.clamp(min=1e-12, max=1e+12))/(batch_size*self.feature_dims)
# get support set centers
mask1 = torch.sum(mask, 0).bool()
support_centers = self.centers[mask1, :]
# print('support_centers:', support_centers.shape)
return center_loss, support_centers
from pytorch-center-loss.
It seems there is no issues with your code, but can try removing list from the below part
optimizer_centerloss = torch.optim.SGD(list(centerLoss.parameters()), lr=0.5).
to be
optimizer_centerloss = torch.optim.SGD(centerLoss.parameters(), lr=0.5).
from pytorch-center-loss.
I have removed list, but it also has the error above.
from pytorch-center-loss.
Related Issues (20)
- question about center loss
- about loss HOT 7
- Dont know how to apply centerloss to my own project HOT 2
- Plots to vizualize HOT 1
- about use own database
- High dimension feature embedding visualization HOT 1
- UserWarning: This overload of addmm_ is deprecated
- Pre-Learned Model
- center loss
- How many classes? HOT 2
- Doesn't anyone think the author's center loss is too complicated? HOT 3
- How to plot the figures? HOT 2
- remember to include The center loss parameters in the optimizer, othewise it will not work HOT 4
- gradient exploding
- The center loss is very large
- Center loss in small batch multiclass training
- about the feat_dim
- accuracy not increasing
- question about optimizer
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch-center-loss.