server.go 33 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201
  1. // Package server provides MCP (Model Context Protocol) server implementations.
  2. package server
  3. import (
  4. "context"
  5. "encoding/base64"
  6. "encoding/json"
  7. "fmt"
  8. "slices"
  9. "sort"
  10. "sync"
  11. "github.com/mark3labs/mcp-go/mcp"
  12. )
  13. // resourceEntry holds both a resource and its handler
  14. type resourceEntry struct {
  15. resource mcp.Resource
  16. handler ResourceHandlerFunc
  17. }
  18. // resourceTemplateEntry holds both a template and its handler
  19. type resourceTemplateEntry struct {
  20. template mcp.ResourceTemplate
  21. handler ResourceTemplateHandlerFunc
  22. }
  23. // ServerOption is a function that configures an MCPServer.
  24. type ServerOption func(*MCPServer)
  25. // ResourceHandlerFunc is a function that returns resource contents.
  26. type ResourceHandlerFunc func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error)
  27. // ResourceTemplateHandlerFunc is a function that returns a resource template.
  28. type ResourceTemplateHandlerFunc func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error)
  29. // PromptHandlerFunc handles prompt requests with given arguments.
  30. type PromptHandlerFunc func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error)
  31. // ToolHandlerFunc handles tool calls with given arguments.
  32. type ToolHandlerFunc func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error)
  33. // ToolHandlerMiddleware is a middleware function that wraps a ToolHandlerFunc.
  34. type ToolHandlerMiddleware func(ToolHandlerFunc) ToolHandlerFunc
  35. // ResourceHandlerMiddleware is a middleware function that wraps a ResourceHandlerFunc.
  36. type ResourceHandlerMiddleware func(ResourceHandlerFunc) ResourceHandlerFunc
  37. // ToolFilterFunc is a function that filters tools based on context, typically using session information.
  38. type ToolFilterFunc func(ctx context.Context, tools []mcp.Tool) []mcp.Tool
  39. // ServerTool combines a Tool with its ToolHandlerFunc.
  40. type ServerTool struct {
  41. Tool mcp.Tool
  42. Handler ToolHandlerFunc
  43. }
  44. // ServerPrompt combines a Prompt with its handler function.
  45. type ServerPrompt struct {
  46. Prompt mcp.Prompt
  47. Handler PromptHandlerFunc
  48. }
  49. // ServerResource combines a Resource with its handler function.
  50. type ServerResource struct {
  51. Resource mcp.Resource
  52. Handler ResourceHandlerFunc
  53. }
  54. // ServerResourceTemplate combines a ResourceTemplate with its handler function.
  55. type ServerResourceTemplate struct {
  56. Template mcp.ResourceTemplate
  57. Handler ResourceTemplateHandlerFunc
  58. }
  59. // serverKey is the context key for storing the server instance
  60. type serverKey struct{}
  61. // ServerFromContext retrieves the MCPServer instance from a context
  62. func ServerFromContext(ctx context.Context) *MCPServer {
  63. if srv, ok := ctx.Value(serverKey{}).(*MCPServer); ok {
  64. return srv
  65. }
  66. return nil
  67. }
  68. // UnparsableMessageError is attached to the RequestError when json.Unmarshal
  69. // fails on the request.
  70. type UnparsableMessageError struct {
  71. message json.RawMessage
  72. method mcp.MCPMethod
  73. err error
  74. }
  75. func (e *UnparsableMessageError) Error() string {
  76. return fmt.Sprintf("unparsable %s request: %s", e.method, e.err)
  77. }
  78. func (e *UnparsableMessageError) Unwrap() error {
  79. return e.err
  80. }
  81. func (e *UnparsableMessageError) GetMessage() json.RawMessage {
  82. return e.message
  83. }
  84. func (e *UnparsableMessageError) GetMethod() mcp.MCPMethod {
  85. return e.method
  86. }
  87. // RequestError is an error that can be converted to a JSON-RPC error.
  88. // Implements Unwrap() to allow inspecting the error chain.
  89. type requestError struct {
  90. id any
  91. code int
  92. err error
  93. }
  94. func (e *requestError) Error() string {
  95. return fmt.Sprintf("request error: %s", e.err)
  96. }
  97. func (e *requestError) ToJSONRPCError() mcp.JSONRPCError {
  98. return mcp.JSONRPCError{
  99. JSONRPC: mcp.JSONRPC_VERSION,
  100. ID: mcp.NewRequestId(e.id),
  101. Error: struct {
  102. Code int `json:"code"`
  103. Message string `json:"message"`
  104. Data any `json:"data,omitempty"`
  105. }{
  106. Code: e.code,
  107. Message: e.err.Error(),
  108. },
  109. }
  110. }
  111. func (e *requestError) Unwrap() error {
  112. return e.err
  113. }
  114. // NotificationHandlerFunc handles incoming notifications.
  115. type NotificationHandlerFunc func(ctx context.Context, notification mcp.JSONRPCNotification)
  116. // MCPServer implements a Model Context Protocol server that can handle various types of requests
  117. // including resources, prompts, and tools.
  118. type MCPServer struct {
  119. // Separate mutexes for different resource types
  120. resourcesMu sync.RWMutex
  121. promptsMu sync.RWMutex
  122. toolsMu sync.RWMutex
  123. middlewareMu sync.RWMutex
  124. notificationHandlersMu sync.RWMutex
  125. capabilitiesMu sync.RWMutex
  126. toolFiltersMu sync.RWMutex
  127. name string
  128. version string
  129. instructions string
  130. resources map[string]resourceEntry
  131. resourceTemplates map[string]resourceTemplateEntry
  132. prompts map[string]mcp.Prompt
  133. promptHandlers map[string]PromptHandlerFunc
  134. tools map[string]ServerTool
  135. toolHandlerMiddlewares []ToolHandlerMiddleware
  136. resourceHandlerMiddlewares []ResourceHandlerMiddleware
  137. toolFilters []ToolFilterFunc
  138. notificationHandlers map[string]NotificationHandlerFunc
  139. capabilities serverCapabilities
  140. paginationLimit *int
  141. sessions sync.Map
  142. hooks *Hooks
  143. }
  144. // WithPaginationLimit sets the pagination limit for the server.
  145. func WithPaginationLimit(limit int) ServerOption {
  146. return func(s *MCPServer) {
  147. s.paginationLimit = &limit
  148. }
  149. }
  150. // serverCapabilities defines the supported features of the MCP server
  151. type serverCapabilities struct {
  152. tools *toolCapabilities
  153. resources *resourceCapabilities
  154. prompts *promptCapabilities
  155. logging *bool
  156. sampling *bool
  157. }
  158. // resourceCapabilities defines the supported resource-related features
  159. type resourceCapabilities struct {
  160. subscribe bool
  161. listChanged bool
  162. }
  163. // promptCapabilities defines the supported prompt-related features
  164. type promptCapabilities struct {
  165. listChanged bool
  166. }
  167. // toolCapabilities defines the supported tool-related features
  168. type toolCapabilities struct {
  169. listChanged bool
  170. }
  171. // WithResourceCapabilities configures resource-related server capabilities
  172. func WithResourceCapabilities(subscribe, listChanged bool) ServerOption {
  173. return func(s *MCPServer) {
  174. // Always create a non-nil capability object
  175. s.capabilities.resources = &resourceCapabilities{
  176. subscribe: subscribe,
  177. listChanged: listChanged,
  178. }
  179. }
  180. }
  181. // WithToolHandlerMiddleware allows adding a middleware for the
  182. // tool handler call chain.
  183. func WithToolHandlerMiddleware(
  184. toolHandlerMiddleware ToolHandlerMiddleware,
  185. ) ServerOption {
  186. return func(s *MCPServer) {
  187. s.middlewareMu.Lock()
  188. s.toolHandlerMiddlewares = append(s.toolHandlerMiddlewares, toolHandlerMiddleware)
  189. s.middlewareMu.Unlock()
  190. }
  191. }
  192. // WithResourceHandlerMiddleware allows adding a middleware for the
  193. // resource handler call chain.
  194. func WithResourceHandlerMiddleware(
  195. resourceHandlerMiddleware ResourceHandlerMiddleware,
  196. ) ServerOption {
  197. return func(s *MCPServer) {
  198. s.middlewareMu.Lock()
  199. s.resourceHandlerMiddlewares = append(s.resourceHandlerMiddlewares, resourceHandlerMiddleware)
  200. s.middlewareMu.Unlock()
  201. }
  202. }
  203. // WithResourceRecovery adds a middleware that recovers from panics in resource handlers.
  204. func WithResourceRecovery() ServerOption {
  205. return WithResourceHandlerMiddleware(func(next ResourceHandlerFunc) ResourceHandlerFunc {
  206. return func(ctx context.Context, request mcp.ReadResourceRequest) (result []mcp.ResourceContents, err error) {
  207. defer func() {
  208. if r := recover(); r != nil {
  209. err = fmt.Errorf(
  210. "panic recovered in %s resource handler: %v",
  211. request.Params.URI,
  212. r,
  213. )
  214. }
  215. }()
  216. return next(ctx, request)
  217. }
  218. })
  219. }
  220. // WithToolFilter adds a filter function that will be applied to tools before they are returned in list_tools
  221. func WithToolFilter(
  222. toolFilter ToolFilterFunc,
  223. ) ServerOption {
  224. return func(s *MCPServer) {
  225. s.toolFiltersMu.Lock()
  226. s.toolFilters = append(s.toolFilters, toolFilter)
  227. s.toolFiltersMu.Unlock()
  228. }
  229. }
  230. // WithRecovery adds a middleware that recovers from panics in tool handlers.
  231. func WithRecovery() ServerOption {
  232. return WithToolHandlerMiddleware(func(next ToolHandlerFunc) ToolHandlerFunc {
  233. return func(ctx context.Context, request mcp.CallToolRequest) (result *mcp.CallToolResult, err error) {
  234. defer func() {
  235. if r := recover(); r != nil {
  236. err = fmt.Errorf(
  237. "panic recovered in %s tool handler: %v",
  238. request.Params.Name,
  239. r,
  240. )
  241. }
  242. }()
  243. return next(ctx, request)
  244. }
  245. })
  246. }
  247. // WithHooks allows adding hooks that will be called before or after
  248. // either [all] requests or before / after specific request methods, or else
  249. // prior to returning an error to the client.
  250. func WithHooks(hooks *Hooks) ServerOption {
  251. return func(s *MCPServer) {
  252. s.hooks = hooks
  253. }
  254. }
  255. // WithPromptCapabilities configures prompt-related server capabilities
  256. func WithPromptCapabilities(listChanged bool) ServerOption {
  257. return func(s *MCPServer) {
  258. // Always create a non-nil capability object
  259. s.capabilities.prompts = &promptCapabilities{
  260. listChanged: listChanged,
  261. }
  262. }
  263. }
  264. // WithToolCapabilities configures tool-related server capabilities
  265. func WithToolCapabilities(listChanged bool) ServerOption {
  266. return func(s *MCPServer) {
  267. // Always create a non-nil capability object
  268. s.capabilities.tools = &toolCapabilities{
  269. listChanged: listChanged,
  270. }
  271. }
  272. }
  273. // WithLogging enables logging capabilities for the server
  274. func WithLogging() ServerOption {
  275. return func(s *MCPServer) {
  276. s.capabilities.logging = mcp.ToBoolPtr(true)
  277. }
  278. }
  279. // WithInstructions sets the server instructions for the client returned in the initialize response
  280. func WithInstructions(instructions string) ServerOption {
  281. return func(s *MCPServer) {
  282. s.instructions = instructions
  283. }
  284. }
  285. // NewMCPServer creates a new MCP server instance with the given name, version and options
  286. func NewMCPServer(
  287. name, version string,
  288. opts ...ServerOption,
  289. ) *MCPServer {
  290. s := &MCPServer{
  291. resources: make(map[string]resourceEntry),
  292. resourceTemplates: make(map[string]resourceTemplateEntry),
  293. prompts: make(map[string]mcp.Prompt),
  294. promptHandlers: make(map[string]PromptHandlerFunc),
  295. tools: make(map[string]ServerTool),
  296. toolHandlerMiddlewares: make([]ToolHandlerMiddleware, 0),
  297. resourceHandlerMiddlewares: make([]ResourceHandlerMiddleware, 0),
  298. name: name,
  299. version: version,
  300. notificationHandlers: make(map[string]NotificationHandlerFunc),
  301. capabilities: serverCapabilities{
  302. tools: nil,
  303. resources: nil,
  304. prompts: nil,
  305. logging: nil,
  306. },
  307. }
  308. for _, opt := range opts {
  309. opt(s)
  310. }
  311. return s
  312. }
  313. // GenerateInProcessSessionID generates a unique session ID for inprocess clients
  314. func (s *MCPServer) GenerateInProcessSessionID() string {
  315. return GenerateInProcessSessionID()
  316. }
  317. // AddResources registers multiple resources at once
  318. func (s *MCPServer) AddResources(resources ...ServerResource) {
  319. s.implicitlyRegisterResourceCapabilities()
  320. s.resourcesMu.Lock()
  321. for _, entry := range resources {
  322. s.resources[entry.Resource.URI] = resourceEntry{
  323. resource: entry.Resource,
  324. handler: entry.Handler,
  325. }
  326. }
  327. s.resourcesMu.Unlock()
  328. // When the list of available resources changes, servers that declared the listChanged capability SHOULD send a notification
  329. if s.capabilities.resources.listChanged {
  330. // Send notification to all initialized sessions
  331. s.SendNotificationToAllClients(mcp.MethodNotificationResourcesListChanged, nil)
  332. }
  333. }
  334. // SetResources replaces all existing resources with the provided list
  335. func (s *MCPServer) SetResources(resources ...ServerResource) {
  336. s.resourcesMu.Lock()
  337. s.resources = make(map[string]resourceEntry, len(resources))
  338. s.resourcesMu.Unlock()
  339. s.AddResources(resources...)
  340. }
  341. // AddResource registers a new resource and its handler
  342. func (s *MCPServer) AddResource(
  343. resource mcp.Resource,
  344. handler ResourceHandlerFunc,
  345. ) {
  346. s.AddResources(ServerResource{Resource: resource, Handler: handler})
  347. }
  348. // DeleteResources removes resources from the server
  349. func (s *MCPServer) DeleteResources(uris ...string) {
  350. s.resourcesMu.Lock()
  351. var exists bool
  352. for _, uri := range uris {
  353. if _, ok := s.resources[uri]; ok {
  354. delete(s.resources, uri)
  355. exists = true
  356. }
  357. }
  358. s.resourcesMu.Unlock()
  359. // Send notification to all initialized sessions if listChanged capability is enabled and we actually remove a resource
  360. if exists && s.capabilities.resources != nil && s.capabilities.resources.listChanged {
  361. s.SendNotificationToAllClients(mcp.MethodNotificationResourcesListChanged, nil)
  362. }
  363. }
  364. // RemoveResource removes a resource from the server
  365. func (s *MCPServer) RemoveResource(uri string) {
  366. s.resourcesMu.Lock()
  367. _, exists := s.resources[uri]
  368. if exists {
  369. delete(s.resources, uri)
  370. }
  371. s.resourcesMu.Unlock()
  372. // Send notification to all initialized sessions if listChanged capability is enabled and we actually remove a resource
  373. if exists && s.capabilities.resources != nil && s.capabilities.resources.listChanged {
  374. s.SendNotificationToAllClients(mcp.MethodNotificationResourcesListChanged, nil)
  375. }
  376. }
  377. // AddResourceTemplates registers multiple resource templates at once
  378. func (s *MCPServer) AddResourceTemplates(resourceTemplates ...ServerResourceTemplate) {
  379. s.implicitlyRegisterResourceCapabilities()
  380. s.resourcesMu.Lock()
  381. for _, entry := range resourceTemplates {
  382. s.resourceTemplates[entry.Template.URITemplate.Raw()] = resourceTemplateEntry{
  383. template: entry.Template,
  384. handler: entry.Handler,
  385. }
  386. }
  387. s.resourcesMu.Unlock()
  388. // When the list of available resources changes, servers that declared the listChanged capability SHOULD send a notification
  389. if s.capabilities.resources.listChanged {
  390. // Send notification to all initialized sessions
  391. s.SendNotificationToAllClients(mcp.MethodNotificationResourcesListChanged, nil)
  392. }
  393. }
  394. // SetResourceTemplates replaces all existing resource templates with the provided list
  395. func (s *MCPServer) SetResourceTemplates(templates ...ServerResourceTemplate) {
  396. s.resourcesMu.Lock()
  397. s.resourceTemplates = make(map[string]resourceTemplateEntry, len(templates))
  398. s.resourcesMu.Unlock()
  399. s.AddResourceTemplates(templates...)
  400. }
  401. // AddResourceTemplate registers a new resource template and its handler
  402. func (s *MCPServer) AddResourceTemplate(
  403. template mcp.ResourceTemplate,
  404. handler ResourceTemplateHandlerFunc,
  405. ) {
  406. s.AddResourceTemplates(ServerResourceTemplate{Template: template, Handler: handler})
  407. }
  408. // AddPrompts registers multiple prompts at once
  409. func (s *MCPServer) AddPrompts(prompts ...ServerPrompt) {
  410. s.implicitlyRegisterPromptCapabilities()
  411. s.promptsMu.Lock()
  412. for _, entry := range prompts {
  413. s.prompts[entry.Prompt.Name] = entry.Prompt
  414. s.promptHandlers[entry.Prompt.Name] = entry.Handler
  415. }
  416. s.promptsMu.Unlock()
  417. // When the list of available prompts changes, servers that declared the listChanged capability SHOULD send a notification.
  418. if s.capabilities.prompts.listChanged {
  419. // Send notification to all initialized sessions
  420. s.SendNotificationToAllClients(mcp.MethodNotificationPromptsListChanged, nil)
  421. }
  422. }
  423. // AddPrompt registers a new prompt handler with the given name
  424. func (s *MCPServer) AddPrompt(prompt mcp.Prompt, handler PromptHandlerFunc) {
  425. s.AddPrompts(ServerPrompt{Prompt: prompt, Handler: handler})
  426. }
  427. // SetPrompts replaces all existing prompts with the provided list
  428. func (s *MCPServer) SetPrompts(prompts ...ServerPrompt) {
  429. s.promptsMu.Lock()
  430. s.prompts = make(map[string]mcp.Prompt, len(prompts))
  431. s.promptHandlers = make(map[string]PromptHandlerFunc, len(prompts))
  432. s.promptsMu.Unlock()
  433. s.AddPrompts(prompts...)
  434. }
  435. // DeletePrompts removes prompts from the server
  436. func (s *MCPServer) DeletePrompts(names ...string) {
  437. s.promptsMu.Lock()
  438. var exists bool
  439. for _, name := range names {
  440. if _, ok := s.prompts[name]; ok {
  441. delete(s.prompts, name)
  442. delete(s.promptHandlers, name)
  443. exists = true
  444. }
  445. }
  446. s.promptsMu.Unlock()
  447. // Send notification to all initialized sessions if listChanged capability is enabled, and we actually remove a prompt
  448. if exists && s.capabilities.prompts != nil && s.capabilities.prompts.listChanged {
  449. // Send notification to all initialized sessions
  450. s.SendNotificationToAllClients(mcp.MethodNotificationPromptsListChanged, nil)
  451. }
  452. }
  453. // AddTool registers a new tool and its handler
  454. func (s *MCPServer) AddTool(tool mcp.Tool, handler ToolHandlerFunc) {
  455. s.AddTools(ServerTool{Tool: tool, Handler: handler})
  456. }
  457. // Register tool capabilities due to a tool being added. Default to
  458. // listChanged: true, but don't change the value if we've already explicitly
  459. // registered tools.listChanged false.
  460. func (s *MCPServer) implicitlyRegisterToolCapabilities() {
  461. s.implicitlyRegisterCapabilities(
  462. func() bool { return s.capabilities.tools != nil },
  463. func() { s.capabilities.tools = &toolCapabilities{listChanged: true} },
  464. )
  465. }
  466. func (s *MCPServer) implicitlyRegisterResourceCapabilities() {
  467. s.implicitlyRegisterCapabilities(
  468. func() bool { return s.capabilities.resources != nil },
  469. func() { s.capabilities.resources = &resourceCapabilities{} },
  470. )
  471. }
  472. func (s *MCPServer) implicitlyRegisterPromptCapabilities() {
  473. s.implicitlyRegisterCapabilities(
  474. func() bool { return s.capabilities.prompts != nil },
  475. func() { s.capabilities.prompts = &promptCapabilities{} },
  476. )
  477. }
  478. func (s *MCPServer) implicitlyRegisterCapabilities(check func() bool, register func()) {
  479. s.capabilitiesMu.RLock()
  480. if check() {
  481. s.capabilitiesMu.RUnlock()
  482. return
  483. }
  484. s.capabilitiesMu.RUnlock()
  485. s.capabilitiesMu.Lock()
  486. if !check() {
  487. register()
  488. }
  489. s.capabilitiesMu.Unlock()
  490. }
  491. // AddTools registers multiple tools at once
  492. func (s *MCPServer) AddTools(tools ...ServerTool) {
  493. s.implicitlyRegisterToolCapabilities()
  494. s.toolsMu.Lock()
  495. for _, entry := range tools {
  496. s.tools[entry.Tool.Name] = entry
  497. }
  498. s.toolsMu.Unlock()
  499. // When the list of available tools changes, servers that declared the listChanged capability SHOULD send a notification.
  500. if s.capabilities.tools.listChanged {
  501. // Send notification to all initialized sessions
  502. s.SendNotificationToAllClients(mcp.MethodNotificationToolsListChanged, nil)
  503. }
  504. }
  505. // SetTools replaces all existing tools with the provided list
  506. func (s *MCPServer) SetTools(tools ...ServerTool) {
  507. s.toolsMu.Lock()
  508. s.tools = make(map[string]ServerTool, len(tools))
  509. s.toolsMu.Unlock()
  510. s.AddTools(tools...)
  511. }
  512. // DeleteTools removes tools from the server
  513. func (s *MCPServer) DeleteTools(names ...string) {
  514. s.toolsMu.Lock()
  515. var exists bool
  516. for _, name := range names {
  517. if _, ok := s.tools[name]; ok {
  518. delete(s.tools, name)
  519. exists = true
  520. }
  521. }
  522. s.toolsMu.Unlock()
  523. // When the list of available tools changes, servers that declared the listChanged capability SHOULD send a notification.
  524. if exists && s.capabilities.tools != nil && s.capabilities.tools.listChanged {
  525. // Send notification to all initialized sessions
  526. s.SendNotificationToAllClients(mcp.MethodNotificationToolsListChanged, nil)
  527. }
  528. }
  529. // AddNotificationHandler registers a new handler for incoming notifications
  530. func (s *MCPServer) AddNotificationHandler(
  531. method string,
  532. handler NotificationHandlerFunc,
  533. ) {
  534. s.notificationHandlersMu.Lock()
  535. defer s.notificationHandlersMu.Unlock()
  536. s.notificationHandlers[method] = handler
  537. }
  538. func (s *MCPServer) handleInitialize(
  539. ctx context.Context,
  540. _ any,
  541. request mcp.InitializeRequest,
  542. ) (*mcp.InitializeResult, *requestError) {
  543. capabilities := mcp.ServerCapabilities{}
  544. // Only add resource capabilities if they're configured
  545. if s.capabilities.resources != nil {
  546. capabilities.Resources = &struct {
  547. Subscribe bool `json:"subscribe,omitempty"`
  548. ListChanged bool `json:"listChanged,omitempty"`
  549. }{
  550. Subscribe: s.capabilities.resources.subscribe,
  551. ListChanged: s.capabilities.resources.listChanged,
  552. }
  553. }
  554. // Only add prompt capabilities if they're configured
  555. if s.capabilities.prompts != nil {
  556. capabilities.Prompts = &struct {
  557. ListChanged bool `json:"listChanged,omitempty"`
  558. }{
  559. ListChanged: s.capabilities.prompts.listChanged,
  560. }
  561. }
  562. // Only add tool capabilities if they're configured
  563. if s.capabilities.tools != nil {
  564. capabilities.Tools = &struct {
  565. ListChanged bool `json:"listChanged,omitempty"`
  566. }{
  567. ListChanged: s.capabilities.tools.listChanged,
  568. }
  569. }
  570. if s.capabilities.logging != nil && *s.capabilities.logging {
  571. capabilities.Logging = &struct{}{}
  572. }
  573. if s.capabilities.sampling != nil && *s.capabilities.sampling {
  574. capabilities.Sampling = &struct{}{}
  575. }
  576. result := mcp.InitializeResult{
  577. ProtocolVersion: s.protocolVersion(request.Params.ProtocolVersion),
  578. ServerInfo: mcp.Implementation{
  579. Name: s.name,
  580. Version: s.version,
  581. },
  582. Capabilities: capabilities,
  583. Instructions: s.instructions,
  584. }
  585. if session := ClientSessionFromContext(ctx); session != nil {
  586. session.Initialize()
  587. // Store client info if the session supports it
  588. if sessionWithClientInfo, ok := session.(SessionWithClientInfo); ok {
  589. sessionWithClientInfo.SetClientInfo(request.Params.ClientInfo)
  590. sessionWithClientInfo.SetClientCapabilities(request.Params.Capabilities)
  591. }
  592. }
  593. return &result, nil
  594. }
  595. func (s *MCPServer) protocolVersion(clientVersion string) string {
  596. // For backwards compatibility, if the server does not receive an MCP-Protocol-Version header,
  597. // and has no other way to identify the version - for example, by relying on the protocol version negotiated
  598. // during initialization - the server SHOULD assume protocol version 2025-03-26
  599. // https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#protocol-version-header
  600. if len(clientVersion) == 0 {
  601. clientVersion = "2025-03-26"
  602. }
  603. if slices.Contains(mcp.ValidProtocolVersions, clientVersion) {
  604. return clientVersion
  605. }
  606. return mcp.LATEST_PROTOCOL_VERSION
  607. }
  608. func (s *MCPServer) handlePing(
  609. _ context.Context,
  610. _ any,
  611. _ mcp.PingRequest,
  612. ) (*mcp.EmptyResult, *requestError) {
  613. return &mcp.EmptyResult{}, nil
  614. }
  615. func (s *MCPServer) handleSetLevel(
  616. ctx context.Context,
  617. id any,
  618. request mcp.SetLevelRequest,
  619. ) (*mcp.EmptyResult, *requestError) {
  620. clientSession := ClientSessionFromContext(ctx)
  621. if clientSession == nil || !clientSession.Initialized() {
  622. return nil, &requestError{
  623. id: id,
  624. code: mcp.INTERNAL_ERROR,
  625. err: ErrSessionNotInitialized,
  626. }
  627. }
  628. sessionLogging, ok := clientSession.(SessionWithLogging)
  629. if !ok {
  630. return nil, &requestError{
  631. id: id,
  632. code: mcp.INTERNAL_ERROR,
  633. err: ErrSessionDoesNotSupportLogging,
  634. }
  635. }
  636. level := request.Params.Level
  637. // Validate logging level
  638. switch level {
  639. case mcp.LoggingLevelDebug, mcp.LoggingLevelInfo, mcp.LoggingLevelNotice,
  640. mcp.LoggingLevelWarning, mcp.LoggingLevelError, mcp.LoggingLevelCritical,
  641. mcp.LoggingLevelAlert, mcp.LoggingLevelEmergency:
  642. // Valid level
  643. default:
  644. return nil, &requestError{
  645. id: id,
  646. code: mcp.INVALID_PARAMS,
  647. err: fmt.Errorf("invalid logging level '%s'", level),
  648. }
  649. }
  650. sessionLogging.SetLogLevel(level)
  651. return &mcp.EmptyResult{}, nil
  652. }
  653. func listByPagination[T mcp.Named](
  654. _ context.Context,
  655. s *MCPServer,
  656. cursor mcp.Cursor,
  657. allElements []T,
  658. ) ([]T, mcp.Cursor, error) {
  659. startPos := 0
  660. if cursor != "" {
  661. c, err := base64.StdEncoding.DecodeString(string(cursor))
  662. if err != nil {
  663. return nil, "", err
  664. }
  665. cString := string(c)
  666. startPos = sort.Search(len(allElements), func(i int) bool {
  667. return allElements[i].GetName() > cString
  668. })
  669. }
  670. endPos := len(allElements)
  671. if s.paginationLimit != nil {
  672. if len(allElements) > startPos+*s.paginationLimit {
  673. endPos = startPos + *s.paginationLimit
  674. }
  675. }
  676. elementsToReturn := allElements[startPos:endPos]
  677. // set the next cursor
  678. nextCursor := func() mcp.Cursor {
  679. if s.paginationLimit != nil && len(elementsToReturn) >= *s.paginationLimit {
  680. nc := elementsToReturn[len(elementsToReturn)-1].GetName()
  681. toString := base64.StdEncoding.EncodeToString([]byte(nc))
  682. return mcp.Cursor(toString)
  683. }
  684. return ""
  685. }()
  686. return elementsToReturn, nextCursor, nil
  687. }
  688. func (s *MCPServer) handleListResources(
  689. ctx context.Context,
  690. id any,
  691. request mcp.ListResourcesRequest,
  692. ) (*mcp.ListResourcesResult, *requestError) {
  693. s.resourcesMu.RLock()
  694. resources := make([]mcp.Resource, 0, len(s.resources))
  695. for _, entry := range s.resources {
  696. resources = append(resources, entry.resource)
  697. }
  698. s.resourcesMu.RUnlock()
  699. // Sort the resources by name
  700. sort.Slice(resources, func(i, j int) bool {
  701. return resources[i].Name < resources[j].Name
  702. })
  703. resourcesToReturn, nextCursor, err := listByPagination(
  704. ctx,
  705. s,
  706. request.Params.Cursor,
  707. resources,
  708. )
  709. if err != nil {
  710. return nil, &requestError{
  711. id: id,
  712. code: mcp.INVALID_PARAMS,
  713. err: err,
  714. }
  715. }
  716. result := mcp.ListResourcesResult{
  717. Resources: resourcesToReturn,
  718. PaginatedResult: mcp.PaginatedResult{
  719. NextCursor: nextCursor,
  720. },
  721. }
  722. return &result, nil
  723. }
  724. func (s *MCPServer) handleListResourceTemplates(
  725. ctx context.Context,
  726. id any,
  727. request mcp.ListResourceTemplatesRequest,
  728. ) (*mcp.ListResourceTemplatesResult, *requestError) {
  729. s.resourcesMu.RLock()
  730. templates := make([]mcp.ResourceTemplate, 0, len(s.resourceTemplates))
  731. for _, entry := range s.resourceTemplates {
  732. templates = append(templates, entry.template)
  733. }
  734. s.resourcesMu.RUnlock()
  735. sort.Slice(templates, func(i, j int) bool {
  736. return templates[i].Name < templates[j].Name
  737. })
  738. templatesToReturn, nextCursor, err := listByPagination(
  739. ctx,
  740. s,
  741. request.Params.Cursor,
  742. templates,
  743. )
  744. if err != nil {
  745. return nil, &requestError{
  746. id: id,
  747. code: mcp.INVALID_PARAMS,
  748. err: err,
  749. }
  750. }
  751. result := mcp.ListResourceTemplatesResult{
  752. ResourceTemplates: templatesToReturn,
  753. PaginatedResult: mcp.PaginatedResult{
  754. NextCursor: nextCursor,
  755. },
  756. }
  757. return &result, nil
  758. }
  759. func (s *MCPServer) handleReadResource(
  760. ctx context.Context,
  761. id any,
  762. request mcp.ReadResourceRequest,
  763. ) (*mcp.ReadResourceResult, *requestError) {
  764. s.resourcesMu.RLock()
  765. // First try direct resource handlers
  766. if entry, ok := s.resources[request.Params.URI]; ok {
  767. handler := entry.handler
  768. s.resourcesMu.RUnlock()
  769. finalHandler := handler
  770. s.middlewareMu.RLock()
  771. mw := s.resourceHandlerMiddlewares
  772. // Apply middlewares in reverse order
  773. for i := len(mw) - 1; i >= 0; i-- {
  774. finalHandler = mw[i](finalHandler)
  775. }
  776. s.middlewareMu.RUnlock()
  777. contents, err := finalHandler(ctx, request)
  778. if err != nil {
  779. return nil, &requestError{
  780. id: id,
  781. code: mcp.INTERNAL_ERROR,
  782. err: err,
  783. }
  784. }
  785. return &mcp.ReadResourceResult{Contents: contents}, nil
  786. }
  787. // If no direct handler found, try matching against templates
  788. var matchedHandler ResourceTemplateHandlerFunc
  789. var matched bool
  790. for _, entry := range s.resourceTemplates {
  791. template := entry.template
  792. if matchesTemplate(request.Params.URI, template.URITemplate) {
  793. matchedHandler = entry.handler
  794. matched = true
  795. matchedVars := template.URITemplate.Match(request.Params.URI)
  796. // Convert matched variables to a map
  797. request.Params.Arguments = make(map[string]any, len(matchedVars))
  798. for name, value := range matchedVars {
  799. request.Params.Arguments[name] = value.V
  800. }
  801. break
  802. }
  803. }
  804. s.resourcesMu.RUnlock()
  805. if matched {
  806. contents, err := matchedHandler(ctx, request)
  807. if err != nil {
  808. return nil, &requestError{
  809. id: id,
  810. code: mcp.INTERNAL_ERROR,
  811. err: err,
  812. }
  813. }
  814. return &mcp.ReadResourceResult{Contents: contents}, nil
  815. }
  816. return nil, &requestError{
  817. id: id,
  818. code: mcp.RESOURCE_NOT_FOUND,
  819. err: fmt.Errorf(
  820. "handler not found for resource URI '%s': %w",
  821. request.Params.URI,
  822. ErrResourceNotFound,
  823. ),
  824. }
  825. }
  826. // matchesTemplate checks if a URI matches a URI template pattern
  827. func matchesTemplate(uri string, template *mcp.URITemplate) bool {
  828. return template.Regexp().MatchString(uri)
  829. }
  830. func (s *MCPServer) handleListPrompts(
  831. ctx context.Context,
  832. id any,
  833. request mcp.ListPromptsRequest,
  834. ) (*mcp.ListPromptsResult, *requestError) {
  835. s.promptsMu.RLock()
  836. prompts := make([]mcp.Prompt, 0, len(s.prompts))
  837. for _, prompt := range s.prompts {
  838. prompts = append(prompts, prompt)
  839. }
  840. s.promptsMu.RUnlock()
  841. // sort prompts by name
  842. sort.Slice(prompts, func(i, j int) bool {
  843. return prompts[i].Name < prompts[j].Name
  844. })
  845. promptsToReturn, nextCursor, err := listByPagination(
  846. ctx,
  847. s,
  848. request.Params.Cursor,
  849. prompts,
  850. )
  851. if err != nil {
  852. return nil, &requestError{
  853. id: id,
  854. code: mcp.INVALID_PARAMS,
  855. err: err,
  856. }
  857. }
  858. result := mcp.ListPromptsResult{
  859. Prompts: promptsToReturn,
  860. PaginatedResult: mcp.PaginatedResult{
  861. NextCursor: nextCursor,
  862. },
  863. }
  864. return &result, nil
  865. }
  866. func (s *MCPServer) handleGetPrompt(
  867. ctx context.Context,
  868. id any,
  869. request mcp.GetPromptRequest,
  870. ) (*mcp.GetPromptResult, *requestError) {
  871. s.promptsMu.RLock()
  872. handler, ok := s.promptHandlers[request.Params.Name]
  873. s.promptsMu.RUnlock()
  874. if !ok {
  875. return nil, &requestError{
  876. id: id,
  877. code: mcp.INVALID_PARAMS,
  878. err: fmt.Errorf("prompt '%s' not found: %w", request.Params.Name, ErrPromptNotFound),
  879. }
  880. }
  881. result, err := handler(ctx, request)
  882. if err != nil {
  883. return nil, &requestError{
  884. id: id,
  885. code: mcp.INTERNAL_ERROR,
  886. err: err,
  887. }
  888. }
  889. return result, nil
  890. }
  891. func (s *MCPServer) handleListTools(
  892. ctx context.Context,
  893. id any,
  894. request mcp.ListToolsRequest,
  895. ) (*mcp.ListToolsResult, *requestError) {
  896. // Get the base tools from the server
  897. s.toolsMu.RLock()
  898. tools := make([]mcp.Tool, 0, len(s.tools))
  899. // Get all tool names for consistent ordering
  900. toolNames := make([]string, 0, len(s.tools))
  901. for name := range s.tools {
  902. toolNames = append(toolNames, name)
  903. }
  904. // Sort the tool names for consistent ordering
  905. sort.Strings(toolNames)
  906. // Add tools in sorted order
  907. for _, name := range toolNames {
  908. tools = append(tools, s.tools[name].Tool)
  909. }
  910. s.toolsMu.RUnlock()
  911. // Check if there are session-specific tools
  912. session := ClientSessionFromContext(ctx)
  913. if session != nil {
  914. if sessionWithTools, ok := session.(SessionWithTools); ok {
  915. if sessionTools := sessionWithTools.GetSessionTools(); sessionTools != nil {
  916. // Override or add session-specific tools
  917. // We need to create a map first to merge the tools properly
  918. toolMap := make(map[string]mcp.Tool)
  919. // Add global tools first
  920. for _, tool := range tools {
  921. toolMap[tool.Name] = tool
  922. }
  923. // Then override with session-specific tools
  924. for name, serverTool := range sessionTools {
  925. toolMap[name] = serverTool.Tool
  926. }
  927. // Convert back to slice
  928. tools = make([]mcp.Tool, 0, len(toolMap))
  929. for _, tool := range toolMap {
  930. tools = append(tools, tool)
  931. }
  932. // Sort again to maintain consistent ordering
  933. sort.Slice(tools, func(i, j int) bool {
  934. return tools[i].Name < tools[j].Name
  935. })
  936. }
  937. }
  938. }
  939. // Apply tool filters if any are defined
  940. s.toolFiltersMu.RLock()
  941. if len(s.toolFilters) > 0 {
  942. for _, filter := range s.toolFilters {
  943. tools = filter(ctx, tools)
  944. }
  945. }
  946. s.toolFiltersMu.RUnlock()
  947. // Apply pagination
  948. toolsToReturn, nextCursor, err := listByPagination(
  949. ctx,
  950. s,
  951. request.Params.Cursor,
  952. tools,
  953. )
  954. if err != nil {
  955. return nil, &requestError{
  956. id: id,
  957. code: mcp.INVALID_PARAMS,
  958. err: err,
  959. }
  960. }
  961. result := mcp.ListToolsResult{
  962. Tools: toolsToReturn,
  963. PaginatedResult: mcp.PaginatedResult{
  964. NextCursor: nextCursor,
  965. },
  966. }
  967. return &result, nil
  968. }
  969. func (s *MCPServer) handleToolCall(
  970. ctx context.Context,
  971. id any,
  972. request mcp.CallToolRequest,
  973. ) (*mcp.CallToolResult, *requestError) {
  974. // First check session-specific tools
  975. var tool ServerTool
  976. var ok bool
  977. session := ClientSessionFromContext(ctx)
  978. if session != nil {
  979. if sessionWithTools, typeAssertOk := session.(SessionWithTools); typeAssertOk {
  980. if sessionTools := sessionWithTools.GetSessionTools(); sessionTools != nil {
  981. var sessionOk bool
  982. tool, sessionOk = sessionTools[request.Params.Name]
  983. if sessionOk {
  984. ok = true
  985. }
  986. }
  987. }
  988. }
  989. // If not found in session tools, check global tools
  990. if !ok {
  991. s.toolsMu.RLock()
  992. tool, ok = s.tools[request.Params.Name]
  993. s.toolsMu.RUnlock()
  994. }
  995. if !ok {
  996. return nil, &requestError{
  997. id: id,
  998. code: mcp.INVALID_PARAMS,
  999. err: fmt.Errorf("tool '%s' not found: %w", request.Params.Name, ErrToolNotFound),
  1000. }
  1001. }
  1002. finalHandler := tool.Handler
  1003. s.middlewareMu.RLock()
  1004. mw := s.toolHandlerMiddlewares
  1005. // Apply middlewares in reverse order
  1006. for i := len(mw) - 1; i >= 0; i-- {
  1007. finalHandler = mw[i](finalHandler)
  1008. }
  1009. s.middlewareMu.RUnlock()
  1010. result, err := finalHandler(ctx, request)
  1011. if err != nil {
  1012. return nil, &requestError{
  1013. id: id,
  1014. code: mcp.INTERNAL_ERROR,
  1015. err: err,
  1016. }
  1017. }
  1018. return result, nil
  1019. }
  1020. func (s *MCPServer) handleNotification(
  1021. ctx context.Context,
  1022. notification mcp.JSONRPCNotification,
  1023. ) mcp.JSONRPCMessage {
  1024. s.notificationHandlersMu.RLock()
  1025. handler, ok := s.notificationHandlers[notification.Method]
  1026. s.notificationHandlersMu.RUnlock()
  1027. if ok {
  1028. handler(ctx, notification)
  1029. }
  1030. return nil
  1031. }
  1032. func createResponse(id any, result any) mcp.JSONRPCMessage {
  1033. return mcp.JSONRPCResponse{
  1034. JSONRPC: mcp.JSONRPC_VERSION,
  1035. ID: mcp.NewRequestId(id),
  1036. Result: result,
  1037. }
  1038. }
  1039. func createErrorResponse(
  1040. id any,
  1041. code int,
  1042. message string,
  1043. ) mcp.JSONRPCMessage {
  1044. return mcp.JSONRPCError{
  1045. JSONRPC: mcp.JSONRPC_VERSION,
  1046. ID: mcp.NewRequestId(id),
  1047. Error: struct {
  1048. Code int `json:"code"`
  1049. Message string `json:"message"`
  1050. Data any `json:"data,omitempty"`
  1051. }{
  1052. Code: code,
  1053. Message: message,
  1054. },
  1055. }
  1056. }