From 2adec8dd31fc9d96851bf460787ce813837e6791 Mon Sep 17 00:00:00 2001 From: Flora Thiebaut Date: Thu, 16 May 2024 19:06:55 +0200 Subject: [PATCH] feat: support connected services (#1874) Add support to mount repositories from external sources in session. This requires `connected services` support from `renku-data-services`. Also, the API for starting Renku 2.0 sessions has been modified to accept external repositories. Changes in `notebooks`: * Refactored the `UserServer` class to be an abstract class. `Renku1UserServer` and `Renku2UserServer` inherit from `UserServer` and contain the corresponding adaptations needed to support Renku 1.0 and Renku 2.0 sessions respectively. * Changed the `amalthea` patches to reflect changes in `git-clone` and `git-proxy`. * Updated the `patch_statefulset_tokens()` method invoked when resuming sessions. Changes in `git-clone`: * Refactored the `git-clone` container to support cloning repositories from any source. * Cloning is done optimistically. The `git-clone` container will not crash if a repository cannot be cloned. * For private repositories, a `git provider` must be configured and will be used to clone. * The container now only uses the `renku_access_token`. Changes in `git-proxy`: * Refactored the `git-proxy` container to support injecting credentials from external services. * The `git-proxy` is a simple pass-through for anonymous sessions (meaning it should probably not run for anonymous sessions). * Repositories matching a configured `git provider` will have credentials injected: - From `renku-gateway-auth` for the internal GitLab - From `renku-data-services` for other services * The container now only uses the `renku_access_token` and the `renku_refresh_token`. --- git-https-proxy/config/config.go | 401 ++++------- git-https-proxy/go.mod | 26 +- git-https-proxy/go.sum | 74 +- git-https-proxy/main.go | 134 +--- git-https-proxy/main_test.go | 147 ---- git-https-proxy/proxy/main.go | 122 ++++ git-https-proxy/tokenstore/main.go | 233 ++++++ .../main_test.go} | 134 ++-- git_services/git_services/cli/__init__.py | 5 +- git_services/git_services/init/clone.py | 13 +- git_services/git_services/init/cloner.py | 176 +++-- git_services/git_services/init/config.py | 60 +- git_services/git_services/init/errors.py | 4 + git_services/git_services/sidecar/errors.py | 14 +- git_services/tests/test_init_clone.py | 58 +- .../api/amalthea_patches/git_proxy.py | 88 +-- .../api/amalthea_patches/git_sidecar.py | 18 +- .../api/amalthea_patches/init_containers.py | 65 +- .../amalthea_patches/inject_certificates.py | 6 +- .../api/amalthea_patches/jupyter_server.py | 10 +- renku_notebooks/api/classes/data_service.py | 85 +++ renku_notebooks/api/classes/k8s_client.py | 193 ++--- renku_notebooks/api/classes/repository.py | 58 ++ renku_notebooks/api/classes/server.py | 667 ++++++++---------- renku_notebooks/api/notebooks.py | 86 ++- renku_notebooks/api/schemas/repository.py | 13 +- renku_notebooks/api/schemas/servers_post.py | 2 +- renku_notebooks/config/__init__.py | 22 + renku_notebooks/util/kubernetes_.py | 98 ++- tests/unit/test_server_class/test_manifest.py | 29 + 30 files changed, 1745 insertions(+), 1296 deletions(-) delete mode 100644 git-https-proxy/main_test.go create mode 100644 git-https-proxy/proxy/main.go create mode 100644 git-https-proxy/tokenstore/main.go rename git-https-proxy/{config/config_test.go => tokenstore/main_test.go} (50%) create mode 100644 renku_notebooks/api/classes/repository.py diff --git a/git-https-proxy/config/config.go b/git-https-proxy/config/config.go index 0468bf1fb..8cab96974 100644 --- a/git-https-proxy/config/config.go +++ b/git-https-proxy/config/config.go @@ -1,307 +1,206 @@ -// Package config stores the configuration for a git proxy. -// It also manages and refreshes the renku and git oauth tokens it stores. package config import ( - "encoding/base64" "encoding/json" "fmt" - "log" - "net/http" "net/url" - "os" - "strconv" - "strings" - "sync" + "reflect" "time" - "github.com/golang-jwt/jwt/v4" + "github.com/mitchellh/mapstructure" + "github.com/spf13/viper" ) +type GitRepository struct { + Url string `json:"url"` + Provider string `json:"provider"` +} + +type GitProvider struct { + Id string `json:"id"` + AccessTokenUrl string `json:"access_token_url"` +} + type GitProxyConfig struct { // The port where the proxy is listening on - ProxyPort string + ProxyPort int `mapstructure:"port"` // The port (separate from the proxy) where the proxy will respond to status probes - HealthPort string + HealthPort int `mapstructure:"health_port"` // True if this is an anonymous session - AnonymousSession bool - // The Git oauth token injected in Git requests by the proxy - not guaranteed to be a JWT. - // Gitlab oauth tokens are not JWT tokens. - gitAccessToken string - // The unix epoch timestamp (in seconds) when the Git Oauth token expires - gitAccessTokenExpiresAt int64 + AnonymousSession bool `mapstructure:"anonymous_session"` // The oauth access token issued by Keycloak to a logged in Renku user - renkuAccessToken string + RenkuAccessToken string `mapstructure:"renku_access_token"` // The oauth refresh token issued by Keycloak to a logged in Renku user // It is assumed that the refresh tokens do not expire after use and can be reused. // This means that the 'Revoke Refresh Token' setting in the Renku realm in Keycloak // is not enabled. - renkuRefreshToken string + RenkuRefreshToken string `mapstructure:"renku_refresh_token"` + // The url of the renku deployment + RenkuURL *url.URL `mapstructure:"renku_url"` // The name of the Renku realm in Keycloak - renkuRealm string + RenkuRealm string `mapstructure:"renku_realm"` // The Keycloak client ID to which the access token and refresh tokens were issued to - renkuClientID string + RenkuClientID string `mapstructure:"renku_client_id"` // The client secret for the client ID - renkuClientSecret string - // The URL of the project repository for the session - RepoURL *url.URL - // The url of the renku deployment - RenkuURL *url.URL - // Used when the Git oauth token is refreshed. Ensures that the token is not refereshed - // twice at the same time. It also ensures that all other threads that need to simply - // read the token will wait until the refresh (write) is complete. - gitAccessTokenLock *sync.RWMutex - renkuAccessTokenLock *sync.RWMutex - // Safety margin for when to consider a token expired. For example if this is set to - // 30 seconds then the token is considered expired if it expires in the next 30 seconds. - expiredLeeway time.Duration - // Channel that is populated by the timer that triggers the automated renku access token refresh - refreshTicker *time.Ticker + RenkuClientSecret string `mapstructure:"renku_client_secret"` + // The git repositories to proxy + Repositories []GitRepository `mapstructure:"repositories"` + // The git providers + Providers []GitProvider `mapstructure:"providers"` + // The time interval used for refreshing renku tokens + RefreshCheckPeriodSeconds int64 `mapstructure:"refresh_check_period_seconds"` } -// Parse the environment variables used as the configuration for the proxy. -func ParseEnv() *GitProxyConfig { - var ok, anonymousSession bool - var gitOauthToken, proxyPort, healthPort, anonymousSessionStr, renkuAccessToken, renkuClientID, renkuRealm, renkuClientSecret, renkuRefreshToken, renkuURL, gitOauthTokenExpiresAtRaw, refreshCheckPeriodSeconds, repoURL string - var parsedRepoURL *url.URL - var err error - var gitOauthTokenExpiresAt int64 - if proxyPort, ok = os.LookupEnv("GIT_PROXY_PORT"); !ok { - proxyPort = "8080" - } - if healthPort, ok = os.LookupEnv("GIT_PROXY_HEALTH_PORT"); !ok { - healthPort = "8081" - } - if anonymousSessionStr, ok = os.LookupEnv("ANONYMOUS_SESSION"); !ok { - anonymousSessionStr = "true" - } - anonymousSession = anonymousSessionStr == "true" - if renkuAccessToken, ok = os.LookupEnv("RENKU_ACCESS_TOKEN"); !ok { - log.Fatal("Cannot find required 'RENKU_ACCESS_TOKEN' environment variable\n") - } - if renkuRefreshToken, ok = os.LookupEnv("RENKU_REFRESH_TOKEN"); !ok { - log.Fatal("Cannot find required 'RENKU_REFRESH_TOKEN' environment variable\n") - } - if renkuClientID, ok = os.LookupEnv("RENKU_CLIENT_ID"); !ok { - log.Fatal("Cannot find required 'RENKU_CLIENT_ID' environment variable\n") - } - if renkuClientSecret, ok = os.LookupEnv("RENKU_CLIENT_SECRET"); !ok { - log.Fatal("Cannot find required 'RENKU_CLIENT_SECRET' environment variable\n") - } - if renkuRealm, ok = os.LookupEnv("RENKU_REALM"); !ok { - log.Fatal("Cannot find required 'RENKU_REALM' environment variable\n") - } - if gitOauthToken, ok = os.LookupEnv("GITLAB_OAUTH_TOKEN"); !ok { - log.Fatal("Cannot find required 'GITLAB_OAUTH_TOKEN' environment variable\n") - } - if gitOauthTokenExpiresAtRaw, ok = os.LookupEnv("GITLAB_OAUTH_TOKEN_EXPIRES_AT"); !ok { - log.Fatal("Cannot find required 'GITLAB_OAUTH_TOKEN_EXPIRES_AT' environment variable\n") +func GetConfig() (GitProxyConfig, error) { + v := viper.New() + v.SetConfigType("env") + v.SetEnvPrefix("git_proxy") + v.AutomaticEnv() + + v.SetDefault("port", 8080) + v.SetDefault("health_port", 8081) + v.SetDefault("anonymous_session", true) + v.SetDefault("renku_access_token", "") + v.SetDefault("renku_refresh_token", "") + v.SetDefault("renku_url", nil) + v.SetDefault("renku_realm", "") + v.SetDefault("renku_client_id", "") + v.SetDefault("renku_client_secret", "") + v.SetDefault("repositories", []GitRepository{}) + v.SetDefault("providers", []GitProvider{}) + v.SetDefault("refresh_check_period_seconds", 600) + + var config GitProxyConfig + dh := viper.DecodeHook(mapstructure.ComposeDecodeHookFunc( + parseStringAsURL(), + parseJsonArray(), + parseJsonVariable(), + )) + if err := v.Unmarshal(&config, dh); err != nil { + return GitProxyConfig{}, err } - if gitOauthTokenExpiresAt, err = strconv.ParseInt(gitOauthTokenExpiresAtRaw, 10, 64); err != nil { - log.Fatalf("Cannot convert 'GITLAB_OAUTH_TOKEN_EXPIRES_AT' environment variable %s to integer\n", gitOauthTokenExpiresAtRaw) + + return config, nil +} + +func (c *GitProxyConfig) Validate() error { + //? INFO: The proxy is a pass-through for anonymous sessions, so no config is required. + if c.AnonymousSession { + return nil } - if repoURL, ok = os.LookupEnv("REPOSITORY_URL"); !ok { - log.Fatalln("Cannot find required 'REPOSITORY_URL' environment variable") + if c.RenkuAccessToken == "" { + return fmt.Errorf("the renku access token is not defined") } - parsedRepoURL, err = url.Parse(repoURL) - if err != nil { - log.Fatalf("Cannot parse 'REPOSITORY_URL': %s", err.Error()) + if c.RenkuRefreshToken == "" { + return fmt.Errorf("the renku refresh token is not defined") } - if renkuURL, ok = os.LookupEnv("RENKU_URL"); !ok { - log.Fatal("Cannot find required 'RENKU_URL' environment variable\n") + if c.RenkuURL == nil { + return fmt.Errorf("the renku URL is not defined") } - parsedRenkuURL, err := url.Parse(renkuURL) - if err != nil { - log.Fatalf("Cannot parse 'RENKU_URL' %s: %s", renkuURL, err.Error()) + if c.RenkuRealm == "" { + return fmt.Errorf("the renku realm is not defined") } - if refreshCheckPeriodSeconds, ok = os.LookupEnv("REFRESH_CHECK_PERIOD_SECONDS"); !ok { - refreshCheckPeriodSeconds = "600" + if c.RenkuClientID == "" { + return fmt.Errorf("the renku client id is not defined") } - refreshCheckPeriodSecondsParsed, err := strconv.ParseInt(refreshCheckPeriodSeconds, 10, 64) - if err != nil { - log.Fatalf("Cannot parse refresh period as integer %s: %s\n", refreshCheckPeriodSeconds, err.Error()) + if c.RenkuClientSecret == "" { + return fmt.Errorf("the renku client secret is not defined") } - config := GitProxyConfig{ - ProxyPort: proxyPort, - HealthPort: healthPort, - AnonymousSession: anonymousSession, - gitAccessToken: gitOauthToken, - gitAccessTokenExpiresAt: gitOauthTokenExpiresAt, - renkuAccessToken: renkuAccessToken, - renkuRefreshToken: renkuRefreshToken, - renkuClientID: renkuClientID, - renkuClientSecret: renkuClientSecret, - renkuRealm: renkuRealm, - RepoURL: parsedRepoURL, - RenkuURL: parsedRenkuURL, - gitAccessTokenLock: &sync.RWMutex{}, - renkuAccessTokenLock: &sync.RWMutex{}, - expiredLeeway: time.Second * time.Duration(refreshCheckPeriodSecondsParsed) * 4, - refreshTicker: time.NewTicker(time.Second * time.Duration(refreshCheckPeriodSecondsParsed)), + if c.RefreshCheckPeriodSeconds <= 0 { + return fmt.Errorf("the refresh token period is invalid") } - // Start a go routine to keep the refresh token valid - go config.periodicTokenRefresh() - return &config + return nil } -func (c *GitProxyConfig) getRenkuAccessToken() string { - c.renkuAccessTokenLock.RLock() - defer c.renkuAccessTokenLock.RUnlock() - return c.renkuAccessToken +func (c *GitProxyConfig) GetRefreshCheckPeriod() time.Duration { + return time.Duration(c.RefreshCheckPeriodSeconds) * time.Second } -// getRenkuAccessToken checks if the token is expired and if it is it will renew the token -// and return a new valid access token. If the token is valid it simply returns the access token. -func (c *GitProxyConfig) getAndRefreshRenkuAccessToken() (string, error) { - isExpired, err := c.isJWTExpired(c.getRenkuAccessToken()) - if err != nil { - return "", err - } - if isExpired { - err = c.refreshRenkuAccessToken() - if err != nil { - return "", err - } - } - return c.getRenkuAccessToken(), nil +func (c *GitProxyConfig) GetExpirationLeeway() time.Duration { + return 4 * c.GetRefreshCheckPeriod() } -// GetGitAccessToken will return a valid gitlab access token. If the token is expired -// it will call the gateway to get a new valid gitlab access token. -func (c *GitProxyConfig) GetGitAccessToken(encode bool) (string, error) { - c.gitAccessTokenLock.RLock() - accessTokenExpiresAt := c.gitAccessTokenExpiresAt - c.gitAccessTokenLock.RUnlock() - if accessTokenExpiresAt > 0 && time.Now().Unix() >= accessTokenExpiresAt-(c.expiredLeeway.Milliseconds()/1000) { - log.Println("Refreshing git token") - err := c.refreshGitAccessToken() +func parseStringAsURL() mapstructure.DecodeHookFuncType { + return func(f reflect.Type, t reflect.Type, data any) (interface{}, error) { + // Check that the data is string + if f.Kind() != reflect.String { + return data, nil + } + + // Check that the target type is our custom type + if t != reflect.TypeOf(url.URL{}) { + return data, nil + } + + // Return the parsed value + dataStr, ok := data.(string) + if !ok { + return nil, fmt.Errorf("cannot cast URL value to string") + } + if dataStr == "" { + return nil, fmt.Errorf("empty values are not allowed for URLs") + } + url, err := url.Parse(dataStr) if err != nil { - return "", err + return nil, err } + return url, nil } - c.gitAccessTokenLock.RLock() - defer c.gitAccessTokenLock.RUnlock() - if encode { - return encodeGitCredentials(c.gitAccessToken), nil - } - return c.gitAccessToken, nil } -func encodeGitCredentials(token string) string { - return base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("oauth2:%s", token))) -} +func parseJsonArray() mapstructure.DecodeHookFuncType { + return func(f reflect.Type, t reflect.Type, data any) (interface{}, error) { + // Check that the data is a string + if f.Kind() != reflect.String { + return data, nil + } -type gitTokenRefreshResponse struct { - AccessToken string `json:"access_token"` - ExpiresAt int64 `json:"expires_at"` -} + // Check that the target type is a slice + if t.Kind() != reflect.Slice { + return data, nil + } -// Exchange the keycloak access token for a gitlab access token -func (c *GitProxyConfig) refreshGitAccessToken() (err error) { - c.gitAccessTokenLock.Lock() - defer c.gitAccessTokenLock.Unlock() - req, err := http.NewRequest(http.MethodGet, c.RenkuURL.JoinPath("api/auth/gitlab/exchange").String(), nil) - if err != nil { - return - } - renkuAccessToken, err := c.getAndRefreshRenkuAccessToken() - if err != nil { - return - } - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", renkuAccessToken)) - res, err := http.DefaultClient.Do(req) - if err != nil { - return - } - if res.StatusCode != 200 { - err = fmt.Errorf("cannot exchange keycloak oauth token for git token, failed with staus code: %d", res.StatusCode) - return - } - var resParsed gitTokenRefreshResponse - err = json.NewDecoder(res.Body).Decode(&resParsed) - if err != nil { - return - } - c.gitAccessToken = resParsed.AccessToken - c.gitAccessTokenExpiresAt = resParsed.ExpiresAt - return nil -} + raw := data.(string) + if raw == "" { + return nil, fmt.Errorf("cannot parse empty string as a slice") + } -type renkuTokenRefreshResponse struct { - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` -} + var slice []json.RawMessage + if err := json.Unmarshal([]byte(raw), &slice); err != nil { + return data, nil + } -// refreshRenkuAccessToken calls keycloak with a refresh token to get a new access token -func (c *GitProxyConfig) refreshRenkuAccessToken() (err error) { - c.renkuAccessTokenLock.Lock() - defer c.renkuAccessTokenLock.Unlock() - payload := url.Values{} - payload.Add("grant_type", "refresh_token") - payload.Add("refresh_token", c.renkuRefreshToken) - body := strings.NewReader(payload.Encode()) - req, err := http.NewRequest(http.MethodPost, c.RenkuURL.JoinPath(fmt.Sprintf("auth/realms/%s/protocol/openid-connect/token", c.renkuRealm)).String(), body) - if err != nil { - return - } - req.SetBasicAuth(c.renkuClientID, c.renkuClientSecret) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - res, err := http.DefaultClient.Do(req) - if err != nil { - return - } - if res.StatusCode != 200 { - err = fmt.Errorf("cannot refresh keycloak token, failed with staus code: %d", res.StatusCode) - return - } - var resParsed renkuTokenRefreshResponse - err = json.NewDecoder(res.Body).Decode(&resParsed) - if err != nil { - return - } - c.renkuAccessToken = resParsed.AccessToken - if resParsed.RefreshToken != "" { - c.renkuRefreshToken = resParsed.RefreshToken - } - return nil -} + var value []string + for _, v := range slice { + value = append(value, string(v)) + } -// Checks if the expiry of the token has passed or is coming up soon based on a predefined threshold. -// NOTE: no signature validation is performed at all. All of the tokens in the proxy are trusted implicitly -// because they comes from trusted/controlled sources. -func (c *GitProxyConfig) isJWTExpired(token string) (isExpired bool, err error) { - parser := jwt.NewParser() - claims := jwt.RegisteredClaims{} - isExpired = true - _, _, err = parser.ParseUnverified(token, &claims) - if err != nil { - log.Printf("Cannot parse token claims, assuming token is expired: %s\n", err.Error()) - return + return value, nil } - // VerifyExpiresAt returns cmp.Before(exp) if exp is set, otherwise !req if exp is not set. - // Here we have it setup so that if the exp claim is not defined we assume the token is not expired. - // Keycloak does not set the `exp` claim on tokens that have the offline access grant - because they do not expire. - jwtIsNotExpired := claims.VerifyExpiresAt(time.Now().Add(c.expiredLeeway), false) - return !jwtIsNotExpired, nil } -// Periodically refreshes the renku acces token. Used to make sure the refresh token does not expire. -func (c *GitProxyConfig) periodicTokenRefresh() { - for { - <-c.refreshTicker.C - c.renkuAccessTokenLock.RLock() - renkuRefreshToken := c.renkuRefreshToken - c.renkuAccessTokenLock.RUnlock() - refreshTokenIsExpired, err := c.isJWTExpired(renkuRefreshToken) - if err != nil { - log.Printf("Could not check if renku refresh token is expired: %s\n", err.Error()) +func parseJsonVariable() mapstructure.DecodeHookFuncType { + return func(f reflect.Type, t reflect.Type, data any) (interface{}, error) { + // Check that the data is a string + if f.Kind() != reflect.String { + return data, nil + } + + // Check that the target type is a struct + if t.Kind() != reflect.Struct { + return data, nil } - if refreshTokenIsExpired { - log.Println("Getting a new renku refresh token from automatic checks") - err = c.refreshRenkuAccessToken() - if err != nil { - log.Printf("Could not refresh renku token: %s\n", err.Error()) - } + + raw := data.(string) + if raw == "" { + return nil, fmt.Errorf("cannot parse empty string as a struct") } + + value := reflect.New(t) + if err := json.Unmarshal([]byte(raw), value.Interface()); err != nil { + return data, nil + } + + return value.Interface(), nil } } diff --git a/git-https-proxy/go.mod b/git-https-proxy/go.mod index 3ab0e1bd3..2579f4e8e 100644 --- a/git-https-proxy/go.mod +++ b/git-https-proxy/go.mod @@ -1,15 +1,33 @@ module github.com/SwissDataScienceCenter/renku-notebooks/git-https-proxy -go 1.19 +go 1.21 require ( - github.com/elazarl/goproxy v0.0.0-20220328115640-894aeddb713e + github.com/elazarl/goproxy v0.0.0-20231117061959-7cc037d33fb5 github.com/golang-jwt/jwt/v4 v4.5.0 + github.com/mitchellh/mapstructure v1.5.0 + github.com/spf13/viper v1.18.2 github.com/stretchr/testify v1.9.0 ) require ( - github.com/davecgh/go-spew v1.1.1 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/hashicorp/hcl v1.0.0 // indirect + github.com/magiconair/properties v1.8.7 // indirect + github.com/pelletier/go-toml/v2 v2.2.2 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/sagikazarmark/locafero v0.4.0 // indirect + github.com/sagikazarmark/slog-shim v0.1.0 // indirect + github.com/sourcegraph/conc v0.3.0 // indirect + github.com/spf13/afero v1.11.0 // indirect + github.com/spf13/cast v1.6.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect + github.com/subosito/gotenv v1.6.0 // indirect + go.uber.org/multierr v1.11.0 // indirect + golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 // indirect + golang.org/x/sys v0.20.0 // indirect + golang.org/x/text v0.15.0 // indirect + gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/git-https-proxy/go.sum b/git-https-proxy/go.sum index 333468caa..5d1ce12d2 100644 --- a/git-https-proxy/go.sum +++ b/git-https-proxy/go.sum @@ -1,17 +1,81 @@ -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/elazarl/goproxy v0.0.0-20220328115640-894aeddb713e h1:99KFda6F/mw8xSfceY2JEVCrYWX7l+Ms6BcO5wEct+Q= -github.com/elazarl/goproxy v0.0.0-20220328115640-894aeddb713e/go.mod h1:Ro8st/ElPeALwNFlcTpWmkr6IoMFfkjXAvTHpevnDsM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/elazarl/goproxy v0.0.0-20231117061959-7cc037d33fb5 h1:m62nsMU279qRD9PQSWD1l66kmkXzuYcnVJqL4XLeV2M= +github.com/elazarl/goproxy v0.0.0-20231117061959-7cc037d33fb5/go.mod h1:Ro8st/ElPeALwNFlcTpWmkr6IoMFfkjXAvTHpevnDsM= github.com/elazarl/goproxy/ext v0.0.0-20190711103511-473e67f1d7d2 h1:dWB6v3RcOy03t/bUadywsbyrQwCqZeNIEX6M1OtSZOM= github.com/elazarl/goproxy/ext v0.0.0-20190711103511-473e67f1d7d2/go.mod h1:gNh8nYJoAm43RfaxurUnxr+N1PwuFV3ZMl/efxlIlY8= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= +github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= +github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= +github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-charset v0.0.0-20180617210344-2471d30d28b4/go.mod h1:qgYeAmZ5ZIpBWTGllZSQnw97Dj+woV0toclVaRGI8pc= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ= +github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4= +github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= +github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ= +github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= +github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= +github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= +github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= +github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= +github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ= +github.com/spf13/viper v1.18.2/go.mod h1:EKmWIqdnk5lOcmR72yw6hS+8OPYcwD0jteitLMVB+yk= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= +github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= +go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +golang.org/x/exp v0.0.0-20240416160154-fe59bbe5cc7f h1:99ci1mjWVBWwJiEKYY6jWa4d2nTQVIEhZIptnrVb1XY= +golang.org/x/exp v0.0.0-20240416160154-fe59bbe5cc7f/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI= +golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 h1:vr/HnozRka3pE4EsMEg1lgkXJkTFJCVUX+S/ZT6wYzM= +golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= +golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= +golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= +golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= +golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= +gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/git-https-proxy/main.go b/git-https-proxy/main.go index a496d4c91..ef49af635 100644 --- a/git-https-proxy/main.go +++ b/git-https-proxy/main.go @@ -9,138 +9,66 @@ import ( "net/url" "os" "os/signal" - "regexp" - "strings" "syscall" - "github.com/SwissDataScienceCenter/renku-notebooks/git-https-proxy/config" configLib "github.com/SwissDataScienceCenter/renku-notebooks/git-https-proxy/config" - "github.com/elazarl/goproxy" + "github.com/SwissDataScienceCenter/renku-notebooks/git-https-proxy/proxy" ) func main() { - config := configLib.ParseEnv() - // INFO: Make a channel that will receive the SIGTERM on shutdown + config, err := configLib.GetConfig() + if err != nil { + log.Fatalln(err) + } + + if err := config.Validate(); err != nil { + log.Fatalln(err) + } + + if config.AnonymousSession { + log.Println("Warning: Starting the git-proxy for an anonymous session, which is essentially useless.") + } + + //? INFO: Make a channel that will receive the SIGTERM on shutdown sigTerm := make(chan os.Signal, 1) signal.Notify(sigTerm, syscall.SIGTERM, syscall.SIGINT) ctx := context.Background() - // INFO: Setup servers - proxyHandler := getProxyHandler(config) + //? INFO: Setup servers + proxyHandler := proxy.GetProxyHandler(config) proxyServer := http.Server{ - Addr: fmt.Sprintf(":%s", config.ProxyPort), + Addr: fmt.Sprintf(":%d", config.ProxyPort), Handler: proxyHandler, } healthHandler := getHealthHandler(config) healthServer := http.Server{ - Addr: fmt.Sprintf(":%s", config.HealthPort), + Addr: fmt.Sprintf(":%d", config.HealthPort), Handler: healthHandler, } - // INFO: Run servers in the background + //? INFO: Run servers in the background go func() { - log.Printf("Health server active on port %s\n", config.HealthPort) + log.Printf("Health server active on port %d\n", config.HealthPort) log.Fatalln(healthServer.ListenAndServe()) }() go func() { - log.Printf("Git proxy active on port %s\n", config.ProxyPort) - log.Printf("Repo Url: %v, anonymous session: %v\n", config.RepoURL, config.AnonymousSession) + log.Printf("Git proxy active on port %d\n", config.ProxyPort) log.Fatalln(proxyServer.ListenAndServe()) }() - // INFO: Block until you receive sigTerm to shutdown. All of this is necessary - // because the proxy has to shut down only after all the other containers do so in case - // any other containers (i.e. session or sidecar) need git right before shutting down. + //? INFO: Block until you receive sigTerm to shutdown. All of this is necessary + //? INFO: because the proxy has to shut down only after all the other containers do so in case + //? INFO: any other containers (i.e. session or sidecar) need git right before shutting down. <-sigTerm - log.Print("SIGTERM received. Shutting down servers.\n") - err := healthServer.Shutdown(ctx) - if err != nil { - log.Fatalln(err) - } - err = proxyServer.Shutdown(ctx) - if err != nil { - log.Fatalln(err) - } -} - -// Infer port if not explicitly specified -func getPort(urlAddress *url.URL) string { - if urlAddress.Port() == "" { - if urlAddress.Scheme == "http" { - return "80" - } else if urlAddress.Scheme == "https" { - return "443" - } - } - return urlAddress.Port() -} - -// Ensure that hosts name watch with/without. I.e. -// ensure www.hostname.com matches hostname.com and vice versa -func hostsMatch(url1 *url.URL, url2 *url.URL) bool { - var err error - var url1ContainsWww, url2ContainsWww bool - wwwRegex := fmt.Sprintf("^%s", regexp.QuoteMeta("www.")) - url1ContainsWww, err = regexp.MatchString(wwwRegex, url1.Hostname()) + log.Print("SIGTERM received. Shutting down servers.\n") + err = healthServer.Shutdown(ctx) if err != nil { log.Fatalln(err) } - url2ContainsWww, err = regexp.MatchString(wwwRegex, url2.Hostname()) + err = proxyServer.Shutdown(ctx) if err != nil { log.Fatalln(err) } - if url1ContainsWww && !url2ContainsWww { - return url1.Hostname() == fmt.Sprintf("www.%s", url2.Hostname()) - } else if !url1ContainsWww && url2ContainsWww { - return fmt.Sprintf("www.%s", url1.Hostname()) == url2.Hostname() - } else { - return url1.Hostname() == url2.Hostname() - } -} - -// Return a server handler that contains the proxy that injects the Git aithorization header when -// the conditions for doing so are met. -func getProxyHandler(config *configLib.GitProxyConfig) *goproxy.ProxyHttpServer { - proxyHandler := goproxy.NewProxyHttpServer() - proxyHandler.Verbose = false - gitRepoHostWithWww := fmt.Sprintf("www.%s", config.RepoURL.Hostname()) - handlerFunc := func(r *http.Request, ctx *goproxy.ProxyCtx) (*http.Request, *http.Response) { - var validGitRequest bool - validGitRequest = r.URL.Scheme == config.RepoURL.Scheme && - hostsMatch(r.URL, config.RepoURL) && - getPort(r.URL) == getPort(config.RepoURL) && - strings.HasPrefix(strings.TrimLeft(r.URL.Path, "/"), strings.TrimLeft(config.RepoURL.Path, "/")) - if config.AnonymousSession { - log.Print("Anonymous session, not adding auth headers, letting request through without adding auth headers.\n") - return r, nil - } - if !validGitRequest { - // Skip logging healthcheck requests - if r.URL.Path != "/ping" && r.URL.Path != "/ping/" { - log.Printf("The request %s does not match the git repository %s letting request through without adding auth headers\n", r.URL.String(), config.RepoURL.String()) - } - return r, nil - } - log.Printf("The request %s matches the git repository %s, adding auth headers\n", r.URL.String(), config.RepoURL.String()) - gitToken, err := config.GetGitAccessToken(true) - if err != nil { - log.Printf("The git token cannot be refreshed, returning 401, error: %s\n", err.Error()) - return r, goproxy.NewResponse(r, goproxy.ContentTypeText, 401, "The git token could not be refreshed") - } - r.Header.Set("Authorization", fmt.Sprintf("Basic %s", gitToken)) - return r, nil - } - // NOTE: We need to eavesdrop on the HTTPS connection to insert the Auth header - // we do this only for the case where the request host matches the host of the git repo - // in all other cases we leave the request alone. - proxyHandler.OnRequest(goproxy.ReqHostIs( - config.RepoURL.Hostname(), - gitRepoHostWithWww, - fmt.Sprintf("%s:443", config.RepoURL.Hostname()), - fmt.Sprintf("%s:443", gitRepoHostWithWww), - )).HandleConnect(goproxy.AlwaysMitm) - proxyHandler.OnRequest().DoFunc(handlerFunc) - return proxyHandler } // The proxy does not expose a health endpoint. Therefore the purpose of this server @@ -148,7 +76,7 @@ func getProxyHandler(config *configLib.GitProxyConfig) *goproxy.ProxyHttpServer // and running the health server will use the proxy as a proxy for the health endpoint. // This is necessary because sending any requests directly to the proxy results in a 500 // with a message that the proxy only accepts proxy requests and no direct requests. -func getHealthHandler(config *config.GitProxyConfig) *http.ServeMux { +func getHealthHandler(config configLib.GitProxyConfig) *http.ServeMux { handler := http.NewServeMux() handler.HandleFunc("/ping", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") @@ -161,12 +89,12 @@ func getHealthHandler(config *config.GitProxyConfig) *http.ServeMux { w.Write(jsonResp) }) handler.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { - proxyUrl, err := url.Parse(fmt.Sprintf("http://localhost:%s", config.ProxyPort)) + proxyUrl, err := url.Parse(fmt.Sprintf("http://localhost:%d", config.ProxyPort)) if err != nil { log.Fatalln(err) } client := &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(proxyUrl)}} - resp, err := client.Get(fmt.Sprintf("http://localhost:%s/ping", config.HealthPort)) + resp, err := client.Get(fmt.Sprintf("http://localhost:%d/ping", config.HealthPort)) if err != nil { log.Println("The GET request to /ping from within /health failed with:", err) w.WriteHeader(http.StatusBadRequest) diff --git a/git-https-proxy/main_test.go b/git-https-proxy/main_test.go deleted file mode 100644 index 60706e1c2..000000000 --- a/git-https-proxy/main_test.go +++ /dev/null @@ -1,147 +0,0 @@ -package main - -import ( - "fmt" - "io" - "log" - "net/http" - "net/http/httptest" - "net/url" - "os" - "testing" - "time" - - configLib "github.com/SwissDataScienceCenter/renku-notebooks/git-https-proxy/config" - "github.com/stretchr/testify/assert" -) - -const gitAuthToken string = "verySecretToken" -const renkuJWT string = "verySecretRenkuJWT" - -// This is a dummy server meant to mimic the final -// destionation the proxy will route a request to. The -// server just returns information about the received request -// and nothing else. Used to confirm that the proxy is properly -// routing things and injecting the right headers. -func setUpGitServer() (*url.URL, func()) { - handler := http.NewServeMux() - handlerFunc := func(w http.ResponseWriter, r *http.Request) { - var body []byte - var err error - body, err = io.ReadAll(r.Body) - if err != nil { - log.Fatalln("Cannot read body from response") - } - for name, values := range r.Header { - w.Header().Set(name, values[0]) - } - w.WriteHeader(http.StatusOK) - w.Write(body) - } - handler.HandleFunc("/", handlerFunc) - return setUpTestServer(handler) -} - -func setUpGitProxy(c *configLib.GitProxyConfig) (*url.URL, func()) { - proxyHandler := getProxyHandler(c) - return setUpTestServer(proxyHandler) -} - -func setUpTestServer(handler http.Handler) (*url.URL, func()) { - ts := httptest.NewServer(handler) - tsUrl, err := url.Parse(ts.URL) - if err != nil { - log.Fatalln(err) - } - return tsUrl, ts.Close -} - -func getTestConfig(isSessionAnonymous bool, token string, injectionURL *url.URL) *configLib.GitProxyConfig { - os.Setenv("GITLAB_OAUTH_TOKEN", token) - defer os.Unsetenv("GITLAB_OAUTH_TOKEN") - os.Setenv("GITLAB_OAUTH_TOKEN_EXPIRES_AT", fmt.Sprintf("%d", time.Now().Unix()+9999999999)) - defer os.Unsetenv("GITLAB_OAUTH_TOKEN_EXPIRES_AT") - os.Setenv("RENKU_ACCESS_TOKEN", renkuJWT) - defer os.Unsetenv("RENKU_ACCESS_TOKEN") - os.Setenv("RENKU_REFRESH_TOKEN", "renkuRefreshToken") - defer os.Unsetenv("RENKU_REFRESH_TOKEN") - os.Setenv("RENKU_URL", "https://dummy.renku.com") - defer os.Unsetenv("RENKU_URL") - os.Setenv("REPOSITORY_URL", injectionURL.String()) - defer os.Unsetenv("REPOSITORY_URL") - os.Setenv("ANONYMOUS_SESSION", fmt.Sprint(isSessionAnonymous)) - defer os.Unsetenv("ANONYMOUS_SESSION") - os.Setenv("RENKU_REALM", "Renku") - defer os.Unsetenv("RENKU_REALM") - os.Setenv("RENKU_CLIENT_ID", "RenkuClientID") - defer os.Unsetenv("RENKU_CLIENT_ID") - os.Setenv("RENKU_CLIENT_SECRET", "RenkuClientSecret") - defer os.Unsetenv("RENKU_CLIENT_SECRET") - return configLib.ParseEnv() -} - -func getTestClient(proxyUrl *url.URL) *http.Client { - return &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(proxyUrl)}} -} - -type testEntry struct { - Url string - AuthHeader string -} - -// Ensure token is not sent when user is anonymous -func TestProxyAnonymous(t *testing.T) { - gitServerUrl, gitServerClose := setUpGitServer() - defer gitServerClose() - injectionPath := &url.URL{ - Scheme: gitServerUrl.Scheme, - Host: gitServerUrl.Host, - Path: "injection/path", - } - config := getTestConfig(true, gitAuthToken, injectionPath) - proxyServerUrl, proxyServerClose := setUpGitProxy(config) - defer proxyServerClose() - testClient := getTestClient(proxyServerUrl) - tests := []testEntry{ - {Url: gitServerUrl.String(), AuthHeader: ""}, - {Url: injectionPath.String(), AuthHeader: ""}, - } - for _, test := range tests { - resp, err := testClient.Get(test.Url) - assert.Nil(t, err) - assert.Equal(t, resp.Header.Get("Authorization"), test.AuthHeader) - } -} - -// Ensure token is sent in header only when urls match -func TestProxyRegistered(t *testing.T) { - gitServerUrl, gitServerClose := setUpGitServer() - defer gitServerClose() - injectionPath := &url.URL{ - Scheme: gitServerUrl.Scheme, - Host: gitServerUrl.Host, - Path: "injection/path", - } - config := getTestConfig(false, gitAuthToken, injectionPath) - proxyServerUrl, proxyServerClose := setUpGitProxy(config) - defer proxyServerClose() - testClient := getTestClient(proxyServerUrl) - token, err := config.GetGitAccessToken(true) - assert.Nil(t, err) - authHeaderValue := fmt.Sprintf("Basic %s", token) - tests := []testEntry{ - // Path is root and does not match repo url - {Url: gitServerUrl.String(), AuthHeader: ""}, - // Path is not root and does not match repo url - {Url: fmt.Sprintf("%s/%s", gitServerUrl.String(), "some/subpath"), AuthHeader: ""}, - // Path exactly matches repo url - {Url: injectionPath.String(), AuthHeader: authHeaderValue}, - // Path begins with repo url - {Url: fmt.Sprintf("%s/%s", injectionPath.String(), "some/subpath"), AuthHeader: authHeaderValue}, - } - for _, test := range tests { - resp, err := testClient.Get(test.Url) - assert.Nil(t, err) - assert.Equal(t, resp.Header.Get("Authorization"), test.AuthHeader) - } -} diff --git a/git-https-proxy/proxy/main.go b/git-https-proxy/proxy/main.go new file mode 100644 index 000000000..570aa78a8 --- /dev/null +++ b/git-https-proxy/proxy/main.go @@ -0,0 +1,122 @@ +package proxy + +import ( + "fmt" + "log" + "net/http" + "net/url" + "regexp" + "strings" + + configLib "github.com/SwissDataScienceCenter/renku-notebooks/git-https-proxy/config" + "github.com/SwissDataScienceCenter/renku-notebooks/git-https-proxy/tokenstore" + "github.com/elazarl/goproxy" +) + +// Returns a server handler that contains the proxy that injects the Git aithorization header when +// the conditions for doing so are met. +func GetProxyHandler(config configLib.GitProxyConfig) *goproxy.ProxyHttpServer { + proxyHandler := goproxy.NewProxyHttpServer() + proxyHandler.Verbose = false + + if config.AnonymousSession { + return proxyHandler + } + + tokenStore := tokenstore.New(&config) + + providers := make(map[string]configLib.GitProvider, len(config.Providers)) + for _, p := range config.Providers { + providers[p.Id] = p + } + + for _, repo := range config.Repositories { + repoURL, err := url.Parse(repo.Url) + if err != nil { + log.Printf("Cannot parse repository URL (%s), skipping proxy setup.", repo.Url) + continue + } + provider := repo.Provider + if provider == "" { + log.Printf("Repository (%s) has no provider, skipping proxy setup.", repo.Url) + continue + } + if _, providerExists := providers[provider]; !providerExists { + log.Printf("The provider (%s) for repository (%s) is not configured, skipping proxy setup.", provider, repo.Url) + continue + } + log.Printf("Setting up proxy for repository: %s [%s]", repo.Url, provider) + + gitRepoHostWithWww := fmt.Sprintf("www.%s", repoURL.Hostname()) + + handlerFunc := func(r *http.Request, ctx *goproxy.ProxyCtx) (*http.Request, *http.Response) { + validGitRequest := r.URL.Scheme == repoURL.Scheme && + hostsMatch(r.URL, repoURL) && + getPort(r.URL) == getPort(repoURL) && + strings.HasPrefix(strings.TrimLeft(r.URL.Path, "/"), strings.TrimLeft(repoURL.Path, "/")) + if !validGitRequest { + // Skip logging healthcheck requests + if r.URL.Path != "/ping" && r.URL.Path != "/ping/" { + log.Printf("The request %s does not match the git repository %s letting request through without adding auth headers\n", r.URL.String(), repoURL.String()) + } + return r, nil + } + log.Printf("The request %s matches the git repository %s [%s], adding auth headers\n", r.URL.String(), repoURL.String(), provider) + gitToken, err := tokenStore.GetGitAccessToken(provider, true) + if err != nil { + log.Printf("The git token cannot be refreshed, returning 401, error: %s\n", err.Error()) + return r, goproxy.NewResponse(r, goproxy.ContentTypeText, 401, "The git token could not be refreshed") + } + r.Header.Set("Authorization", fmt.Sprintf("Basic %s", gitToken)) + return r, nil + } + + conditions := goproxy.ReqHostIs( + repoURL.Hostname(), + gitRepoHostWithWww, + fmt.Sprintf("%s:443", repoURL.Hostname()), + fmt.Sprintf("%s:443", gitRepoHostWithWww), + ) + // NOTE: We need to eavesdrop on the HTTPS connection to insert the Auth header + // we do this only for the case where the request host matches the host of the git repo + // in all other cases we leave the request alone. + proxyHandler.OnRequest(conditions).HandleConnect(goproxy.AlwaysMitm) + proxyHandler.OnRequest(conditions).DoFunc(handlerFunc) + } + return proxyHandler +} + +// Ensure that hosts name match with/without www. I.e. +// ensure www.hostname.com matches hostname.com and vice versa +func hostsMatch(url1 *url.URL, url2 *url.URL) bool { + var err error + var url1ContainsWww, url2ContainsWww bool + wwwRegex := fmt.Sprintf("^%s", regexp.QuoteMeta("www.")) + url1ContainsWww, err = regexp.MatchString(wwwRegex, url1.Hostname()) + if err != nil { + log.Fatalln(err) + } + url2ContainsWww, err = regexp.MatchString(wwwRegex, url2.Hostname()) + if err != nil { + log.Fatalln(err) + } + if url1ContainsWww && !url2ContainsWww { + return url1.Hostname() == fmt.Sprintf("www.%s", url2.Hostname()) + } else if !url1ContainsWww && url2ContainsWww { + return fmt.Sprintf("www.%s", url1.Hostname()) == url2.Hostname() + } else { + return url1.Hostname() == url2.Hostname() + } +} + +// Infer port if not explicitly specified +func getPort(urlAddress *url.URL) string { + if urlAddress.Port() == "" { + if urlAddress.Scheme == "http" { + return "80" + } else if urlAddress.Scheme == "https" { + return "443" + } + } + return urlAddress.Port() +} diff --git a/git-https-proxy/tokenstore/main.go b/git-https-proxy/tokenstore/main.go new file mode 100644 index 000000000..674b30740 --- /dev/null +++ b/git-https-proxy/tokenstore/main.go @@ -0,0 +1,233 @@ +package tokenstore + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "log" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/SwissDataScienceCenter/renku-notebooks/git-https-proxy/config" + "github.com/golang-jwt/jwt/v4" +) + +type TokenSet struct { + AccessToken string + ExpiresAt int64 +} + +type TokenStore struct { + // The git proxy config + Config *config.GitProxyConfig + // The git providers + Providers map[string]config.GitProvider + // Period used to refresh renku tokens + RefreshTickerPeriod time.Duration + // Safety margin for when to consider a token expired. For example if this is set to + // 30 seconds then the token is considered expired if it expires in the next 30 seconds. + ExpirationLeeway time.Duration + + // The current renku access token + renkuAccessToken string + // The current renku refresh token + renkuRefreshToken string + // Ensures that the renku token is not refereshed + // twice at the same time. It also ensures that all other threads that need to simply + // read the token will wait until the refresh (write) is complete. + renkuAccessTokenLock *sync.RWMutex + // Channel that is populated by the timer that triggers the automated renku access token refresh + refreshTicker *time.Ticker + // The current git access tokens for each provider + gitAccessTokens map[string]TokenSet + // Ensures that the git access token are not refreshed twice at the same time. + // Note: We use one lock for all tokens for simplicity. + gitAccessTokensLock *sync.RWMutex +} + +func New(c *config.GitProxyConfig) *TokenStore { + providers := make(map[string]config.GitProvider, len(c.Providers)) + for _, p := range c.Providers { + providers[p.Id] = p + } + + store := TokenStore{ + Config: c, + Providers: providers, + RefreshTickerPeriod: c.GetRefreshCheckPeriod(), + ExpirationLeeway: c.GetExpirationLeeway(), + renkuAccessToken: c.RenkuAccessToken, + renkuRefreshToken: c.RenkuRefreshToken, + renkuAccessTokenLock: &sync.RWMutex{}, + refreshTicker: time.NewTicker(c.GetRefreshCheckPeriod()), + gitAccessTokens: make(map[string]TokenSet, len(c.Providers)), + gitAccessTokensLock: &sync.RWMutex{}, + } + // Start a go routine to keep the refresh token valid + go store.periodicTokenRefresh() + return &store +} + +// Returns a valid access token for the corresponding git provider. +// If the token is expired, a new one will be retrieved using the renku access token. +func (s *TokenStore) GetGitAccessToken(provider string, encode bool) (string, error) { + s.gitAccessTokensLock.RLock() + tokenSet, accessTokenExists := s.gitAccessTokens[provider] + accessTokenExpiresAt := tokenSet.ExpiresAt + s.gitAccessTokensLock.RUnlock() + + if !accessTokenExists || (0 < accessTokenExpiresAt && accessTokenExpiresAt < time.Now().Add(s.ExpirationLeeway).Unix()) { + log.Printf("Getting a fresh token for git provider: %s", provider) + if err := s.refreshGitAccessToken(provider); err != nil { + return "", err + } + } + + s.gitAccessTokensLock.RLock() + defer s.gitAccessTokensLock.RUnlock() + if encode { + return encodeGitCredentials(s.gitAccessTokens[provider].AccessToken), nil + } + return s.gitAccessTokens[provider].AccessToken, nil +} + +type gitTokenRefreshResponse struct { + AccessToken string `json:"access_token"` + ExpiresAt int64 `json:"expires_at"` +} + +// Exchange the renku access token for the access token of the corresponding provider +func (s *TokenStore) refreshGitAccessToken(provider string) error { + s.gitAccessTokensLock.Lock() + defer s.gitAccessTokensLock.Unlock() + + providerURL := s.Providers[provider].AccessTokenUrl + + req, err := http.NewRequest(http.MethodGet, providerURL, nil) + if err != nil { + return err + } + renkuAccessToken, err := s.getValidRenkuAccessToken() + if err != nil { + return err + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", renkuAccessToken)) + res, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + if res.StatusCode != 200 { + return fmt.Errorf("cannot exchange renku token for git token, failed with staus code: %d", res.StatusCode) + } + var resParsed gitTokenRefreshResponse + err = json.NewDecoder(res.Body).Decode(&resParsed) + if err != nil { + return err + } + s.gitAccessTokens[provider] = TokenSet(resParsed) + return nil +} + +// Returns a valid renku access token. If the token is expired, the token will be refreshed first. +func (s *TokenStore) getValidRenkuAccessToken() (string, error) { + isExpired, err := s.isJWTExpired(s.getRenkuAccessToken()) + if err != nil { + return "", err + } + if isExpired { + if err = s.refreshRenkuAccessToken(); err != nil { + return "", err + } + } + return s.getRenkuAccessToken(), nil +} + +func (s *TokenStore) getRenkuAccessToken() string { + s.renkuAccessTokenLock.RLock() + defer s.renkuAccessTokenLock.RUnlock() + return s.renkuAccessToken +} + +// Checks if the expiry of the token has passed or is coming up soon based on a predefined threshold. +// NOTE: no signature validation is performed at all. All of the tokens in the proxy are trusted implicitly +// because they come from trusted/controlled sources. +func (s *TokenStore) isJWTExpired(token string) (bool, error) { + parser := jwt.NewParser() + claims := jwt.RegisteredClaims{} + if _, _, err := parser.ParseUnverified(token, &claims); err != nil { + log.Printf("Cannot parse token claims, assuming token is expired: %s\n", err.Error()) + return true, err + } + // VerifyExpiresAt returns cmp.Before(exp) if exp is set, otherwise !req if exp is not set. + // Here we have it setup so that if the exp claim is not defined we assume the token is not expired. + // Keycloak does not set the `exp` claim on tokens that have the offline access grant - because they do not expire. + jwtIsNotExpired := claims.VerifyExpiresAt(time.Now().Add(s.ExpirationLeeway), false) + return !jwtIsNotExpired, nil +} + +type renkuTokenRefreshResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` +} + +// Refreshes the renku access token. +func (s *TokenStore) refreshRenkuAccessToken() error { + s.renkuAccessTokenLock.Lock() + defer s.renkuAccessTokenLock.Unlock() + payload := url.Values{} + payload.Add("grant_type", "refresh_token") + payload.Add("refresh_token", s.renkuRefreshToken) + body := strings.NewReader(payload.Encode()) + req, err := http.NewRequest(http.MethodPost, s.Config.RenkuURL.JoinPath(fmt.Sprintf("auth/realms/%s/protocol/openid-connect/token", s.Config.RenkuRealm)).String(), body) + if err != nil { + return err + } + req.SetBasicAuth(s.Config.RenkuClientID, s.Config.RenkuClientSecret) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + res, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + if res.StatusCode != 200 { + err = fmt.Errorf("cannot refresh renku access token, failed with staus code: %d", res.StatusCode) + return err + } + var resParsed renkuTokenRefreshResponse + err = json.NewDecoder(res.Body).Decode(&resParsed) + if err != nil { + return err + } + s.renkuAccessToken = resParsed.AccessToken + if resParsed.RefreshToken != "" { + s.renkuRefreshToken = resParsed.RefreshToken + } + return nil +} + +// Periodically refreshes the renku access token. Used to make sure the refresh token does not expire. +func (s *TokenStore) periodicTokenRefresh() { + for { + <-s.refreshTicker.C + s.renkuAccessTokenLock.RLock() + renkuRefreshToken := s.renkuRefreshToken + s.renkuAccessTokenLock.RUnlock() + refreshTokenIsExpired, err := s.isJWTExpired(renkuRefreshToken) + if err != nil { + log.Printf("Could not check if renku refresh token is expired: %s\n", err.Error()) + } + if refreshTokenIsExpired { + log.Println("Getting a new renku refresh token from automatic checks") + err = s.refreshRenkuAccessToken() + if err != nil { + log.Printf("Could not refresh renku token: %s\n", err.Error()) + } + } + } +} + +func encodeGitCredentials(token string) string { + return base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("oauth2:%s", token))) +} diff --git a/git-https-proxy/config/config_test.go b/git-https-proxy/tokenstore/main_test.go similarity index 50% rename from git-https-proxy/config/config_test.go rename to git-https-proxy/tokenstore/main_test.go index 71e3d0daf..8c61e4000 100644 --- a/git-https-proxy/config/config_test.go +++ b/git-https-proxy/tokenstore/main_test.go @@ -1,45 +1,52 @@ -package config +package tokenstore import ( "encoding/base64" "encoding/json" - "fmt" "log" "net/http" "net/http/httptest" "net/url" - "os" "testing" "time" + configLib "github.com/SwissDataScienceCenter/renku-notebooks/git-https-proxy/config" "github.com/golang-jwt/jwt/v4" "github.com/stretchr/testify/assert" ) -func getTestConfig(renkuUrl string, gitAccessToken string, gitAccessTokenExpiresAt int64, renkuAccessToken string, renkuRefreshToken string, refreshPeriodSeconds string) *GitProxyConfig { - os.Setenv("GITLAB_OAUTH_TOKEN", gitAccessToken) - defer os.Unsetenv("GITLAB_OAUTH_TOKEN") - os.Setenv("GITLAB_OAUTH_TOKEN_EXPIRES_AT", fmt.Sprintf("%d", gitAccessTokenExpiresAt)) - defer os.Unsetenv("GITLAB_OAUTH_TOKEN_EXPIRES_AT") - os.Setenv("RENKU_ACCESS_TOKEN", renkuAccessToken) - defer os.Unsetenv("RENKU_ACCESS_TOKEN") - os.Setenv("RENKU_URL", renkuUrl) - defer os.Unsetenv("RENKU_URL") - os.Setenv("REPOSITORY_URL", "https://dummy.renku.com") - defer os.Unsetenv("REPOSITORY_URL") - os.Setenv("ANONYMOUS_SESSION", "false") - defer os.Unsetenv("ANONYMOUS_SESSION") - os.Setenv("RENKU_REALM", "Renku") - defer os.Unsetenv("RENKU_REALM") - os.Setenv("RENKU_REFRESH_TOKEN", renkuRefreshToken) - defer os.Unsetenv("RENKU_REFRESH_TOKEN") - os.Setenv("RENKU_CLIENT_ID", "RenkuClientID") - defer os.Unsetenv("RENKU_CLIENT_ID") - os.Setenv("RENKU_CLIENT_SECRET", "RenkuClientSecret") - defer os.Unsetenv("RENKU_CLIENT_SECRET") - os.Setenv("REFRESH_CHECK_PERIOD_SECONDS", refreshPeriodSeconds) - defer os.Unsetenv("REFRESH_CHECK_PERIOD_SECONDS") - return ParseEnv() +func getTestConfig(renkuURL string, renkuAccessToken string, renkuRefreshToken string) configLib.GitProxyConfig { + parsedRenkuURL, err := url.Parse(renkuURL) + if err != nil { + log.Fatalln(err) + } + + providers := []configLib.GitProvider{ + { + Id: "example", + AccessTokenUrl: parsedRenkuURL.JoinPath("/api/oauth2/token").String(), + }, + } + + return configLib.GitProxyConfig{ + ProxyPort: 8080, + HealthPort: 8081, + AnonymousSession: false, + RenkuAccessToken: renkuAccessToken, + RenkuRefreshToken: renkuRefreshToken, + RenkuURL: parsedRenkuURL, + RenkuRealm: "Renku", + RenkuClientID: "RenkuClientID", + RenkuClientSecret: "RenkuClientSecret", + Repositories: []configLib.GitRepository{}, + Providers: providers, + RefreshCheckPeriodSeconds: 600, + } +} + +func getTestTokenStore(renkuURL string, renkuAccessToken string, renkuRefreshToken string) *TokenStore { + config := getTestConfig(renkuURL, renkuAccessToken, renkuRefreshToken) + return New(&config) } func setUpTestServer(handler http.Handler) (*url.URL, func()) { @@ -69,7 +76,7 @@ func setUpDummyRefreshEndpoints(gitRefreshResponse *gitTokenRefreshResponse, ren } json.NewEncoder(w).Encode(renkuRefreshResponse) } - handler.HandleFunc("/api/auth/gitlab/exchange", gitHandlerFunc) + handler.HandleFunc("/api/oauth2/token", gitHandlerFunc) handler.HandleFunc("/auth/realms/Renku/protocol/openid-connect/token", renkuHandlerFunc) return setUpTestServer(handler) } @@ -88,27 +95,26 @@ func (d DummySigningMethod) Alg() string { return "none" } func getDummyAccessToken(expiresAt int64) (token string, err error) { t := jwt.New(DummySigningMethod{}) - t.Claims = &jwt.StandardClaims{ - ExpiresAt: expiresAt, + t.Claims = &jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Unix(expiresAt, 0)), } return t.SignedString(nil) } func TestSuccessfulRefresh(t *testing.T) { newGitToken := "newGitToken" - oldGitToken := "oldGitToken" - newRenkuToken, err := getDummyAccessToken(time.Now().Unix() + 3600) + newRenkuToken, err := getDummyAccessToken(time.Now().Add(time.Hour).Unix()) assert.Nil(t, err) - oldRenkuAccessToken, err := getDummyAccessToken(time.Now().Unix() - 3600) + oldRenkuAccessToken, err := getDummyAccessToken(time.Now().Add(-time.Hour).Unix()) assert.Nil(t, err) - oldRenkuRefreshToken, err := getDummyAccessToken(time.Now().Unix() + 7200) + oldRenkuRefreshToken, err := getDummyAccessToken(time.Now().Add(2 * time.Hour).Unix()) assert.Nil(t, err) gitRefreshResponse := &gitTokenRefreshResponse{ AccessToken: newGitToken, - ExpiresAt: time.Now().Unix() + 3600, + ExpiresAt: time.Now().Add(time.Hour).Unix(), } renkuRefreshResponse := &renkuTokenRefreshResponse{ - AccessToken: newRenkuToken, + AccessToken: newRenkuToken, RefreshToken: oldRenkuRefreshToken, } authServerURL, authServerClose := setUpDummyRefreshEndpoints(gitRefreshResponse, renkuRefreshResponse) @@ -116,69 +122,75 @@ func TestSuccessfulRefresh(t *testing.T) { defer authServerClose() // token refresh is needed and succeeds - config := getTestConfig(authServerURL.String(), oldGitToken, time.Now().Unix()-9999999, oldRenkuAccessToken, oldRenkuRefreshToken, "600") - gitToken, err := config.GetGitAccessToken(false) + store := getTestTokenStore(authServerURL.String(), oldRenkuAccessToken, oldRenkuRefreshToken) + gitToken, err := store.GetGitAccessToken("example", false) assert.Nil(t, err) assert.Equal(t, gitToken, newGitToken) - renkuAccessToken, err := config.getAndRefreshRenkuAccessToken() + renkuAccessToken, err := store.getValidRenkuAccessToken() assert.Nil(t, err) assert.Equal(t, renkuAccessToken, newRenkuToken) // change token in server response // assert that immediately after the refresh the token is valid and is not refreshed again gitRefreshResponse.AccessToken = "SomethingElse" - evenNewerRenkuToken, err := getDummyAccessToken(time.Now().Unix() + 7200) + evenNewerRenkuToken, err := getDummyAccessToken(time.Now().Add(2 * time.Hour).Unix()) assert.Nil(t, err) renkuRefreshResponse.AccessToken = evenNewerRenkuToken - gitToken, err = config.GetGitAccessToken(false) + gitToken, err = store.GetGitAccessToken("example", false) assert.Nil(t, err) assert.Equal(t, gitToken, newGitToken) - renkuAccessToken, err = config.getAndRefreshRenkuAccessToken() + renkuAccessToken, err = store.getValidRenkuAccessToken() assert.Nil(t, err) assert.Equal(t, renkuAccessToken, newRenkuToken) } func TestNoRefreshNeeded(t *testing.T) { - oldGitToken := "oldGitToken" - oldRenkuAccessToken, err := getDummyAccessToken(time.Now().Unix() + 3600) + newGitToken := "newGitToken" + oldRenkuAccessToken, err := getDummyAccessToken(time.Now().Add(time.Hour).Unix()) assert.Nil(t, err) - oldRenkuRefreshToken, err := getDummyAccessToken(time.Now().Unix() + 7200) + oldRenkuRefreshToken, err := getDummyAccessToken(time.Now().Add(2 * time.Hour).Unix()) assert.Nil(t, err) + gitRefreshResponse := &gitTokenRefreshResponse{ + AccessToken: newGitToken, + ExpiresAt: time.Now().Add(time.Hour).Unix(), + } // Passing nil means that if the any tokens are attempted to be refreshed errors will be returned - authServerURL, authServerClose := setUpDummyRefreshEndpoints(nil, nil) + authServerURL, authServerClose := setUpDummyRefreshEndpoints(gitRefreshResponse, nil) defer authServerClose() - config := getTestConfig(authServerURL.String(), oldGitToken, time.Now().Unix()+99999, oldRenkuAccessToken, oldRenkuRefreshToken, "600") - gitToken, err := config.GetGitAccessToken(false) + store := getTestTokenStore(authServerURL.String(), oldRenkuAccessToken, oldRenkuRefreshToken) + gitToken, err := store.GetGitAccessToken("example", false) assert.Nil(t, err) - assert.Equal(t, gitToken, oldGitToken) - renkuAccessToken, err := config.getAndRefreshRenkuAccessToken() + assert.Equal(t, newGitToken, gitToken) + renkuAccessToken, err := store.getValidRenkuAccessToken() assert.Nil(t, err) assert.Equal(t, renkuAccessToken, oldRenkuAccessToken) } func TestAutomatedRefreshTokenRenewal(t *testing.T) { - newRenkuAccessToken, err := getDummyAccessToken(time.Now().Unix() + 3600) + newRenkuAccessToken, err := getDummyAccessToken(time.Now().Add(time.Hour).Unix()) assert.Nil(t, err) - newRenkuRefreshToken, err := getDummyAccessToken(time.Now().Unix() + (3600 * 24)) + newRenkuRefreshToken, err := getDummyAccessToken(time.Now().Add(24 * time.Hour).Unix()) assert.Nil(t, err) - oldRenkuAccessToken, err := getDummyAccessToken(time.Now().Unix() - 3600) + oldRenkuAccessToken, err := getDummyAccessToken(time.Now().Add(-time.Hour).Unix()) assert.Nil(t, err) - oldRenkuRefreshToken, err := getDummyAccessToken(time.Now().Unix() + 10) + oldRenkuRefreshToken, err := getDummyAccessToken(time.Now().Add(10 * time.Second).Unix()) assert.Nil(t, err) renkuRefreshResponse := &renkuTokenRefreshResponse{ - AccessToken: newRenkuAccessToken, + AccessToken: newRenkuAccessToken, RefreshToken: newRenkuRefreshToken, } authServerURL, authServerClose := setUpDummyRefreshEndpoints(nil, renkuRefreshResponse) log.Printf("Dummy refresh server running at %s\n", authServerURL.String()) defer authServerClose() - config := getTestConfig(authServerURL.String(), "", time.Now().Unix()+3600, oldRenkuAccessToken, oldRenkuRefreshToken, "2") - assert.Equal(t, config.getRenkuAccessToken(), oldRenkuAccessToken) - assert.Equal(t, config.renkuRefreshToken, oldRenkuRefreshToken) + config := getTestConfig(authServerURL.String(), oldRenkuAccessToken, oldRenkuRefreshToken) + config.RefreshCheckPeriodSeconds = 2 + store := New(&config) + assert.Equal(t, store.getRenkuAccessToken(), oldRenkuAccessToken) + assert.Equal(t, store.renkuRefreshToken, oldRenkuRefreshToken) // Sleep to allow for automated token refresh to occur - time.Sleep(time.Second * 5) - assert.Equal(t, config.getRenkuAccessToken(), newRenkuAccessToken) - assert.Equal(t, config.renkuRefreshToken, newRenkuRefreshToken) + time.Sleep(5 * time.Second) + assert.Equal(t, store.getRenkuAccessToken(), newRenkuAccessToken) + assert.Equal(t, store.renkuRefreshToken, newRenkuRefreshToken) } diff --git a/git_services/git_services/cli/__init__.py b/git_services/git_services/cli/__init__.py index a762bf5e1..21629b960 100644 --- a/git_services/git_services/cli/__init__.py +++ b/git_services/git_services/cli/__init__.py @@ -19,7 +19,7 @@ def __init__(self, repo_directory: Path) -> None: if not self.repo_directory.exists(): raise RepoDirectoryDoesNotExistError - def _execute_command(self, *args): + def _execute_command(self, *args) -> str: # NOTE: When running in gunicorn with gevent Popen and PIPE from subprocess do not work # and the gevent equivalents have to be used if os.environ.get("RUNNING_WITH_GEVENT"): @@ -89,3 +89,6 @@ def git_clone(self, *args): def git_diff(self, *args): return self._execute_command("git", "diff", *args) + + def git_symbolic_ref(self, *args): + return self._execute_command("git", "symbolic-ref", *args) diff --git a/git_services/git_services/init/clone.py b/git_services/git_services/init/clone.py index 835ac807b..f2258bd59 100644 --- a/git_services/git_services/init/clone.py +++ b/git_services/git_services/init/clone.py @@ -1,4 +1,3 @@ -import json import sys from git_services.cli.sentry import setup_sentry @@ -13,18 +12,12 @@ config = config_from_env() setup_sentry(config.sentry) - repositories = config.repositories - if repositories: - repos = json.loads(repositories) - repository_url = repos[0]["url"] - else: - repository_url = config.repository_url - git_cloner = GitCloner( - repositories=json.loads(config.repositories) if config.repositories else [], + repositories=config.repositories, + git_providers=config.git_providers, workspace_mount_path=config.workspace_mount_path, user=config.user, lfs_auto_fetch=config.lfs_auto_fetch, - repository_url=repository_url, + is_git_proxy_enabled=config.is_git_proxy_enabled, ) git_cloner.run(storage_mounts=config.storage_mounts) diff --git a/git_services/git_services/init/cloner.py b/git_services/git_services/init/cloner.py index d480c8e00..8752d5f1a 100644 --- a/git_services/git_services/init/cloner.py +++ b/git_services/git_services/init/cloner.py @@ -1,42 +1,45 @@ import json import logging +import re from contextlib import contextmanager from dataclasses import dataclass -from datetime import datetime, timedelta from pathlib import Path from shutil import disk_usage -from time import sleep -from typing import Optional from urllib.parse import urljoin, urlparse import requests from git_services.cli import GitCLI, GitCommandError from git_services.init import errors -from git_services.init.config import User +from git_services.init.config import Provider, User +from git_services.init.config import Repository as ConfigRepo @dataclass class Repository: """Information required to clone a repository.""" - namespace: str - project: str - branch: str - commit_sha: str url: str + dirname: str absolute_path: Path - _git_cli: Optional[GitCLI] = None + provider: str | None + branch: str | None = None + commit_sha: str | None = None + _git_cli: GitCLI | None = None @classmethod - def from_dict(cls, data: dict[str, str], workspace_mount_path: Path) -> "Repository": + def from_config_repo(cls, data: ConfigRepo, workspace_mount_path: Path): + dirname = data.dirname or cls._make_dirname(data.url) + provider = data.provider + branch = data.branch + commit_sha = data.commit_sha return cls( - namespace=data["namespace"], - project=data["project"], - branch=data["branch"], - commit_sha=data["commit_sha"], - url=data["url"], - absolute_path=workspace_mount_path / data["project"], + url=data.url, + dirname=dirname, + absolute_path=workspace_mount_path / dirname, + provider=provider, + branch=branch, + commit_sha=commit_sha, ) @property @@ -57,6 +60,12 @@ def exists(self) -> bool: return False return is_inside.lower().strip() == "true" + @staticmethod + def _make_dirname(url: str) -> str: + path = urlparse(url).path + path = path.removesuffix(".git") + return path.rsplit("/", maxsplit=1).pop() + class GitCloner: remote_name = "origin" @@ -65,39 +74,25 @@ class GitCloner: def __init__( self, - repositories: list[dict[str, str]], + repositories: list[ConfigRepo], + git_providers: list[Provider], workspace_mount_path: str, user: User, - repository_url: str, lfs_auto_fetch=False, + is_git_proxy_enabled=False, ): base_path = Path(workspace_mount_path) logging.basicConfig(level=logging.INFO) self.repositories: list[Repository] = [ - Repository.from_dict(r, workspace_mount_path=base_path) for r in repositories + Repository.from_config_repo(r, workspace_mount_path=base_path) + for r in repositories ] + self.git_providers = {p.id: p for p in git_providers} self.workspace_mount_path = Path(workspace_mount_path) self.user = user - self.repository_url = repository_url self.lfs_auto_fetch = lfs_auto_fetch - self._wait_for_server() - - def _wait_for_server(self, timeout_minutes=None): - if not self.repositories: - return - start = datetime.now() - - while True: - logging.info(f"Waiting for git to become available with timeout minutes {timeout_minutes}...") - res = requests.get(self.repository_url) - if 200 <= res.status_code < 400: - logging.info("Git is available") - return - if timeout_minutes is not None: - timeout_delta = timedelta(minutes=timeout_minutes) - if datetime.now() - start > timeout_delta: - raise errors.GitServerUnavailableError - sleep(5) + self.is_git_proxy_enabled = is_git_proxy_enabled + self._access_tokens: dict[str, str | None] = dict() def _initialize_repo(self, repository: Repository): logging.info("Initializing repo") @@ -119,7 +114,9 @@ def _exclude_storages_from_git(repository: Repository, storages: list[str]): if not storages: return - with open(repository.absolute_path / ".git" / "info" / "exclude", "a") as exclude_file: + with open( + repository.absolute_path / ".git" / "info" / "exclude", "a" + ) as exclude_file: exclude_file.write("\n") for storage in storages: @@ -127,18 +124,42 @@ def _exclude_storages_from_git(repository: Repository, storages: list[str]): if repository.absolute_path not in storage_path.parents: # The storage path is not inside the repo, no need to gitignore continue - exclude_path = storage_path.relative_to(repository.absolute_path).as_posix() + exclude_path = storage_path.relative_to( + repository.absolute_path + ).as_posix() exclude_file.write(f"{exclude_path}\n") + def _get_access_token(self, provider_id: str): + if provider_id in self._access_tokens: + return self._access_tokens[provider_id] + if provider_id not in self.git_providers: + return None + + provider = self.git_providers[provider_id] + request_url = provider.access_token_url + headers = {"Authorization": f"bearer {self.user.renku_token}"} + logging.info(f"Requesting token for provider {provider_id}") + res = requests.get(request_url, headers=headers) + if res.status_code != 200: + logging.warning(f"Could not get access token for provider {provider_id}") + del self._access_tokens[provider_id] + return None + token = res.json() + logging.info(f"Got token response for {provider_id}") + self._access_tokens[provider_id] = token["access_token"] + return self._access_tokens[provider_id] + @contextmanager - def _temp_plaintext_credentials(self, repository: Repository): + def _temp_plaintext_credentials( + self, repository: Repository, git_user: str, git_access_token: str + ): # NOTE: If "lfs." is included in urljoin it does not work properly lfs_auth_setting = "lfs." + urljoin(f"{repository.url}/", "info/lfs.access") credential_loc = Path("/tmp/git-credentials") try: with open(credential_loc, "w") as f: git_host = urlparse(repository.url).netloc - f.write(f"https://oauth2:{self.user.oauth_token}@{git_host}") + f.write(f"https://{git_user}:{git_access_token}@{git_host}") # NOTE: This is required to let LFS know that it should use basic auth to pull data. # If not set LFS will try to pull data without any auth and will then set this field # automatically but the password and username will be required for every git @@ -165,7 +186,7 @@ def _temp_plaintext_credentials(self, repository: Repository): ) @staticmethod - def _get_lfs_total_size_bytes(repository) -> int: + def _get_lfs_total_size_bytes(repository: Repository) -> int: """Get the total size of all LFS files in bytes.""" try: res = repository.git_cli.git_lfs("ls-files", "--json") @@ -180,16 +201,40 @@ def _get_lfs_total_size_bytes(repository) -> int: size_bytes += f.get("size", 0) return size_bytes + @staticmethod + def _get_default_branch(repository: Repository, remote_name: str) -> str: + """Get the default branch of the repository.""" + try: + repository.git_cli.git_remote("set-head", remote_name, "--auto") + res = repository.git_cli.git_symbolic_ref( + f"refs/remotes/{remote_name}/HEAD" + ) + except GitCommandError as err: + raise errors.BranchDoesNotExistError from err + r = re.compile(r"^refs/remotes/origin/(?P.*)$") + match = r.match(res) + if match is None: + raise errors.BranchDoesNotExistError + match_dict = match.groupdict() + return match_dict["branch"] + def _clone(self, repository: Repository): - logging.info(f"Cloning branch {repository.branch}") + logging.info(f"Cloning repository {repository.dirname} from {repository.url}") if self.lfs_auto_fetch: repository.git_cli.git_lfs("install", "--local") else: repository.git_cli.git_lfs("install", "--skip-smudge", "--local") repository.git_cli.git_remote("add", self.remote_name, repository.url) - repository.git_cli.git_fetch(self.remote_name) try: - repository.git_cli.git_checkout(repository.branch) + repository.git_cli.git_fetch(self.remote_name) + except GitCommandError as err: + raise errors.GitFetchError from err + branch = repository.branch or self._get_default_branch( + repository=repository, remote_name=self.remote_name + ) + logging.info(f"Checking out branch {branch}") + try: + repository.git_cli.git_checkout(branch) except GitCommandError as err: if err.returncode != 0 or len(err.stderr) != 0: if "no space left on device" in str(err.stderr).lower(): @@ -222,13 +267,35 @@ def run_helper(self, repository: Repository, *, storage_mounts: list[str]): # will result in lost work if there is uncommitted work. logging.info("The repo already exists - exiting.") return + + # TODO: Is this something else for non-GitLab providers? + git_user = "oauth2" + git_access_token = ( + self._get_access_token(repository.provider) if repository.provider else None + ) + self._initialize_repo(repository) - if self.user.is_anonymous: - self._clone(repository) - repository.git_cli.git_reset("--hard", repository.commit_sha) - else: - with self._temp_plaintext_credentials(repository): + try: + if self.user.is_anonymous: + self._clone(repository) + if repository.commit_sha: + repository.git_cli.git_reset("--hard", repository.commit_sha) + elif git_access_token is None: self._clone(repository) + else: + with self._temp_plaintext_credentials( + repository, git_user, git_access_token + ): + self._clone(repository) + except errors.GitFetchError as err: + logging.error(msg=f"Cannot clone {repository.url}", exc_info=err) + with open( + repository.absolute_path / "ERROR", mode="w", encoding="utf-8" + ) as f: + import traceback + + traceback.print_exception(err, file=f) + return # NOTE: If the storage mount location already exists it means that the repo folder/file # or another existing file will be overwritten, so raise an error here and crash. @@ -236,13 +303,18 @@ def run_helper(self, repository: Repository, *, storage_mounts: list[str]): if Path(a_mount).exists(): raise errors.CloudStorageOverwritesExistingFilesError - logging.info(f"Excluding cloud storage from git: {storage_mounts} for {repository}") + logging.info( + f"Excluding cloud storage from git: {storage_mounts} for {repository}" + ) if storage_mounts: self._exclude_storages_from_git(repository, storage_mounts) self._setup_proxy(repository) def _setup_proxy(self, repository: Repository): + if not self.is_git_proxy_enabled: + logging.info("Skipping git proxy setup") + return logging.info(f"Setting up git proxy to {self.proxy_url}") repository.git_cli.git_config("http.proxy", self.proxy_url) repository.git_cli.git_config("http.sslVerify", "false") diff --git a/git_services/git_services/init/config.py b/git_services/git_services/init/config.py index 616ed23bc..36cf182a5 100644 --- a/git_services/git_services/init/config.py +++ b/git_services/git_services/init/config.py @@ -1,7 +1,6 @@ import shlex from dataclasses import dataclass, field from pathlib import Path -from typing import Optional, Union import dataconf @@ -14,9 +13,9 @@ class User: """Class for keep track of basic user info used in cloning a repo.""" username: str - oauth_token: Optional[str] = None - full_name: Optional[str] = None - email: Optional[str] = None + full_name: str | None = None + email: str | None = None + renku_token: str | None = None def __post_init__(self): # NOTE: Sanitize user input that is used in running git shell commands with shlex @@ -28,33 +27,56 @@ def __post_init__(self): @property def is_anonymous(self) -> bool: - return self.oauth_token is None or self.oauth_token == "" + return not self.renku_token + + +@dataclass +class Repository: + """Represents a git repository.""" + + url: str + provider: str | None = None + dirname: str | None = None + branch: str | None = None + commit_sha: str | None = None + + +@dataclass +class Provider: + """Represents a git provider.""" + + id: str + access_token_url: str @dataclass class Config: sentry: SentryConfig - repositories: str = None - workspace_mount_path: str = None - repository_url: str = None - commit_sha: str = None - branch: str = None - git_url: str = None - user: User = None - lfs_auto_fetch: Union[str, bool] = "0" - mount_path: str = "/work" + workspace_mount_path: str + mount_path: str + user: User + repositories: list[Repository] = field(default_factory=list) + git_providers: list[Provider] = field(default_factory=list) + lfs_auto_fetch: str | bool = "0" storage_mounts: list[str] = field(default_factory=list) + is_git_proxy_enabled: str | bool = "0" def __post_init__(self): - allowed_string_flags = ["0", "1"] - if self.lfs_auto_fetch not in allowed_string_flags: - raise ValueError("lfs_auto_fetch can only be a string with values '0' or '1'") - if isinstance(self.lfs_auto_fetch, str): - self.lfs_auto_fetch = self.lfs_auto_fetch == "1" + self._check_bool_flag("lfs_auto_fetch") + self._check_bool_flag("is_git_proxy_enabled") for mount in self.storage_mounts: if not Path(mount).is_absolute(): raise errors.CloudStorageMountPathNotAbsolute + def _check_bool_flag(self, attr: str): + value = getattr(self, attr) + if isinstance(value, bool): + return + allowed_string_flags = ["0", "1"] + if value not in allowed_string_flags: + raise ValueError(f"{attr} can only be a string with values '0' or '1'") + setattr(self, attr, value == "1") + def config_from_env() -> Config: return dataconf.env("GIT_CLONE_", Config) diff --git a/git_services/git_services/init/errors.py b/git_services/git_services/init/errors.py index b105ccabc..906c2fbdd 100644 --- a/git_services/git_services/init/errors.py +++ b/git_services/git_services/init/errors.py @@ -35,6 +35,10 @@ class CloudStorageMountPathNotAbsolute(GitCloneGenericError): exit_code = 207 +class GitFetchError(GitCloneGenericError): + exit_code = 208 + + def handle_exception(exc_type, exc_value, exc_traceback): # NOTE: To prevent restarts of a failing init container from producing ambiguous errors # cleanup the repo after a failure so that a restart of the container produces the same error. diff --git a/git_services/git_services/sidecar/errors.py b/git_services/git_services/sidecar/errors.py index a1101766e..b8b196a8c 100644 --- a/git_services/git_services/sidecar/errors.py +++ b/git_services/git_services/sidecar/errors.py @@ -32,14 +32,18 @@ class SidecarProgrammingError(SidecarGenericError): class JSONRPCGenericError(JSONRPCDispatchException): """Base class for all JSON RPC errors.""" - def __init__(self, code=-32603, message="Something went wrong", data=None, *args, **kwargs): + def __init__( + self, code=-32603, message="Something went wrong", data=None, *args, **kwargs + ): super().__init__(code, message, data, *args, **kwargs) class JSONRPCProgrammingError(JSONRPCDispatchException): """An error that cannot be corrected by the user the RPC server.""" - def __init__(self, code=-32000, message="Something went wrong", data=None, *args, **kwargs): + def __init__( + self, code=-32000, message="Something went wrong", data=None, *args, **kwargs + ): super().__init__(code, message, data, *args, **kwargs) @@ -75,7 +79,7 @@ def _json_rpc_errors(*args, **kwargs): message=getattr( e, "message", - f"Something went wrong running a Renku command, " f"this resulted from Renku error {type(e)}", + f"Something went wrong running a Renku command, this resulted from Renku error {type(e)}", ) ) except GitCommandError as e: @@ -86,7 +90,9 @@ def _json_rpc_errors(*args, **kwargs): except Exception as e: logging.exception(e) raise JSONRPCGenericError( - message=getattr(e, "message", f"Failed with an unexpected error of type {type(e)}") + message=getattr( + e, "message", f"Failed with an unexpected error of type {type(e)}" + ) ) return _json_rpc_errors diff --git a/git_services/tests/test_init_clone.py b/git_services/tests/test_init_clone.py index a4f634c8e..b6c7c4bb6 100644 --- a/git_services/tests/test_init_clone.py +++ b/git_services/tests/test_init_clone.py @@ -7,8 +7,7 @@ from git_services.cli import GitCLI from git_services.init import errors from git_services.init.clone import GitCloner -from git_services.init.cloner import Repository -from git_services.init.config import User +from git_services.init.config import Repository, User @pytest.fixture @@ -17,36 +16,27 @@ def test_user() -> User: username="Test.Username", full_name="Test Name", email="test.uesername@email.com", - oauth_token="TestSecretOauthToken123456", + renku_token="TestRenkuToken12345", ) @pytest.fixture -def clone_dir(tmp_path): +def clone_dir(tmp_path: Path): repo_dir = tmp_path / "clone" repo_dir.mkdir(parents=True, exist_ok=True) yield repo_dir shutil.rmtree(repo_dir, ignore_errors=True) -def test_simple_git_clone(test_user, clone_dir, mocker): +def test_simple_git_clone(test_user: User, clone_dir: str, mocker): repo_url = "https://github.com/SwissDataScienceCenter/amalthea.git" mocker.patch("git_services.init.cloner.GitCloner._temp_plaintext_credentials", autospec=True) - repositories = [ - { - "project": "my-project", - "namespace": "", - "branch": "main", - "commit_sha": "test", - "url": repo_url, - } - ] - + repositories = [Repository(url=repo_url)] cloner = GitCloner( repositories=repositories, + git_providers=[], workspace_mount_path=clone_dir, user=test_user, - repository_url=repo_url, ) assert len(os.listdir(clone_dir)) == 0 @@ -65,22 +55,13 @@ def test_lfs_size_check(test_user, clone_dir, mocker): mock_disk_usage = mocker.patch("git_services.init.cloner.disk_usage", autospec=True) mock_get_lfs_total_size_bytes.return_value = 100 mock_disk_usage.return_value = 0, 0, 10 - repositories = [ - { - "project": "my-project", - "namespace": "", - "branch": "main", - "commit_sha": "test", - "url": repo_url, - } - ] - + repositories = [Repository(url=repo_url)] cloner = GitCloner( repositories=repositories, + git_providers=[], workspace_mount_path=clone_dir, user=test_user, lfs_auto_fetch=True, - repository_url=repo_url, ) with pytest.raises(errors.NoDiskSpaceError): @@ -93,24 +74,17 @@ def test_lfs_size_check(test_user, clone_dir, mocker): ) def test_lfs_output_parse(test_user, clone_dir, mocker, lfs_lfs_files_output, expected_output): repo_url = "https://github.com" - repositories = [ - { - "project": "my-project", - "namespace": "", - "branch": "main", - "commit_sha": "test", - "url": repo_url, - } - ] - repository = Repository.from_dict(repositories[0], Path("test")) - mock_cli = mocker.MagicMock(GitCLI, autospec=True) - mock_cli.git_lfs.return_value = lfs_lfs_files_output - mocker.patch("git_services.init.cloner.Repository.git_cli", mock_cli) - + repositories = [Repository(url=repo_url)] cloner = GitCloner( repositories=repositories, + git_providers=[], workspace_mount_path=clone_dir, user=test_user, - repository_url=repo_url, ) + + repository = cloner.repositories[0] + mock_cli = mocker.MagicMock(GitCLI, autospec=True) + mock_cli.git_lfs.return_value = lfs_lfs_files_output + mocker.patch("git_services.init.cloner.Repository.git_cli", mock_cli) + assert cloner._get_lfs_total_size_bytes(repository=repository) == expected_output diff --git a/renku_notebooks/api/amalthea_patches/git_proxy.py b/renku_notebooks/api/amalthea_patches/git_proxy.py index caccb7cd6..0391ffd7c 100644 --- a/renku_notebooks/api/amalthea_patches/git_proxy.py +++ b/renku_notebooks/api/amalthea_patches/git_proxy.py @@ -1,3 +1,5 @@ +import json +from dataclasses import asdict from typing import TYPE_CHECKING from ...config import config @@ -15,6 +17,41 @@ def main(server: "UserServer"): ) patches = [] + prefix = "GIT_PROXY_" + env = [ + {"name": f"{prefix}PORT", "value": str(config.sessions.git_proxy.port)}, + {"name": f"{prefix}HEALTH_PORT", "value": str(config.sessions.git_proxy.health_port)}, + { + "name": f"{prefix}ANONYMOUS_SESSION", + "value": "true" if server.user.anonymous else "false", + }, + {"name": f"{prefix}RENKU_ACCESS_TOKEN", "value": str(server.user.access_token)}, + {"name": f"{prefix}RENKU_REFRESH_TOKEN", "value": str(server.user.refresh_token)}, + {"name": f"{prefix}RENKU_REALM", "value": config.keycloak_realm}, + { + "name": f"{prefix}RENKU_CLIENT_ID", + "value": str(config.sessions.git_proxy.renku_client_id), + }, + { + "name": f"{prefix}RENKU_CLIENT_SECRET", + "value": str(config.sessions.git_proxy.renku_client_secret), + }, + {"name": f"{prefix}RENKU_URL", "value": "https://" + config.sessions.ingress.host}, + { + "name": f"{prefix}REPOSITORIES", + "value": json.dumps([asdict(repo) for repo in server.repositories]), + }, + { + "name": f"{prefix}PROVIDERS", + "value": json.dumps( + [ + dict(id=provider.id, access_token_url=provider.access_token_url) + for provider in server.git_providers + ] + ), + }, + ] + patches.append( { "type": "application/json-patch+json", @@ -32,56 +69,7 @@ def main(server: "UserServer"): "runAsNonRoot": True, }, "name": "git-proxy", - "env": [ - { - "name": "REPOSITORY_URL", - "value": server.gl_project_url, - }, - { - "name": "GIT_PROXY_PORT", - "value": str(config.sessions.git_proxy.port), - }, - { - "name": "GIT_PROXY_HEALTH_PORT", - "value": str(config.sessions.git_proxy.health_port), - }, - { - "name": "GITLAB_OAUTH_TOKEN", - "value": str(server.user.git_token), - }, - { - "name": "GITLAB_OAUTH_TOKEN_EXPIRES_AT", - "value": str(server.user.git_token_expires_at), - }, - { - "name": "RENKU_ACCESS_TOKEN", - "value": str(server.user.access_token), - }, - { - "name": "RENKU_REFRESH_TOKEN", - "value": str(server.user.refresh_token), - }, - { - "name": "RENKU_REALM", - "value": config.keycloak_realm, - }, - { - "name": "RENKU_CLIENT_ID", - "value": str(config.sessions.git_proxy.renku_client_id), - }, - { - "name": "RENKU_CLIENT_SECRET", - "value": str(config.sessions.git_proxy.renku_client_secret), - }, - { - "name": "RENKU_URL", - "value": "https://" + config.sessions.ingress.host, - }, - { - "name": "ANONYMOUS_SESSION", - "value": "true" if server.user.anonymous else "false", - }, - ], + "env": env, "livenessProbe": { "httpGet": { "path": "/health", diff --git a/renku_notebooks/api/amalthea_patches/git_sidecar.py b/renku_notebooks/api/amalthea_patches/git_sidecar.py index 19a4de9dd..c75789f0a 100644 --- a/renku_notebooks/api/amalthea_patches/git_sidecar.py +++ b/renku_notebooks/api/amalthea_patches/git_sidecar.py @@ -13,7 +13,9 @@ def main(server: "UserServer"): if not isinstance(server.user, RegisteredUser): return [] - gl_project_path = server.gl_project_path or "." + gitlab_project = getattr(server, "gitlab_project", None) + gl_project_path = gitlab_project.path if gitlab_project else None + commit_sha = getattr(server, "commit_sha", None) patches = [ { @@ -54,7 +56,9 @@ def main(server: "UserServer"): }, { "name": "GIT_RPC_SENTRY__ENABLED", - "value": str(config.sessions.git_rpc_server.sentry.enabled).lower(), + "value": str( + config.sessions.git_rpc_server.sentry.enabled + ).lower(), }, { "name": "GIT_RPC_SENTRY__DSN", @@ -66,7 +70,9 @@ def main(server: "UserServer"): }, { "name": "GIT_RPC_SENTRY__SAMPLE_RATE", - "value": str(config.sessions.git_rpc_server.sentry.sample_rate), + "value": str( + config.sessions.git_rpc_server.sentry.sample_rate + ), }, { "name": "SENTRY_RELEASE", @@ -74,7 +80,7 @@ def main(server: "UserServer"): }, { "name": "CI_COMMIT_SHA", - "value": f"{server.commit_sha}", + "value": f"{commit_sha}", }, { "name": "RENKU_USERNAME", @@ -149,7 +155,9 @@ def main(server: "UserServer"): { "op": "add", "path": "/statefulset/spec/template/spec/containers/1/args/-", - "value": (f"--skip-auth-route=^/sessions/{server.server_name}/sidecar/health$"), + "value": ( + f"--skip-auth-route=^/sessions/{server.server_name}/sidecar/health$" + ), }, { "op": "add", diff --git a/renku_notebooks/api/amalthea_patches/init_containers.py b/renku_notebooks/api/amalthea_patches/init_containers.py index e963bef83..2f386276a 100644 --- a/renku_notebooks/api/amalthea_patches/init_containers.py +++ b/renku_notebooks/api/amalthea_patches/init_containers.py @@ -1,5 +1,6 @@ import json import os +from dataclasses import asdict from pathlib import Path from typing import TYPE_CHECKING @@ -19,51 +20,43 @@ def git_clone(server: "UserServer"): read_only_etc_certs=True, ) - gl_project_path = server.gl_project_path or "" - + prefix = "GIT_CLONE_" env = [ { - "name": "GIT_CLONE_REPOSITORIES", - "value": json.dumps(server.repositories), - }, - { - "name": "GIT_CLONE_WORKSPACE_MOUNT_PATH", + "name": f"{prefix}WORKSPACE_MOUNT_PATH", "value": server.workspace_mount_path.absolute().as_posix(), }, { - "name": "GIT_CLONE_REPOSITORY_URL", - "value": server.gl_project.http_url_to_repo if server.gl_project else None, + "name": f"{prefix}MOUNT_PATH", + "value": server.work_dir.absolute().as_posix(), }, { - "name": "GIT_CLONE_MOUNT_PATH", - "value": (server.workspace_mount_path / gl_project_path).absolute().as_posix(), - }, - { - "name": "GIT_CLONE_LFS_AUTO_FETCH", + "name": f"{prefix}LFS_AUTO_FETCH", "value": "1" if server.server_options.lfs_auto_fetch else "0", }, - {"name": "GIT_CLONE_COMMIT_SHA", "value": server.commit_sha}, - {"name": "GIT_CLONE_BRANCH", "value": server.branch}, { - "name": "GIT_CLONE_USER__USERNAME", + "name": f"{prefix}USER__USERNAME", "value": server.user.username, }, - {"name": "GIT_CLONE_GIT_URL", "value": server.user.gitlab_client.url}, - {"name": "GIT_CLONE_USER__OAUTH_TOKEN", "value": server.user.git_token}, { - "name": "GIT_CLONE_SENTRY__ENABLED", + "name": f"{prefix}USER__RENKU_TOKEN", + "value": str(server.user.access_token), + }, + {"name": f"{prefix}IS_GIT_PROXY_ENABLED", "value": "0" if server.user.anonymous else "1"}, + { + "name": f"{prefix}SENTRY__ENABLED", "value": str(config.sessions.git_clone.sentry.enabled).lower(), }, { - "name": "GIT_CLONE_SENTRY__DSN", + "name": f"{prefix}SENTRY__DSN", "value": config.sessions.git_clone.sentry.dsn, }, { - "name": "GIT_CLONE_SENTRY__ENVIRONMENT", + "name": f"{prefix}SENTRY__ENVIRONMENT", "value": config.sessions.git_clone.sentry.env, }, { - "name": "GIT_CLONE_SENTRY__SAMPLE_RATE", + "name": f"{prefix}SENTRY__SAMPLE_RATE", "value": str(config.sessions.git_clone.sentry.sample_rate), }, {"name": "SENTRY_RELEASE", "value": os.environ.get("SENTRY_RELEASE")}, @@ -78,12 +71,34 @@ def git_clone(server: "UserServer"): ] if not server.user.anonymous: env += [ - {"name": "GIT_CLONE_USER__EMAIL", "value": server.user.gitlab_user.email}, + {"name": f"{prefix}USER__EMAIL", "value": server.user.gitlab_user.email}, { - "name": "GIT_CLONE_USER__FULL_NAME", + "name": f"{prefix}USER__FULL_NAME", "value": server.user.gitlab_user.name, }, ] + + # Set up git repositories + for idx, repo in enumerate(server.repositories): + obj_env = f"{prefix}REPOSITORIES_{idx}_" + env.append( + { + "name": obj_env, + "value": json.dumps(asdict(repo)), + } + ) + + # Set up git providers + for idx, provider in enumerate(server.required_git_providers): + obj_env = f"{prefix}GIT_PROVIDERS_{idx}_" + data = dict(id=provider.id, access_token_url=provider.access_token_url) + env.append( + { + "name": obj_env, + "value": json.dumps(data), + } + ) + return [ { "type": "application/json-patch+json", diff --git a/renku_notebooks/api/amalthea_patches/inject_certificates.py b/renku_notebooks/api/amalthea_patches/inject_certificates.py index 544c11080..67e9c83d9 100644 --- a/renku_notebooks/api/amalthea_patches/inject_certificates.py +++ b/renku_notebooks/api/amalthea_patches/inject_certificates.py @@ -20,14 +20,16 @@ def proxy(server: "UserServer"): "patch": [ { "op": "add", - "path": ("/statefulset/spec/template/spec/containers/1/volumeMounts/-"), + "path": ( + "/statefulset/spec/template/spec/containers/1/volumeMounts/-" + ), "value": volume_mount, } for volume_mount in etc_cert_volume_mounts ], }, ] - if isinstance(server._user, RegisteredUser): + if isinstance(server.user, RegisteredUser): patches.append( { "type": "application/json-patch+json", diff --git a/renku_notebooks/api/amalthea_patches/jupyter_server.py b/renku_notebooks/api/amalthea_patches/jupyter_server.py index b34740e10..94c7db443 100644 --- a/renku_notebooks/api/amalthea_patches/jupyter_server.py +++ b/renku_notebooks/api/amalthea_patches/jupyter_server.py @@ -14,7 +14,9 @@ def env(server: "UserServer"): # amalthea always makes the jupyter server the first container in the statefulset - gl_project_path = server.gl_project_path or "" + + commit_sha = getattr(server, "commit_sha", None) + project = getattr(server, "project", None) patch_list = [ { @@ -28,7 +30,7 @@ def env(server: "UserServer"): { "op": "add", "path": "/statefulset/spec/template/spec/containers/0/env/-", - "value": {"name": "CI_COMMIT_SHA", "value": server.commit_sha}, + "value": {"name": "CI_COMMIT_SHA", "value": commit_sha}, }, { "op": "add", @@ -45,7 +47,7 @@ def env(server: "UserServer"): # relative to $HOME. "value": { "name": "MOUNT_PATH", - "value": f"/work/{gl_project_path}", + "value": server.work_dir.absolute().as_posix(), }, }, { @@ -56,7 +58,7 @@ def env(server: "UserServer"): { "op": "add", "path": "/statefulset/spec/template/spec/containers/0/env/-", - "value": {"name": "PROJECT_NAME", "value": server.project}, + "value": {"name": "PROJECT_NAME", "value": project}, }, { "op": "add", diff --git a/renku_notebooks/api/classes/data_service.py b/renku_notebooks/api/classes/data_service.py index 912976e4a..de09b68ce 100644 --- a/renku_notebooks/api/classes/data_service.py +++ b/renku_notebooks/api/classes/data_service.py @@ -1,5 +1,6 @@ from dataclasses import dataclass, field from typing import Any, NamedTuple, Optional +from urllib.parse import urljoin, urlparse import requests from flask import current_app @@ -14,6 +15,7 @@ ) from ..schemas.server_options import ServerOptions +from .repository import INTERNAL_GITLAB_PROVIDER, GitProvider, OAuth2Connection, OAuth2Provider from .user import User @@ -225,3 +227,86 @@ def get_default_class(self) -> dict[str, Any]: def find_acceptable_class(self, *args, **kwargs) -> Optional[ServerOptions]: return self.options + + +@dataclass +class GitProviderHelper: + """Calls to the data service to configure git providers.""" + + service_url: str + renku_url: str + internal_gitlab_url: str + + def __post_init__(self): + self.service_url = self.service_url.rstrip("/") + self.renku_url = self.renku_url.rstrip("/") + + def get_providers(self, user: User) -> list[GitProvider]: + connections = self.get_oauth2_connections(user=user) + providers: dict[str, GitProvider] = dict() + for c in connections: + if c.provider_id in providers: + continue + provider = self.get_oauth2_provider(c.provider_id) + access_token_url = urljoin( + self.renku_url, + urlparse(f"{self.service_url}/oauth2/connections/{c.id}/token").path, + ) + providers[c.provider_id] = GitProvider( + id=c.provider_id, + url=provider.url, + connection_id=c.id, + access_token_url=access_token_url, + ) + + providers_list = list(providers.values()) + # Insert the internal GitLab as the first provider + internal_gitlab_access_token_url = urljoin( + self.renku_url, "/api/auth/gitlab/exchange" + ) + providers_list.insert( + 0, + GitProvider( + id=INTERNAL_GITLAB_PROVIDER, + url=self.internal_gitlab_url, + connection_id="", + access_token_url=internal_gitlab_access_token_url, + ), + ) + return providers_list + + def get_oauth2_connections( + self, user: User | None = None + ) -> list[OAuth2Connection]: + if user is None or user.access_token is None: + return [] + request_url = f"{self.service_url}/oauth2/connections" + headers = {"Authorization": f"bearer {user.access_token}"} + res = requests.get(request_url, headers=headers) + if res.status_code != 200: + raise IntermittentError( + message="The data service sent an unexpected response, please try again later" + ) + connections = res.json() + connections = [ + OAuth2Connection.from_dict(c) + for c in connections + if c["status"] == "connected" + ] + return connections + + def get_oauth2_provider(self, provider_id: str) -> OAuth2Provider: + request_url = f"{self.service_url}/oauth2/providers/{provider_id}" + res = requests.get(request_url) + if res.status_code != 200: + raise IntermittentError( + message="The data service sent an unexpected response, please try again later" + ) + provider = res.json() + return OAuth2Provider.from_dict(provider) + + +@dataclass +class DummyGitProviderHelper: + def get_providers(self, *args, **kwargs) -> list[GitProvider]: + return [] diff --git a/renku_notebooks/api/classes/k8s_client.py b/renku_notebooks/api/classes/k8s_client.py index a750e34d1..bbdfc631b 100644 --- a/renku_notebooks/api/classes/k8s_client.py +++ b/renku_notebooks/api/classes/k8s_client.py @@ -9,14 +9,10 @@ import requests from kubernetes import client from kubernetes.client.exceptions import ApiException -from kubernetes.client.models import V1DeleteOptions +from kubernetes.client.models import V1Container, V1DeleteOptions from kubernetes.config import load_config from kubernetes.config.config_exception import ConfigException -from kubernetes.config.incluster_config import ( - SERVICE_CERT_FILENAME, - SERVICE_TOKEN_FILENAME, - InClusterConfigLoader, -) +from kubernetes.config.incluster_config import SERVICE_CERT_FILENAME, SERVICE_TOKEN_FILENAME, InClusterConfigLoader from ...errors.intermittent import ( CannotStartServerError, @@ -250,10 +246,10 @@ def patch_image_pull_secret(self, server_name: str, gitlab_token: GitlabToken): patch, ) - def patch_statefulset_tokens(self, name: str, renku_tokens: RenkuTokens, gitlab_token: GitlabToken): + def patch_statefulset_tokens(self, name: str, renku_tokens: RenkuTokens): """Patch the Renku and Gitlab access tokens that are used in the session statefulset.""" try: - ss = self._apps_v1.read_namespaced_stateful_set(name, self.namespace) + sts = self._apps_v1.read_namespaced_stateful_set(name, self.namespace) except ApiException as err: if err.status == 404: # NOTE: It can happen potentially that another request or something else @@ -261,90 +257,113 @@ def patch_statefulset_tokens(self, name: str, renku_tokens: RenkuTokens, gitlab_ # the missing statefulset return raise - if len(ss.spec.template.spec.containers) < 3 or len(ss.spec.template.spec.init_containers) < 3: - raise ProgrammingError( - "The expected setup for a session was not found when trying to inject new tokens", - detail="Please contact a Renku administrator.", + + containers: list[V1Container] = sts.spec.template.spec.containers + init_containers: list[V1Container] = sts.spec.template.spec.init_containers + + git_proxy_container_index, git_proxy_container = next( + ((i, c) for i, c in enumerate(containers) if c.name == "git-proxy"), + (None, None), + ) + git_clone_container_index, git_clone_container = next( + ((i, c) for i, c in enumerate(init_containers) if c.name == "git-proxy"), + (None, None), + ) + secrets_container_index, secrets_container = next( + ( + (i, c) + for i, c in enumerate(init_containers) + if c.name == "init-user-secrets" + ), + (None, None), + ) + + git_proxy_renku_access_token_env = ( + find_env_var(git_proxy_container, "GIT_PROXY_RENKU_ACCESS_TOKEN") + if git_proxy_container is not None + else None + ) + git_proxy_renku_refresh_token_env = ( + find_env_var(git_proxy_container, "GIT_PROXY_RENKU_REFRESH_TOKEN") + if git_proxy_container is not None + else None + ) + git_clone_renku_access_token_env = ( + find_env_var(git_clone_container, "GIT_CLONE_USER__RENKU_TOKEN") + if git_clone_container is not None + else None + ) + secrets_renku_access_token_env = ( + find_env_var(secrets_container, "RENKU_ACCESS_TOKEN") + if secrets_container is not None + else None + ) + + patches = list() + if ( + git_proxy_container_index is not None + and git_proxy_renku_access_token_env is not None + ): + patches.append( + { + "op": "replace", + "path": ( + f"/spec/template/spec/containers/{git_proxy_container_index}" + f"/env/{git_proxy_renku_access_token_env[0]}/value" + ), + "value": renku_tokens.access_token, + } ) - git_proxy_container_index = 2 - git_proxy_container = ss.spec.template.spec.containers[git_proxy_container_index] - secrets_init_container_index = 0 - secrets_init_container = ss.spec.template.spec.init_containers[secrets_init_container_index] - git_init_container_index = 3 - git_init_container = ss.spec.template.spec.init_containers[git_init_container_index] - patch = [] - expires_at_env = find_env_var(git_proxy_container, "GITLAB_OAUTH_TOKEN_EXPIRES_AT") - gitlab_token_env = find_env_var(git_proxy_container, "GITLAB_OAUTH_TOKEN") - git_init_token_env = find_env_var(git_init_container, "GIT_CLONE_USER__OAUTH_TOKEN") - secrets_access_token_env = find_env_var(secrets_init_container, "RENKU_ACCESS_TOKEN") - renku_access_token_env = find_env_var(git_proxy_container, "RENKU_ACCESS_TOKEN") - renku_refresh_token_env = find_env_var(git_proxy_container, "RENKU_REFRESH_TOKEN") - if not all( - [ - expires_at_env, - gitlab_token_env, - git_init_token_env, - secrets_access_token_env, - renku_access_token_env, - renku_refresh_token_env, - ] + if ( + git_proxy_container_index is not None + and git_proxy_renku_refresh_token_env is not None ): - raise ProgrammingError( - "The expected environment variables were not found when trying to inject new tokens.", - detail="Please contact a Renku administrator.", + patches.append( + { + "op": "replace", + "path": ( + f"/spec/template/spec/containers/{git_proxy_container_index}" + f"/env/{git_proxy_renku_refresh_token_env[0]}/value" + ), + "value": renku_tokens.refresh_token, + }, ) - patch = [ - { - "op": "replace", - "path": ( - f"/spec/template/spec/containers/{git_proxy_container_index}" f"/env/{expires_at_env[0]}/value" - ), - "value": str(gitlab_token.expires_at), - }, - { - "op": "replace", - "path": ( - f"/spec/template/spec/containers/{git_proxy_container_index}" f"/env/{gitlab_token_env[0]}/value" - ), - "value": gitlab_token.access_token, - }, - { - "op": "replace", - "path": ( - f"/spec/template/spec/initContainers/{git_init_container_index}" - f"/env/{git_init_token_env[0]}/value" - ), - "value": gitlab_token.access_token, - }, - { - "op": "replace", - "path": ( - f"/spec/template/spec/initContainers/{secrets_init_container_index}" - f"/env/{secrets_access_token_env[0]}/value" - ), - "value": renku_tokens.access_token, - }, - { - "op": "replace", - "path": ( - f"/spec/template/spec/containers/{git_proxy_container_index}" - f"/env/{renku_access_token_env[0]}/value" - ), - "value": renku_tokens.access_token, - }, - { - "op": "replace", - "path": ( - f"/spec/template/spec/containers/{git_proxy_container_index}" - f"/env/{renku_refresh_token_env[0]}/value" - ), - "value": renku_tokens.refresh_token, - }, - ] + if ( + git_clone_container_index is not None + and git_clone_renku_access_token_env is not None + ): + patches.append( + { + "op": "replace", + "path": ( + f"/spec/template/spec/containers/{git_clone_container_index}" + f"/env/{git_clone_renku_access_token_env[0]}/value" + ), + "value": renku_tokens.access_token, + }, + ) + if ( + secrets_container_index is not None + and secrets_renku_access_token_env is not None + ): + patches.append( + { + "op": "replace", + "path": ( + f"/spec/template/spec/containers/{secrets_container_index}" + f"/env/{secrets_renku_access_token_env[0]}/value" + ), + "value": renku_tokens.access_token, + }, + ) + + if not patches: + return + self._apps_v1.patch_namespaced_stateful_set( name, self.namespace, - patch, + patches, ) @@ -517,7 +536,7 @@ def delete_server(self, server_name: str, safe_username: str, forced: bool = Fal def patch_tokens(self, server_name, renku_tokens: RenkuTokens, gitlab_token: GitlabToken): """Patch the Renku and Gitlab access tokens used in a session.""" client = self.session_ns_client if self.session_ns_client else self.renku_ns_client - client.patch_statefulset_tokens(server_name, renku_tokens, gitlab_token) + client.patch_statefulset_tokens(server_name, renku_tokens) client.patch_image_pull_secret(server_name, gitlab_token) @property diff --git a/renku_notebooks/api/classes/repository.py b/renku_notebooks/api/classes/repository.py new file mode 100644 index 000000000..ad96748df --- /dev/null +++ b/renku_notebooks/api/classes/repository.py @@ -0,0 +1,58 @@ +from dataclasses import dataclass + +INTERNAL_GITLAB_PROVIDER = "INTERNAL_GITLAB" + + +@dataclass +class Repository: + """Information required to clone a git repository.""" + + url: str + provider: str | None = None + dirname: str | None = None + branch: str | None = None + commit_sha: str | None = None + + @classmethod + def from_dict(cls, data: dict[str, str]): + return cls( + url=data["url"], + dirname=data.get("dirname"), + branch=data.get("branch"), + commit_sha=data.get("commit_sha"), + ) + + +@dataclass +class GitProvider: + """A fully-configured git provider.""" + + id: str + url: str + connection_id: str + access_token_url: str + + +@dataclass +class OAuth2Provider: + """An OAuth2 provider.""" + + id: str + url: str + + @classmethod + def from_dict(cls, data: dict[str, str]): + return cls(id=data["id"], url=data["url"]) + + +@dataclass +class OAuth2Connection: + """An OAuth2 connection.""" + + id: str + provider_id: str + status: str + + @classmethod + def from_dict(cls, data: dict[str, str]): + return cls(id=data["id"], provider_id=data["provider_id"], status=data["status"]) diff --git a/renku_notebooks/api/classes/server.py b/renku_notebooks/api/classes/server.py index 020e02dbc..6243b89be 100644 --- a/renku_notebooks/api/classes/server.py +++ b/renku_notebooks/api/classes/server.py @@ -1,8 +1,7 @@ -from dataclasses import asdict, dataclass -from functools import lru_cache +from abc import ABC from itertools import chain from pathlib import Path -from typing import Any, Optional, Union +from typing import Any from urllib.parse import urljoin, urlparse from flask import current_app @@ -10,7 +9,6 @@ from ...config import config from ...errors.programming import ConfigurationError, DuplicateEnvironmentVariableError from ...errors.user import MissingResourceError -from ...util.kubernetes_ import make_server_name from ..amalthea_patches import cloudstorage as cloudstorage_patches from ..amalthea_patches import general as general_patches from ..amalthea_patches import git_proxy as git_proxy_patches @@ -23,71 +21,43 @@ from ..schemas.server_options import ServerOptions from .cloud_storage import ICloudStorageRequest from .k8s_client import K8sClient +from .repository import GitProvider, Repository from .user import AnonymousUser, RegisteredUser -@dataclass -class Repository: - """Information required to clone a repository.""" - - namespace: str - project: str - branch: str - commit_sha: str - url: Optional[str] = None - - @classmethod - def from_schema(cls, data: dict[str, str]) -> "Repository": - return cls( - namespace=data["namespace"], - project=data["project"], - branch=data["branch"], - commit_sha=data["commit_sha"], - ) - - -class UserServer: - """Represents a jupyter server session.""" +class UserServer(ABC): + """Represents a Renku server session.""" def __init__( self, - user: Union[AnonymousUser, RegisteredUser], - namespace: Optional[str], - project: Optional[str], - branch: Optional[str], - commit_sha: Optional[str], - notebook: Optional[str], # TODO: Is this value actually needed? - image: Optional[str], + user: AnonymousUser | RegisteredUser, + server_name: str, + image: str | None, server_options: ServerOptions, environment_variables: dict[str, str], - user_secrets: Optional[K8sUserSecrets], + user_secrets: K8sUserSecrets | None, cloudstorage: list[ICloudStorageRequest], k8s_client: K8sClient, workspace_mount_path: Path, work_dir: Path, using_default_image: bool = False, is_image_private: bool = False, + repositories: list[Repository] = [], **_, ): self._check_flask_config() self._user = user + self.server_name = server_name self._k8s_client: K8sClient = k8s_client - self.safe_username = self._user.safe_username # type:ignore - self.namespace = namespace - self.project = project - self.branch = branch - self.commit_sha = commit_sha - self.notebook = notebook + self.safe_username = self._user.safe_username self.image = image self.server_options = server_options self.environment_variables = environment_variables self.user_secrets = user_secrets self.using_default_image = using_default_image - self.git_host = urlparse(config.git.url).netloc self.workspace_mount_path = workspace_mount_path self.work_dir = work_dir - self.cloudstorage: Optional[list[ICloudStorageRequest]] = cloudstorage - self.gl_project_name = f"{self.namespace}/{self.project}" + self.cloudstorage: list[ICloudStorageRequest] | None = cloudstorage self.is_image_private = is_image_private self.idle_seconds_threshold: int = ( config.sessions.culling.registered.idle_seconds @@ -99,143 +69,132 @@ def __init__( if isinstance(user, RegisteredUser) else config.sessions.culling.anonymous.hibernated_seconds ) - self._repositories: Optional[list[Repository]] = None - - @staticmethod - def _check_flask_config(): - """Check the app config and ensure minimum required parameters are present.""" - if config.git.url is None: - raise ConfigurationError( - message="The gitlab URL is missing, it must be provided in an environment variable called GITLAB_URL" - ) - if config.git.registry is None: - raise ConfigurationError( - message="The url to the docker image registry is missing, it must be provided in " - "an environment variable called IMAGE_REGISTRY" - ) - - @property - def gl_project(self): - return self._user.get_renku_project(self.gl_project_name) + self._repositories: list[Repository] = repositories + self._git_providers: list[GitProvider] | None = None + self._has_configured_git_providers = False @property - def gl_project_path(self) -> Optional[str]: - # NOTE: This is case sensitive and will reflect exact lower/uppercase combination of letters - # that is used in Gitlab for the path of the project. - gl_project = self.gl_project - return gl_project.path if gl_project else None + def user(self) -> AnonymousUser | RegisteredUser: + """Getter for server's user.""" + return self._user @property - def gl_project_url(self) -> Optional[str]: - gl_project = self.gl_project - return gl_project.http_url_to_repo if gl_project else None + def k8s_client(self) -> K8sClient: + """Return server's k8s client.""" + return self._k8s_client @property - def repositories(self) -> list[dict[str, str]]: - if self._repositories is None: - self._repositories = [ - asdict( - Repository( - namespace=self.namespace, - project=self.gl_project_path, - branch=self.branch, - commit_sha=self.commit_sha, - url=self.gl_project_url, - ) - ) - ] + def repositories(self) -> list[Repository]: + # Configure git repository providers based on matching URLs. + if not self._has_configured_git_providers: + for repo in self._repositories: + found_provider = None + for provider in self.git_providers: + if urlparse(provider.url).netloc == urlparse(repo.url).netloc: + found_provider = provider + break + if found_provider is not None: + repo.provider = found_provider.id + self._has_configured_git_providers = True return self._repositories @property - def server_name(self): - """Make the name that is used to identify a unique user session.""" - return make_server_name( - self._user.safe_username, - self.namespace, - self.project, - self.branch, - self.commit_sha, + def server_url(self) -> str: + """The URL where a user can access their session.""" + if type(self._user) is RegisteredUser: + return urljoin( + f"https://{config.sessions.ingress.host}", + f"sessions/{self.server_name}", + ) + return urljoin( + f"https://{config.sessions.ingress.host}", + f"sessions/{self.server_name}?token={self._user.username}", ) @property - def user(self) -> Union[AnonymousUser, RegisteredUser]: - """Getter for server's user.""" - return self._user + def git_providers(self) -> list[GitProvider]: + """The list of git providers.""" + if self._git_providers is None: + self._git_providers = config.git_provider_helper.get_providers( + user=self.user + ) + return self._git_providers @property - def user_is_anonymous(self) -> bool: - """Return True if server's user is not registered.""" - return isinstance(self._user, AnonymousUser) + def required_git_providers(self) -> list[GitProvider]: + """The list of required git providers.""" + required_provider_ids: set[str] = set( + r.provider for r in self.repositories if r.provider + ) + return [p for p in self.git_providers if p.id in required_provider_ids] - @property - def k8s_client(self) -> K8sClient: - """Return server's k8s client.""" - return self._k8s_client + def __str__(self): + return ( + f"" + ) - @property - @lru_cache(maxsize=8) - def hibernation_allowed(self): - return self._user and not self.user_is_anonymous + def start(self) -> dict[str, Any] | None: + """Create the jupyterserver resource in k8s.""" + errors = self._get_start_errors() + if errors: + raise MissingResourceError( + message=( + "Cannot start the session because the following Git " + f"or Docker resources are missing: {', '.join(errors)}" + ) + ) + return self._k8s_client.create_server( + self._get_session_manifest(), self.safe_username + ) - def _branch_exists(self): - """Check if a specific branch exists in the user's gitlab project. + @staticmethod + def _check_flask_config(): + """Check the app config and ensure minimum required parameters are present.""" + if config.git.url is None: + raise ConfigurationError( + message="The gitlab URL is missing, it must be provided in an environment variable called GITLAB_URL" + ) + if config.git.registry is None: + raise ConfigurationError( + message="The url to the docker image registry is missing, it must be provided in " + "an environment variable called IMAGE_REGISTRY" + ) - The branch name is not required by the API and therefore - passing None to this function will return True. + @staticmethod + def _check_environment_variables_overrides(patches_list: list[dict[str, Any]]): + """Check if any patch overrides server's environment variables. + + Checks if it overrides with a different value or if two patches create environment variables with different + values. """ - if self.branch is not None and self.gl_project is not None: - try: - self.gl_project.branches.get(self.branch) - except Exception as err: - current_app.logger.warning(f"Branch {self.branch} cannot be verified or does not exist. {err}") - else: - return True - return False + env_vars = {} - def _commit_sha_exists(self): - """Check if a specific commit sha exists in the user's gitlab project.""" - if self.commit_sha is not None and self.gl_project is not None: - try: - self.gl_project.commits.get(self.commit_sha) - except Exception as err: - current_app.logger.warning(f"Commit {self.commit_sha} cannot be verified or does not exist. {err}") - else: - return True - return False + for patch_list in patches_list: + patches = patch_list["patch"] - def _get_patches(self): - return list( - chain( - general_patches.test(self), - general_patches.session_tolerations(self), - general_patches.session_affinity(self), - general_patches.session_node_selector(self), - general_patches.priority_class(self), - general_patches.dev_shm(self), - jupyter_server_patches.args(), - jupyter_server_patches.env(self), - jupyter_server_patches.image_pull_secret(self), - jupyter_server_patches.disable_service_links(), - jupyter_server_patches.rstudio_env_variables(self), - jupyter_server_patches.user_secrets(self), - git_proxy_patches.main(self), - git_sidecar_patches.main(self), - general_patches.oidc_unverified_email(self), - ssh_patches.main(), - # init container for certs must come before all other init containers - # so that it runs first before all other init containers - init_containers_patches.certificates(), - init_containers_patches.download_image(self), - init_containers_patches.git_clone(self), - inject_certificates_patches.proxy(self), - # Cloud Storage needs to patch the git clone sidecar spec and so should come after - # the sidecars - # WARN: this patch depends on the index of the sidecar and so needs to be updated - # if sidercars are added or removed - cloudstorage_patches.main(self), - ) - ) + for patch in patches: + path = patch["path"].lower() + if path.endswith("/env/-"): + name = patch["value"]["name"] + value = patch["value"]["value"] + key = (path, name) + + if key in env_vars and env_vars[key] != value: + raise DuplicateEnvironmentVariableError( + message=f"Environment variable {path}::{name} is being overridden by " + "multiple patches" + ) + else: + env_vars[key] = value + + def _get_start_errors(self) -> list[str]: + """Check if there are any errors before starting the server.""" + errors: list[str] + errors = [] + if self.image is None: + errors.append(f"image {self.image} does not exist or cannot be accessed") + return errors def _get_session_manifest(self): """Compose the body of the user session for the k8s operator.""" @@ -321,86 +280,74 @@ def _get_session_manifest(self): return manifest @staticmethod - def _check_environment_variables_overrides(patches_list): - """Check if any patch overrides server's environment variables. - - Checks if it overrides with a different value or if two patches create environment variables with different - values. - """ - env_vars = {} - - for patch_list in patches_list: - patches = patch_list["patch"] - - for patch in patches: - path = patch["path"].lower() - if path.endswith("/env/-"): - name = patch["value"]["name"] - value = patch["value"]["value"] - key = (path, name) - - if key in env_vars and env_vars[key] != value: - raise DuplicateEnvironmentVariableError( - message=f"Environment variable {path}::{name} is being overridden by multiple patches" - ) - else: - env_vars[key] = value + def _get_renku_annotation_prefix() -> str: + return config.session_get_endpoint_annotations.renku_annotation_prefix - def start(self) -> Optional[dict[str, Any]]: - """Create the jupyterserver resource in k8s.""" - error = [] - if self.gl_project is None: - error.append(f"project {self.project} does not exist") - if not self._branch_exists(): - error.append(f"branch {self.branch} does not exist") - if not self._commit_sha_exists(): - error.append(f"commit {self.commit_sha} does not exist") - if self.image is None: - error.append(f"image {self.image} does not exist or cannot be accessed") - if len(error) == 0: - js = self._k8s_client.create_server(self._get_session_manifest(), self.safe_username) - else: - raise MissingResourceError( - message=( - "Cannot start the session because the following Git " - f"or Docker resources are missing: {', '.join(error)}" - ) - ) - return js + def _get_patches(self) -> list[dict[str, Any]]: + has_repository = bool(self.repositories) - @property - def server_url(self) -> str: - """The URL where a user can access their session.""" - if isinstance(self._user, RegisteredUser): - return urljoin( - "https://" + config.sessions.ingress.host, - f"sessions/{self.server_name}", - ) - else: - return urljoin( - "https://" + config.sessions.ingress.host, - f"sessions/{self.server_name}?token={self._user.username}", + return list( + chain( + general_patches.test(self), + general_patches.session_tolerations(self), + general_patches.session_affinity(self), + general_patches.session_node_selector(self), + general_patches.priority_class(self), + general_patches.dev_shm(self), + jupyter_server_patches.args(), + jupyter_server_patches.env(self), + jupyter_server_patches.image_pull_secret(self), + jupyter_server_patches.disable_service_links(), + jupyter_server_patches.rstudio_env_variables(self), + jupyter_server_patches.user_secrets(self), + ( + git_proxy_patches.main(self) + if has_repository and not self.user.anonymous + else [] + ), + git_sidecar_patches.main(self) if has_repository else [], + general_patches.oidc_unverified_email(self), + ssh_patches.main(), + # init container for certs must come before all other init containers + # so that it runs first before all other init containers + init_containers_patches.certificates(), + init_containers_patches.download_image(self), + init_containers_patches.git_clone(self) if has_repository else [], + inject_certificates_patches.proxy(self), + # Cloud Storage needs to patch the git clone sidecar spec and so should come after + # the sidecars + # WARN: this patch depends on the index of the sidecar and so needs to be updated + # if sidercars are added or removed + cloudstorage_patches.main(self), ) - - def __str__(self): - return ( - f"" ) - def get_annotations(self): - prefix = config.session_get_endpoint_annotations.renku_annotation_prefix + def get_labels(self) -> dict[str, str | None]: + prefix = self._get_renku_annotation_prefix() + labels = { + "app": "jupyter", + "component": "singleuser-server", + f"{prefix}commit-sha": None, + f"{prefix}gitlabProjectId": None, + f"{prefix}safe-username": self.safe_username, + f"{prefix}quota": self.server_options.priority_class, + f"{prefix}userId": self._user.id, + } + return labels + + def get_annotations(self) -> dict[str, str | None]: + prefix = self._get_renku_annotation_prefix() annotations = { - f"{prefix}commit-sha": self.commit_sha, + f"{prefix}commit-sha": None, f"{prefix}gitlabProjectId": None, f"{prefix}safe-username": self._user.safe_username, f"{prefix}username": self._user.username, f"{prefix}userId": self._user.id, f"{prefix}servername": self.server_name, - f"{prefix}branch": self.branch, - f"{prefix}git-host": self.git_host, - f"{prefix}namespace": self.namespace, - f"{prefix}projectName": self.project, + f"{prefix}branch": None, + f"{prefix}git-host": None, + f"{prefix}namespace": None, + f"{prefix}projectName": None, f"{prefix}requested-image": self.image, f"{prefix}repository": None, f"{prefix}hibernation": "", @@ -409,63 +356,59 @@ def get_annotations(self): f"{prefix}hibernationDirty": "", f"{prefix}hibernationSynchronized": "", f"{prefix}hibernationDate": "", - f"{prefix}hibernatedSecondsThreshold": str(self.hibernated_seconds_threshold), + f"{prefix}hibernatedSecondsThreshold": str( + self.hibernated_seconds_threshold + ), f"{prefix}lastActivityDate": "", f"{prefix}idleSecondsThreshold": str(self.idle_seconds_threshold), } if self.server_options.resource_class_id: - annotations[f"{prefix}resourceClassId"] = str(self.server_options.resource_class_id) - if self.gl_project is not None: - annotations[f"{prefix}gitlabProjectId"] = str(self.gl_project.id) - annotations[f"{prefix}repository"] = self.gl_project.web_url + annotations[f"{prefix}resourceClassId"] = str( + self.server_options.resource_class_id + ) return annotations - def get_labels(self): - prefix = config.session_get_endpoint_annotations.renku_annotation_prefix - labels = { - "app": "jupyter", - "component": "singleuser-server", - f"{prefix}commit-sha": self.commit_sha, - f"{prefix}gitlabProjectId": None, - f"{prefix}safe-username": self._user.safe_username, - f"{prefix}quota": self.server_options.priority_class, - f"{prefix}userId": self._user.id, - } - if self.gl_project is not None: - labels[f"{prefix}gitlabProjectId"] = str(self.gl_project.id) - return labels - -class Renku2UserServer(UserServer): - """Represents a Renku 2 jupyter server session.""" +class Renku1UserServer(UserServer): + """Represents a Renku 1.0 server session.""" def __init__( self, - user: Union[AnonymousUser, RegisteredUser], - notebook: Optional[str], # TODO: Is this value actually needed? - image: str, - project_id: str, - launcher_id: str, + user: AnonymousUser | RegisteredUser, server_name: str, + namespace: str, + project: str, + branch: str, + commit_sha: str, + notebook: str | None, # TODO: Is this value actually needed? + image: str | None, server_options: ServerOptions, environment_variables: dict[str, str], - user_secrets: Optional[K8sUserSecrets], + user_secrets: K8sUserSecrets | None, cloudstorage: list[ICloudStorageRequest], k8s_client: K8sClient, workspace_mount_path: Path, work_dir: Path, - repositories: list[Repository], using_default_image: bool = False, is_image_private: bool = False, **_, ): + gitlab_project_name = f"{namespace}/{project}" + gitlab_project = user.get_renku_project(gitlab_project_name) + single_repository = ( + Repository( + url=gitlab_project.http_url_to_repo, + dirname=gitlab_project.path, + branch=branch, + commit_sha=commit_sha, + ) + if gitlab_project is not None + else None + ) + super().__init__( user=user, - namespace=None, - project=None, - branch=None, - commit_sha=None, - notebook=notebook, + server_name=server_name, image=image, server_options=server_options, environment_variables=environment_variables, @@ -476,54 +419,29 @@ def __init__( work_dir=work_dir, using_default_image=using_default_image, is_image_private=is_image_private, + repositories=[single_repository] if single_repository is not None else [], ) - self._server_name = server_name - self.project_id = project_id - self.launcher_id = launcher_id - self._repositories: list[Repository] = repositories or [] - self._calculated_repository_urls: bool = False - - @property - def gl_project(self): - if len(self._repositories) == 1: - project_path = f"{self._repositories[0].namespace}/{self._repositories[0].project}" - return self._user.get_renku_project(project_path) - return None - - @property - def gl_project_path(self) -> Optional[str]: - gl_project = self.gl_project - return gl_project.path if gl_project else None - - @property - def gl_project_url(self) -> Optional[str]: - """Return the common hostname of all repositories.""" - repositories = self.repositories - if not repositories: - return "" - elif len(repositories) == 1: - return repositories[0]["url"] - - # NOTE: For more than one repository, we only support one gitlab instance atm - return self._user.gitlab_client.url - - @property - def server_name(self): - """Make the name that is used to identify a unique user session.""" - return self._server_name - - @property - def repositories(self) -> list[dict[str, str]]: - if self._repositories and not self._calculated_repository_urls: - for r in self._repositories: - project = self._user.get_renku_project(f"{r.namespace}/{r.project}") - if project: - r.url = project.http_url_to_repo - - self._calculated_repository_urls = True - - return [asdict(r) for r in self._repositories] + self.namespace = namespace + self.project = project + self.branch = branch + self.commit_sha = commit_sha + self.notebook = notebook + self.git_host = urlparse(config.git.url).netloc + self.gitlab_project_name = gitlab_project_name + self.gitlab_project = gitlab_project + self.single_repository = single_repository + + def _get_start_errors(self) -> list[str]: + """Check if there are any errors before starting the server.""" + errors = super()._get_start_errors() + if self.gitlab_project is None: + errors.append(f"project {self.project} does not exist") + if not self._branch_exists(): + errors.append(f"branch {self.branch} does not exist") + if not self._commit_sha_exists(): + errors.append(f"commit {self.commit_sha} does not exist") + return errors def _branch_exists(self): """Check if a specific branch exists in the user's gitlab project. @@ -531,66 +449,97 @@ def _branch_exists(self): The branch name is not required by the API and therefore passing None to this function will return True. """ - raise NotImplementedError + if self.branch is not None and self.gitlab_project is not None: + try: + self.gitlab_project.branches.get(self.branch) + except Exception as err: + current_app.logger.warning( + f"Branch {self.branch} cannot be verified or does not exist. {err}" + ) + else: + return True + return False def _commit_sha_exists(self): """Check if a specific commit sha exists in the user's gitlab project.""" - raise NotImplementedError - - def get_annotations(self): - annotations = super().get_annotations() + if self.commit_sha is not None and self.gitlab_project is not None: + try: + self.gitlab_project.commits.get(self.commit_sha) + except Exception as err: + current_app.logger.warning( + f"Commit {self.commit_sha} cannot be verified or does not exist. {err}" + ) + else: + return True + return False - # Add Renku 2.0 annotations - prefix = config.session_get_endpoint_annotations.renku_annotation_prefix - annotations[f"{prefix}renkuVersion"] = "2.0" - annotations[f"{prefix}projectId"] = self.project_id - annotations[f"{prefix}launcherId"] = self.launcher_id + def get_labels(self) -> dict[str, str | None]: + prefix = self._get_renku_annotation_prefix() + labels = super().get_labels() + labels[f"{prefix}commit-sha"] = self.commit_sha + if self.gitlab_project is not None: + labels[f"{prefix}gitlabProjectId"] = str(self.gitlab_project.id) + return labels + def get_annotations(self) -> dict[str, str | None]: + prefix = self._get_renku_annotation_prefix() + annotations = super().get_annotations() + annotations[f"{prefix}commit-sha"] = self.commit_sha + annotations[f"{prefix}branch"] = self.branch + annotations[f"{prefix}git-host"] = self.git_host + annotations[f"{prefix}namespace"] = self.namespace + annotations[f"{prefix}projectName"] = self.project + if self.gitlab_project is not None: + annotations[f"{prefix}gitlabProjectId"] = str(self.gitlab_project.id) + annotations[f"{prefix}repository"] = self.gitlab_project.web_url return annotations - def _get_patches(self): - has_repository = bool(self._repositories) - return list( - chain( - general_patches.test(self), - general_patches.session_tolerations(self), - general_patches.session_affinity(self), - general_patches.session_node_selector(self), - general_patches.priority_class(self), - general_patches.dev_shm(self), - jupyter_server_patches.args(), - jupyter_server_patches.env(self), - jupyter_server_patches.image_pull_secret(self), - jupyter_server_patches.disable_service_links(), - jupyter_server_patches.rstudio_env_variables(self), - git_proxy_patches.main(self) if has_repository else [], - git_sidecar_patches.main(self) if has_repository else [], - general_patches.oidc_unverified_email(self), - ssh_patches.main(), - # init container for certs must come before all other init containers - # so that it runs first before all other init containers - init_containers_patches.certificates(), - init_containers_patches.download_image(self), - init_containers_patches.git_clone(self) if has_repository else [], - inject_certificates_patches.proxy(self), - # Cloud Storage needs to patch the git clone sidecar spec and so should come after - # the sidecars - # WARN: this patch depends on the index of the sidecar and so needs to be updated - # if sidercars are added or removed - cloudstorage_patches.main(self), - ) +class Renku2UserServer(UserServer): + """Represents a Renku 2.0 server session.""" + + def __init__( + self, + user: AnonymousUser | RegisteredUser, + image: str, + project_id: str, + launcher_id: str, + server_name: str, + server_options: ServerOptions, + environment_variables: dict[str, str], + user_secrets: K8sUserSecrets | None, + cloudstorage: list[ICloudStorageRequest], + k8s_client: K8sClient, + workspace_mount_path: Path, + work_dir: Path, + repositories: list[Repository], + using_default_image: bool = False, + is_image_private: bool = False, + **_, + ): + super().__init__( + user=user, + server_name=server_name, + image=image, + server_options=server_options, + environment_variables=environment_variables, + user_secrets=user_secrets, + cloudstorage=cloudstorage, + k8s_client=k8s_client, + workspace_mount_path=workspace_mount_path, + work_dir=work_dir, + using_default_image=using_default_image, + is_image_private=is_image_private, + repositories=repositories, ) - def start(self) -> Optional[dict[str, Any]]: - """Create the jupyterserver resource in k8s.""" - if self.image is None: - errors = [f"image {self.image} does not exist or cannot be accessed"] - raise MissingResourceError( - message=( - "Cannot start the session because the following Git " - f"or Docker resources are missing: {', '.join(errors)}" - ) - ) + self.project_id = project_id + self.launcher_id = launcher_id - return self._k8s_client.create_server(self._get_session_manifest(), self.safe_username) + def get_annotations(self): + prefix = self._get_renku_annotation_prefix() + annotations = super().get_annotations() + annotations[f"{prefix}renkuVersion"] = "2.0" + annotations[f"{prefix}projectId"] = self.project_id + annotations[f"{prefix}launcherId"] = self.launcher_id + return annotations diff --git a/renku_notebooks/api/notebooks.py b/renku_notebooks/api/notebooks.py index 4a59a7b45..3e6434b54 100644 --- a/renku_notebooks/api/notebooks.py +++ b/renku_notebooks/api/notebooks.py @@ -20,7 +20,7 @@ import logging from datetime import UTC, datetime from pathlib import Path -from typing import Optional +from typing import TYPE_CHECKING, Optional import requests from flask import Blueprint, current_app, jsonify @@ -28,7 +28,7 @@ from marshmallow import ValidationError, fields, validate from webargs.flaskparser import use_args -from renku_notebooks.api.classes.user import AnonymousUser +from renku_notebooks.api.classes.user import AnonymousUser, RegisteredUser from renku_notebooks.api.schemas.cloud_storage import RCloneStorage from renku_notebooks.util.repository import get_status @@ -36,11 +36,16 @@ from ..errors.intermittent import AnonymousUserPatchError, PVDisabledError from ..errors.programming import ProgrammingError from ..errors.user import MissingResourceError, UserInputError -from ..util.kubernetes_ import make_server_name, renku_2_make_server_name +from ..util.kubernetes_ import ( + find_container, + renku_1_make_server_name, + renku_2_make_server_name, +) from .auth import authenticated from .classes.auth import GitlabToken, RenkuTokens from .classes.image import Image -from .classes.server import Renku2UserServer, Repository, UserServer +from .classes.repository import Repository +from .classes.server import Renku1UserServer, Renku2UserServer, UserServer from .classes.server_manifest import UserServerManifest from .schemas.config_server_options import ServerOptionsEndpointResponse from .schemas.logs import ServerLogs @@ -51,6 +56,9 @@ from .schemas.servers_post import LaunchNotebookRequest, Renku2LaunchNotebookRequest from .schemas.version import VersionResponse +if TYPE_CHECKING: + from gitlab.v4.objects.projects import Project + bp = Blueprint("notebooks_blueprint", __name__, url_prefix=config.service_prefix) @@ -159,7 +167,7 @@ def user_server(user, server_name): @use_args(LaunchNotebookRequest(), location="json", as_kwargs=True) @authenticated def launch_notebook( - user, + user: AnonymousUser | RegisteredUser, namespace, project, branch, @@ -175,10 +183,12 @@ def launch_notebook( server_options=None, user_secrets=None, ): - server_name = make_server_name(user.safe_username, namespace, project, branch, commit_sha) + server_name = renku_1_make_server_name( + user.safe_username, namespace, project, branch, commit_sha + ) gl_project = user.get_renku_project(f"{namespace}/{project}") gl_project_path = gl_project.path - server_class = UserServer + server_class = Renku1UserServer return launch_notebook_helper( server_name=server_name, @@ -210,7 +220,7 @@ def launch_notebook( @use_args(Renku2LaunchNotebookRequest(), location="json", as_kwargs=True) @authenticated def renku_2_launch_notebook_helper( - user, + user: AnonymousUser | RegisteredUser, notebook, image, resource_class_id, @@ -221,21 +231,19 @@ def renku_2_launch_notebook_helper( cloudstorage=None, server_options=None, user_secrets=None, - project_id: Optional[str] = None, # Renku 2 - launcher_id: Optional[str] = None, # Renku 2 - repositories: Optional[list[dict[str, str]]] = None, # Renku 2 + project_id: str | None = None, # Renku 2 + launcher_id: str | None = None, # Renku 2 + repositories: list[dict[str, str]] | None = None, # Renku 2 ): server_name = renku_2_make_server_name( safe_username=user.safe_username, project_id=project_id, launcher_id=launcher_id ) - gl_project = None - gl_project_path = "" server_class = Renku2UserServer return launch_notebook_helper( server_name=server_name, - gl_project=gl_project, - gl_project_path=gl_project_path, + gl_project=None, + gl_project_path=None, server_class=server_class, user=user, namespace=None, @@ -260,15 +268,8 @@ def renku_2_launch_notebook_helper( def launch_notebook_helper( server_name: str, - gl_project, - gl_project_path: str, server_class: type[UserServer], - user, - namespace, - project, - branch, - commit_sha, - notebook, + user: AnonymousUser | RegisteredUser, image, resource_class_id, storage, @@ -278,9 +279,16 @@ def launch_notebook_helper( lfs_auto_fetch, cloudstorage, server_options, - project_id: Optional[str], # Renku 2 - launcher_id: Optional[str], # Renku 2 - repositories: Optional[list[dict[str, str]]], # Renku 2 + namespace: str | None, # Renku 1.0 + project: str | None, # Renku 1.0 + branch: str | None, # Renku 1.0 + commit_sha: str | None, # Renku 1.0 + notebook: str | None, # Renku 1.0 + gl_project: Optional["Project"], # Renku 1.0 + gl_project_path: str | None, # Renku 1.0 + project_id: str | None, # Renku 2.0 + launcher_id: str | None, # Renku 2.0 + repositories: list[dict[str, str]] | None, # Renku 2.0 ): """Helper function to launch a Jupyter server.""" server = config.k8s.client.get_server(server_name, user.safe_username) @@ -288,6 +296,8 @@ def launch_notebook_helper( if server: return NotebookResponse().dump(UserServerManifest(server)), 200 + gl_project_path = gl_project_path if gl_project_path is not None else "" + # Add annotation for old and new notebooks is_image_private = False using_default_image = False @@ -306,7 +316,7 @@ def launch_notebook_helper( parsed_image = Image.from_path(image) if image_exists_privately: is_image_private = True - else: + elif gl_project is not None: # An image was not requested specifically, use the one automatically built for the commit image = f"{config.git.registry}/{gl_project.path_with_namespace.lower()}:{commit_sha[:7]}" parsed_image = Image( @@ -317,7 +327,10 @@ def launch_notebook_helper( # NOTE: a project pulled from the Gitlab API without credentials has no visibility attribute # and by default it can only be public since only public projects are visible to # non-authenticated users. Also, a nice footgun from the Gitlab API Python library. - is_image_private = getattr(gl_project, "visibility", GitlabVisibility.PUBLIC) != GitlabVisibility.PUBLIC + is_image_private = ( + getattr(gl_project, "visibility", GitlabVisibility.PUBLIC) + != GitlabVisibility.PUBLIC + ) image_repo = parsed_image.repo_api() if is_image_private and user.git_token: image_repo = image_repo.with_oauth2_token(user.git_token) @@ -328,6 +341,8 @@ def launch_notebook_helper( "exist or the user does not have the permissions to access it." ) ) + else: + raise UserInputError(message="Cannot determine which Docker image to use.") if resource_class_id is not None: # A resource class ID was passed in, validate with CRC service @@ -369,7 +384,9 @@ def launch_notebook_helper( ) if storage is None: storage = default_resource_class.get("default_storage") - parsed_server_options = ServerOptions.from_resource_class(default_resource_class) + parsed_server_options = ServerOptions.from_resource_class( + default_resource_class + ) # Storage in request is in GB parsed_server_options.set_storage(storage, gigabytes=True) @@ -429,7 +446,7 @@ def launch_notebook_helper( work_dir=server_work_dir, using_default_image=using_default_image, is_image_private=is_image_private, - repositories=[Repository.from_schema(r) for r in repositories], + repositories=[Repository.from_dict(r) for r in repositories], namespace=namespace, project=project, branch=branch, @@ -600,7 +617,14 @@ def patch_server(user, server_name, patch_body): hibernation = {"branch": "", "commit": "", "dirty": "", "synchronized": ""} - status = get_status(server_name=server_name, access_token=user.access_token) + sidecar_patch = find_container( + server.get("spec", {}).get("patches", []), "git-sidecar" + ) + status = ( + get_status(server_name=server_name, access_token=user.access_token) + if sidecar_patch is not None + else None + ) if status: hibernation = { "branch": status.get("branch", ""), diff --git a/renku_notebooks/api/schemas/repository.py b/renku_notebooks/api/schemas/repository.py index 35860e6af..38e346d25 100644 --- a/renku_notebooks/api/schemas/repository.py +++ b/renku_notebooks/api/schemas/repository.py @@ -2,16 +2,11 @@ from marshmallow import Schema, fields -from .custom_fields import LowercaseString - class Repository(Schema): """Information required to clone a repository.""" - # namespaces in gitlab are NOT case-sensitive - namespace = LowercaseString(required=True) - # project names in gitlab are NOT case-sensitive - project = LowercaseString(required=True) - # branch names in gitlab are case-sensitive - branch = fields.Str(load_default="master") - commit_sha = fields.Str(required=True) + url: str = fields.Str(required=True) + dirname: str | None = fields.Str() + branch: str | None = fields.Str() + commit_sha: str | None = fields.Str() diff --git a/renku_notebooks/api/schemas/servers_post.py b/renku_notebooks/api/schemas/servers_post.py index 6b9eb3c55..6ac358a97 100644 --- a/renku_notebooks/api/schemas/servers_post.py +++ b/renku_notebooks/api/schemas/servers_post.py @@ -67,7 +67,7 @@ class LaunchNotebookRequestWithStorage(LaunchNotebookRequestWithoutStorage): class Renku2LaunchNotebookRequest(LaunchNotebookRequestWithoutStorageBase): - """To validate start request for Renku 2 sessions.""" + """To validate start request for Renku 2.0 sessions.""" project_id = fields.String(required=True) launcher_id = fields.String(required=True) diff --git a/renku_notebooks/config/__init__.py b/renku_notebooks/config/__init__.py index 800b75e5c..b865b488c 100644 --- a/renku_notebooks/config/__init__.py +++ b/renku_notebooks/config/__init__.py @@ -22,6 +22,7 @@ if TYPE_CHECKING: from ..api.classes.data_service import CloudStorageConfig + from ..api.classes.repository import GitProvider from ..api.classes.user import User from ..api.schemas.server_options import ServerOptions @@ -49,6 +50,10 @@ def validate_storage_configuration(self, configuration: dict[str, Any], source_p def obscure_password_fields_for_storage(self, configuration: dict[str, Any]) -> dict[str, Any]: ... +class GitProviderHelperProto(Protocol): + def get_providers(self, user: "User") -> list["GitProvider"]: ... + + @dataclass class _NotebooksConfig: server_options: _ServerOptionsConfig @@ -99,6 +104,7 @@ def __post_init__(self): ) self._crc_validator = None self._storage_validator = None + self._git_provider_helper = None @property def crc_validator(self) -> CRCValidatorProto: @@ -124,6 +130,22 @@ def storage_validator(self) -> StorageValidatorProto: return self._storage_validator + @property + def git_provider_helper(self) -> GitProviderHelperProto: + from ..api.classes.data_service import DummyGitProviderHelper, GitProviderHelper + + if not self._git_provider_helper: + if self.dummy_stores: + self._git_provider_helper = DummyGitProviderHelper() + else: + self._git_provider_helper = GitProviderHelper( + service_url=self.data_service_url, + renku_url="https://" + self.sessions.ingress.host, + internal_gitlab_url=config.git.url, + ) + + return self._git_provider_helper + def get_config(default_config: str) -> _NotebooksConfig: """Compiles the configuration for the notebook service. diff --git a/renku_notebooks/util/kubernetes_.py b/renku_notebooks/util/kubernetes_.py index 88e60debe..5b8076ed9 100644 --- a/renku_notebooks/util/kubernetes_.py +++ b/renku_notebooks/util/kubernetes_.py @@ -19,6 +19,7 @@ from __future__ import annotations from hashlib import md5 +from typing import Any import escapism from kubernetes.client import V1Container @@ -36,7 +37,10 @@ def filter_resources_by_annotations( def filter_resource(resource): res = [] for annotation_name in annotations: - res.append(resource["metadata"]["annotations"].get(annotation_name) == annotations[annotation_name]) + res.append( + resource["metadata"]["annotations"].get(annotation_name) + == annotations[annotation_name] + ) if len(res) == 0: return True else: @@ -45,39 +49,50 @@ def filter_resource(resource): return list(filter(filter_resource, resources)) -def make_server_name(safe_username: str, namespace: str, project: str, branch: str, commit_sha: str) -> str: - """Form a unique server name. +def renku_1_make_server_name( + safe_username: str, namespace: str, project: str, branch: str, commit_sha: str +) -> str: + """Form a unique server name for Renku 1.0 sessions. This is used in naming all the k8s resources created by amalthea. """ - server_string_for_hashing = f"{safe_username}-{namespace}-{project}-{branch}-{commit_sha}" - safe_username_lowercase = safe_username.lower() - if safe_username_lowercase[0].isalpha() and safe_username_lowercase[0].isascii(): # noqa: SIM108 - prefix = "" - else: - # NOTE: Username starts with an invalid character. This has to be modified because a - # k8s service object cannot start with anything other than a lowercase alphabet character. - # NOTE: We do not have worry about collisions with already existing servers from older - # versions because the server name includes the hash of the original username, so the hash - # would be different because the original username differs between someone whose username - # is for example 7User vs. n7User. - prefix = "n" - return "{prefix}{username}-{project}-{hash}".format( - prefix=prefix, - username=safe_username_lowercase[:10], - project=escapism.escape(project, escape_char="-")[:24].lower(), - hash=md5(server_string_for_hashing.encode()).hexdigest()[:8].lower(), + server_string_for_hashing = ( + f"{safe_username}-{namespace}-{project}-{branch}-{commit_sha}" + ) + server_hash = md5(server_string_for_hashing.encode()).hexdigest().lower() + prefix = _make_server_name_prefix(safe_username) + # NOTE: A K8s object name can only contain lowercase alphanumeric characters, hyphens, or dots. + # Must be less than 253 characters long and start and end with an alphanumeric. + # NOTE: We use server name as a label value, so, server name must be less than 63 characters. + # NOTE: Amalthea adds 11 characters to the server name in a label, so we have only + # 52 characters available. + # !NOTE: For now we limit the server name to 42 characters. + # NOTE: This is 12 + 1 + 20 + 1 + 8 = 42 characters + return "{prefix}-{project}-{hash}".format( + prefix=prefix[:12], + project=escapism.escape(project, escape_char="-")[:20].lower(), + hash=server_hash[:8], ) -def renku_2_make_server_name(safe_username: str, project_id: str, launcher_id: str) -> str: - """Form a unique server name.""" - all_hash = md5(f"{safe_username}{project_id}{launcher_id}".encode()).hexdigest().lower() +def renku_2_make_server_name( + safe_username: str, project_id: str, launcher_id: str +) -> str: + """Form a unique server name for Renku 2.0 sessions. + This is used in naming all the k8s resources created by amalthea. + """ + server_string_for_hashing = f"{safe_username}-{project_id}-{launcher_id}" + server_hash = md5(server_string_for_hashing.encode()).hexdigest().lower() + prefix = _make_server_name_prefix(safe_username) # NOTE: A K8s object name can only contain lowercase alphanumeric characters, hyphens, or dots. # Must be less than 253 characters long and start and end with an alphanumeric. # NOTE: We use server name as a label value, so, server name must be less than 63 characters. - return f"renku-2-{all_hash[:40]}" + # NOTE: Amalthea adds 11 characters to the server name in a label, so we have only + # 52 characters available. + # !NOTE: For now we limit the server name to 42 characters. + # NOTE: This is 12 + 9 + 21 = 42 characters + return f"{prefix[:12]}-renku-2-{server_hash[:21]}" def find_env_var(container: V1Container, env_name: str) -> tuple[int, str] | None: @@ -94,3 +109,38 @@ def find_env_var(container: V1Container, env_name: str) -> tuple[int, str] | Non ind = env_var[0] val = env_var[1].value return ind, val + + +def _make_server_name_prefix(safe_username: str): + safe_username_lowercase = safe_username.lower() + prefix = "" + if ( + not safe_username_lowercase[0].isalpha() + or not safe_username_lowercase[0].isascii() + ): + # NOTE: Username starts with an invalid character. This has to be modified because a + # k8s service object cannot start with anything other than a lowercase alphabet character. + # NOTE: We do not have worry about collisions with already existing servers from older + # versions because the server name includes the hash of the original username, so the hash + # would be different because the original username differs between someone whose username + # is for example 7User vs. n7User. + prefix = "n" + + prefix = f"{prefix}{safe_username_lowercase}" + return prefix + + +def find_container( + patches: list[dict[str, Any]], container_name: str +) -> dict[str, Any] | None: + """Find the json patch corresponding a given container.""" + for patch_obj in patches: + inner_patches = patch_obj.get("patch", []) + for p in inner_patches: + if ( + p.get("op") == "add" + and p.get("path") == "/statefulset/spec/template/spec/containers/-" + and p.get("value", {}).get("name") == container_name + ): + return p + return None diff --git a/tests/unit/test_server_class/test_manifest.py b/tests/unit/test_server_class/test_manifest.py index 77f0d9df5..8635993a2 100644 --- a/tests/unit/test_server_class/test_manifest.py +++ b/tests/unit/test_server_class/test_manifest.py @@ -9,6 +9,7 @@ from renku_notebooks.api.schemas.server_options import ServerOptions from renku_notebooks.errors.programming import DuplicateEnvironmentVariableError from renku_notebooks.errors.user import OverriddenEnvironmentVariableError +from renku_notebooks.util.kubernetes_ import renku_1_make_server_name BASE_PARAMETERS = { "namespace": "test-namespace", @@ -56,6 +57,13 @@ def test_session_manifest( base_parameters = BASE_PARAMETERS.copy() base_parameters["user"] = user_with_project_path("namespace/project") base_parameters["k8s_client"] = mocker.MagicMock(K8sClient) + base_parameters["server_name"] = renku_1_make_server_name( + safe_username=base_parameters["user"].safe_username, + namespace=base_parameters["namespace"], + project=base_parameters["project"], + branch=base_parameters["branch"], + commit_sha=base_parameters["commit_sha"], + ) server = UserServer(**{**base_parameters, **parameters}) server._repositories = {} @@ -177,6 +185,13 @@ def test_user_secrets_manifest( base_parameters = BASE_PARAMETERS.copy() base_parameters["user"] = user_with_project_path("namespace/project") base_parameters["k8s_client"] = mocker.MagicMock(K8sClient) + base_parameters["server_name"] = renku_1_make_server_name( + safe_username=base_parameters["user"].safe_username, + namespace=base_parameters["namespace"], + project=base_parameters["project"], + branch=base_parameters["branch"], + commit_sha=base_parameters["commit_sha"], + ) server = UserServer(**{**base_parameters, **parameters}) server._repositories = {} @@ -200,6 +215,13 @@ def test_session_env_var_override(patch_user_server, user_with_project_path, app parameters["k8s_client"] = mocker.MagicMock(K8sClient) # NOTE: NOTEBOOK_DIR is defined in ``jupyter_server.env`` patch parameters["environment_variables"] = {"NOTEBOOK_DIR": "/some/path"} + parameters["server_name"] = renku_1_make_server_name( + safe_username=parameters["user"].safe_username, + namespace=parameters["namespace"], + project=parameters["project"], + branch=parameters["branch"], + commit_sha=parameters["commit_sha"], + ) server = UserServer(**parameters) server._repositories = {} @@ -237,6 +259,13 @@ def test_patches_env_var_override(patch_user_server, user_with_project_path, app parameters = BASE_PARAMETERS.copy() parameters["user"] = user_with_project_path("namespace/project") parameters["k8s_client"] = mocker.MagicMock(K8sClient) + parameters["server_name"] = renku_1_make_server_name( + safe_username=parameters["user"].safe_username, + namespace=parameters["namespace"], + project=parameters["project"], + branch=parameters["branch"], + commit_sha=parameters["commit_sha"], + ) server = UserServer(**parameters) server._repositories = {}