mmsg.go 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. package mmsg
  2. import (
  3. "net"
  4. "github.com/anacrolix/missinggo/expect"
  5. "github.com/anacrolix/mmsg/socket"
  6. )
  7. // Considered MSG_DONTWAIT, but I think Go puts the socket into non-blocking
  8. // mode in its runtime and it seems to do the right thing.
  9. const flags = 0
  10. type Conn struct {
  11. err error
  12. s *socket.Conn
  13. pr PacketReader
  14. }
  15. type PacketReader interface {
  16. ReadFrom([]byte) (int, net.Addr, error)
  17. }
  18. func NewConn(pr PacketReader) *Conn {
  19. ret := Conn{
  20. pr: pr,
  21. }
  22. ret.s, ret.err = socket.NewConn(pr)
  23. return &ret
  24. }
  25. func (me *Conn) recvMsgAsMsgs(ms []Message) (int, error) {
  26. err := me.RecvMsg(&ms[0])
  27. if err != nil {
  28. return 0, err
  29. }
  30. return 1, err
  31. }
  32. func (me *Conn) RecvMsgs(ms []Message) (n int, err error) {
  33. if me.err != nil || len(ms) == 1 {
  34. return me.recvMsgAsMsgs(ms)
  35. }
  36. sms := make([]socket.Message, len(ms))
  37. for i := range ms {
  38. sms[i].Buffers = ms[i].Buffers
  39. }
  40. n, err = me.s.RecvMsgs(sms, flags)
  41. if err != nil && err.Error() == "not implemented" {
  42. expect.Nil(me.err)
  43. me.err = err
  44. if n <= 0 {
  45. return me.recvMsgAsMsgs(ms)
  46. }
  47. err = nil
  48. }
  49. for i := 0; i < n; i++ {
  50. ms[i].Addr = sms[i].Addr
  51. ms[i].N = sms[i].N
  52. }
  53. return n, err
  54. }
  55. func (me *Conn) RecvMsg(m *Message) error {
  56. if len(m.Buffers) == 1 { // What about 0?
  57. var err error
  58. m.N, m.Addr, err = me.pr.ReadFrom(m.Buffers[0])
  59. return err
  60. }
  61. sm := socket.Message{
  62. Buffers: m.Buffers,
  63. }
  64. err := me.s.RecvMsg(&sm, flags)
  65. m.Addr = sm.Addr
  66. m.N = sm.N
  67. return err
  68. }
  69. type Message struct {
  70. Buffers [][]byte
  71. N int
  72. Addr net.Addr
  73. }
  74. func (me *Message) Payload() (p []byte) {
  75. n := me.N
  76. for _, b := range me.Buffers {
  77. if len(b) >= n {
  78. p = append(p, b[:n]...)
  79. return
  80. }
  81. p = append(p, b...)
  82. n -= len(b)
  83. }
  84. panic(n)
  85. }
  86. func (me *Conn) Err() error {
  87. return me.err
  88. }