packet.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. package sctp
  2. import (
  3. "encoding/binary"
  4. "errors"
  5. "fmt"
  6. "hash/crc32"
  7. )
  8. // Create the crc32 table we'll use for the checksum
  9. var castagnoliTable = crc32.MakeTable(crc32.Castagnoli) // nolint:gochecknoglobals
  10. // Allocate and zero this data once.
  11. // We need to use it for the checksum and don't want to allocate/clear each time.
  12. var fourZeroes [4]byte // nolint:gochecknoglobals
  13. /*
  14. Packet represents an SCTP packet, defined in https://tools.ietf.org/html/rfc4960#section-3
  15. An SCTP packet is composed of a common header and chunks. A chunk
  16. contains either control information or user data.
  17. SCTP Packet Format
  18. 0 1 2 3
  19. 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
  20. +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
  21. | Common Header |
  22. +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
  23. | Chunk #1 |
  24. +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
  25. | ... |
  26. +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
  27. | Chunk #n |
  28. +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
  29. SCTP Common Header Format
  30. 0 1 2 3
  31. 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
  32. +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
  33. | Source Value Number | Destination Value Number |
  34. +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
  35. | Verification Tag |
  36. +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
  37. | Checksum |
  38. +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
  39. */
  40. type packet struct {
  41. sourcePort uint16
  42. destinationPort uint16
  43. verificationTag uint32
  44. chunks []chunk
  45. }
  46. const (
  47. packetHeaderSize = 12
  48. )
  49. var (
  50. errPacketRawTooSmall = errors.New("raw is smaller than the minimum length for a SCTP packet")
  51. errParseSCTPChunkNotEnoughData = errors.New("unable to parse SCTP chunk, not enough data for complete header")
  52. errUnmarshalUnknownChunkType = errors.New("failed to unmarshal, contains unknown chunk type")
  53. errChecksumMismatch = errors.New("checksum mismatch theirs")
  54. )
  55. func (p *packet) unmarshal(raw []byte) error {
  56. if len(raw) < packetHeaderSize {
  57. return fmt.Errorf("%w: raw only %d bytes, %d is the minimum length", errPacketRawTooSmall, len(raw), packetHeaderSize)
  58. }
  59. p.sourcePort = binary.BigEndian.Uint16(raw[0:])
  60. p.destinationPort = binary.BigEndian.Uint16(raw[2:])
  61. p.verificationTag = binary.BigEndian.Uint32(raw[4:])
  62. offset := packetHeaderSize
  63. for {
  64. // Exact match, no more chunks
  65. if offset == len(raw) {
  66. break
  67. } else if offset+chunkHeaderSize > len(raw) {
  68. return fmt.Errorf("%w: offset %d remaining %d", errParseSCTPChunkNotEnoughData, offset, len(raw))
  69. }
  70. var c chunk
  71. switch chunkType(raw[offset]) {
  72. case ctInit:
  73. c = &chunkInit{}
  74. case ctInitAck:
  75. c = &chunkInitAck{}
  76. case ctAbort:
  77. c = &chunkAbort{}
  78. case ctCookieEcho:
  79. c = &chunkCookieEcho{}
  80. case ctCookieAck:
  81. c = &chunkCookieAck{}
  82. case ctHeartbeat:
  83. c = &chunkHeartbeat{}
  84. case ctPayloadData:
  85. c = &chunkPayloadData{}
  86. case ctSack:
  87. c = &chunkSelectiveAck{}
  88. case ctReconfig:
  89. c = &chunkReconfig{}
  90. case ctForwardTSN:
  91. c = &chunkForwardTSN{}
  92. case ctError:
  93. c = &chunkError{}
  94. case ctShutdown:
  95. c = &chunkShutdown{}
  96. case ctShutdownAck:
  97. c = &chunkShutdownAck{}
  98. case ctShutdownComplete:
  99. c = &chunkShutdownComplete{}
  100. default:
  101. return fmt.Errorf("%w: %s", errUnmarshalUnknownChunkType, chunkType(raw[offset]).String())
  102. }
  103. if err := c.unmarshal(raw[offset:]); err != nil {
  104. return err
  105. }
  106. p.chunks = append(p.chunks, c)
  107. chunkValuePadding := getPadding(c.valueLength())
  108. offset += chunkHeaderSize + c.valueLength() + chunkValuePadding
  109. }
  110. theirChecksum := binary.LittleEndian.Uint32(raw[8:])
  111. ourChecksum := generatePacketChecksum(raw)
  112. if theirChecksum != ourChecksum {
  113. return fmt.Errorf("%w: %d ours: %d", errChecksumMismatch, theirChecksum, ourChecksum)
  114. }
  115. return nil
  116. }
  117. func (p *packet) marshal() ([]byte, error) {
  118. raw := make([]byte, packetHeaderSize)
  119. // Populate static headers
  120. // 8-12 is Checksum which will be populated when packet is complete
  121. binary.BigEndian.PutUint16(raw[0:], p.sourcePort)
  122. binary.BigEndian.PutUint16(raw[2:], p.destinationPort)
  123. binary.BigEndian.PutUint32(raw[4:], p.verificationTag)
  124. // Populate chunks
  125. for _, c := range p.chunks {
  126. chunkRaw, err := c.marshal()
  127. if err != nil {
  128. return nil, err
  129. }
  130. raw = append(raw, chunkRaw...)
  131. paddingNeeded := getPadding(len(raw))
  132. if paddingNeeded != 0 {
  133. raw = append(raw, make([]byte, paddingNeeded)...)
  134. }
  135. }
  136. // Checksum is already in BigEndian
  137. // Using LittleEndian.PutUint32 stops it from being flipped
  138. binary.LittleEndian.PutUint32(raw[8:], generatePacketChecksum(raw))
  139. return raw, nil
  140. }
  141. func generatePacketChecksum(raw []byte) (sum uint32) {
  142. // Fastest way to do a crc32 without allocating.
  143. sum = crc32.Update(sum, castagnoliTable, raw[0:8])
  144. sum = crc32.Update(sum, castagnoliTable, fourZeroes[:])
  145. sum = crc32.Update(sum, castagnoliTable, raw[12:])
  146. return sum
  147. }
  148. // String makes packet printable
  149. func (p *packet) String() string {
  150. format := `Packet:
  151. sourcePort: %d
  152. destinationPort: %d
  153. verificationTag: %d
  154. `
  155. res := fmt.Sprintf(format,
  156. p.sourcePort,
  157. p.destinationPort,
  158. p.verificationTag,
  159. )
  160. for i, chunk := range p.chunks {
  161. res += fmt.Sprintf("Chunk %d:\n %s", i, chunk)
  162. }
  163. return res
  164. }