diff --git a/cmd/setup-gh.go b/cmd/setup-gh.go index e75ad9f5..621fadaa 100644 --- a/cmd/setup-gh.go +++ b/cmd/setup-gh.go @@ -7,6 +7,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud" "github.com/Azure/azure-sdk-for-go/sdk/azidentity" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/subscription/armsubscription" "github.com/Azure/draft/pkg/cred" "github.com/manifoldco/promptui" @@ -50,6 +51,13 @@ application and service principle, and will configure that application to trust sc.AzClient.GraphClient = graphClient + roleAssignmentClient, err := armauthorization.NewRoleAssignmentsClient(sc.SubscriptionID, azCred, nil) + if err != nil { + return fmt.Errorf("getting role assignment client: %w", err) + } + + sc.AzClient.RoleAssignClient = roleAssignmentClient + fillSetUpConfig(sc) s := spinner.CreateSpinner("--> Setting up Github OIDC...") diff --git a/go.mod b/go.mod index eb0d7b60..6ff1a307 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.22.0 require ( github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1 github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.5.2 + github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization v1.0.0 github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/subscription/armsubscription v1.2.0 github.com/briandowns/spinner v1.23.0 github.com/cenkalti/backoff/v4 v4.3.0 diff --git a/go.sum b/go.sum index 96b9e145..d853ad21 100644 --- a/go.sum +++ b/go.sum @@ -17,6 +17,8 @@ github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.5.2 h1:FDif4R1+UUR+00q6wquyX github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.5.2/go.mod h1:aiYBYui4BJ/BJCAIKs92XiPyQfTaBWqvHujDwKb6CBU= github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.2 h1:LqbJ/WzJUwBf8UiaSzgX7aMclParm9/5Vgp+TY51uBQ= github.com/Azure/azure-sdk-for-go/sdk/internal v1.5.2/go.mod h1:yInRyqWXAuaPrgI7p70+lDDgh3mlBohis29jGMISnmc= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization v1.0.0 h1:qtRcg5Y7jNJ4jEzPq4GpWLfTspHdNe2ZK6LjwGcjgmU= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization v1.0.0/go.mod h1:lPneRe3TwsoDRKY4O6YDLXHhEWrD+TIRa8XrV/3/fqw= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/subscription/armsubscription v1.2.0 h1:UrGzkHueDwAWDdjQxC+QaXHd4tVCkISYE9j7fSSXF8k= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/subscription/armsubscription v1.2.0/go.mod h1:qskvSQeW+cxEE2bcKYyKimB1/KiQ9xpJ99bcHY0BX6c= github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8= diff --git a/pkg/providers/az-client.go b/pkg/providers/az-client.go index 0c483435..7971304c 100644 --- a/pkg/providers/az-client.go +++ b/pkg/providers/az-client.go @@ -4,14 +4,18 @@ import ( "context" "errors" "fmt" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/subscription/armsubscription" + msgraph "github.com/microsoftgraph/msgraph-sdk-go" ) type AzClient struct { - AzTenantClient azTenantClient - GraphClient GraphClient + AzTenantClient azTenantClient + GraphClient GraphClient + RoleAssignClient RoleAssignClient } //go:generate mockgen -source=./az-client.go -destination=./mock/az-client.go . @@ -43,3 +47,9 @@ func (g *GraphServiceClient) GetApplicationObjectId(ctx context.Context, appId s } return *appObjectId, nil } + +type RoleAssignClient interface { + CreateByID(ctx context.Context, roleAssignmentID string, parameters armauthorization.RoleAssignmentCreateParameters, options *armauthorization.RoleAssignmentsClientCreateByIDOptions) (armauthorization.RoleAssignmentsClientCreateByIDResponse, error) +} + +var _ RoleAssignClient = &armauthorization.RoleAssignmentsClient{} diff --git a/pkg/providers/azure.go b/pkg/providers/azure.go index d1f3affd..d0920ea2 100644 --- a/pkg/providers/azure.go +++ b/pkg/providers/azure.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/subscription/armsubscription" "os/exec" "time" @@ -61,7 +62,7 @@ func InitiateAzureOIDCFlow(ctx context.Context, sc *SetUpCmd, s spinner.Spinner) return err } - if err := sc.assignSpRole(); err != nil { + if err := sc.assignSpRole(ctx); err != nil { return err } @@ -165,14 +166,22 @@ func (sc *SetUpCmd) CreateServicePrincipal() error { return nil } -func (sc *SetUpCmd) assignSpRole() error { +func (sc *SetUpCmd) assignSpRole(ctx context.Context) error { log.Debug("Assigning contributor role to service principal...") - scope := fmt.Sprintf("/subscriptions/%s/resourceGroups/%s", sc.SubscriptionID, sc.ResourceGroupName) - assignSpRoleCmd := exec.Command("az", "role", "assignment", "create", "--role", "contributor", "--subscription", sc.SubscriptionID, "--assignee-object-id", sc.spObjectId, "--assignee-principal-type", "ServicePrincipal", "--scope", scope, "--only-show-errors") - out, err := assignSpRoleCmd.CombinedOutput() + + objectID := sc.spObjectId + roleID := "contributor" + + parameters := armauthorization.RoleAssignmentCreateParameters{ + Properties: &armauthorization.RoleAssignmentProperties{ + PrincipalID: &objectID, + RoleDefinitionID: &roleID, + }, + } + + _, err := sc.AzClient.RoleAssignClient.CreateByID(ctx, roleID, parameters, nil) if err != nil { - log.Printf("%s\n", out) - return err + return fmt.Errorf("creating role assignment: %w", err) } log.Debug("Role assigned successfully!") diff --git a/pkg/providers/azure_test.go b/pkg/providers/azure_test.go index 6cc3c284..0dab8658 100644 --- a/pkg/providers/azure_test.go +++ b/pkg/providers/azure_test.go @@ -5,6 +5,7 @@ import ( "errors" "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/azcore/tracing" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/subscription/armsubscription" mock_providers "github.com/Azure/draft/pkg/providers/mock" "go.uber.org/mock/gomock" @@ -243,3 +244,66 @@ func TestGetAppObjectId_EmptyAppIdFromGraphClient(t *testing.T) { t.Errorf("Expected error '%v', got '%v'", expectedError, err) } } + +var principalId = "mockPrincipalID" +var roleDefId = "mockRoleDefinitionID" +var Id = "mockID" +var name = "mockName" +var Idtype = "mocktype" + +func TestAssignSpRole(t *testing.T) { + tests := []struct { + name string + expectedError error + mockResponse armauthorization.RoleAssignmentsClientCreateByIDResponse + }{ + { + name: "Success", + expectedError: nil, + mockResponse: armauthorization.RoleAssignmentsClientCreateByIDResponse{ + RoleAssignment: armauthorization.RoleAssignment{ + Properties: &armauthorization.RoleAssignmentPropertiesWithScope{ + PrincipalID: &principalId, + RoleDefinitionID: &roleDefId, + }, + ID: &Id, + Name: &name, + Type: &Idtype, + }, + }, + }, + { + name: "Error", + expectedError: errors.New("error"), + mockResponse: armauthorization.RoleAssignmentsClientCreateByIDResponse{}, + }, + { + name: "ErrorDuringRoleAssignment", + expectedError: errors.New("error during role assignment"), + mockResponse: armauthorization.RoleAssignmentsClientCreateByIDResponse{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockRoleAssignClient := mock_providers.NewMockRoleAssignClient(ctrl) + + mockRoleAssignClient.EXPECT().CreateByID(gomock.Any(), "contributor", gomock.Any(), gomock.Any()).Return(tt.mockResponse, tt.expectedError) + + sc := &SetUpCmd{ + AzClient: AzClient{ + RoleAssignClient: mockRoleAssignClient, + }, + spObjectId: "testObjectId", + } + + err := sc.assignSpRole(context.Background()) + if !errors.Is(err, tt.expectedError) { + t.Errorf("Expected error: %v, got: %v", tt.expectedError, err) + } + }) + } +} diff --git a/pkg/providers/mock/az-client.go b/pkg/providers/mock/az-client.go index bcf69751..7514ffc9 100644 --- a/pkg/providers/mock/az-client.go +++ b/pkg/providers/mock/az-client.go @@ -14,6 +14,7 @@ import ( reflect "reflect" runtime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + armauthorization "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization" armsubscription "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/subscription/armsubscription" gomock "go.uber.org/mock/gomock" ) @@ -92,3 +93,41 @@ func (mr *MockGraphClientMockRecorder) GetApplicationObjectId(ctx, appId any) *g mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetApplicationObjectId", reflect.TypeOf((*MockGraphClient)(nil).GetApplicationObjectId), ctx, appId) } + +// MockRoleAssignClient is a mock of RoleAssignClient interface. +type MockRoleAssignClient struct { + ctrl *gomock.Controller + recorder *MockRoleAssignClientMockRecorder +} + +// MockRoleAssignClientMockRecorder is the mock recorder for MockRoleAssignClient. +type MockRoleAssignClientMockRecorder struct { + mock *MockRoleAssignClient +} + +// NewMockRoleAssignClient creates a new mock instance. +func NewMockRoleAssignClient(ctrl *gomock.Controller) *MockRoleAssignClient { + mock := &MockRoleAssignClient{ctrl: ctrl} + mock.recorder = &MockRoleAssignClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRoleAssignClient) EXPECT() *MockRoleAssignClientMockRecorder { + return m.recorder +} + +// CreateByID mocks base method. +func (m *MockRoleAssignClient) CreateByID(ctx context.Context, roleAssignmentID string, parameters armauthorization.RoleAssignmentCreateParameters, options *armauthorization.RoleAssignmentsClientCreateByIDOptions) (armauthorization.RoleAssignmentsClientCreateByIDResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateByID", ctx, roleAssignmentID, parameters, options) + ret0, _ := ret[0].(armauthorization.RoleAssignmentsClientCreateByIDResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateByID indicates an expected call of CreateByID. +func (mr *MockRoleAssignClientMockRecorder) CreateByID(ctx, roleAssignmentID, parameters, options any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateByID", reflect.TypeOf((*MockRoleAssignClient)(nil).CreateByID), ctx, roleAssignmentID, parameters, options) +}