inprocess_session.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. package server
  2. import (
  3. "context"
  4. "fmt"
  5. "sync"
  6. "sync/atomic"
  7. "time"
  8. "github.com/mark3labs/mcp-go/mcp"
  9. )
  10. // SamplingHandler defines the interface for handling sampling requests from servers.
  11. type SamplingHandler interface {
  12. CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error)
  13. }
  14. type InProcessSession struct {
  15. sessionID string
  16. notifications chan mcp.JSONRPCNotification
  17. initialized atomic.Bool
  18. loggingLevel atomic.Value
  19. clientInfo atomic.Value
  20. clientCapabilities atomic.Value
  21. samplingHandler SamplingHandler
  22. mu sync.RWMutex
  23. }
  24. func NewInProcessSession(sessionID string, samplingHandler SamplingHandler) *InProcessSession {
  25. return &InProcessSession{
  26. sessionID: sessionID,
  27. notifications: make(chan mcp.JSONRPCNotification, 100),
  28. samplingHandler: samplingHandler,
  29. }
  30. }
  31. func (s *InProcessSession) SessionID() string {
  32. return s.sessionID
  33. }
  34. func (s *InProcessSession) NotificationChannel() chan<- mcp.JSONRPCNotification {
  35. return s.notifications
  36. }
  37. func (s *InProcessSession) Initialize() {
  38. s.loggingLevel.Store(mcp.LoggingLevelError)
  39. s.initialized.Store(true)
  40. }
  41. func (s *InProcessSession) Initialized() bool {
  42. return s.initialized.Load()
  43. }
  44. func (s *InProcessSession) GetClientInfo() mcp.Implementation {
  45. if value := s.clientInfo.Load(); value != nil {
  46. if clientInfo, ok := value.(mcp.Implementation); ok {
  47. return clientInfo
  48. }
  49. }
  50. return mcp.Implementation{}
  51. }
  52. func (s *InProcessSession) SetClientInfo(clientInfo mcp.Implementation) {
  53. s.clientInfo.Store(clientInfo)
  54. }
  55. func (s *InProcessSession) GetClientCapabilities() mcp.ClientCapabilities {
  56. if value := s.clientCapabilities.Load(); value != nil {
  57. if clientCapabilities, ok := value.(mcp.ClientCapabilities); ok {
  58. return clientCapabilities
  59. }
  60. }
  61. return mcp.ClientCapabilities{}
  62. }
  63. func (s *InProcessSession) SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) {
  64. s.clientCapabilities.Store(clientCapabilities)
  65. }
  66. func (s *InProcessSession) SetLogLevel(level mcp.LoggingLevel) {
  67. s.loggingLevel.Store(level)
  68. }
  69. func (s *InProcessSession) GetLogLevel() mcp.LoggingLevel {
  70. level := s.loggingLevel.Load()
  71. if level == nil {
  72. return mcp.LoggingLevelError
  73. }
  74. return level.(mcp.LoggingLevel)
  75. }
  76. func (s *InProcessSession) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) {
  77. s.mu.RLock()
  78. handler := s.samplingHandler
  79. s.mu.RUnlock()
  80. if handler == nil {
  81. return nil, fmt.Errorf("no sampling handler available")
  82. }
  83. return handler.CreateMessage(ctx, request)
  84. }
  85. // GenerateInProcessSessionID generates a unique session ID for inprocess clients
  86. func GenerateInProcessSessionID() string {
  87. return fmt.Sprintf("inprocess-%d", time.Now().UnixNano())
  88. }
  89. // Ensure interface compliance
  90. var (
  91. _ ClientSession = (*InProcessSession)(nil)
  92. _ SessionWithLogging = (*InProcessSession)(nil)
  93. _ SessionWithClientInfo = (*InProcessSession)(nil)
  94. _ SessionWithSampling = (*InProcessSession)(nil)
  95. )