Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
ybizeul committed Sep 17, 2024
2 parents 7ef5166 + 8843a4d commit ae1946d
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 39 deletions.
24 changes: 15 additions & 9 deletions internal/feed/feed.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ func (feed *Feed) GetPublicItem(i string) (*PublicFeedItem, error) {
return nil, FeedErrorInvalidFeedItem
}

s, err := os.Stat(path.Join(feed.Path, i))
s, err := os.Stat(path.Join(feed.Path, path.Join("/", i)))

if err != nil {
if os.IsNotExist(err) {
Expand All @@ -290,8 +290,11 @@ func (feed *Feed) GetItemData(item string) ([]byte, error) {
var content []byte

// Get path to feed item
filePath := path.Join(feed.Path, item)
filePath := path.Join(feed.Path, path.Join("/"+item))

if path.Base(filePath) == "secret" || path.Base(filePath) == "pin" || path.Base(filePath) == "config.json" {
return nil, fmt.Errorf("%w: %s", FeedErrorItemNotFound, item)
}
// Read feed item content
content, err := os.ReadFile(filePath)
if err != nil {
Expand Down Expand Up @@ -329,7 +332,7 @@ func (feed *Feed) IsSecretValid(secret string) error {

// AddItem reads content from r and creates a new file in the feed directory
// with a name and file extension based on contentType, then notifies clients
func (f *Feed) AddItem(contentType string, r io.ReadCloser) error {
func (f *Feed) AddItem(contentType string, r io.Reader) error {
fL.Logger.Debug("Adding Item", slog.String("feed", f.Name()), slog.String("content-type", contentType))

var err error
Expand Down Expand Up @@ -399,10 +402,11 @@ func (f *Feed) AddItem(contentType string, r io.ReadCloser) error {
}

// Notify additon to all connected browsers
if err = f.WebSocketManager.NotifyAdd(publicItem); err != nil {
return err
if f.WebSocketManager != nil {
if err = f.WebSocketManager.NotifyAdd(publicItem); err != nil {
return err
}
}

// Send push notification to subscribed browsers
err = f.sendPushNotification()
if err != nil {
Expand All @@ -419,7 +423,7 @@ func (f *Feed) AddItem(contentType string, r io.ReadCloser) error {
func (f *Feed) RemoveItem(item string) error {
fL.Logger.Debug("Remove Item", slog.String("name", item), slog.String("feed", f.Path))

itemPath := path.Join(f.Path, item)
itemPath := path.Join(f.Path, path.Join("/", item))

// Save public item before deletion for notification later
publicItem, err := f.GetPublicItem(item)
Expand All @@ -437,8 +441,10 @@ func (f *Feed) RemoveItem(item string) error {
}

// Notify all connected websockets
if err = f.WebSocketManager.NotifyRemove(publicItem); err != nil {
return err
if f.WebSocketManager != nil {
if err = f.WebSocketManager.NotifyRemove(publicItem); err != nil {
return err
}
}

fL.Logger.Debug("Removed Item", slog.String("name", item), slog.String("feed", f.Path))
Expand Down
129 changes: 129 additions & 0 deletions internal/feed/feed_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
package feed

import (
"bytes"
"os"
"testing"
)

func TestGetFeedItemData(t *testing.T) {
t.Cleanup(func() {
os.RemoveAll("tests/feed1")
})
f, err := NewFeed("tests/feed1")
if err != nil {
t.Fatal(err)
}

reader := bytes.NewReader([]byte("test"))

err = f.AddItem("text/plain", reader)
if err != nil {
t.Fatal(err)
}

pf, err := f.Public()
if err != nil {
t.Fatal(err)
}
i := pf.Items[0]
b, err := f.GetItemData(i.Name)
if len(b) == 0 || err != nil {
t.Fatal(err)
}
}

func TestPathTraversalGet(t *testing.T) {
t.Cleanup(func() {
os.RemoveAll("tests/feed1")
os.RemoveAll("tests/feed2")
})
_, err := NewFeed("tests/feed1")
if err != nil {
t.Fatal(err)
}

f, err := NewFeed("tests/feed2")
if err != nil {
t.Fatal(err)
}

b, err := f.GetItemData("../feed1/config.json")

if len(b) != 0 || err == nil {
t.Fatal("Path traversal not blocked")
}
}

func TestPathTraversalDelete(t *testing.T) {
t.Cleanup(func() {
os.RemoveAll("tests/feed1")
os.RemoveAll("tests/feed2")
})
_, err := NewFeed("tests/feed1")
if err != nil {
t.Fatal(err)
}

f, err := NewFeed("tests/feed2")
if err != nil {
t.Fatal(err)
}

err = f.RemoveItem("../feed1/config.json")

if err == nil {
t.Fatal("Path traversal not blocked")
}
}

func TestPathTraversalPublicItem(t *testing.T) {
t.Cleanup(func() {
os.RemoveAll("tests/feed1")
os.RemoveAll("tests/feed2")
})
_, err := NewFeed("tests/feed1")
if err != nil {
t.Fatal(err)
}

f, err := NewFeed("tests/feed2")
if err != nil {
t.Fatal(err)
}

p, err := f.GetPublicItem("../feed1/config.json")

if p != nil || err == nil {
t.Fatal("Path traversal not blocked")
}
}

func TestPublicItem(t *testing.T) {
t.Cleanup(func() {
os.RemoveAll("tests/feed1")
})
f, err := NewFeed("tests/feed1")
if err != nil {
t.Fatal(err)
}

reader := bytes.NewReader([]byte("test"))

err = f.AddItem("text/plain", reader)
if err != nil {
t.Fatal(err)
}

pf, err := f.Public()
if err != nil {
t.Fatal(err)
}
i := pf.Items[0]

p, err := f.GetPublicItem(i.Name)

if p == nil || err != nil {
t.Fatal(err)
}
}
25 changes: 20 additions & 5 deletions internal/feed/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,15 @@ func (m *WebSocketManager) RunSocketForFeed(feedName string, w http.ResponseWrit
m.FeedSockets = append(m.FeedSockets, feedSockets)
}

// Upgrade http connection to websocket
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
utils.CloseWithCodeAndMessage(w, 500, "Unable to upgrade WebSocket")
}

feedSockets.websockets = append(feedSockets.websockets, c)

// Get provided secret and validate feed access
secret, _ := utils.GetSecret(r)

f, err := m.FeedManager.GetFeedWithAuth(feedName, secret)
Expand All @@ -113,19 +115,26 @@ func (m *WebSocketManager) RunSocketForFeed(feedName string, w http.ResponseWrit
}
}

// Cleanup
defer func() {
feedSockets.RemoveConn(c)
c.Close()
}()

// Start waiting for messages
for {
mt, message, err := c.ReadMessage()
wsL.Logger.Debug("Message Received", slog.String("message", string(message)), slog.Int("messageType", mt))
wsL.Logger.Debug("Message Received",
slog.String("message", string(message)),
slog.Int("messageType", mt))
if err != nil {
slog.Error("Error reading message", slog.String("error", err.Error()), slog.Int("messageType", mt))
slog.Error("Error reading message",
slog.String("error", err.Error()),
slog.Int("messageType", mt))
break
}
switch strings.TrimSpace(string(message)) {
// Return pubic feed content
case "feed":
pf, err := f.Public()
if err != nil {
Expand All @@ -139,8 +148,11 @@ func (m *WebSocketManager) RunSocketForFeed(feedName string, w http.ResponseWrit
}
}

// NotifyAdd notifies all connected websockets that an item has been added
func (m *WebSocketManager) NotifyAdd(item *PublicFeedItem) error {
wsL.Logger.Debug("Notify websocket", slog.Any("item", item), slog.Int("ws count", len(m.FeedSockets)))
wsL.Logger.Debug("Notify websocket",
slog.Any("item", item),
slog.Int("ws count", len(m.FeedSockets)))
for _, f := range m.FeedSockets {
wsL.Logger.Debug("checking feed", slog.String("feedName", f.feedName))
if f.feedName == item.Feed.Name {
Expand All @@ -158,12 +170,15 @@ func (m *WebSocketManager) NotifyAdd(item *PublicFeedItem) error {
return nil
}

// NotifyRemove notify all connected websockets that an item has been removed
func (m *WebSocketManager) NotifyRemove(item *PublicFeedItem) error {
wsL.Logger.Debug("Notify websocket", slog.Any("item", item), slog.Int("ws count", len(m.FeedSockets)))
wsL.Logger.Debug("Notify websocket",
slog.Any("item", item),
slog.Int("ws count", len(m.FeedSockets)))
for _, f := range m.FeedSockets {
wsL.Logger.Debug("checking feed", slog.String("feedName", f.feedName))
if f.feedName == item.Feed.Name {
wsL.Logger.Debug("Found feed", slog.String("feedName", f.feedName))
wsL.Logger.Debug("found feed", slog.String("feedName", f.feedName))
for _, w := range f.websockets {
if err := w.WriteJSON(FeedNotification{
Action: "remove",
Expand Down
Loading

0 comments on commit ae1946d

Please sign in to comment.