util.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. package tos
  2. import (
  3. "fmt"
  4. "os"
  5. "time"
  6. )
  7. func min(a int, b int) int {
  8. if a < b {
  9. return a
  10. }
  11. return b
  12. }
  13. const (
  14. EventPartSucceed = 3
  15. EventPartFailed = 4
  16. EventPartAborted = 5 // The task needs to be interrupted in case of 403, 404, 405 errors
  17. )
  18. type task interface {
  19. do() (interface{}, error)
  20. getBaseInput() interface{}
  21. }
  22. type checkPoint interface {
  23. WriteToFile() error
  24. UpdatePartsInfo(result interface{})
  25. GetCheckPointFilePath() string
  26. }
  27. type taskGroup interface {
  28. // Wait 等待执行结果, success 是此次成功的 task 数量
  29. Wait() (success int, err error)
  30. // RunWorker 启动worker
  31. RunWorker()
  32. // Scheduler 分发任务
  33. Scheduler()
  34. }
  35. type postEvent interface {
  36. PostEvent(eventType int, result interface{}, taskErr error)
  37. }
  38. type taskGroupImpl struct {
  39. cancelHandle chan struct{}
  40. abortHandle chan struct{}
  41. errCh chan error
  42. resultsCh chan interface{}
  43. tasksCh chan task
  44. routinesNum int
  45. tasks []task
  46. checkPoint checkPoint
  47. enableCheckPoint bool
  48. postEvent postEvent
  49. }
  50. func (t *taskGroupImpl) Wait() (int, error) {
  51. successNum := 0
  52. failNum := 0
  53. Loop:
  54. for successNum+failNum < len(t.tasks) {
  55. select {
  56. case <-t.abortHandle:
  57. break Loop
  58. case <-t.cancelHandle:
  59. break Loop
  60. case part := <-t.resultsCh:
  61. successNum++
  62. t.checkPoint.UpdatePartsInfo(part)
  63. if t.enableCheckPoint {
  64. t.checkPoint.WriteToFile()
  65. }
  66. t.postEvent.PostEvent(EventPartSucceed, part, nil)
  67. case taskErr := <-t.errCh:
  68. if StatusCode(taskErr) == 403 || StatusCode(taskErr) == 404 || StatusCode(taskErr) == 405 {
  69. close(t.abortHandle)
  70. _ = os.Remove(t.checkPoint.GetCheckPointFilePath())
  71. t.postEvent.PostEvent(EventPartAborted, nil, taskErr)
  72. return successNum, fmt.Errorf("status code not service error, err:%s. ", taskErr.Error())
  73. }
  74. t.postEvent.PostEvent(EventPartFailed, nil, taskErr)
  75. failNum++
  76. }
  77. }
  78. return successNum, nil
  79. }
  80. func newTaskGroup(cancelHandle chan struct{}, routinesNum int, checkPoint checkPoint, postEvent postEvent, enableCheckPoint bool, tasks []task) taskGroup {
  81. taskBufferSize := min(routinesNum, DefaultTaskBufferSize)
  82. tasksCh := make(chan task, taskBufferSize)
  83. return &taskGroupImpl{
  84. cancelHandle: cancelHandle,
  85. abortHandle: make(chan struct{}),
  86. errCh: make(chan error),
  87. resultsCh: make(chan interface{}),
  88. tasksCh: tasksCh,
  89. routinesNum: routinesNum,
  90. tasks: tasks,
  91. checkPoint: checkPoint,
  92. enableCheckPoint: enableCheckPoint,
  93. postEvent: postEvent,
  94. }
  95. }
  96. func (t *taskGroupImpl) RunWorker() {
  97. for i := 0; i < t.routinesNum; i++ {
  98. go t.worker()
  99. }
  100. }
  101. func (t *taskGroupImpl) Scheduler() {
  102. go func() {
  103. for _, task := range t.tasks {
  104. select {
  105. case <-t.cancelHandle:
  106. return
  107. case <-t.abortHandle:
  108. return
  109. default:
  110. t.tasksCh <- task
  111. }
  112. }
  113. close(t.tasksCh)
  114. }()
  115. }
  116. func (t *taskGroupImpl) worker() {
  117. for {
  118. select {
  119. case <-t.cancelHandle:
  120. return
  121. case <-t.abortHandle:
  122. return
  123. case task, ok := <-t.tasksCh:
  124. if !ok {
  125. return
  126. }
  127. result, err := task.do()
  128. if err != nil {
  129. t.errCh <- err
  130. }
  131. if result != nil {
  132. t.resultsCh <- result
  133. }
  134. }
  135. }
  136. }
  137. func GetUnixTimeMs() int64 {
  138. return ToMillis(time.Now())
  139. }
  140. func ToMillis(t time.Time) int64 {
  141. return t.UnixNano() / int64(time.Millisecond)
  142. }
  143. func StringPtr(input string) *string {
  144. return &input
  145. }