Skip to content

Commit

Permalink
Add some archivehelpers
Browse files Browse the repository at this point in the history
  • Loading branch information
bep committed Sep 9, 2022
1 parent 6aacbe1 commit 4ba6d65
Show file tree
Hide file tree
Showing 2 changed files with 303 additions and 0 deletions.
213 changes: 213 additions & 0 deletions archivehelpers/archivehelpers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
package archivehelpers

import (
"archive/tar"
"compress/gzip"
"fmt"
"io"
"os"
"path/filepath"
"strings"
)

const (
// TypeUnknown is the default value for an unknown type.
TypeUnknown Type = iota
// TypeTarGz is a tar.gz archive.
TypeTarGz
)

// New returns a new Archiver for the given type.
func New(typ Type) (Archiver, error) {
switch typ {
case TypeTarGz:
return &archivist{archiver: &tarGzExtractor{}}, nil
default:
return nil, fmt.Errorf("unknown type %d", typ)
}
}

// Archiver is an interface for archiving files and directories.
type Archiver interface {
Extracter

// ArchiveDirectory archives the given directory into the given output stream
// for all files matching predicate.
// out is closed by this method.
ArchiveDirectory(directory string, predicate func(string) bool, out io.WriteCloser) error
}

type Extracter interface {
// Extract extracts the given archive into the given directory.
Extract(in io.ReadCloser, targetDir string) error
}

type Type int

func (t Type) String() string {
switch t {
case TypeTarGz:
return "tar.gz"
default:
return "unknown"
}
}

type archiveAdder interface {
Add(filename string, info os.FileInfo, targetPath string) error
Close() error
}

type archiver interface {
Extracter
NewArchiveAdder(out io.WriteCloser) archiveAdder
}

type archivist struct {
archiver
}

func (a *archivist) ArchiveDirectory(directory string, predicate func(string) bool, out io.WriteCloser) (err error) {
archive := a.archiver.NewArchiveAdder(out)
defer func() {
closeErr := archive.Close()
if err == nil {
err = closeErr
}
}()

err = filepath.Walk(directory, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}

if info.IsDir() {
return nil
}

if !predicate(path) {
return nil
}

targetPath := strings.Trim(filepath.ToSlash(strings.TrimPrefix(path, directory)), "/")
return archive.Add(path, info, targetPath)
})

return

}

type tarGzArchiver struct {
out io.WriteCloser
gw *gzip.Writer
tw *tar.Writer
}

func (a *tarGzArchiver) Add(filename string, info os.FileInfo, targetPath string) error {
f, err := os.Open(filename)
if err != nil {
return err
}

defer f.Close()

header, err := tar.FileInfoHeader(info, "") // TODO(bep) symlink handling?
if err != nil {
return err
}
header.Name = targetPath

err = a.tw.WriteHeader(header)
if err != nil {
return err
}

_, err = io.Copy(a.tw, f)
if err != nil {
return err
}

return nil
}

func (a *tarGzArchiver) Close() error {
if err := a.tw.Close(); err != nil {
return err
}
if err := a.gw.Close(); err != nil {
return err
}

return a.out.Close()
}

type tarGzExtractor struct {
}

func (e *tarGzExtractor) NewArchiveAdder(out io.WriteCloser) archiveAdder {
a := &tarGzArchiver{
out: out,
}

gw, _ := gzip.NewWriterLevel(out, gzip.BestCompression)
tw := tar.NewWriter(gw)

a.gw = gw
a.tw = tw

return struct {
archiveAdder
Extracter
}{
a,
&tarGzExtractor{},
}
}

func (a *tarGzExtractor) Extract(in io.ReadCloser, targetDir string) error {
defer in.Close()

gzr, err := gzip.NewReader(in)
if err != nil {
return err
}

tr := tar.NewReader(gzr)

for {
header, err := tr.Next()
if err == io.EOF {
break
}
if err != nil {
return err
}

target := filepath.Join(targetDir, header.Name)

switch header.Typeflag {
case tar.TypeDir:
if err := os.MkdirAll(target, os.FileMode(header.Mode)); err != nil {
return err
}
case tar.TypeReg:
if err := os.MkdirAll(filepath.Dir(target), 0755); err != nil && !os.IsExist(err) {
return err
}
f, err := os.OpenFile(target, os.O_CREATE|os.O_RDWR, os.FileMode(header.Mode))
if err != nil {
return err
}

if _, err := io.Copy(f, tr); err != nil {
return err
}

f.Close()
default:
return fmt.Errorf("unable to untar type: %c in file %s", header.Typeflag, target)
}
}

return nil
}
90 changes: 90 additions & 0 deletions archivehelpers/archivehelpers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package archivehelpers

import (
"fmt"
"io/fs"
"os"
"path/filepath"
"strings"
"testing"

qt "github.com/frankban/quicktest"
)

func TestArchive(t *testing.T) {
c := qt.New(t)

_, err := New(32)
c.Assert(err, qt.IsNotNil)

c.Assert(TypeTarGz.String(), qt.Equals, "tar.gz")
c.Assert(TypeUnknown.String(), qt.Equals, "unknown")

tempDir := t.TempDir()

for _, tp := range []Type{TypeTarGz} {

archiveFilename := filepath.Join(tempDir, "myarchive1."+tp.String())
f, err := os.Create(archiveFilename)
c.Assert(err, qt.IsNil)

a, err := New(tp)
c.Assert(err, qt.IsNil)

sourceDir := t.TempDir()
subDir := filepath.Join(sourceDir, "subdir")
c.Assert(os.MkdirAll(subDir, 0755), qt.IsNil)
c.Assert(os.WriteFile(filepath.Join(sourceDir, "file1.txt"), []byte("hello"), 0644), qt.IsNil)
c.Assert(os.WriteFile(filepath.Join(subDir, "file2.txt"), []byte("world"), 0643), qt.IsNil)

matchAll := func(string) bool { return true }
c.Assert(a.ArchiveDirectory(sourceDir, matchAll, f), qt.IsNil)

assertArchive := func(ararchiveFilename string, predicate func(string) bool) {
resultDir := t.TempDir()
f, err = os.Open(archiveFilename)
c.Assert(err, qt.IsNil)
c.Assert(a.Extract(f, resultDir), qt.IsNil)

dirList := func(dirname string) string {
var sb strings.Builder
err := filepath.WalkDir(dirname, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if d.IsDir() {
return nil
}

if !predicate(path) {
return nil
}

fi, err := d.Info()
if err != nil {
return err
}

sb.WriteString(fmt.Sprintf("%s %04o %s\n", fi.Mode(), fi.Mode().Perm(), fi.Name()))
return nil
})

c.Assert(err, qt.IsNil)
return sb.String()
}
dirList1 := dirList(sourceDir)
dirList2 := dirList(resultDir)
c.Assert(dirList1, qt.Equals, dirList2)
}

assertArchive(archiveFilename, matchAll)

archiveFilename = filepath.Join(tempDir, "myarchive2."+tp.String())
f, err = os.Create(archiveFilename)
matchSome := func(s string) bool {
return filepath.Base(s) == "file1.txt"
}
c.Assert(a.ArchiveDirectory(sourceDir, matchSome, f), qt.IsNil)
assertArchive(archiveFilename, matchSome)
}
}

0 comments on commit 4ba6d65

Please sign in to comment.