turtleBot/train_bot.py
2025-08-10 17:04:19 +02:00

51 lines
1.7 KiB
Python

import os
import time
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import BaseCallback
from flower_game_env import FlowerGameEnv
# ---- Dein Spielbereich (anpassen!) ----
monitor_area = {"top": 120, "left": 330, "width": 1900, "height": 1263}
env = FlowerGameEnv(monitor_area)
saved_model_name = "flower_bot"
zip_file = saved_model_name + ".zip"
class TimeBasedCheckpoint(BaseCallback):
"""
Speichert das Modell alle 'save_every_secs' Sekunden in 'save_prefix' + Timestamp.
"""
def __init__(self, save_every_secs=60, save_prefix=saved_model_name, verbose=1):
super().__init__(verbose)
self.save_every_secs = save_every_secs
self.save_prefix = save_prefix
self._last_save = time.time()
def _on_step(self) -> bool:
now = time.time()
if now - self._last_save >= self.save_every_secs:
fname = f"{self.save_prefix}"
if self.verbose:
print(f"[Autosave] Saving model to {fname}.zip")
self.model.save(fname)
self._last_save = now
return True
# --- Laden, falls Datei vorhanden ---
if os.path.exists(zip_file):
print(f"Lade existierendes Modell aus {zip_file}")
model = PPO.load(zip_file, env=env) # Weitertrainieren mit neuem Env
else:
print("Starte neues Modell")
# CNN + Dict-Observation → Verwende 'MultiInputPolicy'
model = PPO("MultiInputPolicy", env, verbose=2)
# Trainieren mit Autosave (jede Minute)
model.learn(total_timesteps=500_000, callback=TimeBasedCheckpoint(100, "flower_bot"))
# Abschluss-Speicherstand
model.save("flower_bot_final")
print("Training fertig. Modell gespeichert: flower_bot_final.zip")