Skip to content

Commit

Permalink
Merge pull request #16 from JeffreyRichter/master
Browse files Browse the repository at this point in the history
Improvement to parallel upload to block blob
  • Loading branch information
JeffreyRichter authored Jan 26, 2018
2 parents fb32827 + f5afd62 commit c91e48a
Show file tree
Hide file tree
Showing 21 changed files with 630 additions and 430 deletions.
15 changes: 14 additions & 1 deletion 2016-05-31/azblob/credential_anonymous.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,26 @@ type Credential interface {
credentialMarker()
}

type credentialFunc pipeline.FactoryFunc

func (f credentialFunc) New(next pipeline.Policy, po *pipeline.PolicyOptions) pipeline.Policy {
return f(next, po)
}

// credentialMarker is a package-internal method that exists just to satisfy the Credential interface.
func (credentialFunc) credentialMarker() {}

//////////////////////////////

// NewAnonymousCredential creates an anonymous credential for use with HTTP(S)
// requests that read blobs from public containers or for use with Shared Access
// Signatures (SAS).
func NewAnonymousCredential() Credential {
return &anonymousCredentialPolicyFactory{}
return anonymousCredentialFactory
}

var anonymousCredentialFactory Credential = &anonymousCredentialPolicyFactory{} // Singleton

// anonymousCredentialPolicyFactory is the credential's policy factory.
type anonymousCredentialPolicyFactory struct {
}
Expand Down
44 changes: 17 additions & 27 deletions 2016-05-31/azblob/credential_shared_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,38 +40,28 @@ func (f SharedKeyCredential) AccountName() string {

// New creates a credential policy object.
func (f *SharedKeyCredential) New(next pipeline.Policy, po *pipeline.PolicyOptions) pipeline.Policy {
return sharedKeyCredentialPolicy{factory: f, next: next, po: po}
return pipeline.PolicyFunc(func(ctx context.Context, request pipeline.Request) (pipeline.Response, error) {
// Add a x-ms-date header if it doesn't already exist
if d := request.Header.Get(headerXmsDate); d == "" {
request.Header[headerXmsDate] = []string{time.Now().UTC().Format(http.TimeFormat)}
}
stringToSign := f.buildStringToSign(request)
signature := f.ComputeHMACSHA256(stringToSign)
authHeader := strings.Join([]string{"SharedKey ", f.accountName, ":", signature}, "")
request.Header[headerAuthorization] = []string{authHeader}

response, err := next.Do(ctx, request)
if err != nil && response != nil && response.Response() != nil && response.Response().StatusCode == http.StatusForbidden {
// Service failed to authenticate request, log it
po.Log(pipeline.LogError, "===== HTTP Forbidden status, String-to-Sign:\n"+stringToSign+"\n===============================\n")
}
return response, err
})
}

// credentialMarker is a package-internal method that exists just to satisfy the Credential interface.
func (*SharedKeyCredential) credentialMarker() {}

// sharedKeyCredentialPolicy is the credential's policy object.
type sharedKeyCredentialPolicy struct {
factory *SharedKeyCredential
next pipeline.Policy
po *pipeline.PolicyOptions
}

// Do implements the credential's policy interface.
func (p sharedKeyCredentialPolicy) Do(ctx context.Context, request pipeline.Request) (pipeline.Response, error) {
// Add a x-ms-date header if it doesn't already exist
if d := request.Header.Get(headerXmsDate); d == "" {
request.Header[headerXmsDate] = []string{time.Now().UTC().Format(http.TimeFormat)}
}
stringToSign := p.factory.buildStringToSign(request)
signature := p.factory.ComputeHMACSHA256(stringToSign)
authHeader := strings.Join([]string{"SharedKey ", p.factory.accountName, ":", signature}, "")
request.Header[headerAuthorization] = []string{authHeader}

response, err := p.next.Do(ctx, request)
if err != nil && response != nil && response.Response() != nil && response.Response().StatusCode == http.StatusForbidden {
// Service failed to authenticate request, log it
p.po.Log(pipeline.LogError, "===== HTTP Forbidden status, String-to-Sign:\n"+stringToSign+"\n===============================\n")
}
return response, err
}

// Constants ensuring that header names are correctly spelled and consistently cased.
const (
headerAuthorization = "Authorization"
Expand Down
27 changes: 9 additions & 18 deletions 2016-05-31/azblob/credential_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,16 @@ func (f *TokenCredential) Token() string { return f.token.Load().(string) }
// SetToken changes the current token value
func (f *TokenCredential) SetToken(token string) { f.token.Store(token) }

// New creates a credential policy object.
func (f *TokenCredential) New(next pipeline.Policy, po *pipeline.PolicyOptions) pipeline.Policy {
return &tokenCredentialPolicy{factory: f, next: next}
}

// credentialMarker is a package-internal method that exists just to satisfy the Credential interface.
func (*TokenCredential) credentialMarker() {}

// tokenCredentialPolicy is the credential's policy object.
type tokenCredentialPolicy struct {
factory *TokenCredential
next pipeline.Policy
}

// Do implements the credential's policy interface.
func (p tokenCredentialPolicy) Do(ctx context.Context, request pipeline.Request) (pipeline.Response, error) {
if request.URL.Scheme != "https" {
panic("Token credentials require a URL using the https protocol scheme.")
}
request.Header[headerAuthorization] = []string{"Bearer " + p.factory.Token()}
return p.next.Do(ctx, request)
// New creates a credential policy object.
func (f *TokenCredential) New(next pipeline.Policy, po *pipeline.PolicyOptions) pipeline.Policy {
return pipeline.PolicyFunc(func(ctx context.Context, request pipeline.Request) (pipeline.Response, error) {
if request.URL.Scheme != "https" {
panic("Token credentials require a URL using the https protocol scheme.")
}
request.Header[headerAuthorization] = []string{"Bearer " + f.Token()}
return next.Do(ctx, request)
})
}
141 changes: 122 additions & 19 deletions 2016-05-31/azblob/highlevel.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,38 @@ import (
"net"
"net/http"

"bytes"
"github.com/Azure/azure-pipeline-go/pipeline"
"os"
"sync"
"time"
)

// UploadStreamToBlockBlobOptions identifies options used by the UploadStreamToBlockBlob function. Note that the
// BlockSize field is mandatory and must be set; other fields are optional.
type UploadStreamToBlockBlobOptions struct {
// BlockSize is mandatory. It specifies the block size to use; the maximum size is BlockBlobMaxPutBlockBytes.
BlockSize int64
// CommonResponseHeaders returns the headers common to all blob REST API responses.
type CommonResponse interface {
// ETag returns the value for header ETag.
ETag() ETag

// LastModified returns the value for header Last-Modified.
LastModified() time.Time

// RequestID returns the value for header x-ms-request-id.
RequestID() string

// Date returns the value for header Date.
Date() time.Time

// Version returns the value for header x-ms-version.
Version() string

// Response returns the raw HTTP response object.
Response() *http.Response
}

// UploadToBlockBlobOptions identifies options used by the UploadBufferToBlockBlob and UploadFileToBlockBlob functions.
type UploadToBlockBlobOptions struct {
// BlockSize specifies the block size to use; the default (and maximum size) is BlockBlobMaxPutBlockBytes.
BlockSize uint64

// Progress is a function that is invoked periodically as bytes are send in a PutBlock call to the BlockBlobURL.
Progress pipeline.ProgressReceiver
Expand All @@ -28,45 +52,124 @@ type UploadStreamToBlockBlobOptions struct {

// AccessConditions indicates the access conditions for the block blob.
AccessConditions BlobAccessConditions

// Parallelism indicates the maximum number of blocks to upload in parallel (0=default)
Parallelism uint16
}

// UploadStreamToBlockBlob uploads a stream of data in blocks to a block blob.
func UploadStreamToBlockBlob(ctx context.Context, stream io.ReaderAt, streamSize int64,
blockBlobURL BlockBlobURL, o UploadStreamToBlockBlobOptions) (*BlockBlobsPutBlockListResponse, error) {
// UploadBufferToBlockBlob uploads a buffer in blocks to a block blob.
func UploadBufferToBlockBlob(ctx context.Context, b []byte,
blockBlobURL BlockBlobURL, o UploadToBlockBlobOptions) (CommonResponse, error) {

if o.BlockSize <= 0 || o.BlockSize > BlockBlobMaxPutBlockBytes {
if o.BlockSize < 0 || o.BlockSize > BlockBlobMaxPutBlockBytes {
panic(fmt.Sprintf("BlockSize option must be > 0 and <= %d", BlockBlobMaxPutBlockBytes))
}
if o.BlockSize == 0 {
o.BlockSize = BlockBlobMaxPutBlockBytes // Default if unspecified
}
size := uint64(len(b))

if size <= BlockBlobMaxPutBlobBytes {
// If the size can fit in 1 Put Blob call, do it this way
var body io.ReadSeeker = bytes.NewReader(b)
if o.Progress != nil {
body = pipeline.NewRequestBodyProgress(body, o.Progress)
}
return blockBlobURL.PutBlob(ctx, body, o.BlobHTTPHeaders, o.Metadata, o.AccessConditions)
}

parallelism := o.Parallelism
if parallelism == 0 {
parallelism = 5 // default parallelism
}

numBlocks := ((streamSize - int64(1)) / o.BlockSize) + 1
var numBlocks uint16 = uint16(((size - 1) / o.BlockSize) + 1)
if numBlocks > BlockBlobMaxBlocks {
panic(fmt.Sprintf("The streamSize is too big or the BlockSize is too small; the number of blocks must be <= %d", BlockBlobMaxBlocks))
}

ctx, cancel := context.WithCancel(ctx)
defer cancel()
blockIDList := make([]string, numBlocks) // Base 64 encoded block IDs
blockSize := o.BlockSize

for blockNum := int64(0); blockNum < numBlocks; blockNum++ {
putBlockChannel := make(chan func() (*BlockBlobsPutBlockResponse, error), parallelism) // Create the channel that release 'parallelism' goroutines concurrently
putBlockResponseChannel := make(chan error, numBlocks) // Holds each Put Block's response

// Create the goroutines that process each Put Block (in parallel)
for g := uint16(0); g < parallelism; g++ {
go func() {
for f := range putBlockChannel {
_, err := f()
putBlockResponseChannel <- err
}
}()
}

blobProgress := int64(0)
progressLock := &sync.Mutex{}

// Add each put block to the channel
for blockNum := uint16(0); blockNum < numBlocks; blockNum++ {
if blockNum == numBlocks-1 { // Last block
blockSize = streamSize - (blockNum * o.BlockSize) // Remove size of all uploaded blocks from total
blockSize = size - (uint64(blockNum) * o.BlockSize) // Remove size of all uploaded blocks from total
}
offset := uint64(blockNum) * o.BlockSize

streamOffset := blockNum * o.BlockSize
// Prepare to read the proper block/section of the file
var body io.ReadSeeker = io.NewSectionReader(stream, streamOffset, blockSize)
// Prepare to read the proper block/section of the buffer
var body io.ReadSeeker = bytes.NewReader(b[offset : offset+blockSize])
capturedBlockNum := blockNum
if o.Progress != nil {
blockProgress := int64(0)
body = pipeline.NewRequestBodyProgress(body,
func(bytesTransferred int64) { o.Progress(streamOffset + bytesTransferred) })
func(bytesTransferred int64) {
diff := bytesTransferred - blockProgress
blockProgress = bytesTransferred
progressLock.Lock()
blobProgress += diff
o.Progress(blobProgress)
progressLock.Unlock()
})
}

// Block IDs are unique values to avoid issue if 2+ clients are uploading blocks
// at the same time causeing PutBlockList to get a mix of blocks from all the clients.
// at the same time causing PutBlockList to get a mix of blocks from all the clients.
blockIDList[blockNum] = base64.StdEncoding.EncodeToString(newUUID().bytes())
_, err := blockBlobURL.PutBlock(ctx, blockIDList[blockNum], body, o.AccessConditions.LeaseAccessConditions)
putBlockChannel <- func() (*BlockBlobsPutBlockResponse, error) {
return blockBlobURL.PutBlock(ctx, blockIDList[capturedBlockNum], body, o.AccessConditions.LeaseAccessConditions)
}
}
close(putBlockChannel)

// Wait for the put blocks to complete
for blockNum := uint16(0); blockNum < numBlocks; blockNum++ {
responseError := <-putBlockResponseChannel
if responseError != nil {
cancel() // As soon as any Put Block fails, cancel all remaining Put Block calls
return nil, responseError // No need to process anymore responses
}
}
// All put blocks were successful, call Put Block List to finalize the blob
return blockBlobURL.PutBlockList(ctx, blockIDList, o.Metadata, o.BlobHTTPHeaders, o.AccessConditions)
}

// UploadFileToBlockBlob uploads a file in blocks to a block blob.
func UploadFileToBlockBlob(ctx context.Context, file *os.File,
blockBlobURL BlockBlobURL, o UploadToBlockBlobOptions) (CommonResponse, error) {

stat, err := file.Stat()
if err != nil {
return nil, err
}
m := mmf{} // Default to an empty slice; used for 0-size file
if stat.Size() != 0 {
m, err = newMMF(file, false, 0, int(stat.Size()))
if err != nil {
return nil, err
}
defer m.unmap()
}
return blockBlobURL.PutBlockList(ctx, blockIDList, o.Metadata, o.BlobHTTPHeaders, o.AccessConditions)
return UploadBufferToBlockBlob(ctx, m, blockBlobURL, o)
}

// DownloadStreamOptions is used to configure a call to NewDownloadBlobToStream to download a large stream with intelligent retries.
Expand Down
25 changes: 25 additions & 0 deletions 2016-05-31/azblob/mmap_linux.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package azblob

import (
"os"
"syscall"
)

type mmf []byte

func newMMF(file *os.File, writable bool, offset int64, length int) (mmf, error) {
prot, flags := syscall.PROT_READ, syscall.MAP_SHARED // Assume read-only
if writable {
prot, flags = syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED
}
addr, err := syscall.Mmap(int(file.Fd()), offset, length, prot, flags)
return mmf(addr), err
}

func (m *mmf) unmap() {
err := syscall.Munmap(*m)
*m = nil
if err != nil {
panic(err)
}
}
38 changes: 38 additions & 0 deletions 2016-05-31/azblob/mmap_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package azblob

import (
"os"
"reflect"
"syscall"
"unsafe"
)

type mmf []byte

func newMMF(file *os.File, writable bool, offset int64, length int) (mmf, error) {
prot, access := uint32(syscall.PAGE_READONLY), uint32(syscall.FILE_MAP_READ) // Assume read-only
if writable {
prot, access = uint32(syscall.PAGE_READWRITE), uint32(syscall.FILE_MAP_WRITE)
}
hMMF, errno := syscall.CreateFileMapping(syscall.Handle(file.Fd()), nil, prot, uint32(int64(length)>>32), uint32(int64(length)&0xffffffff), nil)
if hMMF == 0 {
return nil, os.NewSyscallError("CreateFileMapping", errno)
}
defer syscall.CloseHandle(hMMF)
addr, errno := syscall.MapViewOfFile(hMMF, access, uint32(offset>>32), uint32(offset&0xffffffff), uintptr(length))
m := mmf{}
h := (*reflect.SliceHeader)(unsafe.Pointer(&m))
h.Data = addr
h.Len = length
h.Cap = h.Len
return m, nil
}

func (m *mmf) unmap() {
addr := uintptr(unsafe.Pointer(&(([]byte)(*m)[0])))
*m = mmf{}
err := syscall.UnmapViewOfFile(addr)
if err != nil {
panic(err)
}
}
8 changes: 4 additions & 4 deletions 2016-05-31/azblob/parsing_urls.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ func NewBlobURLParts(u url.URL) BlobURLParts {
}

type caseInsensitiveValues url.Values // map[string][]string
func (v caseInsensitiveValues) Get(key string) ([]string, bool) {
func (values caseInsensitiveValues) Get(key string) ([]string, bool) {
key = strings.ToLower(key)
for key, value := range v {
if strings.ToLower(key) == key {
return value, true
for k, v := range values {
if strings.ToLower(k) == key {
return v, true
}
}
return []string{}, false
Expand Down
Loading

0 comments on commit c91e48a

Please sign in to comment.