diff --git a/src/confcom/azext_confcom/_help.py b/src/confcom/azext_confcom/_help.py index 1244f749474..80ba15b9b90 100644 --- a/src/confcom/azext_confcom/_help.py +++ b/src/confcom/azext_confcom/_help.py @@ -67,7 +67,7 @@ - name: --diff -d type: boolean - short-summary: 'When combined with an input ARM Template, verifies the policy present in the ARM Template under "ccePolicy" and the containers within the ARM Template are compatible. If they are incompatible, a list of reasons is given and the exit status code will be 2.' + short-summary: 'When combined with an input ARM Template file (or YAML file for Virtual Node policy generation), verifies the policy present in the ARM Template under "ccePolicy" and the containers within the file are compatible. If they are incompatible, a list of reasons is given and the exit status code will be 2.' - name: --outraw type: boolean diff --git a/src/confcom/azext_confcom/_params.py b/src/confcom/azext_confcom/_params.py index a7c11dbb1af..1429abb7664 100644 --- a/src/confcom/azext_confcom/_params.py +++ b/src/confcom/azext_confcom/_params.py @@ -100,7 +100,7 @@ def load_arguments(self, _): "diff", options_list=("--diff", "-d"), required=False, - help="Compare the CCE Policy field in the ARM Template to the containers in the ARM Template and make sure they are compatible", + help="Compare the CCE Policy field in the ARM Template file (or YAML file for Virtual Node) to the containers in the file and make sure they are compatible", validator=validate_diff ) c.argument( diff --git a/src/confcom/azext_confcom/custom.py b/src/confcom/azext_confcom/custom.py index b03421adeba..d127404e03a 100644 --- a/src/confcom/azext_confcom/custom.py +++ b/src/confcom/azext_confcom/custom.py @@ -110,6 +110,7 @@ def acipolicygen_confcom( debug_mode=debug_mode, disable_stdio=disable_stdio, approve_wildcards=approve_wildcards, + diff_mode=diff ) elif image_name: container_group_policies = security_policy.load_policy_from_image_name( @@ -121,6 +122,7 @@ def acipolicygen_confcom( debug_mode=debug_mode, disable_stdio=disable_stdio, approve_wildcards=approve_wildcards, + diff_mode=diff ) virtual_node_yaml = list(os_util.load_multiple_yaml_from_file(virtual_node_yaml_path)) filtered_yaml = filter_non_pod_resources(virtual_node_yaml) diff --git a/src/confcom/azext_confcom/os_util.py b/src/confcom/azext_confcom/os_util.py index 8541a0dd24d..8fb9edeeae7 100644 --- a/src/confcom/azext_confcom/os_util.py +++ b/src/confcom/azext_confcom/os_util.py @@ -30,8 +30,8 @@ def base64_to_str(data: str) -> str: try: data_bytes = base64.b64decode(data) data_str = data_bytes.decode("ascii") - except binascii.Error: - eprint(f"Invalid base64 string: {data}") + except binascii.Error as e: + raise ValueError(f"Invalid base64 string: {data}") from e return data_str diff --git a/src/confcom/azext_confcom/security_policy.py b/src/confcom/azext_confcom/security_policy.py index 93d4470ab56..a30539b2cc9 100644 --- a/src/confcom/azext_confcom/security_policy.py +++ b/src/confcom/azext_confcom/security_policy.py @@ -533,6 +533,7 @@ def load_policy_from_arm_template_str( debug_mode: bool = False, disable_stdio: bool = False, approve_wildcards: bool = False, + diff_mode: bool = False ) -> List[AciPolicy]: """Function that converts ARM template string to an ACI Policy""" input_arm_json = os_util.load_json_from_str(template_data) @@ -600,9 +601,17 @@ def load_policy_from_arm_template_str( if init_container_list: container_list.extend(init_container_list) - existing_containers, fragments = extract_confidential_properties( - container_group_properties - ) + try: + existing_containers, fragments = extract_confidential_properties( + container_group_properties + ) + except ValueError as e: + if diff_mode: + # In diff mode, we raise an error if the base64 policy is malformed + eprint(f"Unable to decode existing policy. Please check the base64 encoding.\n{e}") + else: + # In non-diff mode, we ignore the error and proceed without the policy + existing_containers, fragments = ([], []) rego_fragments = copy.deepcopy(config.DEFAULT_REGO_FRAGMENTS) if infrastructure_svn: @@ -693,6 +702,7 @@ def load_policy_from_arm_template_file( debug_mode: bool = False, disable_stdio: bool = False, approve_wildcards: bool = False, + diff_mode: bool = False ) -> List[AciPolicy]: """Utility function: generate policy object from given arm template and parameter file paths""" input_arm_json = os_util.load_str_from_file(template_path) @@ -702,6 +712,7 @@ def load_policy_from_arm_template_file( return load_policy_from_arm_template_str( input_arm_json, input_parameter_json, infrastructure_svn, debug_mode=debug_mode, disable_stdio=disable_stdio, approve_wildcards=approve_wildcards, + diff_mode=diff_mode ) @@ -849,14 +860,16 @@ def load_policy_from_virtual_node_yaml_file( virtual_node_yaml_path: str, debug_mode: bool = False, disable_stdio: bool = False, - approve_wildcards: bool = False + approve_wildcards: bool = False, + diff_mode: bool = False ) -> List[AciPolicy]: yaml_contents_str = os_util.load_str_from_file(virtual_node_yaml_path) return load_policy_from_virtual_node_yaml_str( yaml_contents_str, debug_mode=debug_mode, disable_stdio=disable_stdio, - approve_wildcards=approve_wildcards + approve_wildcards=approve_wildcards, + diff_mode=diff_mode ) @@ -864,7 +877,8 @@ def load_policy_from_virtual_node_yaml_str( yaml_contents_str: List[str], debug_mode: bool = False, disable_stdio: bool = False, - approve_wildcards: bool = False + approve_wildcards: bool = False, + diff_mode: bool = False ) -> List[AciPolicy]: """ Load a virtual node yaml file and generate a policy object @@ -892,9 +906,16 @@ def load_policy_from_virtual_node_yaml_str( metadata = case_insensitive_dict_get(yaml, "metadata") annotations = case_insensitive_dict_get(metadata, config.VIRTUAL_NODE_YAML_ANNOTATIONS) existing_policy = case_insensitive_dict_get(annotations, config.VIRTUAL_NODE_YAML_POLICY) - if existing_policy: - (existing_containers, existing_fragments) = decompose_confidential_properties(existing_policy) - + try: + if existing_policy: + existing_containers, existing_fragments = decompose_confidential_properties(existing_policy) + except ValueError as e: + if diff_mode: + # In diff mode, we raise an error if the base64 policy is malformed + eprint(f"Unable to decode existing policy. Please check the base64 encoding.\n{e}") + else: + # In non-diff mode, we ignore the error and proceed without the policy + existing_containers, existing_fragments = ([], []) # because there are many ways to get pod information, we normalize them so the interface is the same normalized_yaml = convert_to_pod_spec(yaml) @@ -946,7 +967,10 @@ def load_policy_from_virtual_node_yaml_str( mount_path = case_insensitive_dict_get(mount, "mountPath") # find the corresponding volume - volume = next((vol for vol in volumes if case_insensitive_dict_get(vol, "name") == mount_name), None) + volume = next( + (vol for vol in volumes if case_insensitive_dict_get(vol, "name") == mount_name), + None + ) # determine if this volume is one of the read-only types read_only_default = any(key in read_only_types for key in volume.keys()) @@ -955,14 +979,17 @@ def load_policy_from_virtual_node_yaml_str( # log warning if readOnly is explicitly set to false for a read-only volume type if case_insensitive_dict_get(mount, "readOnly") is False: logger.warning( - "Volume '%s' in container '%s' is of a type that requires readOnly access (%s), but readOnly: false was specified. Enforcing readOnly: true for policy generation.", - mount_name, case_insensitive_dict_get(container, "name"), ', '.join(read_only_types) + "Volume '%s' in container '%s' is of a type that requires readOnly access (%s), " + "but readOnly: false was specified. Enforcing readOnly: true for policy generation.", + mount_name, + case_insensitive_dict_get(container, "name"), + ', '.join(read_only_types) ) mount_readonly = True else: # use the readOnly field or default to False for non-read-only volumes mount_readonly = case_insensitive_dict_get(mount, "readOnly") or False - + mounts.append({ config.ACI_FIELD_CONTAINERS_MOUNTS_TYPE: config.ACI_FIELD_YAML_MOUNT_TYPE, config.ACI_FIELD_CONTAINERS_MOUNTS_PATH: mount_path,