Giter VIP home page Giter VIP logo

pop-spiking-deep-rl's Issues

About an example test code of popsan_ddpg on GPU

import numpy as np
import torch
import torch.nn as nn
import gym
import pickle
import sys
import math

sys.path.append('../../')
from popsan_drl.popsan_td3.replay_buffer_norm import ReplayBuffer
from popsan_drl.popsan_td3.popsan import PopSpikeActor

with torch.no_grad():

def test_best_models(env_name, model_num, result_dir, spike_ts=5, encoder_pop_dim=10, decoder_pop_dim=10,
                     std=math.sqrt(0.15), max_ep_len=1000):
    # Initialize test environment
    env = gym.make(env_name)
    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]
    act_limit = env.action_space.high[0]

    for m in range(model_num):
        test_reward, _ = pickle.load(open(result_dir + '/model' + str(m) + '_test_rewards.p', 'rb'))
        best_epoch_idx = 5
        best_epoch_reward = 0
        for idx in range(20):
            if test_reward[(idx + 1) * 5 - 1] < best_epoch_reward:
                best_epoch_reward = test_reward[(idx + 1) * 5 - 1]
                best_epoch_idx = (idx + 1) * 5
        print("Train Model: ", m, " Best Epoch: ", best_epoch_idx, " Reward: ", best_epoch_reward)
        model_dir = result_dir + '/model' + str(m) + '_e' + str(best_epoch_idx) + '.pt'
        buffer_dir = result_dir + '/model' + str(m) + '_e' + str(best_epoch_idx) + '_mean_var.p'

    # Set device
    device = torch.device("cpu")

    # Load the saved model and saved input normalization mean and var

    # Replay buffer for running z-score norm
    b_mean_var = pickle.load(open(buffer_dir, "rb"))
    replay_buffer = ReplayBuffer(obs_dim=obs_dim, act_dim=act_dim, size=1,
                                 clip_limit=3, norm_update_every=1)
    replay_buffer.mean = b_mean_var[0]
    replay_buffer.var = b_mean_var[1]
    # PopSAN
    popsan = PopSpikeActor(obs_dim, act_dim, encoder_pop_dim, decoder_pop_dim, (256, 256),
                           (-3, 3), std, spike_ts, act_limit, device, False)
    popsan.load_state_dict(torch.load(model_dir))
    print(model_dir)

    # test_agent
    def get_action(o):
        a = popsan(torch.as_tensor(o, dtype=torch.float32, device=device), 1).to('cpu').numpy()
        # print(a)
        return np.clip(a, -act_limit, act_limit)

    test_reward_sum = 0
    o, d, ep_ret, ep_len = env.reset(), False, 0, 0
    while not (d or (ep_len == max_ep_len)):
        # print(ep_len)
        env.render()
        # Take deterministic actions at test time (noise_scale=0)
        o, r, d, _ = env.step(get_action(replay_buffer.normalize_obs(o)))
        ep_ret += r
        ep_len += 1
        # if d:
        #     o = env.reset()
        test_reward_sum += ep_ret

if __name__ == '__main__':
    data_dir = 'E:\pop-spiking-deep-rl-main\popsan_drl\popsan_td3\params\spike-td3_td3-popsan-HalfCheetah-v3-encoder-dim-10-decoder-dim-10'
    env_name = "HalfCheetah-v3"
    model_num = 10
    r_list, i_list, in_mem = test_best_models(env_name, model_num, data_dir)

About training PopSAN with TD3 algorithm

Hi, I‘m interested in your work, But I have some confusion when trying td3_cuda_norm.py.
I don't quite understand what parameter “encoder_var” and parameter “mean_range” mean. And why did you choose 0.15 and (-3, 3) as the input values?
Thank you very much!

if name == 'main':
import math
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--env', type=str, default='Ant-v3')
parser.add_argument('--encoder_pop_dim', type=int, default=10)
parser.add_argument('--decoder_pop_dim', type=int, default=10)
**parser.add_argument('--encoder_var', type=float, default=0.15)**
parser.add_argument('--start_model_idx', type=int, default=0)
parser.add_argument('--num_model', type=int, default=10)
parser.add_argument('--epochs', type=int, default=100)
args = parser.parse_args()

START_MODEL = args.start_model_idx
NUM_MODEL = args.num_model
USE_POISSON = False
if args.env == 'Hopper-v3' or args.env == 'Walker2d-v3':
    USE_POISSON = True
AC_KWARGS = dict(hidden_sizes=[256, 256],
                 encoder_pop_dim=args.encoder_pop_dim,
                 decoder_pop_dim=args.decoder_pop_dim,
                 **mean_range=(-3, 3),**
                 std=math.sqrt(args.encoder_var),
                 spike_ts=5,
                 device=torch.device('cuda'),
                 use_poisson=USE_POISSON)
COMMENT = "td3-popsan-" + args.env + "-encoder-dim-" + str(AC_KWARGS['encoder_pop_dim']) + \
          "-decoder-dim-" + str(AC_KWARGS['decoder_pop_dim'])

about RateSAN model

Hi Guangzhi,
Thank you very much for sharing your work, I am very interested in it. But the training code of RateSAN model in mujoco is not public, I know it is your IROS work, but our own implementation is not good, I hope you can provide your RateSAN training test code, I hope you can help us, thanks.

loihi_realization test_lohi.py

image

WARNING:DRV: elementType would be deprecated in 0.9 in favor of messageSize, which provides more flexibility
WARNING:DRV: elementType would be deprecated in 0.9 in favor of messageSize, which provides more flexibility
INFO:DRV: SLURM is being run in background
INFO:HST: srun: error: Unable to allocate resources: Invalid node name specified
WARNING:DRV: Connection is taking longer than usual.
WARNING:DRV: Boards might be busy or consider reasons below:
WARNING:DRV: 1. If you are working on INRC Cloud, ensure setting SLURM=1.
WARNING:DRV: 2. Run sinfo to check if all boards are down.
WARNING:DRV: 3. Check squeue below for any unfinished jobs. Run scancel to cancel hung jobs.
WARNING:DRV: JOBID NAME PARTITION TIME NODELIST(REASON) USER

Do you know how to ensure the first two items?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.