Giter VIP home page Giter VIP logo

picanet-implementation's Introduction

PiCANet-Implementation

Pytorch Implementation of PiCANet: Learning Pixel-wise Contextual Attention for Saliency Detection

New method on implementing PiCANet

  • Issue#9
  • Conv3d version is deleted.

input image target_image

  • batchsize 1 training_result
  • batchsize 4 training_result

PPT(Korean)

https://www.slideshare.net/JaehoonYoo5/picanet-pytorch-implementation-korean

Top 10 Performance Test with F-score (beta-square = 0.3)

batchsize:4 The results with wrong measurement.

  • Issue#9
    Will be updated soon with correct measurement (still may not correct)
    (Using pytorch/measure_test.py)
Step Value Threshold MAE
214000 0.8520 0.6980 0.0504
259000 0.8518 0.6510 0.0512
275000 0.8533 0.6627 0.0536
281000 0.8540 0.7451 0.0515
307000 0.8518 0.8078 0.0523
383000 0.8546 0.6627 0.0532
399000 0.8561 0.7882 0.0523
400000 0.8544 0.7804 0.0512
408000 0.8535 0.5922 0.0550
410000 0.8518 0.7882 0.0507

Execution Guideline

Requirements

Pillow==4.3.0
pytorch==0.4.1
tensorboardX==1.1
torchvision==0.2.1
numpy==1.14.2

My Environment

S/W
Windows 10
CUDA 9.0
cudnn 7.0
python 3.5
H/W
AMD Ryzen 1700
Nvidia gtx 1080ti
32GB RAM

Docker environment

docker run --name=torch -it -rm pytorch/pytorch
pip install tensorboardX
pip install datetime
git clone https://github.com/Ugness/PiCANet-Implementation

Docker image(Unstable)

https://hub.docker.com/r/wogns98/picanet/
based on pytorch/pytorch
codes in /workspace/PiCANet-Implementation
You can run code by add images and download models from google drive

You can run the file by following the descriptions in -h option.

python train.py -h
    usage: train.py [-h] [--load LOAD] [--dataset DATASET] [--cuda CUDA]
                [--batch_size BATCH_SIZE] [--epoch EPOCH] [-lr LEARNING_RATE]
                [--lr_decay LR_DECAY] [--decay_step DECAY_STEP]
optional arguments:
-h, --help            show this help message and exit
--load LOAD           Directory of pre-trained model, you can download at 
                    https://drive.google.com/file/d/109a0hLftRZ5at5hwpteRfO1
                    A6xLzf8Na/view?usp=sharing
                    None --> Do not use pre-trained model. 
                    Training will start from random initialized model
--dataset DATASET     Directory of your DUTS dataset "folder"
--cuda CUDA           'cuda' for cuda, 'cpu' for cpu, default = cuda
--batch_size BATCH_SIZE
                    batchsize, default = 1
--epoch EPOCH         # of epochs. default = 20
-lr LEARNING_RATE, --learning_rate LEARNING_RATE
                    learning_rate. default = 0.001
--lr_decay LR_DECAY   Learning rate decrease by lr_decay time per decay_step, default = 0.1
--decay_step DECAY_STEP
                    Learning rate decrease by lr_decay time per decay_step, default = 7000
python Image_Test.py -h
    usage: Image_Test.py [-h] [--model_dir MODEL_DIR] [-img --image_dir IMAGE_DIR]
                         [--cuda CUDA] [--batch_size BATCH_SIZE]
optional arguments:
  -h, --help            show this help message and exit
  --model_dir MODEL_DIR
                        Directory of pre-trained model, you can download at
                        https://drive.google.com/drive/folders/1s4M-_SnCPMj_2rsMkSy3pLnLQcgRakAe?usp=sharing
  -img IMAGE_DIR, --image_dir IMAGE_DIR
                        Directory of your test_image ""folder""
  --cuda CUDA           'cuda' for cuda, 'cpu' for cpu, default = cuda
  --batch_size BATCH_SIZE
                        batchsize, default = 4

Detailed Guideline

Pretrained Model

You can download pre-trained models from https://drive.google.com/drive/folders/1s4M-_SnCPMj_2rsMkSy3pLnLQcgRakAe?usp=sharing

Dataset

I used DUTS dataset as Training dataset and Test dataset.
You can download dataset from http://saliencydetection.net/duts/#outline-container-orgab269ec.

  • Caution: You should check the dataset's Image and GT are matched or not. (ex. # of images, name, ...)
  • You can match the file names and automatically remove un-matched datas by using DUTSdataset.arrange(self) method

Execution Example

Assume you train the model with

  • current dir: Pytorch/
  • Dataset dir: Pytorch/DUTS-TE
  • Pretrained model dir: Pytorch/models/state_dict/07261950/10epo_1000000step.ckpt
  • Goal Epoch : 100
python train.py --load models/state_dict/07261950/10epo_1000000step.ckpt --dataset DUTS-TE --epoch 100

Assume you test the model with

  • current dir: Pytorch/
  • Testset dir: Pytorch/test
  • Pretrained model dir: Pytorch/models/state_dict/07261950/10epo_1000000step.ckpt
  • CPU mode
python Image_test.py --model_dir models/state_dict/07261950/10epo_1000000step.ckpt --image_dir test --cuda cpu

Directory & Name Format of .ckpt files

"models/state_dict//<#epo_#step>.ckpt"
  • The step is accumulated step from epoch 0.
  • If you want to change the format of pre-trained model, you should change the code in train.py line 61-65
    start_iter = int(load.split('epo_')[1].strip('step.ckpt')) + 1
    start_epo = int(load.split('/')[3].split('epo')[0])
    now = datetime.datetime.strptime(load.split('/')[2], '%m%d%H%M')
    

Test with Custom_Images

  • When you run Image_Test.py with your own Images, the images will saved in tensorboard log file.

  • Log files will saved in log/Image_test

  • You can see the images by execute

    tensorboard --logdir=./log/Image_test
    

    and browse 127.0.0.1:6006 with your browser(ex. Chrome, IE, etc)

picanet-implementation's People

Contributors

ugness avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.