1+ use futures:: { channel:: mpsc, SinkExt , StreamExt } ;
2+ use futures_util:: { io, AsyncRead , AsyncWrite } ;
3+ use futures_util:: stream:: Stream ;
4+ use gloo_net:: websocket:: { Message , futures:: WebSocket } ;
5+ use std:: pin:: Pin ;
6+ use std:: task:: { Context , Poll } ;
7+ //use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
8+
9+ use crate :: utils;
10+
11+ pub struct WsIo {
12+ incoming : mpsc:: UnboundedReceiver < Message > ,
13+ outgoing : mpsc:: UnboundedSender < Message > ,
14+ read_buffer : Vec < u8 > ,
15+ }
16+
17+ impl WsIo {
18+ pub fn new ( ws : WebSocket ) -> Self {
19+ let ( outgoing, write) = mpsc:: unbounded ( ) ;
20+ let ( read, incoming) = mpsc:: unbounded ( ) ;
21+ let ( mut sink, mut source) = ws. split ( ) ;
22+ utils:: spawn ( async move {
23+ while let Some ( msg) = source. next ( ) . await {
24+ match msg {
25+ Ok ( Message :: Text ( _) ) => { } // Ignore text messages
26+ Ok ( Message :: Bytes ( data) ) => {
27+ let _ = read. unbounded_send ( Message :: Bytes ( data) ) ;
28+ }
29+ _ => break ,
30+ }
31+ }
32+ } ) ;
33+
34+ utils:: spawn ( async move {
35+ let mut outgoing = write;
36+ while let Some ( msg) = outgoing. next ( ) . await {
37+ let _ = sink. send ( msg) ;
38+ }
39+ let _ = sink. close ( ) ; // TODO test that this actualy works
40+ } ) ;
41+
42+ WsIo {
43+ incoming,
44+ outgoing,
45+ read_buffer : Vec :: new ( ) ,
46+ }
47+ }
48+ }
49+
50+ impl AsyncRead for WsIo {
51+ fn poll_read (
52+ mut self : Pin < & mut Self > ,
53+ cx : & mut Context < ' _ > ,
54+ buf : & mut [ u8 ] ,
55+ ) -> Poll < io:: Result < usize > > {
56+ if !self . read_buffer . is_empty ( ) {
57+ let len = std:: cmp:: min ( buf. len ( ) , self . read_buffer . len ( ) ) ;
58+ buf[ ..len] . copy_from_slice ( & self . read_buffer [ ..len] ) ;
59+ self . read_buffer . drain ( ..len) ;
60+ return Poll :: Ready ( Ok ( len) ) ;
61+ }
62+
63+ match Pin :: new ( & mut self . incoming ) . poll_next ( cx) {
64+ Poll :: Ready ( Some ( Message :: Bytes ( data) ) ) => {
65+ let len = std:: cmp:: min ( buf. len ( ) , data. len ( ) ) ;
66+ buf[ ..len] . copy_from_slice ( & data[ ..len] ) ;
67+ if data. len ( ) > len {
68+ self . read_buffer . extend_from_slice ( & data[ len..] ) ;
69+ }
70+ Poll :: Ready ( Ok ( len) )
71+ }
72+ Poll :: Ready ( Some ( _) ) => Poll :: Pending , // Ignore non-binary messages
73+ Poll :: Ready ( None ) => Poll :: Ready ( Ok ( 0 ) ) , // End of stream, no data read
74+ Poll :: Pending => Poll :: Pending ,
75+ }
76+ }
77+ }
78+
79+ impl AsyncWrite for WsIo {
80+ fn poll_write (
81+ mut self : Pin < & mut Self > ,
82+ cx : & mut Context < ' _ > ,
83+ data : & [ u8 ] ,
84+ ) -> Poll < io:: Result < usize > > {
85+ match self . outgoing . poll_ready ( cx) {
86+ Poll :: Ready ( Ok ( ( ) ) ) => {
87+ let _ = self . outgoing . start_send ( Message :: Bytes ( data. to_vec ( ) ) ) ;
88+ Poll :: Ready ( Ok ( data. len ( ) ) )
89+ }
90+ Poll :: Ready ( Err ( _) ) => Poll :: Ready ( Err ( std:: io:: Error :: new (
91+ std:: io:: ErrorKind :: Other ,
92+ "WebSocket send channel closed" ,
93+ ) ) ) ,
94+ Poll :: Pending => Poll :: Pending ,
95+ }
96+ }
97+
98+ fn poll_flush ( self : Pin < & mut Self > , _cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , std:: io:: Error > > {
99+ Poll :: Ready ( Ok ( ( ) ) ) // Nothing to flush in WebSocket context
100+ }
101+
102+ fn poll_close (
103+ self : Pin < & mut Self > ,
104+ _cx : & mut Context < ' _ > ,
105+ ) -> Poll < io:: Result < ( ) > > {
106+ Poll :: Ready ( Ok ( self . outgoing . close_channel ( ) ) )
107+ }
108+ }
109+
110+ #[ cfg( test) ]
111+ mod test {
112+ use std:: sync:: Arc ;
113+
114+ use gloo_net:: websocket:: futures:: WebSocket ;
115+ use wasm_bindgen_test:: { wasm_bindgen_test as test} ;
116+
117+ #[ test]
118+ async fn test_ws_io ( ) {
119+ use futures_util:: { AsyncReadExt , AsyncWriteExt } ;
120+ assert ! ( true ) ;
121+ // // DANGER! TODO get from &self config, do not get config directly from PAYJOIN_DIR ohttp-gateway
122+ // // That would reveal IP address
123+ // let tls_connector = {
124+ // let root_store = futures_rustls::rustls::RootCertStore {
125+ // roots: webpki_roots::TLS_SERVER_ROOTS.iter().cloned().collect(),
126+ // };
127+
128+ // let config = futures_rustls::rustls::ClientConfig::builder()
129+ // .with_root_certificates(root_store)
130+ // .with_no_client_auth();
131+ // futures_rustls::TlsConnector::from(Arc::new(config))
132+ // };
133+
134+ // let domain = futures_rustls::rustls::pki_types::ServerName::try_from("payjo.in")
135+ // .map_err(|_| {
136+ // std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid dnsname")
137+ // })
138+ // .unwrap()
139+ // .to_owned();
140+
141+ // let ws = WebSocket::open(&format!("ws://127.0.0.1:3030")).unwrap();
142+ // let ws_io = crate::networking::ws_io::WsIo::new(ws);
143+ // let mut tls_stream = tls_connector.connect(domain, ws_io).await.unwrap();
144+ // let ohttp_keys_req = b"GET /ohttp-keys HTTP/1.1\r\nHost: payjo.in\r\nConnection: close\r\n\r\n";
145+ // tls_stream.write_all(ohttp_keys_req).await.unwrap();
146+ // tls_stream.flush().await.unwrap();
147+ // let mut ohttp_keys = Vec::new();
148+ // tls_stream.read_to_end(&mut ohttp_keys).await.unwrap();
149+ // let ohttp_keys_base64 = base64::encode(ohttp_keys);
150+ // println!("{}", &ohttp_keys_base64);
151+ }
152+ }
0 commit comments