Giter VIP home page Giter VIP logo

mixture_of_experts_keras's Introduction

Mixture of Experts on Convolutional Neural Network

Mixture of experts is a ensemble model of neural networks which consists of expert neural networks and gating networks. The expert model is a series of neural network that is specialized in a certain inference, such as classifying within artificial objects or within natural objects. The gating network is a discriminator network that decides which expert, or expers, to use for a certain input data, with importance of each expert.
The mixture of experts can take one gating network, if only deciding an importance of experts, or multiple gating networks, to probabilistically split decision phases to hierarchical order, just like decision tree diagram.

The expert models are pretrained to do only feed-forward inference in the mixture of experts model.
Training phase of the mixture of experts is to train the gating networks to improve decision making of which experts to use with weighted degree of importance of each experts.

This notebook shows a way to use mixture of experts model with deep learning. The objective is to classify images, using Cifar10 and convolution neural netwok. The mixture of experts model takes hierarchical multiple gating networks, to first decide if the input image is artificial object or natural object. Then the next gating network decides importance of each expert models.

There are three expert models:

basic VGG, which is trained to classify all 10 classes
artificial expert VGG, which is trained only to classify artificial objects, that have a label in 0, 1, 8 and 9
natural expert VGG, which is trained only to classify natural objects, that have a label in 2, 3, 4, 5, 6 and 7



The overview of the mixture of experts model

0.png

The first gating network, that decides which way to take, artificial or natural, is a pretrained VGG neural network, to classify the input data.

The second gating network layer, consists of two gating networks, decides the importance of each experts.
The artificial gating network flows classification job to artificial expert VGG and base VGG, only activated when the first gating network decided the input data is an artificial object.
The natural gating network flows classification job to natural expert VGG and base VGG, only the first gating netword decided as a natural object.

The classification output is a sum of softmax of expert VGG and base VGG, with importance from previous gating network multiplied.

Routing the networks

For instance, if the input image is a cat, then the first gating network identifies it is a natural object, routing to the natural gating network in the second layer gating. The natural gating network predicts importance of expert networks, base VGG and natural expert VGG, in softmax probability. The expert networks infers the image class in softmax, and each of them is multiplied by the importance to finally output the inference.
The routing of the gatings are as follows:
5.png
The notebook uses Keras with Tensorflow backend to implement the mixture of network model for classifying Cifar10.
The Cifar10 consists of 10 classes of images, with label of each class representing the following.

0 airplane
1 automobile
2 bird
3 cat
4 deer
5 dog
6 frog
7 horse
8 ship
9 truck

The mixture of experts neural network

model.png

mixture_of_experts_keras's People

Contributors

shibuiwilliam avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

mixture_of_experts_keras's Issues

Ask about RuntimeError: Graph disconnected

Dear Shibuiwilliam,
Thank you so much for your github about mixture of experts for keras.
I learnt a lot from them.
I have a problem related to "Graph disconnected" when I try to use them for my system.

  • I have 6 experts that are 6 different parallel CNN structures, and using the same input. The input of experts are a pair of 4D arrays, in which size of them is (None, 128,50,1) and each expert includes Input layer for the first layer of a model and returning of each model is model=Model([input1,input2], output) - these are differences to your expert structures.
  • I get output of these experts: e0, e1, e2, e3, e4, e5 by calling "model.output"
  • I define a gating model including Input layer like other experts and followed layers like yours (dense, dropout, reshape)
  • The below code is my MoE_output and MoE_model.
    '''
    from keras.layers import Lambda
    MoE_output = Lambda(lambda gx: (gx[0]*gx[6][:,:,0]) + (gx[1]*gx[6][:,:,1])+
    (gx[2]*gx[6][:,:,2]) + (gx[3]*gx[6][:,:,3])+
    (gx[4]*gx[6][:,:,4]) + (gx[5]*gx[6][:,:,5])
    , output_shape=(10,))([e0, e1, e2, e3, e4, e5, gating_model.output])
    X1_tmp = Input(shape=(X1_train.shape[1:]))
    X2_tmp = Input(shape=(X2_train.shape[1:]))
    MoE_model = Model(inputs=[X1_tmp, X2_tmp], outputs=MoE_output)
    '''
    --->If I do not use Input layers for these array inputs, they will cause a error" unhashbleerror: numpy.array"
    since in topology.py of keras, there is an unsatified condition "len(set(seft.inputs))"!= len(self.inputs)"
    so I change these array inputs to tensors by using Input layer of keras.

-->If I use Input layers and then build MoE_model as above, it will launch graphic disconnected error.
Althought I check Input of experts, they are same as <Tensor("input_2_6:0", shape=(?, 128, 50, 1), dtype=float32) >
'''
*** RuntimeError: Graph disconnected: cannot obtain value for tensor Tensor("input_2_6:0", shape=(?, 128, 50, 1), dtype=float32) at layer "input_2". The following previous layers were accessed without issue: []
'''
Please give me some advise to solve this problem.
Looking forward to hearing from you soon
Thank you so much,
Truc

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.