initial commit
This commit is contained in:
commit
c45b902e8a
28
coordinate_viewer.py
Normal file
28
coordinate_viewer.py
Normal file
|
|
@ -0,0 +1,28 @@
|
||||||
|
import pyautogui
|
||||||
|
import mss
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# while True:
|
||||||
|
# print(pyautogui.position())
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
monitor_area = {"top": 120, "left": 330, "width": 1900, "height": 1263}
|
||||||
|
|
||||||
|
sct = mss.mss()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# Screenshot aufnehmen
|
||||||
|
img = np.array(sct.grab(monitor_area))
|
||||||
|
img_bgr = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
|
||||||
|
|
||||||
|
# Anzeige
|
||||||
|
cv2.imshow("Monitor-Ausschnitt", img_bgr)
|
||||||
|
|
||||||
|
# Mit 'q' beenden
|
||||||
|
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||||
|
break
|
||||||
|
|
||||||
|
cv2.destroyAllWindows()
|
||||||
176
debug_viewer.py
Normal file
176
debug_viewer.py
Normal file
|
|
@ -0,0 +1,176 @@
|
||||||
|
import mss
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
|
||||||
|
# ========= Dein Spielausschnitt =========
|
||||||
|
monitor_area = {"top": 120, "left": 330, "width": 1900, "height": 1263}
|
||||||
|
|
||||||
|
# ========= HSV-Grenzen =========
|
||||||
|
yellow_lower = np.array([15, 40, 200], dtype=np.uint8)
|
||||||
|
yellow_upper = np.array([25, 120, 255], dtype=np.uint8)
|
||||||
|
white_lower = np.array([0, 0, 220], dtype=np.uint8)
|
||||||
|
white_upper = np.array([180, 50, 255], dtype=np.uint8)
|
||||||
|
|
||||||
|
black_lower = np.array([0, 0, 0], dtype=np.uint8)
|
||||||
|
black_upper = np.array([180, 80, 60], dtype=np.uint8)
|
||||||
|
|
||||||
|
green1_lower = np.array([30, 80, 80], dtype=np.uint8)
|
||||||
|
green1_upper = np.array([45, 255, 255], dtype=np.uint8)
|
||||||
|
green2_lower = np.array([65, 100, 80], dtype=np.uint8)
|
||||||
|
green2_upper = np.array([90, 255, 255], dtype=np.uint8)
|
||||||
|
|
||||||
|
kernel = np.ones((3,3), np.uint8)
|
||||||
|
|
||||||
|
# Radien
|
||||||
|
EAT_RADIUS = 95
|
||||||
|
COLL_RADIUS = 115
|
||||||
|
|
||||||
|
# Bomben-Filter
|
||||||
|
BOMB_MIN_AREA = 400 # angepasst!
|
||||||
|
BOMB_CIRC_MIN = 0.60
|
||||||
|
BOMB_ASPECT_TOL = 0.35
|
||||||
|
BOMB_EXTENT_MIN = 0.60
|
||||||
|
BOMB_SOLIDITY_MIN = 0.85
|
||||||
|
|
||||||
|
# Fenster-Skalierung (0.7 = 70 % Größe)
|
||||||
|
WINDOW_SCALE = 0.8
|
||||||
|
|
||||||
|
# Anzeige-Modi
|
||||||
|
MODE_OVERLAY, MODE_FLOWER_MASK, MODE_BOMB_MASK, MODE_TURTLE_MASK = 0,1,2,3
|
||||||
|
mode = MODE_OVERLAY
|
||||||
|
|
||||||
|
|
||||||
|
def centroid(mask):
|
||||||
|
cnt = int(cv2.countNonZero(mask))
|
||||||
|
if cnt == 0: return None, None, 0
|
||||||
|
M = cv2.moments(mask)
|
||||||
|
if M["m00"] == 0: return None, None, cnt
|
||||||
|
return int(M["m10"]/M["m00"]), int(M["m01"]/M["m00"]), cnt
|
||||||
|
|
||||||
|
|
||||||
|
def bomb_centroids_filtered(mask):
|
||||||
|
contours,_ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||||
|
out=[]
|
||||||
|
for c in contours:
|
||||||
|
area = float(cv2.contourArea(c))
|
||||||
|
if area < BOMB_MIN_AREA:
|
||||||
|
continue
|
||||||
|
x,y,w,h = cv2.boundingRect(c)
|
||||||
|
if w == 0 or h == 0:
|
||||||
|
continue
|
||||||
|
aspect = w/float(h)
|
||||||
|
if not (1.0 - BOMB_ASPECT_TOL <= aspect <= 1.0 + BOMB_ASPECT_TOL):
|
||||||
|
continue
|
||||||
|
per = float(cv2.arcLength(c, True))
|
||||||
|
if per <= 0:
|
||||||
|
continue
|
||||||
|
circularity = 4.0 * np.pi * area / (per * per)
|
||||||
|
if circularity < BOMB_CIRC_MIN:
|
||||||
|
continue
|
||||||
|
hull = cv2.convexHull(c)
|
||||||
|
hull_area = float(cv2.contourArea(hull))
|
||||||
|
if hull_area <= 0:
|
||||||
|
continue
|
||||||
|
solidity = area / hull_area
|
||||||
|
extent = area / float(w*h)
|
||||||
|
if solidity < BOMB_SOLIDITY_MIN or extent < BOMB_EXTENT_MIN:
|
||||||
|
continue
|
||||||
|
M = cv2.moments(c)
|
||||||
|
if M["m00"] == 0:
|
||||||
|
continue
|
||||||
|
cx = int(M["m10"]/M["m00"])
|
||||||
|
cy = int(M["m01"]/M["m00"])
|
||||||
|
out.append((cx, cy, int(area)))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def detect_all(frame_bgr):
|
||||||
|
hsv = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2HSV)
|
||||||
|
mw = cv2.inRange(hsv, white_lower, white_upper)
|
||||||
|
my = cv2.inRange(hsv, yellow_lower, yellow_upper)
|
||||||
|
mw = cv2.morphologyEx(mw, cv2.MORPH_DILATE, kernel, iterations=1)
|
||||||
|
my = cv2.morphologyEx(my, cv2.MORPH_DILATE, kernel, iterations=1)
|
||||||
|
mf = cv2.bitwise_and(mw, my)
|
||||||
|
mf = cv2.morphologyEx(mf, cv2.MORPH_CLOSE, kernel, iterations=1)
|
||||||
|
fx, fy, _ = centroid(mf)
|
||||||
|
|
||||||
|
mb = cv2.inRange(hsv, black_lower, black_upper)
|
||||||
|
bombs = bomb_centroids_filtered(mb)
|
||||||
|
|
||||||
|
g1 = cv2.inRange(hsv, green1_lower, green1_upper)
|
||||||
|
g2 = cv2.inRange(hsv, green2_lower, green2_upper)
|
||||||
|
mg = cv2.bitwise_or(g1, g2)
|
||||||
|
mg = cv2.morphologyEx(mg, cv2.MORPH_OPEN, kernel, iterations=1)
|
||||||
|
mg = cv2.morphologyEx(mg, cv2.MORPH_DILATE, kernel, iterations=1)
|
||||||
|
tx, ty, _ = centroid(mg)
|
||||||
|
|
||||||
|
masks = {"flower": mf, "bomb": mb, "turtle": mg}
|
||||||
|
return (fx,fy), bombs, (tx,ty), masks
|
||||||
|
|
||||||
|
|
||||||
|
def draw_overlay(frame, flower, bombs, turtle, fps):
|
||||||
|
fx, fy = flower
|
||||||
|
tx, ty = turtle
|
||||||
|
if fx is not None:
|
||||||
|
cv2.circle(frame, (fx,fy), 8, (0,255,255), 2)
|
||||||
|
cv2.putText(frame, "Flower", (fx+10, fy-10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,255), 2)
|
||||||
|
nearest = None
|
||||||
|
if bombs and tx is not None:
|
||||||
|
nearest = min(bombs, key=lambda b: np.hypot(b[0]-tx, b[1]-ty))
|
||||||
|
for (bx,by,_) in bombs:
|
||||||
|
color = (60,60,60); thick = 2
|
||||||
|
if nearest and (bx,by)==(nearest[0],nearest[1]):
|
||||||
|
color = (0,0,255); thick = 3
|
||||||
|
cv2.circle(frame, (bx,by), 10, color, thick)
|
||||||
|
if tx is not None:
|
||||||
|
cv2.circle(frame, (tx,ty), 8, (0,200,0), 2)
|
||||||
|
cv2.putText(frame, "Turtle", (tx+10, ty-10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,200,0), 2)
|
||||||
|
cv2.circle(frame, (tx,ty), EAT_RADIUS, (0,255,0), 1)
|
||||||
|
cv2.circle(frame, (tx,ty), COLL_RADIUS, (0,0,255), 1)
|
||||||
|
cv2.putText(frame, f"FPS: {fps:.1f}", (20,40), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255,255,255), 2)
|
||||||
|
return frame
|
||||||
|
|
||||||
|
|
||||||
|
def colorize(mask):
|
||||||
|
return cv2.applyColorMap(
|
||||||
|
cv2.normalize(mask, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8),
|
||||||
|
cv2.COLORMAP_JET
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
global mode
|
||||||
|
sct = mss.mss()
|
||||||
|
prev = time.time()
|
||||||
|
fps = 0.0
|
||||||
|
while True:
|
||||||
|
raw = np.array(sct.grab(monitor_area))
|
||||||
|
frame = cv2.cvtColor(raw, cv2.COLOR_BGRA2BGR)
|
||||||
|
flower, bombs, turtle, masks = detect_all(frame)
|
||||||
|
now = time.time()
|
||||||
|
dt = now - prev; prev = now
|
||||||
|
if dt > 0: fps = 1.0/dt
|
||||||
|
if mode == MODE_OVERLAY:
|
||||||
|
out = draw_overlay(frame.copy(), flower, bombs, turtle, fps)
|
||||||
|
elif mode == MODE_FLOWER_MASK:
|
||||||
|
out = colorize(masks["flower"])
|
||||||
|
elif mode == MODE_BOMB_MASK:
|
||||||
|
out = colorize(masks["bomb"])
|
||||||
|
elif mode == MODE_TURTLE_MASK:
|
||||||
|
out = colorize(masks["turtle"])
|
||||||
|
# --- hier skalieren ---
|
||||||
|
if WINDOW_SCALE != 1.0:
|
||||||
|
out = cv2.resize(out, (int(out.shape[1]*WINDOW_SCALE), int(out.shape[0]*WINDOW_SCALE)))
|
||||||
|
cv2.imshow("Debug Viewer", out)
|
||||||
|
key = cv2.waitKey(1) & 0xFF
|
||||||
|
if key == ord('q'): break
|
||||||
|
elif key == ord('0'): mode = MODE_OVERLAY
|
||||||
|
elif key == ord('1'): mode = MODE_FLOWER_MASK
|
||||||
|
elif key == ord('2'): mode = MODE_BOMB_MASK
|
||||||
|
elif key == ord('3'): mode = MODE_TURTLE_MASK
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
296
flower_game_env.py
Normal file
296
flower_game_env.py
Normal file
|
|
@ -0,0 +1,296 @@
|
||||||
|
import gymnasium as gym
|
||||||
|
from gymnasium import spaces
|
||||||
|
import numpy as np
|
||||||
|
import mss
|
||||||
|
import cv2
|
||||||
|
import pyautogui
|
||||||
|
import time
|
||||||
|
from typing import Tuple, Optional, List
|
||||||
|
|
||||||
|
|
||||||
|
def _centroid_from_mask(mask: np.ndarray) -> Tuple[Optional[int], Optional[int], int]:
|
||||||
|
cnt = int(cv2.countNonZero(mask))
|
||||||
|
if cnt == 0:
|
||||||
|
return None, None, 0
|
||||||
|
M = cv2.moments(mask)
|
||||||
|
if M["m00"] == 0:
|
||||||
|
return None, None, cnt
|
||||||
|
return int(M["m10"]/M["m00"]), int(M["m01"]/M["m00"]), cnt
|
||||||
|
|
||||||
|
|
||||||
|
def _centroids_from_contours(mask: np.ndarray,
|
||||||
|
ui_exclude_rects: List[Tuple[int,int,int,int]],
|
||||||
|
min_area: int,
|
||||||
|
circ_min: float,
|
||||||
|
aspect_tol: float,
|
||||||
|
extent_min: float,
|
||||||
|
solidity_min: float) -> List[Tuple[int,int,int]]:
|
||||||
|
"""
|
||||||
|
Liefert (cx,cy,area) für Konturen, die Bomben-Formkriterien erfüllen.
|
||||||
|
ui_exclude_rects: Liste von (x0,y0,x1,y1) in Pixeln relativ zum monitor_area.
|
||||||
|
"""
|
||||||
|
# UI-Zonen ausmastern
|
||||||
|
if ui_exclude_rects:
|
||||||
|
h, w = mask.shape
|
||||||
|
for (x0,y0,x1,y1) in ui_exclude_rects:
|
||||||
|
x0 = max(0, min(w, x0)); x1 = max(0, min(w, x1))
|
||||||
|
y0 = max(0, min(h, y0)); y1 = max(0, min(h, y1))
|
||||||
|
mask[y0:y1, x0:x1] = 0
|
||||||
|
|
||||||
|
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||||
|
out = []
|
||||||
|
for c in contours:
|
||||||
|
area = float(cv2.contourArea(c))
|
||||||
|
if area < min_area:
|
||||||
|
continue
|
||||||
|
|
||||||
|
x, y, w, h = cv2.boundingRect(c)
|
||||||
|
if h == 0 or w == 0:
|
||||||
|
continue
|
||||||
|
aspect = w / float(h)
|
||||||
|
if not (1.0 - aspect_tol <= aspect <= 1.0 + aspect_tol):
|
||||||
|
continue
|
||||||
|
|
||||||
|
per = float(cv2.arcLength(c, True))
|
||||||
|
if per <= 0:
|
||||||
|
continue
|
||||||
|
circularity = 4.0 * np.pi * area / (per * per)
|
||||||
|
if circularity < circ_min:
|
||||||
|
continue
|
||||||
|
|
||||||
|
hull = cv2.convexHull(c)
|
||||||
|
hull_area = float(cv2.contourArea(hull))
|
||||||
|
if hull_area <= 0:
|
||||||
|
continue
|
||||||
|
solidity = area / hull_area
|
||||||
|
extent = area / float(w * h)
|
||||||
|
|
||||||
|
if solidity < solidity_min or extent < extent_min:
|
||||||
|
continue
|
||||||
|
|
||||||
|
M = cv2.moments(c)
|
||||||
|
if M["m00"] == 0:
|
||||||
|
continue
|
||||||
|
cx = int(M["m10"] / M["m00"])
|
||||||
|
cy = int(M["m01"] / M["m00"])
|
||||||
|
out.append((cx, cy, int(area)))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class FlowerGameEnv(gym.Env):
|
||||||
|
"""
|
||||||
|
Observation = Dict:
|
||||||
|
"image": (84,84,1)
|
||||||
|
"state": [tx,ty, fx,fy, bx,by] (bx,by = nächstgelegene gültige Bombe)
|
||||||
|
|
||||||
|
Actions: 0=W, 1=A, 2=S, 3=D
|
||||||
|
|
||||||
|
Rewards:
|
||||||
|
+0.6 bei Kontakt (<= 95 px) mit Blume
|
||||||
|
+0.10 * Distanzverkleinerung zur Blume (auf Bilddiagonale normiert)
|
||||||
|
-5.0 wenn Distanz zur nächsten Bombe <= 115 px
|
||||||
|
"""
|
||||||
|
|
||||||
|
metadata = {"render_modes": []}
|
||||||
|
|
||||||
|
def __init__(self, monitor_area, ui_exclude_rects: Optional[List[Tuple[int,int,int,int]]] = None):
|
||||||
|
super().__init__()
|
||||||
|
self.monitor_area = monitor_area
|
||||||
|
self.sct = mss.mss()
|
||||||
|
|
||||||
|
# --- Observation & Actions ---
|
||||||
|
self.observation_space = spaces.Dict({
|
||||||
|
"image": spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8),
|
||||||
|
"state": spaces.Box(low=0.0, high=1.0, shape=(6,), dtype=np.float32),
|
||||||
|
})
|
||||||
|
self.action_space = spaces.Discrete(4)
|
||||||
|
|
||||||
|
# --- HSV-Grenzen ---
|
||||||
|
self.yellow_lower = np.array([15, 40, 200], dtype=np.uint8)
|
||||||
|
self.yellow_upper = np.array([25, 120, 255], dtype=np.uint8)
|
||||||
|
self.white_lower = np.array([0, 0, 220], dtype=np.uint8)
|
||||||
|
self.white_upper = np.array([180, 50, 255], dtype=np.uint8)
|
||||||
|
self.black_lower = np.array([0, 0, 0], dtype=np.uint8)
|
||||||
|
self.black_upper = np.array([180, 80, 60], dtype=np.uint8)
|
||||||
|
self.green1_lower = np.array([30, 80, 80], dtype=np.uint8)
|
||||||
|
self.green1_upper = np.array([45, 255, 255], dtype=np.uint8)
|
||||||
|
self.green2_lower = np.array([65, 100, 80], dtype=np.uint8)
|
||||||
|
self.green2_upper = np.array([90, 255, 255], dtype=np.uint8)
|
||||||
|
|
||||||
|
self.kernel = np.ones((3, 3), np.uint8)
|
||||||
|
|
||||||
|
# --- Reward-/Heuristik-Parameter ---
|
||||||
|
self.eat_radius_px = 95
|
||||||
|
self.collision_dist_px = 115
|
||||||
|
self.shaping_scale = 0.10
|
||||||
|
self.eat_reward = 0.6
|
||||||
|
self.collision_penalty = 5.0
|
||||||
|
|
||||||
|
self.contact_cooldown_frames = 8
|
||||||
|
self._cooldown = 0
|
||||||
|
self.prev_dist_to_flower = None
|
||||||
|
self.flowers_eaten = 0
|
||||||
|
|
||||||
|
# --- Bomben-Filterparameter (gegen Score-Schrift) ---
|
||||||
|
self.bomb_min_area = 400 # <— Text-Glyphen sind meist kleiner
|
||||||
|
self.bomb_circ_min = 0.60 # Kreisförmigkeit (1.0 ist perfekter Kreis)
|
||||||
|
self.bomb_aspect_tol = 0.35 # erlaubt 0.65–1.35 Seitenverhältnis
|
||||||
|
self.bomb_extent_min = 0.60 # Füllgrad im Bounding-Rect
|
||||||
|
self.bomb_solidity_min = 0.85 # gegen ring-/schriftartige Konturen
|
||||||
|
|
||||||
|
# UI-Ausschlusszonen (optional): [(x0,y0,x1,y1), ...] relativ zum monitor_area
|
||||||
|
self.ui_exclude_rects = ui_exclude_rects or []
|
||||||
|
|
||||||
|
self._last_cache = {
|
||||||
|
"turtle_xy": (None, None),
|
||||||
|
"flower_xy": (None, None),
|
||||||
|
"bombs_xy": [],
|
||||||
|
"turtle_found": False,
|
||||||
|
"flower_found": False,
|
||||||
|
"frame_hw": (1, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
# ---------------- Gymnasium API ----------------
|
||||||
|
def reset(self, seed=None, options=None):
|
||||||
|
super().reset(seed=seed)
|
||||||
|
self.prev_dist_to_flower = None
|
||||||
|
self._cooldown = 0
|
||||||
|
self.flowers_eaten = 0
|
||||||
|
return self._build_observation(), {}
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
if action == 0: pyautogui.press("w")
|
||||||
|
elif action == 1: pyautogui.press("a")
|
||||||
|
elif action == 2: pyautogui.press("s")
|
||||||
|
elif action == 3: pyautogui.press("d")
|
||||||
|
|
||||||
|
time.sleep(0.05)
|
||||||
|
|
||||||
|
obs = self._build_observation()
|
||||||
|
reward = self._calculate_reward()
|
||||||
|
if self._cooldown > 0: self._cooldown -= 1
|
||||||
|
|
||||||
|
info = {
|
||||||
|
"flowers_eaten": self.flowers_eaten,
|
||||||
|
"bombs_expected": self._bombs_expected(self.flowers_eaten),
|
||||||
|
"bombs_detected": len(self._last_cache["bombs_xy"])
|
||||||
|
}
|
||||||
|
return obs, reward, False, False, info
|
||||||
|
|
||||||
|
# ---------------- Erkennung ----------------
|
||||||
|
def _grab_bgr(self):
|
||||||
|
raw = np.array(self.sct.grab(self.monitor_area)) # BGRA
|
||||||
|
return cv2.cvtColor(raw, cv2.COLOR_BGRA2BGR)
|
||||||
|
|
||||||
|
def _detect_entities(self, frame_bgr):
|
||||||
|
h, w, _ = frame_bgr.shape
|
||||||
|
hsv = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2HSV)
|
||||||
|
|
||||||
|
# Flower
|
||||||
|
mw = cv2.inRange(hsv, self.white_lower, self.white_upper)
|
||||||
|
my = cv2.inRange(hsv, self.yellow_lower, self.yellow_upper)
|
||||||
|
mw = cv2.morphologyEx(mw, cv2.MORPH_DILATE, self.kernel, iterations=1)
|
||||||
|
my = cv2.morphologyEx(my, cv2.MORPH_DILATE, self.kernel, iterations=1)
|
||||||
|
mf = cv2.bitwise_and(mw, my)
|
||||||
|
mf = cv2.morphologyEx(mf, cv2.MORPH_CLOSE, self.kernel, iterations=1)
|
||||||
|
fx, fy, _ = _centroid_from_mask(mf)
|
||||||
|
flower_found = (fx is not None and fy is not None)
|
||||||
|
|
||||||
|
# Bombs (mit strenger Filterung & UI-Exklusion)
|
||||||
|
mb = cv2.inRange(hsv, self.black_lower, self.black_upper)
|
||||||
|
bombs_xy = _centroids_from_contours(
|
||||||
|
mb.copy(),
|
||||||
|
self.ui_exclude_rects,
|
||||||
|
min_area=self.bomb_min_area,
|
||||||
|
circ_min=self.bomb_circ_min,
|
||||||
|
aspect_tol=self.bomb_aspect_tol,
|
||||||
|
extent_min=self.bomb_extent_min,
|
||||||
|
solidity_min=self.bomb_solidity_min,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Turtle
|
||||||
|
g1 = cv2.inRange(hsv, self.green1_lower, self.green1_upper)
|
||||||
|
g2 = cv2.inRange(hsv, self.green2_lower, self.green2_upper)
|
||||||
|
mg = cv2.bitwise_or(g1, g2)
|
||||||
|
mg = cv2.morphologyEx(mg, cv2.MORPH_OPEN, self.kernel, iterations=1)
|
||||||
|
mg = cv2.morphologyEx(mg, cv2.MORPH_DILATE, self.kernel, iterations=1)
|
||||||
|
tx, ty, _ = _centroid_from_mask(mg)
|
||||||
|
turtle_found = (tx is not None and ty is not None)
|
||||||
|
|
||||||
|
# State: nächste Bombe relativ zur Turtle
|
||||||
|
def nxy(x, y):
|
||||||
|
if x is None or y is None: return 0.0, 0.0
|
||||||
|
return x / float(w), y / float(h)
|
||||||
|
|
||||||
|
nbx, nby = 0.0, 0.0
|
||||||
|
if turtle_found and bombs_xy:
|
||||||
|
txf, tyf = float(tx), float(ty)
|
||||||
|
dists = [(np.hypot(bx - txf, by - tyf), (bx, by)) for (bx, by, _a) in bombs_xy]
|
||||||
|
_, (nbx_px, nby_px) = min(dists, key=lambda x: x[0])
|
||||||
|
nbx, nby = nxy(nbx_px, nby_px)
|
||||||
|
|
||||||
|
n_tx, n_ty = nxy(tx, ty)
|
||||||
|
n_fx, n_fy = nxy(fx, fy)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"state_norm": np.array([n_tx, n_ty, n_fx, n_fy, nbx, nby], dtype=np.float32),
|
||||||
|
"turtle_xy": (tx, ty),
|
||||||
|
"flower_xy": (fx, fy),
|
||||||
|
"bombs_xy": bombs_xy,
|
||||||
|
"turtle_found": turtle_found,
|
||||||
|
"flower_found": flower_found,
|
||||||
|
"frame_hw": (h, w),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _build_observation(self):
|
||||||
|
frame_bgr = self._grab_bgr()
|
||||||
|
gray = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2GRAY)
|
||||||
|
gray = cv2.resize(gray, (84, 84), interpolation=cv2.INTER_AREA)
|
||||||
|
gray = np.expand_dims(gray, axis=-1)
|
||||||
|
det = self._detect_entities(frame_bgr)
|
||||||
|
self._last_cache = det
|
||||||
|
return {"image": gray, "state": det["state_norm"]}
|
||||||
|
|
||||||
|
# ---------------- Reward ----------------
|
||||||
|
def _calculate_reward(self) -> float:
|
||||||
|
det = self._last_cache
|
||||||
|
reward = 0.0
|
||||||
|
|
||||||
|
tx, ty = det["turtle_xy"]
|
||||||
|
fx, fy = det["flower_xy"]
|
||||||
|
tf = det["turtle_found"]
|
||||||
|
ff = det["flower_found"]
|
||||||
|
bombs_xy = det["bombs_xy"]
|
||||||
|
h, w = det["frame_hw"]
|
||||||
|
|
||||||
|
# Distanz-Shaping
|
||||||
|
if tf and ff:
|
||||||
|
txy = np.array([tx, ty], dtype=np.float32)
|
||||||
|
fxy = np.array([fx, fy], dtype=np.float32)
|
||||||
|
dist = float(np.linalg.norm(txy - fxy))
|
||||||
|
if hasattr(self, "prev_dist_to_flower") and self.prev_dist_to_flower is not None:
|
||||||
|
delta = self.prev_dist_to_flower - dist
|
||||||
|
reward += self.shaping_scale * (delta / max(1.0, np.hypot(h, w)))
|
||||||
|
self.prev_dist_to_flower = dist
|
||||||
|
else:
|
||||||
|
self.prev_dist_to_flower = None
|
||||||
|
|
||||||
|
# Eat-Event mit Cooldown
|
||||||
|
if self._cooldown == 0 and tf and ff:
|
||||||
|
if np.linalg.norm(np.array([tx - fx, ty - fy], dtype=np.float32)) <= self.eat_radius_px:
|
||||||
|
reward += self.eat_reward
|
||||||
|
self._cooldown = self.contact_cooldown_frames
|
||||||
|
self.flowers_eaten += 1
|
||||||
|
|
||||||
|
# Kollision mit nächster Bombe
|
||||||
|
if tf and bombs_xy:
|
||||||
|
min_dist = min([np.hypot(tx - bx, ty - by) for (bx, by, _a) in bombs_xy])
|
||||||
|
if min_dist <= self.collision_dist_px:
|
||||||
|
reward -= self.collision_penalty
|
||||||
|
|
||||||
|
return float(reward)
|
||||||
|
|
||||||
|
# ---------------- Hilfsinfo ----------------
|
||||||
|
@staticmethod
|
||||||
|
def _bombs_expected(flowers_eaten: int) -> int:
|
||||||
|
return max(0, flowers_eaten // 5)
|
||||||
50
train_bot.py
Normal file
50
train_bot.py
Normal file
|
|
@ -0,0 +1,50 @@
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from stable_baselines3 import PPO
|
||||||
|
from stable_baselines3.common.callbacks import BaseCallback
|
||||||
|
from flower_game_env import FlowerGameEnv
|
||||||
|
|
||||||
|
|
||||||
|
# ---- Dein Spielbereich (anpassen!) ----
|
||||||
|
monitor_area = {"top": 120, "left": 330, "width": 1900, "height": 1263}
|
||||||
|
env = FlowerGameEnv(monitor_area)
|
||||||
|
saved_model_name = "flower_bot"
|
||||||
|
zip_file = saved_model_name + ".zip"
|
||||||
|
|
||||||
|
|
||||||
|
class TimeBasedCheckpoint(BaseCallback):
|
||||||
|
"""
|
||||||
|
Speichert das Modell alle 'save_every_secs' Sekunden in 'save_prefix' + Timestamp.
|
||||||
|
"""
|
||||||
|
def __init__(self, save_every_secs=60, save_prefix=saved_model_name, verbose=1):
|
||||||
|
super().__init__(verbose)
|
||||||
|
self.save_every_secs = save_every_secs
|
||||||
|
self.save_prefix = save_prefix
|
||||||
|
self._last_save = time.time()
|
||||||
|
|
||||||
|
def _on_step(self) -> bool:
|
||||||
|
now = time.time()
|
||||||
|
if now - self._last_save >= self.save_every_secs:
|
||||||
|
fname = f"{self.save_prefix}"
|
||||||
|
if self.verbose:
|
||||||
|
print(f"[Autosave] Saving model to {fname}.zip")
|
||||||
|
self.model.save(fname)
|
||||||
|
self._last_save = now
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
# --- Laden, falls Datei vorhanden ---
|
||||||
|
if os.path.exists(zip_file):
|
||||||
|
print(f"Lade existierendes Modell aus {zip_file}")
|
||||||
|
model = PPO.load(zip_file, env=env) # Weitertrainieren mit neuem Env
|
||||||
|
else:
|
||||||
|
print("Starte neues Modell")
|
||||||
|
# CNN + Dict-Observation → Verwende 'MultiInputPolicy'
|
||||||
|
model = PPO("MultiInputPolicy", env, verbose=2)
|
||||||
|
|
||||||
|
# Trainieren mit Autosave (jede Minute)
|
||||||
|
model.learn(total_timesteps=500_000, callback=TimeBasedCheckpoint(100, "flower_bot"))
|
||||||
|
|
||||||
|
# Abschluss-Speicherstand
|
||||||
|
model.save("flower_bot_final")
|
||||||
|
print("Training fertig. Modell gespeichert: flower_bot_final.zip")
|
||||||
Loading…
Reference in New Issue
Block a user