Table of contents
Implementation of the paper "ELSR: Extreme Low-Power Super Resolution Network For Mobile Devices" using PyTorch. The code replicates the method proposed by the paper, but it is meant to be trained on limited devices. For that purpose the dataset is drastically smaller, and the training is way simpler.
- pytorch=1.13.1
- opencv=4.7.0
- pillow=9.4.0
- matplotlib
If you use Anaconda on Windows you can just:
conda create -n elsr --file requirements.txt
Once installed the required packages, download the dataset I used to run the training. Alternatively you can download the entire REDS dataset from here.
ELSR is trained on the REDS dataset, composed of sets of 300 videos, each set has a different degradation. My model is trained on a drastically reduced version of the dataset, containing only 30 videos with lower resolution (the original dataset was too big for me to train). The dataset (h5 files) is available at the following link: https://drive.google.com/drive/folders/158bbeXr6EtCiuLI5wSh3SYRWMaWxK0Mq?usp=sharing.
To prevent overfitting and achieve better training results, I've done some random data augmentation (see augment_data() in preprocessing.py). An example of augmentation by rotation is shown below:
The ELSR model is a small sub-pixel convolutional neural network with 6 layers. Only 5 of them have learnable parameters. The architecture is shown in the image below:
The PixelShuffle block (also known as depth2space) that performs computationally efficient upsampling by rearranging pixels in an image to increase its spatial resolution. Formally, let x be a tensor of size (batch_size, C_in, H_in, W_in), where C_in is the number of input channels, H_in and W_in are the height and width of the input, respectively. The goal of PixelShuffle is to upsample the spatial resolution of x by a factor of r, meaning that the output should be a tensor of size (batch_size, C_out, H_in * r, W_in * r), where C_out = C_in // r^2.
To train the model run:
python training.py \
--train <training_dataset_path> \
--val <validation_dataset_path> \
--out <path_for_best_model> \
--weights <weights_path(not required)>
To test the model run:
python training.py \
--weights <weights_path(not required)> \
--input <input_frames_path>
The training of the ELSR model is split in 6 steps in the paper, using different loss functions and different frame patch sizes. Nonetheless, for this implementation the images in the dataset are much smaller, hence only 3 steps are needed since we can use full-size images. Notice the number of epochs is reduced and the learning rate scheduler of the first training step is used even in the others.
Train the model on the x2 dataset using the L1 loss:
python training.py \
--train "datasets/h5/train_X2.h5" \
--val "datasets/h5/val_X2.h5" \
--out "checkpoints/" \
--scale 2 \
--epochs 300 \
--loss "mae" \
--lr 0.01
Fine-tune the pre-trained model from step 1 using the x4 dataset. Use L1 loss and use a higher learning rate. In the paper this is done in 2 steps, using different patch-sizes.
python training.py \
--train "datasets/h5/train_X4.h5" \
--val "datasets/h5/val_X4.h5" \
--out "checkpoints/" \
--scale 4 \
--epochs 50 \
--loss "mae" \
--lr 0.05 \
--weights "best_X2_model.pth"
Fine-tune the pre-trained model from step 2 using the x4 dataset. Use MSE loss and use a lower learning rate. In the paper this is done in 3 steps, using different patch-sizes.
python training.py \
--train "datasets/h5/train_X4.h5" \
--val "datasets/h5/val_X4.h5" \
--out "checkpoints/" \
--scale 4 \
--epochs 250 \
--loss "mse" \
--lr 5e-3 --weights "best_X4_model.pth"
Due to the limited size of the dataset I wasn't able to replicate the papers results, but indeed there are interesting results proving that video-super-resolution can be done in such a small model. The graphs below are the training losses through each training step:
The testing of single frame super-resolution is done in this way (video-sr is achieved by iterating sr on every frame):
- Resize the input image to (image.height // upscale_factor, image.width // upscale_factor) using Bicubic interpolation
- Calculate the bicubic_upsampled image of the previously produced low resolution image by the same upscaling factor using Bicubic interpolation
- Use the low resolution image to predict the sr_image
- Calculate PSNR between sr_image and bicubic_upsampled The results are shown below:
The PSNR of the generated image has shown to be lower, but the resulting images are smoother, making bigger images better-looking:
Blurring stands out in pixelated images:
Of course tests on videos have been done. To achieve "real-time" video-sr the model should be able to produce at least 30 FPS on edge devices, I couldn't test the model on mobile, but on GPU the video is produced at 2500+ FPS (see project_report.ipynb). GIFs below:
Bicubic GIF: 28.20 dB | ELSR GIF: 28.45 dB |
---|---|
You can find a complete project report in this notebook.