Skip to content

Commit 087ab44

Browse files
committed
use ContextDialer isof DialContextFn
1 parent e0a97d7 commit 087ab44

6 files changed

Lines changed: 52 additions & 15 deletions

File tree

contextdialer.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
package bwlimit
2+
3+
import (
4+
"context"
5+
"net"
6+
)
7+
8+
type ContextDialer interface {
9+
DialContext(ctx context.Context, network, address string) (conn net.Conn, err error)
10+
}

dialer.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@ import (
66
)
77

88
type Dialer struct {
9-
DialContextFn // DialContext function we wrap
9+
ContextDialer // ContextDialer we wrap
1010
*Limiter // Limiter to use
1111
}
1212

1313
func (d *Dialer) DialContext(ctx context.Context, network, address string) (conn net.Conn, err error) {
14-
if conn, err = d.DialContextFn(ctx, network, address); err == nil {
14+
if conn, err = d.ContextDialer.DialContext(ctx, network, address); err == nil {
1515
conn = &Conn{
1616
Conn: conn,
1717
Limiter: d.Limiter,

dialer_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@ func TestDialer_Dial(t *testing.T) {
1313
defer l2.Stop()
1414

1515
d1 := &Dialer{
16-
DialContextFn: l1.Wrap(nil),
1716
Limiter: l1,
17+
ContextDialer: l1.Wrap(nil),
1818
}
1919
d2 := &Dialer{
20-
DialContextFn: l2.Wrap(d1.DialContext),
2120
Limiter: l2,
21+
ContextDialer: l2.Wrap(d1),
2222
}
2323

2424
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))

example_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ func ExampleLimiter_NewLimiter() {
2323

2424
// wrap the default http transport DialContext
2525
tp := http.DefaultTransport.(*http.Transport)
26-
tp.DialContext = lim.Wrap(tp.DialContext)
26+
tp.DialContext = lim.Wrap(nil).DialContext
2727

2828
// make a request and time it
2929
now := time.Now()

limiter.go

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
package bwlimit
22

33
import (
4-
"context"
54
"net"
65
)
76

8-
type DialContextFn func(ctx context.Context, network string, address string) (net.Conn, error)
9-
107
var DefaultNetDialer = &net.Dialer{}
118

129
type Limiter struct {
@@ -32,12 +29,31 @@ func (l *Limiter) Stop() {
3229
l.Writes.Stop()
3330
}
3431

35-
// Wrap returns a DialContextFn using the given fn that is bandwidth limited by this Limiter.
36-
// If fn is nil we use DefaultNetDialer.DialContext.
37-
func (l *Limiter) Wrap(fn DialContextFn) DialContextFn {
38-
if fn == nil {
39-
fn = DefaultNetDialer.DialContext
32+
// alreadyLimits returns true if cd is already limited by this Limiter.
33+
// This lets us help the user avoiding double-accounting bandwidth.
34+
func (l *Limiter) alreadyLimits(cd ContextDialer) bool {
35+
for {
36+
if d, ok := cd.(*Dialer); ok {
37+
if d.Limiter == l {
38+
return true
39+
}
40+
cd = d.ContextDialer
41+
} else {
42+
return false
43+
}
44+
}
45+
}
46+
47+
// Wrap returns a ContextDialer wrapping cd that is bandwidth limited by this Limiter.
48+
//
49+
// If cd is nil we use DefaultNetDialer. If cd is already limited by this Limiter, cd
50+
// is returned unchanged.
51+
func (l *Limiter) Wrap(cd ContextDialer) ContextDialer {
52+
if cd == nil {
53+
cd = DefaultNetDialer
54+
}
55+
if l.alreadyLimits(cd) {
56+
return cd
4057
}
41-
d := &Dialer{DialContextFn: fn, Limiter: l}
42-
return d.DialContext
58+
return &Dialer{ContextDialer: cd, Limiter: l}
4359
}

limiter_test.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,14 @@ func TestLimiter_Stop(t *testing.T) {
3131
t.Error(n)
3232
}
3333
}
34+
35+
func TestLimiter_double_Wrap(t *testing.T) {
36+
l := NewLimiter()
37+
defer l.Stop()
38+
39+
d1 := l.Wrap(nil)
40+
d2 := l.Wrap(d1)
41+
if d1 != d2 {
42+
t.Error(d1, d2)
43+
}
44+
}

0 commit comments

Comments
 (0)