Thank you for your great tutorial, I want to use your tutorial to make some modifications and apply it to my work, here I will need to use some Transforms from MONAI, but I found that the loss of the program will not change after a few epochs , is there any suggestion here?
Thanks in advance!
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
import os, sys, glob
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from monai.data import CacheDataset, ThreadDataLoader
from monai.transforms import (
Compose,
EnsureType,
ToDevice,
RandSpatialCropSamples,
)
from torchvision.models import resnet18
from torchvision.datasets import STL10
from torchvision import transforms
class ContrastiveTransformations(object):
def __init__(self, base_transforms, n_views=2):
self.base_transforms = base_transforms
self.n_views = n_views
def __call__(self, x):
return [self.base_transforms(x) for i in range(self.n_views)]
class SimCLR(LightningModule):
def __init__(self, hidden_dim, lr, temperature, weight_decay, batch_size, max_epochs=500):
super().__init__()
self.save_hyperparameters()
assert self.hparams.temperature > 0.0, 'The temperature must be a positive float!'
# Base model f(.)
self.convnet = resnet18(pretrained=False, num_classes=4*hidden_dim) # Output of last linear layer
# The MLP for g(.) consists of Linear->ReLU->Linear
self.convnet.fc = nn.Sequential(
self.convnet.fc, # Linear(ResNet output, 4*hidden_dim)
nn.ReLU(inplace=True),
nn.Linear(4*hidden_dim, hidden_dim)
)
def prepare_data(self):
unlabeled_data = STL10(root='datasets', split='unlabeled', download=False,
transform=transforms.Compose([transforms.ToTensor()]))
train_data_contrast = STL10(root='datasets', split='train', download=False,
transform=transforms.Compose([transforms.ToTensor()]))
train_files = list()
test_files = list()
for i,data in enumerate(unlabeled_data):
if i >= 10000:
break
img, _ = data
train_files.append(img)
test_files = [img for img,_ in train_data_contrast]
contrast_transforms = [
EnsureType(),
ToDevice(device='cuda:0'),
RandSpatialCropSamples(roi_size=(50,50), num_samples=2, random_size=False, random_center=True),
]
self.train_ds = CacheDataset(
data=train_files,
transform=Compose(contrast_transforms),
cache_rate=1.0,
copy_cache=False,
num_workers=4
)
self.test_ds = CacheDataset(
data=test_files,
transform=Compose(contrast_transforms),
cache_rate=1.0,
copy_cache=False,
num_workers=4
)
def train_dataloader(self):
return ThreadDataLoader(self.train_ds,
num_workers=0,
batch_size=self.hparams.batch_size,
shuffle=True)
def val_dataloader(self):
return ThreadDataLoader(self.test_ds,
num_workers=0,
batch_size=self.hparams.batch_size,
shuffle=False)
def configure_optimizers(self):
optimizer = optim.AdamW(self.parameters(),
lr=self.hparams.lr,
weight_decay=self.hparams.weight_decay)
lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
T_max=self.hparams.max_epochs,
eta_min=self.hparams.lr/50)
return [optimizer], [lr_scheduler]
def info_nce_loss(self, batch, mode='train'):
# imgs = torch.cat(batch['image'], dim=0)
imgs = batch
# Encode all images
feats = self.convnet(imgs)
# Calculate cosine similarity
cos_sim = F.cosine_similarity(feats[:,None,:], feats[None,:,:], dim=-1)
# Mask out cosine similarity to itself
self_mask = torch.eye(cos_sim.shape[0], dtype=torch.bool, device=cos_sim.device)
cos_sim.masked_fill_(self_mask, -9e15)
# Find positive example -> batch_size//2 away from the original example
pos_mask = self_mask.roll(shifts=cos_sim.shape[0]//2, dims=0)
# InfoNCE loss
cos_sim = cos_sim / self.hparams.temperature
nll = -cos_sim[pos_mask] + torch.logsumexp(cos_sim, dim=-1)
nll = nll.mean()
# Logging loss
self.log(mode+'_loss', nll)
# Get ranking position of positive example
comb_sim = torch.cat([cos_sim[pos_mask][:,None], # First position positive example
cos_sim.masked_fill(pos_mask, -9e15)],
dim=-1)
sim_argsort = comb_sim.argsort(dim=-1, descending=True).argmin(dim=-1)
# Logging ranking metrics
self.log(mode+'_acc_top1', (sim_argsort == 0).float().mean())
self.log(mode+'_acc_top5', (sim_argsort < 5).float().mean())
self.log(mode+'_acc_mean_pos', 1+sim_argsort.float().mean())
return nll
def training_step(self, batch, batch_idx):
return self.info_nce_loss(batch, mode='train')
def validation_step(self, batch, batch_idx):
self.info_nce_loss(batch, mode='val')
if __name__ == '__main__':
seed_everything(42)
tb_logger = TensorBoardLogger(save_dir='logs', name='SimCLR')
checkpoint_dir = os.path.join(tb_logger.save_dir, tb_logger.name, 'version_%d'%tb_logger.version,'checkpoints')
max_epochs = 500
trainer = Trainer(gpus=[0],
max_epochs=max_epochs,
logger=tb_logger,
enable_progress_bar=True,
enable_checkpointing=True,
num_sanity_val_steps=1,
callbacks=[ModelCheckpoint(save_weights_only=True,
save_top_k=5,
mode='max',
monitor='val_acc_top5',
dirpath=checkpoint_dir,
filename='{epoch:04d}-{val_acc_top5:.2f}'),
LearningRateMonitor('epoch')])
net = SimCLR(
batch_size=128,
hidden_dim=128,
lr=5e-4,
temperature=0.07,
weight_decay=1e-4,
max_epochs=max_epochs)
trainer.fit(net)
Not even using monai, just splitting the transforms of STL10 into two parts, resulting in no change in loss.
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
import os, sys, glob
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from torchvision.models import resnet18
from torchvision.datasets import STL10
from torchvision import transforms
from torch.utils.data import DataLoader
class ContrastiveTransformations(object):
def __init__(self, base_transforms, n_views=2):
self.base_transforms = base_transforms
self.n_views = n_views
def __call__(self, x):
return [self.base_transforms(x) for i in range(self.n_views)]
class SimCLR(LightningModule):
def __init__(self, hidden_dim, lr, temperature, weight_decay, batch_size, max_epochs=500):
super().__init__()
self.save_hyperparameters()
assert self.hparams.temperature > 0.0, 'The temperature must be a positive float!'
# Base model f(.)
self.convnet = resnet18(pretrained=False, num_classes=4*hidden_dim) # Output of last linear layer
# The MLP for g(.) consists of Linear->ReLU->Linear
self.convnet.fc = nn.Sequential(
self.convnet.fc, # Linear(ResNet output, 4*hidden_dim)
nn.ReLU(inplace=True),
nn.Linear(4*hidden_dim, hidden_dim)
)
def prepare_data(self):
self.unlabeled_data = STL10(root='datasets', split='unlabeled', download=False,
transform=transforms.Compose([transforms.ToTensor()]))
self.train_data_contrast = STL10(root='datasets', split='train', download=False,
transform=transforms.Compose([transforms.ToTensor()]))
self.contrast_transforms = ContrastiveTransformations(base_transforms=transforms.Compose([
transforms.Normalize((0.5,), (0.5,))
]))
def train_dataloader(self):
return DataLoader(self.unlabeled_data, batch_size=self.hparams.batch_size, shuffle=True,
drop_last=True, pin_memory=True, num_workers=4)
def val_dataloader(self):
return DataLoader(self.train_data_contrast, batch_size=self.hparams.batch_size, shuffle=False,
drop_last=False, pin_memory=True, num_workers=4)
def configure_optimizers(self):
optimizer = optim.AdamW(self.parameters(),
lr=self.hparams.lr,
weight_decay=self.hparams.weight_decay)
lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
T_max=self.hparams.max_epochs,
eta_min=self.hparams.lr/50)
return [optimizer], [lr_scheduler]
def info_nce_loss(self, batch, mode='train'):
imgs, _ = batch
_imgs = list()
for i in imgs:
img = self.contrast_transforms(i)
_imgs.append(img[0].unsqueeze(0))
_imgs.append(img[1].unsqueeze(0))
imgs = torch.cat(_imgs, dim=0)
# Encode all images
feats = self.convnet(imgs)
# Calculate cosine similarity
cos_sim = F.cosine_similarity(feats[:,None,:], feats[None,:,:], dim=-1)
# Mask out cosine similarity to itself
self_mask = torch.eye(cos_sim.shape[0], dtype=torch.bool, device=cos_sim.device)
cos_sim.masked_fill_(self_mask, -9e15)
# Find positive example -> batch_size//2 away from the original example
pos_mask = self_mask.roll(shifts=cos_sim.shape[0]//2, dims=0)
# InfoNCE loss
cos_sim = cos_sim / self.hparams.temperature
nll = -cos_sim[pos_mask] + torch.logsumexp(cos_sim, dim=-1)
nll = nll.mean()
# Logging loss
self.log(mode+'_loss', nll)
# Get ranking position of positive example
comb_sim = torch.cat([cos_sim[pos_mask][:,None], # First position positive example
cos_sim.masked_fill(pos_mask, -9e15)],
dim=-1)
sim_argsort = comb_sim.argsort(dim=-1, descending=True).argmin(dim=-1)
# Logging ranking metrics
self.log(mode+'_acc_top1', (sim_argsort == 0).float().mean())
self.log(mode+'_acc_top5', (sim_argsort < 5).float().mean())
self.log(mode+'_acc_mean_pos', 1+sim_argsort.float().mean())
return nll
def training_step(self, batch, batch_idx):
return self.info_nce_loss(batch, mode='train')
def validation_step(self, batch, batch_idx):
self.info_nce_loss(batch, mode='val')
if __name__ == '__main__':
seed_everything(42)
tb_logger = TensorBoardLogger(save_dir='logs', name='SimCLR')
checkpoint_dir = os.path.join(tb_logger.save_dir, tb_logger.name, 'version_%d'%tb_logger.version,'checkpoints')
max_epochs = 500
trainer = Trainer(gpus=[0],
max_epochs=max_epochs,
logger=tb_logger,
enable_progress_bar=True,
enable_checkpointing=True,
num_sanity_val_steps=1,
callbacks=[ModelCheckpoint(save_weights_only=True,
save_top_k=5,
mode='max',
monitor='val_acc_top5',
dirpath=checkpoint_dir,
filename='{epoch:04d}-{val_acc_top5:.2f}'),
LearningRateMonitor('epoch')])
net = SimCLR(
batch_size=128,
hidden_dim=128,
lr=5e-4,
temperature=0.07,
weight_decay=1e-4,
max_epochs=max_epochs)
trainer.fit(net)