强化学习A2C¶
本教程作者:lucifer2859
本节教程使用SNN重新实现 actor-critic.py。 请确保你已经阅读了原版代码以及相关论文,因为本教程是对原代码的扩展。
状态输入
同DQN一样我们使用另一种常用的使SNN输出浮点值的方法:将神经元的阈值设置成无穷大,使其不发放脉冲,用神经元最后时刻的电压作为输出值。神经元实现这
种神经元非常简单,只需要继承已有神经元,重写 forward
函数即可。LIF神经元的电压不像IF神经元那样是简单的积分,因此我们使用LIF
神经元来改写:
class NonSpikingLIFNode(neuron.LIFNode):
def forward(self, dv: torch.Tensor):
self.neuronal_charge(dv)
# self.neuronal_fire()
# self.neuronal_reset()
return self.v
接下来,搭建我们的Spiking Actor-Critic Network,网络的结构非常简单,全连接-IF神经元-全连接-NonSpikingLIF神经元,全连接-IF神经元起到 编码器的作用,而全连接-NonSpikingLIF神经元则可以看作一个决策器:
class ActorCritic(nn.Module):
def __init__(self, num_inputs, num_outputs, hidden_size, T=16):
super(ActorCritic, self).__init__()
self.critic = nn.Sequential(
nn.Linear(num_inputs, hidden_size),
neuron.IFNode(),
nn.Linear(hidden_size, 1),
NonSpikingLIFNode(tau=2.0)
)
self.actor = nn.Sequential(
nn.Linear(num_inputs, hidden_size),
neuron.IFNode(),
nn.Linear(hidden_size, num_outputs),
NonSpikingLIFNode(tau=2.0)
)
self.T = T
def forward(self, x):
for t in range(self.T):
self.critic(x)
self.actor(x)
value = self.critic[-1].v
probs = F.softmax(self.actor[-1].v, dim=1)
dist = Categorical(probs)
return dist, value
训练网络¶
训练部分的代码,与ANN版本几乎相同,使用env返回的Observation作为输入。
SNN的训练代码如下,我们会保存训练过程中使得奖励最大的模型参数:
while step_idx < max_steps:
log_probs = []
values = []
rewards = []
masks = []
entropy = 0
for _ in range(num_steps):
state = torch.FloatTensor(state).to(device)
dist, value = model(state)
functional.reset_net(model)
action = dist.sample()
next_state, reward, done, _ = envs.step(action.cpu().numpy())
log_prob = dist.log_prob(action)
entropy += dist.entropy().mean()
log_probs.append(log_prob)
values.append(value)
rewards.append(torch.FloatTensor(reward).unsqueeze(1).to(device))
masks.append(torch.FloatTensor(1 - done).unsqueeze(1).to(device))
state = next_state
step_idx += 1
if step_idx % 1000 == 0:
test_reward = test_env()
print('Step: %d, Reward: %.2f' % (step_idx, test_reward))
writer.add_scalar('Spiking-A2C-multi_env-' + env_name + '/Reward', test_reward, step_idx)
next_state = torch.FloatTensor(next_state).to(device)
_, next_value = model(next_state)
functional.reset_net(model)
returns = compute_returns(next_value, rewards, masks)
log_probs = torch.cat(log_probs)
returns = torch.cat(returns).detach()
values = torch.cat(values)
advantage = returns - values
actor_loss = - (log_probs * advantage.detach()).mean()
critic_loss = advantage.pow(2).mean()
loss = actor_loss + 0.5 * critic_loss - 0.001 * entropy
optimizer.zero_grad()
loss.backward()
optimizer.step()
另外一个需要注意的地方是,SNN是有状态的,因此每次前向传播后,不要忘了将网络 reset
。
完整的代码可见于 clock_driven/examples/Spiking_A2C.py。可以从命令行直接启动训练:
>>> python Spiking_A2C.py