spikingjelly.clock_driven.examples.dqn_cart_pole 源代码

import gym
import math
import random
from collections import namedtuple
from itertools import count
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from spikingjelly.clock_driven import neuron, functional
import os


Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))


[文档]class ReplayMemory(object): def __init__(self, capacity): self.capacity = capacity self.memory = [] self.position = 0
[文档] def push(self, *args): if len(self.memory) < self.capacity: self.memory.append(None) self.memory[self.position] = Transition(*args) self.position = (self.position + 1) % self.capacity
[文档] def sample(self, batch_size): return random.sample(self.memory, batch_size)
def __len__(self): return len(self.memory)
[文档]class NonSpikingLIFNode(neuron.LIFNode): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[文档] def forward(self, dv: torch.Tensor): self.neuronal_charge(dv) # self.neuronal_fire() # self.neuronal_reset() return self.v
[文档]class DQSN(nn.Module): def __init__(self, hidden_num): super().__init__() self.fc = nn.Sequential( nn.Linear(4, hidden_num), neuron.IFNode(), nn.Linear(hidden_num, 2), NonSpikingLIFNode(tau=2.0) ) self.T = 16
[文档] def forward(self, x): for t in range(self.T): self.fc(x) return self.fc[-1].v
[文档]def train(device, root, hidden_num=128, num_episodes=256): BATCH_SIZE = 128 GAMMA = 0.999 EPS_START = 0.9 EPS_END = 0.05 EPS_DECAY = 200 TARGET_UPDATE = 10 steps_done = 0 env = gym.make('CartPole-v0').unwrapped policy_net = DQSN(hidden_num).to(device) target_net = DQSN(hidden_num).to(device) target_net.load_state_dict(policy_net.state_dict()) target_net.eval() optimizer = optim.Adam(policy_net.parameters()) memory = ReplayMemory(10000) def select_action(state, steps_done): sample = random.random() eps_threshold = EPS_END + (EPS_START - EPS_END) * \ math.exp(-1. * steps_done / EPS_DECAY) if sample > eps_threshold: with torch.no_grad(): ac = policy_net(state).max(1)[1].view(1, 1) functional.reset_net(policy_net) return ac else: return torch.tensor([[random.randrange(2)]], device=device, dtype=torch.long) def optimize_model(): if len(memory) < BATCH_SIZE: return transitions = memory.sample(BATCH_SIZE) batch = Transition(*zip(*transitions)) non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=device, dtype=torch.bool) non_final_next_states = torch.cat([s for s in batch.next_state if s is not None]) state_batch = torch.cat(batch.state) action_batch = torch.cat(batch.action) reward_batch = torch.cat(batch.reward) state_action_values = policy_net(state_batch).gather(1, action_batch) next_state_values = torch.zeros(BATCH_SIZE, device=device) next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach() functional.reset_net(target_net) expected_state_action_values = (next_state_values * GAMMA) + reward_batch loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1)) optimizer.zero_grad() loss.backward() for param in policy_net.parameters(): if param.grad is not None: param.grad.data.clamp_(-1, 1) optimizer.step() functional.reset_net(policy_net) max_duration = 0 max_pt_path = os.path.join(root, f'policy_net_{hidden_num}_max.pt') pt_path = os.path.join(root, f'policy_net_{hidden_num}.pt') episode_durations = [] for i_episode in range(num_episodes): # Initialize the environment and state env.reset() state = torch.zeros([1, 4], dtype=torch.float, device=device) for t in count(): action = select_action(state, steps_done) steps_done += 1 next_state, reward, done, _ = env.step(action.item()) next_state = torch.from_numpy(next_state).float().to(device).unsqueeze(0) reward = torch.tensor([reward], device=device) if done: next_state = None memory.push(state, action, next_state, reward) state = next_state if done and t + 1 > max_duration: max_duration = t + 1 torch.save(policy_net.state_dict(), max_pt_path) print(f'max_duration={max_duration}, save models') optimize_model() if done: print(f'i_episode={i_episode}, duration={t + 1}') episode_durations.append(t + 1) break if i_episode % TARGET_UPDATE == 0: target_net.load_state_dict(policy_net.state_dict()) print('complete') torch.save(policy_net.state_dict(), pt_path) print('state_dict path is', pt_path)
# i_episode = 492, duration = 169 # i_episode = 493, duration = 356 # i_episode = 494, duration = 296 # i_episode = 495, duration = 301 # i_episode = 496, duration = 287 # i_episode = 497, duration = 303 # i_episode = 498, duration = 285 # i_episode = 499, duration = 329 # i_episode = 500, duration = 303 # i_episode = 501, duration = 389 # i_episode = 502, duration = 544 # i_episode = 503, duration = 619 # i_episode = 504, duration = 443 # i_episode = 505, duration = 441 # i_episode = 506, duration = 322 # i_episode = 507, duration = 575 # i_episode = 508, duration = 384 # i_episode = 509, duration = 715 # i_episode = 510, duration = 3051 # i_episode = 511, duration = 571 # complete # state_dict path is./ policy_net_256.pt
[文档]def play(device, pt_path, hidden_num, played_frames=60, save_fig_num=0, fig_dir=None, figsize=(12, 6), firing_rates_plot_type='bar', heatmap_shape=None): import numpy as np from matplotlib import pyplot as plt import matplotlib.ticker plt.rcParams['figure.figsize'] = figsize # plt.rcParams['figure.dpi'] = 200 plt.ion() env = gym.make('CartPole-v0').unwrapped policy_net = DQSN(hidden_num).to(device) policy_net.load_state_dict(torch.load(pt_path, map_location=device)) env.reset() state = torch.zeros([1, 4], dtype=torch.float, device=device) with torch.no_grad(): functional.set_monitor(policy_net, True) delta_lim = 0 over_score = 1e9 for i in count(): # plt.clf() LIF_v = policy_net(state) # shape=[1, 2] action = LIF_v.max(1)[1].view(1, 1).item() if firing_rates_plot_type == 'bar': plt.subplot2grid((2, 9), (1, 0), colspan=3) elif firing_rates_plot_type == 'heatmap': plt.subplot2grid((2, 3), (1, 0)) plt.xticks(np.arange(2), ('Left', 'Right')) plt.ylabel('Voltage') plt.title('Voltage of LIF neurons at last time step') delta_lim = (LIF_v.max() - LIF_v.min()) * 0.5 plt.ylim(LIF_v.min() - delta_lim, LIF_v.max() + delta_lim) plt.yticks([]) plt.text(0, LIF_v[0][0], str(round(LIF_v[0][0].item(), 2)), ha='center') plt.text(1, LIF_v[0][1], str(round(LIF_v[0][1].item(), 2)), ha='center') plt.bar(np.arange(2), LIF_v.squeeze(), color=['r', 'gray'] if action == 0 else ['gray', 'r'], width=0.5) if LIF_v.min() - delta_lim < 0: plt.axhline(0, color='black', linewidth=0.1) IF_spikes = np.asarray(policy_net.fc[1].monitor['s']) # shape=[16, 1, 256] firing_rates = IF_spikes.mean(axis=0).squeeze() if firing_rates_plot_type == 'bar': plt.subplot2grid((2, 9), (0, 4), rowspan=2, colspan=5) elif firing_rates_plot_type == 'heatmap': plt.subplot2grid((2, 3), (0, 1), rowspan=2, colspan=2) plt.title('Firing rates of IF neurons') if firing_rates_plot_type == 'bar': # 绘制柱状图 plt.xlabel('Neuron index') plt.ylabel('Firing rate') plt.xlim(0, firing_rates.size) plt.ylim(0, 1.01) plt.bar(np.arange(firing_rates.size), firing_rates, width=0.5) elif firing_rates_plot_type == 'heatmap': # 绘制热力图 heatmap = plt.imshow(firing_rates.reshape(heatmap_shape), vmin=0, vmax=1, cmap='ocean') plt.gca().invert_yaxis() cbar = heatmap.figure.colorbar(heatmap) cbar.ax.set_ylabel('Magnitude', rotation=90, va='top') functional.reset_net(policy_net) subtitle = f'Position={state[0][0].item(): .2f}, Velocity={state[0][1].item(): .2f}, Pole Angle={state[0][2].item(): .2f}, Pole Velocity At Tip={state[0][3].item(): .2f}, Score={i}' state, reward, done, _ = env.step(action) if done: over_score = min(over_score, i) subtitle = f'Game over, Score={over_score}' plt.suptitle(subtitle) state = torch.from_numpy(state).float().to(device).unsqueeze(0) screen = env.render(mode='rgb_array').copy() screen[300, :, :] = 0 # 画出黑线 if firing_rates_plot_type == 'bar': plt.subplot2grid((2, 9), (0, 0), colspan=3) elif firing_rates_plot_type == 'heatmap': plt.subplot2grid((2, 3), (0, 0)) plt.xticks([]) plt.yticks([]) plt.title('Game screen') plt.imshow(screen, interpolation='bicubic') plt.pause(0.001) if i < save_fig_num: plt.savefig(os.path.join(fig_dir, f'{i}.png')) if done and i >= played_frames: env.close() plt.close() break