mirror of
https://github.com/Azure/MachineLearningNotebooks.git
synced 2025-12-20 01:27:06 -05:00
83 lines
3.5 KiB
Python
83 lines
3.5 KiB
Python
import argparse
|
|
import os
|
|
import re
|
|
|
|
from rllib_multiagent_particle_env import CUSTOM_SCENARIOS
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser('MADDPG with OpenAI MPE')
|
|
|
|
# Environment
|
|
parser.add_argument('--scenario', type=str, default='simple',
|
|
choices=['simple', 'simple_speaker_listener',
|
|
'simple_crypto', 'simple_push',
|
|
'simple_tag', 'simple_spread', 'simple_adversary'
|
|
] + CUSTOM_SCENARIOS,
|
|
help='name of the scenario script')
|
|
parser.add_argument('--max-episode-len', type=int, default=25,
|
|
help='maximum episode length')
|
|
parser.add_argument('--num-episodes', type=int, default=60000,
|
|
help='number of episodes')
|
|
parser.add_argument('--num-adversaries', type=int, default=0,
|
|
help='number of adversaries')
|
|
parser.add_argument('--good-policy', type=str, default='maddpg',
|
|
help='policy for good agents')
|
|
parser.add_argument('--adv-policy', type=str, default='maddpg',
|
|
help='policy of adversaries')
|
|
|
|
# Core training parameters
|
|
parser.add_argument('--lr', type=float, default=1e-2,
|
|
help='learning rate for Adam optimizer')
|
|
parser.add_argument('--gamma', type=float, default=0.95,
|
|
help='discount factor')
|
|
# NOTE: 1 iteration = sample_batch_size * num_workers timesteps * num_envs_per_worker
|
|
parser.add_argument('--sample-batch-size', type=int, default=25,
|
|
help='number of data points sampled /update /worker')
|
|
parser.add_argument('--train-batch-size', type=int, default=1024,
|
|
help='number of data points /update')
|
|
parser.add_argument('--n-step', type=int, default=1,
|
|
help='length of multistep value backup')
|
|
parser.add_argument('--num-units', type=int, default=64,
|
|
help='number of units in the mlp')
|
|
parser.add_argument('--final-reward', type=int, default=-400,
|
|
help='final reward after which to stop training')
|
|
|
|
# Checkpoint
|
|
parser.add_argument('--checkpoint-freq', type=int, default=200,
|
|
help='save model once every time this many iterations are completed')
|
|
parser.add_argument('--local-dir', type=str, default='./logs',
|
|
help='path to save checkpoints')
|
|
parser.add_argument('--restore', type=str, default=None,
|
|
help='directory in which training state and model are loaded')
|
|
|
|
# Parallelism
|
|
parser.add_argument('--num-workers', type=int, default=1)
|
|
parser.add_argument('--num-envs-per-worker', type=int, default=4)
|
|
parser.add_argument('--num-gpus', type=int, default=0)
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
def find_final_checkpoint(start_dir):
|
|
def find(pattern, path):
|
|
result = []
|
|
for root, _, files in os.walk(path):
|
|
for name in files:
|
|
if pattern.match(name):
|
|
result.append(os.path.join(root, name))
|
|
return result
|
|
|
|
cp_pattern = re.compile('.*checkpoint-\\d+$')
|
|
checkpoint_files = find(cp_pattern, start_dir)
|
|
|
|
checkpoint_numbers = []
|
|
for file in checkpoint_files:
|
|
checkpoint_numbers.append(int(file.split('-')[-1]))
|
|
|
|
final_checkpoint_number = max(checkpoint_numbers)
|
|
|
|
return next(
|
|
checkpoint_file for checkpoint_file in checkpoint_files
|
|
if checkpoint_file.endswith(str(final_checkpoint_number)))
|