Skip to content

Commit

Permalink
feat: implement preserve (bramvdbogaerde#81)
Browse files Browse the repository at this point in the history
  • Loading branch information
datadius authored and jpdoyon committed Jun 28, 2024
1 parent 50ade08 commit af122c4
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 20 deletions.
39 changes: 35 additions & 4 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,14 +315,38 @@ func (a *Client) CopyFromRemotePassThru(
remotePath string,
passThru PassThru,
) error {
_, err := a.copyFromRemote(ctx, w, remotePath, passThru, false)

return err
}

// CopyFroRemoteFileInfos copies a file from the remote to a given writer and return a FileInfos struct
// containing information about the file such as permissions, the file size, modification time and access time
func (a *Client) CopyFromRemoteFileInfos(
ctx context.Context,
w io.Writer,
remotePath string,
passThru PassThru,
) (*FileInfos, error) {
return a.copyFromRemote(ctx, w, remotePath, passThru, true)
}

func (a *Client) copyFromRemote(
ctx context.Context,
w io.Writer,
remotePath string,
passThru PassThru,
preserveFileTimes bool,
) (*FileInfos, error) {
session, err := a.sshClient.NewSession()
if err != nil {
return fmt.Errorf("Error creating ssh session in copy from remote: %v", err)
return nil, fmt.Errorf("Error creating ssh session in copy from remote: %v", err)
}
defer session.Close()

wg := sync.WaitGroup{}
errCh := make(chan error, 4)
var fileInfos *FileInfos

wg.Add(1)
go func() {
Expand All @@ -349,7 +373,11 @@ func (a *Client) CopyFromRemotePassThru(
}
defer in.Close()

err = session.Start(fmt.Sprintf("%s -f %q", a.RemoteBinary, remotePath))
if preserveFileTimes {
err = session.Start(fmt.Sprintf("%s -pf %q", a.RemoteBinary, remotePath))
} else {
err = session.Start(fmt.Sprintf("%s -f %q", a.RemoteBinary, remotePath))
}
if err != nil {
errCh <- err
return
Expand All @@ -367,6 +395,8 @@ func (a *Client) CopyFromRemotePassThru(
return
}

fileInfos = fileInfo

err = Ack(in)
if err != nil {
errCh <- err
Expand Down Expand Up @@ -403,11 +433,12 @@ func (a *Client) CopyFromRemotePassThru(
}

if err := wait(&wg, ctx); err != nil {
return err
return nil, err
}

finalErr := <-errCh
close(errCh)
return finalErr
return fileInfos, finalErr
}

func (a *Client) Close() {
Expand Down
2 changes: 1 addition & 1 deletion configurer.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,6 @@ func (c *ClientConfigurer) Create() Client {
Timeout: c.timeout,
RemoteBinary: c.remoteBinary,
sshClient: c.sshClient,
closeHandler: EmptyHandler{},
closeHandler: EmptyHandler{},
}
}
37 changes: 24 additions & 13 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,19 @@ func ParseResponse(reader io.Reader, writer io.Writer) (*FileInfos, error) {
return nil, err
}

message, err = bufferedReader.ReadString('\n')
if err == io.EOF {
// A custom ssh server can send both time, permissions and size information at once
// without needing an Ack response. Example: wish from charmbracelet while using their default scp implementation
// If the buffer is empty, then it's likely the default implementation for ssh, so send Ack
if bufferedReader.Buffered() == 0 {
err = Ack(writer)
if err != nil {
return fileInfos, err
}
message, err = bufferedReader.ReadString('\n')

if err != nil {
return fileInfos, err
}
}

if err != nil && err != io.EOF {
message, err = bufferedReader.ReadString('\n')

if err != nil {
return fileInfos, err
}

Expand All @@ -102,7 +101,7 @@ func ParseResponse(reader io.Reader, writer io.Writer) (*FileInfos, error) {
type FileInfos struct {
Message string
Filename string
Permissions string
Permissions uint32
Size int64
Atime int64
Mtime int64
Expand All @@ -119,7 +118,7 @@ func (fileInfos *FileInfos) Update(new *FileInfos) {
if new.Filename != "" {
fileInfos.Filename = new.Filename
}
if new.Permissions != "" {
if new.Permissions != 0 {
fileInfos.Permissions = new.Permissions
}
if new.Size != 0 {
Expand All @@ -140,14 +139,19 @@ func ParseFileInfos(message string, fileInfos *FileInfos) error {
return errors.New("unable to parse Chmod protocol")
}

permissions, err := strconv.ParseUint(parts[0][1:], 0, 32)
if err != nil {
return err
}

size, err := strconv.Atoi(parts[1])
if err != nil {
return err
}

fileInfos.Update(&FileInfos{
Filename: parts[2],
Permissions: parts[0],
Permissions: uint32(permissions),
Size: int64(size),
})

Expand All @@ -164,11 +168,18 @@ func ParseFileTime(
return errors.New("unable to parse Time protocol")
}

aTime, err := strconv.Atoi(string(parts[0][0:10]))
if len(parts[0]) != 10 {
return errors.New("length of ATime is not 10")
}
mTime, err := strconv.Atoi(parts[0][0:10])
if err != nil {
return errors.New("unable to parse ATime component of message")
}
mTime, err := strconv.Atoi(string(parts[2][0:10]))

if len(parts[2]) != 10 {
return errors.New("length of MTime is not 10")
}
aTime, err := strconv.Atoi(parts[2][0:10])
if err != nil {
return errors.New("unable to parse MTime component of message")
}
Expand Down
63 changes: 61 additions & 2 deletions tests/basic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package scp
import (
"context"
"fmt"
"io/fs"
"os"
"strings"
"testing"
Expand Down Expand Up @@ -207,7 +208,35 @@ func TestDownloadFile(t *testing.T) {
client := establishConnection(t)
defer client.Close()

// Open a file we can transfer to the remote container.
// Create a local file to write to.
f, err := os.OpenFile("./tmp/output.txt", os.O_RDWR|os.O_CREATE, 0777)
if err != nil {
t.Errorf("Couldn't open the output file")
}
defer f.Close()

// Use a file name with exotic characters and spaces in them.
// If this test works for this, simpler files should not be a problem.
err = client.CopyFromRemote(context.Background(), f, "/input/Exöt1ç download file.txt.txt")
if err != nil {
t.Errorf("Copy failed from remote: %s", err.Error())
}

content, err := os.ReadFile("./tmp/output.txt")
if err != nil {
t.Errorf("Result file could not be read: %s", err)
}

text := string(content)
expected := "It works for download!\n"
if strings.Compare(text, expected) != 0 {
t.Errorf("Got different text than expected, expected %q got, %q", expected, text)
}
}

func TestDownloadFileInfo(t *testing.T) {
client := establishConnection(t)
defer client.Close()
f, _ := os.Open("./data/input.txt")
defer f.Close()

Expand All @@ -220,7 +249,12 @@ func TestDownloadFile(t *testing.T) {

// Use a file name with exotic characters and spaces in them.
// If this test works for this, simpler files should not be a problem.
err = client.CopyFromRemote(context.Background(), f, "/input/Exöt1ç download file.txt.txt")
fileInfos, err := client.CopyFromRemoteFileInfos(
context.Background(),
f,
"/input/Exöt1ç download file.txt.txt",
nil,
)
if err != nil {
t.Errorf("Copy failed from remote: %s", err.Error())
}
Expand All @@ -235,6 +269,31 @@ func TestDownloadFile(t *testing.T) {
if strings.Compare(text, expected) != 0 {
t.Errorf("Got different text than expected, expected %q got, %q", expected, text)
}

fileStat, err := os.Stat("./data/Exöt1ç download file.txt.txt")
if err != nil {
t.Errorf("Result file could not be read: %s", err)
}

if fileInfos.Size != fileStat.Size() {
t.Errorf("File size does not match")
}

if fs.FileMode(fileInfos.Permissions) == fs.FileMode(0777) {
t.Errorf(
"File permissions don't match %s vs %s",
fs.FileMode(fileInfos.Permissions),
fileStat.Mode().Perm(),
)
}

if fileInfos.Mtime != fileStat.ModTime().Unix() {
t.Errorf(
"File modification time does not match %d vs %d",
fileInfos.Mtime,
fileStat.ModTime().Unix(),
)
}
}

// TestTimeoutDownload tests that a timeout error is produced if the file is not copied in the given
Expand Down

0 comments on commit af122c4

Please sign in to comment.