| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213 |
- package peer_protocol
- import (
- "bufio"
- "encoding/binary"
- "fmt"
- "io"
- "sync"
- g "github.com/anacrolix/generics"
- "github.com/pkg/errors"
- )
- type Decoder struct {
- R *bufio.Reader
- // This must return *[]byte where the slices can fit data for piece messages. I think we store
- // *[]byte in the pool to avoid an extra allocation every time we put the slice back into the
- // pool. The chunk size should not change for the life of the decoder.
- Pool *sync.Pool
- MaxLength Integer // TODO: Should this include the length header or not?
- }
- // This limits reads to the length of a message, returning io.EOF when the end of the message bytes
- // are reached. If you aren't expecting io.EOF, you should probably wrap it with expectReader.
- type decodeReader struct {
- lr io.LimitedReader
- br *bufio.Reader
- }
- func (dr *decodeReader) Init(r *bufio.Reader, length int64) {
- dr.lr.R = r
- dr.lr.N = length
- dr.br = r
- }
- func (dr *decodeReader) ReadByte() (c byte, err error) {
- if dr.lr.N <= 0 {
- err = io.EOF
- return
- }
- c, err = dr.br.ReadByte()
- if err == nil {
- dr.lr.N--
- }
- return
- }
- func (dr *decodeReader) Read(p []byte) (n int, err error) {
- n, err = dr.lr.Read(p)
- if dr.lr.N != 0 && err == io.EOF {
- err = io.ErrUnexpectedEOF
- }
- return
- }
- func (dr *decodeReader) UnreadLength() int64 {
- return dr.lr.N
- }
- // This expects reads to have enough bytes. io.EOF is mapped to io.ErrUnexpectedEOF. It's probably
- // not a good idea to pass this to functions that expect to read until the end of something, because
- // they will probably expect io.EOF.
- type expectReader struct {
- dr *decodeReader
- }
- func (er expectReader) ReadByte() (c byte, err error) {
- c, err = er.dr.ReadByte()
- if err == io.EOF {
- err = io.ErrUnexpectedEOF
- }
- return
- }
- func (er expectReader) Read(p []byte) (n int, err error) {
- n, err = er.dr.Read(p)
- if err == io.EOF {
- err = io.ErrUnexpectedEOF
- }
- return
- }
- func (er expectReader) UnreadLength() int64 {
- return er.dr.UnreadLength()
- }
- // io.EOF is returned if the source terminates cleanly on a message boundary.
- func (d *Decoder) Decode(msg *Message) (err error) {
- var dr decodeReader
- {
- var length Integer
- err = length.Read(d.R)
- if err != nil {
- return fmt.Errorf("reading message length: %w", err)
- }
- if length > d.MaxLength {
- return errors.New("message too long")
- }
- if length == 0 {
- msg.Keepalive = true
- return
- }
- dr.Init(d.R, int64(length))
- }
- r := expectReader{&dr}
- c, err := r.ReadByte()
- if err != nil {
- return
- }
- msg.Type = MessageType(c)
- err = readMessageAfterType(msg, &r, d.Pool)
- if err != nil {
- err = fmt.Errorf("reading fields for message type %v: %w", msg.Type, err)
- return
- }
- if r.UnreadLength() != 0 {
- err = fmt.Errorf("%v unused bytes in message type %v", r.UnreadLength(), msg.Type)
- }
- return
- }
- func readMessageAfterType(msg *Message, r *expectReader, piecePool *sync.Pool) (err error) {
- switch msg.Type {
- case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
- case Have, AllowedFast, Suggest:
- err = msg.Index.Read(r)
- case Request, Cancel, Reject:
- for _, data := range []*Integer{&msg.Index, &msg.Begin, &msg.Length} {
- err = data.Read(r)
- if err != nil {
- break
- }
- }
- case Bitfield:
- b := make([]byte, r.UnreadLength())
- _, err = io.ReadFull(r, b)
- msg.Bitfield = unmarshalBitfield(b)
- case Piece:
- for _, pi := range []*Integer{&msg.Index, &msg.Begin} {
- err = pi.Read(r)
- if err != nil {
- return
- }
- }
- dataLen := r.UnreadLength()
- if piecePool == nil {
- msg.Piece = make([]byte, dataLen)
- } else {
- msg.Piece = *piecePool.Get().(*[]byte)
- if int64(cap(msg.Piece)) < dataLen {
- return errors.New("piece data longer than expected")
- }
- msg.Piece = msg.Piece[:dataLen]
- }
- _, err = io.ReadFull(r, msg.Piece)
- case Extended:
- var b byte
- b, err = r.ReadByte()
- if err != nil {
- break
- }
- msg.ExtendedID = ExtensionNumber(b)
- msg.ExtendedPayload = make([]byte, r.UnreadLength())
- _, err = io.ReadFull(r, msg.ExtendedPayload)
- case Port:
- err = binary.Read(r, binary.BigEndian, &msg.Port)
- case HashRequest, HashReject:
- err = readHashRequest(r, msg)
- case Hashes:
- err = readHashRequest(r, msg)
- numHashes := (r.UnreadLength() + 31) / 32
- g.MakeSliceWithCap(&msg.Hashes, numHashes)
- for range numHashes {
- var oneHash [32]byte
- _, err = io.ReadFull(r, oneHash[:])
- if err != nil {
- err = fmt.Errorf("error while reading hashes: %w", err)
- return
- }
- msg.Hashes = append(msg.Hashes, oneHash)
- }
- default:
- err = errors.New("unhandled message type")
- }
- return
- }
- func readHashRequest(r io.Reader, msg *Message) (err error) {
- _, err = io.ReadFull(r, msg.PiecesRoot[:])
- if err != nil {
- return
- }
- return readSeq(r, &msg.BaseLayer, &msg.Index, &msg.Length, &msg.ProofLayers)
- }
- func readSeq(r io.Reader, data ...any) (err error) {
- for _, d := range data {
- err = binary.Read(r, binary.BigEndian, d)
- if err != nil {
- return
- }
- }
- return
- }
- func unmarshalBitfield(b []byte) (bf []bool) {
- for _, c := range b {
- for i := 7; i >= 0; i-- {
- bf = append(bf, (c>>uint(i))&1 == 1)
- }
- }
- return
- }
|