Skip to content

Commit 501ea1d

Browse files
committed
perf: increase performance and fix potential concurrency issues
1 parent 48f27fe commit 501ea1d

1 file changed

Lines changed: 30 additions & 17 deletions

File tree

schedule.go

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
package schedule
22

3-
import "time"
3+
import (
4+
"sync"
5+
"sync/atomic"
6+
"time"
7+
)
48

59
// Task holds information about the running task and can be used to stop running tasks.
610
type Task struct {
711
stop chan struct{}
812
nextExecution time.Time
913
startedAt time.Time
10-
stopped bool
14+
stopped int32 // 0 means active, 1 means stopped
15+
once sync.Once
1116
}
1217

1318
// newTask creates a new Task.
@@ -35,7 +40,7 @@ func (s *Task) ExecutesIn() time.Duration {
3540

3641
// IsActive returns true if the scheduler is active.
3742
func (s *Task) IsActive() bool {
38-
return !s.stopped
43+
return atomic.LoadInt32(&s.stopped) == 0
3944
}
4045

4146
// Wait blocks until the scheduler is stopped.
@@ -46,26 +51,29 @@ func (s *Task) Wait() {
4651

4752
// Stop stops the scheduler.
4853
func (s *Task) Stop() {
49-
if s.stopped {
50-
return
51-
}
52-
53-
s.stopped = true
54-
close(s.stop)
54+
s.once.Do(func() {
55+
atomic.StoreInt32(&s.stopped, 1)
56+
close(s.stop)
57+
})
5558
}
5659

5760
// After executes the task after the given duration.
5861
// The function is non-blocking. If you want to wait for the task to be executed, use the Task.Wait method.
5962
func After(duration time.Duration, task func()) *Task {
6063
scheduler := newTask()
6164
scheduler.nextExecution = time.Now().Add(duration)
65+
timer := time.NewTimer(duration)
6266

6367
go func() {
6468
select {
65-
case <-time.After(duration):
69+
case <-timer.C:
6670
task()
6771
scheduler.Stop()
6872
case <-scheduler.stop:
73+
// If the task is stopped before the timer fires, stop the timer.
74+
if !timer.Stop() {
75+
<-timer.C // drain if necessary
76+
}
6977
return
7078
}
7179
}()
@@ -78,13 +86,21 @@ func After(duration time.Duration, task func()) *Task {
7886
func At(t time.Time, task func()) *Task {
7987
scheduler := newTask()
8088
scheduler.nextExecution = t
89+
d := time.Until(t)
90+
if d < 0 {
91+
d = 0
92+
}
93+
timer := time.NewTimer(d)
8194

8295
go func() {
8396
select {
84-
case <-time.After(time.Until(t)):
97+
case <-timer.C:
8598
task()
8699
scheduler.Stop()
87100
case <-scheduler.stop:
101+
if !timer.Stop() {
102+
<-timer.C
103+
}
88104
return
89105
}
90106
}()
@@ -97,23 +113,20 @@ func At(t time.Time, task func()) *Task {
97113
func Every(interval time.Duration, task func() bool) *Task {
98114
scheduler := newTask()
99115
scheduler.nextExecution = time.Now().Add(interval)
100-
101116
ticker := time.NewTicker(interval)
102117

103118
go func() {
104119
for {
105120
select {
106121
case <-ticker.C:
107-
res := task()
108-
if !res {
122+
if !task() {
109123
scheduler.Stop()
124+
ticker.Stop()
125+
return
110126
}
111-
112127
scheduler.nextExecution = time.Now().Add(interval)
113-
114128
case <-scheduler.stop:
115129
ticker.Stop()
116-
117130
return
118131
}
119132
}

0 commit comments

Comments
 (0)