1 /** Network transport management implementation for JSON-RPC data.
2 
3     You attach a transport to your RPCClient and a listener to your RPCServers,
4     but you do not need to use the APIs directly.
5 
6     Example:
7     ---
8     interface IMyFuncs { void f(); }
9     class MyFuncs : IMyFuncs { void f() { return; }
10 
11     // TCP sockets are the default - you don't have to name them explicitly...
12     auto server = new RPCServer!(MyFuncs, TCPListener!MyFuncs)
13             ("127.0.0.1", 54321);
14     auto client = new RPCClient!(IMyFuncs, TCPTransport)
15             ("127.0.0.1", 54321);
16 
17     client.f();
18     ---
19 
20     Authors:
21         Ryan Frame
22 
23     Copyright:
24         Copyright 2018 Ryan Frame
25 
26     License:
27         MIT
28 */
29 module jsonrpc.transport; @safe:
30 
31 import std.socket;
32 import std.traits : ReturnType;
33 import jsonrpc.exception;
34 
35 version(Have_tested) import tested : test = name;
36 else private struct test { string name; }
37 
38 private enum SocketBufSize = 4096;
39 
40 /** Check whether the specified object is a valid transport for RPCClients. */
41 enum bool isTransport(T) =
42         is(T == struct) &&
43         is(ReturnType!((T t) => t.send([])) == size_t) &&
44         is(ReturnType!((T t) => t.receiveJSONObjectOrArray) == char[]) &&
45         is(typeof((T t) => t.close)) &&
46         is(ReturnType!((T t) => t.isAlive) == bool);
47 
48 /** Check whether the specified object is a valid listener for RPCServers. */
49 enum bool isListener(T) =
50         is(T == struct);
51 
52 
53 /** Receive a JSON object or array. Mixin template for transport implementations.
54 
55     If your transport provides a `receiveData` function defined as
56     $(D_INLINECODE char[] receiveData(); ) the receiveJSONObjectOrArray will
57     call it and return the first complete JSON object or array from the char
58     stream. Any trailing data is thrown away.
59 */
60 mixin template ReceiveJSON() {
61     /** Receive a single JSON object or array from the socket stream.
62 
63         Any trailing data is thrown away.
64     */
65     char[] receiveJSONObjectOrArray() {
66         auto data = receiveData();
67 
68         char startBrace;
69         char endBrace;
70         if (data[0] == '{') {
71             startBrace = '{';
72             endBrace = '}';
73         } else if (data[0] == '[') {
74             startBrace = '[';
75             endBrace = ']';
76         } else {
77             raise!(InvalidDataReceived, data)
78                     ("Expected to receive a JSON object or array.");
79         }
80 
81         // Count the braces we receive. If we don't have a full object/array,
82         // receive until we do.
83         int braceCount = 0;
84         size_t loc = 0;
85         while(true) {
86             for (; loc < data.length; ++loc) {
87                 if (data[loc] == startBrace) ++braceCount;
88                 else if (data[loc] == endBrace) --braceCount;
89             }
90 
91             // If we receive an incomplete object, get more data and repeat as
92             // needed.
93             if (braceCount > 0) {
94                 data ~= receiveData();
95             } else return data;
96         }
97     }
98 }
99 
100 
101 /** Manage TCP transport connection details and tasks. */
102 struct TCPTransport {
103     static assert(isTransport!TCPTransport);
104 
105     package:
106 
107     /** Instantiate a TCPTransport object.
108 
109         Params:
110             host = The hostname to connect to.
111             port = The port number of the host to connect to.
112     */
113     this(string host, ushort port) in {
114         assert(host.length > 0);
115     } body {
116         this(new TcpSocket(getAddress(host, port)[0]));
117     }
118 
119     /** Send the provided data and return the number of bytes sent.
120 
121         If the return value is not equal to the length of the input in bytes,
122         there was a transmission error.
123 
124         Params:
125             data = The string data to send.
126     */
127     size_t send(const char[] data) {
128         ptrdiff_t bytesSent = 0;
129         while (bytesSent < data.length) {
130             auto sent = _socket.send(data[bytesSent..$]);
131             if (sent == Socket.ERROR || sent == 0) break;
132             bytesSent += sent;
133         }
134         return bytesSent;
135     }
136 
137     mixin ReceiveJSON;
138 
139     /** Close the transport's underlying socket. */
140     void close() {
141         _socket.shutdown(SocketShutdown.BOTH);
142         _socket.close();
143     }
144 
145     /** Query the transport to see if it's still active. */
146     nothrow
147     bool isAlive() {
148         scope(failure) return false;
149         return _socket.isAlive();
150     }
151 
152     private:
153 
154     /** Receive incoming data. */
155     char[] receiveData() {
156         char[SocketBufSize] buf;
157         ptrdiff_t receivedBytes = 0;
158 
159         receivedBytes = _socket.receive(buf);
160         if (receivedBytes <= 0) return [];
161         return buf[0..receivedBytes].dup;
162     }
163 
164     Socket _socket;
165 
166     /** This constructor is for unit testing. */
167     package this(Socket socket) {
168         _socket = socket;
169         _socket.blocking = true;
170     }
171 }
172 
173 /** Listen for incoming connections and pass clients to a handler function.
174 
175     Template_Parameters:
176         API = The class containing the methods for the server to execute.
177 */
178 struct TCPListener(API) {
179     static assert(isListener!(TCPListener!(API)));
180 
181     package:
182 
183     /** Instantiate a TCPListener object.
184 
185         Params:
186             host = The hostname to connect to.
187             port = The port number of the host to connect to.
188     */
189     this(string host, ushort port) in {
190         assert(host.length > 0);
191     } body {
192         _socket = new TcpSocket();
193         _socket.blocking = true;
194         _socket.setOption(SocketOptionLevel.SOCKET, SocketOption.REUSEADDR, true);
195         _socket.bind(getAddress(host, port)[0]);
196     }
197 
198     /** Listen for client requests.
199 
200         `listen` will call the specified handler function in a new thread to
201         handle each client it accepts.
202 
203         Template_Parameters:
204             handler = The handler function to call when a client connects.
205 
206         Params:
207             api =                  An instantiated class with the methods to
208                                    execute.
209             maxQueuedConnections = The maximum number of connections to backlog
210                                    before refusing connections.
211     */
212     void listen(alias handler)(API api, int maxQueuedConnections = 10) {
213         _socket.listen(maxQueuedConnections);
214         if (! _socket.isAlive) {
215             raise!(ConnectionException)("Listening socket not active.");
216         }
217 
218         while (true) {
219             import std.parallelism : task;
220             auto conn = _socket.accept();
221             task!handler(TCPTransport(conn), api).executeInNewThread();
222         }
223     }
224 
225     private:
226 
227     Socket _socket;
228 }
229 
230 version(unittest) @system:
231 
232 @test("receiveJSONObjectOrArray can receive a JSON object")
233 unittest {
234     interface I {}
235     auto sock = new FakeSocket();
236     auto transport = TCPTransport(sock);
237     enum val = cast(char[])`{"id":23,"method":"func","params":[1,2,3]}`;
238 
239     sock._receiveReturnValue = val;
240     auto ret = transport.receiveJSONObjectOrArray();
241     assert(ret == val);
242 }
243 
244 @test("receiveJSONObjectOrArray can receive a JSON array")
245 unittest {
246     interface I {}
247     auto sock = new FakeSocket();
248     auto transport = TCPTransport(sock);
249     enum val = cast(char[])
250           `[{"id":23,"method":"func","params":[1,2,3]},
251             {"id":24,"method":"func","params":[1,2,3]},
252             {"id":25,"method":"func","params":[1,2,3]},
253             {"method":"func","params":[1,2,3]},
254             {"id":26,"method":"func","params":[1,2,3]}]`;
255 
256     sock._receiveReturnValue = val;
257     auto ret = transport.receiveJSONObjectOrArray();
258     assert(ret == val);
259 }
260 
261 @test("receiveJSONObjectOrArray throws an exception if not given an array or object")
262 unittest {
263     import std.exception : assertThrown;
264     interface I {}
265     auto sock = new FakeSocket();
266     auto transport = TCPTransport(sock);
267     enum val = cast(char[])`"id":23,"method":"func","params":[1,2,3]}`;
268 
269     sock._receiveReturnValue = val;
270     assertThrown!InvalidDataReceived(transport.receiveJSONObjectOrArray());
271 }
272 
273 @test("receiveJSONObjectOrArray receives a full object when its length exceeds SocketBufSize")
274 unittest {
275     import std.array : array;
276     import std.range : repeat, takeExactly;
277     auto sock = new FakeSocket();
278     auto transport = TCPTransport(sock);
279 
280     // This gives us a length of SocketBufSize+8.
281     enum key = 'a'.repeat().takeExactly(SocketBufSize/2).array;
282     enum val = 'b'.repeat().takeExactly(SocketBufSize/2).array;
283     auto sockReturn = cast(char[]) (`{"` ~ key ~ `": "` ~ val ~ `"}`);
284 
285     sock._receiveReturnValue = sockReturn;
286     auto ret = transport.receiveJSONObjectOrArray();
287     assert(cast(string) ret == sockReturn);
288 }
289 
290 version(unittest) {
291     class FakeSocket : Socket {
292         private bool _blocking;
293         private bool _isAlive;
294 
295         private char[] _receiveReturnValue =
296                 cast(char[])`{"id":3,"result":[1,2,3]}`;
297 
298         private char[] _lastDataSent;
299 
300         @property receiveReturnValue(inout char[] s) {
301             _receiveReturnValue = cast(char[])s;
302         }
303 
304         @property lastDataSent() { return _lastDataSent; }
305 
306         @property char[] receiveReturnValue() { return _receiveReturnValue; }
307 
308         override void bind(Address addr) { _isAlive = true; }
309 
310         override const nothrow @nogc @property @trusted bool blocking() {
311             return _blocking;
312         }
313 
314         override @property @trusted void blocking(bool byes) {
315             _blocking = byes;
316         }
317 
318         override @trusted void setOption(SocketOptionLevel level,
319                 SocketOption option, void[] value) {}
320 
321         override const @property @trusted bool isAlive() { return _isAlive; }
322 
323         override @trusted void listen(int backlog) { _isAlive = true; }
324 
325         alias receive = Socket.receive;
326         override @trusted ptrdiff_t receive(void[] buf) {
327             if (buf.length == 0) return 0;
328             auto ret = fillBuffer(cast(char*)buf.ptr, buf.length);
329             _receiveReturnValue = _receiveReturnValue[ret..$];
330             return ret;
331         }
332 
333         @test("FakeSocket.receive allows injecting 'received' characters.")
334         unittest {
335             auto s = new FakeSocket;
336             char[] buf = new char[](SocketBufSize);
337             s.receiveReturnValue = `{"id":3,"result":[1,2,3]}`;
338 
339             auto len = s.receive(buf);
340             assert(buf[0..len] == `{"id":3,"result":[1,2,3]}`,
341                     "Incorrect data received: " ~ buf);
342         }
343 
344         alias send = Socket.send;
345         override @trusted ptrdiff_t send(const(void)[] buf) {
346             _lastDataSent = cast(char[])buf;
347             return buf.length;
348         }
349 
350         private @trusted ptrdiff_t fillBuffer(char* ptr, size_t length) {
351             import std.algorithm.comparison : min;
352             char[] p = ptr[0..length];
353             ptrdiff_t cnt;
354             for (cnt = 0; cnt < min(length, receiveReturnValue.length); ++cnt) {
355                 ptr[cnt] = receiveReturnValue[cnt];
356             }
357             return cnt;
358         }
359     }
360 }