stdio.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592
  1. package server
  2. import (
  3. "bufio"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "log"
  9. "os"
  10. "os/signal"
  11. "sync"
  12. "sync/atomic"
  13. "syscall"
  14. "github.com/mark3labs/mcp-go/mcp"
  15. )
  16. // StdioContextFunc is a function that takes an existing context and returns
  17. // a potentially modified context.
  18. // This can be used to inject context values from environment variables,
  19. // for example.
  20. type StdioContextFunc func(ctx context.Context) context.Context
  21. // StdioServer wraps a MCPServer and handles stdio communication.
  22. // It provides a simple way to create command-line MCP servers that
  23. // communicate via standard input/output streams using JSON-RPC messages.
  24. type StdioServer struct {
  25. server *MCPServer
  26. errLogger *log.Logger
  27. contextFunc StdioContextFunc
  28. // Thread-safe tool call processing
  29. toolCallQueue chan *toolCallWork
  30. workerWg sync.WaitGroup
  31. workerPoolSize int
  32. queueSize int
  33. writeMu sync.Mutex // Protects concurrent writes
  34. }
  35. // toolCallWork represents a queued tool call request
  36. type toolCallWork struct {
  37. ctx context.Context
  38. message json.RawMessage
  39. writer io.Writer
  40. }
  41. // StdioOption defines a function type for configuring StdioServer
  42. type StdioOption func(*StdioServer)
  43. // WithErrorLogger sets the error logger for the server
  44. func WithErrorLogger(logger *log.Logger) StdioOption {
  45. return func(s *StdioServer) {
  46. s.errLogger = logger
  47. }
  48. }
  49. // WithStdioContextFunc sets a function that will be called to customise the context
  50. // to the server. Note that the stdio server uses the same context for all requests,
  51. // so this function will only be called once per server instance.
  52. func WithStdioContextFunc(fn StdioContextFunc) StdioOption {
  53. return func(s *StdioServer) {
  54. s.contextFunc = fn
  55. }
  56. }
  57. // WithWorkerPoolSize sets the number of workers for processing tool calls
  58. func WithWorkerPoolSize(size int) StdioOption {
  59. return func(s *StdioServer) {
  60. const maxWorkerPoolSize = 100
  61. if size > 0 && size <= maxWorkerPoolSize {
  62. s.workerPoolSize = size
  63. } else if size > maxWorkerPoolSize {
  64. s.errLogger.Printf("Worker pool size %d exceeds maximum (%d), using maximum", size, maxWorkerPoolSize)
  65. s.workerPoolSize = maxWorkerPoolSize
  66. }
  67. }
  68. }
  69. // WithQueueSize sets the size of the tool call queue
  70. func WithQueueSize(size int) StdioOption {
  71. return func(s *StdioServer) {
  72. const maxQueueSize = 10000
  73. if size > 0 && size <= maxQueueSize {
  74. s.queueSize = size
  75. } else if size > maxQueueSize {
  76. s.errLogger.Printf("Queue size %d exceeds maximum (%d), using maximum", size, maxQueueSize)
  77. s.queueSize = maxQueueSize
  78. }
  79. }
  80. }
  81. // stdioSession is a static client session, since stdio has only one client.
  82. type stdioSession struct {
  83. notifications chan mcp.JSONRPCNotification
  84. initialized atomic.Bool
  85. loggingLevel atomic.Value
  86. clientInfo atomic.Value // stores session-specific client info
  87. clientCapabilities atomic.Value // stores session-specific client capabilities
  88. writer io.Writer // for sending requests to client
  89. requestID atomic.Int64 // for generating unique request IDs
  90. mu sync.RWMutex // protects writer
  91. pendingRequests map[int64]chan *samplingResponse // for tracking pending sampling requests
  92. pendingMu sync.RWMutex // protects pendingRequests
  93. }
  94. // samplingResponse represents a response to a sampling request
  95. type samplingResponse struct {
  96. result *mcp.CreateMessageResult
  97. err error
  98. }
  99. func (s *stdioSession) SessionID() string {
  100. return "stdio"
  101. }
  102. func (s *stdioSession) NotificationChannel() chan<- mcp.JSONRPCNotification {
  103. return s.notifications
  104. }
  105. func (s *stdioSession) Initialize() {
  106. // set default logging level
  107. s.loggingLevel.Store(mcp.LoggingLevelError)
  108. s.initialized.Store(true)
  109. }
  110. func (s *stdioSession) Initialized() bool {
  111. return s.initialized.Load()
  112. }
  113. func (s *stdioSession) GetClientInfo() mcp.Implementation {
  114. if value := s.clientInfo.Load(); value != nil {
  115. if clientInfo, ok := value.(mcp.Implementation); ok {
  116. return clientInfo
  117. }
  118. }
  119. return mcp.Implementation{}
  120. }
  121. func (s *stdioSession) SetClientInfo(clientInfo mcp.Implementation) {
  122. s.clientInfo.Store(clientInfo)
  123. }
  124. func (s *stdioSession) GetClientCapabilities() mcp.ClientCapabilities {
  125. if value := s.clientCapabilities.Load(); value != nil {
  126. if clientCapabilities, ok := value.(mcp.ClientCapabilities); ok {
  127. return clientCapabilities
  128. }
  129. }
  130. return mcp.ClientCapabilities{}
  131. }
  132. func (s *stdioSession) SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) {
  133. s.clientCapabilities.Store(clientCapabilities)
  134. }
  135. func (s *stdioSession) SetLogLevel(level mcp.LoggingLevel) {
  136. s.loggingLevel.Store(level)
  137. }
  138. func (s *stdioSession) GetLogLevel() mcp.LoggingLevel {
  139. level := s.loggingLevel.Load()
  140. if level == nil {
  141. return mcp.LoggingLevelError
  142. }
  143. return level.(mcp.LoggingLevel)
  144. }
  145. // RequestSampling sends a sampling request to the client and waits for the response.
  146. func (s *stdioSession) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) {
  147. s.mu.RLock()
  148. writer := s.writer
  149. s.mu.RUnlock()
  150. if writer == nil {
  151. return nil, fmt.Errorf("no writer available for sending requests")
  152. }
  153. // Generate a unique request ID
  154. id := s.requestID.Add(1)
  155. // Create a response channel for this request
  156. responseChan := make(chan *samplingResponse, 1)
  157. s.pendingMu.Lock()
  158. s.pendingRequests[id] = responseChan
  159. s.pendingMu.Unlock()
  160. // Cleanup function to remove the pending request
  161. cleanup := func() {
  162. s.pendingMu.Lock()
  163. delete(s.pendingRequests, id)
  164. s.pendingMu.Unlock()
  165. }
  166. defer cleanup()
  167. // Create the JSON-RPC request
  168. jsonRPCRequest := struct {
  169. JSONRPC string `json:"jsonrpc"`
  170. ID int64 `json:"id"`
  171. Method string `json:"method"`
  172. Params mcp.CreateMessageParams `json:"params"`
  173. }{
  174. JSONRPC: mcp.JSONRPC_VERSION,
  175. ID: id,
  176. Method: string(mcp.MethodSamplingCreateMessage),
  177. Params: request.CreateMessageParams,
  178. }
  179. // Marshal and send the request
  180. requestBytes, err := json.Marshal(jsonRPCRequest)
  181. if err != nil {
  182. return nil, fmt.Errorf("failed to marshal sampling request: %w", err)
  183. }
  184. requestBytes = append(requestBytes, '\n')
  185. if _, err := writer.Write(requestBytes); err != nil {
  186. return nil, fmt.Errorf("failed to write sampling request: %w", err)
  187. }
  188. // Wait for the response or context cancellation
  189. select {
  190. case <-ctx.Done():
  191. return nil, ctx.Err()
  192. case response := <-responseChan:
  193. if response.err != nil {
  194. return nil, response.err
  195. }
  196. return response.result, nil
  197. }
  198. }
  199. // SetWriter sets the writer for sending requests to the client.
  200. func (s *stdioSession) SetWriter(writer io.Writer) {
  201. s.mu.Lock()
  202. defer s.mu.Unlock()
  203. s.writer = writer
  204. }
  205. var (
  206. _ ClientSession = (*stdioSession)(nil)
  207. _ SessionWithLogging = (*stdioSession)(nil)
  208. _ SessionWithClientInfo = (*stdioSession)(nil)
  209. _ SessionWithSampling = (*stdioSession)(nil)
  210. )
  211. var stdioSessionInstance = stdioSession{
  212. notifications: make(chan mcp.JSONRPCNotification, 100),
  213. pendingRequests: make(map[int64]chan *samplingResponse),
  214. }
  215. // NewStdioServer creates a new stdio server wrapper around an MCPServer.
  216. // It initializes the server with a default error logger that discards all output.
  217. func NewStdioServer(server *MCPServer) *StdioServer {
  218. return &StdioServer{
  219. server: server,
  220. errLogger: log.New(
  221. os.Stderr,
  222. "",
  223. log.LstdFlags,
  224. ), // Default to discarding logs
  225. workerPoolSize: 5, // Default worker pool size
  226. queueSize: 100, // Default queue size
  227. }
  228. }
  229. // SetErrorLogger configures where error messages from the StdioServer are logged.
  230. // The provided logger will receive all error messages generated during server operation.
  231. func (s *StdioServer) SetErrorLogger(logger *log.Logger) {
  232. s.errLogger = logger
  233. }
  234. // SetContextFunc sets a function that will be called to customise the context
  235. // to the server. Note that the stdio server uses the same context for all requests,
  236. // so this function will only be called once per server instance.
  237. func (s *StdioServer) SetContextFunc(fn StdioContextFunc) {
  238. s.contextFunc = fn
  239. }
  240. // handleNotifications continuously processes notifications from the session's notification channel
  241. // and writes them to the provided output. It runs until the context is cancelled.
  242. // Any errors encountered while writing notifications are logged but do not stop the handler.
  243. func (s *StdioServer) handleNotifications(ctx context.Context, stdout io.Writer) {
  244. for {
  245. select {
  246. case notification := <-stdioSessionInstance.notifications:
  247. if err := s.writeResponse(notification, stdout); err != nil {
  248. s.errLogger.Printf("Error writing notification: %v", err)
  249. }
  250. case <-ctx.Done():
  251. return
  252. }
  253. }
  254. }
  255. // processInputStream continuously reads and processes messages from the input stream.
  256. // It handles EOF gracefully as a normal termination condition.
  257. // The function returns when either:
  258. // - The context is cancelled (returns context.Err())
  259. // - EOF is encountered (returns nil)
  260. // - An error occurs while reading or processing messages (returns the error)
  261. func (s *StdioServer) processInputStream(ctx context.Context, reader *bufio.Reader, stdout io.Writer) error {
  262. for {
  263. if err := ctx.Err(); err != nil {
  264. return err
  265. }
  266. line, err := s.readNextLine(ctx, reader)
  267. if err != nil {
  268. if err == io.EOF {
  269. return nil
  270. }
  271. s.errLogger.Printf("Error reading input: %v", err)
  272. return err
  273. }
  274. if err := s.processMessage(ctx, line, stdout); err != nil {
  275. if err == io.EOF {
  276. return nil
  277. }
  278. s.errLogger.Printf("Error handling message: %v", err)
  279. return err
  280. }
  281. }
  282. }
  283. // toolCallWorker processes tool calls from the queue
  284. func (s *StdioServer) toolCallWorker(ctx context.Context) {
  285. defer s.workerWg.Done()
  286. for {
  287. select {
  288. case work, ok := <-s.toolCallQueue:
  289. if !ok {
  290. // Channel closed, exit worker
  291. return
  292. }
  293. // Process the tool call
  294. response := s.server.HandleMessage(work.ctx, work.message)
  295. if response != nil {
  296. if err := s.writeResponse(response, work.writer); err != nil {
  297. s.errLogger.Printf("Error writing tool response: %v", err)
  298. }
  299. }
  300. case <-ctx.Done():
  301. return
  302. }
  303. }
  304. }
  305. // readNextLine reads a single line from the input reader in a context-aware manner.
  306. // It uses channels to make the read operation cancellable via context.
  307. // Returns the read line and any error encountered. If the context is cancelled,
  308. // returns an empty string and the context's error. EOF is returned when the input
  309. // stream is closed.
  310. func (s *StdioServer) readNextLine(ctx context.Context, reader *bufio.Reader) (string, error) {
  311. type result struct {
  312. line string
  313. err error
  314. }
  315. resultCh := make(chan result, 1)
  316. go func() {
  317. line, err := reader.ReadString('\n')
  318. resultCh <- result{line: line, err: err}
  319. }()
  320. select {
  321. case <-ctx.Done():
  322. return "", nil
  323. case res := <-resultCh:
  324. return res.line, res.err
  325. }
  326. }
  327. // Listen starts listening for JSON-RPC messages on the provided input and writes responses to the provided output.
  328. // It runs until the context is cancelled or an error occurs.
  329. // Returns an error if there are issues with reading input or writing output.
  330. func (s *StdioServer) Listen(
  331. ctx context.Context,
  332. stdin io.Reader,
  333. stdout io.Writer,
  334. ) error {
  335. // Initialize the tool call queue
  336. s.toolCallQueue = make(chan *toolCallWork, s.queueSize)
  337. // Set a static client context since stdio only has one client
  338. if err := s.server.RegisterSession(ctx, &stdioSessionInstance); err != nil {
  339. return fmt.Errorf("register session: %w", err)
  340. }
  341. defer s.server.UnregisterSession(ctx, stdioSessionInstance.SessionID())
  342. ctx = s.server.WithContext(ctx, &stdioSessionInstance)
  343. // Set the writer for sending requests to the client
  344. stdioSessionInstance.SetWriter(stdout)
  345. // Add in any custom context.
  346. if s.contextFunc != nil {
  347. ctx = s.contextFunc(ctx)
  348. }
  349. reader := bufio.NewReader(stdin)
  350. // Start worker pool for tool calls
  351. for i := 0; i < s.workerPoolSize; i++ {
  352. s.workerWg.Add(1)
  353. go s.toolCallWorker(ctx)
  354. }
  355. // Start notification handler
  356. go s.handleNotifications(ctx, stdout)
  357. // Process input stream
  358. err := s.processInputStream(ctx, reader, stdout)
  359. // Shutdown workers gracefully
  360. close(s.toolCallQueue)
  361. s.workerWg.Wait()
  362. return err
  363. }
  364. // processMessage handles a single JSON-RPC message and writes the response.
  365. // It parses the message, processes it through the wrapped MCPServer, and writes any response.
  366. // Returns an error if there are issues with message processing or response writing.
  367. func (s *StdioServer) processMessage(
  368. ctx context.Context,
  369. line string,
  370. writer io.Writer,
  371. ) error {
  372. // If line is empty, likely due to ctx cancellation
  373. if len(line) == 0 {
  374. return nil
  375. }
  376. // Parse the message as raw JSON
  377. var rawMessage json.RawMessage
  378. if err := json.Unmarshal([]byte(line), &rawMessage); err != nil {
  379. response := createErrorResponse(nil, mcp.PARSE_ERROR, "Parse error")
  380. return s.writeResponse(response, writer)
  381. }
  382. // Check if this is a response to a sampling request
  383. if s.handleSamplingResponse(rawMessage) {
  384. return nil
  385. }
  386. // Check if this is a tool call that might need sampling (and thus should be processed concurrently)
  387. var baseMessage struct {
  388. Method string `json:"method"`
  389. }
  390. if json.Unmarshal(rawMessage, &baseMessage) == nil && baseMessage.Method == "tools/call" {
  391. // Queue tool calls for processing by workers
  392. select {
  393. case s.toolCallQueue <- &toolCallWork{
  394. ctx: ctx,
  395. message: rawMessage,
  396. writer: writer,
  397. }:
  398. return nil
  399. case <-ctx.Done():
  400. return ctx.Err()
  401. default:
  402. // Queue is full, process synchronously as fallback
  403. s.errLogger.Printf("Tool call queue full, processing synchronously")
  404. response := s.server.HandleMessage(ctx, rawMessage)
  405. if response != nil {
  406. return s.writeResponse(response, writer)
  407. }
  408. return nil
  409. }
  410. }
  411. // Handle other messages synchronously
  412. response := s.server.HandleMessage(ctx, rawMessage)
  413. // Only write response if there is one (not for notifications)
  414. if response != nil {
  415. if err := s.writeResponse(response, writer); err != nil {
  416. return fmt.Errorf("failed to write response: %w", err)
  417. }
  418. }
  419. return nil
  420. }
  421. // handleSamplingResponse checks if the message is a response to a sampling request
  422. // and routes it to the appropriate pending request channel.
  423. func (s *StdioServer) handleSamplingResponse(rawMessage json.RawMessage) bool {
  424. return stdioSessionInstance.handleSamplingResponse(rawMessage)
  425. }
  426. // handleSamplingResponse handles incoming sampling responses for this session
  427. func (s *stdioSession) handleSamplingResponse(rawMessage json.RawMessage) bool {
  428. // Try to parse as a JSON-RPC response
  429. var response struct {
  430. JSONRPC string `json:"jsonrpc"`
  431. ID json.Number `json:"id"`
  432. Result json.RawMessage `json:"result,omitempty"`
  433. Error *struct {
  434. Code int `json:"code"`
  435. Message string `json:"message"`
  436. } `json:"error,omitempty"`
  437. }
  438. if err := json.Unmarshal(rawMessage, &response); err != nil {
  439. return false
  440. }
  441. // Parse the ID as int64
  442. idInt64, err := response.ID.Int64()
  443. if err != nil || (response.Result == nil && response.Error == nil) {
  444. return false
  445. }
  446. // Look for a pending request with this ID
  447. s.pendingMu.RLock()
  448. responseChan, exists := s.pendingRequests[idInt64]
  449. s.pendingMu.RUnlock()
  450. if !exists {
  451. return false
  452. } // Parse and send the response
  453. samplingResp := &samplingResponse{}
  454. if response.Error != nil {
  455. samplingResp.err = fmt.Errorf("sampling request failed: %s", response.Error.Message)
  456. } else {
  457. var result mcp.CreateMessageResult
  458. if err := json.Unmarshal(response.Result, &result); err != nil {
  459. samplingResp.err = fmt.Errorf("failed to unmarshal sampling response: %w", err)
  460. } else {
  461. samplingResp.result = &result
  462. }
  463. }
  464. // Send the response (non-blocking)
  465. select {
  466. case responseChan <- samplingResp:
  467. default:
  468. // Channel is full or closed, ignore
  469. }
  470. return true
  471. }
  472. // writeResponse marshals and writes a JSON-RPC response message followed by a newline.
  473. // Returns an error if marshaling or writing fails.
  474. func (s *StdioServer) writeResponse(
  475. response mcp.JSONRPCMessage,
  476. writer io.Writer,
  477. ) error {
  478. responseBytes, err := json.Marshal(response)
  479. if err != nil {
  480. return err
  481. }
  482. // Protect concurrent writes
  483. s.writeMu.Lock()
  484. defer s.writeMu.Unlock()
  485. // Write response followed by newline
  486. if _, err := fmt.Fprintf(writer, "%s\n", responseBytes); err != nil {
  487. return err
  488. }
  489. return nil
  490. }
  491. // ServeStdio is a convenience function that creates and starts a StdioServer with os.Stdin and os.Stdout.
  492. // It sets up signal handling for graceful shutdown on SIGTERM and SIGINT.
  493. // Returns an error if the server encounters any issues during operation.
  494. func ServeStdio(server *MCPServer, opts ...StdioOption) error {
  495. s := NewStdioServer(server)
  496. for _, opt := range opts {
  497. opt(s)
  498. }
  499. ctx, cancel := context.WithCancel(context.Background())
  500. defer cancel()
  501. // Set up signal handling
  502. sigChan := make(chan os.Signal, 1)
  503. signal.Notify(sigChan, syscall.SIGTERM, syscall.SIGINT)
  504. go func() {
  505. <-sigChan
  506. cancel()
  507. }()
  508. return s.Listen(ctx, os.Stdin, os.Stdout)
  509. }