297 lines
11 KiB
Python
297 lines
11 KiB
Python
import gymnasium as gym
|
||
from gymnasium import spaces
|
||
import numpy as np
|
||
import mss
|
||
import cv2
|
||
import pyautogui
|
||
import time
|
||
from typing import Tuple, Optional, List
|
||
|
||
|
||
def _centroid_from_mask(mask: np.ndarray) -> Tuple[Optional[int], Optional[int], int]:
|
||
cnt = int(cv2.countNonZero(mask))
|
||
if cnt == 0:
|
||
return None, None, 0
|
||
M = cv2.moments(mask)
|
||
if M["m00"] == 0:
|
||
return None, None, cnt
|
||
return int(M["m10"]/M["m00"]), int(M["m01"]/M["m00"]), cnt
|
||
|
||
|
||
def _centroids_from_contours(mask: np.ndarray,
|
||
ui_exclude_rects: List[Tuple[int,int,int,int]],
|
||
min_area: int,
|
||
circ_min: float,
|
||
aspect_tol: float,
|
||
extent_min: float,
|
||
solidity_min: float) -> List[Tuple[int,int,int]]:
|
||
"""
|
||
Liefert (cx,cy,area) für Konturen, die Bomben-Formkriterien erfüllen.
|
||
ui_exclude_rects: Liste von (x0,y0,x1,y1) in Pixeln relativ zum monitor_area.
|
||
"""
|
||
# UI-Zonen ausmastern
|
||
if ui_exclude_rects:
|
||
h, w = mask.shape
|
||
for (x0,y0,x1,y1) in ui_exclude_rects:
|
||
x0 = max(0, min(w, x0)); x1 = max(0, min(w, x1))
|
||
y0 = max(0, min(h, y0)); y1 = max(0, min(h, y1))
|
||
mask[y0:y1, x0:x1] = 0
|
||
|
||
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||
out = []
|
||
for c in contours:
|
||
area = float(cv2.contourArea(c))
|
||
if area < min_area:
|
||
continue
|
||
|
||
x, y, w, h = cv2.boundingRect(c)
|
||
if h == 0 or w == 0:
|
||
continue
|
||
aspect = w / float(h)
|
||
if not (1.0 - aspect_tol <= aspect <= 1.0 + aspect_tol):
|
||
continue
|
||
|
||
per = float(cv2.arcLength(c, True))
|
||
if per <= 0:
|
||
continue
|
||
circularity = 4.0 * np.pi * area / (per * per)
|
||
if circularity < circ_min:
|
||
continue
|
||
|
||
hull = cv2.convexHull(c)
|
||
hull_area = float(cv2.contourArea(hull))
|
||
if hull_area <= 0:
|
||
continue
|
||
solidity = area / hull_area
|
||
extent = area / float(w * h)
|
||
|
||
if solidity < solidity_min or extent < extent_min:
|
||
continue
|
||
|
||
M = cv2.moments(c)
|
||
if M["m00"] == 0:
|
||
continue
|
||
cx = int(M["m10"] / M["m00"])
|
||
cy = int(M["m01"] / M["m00"])
|
||
out.append((cx, cy, int(area)))
|
||
return out
|
||
|
||
|
||
class FlowerGameEnv(gym.Env):
|
||
"""
|
||
Observation = Dict:
|
||
"image": (84,84,1)
|
||
"state": [tx,ty, fx,fy, bx,by] (bx,by = nächstgelegene gültige Bombe)
|
||
|
||
Actions: 0=W, 1=A, 2=S, 3=D
|
||
|
||
Rewards:
|
||
+0.6 bei Kontakt (<= 95 px) mit Blume
|
||
+0.10 * Distanzverkleinerung zur Blume (auf Bilddiagonale normiert)
|
||
-5.0 wenn Distanz zur nächsten Bombe <= 115 px
|
||
"""
|
||
|
||
metadata = {"render_modes": []}
|
||
|
||
def __init__(self, monitor_area, ui_exclude_rects: Optional[List[Tuple[int,int,int,int]]] = None):
|
||
super().__init__()
|
||
self.monitor_area = monitor_area
|
||
self.sct = mss.mss()
|
||
|
||
# --- Observation & Actions ---
|
||
self.observation_space = spaces.Dict({
|
||
"image": spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8),
|
||
"state": spaces.Box(low=0.0, high=1.0, shape=(6,), dtype=np.float32),
|
||
})
|
||
self.action_space = spaces.Discrete(4)
|
||
|
||
# --- HSV-Grenzen ---
|
||
self.yellow_lower = np.array([15, 40, 200], dtype=np.uint8)
|
||
self.yellow_upper = np.array([25, 120, 255], dtype=np.uint8)
|
||
self.white_lower = np.array([0, 0, 220], dtype=np.uint8)
|
||
self.white_upper = np.array([180, 50, 255], dtype=np.uint8)
|
||
self.black_lower = np.array([0, 0, 0], dtype=np.uint8)
|
||
self.black_upper = np.array([180, 80, 60], dtype=np.uint8)
|
||
self.green1_lower = np.array([30, 80, 80], dtype=np.uint8)
|
||
self.green1_upper = np.array([45, 255, 255], dtype=np.uint8)
|
||
self.green2_lower = np.array([65, 100, 80], dtype=np.uint8)
|
||
self.green2_upper = np.array([90, 255, 255], dtype=np.uint8)
|
||
|
||
self.kernel = np.ones((3, 3), np.uint8)
|
||
|
||
# --- Reward-/Heuristik-Parameter ---
|
||
self.eat_radius_px = 95
|
||
self.collision_dist_px = 115
|
||
self.shaping_scale = 0.10
|
||
self.eat_reward = 0.6
|
||
self.collision_penalty = 5.0
|
||
|
||
self.contact_cooldown_frames = 8
|
||
self._cooldown = 0
|
||
self.prev_dist_to_flower = None
|
||
self.flowers_eaten = 0
|
||
|
||
# --- Bomben-Filterparameter (gegen Score-Schrift) ---
|
||
self.bomb_min_area = 400 # <— Text-Glyphen sind meist kleiner
|
||
self.bomb_circ_min = 0.60 # Kreisförmigkeit (1.0 ist perfekter Kreis)
|
||
self.bomb_aspect_tol = 0.35 # erlaubt 0.65–1.35 Seitenverhältnis
|
||
self.bomb_extent_min = 0.60 # Füllgrad im Bounding-Rect
|
||
self.bomb_solidity_min = 0.85 # gegen ring-/schriftartige Konturen
|
||
|
||
# UI-Ausschlusszonen (optional): [(x0,y0,x1,y1), ...] relativ zum monitor_area
|
||
self.ui_exclude_rects = ui_exclude_rects or []
|
||
|
||
self._last_cache = {
|
||
"turtle_xy": (None, None),
|
||
"flower_xy": (None, None),
|
||
"bombs_xy": [],
|
||
"turtle_found": False,
|
||
"flower_found": False,
|
||
"frame_hw": (1, 1),
|
||
}
|
||
|
||
# ---------------- Gymnasium API ----------------
|
||
def reset(self, seed=None, options=None):
|
||
super().reset(seed=seed)
|
||
self.prev_dist_to_flower = None
|
||
self._cooldown = 0
|
||
self.flowers_eaten = 0
|
||
return self._build_observation(), {}
|
||
|
||
def step(self, action):
|
||
if action == 0: pyautogui.press("w")
|
||
elif action == 1: pyautogui.press("a")
|
||
elif action == 2: pyautogui.press("s")
|
||
elif action == 3: pyautogui.press("d")
|
||
|
||
time.sleep(0.05)
|
||
|
||
obs = self._build_observation()
|
||
reward = self._calculate_reward()
|
||
if self._cooldown > 0: self._cooldown -= 1
|
||
|
||
info = {
|
||
"flowers_eaten": self.flowers_eaten,
|
||
"bombs_expected": self._bombs_expected(self.flowers_eaten),
|
||
"bombs_detected": len(self._last_cache["bombs_xy"])
|
||
}
|
||
return obs, reward, False, False, info
|
||
|
||
# ---------------- Erkennung ----------------
|
||
def _grab_bgr(self):
|
||
raw = np.array(self.sct.grab(self.monitor_area)) # BGRA
|
||
return cv2.cvtColor(raw, cv2.COLOR_BGRA2BGR)
|
||
|
||
def _detect_entities(self, frame_bgr):
|
||
h, w, _ = frame_bgr.shape
|
||
hsv = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2HSV)
|
||
|
||
# Flower
|
||
mw = cv2.inRange(hsv, self.white_lower, self.white_upper)
|
||
my = cv2.inRange(hsv, self.yellow_lower, self.yellow_upper)
|
||
mw = cv2.morphologyEx(mw, cv2.MORPH_DILATE, self.kernel, iterations=1)
|
||
my = cv2.morphologyEx(my, cv2.MORPH_DILATE, self.kernel, iterations=1)
|
||
mf = cv2.bitwise_and(mw, my)
|
||
mf = cv2.morphologyEx(mf, cv2.MORPH_CLOSE, self.kernel, iterations=1)
|
||
fx, fy, _ = _centroid_from_mask(mf)
|
||
flower_found = (fx is not None and fy is not None)
|
||
|
||
# Bombs (mit strenger Filterung & UI-Exklusion)
|
||
mb = cv2.inRange(hsv, self.black_lower, self.black_upper)
|
||
bombs_xy = _centroids_from_contours(
|
||
mb.copy(),
|
||
self.ui_exclude_rects,
|
||
min_area=self.bomb_min_area,
|
||
circ_min=self.bomb_circ_min,
|
||
aspect_tol=self.bomb_aspect_tol,
|
||
extent_min=self.bomb_extent_min,
|
||
solidity_min=self.bomb_solidity_min,
|
||
)
|
||
|
||
# Turtle
|
||
g1 = cv2.inRange(hsv, self.green1_lower, self.green1_upper)
|
||
g2 = cv2.inRange(hsv, self.green2_lower, self.green2_upper)
|
||
mg = cv2.bitwise_or(g1, g2)
|
||
mg = cv2.morphologyEx(mg, cv2.MORPH_OPEN, self.kernel, iterations=1)
|
||
mg = cv2.morphologyEx(mg, cv2.MORPH_DILATE, self.kernel, iterations=1)
|
||
tx, ty, _ = _centroid_from_mask(mg)
|
||
turtle_found = (tx is not None and ty is not None)
|
||
|
||
# State: nächste Bombe relativ zur Turtle
|
||
def nxy(x, y):
|
||
if x is None or y is None: return 0.0, 0.0
|
||
return x / float(w), y / float(h)
|
||
|
||
nbx, nby = 0.0, 0.0
|
||
if turtle_found and bombs_xy:
|
||
txf, tyf = float(tx), float(ty)
|
||
dists = [(np.hypot(bx - txf, by - tyf), (bx, by)) for (bx, by, _a) in bombs_xy]
|
||
_, (nbx_px, nby_px) = min(dists, key=lambda x: x[0])
|
||
nbx, nby = nxy(nbx_px, nby_px)
|
||
|
||
n_tx, n_ty = nxy(tx, ty)
|
||
n_fx, n_fy = nxy(fx, fy)
|
||
|
||
return {
|
||
"state_norm": np.array([n_tx, n_ty, n_fx, n_fy, nbx, nby], dtype=np.float32),
|
||
"turtle_xy": (tx, ty),
|
||
"flower_xy": (fx, fy),
|
||
"bombs_xy": bombs_xy,
|
||
"turtle_found": turtle_found,
|
||
"flower_found": flower_found,
|
||
"frame_hw": (h, w),
|
||
}
|
||
|
||
def _build_observation(self):
|
||
frame_bgr = self._grab_bgr()
|
||
gray = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2GRAY)
|
||
gray = cv2.resize(gray, (84, 84), interpolation=cv2.INTER_AREA)
|
||
gray = np.expand_dims(gray, axis=-1)
|
||
det = self._detect_entities(frame_bgr)
|
||
self._last_cache = det
|
||
return {"image": gray, "state": det["state_norm"]}
|
||
|
||
# ---------------- Reward ----------------
|
||
def _calculate_reward(self) -> float:
|
||
det = self._last_cache
|
||
reward = 0.0
|
||
|
||
tx, ty = det["turtle_xy"]
|
||
fx, fy = det["flower_xy"]
|
||
tf = det["turtle_found"]
|
||
ff = det["flower_found"]
|
||
bombs_xy = det["bombs_xy"]
|
||
h, w = det["frame_hw"]
|
||
|
||
# Distanz-Shaping
|
||
if tf and ff:
|
||
txy = np.array([tx, ty], dtype=np.float32)
|
||
fxy = np.array([fx, fy], dtype=np.float32)
|
||
dist = float(np.linalg.norm(txy - fxy))
|
||
if hasattr(self, "prev_dist_to_flower") and self.prev_dist_to_flower is not None:
|
||
delta = self.prev_dist_to_flower - dist
|
||
reward += self.shaping_scale * (delta / max(1.0, np.hypot(h, w)))
|
||
self.prev_dist_to_flower = dist
|
||
else:
|
||
self.prev_dist_to_flower = None
|
||
|
||
# Eat-Event mit Cooldown
|
||
if self._cooldown == 0 and tf and ff:
|
||
if np.linalg.norm(np.array([tx - fx, ty - fy], dtype=np.float32)) <= self.eat_radius_px:
|
||
reward += self.eat_reward
|
||
self._cooldown = self.contact_cooldown_frames
|
||
self.flowers_eaten += 1
|
||
|
||
# Kollision mit nächster Bombe
|
||
if tf and bombs_xy:
|
||
min_dist = min([np.hypot(tx - bx, ty - by) for (bx, by, _a) in bombs_xy])
|
||
if min_dist <= self.collision_dist_px:
|
||
reward -= self.collision_penalty
|
||
|
||
return float(reward)
|
||
|
||
# ---------------- Hilfsinfo ----------------
|
||
@staticmethod
|
||
def _bombs_expected(flowers_eaten: int) -> int:
|
||
return max(0, flowers_eaten // 5)
|