From ffdaff06c3e5c579e3dbf34eaa0a189334ac559d Mon Sep 17 00:00:00 2001 From: Shiming Zhang Date: Mon, 10 Feb 2025 19:44:43 +0800 Subject: [PATCH] Fix big cache --- agent/agent.go | 93 ++++++++++++++++++++++++++++++++++----------- agent/blob_cache.go | 11 ++++-- runner/runner.go | 88 ++++++++++++++++++++++++------------------ 3 files changed, 129 insertions(+), 63 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index 48a4744..1d32e56 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -338,19 +338,51 @@ func (c *Agent) Serve(rw http.ResponseWriter, r *http.Request, info *BlobInfo, t } c.rateLimit(rw, r, info.Blobs, info, t, value.Size, start) + if value.BigCache { + c.serveBigCachedBlob(rw, r, info.Blobs, info, t, value.Size) + return + } + c.serveCachedBlob(rw, r, info.Blobs, info, t, value.Size) return } + var isBigCache bool stat, err := c.cache.StatBlob(ctx, info.Blobs) if err == nil { - if c.serveCachedBlobHead(rw, r, stat.Size()) { + if c.bigCache != nil && stat.Size() >= int64(c.bigCacheSize) { + isBigCache = true + _, err := c.bigCache.StatBlob(ctx, info.Blobs) + if err == nil { + if c.serveCachedBlobHead(rw, r, stat.Size()) { + return + } + + c.rateLimit(rw, r, info.Blobs, info, t, value.Size, start) + c.serveBigCachedBlob(rw, r, info.Blobs, info, t, stat.Size()) + return + } + } else { + if c.serveCachedBlobHead(rw, r, stat.Size()) { + return + } + + c.rateLimit(rw, r, info.Blobs, info, t, value.Size, start) + c.serveCachedBlob(rw, r, info.Blobs, info, t, stat.Size()) return } + } else { + stat, err := c.bigCache.StatBlob(ctx, info.Blobs) + if err == nil { + isBigCache = true + if c.serveCachedBlobHead(rw, r, stat.Size()) { + return + } - c.rateLimit(rw, r, info.Blobs, info, t, value.Size, start) - c.serveCachedBlob(rw, r, info.Blobs, info, t, stat.Size()) - return + c.rateLimit(rw, r, info.Blobs, info, t, value.Size, start) + c.serveBigCachedBlob(rw, r, info.Blobs, info, t, stat.Size()) + return + } } c.rateLimit(rw, r, info.Blobs, info, t, value.Size, start) @@ -373,6 +405,17 @@ func (c *Agent) Serve(rw http.ResponseWriter, r *http.Request, info *BlobInfo, t return } + if isBigCache { + stat, err = c.bigCache.StatBlob(ctx, info.Blobs) + if err == nil { + if c.serveCachedBlobHead(rw, r, stat.Size()) { + return + } + c.serveBigCachedBlob(rw, r, info.Blobs, info, t, stat.Size()) + return + } + } + stat, err = c.cache.StatBlob(ctx, info.Blobs) if err == nil { if c.serveCachedBlobHead(rw, r, stat.Size()) { @@ -475,7 +518,7 @@ func (c *Agent) cacheBlob(info *BlobInfo) (int64, func() error, int, error) { if err != nil { return fmt.Errorf("Put to big cache: %w", err) } - c.blobCache.PutNoTTL(info.Blobs, size) + c.blobCache.PutNoTTL(info.Blobs, size, true) return nil } @@ -483,7 +526,7 @@ func (c *Agent) cacheBlob(info *BlobInfo) (int64, func() error, int, error) { if err != nil { return fmt.Errorf("Put to cache: %w", err) } - c.blobCache.Put(info.Blobs, size) + c.blobCache.Put(info.Blobs, size, false) return nil } @@ -508,28 +551,29 @@ func (c *Agent) serveCachedBlobHead(rw http.ResponseWriter, r *http.Request, siz return false } -func (c *Agent) serveCachedBlob(rw http.ResponseWriter, r *http.Request, blob string, info *BlobInfo, t *token.Token, size int64) { +func (c *Agent) serveBigCachedBlob(rw http.ResponseWriter, r *http.Request, blob string, info *BlobInfo, t *token.Token, size int64) { referer := r.RemoteAddr if info != nil { referer = fmt.Sprintf("%d-%d:%s:%s/%s", t.RegistryID, t.TokenID, referer, info.Host, info.Image) } - if c.bigCache != nil && c.bigCacheSize > 0 && size >= int64(c.bigCacheSize) { - u, err := c.bigCache.RedirectBlob(r.Context(), blob, referer) - if err != nil { - c.logger.Info("failed to redirect blob", "digest", blob, "error", err) - c.blobCache.Remove(info.Blobs) - utils.ServeError(rw, r, errcode.ErrorCodeUnknown, 0) - return - } - - c.blobCache.PutNoTTL(info.Blobs, size) - - c.logger.Info("Cache hit", "digest", blob, "url", u) - http.Redirect(rw, r, u, http.StatusTemporaryRedirect) + u, err := c.bigCache.RedirectBlob(r.Context(), blob, referer) + if err != nil { + c.logger.Info("failed to redirect blob", "digest", blob, "error", err) + c.blobCache.Remove(info.Blobs) + utils.ServeError(rw, r, errcode.ErrorCodeUnknown, 0) return } + c.blobCache.PutNoTTL(info.Blobs, size, true) + + c.logger.Info("Big Cache hit", "digest", blob, "url", u) + http.Redirect(rw, r, u, http.StatusTemporaryRedirect) + return +} + +func (c *Agent) serveCachedBlob(rw http.ResponseWriter, r *http.Request, blob string, info *BlobInfo, t *token.Token, size int64) { + if c.blobNoRedirectSize < 0 || int64(c.blobNoRedirectSize) > size { data, err := c.cache.GetBlob(r.Context(), info.Blobs) if err != nil { @@ -540,7 +584,7 @@ func (c *Agent) serveCachedBlob(rw http.ResponseWriter, r *http.Request, blob st } defer data.Close() - c.blobCache.Put(info.Blobs, size) + c.blobCache.Put(info.Blobs, size, false) rw.Header().Set("Content-Length", strconv.FormatInt(size, 10)) rw.Header().Set("Content-Type", "application/octet-stream") @@ -561,6 +605,11 @@ func (c *Agent) serveCachedBlob(rw http.ResponseWriter, r *http.Request, blob st // fallback to redirect } + referer := r.RemoteAddr + if info != nil { + referer = fmt.Sprintf("%d-%d:%s:%s/%s", t.RegistryID, t.TokenID, referer, info.Host, info.Image) + } + u, err := c.cache.RedirectBlob(r.Context(), blob, referer) if err != nil { c.logger.Info("failed to redirect blob", "digest", blob, "error", err) @@ -569,7 +618,7 @@ func (c *Agent) serveCachedBlob(rw http.ResponseWriter, r *http.Request, blob st return } - c.blobCache.Put(info.Blobs, size) + c.blobCache.Put(info.Blobs, size, false) c.logger.Info("Cache hit", "digest", blob, "url", u) http.Redirect(rw, r, u, http.StatusTemporaryRedirect) diff --git a/agent/blob_cache.go b/agent/blob_cache.go index cba7837..a936d73 100644 --- a/agent/blob_cache.go +++ b/agent/blob_cache.go @@ -47,20 +47,23 @@ func (m *blobCache) PutError(key string, err error, sc int) { }, m.duration) } -func (m *blobCache) Put(key string, size int64) { +func (m *blobCache) Put(key string, size int64, bigCache bool) { m.digest.SetWithTTL(key, blobValue{ - Size: size, + Size: size, + BigCache: bigCache, }, m.duration) } -func (m *blobCache) PutNoTTL(key string, size int64) { +func (m *blobCache) PutNoTTL(key string, size int64, bigCache bool) { m.digest.Set(key, blobValue{ - Size: size, + Size: size, + BigCache: bigCache, }) } type blobValue struct { Size int64 + BigCache bool Error error StatusCode int } diff --git a/runner/runner.go b/runner/runner.go index 9fd32bc..bf9ce08 100644 --- a/runner/runner.go +++ b/runner/runner.go @@ -372,19 +372,56 @@ func (r *Runner) runOnceManifestSync(ctx context.Context) error { } func (r *Runner) blob(ctx context.Context, host, name, blob string, size int64, gotSize, progress *atomic.Int64) error { + u := &url.URL{ + Scheme: "https", + Host: host, + Path: fmt.Sprintf("/v2/%s/blobs/%s", name, blob), + } + + if size == 0 { + req, err := http.NewRequestWithContext(ctx, http.MethodHead, u.String(), nil) + if err != nil { + return err + } + resp, err := r.httpClient.Do(req) + if err != nil { + return err + } + if resp.Body != nil { + _ = resp.Body.Close() + } + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to head blob: status code %d", resp.StatusCode) + } + size = resp.ContentLength + } + + if size > 0 { + gotSize.Store(size) + } + + var caches []*cache.Cache + + if r.bigCache != nil && r.bigCacheSize > 0 && size >= int64(r.bigCacheSize) { + if !r.bigCacheBackup { + caches = []*cache.Cache{r.bigCache} + } else { + caches = append([]*cache.Cache{r.bigCache}, r.caches...) + } + } else { + caches = append(caches, r.caches...) + } + var subCaches []*cache.Cache - for _, cache := range r.caches { + for _, cache := range caches { stat, err := cache.StatBlob(ctx, blob) if err == nil { - if size > 0 { - gotSize := stat.Size() - if size == gotSize { - continue - } - r.logger.Error("size is not meeting expectations", "digest", blob, "size", size, "gotSize", gotSize) - } else { + gotSize := stat.Size() + if size == gotSize { continue } + r.logger.Error("size is not meeting expectations", "digest", blob, "size", size, "gotSize", gotSize) } subCaches = append(subCaches, cache) } @@ -394,12 +431,6 @@ func (r *Runner) blob(ctx context.Context, host, name, blob string, size int64, return nil } - u := &url.URL{ - Scheme: "https", - Host: host, - Path: fmt.Sprintf("/v2/%s/blobs/%s", name, blob), - } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) if err != nil { return err @@ -411,7 +442,7 @@ func (r *Runner) blob(ctx context.Context, host, name, blob string, size int64, defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return fmt.Errorf("failed to retrieve blob: status code %d", resp.StatusCode) + return fmt.Errorf("failed to get blob: status code %d", resp.StatusCode) } r.logger.Info("start sync blob", "digest", blob, "url", u.String()) @@ -422,32 +453,11 @@ func (r *Runner) blob(ctx context.Context, host, name, blob string, size int64, } } - if size > 0 { - gotSize.Store(size) - } - if resp.ContentLength > 0 { - gotSize.Store(resp.ContentLength) - } - body := &readerCounter{ r: resp.Body, counter: progress, } - if r.bigCache != nil && r.bigCacheSize > 0 && gotSize.Load() >= int64(r.bigCacheSize) { - if !r.bigCacheBackup { - n, err := r.bigCache.PutBlob(ctx, blob, body) - if err != nil { - return fmt.Errorf("put blob failed: %w", err) - } - - r.logger.Info("finish sync blob", "digest", blob, "size", n) - return nil - } - - subCaches = append(subCaches, r.bigCache) - } - if len(subCaches) == 1 { n, err := subCaches[0].PutBlob(ctx, blob, body) if err != nil { @@ -640,6 +650,10 @@ func (r *Runner) manifest(ctx context.Context, messageID int64, host, image, tag _ = resp.Body.Close() } + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to head manifest: status code %d", resp.StatusCode) + } + digest := resp.Header.Get("Docker-Content-Digest") if digest != "" { for _, cache := range r.caches { @@ -682,7 +696,7 @@ func (r *Runner) manifest(ctx context.Context, messageID int64, host, image, tag defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return fmt.Errorf("failed to retrieve manifest: status code %d", resp.StatusCode) + return fmt.Errorf("failed to get manifest: status code %d", resp.StatusCode) } body, err := io.ReadAll(resp.Body)