From 7043b277f11909308b3ff01074ab4bcbd0209ba1 Mon Sep 17 00:00:00 2001 From: zongz Date: Wed, 30 Oct 2024 15:30:31 +0800 Subject: [PATCH] feat: make 'pull' supports ModSpec Signed-off-by: zongz --- pkg/client/pull.go | 44 ++++++++----- pkg/client/pull_test.go | 61 +++++++++++++++++++ .../test_data/test_pull_with_modspec/.gitkeep | 0 pkg/downloader/source.go | 27 ++++++-- 4 files changed, 114 insertions(+), 18 deletions(-) create mode 100644 pkg/client/test_data/test_pull_with_modspec/.gitkeep diff --git a/pkg/client/pull.go b/pkg/client/pull.go index 00590de2..c54cd0bd 100644 --- a/pkg/client/pull.go +++ b/pkg/client/pull.go @@ -3,11 +3,13 @@ package client import ( "errors" "fmt" + "os" "path/filepath" "kcl-lang.io/kpm/pkg/downloader" pkg "kcl-lang.io/kpm/pkg/package" "kcl-lang.io/kpm/pkg/reporter" + "kcl-lang.io/kpm/pkg/utils" ) // The PullOptions struct contains the options for pulling a package from the registry. @@ -79,25 +81,39 @@ func (c *KpmClient) Pull(options ...PullOption) (*pkg.KclPkg, error) { c.GetLogWriter(), ) - kPkg, err := c.downloadPkg( - // The package pulled will be stored in the 'opts.LocalPath/sourceFilePath' - // 'opts.LocalPath' is the local path input by the user. - // 'sourceFilePath' is generated by the source. - // For example, - // kcl package from 'https://github.com/kcl-lang/kcl' will be stored in '$LOCAL_PATH/git/github.com/kcl-lang/kcl' - // kcl package from 'oci://ghcr.io/kcl-lang/kcl' will be stored in '$LOCAL_PATH/oci/ghcr.io/kcl-lang/kcl' - downloader.WithLocalPath(filepath.Join(opts.LocalPath, sourceFilePath)), - downloader.WithSource(*opts.Source), - downloader.WithInsecureSkipTLSverify(c.insecureSkipTLSverify), - ) + pkgSource := opts.Source + pulledFullPath := filepath.Join(opts.LocalPath, sourceFilePath) + + err = NewVisitor(*pkgSource, c).Visit(pkgSource, func(kPkg *pkg.KclPkg) error { + if !utils.DirExists(filepath.Dir(pulledFullPath)) { + err := os.MkdirAll(filepath.Dir(pulledFullPath), os.ModePerm) + if err != nil { + return err + } + } + err := utils.MoveOrCopy(kPkg.HomePath, pulledFullPath) + if err != nil { + return err + } + reporter.ReportMsgTo( + fmt.Sprintf("pulled %s %s successfully", kPkg.GetPkgName(), kPkg.GetPkgVersion()), + c.GetLogWriter(), + ) + return nil + }) if err != nil { return nil, err } - reporter.ReportMsgTo( - fmt.Sprintf("pulled %s %s successfully", kPkg.GetPkgName(), kPkg.GetPkgVersion()), - c.GetLogWriter(), + kPkg, err := pkg.LoadKclPkgWithOpts( + pkg.WithPath(pulledFullPath), + pkg.WithSettings(c.GetSettings()), ) + + if err != nil { + return nil, err + } + return kPkg, nil } diff --git a/pkg/client/pull_test.go b/pkg/client/pull_test.go index 2b03056f..00685c77 100644 --- a/pkg/client/pull_test.go +++ b/pkg/client/pull_test.go @@ -16,6 +16,10 @@ import ( func TestPull(t *testing.T) { pulledPath := getTestDir("test_pull") + defer func() { + err := os.RemoveAll(filepath.Join(pulledPath, "oci")) + assert.NilError(t, err) + }() kpmcli, err := NewKpmClient() assert.NilError(t, err) @@ -39,6 +43,8 @@ func TestPull(t *testing.T) { assert.Equal(t, kPkg.GetPkgName(), "helloworld") assert.Equal(t, kPkg.GetPkgVersion(), "0.0.1") assert.Equal(t, kPkg.HomePath, pkgPath) + err = os.RemoveAll(filepath.Join(pulledPath, "oci")) + assert.NilError(t, err) kPkg, err = kpmcli.Pull( WithLocalPath(pulledPath), @@ -121,3 +127,58 @@ func TestInsecureSkipTLSverifyOCIRegistry(t *testing.T) { assert.Equal(t, buf.String(), "Called Success\n") } + +func TestPullWithModSpec(t *testing.T) { + pulledPath := getTestDir("test_pull_with_modspec") + defer func() { + err := os.RemoveAll(filepath.Join(pulledPath, "oci")) + assert.NilError(t, err) + }() + + kpmcli, err := NewKpmClient() + assert.NilError(t, err) + + var buf bytes.Buffer + kpmcli.SetLogWriter(&buf) + + kPkg, err := kpmcli.Pull( + WithLocalPath(pulledPath), + WithPullSource(&downloader.Source{ + ModSpec: &downloader.ModSpec{ + Name: "subhelloworld", + Version: "0.0.1", + }, + Oci: &downloader.Oci{ + Reg: "ghcr.io", + Repo: "kcl-lang/helloworld", + Tag: "0.1.4", + }, + }), + ) + + pkgPath := filepath.Join(pulledPath, "oci", "ghcr.io", "kcl-lang", "helloworld", "0.1.4", "subhelloworld", "0.0.1") + assert.NilError(t, err) + assert.Equal(t, kPkg.GetPkgName(), "subhelloworld") + assert.Equal(t, kPkg.GetPkgVersion(), "0.0.1") + assert.Equal(t, kPkg.HomePath, pkgPath) + err = os.RemoveAll(filepath.Join(pulledPath, "oci")) + assert.NilError(t, err) + + kPkg, err = kpmcli.Pull( + WithLocalPath(pulledPath), + WithPullSourceUrl("oci://ghcr.io/kcl-lang/helloworld?tag=0.1.4&mod=subhelloworld:0.0.1"), + ) + pkgPath = filepath.Join(pulledPath, "oci", "ghcr.io", "kcl-lang", "helloworld", "0.1.4", "subhelloworld", "0.0.1") + assert.NilError(t, err) + assert.Equal(t, kPkg.GetPkgName(), "subhelloworld") + assert.Equal(t, kPkg.GetPkgVersion(), "0.0.1") + assert.Equal(t, kPkg.HomePath, pkgPath) + err = os.RemoveAll(filepath.Join(pulledPath, "oci")) + assert.NilError(t, err) + + _, err = kpmcli.Pull( + WithLocalPath(pulledPath), + WithPullSourceUrl("oci://ghcr.io/kcl-lang/helloworld?tag=0.1.4&mod=subhelloworld:0.0.2"), + ) + assert.Equal(t, err.Error(), "version mismatch: 0.0.1 != 0.0.2, version 0.0.2 not found") +} diff --git a/pkg/client/test_data/test_pull_with_modspec/.gitkeep b/pkg/client/test_data/test_pull_with_modspec/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/pkg/downloader/source.go b/pkg/downloader/source.go index a2996ef2..71718eba 100644 --- a/pkg/downloader/source.go +++ b/pkg/downloader/source.go @@ -197,19 +197,38 @@ func (local *Local) FindRootPath() (string, error) { } func (source *Source) ToFilePath() (string, error) { + var path string + var err error if source == nil { return "", fmt.Errorf("source is nil") } if source.Git != nil { - return source.Git.ToFilePath() + path, err = source.Git.ToFilePath() + if err != nil { + return "", err + } } if source.Oci != nil { - return source.Oci.ToFilePath() + path, err = source.Oci.ToFilePath() + if err != nil { + return "", err + } } if source.Local != nil { - return source.Local.ToFilePath() + path, err = source.Local.ToFilePath() + if err != nil { + return "", err + } } - return "", fmt.Errorf("source is nil") + + if !source.ModSpec.IsNil() { + path = filepath.Join(path, source.ModSpec.Name, source.ModSpec.Version) + if err != nil { + return "", err + } + } + + return path, err } func (git *Git) ToFilePath() (string, error) {