Giter VIP home page Giter VIP logo

tensorrvea's Introduction

๐ŸŒŸ TensorRVEA: Tensorized RVEA for GPU-accelerated Evolutionary Multi-objective Optimization ๐ŸŒŸ

TensorRVEA Paper on arXiv

Tensorized Reference Vector Guided Evolutionary Algorithm (TensorRVEA) aims to enhance the scalability and efficiency of evolutionary multi-objective optimization by incorporating GPU acceleration. By adapting key data structures and operators into tensor forms, TensorRVEA seeks to utilize GPU-based parallel computing to offer a more efficient approach to complex optimization challenges. The implementation of TensorRVEA is compatible with the EvoX framewrok in JAX.

Demonstrations

Below are demonstrations of TensorRVEA applied to various simulated multiobjective robotics environments. Specifically, TensorRVEA optimizes the parameters of the MLP, and then uses this MLP as a policy model to visualize the robot's behavior in the simulated environment.

MoHalfcheetah MoHopper-m2 MoSwimmer
  • MoHalfcheetah: Optimizing for speed and control cost.
  • MoHopper-m2: Aiming for maximum speed and jumping height.
  • MoSwimmer: Enhancing movement efficiency in fluid environments.

Key Features

  • GPU Acceleration ๐Ÿ’ป: Leverages GPUs for enhanced computational capabilities.
  • Large-Scale Optimization ๐Ÿ“ˆ: Ideal for large population sizes and high-dimensional challenges.
  • Flexibility ๐Ÿ”จ: Compatible with a variety of tensor-based reproduction operators, including GA, DE, PSO, and CSO.
  • Real-World Applications ๐ŸŒ: Suited for complex tasks like multiobjective robotic control (MoBrax), with a special emphasis on neuroevolution methodologies.

Requirements

TensorRVEA requires:

  • evox (version == 0.8.1)
  • jax (version >= 0.4.16)
  • jaxlib (version >= 0.3.0)
  • brax (version == 0.10.3)
  • flax
  • Visualization tools: plotly, pandas

Example Usage

Sample example for DTLZ problems:

from evox import workflows, problems
import algorithms
from evox.monitors import PopMonitor
from evox.metrics import IGD
import jax
import jax.numpy as jnp
import numpy as np
import time


def run_moea(algorithm, key):
    monitor = PopMonitor()

    problem = problems.numerical.DTLZ2(m=3)
    workflow = workflows.StdWorkflow(
        algorithm=algorithm,
        problem=problem,
        monitor=monitor,
    )

    state = workflow.init(key)

    true_pf = problem.pf()

    igd = IGD(true_pf)

    for i in range(100):
        key, subkey = jax.random.split(key)
        state = workflow.step(state)

        fit = state.get_child_state("algorithm").fitness
        non_nan_rows = fit[~np.isnan(fit).any(axis=1)]
        print(f'Generation {i+1}, IGD: {igd(non_nan_rows)}')
    fig = monitor.plot()
    fig.show()


if __name__ == '__main__':
    lb = jnp.full(shape=(12,), fill_value=0)
    ub = jnp.full(shape=(12,), fill_value=1)

    algorithm = algorithms.TensorRVEA(
        lb=lb,
        ub=ub,
        n_objs=3,
        pop_size=100,
    )
    key = jax.random.PRNGKey(42)

    start = time.time()
    run_moea(algorithm, key)
    end = time.time()
    print(f"time: {end-start}s")

Sample example for MoBrax:

from algorithms import TensorRVEA
from evox.workflows import StdWorkflow
from evox.monitors import StdMOMonitor
from evox.utils import TreeAndVector
import jax
import jax.numpy as jnp
from flax import linen as nn
import time
import problems
from evox.operators.sampling import UniformSampling
from evox.metrics import HV
from metrics.expected_utility import ExpectedUtility

env_name = "mo_swimmer"


class Model(nn.Module):

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(16)(x)
        x = nn.tanh(x)
        x = nn.Dense(2)(x)
        x = nn.tanh(x)
        return x


def main():
    key = jax.random.PRNGKey(43)
    model_key, workflow_key = jax.random.split(key)
    model = Model()
    params = model.init(model_key, jnp.zeros((8,)))
    adapter = TreeAndVector(params)
    monitor = StdMOMonitor(record_pf=False)

    problem = problems.MoBrax(
        policy=jax.jit(model.apply),
        env_name=env_name,
        cap_episode=1000,
        num_obj=2,
    )
    center = adapter.to_vector(params)

    workflow = StdWorkflow(
        algorithm=TensorRVEA(
            lb=jnp.full_like(center, -8),
            ub=jnp.full_like(center, 8),
            n_objs=2,
            pop_size=100,
            uniform_init=False,
        ),
        problem=problem,
        monitor=monitor,
        num_objectives=2,
        pop_transform=adapter.batched_to_tree,
        opt_direction="max",
    )

    state = workflow.init(workflow_key)
    step_func = jax.jit(workflow.step).lower(state).compile()
    state = step_func(state)
    w = UniformSampling(100, 2)()[0]
    ref = jnp.array([0, -1])
    hv_metric = HV(ref=-ref)
    eu_metric = ExpectedUtility(w=w)
    start = time.time()
    for i in range(100):
        key, subkey = jax.random.split(key)
        state = step_func(state)
        f = -state.get_child_state("algorithm").fitness
        f = f[~jnp.isnan(f).any(axis=1)]
        current_f = f[jnp.all(f >= ref, axis=1)]
        if current_f.shape[0] == 0:
            hv = 0
            eu = 0
        else:
            hv = hv_metric(jax.random.split(workflow_key)[1], -current_f)
            eu = eu_metric(current_f)
        print(f'Generation {i+1}, HV: {hv}, EU: {eu}')
    end = time.time()
    print(f"Total time: {end - start}s")


if __name__ == "__main__":
    main()

Community & Support

Citing TensorRVEA

If you use TensorRVEA in your research and want to cite it in your work, please use:

@inproceedings{tensorrvea,
    author = {Liang, Zhenyu and Jiang, Tao and Sun, Kebin and Cheng, Ran},
    title = {GPU-accelerated Evolutionary Multiobjective Optimization Using Tensorized RVEA},
    year = {2024},
    doi = {10.1145/3638529.3654223},
    booktitle = {Proceedings of the Genetic and Evolutionary Computation Conference},
    pages = {566โ€“575},
    numpages = {10},
    location = {Melbourne, VIC, Australia},
    series = {GECCO '24}
}

tensorrvea's People

Contributors

ranchengcn avatar zhenyu2liang avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

Forkers

lihao-ms

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.