httpstream_test.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. // Copyright 2019 Yunion
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. /*
  15. Copyright 2015 The Kubernetes Authors.
  16. Licensed under the Apache License, Version 2.0 (the "License");
  17. you may not use this file except in compliance with the License.
  18. You may obtain a copy of the License at
  19. http://www.apache.org/licenses/LICENSE-2.0
  20. Unless required by applicable law or agreed to in writing, software
  21. distributed under the License is distributed on an "AS IS" BASIS,
  22. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  23. See the License for the specific language governing permissions and
  24. limitations under the License.
  25. */
  26. package httpstream
  27. import (
  28. "net/http"
  29. "reflect"
  30. "testing"
  31. )
  32. type responseWriter struct {
  33. header http.Header
  34. statusCode *int
  35. }
  36. func newResponseWriter() *responseWriter {
  37. return &responseWriter{
  38. header: make(http.Header),
  39. }
  40. }
  41. func (r *responseWriter) Header() http.Header {
  42. return r.header
  43. }
  44. func (r *responseWriter) WriteHeader(code int) {
  45. r.statusCode = &code
  46. }
  47. func (r *responseWriter) Write([]byte) (int, error) {
  48. return 0, nil
  49. }
  50. func TestHandshake(t *testing.T) {
  51. tests := map[string]struct {
  52. clientProtocols []string
  53. serverProtocols []string
  54. expectedProtocol string
  55. expectError bool
  56. }{
  57. "no client protocols": {
  58. clientProtocols: []string{},
  59. serverProtocols: []string{"a", "b"},
  60. expectedProtocol: "",
  61. },
  62. "no common protocol": {
  63. clientProtocols: []string{"c"},
  64. serverProtocols: []string{"a", "b"},
  65. expectedProtocol: "",
  66. expectError: true,
  67. },
  68. "common protocol": {
  69. clientProtocols: []string{"b"},
  70. serverProtocols: []string{"a", "b"},
  71. expectedProtocol: "b",
  72. },
  73. }
  74. for name, test := range tests {
  75. req, err := http.NewRequest("GET", "http://www.example.com/", nil)
  76. if err != nil {
  77. t.Fatalf("%s: error creating request: %v", name, err)
  78. }
  79. for _, p := range test.clientProtocols {
  80. req.Header.Add(HeaderProtocolVersion, p)
  81. }
  82. w := newResponseWriter()
  83. negotiated, err := Handshake(req, w, test.serverProtocols)
  84. // verify negotiated protocol
  85. if e, a := test.expectedProtocol, negotiated; e != a {
  86. t.Errorf("%s: protocol: expected %q, got %q", name, e, a)
  87. }
  88. if test.expectError {
  89. if err == nil {
  90. t.Errorf("%s: expected error but did not get one", name)
  91. }
  92. if w.statusCode == nil {
  93. t.Errorf("%s: expected w.statusCode to be set", name)
  94. } else if e, a := http.StatusForbidden, *w.statusCode; e != a {
  95. t.Errorf("%s: w.statusCode: expected %d, got %d", name, e, a)
  96. }
  97. if e, a := test.serverProtocols, w.Header()[HeaderAcceptedProtocolVersions]; !reflect.DeepEqual(e, a) {
  98. t.Errorf("%s: accepted server protocols: expected %v, got %v", name, e, a)
  99. }
  100. continue
  101. }
  102. if !test.expectError && err != nil {
  103. t.Errorf("%s: unexpected error: %v", name, err)
  104. continue
  105. }
  106. if w.statusCode != nil {
  107. t.Errorf("%s: unexpected non-nil w.statusCode: %d", name, w.statusCode)
  108. }
  109. if len(test.expectedProtocol) == 0 {
  110. if len(w.Header()[HeaderProtocolVersion]) > 0 {
  111. t.Errorf("%s: unexpected protocol version response header: %s", name, w.Header()[HeaderProtocolVersion])
  112. }
  113. continue
  114. }
  115. // verify response headers
  116. if e, a := []string{test.expectedProtocol}, w.Header()[HeaderProtocolVersion]; !reflect.DeepEqual(e, a) {
  117. t.Errorf("%s: protocol response header: expected %v, got %v", name, e, a)
  118. }
  119. }
  120. }