Giter VIP home page Giter VIP logo

segmentation_models's Introduction

Segmentation Models

Segmentation models is python library with Neural Networks for Image Segmentation based on Keras (Tensorflow) framework.

The main features of this library are:

  • High level API (just two lines to create NN)
  • 4 models architectures for binary and multi class segmentation (including legendary Unet)
  • 25 available backbones for each architecture
  • All backbones have pre-trained weights for faster and better convergence

Table of Contents

Quick start

Since the library is built on the Keras framework, created segmentaion model is just a Keras Model, which can be created as easy as:

from segmentation_models import Unet

model = Unet()

Depending on the task, you can change the network architecture by choosing backbones with fewer or more parameters and use pretrainded weights to initialize it:

model = Unet('resnet34', encoder_weights='imagenet')

Change number of output classes in the model (choose your case):

# binary segmentation (this parameters are default when you call Unet('resnet34')
model = Unet('resnet34', classes=1, activation='sigmoid')
# multiclass segmentation with non overlapping class masks (your classes + background)
model = Unet('resnet34', classes=3, activation='softmax')
# multiclass segmentation with independent overlapping/non-overlapping class masks
model = Unet('resnet34', classes=3, activation='sigmoid')

Change input shape of the model:

# if you set input channels not equal to 3, you have to set encoder_weights=None
# how to handle such case with encoder_weights='imagenet' described in docs
model = Unet('resnet34', input_shape=(None, None, 6), encoder_weights=None)

Simple training pipeline

from segmentation_models import Unet
from segmentation_models.backbones import get_preprocessing
from segmentation_models.losses import bce_jaccard_loss
from segmentation_models.metrics import iou_score

BACKBONE = 'resnet34'
preprocess_input = get_preprocessing(BACKBONE)

# load your data
x_train, y_train, x_val, y_val = load_data(...)

# preprocess input
x_train = preprocess_input(x_train)
x_val = preprocess_input(x_val)

# define model
model = Unet(BACKBONE, encoder_weights='imagenet')
model.compile('Adam', loss=bce_jaccard_loss, metrics=[iou_score])

# fit model
# if you use data generator use model.fit_generator(...) instead of model.fit(...)
# more about `fit_generator` here: https://keras.io/models/sequential/#fit_generator
model.fit(
    x=x_train,
    y=y_train,
    batch_size=16,
    epochs=100,
    validation_data=(x_val, y_val),
)

Same manimulations can be done with Linknet, PSPNet and FPN. For more detailed information about models API and use cases Read the Docs.

Models and Backbones

Models

Unet Linknet
unet_image linknet_image
PSPNet FPN
psp_image fpn_image

Backbones

Type Names
VGG 'vgg16' 'vgg19'
ResNet 'resnet18' 'resnet34' 'resnet50' 'resnet101' 'resnet152'
SE-ResNet 'seresnet18' 'seresnet34' 'seresnet50' 'seresnet101' 'seresnet152'
ResNeXt 'resnext50' 'resnext101'
SE-ResNeXt 'seresnext50' 'seresnext101'
SENet154 'senet154'
DenseNet 'densenet121' 'densenet169' 'densenet201'
Inception 'inceptionv3' 'inceptionresnetv2'
MobileNet 'mobilenet' 'mobilenetv2'
All backbones have weights trained on 2012 ILSVRC ImageNet dataset (encoder_weights='imagenet').

Installation

Requirements

  1. Python 3.5+
  2. Keras >= 2.2.0
  3. Keras Application >= 1.0.7
  4. Image Classifiers == 0.2.0
  5. Tensorflow 1.9 (tested)

Pip package

$ pip install segmentation-models

Latest version

$ pip install git+https://github.com/qubvel/segmentation_models

Documentation

Latest documentation is avaliable on Read the Docs

Change Log

To see important changes between versions look at CHANGELOG.md

Citing

@misc{Yakubovskiy:2019,
  Author = {Pavel Yakubovskiy},
  Title = {Segmentation Models},
  Year = {2019},
  Publisher = {GitHub},
  Journal = {GitHub repository},
  Howpublished = {\url{https://github.com/qubvel/segmentation_models}}
}

License

Project is distributed under MIT Licence.

segmentation_models's People

Contributors

qubvel avatar gazay avatar mathandy avatar ilyaovodov avatar gagolucasm avatar tyler-d avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.