DDQN代码实现
代码分析
def replay(self):
for _ in range(10):
states, actions, rewards, next_states, done = self.buffer.sample()
next_target = self.target_model(next_states).numpy()
next_q_value = next_target[
range(args.batch_size), np.argmax(self.model(next_states), axis=1)
]
target[range(args.batch_size), actions] = rewards + (1 - done) * args.gamma * next_q_value
with tf.GradientTape() as tape:
q_pred = self.model(states)
loss = tf.losses.mean_squared_error(target, q_pred)
grads = tape.gradient(loss, self.model.trainable_weights)
self.model_optim.apply_gradients(zip(grads, self.model.trainable_weights))
- DDQN的实现和DQN只有求next_q_value这一行代码不同。在DDQN中,是使用Q网络来选择最优动作,再使用target网络来计算下一时刻的Q值,而在DQN with Target中两步使用的都是target网络。
- np.argmax(self.model(next_states), axis=1)语句的作用就是使用Q网络选择batch_size个state的最优动作,并返回一个长batch_size的一维数组,数组的每个值对应其中一个state的最有动作。
训练结果
1000次
DQN with Target代码实现