decoder.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. package peer_protocol
  2. import (
  3. "bufio"
  4. "encoding/binary"
  5. "fmt"
  6. "io"
  7. "sync"
  8. g "github.com/anacrolix/generics"
  9. "github.com/pkg/errors"
  10. )
  11. type Decoder struct {
  12. R *bufio.Reader
  13. // This must return *[]byte where the slices can fit data for piece messages. I think we store
  14. // *[]byte in the pool to avoid an extra allocation every time we put the slice back into the
  15. // pool. The chunk size should not change for the life of the decoder.
  16. Pool *sync.Pool
  17. MaxLength Integer // TODO: Should this include the length header or not?
  18. }
  19. // This limits reads to the length of a message, returning io.EOF when the end of the message bytes
  20. // are reached. If you aren't expecting io.EOF, you should probably wrap it with expectReader.
  21. type decodeReader struct {
  22. lr io.LimitedReader
  23. br *bufio.Reader
  24. }
  25. func (dr *decodeReader) Init(r *bufio.Reader, length int64) {
  26. dr.lr.R = r
  27. dr.lr.N = length
  28. dr.br = r
  29. }
  30. func (dr *decodeReader) ReadByte() (c byte, err error) {
  31. if dr.lr.N <= 0 {
  32. err = io.EOF
  33. return
  34. }
  35. c, err = dr.br.ReadByte()
  36. if err == nil {
  37. dr.lr.N--
  38. }
  39. return
  40. }
  41. func (dr *decodeReader) Read(p []byte) (n int, err error) {
  42. n, err = dr.lr.Read(p)
  43. if dr.lr.N != 0 && err == io.EOF {
  44. err = io.ErrUnexpectedEOF
  45. }
  46. return
  47. }
  48. func (dr *decodeReader) UnreadLength() int64 {
  49. return dr.lr.N
  50. }
  51. // This expects reads to have enough bytes. io.EOF is mapped to io.ErrUnexpectedEOF. It's probably
  52. // not a good idea to pass this to functions that expect to read until the end of something, because
  53. // they will probably expect io.EOF.
  54. type expectReader struct {
  55. dr *decodeReader
  56. }
  57. func (er expectReader) ReadByte() (c byte, err error) {
  58. c, err = er.dr.ReadByte()
  59. if err == io.EOF {
  60. err = io.ErrUnexpectedEOF
  61. }
  62. return
  63. }
  64. func (er expectReader) Read(p []byte) (n int, err error) {
  65. n, err = er.dr.Read(p)
  66. if err == io.EOF {
  67. err = io.ErrUnexpectedEOF
  68. }
  69. return
  70. }
  71. func (er expectReader) UnreadLength() int64 {
  72. return er.dr.UnreadLength()
  73. }
  74. // io.EOF is returned if the source terminates cleanly on a message boundary.
  75. func (d *Decoder) Decode(msg *Message) (err error) {
  76. var dr decodeReader
  77. {
  78. var length Integer
  79. err = length.Read(d.R)
  80. if err != nil {
  81. return fmt.Errorf("reading message length: %w", err)
  82. }
  83. if length > d.MaxLength {
  84. return errors.New("message too long")
  85. }
  86. if length == 0 {
  87. msg.Keepalive = true
  88. return
  89. }
  90. dr.Init(d.R, int64(length))
  91. }
  92. r := expectReader{&dr}
  93. c, err := r.ReadByte()
  94. if err != nil {
  95. return
  96. }
  97. msg.Type = MessageType(c)
  98. err = readMessageAfterType(msg, &r, d.Pool)
  99. if err != nil {
  100. err = fmt.Errorf("reading fields for message type %v: %w", msg.Type, err)
  101. return
  102. }
  103. if r.UnreadLength() != 0 {
  104. err = fmt.Errorf("%v unused bytes in message type %v", r.UnreadLength(), msg.Type)
  105. }
  106. return
  107. }
  108. func readMessageAfterType(msg *Message, r *expectReader, piecePool *sync.Pool) (err error) {
  109. switch msg.Type {
  110. case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
  111. case Have, AllowedFast, Suggest:
  112. err = msg.Index.Read(r)
  113. case Request, Cancel, Reject:
  114. for _, data := range []*Integer{&msg.Index, &msg.Begin, &msg.Length} {
  115. err = data.Read(r)
  116. if err != nil {
  117. break
  118. }
  119. }
  120. case Bitfield:
  121. b := make([]byte, r.UnreadLength())
  122. _, err = io.ReadFull(r, b)
  123. msg.Bitfield = unmarshalBitfield(b)
  124. case Piece:
  125. for _, pi := range []*Integer{&msg.Index, &msg.Begin} {
  126. err = pi.Read(r)
  127. if err != nil {
  128. return
  129. }
  130. }
  131. dataLen := r.UnreadLength()
  132. if piecePool == nil {
  133. msg.Piece = make([]byte, dataLen)
  134. } else {
  135. msg.Piece = *piecePool.Get().(*[]byte)
  136. if int64(cap(msg.Piece)) < dataLen {
  137. return errors.New("piece data longer than expected")
  138. }
  139. msg.Piece = msg.Piece[:dataLen]
  140. }
  141. _, err = io.ReadFull(r, msg.Piece)
  142. case Extended:
  143. var b byte
  144. b, err = r.ReadByte()
  145. if err != nil {
  146. break
  147. }
  148. msg.ExtendedID = ExtensionNumber(b)
  149. msg.ExtendedPayload = make([]byte, r.UnreadLength())
  150. _, err = io.ReadFull(r, msg.ExtendedPayload)
  151. case Port:
  152. err = binary.Read(r, binary.BigEndian, &msg.Port)
  153. case HashRequest, HashReject:
  154. err = readHashRequest(r, msg)
  155. case Hashes:
  156. err = readHashRequest(r, msg)
  157. numHashes := (r.UnreadLength() + 31) / 32
  158. g.MakeSliceWithCap(&msg.Hashes, numHashes)
  159. for range numHashes {
  160. var oneHash [32]byte
  161. _, err = io.ReadFull(r, oneHash[:])
  162. if err != nil {
  163. err = fmt.Errorf("error while reading hashes: %w", err)
  164. return
  165. }
  166. msg.Hashes = append(msg.Hashes, oneHash)
  167. }
  168. default:
  169. err = errors.New("unhandled message type")
  170. }
  171. return
  172. }
  173. func readHashRequest(r io.Reader, msg *Message) (err error) {
  174. _, err = io.ReadFull(r, msg.PiecesRoot[:])
  175. if err != nil {
  176. return
  177. }
  178. return readSeq(r, &msg.BaseLayer, &msg.Index, &msg.Length, &msg.ProofLayers)
  179. }
  180. func readSeq(r io.Reader, data ...any) (err error) {
  181. for _, d := range data {
  182. err = binary.Read(r, binary.BigEndian, d)
  183. if err != nil {
  184. return
  185. }
  186. }
  187. return
  188. }
  189. func unmarshalBitfield(b []byte) (bf []bool) {
  190. for _, c := range b {
  191. for i := 7; i >= 0; i-- {
  192. bf = append(bf, (c>>uint(i))&1 == 1)
  193. }
  194. }
  195. return
  196. }