Giter VIP home page Giter VIP logo

dccm's Introduction

DCCM

This repository is a PyTorch implementation for Deep Comprehensive Correlation Mining for Image Clustering (accepted to ICCV 2019) at https://arxiv.org/abs/1904.06925?context=cs.CV

by Jianlong Wu*, Keyu Long*, Fei Wang, Chen Qian, Cheng Li, Zhouchen Lin and Hongbin Zha.

citation

If you find DCCM useful in your research, please consider citing:

@inproceedings{DCCM,
    author={Wu, Jianlong and Long, Keyu and Wang, Fei and Qian, Chen and Li, Cheng and Lin, Zhouchen and Zha, Hongbin},
    title={Deep Comprehensive Correlation Mining for Image Clustering},
    booktitle={International Conference on Computer Vision},   
    year={2019},   
}

Table of contents

Introduction

DCCM Figure 1. The pipeline of the proposed DCCM.

Usage

To train with CIFAR10/100 datasets, try:

$ python main.py --config cfgs/cifar10.yaml
$ python main.py --config cfgs/cifar100.yaml

To resume with a certain checkpoint , try:

$ python main.py --config cfgs/xx.yaml --resume xxx.ckpt

Parameters and datapaths can be modified in the config files.

Note that we use meta-files (examples could be found in the folder 'meta') to load data.

Requirments

  • a Python installation version 3.6.5
  • a Pytorch installation version 0.4.1
  • a Keras installation version 2.0.2
  • download the image dataset and stored according to the meta-files

Please note that all reported performance are tested under this environment.

Comparisons with SOTAs

Table 1. Clustering performance of different methods on six challenging datasets. Results

Reference Github Repos

Our group at SenseTime Research is looking for algorithm researchers and engineers. Our research interests include object detection, tracking, classification, and segmentation, auto network search, network compression and quantization on mobile terminals, 3d gaze tracking, computer vision related SDK, and product platform development. Our group aims to pioneer the computer vision based IOT industry. We have a lot of NOI & ACM gold medal winners, and thousands of GPU Cards. Our team has win the world champions of MegaFace in face recognition and VOT challenge in object tracking, and has published many research papers in top conferences, such as CVPR, ICCV, ECCV, and NeurIPS. Please feel free to contact us with Wechat#: 18810636695 or Email: [email protected] if you are interested in our group.

dccm's People

Contributors

cory-m avatar jlwu1992 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

dccm's Issues

other architectures and config files

can you please upload other architectures and the config yml files for other datasets such as STL10, and TinyImageNet?

Also, in the supplementary you mention avg pool is used with stride 2 for cifar10/100 but in the code you use stride 4? Which one is correct?

A mistake about dim_Model.py ?

Sorry to bother you, but i found a contradiction in your code.
I read your paper, and find the difference between the E.q(14) and your actual code in dim_Model.py.

In your dim_Model.py, your code is:
# local loss
Ej = -F.softplus(-self.local_D(Y_cat_M)).mean()
Em = -F.softplus(self.local_D(Y_cat_M_fake)).mean()
local_loss = -(Em + Ej)

but for E.q(14), it should be:
local_loss = -( Ej-Em)

Is some thing wrong in my analysis?
Or it's an error in your code?
Looking forward to your reply!

ValueError: NumpyArrayIterator is set to use the data format convention "channels_last"

Excuse me, i have a question that cant solve for a long time.

I want to use your code to deal with clustering task by using NUS-WIDE dataset.
I organize these images on your suggestion in 'data' folder and put nus_lable.txt in 'meta' folder.
Then i add 'resize' operation in your 'mc_dataset.py'
It's just about:
1.from skimage.transform import resize
2.img = resize(img, (32, 32)) ------before 30 line

Besides, I just modified cifar10.yaml for:
small_bs: 16
workers: 4
num_classes: 4

Other than the above, I didn't modify any other files.(Except some notes I wrote)

But when i run 'main_2.py' using 'python main_2.py --config cfgs/cifar10.yaml'
(main_2.py is just a copy of main.py and added with my note)
It shows:

Traceback (most recent call last):
File "main_2.py", line 309, in
main()
File "main_2.py", line 140, in main
train(dataloader, model, dim_loss, crit_label, crit_graph, crit_c, optimizer, epoch, datagen, tb_logger)
File "main_2.py", line 203, in train
for X_batch_i in datagen.flow(input_bs,batch_size=args.small_bs,shuffle=False):
File "/home/sky/anaconda2/lib/python2.7/site-packages/keras/preprocessing/image.py", line 460, in flow
save_format=save_format)
File "/home/sky/anaconda2/lib/python2.7/site-packages/keras/preprocessing/image.py", line 782, in init
' (' + str(self.x.shape[channels_axis]) + ' channels).')
ValueError: NumpyArrayIterator is set to use the data format convention "channels_last" (channels on axis 3), i.e. expected either 1, 3 or 4 channels on axis 3. However, it was passed an array with shape (16, 3, 32, 32) (32 channels).

I dont know how to solve this problem.
Is something wrong for my modification?
Or your main.py has something need to revise?
I would appreciate it if you could help me solve this problem!

Unhandled last batch truncation

It seems that line 145 in main.py is the cause for an index out of range in subsequent lines when training on the last batch of a dataset that has size non divisible by large_bs.

Correcting the line as follows should fix the issue
#index_loc = np.arange(args.large_bs)
# We use the actual batch size instead of the parameter to avoid errors on truncated batches
index_loc = np.arange(input_tensor.size()[0])

Targets(Labels) do need??

Hello Sir,

For Clustering, I think that the target(label) needs not for clustering.
But I saw targets(labels) on your code.

Targets(Labels) do need on training for your code??

Thanks,
Edward Cho.

mutual information loss

Dear author, your research results are great and enlightening, but after I read the code, I want to know where the mutual information is reflected in the code as a loss function.

inferior results on all datasets compared to that reported in the paper

I have been trying to reproduce the results mentioned in the paper using the same libraries as you mentioned and using the exp configs provided by you. But the results I get on GTX 1080 Ti are completely different. Also, you mentioned that the training takes only 19 hours but instead it has taken me around 3 days to train it on CIFAR100 for just 100 epochs and still its training.

Are these config files correct? Has anybody else been able to reproduce the results?

Below are the results which i get
STL10--->
[2020-04-12 17:56:34,340][main.py][line:272][INFO][rank:0] Epoch: [199/200] ARI against ground truth label: 0.182
[2020-04-12 17:56:34,353][main.py][line:273][INFO][rank:0] Epoch: [199/200] NMI against ground truth label: 0.296
[2020-04-12 17:56:34,372][main.py][line:274][INFO][rank:0] Epoch: [199/200] ACC against ground truth label: 0.368

CIFAR10 (training still not finished in 3 days)--->
[2020-04-13 03:57:40,966][main.py][line:272][INFO][rank:0] Epoch: [106/200] ARI against ground truth label: 0.305
[2020-04-13 03:57:40,968][main.py][line:273][INFO][rank:0] Epoch: [106/200] NMI against ground truth label: 0.407
[2020-04-13 03:57:40,968][main.py][line:274][INFO][rank:0] Epoch: [106/200] ACC against ground truth label: 0.463

CIFAR100 (training still not finished in 3 days)--->
FutureWarning)
[2020-04-13 03:48:53,305][main.py][line:272][INFO][rank:0] Epoch: [105/200] ARI against ground truth label: 0.169
[2020-04-13 03:48:53,307][main.py][line:273][INFO][rank:0] Epoch: [105/200] NMI against ground truth label: 0.282
[2020-04-13 03:48:53,308][main.py][line:274][INFO][rank:0] Epoch: [105/200] ACC against ground truth label: 0.308

How to use other model(eg. vgg16) to run your code?

I want to use a model with more layers (eg. vgg) to deal with problems in my field, but your code only provides the cifar_c4_L2 model. If I want to customize my own model, what modifications should I make?

I noticed that there are layers, c_layer and classifier in the cifar10.yaml file. Is this related to calculating the mutual information between the shallow and deep layers?
And how should this value be modified in the custom model?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.