diff --git a/socket.go b/socket.go index 9c9dbf9..173cae6 100644 --- a/socket.go +++ b/socket.go @@ -12,8 +12,6 @@ import ( "os" "runtime" "runtime/cgo" - "syscall" - "time" "unsafe" "golang.org/x/sys/unix" @@ -106,11 +104,11 @@ func connectionHandler(connPtr, errPtr, cgoHandlerPtr unsafe.Pointer) { handler := cgoHandler.Value().(func(*VirtioSocketConnection, error)) defer cgoHandler.Delete() // see: startHandler - conn := newVirtioSocketConnection(connPtr) if err := newNSError(errPtr); err != nil { - handler(conn, err) + handler(nil, err) } else { - handler(conn, nil) + conn, err := newVirtioSocketConnection(connPtr) + handler(conn, err) } } @@ -133,12 +131,7 @@ type VirtioSocketListener struct { pointer } -type dup struct { - conn *VirtioSocketConnection - err error -} - -var shouldAcceptNewConnectionHandlers = map[unsafe.Pointer]func(conn *VirtioSocketConnection) bool{} +var shouldAcceptNewConnectionHandlers = map[unsafe.Pointer]func(conn *VirtioSocketConnection, err error) bool{} // NewVirtioSocketListener creates a new VirtioSocketListener with connection handler. // @@ -159,18 +152,8 @@ func NewVirtioSocketListener(handler func(conn *VirtioSocketConnection, err erro }, } - dupCh := make(chan dup, 1) - go func() { - for dup := range dupCh { - go handler(dup.conn, dup.err) - } - }() - shouldAcceptNewConnectionHandlers[ptr] = func(conn *VirtioSocketConnection) bool { - dupConn, err := conn.dup() - dupCh <- dup{ - conn: dupConn, - err: err, - } + shouldAcceptNewConnectionHandlers[ptr] = func(conn *VirtioSocketConnection, err error) bool { + go handler(conn, err) return true // must be connected } @@ -185,8 +168,8 @@ func shouldAcceptNewConnectionHandler(listenerPtr, connPtr, devicePtr unsafe.Poi _ = devicePtr // NOTO(codehex): Is this really required? How to use? // see: startHandler - conn := newVirtioSocketConnection(connPtr) - return (C.bool)(shouldAcceptNewConnectionHandlers[listenerPtr](conn)) + conn, err := newVirtioSocketConnection(connPtr) + return (C.bool)(shouldAcceptNewConnectionHandlers[listenerPtr](conn, err)) } // VirtioSocketConnection is a port-based connection between the guest operating system and the host computer. @@ -202,27 +185,21 @@ func shouldAcceptNewConnectionHandler(listenerPtr, connPtr, devicePtr unsafe.Poi // // see: https://developer.apple.com/documentation/virtualization/vzvirtiosocketconnection?language=objc type VirtioSocketConnection struct { - sourcePort uint32 - destinationPort uint32 - fileDescriptor uintptr - file *os.File - laddr net.Addr // local - raddr net.Addr // remote + net.Conn + laddr *Addr // local + raddr *Addr // remote } -var _ net.Conn = (*VirtioSocketConnection)(nil) - -func newVirtioSocketConnection(ptr unsafe.Pointer) *VirtioSocketConnection { +func newVirtioSocketConnection(ptr unsafe.Pointer) (*VirtioSocketConnection, error) { vzVirtioSocketConnection := C.convertVZVirtioSocketConnection2Flat(ptr) - err := unix.SetNonblock(int(vzVirtioSocketConnection.fileDescriptor), true) + file := os.NewFile((uintptr)(vzVirtioSocketConnection.fileDescriptor), "") + defer file.Close() + rawConn, err := net.FileConn(file) if err != nil { - fmt.Printf("set nonblock: %s\n", err.Error()) + return nil, err } conn := &VirtioSocketConnection{ - sourcePort: (uint32)(vzVirtioSocketConnection.sourcePort), - destinationPort: (uint32)(vzVirtioSocketConnection.destinationPort), - fileDescriptor: (uintptr)(vzVirtioSocketConnection.fileDescriptor), - file: os.NewFile((uintptr)(vzVirtioSocketConnection.fileDescriptor), ""), + Conn: rawConn, laddr: &Addr{ CID: unix.VMADDR_CID_HOST, Port: (uint32)(vzVirtioSocketConnection.destinationPort), @@ -232,40 +209,7 @@ func newVirtioSocketConnection(ptr unsafe.Pointer) *VirtioSocketConnection { Port: (uint32)(vzVirtioSocketConnection.sourcePort), }, } - return conn -} - -func (v *VirtioSocketConnection) dup() (*VirtioSocketConnection, error) { - nfd, err := syscall.Dup(int(v.fileDescriptor)) - if err != nil { - return nil, &net.OpError{ - Op: "dup", - Net: "vsock", - Source: v.laddr, - Addr: v.raddr, - Err: err, - } - } - - dupConn := new(VirtioSocketConnection) - *dupConn = *v - dupConn.fileDescriptor = uintptr(nfd) - dupConn.file = os.NewFile(uintptr(nfd), v.file.Name()) - dupConn.laddr = v.laddr - dupConn.raddr = v.raddr - - return dupConn, nil -} - -// Read reads data from connection of the vsock protocol. -func (v *VirtioSocketConnection) Read(b []byte) (n int, err error) { return v.file.Read(b) } - -// Write writes data to the connection of the vsock protocol. -func (v *VirtioSocketConnection) Write(b []byte) (n int, err error) { return v.file.Write(b) } - -// Close will be called when caused something error in socket. -func (v *VirtioSocketConnection) Close() error { - return v.file.Close() + return conn, nil } // LocalAddr returns the local network address. @@ -274,44 +218,14 @@ func (v *VirtioSocketConnection) LocalAddr() net.Addr { return v.laddr } // RemoteAddr returns the remote network address. func (v *VirtioSocketConnection) RemoteAddr() net.Addr { return v.raddr } -// SetDeadline sets the read and write deadlines associated -// with the connection. It is equivalent to calling both -// SetReadDeadline and SetWriteDeadline. -func (v *VirtioSocketConnection) SetDeadline(t time.Time) error { return v.file.SetDeadline(t) } - -// SetReadDeadline sets the deadline for future Read calls -// and any currently-blocked Read call. -// A zero value for t means Read will not time out. -func (v *VirtioSocketConnection) SetReadDeadline(t time.Time) error { - return v.file.SetReadDeadline(t) -} - -// SetWriteDeadline sets the deadline for future Write calls -// and any currently-blocked Write call. -// Even if write times out, it may return n > 0, indicating that -// some of the data was successfully written. -// A zero value for t means Write will not time out. -func (v *VirtioSocketConnection) SetWriteDeadline(t time.Time) error { - return v.file.SetWriteDeadline(t) -} - // DestinationPort returns the destination port number of the connection. func (v *VirtioSocketConnection) DestinationPort() uint32 { - return v.destinationPort + return v.laddr.Port } // SourcePort returns the source port number of the connection. func (v *VirtioSocketConnection) SourcePort() uint32 { - return v.sourcePort -} - -// FileDescriptor returns the file descriptor associated with the socket. -// -// Data is sent by writing to the file descriptor. -// Data is received by reading from the file descriptor. -// A file descriptor of -1 indicates a closed connection. -func (v *VirtioSocketConnection) FileDescriptor() uintptr { - return v.fileDescriptor + return v.raddr.Port } // Addr represents a network end point address for the vsock protocol.