| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444 |
- package server
- import (
- "context"
- "fmt"
- "github.com/mark3labs/mcp-go/mcp"
- )
- // ClientSession represents an active session that can be used by MCPServer to interact with client.
- type ClientSession interface {
- // Initialize marks session as fully initialized and ready for notifications
- Initialize()
- // Initialized returns if session is ready to accept notifications
- Initialized() bool
- // NotificationChannel provides a channel suitable for sending notifications to client.
- NotificationChannel() chan<- mcp.JSONRPCNotification
- // SessionID is a unique identifier used to track user session.
- SessionID() string
- }
- // SessionWithLogging is an extension of ClientSession that can receive log message notifications and set log level
- type SessionWithLogging interface {
- ClientSession
- // SetLogLevel sets the minimum log level
- SetLogLevel(level mcp.LoggingLevel)
- // GetLogLevel retrieves the minimum log level
- GetLogLevel() mcp.LoggingLevel
- }
- // SessionWithTools is an extension of ClientSession that can store session-specific tool data
- type SessionWithTools interface {
- ClientSession
- // GetSessionTools returns the tools specific to this session, if any
- // This method must be thread-safe for concurrent access
- GetSessionTools() map[string]ServerTool
- // SetSessionTools sets tools specific to this session
- // This method must be thread-safe for concurrent access
- SetSessionTools(tools map[string]ServerTool)
- }
- // SessionWithClientInfo is an extension of ClientSession that can store client info
- type SessionWithClientInfo interface {
- ClientSession
- // GetClientInfo returns the client information for this session
- GetClientInfo() mcp.Implementation
- // SetClientInfo sets the client information for this session
- SetClientInfo(clientInfo mcp.Implementation)
- // GetClientCapabilities returns the client capabilities for this session
- GetClientCapabilities() mcp.ClientCapabilities
- // SetClientCapabilities sets the client capabilities for this session
- SetClientCapabilities(clientCapabilities mcp.ClientCapabilities)
- }
- // SessionWithStreamableHTTPConfig extends ClientSession to support streamable HTTP transport configurations
- type SessionWithStreamableHTTPConfig interface {
- ClientSession
- // UpgradeToSSEWhenReceiveNotification upgrades the client-server communication to SSE stream when the server
- // sends notifications to the client
- //
- // The protocol specification:
- // - If the server response contains any JSON-RPC notifications, it MUST either:
- // - Return Content-Type: text/event-stream to initiate an SSE stream, OR
- // - Return Content-Type: application/json for a single JSON object
- // - The client MUST support both response types.
- //
- // Reference: https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#sending-messages-to-the-server
- UpgradeToSSEWhenReceiveNotification()
- }
- // clientSessionKey is the context key for storing current client notification channel.
- type clientSessionKey struct{}
- // ClientSessionFromContext retrieves current client notification context from context.
- func ClientSessionFromContext(ctx context.Context) ClientSession {
- if session, ok := ctx.Value(clientSessionKey{}).(ClientSession); ok {
- return session
- }
- return nil
- }
- // WithContext sets the current client session and returns the provided context
- func (s *MCPServer) WithContext(
- ctx context.Context,
- session ClientSession,
- ) context.Context {
- return context.WithValue(ctx, clientSessionKey{}, session)
- }
- // RegisterSession saves session that should be notified in case if some server attributes changed.
- func (s *MCPServer) RegisterSession(
- ctx context.Context,
- session ClientSession,
- ) error {
- sessionID := session.SessionID()
- if _, exists := s.sessions.LoadOrStore(sessionID, session); exists {
- return ErrSessionExists
- }
- s.hooks.RegisterSession(ctx, session)
- return nil
- }
- func (s *MCPServer) buildLogNotification(notification mcp.LoggingMessageNotification) mcp.JSONRPCNotification {
- return mcp.JSONRPCNotification{
- JSONRPC: mcp.JSONRPC_VERSION,
- Notification: mcp.Notification{
- Method: notification.Method,
- Params: mcp.NotificationParams{
- AdditionalFields: map[string]any{
- "level": notification.Params.Level,
- "logger": notification.Params.Logger,
- "data": notification.Params.Data,
- },
- },
- },
- }
- }
- func (s *MCPServer) SendLogMessageToClient(ctx context.Context, notification mcp.LoggingMessageNotification) error {
- session := ClientSessionFromContext(ctx)
- if session == nil || !session.Initialized() {
- return ErrNotificationNotInitialized
- }
- sessionLogging, ok := session.(SessionWithLogging)
- if !ok {
- return ErrSessionDoesNotSupportLogging
- }
- if !notification.Params.Level.ShouldSendTo(sessionLogging.GetLogLevel()) {
- return nil
- }
- return s.sendNotificationCore(ctx, session, s.buildLogNotification(notification))
- }
- func (s *MCPServer) sendNotificationToAllClients(notification mcp.JSONRPCNotification) {
- s.sessions.Range(func(k, v any) bool {
- if session, ok := v.(ClientSession); ok && session.Initialized() {
- select {
- case session.NotificationChannel() <- notification:
- // Successfully sent notification
- default:
- // Channel is blocked, if there's an error hook, use it
- if s.hooks != nil && len(s.hooks.OnError) > 0 {
- err := ErrNotificationChannelBlocked
- // Copy hooks pointer to local variable to avoid race condition
- hooks := s.hooks
- go func(sessionID string, hooks *Hooks) {
- ctx := context.Background()
- // Use the error hook to report the blocked channel
- hooks.onError(ctx, nil, "notification", map[string]any{
- "method": notification.Method,
- "sessionID": sessionID,
- }, fmt.Errorf("notification channel blocked for session %s: %w", sessionID, err))
- }(session.SessionID(), hooks)
- }
- }
- }
- return true
- })
- }
- func (s *MCPServer) sendNotificationToSpecificClient(session ClientSession, notification mcp.JSONRPCNotification) error {
- // upgrades the client-server communication to SSE stream when the server sends notifications to the client
- if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok {
- sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification()
- }
- select {
- case session.NotificationChannel() <- notification:
- return nil
- default:
- // Channel is blocked, if there's an error hook, use it
- if s.hooks != nil && len(s.hooks.OnError) > 0 {
- err := ErrNotificationChannelBlocked
- ctx := context.Background()
- // Copy hooks pointer to local variable to avoid race condition
- hooks := s.hooks
- go func(sID string, hooks *Hooks) {
- // Use the error hook to report the blocked channel
- hooks.onError(ctx, nil, "notification", map[string]any{
- "method": notification.Method,
- "sessionID": sID,
- }, fmt.Errorf("notification channel blocked for session %s: %w", sID, err))
- }(session.SessionID(), hooks)
- }
- return ErrNotificationChannelBlocked
- }
- }
- func (s *MCPServer) SendLogMessageToSpecificClient(sessionID string, notification mcp.LoggingMessageNotification) error {
- sessionValue, ok := s.sessions.Load(sessionID)
- if !ok {
- return ErrSessionNotFound
- }
- session, ok := sessionValue.(ClientSession)
- if !ok || !session.Initialized() {
- return ErrSessionNotInitialized
- }
- sessionLogging, ok := session.(SessionWithLogging)
- if !ok {
- return ErrSessionDoesNotSupportLogging
- }
- if !notification.Params.Level.ShouldSendTo(sessionLogging.GetLogLevel()) {
- return nil
- }
- return s.sendNotificationToSpecificClient(session, s.buildLogNotification(notification))
- }
- // UnregisterSession removes from storage session that is shut down.
- func (s *MCPServer) UnregisterSession(
- ctx context.Context,
- sessionID string,
- ) {
- sessionValue, ok := s.sessions.LoadAndDelete(sessionID)
- if !ok {
- return
- }
- if session, ok := sessionValue.(ClientSession); ok {
- s.hooks.UnregisterSession(ctx, session)
- }
- }
- // SendNotificationToAllClients sends a notification to all the currently active clients.
- func (s *MCPServer) SendNotificationToAllClients(
- method string,
- params map[string]any,
- ) {
- notification := mcp.JSONRPCNotification{
- JSONRPC: mcp.JSONRPC_VERSION,
- Notification: mcp.Notification{
- Method: method,
- Params: mcp.NotificationParams{
- AdditionalFields: params,
- },
- },
- }
- s.sendNotificationToAllClients(notification)
- }
- // SendNotificationToClient sends a notification to the current client
- func (s *MCPServer) sendNotificationCore(
- ctx context.Context,
- session ClientSession,
- notification mcp.JSONRPCNotification,
- ) error {
- // upgrades the client-server communication to SSE stream when the server sends notifications to the client
- if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok {
- sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification()
- }
- select {
- case session.NotificationChannel() <- notification:
- return nil
- default:
- // Channel is blocked, if there's an error hook, use it
- if s.hooks != nil && len(s.hooks.OnError) > 0 {
- method := notification.Method
- err := ErrNotificationChannelBlocked
- // Copy hooks pointer to local variable to avoid race condition
- hooks := s.hooks
- go func(sessionID string, hooks *Hooks) {
- // Use the error hook to report the blocked channel
- hooks.onError(ctx, nil, "notification", map[string]any{
- "method": method,
- "sessionID": sessionID,
- }, fmt.Errorf("notification channel blocked for session %s: %w", sessionID, err))
- }(session.SessionID(), hooks)
- }
- return ErrNotificationChannelBlocked
- }
- }
- // SendNotificationToClient sends a notification to the current client
- func (s *MCPServer) SendNotificationToClient(
- ctx context.Context,
- method string,
- params map[string]any,
- ) error {
- session := ClientSessionFromContext(ctx)
- if session == nil || !session.Initialized() {
- return ErrNotificationNotInitialized
- }
- notification := mcp.JSONRPCNotification{
- JSONRPC: mcp.JSONRPC_VERSION,
- Notification: mcp.Notification{
- Method: method,
- Params: mcp.NotificationParams{
- AdditionalFields: params,
- },
- },
- }
- return s.sendNotificationCore(ctx, session, notification)
- }
- // SendNotificationToSpecificClient sends a notification to a specific client by session ID
- func (s *MCPServer) SendNotificationToSpecificClient(
- sessionID string,
- method string,
- params map[string]any,
- ) error {
- sessionValue, ok := s.sessions.Load(sessionID)
- if !ok {
- return ErrSessionNotFound
- }
- session, ok := sessionValue.(ClientSession)
- if !ok || !session.Initialized() {
- return ErrSessionNotInitialized
- }
- notification := mcp.JSONRPCNotification{
- JSONRPC: mcp.JSONRPC_VERSION,
- Notification: mcp.Notification{
- Method: method,
- Params: mcp.NotificationParams{
- AdditionalFields: params,
- },
- },
- }
- return s.sendNotificationToSpecificClient(session, notification)
- }
- // AddSessionTool adds a tool for a specific session
- func (s *MCPServer) AddSessionTool(sessionID string, tool mcp.Tool, handler ToolHandlerFunc) error {
- return s.AddSessionTools(sessionID, ServerTool{Tool: tool, Handler: handler})
- }
- // AddSessionTools adds tools for a specific session
- func (s *MCPServer) AddSessionTools(sessionID string, tools ...ServerTool) error {
- sessionValue, ok := s.sessions.Load(sessionID)
- if !ok {
- return ErrSessionNotFound
- }
- session, ok := sessionValue.(SessionWithTools)
- if !ok {
- return ErrSessionDoesNotSupportTools
- }
- s.implicitlyRegisterToolCapabilities()
- // Get existing tools (this should return a thread-safe copy)
- sessionTools := session.GetSessionTools()
- // Create a new map to avoid concurrent modification issues
- newSessionTools := make(map[string]ServerTool, len(sessionTools)+len(tools))
- // Copy existing tools
- for k, v := range sessionTools {
- newSessionTools[k] = v
- }
- // Add new tools
- for _, tool := range tools {
- newSessionTools[tool.Tool.Name] = tool
- }
- // Set the tools (this should be thread-safe)
- session.SetSessionTools(newSessionTools)
- // It only makes sense to send tool notifications to initialized sessions --
- // if we're not initialized yet the client can't possibly have sent their
- // initial tools/list message.
- //
- // For initialized sessions, honor tools.listChanged, which is specifically
- // about whether notifications will be sent or not.
- // see <https://modelcontextprotocol.io/specification/2025-03-26/server/tools#capabilities>
- if session.Initialized() && s.capabilities.tools != nil && s.capabilities.tools.listChanged {
- // Send notification only to this session
- if err := s.SendNotificationToSpecificClient(sessionID, "notifications/tools/list_changed", nil); err != nil {
- // Log the error but don't fail the operation
- // The tools were successfully added, but notification failed
- if s.hooks != nil && len(s.hooks.OnError) > 0 {
- hooks := s.hooks
- go func(sID string, hooks *Hooks) {
- ctx := context.Background()
- hooks.onError(ctx, nil, "notification", map[string]any{
- "method": "notifications/tools/list_changed",
- "sessionID": sID,
- }, fmt.Errorf("failed to send notification after adding tools: %w", err))
- }(sessionID, hooks)
- }
- }
- }
- return nil
- }
- // DeleteSessionTools removes tools from a specific session
- func (s *MCPServer) DeleteSessionTools(sessionID string, names ...string) error {
- sessionValue, ok := s.sessions.Load(sessionID)
- if !ok {
- return ErrSessionNotFound
- }
- session, ok := sessionValue.(SessionWithTools)
- if !ok {
- return ErrSessionDoesNotSupportTools
- }
- // Get existing tools (this should return a thread-safe copy)
- sessionTools := session.GetSessionTools()
- if sessionTools == nil {
- return nil
- }
- // Create a new map to avoid concurrent modification issues
- newSessionTools := make(map[string]ServerTool, len(sessionTools))
- // Copy existing tools except those being deleted
- for k, v := range sessionTools {
- newSessionTools[k] = v
- }
- // Remove specified tools
- for _, name := range names {
- delete(newSessionTools, name)
- }
- // Set the tools (this should be thread-safe)
- session.SetSessionTools(newSessionTools)
- // It only makes sense to send tool notifications to initialized sessions --
- // if we're not initialized yet the client can't possibly have sent their
- // initial tools/list message.
- //
- // For initialized sessions, honor tools.listChanged, which is specifically
- // about whether notifications will be sent or not.
- // see <https://modelcontextprotocol.io/specification/2025-03-26/server/tools#capabilities>
- if session.Initialized() && s.capabilities.tools != nil && s.capabilities.tools.listChanged {
- // Send notification only to this session
- if err := s.SendNotificationToSpecificClient(sessionID, "notifications/tools/list_changed", nil); err != nil {
- // Log the error but don't fail the operation
- // The tools were successfully deleted, but notification failed
- if s.hooks != nil && len(s.hooks.OnError) > 0 {
- hooks := s.hooks
- go func(sID string, hooks *Hooks) {
- ctx := context.Background()
- hooks.onError(ctx, nil, "notification", map[string]any{
- "method": "notifications/tools/list_changed",
- "sessionID": sID,
- }, fmt.Errorf("failed to send notification after deleting tools: %w", err))
- }(sessionID, hooks)
- }
- }
- }
- return nil
- }
|