Skip to content

Commit ff64d28

Browse files
committed
Add concurrency to mdformatter
Signed-off-by: Saswata Mukherjee <saswataminsta@yahoo.com>
1 parent f49414c commit ff64d28

3 files changed

Lines changed: 71 additions & 31 deletions

File tree

pkg/mdformatter/linktransformer/link.go

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ type validator struct {
125125
c *colly.Collector
126126

127127
futureMu sync.Mutex
128-
destFutures map[futureKey]*futureResult
128+
destFutures sync.Map
129129
}
130130

131131
type futureKey struct {
@@ -156,14 +156,13 @@ func NewValidator(ctx context.Context, logger log.Logger, linksValidateConfig []
156156
localLinks: map[string]*[]string{},
157157
remoteLinks: map[string]error{},
158158
c: colly.NewCollector(colly.Async(), colly.StdlibContext(ctx)),
159-
destFutures: map[futureKey]*futureResult{},
160159
}
161160
// Set very soft limits.
162161
// E.g github has 50-5000 https://docs.github.com/en/free-pro-team@latest/rest/reference/rate-limit limit depending
163162
// on api (only search is below 100).
164163
if err := v.c.Limit(&colly.LimitRule{
165164
DomainGlob: "*",
166-
Parallelism: 100,
165+
Parallelism: 5,
167166
}); err != nil {
168167
return nil, err
169168
}
@@ -234,6 +233,11 @@ func MustNewValidator(logger log.Logger, linksValidateConfig []byte, anchorDir s
234233
}
235234

236235
func (v *validator) TransformDestination(ctx mdformatter.SourceContext, destination []byte) (_ []byte, err error) {
236+
select {
237+
case <-ctx.Context.Done():
238+
return nil, ctx.Err()
239+
default:
240+
}
237241
v.visit(ctx.Filepath, string(destination), ctx.LineNumbers)
238242
return destination, nil
239243
}
@@ -242,7 +246,14 @@ func (v *validator) Close(ctx mdformatter.SourceContext) error {
242246
v.c.Wait()
243247

244248
var keys []futureKey
245-
for k := range v.destFutures {
249+
// Read map from sync.Map.
250+
destFuturesMap := map[futureKey]*futureResult{}
251+
v.destFutures.Range(func(key, value interface{}) bool {
252+
destFuturesMap[key.(futureKey)] = value.(*futureResult)
253+
return true
254+
})
255+
256+
for k := range destFuturesMap {
246257
if k.filepath != ctx.Filepath {
247258
continue
248259
}
@@ -263,7 +274,7 @@ func (v *validator) Close(ctx mdformatter.SourceContext) error {
263274
}
264275

265276
for _, k := range keys {
266-
f := v.destFutures[k]
277+
f := destFuturesMap[k]
267278
if err := f.resultFn(); err != nil {
268279
if f.cases == 1 {
269280
merr.Add(errors.Wrapf(err, "%v:%v", path, k.lineNumbers))
@@ -279,19 +290,27 @@ func (v *validator) visit(filepath string, dest string, lineNumbers string) {
279290
v.futureMu.Lock()
280291
defer v.futureMu.Unlock()
281292
k := futureKey{filepath: filepath, dest: dest, lineNumbers: lineNumbers}
282-
if _, ok := v.destFutures[k]; ok {
283-
v.destFutures[k].cases++
293+
// If key present, delete and increment cases.
294+
if prevResult, loaded := v.destFutures.LoadAndDelete(k); loaded {
295+
newResult := prevResult.(*futureResult)
296+
newResult.cases++
297+
v.destFutures.Store(k, newResult)
284298
return
285299
}
286-
v.destFutures[k] = &futureResult{cases: 1, resultFn: func() error { return nil }}
300+
301+
// Key not present, no store.
302+
v.destFutures.Store(k, &futureResult{cases: 1, resultFn: func() error { return nil }})
287303
matches := remoteLinkPrefixRe.FindAllStringIndex(dest, 1)
288304
if matches == nil {
289305
// Relative or absolute path. Check if exists.
290306
newDest := absLocalLink(v.anchorDir, filepath, dest)
291307

292308
// Local link. Check if exists.
293309
if err := v.localLinks.Lookup(newDest); err != nil {
294-
v.destFutures[k].resultFn = func() error { return errors.Wrapf(err, "link %v, normalized to", dest) }
310+
prevResult, _ := v.destFutures.LoadAndDelete(k)
311+
newResult := prevResult.(*futureResult)
312+
newResult.resultFn = func() error { return errors.Wrapf(err, "link %v, normalized to", dest) }
313+
v.destFutures.Store(k, newResult)
295314
}
296315
return
297316
}

pkg/mdformatter/linktransformer/validator.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,11 @@ func (v GitHubValidator) IsValid(k futureKey, r *validator) (bool, error) {
3434
// RoundTripValidator.IsValid returns true if url is checked by colly.
3535
func (v RoundTripValidator) IsValid(k futureKey, r *validator) (bool, error) {
3636
// Result will be in future.
37-
r.destFutures[k].resultFn = func() error { return r.remoteLinks[k.dest] }
37+
prevResult, _ := r.destFutures.LoadAndDelete(k)
38+
newResult := prevResult.(*futureResult)
39+
newResult.resultFn = func() error { return r.remoteLinks[k.dest] }
40+
r.destFutures.Store(k, newResult)
41+
3842
r.rMu.RLock()
3943
if _, ok := r.remoteLinks[k.dest]; ok {
4044
r.rMu.RUnlock()

pkg/mdformatter/mdformatter.go

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"io/ioutil"
1111
"os"
1212
"sort"
13+
"sync"
1314
"time"
1415

1516
"github.com/Kunde21/markdownfmt/v2/markdown"
@@ -217,7 +218,7 @@ func newSpinner(suffix string) (*yacspin.Spinner, error) {
217218

218219
// Format formats given markdown files in-place. IsFormatted `With...` function to see what modifiers you can add.
219220
func Format(ctx context.Context, logger log.Logger, files []string, opts ...Option) error {
220-
spin, err := newSpinner(" Formatting: ")
221+
spin, err := newSpinner(" Formatting... ")
221222
if err != nil {
222223
return err
223224
}
@@ -228,7 +229,7 @@ func Format(ctx context.Context, logger log.Logger, files []string, opts ...Opti
228229
// If diff is empty it means all files are formatted.
229230
func IsFormatted(ctx context.Context, logger log.Logger, files []string, opts ...Option) (diffs Diffs, err error) {
230231
d := Diffs{}
231-
spin, err := newSpinner(" Checking: ")
232+
spin, err := newSpinner(" Checking... ")
232233
if err != nil {
233234
return nil, err
234235
}
@@ -240,57 +241,73 @@ func IsFormatted(ctx context.Context, logger log.Logger, files []string, opts ..
240241

241242
func format(ctx context.Context, logger log.Logger, files []string, diffs *Diffs, spin *yacspin.Spinner, opts ...Option) error {
242243
f := New(ctx, opts...)
243-
b := bytes.Buffer{}
244-
// TODO(bwplotka): Add concurrency (collector will need to redone).
244+
errorChannel := make(chan error)
245+
var wg sync.WaitGroup
245246

246247
errs := merrors.New()
247248
if spin != nil {
248249
errs.Add(spin.Start())
249250
}
251+
252+
wg.Add(len(files))
253+
254+
go func() {
255+
wg.Wait()
256+
close(errorChannel)
257+
}()
258+
250259
for _, fn := range files {
251-
select {
252-
case <-ctx.Done():
253-
return ctx.Err()
254-
default:
255-
}
256-
if spin != nil {
257-
spin.Message(fn + "...")
258-
}
259-
errs.Add(func() error {
260+
go func(fn string) {
261+
defer wg.Done()
262+
b := bytes.Buffer{}
263+
260264
file, err := os.OpenFile(fn, os.O_RDWR, 0)
261265
if err != nil {
262-
return errors.Wrapf(err, "open %v", fn)
266+
errorChannel <- errors.Wrapf(err, "open %v", fn)
267+
return
263268
}
264269
defer logerrcapture.ExhaustClose(logger, file, "close file %v", fn)
265270

266271
b.Reset()
267272
if err := f.Format(file, &b); err != nil {
268-
return err
273+
errorChannel <- err
274+
return
269275
}
270276

271277
if diffs != nil {
272278
if _, err := file.Seek(0, 0); err != nil {
273-
return err
279+
errorChannel <- err
280+
return
274281
}
275282

276283
in, err := ioutil.ReadAll(file)
277284
if err != nil {
278-
return errors.Wrapf(err, "read all %v", fn)
285+
errorChannel <- errors.Wrapf(err, "read all %v", fn)
286+
return
279287
}
280288

281289
if !bytes.Equal(in, b.Bytes()) {
282290
*diffs = append(*diffs, gitdiff.CompareBytes(in, fn, b.Bytes(), fn+" (formatted)"))
283291
}
284-
return nil
292+
return
285293
}
286294

287295
n, err := file.WriteAt(b.Bytes(), 0)
288296
if err != nil {
289-
return errors.Wrapf(err, "write %v", fn)
297+
errorChannel <- errors.Wrapf(err, "write %v", fn)
298+
return
299+
}
300+
if err := file.Truncate(int64(n)); err != nil {
301+
errorChannel <- err
302+
return
290303
}
291-
return file.Truncate(int64(n))
292-
}())
304+
}(fn)
293305
}
306+
307+
for err := range errorChannel {
308+
errs.Add(err)
309+
}
310+
294311
if spin != nil {
295312
errs.Add(spin.Stop())
296313
}

0 commit comments

Comments
 (0)