-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
vmsdk/python/tests: add tests for TDX #57
Merged
Merged
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
""" | ||
RTMR (Runtime Measurement Register). | ||
""" | ||
|
||
from cctrusted_base.imr import TcgIMR | ||
from cctrusted_base.tcg import TcgAlgorithmRegistry | ||
|
||
class TdxRTMR(TcgIMR): | ||
"""RTMR class defined for Intel TDX.""" | ||
|
||
RTMR_COUNT = 4 | ||
"""Intel TDX TDREPORT provides the 4 measurement registers by default.""" | ||
|
||
RTMR_LENGTH_BY_BYTES = 48 | ||
"""RTMR length by bytes.""" | ||
|
||
@property | ||
def max_index(self): | ||
return 3 | ||
|
||
def __init__(self, index, digest_hash): | ||
super().__init__(index, TcgAlgorithmRegistry.TPM_ALG_SHA384, | ||
digest_hash) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
""" | ||
PCR (Platform Configuration Register). | ||
""" | ||
|
||
from cctrusted_base.imr import TcgIMR | ||
|
||
class TpmPCR(TcgIMR): | ||
"""PCR class defined for TPM""" | ||
|
||
@property | ||
def max_index(self): | ||
return 23 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
"""Local conftest.py containing directory-specific hook implementations.""" | ||
|
||
import pytest | ||
from cctrusted_base.tcg import TcgAlgorithmRegistry | ||
from cctrusted_base.tdx.rtmr import TdxRTMR | ||
from cctrusted_vm.cvm import ConfidentialVM | ||
from cctrusted_vm.sdk import CCTrustedVmSdk | ||
import tdx_check | ||
|
||
cnf_default_alg = { | ||
ConfidentialVM.TYPE_CC_TDX: TcgAlgorithmRegistry.TPM_ALG_SHA384 | ||
} | ||
"""Configurations of default algorithm. | ||
The configurations could be different for different confidential VMs. | ||
e.g. TDX use sha384 as the default. | ||
""" | ||
|
||
cnf_measurement_cnt = { | ||
ConfidentialVM.TYPE_CC_TDX: TdxRTMR.RTMR_COUNT | ||
} | ||
"""Configurations of measurement count. | ||
The configurations could be different for different confidential VMs. | ||
""" | ||
|
||
cnf_measurement_check = { | ||
ConfidentialVM.TYPE_CC_TDX: tdx_check.tdx_check_measurement_imrs | ||
} | ||
"""Configurations of measurement check functions. | ||
The configurations could be different for different confidential VMs. | ||
""" | ||
|
||
cnf_quote_check = { | ||
ConfidentialVM.TYPE_CC_TDX: tdx_check.tdx_check_quote_rtmrs | ||
} | ||
"""Configurations of quote check functions. | ||
The configurations could be different for different confidential VMs. | ||
""" | ||
|
||
@pytest.fixture(scope="module") | ||
def vm_sdk(): | ||
"""Get VMSDK instance.""" | ||
return CCTrustedVmSdk.inst() | ||
|
||
@pytest.fixture(scope="module") | ||
def default_alg_id(): | ||
"""Get default algorithm.""" | ||
cc_type = ConfidentialVM.detect_cc_type() | ||
return cnf_default_alg[cc_type] | ||
|
||
@pytest.fixture(scope="module") | ||
def measurement_count(): | ||
"""Get measurement count.""" | ||
cc_type = ConfidentialVM.detect_cc_type() | ||
return cnf_measurement_cnt[cc_type] | ||
|
||
@pytest.fixture(scope="module") | ||
def check_measurement(): | ||
"""Return checker for measurement.""" | ||
cc_type = ConfidentialVM.detect_cc_type() | ||
return cnf_measurement_check[cc_type] | ||
|
||
@pytest.fixture(scope="module") | ||
def check_quote(): | ||
"""Return checker for quote.""" | ||
cc_type = ConfidentialVM.detect_cc_type() | ||
return cnf_quote_check[cc_type] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
"""TDX specific test.""" | ||
|
||
from hashlib import sha384 | ||
from cctrusted_base.tcg import TcgAlgorithmRegistry, TcgImrEvent | ||
from cctrusted_base.tdx.quote import TdxQuote, TdxQuoteBody | ||
from cctrusted_base.tdx.rtmr import TdxRTMR | ||
from cctrusted_vm.sdk import CCTrustedVmSdk | ||
|
||
def _replay_eventlog(): | ||
"""Get RTMRs from event log by replay.""" | ||
rtmr_len = TdxRTMR.RTMR_LENGTH_BY_BYTES | ||
rtmr_cnt = TdxRTMR.RTMR_COUNT | ||
rtmrs = [bytearray(rtmr_len)] * rtmr_cnt | ||
event_logs = CCTrustedVmSdk.inst().get_eventlog().event_logs | ||
assert event_logs is not None | ||
for event in event_logs: | ||
if isinstance(event, TcgImrEvent): | ||
sha384_algo = sha384() | ||
sha384_algo.update(rtmrs[event.imr_index] + event.digests[0].hash) | ||
rtmrs[event.imr_index] = sha384_algo.digest() | ||
return rtmrs | ||
|
||
def _check_imr(imr_index: int, alg_id: int, rtmr: bytes): | ||
"""Check individual IMR. | ||
Compare the 4 IMR hash with the hash derived by replay event log. They are | ||
expected to be same. | ||
Args: | ||
imr_index: an integer specified the IMR index. | ||
alg_id: an integer specified the hash algorithm. | ||
rtmr: bytes of RTMR data for comparison. | ||
""" | ||
assert 0 <= imr_index < TdxRTMR.RTMR_COUNT | ||
assert rtmr is not None | ||
assert alg_id == TcgAlgorithmRegistry.TPM_ALG_SHA384 | ||
imr = CCTrustedVmSdk.inst().get_measurement([imr_index, alg_id]) | ||
assert imr is not None | ||
digest_obj = imr.digest(alg_id) | ||
assert digest_obj is not None | ||
digest_alg_id = digest_obj.alg.alg_id | ||
assert digest_alg_id == TcgAlgorithmRegistry.TPM_ALG_SHA384 | ||
digest_hash = digest_obj.hash | ||
assert digest_hash is not None | ||
assert digest_hash == rtmr, \ | ||
f"rtmr {rtmr.hex()} doesn't equal digest {digest_hash.hex()}" | ||
|
||
def tdx_check_measurement_imrs(): | ||
"""Test measurement result. | ||
The test is done by compare the measurement register against the value | ||
derived by replay eventlog. | ||
""" | ||
alg = CCTrustedVmSdk.inst().get_default_algorithms() | ||
rtmrs = _replay_eventlog() | ||
_check_imr(0, alg.alg_id, rtmrs[0]) | ||
_check_imr(1, alg.alg_id, rtmrs[1]) | ||
_check_imr(2, alg.alg_id, rtmrs[2]) | ||
_check_imr(3, alg.alg_id, rtmrs[3]) | ||
|
||
def tdx_check_quote_rtmrs(): | ||
"""Test quote result. | ||
The test is done by compare the RTMRs in quote body against the value | ||
derived by replay eventlog. | ||
""" | ||
quote = CCTrustedVmSdk.inst().get_quote() | ||
assert quote is not None | ||
assert isinstance(quote, TdxQuote) | ||
body = quote.body | ||
assert body is not None | ||
assert isinstance(body, TdxQuoteBody) | ||
rtmrs = _replay_eventlog() | ||
assert body.rtmr0 == rtmrs[0], \ | ||
"RTMR0 doesn't equal the replay from event log!" | ||
assert body.rtmr1 == rtmrs[1], \ | ||
"RTMR1 doesn't equal the replay from event log!" | ||
assert body.rtmr2 == rtmrs[2], \ | ||
"RTMR2 doesn't equal the replay from event log!" | ||
assert body.rtmr3 == rtmrs[3], \ | ||
"RTMR3 doesn't equal the replay from event log!" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,55 +1,55 @@ | ||
"""Containing unit test cases for sdk class""" | ||
|
||
import pytest | ||
from cctrusted_vm import CCTrustedVmSdk | ||
|
||
class TestCCTrustedVmSdk(): | ||
"""Unit tests for CCTrustedVmSdk class.""" | ||
|
||
def test_get_default_algorithms(self): | ||
"""Test get_default_algorithms() function.""" | ||
algo = CCTrustedVmSdk.inst().get_default_algorithms() | ||
assert algo is not None | ||
|
||
def test_get_measurement_count(self): | ||
"""Test get_measurement_count() function.""" | ||
count = CCTrustedVmSdk.inst().get_measurement_count() | ||
assert count is not None | ||
|
||
def test_get_measurement_with_invalid_input(self): | ||
"""Test get_measurement() function with invalid input.""" | ||
# calling get_measurement() with invalid IMR index | ||
measurement = CCTrustedVmSdk.inst().get_measurement([-1, 0xC]) | ||
assert measurement is None | ||
|
||
# calling get_measurement() with invalid algorithm ID | ||
measurement = CCTrustedVmSdk.inst().get_measurement([0, None]) | ||
assert measurement is not None | ||
|
||
def test_get_measurement_with_valid_input(self): | ||
"""Test get_measurement() function with valid input.""" | ||
count = CCTrustedVmSdk.inst().get_measurement_count() | ||
for index in range(count): | ||
alg = CCTrustedVmSdk.inst().get_default_algorithms() | ||
digest_obj = CCTrustedVmSdk.inst().get_measurement([index, alg.alg_id]) | ||
assert digest_obj is not None | ||
|
||
def test_get_eventlog_with_invalid_input(self): | ||
"""Test get_eventlog() function with invalid input.""" | ||
# calling get_eventlog with count < 0 | ||
with pytest.raises(ValueError): | ||
CCTrustedVmSdk.inst().get_eventlog(start=1, count=-1) | ||
|
||
# calling get_eventlog with start < 1 | ||
with pytest.raises(ValueError): | ||
CCTrustedVmSdk.inst().get_eventlog(start=0) | ||
|
||
def test_get_eventlog_with_valid_input(self): | ||
"""Test get_eventlog() funtion with valid input.""" | ||
event_logs = CCTrustedVmSdk.inst().get_eventlog() | ||
assert event_logs is not None | ||
|
||
def test_get_quote_with_valid_input(self): | ||
"""Test get_quote() function with valid input.""" | ||
quote = CCTrustedVmSdk.inst().get_quote(None, None, None) | ||
assert quote is not None | ||
|
||
def test_get_default_algorithms(vm_sdk, default_alg_id): | ||
"""Test get_default_algorithms() function.""" | ||
algo = vm_sdk.get_default_algorithms() | ||
assert algo is not None | ||
assert algo.alg_id == default_alg_id | ||
|
||
def test_get_measurement_count(vm_sdk, measurement_count): | ||
"""Test get_measurement_count() function.""" | ||
count = vm_sdk.get_measurement_count() | ||
assert count is not None | ||
assert count == measurement_count | ||
|
||
def test_get_measurement_with_invalid_input(vm_sdk): | ||
"""Test get_measurement() function with invalid input.""" | ||
# calling get_measurement() with invalid IMR index | ||
measurement = vm_sdk.get_measurement([-1, 0xC]) | ||
assert measurement is None | ||
|
||
# calling get_measurement() with invalid algorithm ID | ||
measurement = vm_sdk.get_measurement([0, None]) | ||
assert measurement is not None | ||
|
||
def test_get_measurement_with_valid_input(vm_sdk, check_measurement): | ||
"""Test get_measurement() function with valid input.""" | ||
count = vm_sdk.get_measurement_count() | ||
for index in range(count): | ||
alg = vm_sdk.get_default_algorithms() | ||
digest_obj = vm_sdk.get_measurement([index, alg.alg_id]) | ||
assert digest_obj is not None | ||
check_measurement() | ||
|
||
def test_get_eventlog_with_invalid_input(vm_sdk): | ||
"""Test get_eventlog() function with invalid input.""" | ||
# calling get_eventlog with count < 0 | ||
with pytest.raises(ValueError): | ||
vm_sdk.get_eventlog(start=1, count=-1) | ||
|
||
# calling get_eventlog with start < 1 | ||
with pytest.raises(ValueError): | ||
vm_sdk.get_eventlog(start=0) | ||
|
||
def test_get_eventlog_with_valid_input(vm_sdk): | ||
"""Test get_eventlog() funtion with valid input.""" | ||
event_logs = vm_sdk.get_eventlog() | ||
assert event_logs is not None | ||
|
||
def test_get_quote_with_valid_input(vm_sdk, check_quote): | ||
"""Test get_quote() function with valid input.""" | ||
quote = vm_sdk.get_quote(None, None, None) | ||
assert quote is not None | ||
check_quote() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why change to this style, looks wierd.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After adding default values for parameters "nonce" and "data", the line became too long. So I followed the https://google.github.io/styleguide/pyguide.html#3192-line-breaking 3.19.2 Line Breaking for that style.