Skip to content

Commit

Permalink
Merge pull request #87 from Code-Hex/fix/listener-interface
Browse files Browse the repository at this point in the history
breaking change around VirtioSocketListener
  • Loading branch information
Code-Hex authored Oct 27, 2022
2 parents e69f5f8 + 005adff commit b33b17e
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 53 deletions.
4 changes: 2 additions & 2 deletions osversion_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ func TestAvailableVersion(t *testing.T) {
_, err := NewVirtioSocketDeviceConfiguration()
return err
},
"NewVirtioSocketListener": func() error {
_, err := NewVirtioSocketListener(nil)
"(*VirtioSocketDevice).Listen": func() error {
_, err := (*VirtioSocketDevice)(nil).Listen(1)
return err
},
"NewDiskImageStorageDeviceAttachment": func() error {
Expand Down
131 changes: 96 additions & 35 deletions socket.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ package vz
*/
import "C"
import (
"fmt"
"net"
"os"
"runtime"
"runtime/cgo"
"sync"
"time"
"unsafe"
)
Expand Down Expand Up @@ -83,18 +85,44 @@ func newVirtioSocketDevice(ptr, dispatchQueue unsafe.Pointer) *VirtioSocketDevic
}
}

// SetSocketListenerForPort configures an object to monitor the specified port for new connections.
// Listen creates a new VirtioSocketListener which is a struct that listens for port-based connection requests
// from the guest operating system.
//
// see: https://developer.apple.com/documentation/virtualization/vzvirtiosocketdevice/3656679-setsocketlistener?language=objc
func (v *VirtioSocketDevice) SetSocketListenerForPort(listener *VirtioSocketListener, port uint32) {
C.VZVirtioSocketDevice_setSocketListenerForPort(v.Ptr(), v.dispatchQueue, listener.Ptr(), C.uint32_t(port))
}

// RemoveSocketListenerForPort removes the listener object from the specfied port.
// Be sure to close the listener by calling `VirtioSocketListener.Close` after used this one.
//
// see: https://developer.apple.com/documentation/virtualization/vzvirtiosocketdevice/3656678-removesocketlistenerforport?language=objc
func (v *VirtioSocketDevice) RemoveSocketListenerForPort(listener *VirtioSocketListener, port uint32) {
C.VZVirtioSocketDevice_removeSocketListenerForPort(v.Ptr(), v.dispatchQueue, C.uint32_t(port))
// This is only supported on macOS 11 and newer, ErrUnsupportedOSVersion will
// be returned on older versions.
func (v *VirtioSocketDevice) Listen(port uint32) (*VirtioSocketListener, error) {
if macosMajorVersionLessThan(11) {
return nil, ErrUnsupportedOSVersion
}

ch := make(chan accept, 1) // should I increase more caps?

handler := cgo.NewHandle(func(conn *VirtioSocketConnection, err error) {
ch <- accept{conn, err}
})
ptr := C.newVZVirtioSocketListener(
unsafe.Pointer(&handler),
)
listener := &VirtioSocketListener{
pointer: pointer{
ptr: ptr,
},
vsockDevice: v,
port: port,
handler: handler,
acceptch: ch,
}

C.VZVirtioSocketDevice_setSocketListenerForPort(
v.Ptr(),
v.dispatchQueue,
listener.Ptr(),
C.uint32_t(port),
)

return listener, nil
}

//export connectionHandler
Expand All @@ -121,51 +149,84 @@ func connectionHandler(connPtr, errPtr, cgoHandlerPtr unsafe.Pointer) {
func (v *VirtioSocketDevice) ConnectToPort(port uint32, fn func(conn *VirtioSocketConnection, err error)) {
cgoHandler := cgo.NewHandle(fn)
C.VZVirtioSocketDevice_connectToPort(v.Ptr(), v.dispatchQueue, C.uint32_t(port), unsafe.Pointer(&cgoHandler))
runtime.KeepAlive(v)
}

type accept struct {
conn *VirtioSocketConnection
err error
}

// VirtioSocketListener a struct that listens for port-based connection requests from the guest operating system.
//
// see: https://developer.apple.com/documentation/virtualization/vzvirtiosocketlistener?language=objc
type VirtioSocketListener struct {
pointer
vsockDevice *VirtioSocketDevice
handler cgo.Handle
port uint32
acceptch chan accept
closeOnce sync.Once
}

var shouldAcceptNewConnectionHandlers = map[unsafe.Pointer]func(conn *VirtioSocketConnection, err error) bool{}
var _ net.Listener = (*VirtioSocketListener)(nil)

// NewVirtioSocketListener creates a new VirtioSocketListener with connection handler.
//
// The handler is executed asynchronously. Be sure to close the connection used in the handler by calling `conn.Close`.
// This is to prevent connection leaks.
//
// This is only supported on macOS 11 and newer, ErrUnsupportedOSVersion will
// be returned on older versions.
func NewVirtioSocketListener(handler func(conn *VirtioSocketConnection, err error)) (*VirtioSocketListener, error) {
if macosMajorVersionLessThan(11) {
return nil, ErrUnsupportedOSVersion
}
// Accept implements the Accept method in the Listener interface; it waits for the next call and returns a net.Conn.
func (v *VirtioSocketListener) Accept() (net.Conn, error) {
return v.AcceptVirtioSocketConnection()
}

ptr := C.newVZVirtioSocketListener()
listener := &VirtioSocketListener{
pointer: pointer{
ptr: ptr,
},
}
// AcceptVirtioSocketConnection accepts the next incoming call and returns the new connection.
func (v *VirtioSocketListener) AcceptVirtioSocketConnection() (*VirtioSocketConnection, error) {
result := <-v.acceptch
return result.conn, result.err
}

// Close stops listening on the virtio socket.
func (v *VirtioSocketListener) Close() error {
v.closeOnce.Do(func() {
C.VZVirtioSocketDevice_removeSocketListenerForPort(
v.vsockDevice.Ptr(),
v.vsockDevice.dispatchQueue,
C.uint32_t(v.port),
)
v.handler.Delete()
})
return nil
}

shouldAcceptNewConnectionHandlers[ptr] = func(conn *VirtioSocketConnection, err error) bool {
go handler(conn, err)
return true // must be connected
// Addr returns the listener's network address, a *VirtioSocketListenerAddr.
func (v *VirtioSocketListener) Addr() net.Addr {
const VMADDR_CID_HOST = 2 // copied from unix pacage
return &VirtioSocketListenerAddr{
CID: VMADDR_CID_HOST,
Port: v.port,
}
}

return listener, nil
// VirtioSocketListenerAddr represents a network end point address for the vsock protocol.
type VirtioSocketListenerAddr struct {
CID uint32
Port uint32
}

var _ net.Addr = (*VirtioSocketListenerAddr)(nil)

// Network returns "vsock".
func (a *VirtioSocketListenerAddr) Network() string { return "vsock" }

// String returns string of "<cid>:<port>"
func (a *VirtioSocketListenerAddr) String() string { return fmt.Sprintf("%d:%d", a.CID, a.Port) }

//export shouldAcceptNewConnectionHandler
func shouldAcceptNewConnectionHandler(listenerPtr, connPtr, devicePtr unsafe.Pointer) C.bool {
_ = devicePtr // NOTO(codehex): Is this really required? How to use?
func shouldAcceptNewConnectionHandler(cgoHandlerPtr, connPtr, devicePtr unsafe.Pointer) C.bool {
cgoHandler := *(*cgo.Handle)(cgoHandlerPtr)
handler := cgoHandler.Value().(func(*VirtioSocketConnection, error))

// see: startHandler
conn, err := newVirtioSocketConnection(connPtr)
return (C.bool)(shouldAcceptNewConnectionHandlers[listenerPtr](conn, err))
go handler(conn, err)
return (C.bool)(true)
}

// VirtioSocketConnection is a port-based connection between the guest operating system and the host computer.
Expand Down
20 changes: 10 additions & 10 deletions socket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,23 @@ func TestVirtioSocketListener(t *testing.T) {
wantData := "hello"
done := make(chan struct{})

listener, err := vz.NewVirtioSocketListener(func(conn *vz.VirtioSocketConnection, err error) {
listener, err := socketDevice.Listen(uint32(port))
if err != nil {
t.Fatal(err)
}
defer listener.Close()

go func() {
defer close(done)

conn, err := listener.Accept()
if err != nil {
t.Errorf("failed to accept connection: %v", err)
return
}
defer conn.Close()

destPort := conn.DestinationPort()
destPort := conn.(*vz.VirtioSocketConnection).DestinationPort()
if port != int(destPort) {
t.Errorf("want destination port %d but got %d", destPort, port)
return
Expand All @@ -48,12 +55,7 @@ func TestVirtioSocketListener(t *testing.T) {
if wantData != got {
t.Errorf("want %q but got %q", wantData, got)
}
})
if err != nil {
t.Fatal(err)
}

socketDevice.SetSocketListenerForPort(listener, uint32(port))
}()

session := container.NewSession(t)
var buf bytes.Buffer
Expand All @@ -69,6 +71,4 @@ func TestVirtioSocketListener(t *testing.T) {
case <-time.After(3 * time.Second):
t.Fatalf("timeout connection handling after accepted")
}

socketDevice.RemoveSocketListenerForPort(listener, uint32(port))
}
5 changes: 3 additions & 2 deletions virtualization.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
void virtualMachineCompletionHandler(void *cgoHandler, void *errPtr);
void connectionHandler(void *connection, void *err, void *cgoHandlerPtr);
void changeStateOnObserver(int state, void *cgoHandler);
bool shouldAcceptNewConnectionHandler(void *listener, void *connection, void *socketDevice);
bool shouldAcceptNewConnectionHandler(void *cgoHandler, void *connection, void *socketDevice);

@interface Observer : NSObject
- (void)observeValueForKeyPath:(NSString *)keyPath ofObject:(id)object change:(NSDictionary *)change context:(void *)context;
@end

/* VZVirtioSocketListener */
@interface VZVirtioSocketListenerDelegateImpl : NSObject <VZVirtioSocketListenerDelegate>
- (instancetype)initWithHandler:(void *)cgoHandler;
- (BOOL)listener:(VZVirtioSocketListener *)listener shouldAcceptNewConnection:(VZVirtioSocketConnection *)connection fromSocketDevice:(VZVirtioSocketDevice *)socketDevice;
@end

Expand Down Expand Up @@ -81,7 +82,7 @@ void *newVZVirtioSocketDeviceConfiguration();
void *newVZMACAddress(const char *macAddress);
void *newRandomLocallyAdministeredVZMACAddress();
const char *getVZMACAddressString(void *macAddress);
void *newVZVirtioSocketListener();
void *newVZVirtioSocketListener(void *cgoHandlerPtr);
void *newVZSharedDirectory(const char *dirPath, bool readOnly);
void *newVZSingleDirectoryShare(void *sharedDirectory);
void *newVZMultipleDirectoryShare(void *sharedDirectories);
Expand Down
18 changes: 14 additions & 4 deletions virtualization.m
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,20 @@ - (void)observeValueForKeyPath:(NSString *)keyPath ofObject:(id)object change:(N
}
@end

@implementation VZVirtioSocketListenerDelegateImpl
@implementation VZVirtioSocketListenerDelegateImpl {
void *_cgoHandler;
}

- (instancetype)initWithHandler:(void *)cgoHandler
{
self = [super init];
_cgoHandler = cgoHandler;
return self;
}

- (BOOL)listener:(VZVirtioSocketListener *)listener shouldAcceptNewConnection:(VZVirtioSocketConnection *)connection fromSocketDevice:(VZVirtioSocketDevice *)socketDevice;
{
return (BOOL)shouldAcceptNewConnectionHandler(listener, connection, socketDevice);
return (BOOL)shouldAcceptNewConnectionHandler(_cgoHandler, connection, socketDevice);
}
@end

Expand Down Expand Up @@ -743,11 +753,11 @@ void setStreamsVZVirtioSoundDeviceConfiguration(void *audioDeviceConfiguration,
@see VZVirtioSocketDevice
@see VZVirtioSocketListenerDelegate
*/
void *newVZVirtioSocketListener()
void *newVZVirtioSocketListener(void *cgoHandlerPtr)
{
if (@available(macOS 11, *)) {
VZVirtioSocketListener *ret = [[VZVirtioSocketListener alloc] init];
[ret setDelegate:[[VZVirtioSocketListenerDelegateImpl alloc] init]];
[ret setDelegate:[[VZVirtioSocketListenerDelegateImpl alloc] initWithHandler:cgoHandlerPtr]];
return ret;
}

Expand Down

0 comments on commit b33b17e

Please sign in to comment.