sse.go 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751
  1. package server
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "log"
  7. "net/http"
  8. "net/http/httptest"
  9. "net/url"
  10. "path"
  11. "strings"
  12. "sync"
  13. "sync/atomic"
  14. "time"
  15. "github.com/google/uuid"
  16. "github.com/mark3labs/mcp-go/mcp"
  17. )
  18. // sseSession represents an active SSE connection.
  19. type sseSession struct {
  20. done chan struct{}
  21. eventQueue chan string // Channel for queuing events
  22. sessionID string
  23. requestID atomic.Int64
  24. notificationChannel chan mcp.JSONRPCNotification
  25. initialized atomic.Bool
  26. loggingLevel atomic.Value
  27. tools sync.Map // stores session-specific tools
  28. clientInfo atomic.Value // stores session-specific client info
  29. clientCapabilities atomic.Value // stores session-specific client capabilities
  30. }
  31. // SSEContextFunc is a function that takes an existing context and the current
  32. // request and returns a potentially modified context based on the request
  33. // content. This can be used to inject context values from headers, for example.
  34. type SSEContextFunc func(ctx context.Context, r *http.Request) context.Context
  35. // DynamicBasePathFunc allows the user to provide a function to generate the
  36. // base path for a given request and sessionID. This is useful for cases where
  37. // the base path is not known at the time of SSE server creation, such as when
  38. // using a reverse proxy or when the base path is dynamically generated. The
  39. // function should return the base path (e.g., "/mcp/tenant123").
  40. type DynamicBasePathFunc func(r *http.Request, sessionID string) string
  41. func (s *sseSession) SessionID() string {
  42. return s.sessionID
  43. }
  44. func (s *sseSession) NotificationChannel() chan<- mcp.JSONRPCNotification {
  45. return s.notificationChannel
  46. }
  47. func (s *sseSession) Initialize() {
  48. // set default logging level
  49. s.loggingLevel.Store(mcp.LoggingLevelError)
  50. s.initialized.Store(true)
  51. }
  52. func (s *sseSession) Initialized() bool {
  53. return s.initialized.Load()
  54. }
  55. func (s *sseSession) SetLogLevel(level mcp.LoggingLevel) {
  56. s.loggingLevel.Store(level)
  57. }
  58. func (s *sseSession) GetLogLevel() mcp.LoggingLevel {
  59. level := s.loggingLevel.Load()
  60. if level == nil {
  61. return mcp.LoggingLevelError
  62. }
  63. return level.(mcp.LoggingLevel)
  64. }
  65. func (s *sseSession) GetSessionTools() map[string]ServerTool {
  66. tools := make(map[string]ServerTool)
  67. s.tools.Range(func(key, value any) bool {
  68. if tool, ok := value.(ServerTool); ok {
  69. tools[key.(string)] = tool
  70. }
  71. return true
  72. })
  73. return tools
  74. }
  75. func (s *sseSession) SetSessionTools(tools map[string]ServerTool) {
  76. // Clear existing tools
  77. s.tools.Clear()
  78. // Set new tools
  79. for name, tool := range tools {
  80. s.tools.Store(name, tool)
  81. }
  82. }
  83. func (s *sseSession) GetClientInfo() mcp.Implementation {
  84. if value := s.clientInfo.Load(); value != nil {
  85. if clientInfo, ok := value.(mcp.Implementation); ok {
  86. return clientInfo
  87. }
  88. }
  89. return mcp.Implementation{}
  90. }
  91. func (s *sseSession) SetClientInfo(clientInfo mcp.Implementation) {
  92. s.clientInfo.Store(clientInfo)
  93. }
  94. func (s *sseSession) SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) {
  95. s.clientCapabilities.Store(clientCapabilities)
  96. }
  97. func (s *sseSession) GetClientCapabilities() mcp.ClientCapabilities {
  98. if value := s.clientCapabilities.Load(); value != nil {
  99. if clientCapabilities, ok := value.(mcp.ClientCapabilities); ok {
  100. return clientCapabilities
  101. }
  102. }
  103. return mcp.ClientCapabilities{}
  104. }
  105. var (
  106. _ ClientSession = (*sseSession)(nil)
  107. _ SessionWithTools = (*sseSession)(nil)
  108. _ SessionWithLogging = (*sseSession)(nil)
  109. _ SessionWithClientInfo = (*sseSession)(nil)
  110. )
  111. // SSEServer implements a Server-Sent Events (SSE) based MCP server.
  112. // It provides real-time communication capabilities over HTTP using the SSE protocol.
  113. type SSEServer struct {
  114. server *MCPServer
  115. baseURL string
  116. basePath string
  117. appendQueryToMessageEndpoint bool
  118. useFullURLForMessageEndpoint bool
  119. messageEndpoint string
  120. sseEndpoint string
  121. sessions sync.Map
  122. srv *http.Server
  123. contextFunc SSEContextFunc
  124. dynamicBasePathFunc DynamicBasePathFunc
  125. keepAlive bool
  126. keepAliveInterval time.Duration
  127. mu sync.RWMutex
  128. }
  129. // SSEOption defines a function type for configuring SSEServer
  130. type SSEOption func(*SSEServer)
  131. // WithBaseURL sets the base URL for the SSE server
  132. func WithBaseURL(baseURL string) SSEOption {
  133. return func(s *SSEServer) {
  134. if baseURL != "" {
  135. u, err := url.Parse(baseURL)
  136. if err != nil {
  137. return
  138. }
  139. if u.Scheme != "http" && u.Scheme != "https" {
  140. return
  141. }
  142. // Check if the host is empty or only contains a port
  143. if u.Host == "" || strings.HasPrefix(u.Host, ":") {
  144. return
  145. }
  146. if len(u.Query()) > 0 {
  147. return
  148. }
  149. }
  150. s.baseURL = strings.TrimSuffix(baseURL, "/")
  151. }
  152. }
  153. // WithStaticBasePath adds a new option for setting a static base path
  154. func WithStaticBasePath(basePath string) SSEOption {
  155. return func(s *SSEServer) {
  156. s.basePath = normalizeURLPath(basePath)
  157. }
  158. }
  159. // WithBasePath adds a new option for setting a static base path.
  160. //
  161. // Deprecated: Use WithStaticBasePath instead. This will be removed in a future version.
  162. //
  163. //go:deprecated
  164. func WithBasePath(basePath string) SSEOption {
  165. return WithStaticBasePath(basePath)
  166. }
  167. // WithDynamicBasePath accepts a function for generating the base path. This is
  168. // useful for cases where the base path is not known at the time of SSE server
  169. // creation, such as when using a reverse proxy or when the server is mounted
  170. // at a dynamic path.
  171. func WithDynamicBasePath(fn DynamicBasePathFunc) SSEOption {
  172. return func(s *SSEServer) {
  173. if fn != nil {
  174. s.dynamicBasePathFunc = func(r *http.Request, sid string) string {
  175. bp := fn(r, sid)
  176. return normalizeURLPath(bp)
  177. }
  178. }
  179. }
  180. }
  181. // WithMessageEndpoint sets the message endpoint path
  182. func WithMessageEndpoint(endpoint string) SSEOption {
  183. return func(s *SSEServer) {
  184. s.messageEndpoint = endpoint
  185. }
  186. }
  187. // WithAppendQueryToMessageEndpoint configures the SSE server to append the original request's
  188. // query parameters to the message endpoint URL that is sent to clients during the SSE connection
  189. // initialization. This is useful when you need to preserve query parameters from the initial
  190. // SSE connection request and carry them over to subsequent message requests, maintaining
  191. // context or authentication details across the communication channel.
  192. func WithAppendQueryToMessageEndpoint() SSEOption {
  193. return func(s *SSEServer) {
  194. s.appendQueryToMessageEndpoint = true
  195. }
  196. }
  197. // WithUseFullURLForMessageEndpoint controls whether the SSE server returns a complete URL (including baseURL)
  198. // or just the path portion for the message endpoint. Set to false when clients will concatenate
  199. // the baseURL themselves to avoid malformed URLs like "http://localhost/mcphttp://localhost/mcp/message".
  200. func WithUseFullURLForMessageEndpoint(useFullURLForMessageEndpoint bool) SSEOption {
  201. return func(s *SSEServer) {
  202. s.useFullURLForMessageEndpoint = useFullURLForMessageEndpoint
  203. }
  204. }
  205. // WithSSEEndpoint sets the SSE endpoint path
  206. func WithSSEEndpoint(endpoint string) SSEOption {
  207. return func(s *SSEServer) {
  208. s.sseEndpoint = endpoint
  209. }
  210. }
  211. // WithHTTPServer sets the HTTP server instance.
  212. // NOTE: When providing a custom HTTP server, you must handle routing yourself
  213. // If routing is not set up, the server will start but won't handle any MCP requests.
  214. func WithHTTPServer(srv *http.Server) SSEOption {
  215. return func(s *SSEServer) {
  216. s.srv = srv
  217. }
  218. }
  219. func WithKeepAliveInterval(keepAliveInterval time.Duration) SSEOption {
  220. return func(s *SSEServer) {
  221. s.keepAlive = true
  222. s.keepAliveInterval = keepAliveInterval
  223. }
  224. }
  225. func WithKeepAlive(keepAlive bool) SSEOption {
  226. return func(s *SSEServer) {
  227. s.keepAlive = keepAlive
  228. }
  229. }
  230. // WithSSEContextFunc sets a function that will be called to customise the context
  231. // to the server using the incoming request.
  232. func WithSSEContextFunc(fn SSEContextFunc) SSEOption {
  233. return func(s *SSEServer) {
  234. s.contextFunc = fn
  235. }
  236. }
  237. // NewSSEServer creates a new SSE server instance with the given MCP server and options.
  238. func NewSSEServer(server *MCPServer, opts ...SSEOption) *SSEServer {
  239. s := &SSEServer{
  240. server: server,
  241. sseEndpoint: "/sse",
  242. messageEndpoint: "/message",
  243. useFullURLForMessageEndpoint: true,
  244. keepAlive: false,
  245. keepAliveInterval: 10 * time.Second,
  246. }
  247. // Apply all options
  248. for _, opt := range opts {
  249. opt(s)
  250. }
  251. return s
  252. }
  253. // NewTestServer creates a test server for testing purposes
  254. func NewTestServer(server *MCPServer, opts ...SSEOption) *httptest.Server {
  255. sseServer := NewSSEServer(server, opts...)
  256. testServer := httptest.NewServer(sseServer)
  257. sseServer.baseURL = testServer.URL
  258. return testServer
  259. }
  260. // Start begins serving SSE connections on the specified address.
  261. // It sets up HTTP handlers for SSE and message endpoints.
  262. func (s *SSEServer) Start(addr string) error {
  263. s.mu.Lock()
  264. if s.srv == nil {
  265. s.srv = &http.Server{
  266. Addr: addr,
  267. Handler: s,
  268. }
  269. } else {
  270. if s.srv.Addr == "" {
  271. s.srv.Addr = addr
  272. } else if s.srv.Addr != addr {
  273. return fmt.Errorf("conflicting listen address: WithHTTPServer(%q) vs Start(%q)", s.srv.Addr, addr)
  274. }
  275. }
  276. srv := s.srv
  277. s.mu.Unlock()
  278. return srv.ListenAndServe()
  279. }
  280. // Shutdown gracefully stops the SSE server, closing all active sessions
  281. // and shutting down the HTTP server.
  282. func (s *SSEServer) Shutdown(ctx context.Context) error {
  283. s.mu.RLock()
  284. srv := s.srv
  285. s.mu.RUnlock()
  286. if srv != nil {
  287. s.sessions.Range(func(key, value any) bool {
  288. if session, ok := value.(*sseSession); ok {
  289. close(session.done)
  290. }
  291. s.sessions.Delete(key)
  292. return true
  293. })
  294. return srv.Shutdown(ctx)
  295. }
  296. return nil
  297. }
  298. // handleSSE handles incoming SSE connection requests.
  299. // It sets up appropriate headers and creates a new session for the client.
  300. func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
  301. if r.Method != http.MethodGet {
  302. http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
  303. return
  304. }
  305. w.Header().Set("Content-Type", "text/event-stream")
  306. w.Header().Set("Cache-Control", "no-cache")
  307. w.Header().Set("Connection", "keep-alive")
  308. w.Header().Set("Access-Control-Allow-Origin", "*")
  309. flusher, ok := w.(http.Flusher)
  310. if !ok {
  311. http.Error(w, "Streaming unsupported", http.StatusInternalServerError)
  312. return
  313. }
  314. sessionID := uuid.New().String()
  315. session := &sseSession{
  316. done: make(chan struct{}),
  317. eventQueue: make(chan string, 100), // Buffer for events
  318. sessionID: sessionID,
  319. notificationChannel: make(chan mcp.JSONRPCNotification, 100),
  320. }
  321. s.sessions.Store(sessionID, session)
  322. defer s.sessions.Delete(sessionID)
  323. if err := s.server.RegisterSession(r.Context(), session); err != nil {
  324. http.Error(
  325. w,
  326. fmt.Sprintf("Session registration failed: %v", err),
  327. http.StatusInternalServerError,
  328. )
  329. return
  330. }
  331. defer s.server.UnregisterSession(r.Context(), sessionID)
  332. // Start notification handler for this session
  333. go func() {
  334. for {
  335. select {
  336. case notification := <-session.notificationChannel:
  337. eventData, err := json.Marshal(notification)
  338. if err == nil {
  339. select {
  340. case session.eventQueue <- fmt.Sprintf("event: message\ndata: %s\n\n", eventData):
  341. // Event queued successfully
  342. case <-session.done:
  343. return
  344. }
  345. }
  346. case <-session.done:
  347. return
  348. case <-r.Context().Done():
  349. return
  350. }
  351. }
  352. }()
  353. // Start keep alive : ping
  354. if s.keepAlive {
  355. go func() {
  356. ticker := time.NewTicker(s.keepAliveInterval)
  357. defer ticker.Stop()
  358. for {
  359. select {
  360. case <-ticker.C:
  361. message := mcp.JSONRPCRequest{
  362. JSONRPC: "2.0",
  363. ID: mcp.NewRequestId(session.requestID.Add(1)),
  364. Request: mcp.Request{
  365. Method: "ping",
  366. },
  367. }
  368. messageBytes, _ := json.Marshal(message)
  369. pingMsg := fmt.Sprintf("event: message\ndata:%s\n\n", messageBytes)
  370. select {
  371. case session.eventQueue <- pingMsg:
  372. // Message sent successfully
  373. case <-session.done:
  374. return
  375. }
  376. case <-session.done:
  377. return
  378. case <-r.Context().Done():
  379. return
  380. }
  381. }
  382. }()
  383. }
  384. // Send the initial endpoint event
  385. endpoint := s.GetMessageEndpointForClient(r, sessionID)
  386. if s.appendQueryToMessageEndpoint && len(r.URL.RawQuery) > 0 {
  387. endpoint += "&" + r.URL.RawQuery
  388. }
  389. fmt.Fprintf(w, "event: endpoint\ndata: %s\r\n\r\n", endpoint)
  390. flusher.Flush()
  391. // Main event loop - this runs in the HTTP handler goroutine
  392. for {
  393. select {
  394. case event := <-session.eventQueue:
  395. // Write the event to the response
  396. fmt.Fprint(w, event)
  397. flusher.Flush()
  398. case <-r.Context().Done():
  399. close(session.done)
  400. return
  401. case <-session.done:
  402. return
  403. }
  404. }
  405. }
  406. // GetMessageEndpointForClient returns the appropriate message endpoint URL with session ID
  407. // for the given request. This is the canonical way to compute the message endpoint for a client.
  408. // It handles both dynamic and static path modes, and honors the WithUseFullURLForMessageEndpoint flag.
  409. func (s *SSEServer) GetMessageEndpointForClient(r *http.Request, sessionID string) string {
  410. basePath := s.basePath
  411. if s.dynamicBasePathFunc != nil {
  412. basePath = s.dynamicBasePathFunc(r, sessionID)
  413. }
  414. endpointPath := normalizeURLPath(basePath, s.messageEndpoint)
  415. if s.useFullURLForMessageEndpoint && s.baseURL != "" {
  416. endpointPath = s.baseURL + endpointPath
  417. }
  418. return fmt.Sprintf("%s?sessionId=%s", endpointPath, sessionID)
  419. }
  420. // handleMessage processes incoming JSON-RPC messages from clients and sends responses
  421. // back through the SSE connection and 202 code to HTTP response.
  422. func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) {
  423. if r.Method != http.MethodPost {
  424. s.writeJSONRPCError(w, nil, mcp.INVALID_REQUEST, "Method not allowed")
  425. return
  426. }
  427. sessionID := r.URL.Query().Get("sessionId")
  428. if sessionID == "" {
  429. s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Missing sessionId")
  430. return
  431. }
  432. sessionI, ok := s.sessions.Load(sessionID)
  433. if !ok {
  434. s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Invalid session ID")
  435. return
  436. }
  437. session := sessionI.(*sseSession)
  438. // Set the client context before handling the message
  439. ctx := s.server.WithContext(r.Context(), session)
  440. if s.contextFunc != nil {
  441. ctx = s.contextFunc(ctx, r)
  442. }
  443. // Parse message as raw JSON
  444. var rawMessage json.RawMessage
  445. if err := json.NewDecoder(r.Body).Decode(&rawMessage); err != nil {
  446. s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, "Parse error")
  447. return
  448. }
  449. // Create a context that preserves all values from parent ctx but won't be canceled when the parent is canceled.
  450. // this is required because the http ctx will be canceled when the client disconnects
  451. detachedCtx := context.WithoutCancel(ctx)
  452. // quick return request, send 202 Accepted with no body, then deal the message and sent response via SSE
  453. w.WriteHeader(http.StatusAccepted)
  454. // Create a new context for handling the message that will be canceled when the message handling is done
  455. messageCtx := context.WithValue(detachedCtx, requestHeader, r.Header)
  456. messageCtx, cancel := context.WithCancel(messageCtx)
  457. go func(ctx context.Context) {
  458. defer cancel()
  459. // Use the context that will be canceled when session is done
  460. // Process message through MCPServer
  461. response := s.server.HandleMessage(ctx, rawMessage)
  462. // Only send response if there is one (not for notifications)
  463. if response != nil {
  464. var message string
  465. if eventData, err := json.Marshal(response); err != nil {
  466. // If there is an error marshalling the response, send a generic error response
  467. log.Printf("failed to marshal response: %v", err)
  468. message = "event: message\ndata: {\"error\": \"internal error\",\"jsonrpc\": \"2.0\", \"id\": null}\n\n"
  469. } else {
  470. message = fmt.Sprintf("event: message\ndata: %s\n\n", eventData)
  471. }
  472. // Queue the event for sending via SSE
  473. select {
  474. case session.eventQueue <- message:
  475. // Event queued successfully
  476. case <-session.done:
  477. // Session is closed, don't try to queue
  478. default:
  479. // Queue is full, log this situation
  480. log.Printf("Event queue full for session %s", sessionID)
  481. }
  482. }
  483. }(messageCtx)
  484. }
  485. // writeJSONRPCError writes a JSON-RPC error response with the given error details.
  486. func (s *SSEServer) writeJSONRPCError(
  487. w http.ResponseWriter,
  488. id any,
  489. code int,
  490. message string,
  491. ) {
  492. response := createErrorResponse(id, code, message)
  493. w.Header().Set("Content-Type", "application/json")
  494. w.WriteHeader(http.StatusBadRequest)
  495. if err := json.NewEncoder(w).Encode(response); err != nil {
  496. http.Error(
  497. w,
  498. fmt.Sprintf("Failed to encode response: %v", err),
  499. http.StatusInternalServerError,
  500. )
  501. return
  502. }
  503. }
  504. // SendEventToSession sends an event to a specific SSE session identified by sessionID.
  505. // Returns an error if the session is not found or closed.
  506. func (s *SSEServer) SendEventToSession(
  507. sessionID string,
  508. event any,
  509. ) error {
  510. sessionI, ok := s.sessions.Load(sessionID)
  511. if !ok {
  512. return fmt.Errorf("session not found: %s", sessionID)
  513. }
  514. session := sessionI.(*sseSession)
  515. eventData, err := json.Marshal(event)
  516. if err != nil {
  517. return err
  518. }
  519. // Queue the event for sending via SSE
  520. select {
  521. case session.eventQueue <- fmt.Sprintf("event: message\ndata: %s\n\n", eventData):
  522. return nil
  523. case <-session.done:
  524. return fmt.Errorf("session closed")
  525. default:
  526. return fmt.Errorf("event queue full")
  527. }
  528. }
  529. func (s *SSEServer) GetUrlPath(input string) (string, error) {
  530. parse, err := url.Parse(input)
  531. if err != nil {
  532. return "", fmt.Errorf("failed to parse URL %s: %w", input, err)
  533. }
  534. return parse.Path, nil
  535. }
  536. func (s *SSEServer) CompleteSseEndpoint() (string, error) {
  537. if s.dynamicBasePathFunc != nil {
  538. return "", &ErrDynamicPathConfig{Method: "CompleteSseEndpoint"}
  539. }
  540. path := normalizeURLPath(s.basePath, s.sseEndpoint)
  541. return s.baseURL + path, nil
  542. }
  543. func (s *SSEServer) CompleteSsePath() string {
  544. path, err := s.CompleteSseEndpoint()
  545. if err != nil {
  546. return normalizeURLPath(s.basePath, s.sseEndpoint)
  547. }
  548. urlPath, err := s.GetUrlPath(path)
  549. if err != nil {
  550. return normalizeURLPath(s.basePath, s.sseEndpoint)
  551. }
  552. return urlPath
  553. }
  554. func (s *SSEServer) CompleteMessageEndpoint() (string, error) {
  555. if s.dynamicBasePathFunc != nil {
  556. return "", &ErrDynamicPathConfig{Method: "CompleteMessageEndpoint"}
  557. }
  558. path := normalizeURLPath(s.basePath, s.messageEndpoint)
  559. return s.baseURL + path, nil
  560. }
  561. func (s *SSEServer) CompleteMessagePath() string {
  562. path, err := s.CompleteMessageEndpoint()
  563. if err != nil {
  564. return normalizeURLPath(s.basePath, s.messageEndpoint)
  565. }
  566. urlPath, err := s.GetUrlPath(path)
  567. if err != nil {
  568. return normalizeURLPath(s.basePath, s.messageEndpoint)
  569. }
  570. return urlPath
  571. }
  572. // SSEHandler returns an http.Handler for the SSE endpoint.
  573. //
  574. // This method allows you to mount the SSE handler at any arbitrary path
  575. // using your own router (e.g. net/http, gorilla/mux, chi, etc.). It is
  576. // intended for advanced scenarios where you want to control the routing or
  577. // support dynamic segments.
  578. //
  579. // IMPORTANT: When using this handler in advanced/dynamic mounting scenarios,
  580. // you must use the WithDynamicBasePath option to ensure the correct base path
  581. // is communicated to clients.
  582. //
  583. // Example usage:
  584. //
  585. // // Advanced/dynamic:
  586. // sseServer := NewSSEServer(mcpServer,
  587. // WithDynamicBasePath(func(r *http.Request, sessionID string) string {
  588. // tenant := r.PathValue("tenant")
  589. // return "/mcp/" + tenant
  590. // }),
  591. // WithBaseURL("http://localhost:8080")
  592. // )
  593. // mux.Handle("/mcp/{tenant}/sse", sseServer.SSEHandler())
  594. // mux.Handle("/mcp/{tenant}/message", sseServer.MessageHandler())
  595. //
  596. // For non-dynamic cases, use ServeHTTP method instead.
  597. func (s *SSEServer) SSEHandler() http.Handler {
  598. return http.HandlerFunc(s.handleSSE)
  599. }
  600. // MessageHandler returns an http.Handler for the message endpoint.
  601. //
  602. // This method allows you to mount the message handler at any arbitrary path
  603. // using your own router (e.g. net/http, gorilla/mux, chi, etc.). It is
  604. // intended for advanced scenarios where you want to control the routing or
  605. // support dynamic segments.
  606. //
  607. // IMPORTANT: When using this handler in advanced/dynamic mounting scenarios,
  608. // you must use the WithDynamicBasePath option to ensure the correct base path
  609. // is communicated to clients.
  610. //
  611. // Example usage:
  612. //
  613. // // Advanced/dynamic:
  614. // sseServer := NewSSEServer(mcpServer,
  615. // WithDynamicBasePath(func(r *http.Request, sessionID string) string {
  616. // tenant := r.PathValue("tenant")
  617. // return "/mcp/" + tenant
  618. // }),
  619. // WithBaseURL("http://localhost:8080")
  620. // )
  621. // mux.Handle("/mcp/{tenant}/sse", sseServer.SSEHandler())
  622. // mux.Handle("/mcp/{tenant}/message", sseServer.MessageHandler())
  623. //
  624. // For non-dynamic cases, use ServeHTTP method instead.
  625. func (s *SSEServer) MessageHandler() http.Handler {
  626. return http.HandlerFunc(s.handleMessage)
  627. }
  628. // ServeHTTP implements the http.Handler interface.
  629. func (s *SSEServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  630. if s.dynamicBasePathFunc != nil {
  631. http.Error(
  632. w,
  633. (&ErrDynamicPathConfig{Method: "ServeHTTP"}).Error(),
  634. http.StatusInternalServerError,
  635. )
  636. return
  637. }
  638. path := r.URL.Path
  639. // Use exact path matching rather than Contains
  640. ssePath := s.CompleteSsePath()
  641. if ssePath != "" && path == ssePath {
  642. s.handleSSE(w, r)
  643. return
  644. }
  645. messagePath := s.CompleteMessagePath()
  646. if messagePath != "" && path == messagePath {
  647. s.handleMessage(w, r)
  648. return
  649. }
  650. http.NotFound(w, r)
  651. }
  652. // normalizeURLPath joins path elements like path.Join but ensures the
  653. // result always starts with a leading slash and never ends with a slash
  654. func normalizeURLPath(elem ...string) string {
  655. joined := path.Join(elem...)
  656. // Ensure leading slash
  657. if !strings.HasPrefix(joined, "/") {
  658. joined = "/" + joined
  659. }
  660. // Remove trailing slash if not just "/"
  661. if len(joined) > 1 && strings.HasSuffix(joined, "/") {
  662. joined = joined[:len(joined)-1]
  663. }
  664. return joined
  665. }