combra-lab / pop-spiking-deep-rl Goto Github PK
View Code? Open in Web Editor NEWDRL with population coded spiking neural network for optimal and energy-efficient continuous control.
License: MIT License
DRL with population coded spiking neural network for optimal and energy-efficient continuous control.
License: MIT License
import numpy as np
import torch
import torch.nn as nn
import gym
import pickle
import sys
import math
sys.path.append('../../')
from popsan_drl.popsan_td3.replay_buffer_norm import ReplayBuffer
from popsan_drl.popsan_td3.popsan import PopSpikeActor
with torch.no_grad():
def test_best_models(env_name, model_num, result_dir, spike_ts=5, encoder_pop_dim=10, decoder_pop_dim=10,
std=math.sqrt(0.15), max_ep_len=1000):
# Initialize test environment
env = gym.make(env_name)
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]
act_limit = env.action_space.high[0]
for m in range(model_num):
test_reward, _ = pickle.load(open(result_dir + '/model' + str(m) + '_test_rewards.p', 'rb'))
best_epoch_idx = 5
best_epoch_reward = 0
for idx in range(20):
if test_reward[(idx + 1) * 5 - 1] < best_epoch_reward:
best_epoch_reward = test_reward[(idx + 1) * 5 - 1]
best_epoch_idx = (idx + 1) * 5
print("Train Model: ", m, " Best Epoch: ", best_epoch_idx, " Reward: ", best_epoch_reward)
model_dir = result_dir + '/model' + str(m) + '_e' + str(best_epoch_idx) + '.pt'
buffer_dir = result_dir + '/model' + str(m) + '_e' + str(best_epoch_idx) + '_mean_var.p'
# Set device
device = torch.device("cpu")
# Load the saved model and saved input normalization mean and var
# Replay buffer for running z-score norm
b_mean_var = pickle.load(open(buffer_dir, "rb"))
replay_buffer = ReplayBuffer(obs_dim=obs_dim, act_dim=act_dim, size=1,
clip_limit=3, norm_update_every=1)
replay_buffer.mean = b_mean_var[0]
replay_buffer.var = b_mean_var[1]
# PopSAN
popsan = PopSpikeActor(obs_dim, act_dim, encoder_pop_dim, decoder_pop_dim, (256, 256),
(-3, 3), std, spike_ts, act_limit, device, False)
popsan.load_state_dict(torch.load(model_dir))
print(model_dir)
# test_agent
def get_action(o):
a = popsan(torch.as_tensor(o, dtype=torch.float32, device=device), 1).to('cpu').numpy()
# print(a)
return np.clip(a, -act_limit, act_limit)
test_reward_sum = 0
o, d, ep_ret, ep_len = env.reset(), False, 0, 0
while not (d or (ep_len == max_ep_len)):
# print(ep_len)
env.render()
# Take deterministic actions at test time (noise_scale=0)
o, r, d, _ = env.step(get_action(replay_buffer.normalize_obs(o)))
ep_ret += r
ep_len += 1
# if d:
# o = env.reset()
test_reward_sum += ep_ret
if __name__ == '__main__':
data_dir = 'E:\pop-spiking-deep-rl-main\popsan_drl\popsan_td3\params\spike-td3_td3-popsan-HalfCheetah-v3-encoder-dim-10-decoder-dim-10'
env_name = "HalfCheetah-v3"
model_num = 10
r_list, i_list, in_mem = test_best_models(env_name, model_num, data_dir)
Hi, I‘m interested in your work, But I have some confusion when trying td3_cuda_norm.py.
I don't quite understand what parameter “encoder_var” and parameter “mean_range” mean. And why did you choose 0.15 and (-3, 3) as the input values?
Thank you very much!
if name == 'main':
import math
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--env', type=str, default='Ant-v3')
parser.add_argument('--encoder_pop_dim', type=int, default=10)
parser.add_argument('--decoder_pop_dim', type=int, default=10)
**parser.add_argument('--encoder_var', type=float, default=0.15)**
parser.add_argument('--start_model_idx', type=int, default=0)
parser.add_argument('--num_model', type=int, default=10)
parser.add_argument('--epochs', type=int, default=100)
args = parser.parse_args()
START_MODEL = args.start_model_idx
NUM_MODEL = args.num_model
USE_POISSON = False
if args.env == 'Hopper-v3' or args.env == 'Walker2d-v3':
USE_POISSON = True
AC_KWARGS = dict(hidden_sizes=[256, 256],
encoder_pop_dim=args.encoder_pop_dim,
decoder_pop_dim=args.decoder_pop_dim,
**mean_range=(-3, 3),**
std=math.sqrt(args.encoder_var),
spike_ts=5,
device=torch.device('cuda'),
use_poisson=USE_POISSON)
COMMENT = "td3-popsan-" + args.env + "-encoder-dim-" + str(AC_KWARGS['encoder_pop_dim']) + \
"-decoder-dim-" + str(AC_KWARGS['decoder_pop_dim'])
Hi Guangzhi,
Thank you very much for sharing your work, I am very interested in it. But the training code of RateSAN model in mujoco is not public, I know it is your IROS work, but our own implementation is not good, I hope you can provide your RateSAN training test code, I hope you can help us, thanks.
WARNING:DRV: elementType would be deprecated in 0.9 in favor of messageSize, which provides more flexibility
WARNING:DRV: elementType would be deprecated in 0.9 in favor of messageSize, which provides more flexibility
INFO:DRV: SLURM is being run in background
INFO:HST: srun: error: Unable to allocate resources: Invalid node name specified
WARNING:DRV: Connection is taking longer than usual.
WARNING:DRV: Boards might be busy or consider reasons below:
WARNING:DRV: 1. If you are working on INRC Cloud, ensure setting SLURM=1.
WARNING:DRV: 2. Run sinfo to check if all boards are down.
WARNING:DRV: 3. Check squeue below for any unfinished jobs. Run scancel to cancel hung jobs.
WARNING:DRV: JOBID NAME PARTITION TIME NODELIST(REASON) USER
Do you know how to ensure the first two items?
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.