sampling.go 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. package server
  2. import (
  3. "context"
  4. "fmt"
  5. "github.com/mark3labs/mcp-go/mcp"
  6. )
  7. // EnableSampling enables sampling capabilities for the server.
  8. // This allows the server to send sampling requests to clients that support it.
  9. func (s *MCPServer) EnableSampling() {
  10. s.capabilitiesMu.Lock()
  11. defer s.capabilitiesMu.Unlock()
  12. enabled := true
  13. s.capabilities.sampling = &enabled
  14. }
  15. // RequestSampling sends a sampling request to the client.
  16. // The client must have declared sampling capability during initialization.
  17. func (s *MCPServer) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) {
  18. session := ClientSessionFromContext(ctx)
  19. if session == nil {
  20. return nil, fmt.Errorf("no active session")
  21. }
  22. // Check if the session supports sampling requests
  23. if samplingSession, ok := session.(SessionWithSampling); ok {
  24. return samplingSession.RequestSampling(ctx, request)
  25. }
  26. // Check for inprocess sampling handler in context
  27. if handler := InProcessSamplingHandlerFromContext(ctx); handler != nil {
  28. return handler.CreateMessage(ctx, request)
  29. }
  30. return nil, fmt.Errorf("session does not support sampling")
  31. }
  32. // SessionWithSampling extends ClientSession to support sampling requests.
  33. type SessionWithSampling interface {
  34. ClientSession
  35. RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error)
  36. }
  37. // inProcessSamplingHandlerKey is the context key for storing inprocess sampling handler
  38. type inProcessSamplingHandlerKey struct{}
  39. // WithInProcessSamplingHandler adds a sampling handler to the context for inprocess clients
  40. func WithInProcessSamplingHandler(ctx context.Context, handler SamplingHandler) context.Context {
  41. return context.WithValue(ctx, inProcessSamplingHandlerKey{}, handler)
  42. }
  43. // InProcessSamplingHandlerFromContext retrieves the inprocess sampling handler from context
  44. func InProcessSamplingHandlerFromContext(ctx context.Context) SamplingHandler {
  45. if handler, ok := ctx.Value(inProcessSamplingHandlerKey{}).(SamplingHandler); ok {
  46. return handler
  47. }
  48. return nil
  49. }