57 lines
1.7 KiB
Python
57 lines
1.7 KiB
Python
import os
|
|
import time
|
|
from stable_baselines3 import PPO
|
|
from stable_baselines3.common.callbacks import BaseCallback
|
|
from flower_game_env import FlowerGameEnv
|
|
|
|
# ---- Spielbereich ----
|
|
monitor_area = {"top": 120, "left": 330, "width": 1900, "height": 1263}
|
|
ui_exclude_rects = [] # optional: Pixel oder [0..1]-normiert
|
|
|
|
time.sleep(3)
|
|
|
|
# Env mit fester Referenzgröße (Baseline)
|
|
env = FlowerGameEnv(
|
|
monitor_area,
|
|
ui_exclude_rects=ui_exclude_rects,
|
|
ref_size=(1900, 1263),
|
|
)
|
|
|
|
saved_model_name = "flower_bot"
|
|
zip_file = f"{saved_model_name}.zip"
|
|
|
|
|
|
class TimeBasedCheckpoint(BaseCallback):
|
|
"""Speichert das Modell alle 'save_every_secs' Sekunden nach save_prefix.zip"""
|
|
def __init__(self, save_every_secs=60, save_prefix=saved_model_name, verbose=1):
|
|
super().__init__(verbose)
|
|
self.save_every_secs = save_every_secs
|
|
self.save_prefix = save_prefix
|
|
self._last_save = time.time()
|
|
|
|
def _on_step(self) -> bool:
|
|
now = time.time()
|
|
if now - self._last_save >= self.save_every_secs:
|
|
fname = self.save_prefix
|
|
if self.verbose:
|
|
print(f"[Autosave] Saving model to {fname}.zip")
|
|
self.model.save(fname)
|
|
self._last_save = now
|
|
return True
|
|
|
|
|
|
# --- Laden/Starten ---
|
|
if os.path.exists(zip_file):
|
|
print(f"Lade existierendes Modell aus {zip_file}")
|
|
model = PPO.load(zip_file, env=env) # weitertrainieren
|
|
else:
|
|
print("Starte neues Modell")
|
|
model = PPO("MultiInputPolicy", env, verbose=0)
|
|
|
|
# Trainieren mit Autosave
|
|
model.learn(total_timesteps=500_000, callback=TimeBasedCheckpoint(100, saved_model_name))
|
|
|
|
# Abschluss-Speicherstand
|
|
model.save("flower_bot_final")
|
|
print("Training fertig. Modell gespeichert: flower_bot_final.zip")
|