Reinforcement Learning: Proximal Policy Optimization (PPO)

Author: lucifer2859

Translator: LiutaoYu

This tutorial applies a spiking neural network to reproduce ppo.py. Please make sure that you have read the original tutorial and corresponding codes before proceeding.

Here, we apply the same method as the previous DQN tutorial to make SNN output floating numbers. We set the firing threshold of a neuron to be infinity, which won’t fire at all, and we adopt the final membrane potential to represent Q function. It is convenient to implement such neurons in the SpikingJelly framework: just inherit everything from LIF neuron neuron.LIFNode and rewrite the forward function.

class NonSpikingLIFNode(neuron.LIFNode):
    def forward(self, dv: torch.Tensor):
        self.neuronal_charge(dv)
        # self.neuronal_fire()
        # self.neuronal_reset()
        return self.v

The basic structure of the Spiking Actor-Critic Network is very simple: input layer, IF neuron layer, and NonSpikingLIF neuron layer, between which are fully linear connections. The IF neuron layer is an encoder to convert the CartPole’s state variables to spikes, and the NonSpikingLIF neuron layer can be regraded as the decision making unit.

class ActorCritic(nn.Module):
    def __init__(self, num_inputs, num_outputs, hidden_size, T=16, std=0.0):
        super(ActorCritic, self).__init__()

        self.critic = nn.Sequential(
            nn.Linear(num_inputs, hidden_size),
            neuron.IFNode(),
            nn.Linear(hidden_size, 1),
            NonSpikingLIFNode(tau=2.0)
        )

        self.actor = nn.Sequential(
            nn.Linear(num_inputs, hidden_size),
            neuron.IFNode(),
            nn.Linear(hidden_size, num_outputs),
            NonSpikingLIFNode(tau=2.0)
        )

        self.log_std = nn.Parameter(torch.ones(1, num_outputs) * std)

        self.T = T

    def forward(self, x):
        for t in range(self.T):
            self.critic(x)
            self.actor(x)
        value = self.critic[-1].v
        mu    = self.actor[-1].v
        std   = self.log_std.exp().expand_as(mu)
        dist  = Normal(mu, std)
        return dist, value

Training the network

The code of this part is almost the same with the ANN version. But note that the SNN version here adopts Observation returned by env as the network input.

Following is the training code of the SNN version. During the training process, we will save the model parameters responsible for the largest reward.

# GAE
def compute_gae(next_value, rewards, masks, values, gamma=0.99, tau=0.95):
    values = values + [next_value]
    gae = 0
    returns = []
    for step in reversed(range(len(rewards))):
        delta = rewards[step] + gamma * values[step + 1] * masks[step] - values[step]
        gae = delta + gamma * tau * masks[step] * gae
        returns.insert(0, gae + values[step])
    return returns

# Proximal Policy Optimization Algorithm
# Arxiv: "https://arxiv.org/abs/1707.06347"
def ppo_iter(mini_batch_size, states, actions, log_probs, returns, advantage):
    batch_size = states.size(0)
    ids = np.random.permutation(batch_size)
    ids = np.split(ids[:batch_size // mini_batch_size * mini_batch_size], batch_size // mini_batch_size)
    for i in range(len(ids)):
        yield states[ids[i], :], actions[ids[i], :], log_probs[ids[i], :], returns[ids[i], :], advantage[ids[i], :]

def ppo_update(ppo_epochs, mini_batch_size, states, actions, log_probs, returns, advantages, clip_param=0.2):
    for _ in range(ppo_epochs):
        for state, action, old_log_probs, return_, advantage in ppo_iter(mini_batch_size, states, actions, log_probs, returns, advantages):
            dist, value = model(state)
            functional.reset_net(model)
            entropy = dist.entropy().mean()
            new_log_probs = dist.log_prob(action)

            ratio = (new_log_probs - old_log_probs).exp()
            surr1 = ratio * advantage
            surr2 = torch.clamp(ratio, 1.0 - clip_param, 1.0 + clip_param) * advantage

            actor_loss  = - torch.min(surr1, surr2).mean()
            critic_loss = (return_ - value).pow(2).mean()

            loss = 0.5 * critic_loss + actor_loss - 0.001 * entropy

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

while step_idx < max_steps:

    log_probs = []
    values    = []
    states    = []
    actions   = []
    rewards   = []
    masks     = []
    entropy = 0

    for _ in range(num_steps):
        state = torch.FloatTensor(state).to(device)
        dist, value = model(state)
        functional.reset_net(model)

        action = dist.sample()
        next_state, reward, done, _ = envs.step(torch.max(action, 1)[1].cpu().numpy())

        log_prob = dist.log_prob(action)
        entropy += dist.entropy().mean()

        log_probs.append(log_prob)
        values.append(value)
        rewards.append(torch.FloatTensor(reward).unsqueeze(1).to(device))
        masks.append(torch.FloatTensor(1 - done).unsqueeze(1).to(device))

        states.append(state)
        actions.append(action)

        state = next_state
        step_idx += 1

        if step_idx % 100 == 0:
            test_reward = test_env()
            print('Step: %d, Reward: %.2f' % (step_idx, test_reward))
            writer.add_scalar('Spiking-PPO-' + env_name + '/Reward', test_reward, step_idx)

    next_state = torch.FloatTensor(next_state).to(device)
    _, next_value = model(next_state)
    functional.reset_net(model)
    returns = compute_gae(next_value, rewards, masks, values)

    returns   = torch.cat(returns).detach()
    log_probs = torch.cat(log_probs).detach()
    values    = torch.cat(values).detach()
    states    = torch.cat(states)
    actions   = torch.cat(actions)
    advantage = returns - values

    ppo_update(ppo_epochs, mini_batch_size, states, actions, log_probs, returns, advantage)

It should be emphasized here that, we need to reset the network after each forward process, because SNN is retentive while each trial should be started with a clean network state.

The integrated script can be found here clock_driven/examples/Spiking_PPO.py. And we can start the training process in a Python Console as follows.

>>> python Spiking_PPO.py

Performance comparison between ANN and SNN

Here is the reward curve during the training process of 1e5 episodes:

../_images/Spiking-PPO-CartPole-v0.svg

And here is the result of the ANN version with the same settings. The integrated code can be found here clock_driven/examples/PPO.py.

../_images/PPO-CartPole-v0.svg