Skip to content

Commit

Permalink
vmsdk/python/tests: add tests for TDX
Browse files Browse the repository at this point in the history
Signed-off-by: zhongjie <zhongjie.shi@intel.com>
  • Loading branch information
intelzhongjie committed Jan 12, 2024
1 parent 904333a commit 6913295
Show file tree
Hide file tree
Showing 13 changed files with 223 additions and 77 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/vmsdk-test-python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ jobs:
- name: Run PyTest for VMSDK
run: |
set -ex
sudo su -c "source setupenv.sh && python3 -m pytest -v ./vmsdk/python/tests/test_sdk.py"
sudo su -c "source setupenv.sh && pushd vmsdk/python/tests && ./run.sh && popd"
7 changes: 6 additions & 1 deletion common/python/cctrusted_base/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,12 @@ def get_measurement(self, imr_select:[int, int]) -> TcgIMR:
raise NotImplementedError("Inherited SDK class should implement this.")

@abstractmethod
def get_quote(self, nonce: bytearray, data: bytearray, extraArgs=None) -> Quote:
def get_quote(
self,
nonce: bytearray = None,
data: bytearray = None,
extraArgs = None
) -> Quote:
"""Get the quote for given nonce and data.
The quote is signing of attestation data (IMR values or hashes of IMR
Expand Down
20 changes: 1 addition & 19 deletions common/python/cctrusted_base/imr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

from abc import ABC, abstractmethod
from cctrusted_base.tcg import TcgDigest, TcgAlgorithmRegistry
from cctrusted_base.tcg import TcgDigest

class TcgIMR(ABC):
"""Common Integrated Measurement Register class."""
Expand Down Expand Up @@ -56,21 +56,3 @@ def is_valid(self):
"""
return self._index != TcgIMR._INVALID_IMR_INDEX and \
self._index <= self.max_index

class TdxRTMR(TcgIMR):
"""RTMR class defined for Intel TDX."""

@property
def max_index(self):
return 3

def __init__(self, index, digest_hash):
super().__init__(index, TcgAlgorithmRegistry.TPM_ALG_SHA384,
digest_hash)

class TpmPCR(TcgIMR):
"""PCR class defined for TPM"""

@property
def max_index(self):
return 23
23 changes: 23 additions & 0 deletions common/python/cctrusted_base/tdx/rtmr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
RTMR (Runtime Measurement Register).
"""

from cctrusted_base.imr import TcgIMR
from cctrusted_base.tcg import TcgAlgorithmRegistry

class TdxRTMR(TcgIMR):
"""RTMR class defined for Intel TDX."""

RTMR_COUNT = 4
"""Intel TDX TDREPORT provides the 4 measurement registers by default."""

RTMR_LENGTH_BY_BYTES = 48
"""RTMR length by bytes."""

@property
def max_index(self):
return 3

def __init__(self, index, digest_hash):
super().__init__(index, TcgAlgorithmRegistry.TPM_ALG_SHA384,
digest_hash)
12 changes: 12 additions & 0 deletions common/python/cctrusted_base/tpm/pcr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
PCR (Platform Configuration Register).
"""

from cctrusted_base.imr import TcgIMR

class TpmPCR(TcgIMR):
"""PCR class defined for TPM"""

@property
def max_index(self):
return 23
3 changes: 2 additions & 1 deletion vmsdk/python/cc_imr_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
count = CCTrustedVmSdk.inst().get_measurement_count()
for index in range(CCTrustedVmSdk.inst().get_measurement_count()):
alg = CCTrustedVmSdk.inst().get_default_algorithms()
digest_obj = CCTrustedVmSdk.inst().get_measurement([index, alg.alg_id])
imr = CCTrustedVmSdk.inst().get_measurement([index, alg.alg_id])
digest_obj = imr.digest(alg.alg_id)

hash_str = ""
for hash_item in digest_obj.hash:
Expand Down
2 changes: 1 addition & 1 deletion vmsdk/python/cc_quote_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def main():
level=logging.NOTSET,
format="%(name)s %(levelname)-8s %(message)s"
)
quote = CCTrustedVmSdk.inst().get_quote(None, None, None)
quote = CCTrustedVmSdk.inst().get_quote()
if quote is not None:
quote.dump(args.out_format == OUT_FORMAT_RAW)
else:
Expand Down
3 changes: 2 additions & 1 deletion vmsdk/python/cctrusted_vm/cvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
import struct
import fcntl
from abc import abstractmethod
from cctrusted_base.imr import TdxRTMR,TcgIMR
from cctrusted_base.imr import TcgIMR
from cctrusted_base.quote import Quote
from cctrusted_base.tcg import TcgAlgorithmRegistry
from cctrusted_base.tdx.common import TDX_VERSION_1_0, TDX_VERSION_1_5
from cctrusted_base.tdx.rtmr import TdxRTMR
from cctrusted_base.tdx.quote import TdxQuoteReq10, TdxQuoteReq15
from cctrusted_base.tdx.report import TdxReportReq10, TdxReportReq15

Expand Down
11 changes: 8 additions & 3 deletions vmsdk/python/cctrusted_vm/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,14 @@ def get_measurement(self, imr_select:[int, int]) -> TcgIMR:
if algo_id is None or algo_id is TcgAlgorithmRegistry.TPM_ALG_ERROR:
algo_id = self._cvm.default_algo_id

return self._cvm.imrs[imr_index].digest(algo_id)

def get_quote(self, nonce: bytearray, data: bytearray, extraArgs=None) -> Quote:
return self._cvm.imrs[imr_index]

def get_quote(
self,
nonce: bytearray = None,
data: bytearray = None,
extraArgs = None
) -> Quote:
"""Get the quote for given nonce and data.
The quote is signing of attestation data (IMR values or hashes of IMR
Expand Down
4 changes: 4 additions & 0 deletions vmsdk/python/tests/pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[pytest]
markers =
basic: Select the test functions for basic testing
tdx: Select the test functions for TDX testing
16 changes: 16 additions & 0 deletions vmsdk/python/tests/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/bin/bash

echo Run basic tests ...
python3 -m pytest -v . -m basic

cc_type=$(python3 -c '
from cctrusted_vm.cvm import ConfidentialVM
cc_type = ConfidentialVM.detect_cc_type()
print(ConfidentialVM.TYPE_CC_STRING[cc_type])
')
echo CC type is ${cc_type}

if [[ ${cc_type} == "TDX" ]]; then
echo Run TDX specific tests ...
python3 -m pytest -v . -m tdx
fi
104 changes: 54 additions & 50 deletions vmsdk/python/tests/test_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,53 +3,57 @@
import pytest
from cctrusted_vm import CCTrustedVmSdk

class TestCCTrustedVmSdk():
"""Unit tests for CCTrustedVmSdk class."""

def test_get_default_algorithms(self):
"""Test get_default_algorithms() function."""
algo = CCTrustedVmSdk.inst().get_default_algorithms()
assert algo is not None

def test_get_measurement_count(self):
"""Test get_measurement_count() function."""
count = CCTrustedVmSdk.inst().get_measurement_count()
assert count is not None

def test_get_measurement_with_invalid_input(self):
"""Test get_measurement() function with invalid input."""
# calling get_measurement() with invalid IMR index
measurement = CCTrustedVmSdk.inst().get_measurement([-1, 0xC])
assert measurement is None

# calling get_measurement() with invalid algorithm ID
measurement = CCTrustedVmSdk.inst().get_measurement([0, None])
assert measurement is not None

def test_get_measurement_with_valid_input(self):
"""Test get_measurement() function with valid input."""
count = CCTrustedVmSdk.inst().get_measurement_count()
for index in range(count):
alg = CCTrustedVmSdk.inst().get_default_algorithms()
digest_obj = CCTrustedVmSdk.inst().get_measurement([index, alg.alg_id])
assert digest_obj is not None

def test_get_eventlog_with_invalid_input(self):
"""Test get_eventlog() function with invalid input."""
# calling get_eventlog with count < 0
with pytest.raises(ValueError):
CCTrustedVmSdk.inst().get_eventlog(start=1, count=-1)

# calling get_eventlog with start < 1
with pytest.raises(ValueError):
CCTrustedVmSdk.inst().get_eventlog(start=0)

def test_get_eventlog_with_valid_input(self):
"""Test get_eventlog() funtion with valid input."""
event_logs = CCTrustedVmSdk.inst().get_eventlog()
assert event_logs is not None

def test_get_quote_with_valid_input(self):
"""Test get_quote() function with valid input."""
quote = CCTrustedVmSdk.inst().get_quote(None, None, None)
assert quote is not None
@pytest.mark.basic
def test_get_default_algorithms():
"""Test get_default_algorithms() function."""
algo = CCTrustedVmSdk.inst().get_default_algorithms()
assert algo is not None

@pytest.mark.basic
def test_get_measurement_count():
"""Test get_measurement_count() function."""
count = CCTrustedVmSdk.inst().get_measurement_count()
assert count is not None

@pytest.mark.basic
def test_get_measurement_with_invalid_input():
"""Test get_measurement() function with invalid input."""
# calling get_measurement() with invalid IMR index
measurement = CCTrustedVmSdk.inst().get_measurement([-1, 0xC])
assert measurement is None

# calling get_measurement() with invalid algorithm ID
measurement = CCTrustedVmSdk.inst().get_measurement([0, None])
assert measurement is not None

@pytest.mark.basic
def test_get_measurement_with_valid_input():
"""Test get_measurement() function with valid input."""
count = CCTrustedVmSdk.inst().get_measurement_count()
for index in range(count):
alg = CCTrustedVmSdk.inst().get_default_algorithms()
digest_obj = CCTrustedVmSdk.inst().get_measurement([index, alg.alg_id])
assert digest_obj is not None

@pytest.mark.basic
def test_get_eventlog_with_invalid_input():
"""Test get_eventlog() function with invalid input."""
# calling get_eventlog with count < 0
with pytest.raises(ValueError):
CCTrustedVmSdk.inst().get_eventlog(start=1, count=-1)

# calling get_eventlog with start < 1
with pytest.raises(ValueError):
CCTrustedVmSdk.inst().get_eventlog(start=0)

@pytest.mark.basic
def test_get_eventlog_with_valid_input():
"""Test get_eventlog() funtion with valid input."""
event_logs = CCTrustedVmSdk.inst().get_eventlog()
assert event_logs is not None

@pytest.mark.basic
def test_get_quote_with_valid_input():
"""Test get_quote() function with valid input."""
quote = CCTrustedVmSdk.inst().get_quote(None, None, None)
assert quote is not None
93 changes: 93 additions & 0 deletions vmsdk/python/tests/test_sdk_tdx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""TDX specific test."""

from hashlib import sha384
import pytest
from cctrusted_base.tcg import TcgAlgorithmRegistry, TcgImrEvent
from cctrusted_base.tdx.quote import TdxQuote, TdxQuoteBody
from cctrusted_base.tdx.rtmr import TdxRTMR
from cctrusted_vm.sdk import CCTrustedVmSdk

@pytest.mark.tdx
def test_tdx_get_default_algorithms():
"""Test default algorithm is supported."""
algo = CCTrustedVmSdk.inst().get_default_algorithms()
assert algo is not None
assert algo.alg_id == TcgAlgorithmRegistry.TPM_ALG_SHA384

@pytest.mark.tdx
def test_tdx_get_measurement_count():
"""Test measurement count is 4 (RTMR count)."""
count = CCTrustedVmSdk.inst().get_measurement_count()
assert count == TdxRTMR.RTMR_COUNT

def _replay_eventlog():
"""Get RTMRs from event log by replay."""
rtmr_len = TdxRTMR.RTMR_LENGTH_BY_BYTES
rtmr_cnt = TdxRTMR.RTMR_COUNT
rtmrs = [bytearray(rtmr_len)] * rtmr_cnt
event_logs = CCTrustedVmSdk.inst().get_eventlog().event_logs
assert event_logs is not None
for event in event_logs:
if isinstance(event, TcgImrEvent):
sha384_algo = sha384()
sha384_algo.update(rtmrs[event.imr_index] + event.digests[0].hash)
rtmrs[event.imr_index] = sha384_algo.digest()
return rtmrs

def _check_imr(imr_index: int, alg_id: int, rtmr: bytes):
"""Check individual IMR.
Compare the 4 IMR hash with the hash derived by replay event log. They are
expected to be same.
Args:
imr_index: an integer specified the IMR index.
alg_id: an integer specified the hash algorithm.
rtmr: bytes of RTMR data for comparison.
"""
assert 0 <= imr_index < TdxRTMR.RTMR_COUNT
assert rtmr is not None
assert alg_id == TcgAlgorithmRegistry.TPM_ALG_SHA384
imr = CCTrustedVmSdk.inst().get_measurement([imr_index, alg_id])
assert imr is not None
digest_obj = imr.digest(alg_id)
assert digest_obj is not None
digest_alg_id = digest_obj.alg.alg_id
assert digest_alg_id == TcgAlgorithmRegistry.TPM_ALG_SHA384
digest_hash = digest_obj.hash
assert digest_hash is not None
assert digest_hash == rtmr, \
f"rtmr {rtmr.hex()} doesn't equal digest {digest_hash.hex()}"

@pytest.mark.tdx
def test_tdx_get_measurement_imrs():
"""Test measurement result.
The test is done by compare the measurement register against the value
derived by replay eventlog.
"""
alg = CCTrustedVmSdk.inst().get_default_algorithms()
rtmrs = _replay_eventlog()
_check_imr(0, alg.alg_id, rtmrs[0])
_check_imr(1, alg.alg_id, rtmrs[1])
_check_imr(2, alg.alg_id, rtmrs[2])
_check_imr(3, alg.alg_id, rtmrs[3])

@pytest.mark.tdx
def test_tdx_get_quote_rtmrs():
"""Test quote result.
The test is done by compare the RTMRs in quote body against the value
derived by replay eventlog.
"""
quote = CCTrustedVmSdk.inst().get_quote()
assert quote is not None
assert isinstance(quote, TdxQuote)
body = quote.body
assert body is not None
assert isinstance(body, TdxQuoteBody)
rtmrs = _replay_eventlog()
assert body.rtmr0 == rtmrs[0], \
"RTMR0 doesn't equal the replay from event log!"
assert body.rtmr1 == rtmrs[1], \
"RTMR1 doesn't equal the replay from event log!"
assert body.rtmr2 == rtmrs[2], \
"RTMR2 doesn't equal the replay from event log!"
assert body.rtmr3 == rtmrs[3], \
"RTMR3 doesn't equal the replay from event log!"

0 comments on commit 6913295

Please sign in to comment.