@@ -14,6 +14,7 @@ import (
1414 "github.com/libp2p/go-libp2p/core/host"
1515 "github.com/libp2p/go-libp2p/core/peer"
1616 "github.com/rs/zerolog"
17+ "github.com/sourcegraph/conc/pool"
1718 "github.com/sprintertech/sprinter-signing/comm"
1819 "github.com/sprintertech/sprinter-signing/tss/message"
1920)
@@ -57,15 +58,17 @@ func (b *BaseTss) ProcessInboundMessages(ctx context.Context, msgChan chan *comm
5758
5859 for {
5960 select {
60- case wMsg := <- msgChan :
61+ case msg := <- msgChan :
6162 {
62- go func (wMsg * comm.WrappedMessage ) {
63+ wMsg := msg
64+ p := pool .New ().WithContext (ctx ).WithCancelOnError ()
65+ p .Go (func (ctx context.Context ) error {
6366 b .Log .Debug ().Msgf ("Processed inbound message from %s" , wMsg .From )
6467
6568 msg , err := message .UnmarshalTssMessage (wMsg .Payload )
6669 if err != nil {
6770 b .Log .Error ().Err (err ).Msgf ("Failed unmarshaling message from %s" , wMsg .From )
68- return
71+ return err
6972 }
7073
7174 ok , err := b .Party .UpdateFromBytes (
@@ -75,10 +78,11 @@ func (b *BaseTss) ProcessInboundMessages(ctx context.Context, msgChan chan *comm
7578 new (big.Int ).SetBytes ([]byte (b .SID )))
7679 if ! ok {
7780 b .Log .Error ().Err (err ).Msgf ("Failed updating party with message from %s" , wMsg .From )
78- return
81+ return err
7982 }
8083 b .Log .Debug ().Msgf ("Updated party with message from %s" , wMsg .From )
81- }(wMsg )
84+ return nil
85+ })
8286 }
8387 case <- ctx .Done ():
8488 return nil
@@ -93,33 +97,37 @@ func (b *BaseTss) ProcessOutboundMessages(ctx context.Context, outChn chan tss.M
9397 select {
9498 case msg := <- outChn :
9599 {
96- go func (msg tss.Message ) {
97- b .Log .Debug ().Msg (msg .String ())
98- wireBytes , routing , err := msg .WireBytes ()
100+ wMsg := msg
101+ p := pool .New ().WithContext (ctx ).WithCancelOnError ()
102+ p .Go (func (ctx context.Context ) error {
103+ b .Log .Debug ().Msg (wMsg .String ())
104+ wireBytes , routing , err := wMsg .WireBytes ()
99105 if err != nil {
100106 b .Log .Error ().Err (err ).Msgf ("Failed getting wire bytes" )
101- return
107+ return err
102108 }
103109
104110 msgBytes , err := message .MarshalTssMessage (wireBytes , routing .IsBroadcast )
105111 if err != nil {
106112 b .Log .Error ().Err (err ).Msgf ("Failed marshaling message" )
107- return
113+ return err
108114 }
109115
110- peers , err := b .BroadcastPeers (msg )
116+ peers , err := b .BroadcastPeers (wMsg )
111117 if err != nil {
112118 b .Log .Error ().Err (err ).Msgf ("Failed getting broadcast peers" )
113- return
119+ return err
114120 }
115121
116122 b .Log .Debug ().Msgf ("Sending message to %s" , peers )
117123 err = b .Communication .Broadcast (peers , msgBytes , messageType , b .SessionID ())
118124 if err != nil {
119125 b .Log .Error ().Err (err ).Msgf ("Failed broadcasting message" )
120- return
126+ return err
121127 }
122- }(msg )
128+
129+ return nil
130+ })
123131 }
124132 case <- ctx .Done ():
125133 {
0 commit comments