Files

41 lines
997 B
Python

import ray
import ray.tune as tune
from ray.rllib import train
import os
import sys
from azureml.core import Run
from utils import callbacks
DEFAULT_RAY_ADDRESS = 'localhost:6379'
if __name__ == "__main__":
# Parse arguments
train_parser = train.create_parser()
args = train_parser.parse_args()
print("Algorithm config:", args.config)
if args.ray_address is None:
args.ray_address = DEFAULT_RAY_ADDRESS
ray.init(address=args.ray_address)
tune.run(
run_or_experiment=args.run,
config={
"env": args.env,
"num_gpus": args.config["num_gpus"],
"num_workers": args.config["num_workers"],
"callbacks": {"on_train_result": callbacks.on_train_result},
"sample_batch_size": 50,
"train_batch_size": 1000,
"num_sgd_iter": 2,
"num_data_loader_buffers": 2,
"model": {"dim": 42},
},
stop=args.stop,
local_dir='./logs')