Giter VIP home page Giter VIP logo

pytorch-center-loss's Introduction

pytorch-center-loss

Pytorch implementation of center loss: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.

This loss function is also used by deep-person-reid.

Get started

Clone this repo and run the code

$ git clone https://github.com/KaiyangZhou/pytorch-center-loss
$ cd pytorch-center-loss
$ python main.py --eval-freq 1 --gpu 0 --save-dir log/ --plot

You will see the following info in your terminal

Currently using GPU: 0
Creating dataset: mnist
Creating model: cnn
==> Epoch 1/100
Batch 50/469     Loss 2.332793 (2.557837) XentLoss 2.332744 (2.388296) CenterLoss 0.000048 (0.169540)
Batch 100/469    Loss 2.354638 (2.463851) XentLoss 2.354637 (2.379078) CenterLoss 0.000001 (0.084773)
Batch 150/469    Loss 2.361732 (2.434477) XentLoss 2.361732 (2.377962) CenterLoss 0.000000 (0.056515)
Batch 200/469    Loss 2.336701 (2.417842) XentLoss 2.336700 (2.375455) CenterLoss 0.000001 (0.042386)
Batch 250/469    Loss 2.404814 (2.407015) XentLoss 2.404813 (2.373106) CenterLoss 0.000001 (0.033909)
Batch 300/469    Loss 2.338753 (2.398546) XentLoss 2.338752 (2.370288) CenterLoss 0.000001 (0.028258)
Batch 350/469    Loss 2.367068 (2.390672) XentLoss 2.367059 (2.366450) CenterLoss 0.000009 (0.024221)
Batch 400/469    Loss 2.344178 (2.384820) XentLoss 2.344142 (2.363620) CenterLoss 0.000036 (0.021199)
Batch 450/469    Loss 2.329708 (2.379460) XentLoss 2.329661 (2.360611) CenterLoss 0.000047 (0.018848)
==> Test
Accuracy (%): 10.32  Error rate (%): 89.68
... ...
==> Epoch 30/100
Batch 50/469     Loss 0.141117 (0.155986) XentLoss 0.084169 (0.091617) CenterLoss 0.056949 (0.064369)
Batch 100/469    Loss 0.138201 (0.151291) XentLoss 0.089146 (0.092839) CenterLoss 0.049055 (0.058452)
Batch 150/469    Loss 0.151055 (0.151985) XentLoss 0.090816 (0.092405) CenterLoss 0.060239 (0.059580)
Batch 200/469    Loss 0.150803 (0.153333) XentLoss 0.092857 (0.092156) CenterLoss 0.057946 (0.061176)
Batch 250/469    Loss 0.162954 (0.154971) XentLoss 0.094889 (0.092099) CenterLoss 0.068065 (0.062872)
Batch 300/469    Loss 0.162895 (0.156038) XentLoss 0.093100 (0.092034) CenterLoss 0.069795 (0.064004)
Batch 350/469    Loss 0.146187 (0.156491) XentLoss 0.082508 (0.091787) CenterLoss 0.063679 (0.064704)
Batch 400/469    Loss 0.171533 (0.157390) XentLoss 0.092526 (0.091674) CenterLoss 0.079007 (0.065716)
Batch 450/469    Loss 0.209196 (0.158371) XentLoss 0.098388 (0.091560) CenterLoss 0.110808 (0.066811)
==> Test
Accuracy (%): 98.51  Error rate (%): 1.49
... ...

Please run python main.py -h for more details regarding input arguments.

Results

We visualize the feature learning process below.

Softmax only. Left: training set. Right: test set.

train train

Softmax + center loss. Left: training set. Right: test set.

train train

How to use center loss in your own project

  1. All you need is the center_loss.py file
from center_loss import CenterLoss
  1. Initialize center loss in the main function
center_loss = CenterLoss(num_classes=10, feat_dim=2, use_gpu=True)
  1. Construct an optimizer for center loss
optimizer_centloss = torch.optim.SGD(center_loss.parameters(), lr=0.5)

Alternatively, you can merge optimizers of model and center loss, like

params = list(model.parameters()) + list(center_loss.parameters())
optimizer = torch.optim.SGD(params, lr=0.1) # here lr is the overall learning rate
  1. Update class centers just like how you update a pytorch model
# features (torch tensor): a 2D torch float tensor with shape (batch_size, feat_dim)
# labels (torch long tensor): 1D torch long tensor with shape (batch_size)
# alpha (float): weight for center loss
loss = center_loss(features, labels) * alpha + other_loss
optimizer_centloss.zero_grad()
loss.backward()
# multiple (1./alpha) in order to remove the effect of alpha on updating centers
for param in center_loss.parameters():
    param.grad.data *= (1./alpha)
optimizer_centloss.step()

If you adopt the second way (i.e. use one optimizer for both model and center loss), the update code would look like

loss = center_loss(features, labels) * alpha + other_loss
optimizer.zero_grad()
loss.backward()
for param in center_loss.parameters():
    # lr_cent is learning rate for center loss, e.g. lr_cent = 0.5
    param.grad.data *= (lr_cent / (alpha * lr))
optimizer.step()

pytorch-center-loss's People

Contributors

chuhanxx avatar kaiyangzhou avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

pytorch-center-loss's Issues

optimizer got an empty parameter list

when i put centerloss's parameters into optimizer, raise valueError("optimizer got an empty parameter list")
optimizer_centloss = torch.optim.SGD(criterion_cent.parameters(), lr=args.lr_cent)

center loss

请问,可以单独使用center loss不加softmax loss吗,效果怎么样

gradient exploding

Gradient exploding when I set ‘weight-cent’ to 200+,then I got a loss value(Nan). why?

about loss

In my project experiment, the center loss does not decrease with the number of iterations.It seemingly irregular changes. I don't quite understand what's going on. I'd like to ask you about it.Thanks.

UserWarning: This overload of addmm_ is deprecated

Packages:
torch 1.5.0
torchvision 0.6.0

Code:
distmat.addmm_(1, -2, x, self.centers.t())

Warning:
..\torch\csrc\utils\python_arg_parser.cpp:756: UserWarning: This overload of addmm_ is deprecated:
     addmm_(Number beta, Number alpha, Tensor mat1, Tensor mat2)
Consider using one of the following signatures instead:
     addmm_(Tensor mat1, Tensor mat2, *, Number beta, Number alpha)

Fix:
distmat.addmm_(x, self.centers.t(), beta=1, alpha=-2)

Plots to vizualize

What are the axis in the plots shown in the README file?
How can we achieve it?

Inappropriate description

in the center_loss.py, the shape of ground truth label should be (batch_size), not num_classes
and I think it can be good, if you decrease using of cuda() in center_loss.py

About updating centers

Hi~
May I ask whether the implementation in this repo will be equivalent to the original implementation in paper?

# by doing so, weight_cent would not impact on the learning of centers
for param in criterion_cent.parameters():
     param.grad.data *= (1. / args.weight_cent)

Simply let alpha=alpha/lambda, (alpha is lr_cent, lambda is weight_cent), will it equivalent to previous implementation?

It seems the author do not adopt gradient w.r.t. c_j, and instead, use the delta rule shown below.

image

image

Center loss in small batch multiclass training

Hello authors and contributors of center loss, thanks for the impressive work. I got a question as I noticed the center update is based on each mini-batch, and if I had a small batch and many classes (way more than batch size), the center update may become tricky, and I wonder if using momentum or other optimizers is necessary for training. Thanks.

question about center loss

I don't know how Xi changes with depth.
请问,公式里说Xi是随深度变化的特征维度,代码里是怎么体现的,没看出来?feat_dim的设置是输入的维度吗,可以保证优化的是最后的输出维度吗?我用于的数据不是图片数据集。请大佬回答我的疑惑,谢谢。

How many classes?

Hello!
I am trying to train a model with central loss. My dataset has 901 classes. I am creating mini batches after shuffling the training data with batch size of 128 such that:

feats_dim = [128,512] per batch
labels_dim =[128] per batch

Then:
center_loss = CenterLoss(num_classes = ????, feat_dim = 512, use_gpu=False)

What should I pass in num_classes?
901: actual number of classes
128: classes in current batch

Nonetype of param.grad.data

Hi, I tried to use your package directly in my network. However, the param.grad is NoneType and param.grad.data cannot be got.. Could you help me to explain it? Thanks!

cross entropy loss

I'm afraid that u're writing the wrong code for the classification loss. In pytorch, the softmax cross entropy loss should be

NLL + LogSoftmax

or

nn.CrossEntropyLoss + logits

about use own database

Hello,thanks for your projects,when i use my own database something is wrong,the problems is:
Traceback (most recent call last):
File "main.py", line 213, in
main()
File "main.py", line 96, in main
train(model, criterion_xent, criterion_cent,optimizer_model, optimizer_centloss, train_loader, use_gpu, 2, epoch)
File "main.py", line 120, in train
for batch_idx, (data, labels) in enumerate(trainloader):
ValueError: too many values to unpack (expected 2)

I don't know how to fix it.thank you very much.

Doesn't anyone think the author's center loss is too complicated?

A concise and easy to understand version

class CenterLoss(nn.Module):
    def __init__(self, num_class=10, num_feature=2):
        super(CenterLoss, self).__init__()
        self.num_class = num_class
        self.num_feature = num_feature
        self.centers = nn.Parameter(torch.randn(self.num_class, self.num_feature))

    def forward(self, x, labels):
        center = self.centers[labels]
        dist = (x-center).pow(2).sum(dim=-1)
        loss = torch.clamp(dist, min=1e-12, max=1e+12).mean(dim=-1)

        return loss

accuracy not increasing

I use your center loss in my own vgg19. Although my loss decreasing slowly, but the acc on test set never changed, I'm sure I update both opimizer's params, what's wrong with my code?

High dimension feature embedding visualization

Hi, Kaiyang. The feature embedding is 2-dim in this repository and it is straightforward to visualize.
But if the feature is in a high-dimension space (e.g 2048, ResNet50 GAP feature). How I visualize
it in 2-d feature space, should I utilize the approaches such as PCA or TSNE ?

Pre-Learned Model

Hi, would it be possible for you to share your pre-learned model? Thank you.

Dont know how to apply centerloss to my own project

I want to use ResNet50 as my model and change Fully Connection layer to classify images.

#Create model:
model = models.resnet50(pretrained=True)
fc_inputs = model.fc.in_features
class FClayer(nn.Module):
def init(self):
super(FClayer, self).init()
self.fc1 = nn.Linear(fc_inputs, 2)
self.relu1 = nn.ReLU()
self.fc2 = nn.Linear(2, num_classes)
def forward(self, x):
x = self.relu1(self.fc1(x))
y = self.fc2(x)
return x, y
model.fc = FClayer()
...
#In train function:
features, outputs = model(inputs)
alpha = 1
loss = loss_criterion(outputs, labels) + (loss_criterion_cent(features, labels) * alpha)
optimizer.zero_grad()
optimizer_cent.zero_grad()
loss.backward()
optimizer.step()
for param in loss_criterion_cent.parameters():
param.grad.data *= (1./alpha)
optimizer_cent.step()

I dont know why when i was training, at the very first epochs, the accuracy had increased, but soon later, nothing happended. Here is my accuracy plot curve chart:
Sign_accuracy_curve

Do I have something wrong?

The center loss is very large

At the beginning of training, the center loss is very large, for example, 520, although the use of pre-trained model.
PS. I use the arcface as pre-train before adding center loss

about the feat_dim

Thanks the auther about the implementaion of center loss.But i have a question about will the feat_dim in center_loss have an negative effect on the result ? My experiment was on the res_net18 and the feat_dim is 2048,after using the center_loss,the accuracy has a sharp decrease from 80% to 60%.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.