fragment_buffer.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. package dtls
  2. import (
  3. "github.com/pion/dtls/v2/pkg/protocol"
  4. "github.com/pion/dtls/v2/pkg/protocol/handshake"
  5. "github.com/pion/dtls/v2/pkg/protocol/recordlayer"
  6. )
  7. // 2 megabytes
  8. const fragmentBufferMaxSize = 2000000
  9. type fragment struct {
  10. recordLayerHeader recordlayer.Header
  11. handshakeHeader handshake.Header
  12. data []byte
  13. }
  14. type fragmentBuffer struct {
  15. // map of MessageSequenceNumbers that hold slices of fragments
  16. cache map[uint16][]*fragment
  17. currentMessageSequenceNumber uint16
  18. }
  19. func newFragmentBuffer() *fragmentBuffer {
  20. return &fragmentBuffer{cache: map[uint16][]*fragment{}}
  21. }
  22. // current total size of buffer
  23. func (f *fragmentBuffer) size() int {
  24. size := 0
  25. for i := range f.cache {
  26. for j := range f.cache[i] {
  27. size += len(f.cache[i][j].data)
  28. }
  29. }
  30. return size
  31. }
  32. // Attempts to push a DTLS packet to the fragmentBuffer
  33. // when it returns true it means the fragmentBuffer has inserted and the buffer shouldn't be handled
  34. // when an error returns it is fatal, and the DTLS connection should be stopped
  35. func (f *fragmentBuffer) push(buf []byte) (bool, error) {
  36. if f.size()+len(buf) >= fragmentBufferMaxSize {
  37. return false, errFragmentBufferOverflow
  38. }
  39. frag := new(fragment)
  40. if err := frag.recordLayerHeader.Unmarshal(buf); err != nil {
  41. return false, err
  42. }
  43. // fragment isn't a handshake, we don't need to handle it
  44. if frag.recordLayerHeader.ContentType != protocol.ContentTypeHandshake {
  45. return false, nil
  46. }
  47. for buf = buf[recordlayer.HeaderSize:]; len(buf) != 0; frag = new(fragment) {
  48. if err := frag.handshakeHeader.Unmarshal(buf); err != nil {
  49. return false, err
  50. }
  51. if _, ok := f.cache[frag.handshakeHeader.MessageSequence]; !ok {
  52. f.cache[frag.handshakeHeader.MessageSequence] = []*fragment{}
  53. }
  54. // end index should be the length of handshake header but if the handshake
  55. // was fragmented, we should keep them all
  56. end := int(handshake.HeaderLength + frag.handshakeHeader.Length)
  57. if size := len(buf); end > size {
  58. end = size
  59. }
  60. // Discard all headers, when rebuilding the packet we will re-build
  61. frag.data = append([]byte{}, buf[handshake.HeaderLength:end]...)
  62. f.cache[frag.handshakeHeader.MessageSequence] = append(f.cache[frag.handshakeHeader.MessageSequence], frag)
  63. buf = buf[end:]
  64. }
  65. return true, nil
  66. }
  67. func (f *fragmentBuffer) pop() (content []byte, epoch uint16) {
  68. frags, ok := f.cache[f.currentMessageSequenceNumber]
  69. if !ok {
  70. return nil, 0
  71. }
  72. // Go doesn't support recursive lambdas
  73. var appendMessage func(targetOffset uint32) bool
  74. rawMessage := []byte{}
  75. appendMessage = func(targetOffset uint32) bool {
  76. for _, f := range frags {
  77. if f.handshakeHeader.FragmentOffset == targetOffset {
  78. fragmentEnd := (f.handshakeHeader.FragmentOffset + f.handshakeHeader.FragmentLength)
  79. if fragmentEnd != f.handshakeHeader.Length && f.handshakeHeader.FragmentLength != 0 {
  80. if !appendMessage(fragmentEnd) {
  81. return false
  82. }
  83. }
  84. rawMessage = append(f.data, rawMessage...)
  85. return true
  86. }
  87. }
  88. return false
  89. }
  90. // Recursively collect up
  91. if !appendMessage(0) {
  92. return nil, 0
  93. }
  94. firstHeader := frags[0].handshakeHeader
  95. firstHeader.FragmentOffset = 0
  96. firstHeader.FragmentLength = firstHeader.Length
  97. rawHeader, err := firstHeader.Marshal()
  98. if err != nil {
  99. return nil, 0
  100. }
  101. messageEpoch := frags[0].recordLayerHeader.Epoch
  102. delete(f.cache, f.currentMessageSequenceNumber)
  103. f.currentMessageSequenceNumber++
  104. return append(rawHeader, rawMessage...), messageEpoch
  105. }