| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388 |
- package engineio
- import (
- "encoding/json"
- "errors"
- "fmt"
- "io"
- "net/http"
- "sync"
- "time"
- "github.com/googollee/go-engine.io/message"
- "github.com/googollee/go-engine.io/parser"
- "github.com/googollee/go-engine.io/transport"
- )
- type MessageType message.MessageType
- const (
- MessageBinary MessageType = MessageType(message.MessageBinary)
- MessageText MessageType = MessageType(message.MessageText)
- )
- // Conn is the connection object of engine.io.
- type Conn interface {
- // Id returns the session id of connection.
- Id() string
- // Request returns the first http request when established connection.
- Request() *http.Request
- // Close closes the connection.
- Close() error
- // NextReader returns the next message type, reader. If no message received, it will block.
- NextReader() (MessageType, io.ReadCloser, error)
- // NextWriter returns the next message writer with given message type.
- NextWriter(messageType MessageType) (io.WriteCloser, error)
- }
- type transportCreaters map[string]transport.Creater
- func (c transportCreaters) Get(name string) transport.Creater {
- return c[name]
- }
- type serverCallback interface {
- configure() config
- transports() transportCreaters
- onClose(sid string)
- }
- type state int
- const (
- stateUnknow state = iota
- stateNormal
- stateUpgrading
- stateClosing
- stateClosed
- )
- type serverConn struct {
- id string
- request *http.Request
- callback serverCallback
- writerLocker sync.Mutex
- transportLocker sync.RWMutex
- currentName string
- current transport.Server
- upgradingName string
- upgrading transport.Server
- state state
- stateLocker sync.RWMutex
- readerChan chan *connReader
- pingTimeout time.Duration
- pingInterval time.Duration
- pingChan chan bool
- pingLocker sync.Mutex
- }
- var InvalidError = errors.New("invalid transport")
- func newServerConn(id string, w http.ResponseWriter, r *http.Request, callback serverCallback) (*serverConn, error) {
- transportName := r.URL.Query().Get("transport")
- creater := callback.transports().Get(transportName)
- if creater.Name == "" {
- return nil, InvalidError
- }
- ret := &serverConn{
- id: id,
- request: r,
- callback: callback,
- state: stateNormal,
- readerChan: make(chan *connReader),
- pingTimeout: callback.configure().PingTimeout,
- pingInterval: callback.configure().PingInterval,
- pingChan: make(chan bool),
- }
- transport, err := creater.Server(w, r, ret)
- if err != nil {
- return nil, err
- }
- ret.setCurrent(transportName, transport)
- if err := ret.onOpen(); err != nil {
- return nil, err
- }
- go ret.pingLoop()
- return ret, nil
- }
- func (c *serverConn) Id() string {
- return c.id
- }
- func (c *serverConn) Request() *http.Request {
- return c.request
- }
- func (c *serverConn) NextReader() (MessageType, io.ReadCloser, error) {
- if c.getState() == stateClosed {
- return MessageBinary, nil, io.EOF
- }
- ret := <-c.readerChan
- if ret == nil {
- return MessageBinary, nil, io.EOF
- }
- return MessageType(ret.MessageType()), ret, nil
- }
- func (c *serverConn) NextWriter(t MessageType) (io.WriteCloser, error) {
- switch c.getState() {
- case stateUpgrading:
- for i := 0; i < 30; i++ {
- time.Sleep(50 * time.Millisecond)
- if c.getState() != stateUpgrading {
- break
- }
- }
- if c.getState() == stateUpgrading {
- return nil, fmt.Errorf("upgrading")
- }
- case stateNormal:
- default:
- return nil, io.EOF
- }
- c.writerLocker.Lock()
- ret, err := c.getCurrent().NextWriter(message.MessageType(t), parser.MESSAGE)
- if err != nil {
- c.writerLocker.Unlock()
- return ret, err
- }
- writer := newConnWriter(ret, &c.writerLocker)
- return writer, err
- }
- func (c *serverConn) Close() error {
- if c.getState() != stateNormal && c.getState() != stateUpgrading {
- return nil
- }
- if c.upgrading != nil {
- c.upgrading.Close()
- }
- c.writerLocker.Lock()
- if w, err := c.getCurrent().NextWriter(message.MessageText, parser.CLOSE); err == nil {
- writer := newConnWriter(w, &c.writerLocker)
- writer.Close()
- } else {
- c.writerLocker.Unlock()
- }
- if err := c.getCurrent().Close(); err != nil {
- return err
- }
- c.setState(stateClosing)
- return nil
- }
- func (c *serverConn) ServeHTTP(w http.ResponseWriter, r *http.Request) {
- transportName := r.URL.Query().Get("transport")
- if c.currentName != transportName {
- creater := c.callback.transports().Get(transportName)
- if creater.Name == "" {
- http.Error(w, fmt.Sprintf("invalid transport %s", transportName), http.StatusBadRequest)
- return
- }
- u, err := creater.Server(w, r, c)
- if err != nil {
- http.Error(w, err.Error(), http.StatusBadRequest)
- return
- }
- c.setUpgrading(creater.Name, u)
- return
- }
- c.current.ServeHTTP(w, r)
- }
- func (c *serverConn) OnPacket(r *parser.PacketDecoder) {
- if s := c.getState(); s != stateNormal && s != stateUpgrading {
- return
- }
- switch r.Type() {
- case parser.OPEN:
- case parser.CLOSE:
- c.getCurrent().Close()
- case parser.PING:
- c.writerLocker.Lock()
- t := c.getCurrent()
- u := c.getUpgrade()
- newWriter := t.NextWriter
- if u != nil {
- if w, _ := t.NextWriter(message.MessageText, parser.NOOP); w != nil {
- w.Close()
- }
- newWriter = u.NextWriter
- }
- if w, _ := newWriter(message.MessageText, parser.PONG); w != nil {
- io.Copy(w, r)
- w.Close()
- }
- c.writerLocker.Unlock()
- fallthrough
- case parser.PONG:
- c.pingLocker.Lock()
- defer c.pingLocker.Unlock()
- if s := c.getState(); s != stateNormal && s != stateUpgrading {
- return
- }
- c.pingChan <- true
- case parser.MESSAGE:
- closeChan := make(chan struct{})
- c.readerChan <- newConnReader(r, closeChan)
- <-closeChan
- close(closeChan)
- r.Close()
- case parser.UPGRADE:
- c.upgraded()
- case parser.NOOP:
- }
- }
- func (c *serverConn) OnClose(server transport.Server) {
- if t := c.getUpgrade(); server == t {
- c.setUpgrading("", nil)
- t.Close()
- return
- }
- t := c.getCurrent()
- if server != t {
- return
- }
- t.Close()
- if t := c.getUpgrade(); t != nil {
- t.Close()
- c.setUpgrading("", nil)
- }
- c.setState(stateClosed)
- close(c.readerChan)
- c.pingLocker.Lock()
- close(c.pingChan)
- c.pingLocker.Unlock()
- c.callback.onClose(c.id)
- }
- func (s *serverConn) onOpen() error {
- upgrades := []string{}
- for name := range s.callback.transports() {
- if name == s.currentName {
- continue
- }
- upgrades = append(upgrades, name)
- }
- type connectionInfo struct {
- Sid string `json:"sid"`
- Upgrades []string `json:"upgrades"`
- PingInterval time.Duration `json:"pingInterval"`
- PingTimeout time.Duration `json:"pingTimeout"`
- }
- resp := connectionInfo{
- Sid: s.Id(),
- Upgrades: upgrades,
- PingInterval: s.callback.configure().PingInterval / time.Millisecond,
- PingTimeout: s.callback.configure().PingTimeout / time.Millisecond,
- }
- w, err := s.getCurrent().NextWriter(message.MessageText, parser.OPEN)
- if err != nil {
- return err
- }
- encoder := json.NewEncoder(w)
- if err := encoder.Encode(resp); err != nil {
- return err
- }
- if err := w.Close(); err != nil {
- return err
- }
- return nil
- }
- func (c *serverConn) getCurrent() transport.Server {
- c.transportLocker.RLock()
- defer c.transportLocker.RUnlock()
- return c.current
- }
- func (c *serverConn) getUpgrade() transport.Server {
- c.transportLocker.RLock()
- defer c.transportLocker.RUnlock()
- return c.upgrading
- }
- func (c *serverConn) setCurrent(name string, s transport.Server) {
- c.transportLocker.Lock()
- defer c.transportLocker.Unlock()
- c.currentName = name
- c.current = s
- }
- func (c *serverConn) setUpgrading(name string, s transport.Server) {
- c.transportLocker.Lock()
- defer c.transportLocker.Unlock()
- c.upgradingName = name
- c.upgrading = s
- c.setState(stateUpgrading)
- }
- func (c *serverConn) upgraded() {
- c.transportLocker.Lock()
- current := c.current
- c.current = c.upgrading
- c.currentName = c.upgradingName
- c.upgrading = nil
- c.upgradingName = ""
- c.transportLocker.Unlock()
- current.Close()
- c.setState(stateNormal)
- }
- func (c *serverConn) getState() state {
- c.stateLocker.RLock()
- defer c.stateLocker.RUnlock()
- return c.state
- }
- func (c *serverConn) setState(state state) {
- c.stateLocker.Lock()
- defer c.stateLocker.Unlock()
- c.state = state
- }
- func (c *serverConn) pingLoop() {
- lastPing := time.Now()
- lastTry := lastPing
- for {
- now := time.Now()
- pingDiff := now.Sub(lastPing)
- tryDiff := now.Sub(lastTry)
- select {
- case ok := <-c.pingChan:
- if !ok {
- return
- }
- lastPing = time.Now()
- lastTry = lastPing
- case <-time.After(c.pingInterval - tryDiff):
- c.writerLocker.Lock()
- if w, _ := c.getCurrent().NextWriter(message.MessageText, parser.PING); w != nil {
- writer := newConnWriter(w, &c.writerLocker)
- writer.Close()
- } else {
- c.writerLocker.Unlock()
- }
- lastTry = time.Now()
- case <-time.After(c.pingTimeout - pingDiff):
- c.Close()
- return
- }
- }
- }
|