cors.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508
  1. // Copyright 2019 Yunion
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package appsrv
  15. /*
  16. Package cors is net/http handler to handle CORS related requests
  17. as defined by http://www.w3.org/TR/cors/
  18. You can configure it by passing an option struct to cors.New:
  19. c := cors.New(cors.Options{
  20. AllowedOrigins: []string{"foo.com"},
  21. AllowedMethods: []string{"GET", "POST", "DELETE"},
  22. AllowCredentials: true,
  23. })
  24. Then insert the handler in the chain:
  25. handler = c.Handler(handler)
  26. See Options documentation for more options.
  27. The resulting handler is a standard net/http handler.
  28. */
  29. // package cors
  30. import (
  31. "log"
  32. "net/http"
  33. "net/url"
  34. "os"
  35. "strconv"
  36. "strings"
  37. )
  38. const toLower = 'a' - 'A'
  39. type wildcard struct {
  40. prefix string
  41. suffix string
  42. }
  43. func (w wildcard) match(s string) bool {
  44. return len(s) >= len(w.prefix+w.suffix) && strings.HasPrefix(s, w.prefix) && strings.HasSuffix(s, w.suffix)
  45. }
  46. type converter func(string) string
  47. // convert converts a list of string using the passed converter function
  48. func convert(s []string, c converter) []string {
  49. out := []string{}
  50. for _, i := range s {
  51. out = append(out, c(i))
  52. }
  53. return out
  54. }
  55. // parseHeaderList tokenize + normalize a string containing a list of headers
  56. func parseHeaderList(headerList string) []string {
  57. l := len(headerList)
  58. h := make([]byte, 0, l)
  59. upper := true
  60. // Estimate the number headers in order to allocate the right splice size
  61. t := 0
  62. for i := 0; i < l; i++ {
  63. if headerList[i] == ',' {
  64. t++
  65. }
  66. }
  67. headers := make([]string, 0, t)
  68. for i := 0; i < l; i++ {
  69. b := headerList[i]
  70. if b >= 'a' && b <= 'z' {
  71. if upper {
  72. h = append(h, b-toLower)
  73. } else {
  74. h = append(h, b)
  75. }
  76. } else if b >= 'A' && b <= 'Z' {
  77. if !upper {
  78. h = append(h, b+toLower)
  79. } else {
  80. h = append(h, b)
  81. }
  82. } else if b == '-' || b == '_' || (b >= '0' && b <= '9') {
  83. h = append(h, b)
  84. }
  85. if b == ' ' || b == ',' || i == l-1 {
  86. if len(h) > 0 {
  87. // Flush the found header
  88. headers = append(headers, string(h))
  89. h = h[:0]
  90. upper = true
  91. }
  92. } else {
  93. upper = b == '-' || b == '_'
  94. }
  95. }
  96. return headers
  97. }
  98. // Options is a configuration container to setup the CORS middleware.
  99. type CorsOptions struct {
  100. // AllowedOrigins is a list of origins a cross-domain request can be executed from.
  101. // If the special "*" value is present in the list, all origins will be allowed.
  102. // An origin may contain a wildcard (*) to replace 0 or more characters
  103. // (i.e.: http://*.domain.com). Usage of wildcards implies a small performance penalty.
  104. // Only one wildcard can be used per origin.
  105. // Default value is ["*"]
  106. AllowedOrigins []string
  107. // AllowOriginFunc is a custom function to validate the origin. It take the origin
  108. // as argument and returns true if allowed or false otherwise. If this option is
  109. // set, the content of AllowedOrigins is ignored.
  110. AllowOriginFunc func(origin string) bool
  111. // AllowedMethods is a list of methods the client is allowed to use with
  112. // cross-domain requests. Default value is simple methods (HEAD, GET and POST).
  113. AllowedMethods []string
  114. // AllowedHeaders is list of non simple headers the client is allowed to use with
  115. // cross-domain requests.
  116. // If the special "*" value is present in the list, all headers will be allowed.
  117. // Default value is [] but "Origin" is always appended to the list.
  118. AllowedHeaders []string
  119. // ExposedHeaders indicates which headers are safe to expose to the API of a CORS
  120. // API specification
  121. ExposedHeaders []string
  122. // AllowCredentials indicates whether the request can include user credentials like
  123. // cookies, HTTP authentication or client side SSL certificates.
  124. AllowCredentials bool
  125. // MaxAge indicates how long (in seconds) the results of a preflight request
  126. // can be cached
  127. MaxAge int
  128. // Debugging flag adds additional output to debug server side CORS issues
  129. Debug bool
  130. }
  131. // Cors http handler
  132. type Cors struct {
  133. // Debug logger
  134. Log *log.Logger
  135. // Set to true when allowed origins contains a "*"
  136. allowedOriginsAll bool
  137. // Normalized list of plain allowed origins
  138. allowedOrigins []string
  139. // List of allowed origins containing wildcards
  140. allowedWOrigins []wildcard
  141. // Optional origin validator function
  142. allowOriginFunc func(origin string) bool
  143. // Set to true when allowed headers contains a "*"
  144. allowedHeadersAll bool
  145. // Normalized list of allowed headers
  146. allowedHeaders []string
  147. // Normalized list of allowed methods
  148. allowedMethods []string
  149. // Normalized list of exposed headers
  150. exposedHeaders []string
  151. allowCredentials bool
  152. maxAge int
  153. optionPassthrough bool
  154. }
  155. // New creates a new Cors handler with the provided options.
  156. func NewCors(options CorsOptions) *Cors {
  157. c := &Cors{
  158. exposedHeaders: convert(options.ExposedHeaders, http.CanonicalHeaderKey),
  159. allowOriginFunc: options.AllowOriginFunc,
  160. allowCredentials: options.AllowCredentials,
  161. maxAge: options.MaxAge,
  162. }
  163. if options.Debug {
  164. c.Log = log.New(os.Stdout, "[cors] ", log.LstdFlags)
  165. }
  166. // Normalize options
  167. // Note: for origins and methods matching, the spec requires a case-sensitive matching.
  168. // As it may error prone, we chose to ignore the spec here.
  169. // Allowed Origins
  170. if len(options.AllowedOrigins) == 0 {
  171. if options.AllowOriginFunc == nil {
  172. // Default is all origins
  173. c.allowedOriginsAll = true
  174. }
  175. } else {
  176. c.allowedOrigins = []string{}
  177. c.allowedWOrigins = []wildcard{}
  178. for _, origin := range options.AllowedOrigins {
  179. // Normalize
  180. origin = strings.ToLower(origin)
  181. if origin == "*" {
  182. // If "*" is present in the list, turn the whole list into a match all
  183. c.allowedOriginsAll = true
  184. c.allowedOrigins = nil
  185. c.allowedWOrigins = nil
  186. break
  187. } else if i := strings.IndexByte(origin, '*'); i >= 0 {
  188. // Split the origin in two: start and end string without the *
  189. w := wildcard{origin[0:i], origin[i+1:]}
  190. c.allowedWOrigins = append(c.allowedWOrigins, w)
  191. } else {
  192. c.allowedOrigins = append(c.allowedOrigins, origin)
  193. }
  194. }
  195. }
  196. // Allowed Headers
  197. if len(options.AllowedHeaders) == 0 {
  198. // Use sensible defaults
  199. c.allowedHeaders = []string{"Origin", "Accept", "Content-Type"}
  200. } else {
  201. // Origin is always appended as some browsers will always request for this header at preflight
  202. c.allowedHeaders = convert(append(options.AllowedHeaders, "Origin"), http.CanonicalHeaderKey)
  203. for _, h := range options.AllowedHeaders {
  204. if h == "*" {
  205. c.allowedHeadersAll = true
  206. c.allowedHeaders = nil
  207. break
  208. }
  209. }
  210. }
  211. // Allowed Methods
  212. if len(options.AllowedMethods) == 0 {
  213. // Default is spec's "simple" methods
  214. c.allowedMethods = []string{"GET", "POST", "HEAD"}
  215. } else {
  216. c.allowedMethods = convert(options.AllowedMethods, strings.ToUpper)
  217. }
  218. return c
  219. }
  220. // AllowAll create a new Cors handler with permissive configuration allowing all
  221. // origins with all standard methods with any header and credentials.
  222. /* func AllowAll() *Cors {
  223. return NewCors(CorsOptions{
  224. AllowedOrigins: []string{"*"},
  225. AllowedMethods: []string{"HEAD", "GET", "POST", "PUT", "PATCH", "DELETE"},
  226. AllowedHeaders: []string{"*"},
  227. AllowCredentials: true,
  228. })
  229. }*/
  230. // Handler apply the CORS specification on the request, and add relevant CORS headers
  231. // as necessary.
  232. func (c *Cors) Handler(h http.Handler) http.Handler {
  233. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  234. if r.Method == "OPTIONS" && r.Header.Get("Access-Control-Request-Method") != "" {
  235. c.logf("Handler: Preflight request")
  236. c.handlePreflight(w, r)
  237. // Preflight requests are standalone and should stop the chain as some other
  238. // middleware may not handle OPTIONS requests correctly. One typical example
  239. // is authentication middleware ; OPTIONS requests won't carry authentication
  240. // headers (see #1)
  241. if c.optionPassthrough {
  242. h.ServeHTTP(w, r)
  243. } else {
  244. w.WriteHeader(http.StatusOK)
  245. }
  246. } else {
  247. c.logf("Handler: Actual request")
  248. c.handleActualRequest(w, r)
  249. h.ServeHTTP(w, r)
  250. }
  251. })
  252. }
  253. // HandlerFunc provides Martini compatible handler
  254. func (c *Cors) HandlerFunc(w http.ResponseWriter, r *http.Request) {
  255. if r.Method == "OPTIONS" && r.Header.Get("Access-Control-Request-Method") != "" {
  256. c.logf("HandlerFunc: Preflight request")
  257. c.handlePreflight(w, r)
  258. } else {
  259. c.logf("HandlerFunc: Actual request")
  260. c.handleActualRequest(w, r)
  261. }
  262. }
  263. // Negroni compatible interface
  264. func (c *Cors) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
  265. if r.Method == "OPTIONS" && r.Header.Get("Access-Control-Request-Method") != "" {
  266. c.logf("ServeHTTP: Preflight request")
  267. c.handlePreflight(w, r)
  268. // Preflight requests are standalone and should stop the chain as some other
  269. // middleware may not handle OPTIONS requests correctly. One typical example
  270. // is authentication middleware ; OPTIONS requests won't carry authentication
  271. // headers (see #1)
  272. if c.optionPassthrough {
  273. next(w, r)
  274. } else {
  275. w.WriteHeader(http.StatusOK)
  276. }
  277. } else {
  278. c.logf("ServeHTTP: Actual request")
  279. c.handleActualRequest(w, r)
  280. next(w, r)
  281. }
  282. }
  283. // handlePreflight handles pre-flight CORS requests
  284. func (c *Cors) handlePreflight(w http.ResponseWriter, r *http.Request) {
  285. headers := w.Header()
  286. origin := r.Header.Get("Origin")
  287. if r.Method != "OPTIONS" {
  288. c.logf(" Preflight aborted: %s!=OPTIONS", r.Method)
  289. return
  290. }
  291. // Always set Vary headers
  292. // see https://github.com/rs/cors/issues/10,
  293. // https://github.com/rs/cors/commit/dbdca4d95feaa7511a46e6f1efb3b3aa505bc43f#commitcomment-12352001
  294. headers.Add("Vary", "Origin")
  295. headers.Add("Vary", "Access-Control-Request-Method")
  296. headers.Add("Vary", "Access-Control-Request-Headers")
  297. if origin == "" {
  298. c.logf(" Preflight aborted: empty origin")
  299. return
  300. }
  301. if !c.isOriginAllowed(origin) {
  302. c.logf(" Preflight aborted: origin '%s' not allowed", origin)
  303. return
  304. }
  305. reqMethod := r.Header.Get("Access-Control-Request-Method")
  306. if !c.isMethodAllowed(reqMethod) {
  307. c.logf(" Preflight aborted: method '%s' not allowed", reqMethod)
  308. return
  309. }
  310. reqHeaders := parseHeaderList(r.Header.Get("Access-Control-Request-Headers"))
  311. if !c.areHeadersAllowed(reqHeaders) {
  312. c.logf(" Preflight aborted: headers '%v' not allowed", reqHeaders)
  313. return
  314. }
  315. if c.allowedOriginsAll && !c.allowCredentials {
  316. headers.Set("Access-Control-Allow-Origin", "*")
  317. } else {
  318. headers.Set("Access-Control-Allow-Origin", origin)
  319. }
  320. // Spec says: Since the list of methods can be unbounded, simply returning the method indicated
  321. // by Access-Control-Request-Method (if supported) can be enough
  322. headers.Set("Access-Control-Allow-Methods", strings.ToUpper(reqMethod))
  323. if len(reqHeaders) > 0 {
  324. // Spec says: Since the list of headers can be unbounded, simply returning supported headers
  325. // from Access-Control-Request-Headers can be enough
  326. headers.Set("Access-Control-Allow-Headers", strings.Join(reqHeaders, ", "))
  327. }
  328. if c.allowCredentials {
  329. headers.Set("Access-Control-Allow-Credentials", "true")
  330. }
  331. if c.maxAge > 0 {
  332. headers.Set("Access-Control-Max-Age", strconv.Itoa(c.maxAge))
  333. }
  334. c.logf(" Preflight response headers: %v", headers)
  335. }
  336. // handleActualRequest handles simple cross-origin requests, actual request or redirects
  337. func (c *Cors) handleActualRequest(w http.ResponseWriter, r *http.Request) {
  338. headers := w.Header()
  339. origin := r.Header.Get("Origin")
  340. if r.Method == "OPTIONS" {
  341. c.logf(" Actual request no headers added: method == %s", r.Method)
  342. return
  343. }
  344. // Always set Vary, see https://github.com/rs/cors/issues/10
  345. headers.Add("Vary", "Origin")
  346. if origin == "" {
  347. c.logf(" Actual request no headers added: missing origin")
  348. return
  349. }
  350. if !c.isOriginAllowed(origin) {
  351. c.logf(" Actual request no headers added: origin '%s' not allowed", origin)
  352. return
  353. }
  354. // Note that spec does define a way to specifically disallow a simple method like GET or
  355. // POST. Access-Control-Allow-Methods is only used for pre-flight requests and the
  356. // spec doesn't instruct to check the allowed methods for simple cross-origin requests.
  357. // We think it's a nice feature to be able to have control on those methods though.
  358. if !c.isMethodAllowed(r.Method) {
  359. c.logf(" Actual request no headers added: method '%s' not allowed", r.Method)
  360. return
  361. }
  362. if c.allowedOriginsAll && !c.allowCredentials {
  363. headers.Set("Access-Control-Allow-Origin", "*")
  364. } else {
  365. headers.Set("Access-Control-Allow-Origin", origin)
  366. }
  367. if len(c.exposedHeaders) > 0 {
  368. headers.Set("Access-Control-Expose-Headers", strings.Join(c.exposedHeaders, ", "))
  369. }
  370. if c.allowCredentials {
  371. headers.Set("Access-Control-Allow-Credentials", "true")
  372. }
  373. c.logf(" Actual response added headers: %v", headers)
  374. }
  375. // convenience method. checks if debugging is turned on before printing
  376. func (c *Cors) logf(format string, a ...interface{}) {
  377. if c.Log != nil {
  378. c.Log.Printf(format, a...)
  379. }
  380. }
  381. // isOriginAllowed checks if a given origin is allowed to perform cross-domain requests
  382. // on the endpoint
  383. func (c *Cors) isOriginAllowed(originURL string) bool {
  384. if c.allowedOriginsAll {
  385. return true
  386. }
  387. u, e := url.Parse(originURL)
  388. if e != nil {
  389. return false
  390. }
  391. origin := u.Hostname()
  392. if c.allowOriginFunc != nil {
  393. return c.allowOriginFunc(origin)
  394. }
  395. origin = strings.ToLower(origin)
  396. for _, o := range c.allowedOrigins {
  397. if o == origin {
  398. return true
  399. }
  400. }
  401. for _, w := range c.allowedWOrigins {
  402. if w.match(origin) {
  403. return true
  404. }
  405. }
  406. return false
  407. }
  408. // isMethodAllowed checks if a given method can be used as part of a cross-domain request
  409. // on the endpoing
  410. func (c *Cors) isMethodAllowed(method string) bool {
  411. if len(c.allowedMethods) == 0 {
  412. // If no method allowed, always return false, even for preflight request
  413. return false
  414. }
  415. method = strings.ToUpper(method)
  416. if method == "OPTIONS" {
  417. // Always allow preflight requests
  418. return true
  419. }
  420. for _, m := range c.allowedMethods {
  421. if m == method {
  422. return true
  423. }
  424. }
  425. return false
  426. }
  427. // areHeadersAllowed checks if a given list of headers are allowed to used within
  428. // a cross-domain request.
  429. func (c *Cors) areHeadersAllowed(requestedHeaders []string) bool {
  430. if c.allowedHeadersAll || len(requestedHeaders) == 0 {
  431. return true
  432. }
  433. for _, header := range requestedHeaders {
  434. header = http.CanonicalHeaderKey(header)
  435. found := false
  436. for _, h := range c.allowedHeaders {
  437. if h == header {
  438. found = true
  439. }
  440. }
  441. if !found {
  442. return false
  443. }
  444. }
  445. return true
  446. }
  447. func (app *Application) CORSAllowAll() {
  448. app.CORSAllowHosts([]string{"*"})
  449. }
  450. func (app *Application) CORSAllowHosts(hosts []string) {
  451. log.Println("Allow hosts", hosts)
  452. options := CorsOptions{
  453. AllowedOrigins: hosts,
  454. AllowedMethods: []string{"HEAD", "GET", "POST", "PUT", "PATCH", "DELETE"},
  455. AllowedHeaders: []string{"*"},
  456. ExposedHeaders: []string{"Authorization"},
  457. AllowCredentials: true,
  458. // Debug: true,
  459. }
  460. app.EnableCORS(options)
  461. }
  462. func (app *Application) EnableCORS(options CorsOptions) {
  463. app.cors = NewCors(options)
  464. }