I would like to test the code with more hidden layers. I tried the following code:
# weights initialization
nb_inputs = 28*28
nb_hidden = 100
nb_hidden2 = 50
nb_outputs = 10
weight_scale = 0.2
w1 = torch.empty((nb_inputs, nb_hidden), device=device, dtype=dtype, requires_grad=True)
torch.nn.init.normal_(w1, mean=0.0, std=weight_scale/np.sqrt(nb_inputs))
wh = torch.empty((nb_hidden, nb_hidden2), device=device, dtype=dtype, requires_grad=True)
torch.nn.init.normal_(wh, mean=0.0, std=weight_scale/np.sqrt(nb_hidden))
w2 = torch.empty((nb_hidden2, nb_outputs), device=device, dtype=dtype, requires_grad=True)
torch.nn.init.normal_(w2, mean=0.0, std=weight_scale/np.sqrt(nb_hidden))
print("init done")
def run_snn_n(inputs):
h1 = torch.einsum("abc,cd->abd", (inputs, w1))
syn = torch.zeros((batch_size,nb_hidden), device=device, dtype=dtype)
mem = torch.zeros((batch_size,nb_hidden), device=device, dtype=dtype)
mem_rec = []
spk_rec = []
# Compute hidden layer activity
for t in range(nb_steps):
mthr = mem-1.0
out = spike_fn(mthr)
rst = out.detach() # We do not want to backprop through the reset
new_syn = alpha*syn +h1[:,t]
new_mem = (beta*mem +syn)*(1.0-rst)
mem_rec.append(mem)
spk_rec.append(out)
mem = new_mem
syn = new_syn
mem_rec = torch.stack(mem_rec,dim=1)
spk_rec = torch.stack(spk_rec,dim=1)
# middle
ht = torch.einsum("abc,cd->abd", (spk_rec, wh))
syn = torch.zeros((batch_size,nb_hidden2), device=device, dtype=dtype)
mem = torch.zeros((batch_size,nb_hidden2), device=device, dtype=dtype)
mem_rec = []
spk_rec = []
# Compute hidden layer activity
for t in range(nb_steps):
mthr = mem-1.0
out = spike_fn(mthr)
rst = out.detach() # We do not want to backprop through the reset
new_syn = alpha*syn +ht[:,t]
new_mem = (beta*mem +syn)*(1.0-rst)
mem_rec.append(mem)
spk_rec.append(out)
mem = new_mem
syn = new_syn
mem_rec = torch.stack(mem_rec,dim=1)
spk_rec = torch.stack(spk_rec,dim=1)
# Readout layer
h2= torch.einsum("abc,cd->abd", (spk_rec, w2))
flt = torch.zeros((batch_size,nb_outputs), device=device, dtype=dtype)
out = torch.zeros((batch_size,nb_outputs), device=device, dtype=dtype)
out_rec = [out]
for t in range(nb_steps):
new_flt = alpha*flt +h2[:,t]
new_out = beta*out +flt
flt = new_flt
out = new_out
out_rec.append(out)
out_rec = torch.stack(out_rec,dim=1)
other_recs = [mem_rec, spk_rec]
return out_rec, other_recs
def train(x_data, y_data, lr=1e-3, nb_epochs=10):
params = [w1,wh,w2]
...
But the gradient seems to not propagate across the layers.
Can you please tell me where the problem is?