Skip to content

Commit

Permalink
vmsdk/python/tests: add tests for TDX
Browse files Browse the repository at this point in the history
This patch mainly adds some tests for TDX.
And it refactors some corresponding code accordingly.

Signed-off-by: zhongjie <zhongjie.shi@intel.com>
  • Loading branch information
intelzhongjie committed Jan 16, 2024
1 parent 904333a commit 4cc0cc4
Show file tree
Hide file tree
Showing 12 changed files with 259 additions and 79 deletions.
10 changes: 9 additions & 1 deletion .github/workflows/vmsdk-test-python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,12 @@ jobs:
- name: Run PyTest for VMSDK
run: |
set -ex
sudo su -c "source setupenv.sh && python3 -m pytest -v ./vmsdk/python/tests/test_sdk.py"
# Set PYTHONDONTWRITEBYTECODE and --no-cacheprovider to prevent
# generated some intermediate files by root. Othwerwise, these
# files will fail the action/checkout in the next round of running
# due to the permission issue.
sudo su -c "source setupenv.sh && \
pushd vmsdk/python/tests && \
export PYTHONDONTWRITEBYTECODE=1 && \
python3 -m pytest -p no:cacheprovider -v test_sdk.py && \
popd"
7 changes: 6 additions & 1 deletion common/python/cctrusted_base/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,12 @@ def get_measurement(self, imr_select:[int, int]) -> TcgIMR:
raise NotImplementedError("Inherited SDK class should implement this.")

@abstractmethod
def get_quote(self, nonce: bytearray, data: bytearray, extraArgs=None) -> Quote:
def get_quote(
self,
nonce: bytearray = None,
data: bytearray = None,
extraArgs = None
) -> Quote:
"""Get the quote for given nonce and data.
The quote is signing of attestation data (IMR values or hashes of IMR
Expand Down
20 changes: 1 addition & 19 deletions common/python/cctrusted_base/imr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

from abc import ABC, abstractmethod
from cctrusted_base.tcg import TcgDigest, TcgAlgorithmRegistry
from cctrusted_base.tcg import TcgDigest

class TcgIMR(ABC):
"""Common Integrated Measurement Register class."""
Expand Down Expand Up @@ -56,21 +56,3 @@ def is_valid(self):
"""
return self._index != TcgIMR._INVALID_IMR_INDEX and \
self._index <= self.max_index

class TdxRTMR(TcgIMR):
"""RTMR class defined for Intel TDX."""

@property
def max_index(self):
return 3

def __init__(self, index, digest_hash):
super().__init__(index, TcgAlgorithmRegistry.TPM_ALG_SHA384,
digest_hash)

class TpmPCR(TcgIMR):
"""PCR class defined for TPM"""

@property
def max_index(self):
return 23
23 changes: 23 additions & 0 deletions common/python/cctrusted_base/tdx/rtmr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
RTMR (Runtime Measurement Register).
"""

from cctrusted_base.imr import TcgIMR
from cctrusted_base.tcg import TcgAlgorithmRegistry

class TdxRTMR(TcgIMR):
"""RTMR class defined for Intel TDX."""

RTMR_COUNT = 4
"""Intel TDX TDREPORT provides the 4 measurement registers by default."""

RTMR_LENGTH_BY_BYTES = 48
"""RTMR length by bytes."""

@property
def max_index(self):
return 3

def __init__(self, index, digest_hash):
super().__init__(index, TcgAlgorithmRegistry.TPM_ALG_SHA384,
digest_hash)
12 changes: 12 additions & 0 deletions common/python/cctrusted_base/tpm/pcr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
PCR (Platform Configuration Register).
"""

from cctrusted_base.imr import TcgIMR

class TpmPCR(TcgIMR):
"""PCR class defined for TPM"""

@property
def max_index(self):
return 23
3 changes: 2 additions & 1 deletion vmsdk/python/cc_imr_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
count = CCTrustedVmSdk.inst().get_measurement_count()
for index in range(CCTrustedVmSdk.inst().get_measurement_count()):
alg = CCTrustedVmSdk.inst().get_default_algorithms()
digest_obj = CCTrustedVmSdk.inst().get_measurement([index, alg.alg_id])
imr = CCTrustedVmSdk.inst().get_measurement([index, alg.alg_id])
digest_obj = imr.digest(alg.alg_id)

hash_str = ""
for hash_item in digest_obj.hash:
Expand Down
2 changes: 1 addition & 1 deletion vmsdk/python/cc_quote_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def main():
level=logging.NOTSET,
format="%(name)s %(levelname)-8s %(message)s"
)
quote = CCTrustedVmSdk.inst().get_quote(None, None, None)
quote = CCTrustedVmSdk.inst().get_quote()
if quote is not None:
quote.dump(args.out_format == OUT_FORMAT_RAW)
else:
Expand Down
3 changes: 2 additions & 1 deletion vmsdk/python/cctrusted_vm/cvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
import struct
import fcntl
from abc import abstractmethod
from cctrusted_base.imr import TdxRTMR,TcgIMR
from cctrusted_base.imr import TcgIMR
from cctrusted_base.quote import Quote
from cctrusted_base.tcg import TcgAlgorithmRegistry
from cctrusted_base.tdx.common import TDX_VERSION_1_0, TDX_VERSION_1_5
from cctrusted_base.tdx.rtmr import TdxRTMR
from cctrusted_base.tdx.quote import TdxQuoteReq10, TdxQuoteReq15
from cctrusted_base.tdx.report import TdxReportReq10, TdxReportReq15

Expand Down
11 changes: 8 additions & 3 deletions vmsdk/python/cctrusted_vm/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,14 @@ def get_measurement(self, imr_select:[int, int]) -> TcgIMR:
if algo_id is None or algo_id is TcgAlgorithmRegistry.TPM_ALG_ERROR:
algo_id = self._cvm.default_algo_id

return self._cvm.imrs[imr_index].digest(algo_id)

def get_quote(self, nonce: bytearray, data: bytearray, extraArgs=None) -> Quote:
return self._cvm.imrs[imr_index]

def get_quote(
self,
nonce: bytearray = None,
data: bytearray = None,
extraArgs = None
) -> Quote:
"""Get the quote for given nonce and data.
The quote is signing of attestation data (IMR values or hashes of IMR
Expand Down
66 changes: 66 additions & 0 deletions vmsdk/python/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""Local conftest.py containing directory-specific hook implementations."""

import pytest
from cctrusted_base.tcg import TcgAlgorithmRegistry
from cctrusted_base.tdx.rtmr import TdxRTMR
from cctrusted_vm.cvm import ConfidentialVM
from cctrusted_vm.sdk import CCTrustedVmSdk
import tdx_check

cnf_default_alg = {
ConfidentialVM.TYPE_CC_TDX: TcgAlgorithmRegistry.TPM_ALG_SHA384
}
"""Configurations of default algorithm.
The configurations could be different for different confidential VMs.
e.g. TDX use sha384 as the default.
"""

cnf_measurement_cnt = {
ConfidentialVM.TYPE_CC_TDX: TdxRTMR.RTMR_COUNT
}
"""Configurations of measurement count.
The configurations could be different for different confidential VMs.
"""

cnf_measurement_check = {
ConfidentialVM.TYPE_CC_TDX: tdx_check.tdx_check_measurement_imrs
}
"""Configurations of measurement check functions.
The configurations could be different for different confidential VMs.
"""

cnf_quote_check = {
ConfidentialVM.TYPE_CC_TDX: tdx_check.tdx_check_quote_rtmrs
}
"""Configurations of quote check functions.
The configurations could be different for different confidential VMs.
"""

@pytest.fixture(scope="module")
def vm_sdk():
"""Get VMSDK instance."""
return CCTrustedVmSdk.inst()

@pytest.fixture(scope="module")
def default_alg_id():
"""Get default algorithm."""
cc_type = ConfidentialVM.detect_cc_type()
return cnf_default_alg[cc_type]

@pytest.fixture(scope="module")
def measurement_count():
"""Get measurement count."""
cc_type = ConfidentialVM.detect_cc_type()
return cnf_measurement_cnt[cc_type]

@pytest.fixture(scope="module")
def check_measurement():
"""Return checker for measurement."""
cc_type = ConfidentialVM.detect_cc_type()
return cnf_measurement_check[cc_type]

@pytest.fixture(scope="module")
def check_quote():
"""Return checker for quote."""
cc_type = ConfidentialVM.detect_cc_type()
return cnf_quote_check[cc_type]
77 changes: 77 additions & 0 deletions vmsdk/python/tests/tdx_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""TDX specific test."""

from hashlib import sha384
from cctrusted_base.tcg import TcgAlgorithmRegistry, TcgImrEvent
from cctrusted_base.tdx.quote import TdxQuote, TdxQuoteBody
from cctrusted_base.tdx.rtmr import TdxRTMR
from cctrusted_vm.sdk import CCTrustedVmSdk

def _replay_eventlog():
"""Get RTMRs from event log by replay."""
rtmr_len = TdxRTMR.RTMR_LENGTH_BY_BYTES
rtmr_cnt = TdxRTMR.RTMR_COUNT
rtmrs = [bytearray(rtmr_len)] * rtmr_cnt
event_logs = CCTrustedVmSdk.inst().get_eventlog().event_logs
assert event_logs is not None
for event in event_logs:
if isinstance(event, TcgImrEvent):
sha384_algo = sha384()
sha384_algo.update(rtmrs[event.imr_index] + event.digests[0].hash)
rtmrs[event.imr_index] = sha384_algo.digest()
return rtmrs

def _check_imr(imr_index: int, alg_id: int, rtmr: bytes):
"""Check individual IMR.
Compare the 4 IMR hash with the hash derived by replay event log. They are
expected to be same.
Args:
imr_index: an integer specified the IMR index.
alg_id: an integer specified the hash algorithm.
rtmr: bytes of RTMR data for comparison.
"""
assert 0 <= imr_index < TdxRTMR.RTMR_COUNT
assert rtmr is not None
assert alg_id == TcgAlgorithmRegistry.TPM_ALG_SHA384
imr = CCTrustedVmSdk.inst().get_measurement([imr_index, alg_id])
assert imr is not None
digest_obj = imr.digest(alg_id)
assert digest_obj is not None
digest_alg_id = digest_obj.alg.alg_id
assert digest_alg_id == TcgAlgorithmRegistry.TPM_ALG_SHA384
digest_hash = digest_obj.hash
assert digest_hash is not None
assert digest_hash == rtmr, \
f"rtmr {rtmr.hex()} doesn't equal digest {digest_hash.hex()}"

def tdx_check_measurement_imrs():
"""Test measurement result.
The test is done by compare the measurement register against the value
derived by replay eventlog.
"""
alg = CCTrustedVmSdk.inst().get_default_algorithms()
rtmrs = _replay_eventlog()
_check_imr(0, alg.alg_id, rtmrs[0])
_check_imr(1, alg.alg_id, rtmrs[1])
_check_imr(2, alg.alg_id, rtmrs[2])
_check_imr(3, alg.alg_id, rtmrs[3])

def tdx_check_quote_rtmrs():
"""Test quote result.
The test is done by compare the RTMRs in quote body against the value
derived by replay eventlog.
"""
quote = CCTrustedVmSdk.inst().get_quote()
assert quote is not None
assert isinstance(quote, TdxQuote)
body = quote.body
assert body is not None
assert isinstance(body, TdxQuoteBody)
rtmrs = _replay_eventlog()
assert body.rtmr0 == rtmrs[0], \
"RTMR0 doesn't equal the replay from event log!"
assert body.rtmr1 == rtmrs[1], \
"RTMR1 doesn't equal the replay from event log!"
assert body.rtmr2 == rtmrs[2], \
"RTMR2 doesn't equal the replay from event log!"
assert body.rtmr3 == rtmrs[3], \
"RTMR3 doesn't equal the replay from event log!"
104 changes: 52 additions & 52 deletions vmsdk/python/tests/test_sdk.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,55 @@
"""Containing unit test cases for sdk class"""

import pytest
from cctrusted_vm import CCTrustedVmSdk

class TestCCTrustedVmSdk():
"""Unit tests for CCTrustedVmSdk class."""

def test_get_default_algorithms(self):
"""Test get_default_algorithms() function."""
algo = CCTrustedVmSdk.inst().get_default_algorithms()
assert algo is not None

def test_get_measurement_count(self):
"""Test get_measurement_count() function."""
count = CCTrustedVmSdk.inst().get_measurement_count()
assert count is not None

def test_get_measurement_with_invalid_input(self):
"""Test get_measurement() function with invalid input."""
# calling get_measurement() with invalid IMR index
measurement = CCTrustedVmSdk.inst().get_measurement([-1, 0xC])
assert measurement is None

# calling get_measurement() with invalid algorithm ID
measurement = CCTrustedVmSdk.inst().get_measurement([0, None])
assert measurement is not None

def test_get_measurement_with_valid_input(self):
"""Test get_measurement() function with valid input."""
count = CCTrustedVmSdk.inst().get_measurement_count()
for index in range(count):
alg = CCTrustedVmSdk.inst().get_default_algorithms()
digest_obj = CCTrustedVmSdk.inst().get_measurement([index, alg.alg_id])
assert digest_obj is not None

def test_get_eventlog_with_invalid_input(self):
"""Test get_eventlog() function with invalid input."""
# calling get_eventlog with count < 0
with pytest.raises(ValueError):
CCTrustedVmSdk.inst().get_eventlog(start=1, count=-1)

# calling get_eventlog with start < 1
with pytest.raises(ValueError):
CCTrustedVmSdk.inst().get_eventlog(start=0)

def test_get_eventlog_with_valid_input(self):
"""Test get_eventlog() funtion with valid input."""
event_logs = CCTrustedVmSdk.inst().get_eventlog()
assert event_logs is not None

def test_get_quote_with_valid_input(self):
"""Test get_quote() function with valid input."""
quote = CCTrustedVmSdk.inst().get_quote(None, None, None)
assert quote is not None

def test_get_default_algorithms(vm_sdk, default_alg_id):
"""Test get_default_algorithms() function."""
algo = vm_sdk.get_default_algorithms()
assert algo is not None
assert algo.alg_id == default_alg_id

def test_get_measurement_count(vm_sdk, measurement_count):
"""Test get_measurement_count() function."""
count = vm_sdk.get_measurement_count()
assert count is not None
assert count == measurement_count

def test_get_measurement_with_invalid_input(vm_sdk):
"""Test get_measurement() function with invalid input."""
# calling get_measurement() with invalid IMR index
measurement = vm_sdk.get_measurement([-1, 0xC])
assert measurement is None

# calling get_measurement() with invalid algorithm ID
measurement = vm_sdk.get_measurement([0, None])
assert measurement is not None

def test_get_measurement_with_valid_input(vm_sdk, check_measurement):
"""Test get_measurement() function with valid input."""
count = vm_sdk.get_measurement_count()
for index in range(count):
alg = vm_sdk.get_default_algorithms()
digest_obj = vm_sdk.get_measurement([index, alg.alg_id])
assert digest_obj is not None
check_measurement()

def test_get_eventlog_with_invalid_input(vm_sdk):
"""Test get_eventlog() function with invalid input."""
# calling get_eventlog with count < 0
with pytest.raises(ValueError):
vm_sdk.get_eventlog(start=1, count=-1)

# calling get_eventlog with start < 1
with pytest.raises(ValueError):
vm_sdk.get_eventlog(start=0)

def test_get_eventlog_with_valid_input(vm_sdk):
"""Test get_eventlog() funtion with valid input."""
event_logs = vm_sdk.get_eventlog()
assert event_logs is not None

def test_get_quote_with_valid_input(vm_sdk, check_quote):
"""Test get_quote() function with valid input."""
quote = vm_sdk.get_quote(None, None, None)
assert quote is not None
check_quote()

0 comments on commit 4cc0cc4

Please sign in to comment.