Initial commit — Elden Ring RL agent
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
70
train.py
Normal file
70
train.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import torch
|
||||
from stable_baselines3 import PPO
|
||||
from elden_env import EldenRingEnv
|
||||
|
||||
|
||||
def get_system_info():
|
||||
with open('/proc/meminfo') as f:
|
||||
for line in f:
|
||||
if line.startswith('MemTotal'):
|
||||
ram_gb = int(line.split()[1]) / (1024 ** 2)
|
||||
break
|
||||
|
||||
if torch.cuda.is_available():
|
||||
props = torch.cuda.get_device_properties(0)
|
||||
vram_gb = props.total_memory / (1024 ** 3)
|
||||
device = "cuda"
|
||||
gpu_name = props.name
|
||||
else:
|
||||
vram_gb = 0
|
||||
device = "cpu"
|
||||
gpu_name = "None"
|
||||
|
||||
return ram_gb, vram_gb, device, gpu_name
|
||||
|
||||
|
||||
def get_hyperparams(ram_gb, vram_gb):
|
||||
# n_steps: rollout buffer size, lives in RAM.
|
||||
# Each obs is (640,640,12) float32 = ~18.75MB. Use at most 20% of RAM.
|
||||
if ram_gb >= 64:
|
||||
n_steps = 2048
|
||||
elif ram_gb >= 32:
|
||||
n_steps = 512
|
||||
elif ram_gb >= 16:
|
||||
n_steps = 256
|
||||
else:
|
||||
n_steps = 128
|
||||
|
||||
# batch_size: minibatch for the gradient update, lives in VRAM.
|
||||
if vram_gb >= 16:
|
||||
batch_size = 128
|
||||
elif vram_gb >= 8:
|
||||
batch_size = 32
|
||||
elif vram_gb >= 4:
|
||||
batch_size = 16
|
||||
else:
|
||||
batch_size = 8
|
||||
|
||||
return n_steps, batch_size
|
||||
|
||||
|
||||
def train():
|
||||
ram_gb, vram_gb, device, gpu_name = get_system_info()
|
||||
n_steps, batch_size = get_hyperparams(ram_gb, vram_gb)
|
||||
|
||||
print(f"[HW] RAM: {ram_gb:.1f} GB | VRAM: {vram_gb:.1f} GB ({gpu_name}) | Device: {device}")
|
||||
print(f"[HW] n_steps={n_steps}, batch_size={batch_size}")
|
||||
|
||||
env = EldenRingEnv()
|
||||
model = PPO("CnnPolicy", env, verbose=1,
|
||||
device=device,
|
||||
n_steps=n_steps,
|
||||
batch_size=batch_size)
|
||||
|
||||
print("Starting Training...")
|
||||
model.learn(total_timesteps=100000)
|
||||
model.save("elden_ai_model")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
Reference in New Issue
Block a user