71 lines
1.8 KiB
Python
71 lines
1.8 KiB
Python
import torch
|
|
from stable_baselines3 import PPO
|
|
from elden_env import EldenRingEnv
|
|
|
|
|
|
def get_system_info():
|
|
with open('/proc/meminfo') as f:
|
|
for line in f:
|
|
if line.startswith('MemTotal'):
|
|
ram_gb = int(line.split()[1]) / (1024 ** 2)
|
|
break
|
|
|
|
if torch.cuda.is_available():
|
|
props = torch.cuda.get_device_properties(0)
|
|
vram_gb = props.total_memory / (1024 ** 3)
|
|
device = "cuda"
|
|
gpu_name = props.name
|
|
else:
|
|
vram_gb = 0
|
|
device = "cpu"
|
|
gpu_name = "None"
|
|
|
|
return ram_gb, vram_gb, device, gpu_name
|
|
|
|
|
|
def get_hyperparams(ram_gb, vram_gb):
|
|
# n_steps: rollout buffer size, lives in RAM.
|
|
# Each obs is (640,640,12) float32 = ~18.75MB. Use at most 20% of RAM.
|
|
if ram_gb >= 64:
|
|
n_steps = 2048
|
|
elif ram_gb >= 32:
|
|
n_steps = 512
|
|
elif ram_gb >= 16:
|
|
n_steps = 256
|
|
else:
|
|
n_steps = 128
|
|
|
|
# batch_size: minibatch for the gradient update, lives in VRAM.
|
|
if vram_gb >= 16:
|
|
batch_size = 128
|
|
elif vram_gb >= 8:
|
|
batch_size = 32
|
|
elif vram_gb >= 4:
|
|
batch_size = 16
|
|
else:
|
|
batch_size = 8
|
|
|
|
return n_steps, batch_size
|
|
|
|
|
|
def train():
|
|
ram_gb, vram_gb, device, gpu_name = get_system_info()
|
|
n_steps, batch_size = get_hyperparams(ram_gb, vram_gb)
|
|
|
|
print(f"[HW] RAM: {ram_gb:.1f} GB | VRAM: {vram_gb:.1f} GB ({gpu_name}) | Device: {device}")
|
|
print(f"[HW] n_steps={n_steps}, batch_size={batch_size}")
|
|
|
|
env = EldenRingEnv()
|
|
model = PPO("CnnPolicy", env, verbose=1,
|
|
device=device,
|
|
n_steps=n_steps,
|
|
batch_size=batch_size)
|
|
|
|
print("Starting Training...")
|
|
model.learn(total_timesteps=100000)
|
|
model.save("elden_ai_model")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
train()
|