From 516fe6582d83652cc56e37dc7e3d80be866d43fc Mon Sep 17 00:00:00 2001 From: zhongjie Date: Mon, 8 Jan 2024 23:14:45 +0800 Subject: [PATCH] vmsdk/python/tests: add tests for TDX This patch mainly adds some tests for TDX. And it refactors some corresponding code accordingly. Signed-off-by: zhongjie --- .github/workflows/vmsdk-test-python.yaml | 10 ++- common/python/cctrusted_base/api.py | 7 +- common/python/cctrusted_base/imr.py | 20 +---- common/python/cctrusted_base/tdx/rtmr.py | 23 +++++ common/python/cctrusted_base/tpm/pcr.py | 12 +++ vmsdk/python/cc_imr_cli.py | 3 +- vmsdk/python/cc_quote_cli.py | 2 +- vmsdk/python/cctrusted_vm/cvm.py | 3 +- vmsdk/python/cctrusted_vm/sdk.py | 11 ++- vmsdk/python/tests/conftest.py | 66 ++++++++++++++ vmsdk/python/tests/tdx_check.py | 77 +++++++++++++++++ vmsdk/python/tests/test_sdk.py | 104 +++++++++++------------ 12 files changed, 259 insertions(+), 79 deletions(-) create mode 100644 common/python/cctrusted_base/tdx/rtmr.py create mode 100644 common/python/cctrusted_base/tpm/pcr.py create mode 100644 vmsdk/python/tests/conftest.py create mode 100644 vmsdk/python/tests/tdx_check.py diff --git a/.github/workflows/vmsdk-test-python.yaml b/.github/workflows/vmsdk-test-python.yaml index f9509dab..06a074fa 100644 --- a/.github/workflows/vmsdk-test-python.yaml +++ b/.github/workflows/vmsdk-test-python.yaml @@ -19,4 +19,12 @@ jobs: - name: Run PyTest for VMSDK run: | set -ex - sudo su -c "source setupenv.sh && python3 -m pytest -v ./vmsdk/python/tests/test_sdk.py" + # Set the "PYTHONDONTWRITEBYTECODE" and "no:cacheprovider" to prevent + # generated some intermediate files by root. Othwerwise, these + # files will fail the action/checkout in the next round of running + # due to the permission issue. + sudo su -c "source setupenv.sh && \ + pushd vmsdk/python/tests && \ + export PYTHONDONTWRITEBYTECODE=1 && \ + python3 -m pytest -p no:cacheprovider -v test_sdk.py && \ + popd" diff --git a/common/python/cctrusted_base/api.py b/common/python/cctrusted_base/api.py index be3415c7..91628b89 100644 --- a/common/python/cctrusted_base/api.py +++ b/common/python/cctrusted_base/api.py @@ -66,7 +66,12 @@ def get_measurement(self, imr_select:[int, int]) -> TcgIMR: raise NotImplementedError("Inherited SDK class should implement this.") @abstractmethod - def get_quote(self, nonce: bytearray, data: bytearray, extraArgs=None) -> Quote: + def get_quote( + self, + nonce: bytearray = None, + data: bytearray = None, + extraArgs = None + ) -> Quote: """Get the quote for given nonce and data. The quote is signing of attestation data (IMR values or hashes of IMR diff --git a/common/python/cctrusted_base/imr.py b/common/python/cctrusted_base/imr.py index d1fd357d..53ea0220 100644 --- a/common/python/cctrusted_base/imr.py +++ b/common/python/cctrusted_base/imr.py @@ -3,7 +3,7 @@ """ from abc import ABC, abstractmethod -from cctrusted_base.tcg import TcgDigest, TcgAlgorithmRegistry +from cctrusted_base.tcg import TcgDigest class TcgIMR(ABC): """Common Integrated Measurement Register class.""" @@ -56,21 +56,3 @@ def is_valid(self): """ return self._index != TcgIMR._INVALID_IMR_INDEX and \ self._index <= self.max_index - -class TdxRTMR(TcgIMR): - """RTMR class defined for Intel TDX.""" - - @property - def max_index(self): - return 3 - - def __init__(self, index, digest_hash): - super().__init__(index, TcgAlgorithmRegistry.TPM_ALG_SHA384, - digest_hash) - -class TpmPCR(TcgIMR): - """PCR class defined for TPM""" - - @property - def max_index(self): - return 23 diff --git a/common/python/cctrusted_base/tdx/rtmr.py b/common/python/cctrusted_base/tdx/rtmr.py new file mode 100644 index 00000000..27a54ef4 --- /dev/null +++ b/common/python/cctrusted_base/tdx/rtmr.py @@ -0,0 +1,23 @@ +""" +RTMR (Runtime Measurement Register). +""" + +from cctrusted_base.imr import TcgIMR +from cctrusted_base.tcg import TcgAlgorithmRegistry + +class TdxRTMR(TcgIMR): + """RTMR class defined for Intel TDX.""" + + RTMR_COUNT = 4 + """Intel TDX TDREPORT provides the 4 measurement registers by default.""" + + RTMR_LENGTH_BY_BYTES = 48 + """RTMR length by bytes.""" + + @property + def max_index(self): + return 3 + + def __init__(self, index, digest_hash): + super().__init__(index, TcgAlgorithmRegistry.TPM_ALG_SHA384, + digest_hash) diff --git a/common/python/cctrusted_base/tpm/pcr.py b/common/python/cctrusted_base/tpm/pcr.py new file mode 100644 index 00000000..e33d61c2 --- /dev/null +++ b/common/python/cctrusted_base/tpm/pcr.py @@ -0,0 +1,12 @@ +""" +PCR (Platform Configuration Register). +""" + +from cctrusted_base.imr import TcgIMR + +class TpmPCR(TcgIMR): + """PCR class defined for TPM""" + + @property + def max_index(self): + return 23 diff --git a/vmsdk/python/cc_imr_cli.py b/vmsdk/python/cc_imr_cli.py index ce181840..bf4452f0 100644 --- a/vmsdk/python/cc_imr_cli.py +++ b/vmsdk/python/cc_imr_cli.py @@ -12,7 +12,8 @@ count = CCTrustedVmSdk.inst().get_measurement_count() for index in range(CCTrustedVmSdk.inst().get_measurement_count()): alg = CCTrustedVmSdk.inst().get_default_algorithms() - digest_obj = CCTrustedVmSdk.inst().get_measurement([index, alg.alg_id]) + imr = CCTrustedVmSdk.inst().get_measurement([index, alg.alg_id]) + digest_obj = imr.digest(alg.alg_id) hash_str = "" for hash_item in digest_obj.hash: diff --git a/vmsdk/python/cc_quote_cli.py b/vmsdk/python/cc_quote_cli.py index 34bb97a0..1e4a8587 100644 --- a/vmsdk/python/cc_quote_cli.py +++ b/vmsdk/python/cc_quote_cli.py @@ -42,7 +42,7 @@ def main(): level=logging.NOTSET, format="%(name)s %(levelname)-8s %(message)s" ) - quote = CCTrustedVmSdk.inst().get_quote(None, None, None) + quote = CCTrustedVmSdk.inst().get_quote() if quote is not None: quote.dump(args.out_format == OUT_FORMAT_RAW) else: diff --git a/vmsdk/python/cctrusted_vm/cvm.py b/vmsdk/python/cctrusted_vm/cvm.py index f9386a34..161c0b16 100644 --- a/vmsdk/python/cctrusted_vm/cvm.py +++ b/vmsdk/python/cctrusted_vm/cvm.py @@ -13,10 +13,11 @@ import struct import fcntl from abc import abstractmethod -from cctrusted_base.imr import TdxRTMR,TcgIMR +from cctrusted_base.imr import TcgIMR from cctrusted_base.quote import Quote from cctrusted_base.tcg import TcgAlgorithmRegistry from cctrusted_base.tdx.common import TDX_VERSION_1_0, TDX_VERSION_1_5 +from cctrusted_base.tdx.rtmr import TdxRTMR from cctrusted_base.tdx.quote import TdxQuoteReq10, TdxQuoteReq15 from cctrusted_base.tdx.report import TdxReportReq10, TdxReportReq15 diff --git a/vmsdk/python/cctrusted_vm/sdk.py b/vmsdk/python/cctrusted_vm/sdk.py index 2f08a427..115fbfdf 100644 --- a/vmsdk/python/cctrusted_vm/sdk.py +++ b/vmsdk/python/cctrusted_vm/sdk.py @@ -86,9 +86,14 @@ def get_measurement(self, imr_select:[int, int]) -> TcgIMR: if algo_id is None or algo_id is TcgAlgorithmRegistry.TPM_ALG_ERROR: algo_id = self._cvm.default_algo_id - return self._cvm.imrs[imr_index].digest(algo_id) - - def get_quote(self, nonce: bytearray, data: bytearray, extraArgs=None) -> Quote: + return self._cvm.imrs[imr_index] + + def get_quote( + self, + nonce: bytearray = None, + data: bytearray = None, + extraArgs = None + ) -> Quote: """Get the quote for given nonce and data. The quote is signing of attestation data (IMR values or hashes of IMR diff --git a/vmsdk/python/tests/conftest.py b/vmsdk/python/tests/conftest.py new file mode 100644 index 00000000..cf718b3d --- /dev/null +++ b/vmsdk/python/tests/conftest.py @@ -0,0 +1,66 @@ +"""Local conftest.py containing directory-specific hook implementations.""" + +import pytest +from cctrusted_base.tcg import TcgAlgorithmRegistry +from cctrusted_base.tdx.rtmr import TdxRTMR +from cctrusted_vm.cvm import ConfidentialVM +from cctrusted_vm.sdk import CCTrustedVmSdk +import tdx_check + +cnf_default_alg = { + ConfidentialVM.TYPE_CC_TDX: TcgAlgorithmRegistry.TPM_ALG_SHA384 +} +"""Configurations of default algorithm. +The configurations could be different for different confidential VMs. +e.g. TDX use sha384 as the default. +""" + +cnf_measurement_cnt = { + ConfidentialVM.TYPE_CC_TDX: TdxRTMR.RTMR_COUNT +} +"""Configurations of measurement count. +The configurations could be different for different confidential VMs. +""" + +cnf_measurement_check = { + ConfidentialVM.TYPE_CC_TDX: tdx_check.tdx_check_measurement_imrs +} +"""Configurations of measurement check functions. +The configurations could be different for different confidential VMs. +""" + +cnf_quote_check = { + ConfidentialVM.TYPE_CC_TDX: tdx_check.tdx_check_quote_rtmrs +} +"""Configurations of quote check functions. +The configurations could be different for different confidential VMs. +""" + +@pytest.fixture(scope="module") +def vm_sdk(): + """Get VMSDK instance.""" + return CCTrustedVmSdk.inst() + +@pytest.fixture(scope="module") +def default_alg_id(): + """Get default algorithm.""" + cc_type = ConfidentialVM.detect_cc_type() + return cnf_default_alg[cc_type] + +@pytest.fixture(scope="module") +def measurement_count(): + """Get measurement count.""" + cc_type = ConfidentialVM.detect_cc_type() + return cnf_measurement_cnt[cc_type] + +@pytest.fixture(scope="module") +def check_measurement(): + """Return checker for measurement.""" + cc_type = ConfidentialVM.detect_cc_type() + return cnf_measurement_check[cc_type] + +@pytest.fixture(scope="module") +def check_quote(): + """Return checker for quote.""" + cc_type = ConfidentialVM.detect_cc_type() + return cnf_quote_check[cc_type] diff --git a/vmsdk/python/tests/tdx_check.py b/vmsdk/python/tests/tdx_check.py new file mode 100644 index 00000000..bbd2cf7c --- /dev/null +++ b/vmsdk/python/tests/tdx_check.py @@ -0,0 +1,77 @@ +"""TDX specific test.""" + +from hashlib import sha384 +from cctrusted_base.tcg import TcgAlgorithmRegistry, TcgImrEvent +from cctrusted_base.tdx.quote import TdxQuote, TdxQuoteBody +from cctrusted_base.tdx.rtmr import TdxRTMR +from cctrusted_vm.sdk import CCTrustedVmSdk + +def _replay_eventlog(): + """Get RTMRs from event log by replay.""" + rtmr_len = TdxRTMR.RTMR_LENGTH_BY_BYTES + rtmr_cnt = TdxRTMR.RTMR_COUNT + rtmrs = [bytearray(rtmr_len)] * rtmr_cnt + event_logs = CCTrustedVmSdk.inst().get_eventlog().event_logs + assert event_logs is not None + for event in event_logs: + if isinstance(event, TcgImrEvent): + sha384_algo = sha384() + sha384_algo.update(rtmrs[event.imr_index] + event.digests[0].hash) + rtmrs[event.imr_index] = sha384_algo.digest() + return rtmrs + +def _check_imr(imr_index: int, alg_id: int, rtmr: bytes): + """Check individual IMR. + Compare the 4 IMR hash with the hash derived by replay event log. They are + expected to be same. + Args: + imr_index: an integer specified the IMR index. + alg_id: an integer specified the hash algorithm. + rtmr: bytes of RTMR data for comparison. + """ + assert 0 <= imr_index < TdxRTMR.RTMR_COUNT + assert rtmr is not None + assert alg_id == TcgAlgorithmRegistry.TPM_ALG_SHA384 + imr = CCTrustedVmSdk.inst().get_measurement([imr_index, alg_id]) + assert imr is not None + digest_obj = imr.digest(alg_id) + assert digest_obj is not None + digest_alg_id = digest_obj.alg.alg_id + assert digest_alg_id == TcgAlgorithmRegistry.TPM_ALG_SHA384 + digest_hash = digest_obj.hash + assert digest_hash is not None + assert digest_hash == rtmr, \ + f"rtmr {rtmr.hex()} doesn't equal digest {digest_hash.hex()}" + +def tdx_check_measurement_imrs(): + """Test measurement result. + The test is done by compare the measurement register against the value + derived by replay eventlog. + """ + alg = CCTrustedVmSdk.inst().get_default_algorithms() + rtmrs = _replay_eventlog() + _check_imr(0, alg.alg_id, rtmrs[0]) + _check_imr(1, alg.alg_id, rtmrs[1]) + _check_imr(2, alg.alg_id, rtmrs[2]) + _check_imr(3, alg.alg_id, rtmrs[3]) + +def tdx_check_quote_rtmrs(): + """Test quote result. + The test is done by compare the RTMRs in quote body against the value + derived by replay eventlog. + """ + quote = CCTrustedVmSdk.inst().get_quote() + assert quote is not None + assert isinstance(quote, TdxQuote) + body = quote.body + assert body is not None + assert isinstance(body, TdxQuoteBody) + rtmrs = _replay_eventlog() + assert body.rtmr0 == rtmrs[0], \ + "RTMR0 doesn't equal the replay from event log!" + assert body.rtmr1 == rtmrs[1], \ + "RTMR1 doesn't equal the replay from event log!" + assert body.rtmr2 == rtmrs[2], \ + "RTMR2 doesn't equal the replay from event log!" + assert body.rtmr3 == rtmrs[3], \ + "RTMR3 doesn't equal the replay from event log!" diff --git a/vmsdk/python/tests/test_sdk.py b/vmsdk/python/tests/test_sdk.py index 7645f104..a4dd512c 100644 --- a/vmsdk/python/tests/test_sdk.py +++ b/vmsdk/python/tests/test_sdk.py @@ -1,55 +1,55 @@ """Containing unit test cases for sdk class""" import pytest -from cctrusted_vm import CCTrustedVmSdk - -class TestCCTrustedVmSdk(): - """Unit tests for CCTrustedVmSdk class.""" - - def test_get_default_algorithms(self): - """Test get_default_algorithms() function.""" - algo = CCTrustedVmSdk.inst().get_default_algorithms() - assert algo is not None - - def test_get_measurement_count(self): - """Test get_measurement_count() function.""" - count = CCTrustedVmSdk.inst().get_measurement_count() - assert count is not None - - def test_get_measurement_with_invalid_input(self): - """Test get_measurement() function with invalid input.""" - # calling get_measurement() with invalid IMR index - measurement = CCTrustedVmSdk.inst().get_measurement([-1, 0xC]) - assert measurement is None - - # calling get_measurement() with invalid algorithm ID - measurement = CCTrustedVmSdk.inst().get_measurement([0, None]) - assert measurement is not None - - def test_get_measurement_with_valid_input(self): - """Test get_measurement() function with valid input.""" - count = CCTrustedVmSdk.inst().get_measurement_count() - for index in range(count): - alg = CCTrustedVmSdk.inst().get_default_algorithms() - digest_obj = CCTrustedVmSdk.inst().get_measurement([index, alg.alg_id]) - assert digest_obj is not None - - def test_get_eventlog_with_invalid_input(self): - """Test get_eventlog() function with invalid input.""" - # calling get_eventlog with count < 0 - with pytest.raises(ValueError): - CCTrustedVmSdk.inst().get_eventlog(start=1, count=-1) - - # calling get_eventlog with start < 1 - with pytest.raises(ValueError): - CCTrustedVmSdk.inst().get_eventlog(start=0) - - def test_get_eventlog_with_valid_input(self): - """Test get_eventlog() funtion with valid input.""" - event_logs = CCTrustedVmSdk.inst().get_eventlog() - assert event_logs is not None - - def test_get_quote_with_valid_input(self): - """Test get_quote() function with valid input.""" - quote = CCTrustedVmSdk.inst().get_quote(None, None, None) - assert quote is not None + +def test_get_default_algorithms(vm_sdk, default_alg_id): + """Test get_default_algorithms() function.""" + algo = vm_sdk.get_default_algorithms() + assert algo is not None + assert algo.alg_id == default_alg_id + +def test_get_measurement_count(vm_sdk, measurement_count): + """Test get_measurement_count() function.""" + count = vm_sdk.get_measurement_count() + assert count is not None + assert count == measurement_count + +def test_get_measurement_with_invalid_input(vm_sdk): + """Test get_measurement() function with invalid input.""" + # calling get_measurement() with invalid IMR index + measurement = vm_sdk.get_measurement([-1, 0xC]) + assert measurement is None + + # calling get_measurement() with invalid algorithm ID + measurement = vm_sdk.get_measurement([0, None]) + assert measurement is not None + +def test_get_measurement_with_valid_input(vm_sdk, check_measurement): + """Test get_measurement() function with valid input.""" + count = vm_sdk.get_measurement_count() + for index in range(count): + alg = vm_sdk.get_default_algorithms() + digest_obj = vm_sdk.get_measurement([index, alg.alg_id]) + assert digest_obj is not None + check_measurement() + +def test_get_eventlog_with_invalid_input(vm_sdk): + """Test get_eventlog() function with invalid input.""" + # calling get_eventlog with count < 0 + with pytest.raises(ValueError): + vm_sdk.get_eventlog(start=1, count=-1) + + # calling get_eventlog with start < 1 + with pytest.raises(ValueError): + vm_sdk.get_eventlog(start=0) + +def test_get_eventlog_with_valid_input(vm_sdk): + """Test get_eventlog() funtion with valid input.""" + event_logs = vm_sdk.get_eventlog() + assert event_logs is not None + +def test_get_quote_with_valid_input(vm_sdk, check_quote): + """Test get_quote() function with valid input.""" + quote = vm_sdk.get_quote(None, None, None) + assert quote is not None + check_quote()