Skip to content

Commit

Permalink
add 'timeout' parameter to the OSSession.SaveData method
Browse files Browse the repository at this point in the history
  • Loading branch information
darkdarkdragon committed Jun 29, 2021
1 parent 8027b6e commit f383b08
Show file tree
Hide file tree
Showing 18 changed files with 67 additions and 54 deletions.
11 changes: 7 additions & 4 deletions core/capabilities_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"sort"
"testing"
"time"

"github.com/livepeer/go-livepeer/drivers"
"github.com/livepeer/go-livepeer/net"
Expand Down Expand Up @@ -270,10 +271,12 @@ func (os *stubOS) GetInfo() *net.OSInfo {
}
return &net.OSInfo{StorageType: net.OSInfo_StorageType(os.storageType)}
}
func (os *stubOS) EndSession() {}
func (os *stubOS) SaveData(string, []byte, map[string]string) (string, error) { return "", nil }
func (os *stubOS) IsExternal() bool { return false }
func (os *stubOS) IsOwn(url string) bool { return true }
func (os *stubOS) EndSession() {}
func (os *stubOS) SaveData(string, []byte, map[string]string, time.Duration) (string, error) {
return "", nil
}
func (os *stubOS) IsExternal() bool { return false }
func (os *stubOS) IsOwn(url string) bool { return true }
func (os *stubOS) ListFiles(ctx context.Context, prefix, delim string) (drivers.PageInfo, error) {
return nil, nil
}
Expand Down
6 changes: 3 additions & 3 deletions core/livepeernode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,23 +130,23 @@ func TestServiceURIChange(t *testing.T) {

drivers.NodeStorage = drivers.NewMemoryDriver(n.GetServiceURI())
sesh := drivers.NodeStorage.NewSession("testpath")
savedUrl, err := sesh.SaveData("testdata1", []byte{0, 0, 0}, nil)
savedUrl, err := sesh.SaveData("testdata1", []byte{0, 0, 0}, nil, 0)
require.Nil(err)
assert.Equal("test://testurl.com/stream/testpath/testdata1", savedUrl)

glog.Infof("Setting service URL to newurl")
newUrl, err := url.Parse("test://newurl.com")
n.SetServiceURI(newUrl)
require.Nil(err)
furl, err := sesh.SaveData("testdata2", []byte{0, 0, 0}, nil)
furl, err := sesh.SaveData("testdata2", []byte{0, 0, 0}, nil, 0)
require.Nil(err)
assert.Equal("test://newurl.com/stream/testpath/testdata2", furl)

glog.Infof("Setting service URL to secondurl")
secondUrl, err := url.Parse("test://secondurl.com")
n.SetServiceURI(secondUrl)
require.Nil(err)
surl, err := sesh.SaveData("testdata3", []byte{0, 0, 0}, nil)
surl, err := sesh.SaveData("testdata3", []byte{0, 0, 0}, nil, 0)
require.Nil(err)
assert.Equal("test://secondurl.com/stream/testpath/testdata3", surl)
}
Expand Down
2 changes: 1 addition & 1 deletion core/orchestrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ func (n *LivepeerNode) transcodeSeg(config transcodeConfig, seg *stream.HLSSegme
// Need to store segment in our local OS
var err error
name := fmt.Sprintf("%d.tempfile", seg.SeqNo)
url, err = config.LocalOS.SaveData(name, seg.Data, nil)
url, err = config.LocalOS.SaveData(name, seg.Data, nil, 0)
if err != nil {
return terr(err)
}
Expand Down
2 changes: 1 addition & 1 deletion core/playlistmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ func (mgr *BasicPlaylistManager) FlushRecord() {
}
go func(name string, data []byte) {
now := time.Now()
_, err := mgr.recordSession.SaveData(name, b, nil)
_, err := mgr.recordSession.SaveData(name, b, nil, 0)
took := time.Since(now)
if err != nil {
glog.Errorf("Error saving json playlist name=%s bytes=%d took=%s err=%v", name,
Expand Down
2 changes: 1 addition & 1 deletion core/playlistmanager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ func TestCleanup(t *testing.T) {
testData := []byte{1, 2, 3, 4}

c := NewBasicPlaylistManager(mid, osSession, nil)
uri, err := c.GetOSSession().SaveData("testName", testData, nil)
uri, err := c.GetOSSession().SaveData("testName", testData, nil, 0)
if err != nil {
t.Fatal(err)
}
Expand Down
6 changes: 3 additions & 3 deletions drivers/drivers.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ type PageInfo interface {
type OSSession interface {
OS() OSDriver

SaveData(name string, data []byte, meta map[string]string) (string, error)
SaveData(name string, data []byte, meta map[string]string, timeout time.Duration) (string, error)
EndSession()

// Info in order to have this session used via RPC
Expand Down Expand Up @@ -187,12 +187,12 @@ func ParseOSURL(input string, useFullAPI bool) (OSDriver, error) {
// SaveRetried tries to SaveData specified number of times
func SaveRetried(sess OSSession, name string, data []byte, meta map[string]string, retryCount int) (string, error) {
if retryCount < 1 {
return "", fmt.Errorf("Invalid retry count %d", retryCount)
return "", fmt.Errorf("invalid retry count %d", retryCount)
}
var uri string
var err error
for i := 0; i < retryCount; i++ {
uri, err = sess.SaveData(name, data, meta)
uri, err = sess.SaveData(name, data, meta, 0)
if err == nil {
return uri, err
}
Expand Down
9 changes: 6 additions & 3 deletions drivers/gs.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ func (os *gsSession) createClient() error {
return nil
}

func (os *gsSession) SaveData(name string, data []byte, meta map[string]string) (string, error) {
func (os *gsSession) SaveData(name string, data []byte, meta map[string]string, timeout time.Duration) (string, error) {
if os.useFullAPI {
if os.client == nil {
if err := os.createClient(); err != nil {
Expand All @@ -163,7 +163,10 @@ func (os *gsSession) SaveData(name string, data []byte, meta map[string]string)
keyname := os.key + "/" + name
objh := os.client.Bucket(os.bucket).Object(keyname)
glog.V(common.VERBOSE).Infof("Saving to GS %s/%s", os.bucket, keyname)
ctx, cancel := context.WithTimeout(context.Background(), saveTimeout)
if timeout == 0 {
timeout = saveTimeout
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
wr := objh.NewWriter(ctx)
if len(meta) > 0 && wr.Metadata == nil {
Expand All @@ -185,7 +188,7 @@ func (os *gsSession) SaveData(name string, data []byte, meta map[string]string)
glog.V(common.VERBOSE).Infof("Saved to GS %s", uri)
return uri, err
}
return os.s3Session.SaveData(name, data, meta)
return os.s3Session.SaveData(name, data, meta, timeout)
}

type gsPageInfo struct {
Expand Down
3 changes: 2 additions & 1 deletion drivers/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"path"
"strings"
"sync"
"time"

"github.com/livepeer/go-livepeer/net"
)
Expand Down Expand Up @@ -194,7 +195,7 @@ func (ostore *MemorySession) GetInfo() *net.OSInfo {
return nil
}

func (ostore *MemorySession) SaveData(name string, data []byte, meta map[string]string) (string, error) {
func (ostore *MemorySession) SaveData(name string, data []byte, meta map[string]string, timeout time.Duration) (string, error) {
path, file := path.Split(ostore.getAbsolutePath(name))

ostore.dLock.Lock()
Expand Down
8 changes: 4 additions & 4 deletions drivers/local_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,17 @@ func TestLocalOS(t *testing.T) {
assert.NoError((err))
os := NewMemoryDriver(u)
sess := os.NewSession(("sesspath")).(*MemorySession)
path, err := sess.SaveData("name1/1.ts", copyBytes(tempData1), nil)
path, err := sess.SaveData("name1/1.ts", copyBytes(tempData1), nil, 0)
glog.Info(path)
fmt.Println(path)
assert.Equal("fake.com/url/stream/sesspath/name1/1.ts", path)
data := sess.GetData("sesspath/name1/1.ts")
fmt.Printf("got Data: '%s'\n", data)
assert.Equal(tempData1, string(data))
path, err = sess.SaveData("name1/1.ts", copyBytes(tempData2), nil)
path, err = sess.SaveData("name1/1.ts", copyBytes(tempData2), nil, 0)
data = sess.GetData("sesspath/name1/1.ts")
assert.Equal(tempData2, string(data))
path, err = sess.SaveData("name1/2.ts", copyBytes(tempData3), nil)
path, err = sess.SaveData("name1/2.ts", copyBytes(tempData3), nil, 0)
data = sess.GetData("sesspath/name1/2.ts")
assert.Equal(tempData3, string(data))
// Test trim prefix when baseURI != nil
Expand All @@ -55,7 +55,7 @@ func TestLocalOS(t *testing.T) {
// Test trim prefix when baseURI = nil
os = NewMemoryDriver(nil)
sess = os.NewSession("sesspath").(*MemorySession)
path, err = sess.SaveData("name1/1.ts", copyBytes(tempData1), nil)
path, err = sess.SaveData("name1/1.ts", copyBytes(tempData1), nil, 0)
assert.Nil(err)
assert.Equal("/stream/sesspath/name1/1.ts", path)

Expand Down
25 changes: 15 additions & 10 deletions drivers/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ func (os *s3Session) ReadData(ctx context.Context, name string) (*FileInfoReader
return res, nil
}

func (os *s3Session) saveDataPut(name string, data []byte, meta map[string]string) (string, error) {
func (os *s3Session) saveDataPut(name string, data []byte, meta map[string]string, timeout time.Duration) (string, error) {
now := time.Now()
bucket := aws.String(os.bucket)
keyname := aws.String(os.key + "/" + name)
Expand All @@ -294,7 +294,10 @@ func (os *s3Session) saveDataPut(name string, data []byte, meta map[string]strin
ContentType: contentType,
ContentLength: aws.Int64(int64(len(data))),
}
ctx, cancel := context.WithTimeout(context.Background(), saveTimeout)
if timeout == 0 {
timeout = saveTimeout
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
resp, err := os.s3svc.PutObjectWithContext(ctx, params, request.WithLogLevel(aws.LogDebug))
cancel()
if err != nil {
Expand All @@ -306,14 +309,14 @@ func (os *s3Session) saveDataPut(name string, data []byte, meta map[string]strin
return uri, err
}

func (os *s3Session) SaveData(name string, data []byte, meta map[string]string) (string, error) {
func (os *s3Session) SaveData(name string, data []byte, meta map[string]string, timeout time.Duration) (string, error) {
if os.s3svc != nil {
return os.saveDataPut(name, data, meta)
return os.saveDataPut(name, data, meta, timeout)
}
// tentativeUrl just used for logging
tentativeURL := path.Join(os.host, os.key, name)
glog.V(common.VERBOSE).Infof("Saving to S3 %s", tentativeURL)
path, err := os.postData(name, data, meta)
path, err := os.postData(name, data, meta, timeout)
if err != nil {
// handle error
glog.Errorf("Save S3 error: %v", err)
Expand Down Expand Up @@ -358,7 +361,7 @@ func (os *s3Session) getContentType(fileName string, buffer []byte) string {
}

// if s3 storage is not our own, we are saving data into it using POST request
func (os *s3Session) postData(fileName string, buffer []byte, meta map[string]string) (string, error) {
func (os *s3Session) postData(fileName string, buffer []byte, meta map[string]string, timeout time.Duration) (string, error) {
fileBytes := bytes.NewReader(buffer)
fileType := os.getContentType(fileName, buffer)
path, fileName := path.Split(path.Join(os.key, fileName))
Expand All @@ -375,7 +378,7 @@ func (os *s3Session) postData(fileName string, buffer []byte, meta map[string]st
if !strings.Contains(postURL, os.bucket) {
postURL += "/" + os.bucket
}
req, cancel, err := newfileUploadRequest(postURL, fields, fileBytes, fileName)
req, cancel, err := newfileUploadRequest(postURL, fields, fileBytes, fileName, timeout)
if err != nil {
glog.Error(err)
return "", err
Expand Down Expand Up @@ -446,7 +449,7 @@ func createPolicy(key, bucket, region, secret, path string) (string, string, str
return policy, signString(policy, region, xAmzDate, secret), xAmzCredential, xAmzDate + "T000000Z"
}

func newfileUploadRequest(uri string, params map[string]string, fData io.Reader, fileName string) (*http.Request, context.CancelFunc, error) {
func newfileUploadRequest(uri string, params map[string]string, fData io.Reader, fileName string, timeout time.Duration) (*http.Request, context.CancelFunc, error) {
glog.Infof("Posting data to %s (params %+v)", uri, params)
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
Expand All @@ -466,8 +469,10 @@ func newfileUploadRequest(uri string, params map[string]string, fData io.Reader,
if err != nil {
return nil, nil, err
}

ctx, cancel := context.WithTimeout(context.Background(), saveTimeout)
if timeout == 0 {
timeout = saveTimeout
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
req, err := http.NewRequestWithContext(ctx, "POST", uri, body)
if err != nil {
cancel()
Expand Down
3 changes: 2 additions & 1 deletion drivers/session_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package drivers

import (
"context"
"time"

"github.com/livepeer/go-livepeer/net"
"github.com/stretchr/testify/mock"
Expand All @@ -11,7 +12,7 @@ type MockOSSession struct {
mock.Mock
}

func (s *MockOSSession) SaveData(name string, data []byte, meta map[string]string) (string, error) {
func (s *MockOSSession) SaveData(name string, data []byte, meta map[string]string, timeout time.Duration) (string, error) {
args := s.Called()
return args.String(0), args.Error(1)
}
Expand Down
8 changes: 4 additions & 4 deletions server/broadcast.go
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ func processSegment(cxn *rtmpConnection, seg *stream.HLSSegment) ([]string, erro
}
}()
}
uri, err := cpl.GetOSSession().SaveData(name, seg.Data, nil)
uri, err := cpl.GetOSSession().SaveData(name, seg.Data, nil, 0)
if err != nil {
glog.Errorf("Error saving segment nonce=%d seqNo=%d: %v", nonce, seg.SeqNo, err)
if monitor.Enabled {
Expand Down Expand Up @@ -498,7 +498,7 @@ func transcodeSegment(cxn *rtmpConnection, seg *stream.HLSSegment, name string,
// storage the orchestrator prefers
if ios := sess.OrchestratorOS; ios != nil {
// XXX handle case when orch expects direct upload
uri, err := ios.SaveData(name, seg.Data, nil)
uri, err := ios.SaveData(name, seg.Data, nil, 0)
if err != nil {
glog.Errorf("Error saving segment to OS nonce=%d seqNo=%d: %v", nonce, seg.SeqNo, err)
if monitor.Enabled {
Expand Down Expand Up @@ -686,7 +686,7 @@ func transcodeSegment(cxn *rtmpConnection, seg *stream.HLSSegment, name string,
return
}
name := fmt.Sprintf("%s/%d%s", profile.Name, seg.SeqNo, ext)
newURL, err := bos.SaveData(name, data, nil)
newURL, err := bos.SaveData(name, data, nil, 0)
if err != nil {
switch err.Error() {
case "Session ended":
Expand Down Expand Up @@ -830,7 +830,7 @@ func verify(verifier *verification.SegmentVerifier, cxn *rtmpConnection,
// Hence, trim the /stream/<manifestID> prefix if it exists.
pfx := fmt.Sprintf("/stream/%s/", sess.Params.ManifestID)
uri := strings.TrimPrefix(accepted.URIs[i], pfx)
_, err := sess.BroadcasterOS.SaveData(uri, data, nil)
_, err := sess.BroadcasterOS.SaveData(uri, data, nil, 0)
if err != nil {
return err
}
Expand Down
6 changes: 3 additions & 3 deletions server/broadcast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ type stubOSSession struct {
err error
}

func (s *stubOSSession) SaveData(name string, data []byte, meta map[string]string) (string, error) {
func (s *stubOSSession) SaveData(name string, data []byte, meta map[string]string, timeout time.Duration) (string, error) {
s.saved = append(s.saved, name)
return "saved_" + name, s.err
}
Expand Down Expand Up @@ -1437,7 +1437,7 @@ func TestVerifier_Verify(t *testing.T) {
}
mem, ok := drivers.NewMemoryDriver(nil).NewSession("streamName").(*drivers.MemorySession)
assert.True(ok)
name, err := mem.SaveData("/rendition/seg/1", []byte("attempt1"), nil)
name, err := mem.SaveData("/rendition/seg/1", []byte("attempt1"), nil, 0)
assert.Nil(err)
assert.Equal([]byte("attempt1"), mem.GetData(name))
sess.BroadcasterOS = mem
Expand All @@ -1449,7 +1449,7 @@ func TestVerifier_Verify(t *testing.T) {

// Now "insert" 2nd attempt into OS
// and ensure 1st attempt is what remains after verification
_, err = mem.SaveData("/rendition/seg/1", []byte("attempt2"), nil)
_, err = mem.SaveData("/rendition/seg/1", []byte("attempt2"), nil, 0)
assert.Nil(err)
assert.Equal([]byte("attempt2"), mem.GetData(name))
renditionData = [][]byte{[]byte("attempt2")}
Expand Down
4 changes: 2 additions & 2 deletions server/mediaserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -1390,7 +1390,7 @@ func (s *LivepeerServer) HandleRecordings(w http.ResponseWriter, r *http.Request
mainJspl.AddSegmentsToMPL(manifests, trackName, mpl, resp.RecordObjectStoreURL)
fileName := trackName + ".m3u8"
nows := time.Now()
_, err = sess.SaveData(fileName, mpl.Encode().Bytes(), nil)
_, err = sess.SaveData(fileName, mpl.Encode().Bytes(), nil, 0)
glog.V(common.VERBOSE).Infof("Saving playlist fileName=%s for manifestID=%s took=%s", fileName, manifestID, time.Since(nows))
if err != nil {
glog.Error(err)
Expand All @@ -1399,7 +1399,7 @@ func (s *LivepeerServer) HandleRecordings(w http.ResponseWriter, r *http.Request
}
}
nows := time.Now()
_, err = sess.SaveData("index.m3u8", masterPList.Encode().Bytes(), nil)
_, err = sess.SaveData("index.m3u8", masterPList.Encode().Bytes(), nil, 0)
glog.V(common.VERBOSE).Infof("Saving playlist fileName=%s for manifestID=%s took=%s", "index.m3u8", manifestID, time.Since(nows))
if err != nil {
glog.Error(err)
Expand Down
Loading

0 comments on commit f383b08

Please sign in to comment.