1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 module thrift.transport.socket;
20 
21 import core.thread : Thread;
22 import core.time : dur, Duration;
23 import std.array : empty;
24 import std.conv : text, to;
25 import std.exception : enforce;
26 import std.socket;
27 import thrift.base;
28 import thrift.transport.base;
29 import thrift.internal.socket;
30 
31 /**
32  * Common parts of a socket TTransport implementation, regardless of how the
33  * actual I/O is performed (sync/async).
34  */
35 abstract class TSocketBase : TBaseTransport {
36   /**
37    * Constructor that takes an already created, connected (!) socket.
38    *
39    * Params:
40    *   socket = Already created, connected socket object.
41    */
42   this(Socket socket) {
43     socket_ = socket;
44     setSocketOpts();
45   }
46 
47   /**
48    * Creates a new unconnected socket that will connect to the given host
49    * on the given port.
50    *
51    * Params:
52    *   host = Remote host.
53    *   port = Remote port.
54    */
55   this(string host, ushort port) {
56     host_ = host;
57     port_ = port;
58   }
59 
60   /**
61    * Checks whether the socket is connected.
62    */
63   override bool isOpen() @property {
64     return socket_ !is null;
65   }
66 
67   /**
68    * Writes as much data to the socket as there can be in a single OS call.
69    *
70    * Params:
71    *   buf = Data to write.
72    *
73    * Returns: The actual number of bytes written. Never more than buf.length.
74    */
75   abstract size_t writeSome(in ubyte[] buf) out (written) {
76     // DMD @@BUG@@: Enabling this e.g. fails the contract in the
77     // async_test_server, because buf.length evaluates to 0 here, even though
78     // in the method body it correctly is 27 (equal to the return value).
79     version (none) assert(written <= buf.length, text("Implementation wrote " ~
80       "more data than requested to?! (", written, " vs. ", buf.length, ")"));
81   } body {
82     assert(0, "DMD bug? – Why would contracts work for interfaces, but not "
83       "for abstract methods? "
84       "(Error: function […] in and out contracts require function body");
85   }
86 
87   /**
88    * Returns the actual address of the peer the socket is connected to.
89    *
90    * In contrast, the host and port properties contain the address used to
91    * establish the connection, and are not updated after the connection.
92    *
93    * The socket must be open when calling this.
94    */
95   Address getPeerAddress() {
96     enforce(isOpen, new TTransportException("Cannot get peer host for " ~
97       "closed socket.", TTransportException.Type.NOT_OPEN));
98 
99     if (!peerAddress_) {
100       peerAddress_ = socket_.remoteAddress();
101       assert(peerAddress_);
102     }
103 
104     return peerAddress_;
105   }
106 
107   /**
108    * The host the socket is connected to or will connect to. Null if an
109    * already connected socket was used to construct the object.
110    */
111   string host() const @property {
112     return host_;
113   }
114 
115   /**
116    * The port the socket is connected to or will connect to. Zero if an
117    * already connected socket was used to construct the object.
118    */
119   ushort port() const @property {
120     return port_;
121   }
122 
123   /// The socket send timeout.
124   Duration sendTimeout() const @property {
125     return sendTimeout_;
126   }
127 
128   /// Ditto
129   void sendTimeout(Duration value) @property {
130     sendTimeout_ = value;
131   }
132 
133   /// The socket receiving timeout. Values smaller than 500 ms are not
134   /// supported on Windows.
135   Duration recvTimeout() const @property {
136     return recvTimeout_;
137   }
138 
139   /// Ditto
140   void recvTimeout(Duration value) @property {
141     recvTimeout_ = value;
142   }
143 
144   /**
145    * Returns the OS handle of the underlying socket.
146    *
147    * Should not usually be used directly, but access to it can be necessary
148    * to interface with C libraries.
149    */
150   typeof(socket_.handle()) socketHandle() @property {
151     return socket_.handle();
152   }
153 
154 protected:
155   /**
156    * Sets the needed socket options.
157    */
158   void setSocketOpts() {
159     try {
160       alias SocketOptionLevel.SOCKET lvlSock;
161       Linger l;
162       l.on = 0;
163       l.time = 0;
164       socket_.setOption(lvlSock, SocketOption.LINGER, l);
165     } catch (SocketException e) {
166       logError("Could not set socket option: %s", e);
167     }
168 
169     // Just try to disable Nagle's algorithm – this will fail if we are passed
170     // in a non-TCP socket via the Socket-accepting constructor.
171     try {
172       socket_.setOption(SocketOptionLevel.TCP, SocketOption.TCP_NODELAY, true);
173     } catch (SocketException e) {}
174   }
175 
176   /// Remote host.
177   string host_;
178 
179   /// Remote port.
180   ushort port_;
181 
182   /// Timeout for sending.
183   Duration sendTimeout_;
184 
185   /// Timeout for receiving.
186   Duration recvTimeout_;
187 
188   /// Cached peer address.
189   Address peerAddress_;
190 
191   /// Cached peer host name.
192   string peerHost_;
193 
194   /// Cached peer port.
195   ushort peerPort_;
196 
197   /// Wrapped socket object.
198   Socket socket_;
199 }
200 
201 /**
202  * Socket implementation of the TTransport interface.
203  *
204  * Due to the limitations of std.socket, currently only TCP/IP sockets are
205  * supported (i.e. Unix domain sockets are not).
206  */
207 class TSocket : TSocketBase {
208   ///
209   this(Socket socket) {
210     super(socket);
211   }
212 
213   ///
214   this(string host, ushort port) {
215     super(host, port);
216   }
217 
218   /**
219    * Connects the socket.
220    */
221   override void open() {
222     if (isOpen) return;
223 
224     enforce(!host_.empty, new TTransportException(
225       "Cannot open socket to null host.", TTransportException.Type.NOT_OPEN));
226     enforce(port_ != 0, new TTransportException(
227       "Cannot open socket to port zero.", TTransportException.Type.NOT_OPEN));
228 
229     Address[] addrs;
230     try {
231       addrs = getAddress(host_, port_);
232     } catch (SocketException e) {
233       throw new TTransportException("Could not resolve given host string.",
234         TTransportException.Type.NOT_OPEN, __FILE__, __LINE__, e);
235     }
236 
237     Exception[] errors;
238     foreach (addr; addrs) {
239       try {
240         socket_ = new TcpSocket(addr.addressFamily);
241         setSocketOpts();
242         socket_.connect(addr);
243         break;
244       } catch (SocketException e) {
245         errors ~= e;
246       }
247     }
248     if (errors.length == addrs.length) {
249       socket_ = null;
250       // Need to throw a TTransportException to abide the TTransport API.
251       import std.algorithm, std.range;
252       throw new TTransportException(
253         text("Failed to connect to ", host_, ":", port_, "."),
254         TTransportException.Type.NOT_OPEN,
255         __FILE__, __LINE__,
256         new TCompoundOperationException(
257           text(
258             "All addresses tried failed (",
259             joiner(map!q{text(a._0, `: "`, a._1.msg, `"`)}(zip(addrs, errors)), ", "),
260             ")."
261           ),
262           errors
263         )
264       );
265     }
266   }
267 
268   /**
269    * Closes the socket.
270    */
271   override void close() {
272     if (!isOpen) return;
273 
274     socket_.close();
275     socket_ = null;
276   }
277 
278   override bool peek() {
279     if (!isOpen) return false;
280 
281     ubyte buf;
282     auto r = socket_.receive((&buf)[0 .. 1], SocketFlags.PEEK);
283     if (r == -1) {
284       auto lastErrno = getSocketErrno();
285       static if (connresetOnPeerShutdown) {
286         if (lastErrno == ECONNRESET) {
287           close();
288           return false;
289         }
290       }
291       throw new TTransportException("Peeking into socket failed: " ~
292         socketErrnoString(lastErrno), TTransportException.Type.UNKNOWN);
293     }
294     return (r > 0);
295   }
296 
297   override size_t read(ubyte[] buf) {
298     enforce(isOpen, new TTransportException(
299       "Cannot read if socket is not open.", TTransportException.Type.NOT_OPEN));
300 
301     typeof(getSocketErrno()) lastErrno;
302     ushort tries;
303     while (tries++ <= maxRecvRetries_) {
304       auto r = socket_.receive(cast(void[])buf);
305 
306       // If recv went fine, immediately return.
307       if (r >= 0) return r;
308 
309       // Something went wrong, find out how to handle it.
310       lastErrno = getSocketErrno();
311 
312       if (lastErrno == INTERRUPTED_ERRNO) {
313         // If the syscall was interrupted, just try again.
314         continue;
315       }
316 
317       static if (connresetOnPeerShutdown) {
318         // See top comment.
319         if (lastErrno == ECONNRESET) {
320           return 0;
321         }
322       }
323 
324       // Not an error which is handled in a special way, just leave the loop.
325       break;
326     }
327 
328     if (isSocketCloseErrno(lastErrno)) {
329       close();
330       throw new TTransportException("Receiving failed, closing socket: " ~
331         socketErrnoString(lastErrno), TTransportException.Type.NOT_OPEN);
332     } else if (lastErrno == TIMEOUT_ERRNO) {
333       throw new TTransportException(TTransportException.Type.TIMED_OUT);
334     } else {
335       throw new TTransportException("Receiving from socket failed: " ~
336         socketErrnoString(lastErrno), TTransportException.Type.UNKNOWN);
337     }
338   }
339 
340   override void write(in ubyte[] buf) {
341     size_t sent;
342     while (sent < buf.length) {
343       auto b = writeSome(buf[sent .. $]);
344       if (b == 0) {
345         // This should only happen if the timeout set with SO_SNDTIMEO expired.
346         throw new TTransportException("send() timeout expired.",
347           TTransportException.Type.TIMED_OUT);
348       }
349       sent += b;
350     }
351     assert(sent == buf.length);
352   }
353 
354   override size_t writeSome(in ubyte[] buf) {
355     enforce(isOpen, new TTransportException(
356       "Cannot write if file is not open.", TTransportException.Type.NOT_OPEN));
357 
358     auto r = socket_.send(buf);
359 
360     // Everything went well, just return the number of bytes written.
361     if (r > 0) return r;
362 
363     // Handle error conditions.
364     if (r < 0) {
365       auto lastErrno = getSocketErrno();
366 
367       if (lastErrno == WOULD_BLOCK_ERRNO) {
368         // Not an exceptional error per se – even with blocking sockets,
369         // EAGAIN apparently is returned sometimes on out-of-resource
370         // conditions (see the C++ implementation for details). Also, this
371         // allows using TSocket with non-blocking sockets e.g. in
372         // TNonblockingServer.
373         return 0;
374       }
375 
376       auto type = TTransportException.Type.UNKNOWN;
377       if (isSocketCloseErrno(lastErrno)) {
378         type = TTransportException.Type.NOT_OPEN;
379         close();
380       }
381 
382       throw new TTransportException("Sending to socket failed: " ~
383         socketErrnoString(lastErrno), type);
384     }
385 
386     // send() should never return 0.
387     throw new TTransportException("Sending to socket failed (0 bytes written).",
388       TTransportException.Type.UNKNOWN);
389   }
390 
391   override void sendTimeout(Duration value) @property {
392     super.sendTimeout(value);
393     setTimeout(SocketOption.SNDTIMEO, value);
394   }
395 
396   override void recvTimeout(Duration value) @property {
397     super.recvTimeout(value);
398     setTimeout(SocketOption.RCVTIMEO, value);
399   }
400 
401   /**
402    * Maximum number of retries for receiving from socket on read() in case of
403    * EAGAIN/EINTR.
404    */
405   ushort maxRecvRetries() @property const {
406     return maxRecvRetries_;
407   }
408 
409   /// Ditto
410   void maxRecvRetries(ushort value) @property {
411     maxRecvRetries_ = value;
412   }
413 
414   /// Ditto
415   enum DEFAULT_MAX_RECV_RETRIES = 5;
416 
417 protected:
418   override void setSocketOpts() {
419     super.setSocketOpts();
420     setTimeout(SocketOption.SNDTIMEO, sendTimeout_);
421     setTimeout(SocketOption.RCVTIMEO, recvTimeout_);
422   }
423 
424   void setTimeout(SocketOption type, Duration value) {
425     assert(type == SocketOption.SNDTIMEO || type == SocketOption.RCVTIMEO);
426     version (Win32) {
427       if (value > dur!"hnsecs"(0) && value < dur!"msecs"(500)) {
428         logError(
429           "Socket %s timeout of %s ms might be raised to 500 ms on Windows.",
430           (type == SocketOption.SNDTIMEO) ? "send" : "receive",
431           value.total!"msecs"
432         );
433       }
434     }
435 
436     if (socket_) {
437       try {
438         socket_.setOption(SocketOptionLevel.SOCKET, type, value);
439       } catch (SocketException e) {
440         throw new TTransportException(
441           "Could not set timeout.",
442           TTransportException.Type.UNKNOWN,
443           __FILE__,
444           __LINE__,
445           e
446         );
447       }
448     }
449   }
450 
451   /// Maximum number of recv() retries.
452   ushort maxRecvRetries_  = DEFAULT_MAX_RECV_RETRIES;
453 }