Giter VIP home page Giter VIP logo

elsr-torch's Introduction

Table of contents

ELSR-torch

Implementation of the paper "ELSR: Extreme Low-Power Super Resolution Network For Mobile Devices" using PyTorch. The code replicates the method proposed by the paper, but it is meant to be trained on limited devices. For that purpose the dataset is drastically smaller, and the training is way simpler.

Requirements

  • pytorch=1.13.1
  • opencv=4.7.0
  • pillow=9.4.0
  • matplotlib

If you use Anaconda on Windows you can just:

conda create -n elsr --file requirements.txt 

Once installed the required packages, download the dataset I used to run the training. Alternatively you can download the entire REDS dataset from here.

Dataset

ELSR is trained on the REDS dataset, composed of sets of 300 videos, each set has a different degradation. My model is trained on a drastically reduced version of the dataset, containing only 30 videos with lower resolution (the original dataset was too big for me to train). The dataset (h5 files) is available at the following link: https://drive.google.com/drive/folders/158bbeXr6EtCiuLI5wSh3SYRWMaWxK0Mq?usp=sharing.

Data augmentation

To prevent overfitting and achieve better training results, I've done some random data augmentation (see augment_data() in preprocessing.py). An example of augmentation by rotation is shown below:

Model

The ELSR model is a small sub-pixel convolutional neural network with 6 layers. Only 5 of them have learnable parameters. The architecture is shown in the image below:

PixelShuffle

The PixelShuffle block (also known as depth2space) that performs computationally efficient upsampling by rearranging pixels in an image to increase its spatial resolution. Formally, let x be a tensor of size (batch_size, C_in, H_in, W_in), where C_in is the number of input channels, H_in and W_in are the height and width of the input, respectively. The goal of PixelShuffle is to upsample the spatial resolution of x by a factor of r, meaning that the output should be a tensor of size (batch_size, C_out, H_in * r, W_in * r), where C_out = C_in // r^2.

Usage

To train the model run:

python training.py	\
	--train <training_dataset_path>	\
	--val <validation_dataset_path>	\
	--out <path_for_best_model>	\
	--weights <weights_path(not required)>

To test the model run:

python training.py	\
	--weights <weights_path(not required)>	\
	--input <input_frames_path>

Training

The training of the ELSR model is split in 6 steps in the paper, using different loss functions and different frame patch sizes. Nonetheless, for this implementation the images in the dataset are much smaller, hence only 3 steps are needed since we can use full-size images. Notice the number of epochs is reduced and the learning rate scheduler of the first training step is used even in the others.

Training step 1

Train the model on the x2 dataset using the L1 loss:

python training.py \
	--train "datasets/h5/train_X2.h5" \
	--val "datasets/h5/val_X2.h5" \
	--out "checkpoints/" \
	--scale 2 \
	--epochs 300 \
	--loss "mae" \
	--lr 0.01

Training step 2

Fine-tune the pre-trained model from step 1 using the x4 dataset. Use L1 loss and use a higher learning rate. In the paper this is done in 2 steps, using different patch-sizes.

python training.py \
	--train "datasets/h5/train_X4.h5" \
	--val "datasets/h5/val_X4.h5" \
	--out "checkpoints/" \
	--scale 4 \
	--epochs 50 \
	--loss "mae" \
	--lr 0.05 \
	--weights "best_X2_model.pth"

Training step 3

Fine-tune the pre-trained model from step 2 using the x4 dataset. Use MSE loss and use a lower learning rate. In the paper this is done in 3 steps, using different patch-sizes.

python training.py \
	--train "datasets/h5/train_X4.h5" \
	--val "datasets/h5/val_X4.h5" \
	--out "checkpoints/" \
	--scale 4 \
	--epochs 250 \
	--loss "mse" \
	--lr 5e-3 --weights "best_X4_model.pth"

Results

Due to the limited size of the dataset I wasn't able to replicate the papers results, but indeed there are interesting results proving that video-super-resolution can be done in such a small model. The graphs below are the training losses through each training step:

Tests

The testing of single frame super-resolution is done in this way (video-sr is achieved by iterating sr on every frame):

  1. Resize the input image to (image.height // upscale_factor, image.width // upscale_factor) using Bicubic interpolation
  2. Calculate the bicubic_upsampled image of the previously produced low resolution image by the same upscaling factor using Bicubic interpolation
  3. Use the low resolution image to predict the sr_image
  4. Calculate PSNR between sr_image and bicubic_upsampled The results are shown below:

The PSNR of the generated image has shown to be lower, but the resulting images are smoother, making bigger images better-looking:

Blurring stands out in pixelated images:

Low-power real-time video super-resolution

Of course tests on videos have been done. To achieve "real-time" video-sr the model should be able to produce at least 30 FPS on edge devices, I couldn't test the model on mobile, but on GPU the video is produced at 2500+ FPS (see project_report.ipynb). GIFs below:

Bicubic GIF: 28.20 dB ELSR GIF: 28.45 dB

Project report

You can find a complete project report in this notebook.

elsr-torch's People

Contributors

andreacoppari avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

elsr-torch's Issues

Android tflite?

Interesting repo. I like to test it on Android. But do you have the tflite model? How can we run in on a phone?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.