mirror of
https://github.com/Azure/MachineLearningNotebooks.git
synced 2025-12-20 01:27:06 -05:00
41 lines
997 B
Python
41 lines
997 B
Python
import ray
|
|
import ray.tune as tune
|
|
from ray.rllib import train
|
|
|
|
import os
|
|
import sys
|
|
|
|
from azureml.core import Run
|
|
from utils import callbacks
|
|
|
|
DEFAULT_RAY_ADDRESS = 'localhost:6379'
|
|
|
|
if __name__ == "__main__":
|
|
|
|
# Parse arguments
|
|
train_parser = train.create_parser()
|
|
|
|
args = train_parser.parse_args()
|
|
print("Algorithm config:", args.config)
|
|
|
|
if args.ray_address is None:
|
|
args.ray_address = DEFAULT_RAY_ADDRESS
|
|
|
|
ray.init(address=args.ray_address)
|
|
|
|
tune.run(
|
|
run_or_experiment=args.run,
|
|
config={
|
|
"env": args.env,
|
|
"num_gpus": args.config["num_gpus"],
|
|
"num_workers": args.config["num_workers"],
|
|
"callbacks": {"on_train_result": callbacks.on_train_result},
|
|
"sample_batch_size": 50,
|
|
"train_batch_size": 1000,
|
|
"num_sgd_iter": 2,
|
|
"num_data_loader_buffers": 2,
|
|
"model": {"dim": 42},
|
|
},
|
|
stop=args.stop,
|
|
local_dir='./logs')
|