mock.go 35 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241
  1. package mock
  2. import (
  3. "errors"
  4. "fmt"
  5. "path"
  6. "reflect"
  7. "regexp"
  8. "runtime"
  9. "strings"
  10. "sync"
  11. "time"
  12. "github.com/davecgh/go-spew/spew"
  13. "github.com/pmezard/go-difflib/difflib"
  14. "github.com/stretchr/objx"
  15. "github.com/stretchr/testify/assert"
  16. )
  17. // regex for GCCGO functions
  18. var gccgoRE = regexp.MustCompile(`\.pN\d+_`)
  19. // TestingT is an interface wrapper around *testing.T
  20. type TestingT interface {
  21. Logf(format string, args ...interface{})
  22. Errorf(format string, args ...interface{})
  23. FailNow()
  24. }
  25. /*
  26. Call
  27. */
  28. // Call represents a method call and is used for setting expectations,
  29. // as well as recording activity.
  30. type Call struct {
  31. Parent *Mock
  32. // The name of the method that was or will be called.
  33. Method string
  34. // Holds the arguments of the method.
  35. Arguments Arguments
  36. // Holds the arguments that should be returned when
  37. // this method is called.
  38. ReturnArguments Arguments
  39. // Holds the caller info for the On() call
  40. callerInfo []string
  41. // The number of times to return the return arguments when setting
  42. // expectations. 0 means to always return the value.
  43. Repeatability int
  44. // Amount of times this call has been called
  45. totalCalls int
  46. // Call to this method can be optional
  47. optional bool
  48. // Holds a channel that will be used to block the Return until it either
  49. // receives a message or is closed. nil means it returns immediately.
  50. WaitFor <-chan time.Time
  51. waitTime time.Duration
  52. // Holds a handler used to manipulate arguments content that are passed by
  53. // reference. It's useful when mocking methods such as unmarshalers or
  54. // decoders.
  55. RunFn func(Arguments)
  56. // PanicMsg holds msg to be used to mock panic on the function call
  57. // if the PanicMsg is set to a non nil string the function call will panic
  58. // irrespective of other settings
  59. PanicMsg *string
  60. // Calls which must be satisfied before this call can be
  61. requires []*Call
  62. }
  63. func newCall(parent *Mock, methodName string, callerInfo []string, methodArguments ...interface{}) *Call {
  64. return &Call{
  65. Parent: parent,
  66. Method: methodName,
  67. Arguments: methodArguments,
  68. ReturnArguments: make([]interface{}, 0),
  69. callerInfo: callerInfo,
  70. Repeatability: 0,
  71. WaitFor: nil,
  72. RunFn: nil,
  73. PanicMsg: nil,
  74. }
  75. }
  76. func (c *Call) lock() {
  77. c.Parent.mutex.Lock()
  78. }
  79. func (c *Call) unlock() {
  80. c.Parent.mutex.Unlock()
  81. }
  82. // Return specifies the return arguments for the expectation.
  83. //
  84. // Mock.On("DoSomething").Return(errors.New("failed"))
  85. func (c *Call) Return(returnArguments ...interface{}) *Call {
  86. c.lock()
  87. defer c.unlock()
  88. c.ReturnArguments = returnArguments
  89. return c
  90. }
  91. // Panic specifies if the function call should fail and the panic message
  92. //
  93. // Mock.On("DoSomething").Panic("test panic")
  94. func (c *Call) Panic(msg string) *Call {
  95. c.lock()
  96. defer c.unlock()
  97. c.PanicMsg = &msg
  98. return c
  99. }
  100. // Once indicates that the mock should only return the value once.
  101. //
  102. // Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Once()
  103. func (c *Call) Once() *Call {
  104. return c.Times(1)
  105. }
  106. // Twice indicates that the mock should only return the value twice.
  107. //
  108. // Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Twice()
  109. func (c *Call) Twice() *Call {
  110. return c.Times(2)
  111. }
  112. // Times indicates that the mock should only return the indicated number
  113. // of times.
  114. //
  115. // Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Times(5)
  116. func (c *Call) Times(i int) *Call {
  117. c.lock()
  118. defer c.unlock()
  119. c.Repeatability = i
  120. return c
  121. }
  122. // WaitUntil sets the channel that will block the mock's return until its closed
  123. // or a message is received.
  124. //
  125. // Mock.On("MyMethod", arg1, arg2).WaitUntil(time.After(time.Second))
  126. func (c *Call) WaitUntil(w <-chan time.Time) *Call {
  127. c.lock()
  128. defer c.unlock()
  129. c.WaitFor = w
  130. return c
  131. }
  132. // After sets how long to block until the call returns
  133. //
  134. // Mock.On("MyMethod", arg1, arg2).After(time.Second)
  135. func (c *Call) After(d time.Duration) *Call {
  136. c.lock()
  137. defer c.unlock()
  138. c.waitTime = d
  139. return c
  140. }
  141. // Run sets a handler to be called before returning. It can be used when
  142. // mocking a method (such as an unmarshaler) that takes a pointer to a struct and
  143. // sets properties in such struct
  144. //
  145. // Mock.On("Unmarshal", AnythingOfType("*map[string]interface{}")).Return().Run(func(args Arguments) {
  146. // arg := args.Get(0).(*map[string]interface{})
  147. // arg["foo"] = "bar"
  148. // })
  149. func (c *Call) Run(fn func(args Arguments)) *Call {
  150. c.lock()
  151. defer c.unlock()
  152. c.RunFn = fn
  153. return c
  154. }
  155. // Maybe allows the method call to be optional. Not calling an optional method
  156. // will not cause an error while asserting expectations
  157. func (c *Call) Maybe() *Call {
  158. c.lock()
  159. defer c.unlock()
  160. c.optional = true
  161. return c
  162. }
  163. // On chains a new expectation description onto the mocked interface. This
  164. // allows syntax like.
  165. //
  166. // Mock.
  167. // On("MyMethod", 1).Return(nil).
  168. // On("MyOtherMethod", 'a', 'b', 'c').Return(errors.New("Some Error"))
  169. //
  170. //go:noinline
  171. func (c *Call) On(methodName string, arguments ...interface{}) *Call {
  172. return c.Parent.On(methodName, arguments...)
  173. }
  174. // Unset removes a mock handler from being called.
  175. //
  176. // test.On("func", mock.Anything).Unset()
  177. func (c *Call) Unset() *Call {
  178. var unlockOnce sync.Once
  179. for _, arg := range c.Arguments {
  180. if v := reflect.ValueOf(arg); v.Kind() == reflect.Func {
  181. panic(fmt.Sprintf("cannot use Func in expectations. Use mock.AnythingOfType(\"%T\")", arg))
  182. }
  183. }
  184. c.lock()
  185. defer unlockOnce.Do(c.unlock)
  186. foundMatchingCall := false
  187. // in-place filter slice for calls to be removed - iterate from 0'th to last skipping unnecessary ones
  188. var index int // write index
  189. for _, call := range c.Parent.ExpectedCalls {
  190. if call.Method == c.Method {
  191. _, diffCount := call.Arguments.Diff(c.Arguments)
  192. if diffCount == 0 {
  193. foundMatchingCall = true
  194. // Remove from ExpectedCalls - just skip it
  195. continue
  196. }
  197. }
  198. c.Parent.ExpectedCalls[index] = call
  199. index++
  200. }
  201. // trim slice up to last copied index
  202. c.Parent.ExpectedCalls = c.Parent.ExpectedCalls[:index]
  203. if !foundMatchingCall {
  204. unlockOnce.Do(c.unlock)
  205. c.Parent.fail("\n\nmock: Could not find expected call\n-----------------------------\n\n%s\n\n",
  206. callString(c.Method, c.Arguments, true),
  207. )
  208. }
  209. return c
  210. }
  211. // NotBefore indicates that the mock should only be called after the referenced
  212. // calls have been called as expected. The referenced calls may be from the
  213. // same mock instance and/or other mock instances.
  214. //
  215. // Mock.On("Do").Return(nil).Notbefore(
  216. // Mock.On("Init").Return(nil)
  217. // )
  218. func (c *Call) NotBefore(calls ...*Call) *Call {
  219. c.lock()
  220. defer c.unlock()
  221. for _, call := range calls {
  222. if call.Parent == nil {
  223. panic("not before calls must be created with Mock.On()")
  224. }
  225. }
  226. c.requires = append(c.requires, calls...)
  227. return c
  228. }
  229. // Mock is the workhorse used to track activity on another object.
  230. // For an example of its usage, refer to the "Example Usage" section at the top
  231. // of this document.
  232. type Mock struct {
  233. // Represents the calls that are expected of
  234. // an object.
  235. ExpectedCalls []*Call
  236. // Holds the calls that were made to this mocked object.
  237. Calls []Call
  238. // test is An optional variable that holds the test struct, to be used when an
  239. // invalid mock call was made.
  240. test TestingT
  241. // TestData holds any data that might be useful for testing. Testify ignores
  242. // this data completely allowing you to do whatever you like with it.
  243. testData objx.Map
  244. mutex sync.Mutex
  245. }
  246. // String provides a %v format string for Mock.
  247. // Note: this is used implicitly by Arguments.Diff if a Mock is passed.
  248. // It exists because go's default %v formatting traverses the struct
  249. // without acquiring the mutex, which is detected by go test -race.
  250. func (m *Mock) String() string {
  251. return fmt.Sprintf("%[1]T<%[1]p>", m)
  252. }
  253. // TestData holds any data that might be useful for testing. Testify ignores
  254. // this data completely allowing you to do whatever you like with it.
  255. func (m *Mock) TestData() objx.Map {
  256. if m.testData == nil {
  257. m.testData = make(objx.Map)
  258. }
  259. return m.testData
  260. }
  261. /*
  262. Setting expectations
  263. */
  264. // Test sets the test struct variable of the mock object
  265. func (m *Mock) Test(t TestingT) {
  266. m.mutex.Lock()
  267. defer m.mutex.Unlock()
  268. m.test = t
  269. }
  270. // fail fails the current test with the given formatted format and args.
  271. // In case that a test was defined, it uses the test APIs for failing a test,
  272. // otherwise it uses panic.
  273. func (m *Mock) fail(format string, args ...interface{}) {
  274. m.mutex.Lock()
  275. defer m.mutex.Unlock()
  276. if m.test == nil {
  277. panic(fmt.Sprintf(format, args...))
  278. }
  279. m.test.Errorf(format, args...)
  280. m.test.FailNow()
  281. }
  282. // On starts a description of an expectation of the specified method
  283. // being called.
  284. //
  285. // Mock.On("MyMethod", arg1, arg2)
  286. func (m *Mock) On(methodName string, arguments ...interface{}) *Call {
  287. for _, arg := range arguments {
  288. if v := reflect.ValueOf(arg); v.Kind() == reflect.Func {
  289. panic(fmt.Sprintf("cannot use Func in expectations. Use mock.AnythingOfType(\"%T\")", arg))
  290. }
  291. }
  292. m.mutex.Lock()
  293. defer m.mutex.Unlock()
  294. c := newCall(m, methodName, assert.CallerInfo(), arguments...)
  295. m.ExpectedCalls = append(m.ExpectedCalls, c)
  296. return c
  297. }
  298. // /*
  299. // Recording and responding to activity
  300. // */
  301. func (m *Mock) findExpectedCall(method string, arguments ...interface{}) (int, *Call) {
  302. var expectedCall *Call
  303. for i, call := range m.ExpectedCalls {
  304. if call.Method == method {
  305. _, diffCount := call.Arguments.Diff(arguments)
  306. if diffCount == 0 {
  307. expectedCall = call
  308. if call.Repeatability > -1 {
  309. return i, call
  310. }
  311. }
  312. }
  313. }
  314. return -1, expectedCall
  315. }
  316. type matchCandidate struct {
  317. call *Call
  318. mismatch string
  319. diffCount int
  320. }
  321. func (c matchCandidate) isBetterMatchThan(other matchCandidate) bool {
  322. if c.call == nil {
  323. return false
  324. }
  325. if other.call == nil {
  326. return true
  327. }
  328. if c.diffCount > other.diffCount {
  329. return false
  330. }
  331. if c.diffCount < other.diffCount {
  332. return true
  333. }
  334. if c.call.Repeatability > 0 && other.call.Repeatability <= 0 {
  335. return true
  336. }
  337. return false
  338. }
  339. func (m *Mock) findClosestCall(method string, arguments ...interface{}) (*Call, string) {
  340. var bestMatch matchCandidate
  341. for _, call := range m.expectedCalls() {
  342. if call.Method == method {
  343. errInfo, tempDiffCount := call.Arguments.Diff(arguments)
  344. tempCandidate := matchCandidate{
  345. call: call,
  346. mismatch: errInfo,
  347. diffCount: tempDiffCount,
  348. }
  349. if tempCandidate.isBetterMatchThan(bestMatch) {
  350. bestMatch = tempCandidate
  351. }
  352. }
  353. }
  354. return bestMatch.call, bestMatch.mismatch
  355. }
  356. func callString(method string, arguments Arguments, includeArgumentValues bool) string {
  357. var argValsString string
  358. if includeArgumentValues {
  359. var argVals []string
  360. for argIndex, arg := range arguments {
  361. if _, ok := arg.(*FunctionalOptionsArgument); ok {
  362. argVals = append(argVals, fmt.Sprintf("%d: %s", argIndex, arg))
  363. continue
  364. }
  365. argVals = append(argVals, fmt.Sprintf("%d: %#v", argIndex, arg))
  366. }
  367. argValsString = fmt.Sprintf("\n\t\t%s", strings.Join(argVals, "\n\t\t"))
  368. }
  369. return fmt.Sprintf("%s(%s)%s", method, arguments.String(), argValsString)
  370. }
  371. // Called tells the mock object that a method has been called, and gets an array
  372. // of arguments to return. Panics if the call is unexpected (i.e. not preceded by
  373. // appropriate .On .Return() calls)
  374. // If Call.WaitFor is set, blocks until the channel is closed or receives a message.
  375. func (m *Mock) Called(arguments ...interface{}) Arguments {
  376. // get the calling function's name
  377. pc, _, _, ok := runtime.Caller(1)
  378. if !ok {
  379. panic("Couldn't get the caller information")
  380. }
  381. functionPath := runtime.FuncForPC(pc).Name()
  382. // Next four lines are required to use GCCGO function naming conventions.
  383. // For Ex: github_com_docker_libkv_store_mock.WatchTree.pN39_github_com_docker_libkv_store_mock.Mock
  384. // uses interface information unlike golang github.com/docker/libkv/store/mock.(*Mock).WatchTree
  385. // With GCCGO we need to remove interface information starting from pN<dd>.
  386. if gccgoRE.MatchString(functionPath) {
  387. functionPath = gccgoRE.Split(functionPath, -1)[0]
  388. }
  389. parts := strings.Split(functionPath, ".")
  390. functionName := parts[len(parts)-1]
  391. return m.MethodCalled(functionName, arguments...)
  392. }
  393. // MethodCalled tells the mock object that the given method has been called, and gets
  394. // an array of arguments to return. Panics if the call is unexpected (i.e. not preceded
  395. // by appropriate .On .Return() calls)
  396. // If Call.WaitFor is set, blocks until the channel is closed or receives a message.
  397. func (m *Mock) MethodCalled(methodName string, arguments ...interface{}) Arguments {
  398. m.mutex.Lock()
  399. // TODO: could combine expected and closes in single loop
  400. found, call := m.findExpectedCall(methodName, arguments...)
  401. if found < 0 {
  402. // expected call found, but it has already been called with repeatable times
  403. if call != nil {
  404. m.mutex.Unlock()
  405. m.fail("\nassert: mock: The method has been called over %d times.\n\tEither do one more Mock.On(\"%s\").Return(...), or remove extra call.\n\tThis call was unexpected:\n\t\t%s\n\tat: %s", call.totalCalls, methodName, callString(methodName, arguments, true), assert.CallerInfo())
  406. }
  407. // we have to fail here - because we don't know what to do
  408. // as the return arguments. This is because:
  409. //
  410. // a) this is a totally unexpected call to this method,
  411. // b) the arguments are not what was expected, or
  412. // c) the developer has forgotten to add an accompanying On...Return pair.
  413. closestCall, mismatch := m.findClosestCall(methodName, arguments...)
  414. m.mutex.Unlock()
  415. if closestCall != nil {
  416. m.fail("\n\nmock: Unexpected Method Call\n-----------------------------\n\n%s\n\nThe closest call I have is: \n\n%s\n\n%s\nDiff: %s",
  417. callString(methodName, arguments, true),
  418. callString(methodName, closestCall.Arguments, true),
  419. diffArguments(closestCall.Arguments, arguments),
  420. strings.TrimSpace(mismatch),
  421. )
  422. } else {
  423. m.fail("\nassert: mock: I don't know what to return because the method call was unexpected.\n\tEither do Mock.On(\"%s\").Return(...) first, or remove the %s() call.\n\tThis method was unexpected:\n\t\t%s\n\tat: %s", methodName, methodName, callString(methodName, arguments, true), assert.CallerInfo())
  424. }
  425. }
  426. for _, requirement := range call.requires {
  427. if satisfied, _ := requirement.Parent.checkExpectation(requirement); !satisfied {
  428. m.mutex.Unlock()
  429. m.fail("mock: Unexpected Method Call\n-----------------------------\n\n%s\n\nMust not be called before%s:\n\n%s",
  430. callString(call.Method, call.Arguments, true),
  431. func() (s string) {
  432. if requirement.totalCalls > 0 {
  433. s = " another call of"
  434. }
  435. if call.Parent != requirement.Parent {
  436. s += " method from another mock instance"
  437. }
  438. return
  439. }(),
  440. callString(requirement.Method, requirement.Arguments, true),
  441. )
  442. }
  443. }
  444. if call.Repeatability == 1 {
  445. call.Repeatability = -1
  446. } else if call.Repeatability > 1 {
  447. call.Repeatability--
  448. }
  449. call.totalCalls++
  450. // add the call
  451. m.Calls = append(m.Calls, *newCall(m, methodName, assert.CallerInfo(), arguments...))
  452. m.mutex.Unlock()
  453. // block if specified
  454. if call.WaitFor != nil {
  455. <-call.WaitFor
  456. } else {
  457. time.Sleep(call.waitTime)
  458. }
  459. m.mutex.Lock()
  460. panicMsg := call.PanicMsg
  461. m.mutex.Unlock()
  462. if panicMsg != nil {
  463. panic(*panicMsg)
  464. }
  465. m.mutex.Lock()
  466. runFn := call.RunFn
  467. m.mutex.Unlock()
  468. if runFn != nil {
  469. runFn(arguments)
  470. }
  471. m.mutex.Lock()
  472. returnArgs := call.ReturnArguments
  473. m.mutex.Unlock()
  474. return returnArgs
  475. }
  476. /*
  477. Assertions
  478. */
  479. type assertExpectationiser interface {
  480. AssertExpectations(TestingT) bool
  481. }
  482. // AssertExpectationsForObjects asserts that everything specified with On and Return
  483. // of the specified objects was in fact called as expected.
  484. //
  485. // Calls may have occurred in any order.
  486. func AssertExpectationsForObjects(t TestingT, testObjects ...interface{}) bool {
  487. if h, ok := t.(tHelper); ok {
  488. h.Helper()
  489. }
  490. for _, obj := range testObjects {
  491. if m, ok := obj.(*Mock); ok {
  492. t.Logf("Deprecated mock.AssertExpectationsForObjects(myMock.Mock) use mock.AssertExpectationsForObjects(myMock)")
  493. obj = m
  494. }
  495. m := obj.(assertExpectationiser)
  496. if !m.AssertExpectations(t) {
  497. t.Logf("Expectations didn't match for Mock: %+v", reflect.TypeOf(m))
  498. return false
  499. }
  500. }
  501. return true
  502. }
  503. // AssertExpectations asserts that everything specified with On and Return was
  504. // in fact called as expected. Calls may have occurred in any order.
  505. func (m *Mock) AssertExpectations(t TestingT) bool {
  506. if s, ok := t.(interface{ Skipped() bool }); ok && s.Skipped() {
  507. return true
  508. }
  509. if h, ok := t.(tHelper); ok {
  510. h.Helper()
  511. }
  512. m.mutex.Lock()
  513. defer m.mutex.Unlock()
  514. var failedExpectations int
  515. // iterate through each expectation
  516. expectedCalls := m.expectedCalls()
  517. for _, expectedCall := range expectedCalls {
  518. satisfied, reason := m.checkExpectation(expectedCall)
  519. if !satisfied {
  520. failedExpectations++
  521. t.Logf(reason)
  522. }
  523. }
  524. if failedExpectations != 0 {
  525. t.Errorf("FAIL: %d out of %d expectation(s) were met.\n\tThe code you are testing needs to make %d more call(s).\n\tat: %s", len(expectedCalls)-failedExpectations, len(expectedCalls), failedExpectations, assert.CallerInfo())
  526. }
  527. return failedExpectations == 0
  528. }
  529. func (m *Mock) checkExpectation(call *Call) (bool, string) {
  530. if !call.optional && !m.methodWasCalled(call.Method, call.Arguments) && call.totalCalls == 0 {
  531. return false, fmt.Sprintf("FAIL:\t%s(%s)\n\t\tat: %s", call.Method, call.Arguments.String(), call.callerInfo)
  532. }
  533. if call.Repeatability > 0 {
  534. return false, fmt.Sprintf("FAIL:\t%s(%s)\n\t\tat: %s", call.Method, call.Arguments.String(), call.callerInfo)
  535. }
  536. return true, fmt.Sprintf("PASS:\t%s(%s)", call.Method, call.Arguments.String())
  537. }
  538. // AssertNumberOfCalls asserts that the method was called expectedCalls times.
  539. func (m *Mock) AssertNumberOfCalls(t TestingT, methodName string, expectedCalls int) bool {
  540. if h, ok := t.(tHelper); ok {
  541. h.Helper()
  542. }
  543. m.mutex.Lock()
  544. defer m.mutex.Unlock()
  545. var actualCalls int
  546. for _, call := range m.calls() {
  547. if call.Method == methodName {
  548. actualCalls++
  549. }
  550. }
  551. return assert.Equal(t, expectedCalls, actualCalls, fmt.Sprintf("Expected number of calls (%d) does not match the actual number of calls (%d).", expectedCalls, actualCalls))
  552. }
  553. // AssertCalled asserts that the method was called.
  554. // It can produce a false result when an argument is a pointer type and the underlying value changed after calling the mocked method.
  555. func (m *Mock) AssertCalled(t TestingT, methodName string, arguments ...interface{}) bool {
  556. if h, ok := t.(tHelper); ok {
  557. h.Helper()
  558. }
  559. m.mutex.Lock()
  560. defer m.mutex.Unlock()
  561. if !m.methodWasCalled(methodName, arguments) {
  562. var calledWithArgs []string
  563. for _, call := range m.calls() {
  564. calledWithArgs = append(calledWithArgs, fmt.Sprintf("%v", call.Arguments))
  565. }
  566. if len(calledWithArgs) == 0 {
  567. return assert.Fail(t, "Should have called with given arguments",
  568. fmt.Sprintf("Expected %q to have been called with:\n%v\nbut no actual calls happened", methodName, arguments))
  569. }
  570. return assert.Fail(t, "Should have called with given arguments",
  571. fmt.Sprintf("Expected %q to have been called with:\n%v\nbut actual calls were:\n %v", methodName, arguments, strings.Join(calledWithArgs, "\n")))
  572. }
  573. return true
  574. }
  575. // AssertNotCalled asserts that the method was not called.
  576. // It can produce a false result when an argument is a pointer type and the underlying value changed after calling the mocked method.
  577. func (m *Mock) AssertNotCalled(t TestingT, methodName string, arguments ...interface{}) bool {
  578. if h, ok := t.(tHelper); ok {
  579. h.Helper()
  580. }
  581. m.mutex.Lock()
  582. defer m.mutex.Unlock()
  583. if m.methodWasCalled(methodName, arguments) {
  584. return assert.Fail(t, "Should not have called with given arguments",
  585. fmt.Sprintf("Expected %q to not have been called with:\n%v\nbut actually it was.", methodName, arguments))
  586. }
  587. return true
  588. }
  589. // IsMethodCallable checking that the method can be called
  590. // If the method was called more than `Repeatability` return false
  591. func (m *Mock) IsMethodCallable(t TestingT, methodName string, arguments ...interface{}) bool {
  592. if h, ok := t.(tHelper); ok {
  593. h.Helper()
  594. }
  595. m.mutex.Lock()
  596. defer m.mutex.Unlock()
  597. for _, v := range m.ExpectedCalls {
  598. if v.Method != methodName {
  599. continue
  600. }
  601. if len(arguments) != len(v.Arguments) {
  602. continue
  603. }
  604. if v.Repeatability < v.totalCalls {
  605. continue
  606. }
  607. if isArgsEqual(v.Arguments, arguments) {
  608. return true
  609. }
  610. }
  611. return false
  612. }
  613. // isArgsEqual compares arguments
  614. func isArgsEqual(expected Arguments, args []interface{}) bool {
  615. if len(expected) != len(args) {
  616. return false
  617. }
  618. for i, v := range args {
  619. if !reflect.DeepEqual(expected[i], v) {
  620. return false
  621. }
  622. }
  623. return true
  624. }
  625. func (m *Mock) methodWasCalled(methodName string, expected []interface{}) bool {
  626. for _, call := range m.calls() {
  627. if call.Method == methodName {
  628. _, differences := Arguments(expected).Diff(call.Arguments)
  629. if differences == 0 {
  630. // found the expected call
  631. return true
  632. }
  633. }
  634. }
  635. // we didn't find the expected call
  636. return false
  637. }
  638. func (m *Mock) expectedCalls() []*Call {
  639. return append([]*Call{}, m.ExpectedCalls...)
  640. }
  641. func (m *Mock) calls() []Call {
  642. return append([]Call{}, m.Calls...)
  643. }
  644. /*
  645. Arguments
  646. */
  647. // Arguments holds an array of method arguments or return values.
  648. type Arguments []interface{}
  649. const (
  650. // Anything is used in Diff and Assert when the argument being tested
  651. // shouldn't be taken into consideration.
  652. Anything = "mock.Anything"
  653. )
  654. // AnythingOfTypeArgument contains the type of an argument
  655. // for use when type checking. Used in Diff and Assert.
  656. //
  657. // Deprecated: this is an implementation detail that must not be used. Use [AnythingOfType] instead.
  658. type AnythingOfTypeArgument = anythingOfTypeArgument
  659. // anythingOfTypeArgument is a string that contains the type of an argument
  660. // for use when type checking. Used in Diff and Assert.
  661. type anythingOfTypeArgument string
  662. // AnythingOfType returns a special value containing the
  663. // name of the type to check for. The type name will be matched against the type name returned by [reflect.Type.String].
  664. //
  665. // Used in Diff and Assert.
  666. //
  667. // For example:
  668. //
  669. // Assert(t, AnythingOfType("string"), AnythingOfType("int"))
  670. func AnythingOfType(t string) AnythingOfTypeArgument {
  671. return anythingOfTypeArgument(t)
  672. }
  673. // IsTypeArgument is a struct that contains the type of an argument
  674. // for use when type checking. This is an alternative to AnythingOfType.
  675. // Used in Diff and Assert.
  676. type IsTypeArgument struct {
  677. t reflect.Type
  678. }
  679. // IsType returns an IsTypeArgument object containing the type to check for.
  680. // You can provide a zero-value of the type to check. This is an
  681. // alternative to AnythingOfType. Used in Diff and Assert.
  682. //
  683. // For example:
  684. // Assert(t, IsType(""), IsType(0))
  685. func IsType(t interface{}) *IsTypeArgument {
  686. return &IsTypeArgument{t: reflect.TypeOf(t)}
  687. }
  688. // FunctionalOptionsArgument is a struct that contains the type and value of an functional option argument
  689. // for use when type checking.
  690. type FunctionalOptionsArgument struct {
  691. value interface{}
  692. }
  693. // String returns the string representation of FunctionalOptionsArgument
  694. func (f *FunctionalOptionsArgument) String() string {
  695. var name string
  696. tValue := reflect.ValueOf(f.value)
  697. if tValue.Len() > 0 {
  698. name = "[]" + reflect.TypeOf(tValue.Index(0).Interface()).String()
  699. }
  700. return strings.Replace(fmt.Sprintf("%#v", f.value), "[]interface {}", name, 1)
  701. }
  702. // FunctionalOptions returns an FunctionalOptionsArgument object containing the functional option type
  703. // and the values to check of
  704. //
  705. // For example:
  706. // Assert(t, FunctionalOptions("[]foo.FunctionalOption", foo.Opt1(), foo.Opt2()))
  707. func FunctionalOptions(value ...interface{}) *FunctionalOptionsArgument {
  708. return &FunctionalOptionsArgument{
  709. value: value,
  710. }
  711. }
  712. // argumentMatcher performs custom argument matching, returning whether or
  713. // not the argument is matched by the expectation fixture function.
  714. type argumentMatcher struct {
  715. // fn is a function which accepts one argument, and returns a bool.
  716. fn reflect.Value
  717. }
  718. func (f argumentMatcher) Matches(argument interface{}) bool {
  719. expectType := f.fn.Type().In(0)
  720. expectTypeNilSupported := false
  721. switch expectType.Kind() {
  722. case reflect.Interface, reflect.Chan, reflect.Func, reflect.Map, reflect.Slice, reflect.Ptr:
  723. expectTypeNilSupported = true
  724. }
  725. argType := reflect.TypeOf(argument)
  726. var arg reflect.Value
  727. if argType == nil {
  728. arg = reflect.New(expectType).Elem()
  729. } else {
  730. arg = reflect.ValueOf(argument)
  731. }
  732. if argType == nil && !expectTypeNilSupported {
  733. panic(errors.New("attempting to call matcher with nil for non-nil expected type"))
  734. }
  735. if argType == nil || argType.AssignableTo(expectType) {
  736. result := f.fn.Call([]reflect.Value{arg})
  737. return result[0].Bool()
  738. }
  739. return false
  740. }
  741. func (f argumentMatcher) String() string {
  742. return fmt.Sprintf("func(%s) bool", f.fn.Type().In(0).String())
  743. }
  744. // MatchedBy can be used to match a mock call based on only certain properties
  745. // from a complex struct or some calculation. It takes a function that will be
  746. // evaluated with the called argument and will return true when there's a match
  747. // and false otherwise.
  748. //
  749. // Example:
  750. // m.On("Do", MatchedBy(func(req *http.Request) bool { return req.Host == "example.com" }))
  751. //
  752. // |fn|, must be a function accepting a single argument (of the expected type)
  753. // which returns a bool. If |fn| doesn't match the required signature,
  754. // MatchedBy() panics.
  755. func MatchedBy(fn interface{}) argumentMatcher {
  756. fnType := reflect.TypeOf(fn)
  757. if fnType.Kind() != reflect.Func {
  758. panic(fmt.Sprintf("assert: arguments: %s is not a func", fn))
  759. }
  760. if fnType.NumIn() != 1 {
  761. panic(fmt.Sprintf("assert: arguments: %s does not take exactly one argument", fn))
  762. }
  763. if fnType.NumOut() != 1 || fnType.Out(0).Kind() != reflect.Bool {
  764. panic(fmt.Sprintf("assert: arguments: %s does not return a bool", fn))
  765. }
  766. return argumentMatcher{fn: reflect.ValueOf(fn)}
  767. }
  768. // Get Returns the argument at the specified index.
  769. func (args Arguments) Get(index int) interface{} {
  770. if index+1 > len(args) {
  771. panic(fmt.Sprintf("assert: arguments: Cannot call Get(%d) because there are %d argument(s).", index, len(args)))
  772. }
  773. return args[index]
  774. }
  775. // Is gets whether the objects match the arguments specified.
  776. func (args Arguments) Is(objects ...interface{}) bool {
  777. for i, obj := range args {
  778. if obj != objects[i] {
  779. return false
  780. }
  781. }
  782. return true
  783. }
  784. // Diff gets a string describing the differences between the arguments
  785. // and the specified objects.
  786. //
  787. // Returns the diff string and number of differences found.
  788. func (args Arguments) Diff(objects []interface{}) (string, int) {
  789. // TODO: could return string as error and nil for No difference
  790. output := "\n"
  791. var differences int
  792. maxArgCount := len(args)
  793. if len(objects) > maxArgCount {
  794. maxArgCount = len(objects)
  795. }
  796. for i := 0; i < maxArgCount; i++ {
  797. var actual, expected interface{}
  798. var actualFmt, expectedFmt string
  799. if len(objects) <= i {
  800. actual = "(Missing)"
  801. actualFmt = "(Missing)"
  802. } else {
  803. actual = objects[i]
  804. actualFmt = fmt.Sprintf("(%[1]T=%[1]v)", actual)
  805. }
  806. if len(args) <= i {
  807. expected = "(Missing)"
  808. expectedFmt = "(Missing)"
  809. } else {
  810. expected = args[i]
  811. expectedFmt = fmt.Sprintf("(%[1]T=%[1]v)", expected)
  812. }
  813. if matcher, ok := expected.(argumentMatcher); ok {
  814. var matches bool
  815. func() {
  816. defer func() {
  817. if r := recover(); r != nil {
  818. actualFmt = fmt.Sprintf("panic in argument matcher: %v", r)
  819. }
  820. }()
  821. matches = matcher.Matches(actual)
  822. }()
  823. if matches {
  824. output = fmt.Sprintf("%s\t%d: PASS: %s matched by %s\n", output, i, actualFmt, matcher)
  825. } else {
  826. differences++
  827. output = fmt.Sprintf("%s\t%d: FAIL: %s not matched by %s\n", output, i, actualFmt, matcher)
  828. }
  829. } else {
  830. switch expected := expected.(type) {
  831. case anythingOfTypeArgument:
  832. // type checking
  833. if reflect.TypeOf(actual).Name() != string(expected) && reflect.TypeOf(actual).String() != string(expected) {
  834. // not match
  835. differences++
  836. output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, expected, reflect.TypeOf(actual).Name(), actualFmt)
  837. }
  838. case *IsTypeArgument:
  839. actualT := reflect.TypeOf(actual)
  840. if actualT != expected.t {
  841. differences++
  842. output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, expected.t.Name(), actualT.Name(), actualFmt)
  843. }
  844. case *FunctionalOptionsArgument:
  845. t := expected.value
  846. var name string
  847. tValue := reflect.ValueOf(t)
  848. if tValue.Len() > 0 {
  849. name = "[]" + reflect.TypeOf(tValue.Index(0).Interface()).String()
  850. }
  851. tName := reflect.TypeOf(t).Name()
  852. if name != reflect.TypeOf(actual).String() && tValue.Len() != 0 {
  853. differences++
  854. output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, tName, reflect.TypeOf(actual).Name(), actualFmt)
  855. } else {
  856. if ef, af := assertOpts(t, actual); ef == "" && af == "" {
  857. // match
  858. output = fmt.Sprintf("%s\t%d: PASS: %s == %s\n", output, i, tName, tName)
  859. } else {
  860. // not match
  861. differences++
  862. output = fmt.Sprintf("%s\t%d: FAIL: %s != %s\n", output, i, af, ef)
  863. }
  864. }
  865. default:
  866. if assert.ObjectsAreEqual(expected, Anything) || assert.ObjectsAreEqual(actual, Anything) || assert.ObjectsAreEqual(actual, expected) {
  867. // match
  868. output = fmt.Sprintf("%s\t%d: PASS: %s == %s\n", output, i, actualFmt, expectedFmt)
  869. } else {
  870. // not match
  871. differences++
  872. output = fmt.Sprintf("%s\t%d: FAIL: %s != %s\n", output, i, actualFmt, expectedFmt)
  873. }
  874. }
  875. }
  876. }
  877. if differences == 0 {
  878. return "No differences.", differences
  879. }
  880. return output, differences
  881. }
  882. // Assert compares the arguments with the specified objects and fails if
  883. // they do not exactly match.
  884. func (args Arguments) Assert(t TestingT, objects ...interface{}) bool {
  885. if h, ok := t.(tHelper); ok {
  886. h.Helper()
  887. }
  888. // get the differences
  889. diff, diffCount := args.Diff(objects)
  890. if diffCount == 0 {
  891. return true
  892. }
  893. // there are differences... report them...
  894. t.Logf(diff)
  895. t.Errorf("%sArguments do not match.", assert.CallerInfo())
  896. return false
  897. }
  898. // String gets the argument at the specified index. Panics if there is no argument, or
  899. // if the argument is of the wrong type.
  900. //
  901. // If no index is provided, String() returns a complete string representation
  902. // of the arguments.
  903. func (args Arguments) String(indexOrNil ...int) string {
  904. if len(indexOrNil) == 0 {
  905. // normal String() method - return a string representation of the args
  906. var argsStr []string
  907. for _, arg := range args {
  908. argsStr = append(argsStr, fmt.Sprintf("%T", arg)) // handles nil nicely
  909. }
  910. return strings.Join(argsStr, ",")
  911. } else if len(indexOrNil) == 1 {
  912. // Index has been specified - get the argument at that index
  913. index := indexOrNil[0]
  914. var s string
  915. var ok bool
  916. if s, ok = args.Get(index).(string); !ok {
  917. panic(fmt.Sprintf("assert: arguments: String(%d) failed because object wasn't correct type: %s", index, args.Get(index)))
  918. }
  919. return s
  920. }
  921. panic(fmt.Sprintf("assert: arguments: Wrong number of arguments passed to String. Must be 0 or 1, not %d", len(indexOrNil)))
  922. }
  923. // Int gets the argument at the specified index. Panics if there is no argument, or
  924. // if the argument is of the wrong type.
  925. func (args Arguments) Int(index int) int {
  926. var s int
  927. var ok bool
  928. if s, ok = args.Get(index).(int); !ok {
  929. panic(fmt.Sprintf("assert: arguments: Int(%d) failed because object wasn't correct type: %v", index, args.Get(index)))
  930. }
  931. return s
  932. }
  933. // Error gets the argument at the specified index. Panics if there is no argument, or
  934. // if the argument is of the wrong type.
  935. func (args Arguments) Error(index int) error {
  936. obj := args.Get(index)
  937. var s error
  938. var ok bool
  939. if obj == nil {
  940. return nil
  941. }
  942. if s, ok = obj.(error); !ok {
  943. panic(fmt.Sprintf("assert: arguments: Error(%d) failed because object wasn't correct type: %v", index, args.Get(index)))
  944. }
  945. return s
  946. }
  947. // Bool gets the argument at the specified index. Panics if there is no argument, or
  948. // if the argument is of the wrong type.
  949. func (args Arguments) Bool(index int) bool {
  950. var s bool
  951. var ok bool
  952. if s, ok = args.Get(index).(bool); !ok {
  953. panic(fmt.Sprintf("assert: arguments: Bool(%d) failed because object wasn't correct type: %v", index, args.Get(index)))
  954. }
  955. return s
  956. }
  957. func typeAndKind(v interface{}) (reflect.Type, reflect.Kind) {
  958. t := reflect.TypeOf(v)
  959. k := t.Kind()
  960. if k == reflect.Ptr {
  961. t = t.Elem()
  962. k = t.Kind()
  963. }
  964. return t, k
  965. }
  966. func diffArguments(expected Arguments, actual Arguments) string {
  967. if len(expected) != len(actual) {
  968. return fmt.Sprintf("Provided %v arguments, mocked for %v arguments", len(expected), len(actual))
  969. }
  970. for x := range expected {
  971. if diffString := diff(expected[x], actual[x]); diffString != "" {
  972. return fmt.Sprintf("Difference found in argument %v:\n\n%s", x, diffString)
  973. }
  974. }
  975. return ""
  976. }
  977. // diff returns a diff of both values as long as both are of the same type and
  978. // are a struct, map, slice or array. Otherwise it returns an empty string.
  979. func diff(expected interface{}, actual interface{}) string {
  980. if expected == nil || actual == nil {
  981. return ""
  982. }
  983. et, ek := typeAndKind(expected)
  984. at, _ := typeAndKind(actual)
  985. if et != at {
  986. return ""
  987. }
  988. if ek != reflect.Struct && ek != reflect.Map && ek != reflect.Slice && ek != reflect.Array {
  989. return ""
  990. }
  991. e := spewConfig.Sdump(expected)
  992. a := spewConfig.Sdump(actual)
  993. diff, _ := difflib.GetUnifiedDiffString(difflib.UnifiedDiff{
  994. A: difflib.SplitLines(e),
  995. B: difflib.SplitLines(a),
  996. FromFile: "Expected",
  997. FromDate: "",
  998. ToFile: "Actual",
  999. ToDate: "",
  1000. Context: 1,
  1001. })
  1002. return diff
  1003. }
  1004. var spewConfig = spew.ConfigState{
  1005. Indent: " ",
  1006. DisablePointerAddresses: true,
  1007. DisableCapacities: true,
  1008. SortKeys: true,
  1009. }
  1010. type tHelper interface {
  1011. Helper()
  1012. }
  1013. func assertOpts(expected, actual interface{}) (expectedFmt, actualFmt string) {
  1014. expectedOpts := reflect.ValueOf(expected)
  1015. actualOpts := reflect.ValueOf(actual)
  1016. var expectedNames []string
  1017. for i := 0; i < expectedOpts.Len(); i++ {
  1018. expectedNames = append(expectedNames, funcName(expectedOpts.Index(i).Interface()))
  1019. }
  1020. var actualNames []string
  1021. for i := 0; i < actualOpts.Len(); i++ {
  1022. actualNames = append(actualNames, funcName(actualOpts.Index(i).Interface()))
  1023. }
  1024. if !assert.ObjectsAreEqual(expectedNames, actualNames) {
  1025. expectedFmt = fmt.Sprintf("%v", expectedNames)
  1026. actualFmt = fmt.Sprintf("%v", actualNames)
  1027. return
  1028. }
  1029. for i := 0; i < expectedOpts.Len(); i++ {
  1030. expectedOpt := expectedOpts.Index(i).Interface()
  1031. actualOpt := actualOpts.Index(i).Interface()
  1032. expectedFunc := expectedNames[i]
  1033. actualFunc := actualNames[i]
  1034. if expectedFunc != actualFunc {
  1035. expectedFmt = expectedFunc
  1036. actualFmt = actualFunc
  1037. return
  1038. }
  1039. ot := reflect.TypeOf(expectedOpt)
  1040. var expectedValues []reflect.Value
  1041. var actualValues []reflect.Value
  1042. if ot.NumIn() == 0 {
  1043. return
  1044. }
  1045. for i := 0; i < ot.NumIn(); i++ {
  1046. vt := ot.In(i).Elem()
  1047. expectedValues = append(expectedValues, reflect.New(vt))
  1048. actualValues = append(actualValues, reflect.New(vt))
  1049. }
  1050. reflect.ValueOf(expectedOpt).Call(expectedValues)
  1051. reflect.ValueOf(actualOpt).Call(actualValues)
  1052. for i := 0; i < ot.NumIn(); i++ {
  1053. if !assert.ObjectsAreEqual(expectedValues[i].Interface(), actualValues[i].Interface()) {
  1054. expectedFmt = fmt.Sprintf("%s %+v", expectedNames[i], expectedValues[i].Interface())
  1055. actualFmt = fmt.Sprintf("%s %+v", expectedNames[i], actualValues[i].Interface())
  1056. return
  1057. }
  1058. }
  1059. }
  1060. return "", ""
  1061. }
  1062. func funcName(opt interface{}) string {
  1063. n := runtime.FuncForPC(reflect.ValueOf(opt).Pointer()).Name()
  1064. return strings.TrimSuffix(path.Base(n), path.Ext(n))
  1065. }