| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189 |
- package engineio
- import (
- "bytes"
- "crypto/md5"
- "encoding/base64"
- "fmt"
- "net/http"
- "sync/atomic"
- "time"
- "github.com/googollee/go-engine.io/polling"
- "github.com/googollee/go-engine.io/websocket"
- )
- type config struct {
- PingTimeout time.Duration
- PingInterval time.Duration
- MaxConnection int
- AllowRequest func(*http.Request) error
- AllowUpgrades bool
- Cookie string
- NewId func(r *http.Request) string
- }
- // Server is the server of engine.io.
- type Server struct {
- config config
- socketChan chan Conn
- serverSessions Sessions
- creaters transportCreaters
- currentConnection int32
- }
- // NewServer returns the server suppported given transports. If transports is nil, server will use ["polling", "websocket"] as default.
- func NewServer(transports []string) (*Server, error) {
- if transports == nil {
- transports = []string{"polling", "websocket"}
- }
- creaters := make(transportCreaters)
- for _, t := range transports {
- switch t {
- case "polling":
- creaters[t] = polling.Creater
- case "websocket":
- creaters[t] = websocket.Creater
- default:
- return nil, InvalidError
- }
- }
- return &Server{
- config: config{
- PingTimeout: 60000 * time.Millisecond,
- PingInterval: 25000 * time.Millisecond,
- MaxConnection: 1000,
- AllowRequest: func(*http.Request) error { return nil },
- AllowUpgrades: true,
- Cookie: "io",
- NewId: newId,
- },
- socketChan: make(chan Conn),
- serverSessions: newServerSessions(),
- creaters: creaters,
- }, nil
- }
- // SetPingTimeout sets the timeout of ping. When time out, server will close connection. Default is 60s.
- func (s *Server) SetPingTimeout(t time.Duration) {
- s.config.PingTimeout = t
- }
- // SetPingInterval sets the interval of ping. Default is 25s.
- func (s *Server) SetPingInterval(t time.Duration) {
- s.config.PingInterval = t
- }
- // SetMaxConnection sets the max connetion. Default is 1000.
- func (s *Server) SetMaxConnection(n int) {
- s.config.MaxConnection = n
- }
- // GetMaxConnection returns the current max connection
- func (s *Server) GetMaxConnection() int {
- return s.config.MaxConnection
- }
- // Count returns a count of current number of active connections in session
- func (s *Server) Count() int {
- return int(atomic.LoadInt32(&s.currentConnection))
- }
- // SetAllowRequest sets the middleware function when establish connection. If it return non-nil, connection won't be established. Default will allow all request.
- func (s *Server) SetAllowRequest(f func(*http.Request) error) {
- s.config.AllowRequest = f
- }
- // SetAllowUpgrades sets whether server allows transport upgrade. Default is true.
- func (s *Server) SetAllowUpgrades(allow bool) {
- s.config.AllowUpgrades = allow
- }
- // SetCookie sets the name of cookie which used by engine.io. Default is "io".
- func (s *Server) SetCookie(prefix string) {
- s.config.Cookie = prefix
- }
- // SetNewId sets the callback func to generate new connection id. By default, id is generated from remote addr + current time stamp
- func (s *Server) SetNewId(f func(*http.Request) string) {
- s.config.NewId = f
- }
- // SetSessionManager sets the sessions as server's session manager. Default sessions is single process manager. You can custom it as load balance.
- func (s *Server) SetSessionManager(sessions Sessions) {
- s.serverSessions = sessions
- }
- // ServeHTTP handles http request.
- func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
- defer r.Body.Close()
- sid := r.URL.Query().Get("sid")
- conn := s.serverSessions.Get(sid)
- if conn == nil {
- if sid != "" {
- http.Error(w, "invalid sid", http.StatusBadRequest)
- return
- }
- if err := s.config.AllowRequest(r); err != nil {
- http.Error(w, err.Error(), http.StatusBadRequest)
- return
- }
- n := atomic.AddInt32(&s.currentConnection, 1)
- if int(n) > s.config.MaxConnection {
- atomic.AddInt32(&s.currentConnection, -1)
- http.Error(w, "too many connections", http.StatusServiceUnavailable)
- return
- }
- sid = s.config.NewId(r)
- var err error
- conn, err = newServerConn(sid, w, r, s)
- if err != nil {
- atomic.AddInt32(&s.currentConnection, -1)
- http.Error(w, err.Error(), http.StatusBadRequest)
- return
- }
- s.serverSessions.Set(sid, conn)
- s.socketChan <- conn
- }
- http.SetCookie(w, &http.Cookie{
- Name: s.config.Cookie,
- Value: sid,
- })
- conn.(*serverConn).ServeHTTP(w, r)
- }
- // Accept returns Conn when client connect to server.
- func (s *Server) Accept() (Conn, error) {
- return <-s.socketChan, nil
- }
- func (s *Server) configure() config {
- return s.config
- }
- func (s *Server) transports() transportCreaters {
- return s.creaters
- }
- func (s *Server) onClose(id string) {
- s.serverSessions.Remove(id)
- atomic.AddInt32(&s.currentConnection, -1)
- }
- func newId(r *http.Request) string {
- hash := fmt.Sprintf("%s %s", r.RemoteAddr, time.Now())
- buf := bytes.NewBuffer(nil)
- sum := md5.Sum([]byte(hash))
- encoder := base64.NewEncoder(base64.URLEncoding, buf)
- encoder.Write(sum[:])
- encoder.Close()
- return buf.String()[:20]
- }
|