From c053adef80688fe1fd2e03b5689826bf10662cca Mon Sep 17 00:00:00 2001 From: bzs666 <74452652+bzs666@users.noreply.github.com> Date: Thu, 17 Oct 2024 17:21:21 +0800 Subject: [PATCH] Update deep_q_network.py I hope that the reward system will further provide specific rewards and punishments for each movement of the bird, and encourage the bird not only to successfully pass the next level, but also to try to be smooth (not too fast). --- deep_q_network.py | 50 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/deep_q_network.py b/deep_q_network.py index 1294f96..91acb77 100755 --- a/deep_q_network.py +++ b/deep_q_network.py @@ -8,8 +8,11 @@ import wrapped_flappy_bird as game import random import numpy as np +import numpy as np +import cvlib as cv from collections import deque +BOUNDARY = 0.038 # 表示连续4帧中,首尾帧中小鸟y轴变化距离与图片高的比值,高于这个比值则需要惩罚 GAME = 'bird' # the name of the game being played for log files ACTIONS = 2 # number of valid actions GAMMA = 0.99 # decay rate of past observations @@ -21,6 +24,51 @@ BATCH = 32 # size of minibatch FRAME_PER_ACTION = 1 +def get_highest_conf_bird_center(image): + """ + 处理图像并获取置信度最高的 'bird' 标签的中心坐标。 + 如果没有找到 'bird' 标签,返回 (0, 0)。 + """ + # 调整图像大小为 480x480,转换为 BGR 确保图像数据类型是 uint8 + resized_image = cv2.resize(image, (480, 480)) + bgr_image = cv2.cvtColor(resized_image, cv2.COLOR_RGB2BGR) + bgr_image = bgr_image.astype(np.uint8) + + # 进行物体检测 + bbox, label, conf = cv.detect_common_objects(bgr_image) + + max_conf = 0 + center_x, center_y = 0, 0 + # 遍历检测到的物体,找到置信度最高的 'bird' + for i in range(len(label)): + if label[i] == 'bird': + if conf[i] > max_conf: + max_conf = conf[i] + # 获取边界框并计算中心坐标 + x1, y1, x2, y2 = bbox[i] + center_x = (x1 + x2) / 2 + center_y = (y1 + y2) / 2 + + return center_y + +def new_reward(first_frame, last_frame): + """ + 计算两帧图像中 'bird' 标签的纵坐标变化。 + 并根据变化率,得到具体奖惩数值。 + """ + # 获取 first_frame last_frame 中置信度最高的 'bird' 中心纵坐标 + center_y_first = get_highest_conf_bird_center(first_frame) + center_y_last = get_highest_conf_bird_center(last_frame) + + # 计算纵坐标变化,取绝对值并除以 480 + vertical_change = abs(center_y_first - center_y_last) / 480 + if vertical_change >= BOUNDARY: # 小鸟移动速度过快,实施惩罚与速度正相关 + reward = -vertical_change + else: # 鼓励小鸟平稳缓慢移动,实施奖励与速度负相关 + reward = BOUNDARY - vertical_change + + return reward + def weight_variable(shape): initial = tf.truncated_normal(shape, stddev = 0.01) return tf.Variable(initial) @@ -136,6 +184,8 @@ def trainNetwork(s, readout, h_fc1, sess): # run the selected action and observe next state and reward x_t1_colored, r_t, terminal = game_state.frame_step(a_t) + # 细化奖励 + r_t = new_reward(x_t1_colored, s_t[:, :, -1]) x_t1 = cv2.cvtColor(cv2.resize(x_t1_colored, (80, 80)), cv2.COLOR_BGR2GRAY) ret, x_t1 = cv2.threshold(x_t1, 1, 255, cv2.THRESH_BINARY) x_t1 = np.reshape(x_t1, (80, 80, 1))