From 005adff7ec6b049db78ddddb022c6a286a318a2b Mon Sep 17 00:00:00 2001 From: Kei Kamikawa Date: Thu, 27 Oct 2022 13:22:28 +0900 Subject: [PATCH] breaking change around VirtioSocketListener --- osversion_test.go | 4 +- socket.go | 131 +++++++++++++++++++++++++++++++++------------- socket_test.go | 20 +++---- virtualization.h | 5 +- virtualization.m | 18 +++++-- 5 files changed, 125 insertions(+), 53 deletions(-) diff --git a/osversion_test.go b/osversion_test.go index 4bf8e22..0d549e1 100644 --- a/osversion_test.go +++ b/osversion_test.go @@ -76,8 +76,8 @@ func TestAvailableVersion(t *testing.T) { _, err := NewVirtioSocketDeviceConfiguration() return err }, - "NewVirtioSocketListener": func() error { - _, err := NewVirtioSocketListener(nil) + "(*VirtioSocketDevice).Listen": func() error { + _, err := (*VirtioSocketDevice)(nil).Listen(1) return err }, "NewDiskImageStorageDeviceAttachment": func() error { diff --git a/socket.go b/socket.go index c911a03..d6b182b 100644 --- a/socket.go +++ b/socket.go @@ -7,10 +7,12 @@ package vz */ import "C" import ( + "fmt" "net" "os" "runtime" "runtime/cgo" + "sync" "time" "unsafe" ) @@ -83,18 +85,44 @@ func newVirtioSocketDevice(ptr, dispatchQueue unsafe.Pointer) *VirtioSocketDevic } } -// SetSocketListenerForPort configures an object to monitor the specified port for new connections. +// Listen creates a new VirtioSocketListener which is a struct that listens for port-based connection requests +// from the guest operating system. // -// see: https://developer.apple.com/documentation/virtualization/vzvirtiosocketdevice/3656679-setsocketlistener?language=objc -func (v *VirtioSocketDevice) SetSocketListenerForPort(listener *VirtioSocketListener, port uint32) { - C.VZVirtioSocketDevice_setSocketListenerForPort(v.Ptr(), v.dispatchQueue, listener.Ptr(), C.uint32_t(port)) -} - -// RemoveSocketListenerForPort removes the listener object from the specfied port. +// Be sure to close the listener by calling `VirtioSocketListener.Close` after used this one. // -// see: https://developer.apple.com/documentation/virtualization/vzvirtiosocketdevice/3656678-removesocketlistenerforport?language=objc -func (v *VirtioSocketDevice) RemoveSocketListenerForPort(listener *VirtioSocketListener, port uint32) { - C.VZVirtioSocketDevice_removeSocketListenerForPort(v.Ptr(), v.dispatchQueue, C.uint32_t(port)) +// This is only supported on macOS 11 and newer, ErrUnsupportedOSVersion will +// be returned on older versions. +func (v *VirtioSocketDevice) Listen(port uint32) (*VirtioSocketListener, error) { + if macosMajorVersionLessThan(11) { + return nil, ErrUnsupportedOSVersion + } + + ch := make(chan accept, 1) // should I increase more caps? + + handler := cgo.NewHandle(func(conn *VirtioSocketConnection, err error) { + ch <- accept{conn, err} + }) + ptr := C.newVZVirtioSocketListener( + unsafe.Pointer(&handler), + ) + listener := &VirtioSocketListener{ + pointer: pointer{ + ptr: ptr, + }, + vsockDevice: v, + port: port, + handler: handler, + acceptch: ch, + } + + C.VZVirtioSocketDevice_setSocketListenerForPort( + v.Ptr(), + v.dispatchQueue, + listener.Ptr(), + C.uint32_t(port), + ) + + return listener, nil } //export connectionHandler @@ -121,6 +149,12 @@ func connectionHandler(connPtr, errPtr, cgoHandlerPtr unsafe.Pointer) { func (v *VirtioSocketDevice) ConnectToPort(port uint32, fn func(conn *VirtioSocketConnection, err error)) { cgoHandler := cgo.NewHandle(fn) C.VZVirtioSocketDevice_connectToPort(v.Ptr(), v.dispatchQueue, C.uint32_t(port), unsafe.Pointer(&cgoHandler)) + runtime.KeepAlive(v) +} + +type accept struct { + conn *VirtioSocketConnection + err error } // VirtioSocketListener a struct that listens for port-based connection requests from the guest operating system. @@ -128,44 +162,71 @@ func (v *VirtioSocketDevice) ConnectToPort(port uint32, fn func(conn *VirtioSock // see: https://developer.apple.com/documentation/virtualization/vzvirtiosocketlistener?language=objc type VirtioSocketListener struct { pointer + vsockDevice *VirtioSocketDevice + handler cgo.Handle + port uint32 + acceptch chan accept + closeOnce sync.Once } -var shouldAcceptNewConnectionHandlers = map[unsafe.Pointer]func(conn *VirtioSocketConnection, err error) bool{} +var _ net.Listener = (*VirtioSocketListener)(nil) -// NewVirtioSocketListener creates a new VirtioSocketListener with connection handler. -// -// The handler is executed asynchronously. Be sure to close the connection used in the handler by calling `conn.Close`. -// This is to prevent connection leaks. -// -// This is only supported on macOS 11 and newer, ErrUnsupportedOSVersion will -// be returned on older versions. -func NewVirtioSocketListener(handler func(conn *VirtioSocketConnection, err error)) (*VirtioSocketListener, error) { - if macosMajorVersionLessThan(11) { - return nil, ErrUnsupportedOSVersion - } +// Accept implements the Accept method in the Listener interface; it waits for the next call and returns a net.Conn. +func (v *VirtioSocketListener) Accept() (net.Conn, error) { + return v.AcceptVirtioSocketConnection() +} - ptr := C.newVZVirtioSocketListener() - listener := &VirtioSocketListener{ - pointer: pointer{ - ptr: ptr, - }, - } +// AcceptVirtioSocketConnection accepts the next incoming call and returns the new connection. +func (v *VirtioSocketListener) AcceptVirtioSocketConnection() (*VirtioSocketConnection, error) { + result := <-v.acceptch + return result.conn, result.err +} + +// Close stops listening on the virtio socket. +func (v *VirtioSocketListener) Close() error { + v.closeOnce.Do(func() { + C.VZVirtioSocketDevice_removeSocketListenerForPort( + v.vsockDevice.Ptr(), + v.vsockDevice.dispatchQueue, + C.uint32_t(v.port), + ) + v.handler.Delete() + }) + return nil +} - shouldAcceptNewConnectionHandlers[ptr] = func(conn *VirtioSocketConnection, err error) bool { - go handler(conn, err) - return true // must be connected +// Addr returns the listener's network address, a *VirtioSocketListenerAddr. +func (v *VirtioSocketListener) Addr() net.Addr { + const VMADDR_CID_HOST = 2 // copied from unix pacage + return &VirtioSocketListenerAddr{ + CID: VMADDR_CID_HOST, + Port: v.port, } +} - return listener, nil +// VirtioSocketListenerAddr represents a network end point address for the vsock protocol. +type VirtioSocketListenerAddr struct { + CID uint32 + Port uint32 } +var _ net.Addr = (*VirtioSocketListenerAddr)(nil) + +// Network returns "vsock". +func (a *VirtioSocketListenerAddr) Network() string { return "vsock" } + +// String returns string of ":" +func (a *VirtioSocketListenerAddr) String() string { return fmt.Sprintf("%d:%d", a.CID, a.Port) } + //export shouldAcceptNewConnectionHandler -func shouldAcceptNewConnectionHandler(listenerPtr, connPtr, devicePtr unsafe.Pointer) C.bool { - _ = devicePtr // NOTO(codehex): Is this really required? How to use? +func shouldAcceptNewConnectionHandler(cgoHandlerPtr, connPtr, devicePtr unsafe.Pointer) C.bool { + cgoHandler := *(*cgo.Handle)(cgoHandlerPtr) + handler := cgoHandler.Value().(func(*VirtioSocketConnection, error)) // see: startHandler conn, err := newVirtioSocketConnection(connPtr) - return (C.bool)(shouldAcceptNewConnectionHandlers[listenerPtr](conn, err)) + go handler(conn, err) + return (C.bool)(true) } // VirtioSocketConnection is a port-based connection between the guest operating system and the host computer. diff --git a/socket_test.go b/socket_test.go index e316cd2..8025d15 100644 --- a/socket_test.go +++ b/socket_test.go @@ -22,16 +22,23 @@ func TestVirtioSocketListener(t *testing.T) { wantData := "hello" done := make(chan struct{}) - listener, err := vz.NewVirtioSocketListener(func(conn *vz.VirtioSocketConnection, err error) { + listener, err := socketDevice.Listen(uint32(port)) + if err != nil { + t.Fatal(err) + } + defer listener.Close() + + go func() { defer close(done) + conn, err := listener.Accept() if err != nil { t.Errorf("failed to accept connection: %v", err) return } defer conn.Close() - destPort := conn.DestinationPort() + destPort := conn.(*vz.VirtioSocketConnection).DestinationPort() if port != int(destPort) { t.Errorf("want destination port %d but got %d", destPort, port) return @@ -48,12 +55,7 @@ func TestVirtioSocketListener(t *testing.T) { if wantData != got { t.Errorf("want %q but got %q", wantData, got) } - }) - if err != nil { - t.Fatal(err) - } - - socketDevice.SetSocketListenerForPort(listener, uint32(port)) + }() session := container.NewSession(t) var buf bytes.Buffer @@ -69,6 +71,4 @@ func TestVirtioSocketListener(t *testing.T) { case <-time.After(3 * time.Second): t.Fatalf("timeout connection handling after accepted") } - - socketDevice.RemoveSocketListenerForPort(listener, uint32(port)) } diff --git a/virtualization.h b/virtualization.h index d1c9780..6a4616d 100644 --- a/virtualization.h +++ b/virtualization.h @@ -13,7 +13,7 @@ void virtualMachineCompletionHandler(void *cgoHandler, void *errPtr); void connectionHandler(void *connection, void *err, void *cgoHandlerPtr); void changeStateOnObserver(int state, void *cgoHandler); -bool shouldAcceptNewConnectionHandler(void *listener, void *connection, void *socketDevice); +bool shouldAcceptNewConnectionHandler(void *cgoHandler, void *connection, void *socketDevice); @interface Observer : NSObject - (void)observeValueForKeyPath:(NSString *)keyPath ofObject:(id)object change:(NSDictionary *)change context:(void *)context; @@ -21,6 +21,7 @@ bool shouldAcceptNewConnectionHandler(void *listener, void *connection, void *so /* VZVirtioSocketListener */ @interface VZVirtioSocketListenerDelegateImpl : NSObject +- (instancetype)initWithHandler:(void *)cgoHandler; - (BOOL)listener:(VZVirtioSocketListener *)listener shouldAcceptNewConnection:(VZVirtioSocketConnection *)connection fromSocketDevice:(VZVirtioSocketDevice *)socketDevice; @end @@ -81,7 +82,7 @@ void *newVZVirtioSocketDeviceConfiguration(); void *newVZMACAddress(const char *macAddress); void *newRandomLocallyAdministeredVZMACAddress(); const char *getVZMACAddressString(void *macAddress); -void *newVZVirtioSocketListener(); +void *newVZVirtioSocketListener(void *cgoHandlerPtr); void *newVZSharedDirectory(const char *dirPath, bool readOnly); void *newVZSingleDirectoryShare(void *sharedDirectory); void *newVZMultipleDirectoryShare(void *sharedDirectories); diff --git a/virtualization.m b/virtualization.m index f218b29..3f2b4ca 100644 --- a/virtualization.m +++ b/virtualization.m @@ -36,10 +36,20 @@ - (void)observeValueForKeyPath:(NSString *)keyPath ofObject:(id)object change:(N } @end -@implementation VZVirtioSocketListenerDelegateImpl +@implementation VZVirtioSocketListenerDelegateImpl { + void *_cgoHandler; +} + +- (instancetype)initWithHandler:(void *)cgoHandler +{ + self = [super init]; + _cgoHandler = cgoHandler; + return self; +} + - (BOOL)listener:(VZVirtioSocketListener *)listener shouldAcceptNewConnection:(VZVirtioSocketConnection *)connection fromSocketDevice:(VZVirtioSocketDevice *)socketDevice; { - return (BOOL)shouldAcceptNewConnectionHandler(listener, connection, socketDevice); + return (BOOL)shouldAcceptNewConnectionHandler(_cgoHandler, connection, socketDevice); } @end @@ -743,11 +753,11 @@ void setStreamsVZVirtioSoundDeviceConfiguration(void *audioDeviceConfiguration, @see VZVirtioSocketDevice @see VZVirtioSocketListenerDelegate */ -void *newVZVirtioSocketListener() +void *newVZVirtioSocketListener(void *cgoHandlerPtr) { if (@available(macOS 11, *)) { VZVirtioSocketListener *ret = [[VZVirtioSocketListener alloc] init]; - [ret setDelegate:[[VZVirtioSocketListenerDelegateImpl alloc] init]]; + [ret setDelegate:[[VZVirtioSocketListenerDelegateImpl alloc] initWithHandler:cgoHandlerPtr]]; return ret; }