server_conn.go 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388
  1. package engineio
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "sync"
  9. "time"
  10. "github.com/googollee/go-engine.io/message"
  11. "github.com/googollee/go-engine.io/parser"
  12. "github.com/googollee/go-engine.io/transport"
  13. )
  14. type MessageType message.MessageType
  15. const (
  16. MessageBinary MessageType = MessageType(message.MessageBinary)
  17. MessageText MessageType = MessageType(message.MessageText)
  18. )
  19. // Conn is the connection object of engine.io.
  20. type Conn interface {
  21. // Id returns the session id of connection.
  22. Id() string
  23. // Request returns the first http request when established connection.
  24. Request() *http.Request
  25. // Close closes the connection.
  26. Close() error
  27. // NextReader returns the next message type, reader. If no message received, it will block.
  28. NextReader() (MessageType, io.ReadCloser, error)
  29. // NextWriter returns the next message writer with given message type.
  30. NextWriter(messageType MessageType) (io.WriteCloser, error)
  31. }
  32. type transportCreaters map[string]transport.Creater
  33. func (c transportCreaters) Get(name string) transport.Creater {
  34. return c[name]
  35. }
  36. type serverCallback interface {
  37. configure() config
  38. transports() transportCreaters
  39. onClose(sid string)
  40. }
  41. type state int
  42. const (
  43. stateUnknow state = iota
  44. stateNormal
  45. stateUpgrading
  46. stateClosing
  47. stateClosed
  48. )
  49. type serverConn struct {
  50. id string
  51. request *http.Request
  52. callback serverCallback
  53. writerLocker sync.Mutex
  54. transportLocker sync.RWMutex
  55. currentName string
  56. current transport.Server
  57. upgradingName string
  58. upgrading transport.Server
  59. state state
  60. stateLocker sync.RWMutex
  61. readerChan chan *connReader
  62. pingTimeout time.Duration
  63. pingInterval time.Duration
  64. pingChan chan bool
  65. pingLocker sync.Mutex
  66. }
  67. var InvalidError = errors.New("invalid transport")
  68. func newServerConn(id string, w http.ResponseWriter, r *http.Request, callback serverCallback) (*serverConn, error) {
  69. transportName := r.URL.Query().Get("transport")
  70. creater := callback.transports().Get(transportName)
  71. if creater.Name == "" {
  72. return nil, InvalidError
  73. }
  74. ret := &serverConn{
  75. id: id,
  76. request: r,
  77. callback: callback,
  78. state: stateNormal,
  79. readerChan: make(chan *connReader),
  80. pingTimeout: callback.configure().PingTimeout,
  81. pingInterval: callback.configure().PingInterval,
  82. pingChan: make(chan bool),
  83. }
  84. transport, err := creater.Server(w, r, ret)
  85. if err != nil {
  86. return nil, err
  87. }
  88. ret.setCurrent(transportName, transport)
  89. if err := ret.onOpen(); err != nil {
  90. return nil, err
  91. }
  92. go ret.pingLoop()
  93. return ret, nil
  94. }
  95. func (c *serverConn) Id() string {
  96. return c.id
  97. }
  98. func (c *serverConn) Request() *http.Request {
  99. return c.request
  100. }
  101. func (c *serverConn) NextReader() (MessageType, io.ReadCloser, error) {
  102. if c.getState() == stateClosed {
  103. return MessageBinary, nil, io.EOF
  104. }
  105. ret := <-c.readerChan
  106. if ret == nil {
  107. return MessageBinary, nil, io.EOF
  108. }
  109. return MessageType(ret.MessageType()), ret, nil
  110. }
  111. func (c *serverConn) NextWriter(t MessageType) (io.WriteCloser, error) {
  112. switch c.getState() {
  113. case stateUpgrading:
  114. for i := 0; i < 30; i++ {
  115. time.Sleep(50 * time.Millisecond)
  116. if c.getState() != stateUpgrading {
  117. break
  118. }
  119. }
  120. if c.getState() == stateUpgrading {
  121. return nil, fmt.Errorf("upgrading")
  122. }
  123. case stateNormal:
  124. default:
  125. return nil, io.EOF
  126. }
  127. c.writerLocker.Lock()
  128. ret, err := c.getCurrent().NextWriter(message.MessageType(t), parser.MESSAGE)
  129. if err != nil {
  130. c.writerLocker.Unlock()
  131. return ret, err
  132. }
  133. writer := newConnWriter(ret, &c.writerLocker)
  134. return writer, err
  135. }
  136. func (c *serverConn) Close() error {
  137. if c.getState() != stateNormal && c.getState() != stateUpgrading {
  138. return nil
  139. }
  140. if c.upgrading != nil {
  141. c.upgrading.Close()
  142. }
  143. c.writerLocker.Lock()
  144. if w, err := c.getCurrent().NextWriter(message.MessageText, parser.CLOSE); err == nil {
  145. writer := newConnWriter(w, &c.writerLocker)
  146. writer.Close()
  147. } else {
  148. c.writerLocker.Unlock()
  149. }
  150. if err := c.getCurrent().Close(); err != nil {
  151. return err
  152. }
  153. c.setState(stateClosing)
  154. return nil
  155. }
  156. func (c *serverConn) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  157. transportName := r.URL.Query().Get("transport")
  158. if c.currentName != transportName {
  159. creater := c.callback.transports().Get(transportName)
  160. if creater.Name == "" {
  161. http.Error(w, fmt.Sprintf("invalid transport %s", transportName), http.StatusBadRequest)
  162. return
  163. }
  164. u, err := creater.Server(w, r, c)
  165. if err != nil {
  166. http.Error(w, err.Error(), http.StatusBadRequest)
  167. return
  168. }
  169. c.setUpgrading(creater.Name, u)
  170. return
  171. }
  172. c.current.ServeHTTP(w, r)
  173. }
  174. func (c *serverConn) OnPacket(r *parser.PacketDecoder) {
  175. if s := c.getState(); s != stateNormal && s != stateUpgrading {
  176. return
  177. }
  178. switch r.Type() {
  179. case parser.OPEN:
  180. case parser.CLOSE:
  181. c.getCurrent().Close()
  182. case parser.PING:
  183. c.writerLocker.Lock()
  184. t := c.getCurrent()
  185. u := c.getUpgrade()
  186. newWriter := t.NextWriter
  187. if u != nil {
  188. if w, _ := t.NextWriter(message.MessageText, parser.NOOP); w != nil {
  189. w.Close()
  190. }
  191. newWriter = u.NextWriter
  192. }
  193. if w, _ := newWriter(message.MessageText, parser.PONG); w != nil {
  194. io.Copy(w, r)
  195. w.Close()
  196. }
  197. c.writerLocker.Unlock()
  198. fallthrough
  199. case parser.PONG:
  200. c.pingLocker.Lock()
  201. defer c.pingLocker.Unlock()
  202. if s := c.getState(); s != stateNormal && s != stateUpgrading {
  203. return
  204. }
  205. c.pingChan <- true
  206. case parser.MESSAGE:
  207. closeChan := make(chan struct{})
  208. c.readerChan <- newConnReader(r, closeChan)
  209. <-closeChan
  210. close(closeChan)
  211. r.Close()
  212. case parser.UPGRADE:
  213. c.upgraded()
  214. case parser.NOOP:
  215. }
  216. }
  217. func (c *serverConn) OnClose(server transport.Server) {
  218. if t := c.getUpgrade(); server == t {
  219. c.setUpgrading("", nil)
  220. t.Close()
  221. return
  222. }
  223. t := c.getCurrent()
  224. if server != t {
  225. return
  226. }
  227. t.Close()
  228. if t := c.getUpgrade(); t != nil {
  229. t.Close()
  230. c.setUpgrading("", nil)
  231. }
  232. c.setState(stateClosed)
  233. close(c.readerChan)
  234. c.pingLocker.Lock()
  235. close(c.pingChan)
  236. c.pingLocker.Unlock()
  237. c.callback.onClose(c.id)
  238. }
  239. func (s *serverConn) onOpen() error {
  240. upgrades := []string{}
  241. for name := range s.callback.transports() {
  242. if name == s.currentName {
  243. continue
  244. }
  245. upgrades = append(upgrades, name)
  246. }
  247. type connectionInfo struct {
  248. Sid string `json:"sid"`
  249. Upgrades []string `json:"upgrades"`
  250. PingInterval time.Duration `json:"pingInterval"`
  251. PingTimeout time.Duration `json:"pingTimeout"`
  252. }
  253. resp := connectionInfo{
  254. Sid: s.Id(),
  255. Upgrades: upgrades,
  256. PingInterval: s.callback.configure().PingInterval / time.Millisecond,
  257. PingTimeout: s.callback.configure().PingTimeout / time.Millisecond,
  258. }
  259. w, err := s.getCurrent().NextWriter(message.MessageText, parser.OPEN)
  260. if err != nil {
  261. return err
  262. }
  263. encoder := json.NewEncoder(w)
  264. if err := encoder.Encode(resp); err != nil {
  265. return err
  266. }
  267. if err := w.Close(); err != nil {
  268. return err
  269. }
  270. return nil
  271. }
  272. func (c *serverConn) getCurrent() transport.Server {
  273. c.transportLocker.RLock()
  274. defer c.transportLocker.RUnlock()
  275. return c.current
  276. }
  277. func (c *serverConn) getUpgrade() transport.Server {
  278. c.transportLocker.RLock()
  279. defer c.transportLocker.RUnlock()
  280. return c.upgrading
  281. }
  282. func (c *serverConn) setCurrent(name string, s transport.Server) {
  283. c.transportLocker.Lock()
  284. defer c.transportLocker.Unlock()
  285. c.currentName = name
  286. c.current = s
  287. }
  288. func (c *serverConn) setUpgrading(name string, s transport.Server) {
  289. c.transportLocker.Lock()
  290. defer c.transportLocker.Unlock()
  291. c.upgradingName = name
  292. c.upgrading = s
  293. c.setState(stateUpgrading)
  294. }
  295. func (c *serverConn) upgraded() {
  296. c.transportLocker.Lock()
  297. current := c.current
  298. c.current = c.upgrading
  299. c.currentName = c.upgradingName
  300. c.upgrading = nil
  301. c.upgradingName = ""
  302. c.transportLocker.Unlock()
  303. current.Close()
  304. c.setState(stateNormal)
  305. }
  306. func (c *serverConn) getState() state {
  307. c.stateLocker.RLock()
  308. defer c.stateLocker.RUnlock()
  309. return c.state
  310. }
  311. func (c *serverConn) setState(state state) {
  312. c.stateLocker.Lock()
  313. defer c.stateLocker.Unlock()
  314. c.state = state
  315. }
  316. func (c *serverConn) pingLoop() {
  317. lastPing := time.Now()
  318. lastTry := lastPing
  319. for {
  320. now := time.Now()
  321. pingDiff := now.Sub(lastPing)
  322. tryDiff := now.Sub(lastTry)
  323. select {
  324. case ok := <-c.pingChan:
  325. if !ok {
  326. return
  327. }
  328. lastPing = time.Now()
  329. lastTry = lastPing
  330. case <-time.After(c.pingInterval - tryDiff):
  331. c.writerLocker.Lock()
  332. if w, _ := c.getCurrent().NextWriter(message.MessageText, parser.PING); w != nil {
  333. writer := newConnWriter(w, &c.writerLocker)
  334. writer.Close()
  335. } else {
  336. c.writerLocker.Unlock()
  337. }
  338. lastTry = time.Now()
  339. case <-time.After(c.pingTimeout - pingDiff):
  340. c.Close()
  341. return
  342. }
  343. }
  344. }