-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
vmsdk/python/tests: add tests for TDX
This patch mainly adds some tests for TDX. And it refactors some corresponding code accordingly. Signed-off-by: zhongjie <zhongjie.shi@intel.com>
- Loading branch information
1 parent
904333a
commit 4cc0cc4
Showing
12 changed files
with
259 additions
and
79 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
""" | ||
RTMR (Runtime Measurement Register). | ||
""" | ||
|
||
from cctrusted_base.imr import TcgIMR | ||
from cctrusted_base.tcg import TcgAlgorithmRegistry | ||
|
||
class TdxRTMR(TcgIMR): | ||
"""RTMR class defined for Intel TDX.""" | ||
|
||
RTMR_COUNT = 4 | ||
"""Intel TDX TDREPORT provides the 4 measurement registers by default.""" | ||
|
||
RTMR_LENGTH_BY_BYTES = 48 | ||
"""RTMR length by bytes.""" | ||
|
||
@property | ||
def max_index(self): | ||
return 3 | ||
|
||
def __init__(self, index, digest_hash): | ||
super().__init__(index, TcgAlgorithmRegistry.TPM_ALG_SHA384, | ||
digest_hash) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
""" | ||
PCR (Platform Configuration Register). | ||
""" | ||
|
||
from cctrusted_base.imr import TcgIMR | ||
|
||
class TpmPCR(TcgIMR): | ||
"""PCR class defined for TPM""" | ||
|
||
@property | ||
def max_index(self): | ||
return 23 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
"""Local conftest.py containing directory-specific hook implementations.""" | ||
|
||
import pytest | ||
from cctrusted_base.tcg import TcgAlgorithmRegistry | ||
from cctrusted_base.tdx.rtmr import TdxRTMR | ||
from cctrusted_vm.cvm import ConfidentialVM | ||
from cctrusted_vm.sdk import CCTrustedVmSdk | ||
import tdx_check | ||
|
||
cnf_default_alg = { | ||
ConfidentialVM.TYPE_CC_TDX: TcgAlgorithmRegistry.TPM_ALG_SHA384 | ||
} | ||
"""Configurations of default algorithm. | ||
The configurations could be different for different confidential VMs. | ||
e.g. TDX use sha384 as the default. | ||
""" | ||
|
||
cnf_measurement_cnt = { | ||
ConfidentialVM.TYPE_CC_TDX: TdxRTMR.RTMR_COUNT | ||
} | ||
"""Configurations of measurement count. | ||
The configurations could be different for different confidential VMs. | ||
""" | ||
|
||
cnf_measurement_check = { | ||
ConfidentialVM.TYPE_CC_TDX: tdx_check.tdx_check_measurement_imrs | ||
} | ||
"""Configurations of measurement check functions. | ||
The configurations could be different for different confidential VMs. | ||
""" | ||
|
||
cnf_quote_check = { | ||
ConfidentialVM.TYPE_CC_TDX: tdx_check.tdx_check_quote_rtmrs | ||
} | ||
"""Configurations of quote check functions. | ||
The configurations could be different for different confidential VMs. | ||
""" | ||
|
||
@pytest.fixture(scope="module") | ||
def vm_sdk(): | ||
"""Get VMSDK instance.""" | ||
return CCTrustedVmSdk.inst() | ||
|
||
@pytest.fixture(scope="module") | ||
def default_alg_id(): | ||
"""Get default algorithm.""" | ||
cc_type = ConfidentialVM.detect_cc_type() | ||
return cnf_default_alg[cc_type] | ||
|
||
@pytest.fixture(scope="module") | ||
def measurement_count(): | ||
"""Get measurement count.""" | ||
cc_type = ConfidentialVM.detect_cc_type() | ||
return cnf_measurement_cnt[cc_type] | ||
|
||
@pytest.fixture(scope="module") | ||
def check_measurement(): | ||
"""Return checker for measurement.""" | ||
cc_type = ConfidentialVM.detect_cc_type() | ||
return cnf_measurement_check[cc_type] | ||
|
||
@pytest.fixture(scope="module") | ||
def check_quote(): | ||
"""Return checker for quote.""" | ||
cc_type = ConfidentialVM.detect_cc_type() | ||
return cnf_quote_check[cc_type] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
"""TDX specific test.""" | ||
|
||
from hashlib import sha384 | ||
from cctrusted_base.tcg import TcgAlgorithmRegistry, TcgImrEvent | ||
from cctrusted_base.tdx.quote import TdxQuote, TdxQuoteBody | ||
from cctrusted_base.tdx.rtmr import TdxRTMR | ||
from cctrusted_vm.sdk import CCTrustedVmSdk | ||
|
||
def _replay_eventlog(): | ||
"""Get RTMRs from event log by replay.""" | ||
rtmr_len = TdxRTMR.RTMR_LENGTH_BY_BYTES | ||
rtmr_cnt = TdxRTMR.RTMR_COUNT | ||
rtmrs = [bytearray(rtmr_len)] * rtmr_cnt | ||
event_logs = CCTrustedVmSdk.inst().get_eventlog().event_logs | ||
assert event_logs is not None | ||
for event in event_logs: | ||
if isinstance(event, TcgImrEvent): | ||
sha384_algo = sha384() | ||
sha384_algo.update(rtmrs[event.imr_index] + event.digests[0].hash) | ||
rtmrs[event.imr_index] = sha384_algo.digest() | ||
return rtmrs | ||
|
||
def _check_imr(imr_index: int, alg_id: int, rtmr: bytes): | ||
"""Check individual IMR. | ||
Compare the 4 IMR hash with the hash derived by replay event log. They are | ||
expected to be same. | ||
Args: | ||
imr_index: an integer specified the IMR index. | ||
alg_id: an integer specified the hash algorithm. | ||
rtmr: bytes of RTMR data for comparison. | ||
""" | ||
assert 0 <= imr_index < TdxRTMR.RTMR_COUNT | ||
assert rtmr is not None | ||
assert alg_id == TcgAlgorithmRegistry.TPM_ALG_SHA384 | ||
imr = CCTrustedVmSdk.inst().get_measurement([imr_index, alg_id]) | ||
assert imr is not None | ||
digest_obj = imr.digest(alg_id) | ||
assert digest_obj is not None | ||
digest_alg_id = digest_obj.alg.alg_id | ||
assert digest_alg_id == TcgAlgorithmRegistry.TPM_ALG_SHA384 | ||
digest_hash = digest_obj.hash | ||
assert digest_hash is not None | ||
assert digest_hash == rtmr, \ | ||
f"rtmr {rtmr.hex()} doesn't equal digest {digest_hash.hex()}" | ||
|
||
def tdx_check_measurement_imrs(): | ||
"""Test measurement result. | ||
The test is done by compare the measurement register against the value | ||
derived by replay eventlog. | ||
""" | ||
alg = CCTrustedVmSdk.inst().get_default_algorithms() | ||
rtmrs = _replay_eventlog() | ||
_check_imr(0, alg.alg_id, rtmrs[0]) | ||
_check_imr(1, alg.alg_id, rtmrs[1]) | ||
_check_imr(2, alg.alg_id, rtmrs[2]) | ||
_check_imr(3, alg.alg_id, rtmrs[3]) | ||
|
||
def tdx_check_quote_rtmrs(): | ||
"""Test quote result. | ||
The test is done by compare the RTMRs in quote body against the value | ||
derived by replay eventlog. | ||
""" | ||
quote = CCTrustedVmSdk.inst().get_quote() | ||
assert quote is not None | ||
assert isinstance(quote, TdxQuote) | ||
body = quote.body | ||
assert body is not None | ||
assert isinstance(body, TdxQuoteBody) | ||
rtmrs = _replay_eventlog() | ||
assert body.rtmr0 == rtmrs[0], \ | ||
"RTMR0 doesn't equal the replay from event log!" | ||
assert body.rtmr1 == rtmrs[1], \ | ||
"RTMR1 doesn't equal the replay from event log!" | ||
assert body.rtmr2 == rtmrs[2], \ | ||
"RTMR2 doesn't equal the replay from event log!" | ||
assert body.rtmr3 == rtmrs[3], \ | ||
"RTMR3 doesn't equal the replay from event log!" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,55 +1,55 @@ | ||
"""Containing unit test cases for sdk class""" | ||
|
||
import pytest | ||
from cctrusted_vm import CCTrustedVmSdk | ||
|
||
class TestCCTrustedVmSdk(): | ||
"""Unit tests for CCTrustedVmSdk class.""" | ||
|
||
def test_get_default_algorithms(self): | ||
"""Test get_default_algorithms() function.""" | ||
algo = CCTrustedVmSdk.inst().get_default_algorithms() | ||
assert algo is not None | ||
|
||
def test_get_measurement_count(self): | ||
"""Test get_measurement_count() function.""" | ||
count = CCTrustedVmSdk.inst().get_measurement_count() | ||
assert count is not None | ||
|
||
def test_get_measurement_with_invalid_input(self): | ||
"""Test get_measurement() function with invalid input.""" | ||
# calling get_measurement() with invalid IMR index | ||
measurement = CCTrustedVmSdk.inst().get_measurement([-1, 0xC]) | ||
assert measurement is None | ||
|
||
# calling get_measurement() with invalid algorithm ID | ||
measurement = CCTrustedVmSdk.inst().get_measurement([0, None]) | ||
assert measurement is not None | ||
|
||
def test_get_measurement_with_valid_input(self): | ||
"""Test get_measurement() function with valid input.""" | ||
count = CCTrustedVmSdk.inst().get_measurement_count() | ||
for index in range(count): | ||
alg = CCTrustedVmSdk.inst().get_default_algorithms() | ||
digest_obj = CCTrustedVmSdk.inst().get_measurement([index, alg.alg_id]) | ||
assert digest_obj is not None | ||
|
||
def test_get_eventlog_with_invalid_input(self): | ||
"""Test get_eventlog() function with invalid input.""" | ||
# calling get_eventlog with count < 0 | ||
with pytest.raises(ValueError): | ||
CCTrustedVmSdk.inst().get_eventlog(start=1, count=-1) | ||
|
||
# calling get_eventlog with start < 1 | ||
with pytest.raises(ValueError): | ||
CCTrustedVmSdk.inst().get_eventlog(start=0) | ||
|
||
def test_get_eventlog_with_valid_input(self): | ||
"""Test get_eventlog() funtion with valid input.""" | ||
event_logs = CCTrustedVmSdk.inst().get_eventlog() | ||
assert event_logs is not None | ||
|
||
def test_get_quote_with_valid_input(self): | ||
"""Test get_quote() function with valid input.""" | ||
quote = CCTrustedVmSdk.inst().get_quote(None, None, None) | ||
assert quote is not None | ||
|
||
def test_get_default_algorithms(vm_sdk, default_alg_id): | ||
"""Test get_default_algorithms() function.""" | ||
algo = vm_sdk.get_default_algorithms() | ||
assert algo is not None | ||
assert algo.alg_id == default_alg_id | ||
|
||
def test_get_measurement_count(vm_sdk, measurement_count): | ||
"""Test get_measurement_count() function.""" | ||
count = vm_sdk.get_measurement_count() | ||
assert count is not None | ||
assert count == measurement_count | ||
|
||
def test_get_measurement_with_invalid_input(vm_sdk): | ||
"""Test get_measurement() function with invalid input.""" | ||
# calling get_measurement() with invalid IMR index | ||
measurement = vm_sdk.get_measurement([-1, 0xC]) | ||
assert measurement is None | ||
|
||
# calling get_measurement() with invalid algorithm ID | ||
measurement = vm_sdk.get_measurement([0, None]) | ||
assert measurement is not None | ||
|
||
def test_get_measurement_with_valid_input(vm_sdk, check_measurement): | ||
"""Test get_measurement() function with valid input.""" | ||
count = vm_sdk.get_measurement_count() | ||
for index in range(count): | ||
alg = vm_sdk.get_default_algorithms() | ||
digest_obj = vm_sdk.get_measurement([index, alg.alg_id]) | ||
assert digest_obj is not None | ||
check_measurement() | ||
|
||
def test_get_eventlog_with_invalid_input(vm_sdk): | ||
"""Test get_eventlog() function with invalid input.""" | ||
# calling get_eventlog with count < 0 | ||
with pytest.raises(ValueError): | ||
vm_sdk.get_eventlog(start=1, count=-1) | ||
|
||
# calling get_eventlog with start < 1 | ||
with pytest.raises(ValueError): | ||
vm_sdk.get_eventlog(start=0) | ||
|
||
def test_get_eventlog_with_valid_input(vm_sdk): | ||
"""Test get_eventlog() funtion with valid input.""" | ||
event_logs = vm_sdk.get_eventlog() | ||
assert event_logs is not None | ||
|
||
def test_get_quote_with_valid_input(vm_sdk, check_quote): | ||
"""Test get_quote() function with valid input.""" | ||
quote = vm_sdk.get_quote(None, None, None) | ||
assert quote is not None | ||
check_quote() |