Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vmsdk/python/tests: add tests for TDX #57

Merged
merged 3 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion .github/workflows/vmsdk-test-python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,12 @@ jobs:
- name: Run PyTest for VMSDK
run: |
set -ex
sudo su -c "source setupenv.sh && python3 -m pytest -v ./vmsdk/python/tests/test_sdk.py"
# Set the "PYTHONDONTWRITEBYTECODE" and "no:cacheprovider" to prevent
# generated some intermediate files by root. Othwerwise, these
# files will fail the action/checkout in the next round of running
# due to the permission issue.
sudo su -c "source setupenv.sh && \
pushd vmsdk/python/tests && \
export PYTHONDONTWRITEBYTECODE=1 && \
python3 -m pytest -p no:cacheprovider -v test_sdk.py && \
popd"
7 changes: 6 additions & 1 deletion common/python/cctrusted_base/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,12 @@ def get_measurement(self, imr_select:[int, int]) -> TcgIMR:
raise NotImplementedError("Inherited SDK class should implement this.")

@abstractmethod
def get_quote(self, nonce: bytearray, data: bytearray, extraArgs=None) -> Quote:
def get_quote(
self,
nonce: bytearray = None,
data: bytearray = None,
extraArgs = None
) -> Quote:
"""Get the quote for given nonce and data.

The quote is signing of attestation data (IMR values or hashes of IMR
Expand Down
20 changes: 1 addition & 19 deletions common/python/cctrusted_base/imr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

from abc import ABC, abstractmethod
from cctrusted_base.tcg import TcgDigest, TcgAlgorithmRegistry
from cctrusted_base.tcg import TcgDigest

class TcgIMR(ABC):
"""Common Integrated Measurement Register class."""
Expand Down Expand Up @@ -56,21 +56,3 @@ def is_valid(self):
"""
return self._index != TcgIMR._INVALID_IMR_INDEX and \
self._index <= self.max_index

class TdxRTMR(TcgIMR):
"""RTMR class defined for Intel TDX."""

@property
def max_index(self):
return 3

def __init__(self, index, digest_hash):
super().__init__(index, TcgAlgorithmRegistry.TPM_ALG_SHA384,
digest_hash)

class TpmPCR(TcgIMR):
"""PCR class defined for TPM"""

@property
def max_index(self):
return 23
23 changes: 23 additions & 0 deletions common/python/cctrusted_base/tdx/rtmr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
RTMR (Runtime Measurement Register).
"""

from cctrusted_base.imr import TcgIMR
from cctrusted_base.tcg import TcgAlgorithmRegistry

class TdxRTMR(TcgIMR):
"""RTMR class defined for Intel TDX."""

RTMR_COUNT = 4
"""Intel TDX TDREPORT provides the 4 measurement registers by default."""

RTMR_LENGTH_BY_BYTES = 48
"""RTMR length by bytes."""

@property
def max_index(self):
return 3

def __init__(self, index, digest_hash):
super().__init__(index, TcgAlgorithmRegistry.TPM_ALG_SHA384,
digest_hash)
12 changes: 12 additions & 0 deletions common/python/cctrusted_base/tpm/pcr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
PCR (Platform Configuration Register).
"""

from cctrusted_base.imr import TcgIMR

class TpmPCR(TcgIMR):
"""PCR class defined for TPM"""

@property
def max_index(self):
return 23
3 changes: 2 additions & 1 deletion vmsdk/python/cc_imr_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
count = CCTrustedVmSdk.inst().get_measurement_count()
for index in range(CCTrustedVmSdk.inst().get_measurement_count()):
alg = CCTrustedVmSdk.inst().get_default_algorithms()
digest_obj = CCTrustedVmSdk.inst().get_measurement([index, alg.alg_id])
imr = CCTrustedVmSdk.inst().get_measurement([index, alg.alg_id])
digest_obj = imr.digest(alg.alg_id)

hash_str = ""
for hash_item in digest_obj.hash:
Expand Down
2 changes: 1 addition & 1 deletion vmsdk/python/cc_quote_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def main():
level=logging.NOTSET,
format="%(name)s %(levelname)-8s %(message)s"
)
quote = CCTrustedVmSdk.inst().get_quote(None, None, None)
quote = CCTrustedVmSdk.inst().get_quote()
if quote is not None:
quote.dump(args.out_format == OUT_FORMAT_RAW)
else:
Expand Down
3 changes: 2 additions & 1 deletion vmsdk/python/cctrusted_vm/cvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
import struct
import fcntl
from abc import abstractmethod
from cctrusted_base.imr import TdxRTMR,TcgIMR
from cctrusted_base.imr import TcgIMR
from cctrusted_base.quote import Quote
from cctrusted_base.tcg import TcgAlgorithmRegistry
from cctrusted_base.tdx.common import TDX_VERSION_1_0, TDX_VERSION_1_5
from cctrusted_base.tdx.rtmr import TdxRTMR
from cctrusted_base.tdx.quote import TdxQuoteReq10, TdxQuoteReq15
from cctrusted_base.tdx.report import TdxReportReq10, TdxReportReq15

Expand Down
11 changes: 8 additions & 3 deletions vmsdk/python/cctrusted_vm/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,14 @@ def get_measurement(self, imr_select:[int, int]) -> TcgIMR:
if algo_id is None or algo_id is TcgAlgorithmRegistry.TPM_ALG_ERROR:
algo_id = self._cvm.default_algo_id

return self._cvm.imrs[imr_index].digest(algo_id)

def get_quote(self, nonce: bytearray, data: bytearray, extraArgs=None) -> Quote:
return self._cvm.imrs[imr_index]

def get_quote(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why change to this style, looks wierd.

Copy link
Contributor Author

@intelzhongjie intelzhongjie Jan 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After adding default values for parameters "nonce" and "data", the line became too long. So I followed the https://google.github.io/styleguide/pyguide.html#3192-line-breaking 3.19.2 Line Breaking for that style.

self,
nonce: bytearray = None,
data: bytearray = None,
extraArgs = None
) -> Quote:
"""Get the quote for given nonce and data.

The quote is signing of attestation data (IMR values or hashes of IMR
Expand Down
66 changes: 66 additions & 0 deletions vmsdk/python/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""Local conftest.py containing directory-specific hook implementations."""

import pytest
from cctrusted_base.tcg import TcgAlgorithmRegistry
from cctrusted_base.tdx.rtmr import TdxRTMR
from cctrusted_vm.cvm import ConfidentialVM
from cctrusted_vm.sdk import CCTrustedVmSdk
import tdx_check

cnf_default_alg = {
ConfidentialVM.TYPE_CC_TDX: TcgAlgorithmRegistry.TPM_ALG_SHA384
}
"""Configurations of default algorithm.
The configurations could be different for different confidential VMs.
e.g. TDX use sha384 as the default.
"""

cnf_measurement_cnt = {
ConfidentialVM.TYPE_CC_TDX: TdxRTMR.RTMR_COUNT
}
"""Configurations of measurement count.
The configurations could be different for different confidential VMs.
"""

cnf_measurement_check = {
ConfidentialVM.TYPE_CC_TDX: tdx_check.tdx_check_measurement_imrs
}
"""Configurations of measurement check functions.
The configurations could be different for different confidential VMs.
"""

cnf_quote_check = {
ConfidentialVM.TYPE_CC_TDX: tdx_check.tdx_check_quote_rtmrs
}
"""Configurations of quote check functions.
The configurations could be different for different confidential VMs.
"""

@pytest.fixture(scope="module")
def vm_sdk():
"""Get VMSDK instance."""
return CCTrustedVmSdk.inst()

@pytest.fixture(scope="module")
def default_alg_id():
"""Get default algorithm."""
cc_type = ConfidentialVM.detect_cc_type()
return cnf_default_alg[cc_type]

@pytest.fixture(scope="module")
def measurement_count():
"""Get measurement count."""
cc_type = ConfidentialVM.detect_cc_type()
return cnf_measurement_cnt[cc_type]

@pytest.fixture(scope="module")
def check_measurement():
"""Return checker for measurement."""
cc_type = ConfidentialVM.detect_cc_type()
return cnf_measurement_check[cc_type]

@pytest.fixture(scope="module")
def check_quote():
"""Return checker for quote."""
cc_type = ConfidentialVM.detect_cc_type()
return cnf_quote_check[cc_type]
77 changes: 77 additions & 0 deletions vmsdk/python/tests/tdx_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""TDX specific test."""

from hashlib import sha384
from cctrusted_base.tcg import TcgAlgorithmRegistry, TcgImrEvent
from cctrusted_base.tdx.quote import TdxQuote, TdxQuoteBody
from cctrusted_base.tdx.rtmr import TdxRTMR
from cctrusted_vm.sdk import CCTrustedVmSdk

def _replay_eventlog():
"""Get RTMRs from event log by replay."""
rtmr_len = TdxRTMR.RTMR_LENGTH_BY_BYTES
rtmr_cnt = TdxRTMR.RTMR_COUNT
rtmrs = [bytearray(rtmr_len)] * rtmr_cnt
event_logs = CCTrustedVmSdk.inst().get_eventlog().event_logs
assert event_logs is not None
for event in event_logs:
if isinstance(event, TcgImrEvent):
sha384_algo = sha384()
sha384_algo.update(rtmrs[event.imr_index] + event.digests[0].hash)
rtmrs[event.imr_index] = sha384_algo.digest()
return rtmrs

def _check_imr(imr_index: int, alg_id: int, rtmr: bytes):
"""Check individual IMR.
Compare the 4 IMR hash with the hash derived by replay event log. They are
expected to be same.
Args:
imr_index: an integer specified the IMR index.
alg_id: an integer specified the hash algorithm.
rtmr: bytes of RTMR data for comparison.
"""
assert 0 <= imr_index < TdxRTMR.RTMR_COUNT
assert rtmr is not None
assert alg_id == TcgAlgorithmRegistry.TPM_ALG_SHA384
imr = CCTrustedVmSdk.inst().get_measurement([imr_index, alg_id])
assert imr is not None
digest_obj = imr.digest(alg_id)
assert digest_obj is not None
digest_alg_id = digest_obj.alg.alg_id
assert digest_alg_id == TcgAlgorithmRegistry.TPM_ALG_SHA384
digest_hash = digest_obj.hash
assert digest_hash is not None
assert digest_hash == rtmr, \
f"rtmr {rtmr.hex()} doesn't equal digest {digest_hash.hex()}"

def tdx_check_measurement_imrs():
"""Test measurement result.
The test is done by compare the measurement register against the value
derived by replay eventlog.
"""
alg = CCTrustedVmSdk.inst().get_default_algorithms()
rtmrs = _replay_eventlog()
_check_imr(0, alg.alg_id, rtmrs[0])
_check_imr(1, alg.alg_id, rtmrs[1])
_check_imr(2, alg.alg_id, rtmrs[2])
_check_imr(3, alg.alg_id, rtmrs[3])

def tdx_check_quote_rtmrs():
"""Test quote result.
The test is done by compare the RTMRs in quote body against the value
derived by replay eventlog.
"""
quote = CCTrustedVmSdk.inst().get_quote()
assert quote is not None
assert isinstance(quote, TdxQuote)
body = quote.body
assert body is not None
assert isinstance(body, TdxQuoteBody)
rtmrs = _replay_eventlog()
assert body.rtmr0 == rtmrs[0], \
"RTMR0 doesn't equal the replay from event log!"
assert body.rtmr1 == rtmrs[1], \
"RTMR1 doesn't equal the replay from event log!"
assert body.rtmr2 == rtmrs[2], \
"RTMR2 doesn't equal the replay from event log!"
assert body.rtmr3 == rtmrs[3], \
"RTMR3 doesn't equal the replay from event log!"
104 changes: 52 additions & 52 deletions vmsdk/python/tests/test_sdk.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,55 @@
"""Containing unit test cases for sdk class"""

import pytest
from cctrusted_vm import CCTrustedVmSdk

class TestCCTrustedVmSdk():
"""Unit tests for CCTrustedVmSdk class."""

def test_get_default_algorithms(self):
"""Test get_default_algorithms() function."""
algo = CCTrustedVmSdk.inst().get_default_algorithms()
assert algo is not None

def test_get_measurement_count(self):
"""Test get_measurement_count() function."""
count = CCTrustedVmSdk.inst().get_measurement_count()
assert count is not None

def test_get_measurement_with_invalid_input(self):
"""Test get_measurement() function with invalid input."""
# calling get_measurement() with invalid IMR index
measurement = CCTrustedVmSdk.inst().get_measurement([-1, 0xC])
assert measurement is None

# calling get_measurement() with invalid algorithm ID
measurement = CCTrustedVmSdk.inst().get_measurement([0, None])
assert measurement is not None

def test_get_measurement_with_valid_input(self):
"""Test get_measurement() function with valid input."""
count = CCTrustedVmSdk.inst().get_measurement_count()
for index in range(count):
alg = CCTrustedVmSdk.inst().get_default_algorithms()
digest_obj = CCTrustedVmSdk.inst().get_measurement([index, alg.alg_id])
assert digest_obj is not None

def test_get_eventlog_with_invalid_input(self):
"""Test get_eventlog() function with invalid input."""
# calling get_eventlog with count < 0
with pytest.raises(ValueError):
CCTrustedVmSdk.inst().get_eventlog(start=1, count=-1)

# calling get_eventlog with start < 1
with pytest.raises(ValueError):
CCTrustedVmSdk.inst().get_eventlog(start=0)

def test_get_eventlog_with_valid_input(self):
"""Test get_eventlog() funtion with valid input."""
event_logs = CCTrustedVmSdk.inst().get_eventlog()
assert event_logs is not None

def test_get_quote_with_valid_input(self):
"""Test get_quote() function with valid input."""
quote = CCTrustedVmSdk.inst().get_quote(None, None, None)
assert quote is not None

def test_get_default_algorithms(vm_sdk, default_alg_id):
"""Test get_default_algorithms() function."""
algo = vm_sdk.get_default_algorithms()
assert algo is not None
assert algo.alg_id == default_alg_id

def test_get_measurement_count(vm_sdk, measurement_count):
"""Test get_measurement_count() function."""
count = vm_sdk.get_measurement_count()
assert count is not None
assert count == measurement_count

def test_get_measurement_with_invalid_input(vm_sdk):
"""Test get_measurement() function with invalid input."""
# calling get_measurement() with invalid IMR index
measurement = vm_sdk.get_measurement([-1, 0xC])
assert measurement is None

# calling get_measurement() with invalid algorithm ID
measurement = vm_sdk.get_measurement([0, None])
assert measurement is not None

def test_get_measurement_with_valid_input(vm_sdk, check_measurement):
"""Test get_measurement() function with valid input."""
count = vm_sdk.get_measurement_count()
for index in range(count):
alg = vm_sdk.get_default_algorithms()
digest_obj = vm_sdk.get_measurement([index, alg.alg_id])
assert digest_obj is not None
check_measurement()

def test_get_eventlog_with_invalid_input(vm_sdk):
"""Test get_eventlog() function with invalid input."""
# calling get_eventlog with count < 0
with pytest.raises(ValueError):
vm_sdk.get_eventlog(start=1, count=-1)

# calling get_eventlog with start < 1
with pytest.raises(ValueError):
vm_sdk.get_eventlog(start=0)

def test_get_eventlog_with_valid_input(vm_sdk):
"""Test get_eventlog() funtion with valid input."""
event_logs = vm_sdk.get_eventlog()
assert event_logs is not None

def test_get_quote_with_valid_input(vm_sdk, check_quote):
"""Test get_quote() function with valid input."""
quote = vm_sdk.get_quote(None, None, None)
assert quote is not None
check_quote()