diff --git a/lib/files/about/about.go b/lib/files/about/about.go index 68dca2a..27853c3 100644 --- a/lib/files/about/about.go +++ b/lib/files/about/about.go @@ -8,7 +8,7 @@ import ( "fmt" "net/url" "os" - "syscall" + "strings" "time" "github.com/puellanivis/breton/lib/files" @@ -20,41 +20,60 @@ import ( type handler struct{} func init() { - files.RegisterScheme(&handler{}, "about") + files.RegisterScheme(handler{}, "about") } -func (h *handler) Create(ctx context.Context, uri *url.URL) (files.Writer, error) { - return nil, files.PathError("create", uri.String(), os.ErrInvalid) +type reader interface { + ReadAll() ([]byte, error) } -type fn func() ([]byte, error) - -func blank() ([]byte, error) { - return nil, nil +type lister interface { + ReadDir() ([]os.FileInfo, error) } -func notfound() ([]byte, error) { - return nil, os.ErrNotExist +type stringFunc func() string + +func (f stringFunc) ReadAll() ([]byte, error) { + return append([]byte(f()), '\n'), nil } -var errUnresolvable = errors.New("unresolvable address") +var ( + blank stringFunc = func() string { return "" } + version stringFunc = func() string { return process.Version() } + now stringFunc = func() string { return time.Now().Truncate(0).String() } +) + +type errorURL struct { + error +} -func unresolvable() ([]byte, error) { - return nil, errUnresolvable +func (e errorURL) ReadAll() ([]byte, error) { + return nil, e.error } -func version() ([]byte, error) { - return append([]byte(process.Version()), '\n'), nil +func (e errorURL) ReadDir() ([]os.FileInfo, error) { + return nil, e.error } +// ErrNoSuchHost defines an error, where a DNS host lookup failed to resolve. +var ErrNoSuchHost = errors.New("no such host") + +var ( + notfound = errorURL{os.ErrNotExist} + unresolvable = errorURL{ErrNoSuchHost} +) + +type aboutMap map[string]reader + var ( - aboutMap = map[string]fn{ + about = aboutMap{ "": version, "blank": blank, "cache": blank, "invalid": notfound, "html-kind": unresolvable, "legacy-compat": unresolvable, + "now": now, "plugins": plugins, "srcdoc": unresolvable, "version": version, @@ -64,113 +83,204 @@ var ( func init() { // if aboutMap references about, then about references aboutMap // and go errors with "initialization loop" - aboutMap["about"] = about + about["about"] = about } -func listOf(list []string) ([]byte, error) { +func (m aboutMap) keys() []string { + var list []string + + for key := range m { + if key == "" || strings.HasPrefix(key, ".") { + continue + } + + list = append(list, key) + } + sort.Strings(list) + return list +} + +func (m aboutMap) ReadAll() ([]byte, error) { + keys := m.keys() + b := new(bytes.Buffer) - for _, item := range list { - fmt.Fprintln(b, item) + for _, key := range keys { + uri := &url.URL{ + Scheme: "about", + Opaque: url.PathEscape(key), + } + + fmt.Fprintln(b, uri) } return b.Bytes(), nil } -func plugins() ([]byte, error) { - return listOf(files.RegisteredSchemes()) +func (m aboutMap) ReadDir() ([]os.FileInfo, error) { + keys := m.keys() + + var infos []os.FileInfo + + for _, key := range keys { + f := m[key] + + data, err := f.ReadAll() + if err != nil { + // skip errorURL endpoints. + continue + } + + uri := &url.URL{ + Path: key, + } + + info := wrapper.NewInfo(uri, len(data), time.Now()) + + if _, ok := f.(lister); ok { + info.Chmod(info.Mode() | os.ModeDir) + } + + infos = append(infos, info) + } + + return infos, nil } -func about() ([]byte, error) { - var list []string +type schemeList struct{} + +func (schemeList) ReadAll() ([]byte, error) { + schemes := files.RegisteredSchemes() - for name := range aboutMap { + b := new(bytes.Buffer) + + for _, scheme := range schemes { uri := &url.URL{ - Scheme: "about", - Opaque: name, + Scheme: scheme, } - list = append(list, uri.String()) + fmt.Fprintln(b, uri) } - return listOf(list) + return b.Bytes(), nil } -func (h *handler) Open(ctx context.Context, uri *url.URL) (files.Reader, error) { +func (schemeList) ReadDir() ([]os.FileInfo, error) { + schemes := files.RegisteredSchemes() + + var infos []os.FileInfo + + for _, scheme := range schemes { + uri := &url.URL{ + Path: scheme, + } + + infos = append(infos, wrapper.NewInfo(uri, 0, time.Now())) + } + + return infos, nil +} + +var ( + plugins schemeList +) + +func (h handler) Open(ctx context.Context, uri *url.URL) (files.Reader, error) { if uri.Host != "" || uri.User != nil { - return nil, files.PathError("open", uri.String(), os.ErrInvalid) + return nil, &os.PathError{ + Op: "open", + Path: uri.String(), + Err: files.ErrURLCannotHaveAuthority, + } } path := uri.Path if path == "" { - path = uri.Opaque + var err error + path, err = url.PathUnescape(uri.Opaque) + if err != nil { + return nil, &os.PathError{ + Op: "open", + Path: uri.String(), + Err: files.ErrURLInvalid, + } + } } - f, ok := aboutMap[path] + f, ok := about[path] if !ok { - return nil, files.PathError("open", uri.String(), os.ErrNotExist) + return nil, &os.PathError{ + Op: "open", + Path: uri.String(), + Err: os.ErrNotExist, + } } - data, err := f() + data, err := f.ReadAll() if err != nil { - return nil, files.PathError("open", uri.String(), err) + return nil, &os.PathError{ + Op: "read", + Path: uri.String(), + Err: err, + } } return wrapper.NewReaderFromBytes(data, uri, time.Now()), nil } -func (h *handler) List(ctx context.Context, uri *url.URL) ([]os.FileInfo, error) { +func (h handler) ReadDir(ctx context.Context, uri *url.URL) ([]os.FileInfo, error) { if uri.Host != "" || uri.User != nil { - return nil, files.PathError("readdir", uri.String(), os.ErrInvalid) + return nil, &os.PathError{ + Op: "readdir", + Path: uri.String(), + Err: files.ErrURLCannotHaveAuthority, + } } path := uri.Path if path == "" { - path = uri.Opaque - } - - if f, ok := aboutMap[path]; !ok { - return nil, files.PathError("readdir", uri.String(), os.ErrNotExist) - - } else if f != nil { - if _, err := f(); err != nil { - return nil, files.PathError("readdir", uri.String(), err) + var err error + path, err = url.PathUnescape(uri.Opaque) + if err != nil { + return nil, &os.PathError{ + Op: "open", + Path: uri.String(), + Err: files.ErrURLInvalid, + } } } - if path != "about" && path != "" { - return nil, files.PathError("readdir", uri.String(), syscall.ENOTDIR) - } - - var list []string - for name := range aboutMap { - list = append(list, name) + if path == "" { + path = "about" } - sort.Strings(list) - - var ret []os.FileInfo - for _, name := range list { - f := aboutMap[name] - - uri := &url.URL{ - Scheme: "about", - Opaque: name, - } - - if f == nil { - continue + f, ok := about[path] + if !ok { + return nil, &os.PathError{ + Op: "readdir", + Path: uri.String(), + Err: os.ErrNotExist, } + } - data, err := f() + if f, ok := f.(lister); ok { + infos, err := f.ReadDir() if err != nil { - continue + return nil, &os.PathError{ + Op: "readdir", + Path: uri.String(), + Err: err, + } } - ret = append(ret, wrapper.NewInfo(uri, len(data), time.Now())) + return infos, nil } - return ret, nil + return nil, &os.PathError{ + Op: "readdir", + Path: uri.String(), + Err: files.ErrNotDirectory, + } } diff --git a/lib/files/about/about_test.go b/lib/files/about/about_test.go new file mode 100644 index 0000000..01ba4f8 --- /dev/null +++ b/lib/files/about/about_test.go @@ -0,0 +1,15 @@ +package aboutfiles + +import ( + "testing" + + "github.com/puellanivis/breton/lib/files" +) + +func TestHandlerFulfillsReadDirFS(t *testing.T) { + var h files.FS = handler{} + + if _, ok := h.(files.ReadDirFS); !ok { + t.Fatal("handler does not implement files.ReadDirFS") + } +} diff --git a/lib/files/cachefiles/cache.go b/lib/files/cachefiles/cache.go index dd170f0..05f0692 100644 --- a/lib/files/cachefiles/cache.go +++ b/lib/files/cachefiles/cache.go @@ -14,23 +14,20 @@ import ( ) type line struct { - os.FileInfo - + info os.FileInfo data []byte } -// FileStore is a caching structure that holds copies of the content of files. -type FileStore struct { +// FS is a caching structure that holds copies of the content of files. +type FS struct { sync.RWMutex cache map[string]*line } -// New returns a new caching FileStore, which can be registered into lib/files -func New() *FileStore { - return &FileStore{ - cache: make(map[string]*line), - } +// New returns a new caching FS, which can be registered into lib/files +func New() *FS { + return &FS{} } // Default is the default cache attached to the "cache" Scheme @@ -40,81 +37,139 @@ func init() { files.RegisterScheme(Default, "cache") } -func (h *FileStore) expire(filename string) { +func (h *FS) expire(filename string) { h.Lock() defer h.Unlock() delete(h.cache, filename) } -func trimScheme(uri *url.URL) string { - if uri.Scheme == "" { - return uri.String() +func resolveReference(uri *url.URL) (string, error) { + if uri.Host != "" || uri.User != nil { + return "", files.ErrURLCannotHaveAuthority } - return uri.String()[len(uri.Scheme)+1:] + if uri.Path != "" { + return uri.Path, nil + } + + path, err := url.PathUnescape(uri.Opaque) + if err != nil { + return "", files.ErrURLInvalid + } + + return path, nil } -// Create implements the files.FileStore Create. At this time, it just returns the files.Create() from the wrapped url. -func (h *FileStore) Create(ctx context.Context, uri *url.URL) (files.Writer, error) { - return files.Create(ctx, trimScheme(uri)) +// Create implements files.CreateFS. +// At this time, it just returns the files.Create() from the wrapped url. +func (h *FS) Create(ctx context.Context, uri *url.URL) (files.Writer, error) { + filename, err := resolveReference(uri) + if err != nil { + return nil, &os.PathError{ + Op: "create", + Path: uri.String(), + Err: err, + } + } + + return files.Create(ctx, filename) } -// Open implements the files.FileStore Open. It returns a buffered copy of the files.Reader returned from reading the uri escaped by the "cache:" scheme. Any access within the next ExpireTime set by the context.Context (5 minutes by default) will return a new copy of an bytes.Reader of the same buffer. -func (h *FileStore) Open(ctx context.Context, uri *url.URL) (files.Reader, error) { +// Open implements files.FS. +// It returns a buffered copy of the files.Reader returned from reading the uri escaped by the "cache:" scheme. +// Any access within the next ExpireTime set by the context.Context (or 5 minutes by default) will return a new copy of a files.Reader, with the same content. +func (h *FS) Open(ctx context.Context, uri *url.URL) (files.Reader, error) { + filename, err := resolveReference(uri) + if err != nil { + return nil, &os.PathError{ + Op: "open", + Path: uri.String(), + Err: err, + } + } + + ctx, safe := isReentrySafe(ctx) + if !safe { + // We are in a rentrant caching scenario. + // Continuing will deadlock, so we won’t even try to cache at all. + return files.Open(ctx, filename) + } + + h.RLock() + f, ok := h.cache[filename] + h.RUnlock() + + if ok { + return wrapper.NewReaderWithInfo(bytes.NewReader(f.data), f.info), nil + } + + // default 5 minute expiration + expiration := 5 * time.Minute + if d, ok := GetExpire(ctx); ok { + expiration = d + } + h.Lock() defer h.Unlock() - filename := trimScheme(uri) + f, ok = h.cache[filename] - f, ok := h.cache[filename] + // We have to test existence again. + // Maybe another thread already did our work. if !ok { - if _, ok := ctx.Deadline(); !ok { - // default 5 minute expire time - d := 5 * time.Minute - if t, ok := GetExpire(ctx); ok { - d = t - } - - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, d) - defer cancel() - } - raw, err := files.Open(ctx, filename) if err != nil { return nil, err } + info, err := raw.Stat() + if err != nil { + info = nil // safety guard. + } + data, err := files.ReadFrom(raw) if err != nil { return nil, err } - info, err := raw.Stat() - if err != nil { + if info == nil { info = wrapper.NewInfo(uri, len(data), time.Now()) } f = &line{ - data: data, - FileInfo: info, + data: data, + info: info, + } + + if h.cache == nil { + h.cache = make(map[string]*line) } h.cache[filename] = f + timer := time.NewTimer(expiration) go func() { - defer h.expire(filename) - - <-ctx.Done() + <-timer.C + h.expire(filename) }() } - return wrapper.NewReaderWithInfo(bytes.NewReader(f.data), f.FileInfo), nil + return wrapper.NewReaderWithInfo(bytes.NewReader(f.data), f.info), nil } -// List implements the files.FileStore List. It does not cache anything and just returns the files.List() from the wrapped url. -func (h *FileStore) List(ctx context.Context, uri *url.URL) ([]os.FileInfo, error) { - return files.List(ctx, trimScheme(uri)) +// ReadDir implements files.ReadDirFS. +// It does not cache anything and just returns the files.ReadDir() from the wrapped url. +func (h *FS) ReadDir(ctx context.Context, uri *url.URL) ([]os.FileInfo, error) { + filename, err := resolveReference(uri) + if err != nil { + return nil, &os.PathError{ + Op: "readdir", + Path: uri.String(), + Err: err, + } + } + + return files.ReadDir(ctx, filename) } diff --git a/lib/files/cachefiles/cache_test.go b/lib/files/cachefiles/cache_test.go new file mode 100644 index 0000000..ac832e5 --- /dev/null +++ b/lib/files/cachefiles/cache_test.go @@ -0,0 +1,23 @@ +package cachefiles + +import ( + "testing" + + "github.com/puellanivis/breton/lib/files" +) + +func TestHandlerFulfillsReadDirFS(t *testing.T) { + var h files.FS = &FS{} + + if _, ok := h.(files.ReadDirFS); !ok { + t.Fatal("handler does not implement files.ReadDirFS") + } +} + +func TestHandlerFulfillsCreateFS(t *testing.T) { + var h files.FS = &FS{} + + if _, ok := h.(files.CreateFS); !ok { + t.Fatal("handler does not implement files.CreateFS") + } +} diff --git a/lib/files/cachefiles/context.go b/lib/files/cachefiles/context.go index 49c3fa4..20dc2c7 100644 --- a/lib/files/cachefiles/context.go +++ b/lib/files/cachefiles/context.go @@ -6,7 +6,8 @@ import ( ) type ( - expireKey struct{} + expireKey struct{} + reentranceKey struct{} ) // WithExpire returns a Context that includes information for the cache FileStore to expire buffers after the given timeout. @@ -20,3 +21,11 @@ func GetExpire(ctx context.Context) (time.Duration, bool) { return timeout, ok } + +func isReentrySafe(ctx context.Context) (context.Context, bool) { + if v := ctx.Value(reentranceKey{}); v != nil { + return ctx, false + } + + return context.WithValue(ctx, reentranceKey{}, struct{}{}), true +} diff --git a/lib/files/clipboard/clip.go b/lib/files/clipboard/clip.go index 7085cfc..526209c 100644 --- a/lib/files/clipboard/clip.go +++ b/lib/files/clipboard/clip.go @@ -5,7 +5,6 @@ import ( "context" "net/url" "os" - "syscall" "time" "github.com/puellanivis/breton/lib/files" @@ -28,12 +27,16 @@ var clipboards = make(map[string]clipboard) func getClip(uri *url.URL) (clipboard, error) { if uri.Host != "" || uri.User != nil { - return nil, os.ErrInvalid + return nil, files.ErrURLCannotHaveAuthority } path := uri.Path if path == "" { - path = uri.Opaque + var err error + path, err = url.PathUnescape(uri.Opaque) + if err != nil { + return nil, files.ErrURLInvalid + } } clip := clipboards[path] @@ -44,10 +47,14 @@ func getClip(uri *url.URL) (clipboard, error) { return clip, nil } -func (h *handler) Open(ctx context.Context, uri *url.URL) (files.Reader, error) { +func (handler) Open(ctx context.Context, uri *url.URL) (files.Reader, error) { clip, err := getClip(uri) if err != nil { - return nil, files.PathError("open", uri.String(), err) + return nil, &os.PathError{ + Op: "open", + Path: uri.String(), + Err: err, + } } b, err := clip.Read() @@ -58,10 +65,14 @@ func (h *handler) Open(ctx context.Context, uri *url.URL) (files.Reader, error) return wrapper.NewReaderFromBytes(b, uri, time.Now()), nil } -func (h *handler) Create(ctx context.Context, uri *url.URL) (files.Writer, error) { +func (handler) Create(ctx context.Context, uri *url.URL) (files.Writer, error) { clip, err := getClip(uri) if err != nil { - return nil, files.PathError("create", uri.String(), err) + return nil, &os.PathError{ + Op: "create", + Path: uri.String(), + Err: err, + } } return wrapper.NewWriter(ctx, uri, func(b []byte) error { @@ -69,34 +80,62 @@ func (h *handler) Create(ctx context.Context, uri *url.URL) (files.Writer, error }), nil } -func (h *handler) List(ctx context.Context, uri *url.URL) ([]os.FileInfo, error) { +func (handler) ReadDir(ctx context.Context, uri *url.URL) ([]os.FileInfo, error) { if uri.Host != "" || uri.User != nil { - return nil, files.PathError("readdir", uri.String(), os.ErrInvalid) + return nil, &os.PathError{ + Op: "readdir", + Path: uri.String(), + Err: files.ErrURLCannotHaveAuthority, + } } path := uri.Path if path == "" { - path = uri.Opaque + var err error + path, err = url.PathUnescape(uri.Opaque) + if err != nil { + return nil, files.ErrURLInvalid + } } clip := clipboards[path] if clip == nil { - return nil, files.PathError("readdir", uri.String(), os.ErrNotExist) + return nil, &os.PathError{ + Op: "readdir", + Path: uri.String(), + Err: os.ErrNotExist, + } } if path != "" { - return nil, files.PathError("readdir", uri.String(), syscall.ENOTDIR) - } - - if len(clipboards) < 1 { - return nil, files.PathError("readdir", uri.String(), os.ErrNotExist) + return nil, &os.PathError{ + Op: "readdir", + Path: uri.String(), + Err: files.ErrNotDirectory, + } } var ret []os.FileInfo for _, clip := range clipboards { - if fi, err := clip.Stat(); err == nil { - ret = append(ret, fi) + if info, err := clip.Stat(); err == nil { + if fi, ok := info.(interface{ URL() *url.URL }); ok { + u := fi.URL() + u.Scheme = "" + + switch fi := fi.(type) { + case interface{ SetNameFromURL(*url.URL) }: + fi.SetNameFromURL(u) + + case interface{ SetName(string) }: + fi.SetName(u.String()) + + default: + info = wrapper.NewInfo(u, int(info.Size()), info.ModTime()) + } + } + + ret = append(ret, info) } } diff --git a/lib/files/clipboard/exec.go b/lib/files/clipboard/exec.go index f37a571..6c5436a 100644 --- a/lib/files/clipboard/exec.go +++ b/lib/files/clipboard/exec.go @@ -24,7 +24,7 @@ func newExecClip(name string, target ...string) { clipboards[name] = &execClip{ name: &url.URL{ Scheme: "clipboard", - Opaque: name, + Opaque: url.PathEscape(name), }, paste: append(pasteCmd, append(selParam, target...)...), copy: append(copyCmd, append(selParam, target...)...), diff --git a/lib/files/context.go b/lib/files/context.go index 7ae8804..4d95eaf 100644 --- a/lib/files/context.go +++ b/lib/files/context.go @@ -3,48 +3,25 @@ package files import ( "context" "net/url" - "path/filepath" ) type ( rootKey struct{} ) +func withRoot(ctx context.Context, root *url.URL) context.Context { + return context.WithValue(ctx, rootKey{}, root) +} + // WithRootURL attaches a url.URL to a Context // and is used as the resolution reference for any files.Open() using that context. func WithRootURL(ctx context.Context, uri *url.URL) context.Context { - uriCopy := uri - uri = resolveFilename(ctx, uri) - - // we got the same URL back, clone it so that it stays immutable to the original uri passed in. - if uriCopy == uri { - uriCopy = new(url.URL) - *uriCopy = *uri - - if uri.User != nil { - uriCopy.User = new(url.Userinfo) - *uriCopy.User = *uri.User // gotta copy this pointer struct also. - } - - uri = uriCopy - } - - return context.WithValue(ctx, rootKey{}, uri) + return withRoot(ctx, resolveURL(ctx, uri)) } // WithRoot stores either a URL or a local path to use as a root point when resolving filenames. func WithRoot(ctx context.Context, path string) (context.Context, error) { - if filepath.IsAbs(path) { - path = filepath.Clean(path) - return WithRootURL(ctx, makePath(path)), nil - } - - uri, err := url.Parse(path) - if err != nil { - return ctx, err - } - - return WithRootURL(ctx, uri), nil + return withRoot(ctx, parsePath(ctx, path)), nil } func getRoot(ctx context.Context) (*url.URL, bool) { @@ -60,7 +37,7 @@ func GetRoot(ctx context.Context) (string, bool) { } if isPath(root) { - return getPath(root), true + return root.Path, true } return root.String(), true diff --git a/lib/files/create.go b/lib/files/create.go index c1a567c..1e26e65 100644 --- a/lib/files/create.go +++ b/lib/files/create.go @@ -4,12 +4,18 @@ import ( "context" "net/url" "os" - "path/filepath" ) -// Create takes a context and a filename (which may be a URL) and returns a -// files.Writer that allows writing data to that local filename or URL. All -// errors and reversion functions returned by Option arguments are discarded. +// CreateFS defines an extention interface on FS, which also provides an ability to create a new file for read/write. +type CreateFS interface { + FS + Create(ctx context.Context, uri *url.URL) (Writer, error) +} + +// Create takes a context and a filename (which may be a URL) and +// returns a files.Writer that allows writing data to that local filename or URL. +// +// All errors and reversion functions returned by Option arguments are discarded. func Create(ctx context.Context, filename string, options ...Option) (Writer, error) { f, err := create(ctx, filename) if err != nil { @@ -31,17 +37,30 @@ func create(ctx context.Context, filename string) (Writer, error) { return os.Stderr, nil } - if filepath.IsAbs(filename) { + uri := parsePath(ctx, filename) + if isPath(uri) { return os.Create(filename) } - if uri, err := url.Parse(filename); err == nil { - uri = resolveFilename(ctx, uri) - - if fs, ok := getFS(uri); ok { - return fs.Create(ctx, uri) + fsys, ok := getFS(uri) + if !ok { + return nil, &os.PathError{ + Op: "create", + Path: uri.String(), + Err: ErrNotSupported, } } - return os.Create(filename) + switch fsys := fsys.(type) { + case CreateFS: + return fsys.Create(ctx, uri) + + // case OpenFileFS: // implement + } + + return nil, &os.PathError{ + Op: "create", + Path: filename, + Err: ErrNotSupported, + } } diff --git a/lib/files/datafiles/data.go b/lib/files/datafiles/data.go index aff85cd..0632eed 100644 --- a/lib/files/datafiles/data.go +++ b/lib/files/datafiles/data.go @@ -17,45 +17,54 @@ import ( type handler struct{} func init() { - files.RegisterScheme(&handler{}, "data") + files.RegisterScheme(handler{}, "data") } var b64enc = base64.StdEncoding -func (h *handler) Create(ctx context.Context, uri *url.URL) (files.Writer, error) { - return nil, files.PathError("create", uri.String(), os.ErrInvalid) -} - -type withHeaders struct { +type withHeader struct { files.Reader header http.Header } -func (w *withHeaders) Header() http.Header { +func (w *withHeader) Header() http.Header { return w.header } -func (h *handler) Open(ctx context.Context, uri *url.URL) (files.Reader, error) { +func (handler) Open(ctx context.Context, uri *url.URL) (files.Reader, error) { if uri.Host != "" || uri.User != nil { - return nil, files.PathError("open", uri.String(), os.ErrInvalid) + return nil, &os.PathError{ + Op: "open", + Path: uri.String(), + Err: files.ErrURLCannotHaveAuthority, + } } path := uri.Path if path == "" { - path = uri.Opaque - if p, err := url.PathUnescape(path); err == nil { - path = p + var err error + path, err = url.PathUnescape(uri.Opaque) + if err != nil { + return nil, &os.PathError{ + Op: "open", + Path: uri.String(), + Err: err, + } } } i := strings.IndexByte(path, ',') if i < 0 { - return nil, files.PathError("open", uri.String(), os.ErrInvalid) + return nil, &os.PathError{ + Op: "open", + Path: uri.String(), + Err: files.ErrURLInvalid, + } } contentType, data := path[:i], []byte(path[i+1:]) - var isBase64 bool + var isBase64 bool if strings.HasSuffix(contentType, ";base64") { contentType = strings.TrimSuffix(contentType, ";base64") isBase64 = true @@ -65,26 +74,27 @@ func (h *handler) Open(ctx context.Context, uri *url.URL) (files.Reader, error) contentType = "text/plain;charset=US-ASCII" } - header := make(http.Header) - header.Set("Content-Type", contentType) + header := http.Header{ + "Content-Type": []string{contentType}, + } if isBase64 { b := make([]byte, b64enc.DecodedLen(len(data))) n, err := b64enc.Decode(b, data) if err != nil { - return nil, files.PathError("decode", uri.String(), err) + return nil, &os.PathError{ + Op: "decode_base64", + Path: uri.String(), + Err: err, + } } data = b[:n] } - return &withHeaders{ + return &withHeader{ Reader: wrapper.NewReaderFromBytes(data, uri, time.Now()), header: header, }, nil } - -func (h *handler) List(ctx context.Context, uri *url.URL) ([]os.FileInfo, error) { - return nil, files.PathError("readdir", uri.String(), os.ErrInvalid) -} diff --git a/lib/files/datafiles/data_test.go b/lib/files/datafiles/data_test.go index ce2b395..0cb4238 100644 --- a/lib/files/datafiles/data_test.go +++ b/lib/files/datafiles/data_test.go @@ -14,139 +14,112 @@ type headerer interface { Header() http.Header } -func TestDataURL(t *testing.T) { - uri, err := url.Parse("data:,ohai%2A") - if err != nil { - t.Fatal("unexpected error parsing constant URL", err) - } - - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) - defer cancel() - - f, err := (&handler{}).Open(ctx, uri) - if err != nil { - t.Fatal("unexpected error", err) - } - defer f.Close() - - b, err := ioutil.ReadAll(f) - if err != nil { - t.Fatal("unexpected error", err) - } - - expected := []byte("ohai*") - - if !bytes.Equal(b, expected) { - t.Errorf("got wrong content for data:,ohai%%2A got %v, wanted %v", b, expected) - } - - h, ok := f.(headerer) - if !ok { - t.Fatalf("returned files.Reader does not implement interface{ Header() (http.Header, error}") - } - - header := h.Header() - if err != nil { - t.Fatal("unexpected error", err) - } - - expectedContentType := "text/plain;charset=US-ASCII" - if got := header.Get("Content-Type"); got != expectedContentType { - t.Errorf("unexpected Content-Type header, got %q, wanted %q", got, expectedContentType) - } -} - -func TestDataURLBadBase64(t *testing.T) { - uri, err := url.Parse("data:base64,b2hhaSo=") - if err != nil { - t.Fatal("unexpected error parsing constant URL", err) - } - - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) - defer cancel() - - f, err := (&handler{}).Open(ctx, uri) - if err != nil { - t.Fatal("unexpected error", err) - } - defer f.Close() - - b, err := ioutil.ReadAll(f) - if err != nil { - t.Fatal("unexpected error", err) - } - - expected := []byte("b2hhaSo=") - - if !bytes.Equal(b, expected) { - t.Errorf("got wrong content for data:base64,b2hhaSo= got %v, wanted %v", b, expected) - } -} - -func TestDataURLSimpleBase64(t *testing.T) { - uri, err := url.Parse("data:;base64,b2hhaSo=") - if err != nil { - t.Fatal("unexpected error parsing constant URL", err) - } - - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) - defer cancel() - - f, err := (&handler{}).Open(ctx, uri) - if err != nil { - t.Fatal("unexpected error", err) - } - defer f.Close() - - b, err := ioutil.ReadAll(f) - if err != nil { - t.Fatal("unexpected error", err) - } - - expected := []byte("ohai*") - - if !bytes.Equal(b, expected) { - t.Errorf("got wrong content for data:base64,b2hhaSo= got %v, wanted %v", b, expected) - } -} - -func TestDataURLComplexBase64(t *testing.T) { - uri, err := url.Parse("data:;base64,ohai+/Z=") - if err != nil { - t.Fatal("unexpected error parsing constant URL", err) - } - - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) - defer cancel() - - f, err := (&handler{}).Open(ctx, uri) - if err != nil { - t.Fatal("unexpected error", err) - } - defer f.Close() - - b, err := ioutil.ReadAll(f) - if err != nil { - t.Fatal("unexpected error", err) - } - - expected := []byte{162, 22, 162, 251, 246} - - if !bytes.Equal(b, expected) { - t.Errorf("got wrong content for data:base64,ohai+/Z= got %v, wanted %v", b, expected) +var testHandler handler + +func TestDataURLs(t *testing.T) { + type test struct { + name string + input string + expected []byte + expectedContentType string + } + + tests := []test{ + { + name: "correctly encoded text with default media type", + input: "data:,ohai%2A", + expected: []byte("ohai*"), + }, + { + name: "correctly encoded text with media type", + input: "data:type/subtype;foo=bar,ohai%2A", + expected: []byte("ohai*"), + + expectedContentType: "type/subtype;foo=bar", + }, + { + name: "correctly encoded base64 text with default media type", + input: "data:;base64,b2hhaSo=", + expected: []byte("ohai*"), + }, + { + name: "correctly encoded base64 binary data with default media type", + input: "data:;base64,ohai+/Z=", + expected: []byte{162, 22, 162, 251, 246}, + }, + { + name: "correctly encoded base64 with media type", + input: "data:type/subtype;foo=bar;base64,b2hhaSo=", + expected: []byte("ohai*"), + + expectedContentType: "type/subtype;foo=bar", + }, + { + name: "incorrectly encoded base64 directive is actually media type", + input: "data:base64,b2hhaSo=", + expected: []byte("b2hhaSo="), + + expectedContentType: "base64", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + uri, err := url.Parse(tt.input) + if err != nil { + t.Fatal("unexpected error:", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + f, err := testHandler.Open(ctx, uri) + if err != nil { + t.Fatal("unexpected error:", err) + } + defer f.Close() + + b, err := ioutil.ReadAll(f) + if err != nil { + t.Fatal("unexpected error:", err) + } + + if !bytes.Equal(b, tt.expected) { + t.Errorf("got wrong content for %q got %q, wanted %q", tt.input, b, tt.expected) + } + + h, ok := f.(headerer) + if !ok { + t.Fatalf("returned files.Reader does not implement interface{ Header() http.Header }") + } + + header := h.Header() + if err != nil { + t.Fatal("unexpected error:", err) + } + + expectedContentType := "text/plain;charset=US-ASCII" + if tt.expectedContentType != "" { + expectedContentType = tt.expectedContentType + } + + if got := header.Get("Content-Type"); got != expectedContentType { + t.Errorf("Content-Type header was %q, expected %q", got, expectedContentType) + } + }) } } func TestDataURLNoComma(t *testing.T) { uri, err := url.Parse("data:ohai%2A") if err != nil { - t.Fatal("unexpected error parsing constant URL", err) + t.Fatal("unexpected error:", err) } ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() - f, err := (&handler{}).Open(ctx, uri) + f, err := testHandler.Open(ctx, uri) if err == nil { f.Close() t.Fatal("expected error but got none") @@ -156,13 +129,13 @@ func TestDataURLNoComma(t *testing.T) { func TestDataURLWithHost(t *testing.T) { uri, err := url.Parse("data://host/,ohai%2A") if err != nil { - t.Fatal("unexpected error parsing constant URL", err) + t.Fatal("unexpected error:", err) } ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() - f, err := (&handler{}).Open(ctx, uri) + f, err := testHandler.Open(ctx, uri) if err == nil { f.Close() t.Fatal("expected error but got none") @@ -172,58 +145,15 @@ func TestDataURLWithHost(t *testing.T) { func TestDataURLWithUser(t *testing.T) { uri, err := url.Parse("data://user@/,ohai%2A") if err != nil { - t.Fatal("unexpected error parsing constant URL", err) + t.Fatal("unexpected error:", err) } ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() - f, err := (&handler{}).Open(ctx, uri) + f, err := testHandler.Open(ctx, uri) if err == nil { f.Close() t.Fatal("expected error but got none") } } - -func TestDataWithHeader(t *testing.T) { - uriString := "data:type/subtype;foo=bar;base64,b2hhaSo=" - uri, err := url.Parse(uriString) - if err != nil { - t.Fatal("unexpected error parsing constant URL", err) - } - - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) - defer cancel() - - f, err := (&handler{}).Open(ctx, uri) - if err != nil { - t.Fatal("unexpected error", err) - } - defer f.Close() - - b, err := ioutil.ReadAll(f) - if err != nil { - t.Fatal("unexpected error", err) - } - - expected := []byte("ohai*") - - if !bytes.Equal(b, expected) { - t.Errorf("got wrong content for %s got %v, wanted %v", uriString, b, expected) - } - - h, ok := f.(headerer) - if !ok { - t.Fatalf("returned files.Reader does not implement interface{ Header() (http.Header, error}") - } - - header := h.Header() - if err != nil { - t.Fatal("unexpected error", err) - } - - expectedContentType := "type/subtype;foo=bar" - if got := header.Get("Content-Type"); got != expectedContentType { - t.Errorf("unexpected Content-Type header, got %q, wanted %q", got, expectedContentType) - } -} diff --git a/lib/files/errors.go b/lib/files/errors.go index 7510461..dbfd448 100644 --- a/lib/files/errors.go +++ b/lib/files/errors.go @@ -1,13 +1,17 @@ package files import ( + "errors" "os" + "syscall" ) -// PathError returns an *os.PathError with appropriate fields set. DO NOT USE. +// PathError is DEPRECATED, and returns an *os.PathError with appropriate fields set. DO NOT USE. // // This is a stop-gap quick-replace to remove `&os.PathError{ op, path, err }`. // One should use the direct complex literal instruction instead. +// +// DEPRECATED: use &os.PathError{} directly. func PathError(op, path string, err error) error { return &os.PathError{ Op: op, @@ -15,3 +19,52 @@ func PathError(op, path string, err error) error { Err: err, } } + +var ( + // ErrNotSupported should be returned, if a particular feature or option is not supported. + ErrNotSupported = errors.New("not supported") + + // ErrNotDirectory should be returned, if a request is made to ReadDir a non-directory. + ErrNotDirectory = syscall.ENOTDIR +) + +type invalidURLError struct { + s string +} + +func (e *invalidURLError) Error() string { + return e.s +} + +func (e *invalidURLError) Unwrap() error { + if e == ErrURLInvalid { + return os.ErrInvalid + } + + return ErrURLInvalid +} + +// NewInvalidURLError returns an error that formats as the given text, +// and errors.Is: files.ErrURLInvalid and os.ErrInvalid. +func NewInvalidURLError(reason string) error { + return &invalidURLError{ + s: reason, + } +} + +var ( + // ErrURLInvalid should be returned, if the URL is syntactically valid, but semantically invalid. + ErrURLInvalid = NewInvalidURLError("invalid url") + + // ErrURLCannotHaveAuthority should be returned, if the URL scheme does not allow a non-empty authority section. + ErrURLCannotHaveAuthority = NewInvalidURLError("invalid url: cannot have authority") + + // ErrURLNoHost should be return, if the URL scheme requires the authority section to not specify a host. + ErrURLNoHost = NewInvalidURLError("invalid url: scheme cannot have host in authority") + + // ErrURLHostRequired should be returned, if the URL scheme requires a host in the authority section. + ErrURLHostRequired = NewInvalidURLError("invalid url: scheme requires host in authority") + + // ErrURLPathRequired should be returned, if the URL scheme requires a non-empty path. + ErrURLPathRequired = NewInvalidURLError("invalid url: scheme requires non-empty path") +) diff --git a/lib/files/fd.go b/lib/files/fd.go index f268387..6edc20d 100644 --- a/lib/files/fd.go +++ b/lib/files/fd.go @@ -4,36 +4,77 @@ import ( "context" "net/url" "os" - "strconv" + "strings" ) type descriptorHandler struct{} func init() { - RegisterScheme(&descriptorHandler{}, "fd") + RegisterScheme(descriptorHandler{}, "fd") } -func (h *descriptorHandler) open(uri *url.URL) (*os.File, error) { - fd, err := strconv.ParseUint(filename(uri), 0, 64) +func openFD(uri *url.URL) (*os.File, error) { + if uri.Host != "" || uri.User != nil { + return nil, ErrURLCannotHaveAuthority + } + + num := strings.TrimPrefix(uri.Path, "/") + if num == "" { + var err error + num, err = url.PathUnescape(uri.Opaque) + if err != nil { + return nil, ErrURLInvalid + } + } + + fd, err := resolveFileHandle(num) if err != nil { return nil, err } + // Canonicalize the name. + uri = &url.URL{ + Scheme: "fd", + Opaque: url.PathEscape(num), + } + return os.NewFile(uintptr(fd), uri.String()), nil } -func (h *descriptorHandler) Open(ctx context.Context, uri *url.URL) (Reader, error) { - return h.open(uri) +func (descriptorHandler) Open(ctx context.Context, uri *url.URL) (Reader, error) { + f, err := openFD(uri) + if err != nil { + return nil, &os.PathError{ + Op: "open", + Path: uri.String(), + Err: err, + } + } + + return f, nil } -func (h *descriptorHandler) Create(ctx context.Context, uri *url.URL) (Writer, error) { - return h.open(uri) +func (descriptorHandler) Create(ctx context.Context, uri *url.URL) (Writer, error) { + f, err := openFD(uri) + if err != nil { + return nil, &os.PathError{ + Op: "create", + Path: uri.String(), + Err: err, + } + } + + return f, nil } -func (h *descriptorHandler) List(ctx context.Context, uri *url.URL) ([]os.FileInfo, error) { - f, err := h.open(uri) +func (descriptorHandler) ReadDir(ctx context.Context, uri *url.URL) ([]os.FileInfo, error) { + f, err := openFD(uri) if err != nil { - return nil, err + return nil, &os.PathError{ + Op: "readdir", + Path: uri.String(), + Err: err, + } } defer f.Close() diff --git a/lib/files/files.go b/lib/files/files.go index db20d0d..4f37850 100644 --- a/lib/files/files.go +++ b/lib/files/files.go @@ -84,22 +84,33 @@ import ( "os" ) -// File defines an interface that abstracts the concept of files to allow for multiple types implementation, beyond the local filesystem. +// File defines an interface that abstracts the central core concepts of files, for broad implementation. type File interface { io.Closer Name() string Stat() (os.FileInfo, error) } -// Reader defines a files.File that is also an io.ReadSeeker +// Reader defines an extension interface on files.File that is also an io.Reader. type Reader interface { File - io.ReadSeeker + io.Reader } -// Writer defines a files.File that is also an io.Writer with a Sync() function. +// SeekReader defines an extension interface on files.Reader that is also an io.Seeker +type SeekReader interface { + Reader + io.Seeker +} + +// Writer defines an extention interface on files.File that is also an io.Writer. type Writer interface { File io.Writer +} + +// SyncWriter defines an extension interface on files.Writer that also supports Sync(). +type SyncWriter interface { + Writer Sync() error } diff --git a/lib/files/filestore.go b/lib/files/filestore.go index 3c2c21e..511f9ba 100644 --- a/lib/files/filestore.go +++ b/lib/files/filestore.go @@ -4,8 +4,6 @@ import ( "context" "net/url" "os" - "sort" - "sync" ) // FileStore defines an interface which implements a system of accessing files for reading (Open) writing (Write) and directly listing (List) @@ -14,63 +12,3 @@ type FileStore interface { Create(ctx context.Context, uri *url.URL) (Writer, error) List(ctx context.Context, uri *url.URL) ([]os.FileInfo, error) } - -var fsMap struct { - sync.Mutex - - m map[string]FileStore - keys []string - sorted bool -} - -func getFS(uri *url.URL) (FileStore, bool) { - fsMap.Lock() - defer fsMap.Unlock() - - if fsMap.m == nil { - return nil, false - } - - fs, ok := fsMap.m[uri.Scheme] - return fs, ok -} - -// RegisterScheme takes a FileStore and attaches to it the given schemes so -// that files.Open will use that FileStore when a files.Open() is performed -// with a URL of any of those schemes. -func RegisterScheme(fs FileStore, schemes ...string) { - if len(schemes) < 1 { - return - } - - fsMap.Lock() - defer fsMap.Unlock() - - if fsMap.m == nil { - fsMap.m = make(map[string]FileStore) - } - fsMap.sorted = false - - for _, scheme := range schemes { - if _, ok := fsMap.m[scheme]; ok { - // TODO: report duplicate scheme registration - continue - } - - fsMap.m[scheme] = fs - fsMap.keys = append(fsMap.keys, scheme) - } -} - -// RegisteredSchemes returns a slice of strings that describe all registered schemes. -func RegisteredSchemes() []string { - fsMap.Lock() - defer fsMap.Unlock() - - if !fsMap.sorted { - sort.Strings(fsMap.keys) - fsMap.sorted = true - } - - return fsMap.keys -} diff --git a/lib/files/filesystem.go b/lib/files/filesystem.go new file mode 100644 index 0000000..8582115 --- /dev/null +++ b/lib/files/filesystem.go @@ -0,0 +1,69 @@ +package files + +import ( + "context" + "net/url" + "sort" + "sync" +) + +// FS defines an interface, which implements a minimal set of functionality for a filesystem from this package. +type FS interface { + Open(ctx context.Context, uri *url.URL) (Reader, error) +} + +var fsMap struct { + sync.Mutex + + m map[string]FS + keys []string + sorted bool +} + +func getFS(uri *url.URL) (FS, bool) { + fsMap.Lock() + defer fsMap.Unlock() + + fs, ok := fsMap.m[uri.Scheme] + return fs, ok +} + +// RegisterScheme takes an FS and attaches to it the given schemes so +// that files.Open will use that FS when a files.Open() is performed +// with a URL with any of those schemes. +func RegisterScheme(fs FS, schemes ...string) { + if len(schemes) < 1 { + return + } + + fsMap.Lock() + defer fsMap.Unlock() + + if fsMap.m == nil { + fsMap.m = make(map[string]FS) + } + fsMap.sorted = false + + for _, scheme := range schemes { + if _, ok := fsMap.m[scheme]; ok { + // TODO: report duplicate scheme registration + continue + } + + fsMap.m[scheme] = fs + fsMap.keys = append(fsMap.keys, scheme) + } +} + +// RegisteredSchemes returns a slice of strings that describe all registered schemes. +func RegisteredSchemes() []string { + fsMap.Lock() + defer fsMap.Unlock() + + if !fsMap.sorted { + sort.Strings(fsMap.keys) + fsMap.sorted = true + } + + return append([]string(nil), fsMap.keys...) +} diff --git a/lib/files/home/filename.go b/lib/files/home/filename.go new file mode 100644 index 0000000..9c0e1b9 --- /dev/null +++ b/lib/files/home/filename.go @@ -0,0 +1,129 @@ +// Package home implements a URL scheme "home:" which references files according to user home directories. +package home + +import ( + "fmt" + "net/url" + "os" + "os/user" + "path/filepath" + "sync" + + "github.com/puellanivis/breton/lib/files" +) + +type cache struct { + mu sync.RWMutex + + cur *user.User + users map[string]*user.User +} + +func (c *cache) current() (*user.User, error) { + c.mu.RLock() + u := c.cur + c.mu.RUnlock() + + if u != nil { + return u, nil + } + + c.mu.Lock() + defer c.mu.Unlock() + + if u := c.cur; u != nil { + // Another thread already did the work. + return u, nil + } + + u, err := user.Current() + if err != nil { + return nil, err + } + + c.cur = u + + if c.users == nil { + c.users = make(map[string]*user.User) + } + + c.users[u.Username] = u + + return u, nil +} + +func (c *cache) lookup(username string) (*user.User, error) { + c.mu.RLock() + u := c.users[username] + c.mu.RUnlock() + + if u != nil { + return u, nil + } + + c.mu.Lock() + defer c.mu.Unlock() + + if u = c.users[username]; u != nil { + // Another thread already did the work. + return u, nil + } + + u, err := user.Lookup(username) + if err != nil { + return nil, err + } + + if c.users == nil { + c.users = make(map[string]*user.User) + } + + c.users[username] = u + + return u, nil +} + +var users cache + +// Filename takes a given url, and returns a filename that is an absolute path +// for the specific default user if home:filename, or a specific user if home://user@/filename. +func Filename(uri *url.URL) (string, error) { + if uri.Host != "" { + return "", files.ErrURLNoHost + } + + path := uri.Path + if path == "" { + var err error + path, err = url.PathUnescape(uri.Opaque) + if err != nil { + return "", files.ErrURLInvalid + } + } + + var base string + + switch uri.User { + case nil: + u, err := users.current() + if err != nil { + return "", err + } + + base = u.HomeDir + + default: + u, err := users.lookup(uri.User.Username()) + if err != nil { + return "", err + } + + base = u.HomeDir + } + + if base == "" { + return "", fmt.Errorf("could not find home directory: %w", os.ErrNotExist) + } + + return filepath.Join(base, path), nil +} diff --git a/lib/files/home/home.go b/lib/files/home/home.go index acd29b1..635ddc3 100644 --- a/lib/files/home/home.go +++ b/lib/files/home/home.go @@ -6,78 +6,50 @@ import ( "io/ioutil" "net/url" "os" - "path/filepath" "github.com/puellanivis/breton/lib/files" - "github.com/puellanivis/breton/lib/os/user" ) -var userDir string - type handler struct{} func init() { - var err error - - // Short-circuit figuring out the whole User, in case we're on Windows. - userDir, err = user.CurrentHomeDir() - if err != nil { - return - } - - files.RegisterScheme(&handler{}, "home") -} - -// Filename takes a given url, and returns a filename that is an absolute path -// for the specific default user if home:filename, or a specific user if home://user@/filename. -func Filename(uri *url.URL) (string, error) { - if uri.Host != "" { - return "", os.ErrInvalid - } - - path := uri.Path - if path == "" { - path = uri.Opaque - } - - dir := userDir - - if uri.User != nil { - u, err := user.Lookup(uri.User.Username()) - if err != nil { - return "", err - } - - if u.HomeDir != "" { - dir = u.HomeDir - } - } - - return filepath.Join(dir, path), nil + files.RegisterScheme(handler{}, "home") } -func (h *handler) Open(ctx context.Context, uri *url.URL) (files.Reader, error) { +func (handler) Open(ctx context.Context, uri *url.URL) (files.Reader, error) { filename, err := Filename(uri) if err != nil { - return nil, files.PathError("open", uri.String(), err) + return nil, &os.PathError{ + Op: "open", + Path: uri.String(), + Err: err, + } } return os.Open(filename) } -func (h *handler) Create(ctx context.Context, uri *url.URL) (files.Writer, error) { +func (handler) Create(ctx context.Context, uri *url.URL) (files.Writer, error) { filename, err := Filename(uri) if err != nil { - return nil, files.PathError("create", uri.String(), err) + return nil, &os.PathError{ + Op: "create", + Path: uri.String(), + Err: err, + } } return os.Create(filename) } -func (h *handler) List(ctx context.Context, uri *url.URL) ([]os.FileInfo, error) { +func (handler) ReadDir(ctx context.Context, uri *url.URL) ([]os.FileInfo, error) { filename, err := Filename(uri) if err != nil { - return nil, files.PathError("readdir", uri.String(), err) + return nil, &os.PathError{ + Op: "readdir", + Path: uri.String(), + Err: err, + } } return ioutil.ReadDir(filename) diff --git a/lib/files/home/home_test.go b/lib/files/home/home_test.go new file mode 100644 index 0000000..5ba3619 --- /dev/null +++ b/lib/files/home/home_test.go @@ -0,0 +1,23 @@ +package home + +import ( + "testing" + + "github.com/puellanivis/breton/lib/files" +) + +func TestHandlerFulfillsReadDirFS(t *testing.T) { + var h files.FS = handler{} + + if _, ok := h.(files.ReadDirFS); !ok { + t.Fatal("handler does not implement files.ReadDirFS") + } +} + +func TestHandlerFulfillsCreateFS(t *testing.T) { + var h files.FS = handler{} + + if _, ok := h.(files.CreateFS); !ok { + t.Fatal("handler does not implement files.CreateFS") + } +} diff --git a/lib/files/httpfiles/http.go b/lib/files/httpfiles/http.go index aafce1b..59dd1c1 100644 --- a/lib/files/httpfiles/http.go +++ b/lib/files/httpfiles/http.go @@ -2,7 +2,6 @@ package httpfiles import ( - "context" "errors" "net/http" "net/url" @@ -25,17 +24,20 @@ func init() { schemeList = append(schemeList, scheme) } - files.RegisterScheme(&handler{}, schemeList...) + files.RegisterScheme(handler{}, schemeList...) } func elideDefaultPort(uri *url.URL) *url.URL { port := uri.Port() + if port == "" { + return uri + } /* elide default ports */ - if defport, ok := schemes[uri.Scheme]; ok && defport == port { - newuri := *uri - newuri.Host = uri.Hostname() - return &newuri + if defport := schemes[uri.Scheme]; defport == port { + u := *uri + u.Host = uri.Hostname() + return &u } return uri @@ -53,7 +55,3 @@ func getErr(resp *http.Response) error { return errors.New(resp.Status) } - -func (h *handler) List(ctx context.Context, uri *url.URL) ([]os.FileInfo, error) { - return nil, files.PathError("readdir", uri.String(), os.ErrInvalid) -} diff --git a/lib/files/httpfiles/http_test.go b/lib/files/httpfiles/http_test.go new file mode 100644 index 0000000..3630013 --- /dev/null +++ b/lib/files/httpfiles/http_test.go @@ -0,0 +1,15 @@ +package httpfiles + +import ( + "testing" + + "github.com/puellanivis/breton/lib/files" +) + +func TestHandlerFulfillsCreateFS(t *testing.T) { + var h files.FS = handler{} + + if _, ok := h.(files.CreateFS); !ok { + t.Fatal("handler does not implement files.CreateFS") + } +} diff --git a/lib/files/httpfiles/option.go b/lib/files/httpfiles/option.go index f3867bd..e99d324 100644 --- a/lib/files/httpfiles/option.go +++ b/lib/files/httpfiles/option.go @@ -9,44 +9,43 @@ import ( // WithForm returns a files.Option that that will add to the underlying HTTP // request the url.Values given as a POST request. (A GET request can always // be composed through the URL string itself. -func WithForm(vals url.Values) files.Option { - body := []byte(vals.Encode()) +func WithForm(form url.Values) files.Option { + body := []byte(form.Encode()) return WithContent("POST", "application/x-www-form-urlencoded", body) } // WithContent returns a files.Option that will set the Method, Body and // Content-Type of the underlying HTTP request to the given values. func WithContent(method, contentType string, data []byte) files.Option { - type methodSetter interface { - SetMethod(string) string - } - - type ctypeSetter interface { - SetContentType(string) string - } + data = append([]byte(nil), data...) - type bodySetter interface { - SetBody([]byte) []byte + type contentSetter interface { + SetMethod(string) (string, error) + SetContentType(string) (string, error) + SetBody([]byte) ([]byte, error) } return func(f files.File) (files.Option, error) { - var methodSave, ctypeSave string - var dataSave []byte - - if r, ok := f.(methodSetter); ok { - methodSave = r.SetMethod(method) - } - - if r, ok := f.(ctypeSetter); ok { - ctypeSave = r.SetContentType(contentType) - } - - if r, ok := f.(bodySetter); ok { - dataSave = r.SetBody(data) + if f, ok := f.(contentSetter); ok { + methodSave, err := f.SetMethod(method) + if err != nil { + return nil, err + } + + ctypeSave, err := f.SetContentType(contentType) + if err != nil { + return nil, err + } + + dataSave, err := f.SetBody(data) + if err != nil { + return nil, err + } + + return WithContent(methodSave, ctypeSave, dataSave), nil } - // option is not reversible - return WithContent(methodSave, ctypeSave, dataSave), nil + return nil, files.ErrNotSupported } } @@ -54,17 +53,20 @@ func WithContent(method, contentType string, data []byte) files.Option { // underlying HTTP request to be the given value. func WithMethod(method string) files.Option { type methodSetter interface { - SetMethod(string) string + SetMethod(string) (string, error) } return func(f files.File) (files.Option, error) { - var save string + if f, ok := f.(methodSetter); ok { + save, err := f.SetMethod(method) + if err != nil { + return nil, err + } - if r, ok := f.(methodSetter); ok { - save = r.SetMethod(method) + return WithMethod(save), nil } - return WithMethod(save), nil + return nil, files.ErrNotSupported } } @@ -74,16 +76,19 @@ func WithMethod(method string) files.Option { // during the eventual commit of the request at Sync() or Close().) func WithContentType(contentType string) files.Option { type ctypeSetter interface { - SetContentType(string) string + SetContentType(string) (string, error) } return func(f files.File) (files.Option, error) { - var save string + if f, ok := f.(ctypeSetter); ok { + save, err := f.SetContentType(contentType) + if err != nil { + return nil, err + } - if r, ok := f.(ctypeSetter); ok { - save = r.SetContentType(contentType) + return WithContentType(save), nil } - return WithContentType(save), nil + return nil, files.ErrNotSupported } } diff --git a/lib/files/httpfiles/reader.go b/lib/files/httpfiles/reader.go index fd2a7b7..2aff484 100644 --- a/lib/files/httpfiles/reader.go +++ b/lib/files/httpfiles/reader.go @@ -67,16 +67,11 @@ func (r *reader) Seek(offset int64, whence int) (int64, error) { return 0, r.err } - if r.s == nil { - switch s := r.r.(type) { - case io.Seeker: - r.s = s - default: - return 0, os.ErrInvalid - } + if s, ok := r.r.(io.Seeker); ok { + return s.Seek(offset, whence) } - return r.s.Seek(offset, whence) + return 0, files.ErrNotSupported } func (r *reader) Close() error { @@ -92,7 +87,7 @@ func (r *reader) Close() error { return nil } -func (h *handler) Open(ctx context.Context, uri *url.URL) (files.Reader, error) { +func (handler) Open(ctx context.Context, uri *url.URL) (files.Reader, error) { uri = elideDefaultPort(uri) cl, ok := getClient(ctx) @@ -100,8 +95,7 @@ func (h *handler) Open(ctx context.Context, uri *url.URL) (files.Reader, error) cl = http.DefaultClient } - req := newHTTPRequest(http.MethodGet, uri) - req = req.WithContext(ctx) + req := httpNewRequestWithContext(ctx, http.MethodGet, uri) if ua, ok := getUserAgent(ctx); ok { req.Header.Set("User-Agent", ua) @@ -133,10 +127,16 @@ func (h *handler) Open(ctx context.Context, uri *url.URL) (files.Reader, error) select { case loading <- struct{}{}: case <-ctx.Done(): - r.err = files.PathError("open", r.name, ctx.Err()) + r.err = &os.PathError{ + Op: "open", + Path: r.name, + Err: ctx.Err(), + } return } + r.markSent() + // So, we will not arrive here until someone is ranging over the loading channel. // // This ensures the actual http request HAPPENS AFTER the first file operation is called, @@ -147,7 +147,11 @@ func (h *handler) Open(ctx context.Context, uri *url.URL) (files.Reader, error) resp, err := cl.Do(req) if err != nil { - r.err = files.PathError("open", r.name, err) + r.err = &os.PathError{ + Op: "open", + Path: r.name, + Err: err, + } return } @@ -164,9 +168,13 @@ func (h *handler) Open(ctx context.Context, uri *url.URL) (files.Reader, error) r.info = wrapper.NewInfo(uri, int(resp.ContentLength), t) if err := getErr(resp); err != nil { - resp.Body.Close() + _ = files.Discard(resp.Body) - r.err = files.PathError("open", uri.String(), err) + r.err = &os.PathError{ + Op: "open", + Path: uri.String(), + Err: err, + } return } @@ -177,10 +185,14 @@ func (h *handler) Open(ctx context.Context, uri *url.URL) (files.Reader, error) b, err := files.ReadFrom(resp.Body) if err != nil { - r.err = files.PathError("read", uri.String(), err) + r.err = &os.PathError{ + Op: "read", + Path: uri.String(), + Err: err, + } return } - resp.Body.Close() + // files.ReadFrom will close the body. r.r = bytes.NewReader(b) }() diff --git a/lib/files/httpfiles/request.go b/lib/files/httpfiles/request.go index f648f66..679efa9 100644 --- a/lib/files/httpfiles/request.go +++ b/lib/files/httpfiles/request.go @@ -2,22 +2,22 @@ package httpfiles import ( "bytes" + "context" + "errors" "io" "io/ioutil" "net/http" + "net/textproto" "net/url" + "sync" ) -type request struct { - name string - - // this is what we really care about - body []byte - req *http.Request -} +// ErrRequestAlreadySent is returned, if you attempt to modify a Request field +// after the request has already been sent. +var ErrRequestAlreadySent = errors.New("request already sent") -func newHTTPRequest(method string, uri *url.URL) *http.Request { - return &http.Request{ +func httpNewRequestWithContext(ctx context.Context, method string, uri *url.URL) *http.Request { + r := &http.Request{ Method: method, URL: uri, Proto: "HTTP/1.1", @@ -26,47 +26,163 @@ func newHTTPRequest(method string, uri *url.URL) *http.Request { Header: make(http.Header), Host: uri.Host, } + + return r.WithContext(ctx) +} + +type request struct { + mu sync.RWMutex + + sent bool // true if request has already been sent. + + setContentType bool // true if Content-Type has been set to a specific value. + + name string + + // this is what we really care about + body []byte + req *http.Request +} + +func (r *request) markSent() { + r.mu.Lock() + defer r.mu.Unlock() + + r.sent = true } func (r *request) Name() string { + r.mu.RLock() + defer r.mu.RUnlock() + return r.name } -func (r *request) SetMethod(method string) string { +func (r *request) SetName(name string) { + r.mu.Lock() + defer r.mu.Unlock() + + r.name = name +} + +func (r *request) SetMethod(method string) (string, error) { + r.mu.Lock() + defer r.mu.Unlock() + save := r.req.Method + + if r.sent { + return save, ErrRequestAlreadySent + } + r.req.Method = method - return save + return save, nil } -func (r *request) SetContentType(contentType string) string { +func (r *request) SetHeader(key string, values ...string) ([]string, error) { + key = textproto.CanonicalMIMEHeaderKey(key) + + r.mu.Lock() + defer r.mu.Unlock() + + if r.req.Header == nil { + // Safety check. + r.req.Header = make(http.Header) + } + + save := r.req.Header[key] + + if r.sent { + return append([]string(nil), save...), ErrRequestAlreadySent + } + + if key == "Content-Type" { + r.setContentType = len(values) > 0 + } + + r.req.Header[key] = append([]string(nil), values...) + + return save, nil +} + +func (r *request) AddHeader(key string, values ...string) ([]string, error) { + key = textproto.CanonicalMIMEHeaderKey(key) + + r.mu.Lock() + defer r.mu.Unlock() + if r.req.Header == nil { + // Safety check. r.req.Header = make(http.Header) } - save := r.req.Header.Get("Content-Type") - r.req.Header.Set("Content-Type", contentType) + save := r.req.Header[key] + + switch { + case r.sent: + return append([]string(nil), save...), ErrRequestAlreadySent + + case len(values) < 1: + return append([]string(nil), save...), nil + } + + cur := save[:len(save):len(save)] // truncate capacity, so the append below clones save. + r.req.Header[key] = append(cur, values...) + + // Go ahead and return `save` with full allocated capacity. + return save, nil +} - return save +func (r *request) SetContentType(contentType string) (string, error) { + prev, err := r.SetHeader("Content-Type", contentType) + if len(prev) > 0 { + return prev[0], err + } + return "", err } -func (r *request) SetBody(body []byte) []byte { +func (r *request) SetBody(body []byte) ([]byte, error) { + r.mu.Lock() + defer r.mu.Unlock() + save := r.body - r.body = body - r.req.Method = http.MethodPost + if r.sent { + // Since we are not changing `r.body` at all, + // multiple calls would all return the same backing store. + // So, we clone the body we’re returning just to be safe. + return append([]byte(nil), save...), ErrRequestAlreadySent + } + + // To ensure we have an exclusive copy of the backing store, + // we clone the input body as a safety measure. + // Otherwise, a caller could mutate the backing store behind our back. + body = append([]byte(nil), body...) + r.body = body // `save` is now the exclusive reference to the previous backing store. + r.req.ContentLength = int64(len(r.body)) - r.req.GetBody = func() (io.ReadCloser, error) { - if len(r.body) < 1 { - return nil, nil + if !r.setContentType { + r.req.Header.Set("Content-Type", http.DetectContentType(body)) + } + + switch { + case len(body) < 1: + r.req.GetBody = func() (io.ReadCloser, error) { + return http.NoBody, nil } - return ioutil.NopCloser(bytes.NewReader(r.body)), nil + default: + r.req.GetBody = func() (io.ReadCloser, error) { + return ioutil.NopCloser(bytes.NewReader(body)), nil + } } - // we know this http.Request.GetBody won’t throw an error + // we know this http.Request.GetBody can’t throw an error r.req.Body, _ = r.req.GetBody() - return save + // Since save is the exclusive reference to the previous backing store, + // there is no need for a copy here. + return save, nil } diff --git a/lib/files/httpfiles/writer.go b/lib/files/httpfiles/writer.go index 70842b3..8aca88c 100644 --- a/lib/files/httpfiles/writer.go +++ b/lib/files/httpfiles/writer.go @@ -4,6 +4,7 @@ import ( "context" "net/http" "net/url" + "os" "github.com/puellanivis/breton/lib/files" "github.com/puellanivis/breton/lib/files/wrapper" @@ -22,7 +23,7 @@ func (w *writer) Header() (http.Header, error) { return w.request.req.Header, nil } -func (h *handler) Create(ctx context.Context, uri *url.URL) (files.Writer, error) { +func (handler) Create(ctx context.Context, uri *url.URL) (files.Writer, error) { uri = elideDefaultPort(uri) cl, ok := getClient(ctx) @@ -30,8 +31,7 @@ func (h *handler) Create(ctx context.Context, uri *url.URL) (files.Writer, error cl = http.DefaultClient } - req := newHTTPRequest(http.MethodPost, uri) - req = req.WithContext(ctx) + req := httpNewRequestWithContext(ctx, http.MethodPost, uri) if ua, ok := getUserAgent(ctx); ok { req.Header.Set("User-Agent", ua) @@ -45,22 +45,35 @@ func (h *handler) Create(ctx context.Context, uri *url.URL) (files.Writer, error // The http.Writer does not actually perform the http.Request until wrapper.Sync is called, // So there is no need for complex synchronization like the httpfiles.Reader needs. w := wrapper.NewWriter(ctx, uri, func(b []byte) error { - if r.req.Header.Get("Content-Type") == "" { - r.req.Header.Set("Content-Type", http.DetectContentType(b)) + _, err := r.SetBody(b) + if err != nil { + return err } - _ = r.SetBody(b) + + r.mu.Lock() + defer r.mu.Unlock() + + // Unlike on the reader side, we never want to call r.markSent() + // Because, we perform a brand new request, every Sync() or Close(). + // So, we can continuously update headers, and bodies, and methods. resp, err := cl.Do(r.req) if err != nil { - return files.PathError("write", r.name, err) + return &os.PathError{ + Op: "write", + Path: r.name, + Err: err, + } } - if err := files.Discard(resp.Body); err != nil { - return err - } + _ = files.Discard(resp.Body) if err := getErr(resp); err != nil { - return files.PathError("write", r.name, err) + return &os.PathError{ + Op: "write", + Path: r.name, + Err: err, + } } return nil diff --git a/lib/files/json/json.go b/lib/files/json/json.go index 1e89f33..67e5597 100644 --- a/lib/files/json/json.go +++ b/lib/files/json/json.go @@ -15,18 +15,15 @@ func Unmarshal(data []byte, v interface{}) error { return json.Unmarshal(data, v) } -// ReadFrom will ReadAndClose the given io.ReadCloser and unmarshal that data into v as per json.Unmarshal. -func ReadFrom(r io.ReadCloser, v interface{}) error { +// ReadFrom will read the whole of io.Reader into memory. +// It will then close the reader, if it implements io.Closer. +// Finally it will unmarshal that data into v as per json.Unmarshal. +func ReadFrom(r io.Reader, v interface{}) error { data, err := files.ReadFrom(r) if err != nil { return err } - if len(data) < 1 { - v = nil - return nil - } - return json.Unmarshal(data, v) } @@ -40,29 +37,24 @@ func Read(ctx context.Context, filename string, v interface{}) error { return ReadFrom(f, v) } -// Marshal is a wrapper around encoding/json.Marshal that will optionally apply -// Indent or Compact options. +// Marshal is a wrapper around encoding/json.Marshal, +// that will optionally apply Indent, EscapeHTML, or Compact options. func Marshal(v interface{}, opts ...Option) ([]byte, error) { + b := new(bytes.Buffer) + c := &config{ - escapeHTML: true, + Encoder: json.NewEncoder(b), } for _, opt := range opts { - _ = opt(c) + opt(c) } - b := new(bytes.Buffer) - enc := json.NewEncoder(b) - if c.prefix != "" || c.indent != "" { - enc.SetIndent(c.prefix, c.indent) - } - - if !c.escapeHTML { - enc.SetEscapeHTML(c.escapeHTML) + c.SetIndent(c.prefix, c.indent) } - if err := enc.Encode(v); err != nil { + if err := c.Encode(v); err != nil { return nil, err } @@ -71,14 +63,17 @@ func Marshal(v interface{}, opts ...Option) ([]byte, error) { if err := json.Compact(buf, b.Bytes()); err != nil { return nil, err } - b = buf + + return buf.Bytes(), nil } return b.Bytes(), nil } -// WriteTo writes a value marshalled as JSON to the the given io.WriteCloser. -func WriteTo(w io.WriteCloser, v interface{}, opts ...Option) error { +// WriteTo marshals v as per json.Marshal, +// it then writes that data to the the given io.Writer. +// Finally, it will close it, if it implements io.Closer. +func WriteTo(w io.Writer, v interface{}, opts ...Option) error { b, err := Marshal(v, opts...) if err != nil { return err @@ -87,7 +82,7 @@ func WriteTo(w io.WriteCloser, v interface{}, opts ...Option) error { return files.WriteTo(w, b) } -// Write writes a marshaled JSON to a filename with the given Context. +// Write writes a marshaled JSON output to the given filename. func Write(ctx context.Context, filename string, v interface{}, opts ...Option) error { f, err := files.Create(ctx, filename) if err != nil { diff --git a/lib/files/json/option.go b/lib/files/json/option.go index f5fc4d7..6e54047 100644 --- a/lib/files/json/option.go +++ b/lib/files/json/option.go @@ -1,55 +1,44 @@ package json +import ( + "encoding/json" +) + type config struct { + *json.Encoder + prefix, indent string - escapeHTML, compact bool + compact bool } -// An Option is a function that apply a specific option, then returns an Option function -// that will revert the change applied. -type Option func(*config) Option +// An Option is a function that applies a specific option to an encoder config. +type Option func(*config) // WithPrefix returns a function that directs Marshal to use the prefix string given. func WithPrefix(prefix string) Option { - return func(c *config) Option { - save := c.prefix - + return func(c *config) { c.prefix = prefix - - return WithPrefix(save) } } // WithIndent returns a function that directs Marshal to use the indenting string given. func WithIndent(indent string) Option { - return func(c *config) Option { - save := c.indent - + return func(c *config) { c.indent = indent - - return WithIndent(save) } } // EscapeHTML returns a function that directs Marshal to either enable or disable HTML escaping. -func EscapeHTML(value bool) Option { - return func(c *config) Option { - save := c.escapeHTML - - c.escapeHTML = value - - return EscapeHTML(save) +func EscapeHTML(on bool) Option { + return func(c *config) { + c.SetEscapeHTML(on) } } // Compact returns a function that directs Marshal to use compact format. func Compact(value bool) Option { - return func(c *config) Option { - save := c.compact - + return func(c *config) { c.compact = value - - return Compact(save) } } diff --git a/lib/files/local.go b/lib/files/local.go index d3c671c..8c8659b 100644 --- a/lib/files/local.go +++ b/lib/files/local.go @@ -9,33 +9,51 @@ import ( type localFS struct{} -// Local implements a wrapper from the os functions Open, Create, and Readdir, to the files.FileStore implementation. -var Local FileStore = &localFS{} +// Local implements a wrapper from os.Open, os.Create, and os.Readdir, to the files.FS implementation. +var Local FS = localFS{} func init() { RegisterScheme(Local, "file") } -func filename(uri *url.URL) string { - fname := uri.Path - if fname == "" { - fname = uri.Opaque +// Open opens up a local filesystem file specified in the uri.Path for reading. +func (localFS) Open(ctx context.Context, uri *url.URL) (Reader, error) { + name, err := resolveFileURL(uri) + if err != nil { + return nil, &os.PathError{ + Op: "open", + Path: uri.String(), + Err: err, + } } - return fname -} - -// Open opens up a local filesystem file specified in the uri.Path for reading. -func (h *localFS) Open(ctx context.Context, uri *url.URL) (Reader, error) { - return os.Open(filename(uri)) + return os.Open(name) } // Create opens up a local filesystem file specified in the uri.Path for writing. It will create a new one if it does not exist. -func (h *localFS) Create(ctx context.Context, uri *url.URL) (Writer, error) { - return os.Create(filename(uri)) +func (localFS) Create(ctx context.Context, uri *url.URL) (Writer, error) { + name, err := resolveFileURL(uri) + if err != nil { + return nil, &os.PathError{ + Op: "create", + Path: uri.String(), + Err: err, + } + } + + return os.Create(name) } // List returns the whole slice of os.FileInfos for a specific local filesystem at uri.Path. -func (h *localFS) List(ctx context.Context, uri *url.URL) ([]os.FileInfo, error) { - return ioutil.ReadDir(filename(uri)) +func (localFS) ReadDir(ctx context.Context, uri *url.URL) ([]os.FileInfo, error) { + name, err := resolveFileURL(uri) + if err != nil { + return nil, &os.PathError{ + Op: "readdir", + Path: uri.String(), + Err: err, + } + } + + return ioutil.ReadDir(name) } diff --git a/lib/files/open.go b/lib/files/open.go index 94a3959..64303fd 100644 --- a/lib/files/open.go +++ b/lib/files/open.go @@ -2,15 +2,13 @@ package files import ( "context" - "io/ioutil" - "net/url" "os" - "path/filepath" ) -// Open takes a Context and a filename (which may be a URL) and returns a -// files.Reader which will read the contents of that filename or URL. All -// errors and reversion functions returned by Option arguments are discarded. +// Open takes a context and a filename (which may be a URL) and +// returns a `files.Reader`, which will read the contents of that filename or URL. +// +// All errors and reversion functions returned by Option arguments are discarded. func Open(ctx context.Context, filename string, options ...Option) (Reader, error) { f, err := open(ctx, filename) if err != nil { @@ -30,40 +28,19 @@ func open(ctx context.Context, filename string) (Reader, error) { return os.Stdin, nil } - if filepath.IsAbs(filename) { - return os.Open(filename) + uri := parsePath(ctx, filename) + if isPath(uri) { + return os.Open(uri.Path) } - if uri, err := url.Parse(filename); err == nil { - uri = resolveFilename(ctx, uri) - - if fs, ok := getFS(uri); ok { - return fs.Open(ctx, uri) - } - } - - return os.Open(filename) -} - -// List takes a Context and a filename (which may be a URL) and returns a list -// of os.FileInfo that describes the files contained in the directory or listing. -func List(ctx context.Context, filename string) ([]os.FileInfo, error) { - switch filename { - case "", "-", "/dev/stdin": - return os.Stdin.Readdir(0) - } - - if filepath.IsAbs(filename) { - return ioutil.ReadDir(filename) - } - - if uri, err := url.Parse(filename); err == nil { - uri = resolveFilename(ctx, uri) - - if fs, ok := getFS(uri); ok { - return fs.List(ctx, uri) + fsys, ok := getFS(uri) + if !ok { + return nil, &os.PathError{ + Op: "open", + Path: uri.String(), + Err: ErrNotSupported, } } - return ioutil.ReadDir(filename) + return fsys.Open(ctx, uri) } diff --git a/lib/files/option.go b/lib/files/option.go index 086bcf3..9f6e984 100644 --- a/lib/files/option.go +++ b/lib/files/option.go @@ -1,15 +1,10 @@ package files import ( - "errors" "os" "time" ) -// ErrNotSupported should be returned when a specific file.File given to an -// Option does not support the Option specified. -var ErrNotSupported = errors.New("option not supported") - // Option is a function that applies a specific option to a files.File, it // returns an Option and and error. If error is not nil, then the Option // returned will revert the option that was set. Since errors returned by diff --git a/lib/files/path.go b/lib/files/path.go index 6c4eed4..470a5ba 100644 --- a/lib/files/path.go +++ b/lib/files/path.go @@ -7,64 +7,80 @@ import ( ) func isPath(uri *url.URL) bool { - if uri.IsAbs() { + switch { + case uri.IsAbs(): return false - } - if uri.User != nil { + case uri.User != nil: return false - } - if len(uri.Host)+len(uri.RawQuery)+len(uri.Fragment) > 0 { + case len(uri.Host)+len(uri.RawQuery)+len(uri.Fragment) > 0: return false - } - if uri.ForceQuery { + case uri.ForceQuery: return false } return true } -func getPath(uri *url.URL) string { - if uri.RawPath != "" { - return uri.RawPath +func resolveURL(ctx context.Context, uri *url.URL) *url.URL { + if uri.IsAbs() { + // short-circuit: If the url is absolute, + // then we should never consider resolving it as a reference relative to root. + return uri } - return uri.Path -} + if root, ok := getRoot(ctx); ok { + switch { + case !isPath(root): + // If root is not a path-only URL, + // then always resolve the uri as a reference. + return root.ResolveReference(uri) + + case isPath(uri): + // special-case: If both root and uri are wrapped simple paths, + // then join their Paths through filepath.Join, + // instead of using URL path handling. + return &url.URL{ + Path: filepath.Join(root.Path, uri.Path), + } + } -func makePath(path string) *url.URL { - return &url.URL{ - Path: path, - RawPath: path, + // root is a wrapped simple path, but uri is not, + // there’s no good way to really join these two together. + // fallthrough to not even considering root. } -} -func resolveFilename(ctx context.Context, uri *url.URL) *url.URL { - if uri.IsAbs() { - return uri + if isPath(uri) { + uri.Path = filepath.Clean(uri.Path) } - var path string - - if isPath(uri) { - path = getPath(uri) + return uri +} - if filepath.IsAbs(path) { - return makePath(path) +// parsePath will always return a non-nil `*url.URL`. +// +// If the path is an invalid URL, then we will return a wrapped simple path, +// which is simply a &url.URL{ Path: path }. +func parsePath(ctx context.Context, path string) *url.URL { + if filepath.IsAbs(path) { + // If for this architecture, path is an an absolute path + // then we should only ever treat it as a wrapped simple path. + return &url.URL{ + Path: filepath.Clean(path), } } - root, ok := getRoot(ctx) - if !ok { - return uri - } - - if path != "" && isPath(root) { - return makePath(filepath.Join(getPath(root), path)) - + uri, err := url.Parse(path) + if err != nil { + // If this path fails to parse as a URL, treat it like a wrapped simple path. + uri = &url.URL{ + Path: path, + } } - return root.ResolveReference(uri) + // Since we do not `filepath.Clean` the wrapped simple path from above, + // this function must assure that the wrapped simple path is cleaned before returning the url. + return resolveURL(ctx, uri) } diff --git a/lib/files/path_test.go b/lib/files/path_test.go index 6033046..21a7c0e 100644 --- a/lib/files/path_test.go +++ b/lib/files/path_test.go @@ -3,29 +3,48 @@ package files import ( "context" "net/url" + "reflect" "runtime" "testing" ) +func TestInvalidURLAsSimplePath(t *testing.T) { + path := parsePath(context.Background(), ":/foo") + expect := &url.URL{ + Path: ":/foo", + } + + if !reflect.DeepEqual(path, expect) { + t.Errorf("parsePath returned %#v, expected %#v", path, expect) + } +} + func TestPathWindows(t *testing.T) { if runtime.GOOS != "windows" { return } - p := makePath("C:\\") - if !isPath(p) { - t.Fatalf("makePath returned something not an isPath, got %#v", p) + ctx := context.Background() + + root := parsePath(ctx, "C:\\") + expect := &url.URL{ + Path: "C:\\", } - if path := getPath(p); path != "C:\\" { - t.Errorf("getPath(makePath) not inverting, got %s", path) + if !reflect.DeepEqual(root, expect) { + t.Fatalf("parsePath returned %#v, expected: %#v", root, expect) } - ctx := WithRootURL(context.Background(), p) + ctx = WithRootURL(ctx, root) + + filename := parsePath(ctx, "filename") + + expect = &url.URL{ + Path: "C:\\filename", + } - filename := makePath("filename") - if path := resolveFilename(ctx, filename); getPath(path) != "C:\\filename" { - t.Errorf("resolveFilename with %q and %q gave %#v instead", filename, p, path) + if !reflect.DeepEqual(filename, expect) { + t.Errorf("resolveFilename returned %#v, expected: %#v", filename, expect) } } @@ -34,52 +53,48 @@ func TestPathPOSIX(t *testing.T) { return } - p := makePath("/asdf") - if !isPath(p) { - t.Fatalf("makePath returned something not an isPath, got %#v", p) + ctx := context.Background() + + root := parsePath(ctx, "/tmp") + expect := &url.URL{ + Path: "/tmp", } - if path := getPath(p); path != "/asdf" { - t.Errorf("getPath(makePath) not inverting, got %s", path) + if !reflect.DeepEqual(root, expect) { + t.Fatalf("parsePath returned %#v, expected: %#v", root, expect) } - ctx := WithRootURL(context.Background(), p) + ctx = WithRootURL(ctx, root) - filename := makePath("filename") - if path := resolveFilename(ctx, filename); getPath(path) != "/asdf/filename" { - t.Errorf("resolveFilename with %q and %q gave %#v instead", filename, p, path) - } -} + filename := parsePath(ctx, "filename") -func TestPathURL(t *testing.T) { - p, err := url.Parse("scheme://username:password@hostname:12345/path/?query#fragment") - if err != nil { - t.Fatal(err) + expect = &url.URL{ + Path: "/tmp/filename", } - if isPath(p) { - t.Fatalf("url.Parse with scheme returned something that is an isPath, got %#v", p) + if !reflect.DeepEqual(filename, expect) { + t.Errorf("resolveFilename returned %#v, expected: %#v", filename, expect) } +} - ctx := WithRootURL(context.Background(), p) +func TestPathURL(t *testing.T) { + ctx := context.Background() - filename := makePath("filename") - if path := resolveFilename(ctx, filename); path.String() != "scheme://username:password@hostname:12345/path/filename" { - t.Errorf("resolveFilename with %q and %q gave %#v instead", filename, p, path) - } + path := "scheme://username:password@hostname:12345/path/?query#fragment" - p, err = url.Parse("file:///c:/Windows/") - if err != nil { - t.Fatal(err) - } + root := parsePath(ctx, path) + expect := path - if isPath(p) { - t.Fatalf("url.Parse with scheme returned something that is an isPath, got %#v", p) + if got := root.String(); got != expect { + t.Fatalf("parsePath returned %q, expected: %q", root, expect) } - ctx = WithRootURL(context.Background(), p) + ctx = WithRootURL(ctx, root) + + filename := parsePath(ctx, "filename?newquery#newfragment") + expect = "scheme://username:password@hostname:12345/path/filename?newquery#newfragment" - if path := resolveFilename(ctx, filename); path.String() != "file:///c:/Windows/filename" { - t.Errorf("resolveFilename with %q and %q gave %#v instead", filename, p, path) + if got := filename.String(); got != expect { + t.Errorf("resolveFilename returned %q, expected: %q", filename, expect) } } diff --git a/lib/files/read.go b/lib/files/read.go index f606926..7c211b6 100644 --- a/lib/files/read.go +++ b/lib/files/read.go @@ -6,22 +6,31 @@ import ( "io/ioutil" ) -// ReadFrom reads the entire content of an io.ReadCloser and returns the content as a byte slice. It will also Close the reader. -func ReadFrom(r io.ReadCloser) ([]byte, error) { +// ReadFrom reads the entire content of a io.Reader and returns the content as a byte slice. +// If the reader also implements io.Closer, it will also Close it. +func ReadFrom(r io.Reader) ([]byte, error) { b, err := ioutil.ReadAll(r) - if err1 := r.Close(); err == nil { - err = err1 + + if c, ok := r.(io.Closer); ok { + if err2 := c.Close(); err == nil { + err = err2 + } } + return b, err } -// Discard throws away the entire content of an io.ReadCloser and closes the reader. +// Discard throws away the entire content of an io.Reader. +// If the reader also implements io.Closer, it will also Close it. +// // This is specifically not context aware, it is intended to always run to completion. -func Discard(r io.ReadCloser) error { +func Discard(r io.Reader) error { _, err := io.Copy(ioutil.Discard, r) - if err2 := r.Close(); err == nil { - err = err2 + if c, ok := r.(io.Closer); ok { + if err2 := c.Close(); err == nil { + err = err2 + } } return err diff --git a/lib/files/readdir.go b/lib/files/readdir.go new file mode 100644 index 0000000..d462dc1 --- /dev/null +++ b/lib/files/readdir.go @@ -0,0 +1,90 @@ +package files + +import ( + "context" + "io/ioutil" + "net/url" + "os" +) + +// ReadDirFS defines an extension interface on FS, which also provides an ability to enumerate files given a prefix. +type ReadDirFS interface { + FS + ReadDir(ctx context.Context, uri *url.URL) ([]os.FileInfo, error) +} + +// ReadDir takes a context and a filename (which may be a URL) and +// returns a slice of `os.FileInfo` that describes the files contained in the directory or listing. +func ReadDir(ctx context.Context, filename string) ([]os.FileInfo, error) { + switch filename { + case "", "-", "/dev/stdin": + return os.Stdin.Readdir(0) + } + + uri := parsePath(ctx, filename) + if isPath(uri) { + return ioutil.ReadDir(filename) + } + + fsys, ok := getFS(uri) + if !ok { + return nil, &os.PathError{ + Op: "readdir", + Path: uri.String(), + Err: ErrNotSupported, + } + } + + switch fsys := fsys.(type) { + case ReadDirFS: + return fsys.ReadDir(ctx, uri) + + case FileStore: + return fsys.List(ctx, uri) + } + + return nil, &os.PathError{ + Op: "readdir", + Path: filename, + Err: ErrNotSupported, + } +} + +// List takes a context and a filename (which may be a URL) and +// returns a slice of `os.FileInfo` that describes the files contained in the directory or listing. +// +// DEPRECATED: use `ReadDir`. +func List(ctx context.Context, filename string) ([]os.FileInfo, error) { + switch filename { + case "", "-", "/dev/stdin": + return os.Stdin.Readdir(0) + } + + uri := parsePath(ctx, filename) + if isPath(uri) { + return ioutil.ReadDir(filename) + } + + fsys, ok := getFS(uri) + if !ok { + return nil, &os.PathError{ + Op: "readdir", + Path: uri.String(), + Err: ErrNotSupported, + } + } + + switch fsys := fsys.(type) { + case ReadDirFS: + return fsys.ReadDir(ctx, uri) + + case FileStore: + return fsys.List(ctx, uri) + } + + return nil, &os.PathError{ + Op: "readdir", + Path: filename, + Err: ErrNotSupported, + } +} diff --git a/lib/files/resolve_posix.go b/lib/files/resolve_posix.go new file mode 100644 index 0000000..f2d105d --- /dev/null +++ b/lib/files/resolve_posix.go @@ -0,0 +1,51 @@ +// +build dragonflybsd freebsd linux netbsd openbsd solaris darwin + +package files + +import ( + "net/url" + "path" + "strconv" +) + +func resolveFileHandle(num string) (uintptr, error) { + fd, err := strconv.ParseUint(num, 0, strconv.IntSize) + if err != nil { + return uintptr(^fd), ErrURLInvalid + } + + return uintptr(fd), nil +} + +func resolveFileURL(uri *url.URL) (string, error) { + if uri.User != nil { + return "", ErrURLInvalid + } + + switch uri.Host { + case "", "localhost": + default: + return "", ErrURLInvalid + } + + if name := uri.Opaque; name != "" { + if !path.IsAbs(name) { + // a path in Opaque must start with `/` and not with `%2f`. + return "", ErrURLInvalid + } + + name, err := url.PathUnescape(name) + if err != nil { + return "", ErrURLInvalid + } + + return path.Clean(name), nil + } + + name := uri.Path + if !path.IsAbs(name) { + return "", ErrURLInvalid + } + + return path.Clean(name), nil +} diff --git a/lib/files/resolve_windows.go b/lib/files/resolve_windows.go new file mode 100644 index 0000000..9de80af --- /dev/null +++ b/lib/files/resolve_windows.go @@ -0,0 +1,64 @@ +package files + +import ( + "net/url" + "path" + "path/filepath" + "strconv" + "strings" +) + +func resolveFileHandle(num string) (uintptr, error) { + fd, err := strconv.ParseInt(num, 0, 32) + if err != nil { + return uintptr(^fd), ErrURLInvalid + } + + return uintptr(fd), nil +} + +func resolveFileURL(uri *url.URL) (string, error) { + if uri.User != nil { + return "", ErrURLInvalid + } + + if name := uri.Opaque; name != "" { + if !path.IsAbs(name) { + // a path in Opaque must start with `/` and not with `%2f`. + return "", ErrURLInvalid + } + + name = strings.TrimPrefix(name, "/") + + name, err := url.PathUnescape(name) + if err != nil { + return "", ErrURLInvalid + } + + if !filepath.IsAbs(name) { + return "", ErrURLInvalid + } + + return filepath.Clean(filepath.FromSlash(name)), nil + } + + name := uri.Path + if !path.IsAbs(name) { + return "", ErrURLInvalid + } + + switch uri.Host { + case "", ".": + name = strings.TrimPrefix(name, "/") + + if !filepath.IsAbs(name) { + return "", ErrURLInvalid + } + + return filepath.Clean(filepath.FromSlash(name)), nil + } + + name = filepath.Clean(filepath.FromSlash(name)) + + return `\\` + uri.Host + name, nil +} diff --git a/lib/files/s3files/reader.go b/lib/files/s3files/reader.go index cc84dd9..f65d23d 100644 --- a/lib/files/s3files/reader.go +++ b/lib/files/s3files/reader.go @@ -3,6 +3,7 @@ package s3files import ( "context" "net/url" + "os" "time" "github.com/puellanivis/breton/lib/files" @@ -13,14 +14,22 @@ import ( ) func (h *handler) Open(ctx context.Context, uri *url.URL) (files.Reader, error) { - bucket, key, err := getBucketKey("open", uri) + bucket, key, err := getBucketKey(uri) if err != nil { - return nil, err + return nil, &os.PathError{ + Op: "open", + Path: uri.String(), + Err: err, + } } cl, err := h.getClient(ctx, bucket) if err != nil { - return nil, files.PathError("open", uri.String(), err) + return nil, &os.PathError{ + Op: "open", + Path: uri.String(), + Err: err, + } } req := &s3.GetObjectInput{ @@ -30,12 +39,16 @@ func (h *handler) Open(ctx context.Context, uri *url.URL) (files.Reader, error) res, err := cl.GetObjectWithContext(ctx, req) if err != nil { - return nil, files.PathError("read", uri.String(), err) + return nil, &os.PathError{ + Op: "get_object", + Path: uri.String(), + Err: err, + } } - var l int64 + var sz int64 if res.ContentLength != nil { - l = *res.ContentLength + sz = *res.ContentLength } lm := time.Now() @@ -43,5 +56,5 @@ func (h *handler) Open(ctx context.Context, uri *url.URL) (files.Reader, error) lm = *res.LastModified } - return wrapper.NewReaderWithInfo(res.Body, wrapper.NewInfo(uri, int(l), lm)), nil + return wrapper.NewReaderWithInfo(res.Body, wrapper.NewInfo(uri, int(sz), lm)), nil } diff --git a/lib/files/s3files/s3.go b/lib/files/s3files/s3.go index 42993cb..188492e 100644 --- a/lib/files/s3files/s3.go +++ b/lib/files/s3files/s3.go @@ -72,6 +72,7 @@ func (h *handler) lookup(region string) (*region, error) { if err != nil { return nil, err } + h.rmap[region] = r return r, nil @@ -104,35 +105,51 @@ func (h *handler) getClient(ctx context.Context, bucket string) (*s3.S3, error) return r.cl, nil } -func getBucketKey(op string, uri *url.URL) (bucket, key string, err error) { - if uri.Host == "" || uri.Path == "" { - return "", "", files.PathError(op, uri.String(), os.ErrInvalid) +func getBucketKey(uri *url.URL) (bucket, key string, err error) { + if uri.Host == "" { + return "", "", files.ErrURLHostRequired + } + + if uri.Path == "" { + return "", "", files.ErrURLPathRequired } return uri.Host, uri.Path, nil } -func (h *handler) List(ctx context.Context, uri *url.URL) ([]os.FileInfo, error) { +func (h *handler) ReadDir(ctx context.Context, uri *url.URL) ([]os.FileInfo, error) { if uri.Host == "" { - return nil, files.PathError("list", uri.String(), os.ErrInvalid) + return nil, &os.PathError{ + Op: "readdir", + Path: uri.String(), + Err: files.ErrURLHostRequired, + } } bucket, key := uri.Host, strings.TrimPrefix(uri.Path, "/") cl, err := h.getClient(ctx, bucket) if err != nil { - return nil, files.PathError("list", uri.String(), err) + return nil, &os.PathError{ + Op: "readdir", + Path: uri.String(), + Err: err, + } } req := &s3.ListObjectsInput{ - Bucket: aws.String(bucket), + Bucket: &bucket, Delimiter: aws.String("/"), - Prefix: aws.String(key), + Prefix: &key, } res, err := cl.ListObjectsWithContext(ctx, req) if err != nil { - return nil, files.PathError("list", uri.String(), err) + return nil, &os.PathError{ + Op: "list_objects", + Path: uri.String(), + Err: err, + } } var fi []os.FileInfo diff --git a/lib/files/s3files/s3_test.go b/lib/files/s3files/s3_test.go new file mode 100644 index 0000000..0d6e330 --- /dev/null +++ b/lib/files/s3files/s3_test.go @@ -0,0 +1,19 @@ +package s3files + +import ( + "testing" + + "github.com/puellanivis/breton/lib/files" +) + +func TestHandlerFulfillsReadDirFS(t *testing.T) { + var h files.FS = &handler{} + + if _, ok := h.(files.ReadDirFS); !ok { + t.Fatal("handler does not implement files.ReadDirFS") + } + + if _, ok := h.(files.CreateFS); !ok { + t.Fatal("handler does not implement files.CreateFS") + } +} diff --git a/lib/files/s3files/writer.go b/lib/files/s3files/writer.go index 17ff518..9991558 100644 --- a/lib/files/s3files/writer.go +++ b/lib/files/s3files/writer.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "net/url" + "os" "github.com/puellanivis/breton/lib/files" "github.com/puellanivis/breton/lib/files/wrapper" @@ -13,15 +14,25 @@ import ( ) func (h *handler) Create(ctx context.Context, uri *url.URL) (files.Writer, error) { - bucket, key, err := getBucketKey("create", uri) + bucket, key, err := getBucketKey(uri) if err != nil { - return nil, err + return nil, &os.PathError{ + Op: "create", + Path: uri.String(), + Err: err, + } } - w := wrapper.NewWriter(ctx, uri, func(b []byte) error { + path := uri.String() + + return wrapper.NewWriter(ctx, uri, func(b []byte) error { cl, err := h.getClient(ctx, bucket) if err != nil { - return files.PathError("sync", uri.String(), err) + return &os.PathError{ + Op: "write", + Path: path, + Err: err, + } } req := &s3.PutObjectInput{ @@ -32,11 +43,13 @@ func (h *handler) Create(ctx context.Context, uri *url.URL) (files.Writer, error _, err = cl.PutObjectWithContext(ctx, req) if err != nil { - return files.PathError("sync", uri.String(), err) + return &os.PathError{ + Op: "put_object", + Path: path, + Err: err, + } } return nil - }) - - return w, nil + }), nil } diff --git a/lib/files/sftpfiles/agent.go b/lib/files/sftpfiles/agent.go index 8b75808..15b8d9a 100644 --- a/lib/files/sftpfiles/agent.go +++ b/lib/files/sftpfiles/agent.go @@ -27,7 +27,7 @@ func (a *Agent) Close() error { func GetAgent() (*Agent, error) { sock := os.Getenv("SSH_AUTH_SOCK") if sock == "" { - // No agent setup, so return no agent and not error. + // No agent setup, so return no agent and no error. return nil, nil } diff --git a/lib/files/sftpfiles/host.go b/lib/files/sftpfiles/host.go index c3b42ff..8524bbc 100644 --- a/lib/files/sftpfiles/host.go +++ b/lib/files/sftpfiles/host.go @@ -3,10 +3,9 @@ package sftpfiles import ( "errors" "net/url" + "os/user" "sync" - "github.com/puellanivis/breton/lib/os/user" - "github.com/pkg/sftp" "golang.org/x/crypto/ssh" ) @@ -18,7 +17,8 @@ type Host struct { conn *ssh.Client cl *sftp.Client - uri *url.URL + uri *url.URL + name string auths []ssh.AuthMethod @@ -34,12 +34,12 @@ var ( func getUser() *url.Userinfo { userInit.Do(func() { - name, err := user.CurrentUsername() + u, err := user.Current() if err != nil { return } - defaultUser = url.User(name) + defaultUser = url.User(u.Username) }) return defaultUser @@ -71,13 +71,22 @@ func NewHost(uri *url.URL) *Host { return &Host{ uri: uri, + name: uri.String(), auths: auths, } } // Name returns an identifying name of the Host composed of the authority section of the URL: //user[:pass]@hostname:port func (h *Host) Name() string { - return h.uri.String() + return h.name +} + +func (h *Host) getPath(uri *url.URL) *url.URL { + u := *uri + u.Host = h.uri.Host + u.User = h.uri.User + + return &u } func (h *Host) close() error { @@ -168,21 +177,34 @@ func (h *Host) Connect() (*sftp.Client, error) { } func (h *Host) cloneAuths() []ssh.AuthMethod { - return append([]ssh.AuthMethod{}, h.auths...) + return append([]ssh.AuthMethod(nil), h.auths...) } // addAuths is an internal convenience func to add any number of auths. func (h *Host) addAuths(auths ...ssh.AuthMethod) []ssh.AuthMethod { - return h.SetAuths(append(h.cloneAuths(), auths...)) + h.mu.Lock() + defer h.mu.Unlock() + + clone := h.cloneAuths() + + if len(auths) < 1 { + return clone + } + + return h.setAuths(append(clone, auths...)) } // AddAuth adds the given ssh.AuthMethod to the authorization methods for the Host, and return the previous value. func (h *Host) AddAuth(auth ssh.AuthMethod) []ssh.AuthMethod { - return h.addAuths(auth) + h.mu.Lock() + defer h.mu.Unlock() + + return h.setAuths(append(h.cloneAuths(), auth)) } -// SetAuths sets the slice of ssh.AuthMethod on the Host, and returns the previous value. -func (h *Host) SetAuths(auths []ssh.AuthMethod) []ssh.AuthMethod { +// setAuths sets the slice of ssh.AuthMethod on the Host, and returns the previous value. +// It must be called under lock. +func (h *Host) setAuths(auths []ssh.AuthMethod) []ssh.AuthMethod { save := h.auths h.auths = auths @@ -190,9 +212,20 @@ func (h *Host) SetAuths(auths []ssh.AuthMethod) []ssh.AuthMethod { return save } +// SetAuths sets the slice of ssh.AuthMethod on the Host, and returns the previous value. +func (h *Host) SetAuths(auths []ssh.AuthMethod) []ssh.AuthMethod { + h.mu.Lock() + defer h.mu.Unlock() + + return h.setAuths(append([]ssh.AuthMethod(nil), auths...)) +} + // IgnoreHostKeys sets a flag that Host should ignore Host keys when connecting. // THIS IS INSECURE. func (h *Host) IgnoreHostKeys(state bool) bool { + h.mu.Lock() + defer h.mu.Unlock() + save := h.ignoreHostkey h.ignoreHostkey = state @@ -202,6 +235,9 @@ func (h *Host) IgnoreHostKeys(state bool) bool { // SetHostKeyCallback sets the current hostkey callback for the Host, and returns the previous value. func (h *Host) SetHostKeyCallback(cb ssh.HostKeyCallback, algos []string) (ssh.HostKeyCallback, []string) { + h.mu.Lock() + defer h.mu.Unlock() + saveHK, saveAlgos := h.hostkey, h.hostkeyAlgos h.hostkey = cb diff --git a/lib/files/sftpfiles/option.go b/lib/files/sftpfiles/option.go index 1932dfc..ac30d69 100644 --- a/lib/files/sftpfiles/option.go +++ b/lib/files/sftpfiles/option.go @@ -6,12 +6,6 @@ import ( "golang.org/x/crypto/ssh" ) -func noopOption() files.Option { - return func(_ files.File) (files.Option, error) { - return noopOption(), nil - } -} - func withAuths(auths []ssh.AuthMethod) files.Option { type authSetter interface { SetAuths([]ssh.AuthMethod) []ssh.AuthMethod @@ -20,10 +14,11 @@ func withAuths(auths []ssh.AuthMethod) files.Option { return func(f files.File) (files.Option, error) { h, ok := f.(authSetter) if !ok { - return noopOption(), nil + return nil, files.ErrNotSupported } save := h.SetAuths(auths) + return withAuths(save), nil } } @@ -37,10 +32,11 @@ func WithAuth(auth ssh.AuthMethod) files.Option { return func(f files.File) (files.Option, error) { h, ok := f.(authAdder) if !ok { - return noopOption(), nil + return nil, files.ErrNotSupported } save := h.AddAuth(auth) + return withAuths(save), nil } } @@ -56,10 +52,11 @@ func IgnoreHostKeys(state bool) files.Option { return func(f files.File) (files.Option, error) { h, ok := f.(hostkeyIgnorer) if !ok { - return noopOption(), nil + return nil, files.ErrNotSupported } save := h.IgnoreHostKeys(state) + return IgnoreHostKeys(save), nil } } @@ -72,10 +69,11 @@ func withHostKeyCallback(cb ssh.HostKeyCallback, algos []string) files.Option { return func(f files.File) (files.Option, error) { h, ok := f.(hostkeySetter) if !ok { - return noopOption(), nil + return nil, files.ErrNotSupported } saveHK, saveAlgos := h.SetHostKeyCallback(cb, algos) + return withHostKeyCallback(saveHK, saveAlgos), nil } } diff --git a/lib/files/sftpfiles/reader.go b/lib/files/sftpfiles/reader.go index ec89fda..53b26d9 100644 --- a/lib/files/sftpfiles/reader.go +++ b/lib/files/sftpfiles/reader.go @@ -70,12 +70,16 @@ func (r *reader) Close() error { } func (fs *filesystem) Open(ctx context.Context, uri *url.URL) (files.Reader, error) { - h := fs.getHost(uri) + h, u := fs.getHost(uri) if cl := h.GetClient(); cl != nil { - f, err := cl.Open(uri.Path) + f, err := cl.Open(u.Path) if err != nil { - return nil, files.PathError("open", uri.String(), err) + return nil, &os.PathError{ + Op: "open", + Path: u.String(), + Err: err, + } } return f, nil @@ -83,12 +87,8 @@ func (fs *filesystem) Open(ctx context.Context, uri *url.URL) (files.Reader, err loading := make(chan struct{}) - fixURL := *uri - fixURL.Host = h.uri.Host - fixURL.User = h.uri.User - r := &reader{ - uri: &fixURL, + uri: u, Host: h, loading: loading, @@ -100,19 +100,31 @@ func (fs *filesystem) Open(ctx context.Context, uri *url.URL) (files.Reader, err select { case loading <- struct{}{}: case <-ctx.Done(): - r.err = files.PathError("connect", h.Name(), ctx.Err()) + r.err = &os.PathError{ + Op: "connect", + Path: h.Name(), + Err: ctx.Err(), + } return } cl, err := h.Connect() if err != nil { - r.err = files.PathError("connect", h.Name(), err) + r.err = &os.PathError{ + Op: "connect", + Path: h.Name(), + Err: err, + } return } - f, err := cl.Open(uri.Path) + f, err := cl.Open(u.Path) if err != nil { - r.err = files.PathError("open", r.Name(), err) + r.err = &os.PathError{ + Op: "open", + Path: u.String(), + Err: err, + } return } diff --git a/lib/files/sftpfiles/sftpfiles.go b/lib/files/sftpfiles/sftpfiles.go index fdbfe3e..2c440c3 100644 --- a/lib/files/sftpfiles/sftpfiles.go +++ b/lib/files/sftpfiles/sftpfiles.go @@ -4,11 +4,11 @@ import ( "context" "net/url" "os" + "os/user" "path/filepath" "sync" "github.com/puellanivis/breton/lib/files" - "github.com/puellanivis/breton/lib/os/user" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/knownhosts" @@ -31,60 +31,78 @@ func (fs *filesystem) lazyInit() { fs.auths = append(fs.auths, ssh.PublicKeysCallback(agent.Signers)) } - if home, err := user.CurrentHomeDir(); err == nil { - filename := filepath.Join(home, ".ssh", "known_hosts") + var home string - if cb, err := knownhosts.New(filename); err == nil { - fs.knownhosts = cb + if u, err := user.Current(); err == nil { + home = u.HomeDir + } + + if home == "" { + if h, err := os.UserHomeDir(); err == nil { + home = h } } -} -func init() { - fs := &filesystem{ - hosts: make(map[string]*Host), + if home == "" { + // couldn’t find a home directory, just give up trying to import known_hosts + return } - files.RegisterScheme(fs, "sftp", "scp") + filename := filepath.Join(home, ".ssh", "known_hosts") + + if cb, err := knownhosts.New(filename); err == nil { + fs.knownhosts = cb + } +} + +func init() { + files.RegisterScheme(&filesystem{}, "sftp", "scp") } -func (fs *filesystem) getHost(uri *url.URL) *Host { +func (fs *filesystem) getHost(uri *url.URL) (*Host, *url.URL) { fs.once.Do(fs.lazyInit) h := NewHost(uri) + key := h.Name() fs.mu.Lock() defer fs.mu.Unlock() - key := h.Name() - if h := fs.hosts[key]; h != nil { - return h + return h, h.getPath(uri) } _ = h.addAuths(fs.auths...) _, _ = h.SetHostKeyCallback(fs.knownhosts, nil) + if fs.hosts == nil { + fs.hosts = make(map[string]*Host) + } + fs.hosts[key] = h - return h + return h, h.getPath(uri) } -func (fs *filesystem) List(ctx context.Context, uri *url.URL) ([]os.FileInfo, error) { - h := fs.getHost(uri) +func (fs *filesystem) ReadDir(ctx context.Context, uri *url.URL) ([]os.FileInfo, error) { + h, u := fs.getHost(uri) cl, err := h.Connect() if err != nil { - return nil, files.PathError("connect", h.Name(), err) + return nil, &os.PathError{ + Op: "connect", + Path: h.Name(), + Err: err, + } } fi, err := cl.ReadDir(uri.Path) if err != nil { - fixURL := *uri - fixURL.Host = h.uri.Host - fixURL.User = h.uri.User - - return nil, files.PathError("readdir", fixURL.String(), err) + return nil, &os.PathError{ + Op: "readdir", + Path: u.String(), + Err: err, + } } return fi, nil diff --git a/lib/files/sftpfiles/sftpfiles_test.go b/lib/files/sftpfiles/sftpfiles_test.go new file mode 100644 index 0000000..533b765 --- /dev/null +++ b/lib/files/sftpfiles/sftpfiles_test.go @@ -0,0 +1,19 @@ +package sftpfiles + +import ( + "testing" + + "github.com/puellanivis/breton/lib/files" +) + +func TestHandlerFulfillsReadDirFS(t *testing.T) { + var h files.FS = &filesystem{} + + if _, ok := h.(files.ReadDirFS); !ok { + t.Fatal("handler does not implement files.ReadDirFS") + } + + if _, ok := h.(files.CreateFS); !ok { + t.Fatal("handler does not implement files.CreateFS") + } +} diff --git a/lib/files/sftpfiles/writer.go b/lib/files/sftpfiles/writer.go index 016939b..a1ac148 100644 --- a/lib/files/sftpfiles/writer.go +++ b/lib/files/sftpfiles/writer.go @@ -56,13 +56,6 @@ func (w *writer) Seek(offset int64, whence int) (int64, error) { return w.f.Seek(offset, whence) } -func (w *writer) Sync() error { - for range w.loading { - } - - return nil -} - func (w *writer) Close() error { for range w.loading { } @@ -76,34 +69,26 @@ func (w *writer) Close() error { return w.f.Close() } -type noopSync struct { - *sftp.File -} - -func (f noopSync) Sync() error { - return nil -} - func (fs *filesystem) Create(ctx context.Context, uri *url.URL) (files.Writer, error) { - h := fs.getHost(uri) + h, u := fs.getHost(uri) if cl := h.GetClient(); cl != nil { f, err := cl.Create(uri.Path) if err != nil { - return nil, files.PathError("create", uri.String(), err) + return nil, &os.PathError{ + Op: "create", + Path: u.String(), + Err: err, + } } - return noopSync{f}, nil + return f, nil } loading := make(chan struct{}) - fixURL := *uri - fixURL.Host = h.uri.Host - fixURL.User = h.uri.User - w := &writer{ - uri: &fixURL, + uri: u, Host: h, loading: loading, @@ -115,19 +100,31 @@ func (fs *filesystem) Create(ctx context.Context, uri *url.URL) (files.Writer, e select { case loading <- struct{}{}: case <-ctx.Done(): - w.err = files.PathError("connect", h.Name(), ctx.Err()) + w.err = &os.PathError{ + Op: "connect", + Path: h.Name(), + Err: ctx.Err(), + } return } cl, err := h.Connect() if err != nil { - w.err = files.PathError("connect", h.Name(), err) + w.err = &os.PathError{ + Op: "connect", + Path: h.Name(), + Err: ctx.Err(), + } return } f, err := cl.Create(uri.Path) if err != nil { - w.err = files.PathError("create", w.Name(), err) + w.err = &os.PathError{ + Op: "create", + Path: u.String(), + Err: err, + } return } diff --git a/lib/files/socketfiles/socket.go b/lib/files/socketfiles/socket.go index 7ff642d..34cbf11 100644 --- a/lib/files/socketfiles/socket.go +++ b/lib/files/socketfiles/socket.go @@ -3,7 +3,6 @@ package socketfiles import ( "context" - "errors" "net" "net/url" "strconv" @@ -12,11 +11,6 @@ import ( "golang.org/x/net/ipv4" ) -var ( - errInvalidURL = errors.New("invalid url") - errInvalidIP = errors.New("invalid ip") -) - // URL query field keys. const ( FieldBufferSize = "buffer_size" diff --git a/lib/files/socketfiles/socket_test.go b/lib/files/socketfiles/socket_test.go new file mode 100644 index 0000000..c5a7d36 --- /dev/null +++ b/lib/files/socketfiles/socket_test.go @@ -0,0 +1,25 @@ +package socketfiles + +import ( + "testing" + + "github.com/puellanivis/breton/lib/files" +) + +func TestHandlersFulfillsCreateFS(t *testing.T) { + var h files.FS = tcpHandler{} + + if _, ok := h.(files.CreateFS); !ok { + t.Fatal("tcp handler does not implement files.CreateFS") + } + + h = udpHandler{} + if _, ok := h.(files.CreateFS); !ok { + t.Fatal("udp handler does not implement files.CreateFS") + } + + h = unixHandler{} + if _, ok := h.(files.CreateFS); !ok { + t.Fatal("unix handler does not implement files.CreateFS") + } +} diff --git a/lib/files/socketfiles/stream.go b/lib/files/socketfiles/stream.go index 50e5d4a..f8c2cf9 100644 --- a/lib/files/socketfiles/stream.go +++ b/lib/files/socketfiles/stream.go @@ -8,7 +8,6 @@ import ( "sync" "time" - "github.com/puellanivis/breton/lib/files" "github.com/puellanivis/breton/lib/files/wrapper" ) @@ -151,7 +150,11 @@ func newStreamReader(ctx context.Context, l net.Listener) (*streamReader, error) select { case loading <- struct{}{}: case <-ctx.Done(): - r.err = files.PathError("open", uri.String(), ctx.Err()) + r.err = &os.PathError{ + Op: "accept", + Path: uri.String(), + Err: ctx.Err(), + } return } @@ -165,7 +168,11 @@ func newStreamReader(ctx context.Context, l net.Listener) (*streamReader, error) } if err := do(ctx, accept); err != nil { - r.err = files.PathError("accept", uri.String(), err) + r.err = &os.PathError{ + Op: "accept", + Path: uri.String(), + Err: err, + } return } diff --git a/lib/files/socketfiles/tcp.go b/lib/files/socketfiles/tcp.go index 83cb8f4..a70724d 100644 --- a/lib/files/socketfiles/tcp.go +++ b/lib/files/socketfiles/tcp.go @@ -12,35 +12,55 @@ import ( type tcpHandler struct{} func init() { - files.RegisterScheme(&tcpHandler{}, "tcp") + files.RegisterScheme(tcpHandler{}, "tcp") } -func (h *tcpHandler) Open(ctx context.Context, uri *url.URL) (files.Reader, error) { +func (tcpHandler) Open(ctx context.Context, uri *url.URL) (files.Reader, error) { if uri.Host == "" { - return nil, files.PathError("open", uri.String(), errInvalidURL) + return nil, &os.PathError{ + Op: "open", + Path: uri.String(), + Err: files.ErrURLHostRequired, + } } laddr, err := net.ResolveTCPAddr("tcp", uri.Host) if err != nil { - return nil, files.PathError("open", uri.String(), err) + return nil, &os.PathError{ + Op: "open", + Path: uri.String(), + Err: err, + } } l, err := net.ListenTCP("tcp", laddr) if err != nil { - return nil, files.PathError("open", uri.String(), err) + return nil, &os.PathError{ + Op: "open", + Path: uri.String(), + Err: err, + } } return newStreamReader(ctx, l) } -func (h *tcpHandler) Create(ctx context.Context, uri *url.URL) (files.Writer, error) { +func (tcpHandler) Create(ctx context.Context, uri *url.URL) (files.Writer, error) { if uri.Host == "" { - return nil, files.PathError("create", uri.String(), errInvalidURL) + return nil, &os.PathError{ + Op: "create", + Path: uri.String(), + Err: files.ErrURLHostRequired, + } } raddr, err := net.ResolveTCPAddr("tcp", uri.Host) if err != nil { - return nil, files.PathError("create", uri.String(), err) + return nil, &os.PathError{ + Op: "create", + Path: uri.String(), + Err: err, + } } q := uri.Query() @@ -52,7 +72,11 @@ func (h *tcpHandler) Create(ctx context.Context, uri *url.URL) (files.Writer, er if host != "" || port != "" { laddr, err = net.ResolveTCPAddr("tcp", net.JoinHostPort(host, port)) if err != nil { - return nil, files.PathError("create", uri.String(), err) + return nil, &os.PathError{ + Op: "create", + Path: uri.String(), + Err: err, + } } } @@ -66,18 +90,22 @@ func (h *tcpHandler) Create(ctx context.Context, uri *url.URL) (files.Writer, er } if err := do(ctx, dial); err != nil { - return nil, files.PathError("create", uri.String(), err) + return nil, &os.PathError{ + Op: "create", + Path: uri.String(), + Err: err, + } } sock, err := sockWriter(conn, laddr != nil, q) if err != nil { conn.Close() - return nil, files.PathError("create", uri.String(), err) + return nil, &os.PathError{ + Op: "create", + Path: uri.String(), + Err: err, + } } return newStreamWriter(ctx, sock), nil } - -func (h *tcpHandler) List(ctx context.Context, uri *url.URL) ([]os.FileInfo, error) { - return nil, files.PathError("readdir", uri.String(), os.ErrInvalid) -} diff --git a/lib/files/socketfiles/udp.go b/lib/files/socketfiles/udp.go index 47e0e54..a497edd 100644 --- a/lib/files/socketfiles/udp.go +++ b/lib/files/socketfiles/udp.go @@ -12,22 +12,34 @@ import ( type udpHandler struct{} func init() { - files.RegisterScheme(&udpHandler{}, "udp") + files.RegisterScheme(udpHandler{}, "udp") } -func (h *udpHandler) Open(ctx context.Context, uri *url.URL) (files.Reader, error) { +func (udpHandler) Open(ctx context.Context, uri *url.URL) (files.Reader, error) { if uri.Host == "" { - return nil, files.PathError("open", uri.String(), errInvalidURL) + return nil, &os.PathError{ + Op: "open", + Path: uri.String(), + Err: files.ErrURLHostRequired, + } } laddr, err := net.ResolveUDPAddr("udp", uri.Host) if err != nil { - return nil, files.PathError("open", uri.String(), err) + return nil, &os.PathError{ + Op: "open", + Path: uri.String(), + Err: err, + } } conn, err := net.ListenUDP("udp", laddr) if err != nil { - return nil, files.PathError("open", uri.String(), err) + return nil, &os.PathError{ + Op: "open", + Path: uri.String(), + Err: err, + } } // Maybe we asked for an arbitrary port, @@ -37,20 +49,32 @@ func (h *udpHandler) Open(ctx context.Context, uri *url.URL) (files.Reader, erro sock, err := sockReader(conn, uri.Query()) if err != nil { conn.Close() - return nil, files.PathError("open", uri.String(), err) + return nil, &os.PathError{ + Op: "open", + Path: uri.String(), + Err: err, + } } return newDatagramReader(ctx, sock), nil } -func (h *udpHandler) Create(ctx context.Context, uri *url.URL) (files.Writer, error) { +func (udpHandler) Create(ctx context.Context, uri *url.URL) (files.Writer, error) { if uri.Host == "" { - return nil, files.PathError("create", uri.String(), errInvalidURL) + return nil, &os.PathError{ + Op: "create", + Path: uri.String(), + Err: files.ErrURLHostRequired, + } } raddr, err := net.ResolveUDPAddr("udp", uri.Host) if err != nil { - return nil, files.PathError("create", uri.String(), err) + return nil, &os.PathError{ + Op: "create", + Path: uri.String(), + Err: err, + } } q := uri.Query() @@ -62,7 +86,11 @@ func (h *udpHandler) Create(ctx context.Context, uri *url.URL) (files.Writer, er if host != "" || port != "" { laddr, err = net.ResolveUDPAddr("udp", net.JoinHostPort(host, port)) if err != nil { - return nil, files.PathError("create", uri.String(), err) + return nil, &os.PathError{ + Op: "create", + Path: uri.String(), + Err: err, + } } } @@ -76,18 +104,22 @@ func (h *udpHandler) Create(ctx context.Context, uri *url.URL) (files.Writer, er } if err := do(ctx, dial); err != nil { - return nil, files.PathError("create", uri.String(), err) + return nil, &os.PathError{ + Op: "create", + Path: uri.String(), + Err: err, + } } sock, err := sockWriter(conn, laddr != nil, q) if err != nil { conn.Close() - return nil, files.PathError("create", uri.String(), err) + return nil, &os.PathError{ + Op: "create", + Path: uri.String(), + Err: err, + } } return newDatagramWriter(ctx, sock), nil } - -func (h *udpHandler) List(ctx context.Context, uri *url.URL) ([]os.FileInfo, error) { - return nil, files.PathError("readdir", uri.String(), os.ErrInvalid) -} diff --git a/lib/files/socketfiles/unixsock.go b/lib/files/socketfiles/unixsock.go index ea1c630..bc3bc1f 100644 --- a/lib/files/socketfiles/unixsock.go +++ b/lib/files/socketfiles/unixsock.go @@ -2,7 +2,6 @@ package socketfiles import ( "context" - "errors" "net" "net/url" "os" @@ -13,32 +12,52 @@ import ( type unixHandler struct{} func init() { - files.RegisterScheme(&unixHandler{}, "unix", "unixgram") + files.RegisterScheme(unixHandler{}, "unix", "unixgram") } -func (h *unixHandler) Open(ctx context.Context, uri *url.URL) (files.Reader, error) { +func (h unixHandler) Open(ctx context.Context, uri *url.URL) (files.Reader, error) { + if uri.Host != "" || uri.User != nil { + return nil, files.ErrURLCannotHaveAuthority + } + path := uri.Path if path == "" { - path = uri.Opaque + var err error + path, err = url.PathUnescape(uri.Opaque) + if err != nil { + return nil, files.ErrURLInvalid + } } network := uri.Scheme laddr, err := net.ResolveUnixAddr(network, path) if err != nil { - return nil, files.PathError("open", uri.String(), err) + return nil, &os.PathError{ + Op: "open", + Path: uri.String(), + Err: err, + } } switch laddr.Network() { case "unixgram": conn, err := net.ListenUnixgram(network, laddr) if err != nil { - return nil, files.PathError("open", uri.String(), err) + return nil, &os.PathError{ + Op: "open", + Path: uri.String(), + Err: err, + } } sock, err := sockReader(conn, uri.Query()) if err != nil { conn.Close() - return nil, files.PathError("open", uri.String(), err) + return nil, &os.PathError{ + Op: "open", + Path: uri.String(), + Err: err, + } } return newDatagramReader(ctx, sock), nil @@ -46,19 +65,35 @@ func (h *unixHandler) Open(ctx context.Context, uri *url.URL) (files.Reader, err case "unix": l, err := net.ListenUnix(network, laddr) if err != nil { - return nil, files.PathError("open", uri.String(), err) + return nil, &os.PathError{ + Op: "open", + Path: uri.String(), + Err: err, + } } return newStreamReader(ctx, l) } - return nil, files.PathError("create", uri.String(), errors.New("unknown unix socket type")) + return nil, &os.PathError{ + Op: "open", + Path: uri.String(), + Err: files.ErrNotSupported, + } } -func (h *unixHandler) Create(ctx context.Context, uri *url.URL) (files.Writer, error) { +func (h unixHandler) Create(ctx context.Context, uri *url.URL) (files.Writer, error) { + if uri.Host != "" || uri.User != nil { + return nil, files.ErrURLCannotHaveAuthority + } + path := uri.Path if path == "" { - path = uri.Opaque + var err error + path, err = url.PathUnescape(uri.Opaque) + if err != nil { + return nil, files.ErrURLInvalid + } } network := uri.Scheme @@ -75,7 +110,11 @@ func (h *unixHandler) Create(ctx context.Context, uri *url.URL) (files.Writer, e if addr != "" { laddr, err = net.ResolveUnixAddr(network, addr) if err != nil { - return nil, files.PathError("create", uri.String(), err) + return nil, &os.PathError{ + Op: "create", + Path: uri.String(), + Err: err, + } } } @@ -89,13 +128,21 @@ func (h *unixHandler) Create(ctx context.Context, uri *url.URL) (files.Writer, e } if err := do(ctx, dial); err != nil { - return nil, files.PathError("create", uri.String(), err) + return nil, &os.PathError{ + Op: "create", + Path: uri.String(), + Err: err, + } } sock, err := sockWriter(conn, laddr != nil, q) if err != nil { conn.Close() - return nil, files.PathError("create", uri.String(), err) + return nil, &os.PathError{ + Op: "create", + Path: uri.String(), + Err: err, + } } switch network { @@ -107,9 +154,9 @@ func (h *unixHandler) Create(ctx context.Context, uri *url.URL) (files.Writer, e } conn.Close() - return nil, files.PathError("create", uri.String(), errors.New("unknown unix socket type")) -} - -func (h *unixHandler) List(ctx context.Context, uri *url.URL) ([]os.FileInfo, error) { - return nil, files.PathError("readdir", uri.String(), os.ErrInvalid) + return nil, &os.PathError{ + Op: "create", + Path: uri.String(), + Err: files.ErrNotSupported, + } } diff --git a/lib/files/wrapper/info.go b/lib/files/wrapper/info.go index 12d8ced..1659e3c 100644 --- a/lib/files/wrapper/info.go +++ b/lib/files/wrapper/info.go @@ -22,8 +22,10 @@ type Info struct { // NewInfo returns a new Info set with the url, size and time specified. func NewInfo(uri *url.URL, size int, t time.Time) *Info { + u := *uri + return &Info{ - uri: uri, + uri: &u, sz: int64(size), mode: os.FileMode(0644), t: t, @@ -52,11 +54,30 @@ func (fi *Info) SetNameFromURL(uri *url.URL) { fi.mu.Lock() defer fi.mu.Unlock() - fi.uri = uri + u := *uri + fi.uri = &u + fi.name = "" } -func (fi *Info) fixName() string { +// Name returns the filename of the Info, if name == "" and there is a url, +// then it renders the url, and returns that as the name. +func (fi *Info) Name() string { + if fi == nil { + return "" + } + + fi.mu.RLock() + name, uri := fi.name, fi.uri + fi.mu.RUnlock() + + if name != "" || uri == nil { + return name + } + + fi.mu.Lock() + defer fi.mu.Unlock() + if fi.name != "" || fi.uri == nil { // Nothing to fix. // Likely, someone else already fixed the name while we were waiting on the mutex. @@ -74,25 +95,33 @@ func (fi *Info) fixName() string { return fi.name } -// Name returns the filename of the Info, if name == "" and there is a url, -// then it renders the url, and returns that as the name. -func (fi *Info) Name() string { +// URL returns the URL of the Info, if there is no URL yet, +// then it will set the URL to be `&url.URL{ Path: name }`. +func (fi *Info) URL() *url.URL { if fi == nil { - return "" + return &url.URL{} } fi.mu.RLock() - name, uri := fi.name, fi.uri + uri := fi.uri fi.mu.RUnlock() - if name == "" && uri != nil { - fi.mu.Lock() - defer fi.mu.Unlock() + if uri != nil { + u := *uri + return &u + } - return fi.fixName() + fi.mu.Lock() + defer fi.mu.Unlock() + + if fi.uri == nil { + fi.uri = &url.URL{ + Path: fi.name, + } } - return name + u := *fi.uri + return &u } // Size returns the size declared in the Info. @@ -102,9 +131,10 @@ func (fi *Info) Size() int64 { } fi.mu.RLock() - defer fi.mu.RUnlock() + sz := fi.sz + fi.mu.RUnlock() - return fi.sz + return sz } // SetSize sets a new size in the Info. @@ -126,9 +156,10 @@ func (fi *Info) Mode() (mode os.FileMode) { } fi.mu.RLock() - defer fi.mu.RUnlock() + mode = fi.mode + fi.mu.RUnlock() - return fi.mode + return mode } // Chmod sets the os.FileMode to be returned from Mode(). @@ -151,9 +182,10 @@ func (fi *Info) ModTime() (t time.Time) { } fi.mu.RLock() - defer fi.mu.RUnlock() + t = fi.t + fi.mu.RUnlock() - return fi.t + return t } // SetModTime sets the modification time in the Info to the time.Time given. @@ -171,9 +203,10 @@ func (fi *Info) IsDir() bool { } fi.mu.RLock() - defer fi.mu.RUnlock() + isDir := fi.mode&os.ModeDir != 0 + fi.mu.RUnlock() - return fi.mode&os.ModeDir != 0 + return isDir } // Sys returns the Info object itself, as it is already the underlying data source. diff --git a/lib/files/wrapper/reader.go b/lib/files/wrapper/reader.go index e9aa930..39122b5 100644 --- a/lib/files/wrapper/reader.go +++ b/lib/files/wrapper/reader.go @@ -15,7 +15,6 @@ type Reader struct { fi os.FileInfo r io.Reader - s io.Seeker } // NewReaderWithInfo returns a new Reader with the given FileInfo. @@ -54,16 +53,11 @@ func (r *Reader) Seek(offset int64, whence int) (int64, error) { r.mu.Lock() defer r.mu.Unlock() - if r.s == nil { - switch s := r.r.(type) { - case io.Seeker: - r.s = s - default: - return 0, os.ErrInvalid - } + if s, ok := r.r.(io.Seeker); ok { + return s.Seek(offset, whence) } - return r.s.Seek(offset, whence) + return 0, os.ErrInvalid } // Close recovers resources assigned in the Reader. @@ -73,17 +67,15 @@ func (r *Reader) Close() error { var err error - switch c := r.r.(type) { - case nil: - err = os.ErrClosed + if r.r == nil { + return os.ErrClosed + } - case io.Closer: + if c, ok := r.r.(io.Closer); ok { err = c.Close() } - r.s = nil r.r = nil - r.fi = nil return err } diff --git a/lib/files/wrapper/reader_test.go b/lib/files/wrapper/reader_test.go index 3961135..1cfd745 100644 --- a/lib/files/wrapper/reader_test.go +++ b/lib/files/wrapper/reader_test.go @@ -7,7 +7,7 @@ import ( ) func TestImplementsFilesReader(t *testing.T) { - var f files.Reader = new(Reader) + var f files.SeekReader = new(Reader) _ = f } diff --git a/lib/files/wrapper/writer.go b/lib/files/wrapper/writer.go index a30fe9e..f6a8e0f 100644 --- a/lib/files/wrapper/writer.go +++ b/lib/files/wrapper/writer.go @@ -3,7 +3,6 @@ package wrapper import ( "bytes" "context" - "io" "net/url" "os" "sync" @@ -15,11 +14,8 @@ type Writer struct { mu sync.Mutex *Info - b bytes.Buffer - - flush chan struct{} - done chan struct{} - errch chan error + b *bytes.Buffer + do func([]byte) error // must be called with lock. } // WriteFunc is a function that is intended to write the given byte slice to some @@ -30,49 +26,18 @@ type WriteFunc func([]byte) error // NewWriter returns a Writer that is setup to call the given WriteFunc with // the underlying buffer on every Sync, and Close. func NewWriter(ctx context.Context, uri *url.URL, f WriteFunc) *Writer { - wr := &Writer{ - Info: NewInfo(uri, 0, time.Now()), - flush: make(chan struct{}), - done: make(chan struct{}), - errch: make(chan error), - } + info := NewInfo(uri, 0, time.Now()) - doWrite := func() error { - wr.mu.Lock() - defer wr.mu.Unlock() - - // Update ModTime to now. - wr.Info.SetModTime(time.Now()) - return f(wr.b.Bytes()) + wr := &Writer{ + Info: info, + b: new(bytes.Buffer), + do: func(b []byte) error { + // Update ModTime to now. + info.SetModTime(time.Now()) + return f(b) + }, } - go func() { - defer func() { - close(wr.errch) - close(wr.flush) - }() - - for { - select { - case <-wr.done: - // For done, we only send a non-nil err, - // When we close the errch, it will then return nil errors. - if err := doWrite(); err != nil { - wr.errch <- err - } - return - - case <-wr.flush: - // For flush, we send even nil errors, - // Otherwise, the Sync() routine would block forever waiting on an errch. - wr.errch <- doWrite() - - case <-ctx.Done(): - return - } - } - }() - return wr } @@ -81,6 +46,11 @@ func (w *Writer) Write(b []byte) (n int, err error) { w.mu.Lock() defer w.mu.Unlock() + if w.b == nil { + // cannot write to closed Writer + return 0, os.ErrClosed + } + n, err = w.b.Write(b) w.Info.SetSize(w.b.Len()) @@ -88,60 +58,30 @@ func (w *Writer) Write(b []byte) (n int, err error) { return n, err } -func (w *Writer) signalSync() error { - w.mu.Lock() - defer w.mu.Unlock() - - select { - case <-w.done: - // cannot flush a closed Writer. - return io.ErrClosedPipe - default: - } - - w.flush <- struct{}{} - return nil -} - // Sync calls the defined WriteFunc for the Writer with the entire underlying buffer. func (w *Writer) Sync() error { - if err := w.signalSync(); err != nil { - return err - } - - // We cannot wait here under Lock, because the sync process requires the Lock. - return <-w.errch -} - -func (w *Writer) markDone() error { w.mu.Lock() defer w.mu.Unlock() - select { - case <-w.done: - // already closed + if w.b == nil { + // cannot sync a closed Writer return os.ErrClosed - default: } - close(w.done) - return nil + return w.do(w.b.Bytes()) } // Close performs a marks the Writer as complete, which also causes a Sync. func (w *Writer) Close() error { - if err := w.markDone(); err != nil { - return err - } - - var err error + w.mu.Lock() + defer w.mu.Unlock() - // We cannot wait here under Lock, because the sync process requires the Lock. - for err2 := range w.errch { - if err == nil { - err = err2 - } + if w.b == nil { + // cannot sync a closed Writer + return os.ErrClosed } + data := w.b.Bytes() + w.b = nil - return nil + return w.do(data) } diff --git a/lib/files/wrapper/writer_test.go b/lib/files/wrapper/writer_test.go index a1f5363..4cec475 100644 --- a/lib/files/wrapper/writer_test.go +++ b/lib/files/wrapper/writer_test.go @@ -7,7 +7,7 @@ import ( ) func TestImplementsFilesWriter(t *testing.T) { - var f files.Writer = new(Writer) + var f files.SyncWriter = new(Writer) _ = f } diff --git a/lib/files/write.go b/lib/files/write.go index eb23607..b8db4da 100644 --- a/lib/files/write.go +++ b/lib/files/write.go @@ -5,15 +5,23 @@ import ( "io" ) -// WriteTo will write the given data to the io.WriteCloser and Close the writer. -func WriteTo(w io.WriteCloser, data []byte) error { +// WriteTo will write the given data to the io.Writer, +// it will then Close the writer if it implements io.Closer. +// +// Note: this function will return io.ErrShortWrite, +// if the amount written is less than the data given as input. +func WriteTo(w io.Writer, data []byte) error { n, err := w.Write(data) if err == nil && n < len(data) { err = io.ErrShortWrite } - if err1 := w.Close(); err == nil { - err = err1 + + if c, ok := w.(io.Closer); ok { + if err2 := c.Close(); err == nil { + err = err2 + } } + return err }