mirror of
https://github.com/Azure/MachineLearningNotebooks.git
synced 2025-12-20 09:37:04 -05:00
36 lines
1023 B
Python
36 lines
1023 B
Python
from ray_on_aml.core import Ray_On_AML
|
|
|
|
import ray.tune as tune
|
|
from ray.rllib import train
|
|
|
|
from utils import callbacks
|
|
|
|
if __name__ == "__main__":
|
|
|
|
ray_on_aml = Ray_On_AML()
|
|
ray = ray_on_aml.getRay()
|
|
if ray: # in the headnode
|
|
# Parse arguments
|
|
train_parser = train.create_parser()
|
|
|
|
args = train_parser.parse_args()
|
|
print("Algorithm config:", args.config)
|
|
|
|
tune.run(
|
|
run_or_experiment=args.run,
|
|
config={
|
|
"env": args.env,
|
|
"num_gpus": args.config["num_gpus"],
|
|
"num_workers": args.config["num_workers"],
|
|
"callbacks": {"on_train_result": callbacks.on_train_result},
|
|
"sample_batch_size": 50,
|
|
"train_batch_size": 1000,
|
|
"num_sgd_iter": 2,
|
|
"num_data_loader_buffers": 2,
|
|
"model": {"dim": 42},
|
|
},
|
|
stop=args.stop,
|
|
local_dir='./logs')
|
|
else:
|
|
print("in worker node")
|