From ccc639383a70f92e20b5356e945cd350c80f1bac Mon Sep 17 00:00:00 2001 From: Noah Stride Date: Fri, 10 Jan 2025 21:15:11 +0000 Subject: [PATCH 01/15] Workload ID: Introduce basic rule condition operators (#50940) * Pull condition operators into a one of * Add wrapper structs to support repeated * Adjust for new protos * Fix tflint * Start fixing tests * Add more test cases * More test cases * Fix marshalling of resource * Fix test in lib/services --- .../workloadidentity/v1/resource.pb.go | 492 ++++++++++++++---- .../workloadidentity/v1/resource.proto | 38 +- .../data-sources/workload_identity.mdx | 33 +- .../resources/workload_identity.mdx | 37 +- .../teleport_workload_identity/resource.tf | 4 +- .../fixtures/workload_identity_0_create.tf | 4 +- .../fixtures/workload_identity_1_update.tf | 4 +- .../testlib/workload_identity_test.go | 12 +- .../workloadidentity/v1/resource_terraform.go | 491 ++++++++++++++++- .../machineid/workloadidentityv1/decision.go | 21 +- .../workloadidentityv1/decision_test.go | 285 +++++++++- .../workloadidentityv1_test.go | 30 +- lib/services/workload_identity.go | 12 +- lib/services/workload_identity_test.go | 14 +- tool/tctl/common/collection.go | 2 +- 15 files changed, 1317 insertions(+), 162 deletions(-) diff --git a/api/gen/proto/go/teleport/workloadidentity/v1/resource.pb.go b/api/gen/proto/go/teleport/workloadidentity/v1/resource.pb.go index 1849d7e902173..fa758941db455 100644 --- a/api/gen/proto/go/teleport/workloadidentity/v1/resource.pb.go +++ b/api/gen/proto/go/teleport/workloadidentity/v1/resource.pb.go @@ -121,20 +121,209 @@ func (x *WorkloadIdentity) GetSpec() *WorkloadIdentitySpec { return nil } +// The attribute casted to a string must be equal to the value. +type WorkloadIdentityConditionEq struct { + state protoimpl.MessageState `protogen:"open.v1"` + // The value to compare the attribute against. + Value string `protobuf:"bytes,1,opt,name=value,proto3" json:"value,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *WorkloadIdentityConditionEq) Reset() { + *x = WorkloadIdentityConditionEq{} + mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *WorkloadIdentityConditionEq) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*WorkloadIdentityConditionEq) ProtoMessage() {} + +func (x *WorkloadIdentityConditionEq) ProtoReflect() protoreflect.Message { + mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use WorkloadIdentityConditionEq.ProtoReflect.Descriptor instead. +func (*WorkloadIdentityConditionEq) Descriptor() ([]byte, []int) { + return file_teleport_workloadidentity_v1_resource_proto_rawDescGZIP(), []int{1} +} + +func (x *WorkloadIdentityConditionEq) GetValue() string { + if x != nil { + return x.Value + } + return "" +} + +// The attribute casted to a string must not be equal to the value. +type WorkloadIdentityConditionNotEq struct { + state protoimpl.MessageState `protogen:"open.v1"` + // The value to compare the attribute against. + Value string `protobuf:"bytes,1,opt,name=value,proto3" json:"value,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *WorkloadIdentityConditionNotEq) Reset() { + *x = WorkloadIdentityConditionNotEq{} + mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *WorkloadIdentityConditionNotEq) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*WorkloadIdentityConditionNotEq) ProtoMessage() {} + +func (x *WorkloadIdentityConditionNotEq) ProtoReflect() protoreflect.Message { + mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use WorkloadIdentityConditionNotEq.ProtoReflect.Descriptor instead. +func (*WorkloadIdentityConditionNotEq) Descriptor() ([]byte, []int) { + return file_teleport_workloadidentity_v1_resource_proto_rawDescGZIP(), []int{2} +} + +func (x *WorkloadIdentityConditionNotEq) GetValue() string { + if x != nil { + return x.Value + } + return "" +} + +// The attribute casted to a string must be in the list of values. +type WorkloadIdentityConditionIn struct { + state protoimpl.MessageState `protogen:"open.v1"` + // The list of values to compare the attribute against. + Values []string `protobuf:"bytes,1,rep,name=values,proto3" json:"values,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *WorkloadIdentityConditionIn) Reset() { + *x = WorkloadIdentityConditionIn{} + mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *WorkloadIdentityConditionIn) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*WorkloadIdentityConditionIn) ProtoMessage() {} + +func (x *WorkloadIdentityConditionIn) ProtoReflect() protoreflect.Message { + mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use WorkloadIdentityConditionIn.ProtoReflect.Descriptor instead. +func (*WorkloadIdentityConditionIn) Descriptor() ([]byte, []int) { + return file_teleport_workloadidentity_v1_resource_proto_rawDescGZIP(), []int{3} +} + +func (x *WorkloadIdentityConditionIn) GetValues() []string { + if x != nil { + return x.Values + } + return nil +} + +// The attribute casted to a string must not be in the list of values. +type WorkloadIdentityConditionNotIn struct { + state protoimpl.MessageState `protogen:"open.v1"` + // The list of values to compare the attribute against. + Values []string `protobuf:"bytes,1,rep,name=values,proto3" json:"values,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *WorkloadIdentityConditionNotIn) Reset() { + *x = WorkloadIdentityConditionNotIn{} + mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *WorkloadIdentityConditionNotIn) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*WorkloadIdentityConditionNotIn) ProtoMessage() {} + +func (x *WorkloadIdentityConditionNotIn) ProtoReflect() protoreflect.Message { + mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use WorkloadIdentityConditionNotIn.ProtoReflect.Descriptor instead. +func (*WorkloadIdentityConditionNotIn) Descriptor() ([]byte, []int) { + return file_teleport_workloadidentity_v1_resource_proto_rawDescGZIP(), []int{4} +} + +func (x *WorkloadIdentityConditionNotIn) GetValues() []string { + if x != nil { + return x.Values + } + return nil +} + // The individual conditions that make up a rule. type WorkloadIdentityCondition struct { state protoimpl.MessageState `protogen:"open.v1"` // The name of the attribute to evaluate the condition against. Attribute string `protobuf:"bytes,1,opt,name=attribute,proto3" json:"attribute,omitempty"` - // An exact string that the attribute must match. - Equals string `protobuf:"bytes,2,opt,name=equals,proto3" json:"equals,omitempty"` + // Types that are valid to be assigned to Operator: + // + // *WorkloadIdentityCondition_Eq + // *WorkloadIdentityCondition_NotEq + // *WorkloadIdentityCondition_In + // *WorkloadIdentityCondition_NotIn + Operator isWorkloadIdentityCondition_Operator `protobuf_oneof:"operator"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *WorkloadIdentityCondition) Reset() { *x = WorkloadIdentityCondition{} - mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[1] + mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[5] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -146,7 +335,7 @@ func (x *WorkloadIdentityCondition) String() string { func (*WorkloadIdentityCondition) ProtoMessage() {} func (x *WorkloadIdentityCondition) ProtoReflect() protoreflect.Message { - mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[1] + mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[5] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -159,7 +348,7 @@ func (x *WorkloadIdentityCondition) ProtoReflect() protoreflect.Message { // Deprecated: Use WorkloadIdentityCondition.ProtoReflect.Descriptor instead. func (*WorkloadIdentityCondition) Descriptor() ([]byte, []int) { - return file_teleport_workloadidentity_v1_resource_proto_rawDescGZIP(), []int{1} + return file_teleport_workloadidentity_v1_resource_proto_rawDescGZIP(), []int{5} } func (x *WorkloadIdentityCondition) GetAttribute() string { @@ -169,13 +358,81 @@ func (x *WorkloadIdentityCondition) GetAttribute() string { return "" } -func (x *WorkloadIdentityCondition) GetEquals() string { +func (x *WorkloadIdentityCondition) GetOperator() isWorkloadIdentityCondition_Operator { if x != nil { - return x.Equals + return x.Operator } - return "" + return nil +} + +func (x *WorkloadIdentityCondition) GetEq() *WorkloadIdentityConditionEq { + if x != nil { + if x, ok := x.Operator.(*WorkloadIdentityCondition_Eq); ok { + return x.Eq + } + } + return nil +} + +func (x *WorkloadIdentityCondition) GetNotEq() *WorkloadIdentityConditionNotEq { + if x != nil { + if x, ok := x.Operator.(*WorkloadIdentityCondition_NotEq); ok { + return x.NotEq + } + } + return nil } +func (x *WorkloadIdentityCondition) GetIn() *WorkloadIdentityConditionIn { + if x != nil { + if x, ok := x.Operator.(*WorkloadIdentityCondition_In); ok { + return x.In + } + } + return nil +} + +func (x *WorkloadIdentityCondition) GetNotIn() *WorkloadIdentityConditionNotIn { + if x != nil { + if x, ok := x.Operator.(*WorkloadIdentityCondition_NotIn); ok { + return x.NotIn + } + } + return nil +} + +type isWorkloadIdentityCondition_Operator interface { + isWorkloadIdentityCondition_Operator() +} + +type WorkloadIdentityCondition_Eq struct { + // The attribute casted to a string must be equal to the value. + Eq *WorkloadIdentityConditionEq `protobuf:"bytes,3,opt,name=eq,proto3,oneof"` +} + +type WorkloadIdentityCondition_NotEq struct { + // The attribute casted to a string must not be equal to the value. + NotEq *WorkloadIdentityConditionNotEq `protobuf:"bytes,4,opt,name=not_eq,json=notEq,proto3,oneof"` +} + +type WorkloadIdentityCondition_In struct { + // The attribute casted to a string must be in the list of values. + In *WorkloadIdentityConditionIn `protobuf:"bytes,5,opt,name=in,proto3,oneof"` +} + +type WorkloadIdentityCondition_NotIn struct { + // The attribute casted to a string must not be in the list of values. + NotIn *WorkloadIdentityConditionNotIn `protobuf:"bytes,6,opt,name=not_in,json=notIn,proto3,oneof"` +} + +func (*WorkloadIdentityCondition_Eq) isWorkloadIdentityCondition_Operator() {} + +func (*WorkloadIdentityCondition_NotEq) isWorkloadIdentityCondition_Operator() {} + +func (*WorkloadIdentityCondition_In) isWorkloadIdentityCondition_Operator() {} + +func (*WorkloadIdentityCondition_NotIn) isWorkloadIdentityCondition_Operator() {} + // An individual rule that is evaluated during the issuance of a WorkloadIdentity. type WorkloadIdentityRule struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -187,7 +444,7 @@ type WorkloadIdentityRule struct { func (x *WorkloadIdentityRule) Reset() { *x = WorkloadIdentityRule{} - mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[2] + mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[6] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -199,7 +456,7 @@ func (x *WorkloadIdentityRule) String() string { func (*WorkloadIdentityRule) ProtoMessage() {} func (x *WorkloadIdentityRule) ProtoReflect() protoreflect.Message { - mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[2] + mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[6] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -212,7 +469,7 @@ func (x *WorkloadIdentityRule) ProtoReflect() protoreflect.Message { // Deprecated: Use WorkloadIdentityRule.ProtoReflect.Descriptor instead. func (*WorkloadIdentityRule) Descriptor() ([]byte, []int) { - return file_teleport_workloadidentity_v1_resource_proto_rawDescGZIP(), []int{2} + return file_teleport_workloadidentity_v1_resource_proto_rawDescGZIP(), []int{6} } func (x *WorkloadIdentityRule) GetConditions() []*WorkloadIdentityCondition { @@ -235,7 +492,7 @@ type WorkloadIdentityRules struct { func (x *WorkloadIdentityRules) Reset() { *x = WorkloadIdentityRules{} - mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[3] + mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[7] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -247,7 +504,7 @@ func (x *WorkloadIdentityRules) String() string { func (*WorkloadIdentityRules) ProtoMessage() {} func (x *WorkloadIdentityRules) ProtoReflect() protoreflect.Message { - mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[3] + mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[7] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -260,7 +517,7 @@ func (x *WorkloadIdentityRules) ProtoReflect() protoreflect.Message { // Deprecated: Use WorkloadIdentityRules.ProtoReflect.Descriptor instead. func (*WorkloadIdentityRules) Descriptor() ([]byte, []int) { - return file_teleport_workloadidentity_v1_resource_proto_rawDescGZIP(), []int{3} + return file_teleport_workloadidentity_v1_resource_proto_rawDescGZIP(), []int{7} } func (x *WorkloadIdentityRules) GetAllow() []*WorkloadIdentityRule { @@ -284,7 +541,7 @@ type WorkloadIdentitySPIFFEX509 struct { func (x *WorkloadIdentitySPIFFEX509) Reset() { *x = WorkloadIdentitySPIFFEX509{} - mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[4] + mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[8] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -296,7 +553,7 @@ func (x *WorkloadIdentitySPIFFEX509) String() string { func (*WorkloadIdentitySPIFFEX509) ProtoMessage() {} func (x *WorkloadIdentitySPIFFEX509) ProtoReflect() protoreflect.Message { - mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[4] + mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[8] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -309,7 +566,7 @@ func (x *WorkloadIdentitySPIFFEX509) ProtoReflect() protoreflect.Message { // Deprecated: Use WorkloadIdentitySPIFFEX509.ProtoReflect.Descriptor instead. func (*WorkloadIdentitySPIFFEX509) Descriptor() ([]byte, []int) { - return file_teleport_workloadidentity_v1_resource_proto_rawDescGZIP(), []int{4} + return file_teleport_workloadidentity_v1_resource_proto_rawDescGZIP(), []int{8} } func (x *WorkloadIdentitySPIFFEX509) GetDnsSans() []string { @@ -341,7 +598,7 @@ type WorkloadIdentitySPIFFE struct { func (x *WorkloadIdentitySPIFFE) Reset() { *x = WorkloadIdentitySPIFFE{} - mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[5] + mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[9] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -353,7 +610,7 @@ func (x *WorkloadIdentitySPIFFE) String() string { func (*WorkloadIdentitySPIFFE) ProtoMessage() {} func (x *WorkloadIdentitySPIFFE) ProtoReflect() protoreflect.Message { - mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[5] + mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[9] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -366,7 +623,7 @@ func (x *WorkloadIdentitySPIFFE) ProtoReflect() protoreflect.Message { // Deprecated: Use WorkloadIdentitySPIFFE.ProtoReflect.Descriptor instead. func (*WorkloadIdentitySPIFFE) Descriptor() ([]byte, []int) { - return file_teleport_workloadidentity_v1_resource_proto_rawDescGZIP(), []int{5} + return file_teleport_workloadidentity_v1_resource_proto_rawDescGZIP(), []int{9} } func (x *WorkloadIdentitySPIFFE) GetId() string { @@ -404,7 +661,7 @@ type WorkloadIdentitySpec struct { func (x *WorkloadIdentitySpec) Reset() { *x = WorkloadIdentitySpec{} - mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[6] + mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[10] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -416,7 +673,7 @@ func (x *WorkloadIdentitySpec) String() string { func (*WorkloadIdentitySpec) ProtoMessage() {} func (x *WorkloadIdentitySpec) ProtoReflect() protoreflect.Message { - mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[6] + mi := &file_teleport_workloadidentity_v1_resource_proto_msgTypes[10] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -429,7 +686,7 @@ func (x *WorkloadIdentitySpec) ProtoReflect() protoreflect.Message { // Deprecated: Use WorkloadIdentitySpec.ProtoReflect.Descriptor instead. func (*WorkloadIdentitySpec) Descriptor() ([]byte, []int) { - return file_teleport_workloadidentity_v1_resource_proto_rawDescGZIP(), []int{6} + return file_teleport_workloadidentity_v1_resource_proto_rawDescGZIP(), []int{10} } func (x *WorkloadIdentitySpec) GetRules() *WorkloadIdentityRules { @@ -469,56 +726,91 @@ var file_teleport_workloadidentity_v1_resource_proto_rawDesc = []byte{ 0x20, 0x01, 0x28, 0x0b, 0x32, 0x32, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x77, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x2e, 0x76, 0x31, 0x2e, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, 0x65, 0x6e, - 0x74, 0x69, 0x74, 0x79, 0x53, 0x70, 0x65, 0x63, 0x52, 0x04, 0x73, 0x70, 0x65, 0x63, 0x22, 0x51, - 0x0a, 0x19, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, - 0x74, 0x79, 0x43, 0x6f, 0x6e, 0x64, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x61, - 0x74, 0x74, 0x72, 0x69, 0x62, 0x75, 0x74, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, - 0x61, 0x74, 0x74, 0x72, 0x69, 0x62, 0x75, 0x74, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x65, 0x71, 0x75, - 0x61, 0x6c, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x65, 0x71, 0x75, 0x61, 0x6c, - 0x73, 0x22, 0x6f, 0x0a, 0x14, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, 0x65, - 0x6e, 0x74, 0x69, 0x74, 0x79, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x57, 0x0a, 0x0a, 0x63, 0x6f, 0x6e, - 0x64, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x37, 0x2e, - 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x77, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, - 0x64, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x2e, 0x76, 0x31, 0x2e, 0x57, 0x6f, 0x72, - 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x43, 0x6f, 0x6e, - 0x64, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x0a, 0x63, 0x6f, 0x6e, 0x64, 0x69, 0x74, 0x69, 0x6f, - 0x6e, 0x73, 0x22, 0x61, 0x0a, 0x15, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, - 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x48, 0x0a, 0x05, 0x61, - 0x6c, 0x6c, 0x6f, 0x77, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x32, 0x2e, 0x74, 0x65, 0x6c, - 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x77, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x69, 0x64, - 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x2e, 0x76, 0x31, 0x2e, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, - 0x61, 0x64, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x05, - 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x22, 0x37, 0x0a, 0x1a, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, - 0x64, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x53, 0x50, 0x49, 0x46, 0x46, 0x45, 0x58, - 0x35, 0x30, 0x39, 0x12, 0x19, 0x0a, 0x08, 0x64, 0x6e, 0x73, 0x5f, 0x73, 0x61, 0x6e, 0x73, 0x18, - 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6e, 0x73, 0x53, 0x61, 0x6e, 0x73, 0x22, 0x8a, - 0x01, 0x0a, 0x16, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, 0x65, 0x6e, 0x74, - 0x69, 0x74, 0x79, 0x53, 0x50, 0x49, 0x46, 0x46, 0x45, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x68, 0x69, 0x6e, - 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x68, 0x69, 0x6e, 0x74, 0x12, 0x4c, 0x0a, - 0x04, 0x78, 0x35, 0x30, 0x39, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x38, 0x2e, 0x74, 0x65, - 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x77, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x69, - 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x2e, 0x76, 0x31, 0x2e, 0x57, 0x6f, 0x72, 0x6b, 0x6c, - 0x6f, 0x61, 0x64, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x53, 0x50, 0x49, 0x46, 0x46, - 0x45, 0x58, 0x35, 0x30, 0x39, 0x52, 0x04, 0x78, 0x35, 0x30, 0x39, 0x22, 0xaf, 0x01, 0x0a, 0x14, - 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, - 0x53, 0x70, 0x65, 0x63, 0x12, 0x49, 0x0a, 0x05, 0x72, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x0b, 0x32, 0x33, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x77, + 0x74, 0x69, 0x74, 0x79, 0x53, 0x70, 0x65, 0x63, 0x52, 0x04, 0x73, 0x70, 0x65, 0x63, 0x22, 0x33, + 0x0a, 0x1b, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, + 0x74, 0x79, 0x43, 0x6f, 0x6e, 0x64, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x71, 0x12, 0x14, 0x0a, + 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, + 0x6c, 0x75, 0x65, 0x22, 0x36, 0x0a, 0x1e, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, + 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x43, 0x6f, 0x6e, 0x64, 0x69, 0x74, 0x69, 0x6f, 0x6e, + 0x4e, 0x6f, 0x74, 0x45, 0x71, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x22, 0x35, 0x0a, 0x1b, 0x57, + 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x43, + 0x6f, 0x6e, 0x64, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x49, 0x6e, 0x12, 0x16, 0x0a, 0x06, 0x76, 0x61, + 0x6c, 0x75, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x76, 0x61, 0x6c, 0x75, + 0x65, 0x73, 0x22, 0x38, 0x0a, 0x1e, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, + 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x43, 0x6f, 0x6e, 0x64, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x4e, + 0x6f, 0x74, 0x49, 0x6e, 0x12, 0x16, 0x0a, 0x06, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x18, 0x01, + 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x22, 0x9b, 0x03, 0x0a, + 0x19, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, + 0x79, 0x43, 0x6f, 0x6e, 0x64, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x74, + 0x74, 0x72, 0x69, 0x62, 0x75, 0x74, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x61, + 0x74, 0x74, 0x72, 0x69, 0x62, 0x75, 0x74, 0x65, 0x12, 0x4b, 0x0a, 0x02, 0x65, 0x71, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x39, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2e, + 0x77, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, + 0x2e, 0x76, 0x31, 0x2e, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, 0x65, 0x6e, + 0x74, 0x69, 0x74, 0x79, 0x43, 0x6f, 0x6e, 0x64, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x71, 0x48, + 0x00, 0x52, 0x02, 0x65, 0x71, 0x12, 0x55, 0x0a, 0x06, 0x6e, 0x6f, 0x74, 0x5f, 0x65, 0x71, 0x18, + 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x3c, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, + 0x2e, 0x77, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, + 0x79, 0x2e, 0x76, 0x31, 0x2e, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, 0x65, + 0x6e, 0x74, 0x69, 0x74, 0x79, 0x43, 0x6f, 0x6e, 0x64, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x4e, 0x6f, + 0x74, 0x45, 0x71, 0x48, 0x00, 0x52, 0x05, 0x6e, 0x6f, 0x74, 0x45, 0x71, 0x12, 0x4b, 0x0a, 0x02, + 0x69, 0x6e, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x39, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70, + 0x6f, 0x72, 0x74, 0x2e, 0x77, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x69, 0x64, 0x65, 0x6e, + 0x74, 0x69, 0x74, 0x79, 0x2e, 0x76, 0x31, 0x2e, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, + 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x43, 0x6f, 0x6e, 0x64, 0x69, 0x74, 0x69, 0x6f, + 0x6e, 0x49, 0x6e, 0x48, 0x00, 0x52, 0x02, 0x69, 0x6e, 0x12, 0x55, 0x0a, 0x06, 0x6e, 0x6f, 0x74, + 0x5f, 0x69, 0x6e, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x3c, 0x2e, 0x74, 0x65, 0x6c, 0x65, + 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x77, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x69, 0x64, 0x65, + 0x6e, 0x74, 0x69, 0x74, 0x79, 0x2e, 0x76, 0x31, 0x2e, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, + 0x64, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x43, 0x6f, 0x6e, 0x64, 0x69, 0x74, 0x69, + 0x6f, 0x6e, 0x4e, 0x6f, 0x74, 0x49, 0x6e, 0x48, 0x00, 0x52, 0x05, 0x6e, 0x6f, 0x74, 0x49, 0x6e, + 0x42, 0x0a, 0x0a, 0x08, 0x6f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x6f, 0x72, 0x4a, 0x04, 0x08, 0x02, + 0x10, 0x03, 0x52, 0x06, 0x65, 0x71, 0x75, 0x61, 0x6c, 0x73, 0x22, 0x6f, 0x0a, 0x14, 0x57, 0x6f, + 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x52, 0x75, + 0x6c, 0x65, 0x12, 0x57, 0x0a, 0x0a, 0x63, 0x6f, 0x6e, 0x64, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x73, + 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x37, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, + 0x74, 0x2e, 0x77, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, + 0x74, 0x79, 0x2e, 0x76, 0x31, 0x2e, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, + 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x43, 0x6f, 0x6e, 0x64, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x52, + 0x0a, 0x63, 0x6f, 0x6e, 0x64, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x22, 0x61, 0x0a, 0x15, 0x57, + 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x52, + 0x75, 0x6c, 0x65, 0x73, 0x12, 0x48, 0x0a, 0x05, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x18, 0x01, 0x20, + 0x03, 0x28, 0x0b, 0x32, 0x32, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x77, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x2e, 0x76, 0x31, 0x2e, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, 0x65, 0x6e, 0x74, - 0x69, 0x74, 0x79, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x52, 0x05, 0x72, 0x75, 0x6c, 0x65, 0x73, 0x12, - 0x4c, 0x0a, 0x06, 0x73, 0x70, 0x69, 0x66, 0x66, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, - 0x34, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x77, 0x6f, 0x72, 0x6b, 0x6c, - 0x6f, 0x61, 0x64, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x2e, 0x76, 0x31, 0x2e, 0x57, - 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x53, - 0x50, 0x49, 0x46, 0x46, 0x45, 0x52, 0x06, 0x73, 0x70, 0x69, 0x66, 0x66, 0x65, 0x42, 0x64, 0x5a, - 0x62, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x67, 0x72, 0x61, 0x76, - 0x69, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x2f, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, - 0x72, 0x74, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x67, 0x65, 0x6e, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x2f, 0x67, 0x6f, 0x2f, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2f, 0x77, 0x6f, 0x72, - 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x2f, 0x76, 0x31, - 0x3b, 0x77, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, - 0x79, 0x76, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x69, 0x74, 0x79, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x05, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x22, 0x37, + 0x0a, 0x1a, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, + 0x74, 0x79, 0x53, 0x50, 0x49, 0x46, 0x46, 0x45, 0x58, 0x35, 0x30, 0x39, 0x12, 0x19, 0x0a, 0x08, + 0x64, 0x6e, 0x73, 0x5f, 0x73, 0x61, 0x6e, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, + 0x64, 0x6e, 0x73, 0x53, 0x61, 0x6e, 0x73, 0x22, 0x8a, 0x01, 0x0a, 0x16, 0x57, 0x6f, 0x72, 0x6b, + 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x53, 0x50, 0x49, 0x46, + 0x46, 0x45, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, + 0x69, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x68, 0x69, 0x6e, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x04, 0x68, 0x69, 0x6e, 0x74, 0x12, 0x4c, 0x0a, 0x04, 0x78, 0x35, 0x30, 0x39, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x38, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2e, + 0x77, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, + 0x2e, 0x76, 0x31, 0x2e, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, 0x65, 0x6e, + 0x74, 0x69, 0x74, 0x79, 0x53, 0x50, 0x49, 0x46, 0x46, 0x45, 0x58, 0x35, 0x30, 0x39, 0x52, 0x04, + 0x78, 0x35, 0x30, 0x39, 0x22, 0xaf, 0x01, 0x0a, 0x14, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, + 0x64, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x53, 0x70, 0x65, 0x63, 0x12, 0x49, 0x0a, + 0x05, 0x72, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x33, 0x2e, 0x74, + 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2e, 0x77, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, + 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x2e, 0x76, 0x31, 0x2e, 0x57, 0x6f, 0x72, 0x6b, + 0x6c, 0x6f, 0x61, 0x64, 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x52, 0x75, 0x6c, 0x65, + 0x73, 0x52, 0x05, 0x72, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x4c, 0x0a, 0x06, 0x73, 0x70, 0x69, 0x66, + 0x66, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x34, 0x2e, 0x74, 0x65, 0x6c, 0x65, 0x70, + 0x6f, 0x72, 0x74, 0x2e, 0x77, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x69, 0x64, 0x65, 0x6e, + 0x74, 0x69, 0x74, 0x79, 0x2e, 0x76, 0x31, 0x2e, 0x57, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, + 0x49, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x53, 0x50, 0x49, 0x46, 0x46, 0x45, 0x52, 0x06, + 0x73, 0x70, 0x69, 0x66, 0x66, 0x65, 0x42, 0x64, 0x5a, 0x62, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, + 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x67, 0x72, 0x61, 0x76, 0x69, 0x74, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x61, 0x6c, 0x2f, 0x74, 0x65, 0x6c, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2f, 0x61, 0x70, 0x69, 0x2f, + 0x67, 0x65, 0x6e, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x67, 0x6f, 0x2f, 0x74, 0x65, 0x6c, + 0x65, 0x70, 0x6f, 0x72, 0x74, 0x2f, 0x77, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, 0x61, 0x64, 0x69, 0x64, + 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x2f, 0x76, 0x31, 0x3b, 0x77, 0x6f, 0x72, 0x6b, 0x6c, 0x6f, + 0x61, 0x64, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x74, 0x79, 0x76, 0x31, 0x62, 0x06, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -533,30 +825,38 @@ func file_teleport_workloadidentity_v1_resource_proto_rawDescGZIP() []byte { return file_teleport_workloadidentity_v1_resource_proto_rawDescData } -var file_teleport_workloadidentity_v1_resource_proto_msgTypes = make([]protoimpl.MessageInfo, 7) +var file_teleport_workloadidentity_v1_resource_proto_msgTypes = make([]protoimpl.MessageInfo, 11) var file_teleport_workloadidentity_v1_resource_proto_goTypes = []any{ - (*WorkloadIdentity)(nil), // 0: teleport.workloadidentity.v1.WorkloadIdentity - (*WorkloadIdentityCondition)(nil), // 1: teleport.workloadidentity.v1.WorkloadIdentityCondition - (*WorkloadIdentityRule)(nil), // 2: teleport.workloadidentity.v1.WorkloadIdentityRule - (*WorkloadIdentityRules)(nil), // 3: teleport.workloadidentity.v1.WorkloadIdentityRules - (*WorkloadIdentitySPIFFEX509)(nil), // 4: teleport.workloadidentity.v1.WorkloadIdentitySPIFFEX509 - (*WorkloadIdentitySPIFFE)(nil), // 5: teleport.workloadidentity.v1.WorkloadIdentitySPIFFE - (*WorkloadIdentitySpec)(nil), // 6: teleport.workloadidentity.v1.WorkloadIdentitySpec - (*v1.Metadata)(nil), // 7: teleport.header.v1.Metadata + (*WorkloadIdentity)(nil), // 0: teleport.workloadidentity.v1.WorkloadIdentity + (*WorkloadIdentityConditionEq)(nil), // 1: teleport.workloadidentity.v1.WorkloadIdentityConditionEq + (*WorkloadIdentityConditionNotEq)(nil), // 2: teleport.workloadidentity.v1.WorkloadIdentityConditionNotEq + (*WorkloadIdentityConditionIn)(nil), // 3: teleport.workloadidentity.v1.WorkloadIdentityConditionIn + (*WorkloadIdentityConditionNotIn)(nil), // 4: teleport.workloadidentity.v1.WorkloadIdentityConditionNotIn + (*WorkloadIdentityCondition)(nil), // 5: teleport.workloadidentity.v1.WorkloadIdentityCondition + (*WorkloadIdentityRule)(nil), // 6: teleport.workloadidentity.v1.WorkloadIdentityRule + (*WorkloadIdentityRules)(nil), // 7: teleport.workloadidentity.v1.WorkloadIdentityRules + (*WorkloadIdentitySPIFFEX509)(nil), // 8: teleport.workloadidentity.v1.WorkloadIdentitySPIFFEX509 + (*WorkloadIdentitySPIFFE)(nil), // 9: teleport.workloadidentity.v1.WorkloadIdentitySPIFFE + (*WorkloadIdentitySpec)(nil), // 10: teleport.workloadidentity.v1.WorkloadIdentitySpec + (*v1.Metadata)(nil), // 11: teleport.header.v1.Metadata } var file_teleport_workloadidentity_v1_resource_proto_depIdxs = []int32{ - 7, // 0: teleport.workloadidentity.v1.WorkloadIdentity.metadata:type_name -> teleport.header.v1.Metadata - 6, // 1: teleport.workloadidentity.v1.WorkloadIdentity.spec:type_name -> teleport.workloadidentity.v1.WorkloadIdentitySpec - 1, // 2: teleport.workloadidentity.v1.WorkloadIdentityRule.conditions:type_name -> teleport.workloadidentity.v1.WorkloadIdentityCondition - 2, // 3: teleport.workloadidentity.v1.WorkloadIdentityRules.allow:type_name -> teleport.workloadidentity.v1.WorkloadIdentityRule - 4, // 4: teleport.workloadidentity.v1.WorkloadIdentitySPIFFE.x509:type_name -> teleport.workloadidentity.v1.WorkloadIdentitySPIFFEX509 - 3, // 5: teleport.workloadidentity.v1.WorkloadIdentitySpec.rules:type_name -> teleport.workloadidentity.v1.WorkloadIdentityRules - 5, // 6: teleport.workloadidentity.v1.WorkloadIdentitySpec.spiffe:type_name -> teleport.workloadidentity.v1.WorkloadIdentitySPIFFE - 7, // [7:7] is the sub-list for method output_type - 7, // [7:7] is the sub-list for method input_type - 7, // [7:7] is the sub-list for extension type_name - 7, // [7:7] is the sub-list for extension extendee - 0, // [0:7] is the sub-list for field type_name + 11, // 0: teleport.workloadidentity.v1.WorkloadIdentity.metadata:type_name -> teleport.header.v1.Metadata + 10, // 1: teleport.workloadidentity.v1.WorkloadIdentity.spec:type_name -> teleport.workloadidentity.v1.WorkloadIdentitySpec + 1, // 2: teleport.workloadidentity.v1.WorkloadIdentityCondition.eq:type_name -> teleport.workloadidentity.v1.WorkloadIdentityConditionEq + 2, // 3: teleport.workloadidentity.v1.WorkloadIdentityCondition.not_eq:type_name -> teleport.workloadidentity.v1.WorkloadIdentityConditionNotEq + 3, // 4: teleport.workloadidentity.v1.WorkloadIdentityCondition.in:type_name -> teleport.workloadidentity.v1.WorkloadIdentityConditionIn + 4, // 5: teleport.workloadidentity.v1.WorkloadIdentityCondition.not_in:type_name -> teleport.workloadidentity.v1.WorkloadIdentityConditionNotIn + 5, // 6: teleport.workloadidentity.v1.WorkloadIdentityRule.conditions:type_name -> teleport.workloadidentity.v1.WorkloadIdentityCondition + 6, // 7: teleport.workloadidentity.v1.WorkloadIdentityRules.allow:type_name -> teleport.workloadidentity.v1.WorkloadIdentityRule + 8, // 8: teleport.workloadidentity.v1.WorkloadIdentitySPIFFE.x509:type_name -> teleport.workloadidentity.v1.WorkloadIdentitySPIFFEX509 + 7, // 9: teleport.workloadidentity.v1.WorkloadIdentitySpec.rules:type_name -> teleport.workloadidentity.v1.WorkloadIdentityRules + 9, // 10: teleport.workloadidentity.v1.WorkloadIdentitySpec.spiffe:type_name -> teleport.workloadidentity.v1.WorkloadIdentitySPIFFE + 11, // [11:11] is the sub-list for method output_type + 11, // [11:11] is the sub-list for method input_type + 11, // [11:11] is the sub-list for extension type_name + 11, // [11:11] is the sub-list for extension extendee + 0, // [0:11] is the sub-list for field type_name } func init() { file_teleport_workloadidentity_v1_resource_proto_init() } @@ -564,13 +864,19 @@ func file_teleport_workloadidentity_v1_resource_proto_init() { if File_teleport_workloadidentity_v1_resource_proto != nil { return } + file_teleport_workloadidentity_v1_resource_proto_msgTypes[5].OneofWrappers = []any{ + (*WorkloadIdentityCondition_Eq)(nil), + (*WorkloadIdentityCondition_NotEq)(nil), + (*WorkloadIdentityCondition_In)(nil), + (*WorkloadIdentityCondition_NotIn)(nil), + } type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_teleport_workloadidentity_v1_resource_proto_rawDesc, NumEnums: 0, - NumMessages: 7, + NumMessages: 11, NumExtensions: 0, NumServices: 0, }, diff --git a/api/proto/teleport/workloadidentity/v1/resource.proto b/api/proto/teleport/workloadidentity/v1/resource.proto index b0faf7f94b99e..ad4cc03cf4c24 100644 --- a/api/proto/teleport/workloadidentity/v1/resource.proto +++ b/api/proto/teleport/workloadidentity/v1/resource.proto @@ -38,12 +38,46 @@ message WorkloadIdentity { WorkloadIdentitySpec spec = 5; } +// The attribute casted to a string must be equal to the value. +message WorkloadIdentityConditionEq { + // The value to compare the attribute against. + string value = 1; +} + +// The attribute casted to a string must not be equal to the value. +message WorkloadIdentityConditionNotEq { + // The value to compare the attribute against. + string value = 1; +} + +// The attribute casted to a string must be in the list of values. +message WorkloadIdentityConditionIn { + // The list of values to compare the attribute against. + repeated string values = 1; +} + +// The attribute casted to a string must not be in the list of values. +message WorkloadIdentityConditionNotIn { + // The list of values to compare the attribute against. + repeated string values = 1; +} + // The individual conditions that make up a rule. message WorkloadIdentityCondition { + reserved 2; + reserved "equals"; // The name of the attribute to evaluate the condition against. string attribute = 1; - // An exact string that the attribute must match. - string equals = 2; + oneof operator { + // The attribute casted to a string must be equal to the value. + WorkloadIdentityConditionEq eq = 3; + // The attribute casted to a string must not be equal to the value. + WorkloadIdentityConditionNotEq not_eq = 4; + // The attribute casted to a string must be in the list of values. + WorkloadIdentityConditionIn in = 5; + // The attribute casted to a string must not be in the list of values. + WorkloadIdentityConditionNotIn not_in = 6; + } } // An individual rule that is evaluated during the issuance of a WorkloadIdentity. diff --git a/docs/pages/reference/terraform-provider/data-sources/workload_identity.mdx b/docs/pages/reference/terraform-provider/data-sources/workload_identity.mdx index 7c7e1a05a5af0..6a5f12830bf4f 100644 --- a/docs/pages/reference/terraform-provider/data-sources/workload_identity.mdx +++ b/docs/pages/reference/terraform-provider/data-sources/workload_identity.mdx @@ -55,7 +55,38 @@ Optional: Optional: - `attribute` (String) The name of the attribute to evaluate the condition against. -- `equals` (String) An exact string that the attribute must match. +- `eq` (Attributes) The attribute casted to a string must be equal to the value. (see [below for nested schema](#nested-schema-for-specrulesallowconditionseq)) +- `in` (Attributes) The attribute casted to a string must be in the list of values. (see [below for nested schema](#nested-schema-for-specrulesallowconditionsin)) +- `not_eq` (Attributes) The attribute casted to a string must not be equal to the value. (see [below for nested schema](#nested-schema-for-specrulesallowconditionsnot_eq)) +- `not_in` (Attributes) The attribute casted to a string must not be in the list of values. (see [below for nested schema](#nested-schema-for-specrulesallowconditionsnot_in)) + +### Nested Schema for `spec.rules.allow.conditions.eq` + +Optional: + +- `value` (String) The value to compare the attribute against. + + +### Nested Schema for `spec.rules.allow.conditions.in` + +Optional: + +- `values` (List of String) The list of values to compare the attribute against. + + +### Nested Schema for `spec.rules.allow.conditions.not_eq` + +Optional: + +- `value` (String) The value to compare the attribute against. + + +### Nested Schema for `spec.rules.allow.conditions.not_in` + +Optional: + +- `values` (List of String) The list of values to compare the attribute against. + diff --git a/docs/pages/reference/terraform-provider/resources/workload_identity.mdx b/docs/pages/reference/terraform-provider/resources/workload_identity.mdx index fbbeb1306abd8..6238a0d535b03 100644 --- a/docs/pages/reference/terraform-provider/resources/workload_identity.mdx +++ b/docs/pages/reference/terraform-provider/resources/workload_identity.mdx @@ -23,7 +23,9 @@ resource "teleport_workload_identity" "example" { { conditions = [{ attribute = "user.name" - equals = "noah" + eq = { + value = "my-user" + } }] } ] @@ -80,7 +82,38 @@ Optional: Optional: - `attribute` (String) The name of the attribute to evaluate the condition against. -- `equals` (String) An exact string that the attribute must match. +- `eq` (Attributes) The attribute casted to a string must be equal to the value. (see [below for nested schema](#nested-schema-for-specrulesallowconditionseq)) +- `in` (Attributes) The attribute casted to a string must be in the list of values. (see [below for nested schema](#nested-schema-for-specrulesallowconditionsin)) +- `not_eq` (Attributes) The attribute casted to a string must not be equal to the value. (see [below for nested schema](#nested-schema-for-specrulesallowconditionsnot_eq)) +- `not_in` (Attributes) The attribute casted to a string must not be in the list of values. (see [below for nested schema](#nested-schema-for-specrulesallowconditionsnot_in)) + +### Nested Schema for `spec.rules.allow.conditions.eq` + +Optional: + +- `value` (String) The value to compare the attribute against. + + +### Nested Schema for `spec.rules.allow.conditions.in` + +Optional: + +- `values` (List of String) The list of values to compare the attribute against. + + +### Nested Schema for `spec.rules.allow.conditions.not_eq` + +Optional: + +- `value` (String) The value to compare the attribute against. + + +### Nested Schema for `spec.rules.allow.conditions.not_in` + +Optional: + +- `values` (List of String) The list of values to compare the attribute against. + diff --git a/integrations/terraform/examples/resources/teleport_workload_identity/resource.tf b/integrations/terraform/examples/resources/teleport_workload_identity/resource.tf index e48ab1e5d0dd2..34dee932f430f 100644 --- a/integrations/terraform/examples/resources/teleport_workload_identity/resource.tf +++ b/integrations/terraform/examples/resources/teleport_workload_identity/resource.tf @@ -9,7 +9,9 @@ resource "teleport_workload_identity" "example" { { conditions = [{ attribute = "user.name" - equals = "noah" + eq = { + value = "my-user" + } }] } ] diff --git a/integrations/terraform/testlib/fixtures/workload_identity_0_create.tf b/integrations/terraform/testlib/fixtures/workload_identity_0_create.tf index b5d0ebe8aae08..a506ee5773d06 100644 --- a/integrations/terraform/testlib/fixtures/workload_identity_0_create.tf +++ b/integrations/terraform/testlib/fixtures/workload_identity_0_create.tf @@ -9,7 +9,9 @@ resource "teleport_workload_identity" "test" { { conditions = [{ attribute = "user.name" - equals = "foo" + eq = { + value = "foo" + } }] } ] diff --git a/integrations/terraform/testlib/fixtures/workload_identity_1_update.tf b/integrations/terraform/testlib/fixtures/workload_identity_1_update.tf index cced0a4f8ecdd..bb64491258471 100644 --- a/integrations/terraform/testlib/fixtures/workload_identity_1_update.tf +++ b/integrations/terraform/testlib/fixtures/workload_identity_1_update.tf @@ -9,7 +9,9 @@ resource "teleport_workload_identity" "test" { { conditions = [{ attribute = "user.name" - equals = "foo" + eq = { + value = "foo" + } }] } ] diff --git a/integrations/terraform/testlib/workload_identity_test.go b/integrations/terraform/testlib/workload_identity_test.go index 1e6d84cf6feb9..3e6d5a6ca4342 100644 --- a/integrations/terraform/testlib/workload_identity_test.go +++ b/integrations/terraform/testlib/workload_identity_test.go @@ -55,7 +55,7 @@ func (s *TerraformSuiteOSS) TestWorkloadIdentity() { resource.TestCheckResourceAttr(name, "kind", "workload_identity"), resource.TestCheckResourceAttr(name, "spec.spiffe.id", "/test"), resource.TestCheckResourceAttr(name, "spec.rules.allow.0.conditions.0.attribute", "user.name"), - resource.TestCheckResourceAttr(name, "spec.rules.allow.0.conditions.0.equals", "foo"), + resource.TestCheckResourceAttr(name, "spec.rules.allow.0.conditions.0.eq.value", "foo"), ), }, { @@ -68,7 +68,7 @@ func (s *TerraformSuiteOSS) TestWorkloadIdentity() { resource.TestCheckResourceAttr(name, "kind", "workload_identity"), resource.TestCheckResourceAttr(name, "spec.spiffe.id", "/test/updated"), resource.TestCheckResourceAttr(name, "spec.rules.allow.0.conditions.0.attribute", "user.name"), - resource.TestCheckResourceAttr(name, "spec.rules.allow.0.conditions.0.equals", "foo"), + resource.TestCheckResourceAttr(name, "spec.rules.allow.0.conditions.0.eq.value", "foo"), ), }, { @@ -101,7 +101,11 @@ func (s *TerraformSuiteOSS) TestImportWorkloadIdentity() { Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{ { Attribute: "user.name", - Equals: "foo", + Operator: &workloadidentityv1pb.WorkloadIdentityCondition_Eq{ + Eq: &workloadidentityv1pb.WorkloadIdentityConditionEq{ + Value: "foo", + }, + }, }, }, }, @@ -133,7 +137,7 @@ func (s *TerraformSuiteOSS) TestImportWorkloadIdentity() { require.Equal(t, types.KindWorkloadIdentity, state[0].Attributes["kind"]) require.Equal(t, "/test", state[0].Attributes["spec.spiffe.id"]) require.Equal(t, "user.name", state[0].Attributes["spec.rules.allow.0.conditions.0.attribute"]) - require.Equal(t, "foo", state[0].Attributes["spec.rules.allow.0.conditions.0.equals"]) + require.Equal(t, "foo", state[0].Attributes["spec.rules.allow.0.conditions.0.eq.value"]) return nil }, diff --git a/integrations/terraform/tfschema/workloadidentity/v1/resource_terraform.go b/integrations/terraform/tfschema/workloadidentity/v1/resource_terraform.go index 5a76525cde345..27c4412b764de 100644 --- a/integrations/terraform/tfschema/workloadidentity/v1/resource_terraform.go +++ b/integrations/terraform/tfschema/workloadidentity/v1/resource_terraform.go @@ -107,10 +107,41 @@ func GenSchemaWorkloadIdentity(ctx context.Context) (github_com_hashicorp_terraf Optional: true, Type: github_com_hashicorp_terraform_plugin_framework_types.StringType, }, - "equals": { - Description: "An exact string that the attribute must match.", + "eq": { + Attributes: github_com_hashicorp_terraform_plugin_framework_tfsdk.SingleNestedAttributes(map[string]github_com_hashicorp_terraform_plugin_framework_tfsdk.Attribute{"value": { + Description: "The value to compare the attribute against.", + Optional: true, + Type: github_com_hashicorp_terraform_plugin_framework_types.StringType, + }}), + Description: "The attribute casted to a string must be equal to the value.", + Optional: true, + }, + "in": { + Attributes: github_com_hashicorp_terraform_plugin_framework_tfsdk.SingleNestedAttributes(map[string]github_com_hashicorp_terraform_plugin_framework_tfsdk.Attribute{"values": { + Description: "The list of values to compare the attribute against.", + Optional: true, + Type: github_com_hashicorp_terraform_plugin_framework_types.ListType{ElemType: github_com_hashicorp_terraform_plugin_framework_types.StringType}, + }}), + Description: "The attribute casted to a string must be in the list of values.", + Optional: true, + }, + "not_eq": { + Attributes: github_com_hashicorp_terraform_plugin_framework_tfsdk.SingleNestedAttributes(map[string]github_com_hashicorp_terraform_plugin_framework_tfsdk.Attribute{"value": { + Description: "The value to compare the attribute against.", + Optional: true, + Type: github_com_hashicorp_terraform_plugin_framework_types.StringType, + }}), + Description: "The attribute casted to a string must not be equal to the value.", + Optional: true, + }, + "not_in": { + Attributes: github_com_hashicorp_terraform_plugin_framework_tfsdk.SingleNestedAttributes(map[string]github_com_hashicorp_terraform_plugin_framework_tfsdk.Attribute{"values": { + Description: "The list of values to compare the attribute against.", + Optional: true, + Type: github_com_hashicorp_terraform_plugin_framework_types.ListType{ElemType: github_com_hashicorp_terraform_plugin_framework_types.StringType}, + }}), + Description: "The attribute casted to a string must not be in the list of values.", Optional: true, - Type: github_com_hashicorp_terraform_plugin_framework_types.StringType, }, }), Description: "The conditions that must be met for this rule to be considered passed.", @@ -408,6 +439,7 @@ func CopyWorkloadIdentityFromTerraform(_ context.Context, tf github_com_hashicor tf := v t = &github_com_gravitational_teleport_api_gen_proto_go_teleport_workloadidentity_v1.WorkloadIdentityCondition{} obj := t + obj.Operator = nil { a, ok := tf.Attrs["attribute"] if !ok { @@ -426,19 +458,162 @@ func CopyWorkloadIdentityFromTerraform(_ context.Context, tf github_com_hashicor } } { - a, ok := tf.Attrs["equals"] + a, ok := tf.Attrs["eq"] if !ok { - diags.Append(attrReadMissingDiag{"WorkloadIdentity.spec.rules.allow.conditions.equals"}) + diags.Append(attrReadMissingDiag{"WorkloadIdentity.spec.rules.allow.conditions.eq"}) } else { - v, ok := a.(github_com_hashicorp_terraform_plugin_framework_types.String) + v, ok := a.(github_com_hashicorp_terraform_plugin_framework_types.Object) if !ok { - diags.Append(attrReadConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.equals", "github.com/hashicorp/terraform-plugin-framework/types.String"}) + diags.Append(attrReadConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.eq", "github.com/hashicorp/terraform-plugin-framework/types.Object"}) } else { - var t string if !v.Null && !v.Unknown { - t = string(v.Value) + b := &github_com_gravitational_teleport_api_gen_proto_go_teleport_workloadidentity_v1.WorkloadIdentityConditionEq{} + obj.Operator = &github_com_gravitational_teleport_api_gen_proto_go_teleport_workloadidentity_v1.WorkloadIdentityCondition_Eq{Eq: b} + obj := b + tf := v + { + a, ok := tf.Attrs["value"] + if !ok { + diags.Append(attrReadMissingDiag{"WorkloadIdentity.spec.rules.allow.conditions.eq.value"}) + } else { + v, ok := a.(github_com_hashicorp_terraform_plugin_framework_types.String) + if !ok { + diags.Append(attrReadConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.eq.value", "github.com/hashicorp/terraform-plugin-framework/types.String"}) + } else { + var t string + if !v.Null && !v.Unknown { + t = string(v.Value) + } + obj.Value = t + } + } + } + } + } + } + } + { + a, ok := tf.Attrs["not_eq"] + if !ok { + diags.Append(attrReadMissingDiag{"WorkloadIdentity.spec.rules.allow.conditions.not_eq"}) + } else { + v, ok := a.(github_com_hashicorp_terraform_plugin_framework_types.Object) + if !ok { + diags.Append(attrReadConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.not_eq", "github.com/hashicorp/terraform-plugin-framework/types.Object"}) + } else { + if !v.Null && !v.Unknown { + b := &github_com_gravitational_teleport_api_gen_proto_go_teleport_workloadidentity_v1.WorkloadIdentityConditionNotEq{} + obj.Operator = &github_com_gravitational_teleport_api_gen_proto_go_teleport_workloadidentity_v1.WorkloadIdentityCondition_NotEq{NotEq: b} + obj := b + tf := v + { + a, ok := tf.Attrs["value"] + if !ok { + diags.Append(attrReadMissingDiag{"WorkloadIdentity.spec.rules.allow.conditions.not_eq.value"}) + } else { + v, ok := a.(github_com_hashicorp_terraform_plugin_framework_types.String) + if !ok { + diags.Append(attrReadConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.not_eq.value", "github.com/hashicorp/terraform-plugin-framework/types.String"}) + } else { + var t string + if !v.Null && !v.Unknown { + t = string(v.Value) + } + obj.Value = t + } + } + } + } + } + } + } + { + a, ok := tf.Attrs["in"] + if !ok { + diags.Append(attrReadMissingDiag{"WorkloadIdentity.spec.rules.allow.conditions.in"}) + } else { + v, ok := a.(github_com_hashicorp_terraform_plugin_framework_types.Object) + if !ok { + diags.Append(attrReadConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.in", "github.com/hashicorp/terraform-plugin-framework/types.Object"}) + } else { + if !v.Null && !v.Unknown { + b := &github_com_gravitational_teleport_api_gen_proto_go_teleport_workloadidentity_v1.WorkloadIdentityConditionIn{} + obj.Operator = &github_com_gravitational_teleport_api_gen_proto_go_teleport_workloadidentity_v1.WorkloadIdentityCondition_In{In: b} + obj := b + tf := v + { + a, ok := tf.Attrs["values"] + if !ok { + diags.Append(attrReadMissingDiag{"WorkloadIdentity.spec.rules.allow.conditions.in.values"}) + } else { + v, ok := a.(github_com_hashicorp_terraform_plugin_framework_types.List) + if !ok { + diags.Append(attrReadConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.in.values", "github.com/hashicorp/terraform-plugin-framework/types.List"}) + } else { + obj.Values = make([]string, len(v.Elems)) + if !v.Null && !v.Unknown { + for k, a := range v.Elems { + v, ok := a.(github_com_hashicorp_terraform_plugin_framework_types.String) + if !ok { + diags.Append(attrReadConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.in.values", "github_com_hashicorp_terraform_plugin_framework_types.String"}) + } else { + var t string + if !v.Null && !v.Unknown { + t = string(v.Value) + } + obj.Values[k] = t + } + } + } + } + } + } + } + } + } + } + { + a, ok := tf.Attrs["not_in"] + if !ok { + diags.Append(attrReadMissingDiag{"WorkloadIdentity.spec.rules.allow.conditions.not_in"}) + } else { + v, ok := a.(github_com_hashicorp_terraform_plugin_framework_types.Object) + if !ok { + diags.Append(attrReadConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.not_in", "github.com/hashicorp/terraform-plugin-framework/types.Object"}) + } else { + if !v.Null && !v.Unknown { + b := &github_com_gravitational_teleport_api_gen_proto_go_teleport_workloadidentity_v1.WorkloadIdentityConditionNotIn{} + obj.Operator = &github_com_gravitational_teleport_api_gen_proto_go_teleport_workloadidentity_v1.WorkloadIdentityCondition_NotIn{NotIn: b} + obj := b + tf := v + { + a, ok := tf.Attrs["values"] + if !ok { + diags.Append(attrReadMissingDiag{"WorkloadIdentity.spec.rules.allow.conditions.not_in.values"}) + } else { + v, ok := a.(github_com_hashicorp_terraform_plugin_framework_types.List) + if !ok { + diags.Append(attrReadConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.not_in.values", "github.com/hashicorp/terraform-plugin-framework/types.List"}) + } else { + obj.Values = make([]string, len(v.Elems)) + if !v.Null && !v.Unknown { + for k, a := range v.Elems { + v, ok := a.(github_com_hashicorp_terraform_plugin_framework_types.String) + if !ok { + diags.Append(attrReadConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.not_in.values", "github_com_hashicorp_terraform_plugin_framework_types.String"}) + } else { + var t string + if !v.Null && !v.Unknown { + t = string(v.Value) + } + obj.Values[k] = t + } + } + } + } + } + } } - obj.Equals = t } } } @@ -984,25 +1159,297 @@ func CopyWorkloadIdentityToTerraform(ctx context.Context, obj *github_com_gravit } } { - t, ok := tf.AttrTypes["equals"] + a, ok := tf.AttrTypes["eq"] if !ok { - diags.Append(attrWriteMissingDiag{"WorkloadIdentity.spec.rules.allow.conditions.equals"}) + diags.Append(attrWriteMissingDiag{"WorkloadIdentity.spec.rules.allow.conditions.eq"}) } else { - v, ok := tf.Attrs["equals"].(github_com_hashicorp_terraform_plugin_framework_types.String) + obj, ok := obj.Operator.(*github_com_gravitational_teleport_api_gen_proto_go_teleport_workloadidentity_v1.WorkloadIdentityCondition_Eq) if !ok { - i, err := t.ValueFromTerraform(ctx, github_com_hashicorp_terraform_plugin_go_tftypes.NewValue(t.TerraformType(ctx), nil)) - if err != nil { - diags.Append(attrWriteGeneralError{"WorkloadIdentity.spec.rules.allow.conditions.equals", err}) + obj = &github_com_gravitational_teleport_api_gen_proto_go_teleport_workloadidentity_v1.WorkloadIdentityCondition_Eq{} + } + o, ok := a.(github_com_hashicorp_terraform_plugin_framework_types.ObjectType) + if !ok { + diags.Append(attrWriteConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.eq", "github.com/hashicorp/terraform-plugin-framework/types.ObjectType"}) + } else { + v, ok := tf.Attrs["eq"].(github_com_hashicorp_terraform_plugin_framework_types.Object) + if !ok { + v = github_com_hashicorp_terraform_plugin_framework_types.Object{ + + AttrTypes: o.AttrTypes, + Attrs: make(map[string]github_com_hashicorp_terraform_plugin_framework_attr.Value, len(o.AttrTypes)), + } + } else { + if v.Attrs == nil { + v.Attrs = make(map[string]github_com_hashicorp_terraform_plugin_framework_attr.Value, len(tf.AttrTypes)) + } } - v, ok = i.(github_com_hashicorp_terraform_plugin_framework_types.String) + if obj.Eq == nil { + v.Null = true + } else { + obj := obj.Eq + tf := &v + { + t, ok := tf.AttrTypes["value"] + if !ok { + diags.Append(attrWriteMissingDiag{"WorkloadIdentity.spec.rules.allow.conditions.eq.value"}) + } else { + v, ok := tf.Attrs["value"].(github_com_hashicorp_terraform_plugin_framework_types.String) + if !ok { + i, err := t.ValueFromTerraform(ctx, github_com_hashicorp_terraform_plugin_go_tftypes.NewValue(t.TerraformType(ctx), nil)) + if err != nil { + diags.Append(attrWriteGeneralError{"WorkloadIdentity.spec.rules.allow.conditions.eq.value", err}) + } + v, ok = i.(github_com_hashicorp_terraform_plugin_framework_types.String) + if !ok { + diags.Append(attrWriteConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.eq.value", "github.com/hashicorp/terraform-plugin-framework/types.String"}) + } + v.Null = string(obj.Value) == "" + } + v.Value = string(obj.Value) + v.Unknown = false + tf.Attrs["value"] = v + } + } + } + v.Unknown = false + tf.Attrs["eq"] = v + } + } + } + { + a, ok := tf.AttrTypes["not_eq"] + if !ok { + diags.Append(attrWriteMissingDiag{"WorkloadIdentity.spec.rules.allow.conditions.not_eq"}) + } else { + obj, ok := obj.Operator.(*github_com_gravitational_teleport_api_gen_proto_go_teleport_workloadidentity_v1.WorkloadIdentityCondition_NotEq) + if !ok { + obj = &github_com_gravitational_teleport_api_gen_proto_go_teleport_workloadidentity_v1.WorkloadIdentityCondition_NotEq{} + } + o, ok := a.(github_com_hashicorp_terraform_plugin_framework_types.ObjectType) + if !ok { + diags.Append(attrWriteConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.not_eq", "github.com/hashicorp/terraform-plugin-framework/types.ObjectType"}) + } else { + v, ok := tf.Attrs["not_eq"].(github_com_hashicorp_terraform_plugin_framework_types.Object) if !ok { - diags.Append(attrWriteConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.equals", "github.com/hashicorp/terraform-plugin-framework/types.String"}) + v = github_com_hashicorp_terraform_plugin_framework_types.Object{ + + AttrTypes: o.AttrTypes, + Attrs: make(map[string]github_com_hashicorp_terraform_plugin_framework_attr.Value, len(o.AttrTypes)), + } + } else { + if v.Attrs == nil { + v.Attrs = make(map[string]github_com_hashicorp_terraform_plugin_framework_attr.Value, len(tf.AttrTypes)) + } + } + if obj.NotEq == nil { + v.Null = true + } else { + obj := obj.NotEq + tf := &v + { + t, ok := tf.AttrTypes["value"] + if !ok { + diags.Append(attrWriteMissingDiag{"WorkloadIdentity.spec.rules.allow.conditions.not_eq.value"}) + } else { + v, ok := tf.Attrs["value"].(github_com_hashicorp_terraform_plugin_framework_types.String) + if !ok { + i, err := t.ValueFromTerraform(ctx, github_com_hashicorp_terraform_plugin_go_tftypes.NewValue(t.TerraformType(ctx), nil)) + if err != nil { + diags.Append(attrWriteGeneralError{"WorkloadIdentity.spec.rules.allow.conditions.not_eq.value", err}) + } + v, ok = i.(github_com_hashicorp_terraform_plugin_framework_types.String) + if !ok { + diags.Append(attrWriteConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.not_eq.value", "github.com/hashicorp/terraform-plugin-framework/types.String"}) + } + v.Null = string(obj.Value) == "" + } + v.Value = string(obj.Value) + v.Unknown = false + tf.Attrs["value"] = v + } + } } - v.Null = string(obj.Equals) == "" + v.Unknown = false + tf.Attrs["not_eq"] = v + } + } + } + { + a, ok := tf.AttrTypes["in"] + if !ok { + diags.Append(attrWriteMissingDiag{"WorkloadIdentity.spec.rules.allow.conditions.in"}) + } else { + obj, ok := obj.Operator.(*github_com_gravitational_teleport_api_gen_proto_go_teleport_workloadidentity_v1.WorkloadIdentityCondition_In) + if !ok { + obj = &github_com_gravitational_teleport_api_gen_proto_go_teleport_workloadidentity_v1.WorkloadIdentityCondition_In{} + } + o, ok := a.(github_com_hashicorp_terraform_plugin_framework_types.ObjectType) + if !ok { + diags.Append(attrWriteConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.in", "github.com/hashicorp/terraform-plugin-framework/types.ObjectType"}) + } else { + v, ok := tf.Attrs["in"].(github_com_hashicorp_terraform_plugin_framework_types.Object) + if !ok { + v = github_com_hashicorp_terraform_plugin_framework_types.Object{ + + AttrTypes: o.AttrTypes, + Attrs: make(map[string]github_com_hashicorp_terraform_plugin_framework_attr.Value, len(o.AttrTypes)), + } + } else { + if v.Attrs == nil { + v.Attrs = make(map[string]github_com_hashicorp_terraform_plugin_framework_attr.Value, len(tf.AttrTypes)) + } + } + if obj.In == nil { + v.Null = true + } else { + obj := obj.In + tf := &v + { + a, ok := tf.AttrTypes["values"] + if !ok { + diags.Append(attrWriteMissingDiag{"WorkloadIdentity.spec.rules.allow.conditions.in.values"}) + } else { + o, ok := a.(github_com_hashicorp_terraform_plugin_framework_types.ListType) + if !ok { + diags.Append(attrWriteConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.in.values", "github.com/hashicorp/terraform-plugin-framework/types.ListType"}) + } else { + c, ok := tf.Attrs["values"].(github_com_hashicorp_terraform_plugin_framework_types.List) + if !ok { + c = github_com_hashicorp_terraform_plugin_framework_types.List{ + + ElemType: o.ElemType, + Elems: make([]github_com_hashicorp_terraform_plugin_framework_attr.Value, len(obj.Values)), + Null: true, + } + } else { + if c.Elems == nil { + c.Elems = make([]github_com_hashicorp_terraform_plugin_framework_attr.Value, len(obj.Values)) + } + } + if obj.Values != nil { + t := o.ElemType + if len(obj.Values) != len(c.Elems) { + c.Elems = make([]github_com_hashicorp_terraform_plugin_framework_attr.Value, len(obj.Values)) + } + for k, a := range obj.Values { + v, ok := tf.Attrs["values"].(github_com_hashicorp_terraform_plugin_framework_types.String) + if !ok { + i, err := t.ValueFromTerraform(ctx, github_com_hashicorp_terraform_plugin_go_tftypes.NewValue(t.TerraformType(ctx), nil)) + if err != nil { + diags.Append(attrWriteGeneralError{"WorkloadIdentity.spec.rules.allow.conditions.in.values", err}) + } + v, ok = i.(github_com_hashicorp_terraform_plugin_framework_types.String) + if !ok { + diags.Append(attrWriteConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.in.values", "github.com/hashicorp/terraform-plugin-framework/types.String"}) + } + v.Null = string(a) == "" + } + v.Value = string(a) + v.Unknown = false + c.Elems[k] = v + } + if len(obj.Values) > 0 { + c.Null = false + } + } + c.Unknown = false + tf.Attrs["values"] = c + } + } + } + } + v.Unknown = false + tf.Attrs["in"] = v + } + } + } + { + a, ok := tf.AttrTypes["not_in"] + if !ok { + diags.Append(attrWriteMissingDiag{"WorkloadIdentity.spec.rules.allow.conditions.not_in"}) + } else { + obj, ok := obj.Operator.(*github_com_gravitational_teleport_api_gen_proto_go_teleport_workloadidentity_v1.WorkloadIdentityCondition_NotIn) + if !ok { + obj = &github_com_gravitational_teleport_api_gen_proto_go_teleport_workloadidentity_v1.WorkloadIdentityCondition_NotIn{} + } + o, ok := a.(github_com_hashicorp_terraform_plugin_framework_types.ObjectType) + if !ok { + diags.Append(attrWriteConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.not_in", "github.com/hashicorp/terraform-plugin-framework/types.ObjectType"}) + } else { + v, ok := tf.Attrs["not_in"].(github_com_hashicorp_terraform_plugin_framework_types.Object) + if !ok { + v = github_com_hashicorp_terraform_plugin_framework_types.Object{ + + AttrTypes: o.AttrTypes, + Attrs: make(map[string]github_com_hashicorp_terraform_plugin_framework_attr.Value, len(o.AttrTypes)), + } + } else { + if v.Attrs == nil { + v.Attrs = make(map[string]github_com_hashicorp_terraform_plugin_framework_attr.Value, len(tf.AttrTypes)) + } + } + if obj.NotIn == nil { + v.Null = true + } else { + obj := obj.NotIn + tf := &v + { + a, ok := tf.AttrTypes["values"] + if !ok { + diags.Append(attrWriteMissingDiag{"WorkloadIdentity.spec.rules.allow.conditions.not_in.values"}) + } else { + o, ok := a.(github_com_hashicorp_terraform_plugin_framework_types.ListType) + if !ok { + diags.Append(attrWriteConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.not_in.values", "github.com/hashicorp/terraform-plugin-framework/types.ListType"}) + } else { + c, ok := tf.Attrs["values"].(github_com_hashicorp_terraform_plugin_framework_types.List) + if !ok { + c = github_com_hashicorp_terraform_plugin_framework_types.List{ + + ElemType: o.ElemType, + Elems: make([]github_com_hashicorp_terraform_plugin_framework_attr.Value, len(obj.Values)), + Null: true, + } + } else { + if c.Elems == nil { + c.Elems = make([]github_com_hashicorp_terraform_plugin_framework_attr.Value, len(obj.Values)) + } + } + if obj.Values != nil { + t := o.ElemType + if len(obj.Values) != len(c.Elems) { + c.Elems = make([]github_com_hashicorp_terraform_plugin_framework_attr.Value, len(obj.Values)) + } + for k, a := range obj.Values { + v, ok := tf.Attrs["values"].(github_com_hashicorp_terraform_plugin_framework_types.String) + if !ok { + i, err := t.ValueFromTerraform(ctx, github_com_hashicorp_terraform_plugin_go_tftypes.NewValue(t.TerraformType(ctx), nil)) + if err != nil { + diags.Append(attrWriteGeneralError{"WorkloadIdentity.spec.rules.allow.conditions.not_in.values", err}) + } + v, ok = i.(github_com_hashicorp_terraform_plugin_framework_types.String) + if !ok { + diags.Append(attrWriteConversionFailureDiag{"WorkloadIdentity.spec.rules.allow.conditions.not_in.values", "github.com/hashicorp/terraform-plugin-framework/types.String"}) + } + v.Null = string(a) == "" + } + v.Value = string(a) + v.Unknown = false + c.Elems[k] = v + } + if len(obj.Values) > 0 { + c.Null = false + } + } + c.Unknown = false + tf.Attrs["values"] = c + } + } + } + } + v.Unknown = false + tf.Attrs["not_in"] = v } - v.Value = string(obj.Equals) - v.Unknown = false - tf.Attrs["equals"] = v } } } diff --git a/lib/auth/machineid/workloadidentityv1/decision.go b/lib/auth/machineid/workloadidentityv1/decision.go index 4e959efe6ee12..ccbbde2967c90 100644 --- a/lib/auth/machineid/workloadidentityv1/decision.go +++ b/lib/auth/machineid/workloadidentityv1/decision.go @@ -176,8 +176,25 @@ ruleLoop: if err != nil { return trace.Wrap(err) } - if val != condition.Equals { - continue ruleLoop + switch c := condition.Operator.(type) { + case *workloadidentityv1pb.WorkloadIdentityCondition_Eq: + if val != c.Eq.Value { + continue ruleLoop + } + case *workloadidentityv1pb.WorkloadIdentityCondition_NotEq: + if val == c.NotEq.Value { + continue ruleLoop + } + case *workloadidentityv1pb.WorkloadIdentityCondition_In: + if !slices.Contains(c.In.Values, val) { + continue ruleLoop + } + case *workloadidentityv1pb.WorkloadIdentityCondition_NotIn: + if slices.Contains(c.NotIn.Values, val) { + continue ruleLoop + } + default: + return trace.BadParameter("unsupported operator %T", c) } } return nil diff --git a/lib/auth/machineid/workloadidentityv1/decision_test.go b/lib/auth/machineid/workloadidentityv1/decision_test.go index 5d00bf7595669..3d2b9ed4cff95 100644 --- a/lib/auth/machineid/workloadidentityv1/decision_test.go +++ b/lib/auth/machineid/workloadidentityv1/decision_test.go @@ -263,28 +263,285 @@ func Test_evaluateRules(t *testing.T) { User: &workloadidentityv1pb.UserAttrs{ Name: "foo", }, + Workload: &workloadidentityv1pb.WorkloadAttrs{ + Kubernetes: &workloadidentityv1pb.WorkloadAttrsKubernetes{ + PodName: "pod1", + Namespace: "default", + }, + }, + } + + var noMatchRule require.ErrorAssertionFunc = func(t require.TestingT, err error, i ...interface{}) { + require.Error(t, err) + require.Contains(t, err.Error(), "no matching rule found") } - wi := &workloadidentityv1pb.WorkloadIdentity{ - Kind: types.KindWorkloadIdentity, - Version: types.V1, - Metadata: &headerv1.Metadata{ - Name: "test", + + tests := []struct { + name string + wid *workloadidentityv1pb.WorkloadIdentity + attrs *workloadidentityv1pb.Attrs + requireErr require.ErrorAssertionFunc + }{ + { + name: "no rules: pass", + wid: &workloadidentityv1pb.WorkloadIdentity{ + Kind: types.KindWorkloadIdentity, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "test", + }, + Spec: &workloadidentityv1pb.WorkloadIdentitySpec{ + Rules: &workloadidentityv1pb.WorkloadIdentityRules{}, + }, + }, + attrs: attrs, + requireErr: require.NoError, + }, + { + name: "eq: pass", + wid: &workloadidentityv1pb.WorkloadIdentity{ + Kind: types.KindWorkloadIdentity, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "test", + }, + Spec: &workloadidentityv1pb.WorkloadIdentitySpec{ + Rules: &workloadidentityv1pb.WorkloadIdentityRules{ + Allow: []*workloadidentityv1pb.WorkloadIdentityRule{ + { + Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{ + { + Attribute: "user.name", + Operator: &workloadidentityv1pb.WorkloadIdentityCondition_Eq{ + Eq: &workloadidentityv1pb.WorkloadIdentityConditionEq{ + Value: "foo", + }, + }, + }, + }, + }, + }, + }, + }, + }, + attrs: attrs, + requireErr: require.NoError, + }, + { + name: "eq: fail", + wid: &workloadidentityv1pb.WorkloadIdentity{ + Kind: types.KindWorkloadIdentity, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "test", + }, + Spec: &workloadidentityv1pb.WorkloadIdentitySpec{ + Rules: &workloadidentityv1pb.WorkloadIdentityRules{ + Allow: []*workloadidentityv1pb.WorkloadIdentityRule{ + { + Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{ + { + Attribute: "user.name", + Operator: &workloadidentityv1pb.WorkloadIdentityCondition_Eq{ + Eq: &workloadidentityv1pb.WorkloadIdentityConditionEq{ + Value: "not-foo", + }, + }, + }, + }, + }, + }, + }, + }, + }, + attrs: attrs, + requireErr: noMatchRule, + }, + { + name: "not_eq: pass", + wid: &workloadidentityv1pb.WorkloadIdentity{ + Kind: types.KindWorkloadIdentity, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "test", + }, + Spec: &workloadidentityv1pb.WorkloadIdentitySpec{ + Rules: &workloadidentityv1pb.WorkloadIdentityRules{ + Allow: []*workloadidentityv1pb.WorkloadIdentityRule{ + { + Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{ + { + Attribute: "user.name", + Operator: &workloadidentityv1pb.WorkloadIdentityCondition_NotEq{ + NotEq: &workloadidentityv1pb.WorkloadIdentityConditionNotEq{ + Value: "bar", + }, + }, + }, + }, + }, + }, + }, + }, + }, + attrs: attrs, + requireErr: require.NoError, + }, + { + name: "not_eq: fail", + wid: &workloadidentityv1pb.WorkloadIdentity{ + Kind: types.KindWorkloadIdentity, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "test", + }, + Spec: &workloadidentityv1pb.WorkloadIdentitySpec{ + Rules: &workloadidentityv1pb.WorkloadIdentityRules{ + Allow: []*workloadidentityv1pb.WorkloadIdentityRule{ + { + Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{ + { + Attribute: "user.name", + Operator: &workloadidentityv1pb.WorkloadIdentityCondition_NotEq{ + NotEq: &workloadidentityv1pb.WorkloadIdentityConditionNotEq{ + Value: "foo", + }, + }, + }, + }, + }, + }, + }, + }, + }, + attrs: attrs, + requireErr: noMatchRule, + }, + { + name: "in: pass", + wid: &workloadidentityv1pb.WorkloadIdentity{ + Kind: types.KindWorkloadIdentity, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "test", + }, + Spec: &workloadidentityv1pb.WorkloadIdentitySpec{ + Rules: &workloadidentityv1pb.WorkloadIdentityRules{ + Allow: []*workloadidentityv1pb.WorkloadIdentityRule{ + { + Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{ + { + Attribute: "user.name", + Operator: &workloadidentityv1pb.WorkloadIdentityCondition_In{ + In: &workloadidentityv1pb.WorkloadIdentityConditionIn{ + Values: []string{"bar", "foo"}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + attrs: attrs, + requireErr: require.NoError, }, - Spec: &workloadidentityv1pb.WorkloadIdentitySpec{ - Rules: &workloadidentityv1pb.WorkloadIdentityRules{ - Allow: []*workloadidentityv1pb.WorkloadIdentityRule{ - { - Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{ + { + name: "in: fail", + wid: &workloadidentityv1pb.WorkloadIdentity{ + Kind: types.KindWorkloadIdentity, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "test", + }, + Spec: &workloadidentityv1pb.WorkloadIdentitySpec{ + Rules: &workloadidentityv1pb.WorkloadIdentityRules{ + Allow: []*workloadidentityv1pb.WorkloadIdentityRule{ { - Attribute: "user.name", - Equals: "foo", + Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{ + { + Attribute: "user.name", + Operator: &workloadidentityv1pb.WorkloadIdentityCondition_In{ + In: &workloadidentityv1pb.WorkloadIdentityConditionIn{ + Values: []string{"bar", "fizz"}, + }, + }, + }, + }, }, }, }, }, }, + attrs: attrs, + requireErr: noMatchRule, }, + { + name: "not_in: pass", + wid: &workloadidentityv1pb.WorkloadIdentity{ + Kind: types.KindWorkloadIdentity, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "test", + }, + Spec: &workloadidentityv1pb.WorkloadIdentitySpec{ + Rules: &workloadidentityv1pb.WorkloadIdentityRules{ + Allow: []*workloadidentityv1pb.WorkloadIdentityRule{ + { + Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{ + { + Attribute: "user.name", + Operator: &workloadidentityv1pb.WorkloadIdentityCondition_NotIn{ + NotIn: &workloadidentityv1pb.WorkloadIdentityConditionNotIn{ + Values: []string{"bar", "fizz"}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + attrs: attrs, + requireErr: require.NoError, + }, + { + name: "in: fail", + wid: &workloadidentityv1pb.WorkloadIdentity{ + Kind: types.KindWorkloadIdentity, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "test", + }, + Spec: &workloadidentityv1pb.WorkloadIdentitySpec{ + Rules: &workloadidentityv1pb.WorkloadIdentityRules{ + Allow: []*workloadidentityv1pb.WorkloadIdentityRule{ + { + Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{ + { + Attribute: "user.name", + Operator: &workloadidentityv1pb.WorkloadIdentityCondition_NotIn{ + NotIn: &workloadidentityv1pb.WorkloadIdentityConditionNotIn{ + Values: []string{"bar", "foo"}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + attrs: attrs, + requireErr: noMatchRule, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := evaluateRules(tt.wid, tt.attrs) + tt.requireErr(t, err) + }) } - err := evaluateRules(wi, attrs) - require.NoError(t, err) } diff --git a/lib/auth/machineid/workloadidentityv1/workloadidentityv1_test.go b/lib/auth/machineid/workloadidentityv1/workloadidentityv1_test.go index e5f23dc96216c..1ddf63bcf28d1 100644 --- a/lib/auth/machineid/workloadidentityv1/workloadidentityv1_test.go +++ b/lib/auth/machineid/workloadidentityv1/workloadidentityv1_test.go @@ -187,7 +187,11 @@ func TestIssueWorkloadIdentityE2E(t *testing.T) { Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{ { Attribute: "join.kubernetes.service_account.namespace", - Equals: "my-namespace", + Operator: &workloadidentityv1pb.WorkloadIdentityCondition_Eq{ + Eq: &workloadidentityv1pb.WorkloadIdentityConditionEq{ + Value: "my-namespace", + }, + }, }, }, }, @@ -402,11 +406,19 @@ func TestIssueWorkloadIdentity(t *testing.T) { Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{ { Attribute: "user.name", - Equals: "dog", + Operator: &workloadidentityv1pb.WorkloadIdentityCondition_Eq{ + Eq: &workloadidentityv1pb.WorkloadIdentityConditionEq{ + Value: "dog", + }, + }, }, { Attribute: "workload.kubernetes.namespace", - Equals: "default", + Operator: &workloadidentityv1pb.WorkloadIdentityCondition_Eq{ + Eq: &workloadidentityv1pb.WorkloadIdentityConditionEq{ + Value: "default", + }, + }, }, }, }, @@ -768,7 +780,11 @@ func TestIssueWorkloadIdentities(t *testing.T) { Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{ { Attribute: "workload.kubernetes.namespace", - Equals: "default", + Operator: &workloadidentityv1pb.WorkloadIdentityCondition_Eq{ + Eq: &workloadidentityv1pb.WorkloadIdentityConditionEq{ + Value: "default", + }, + }, }, }, }, @@ -798,7 +814,11 @@ func TestIssueWorkloadIdentities(t *testing.T) { Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{ { Attribute: "workload.kubernetes.namespace", - Equals: "default", + Operator: &workloadidentityv1pb.WorkloadIdentityCondition_Eq{ + Eq: &workloadidentityv1pb.WorkloadIdentityConditionEq{ + Value: "default", + }, + }, }, }, }, diff --git a/lib/services/workload_identity.go b/lib/services/workload_identity.go index 89b87ba0d2473..826ab6540b0e6 100644 --- a/lib/services/workload_identity.go +++ b/lib/services/workload_identity.go @@ -104,16 +104,8 @@ func ValidateWorkloadIdentity(s *workloadidentityv1pb.WorkloadIdentity) error { if condition.Attribute == "" { return trace.BadParameter("spec.rules.allow[%d].conditions[%d].attribute: must be non-empty", i, j) } - // Ensure exactly one operator is set. - operatorsSet := 0 - if condition.Equals != "" { - operatorsSet++ - } - if operatorsSet == 0 || operatorsSet > 1 { - return trace.BadParameter( - "spec.rules.allow[%d].conditions[%d]: exactly one operator must be specified, found %d", - i, j, operatorsSet, - ) + if condition.Operator == nil { + return trace.BadParameter("spec.rules.allow[%d].conditions[%d]: operator must be specified", i, j) } } } diff --git a/lib/services/workload_identity_test.go b/lib/services/workload_identity_test.go index 429612ed48555..27d0e1ec0261b 100644 --- a/lib/services/workload_identity_test.go +++ b/lib/services/workload_identity_test.go @@ -92,7 +92,11 @@ func TestValidateWorkloadIdentity(t *testing.T) { Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{ { Attribute: "example", - Equals: "foo", + Operator: &workloadidentityv1pb.WorkloadIdentityCondition_Eq{ + Eq: &workloadidentityv1pb.WorkloadIdentityConditionEq{ + Value: "foo", + }, + }, }, }, }, @@ -180,7 +184,11 @@ func TestValidateWorkloadIdentity(t *testing.T) { Conditions: []*workloadidentityv1pb.WorkloadIdentityCondition{ { Attribute: "", - Equals: "foo", + Operator: &workloadidentityv1pb.WorkloadIdentityCondition_Eq{ + Eq: &workloadidentityv1pb.WorkloadIdentityConditionEq{ + Value: "foo", + }, + }, }, }, }, @@ -218,7 +226,7 @@ func TestValidateWorkloadIdentity(t *testing.T) { }, }, }, - requireErr: errContains("spec.rules.allow[0].conditions[0]: exactly one operator must be specified, found 0"), + requireErr: errContains("spec.rules.allow[0].conditions[0]: operator must be specified"), }, } diff --git a/tool/tctl/common/collection.go b/tool/tctl/common/collection.go index c1ea21addc2b6..4bf5d1629d0c9 100644 --- a/tool/tctl/common/collection.go +++ b/tool/tctl/common/collection.go @@ -1792,7 +1792,7 @@ type workloadIdentityCollection struct { func (c *workloadIdentityCollection) resources() []types.Resource { r := make([]types.Resource, 0, len(c.items)) for _, resource := range c.items { - r = append(r, types.Resource153ToLegacy(resource)) + r = append(r, types.ProtoResource153ToLegacy(resource)) } return r } From f6544ed060adf57d9441269d40560b21b1daeadc Mon Sep 17 00:00:00 2001 From: Russell Jones Date: Fri, 10 Jan 2025 14:02:20 -0800 Subject: [PATCH 02/15] Added RFD 0144 - Client Tools Updates (#39805) * Added RFD 0144 - Client Tools Updates * Fix. * Fix. * Fix. * Fix. * Fix. * Fix. * Fix. * Fix. * Fix. * Fix. * Fix. * Fix. * Fix. * Fix. * Fix. * Client tools RFD update (#45515) * Add changes proposed for client autoupdate * Add proxy flag and CDN info * Naming adjustments * Naming adjustments * Update client tools autoupdate RFD (#47498) * Drop watch command for autoupdate * Add modified tctl commands and `find` endpoint response * Add words to spell checker --------- Co-authored-by: Vadym Popov --- rfd/0144-client-tools-updates.md | 309 +++++++++++++++++++++++++++++++ rfd/cspell.json | 5 + 2 files changed, 314 insertions(+) create mode 100644 rfd/0144-client-tools-updates.md diff --git a/rfd/0144-client-tools-updates.md b/rfd/0144-client-tools-updates.md new file mode 100644 index 0000000000000..34ba8062971c8 --- /dev/null +++ b/rfd/0144-client-tools-updates.md @@ -0,0 +1,309 @@ +--- +authors: Russell Jones (rjones@goteleport.com) and Bernard Kim (bernard@goteleport.com) +state: draft +--- + +# RFD 0144 - Client Tools Updates + +## Required Approvers + +* Engineering: @sclevine && @bernardjkim && @r0mant +* Product: @klizhentas || @xinding33 +* Security: @reedloden + +## What/Why + +This RFD describes how client tools like `tsh` and `tctl` can be kept up to +date, either using automatic updates or self-managed updates. + +Keeping client tools updated helps with security (fixes for known security +vulnerabilities are pushed to endpoints), bugs (fixes for resolved issues are +pushed to endpoints), and compatibility (users no longer have to learn and +understand [Teleport component +compatibility](https://goteleport.com/docs/upgrading/overview/#component-compatibility) +rules). + +## Details + +### Summary + +Client tools like `tsh` and `tctl` will automatically download and install the +required version for the Teleport cluster. + +Enrollment in automatic updates for client tools will be controlled at the +cluster level. By default all Cloud clusters will be opted into automatic +updates for client tools. Cluster administrators using MDM software like Jamf +will be able opt-out manually manage updates. + +Self-hosted clusters will be be opted out, but have the option to use the same +automatic update mechanism. + +Inspiration drawn from https://go.dev/doc/toolchain. + +### Implementation + +#### Client tools + +##### Automatic updates + +When `tsh login` is executed, client tools will check `/v1/webapi/find` to +determine if automatic updates are enabled. If the cluster's required version +differs from the current binary, client tools will download and re-execute +using the version required by the cluster. All other `tsh` subcommands (like +`tsh ssh ...`) will always use the downloaded version. + +The original client tools binaries won't be overwritten. Instead, an additional +binary will be downloaded and stored in `~/.tsh/bin` with `0755` permissions. + +To validate the binaries have not been corrupted during download, a hash of the +archive will be checked against the expected value. The expected hash value +comes from the archive download path with `.sha256` appended. + +To enable concurrent operation of client tools, a locking mechanisms utilizing +[syscall.Flock](https://pkg.go.dev/syscall#Flock) (for Linux and macOS) and +[LockFileEx](https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-lockfileex) +(for Windows) will be used. + +``` +$ tree ~/.tsh +~/.tsh +├── bin +│ ├── tctl +│ └── tsh +├── current-profile +├── keys +│ └── proxy.example.com +│ ├── cas +│ │ └── example.com.pem +│ ├── certs.pem +│ ├── foo +│ ├── foo-ssh +│ │ └── example.com-cert.pub +│ ├── foo-x509.pem +│ └── foo.pub +├── known_hosts +└── proxy.example.com.yaml +``` + +Users can cancel client tools updates using `Ctrl-C`. This may be needed if the +user is on a low bandwidth connection (LTE or public Wi-Fi), if the Teleport +download server is inaccessible, or the user urgently needs to access the +cluster and can not wait for the update to occur. + +``` +$ tsh login --proxy=proxy.example.com +Client tools are out of date, updating to vX.Y.Z. +Update progress: [▒▒▒▒▒▒ ] (Ctrl-C to cancel update) + +[...] +``` + +All archive downloads are targeted to the `cdn.teleport.dev` endpoint and depend +on the operating system, platform, and edition. Where edition must be identified +by the original client tools binary, URL pattern: +`https://cdn.teleport.dev/teleport-{, ent-}v15.3.0-{linux, darwin, windows}-{amd64,arm64,arm,386}-{fips-}bin.tar.gz` + +An environment variable `TELEPORT_TOOLS_VERSION` will be introduced that can be +`X.Y.Z` (use specific semver version) or `off` (do not update). This +environment variable can be used as a emergency workaround for a known issue, +pinning to a specific version in CI/CD, or for debugging. + +During re-execution, child process will inherit all environment variables and +flags. `TELEPORT_TOOLS_VERSION=off` will be added during re-execution to +prevent infinite loops. + +When `tctl` is used to connect to Auth Service running on the same host over +`localhost`, `tctl` assumes a special administrator role that can perform all +operations on a cluster. In this situation the expectation is for the version +of `tctl` and `teleport` to match so automatic updates will not be used. + +> [!NOTE] +> If a user connects to multiple root clusters, each running a different +> version of Teleport, client tools will attempt to download the differing +> version of Teleport each time the user performs a `tsh login`. +> +> In practice, the number of users impacted by this would be small. Customer +> Cloud tenants would be on the same version and this feature is turned off by +> default for self-hosted cluster. +> +> However, for those people in this situation, the recommendation would be to +> use self-managed updates. + +##### Errors and warnings + +If cluster administrator has chosen not to enroll client tools in automatic +updates and does not self-manage client tools updates as outlined in +[Self-managed client tools updates](#self-managed-client-tools-updates), a +series of warnings and errors with increasing urgency will be shown to the +user. + +If the version of client tools is within the same major version as advertised +by the cluster, a warning will be shown to urge the user to enroll in automatic +updates. Warnings will not prevent the user from using client tools that are +slightly out of date. + +``` +$ tsh login --proxy=proxy.example.com +Warning: Client tools are out of date, update to vX.Y.Z. + +Update Teleport to vX.Y.Z from https://goteleport.com/download or your system +package manager. + +Enroll in automatic updates to keep client tools like tsh and tctl +automatically updated. https://goteleport.com/docs/upgrading/automatic-updates + +[...] +``` + +If the version of client tools is 1 major version below the version advertised +by the cluster, a warning will be shown that indicates some functionality may +not work. + +``` +$ tsh login --proxy=proxy.example.com +WARNING: Client tools are 1 major version out of date, update to vX.Y.Z. + +Some functionality may not work. Update Teleport to vX.Y.Z from +https://goteleport.com/download or your system package manager. + +Enroll in automatic updates to keep client tools like tsh and tctl +automatically updated. https://goteleport.com/docs/upgrading/automatic-updates +``` + +If the version of client tools is 2 (or more) versions lower than the version +advertised by the cluster or 1 (or more) version greater than the version +advertised by the cluster, an error will be shown and will require the user to +use the `--skip-version-check` flag. + +``` +$ tsh login --proxy=proxy.example.com +ERROR: Client tools are N major versions out of date, update to vX.Y.Z. + +Your cluster requires {tsh,tctl} vX.Y.Z. Update Teleport from +https://goteleport.com/download or your system package manager. + +Enroll in automatic updates to keep client tools like tsh and tctl +automatically updated. https://goteleport.com/docs/upgrading/automatic-updates + +Use the "--skip-version-check" flag to bypass this check and attempt to connect +to this cluster. +``` + +#### Self-managed client tools updates + +Cluster administrators that want to self-manage client tools updates will be +able to get changes to client tools versions which can then be +used to trigger other integrations (using MDM software like Jamf) to update the +installed version of client tools on endpoints. + +By defining the `proxy` flag, we can use the get command without logging in. + +``` +$ tctl autoupdate client-tools status --proxy proxy.example.com --format json +{ + "mode": "enabled", + "target_version": "X.Y.Z" +} +``` + +##### Cluster configuration + +Enrollment of clients in automatic updates will be enforced at the cluster +level. + +The `autoupdate_config` resource will be updated to allow cluster +administrators to turn client tools automatic updates `on` or `off`. +A `autoupdate_version` resource will be added to allow cluster administrators +to manage the version of tools pushed to clients. + +> [!NOTE] +> Client tools configuration is broken into two resources to [prevent +> updates](https://github.com/gravitational/teleport/blob/master/lib/modules/modules.go#L332-L355) +> to `autoupdate_version` on Cloud. +> +> While Cloud customers will be able to use `autoupdate_config` to +> turn client tools automatic updates `off` and self-manage updates, they will +> not be able to control the version of client tools in `autoupdate_version`. +> That will continue to be managed by the Teleport Cloud team. + +Both resources can either be updated directly or by using `tctl` helper +functions. + +```yaml +kind: autoupdate_config +spec: + tools: + # tools mode allows to enable client tools updates or disable at the + # cluster level. Disable client tools automatic updates only if self-managed + # updates are in place. + mode: enabled|disabled +``` +``` +$ tctl autoupdate client-tools enable +client tools auto update mode has been changed + +$ tctl autoupdate client-tools disable +client tools auto update mode has been changed +``` + +By default, all Cloud clusters will be opted into `tools.mode: enabled`. All +self-hosted clusters will be opted into `tools.mode: disabled`. + +```yaml +kind: autoupdate_version +spec: + tools: + # target_version is the semver version of client tools the cluster will + # advertise. + target_version: X.Y.Z +``` +``` +$ tctl autoupdate client-tools target X.Y.Z +client tools auto update target version has been set + +$ tctl autoupdate client-tools target --clear +client tools auto update target version has been cleared +``` + +For Cloud clusters, `target_version` will always be `X.Y.Z`, with the version +controlled by the Cloud team. + +The above configuration will then be available from the unauthenticated +proxy discovery endpoint `/v1/webapi/find` which clients will consult. +Resources that store information about autoupdate and tools version are cached on +the proxy side to minimize requests to the auth service. In case of an unhealthy +cache state, the last known version of the resources should be used for the response. + +``` +$ curl https://proxy.example.com/v1/webapi/find | jq .auto_update +{ + "tools_auto_update": true, + "tools_version": "X.Y.Z", +} +``` + +### Costs + +Some additional costs will be incurred as Teleport downloads will increase in +frequency. + +### Out of scope + +How Cloud will push changes to `autoupdate_version` is out of scope for this +RFD and will be handled by a separate Cloud specific RFD. + +Automatic updates for Teleport Connect are out of scope for this RFD as it uses +a different install/update mechanism. For now it will call `tsh` with +`TELEPORT_TOOLS_VERSION=off` until automatic updates support can be added to +Connect. + +### Security + +The initial version of automatic updates will rely on TLS to establish +connection authenticity to the Teleport download server. The authenticity of +assets served from the download server is out of scope for this RFD. Cluster +administrators concerned with the authenticity of assets served from the +download server can use self-managed updates with system package managers which +are signed. + +Phase 2 will use The Upgrade Framework (TUF) to implement secure updates. diff --git a/rfd/cspell.json b/rfd/cspell.json index 9982219bada5e..ee16b81f55872 100644 --- a/rfd/cspell.json +++ b/rfd/cspell.json @@ -201,6 +201,7 @@ "Statfs", "Subconditions", "Submatch", + "Sudia", "Sycqsbqf", "TBLPROPERTIES", "TCSD", @@ -211,6 +212,7 @@ "TPMs", "Tablename", "Teleconsole", + "Teleporter", "Teleporting", "Tiago", "Tkachenko", @@ -297,6 +299,7 @@ "behaviour", "behaviours", "benchtime", + "bernardjkim", "bizz", "bjoerger", "blocklists", @@ -667,6 +670,7 @@ "runtimes", "russjones", "ryanclark", + "sclevine", "secretless", "selfsubjectaccessreviews", "selfsubjectrulesreviews", @@ -731,6 +735,7 @@ "sudoersfile", "supercede", "syft", + "syscall", "tablewriter", "tailscale", "targetting", From aa34c34dad985a33d6607d602969cbf471246fa8 Mon Sep 17 00:00:00 2001 From: Gavin Frazar Date: Fri, 10 Jan 2025 14:23:38 -0800 Subject: [PATCH 03/15] migrate AWS RDS services to AWS SDK v2 (#50848) --- go.mod | 1 + go.sum | 2 + integrations/terraform/go.sum | 2 + lib/cloud/aws/aws.go | 51 ++- lib/cloud/aws/tags_helpers.go | 29 +- lib/cloud/aws/tags_helpers_test.go | 25 +- lib/cloud/awstesthelpers/tags.go | 20 ++ lib/cloud/clients.go | 27 -- lib/cloud/mocks/aws_config.go | 5 + lib/cloud/mocks/aws_rds.go | 336 ++++++------------ lib/srv/db/access_test.go | 3 +- lib/srv/db/cloud/aws.go | 51 ++- lib/srv/db/cloud/iam.go | 21 +- lib/srv/db/cloud/iam_test.go | 42 ++- lib/srv/db/cloud/meta.go | 128 ++++--- lib/srv/db/cloud/meta_test.go | 43 ++- lib/srv/db/cloud/resource_checker_url.go | 16 +- lib/srv/db/cloud/resource_checker_url_aws.go | 54 +-- .../db/cloud/resource_checker_url_aws_test.go | 32 +- lib/srv/db/common/auth.go | 42 ++- lib/srv/db/common/auth_test.go | 21 +- lib/srv/db/server.go | 7 +- lib/srv/db/watcher_test.go | 3 +- lib/srv/discovery/access_graph.go | 1 + lib/srv/discovery/common/database.go | 81 ++--- lib/srv/discovery/common/database_test.go | 109 +++--- lib/srv/discovery/common/kubernetes_test.go | 5 +- lib/srv/discovery/common/renaming_test.go | 25 +- lib/srv/discovery/discovery_test.go | 88 +++-- .../discovery/fetchers/aws-sync/aws-sync.go | 39 +- lib/srv/discovery/fetchers/aws-sync/rds.go | 154 +++++--- .../discovery/fetchers/aws-sync/rds_test.go | 197 ++++++---- lib/srv/discovery/fetchers/db/aws.go | 12 +- lib/srv/discovery/fetchers/db/aws_docdb.go | 58 +-- .../discovery/fetchers/db/aws_docdb_test.go | 53 +-- lib/srv/discovery/fetchers/db/aws_rds.go | 178 ++++++---- .../discovery/fetchers/db/aws_rds_proxy.go | 111 +++--- .../fetchers/db/aws_rds_proxy_test.go | 27 +- lib/srv/discovery/fetchers/db/aws_rds_test.go | 179 ++++++---- lib/srv/discovery/fetchers/db/aws_redshift.go | 5 +- .../fetchers/db/aws_redshift_test.go | 30 +- lib/srv/discovery/fetchers/db/db.go | 47 ++- 42 files changed, 1281 insertions(+), 1079 deletions(-) diff --git a/go.mod b/go.mod index 78f04732806b6..625a780eb3ff6 100644 --- a/go.mod +++ b/go.mod @@ -48,6 +48,7 @@ require ( github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue v1.15.23 github.com/aws/aws-sdk-go-v2/feature/dynamodbstreams/attributevalue v1.14.58 github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.22 + github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.5.2 github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.45 github.com/aws/aws-sdk-go-v2/service/applicationautoscaling v1.34.3 github.com/aws/aws-sdk-go-v2/service/athena v1.49.2 diff --git a/go.sum b/go.sum index 5665c4f7280c7..5bf38ba7fc0c4 100644 --- a/go.sum +++ b/go.sum @@ -866,6 +866,8 @@ github.com/aws/aws-sdk-go-v2/feature/dynamodbstreams/attributevalue v1.14.58/go. github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.13.3/go.mod h1:4Q0UFP0YJf0NrsEuEYHpM9fTSEVnD16Z3uyEF7J9JGM= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.22 h1:kqOrpojG71DxJm/KDPO+Z/y1phm1JlC8/iT+5XRmAn8= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.22/go.mod h1:NtSFajXVVL8TA2QNngagVZmUtXciyrHOt7xgz4faS/M= +github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.5.2 h1:fo+GuZNME9oGDc7VY+EBT+oCrco6RjRgUp1bKTcaHrU= +github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.5.2/go.mod h1:fnqb94UO6YCjBIic4WaqDYkNVAEFWOWiReVHitBBWW0= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.45 h1:ZxB8WFVYwolhDZxuZXoesHkl+L9cXLWy0K/G0QkNATc= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.45/go.mod h1:1krrbyoFFDqaNldmltPTP+mK3sAXLHPoaFtISOw2Hkk= github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.33/go.mod h1:7i0PF1ME/2eUPFcjkVIwq+DOygHEoK92t5cDqNgYbIw= diff --git a/integrations/terraform/go.sum b/integrations/terraform/go.sum index 106e4e41c759b..da4bca430e263 100644 --- a/integrations/terraform/go.sum +++ b/integrations/terraform/go.sum @@ -798,6 +798,8 @@ github.com/aws/aws-sdk-go-v2/feature/dynamodbstreams/attributevalue v1.14.58 h1: github.com/aws/aws-sdk-go-v2/feature/dynamodbstreams/attributevalue v1.14.58/go.mod h1:1FDesv+tfF2w5mRnLQbB8P33BPfxrngXtfNcdnrtmjw= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.22 h1:kqOrpojG71DxJm/KDPO+Z/y1phm1JlC8/iT+5XRmAn8= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.22/go.mod h1:NtSFajXVVL8TA2QNngagVZmUtXciyrHOt7xgz4faS/M= +github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.5.2 h1:fo+GuZNME9oGDc7VY+EBT+oCrco6RjRgUp1bKTcaHrU= +github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.5.2/go.mod h1:fnqb94UO6YCjBIic4WaqDYkNVAEFWOWiReVHitBBWW0= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.45 h1:ZxB8WFVYwolhDZxuZXoesHkl+L9cXLWy0K/G0QkNATc= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.45/go.mod h1:1krrbyoFFDqaNldmltPTP+mK3sAXLHPoaFtISOw2Hkk= github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.26 h1:I/5wmGMffY4happ8NOCuIUEWGUvvFp5NSeQcXl9RHcI= diff --git a/lib/cloud/aws/aws.go b/lib/cloud/aws/aws.go index 27ea56321b7df..7361ff75f219c 100644 --- a/lib/cloud/aws/aws.go +++ b/lib/cloud/aws/aws.go @@ -22,12 +22,12 @@ import ( "slices" "strings" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/elasticache" "github.com/aws/aws-sdk-go/service/memorydb" "github.com/aws/aws-sdk-go/service/opensearchservice" - "github.com/aws/aws-sdk-go/service/rds" "github.com/coreos/go-semver/semver" "github.com/gravitational/teleport/lib/services" @@ -74,18 +74,51 @@ func IsOpenSearchDomainAvailable(domain *opensearchservice.DomainStatus) bool { } // IsRDSProxyAvailable checks if the RDS Proxy is available. -func IsRDSProxyAvailable(dbProxy *rds.DBProxy) bool { - return IsResourceAvailable(dbProxy, dbProxy.Status) +func IsRDSProxyAvailable(dbProxy *rdstypes.DBProxy) bool { + switch dbProxy.Status { + case + rdstypes.DBProxyStatusAvailable, + rdstypes.DBProxyStatusModifying, + rdstypes.DBProxyStatusReactivating: + return true + case + rdstypes.DBProxyStatusCreating, + rdstypes.DBProxyStatusDeleting, + rdstypes.DBProxyStatusIncompatibleNetwork, + rdstypes.DBProxyStatusInsufficientResourceLimits, + rdstypes.DBProxyStatusSuspended, + rdstypes.DBProxyStatusSuspending: + return false + } + slog.WarnContext(context.Background(), "Assuming RDS Proxy with unknown status is available", + "status", dbProxy.Status, + ) + return true } // IsRDSProxyCustomEndpointAvailable checks if the RDS Proxy custom endpoint is available. -func IsRDSProxyCustomEndpointAvailable(customEndpoint *rds.DBProxyEndpoint) bool { - return IsResourceAvailable(customEndpoint, customEndpoint.Status) +func IsRDSProxyCustomEndpointAvailable(customEndpoint *rdstypes.DBProxyEndpoint) bool { + switch customEndpoint.Status { + case + rdstypes.DBProxyEndpointStatusAvailable, + rdstypes.DBProxyEndpointStatusModifying: + return true + case + rdstypes.DBProxyEndpointStatusCreating, + rdstypes.DBProxyEndpointStatusDeleting, + rdstypes.DBProxyEndpointStatusIncompatibleNetwork, + rdstypes.DBProxyEndpointStatusInsufficientResourceLimits: + return false + } + slog.WarnContext(context.Background(), "Assuming RDS Proxy custom endpoint with unknown status is available", + "status", customEndpoint.Status, + ) + return true } // IsRDSInstanceSupported returns true if database supports IAM authentication. // Currently, only MariaDB is being checked. -func IsRDSInstanceSupported(instance *rds.DBInstance) bool { +func IsRDSInstanceSupported(instance *rdstypes.DBInstance) bool { // TODO(jakule): Check other engines. if aws.StringValue(instance.Engine) != services.RDSEngineMariaDB { return true @@ -105,7 +138,7 @@ func IsRDSInstanceSupported(instance *rds.DBInstance) bool { } // IsRDSClusterSupported checks whether the Aurora cluster is supported. -func IsRDSClusterSupported(cluster *rds.DBCluster) bool { +func IsRDSClusterSupported(cluster *rdstypes.DBCluster) bool { switch aws.StringValue(cluster.EngineMode) { // Aurora Serverless v1 does NOT support IAM authentication. // https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/aurora-serverless.html#aurora-serverless.limitations @@ -129,7 +162,7 @@ func IsRDSClusterSupported(cluster *rds.DBCluster) bool { } // AuroraMySQLVersion extracts aurora mysql version from engine version -func AuroraMySQLVersion(cluster *rds.DBCluster) string { +func AuroraMySQLVersion(cluster *rdstypes.DBCluster) string { // version guide: https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/AuroraMySQL.Updates.Versions.html // a list of all the available versions: https://docs.aws.amazon.com/cli/latest/reference/rds/describe-db-engine-versions.html // @@ -154,7 +187,7 @@ func AuroraMySQLVersion(cluster *rds.DBCluster) string { // for this DocumentDB cluster. // // https://docs.aws.amazon.com/documentdb/latest/developerguide/iam-identity-auth.html -func IsDocumentDBClusterSupported(cluster *rds.DBCluster) bool { +func IsDocumentDBClusterSupported(cluster *rdstypes.DBCluster) bool { ver, err := semver.NewVersion(aws.StringValue(cluster.EngineVersion)) if err != nil { slog.ErrorContext(context.Background(), "Failed to parse DocumentDB engine version", "version", aws.StringValue(cluster.EngineVersion)) diff --git a/lib/cloud/aws/tags_helpers.go b/lib/cloud/aws/tags_helpers.go index 3e61bd6fc1a42..43f6ba48f61ca 100644 --- a/lib/cloud/aws/tags_helpers.go +++ b/lib/cloud/aws/tags_helpers.go @@ -24,14 +24,13 @@ import ( "slices" ec2TypesV2 "github.com/aws/aws-sdk-go-v2/service/ec2/types" - rdsTypesV2 "github.com/aws/aws-sdk-go-v2/service/rds/types" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/elasticache" "github.com/aws/aws-sdk-go/service/memorydb" "github.com/aws/aws-sdk-go/service/opensearchservice" - "github.com/aws/aws-sdk-go/service/rds" "github.com/aws/aws-sdk-go/service/redshiftserverless" "github.com/aws/aws-sdk-go/service/secretsmanager" "golang.org/x/exp/maps" @@ -43,11 +42,10 @@ import ( type ResourceTag interface { // TODO Go generic does not allow access common fields yet. List all types // here and use a type switch for now. - rdsTypesV2.Tag | + rdstypes.Tag | ec2TypesV2.Tag | redshifttypes.Tag | *ec2.Tag | - *rds.Tag | *elasticache.Tag | *memorydb.Tag | *redshiftserverless.Tag | @@ -76,8 +74,6 @@ func TagsToLabels[Tag ResourceTag](tags []Tag) map[string]string { func resourceTagToKeyValue[Tag ResourceTag](tag Tag) (string, string) { switch v := any(tag).(type) { - case *rds.Tag: - return aws.StringValue(v.Key), aws.StringValue(v.Value) case *ec2.Tag: return aws.StringValue(v.Key), aws.StringValue(v.Value) case *elasticache.Tag: @@ -86,7 +82,7 @@ func resourceTagToKeyValue[Tag ResourceTag](tag Tag) (string, string) { return aws.StringValue(v.Key), aws.StringValue(v.Value) case *redshiftserverless.Tag: return aws.StringValue(v.Key), aws.StringValue(v.Value) - case rdsTypesV2.Tag: + case rdstypes.Tag: return aws.StringValue(v.Key), aws.StringValue(v.Value) case ec2TypesV2.Tag: return aws.StringValue(v.Key), aws.StringValue(v.Value) @@ -123,22 +119,3 @@ func LabelsToTags[T any, PT SettableTag[T]](labels map[string]string) (tags []*T } return } - -// LabelsToRDSV2Tags converts labels into [rdsTypesV2.Tag] list. -func LabelsToRDSV2Tags(labels map[string]string) []rdsTypesV2.Tag { - keys := maps.Keys(labels) - slices.Sort(keys) - - ret := make([]rdsTypesV2.Tag, 0, len(keys)) - for _, key := range keys { - key := key - value := labels[key] - - ret = append(ret, rdsTypesV2.Tag{ - Key: &key, - Value: &value, - }) - } - - return ret -} diff --git a/lib/cloud/aws/tags_helpers_test.go b/lib/cloud/aws/tags_helpers_test.go index 228c477a316cb..0bc75677fefbd 100644 --- a/lib/cloud/aws/tags_helpers_test.go +++ b/lib/cloud/aws/tags_helpers_test.go @@ -22,10 +22,10 @@ import ( "testing" rdsTypesV2 "github.com/aws/aws-sdk-go-v2/service/rds/types" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ec2" "github.com/aws/aws-sdk-go/service/elasticache" - "github.com/aws/aws-sdk-go/service/rds" "github.com/stretchr/testify/require" ) @@ -33,7 +33,7 @@ func TestTagsToLabels(t *testing.T) { t.Parallel() t.Run("rds", func(t *testing.T) { - inputTags := []*rds.Tag{ + inputTags := []rdstypes.Tag{ { Key: aws.String("Env"), Value: aws.String("dev"), @@ -135,25 +135,4 @@ func TestLabelsToTags(t *testing.T) { actualTags := LabelsToTags[elasticache.Tag](inputLabels) require.Equal(t, expectTags, actualTags) }) - - t.Run("rdsV2", func(t *testing.T) { - inputLabels := map[string]string{ - "labelB": "valueB", - "labelA": "valueA", - } - - expectTags := []rdsTypesV2.Tag{ - { - Key: aws.String("labelA"), - Value: aws.String("valueA"), - }, - { - Key: aws.String("labelB"), - Value: aws.String("valueB"), - }, - } - - actualTags := LabelsToRDSV2Tags(inputLabels) - require.EqualValues(t, expectTags, actualTags) - }) } diff --git a/lib/cloud/awstesthelpers/tags.go b/lib/cloud/awstesthelpers/tags.go index 5e1f4aa0e0738..28bed6b973f0b 100644 --- a/lib/cloud/awstesthelpers/tags.go +++ b/lib/cloud/awstesthelpers/tags.go @@ -22,6 +22,7 @@ import ( "maps" "slices" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types" ) @@ -43,3 +44,22 @@ func LabelsToRedshiftTags(labels map[string]string) []redshifttypes.Tag { return ret } + +// LabelsToRDSTags converts labels into a [rdstypes.Tag] list. +func LabelsToRDSTags(labels map[string]string) []rdstypes.Tag { + keys := slices.Collect(maps.Keys(labels)) + slices.Sort(keys) + + ret := make([]rdstypes.Tag, 0, len(keys)) + for _, key := range keys { + key := key + value := labels[key] + + ret = append(ret, rdstypes.Tag{ + Key: &key, + Value: &value, + }) + } + + return ret +} diff --git a/lib/cloud/clients.go b/lib/cloud/clients.go index 28e8ebabac598..cc50c98c1ba4f 100644 --- a/lib/cloud/clients.go +++ b/lib/cloud/clients.go @@ -49,8 +49,6 @@ import ( "github.com/aws/aws-sdk-go/service/memorydb/memorydbiface" "github.com/aws/aws-sdk-go/service/opensearchservice" "github.com/aws/aws-sdk-go/service/opensearchservice/opensearchserviceiface" - "github.com/aws/aws-sdk-go/service/rds" - "github.com/aws/aws-sdk-go/service/rds/rdsiface" "github.com/aws/aws-sdk-go/service/redshiftserverless" "github.com/aws/aws-sdk-go/service/redshiftserverless/redshiftserverlessiface" "github.com/aws/aws-sdk-go/service/s3" @@ -109,8 +107,6 @@ type GCPClients interface { type AWSClients interface { // GetAWSSession returns AWS session for the specified region and any role(s). GetAWSSession(ctx context.Context, region string, opts ...AWSOptionsFn) (*awssession.Session, error) - // GetAWSRDSClient returns AWS RDS client for the specified region. - GetAWSRDSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (rdsiface.RDSAPI, error) // GetAWSRedshiftServerlessClient returns AWS Redshift Serverless client for the specified region. GetAWSRedshiftServerlessClient(ctx context.Context, region string, opts ...AWSOptionsFn) (redshiftserverlessiface.RedshiftServerlessAPI, error) // GetAWSElastiCacheClient returns AWS ElastiCache client for the specified region. @@ -500,15 +496,6 @@ func (c *cloudClients) GetAWSSession(ctx context.Context, region string, opts .. return c.getAWSSessionForRole(ctx, region, options) } -// GetAWSRDSClient returns AWS RDS client for the specified region. -func (c *cloudClients) GetAWSRDSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (rdsiface.RDSAPI, error) { - session, err := c.GetAWSSession(ctx, region, opts...) - if err != nil { - return nil, trace.Wrap(err) - } - return rds.New(session), nil -} - // GetAWSRedshiftServerlessClient returns AWS Redshift Serverless client for the specified region. func (c *cloudClients) GetAWSRedshiftServerlessClient(ctx context.Context, region string, opts ...AWSOptionsFn) (redshiftserverlessiface.RedshiftServerlessAPI, error) { session, err := c.GetAWSSession(ctx, region, opts...) @@ -1005,8 +992,6 @@ var _ Clients = (*TestCloudClients)(nil) // TestCloudClients are used in tests. type TestCloudClients struct { - RDS rdsiface.RDSAPI - RDSPerRegion map[string]rdsiface.RDSAPI RedshiftServerless redshiftserverlessiface.RedshiftServerlessAPI ElastiCache elasticacheiface.ElastiCacheAPI OpenSearch opensearchserviceiface.OpenSearchServiceAPI @@ -1075,18 +1060,6 @@ func (c *TestCloudClients) getAWSSessionForRegion(region string) (*awssession.Se }) } -// GetAWSRDSClient returns AWS RDS client for the specified region. -func (c *TestCloudClients) GetAWSRDSClient(ctx context.Context, region string, opts ...AWSOptionsFn) (rdsiface.RDSAPI, error) { - _, err := c.GetAWSSession(ctx, region, opts...) - if err != nil { - return nil, trace.Wrap(err) - } - if len(c.RDSPerRegion) != 0 { - return c.RDSPerRegion[region], nil - } - return c.RDS, nil -} - // GetAWSRedshiftServerlessClient returns AWS Redshift Serverless client for the specified region. func (c *TestCloudClients) GetAWSRedshiftServerlessClient(ctx context.Context, region string, opts ...AWSOptionsFn) (redshiftserverlessiface.RedshiftServerlessAPI, error) { _, err := c.GetAWSSession(ctx, region, opts...) diff --git a/lib/cloud/mocks/aws_config.go b/lib/cloud/mocks/aws_config.go index 819d6ca8f535e..d148e9512c8d4 100644 --- a/lib/cloud/mocks/aws_config.go +++ b/lib/cloud/mocks/aws_config.go @@ -29,11 +29,16 @@ import ( ) type AWSConfigProvider struct { + Err error STSClient *STSClient OIDCIntegrationClient awsconfig.OIDCIntegrationClient } func (f *AWSConfigProvider) GetConfig(ctx context.Context, region string, optFns ...awsconfig.OptionsFn) (aws.Config, error) { + if f.Err != nil { + return aws.Config{}, f.Err + } + stsClt := f.STSClient if stsClt == nil { stsClt = &STSClient{} diff --git a/lib/cloud/mocks/aws_rds.go b/lib/cloud/mocks/aws_rds.go index 50130d668f5c0..9338b8330dc5f 100644 --- a/lib/cloud/mocks/aws_rds.go +++ b/lib/cloud/mocks/aws_rds.go @@ -19,159 +19,156 @@ package mocks import ( + "context" "fmt" + "github.com/aws/aws-sdk-go-v2/service/rds" + rdsv2 "github.com/aws/aws-sdk-go-v2/service/rds" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/arn" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/service/rds" - "github.com/aws/aws-sdk-go/service/rds/rdsiface" "github.com/google/uuid" "github.com/gravitational/trace" - libcloudaws "github.com/gravitational/teleport/lib/cloud/aws" + "github.com/gravitational/teleport/lib/cloud/awstesthelpers" ) -// RDSMock mocks AWS RDS API. -type RDSMock struct { - rdsiface.RDSAPI - DBInstances []*rds.DBInstance - DBClusters []*rds.DBCluster - DBProxies []*rds.DBProxy - DBProxyEndpoints []*rds.DBProxyEndpoint - DBEngineVersions []*rds.DBEngineVersion +type RDSClient struct { + Unauth bool + + DBInstances []rdstypes.DBInstance + DBClusters []rdstypes.DBCluster + DBProxies []rdstypes.DBProxy + DBProxyEndpoints []rdstypes.DBProxyEndpoint + DBEngineVersions []rdstypes.DBEngineVersion } -func (m *RDSMock) DescribeDBInstancesWithContext(ctx aws.Context, input *rds.DescribeDBInstancesInput, options ...request.Option) (*rds.DescribeDBInstancesOutput, error) { - if err := checkEngineFilters(input.Filters, m.DBEngineVersions); err != nil { +func (c *RDSClient) DescribeDBInstances(_ context.Context, input *rdsv2.DescribeDBInstancesInput, _ ...func(*rdsv2.Options)) (*rdsv2.DescribeDBInstancesOutput, error) { + if c.Unauth { + return nil, trace.AccessDenied("unauthorized") + } + + if err := checkEngineFilters(input.Filters, c.DBEngineVersions); err != nil { return nil, trace.Wrap(err) } - instances, err := applyInstanceFilters(m.DBInstances, input.Filters) + instances, err := applyInstanceFilters(c.DBInstances, input.Filters) if err != nil { return nil, trace.Wrap(err) } if aws.StringValue(input.DBInstanceIdentifier) == "" { - return &rds.DescribeDBInstancesOutput{ + return &rdsv2.DescribeDBInstancesOutput{ DBInstances: instances, }, nil } for _, instance := range instances { if aws.StringValue(instance.DBInstanceIdentifier) == aws.StringValue(input.DBInstanceIdentifier) { - return &rds.DescribeDBInstancesOutput{ - DBInstances: []*rds.DBInstance{instance}, + return &rdsv2.DescribeDBInstancesOutput{ + DBInstances: []rdstypes.DBInstance{instance}, }, nil } } return nil, trace.NotFound("instance %v not found", aws.StringValue(input.DBInstanceIdentifier)) } -func (m *RDSMock) DescribeDBInstancesPagesWithContext(ctx aws.Context, input *rds.DescribeDBInstancesInput, fn func(*rds.DescribeDBInstancesOutput, bool) bool, options ...request.Option) error { - if err := checkEngineFilters(input.Filters, m.DBEngineVersions); err != nil { - return trace.Wrap(err) - } - instances, err := applyInstanceFilters(m.DBInstances, input.Filters) - if err != nil { - return trace.Wrap(err) +func (c *RDSClient) DescribeDBClusters(_ context.Context, input *rdsv2.DescribeDBClustersInput, _ ...func(*rdsv2.Options)) (*rdsv2.DescribeDBClustersOutput, error) { + if c.Unauth { + return nil, trace.AccessDenied("unauthorized") } - fn(&rds.DescribeDBInstancesOutput{ - DBInstances: instances, - }, true) - return nil -} -func (m *RDSMock) DescribeDBClustersWithContext(ctx aws.Context, input *rds.DescribeDBClustersInput, options ...request.Option) (*rds.DescribeDBClustersOutput, error) { - if err := checkEngineFilters(input.Filters, m.DBEngineVersions); err != nil { + if err := checkEngineFilters(input.Filters, c.DBEngineVersions); err != nil { return nil, trace.Wrap(err) } - clusters, err := applyClusterFilters(m.DBClusters, input.Filters) + clusters, err := applyClusterFilters(c.DBClusters, input.Filters) if err != nil { return nil, trace.Wrap(err) } if aws.StringValue(input.DBClusterIdentifier) == "" { - return &rds.DescribeDBClustersOutput{ + return &rdsv2.DescribeDBClustersOutput{ DBClusters: clusters, }, nil } for _, cluster := range clusters { if aws.StringValue(cluster.DBClusterIdentifier) == aws.StringValue(input.DBClusterIdentifier) { - return &rds.DescribeDBClustersOutput{ - DBClusters: []*rds.DBCluster{cluster}, + return &rdsv2.DescribeDBClustersOutput{ + DBClusters: []rdstypes.DBCluster{cluster}, }, nil } } return nil, trace.NotFound("cluster %v not found", aws.StringValue(input.DBClusterIdentifier)) } -func (m *RDSMock) DescribeDBClustersPagesWithContext(aws aws.Context, input *rds.DescribeDBClustersInput, fn func(*rds.DescribeDBClustersOutput, bool) bool, options ...request.Option) error { - if err := checkEngineFilters(input.Filters, m.DBEngineVersions); err != nil { - return trace.Wrap(err) - } - clusters, err := applyClusterFilters(m.DBClusters, input.Filters) - if err != nil { - return trace.Wrap(err) +func (c *RDSClient) ModifyDBInstance(ctx context.Context, input *rdsv2.ModifyDBInstanceInput, optFns ...func(*rdsv2.Options)) (*rdsv2.ModifyDBInstanceOutput, error) { + if c.Unauth { + return nil, trace.AccessDenied("unauthorized") } - fn(&rds.DescribeDBClustersOutput{ - DBClusters: clusters, - }, true) - return nil -} -func (m *RDSMock) ModifyDBInstanceWithContext(ctx aws.Context, input *rds.ModifyDBInstanceInput, options ...request.Option) (*rds.ModifyDBInstanceOutput, error) { - for i, instance := range m.DBInstances { + for i, instance := range c.DBInstances { if aws.StringValue(instance.DBInstanceIdentifier) == aws.StringValue(input.DBInstanceIdentifier) { if aws.BoolValue(input.EnableIAMDatabaseAuthentication) { - m.DBInstances[i].IAMDatabaseAuthenticationEnabled = aws.Bool(true) + c.DBInstances[i].IAMDatabaseAuthenticationEnabled = aws.Bool(true) } - return &rds.ModifyDBInstanceOutput{ - DBInstance: m.DBInstances[i], + return &rdsv2.ModifyDBInstanceOutput{ + DBInstance: &c.DBInstances[i], }, nil } } return nil, trace.NotFound("instance %v not found", aws.StringValue(input.DBInstanceIdentifier)) } -func (m *RDSMock) ModifyDBClusterWithContext(ctx aws.Context, input *rds.ModifyDBClusterInput, options ...request.Option) (*rds.ModifyDBClusterOutput, error) { - for i, cluster := range m.DBClusters { +func (c *RDSClient) ModifyDBCluster(ctx context.Context, input *rdsv2.ModifyDBClusterInput, optFns ...func(*rdsv2.Options)) (*rdsv2.ModifyDBClusterOutput, error) { + if c.Unauth { + return nil, trace.AccessDenied("unauthorized") + } + + for i, cluster := range c.DBClusters { if aws.StringValue(cluster.DBClusterIdentifier) == aws.StringValue(input.DBClusterIdentifier) { if aws.BoolValue(input.EnableIAMDatabaseAuthentication) { - m.DBClusters[i].IAMDatabaseAuthenticationEnabled = aws.Bool(true) + c.DBClusters[i].IAMDatabaseAuthenticationEnabled = aws.Bool(true) } - return &rds.ModifyDBClusterOutput{ - DBCluster: m.DBClusters[i], + return &rdsv2.ModifyDBClusterOutput{ + DBCluster: &c.DBClusters[i], }, nil } } return nil, trace.NotFound("cluster %v not found", aws.StringValue(input.DBClusterIdentifier)) } -func (m *RDSMock) DescribeDBProxiesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, options ...request.Option) (*rds.DescribeDBProxiesOutput, error) { +func (c *RDSClient) DescribeDBProxies(_ context.Context, input *rdsv2.DescribeDBProxiesInput, _ ...func(*rdsv2.Options)) (*rdsv2.DescribeDBProxiesOutput, error) { + if c.Unauth { + return nil, trace.AccessDenied("unauthorized") + } + if aws.StringValue(input.DBProxyName) == "" { - return &rds.DescribeDBProxiesOutput{ - DBProxies: m.DBProxies, + return &rdsv2.DescribeDBProxiesOutput{ + DBProxies: c.DBProxies, }, nil } - for _, dbProxy := range m.DBProxies { + for _, dbProxy := range c.DBProxies { if aws.StringValue(dbProxy.DBProxyName) == aws.StringValue(input.DBProxyName) { - return &rds.DescribeDBProxiesOutput{ - DBProxies: []*rds.DBProxy{dbProxy}, + return &rdsv2.DescribeDBProxiesOutput{ + DBProxies: []rdstypes.DBProxy{dbProxy}, }, nil } } return nil, trace.NotFound("proxy %v not found", aws.StringValue(input.DBProxyName)) } -func (m *RDSMock) DescribeDBProxyEndpointsWithContext(ctx aws.Context, input *rds.DescribeDBProxyEndpointsInput, options ...request.Option) (*rds.DescribeDBProxyEndpointsOutput, error) { +func (c *RDSClient) DescribeDBProxyEndpoints(_ context.Context, input *rdsv2.DescribeDBProxyEndpointsInput, _ ...func(*rdsv2.Options)) (*rdsv2.DescribeDBProxyEndpointsOutput, error) { + if c.Unauth { + return nil, trace.AccessDenied("unauthorized") + } + inputProxyName := aws.StringValue(input.DBProxyName) inputProxyEndpointName := aws.StringValue(input.DBProxyEndpointName) if inputProxyName == "" && inputProxyEndpointName == "" { - return &rds.DescribeDBProxyEndpointsOutput{ - DBProxyEndpoints: m.DBProxyEndpoints, + return &rdsv2.DescribeDBProxyEndpointsOutput{ + DBProxyEndpoints: c.DBProxyEndpoints, }, nil } - var endpoints []*rds.DBProxyEndpoint - for _, dbProxyEndpoiont := range m.DBProxyEndpoints { + var endpoints []rdstypes.DBProxyEndpoint + for _, dbProxyEndpoiont := range c.DBProxyEndpoints { if inputProxyEndpointName != "" && inputProxyEndpointName != aws.StringValue(dbProxyEndpoiont.DBProxyEndpointName) { continue @@ -187,114 +184,15 @@ func (m *RDSMock) DescribeDBProxyEndpointsWithContext(ctx aws.Context, input *rd if len(endpoints) == 0 { return nil, trace.NotFound("proxy endpoint %v not found", aws.StringValue(input.DBProxyEndpointName)) } - return &rds.DescribeDBProxyEndpointsOutput{DBProxyEndpoints: endpoints}, nil -} - -func (m *RDSMock) DescribeDBProxiesPagesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, fn func(*rds.DescribeDBProxiesOutput, bool) bool, options ...request.Option) error { - fn(&rds.DescribeDBProxiesOutput{ - DBProxies: m.DBProxies, - }, true) - return nil + return &rdsv2.DescribeDBProxyEndpointsOutput{DBProxyEndpoints: endpoints}, nil } -func (m *RDSMock) DescribeDBProxyEndpointsPagesWithContext(ctx aws.Context, input *rds.DescribeDBProxyEndpointsInput, fn func(*rds.DescribeDBProxyEndpointsOutput, bool) bool, options ...request.Option) error { - fn(&rds.DescribeDBProxyEndpointsOutput{ - DBProxyEndpoints: m.DBProxyEndpoints, - }, true) - return nil -} - -func (m *RDSMock) ListTagsForResourceWithContext(ctx aws.Context, input *rds.ListTagsForResourceInput, options ...request.Option) (*rds.ListTagsForResourceOutput, error) { +func (c *RDSClient) ListTagsForResource(context.Context, *rds.ListTagsForResourceInput, ...func(*rds.Options)) (*rds.ListTagsForResourceOutput, error) { return &rds.ListTagsForResourceOutput{}, nil } -// RDSMockUnauth is a mock RDS client that returns access denied to each call. -type RDSMockUnauth struct { - rdsiface.RDSAPI -} - -func (m *RDSMockUnauth) DescribeDBInstancesWithContext(ctx aws.Context, input *rds.DescribeDBInstancesInput, options ...request.Option) (*rds.DescribeDBInstancesOutput, error) { - return nil, trace.AccessDenied("unauthorized") -} - -func (m *RDSMockUnauth) DescribeDBClustersWithContext(ctx aws.Context, input *rds.DescribeDBClustersInput, options ...request.Option) (*rds.DescribeDBClustersOutput, error) { - return nil, trace.AccessDenied("unauthorized") -} - -func (m *RDSMockUnauth) ModifyDBInstanceWithContext(ctx aws.Context, input *rds.ModifyDBInstanceInput, options ...request.Option) (*rds.ModifyDBInstanceOutput, error) { - return nil, trace.AccessDenied("unauthorized") -} - -func (m *RDSMockUnauth) ModifyDBClusterWithContext(ctx aws.Context, input *rds.ModifyDBClusterInput, options ...request.Option) (*rds.ModifyDBClusterOutput, error) { - return nil, trace.AccessDenied("unauthorized") -} - -func (m *RDSMockUnauth) DescribeDBInstancesPagesWithContext(ctx aws.Context, input *rds.DescribeDBInstancesInput, fn func(*rds.DescribeDBInstancesOutput, bool) bool, options ...request.Option) error { - return trace.AccessDenied("unauthorized") -} - -func (m *RDSMockUnauth) DescribeDBClustersPagesWithContext(aws aws.Context, input *rds.DescribeDBClustersInput, fn func(*rds.DescribeDBClustersOutput, bool) bool, options ...request.Option) error { - return trace.AccessDenied("unauthorized") -} - -func (m *RDSMockUnauth) DescribeDBProxiesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, options ...request.Option) (*rds.DescribeDBProxiesOutput, error) { - return nil, trace.AccessDenied("unauthorized") -} - -func (m *RDSMockUnauth) DescribeDBProxyEndpointsWithContext(ctx aws.Context, input *rds.DescribeDBProxyEndpointsInput, options ...request.Option) (*rds.DescribeDBProxyEndpointsOutput, error) { - return nil, trace.AccessDenied("unauthorized") -} - -func (m *RDSMockUnauth) DescribeDBProxiesPagesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, fn func(*rds.DescribeDBProxiesOutput, bool) bool, options ...request.Option) error { - return trace.AccessDenied("unauthorized") -} - -// RDSMockByDBType is a mock RDS client that mocks API calls by DB type -type RDSMockByDBType struct { - rdsiface.RDSAPI - DBInstances rdsiface.RDSAPI - DBClusters rdsiface.RDSAPI - DBProxies rdsiface.RDSAPI -} - -func (m *RDSMockByDBType) DescribeDBInstancesWithContext(ctx aws.Context, input *rds.DescribeDBInstancesInput, options ...request.Option) (*rds.DescribeDBInstancesOutput, error) { - return m.DBInstances.DescribeDBInstancesWithContext(ctx, input, options...) -} - -func (m *RDSMockByDBType) ModifyDBInstanceWithContext(ctx aws.Context, input *rds.ModifyDBInstanceInput, options ...request.Option) (*rds.ModifyDBInstanceOutput, error) { - return m.DBInstances.ModifyDBInstanceWithContext(ctx, input, options...) -} - -func (m *RDSMockByDBType) DescribeDBInstancesPagesWithContext(ctx aws.Context, input *rds.DescribeDBInstancesInput, fn func(*rds.DescribeDBInstancesOutput, bool) bool, options ...request.Option) error { - return m.DBInstances.DescribeDBInstancesPagesWithContext(ctx, input, fn, options...) -} - -func (m *RDSMockByDBType) DescribeDBClustersWithContext(ctx aws.Context, input *rds.DescribeDBClustersInput, options ...request.Option) (*rds.DescribeDBClustersOutput, error) { - return m.DBClusters.DescribeDBClustersWithContext(ctx, input, options...) -} - -func (m *RDSMockByDBType) ModifyDBClusterWithContext(ctx aws.Context, input *rds.ModifyDBClusterInput, options ...request.Option) (*rds.ModifyDBClusterOutput, error) { - return m.DBClusters.ModifyDBClusterWithContext(ctx, input, options...) -} - -func (m *RDSMockByDBType) DescribeDBClustersPagesWithContext(aws aws.Context, input *rds.DescribeDBClustersInput, fn func(*rds.DescribeDBClustersOutput, bool) bool, options ...request.Option) error { - return m.DBClusters.DescribeDBClustersPagesWithContext(aws, input, fn, options...) -} - -func (m *RDSMockByDBType) DescribeDBProxiesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, options ...request.Option) (*rds.DescribeDBProxiesOutput, error) { - return m.DBProxies.DescribeDBProxiesWithContext(ctx, input, options...) -} - -func (m *RDSMockByDBType) DescribeDBProxyEndpointsWithContext(ctx aws.Context, input *rds.DescribeDBProxyEndpointsInput, options ...request.Option) (*rds.DescribeDBProxyEndpointsOutput, error) { - return m.DBProxies.DescribeDBProxyEndpointsWithContext(ctx, input, options...) -} - -func (m *RDSMockByDBType) DescribeDBProxiesPagesWithContext(ctx aws.Context, input *rds.DescribeDBProxiesInput, fn func(*rds.DescribeDBProxiesOutput, bool) bool, options ...request.Option) error { - return m.DBProxies.DescribeDBProxiesPagesWithContext(ctx, input, fn, options...) -} - // checkEngineFilters checks RDS filters to detect unrecognized engine filters. -func checkEngineFilters(filters []*rds.Filter, engineVersions []*rds.DBEngineVersion) error { +func checkEngineFilters(filters []rdstypes.Filter, engineVersions []rdstypes.DBEngineVersion) error { if len(filters) == 0 { return nil } @@ -307,8 +205,8 @@ func checkEngineFilters(filters []*rds.Filter, engineVersions []*rds.DBEngineVer continue } for _, v := range f.Values { - if _, ok := recognizedEngines[aws.StringValue(v)]; !ok { - return trace.Errorf("unrecognized engine name %q", aws.StringValue(v)) + if _, ok := recognizedEngines[v]; !ok { + return trace.Errorf("unrecognized engine name %q", v) } } } @@ -316,11 +214,11 @@ func checkEngineFilters(filters []*rds.Filter, engineVersions []*rds.DBEngineVer } // applyInstanceFilters filters RDS DBInstances using the provided RDS filters. -func applyInstanceFilters(in []*rds.DBInstance, filters []*rds.Filter) ([]*rds.DBInstance, error) { +func applyInstanceFilters(in []rdstypes.DBInstance, filters []rdstypes.Filter) ([]rdstypes.DBInstance, error) { if len(filters) == 0 { return in, nil } - var out []*rds.DBInstance + var out []rdstypes.DBInstance efs := engineFilterSet(filters) clusterIDs := clusterIdentifierFilterSet(filters) for _, instance := range in { @@ -336,11 +234,11 @@ func applyInstanceFilters(in []*rds.DBInstance, filters []*rds.Filter) ([]*rds.D } // applyClusterFilters filters RDS DBClusters using the provided RDS filters. -func applyClusterFilters(in []*rds.DBCluster, filters []*rds.Filter) ([]*rds.DBCluster, error) { +func applyClusterFilters(in []rdstypes.DBCluster, filters []rdstypes.Filter) ([]rdstypes.DBCluster, error) { if len(filters) == 0 { return in, nil } - var out []*rds.DBCluster + var out []rdstypes.DBCluster efs := engineFilterSet(filters) for _, cluster := range in { if clusterEngineMatches(cluster, efs) { @@ -351,59 +249,59 @@ func applyClusterFilters(in []*rds.DBCluster, filters []*rds.Filter) ([]*rds.DBC } // engineFilterSet builds a string set of engine names from a list of RDS filters. -func engineFilterSet(filters []*rds.Filter) map[string]struct{} { +func engineFilterSet(filters []rdstypes.Filter) map[string]struct{} { return filterValues(filters, "engine") } // clusterIdentifierFilterSet builds a string set of ClusterIDs from a list of RDS filters. -func clusterIdentifierFilterSet(filters []*rds.Filter) map[string]struct{} { +func clusterIdentifierFilterSet(filters []rdstypes.Filter) map[string]struct{} { return filterValues(filters, "db-cluster-id") } -func filterValues(filters []*rds.Filter, filterKey string) map[string]struct{} { +func filterValues(filters []rdstypes.Filter, filterKey string) map[string]struct{} { out := make(map[string]struct{}) for _, f := range filters { if aws.StringValue(f.Name) != filterKey { continue } for _, v := range f.Values { - out[aws.StringValue(v)] = struct{}{} + out[v] = struct{}{} } } return out } // instanceEngineMatches returns whether an RDS DBInstance engine matches any engine name in a filter set. -func instanceEngineMatches(instance *rds.DBInstance, filterSet map[string]struct{}) bool { +func instanceEngineMatches(instance rdstypes.DBInstance, filterSet map[string]struct{}) bool { _, ok := filterSet[aws.StringValue(instance.Engine)] return ok } // instanceClusterIDMatches returns whether an RDS DBInstance ClusterID matches any ClusterID in a filter set. -func instanceClusterIDMatches(instance *rds.DBInstance, filterSet map[string]struct{}) bool { +func instanceClusterIDMatches(instance rdstypes.DBInstance, filterSet map[string]struct{}) bool { _, ok := filterSet[aws.StringValue(instance.DBClusterIdentifier)] return ok } // clusterEngineMatches returns whether an RDS DBCluster engine matches any engine name in a filter set. -func clusterEngineMatches(cluster *rds.DBCluster, filterSet map[string]struct{}) bool { +func clusterEngineMatches(cluster rdstypes.DBCluster, filterSet map[string]struct{}) bool { _, ok := filterSet[aws.StringValue(cluster.Engine)] return ok } -// RDSInstance returns a sample rds.DBInstance. -func RDSInstance(name, region string, labels map[string]string, opts ...func(*rds.DBInstance)) *rds.DBInstance { - instance := &rds.DBInstance{ +// RDSInstance returns a sample rdstypes.DBInstance. +func RDSInstance(name, region string, labels map[string]string, opts ...func(*rdstypes.DBInstance)) *rdstypes.DBInstance { + instance := &rdstypes.DBInstance{ DBInstanceArn: aws.String(fmt.Sprintf("arn:aws:rds:%v:123456789012:db:%v", region, name)), DBInstanceIdentifier: aws.String(name), DbiResourceId: aws.String(uuid.New().String()), Engine: aws.String("postgres"), DBInstanceStatus: aws.String("available"), - Endpoint: &rds.Endpoint{ + Endpoint: &rdstypes.Endpoint{ Address: aws.String(fmt.Sprintf("%v.aabbccdd.%v.rds.amazonaws.com", name, region)), - Port: aws.Int64(5432), + Port: aws.Int32(5432), }, - TagList: libcloudaws.LabelsToTags[rds.Tag](labels), + TagList: awstesthelpers.LabelsToRDSTags(labels), } for _, opt := range opts { opt(instance) @@ -411,9 +309,9 @@ func RDSInstance(name, region string, labels map[string]string, opts ...func(*rd return instance } -// RDSCluster returns a sample rds.DBCluster. -func RDSCluster(name, region string, labels map[string]string, opts ...func(*rds.DBCluster)) *rds.DBCluster { - cluster := &rds.DBCluster{ +// RDSCluster returns a sample rdstypes.DBCluster. +func RDSCluster(name, region string, labels map[string]string, opts ...func(*rdstypes.DBCluster)) *rdstypes.DBCluster { + cluster := &rdstypes.DBCluster{ DBClusterArn: aws.String(fmt.Sprintf("arn:aws:rds:%v:123456789012:cluster:%v", region, name)), DBClusterIdentifier: aws.String(name), DbClusterResourceId: aws.String(uuid.New().String()), @@ -422,9 +320,9 @@ func RDSCluster(name, region string, labels map[string]string, opts ...func(*rds Status: aws.String("available"), Endpoint: aws.String(fmt.Sprintf("%v.cluster-aabbccdd.%v.rds.amazonaws.com", name, region)), ReaderEndpoint: aws.String(fmt.Sprintf("%v.cluster-ro-aabbccdd.%v.rds.amazonaws.com", name, region)), - Port: aws.Int64(3306), - TagList: libcloudaws.LabelsToTags[rds.Tag](labels), - DBClusterMembers: []*rds.DBClusterMember{{ + Port: aws.Int32(3306), + TagList: awstesthelpers.LabelsToRDSTags(labels), + DBClusterMembers: []rdstypes.DBClusterMember{{ IsClusterWriter: aws.Bool(true), // One writer by default. }}, } @@ -434,49 +332,49 @@ func RDSCluster(name, region string, labels map[string]string, opts ...func(*rds return cluster } -func WithRDSClusterReader(cluster *rds.DBCluster) { - cluster.DBClusterMembers = append(cluster.DBClusterMembers, &rds.DBClusterMember{ +func WithRDSClusterReader(cluster *rdstypes.DBCluster) { + cluster.DBClusterMembers = append(cluster.DBClusterMembers, rdstypes.DBClusterMember{ IsClusterWriter: aws.Bool(false), // Add reader. }) } -func WithRDSClusterCustomEndpoint(name string) func(*rds.DBCluster) { - return func(cluster *rds.DBCluster) { +func WithRDSClusterCustomEndpoint(name string) func(*rdstypes.DBCluster) { + return func(cluster *rdstypes.DBCluster) { parsed, _ := arn.Parse(aws.StringValue(cluster.DBClusterArn)) - cluster.CustomEndpoints = append(cluster.CustomEndpoints, aws.String( + cluster.CustomEndpoints = append(cluster.CustomEndpoints, fmt.Sprintf("%v.cluster-custom-aabbccdd.%v.rds.amazonaws.com", name, parsed.Region), - )) + ) } } -// RDSProxy returns a sample rds.DBProxy. -func RDSProxy(name, region, vpcID string) *rds.DBProxy { - return &rds.DBProxy{ +// RDSProxy returns a sample rdstypes.DBProxy. +func RDSProxy(name, region, vpcID string) *rdstypes.DBProxy { + return &rdstypes.DBProxy{ DBProxyArn: aws.String(fmt.Sprintf("arn:aws:rds:%s:123456789012:db-proxy:prx-%s", region, name)), DBProxyName: aws.String(name), - EngineFamily: aws.String(rds.EngineFamilyMysql), + EngineFamily: aws.String(string(rdstypes.EngineFamilyMysql)), Endpoint: aws.String(fmt.Sprintf("%s.proxy-aabbccdd.%s.rds.amazonaws.com", name, region)), VpcId: aws.String(vpcID), RequireTLS: aws.Bool(true), - Status: aws.String("available"), + Status: "available", } } -// RDSProxyCustomEndpoint returns a sample rds.DBProxyEndpoint. -func RDSProxyCustomEndpoint(rdsProxy *rds.DBProxy, name, region string) *rds.DBProxyEndpoint { - return &rds.DBProxyEndpoint{ +// RDSProxyCustomEndpoint returns a sample rdstypes.DBProxyEndpoint. +func RDSProxyCustomEndpoint(rdsProxy *rdstypes.DBProxy, name, region string) *rdstypes.DBProxyEndpoint { + return &rdstypes.DBProxyEndpoint{ Endpoint: aws.String(fmt.Sprintf("%s.endpoint.proxy-aabbccdd.%s.rds.amazonaws.com", name, region)), DBProxyEndpointName: aws.String(name), DBProxyName: rdsProxy.DBProxyName, DBProxyEndpointArn: aws.String(fmt.Sprintf("arn:aws:rds:%v:123456789012:db-proxy-endpoint:prx-endpoint-%v", region, name)), - TargetRole: aws.String(rds.DBProxyEndpointTargetRoleReadOnly), - Status: aws.String("available"), + TargetRole: rdstypes.DBProxyEndpointTargetRoleReadOnly, + Status: "available", } } -// DocumentDBCluster returns a sample rds.DBCluster for DocumentDB. -func DocumentDBCluster(name, region string, labels map[string]string, opts ...func(*rds.DBCluster)) *rds.DBCluster { - cluster := &rds.DBCluster{ +// DocumentDBCluster returns a sample rdstypes.DBCluster for DocumentDB. +func DocumentDBCluster(name, region string, labels map[string]string, opts ...func(*rdstypes.DBCluster)) *rdstypes.DBCluster { + cluster := &rdstypes.DBCluster{ DBClusterArn: aws.String(fmt.Sprintf("arn:aws:rds:%v:123456789012:cluster:%v", region, name)), DBClusterIdentifier: aws.String(name), DbClusterResourceId: aws.String(uuid.New().String()), @@ -485,9 +383,9 @@ func DocumentDBCluster(name, region string, labels map[string]string, opts ...fu Status: aws.String("available"), Endpoint: aws.String(fmt.Sprintf("%v.cluster-aabbccdd.%v.docdb.amazonaws.com", name, region)), ReaderEndpoint: aws.String(fmt.Sprintf("%v.cluster-ro-aabbccdd.%v.docdb.amazonaws.com", name, region)), - Port: aws.Int64(27017), - TagList: libcloudaws.LabelsToTags[rds.Tag](labels), - DBClusterMembers: []*rds.DBClusterMember{{ + Port: aws.Int32(27017), + TagList: awstesthelpers.LabelsToRDSTags(labels), + DBClusterMembers: []rdstypes.DBClusterMember{{ IsClusterWriter: aws.Bool(true), // One writer by default. }}, } @@ -497,6 +395,6 @@ func DocumentDBCluster(name, region string, labels map[string]string, opts ...fu return cluster } -func WithDocumentDBClusterReader(cluster *rds.DBCluster) { +func WithDocumentDBClusterReader(cluster *rdstypes.DBCluster) { WithRDSClusterReader(cluster) } diff --git a/lib/srv/db/access_test.go b/lib/srv/db/access_test.go index 46c6ca1a19f53..906acfd06c7cd 100644 --- a/lib/srv/db/access_test.go +++ b/lib/srv/db/access_test.go @@ -2491,7 +2491,6 @@ func (p *agentParams) setDefaults(c *testContext) { if p.CloudClients == nil { p.CloudClients = &clients.TestCloudClients{ STS: &mocks.STSClientV1{}, - RDS: &mocks.RDSMock{}, RedshiftServerless: &mocks.RedshiftServerlessMock{}, ElastiCache: p.ElastiCache, MemoryDB: p.MemoryDB, @@ -2501,7 +2500,7 @@ func (p *agentParams) setDefaults(c *testContext) { } } if p.AWSConfigProvider == nil { - p.AWSConfigProvider = &mocks.AWSConfigProvider{} + p.AWSConfigProvider = &mocks.AWSConfigProvider{Err: trace.AccessDenied("AWS SDK clients are disabled for tests by default")} } if p.DiscoveryResourceChecker == nil { diff --git a/lib/srv/db/cloud/aws.go b/lib/srv/db/cloud/aws.go index 8222599c318a7..c336cb43230dd 100644 --- a/lib/srv/db/cloud/aws.go +++ b/lib/srv/db/cloud/aws.go @@ -23,21 +23,24 @@ import ( "encoding/json" "log/slog" + "github.com/aws/aws-sdk-go-v2/service/rds" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/iam" "github.com/aws/aws-sdk-go/service/iam/iamiface" - "github.com/aws/aws-sdk-go/service/rds" "github.com/gravitational/trace" "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/cloud" awslib "github.com/gravitational/teleport/lib/cloud/aws" + "github.com/gravitational/teleport/lib/cloud/awsconfig" dbiam "github.com/gravitational/teleport/lib/srv/db/common/iam" ) // awsConfig is the config for the client that configures IAM for AWS databases. type awsConfig struct { + // awsConfigProvider provides [aws.Config] for AWS SDK service clients. + awsConfigProvider awsconfig.Provider // clients is an interface for creating AWS clients. clients cloud.Clients // identity is AWS identity this database agent is running as. @@ -46,6 +49,9 @@ type awsConfig struct { database types.Database // policyName is the name of the inline policy for the identity. policyName string + // awsClients is an internal-only AWS SDK client provider that is + // only set in tests. + awsClients awsClientProvider } // Check validates the config. @@ -62,6 +68,12 @@ func (c *awsConfig) Check() error { if c.policyName == "" { return trace.BadParameter("missing parameter policy name") } + if c.awsConfigProvider == nil { + return trace.BadParameter("missing parameter awsConfigProvider") + } + if c.awsClients == nil { + return trace.BadParameter("missing parameter awsClients") + } return nil } @@ -75,7 +87,7 @@ func newAWS(ctx context.Context, config awsConfig) (*awsClient, error) { teleport.ComponentKey, "aws", "db", config.database.GetName(), ) - dbConfigurator, err := getDBConfigurator(logger, config.clients, config.database) + dbConfigurator, err := getDBConfigurator(logger, config) if err != nil { return nil, trace.Wrap(err) } @@ -102,10 +114,14 @@ type dbIAMAuthConfigurator interface { } // getDBConfigurator returns a database IAM Auth configurator. -func getDBConfigurator(logger *slog.Logger, clients cloud.Clients, db types.Database) (dbIAMAuthConfigurator, error) { - if db.IsRDS() { +func getDBConfigurator(logger *slog.Logger, cfg awsConfig) (dbIAMAuthConfigurator, error) { + if cfg.database.IsRDS() { // Only setting for RDS instances and Aurora clusters. - return &rdsDBConfigurator{clients: clients, logger: logger}, nil + return &rdsDBConfigurator{ + awsConfigProvider: cfg.awsConfigProvider, + logger: logger, + awsClients: cfg.awsClients, + }, nil } // IAM Auth for Redshift, ElastiCache, and RDS Proxy is always enabled. return &nopDBConfigurator{}, nil @@ -303,8 +319,9 @@ func (r *awsClient) detachIAMPolicy(ctx context.Context) error { } type rdsDBConfigurator struct { - clients cloud.Clients - logger *slog.Logger + awsConfigProvider awsconfig.Provider + logger *slog.Logger + awsClients awsClientProvider } // ensureIAMAuth enables RDS instance IAM auth if it isn't already enabled. @@ -323,30 +340,34 @@ func (r *rdsDBConfigurator) ensureIAMAuth(ctx context.Context, db types.Database func (r *rdsDBConfigurator) enableIAMAuth(ctx context.Context, db types.Database) error { r.logger.DebugContext(ctx, "Enabling IAM auth for RDS") meta := db.GetAWS() - rdsClt, err := r.clients.GetAWSRDSClient(ctx, meta.Region, - cloud.WithAssumeRoleFromAWSMeta(meta), - cloud.WithAmbientCredentials(), + if meta.RDS.ClusterID == "" && meta.RDS.InstanceID == "" { + return trace.BadParameter("no RDS cluster ID or instance ID for %v", db) + } + awsCfg, err := r.awsConfigProvider.GetConfig(ctx, meta.Region, + awsconfig.WithAssumeRole(meta.AssumeRoleARN, meta.ExternalID), + awsconfig.WithAmbientCredentials(), ) if err != nil { return trace.Wrap(err) } + clt := r.awsClients.getRDSClient(awsCfg) if meta.RDS.ClusterID != "" { - _, err = rdsClt.ModifyDBClusterWithContext(ctx, &rds.ModifyDBClusterInput{ + _, err = clt.ModifyDBCluster(ctx, &rds.ModifyDBClusterInput{ DBClusterIdentifier: aws.String(meta.RDS.ClusterID), EnableIAMDatabaseAuthentication: aws.Bool(true), ApplyImmediately: aws.Bool(true), }) - return awslib.ConvertIAMError(err) + return awslib.ConvertRequestFailureErrorV2(err) } if meta.RDS.InstanceID != "" { - _, err = rdsClt.ModifyDBInstanceWithContext(ctx, &rds.ModifyDBInstanceInput{ + _, err = clt.ModifyDBInstance(ctx, &rds.ModifyDBInstanceInput{ DBInstanceIdentifier: aws.String(meta.RDS.InstanceID), EnableIAMDatabaseAuthentication: aws.Bool(true), ApplyImmediately: aws.Bool(true), }) - return awslib.ConvertIAMError(err) + return awslib.ConvertRequestFailureErrorV2(err) } - return trace.BadParameter("no RDS cluster ID or instance ID for %v", db) + return nil } type nopDBConfigurator struct{} diff --git a/lib/srv/db/cloud/iam.go b/lib/srv/db/cloud/iam.go index aa1629157d78f..ef49e061e59b8 100644 --- a/lib/srv/db/cloud/iam.go +++ b/lib/srv/db/cloud/iam.go @@ -35,6 +35,7 @@ import ( "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/cloud" awslib "github.com/gravitational/teleport/lib/cloud/aws" + "github.com/gravitational/teleport/lib/cloud/awsconfig" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/db/common/iam" ) @@ -45,6 +46,8 @@ type IAMConfig struct { Clock clockwork.Clock // AccessPoint is a caching client connected to the Auth Server. AccessPoint authclient.DatabaseAccessPoint + // AWSConfigProvider provides [aws.Config] for AWS SDK service clients. + AWSConfigProvider awsconfig.Provider // Clients is an interface for retrieving cloud clients. Clients cloud.Clients // HostID is the host identified where this agent is running. @@ -52,6 +55,8 @@ type IAMConfig struct { HostID string // onProcessedTask is called after a task is processed. onProcessedTask func(processedTask iamTask, processError error) + // awsClients is an SDK client provider. + awsClients awsClientProvider } // Check validates the IAM configurator config. @@ -62,6 +67,9 @@ func (c *IAMConfig) Check() error { if c.AccessPoint == nil { return trace.BadParameter("missing AccessPoint") } + if c.AWSConfigProvider == nil { + return trace.BadParameter("missing AWSConfigProvider") + } if c.Clients == nil { cloudClients, err := cloud.NewClients() if err != nil { @@ -72,6 +80,9 @@ func (c *IAMConfig) Check() error { if c.HostID == "" { return trace.BadParameter("missing HostID") } + if c.awsClients == nil { + c.awsClients = defaultAWSClients{} + } return nil } @@ -233,10 +244,12 @@ func (c *IAM) getAWSConfigurator(ctx context.Context, database types.Database) ( return nil, trace.Wrap(err) } return newAWS(ctx, awsConfig{ - clients: c.cfg.Clients, - policyName: policyName, - identity: identity, - database: database, + awsConfigProvider: c.cfg.AWSConfigProvider, + clients: c.cfg.Clients, + database: database, + identity: identity, + policyName: policyName, + awsClients: c.cfg.awsClients, }) } diff --git a/lib/srv/db/cloud/iam_test.go b/lib/srv/db/cloud/iam_test.go index c3b9ecf3dd716..36397d6a64727 100644 --- a/lib/srv/db/cloud/iam_test.go +++ b/lib/srv/db/cloud/iam_test.go @@ -24,10 +24,10 @@ import ( "testing" "time" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/service/iam" - "github.com/aws/aws-sdk-go/service/rds" "github.com/google/uuid" "github.com/gravitational/trace" "github.com/stretchr/testify/require" @@ -46,26 +46,28 @@ func TestAWSIAM(t *testing.T) { t.Cleanup(cancel) // Setup AWS database objects. - rdsInstance := &rds.DBInstance{ + rdsInstance := &rdstypes.DBInstance{ DBInstanceArn: aws.String("arn:aws:rds:us-west-1:123456789012:db:postgres-rds"), DBInstanceIdentifier: aws.String("postgres-rds"), DbiResourceId: aws.String("db-xyz"), } - auroraCluster := &rds.DBCluster{ + auroraCluster := &rdstypes.DBCluster{ DBClusterArn: aws.String("arn:aws:rds:us-east-1:123456789012:cluster:postgres-aurora"), DBClusterIdentifier: aws.String("postgres-aurora"), DbClusterResourceId: aws.String("cluster-xyz"), } // Configure mocks. - stsClient := &mocks.STSClientV1{ - ARN: "arn:aws:iam::123456789012:role/test-role", + stsClient := &mocks.STSClient{ + STSClientV1: mocks.STSClientV1{ + ARN: "arn:aws:iam::123456789012:role/test-role", + }, } - rdsClient := &mocks.RDSMock{ - DBInstances: []*rds.DBInstance{rdsInstance}, - DBClusters: []*rds.DBCluster{auroraCluster}, + clt := &mocks.RDSClient{ + DBInstances: []rdstypes.DBInstance{*rdsInstance}, + DBClusters: []rdstypes.DBCluster{*auroraCluster}, } iamClient := &mocks.IAMMock{} @@ -152,15 +154,20 @@ func TestAWSIAM(t *testing.T) { } configurator, err := NewIAM(ctx, IAMConfig{ AccessPoint: &mockAccessPoint{}, + AWSConfigProvider: &mocks.AWSConfigProvider{ + STSClient: stsClient, + }, Clients: &clients.TestCloudClients{ - RDS: rdsClient, - STS: stsClient, + STS: &stsClient.STSClientV1, IAM: iamClient, }, HostID: "host-id", onProcessedTask: func(iamTask, error) { taskChan <- struct{}{} }, + awsClients: fakeAWSClients{ + rdsClient: clt, + }, }) require.NoError(t, err) require.NoError(t, configurator.Start(ctx)) @@ -177,6 +184,7 @@ func TestAWSIAM(t *testing.T) { database: rdsDatabase, wantPolicyContains: rdsDatabase.GetAWS().RDS.ResourceID, getIAMAuthEnabled: func() bool { + rdsInstance := &clt.DBInstances[0] out := aws.BoolValue(rdsInstance.IAMDatabaseAuthenticationEnabled) // reset it rdsInstance.IAMDatabaseAuthenticationEnabled = aws.Bool(false) @@ -187,6 +195,7 @@ func TestAWSIAM(t *testing.T) { database: auroraDatabase, wantPolicyContains: auroraDatabase.GetAWS().RDS.ResourceID, getIAMAuthEnabled: func() bool { + auroraCluster := &clt.DBClusters[0] out := aws.BoolValue(auroraCluster.IAMDatabaseAuthenticationEnabled) // reset it auroraCluster.IAMDatabaseAuthenticationEnabled = aws.Bool(false) @@ -291,6 +300,16 @@ func TestAWSIAMNoPermissions(t *testing.T) { AccessPoint: &mockAccessPoint{}, Clients: &clients.TestCloudClients{}, // placeholder, HostID: "host-id", + AWSConfigProvider: &mocks.AWSConfigProvider{ + STSClient: &mocks.STSClient{ + STSClientV1: mocks.STSClientV1{ + ARN: "arn:aws:iam::123456789012:role/test-role", + }, + }, + }, + awsClients: fakeAWSClients{ + rdsClient: &mocks.RDSClient{Unauth: true}, + }, }) require.NoError(t, err) @@ -303,7 +322,6 @@ func TestAWSIAMNoPermissions(t *testing.T) { name: "RDS database", meta: types.AWS{Region: "localhost", AccountID: "123456789012", RDS: types.RDS{InstanceID: "postgres-rds", ResourceID: "postgres-rds-resource-id"}}, clients: &clients.TestCloudClients{ - RDS: &mocks.RDSMockUnauth{}, IAM: &mocks.IAMErrorMock{ Error: trace.AccessDenied("unauthorized"), }, @@ -314,7 +332,6 @@ func TestAWSIAMNoPermissions(t *testing.T) { name: "Aurora cluster", meta: types.AWS{Region: "localhost", AccountID: "123456789012", RDS: types.RDS{ClusterID: "postgres-aurora", ResourceID: "postgres-aurora-resource-id"}}, clients: &clients.TestCloudClients{ - RDS: &mocks.RDSMockUnauth{}, IAM: &mocks.IAMErrorMock{ Error: trace.AccessDenied("unauthorized"), }, @@ -325,7 +342,6 @@ func TestAWSIAMNoPermissions(t *testing.T) { name: "RDS database missing metadata", meta: types.AWS{Region: "localhost", RDS: types.RDS{ClusterID: "postgres-aurora"}}, clients: &clients.TestCloudClients{ - RDS: &mocks.RDSMockUnauth{}, IAM: &mocks.IAMErrorMock{ Error: trace.AccessDenied("unauthorized"), }, diff --git a/lib/srv/db/cloud/meta.go b/lib/srv/db/cloud/meta.go index 98e2280fb1db5..0956759422b07 100644 --- a/lib/srv/db/cloud/meta.go +++ b/lib/srv/db/cloud/meta.go @@ -24,14 +24,14 @@ import ( "strings" "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/rds" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/aws/aws-sdk-go-v2/service/redshift" redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types" "github.com/aws/aws-sdk-go/service/elasticache" "github.com/aws/aws-sdk-go/service/elasticache/elasticacheiface" "github.com/aws/aws-sdk-go/service/memorydb" "github.com/aws/aws-sdk-go/service/memorydb/memorydbiface" - "github.com/aws/aws-sdk-go/service/rds" - "github.com/aws/aws-sdk-go/service/rds/rdsiface" "github.com/aws/aws-sdk-go/service/redshiftserverless" "github.com/aws/aws-sdk-go/service/redshiftserverless/redshiftserverlessiface" "github.com/gravitational/trace" @@ -45,13 +45,36 @@ import ( logutils "github.com/gravitational/teleport/lib/utils/log" ) +// rdsClient defines a subset of the AWS RDS client API. +type rdsClient interface { + rds.DescribeDBClustersAPIClient + rds.DescribeDBInstancesAPIClient + rds.DescribeDBProxiesAPIClient + rds.DescribeDBProxyEndpointsAPIClient + ModifyDBCluster(ctx context.Context, params *rds.ModifyDBClusterInput, optFns ...func(*rds.Options)) (*rds.ModifyDBClusterOutput, error) + ModifyDBInstance(ctx context.Context, params *rds.ModifyDBInstanceInput, optFns ...func(*rds.Options)) (*rds.ModifyDBInstanceOutput, error) +} + // redshiftClient defines a subset of the AWS Redshift client API. type redshiftClient interface { redshift.DescribeClustersAPIClient } -// redshiftClientProviderFunc provides a [redshiftClient]. -type redshiftClientProviderFunc func(cfg aws.Config, optFns ...func(*redshift.Options)) redshiftClient +// awsClientProvider is an AWS SDK client provider. +type awsClientProvider interface { + getRDSClient(cfg aws.Config, optFns ...func(*rds.Options)) rdsClient + getRedshiftClient(cfg aws.Config, optFns ...func(*redshift.Options)) redshiftClient +} + +type defaultAWSClients struct{} + +func (defaultAWSClients) getRDSClient(cfg aws.Config, optFns ...func(*rds.Options)) rdsClient { + return rds.NewFromConfig(cfg, optFns...) +} + +func (defaultAWSClients) getRedshiftClient(cfg aws.Config, optFns ...func(*redshift.Options)) redshiftClient { + return redshift.NewFromConfig(cfg, optFns...) +} // MetadataConfig is the cloud metadata service config. type MetadataConfig struct { @@ -60,9 +83,8 @@ type MetadataConfig struct { // AWSConfigProvider provides [aws.Config] for AWS SDK service clients. AWSConfigProvider awsconfig.Provider - // redshiftClientProviderFn is an internal-only [redshiftClient] provider - // func that is only set in tests. - redshiftClientProviderFn redshiftClientProviderFunc + // awsClients is an SDK client provider. + awsClients awsClientProvider } // Check validates the metadata service config. @@ -78,10 +100,8 @@ func (c *MetadataConfig) Check() error { return trace.BadParameter("missing AWSConfigProvider") } - if c.redshiftClientProviderFn == nil { - c.redshiftClientProviderFn = func(cfg aws.Config, optFns ...func(*redshift.Options)) redshiftClient { - return redshift.NewFromConfig(cfg, optFns...) - } + if c.awsClients == nil { + c.awsClients = defaultAWSClients{} } return nil } @@ -147,20 +167,21 @@ func (m *Metadata) updateAWS(ctx context.Context, database types.Database, fetch // fetchRDSMetadata fetches metadata for the provided RDS or Aurora database. func (m *Metadata) fetchRDSMetadata(ctx context.Context, database types.Database) (*types.AWS, error) { meta := database.GetAWS() - rds, err := m.cfg.Clients.GetAWSRDSClient(ctx, meta.Region, - cloud.WithAssumeRoleFromAWSMeta(meta), - cloud.WithAmbientCredentials(), + awsCfg, err := m.cfg.AWSConfigProvider.GetConfig(ctx, meta.Region, + awsconfig.WithAssumeRole(meta.AssumeRoleARN, meta.ExternalID), + awsconfig.WithAmbientCredentials(), ) if err != nil { return nil, trace.Wrap(err) } + clt := m.cfg.awsClients.getRDSClient(awsCfg) if meta.RDS.ClusterID != "" { - return fetchRDSClusterMetadata(ctx, rds, meta.RDS.ClusterID) + return fetchRDSClusterMetadata(ctx, clt, meta.RDS.ClusterID) } // Try to fetch the RDS instance fetchedMeta. - fetchedMeta, err := fetchRDSInstanceMetadata(ctx, rds, meta.RDS.InstanceID) + fetchedMeta, err := fetchRDSInstanceMetadata(ctx, clt, meta.RDS.InstanceID) if err != nil && !trace.IsNotFound(err) && !trace.IsAccessDenied(err) { return nil, trace.Wrap(err) } @@ -172,11 +193,11 @@ func (m *Metadata) fetchRDSMetadata(ctx context.Context, database types.Database if clusterID == "" { clusterID = meta.RDS.InstanceID } - return fetchRDSClusterMetadata(ctx, rds, clusterID) + return fetchRDSClusterMetadata(ctx, clt, clusterID) } // If instance was found, it may be a part of an Aurora cluster. if fetchedMeta.RDS.ClusterID != "" { - return fetchRDSClusterMetadata(ctx, rds, fetchedMeta.RDS.ClusterID) + return fetchRDSClusterMetadata(ctx, clt, fetchedMeta.RDS.ClusterID) } return fetchedMeta, nil } @@ -184,18 +205,19 @@ func (m *Metadata) fetchRDSMetadata(ctx context.Context, database types.Database // fetchRDSProxyMetadata fetches metadata for the provided RDS Proxy database. func (m *Metadata) fetchRDSProxyMetadata(ctx context.Context, database types.Database) (*types.AWS, error) { meta := database.GetAWS() - rds, err := m.cfg.Clients.GetAWSRDSClient(ctx, meta.Region, - cloud.WithAssumeRoleFromAWSMeta(meta), - cloud.WithAmbientCredentials(), + awsCfg, err := m.cfg.AWSConfigProvider.GetConfig(ctx, meta.Region, + awsconfig.WithAssumeRole(meta.AssumeRoleARN, meta.ExternalID), + awsconfig.WithAmbientCredentials(), ) if err != nil { return nil, trace.Wrap(err) } + clt := m.cfg.awsClients.getRDSClient(awsCfg) if meta.RDSProxy.CustomEndpointName != "" { - return fetchRDSProxyCustomEndpointMetadata(ctx, rds, meta.RDSProxy.CustomEndpointName, database.GetURI()) + return fetchRDSProxyCustomEndpointMetadata(ctx, clt, meta.RDSProxy.CustomEndpointName, database.GetURI()) } - return fetchRDSProxyMetadata(ctx, rds, meta.RDSProxy.Name) + return fetchRDSProxyMetadata(ctx, clt, meta.RDSProxy.Name) } // fetchRedshiftMetadata fetches metadata for the provided Redshift database. @@ -208,7 +230,7 @@ func (m *Metadata) fetchRedshiftMetadata(ctx context.Context, database types.Dat if err != nil { return nil, trace.Wrap(err) } - redshift := m.cfg.redshiftClientProviderFn(awsCfg) + redshift := m.cfg.awsClients.getRedshiftClient(awsCfg) cluster, err := describeRedshiftCluster(ctx, redshift, meta.Redshift.ClusterID) if err != nil { return nil, trace.Wrap(err) @@ -275,8 +297,8 @@ func (m *Metadata) fetchMemoryDBMetadata(ctx context.Context, database types.Dat } // fetchRDSInstanceMetadata fetches metadata about specified RDS instance. -func fetchRDSInstanceMetadata(ctx context.Context, rdsClient rdsiface.RDSAPI, instanceID string) (*types.AWS, error) { - rdsInstance, err := describeRDSInstance(ctx, rdsClient, instanceID) +func fetchRDSInstanceMetadata(ctx context.Context, clt rdsClient, instanceID string) (*types.AWS, error) { + rdsInstance, err := describeRDSInstance(ctx, clt, instanceID) if err != nil { return nil, trace.Wrap(err) } @@ -284,22 +306,22 @@ func fetchRDSInstanceMetadata(ctx context.Context, rdsClient rdsiface.RDSAPI, in } // describeRDSInstance returns AWS RDS instance for the specified ID. -func describeRDSInstance(ctx context.Context, rdsClient rdsiface.RDSAPI, instanceID string) (*rds.DBInstance, error) { - out, err := rdsClient.DescribeDBInstancesWithContext(ctx, &rds.DescribeDBInstancesInput{ +func describeRDSInstance(ctx context.Context, clt rdsClient, instanceID string) (*rdstypes.DBInstance, error) { + out, err := clt.DescribeDBInstances(ctx, &rds.DescribeDBInstancesInput{ DBInstanceIdentifier: aws.String(instanceID), }) if err != nil { return nil, common.ConvertError(err) } if len(out.DBInstances) != 1 { - return nil, trace.BadParameter("expected 1 RDS instance for %v, got %+v", instanceID, out.DBInstances) + return nil, trace.BadParameter("expected 1 RDS instance for %v, got %d", instanceID, len(out.DBInstances)) } - return out.DBInstances[0], nil + return &out.DBInstances[0], nil } // fetchRDSClusterMetadata fetches metadata about specified Aurora cluster. -func fetchRDSClusterMetadata(ctx context.Context, rdsClient rdsiface.RDSAPI, clusterID string) (*types.AWS, error) { - rdsCluster, err := describeRDSCluster(ctx, rdsClient, clusterID) +func fetchRDSClusterMetadata(ctx context.Context, clt rdsClient, clusterID string) (*types.AWS, error) { + rdsCluster, err := describeRDSCluster(ctx, clt, clusterID) if err != nil { return nil, trace.Wrap(err) } @@ -307,8 +329,8 @@ func fetchRDSClusterMetadata(ctx context.Context, rdsClient rdsiface.RDSAPI, clu } // describeRDSCluster returns AWS Aurora cluster for the specified ID. -func describeRDSCluster(ctx context.Context, rdsClient rdsiface.RDSAPI, clusterID string) (*rds.DBCluster, error) { - out, err := rdsClient.DescribeDBClustersWithContext(ctx, &rds.DescribeDBClustersInput{ +func describeRDSCluster(ctx context.Context, clt rdsClient, clusterID string) (*rdstypes.DBCluster, error) { + out, err := clt.DescribeDBClusters(ctx, &rds.DescribeDBClustersInput{ DBClusterIdentifier: aws.String(clusterID), }) if err != nil { @@ -317,7 +339,7 @@ func describeRDSCluster(ctx context.Context, rdsClient rdsiface.RDSAPI, clusterI if len(out.DBClusters) != 1 { return nil, trace.BadParameter("expected 1 RDS cluster for %v, got %+v", clusterID, out.DBClusters) } - return out.DBClusters[0], nil + return &out.DBClusters[0], nil } // describeRedshiftCluster returns AWS Redshift cluster for the specified ID. @@ -364,8 +386,8 @@ func describeMemoryDBCluster(ctx context.Context, client memorydbiface.MemoryDBA } // fetchRDSProxyMetadata fetches metadata about specified RDS Proxy name. -func fetchRDSProxyMetadata(ctx context.Context, rdsClient rdsiface.RDSAPI, proxyName string) (*types.AWS, error) { - rdsProxy, err := describeRDSProxy(ctx, rdsClient, proxyName) +func fetchRDSProxyMetadata(ctx context.Context, clt rdsClient, proxyName string) (*types.AWS, error) { + rdsProxy, err := describeRDSProxy(ctx, clt, proxyName) if err != nil { return nil, trace.Wrap(err) } @@ -373,28 +395,28 @@ func fetchRDSProxyMetadata(ctx context.Context, rdsClient rdsiface.RDSAPI, proxy } // describeRDSProxy returns AWS RDS Proxy for the specified RDS Proxy name. -func describeRDSProxy(ctx context.Context, rdsClient rdsiface.RDSAPI, proxyName string) (*rds.DBProxy, error) { - out, err := rdsClient.DescribeDBProxiesWithContext(ctx, &rds.DescribeDBProxiesInput{ +func describeRDSProxy(ctx context.Context, clt rdsClient, proxyName string) (*rdstypes.DBProxy, error) { + out, err := clt.DescribeDBProxies(ctx, &rds.DescribeDBProxiesInput{ DBProxyName: aws.String(proxyName), }) if err != nil { return nil, common.ConvertError(err) } if len(out.DBProxies) != 1 { - return nil, trace.BadParameter("expected 1 RDS Proxy for %v, got %s", proxyName, out.DBProxies) + return nil, trace.BadParameter("expected 1 RDS Proxy for %v, got %d", proxyName, len(out.DBProxies)) } - return out.DBProxies[0], nil + return &out.DBProxies[0], nil } // fetchRDSProxyCustomEndpointMetadata fetches metadata about specified RDS // proxy custom endpoint. -func fetchRDSProxyCustomEndpointMetadata(ctx context.Context, rdsClient rdsiface.RDSAPI, proxyEndpointName, uri string) (*types.AWS, error) { - rdsProxyEndpoint, err := describeRDSProxyCustomEndpointAndFindURI(ctx, rdsClient, proxyEndpointName, uri) +func fetchRDSProxyCustomEndpointMetadata(ctx context.Context, clt rdsClient, proxyEndpointName, uri string) (*types.AWS, error) { + rdsProxyEndpoint, err := describeRDSProxyCustomEndpointAndFindURI(ctx, clt, proxyEndpointName, uri) if err != nil { return nil, trace.Wrap(err) } - rdsProxy, err := describeRDSProxy(ctx, rdsClient, aws.ToString(rdsProxyEndpoint.DBProxyName)) + rdsProxy, err := describeRDSProxy(ctx, clt, aws.ToString(rdsProxyEndpoint.DBProxyName)) if err != nil { return nil, trace.Wrap(err) } @@ -404,21 +426,27 @@ func fetchRDSProxyCustomEndpointMetadata(ctx context.Context, rdsClient rdsiface // describeRDSProxyCustomEndpointAndFindURI returns AWS RDS Proxy endpoint for // the specified RDS Proxy custom endpoint. -func describeRDSProxyCustomEndpointAndFindURI(ctx context.Context, rdsClient rdsiface.RDSAPI, proxyEndpointName, uri string) (*rds.DBProxyEndpoint, error) { - out, err := rdsClient.DescribeDBProxyEndpointsWithContext(ctx, &rds.DescribeDBProxyEndpointsInput{ +func describeRDSProxyCustomEndpointAndFindURI(ctx context.Context, clt rdsClient, proxyEndpointName, uri string) (*rdstypes.DBProxyEndpoint, error) { + out, err := clt.DescribeDBProxyEndpoints(ctx, &rds.DescribeDBProxyEndpointsInput{ DBProxyEndpointName: aws.String(proxyEndpointName), }) if err != nil { return nil, common.ConvertError(err) } - for _, customEndpoint := range out.DBProxyEndpoints { + var endpoints []string + for _, e := range out.DBProxyEndpoints { + endpoint := aws.ToString(e.Endpoint) + if endpoint == "" { + continue + } // Double check if it has the same URI in case multiple custom // endpoints have the same name. - if strings.Contains(uri, aws.ToString(customEndpoint.Endpoint)) { - return customEndpoint, nil + if strings.Contains(uri, endpoint) { + return &e, nil } + endpoints = append(endpoints, endpoint) } - return nil, trace.BadParameter("could not find RDS Proxy custom endpoint %v with URI %v, got %s", proxyEndpointName, uri, out.DBProxyEndpoints) + return nil, trace.BadParameter("could not find RDS Proxy custom endpoint %v with URI %v, got %s", proxyEndpointName, uri, endpoints) } func fetchRedshiftServerlessWorkgroupMetadata(ctx context.Context, client redshiftserverlessiface.RedshiftServerlessAPI, workgroupName string) (*types.AWS, error) { diff --git a/lib/srv/db/cloud/meta_test.go b/lib/srv/db/cloud/meta_test.go index 9e66a416a2ebb..9c8805f026820 100644 --- a/lib/srv/db/cloud/meta_test.go +++ b/lib/srv/db/cloud/meta_test.go @@ -23,11 +23,12 @@ import ( "testing" "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/rds" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/aws/aws-sdk-go-v2/service/redshift" redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types" "github.com/aws/aws-sdk-go/service/elasticache" "github.com/aws/aws-sdk-go/service/memorydb" - "github.com/aws/aws-sdk-go/service/rds" "github.com/aws/aws-sdk-go/service/redshiftserverless" "github.com/stretchr/testify/require" @@ -40,8 +41,8 @@ import ( // TestAWSMetadata tests fetching AWS metadata for RDS and Redshift databases. func TestAWSMetadata(t *testing.T) { // Configure RDS API mock. - rds := &mocks.RDSMock{ - DBInstances: []*rds.DBInstance{ + rdsClt := &mocks.RDSClient{ + DBInstances: []rdstypes.DBInstance{ // Standalone RDS instance. { DBInstanceArn: aws.String("arn:aws:rds:us-west-1:123456789012:db:postgres-rds"), @@ -56,7 +57,7 @@ func TestAWSMetadata(t *testing.T) { DBClusterIdentifier: aws.String("postgres-aurora"), }, }, - DBClusters: []*rds.DBCluster{ + DBClusters: []rdstypes.DBCluster{ // Aurora cluster. { DBClusterArn: aws.String("arn:aws:rds:us-east-1:123456789012:cluster:postgres-aurora"), @@ -64,16 +65,17 @@ func TestAWSMetadata(t *testing.T) { DbClusterResourceId: aws.String("cluster-xyz"), }, }, - DBProxies: []*rds.DBProxy{ + DBProxies: []rdstypes.DBProxy{ { DBProxyArn: aws.String("arn:aws:rds:us-east-1:123456789012:db-proxy:prx-resource-id"), DBProxyName: aws.String("rds-proxy"), }, }, - DBProxyEndpoints: []*rds.DBProxyEndpoint{ + DBProxyEndpoints: []rdstypes.DBProxyEndpoint{ { DBProxyEndpointName: aws.String("rds-proxy-endpoint"), DBProxyName: aws.String("rds-proxy"), + Endpoint: aws.String("localhost"), }, }, } @@ -130,7 +132,6 @@ func TestAWSMetadata(t *testing.T) { // Create metadata fetcher. metadata, err := NewMetadata(MetadataConfig{ Clients: &cloud.TestCloudClients{ - RDS: rds, ElastiCache: elasticache, MemoryDB: memorydb, RedshiftServerless: redshiftServerless, @@ -139,7 +140,10 @@ func TestAWSMetadata(t *testing.T) { AWSConfigProvider: &mocks.AWSConfigProvider{ STSClient: fakeSTS, }, - redshiftClientProviderFn: newFakeRedshiftClientProvider(redshiftClt), + awsClients: fakeAWSClients{ + rdsClient: rdsClt, + redshiftClient: redshiftClt, + }, }) require.NoError(t, err) @@ -407,7 +411,7 @@ func TestAWSMetadata(t *testing.T) { // cause an error. func TestAWSMetadataNoPermissions(t *testing.T) { // Create unauthorized mocks. - rds := &mocks.RDSMockUnauth{} + rdsClt := &mocks.RDSClient{Unauth: true} redshiftClt := &mocks.RedshiftClient{Unauth: true} fakeSTS := &mocks.STSClient{} @@ -415,13 +419,15 @@ func TestAWSMetadataNoPermissions(t *testing.T) { // Create metadata fetcher. metadata, err := NewMetadata(MetadataConfig{ Clients: &cloud.TestCloudClients{ - RDS: rds, STS: &fakeSTS.STSClientV1, }, AWSConfigProvider: &mocks.AWSConfigProvider{ STSClient: fakeSTS, }, - redshiftClientProviderFn: newFakeRedshiftClientProvider(redshiftClt), + awsClients: fakeAWSClients{ + rdsClient: rdsClt, + redshiftClient: redshiftClt, + }, }) require.NoError(t, err) @@ -494,8 +500,15 @@ func TestAWSMetadataNoPermissions(t *testing.T) { } } -func newFakeRedshiftClientProvider(c redshiftClient) redshiftClientProviderFunc { - return func(cfg aws.Config, optFns ...func(*redshift.Options)) redshiftClient { - return c - } +type fakeAWSClients struct { + rdsClient rdsClient + redshiftClient redshiftClient +} + +func (f fakeAWSClients) getRDSClient(aws.Config, ...func(*rds.Options)) rdsClient { + return f.rdsClient +} + +func (f fakeAWSClients) getRedshiftClient(aws.Config, ...func(*redshift.Options)) redshiftClient { + return f.redshiftClient } diff --git a/lib/srv/db/cloud/resource_checker_url.go b/lib/srv/db/cloud/resource_checker_url.go index fdc4efdb65fe9..da8dd40fac772 100644 --- a/lib/srv/db/cloud/resource_checker_url.go +++ b/lib/srv/db/cloud/resource_checker_url.go @@ -28,7 +28,6 @@ import ( "sync" "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/redshift" "github.com/gravitational/trace" "github.com/gravitational/teleport/api/types" @@ -42,9 +41,8 @@ import ( type urlChecker struct { // awsConfigProvider provides [aws.Config] for AWS SDK service clients. awsConfigProvider awsconfig.Provider - // redshiftClientProviderFn is an internal-only [redshiftClient] provider - // func that is only set in tests. - redshiftClientProviderFn redshiftClientProviderFunc + // awsClients is an SDK client provider. + awsClients awsClientProvider clients cloud.Clients logger *slog.Logger @@ -61,12 +59,10 @@ type urlChecker struct { func newURLChecker(cfg DiscoveryResourceCheckerConfig) *urlChecker { return &urlChecker{ awsConfigProvider: cfg.AWSConfigProvider, - redshiftClientProviderFn: func(cfg aws.Config, optFns ...func(*redshift.Options)) redshiftClient { - return redshift.NewFromConfig(cfg, optFns...) - }, - clients: cfg.Clients, - logger: cfg.Logger, - warnOnError: getWarnOnError(), + awsClients: defaultAWSClients{}, + clients: cfg.Clients, + logger: cfg.Logger, + warnOnError: getWarnOnError(), } } diff --git a/lib/srv/db/cloud/resource_checker_url_aws.go b/lib/srv/db/cloud/resource_checker_url_aws.go index 336ee197815fb..5b87d643ea7b7 100644 --- a/lib/srv/db/cloud/resource_checker_url_aws.go +++ b/lib/srv/db/cloud/resource_checker_url_aws.go @@ -21,10 +21,9 @@ package cloud import ( "context" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/opensearchservice" - "github.com/aws/aws-sdk-go/service/rds" - "github.com/aws/aws-sdk-go/service/rds/rdsiface" "github.com/aws/aws-sdk-go/service/redshiftserverless/redshiftserverlessiface" "github.com/gravitational/trace" @@ -82,22 +81,23 @@ func (c *urlChecker) logAWSAccessDeniedError(ctx context.Context, database types func (c *urlChecker) checkRDS(ctx context.Context, database types.Database) error { meta := database.GetAWS() - rdsClient, err := c.clients.GetAWSRDSClient(ctx, meta.Region, - cloud.WithAssumeRoleFromAWSMeta(meta), - cloud.WithAmbientCredentials(), + awsCfg, err := c.awsConfigProvider.GetConfig(ctx, meta.Region, + awsconfig.WithAssumeRole(meta.AssumeRoleARN, meta.ExternalID), + awsconfig.WithAmbientCredentials(), ) if err != nil { return trace.Wrap(err) } + clt := c.awsClients.getRDSClient(awsCfg) if meta.RDS.ClusterID != "" { - return trace.Wrap(c.checkRDSCluster(ctx, database, rdsClient, meta.RDS.ClusterID)) + return trace.Wrap(c.checkRDSCluster(ctx, database, clt, meta.RDS.ClusterID)) } - return trace.Wrap(c.checkRDSInstance(ctx, database, rdsClient, meta.RDS.InstanceID)) + return trace.Wrap(c.checkRDSInstance(ctx, database, clt, meta.RDS.InstanceID)) } -func (c *urlChecker) checkRDSInstance(ctx context.Context, database types.Database, rdsClient rdsiface.RDSAPI, instanceID string) error { - rdsInstance, err := describeRDSInstance(ctx, rdsClient, instanceID) +func (c *urlChecker) checkRDSInstance(ctx context.Context, database types.Database, clt rdsClient, instanceID string) error { + rdsInstance, err := describeRDSInstance(ctx, clt, instanceID) if err != nil { return trace.Wrap(err) } @@ -107,12 +107,12 @@ func (c *urlChecker) checkRDSInstance(ctx context.Context, database types.Databa return trace.Wrap(requireDatabaseAddressPort(database, rdsInstance.Endpoint.Address, rdsInstance.Endpoint.Port)) } -func (c *urlChecker) checkRDSCluster(ctx context.Context, database types.Database, rdsClient rdsiface.RDSAPI, clusterID string) error { - rdsCluster, err := describeRDSCluster(ctx, rdsClient, clusterID) +func (c *urlChecker) checkRDSCluster(ctx context.Context, database types.Database, clt rdsClient, clusterID string) error { + rdsCluster, err := describeRDSCluster(ctx, clt, clusterID) if err != nil { return trace.Wrap(err) } - databases, err := common.NewDatabasesFromRDSCluster(rdsCluster, []*rds.DBInstance{}) + databases, err := common.NewDatabasesFromRDSCluster(rdsCluster, []rdstypes.DBInstance{}) if err != nil { c.logger.WarnContext(ctx, "Could not convert RDS cluster to database resources", "cluster", aws.StringValue(rdsCluster.DBClusterIdentifier), @@ -130,21 +130,22 @@ func (c *urlChecker) checkRDSCluster(ctx context.Context, database types.Databas func (c *urlChecker) checkRDSProxy(ctx context.Context, database types.Database) error { meta := database.GetAWS() - rdsClient, err := c.clients.GetAWSRDSClient(ctx, meta.Region, - cloud.WithAssumeRoleFromAWSMeta(meta), - cloud.WithAmbientCredentials(), + awsCfg, err := c.awsConfigProvider.GetConfig(ctx, meta.Region, + awsconfig.WithAssumeRole(meta.AssumeRoleARN, meta.ExternalID), + awsconfig.WithAmbientCredentials(), ) if err != nil { return trace.Wrap(err) } + clt := c.awsClients.getRDSClient(awsCfg) if meta.RDSProxy.CustomEndpointName != "" { - return trace.Wrap(c.checkRDSProxyCustomEndpoint(ctx, database, rdsClient, meta.RDSProxy.CustomEndpointName)) + return trace.Wrap(c.checkRDSProxyCustomEndpoint(ctx, database, clt, meta.RDSProxy.CustomEndpointName)) } - return trace.Wrap(c.checkRDSProxyPrimaryEndpoint(ctx, database, rdsClient, meta.RDSProxy.Name)) + return trace.Wrap(c.checkRDSProxyPrimaryEndpoint(ctx, database, clt, meta.RDSProxy.Name)) } -func (c *urlChecker) checkRDSProxyPrimaryEndpoint(ctx context.Context, database types.Database, rdsClient rdsiface.RDSAPI, proxyName string) error { - rdsProxy, err := describeRDSProxy(ctx, rdsClient, proxyName) +func (c *urlChecker) checkRDSProxyPrimaryEndpoint(ctx context.Context, database types.Database, clt rdsClient, proxyName string) error { + rdsProxy, err := describeRDSProxy(ctx, clt, proxyName) if err != nil { return trace.Wrap(err) } @@ -153,8 +154,8 @@ func (c *urlChecker) checkRDSProxyPrimaryEndpoint(ctx context.Context, database return requireDatabaseHost(database, aws.StringValue(rdsProxy.Endpoint)) } -func (c *urlChecker) checkRDSProxyCustomEndpoint(ctx context.Context, database types.Database, rdsClient rdsiface.RDSAPI, proxyEndpointName string) error { - _, err := describeRDSProxyCustomEndpointAndFindURI(ctx, rdsClient, proxyEndpointName, database.GetURI()) +func (c *urlChecker) checkRDSProxyCustomEndpoint(ctx context.Context, database types.Database, clt rdsClient, proxyEndpointName string) error { + _, err := describeRDSProxyCustomEndpointAndFindURI(ctx, clt, proxyEndpointName, database.GetURI()) return trace.Wrap(err) } @@ -167,7 +168,7 @@ func (c *urlChecker) checkRedshift(ctx context.Context, database types.Database) if err != nil { return trace.Wrap(err) } - redshift := c.redshiftClientProviderFn(awsCfg) + redshift := c.awsClients.getRedshiftClient(awsCfg) cluster, err := describeRedshiftCluster(ctx, redshift, meta.Redshift.ClusterID) if err != nil { return trace.Wrap(err) @@ -290,15 +291,16 @@ func (c *urlChecker) checkOpenSearchEndpoint(ctx context.Context, database types func (c *urlChecker) checkDocumentDB(ctx context.Context, database types.Database) error { meta := database.GetAWS() - rdsClient, err := c.clients.GetAWSRDSClient(ctx, meta.Region, - cloud.WithAssumeRoleFromAWSMeta(meta), - cloud.WithAmbientCredentials(), + awsCfg, err := c.awsConfigProvider.GetConfig(ctx, meta.Region, + awsconfig.WithAssumeRole(meta.AssumeRoleARN, meta.ExternalID), + awsconfig.WithAmbientCredentials(), ) if err != nil { return trace.Wrap(err) } + clt := c.awsClients.getRDSClient(awsCfg) - cluster, err := describeRDSCluster(ctx, rdsClient, meta.DocumentDB.ClusterID) + cluster, err := describeRDSCluster(ctx, clt, meta.DocumentDB.ClusterID) if err != nil { return trace.Wrap(err) } diff --git a/lib/srv/db/cloud/resource_checker_url_aws_test.go b/lib/srv/db/cloud/resource_checker_url_aws_test.go index e8ba24f624c16..40095f7efafe0 100644 --- a/lib/srv/db/cloud/resource_checker_url_aws_test.go +++ b/lib/srv/db/cloud/resource_checker_url_aws_test.go @@ -22,11 +22,11 @@ import ( "context" "testing" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types" "github.com/aws/aws-sdk-go/service/elasticache" "github.com/aws/aws-sdk-go/service/memorydb" "github.com/aws/aws-sdk-go/service/opensearchservice" - "github.com/aws/aws-sdk-go/service/rds" "github.com/aws/aws-sdk-go/service/redshiftserverless" "github.com/stretchr/testify/require" @@ -54,7 +54,7 @@ func TestURLChecker_AWS(t *testing.T) { mocks.WithRDSClusterReader, mocks.WithRDSClusterCustomEndpoint("my-custom"), ) - rdsClusterDBs, err := common.NewDatabasesFromRDSCluster(rdsCluster, []*rds.DBInstance{}) + rdsClusterDBs, err := common.NewDatabasesFromRDSCluster(rdsCluster, []rdstypes.DBInstance{}) require.NoError(t, err) require.Len(t, rdsClusterDBs, 3) // Primary, reader, custom. testCases = append(testCases, append(rdsClusterDBs, rdsInstanceDB)...) @@ -121,12 +121,6 @@ func TestURLChecker_AWS(t *testing.T) { // Mock cloud clients. mockClients := &cloud.TestCloudClients{ - RDS: &mocks.RDSMock{ - DBInstances: []*rds.DBInstance{rdsInstance}, - DBClusters: []*rds.DBCluster{rdsCluster, docdbCluster}, - DBProxies: []*rds.DBProxy{rdsProxy}, - DBProxyEndpoints: []*rds.DBProxyEndpoint{rdsProxyCustomEndpoint}, - }, RedshiftServerless: &mocks.RedshiftServerlessMock{ Workgroups: []*redshiftserverless.Workgroup{redshiftServerlessWorkgroup}, Endpoints: []*redshiftserverless.EndpointAccess{redshiftServerlessVPCEndpoint}, @@ -143,7 +137,6 @@ func TestURLChecker_AWS(t *testing.T) { STS: &mocks.STSClientV1{}, } mockClientsUnauth := &cloud.TestCloudClients{ - RDS: &mocks.RDSMockUnauth{}, RedshiftServerless: &mocks.RedshiftServerlessMock{Unauth: true}, ElastiCache: &mocks.ElastiCacheMock{Unauth: true}, MemoryDB: &mocks.MemoryDBMock{Unauth: true}, @@ -158,21 +151,32 @@ func TestURLChecker_AWS(t *testing.T) { name string clients cloud.Clients awsConfigProvider awsconfig.Provider - redshiftClient redshiftClient + awsClients awsClientProvider }{ { name: "API check", clients: mockClients, awsConfigProvider: &mocks.AWSConfigProvider{}, - redshiftClient: &mocks.RedshiftClient{ - Clusters: []redshifttypes.Cluster{redshiftCluster}, + awsClients: fakeAWSClients{ + rdsClient: &mocks.RDSClient{ + DBInstances: []rdstypes.DBInstance{*rdsInstance}, + DBClusters: []rdstypes.DBCluster{*rdsCluster, *docdbCluster}, + DBProxies: []rdstypes.DBProxy{*rdsProxy}, + DBProxyEndpoints: []rdstypes.DBProxyEndpoint{*rdsProxyCustomEndpoint}, + }, + redshiftClient: &mocks.RedshiftClient{ + Clusters: []redshifttypes.Cluster{redshiftCluster}, + }, }, }, { name: "basic endpoint check", clients: mockClientsUnauth, awsConfigProvider: &mocks.AWSConfigProvider{}, - redshiftClient: &mocks.RedshiftClient{Unauth: true}, + awsClients: fakeAWSClients{ + rdsClient: &mocks.RDSClient{Unauth: true}, + redshiftClient: &mocks.RedshiftClient{Unauth: true}, + }, }, } @@ -183,7 +187,7 @@ func TestURLChecker_AWS(t *testing.T) { AWSConfigProvider: method.awsConfigProvider, Logger: utils.NewSlogLoggerForTests(), }) - c.redshiftClientProviderFn = newFakeRedshiftClientProvider(method.redshiftClient) + c.awsClients = method.awsClients for _, database := range testCases { t.Run(database.GetName(), func(t *testing.T) { diff --git a/lib/srv/db/common/auth.go b/lib/srv/db/common/auth.go index e567d82d402e0..ad7183e70563b 100644 --- a/lib/srv/db/common/auth.go +++ b/lib/srv/db/common/auth.go @@ -35,12 +35,12 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/aws/aws-sdk-go-v2/aws" + rdsauth "github.com/aws/aws-sdk-go-v2/feature/rds/auth" "github.com/aws/aws-sdk-go-v2/service/redshift" "github.com/aws/aws-sdk-go/aws/credentials" v4 "github.com/aws/aws-sdk-go/aws/signer/v4" "github.com/aws/aws-sdk-go/service/elasticache" "github.com/aws/aws-sdk-go/service/memorydb" - "github.com/aws/aws-sdk-go/service/rds/rdsutils" "github.com/aws/aws-sdk-go/service/redshiftserverless" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" @@ -131,8 +131,16 @@ type redshiftClient interface { GetClusterCredentials(context.Context, *redshift.GetClusterCredentialsInput, ...func(*redshift.Options)) (*redshift.GetClusterCredentialsOutput, error) } -// redshiftClientProviderFunc provides a [redshiftClient]. -type redshiftClientProviderFunc func(cfg aws.Config, optFns ...func(*redshift.Options)) redshiftClient +// awsClientProvider is an AWS SDK client provider. +type awsClientProvider interface { + getRedshiftClient(cfg aws.Config, optFns ...func(*redshift.Options)) redshiftClient +} + +type defaultAWSClients struct{} + +func (defaultAWSClients) getRedshiftClient(cfg aws.Config, optFns ...func(*redshift.Options)) redshiftClient { + return redshift.NewFromConfig(cfg, optFns...) +} // AuthConfig is the database access authenticator configuration. type AuthConfig struct { @@ -149,10 +157,8 @@ type AuthConfig struct { // AWSConfigProvider provides [aws.Config] for AWS SDK service clients. AWSConfigProvider awsconfig.Provider - // redshiftClientProviderFn is an internal-only [redshiftClient] provider - // func that defaults to a func that provides a real Redshift client. - // The default is only overridden in tests. - redshiftClientProviderFn redshiftClientProviderFunc + // awsClients is an SDK client provider. + awsClients awsClientProvider } // CheckAndSetDefaults validates the config and sets defaults. @@ -176,10 +182,8 @@ func (c *AuthConfig) CheckAndSetDefaults() error { c.Logger = slog.With(teleport.ComponentKey, "db:auth") } - if c.redshiftClientProviderFn == nil { - c.redshiftClientProviderFn = func(cfg aws.Config, optFns ...func(*redshift.Options)) redshiftClient { - return redshift.NewFromConfig(cfg, optFns...) - } + if c.awsClients == nil { + c.awsClients = defaultAWSClients{} } return nil } @@ -243,9 +247,9 @@ func (a *dbAuth) WithLogger(getUpdatedLogger func(*slog.Logger) *slog.Logger) Au // when connecting to RDS and Aurora databases. func (a *dbAuth) GetRDSAuthToken(ctx context.Context, database types.Database, databaseUser string) (string, error) { meta := database.GetAWS() - awsSession, err := a.cfg.Clients.GetAWSSession(ctx, meta.Region, - cloud.WithAssumeRoleFromAWSMeta(meta), - cloud.WithAmbientCredentials(), + awsCfg, err := a.cfg.AWSConfigProvider.GetConfig(ctx, meta.Region, + awsconfig.WithAssumeRole(meta.AssumeRoleARN, meta.ExternalID), + awsconfig.WithAmbientCredentials(), ) if err != nil { return "", trace.Wrap(err) @@ -254,11 +258,13 @@ func (a *dbAuth) GetRDSAuthToken(ctx context.Context, database types.Database, d "database", database, "database_user", databaseUser, ) - token, err := rdsutils.BuildAuthToken( + token, err := rdsauth.BuildAuthToken( + ctx, database.GetURI(), meta.Region, databaseUser, - awsSession.Config.Credentials) + awsCfg.Credentials, + ) if err != nil { policy, getPolicyErr := dbiam.GetReadableAWSPolicyDocument(database) if getPolicyErr != nil { @@ -316,7 +322,7 @@ Make sure that IAM role %q has a trust relationship with Teleport database agent "database_user", databaseUser, "database_name", databaseName, ) - client := a.cfg.redshiftClientProviderFn(awsCfg) + client := a.cfg.awsClients.getRedshiftClient(awsCfg) resp, err := client.GetClusterCredentialsWithIAM(ctx, &redshift.GetClusterCredentialsWithIAMInput{ ClusterIdentifier: aws.String(meta.Redshift.ClusterID), DbName: aws.String(databaseName), @@ -352,7 +358,7 @@ func (a *dbAuth) getRedshiftDBUserAuthToken(ctx context.Context, database types. "database_user", databaseUser, "database_name", databaseName, ) - clt := a.cfg.redshiftClientProviderFn(awsCfg) + clt := a.cfg.awsClients.getRedshiftClient(awsCfg) resp, err := clt.GetClusterCredentials(ctx, &redshift.GetClusterCredentialsInput{ ClusterIdentifier: aws.String(meta.Redshift.ClusterID), DbUser: aws.String(databaseUser), diff --git a/lib/srv/db/common/auth_test.go b/lib/srv/db/common/auth_test.go index 63d79af27e500..d85df87c5fd54 100644 --- a/lib/srv/db/common/auth_test.go +++ b/lib/srv/db/common/auth_test.go @@ -609,7 +609,6 @@ func TestAuthGetAWSTokenWithAssumedRole(t *testing.T) { AccessPoint: new(accessPointMock), Clients: &cloud.TestCloudClients{ STS: &fakeSTS.STSClientV1, - RDS: &mocks.RDSMock{}, RedshiftServerless: &mocks.RedshiftServerlessMock{ GetCredentialsOutput: mocks.RedshiftServerlessGetCredentialsOutput("IAM:some-user", "some-password", clock), }, @@ -617,10 +616,12 @@ func TestAuthGetAWSTokenWithAssumedRole(t *testing.T) { AWSConfigProvider: &mocks.AWSConfigProvider{ STSClient: fakeSTS, }, - redshiftClientProviderFn: newFakeRedshiftClientProvider(&mocks.RedshiftClient{ - GetClusterCredentialsOutput: mocks.RedshiftGetClusterCredentialsOutput("IAM:some-user", "some-password", clock), - GetClusterCredentialsWithIAMOutput: mocks.RedshiftGetClusterCredentialsWithIAMOutput("IAM:some-role", "some-password-for-some-role", clock), - }), + awsClients: fakeAWSClients{ + redshiftClient: &mocks.RedshiftClient{ + GetClusterCredentialsOutput: mocks.RedshiftGetClusterCredentialsOutput("IAM:some-user", "some-password", clock), + GetClusterCredentialsWithIAMOutput: mocks.RedshiftGetClusterCredentialsWithIAMOutput("IAM:some-role", "some-password-for-some-role", clock), + }, + }, }) require.NoError(t, err) @@ -1020,8 +1021,10 @@ func (m *imdsMock) GetType() types.InstanceMetadataType { return m.instanceType } -func newFakeRedshiftClientProvider(c redshiftClient) redshiftClientProviderFunc { - return func(cfg aws.Config, optFns ...func(*redshift.Options)) redshiftClient { - return c - } +type fakeAWSClients struct { + redshiftClient redshiftClient +} + +func (f fakeAWSClients) getRedshiftClient(aws.Config, ...func(*redshift.Options)) redshiftClient { + return f.redshiftClient } diff --git a/lib/srv/db/server.go b/lib/srv/db/server.go index 28fcc486bf4db..dfb1a4b164192 100644 --- a/lib/srv/db/server.go +++ b/lib/srv/db/server.go @@ -259,9 +259,10 @@ func (c *Config) CheckAndSetDefaults(ctx context.Context) (err error) { } if c.CloudIAM == nil { c.CloudIAM, err = cloud.NewIAM(ctx, cloud.IAMConfig{ - AccessPoint: c.AccessPoint, - Clients: c.CloudClients, - HostID: c.HostID, + AccessPoint: c.AccessPoint, + AWSConfigProvider: c.AWSConfigProvider, + Clients: c.CloudClients, + HostID: c.HostID, }) if err != nil { return trace.Wrap(err) diff --git a/lib/srv/db/watcher_test.go b/lib/srv/db/watcher_test.go index 8a7750a26a07a..6020547ea9590 100644 --- a/lib/srv/db/watcher_test.go +++ b/lib/srv/db/watcher_test.go @@ -320,7 +320,6 @@ func TestWatcherCloudFetchers(t *testing.T) { reconcileCh <- d }, CloudClients: &clients.TestCloudClients{ - RDS: &mocks.RDSMockUnauth{}, // Access denied error should not affect other fetchers. RedshiftServerless: &mocks.RedshiftServerlessMock{ Workgroups: []*redshiftserverless.Workgroup{redshiftServerlessWorkgroup}, }, @@ -358,7 +357,7 @@ func assertReconciledResource(t *testing.T, ch chan types.Databases, databases t cmpopts.IgnoreFields(types.DatabaseStatusV3{}, "CACert"), )) case <-time.After(time.Second): - t.Fatal("Didn't receive reconcile event after 1s.") + require.FailNow(t, "Didn't receive reconcile event after 1s.") } } diff --git a/lib/srv/discovery/access_graph.go b/lib/srv/discovery/access_graph.go index 9d6d344ac9fda..f19d902068daf 100644 --- a/lib/srv/discovery/access_graph.go +++ b/lib/srv/discovery/access_graph.go @@ -501,6 +501,7 @@ func (s *Server) accessGraphFetchersFromMatchers(ctx context.Context, matchers M fetcher, err := aws_sync.NewAWSFetcher( ctx, aws_sync.Config{ + AWSConfigProvider: s.AWSConfigProvider, CloudClients: s.CloudClients, GetEKSClient: s.GetAWSSyncEKSClient, GetEC2Client: s.GetEC2Client, diff --git a/lib/srv/discovery/common/database.go b/lib/srv/discovery/common/database.go index 8afe335f87fcb..dcff7a2c0f614 100644 --- a/lib/srv/discovery/common/database.go +++ b/lib/srv/discovery/common/database.go @@ -35,7 +35,6 @@ import ( "github.com/aws/aws-sdk-go/service/elasticache" "github.com/aws/aws-sdk-go/service/memorydb" "github.com/aws/aws-sdk-go/service/opensearchservice" - "github.com/aws/aws-sdk-go/service/rds" "github.com/aws/aws-sdk-go/service/redshiftserverless" "github.com/gravitational/trace" @@ -286,7 +285,7 @@ func NewDatabaseFromAzurePostgresFlexServer(server *armpostgresqlflexibleservers } // NewDatabaseFromRDSInstance creates a database resource from an RDS instance. -func NewDatabaseFromRDSInstance(instance *rds.DBInstance) (types.Database, error) { +func NewDatabaseFromRDSInstance(instance *rdstypes.DBInstance) (types.Database, error) { endpoint := instance.Endpoint if endpoint == nil { return nil, trace.BadParameter("empty endpoint") @@ -307,7 +306,7 @@ func NewDatabaseFromRDSInstance(instance *rds.DBInstance) (types.Database, error }, aws.ToString(instance.DBInstanceIdentifier)), types.DatabaseSpecV3{ Protocol: protocol, - URI: fmt.Sprintf("%v:%v", aws.ToString(endpoint.Address), aws.ToInt64(endpoint.Port)), + URI: fmt.Sprintf("%v:%v", aws.ToString(endpoint.Address), aws.ToInt32(endpoint.Port)), AWS: *metadata, }) } @@ -492,7 +491,7 @@ func labelsFromRDSV2Cluster(rdsCluster *rdstypes.DBCluster, meta *types.AWS, end } // NewDatabaseFromRDSCluster creates a database resource from an RDS cluster (Aurora). -func NewDatabaseFromRDSCluster(cluster *rds.DBCluster, memberInstances []*rds.DBInstance) (types.Database, error) { +func NewDatabaseFromRDSCluster(cluster *rdstypes.DBCluster, memberInstances []rdstypes.DBInstance) (types.Database, error) { metadata, err := MetadataFromRDSCluster(cluster) if err != nil { return nil, trace.Wrap(err) @@ -508,13 +507,13 @@ func NewDatabaseFromRDSCluster(cluster *rds.DBCluster, memberInstances []*rds.DB }, aws.ToString(cluster.DBClusterIdentifier)), types.DatabaseSpecV3{ Protocol: protocol, - URI: fmt.Sprintf("%v:%v", aws.ToString(cluster.Endpoint), aws.ToInt64(cluster.Port)), + URI: fmt.Sprintf("%v:%v", aws.ToString(cluster.Endpoint), aws.ToInt32(cluster.Port)), AWS: *metadata, }) } // NewDatabaseFromRDSClusterReaderEndpoint creates a database resource from an RDS cluster reader endpoint (Aurora). -func NewDatabaseFromRDSClusterReaderEndpoint(cluster *rds.DBCluster, memberInstances []*rds.DBInstance) (types.Database, error) { +func NewDatabaseFromRDSClusterReaderEndpoint(cluster *rdstypes.DBCluster, memberInstances []rdstypes.DBInstance) (types.Database, error) { metadata, err := MetadataFromRDSCluster(cluster) if err != nil { return nil, trace.Wrap(err) @@ -530,13 +529,13 @@ func NewDatabaseFromRDSClusterReaderEndpoint(cluster *rds.DBCluster, memberInsta }, aws.ToString(cluster.DBClusterIdentifier), apiawsutils.RDSEndpointTypeReader), types.DatabaseSpecV3{ Protocol: protocol, - URI: fmt.Sprintf("%v:%v", aws.ToString(cluster.ReaderEndpoint), aws.ToInt64(cluster.Port)), + URI: fmt.Sprintf("%v:%v", aws.ToString(cluster.ReaderEndpoint), aws.ToInt32(cluster.Port)), AWS: *metadata, }) } // NewDatabasesFromRDSClusterCustomEndpoints creates database resources from RDS cluster custom endpoints (Aurora). -func NewDatabasesFromRDSClusterCustomEndpoints(cluster *rds.DBCluster, memberInstances []*rds.DBInstance) (types.Databases, error) { +func NewDatabasesFromRDSClusterCustomEndpoints(cluster *rdstypes.DBCluster, memberInstances []rdstypes.DBInstance) (types.Databases, error) { metadata, err := MetadataFromRDSCluster(cluster) if err != nil { return nil, trace.Wrap(err) @@ -551,7 +550,7 @@ func NewDatabasesFromRDSClusterCustomEndpoints(cluster *rds.DBCluster, memberIns for _, endpoint := range cluster.CustomEndpoints { // RDS custom endpoint format: // .cluster-custom-. - endpointDetails, err := apiawsutils.ParseRDSEndpoint(aws.ToString(endpoint)) + endpointDetails, err := apiawsutils.ParseRDSEndpoint(endpoint) if err != nil { errors = append(errors, trace.Wrap(err)) continue @@ -568,7 +567,7 @@ func NewDatabasesFromRDSClusterCustomEndpoints(cluster *rds.DBCluster, memberIns }, aws.ToString(cluster.DBClusterIdentifier), apiawsutils.RDSEndpointTypeCustom, endpointDetails.ClusterCustomEndpointName), types.DatabaseSpecV3{ Protocol: protocol, - URI: fmt.Sprintf("%v:%v", aws.ToString(endpoint), aws.ToInt64(cluster.Port)), + URI: fmt.Sprintf("%v:%v", endpoint, aws.ToInt32(cluster.Port)), AWS: *metadata, // Aurora instances update their certificates upon restart, and thus custom endpoint SAN may not be available right @@ -588,14 +587,12 @@ func NewDatabasesFromRDSClusterCustomEndpoints(cluster *rds.DBCluster, memberIns return databases, trace.NewAggregate(errors...) } -func checkRDSClusterMembers(cluster *rds.DBCluster) (hasWriterInstance, hasReaderInstance bool) { +func checkRDSClusterMembers(cluster *rdstypes.DBCluster) (hasWriterInstance, hasReaderInstance bool) { for _, clusterMember := range cluster.DBClusterMembers { - if clusterMember != nil { - if aws.ToBool(clusterMember.IsClusterWriter) { - hasWriterInstance = true - } else { - hasReaderInstance = true - } + if aws.ToBool(clusterMember.IsClusterWriter) { + hasWriterInstance = true + } else { + hasReaderInstance = true } } return @@ -603,7 +600,7 @@ func checkRDSClusterMembers(cluster *rds.DBCluster) (hasWriterInstance, hasReade // NewDatabasesFromRDSCluster creates all database resources from an RDS Aurora // cluster. -func NewDatabasesFromRDSCluster(cluster *rds.DBCluster, memberInstances []*rds.DBInstance) (types.Databases, error) { +func NewDatabasesFromRDSCluster(cluster *rdstypes.DBCluster, memberInstances []rdstypes.DBInstance) (types.Databases, error) { var errors []error var databases types.Databases @@ -648,7 +645,7 @@ func NewDatabasesFromRDSCluster(cluster *rds.DBCluster, memberInstances []*rds.D // NewDatabasesFromDocumentDBCluster creates all database resources from a // DocumentDB cluster. -func NewDatabasesFromDocumentDBCluster(cluster *rds.DBCluster) (types.Databases, error) { +func NewDatabasesFromDocumentDBCluster(cluster *rdstypes.DBCluster) (types.Databases, error) { var errors []error var databases types.Databases @@ -682,7 +679,7 @@ func NewDatabasesFromDocumentDBCluster(cluster *rds.DBCluster) (types.Databases, // NewDatabaseFromDocumentDBClusterEndpoint creates database resource from // DocumentDB cluster endpoint. -func NewDatabaseFromDocumentDBClusterEndpoint(cluster *rds.DBCluster) (types.Database, error) { +func NewDatabaseFromDocumentDBClusterEndpoint(cluster *rdstypes.DBCluster) (types.Database, error) { endpointType := apiawsutils.DocumentDBClusterEndpoint metadata, err := MetadataFromDocumentDBCluster(cluster, endpointType) if err != nil { @@ -695,14 +692,14 @@ func NewDatabaseFromDocumentDBClusterEndpoint(cluster *rds.DBCluster) (types.Dat }, aws.ToString(cluster.DBClusterIdentifier)), types.DatabaseSpecV3{ Protocol: types.DatabaseProtocolMongoDB, - URI: fmt.Sprintf("%v:%v", aws.ToString(cluster.Endpoint), aws.ToInt64(cluster.Port)), + URI: fmt.Sprintf("%v:%v", aws.ToString(cluster.Endpoint), aws.ToInt32(cluster.Port)), AWS: *metadata, }) } // NewDatabaseFromDocumentDBReaderEndpoint creates database resource from // DocumentDB reader endpoint. -func NewDatabaseFromDocumentDBReaderEndpoint(cluster *rds.DBCluster) (types.Database, error) { +func NewDatabaseFromDocumentDBReaderEndpoint(cluster *rdstypes.DBCluster) (types.Database, error) { endpointType := apiawsutils.DocumentDBClusterReaderEndpoint metadata, err := MetadataFromDocumentDBCluster(cluster, endpointType) if err != nil { @@ -715,13 +712,13 @@ func NewDatabaseFromDocumentDBReaderEndpoint(cluster *rds.DBCluster) (types.Data }, aws.ToString(cluster.DBClusterIdentifier), endpointType), types.DatabaseSpecV3{ Protocol: types.DatabaseProtocolMongoDB, - URI: fmt.Sprintf("%v:%v", aws.ToString(cluster.ReaderEndpoint), aws.ToInt64(cluster.Port)), + URI: fmt.Sprintf("%v:%v", aws.ToString(cluster.ReaderEndpoint), aws.ToInt32(cluster.Port)), AWS: *metadata, }) } // NewDatabaseFromRDSProxy creates database resource from RDS Proxy. -func NewDatabaseFromRDSProxy(dbProxy *rds.DBProxy, tags []*rds.Tag) (types.Database, error) { +func NewDatabaseFromRDSProxy(dbProxy *rdstypes.DBProxy, tags []rdstypes.Tag) (types.Database, error) { metadata, err := MetadataFromRDSProxy(dbProxy) if err != nil { return nil, trace.Wrap(err) @@ -744,7 +741,7 @@ func NewDatabaseFromRDSProxy(dbProxy *rds.DBProxy, tags []*rds.Tag) (types.Datab // NewDatabaseFromRDSProxyCustomEndpoint creates database resource from RDS // Proxy custom endpoint. -func NewDatabaseFromRDSProxyCustomEndpoint(dbProxy *rds.DBProxy, customEndpoint *rds.DBProxyEndpoint, tags []*rds.Tag) (types.Database, error) { +func NewDatabaseFromRDSProxyCustomEndpoint(dbProxy *rdstypes.DBProxy, customEndpoint *rdstypes.DBProxyEndpoint, tags []rdstypes.Tag) (types.Database, error) { metadata, err := MetadataFromRDSProxyCustomEndpoint(dbProxy, customEndpoint) if err != nil { return nil, trace.Wrap(err) @@ -1045,7 +1042,7 @@ func NewDatabaseFromRedshiftServerlessVPCEndpoint(endpoint *redshiftserverless.E } // MetadataFromRDSInstance creates AWS metadata from the provided RDS instance. -func MetadataFromRDSInstance(rdsInstance *rds.DBInstance) (*types.AWS, error) { +func MetadataFromRDSInstance(rdsInstance *rdstypes.DBInstance) (*types.AWS, error) { parsedARN, err := arn.Parse(aws.ToString(rdsInstance.DBInstanceArn)) if err != nil { return nil, trace.Wrap(err) @@ -1063,7 +1060,7 @@ func MetadataFromRDSInstance(rdsInstance *rds.DBInstance) (*types.AWS, error) { } // MetadataFromRDSCluster creates AWS metadata from the provided RDS cluster. -func MetadataFromRDSCluster(rdsCluster *rds.DBCluster) (*types.AWS, error) { +func MetadataFromRDSCluster(rdsCluster *rdstypes.DBCluster) (*types.AWS, error) { parsedARN, err := arn.Parse(aws.ToString(rdsCluster.DBClusterArn)) if err != nil { return nil, trace.Wrap(err) @@ -1081,7 +1078,7 @@ func MetadataFromRDSCluster(rdsCluster *rds.DBCluster) (*types.AWS, error) { // MetadataFromDocumentDBCluster creates AWS metadata from the provided // DocumentDB cluster. -func MetadataFromDocumentDBCluster(cluster *rds.DBCluster, endpointType string) (*types.AWS, error) { +func MetadataFromDocumentDBCluster(cluster *rdstypes.DBCluster, endpointType string) (*types.AWS, error) { parsedARN, err := arn.Parse(aws.ToString(cluster.DBClusterArn)) if err != nil { return nil, trace.Wrap(err) @@ -1097,13 +1094,13 @@ func MetadataFromDocumentDBCluster(cluster *rds.DBCluster, endpointType string) } // MetadataFromRDSProxy creates AWS metadata from the provided RDS Proxy. -func MetadataFromRDSProxy(rdsProxy *rds.DBProxy) (*types.AWS, error) { +func MetadataFromRDSProxy(rdsProxy *rdstypes.DBProxy) (*types.AWS, error) { parsedARN, err := arn.Parse(aws.ToString(rdsProxy.DBProxyArn)) if err != nil { return nil, trace.Wrap(err) } - // rds.DBProxy has no resource ID attribute. The resource ID can be found + // rdstypes.DBProxy has no resource ID attribute. The resource ID can be found // in the ARN, e.g.: // // arn:aws:rds:ca-central-1:123456789012:db-proxy:prx-xxxyyyzzz @@ -1127,7 +1124,7 @@ func MetadataFromRDSProxy(rdsProxy *rds.DBProxy) (*types.AWS, error) { // MetadataFromRDSProxyCustomEndpoint creates AWS metadata from the provided // RDS Proxy custom endpoint. -func MetadataFromRDSProxyCustomEndpoint(rdsProxy *rds.DBProxy, customEndpoint *rds.DBProxyEndpoint) (*types.AWS, error) { +func MetadataFromRDSProxyCustomEndpoint(rdsProxy *rdstypes.DBProxy, customEndpoint *rdstypes.DBProxyEndpoint) (*types.AWS, error) { // Using resource ID from the default proxy for IAM policies to gain the // RDS connection access. metadata, err := MetadataFromRDSProxy(rdsProxy) @@ -1323,12 +1320,12 @@ func rdsEngineToProtocol(engine string) (string, error) { // rdsEngineFamilyToProtocolAndPort converts RDS engine family to the database protocol and port. func rdsEngineFamilyToProtocolAndPort(engineFamily string) (string, int, error) { - switch engineFamily { - case rds.EngineFamilyMysql: + switch rdstypes.EngineFamily(engineFamily) { + case rdstypes.EngineFamilyMysql: return defaults.ProtocolMySQL, services.RDSProxyMySQLPort, nil - case rds.EngineFamilyPostgresql: + case rdstypes.EngineFamilyPostgresql: return defaults.ProtocolPostgres, services.RDSProxyPostgresPort, nil - case rds.EngineFamilySqlserver: + case rdstypes.EngineFamilySqlserver: return defaults.ProtocolSQLServer, services.RDSProxySQLServerPort, nil } return "", 0, trace.BadParameter("unknown RDS engine family type %q", engineFamily) @@ -1421,7 +1418,7 @@ func labelsFromAzurePostgresFlexServer(server *armpostgresqlflexibleservers.Serv } // labelsFromRDSInstance creates database labels for the provided RDS instance. -func labelsFromRDSInstance(rdsInstance *rds.DBInstance, meta *types.AWS) map[string]string { +func labelsFromRDSInstance(rdsInstance *rdstypes.DBInstance, meta *types.AWS) map[string]string { labels := labelsFromAWSMetadata(meta) labels[types.DiscoveryLabelEngine] = aws.ToString(rdsInstance.Engine) labels[types.DiscoveryLabelEngineVersion] = aws.ToString(rdsInstance.EngineVersion) @@ -1433,7 +1430,7 @@ func labelsFromRDSInstance(rdsInstance *rds.DBInstance, meta *types.AWS) map[str } // labelsFromRDSCluster creates database labels for the provided RDS cluster. -func labelsFromRDSCluster(rdsCluster *rds.DBCluster, meta *types.AWS, endpointType string, memberInstances []*rds.DBInstance) map[string]string { +func labelsFromRDSCluster(rdsCluster *rdstypes.DBCluster, meta *types.AWS, endpointType string, memberInstances []rdstypes.DBInstance) map[string]string { labels := labelsFromAWSMetadata(meta) labels[types.DiscoveryLabelEngine] = aws.ToString(rdsCluster.Engine) labels[types.DiscoveryLabelEngineVersion] = aws.ToString(rdsCluster.EngineVersion) @@ -1444,7 +1441,7 @@ func labelsFromRDSCluster(rdsCluster *rds.DBCluster, meta *types.AWS, endpointTy return addLabels(labels, libcloudaws.TagsToLabels(rdsCluster.TagList)) } -func labelsFromDocumentDBCluster(cluster *rds.DBCluster, meta *types.AWS, endpointType string) map[string]string { +func labelsFromDocumentDBCluster(cluster *rdstypes.DBCluster, meta *types.AWS, endpointType string) map[string]string { labels := labelsFromAWSMetadata(meta) labels[types.DiscoveryLabelEngine] = aws.ToString(cluster.Engine) labels[types.DiscoveryLabelEngineVersion] = aws.ToString(cluster.EngineVersion) @@ -1453,8 +1450,8 @@ func labelsFromDocumentDBCluster(cluster *rds.DBCluster, meta *types.AWS, endpoi } // labelsFromRDSProxy creates database labels for the provided RDS Proxy. -func labelsFromRDSProxy(rdsProxy *rds.DBProxy, meta *types.AWS, tags []*rds.Tag) map[string]string { - // rds.DBProxy has no TagList. +func labelsFromRDSProxy(rdsProxy *rdstypes.DBProxy, meta *types.AWS, tags []rdstypes.Tag) map[string]string { + // rdstypes.DBProxy has no TagList. labels := labelsFromAWSMetadata(meta) labels[types.DiscoveryLabelVPCID] = aws.ToString(rdsProxy.VpcId) labels[types.DiscoveryLabelEngine] = aws.ToString(rdsProxy.EngineFamily) @@ -1463,9 +1460,9 @@ func labelsFromRDSProxy(rdsProxy *rds.DBProxy, meta *types.AWS, tags []*rds.Tag) // labelsFromRDSProxyCustomEndpoint creates database labels for the provided // RDS Proxy custom endpoint. -func labelsFromRDSProxyCustomEndpoint(rdsProxy *rds.DBProxy, customEndpoint *rds.DBProxyEndpoint, meta *types.AWS, tags []*rds.Tag) map[string]string { +func labelsFromRDSProxyCustomEndpoint(rdsProxy *rdstypes.DBProxy, customEndpoint *rdstypes.DBProxyEndpoint, meta *types.AWS, tags []rdstypes.Tag) map[string]string { labels := labelsFromRDSProxy(rdsProxy, meta, tags) - labels[types.DiscoveryLabelEndpointType] = aws.ToString(customEndpoint.TargetRole) + labels[types.DiscoveryLabelEndpointType] = string(customEndpoint.TargetRole) return labels } diff --git a/lib/srv/discovery/common/database_test.go b/lib/srv/discovery/common/database_test.go index ab2b45fff24bc..891c31a18bc13 100644 --- a/lib/srv/discovery/common/database_test.go +++ b/lib/srv/discovery/common/database_test.go @@ -28,11 +28,10 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redisenterprise/armredisenterprise" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/sql/armsql" "github.com/aws/aws-sdk-go-v2/aws" - rdsTypesV2 "github.com/aws/aws-sdk-go-v2/service/rds/types" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types" "github.com/aws/aws-sdk-go/service/elasticache" "github.com/aws/aws-sdk-go/service/memorydb" - "github.com/aws/aws-sdk-go/service/rds" "github.com/aws/aws-sdk-go/service/redshiftserverless" "github.com/google/go-cmp/cmp" "github.com/google/uuid" @@ -217,7 +216,7 @@ func TestDatabaseFromAzureRedisEnterprise(t *testing.T) { // TestDatabaseFromRDSInstance tests converting an RDS instance to a database resource. func TestDatabaseFromRDSInstance(t *testing.T) { - instance := &rds.DBInstance{ + instance := &rdstypes.DBInstance{ DBInstanceArn: aws.String("arn:aws:rds:us-west-1:123456789012:db:instance-1"), DBInstanceIdentifier: aws.String("instance-1"), DBClusterIdentifier: aws.String("cluster-1"), @@ -225,11 +224,11 @@ func TestDatabaseFromRDSInstance(t *testing.T) { IAMDatabaseAuthenticationEnabled: aws.Bool(true), Engine: aws.String(services.RDSEnginePostgres), EngineVersion: aws.String("13.0"), - Endpoint: &rds.Endpoint{ + Endpoint: &rdstypes.Endpoint{ Address: aws.String("localhost"), - Port: aws.Int64(5432), + Port: aws.Int32(5432), }, - TagList: []*rds.Tag{{ + TagList: []rdstypes.Tag{{ Key: aws.String("key"), Value: aws.String("val"), }}, @@ -268,7 +267,7 @@ func TestDatabaseFromRDSInstance(t *testing.T) { // TestDatabaseFromRDSV2Instance tests converting an RDS instance (from aws sdk v2/rds) to a database resource. func TestDatabaseFromRDSV2Instance(t *testing.T) { - instance := &rdsTypesV2.DBInstance{ + instance := &rdstypes.DBInstance{ DBInstanceArn: aws.String("arn:aws:rds:us-west-1:123456789012:db:instance-1"), DBInstanceIdentifier: aws.String("instance-1"), DBClusterIdentifier: aws.String("cluster-1"), @@ -277,16 +276,16 @@ func TestDatabaseFromRDSV2Instance(t *testing.T) { IAMDatabaseAuthenticationEnabled: aws.Bool(true), Engine: aws.String(services.RDSEnginePostgres), EngineVersion: aws.String("13.0"), - Endpoint: &rdsTypesV2.Endpoint{ + Endpoint: &rdstypes.Endpoint{ Address: aws.String("localhost"), Port: aws.Int32(5432), }, - TagList: []rdsTypesV2.Tag{{ + TagList: []rdstypes.Tag{{ Key: aws.String("key"), Value: aws.String("val"), }}, - DBSubnetGroup: &rdsTypesV2.DBSubnetGroup{ - Subnets: []rdsTypesV2.Subnet{ + DBSubnetGroup: &rdstypes.DBSubnetGroup{ + Subnets: []rdstypes.Subnet{ {SubnetIdentifier: aws.String("")}, {SubnetIdentifier: aws.String("subnet-1234567890abcdef0")}, {SubnetIdentifier: aws.String("subnet-1234567890abcdef1")}, @@ -294,7 +293,7 @@ func TestDatabaseFromRDSV2Instance(t *testing.T) { }, VpcId: aws.String("vpc-asd"), }, - VpcSecurityGroups: []rdsTypesV2.VpcSecurityGroupMembership{ + VpcSecurityGroups: []rdstypes.VpcSecurityGroupMembership{ {VpcSecurityGroupId: aws.String("")}, {VpcSecurityGroupId: aws.String("sg-1")}, {VpcSecurityGroupId: aws.String("sg-2")}, @@ -348,7 +347,7 @@ func TestDatabaseFromRDSV2Instance(t *testing.T) { newName := "override-1" instance := instance instance.TagList = append(instance.TagList, - rdsTypesV2.Tag{ + rdstypes.Tag{ Key: aws.String(overrideLabel), Value: aws.String(newName), }, @@ -365,7 +364,7 @@ func TestDatabaseFromRDSV2Instance(t *testing.T) { // TestDatabaseFromRDSInstance tests converting an RDS instance to a database resource. func TestDatabaseFromRDSInstanceNameOverride(t *testing.T) { for _, overrideLabel := range types.AWSDatabaseNameOverrideLabels { - instance := &rds.DBInstance{ + instance := &rdstypes.DBInstance{ DBInstanceArn: aws.String("arn:aws:rds:us-west-1:123456789012:db:instance-1"), DBInstanceIdentifier: aws.String("instance-1"), DBClusterIdentifier: aws.String("cluster-1"), @@ -373,11 +372,11 @@ func TestDatabaseFromRDSInstanceNameOverride(t *testing.T) { IAMDatabaseAuthenticationEnabled: aws.Bool(true), Engine: aws.String(services.RDSEnginePostgres), EngineVersion: aws.String("13.0"), - Endpoint: &rds.Endpoint{ + Endpoint: &rdstypes.Endpoint{ Address: aws.String("localhost"), - Port: aws.Int64(5432), + Port: aws.Int32(5432), }, - TagList: []*rds.Tag{ + TagList: []rdstypes.Tag{ {Key: aws.String("key"), Value: aws.String("val")}, {Key: aws.String(overrideLabel), Value: aws.String("override-1")}, }, @@ -421,8 +420,8 @@ func TestDatabaseFromRDSInstanceNameOverride(t *testing.T) { // TestDatabaseFromRDSCluster tests converting an RDS cluster to a database resource. func TestDatabaseFromRDSCluster(t *testing.T) { vpcid := uuid.NewString() - dbInstanceMembers := []*rds.DBInstance{{DBSubnetGroup: &rds.DBSubnetGroup{VpcId: aws.String(vpcid)}}} - cluster := &rds.DBCluster{ + dbInstanceMembers := []rdstypes.DBInstance{{DBSubnetGroup: &rdstypes.DBSubnetGroup{VpcId: aws.String(vpcid)}}} + cluster := &rdstypes.DBCluster{ DBClusterArn: aws.String("arn:aws:rds:us-east-1:123456789012:cluster:cluster-1"), DBClusterIdentifier: aws.String("cluster-1"), DbClusterResourceId: aws.String("resource-1"), @@ -431,12 +430,12 @@ func TestDatabaseFromRDSCluster(t *testing.T) { EngineVersion: aws.String("8.0.0"), Endpoint: aws.String("localhost"), ReaderEndpoint: aws.String("reader.host"), - Port: aws.Int64(3306), - CustomEndpoints: []*string{ - aws.String("myendpoint1.cluster-custom-example.us-east-1.rds.amazonaws.com"), - aws.String("myendpoint2.cluster-custom-example.us-east-1.rds.amazonaws.com"), + Port: aws.Int32(3306), + CustomEndpoints: []string{ + "myendpoint1.cluster-custom-example.us-east-1.rds.amazonaws.com", + "myendpoint2.cluster-custom-example.us-east-1.rds.amazonaws.com", }, - TagList: []*rds.Tag{{ + TagList: []rdstypes.Tag{{ Key: aws.String("key"), Value: aws.String("val"), }}, @@ -549,9 +548,9 @@ func TestDatabaseFromRDSCluster(t *testing.T) { t.Run("bad custom endpoints ", func(t *testing.T) { badCluster := *cluster - badCluster.CustomEndpoints = []*string{ - aws.String("badendpoint1"), - aws.String("badendpoint2"), + badCluster.CustomEndpoints = []string{ + "badendpoint1", + "badendpoint2", } _, err := NewDatabasesFromRDSClusterCustomEndpoints(&badCluster, dbInstanceMembers) require.Error(t, err) @@ -561,7 +560,7 @@ func TestDatabaseFromRDSCluster(t *testing.T) { // TestDatabaseFromRDSV2Cluster tests converting an RDS cluster to a database resource. // It uses the V2 of the aws sdk. func TestDatabaseFromRDSV2Cluster(t *testing.T) { - cluster := &rdsTypesV2.DBCluster{ + cluster := &rdstypes.DBCluster{ DBClusterArn: aws.String("arn:aws:rds:us-east-1:123456789012:cluster:cluster-1"), DBClusterIdentifier: aws.String("cluster-1"), DbClusterResourceId: aws.String("resource-1"), @@ -572,7 +571,7 @@ func TestDatabaseFromRDSV2Cluster(t *testing.T) { Endpoint: aws.String("localhost"), ReaderEndpoint: aws.String("reader.host"), Port: aws.Int32(3306), - VpcSecurityGroups: []rdsTypesV2.VpcSecurityGroupMembership{ + VpcSecurityGroups: []rdstypes.VpcSecurityGroupMembership{ {VpcSecurityGroupId: aws.String("")}, {VpcSecurityGroupId: aws.String("sg-1")}, {VpcSecurityGroupId: aws.String("sg-2")}, @@ -581,7 +580,7 @@ func TestDatabaseFromRDSV2Cluster(t *testing.T) { "myendpoint1.cluster-custom-example.us-east-1.rds.amazonaws.com", "myendpoint2.cluster-custom-example.us-east-1.rds.amazonaws.com", }, - TagList: []rdsTypesV2.Tag{{ + TagList: []rdstypes.Tag{{ Key: aws.String("key"), Value: aws.String("val"), }}, @@ -630,7 +629,7 @@ func TestDatabaseFromRDSV2Cluster(t *testing.T) { newName := "override-1" cluster.TagList = append(cluster.TagList, - rdsTypesV2.Tag{ + rdstypes.Tag{ Key: aws.String(overrideLabel), Value: aws.String(newName), }, @@ -645,10 +644,10 @@ func TestDatabaseFromRDSV2Cluster(t *testing.T) { }) t.Run("DB Cluster uses network information from DB Instance when available", func(t *testing.T) { - instance := &rdsTypesV2.DBInstance{ - DBSubnetGroup: &rdsTypesV2.DBSubnetGroup{ + instance := &rdstypes.DBInstance{ + DBSubnetGroup: &rdstypes.DBSubnetGroup{ VpcId: aws.String("vpc-123"), - Subnets: []rdsTypesV2.Subnet{ + Subnets: []rdstypes.Subnet{ {SubnetIdentifier: aws.String("subnet-123")}, {SubnetIdentifier: aws.String("subnet-456")}, }, @@ -699,9 +698,9 @@ func TestDatabaseFromRDSV2Cluster(t *testing.T) { // TestDatabaseFromRDSClusterNameOverride tests converting an RDS cluster to a database resource with overridden name. func TestDatabaseFromRDSClusterNameOverride(t *testing.T) { - dbInstanceMembers := []*rds.DBInstance{{DBSubnetGroup: &rds.DBSubnetGroup{VpcId: aws.String("vpc-123")}}} + dbInstanceMembers := []rdstypes.DBInstance{{DBSubnetGroup: &rdstypes.DBSubnetGroup{VpcId: aws.String("vpc-123")}}} for _, overrideLabel := range types.AWSDatabaseNameOverrideLabels { - cluster := &rds.DBCluster{ + cluster := &rdstypes.DBCluster{ DBClusterArn: aws.String("arn:aws:rds:us-east-1:123456789012:cluster:cluster-1"), DBClusterIdentifier: aws.String("cluster-1"), DbClusterResourceId: aws.String("resource-1"), @@ -710,12 +709,12 @@ func TestDatabaseFromRDSClusterNameOverride(t *testing.T) { EngineVersion: aws.String("8.0.0"), Endpoint: aws.String("localhost"), ReaderEndpoint: aws.String("reader.host"), - Port: aws.Int64(3306), - CustomEndpoints: []*string{ - aws.String("myendpoint1.cluster-custom-example.us-east-1.rds.amazonaws.com"), - aws.String("myendpoint2.cluster-custom-example.us-east-1.rds.amazonaws.com"), + Port: aws.Int32(3306), + CustomEndpoints: []string{ + "myendpoint1.cluster-custom-example.us-east-1.rds.amazonaws.com", + "myendpoint2.cluster-custom-example.us-east-1.rds.amazonaws.com", }, - TagList: []*rds.Tag{ + TagList: []rdstypes.Tag{ {Key: aws.String("key"), Value: aws.String("val")}, {Key: aws.String(overrideLabel), Value: aws.String("mycluster-2")}, }, @@ -831,9 +830,9 @@ func TestDatabaseFromRDSClusterNameOverride(t *testing.T) { t.Run("bad custom endpoints ", func(t *testing.T) { badCluster := *cluster - badCluster.CustomEndpoints = []*string{ - aws.String("badendpoint1"), - aws.String("badendpoint2"), + badCluster.CustomEndpoints = []string{ + "badendpoint1", + "badendpoint2", } _, err := NewDatabasesFromRDSClusterCustomEndpoints(&badCluster, dbInstanceMembers) require.Error(t, err) @@ -896,7 +895,7 @@ func TestNewDatabasesFromDocumentDBCluster(t *testing.T) { tests := []struct { name string - inputCluster *rds.DBCluster + inputCluster *rdstypes.DBCluster wantDatabases types.Databases }{ { @@ -929,26 +928,26 @@ func TestDatabaseFromRDSProxy(t *testing.T) { }{ { desc: "mysql", - engineFamily: rds.EngineFamilyMysql, + engineFamily: string(rdstypes.EngineFamilyMysql), wantProtocol: "mysql", wantPort: 3306, }, { desc: "postgres", - engineFamily: rds.EngineFamilyPostgresql, + engineFamily: string(rdstypes.EngineFamilyPostgresql), wantProtocol: "postgres", wantPort: 5432, }, { desc: "sqlserver", - engineFamily: rds.EngineFamilySqlserver, + engineFamily: string(rdstypes.EngineFamilySqlserver), wantProtocol: "sqlserver", wantPort: 1433, }, } for _, test := range tests { t.Run(test.desc, func(t *testing.T) { - dbProxy := &rds.DBProxy{ + dbProxy := &rdstypes.DBProxy{ DBProxyArn: aws.String("arn:aws:rds:ca-central-1:123456789012:db-proxy:prx-abcdef"), DBProxyName: aws.String("testproxy"), EngineFamily: aws.String(test.engineFamily), @@ -956,15 +955,15 @@ func TestDatabaseFromRDSProxy(t *testing.T) { VpcId: aws.String("test-vpc-id"), } - dbProxyEndpoint := &rds.DBProxyEndpoint{ + dbProxyEndpoint := &rdstypes.DBProxyEndpoint{ Endpoint: aws.String("custom.proxy.rds.test"), DBProxyEndpointName: aws.String("custom"), DBProxyName: aws.String("testproxy"), DBProxyEndpointArn: aws.String("arn:aws:rds:ca-central-1:123456789012:db-proxy-endpoint:prx-endpoint-abcdef"), - TargetRole: aws.String(rds.DBProxyEndpointTargetRoleReadOnly), + TargetRole: rdstypes.DBProxyEndpointTargetRoleReadOnly, } - tags := []*rds.Tag{{ + tags := []rdstypes.Tag{{ Key: aws.String("key"), Value: aws.String("val"), }} @@ -1059,7 +1058,7 @@ func TestAuroraMySQLVersion(t *testing.T) { } for _, test := range tests { t.Run(test.engineVersion, func(t *testing.T) { - require.Equal(t, test.expectedMySQLVersion, libcloudaws.AuroraMySQLVersion(&rds.DBCluster{EngineVersion: aws.String(test.engineVersion)})) + require.Equal(t, test.expectedMySQLVersion, libcloudaws.AuroraMySQLVersion(&rdstypes.DBCluster{EngineVersion: aws.String(test.engineVersion)})) }) } } @@ -1099,7 +1098,7 @@ func TestIsRDSClusterSupported(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - cluster := &rds.DBCluster{ + cluster := &rdstypes.DBCluster{ DBClusterArn: aws.String("arn:aws:rds:us-east-1:123456789012:cluster:test"), DBClusterIdentifier: aws.String(test.name), DbClusterResourceId: aws.String(uuid.New().String()), @@ -1149,7 +1148,7 @@ func TestIsRDSInstanceSupported(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - cluster := &rds.DBInstance{ + cluster := &rdstypes.DBInstance{ DBInstanceArn: aws.String("arn:aws:rds:us-east-1:123456789012:instance:test"), DBClusterIdentifier: aws.String(test.name), DbiResourceId: aws.String(uuid.New().String()), diff --git a/lib/srv/discovery/common/kubernetes_test.go b/lib/srv/discovery/common/kubernetes_test.go index 868f9dfac9370..bd69eccaa4676 100644 --- a/lib/srv/discovery/common/kubernetes_test.go +++ b/lib/srv/discovery/common/kubernetes_test.go @@ -98,9 +98,8 @@ func TestNewKubeClusterFromAWSEKS(t *testing.T) { require.NoError(t, err) cluster := &ekstypes.Cluster{ - Name: aws.String("cluster1"), - Arn: aws.String("arn:aws:eks:eu-west-1:123456789012:cluster/cluster1"), - Status: ekstypes.ClusterStatusActive, + Name: aws.String("cluster1"), + Arn: aws.String("arn:aws:eks:eu-west-1:123456789012:cluster/cluster1"), Tags: map[string]string{ overrideLabel: "override-1", "env": "prod", diff --git a/lib/srv/discovery/common/renaming_test.go b/lib/srv/discovery/common/renaming_test.go index 5be2c13f3b3c4..7bb64f9f01bab 100644 --- a/lib/srv/discovery/common/renaming_test.go +++ b/lib/srv/discovery/common/renaming_test.go @@ -28,14 +28,14 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redis/armredis/v3" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/redisenterprise/armredisenterprise" ekstypes "github.com/aws/aws-sdk-go-v2/service/eks/types" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/rds" "github.com/google/uuid" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" azureutils "github.com/gravitational/teleport/api/utils/azure" - libcloudaws "github.com/gravitational/teleport/lib/cloud/aws" + "github.com/gravitational/teleport/lib/cloud/awstesthelpers" "github.com/gravitational/teleport/lib/cloud/azure" "github.com/gravitational/teleport/lib/cloud/gcp" "github.com/gravitational/teleport/lib/services" @@ -365,7 +365,7 @@ func requireOverrideLabelSkipsRenaming(t *testing.T, r types.ResourceWithLabels, func makeAuroraPrimaryDB(t *testing.T, name, region, accountID, overrideLabel string) types.Database { t.Helper() - cluster := &rds.DBCluster{ + cluster := &rdstypes.DBCluster{ DBClusterArn: aws.String(fmt.Sprintf("arn:aws:rds:%s:%s:cluster:%v", region, accountID, name)), DBClusterIdentifier: aws.String("cluster-1"), DbClusterResourceId: aws.String("resource-1"), @@ -373,29 +373,29 @@ func makeAuroraPrimaryDB(t *testing.T, name, region, accountID, overrideLabel st Engine: aws.String("aurora-mysql"), EngineVersion: aws.String("8.0.0"), Endpoint: aws.String("localhost"), - Port: aws.Int64(3306), - TagList: libcloudaws.LabelsToTags[rds.Tag](map[string]string{ + Port: aws.Int32(3306), + TagList: awstesthelpers.LabelsToRDSTags(map[string]string{ overrideLabel: name, }), } - database, err := NewDatabaseFromRDSCluster(cluster, []*rds.DBInstance{}) + database, err := NewDatabaseFromRDSCluster(cluster, []rdstypes.DBInstance{}) require.NoError(t, err) return database } func makeRDSInstanceDB(t *testing.T, name, region, accountID, overrideLabel string) types.Database { t.Helper() - instance := &rds.DBInstance{ + instance := &rdstypes.DBInstance{ DBInstanceArn: aws.String(fmt.Sprintf("arn:aws:rds:%s:%s:db:%v", region, accountID, name)), DBInstanceIdentifier: aws.String(name), DbiResourceId: aws.String(uuid.New().String()), Engine: aws.String(services.RDSEnginePostgres), DBInstanceStatus: aws.String("available"), - Endpoint: &rds.Endpoint{ + Endpoint: &rdstypes.Endpoint{ Address: aws.String("localhost"), - Port: aws.Int64(5432), + Port: aws.Int32(5432), }, - TagList: libcloudaws.LabelsToTags[rds.Tag](map[string]string{ + TagList: awstesthelpers.LabelsToRDSTags(map[string]string{ overrideLabel: name, }), } @@ -499,9 +499,8 @@ func labelsToAzureTags(labels map[string]string) map[string]*string { func makeEKSKubeCluster(t *testing.T, name, region, accountID, overrideLabel string) types.KubeCluster { t.Helper() eksCluster := &ekstypes.Cluster{ - Name: aws.String(name), - Arn: aws.String(fmt.Sprintf("arn:aws:eks:%s:%s:cluster/%s", region, accountID, name)), - Status: ekstypes.ClusterStatusActive, + Name: aws.String(name), + Arn: aws.String(fmt.Sprintf("arn:aws:eks:%s:%s:cluster/%s", region, accountID, name)), Tags: map[string]string{ overrideLabel: name, }, diff --git a/lib/srv/discovery/discovery_test.go b/lib/srv/discovery/discovery_test.go index 3eea560f67174..2948e10cdb916 100644 --- a/lib/srv/discovery/discovery_test.go +++ b/lib/srv/discovery/discovery_test.go @@ -41,11 +41,12 @@ import ( ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/aws/aws-sdk-go-v2/service/eks" ekstypes "github.com/aws/aws-sdk-go-v2/service/eks/types" + "github.com/aws/aws-sdk-go-v2/service/rds" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/aws/aws-sdk-go-v2/service/redshift" redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types" "github.com/aws/aws-sdk-go-v2/service/ssm" ssmtypes "github.com/aws/aws-sdk-go-v2/service/ssm/types" - "github.com/aws/aws-sdk-go/service/rds" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/google/uuid" @@ -2063,13 +2064,7 @@ func TestDiscoveryDatabase(t *testing.T) { } testCloudClients := &cloud.TestCloudClients{ - STS: &mocks.STSClientV1{}, - RDS: &mocks.RDSMock{ - DBInstances: []*rds.DBInstance{awsRDSInstance}, - DBEngineVersions: []*rds.DBEngineVersion{ - {Engine: aws.String(services.RDSEnginePostgres)}, - }, - }, + STS: &mocks.STSClientV1{}, MemoryDB: &mocks.MemoryDBMock{}, AzureRedis: azure.NewRedisClientByAPI(&azure.ARMRedisMock{ Servers: []*armredis.ResourceInfo{azRedisResource}, @@ -2407,9 +2402,17 @@ func TestDiscoveryDatabase(t *testing.T) { dbFetcherFactory, err := db.NewAWSFetcherFactory(db.AWSFetcherFactoryConfig{ AWSConfigProvider: fakeConfigProvider, CloudClients: testCloudClients, - RedshiftClientProviderFn: newFakeRedshiftClientProvider(&mocks.RedshiftClient{ - Clusters: []redshifttypes.Cluster{*awsRedshiftResource}, - }), + AWSClients: fakeAWSClients{ + rdsClient: &mocks.RDSClient{ + DBInstances: []rdstypes.DBInstance{*awsRDSInstance}, + DBEngineVersions: []rdstypes.DBEngineVersion{ + {Engine: aws.String(services.RDSEnginePostgres)}, + }, + }, + redshiftClient: &mocks.RedshiftClient{ + Clusters: []redshifttypes.Cluster{*awsRedshiftResource}, + }, + }, }) require.NoError(t, err) @@ -2503,15 +2506,25 @@ func TestDiscoveryDatabaseRemovingDiscoveryConfigs(t *testing.T) { awsRDSInstance, awsRDSDB := makeRDSInstance(t, "aws-rds", "us-west-1", rewriteDiscoveryLabelsParams{discoveryConfigName: dc2Name, discoveryGroup: mainDiscoveryGroup}) + fakeConfigProvider := &mocks.AWSConfigProvider{ + STSClient: &mocks.STSClient{}, + } testCloudClients := &cloud.TestCloudClients{ - STS: &mocks.STSClientV1{}, - RDS: &mocks.RDSMock{ - DBInstances: []*rds.DBInstance{awsRDSInstance}, - DBEngineVersions: []*rds.DBEngineVersion{ - {Engine: aws.String(services.RDSEnginePostgres)}, + STS: &fakeConfigProvider.STSClient.STSClientV1, + } + dbFetcherFactory, err := db.NewAWSFetcherFactory(db.AWSFetcherFactoryConfig{ + AWSConfigProvider: fakeConfigProvider, + CloudClients: testCloudClients, + AWSClients: fakeAWSClients{ + rdsClient: &mocks.RDSClient{ + DBInstances: []rdstypes.DBInstance{*awsRDSInstance}, + DBEngineVersions: []rdstypes.DBEngineVersion{ + {Engine: aws.String(services.RDSEnginePostgres)}, + }, }, }, - } + }) + require.NoError(t, err) ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) @@ -2539,14 +2552,16 @@ func TestDiscoveryDatabaseRemovingDiscoveryConfigs(t *testing.T) { srv, err := New( authz.ContextWithUser(ctx, identity.I), &Config{ - CloudClients: testCloudClients, - ClusterFeatures: func() proto.Features { return proto.Features{} }, - KubernetesClient: fake.NewSimpleClientset(), - AccessPoint: getDiscoveryAccessPoint(tlsServer.Auth(), authClient), - Matchers: Matchers{}, - Emitter: authClient, - DiscoveryGroup: mainDiscoveryGroup, - clock: clock, + AWSConfigProvider: fakeConfigProvider, + AWSDatabaseFetcherFactory: dbFetcherFactory, + CloudClients: testCloudClients, + ClusterFeatures: func() proto.Features { return proto.Features{} }, + KubernetesClient: fake.NewSimpleClientset(), + AccessPoint: getDiscoveryAccessPoint(tlsServer.Auth(), authClient), + Matchers: Matchers{}, + Emitter: authClient, + DiscoveryGroup: mainDiscoveryGroup, + clock: clock, }) require.NoError(t, err) @@ -2669,16 +2684,16 @@ func makeEKSCluster(t *testing.T, name, region string, discoveryParams rewriteDi return eksAWSCluster, actual } -func makeRDSInstance(t *testing.T, name, region string, discoveryParams rewriteDiscoveryLabelsParams) (*rds.DBInstance, types.Database) { - instance := &rds.DBInstance{ +func makeRDSInstance(t *testing.T, name, region string, discoveryParams rewriteDiscoveryLabelsParams) (*rdstypes.DBInstance, types.Database) { + instance := &rdstypes.DBInstance{ DBInstanceArn: aws.String(fmt.Sprintf("arn:aws:rds:%v:123456789012:db:%v", region, name)), DBInstanceIdentifier: aws.String(name), DbiResourceId: aws.String(uuid.New().String()), Engine: aws.String(services.RDSEnginePostgres), DBInstanceStatus: aws.String("available"), - Endpoint: &rds.Endpoint{ + Endpoint: &rdstypes.Endpoint{ Address: aws.String("localhost"), - Port: aws.Int64(5432), + Port: aws.Int32(5432), }, } database, err := common.NewDatabaseFromRDSInstance(instance) @@ -3748,8 +3763,15 @@ func newPopulatedGCPProjectsMock() *mockProjectsAPI { } } -func newFakeRedshiftClientProvider(c redshift.DescribeClustersAPIClient) db.RedshiftClientProviderFunc { - return func(cfg aws.Config, optFns ...func(*redshift.Options)) db.RedshiftClient { - return c - } +type fakeAWSClients struct { + rdsClient db.RDSClient + redshiftClient db.RedshiftClient +} + +func (f fakeAWSClients) GetRDSClient(cfg aws.Config, optFns ...func(*rds.Options)) db.RDSClient { + return f.rdsClient +} + +func (f fakeAWSClients) GetRedshiftClient(cfg aws.Config, optFns ...func(*redshift.Options)) db.RedshiftClient { + return f.redshiftClient } diff --git a/lib/srv/discovery/fetchers/aws-sync/aws-sync.go b/lib/srv/discovery/fetchers/aws-sync/aws-sync.go index adc450ece9fbc..a65742fc38856 100644 --- a/lib/srv/discovery/fetchers/aws-sync/aws-sync.go +++ b/lib/srv/discovery/fetchers/aws-sync/aws-sync.go @@ -24,9 +24,9 @@ import ( "sync" "time" - awsv2 "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/aws/retry" - "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go-v2/service/rds" "github.com/aws/aws-sdk-go/aws/client" "github.com/aws/aws-sdk-go/service/sts" "github.com/gravitational/trace" @@ -45,6 +45,8 @@ const pageSize int64 = 500 // Config is the configuration for the AWS fetcher. type Config struct { + // AWSConfigProvider provides [aws.Config] for AWS SDK service clients. + AWSConfigProvider awsconfig.Provider // CloudClients is the cloud clients to use when fetching AWS resources. CloudClients cloud.Clients // GetEKSClient gets an AWS EKS client for the given region. @@ -61,6 +63,32 @@ type Config struct { Integration string // DiscoveryConfigName if set, will be used to report the Discovery Config Status to the Auth Server. DiscoveryConfigName string + + // awsClients provides AWS SDK clients. + awsClients awsClientProvider +} + +func (c *Config) CheckAndSetDefaults() error { + if c.AWSConfigProvider == nil { + return trace.BadParameter("missing AWSConfigProvider") + } + + if c.awsClients == nil { + c.awsClients = defaultAWSClients{} + } + return nil +} + +// awsClientProvider provides AWS service API clients. +type awsClientProvider interface { + // getRDSClient provides an [RDSClient]. + getRDSClient(cfg aws.Config, optFns ...func(*rds.Options)) rdsClient +} + +type defaultAWSClients struct{} + +func (defaultAWSClients) getRDSClient(cfg aws.Config, optFns ...func(*rds.Options)) rdsClient { + return rds.NewFromConfig(cfg, optFns...) } // AssumeRole is the configuration for assuming an AWS role. @@ -184,6 +212,9 @@ func (r *Resources) UsageReport(numberAccounts int) *usageeventsv1.AccessGraphAW // NewAWSFetcher creates a new AWS fetcher. func NewAWSFetcher(ctx context.Context, cfg Config) (AWSSync, error) { + if err := cfg.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } a := &awsFetcher{ Config: cfg, lastResult: &Resources{}, @@ -337,7 +368,7 @@ func (a *awsFetcher) getAWSV2Options() []awsconfig.OptionsFn { opts = append(opts, awsconfig.WithAssumeRole(a.Config.AssumeRole.RoleARN, a.Config.AssumeRole.ExternalID)) } const maxRetries = 10 - opts = append(opts, awsconfig.WithRetryer(func() awsv2.Retryer { + opts = append(opts, awsconfig.WithRetryer(func() aws.Retryer { return retry.NewStandard(func(so *retry.StandardOptions) { so.MaxAttempts = maxRetries so.Backoff = retry.NewExponentialJitterBackoff(300 * time.Second) @@ -363,7 +394,7 @@ func (a *awsFetcher) getAccountId(ctx context.Context) (string, error) { return "", trace.Wrap(err) } - return aws.StringValue(req.Account), nil + return aws.ToString(req.Account), nil } func (a *awsFetcher) DiscoveryConfigName() string { diff --git a/lib/srv/discovery/fetchers/aws-sync/rds.go b/lib/srv/discovery/fetchers/aws-sync/rds.go index 08195e2132e82..f163c49f6b6d3 100644 --- a/lib/srv/discovery/fetchers/aws-sync/rds.go +++ b/lib/srv/discovery/fetchers/aws-sync/rds.go @@ -22,8 +22,9 @@ import ( "context" "sync" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/rds" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/rds" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/gravitational/trace" "golang.org/x/sync/errgroup" "google.golang.org/protobuf/types/known/timestamppb" @@ -31,12 +32,18 @@ import ( accessgraphv1alpha "github.com/gravitational/teleport/gen/proto/go/accessgraph/v1alpha" ) +// rdsClient defines a subset of the AWS RDS client API. +type rdsClient interface { + rds.DescribeDBClustersAPIClient + rds.DescribeDBInstancesAPIClient +} + // pollAWSRDSDatabases is a function that returns a function that fetches // RDS instances and clusters. func (a *awsFetcher) pollAWSRDSDatabases(ctx context.Context, result *Resources, collectErr func(error)) func() error { return func() error { var err error - result.RDSDatabases, err = a.fetchAWSRDSDatabases(ctx, a.lastResult) + result.RDSDatabases, err = a.fetchAWSRDSDatabases(ctx) if err != nil { collectErr(trace.Wrap(err, "failed to fetch databases")) } @@ -45,7 +52,7 @@ func (a *awsFetcher) pollAWSRDSDatabases(ctx context.Context, result *Resources, } // fetchAWSRDSDatabases fetches RDS databases from all regions. -func (a *awsFetcher) fetchAWSRDSDatabases(ctx context.Context, existing *Resources) ( +func (a *awsFetcher) fetchAWSRDSDatabases(ctx context.Context) ( []*accessgraphv1alpha.AWSRDSDatabaseV1, error, ) { @@ -59,14 +66,14 @@ func (a *awsFetcher) fetchAWSRDSDatabases(ctx context.Context, existing *Resourc // This is a temporary solution until we have a better way to limit the // number of concurrent requests. eG.SetLimit(5) - collectDBs := func(db *accessgraphv1alpha.AWSRDSDatabaseV1, err error) { + collectDBs := func(db []*accessgraphv1alpha.AWSRDSDatabaseV1, err error) { hostsMu.Lock() defer hostsMu.Unlock() if err != nil { errs = append(errs, err) } if db != nil { - dbs = append(dbs, db) + dbs = append(dbs, db...) } } @@ -74,42 +81,14 @@ func (a *awsFetcher) fetchAWSRDSDatabases(ctx context.Context, existing *Resourc for _, region := range a.Regions { region := region eG.Go(func() error { - rdsClient, err := a.CloudClients.GetAWSRDSClient(ctx, region, a.getAWSOptions()...) + awsCfg, err := a.AWSConfigProvider.GetConfig(ctx, region, a.getAWSV2Options()...) if err != nil { collectDBs(nil, trace.Wrap(err)) return nil } - err = rdsClient.DescribeDBInstancesPagesWithContext(ctx, &rds.DescribeDBInstancesInput{}, - func(output *rds.DescribeDBInstancesOutput, lastPage bool) bool { - for _, db := range output.DBInstances { - // if instance belongs to a cluster, skip it as we want to represent the cluster itself - // and we pull it using DescribeDBClustersPagesWithContext instead. - if aws.StringValue(db.DBClusterIdentifier) != "" { - continue - } - protoRDS := awsRDSInstanceToRDS(db, region, a.AccountID) - collectDBs(protoRDS, nil) - } - return !lastPage - }, - ) - if err != nil { - collectDBs(nil, trace.Wrap(err)) - } - - err = rdsClient.DescribeDBClustersPagesWithContext(ctx, &rds.DescribeDBClustersInput{}, - func(output *rds.DescribeDBClustersOutput, lastPage bool) bool { - for _, db := range output.DBClusters { - protoRDS := awsRDSClusterToRDS(db, region, a.AccountID) - collectDBs(protoRDS, nil) - } - return !lastPage - }, - ) - if err != nil { - collectDBs(nil, trace.Wrap(err)) - } - + clt := a.awsClients.getRDSClient(awsCfg) + a.collectDBInstances(ctx, clt, region, collectDBs) + a.collectDBClusters(ctx, clt, region, collectDBs) return nil }) } @@ -118,60 +97,123 @@ func (a *awsFetcher) fetchAWSRDSDatabases(ctx context.Context, existing *Resourc return dbs, trace.NewAggregate(append(errs, err)...) } -// awsRDSInstanceToRDS converts an rds.DBInstance to accessgraphv1alpha.AWSRDSDatabaseV1 +// awsRDSInstanceToRDS converts an rdstypes.DBInstance to accessgraphv1alpha.AWSRDSDatabaseV1 // representation. -func awsRDSInstanceToRDS(instance *rds.DBInstance, region, accountID string) *accessgraphv1alpha.AWSRDSDatabaseV1 { +func awsRDSInstanceToRDS(instance *rdstypes.DBInstance, region, accountID string) *accessgraphv1alpha.AWSRDSDatabaseV1 { var tags []*accessgraphv1alpha.AWSTag for _, v := range instance.TagList { tags = append(tags, &accessgraphv1alpha.AWSTag{ - Key: aws.StringValue(v.Key), + Key: aws.ToString(v.Key), Value: strPtrToWrapper(v.Value), }) } return &accessgraphv1alpha.AWSRDSDatabaseV1{ - Name: aws.StringValue(instance.DBInstanceIdentifier), - Arn: aws.StringValue(instance.DBInstanceArn), + Name: aws.ToString(instance.DBInstanceIdentifier), + Arn: aws.ToString(instance.DBInstanceArn), CreatedAt: awsTimeToProtoTime(instance.InstanceCreateTime), - Status: aws.StringValue(instance.DBInstanceStatus), + Status: aws.ToString(instance.DBInstanceStatus), Region: region, AccountId: accountID, Tags: tags, EngineDetails: &accessgraphv1alpha.AWSRDSEngineV1{ - Engine: aws.StringValue(instance.Engine), - Version: aws.StringValue(instance.EngineVersion), + Engine: aws.ToString(instance.Engine), + Version: aws.ToString(instance.EngineVersion), }, IsCluster: false, - ResourceId: aws.StringValue(instance.DbiResourceId), + ResourceId: aws.ToString(instance.DbiResourceId), LastSyncTime: timestamppb.Now(), } } -// awsRDSInstanceToRDS converts an rds.DBCluster to accessgraphv1alpha.AWSRDSDatabaseV1 +// awsRDSInstanceToRDS converts an rdstypes.DBCluster to accessgraphv1alpha.AWSRDSDatabaseV1 // representation. -func awsRDSClusterToRDS(instance *rds.DBCluster, region, accountID string) *accessgraphv1alpha.AWSRDSDatabaseV1 { +func awsRDSClusterToRDS(instance *rdstypes.DBCluster, region, accountID string) *accessgraphv1alpha.AWSRDSDatabaseV1 { var tags []*accessgraphv1alpha.AWSTag for _, v := range instance.TagList { tags = append(tags, &accessgraphv1alpha.AWSTag{ - Key: aws.StringValue(v.Key), + Key: aws.ToString(v.Key), Value: strPtrToWrapper(v.Value), }) } return &accessgraphv1alpha.AWSRDSDatabaseV1{ - Name: aws.StringValue(instance.DBClusterIdentifier), - Arn: aws.StringValue(instance.DBClusterArn), + Name: aws.ToString(instance.DBClusterIdentifier), + Arn: aws.ToString(instance.DBClusterArn), CreatedAt: awsTimeToProtoTime(instance.ClusterCreateTime), - Status: aws.StringValue(instance.Status), + Status: aws.ToString(instance.Status), Region: region, AccountId: accountID, Tags: tags, EngineDetails: &accessgraphv1alpha.AWSRDSEngineV1{ - Engine: aws.StringValue(instance.Engine), - Version: aws.StringValue(instance.EngineVersion), + Engine: aws.ToString(instance.Engine), + Version: aws.ToString(instance.EngineVersion), }, IsCluster: true, - ResourceId: aws.StringValue(instance.DbClusterResourceId), + ResourceId: aws.ToString(instance.DbClusterResourceId), LastSyncTime: timestamppb.Now(), } } + +func (a *awsFetcher) collectDBInstances(ctx context.Context, + clt rdsClient, + region string, + collectDBs func([]*accessgraphv1alpha.AWSRDSDatabaseV1, error), +) { + pager := rds.NewDescribeDBInstancesPaginator(clt, + &rds.DescribeDBInstancesInput{}, + func(ddpo *rds.DescribeDBInstancesPaginatorOptions) { + ddpo.StopOnDuplicateToken = true + }, + ) + var instances []*accessgraphv1alpha.AWSRDSDatabaseV1 + for pager.HasMorePages() { + page, err := pager.NextPage(ctx) + if err != nil { + old := sliceFilter(a.lastResult.RDSDatabases, func(db *accessgraphv1alpha.AWSRDSDatabaseV1) bool { + return !db.IsCluster && db.Region == region && db.AccountId == a.AccountID + }) + collectDBs(old, trace.Wrap(err)) + return + } + for _, db := range page.DBInstances { + // if instance belongs to a cluster, skip it as we want to represent the cluster itself + // and we pull it using DescribeDBClustersPaginator instead. + if aws.ToString(db.DBClusterIdentifier) != "" { + continue + } + protoRDS := awsRDSInstanceToRDS(&db, region, a.AccountID) + instances = append(instances, protoRDS) + } + } + collectDBs(instances, nil) +} + +func (a *awsFetcher) collectDBClusters( + ctx context.Context, + clt rdsClient, + region string, + collectDBs func([]*accessgraphv1alpha.AWSRDSDatabaseV1, error), +) { + pager := rds.NewDescribeDBClustersPaginator(clt, &rds.DescribeDBClustersInput{}, + func(ddpo *rds.DescribeDBClustersPaginatorOptions) { + ddpo.StopOnDuplicateToken = true + }, + ) + var clusters []*accessgraphv1alpha.AWSRDSDatabaseV1 + for pager.HasMorePages() { + page, err := pager.NextPage(ctx) + if err != nil { + old := sliceFilter(a.lastResult.RDSDatabases, func(db *accessgraphv1alpha.AWSRDSDatabaseV1) bool { + return db.IsCluster && db.Region == region && db.AccountId == a.AccountID + }) + collectDBs(old, trace.Wrap(err)) + return + } + for _, db := range page.DBClusters { + protoRDS := awsRDSClusterToRDS(&db, region, a.AccountID) + clusters = append(clusters, protoRDS) + } + } + collectDBs(clusters, nil) +} diff --git a/lib/srv/discovery/fetchers/aws-sync/rds_test.go b/lib/srv/discovery/fetchers/aws-sync/rds_test.go index bed0811d88e1d..b228264b7c834 100644 --- a/lib/srv/discovery/fetchers/aws-sync/rds_test.go +++ b/lib/srv/discovery/fetchers/aws-sync/rds_test.go @@ -20,19 +20,19 @@ package aws_sync import ( "context" - "sync" "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/rds" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/rds" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" "google.golang.org/protobuf/testing/protocmp" "google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/wrapperspb" + "github.com/gravitational/teleport/api/types" accessgraphv1alpha "github.com/gravitational/teleport/gen/proto/go/accessgraph/v1alpha" - "github.com/gravitational/teleport/lib/cloud" "github.com/gravitational/teleport/lib/cloud/mocks" ) @@ -44,86 +44,117 @@ func TestPollAWSRDS(t *testing.T) { regions = []string{"eu-west-1"} ) - tests := []struct { - name string - want *Resources - }{ - { - name: "poll rds databases", - want: &Resources{ - RDSDatabases: []*accessgraphv1alpha.AWSRDSDatabaseV1{ + awsOIDCIntegration, err := types.NewIntegrationAWSOIDC( + types.Metadata{Name: "integration-test"}, + &types.AWSOIDCIntegrationSpecV1{ + RoleARN: "arn:aws:sts::123456789012:role/TestRole", + }, + ) + require.NoError(t, err) + + resourcesFixture := Resources{ + RDSDatabases: []*accessgraphv1alpha.AWSRDSDatabaseV1{ + { + Arn: "arn:us-west1:rds:instance1", + Status: string(rdstypes.DBProxyStatusAvailable), + Name: "db1", + EngineDetails: &accessgraphv1alpha.AWSRDSEngineV1{ + Engine: string(rdstypes.EngineFamilyMysql), + Version: "v1.1", + }, + CreatedAt: timestamppb.New(date), + Tags: []*accessgraphv1alpha.AWSTag{ { - Arn: "arn:us-west1:rds:instance1", - Status: rds.DBProxyStatusAvailable, - Name: "db1", - EngineDetails: &accessgraphv1alpha.AWSRDSEngineV1{ - Engine: rds.EngineFamilyMysql, - Version: "v1.1", - }, - CreatedAt: timestamppb.New(date), - Tags: []*accessgraphv1alpha.AWSTag{ - { - Key: "tag", - Value: wrapperspb.String("val"), - }, - }, - Region: "eu-west-1", - IsCluster: false, - AccountId: "12345678", - ResourceId: "db1", + Key: "tag", + Value: wrapperspb.String("val"), }, + }, + Region: "eu-west-1", + IsCluster: false, + AccountId: "12345678", + ResourceId: "db1", + }, + { + Arn: "arn:us-west1:rds:cluster1", + Status: string(rdstypes.DBProxyStatusAvailable), + Name: "cluster1", + EngineDetails: &accessgraphv1alpha.AWSRDSEngineV1{ + Engine: string(rdstypes.EngineFamilyMysql), + Version: "v1.1", + }, + CreatedAt: timestamppb.New(date), + Tags: []*accessgraphv1alpha.AWSTag{ { - Arn: "arn:us-west1:rds:cluster1", - Status: rds.DBProxyStatusAvailable, - Name: "cluster1", - EngineDetails: &accessgraphv1alpha.AWSRDSEngineV1{ - Engine: rds.EngineFamilyMysql, - Version: "v1.1", - }, - CreatedAt: timestamppb.New(date), - Tags: []*accessgraphv1alpha.AWSTag{ - { - Key: "tag", - Value: wrapperspb.String("val"), - }, - }, - Region: "eu-west-1", - IsCluster: true, - AccountId: "12345678", - ResourceId: "cluster1", + Key: "tag", + Value: wrapperspb.String("val"), }, }, + Region: "eu-west-1", + IsCluster: true, + AccountId: "12345678", + ResourceId: "cluster1", + }, + }, + } + + tests := []struct { + name string + fetcherConfigOpt func(*awsFetcher) + want *Resources + checkError func(*testing.T, error) + }{ + { + name: "poll rds databases", + want: &resourcesFixture, + fetcherConfigOpt: func(a *awsFetcher) { + a.awsClients = fakeAWSClients{ + rdsClient: &mocks.RDSClient{ + DBInstances: dbInstances(), + DBClusters: dbClusters(), + }, + } + }, + checkError: func(t *testing.T, err error) { + require.NoError(t, err) + }, + }, + { + name: "reuse last synced databases on failure", + want: &resourcesFixture, + fetcherConfigOpt: func(a *awsFetcher) { + a.awsClients = fakeAWSClients{ + rdsClient: &mocks.RDSClient{Unauth: true}, + } + a.lastResult = &resourcesFixture + }, + checkError: func(t *testing.T, err error) { + require.Error(t, err) + require.ErrorContains(t, err, "failed to fetch databases") }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - mockedClients := &cloud.TestCloudClients{ - RDS: &mocks.RDSMock{ - DBInstances: dbInstances(), - DBClusters: dbClusters(), - }, - } - - var ( - errs []error - mu sync.Mutex - ) - - collectErr := func(err error) { - mu.Lock() - defer mu.Unlock() - errs = append(errs, err) - } a := &awsFetcher{ Config: Config{ - AccountID: accountID, - CloudClients: mockedClients, - Regions: regions, - Integration: accountID, + AccountID: accountID, + AWSConfigProvider: &mocks.AWSConfigProvider{ + OIDCIntegrationClient: &mocks.FakeOIDCIntegrationClient{ + Integration: awsOIDCIntegration, + Token: "fake-oidc-token", + }, + }, + Regions: regions, + Integration: awsOIDCIntegration.GetName(), }, } + if tt.fetcherConfigOpt != nil { + tt.fetcherConfigOpt(a) + } result := &Resources{} + collectErr := func(err error) { + tt.checkError(t, err) + } execFunc := a.pollAWSRDSDatabases(context.Background(), result, collectErr) require.NoError(t, execFunc()) require.Empty(t, cmp.Diff( @@ -144,16 +175,16 @@ func TestPollAWSRDS(t *testing.T) { } } -func dbInstances() []*rds.DBInstance { - return []*rds.DBInstance{ +func dbInstances() []rdstypes.DBInstance { + return []rdstypes.DBInstance{ { DBInstanceIdentifier: aws.String("db1"), DBInstanceArn: aws.String("arn:us-west1:rds:instance1"), InstanceCreateTime: aws.Time(date), - Engine: aws.String(rds.EngineFamilyMysql), - DBInstanceStatus: aws.String(rds.DBProxyStatusAvailable), + Engine: aws.String(string(rdstypes.EngineFamilyMysql)), + DBInstanceStatus: aws.String(string(rdstypes.DBProxyStatusAvailable)), EngineVersion: aws.String("v1.1"), - TagList: []*rds.Tag{ + TagList: []rdstypes.Tag{ { Key: aws.String("tag"), Value: aws.String("val"), @@ -164,16 +195,16 @@ func dbInstances() []*rds.DBInstance { } } -func dbClusters() []*rds.DBCluster { - return []*rds.DBCluster{ +func dbClusters() []rdstypes.DBCluster { + return []rdstypes.DBCluster{ { DBClusterIdentifier: aws.String("cluster1"), DBClusterArn: aws.String("arn:us-west1:rds:cluster1"), ClusterCreateTime: aws.Time(date), - Engine: aws.String(rds.EngineFamilyMysql), - Status: aws.String(rds.DBProxyStatusAvailable), + Engine: aws.String(string(rdstypes.EngineFamilyMysql)), + Status: aws.String(string(rdstypes.DBProxyStatusAvailable)), EngineVersion: aws.String("v1.1"), - TagList: []*rds.Tag{ + TagList: []rdstypes.Tag{ { Key: aws.String("tag"), Value: aws.String("val"), @@ -183,3 +214,11 @@ func dbClusters() []*rds.DBCluster { }, } } + +type fakeAWSClients struct { + rdsClient rdsClient +} + +func (f fakeAWSClients) getRDSClient(cfg aws.Config, optFns ...func(*rds.Options)) rdsClient { + return f.rdsClient +} diff --git a/lib/srv/discovery/fetchers/db/aws.go b/lib/srv/discovery/fetchers/db/aws.go index d6d70912d7092..24de91e83e309 100644 --- a/lib/srv/discovery/fetchers/db/aws.go +++ b/lib/srv/discovery/fetchers/db/aws.go @@ -23,8 +23,6 @@ import ( "fmt" "log/slog" - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/redshift" "github.com/gravitational/trace" "github.com/gravitational/teleport" @@ -74,8 +72,8 @@ type awsFetcherConfig struct { // ie teleport.yaml/discovery_service.. DiscoveryConfigName string - // redshiftClientProviderFn provides an AWS Redshift client. - redshiftClientProviderFn RedshiftClientProviderFunc + // awsClients provides AWS SDK v2 clients. + awsClients AWSClientProvider } // CheckAndSetDefaults validates the config and sets defaults. @@ -109,10 +107,8 @@ func (cfg *awsFetcherConfig) CheckAndSetDefaults(component string) error { ) } - if cfg.redshiftClientProviderFn == nil { - cfg.redshiftClientProviderFn = func(cfg aws.Config, optFns ...func(*redshift.Options)) RedshiftClient { - return redshift.NewFromConfig(cfg, optFns...) - } + if cfg.awsClients == nil { + cfg.awsClients = defaultAWSClients{} } return nil } diff --git a/lib/srv/discovery/fetchers/db/aws_docdb.go b/lib/srv/discovery/fetchers/db/aws_docdb.go index a6a604be340eb..ef1920d83d6b8 100644 --- a/lib/srv/discovery/fetchers/db/aws_docdb.go +++ b/lib/srv/discovery/fetchers/db/aws_docdb.go @@ -21,14 +21,14 @@ package db import ( "context" + "github.com/aws/aws-sdk-go-v2/service/rds" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/rds" - "github.com/aws/aws-sdk-go/service/rds/rdsiface" "github.com/gravitational/trace" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/cloud" libcloudaws "github.com/gravitational/teleport/lib/cloud/aws" + "github.com/gravitational/teleport/lib/cloud/awsconfig" "github.com/gravitational/teleport/lib/srv/discovery/common" ) @@ -39,13 +39,6 @@ func newDocumentDBFetcher(cfg awsFetcherConfig) (common.Fetcher, error) { } // rdsDocumentDBFetcher retrieves DocumentDB clusters. -// -// Note that AWS DocumentDB internally uses the RDS APIs: -// https://github.com/aws/aws-sdk-go/blob/3248e69e16aa601ffa929be53a52439425257e5e/service/docdb/service.go#L33 -// The interfaces/structs in "services/docdb" are usually a subset of those in -// "services/rds". -// -// TODO(greedy52) switch to aws-sdk-go-v2/services/docdb. type rdsDocumentDBFetcher struct{} func (f *rdsDocumentDBFetcher) ComponentShortName() string { @@ -54,21 +47,22 @@ func (f *rdsDocumentDBFetcher) ComponentShortName() string { // GetDatabases returns a list of database resources representing DocumentDB endpoints. func (f *rdsDocumentDBFetcher) GetDatabases(ctx context.Context, cfg *awsFetcherConfig) (types.Databases, error) { - rdsClient, err := cfg.AWSClients.GetAWSRDSClient(ctx, cfg.Region, - cloud.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID), - cloud.WithCredentialsMaybeIntegration(cfg.Integration), + awsCfg, err := cfg.AWSConfigProvider.GetConfig(ctx, cfg.Region, + awsconfig.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID), + awsconfig.WithCredentialsMaybeIntegration(cfg.Integration), ) if err != nil { return nil, trace.Wrap(err) } - clusters, err := f.getAllDBClusters(ctx, rdsClient) + clt := cfg.awsClients.GetRDSClient(awsCfg) + clusters, err := f.getAllDBClusters(ctx, clt) if err != nil { - return nil, trace.Wrap(libcloudaws.ConvertRequestFailureError(err)) + return nil, trace.Wrap(err) } databases := make(types.Databases, 0) for _, cluster := range clusters { - if !libcloudaws.IsDocumentDBClusterSupported(cluster) { + if !libcloudaws.IsDocumentDBClusterSupported(&cluster) { cfg.Logger.DebugContext(ctx, "DocumentDB cluster doesn't support IAM authentication. Skipping.", "cluster", aws.StringValue(cluster.DBClusterIdentifier), "engine_version", aws.StringValue(cluster.EngineVersion)) @@ -82,7 +76,7 @@ func (f *rdsDocumentDBFetcher) GetDatabases(ctx context.Context, cfg *awsFetcher continue } - dbs, err := common.NewDatabasesFromDocumentDBCluster(cluster) + dbs, err := common.NewDatabasesFromDocumentDBCluster(&cluster) if err != nil { cfg.Logger.WarnContext(ctx, "Could not convert DocumentDB cluster to database resources.", "cluster", aws.StringValue(cluster.DBClusterIdentifier), @@ -93,15 +87,23 @@ func (f *rdsDocumentDBFetcher) GetDatabases(ctx context.Context, cfg *awsFetcher return databases, nil } -func (f *rdsDocumentDBFetcher) getAllDBClusters(ctx context.Context, rdsClient rdsiface.RDSAPI) ([]*rds.DBCluster, error) { - var pageNum int - var clusters []*rds.DBCluster - err := rdsClient.DescribeDBClustersPagesWithContext(ctx, &rds.DescribeDBClustersInput{ - Filters: rdsEngineFilter([]string{"docdb"}), - }, func(ddo *rds.DescribeDBClustersOutput, lastPage bool) bool { - pageNum++ - clusters = append(clusters, ddo.DBClusters...) - return pageNum <= maxAWSPages - }) - return clusters, trace.Wrap(err) +func (f *rdsDocumentDBFetcher) getAllDBClusters(ctx context.Context, clt RDSClient) ([]rdstypes.DBCluster, error) { + pager := rds.NewDescribeDBClustersPaginator(clt, + &rds.DescribeDBClustersInput{ + Filters: rdsEngineFilter([]string{"docdb"}), + }, + func(pagerOpts *rds.DescribeDBClustersPaginatorOptions) { + pagerOpts.StopOnDuplicateToken = true + }, + ) + + var clusters []rdstypes.DBCluster + for i := 0; i < maxAWSPages && pager.HasMorePages(); i++ { + page, err := pager.NextPage(ctx) + if err != nil { + return nil, trace.Wrap(libcloudaws.ConvertRequestFailureErrorV2(err)) + } + clusters = append(clusters, page.DBClusters...) + } + return clusters, nil } diff --git a/lib/srv/discovery/fetchers/db/aws_docdb_test.go b/lib/srv/discovery/fetchers/db/aws_docdb_test.go index 4ae7cfee582f0..5f71a805f8131 100644 --- a/lib/srv/discovery/fetchers/db/aws_docdb_test.go +++ b/lib/srv/discovery/fetchers/db/aws_docdb_test.go @@ -21,12 +21,11 @@ package db import ( "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/rds" + "github.com/aws/aws-sdk-go-v2/aws" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/cloud" "github.com/gravitational/teleport/lib/cloud/mocks" "github.com/gravitational/teleport/lib/srv/discovery/common" ) @@ -34,16 +33,16 @@ import ( func TestDocumentDBFetcher(t *testing.T) { t.Parallel() - docdbEngine := &rds.DBEngineVersion{ + docdbEngine := &rdstypes.DBEngineVersion{ Engine: aws.String("docdb"), } clusterProd := mocks.DocumentDBCluster("cluster1", "us-east-1", envProdLabels, mocks.WithDocumentDBClusterReader) clusterDev := mocks.DocumentDBCluster("cluster2", "us-east-1", envDevLabels) - clusterNotAvailable := mocks.DocumentDBCluster("cluster3", "us-east-1", envDevLabels, func(cluster *rds.DBCluster) { + clusterNotAvailable := mocks.DocumentDBCluster("cluster3", "us-east-1", envDevLabels, func(cluster *rdstypes.DBCluster) { cluster.Status = aws.String("creating") }) - clusterNotSupported := mocks.DocumentDBCluster("cluster4", "us-east-1", envDevLabels, func(cluster *rds.DBCluster) { + clusterNotSupported := mocks.DocumentDBCluster("cluster4", "us-east-1", envDevLabels, func(cluster *rdstypes.DBCluster) { cluster.EngineVersion = aws.String("4.0.0") }) @@ -53,10 +52,12 @@ func TestDocumentDBFetcher(t *testing.T) { tests := []awsFetcherTest{ { name: "fetch all", - inputClients: &cloud.TestCloudClients{ - RDS: &mocks.RDSMock{ - DBClusters: []*rds.DBCluster{clusterProd, clusterDev}, - DBEngineVersions: []*rds.DBEngineVersion{docdbEngine}, + fetcherCfg: AWSFetcherFactoryConfig{ + AWSClients: fakeAWSClients{ + rdsClient: &mocks.RDSClient{ + DBClusters: []rdstypes.DBCluster{*clusterProd, *clusterDev}, + DBEngineVersions: []rdstypes.DBEngineVersion{*docdbEngine}, + }, }, }, inputMatchers: []types.AWSMatcher{ @@ -70,10 +71,12 @@ func TestDocumentDBFetcher(t *testing.T) { }, { name: "filter by labels", - inputClients: &cloud.TestCloudClients{ - RDS: &mocks.RDSMock{ - DBClusters: []*rds.DBCluster{clusterProd, clusterDev}, - DBEngineVersions: []*rds.DBEngineVersion{docdbEngine}, + fetcherCfg: AWSFetcherFactoryConfig{ + AWSClients: fakeAWSClients{ + rdsClient: &mocks.RDSClient{ + DBClusters: []rdstypes.DBCluster{*clusterProd, *clusterDev}, + DBEngineVersions: []rdstypes.DBEngineVersion{*docdbEngine}, + }, }, }, inputMatchers: []types.AWSMatcher{ @@ -87,10 +90,12 @@ func TestDocumentDBFetcher(t *testing.T) { }, { name: "skip unsupported databases", - inputClients: &cloud.TestCloudClients{ - RDS: &mocks.RDSMock{ - DBClusters: []*rds.DBCluster{clusterProd, clusterNotSupported}, - DBEngineVersions: []*rds.DBEngineVersion{docdbEngine}, + fetcherCfg: AWSFetcherFactoryConfig{ + AWSClients: fakeAWSClients{ + rdsClient: &mocks.RDSClient{ + DBClusters: []rdstypes.DBCluster{*clusterProd, *clusterNotSupported}, + DBEngineVersions: []rdstypes.DBEngineVersion{*docdbEngine}, + }, }, }, inputMatchers: []types.AWSMatcher{ @@ -104,10 +109,12 @@ func TestDocumentDBFetcher(t *testing.T) { }, { name: "skip unavailable databases", - inputClients: &cloud.TestCloudClients{ - RDS: &mocks.RDSMock{ - DBClusters: []*rds.DBCluster{clusterProd, clusterNotAvailable}, - DBEngineVersions: []*rds.DBEngineVersion{docdbEngine}, + fetcherCfg: AWSFetcherFactoryConfig{ + AWSClients: fakeAWSClients{ + rdsClient: &mocks.RDSClient{ + DBClusters: []rdstypes.DBCluster{*clusterProd, *clusterNotAvailable}, + DBEngineVersions: []rdstypes.DBEngineVersion{*docdbEngine}, + }, }, }, inputMatchers: []types.AWSMatcher{ @@ -123,7 +130,7 @@ func TestDocumentDBFetcher(t *testing.T) { testAWSFetchers(t, tests...) } -func mustMakeDocumentDBDatabases(t *testing.T, cluster *rds.DBCluster) types.Databases { +func mustMakeDocumentDBDatabases(t *testing.T, cluster *rdstypes.DBCluster) types.Databases { t.Helper() databases, err := common.NewDatabasesFromDocumentDBCluster(cluster) diff --git a/lib/srv/discovery/fetchers/db/aws_rds.go b/lib/srv/discovery/fetchers/db/aws_rds.go index 639835f2b75a2..1b438873c8726 100644 --- a/lib/srv/discovery/fetchers/db/aws_rds.go +++ b/lib/srv/discovery/fetchers/db/aws_rds.go @@ -23,18 +23,27 @@ import ( "log/slog" "strings" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/rds" - "github.com/aws/aws-sdk-go/service/rds/rdsiface" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/rds" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/gravitational/trace" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/cloud" libcloudaws "github.com/gravitational/teleport/lib/cloud/aws" + "github.com/gravitational/teleport/lib/cloud/awsconfig" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/discovery/common" ) +// RDSClient is a subset of the AWS RDS API. +type RDSClient interface { + rds.DescribeDBClustersAPIClient + rds.DescribeDBInstancesAPIClient + rds.DescribeDBProxiesAPIClient + rds.DescribeDBProxyEndpointsAPIClient + ListTagsForResource(context.Context, *rds.ListTagsForResourceInput, ...func(*rds.Options)) (*rds.ListTagsForResourceOutput, error) +} + // newRDSDBInstancesFetcher returns a new AWS fetcher for RDS databases. func newRDSDBInstancesFetcher(cfg awsFetcherConfig) (common.Fetcher, error) { return newAWSFetcher(cfg, &rdsDBInstancesPlugin{}) @@ -49,40 +58,41 @@ func (f *rdsDBInstancesPlugin) ComponentShortName() string { // GetDatabases returns a list of database resources representing RDS instances. func (f *rdsDBInstancesPlugin) GetDatabases(ctx context.Context, cfg *awsFetcherConfig) (types.Databases, error) { - rdsClient, err := cfg.AWSClients.GetAWSRDSClient(ctx, cfg.Region, - cloud.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID), - cloud.WithCredentialsMaybeIntegration(cfg.Integration), + awsCfg, err := cfg.AWSConfigProvider.GetConfig(ctx, cfg.Region, + awsconfig.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID), + awsconfig.WithCredentialsMaybeIntegration(cfg.Integration), ) if err != nil { return nil, trace.Wrap(err) } - instances, err := getAllDBInstances(ctx, rdsClient, maxAWSPages, cfg.Logger) + clt := cfg.awsClients.GetRDSClient(awsCfg) + instances, err := getAllDBInstances(ctx, clt, maxAWSPages, cfg.Logger) if err != nil { - return nil, trace.Wrap(libcloudaws.ConvertRequestFailureError(err)) + return nil, trace.Wrap(err) } databases := make(types.Databases, 0, len(instances)) for _, instance := range instances { - if !libcloudaws.IsRDSInstanceSupported(instance) { + if !libcloudaws.IsRDSInstanceSupported(&instance) { cfg.Logger.DebugContext(ctx, "Skipping RDS instance that does not support IAM authentication", - "instance", aws.StringValue(instance.DBInstanceIdentifier), - "engine_mode", aws.StringValue(instance.Engine), - "engine_version", aws.StringValue(instance.EngineVersion), + "instance", aws.ToString(instance.DBInstanceIdentifier), + "engine_mode", aws.ToString(instance.Engine), + "engine_version", aws.ToString(instance.EngineVersion), ) continue } if !libcloudaws.IsRDSInstanceAvailable(instance.DBInstanceStatus, instance.DBInstanceIdentifier) { cfg.Logger.DebugContext(ctx, "Skipping unavailable RDS instance", - "instance", aws.StringValue(instance.DBInstanceIdentifier), - "status", aws.StringValue(instance.DBInstanceStatus), + "instance", aws.ToString(instance.DBInstanceIdentifier), + "status", aws.ToString(instance.DBInstanceStatus), ) continue } - database, err := common.NewDatabaseFromRDSInstance(instance) + database, err := common.NewDatabaseFromRDSInstance(&instance) if err != nil { cfg.Logger.WarnContext(ctx, "Could not convert RDS instance to database resource", - "instance", aws.StringValue(instance.DBInstanceIdentifier), + "instance", aws.ToString(instance.DBInstanceIdentifier), "error", err, ) } else { @@ -94,36 +104,40 @@ func (f *rdsDBInstancesPlugin) GetDatabases(ctx context.Context, cfg *awsFetcher // getAllDBInstances fetches all RDS instances using the provided client, up // to the specified max number of pages. -func getAllDBInstances(ctx context.Context, rdsClient rdsiface.RDSAPI, maxPages int, logger *slog.Logger) ([]*rds.DBInstance, error) { - return getAllDBInstancesWithFilters(ctx, rdsClient, maxPages, rdsInstanceEngines(), rdsEmptyFilter(), logger) +func getAllDBInstances(ctx context.Context, clt RDSClient, maxPages int, logger *slog.Logger) ([]rdstypes.DBInstance, error) { + return getAllDBInstancesWithFilters(ctx, clt, maxPages, rdsInstanceEngines(), rdsEmptyFilter(), logger) } // findDBInstancesForDBCluster returns the DBInstances associated with a given DB Cluster Identifier -func findDBInstancesForDBCluster(ctx context.Context, rdsClient rdsiface.RDSAPI, maxPages int, dbClusterIdentifier string, logger *slog.Logger) ([]*rds.DBInstance, error) { - return getAllDBInstancesWithFilters(ctx, rdsClient, maxPages, auroraEngines(), rdsClusterIDFilter(dbClusterIdentifier), logger) +func findDBInstancesForDBCluster(ctx context.Context, clt RDSClient, maxPages int, dbClusterIdentifier string, logger *slog.Logger) ([]rdstypes.DBInstance, error) { + return getAllDBInstancesWithFilters(ctx, clt, maxPages, auroraEngines(), rdsClusterIDFilter(dbClusterIdentifier), logger) } // getAllDBInstancesWithFilters fetches all RDS instances matching the filters using the provided client, up // to the specified max number of pages. -func getAllDBInstancesWithFilters(ctx context.Context, rdsClient rdsiface.RDSAPI, maxPages int, engines []string, baseFilters []*rds.Filter, logger *slog.Logger) ([]*rds.DBInstance, error) { - var instances []*rds.DBInstance - err := retryWithIndividualEngineFilters(ctx, logger, engines, func(engineFilters []*rds.Filter) error { - var pageNum int - var out []*rds.DBInstance - err := rdsClient.DescribeDBInstancesPagesWithContext(ctx, &rds.DescribeDBInstancesInput{ - Filters: append(engineFilters, baseFilters...), - }, func(ddo *rds.DescribeDBInstancesOutput, lastPage bool) bool { - pageNum++ - instances = append(instances, ddo.DBInstances...) - return pageNum <= maxPages - }) - if err == nil { - // only append to instances on nil error, just in case we have to retry. - instances = append(instances, out...) +func getAllDBInstancesWithFilters(ctx context.Context, clt RDSClient, maxPages int, engines []string, baseFilters []rdstypes.Filter, logger *slog.Logger) ([]rdstypes.DBInstance, error) { + var out []rdstypes.DBInstance + err := retryWithIndividualEngineFilters(ctx, logger, engines, func(engineFilters []rdstypes.Filter) error { + pager := rds.NewDescribeDBInstancesPaginator(clt, + &rds.DescribeDBInstancesInput{ + Filters: append(engineFilters, baseFilters...), + }, + func(dcpo *rds.DescribeDBInstancesPaginatorOptions) { + dcpo.StopOnDuplicateToken = true + }, + ) + var instances []rdstypes.DBInstance + for i := 0; i < maxPages && pager.HasMorePages(); i++ { + page, err := pager.NextPage(ctx) + if err != nil { + return trace.Wrap(err) + } + instances = append(instances, page.DBInstances...) } - return trace.Wrap(err) + out = instances + return nil }) - return instances, trace.Wrap(err) + return out, trace.Wrap(libcloudaws.ConvertRequestFailureErrorV2(err)) } // newRDSAuroraClustersFetcher returns a new AWS fetcher for RDS Aurora @@ -141,48 +155,49 @@ func (f *rdsAuroraClustersPlugin) ComponentShortName() string { // GetDatabases returns a list of database resources representing RDS clusters. func (f *rdsAuroraClustersPlugin) GetDatabases(ctx context.Context, cfg *awsFetcherConfig) (types.Databases, error) { - rdsClient, err := cfg.AWSClients.GetAWSRDSClient(ctx, cfg.Region, - cloud.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID), - cloud.WithCredentialsMaybeIntegration(cfg.Integration), + awsCfg, err := cfg.AWSConfigProvider.GetConfig(ctx, cfg.Region, + awsconfig.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID), + awsconfig.WithCredentialsMaybeIntegration(cfg.Integration), ) if err != nil { return nil, trace.Wrap(err) } - clusters, err := getAllDBClusters(ctx, rdsClient, maxAWSPages, cfg.Logger) + clt := cfg.awsClients.GetRDSClient(awsCfg) + clusters, err := getAllDBClusters(ctx, clt, maxAWSPages, cfg.Logger) if err != nil { - return nil, trace.Wrap(libcloudaws.ConvertRequestFailureError(err)) + return nil, trace.Wrap(err) } databases := make(types.Databases, 0, len(clusters)) for _, cluster := range clusters { - if !libcloudaws.IsRDSClusterSupported(cluster) { + if !libcloudaws.IsRDSClusterSupported(&cluster) { cfg.Logger.DebugContext(ctx, "Skipping Aurora cluster that does not support IAM authentication", - "cluster", aws.StringValue(cluster.DBClusterIdentifier), - "engine_mode", aws.StringValue(cluster.EngineMode), - "engine_version", aws.StringValue(cluster.EngineVersion), + "cluster", aws.ToString(cluster.DBClusterIdentifier), + "engine_mode", aws.ToString(cluster.EngineMode), + "engine_version", aws.ToString(cluster.EngineVersion), ) continue } if !libcloudaws.IsDBClusterAvailable(cluster.Status, cluster.DBClusterIdentifier) { cfg.Logger.DebugContext(ctx, "Skipping unavailable Aurora cluster", - "instance", aws.StringValue(cluster.DBClusterIdentifier), - "status", aws.StringValue(cluster.Status), + "instance", aws.ToString(cluster.DBClusterIdentifier), + "status", aws.ToString(cluster.Status), ) continue } - rdsDBInstances, err := findDBInstancesForDBCluster(ctx, rdsClient, maxAWSPages, aws.StringValue(cluster.DBClusterIdentifier), cfg.Logger) + rdsDBInstances, err := findDBInstancesForDBCluster(ctx, clt, maxAWSPages, aws.ToString(cluster.DBClusterIdentifier), cfg.Logger) if err != nil || len(rdsDBInstances) == 0 { cfg.Logger.WarnContext(ctx, "Could not fetch Member Instance for DB Cluster", - "instance", aws.StringValue(cluster.DBClusterIdentifier), + "instance", aws.ToString(cluster.DBClusterIdentifier), "error", err, ) } - dbs, err := common.NewDatabasesFromRDSCluster(cluster, rdsDBInstances) + dbs, err := common.NewDatabasesFromRDSCluster(&cluster, rdsDBInstances) if err != nil { cfg.Logger.WarnContext(ctx, "Could not convert RDS cluster to database resources", - "identifier", aws.StringValue(cluster.DBClusterIdentifier), + "identifier", aws.ToString(cluster.DBClusterIdentifier), "error", err, ) } @@ -193,25 +208,30 @@ func (f *rdsAuroraClustersPlugin) GetDatabases(ctx context.Context, cfg *awsFetc // getAllDBClusters fetches all RDS clusters using the provided client, up to // the specified max number of pages. -func getAllDBClusters(ctx context.Context, rdsClient rdsiface.RDSAPI, maxPages int, logger *slog.Logger) ([]*rds.DBCluster, error) { - var clusters []*rds.DBCluster - err := retryWithIndividualEngineFilters(ctx, logger, auroraEngines(), func(filters []*rds.Filter) error { - var pageNum int - var out []*rds.DBCluster - err := rdsClient.DescribeDBClustersPagesWithContext(ctx, &rds.DescribeDBClustersInput{ - Filters: filters, - }, func(ddo *rds.DescribeDBClustersOutput, lastPage bool) bool { - pageNum++ - out = append(out, ddo.DBClusters...) - return pageNum <= maxPages - }) - if err == nil { - // only append to clusters on nil error, just in case we have to retry. - clusters = append(clusters, out...) +func getAllDBClusters(ctx context.Context, clt RDSClient, maxPages int, logger *slog.Logger) ([]rdstypes.DBCluster, error) { + var out []rdstypes.DBCluster + err := retryWithIndividualEngineFilters(ctx, logger, auroraEngines(), func(filters []rdstypes.Filter) error { + pager := rds.NewDescribeDBClustersPaginator(clt, + &rds.DescribeDBClustersInput{ + Filters: filters, + }, + func(pagerOpts *rds.DescribeDBClustersPaginatorOptions) { + pagerOpts.StopOnDuplicateToken = true + }, + ) + + var clusters []rdstypes.DBCluster + for i := 0; i < maxPages && pager.HasMorePages(); i++ { + page, err := pager.NextPage(ctx) + if err != nil { + return trace.Wrap(err) + } + clusters = append(clusters, page.DBClusters...) } - return trace.Wrap(err) + out = clusters + return nil }) - return clusters, trace.Wrap(err) + return out, trace.Wrap(libcloudaws.ConvertRequestFailureErrorV2(err)) } // rdsInstanceEngines returns engines to make sure DescribeDBInstances call returns @@ -234,28 +254,28 @@ func auroraEngines() []string { } // rdsEngineFilter is a helper func to construct an RDS filter for engine names. -func rdsEngineFilter(engines []string) []*rds.Filter { - return []*rds.Filter{{ +func rdsEngineFilter(engines []string) []rdstypes.Filter { + return []rdstypes.Filter{{ Name: aws.String("engine"), - Values: aws.StringSlice(engines), + Values: engines, }} } // rdsClusterIDFilter is a helper func to construct an RDS DB Instances for returning Instances of a specific DB Cluster. -func rdsClusterIDFilter(clusterIdentifier string) []*rds.Filter { - return []*rds.Filter{{ +func rdsClusterIDFilter(clusterIdentifier string) []rdstypes.Filter { + return []rdstypes.Filter{{ Name: aws.String("db-cluster-id"), - Values: aws.StringSlice([]string{clusterIdentifier}), + Values: []string{clusterIdentifier}, }} } // rdsEmptyFilter is a helper func to construct an empty RDS filter. -func rdsEmptyFilter() []*rds.Filter { - return []*rds.Filter{} +func rdsEmptyFilter() []rdstypes.Filter { + return []rdstypes.Filter{} } // rdsFilterFn is a function that takes RDS filters and performs some operation with them, returning any error encountered. -type rdsFilterFn func([]*rds.Filter) error +type rdsFilterFn func([]rdstypes.Filter) error // retryWithIndividualEngineFilters is a helper error handling function for AWS RDS unrecognized engine name filter errors, // that will call the provided RDS querying function with filters, check the returned error, diff --git a/lib/srv/discovery/fetchers/db/aws_rds_proxy.go b/lib/srv/discovery/fetchers/db/aws_rds_proxy.go index dde1a1a189940..59adf7f7f5b88 100644 --- a/lib/srv/discovery/fetchers/db/aws_rds_proxy.go +++ b/lib/srv/discovery/fetchers/db/aws_rds_proxy.go @@ -21,14 +21,14 @@ package db import ( "context" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/rds" - "github.com/aws/aws-sdk-go/service/rds/rdsiface" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/rds" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/gravitational/trace" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/cloud" libcloudaws "github.com/gravitational/teleport/lib/cloud/aws" + "github.com/gravitational/teleport/lib/cloud/awsconfig" "github.com/gravitational/teleport/lib/srv/discovery/common" ) @@ -47,56 +47,57 @@ func (f *rdsDBProxyPlugin) ComponentShortName() string { // GetDatabases returns a list of database resources representing RDS // Proxies and custom endpoints. func (f *rdsDBProxyPlugin) GetDatabases(ctx context.Context, cfg *awsFetcherConfig) (types.Databases, error) { - rdsClient, err := cfg.AWSClients.GetAWSRDSClient(ctx, cfg.Region, - cloud.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID), - cloud.WithCredentialsMaybeIntegration(cfg.Integration), + awsCfg, err := cfg.AWSConfigProvider.GetConfig(ctx, cfg.Region, + awsconfig.WithAssumeRole(cfg.AssumeRole.RoleARN, cfg.AssumeRole.ExternalID), + awsconfig.WithCredentialsMaybeIntegration(cfg.Integration), ) if err != nil { return nil, trace.Wrap(err) } + clt := cfg.awsClients.GetRDSClient(awsCfg) // Get a list of all RDS Proxies. Each RDS Proxy has one "default" // endpoint. - rdsProxies, err := getRDSProxies(ctx, rdsClient, maxAWSPages) + rdsProxies, err := getRDSProxies(ctx, clt, maxAWSPages) if err != nil { return nil, trace.Wrap(err) } // Get all RDS Proxy custom endpoints sorted by the name of the RDS Proxy // that owns the custom endpoints. - customEndpointsByProxyName, err := getRDSProxyCustomEndpoints(ctx, rdsClient, maxAWSPages) + customEndpointsByProxyName, err := getRDSProxyCustomEndpoints(ctx, clt, maxAWSPages) if err != nil { cfg.Logger.DebugContext(ctx, "Failed to get RDS Proxy endpoints", "error", err) } var databases types.Databases for _, dbProxy := range rdsProxies { - if !aws.BoolValue(dbProxy.RequireTLS) { - cfg.Logger.DebugContext(ctx, "Skipping RDS Proxy that doesn't support TLS", "rds_proxy", aws.StringValue(dbProxy.DBProxyName)) + if !aws.ToBool(dbProxy.RequireTLS) { + cfg.Logger.DebugContext(ctx, "Skipping RDS Proxy that doesn't support TLS", "rds_proxy", aws.ToString(dbProxy.DBProxyName)) continue } - if !libcloudaws.IsRDSProxyAvailable(dbProxy) { + if !libcloudaws.IsRDSProxyAvailable(&dbProxy) { cfg.Logger.DebugContext(ctx, "Skipping unavailable RDS Proxy", - "rds_proxy", aws.StringValue(dbProxy.DBProxyName), - "status", aws.StringValue(dbProxy.Status)) + "rds_proxy", aws.ToString(dbProxy.DBProxyName), + "status", dbProxy.Status) continue } - // rds.DBProxy has no tags information. An extra SDK call is made to + // rdstypes.DBProxy has no tags information. An extra SDK call is made to // fetch the tags. If failed, keep going without the tags. - tags, err := listRDSResourceTags(ctx, rdsClient, dbProxy.DBProxyArn) + tags, err := listRDSResourceTags(ctx, clt, dbProxy.DBProxyArn) if err != nil { cfg.Logger.DebugContext(ctx, "Failed to get tags for RDS Proxy", - "rds_proxy", aws.StringValue(dbProxy.DBProxyName), + "rds_proxy", aws.ToString(dbProxy.DBProxyName), "error", err, ) } // Add a database from RDS Proxy (default endpoint). - database, err := common.NewDatabaseFromRDSProxy(dbProxy, tags) + database, err := common.NewDatabaseFromRDSProxy(&dbProxy, tags) if err != nil { cfg.Logger.DebugContext(ctx, "Could not convert RDS Proxy to database resource", - "rds_proxy", aws.StringValue(dbProxy.DBProxyName), + "rds_proxy", aws.ToString(dbProxy.DBProxyName), "error", err, ) } else { @@ -104,21 +105,21 @@ func (f *rdsDBProxyPlugin) GetDatabases(ctx context.Context, cfg *awsFetcherConf } // Add custom endpoints. - for _, customEndpoint := range customEndpointsByProxyName[aws.StringValue(dbProxy.DBProxyName)] { - if !libcloudaws.IsRDSProxyCustomEndpointAvailable(customEndpoint) { + for _, customEndpoint := range customEndpointsByProxyName[aws.ToString(dbProxy.DBProxyName)] { + if !libcloudaws.IsRDSProxyCustomEndpointAvailable(&customEndpoint) { cfg.Logger.DebugContext(ctx, "Skipping unavailable custom endpoint of RDS Proxy", - "endpoint", aws.StringValue(customEndpoint.DBProxyEndpointName), - "rds_proxy", aws.StringValue(customEndpoint.DBProxyName), - "status", aws.StringValue(customEndpoint.Status), + "endpoint", aws.ToString(customEndpoint.DBProxyEndpointName), + "rds_proxy", aws.ToString(customEndpoint.DBProxyName), + "status", customEndpoint.Status, ) continue } - database, err = common.NewDatabaseFromRDSProxyCustomEndpoint(dbProxy, customEndpoint, tags) + database, err = common.NewDatabaseFromRDSProxyCustomEndpoint(&dbProxy, &customEndpoint, tags) if err != nil { cfg.Logger.DebugContext(ctx, "Could not convert custom endpoint for RDS Proxy to database resource", - "endpoint", aws.StringValue(customEndpoint.DBProxyEndpointName), - "rds_proxy", aws.StringValue(customEndpoint.DBProxyName), + "endpoint", aws.ToString(customEndpoint.DBProxyEndpointName), + "rds_proxy", aws.ToString(customEndpoint.DBProxyName), "error", err, ) continue @@ -132,46 +133,54 @@ func (f *rdsDBProxyPlugin) GetDatabases(ctx context.Context, cfg *awsFetcherConf // getRDSProxies fetches all RDS Proxies using the provided client, up to the // specified max number of pages. -func getRDSProxies(ctx context.Context, rdsClient rdsiface.RDSAPI, maxPages int) (rdsProxies []*rds.DBProxy, err error) { - var pageNum int - err = rdsClient.DescribeDBProxiesPagesWithContext( - ctx, +func getRDSProxies(ctx context.Context, clt RDSClient, maxPages int) ([]rdstypes.DBProxy, error) { + pager := rds.NewDescribeDBProxiesPaginator(clt, &rds.DescribeDBProxiesInput{}, - func(ddo *rds.DescribeDBProxiesOutput, lastPage bool) bool { - pageNum++ - rdsProxies = append(rdsProxies, ddo.DBProxies...) - return pageNum <= maxPages + func(dcpo *rds.DescribeDBProxiesPaginatorOptions) { + dcpo.StopOnDuplicateToken = true }, ) - return rdsProxies, trace.Wrap(libcloudaws.ConvertRequestFailureError(err)) + + var rdsProxies []rdstypes.DBProxy + for i := 0; i < maxPages && pager.HasMorePages(); i++ { + page, err := pager.NextPage(ctx) + if err != nil { + return nil, trace.Wrap(libcloudaws.ConvertRequestFailureErrorV2(err)) + } + rdsProxies = append(rdsProxies, page.DBProxies...) + } + return rdsProxies, nil } // getRDSProxyCustomEndpoints fetches all RDS Proxy custom endpoints using the // provided client. -func getRDSProxyCustomEndpoints(ctx context.Context, rdsClient rdsiface.RDSAPI, maxPages int) (map[string][]*rds.DBProxyEndpoint, error) { - customEndpointsByProxyName := make(map[string][]*rds.DBProxyEndpoint) - var pageNum int - err := rdsClient.DescribeDBProxyEndpointsPagesWithContext( - ctx, +func getRDSProxyCustomEndpoints(ctx context.Context, clt RDSClient, maxPages int) (map[string][]rdstypes.DBProxyEndpoint, error) { + customEndpointsByProxyName := make(map[string][]rdstypes.DBProxyEndpoint) + pager := rds.NewDescribeDBProxyEndpointsPaginator(clt, &rds.DescribeDBProxyEndpointsInput{}, - func(ddo *rds.DescribeDBProxyEndpointsOutput, lastPage bool) bool { - pageNum++ - for _, customEndpoint := range ddo.DBProxyEndpoints { - customEndpointsByProxyName[aws.StringValue(customEndpoint.DBProxyName)] = append(customEndpointsByProxyName[aws.StringValue(customEndpoint.DBProxyName)], customEndpoint) - } - return pageNum <= maxPages + func(ddepo *rds.DescribeDBProxyEndpointsPaginatorOptions) { + ddepo.StopOnDuplicateToken = true }, ) - return customEndpointsByProxyName, trace.Wrap(libcloudaws.ConvertRequestFailureError(err)) + for i := 0; i < maxPages && pager.HasMorePages(); i++ { + page, err := pager.NextPage(ctx) + if err != nil { + return nil, trace.Wrap(libcloudaws.ConvertRequestFailureErrorV2(err)) + } + for _, customEndpoint := range page.DBProxyEndpoints { + customEndpointsByProxyName[aws.ToString(customEndpoint.DBProxyName)] = append(customEndpointsByProxyName[aws.ToString(customEndpoint.DBProxyName)], customEndpoint) + } + } + return customEndpointsByProxyName, nil } // listRDSResourceTags returns tags for provided RDS resource. -func listRDSResourceTags(ctx context.Context, rdsClient rdsiface.RDSAPI, resourceName *string) ([]*rds.Tag, error) { - output, err := rdsClient.ListTagsForResourceWithContext(ctx, &rds.ListTagsForResourceInput{ +func listRDSResourceTags(ctx context.Context, clt RDSClient, resourceName *string) ([]rdstypes.Tag, error) { + output, err := clt.ListTagsForResource(ctx, &rds.ListTagsForResourceInput{ ResourceName: resourceName, }) if err != nil { - return nil, trace.Wrap(libcloudaws.ConvertRequestFailureError(err)) + return nil, trace.Wrap(libcloudaws.ConvertRequestFailureErrorV2(err)) } return output.TagList, nil } diff --git a/lib/srv/discovery/fetchers/db/aws_rds_proxy_test.go b/lib/srv/discovery/fetchers/db/aws_rds_proxy_test.go index b92ff2a439eda..99af538f590f4 100644 --- a/lib/srv/discovery/fetchers/db/aws_rds_proxy_test.go +++ b/lib/srv/discovery/fetchers/db/aws_rds_proxy_test.go @@ -21,11 +21,10 @@ package db import ( "testing" - "github.com/aws/aws-sdk-go/service/rds" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/cloud" "github.com/gravitational/teleport/lib/cloud/mocks" "github.com/gravitational/teleport/lib/srv/discovery/common" ) @@ -41,10 +40,12 @@ func TestRDSDBProxyFetcher(t *testing.T) { tests := []awsFetcherTest{ { name: "fetch all", - inputClients: &cloud.TestCloudClients{ - RDS: &mocks.RDSMock{ - DBProxies: []*rds.DBProxy{rdsProxyVpc1, rdsProxyVpc2}, - DBProxyEndpoints: []*rds.DBProxyEndpoint{rdsProxyEndpointVpc1, rdsProxyEndpointVpc2}, + fetcherCfg: AWSFetcherFactoryConfig{ + AWSClients: fakeAWSClients{ + rdsClient: &mocks.RDSClient{ + DBProxies: []rdstypes.DBProxy{*rdsProxyVpc1, *rdsProxyVpc2}, + DBProxyEndpoints: []rdstypes.DBProxyEndpoint{*rdsProxyEndpointVpc1, *rdsProxyEndpointVpc2}, + }, }, }, inputMatchers: makeAWSMatchersForType(types.AWSMatcherRDSProxy, "us-east-1", wildcardLabels), @@ -52,10 +53,12 @@ func TestRDSDBProxyFetcher(t *testing.T) { }, { name: "fetch vpc1", - inputClients: &cloud.TestCloudClients{ - RDS: &mocks.RDSMock{ - DBProxies: []*rds.DBProxy{rdsProxyVpc1, rdsProxyVpc2}, - DBProxyEndpoints: []*rds.DBProxyEndpoint{rdsProxyEndpointVpc1, rdsProxyEndpointVpc2}, + fetcherCfg: AWSFetcherFactoryConfig{ + AWSClients: fakeAWSClients{ + rdsClient: &mocks.RDSClient{ + DBProxies: []rdstypes.DBProxy{*rdsProxyVpc1, *rdsProxyVpc2}, + DBProxyEndpoints: []rdstypes.DBProxyEndpoint{*rdsProxyEndpointVpc1, *rdsProxyEndpointVpc2}, + }, }, }, inputMatchers: makeAWSMatchersForType(types.AWSMatcherRDSProxy, "us-east-1", map[string]string{"vpc-id": "vpc1"}), @@ -65,7 +68,7 @@ func TestRDSDBProxyFetcher(t *testing.T) { testAWSFetchers(t, tests...) } -func makeRDSProxy(t *testing.T, name, region, vpcID string) (*rds.DBProxy, types.Database) { +func makeRDSProxy(t *testing.T, name, region, vpcID string) (*rdstypes.DBProxy, types.Database) { rdsProxy := mocks.RDSProxy(name, region, vpcID) rdsProxyDatabase, err := common.NewDatabaseFromRDSProxy(rdsProxy, nil) require.NoError(t, err) @@ -73,7 +76,7 @@ func makeRDSProxy(t *testing.T, name, region, vpcID string) (*rds.DBProxy, types return rdsProxy, rdsProxyDatabase } -func makeRDSProxyCustomEndpoint(t *testing.T, rdsProxy *rds.DBProxy, name, region string) (*rds.DBProxyEndpoint, types.Database) { +func makeRDSProxyCustomEndpoint(t *testing.T, rdsProxy *rdstypes.DBProxy, name, region string) (*rdstypes.DBProxyEndpoint, types.Database) { rdsProxyEndpoint := mocks.RDSProxyCustomEndpoint(rdsProxy, name, region) rdsProxyEndpointDatabase, err := common.NewDatabaseFromRDSProxyCustomEndpoint(rdsProxy, rdsProxyEndpoint, nil) require.NoError(t, err) diff --git a/lib/srv/discovery/fetchers/db/aws_rds_test.go b/lib/srv/discovery/fetchers/db/aws_rds_test.go index 9dfc658268eeb..db4aeeb376cc3 100644 --- a/lib/srv/discovery/fetchers/db/aws_rds_test.go +++ b/lib/srv/discovery/fetchers/db/aws_rds_test.go @@ -21,13 +21,13 @@ package db import ( "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/rds" - "github.com/aws/aws-sdk-go/service/rds/rdsiface" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/rds" + rdstypes "github.com/aws/aws-sdk-go-v2/service/rds/types" + "github.com/aws/aws-sdk-go-v2/service/redshift" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/cloud" "github.com/gravitational/teleport/lib/cloud/mocks" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/discovery/common" @@ -38,8 +38,8 @@ import ( func TestRDSFetchers(t *testing.T) { t.Parallel() - auroraMySQLEngine := &rds.DBEngineVersion{Engine: aws.String(services.RDSEngineAuroraMySQL)} - postgresEngine := &rds.DBEngineVersion{Engine: aws.String(services.RDSEnginePostgres)} + auroraMySQLEngine := &rdstypes.DBEngineVersion{Engine: aws.String(services.RDSEngineAuroraMySQL)} + postgresEngine := &rdstypes.DBEngineVersion{Engine: aws.String(services.RDSEnginePostgres)} rdsInstance1, rdsDatabase1 := makeRDSInstance(t, "instance-1", "us-east-1", envProdLabels) rdsInstance2, rdsDatabase2 := makeRDSInstance(t, "instance-2", "us-east-2", envProdLabels) @@ -58,19 +58,19 @@ func TestRDSFetchers(t *testing.T) { tests := []awsFetcherTest{ { name: "fetch all", - inputClients: &cloud.TestCloudClients{ - RDSPerRegion: map[string]rdsiface.RDSAPI{ - "us-east-1": &mocks.RDSMock{ - DBInstances: []*rds.DBInstance{rdsInstance1, rdsInstance3, auroraCluster1MemberInstance}, - DBClusters: []*rds.DBCluster{auroraCluster1}, - DBEngineVersions: []*rds.DBEngineVersion{auroraMySQLEngine, postgresEngine}, + fetcherCfg: AWSFetcherFactoryConfig{ + AWSClients: newRegionalFakeRDSClientProvider(map[string]RDSClient{ + "us-east-1": &mocks.RDSClient{ + DBInstances: []rdstypes.DBInstance{*rdsInstance1, *rdsInstance3, *auroraCluster1MemberInstance}, + DBClusters: []rdstypes.DBCluster{*auroraCluster1}, + DBEngineVersions: []rdstypes.DBEngineVersion{*auroraMySQLEngine, *postgresEngine}, }, - "us-east-2": &mocks.RDSMock{ - DBInstances: []*rds.DBInstance{rdsInstance2, auroraCluster2MemberInstance, auroraCluster3MemberInstance}, - DBClusters: []*rds.DBCluster{auroraCluster2, auroraCluster3}, - DBEngineVersions: []*rds.DBEngineVersion{auroraMySQLEngine, postgresEngine}, + "us-east-2": &mocks.RDSClient{ + DBInstances: []rdstypes.DBInstance{*rdsInstance2, *auroraCluster2MemberInstance, *auroraCluster3MemberInstance}, + DBClusters: []rdstypes.DBCluster{*auroraCluster2, *auroraCluster3}, + DBEngineVersions: []rdstypes.DBEngineVersion{*auroraMySQLEngine, *postgresEngine}, }, - }, + }), }, inputMatchers: []types.AWSMatcher{ { @@ -91,19 +91,19 @@ func TestRDSFetchers(t *testing.T) { }, { name: "fetch different labels for different regions", - inputClients: &cloud.TestCloudClients{ - RDSPerRegion: map[string]rdsiface.RDSAPI{ - "us-east-1": &mocks.RDSMock{ - DBInstances: []*rds.DBInstance{rdsInstance1, rdsInstance3, auroraCluster1MemberInstance}, - DBClusters: []*rds.DBCluster{auroraCluster1}, - DBEngineVersions: []*rds.DBEngineVersion{auroraMySQLEngine, postgresEngine}, + fetcherCfg: AWSFetcherFactoryConfig{ + AWSClients: newRegionalFakeRDSClientProvider(map[string]RDSClient{ + "us-east-1": &mocks.RDSClient{ + DBInstances: []rdstypes.DBInstance{*rdsInstance1, *rdsInstance3, *auroraCluster1MemberInstance}, + DBClusters: []rdstypes.DBCluster{*auroraCluster1}, + DBEngineVersions: []rdstypes.DBEngineVersion{*auroraMySQLEngine, *postgresEngine}, }, - "us-east-2": &mocks.RDSMock{ - DBInstances: []*rds.DBInstance{rdsInstance2, auroraCluster2MemberInstance, auroraCluster3MemberInstance}, - DBClusters: []*rds.DBCluster{auroraCluster2, auroraCluster3}, - DBEngineVersions: []*rds.DBEngineVersion{auroraMySQLEngine, postgresEngine}, + "us-east-2": &mocks.RDSClient{ + DBInstances: []rdstypes.DBInstance{*rdsInstance2, *auroraCluster2MemberInstance, *auroraCluster3MemberInstance}, + DBClusters: []rdstypes.DBCluster{*auroraCluster2, *auroraCluster3}, + DBEngineVersions: []rdstypes.DBEngineVersion{*auroraMySQLEngine, *postgresEngine}, }, - }, + }), }, inputMatchers: []types.AWSMatcher{ { @@ -124,19 +124,19 @@ func TestRDSFetchers(t *testing.T) { }, { name: "skip unrecognized engines", - inputClients: &cloud.TestCloudClients{ - RDSPerRegion: map[string]rdsiface.RDSAPI{ - "us-east-1": &mocks.RDSMock{ - DBInstances: []*rds.DBInstance{rdsInstance1, rdsInstance3, auroraCluster1MemberInstance}, - DBClusters: []*rds.DBCluster{auroraCluster1}, - DBEngineVersions: []*rds.DBEngineVersion{auroraMySQLEngine}, + fetcherCfg: AWSFetcherFactoryConfig{ + AWSClients: newRegionalFakeRDSClientProvider(map[string]RDSClient{ + "us-east-1": &mocks.RDSClient{ + DBInstances: []rdstypes.DBInstance{*rdsInstance1, *rdsInstance3, *auroraCluster1MemberInstance}, + DBClusters: []rdstypes.DBCluster{*auroraCluster1}, + DBEngineVersions: []rdstypes.DBEngineVersion{*auroraMySQLEngine}, }, - "us-east-2": &mocks.RDSMock{ - DBInstances: []*rds.DBInstance{rdsInstance2, auroraCluster2MemberInstance, auroraCluster3MemberInstance}, - DBClusters: []*rds.DBCluster{auroraCluster2, auroraCluster3}, - DBEngineVersions: []*rds.DBEngineVersion{postgresEngine}, + "us-east-2": &mocks.RDSClient{ + DBInstances: []rdstypes.DBInstance{*rdsInstance2, *auroraCluster2MemberInstance, *auroraCluster3MemberInstance}, + DBClusters: []rdstypes.DBCluster{*auroraCluster2, *auroraCluster3}, + DBEngineVersions: []rdstypes.DBEngineVersion{*postgresEngine}, }, - }, + }), }, inputMatchers: []types.AWSMatcher{ { @@ -154,14 +154,14 @@ func TestRDSFetchers(t *testing.T) { }, { name: "skip unsupported databases", - inputClients: &cloud.TestCloudClients{ - RDSPerRegion: map[string]rdsiface.RDSAPI{ - "us-east-1": &mocks.RDSMock{ - DBInstances: []*rds.DBInstance{auroraCluster1MemberInstance}, - DBClusters: []*rds.DBCluster{auroraCluster1, auroraClusterUnsupported}, - DBEngineVersions: []*rds.DBEngineVersion{auroraMySQLEngine}, + fetcherCfg: AWSFetcherFactoryConfig{ + AWSClients: newRegionalFakeRDSClientProvider(map[string]RDSClient{ + "us-east-1": &mocks.RDSClient{ + DBInstances: []rdstypes.DBInstance{*auroraCluster1MemberInstance}, + DBClusters: []rdstypes.DBCluster{*auroraCluster1, *auroraClusterUnsupported}, + DBEngineVersions: []rdstypes.DBEngineVersion{*auroraMySQLEngine}, }, - }, + }), }, inputMatchers: []types.AWSMatcher{{ Types: []string{types.AWSMatcherRDS}, @@ -172,11 +172,13 @@ func TestRDSFetchers(t *testing.T) { }, { name: "skip unavailable databases", - inputClients: &cloud.TestCloudClients{ - RDS: &mocks.RDSMock{ - DBInstances: []*rds.DBInstance{rdsInstance1, rdsInstanceUnavailable, rdsInstanceUnknownStatus, auroraCluster1MemberInstance, auroraClusterUnknownStatusMemberInstance}, - DBClusters: []*rds.DBCluster{auroraCluster1, auroraClusterUnavailable, auroraClusterUnknownStatus}, - DBEngineVersions: []*rds.DBEngineVersion{auroraMySQLEngine, postgresEngine}, + fetcherCfg: AWSFetcherFactoryConfig{ + AWSClients: fakeAWSClients{ + rdsClient: &mocks.RDSClient{ + DBInstances: []rdstypes.DBInstance{*rdsInstance1, *rdsInstanceUnavailable, *rdsInstanceUnknownStatus, *auroraCluster1MemberInstance, *auroraClusterUnknownStatusMemberInstance}, + DBClusters: []rdstypes.DBCluster{*auroraCluster1, *auroraClusterUnavailable, *auroraClusterUnknownStatus}, + DBEngineVersions: []rdstypes.DBEngineVersion{*auroraMySQLEngine, *postgresEngine}, + }, }, }, inputMatchers: []types.AWSMatcher{{ @@ -188,11 +190,13 @@ func TestRDSFetchers(t *testing.T) { }, { name: "Aurora cluster without writer", - inputClients: &cloud.TestCloudClients{ - RDS: &mocks.RDSMock{ - DBClusters: []*rds.DBCluster{auroraClusterNoWriter}, - DBInstances: []*rds.DBInstance{auroraClusterMemberNoWriter}, - DBEngineVersions: []*rds.DBEngineVersion{auroraMySQLEngine}, + fetcherCfg: AWSFetcherFactoryConfig{ + AWSClients: fakeAWSClients{ + rdsClient: &mocks.RDSClient{ + DBClusters: []rdstypes.DBCluster{*auroraClusterNoWriter}, + DBInstances: []rdstypes.DBInstance{*auroraClusterMemberNoWriter}, + DBEngineVersions: []rdstypes.DBEngineVersion{*auroraMySQLEngine}, + }, }, }, inputMatchers: []types.AWSMatcher{{ @@ -206,7 +210,7 @@ func TestRDSFetchers(t *testing.T) { testAWSFetchers(t, tests...) } -func makeRDSInstance(t *testing.T, name, region string, labels map[string]string, opts ...func(*rds.DBInstance)) (*rds.DBInstance, types.Database) { +func makeRDSInstance(t *testing.T, name, region string, labels map[string]string, opts ...func(*rdstypes.DBInstance)) (*rdstypes.DBInstance, types.Database) { instance := mocks.RDSInstance(name, region, labels, opts...) database, err := common.NewDatabaseFromRDSInstance(instance) require.NoError(t, err) @@ -214,21 +218,21 @@ func makeRDSInstance(t *testing.T, name, region string, labels map[string]string return instance, database } -func makeRDSCluster(t *testing.T, name, region string, labels map[string]string, opts ...func(*rds.DBCluster)) (*rds.DBCluster, *rds.DBInstance, types.Database) { +func makeRDSCluster(t *testing.T, name, region string, labels map[string]string, opts ...func(*rdstypes.DBCluster)) (*rdstypes.DBCluster, *rdstypes.DBInstance, types.Database) { cluster := mocks.RDSCluster(name, region, labels, opts...) dbInstanceMember := makeRDSMemberForCluster(t, name, region, "vpc-123", *cluster.Engine, labels) - database, err := common.NewDatabaseFromRDSCluster(cluster, []*rds.DBInstance{dbInstanceMember}) + database, err := common.NewDatabaseFromRDSCluster(cluster, []rdstypes.DBInstance{*dbInstanceMember}) require.NoError(t, err) common.ApplyAWSDatabaseNameSuffix(database, types.AWSMatcherRDS) return cluster, dbInstanceMember, database } -func makeRDSMemberForCluster(t *testing.T, name, region, vpcid, engine string, labels map[string]string) *rds.DBInstance { - instanceRDSMember, _ := makeRDSInstance(t, name+"-instance-1", region, labels, func(d *rds.DBInstance) { +func makeRDSMemberForCluster(t *testing.T, name, region, vpcid, engine string, labels map[string]string) *rdstypes.DBInstance { + instanceRDSMember, _ := makeRDSInstance(t, name+"-instance-1", region, labels, func(d *rdstypes.DBInstance) { if d.DBSubnetGroup == nil { - d.DBSubnetGroup = &rds.DBSubnetGroup{} + d.DBSubnetGroup = &rdstypes.DBSubnetGroup{} } - d.DBSubnetGroup.SetVpcId(vpcid) + d.DBSubnetGroup.VpcId = aws.String(vpcid) d.DBClusterIdentifier = aws.String(name) d.Engine = aws.String(engine) }) @@ -236,9 +240,9 @@ func makeRDSMemberForCluster(t *testing.T, name, region, vpcid, engine string, l return instanceRDSMember } -func makeRDSClusterWithExtraEndpoints(t *testing.T, name, region string, labels map[string]string, hasWriter bool) (*rds.DBCluster, *rds.DBInstance, types.Databases) { +func makeRDSClusterWithExtraEndpoints(t *testing.T, name, region string, labels map[string]string, hasWriter bool) (*rdstypes.DBCluster, *rdstypes.DBInstance, types.Databases) { cluster := mocks.RDSCluster(name, region, labels, - func(cluster *rds.DBCluster) { + func(cluster *rdstypes.DBCluster) { // Disable writer by default. If hasWriter, writer endpoint will be added below. cluster.DBClusterMembers = nil }, @@ -249,11 +253,11 @@ func makeRDSClusterWithExtraEndpoints(t *testing.T, name, region string, labels var databases types.Databases - instanceRDSMember := makeRDSMemberForCluster(t, name, region, "vpc-123", aws.StringValue(cluster.Engine), labels) - dbInstanceMembers := []*rds.DBInstance{instanceRDSMember} + instanceRDSMember := makeRDSMemberForCluster(t, name, region, "vpc-123", aws.ToString(cluster.Engine), labels) + dbInstanceMembers := []rdstypes.DBInstance{*instanceRDSMember} if hasWriter { - cluster.DBClusterMembers = append(cluster.DBClusterMembers, &rds.DBClusterMember{ + cluster.DBClusterMembers = append(cluster.DBClusterMembers, rdstypes.DBClusterMember{ IsClusterWriter: aws.Bool(true), // Add writer. }) @@ -277,22 +281,49 @@ func makeRDSClusterWithExtraEndpoints(t *testing.T, name, region string, labels } // withRDSInstanceStatus returns an option function for makeRDSInstance to overwrite status. -func withRDSInstanceStatus(status string) func(*rds.DBInstance) { - return func(instance *rds.DBInstance) { +func withRDSInstanceStatus(status string) func(*rdstypes.DBInstance) { + return func(instance *rdstypes.DBInstance) { instance.DBInstanceStatus = aws.String(status) } } // withRDSClusterEngineMode returns an option function for makeRDSCluster to overwrite engine mode. -func withRDSClusterEngineMode(mode string) func(*rds.DBCluster) { - return func(cluster *rds.DBCluster) { +func withRDSClusterEngineMode(mode string) func(*rdstypes.DBCluster) { + return func(cluster *rdstypes.DBCluster) { cluster.EngineMode = aws.String(mode) } } // withRDSClusterStatus returns an option function for makeRDSCluster to overwrite status. -func withRDSClusterStatus(status string) func(*rds.DBCluster) { - return func(cluster *rds.DBCluster) { +func withRDSClusterStatus(status string) func(*rdstypes.DBCluster) { + return func(cluster *rdstypes.DBCluster) { cluster.Status = aws.String(status) } } + +// provides a client specific to each region, where the map keys are regions. +func newRegionalFakeRDSClientProvider(cs map[string]RDSClient) fakeRegionalRDSClients { + return fakeRegionalRDSClients{rdsClients: cs} +} + +type fakeAWSClients struct { + rdsClient RDSClient + redshiftClient RedshiftClient +} + +func (f fakeAWSClients) GetRDSClient(cfg aws.Config, optFns ...func(*rds.Options)) RDSClient { + return f.rdsClient +} + +func (f fakeAWSClients) GetRedshiftClient(cfg aws.Config, optFns ...func(*redshift.Options)) RedshiftClient { + return f.redshiftClient +} + +type fakeRegionalRDSClients struct { + AWSClientProvider + rdsClients map[string]RDSClient +} + +func (f fakeRegionalRDSClients) GetRDSClient(cfg aws.Config, optFns ...func(*rds.Options)) RDSClient { + return f.rdsClients[cfg.Region] +} diff --git a/lib/srv/discovery/fetchers/db/aws_redshift.go b/lib/srv/discovery/fetchers/db/aws_redshift.go index 0cda0b478e67b..b6a17f32ede5e 100644 --- a/lib/srv/discovery/fetchers/db/aws_redshift.go +++ b/lib/srv/discovery/fetchers/db/aws_redshift.go @@ -32,9 +32,6 @@ import ( "github.com/gravitational/teleport/lib/srv/discovery/common" ) -// RedshiftClientProviderFunc provides a [RedshiftClient]. -type RedshiftClientProviderFunc func(cfg aws.Config, optFns ...func(*redshift.Options)) RedshiftClient - // RedshiftClient is a subset of the AWS Redshift API. type RedshiftClient interface { redshift.DescribeClustersAPIClient @@ -57,7 +54,7 @@ func (f *redshiftPlugin) GetDatabases(ctx context.Context, cfg *awsFetcherConfig if err != nil { return nil, trace.Wrap(err) } - clusters, err := getRedshiftClusters(ctx, cfg.redshiftClientProviderFn(awsCfg)) + clusters, err := getRedshiftClusters(ctx, cfg.awsClients.GetRedshiftClient(awsCfg)) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/srv/discovery/fetchers/db/aws_redshift_test.go b/lib/srv/discovery/fetchers/db/aws_redshift_test.go index ded47035e96e3..8e95641f7931c 100644 --- a/lib/srv/discovery/fetchers/db/aws_redshift_test.go +++ b/lib/srv/discovery/fetchers/db/aws_redshift_test.go @@ -22,7 +22,6 @@ import ( "testing" "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/redshift" redshifttypes "github.com/aws/aws-sdk-go-v2/service/redshift/types" "github.com/stretchr/testify/require" @@ -31,11 +30,6 @@ import ( "github.com/gravitational/teleport/lib/srv/discovery/common" ) -func newFakeRedshiftClientProvider(c RedshiftClient) RedshiftClientProviderFunc { - return func(cfg aws.Config, optFns ...func(*redshift.Options)) RedshiftClient { - return c - } -} func TestRedshiftFetcher(t *testing.T) { t.Parallel() @@ -48,9 +42,11 @@ func TestRedshiftFetcher(t *testing.T) { { name: "fetch all", fetcherCfg: AWSFetcherFactoryConfig{ - RedshiftClientProviderFn: newFakeRedshiftClientProvider(&mocks.RedshiftClient{ - Clusters: []redshifttypes.Cluster{*redshiftUse1Prod, *redshiftUse1Dev}, - }), + AWSClients: fakeAWSClients{ + redshiftClient: &mocks.RedshiftClient{ + Clusters: []redshifttypes.Cluster{*redshiftUse1Prod, *redshiftUse1Dev}, + }, + }, }, inputMatchers: makeAWSMatchersForType(types.AWSMatcherRedshift, "us-east-1", wildcardLabels), wantDatabases: types.Databases{redshiftDatabaseUse1Prod, redshiftDatabaseUse1Dev}, @@ -58,9 +54,11 @@ func TestRedshiftFetcher(t *testing.T) { { name: "fetch prod", fetcherCfg: AWSFetcherFactoryConfig{ - RedshiftClientProviderFn: newFakeRedshiftClientProvider(&mocks.RedshiftClient{ - Clusters: []redshifttypes.Cluster{*redshiftUse1Prod, *redshiftUse1Dev}, - }), + AWSClients: fakeAWSClients{ + redshiftClient: &mocks.RedshiftClient{ + Clusters: []redshifttypes.Cluster{*redshiftUse1Prod, *redshiftUse1Dev}, + }, + }, }, inputMatchers: makeAWSMatchersForType(types.AWSMatcherRedshift, "us-east-1", envProdLabels), wantDatabases: types.Databases{redshiftDatabaseUse1Prod}, @@ -68,9 +66,11 @@ func TestRedshiftFetcher(t *testing.T) { { name: "skip unavailable", fetcherCfg: AWSFetcherFactoryConfig{ - RedshiftClientProviderFn: newFakeRedshiftClientProvider(&mocks.RedshiftClient{ - Clusters: []redshifttypes.Cluster{*redshiftUse1Prod, *redshiftUse1Unavailable, *redshiftUse1UnknownStatus}, - }), + AWSClients: fakeAWSClients{ + redshiftClient: &mocks.RedshiftClient{ + Clusters: []redshifttypes.Cluster{*redshiftUse1Prod, *redshiftUse1Unavailable, *redshiftUse1UnknownStatus}, + }, + }, }, inputMatchers: makeAWSMatchersForType(types.AWSMatcherRedshift, "us-east-1", wildcardLabels), wantDatabases: types.Databases{redshiftDatabaseUse1Prod, redshiftDatabaseUnknownStatus}, diff --git a/lib/srv/discovery/fetchers/db/db.go b/lib/srv/discovery/fetchers/db/db.go index 8d79bc2bb65bc..cd4df7269a14e 100644 --- a/lib/srv/discovery/fetchers/db/db.go +++ b/lib/srv/discovery/fetchers/db/db.go @@ -23,6 +23,7 @@ import ( "log/slog" "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/rds" "github.com/aws/aws-sdk-go-v2/service/redshift" "github.com/gravitational/trace" "golang.org/x/exp/maps" @@ -67,14 +68,32 @@ func IsAzureMatcherType(matcherType string) bool { return len(makeAzureFetcherFuncs[matcherType]) > 0 } +// AWSClientProvider provides AWS service API clients. +type AWSClientProvider interface { + // GetRDSClient provides an [RDSClient]. + GetRDSClient(cfg aws.Config, optFns ...func(*rds.Options)) RDSClient + // GetRedshiftClient provides an [RedshiftClient]. + GetRedshiftClient(cfg aws.Config, optFns ...func(*redshift.Options)) RedshiftClient +} + +type defaultAWSClients struct{} + +func (defaultAWSClients) GetRDSClient(cfg aws.Config, optFns ...func(*rds.Options)) RDSClient { + return rds.NewFromConfig(cfg, optFns...) +} + +func (defaultAWSClients) GetRedshiftClient(cfg aws.Config, optFns ...func(*redshift.Options)) RedshiftClient { + return redshift.NewFromConfig(cfg, optFns...) +} + // AWSFetcherFactoryConfig is the configuration for an [AWSFetcherFactory]. type AWSFetcherFactoryConfig struct { // AWSConfigProvider provides [aws.Config] for AWS SDK service clients. AWSConfigProvider awsconfig.Provider + // AWSClients provides AWS SDK clients. + AWSClients AWSClientProvider // CloudClients is an interface for retrieving AWS SDK v1 cloud clients. CloudClients cloud.AWSClients - // RedshiftClientProviderFn is an optional function that provides - RedshiftClientProviderFn RedshiftClientProviderFunc } func (c *AWSFetcherFactoryConfig) checkAndSetDefaults() error { @@ -84,10 +103,8 @@ func (c *AWSFetcherFactoryConfig) checkAndSetDefaults() error { if c.AWSConfigProvider == nil { return trace.BadParameter("missing AWSConfigProvider") } - if c.RedshiftClientProviderFn == nil { - c.RedshiftClientProviderFn = func(cfg aws.Config, optFns ...func(*redshift.Options)) RedshiftClient { - return redshift.NewFromConfig(cfg, optFns...) - } + if c.AWSClients == nil { + c.AWSClients = defaultAWSClients{} } return nil } @@ -125,15 +142,15 @@ func (f *AWSFetcherFactory) MakeFetchers(ctx context.Context, matchers []types.A for _, makeFetcher := range makeFetchers { for _, region := range matcher.Regions { fetcher, err := makeFetcher(awsFetcherConfig{ - AWSClients: f.cfg.CloudClients, - Type: matcherType, - AssumeRole: assumeRole, - Labels: matcher.Tags, - Region: region, - Integration: matcher.Integration, - DiscoveryConfigName: discoveryConfigName, - AWSConfigProvider: f.cfg.AWSConfigProvider, - redshiftClientProviderFn: f.cfg.RedshiftClientProviderFn, + AWSClients: f.cfg.CloudClients, + Type: matcherType, + AssumeRole: assumeRole, + Labels: matcher.Tags, + Region: region, + Integration: matcher.Integration, + DiscoveryConfigName: discoveryConfigName, + AWSConfigProvider: f.cfg.AWSConfigProvider, + awsClients: f.cfg.AWSClients, }) if err != nil { return nil, trace.Wrap(err) From b507ebed18a94be8467b54b9b9a881f911c280b4 Mon Sep 17 00:00:00 2001 From: Steven Martin Date: Mon, 13 Jan 2025 07:37:46 -0500 Subject: [PATCH 04/15] Fix spelling in log messages (#50974) --- lib/backend/memory/memory.go | 2 +- lib/events/complete.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/backend/memory/memory.go b/lib/backend/memory/memory.go index cd00a6bb6efaa..4adb2b0779803 100644 --- a/lib/backend/memory/memory.go +++ b/lib/backend/memory/memory.go @@ -472,7 +472,7 @@ func (m *Memory) removeExpired() int { } m.heap.PopEl() m.tree.Delete(item) - m.logger.DebugContext(m.ctx, "Removed expired item.", "key", item.Key.String(), "epiry", item.Expires) + m.logger.DebugContext(m.ctx, "Removed expired item.", "key", item.Key.String(), "expiry", item.Expires) removed++ event := backend.Event{ diff --git a/lib/events/complete.go b/lib/events/complete.go index de9022391a533..881e02e80ce62 100644 --- a/lib/events/complete.go +++ b/lib/events/complete.go @@ -271,7 +271,7 @@ func (u *UploadCompleter) CheckUploads(ctx context.Context) error { continue } - log.DebugContext(ctx, "foud upload with parts", "part_count", len(parts)) + log.DebugContext(ctx, "found upload with parts", "part_count", len(parts)) if err := u.cfg.Uploader.CompleteUpload(ctx, upload, parts); trace.IsNotFound(err) { log.DebugContext(ctx, "Upload not found, moving on to next upload", "error", err) From 8e513b4643176d3a5285c4c05b75e8c95c2b6d44 Mon Sep 17 00:00:00 2001 From: Steven Martin Date: Mon, 13 Jan 2025 07:37:55 -0500 Subject: [PATCH 05/15] Fix spelling for log messages (#50917) --- tool/tsh/common/git_config.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tool/tsh/common/git_config.go b/tool/tsh/common/git_config.go index 89771735b30b3..6b703af251cee 100644 --- a/tool/tsh/common/git_config.go +++ b/tool/tsh/common/git_config.go @@ -124,11 +124,11 @@ func (c *gitConfigCommand) doUpdate(cf *CLIConf) error { for _, url := range strings.Split(urls, "\n") { u, err := parseGitSSHURL(url) if err != nil { - logger.DebugContext(cf.Context, "Skippig URL", "error", err, "url", url) + logger.DebugContext(cf.Context, "Skipping URL", "error", err, "url", url) continue } if !u.isGitHub() { - logger.DebugContext(cf.Context, "Skippig non-GitHub host", "host", u.Host) + logger.DebugContext(cf.Context, "Skipping non-GitHub host", "host", u.Host) continue } From 3c87bea3cf768258c9ac059e84553e43918e266f Mon Sep 17 00:00:00 2001 From: rosstimothy <39066650+rosstimothy@users.noreply.github.com> Date: Mon, 13 Jan 2025 11:54:42 -0500 Subject: [PATCH 06/15] Fix incorrectly named dynamo events scaling policy (#50907) The read scaling policy name was incorrectly changed to match the write scaling policy. This prevents upgrading from a v16 cluster with a dynamo audit backend configured to use autoscaling to v17. To resolve, when conflicts are found the incorrectly named scaling policy is removed, and replaced by the correctly named one. --- lib/events/dynamoevents/dynamoevents.go | 49 +++++++++++++++++-------- 1 file changed, 34 insertions(+), 15 deletions(-) diff --git a/lib/events/dynamoevents/dynamoevents.go b/lib/events/dynamoevents/dynamoevents.go index 5c2036336b278..9e8a18c92b4c8 100644 --- a/lib/events/dynamoevents/dynamoevents.go +++ b/lib/events/dynamoevents/dynamoevents.go @@ -437,14 +437,14 @@ func (l *Log) configureTable(ctx context.Context, svc *applicationautoscaling.Cl readDimension: autoscalingtypes.ScalableDimensionDynamoDBTableReadCapacityUnits, writeDimension: autoscalingtypes.ScalableDimensionDynamoDBTableWriteCapacityUnits, resourceID: fmt.Sprintf("table/%s", l.Tablename), - readPolicy: fmt.Sprintf("%s-write-target-tracking-scaling-policy", l.Tablename), + readPolicy: fmt.Sprintf("%s-read-target-tracking-scaling-policy", l.Tablename), writePolicy: fmt.Sprintf("%s-write-target-tracking-scaling-policy", l.Tablename), }, { readDimension: autoscalingtypes.ScalableDimensionDynamoDBIndexReadCapacityUnits, writeDimension: autoscalingtypes.ScalableDimensionDynamoDBIndexWriteCapacityUnits, resourceID: fmt.Sprintf("table/%s/index/%s", l.Tablename, indexTimeSearchV2), - readPolicy: fmt.Sprintf("%s/index/%s-write-target-tracking-scaling-policy", l.Tablename, indexTimeSearchV2), + readPolicy: fmt.Sprintf("%s/index/%s-read-target-tracking-scaling-policy", l.Tablename, indexTimeSearchV2), writePolicy: fmt.Sprintf("%s/index/%s-write-target-tracking-scaling-policy", l.Tablename, indexTimeSearchV2), }, } @@ -472,20 +472,39 @@ func (l *Log) configureTable(ctx context.Context, svc *applicationautoscaling.Cl // Define scaling policy. Defines the ratio of {read,write} consumed capacity to // provisioned capacity DynamoDB will try and maintain. - if _, err := svc.PutScalingPolicy(ctx, &applicationautoscaling.PutScalingPolicyInput{ - PolicyName: aws.String(p.readPolicy), - PolicyType: autoscalingtypes.PolicyTypeTargetTrackingScaling, - ResourceId: aws.String(p.resourceID), - ScalableDimension: p.readDimension, - ServiceNamespace: autoscalingtypes.ServiceNamespaceDynamodb, - TargetTrackingScalingPolicyConfiguration: &autoscalingtypes.TargetTrackingScalingPolicyConfiguration{ - PredefinedMetricSpecification: &autoscalingtypes.PredefinedMetricSpecification{ - PredefinedMetricType: autoscalingtypes.MetricTypeDynamoDBReadCapacityUtilization, + for i := 0; i < 2; i++ { + if _, err := svc.PutScalingPolicy(ctx, &applicationautoscaling.PutScalingPolicyInput{ + PolicyName: aws.String(p.readPolicy), + PolicyType: autoscalingtypes.PolicyTypeTargetTrackingScaling, + ResourceId: aws.String(p.resourceID), + ScalableDimension: p.readDimension, + ServiceNamespace: autoscalingtypes.ServiceNamespaceDynamodb, + TargetTrackingScalingPolicyConfiguration: &autoscalingtypes.TargetTrackingScalingPolicyConfiguration{ + PredefinedMetricSpecification: &autoscalingtypes.PredefinedMetricSpecification{ + PredefinedMetricType: autoscalingtypes.MetricTypeDynamoDBReadCapacityUtilization, + }, + TargetValue: aws.Float64(l.ReadTargetValue), }, - TargetValue: aws.Float64(l.ReadTargetValue), - }, - }); err != nil { - return trace.Wrap(convertError(err)) + }); err != nil { + // The read policy name was accidentally changed to match the write policy in 17.0.0-17.1.4. This + // prevented upgrading a cluster with autoscaling enabled from v16 to v17. To resolve in + // a backwards compatible way, the read policy name was restored, however, any new clusters that + // were created between 17.0.0 and 17.1.4 need to have the misnamed policy deleted and recreated + // with the correct name. + if i == 1 || !strings.Contains(err.Error(), "ValidationException: Only one TargetTrackingScaling policy for a given metric specification is allowed.") { + return trace.Wrap(convertError(err)) + } + + l.logger.DebugContext(ctx, "Fixing incorrectly named scaling policy") + if _, err := svc.DeleteScalingPolicy(ctx, &applicationautoscaling.DeleteScalingPolicyInput{ + PolicyName: aws.String(p.writePolicy), + ResourceId: aws.String(p.resourceID), + ScalableDimension: p.readDimension, + ServiceNamespace: autoscalingtypes.ServiceNamespaceDynamodb, + }); err != nil { + return trace.Wrap(convertError(err)) + } + } } if _, err := svc.PutScalingPolicy(ctx, &applicationautoscaling.PutScalingPolicyInput{ From 3653dda62282a3632229db7c10420064fefb552c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rafa=C5=82=20Cie=C5=9Blak?= Date: Mon, 13 Jan 2025 17:55:41 +0100 Subject: [PATCH 07/15] Add basic support for target port in gateways in Connect (#50912) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update type for targetSubresourceName on DocumentGateway The way DocumentsService.createGatewayDocument is implemented means that the targetSubresourceName property is always present, but it can be undefined. * Use "local port" instead of "port" in DocumentGatewapApp * Rewrite gateway FieldInputs to use styled components * Update comments in protos * useGateway: Stabilize useAsync functions of ports * Add padding to menu label if it's first child * Add support for required prop to Input and FieldInput * Add UI for changing target port * ActionButtons: Show ports of multi-port apps when VNet is not supported Now that we have support for the target port in Connect's gateways, we can show the ports and then open a gateway for that specific port on click. * Add RWMutex to gateways * Clear app gateway cert on target port change * Remove gateways/app.LocalProxyURL It was used only in tests and it made sense only for web apps anyway. * TestTCP: Close connections when test ends * Create context with timeout in testGatewayCertRenewal …instead of in each function that uses it. * Add tests for changing the target port of a TCP gateway * Parallelize app gateway tests within MFA/non-MFA groups * Make testGatewayConnection take ctx as first arg This will be needed in tests that check target port validation. * Validate target port of app gateways * Increase timeouts in app gateway tests * Change icons from medium to small * Use consistent spacing in AppGateway * Add godoc for ValidateTargetPort * Add retry with relogin to change target port --- .../go/teleport/lib/teleterm/v1/gateway.pb.go | 5 +- .../ts/teleport/lib/teleterm/v1/gateway_pb.ts | 5 +- integration/appaccess/appaccess_test.go | 1 + integration/appaccess/pack.go | 56 ++++ integration/proxy/proxy_helpers.go | 53 +++- integration/proxy/proxy_test.go | 30 ++- integration/proxy/teleterm_test.go | 254 +++++++++++++++--- .../apiserver/handler/handler_gateways.go | 2 +- lib/teleterm/clusters/cluster_apps.go | 31 ++- lib/teleterm/clusters/cluster_gateways.go | 31 ++- lib/teleterm/daemon/daemon.go | 24 +- lib/teleterm/gateway/app.go | 11 - lib/teleterm/gateway/app_middleware.go | 6 +- lib/teleterm/gateway/base.go | 28 +- lib/teleterm/gateway/config.go | 5 + lib/teleterm/gateway/interfaces.go | 5 +- lib/teleterm/gateway/kube.go | 2 + proto/teleport/lib/teleterm/v1/gateway.proto | 5 +- web/packages/design/src/Input/Input.tsx | 3 + web/packages/design/src/Menu/Menu.story.tsx | 12 + web/packages/design/src/Menu/MenuItem.tsx | 30 ++- web/packages/design/src/keyframes.ts | 4 + .../components/FieldInput/FieldInput.tsx | 5 +- .../teleterm/src/services/tshd/testHelpers.ts | 2 +- .../src/ui/DocumentCluster/ActionButtons.tsx | 97 ++++--- .../src/ui/DocumentGateway/useGateway.ts | 63 +++-- .../src/ui/DocumentGatewayApp/AppGateway.tsx | 208 +++++++++++--- .../DocumentGatewayApp.story.tsx | 47 +++- .../DocumentGatewayApp/DocumentGatewayApp.tsx | 13 +- .../src/ui/TabHost/useTabShortcuts.test.tsx | 2 + .../src/ui/components/FieldInputs.tsx | 37 +-- .../src/ui/components/OfflineGateway.tsx | 4 +- .../documentsService/connectToApp.test.ts | 28 +- .../documentsService/connectToApp.ts | 20 +- .../documentsService/documentsService.test.ts | 2 + .../documentsService/types.ts | 6 +- 36 files changed, 904 insertions(+), 233 deletions(-) diff --git a/gen/proto/go/teleport/lib/teleterm/v1/gateway.pb.go b/gen/proto/go/teleport/lib/teleterm/v1/gateway.pb.go index 612afa9d557af..17d5fd1e02179 100644 --- a/gen/proto/go/teleport/lib/teleterm/v1/gateway.pb.go +++ b/gen/proto/go/teleport/lib/teleterm/v1/gateway.pb.go @@ -59,10 +59,11 @@ type Gateway struct { LocalAddress string `protobuf:"bytes,5,opt,name=local_address,json=localAddress,proto3" json:"local_address,omitempty"` // local_port is the gateway address on localhost LocalPort string `protobuf:"bytes,6,opt,name=local_port,json=localPort,proto3" json:"local_port,omitempty"` - // protocol is the gateway protocol + // protocol is the protocol used by the gateway. For databases, it matches the type of the + // database that the gateway targets. For apps, it's either "HTTP" or "TCP". Protocol string `protobuf:"bytes,7,opt,name=protocol,proto3" json:"protocol,omitempty"` // target_subresource_name points at a subresource of the remote resource, for example a - // database name on a database server. + // database name on a database server or a target port of a multi-port TCP app. TargetSubresourceName string `protobuf:"bytes,9,opt,name=target_subresource_name,json=targetSubresourceName,proto3" json:"target_subresource_name,omitempty"` // gateway_cli_client represents a command that the user can execute to connect to the resource // through the gateway. diff --git a/gen/proto/ts/teleport/lib/teleterm/v1/gateway_pb.ts b/gen/proto/ts/teleport/lib/teleterm/v1/gateway_pb.ts index f6523f7cc2210..194cc93867671 100644 --- a/gen/proto/ts/teleport/lib/teleterm/v1/gateway_pb.ts +++ b/gen/proto/ts/teleport/lib/teleterm/v1/gateway_pb.ts @@ -80,14 +80,15 @@ export interface Gateway { */ localPort: string; /** - * protocol is the gateway protocol + * protocol is the protocol used by the gateway. For databases, it matches the type of the + * database that the gateway targets. For apps, it's either "HTTP" or "TCP". * * @generated from protobuf field: string protocol = 7; */ protocol: string; /** * target_subresource_name points at a subresource of the remote resource, for example a - * database name on a database server. + * database name on a database server or a target port of a multi-port TCP app. * * @generated from protobuf field: string target_subresource_name = 9; */ diff --git a/integration/appaccess/appaccess_test.go b/integration/appaccess/appaccess_test.go index dffd5f8aa1912..8bb73e091754b 100644 --- a/integration/appaccess/appaccess_test.go +++ b/integration/appaccess/appaccess_test.go @@ -831,6 +831,7 @@ func TestTCP(t *testing.T) { conn, err := net.Dial("tcp", localProxyAddress) require.NoError(t, err) + defer conn.Close() buf := make([]byte, 1024) n, err := conn.Read(buf) diff --git a/integration/appaccess/pack.go b/integration/appaccess/pack.go index 5deabac208c4c..5a5de08691da4 100644 --- a/integration/appaccess/pack.go +++ b/integration/appaccess/pack.go @@ -184,6 +184,34 @@ func (p *Pack) RootAppPublicAddr() string { return p.rootAppPublicAddr } +func (p *Pack) RootTCPAppName() string { + return p.rootTCPAppName +} + +func (p *Pack) RootTCPMessage() string { + return p.rootTCPMessage +} + +func (p *Pack) RootTCPMultiPortAppName() string { + return p.rootTCPMultiPortAppName +} + +func (p *Pack) RootTCPMultiPortAppPortAlpha() int { + return p.rootTCPMultiPortAppPortAlpha +} + +func (p *Pack) RootTCPMultiPortMessageAlpha() string { + return p.rootTCPMultiPortMessageAlpha +} + +func (p *Pack) RootTCPMultiPortAppPortBeta() int { + return p.rootTCPMultiPortAppPortBeta +} + +func (p *Pack) RootTCPMultiPortMessageBeta() string { + return p.rootTCPMultiPortMessageBeta +} + func (p *Pack) RootAuthServer() *auth.Server { return p.rootCluster.Process.GetAuthServer() } @@ -200,6 +228,34 @@ func (p *Pack) LeafAppPublicAddr() string { return p.leafAppPublicAddr } +func (p *Pack) LeafTCPAppName() string { + return p.leafTCPAppName +} + +func (p *Pack) LeafTCPMessage() string { + return p.leafTCPMessage +} + +func (p *Pack) LeafTCPMultiPortAppName() string { + return p.leafTCPMultiPortAppName +} + +func (p *Pack) LeafTCPMultiPortAppPortAlpha() int { + return p.leafTCPMultiPortAppPortAlpha +} + +func (p *Pack) LeafTCPMultiPortMessageAlpha() string { + return p.leafTCPMultiPortMessageAlpha +} + +func (p *Pack) LeafTCPMultiPortAppPortBeta() int { + return p.leafTCPMultiPortAppPortBeta +} + +func (p *Pack) LeafTCPMultiPortMessageBeta() string { + return p.leafTCPMultiPortMessageBeta +} + func (p *Pack) LeafAuthServer() *auth.Server { return p.leafCluster.Process.GetAuthServer() } diff --git a/integration/proxy/proxy_helpers.go b/integration/proxy/proxy_helpers.go index 789ab0f4f577f..b5796110eb53d 100644 --- a/integration/proxy/proxy_helpers.go +++ b/integration/proxy/proxy_helpers.go @@ -28,6 +28,7 @@ import ( "net/http" "net/url" "path/filepath" + "strconv" "strings" "testing" "time" @@ -684,7 +685,7 @@ func mustFindKubePod(t *testing.T, tc *client.TeleportClient) { require.Equal(t, types.KindKubePod, response.Resources[0].Kind) } -func mustConnectDatabaseGateway(t *testing.T, _ *daemon.Service, gw gateway.Gateway) { +func mustConnectDatabaseGateway(ctx context.Context, t *testing.T, _ *daemon.Service, gw gateway.Gateway) { t.Helper() dbGateway, err := gateway.AsDatabase(gw) @@ -705,15 +706,15 @@ func mustConnectDatabaseGateway(t *testing.T, _ *daemon.Service, gw gateway.Gate require.NoError(t, client.Close()) } -// mustConnectAppGateway verifies that the gateway acts as an unauthenticated proxy that forwards -// requests to the app behind it. -func mustConnectAppGateway(t *testing.T, _ *daemon.Service, gw gateway.Gateway) { +// mustConnectWebAppGateway verifies that the gateway acts as an unauthenticated proxy that forwards +// requests to the web app behind it. +func mustConnectWebAppGateway(ctx context.Context, t *testing.T, _ *daemon.Service, gw gateway.Gateway) { t.Helper() - appGw, err := gateway.AsApp(gw) - require.NoError(t, err) + gatewayAddress := net.JoinHostPort(gw.LocalAddress(), gw.LocalPort()) + gatewayURL := fmt.Sprintf("http://%s", gatewayAddress) - req, err := http.NewRequest(http.MethodGet, appGw.LocalProxyURL(), nil) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, gatewayURL, nil) require.NoError(t, err) client := &http.Client{} @@ -724,6 +725,44 @@ func mustConnectAppGateway(t *testing.T, _ *daemon.Service, gw gateway.Gateway) require.Equal(t, http.StatusOK, resp.StatusCode) } +func makeMustConnectMultiPortTCPAppGateway(wantMessage string, otherTargetPort int, otherWantMessage string) testGatewayConnectionFunc { + return func(ctx context.Context, t *testing.T, d *daemon.Service, gw gateway.Gateway) { + t.Helper() + + gwURI := gw.URI().String() + originalTargetPort := gw.TargetSubresourceName() + makeMustConnectTCPAppGateway(wantMessage)(ctx, t, d, gw) + + _, err := d.SetGatewayTargetSubresourceName(ctx, gwURI, strconv.Itoa(otherTargetPort)) + require.NoError(t, err) + makeMustConnectTCPAppGateway(otherWantMessage)(ctx, t, d, gw) + + // Restore the original port, so that the next time the test calls this function after certs + // expire, wantMessage is going to match the port that the gateway points to. + _, err = d.SetGatewayTargetSubresourceName(ctx, gwURI, originalTargetPort) + require.NoError(t, err) + makeMustConnectTCPAppGateway(wantMessage)(ctx, t, d, gw) + } +} + +func makeMustConnectTCPAppGateway(wantMessage string) testGatewayConnectionFunc { + return func(ctx context.Context, t *testing.T, _ *daemon.Service, gw gateway.Gateway) { + t.Helper() + + gatewayAddress := net.JoinHostPort(gw.LocalAddress(), gw.LocalPort()) + conn, err := net.Dial("tcp", gatewayAddress) + require.NoError(t, err) + defer conn.Close() + + buf := make([]byte, 1024) + n, err := conn.Read(buf) + require.NoError(t, err) + + resp := strings.TrimSpace(string(buf[:n])) + require.Equal(t, wantMessage, resp) + } +} + func kubeClientForLocalProxy(t *testing.T, kubeconfigPath, teleportCluster, kubeCluster string) *kubernetes.Clientset { t.Helper() diff --git a/integration/proxy/proxy_test.go b/integration/proxy/proxy_test.go index cf75bfd5f146b..262b8c1046726 100644 --- a/integration/proxy/proxy_test.go +++ b/integration/proxy/proxy_test.go @@ -54,6 +54,7 @@ import ( "github.com/gravitational/teleport/lib" "github.com/gravitational/teleport/lib/auth/testauthority" libclient "github.com/gravitational/teleport/lib/client" + "github.com/gravitational/teleport/lib/client/mfa" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/multiplexer" @@ -1315,18 +1316,29 @@ func TestALPNSNIProxyAppAccess(t *testing.T) { }) t.Run("teleterm app gateways cert renewal", func(t *testing.T) { - user, _ := pack.CreateUser(t) - tc := pack.MakeTeleportClient(t, user.GetName()) - - // test without per session MFA. - testTeletermAppGateway(t, pack, tc) + t.Run("without per-session MFA", func(t *testing.T) { + makeTC := func(t *testing.T) (*libclient.TeleportClient, mfa.WebauthnLoginFunc) { + user, _ := pack.CreateUser(t) + tc := pack.MakeTeleportClient(t, user.GetName()) + return tc, nil + } + testTeletermAppGateway(t, pack, makeTC) + testTeletermAppGatewayTargetPortValidation(t, pack, makeTC) + }) - t.Run("per session MFA", func(t *testing.T) { - // They update user's authentication to Webauthn so they must run after tests which do not use MFA. + t.Run("per-session MFA", func(t *testing.T) { + // They update clusters authentication to Webauthn so they must run after tests which do not use MFA. requireSessionMFAAuthPref(ctx, t, pack.RootAuthServer(), "127.0.0.1") requireSessionMFAAuthPref(ctx, t, pack.LeafAuthServer(), "127.0.0.1") - tc.WebauthnLogin = setupUserMFA(ctx, t, pack.RootAuthServer(), user.GetName(), "127.0.0.1") - testTeletermAppGateway(t, pack, tc) + makeTCAndWebauthnLogin := func(t *testing.T) (*libclient.TeleportClient, mfa.WebauthnLoginFunc) { + // Create a separate user for each tests to enable parallel tests that use per-session MFA. + // See the comment for webauthnLogin in setupUserMFA for more details. + user, _ := pack.CreateUser(t) + tc := pack.MakeTeleportClient(t, user.GetName()) + webauthnLogin := setupUserMFA(ctx, t, pack.RootAuthServer(), user.GetName(), "127.0.0.1") + return tc, webauthnLogin + } + testTeletermAppGateway(t, pack, makeTCAndWebauthnLogin) }) }) } diff --git a/integration/proxy/teleterm_test.go b/integration/proxy/teleterm_test.go index 67feeda87944c..18b0efd4884c7 100644 --- a/integration/proxy/teleterm_test.go +++ b/integration/proxy/teleterm_test.go @@ -19,9 +19,11 @@ package proxy import ( + "cmp" "context" "errors" "net" + "strconv" "sync" "sync/atomic" "testing" @@ -50,9 +52,9 @@ import ( "github.com/gravitational/teleport/lib/auth/mocku2f" wancli "github.com/gravitational/teleport/lib/auth/webauthncli" wantypes "github.com/gravitational/teleport/lib/auth/webauthntypes" - "github.com/gravitational/teleport/lib/client" libclient "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/client/clientcache" + "github.com/gravitational/teleport/lib/client/mfa" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/service" "github.com/gravitational/teleport/lib/service/servicecfg" @@ -168,8 +170,8 @@ func testDBGatewayCertRenewal(ctx context.Context, t *testing.T, params dbGatewa TargetURI: params.databaseURI.String(), TargetUser: params.pack.Root.User.GetName(), }, - testGatewayConnectionFunc: mustConnectDatabaseGateway, - webauthnLogin: params.webauthnLogin, + testGatewayConnection: mustConnectDatabaseGateway, + webauthnLogin: params.webauthnLogin, generateAndSetupUserCreds: func(t *testing.T, tc *libclient.TeleportClient, ttl time.Duration) { creds, err := helpers.GenerateUserCreds(helpers.UserCredsRequest{ Process: params.pack.Root.Cluster.Process, @@ -184,7 +186,7 @@ func testDBGatewayCertRenewal(ctx context.Context, t *testing.T, params dbGatewa ) } -type testGatewayConnectionFunc func(*testing.T, *daemon.Service, gateway.Gateway) +type testGatewayConnectionFunc func(context.Context, *testing.T, *daemon.Service, gateway.Gateway) type generateAndSetupUserCredsFunc func(t *testing.T, tc *libclient.TeleportClient, ttl time.Duration) @@ -192,14 +194,19 @@ type gatewayCertRenewalParams struct { tc *libclient.TeleportClient albAddr string createGatewayParams daemon.CreateGatewayParams - testGatewayConnectionFunc testGatewayConnectionFunc + testGatewayConnection testGatewayConnectionFunc webauthnLogin libclient.WebauthnLoginFunc generateAndSetupUserCreds generateAndSetupUserCredsFunc + wantPromptMFACallCount int } func testGatewayCertRenewal(ctx context.Context, t *testing.T, params gatewayCertRenewalParams) { t.Helper() + // The test can potentially hang forever if something is wrong with the MFA prompt, add a timeout. + ctx, cancel := context.WithTimeout(ctx, time.Minute) + t.Cleanup(cancel) + tc := params.tc // Save the profile yaml file to disk as test helpers like helpers.NewClientWithCreds don't do @@ -273,7 +280,7 @@ func testGatewayCertRenewal(ctx context.Context, t *testing.T, params gatewayCer gateway, err := daemonService.CreateGateway(ctx, params.createGatewayParams) require.NoError(t, err, trace.DebugReport(err)) - params.testGatewayConnectionFunc(t, daemonService, gateway) + params.testGatewayConnection(ctx, t, daemonService, gateway) // Advance the fake clock to simulate the db cert expiry inside the middleware. fakeClock.Advance(time.Hour * 48) @@ -286,16 +293,17 @@ func testGatewayCertRenewal(ctx context.Context, t *testing.T, params gatewayCer // and then it will attempt to reissue the user cert using an expired user cert. // The mocked tshdEventsClient will issue a valid user cert, save it to disk, and the middleware // will let the connection through. - params.testGatewayConnectionFunc(t, daemonService, gateway) + params.testGatewayConnection(ctx, t, daemonService, gateway) require.Equal(t, uint32(1), tshdEventsService.reloginCallCount.Load(), "Unexpected number of calls to TSHDEventsClient.Relogin") require.Equal(t, uint32(0), tshdEventsService.sendNotificationCallCount.Load(), "Unexpected number of calls to TSHDEventsClient.SendNotification") if params.webauthnLogin != nil { - // There are two calls, one to issue the certs when creating the gateway and then another to - // reissue them after relogin. - require.Equal(t, uint32(2), tshdEventsService.promptMFACallCount.Load(), + // By default, there are two calls, one to issue the certs when creating the gateway and then + // another to reissue them after relogin. + wantCallCount := cmp.Or(params.wantPromptMFACallCount, 2) + require.Equal(t, uint32(wantCallCount), tshdEventsService.promptMFACallCount.Load(), "Unexpected number of calls to TSHDEventsClient.PromptMFA") } } @@ -474,9 +482,6 @@ func TestTeletermKubeGateway(t *testing.T) { t.Run("root with per-session MFA", func(t *testing.T) { profileName := mustGetProfileName(t, suite.root.Web) kubeURI := uri.NewClusterURI(profileName).AppendKube(kubeClusterName) - // The test can potentially hang forever if something is wrong with the MFA prompt, add a timeout. - ctx, cancel := context.WithTimeout(ctx, 10*time.Second) - t.Cleanup(cancel) testKubeGatewayCertRenewal(ctx, t, kubeGatewayCertRenewalParams{ suite: suite, kubeURI: kubeURI, @@ -486,9 +491,6 @@ func TestTeletermKubeGateway(t *testing.T) { t.Run("leaf with per-session MFA", func(t *testing.T) { profileName := mustGetProfileName(t, suite.root.Web) kubeURI := uri.NewClusterURI(profileName).AppendLeafCluster(suite.leaf.Secrets.SiteName).AppendKube(kubeClusterName) - // The test can potentially hang forever if something is wrong with the MFA prompt, add a timeout. - ctx, cancel := context.WithTimeout(ctx, 10*time.Second) - t.Cleanup(cancel) testKubeGatewayCertRenewal(ctx, t, kubeGatewayCertRenewalParams{ suite: suite, kubeURI: kubeURI, @@ -523,7 +525,7 @@ func testKubeGatewayCertRenewal(ctx context.Context, t *testing.T, params kubeGa }) require.NoError(t, err) - testKubeConnection := func(t *testing.T, daemonService *daemon.Service, gw gateway.Gateway) { + testKubeConnection := func(ctx context.Context, t *testing.T, daemonService *daemon.Service, gw gateway.Gateway) { t.Helper() clientOnce.Do(func() { @@ -548,8 +550,8 @@ func testKubeGatewayCertRenewal(ctx context.Context, t *testing.T, params kubeGa createGatewayParams: daemon.CreateGatewayParams{ TargetURI: params.kubeURI.String(), }, - testGatewayConnectionFunc: testKubeConnection, - webauthnLogin: params.webauthnLogin, + testGatewayConnection: testKubeConnection, + webauthnLogin: params.webauthnLogin, generateAndSetupUserCreds: func(t *testing.T, tc *libclient.TeleportClient, ttl time.Duration) { creds, err := helpers.GenerateUserCreds(helpers.UserCredsRequest{ Process: params.suite.root.Process, @@ -614,6 +616,10 @@ func setupUserMFA(ctx context.Context, t *testing.T, authServer *auth.Server, us }) require.NoError(t, err) + // webauthnLogin is not safe for concurrent use, partly due to the implementation of device, but + // mostly because Teleport itself doesn't allow for more than one in-flight MFA challenge. This is + // an arbitrary limitation which in theory we could change. But for now, parallel tests that use + // webauthnLogin must use a separate user for each test and not trigger parallel MFA prompts. webauthnLogin := func(ctx context.Context, origin string, assertion *wantypes.CredentialAssertion, prompt wancli.LoginPrompt, opts *wancli.LoginOpts) (*proto.MFAAuthenticateResponse, string, error) { car, err := device.SignAssertion(origin, assertion) if err != nil { @@ -676,34 +682,210 @@ func requireSessionMFARole(ctx context.Context, t *testing.T, authServer *auth.S require.NoError(t, err) } -func testTeletermAppGateway(t *testing.T, pack *appaccess.Pack, tc *client.TeleportClient) { +type makeTCAndWebauthnLoginFunc func(t *testing.T) (*libclient.TeleportClient, mfa.WebauthnLoginFunc) + +func testTeletermAppGateway(t *testing.T, pack *appaccess.Pack, makeTCAndWebauthnLogin makeTCAndWebauthnLoginFunc) { ctx := context.Background() t.Run("root cluster", func(t *testing.T) { - profileName := mustGetProfileName(t, pack.RootWebAddr()) - appURI := uri.NewClusterURI(profileName).AppendApp(pack.RootAppName()) + t.Parallel() - // The test can potentially hang forever if something is wrong with the MFA prompt, add a timeout. - ctx, cancel := context.WithTimeout(ctx, 10*time.Second) - t.Cleanup(cancel) - testAppGatewayCertRenewal(ctx, t, pack, tc, appURI) + t.Run("web app", func(t *testing.T) { + t.Parallel() + + profileName := mustGetProfileName(t, pack.RootWebAddr()) + appURI := uri.NewClusterURI(profileName).AppendApp(pack.RootAppName()) + + testAppGatewayCertRenewal(ctx, t, pack, makeTCAndWebauthnLogin, appURI) + }) + + t.Run("TCP app", func(t *testing.T) { + t.Parallel() + + profileName := mustGetProfileName(t, pack.RootWebAddr()) + appURI := uri.NewClusterURI(profileName).AppendApp(pack.RootTCPAppName()) + + tc, webauthnLogin := makeTCAndWebauthnLogin(t) + + testGatewayCertRenewal( + ctx, + t, + gatewayCertRenewalParams{ + tc: tc, + createGatewayParams: daemon.CreateGatewayParams{TargetURI: appURI.String()}, + testGatewayConnection: makeMustConnectTCPAppGateway(pack.RootTCPMessage()), + generateAndSetupUserCreds: pack.GenerateAndSetupUserCreds, + webauthnLogin: webauthnLogin, + }, + ) + }) + + t.Run("multi-port TCP app", func(t *testing.T) { + t.Parallel() + profileName := mustGetProfileName(t, pack.RootWebAddr()) + appURI := uri.NewClusterURI(profileName).AppendApp(pack.RootTCPMultiPortAppName()) + + tc, webauthnLogin := makeTCAndWebauthnLogin(t) + + testGatewayCertRenewal( + ctx, + t, + gatewayCertRenewalParams{ + tc: tc, + createGatewayParams: daemon.CreateGatewayParams{ + TargetURI: appURI.String(), + TargetSubresourceName: strconv.Itoa(pack.RootTCPMultiPortAppPortAlpha()), + }, + testGatewayConnection: makeMustConnectMultiPortTCPAppGateway( + pack.RootTCPMultiPortMessageAlpha(), pack.RootTCPMultiPortAppPortBeta(), pack.RootTCPMultiPortMessageBeta(), + ), + generateAndSetupUserCreds: pack.GenerateAndSetupUserCreds, + webauthnLogin: webauthnLogin, + // First MFA prompt is made when creating the gateway. Then makeMustConnectMultiPortTCPAppGateway + // changes the target port twice, which means two more prompts. + // + // Then testGatewayCertRenewal expires the certs and calls + // makeMustConnectMultiPortTCPAppGateway. The first connection refreshes the expired cert, + // then the function changes the target port twice again, resulting in two more prompts. + wantPromptMFACallCount: 3 + 3, + }, + ) + }) }) t.Run("leaf cluster", func(t *testing.T) { - profileName := mustGetProfileName(t, pack.RootWebAddr()) - appURI := uri.NewClusterURI(profileName). - AppendLeafCluster(pack.LeafAppClusterName()). - AppendApp(pack.LeafAppName()) + t.Parallel() + + t.Run("web app", func(t *testing.T) { + t.Parallel() + + profileName := mustGetProfileName(t, pack.RootWebAddr()) + appURI := uri.NewClusterURI(profileName). + AppendLeafCluster(pack.LeafAppClusterName()). + AppendApp(pack.LeafAppName()) + + testAppGatewayCertRenewal(ctx, t, pack, makeTCAndWebauthnLogin, appURI) + }) + + t.Run("TCP app", func(t *testing.T) { + t.Parallel() + + profileName := mustGetProfileName(t, pack.RootWebAddr()) + appURI := uri.NewClusterURI(profileName).AppendLeafCluster(pack.LeafAppClusterName()).AppendApp(pack.LeafTCPAppName()) - // The test can potentially hang forever if something is wrong with the MFA prompt, add a timeout. - ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + tc, webauthnLogin := makeTCAndWebauthnLogin(t) + + testGatewayCertRenewal( + ctx, + t, + gatewayCertRenewalParams{ + tc: tc, + createGatewayParams: daemon.CreateGatewayParams{TargetURI: appURI.String()}, + testGatewayConnection: makeMustConnectTCPAppGateway(pack.LeafTCPMessage()), + generateAndSetupUserCreds: pack.GenerateAndSetupUserCreds, + webauthnLogin: webauthnLogin, + }, + ) + }) + + t.Run("multi-port TCP app", func(t *testing.T) { + t.Parallel() + + profileName := mustGetProfileName(t, pack.RootWebAddr()) + appURI := uri.NewClusterURI(profileName).AppendLeafCluster(pack.LeafAppClusterName()).AppendApp(pack.LeafTCPMultiPortAppName()) + + tc, webauthnLogin := makeTCAndWebauthnLogin(t) + + testGatewayCertRenewal( + ctx, + t, + gatewayCertRenewalParams{ + tc: tc, + createGatewayParams: daemon.CreateGatewayParams{ + TargetURI: appURI.String(), + TargetSubresourceName: strconv.Itoa(pack.LeafTCPMultiPortAppPortAlpha()), + }, + testGatewayConnection: makeMustConnectMultiPortTCPAppGateway( + pack.LeafTCPMultiPortMessageAlpha(), pack.LeafTCPMultiPortAppPortBeta(), pack.LeafTCPMultiPortMessageBeta(), + ), + generateAndSetupUserCreds: pack.GenerateAndSetupUserCreds, + webauthnLogin: webauthnLogin, + // First MFA prompt is made when creating the gateway. Then makeMustConnectMultiPortTCPAppGateway + // changes the target port twice, which means two more prompts. + // + // Then testGatewayCertRenewal expires the certs and calls + // makeMustConnectMultiPortTCPAppGateway. The first connection refreshes the expired cert, + // then the function changes the target port twice again, resulting in two more prompts. + wantPromptMFACallCount: 3 + 3, + }, + ) + }) + }) +} + +func testTeletermAppGatewayTargetPortValidation(t *testing.T, pack *appaccess.Pack, makeTCAndWebauthnLogin makeTCAndWebauthnLoginFunc) { + t.Run("target port validation", func(t *testing.T) { + t.Parallel() + + tc, _ := makeTCAndWebauthnLogin(t) + err := tc.SaveProfile(false /* makeCurrent */) + require.NoError(t, err) + + storage, err := clusters.NewStorage(clusters.Config{ + Dir: tc.KeysDir, + InsecureSkipVerify: tc.InsecureSkipVerify, + HardwareKeyPromptConstructor: func(rootClusterURI uri.ResourceURI) keys.HardwareKeyPrompt { + return nil + }, + }) + require.NoError(t, err) + daemonService, err := daemon.New(daemon.Config{ + Storage: storage, + CreateTshdEventsClientCredsFunc: func() (grpc.DialOption, error) { + return grpc.WithTransportCredentials(insecure.NewCredentials()), nil + }, + CreateClientCacheFunc: func(newClient clientcache.NewClientFunc) (daemon.ClientCache, error) { + return clientcache.NewNoCache(newClient), nil + }, + KubeconfigsDir: t.TempDir(), + AgentsDir: t.TempDir(), + }) + require.NoError(t, err) + t.Cleanup(func() { + daemonService.Stop() + }) + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) t.Cleanup(cancel) - testAppGatewayCertRenewal(ctx, t, pack, tc, appURI) + + // Here the test setup ends and actual test code starts. + profileName := mustGetProfileName(t, pack.RootWebAddr()) + appURI := uri.NewClusterURI(profileName).AppendApp(pack.RootTCPMultiPortAppName()) + + _, err = daemonService.CreateGateway(ctx, daemon.CreateGatewayParams{ + TargetURI: appURI.String(), + // 42 shouldn't be handed out to a non-root user when creating a listener on port 0, so it's + // unlikely that 42 is going to end up in the app spec. + TargetSubresourceName: "42", + }) + require.True(t, trace.IsBadParameter(err), "Expected BadParameter, got %v", err) + require.ErrorContains(t, err, "not included in target ports") + + gateway, err := daemonService.CreateGateway(ctx, daemon.CreateGatewayParams{ + TargetURI: appURI.String(), + TargetSubresourceName: strconv.Itoa(pack.RootTCPMultiPortAppPortAlpha()), + }) + require.NoError(t, err) + + _, err = daemonService.SetGatewayTargetSubresourceName(ctx, gateway.URI().String(), "42") + require.True(t, trace.IsBadParameter(err), "Expected BadParameter, got %v", err) + require.ErrorContains(t, err, "not included in target ports") }) } -func testAppGatewayCertRenewal(ctx context.Context, t *testing.T, pack *appaccess.Pack, tc *libclient.TeleportClient, appURI uri.ResourceURI) { +func testAppGatewayCertRenewal(ctx context.Context, t *testing.T, pack *appaccess.Pack, makeTCAndWebauthnLogin makeTCAndWebauthnLoginFunc, appURI uri.ResourceURI) { t.Helper() + tc, webauthnLogin := makeTCAndWebauthnLogin(t) testGatewayCertRenewal( ctx, @@ -713,9 +895,9 @@ func testAppGatewayCertRenewal(ctx context.Context, t *testing.T, pack *appacces createGatewayParams: daemon.CreateGatewayParams{ TargetURI: appURI.String(), }, - testGatewayConnectionFunc: mustConnectAppGateway, + testGatewayConnection: mustConnectWebAppGateway, generateAndSetupUserCreds: pack.GenerateAndSetupUserCreds, - webauthnLogin: tc.WebauthnLogin, + webauthnLogin: webauthnLogin, }, ) } diff --git a/lib/teleterm/apiserver/handler/handler_gateways.go b/lib/teleterm/apiserver/handler/handler_gateways.go index 5a303e8e45c78..dbb0de52c9363 100644 --- a/lib/teleterm/apiserver/handler/handler_gateways.go +++ b/lib/teleterm/apiserver/handler/handler_gateways.go @@ -119,7 +119,7 @@ func makeGatewayCLICommand(cmds cmd.Cmds) *api.GatewayCLICommand { // // In Connect this is used to update the db name of a db connection along with the CLI command. func (s *Handler) SetGatewayTargetSubresourceName(ctx context.Context, req *api.SetGatewayTargetSubresourceNameRequest) (*api.Gateway, error) { - gateway, err := s.DaemonService.SetGatewayTargetSubresourceName(req.GatewayUri, req.TargetSubresourceName) + gateway, err := s.DaemonService.SetGatewayTargetSubresourceName(ctx, req.GatewayUri, req.TargetSubresourceName) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/teleterm/clusters/cluster_apps.go b/lib/teleterm/clusters/cluster_apps.go index 5b92788cb15b4..cfdecb8a62f66 100644 --- a/lib/teleterm/clusters/cluster_apps.go +++ b/lib/teleterm/clusters/cluster_apps.go @@ -25,6 +25,7 @@ import ( apiclient "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/client/proto" + apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/client" @@ -55,11 +56,11 @@ type SAMLIdPServiceProvider struct { Provider types.SAMLIdPServiceProvider } -func (c *Cluster) getApp(ctx context.Context, authClient authclient.ClientI, appName string) (types.Application, error) { +func GetApp(ctx context.Context, authClient authclient.ClientI, appName string) (types.Application, error) { var app types.Application err := AddMetadataToRetryableError(ctx, func() error { apps, err := apiclient.GetAllResources[types.AppServer](ctx, authClient, &proto.ListResourcesRequest{ - Namespace: c.clusterClient.Namespace, + Namespace: apidefaults.Namespace, ResourceType: types.KindAppServer, PredicateExpression: fmt.Sprintf(`name == "%s"`, appName), }) @@ -143,3 +144,29 @@ func (c *Cluster) GetAWSRoles(app types.Application) aws.Roles { } return aws.Roles{} } + +// ValidateTargetPort parses rawTargetPort to uint32 and checks if it's included in TCP ports of app. +// It also returns an error if app doesn't have any TCP ports defined. +func ValidateTargetPort(app types.Application, rawTargetPort string) (uint32, error) { + if rawTargetPort == "" { + return 0, nil + } + + targetPort, err := parseTargetPort(rawTargetPort) + if err != nil { + return 0, trace.Wrap(err) + } + + tcpPorts := app.GetTCPPorts() + if len(tcpPorts) == 0 { + return 0, trace.BadParameter("cannot specify target port %d because app %s does not provide access to multiple ports", + targetPort, app.GetName()) + } + + if !tcpPorts.Contains(int(targetPort)) { + return 0, trace.BadParameter("port %d is not included in target ports of app %s", + targetPort, app.GetName()) + } + + return targetPort, nil +} diff --git a/lib/teleterm/clusters/cluster_gateways.go b/lib/teleterm/clusters/cluster_gateways.go index 64577c35cf7dd..61c5fa7f38df4 100644 --- a/lib/teleterm/clusters/cluster_gateways.go +++ b/lib/teleterm/clusters/cluster_gateways.go @@ -21,6 +21,7 @@ package clusters import ( "context" "crypto/tls" + "strconv" "github.com/gravitational/trace" @@ -160,7 +161,7 @@ func (c *Cluster) createKubeGateway(ctx context.Context, params CreateGatewayPar func (c *Cluster) createAppGateway(ctx context.Context, params CreateGatewayParams) (gateway.Gateway, error) { appName := params.TargetURI.GetAppName() - app, err := c.getApp(ctx, params.ClusterClient.AuthClient, appName) + app, err := GetApp(ctx, params.ClusterClient.AuthClient, appName) if err != nil { return nil, trace.Wrap(err) } @@ -170,6 +171,13 @@ func (c *Cluster) createAppGateway(ctx context.Context, params CreateGatewayPara ClusterName: c.clusterClient.SiteName, URI: app.GetURI(), } + if params.TargetSubresourceName != "" { + targetPort, err := ValidateTargetPort(app, params.TargetSubresourceName) + if err != nil { + return nil, trace.Wrap(err) + } + routeToApp.TargetPort = targetPort + } var cert tls.Certificate if err := AddMetadataToRetryableError(ctx, func() error { @@ -182,6 +190,7 @@ func (c *Cluster) createAppGateway(ctx context.Context, params CreateGatewayPara gw, err := gateway.New(gateway.Config{ LocalPort: params.LocalPort, TargetURI: params.TargetURI, + TargetSubresourceName: params.TargetSubresourceName, TargetName: appName, Cert: cert, Protocol: app.GetProtocol(), @@ -195,6 +204,9 @@ func (c *Cluster) createAppGateway(ctx context.Context, params CreateGatewayPara RootClusterCACertPoolFunc: c.clusterClient.RootClusterCACertPool, ClusterName: c.Name, Username: c.status.Username, + // For multi-port TCP apps, the target port is stored in the target subresource name. Whenever + // that field is updated, the local proxy needs to generate a new cert which includes that port. + ClearCertsOnTargetSubresourceNameChange: true, }) return gw, trace.Wrap(err) } @@ -214,7 +226,7 @@ func (c *Cluster) ReissueGatewayCerts(ctx context.Context, clusterClient *client return cert, trace.Wrap(err) case g.TargetURI().IsApp(): appName := g.TargetURI().GetAppName() - app, err := c.getApp(ctx, clusterClient.AuthClient, appName) + app, err := GetApp(ctx, clusterClient.AuthClient, appName) if err != nil { return tls.Certificate{}, trace.Wrap(err) } @@ -224,6 +236,13 @@ func (c *Cluster) ReissueGatewayCerts(ctx context.Context, clusterClient *client ClusterName: c.clusterClient.SiteName, URI: app.GetURI(), } + if g.TargetSubresourceName() != "" { + targetPort, err := parseTargetPort(g.TargetSubresourceName()) + if err != nil { + return tls.Certificate{}, trace.BadParameter(err.Error()) + } + routeToApp.TargetPort = targetPort + } // The cert is returned from this function and finally set on LocalProxy by the middleware. cert, err := c.ReissueAppCert(ctx, clusterClient, routeToApp) @@ -232,3 +251,11 @@ func (c *Cluster) ReissueGatewayCerts(ctx context.Context, clusterClient *client return tls.Certificate{}, trace.NotImplemented("ReissueGatewayCerts does not support this gateway kind %v", g.TargetURI().String()) } } + +func parseTargetPort(rawTargetPort string) (uint32, error) { + targetPort, err := strconv.ParseUint(rawTargetPort, 10, 32) + if err != nil { + return 0, trace.BadParameter(err.Error()) + } + return uint32(targetPort), nil +} diff --git a/lib/teleterm/daemon/daemon.go b/lib/teleterm/daemon/daemon.go index d3528793a4b99..b27ded1ba205c 100644 --- a/lib/teleterm/daemon/daemon.go +++ b/lib/teleterm/daemon/daemon.go @@ -511,7 +511,7 @@ func (s *Service) GetGatewayCLICommand(ctx context.Context, gateway gateway.Gate // SetGatewayTargetSubresourceName updates the TargetSubresourceName field of a gateway stored in // s.gateways. -func (s *Service) SetGatewayTargetSubresourceName(gatewayURI, targetSubresourceName string) (gateway.Gateway, error) { +func (s *Service) SetGatewayTargetSubresourceName(ctx context.Context, gatewayURI, targetSubresourceName string) (gateway.Gateway, error) { s.mu.Lock() defer s.mu.Unlock() @@ -520,6 +520,28 @@ func (s *Service) SetGatewayTargetSubresourceName(gatewayURI, targetSubresourceN return nil, trace.Wrap(err) } + targetURI := gateway.TargetURI() + switch { + case targetURI.IsApp(): + clusterClient, err := s.GetCachedClient(ctx, targetURI) + if err != nil { + return nil, trace.Wrap(err) + } + + var app types.Application + if err := clusters.AddMetadataToRetryableError(ctx, func() error { + var err error + app, err = clusters.GetApp(ctx, clusterClient.CurrentCluster(), targetURI.GetAppName()) + return trace.Wrap(err) + }); err != nil { + return nil, trace.Wrap(err) + } + + if _, err := clusters.ValidateTargetPort(app, targetSubresourceName); err != nil { + return nil, trace.Wrap(err) + } + } + gateway.SetTargetSubresourceName(targetSubresourceName) return gateway, nil diff --git a/lib/teleterm/gateway/app.go b/lib/teleterm/gateway/app.go index 110d36604aeff..57b2753269a5a 100644 --- a/lib/teleterm/gateway/app.go +++ b/lib/teleterm/gateway/app.go @@ -19,8 +19,6 @@ package gateway import ( "context" "crypto/tls" - "net/url" - "strings" "github.com/gravitational/trace" @@ -33,15 +31,6 @@ type app struct { *base } -// LocalProxyURL returns the URL of the local proxy. -func (a *app) LocalProxyURL() string { - proxyURL := url.URL{ - Scheme: strings.ToLower(a.Protocol()), - Host: a.LocalAddress() + ":" + a.LocalPort(), - } - return proxyURL.String() -} - func makeAppGateway(cfg Config) (Gateway, error) { base, err := newBase(cfg) if err != nil { diff --git a/lib/teleterm/gateway/app_middleware.go b/lib/teleterm/gateway/app_middleware.go index 9b58de8624016..8f47425142d80 100644 --- a/lib/teleterm/gateway/app_middleware.go +++ b/lib/teleterm/gateway/app_middleware.go @@ -43,12 +43,12 @@ func (m *appMiddleware) OnNewConnection(ctx context.Context, lp *alpn.LocalProxy return nil } - // Return early and don't fire onExpiredCert if certs are invalid but not due to expiry. - if !errors.As(err, &x509.CertificateInvalidError{}) { + // Return early and don't fire onExpiredCert if certs are invalid but not due to expiry or removal. + if !errors.As(err, &x509.CertificateInvalidError{}) && !trace.IsNotFound(err) { return trace.Wrap(err) } - m.logger.DebugContext(ctx, "Gateway certificates have expired", "error", err) + m.logger.DebugContext(ctx, "Gateway certificates have expired or been removed", "error", err) cert, err := m.onExpiredCert(ctx) if err != nil { diff --git a/lib/teleterm/gateway/base.go b/lib/teleterm/gateway/base.go index 3a8b076307c60..41d407ca0d8d7 100644 --- a/lib/teleterm/gateway/base.go +++ b/lib/teleterm/gateway/base.go @@ -20,10 +20,12 @@ package gateway import ( "context" + "crypto/tls" "fmt" "log/slog" "net" "strconv" + "sync" "github.com/gravitational/trace" @@ -89,6 +91,9 @@ func newBase(cfg Config) (*base, error) { // Close terminates gateway connection. Fails if called on an already closed gateway. func (b *base) Close() error { + b.mu.Lock() + defer b.mu.Unlock() + b.closeCancel() var errs []error @@ -158,17 +163,29 @@ func (b *base) TargetUser() string { } func (b *base) TargetSubresourceName() string { + b.mu.RLock() + defer b.mu.RUnlock() + return b.cfg.TargetSubresourceName } func (b *base) SetTargetSubresourceName(value string) { + b.mu.Lock() + defer b.mu.Unlock() b.cfg.TargetSubresourceName = value + + if b.cfg.ClearCertsOnTargetSubresourceNameChange { + b.Log().InfoContext(b.closeContext, "Clearing cert") + b.localProxy.SetCert(tls.Certificate{}) + } } func (b *base) Log() *slog.Logger { return b.cfg.Logger } +// LocalAddress returns the local host in the net package terms (localhost or 127.0.0.1, depending +// on the platform). func (b *base) LocalAddress() string { return b.cfg.LocalAddress } @@ -187,15 +204,13 @@ func (b *base) LocalPortInt() int { } func (b *base) cloneConfig() Config { + b.mu.RLock() + defer b.mu.RUnlock() + return *b.cfg } -// Gateway describes local proxy that creates a gateway to the remote Teleport resource. -// -// Gateway is not safe for concurrent use in itself. However, all access to gateways is gated by -// daemon.Service which obtains a lock for any operation pertaining to gateways. -// -// In the future if Gateway becomes more complex it might be worthwhile to add an RWMutex to it. +// Gateway is a local proxy to a remote Teleport resource. type base struct { cfg *Config localProxy *alpn.LocalProxy @@ -206,6 +221,7 @@ type base struct { // that the local proxy is now closed and to release any resources. closeContext context.Context closeCancel context.CancelFunc + mu sync.RWMutex } type TCPPortAllocator interface { diff --git a/lib/teleterm/gateway/config.go b/lib/teleterm/gateway/config.go index 67768d05900db..a4877d20d8394 100644 --- a/lib/teleterm/gateway/config.go +++ b/lib/teleterm/gateway/config.go @@ -91,6 +91,11 @@ type Config struct { RootClusterCACertPoolFunc alpnproxy.GetClusterCACertPoolFunc // KubeconfigsDir is the directory containing kubeconfigs for kube gateways. KubeconfigsDir string + // ClearCertsOnTargetSubresourceNameChange is useful in situations where TargetSubresourceName is + // used to generate a cert. In that case, after TargetSubresourceName is changed, the gateway will + // clear the cert from the local proxy and the middleware is going to request a new cert on the + // next connection. + ClearCertsOnTargetSubresourceNameChange bool } // OnExpiredCertFunc is the type of a function that is called when a new downstream connection is diff --git a/lib/teleterm/gateway/interfaces.go b/lib/teleterm/gateway/interfaces.go index 27bc6735a2b9d..9d102d788f041 100644 --- a/lib/teleterm/gateway/interfaces.go +++ b/lib/teleterm/gateway/interfaces.go @@ -43,6 +43,8 @@ type Gateway interface { TargetSubresourceName() string SetTargetSubresourceName(value string) Log() *slog.Logger + // LocalAddress returns the local host in the net package terms (localhost or 127.0.0.1, depending + // on the platform). LocalAddress() string LocalPort() string LocalPortInt() int @@ -95,7 +97,4 @@ type Kube interface { // App defines an app gateway. type App interface { Gateway - - // LocalProxyURL returns the URL of the local proxy. - LocalProxyURL() string } diff --git a/lib/teleterm/gateway/kube.go b/lib/teleterm/gateway/kube.go index 1dccb4189accc..d39f925bd75bf 100644 --- a/lib/teleterm/gateway/kube.go +++ b/lib/teleterm/gateway/kube.go @@ -187,6 +187,8 @@ func (k *kube) makeForwardProxyForKube() error { } func (k *kube) writeKubeconfig(key *keys.PrivateKey, cas map[string]tls.Certificate) error { + k.base.mu.RLock() + defer k.base.mu.RUnlock() ca, ok := cas[k.cfg.ClusterName] if !ok { return trace.BadParameter("CA for teleport cluster %q is missing", k.cfg.ClusterName) diff --git a/proto/teleport/lib/teleterm/v1/gateway.proto b/proto/teleport/lib/teleterm/v1/gateway.proto index 7661a6bf31f4a..4399fcc307e26 100644 --- a/proto/teleport/lib/teleterm/v1/gateway.proto +++ b/proto/teleport/lib/teleterm/v1/gateway.proto @@ -43,12 +43,13 @@ message Gateway { string local_address = 5; // local_port is the gateway address on localhost string local_port = 6; - // protocol is the gateway protocol + // protocol is the protocol used by the gateway. For databases, it matches the type of the + // database that the gateway targets. For apps, it's either "HTTP" or "TCP". string protocol = 7; reserved 8; reserved "cli_command"; // target_subresource_name points at a subresource of the remote resource, for example a - // database name on a database server. + // database name on a database server or a target port of a multi-port TCP app. string target_subresource_name = 9; // gateway_cli_client represents a command that the user can execute to connect to the resource // through the gateway. diff --git a/web/packages/design/src/Input/Input.tsx b/web/packages/design/src/Input/Input.tsx index 3cf50e9d1009b..fe3b7feca968c 100644 --- a/web/packages/design/src/Input/Input.tsx +++ b/web/packages/design/src/Input/Input.tsx @@ -70,6 +70,7 @@ interface InputProps extends ColorProps, SpaceProps, WidthProps, HeightProps { inputMode?: InputMode; spellCheck?: boolean; style?: React.CSSProperties; + required?: boolean; 'aria-invalid'?: HTMLAttributes<'input'>['aria-invalid']; 'aria-describedby'?: HTMLAttributes<'input'>['aria-describedby']; @@ -170,6 +171,7 @@ const Input = forwardRef((props, ref) => { inputMode, spellCheck, style, + required, 'aria-invalid': ariaInvalid, 'aria-describedby': ariaDescribedBy, @@ -222,6 +224,7 @@ const Input = forwardRef((props, ref) => { inputMode, spellCheck, style, + required, 'aria-invalid': ariaInvalid, 'aria-describedby': ariaDescribedBy, diff --git a/web/packages/design/src/Menu/Menu.story.tsx b/web/packages/design/src/Menu/Menu.story.tsx index c7b0726ea414b..c3ba4ae802762 100644 --- a/web/packages/design/src/Menu/Menu.story.tsx +++ b/web/packages/design/src/Menu/Menu.story.tsx @@ -107,6 +107,18 @@ export const MenuItems = () => ( Amet nisi tempor + +

Label as first child

+ + Tempus ut libero + Lorem ipsum + Dolor sit amet + + Leo vitae arcu + Donec volutpat + Mauris sit + +
); diff --git a/web/packages/design/src/Menu/MenuItem.tsx b/web/packages/design/src/Menu/MenuItem.tsx index 5ccbae227c835..a9b06373bd787 100644 --- a/web/packages/design/src/Menu/MenuItem.tsx +++ b/web/packages/design/src/Menu/MenuItem.tsx @@ -71,37 +71,39 @@ const MenuItemBase = styled(Flex)` ${fromThemeBase} `; -export const MenuItemSectionLabel = styled(MenuItemBase).attrs({ - px: 2, +export const MenuItemSectionSeparator = styled.hr.attrs({ onClick: event => { // Make sure that clicks on this element don't trigger onClick set on MenuList. event.stopPropagation(); }, })` - font-weight: bold; - min-height: 16px; + background: ${props => props.theme.colors.interactive.tonal.neutral[1]}; + height: 1px; + border: 0; + font-size: 0; `; -export const MenuItemSectionSeparator = styled.hr.attrs({ +export const MenuItemSectionLabel = styled(MenuItemBase).attrs({ + px: 2, onClick: event => { // Make sure that clicks on this element don't trigger onClick set on MenuList. event.stopPropagation(); }, })` - background: ${props => props.theme.colors.interactive.tonal.neutral[1]}; - height: 1px; - border: 0; - font-size: 0; + font-weight: bold; + min-height: 16px; - // Add padding to the label for extra visual space, but only when it follows a separator. - // If a separator follows a MenuItem, there's already enough visual space, so no extra space is - // needed. The hover state of MenuItem highlights everything right from the separator start to the - // end of MenuItem. + // Add padding to the label for extra visual space, but only when it follows a separator or is the + // first child. + // + // If a separator follows a MenuItem, there's already enough visual space between MenuItem and + // separator, so no extra space is needed. The hover state of MenuItem highlights everything right + // from the separator start to the end of MenuItem. // // Padding is used instead of margin here on purpose, so that there's no empty transparent space // between Separator and Label – otherwise clicking on that space would count as a click on // MenuList and not trigger onClick set on Separator or Label. - & + ${MenuItemSectionLabel} { + ${MenuItemSectionSeparator} + &, &:first-child { padding-top: ${props => props.theme.space[1]}px; } `; diff --git a/web/packages/design/src/keyframes.ts b/web/packages/design/src/keyframes.ts index c49799db9f67f..a3a7bf96f7245 100644 --- a/web/packages/design/src/keyframes.ts +++ b/web/packages/design/src/keyframes.ts @@ -46,3 +46,7 @@ export const blink = keyframes` opacity: 100%; } `; + +export const disappear = keyframes` +to { opacity: 0; } +`; diff --git a/web/packages/shared/components/FieldInput/FieldInput.tsx b/web/packages/shared/components/FieldInput/FieldInput.tsx index 2ac28f54e810c..2f3a3eb012550 100644 --- a/web/packages/shared/components/FieldInput/FieldInput.tsx +++ b/web/packages/shared/components/FieldInput/FieldInput.tsx @@ -59,6 +59,7 @@ const FieldInput = forwardRef( toolTipContent = null, disabled = false, markAsError = false, + required = false, ...styles }, ref @@ -94,6 +95,7 @@ const FieldInput = forwardRef( size={size} aria-invalid={hasError || markAsError} aria-describedby={helperTextId} + required={required} /> ); @@ -219,7 +221,7 @@ export type FieldInputProps = BoxProps & { id?: string; name?: string; value?: string; - label?: string; + label?: React.ReactNode; helperText?: React.ReactNode; icon?: React.ComponentType; size?: InputSize; @@ -245,4 +247,5 @@ export type FieldInputProps = BoxProps & { // input box as error color before validator // runs (which marks it as error) markAsError?: boolean; + required?: boolean; }; diff --git a/web/packages/teleterm/src/services/tshd/testHelpers.ts b/web/packages/teleterm/src/services/tshd/testHelpers.ts index b19fc95725192..8cb15ec3e3701 100644 --- a/web/packages/teleterm/src/services/tshd/testHelpers.ts +++ b/web/packages/teleterm/src/services/tshd/testHelpers.ts @@ -290,7 +290,7 @@ export const makeAppGateway = ( targetUri: appUri, localAddress: 'localhost', localPort: '1337', - targetSubresourceName: 'bar', + targetSubresourceName: undefined, gatewayCliCommand: { path: '', preview: 'curl http://localhost:1337', diff --git a/web/packages/teleterm/src/ui/DocumentCluster/ActionButtons.tsx b/web/packages/teleterm/src/ui/DocumentCluster/ActionButtons.tsx index c16dd1d5fa779..39147660e843e 100644 --- a/web/packages/teleterm/src/ui/DocumentCluster/ActionButtons.tsx +++ b/web/packages/teleterm/src/ui/DocumentCluster/ActionButtons.tsx @@ -23,7 +23,7 @@ import { MenuItemSectionLabel, MenuItemSectionSeparator, } from 'design/Menu/MenuItem'; -import { App } from 'gen-proto-ts/teleport/lib/teleterm/v1/app_pb'; +import { App, PortRange } from 'gen-proto-ts/teleport/lib/teleterm/v1/app_pb'; import { Cluster } from 'gen-proto-ts/teleport/lib/teleterm/v1/cluster_pb'; import { Database } from 'gen-proto-ts/teleport/lib/teleterm/v1/database_pb'; import { Kube } from 'gen-proto-ts/teleport/lib/teleterm/v1/kube_pb'; @@ -125,8 +125,11 @@ export function ConnectAppActionButton(props: { app: App }): React.JSX.Element { connectToAppWithVnet(appContext, launchVnet, props.app, targetPort); } - function setUpGateway(): void { - setUpAppGateway(appContext, props.app, { origin: 'resource_table' }); + function setUpGateway(targetPort?: number): void { + setUpAppGateway(appContext, props.app, { + telemetry: { origin: 'resource_table' }, + targetPort, + }); } const rootCluster = appContext.clustersService.findCluster( @@ -229,7 +232,7 @@ function AppButton(props: { cluster: Cluster; rootCluster: Cluster; connectWithVnet(targetPort?: number): void; - setUpGateway(): void; + setUpGateway(targetPort?: number): void; onLaunchUrl(): void; isVnetSupported: boolean; }) { @@ -285,37 +288,15 @@ function AppButton(props: { target="_blank" title="Launch the app in the browser" > - Set up connection + props.setUpGateway()}> + Set up connection + ); } // TCP app with VNet. if (props.isVnetSupported) { - let $targetPorts: JSX.Element; - if (props.app.tcpPorts.length) { - $targetPorts = ( - <> - - Available target ports - {props.app.tcpPorts.map((portRange, index) => ( - props.connectWithVnet(portRange.port)} - > - {formatPortRange(portRange)} - - ))} - - ); - } - return ( props.connectWithVnet()} > - Connect without VNet - {$targetPorts} + props.setUpGateway()}> + Connect without VNet + + {!!props.app.tcpPorts.length && ( + <> + + props.connectWithVnet(port)} + /> + + )} + + ); + } + + // Multi-port TCP app without VNet. + if (props.app.tcpPorts.length) { + return ( + props.setUpGateway()} + > + props.setUpGateway(port)} + /> ); } - // TCP app without VNet. + // Single-port TCP app without VNet. return ( props.setUpGateway()} textTransform="none" > Connect @@ -341,6 +349,29 @@ function AppButton(props: { ); } +const AvailableTargetPorts = (props: { + tcpPorts: PortRange[]; + onItemClick: (portRangePort: number) => void; +}) => ( + <> + Available target ports + {props.tcpPorts.map((portRange, index) => ( + props.onItemClick(portRange.port)} + > + {formatPortRange(portRange)} + + ))} + +); + export function AccessRequestButton(props: { isResourceAdded: boolean; requestStarted: boolean; diff --git a/web/packages/teleterm/src/ui/DocumentGateway/useGateway.ts b/web/packages/teleterm/src/ui/DocumentGateway/useGateway.ts index 1c08bae058742..743667ebfa662 100644 --- a/web/packages/teleterm/src/ui/DocumentGateway/useGateway.ts +++ b/web/packages/teleterm/src/ui/DocumentGateway/useGateway.ts @@ -30,6 +30,7 @@ import { retryWithRelogin } from 'teleterm/ui/utils'; export function useGateway(doc: DocumentGateway) { const ctx = useAppContext(); + const { clustersService } = ctx; const { documentsService } = useWorkspaceContext(); // The port to show as default in the input field in case creating a gateway fails. // This is typically the case if someone reopens the app and the port of the gateway is already @@ -51,7 +52,7 @@ export function useGateway(doc: DocumentGateway) { try { gw = await retryWithRelogin(ctx, doc.targetUri, () => - ctx.clustersService.createGateway({ + clustersService.createGateway({ targetUri: doc.targetUri, localPort: port, targetUser: doc.targetUser, @@ -92,34 +93,52 @@ export function useGateway(doc: DocumentGateway) { }); const [disconnectAttempt, disconnect] = useAsync(async () => { - await ctx.clustersService.removeGateway(doc.gatewayUri); + await clustersService.removeGateway(doc.gatewayUri); documentsService.close(doc.uri); }); const [changeTargetSubresourceNameAttempt, changeTargetSubresourceName] = - useAsync(async (name: string) => { - const updatedGateway = - await ctx.clustersService.setGatewayTargetSubresourceName( - doc.gatewayUri, - name - ); + useAsync( + useCallback( + (name: string) => + retryWithRelogin(ctx, doc.targetUri, async () => { + const updatedGateway = + await clustersService.setGatewayTargetSubresourceName( + doc.gatewayUri, + name + ); - documentsService.update(doc.uri, { - targetSubresourceName: updatedGateway.targetSubresourceName, - }); - }); - - const [changePortAttempt, changePort] = useAsync(async (port: string) => { - const updatedGateway = await ctx.clustersService.setGatewayLocalPort( - doc.gatewayUri, - port + documentsService.update(doc.uri, { + targetSubresourceName: updatedGateway.targetSubresourceName, + }); + }), + [ + clustersService, + documentsService, + doc.uri, + doc.gatewayUri, + ctx, + doc.targetUri, + ] + ) ); - documentsService.update(doc.uri, { - targetSubresourceName: updatedGateway.targetSubresourceName, - port: updatedGateway.localPort, - }); - }); + const [changePortAttempt, changePort] = useAsync( + useCallback( + async (port: string) => { + const updatedGateway = await clustersService.setGatewayLocalPort( + doc.gatewayUri, + port + ); + + documentsService.update(doc.uri, { + targetSubresourceName: updatedGateway.targetSubresourceName, + port: updatedGateway.localPort, + }); + }, + [clustersService, documentsService, doc.uri, doc.gatewayUri] + ) + ); useEffect( function createGatewayOnMount() { diff --git a/web/packages/teleterm/src/ui/DocumentGatewayApp/AppGateway.tsx b/web/packages/teleterm/src/ui/DocumentGatewayApp/AppGateway.tsx index 1c2981ce9f42b..bd31e84d80035 100644 --- a/web/packages/teleterm/src/ui/DocumentGatewayApp/AppGateway.tsx +++ b/web/packages/teleterm/src/ui/DocumentGatewayApp/AppGateway.tsx @@ -16,18 +16,27 @@ * along with this program. If not, see . */ -import { useMemo, useRef } from 'react'; +import { + ChangeEvent, + ChangeEventHandler, + PropsWithChildren, + useEffect, + useMemo, + useState, +} from 'react'; +import styled from 'styled-components'; import { Alert, - Box, ButtonSecondary, + disappear, Flex, H1, - Indicator, Link, + rotate360, Text, } from 'design'; +import { Check, Spinner } from 'design/Icon'; import { Gateway } from 'gen-proto-ts/teleport/lib/teleterm/v1/gateway_pb'; import { TextSelectCopy } from 'shared/components/TextSelectCopy'; import Validation from 'shared/components/Validation'; @@ -39,68 +48,110 @@ import { PortFieldInput } from '../components/FieldInputs'; export function AppGateway(props: { gateway: Gateway; disconnectAttempt: Attempt; - changePort(port: string): void; - changePortAttempt: Attempt; + changeLocalPort(port: string): void; + changeLocalPortAttempt: Attempt; + changeTargetPort(port: string): void; + changeTargetPortAttempt: Attempt; disconnect(): void; }) { const { gateway } = props; - const formRef = useRef(); - const { changePort } = props; - const handleChangePort = useMemo(() => { - return debounce((value: string) => { - if (formRef.current.reportValidity()) { - changePort(value); - } - }, 1000); - }, [changePort]); + const { + changeLocalPort, + changeLocalPortAttempt, + changeTargetPort, + changeTargetPortAttempt, + disconnectAttempt, + } = props; + // It must be possible to update local port while target port is invalid, hence why + // useDebouncedPortChangeHandler checks the validity of only one input at a time. Otherwise the UI + // would lose updates to the local port while the target port was invalid. + const handleLocalPortChange = useDebouncedPortChangeHandler(changeLocalPort); + const handleTargetPortChange = + useDebouncedPortChangeHandler(changeTargetPort); let address = `${gateway.localAddress}:${gateway.localPort}`; if (gateway.protocol === 'HTTP') { address = `http://${address}`; } + // AppGateway doesn't have access to the app resource itself, so it has to decide whether the + // app is multi-port or not in some other way. + // For multi-port apps, DocumentGateway comes with targetSubresourceName prefilled to the first + // port number found in TCP ports. Single-port apps have this field empty. + // So, if targetSubresourceName is present, then the app must be multi-port. In this case, the + // user is free to change it and can never provide an empty targetSubresourceName. + // When the app is not multi-port, targetSubresourceName is empty and the user cannot change it. + const isMultiPort = + gateway.protocol === 'TCP' && gateway.targetSubresourceName; + return ( - - + +

App Connection

Close Connection
- {props.disconnectAttempt.status === 'error' && ( - + {disconnectAttempt.status === 'error' && ( + Could not close the connection )} - + + } defaultValue={gateway.localPort} - onChange={e => handleChangePort(e.target.value)} - mb={2} + onChange={handleLocalPortChange} + mb={0} /> + {isMultiPort && ( + + } + required + defaultValue={gateway.targetSubresourceName} + onChange={handleTargetPortChange} + mb={0} + /> + )} - {props.changePortAttempt.status === 'processing' && ( - - )} - Access the app at: - +
+ Access the app at: + +
- {props.changePortAttempt.status === 'error' && ( - - Could not change the port number + {changeLocalPortAttempt.status === 'error' && ( + + Could not change the local port + + )} + + {changeTargetPortAttempt.status === 'error' && ( + + Could not change the target port )} @@ -115,6 +166,89 @@ export function AppGateway(props: { {' '} for more details. -
+ ); } + +const LabelWithAttemptStatus = (props: { + text: string; + attempt: Attempt; +}) => ( + + {props.text} + {props.attempt.status === 'processing' && ( + + )} + {props.attempt.status === 'success' && ( + // CSS animations are repeated whenever the parent goes from `display: none` to something + // else. As a result, we need to unmount the animated check so that the animation is not + // repeated when the user switches to this tab. + // https://www.w3.org/TR/css-animations-1/#example-4e34d7ba + + + + )} + +); + +/** + * useDebouncedPortChangeHandler returns a debounced change handler that calls the change function + * only if the input from which the event originated is valid. + */ +const useDebouncedPortChangeHandler = ( + changeFunc: (port: string) => void +): ChangeEventHandler => + useMemo( + () => + debounce((event: ChangeEvent) => { + if (event.target.reportValidity()) { + changeFunc(event.target.value); + } + }, 1000), + [changeFunc] + ); + +const AnimatedSpinner = styled(Spinner)` + animation: ${rotate360} 1.5s infinite linear; + // The spinner needs to be positioned absolutely so that the fact that it's spinning + // doesn't affect the size of the parent. + position: absolute; + right: 0; + top: 0; +`; + +const disappearanceDelayMs = 1000; +const disappearanceDurationMs = 200; + +const DisappearingCheck = styled(Check)` + opacity: 1; + animation: ${disappear}; + animation-delay: ${disappearanceDelayMs}ms; + animation-duration: ${disappearanceDurationMs}ms; + animation-fill-mode: forwards; +`; + +const UnmountAfter = ({ + timeoutMs, + children, +}: PropsWithChildren<{ timeoutMs: number }>) => { + const [isMounted, setIsMounted] = useState(true); + + useEffect(() => { + const timeout = setTimeout(() => { + setIsMounted(false); + }, timeoutMs); + + return () => { + clearTimeout(timeout); + }; + }, [timeoutMs]); + + return isMounted ? children : null; +}; diff --git a/web/packages/teleterm/src/ui/DocumentGatewayApp/DocumentGatewayApp.story.tsx b/web/packages/teleterm/src/ui/DocumentGatewayApp/DocumentGatewayApp.story.tsx index c0b4ec802b28e..936f1c8a399b1 100644 --- a/web/packages/teleterm/src/ui/DocumentGatewayApp/DocumentGatewayApp.story.tsx +++ b/web/packages/teleterm/src/ui/DocumentGatewayApp/DocumentGatewayApp.story.tsx @@ -30,9 +30,10 @@ import { MockWorkspaceContextProvider } from 'teleterm/ui/fixtures/MockWorkspace import * as types from 'teleterm/ui/services/workspacesService'; type StoryProps = { - appType: 'web' | 'tcp'; + appType: 'web' | 'tcp' | 'tcp-multi-port'; online: boolean; - changePort: 'succeed' | 'throw-error'; + changeLocalPort: 'succeed' | 'throw-error'; + changeTargetPort: 'succeed' | 'throw-error'; disconnect: 'succeed' | 'throw-error'; }; @@ -42,9 +43,14 @@ const meta: Meta = { argTypes: { appType: { control: { type: 'radio' }, - options: ['web', 'tcp'], + options: ['web', 'tcp', 'tcp-multi-port'], }, - changePort: { + changeLocalPort: { + if: { arg: 'online' }, + control: { type: 'radio' }, + options: ['succeed', 'throw-error'], + }, + changeTargetPort: { if: { arg: 'online' }, control: { type: 'radio' }, options: ['succeed', 'throw-error'], @@ -58,7 +64,8 @@ const meta: Meta = { args: { appType: 'web', online: true, - changePort: 'succeed', + changeLocalPort: 'succeed', + changeTargetPort: 'succeed', disconnect: 'succeed', }, }; @@ -70,6 +77,10 @@ export function Story(props: StoryProps) { if (props.appType === 'tcp') { gateway.protocol = 'TCP'; } + if (props.appType === 'tcp-multi-port') { + gateway.protocol = 'TCP'; + gateway.targetSubresourceName = '4242'; + } const documentGateway: types.DocumentGateway = { kind: 'doc.gateway', targetUri: '/clusters/bar/apps/quux', @@ -80,10 +91,14 @@ export function Story(props: StoryProps) { targetUser: '', status: '', targetName: 'quux', + targetSubresourceName: undefined, }; if (!props.online) { documentGateway.gatewayUri = undefined; } + if (props.appType === 'tcp-multi-port') { + documentGateway.targetSubresourceName = '4242'; + } const appContext = new MockAppContext(); appContext.workspacesService.setState(draftState => { @@ -105,8 +120,26 @@ export function Story(props: StoryProps) { wait(1000).then( () => new MockedUnaryCall( - { ...gateway, localPort }, - props.changePort === 'throw-error' + { + ...appContext.clustersService.findGateway(gateway.uri), + localPort, + }, + props.changeLocalPort === 'throw-error' + ? new Error('something went wrong') + : undefined + ) + ); + appContext.tshd.setGatewayTargetSubresourceName = ({ + targetSubresourceName, + }) => + wait(1000).then( + () => + new MockedUnaryCall( + { + ...appContext.clustersService.findGateway(gateway.uri), + targetSubresourceName, + }, + props.changeTargetPort === 'throw-error' ? new Error('something went wrong') : undefined ) diff --git a/web/packages/teleterm/src/ui/DocumentGatewayApp/DocumentGatewayApp.tsx b/web/packages/teleterm/src/ui/DocumentGatewayApp/DocumentGatewayApp.tsx index ba70a7dfbdbe3..24db9f673be64 100644 --- a/web/packages/teleterm/src/ui/DocumentGatewayApp/DocumentGatewayApp.tsx +++ b/web/packages/teleterm/src/ui/DocumentGatewayApp/DocumentGatewayApp.tsx @@ -29,13 +29,15 @@ export function DocumentGatewayApp(props: { const { doc } = props; const { gateway, - changePort, - changePortAttempt, + changePort: changeLocalPort, + changePortAttempt: changeLocalPortAttempt, connected, connectAttempt, disconnect, disconnectAttempt, reconnect, + changeTargetSubresourceName: changeTargetPort, + changeTargetSubresourceNameAttempt: changeTargetPortAttempt, } = useGateway(doc); return ( @@ -47,14 +49,17 @@ export function DocumentGatewayApp(props: { targetName={doc.targetName} gatewayPort={{ isSupported: true, defaultPort: doc.port }} reconnect={reconnect} + portFieldLabel="Local Port (optional)" /> ) : ( )} diff --git a/web/packages/teleterm/src/ui/TabHost/useTabShortcuts.test.tsx b/web/packages/teleterm/src/ui/TabHost/useTabShortcuts.test.tsx index ce65290c2eb1f..b8fe467178b54 100644 --- a/web/packages/teleterm/src/ui/TabHost/useTabShortcuts.test.tsx +++ b/web/packages/teleterm/src/ui/TabHost/useTabShortcuts.test.tsx @@ -55,6 +55,7 @@ function getMockDocuments(): Document[] { targetUri: '/clusters/bar/dbs/foobar', targetName: 'foobar', targetUser: 'foo', + targetSubresourceName: undefined, origin: 'resource_table', status: '', }, @@ -66,6 +67,7 @@ function getMockDocuments(): Document[] { targetUri: '/clusters/bar/dbs/foobar', targetName: 'foobar', targetUser: 'bar', + targetSubresourceName: undefined, origin: 'resource_table', status: '', }, diff --git a/web/packages/teleterm/src/ui/components/FieldInputs.tsx b/web/packages/teleterm/src/ui/components/FieldInputs.tsx index 21086d8f9bb23..7e7d57e4ec40f 100644 --- a/web/packages/teleterm/src/ui/components/FieldInputs.tsx +++ b/web/packages/teleterm/src/ui/components/FieldInputs.tsx @@ -16,23 +16,26 @@ * along with this program. If not, see . */ -import { forwardRef } from 'react'; +import styled from 'styled-components'; -import FieldInput, { FieldInputProps } from 'shared/components/FieldInput'; +import FieldInput from 'shared/components/FieldInput'; -export const ConfigFieldInput = forwardRef( - (props, ref) => -); +export const ConfigFieldInput = styled(FieldInput).attrs({ size: 'small' })` + input { + &:invalid, + &:invalid:hover { + border-color: ${props => + props.theme.colors.interactive.solid.danger.default}; + } + } +`; -export const PortFieldInput = forwardRef( - (props, ref) => ( - - ) -); +export const PortFieldInput = styled(ConfigFieldInput).attrs({ + type: 'number', + min: 1, + max: 65535, + // Without a min width, the stepper controls end up being to close to a long port number such + // as 65535. minWidth instead of width allows the field to grow with the label, so that e.g. + // a custom label of "Local Port (optional)" is displayed on a single line. + minWidth: '110px', +})``; diff --git a/web/packages/teleterm/src/ui/components/OfflineGateway.tsx b/web/packages/teleterm/src/ui/components/OfflineGateway.tsx index 500a85951ba9a..2dbc1027565ee 100644 --- a/web/packages/teleterm/src/ui/components/OfflineGateway.tsx +++ b/web/packages/teleterm/src/ui/components/OfflineGateway.tsx @@ -36,7 +36,9 @@ export function OfflineGateway(props: { targetName: string; /** Gateway kind displayed in the UI, for example, 'database'. */ gatewayKind: string; + portFieldLabel?: string; }) { + const portFieldLabel = props.portFieldLabel || 'Port (optional)'; const defaultPort = props.gatewayPort.isSupported ? props.gatewayPort.defaultPort : undefined; @@ -88,7 +90,7 @@ export function OfflineGateway(props: { {props.gatewayPort.isSupported && ( { describe('setUpAppGateway', () => { test.each([ { - name: 'creates tunnel for a tcp app', + name: 'creates tunnel for a single-port TCP app', app: makeApp({ endpointUri: 'tcp://localhost:3000', }), }, + { + name: 'creates tunnel for a multi-port TCP app', + app: makeApp({ + endpointUri: 'tcp://localhost', + tcpPorts: [{ port: 1234, endPort: 0 }], + }), + expectedTargetSubresourceName: '1234', + }, + { + name: 'creates tunnel for a multi-port TCP app with a preselected target port', + app: makeApp({ + endpointUri: 'tcp://localhost', + tcpPorts: [{ port: 1234, endPort: 0 }], + }), + targetPort: 1234, + }, { name: 'creates tunnel for a web app', app: makeApp({ endpointUri: 'http://localhost:3000', }), }, - ])('$name', async ({ app }) => { + ])('$name', async ({ app, targetPort, expectedTargetSubresourceName }) => { const appContext = new MockAppContext(); setTestCluster(appContext); - await setUpAppGateway(appContext, app, { origin: 'resource_table' }); + await setUpAppGateway(appContext, app, { + telemetry: { origin: 'resource_table' }, + targetPort, + }); const documents = appContext.workspacesService .getActiveWorkspaceDocumentService() .getGatewayDocuments(); @@ -147,7 +166,8 @@ describe('setUpAppGateway', () => { port: undefined, status: '', targetName: 'foo', - targetSubresourceName: undefined, + targetSubresourceName: + expectedTargetSubresourceName || targetPort?.toString() || undefined, targetUri: '/clusters/teleport-local/apps/foo', targetUser: '', title: 'foo', diff --git a/web/packages/teleterm/src/ui/services/workspacesService/documentsService/connectToApp.ts b/web/packages/teleterm/src/ui/services/workspacesService/documentsService/connectToApp.ts index 93aee047a7341..2711bae403b0d 100644 --- a/web/packages/teleterm/src/ui/services/workspacesService/documentsService/connectToApp.ts +++ b/web/packages/teleterm/src/ui/services/workspacesService/documentsService/connectToApp.ts @@ -115,13 +115,21 @@ export async function connectToApp( return; } - await setUpAppGateway(ctx, target, telemetry); + await setUpAppGateway(ctx, target, { telemetry }); } export async function setUpAppGateway( ctx: IAppContext, target: App, - telemetry: { origin: DocumentOrigin } + options: { + telemetry: { origin: DocumentOrigin }; + /** + * targetPort allows the caller to preselect the target port for the gateway. Works only with + * multi-port TCP apps. If it's not specified and the app is multi-port, the first port from + * it's TCP ports is used instead. + */ + targetPort?: number; + } ) { const rootClusterUri = routing.ensureRootClusterUri(target.uri); @@ -129,16 +137,20 @@ export async function setUpAppGateway( ctx.workspacesService.getWorkspaceDocumentService(rootClusterUri); const doc = documentsService.createGatewayDocument({ targetUri: target.uri, - origin: telemetry.origin, + origin: options.telemetry.origin, targetName: routing.parseAppUri(target.uri).params.appId, targetUser: '', + targetSubresourceName: + target.tcpPorts.length > 0 + ? (options.targetPort || target.tcpPorts[0].port).toString() + : undefined, }); const connectionToReuse = ctx.connectionTracker.findConnectionByDocument(doc); if (connectionToReuse) { await ctx.connectionTracker.activateItem(connectionToReuse.id, { - origin: telemetry.origin, + origin: options.telemetry.origin, }); } else { await ctx.workspacesService.setActiveWorkspace(rootClusterUri); diff --git a/web/packages/teleterm/src/ui/services/workspacesService/documentsService/documentsService.test.ts b/web/packages/teleterm/src/ui/services/workspacesService/documentsService/documentsService.test.ts index b50989a4273ff..96d1f3129ea24 100644 --- a/web/packages/teleterm/src/ui/services/workspacesService/documentsService/documentsService.test.ts +++ b/web/packages/teleterm/src/ui/services/workspacesService/documentsService/documentsService.test.ts @@ -79,6 +79,7 @@ describe('document should be added', () => { targetUri: '/clusters/bar/dbs/quux', targetName: 'quux', targetUser: 'foo', + targetSubresourceName: undefined, origin: 'resource_table', status: '', }; @@ -155,6 +156,7 @@ test('only gateway documents should be returned', () => { targetUri: '/clusters/bar/dbs/quux', targetName: 'quux', targetUser: 'foo', + targetSubresourceName: undefined, origin: 'resource_table', status: '', }; diff --git a/web/packages/teleterm/src/ui/services/workspacesService/documentsService/types.ts b/web/packages/teleterm/src/ui/services/workspacesService/documentsService/types.ts index 970fb09d22ba7..e975d8e268ae7 100644 --- a/web/packages/teleterm/src/ui/services/workspacesService/documentsService/types.ts +++ b/web/packages/teleterm/src/ui/services/workspacesService/documentsService/types.ts @@ -109,7 +109,11 @@ export interface DocumentGateway extends DocumentBase { targetUri: uri.DatabaseUri | uri.AppUri; targetUser: string; targetName: string; - targetSubresourceName?: string; + /** + * targetSubresourceName contains database name for db gateways and target port for TCP app + * gateways. + */ + targetSubresourceName: string | undefined; port?: string; origin: DocumentOrigin; } From 852fc7d6053459d1a54089c4c1daa724163caba4 Mon Sep 17 00:00:00 2001 From: rosstimothy <39066650+rosstimothy@users.noreply.github.com> Date: Mon, 13 Jan 2025 11:58:27 -0500 Subject: [PATCH 08/15] Remove unused servicecfg.Config.Console field (#50944) Depends on https://github.com/gravitational/teleport.e/pull/5829 --- e | 2 +- e2e/aws/fixtures_test.go | 1 - integration/appaccess/fixtures.go | 2 -- integration/appaccess/pack.go | 2 -- integration/helpers/instance.go | 4 +--- integration/hostuser_test.go | 4 ++-- integration/integration_test.go | 23 +++++++++-------------- integration/kube_integration_test.go | 2 -- integration/port_forwarding_test.go | 1 - integration/proxy/proxy_helpers.go | 1 - lib/client/api_login_test.go | 4 ---- lib/config/configuration.go | 6 ------ lib/service/servicecfg/config.go | 9 --------- tool/teleport/common/teleport_test.go | 1 - 14 files changed, 13 insertions(+), 49 deletions(-) diff --git a/e b/e index 498f643ea9033..65fa473e50c72 160000 --- a/e +++ b/e @@ -1 +1 @@ -Subproject commit 498f643ea9033b1235359d83c310caadb18305d2 +Subproject commit 65fa473e50c72d8f79261033a1298cc2955ca15c diff --git a/e2e/aws/fixtures_test.go b/e2e/aws/fixtures_test.go index 1b30f64f382a5..95373466c237e 100644 --- a/e2e/aws/fixtures_test.go +++ b/e2e/aws/fixtures_test.go @@ -198,7 +198,6 @@ func newTeleportConfig(t *testing.T) *servicecfg.Config { tconf := servicecfg.MakeDefaultConfig() // Replace the default auth and proxy listeners with the ones so we can // run multiple tests in parallel. - tconf.Console = nil tconf.Proxy.DisableWebInterface = true tconf.PollingPeriod = 500 * time.Millisecond tconf.Testing.ClientTimeout = time.Second diff --git a/integration/appaccess/fixtures.go b/integration/appaccess/fixtures.go index eee6390c471f5..e6876e7cbdaec 100644 --- a/integration/appaccess/fixtures.go +++ b/integration/appaccess/fixtures.go @@ -346,7 +346,6 @@ func SetupWithOptions(t *testing.T, opts AppTestOptions) *Pack { p.leafCluster = helpers.NewInstance(t, leafCfg) rcConf := servicecfg.MakeDefaultConfig() - rcConf.Console = nil rcConf.Logger = log rcConf.DataDir = t.TempDir() rcConf.Auth.Enabled = true @@ -364,7 +363,6 @@ func SetupWithOptions(t *testing.T, opts AppTestOptions) *Pack { rcConf.Clock = opts.Clock lcConf := servicecfg.MakeDefaultConfig() - lcConf.Console = nil lcConf.Logger = log lcConf.DataDir = t.TempDir() lcConf.Auth.Enabled = true diff --git a/integration/appaccess/pack.go b/integration/appaccess/pack.go index 5a5de08691da4..24eb1e9a5dde2 100644 --- a/integration/appaccess/pack.go +++ b/integration/appaccess/pack.go @@ -759,7 +759,6 @@ func (p *Pack) startRootAppServers(t *testing.T, count int, opts AppTestOptions) for i := 0; i < count; i++ { raConf := servicecfg.MakeDefaultConfig() raConf.Clock = opts.Clock - raConf.Console = nil raConf.Logger = utils.NewSlogLoggerForTests() raConf.DataDir = t.TempDir() raConf.SetToken("static-token-value") @@ -929,7 +928,6 @@ func (p *Pack) startLeafAppServers(t *testing.T, count int, opts AppTestOptions) for i := 0; i < count; i++ { laConf := servicecfg.MakeDefaultConfig() laConf.Clock = opts.Clock - laConf.Console = nil laConf.Logger = utils.NewSlogLoggerForTests() laConf.DataDir = t.TempDir() laConf.SetToken("static-token-value") diff --git a/integration/helpers/instance.go b/integration/helpers/instance.go index 6d375387a02f6..7e7deb03567a8 100644 --- a/integration/helpers/instance.go +++ b/integration/helpers/instance.go @@ -447,10 +447,9 @@ func (i *TeleInstance) GetSiteAPI(siteName string) authclient.ClientI { // Create creates a new instance of Teleport which trusts a list of other clusters (other // instances) -func (i *TeleInstance) Create(t *testing.T, trustedSecrets []*InstanceSecrets, enableSSH bool, console io.Writer) error { +func (i *TeleInstance) Create(t *testing.T, trustedSecrets []*InstanceSecrets, enableSSH bool) error { tconf := servicecfg.MakeDefaultConfig() tconf.SSH.Enabled = enableSSH - tconf.Console = console tconf.Logger = i.Log tconf.Proxy.DisableWebService = true tconf.Proxy.DisableWebInterface = true @@ -1129,7 +1128,6 @@ func (i *TeleInstance) StartProxy(cfg ProxyConfig, opts ...Option) (reversetunne i.tempDirs = append(i.tempDirs, dataDir) tconf := servicecfg.MakeDefaultConfig() - tconf.Console = nil tconf.Logger = i.Log authServer := utils.MustParseAddr(i.Auth) tconf.SetAuthServerAddress(*authServer) diff --git a/integration/hostuser_test.go b/integration/hostuser_test.go index 02145bc38274e..540ae35a59c48 100644 --- a/integration/hostuser_test.go +++ b/integration/hostuser_test.go @@ -661,7 +661,7 @@ func TestRootLoginAsHostUser(t *testing.T) { Roles: []types.Role{role}, } - require.NoError(t, instance.Create(t, nil, true, nil)) + require.NoError(t, instance.Create(t, nil, true)) require.NoError(t, instance.Start()) t.Cleanup(func() { require.NoError(t, instance.StopAll()) @@ -740,7 +740,7 @@ func TestRootStaticHostUsers(t *testing.T) { Logger: utils.NewSlogLoggerForTests(), }) - require.NoError(t, instance.Create(t, nil, false, nil)) + require.NoError(t, instance.Create(t, nil, false)) require.NoError(t, instance.Start()) t.Cleanup(func() { require.NoError(t, instance.StopAll()) diff --git a/integration/integration_test.go b/integration/integration_test.go index 0b48c90b46f39..4e2c4bed4974e 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -711,7 +711,7 @@ func (s *integrationTestSuite) newUnstartedTeleport(t *testing.T, logins []strin for _, login := range logins { teleport.AddUser(login, []string{login}) } - require.NoError(t, teleport.Create(t, nil, enableSSH, nil)) + require.NoError(t, teleport.Create(t, nil, enableSSH)) return teleport } @@ -2564,9 +2564,9 @@ func testTwoClustersProxy(t *testing.T, suite *integrationTestSuite) { a.AddUser(username, []string{username}) b.AddUser(username, []string{username}) - require.NoError(t, b.Create(t, a.Secrets.AsSlice(), false, nil)) + require.NoError(t, b.Create(t, a.Secrets.AsSlice(), false)) defer b.StopAll() - require.NoError(t, a.Create(t, b.Secrets.AsSlice(), true, nil)) + require.NoError(t, a.Create(t, b.Secrets.AsSlice(), true)) defer a.StopAll() require.NoError(t, b.Start()) @@ -2602,8 +2602,8 @@ func testHA(t *testing.T, suite *integrationTestSuite) { a.AddUser(username, []string{username}) b.AddUser(username, []string{username}) - require.NoError(t, b.Create(t, a.Secrets.AsSlice(), true, nil)) - require.NoError(t, a.Create(t, b.Secrets.AsSlice(), true, nil)) + require.NoError(t, b.Create(t, a.Secrets.AsSlice(), true)) + require.NoError(t, a.Create(t, b.Secrets.AsSlice(), true)) require.NoError(t, b.Start()) require.NoError(t, a.Start()) @@ -3950,13 +3950,13 @@ func testDiscoveryRecovers(t *testing.T, suite *integrationTestSuite) { remote.AddUser(username, []string{username}) main.AddUser(username, []string{username}) - require.NoError(t, main.Create(t, remote.Secrets.AsSlice(), false, nil)) + require.NoError(t, main.Create(t, remote.Secrets.AsSlice(), false)) mainSecrets := main.Secrets // switch listen address of the main cluster to load balancer mainProxyAddr := *utils.MustParseAddr(mainSecrets.TunnelAddr) lb.AddBackend(mainProxyAddr) mainSecrets.TunnelAddr = lb.Addr().String() - require.NoError(t, remote.Create(t, mainSecrets.AsSlice(), true, nil)) + require.NoError(t, remote.Create(t, mainSecrets.AsSlice(), true)) require.NoError(t, main.Start()) require.NoError(t, remote.Start()) @@ -4085,13 +4085,13 @@ func testDiscovery(t *testing.T, suite *integrationTestSuite) { remote.AddUser(username, []string{username}) main.AddUser(username, []string{username}) - require.NoError(t, main.Create(t, remote.Secrets.AsSlice(), false, nil)) + require.NoError(t, main.Create(t, remote.Secrets.AsSlice(), false)) mainSecrets := main.Secrets // switch listen address of the main cluster to load balancer mainProxyAddr := *utils.MustParseAddr(mainSecrets.TunnelAddr) lb.AddBackend(mainProxyAddr) mainSecrets.TunnelAddr = lb.Addr().String() - require.NoError(t, remote.Create(t, mainSecrets.AsSlice(), true, nil)) + require.NoError(t, remote.Create(t, mainSecrets.AsSlice(), true)) require.NoError(t, main.Start()) require.NoError(t, remote.Start()) @@ -7223,7 +7223,6 @@ func WithListeners(setupFn helpers.InstanceListenerSetupFunc) InstanceConfigOpti func (s *integrationTestSuite) defaultServiceConfig() *servicecfg.Config { cfg := servicecfg.MakeDefaultConfig() - cfg.Console = nil cfg.Logger = s.Log cfg.CircuitBreakerConfig = breaker.NoopBreakerConfig() cfg.InstanceMetadataClient = imds.NewDisabledIMDSClient() @@ -8572,7 +8571,6 @@ func TestConnectivityWithoutAuth(t *testing.T) { // Create auth config. authCfg := servicecfg.MakeDefaultConfig() - authCfg.Console = nil authCfg.Logger = utils.NewSlogLoggerForTests() authCfg.CircuitBreakerConfig = breaker.NoopBreakerConfig() authCfg.InstanceMetadataClient = imds.NewDisabledIMDSClient() @@ -8635,7 +8633,6 @@ func TestConnectivityWithoutAuth(t *testing.T) { nodeCfg.SetToken("token") nodeCfg.CachePolicy.Enabled = true nodeCfg.DataDir = t.TempDir() - nodeCfg.Console = nil nodeCfg.Logger = utils.NewSlogLoggerForTests() nodeCfg.CircuitBreakerConfig = breaker.NoopBreakerConfig() nodeCfg.InstanceMetadataClient = imds.NewDisabledIMDSClient() @@ -8716,7 +8713,6 @@ func TestConnectivityDuringAuthRestart(t *testing.T) { // Create auth config. authCfg := servicecfg.MakeDefaultConfig() - authCfg.Console = nil authCfg.Logger = utils.NewSlogLoggerForTests() authCfg.CircuitBreakerConfig = breaker.NoopBreakerConfig() authCfg.InstanceMetadataClient = imds.NewDisabledIMDSClient() @@ -8776,7 +8772,6 @@ func TestConnectivityDuringAuthRestart(t *testing.T) { nodeCfg.SetToken("token") nodeCfg.CachePolicy.Enabled = true nodeCfg.DataDir = t.TempDir() - nodeCfg.Console = nil nodeCfg.Logger = utils.NewSlogLoggerForTests() nodeCfg.CircuitBreakerConfig = breaker.NoopBreakerConfig() nodeCfg.InstanceMetadataClient = imds.NewDisabledIMDSClient() diff --git a/integration/kube_integration_test.go b/integration/kube_integration_test.go index 264bbfdf50706..51568d0e6bc7e 100644 --- a/integration/kube_integration_test.go +++ b/integration/kube_integration_test.go @@ -1833,7 +1833,6 @@ type sessionMetadataResponse struct { // teleKubeConfig sets up teleport with kubernetes turned on func (s *KubeSuite) teleKubeConfig(hostname string) *servicecfg.Config { tconf := servicecfg.MakeDefaultConfig() - tconf.Console = nil tconf.Logger = s.log tconf.SSH.Enabled = true tconf.Proxy.DisableWebInterface = true @@ -1854,7 +1853,6 @@ func (s *KubeSuite) teleKubeConfig(hostname string) *servicecfg.Config { // teleKubeConfig sets up teleport with kubernetes turned on func (s *KubeSuite) teleAuthConfig(hostname string) *servicecfg.Config { tconf := servicecfg.MakeDefaultConfig() - tconf.Console = nil tconf.Logger = s.log tconf.PollingPeriod = 500 * time.Millisecond tconf.Testing.ClientTimeout = time.Second diff --git a/integration/port_forwarding_test.go b/integration/port_forwarding_test.go index 88af150695872..cdef9e9b6f35a 100644 --- a/integration/port_forwarding_test.go +++ b/integration/port_forwarding_test.go @@ -205,7 +205,6 @@ func testPortForwarding(t *testing.T, suite *integrationTestSuite) { nodeCfg.SetToken("token") nodeCfg.CachePolicy.Enabled = true nodeCfg.DataDir = t.TempDir() - nodeCfg.Console = nil nodeCfg.Auth.Enabled = false nodeCfg.Proxy.Enabled = false nodeCfg.SSH.Enabled = true diff --git a/integration/proxy/proxy_helpers.go b/integration/proxy/proxy_helpers.go index b5796110eb53d..e4e0823c2cefd 100644 --- a/integration/proxy/proxy_helpers.go +++ b/integration/proxy/proxy_helpers.go @@ -196,7 +196,6 @@ func newSuite(t *testing.T, opts ...proxySuiteOptionsFunc) *Suite { func (p *Suite) addNodeToLeafCluster(t *testing.T, tunnelNodeHostname string) { nodeConfig := func() *servicecfg.Config { tconf := servicecfg.MakeDefaultConfig() - tconf.Console = nil tconf.Logger = utils.NewSlogLoggerForTests() tconf.Hostname = tunnelNodeHostname tconf.SetToken("token") diff --git a/lib/client/api_login_test.go b/lib/client/api_login_test.go index e06e73c6ce648..15d09fc03b671 100644 --- a/lib/client/api_login_test.go +++ b/lib/client/api_login_test.go @@ -516,8 +516,6 @@ type standaloneBundle struct { func newStandaloneTeleport(t *testing.T, clock clockwork.Clock) *standaloneBundle { randomAddr := utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"} - console := io.Discard - staticToken := uuid.New().String() // Prepare role and user. @@ -549,7 +547,6 @@ func newStandaloneTeleport(t *testing.T, clock clockwork.Clock) *standaloneBundl cfg.DataDir = makeDataDir() cfg.Hostname = "localhost" cfg.Clock = clock - cfg.Console = console cfg.Logger = utils.NewSlogLoggerForTests() cfg.SetAuthServerAddress(randomAddr) // must be present cfg.Auth.Preference, err = types.NewAuthPreferenceFromConfigFile(types.AuthPreferenceSpecV2{ @@ -633,7 +630,6 @@ func newStandaloneTeleport(t *testing.T, clock clockwork.Clock) *standaloneBundl cfg.Hostname = "localhost" cfg.SetToken(staticToken) cfg.Clock = clock - cfg.Console = console cfg.Logger = utils.NewSlogLoggerForTests() cfg.SetAuthServerAddress(*authAddr) cfg.Auth.Enabled = false diff --git a/lib/config/configuration.go b/lib/config/configuration.go index dda2ac6859cf4..45d3544012cfe 100644 --- a/lib/config/configuration.go +++ b/lib/config/configuration.go @@ -781,11 +781,6 @@ func applyAuthOrProxyAddress(fc *FileConfig, cfg *servicecfg.Config) error { } func applyLogConfig(loggerConfig Log, cfg *servicecfg.Config) error { - switch loggerConfig.Output { - case "stderr", "error", "2", "stdout", "out", "1": - cfg.Console = io.Discard // disable console printing - } - logger, level, err := logutils.Initialize(logutils.Config{ Output: loggerConfig.Output, Severity: loggerConfig.Severity, @@ -2514,7 +2509,6 @@ func Configure(clf *CommandLineFlags, cfg *servicecfg.Config, legacyAppFlags boo // apply --debug flag to config: if clf.Debug { - cfg.Console = io.Discard cfg.Debug = clf.Debug } diff --git a/lib/service/servicecfg/config.go b/lib/service/servicecfg/config.go index a89e29a2c7b54..a7841b4d10db4 100644 --- a/lib/service/servicecfg/config.go +++ b/lib/service/servicecfg/config.go @@ -21,7 +21,6 @@ package servicecfg import ( "context" - "io" "log/slog" "net" "net/http" @@ -133,9 +132,6 @@ type Config struct { // a teleport cluster). It's automatically generated on 1st start HostUUID string - // Console writer to speak to a user - Console io.Writer - // ReverseTunnels is a list of reverse tunnels to create on the // first cluster start ReverseTunnels []types.ReverseTunnel @@ -551,7 +547,6 @@ func ApplyDefaults(cfg *Config) { // Global defaults. cfg.Hostname = hostname cfg.DataDir = defaults.DataDir - cfg.Console = os.Stdout cfg.CipherSuites = utils.DefaultCipherSuites() cfg.Ciphers = sc.Ciphers cfg.KEXAlgorithms = kex @@ -695,10 +690,6 @@ func applyDefaults(cfg *Config) { cfg.Version = defaults.TeleportConfigVersionV1 } - if cfg.Console == nil { - cfg.Console = io.Discard - } - if cfg.Logger == nil { cfg.Logger = slog.Default() } diff --git a/tool/teleport/common/teleport_test.go b/tool/teleport/common/teleport_test.go index 7b1292f1e625c..fbf449fe37bc1 100644 --- a/tool/teleport/common/teleport_test.go +++ b/tool/teleport/common/teleport_test.go @@ -84,7 +84,6 @@ func TestTeleportMain(t *testing.T) { require.True(t, conf.Auth.Enabled) require.True(t, conf.SSH.Enabled) require.True(t, conf.Proxy.Enabled) - require.Equal(t, os.Stdout, conf.Console) require.True(t, slog.Default().Handler().Enabled(context.Background(), slog.LevelError)) }) From 9948ed4e2557820076b0d305f465f8f06d788941 Mon Sep 17 00:00:00 2001 From: rosstimothy <39066650+rosstimothy@users.noreply.github.com> Date: Mon, 13 Jan 2025 12:03:07 -0500 Subject: [PATCH 09/15] Use locally scoped slog loggers instead of default (#50950) Cleans up some todos left during the logrus to slog conversion process. --- lib/reversetunnel/srv.go | 17 +++++++---------- lib/srv/discovery/database_watcher.go | 4 +--- lib/srv/discovery/discovery.go | 5 ++--- lib/srv/discovery/kube_services_watcher.go | 4 +--- 4 files changed, 11 insertions(+), 19 deletions(-) diff --git a/lib/reversetunnel/srv.go b/lib/reversetunnel/srv.go index eb7483eec6477..e83efccf31166 100644 --- a/lib/reversetunnel/srv.go +++ b/lib/reversetunnel/srv.go @@ -303,8 +303,7 @@ func NewServer(cfg Config) (reversetunnelclient.Server, error) { ResourceWatcherConfig: services.ResourceWatcherConfig{ Component: cfg.Component, Client: cfg.LocalAccessPoint, - // TODO(tross): update this after converting to slog here - // Logger: cfg.Log, + Logger: cfg.Logger, }, ProxiesC: make(chan []types.Server, 10), ProxyGetter: cfg.LocalAccessPoint, @@ -1211,10 +1210,9 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite, remoteSite.remoteAccessPoint = accessPoint nodeWatcher, err := services.NewNodeWatcher(closeContext, services.NodeWatcherConfig{ ResourceWatcherConfig: services.ResourceWatcherConfig{ - Component: srv.Component, - Client: accessPoint, - // TODO(tross) update this after converting to use slog - // Logger: srv.Log, + Component: srv.Component, + Client: accessPoint, + Logger: srv.Logger, MaxStaleness: time.Minute, }, NodesGetter: accessPoint, @@ -1247,10 +1245,9 @@ func newRemoteSite(srv *server, domainName string, sconn ssh.Conn) (*remoteSite, remoteWatcher, err := services.NewCertAuthorityWatcher(srv.ctx, services.CertAuthorityWatcherConfig{ ResourceWatcherConfig: services.ResourceWatcherConfig{ Component: teleport.ComponentProxy, - // TODO(tross): update this after converting to slog - // Logger: srv.log, - Clock: srv.Clock, - Client: remoteSite.remoteAccessPoint, + Logger: srv.logger, + Clock: srv.Clock, + Client: remoteSite.remoteAccessPoint, }, Types: []types.CertAuthType{types.HostCA}, }) diff --git a/lib/srv/discovery/database_watcher.go b/lib/srv/discovery/database_watcher.go index b14332a8f9bb4..297b9a7dfd8cd 100644 --- a/lib/srv/discovery/database_watcher.go +++ b/lib/srv/discovery/database_watcher.go @@ -20,7 +20,6 @@ package discovery import ( "context" - "log/slog" "sync" "github.com/gravitational/trace" @@ -54,8 +53,7 @@ func (s *Server) startDatabaseWatchers() error { defer mu.Unlock() return utils.FromSlice(newDatabases, types.Database.GetName) }, - // TODO(tross): update to use the server logger once it is converted to use slog - Logger: slog.With("kind", types.KindDatabase), + Logger: s.Log.With("kind", types.KindDatabase), OnCreate: s.onDatabaseCreate, OnUpdate: s.onDatabaseUpdate, OnDelete: s.onDatabaseDelete, diff --git a/lib/srv/discovery/discovery.go b/lib/srv/discovery/discovery.go index 047553edeabde..f27f60a112e2b 100644 --- a/lib/srv/discovery/discovery.go +++ b/lib/srv/discovery/discovery.go @@ -1908,9 +1908,8 @@ func (s *Server) getAzureSubscriptions(ctx context.Context, subs []string) ([]st func (s *Server) initTeleportNodeWatcher() (err error) { s.nodeWatcher, err = services.NewNodeWatcher(s.ctx, services.NodeWatcherConfig{ ResourceWatcherConfig: services.ResourceWatcherConfig{ - Component: teleport.ComponentDiscovery, - // TODO(tross): update this after converting logging to use slog - // Logger: s.Logger, + Component: teleport.ComponentDiscovery, + Logger: s.Log, Client: s.AccessPoint, MaxStaleness: time.Minute, }, diff --git a/lib/srv/discovery/kube_services_watcher.go b/lib/srv/discovery/kube_services_watcher.go index 8e63c6947242a..9940734248588 100644 --- a/lib/srv/discovery/kube_services_watcher.go +++ b/lib/srv/discovery/kube_services_watcher.go @@ -20,7 +20,6 @@ package discovery import ( "context" - "log/slog" "sync" "time" @@ -62,8 +61,7 @@ func (s *Server) startKubeAppsWatchers() error { defer mu.Unlock() return utils.FromSlice(appResources, types.Application.GetName) }, - // TODO(tross): update to use the server logger once it is converted to use slog - Logger: slog.With("kind", types.KindApp), + Logger: s.Log.With("kind", types.KindApp), OnCreate: s.onAppCreate, OnUpdate: s.onAppUpdate, OnDelete: s.onAppDelete, From 646329da7db47dd83472ebc7bc0e36e8f67479e5 Mon Sep 17 00:00:00 2001 From: Bernard Kim Date: Mon, 13 Jan 2025 09:21:09 -0800 Subject: [PATCH 10/15] operator: Support trusted_cluster resources (#49920) * Add UpsertTrustedClusterV2 rpc This supersedes UpsertTrustedCluster rpc. V2 performs resource name validation. * Replace confusing UpsertValidationTrustedCluster name * Use UpsertTrustedClusterV2 in tests * Address feedback - Remove unnecessary ping - Update error messages - Use skipNameValidation consts - Validate cluster name before establishing trust - Do not reveal cluster name in error message - Use BadParameter instead of CompareFailed * Use webclient.Find * Fix test/lint * Allow label updates * Fix test * Update CRDs 1. Run `make manifests`. 2. Run `make -C crdgen update-protos`. 3. Run `make -C crdgen update-snapshot`. * Implement trusted_cluster CRD * Update docs * Support secret lookup * Update secret lookup docs * Fix error handling * Use V2 * Implement CreateTrustedClusterV2 and UpdateTrustedClusterV2 * Address feedback * Minor fixes * Use Create/Update instead of Upsert * Update crdgen * Update trusted_cluster tests * Move V2 RPCs to the trust service * crdgen * Remove V2 suffix * 2024 -> 2025 * Use slog --- .../teleport-operator/secret-lookup.mdx | 5 +- .../teleport-operator/teleport-operator.mdx | 19 +- ...sources.teleport.dev_trustedclustersv2.mdx | 41 ++ ...ources.teleport.dev_trustedclustersv2.yaml | 149 +++++++ .../teleport-operator/templates/role.yaml | 2 + .../templates/auth/config.yaml | 8 + integrations/operator/README.md | 4 + .../resources/v1/trusted_cluster_types.go | 96 +++++ .../resources/v1/zz_generated.deepcopy.go | 69 ++++ ...ources.teleport.dev_trustedclustersv2.yaml | 149 +++++++ .../legacy_resource_without_labels.go | 2 +- .../operator/controllers/resources/setup.go | 1 + .../controllers/resources/testlib/env.go | 1 + .../resources/trusted_cluster_controller.go | 91 +++++ .../trusted_clusterv2_controller_test.go | 356 +++++++++++++++++ .../operator/crdgen/additional_doc.go | 3 + integrations/operator/crdgen/handlerequest.go | 1 + integrations/operator/crdgen/ignored.go | 3 + ...ces.teleport.dev_openssheiceserversv2.yaml | 14 + ...sources.teleport.dev_opensshserversv2.yaml | 14 + .../golden/resources.teleport.dev_roles.yaml | 92 +++++ .../resources.teleport.dev_rolesv6.yaml | 46 +++ .../resources.teleport.dev_rolesv7.yaml | 46 +++ ...ources.teleport.dev_trustedclustersv2.yaml | 149 +++++++ .../golden/resources.teleport.dev_users.yaml | 12 + .../legacy/client/proto/authservice.proto | 124 +++++- .../teleport/legacy/client/proto/event.proto | 5 + .../teleport/legacy/types/events/events.proto | 378 ++++++++++++++++++ .../types/trusted_device_requirement.proto | 37 ++ .../teleport/legacy/types/types.proto | 181 ++++++++- .../operator/hack/fixture-operator-role.yaml | 8 + lib/auth/trustedcluster.go | 1 - 32 files changed, 2074 insertions(+), 33 deletions(-) create mode 100644 docs/pages/reference/operator-resources/resources.teleport.dev_trustedclustersv2.mdx create mode 100644 examples/chart/teleport-cluster/charts/teleport-operator/operator-crds/resources.teleport.dev_trustedclustersv2.yaml create mode 100644 integrations/operator/apis/resources/v1/trusted_cluster_types.go create mode 100644 integrations/operator/config/crd/bases/resources.teleport.dev_trustedclustersv2.yaml create mode 100644 integrations/operator/controllers/resources/trusted_cluster_controller.go create mode 100644 integrations/operator/controllers/resources/trusted_clusterv2_controller_test.go create mode 100644 integrations/operator/crdgen/testdata/golden/resources.teleport.dev_trustedclustersv2.yaml create mode 100644 integrations/operator/crdgen/testdata/protofiles/teleport/legacy/types/trusted_device_requirement.proto diff --git a/docs/pages/admin-guides/infrastructure-as-code/teleport-operator/secret-lookup.mdx b/docs/pages/admin-guides/infrastructure-as-code/teleport-operator/secret-lookup.mdx index a23e4935c5051..ca8f2b20e5b60 100644 --- a/docs/pages/admin-guides/infrastructure-as-code/teleport-operator/secret-lookup.mdx +++ b/docs/pages/admin-guides/infrastructure-as-code/teleport-operator/secret-lookup.mdx @@ -11,7 +11,10 @@ of the Teleport Kubernetes operator CRs. Some Teleport resources might contain sensitive values. Select CR fields can reference an existing Kubernetes secret and the operator will retrieve the value from the secret when reconciling. -Currently only the GithubConnector and OIDCConnector `client_secret` field support secret lookup. +Currently supported fields for secret lookup: +- GithubConnector `client_secret` +- OIDCConnector `client_secret` +- TrustedClusterV2 `token` ## Prerequisites diff --git a/docs/pages/admin-guides/infrastructure-as-code/teleport-operator/teleport-operator.mdx b/docs/pages/admin-guides/infrastructure-as-code/teleport-operator/teleport-operator.mdx index e8dec4b877a13..890421acf2742 100644 --- a/docs/pages/admin-guides/infrastructure-as-code/teleport-operator/teleport-operator.mdx +++ b/docs/pages/admin-guides/infrastructure-as-code/teleport-operator/teleport-operator.mdx @@ -26,16 +26,21 @@ could cause instability and non-deterministic behaviour. Currently supported Teleport resources are: -- users (`TeleportUser`) -- roles +- Users (`TeleportUser`) +- Roles - `TeleportRole` creates role v5 - `TeleportRoleV6` creates role v6 - `TeleportRoleV7` creates role v7 -- OIDC connectors (`TeleportOIDCConnector`) -- SAML connectors (`TeleportSAMLConnector`) -- GitHub connectors (`TeleportGithubConnector`) -- provision tokens (`TeleportProvisionToken`) -- Login Rules (`TeleportLoginRules`) +- OIDC Connectors (`TeleportOIDCConnector`) +- SAML Connectors (`TeleportSAMLConnector`) +- GitHub Connectors (`TeleportGithubConnector`) +- Provision Tokens (`TeleportProvisionToken`) +- Login Rules (`TeleportLoginRule`) +- Access Lists (`TeleportAccessList`) +- Okta Import Rules (`TeleportOktaImportRule`) +- OpenSSHEICE Servers (`TeleportOpenSSHEICEServerV2`) +- OpenSSH Servers (`TeleportOpenSSHServerV2`) +- Trusted Clusters (`TeleportTrustedClusterV2`) ### Setting up the operator diff --git a/docs/pages/reference/operator-resources/resources.teleport.dev_trustedclustersv2.mdx b/docs/pages/reference/operator-resources/resources.teleport.dev_trustedclustersv2.mdx new file mode 100644 index 0000000000000..8728b51b2ab5c --- /dev/null +++ b/docs/pages/reference/operator-resources/resources.teleport.dev_trustedclustersv2.mdx @@ -0,0 +1,41 @@ +--- +title: TeleportTrustedClusterV2 +description: Provides a comprehensive list of fields in the TeleportTrustedClusterV2 resource available through the Teleport Kubernetes operator +tocDepth: 3 +--- + +{/*Auto-generated file. Do not edit.*/} +{/*To regenerate, navigate to integrations/operator and run "make crd-docs".*/} + +This guide is a comprehensive reference to the fields in the `TeleportTrustedClusterV2` +resource, which you can apply after installing the Teleport Kubernetes operator. + + +## resources.teleport.dev/v1 + +**apiVersion:** resources.teleport.dev/v1 + +|Field|Type|Description| +|---|---|---| +|apiVersion|string|APIVersion defines the versioned schema of this representation of an object. Servers should convert recognized schemas to the latest internal value, and may reject unrecognized values. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources| +|kind|string|Kind is a string value representing the REST resource this object represents. Servers may infer this from the endpoint the client submits requests to. Cannot be updated. In CamelCase. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds| +|metadata|object|| +|spec|[object](#spec)|TrustedCluster resource definition v2 from Teleport| + +### spec + +|Field|Type|Description| +|---|---|---| +|enabled|boolean|Enabled is a bool that indicates if the TrustedCluster is enabled or disabled. Setting Enabled to false has a side effect of deleting the user and host certificate authority (CA).| +|role_map|[][object](#specrole_map-items)|RoleMap specifies role mappings to remote roles.| +|token|string|Token is the authorization token provided by another cluster needed by this cluster to join. This field supports secret lookup. See the operator documentation for more details.| +|tunnel_addr|string|ReverseTunnelAddress is the address of the SSH proxy server of the cluster to join. If not set, it is derived from `:`.| +|web_proxy_addr|string|ProxyAddress is the address of the web proxy server of the cluster to join. If not set, it is derived from `:`.| + +### spec.role_map items + +|Field|Type|Description| +|---|---|---| +|local|[]string|Local specifies local roles to map to| +|remote|string|Remote specifies remote role name to map from| + diff --git a/examples/chart/teleport-cluster/charts/teleport-operator/operator-crds/resources.teleport.dev_trustedclustersv2.yaml b/examples/chart/teleport-cluster/charts/teleport-operator/operator-crds/resources.teleport.dev_trustedclustersv2.yaml new file mode 100644 index 0000000000000..4cf1410472b64 --- /dev/null +++ b/examples/chart/teleport-cluster/charts/teleport-operator/operator-crds/resources.teleport.dev_trustedclustersv2.yaml @@ -0,0 +1,149 @@ +apiVersion: apiextensions.k8s.io/v1 +kind: CustomResourceDefinition +metadata: + creationTimestamp: null + name: teleporttrustedclustersv2.resources.teleport.dev +spec: + group: resources.teleport.dev + names: + kind: TeleportTrustedClusterV2 + listKind: TeleportTrustedClusterV2List + plural: teleporttrustedclustersv2 + shortNames: + - trustedclusterv2 + - trustedclustersv2 + singular: teleporttrustedclusterv2 + scope: Namespaced + versions: + - name: v1 + schema: + openAPIV3Schema: + description: TrustedClusterV2 is the Schema for the trustedclustersv2 API + properties: + apiVersion: + description: 'APIVersion defines the versioned schema of this representation + of an object. Servers should convert recognized schemas to the latest + internal value, and may reject unrecognized values. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources' + type: string + kind: + description: 'Kind is a string value representing the REST resource this + object represents. Servers may infer this from the endpoint the client + submits requests to. Cannot be updated. In CamelCase. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds' + type: string + metadata: + type: object + spec: + description: TrustedCluster resource definition v2 from Teleport + properties: + enabled: + description: Enabled is a bool that indicates if the TrustedCluster + is enabled or disabled. Setting Enabled to false has a side effect + of deleting the user and host certificate authority (CA). + type: boolean + role_map: + description: RoleMap specifies role mappings to remote roles. + items: + properties: + local: + description: Local specifies local roles to map to + items: + type: string + nullable: true + type: array + remote: + description: Remote specifies remote role name to map from + type: string + type: object + type: array + token: + description: Token is the authorization token provided by another + cluster needed by this cluster to join. This field supports secret + lookup. See the operator documentation for more details. + type: string + tunnel_addr: + description: ReverseTunnelAddress is the address of the SSH proxy + server of the cluster to join. If not set, it is derived from `:`. + type: string + web_proxy_addr: + description: ProxyAddress is the address of the web proxy server of + the cluster to join. If not set, it is derived from `:`. + type: string + type: object + status: + description: Status defines the observed state of the Teleport resource + properties: + conditions: + description: Conditions represent the latest available observations + of an object's state + items: + description: Condition contains details for one aspect of the current + state of this API Resource. + properties: + lastTransitionTime: + description: |- + lastTransitionTime is the last time the condition transitioned from one status to another. + This should be when the underlying condition changed. If that is not known, then using the time when the API field changed is acceptable. + format: date-time + type: string + message: + description: |- + message is a human readable message indicating details about the transition. + This may be an empty string. + maxLength: 32768 + type: string + observedGeneration: + description: |- + observedGeneration represents the .metadata.generation that the condition was set based upon. + For instance, if .metadata.generation is currently 12, but the .status.conditions[x].observedGeneration is 9, the condition is out of date + with respect to the current state of the instance. + format: int64 + minimum: 0 + type: integer + reason: + description: |- + reason contains a programmatic identifier indicating the reason for the condition's last transition. + Producers of specific condition types may define expected values and meanings for this field, + and whether the values are considered a guaranteed API. + The value should be a CamelCase string. + This field may not be empty. + maxLength: 1024 + minLength: 1 + pattern: ^[A-Za-z]([A-Za-z0-9_,:]*[A-Za-z0-9_])?$ + type: string + status: + description: status of the condition, one of True, False, Unknown. + enum: + - "True" + - "False" + - Unknown + type: string + type: + description: type of condition in CamelCase or in foo.example.com/CamelCase. + maxLength: 316 + pattern: ^([a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*/)?(([A-Za-z0-9][-A-Za-z0-9_.]*)?[A-Za-z0-9])$ + type: string + required: + - lastTransitionTime + - message + - reason + - status + - type + type: object + type: array + teleportResourceID: + format: int64 + type: integer + type: object + type: object + served: true + storage: true + subresources: + status: {} +status: + acceptedNames: + kind: "" + plural: "" + conditions: null + storedVersions: null diff --git a/examples/chart/teleport-cluster/charts/teleport-operator/templates/role.yaml b/examples/chart/teleport-cluster/charts/teleport-operator/templates/role.yaml index 25b8c72416dc6..1b7c21935ce5c 100644 --- a/examples/chart/teleport-cluster/charts/teleport-operator/templates/role.yaml +++ b/examples/chart/teleport-cluster/charts/teleport-operator/templates/role.yaml @@ -36,6 +36,8 @@ rules: - teleportopensshserversv2/status - teleportopenssheiceserversv2 - teleportopenssheiceserversv2/status + - teleporttrustedclustersv2 + - teleporttrustedclustersv2/status verbs: - get - list diff --git a/examples/chart/teleport-cluster/templates/auth/config.yaml b/examples/chart/teleport-cluster/templates/auth/config.yaml index 99fe59e061c9c..d1c4bffcb5cf6 100644 --- a/examples/chart/teleport-cluster/templates/auth/config.yaml +++ b/examples/chart/teleport-cluster/templates/auth/config.yaml @@ -131,6 +131,14 @@ data: - read - update - delete + - resources: + - trusted_cluster + verbs: + - list + - create + - read + - update + - delete deny: {} version: v7 --- diff --git a/integrations/operator/README.md b/integrations/operator/README.md index 8e91c62d6d46c..d240ca82da84b 100644 --- a/integrations/operator/README.md +++ b/integrations/operator/README.md @@ -20,6 +20,10 @@ The operator supports reconciling the following Kubernetes CRs: - TeleportRoleV7 (creates role v7) - TeleportProvisionToken - TeleportGithubConnector +- TeleportAccessList +- TeleportOpenSSHEICEServerV2 +- TeleportOpenSSHServerV2 +- TeleportTrustedClusterV2 - TeleportSAMLConnector [1] - TeleportOIDCConnector [1] - TeleportLoginRule [1] diff --git a/integrations/operator/apis/resources/v1/trusted_cluster_types.go b/integrations/operator/apis/resources/v1/trusted_cluster_types.go new file mode 100644 index 0000000000000..0f6b8f753fac2 --- /dev/null +++ b/integrations/operator/apis/resources/v1/trusted_cluster_types.go @@ -0,0 +1,96 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package v1 + +import ( + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/integrations/operator/apis/resources" +) + +func init() { + SchemeBuilder.Register(&TeleportTrustedClusterV2{}, &TeleportTrustedClusterV2List{}) +} + +//+kubebuilder:object:root=true +//+kubebuilder:subresource:status + +// TeleportTrustedClusterV2 is the Schema for the trusted_clusters API +type TeleportTrustedClusterV2 struct { + metav1.TypeMeta `json:",inline"` + metav1.ObjectMeta `json:"metadata,omitempty"` + + Spec TeleportTrustedClusterV2Spec `json:"spec,omitempty"` + Status resources.Status `json:"status,omitempty"` +} + +//+kubebuilder:object:root=true + +// TeleportTrustedClusterV2List contains a list of TeleportTrustedClusterV2 +type TeleportTrustedClusterV2List struct { + metav1.TypeMeta `json:",inline"` + metav1.ListMeta `json:"metadata,omitempty"` + Items []TeleportTrustedClusterV2 `json:"items"` +} + +// ToTeleport converts the resource to the teleport trusted_cluster API type. +func (r TeleportTrustedClusterV2) ToTeleport() types.TrustedCluster { + return &types.TrustedClusterV2{ + Kind: types.KindTrustedCluster, + Version: types.V2, + Metadata: types.Metadata{ + Name: r.Name, + Labels: r.Labels, + Description: r.Annotations[resources.DescriptionKey], + }, + Spec: types.TrustedClusterSpecV2(r.Spec), + } +} + +// TeleportTrustedClusterV2Spec defines the desired state of TeleportTrustedClusterV2 +type TeleportTrustedClusterV2Spec types.TrustedClusterSpecV2 + +// Marshal serializes a spec into binary data. +func (spec *TeleportTrustedClusterV2Spec) Marshal() ([]byte, error) { + return (*types.TrustedClusterSpecV2)(spec).Marshal() +} + +// Unmarshal deserializes a spec from binary data. +func (spec *TeleportTrustedClusterV2Spec) Unmarshal(data []byte) error { + return (*types.TrustedClusterSpecV2)(spec).Unmarshal(data) +} + +// DeepCopyInto deep-copies one trusted_cluster spec into another. +// Required to satisfy runtime.Object interface. +func (spec *TeleportTrustedClusterV2Spec) DeepCopyInto(out *TeleportTrustedClusterV2Spec) { + data, err := spec.Marshal() + if err != nil { + panic(err) + } + *out = TeleportTrustedClusterV2Spec{} + if err = out.Unmarshal(data); err != nil { + panic(err) + } +} + +// StatusConditions returns a pointer to Status.Conditions slice. +func (r *TeleportTrustedClusterV2) StatusConditions() *[]metav1.Condition { + return &r.Status.Conditions +} diff --git a/integrations/operator/apis/resources/v1/zz_generated.deepcopy.go b/integrations/operator/apis/resources/v1/zz_generated.deepcopy.go index e2f6b7ce932c1..6b803d79d2577 100644 --- a/integrations/operator/apis/resources/v1/zz_generated.deepcopy.go +++ b/integrations/operator/apis/resources/v1/zz_generated.deepcopy.go @@ -605,3 +605,72 @@ func (in *TeleportRoleV7Spec) DeepCopy() *TeleportRoleV7Spec { in.DeepCopyInto(out) return out } + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *TeleportTrustedClusterV2) DeepCopyInto(out *TeleportTrustedClusterV2) { + *out = *in + out.TypeMeta = in.TypeMeta + in.ObjectMeta.DeepCopyInto(&out.ObjectMeta) + in.Spec.DeepCopyInto(&out.Spec) + in.Status.DeepCopyInto(&out.Status) +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new TeleportTrustedClusterV2. +func (in *TeleportTrustedClusterV2) DeepCopy() *TeleportTrustedClusterV2 { + if in == nil { + return nil + } + out := new(TeleportTrustedClusterV2) + in.DeepCopyInto(out) + return out +} + +// DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object. +func (in *TeleportTrustedClusterV2) DeepCopyObject() runtime.Object { + if c := in.DeepCopy(); c != nil { + return c + } + return nil +} + +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *TeleportTrustedClusterV2List) DeepCopyInto(out *TeleportTrustedClusterV2List) { + *out = *in + out.TypeMeta = in.TypeMeta + in.ListMeta.DeepCopyInto(&out.ListMeta) + if in.Items != nil { + in, out := &in.Items, &out.Items + *out = make([]TeleportTrustedClusterV2, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new TeleportTrustedClusterV2List. +func (in *TeleportTrustedClusterV2List) DeepCopy() *TeleportTrustedClusterV2List { + if in == nil { + return nil + } + out := new(TeleportTrustedClusterV2List) + in.DeepCopyInto(out) + return out +} + +// DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object. +func (in *TeleportTrustedClusterV2List) DeepCopyObject() runtime.Object { + if c := in.DeepCopy(); c != nil { + return c + } + return nil +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new TeleportTrustedClusterV2Spec. +func (in *TeleportTrustedClusterV2Spec) DeepCopy() *TeleportTrustedClusterV2Spec { + if in == nil { + return nil + } + out := new(TeleportTrustedClusterV2Spec) + in.DeepCopyInto(out) + return out +} diff --git a/integrations/operator/config/crd/bases/resources.teleport.dev_trustedclustersv2.yaml b/integrations/operator/config/crd/bases/resources.teleport.dev_trustedclustersv2.yaml new file mode 100644 index 0000000000000..4cf1410472b64 --- /dev/null +++ b/integrations/operator/config/crd/bases/resources.teleport.dev_trustedclustersv2.yaml @@ -0,0 +1,149 @@ +apiVersion: apiextensions.k8s.io/v1 +kind: CustomResourceDefinition +metadata: + creationTimestamp: null + name: teleporttrustedclustersv2.resources.teleport.dev +spec: + group: resources.teleport.dev + names: + kind: TeleportTrustedClusterV2 + listKind: TeleportTrustedClusterV2List + plural: teleporttrustedclustersv2 + shortNames: + - trustedclusterv2 + - trustedclustersv2 + singular: teleporttrustedclusterv2 + scope: Namespaced + versions: + - name: v1 + schema: + openAPIV3Schema: + description: TrustedClusterV2 is the Schema for the trustedclustersv2 API + properties: + apiVersion: + description: 'APIVersion defines the versioned schema of this representation + of an object. Servers should convert recognized schemas to the latest + internal value, and may reject unrecognized values. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources' + type: string + kind: + description: 'Kind is a string value representing the REST resource this + object represents. Servers may infer this from the endpoint the client + submits requests to. Cannot be updated. In CamelCase. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds' + type: string + metadata: + type: object + spec: + description: TrustedCluster resource definition v2 from Teleport + properties: + enabled: + description: Enabled is a bool that indicates if the TrustedCluster + is enabled or disabled. Setting Enabled to false has a side effect + of deleting the user and host certificate authority (CA). + type: boolean + role_map: + description: RoleMap specifies role mappings to remote roles. + items: + properties: + local: + description: Local specifies local roles to map to + items: + type: string + nullable: true + type: array + remote: + description: Remote specifies remote role name to map from + type: string + type: object + type: array + token: + description: Token is the authorization token provided by another + cluster needed by this cluster to join. This field supports secret + lookup. See the operator documentation for more details. + type: string + tunnel_addr: + description: ReverseTunnelAddress is the address of the SSH proxy + server of the cluster to join. If not set, it is derived from `:`. + type: string + web_proxy_addr: + description: ProxyAddress is the address of the web proxy server of + the cluster to join. If not set, it is derived from `:`. + type: string + type: object + status: + description: Status defines the observed state of the Teleport resource + properties: + conditions: + description: Conditions represent the latest available observations + of an object's state + items: + description: Condition contains details for one aspect of the current + state of this API Resource. + properties: + lastTransitionTime: + description: |- + lastTransitionTime is the last time the condition transitioned from one status to another. + This should be when the underlying condition changed. If that is not known, then using the time when the API field changed is acceptable. + format: date-time + type: string + message: + description: |- + message is a human readable message indicating details about the transition. + This may be an empty string. + maxLength: 32768 + type: string + observedGeneration: + description: |- + observedGeneration represents the .metadata.generation that the condition was set based upon. + For instance, if .metadata.generation is currently 12, but the .status.conditions[x].observedGeneration is 9, the condition is out of date + with respect to the current state of the instance. + format: int64 + minimum: 0 + type: integer + reason: + description: |- + reason contains a programmatic identifier indicating the reason for the condition's last transition. + Producers of specific condition types may define expected values and meanings for this field, + and whether the values are considered a guaranteed API. + The value should be a CamelCase string. + This field may not be empty. + maxLength: 1024 + minLength: 1 + pattern: ^[A-Za-z]([A-Za-z0-9_,:]*[A-Za-z0-9_])?$ + type: string + status: + description: status of the condition, one of True, False, Unknown. + enum: + - "True" + - "False" + - Unknown + type: string + type: + description: type of condition in CamelCase or in foo.example.com/CamelCase. + maxLength: 316 + pattern: ^([a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*/)?(([A-Za-z0-9][-A-Za-z0-9_.]*)?[A-Za-z0-9])$ + type: string + required: + - lastTransitionTime + - message + - reason + - status + - type + type: object + type: array + teleportResourceID: + format: int64 + type: integer + type: object + type: object + served: true + storage: true + subresources: + status: {} +status: + acceptedNames: + kind: "" + plural: "" + conditions: null + storedVersions: null diff --git a/integrations/operator/controllers/reconcilers/legacy_resource_without_labels.go b/integrations/operator/controllers/reconcilers/legacy_resource_without_labels.go index c7079240a9d7b..307c1283ee398 100644 --- a/integrations/operator/controllers/reconcilers/legacy_resource_without_labels.go +++ b/integrations/operator/controllers/reconcilers/legacy_resource_without_labels.go @@ -62,7 +62,7 @@ func (a ResourceWithoutLabelsAdapter[T]) SetResourceRevision(res T, revision str } // SetResourceLabels implements the Adapter interface. As the resource does not -// // support labels, it only sets the origin label. +// support labels, it only sets the origin label. func (a ResourceWithoutLabelsAdapter[T]) SetResourceLabels(res T, labels map[string]string) { // We don't set all labels as the Resource doesn't support them // Only the origin diff --git a/integrations/operator/controllers/resources/setup.go b/integrations/operator/controllers/resources/setup.go index a2e78a8cdc68c..fffceccbf8c39 100644 --- a/integrations/operator/controllers/resources/setup.go +++ b/integrations/operator/controllers/resources/setup.go @@ -47,6 +47,7 @@ func SetupAllControllers(log logr.Logger, mgr manager.Manager, teleportClient *c {"TeleportProvisionToken", NewProvisionTokenReconciler}, {"TeleportOpenSSHServerV2", NewOpenSSHServerV2Reconciler}, {"TeleportOpenSSHEICEServerV2", NewOpenSSHEICEServerV2Reconciler}, + {"TeleportTrustedClusterV2", NewTrustedClusterV2Reconciler}, } oidc := modules.GetProtoEntitlement(features, entitlements.OIDC) diff --git a/integrations/operator/controllers/resources/testlib/env.go b/integrations/operator/controllers/resources/testlib/env.go index 9de19230826c3..e41a4f22677eb 100644 --- a/integrations/operator/controllers/resources/testlib/env.go +++ b/integrations/operator/controllers/resources/testlib/env.go @@ -139,6 +139,7 @@ func defaultTeleportServiceConfig(t *testing.T) (*helpers.TeleInstance, string) types.NewRule(types.KindOktaImportRule, unrestricted), types.NewRule(types.KindAccessList, unrestricted), types.NewRule(types.KindNode, unrestricted), + types.NewRule(types.KindTrustedCluster, unrestricted), }, }, }) diff --git a/integrations/operator/controllers/resources/trusted_cluster_controller.go b/integrations/operator/controllers/resources/trusted_cluster_controller.go new file mode 100644 index 0000000000000..a3154bed00b42 --- /dev/null +++ b/integrations/operator/controllers/resources/trusted_cluster_controller.go @@ -0,0 +1,91 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package resources + +import ( + "context" + + "github.com/gravitational/trace" + kclient "sigs.k8s.io/controller-runtime/pkg/client" + + "github.com/gravitational/teleport/api/client" + "github.com/gravitational/teleport/api/types" + resourcesv1 "github.com/gravitational/teleport/integrations/operator/apis/resources/v1" + "github.com/gravitational/teleport/integrations/operator/controllers" + "github.com/gravitational/teleport/integrations/operator/controllers/reconcilers" + "github.com/gravitational/teleport/integrations/operator/controllers/resources/secretlookup" +) + +// trustedClusterClient implements TeleportResourceClient and offers CRUD +// methods needed to reconcile trusted_clusters. +type trustedClusterClient struct { + teleportClient *client.Client + kubeClient kclient.Client +} + +// Get gets the Teleport trusted_cluster of a given name. +func (r trustedClusterClient) Get(ctx context.Context, name string) (types.TrustedCluster, error) { + trustedCluster, err := r.teleportClient.GetTrustedCluster(ctx, name) + return trustedCluster, trace.Wrap(err) +} + +// Create creates a Teleport trusted_cluster. +func (r trustedClusterClient) Create(ctx context.Context, trustedCluster types.TrustedCluster) error { + _, err := r.teleportClient.CreateTrustedCluster(ctx, trustedCluster) + return trace.Wrap(err) +} + +// Update updates a Teleport trusted_cluster. +func (r trustedClusterClient) Update(ctx context.Context, trustedCluster types.TrustedCluster) error { + _, err := r.teleportClient.UpdateTrustedCluster(ctx, trustedCluster) + return trace.Wrap(err) +} + +// Delete deletes a Teleport trusted_cluster. +func (r trustedClusterClient) Delete(ctx context.Context, name string) error { + return trace.Wrap(r.teleportClient.DeleteTrustedCluster(ctx, name)) +} + +// Mutate mutates a Teleport trusted_cluster. +func (r trustedClusterClient) Mutate(ctx context.Context, new, existing types.TrustedCluster, crKey kclient.ObjectKey) error { + secret := new.GetToken() + if secretlookup.IsNeeded(secret) { + resolvedSecret, err := secretlookup.Try(ctx, r.kubeClient, crKey.Name, crKey.Namespace, secret) + if err != nil { + return trace.Wrap(err) + } + new.SetToken(resolvedSecret) + } + return nil +} + +// NewTrustedClusterV2Reconciler instantiates a new Kubernetes controller reconciling trusted_cluster v2 resources +func NewTrustedClusterV2Reconciler(client kclient.Client, tClient *client.Client) (controllers.Reconciler, error) { + trustedClusterClient := &trustedClusterClient{ + teleportClient: tClient, + kubeClient: client, + } + + resourceReconciler, err := reconcilers.NewTeleportResourceWithoutLabelsReconciler[types.TrustedCluster, *resourcesv1.TeleportTrustedClusterV2]( + client, + trustedClusterClient, + ) + + return resourceReconciler, trace.Wrap(err, "building teleport resource reconciler") +} diff --git a/integrations/operator/controllers/resources/trusted_clusterv2_controller_test.go b/integrations/operator/controllers/resources/trusted_clusterv2_controller_test.go new file mode 100644 index 0000000000000..a3b1dc98de5ba --- /dev/null +++ b/integrations/operator/controllers/resources/trusted_clusterv2_controller_test.go @@ -0,0 +1,356 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package resources_test + +import ( + "context" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/uuid" + "github.com/gravitational/trace" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + v1 "k8s.io/api/core/v1" + kerrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/util/retry" + kclient "sigs.k8s.io/controller-runtime/pkg/client" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/integration/helpers" + resourcesv1 "github.com/gravitational/teleport/integrations/operator/apis/resources/v1" + "github.com/gravitational/teleport/integrations/operator/controllers/reconcilers" + "github.com/gravitational/teleport/integrations/operator/controllers/resources/secretlookup" + "github.com/gravitational/teleport/integrations/operator/controllers/resources/testlib" + "github.com/gravitational/teleport/lib" + "github.com/gravitational/teleport/lib/service/servicecfg" + "github.com/gravitational/teleport/lib/utils" +) + +type trustedClusterV2TestingPrimitives struct { + // remoteCluster specifies the remote trusted cluster instance. + remoteCluster *helpers.TeleInstance + // trustedClusterSpec specifies the trusted cluster specs. + trustedClusterSpec types.TrustedClusterSpecV2 + + setup *testSetup + reconcilers.ResourceWithoutLabelsAdapter[types.TrustedCluster] +} + +func (r *trustedClusterV2TestingPrimitives) Init(setup *testSetup) { + r.setup = setup +} + +func (r *trustedClusterV2TestingPrimitives) SetupTeleportFixtures(ctx context.Context) error { + return nil +} + +func (r *trustedClusterV2TestingPrimitives) CreateTeleportResource(ctx context.Context, name string) error { + trustedCluster, err := types.NewTrustedCluster(name, r.trustedClusterSpec) + if err != nil { + return trace.Wrap(err) + } + trustedCluster.SetOrigin(types.OriginKubernetes) + _, err = r.setup.TeleportClient.CreateTrustedCluster(ctx, trustedCluster) + return trace.Wrap(err) +} + +func (r *trustedClusterV2TestingPrimitives) GetTeleportResource(ctx context.Context, name string) (types.TrustedCluster, error) { + return r.setup.TeleportClient.GetTrustedCluster(ctx, name) +} + +func (r *trustedClusterV2TestingPrimitives) DeleteTeleportResource(ctx context.Context, name string) error { + return trace.Wrap(r.setup.TeleportClient.DeleteTrustedCluster(ctx, name)) +} + +func (r *trustedClusterV2TestingPrimitives) CreateKubernetesResource(ctx context.Context, name string) error { + trustedCluster := &resourcesv1.TeleportTrustedClusterV2{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: r.setup.Namespace.Name, + }, + Spec: resourcesv1.TeleportTrustedClusterV2Spec(r.trustedClusterSpec), + } + return trace.Wrap(r.setup.K8sClient.Create(ctx, trustedCluster)) +} + +func (r *trustedClusterV2TestingPrimitives) DeleteKubernetesResource(ctx context.Context, name string) error { + trustedCluster := &resourcesv1.TeleportTrustedClusterV2{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: r.setup.Namespace.Name, + }, + } + return trace.Wrap(r.setup.K8sClient.Delete(ctx, trustedCluster)) +} + +func (r *trustedClusterV2TestingPrimitives) GetKubernetesResource(ctx context.Context, name string) (*resourcesv1.TeleportTrustedClusterV2, error) { + trustedCluster := &resourcesv1.TeleportTrustedClusterV2{} + obj := kclient.ObjectKey{ + Name: name, + Namespace: r.setup.Namespace.Name, + } + err := r.setup.K8sClient.Get(ctx, obj, trustedCluster) + return trustedCluster, trace.Wrap(err) +} + +func (r *trustedClusterV2TestingPrimitives) ModifyKubernetesResource(ctx context.Context, name string) error { + trustedCluster, err := r.GetKubernetesResource(ctx, name) + if err != nil { + return trace.Wrap(err) + } + trustedCluster.Spec.RoleMap[0] = types.RoleMapping{ + Remote: "remote-admin", + Local: []string{"local-dev"}, + } + return trace.Wrap(r.setup.K8sClient.Update(ctx, trustedCluster)) +} + +func (r *trustedClusterV2TestingPrimitives) CompareTeleportAndKubernetesResource(tResource types.TrustedCluster, kubeResource *resourcesv1.TeleportTrustedClusterV2) (bool, string) { + diff := cmp.Diff(tResource, kubeResource.ToTeleport(), testlib.CompareOptions()...) + return diff == "", diff +} + +// setupTest initializes a remote cluster for testing trusted clusters. +func (r *trustedClusterV2TestingPrimitives) setupTest(t *testing.T, clusterName string) { + ctx := context.Background() + + remoteCluster := helpers.NewInstance(t, helpers.InstanceConfig{ + ClusterName: clusterName, + HostID: uuid.New().String(), + NodeName: helpers.Loopback, + Logger: utils.NewSlogLoggerForTests(), + }) + r.remoteCluster = remoteCluster + + rcConf := servicecfg.MakeDefaultConfig() + rcConf.DataDir = t.TempDir() + rcConf.Auth.Enabled = true + rcConf.Proxy.Enabled = true + rcConf.Proxy.DisableWebInterface = true + rcConf.Version = "v2" + + lib.SetInsecureDevMode(true) + t.Cleanup(func() { lib.SetInsecureDevMode(false) }) + + require.NoError(t, remoteCluster.CreateEx(t, nil, rcConf)) + require.NoError(t, remoteCluster.Start()) + t.Cleanup(func() { require.NoError(t, remoteCluster.StopAll()) }) + + // Create trusted cluster join token + token := "secret_token" + tokenResource, err := types.NewProvisionToken(token, []types.SystemRole{types.RoleTrustedCluster}, time.Time{}) + require.NoError(t, err) + remoteCluster.Process.GetAuthServer().UpsertToken(ctx, tokenResource) + + // Create required role + localDev := "local-dev" + require.NoError(t, teleportCreateDummyRole(ctx, localDev, r.setup.TeleportClient)) + + r.trustedClusterSpec = types.TrustedClusterSpecV2{ + Enabled: true, + Token: token, + ProxyAddress: remoteCluster.Web, + ReverseTunnelAddress: remoteCluster.ReverseTunnel, + RoleMap: []types.RoleMapping{ + { + Remote: "remote-dev", + Local: []string{localDev}, + }, + }, + } +} + +func TestTrustedClusterV2Creation(t *testing.T) { + test := &trustedClusterV2TestingPrimitives{} + setup := testlib.SetupTestEnv(t) + test.Init(setup) + ctx := context.Background() + + resourceName := "remote.example.com" + test.setupTest(t, resourceName) + + require.NoError(t, test.CreateKubernetesResource(ctx, resourceName)) + + var resource types.TrustedCluster + var err error + testlib.FastEventually(t, func() bool { + resource, err = test.GetTeleportResource(ctx, resourceName) + return !trace.IsNotFound(err) + }) + require.NoError(t, err) + require.Equal(t, resourceName, test.GetResourceName(resource)) + require.Equal(t, types.OriginKubernetes, test.GetResourceOrigin(resource)) + + err = test.DeleteKubernetesResource(ctx, resourceName) + require.NoError(t, err) + + testlib.FastEventually(t, func() bool { + _, err = test.GetTeleportResource(ctx, resourceName) + return trace.IsNotFound(err) + }) +} + +func TestTrustedClusterV2DeletionDrift(t *testing.T) { + test := &trustedClusterV2TestingPrimitives{} + setup := testlib.SetupTestEnv(t) + test.Init(setup) + ctx := context.Background() + + resourceName := "remote.example.com" + test.setupTest(t, resourceName) + + require.NoError(t, test.CreateKubernetesResource(ctx, resourceName)) + + var resource types.TrustedCluster + var err error + testlib.FastEventually(t, func() bool { + resource, err = test.GetTeleportResource(ctx, resourceName) + return !trace.IsNotFound(err) + }) + require.NoError(t, err) + require.Equal(t, resourceName, test.GetResourceName(resource)) + require.Equal(t, types.OriginKubernetes, test.GetResourceOrigin(resource)) + + // We cause a drift by altering the Teleport resource. + // To make sure the operator does not reconcile while we're finished we suspend the operator + setup.StopKubernetesOperator() + + err = test.DeleteTeleportResource(ctx, resourceName) + require.NoError(t, err) + testlib.FastEventually(t, func() bool { + _, err = test.GetTeleportResource(ctx, resourceName) + return trace.IsNotFound(err) + }) + + // We flag the resource for deletion in Kubernetes (it won't be fully removed until the operator has processed it and removed the finalizer) + err = test.DeleteKubernetesResource(ctx, resourceName) + require.NoError(t, err) + + // Test section: We resume the operator, it should reconcile and recover from the drift + setup.StartKubernetesOperator(t) + + // The operator should handle the failed Teleport deletion gracefully and unlock the Kubernetes resource deletion + testlib.FastEventually(t, func() bool { + _, err = test.GetKubernetesResource(ctx, resourceName) + return kerrors.IsNotFound(err) + }) +} + +func TestTrustedClusterV2Update(t *testing.T) { + test := &trustedClusterV2TestingPrimitives{} + setup := testlib.SetupTestEnv(t) + test.Init(setup) + ctx := context.Background() + + resourceName := "remote.example.com" + test.setupTest(t, resourceName) + + // The resource is created in Teleport + require.NoError(t, test.CreateTeleportResource(ctx, resourceName)) + + // The resource is created in Kubernetes, with at least a field altered + require.NoError(t, test.CreateKubernetesResource(ctx, resourceName)) + + // Check the resource was updated in Teleport + testlib.FastEventuallyWithT(t, func(c *assert.CollectT) { + tResource, err := test.GetTeleportResource(ctx, resourceName) + require.NoError(c, err) + + kubeResource, err := test.GetKubernetesResource(ctx, resourceName) + require.NoError(c, err) + + // Kubernetes and Teleport resources are in-sync + equal, diff := test.CompareTeleportAndKubernetesResource(tResource, kubeResource) + if !equal { + t.Logf("Kubernetes and Teleport resources not sync-ed yet: %s", diff) + } + assert.True(c, equal) + }) + + // Updating the resource in Kubernetes + // The modification can fail because of a conflict with the resource controller. We retry if that happens. + err := retry.RetryOnConflict(retry.DefaultRetry, func() error { + return test.ModifyKubernetesResource(ctx, resourceName) + }) + require.NoError(t, err) + + // Check the resource was updated in Teleport + testlib.FastEventuallyWithT(t, func(c *assert.CollectT) { + kubeResource, err := test.GetKubernetesResource(ctx, resourceName) + require.NoError(c, err) + + tResource, err := test.GetTeleportResource(ctx, resourceName) + require.NoError(c, err) + + // Kubernetes and Teleport resources are in-sync + equal, diff := test.CompareTeleportAndKubernetesResource(tResource, kubeResource) + if !equal { + t.Logf("Kubernetes and Teleport resources not sync-ed yet: %s", diff) + } + assert.True(c, equal) + }) + + // Delete the resource to avoid leftover state. + err = test.DeleteTeleportResource(ctx, resourceName) + require.NoError(t, err) +} + +func TestTrustedClusterV2SecretLookup(t *testing.T) { + test := &trustedClusterV2TestingPrimitives{} + setup := testlib.SetupTestEnv(t) + test.Init(setup) + ctx := context.Background() + + resourceName := "remote.example.com" + test.setupTest(t, resourceName) + + secretName := validRandomResourceName("trusted-cluster-secret") + secretKey := "token" + secretValue := test.trustedClusterSpec.Token + + secret := &v1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: secretName, + Namespace: setup.Namespace.Name, + Annotations: map[string]string{ + secretlookup.AllowLookupAnnotation: resourceName, + }, + }, + StringData: map[string]string{ + secretKey: secretValue, + }, + Type: v1.SecretTypeOpaque, + } + kubeClient := setup.K8sClient + require.NoError(t, kubeClient.Create(ctx, secret)) + + test.trustedClusterSpec.Token = "secret://" + secretName + "/" + secretKey + require.NoError(t, test.CreateKubernetesResource(ctx, resourceName)) + + testlib.FastEventually(t, func() bool { + trustedCluster, err := test.GetTeleportResource(ctx, resourceName) + if err != nil { + return false + } + return trustedCluster.GetToken() == secretValue + }) +} diff --git a/integrations/operator/crdgen/additional_doc.go b/integrations/operator/crdgen/additional_doc.go index 2458af6779bbc..396c94410808f 100644 --- a/integrations/operator/crdgen/additional_doc.go +++ b/integrations/operator/crdgen/additional_doc.go @@ -29,4 +29,7 @@ var additionalDescription = map[string]map[string]string{ "OIDCConnectorSpecV3": { "ClientSecret": supportsSecretLookupDescription, }, + "TrustedClusterSpecV2": { + "Token": supportsSecretLookupDescription, + }, } diff --git a/integrations/operator/crdgen/handlerequest.go b/integrations/operator/crdgen/handlerequest.go index 66d90324cc0f4..57f479de185e3 100644 --- a/integrations/operator/crdgen/handlerequest.go +++ b/integrations/operator/crdgen/handlerequest.go @@ -213,6 +213,7 @@ func generateSchema(file *File, groupName string, format crdFormatFunc, resp *go withAdditionalColumns(serverColumns), }, }, + {name: "TrustedClusterV2", opts: []resourceSchemaOption{withVersionInKindOverride()}}, } for _, resource := range resources { diff --git a/integrations/operator/crdgen/ignored.go b/integrations/operator/crdgen/ignored.go index 596e79b4d291a..7b647756f12cc 100644 --- a/integrations/operator/crdgen/ignored.go +++ b/integrations/operator/crdgen/ignored.go @@ -44,4 +44,7 @@ var ignoredFields = map[string]stringSet{ // allows remote exec on agentful nodes. "CmdLabels": struct{}{}, }, + "TrustedClusterSpecV2": { + "Roles": struct{}{}, // Deprecated, use RoleMap instead. + }, } diff --git a/integrations/operator/crdgen/testdata/golden/resources.teleport.dev_openssheiceserversv2.yaml b/integrations/operator/crdgen/testdata/golden/resources.teleport.dev_openssheiceserversv2.yaml index 3617909ae6a67..bad8469a76fb6 100644 --- a/integrations/operator/crdgen/testdata/golden/resources.teleport.dev_openssheiceserversv2.yaml +++ b/integrations/operator/crdgen/testdata/golden/resources.teleport.dev_openssheiceserversv2.yaml @@ -88,6 +88,20 @@ spec: type: string type: object type: object + github: + description: GitHub contains info about GitHub proxies where each + server represents a GitHub organization. + nullable: true + properties: + integration: + description: Integration is the integration that is associated + with this Server. + type: string + organization: + description: Organization specifies the name of the organization + for the GitHub integration. + type: string + type: object hostname: description: Hostname is server hostname type: string diff --git a/integrations/operator/crdgen/testdata/golden/resources.teleport.dev_opensshserversv2.yaml b/integrations/operator/crdgen/testdata/golden/resources.teleport.dev_opensshserversv2.yaml index ad7dfd4174776..fe3d76a8db7a4 100644 --- a/integrations/operator/crdgen/testdata/golden/resources.teleport.dev_opensshserversv2.yaml +++ b/integrations/operator/crdgen/testdata/golden/resources.teleport.dev_opensshserversv2.yaml @@ -87,6 +87,20 @@ spec: type: string type: object type: object + github: + description: GitHub contains info about GitHub proxies where each + server represents a GitHub organization. + nullable: true + properties: + integration: + description: Integration is the integration that is associated + with this Server. + type: string + organization: + description: Organization specifies the name of the organization + for the GitHub integration. + type: string + type: object hostname: description: Hostname is server hostname type: string diff --git a/integrations/operator/crdgen/testdata/golden/resources.teleport.dev_roles.yaml b/integrations/operator/crdgen/testdata/golden/resources.teleport.dev_roles.yaml index 5d1c5ddfb9809..9e3a0f46e9334 100644 --- a/integrations/operator/crdgen/testdata/golden/resources.teleport.dev_roles.yaml +++ b/integrations/operator/crdgen/testdata/golden/resources.teleport.dev_roles.yaml @@ -157,6 +157,18 @@ spec: type: string nullable: true type: array + github_permissions: + description: GitHubPermissions defines GitHub integration related + permissions. + items: + properties: + orgs: + items: + type: string + nullable: true + type: array + type: object + type: array group_labels: additionalProperties: x-kubernetes-preserve-unknown-fields: true @@ -595,6 +607,17 @@ spec: type: string nullable: true type: array + workload_identity_labels: + additionalProperties: + x-kubernetes-preserve-unknown-fields: true + description: WorkloadIdentityLabels controls whether or not specific + WorkloadIdentity resources can be invoked. Further authorization + controls exist on the WorkloadIdentity resource itself. + type: object + workload_identity_labels_expression: + description: WorkloadIdentityLabelsExpression is a predicate expression + used to allow/deny access to issuing a WorkloadIdentity. + type: string type: object deny: description: Deny is the set of conditions evaluated to deny access. @@ -722,6 +745,18 @@ spec: type: string nullable: true type: array + github_permissions: + description: GitHubPermissions defines GitHub integration related + permissions. + items: + properties: + orgs: + items: + type: string + nullable: true + type: array + type: object + type: array group_labels: additionalProperties: x-kubernetes-preserve-unknown-fields: true @@ -1160,6 +1195,17 @@ spec: type: string nullable: true type: array + workload_identity_labels: + additionalProperties: + x-kubernetes-preserve-unknown-fields: true + description: WorkloadIdentityLabels controls whether or not specific + WorkloadIdentity resources can be invoked. Further authorization + controls exist on the WorkloadIdentity resource itself. + type: object + workload_identity_labels_expression: + description: WorkloadIdentityLabelsExpression is a predicate expression + used to allow/deny access to issuing a WorkloadIdentity. + type: string type: object options: description: Options is for OpenSSH options like agent forwarding. @@ -1584,6 +1630,18 @@ spec: type: string nullable: true type: array + github_permissions: + description: GitHubPermissions defines GitHub integration related + permissions. + items: + properties: + orgs: + items: + type: string + nullable: true + type: array + type: object + type: array group_labels: additionalProperties: x-kubernetes-preserve-unknown-fields: true @@ -2022,6 +2080,17 @@ spec: type: string nullable: true type: array + workload_identity_labels: + additionalProperties: + x-kubernetes-preserve-unknown-fields: true + description: WorkloadIdentityLabels controls whether or not specific + WorkloadIdentity resources can be invoked. Further authorization + controls exist on the WorkloadIdentity resource itself. + type: object + workload_identity_labels_expression: + description: WorkloadIdentityLabelsExpression is a predicate expression + used to allow/deny access to issuing a WorkloadIdentity. + type: string type: object deny: description: Deny is the set of conditions evaluated to deny access. @@ -2149,6 +2218,18 @@ spec: type: string nullable: true type: array + github_permissions: + description: GitHubPermissions defines GitHub integration related + permissions. + items: + properties: + orgs: + items: + type: string + nullable: true + type: array + type: object + type: array group_labels: additionalProperties: x-kubernetes-preserve-unknown-fields: true @@ -2587,6 +2668,17 @@ spec: type: string nullable: true type: array + workload_identity_labels: + additionalProperties: + x-kubernetes-preserve-unknown-fields: true + description: WorkloadIdentityLabels controls whether or not specific + WorkloadIdentity resources can be invoked. Further authorization + controls exist on the WorkloadIdentity resource itself. + type: object + workload_identity_labels_expression: + description: WorkloadIdentityLabelsExpression is a predicate expression + used to allow/deny access to issuing a WorkloadIdentity. + type: string type: object options: description: Options is for OpenSSH options like agent forwarding. diff --git a/integrations/operator/crdgen/testdata/golden/resources.teleport.dev_rolesv6.yaml b/integrations/operator/crdgen/testdata/golden/resources.teleport.dev_rolesv6.yaml index f0af70fc7cf2f..5e1ff2a359184 100644 --- a/integrations/operator/crdgen/testdata/golden/resources.teleport.dev_rolesv6.yaml +++ b/integrations/operator/crdgen/testdata/golden/resources.teleport.dev_rolesv6.yaml @@ -160,6 +160,18 @@ spec: type: string nullable: true type: array + github_permissions: + description: GitHubPermissions defines GitHub integration related + permissions. + items: + properties: + orgs: + items: + type: string + nullable: true + type: array + type: object + type: array group_labels: additionalProperties: x-kubernetes-preserve-unknown-fields: true @@ -598,6 +610,17 @@ spec: type: string nullable: true type: array + workload_identity_labels: + additionalProperties: + x-kubernetes-preserve-unknown-fields: true + description: WorkloadIdentityLabels controls whether or not specific + WorkloadIdentity resources can be invoked. Further authorization + controls exist on the WorkloadIdentity resource itself. + type: object + workload_identity_labels_expression: + description: WorkloadIdentityLabelsExpression is a predicate expression + used to allow/deny access to issuing a WorkloadIdentity. + type: string type: object deny: description: Deny is the set of conditions evaluated to deny access. @@ -725,6 +748,18 @@ spec: type: string nullable: true type: array + github_permissions: + description: GitHubPermissions defines GitHub integration related + permissions. + items: + properties: + orgs: + items: + type: string + nullable: true + type: array + type: object + type: array group_labels: additionalProperties: x-kubernetes-preserve-unknown-fields: true @@ -1163,6 +1198,17 @@ spec: type: string nullable: true type: array + workload_identity_labels: + additionalProperties: + x-kubernetes-preserve-unknown-fields: true + description: WorkloadIdentityLabels controls whether or not specific + WorkloadIdentity resources can be invoked. Further authorization + controls exist on the WorkloadIdentity resource itself. + type: object + workload_identity_labels_expression: + description: WorkloadIdentityLabelsExpression is a predicate expression + used to allow/deny access to issuing a WorkloadIdentity. + type: string type: object options: description: Options is for OpenSSH options like agent forwarding. diff --git a/integrations/operator/crdgen/testdata/golden/resources.teleport.dev_rolesv7.yaml b/integrations/operator/crdgen/testdata/golden/resources.teleport.dev_rolesv7.yaml index 88056b0b54a53..fb682402d11e3 100644 --- a/integrations/operator/crdgen/testdata/golden/resources.teleport.dev_rolesv7.yaml +++ b/integrations/operator/crdgen/testdata/golden/resources.teleport.dev_rolesv7.yaml @@ -160,6 +160,18 @@ spec: type: string nullable: true type: array + github_permissions: + description: GitHubPermissions defines GitHub integration related + permissions. + items: + properties: + orgs: + items: + type: string + nullable: true + type: array + type: object + type: array group_labels: additionalProperties: x-kubernetes-preserve-unknown-fields: true @@ -598,6 +610,17 @@ spec: type: string nullable: true type: array + workload_identity_labels: + additionalProperties: + x-kubernetes-preserve-unknown-fields: true + description: WorkloadIdentityLabels controls whether or not specific + WorkloadIdentity resources can be invoked. Further authorization + controls exist on the WorkloadIdentity resource itself. + type: object + workload_identity_labels_expression: + description: WorkloadIdentityLabelsExpression is a predicate expression + used to allow/deny access to issuing a WorkloadIdentity. + type: string type: object deny: description: Deny is the set of conditions evaluated to deny access. @@ -725,6 +748,18 @@ spec: type: string nullable: true type: array + github_permissions: + description: GitHubPermissions defines GitHub integration related + permissions. + items: + properties: + orgs: + items: + type: string + nullable: true + type: array + type: object + type: array group_labels: additionalProperties: x-kubernetes-preserve-unknown-fields: true @@ -1163,6 +1198,17 @@ spec: type: string nullable: true type: array + workload_identity_labels: + additionalProperties: + x-kubernetes-preserve-unknown-fields: true + description: WorkloadIdentityLabels controls whether or not specific + WorkloadIdentity resources can be invoked. Further authorization + controls exist on the WorkloadIdentity resource itself. + type: object + workload_identity_labels_expression: + description: WorkloadIdentityLabelsExpression is a predicate expression + used to allow/deny access to issuing a WorkloadIdentity. + type: string type: object options: description: Options is for OpenSSH options like agent forwarding. diff --git a/integrations/operator/crdgen/testdata/golden/resources.teleport.dev_trustedclustersv2.yaml b/integrations/operator/crdgen/testdata/golden/resources.teleport.dev_trustedclustersv2.yaml new file mode 100644 index 0000000000000..4cf1410472b64 --- /dev/null +++ b/integrations/operator/crdgen/testdata/golden/resources.teleport.dev_trustedclustersv2.yaml @@ -0,0 +1,149 @@ +apiVersion: apiextensions.k8s.io/v1 +kind: CustomResourceDefinition +metadata: + creationTimestamp: null + name: teleporttrustedclustersv2.resources.teleport.dev +spec: + group: resources.teleport.dev + names: + kind: TeleportTrustedClusterV2 + listKind: TeleportTrustedClusterV2List + plural: teleporttrustedclustersv2 + shortNames: + - trustedclusterv2 + - trustedclustersv2 + singular: teleporttrustedclusterv2 + scope: Namespaced + versions: + - name: v1 + schema: + openAPIV3Schema: + description: TrustedClusterV2 is the Schema for the trustedclustersv2 API + properties: + apiVersion: + description: 'APIVersion defines the versioned schema of this representation + of an object. Servers should convert recognized schemas to the latest + internal value, and may reject unrecognized values. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources' + type: string + kind: + description: 'Kind is a string value representing the REST resource this + object represents. Servers may infer this from the endpoint the client + submits requests to. Cannot be updated. In CamelCase. More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds' + type: string + metadata: + type: object + spec: + description: TrustedCluster resource definition v2 from Teleport + properties: + enabled: + description: Enabled is a bool that indicates if the TrustedCluster + is enabled or disabled. Setting Enabled to false has a side effect + of deleting the user and host certificate authority (CA). + type: boolean + role_map: + description: RoleMap specifies role mappings to remote roles. + items: + properties: + local: + description: Local specifies local roles to map to + items: + type: string + nullable: true + type: array + remote: + description: Remote specifies remote role name to map from + type: string + type: object + type: array + token: + description: Token is the authorization token provided by another + cluster needed by this cluster to join. This field supports secret + lookup. See the operator documentation for more details. + type: string + tunnel_addr: + description: ReverseTunnelAddress is the address of the SSH proxy + server of the cluster to join. If not set, it is derived from `:`. + type: string + web_proxy_addr: + description: ProxyAddress is the address of the web proxy server of + the cluster to join. If not set, it is derived from `:`. + type: string + type: object + status: + description: Status defines the observed state of the Teleport resource + properties: + conditions: + description: Conditions represent the latest available observations + of an object's state + items: + description: Condition contains details for one aspect of the current + state of this API Resource. + properties: + lastTransitionTime: + description: |- + lastTransitionTime is the last time the condition transitioned from one status to another. + This should be when the underlying condition changed. If that is not known, then using the time when the API field changed is acceptable. + format: date-time + type: string + message: + description: |- + message is a human readable message indicating details about the transition. + This may be an empty string. + maxLength: 32768 + type: string + observedGeneration: + description: |- + observedGeneration represents the .metadata.generation that the condition was set based upon. + For instance, if .metadata.generation is currently 12, but the .status.conditions[x].observedGeneration is 9, the condition is out of date + with respect to the current state of the instance. + format: int64 + minimum: 0 + type: integer + reason: + description: |- + reason contains a programmatic identifier indicating the reason for the condition's last transition. + Producers of specific condition types may define expected values and meanings for this field, + and whether the values are considered a guaranteed API. + The value should be a CamelCase string. + This field may not be empty. + maxLength: 1024 + minLength: 1 + pattern: ^[A-Za-z]([A-Za-z0-9_,:]*[A-Za-z0-9_])?$ + type: string + status: + description: status of the condition, one of True, False, Unknown. + enum: + - "True" + - "False" + - Unknown + type: string + type: + description: type of condition in CamelCase or in foo.example.com/CamelCase. + maxLength: 316 + pattern: ^([a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*/)?(([A-Za-z0-9][-A-Za-z0-9_.]*)?[A-Za-z0-9])$ + type: string + required: + - lastTransitionTime + - message + - reason + - status + - type + type: object + type: array + teleportResourceID: + format: int64 + type: integer + type: object + type: object + served: true + storage: true + subresources: + status: {} +status: + acceptedNames: + kind: "" + plural: "" + conditions: null + storedVersions: null diff --git a/integrations/operator/crdgen/testdata/golden/resources.teleport.dev_users.yaml b/integrations/operator/crdgen/testdata/golden/resources.teleport.dev_users.yaml index 504c3695c4532..0c68b6dec714f 100644 --- a/integrations/operator/crdgen/testdata/golden/resources.teleport.dev_users.yaml +++ b/integrations/operator/crdgen/testdata/golden/resources.teleport.dev_users.yaml @@ -57,6 +57,10 @@ spec: description: SAMLSingleLogoutURL is the SAML Single log-out URL to initiate SAML SLO (single log-out), if applicable. type: string + user_id: + description: UserID is the ID of the identity. Some connectors + like GitHub have an unique ID apart from the username. + type: string username: description: Username is username supplied by external identity provider @@ -76,6 +80,10 @@ spec: description: SAMLSingleLogoutURL is the SAML Single log-out URL to initiate SAML SLO (single log-out), if applicable. type: string + user_id: + description: UserID is the ID of the identity. Some connectors + like GitHub have an unique ID apart from the username. + type: string username: description: Username is username supplied by external identity provider @@ -101,6 +109,10 @@ spec: description: SAMLSingleLogoutURL is the SAML Single log-out URL to initiate SAML SLO (single log-out), if applicable. type: string + user_id: + description: UserID is the ID of the identity. Some connectors + like GitHub have an unique ID apart from the username. + type: string username: description: Username is username supplied by external identity provider diff --git a/integrations/operator/crdgen/testdata/protofiles/teleport/legacy/client/proto/authservice.proto b/integrations/operator/crdgen/testdata/protofiles/teleport/legacy/client/proto/authservice.proto index 03b6f9ac35439..fc6dc146ff248 100644 --- a/integrations/operator/crdgen/testdata/protofiles/teleport/legacy/client/proto/authservice.proto +++ b/integrations/operator/crdgen/testdata/protofiles/teleport/legacy/client/proto/authservice.proto @@ -333,7 +333,13 @@ message RouteToApp { // GCPServiceAccount is the GCP service account to assume when accessing GCP API. string GCPServiceAccount = 7 [(gogoproto.jsontag) = "gcp_service_account,omitempty"]; // URI is the URI of the app. This is the internal endpoint where the application is running and isn't user-facing. + // Used merely for audit events and mirrors the URI from the app spec. Not used as a source of + // truth when routing connections. string URI = 8 [(gogoproto.jsontag) = "uri,omitempty"]; + // TargetPort signifies that the cert grants access to a specific port in a multi-port TCP app, as + // long as the port is defined in the app spec. When specified, it must be between 1 and 65535. + // Used only for routing, should not be used in other contexts (e.g., access requests). + uint32 TargetPort = 9 [(gogoproto.jsontag) = "target_port,omitempty"]; } // GetUserRequest specifies parameters for the GetUser method. @@ -1934,6 +1940,60 @@ message CreateRegisterChallengeRequest { DeviceUsage DeviceUsage = 3 [(gogoproto.jsontag) = "device_usage,omitempty"]; } +// IdentityCenterAccount holds information about an Identity Center account +// within an IdentityCenterAccountAssignment +message IdentityCenterAccount { + // ID is the AWS-assigned account ID + string ID = 1; + + // ARN is the full Amazon Resource Name for the AWS account + string ARN = 2; + + // AccountName is the human-readable name of the account + string AccountName = 3; + + // Description is a free text description of the account + string Description = 4; +} + +// IdentityCenterPermissionSet holds information about an Identity Center +// permission set within an IdentityCenterAccountAssignment +message IdentityCenterPermissionSet { + // ARN is the full Amazon Resource Name for the Permission Set + string ARN = 1; + + // Name is the human readable name for the Permission Set + string Name = 2; +} + +// IdentityCenterAccountAssignment represents a requestable Identity Center +// Account Assignment. This is strictly a wire-format object for use with the +// Unfied resource cache, and the types defined in the `identitycenter` package +// should be used for actual processing. +message IdentityCenterAccountAssignment { + // Kind is the database server resource kind. + string Kind = 1 [(gogoproto.jsontag) = "kind"]; + // SubKind is an optional resource subkind. + string SubKind = 2 [(gogoproto.jsontag) = "sub_kind,omitempty"]; + // Version is the resource version. + string Version = 3 [(gogoproto.jsontag) = "version"]; + // Metadata is the account metadata. + types.Metadata Metadata = 4 [ + (gogoproto.nullable) = false, + (gogoproto.jsontag) = "metadata" + ]; + + // DisplayName is a human-readable name for the Account assignment + string DisplayName = 5; + + // Account is the Identity Center Account this assigment references + IdentityCenterAccount Account = 6; + + // PermissionSet is the Identity Center Permission Set this assignment + // references + IdentityCenterPermissionSet PermissionSet = 7; +} + // PaginatedResource represents one of the supported resources. message PaginatedResource { // Resource is the resource itself. @@ -1962,6 +2022,12 @@ message PaginatedResource { types.AppServerOrSAMLIdPServiceProviderV1 AppServerOrSAMLIdPServiceProvider = 11 [deprecated = true]; // SAMLIdPServiceProvider represents a SAML IdP service provider resource. types.SAMLIdPServiceProviderV1 SAMLIdPServiceProvider = 12 [(gogoproto.jsontag) = "saml_idp_service_provider,omitempty"]; + // GitServer represents a Git server resource. + types.ServerV2 git_server = 15; + + // IdentityCenterAccountAssignment represents a requestable Identity Center + // Account Assignment + IdentityCenterAccountAssignment IdentityCenterAccountAssignment = 16 [(gogoproto.jsontag) = "identity_center_account_assignment,omitempty"]; } // Logins allowed for the included resource. Only to be populated for SSH and Desktops. @@ -2079,6 +2145,36 @@ message ListResourcesRequest { bool IncludeLogins = 13 [(gogoproto.jsontag) = "include_logins,omitempty"]; } +// ResolveSSHTargetRequest provides details about a server to be resolved in +// an equivalent manner to a ssh dial request. +// +// Resolution can happen in two modes: +// 1) searching for hosts based on labels, a predicate expression, or keywords +// 2) searching based on hostname +// +// If a Host is provided, resolution will only operate in the second mode and +// will not perform any resolution based on labels. In order to resolve via +// labels the Host must not be populated. +message ResolveSSHTargetRequest { + // The target host as would be sent to the proxy during a dial request. + string host = 1; + // The ssh port. This value is optional, and both empty string and "0" are typically + // treated as meaning that any port should match. + string port = 2; + // If not empty, a label-based matcher. + map labels = 3; + // Boolean conditions that will be matched against the resource. + string predicate_expression = 4; + // A list of search keywords to match against resource field values. + repeated string search_keywords = 5; +} + +// GetSSHTargetsResponse holds ssh servers that match an ssh targets request. +message ResolveSSHTargetResponse { + // The target matching the supplied request. + types.ServerV2 server = 1; +} + // GetSSHTargetsRequest gets all servers that might match an equivalent ssh dial request. message GetSSHTargetsRequest { // Host is the target host as would be sent to the proxy during a dial request. @@ -2349,16 +2445,21 @@ message DownstreamInventoryOneOf { } } -// DownstreamInventoryPing is sent down the inventory control stream for testing/debug -// purposes. +// DownstreamInventoryPing is sent down the inventory control stream. message DownstreamInventoryPing { uint64 ID = 1; } // UpstreamInventoryPong is sent up the inventory control stream in response to a downstream -// ping (used for testing/debug purposes). +// ping including the system clock of the downstream. message UpstreamInventoryPong { uint64 ID = 1; + // SystemClock advertises the system clock of the upstream. + google.protobuf.Timestamp SystemClock = 2 [ + (gogoproto.stdtime) = true, + (gogoproto.nullable) = false, + (gogoproto.jsontag) = "system_clock,omitempty" + ]; } // UpstreamInventoryHello is the hello message sent up the inventory control stream. @@ -2553,10 +2654,10 @@ message InventoryPingRequest { // ServerID is the ID of the instance to ping. string ServerID = 1; - // ControlLog forces the ping to use the standard "commit then act" model of control log synchronization - // for the ping. This significantly increases the amount of time it takes for the ping request to - // complete, but is useful for testing/debugging control log issues. - bool ControlLog = 2; + // ControlLog used to signal that the ping should use the control log synchronization. + // + // Deprecated: the control log is unsupported and unsound to use. + bool ControlLog = 2 [deprecated = true]; } // InventoryPingResponse returns the result of an inventory ping initiated via an @@ -3207,7 +3308,11 @@ service AuthService { // GetTrustedClusters gets all current Trusted Cluster resources. rpc GetTrustedClusters(google.protobuf.Empty) returns (types.TrustedClusterV2List); // UpsertTrustedCluster upserts a Trusted Cluster in a backend. - rpc UpsertTrustedCluster(types.TrustedClusterV2) returns (types.TrustedClusterV2); + // + // Deprecated: Use [teleport.trust.v1.UpsertTrustedCluster] instead. + rpc UpsertTrustedCluster(types.TrustedClusterV2) returns (types.TrustedClusterV2) { + option deprecated = true; + } // DeleteTrustedCluster deletes an existing Trusted Cluster in a backend by name. rpc DeleteTrustedCluster(types.ResourceRequest) returns (google.protobuf.Empty); @@ -3482,6 +3587,9 @@ service AuthService { // but may result in confusing behavior if it is used outside of those contexts. rpc GetSSHTargets(GetSSHTargetsRequest) returns (GetSSHTargetsResponse); + // ResolveSSHTarget returns the server that would be resolved in an equivalent ssh dial request. + rpc ResolveSSHTarget(ResolveSSHTargetRequest) returns (ResolveSSHTargetResponse); + // GetDomainName returns local auth domain of the current auth server rpc GetDomainName(google.protobuf.Empty) returns (GetDomainNameResponse); // GetClusterCACert returns the PEM-encoded TLS certs for the local cluster diff --git a/integrations/operator/crdgen/testdata/protofiles/teleport/legacy/client/proto/event.proto b/integrations/operator/crdgen/testdata/protofiles/teleport/legacy/client/proto/event.proto index 7c0cd043eb13d..b8c39fec6054e 100644 --- a/integrations/operator/crdgen/testdata/protofiles/teleport/legacy/client/proto/event.proto +++ b/integrations/operator/crdgen/testdata/protofiles/teleport/legacy/client/proto/event.proto @@ -34,6 +34,7 @@ import "teleport/secreports/v1/secreports.proto"; import "teleport/userloginstate/v1/userloginstate.proto"; import "teleport/userprovisioning/v2/statichostuser.proto"; import "teleport/usertasks/v1/user_tasks.proto"; +import "teleport/workloadidentity/v1/resource.proto"; option go_package = "github.com/gravitational/teleport/api/client/proto"; @@ -206,5 +207,9 @@ message Event { // IdentityCenterAccountlAssignment is a resource representing a potential // Permission Set grant on a specific AWS account. teleport.identitycenter.v1.AccountAssignment IdentityCenterAccountAssignment = 74; + // PluginStaticCredentials is filled in PluginStaticCredentials related events + types.PluginStaticCredentialsV1 PluginStaticCredentials = 75; + // WorkloadIdentity is a resource for workload identity. + teleport.workloadidentity.v1.WorkloadIdentity WorkloadIdentity = 76; } } diff --git a/integrations/operator/crdgen/testdata/protofiles/teleport/legacy/types/events/events.proto b/integrations/operator/crdgen/testdata/protofiles/teleport/legacy/types/events/events.proto index bd61c99381b62..c82a6e6976e0b 100644 --- a/integrations/operator/crdgen/testdata/protofiles/teleport/legacy/types/events/events.proto +++ b/integrations/operator/crdgen/testdata/protofiles/teleport/legacy/types/events/events.proto @@ -1547,6 +1547,33 @@ message AccessRequestCreate { ]; } +// AccessRequestExpire is emitted when access request has expired. +message AccessRequestExpire { + // Metadata is a common event metadata + Metadata Metadata = 1 [ + (gogoproto.nullable) = false, + (gogoproto.embed) = true, + (gogoproto.jsontag) = "" + ]; + + // ResourceMetadata is a common resource event metadata + ResourceMetadata Resource = 2 [ + (gogoproto.nullable) = false, + (gogoproto.embed) = true, + (gogoproto.jsontag) = "" + ]; + + // RequestID is access request ID + string RequestID = 3 [(gogoproto.jsontag) = "id"]; + + // ResourceExpiry is the time at which the access request resource will expire. + google.protobuf.Timestamp ResourceExpiry = 4 [ + (gogoproto.stdtime) = true, + (gogoproto.nullable) = true, + (gogoproto.jsontag) = "expiry,omitempty" + ]; +} + // ResourceID is a unique identifier for a teleport resource. This is duplicated // from api/types/types.proto to decouple the api and events types and because // neither file currently imports the other. @@ -1617,6 +1644,21 @@ message PortForward { // Addr is a target port forwarding address string Addr = 5 [(gogoproto.jsontag) = "addr"]; + + // KubernetesCluster has information about a kubernetes cluster, if + // applicable. + KubernetesClusterMetadata KubernetesCluster = 6 [ + (gogoproto.nullable) = false, + (gogoproto.embed) = true, + (gogoproto.jsontag) = "" + ]; + + // KubernetesPod has information about a kubernetes pod, if applicable. + KubernetesPodMetadata KubernetesPod = 7 [ + (gogoproto.nullable) = false, + (gogoproto.embed) = true, + (gogoproto.jsontag) = "" + ]; } // X11Forward is emitted when a user requests X11 protocol forwarding @@ -2675,6 +2717,9 @@ message AppMetadata { ]; // AppName is the configured application name. string AppName = 4 [(gogoproto.jsontag) = "app_name,omitempty"]; + // AppTargetPort signifies that the app is a multi-port TCP app and says which port was used to + // access the app. This field is not set for other types of apps, including single-port TCP apps. + uint32 AppTargetPort = 5 [(gogoproto.jsontag) = "app_target_port,omitempty"]; } // AppCreate is emitted when a new application resource is created. @@ -3105,6 +3150,12 @@ message DatabaseSessionStart { // connection. This can be useful for backend process cancellation or // termination and it is not a sensitive or secret value. uint32 PostgresPID = 8 [(gogoproto.jsontag) = "postgres_pid,omitempty"]; + // Client is the common client event metadata. + ClientMetadata Client = 9 [ + (gogoproto.nullable) = false, + (gogoproto.embed) = true, + (gogoproto.jsontag) = "" + ]; } // DatabaseSessionQuery is emitted when a user executes a database query. @@ -4330,6 +4381,8 @@ message IntegrationMetadata { AWSOIDCIntegrationMetadata AWSOIDC = 2 [(gogoproto.jsontag) = "aws_oidc,omitempty"]; // AzureOIDC contains metadata for Azure OIDC integrations. AzureOIDCIntegrationMetadata AzureOIDC = 3 [(gogoproto.jsontag) = "azure_oidc,omitempty"]; + // GitHub contains metadata for GitHub integrations. + GitHubIntegrationMetadata GitHub = 4 [(gogoproto.jsontag) = "github,omitempty"]; } // AWSOIDCIntegrationMetadata contains metadata for AWS OIDC integrations. @@ -4351,6 +4404,12 @@ message AzureOIDCIntegrationMetadata { string ClientID = 2 [(gogoproto.jsontag) = "client_id,omitempty"]; } +// GitHubIntegrationMetadata contains metadata for GitHub integrations. +message GitHubIntegrationMetadata { + // Organization specifies the name of the organization for the GitHub integration. + string Organization = 1 [(gogoproto.jsontag) = "organization,omitempty"]; +} + // PluginCreate is emitted when a plugin resource is created. message PluginCreate { // Metadata is a common event metadata. @@ -4676,6 +4735,14 @@ message OneOf { events.UserTaskUpdate UserTaskUpdate = 189; events.UserTaskDelete UserTaskDelete = 190; events.SFTPSummary SFTPSummary = 191; + events.ContactCreate ContactCreate = 192; + events.ContactDelete ContactDelete = 193; + events.WorkloadIdentityCreate WorkloadIdentityCreate = 194; + events.WorkloadIdentityUpdate WorkloadIdentityUpdate = 195; + events.WorkloadIdentityDelete WorkloadIdentityDelete = 196; + events.GitCommand GitCommand = 197; + events.UserLoginAccessListInvalid UserLoginAccessListInvalid = 198; + events.AccessRequestExpire AccessRequestExpire = 199; } } @@ -4829,6 +4896,9 @@ message RouteToApp { string GCPServiceAccount = 7 [(gogoproto.jsontag) = "gcp_service_account,omitempty"]; // URI is the application URI. string URI = 8 [(gogoproto.jsontag) = "uri,omitempty"]; + // TargetPort signifies that the user accessed a specific port in a multi-port TCP app. The value + // must be between 1 and 65535. + uint32 TargetPort = 9 [(gogoproto.jsontag) = "target_port,omitempty"]; } // RouteToDatabase combines parameters for database service routing information. @@ -6697,6 +6767,12 @@ message SPIFFESVIDIssued { // Audiences is the list of audiences in the issued SVID. // Only present if the SVID is a JWT. repeated string Audiences = 11 [(gogoproto.jsontag) = "audiences,omitempty"]; + // The WorkloadIdentity resource that was used to issue the SVID, this will + // be empty if the legacy RPCs were used. + string WorkloadIdentity = 12 [(gogoproto.jsontag) = "workload_identity,omitempty"]; + // The revision of the WorkloadIdentity resource that was used to issue the + // SVID. This will be empty if the legacy RPCs were used. + string WorkloadIdentityRevision = 13 [(gogoproto.jsontag) = "workload_identity_revision,omitempty"]; } // AuthPreferenceUpdate is emitted when the auth preference is updated. @@ -7566,3 +7642,305 @@ message UserTaskDelete { (gogoproto.jsontag) = "" ]; } + +// ContactCreate is emitted when a contact is created. +message ContactCreate { + // Metadata is a common event metadata + Metadata Metadata = 1 [ + (gogoproto.nullable) = false, + (gogoproto.embed) = true, + (gogoproto.jsontag) = "" + ]; + + // ResourceMetadata is a common resource event metadata + ResourceMetadata Resource = 2 [ + (gogoproto.nullable) = false, + (gogoproto.embed) = true, + (gogoproto.jsontag) = "" + ]; + + // User is a common user event metadata + UserMetadata User = 3 [ + (gogoproto.nullable) = false, + (gogoproto.embed) = true, + (gogoproto.jsontag) = "" + ]; + + // ConnectionMetadata holds information about the connection + ConnectionMetadata Connection = 4 [ + (gogoproto.nullable) = false, + (gogoproto.embed) = true, + (gogoproto.jsontag) = "" + ]; + + // Status indicates whether the creation was successful. + Status Status = 5 [ + (gogoproto.nullable) = false, + (gogoproto.embed) = true, + (gogoproto.jsontag) = "" + ]; + + // Email is the Email of the contact being deleted + string Email = 6 [(gogoproto.jsontag) = "email"]; + + // ContactType is the type of the contact being deleted ('Business' or 'Security') + ContactType ContactType = 7 [(gogoproto.jsontag) = "contact_type"]; +} + +// ContactDelete is emitted when a contact is deleted. +message ContactDelete { + // Metadata is a common event metadata + Metadata Metadata = 1 [ + (gogoproto.nullable) = false, + (gogoproto.embed) = true, + (gogoproto.jsontag) = "" + ]; + + // ResourceMetadata is a common resource event metadata + ResourceMetadata Resource = 2 [ + (gogoproto.nullable) = false, + (gogoproto.embed) = true, + (gogoproto.jsontag) = "" + ]; + + // User is a common user event metadata + UserMetadata User = 3 [ + (gogoproto.nullable) = false, + (gogoproto.embed) = true, + (gogoproto.jsontag) = "" + ]; + + // ConnectionMetadata holds information about the connection + ConnectionMetadata Connection = 4 [ + (gogoproto.nullable) = false, + (gogoproto.embed) = true, + (gogoproto.jsontag) = "" + ]; + + // Status indicates whether the deletion was successful. + Status Status = 5 [ + (gogoproto.nullable) = false, + (gogoproto.embed) = true, + (gogoproto.jsontag) = "" + ]; + + // Email is the Email of the contact being deleted + string Email = 6 [(gogoproto.jsontag) = "email"]; + + // ContactType is the type of the contact being deleted ('Business' or 'Security') + ContactType ContactType = 7 [(gogoproto.jsontag) = "contact_type"]; +} + +// ContactType is the type of contact being added. +enum ContactType { + CONTACT_TYPE_UNSPECIFIED = 0; + CONTACT_TYPE_BUSINESS = 1; + CONTACT_TYPE_SECURITY = 2; +} + +// WorkloadIdentityCreate is emitted when a WorkloadIdentity is created. +message WorkloadIdentityCreate { + // Metadata is a common event metadata + Metadata Metadata = 1 [ + (gogoproto.nullable) = false, + (gogoproto.embed) = true, + (gogoproto.jsontag) = "" + ]; + + // ResourceMetadata is a common resource event metadata + ResourceMetadata Resource = 2 [ + (gogoproto.nullable) = false, + (gogoproto.embed) = true, + (gogoproto.jsontag) = "" + ]; + + // User is a common user event metadata + UserMetadata User = 3 [ + (gogoproto.nullable) = false, + (gogoproto.embed) = true, + (gogoproto.jsontag) = "" + ]; + + // ConnectionMetadata holds information about the connection + ConnectionMetadata Connection = 4 [ + (gogoproto.nullable) = false, + (gogoproto.embed) = true, + (gogoproto.jsontag) = "" + ]; + + // WorkloadIdentityData is a copy of the WorkloadIdentity resource + google.protobuf.Struct WorkloadIdentityData = 5 [ + (gogoproto.jsontag) = "workload_identity_data,omitempty", + (gogoproto.casttype) = "Struct" + ]; +} + +// WorkloadIdentityUpdate is emitted when a WorkloadIdentity is updated. +message WorkloadIdentityUpdate { + // Metadata is a common event metadata + Metadata Metadata = 1 [ + (gogoproto.nullable) = false, + (gogoproto.embed) = true, + (gogoproto.jsontag) = "" + ]; + + // ResourceMetadata is a common resource event metadata + ResourceMetadata Resource = 2 [ + (gogoproto.nullable) = false, + (gogoproto.embed) = true, + (gogoproto.jsontag) = "" + ]; + + // User is a common user event metadata + UserMetadata User = 3 [ + (gogoproto.nullable) = false, + (gogoproto.embed) = true, + (gogoproto.jsontag) = "" + ]; + + // ConnectionMetadata holds information about the connection + ConnectionMetadata Connection = 4 [ + (gogoproto.nullable) = false, + (gogoproto.embed) = true, + (gogoproto.jsontag) = "" + ]; + + // WorkloadIdentityData is a copy of the WorkloadIdentity resource + google.protobuf.Struct WorkloadIdentityData = 5 [ + (gogoproto.jsontag) = "workload_identity_data,omitempty", + (gogoproto.casttype) = "Struct" + ]; +} + +// WorkloadIdentityDelete is emitted when a WorkloadIdentity is deleted. +message WorkloadIdentityDelete { + // Metadata is a common event metadata + Metadata Metadata = 1 [ + (gogoproto.nullable) = false, + (gogoproto.embed) = true, + (gogoproto.jsontag) = "" + ]; + + // ResourceMetadata is a common resource event metadata + ResourceMetadata Resource = 2 [ + (gogoproto.nullable) = false, + (gogoproto.embed) = true, + (gogoproto.jsontag) = "" + ]; + + // User is a common user event metadata + UserMetadata User = 3 [ + (gogoproto.nullable) = false, + (gogoproto.embed) = true, + (gogoproto.jsontag) = "" + ]; + + // ConnectionMetadata holds information about the connection + ConnectionMetadata Connection = 4 [ + (gogoproto.nullable) = false, + (gogoproto.embed) = true, + (gogoproto.jsontag) = "" + ]; +} + +// GitCommand is emitted when a user performs a Git fetch or push command. +message GitCommand { + // Metadata is a common event metadata + Metadata Metadata = 1 [ + (gogoproto.nullable) = false, + (gogoproto.embed) = true, + (gogoproto.jsontag) = "" + ]; + + // User is a common user event metadata + UserMetadata User = 2 [ + (gogoproto.nullable) = false, + (gogoproto.embed) = true, + (gogoproto.jsontag) = "" + ]; + + // ConnectionMetadata holds information about the connection + ConnectionMetadata Connection = 3 [ + (gogoproto.nullable) = false, + (gogoproto.embed) = true, + (gogoproto.jsontag) = "" + ]; + + // SessionMetadata is a common event session metadata + SessionMetadata Session = 4 [ + (gogoproto.nullable) = false, + (gogoproto.embed) = true, + (gogoproto.jsontag) = "" + ]; + + // ServerMetadata is a common server metadata + ServerMetadata Server = 5 [ + (gogoproto.nullable) = false, + (gogoproto.embed) = true, + (gogoproto.jsontag) = "" + ]; + + // CommandMetadata is a common command metadata + CommandMetadata Command = 6 [ + (gogoproto.nullable) = false, + (gogoproto.embed) = true, + (gogoproto.jsontag) = "" + ]; + + // Service is the type of the git request like git-upload-pack or + // git-receive-pack. + string service = 8 [(gogoproto.jsontag) = "service"]; + // Path is the Git repo path, usually /. + string path = 9 [(gogoproto.jsontag) = "path"]; + + // Actions defines details for a Git push. + repeated GitCommandAction actions = 10 [(gogoproto.jsontag) = "actions,omitempty"]; +} + +// GitCommandAction defines details for a Git push. +message GitCommandAction { + // Action type like create or update. + string Action = 1 [(gogoproto.jsontag) = "action,omitempty"]; + // Reference name like ref/main/my_branch. + string Reference = 2 [(gogoproto.jsontag) = "reference,omitempty"]; + // Old is the old hash. + string Old = 3 [(gogoproto.jsontag) = "old,omitempty"]; + // New is the new hash. + string New = 4 [(gogoproto.jsontag) = "new,omitempty"]; +} + +// AccessListInvalidMetadata contains metadata about invalid access lists. +message AccessListInvalidMetadata { + // AccessListName is the name of the invalid access list. + string AccessListName = 1 [(gogoproto.jsontag) = "access_list_name, omitempty"]; + // User is the username of the access list member who attempted to log in. + string User = 2 [(gogoproto.jsontag) = "user,omitempty"]; + // MissingRoles are the names of the non-existent roles being referenced by the access list, causing it to be invalid. + repeated string MissingRoles = 3 [(gogoproto.jsontag) = "missing_roles,omitempty"]; +} + +// UserLoginAccessListInvalid is emitted when a user who is a member of an invalid +// access list logs in. It is used to indicate that the access list could not be +// applied to the user's session. +message UserLoginAccessListInvalid { + // Metadata is common event metadata + Metadata Metadata = 1 [ + (gogoproto.nullable) = false, + (gogoproto.embed) = true, + (gogoproto.jsontag) = "" + ]; + + // AccessListInvalidMetadata is the metadata for this access list invalid event. + AccessListInvalidMetadata AccessListInvalidMetadata = 2 [ + (gogoproto.nullable) = false, + (gogoproto.embed) = true, + (gogoproto.jsontag) = "" + ]; + + // Status contains fields to indicate whether attempt was successful or not. + Status Status = 3 [ + (gogoproto.nullable) = false, + (gogoproto.embed) = true, + (gogoproto.jsontag) = "" + ]; +} diff --git a/integrations/operator/crdgen/testdata/protofiles/teleport/legacy/types/trusted_device_requirement.proto b/integrations/operator/crdgen/testdata/protofiles/teleport/legacy/types/trusted_device_requirement.proto new file mode 100644 index 0000000000000..9f074c9e76465 --- /dev/null +++ b/integrations/operator/crdgen/testdata/protofiles/teleport/legacy/types/trusted_device_requirement.proto @@ -0,0 +1,37 @@ +// Copyright 2024 Gravitational, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package types; + +import "gogoproto/gogo.proto"; + +option go_package = "github.com/gravitational/teleport/api/types"; +option (gogoproto.goproto_getters_all) = false; +option (gogoproto.marshaler_all) = true; +option (gogoproto.unmarshaler_all) = true; + +// TrustedDeviceRequirement indicates whether access may be hindered by the lack +// of a trusted device. +enum TrustedDeviceRequirement { + // Device requirement not determined. + // Does not mean that a device is not required, only that the necessary data + // was not considered. + TRUSTED_DEVICE_REQUIREMENT_UNSPECIFIED = 0; + // Trusted device not required. + TRUSTED_DEVICE_REQUIREMENT_NOT_REQUIRED = 1; + // Trusted device required by either cluster mode or user roles. + TRUSTED_DEVICE_REQUIREMENT_REQUIRED = 2; +} diff --git a/integrations/operator/crdgen/testdata/protofiles/teleport/legacy/types/types.proto b/integrations/operator/crdgen/testdata/protofiles/teleport/legacy/types/types.proto index 7739ecad6c7a0..f241f6501956e 100644 --- a/integrations/operator/crdgen/testdata/protofiles/teleport/legacy/types/types.proto +++ b/integrations/operator/crdgen/testdata/protofiles/teleport/legacy/types/types.proto @@ -21,6 +21,7 @@ import "google/protobuf/duration.proto"; import "google/protobuf/timestamp.proto"; import "google/protobuf/wrappers.proto"; import "teleport/attestation/v1/attestation.proto"; +import "teleport/legacy/types/trusted_device_requirement.proto"; import "teleport/legacy/types/wrappers/wrappers.proto"; option go_package = "github.com/gravitational/teleport/api/types"; @@ -711,6 +712,31 @@ message InstanceSpecV1 { // ExternalUpgraderVersion identifies the external upgrader version. Empty if no upgrader is defined. string ExternalUpgraderVersion = 8 [(gogoproto.jsontag) = "ext_upgrader_version,omitempty"]; + + // LastMeasurement stores information about the latest measurement between services. + SystemClockMeasurement LastMeasurement = 9; +} + +// SystemClockMeasurement represents the measurement state of the systems clock difference. +message SystemClockMeasurement { + // ControllerSystemClock is the system clock of the inventory controller. + google.protobuf.Timestamp ControllerSystemClock = 1 [ + (gogoproto.stdtime) = true, + (gogoproto.nullable) = false, + (gogoproto.jsontag) = "controller_system_clock,omitempty" + ]; + // SystemClock is the system clock of the upstream. + google.protobuf.Timestamp SystemClock = 2 [ + (gogoproto.stdtime) = true, + (gogoproto.nullable) = false, + (gogoproto.jsontag) = "system_clock,omitempty" + ]; + // RequestDuration stores information about the request duration between auth and remote service. + google.protobuf.Duration RequestDuration = 3 [ + (gogoproto.jsontag) = "request_duration", + (gogoproto.nullable) = false, + (gogoproto.stdduration) = true + ]; } // InstanceControlLogEntry represents an entry in a given instance's control log. The control log of @@ -842,6 +868,9 @@ message ServerSpecV2 { // CloudMetadata contains info about the cloud instance the server is running // on, if any. CloudMetadata CloudMetadata = 14 [(gogoproto.jsontag) = "cloud_metadata,omitempty"]; + // GitHub contains info about GitHub proxies where each server represents a + // GitHub organization. + GitHubServerMetadata git_hub = 15 [(gogoproto.jsontag) = "github,omitempty"]; reserved 8; reserved 10; @@ -875,6 +904,15 @@ message CloudMetadata { AWSInfo AWS = 1 [(gogoproto.jsontag) = "aws,omitempty"]; } +// GitHubServerMetadata contains info about GitHub proxies where each server +// represents a GitHub organization. +message GitHubServerMetadata { + // Organization specifies the name of the organization for the GitHub integration. + string organization = 1 [(gogoproto.jsontag) = "organization,omitempty"]; + // Integration is the integration that is associated with this Server. + string integration = 2 [(gogoproto.jsontag) = "integration,omitempty"]; +} + // AppServerV3 represents a single proxied web app. message AppServerV3 { option (gogoproto.goproto_stringer) = false; @@ -971,6 +1009,10 @@ message IdentityCenterPermissionSet { // Name is the human-readable name of the Permission Set. string Name = 2 [(gogoproto.jsontag) = "name,omitempty"]; + + // AssignmentID is the ID of the Teelport Account Assignment resource that + // represents this permission being assigned on the enclosing Account. + string AssignmentID = 3 [(gogoproto.jsontag) = "assignment_name,omitempty"]; } // AppIdentityCenter encapsulates information about an AWS Identity Center @@ -1016,6 +1058,11 @@ message AppSpecV3 { // IdentityCenter encasulates AWS identity-center specific information. Only // valid for Identity Center account apps. AppIdentityCenter IdentityCenter = 12 [(gogoproto.jsontag) = "identity_center,omitempty"]; + // TCPPorts is a list of ports and port ranges that an app agent can forward connections to. + // Only applicable to TCP App Access. + // If this field is not empty, URI is expected to contain no port number and start with the tcp + // protocol. + repeated PortRange TCPPorts = 13 [(gogoproto.jsontag) = "tcp_ports,omitempty"]; } // AppServerOrSAMLIdPServiceProviderV1 holds either an AppServerV3 or a SAMLIdPServiceProviderV1 resource (never both). @@ -1057,6 +1104,20 @@ message Header { string Value = 2 [(gogoproto.jsontag) = "value"]; } +// PortRange describes a port range for TCP apps. The range starts with Port and ends with EndPort. +// PortRange can be used to describe a single port in which case the Port field is the port and the +// EndPort field is 0. +message PortRange { + option (gogoproto.goproto_stringer) = false; + option (gogoproto.stringer) = false; + // Port describes the start of the range. It must be between 1 and 65535. + uint32 Port = 1 [(gogoproto.jsontag) = "port"]; + // EndPort describes the end of the range, inclusive. If set, it must be between 2 and 65535 and + // be greater than Port when describing a port range. When omitted or set to zero, it signifies + // that the port range defines a single port. + uint32 EndPort = 2 [(gogoproto.jsontag) = "end_port,omitempty"]; +} + // CommandLabelV2 is a label that has a value as a result of the // output generated by running command, e.g. hostname message CommandLabelV2 { @@ -2668,6 +2729,13 @@ message AccessRequestSpecV3 { (gogoproto.nullable) = true, (gogoproto.jsontag) = "assume_start_time,omitempty" ]; + + // ResourceExpiry is the time at which the access request resource will expire. + google.protobuf.Timestamp ResourceExpiry = 22 [ + (gogoproto.stdtime) = true, + (gogoproto.nullable) = true, + (gogoproto.jsontag) = "expiry,omitempty" + ]; } enum AccessRequestScope { @@ -2759,6 +2827,7 @@ message RequestKubernetesResource { } // ResourceID is a unique identifier for a teleport resource. +// Must be kept in sync with teleport.decision.v1alpha1.ResourceId. message ResourceID { // ClusterName is the name of the cluster the resource is in. string ClusterName = 1 [(gogoproto.jsontag) = "cluster"]; @@ -3372,6 +3441,24 @@ message RoleConditions { (gogoproto.nullable) = false, (gogoproto.jsontag) = "account_assignments,omitempty" ]; + + // GitHubPermissions defines GitHub integration related permissions. + repeated GitHubPermission git_hub_permissions = 43 [ + (gogoproto.nullable) = false, + (gogoproto.jsontag) = "github_permissions,omitempty" + ]; + + // WorkloadIdentityLabels controls whether or not specific WorkloadIdentity + // resources can be invoked. Further authorization controls exist on the + // WorkloadIdentity resource itself. + wrappers.LabelValues WorkloadIdentityLabels = 44 [ + (gogoproto.nullable) = false, + (gogoproto.jsontag) = "workload_identity_labels,omitempty", + (gogoproto.customtype) = "Labels" + ]; + // WorkloadIdentityLabelsExpression is a predicate expression used to + // allow/deny access to issuing a WorkloadIdentity. + string WorkloadIdentityLabelsExpression = 45 [(gogoproto.jsontag) = "workload_identity_labels_expression,omitempty"]; } // IdentityCenterAccountAssignment captures an AWS Identity Center account @@ -3381,6 +3468,11 @@ message IdentityCenterAccountAssignment { string Account = 2 [(gogoproto.jsontag) = "account,omitempty"]; } +// GitHubPermission defines GitHub integration related permissions. +message GitHubPermission { + repeated string organizations = 1 [(gogoproto.jsontag) = "orgs,omitempty"]; +} + // SPIFFERoleCondition sets out which SPIFFE identities this role is allowed or // denied to generate. The Path matcher is required, and is evaluated first. If, // the Path does not match then the other matcher fields are not evaluated. @@ -3803,6 +3895,10 @@ message ExternalIdentity { // SAMLSingleLogoutURL is the SAML Single log-out URL to initiate SAML SLO (single log-out), if applicable. string SAMLSingleLogoutURL = 3 [(gogoproto.jsontag) = "samlSingleLogoutUrl,omitempty"]; + + // UserID is the ID of the identity. Some connectors like GitHub have an + // unique ID apart from the username. + string UserID = 4 [(gogoproto.jsontag) = "user_id,omitempty"]; } // LoginStatus is a login status of the user @@ -4246,19 +4342,6 @@ message WebSessionSpecV2 { bytes TLSPriv = 15 [(gogoproto.jsontag) = "tls_priv,omitempty"]; } -// TrustedDeviceRequirement indicates whether access may be hindered by the lack -// of a trusted device. -enum TrustedDeviceRequirement { - // Device requirement not determined. - // Does not mean that a device is not required, only that the necessary data - // was not considered. - TRUSTED_DEVICE_REQUIREMENT_UNSPECIFIED = 0; - // Trusted device not required. - TRUSTED_DEVICE_REQUIREMENT_NOT_REQUIRED = 1; - // Trusted device required by either cluster mode or user roles. - TRUSTED_DEVICE_REQUIREMENT_REQUIRED = 2; -} - // Web-focused view of teleport.devicetrust.v1.DeviceWebToken. message DeviceWebToken { // Opaque token identifier. @@ -5193,7 +5276,7 @@ message GithubAuthRequest { string KubernetesCluster = 13 [(gogoproto.jsontag) = "kubernetes_cluster,omitempty"]; // SSOTestFlow indicates if the request is part of the test flow. bool SSOTestFlow = 14 [(gogoproto.jsontag) = "sso_test_flow"]; - // ConnectorSpec is embedded connector spec for use in test flow. + // ConnectorSpec is embedded connector spec for use in test flow or authenticated user flow. GithubConnectorSpecV3 ConnectorSpec = 15 [(gogoproto.jsontag) = "connector_spec,omitempty"]; // AttestationStatement is an attestation statement for the given public key. // @@ -5217,6 +5300,10 @@ message GithubAuthRequest { teleport.attestation.v1.AttestationStatement ssh_attestation_statement = 21 [(gogoproto.jsontag) = "ssh_attestation_statement,omitempty"]; // TlsAttestationStatement is an attestation statement for the given TLS public key. teleport.attestation.v1.AttestationStatement tls_attestation_statement = 22 [(gogoproto.jsontag) = "tls_attestation_statement,omitempty"]; + // AuthenticatedUser is the username of an authenticated Teleport user. This + // OAuth flow is used to retrieve GitHub identity info which will be added to + // the existing user. + string authenticated_user = 23 [(gogoproto.jsontag) = "authenticated_user,omitempty"]; } // SSOWarnings conveys a user-facing main message along with auxiliary warnings. @@ -5387,6 +5474,12 @@ message GithubClaims { // Teams is the users team membership repeated string Teams = 3 [(gogoproto.jsontag) = "teams"]; + + // UserID is a global unique integer that is assigned to each GitHub user. The + // user ID is immutable (unlike the GitHub username) and can be found in APIs + // like get user. + // https://docs.github.com/en/rest/users/users + string UserID = 4 [(gogoproto.jsontag) = "user_id,omitempty"]; } // TeamMapping represents a single team membership mapping. @@ -6384,6 +6477,8 @@ message PluginSpecV1 { PluginEmailSettings email = 17; // Settings for the Microsoft Teams plugin PluginMSTeamsSettings msteams = 18; + // Settings for the OpenTex NetIQ plugin. + PluginNetIQSettings net_iq = 19; } // generation contains a unique ID that should: @@ -6810,6 +6905,18 @@ message PluginMSTeamsSettings { string default_recipient = 5; } +// PluginNetIQSettings defines the settings for a NetIQ integration plugin +message PluginNetIQSettings { + option (gogoproto.equal) = true; + // oauth_issuer_endpoint is the NetIQ Oauth Issuer endpoint. + // Usually, it's equal to https://osp.domain.ext/a/idm/auth/oauth2 + string oauth_issuer_endpoint = 1; + // api_endpoint is the IDM PROV Rest API location. + string api_endpoint = 2; + // insecure_skip_verify controls whether the NetIQ certificate validation should be skipped. + bool insecure_skip_verify = 3; +} + message PluginBootstrapCredentialsV1 { oneof credentials { PluginOAuth2AuthorizationCodeCredentials oauth2_authorization_code = 1; @@ -6848,6 +6955,8 @@ message PluginStatusV1 { PluginOktaStatusV1 okta = 7; // AWSIC holds status details for the AWS Identity Center plugin. PluginAWSICStatusV1 aws_ic = 8; + // NetIQ holds status details for the NetIQ plugin. + PluginNetIQStatusV1 net_iq = 9; } // last_raw_error variable stores the most recent raw error message received from an API or service. @@ -6857,6 +6966,18 @@ message PluginStatusV1 { string last_raw_error = 6; } +// PluginNetIQStatusV1 is the status details for the NetIQ plugin. +message PluginNetIQStatusV1 { + // imported_users is the number of users imported from NetIQ eDirectory. + uint32 imported_users = 1; + // imported_groups is the number of groups imported from NetIQ eDirectory. + uint32 imported_groups = 2; + // imported_roles is the number of roles imported from NetIQ eDirectory. + uint32 imported_roles = 3; + // imported_resources is the number of resources imported from NetIQ eDirectory. + uint32 imported_resources = 4; +} + // PluginGitlabStatusV1 is the status details for the Gitlab plugin. message PluginGitlabStatusV1 { // imported_users is the number of users imported from Gitlab. @@ -7119,6 +7240,7 @@ message PluginStaticCredentialsSpecV1 { string APIToken = 1; PluginStaticCredentialsBasicAuth BasicAuth = 2; PluginStaticCredentialsOAuthClientSecret OAuthClientSecret = 3; + PluginStaticCredentialsSSHCertAuthorities SSHCertAuthorities = 4; } } @@ -7140,6 +7262,14 @@ message PluginStaticCredentialsOAuthClientSecret { string ClientSecret = 2 [(gogoproto.jsontag) = "client_secret"]; } +// PluginStaticCredentialsSSHCertAuthorities contains the active SSH CAs used +// for the integration or plugin. +message PluginStaticCredentialsSSHCertAuthorities { + // CertAuthorities contains the active SSH CAs used for the integration or + // plugin. + repeated SSHKeyPair cert_authorities = 1; +} + // SAMLIdPServiceProviderV1 is the representation of a SAML IdP service provider. message SAMLIdPServiceProviderV1 { option (gogoproto.goproto_stringer) = false; @@ -7487,7 +7617,12 @@ message IntegrationSpecV1 { AWSOIDCIntegrationSpecV1 AWSOIDC = 1 [(gogoproto.jsontag) = "aws_oidc,omitempty"]; // AzureOIDC contains the specific fields to handle the Azure OIDC Integration subkind AzureOIDCIntegrationSpecV1 AzureOIDC = 2 [(gogoproto.jsontag) = "azure_oidc,omitempty"]; + // GitHub contains the specific fields to handle the GitHub integration subkind. + GitHubIntegrationSpecV1 GitHub = 3 [(gogoproto.jsontag) = "github,omitempty"]; } + + // Credentials contains credentials for the integration. + PluginCredentialsV1 credentials = 4; } // AWSOIDCIntegrationSpecV1 contains the spec properties for the AWS OIDC SubKind Integration. @@ -7532,6 +7667,12 @@ message AzureOIDCIntegrationSpecV1 { string ClientID = 2 [(gogoproto.jsontag) = "client_id,omitempty"]; } +// GitHubIntegrationSpecV1 contains the specific fields to handle the GitHub integration subkind. +message GitHubIntegrationSpecV1 { + // Organization specifies the name of the organization for the GitHub integration. + string Organization = 1 [(gogoproto.jsontag) = "organization,omitempty"]; +} + // HeadlessAuthentication holds data for an ongoing headless authentication attempt. message HeadlessAuthentication { // Header is the resource header. @@ -7920,12 +8061,14 @@ message OktaOptions { message AccessGraphSync { // AWS is a configuration for AWS Access Graph service poll service. repeated AccessGraphAWSSync AWS = 1 [(gogoproto.jsontag) = "aws,omitempty"]; - // PollInterval is the frequency at which to poll for AWS resources + // PollInterval is the frequency at which to poll for resources google.protobuf.Duration PollInterval = 2 [ (gogoproto.jsontag) = "poll_interval,omitempty", (gogoproto.nullable) = false, (gogoproto.stdduration) = true ]; + // Azure is a configuration for Azure Access Graph service poll service. + repeated AccessGraphAzureSync Azure = 3 [(gogoproto.jsontag) = "azure,omitempty"]; } // AccessGraphAWSSync is a configuration for AWS Access Graph service poll service. @@ -7937,3 +8080,11 @@ message AccessGraphAWSSync { // Integration is the integration name used to generate credentials to interact with AWS APIs. string Integration = 4 [(gogoproto.jsontag) = "integration,omitempty"]; } + +// AccessGraphAzureSync is a configuration for Azure Access Graph service poll service. +message AccessGraphAzureSync { + // SubscriptionID Is the ID of the Azure subscription to sync resources from + string SubscriptionID = 1 [(gogoproto.jsontag) = "subscription_id,omitempty"]; + // Integration is the integration name used to generate credentials to interact with AWS APIs. + string Integration = 2 [(gogoproto.jsontag) = "integration,omitempty"]; +} diff --git a/integrations/operator/hack/fixture-operator-role.yaml b/integrations/operator/hack/fixture-operator-role.yaml index e9925b19a106c..ac6e88a6dfbd1 100644 --- a/integrations/operator/hack/fixture-operator-role.yaml +++ b/integrations/operator/hack/fixture-operator-role.yaml @@ -73,5 +73,13 @@ spec: - read - update - delete + - resources: + - trusted_cluster + verbs: + - list + - create + - read + - update + - delete deny: {} version: v7 diff --git a/lib/auth/trustedcluster.go b/lib/auth/trustedcluster.go index a02e8f4b74de6..c6a11a6d5e5db 100644 --- a/lib/auth/trustedcluster.go +++ b/lib/auth/trustedcluster.go @@ -84,7 +84,6 @@ func (a *Server) UpdateTrustedCluster(ctx context.Context, tc types.TrustedClust if err != nil { return nil, trace.Wrap(err) } - updated, err := a.updateTrustedCluster(ctx, tc, existingCluster) return updated, trace.Wrap(err) } From b6e2badbe969786d69d6c119a4472cbf4005cb61 Mon Sep 17 00:00:00 2001 From: Brian Joerger Date: Mon, 13 Jan 2025 09:57:54 -0800 Subject: [PATCH 11/15] Fix Per-session MFA for desktops (#50793) * Add sendChallengeResponse implementation for desktop sessions. * Rename useMfaTty to useMfaEmitter. --- .../teleport/src/Console/DocumentDb/DocumentDb.tsx | 4 ++-- .../src/Console/DocumentKubeExec/DocumentKubeExec.tsx | 4 ++-- .../teleport/src/Console/DocumentSsh/DocumentSsh.tsx | 4 ++-- .../teleport/src/DesktopSession/useDesktopSession.tsx | 4 ++-- web/packages/teleport/src/lib/tdp/client.ts | 9 +++++++++ web/packages/teleport/src/lib/useMfa.ts | 2 +- 6 files changed, 18 insertions(+), 9 deletions(-) diff --git a/web/packages/teleport/src/Console/DocumentDb/DocumentDb.tsx b/web/packages/teleport/src/Console/DocumentDb/DocumentDb.tsx index e17bed66fe6b2..780f03e1d788f 100644 --- a/web/packages/teleport/src/Console/DocumentDb/DocumentDb.tsx +++ b/web/packages/teleport/src/Console/DocumentDb/DocumentDb.tsx @@ -24,7 +24,7 @@ import AuthnDialog from 'teleport/components/AuthnDialog'; import Document from 'teleport/Console/Document'; import { Terminal, TerminalRef } from 'teleport/Console/DocumentSsh/Terminal'; import * as stores from 'teleport/Console/stores/types'; -import { useMfaTty } from 'teleport/lib/useMfa'; +import { useMfaEmitter } from 'teleport/lib/useMfa'; import { ConnectDialog } from './ConnectDialog'; import { useDbSession } from './useDbSession'; @@ -37,7 +37,7 @@ type Props = { export function DocumentDb({ doc, visible }: Props) { const terminalRef = useRef(); const { tty, status, closeDocument, sendDbConnectData } = useDbSession(doc); - const mfa = useMfaTty(tty); + const mfa = useMfaEmitter(tty); useEffect(() => { // when switching tabs or closing tabs, focus on visible terminal terminalRef.current?.focus(); diff --git a/web/packages/teleport/src/Console/DocumentKubeExec/DocumentKubeExec.tsx b/web/packages/teleport/src/Console/DocumentKubeExec/DocumentKubeExec.tsx index 5a250c9d4b3f1..1d382b40dc91c 100644 --- a/web/packages/teleport/src/Console/DocumentKubeExec/DocumentKubeExec.tsx +++ b/web/packages/teleport/src/Console/DocumentKubeExec/DocumentKubeExec.tsx @@ -25,7 +25,7 @@ import Document from 'teleport/Console/Document'; import useKubeExecSession from 'teleport/Console/DocumentKubeExec/useKubeExecSession'; import { Terminal, TerminalRef } from 'teleport/Console/DocumentSsh/Terminal'; import * as stores from 'teleport/Console/stores/types'; -import { useMfaTty } from 'teleport/lib/useMfa'; +import { useMfaEmitter } from 'teleport/lib/useMfa'; import KubeExecData from './KubeExecDataDialog'; @@ -38,7 +38,7 @@ export default function DocumentKubeExec({ doc, visible }: Props) { const terminalRef = useRef(); const { tty, status, closeDocument, sendKubeExecData } = useKubeExecSession(doc); - const mfa = useMfaTty(tty); + const mfa = useMfaEmitter(tty); useEffect(() => { // when switching tabs or closing tabs, focus on visible terminal terminalRef.current?.focus(); diff --git a/web/packages/teleport/src/Console/DocumentSsh/DocumentSsh.tsx b/web/packages/teleport/src/Console/DocumentSsh/DocumentSsh.tsx index 6cf952ccfc292..b7a2b93534f84 100644 --- a/web/packages/teleport/src/Console/DocumentSsh/DocumentSsh.tsx +++ b/web/packages/teleport/src/Console/DocumentSsh/DocumentSsh.tsx @@ -30,7 +30,7 @@ import { TerminalSearch } from 'shared/components/TerminalSearch'; import AuthnDialog from 'teleport/components/AuthnDialog'; import * as stores from 'teleport/Console/stores'; -import { useMfa, useMfaTty } from 'teleport/lib/useMfa'; +import { useMfa, useMfaEmitter } from 'teleport/lib/useMfa'; import { MfaChallengeScope } from 'teleport/services/auth/auth'; import { useConsoleContext } from '../consoleContextProvider'; @@ -54,7 +54,7 @@ function DocumentSsh({ doc, visible }: PropTypes) { const { tty, status, closeDocument, session } = useSshSession(doc); const [showSearch, setShowSearch] = useState(false); - const ttyMfa = useMfaTty(tty); + const ttyMfa = useMfaEmitter(tty); const ftMfa = useMfa({ isMfaRequired: ttyMfa.required, req: { diff --git a/web/packages/teleport/src/DesktopSession/useDesktopSession.tsx b/web/packages/teleport/src/DesktopSession/useDesktopSession.tsx index 75367eeae955e..e5b1446b09b4a 100644 --- a/web/packages/teleport/src/DesktopSession/useDesktopSession.tsx +++ b/web/packages/teleport/src/DesktopSession/useDesktopSession.tsx @@ -24,7 +24,7 @@ import useAttempt from 'shared/hooks/useAttemptNext'; import type { UrlDesktopParams } from 'teleport/config'; import { ButtonState } from 'teleport/lib/tdp'; -import { useMfaTty } from 'teleport/lib/useMfa'; +import { useMfaEmitter } from 'teleport/lib/useMfa'; import desktopService from 'teleport/services/desktops'; import userService from 'teleport/services/user'; @@ -129,7 +129,7 @@ export default function useDesktopSession() { }); const tdpClient = clientCanvasProps.tdpClient; - const mfa = useMfaTty(tdpClient); + const mfa = useMfaEmitter(tdpClient); const onShareDirectory = () => { try { diff --git a/web/packages/teleport/src/lib/tdp/client.ts b/web/packages/teleport/src/lib/tdp/client.ts index 5434a504631cd..83250b8bddbc6 100644 --- a/web/packages/teleport/src/lib/tdp/client.ts +++ b/web/packages/teleport/src/lib/tdp/client.ts @@ -25,6 +25,7 @@ import init, { import { AuthenticatedWebSocket } from 'teleport/lib/AuthenticatedWebSocket'; import { EventEmitterMfaSender } from 'teleport/lib/EventEmitterMfaSender'; import { TermEvent, WebsocketCloseCode } from 'teleport/lib/term/enums'; +import { MfaChallengeResponse } from 'teleport/services/mfa'; import Codec, { FileType, @@ -619,6 +620,14 @@ export default class Client extends EventEmitterMfaSender { this.send(this.codec.encodeClipboardData(clipboardData)); } + sendChallengeResponse(data: MfaChallengeResponse) { + const msg = this.codec.encodeMfaJson({ + mfaType: 'n', + jsonString: JSON.stringify(data), + }); + this.send(msg); + } + addSharedDirectory(sharedDirectory: FileSystemDirectoryHandle) { try { this.sdManager.add(sharedDirectory); diff --git a/web/packages/teleport/src/lib/useMfa.ts b/web/packages/teleport/src/lib/useMfa.ts index 54d1299c65648..4d014f10e23ba 100644 --- a/web/packages/teleport/src/lib/useMfa.ts +++ b/web/packages/teleport/src/lib/useMfa.ts @@ -172,7 +172,7 @@ export function useMfa({ req, isMfaRequired }: MfaProps): MfaState { }; } -export function useMfaTty(emitterSender: EventEmitterMfaSender): MfaState { +export function useMfaEmitter(emitterSender: EventEmitterMfaSender): MfaState { const [mfaRequired, setMfaRequired] = useState(false); const mfa = useMfa({ isMfaRequired: mfaRequired }); From 47f4498b76c049b7a3ad250bc3df4fc297bd017d Mon Sep 17 00:00:00 2001 From: Matt Brock Date: Mon, 13 Jan 2025 13:27:58 -0600 Subject: [PATCH 12/15] Adding the Azure sync module functions along with new cloud client functionality (#50366) * Protobuf and configuration for Access Graph Azure Discovery * Adding the Azure sync module functions along with new cloud client functionality * Forgot to decouple role definitions fetching function from the fetcher * Moving reconciliation to the upstream azure sync PR * Moving reconciliation test to the upstream azure sync PR * Updating go.sum * Fixing rebase after protobuf gen * Nolinting until upstream PRs * Updating to use existing msgraph client * Adding protection around nil values * PR feedback * Updating principal fetching to incorporate metadata from principal subtypes * Updating opts to not leak URL parameters * Conformant package name * Using variadic options * PR feedback * Removing memberOf expansion * Expanding memberships by calling memberOf on each user * Also returning expanded principals for improved readability * Removing ptrToList * PR feedback * Rebase go.sum stuff * Go mod tidy * Linting * Linting * Collecting errors from fetching memberships and using a WithContext error group * Fixing go.mod * Update lib/msgraph/paginated.go Co-authored-by: Tiago Silva * PR feedback * e ref update * e ref update * Fixing method * Fetching group members from groups rather than memberships of each principal * Linting --------- Co-authored-by: Tiago Silva --- go.mod | 1 + go.sum | 2 + integrations/event-handler/go.mod | 1 + integrations/event-handler/go.sum | 2 + integrations/terraform/go.mod | 1 + integrations/terraform/go.sum | 2 + lib/cloud/azure/roleassignments.go | 57 ++++++++++++ lib/cloud/azure/roledefinitions.go | 57 ++++++++++++ lib/cloud/clients.go | 28 +++++- lib/msgraph/paginated.go | 8 ++ .../fetchers/azure-sync/memberships.go | 65 ++++++++++++++ .../fetchers/azure-sync/principals.go | 87 +++++++++++++++++++ .../fetchers/azure-sync/roleassignments.go | 68 +++++++++++++++ .../fetchers/azure-sync/roledefinitions.go | 78 +++++++++++++++++ .../fetchers/azure-sync/virtualmachines.go | 61 +++++++++++++ 15 files changed, 517 insertions(+), 1 deletion(-) create mode 100644 lib/cloud/azure/roleassignments.go create mode 100644 lib/cloud/azure/roledefinitions.go create mode 100644 lib/srv/discovery/fetchers/azure-sync/memberships.go create mode 100644 lib/srv/discovery/fetchers/azure-sync/principals.go create mode 100644 lib/srv/discovery/fetchers/azure-sync/roleassignments.go create mode 100644 lib/srv/discovery/fetchers/azure-sync/roledefinitions.go create mode 100644 lib/srv/discovery/fetchers/azure-sync/virtualmachines.go diff --git a/go.mod b/go.mod index 625a780eb3ff6..c5594219a47bc 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( connectrpc.com/connect v1.18.0 github.com/Azure/azure-sdk-for-go/sdk/azcore v1.17.0 github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.8.0 + github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v2 v2.2.0 github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6 v6.2.0 github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v6 v6.3.0 github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/msi/armmsi v1.2.0 diff --git a/go.sum b/go.sum index 5bf38ba7fc0c4..af8ff9ce4acc7 100644 --- a/go.sum +++ b/go.sum @@ -668,6 +668,8 @@ github.com/Azure/azure-sdk-for-go/sdk/internal v1.0.0/go.mod h1:eWRD7oawr1Mu1sLC github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.1/go.mod h1:eWRD7oawr1Mu1sLCawqVc0CUiF43ia3qQMxLscsKQ9w= github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 h1:ywEEhmNahHBihViHepv3xPBn1663uRv2t2q/ESv9seY= github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0/go.mod h1:iZDifYGJTIgIIkYRNWPENUnqx6bJ2xnSDFI2tjwZNuY= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v2 v2.2.0 h1:Hp+EScFOu9HeCbeW8WU2yQPJd4gGwhMgKxWe+G6jNzw= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v2 v2.2.0/go.mod h1:/pz8dyNQe+Ey3yBp/XuYz7oqX8YDNWVpPB0hH3XWfbc= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6 v6.2.0 h1:JAebRMoc3vL+Nd97GBprHYHucO4+wlW+tNbBIumqJlk= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6 v6.2.0/go.mod h1:zflC9v4VfViJrSvcvplqws/yGXVbUEMZi/iHpZdSPWA= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5 v5.0.0 h1:5n7dPVqsWfVKw+ZiEKSd3Kzu7gwBkbEBkeXb8rgaE9Q= diff --git a/integrations/event-handler/go.mod b/integrations/event-handler/go.mod index 19d919b359e39..2a4ec93e2f6ac 100644 --- a/integrations/event-handler/go.mod +++ b/integrations/event-handler/go.mod @@ -37,6 +37,7 @@ require ( github.com/Azure/azure-sdk-for-go/sdk/azcore v1.17.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.8.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v2 v2.2.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6 v6.2.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v6 v6.3.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/msi/armmsi v1.2.0 // indirect diff --git a/integrations/event-handler/go.sum b/integrations/event-handler/go.sum index 1f0435df0d184..fbd5df9b4923f 100644 --- a/integrations/event-handler/go.sum +++ b/integrations/event-handler/go.sum @@ -631,6 +631,8 @@ github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.0 h1:+m0M/LFxN43KvUL github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.0/go.mod h1:PwOyop78lveYMRs6oCxjiVyBdyCgIYH6XHIVZO9/SFQ= github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 h1:ywEEhmNahHBihViHepv3xPBn1663uRv2t2q/ESv9seY= github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0/go.mod h1:iZDifYGJTIgIIkYRNWPENUnqx6bJ2xnSDFI2tjwZNuY= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v2 v2.2.0 h1:Hp+EScFOu9HeCbeW8WU2yQPJd4gGwhMgKxWe+G6jNzw= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v2 v2.2.0/go.mod h1:/pz8dyNQe+Ey3yBp/XuYz7oqX8YDNWVpPB0hH3XWfbc= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6 v6.2.0 h1:JAebRMoc3vL+Nd97GBprHYHucO4+wlW+tNbBIumqJlk= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6 v6.2.0/go.mod h1:zflC9v4VfViJrSvcvplqws/yGXVbUEMZi/iHpZdSPWA= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5 v5.0.0 h1:5n7dPVqsWfVKw+ZiEKSd3Kzu7gwBkbEBkeXb8rgaE9Q= diff --git a/integrations/terraform/go.mod b/integrations/terraform/go.mod index 5222dc914a105..3f0a69be92443 100644 --- a/integrations/terraform/go.mod +++ b/integrations/terraform/go.mod @@ -43,6 +43,7 @@ require ( github.com/Azure/azure-sdk-for-go/sdk/azcore v1.17.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.8.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v2 v2.2.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6 v6.2.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v6 v6.3.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/msi/armmsi v1.2.0 // indirect diff --git a/integrations/terraform/go.sum b/integrations/terraform/go.sum index da4bca430e263..6c0be667fb1a2 100644 --- a/integrations/terraform/go.sum +++ b/integrations/terraform/go.sum @@ -644,6 +644,8 @@ github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.0 h1:+m0M/LFxN43KvUL github.com/Azure/azure-sdk-for-go/sdk/azidentity/cache v0.3.0/go.mod h1:PwOyop78lveYMRs6oCxjiVyBdyCgIYH6XHIVZO9/SFQ= github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 h1:ywEEhmNahHBihViHepv3xPBn1663uRv2t2q/ESv9seY= github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0/go.mod h1:iZDifYGJTIgIIkYRNWPENUnqx6bJ2xnSDFI2tjwZNuY= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v2 v2.2.0 h1:Hp+EScFOu9HeCbeW8WU2yQPJd4gGwhMgKxWe+G6jNzw= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v2 v2.2.0/go.mod h1:/pz8dyNQe+Ey3yBp/XuYz7oqX8YDNWVpPB0hH3XWfbc= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6 v6.2.0 h1:JAebRMoc3vL+Nd97GBprHYHucO4+wlW+tNbBIumqJlk= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6 v6.2.0/go.mod h1:zflC9v4VfViJrSvcvplqws/yGXVbUEMZi/iHpZdSPWA= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/containerservice/armcontainerservice/v5 v5.0.0 h1:5n7dPVqsWfVKw+ZiEKSd3Kzu7gwBkbEBkeXb8rgaE9Q= diff --git a/lib/cloud/azure/roleassignments.go b/lib/cloud/azure/roleassignments.go new file mode 100644 index 0000000000000..114bceef88b96 --- /dev/null +++ b/lib/cloud/azure/roleassignments.go @@ -0,0 +1,57 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package azure + +import ( + "context" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v2" + "github.com/gravitational/trace" +) + +// RoleAssignmentsClient wraps the Azure API to provide a high level subset of functionality +type RoleAssignmentsClient struct { + cli *armauthorization.RoleAssignmentsClient +} + +// NewRoleAssignmentsClient creates a new client for a given subscription and credentials +func NewRoleAssignmentsClient(subscription string, cred azcore.TokenCredential, options *arm.ClientOptions) (*RoleAssignmentsClient, error) { + clientFactory, err := armauthorization.NewClientFactory(subscription, cred, options) + if err != nil { + return nil, trace.Wrap(err) + } + roleDefCli := clientFactory.NewRoleAssignmentsClient() + return &RoleAssignmentsClient{cli: roleDefCli}, nil +} + +// ListRoleAssignments returns role assignments for a given scope +func (c *RoleAssignmentsClient) ListRoleAssignments(ctx context.Context, scope string) ([]*armauthorization.RoleAssignment, error) { + pager := c.cli.NewListForScopePager(scope, nil) + var roleDefs []*armauthorization.RoleAssignment + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + roleDefs = append(roleDefs, page.Value...) + } + return roleDefs, nil +} diff --git a/lib/cloud/azure/roledefinitions.go b/lib/cloud/azure/roledefinitions.go new file mode 100644 index 0000000000000..cdc46196aa530 --- /dev/null +++ b/lib/cloud/azure/roledefinitions.go @@ -0,0 +1,57 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package azure + +import ( + "context" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v2" + "github.com/gravitational/trace" +) + +// RoleDefinitionsClient wraps the Azure API to provide a high level subset of functionality +type RoleDefinitionsClient struct { + cli *armauthorization.RoleDefinitionsClient +} + +// NewRoleDefinitionsClient creates a new client for a given subscription and credentials +func NewRoleDefinitionsClient(subscription string, cred azcore.TokenCredential, options *arm.ClientOptions) (*RoleDefinitionsClient, error) { + clientFactory, err := armauthorization.NewClientFactory(subscription, cred, options) + if err != nil { + return nil, trace.Wrap(err) + } + roleDefCli := clientFactory.NewRoleDefinitionsClient() + return &RoleDefinitionsClient{cli: roleDefCli}, nil +} + +// ListRoleDefinitions returns role definitions for a given scope +func (c *RoleDefinitionsClient) ListRoleDefinitions(ctx context.Context, scope string) ([]*armauthorization.RoleDefinition, error) { + pager := c.cli.NewListPager(scope, nil) + var roleDefs []*armauthorization.RoleDefinition + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + roleDefs = append(roleDefs, page.Value...) + } + return roleDefs, nil +} diff --git a/lib/cloud/clients.go b/lib/cloud/clients.go index cc50c98c1ba4f..638658e761e48 100644 --- a/lib/cloud/clients.go +++ b/lib/cloud/clients.go @@ -344,6 +344,10 @@ type azureClients struct { azurePostgresFlexServersClients azure.ClientMap[azure.PostgresFlexServersClient] // azureRunCommandClients contains the cached Azure Run Command clients. azureRunCommandClients azure.ClientMap[azure.RunCommandClient] + // azureRoleDefinitionsClients contains the cached Azure Role Definitions clients. + azureRoleDefinitionsClients azure.ClientMap[azure.RoleDefinitionsClient] + // azureRoleAssignmentsClients contains the cached Azure Role Assignments clients. + azureRoleAssignmentsClients azure.ClientMap[azure.RoleAssignmentsClient] } // credentialsSource defines where the credentials must come from. @@ -717,6 +721,16 @@ func (c *cloudClients) GetAzureRunCommandClient(subscription string) (azure.RunC return c.azureRunCommandClients.Get(subscription, c.GetAzureCredential) } +// GetAzureRoleDefinitionsClient returns an Azure Role Definitions client +func (c *cloudClients) GetAzureRoleDefinitionsClient(subscription string) (azure.RoleDefinitionsClient, error) { + return c.azureRoleDefinitionsClients.Get(subscription, c.GetAzureCredential) +} + +// GetAzureRoleAssignmentsClient returns an Azure Role Assignments client +func (c *cloudClients) GetAzureRoleAssignmentsClient(subscription string) (azure.RoleAssignmentsClient, error) { + return c.azureRoleAssignmentsClients.Get(subscription, c.GetAzureCredential) +} + // Close closes all initialized clients. func (c *cloudClients) Close() (err error) { c.mtx.Lock() @@ -1021,6 +1035,8 @@ type TestCloudClients struct { AzureMySQLFlex azure.MySQLFlexServersClient AzurePostgresFlex azure.PostgresFlexServersClient AzureRunCommand azure.RunCommandClient + AzureRoleDefinitions azure.RoleDefinitionsClient + AzureRoleAssignments azure.RoleAssignmentsClient } // GetAWSSession returns AWS session for the specified region, optionally @@ -1244,11 +1260,21 @@ func (c *TestCloudClients) GetAzurePostgresFlexServersClient(subscription string return c.AzurePostgresFlex, nil } -// GetAzureRunCommand returns an Azure Run Command client for the given subscription. +// GetAzureRunCommandClient returns an Azure Run Command client for the given subscription. func (c *TestCloudClients) GetAzureRunCommandClient(subscription string) (azure.RunCommandClient, error) { return c.AzureRunCommand, nil } +// GetAzureRoleDefinitionsClient returns an Azure Role Definitions client for the given subscription. +func (c *TestCloudClients) GetAzureRoleDefinitionsClient(subscription string) (azure.RoleDefinitionsClient, error) { + return c.AzureRoleDefinitions, nil +} + +// GetAzureRoleAssignmentsClient returns an Azure Role Assignments client for the given subscription. +func (c *TestCloudClients) GetAzureRoleAssignmentsClient(subscription string) (azure.RoleAssignmentsClient, error) { + return c.AzureRoleAssignments, nil +} + // Close closes all initialized clients. func (c *TestCloudClients) Close() error { return nil diff --git a/lib/msgraph/paginated.go b/lib/msgraph/paginated.go index 51c587f19d074..a0b9488af9d70 100644 --- a/lib/msgraph/paginated.go +++ b/lib/msgraph/paginated.go @@ -101,6 +101,14 @@ func (c *Client) IterateUsers(ctx context.Context, f func(*User) bool) error { return iterateSimple(c, ctx, "users", f) } +// IterateServicePrincipals lists all service principals in the Entra ID directory using pagination. +// `f` will be called for each object in the result set. +// if `f` returns `false`, the iteration is stopped (equivalent to `break` in a normal loop). +// Ref: [https://learn.microsoft.com/en-us/graph/api/serviceprincipal-list]. +func (c *Client) IterateServicePrincipals(ctx context.Context, f func(principal *ServicePrincipal) bool) error { + return iterateSimple(c, ctx, "servicePrincipals", f) +} + // IterateGroupMembers lists all members for the given Entra ID group using pagination. // `f` will be called for each object in the result set. // if `f` returns `false`, the iteration is stopped (equivalent to `break` in a normal loop). diff --git a/lib/srv/discovery/fetchers/azure-sync/memberships.go b/lib/srv/discovery/fetchers/azure-sync/memberships.go new file mode 100644 index 0000000000000..f05be8f72567c --- /dev/null +++ b/lib/srv/discovery/fetchers/azure-sync/memberships.go @@ -0,0 +1,65 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package azuresync + +import ( + "context" + + "github.com/gravitational/trace" + "golang.org/x/sync/errgroup" + + accessgraphv1alpha "github.com/gravitational/teleport/gen/proto/go/accessgraph/v1alpha" + "github.com/gravitational/teleport/lib/msgraph" +) + +const parallelism = 10 //nolint:unused // invoked in a dependent PR + +// expandMemberships adds membership data to AzurePrincipal objects by querying the Graph API for group memberships +func expandMemberships(ctx context.Context, cli *msgraph.Client, principals []*accessgraphv1alpha.AzurePrincipal) ([]*accessgraphv1alpha.AzurePrincipal, error) { //nolint:unused // invoked in a dependent PR + // Map principals by ID + var principalsMap = make(map[string]*accessgraphv1alpha.AzurePrincipal) + for _, principal := range principals { + principalsMap[principal.Id] = principal + } + // Iterate through the Azure groups and add the group ID as a membership for its corresponding principal + eg, _ := errgroup.WithContext(ctx) + eg.SetLimit(parallelism) + errCh := make(chan error, len(principals)) + for _, principal := range principals { + if principal.ObjectType != "group" { + continue + } + group := principal + eg.Go(func() error { + err := cli.IterateGroupMembers(ctx, group.Id, func(member msgraph.GroupMember) bool { + if memberPrincipal, ok := principalsMap[*member.GetID()]; ok { + memberPrincipal.MemberOf = append(memberPrincipal.MemberOf, group.Id) + } + return true + }) + if err != nil { + errCh <- err + } + return nil + }) + } + _ = eg.Wait() + close(errCh) + return principals, trace.NewAggregateFromChannel(errCh, ctx) +} diff --git a/lib/srv/discovery/fetchers/azure-sync/principals.go b/lib/srv/discovery/fetchers/azure-sync/principals.go new file mode 100644 index 0000000000000..073d6c4713e0c --- /dev/null +++ b/lib/srv/discovery/fetchers/azure-sync/principals.go @@ -0,0 +1,87 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package azuresync + +import ( + "context" + + "github.com/gravitational/trace" + "google.golang.org/protobuf/types/known/timestamppb" + + accessgraphv1alpha "github.com/gravitational/teleport/gen/proto/go/accessgraph/v1alpha" + "github.com/gravitational/teleport/lib/msgraph" +) + +type dirObjMetadata struct { //nolint:unused // invoked in a dependent PR + objectType string +} + +type queryResult struct { //nolint:unused // invoked in a dependent PR + metadata dirObjMetadata + dirObj msgraph.DirectoryObject +} + +// fetchPrincipals fetches the Azure principals (users, groups, and service principals) using the Graph API +func fetchPrincipals(ctx context.Context, subscriptionID string, cli *msgraph.Client) ([]*accessgraphv1alpha.AzurePrincipal, error) { //nolint: unused // invoked in a dependent PR + // Fetch the users, groups, and service principals as directory objects + var queryResults []queryResult + err := cli.IterateUsers(ctx, func(user *msgraph.User) bool { + res := queryResult{metadata: dirObjMetadata{objectType: "user"}, dirObj: user.DirectoryObject} + queryResults = append(queryResults, res) + return true + }) + if err != nil { + return nil, trace.Wrap(err) + } + err = cli.IterateGroups(ctx, func(group *msgraph.Group) bool { + res := queryResult{metadata: dirObjMetadata{objectType: "group"}, dirObj: group.DirectoryObject} + queryResults = append(queryResults, res) + return true + }) + if err != nil { + return nil, trace.Wrap(err) + } + err = cli.IterateServicePrincipals(ctx, func(servicePrincipal *msgraph.ServicePrincipal) bool { + res := queryResult{metadata: dirObjMetadata{objectType: "servicePrincipal"}, dirObj: servicePrincipal.DirectoryObject} + queryResults = append(queryResults, res) + return true + }) + if err != nil { + return nil, trace.Wrap(err) + } + + // Return the users, groups, and service principals as protobuf messages + var fetchErrs []error + var pbPrincipals []*accessgraphv1alpha.AzurePrincipal + for _, res := range queryResults { + if res.dirObj.ID == nil || res.dirObj.DisplayName == nil { + fetchErrs = append(fetchErrs, + trace.BadParameter("nil values on msgraph directory object: %v", res.dirObj)) + continue + } + pbPrincipals = append(pbPrincipals, &accessgraphv1alpha.AzurePrincipal{ + Id: *res.dirObj.ID, + SubscriptionId: subscriptionID, + LastSyncTime: timestamppb.Now(), + DisplayName: *res.dirObj.DisplayName, + ObjectType: res.metadata.objectType, + }) + } + return pbPrincipals, trace.NewAggregate(fetchErrs...) +} diff --git a/lib/srv/discovery/fetchers/azure-sync/roleassignments.go b/lib/srv/discovery/fetchers/azure-sync/roleassignments.go new file mode 100644 index 0000000000000..a97fe69727ef8 --- /dev/null +++ b/lib/srv/discovery/fetchers/azure-sync/roleassignments.go @@ -0,0 +1,68 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package azuresync + +import ( + "context" + "fmt" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v2" + "github.com/gravitational/trace" + "google.golang.org/protobuf/types/known/timestamppb" + + accessgraphv1alpha "github.com/gravitational/teleport/gen/proto/go/accessgraph/v1alpha" +) + +// RoleAssignmentsClient specifies the methods used to fetch role assignments from Azure +type RoleAssignmentsClient interface { + ListRoleAssignments(ctx context.Context, scope string) ([]*armauthorization.RoleAssignment, error) +} + +// fetchRoleAssignments fetches Azure role assignments using the Azure role assignments API +func fetchRoleAssignments(ctx context.Context, subscriptionID string, cli RoleAssignmentsClient) ([]*accessgraphv1alpha.AzureRoleAssignment, error) { //nolint:unused // invoked in a dependent PR + // List the role definitions + roleAssigns, err := cli.ListRoleAssignments(ctx, fmt.Sprintf("/subscriptions/%s", subscriptionID)) + if err != nil { + return nil, trace.Wrap(err) + } + + // Convert to protobuf format + pbRoleAssigns := make([]*accessgraphv1alpha.AzureRoleAssignment, 0, len(roleAssigns)) + var fetchErrs []error + for _, roleAssign := range roleAssigns { + if roleAssign.ID == nil || + roleAssign.Properties == nil || + roleAssign.Properties.PrincipalID == nil || + roleAssign.Properties.Scope == nil { + fetchErrs = append(fetchErrs, + trace.BadParameter("nil values on AzureRoleAssignment object: %v", roleAssign)) + continue + } + pbRoleAssign := &accessgraphv1alpha.AzureRoleAssignment{ + Id: *roleAssign.ID, + SubscriptionId: subscriptionID, + LastSyncTime: timestamppb.Now(), + PrincipalId: *roleAssign.Properties.PrincipalID, + RoleDefinitionId: *roleAssign.Properties.RoleDefinitionID, + Scope: *roleAssign.Properties.Scope, + } + pbRoleAssigns = append(pbRoleAssigns, pbRoleAssign) + } + return pbRoleAssigns, trace.NewAggregate(fetchErrs...) +} diff --git a/lib/srv/discovery/fetchers/azure-sync/roledefinitions.go b/lib/srv/discovery/fetchers/azure-sync/roledefinitions.go new file mode 100644 index 0000000000000..485117f898b81 --- /dev/null +++ b/lib/srv/discovery/fetchers/azure-sync/roledefinitions.go @@ -0,0 +1,78 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package azuresync + +import ( + "context" + "fmt" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v2" + "github.com/gravitational/trace" + "google.golang.org/protobuf/types/known/timestamppb" + + accessgraphv1alpha "github.com/gravitational/teleport/gen/proto/go/accessgraph/v1alpha" + "github.com/gravitational/teleport/lib/utils/slices" +) + +// RoleDefinitionsClient specifies the methods used to fetch roles from Azure +type RoleDefinitionsClient interface { + ListRoleDefinitions(ctx context.Context, scope string) ([]*armauthorization.RoleDefinition, error) +} + +func fetchRoleDefinitions(ctx context.Context, subscriptionID string, cli RoleDefinitionsClient) ([]*accessgraphv1alpha.AzureRoleDefinition, error) { //nolint:unused // used in a dependent PR + // List the role definitions + roleDefs, err := cli.ListRoleDefinitions(ctx, fmt.Sprintf("/subscriptions/%s", subscriptionID)) + if err != nil { + return nil, trace.Wrap(err) + } + + // Convert to protobuf format + pbRoleDefs := make([]*accessgraphv1alpha.AzureRoleDefinition, 0, len(roleDefs)) + var fetchErrs []error + for _, roleDef := range roleDefs { + if roleDef.ID == nil || + roleDef.Properties == nil || + roleDef.Properties.Permissions == nil || + roleDef.Properties.RoleName == nil { + fetchErrs = append(fetchErrs, trace.BadParameter("nil values on AzureRoleDefinition object: %v", roleDef)) + continue + } + pbPerms := make([]*accessgraphv1alpha.AzureRBACPermission, 0, len(roleDef.Properties.Permissions)) + for _, perm := range roleDef.Properties.Permissions { + if perm.Actions == nil && perm.NotActions == nil { + fetchErrs = append(fetchErrs, trace.BadParameter("nil values on Permission object: %v", perm)) + continue + } + pbPerm := accessgraphv1alpha.AzureRBACPermission{ + Actions: slices.FromPointers(perm.Actions), + NotActions: slices.FromPointers(perm.NotActions), + } + pbPerms = append(pbPerms, &pbPerm) + } + pbRoleDef := &accessgraphv1alpha.AzureRoleDefinition{ + Id: *roleDef.ID, + Name: *roleDef.Properties.RoleName, + SubscriptionId: subscriptionID, + LastSyncTime: timestamppb.Now(), + Permissions: pbPerms, + } + pbRoleDefs = append(pbRoleDefs, pbRoleDef) + } + return pbRoleDefs, trace.NewAggregate(fetchErrs...) +} diff --git a/lib/srv/discovery/fetchers/azure-sync/virtualmachines.go b/lib/srv/discovery/fetchers/azure-sync/virtualmachines.go new file mode 100644 index 0000000000000..cf0d068db7b0c --- /dev/null +++ b/lib/srv/discovery/fetchers/azure-sync/virtualmachines.go @@ -0,0 +1,61 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package azuresync + +import ( + "context" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" + "github.com/gravitational/trace" + "google.golang.org/protobuf/types/known/timestamppb" + + accessgraphv1alpha "github.com/gravitational/teleport/gen/proto/go/accessgraph/v1alpha" +) + +const allResourceGroups = "*" //nolint:unused // invoked in a dependent PR + +// VirtualMachinesClient specifies the methods used to fetch virtual machines from Azure +type VirtualMachinesClient interface { + ListVirtualMachines(ctx context.Context, resourceGroup string) ([]*armcompute.VirtualMachine, error) +} + +func fetchVirtualMachines(ctx context.Context, subscriptionID string, cli VirtualMachinesClient) ([]*accessgraphv1alpha.AzureVirtualMachine, error) { //nolint:unused // invoked in a dependent PR + vms, err := cli.ListVirtualMachines(ctx, allResourceGroups) + if err != nil { + return nil, trace.Wrap(err) + } + + // Return the VMs as protobuf messages + pbVms := make([]*accessgraphv1alpha.AzureVirtualMachine, 0, len(vms)) + var fetchErrs []error + for _, vm := range vms { + if vm.ID == nil || vm.Name == nil { + fetchErrs = append(fetchErrs, trace.BadParameter("nil values on AzureVirtualMachine object: %v", vm)) + continue + } + pbVm := accessgraphv1alpha.AzureVirtualMachine{ + Id: *vm.ID, + SubscriptionId: subscriptionID, + LastSyncTime: timestamppb.Now(), + Name: *vm.Name, + } + pbVms = append(pbVms, &pbVm) + } + return pbVms, trace.NewAggregate(fetchErrs...) +} From 62ad3fe6c7bdfa23fe770e4f3c842d4252f49485 Mon Sep 17 00:00:00 2001 From: Marco Dinis Date: Mon, 13 Jan 2025 19:57:38 +0000 Subject: [PATCH 13/15] Fix EKS Discover User Task reporting (#50989) * Fix EKS Discover User Task reporting The `clusterNames` slice and `clusterByNames` key set must be the same. When there was two groups of EKS Clusters, one with App Discovery enabled and another one with it disabled, we had different set of clusters being processed. `clusterNames` had all the EKS Clusters, while `clusterByNames` only had the EKS Clusters for one of the processing groups (either AppDiscovery=on or AppDiscovery=off). This meant that when the `EnrollEKSClusters` returned an error, we looked up the map, but it might be the case that that particular EKS Cluster was not configured for the current processing group. So, the `clusterByNames[r.EksClusterName]` returned a nil value, which resulted in a panic. * add test * check if cluster exists in local map --- lib/srv/discovery/discovery_test.go | 108 +++++++++++++++++- lib/srv/discovery/kube_integration_watcher.go | 10 +- 2 files changed, 113 insertions(+), 5 deletions(-) diff --git a/lib/srv/discovery/discovery_test.go b/lib/srv/discovery/discovery_test.go index 2948e10cdb916..5e9d3d1acf7e6 100644 --- a/lib/srv/discovery/discovery_test.go +++ b/lib/srv/discovery/discovery_test.go @@ -322,6 +322,31 @@ func TestDiscoveryServer(t *testing.T) { ) require.NoError(t, err) + discoveryConfigWithAndWithoutAppDiscoveryTestName := uuid.NewString() + discoveryConfigWithAndWithoutAppDiscovery, err := discoveryconfig.NewDiscoveryConfig( + header.Metadata{Name: discoveryConfigWithAndWithoutAppDiscoveryTestName}, + discoveryconfig.Spec{ + DiscoveryGroup: defaultDiscoveryGroup, + AWS: []types.AWSMatcher{ + { + Types: []string{"eks"}, + Regions: []string{"eu-west-2"}, + Tags: map[string]utils.Strings{"EnableAppDiscovery": {"No"}}, + Integration: "my-integration", + KubeAppDiscovery: false, + }, + { + Types: []string{"eks"}, + Regions: []string{"eu-west-2"}, + Tags: map[string]utils.Strings{"EnableAppDiscovery": {"Yes"}}, + Integration: "my-integration", + KubeAppDiscovery: true, + }, + }, + }, + ) + require.NoError(t, err) + tcs := []struct { name string // presentInstances is a list of servers already present in teleport. @@ -754,6 +779,74 @@ func TestDiscoveryServer(t *testing.T) { require.Equal(t, defaultDiscoveryGroup, taskCluster.DiscoveryGroup) }, }, + { + name: "multiple EKS clusters with different KubeAppDiscovery setting failed to autoenroll and user tasks are created", + presentInstances: []types.Server{}, + foundEC2Instances: []ec2types.Instance{}, + ssm: &mockSSMClient{}, + eksClusters: []*ekstypes.Cluster{ + { + Name: aws.String("cluster01"), + Arn: aws.String("arn:aws:eks:us-west-2:123456789012:cluster/cluster01"), + Status: ekstypes.ClusterStatusActive, + Tags: map[string]string{ + "EnableAppDiscovery": "Yes", + }, + }, + { + Name: aws.String("cluster02"), + Arn: aws.String("arn:aws:eks:us-west-2:123456789012:cluster/cluster02"), + Status: ekstypes.ClusterStatusActive, + Tags: map[string]string{ + "EnableAppDiscovery": "No", + }, + }, + }, + eksEnroller: &mockEKSClusterEnroller{ + resp: &integrationpb.EnrollEKSClustersResponse{ + Results: []*integrationpb.EnrollEKSClusterResult{ + { + EksClusterName: "cluster01", + Error: "access endpoint is not reachable", + IssueType: "eks-cluster-unreachable", + }, + { + EksClusterName: "cluster02", + Error: "access endpoint is not reachable", + IssueType: "eks-cluster-unreachable", + }, + }, + }, + err: nil, + }, + emitter: &mockEmitter{}, + staticMatchers: Matchers{}, + discoveryConfig: discoveryConfigWithAndWithoutAppDiscovery, + wantInstalledInstances: []string{}, + userTasksDiscoverCheck: func(t require.TestingT, i1 interface{}, i2 ...interface{}) { + existingTasks, ok := i1.([]*usertasksv1.UserTask) + require.True(t, ok, "failed to get existing tasks: %T", i1) + require.Len(t, existingTasks, 2) + existingTask := existingTasks[0] + if existingTask.Spec.DiscoverEks.AppAutoDiscover == false { + existingTask = existingTasks[1] + } + + require.Equal(t, "OPEN", existingTask.GetSpec().State) + require.Equal(t, "my-integration", existingTask.GetSpec().Integration) + require.Equal(t, "eks-cluster-unreachable", existingTask.GetSpec().IssueType) + require.Equal(t, "123456789012", existingTask.GetSpec().GetDiscoverEks().GetAccountId()) + require.Equal(t, "us-west-2", existingTask.GetSpec().GetDiscoverEks().GetRegion()) + + taskClusters := existingTask.GetSpec().GetDiscoverEks().Clusters + require.Contains(t, taskClusters, "cluster01") + taskCluster := taskClusters["cluster01"] + + require.Equal(t, "cluster01", taskCluster.Name) + require.Equal(t, discoveryConfigWithAndWithoutAppDiscoveryTestName, taskCluster.DiscoveryConfig) + require.Equal(t, defaultDiscoveryGroup, taskCluster.DiscoveryGroup) + }, + }, } for _, tc := range tcs { @@ -3528,8 +3621,19 @@ type mockEKSClusterEnroller struct { err error } -func (m *mockEKSClusterEnroller) EnrollEKSClusters(context.Context, *integrationpb.EnrollEKSClustersRequest, ...grpc.CallOption) (*integrationpb.EnrollEKSClustersResponse, error) { - return m.resp, m.err +func (m *mockEKSClusterEnroller) EnrollEKSClusters(ctx context.Context, req *integrationpb.EnrollEKSClustersRequest, opt ...grpc.CallOption) (*integrationpb.EnrollEKSClustersResponse, error) { + ret := &integrationpb.EnrollEKSClustersResponse{ + Results: []*integrationpb.EnrollEKSClusterResult{}, + } + // Filter out non-requested clusters. + for _, clusterName := range req.EksClusterNames { + for _, mockClusterResult := range m.resp.Results { + if clusterName == mockClusterResult.EksClusterName { + ret.Results = append(ret.Results, mockClusterResult) + } + } + } + return ret, m.err } type combinedDiscoveryClient struct { diff --git a/lib/srv/discovery/kube_integration_watcher.go b/lib/srv/discovery/kube_integration_watcher.go index ffbecf6497359..88d89f258f8c4 100644 --- a/lib/srv/discovery/kube_integration_watcher.go +++ b/lib/srv/discovery/kube_integration_watcher.go @@ -21,6 +21,7 @@ package discovery import ( "context" "fmt" + "maps" "slices" "strings" "sync" @@ -243,14 +244,13 @@ func (s *Server) enrollEKSClusters(region, integration, discoveryConfigName stri } ctx, cancel := context.WithTimeout(s.ctx, time.Duration(len(clusters))*30*time.Second) defer cancel() - var clusterNames []string for _, kubeAppDiscovery := range []bool{true, false} { clustersByName := make(map[string]types.DiscoveredEKSCluster) for _, c := range batchedClusters[kubeAppDiscovery] { - clusterNames = append(clusterNames, c.GetAWSConfig().Name) clustersByName[c.GetAWSConfig().Name] = c } + clusterNames := slices.Collect(maps.Keys(clustersByName)) if len(clusterNames) == 0 { continue } @@ -283,7 +283,11 @@ func (s *Server) enrollEKSClusters(region, integration, discoveryConfigName stri s.Log.DebugContext(ctx, "EKS cluster already has installed kube agent", "cluster_name", r.EksClusterName) } - cluster := clustersByName[r.EksClusterName] + cluster, ok := clustersByName[r.EksClusterName] + if !ok { + s.Log.WarnContext(ctx, "Received an EnrollEKSCluster result for a cluster which was not part of the requested clusters", "cluster_name", r.EksClusterName, "clusters_install_request", clusterNames) + continue + } s.awsEKSTasks.addFailedEnrollment( awsEKSTaskKey{ integration: integration, From ce30037005c6cf04e22894fd9fca130f5375332f Mon Sep 17 00:00:00 2001 From: Brian Joerger Date: Mon, 13 Jan 2025 12:13:12 -0800 Subject: [PATCH 14/15] Add SSO MFA docs (#50533) * Add SSO MFA docs. * Address comments from zmb3. * Fix links; minor style fix. * Address comments. * Try removing leading / in example links. * Address Nic's comments. --- .../admin-guides/access-controls/sso/sso.mdx | 112 +++++++++++++++++- examples/resources/oidc-connector-mfa.yaml | 33 ++++++ examples/resources/saml-connector-mfa.yaml | 29 +++++ 3 files changed, 173 insertions(+), 1 deletion(-) create mode 100644 examples/resources/oidc-connector-mfa.yaml create mode 100644 examples/resources/saml-connector-mfa.yaml diff --git a/docs/pages/admin-guides/access-controls/sso/sso.mdx b/docs/pages/admin-guides/access-controls/sso/sso.mdx index 26c0003ea9128..76cd88c08b182 100644 --- a/docs/pages/admin-guides/access-controls/sso/sso.mdx +++ b/docs/pages/admin-guides/access-controls/sso/sso.mdx @@ -213,7 +213,7 @@ spec: - '2001:db8::/96' ``` -## Configuring SSO +## Configuring SSO for login Teleport works with SSO providers by relying on the concept of an **authentication connector**. An authentication connector is a configuration @@ -411,6 +411,116 @@ values to match your identity provider: At this time, the `spec.provider` field should not be set for any other identity providers. +## Configuring SSO for MFA checks + +Teleport administrators can configure Teleport to delegate MFA checks to an +SSO provider as an alternative to registering MFA devices directly with the Teleport cluster. +This allows Teleport users to use MFA devices and custom flows configured in the SSO provider +to carry out privileged actions in Teleport, such as: + +- [Per-session MFA](../guides/per-session-mfa.mdx) +- [Moderated sessions](../guides/moderated-sessions.mdx) +- [Admin actions](../guides/mfa-for-admin-actions.mdx) + +Administrators may want to consider enabling this feature in order to: + +- Make all authentication (login and MFA) go through the IDP, reducing administrative overhead +- Make custom MFA flows, such as prompting for 2 distinct devices for a single MFA check +- Integrate with non-webauthn devices supported directly by your IDP + + + SSO MFA is an enterprise feature. Only OIDC and SAML auth connectors are supported. + + +### Configure the IDP App / Client + +There is no standardized MFA flow unlike there is with SAML/OIDC +login, so each IDP may offer zero, one, or more ways to offer MFA checks. + +Generally, these offerings will fall under one of the following cases: + +1. Use a separate IDP app for MFA: + +You can create a separate IDP app with a custom MFA flow. For example, with +Auth0 (OIDC), you can create a separate app with a custom [Auth0 Action](https://auth0.com/docs/customize/actions) +which prompts for MFA for an active OIDC session. + +2. Use the same IDP app for MFA: + +Some IDPs provide a way to fork to different flows using the same IDP app. +For example, with Okta (OIDC), you can provide `acr_values: ["phr"]` to +[enforce phishing resistant authentication](https://developer.okta.com/docs/guides/step-up-authentication/main/#predefined-parameter-values). + +For a simpler approach, you could use the same IDP app for both login and MFA +with no adjustments. For Teleport MFA checks, the user will be required to +relogin through the IDP with username, password, and MFA if required. + + + While the customizability of SSO MFA presents multiple secure options previously + unavailable to administrators, it also presents the possibility of insecure + misconfigurations. Therefore, we strongly advice administrators to incorporate + strict, phishing-resistant checks with WebAuthn, Device Trust, or some similar + security features into their custom SSO MFA flow. + + +### Updating your authentication connector to enable MFA checks + +Take the authentication connector file `connector.yaml` created in [Configuring SSO for login](#configuring-sso-for-login) +and add MFA settings. + + + + +```yaml +(!examples/resources/oidc-connector-mfa.yaml!) +``` + + + + +```yaml +(!examples/resources/saml-connector-mfa.yaml!) +``` + +You may use `entity_descriptor_url` in lieu of `entity_descriptor` to fetch +the entity descriptor from your IDP. + +We recommend "pinning" the entity descriptor by including the XML rather than +fetching from a URL. + + + + +Update the connector: + +```code +$ tctl create -f connector.yaml +``` + +### Allowing SSO as an MFA method in your cluster + +Before you can use the SSO MFA flow we created above, you need to enable SSO +as a second factor in your cluster settings. Modify the dynamic config resource +using the following command: + +```code +$ tctl edit cluster_auth_preference +``` + +Make the following change: + +```diff +kind: cluster_auth_preference +version: v2 +metadata: + name: cluster-auth-preference +spec: + # ... + second_factors: + - webauthn ++ - sso +``` + ## Working with an external email identity Along with sending groups, an SSO provider will also provide a user's email address. diff --git a/examples/resources/oidc-connector-mfa.yaml b/examples/resources/oidc-connector-mfa.yaml new file mode 100644 index 0000000000000..ca56b727d1487 --- /dev/null +++ b/examples/resources/oidc-connector-mfa.yaml @@ -0,0 +1,33 @@ +kind: oidc +version: v3 +metadata: + name: oidc_connector +spec: + # Login settings + client_id: + client_secret: + # issuer_url and redirect_url are shared by both login and MFA, meaning the same OIDC provider must be used. + issuer_url: https://idp.example.com/ + redirect_url: https://mytenant.teleport.sh:443/v1/webapi/oidc/callback + # ... + + # MFA settings + mfa: + # Enabled specified whether this OIDC connector supports MFA checks. + enabled: true + # client_id and client_secret should point to an IdP configured + # app configured to handle MFA checks. In most cases, these values + # should be different from your login client ID and Secret above. + client_id: + client_secret: + # prompt can be set to request a specific prompt flow from the IdP. Supported + # values depend on the IdP. + prompt: none + # acr_values are Authentication Context Class Reference values. These values + # are context-specific and vary depending on the IdP. + acr_values: [] + # max_age is the amount of time in seconds that an IdP session is valid for. + # Defaults to 0 to always force re-authentication for MFA checks. This should + # only be set to a non-zero value if the IdP is setup to perform MFA checks on + # top of active user sessions. + max_age: 0 diff --git a/examples/resources/saml-connector-mfa.yaml b/examples/resources/saml-connector-mfa.yaml new file mode 100644 index 0000000000000..9c58802ec0ace --- /dev/null +++ b/examples/resources/saml-connector-mfa.yaml @@ -0,0 +1,29 @@ +# +# Example resource for a SAML connector +# This connector can be used for SAML endpoints like Okta +# +kind: saml +version: v2 +metadata: + # the name of the connector + name: okta +spec: + # Login settings + display: Okta + entity_descriptor_url: https://example.okta.com/app//sso/saml/metadata + # acs is shared by both login and MFA, meaning the same SAML provider must be used. + acs: https:///v1/webapi/saml/acs/new_saml_connector + # ... + + # MFA settings + mfa: + # Enabled specifies whether this SAML connector supports MFA checks. + enabled: true + # entity_descriptor_url should point to an IdP configured app that handles MFA checks. + # In most cases, this value should be different from the entity_descriptor_url above. + entity_descriptor_url: https://example.okta.com/app//sso/saml/metadata + # force_reauth determines whether existing login sessions are accepted or if + # re-authentication is always required. Defaults to "yes". This should only be + # set to false if the app described above is setup to perform MFA checks on top + # of active user sessions. + force_reauth: yes \ No newline at end of file From 4034d7c3cdccef8fd2c16f082776fecfbe5811e0 Mon Sep 17 00:00:00 2001 From: Brian Joerger Date: Mon, 13 Jan 2025 15:04:06 -0800 Subject: [PATCH 15/15] Fix data race in x11 forwarding test. (#50997) --- lib/srv/regular/sshserver_test.go | 34 ++++++++++++++++++------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/lib/srv/regular/sshserver_test.go b/lib/srv/regular/sshserver_test.go index 5b095961c3c68..eb7d3283b4508 100644 --- a/lib/srv/regular/sshserver_test.go +++ b/lib/srv/regular/sshserver_test.go @@ -1349,24 +1349,30 @@ func x11EchoSession(ctx context.Context, t *testing.T, clt *tracessh.Client) x11 os.Remove(tmpFile.Name()) }) - // type 'printenv DISPLAY > /path/to/tmp/file' into the session (dumping the value of DISPLAY into the temp file) - _, err = keyboard.Write([]byte(fmt.Sprintf("printenv %v >> %s\n\r", x11.DisplayEnv, tmpFile.Name()))) - require.NoError(t, err) + // Reading the display may fail if the session is not fully initialized + // and the write to stdin is swallowed. + display := make(chan string, 1) + require.EventuallyWithT(t, func(t *assert.CollectT) { + // enter 'printenv DISPLAY > /path/to/tmp/file' into the session (dumping the value of DISPLAY into the temp file) + _, err = keyboard.Write([]byte(fmt.Sprintf("printenv %v > %s\n\r", x11.DisplayEnv, tmpFile.Name()))) + assert.NoError(t, err) - // wait for the output - var display string - require.Eventually(t, func() bool { - output, err := os.ReadFile(tmpFile.Name()) - if err == nil && len(output) != 0 { - display = strings.TrimSpace(string(output)) - return true - } - return false - }, 10*time.Second, 100*time.Millisecond, "failed to read display") + assert.Eventually(t, func() bool { + output, err := os.ReadFile(tmpFile.Name()) + if err == nil && len(output) != 0 { + select { + case display <- strings.TrimSpace(string(output)): + default: + } + return true + } + return false + }, time.Second, 100*time.Millisecond, "failed to read display") + }, 10*time.Second, 1*time.Second) // Make a new connection to the XServer proxy, the client // XServer should echo back anything written on it. - serverDisplay, err := x11.ParseDisplay(display) + serverDisplay, err := x11.ParseDisplay(<-display) require.NoError(t, err) return serverDisplay