hvsock.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582
  1. //go:build windows
  2. // +build windows
  3. package winio
  4. import (
  5. "context"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "net"
  10. "os"
  11. "time"
  12. "unsafe"
  13. "golang.org/x/sys/windows"
  14. "github.com/Microsoft/go-winio/internal/socket"
  15. "github.com/Microsoft/go-winio/pkg/guid"
  16. )
  17. const afHVSock = 34 // AF_HYPERV
  18. // Well known Service and VM IDs
  19. // https://docs.microsoft.com/en-us/virtualization/hyper-v-on-windows/user-guide/make-integration-service#vmid-wildcards
  20. // HvsockGUIDWildcard is the wildcard VmId for accepting connections from all partitions.
  21. func HvsockGUIDWildcard() guid.GUID { // 00000000-0000-0000-0000-000000000000
  22. return guid.GUID{}
  23. }
  24. // HvsockGUIDBroadcast is the wildcard VmId for broadcasting sends to all partitions.
  25. func HvsockGUIDBroadcast() guid.GUID { // ffffffff-ffff-ffff-ffff-ffffffffffff
  26. return guid.GUID{
  27. Data1: 0xffffffff,
  28. Data2: 0xffff,
  29. Data3: 0xffff,
  30. Data4: [8]uint8{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
  31. }
  32. }
  33. // HvsockGUIDLoopback is the Loopback VmId for accepting connections to the same partition as the connector.
  34. func HvsockGUIDLoopback() guid.GUID { // e0e16197-dd56-4a10-9195-5ee7a155a838
  35. return guid.GUID{
  36. Data1: 0xe0e16197,
  37. Data2: 0xdd56,
  38. Data3: 0x4a10,
  39. Data4: [8]uint8{0x91, 0x95, 0x5e, 0xe7, 0xa1, 0x55, 0xa8, 0x38},
  40. }
  41. }
  42. // HvsockGUIDSiloHost is the address of a silo's host partition:
  43. // - The silo host of a hosted silo is the utility VM.
  44. // - The silo host of a silo on a physical host is the physical host.
  45. func HvsockGUIDSiloHost() guid.GUID { // 36bd0c5c-7276-4223-88ba-7d03b654c568
  46. return guid.GUID{
  47. Data1: 0x36bd0c5c,
  48. Data2: 0x7276,
  49. Data3: 0x4223,
  50. Data4: [8]byte{0x88, 0xba, 0x7d, 0x03, 0xb6, 0x54, 0xc5, 0x68},
  51. }
  52. }
  53. // HvsockGUIDChildren is the wildcard VmId for accepting connections from the connector's child partitions.
  54. func HvsockGUIDChildren() guid.GUID { // 90db8b89-0d35-4f79-8ce9-49ea0ac8b7cd
  55. return guid.GUID{
  56. Data1: 0x90db8b89,
  57. Data2: 0xd35,
  58. Data3: 0x4f79,
  59. Data4: [8]uint8{0x8c, 0xe9, 0x49, 0xea, 0xa, 0xc8, 0xb7, 0xcd},
  60. }
  61. }
  62. // HvsockGUIDParent is the wildcard VmId for accepting connections from the connector's parent partition.
  63. // Listening on this VmId accepts connection from:
  64. // - Inside silos: silo host partition.
  65. // - Inside hosted silo: host of the VM.
  66. // - Inside VM: VM host.
  67. // - Physical host: Not supported.
  68. func HvsockGUIDParent() guid.GUID { // a42e7cda-d03f-480c-9cc2-a4de20abb878
  69. return guid.GUID{
  70. Data1: 0xa42e7cda,
  71. Data2: 0xd03f,
  72. Data3: 0x480c,
  73. Data4: [8]uint8{0x9c, 0xc2, 0xa4, 0xde, 0x20, 0xab, 0xb8, 0x78},
  74. }
  75. }
  76. // hvsockVsockServiceTemplate is the Service GUID used for the VSOCK protocol.
  77. func hvsockVsockServiceTemplate() guid.GUID { // 00000000-facb-11e6-bd58-64006a7986d3
  78. return guid.GUID{
  79. Data2: 0xfacb,
  80. Data3: 0x11e6,
  81. Data4: [8]uint8{0xbd, 0x58, 0x64, 0x00, 0x6a, 0x79, 0x86, 0xd3},
  82. }
  83. }
  84. // An HvsockAddr is an address for a AF_HYPERV socket.
  85. type HvsockAddr struct {
  86. VMID guid.GUID
  87. ServiceID guid.GUID
  88. }
  89. type rawHvsockAddr struct {
  90. Family uint16
  91. _ uint16
  92. VMID guid.GUID
  93. ServiceID guid.GUID
  94. }
  95. var _ socket.RawSockaddr = &rawHvsockAddr{}
  96. // Network returns the address's network name, "hvsock".
  97. func (*HvsockAddr) Network() string {
  98. return "hvsock"
  99. }
  100. func (addr *HvsockAddr) String() string {
  101. return fmt.Sprintf("%s:%s", &addr.VMID, &addr.ServiceID)
  102. }
  103. // VsockServiceID returns an hvsock service ID corresponding to the specified AF_VSOCK port.
  104. func VsockServiceID(port uint32) guid.GUID {
  105. g := hvsockVsockServiceTemplate() // make a copy
  106. g.Data1 = port
  107. return g
  108. }
  109. func (addr *HvsockAddr) raw() rawHvsockAddr {
  110. return rawHvsockAddr{
  111. Family: afHVSock,
  112. VMID: addr.VMID,
  113. ServiceID: addr.ServiceID,
  114. }
  115. }
  116. func (addr *HvsockAddr) fromRaw(raw *rawHvsockAddr) {
  117. addr.VMID = raw.VMID
  118. addr.ServiceID = raw.ServiceID
  119. }
  120. // Sockaddr returns a pointer to and the size of this struct.
  121. //
  122. // Implements the [socket.RawSockaddr] interface, and allows use in
  123. // [socket.Bind] and [socket.ConnectEx].
  124. func (r *rawHvsockAddr) Sockaddr() (unsafe.Pointer, int32, error) {
  125. return unsafe.Pointer(r), int32(unsafe.Sizeof(rawHvsockAddr{})), nil
  126. }
  127. // Sockaddr interface allows use with `sockets.Bind()` and `.ConnectEx()`.
  128. func (r *rawHvsockAddr) FromBytes(b []byte) error {
  129. n := int(unsafe.Sizeof(rawHvsockAddr{}))
  130. if len(b) < n {
  131. return fmt.Errorf("got %d, want %d: %w", len(b), n, socket.ErrBufferSize)
  132. }
  133. copy(unsafe.Slice((*byte)(unsafe.Pointer(r)), n), b[:n])
  134. if r.Family != afHVSock {
  135. return fmt.Errorf("got %d, want %d: %w", r.Family, afHVSock, socket.ErrAddrFamily)
  136. }
  137. return nil
  138. }
  139. // HvsockListener is a socket listener for the AF_HYPERV address family.
  140. type HvsockListener struct {
  141. sock *win32File
  142. addr HvsockAddr
  143. }
  144. var _ net.Listener = &HvsockListener{}
  145. // HvsockConn is a connected socket of the AF_HYPERV address family.
  146. type HvsockConn struct {
  147. sock *win32File
  148. local, remote HvsockAddr
  149. }
  150. var _ net.Conn = &HvsockConn{}
  151. func newHVSocket() (*win32File, error) {
  152. fd, err := windows.Socket(afHVSock, windows.SOCK_STREAM, 1)
  153. if err != nil {
  154. return nil, os.NewSyscallError("socket", err)
  155. }
  156. f, err := makeWin32File(fd)
  157. if err != nil {
  158. windows.Close(fd)
  159. return nil, err
  160. }
  161. f.socket = true
  162. return f, nil
  163. }
  164. // ListenHvsock listens for connections on the specified hvsock address.
  165. func ListenHvsock(addr *HvsockAddr) (_ *HvsockListener, err error) {
  166. l := &HvsockListener{addr: *addr}
  167. var sock *win32File
  168. sock, err = newHVSocket()
  169. if err != nil {
  170. return nil, l.opErr("listen", err)
  171. }
  172. defer func() {
  173. if err != nil {
  174. _ = sock.Close()
  175. }
  176. }()
  177. sa := addr.raw()
  178. err = socket.Bind(sock.handle, &sa)
  179. if err != nil {
  180. return nil, l.opErr("listen", os.NewSyscallError("socket", err))
  181. }
  182. err = windows.Listen(sock.handle, 16)
  183. if err != nil {
  184. return nil, l.opErr("listen", os.NewSyscallError("listen", err))
  185. }
  186. return &HvsockListener{sock: sock, addr: *addr}, nil
  187. }
  188. func (l *HvsockListener) opErr(op string, err error) error {
  189. return &net.OpError{Op: op, Net: "hvsock", Addr: &l.addr, Err: err}
  190. }
  191. // Addr returns the listener's network address.
  192. func (l *HvsockListener) Addr() net.Addr {
  193. return &l.addr
  194. }
  195. // Accept waits for the next connection and returns it.
  196. func (l *HvsockListener) Accept() (_ net.Conn, err error) {
  197. sock, err := newHVSocket()
  198. if err != nil {
  199. return nil, l.opErr("accept", err)
  200. }
  201. defer func() {
  202. if sock != nil {
  203. sock.Close()
  204. }
  205. }()
  206. c, err := l.sock.prepareIO()
  207. if err != nil {
  208. return nil, l.opErr("accept", err)
  209. }
  210. defer l.sock.wg.Done()
  211. // AcceptEx, per documentation, requires an extra 16 bytes per address.
  212. //
  213. // https://docs.microsoft.com/en-us/windows/win32/api/mswsock/nf-mswsock-acceptex
  214. const addrlen = uint32(16 + unsafe.Sizeof(rawHvsockAddr{}))
  215. var addrbuf [addrlen * 2]byte
  216. var bytes uint32
  217. err = windows.AcceptEx(l.sock.handle, sock.handle, &addrbuf[0], 0 /* rxdatalen */, addrlen, addrlen, &bytes, &c.o)
  218. if _, err = l.sock.asyncIO(c, nil, bytes, err); err != nil {
  219. return nil, l.opErr("accept", os.NewSyscallError("acceptex", err))
  220. }
  221. conn := &HvsockConn{
  222. sock: sock,
  223. }
  224. // The local address returned in the AcceptEx buffer is the same as the Listener socket's
  225. // address. However, the service GUID reported by GetSockName is different from the Listeners
  226. // socket, and is sometimes the same as the local address of the socket that dialed the
  227. // address, with the service GUID.Data1 incremented, but othertimes is different.
  228. // todo: does the local address matter? is the listener's address or the actual address appropriate?
  229. conn.local.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[0])))
  230. conn.remote.fromRaw((*rawHvsockAddr)(unsafe.Pointer(&addrbuf[addrlen])))
  231. // initialize the accepted socket and update its properties with those of the listening socket
  232. if err = windows.Setsockopt(sock.handle,
  233. windows.SOL_SOCKET, windows.SO_UPDATE_ACCEPT_CONTEXT,
  234. (*byte)(unsafe.Pointer(&l.sock.handle)), int32(unsafe.Sizeof(l.sock.handle))); err != nil {
  235. return nil, conn.opErr("accept", os.NewSyscallError("setsockopt", err))
  236. }
  237. sock = nil
  238. return conn, nil
  239. }
  240. // Close closes the listener, causing any pending Accept calls to fail.
  241. func (l *HvsockListener) Close() error {
  242. return l.sock.Close()
  243. }
  244. // HvsockDialer configures and dials a Hyper-V Socket (ie, [HvsockConn]).
  245. type HvsockDialer struct {
  246. // Deadline is the time the Dial operation must connect before erroring.
  247. Deadline time.Time
  248. // Retries is the number of additional connects to try if the connection times out, is refused,
  249. // or the host is unreachable
  250. Retries uint
  251. // RetryWait is the time to wait after a connection error to retry
  252. RetryWait time.Duration
  253. rt *time.Timer // redial wait timer
  254. }
  255. // Dial the Hyper-V socket at addr.
  256. //
  257. // See [HvsockDialer.Dial] for more information.
  258. func Dial(ctx context.Context, addr *HvsockAddr) (conn *HvsockConn, err error) {
  259. return (&HvsockDialer{}).Dial(ctx, addr)
  260. }
  261. // Dial attempts to connect to the Hyper-V socket at addr, and returns a connection if successful.
  262. // Will attempt (HvsockDialer).Retries if dialing fails, waiting (HvsockDialer).RetryWait between
  263. // retries.
  264. //
  265. // Dialing can be cancelled either by providing (HvsockDialer).Deadline, or cancelling ctx.
  266. func (d *HvsockDialer) Dial(ctx context.Context, addr *HvsockAddr) (conn *HvsockConn, err error) {
  267. op := "dial"
  268. // create the conn early to use opErr()
  269. conn = &HvsockConn{
  270. remote: *addr,
  271. }
  272. if !d.Deadline.IsZero() {
  273. var cancel context.CancelFunc
  274. ctx, cancel = context.WithDeadline(ctx, d.Deadline)
  275. defer cancel()
  276. }
  277. // preemptive timeout/cancellation check
  278. if err = ctx.Err(); err != nil {
  279. return nil, conn.opErr(op, err)
  280. }
  281. sock, err := newHVSocket()
  282. if err != nil {
  283. return nil, conn.opErr(op, err)
  284. }
  285. defer func() {
  286. if sock != nil {
  287. sock.Close()
  288. }
  289. }()
  290. sa := addr.raw()
  291. err = socket.Bind(sock.handle, &sa)
  292. if err != nil {
  293. return nil, conn.opErr(op, os.NewSyscallError("bind", err))
  294. }
  295. c, err := sock.prepareIO()
  296. if err != nil {
  297. return nil, conn.opErr(op, err)
  298. }
  299. defer sock.wg.Done()
  300. var bytes uint32
  301. for i := uint(0); i <= d.Retries; i++ {
  302. err = socket.ConnectEx(
  303. sock.handle,
  304. &sa,
  305. nil, // sendBuf
  306. 0, // sendDataLen
  307. &bytes,
  308. (*windows.Overlapped)(unsafe.Pointer(&c.o)))
  309. _, err = sock.asyncIO(c, nil, bytes, err)
  310. if i < d.Retries && canRedial(err) {
  311. if err = d.redialWait(ctx); err == nil {
  312. continue
  313. }
  314. }
  315. break
  316. }
  317. if err != nil {
  318. return nil, conn.opErr(op, os.NewSyscallError("connectex", err))
  319. }
  320. // update the connection properties, so shutdown can be used
  321. if err = windows.Setsockopt(
  322. sock.handle,
  323. windows.SOL_SOCKET,
  324. windows.SO_UPDATE_CONNECT_CONTEXT,
  325. nil, // optvalue
  326. 0, // optlen
  327. ); err != nil {
  328. return nil, conn.opErr(op, os.NewSyscallError("setsockopt", err))
  329. }
  330. // get the local name
  331. var sal rawHvsockAddr
  332. err = socket.GetSockName(sock.handle, &sal)
  333. if err != nil {
  334. return nil, conn.opErr(op, os.NewSyscallError("getsockname", err))
  335. }
  336. conn.local.fromRaw(&sal)
  337. // one last check for timeout, since asyncIO doesn't check the context
  338. if err = ctx.Err(); err != nil {
  339. return nil, conn.opErr(op, err)
  340. }
  341. conn.sock = sock
  342. sock = nil
  343. return conn, nil
  344. }
  345. // redialWait waits before attempting to redial, resetting the timer as appropriate.
  346. func (d *HvsockDialer) redialWait(ctx context.Context) (err error) {
  347. if d.RetryWait == 0 {
  348. return nil
  349. }
  350. if d.rt == nil {
  351. d.rt = time.NewTimer(d.RetryWait)
  352. } else {
  353. // should already be stopped and drained
  354. d.rt.Reset(d.RetryWait)
  355. }
  356. select {
  357. case <-ctx.Done():
  358. case <-d.rt.C:
  359. return nil
  360. }
  361. // stop and drain the timer
  362. if !d.rt.Stop() {
  363. <-d.rt.C
  364. }
  365. return ctx.Err()
  366. }
  367. // assumes error is a plain, unwrapped windows.Errno provided by direct syscall.
  368. func canRedial(err error) bool {
  369. //nolint:errorlint // guaranteed to be an Errno
  370. switch err {
  371. case windows.WSAECONNREFUSED, windows.WSAENETUNREACH, windows.WSAETIMEDOUT,
  372. windows.ERROR_CONNECTION_REFUSED, windows.ERROR_CONNECTION_UNAVAIL:
  373. return true
  374. default:
  375. return false
  376. }
  377. }
  378. func (conn *HvsockConn) opErr(op string, err error) error {
  379. // translate from "file closed" to "socket closed"
  380. if errors.Is(err, ErrFileClosed) {
  381. err = socket.ErrSocketClosed
  382. }
  383. return &net.OpError{Op: op, Net: "hvsock", Source: &conn.local, Addr: &conn.remote, Err: err}
  384. }
  385. func (conn *HvsockConn) Read(b []byte) (int, error) {
  386. c, err := conn.sock.prepareIO()
  387. if err != nil {
  388. return 0, conn.opErr("read", err)
  389. }
  390. defer conn.sock.wg.Done()
  391. buf := windows.WSABuf{Buf: &b[0], Len: uint32(len(b))}
  392. var flags, bytes uint32
  393. err = windows.WSARecv(conn.sock.handle, &buf, 1, &bytes, &flags, &c.o, nil)
  394. n, err := conn.sock.asyncIO(c, &conn.sock.readDeadline, bytes, err)
  395. if err != nil {
  396. var eno windows.Errno
  397. if errors.As(err, &eno) {
  398. err = os.NewSyscallError("wsarecv", eno)
  399. }
  400. return 0, conn.opErr("read", err)
  401. } else if n == 0 {
  402. err = io.EOF
  403. }
  404. return n, err
  405. }
  406. func (conn *HvsockConn) Write(b []byte) (int, error) {
  407. t := 0
  408. for len(b) != 0 {
  409. n, err := conn.write(b)
  410. if err != nil {
  411. return t + n, err
  412. }
  413. t += n
  414. b = b[n:]
  415. }
  416. return t, nil
  417. }
  418. func (conn *HvsockConn) write(b []byte) (int, error) {
  419. c, err := conn.sock.prepareIO()
  420. if err != nil {
  421. return 0, conn.opErr("write", err)
  422. }
  423. defer conn.sock.wg.Done()
  424. buf := windows.WSABuf{Buf: &b[0], Len: uint32(len(b))}
  425. var bytes uint32
  426. err = windows.WSASend(conn.sock.handle, &buf, 1, &bytes, 0, &c.o, nil)
  427. n, err := conn.sock.asyncIO(c, &conn.sock.writeDeadline, bytes, err)
  428. if err != nil {
  429. var eno windows.Errno
  430. if errors.As(err, &eno) {
  431. err = os.NewSyscallError("wsasend", eno)
  432. }
  433. return 0, conn.opErr("write", err)
  434. }
  435. return n, err
  436. }
  437. // Close closes the socket connection, failing any pending read or write calls.
  438. func (conn *HvsockConn) Close() error {
  439. return conn.sock.Close()
  440. }
  441. func (conn *HvsockConn) IsClosed() bool {
  442. return conn.sock.IsClosed()
  443. }
  444. // shutdown disables sending or receiving on a socket.
  445. func (conn *HvsockConn) shutdown(how int) error {
  446. if conn.IsClosed() {
  447. return socket.ErrSocketClosed
  448. }
  449. err := windows.Shutdown(conn.sock.handle, how)
  450. if err != nil {
  451. // If the connection was closed, shutdowns fail with "not connected"
  452. if errors.Is(err, windows.WSAENOTCONN) ||
  453. errors.Is(err, windows.WSAESHUTDOWN) {
  454. err = socket.ErrSocketClosed
  455. }
  456. return os.NewSyscallError("shutdown", err)
  457. }
  458. return nil
  459. }
  460. // CloseRead shuts down the read end of the socket, preventing future read operations.
  461. func (conn *HvsockConn) CloseRead() error {
  462. err := conn.shutdown(windows.SHUT_RD)
  463. if err != nil {
  464. return conn.opErr("closeread", err)
  465. }
  466. return nil
  467. }
  468. // CloseWrite shuts down the write end of the socket, preventing future write operations and
  469. // notifying the other endpoint that no more data will be written.
  470. func (conn *HvsockConn) CloseWrite() error {
  471. err := conn.shutdown(windows.SHUT_WR)
  472. if err != nil {
  473. return conn.opErr("closewrite", err)
  474. }
  475. return nil
  476. }
  477. // LocalAddr returns the local address of the connection.
  478. func (conn *HvsockConn) LocalAddr() net.Addr {
  479. return &conn.local
  480. }
  481. // RemoteAddr returns the remote address of the connection.
  482. func (conn *HvsockConn) RemoteAddr() net.Addr {
  483. return &conn.remote
  484. }
  485. // SetDeadline implements the net.Conn SetDeadline method.
  486. func (conn *HvsockConn) SetDeadline(t time.Time) error {
  487. // todo: implement `SetDeadline` for `win32File`
  488. if err := conn.SetReadDeadline(t); err != nil {
  489. return fmt.Errorf("set read deadline: %w", err)
  490. }
  491. if err := conn.SetWriteDeadline(t); err != nil {
  492. return fmt.Errorf("set write deadline: %w", err)
  493. }
  494. return nil
  495. }
  496. // SetReadDeadline implements the net.Conn SetReadDeadline method.
  497. func (conn *HvsockConn) SetReadDeadline(t time.Time) error {
  498. return conn.sock.SetReadDeadline(t)
  499. }
  500. // SetWriteDeadline implements the net.Conn SetWriteDeadline method.
  501. func (conn *HvsockConn) SetWriteDeadline(t time.Time) error {
  502. return conn.sock.SetWriteDeadline(t)
  503. }