Giter VIP home page Giter VIP logo

Comments (5)

fafafafafafafa avatar fafafafafafafa commented on June 2, 2024 1

I find where is wrong, edit
self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)).cuda() to be
self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
but I don't know the difference between them.

from pytorch-center-loss.

mohamedr002 avatar mohamedr002 commented on June 2, 2024

Can you provide a code snippet for the definition of criterion_cent, any you need to ensure that you are doing it like this

self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())

from pytorch-center-loss.

fafafafafafafa avatar fafafafafafafa commented on June 2, 2024
SoftmaxLoss = torch.nn.CrossEntropyLoss()
centerLoss = center_losses.CenterLoss(classes=train_class, feature_dims=1024)
optimizer_centerloss = torch.optim.SGD(list(centerLoss.parameters()), lr=0.5)

class CenterLoss(nn.Module):
    def __init__(self, classes, feature_dims, use_gpu=True):
        super(CenterLoss, self).__init__()
        self.classes = classes
        self.feature_dims = feature_dims
        self.use_gpu = use_gpu
        if use_gpu:
            centers = nn.Parameter(torch.randn(self.classes, self.feature_dims)).cuda()
        else:
            centers = nn.Parameter(torch.randn(self.classes, self.feature_dims))
        self.centers = centers

    def forward(self, x, labels):

        # labels: [N_way, K_shot]
        batch_size = x.shape[0]    # x_shape: torch.Size([N_way*K_shot, 1024])
        # dist_mat: torch.Size([batch_size, classes])
        # print('x: ', x)
        print('centers: ', self.centers)
        dist_mat = torch.sum(torch.square(x), 1, keepdim=True).expand(batch_size, self.classes) + \
             torch.sum(torch.square(self.centers), 1, keepdim=True).expand(self.classes, batch_size).t()
        dist_mat = dist_mat - 2*torch.matmul(x, self.centers.t())

        # dist_mat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.classes) + \
        #    torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.classes, batch_size).t()
        # dist_mat.addmm_(x, self.centers.t(), beta=1, alpha=-2)

        labels = torch.reshape(labels, (-1, 1)).expand(batch_size, self.classes)
        classes_mat = torch.arange(self.classes).expand(batch_size, self.classes).long()
        if self.use_gpu:
            # print('use_gpu:', self.use_gpu)
            classes_mat = classes_mat.cuda()
        mask = labels.eq(classes_mat).float()
        dist_mat = dist_mat*mask
        center_loss = torch.sum(dist_mat.clamp(min=1e-12, max=1e+12))/(batch_size*self.feature_dims)
        # get support set centers
        mask1 = torch.sum(mask, 0).bool()
        support_centers = self.centers[mask1, :]
        # print('support_centers:', support_centers.shape)
        return center_loss, support_centers

from pytorch-center-loss.

mohamedr002 avatar mohamedr002 commented on June 2, 2024

It seems there is no issues with your code, but can try removing list from the below part
optimizer_centerloss = torch.optim.SGD(list(centerLoss.parameters()), lr=0.5). to be
optimizer_centerloss = torch.optim.SGD(centerLoss.parameters(), lr=0.5).

from pytorch-center-loss.

fafafafafafafa avatar fafafafafafafa commented on June 2, 2024

I have removed list, but it also has the error above.

from pytorch-center-loss.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.