Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HP-682 Feat/multiple paymodels #44

Merged
merged 26 commits into from
May 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/golang-ci-workflow.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ on: push

jobs:
ci:
name: golang-ci
runs-on: ubuntu-latest
env:
COVERAGE_PROFILE_OUTPUT_LOCATION: "./profile.cov"
Expand Down
26 changes: 26 additions & 0 deletions hatchery/alb.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,29 @@ func (creds *CREDS) CreateLoadBalancer(userName string) (*elbv2.CreateLoadBalanc
}
return loadBalancer, targetGroup.TargetGroups[0].TargetGroupArn, listener, nil
}

func (creds *CREDS) terminateLoadBalancer(userName string) error {
svc := elbv2.New(session.Must(session.NewSession(&aws.Config{
Credentials: creds.creds,
Region: aws.String("us-east-1"),
})))
albName := truncateString(strings.ReplaceAll(userToResourceName(userName, "service")+os.Getenv("GEN3_ENDPOINT"), ".", "-")+"alb", 32)

getInput := &elbv2.DescribeLoadBalancersInput{
Names: []*string{aws.String(albName)},
}
result, err := svc.DescribeLoadBalancers(getInput)
if err != nil {
return err
}
if len(result.LoadBalancers) == 1 {
delInput := &elbv2.DeleteLoadBalancerInput{
LoadBalancerArn: result.LoadBalancers[0].LoadBalancerArn,
}
_, err := svc.DeleteLoadBalancer(delInput)
if err != nil {
return err
}
}
return nil
}
24 changes: 17 additions & 7 deletions hatchery/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,19 +56,29 @@ type AppConfigInfo struct {

// TODO remove PayModel from config once DynamoDB contains all necessary data
type PayModel struct {
Name string `json:"name"`
User string `json:"user_id"`
AWSAccountId string `json:"aws_account_id"`
Region string `json:"region"`
Ecs string `json:"ecs"`
VpcId string `json:"vpcid"`
Subnet int `json:"subnet"`
Id string `json:"bmh_workspace_id"`
Name string `json:"workspace_type"`
User string `json:"user_id"`
AWSAccountId string `json:"account_id"`
Region string `json:"region"`
Ecs bool `json:"ecs"`
Subnet int `json:"subnet"`
HardLimit float32 `json:"hard-limit"`
SoftLimit float32 `json:"soft-limit"`
TotalUsage float32 `json:"total-usage"`
mfshao marked this conversation as resolved.
Show resolved Hide resolved
CurrentPayModel bool `json:"current_pay_model"`
}

type AllPayModels struct {
CurrentPayModel *PayModel `json:"current_pay_model"`
PayModels []PayModel `json:"all_pay_models"`
}

// HatcheryConfig is the root of all the configuration
type HatcheryConfig struct {
UserNamespace string `json:"user-namespace"`
DefaultPayModel PayModel `json:"default-pay-model"`
DisableLocalWS bool `json:"disable-local-ws"`
mfshao marked this conversation as resolved.
Show resolved Hide resolved
PayModels []PayModel `json:"pay-models"`
PayModelsDynamodbTable string `json:"pay-models-dynamodb-table"`
SubDir string `json:"sub-dir"`
Expand Down
1 change: 0 additions & 1 deletion hatchery/ec2.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ func (creds *CREDS) describeWorkspaceNetwork(userName string) (*NetworkInfo, err
}
Config.Logger.Printf("Create Security Group: %s", *newSecurityGroup.GroupId)

// TODO: Make this secure. Right now it's wide open
ingressRules := ec2.AuthorizeSecurityGroupIngressInput{
GroupId: newSecurityGroup.GroupId,
IpPermissions: []*ec2.IpPermission{
Expand Down
15 changes: 9 additions & 6 deletions hatchery/ecs.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ func (input *CreateTaskDefinitionInput) Environment() []*ecs.KeyValuePair {
}

// Create ECS cluster
// TODO: Evaluate if this is still this needed..
func (sess *CREDS) launchEcsCluster(userName string) (*ecs.Cluster, error) {
svc := sess.svc
clusterName := strings.ReplaceAll(os.Getenv("GEN3_ENDPOINT"), ".", "-") + "-cluster"
Expand Down Expand Up @@ -201,7 +200,7 @@ func (sess *CREDS) statusEcsWorkspace(ctx context.Context, userName string, acce
if err != nil {
return &status, err
}

// TODO: Check TransitGatewayAttachment is not in Deleting state (Can't create new one until it's deleted).
var taskDefName string
if len(service.Services) > 0 {
statusMessage = *service.Services[0].Status
Expand Down Expand Up @@ -330,7 +329,13 @@ func terminateEcsWorkspace(ctx context.Context, userName string, accessToken str
if err != nil {
return "", err
}
// TODO: Terminate ALB + target group here too

// Terminate load balancer
err = svc.terminateLoadBalancer(userName)
if err != nil {
return "", err
}

err = teardownTransitGateway(userName)
if err != nil {
return "", err
Expand All @@ -339,8 +344,6 @@ func terminateEcsWorkspace(ctx context.Context, userName string, accessToken str
}

func launchEcsWorkspace(ctx context.Context, userName string, hash string, accessToken string, payModel PayModel) error {
// TODO: Setup EBS volume as pd
// Must create volume using SDK too.. :(
roleARN := "arn:aws:iam::" + payModel.AWSAccountId + ":role/csoc_adminvm"
sess := session.Must(session.NewSession(&aws.Config{
// TODO: Make this configurable
Expand Down Expand Up @@ -486,6 +489,7 @@ func launchEcsWorkspace(ctx context.Context, userName string, hash string, acces
}
return err
}

err = setupTransitGateway(userName)
if err != nil {
return err
Expand All @@ -499,7 +503,6 @@ func launchEcsWorkspace(ctx context.Context, userName string, hash string, acces
}
return err
}

fmt.Printf("Launched ECS workspace service at %s for user %s\n", launchTask, userName)
return nil
}
Expand Down
4 changes: 2 additions & 2 deletions hatchery/efs.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ func (creds *CREDS) createAccessPoint(FileSystemId string, userName string, svc
if err != nil {
return nil, err
}

ap := userToResourceName(userName, "service") + "-" + strings.ReplaceAll(os.Getenv("GEN3_ENDPOINT"), ".", "-") + "-accesspoint"
if len(exResult.AccessPoints) == 0 {
input := &efs.CreateAccessPointInput{
ClientToken: aws.String(fmt.Sprintf("ap-%s", userToResourceName(userName, "pod"))),
ClientToken: aws.String(ap),
FileSystemId: aws.String(FileSystemId),
PosixUser: &efs.PosixUser{
Gid: aws.Int64(100),
Expand Down
102 changes: 85 additions & 17 deletions hatchery/hatchery.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ func RegisterHatchery(mux *httptrace.ServeMux) {
mux.HandleFunc("/status", status)
mux.HandleFunc("/options", options)
mux.HandleFunc("/paymodels", paymodels)
mux.HandleFunc("/setpaymodel", setpaymodel)
mux.HandleFunc("/allpaymodels", allpaymodels)

// ECS functions
mux.HandleFunc("/create-ecs-cluster", createECSCluster)
Expand Down Expand Up @@ -55,16 +57,65 @@ func paymodels(w http.ResponseWriter, r *http.Request) {
return
}
userName := getCurrentUserName(r)
payModel, err := getPayModelForUser(userName)

payModel, err := getCurrentPayModel(userName)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if payModel == nil {
http.Error(w, err.Error(), http.StatusNotFound)
http.Error(w, "Current paymodel not set", http.StatusNotFound)
return
}
out, err := json.Marshal(payModel)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
out, err := json.Marshal(payModel)
fmt.Fprint(w, string(out))
}

func allpaymodels(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
http.Error(w, "Not Found", http.StatusNotFound)
return
}
userName := getCurrentUserName(r)

payModels, err := getPayModelsForUser(userName)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if payModels == nil {
http.Error(w, "No paymodel set", http.StatusNotFound)
mfshao marked this conversation as resolved.
Show resolved Hide resolved
return
}
out, err := json.Marshal(payModels)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
fmt.Fprint(w, string(out))
}

func setpaymodel(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
http.Error(w, "Not Found", http.StatusNotFound)
return
}
userName := getCurrentUserName(r)
id := r.URL.Query().Get("id")
if id == "" {
http.Error(w, "Missing ID argument", http.StatusBadRequest)
return
}
pm, err := setCurrentPaymodel(userName, id)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
out, err := json.Marshal(pm)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
Expand All @@ -76,19 +127,35 @@ func status(w http.ResponseWriter, r *http.Request) {
userName := getCurrentUserName(r)
accessToken := getBearerToken(r)

payModel, err := getPayModelForUser(userName)
payModel, err := getCurrentPayModel(userName)
if err != nil {
Config.Logger.Printf(err.Error())
if err != NopaymodelsError {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
var result *WorkspaceStatus
if payModel != nil && payModel.Ecs == "true" {
result, err = statusEcs(r.Context(), userName, accessToken, payModel.AWSAccountId)
} else {

if payModel == nil {
result, err = statusK8sPod(r.Context(), userName, accessToken, payModel)
}
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
} else {
if payModel.Ecs {
result, err = statusEcs(r.Context(), userName, accessToken, payModel.AWSAccountId)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
} else {
result, err = statusK8sPod(r.Context(), userName, accessToken, payModel)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
}
}

out, err := json.Marshal(result)
Expand Down Expand Up @@ -154,18 +221,19 @@ func launch(w http.ResponseWriter, r *http.Request) {
}

userName := getCurrentUserName(r)
payModel, err := getPayModelForUser(userName)
payModel, err := getCurrentPayModel(userName)
if err != nil {
Config.Logger.Printf(err.Error())
}
if payModel == nil {
err = createLocalK8sPod(r.Context(), hash, userName, accessToken)
} else if payModel.Ecs == "true" {
} else if payModel.Ecs {
err = launchEcsWorkspace(r.Context(), userName, hash, accessToken, *payModel)
} else {
err = createExternalK8sPod(r.Context(), hash, userName, accessToken, *payModel)
}
if err != nil {
Config.Logger.Printf("error during launch: %-v", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
Expand All @@ -179,11 +247,11 @@ func terminate(w http.ResponseWriter, r *http.Request) {
}
accessToken := getBearerToken(r)
userName := getCurrentUserName(r)
payModel, err := getPayModelForUser(userName)
payModel, err := getCurrentPayModel(userName)
if err != nil {
Config.Logger.Printf(err.Error())
}
if payModel != nil && payModel.Ecs == "true" {
if payModel != nil && payModel.Ecs {
svc, err := terminateEcsWorkspace(r.Context(), userName, accessToken, payModel.AWSAccountId)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
Expand Down Expand Up @@ -219,7 +287,7 @@ func getBearerToken(r *http.Request) string {
// TODO: NEED TO CALL THIS FUNCTION IF IT DOESN'T EXIST!!!
func createECSCluster(w http.ResponseWriter, r *http.Request) {
userName := getCurrentUserName(r)
payModel, err := getPayModelForUser(userName)
payModel, err := getCurrentPayModel(userName)
if payModel == nil {
http.Error(w, "Paymodel has not been setup for user", http.StatusNotFound)
return
Expand Down
10 changes: 7 additions & 3 deletions hatchery/iam.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ func (creds *CREDS) taskRole(userName string) (*string, error) {
Credentials: creds.creds,
Region: aws.String("us-east-1"),
})))
pm := Config.PayModelMap[userName]
pm, err := getCurrentPayModel(userName)
if err != nil {
return nil, err
}
policyArn := fmt.Sprintf("arn:aws:iam::%s:policy/%s", pm.AWSAccountId, fmt.Sprintf("ws-task-policy-%s", userName))
taskRoleInput := &iam.GetRoleInput{
RoleName: aws.String(userToResourceName(userName, "pod")),
Expand Down Expand Up @@ -96,8 +99,9 @@ func (creds *CREDS) taskRole(userName string) (*string, error) {
}

}
// https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task_execution_IAM_role.html
// The task execution role grants the Amazon ECS container and Fargate agents permission to make AWS API calls on your behalf.

// https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task_execution_IAM_role.html
// The task execution role grants the Amazon ECS container and Fargate agents permission to make AWS API calls on your behalf.
const ecsTaskExecutionRoleName = "ecsTaskExecutionRole"
const ecsTaskExecutionPolicyArn = "arn:aws:iam::aws:policy/service-role/AmazonECSTaskExecutionRolePolicy"
const ecsTaskExecutionRoleAssumeRolePolicyDocument = `{
Expand Down
Loading