msg.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. package peer_protocol
  2. import (
  3. "bufio"
  4. "bytes"
  5. "encoding"
  6. "encoding/binary"
  7. "fmt"
  8. "io"
  9. )
  10. // This is a lazy union representing all the possible fields for messages. Go doesn't have ADTs, and
  11. // I didn't choose to use type-assertions. Fields are ordered to minimize struct size and padding.
  12. type Message struct {
  13. PiecesRoot [32]byte
  14. Piece []byte
  15. Bitfield []bool
  16. ExtendedPayload []byte
  17. Hashes [][32]byte
  18. Index, Begin, Length Integer
  19. BaseLayer Integer
  20. ProofLayers Integer
  21. Port uint16
  22. Type MessageType
  23. ExtendedID ExtensionNumber
  24. Keepalive bool
  25. }
  26. var _ interface {
  27. encoding.BinaryUnmarshaler
  28. encoding.BinaryMarshaler
  29. } = (*Message)(nil)
  30. func MakeCancelMessage(piece, offset, length Integer) Message {
  31. return Message{
  32. Type: Cancel,
  33. Index: piece,
  34. Begin: offset,
  35. Length: length,
  36. }
  37. }
  38. func (msg Message) RequestSpec() (ret RequestSpec) {
  39. return RequestSpec{
  40. msg.Index,
  41. msg.Begin,
  42. func() Integer {
  43. if msg.Type == Piece {
  44. return Integer(len(msg.Piece))
  45. } else {
  46. return msg.Length
  47. }
  48. }(),
  49. }
  50. }
  51. func (msg Message) MustMarshalBinary() []byte {
  52. b, err := msg.MarshalBinary()
  53. if err != nil {
  54. panic(err)
  55. }
  56. return b
  57. }
  58. type MessageWriter interface {
  59. io.ByteWriter
  60. io.Writer
  61. }
  62. func (msg *Message) writeHashCommon(buf MessageWriter) (err error) {
  63. if _, err = buf.Write(msg.PiecesRoot[:]); err != nil {
  64. return
  65. }
  66. for _, d := range []Integer{msg.BaseLayer, msg.Index, msg.Length, msg.ProofLayers} {
  67. if err = binary.Write(buf, binary.BigEndian, d); err != nil {
  68. return
  69. }
  70. }
  71. return nil
  72. }
  73. func (msg *Message) writePayloadTo(buf MessageWriter) (err error) {
  74. if !msg.Keepalive {
  75. err = buf.WriteByte(byte(msg.Type))
  76. if err != nil {
  77. return
  78. }
  79. switch msg.Type {
  80. case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
  81. case Have, AllowedFast, Suggest:
  82. err = binary.Write(buf, binary.BigEndian, msg.Index)
  83. case Request, Cancel, Reject:
  84. for _, i := range []Integer{msg.Index, msg.Begin, msg.Length} {
  85. err = binary.Write(buf, binary.BigEndian, i)
  86. if err != nil {
  87. break
  88. }
  89. }
  90. case Bitfield:
  91. _, err = buf.Write(marshalBitfield(msg.Bitfield))
  92. case Piece:
  93. for _, i := range []Integer{msg.Index, msg.Begin} {
  94. err = binary.Write(buf, binary.BigEndian, i)
  95. if err != nil {
  96. return
  97. }
  98. }
  99. n, err := buf.Write(msg.Piece)
  100. if err != nil {
  101. break
  102. }
  103. if n != len(msg.Piece) {
  104. panic(n)
  105. }
  106. case Extended:
  107. err = buf.WriteByte(byte(msg.ExtendedID))
  108. if err != nil {
  109. return
  110. }
  111. _, err = buf.Write(msg.ExtendedPayload)
  112. case Port:
  113. err = binary.Write(buf, binary.BigEndian, msg.Port)
  114. case HashRequest, HashReject:
  115. err = msg.writeHashCommon(buf)
  116. case Hashes:
  117. err = msg.writeHashCommon(buf)
  118. if err != nil {
  119. return
  120. }
  121. for _, h := range msg.Hashes {
  122. if _, err = buf.Write(h[:]); err != nil {
  123. return
  124. }
  125. }
  126. default:
  127. err = fmt.Errorf("unknown message type: %v", msg.Type)
  128. }
  129. }
  130. return
  131. }
  132. func (msg *Message) WriteTo(w MessageWriter) (err error) {
  133. length, err := msg.getPayloadLength()
  134. if err != nil {
  135. return
  136. }
  137. err = binary.Write(w, binary.BigEndian, length)
  138. if err != nil {
  139. return
  140. }
  141. return msg.writePayloadTo(w)
  142. }
  143. func (msg *Message) getPayloadLength() (length Integer, err error) {
  144. var lw lengthWriter
  145. err = msg.writePayloadTo(&lw)
  146. length = lw.n
  147. return
  148. }
  149. func (msg Message) MarshalBinary() (data []byte, err error) {
  150. // It might look like you could have a pool of buffers and preallocate the message length
  151. // prefix, but because we have to return []byte, it becomes non-trivial to make this fast. You
  152. // will need a benchmark.
  153. var buf bytes.Buffer
  154. err = msg.WriteTo(&buf)
  155. data = buf.Bytes()
  156. return
  157. }
  158. func marshalBitfield(bf []bool) (b []byte) {
  159. b = make([]byte, (len(bf)+7)/8)
  160. for i, have := range bf {
  161. if !have {
  162. continue
  163. }
  164. c := b[i/8]
  165. c |= 1 << uint(7-i%8)
  166. b[i/8] = c
  167. }
  168. return
  169. }
  170. func (me *Message) UnmarshalBinary(b []byte) error {
  171. d := Decoder{
  172. R: bufio.NewReader(bytes.NewReader(b)),
  173. }
  174. err := d.Decode(me)
  175. if err != nil {
  176. return err
  177. }
  178. if d.R.Buffered() != 0 {
  179. return fmt.Errorf("%d trailing bytes", d.R.Buffered())
  180. }
  181. return nil
  182. }
  183. type lengthWriter struct {
  184. n Integer
  185. }
  186. func (l *lengthWriter) WriteByte(c byte) error {
  187. l.n++
  188. return nil
  189. }
  190. func (l *lengthWriter) Write(p []byte) (n int, err error) {
  191. n = len(p)
  192. l.n += Integer(n)
  193. return
  194. }