diff --git a/src/horizondb/HISTORY.rst b/src/horizondb/HISTORY.rst index 86632532ea6..d68c92975c5 100644 --- a/src/horizondb/HISTORY.rst +++ b/src/horizondb/HISTORY.rst @@ -3,6 +3,11 @@ Release History =============== +1.0.0b5 ++++++++ +* Add support for configuring public access on HorizonDB clusters through `az horizondb create --public-access` and `az horizondb update --public-access`. Supplying an IP address or range automatically creates a firewall rule. +* Add the `az horizondb firewall-rule` command group (`create`, `show`, `list`, `update`, `delete`) to manage cluster firewall rules. + 1.0.0b4 +++++++ * Update validation checks for commands. Add short form arguments for user convenience. diff --git a/src/horizondb/azext_horizondb/_client_factory.py b/src/horizondb/azext_horizondb/_client_factory.py index fe3ddcce26f..9fd760831f0 100644 --- a/src/horizondb/azext_horizondb/_client_factory.py +++ b/src/horizondb/azext_horizondb/_client_factory.py @@ -36,3 +36,7 @@ def resource_client_factory(cli_ctx, subscription_id=None): def cf_horizondb_clusters(cli_ctx, _): return get_horizondb_management_client(cli_ctx).horizon_db_clusters + + +def cf_horizondb_firewall_rules(cli_ctx, _): + return get_horizondb_management_client(cli_ctx).horizon_db_firewall_rules diff --git a/src/horizondb/azext_horizondb/_help.py b/src/horizondb/azext_horizondb/_help.py index 89eb45a7077..44656c6c429 100644 --- a/src/horizondb/azext_horizondb/_help.py +++ b/src/horizondb/azext_horizondb/_help.py @@ -23,6 +23,12 @@ text: az horizondb create --name examplecluster --resource-group exampleresourcegroup --location centralus --administrator-login myadmin --administrator-login-password examplepassword --version 17 --v-cores 4 --replica-count 3 - name: Create a HorizonDB cluster with zone placement policy. text: az horizondb create --name examplecluster --resource-group exampleresourcegroup --location centralus --administrator-login myadmin --administrator-login-password examplepassword --version 17 --v-cores 4 --replica-count 3 --zone-placement-policy Strict + - name: Create a HorizonDB cluster and allow public access from a single IP address (creates a firewall rule). + text: az horizondb create --name examplecluster --resource-group exampleresourcegroup --location centralus --administrator-login myadmin --administrator-login-password examplepassword --version 17 --v-cores 4 --public-access 12.12.12.12 + - name: Create a HorizonDB cluster and allow public access from a range of IP addresses. + text: az horizondb create --name examplecluster --resource-group exampleresourcegroup --location centralus --administrator-login myadmin --administrator-login-password examplepassword --version 17 --v-cores 4 --public-access 12.12.12.0-12.12.12.255 + - name: Create a HorizonDB cluster and allow public access from all IP addresses. + text: az horizondb create --name examplecluster --resource-group exampleresourcegroup --location centralus --administrator-login myadmin --administrator-login-password examplepassword --version 17 --v-cores 4 --public-access All """ @@ -34,6 +40,8 @@ text: az horizondb update --name examplecluster --resource-group exampleresourcegroup --v-cores 6 - name: Assign a parameter group to an existing HorizonDB cluster. text: az horizondb update --name examplecluster --resource-group exampleresourcegroup --parameter-group /subscriptions/{subscriptionId}/resourceGroups/{resourceGroup}/providers/Microsoft.HorizonDb/parameterGroups/{parameterGroup} + - name: Enable public access on an existing HorizonDB cluster (detects your client IP and prompts to create a firewall rule). + text: az horizondb update --name examplecluster --resource-group exampleresourcegroup --public-access Enabled """ @@ -64,3 +72,61 @@ - name: List Azure HorizonDB clusters in a resource group. text: az horizondb list --resource-group exampleresourcegroup """ + + +helps['horizondb firewall-rule'] = """ +type: group +short-summary: Manage firewall rules for an Azure HorizonDB cluster. +long-summary: > + Firewall rules control public access to a HorizonDB cluster and are applied to the cluster's + default pool. Use these commands to allow inbound connections from specific IP addresses or ranges. +""" + + +helps['horizondb firewall-rule create'] = """ +type: command +short-summary: Create a firewall rule for an Azure HorizonDB cluster. +examples: + - name: Create a firewall rule allowing a single IP address. + text: az horizondb firewall-rule create --resource-group exampleresourcegroup --cluster-name examplecluster --name allowclientip --start-ip-address 12.12.12.12 + - name: Create a firewall rule allowing a range of IP addresses. + text: az horizondb firewall-rule create --resource-group exampleresourcegroup --cluster-name examplecluster --name allowrange --start-ip-address 12.12.12.0 --end-ip-address 12.12.12.255 + - name: Create a firewall rule allowing access from all Azure-internal IP addresses. + text: az horizondb firewall-rule create --resource-group exampleresourcegroup --cluster-name examplecluster --name allowazure --start-ip-address 0.0.0.0 --end-ip-address 0.0.0.0 +""" + + +helps['horizondb firewall-rule update'] = """ +type: command +short-summary: Update a firewall rule for an Azure HorizonDB cluster. +examples: + - name: Update the IP range of an existing firewall rule. + text: az horizondb firewall-rule update --resource-group exampleresourcegroup --cluster-name examplecluster --name allowrange --start-ip-address 12.12.12.0 --end-ip-address 12.12.12.128 +""" + + +helps['horizondb firewall-rule show'] = """ +type: command +short-summary: Show the details of a firewall rule for an Azure HorizonDB cluster. +examples: + - name: Show a firewall rule. + text: az horizondb firewall-rule show --resource-group exampleresourcegroup --cluster-name examplecluster --name allowclientip +""" + + +helps['horizondb firewall-rule list'] = """ +type: command +short-summary: List the firewall rules for an Azure HorizonDB cluster. +examples: + - name: List all firewall rules for a cluster. + text: az horizondb firewall-rule list --resource-group exampleresourcegroup --cluster-name examplecluster +""" + + +helps['horizondb firewall-rule delete'] = """ +type: command +short-summary: Delete a firewall rule for an Azure HorizonDB cluster. +examples: + - name: Delete a firewall rule. + text: az horizondb firewall-rule delete --resource-group exampleresourcegroup --cluster-name examplecluster --name allowclientip +""" diff --git a/src/horizondb/azext_horizondb/_params.py b/src/horizondb/azext_horizondb/_params.py index 5aa027a94be..9b86c3c7fd9 100644 --- a/src/horizondb/azext_horizondb/_params.py +++ b/src/horizondb/azext_horizondb/_params.py @@ -14,7 +14,9 @@ get_enum_type) from azure.cli.core.local_context import LocalContextAttribute, LocalContextAction from .utils.validators import ( - validate_replica_count) + validate_replica_count, + public_access_validator, + ip_address_validator) def load_arguments(self, _): # pylint: disable=too-many-statements, too-many-locals @@ -72,6 +74,52 @@ def _horizondb_params(): options_list=['--parameter-group'], help='The resource ID of the parameter group.') + public_access_create_arg_type = CLIArgumentType( + options_list=['--public-access'], + validator=public_access_validator, + help="Determines the public access for the cluster by creating a firewall rule on the " + "default pool. Enter a single IP address or a range of IP addresses (dash-separated, " + "no spaces) to be included in the allowed list of IPs. Specifying 'All' allows public " + "access from any IP (0.0.0.0-255.255.255.255). 'Enabled' detects your current client " + "IP and prompts to allow it. 'None' and 'Disabled' do not create a firewall rule. " + "Acceptable values: 'Enabled', 'Disabled', 'All', 'None', '{startIP}' and " + "'{startIP}-{endIP}' where each IP ranges from 0.0.0.0 to 255.255.255.255.") + + public_access_update_arg_type = CLIArgumentType( + options_list=['--public-access'], + arg_type=get_enum_type(['Enabled', 'Disabled']), + help="Enable or disable public access on the cluster. 'Enabled' detects your current " + "client IP and prompts to create a firewall rule on the default pool. 'Disabled' " + "points you to the 'az horizondb firewall-rule' commands to remove public access.") + + firewall_cluster_name_arg_type = CLIArgumentType( + options_list=['--cluster-name', '-c'], + id_part=None, + help='Name of the HorizonDB cluster.') + + firewall_rule_name_arg_type = CLIArgumentType( + options_list=['--name', '-n'], + id_part=None, + help='The name of the firewall rule.') + + pool_name_arg_type = CLIArgumentType( + options_list=['--pool-name'], + help='The name of the pool the firewall rule targets. Defaults to the default pool.') + + start_ip_address_arg_type = CLIArgumentType( + options_list=['--start-ip-address'], + help='The start IP address of the firewall rule (IPv4). Must be dotted-quad format. Use ' + '0.0.0.0 to represent all Azure-internal IP addresses.') + + end_ip_address_arg_type = CLIArgumentType( + options_list=['--end-ip-address'], + help='The end IP address of the firewall rule (IPv4). Must be dotted-quad format. Use ' + '0.0.0.0 to represent all Azure-internal IP addresses.') + + firewall_rule_description_arg_type = CLIArgumentType( + options_list=['--description'], + help='The description of the firewall rule.') + with self.argument_context('horizondb') as c: c.argument('resource_group_name', arg_type=resource_group_name_type) c.argument('cluster_name', arg_type=cluster_name_arg_type) @@ -85,14 +133,37 @@ def _horizondb_params(): c.argument('replica_count', arg_type=replica_count_arg_type) c.argument('v_cores', arg_type=v_cores_arg_type) c.argument('zone_placement_policy', arg_type=zone_placement_policy_arg_type) + c.argument('public_access', arg_type=public_access_create_arg_type) + c.argument('yes', arg_type=yes_arg_type) with self.argument_context('horizondb update') as c: c.argument('tags', tags_type) c.argument('administrator_login_password', arg_type=administrator_login_password_arg_type) c.argument('v_cores', arg_type=v_cores_arg_type) c.argument('parameter_group', arg_type=parameter_group_arg_type) + c.argument('public_access', arg_type=public_access_update_arg_type) + c.argument('yes', arg_type=yes_arg_type) with self.argument_context('horizondb delete') as c: c.argument('yes', arg_type=yes_arg_type) + with self.argument_context('horizondb firewall-rule') as c: + c.argument('resource_group_name', arg_type=resource_group_name_type) + c.argument('cluster_name', arg_type=firewall_cluster_name_arg_type) + c.argument('firewall_rule_name', arg_type=firewall_rule_name_arg_type) + c.argument('pool_name', arg_type=pool_name_arg_type) + + with self.argument_context('horizondb firewall-rule create') as c: + c.argument('start_ip_address', arg_type=start_ip_address_arg_type, validator=ip_address_validator) + c.argument('end_ip_address', arg_type=end_ip_address_arg_type) + c.argument('description', arg_type=firewall_rule_description_arg_type) + + with self.argument_context('horizondb firewall-rule update') as c: + c.argument('start_ip_address', arg_type=start_ip_address_arg_type, validator=ip_address_validator) + c.argument('end_ip_address', arg_type=end_ip_address_arg_type) + c.argument('description', arg_type=firewall_rule_description_arg_type) + + with self.argument_context('horizondb firewall-rule delete') as c: + c.argument('yes', arg_type=yes_arg_type) + _horizondb_params() diff --git a/src/horizondb/azext_horizondb/cluster_commands.py b/src/horizondb/azext_horizondb/cluster_commands.py index a7ba67a7f9b..6cc4cd3a9ab 100644 --- a/src/horizondb/azext_horizondb/cluster_commands.py +++ b/src/horizondb/azext_horizondb/cluster_commands.py @@ -5,7 +5,8 @@ from azure.cli.core.commands import CliCommandType from azext_horizondb._client_factory import ( - cf_horizondb_clusters) + cf_horizondb_clusters, + cf_horizondb_firewall_rules) from azext_horizondb.utils._transformers import ( table_transform_output) @@ -19,6 +20,11 @@ def load_command_table(self, _): custom_commands = CliCommandType( operations_tmpl='azext_horizondb.commands.custom_commands#{}') + + firewall_rule_custom = CliCommandType( + operations_tmpl='azext_horizondb.commands.firewall_rule_commands#{}', + client_factory=cf_horizondb_firewall_rules) + with self.command_group('horizondb', horizondb_clusters_sdk, custom_command_type=custom_commands, client_factory=cf_horizondb_clusters) as g: @@ -27,3 +33,12 @@ def load_command_table(self, _): g.custom_command('delete', 'horizondb_cluster_delete') g.custom_command('list', 'horizondb_cluster_list') g.show_command('show', 'get') + + with self.command_group('horizondb firewall-rule', firewall_rule_custom, + custom_command_type=firewall_rule_custom, + client_factory=cf_horizondb_firewall_rules) as g: + g.custom_command('create', 'horizondb_firewall_rule_create') + g.custom_command('update', 'horizondb_firewall_rule_update') + g.custom_command('delete', 'horizondb_firewall_rule_delete') + g.custom_show_command('show', 'horizondb_firewall_rule_get') + g.custom_command('list', 'horizondb_firewall_rule_list') diff --git a/src/horizondb/azext_horizondb/commands/custom_commands.py b/src/horizondb/azext_horizondb/commands/custom_commands.py index 0b64828d039..875b6e67022 100644 --- a/src/horizondb/azext_horizondb/commands/custom_commands.py +++ b/src/horizondb/azext_horizondb/commands/custom_commands.py @@ -15,6 +15,7 @@ from ..utils._util import ( check_resource_group, generate_missing_cluster_parameters) +from ..utils._network import resolve_public_access_range logger = get_logger(__name__) @@ -26,6 +27,7 @@ def horizondb_cluster_create(cmd, client, resource_group_name=None, cluster_name tags=None, version=None, replica_count=None, v_cores=None, zone_placement_policy=None, + public_access=None, yes=False, no_wait=False): from azext_horizondb.vendored_sdks.models import HorizonDbCluster, HorizonDbClusterProperties @@ -67,16 +69,46 @@ def horizondb_cluster_create(cmd, client, resource_group_name=None, cluster_name properties=properties, ) - return sdk_no_wait(no_wait, client.begin_create_or_update, - resource_group_name=resource_group_name, - cluster_name=cluster_name, - resource=resource) + result = sdk_no_wait(no_wait, client.begin_create_or_update, + resource_group_name=resource_group_name, + cluster_name=cluster_name, + resource=resource) + # When --public-access supplies an IP range, create a firewall rule once the cluster exists. + # HorizonDB's publicNetworkAccess flag is service-computed (read-only), so a firewall rule is + # the mechanism for opening public access. + if public_access is None: + return result -def horizondb_cluster_update(client, resource_group_name, cluster_name, + cluster = result.result() if hasattr(result, 'result') else result + _apply_public_access(cmd, resource_group_name, cluster_name, public_access, yes) + return cluster + + +def _apply_public_access(cmd, resource_group_name, cluster_name, public_access, yes): + val = str(public_access).lower() + if val == 'disabled': + logger.warning("HorizonDB public network access is managed through firewall rules. To remove " + "public access, delete rules with 'az horizondb firewall-rule delete' " + "(list them with 'az horizondb firewall-rule list').") + return + + start_ip, end_ip = resolve_public_access_range(public_access, yes) + if start_ip == -1 or end_ip == -1: + return + + from .._client_factory import cf_horizondb_firewall_rules + from .firewall_rule_commands import create_firewall_rule + firewall_client = cf_horizondb_firewall_rules(cmd.cli_ctx, None) + create_firewall_rule(cmd, firewall_client, resource_group_name, cluster_name, + start_ip_address=start_ip, end_ip_address=end_ip).result() + + +def horizondb_cluster_update(cmd, client, resource_group_name, cluster_name, administrator_login_password=None, tags=None, v_cores=None, parameter_group=None, + public_access=None, yes=False, no_wait=False): from azext_horizondb.vendored_sdks.models import ( HorizonDbClusterForPatchUpdate, @@ -100,15 +132,23 @@ def horizondb_cluster_update(client, resource_group_name, cluster_name, if cluster_properties: patch_properties["properties"] = HorizonDbClusterPropertiesForPatchUpdate(**cluster_properties) - if not patch_properties: + if not patch_properties and public_access is None: raise ArgumentUsageError("Specify at least one argument to update.") - properties = HorizonDbClusterForPatchUpdate(**patch_properties) + update_result = None + if patch_properties: + properties = HorizonDbClusterForPatchUpdate(**patch_properties) + update_result = sdk_no_wait(no_wait, client.begin_update, + resource_group_name=resource_group_name, + cluster_name=cluster_name, + properties=properties) + + if public_access is not None: + _apply_public_access(cmd, resource_group_name, cluster_name, public_access, yes) - return sdk_no_wait(no_wait, client.begin_update, - resource_group_name=resource_group_name, - cluster_name=cluster_name, - properties=properties) + if update_result is not None: + return update_result + return client.get(resource_group_name=resource_group_name, cluster_name=cluster_name) def horizondb_cluster_delete(cmd, client, resource_group_name, cluster_name, no_wait=False, yes=False): diff --git a/src/horizondb/azext_horizondb/commands/firewall_rule_commands.py b/src/horizondb/azext_horizondb/commands/firewall_rule_commands.py new file mode 100644 index 00000000000..0199117cc4e --- /dev/null +++ b/src/horizondb/azext_horizondb/commands/firewall_rule_commands.py @@ -0,0 +1,137 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +# pylint: disable=unused-argument, line-too-long, too-many-locals + +from datetime import datetime +from knack.log import get_logger +from azure.cli.core.azclierror import ArgumentUsageError +from azure.cli.core.util import user_confirmation +from ..utils.validators import validate_resource_group +from ..utils._network import DEFAULT_POOL_NAME + +logger = get_logger(__name__) + + +def _generate_firewall_rule_name(start_ip_address, end_ip_address): + now = datetime.now() + suffix = '{}-{}-{}_{}-{}-{}'.format(now.year, now.month, now.day, now.hour, now.minute, now.second) + if start_ip_address == '0.0.0.0' and end_ip_address == '255.255.255.255': + logger.warning("Configuring firewall rule to accept connections from all IPs...") + return 'AllowAll_{}'.format(suffix) + if start_ip_address == end_ip_address: + logger.warning("Configuring firewall rule to accept connections from '%s'...", start_ip_address) + else: + logger.warning("Configuring firewall rule to accept connections from '%s' to '%s'...", + start_ip_address, end_ip_address) + return 'FirewallIPAddress_{}'.format(suffix) + + +def _build_firewall_rule(start_ip_address, end_ip_address, description=None): + from azext_horizondb.vendored_sdks.models import ( + HorizonDbFirewallRule, + HorizonDbFirewallRuleProperties, + ) + return HorizonDbFirewallRule( + properties=HorizonDbFirewallRuleProperties( + start_ip_address=start_ip_address, + end_ip_address=end_ip_address, + description=description)) + + +def create_firewall_rule(cmd, client, resource_group_name, cluster_name, start_ip_address, + end_ip_address, pool_name=None, firewall_rule_name=None, description=None): + """Create or update a firewall rule on a cluster pool. Shared by the firewall-rule create + command and by ``horizondb create``/``update`` when a --public-access value produces an IP range.""" + pool_name = pool_name or DEFAULT_POOL_NAME + + if end_ip_address is None and start_ip_address is not None: + end_ip_address = start_ip_address + elif start_ip_address is None and end_ip_address is not None: + start_ip_address = end_ip_address + elif start_ip_address is None and end_ip_address is None: + raise ArgumentUsageError( + "Need to pass in a value for either '--start-ip-address' or '--end-ip-address'.") + + if firewall_rule_name is None: + firewall_rule_name = _generate_firewall_rule_name(start_ip_address, end_ip_address) + + resource = _build_firewall_rule(start_ip_address, end_ip_address, description) + + return client.begin_create_or_update( + resource_group_name=resource_group_name, + cluster_name=cluster_name, + pool_name=pool_name, + firewall_rule_name=firewall_rule_name, + resource=resource) + + +def horizondb_firewall_rule_create(cmd, client, resource_group_name, cluster_name, firewall_rule_name=None, + start_ip_address=None, end_ip_address=None, pool_name=None, description=None): + validate_resource_group(resource_group_name) + return create_firewall_rule(cmd, client, resource_group_name, cluster_name, + start_ip_address=start_ip_address, end_ip_address=end_ip_address, + pool_name=pool_name, firewall_rule_name=firewall_rule_name, + description=description) + + +def horizondb_firewall_rule_update(cmd, client, resource_group_name, cluster_name, firewall_rule_name, + start_ip_address=None, end_ip_address=None, pool_name=None, description=None): + validate_resource_group(resource_group_name) + pool_name = pool_name or DEFAULT_POOL_NAME + + existing = client.get( + resource_group_name=resource_group_name, + cluster_name=cluster_name, + pool_name=pool_name, + firewall_rule_name=firewall_rule_name) + existing_props = existing.properties + + new_start = start_ip_address if start_ip_address is not None else existing_props.start_ip_address + new_end = end_ip_address if end_ip_address is not None else existing_props.end_ip_address + new_description = description if description is not None else existing_props.description + + resource = _build_firewall_rule(new_start, new_end, new_description) + + return client.begin_create_or_update( + resource_group_name=resource_group_name, + cluster_name=cluster_name, + pool_name=pool_name, + firewall_rule_name=firewall_rule_name, + resource=resource) + + +def horizondb_firewall_rule_delete(cmd, client, resource_group_name, cluster_name, firewall_rule_name, + pool_name=None, yes=False): + validate_resource_group(resource_group_name) + pool_name = pool_name or DEFAULT_POOL_NAME + if not yes: + user_confirmation( + "Are you sure you want to delete the firewall rule '{0}' in cluster '{1}', resource group " + "'{2}'".format(firewall_rule_name, cluster_name, resource_group_name), yes=yes) + return client.begin_delete( + resource_group_name=resource_group_name, + cluster_name=cluster_name, + pool_name=pool_name, + firewall_rule_name=firewall_rule_name) + + +def horizondb_firewall_rule_get(cmd, client, resource_group_name, cluster_name, firewall_rule_name, pool_name=None): + validate_resource_group(resource_group_name) + pool_name = pool_name or DEFAULT_POOL_NAME + return client.get( + resource_group_name=resource_group_name, + cluster_name=cluster_name, + pool_name=pool_name, + firewall_rule_name=firewall_rule_name) + + +def horizondb_firewall_rule_list(cmd, client, resource_group_name, cluster_name, pool_name=None): + validate_resource_group(resource_group_name) + pool_name = pool_name or DEFAULT_POOL_NAME + return client.list( + resource_group_name=resource_group_name, + cluster_name=cluster_name, + pool_name=pool_name) diff --git a/src/horizondb/azext_horizondb/tests/latest/test_horizondb_firewall_rule.py b/src/horizondb/azext_horizondb/tests/latest/test_horizondb_firewall_rule.py new file mode 100644 index 00000000000..f24f670f5b8 --- /dev/null +++ b/src/horizondb/azext_horizondb/tests/latest/test_horizondb_firewall_rule.py @@ -0,0 +1,216 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +import unittest +from argparse import Namespace +from unittest import mock + +from knack.util import CLIError +from azure.cli.core.azclierror import ArgumentUsageError, InvalidArgumentValueError + +from azext_horizondb.utils.validators import ( + public_access_validator, + ip_address_validator, + _validate_ip, + _validate_ranges_in_ip, + _valid_range, + _validate_start_and_end_ip_address_order, +) +from azext_horizondb.utils._network import ( + DEFAULT_POOL_NAME, + parse_public_access_input, + resolve_public_access_range, +) +from azext_horizondb.commands.firewall_rule_commands import ( + _generate_firewall_rule_name, + create_firewall_rule, + horizondb_firewall_rule_create, + horizondb_firewall_rule_list, +) + + +class HorizonDBPublicAccessValidatorTests(unittest.TestCase): + + def test_public_access_valid_keywords(self): + for value in ['Enabled', 'Disabled', 'All', 'None', 'enabled', 'disabled']: + public_access_validator(Namespace(public_access=value)) + + def test_public_access_valid_single_ip(self): + public_access_validator(Namespace(public_access='12.12.12.12')) + + def test_public_access_valid_range(self): + public_access_validator(Namespace(public_access='12.12.12.0-12.12.12.255')) + + def test_public_access_none_value_is_noop(self): + # An unset value (None) should not raise. + public_access_validator(Namespace(public_access=None)) + + def test_public_access_invalid_keyword(self): + with self.assertRaises(CLIError): + public_access_validator(Namespace(public_access='sometimes')) + + def test_public_access_invalid_ip_octet(self): + with self.assertRaises(CLIError): + public_access_validator(Namespace(public_access='999.0.0.1')) + + def test_public_access_reversed_range_raises(self): + with self.assertRaises(ArgumentUsageError): + public_access_validator(Namespace(public_access='12.12.12.255-12.12.12.0')) + + +class HorizonDBIpAddressValidatorTests(unittest.TestCase): + + def test_valid_start_and_end(self): + ip_address_validator(Namespace(start_ip_address='1.1.1.1', end_ip_address='2.2.2.2')) + + def test_invalid_ip_raises(self): + with self.assertRaises(CLIError): + ip_address_validator(Namespace(start_ip_address='1.1.1', end_ip_address=None)) + + def test_reversed_order_raises(self): + with self.assertRaises(ArgumentUsageError): + ip_address_validator(Namespace(start_ip_address='2.2.2.2', end_ip_address='1.1.1.1')) + + def test_only_start_ip_valid(self): + ip_address_validator(Namespace(start_ip_address='1.1.1.1', end_ip_address=None)) + + def test_range_value_in_single_ip_field_raises_clierror(self): + # A dash-separated range is invalid for the single-IP start/end fields and must surface a + # clean CLIError rather than a raw ValueError. + with self.assertRaises(CLIError): + ip_address_validator(Namespace(start_ip_address='1.1.1.1-2.2.2.2', end_ip_address='3.3.3.3')) + + +class HorizonDBIpHelpersTests(unittest.TestCase): + + def test_valid_range(self): + self.assertTrue(_valid_range('0')) + self.assertTrue(_valid_range('255')) + self.assertFalse(_valid_range('256')) + self.assertFalse(_valid_range('-1')) + self.assertFalse(_valid_range('abc')) + + def test_validate_ranges_in_ip(self): + self.assertTrue(_validate_ranges_in_ip('192.168.0.1')) + self.assertFalse(_validate_ranges_in_ip('192.168.0')) + self.assertFalse(_validate_ranges_in_ip('192.168.0.256')) + + def test_validate_ip_single_and_range(self): + self.assertTrue(_validate_ip('10.0.0.1')) + self.assertTrue(_validate_ip('10.0.0.1-10.0.0.5')) + self.assertFalse(_validate_ip('10.0.0.1-10.0.0.5-10.0.0.9')) + + def test_start_end_order_ok(self): + _validate_start_and_end_ip_address_order('10.0.0.1', '10.0.0.5') + + def test_start_end_order_bad(self): + with self.assertRaises(ArgumentUsageError): + _validate_start_and_end_ip_address_order('10.0.0.5', '10.0.0.1') + + +class HorizonDBParsePublicAccessTests(unittest.TestCase): + + def test_single_ip(self): + self.assertEqual(parse_public_access_input('10.0.0.1'), ('10.0.0.1', '10.0.0.1')) + + def test_range(self): + self.assertEqual(parse_public_access_input('10.0.0.1-10.0.0.9'), ('10.0.0.1', '10.0.0.9')) + + def test_none(self): + self.assertEqual(parse_public_access_input(None), (None, None)) + + def test_too_many_parts(self): + with self.assertRaises(InvalidArgumentValueError): + parse_public_access_input('10.0.0.1-10.0.0.9-10.0.0.20') + + +class HorizonDBResolvePublicAccessRangeTests(unittest.TestCase): + + def test_all(self): + self.assertEqual(resolve_public_access_range('All', yes=True), ('0.0.0.0', '255.255.255.255')) + + def test_none_and_disabled(self): + self.assertEqual(resolve_public_access_range('None', yes=True), (-1, -1)) + self.assertEqual(resolve_public_access_range('Disabled', yes=True), (-1, -1)) + + def test_single_ip(self): + self.assertEqual(resolve_public_access_range('10.0.0.1', yes=True), ('10.0.0.1', '10.0.0.1')) + + def test_range(self): + self.assertEqual(resolve_public_access_range('10.0.0.1-10.0.0.9', yes=True), + ('10.0.0.1', '10.0.0.9')) + + @mock.patch('azext_horizondb.utils._network.get') + def test_enabled_resolves_client_ip(self, mock_get): + response = mock.MagicMock() + response.text = '13.13.13.13' + response.raise_for_status.return_value = None + mock_get.return_value = response + + self.assertEqual(resolve_public_access_range('Enabled', yes=True), + ('13.13.13.13', '13.13.13.13')) + + @mock.patch('azext_horizondb.utils._network.get') + def test_enabled_detection_failure_raises(self, mock_get): + mock_get.side_effect = Exception('network down') + with self.assertRaises(CLIError): + resolve_public_access_range('Enabled', yes=True) + + +class HorizonDBFirewallRuleNameTests(unittest.TestCase): + + def test_allow_all_name(self): + name = _generate_firewall_rule_name('0.0.0.0', '255.255.255.255') + self.assertTrue(name.startswith('AllowAll_')) + + def test_single_ip_name(self): + name = _generate_firewall_rule_name('10.0.0.1', '10.0.0.1') + self.assertTrue(name.startswith('FirewallIPAddress_')) + + def test_range_name(self): + name = _generate_firewall_rule_name('10.0.0.1', '10.0.0.9') + self.assertTrue(name.startswith('FirewallIPAddress_')) + + +class HorizonDBFirewallRuleCommandTests(unittest.TestCase): + + def test_create_defaults_end_ip_to_start_and_targets_default_pool(self): + client = mock.MagicMock() + create_firewall_rule(cmd=None, client=client, resource_group_name='rg', + cluster_name='c', start_ip_address='10.0.0.1', end_ip_address=None, + firewall_rule_name='rule1') + + _, kwargs = client.begin_create_or_update.call_args + self.assertEqual(kwargs['pool_name'], DEFAULT_POOL_NAME) + self.assertEqual(kwargs['firewall_rule_name'], 'rule1') + resource = kwargs['resource'] + self.assertEqual(resource.properties.start_ip_address, '10.0.0.1') + self.assertEqual(resource.properties.end_ip_address, '10.0.0.1') + + def test_create_requires_at_least_one_ip(self): + client = mock.MagicMock() + with self.assertRaises(ArgumentUsageError): + create_firewall_rule(cmd=None, client=client, resource_group_name='rg', + cluster_name='c', start_ip_address=None, end_ip_address=None) + + def test_command_create_passes_description(self): + client = mock.MagicMock() + horizondb_firewall_rule_create(cmd=None, client=client, resource_group_name='rg', + cluster_name='c', firewall_rule_name='rule1', + start_ip_address='10.0.0.1', end_ip_address='10.0.0.9', + description='corp network') + _, kwargs = client.begin_create_or_update.call_args + self.assertEqual(kwargs['resource'].properties.description, 'corp network') + self.assertEqual(kwargs['resource'].properties.end_ip_address, '10.0.0.9') + + def test_list_targets_default_pool(self): + client = mock.MagicMock() + horizondb_firewall_rule_list(cmd=None, client=client, resource_group_name='rg', cluster_name='c') + _, kwargs = client.list.call_args + self.assertEqual(kwargs['pool_name'], DEFAULT_POOL_NAME) + + +if __name__ == '__main__': + unittest.main() diff --git a/src/horizondb/azext_horizondb/tests/latest/test_horizondb_firewall_rule_scenario.py b/src/horizondb/azext_horizondb/tests/latest/test_horizondb_firewall_rule_scenario.py new file mode 100644 index 00000000000..569ae8ce698 --- /dev/null +++ b/src/horizondb/azext_horizondb/tests/latest/test_horizondb_firewall_rule_scenario.py @@ -0,0 +1,71 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +from azure.cli.testsdk.scenario_tests import AllowLargeResponse +from azure.cli.testsdk import ( + JMESPathCheck, + ResourceGroupPreparer, + ScenarioTest, + live_only) +from .constants import DEFAULT_LOCATION, CLUSTER_NAME_PREFIX, CLUSTER_NAME_MAX_LENGTH, PASSWORD_PREFIX + + +# These scenario tests exercise the live control plane end-to-end. They are marked live-only because +# firewall-rule provisioning is not yet captured in a committed recording; run with `--live` to +# generate a cassette. +class HorizonDBFirewallRuleScenarioTest(ScenarioTest): + + location = DEFAULT_LOCATION + + @live_only() + @AllowLargeResponse() + @ResourceGroupPreparer(location=location) + def test_horizondb_firewall_rule_mgmt(self, resource_group): + cluster_name = self.create_random_name(CLUSTER_NAME_PREFIX, CLUSTER_NAME_MAX_LENGTH) + admin_user = 'horizonadmin' + admin_password = self.create_random_name(PASSWORD_PREFIX, 20) + + self.kwargs.update({ + 'rg': resource_group, + 'cluster': cluster_name, + 'location': self.location, + 'admin_user': admin_user, + 'admin_password': admin_password, + 'rule': 'allowrange', + }) + + # Create a cluster and open public access to a single IP; this should auto-create a rule. + self.cmd('horizondb create -g {rg} -n {cluster} --location {location} ' + '--administrator-login {admin_user} --administrator-login-password {admin_password} ' + '--version 17 --v-cores 4 --public-access 12.12.12.12 --yes', + checks=[JMESPathCheck('name', cluster_name)]) + + # The auto-created rule should be present on the default pool. + self.cmd('horizondb firewall-rule list -g {rg} --cluster-name {cluster}', + checks=[JMESPathCheck("length([?properties.startIpAddress=='12.12.12.12'])", 1)]) + + # Create an explicit range rule. + self.cmd('horizondb firewall-rule create -g {rg} --cluster-name {cluster} --name {rule} ' + '--start-ip-address 10.0.0.0 --end-ip-address 10.0.0.255', + checks=[ + JMESPathCheck('name', 'allowrange'), + JMESPathCheck('properties.startIpAddress', '10.0.0.0'), + JMESPathCheck('properties.endIpAddress', '10.0.0.255'), + ]) + + # Show the rule. + self.cmd('horizondb firewall-rule show -g {rg} --cluster-name {cluster} --name {rule}', + checks=[JMESPathCheck('name', 'allowrange')]) + + # Update the rule's end IP. + self.cmd('horizondb firewall-rule update -g {rg} --cluster-name {cluster} --name {rule} ' + '--end-ip-address 10.0.0.128', + checks=[JMESPathCheck('properties.endIpAddress', '10.0.0.128')]) + + # Delete the rule. + self.cmd('horizondb firewall-rule delete -g {rg} --cluster-name {cluster} --name {rule} --yes') + + # Clean up the cluster. + self.cmd('horizondb delete -g {rg} -n {cluster} --yes') diff --git a/src/horizondb/azext_horizondb/utils/_network.py b/src/horizondb/azext_horizondb/utils/_network.py new file mode 100644 index 00000000000..4d1d515d14c --- /dev/null +++ b/src/horizondb/azext_horizondb/utils/_network.py @@ -0,0 +1,83 @@ +# -------------------------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for license information. +# -------------------------------------------------------------------------------------------- + +# pylint: disable=raise-missing-from + +from requests import get +from knack.log import get_logger +from knack.prompting import NoTTYException, prompt_y_n +from knack.util import CLIError +from azure.cli.core.azclierror import InvalidArgumentValueError +from .validators import _validate_ranges_in_ip + +logger = get_logger(__name__) + +# The reserved, service-seeded default pool. HorizonDB firewall rules are pool-scoped +# (.../clusters/{cluster}/pools/{pool}/firewallRules/{name}); Public Preview clusters use a +# single default pool named "DefaultPool". +DEFAULT_POOL_NAME = 'DefaultPool' + +# Service used to detect the caller's outbound public IP address. +IP_ADDRESS_CHECKER = 'https://api.ipify.org' + + +def parse_public_access_input(public_access): + if public_access is not None: + parsed_input = public_access.split('-') + if len(parsed_input) == 1: + return parsed_input[0], parsed_input[0] + if len(parsed_input) == 2: + return parsed_input[0], parsed_input[1] + raise InvalidArgumentValueError( + "incorrect usage: --public-access. Acceptable values are 'All', 'None', '' and " + "'-' where startIP and destinationIP range from 0.0.0.0 to " + "255.255.255.255") + return None, None + + +def _get_user_confirmation(message, yes=False): + if yes: + return True + try: + return bool(prompt_y_n(message)) + except NoTTYException: + raise CLIError('Unable to prompt for confirmation as no tty available. Use --yes.') + + +def _resolve_client_ip_range(yes): + try: + response = get(IP_ADDRESS_CHECKER, timeout=5) + response.raise_for_status() + ip_address = response.text.strip() + if not _validate_ranges_in_ip(ip_address): + raise ValueError('The detection service returned an invalid IPv4 address.') + except Exception as ex: + raise CLIError('Unable to detect your current IP address. Please provide a valid IP address or ' + 'range for the --public-access parameter, or set --public-access Disabled. ' + 'Error: {}'.format(ex)) + + logger.warning('Detected current client IP : %s', ip_address) + if _get_user_confirmation('Do you want to enable access to client {0}'.format(ip_address), yes=yes): + return ip_address, ip_address + if _get_user_confirmation('Do you want to enable access for all IPs', yes=yes): + return '0.0.0.0', '255.255.255.255' + return -1, -1 + + +def resolve_public_access_range(public_access, yes): + """Map a --public-access value to a (start_ip, end_ip) pair. + + Returns (-1, -1) when no firewall rule should be created. ``Enabled`` triggers client-IP + auto-detection because HorizonDB's ``publicNetworkAccess`` flag is service-computed (read-only), + so a firewall rule is the only way to open public access. + """ + val = str(public_access).lower() + if val == 'enabled': + return _resolve_client_ip_range(yes) + if val == 'all': + return '0.0.0.0', '255.255.255.255' + if val in ['none', 'disabled']: + return -1, -1 + return parse_public_access_input(public_access) diff --git a/src/horizondb/azext_horizondb/utils/validators.py b/src/horizondb/azext_horizondb/utils/validators.py index 2441502ddee..bed5944acb5 100644 --- a/src/horizondb/azext_horizondb/utils/validators.py +++ b/src/horizondb/azext_horizondb/utils/validators.py @@ -5,6 +5,7 @@ from knack.prompting import NoTTYException, prompt_pass from knack.util import CLIError +from azure.cli.core.azclierror import ArgumentUsageError from azure.cli.core.commands.validators import ( get_default_location_from_resource_group, validate_tags) from typing import Any, Dict, Iterable, Optional @@ -70,3 +71,59 @@ def validate_replica_count(ns): return if ns.replica_count < 1 or ns.replica_count > 16: raise CLIError('Replica count must be between 1 and 16, inclusive.') + + +def ip_address_validator(ns): + if (ns.end_ip_address and not _validate_ranges_in_ip(ns.end_ip_address)) or \ + (ns.start_ip_address and not _validate_ranges_in_ip(ns.start_ip_address)): + raise CLIError('Invalid IP address. Provide an IPv4 address, for example 12.12.12.12.') + if ns.start_ip_address and ns.end_ip_address: + _validate_start_and_end_ip_address_order(ns.start_ip_address, ns.end_ip_address) + + +def public_access_validator(ns): + if ns.public_access: + val = ns.public_access.lower() + if not (val in ['disabled', 'enabled', 'all', 'none'] or + (len(val.split('-')) == 1 and _validate_ip(val)) or + (len(val.split('-')) == 2 and _validate_ip(val))): + raise CLIError('Invalid value for --public-access. ' + 'Allowed values: \'Disabled\', \'Enabled\', \'All\', \'None\', \'\', ' + 'or \'-\', where each IP ranges from 0.0.0.0 to 255.255.255.255.') + if len(val.split('-')) == 2: + vals = val.split('-') + _validate_start_and_end_ip_address_order(vals[0], vals[1]) + + +def _validate_start_and_end_ip_address_order(start_ip, end_ip): + start_ip_elements = [int(octet) for octet in start_ip.split('.')] + end_ip_elements = [int(octet) for octet in end_ip.split('.')] + + for idx in range(4): + if start_ip_elements[idx] < end_ip_elements[idx]: + break + if start_ip_elements[idx] > end_ip_elements[idx]: + raise ArgumentUsageError('The end IP address is smaller than the start IP address.') + + +def _validate_ip(ips): + parsed_input = ips.split('-') + if len(parsed_input) == 1: + return _validate_ranges_in_ip(parsed_input[0]) + if len(parsed_input) == 2: + return _validate_ranges_in_ip(parsed_input[0]) and _validate_ranges_in_ip(parsed_input[1]) + return False + + +def _validate_ranges_in_ip(ip): + parsed_ip = ip.split('.') + if len(parsed_ip) == 4 and _valid_range(parsed_ip[0]) and _valid_range(parsed_ip[1]) \ + and _valid_range(parsed_ip[2]) and _valid_range(parsed_ip[3]): + return True + return False + + +def _valid_range(addr_range): + if addr_range.isdigit() and 0 <= int(addr_range) <= 255: + return True + return False diff --git a/src/horizondb/setup.py b/src/horizondb/setup.py index 3200409a61d..fe22055b0e4 100644 --- a/src/horizondb/setup.py +++ b/src/horizondb/setup.py @@ -14,7 +14,7 @@ from distutils import log as logger logger.warn("Wheel is not available, disabling bdist_wheel hook") -VERSION = '1.0.0b4' +VERSION = '1.0.0b5' CLASSIFIERS = [ 'Development Status :: 4 - Beta',