Skip to content

Commit

Permalink
Calling SyscallN directly when dealing with pointer-pointers
Browse files Browse the repository at this point in the history
  • Loading branch information
danieljoos committed May 15, 2023
1 parent 1af1852 commit 18a76dc
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 167 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ go 1.13

require (
github.com/stretchr/testify v1.8.1
golang.org/x/sys v0.0.0-20210819135213-f52c844e1c1c
golang.org/x/sys v0.8.0
)
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
golang.org/x/sys v0.0.0-20210819135213-f52c844e1c1c h1:Lyn7+CqXIiC+LOR9aHD6jDK+hPcmAuCfuXztd1v4w1Q=
golang.org/x/sys v0.0.0-20210819135213-f52c844e1c1c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
Expand Down
28 changes: 16 additions & 12 deletions sys.go
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
//go:build windows
// +build windows

package wincred

import (
"reflect"
"syscall"
"unsafe"

syscall "golang.org/x/sys/windows"
"golang.org/x/sys/windows"
)

var (
modadvapi32 = syscall.NewLazyDLL("advapi32.dll")
procCredRead proc = modadvapi32.NewProc("CredReadW")
modadvapi32 = windows.NewLazySystemDLL("advapi32.dll")
procCredRead = modadvapi32.NewProc("CredReadW")
procCredWrite proc = modadvapi32.NewProc("CredWriteW")
procCredDelete proc = modadvapi32.NewProc("CredDeleteW")
procCredFree proc = modadvapi32.NewProc("CredFree")
procCredEnumerate proc = modadvapi32.NewProc("CredEnumerateW")
procCredEnumerate = modadvapi32.NewProc("CredEnumerateW")
)

// Interface for syscall.Proc: helps testing
Expand All @@ -29,7 +31,7 @@ type sysCREDENTIAL struct {
Type uint32
TargetName *uint16
Comment *uint16
LastWritten syscall.Filetime
LastWritten windows.Filetime
CredentialBlobSize uint32
CredentialBlob uintptr
Persist uint32
Expand Down Expand Up @@ -59,15 +61,16 @@ const (
sysCRED_TYPE_DOMAIN_EXTENDED sysCRED_TYPE = 0x6

// https://docs.microsoft.com/en-us/windows/desktop/Debug/system-error-codes
sysERROR_NOT_FOUND = syscall.Errno(1168)
sysERROR_INVALID_PARAMETER = syscall.Errno(87)
sysERROR_NOT_FOUND = windows.Errno(1168)
sysERROR_INVALID_PARAMETER = windows.Errno(87)
)

// https://docs.microsoft.com/en-us/windows/desktop/api/wincred/nf-wincred-credreadw
func sysCredRead(targetName string, typ sysCRED_TYPE) (*Credential, error) {
var pcred *sysCREDENTIAL
targetNamePtr, _ := syscall.UTF16PtrFromString(targetName)
ret, _, err := procCredRead.Call(
targetNamePtr, _ := windows.UTF16PtrFromString(targetName)
ret, _, err := syscall.SyscallN(
procCredRead.Addr(),
uintptr(unsafe.Pointer(targetNamePtr)),
uintptr(typ),
0,
Expand Down Expand Up @@ -98,7 +101,7 @@ func sysCredWrite(cred *Credential, typ sysCRED_TYPE) error {

// https://docs.microsoft.com/en-us/windows/desktop/api/wincred/nf-wincred-creddeletew
func sysCredDelete(cred *Credential, typ sysCRED_TYPE) error {
targetNamePtr, _ := syscall.UTF16PtrFromString(cred.TargetName)
targetNamePtr, _ := windows.UTF16PtrFromString(cred.TargetName)
ret, _, err := procCredDelete.Call(
uintptr(unsafe.Pointer(targetNamePtr)),
uintptr(typ),
Expand All @@ -117,9 +120,10 @@ func sysCredEnumerate(filter string, all bool) ([]*Credential, error) {
var pcreds uintptr
var filterPtr *uint16
if !all {
filterPtr, _ = syscall.UTF16PtrFromString(filter)
filterPtr, _ = windows.UTF16PtrFromString(filter)
}
ret, _, err := procCredEnumerate.Call(
ret, _, err := syscall.SyscallN(
procCredEnumerate.Addr(),
uintptr(unsafe.Pointer(filterPtr)),
0,
uintptr(unsafe.Pointer(&count)),
Expand Down
153 changes: 1 addition & 152 deletions sys_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
//go:build windows
// +build windows

package wincred

import (
"errors"
"testing"
"unsafe"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
Expand All @@ -32,80 +32,6 @@ func (t *mockProc) Call(a ...uintptr) (r1, r2 uintptr, lastErr error) {
return uintptr(args.Int(0)), uintptr(args.Int(1)), args.Error(2)
}

func TestSysCredRead_MockFailure(t *testing.T) {
// The test error
testError := errors.New("test error")
// Mock `CreadRead`: returns failure state and the error
mockCredRead := new(mockProc)
mockCredRead.On("Call", mock.AnythingOfType("[]uintptr")).Return(0, 0, testError)
mockCredRead.Setup(&procCredRead)
defer mockCredRead.TearDown()
// Mock `CredFree`: Must not be called
mockCredFree := new(mockProc)
mockCredFree.On("Call", mock.AnythingOfType("[]uintptr")).Return(0, 0, nil)
mockCredFree.Setup(&procCredFree)
defer mockCredFree.TearDown()

// Test it:
var res *Credential
var err error
assert.NotPanics(t, func() { res, err = sysCredRead("foo", sysCRED_TYPE_GENERIC) })
assert.Nil(t, res)
assert.NotNil(t, err)
assert.Equal(t, "test error", err.Error())
mockCredRead.AssertNumberOfCalls(t, "Call", 1)
mockCredFree.AssertNumberOfCalls(t, "Call", 0)
}

func TestSysCredRead_Mock(t *testing.T) {
// prepare some test data
cred := new(Credential)
cred.TargetName = "Foo"
cred.Comment = "Bar"
cred.CredentialBlob = []byte{1, 2, 3}
credSys := sysFromCredential(cred)
t.Log(credSys) // Workaround to keep the object alive

// Mock `CreadRead`: returns success and sets the pointer to the prepared sysCred struct
mockCredRead := new(mockProc)
mockCredRead.
On("Call", mock.AnythingOfType("[]uintptr")).
Return(1, 0, nil).
Run(func(args mock.Arguments) {
arg := args.Get(0).([]uintptr)
assert.Equal(t, 4, len(arg))
*(**sysCREDENTIAL)(unsafe.Pointer(arg[3])) = credSys
})
mockCredRead.Setup(&procCredRead)
defer mockCredRead.TearDown()

// Mock `CredFree`: Must be called as well with the correct pointer
mockCredFree := new(mockProc)
mockCredFree.
On("Call", mock.AnythingOfType("[]uintptr")).
Return(0, 0, nil).
Run(func(args mock.Arguments) {
arg := args.Get(0).([]uintptr)
assert.Equal(t, 1, len(arg))
assert.Equal(t, uintptr(unsafe.Pointer(credSys)), arg[0])
})
mockCredFree.Setup(&procCredFree)
defer mockCredFree.TearDown()

// Test it:
var res *Credential
var err error
assert.NotPanics(t, func() { res, err = sysCredRead("Foo", sysCRED_TYPE_GENERIC) })
mockCredRead.AssertNumberOfCalls(t, "Call", 1)
mockCredFree.AssertNumberOfCalls(t, "Call", 1)
assert.NotNil(t, res)
assert.Nil(t, err)
assert.Equal(t, "Foo", res.TargetName)
assert.Equal(t, "Bar", res.Comment)
assert.Equal(t, []byte{1, 2, 3}, res.CredentialBlob)
assert.NotEqual(t, &cred, &res)
}

func TestSysCredWrite_MockFailure(t *testing.T) {
// Mock `CreadWrite`: returns failure state and the error
mockCredWrite := new(mockProc)
Expand Down Expand Up @@ -163,80 +89,3 @@ func TestSysCredDelete_Mock(t *testing.T) {
assert.Nil(t, err)
mockCredDelete.AssertNumberOfCalls(t, "Call", 1)
}

func TestSysCredEnumerate_MockFailure(t *testing.T) {
// The test error
testError := errors.New("test error")
// Mock `CreadEnumerate`: returns failure state and the error
mockCredEnumerate := new(mockProc)
mockCredEnumerate.On("Call", mock.AnythingOfType("[]uintptr")).Return(0, 0, testError)
mockCredEnumerate.Setup(&procCredEnumerate)
defer mockCredEnumerate.TearDown()
// Mock `CredFree`: Must not be called
mockCredFree := new(mockProc)
mockCredFree.On("Call", mock.AnythingOfType("[]uintptr")).Return(0, 0, nil)
mockCredFree.Setup(&procCredFree)
defer mockCredFree.TearDown()

// Test it:
var res []*Credential
var err error
assert.NotPanics(t, func() { res, err = sysCredEnumerate("", true) })
assert.Nil(t, res)
assert.NotNil(t, err)
assert.Equal(t, "test error", err.Error())
mockCredEnumerate.AssertNumberOfCalls(t, "Call", 1)
mockCredFree.AssertNumberOfCalls(t, "Call", 0)
}

func TestSysCredEnumerate_Mock(t *testing.T) {
// prepare some test data
creds := []*Credential{new(Credential), new(Credential)}
creds[0].TargetName = "Foo"
creds[1].TargetName = "Bar"
credsSys := [](*sysCREDENTIAL){
sysFromCredential(creds[0]),
sysFromCredential(creds[1]),
}
t.Log(credsSys[0]) // Workaround to keep the object alive
t.Log(credsSys[1]) // Workaround to keep the object alive

// Mock `CreadEnumerate`: returns success and sets the pointer to the prepared sysCreds array
mockCredEnumerate := new(mockProc)
mockCredEnumerate.
On("Call", mock.AnythingOfType("[]uintptr")).
Return(1, 0, nil).
Run(func(args mock.Arguments) {
arg := args.Get(0).([]uintptr)
assert.Equal(t, 4, len(arg))
*(*int)(unsafe.Pointer(arg[2])) = len(credsSys)
*(*[]*sysCREDENTIAL)(unsafe.Pointer(arg[3])) = credsSys
})
mockCredEnumerate.Setup(&procCredEnumerate)
defer mockCredEnumerate.TearDown()

// Mock `CredFree`: Must be called as well with the correct pointer
mockCredFree := new(mockProc)
mockCredFree.
On("Call", mock.AnythingOfType("[]uintptr")).
Return(0, 0, nil).
Run(func(args mock.Arguments) {
arg := args.Get(0).([]uintptr)
assert.Equal(t, 1, len(arg))
assert.Equal(t, uintptr(unsafe.Pointer(&credsSys[0])), arg[0])
})
mockCredFree.Setup(&procCredFree)
defer mockCredFree.TearDown()

// Test it:
var res []*Credential
var err error
assert.NotPanics(t, func() { res, err = sysCredEnumerate("", true) })
mockCredEnumerate.AssertNumberOfCalls(t, "Call", 1)
mockCredFree.AssertNumberOfCalls(t, "Call", 1)
assert.NotNil(t, res)
assert.Nil(t, err)
assert.Equal(t, 2, len(res))
assert.Equal(t, "Foo", res[0].TargetName)
assert.Equal(t, "Bar", res[1].TargetName)
}

0 comments on commit 18a76dc

Please sign in to comment.