class QNetwork(nn.Module):
def __init__(self, frames=4):
super(QNetwork, self).__init__()
self.network = nn.Sequential(
Scale(1/255),
nn.Conv2d(frames, 32, 8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2),
nn.ReLU(),
nn.Conv2d(64, 64, 3, stride=1),
nn.ReLU(),
nn.Flatten(),
nn.Linear(3136, 512),
nn.ReLU(),
Linear0(512, env.action_space.n)
)
def forward(self, x):
x = torch.Tensor(x).to(device)
return self.network(x)
def forward(self, x, device):
x = torch.Tensor(x).to(device)
return self.network(x)
It would be great if someone is willing to take some time and refactor files. This problem is present in almost all of the files..