diff --git a/cmd/pam-moduler/moduler.go b/cmd/pam-moduler/moduler.go index a8b5bfb6..fd3da218 100644 --- a/cmd/pam-moduler/moduler.go +++ b/cmd/pam-moduler/moduler.go @@ -68,6 +68,7 @@ var ( moduleBuildFlags = flag.String("build-flags", "", "comma-separated list of go build flags to use when generating the module") moduleBuildTags = flag.String("build-tags", "", "comma-separated list of build tags to use when generating the module") noMain = flag.Bool("no-main", false, "whether to add an empty main to generated file") + parallelConv = flag.Bool("parallel-conv", false, "whether to support performing PAM conversations in parallel") ) // Usage is a replacement usage function for the flags package. @@ -136,6 +137,7 @@ func main() { generateTags: generateTags, noMain: *noMain, typeName: *typeName, + parallelConv: *parallelConv, } // Print the header and package clause. @@ -168,6 +170,7 @@ type Generator struct { generateTags []string buildFlags []string noMain bool + parallelConv bool } func (g *Generator) Printf(format string, args ...interface{}) { @@ -185,6 +188,11 @@ func (g *Generator) generate() { buildTagsArg = fmt.Sprintf("-tags %s", strings.Join(g.generateTags, ",")) } + var transactionCreator = "NewModuleTransaction" + if g.parallelConv { + transactionCreator = "NewModuleTransactionParallelConv" + } + vFuncs := map[string]string{ "authenticate": "Authenticate", "setcred": "SetCred", @@ -247,7 +255,7 @@ func handlePamCall(pamh *C.pam_handle_t, flags C.int, argc C.int, return pam.Ignore } - mt := pam.NewModuleTransaction(pam.NativeHandle(pamh)) + mt := pam.%s(pam.NativeHandle(pamh)) ret, err := mt.InvokeHandler(moduleFunc, pam.Flags(flags), sliceFromArgv(argc, argv)) @@ -257,7 +265,7 @@ func handlePamCall(pamh *C.pam_handle_t, flags C.int, argc C.int, return ret } -`) +`, transactionCreator) for cName, goName := range vFuncs { g.Printf(` diff --git a/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go index 23cc2624..70867f04 100644 --- a/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go +++ b/cmd/pam-moduler/tests/integration-tester-module/integration-tester-module.go @@ -1,4 +1,4 @@ -//go:generate go run github.com/msteinert/pam/cmd/pam-moduler -type integrationTesterModule +//go:generate go run github.com/msteinert/pam/cmd/pam-moduler -type integrationTesterModule -parallel-conv //go:generate go generate --skip="pam_module.go" package main diff --git a/module-transaction.go b/module-transaction.go index 24e2375c..71e6fabd 100644 --- a/module-transaction.go +++ b/module-transaction.go @@ -20,6 +20,7 @@ import ( "fmt" "runtime" "runtime/cgo" + "sync" "sync/atomic" "unsafe" ) @@ -52,6 +53,7 @@ type ModuleTransaction interface { // ModuleTransaction is the module-side handle for a PAM transaction type moduleTransaction struct { transactionBase + convMutex *sync.Mutex } // ModuleHandler is an interface for objects that can be used to create @@ -65,9 +67,19 @@ type ModuleHandler interface { SetCred(ModuleTransaction, Flags, []string) error } -// NewModuleTransaction allows initializing a transaction from the module side +// NewModuleTransaction allows initializing a transaction from the module side. +// Using this transaction conversations can't be performed in parallel and a +// mutex is used to ensure this is the case. func NewModuleTransaction(handle NativeHandle) ModuleTransaction { - return &moduleTransaction{transactionBase{handle: handle}} + return &moduleTransaction{transactionBase{handle: handle}, &sync.Mutex{}} +} + +// NewModuleTransaction allows initializing a transaction from the module side. +// Conversations using this transaction can be multi-thread, but this requires +// the application loading the module to support this, otherwise we may just +// break their assumptions. +func NewModuleTransactionParallelConv(handle NativeHandle) ModuleTransaction { + return &moduleTransaction{transactionBase{handle: handle}, nil} } func (m *moduleTransaction) InvokeHandler(handler ModuleHandlerFunc, @@ -460,6 +472,10 @@ func (m *moduleTransaction) startConvMultiImpl(iface moduleTransactionIface, } } + if m.convMutex != nil { + m.convMutex.Lock() + defer m.convMutex.Unlock() + } var cResponses *C.struct_pam_response if err := m.handlePamStatus( iface.startConv(conv, C.int(len(requests)), cMessages, &cResponses)); err != nil { diff --git a/module-transaction_test.go b/module-transaction_test.go index 9e19a1a6..5de99ac4 100644 --- a/module-transaction_test.go +++ b/module-transaction_test.go @@ -296,11 +296,9 @@ func Test_ModuleTransaction_InvokeHandler(t *testing.T) { } } -func Test_MockModuleTransaction(t *testing.T) { +func testMockModuleTransaction(t *testing.T, mt *moduleTransaction) { t.Parallel() - mt := NewModuleTransaction(nil).(*moduleTransaction) - tests := map[string]struct { testFunc func(mock *mockModuleTransaction) (any, error) mockExpectations mockModuleTransactionExpectations @@ -857,3 +855,12 @@ func Test_MockModuleTransaction(t *testing.T) { }) } } + +func Test_MockModuleTransaction(t *testing.T) { + testMockModuleTransaction(t, NewModuleTransaction(nil).(*moduleTransaction)) +} + +func Test_MockModuleTransactionParallelConv(t *testing.T) { + testMockModuleTransaction(t, + NewModuleTransactionParallelConv(nil).(*moduleTransaction)) +}