server.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. package engineio
  2. import (
  3. "bytes"
  4. "crypto/md5"
  5. "encoding/base64"
  6. "fmt"
  7. "net/http"
  8. "sync/atomic"
  9. "time"
  10. "github.com/googollee/go-engine.io/polling"
  11. "github.com/googollee/go-engine.io/websocket"
  12. )
  13. type config struct {
  14. PingTimeout time.Duration
  15. PingInterval time.Duration
  16. MaxConnection int
  17. AllowRequest func(*http.Request) error
  18. AllowUpgrades bool
  19. Cookie string
  20. NewId func(r *http.Request) string
  21. }
  22. // Server is the server of engine.io.
  23. type Server struct {
  24. config config
  25. socketChan chan Conn
  26. serverSessions Sessions
  27. creaters transportCreaters
  28. currentConnection int32
  29. }
  30. // NewServer returns the server suppported given transports. If transports is nil, server will use ["polling", "websocket"] as default.
  31. func NewServer(transports []string) (*Server, error) {
  32. if transports == nil {
  33. transports = []string{"polling", "websocket"}
  34. }
  35. creaters := make(transportCreaters)
  36. for _, t := range transports {
  37. switch t {
  38. case "polling":
  39. creaters[t] = polling.Creater
  40. case "websocket":
  41. creaters[t] = websocket.Creater
  42. default:
  43. return nil, InvalidError
  44. }
  45. }
  46. return &Server{
  47. config: config{
  48. PingTimeout: 60000 * time.Millisecond,
  49. PingInterval: 25000 * time.Millisecond,
  50. MaxConnection: 1000,
  51. AllowRequest: func(*http.Request) error { return nil },
  52. AllowUpgrades: true,
  53. Cookie: "io",
  54. NewId: newId,
  55. },
  56. socketChan: make(chan Conn),
  57. serverSessions: newServerSessions(),
  58. creaters: creaters,
  59. }, nil
  60. }
  61. // SetPingTimeout sets the timeout of ping. When time out, server will close connection. Default is 60s.
  62. func (s *Server) SetPingTimeout(t time.Duration) {
  63. s.config.PingTimeout = t
  64. }
  65. // SetPingInterval sets the interval of ping. Default is 25s.
  66. func (s *Server) SetPingInterval(t time.Duration) {
  67. s.config.PingInterval = t
  68. }
  69. // SetMaxConnection sets the max connetion. Default is 1000.
  70. func (s *Server) SetMaxConnection(n int) {
  71. s.config.MaxConnection = n
  72. }
  73. // GetMaxConnection returns the current max connection
  74. func (s *Server) GetMaxConnection() int {
  75. return s.config.MaxConnection
  76. }
  77. // Count returns a count of current number of active connections in session
  78. func (s *Server) Count() int {
  79. return int(atomic.LoadInt32(&s.currentConnection))
  80. }
  81. // SetAllowRequest sets the middleware function when establish connection. If it return non-nil, connection won't be established. Default will allow all request.
  82. func (s *Server) SetAllowRequest(f func(*http.Request) error) {
  83. s.config.AllowRequest = f
  84. }
  85. // SetAllowUpgrades sets whether server allows transport upgrade. Default is true.
  86. func (s *Server) SetAllowUpgrades(allow bool) {
  87. s.config.AllowUpgrades = allow
  88. }
  89. // SetCookie sets the name of cookie which used by engine.io. Default is "io".
  90. func (s *Server) SetCookie(prefix string) {
  91. s.config.Cookie = prefix
  92. }
  93. // SetNewId sets the callback func to generate new connection id. By default, id is generated from remote addr + current time stamp
  94. func (s *Server) SetNewId(f func(*http.Request) string) {
  95. s.config.NewId = f
  96. }
  97. // SetSessionManager sets the sessions as server's session manager. Default sessions is single process manager. You can custom it as load balance.
  98. func (s *Server) SetSessionManager(sessions Sessions) {
  99. s.serverSessions = sessions
  100. }
  101. // ServeHTTP handles http request.
  102. func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  103. defer r.Body.Close()
  104. sid := r.URL.Query().Get("sid")
  105. conn := s.serverSessions.Get(sid)
  106. if conn == nil {
  107. if sid != "" {
  108. http.Error(w, "invalid sid", http.StatusBadRequest)
  109. return
  110. }
  111. if err := s.config.AllowRequest(r); err != nil {
  112. http.Error(w, err.Error(), http.StatusBadRequest)
  113. return
  114. }
  115. n := atomic.AddInt32(&s.currentConnection, 1)
  116. if int(n) > s.config.MaxConnection {
  117. atomic.AddInt32(&s.currentConnection, -1)
  118. http.Error(w, "too many connections", http.StatusServiceUnavailable)
  119. return
  120. }
  121. sid = s.config.NewId(r)
  122. var err error
  123. conn, err = newServerConn(sid, w, r, s)
  124. if err != nil {
  125. atomic.AddInt32(&s.currentConnection, -1)
  126. http.Error(w, err.Error(), http.StatusBadRequest)
  127. return
  128. }
  129. s.serverSessions.Set(sid, conn)
  130. s.socketChan <- conn
  131. }
  132. http.SetCookie(w, &http.Cookie{
  133. Name: s.config.Cookie,
  134. Value: sid,
  135. })
  136. conn.(*serverConn).ServeHTTP(w, r)
  137. }
  138. // Accept returns Conn when client connect to server.
  139. func (s *Server) Accept() (Conn, error) {
  140. return <-s.socketChan, nil
  141. }
  142. func (s *Server) configure() config {
  143. return s.config
  144. }
  145. func (s *Server) transports() transportCreaters {
  146. return s.creaters
  147. }
  148. func (s *Server) onClose(id string) {
  149. s.serverSessions.Remove(id)
  150. atomic.AddInt32(&s.currentConnection, -1)
  151. }
  152. func newId(r *http.Request) string {
  153. hash := fmt.Sprintf("%s %s", r.RemoteAddr, time.Now())
  154. buf := bytes.NewBuffer(nil)
  155. sum := md5.Sum([]byte(hash))
  156. encoder := base64.NewEncoder(base64.URLEncoding, buf)
  157. encoder.Write(sum[:])
  158. encoder.Close()
  159. return buf.String()[:20]
  160. }