mirror of
https://github.com/Azure/MachineLearningNotebooks.git
synced 2025-12-20 09:37:04 -05:00
238 lines
7.7 KiB
Python
238 lines
7.7 KiB
Python
import sys
|
|
import csv
|
|
from azure.mgmt.network import NetworkManagementClient
|
|
|
|
|
|
def check_port_in_port_range(expected_port: str,
|
|
dest_port_range: str):
|
|
"""
|
|
Check if a port is within a port range
|
|
Port range maybe like *, 8080 or 8888-8889
|
|
"""
|
|
|
|
if dest_port_range == '*':
|
|
return True
|
|
|
|
dest_ports = dest_port_range.split('-')
|
|
|
|
if len(dest_ports) == 1 and \
|
|
int(dest_ports[0]) == int(expected_port):
|
|
return True
|
|
|
|
if len(dest_ports) == 2 and \
|
|
int(dest_ports[0]) <= int(expected_port) and \
|
|
int(dest_ports[1]) >= int(expected_port):
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def check_port_in_destination_port_ranges(expected_port: str,
|
|
dest_port_ranges: list):
|
|
"""
|
|
Check if a port is within a given list of port ranges
|
|
i.e. check if port 8080 is in port ranges of 22,80,8080-8090,443
|
|
"""
|
|
|
|
for dest_port_range in dest_port_ranges:
|
|
if check_port_in_port_range(expected_port, dest_port_range) is True:
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def check_ports_in_destination_port_ranges(expected_ports: list,
|
|
dest_port_ranges: list):
|
|
"""
|
|
Check if all ports in a given port list are within a given list
|
|
of port ranges
|
|
i.e. check if port 8080,8081 are in port ranges of 22,80,8080-8090,443
|
|
"""
|
|
|
|
for expected_port in expected_ports:
|
|
if check_port_in_destination_port_ranges(
|
|
expected_port, dest_port_ranges) is False:
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def check_source_address_prefix(source_address_prefix: str):
|
|
"""Check if source address prefix is BatchNodeManagement or default"""
|
|
|
|
required_prefix = 'BatchNodeManagement'
|
|
default_prefix = 'default'
|
|
|
|
if source_address_prefix.lower() == required_prefix.lower() or \
|
|
source_address_prefix.lower() == default_prefix.lower():
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def check_protocol(protocol: str):
|
|
"""Check if protocol is supported - Tcp/Any"""
|
|
|
|
required_protocol = 'Tcp'
|
|
any_protocol = 'Any'
|
|
|
|
if required_protocol.lower() == protocol.lower() or \
|
|
any_protocol.lower() == protocol.lower():
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def check_direction(direction: str):
|
|
"""Check if port direction is inbound"""
|
|
|
|
required_direction = 'Inbound'
|
|
|
|
if required_direction.lower() == direction.lower():
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def check_provisioning_state(provisioning_state: str):
|
|
"""Check if the provisioning state is succeeded"""
|
|
|
|
required_provisioning_state = 'Succeeded'
|
|
|
|
if required_provisioning_state.lower() == provisioning_state.lower():
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def check_rule_for_Azure_ML(rule):
|
|
"""Check if the ports required for Azure Machine Learning are open"""
|
|
|
|
required_ports = ['29876', '29877']
|
|
|
|
if check_source_address_prefix(rule.source_address_prefix) is False:
|
|
return False
|
|
|
|
if check_protocol(rule.protocol) is False:
|
|
return False
|
|
|
|
if check_direction(rule.direction) is False:
|
|
return False
|
|
|
|
if check_provisioning_state(rule.provisioning_state) is False:
|
|
return False
|
|
|
|
if rule.destination_port_range is not None:
|
|
if check_ports_in_destination_port_ranges(
|
|
required_ports,
|
|
[rule.destination_port_range]) is False:
|
|
return False
|
|
else:
|
|
if check_ports_in_destination_port_ranges(
|
|
required_ports,
|
|
rule.destination_port_ranges) is False:
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def check_vnet_security_rules(auth_object,
|
|
vnet_subscription_id,
|
|
vnet_resource_group,
|
|
vnet_name,
|
|
save_to_file=False):
|
|
"""
|
|
Check all the rules of virtual network if required ports for Azure Machine
|
|
Learning are open
|
|
"""
|
|
|
|
network_client = NetworkManagementClient(
|
|
auth_object,
|
|
vnet_subscription_id)
|
|
|
|
# get the vnet
|
|
vnet = network_client.virtual_networks.get(
|
|
resource_group_name=vnet_resource_group,
|
|
virtual_network_name=vnet_name)
|
|
|
|
vnet_location = vnet.location
|
|
vnet_info = []
|
|
|
|
if vnet.subnets is None or len(vnet.subnets) == 0:
|
|
print('WARNING: No subnet found for VNet:', vnet_name)
|
|
|
|
# for each subnet of the vnet
|
|
for subnet in vnet.subnets:
|
|
if subnet.network_security_group is None:
|
|
print('WARNING: No network security group found for subnet.',
|
|
'Subnet',
|
|
subnet.id.split("/")[-1])
|
|
else:
|
|
# get all the rules
|
|
network_security_group_name = \
|
|
subnet.network_security_group.id.split("/")[-1]
|
|
network_security_group_resource_group_name = \
|
|
subnet.network_security_group.id.split("/")[4]
|
|
network_security_group_subscription_id = \
|
|
subnet.network_security_group.id.split("/")[2]
|
|
|
|
security_rules = list(network_client.security_rules.list(
|
|
network_security_group_resource_group_name,
|
|
network_security_group_name))
|
|
|
|
rule_matched = None
|
|
for rule in security_rules:
|
|
rule_info = []
|
|
# add vnet details
|
|
rule_info.append(vnet_name)
|
|
rule_info.append(vnet_subscription_id)
|
|
rule_info.append(vnet_resource_group)
|
|
rule_info.append(vnet_location)
|
|
# add subnet details
|
|
rule_info.append(subnet.id.split("/")[-1])
|
|
rule_info.append(network_security_group_name)
|
|
rule_info.append(network_security_group_subscription_id)
|
|
rule_info.append(network_security_group_resource_group_name)
|
|
# add rule details
|
|
rule_info.append(rule.priority)
|
|
rule_info.append(rule.name)
|
|
rule_info.append(rule.source_address_prefix)
|
|
if rule.destination_port_range is not None:
|
|
rule_info.append(rule.destination_port_range)
|
|
else:
|
|
rule_info.append(rule.destination_port_ranges)
|
|
rule_info.append(rule.direction)
|
|
rule_info.append(rule.provisioning_state)
|
|
vnet_info.append(rule_info)
|
|
|
|
if check_rule_for_Azure_ML(rule) is True:
|
|
rule_matched = rule
|
|
|
|
if rule_matched is not None:
|
|
print("INFORMATION: Rule matched with required ports. Subnet:",
|
|
subnet.id.split("/")[-1], "Rule:", rule.name)
|
|
else:
|
|
print("WARNING: No rule matched with required ports. Subnet:",
|
|
subnet.id.split("/")[-1])
|
|
|
|
if save_to_file is True:
|
|
file_name = vnet_name + ".csv"
|
|
with open(file_name, mode='w') as vnet_rule_file:
|
|
vnet_rule_file_writer = csv.writer(
|
|
vnet_rule_file,
|
|
delimiter=',',
|
|
quotechar='"',
|
|
quoting=csv.QUOTE_MINIMAL)
|
|
header = ['VNet_Name', 'VNet_Subscription_ID',
|
|
'VNet_Resource_Group', 'VNet_Location',
|
|
'Subnet_Name', 'NSG_Name',
|
|
'NSG_Subscription_ID', 'NSG_Resource_Group',
|
|
'Rule_Priority', 'Rule_Name', 'Rule_Source',
|
|
'Rule_Destination_Ports', 'Rule_Direction',
|
|
'Rule_Provisioning_State']
|
|
vnet_rule_file_writer.writerow(header)
|
|
vnet_rule_file_writer.writerows(vnet_info)
|
|
|
|
print("INFORMATION: Network security group rules for your virtual \
|
|
network are saved in file", file_name)
|