Files

83 lines
3.5 KiB
Python

import argparse
import os
import re
from rllib_multiagent_particle_env import CUSTOM_SCENARIOS
def parse_args():
parser = argparse.ArgumentParser('MADDPG with OpenAI MPE')
# Environment
parser.add_argument('--scenario', type=str, default='simple',
choices=['simple', 'simple_speaker_listener',
'simple_crypto', 'simple_push',
'simple_tag', 'simple_spread', 'simple_adversary'
] + CUSTOM_SCENARIOS,
help='name of the scenario script')
parser.add_argument('--max-episode-len', type=int, default=25,
help='maximum episode length')
parser.add_argument('--num-episodes', type=int, default=60000,
help='number of episodes')
parser.add_argument('--num-adversaries', type=int, default=0,
help='number of adversaries')
parser.add_argument('--good-policy', type=str, default='maddpg',
help='policy for good agents')
parser.add_argument('--adv-policy', type=str, default='maddpg',
help='policy of adversaries')
# Core training parameters
parser.add_argument('--lr', type=float, default=1e-2,
help='learning rate for Adam optimizer')
parser.add_argument('--gamma', type=float, default=0.95,
help='discount factor')
# NOTE: 1 iteration = sample_batch_size * num_workers timesteps * num_envs_per_worker
parser.add_argument('--sample-batch-size', type=int, default=25,
help='number of data points sampled /update /worker')
parser.add_argument('--train-batch-size', type=int, default=1024,
help='number of data points /update')
parser.add_argument('--n-step', type=int, default=1,
help='length of multistep value backup')
parser.add_argument('--num-units', type=int, default=64,
help='number of units in the mlp')
parser.add_argument('--final-reward', type=int, default=-400,
help='final reward after which to stop training')
# Checkpoint
parser.add_argument('--checkpoint-freq', type=int, default=200,
help='save model once every time this many iterations are completed')
parser.add_argument('--local-dir', type=str, default='./logs',
help='path to save checkpoints')
parser.add_argument('--restore', type=str, default=None,
help='directory in which training state and model are loaded')
# Parallelism
parser.add_argument('--num-workers', type=int, default=1)
parser.add_argument('--num-envs-per-worker', type=int, default=4)
parser.add_argument('--num-gpus', type=int, default=0)
return parser.parse_args()
def find_final_checkpoint(start_dir):
def find(pattern, path):
result = []
for root, _, files in os.walk(path):
for name in files:
if pattern.match(name):
result.append(os.path.join(root, name))
return result
cp_pattern = re.compile('.*checkpoint-\\d+$')
checkpoint_files = find(cp_pattern, start_dir)
checkpoint_numbers = []
for file in checkpoint_files:
checkpoint_numbers.append(int(file.split('-')[-1]))
final_checkpoint_number = max(checkpoint_numbers)
return next(
checkpoint_file for checkpoint_file in checkpoint_files
if checkpoint_file.endswith(str(final_checkpoint_number)))