diff --git a/deep_q_network.py b/deep_q_network.py index 1294f96..91acb77 100755 --- a/deep_q_network.py +++ b/deep_q_network.py @@ -8,8 +8,11 @@ import wrapped_flappy_bird as game import random import numpy as np +import numpy as np +import cvlib as cv from collections import deque +BOUNDARY = 0.038 # 表示连续4帧中,首尾帧中小鸟y轴变化距离与图片高的比值,高于这个比值则需要惩罚 GAME = 'bird' # the name of the game being played for log files ACTIONS = 2 # number of valid actions GAMMA = 0.99 # decay rate of past observations @@ -21,6 +24,51 @@ BATCH = 32 # size of minibatch FRAME_PER_ACTION = 1 +def get_highest_conf_bird_center(image): + """ + 处理图像并获取置信度最高的 'bird' 标签的中心坐标。 + 如果没有找到 'bird' 标签,返回 (0, 0)。 + """ + # 调整图像大小为 480x480,转换为 BGR 确保图像数据类型是 uint8 + resized_image = cv2.resize(image, (480, 480)) + bgr_image = cv2.cvtColor(resized_image, cv2.COLOR_RGB2BGR) + bgr_image = bgr_image.astype(np.uint8) + + # 进行物体检测 + bbox, label, conf = cv.detect_common_objects(bgr_image) + + max_conf = 0 + center_x, center_y = 0, 0 + # 遍历检测到的物体,找到置信度最高的 'bird' + for i in range(len(label)): + if label[i] == 'bird': + if conf[i] > max_conf: + max_conf = conf[i] + # 获取边界框并计算中心坐标 + x1, y1, x2, y2 = bbox[i] + center_x = (x1 + x2) / 2 + center_y = (y1 + y2) / 2 + + return center_y + +def new_reward(first_frame, last_frame): + """ + 计算两帧图像中 'bird' 标签的纵坐标变化。 + 并根据变化率,得到具体奖惩数值。 + """ + # 获取 first_frame last_frame 中置信度最高的 'bird' 中心纵坐标 + center_y_first = get_highest_conf_bird_center(first_frame) + center_y_last = get_highest_conf_bird_center(last_frame) + + # 计算纵坐标变化,取绝对值并除以 480 + vertical_change = abs(center_y_first - center_y_last) / 480 + if vertical_change >= BOUNDARY: # 小鸟移动速度过快,实施惩罚与速度正相关 + reward = -vertical_change + else: # 鼓励小鸟平稳缓慢移动,实施奖励与速度负相关 + reward = BOUNDARY - vertical_change + + return reward + def weight_variable(shape): initial = tf.truncated_normal(shape, stddev = 0.01) return tf.Variable(initial) @@ -136,6 +184,8 @@ def trainNetwork(s, readout, h_fc1, sess): # run the selected action and observe next state and reward x_t1_colored, r_t, terminal = game_state.frame_step(a_t) + # 细化奖励 + r_t = new_reward(x_t1_colored, s_t[:, :, -1]) x_t1 = cv2.cvtColor(cv2.resize(x_t1_colored, (80, 80)), cv2.COLOR_BGR2GRAY) ret, x_t1 = cv2.threshold(x_t1, 1, 255, cv2.THRESH_BINARY) x_t1 = np.reshape(x_t1, (80, 80, 1))