Giter VIP home page Giter VIP logo

ruc's Introduction

Improving Unsupervised Image Clustering With Robust Learning

This repo is the PyTorch codes for "Improving Unsupervised Image Clustering With Robust Learning (RUC)"

Improving Unsupervised Image Clustering With Robust Learning

Sungwon Park, Sungwon Han, Sundong Kim, Danu Kim, Sungkyu Park, Seunghoon Hong, Meeyoung Cha.

Highlight

  1. RUC is an add-on module to enhance the performance of any off-the-shelf unsupervised learning algorithms. RUC is inspired by robust learning. It first divides clustered data points into clean and noisy set, then refine the clustering results. With RUC, state-of-the-art unsupervised clustering methods; SCAN and TSUC showed showed huge performance improvements. (STL-10 : 86.7%, CIFAR-10 : 90.3%, CIFAR-20 : 54.3%, CIFAR-100 : 36.5 %, ImageNet-50 : 78.5)

  1. Prediction results of existing unsupervised learning algorithms were overconfident. RUC can make the prediction of existing algorithms softer with better calibration.

  1. Robust to adversarially crafted samples. ERM-based unsupervised clustering algorithms can be prone to adversarial attack. Adding RUC to the clustering models improves robustness against adversarial noise.

Required packages

  • python == 3.6.10
  • pytorch == 1.1.0
  • scikit-learn == 0.21.2
  • scipy == 1.3.0
  • numpy == 1.18.5
  • pillow == 7.1.2

Overall model architecture

Usage

usage: main_ruc_[dataset].py [-h] [--lr LR] [--momentum M] [--weight_decay W]
                         [--epochs EPOCHS] [--batch_size B] [--s_thr S_THR]
                         [--n_num N_NUM] [--o_model O_MODEL]
                         [--e_model E_MODEL] [--seed SEED]

config for RUC

optional arguments:
  -h, --help            show this help message and exit
  --lr LR               initial learning rate
  --momentum M          momentum
  --weight_decay        weight decay
  --epochs EPOCHS       max epoch per round. (default: 200)
  --batch_size B        training batch size
  --s_thr S_THR         confidence sampling threshold
  --n_num N_NUM         the number of neighbor for metric sampling
  --o_model O_MODEL     original model path
  --e_model E_MODEL     embedding model path
  --seed SEED           random seed

Model ZOO

Currently, we support the pretrained model for our model. We used the pretrained SCAN and SimCLR model from SCAN github. (o_model : SCAN , e_model : SimCLR)

SCAN

Dataset o_model e_model
CIFAR-10 Download Download
CIFAR-20 Download Download
STL-10 Download Download

Ours

Dataset Download link
CIFAR-10 Download
CIFAR-20 Download
STL-10 Download

Citation

If you find this repo useful for your research, please consider citing our paper:

@inproceedings{park2021improving,
  title={Improving Unsupervised Image Clustering With Robust Learning},
  author={Park, Sungwon and Han, Sungwon and Kim, Sundong and Kim, Danu and Park, Sungkyu and Hong, Seunghoon and Cha, Meeyoung},
  booktitle={CVPR},
  year={2021}
}

ruc's People

Contributors

deu30303 avatar seondong avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

ruc's Issues

about cluster visualization

Could you tell me how did you visualize clusters just like calibration.png?
I would really appreciate if you let me know the method. Thank you in advance.

tabular data/ noisy instances

Hi,
thanks for sharing your implementation. I have two questions about it:

  1. Does it also work on tabular data?
  2. Is it possible to identify the noisy instances (return the noisy IDs or the clean set).

Thanks!

original model vs embedding model

Hello
what different between
--o_model O_MODEL original model path
--e_model E_MODEL embedding model path

Where do I find them?
There is only one model

Screenshot from 2021-06-02 12-45-40

Thanks

How to get my clustering result?

Hi~Thanks for your excellent work.
I trained RUC on CIFAR10 datasets,but I just get the trained model.
I want to know how to get clustering result?
If I want to save label corresponding to each sample to a txt file , what should I do?

和DividMix论文比较

粗看了一下代码和论文,似乎除了伪标签是通过匈牙利匹配算法生产之外,其他的和DividMix(代码也已开源)基本差不多,训练策略基本都一样,不知道作者如何解释,或许是本人研究不够深,还请作者指点迷津....

Retrain

There is no documentation to retrain

ImageNet dataset

Thank you so much for your amazing work!

Could you please provide the script used for ImageNet dataset?

I would like to train RUC on my own dataset which has a similar image size to ImageNet. However, I cannot train the SCAN-pretrained model (MOCO) with an 8G GPU, even if I resize the image to 96x96 and set the batch size to 2. Thus I would like to ask, which GPUs did you use in your ImageNet experiments and how much memory do they have? If I understand your paper well, you resize them to 256x256, right? And what is the batch size you use?

Many thanks in advance!

Misatch module names

Hi, I have trained the model using the SCAN GitHub and generated two models under the scan and self-label directory as instructed in the SCAN GitHub. When I use those models to run main_ruc_cifar10.py, I get the following error.

Traceback (most recent call last):
File "main_ruc_cifar10.py", line 376, in
main()
File "main_ruc_cifar10.py", line 317, in main
net_uc.load_state_dict(state_dict)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1605, in load_state_dict
self.class.name, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for ClusteringModel:
Missing key(s) in state_dict: "backbone.conv1.weight", "backbone.bn1.weight", "backbone.bn1.bias", "backbone.bn1.running_mean", "backbone.bn1.running_var", "backbone.layer1.0.conv1.weight", "backbone.layer1.0.bn1.weight", "backbone.layer1.0.bn1.bias", "backbone.layer1.0.bn1.running_mean", "backbone.layer1.0.bn1.running_var", "backbone.layer1.0.conv2.weight", "backbone.layer1.0.bn2.weight", "backbone.layer1.0.bn2.bias", "backbone.layer1.0.bn2.running_mean", "backbone.layer1.0.bn2.running_var", "backbone.layer1.1.conv1.weight", "backbone.layer1.1.bn1.weight", "backbone.layer1.1.bn1.bias", "backbone.layer1.1.bn1.running_mean", "backbone.layer1.1.bn1.running_var", "backbone.layer1.1.conv2.weight", "backbone.layer1.1.bn2.weight", "backbone.layer1.1.bn2.bias", "backbone.layer1.1.bn2.running_mean", "backbone.layer1.1.bn2.running_var", "backbone.layer2.0.conv1.weight", "backbone.layer2.0.bn1.weight", "backbone.layer2.0.bn1.bias", "backbone.layer2.0.bn1.running_mean", "backbone.layer2.0.bn1.running_var", "backbone.layer2.0.conv2.weight", "backbone.layer2.0.bn2.weight", "backbone.layer2.0.bn2.bias", "backbone.layer2.0.bn2.running_mean", "backbone.layer2.0.bn2.running_var", "backbone.layer2.0.shortcut.0.weight", "backbone.layer2.0.shortcut.1.weight", "backbone.layer2.0.shortcut.1.bias", "backbone.layer2.0.shortcut.1.running_mean", "backbone.layer2.0.shortcut.1.running_var", "backbone.layer2.1.conv1.weight", "backbone.layer2.1.bn1.weight", "backbone.layer2.1.bn1.bias", "backbone.layer2.1.bn1.running_mean", "backbone.layer2.1.bn1.running_var", "backbone.layer2.1.conv2.weight", "backbone.layer2.1.bn2.weight", "backbone.layer2.1.bn2.bias", "backbone.layer2.1.bn2.running_mean", "backbone.layer2.1.bn2.running_var", "backbone.layer3.0.conv1.weight", "backbone.layer3.0.bn1.weight", "backbone.layer3.0.bn1.bias", "backbone.layer3.0.bn1.running_mean", "backbone.layer3.0.bn1.running_var", "backbone.layer3.0.conv2.weight", "backbone.layer3.0.bn2.weight", "backbone.layer3.0.bn2.bias", "backbone.layer3.0.bn2.running_mean", "backbone.layer3.0.bn2.running_var", "backbone.layer3.0.shortcut.0.weight", "backbone.layer3.0.shortcut.1.weight", "backbone.layer3.0.shortcut.1.bias", "backbone.layer3.0.shortcut.1.running_mean", "backbone.layer3.0.shortcut.1.running_var", "backbone.layer3.1.conv1.weight", "backbone.layer3.1.bn1.weight", "backbone.layer3.1.bn1.bias", "backbone.layer3.1.bn1.running_mean", "backbone.layer3.1.bn1.running_var", "backbone.layer3.1.conv2.weight", "backbone.layer3.1.bn2.weight", "backbone.layer3.1.bn2.bias", "backbone.layer3.1.bn2.running_mean", "backbone.layer3.1.bn2.running_var", "backbone.layer4.0.conv1.weight", "backbone.layer4.0.bn1.weight", "backbone.layer4.0.bn1.bias", "backbone.layer4.0.bn1.running_mean", "backbone.layer4.0.bn1.running_var", "backbone.layer4.0.conv2.weight", "backbone.layer4.0.bn2.weight", "backbone.layer4.0.bn2.bias", "backbone.layer4.0.bn2.running_mean", "backbone.layer4.0.bn2.running_var", "backbone.layer4.0.shortcut.0.weight", "backbone.layer4.0.shortcut.1.weight", "backbone.layer4.0.shortcut.1.bias", "backbone.layer4.0.shortcut.1.running_mean", "backbone.layer4.0.shortcut.1.running_var", "backbone.layer4.1.conv1.weight", "backbone.layer4.1.bn1.weight", "backbone.layer4.1.bn1.bias", "backbone.layer4.1.bn1.running_mean", "backbone.layer4.1.bn1.running_var", "backbone.layer4.1.conv2.weight", "backbone.layer4.1.bn2.weight", "backbone.layer4.1.bn2.bias", "backbone.layer4.1.bn2.running_mean", "backbone.layer4.1.bn2.running_var", "cluster_head.0.weight", "cluster_head.0.bias".
Unexpected key(s) in state_dict: "model", "head".

Please help me resolve this

Does RUC separatate train and test set?

Hi,
I'm confused as to whether the reported performances are with train and test separated i.e training only with train dataset and evaluate using test dataset (I believe this is how SCAN reported results). In code, it seems like both the training dataset and test dataset are loaded with train=True meaning the model uses only the train dataset for training and reporting performance?!

Thank you for your time and for open-sourcing the code. ✌️

Final model selection

Hi, thank you again for sharing the code!

I wonder which model did you use for the evaluation in the paper, Model 1 or Model 2 or some kind of ensembling (for example prediction3 in function test_ruc in lib/protocols.py)?
And how did you select the best model? By selecting the best accuracy of prediction3 on training or validation set?

Thank you so much!

About “ checkpoints ”

Excuse me, could you please provide the documents of selflabel_cifar-10.pth.tar and simclr_cifar-10.pth.tar ?

Can I train the model from scratch?

Hi, I found that if I don't use the preetrained model, the code don't work. So I can't train from scratch? If I need to train with other dataset like Imagenet, or use other backbone like resnet50, how can I find pretrained model? Thank you.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.