request_handler.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. // Code generated by `go generate`. DO NOT EDIT.
  2. // source: server/internal/gen/request_handler.go.tmpl
  3. package server
  4. import (
  5. "context"
  6. "encoding/json"
  7. "fmt"
  8. "net/http"
  9. "github.com/mark3labs/mcp-go/mcp"
  10. )
  11. // HandleMessage processes an incoming JSON-RPC message and returns an appropriate response
  12. func (s *MCPServer) HandleMessage(
  13. ctx context.Context,
  14. message json.RawMessage,
  15. ) mcp.JSONRPCMessage {
  16. // Add server to context
  17. ctx = context.WithValue(ctx, serverKey{}, s)
  18. var err *requestError
  19. var baseMessage struct {
  20. JSONRPC string `json:"jsonrpc"`
  21. Method mcp.MCPMethod `json:"method"`
  22. ID any `json:"id,omitempty"`
  23. Result any `json:"result,omitempty"`
  24. }
  25. if err := json.Unmarshal(message, &baseMessage); err != nil {
  26. return createErrorResponse(
  27. nil,
  28. mcp.PARSE_ERROR,
  29. "Failed to parse message",
  30. )
  31. }
  32. // Check for valid JSONRPC version
  33. if baseMessage.JSONRPC != mcp.JSONRPC_VERSION {
  34. return createErrorResponse(
  35. baseMessage.ID,
  36. mcp.INVALID_REQUEST,
  37. "Invalid JSON-RPC version",
  38. )
  39. }
  40. if baseMessage.ID == nil {
  41. var notification mcp.JSONRPCNotification
  42. if err := json.Unmarshal(message, &notification); err != nil {
  43. return createErrorResponse(
  44. nil,
  45. mcp.PARSE_ERROR,
  46. "Failed to parse notification",
  47. )
  48. }
  49. s.handleNotification(ctx, notification)
  50. return nil // Return nil for notifications
  51. }
  52. if baseMessage.Result != nil {
  53. // this is a response to a request sent by the server (e.g. from a ping
  54. // sent due to WithKeepAlive option)
  55. return nil
  56. }
  57. handleErr := s.hooks.onRequestInitialization(ctx, baseMessage.ID, message)
  58. if handleErr != nil {
  59. return createErrorResponse(
  60. baseMessage.ID,
  61. mcp.INVALID_REQUEST,
  62. handleErr.Error(),
  63. )
  64. }
  65. // Get request header from ctx
  66. h := ctx.Value(requestHeader)
  67. headers, ok := h.(http.Header)
  68. if headers == nil || !ok {
  69. headers = make(http.Header)
  70. }
  71. switch baseMessage.Method {
  72. case mcp.MethodInitialize:
  73. var request mcp.InitializeRequest
  74. var result *mcp.InitializeResult
  75. if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil {
  76. err = &requestError{
  77. id: baseMessage.ID,
  78. code: mcp.INVALID_REQUEST,
  79. err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method},
  80. }
  81. } else {
  82. request.Header = headers
  83. s.hooks.beforeInitialize(ctx, baseMessage.ID, &request)
  84. result, err = s.handleInitialize(ctx, baseMessage.ID, request)
  85. }
  86. if err != nil {
  87. s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err)
  88. return err.ToJSONRPCError()
  89. }
  90. s.hooks.afterInitialize(ctx, baseMessage.ID, &request, result)
  91. return createResponse(baseMessage.ID, *result)
  92. case mcp.MethodPing:
  93. var request mcp.PingRequest
  94. var result *mcp.EmptyResult
  95. if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil {
  96. err = &requestError{
  97. id: baseMessage.ID,
  98. code: mcp.INVALID_REQUEST,
  99. err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method},
  100. }
  101. } else {
  102. request.Header = headers
  103. s.hooks.beforePing(ctx, baseMessage.ID, &request)
  104. result, err = s.handlePing(ctx, baseMessage.ID, request)
  105. }
  106. if err != nil {
  107. s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err)
  108. return err.ToJSONRPCError()
  109. }
  110. s.hooks.afterPing(ctx, baseMessage.ID, &request, result)
  111. return createResponse(baseMessage.ID, *result)
  112. case mcp.MethodSetLogLevel:
  113. var request mcp.SetLevelRequest
  114. var result *mcp.EmptyResult
  115. if s.capabilities.logging == nil {
  116. err = &requestError{
  117. id: baseMessage.ID,
  118. code: mcp.METHOD_NOT_FOUND,
  119. err: fmt.Errorf("logging %w", ErrUnsupported),
  120. }
  121. } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil {
  122. err = &requestError{
  123. id: baseMessage.ID,
  124. code: mcp.INVALID_REQUEST,
  125. err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method},
  126. }
  127. } else {
  128. request.Header = headers
  129. s.hooks.beforeSetLevel(ctx, baseMessage.ID, &request)
  130. result, err = s.handleSetLevel(ctx, baseMessage.ID, request)
  131. }
  132. if err != nil {
  133. s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err)
  134. return err.ToJSONRPCError()
  135. }
  136. s.hooks.afterSetLevel(ctx, baseMessage.ID, &request, result)
  137. return createResponse(baseMessage.ID, *result)
  138. case mcp.MethodResourcesList:
  139. var request mcp.ListResourcesRequest
  140. var result *mcp.ListResourcesResult
  141. if s.capabilities.resources == nil {
  142. err = &requestError{
  143. id: baseMessage.ID,
  144. code: mcp.METHOD_NOT_FOUND,
  145. err: fmt.Errorf("resources %w", ErrUnsupported),
  146. }
  147. } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil {
  148. err = &requestError{
  149. id: baseMessage.ID,
  150. code: mcp.INVALID_REQUEST,
  151. err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method},
  152. }
  153. } else {
  154. request.Header = headers
  155. s.hooks.beforeListResources(ctx, baseMessage.ID, &request)
  156. result, err = s.handleListResources(ctx, baseMessage.ID, request)
  157. }
  158. if err != nil {
  159. s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err)
  160. return err.ToJSONRPCError()
  161. }
  162. s.hooks.afterListResources(ctx, baseMessage.ID, &request, result)
  163. return createResponse(baseMessage.ID, *result)
  164. case mcp.MethodResourcesTemplatesList:
  165. var request mcp.ListResourceTemplatesRequest
  166. var result *mcp.ListResourceTemplatesResult
  167. if s.capabilities.resources == nil {
  168. err = &requestError{
  169. id: baseMessage.ID,
  170. code: mcp.METHOD_NOT_FOUND,
  171. err: fmt.Errorf("resources %w", ErrUnsupported),
  172. }
  173. } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil {
  174. err = &requestError{
  175. id: baseMessage.ID,
  176. code: mcp.INVALID_REQUEST,
  177. err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method},
  178. }
  179. } else {
  180. request.Header = headers
  181. s.hooks.beforeListResourceTemplates(ctx, baseMessage.ID, &request)
  182. result, err = s.handleListResourceTemplates(ctx, baseMessage.ID, request)
  183. }
  184. if err != nil {
  185. s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err)
  186. return err.ToJSONRPCError()
  187. }
  188. s.hooks.afterListResourceTemplates(ctx, baseMessage.ID, &request, result)
  189. return createResponse(baseMessage.ID, *result)
  190. case mcp.MethodResourcesRead:
  191. var request mcp.ReadResourceRequest
  192. var result *mcp.ReadResourceResult
  193. if s.capabilities.resources == nil {
  194. err = &requestError{
  195. id: baseMessage.ID,
  196. code: mcp.METHOD_NOT_FOUND,
  197. err: fmt.Errorf("resources %w", ErrUnsupported),
  198. }
  199. } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil {
  200. err = &requestError{
  201. id: baseMessage.ID,
  202. code: mcp.INVALID_REQUEST,
  203. err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method},
  204. }
  205. } else {
  206. request.Header = headers
  207. s.hooks.beforeReadResource(ctx, baseMessage.ID, &request)
  208. result, err = s.handleReadResource(ctx, baseMessage.ID, request)
  209. }
  210. if err != nil {
  211. s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err)
  212. return err.ToJSONRPCError()
  213. }
  214. s.hooks.afterReadResource(ctx, baseMessage.ID, &request, result)
  215. return createResponse(baseMessage.ID, *result)
  216. case mcp.MethodPromptsList:
  217. var request mcp.ListPromptsRequest
  218. var result *mcp.ListPromptsResult
  219. if s.capabilities.prompts == nil {
  220. err = &requestError{
  221. id: baseMessage.ID,
  222. code: mcp.METHOD_NOT_FOUND,
  223. err: fmt.Errorf("prompts %w", ErrUnsupported),
  224. }
  225. } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil {
  226. err = &requestError{
  227. id: baseMessage.ID,
  228. code: mcp.INVALID_REQUEST,
  229. err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method},
  230. }
  231. } else {
  232. request.Header = headers
  233. s.hooks.beforeListPrompts(ctx, baseMessage.ID, &request)
  234. result, err = s.handleListPrompts(ctx, baseMessage.ID, request)
  235. }
  236. if err != nil {
  237. s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err)
  238. return err.ToJSONRPCError()
  239. }
  240. s.hooks.afterListPrompts(ctx, baseMessage.ID, &request, result)
  241. return createResponse(baseMessage.ID, *result)
  242. case mcp.MethodPromptsGet:
  243. var request mcp.GetPromptRequest
  244. var result *mcp.GetPromptResult
  245. if s.capabilities.prompts == nil {
  246. err = &requestError{
  247. id: baseMessage.ID,
  248. code: mcp.METHOD_NOT_FOUND,
  249. err: fmt.Errorf("prompts %w", ErrUnsupported),
  250. }
  251. } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil {
  252. err = &requestError{
  253. id: baseMessage.ID,
  254. code: mcp.INVALID_REQUEST,
  255. err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method},
  256. }
  257. } else {
  258. request.Header = headers
  259. s.hooks.beforeGetPrompt(ctx, baseMessage.ID, &request)
  260. result, err = s.handleGetPrompt(ctx, baseMessage.ID, request)
  261. }
  262. if err != nil {
  263. s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err)
  264. return err.ToJSONRPCError()
  265. }
  266. s.hooks.afterGetPrompt(ctx, baseMessage.ID, &request, result)
  267. return createResponse(baseMessage.ID, *result)
  268. case mcp.MethodToolsList:
  269. var request mcp.ListToolsRequest
  270. var result *mcp.ListToolsResult
  271. if s.capabilities.tools == nil {
  272. err = &requestError{
  273. id: baseMessage.ID,
  274. code: mcp.METHOD_NOT_FOUND,
  275. err: fmt.Errorf("tools %w", ErrUnsupported),
  276. }
  277. } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil {
  278. err = &requestError{
  279. id: baseMessage.ID,
  280. code: mcp.INVALID_REQUEST,
  281. err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method},
  282. }
  283. } else {
  284. request.Header = headers
  285. s.hooks.beforeListTools(ctx, baseMessage.ID, &request)
  286. result, err = s.handleListTools(ctx, baseMessage.ID, request)
  287. }
  288. if err != nil {
  289. s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err)
  290. return err.ToJSONRPCError()
  291. }
  292. s.hooks.afterListTools(ctx, baseMessage.ID, &request, result)
  293. return createResponse(baseMessage.ID, *result)
  294. case mcp.MethodToolsCall:
  295. var request mcp.CallToolRequest
  296. var result *mcp.CallToolResult
  297. if s.capabilities.tools == nil {
  298. err = &requestError{
  299. id: baseMessage.ID,
  300. code: mcp.METHOD_NOT_FOUND,
  301. err: fmt.Errorf("tools %w", ErrUnsupported),
  302. }
  303. } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil {
  304. err = &requestError{
  305. id: baseMessage.ID,
  306. code: mcp.INVALID_REQUEST,
  307. err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method},
  308. }
  309. } else {
  310. request.Header = headers
  311. s.hooks.beforeCallTool(ctx, baseMessage.ID, &request)
  312. result, err = s.handleToolCall(ctx, baseMessage.ID, request)
  313. }
  314. if err != nil {
  315. s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err)
  316. return err.ToJSONRPCError()
  317. }
  318. s.hooks.afterCallTool(ctx, baseMessage.ID, &request, result)
  319. return createResponse(baseMessage.ID, *result)
  320. default:
  321. return createErrorResponse(
  322. baseMessage.ID,
  323. mcp.METHOD_NOT_FOUND,
  324. fmt.Sprintf("Method %s not found", baseMessage.Method),
  325. )
  326. }
  327. }