Giter VIP home page Giter VIP logo

soft-moe's Introduction

Soft Mixture of Experts

PyTorch implementation of Soft Mixture of Experts (Soft-MoE) from "From Sparse to Soft Mixtures of Experts". This implementation extends the timm library's VisionTransformer class to support Soft-MoE MLP layers.

Installation

pip install soft-moe

Or install the entire repo with:

git clone https://github.com/bwconrad/soft-moe
cd soft-moe/
pip install -r requirements.txt

Usage

Initializing a Soft Mixture of Experts Vision Transformer

import torch
from soft_moe import SoftMoEVisionTransformer

net = SoftMoEVisionTransformer(
    num_experts=128,
    slots_per_expert=1,
    moe_layer_index=6, 
    img_size=224,
    patch_size=32,
    num_classes=1000,
    embed_dim=768,
    depth=12,
    num_heads=12,
    mlp_ratio=4,
)

img = torch.randn(1, 3, 224, 224)
preds = net(img)

Functions are also available to initialize default network configurations:

from soft_moe import (soft_moe_vit_base, soft_moe_vit_huge,
                      soft_moe_vit_large, soft_moe_vit_small,
                      soft_moe_vit_tiny)

net = soft_moe_vit_tiny()
net = soft_moe_vit_small()
net = soft_moe_vit_base()
net = soft_moe_vit_large()
net = soft_moe_vit_huge()

net = soft_moe_vit_tiny(num_experts=64, slots_per_expert=2, img_size=128)

Setting the Mixture of Expert Layers

The moe_layer_index argument sets at which layer indices to use MoE MLP layers instead of regular MLP layers. When an int is given, all layers starting from that depth index will be MoE layers.

net = SoftMoEVisionTransformer(
    moe_layer_index=6, # Blocks 6-12
    depth=12,
)

When a list is given, all specified layers will be MoE layers.

net = SoftMoEVisionTransformer(
    moe_layer_index=[0, 2, 4], # Blocks 0, 2 and 4
    depth=12,
)
  • Note: moe_layer_index uses 0-index convention.

Creating a Soft Mixture of Experts Layer

The SoftMoELayerWrapper class can be used to make any network layer, that takes a tensor of shape [batch, length, dim], into a Soft Mixture of Experts layer.

import torch
import torch.nn as nn

from soft_moe import SoftMoELayerWrapper

x = torch.rand(1, 16, 128)

layer = SoftMoELayerWrapper(
    dim=128,
    slots_per_expert=2,
    num_experts=32,
    layer=nn.Linear,
    # nn.Linear arguments
    in_features=128,
    out_features=32,
)
y = layer(x)

layer = SoftMoELayerWrapper(
    dim=128,
    slots_per_expert=1,
    num_experts=16,
    layer=nn.TransformerEncoderLayer,
    # nn.TransformerEncoderLayer arguments
    d_model=128,
    nhead=8,
)
y = layer(x)
  • Note: If the name of a layer argument overlaps with one of other arguments (e.g. dim) you can pass a partial function to layer.
    • e.g. layer=partial(MyCustomLayer, dim=128)

Citation

@article{puigcerver2023sparse,
  title={From Sparse to Soft Mixtures of Experts},
  author={Puigcerver, Joan and Riquelme, Carlos and Mustafa, Basil and Houlsby, Neil},
  journal={arXiv preprint arXiv:2308.00951},
  year={2023}
}

soft-moe's People

Contributors

bwconrad avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

soft-moe's Issues

Request to add registers and position embedding interpolation

Lol, I had fun with this code. But it wasn't all that suited to my use without the position embedding interpolation and some registers. So I added them. See if you feel like you want to include them into your code:

import math
from functools import partial
from typing import Callable

import torch
import torch.jit
import torch.nn as nn
import torch.utils.checkpoint
from timm.layers import Mlp, PatchDropout, trunc_normal_
from timm.models._manipulate import checkpoint_seq, named_apply
from timm.models.vision_transformer import (Block, _load_weights,
                                            get_init_weights_vit,
                                            init_weights_vit_timm)

from soft_moe.soft_moe import SoftMoELayerWrapper

class PatchEmbed(nn.Module):
    # converts image into patch embeddings based on total number of non-overlapping crops.
    # For each image containing n patches, there should be n embedding vectors per image, so a n x embedding_vector matrix.    
    def __init__(self,img_size,patch_size,in_channels=3, embed_dim=256):
        super().__init__()
        self.img_size       = img_size
        self.patch_size     = patch_size
        self.in_channels    = in_channels
        self.n_patches      = (img_size // patch_size)**2
        self.project        = nn.Conv2d(
                                    in_channels     =in_channels,
                                    out_channels    = embed_dim,
                                    kernel_size     = patch_size,
                                    stride          = patch_size,
                                    )
    
    def forward(self,x):
        # x has input a tensor of shape B, C, H, W (batch, channel, height, width)

        x = self.project(x)     # Batch X Embedding Dim X sqrt(N_patches) X sqrt(N_patches)
        x = x.flatten(2)        # Batch X Embedding Dim X N_patches
        x = x.transpose(1,2)    # Batch X N_patches X Embedding Dim

        return x


class SoftMoEVisionTransformer(nn.Module):
    """Vision Transformer with Soft Mixture of Experts MLP layers.

    From the paper "From Sparse to Soft Mixtures of Experts"
    https://arxiv.org/pdf/2308.00951.pdf

    Code modified from:
    https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
    """

    def __init__(
        self,
        num_experts: int = 128,
        slots_per_expert: int = 1,
        moe_layer_index: int | list[int] = 6,
        img_size: int | tuple[int, int] = 224,
        patch_size: int | tuple[int, int] = 16,
        in_chans: int = 3,
        global_pool: str = "token",
        embed_dim: int = 768,
        depth: int = 12,
        num_heads: int = 12,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = True,
        qk_norm: bool = False,
        init_values: float | None = None,
        class_token: bool = True,
        no_embed_class: bool = False,
        pre_norm: bool = False,
        fc_norm: bool | None = None,
        drop_rate: float = 0.0,
        pos_drop_rate: float = 0.0,
        patch_drop_rate: float = 0.0,
        proj_drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        drop_path_rate: float = 0.0,
        weight_init: str = "",
        embed_layer: Callable = PatchEmbed,
        norm_layer: Callable | None = None,
        act_layer: Callable | None = None,
        block_fn: Callable = Block,
        mlp_layer: Callable = Mlp,
    ):
        """
        Args:
            num_experts (int): Number of experts in MoE layers.
            slots_per_expert (int): Number of token slots per expert.
            moe_layer_index (int or list[int]): Block depth indices where MoE layers are used.
                Either an int which denotes where MoE layers are used from to the end, or a list
                of ints denoting the specific blocks (both use 0-indexing).
            img_size (int or tuple[int, int]): Input image size.
            patch_size (int or tuple[int, int]): Patch size.
            in_chans (int): Number of image input channels.
            global_pool (str): Type of global pooling for the final sequence (default: 'token').
            embed_dim (int): Transformer embedding dimension.
            depth (int): Depth of the transformer.
            num_heads (int): Number of attention heads.
            mlp_ratio (float): Ratio of MLP hidden dim to embedding dim.
            qkv_bias (bool): Enable bias for qkv projections if True.
            qk_norm (bool): Enable normalization of query and key in self-attention.
            init_values (float or None): Layer-scale init values (layer-scale enabled if not None).
            class_token (bool): Use a class token.
            no_embed_class (bool): Do not embed class tokens in the patch embedding.
            pre_norm (bool): Apply normalization before self-attention in the transformer block.
            fc_norm (bool or None): Pre-head norm after pool (instead of before).
                If None, enabled when global_pool == 'avg'.
            drop_rate (float): Head dropout rate.
            pos_drop_rate (float): Position embedding dropout rate.
            attn_drop_rate (float): Attention dropout rate.
            drop_path_rate (float): Stochastic depth rate.
            weight_init (str): Weight initialization scheme.
            embed_layer (Callable): Patch embedding layer.
            norm_layer (Callable or None): Normalization layer.
            act_layer (Callable or None): MLP activation layer.
            block_fn (Callable): Transformer block layer.
            mlp_layer (Callable): MLP layer.
        """
        super().__init__()
        assert global_pool in ("", "avg", "token")
        assert class_token or global_pool != "token"
        use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm
        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
        act_layer = act_layer or nn.GELU

        self.global_pool = global_pool
        # num_features for consistency with other models
        self.num_features = self.embed_dim = embed_dim
        self.num_prefix_tokens = 1 if class_token else 0 
        self.no_embed_class = no_embed_class
        self.grad_checkpointing = False

        self.patch_embed = embed_layer(
            img_size=img_size,
            patch_size=patch_size,
            in_channels=in_chans,
            embed_dim=embed_dim,
        )
        self.patch_embed.project.bias = None

        num_patches = (img_size//patch_size)**2

        self.cls_token = (
            nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
        )

        self.numregisters       = 4
        self.registers          = (
                        nn.Parameter(torch.zeros(1,4,embed_dim))
        )

        embed_len = (
            num_patches if no_embed_class else num_patches + self.num_prefix_tokens
        )
        self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
        self.pos_drop = nn.Dropout(p=pos_drop_rate)
        if patch_drop_rate > 0:
            self.patch_drop = PatchDropout(
                patch_drop_rate,
                num_prefix_tokens=self.num_prefix_tokens,
            )
        else:
            self.patch_drop = nn.Identity()
        self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()

        # Wrap the mlp_layer in a soft-moe wrapper
        self.num_experts = num_experts
        self.slots_per_expert = slots_per_expert

        moe_mlp_layer = partial(
            SoftMoELayerWrapper,
            layer=mlp_layer,
            dim=embed_dim,
            num_experts=self.num_experts,
            slots_per_expert=self.slots_per_expert,
        )

        # Create a list where each index is the mlp layer class to
        # use at that depth
        self.moe_layer_index = moe_layer_index
        if isinstance(moe_layer_index, list):
            # Only the specified layers in moe_layer_index
            assert len(moe_layer_index) > 0
            assert all([0 <= l < depth for l in moe_layer_index])

            mlp_layers_list = [
                moe_mlp_layer if i in moe_layer_index else mlp_layer
                for i in range(depth)
            ]
        else:
            if moe_layer_index < depth: 
                # All layers including and after moe_layer_index

                mlp_layers_list = [
                    moe_mlp_layer if i >= moe_layer_index else mlp_layer
                    for i in range(depth)
                ]
            else: # hack to make all layers mlp
                mlp_layers_list = [
                    mlp_layer
                    for i in range(depth)
                ]

        # stochastic depth decay rule
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
        self.blocks = nn.Sequential(
            *[
                block_fn(
                    dim=embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    qk_norm=qk_norm,
                    init_values=init_values,
                    proj_drop=proj_drop_rate,
                    attn_drop=attn_drop_rate,
                    drop_path=dpr[i],
                    norm_layer=norm_layer,
                    act_layer=act_layer,
                    mlp_layer=mlp_layers_list[i],
                )
                for i in range(depth)
            ]
        )
        self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()

        # Classifier Head
        self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
        

        if weight_init != "skip":
            self.init_weights(weight_init)

    def init_weights(self, mode=""):
        assert mode in ("jax", "jax_nlhb", "moco", "")
        trunc_normal_(self.pos_embed, std=0.02)
        if self.cls_token is not None:
            nn.init.normal_(self.cls_token, std=1e-6)
        if self.registers is not None:
            nn.init.normal_(self.registers, std=1e-6)

    def _init_weights(self, m):
        # this fn left here for compat with downstream users
        init_weights_vit_timm(m)

    @torch.jit.ignore()
    def load_pretrained(self, checkpoint_path, prefix=""):
        _load_weights(self, checkpoint_path, prefix)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {"pos_embed", "cls_token", "dist_token", "registers"}

    @torch.jit.ignore
    def group_matcher(self, coarse=False):
        return dict(
            stem=r"^cls_token|pos_embed|patch_embed",  # stem and embed
            blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))],
        )

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True):
        self.grad_checkpointing = enable


    def pos_embedding_interp(self, x, h, w):

        num_patches = x.shape[1] - 1 # because one is a class token
        N = self.pos_embed.shape[1] - 1 # this is the shape the ViT expects

        if num_patches == N: # won't include a check for the image being square
          return self.pos_embed.shape[1] # because no interpolation needs to be done
        # Now we need to do interpolation. So begin by separating class and position tokens
        class_pos_embed   = self.pos_embed[:,0]
        patch_pos_embed   = self.pos_embed[:,1:]
        dim         = x.shape[-1] # patch embedding dimensionality
        w0 = w // self.patch_embed.patch_size
        h0 = h // self.patch_embed.patch_size
        w0, h0 = w0+0.1, h0+0.1 # preventing some division by zero (just in case)

        # Perform the interpolation
        patch_pos_embed = torch.nn.functional.interpolate(
            patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
            scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
            mode='bicubic',
        )
        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)



    def _pos_embed(self, x):
        
        # original timm, JAX, and deit vit impl
        # pos_embed has entry for class token, concat then add
        batches, _, W, H = x.shape # B, C, W, H
        x = self.patch_embed(x)
        x = torch.cat((self.cls_token.expand(batches, -1, -1), x), dim=1)

        x = x + self.pos_embedding_interp(x,H,W) # I changed this else registers does not work

        if self.registers is not None:
            x = torch.cat(
                (
                x[:,0,:].unsqueeze(1),
                self.registers.expand(x.shape[0],-1,-1),
                x[:,1:],
                ),
                dim = 1,
            )
        
        return self.pos_drop(x)

    def forward_features(self, x):
        x = self._pos_embed(x)
        x = self.patch_drop(x)
        x = self.norm_pre(x)
        if self.grad_checkpointing and not torch.jit.is_scripting():
            x = checkpoint_seq(self.blocks, x)
        else:
            x = self.blocks(x)
        x = self.norm(x)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        return x[:, 0] # you will only use the class token

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.