spikingjelly.clock_driven.examples.Spiking_DQN_state 源代码

import gym
import math
import random
import numpy as np
from collections import namedtuple
from itertools import count
import matplotlib
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from spikingjelly.clock_driven import neuron, functional
import os

from torch.utils.tensorboard import SummaryWriter

Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))

[文档]class ReplayMemory(object): def __init__(self, capacity): self.capacity = capacity self.memory = [] self.position = 0
[文档] def push(self, *args): if len(self.memory) < self.capacity: self.memory.append(None) self.memory[self.position] = Transition(*args) self.position = (self.position + 1) % self.capacity
[文档] def sample(self, batch_size): return random.sample(self.memory, batch_size)
def __len__(self): return len(self.memory)
[文档]class NonSpikingLIFNode(neuron.LIFNode): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[文档] def forward(self, dv: torch.Tensor): self.neuronal_charge(dv) # self.neuronal_fire() # self.neuronal_reset() return self.v
# Spiking DQN algorithm
[文档]class DQSN(nn.Module): def __init__(self, input_size, hidden_size, output_size, T=16): super().__init__() self.fc = nn.Sequential( nn.Linear(input_size, hidden_size), neuron.IFNode(), nn.Linear(hidden_size, output_size), NonSpikingLIFNode(tau=2.0) ) self.T = T
[文档] def forward(self, x): for t in range(self.T): self.fc(x) return self.fc[-1].v
[文档]def train(use_cuda, model_dir, log_dir, env_name, hidden_size, num_episodes, seed): BATCH_SIZE = 128 GAMMA = 0.999 EPS_START = 0.9 EPS_END = 0.05 EPS_DECAY = 200 TARGET_UPDATE = 10 T = 16 random.seed(seed) np.random.seed(seed) torch.cuda.manual_seed(seed) torch.manual_seed(seed) torch.backends.cudnn.deterministic = True device = torch.device("cuda" if use_cuda else "cpu") steps_done = 0 writer = SummaryWriter(logdir=log_dir) env = gym.make(env_name).unwrapped env.seed(seed) n_states = env.observation_space.shape[0] n_actions = env.action_space.n policy_net = DQSN(n_states, hidden_size, n_actions, T).to(device) target_net = DQSN(n_states, hidden_size, n_actions, T).to(device) target_net.load_state_dict(policy_net.state_dict()) target_net.eval() optimizer = optim.Adam(policy_net.parameters()) memory = ReplayMemory(10000) def select_action(state, steps_done): sample = random.random() eps_threshold = EPS_END + (EPS_START - EPS_END) * \ math.exp(-1. * steps_done / EPS_DECAY) if sample > eps_threshold: with torch.no_grad(): ac = policy_net(state).max(1)[1].view(1, 1) functional.reset_net(policy_net) return ac else: return torch.tensor([[random.randrange(env.action_space.n)]], device=device, dtype=torch.long) def optimize_model(): if len(memory) < BATCH_SIZE: return transitions = memory.sample(BATCH_SIZE) batch = Transition(*zip(*transitions)) non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=device, dtype=torch.bool) non_final_next_states = torch.cat([s for s in batch.next_state if s is not None]) state_batch = torch.cat(batch.state) action_batch = torch.cat(batch.action) reward_batch = torch.cat(batch.reward) state_action_values = policy_net(state_batch).gather(1, action_batch) next_state_values = torch.zeros(BATCH_SIZE, device=device) next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach() functional.reset_net(target_net) expected_state_action_values = (next_state_values * GAMMA) + reward_batch loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1)) optimizer.zero_grad() loss.backward() for param in policy_net.parameters(): if param.grad is not None: param.grad.data.clamp_(-1, 1) optimizer.step() functional.reset_net(policy_net) max_reward = 0 max_pt_path = os.path.join(model_dir, f'policy_net_{hidden_size}_max.pt') pt_path = os.path.join(model_dir, f'policy_net_{hidden_size}.pt') for i_episode in range(num_episodes): # Initialize the environment and state env.reset() state = torch.zeros([1, n_states], dtype=torch.float, device=device) total_reward = 0 for t in count(): action = select_action(state, steps_done) steps_done += 1 next_state, reward, done, _ = env.step(action.item()) total_reward += reward next_state = torch.from_numpy(next_state).float().to(device).unsqueeze(0) reward = torch.tensor([reward], device=device) if done: next_state = None memory.push(state, action, next_state, reward) state = next_state if done and total_reward > max_reward: max_reward = total_reward torch.save(policy_net.state_dict(), max_pt_path) print(f'max_reward={max_reward}, save models') optimize_model() if done: print(f'Episode: {i_episode}, Reward: {total_reward}') writer.add_scalar('Spiking-DQN-state-' + env_name + '/Reward', total_reward, i_episode) break if i_episode % TARGET_UPDATE == 0: target_net.load_state_dict(policy_net.state_dict()) print('complete') torch.save(policy_net.state_dict(), pt_path) print('state_dict path is', pt_path) writer.close()
[文档]def play(use_cuda, pt_path, env_name, hidden_size, played_frames=60, save_fig_num=0, fig_dir=None, figsize=(12, 6), firing_rates_plot_type='bar', heatmap_shape=None): T = 16 plt.rcParams['figure.figsize'] = figsize plt.ion() env = gym.make(env_name).unwrapped device = torch.device("cuda" if use_cuda else "cpu") n_states = env.observation_space.shape[0] n_actions = env.action_space.n policy_net = DQSN(n_states, hidden_size, n_actions, T).to(device) policy_net.load_state_dict(torch.load(pt_path, map_location=device)) env.reset() state = torch.zeros([1, n_states], dtype=torch.float, device=device) with torch.no_grad(): functional.set_monitor(policy_net, True) delta_lim = 0 over_score = 1e9 for i in count(): LIF_v = policy_net(state) # shape=[1, 2] action = LIF_v.max(1)[1].view(1, 1).item() if firing_rates_plot_type == 'bar': plt.subplot2grid((2, 9), (1, 0), colspan=3) elif firing_rates_plot_type == 'heatmap': plt.subplot2grid((2, 3), (1, 0)) plt.xticks(np.arange(2), ('Left', 'Right')) plt.ylabel('Voltage') plt.title('Voltage of LIF neurons at last time step') delta_lim = (LIF_v.max() - LIF_v.min()) * 0.5 plt.ylim(LIF_v.min() - delta_lim, LIF_v.max() + delta_lim) plt.yticks([]) plt.text(0, LIF_v[0][0], str(round(LIF_v[0][0].item(), 2)), ha='center') plt.text(1, LIF_v[0][1], str(round(LIF_v[0][1].item(), 2)), ha='center') plt.bar(np.arange(2), LIF_v.squeeze(), color=['r', 'gray'] if action == 0 else ['gray', 'r'], width=0.5) if LIF_v.min() - delta_lim < 0: plt.axhline(0, color='black', linewidth=0.1) IF_spikes = np.asarray(policy_net.fc[1].monitor['s']) # shape=[16, 1, 256] firing_rates = IF_spikes.mean(axis=0).squeeze() if firing_rates_plot_type == 'bar': plt.subplot2grid((2, 9), (0, 4), rowspan=2, colspan=5) elif firing_rates_plot_type == 'heatmap': plt.subplot2grid((2, 3), (0, 1), rowspan=2, colspan=2) plt.title('Firing rates of IF neurons') if firing_rates_plot_type == 'bar': # 绘制柱状图 plt.xlabel('Neuron index') plt.ylabel('Firing rate') plt.xlim(0, firing_rates.size) plt.ylim(0, 1.01) plt.bar(np.arange(firing_rates.size), firing_rates, width=0.5) elif firing_rates_plot_type == 'heatmap': # 绘制热力图 heatmap = plt.imshow(firing_rates.reshape(heatmap_shape), vmin=0, vmax=1, cmap='ocean') plt.gca().invert_yaxis() cbar = heatmap.figure.colorbar(heatmap) cbar.ax.set_ylabel('Magnitude', rotation=90, va='top') functional.reset_net(policy_net) subtitle = f'Position={state[0][0].item(): .2f}, Velocity={state[0][1].item(): .2f}, Pole Angle={state[0][2].item(): .2f}, Pole Velocity At Tip={state[0][3].item(): .2f}, Score={i}' state, reward, done, _ = env.step(action) if done: over_score = min(over_score, i) subtitle = f'Game over, Score={over_score}' plt.suptitle(subtitle) state = torch.from_numpy(state).float().to(device).unsqueeze(0) screen = env.render(mode='rgb_array').copy() screen[300, :, :] = 0 # 画出黑线 if firing_rates_plot_type == 'bar': plt.subplot2grid((2, 9), (0, 0), colspan=3) elif firing_rates_plot_type == 'heatmap': plt.subplot2grid((2, 3), (0, 0)) plt.xticks([]) plt.yticks([]) plt.title('Game screen') plt.imshow(screen, interpolation='bicubic') plt.pause(0.001) if i < save_fig_num: plt.savefig(os.path.join(fig_dir, f'{i}.png')) if done and i >= played_frames: env.close() plt.close() break ''' train(use_cuda=False, model_dir='./model/CartPole-v0/state', log_dir='./log', env_name='CartPole-v0', \ hidden_size=256, num_episodes=500, seed=1) ''' play(use_cuda=False, pt_path='./model/CartPole-v0/policy_net_256_max.pt', env_name='CartPole-v0', \ hidden_size=256, played_frames=300)