streamable_http.go 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939
  1. package server
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "mime"
  8. "net/http"
  9. "net/http/httptest"
  10. "os"
  11. "strings"
  12. "sync"
  13. "sync/atomic"
  14. "time"
  15. "github.com/google/uuid"
  16. "github.com/mark3labs/mcp-go/mcp"
  17. "github.com/mark3labs/mcp-go/util"
  18. )
  19. // StreamableHTTPOption defines a function type for configuring StreamableHTTPServer
  20. type StreamableHTTPOption func(*StreamableHTTPServer)
  21. // WithEndpointPath sets the endpoint path for the server.
  22. // The default is "/mcp".
  23. // It's only works for `Start` method. When used as a http.Handler, it has no effect.
  24. func WithEndpointPath(endpointPath string) StreamableHTTPOption {
  25. return func(s *StreamableHTTPServer) {
  26. // Normalize the endpoint path to ensure it starts with a slash and doesn't end with one
  27. normalizedPath := "/" + strings.Trim(endpointPath, "/")
  28. s.endpointPath = normalizedPath
  29. }
  30. }
  31. // WithStateLess sets the server to stateless mode.
  32. // If true, the server will manage no session information. Every request will be treated
  33. // as a new session. No session id returned to the client.
  34. // The default is false.
  35. //
  36. // Notice: This is a convenience method. It's identical to set WithSessionIdManager option
  37. // to StatelessSessionIdManager.
  38. func WithStateLess(stateLess bool) StreamableHTTPOption {
  39. return func(s *StreamableHTTPServer) {
  40. if stateLess {
  41. s.sessionIdManager = &StatelessSessionIdManager{}
  42. }
  43. }
  44. }
  45. // WithSessionIdManager sets a custom session id generator for the server.
  46. // By default, the server will use SimpleStatefulSessionIdGenerator, which generates
  47. // session ids with uuid, and it's insecure.
  48. // Notice: it will override the WithStateLess option.
  49. func WithSessionIdManager(manager SessionIdManager) StreamableHTTPOption {
  50. return func(s *StreamableHTTPServer) {
  51. s.sessionIdManager = manager
  52. }
  53. }
  54. // WithHeartbeatInterval sets the heartbeat interval. Positive interval means the
  55. // server will send a heartbeat to the client through the GET connection, to keep
  56. // the connection alive from being closed by the network infrastructure (e.g.
  57. // gateways). If the client does not establish a GET connection, it has no
  58. // effect. The default is not to send heartbeats.
  59. func WithHeartbeatInterval(interval time.Duration) StreamableHTTPOption {
  60. return func(s *StreamableHTTPServer) {
  61. s.listenHeartbeatInterval = interval
  62. }
  63. }
  64. // WithHTTPContextFunc sets a function that will be called to customise the context
  65. // to the server using the incoming request.
  66. // This can be used to inject context values from headers, for example.
  67. func WithHTTPContextFunc(fn HTTPContextFunc) StreamableHTTPOption {
  68. return func(s *StreamableHTTPServer) {
  69. s.contextFunc = fn
  70. }
  71. }
  72. // WithStreamableHTTPServer sets the HTTP server instance for StreamableHTTPServer.
  73. // NOTE: When providing a custom HTTP server, you must handle routing yourself
  74. // If routing is not set up, the server will start but won't handle any MCP requests.
  75. func WithStreamableHTTPServer(srv *http.Server) StreamableHTTPOption {
  76. return func(s *StreamableHTTPServer) {
  77. s.httpServer = srv
  78. }
  79. }
  80. // WithLogger sets the logger for the server
  81. func WithLogger(logger util.Logger) StreamableHTTPOption {
  82. return func(s *StreamableHTTPServer) {
  83. s.logger = logger
  84. }
  85. }
  86. // WithTLSCert sets the TLS certificate and key files for HTTPS support.
  87. // Both certFile and keyFile must be provided to enable TLS.
  88. func WithTLSCert(certFile, keyFile string) StreamableHTTPOption {
  89. return func(s *StreamableHTTPServer) {
  90. s.tlsCertFile = certFile
  91. s.tlsKeyFile = keyFile
  92. }
  93. }
  94. // StreamableHTTPServer implements a Streamable-http based MCP server.
  95. // It communicates with clients over HTTP protocol, supporting both direct HTTP responses, and SSE streams.
  96. // https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http
  97. //
  98. // Usage:
  99. //
  100. // server := NewStreamableHTTPServer(mcpServer)
  101. // server.Start(":8080") // The final url for client is http://xxxx:8080/mcp by default
  102. //
  103. // or the server itself can be used as a http.Handler, which is convenient to
  104. // integrate with existing http servers, or advanced usage:
  105. //
  106. // handler := NewStreamableHTTPServer(mcpServer)
  107. // http.Handle("/streamable-http", handler)
  108. // http.ListenAndServe(":8080", nil)
  109. //
  110. // Notice:
  111. // Except for the GET handlers(listening), the POST handlers(request/notification) will
  112. // not trigger the session registration. So the methods like `SendNotificationToSpecificClient`
  113. // or `hooks.onRegisterSession` will not be triggered for POST messages.
  114. //
  115. // The current implementation does not support the following features from the specification:
  116. // - Stream Resumability
  117. type StreamableHTTPServer struct {
  118. server *MCPServer
  119. sessionTools *sessionToolsStore
  120. sessionRequestIDs sync.Map // sessionId --> last requestID(*atomic.Int64)
  121. activeSessions sync.Map // sessionId --> *streamableHttpSession (for sampling responses)
  122. httpServer *http.Server
  123. mu sync.RWMutex
  124. endpointPath string
  125. contextFunc HTTPContextFunc
  126. sessionIdManager SessionIdManager
  127. listenHeartbeatInterval time.Duration
  128. logger util.Logger
  129. sessionLogLevels *sessionLogLevelsStore
  130. tlsCertFile string
  131. tlsKeyFile string
  132. }
  133. // NewStreamableHTTPServer creates a new streamable-http server instance
  134. func NewStreamableHTTPServer(server *MCPServer, opts ...StreamableHTTPOption) *StreamableHTTPServer {
  135. s := &StreamableHTTPServer{
  136. server: server,
  137. sessionTools: newSessionToolsStore(),
  138. sessionLogLevels: newSessionLogLevelsStore(),
  139. endpointPath: "/mcp",
  140. sessionIdManager: &InsecureStatefulSessionIdManager{},
  141. logger: util.DefaultLogger(),
  142. }
  143. // Apply all options
  144. for _, opt := range opts {
  145. opt(s)
  146. }
  147. return s
  148. }
  149. // ServeHTTP implements the http.Handler interface.
  150. func (s *StreamableHTTPServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  151. switch r.Method {
  152. case http.MethodPost:
  153. s.handlePost(w, r)
  154. case http.MethodGet:
  155. s.handleGet(w, r)
  156. case http.MethodDelete:
  157. s.handleDelete(w, r)
  158. default:
  159. http.NotFound(w, r)
  160. }
  161. }
  162. // Start begins serving the http server on the specified address and path
  163. // (endpointPath). like:
  164. //
  165. // s.Start(":8080")
  166. func (s *StreamableHTTPServer) Start(addr string) error {
  167. s.mu.Lock()
  168. if s.httpServer == nil {
  169. mux := http.NewServeMux()
  170. mux.Handle(s.endpointPath, s)
  171. s.httpServer = &http.Server{
  172. Addr: addr,
  173. Handler: mux,
  174. }
  175. } else {
  176. if s.httpServer.Addr == "" {
  177. s.httpServer.Addr = addr
  178. } else if s.httpServer.Addr != addr {
  179. return fmt.Errorf("conflicting listen address: WithStreamableHTTPServer(%q) vs Start(%q)", s.httpServer.Addr, addr)
  180. }
  181. }
  182. srv := s.httpServer
  183. s.mu.Unlock()
  184. if s.tlsCertFile != "" || s.tlsKeyFile != "" {
  185. if s.tlsCertFile == "" || s.tlsKeyFile == "" {
  186. return fmt.Errorf("both TLS cert and key must be provided")
  187. }
  188. if _, err := os.Stat(s.tlsCertFile); err != nil {
  189. return fmt.Errorf("failed to find TLS certificate file: %w", err)
  190. }
  191. if _, err := os.Stat(s.tlsKeyFile); err != nil {
  192. return fmt.Errorf("failed to find TLS key file: %w", err)
  193. }
  194. return srv.ListenAndServeTLS(s.tlsCertFile, s.tlsKeyFile)
  195. }
  196. return srv.ListenAndServe()
  197. }
  198. // Shutdown gracefully stops the server, closing all active sessions
  199. // and shutting down the HTTP server.
  200. func (s *StreamableHTTPServer) Shutdown(ctx context.Context) error {
  201. // shutdown the server if needed (may use as a http.Handler)
  202. s.mu.RLock()
  203. srv := s.httpServer
  204. s.mu.RUnlock()
  205. if srv != nil {
  206. return srv.Shutdown(ctx)
  207. }
  208. return nil
  209. }
  210. // --- internal methods ---
  211. func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request) {
  212. // post request carry request/notification message
  213. // Check content type
  214. contentType := r.Header.Get("Content-Type")
  215. mediaType, _, err := mime.ParseMediaType(contentType)
  216. if err != nil || mediaType != "application/json" {
  217. http.Error(w, "Invalid content type: must be 'application/json'", http.StatusBadRequest)
  218. return
  219. }
  220. // Check the request body is valid json, meanwhile, get the request Method
  221. rawData, err := io.ReadAll(r.Body)
  222. if err != nil {
  223. s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, fmt.Sprintf("read request body error: %v", err))
  224. return
  225. }
  226. // First, try to parse as a response (sampling responses don't have a method field)
  227. var jsonMessage struct {
  228. ID json.RawMessage `json:"id"`
  229. Result json.RawMessage `json:"result,omitempty"`
  230. Error json.RawMessage `json:"error,omitempty"`
  231. Method mcp.MCPMethod `json:"method,omitempty"`
  232. }
  233. if err := json.Unmarshal(rawData, &jsonMessage); err != nil {
  234. s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, "request body is not valid json")
  235. return
  236. }
  237. // Check if this is a sampling response (has result/error but no method)
  238. isSamplingResponse := jsonMessage.Method == "" && jsonMessage.ID != nil &&
  239. (jsonMessage.Result != nil || jsonMessage.Error != nil)
  240. isInitializeRequest := jsonMessage.Method == mcp.MethodInitialize
  241. // Handle sampling responses separately
  242. if isSamplingResponse {
  243. if err := s.handleSamplingResponse(w, r, jsonMessage); err != nil {
  244. s.logger.Errorf("Failed to handle sampling response: %v", err)
  245. http.Error(w, "Failed to handle sampling response", http.StatusInternalServerError)
  246. }
  247. return
  248. }
  249. // Prepare the session for the mcp server
  250. // The session is ephemeral. Its life is the same as the request. It's only created
  251. // for interaction with the mcp server.
  252. var sessionID string
  253. if isInitializeRequest {
  254. // generate a new one for initialize request
  255. sessionID = s.sessionIdManager.Generate()
  256. } else {
  257. // Get session ID from header.
  258. // Stateful servers need the client to carry the session ID.
  259. sessionID = r.Header.Get(HeaderKeySessionID)
  260. isTerminated, err := s.sessionIdManager.Validate(sessionID)
  261. if err != nil {
  262. http.Error(w, "Invalid session ID", http.StatusBadRequest)
  263. return
  264. }
  265. if isTerminated {
  266. http.Error(w, "Session terminated", http.StatusNotFound)
  267. return
  268. }
  269. }
  270. session := newStreamableHttpSession(sessionID, s.sessionTools, s.sessionLogLevels)
  271. // Set the client context before handling the message
  272. ctx := s.server.WithContext(r.Context(), session)
  273. if s.contextFunc != nil {
  274. ctx = s.contextFunc(ctx, r)
  275. }
  276. // handle potential notifications
  277. mu := sync.Mutex{}
  278. upgradedHeader := false
  279. done := make(chan struct{})
  280. ctx = context.WithValue(ctx, requestHeader, r.Header)
  281. go func() {
  282. for {
  283. select {
  284. case nt := <-session.notificationChannel:
  285. func() {
  286. mu.Lock()
  287. defer mu.Unlock()
  288. // if the done chan is closed, as the request is terminated, just return
  289. select {
  290. case <-done:
  291. return
  292. default:
  293. }
  294. defer func() {
  295. flusher, ok := w.(http.Flusher)
  296. if ok {
  297. flusher.Flush()
  298. }
  299. }()
  300. // if there's notifications, upgradedHeader to SSE response
  301. if !upgradedHeader {
  302. w.Header().Set("Content-Type", "text/event-stream")
  303. w.Header().Set("Connection", "keep-alive")
  304. w.Header().Set("Cache-Control", "no-cache")
  305. w.WriteHeader(http.StatusOK)
  306. upgradedHeader = true
  307. }
  308. err := writeSSEEvent(w, nt)
  309. if err != nil {
  310. s.logger.Errorf("Failed to write SSE event: %v", err)
  311. return
  312. }
  313. }()
  314. case <-done:
  315. return
  316. case <-ctx.Done():
  317. return
  318. }
  319. }
  320. }()
  321. // Process message through MCPServer
  322. response := s.server.HandleMessage(ctx, rawData)
  323. if response == nil {
  324. // For notifications, just send 202 Accepted with no body
  325. w.WriteHeader(http.StatusAccepted)
  326. return
  327. }
  328. // Write response
  329. mu.Lock()
  330. defer mu.Unlock()
  331. // close the done chan before unlock
  332. defer close(done)
  333. if ctx.Err() != nil {
  334. return
  335. }
  336. // If client-server communication already upgraded to SSE stream
  337. if session.upgradeToSSE.Load() {
  338. if !upgradedHeader {
  339. w.Header().Set("Content-Type", "text/event-stream")
  340. w.Header().Set("Connection", "keep-alive")
  341. w.Header().Set("Cache-Control", "no-cache")
  342. w.WriteHeader(http.StatusOK)
  343. upgradedHeader = true
  344. }
  345. if err := writeSSEEvent(w, response); err != nil {
  346. s.logger.Errorf("Failed to write final SSE response event: %v", err)
  347. }
  348. } else {
  349. w.Header().Set("Content-Type", "application/json")
  350. if isInitializeRequest && sessionID != "" {
  351. // send the session ID back to the client
  352. w.Header().Set(HeaderKeySessionID, sessionID)
  353. }
  354. w.WriteHeader(http.StatusOK)
  355. err := json.NewEncoder(w).Encode(response)
  356. if err != nil {
  357. s.logger.Errorf("Failed to write response: %v", err)
  358. }
  359. }
  360. }
  361. func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) {
  362. // get request is for listening to notifications
  363. // https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server
  364. sessionID := r.Header.Get(HeaderKeySessionID)
  365. // the specification didn't say we should validate the session id
  366. if sessionID == "" {
  367. // It's a stateless server,
  368. // but the MCP server requires a unique ID for registering, so we use a random one
  369. sessionID = uuid.New().String()
  370. }
  371. session := newStreamableHttpSession(sessionID, s.sessionTools, s.sessionLogLevels)
  372. if err := s.server.RegisterSession(r.Context(), session); err != nil {
  373. http.Error(w, fmt.Sprintf("Session registration failed: %v", err), http.StatusBadRequest)
  374. return
  375. }
  376. defer s.server.UnregisterSession(r.Context(), sessionID)
  377. // Register session for sampling response delivery
  378. s.activeSessions.Store(sessionID, session)
  379. defer s.activeSessions.Delete(sessionID)
  380. // Set the client context before handling the message
  381. w.Header().Set("Content-Type", "text/event-stream")
  382. w.Header().Set("Cache-Control", "no-cache")
  383. w.Header().Set("Connection", "keep-alive")
  384. w.WriteHeader(http.StatusOK)
  385. flusher, ok := w.(http.Flusher)
  386. if !ok {
  387. http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
  388. return
  389. }
  390. flusher.Flush()
  391. // Start notification handler for this session
  392. done := make(chan struct{})
  393. defer close(done)
  394. writeChan := make(chan any, 16)
  395. go func() {
  396. for {
  397. select {
  398. case nt := <-session.notificationChannel:
  399. select {
  400. case writeChan <- &nt:
  401. case <-done:
  402. return
  403. }
  404. case samplingReq := <-session.samplingRequestChan:
  405. // Send sampling request to client via SSE
  406. jsonrpcRequest := mcp.JSONRPCRequest{
  407. JSONRPC: "2.0",
  408. ID: mcp.NewRequestId(samplingReq.requestID),
  409. Request: mcp.Request{
  410. Method: string(mcp.MethodSamplingCreateMessage),
  411. },
  412. Params: samplingReq.request.CreateMessageParams,
  413. }
  414. select {
  415. case writeChan <- jsonrpcRequest:
  416. case <-done:
  417. return
  418. }
  419. case <-done:
  420. return
  421. }
  422. }
  423. }()
  424. if s.listenHeartbeatInterval > 0 {
  425. // heartbeat to keep the connection alive
  426. go func() {
  427. ticker := time.NewTicker(s.listenHeartbeatInterval)
  428. defer ticker.Stop()
  429. for {
  430. select {
  431. case <-ticker.C:
  432. message := mcp.JSONRPCRequest{
  433. JSONRPC: "2.0",
  434. ID: mcp.NewRequestId(s.nextRequestID(sessionID)),
  435. Request: mcp.Request{
  436. Method: "ping",
  437. },
  438. }
  439. select {
  440. case writeChan <- message:
  441. case <-done:
  442. return
  443. }
  444. case <-done:
  445. return
  446. }
  447. }
  448. }()
  449. }
  450. // Keep the connection open until the client disconnects
  451. //
  452. // There's will a Available() check when handler ends, and it maybe race with Flush(),
  453. // so we use a separate channel to send the data, inteading of flushing directly in other goroutine.
  454. for {
  455. select {
  456. case data := <-writeChan:
  457. if data == nil {
  458. continue
  459. }
  460. if err := writeSSEEvent(w, data); err != nil {
  461. s.logger.Errorf("Failed to write SSE event: %v", err)
  462. return
  463. }
  464. flusher.Flush()
  465. case <-r.Context().Done():
  466. return
  467. }
  468. }
  469. }
  470. func (s *StreamableHTTPServer) handleDelete(w http.ResponseWriter, r *http.Request) {
  471. // delete request terminate the session
  472. sessionID := r.Header.Get(HeaderKeySessionID)
  473. notAllowed, err := s.sessionIdManager.Terminate(sessionID)
  474. if err != nil {
  475. http.Error(w, fmt.Sprintf("Session termination failed: %v", err), http.StatusInternalServerError)
  476. return
  477. }
  478. if notAllowed {
  479. http.Error(w, "Session termination not allowed", http.StatusMethodNotAllowed)
  480. return
  481. }
  482. // remove the session relateddata from the sessionToolsStore
  483. s.sessionTools.delete(sessionID)
  484. s.sessionLogLevels.delete(sessionID)
  485. // remove current session's requstID information
  486. s.sessionRequestIDs.Delete(sessionID)
  487. w.WriteHeader(http.StatusOK)
  488. }
  489. func writeSSEEvent(w io.Writer, data any) error {
  490. jsonData, err := json.Marshal(data)
  491. if err != nil {
  492. return fmt.Errorf("failed to marshal data: %w", err)
  493. }
  494. _, err = fmt.Fprintf(w, "event: message\ndata: %s\n\n", jsonData)
  495. if err != nil {
  496. return fmt.Errorf("failed to write SSE event: %w", err)
  497. }
  498. return nil
  499. }
  500. // handleSamplingResponse processes incoming sampling responses from clients
  501. func (s *StreamableHTTPServer) handleSamplingResponse(w http.ResponseWriter, r *http.Request, responseMessage struct {
  502. ID json.RawMessage `json:"id"`
  503. Result json.RawMessage `json:"result,omitempty"`
  504. Error json.RawMessage `json:"error,omitempty"`
  505. Method mcp.MCPMethod `json:"method,omitempty"`
  506. }) error {
  507. // Get session ID from header
  508. sessionID := r.Header.Get(HeaderKeySessionID)
  509. if sessionID == "" {
  510. http.Error(w, "Missing session ID for sampling response", http.StatusBadRequest)
  511. return fmt.Errorf("missing session ID")
  512. }
  513. // Validate session
  514. isTerminated, err := s.sessionIdManager.Validate(sessionID)
  515. if err != nil {
  516. http.Error(w, "Invalid session ID", http.StatusBadRequest)
  517. return err
  518. }
  519. if isTerminated {
  520. http.Error(w, "Session terminated", http.StatusNotFound)
  521. return fmt.Errorf("session terminated")
  522. }
  523. // Parse the request ID
  524. var requestID int64
  525. if err := json.Unmarshal(responseMessage.ID, &requestID); err != nil {
  526. http.Error(w, "Invalid request ID in sampling response", http.StatusBadRequest)
  527. return err
  528. }
  529. // Create the sampling response item
  530. response := samplingResponseItem{
  531. requestID: requestID,
  532. }
  533. // Parse result or error
  534. if responseMessage.Error != nil {
  535. // Parse error
  536. var jsonrpcError struct {
  537. Code int `json:"code"`
  538. Message string `json:"message"`
  539. }
  540. if err := json.Unmarshal(responseMessage.Error, &jsonrpcError); err != nil {
  541. response.err = fmt.Errorf("failed to parse error: %v", err)
  542. } else {
  543. response.err = fmt.Errorf("sampling error %d: %s", jsonrpcError.Code, jsonrpcError.Message)
  544. }
  545. } else if responseMessage.Result != nil {
  546. // Parse result
  547. var result mcp.CreateMessageResult
  548. if err := json.Unmarshal(responseMessage.Result, &result); err != nil {
  549. response.err = fmt.Errorf("failed to parse sampling result: %v", err)
  550. } else {
  551. response.result = &result
  552. }
  553. } else {
  554. response.err = fmt.Errorf("sampling response has neither result nor error")
  555. }
  556. // Find the corresponding session and deliver the response
  557. // The response is delivered to the specific session identified by sessionID
  558. if err := s.deliverSamplingResponse(sessionID, response); err != nil {
  559. s.logger.Errorf("Failed to deliver sampling response: %v", err)
  560. http.Error(w, "Failed to deliver response", http.StatusInternalServerError)
  561. return err
  562. }
  563. // Acknowledge receipt
  564. w.WriteHeader(http.StatusOK)
  565. return nil
  566. }
  567. // deliverSamplingResponse delivers a sampling response to the appropriate session
  568. func (s *StreamableHTTPServer) deliverSamplingResponse(sessionID string, response samplingResponseItem) error {
  569. // Look up the active session
  570. sessionInterface, ok := s.activeSessions.Load(sessionID)
  571. if !ok {
  572. return fmt.Errorf("no active session found for session %s", sessionID)
  573. }
  574. session, ok := sessionInterface.(*streamableHttpSession)
  575. if !ok {
  576. return fmt.Errorf("invalid session type for session %s", sessionID)
  577. }
  578. // Look up the dedicated response channel for this specific request
  579. responseChannelInterface, exists := session.samplingRequests.Load(response.requestID)
  580. if !exists {
  581. return fmt.Errorf("no pending request found for session %s, request %d", sessionID, response.requestID)
  582. }
  583. responseChan, ok := responseChannelInterface.(chan samplingResponseItem)
  584. if !ok {
  585. return fmt.Errorf("invalid response channel type for session %s, request %d", sessionID, response.requestID)
  586. }
  587. // Attempt to deliver the response with timeout to prevent indefinite blocking
  588. select {
  589. case responseChan <- response:
  590. s.logger.Infof("Delivered sampling response for session %s, request %d", sessionID, response.requestID)
  591. return nil
  592. default:
  593. return fmt.Errorf("failed to deliver sampling response for session %s, request %d: channel full or blocked", sessionID, response.requestID)
  594. }
  595. }
  596. // writeJSONRPCError writes a JSON-RPC error response with the given error details.
  597. func (s *StreamableHTTPServer) writeJSONRPCError(
  598. w http.ResponseWriter,
  599. id any,
  600. code int,
  601. message string,
  602. ) {
  603. response := createErrorResponse(id, code, message)
  604. w.Header().Set("Content-Type", "application/json")
  605. w.WriteHeader(http.StatusBadRequest)
  606. err := json.NewEncoder(w).Encode(response)
  607. if err != nil {
  608. s.logger.Errorf("Failed to write JSONRPCError: %v", err)
  609. }
  610. }
  611. // nextRequestID gets the next incrementing requestID for the current session
  612. func (s *StreamableHTTPServer) nextRequestID(sessionID string) int64 {
  613. actual, _ := s.sessionRequestIDs.LoadOrStore(sessionID, new(atomic.Int64))
  614. counter := actual.(*atomic.Int64)
  615. return counter.Add(1)
  616. }
  617. // --- session ---
  618. type sessionLogLevelsStore struct {
  619. mu sync.RWMutex
  620. logs map[string]mcp.LoggingLevel
  621. }
  622. func newSessionLogLevelsStore() *sessionLogLevelsStore {
  623. return &sessionLogLevelsStore{
  624. logs: make(map[string]mcp.LoggingLevel),
  625. }
  626. }
  627. func (s *sessionLogLevelsStore) get(sessionID string) mcp.LoggingLevel {
  628. s.mu.RLock()
  629. defer s.mu.RUnlock()
  630. val, ok := s.logs[sessionID]
  631. if !ok {
  632. return mcp.LoggingLevelError
  633. }
  634. return val
  635. }
  636. func (s *sessionLogLevelsStore) set(sessionID string, level mcp.LoggingLevel) {
  637. s.mu.Lock()
  638. defer s.mu.Unlock()
  639. s.logs[sessionID] = level
  640. }
  641. func (s *sessionLogLevelsStore) delete(sessionID string) {
  642. s.mu.Lock()
  643. defer s.mu.Unlock()
  644. delete(s.logs, sessionID)
  645. }
  646. type sessionToolsStore struct {
  647. mu sync.RWMutex
  648. tools map[string]map[string]ServerTool // sessionID -> toolName -> tool
  649. }
  650. func newSessionToolsStore() *sessionToolsStore {
  651. return &sessionToolsStore{
  652. tools: make(map[string]map[string]ServerTool),
  653. }
  654. }
  655. func (s *sessionToolsStore) get(sessionID string) map[string]ServerTool {
  656. s.mu.RLock()
  657. defer s.mu.RUnlock()
  658. return s.tools[sessionID]
  659. }
  660. func (s *sessionToolsStore) set(sessionID string, tools map[string]ServerTool) {
  661. s.mu.Lock()
  662. defer s.mu.Unlock()
  663. s.tools[sessionID] = tools
  664. }
  665. func (s *sessionToolsStore) delete(sessionID string) {
  666. s.mu.Lock()
  667. defer s.mu.Unlock()
  668. delete(s.tools, sessionID)
  669. }
  670. // Sampling support types for HTTP transport
  671. type samplingRequestItem struct {
  672. requestID int64
  673. request mcp.CreateMessageRequest
  674. response chan samplingResponseItem
  675. }
  676. type samplingResponseItem struct {
  677. requestID int64
  678. result *mcp.CreateMessageResult
  679. err error
  680. }
  681. // streamableHttpSession is a session for streamable-http transport
  682. // When in POST handlers(request/notification), it's ephemeral, and only exists in the life of the request handler.
  683. // When in GET handlers(listening), it's a real session, and will be registered in the MCP server.
  684. type streamableHttpSession struct {
  685. sessionID string
  686. notificationChannel chan mcp.JSONRPCNotification // server -> client notifications
  687. tools *sessionToolsStore
  688. upgradeToSSE atomic.Bool
  689. logLevels *sessionLogLevelsStore
  690. // Sampling support for bidirectional communication
  691. samplingRequestChan chan samplingRequestItem // server -> client sampling requests
  692. samplingRequests sync.Map // requestID -> pending sampling request context
  693. requestIDCounter atomic.Int64 // for generating unique request IDs
  694. }
  695. func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore, levels *sessionLogLevelsStore) *streamableHttpSession {
  696. s := &streamableHttpSession{
  697. sessionID: sessionID,
  698. notificationChannel: make(chan mcp.JSONRPCNotification, 100),
  699. tools: toolStore,
  700. logLevels: levels,
  701. samplingRequestChan: make(chan samplingRequestItem, 10),
  702. }
  703. return s
  704. }
  705. func (s *streamableHttpSession) SessionID() string {
  706. return s.sessionID
  707. }
  708. func (s *streamableHttpSession) NotificationChannel() chan<- mcp.JSONRPCNotification {
  709. return s.notificationChannel
  710. }
  711. func (s *streamableHttpSession) Initialize() {
  712. // do nothing
  713. // the session is ephemeral, no real initialized action needed
  714. }
  715. func (s *streamableHttpSession) Initialized() bool {
  716. // the session is ephemeral, no real initialized action needed
  717. return true
  718. }
  719. func (s *streamableHttpSession) SetLogLevel(level mcp.LoggingLevel) {
  720. s.logLevels.set(s.sessionID, level)
  721. }
  722. func (s *streamableHttpSession) GetLogLevel() mcp.LoggingLevel {
  723. return s.logLevels.get(s.sessionID)
  724. }
  725. var _ ClientSession = (*streamableHttpSession)(nil)
  726. func (s *streamableHttpSession) GetSessionTools() map[string]ServerTool {
  727. return s.tools.get(s.sessionID)
  728. }
  729. func (s *streamableHttpSession) SetSessionTools(tools map[string]ServerTool) {
  730. s.tools.set(s.sessionID, tools)
  731. }
  732. var (
  733. _ SessionWithTools = (*streamableHttpSession)(nil)
  734. _ SessionWithLogging = (*streamableHttpSession)(nil)
  735. )
  736. func (s *streamableHttpSession) UpgradeToSSEWhenReceiveNotification() {
  737. s.upgradeToSSE.Store(true)
  738. }
  739. var _ SessionWithStreamableHTTPConfig = (*streamableHttpSession)(nil)
  740. // RequestSampling implements SessionWithSampling interface for HTTP transport
  741. func (s *streamableHttpSession) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) {
  742. // Generate unique request ID
  743. requestID := s.requestIDCounter.Add(1)
  744. // Create response channel for this specific request
  745. responseChan := make(chan samplingResponseItem, 1)
  746. // Create the sampling request item
  747. samplingRequest := samplingRequestItem{
  748. requestID: requestID,
  749. request: request,
  750. response: responseChan,
  751. }
  752. // Store the pending request
  753. s.samplingRequests.Store(requestID, responseChan)
  754. defer s.samplingRequests.Delete(requestID)
  755. // Send the sampling request via the channel (non-blocking)
  756. select {
  757. case s.samplingRequestChan <- samplingRequest:
  758. // Request queued successfully
  759. case <-ctx.Done():
  760. return nil, ctx.Err()
  761. default:
  762. return nil, fmt.Errorf("sampling request queue is full - server overloaded")
  763. }
  764. // Wait for response or context cancellation
  765. select {
  766. case response := <-responseChan:
  767. if response.err != nil {
  768. return nil, response.err
  769. }
  770. return response.result, nil
  771. case <-ctx.Done():
  772. return nil, ctx.Err()
  773. }
  774. }
  775. var _ SessionWithSampling = (*streamableHttpSession)(nil)
  776. // --- session id manager ---
  777. type SessionIdManager interface {
  778. Generate() string
  779. // Validate checks if a session ID is valid and not terminated.
  780. // Returns isTerminated=true if the ID is valid but belongs to a terminated session.
  781. // Returns err!=nil if the ID format is invalid or lookup failed.
  782. Validate(sessionID string) (isTerminated bool, err error)
  783. // Terminate marks a session ID as terminated.
  784. // Returns isNotAllowed=true if the server policy prevents client termination.
  785. // Returns err!=nil if the ID is invalid or termination failed.
  786. Terminate(sessionID string) (isNotAllowed bool, err error)
  787. }
  788. // StatelessSessionIdManager does nothing, which means it has no session management, which is stateless.
  789. type StatelessSessionIdManager struct{}
  790. func (s *StatelessSessionIdManager) Generate() string {
  791. return ""
  792. }
  793. func (s *StatelessSessionIdManager) Validate(sessionID string) (isTerminated bool, err error) {
  794. // In stateless mode, ignore session IDs completely - don't validate or reject them
  795. return false, nil
  796. }
  797. func (s *StatelessSessionIdManager) Terminate(sessionID string) (isNotAllowed bool, err error) {
  798. return false, nil
  799. }
  800. // InsecureStatefulSessionIdManager generate id with uuid
  801. // It won't validate the id indeed, so it could be fake.
  802. // For more secure session id, use a more complex generator, like a JWT.
  803. type InsecureStatefulSessionIdManager struct{}
  804. const idPrefix = "mcp-session-"
  805. func (s *InsecureStatefulSessionIdManager) Generate() string {
  806. return idPrefix + uuid.New().String()
  807. }
  808. func (s *InsecureStatefulSessionIdManager) Validate(sessionID string) (isTerminated bool, err error) {
  809. // validate the session id is a valid uuid
  810. if !strings.HasPrefix(sessionID, idPrefix) {
  811. return false, fmt.Errorf("invalid session id: %s", sessionID)
  812. }
  813. if _, err := uuid.Parse(sessionID[len(idPrefix):]); err != nil {
  814. return false, fmt.Errorf("invalid session id: %s", sessionID)
  815. }
  816. return false, nil
  817. }
  818. func (s *InsecureStatefulSessionIdManager) Terminate(sessionID string) (isNotAllowed bool, err error) {
  819. return false, nil
  820. }
  821. // NewTestStreamableHTTPServer creates a test server for testing purposes
  822. func NewTestStreamableHTTPServer(server *MCPServer, opts ...StreamableHTTPOption) *httptest.Server {
  823. sseServer := NewStreamableHTTPServer(server, opts...)
  824. testServer := httptest.NewServer(sseServer)
  825. return testServer
  826. }