session.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. package server
  2. import (
  3. "context"
  4. "fmt"
  5. "github.com/mark3labs/mcp-go/mcp"
  6. )
  7. // ClientSession represents an active session that can be used by MCPServer to interact with client.
  8. type ClientSession interface {
  9. // Initialize marks session as fully initialized and ready for notifications
  10. Initialize()
  11. // Initialized returns if session is ready to accept notifications
  12. Initialized() bool
  13. // NotificationChannel provides a channel suitable for sending notifications to client.
  14. NotificationChannel() chan<- mcp.JSONRPCNotification
  15. // SessionID is a unique identifier used to track user session.
  16. SessionID() string
  17. }
  18. // SessionWithLogging is an extension of ClientSession that can receive log message notifications and set log level
  19. type SessionWithLogging interface {
  20. ClientSession
  21. // SetLogLevel sets the minimum log level
  22. SetLogLevel(level mcp.LoggingLevel)
  23. // GetLogLevel retrieves the minimum log level
  24. GetLogLevel() mcp.LoggingLevel
  25. }
  26. // SessionWithTools is an extension of ClientSession that can store session-specific tool data
  27. type SessionWithTools interface {
  28. ClientSession
  29. // GetSessionTools returns the tools specific to this session, if any
  30. // This method must be thread-safe for concurrent access
  31. GetSessionTools() map[string]ServerTool
  32. // SetSessionTools sets tools specific to this session
  33. // This method must be thread-safe for concurrent access
  34. SetSessionTools(tools map[string]ServerTool)
  35. }
  36. // SessionWithClientInfo is an extension of ClientSession that can store client info
  37. type SessionWithClientInfo interface {
  38. ClientSession
  39. // GetClientInfo returns the client information for this session
  40. GetClientInfo() mcp.Implementation
  41. // SetClientInfo sets the client information for this session
  42. SetClientInfo(clientInfo mcp.Implementation)
  43. // GetClientCapabilities returns the client capabilities for this session
  44. GetClientCapabilities() mcp.ClientCapabilities
  45. // SetClientCapabilities sets the client capabilities for this session
  46. SetClientCapabilities(clientCapabilities mcp.ClientCapabilities)
  47. }
  48. // SessionWithStreamableHTTPConfig extends ClientSession to support streamable HTTP transport configurations
  49. type SessionWithStreamableHTTPConfig interface {
  50. ClientSession
  51. // UpgradeToSSEWhenReceiveNotification upgrades the client-server communication to SSE stream when the server
  52. // sends notifications to the client
  53. //
  54. // The protocol specification:
  55. // - If the server response contains any JSON-RPC notifications, it MUST either:
  56. // - Return Content-Type: text/event-stream to initiate an SSE stream, OR
  57. // - Return Content-Type: application/json for a single JSON object
  58. // - The client MUST support both response types.
  59. //
  60. // Reference: https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#sending-messages-to-the-server
  61. UpgradeToSSEWhenReceiveNotification()
  62. }
  63. // clientSessionKey is the context key for storing current client notification channel.
  64. type clientSessionKey struct{}
  65. // ClientSessionFromContext retrieves current client notification context from context.
  66. func ClientSessionFromContext(ctx context.Context) ClientSession {
  67. if session, ok := ctx.Value(clientSessionKey{}).(ClientSession); ok {
  68. return session
  69. }
  70. return nil
  71. }
  72. // WithContext sets the current client session and returns the provided context
  73. func (s *MCPServer) WithContext(
  74. ctx context.Context,
  75. session ClientSession,
  76. ) context.Context {
  77. return context.WithValue(ctx, clientSessionKey{}, session)
  78. }
  79. // RegisterSession saves session that should be notified in case if some server attributes changed.
  80. func (s *MCPServer) RegisterSession(
  81. ctx context.Context,
  82. session ClientSession,
  83. ) error {
  84. sessionID := session.SessionID()
  85. if _, exists := s.sessions.LoadOrStore(sessionID, session); exists {
  86. return ErrSessionExists
  87. }
  88. s.hooks.RegisterSession(ctx, session)
  89. return nil
  90. }
  91. func (s *MCPServer) buildLogNotification(notification mcp.LoggingMessageNotification) mcp.JSONRPCNotification {
  92. return mcp.JSONRPCNotification{
  93. JSONRPC: mcp.JSONRPC_VERSION,
  94. Notification: mcp.Notification{
  95. Method: notification.Method,
  96. Params: mcp.NotificationParams{
  97. AdditionalFields: map[string]any{
  98. "level": notification.Params.Level,
  99. "logger": notification.Params.Logger,
  100. "data": notification.Params.Data,
  101. },
  102. },
  103. },
  104. }
  105. }
  106. func (s *MCPServer) SendLogMessageToClient(ctx context.Context, notification mcp.LoggingMessageNotification) error {
  107. session := ClientSessionFromContext(ctx)
  108. if session == nil || !session.Initialized() {
  109. return ErrNotificationNotInitialized
  110. }
  111. sessionLogging, ok := session.(SessionWithLogging)
  112. if !ok {
  113. return ErrSessionDoesNotSupportLogging
  114. }
  115. if !notification.Params.Level.ShouldSendTo(sessionLogging.GetLogLevel()) {
  116. return nil
  117. }
  118. return s.sendNotificationCore(ctx, session, s.buildLogNotification(notification))
  119. }
  120. func (s *MCPServer) sendNotificationToAllClients(notification mcp.JSONRPCNotification) {
  121. s.sessions.Range(func(k, v any) bool {
  122. if session, ok := v.(ClientSession); ok && session.Initialized() {
  123. select {
  124. case session.NotificationChannel() <- notification:
  125. // Successfully sent notification
  126. default:
  127. // Channel is blocked, if there's an error hook, use it
  128. if s.hooks != nil && len(s.hooks.OnError) > 0 {
  129. err := ErrNotificationChannelBlocked
  130. // Copy hooks pointer to local variable to avoid race condition
  131. hooks := s.hooks
  132. go func(sessionID string, hooks *Hooks) {
  133. ctx := context.Background()
  134. // Use the error hook to report the blocked channel
  135. hooks.onError(ctx, nil, "notification", map[string]any{
  136. "method": notification.Method,
  137. "sessionID": sessionID,
  138. }, fmt.Errorf("notification channel blocked for session %s: %w", sessionID, err))
  139. }(session.SessionID(), hooks)
  140. }
  141. }
  142. }
  143. return true
  144. })
  145. }
  146. func (s *MCPServer) sendNotificationToSpecificClient(session ClientSession, notification mcp.JSONRPCNotification) error {
  147. // upgrades the client-server communication to SSE stream when the server sends notifications to the client
  148. if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok {
  149. sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification()
  150. }
  151. select {
  152. case session.NotificationChannel() <- notification:
  153. return nil
  154. default:
  155. // Channel is blocked, if there's an error hook, use it
  156. if s.hooks != nil && len(s.hooks.OnError) > 0 {
  157. err := ErrNotificationChannelBlocked
  158. ctx := context.Background()
  159. // Copy hooks pointer to local variable to avoid race condition
  160. hooks := s.hooks
  161. go func(sID string, hooks *Hooks) {
  162. // Use the error hook to report the blocked channel
  163. hooks.onError(ctx, nil, "notification", map[string]any{
  164. "method": notification.Method,
  165. "sessionID": sID,
  166. }, fmt.Errorf("notification channel blocked for session %s: %w", sID, err))
  167. }(session.SessionID(), hooks)
  168. }
  169. return ErrNotificationChannelBlocked
  170. }
  171. }
  172. func (s *MCPServer) SendLogMessageToSpecificClient(sessionID string, notification mcp.LoggingMessageNotification) error {
  173. sessionValue, ok := s.sessions.Load(sessionID)
  174. if !ok {
  175. return ErrSessionNotFound
  176. }
  177. session, ok := sessionValue.(ClientSession)
  178. if !ok || !session.Initialized() {
  179. return ErrSessionNotInitialized
  180. }
  181. sessionLogging, ok := session.(SessionWithLogging)
  182. if !ok {
  183. return ErrSessionDoesNotSupportLogging
  184. }
  185. if !notification.Params.Level.ShouldSendTo(sessionLogging.GetLogLevel()) {
  186. return nil
  187. }
  188. return s.sendNotificationToSpecificClient(session, s.buildLogNotification(notification))
  189. }
  190. // UnregisterSession removes from storage session that is shut down.
  191. func (s *MCPServer) UnregisterSession(
  192. ctx context.Context,
  193. sessionID string,
  194. ) {
  195. sessionValue, ok := s.sessions.LoadAndDelete(sessionID)
  196. if !ok {
  197. return
  198. }
  199. if session, ok := sessionValue.(ClientSession); ok {
  200. s.hooks.UnregisterSession(ctx, session)
  201. }
  202. }
  203. // SendNotificationToAllClients sends a notification to all the currently active clients.
  204. func (s *MCPServer) SendNotificationToAllClients(
  205. method string,
  206. params map[string]any,
  207. ) {
  208. notification := mcp.JSONRPCNotification{
  209. JSONRPC: mcp.JSONRPC_VERSION,
  210. Notification: mcp.Notification{
  211. Method: method,
  212. Params: mcp.NotificationParams{
  213. AdditionalFields: params,
  214. },
  215. },
  216. }
  217. s.sendNotificationToAllClients(notification)
  218. }
  219. // SendNotificationToClient sends a notification to the current client
  220. func (s *MCPServer) sendNotificationCore(
  221. ctx context.Context,
  222. session ClientSession,
  223. notification mcp.JSONRPCNotification,
  224. ) error {
  225. // upgrades the client-server communication to SSE stream when the server sends notifications to the client
  226. if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok {
  227. sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification()
  228. }
  229. select {
  230. case session.NotificationChannel() <- notification:
  231. return nil
  232. default:
  233. // Channel is blocked, if there's an error hook, use it
  234. if s.hooks != nil && len(s.hooks.OnError) > 0 {
  235. method := notification.Method
  236. err := ErrNotificationChannelBlocked
  237. // Copy hooks pointer to local variable to avoid race condition
  238. hooks := s.hooks
  239. go func(sessionID string, hooks *Hooks) {
  240. // Use the error hook to report the blocked channel
  241. hooks.onError(ctx, nil, "notification", map[string]any{
  242. "method": method,
  243. "sessionID": sessionID,
  244. }, fmt.Errorf("notification channel blocked for session %s: %w", sessionID, err))
  245. }(session.SessionID(), hooks)
  246. }
  247. return ErrNotificationChannelBlocked
  248. }
  249. }
  250. // SendNotificationToClient sends a notification to the current client
  251. func (s *MCPServer) SendNotificationToClient(
  252. ctx context.Context,
  253. method string,
  254. params map[string]any,
  255. ) error {
  256. session := ClientSessionFromContext(ctx)
  257. if session == nil || !session.Initialized() {
  258. return ErrNotificationNotInitialized
  259. }
  260. notification := mcp.JSONRPCNotification{
  261. JSONRPC: mcp.JSONRPC_VERSION,
  262. Notification: mcp.Notification{
  263. Method: method,
  264. Params: mcp.NotificationParams{
  265. AdditionalFields: params,
  266. },
  267. },
  268. }
  269. return s.sendNotificationCore(ctx, session, notification)
  270. }
  271. // SendNotificationToSpecificClient sends a notification to a specific client by session ID
  272. func (s *MCPServer) SendNotificationToSpecificClient(
  273. sessionID string,
  274. method string,
  275. params map[string]any,
  276. ) error {
  277. sessionValue, ok := s.sessions.Load(sessionID)
  278. if !ok {
  279. return ErrSessionNotFound
  280. }
  281. session, ok := sessionValue.(ClientSession)
  282. if !ok || !session.Initialized() {
  283. return ErrSessionNotInitialized
  284. }
  285. notification := mcp.JSONRPCNotification{
  286. JSONRPC: mcp.JSONRPC_VERSION,
  287. Notification: mcp.Notification{
  288. Method: method,
  289. Params: mcp.NotificationParams{
  290. AdditionalFields: params,
  291. },
  292. },
  293. }
  294. return s.sendNotificationToSpecificClient(session, notification)
  295. }
  296. // AddSessionTool adds a tool for a specific session
  297. func (s *MCPServer) AddSessionTool(sessionID string, tool mcp.Tool, handler ToolHandlerFunc) error {
  298. return s.AddSessionTools(sessionID, ServerTool{Tool: tool, Handler: handler})
  299. }
  300. // AddSessionTools adds tools for a specific session
  301. func (s *MCPServer) AddSessionTools(sessionID string, tools ...ServerTool) error {
  302. sessionValue, ok := s.sessions.Load(sessionID)
  303. if !ok {
  304. return ErrSessionNotFound
  305. }
  306. session, ok := sessionValue.(SessionWithTools)
  307. if !ok {
  308. return ErrSessionDoesNotSupportTools
  309. }
  310. s.implicitlyRegisterToolCapabilities()
  311. // Get existing tools (this should return a thread-safe copy)
  312. sessionTools := session.GetSessionTools()
  313. // Create a new map to avoid concurrent modification issues
  314. newSessionTools := make(map[string]ServerTool, len(sessionTools)+len(tools))
  315. // Copy existing tools
  316. for k, v := range sessionTools {
  317. newSessionTools[k] = v
  318. }
  319. // Add new tools
  320. for _, tool := range tools {
  321. newSessionTools[tool.Tool.Name] = tool
  322. }
  323. // Set the tools (this should be thread-safe)
  324. session.SetSessionTools(newSessionTools)
  325. // It only makes sense to send tool notifications to initialized sessions --
  326. // if we're not initialized yet the client can't possibly have sent their
  327. // initial tools/list message.
  328. //
  329. // For initialized sessions, honor tools.listChanged, which is specifically
  330. // about whether notifications will be sent or not.
  331. // see <https://modelcontextprotocol.io/specification/2025-03-26/server/tools#capabilities>
  332. if session.Initialized() && s.capabilities.tools != nil && s.capabilities.tools.listChanged {
  333. // Send notification only to this session
  334. if err := s.SendNotificationToSpecificClient(sessionID, "notifications/tools/list_changed", nil); err != nil {
  335. // Log the error but don't fail the operation
  336. // The tools were successfully added, but notification failed
  337. if s.hooks != nil && len(s.hooks.OnError) > 0 {
  338. hooks := s.hooks
  339. go func(sID string, hooks *Hooks) {
  340. ctx := context.Background()
  341. hooks.onError(ctx, nil, "notification", map[string]any{
  342. "method": "notifications/tools/list_changed",
  343. "sessionID": sID,
  344. }, fmt.Errorf("failed to send notification after adding tools: %w", err))
  345. }(sessionID, hooks)
  346. }
  347. }
  348. }
  349. return nil
  350. }
  351. // DeleteSessionTools removes tools from a specific session
  352. func (s *MCPServer) DeleteSessionTools(sessionID string, names ...string) error {
  353. sessionValue, ok := s.sessions.Load(sessionID)
  354. if !ok {
  355. return ErrSessionNotFound
  356. }
  357. session, ok := sessionValue.(SessionWithTools)
  358. if !ok {
  359. return ErrSessionDoesNotSupportTools
  360. }
  361. // Get existing tools (this should return a thread-safe copy)
  362. sessionTools := session.GetSessionTools()
  363. if sessionTools == nil {
  364. return nil
  365. }
  366. // Create a new map to avoid concurrent modification issues
  367. newSessionTools := make(map[string]ServerTool, len(sessionTools))
  368. // Copy existing tools except those being deleted
  369. for k, v := range sessionTools {
  370. newSessionTools[k] = v
  371. }
  372. // Remove specified tools
  373. for _, name := range names {
  374. delete(newSessionTools, name)
  375. }
  376. // Set the tools (this should be thread-safe)
  377. session.SetSessionTools(newSessionTools)
  378. // It only makes sense to send tool notifications to initialized sessions --
  379. // if we're not initialized yet the client can't possibly have sent their
  380. // initial tools/list message.
  381. //
  382. // For initialized sessions, honor tools.listChanged, which is specifically
  383. // about whether notifications will be sent or not.
  384. // see <https://modelcontextprotocol.io/specification/2025-03-26/server/tools#capabilities>
  385. if session.Initialized() && s.capabilities.tools != nil && s.capabilities.tools.listChanged {
  386. // Send notification only to this session
  387. if err := s.SendNotificationToSpecificClient(sessionID, "notifications/tools/list_changed", nil); err != nil {
  388. // Log the error but don't fail the operation
  389. // The tools were successfully deleted, but notification failed
  390. if s.hooks != nil && len(s.hooks.OnError) > 0 {
  391. hooks := s.hooks
  392. go func(sID string, hooks *Hooks) {
  393. ctx := context.Background()
  394. hooks.onError(ctx, nil, "notification", map[string]any{
  395. "method": "notifications/tools/list_changed",
  396. "sessionID": sID,
  397. }, fmt.Errorf("failed to send notification after deleting tools: %w", err))
  398. }(sessionID, hooks)
  399. }
  400. }
  401. }
  402. return nil
  403. }