1 package com.ericsson.research.transport.ssl;
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36 import java.io.IOException;
37 import java.nio.ByteBuffer;
38 import java.util.concurrent.Executor;
39 import java.util.concurrent.Executors;
40
41 import javax.net.ssl.SSLContext;
42 import javax.net.ssl.SSLEngine;
43 import javax.net.ssl.SSLEngineResult;
44 import javax.net.ssl.SSLEngineResult.HandshakeStatus;
45 import javax.net.ssl.SSLEngineResult.Status;
46 import javax.net.ssl.SSLException;
47
48 import com.ericsson.research.transport.ManagedSocket;
49 import com.ericsson.research.transport.NioEndpoint;
50
51 public class SSLSocket extends ManagedSocket implements Runnable, NioEndpoint
52 {
53
54 private final SSLEngine engine;
55 private final Executor executor = Executors.newSingleThreadExecutor();
56
57 private static final int growSize = 32 * 1024;
58 private static final int bufSize = 32 * 1024;
59
60 private Object readLock = new Object();
61 private Object writeLock = new Object();
62
63 private ByteBuffer netReadBuf = ByteBuffer.allocate(bufSize);
64 private ByteBuffer sslNetReadBuf = ByteBuffer.allocate(bufSize);
65 private ByteBuffer netWriteBuf = ByteBuffer.allocate(bufSize);
66
67 private ByteBuffer clientReadBuf = ByteBuffer.allocate(bufSize);
68 private ByteBuffer clientWriteBuf = ByteBuffer.allocate(bufSize);
69 private ByteBuffer sslClientWriteBuf = ByteBuffer.allocate(bufSize);
70
71 private SSLEngineResult lastResult = null;
72
73 private boolean needsWrap = false;
74 private boolean needsUnwrap = false;
75 private boolean needsDisconnect = false;
76
77 private int errors = 0;
78 private static final int MAX_ERRORS = 10;
79
80 public synchronized void disconnect()
81 {
82 this.needsDisconnect = true;
83
84 try
85 {
86 this.executor.execute(this);
87 }
88 catch (Exception e)
89 {
90 e.printStackTrace();
91 }
92 }
93
94 public SSLSocket(SSLContext sslc)
95 {
96 this(sslc, true);
97 }
98
99 public SSLSocket(SSLContext sslc, boolean clientMode)
100 {
101 super(!clientMode);
102 this.engine = sslc.createSSLEngine();
103 this.engine.setUseClientMode(clientMode);
104 String[] procols = { "TLSv1" };
105 this.engine.setEnabledProtocols(procols);
106 }
107
108 ByteBuffer grow(ByteBuffer src, int size)
109 {
110 src.flip();
111 ByteBuffer b1 = ByteBuffer.allocate(src.capacity() + size);
112 b1.put(src);
113 return b1;
114 }
115
116 public void receive(byte[] data, int size)
117 {
118
119 if (super.getState() == State.NOT_CONNECTED)
120 return;
121 byte[] mData = new byte[size];
122 System.arraycopy(data, 0, mData, 0, size);
123
124 synchronized (this.readLock)
125 {
126 if (this.netReadBuf.remaining() < size)
127 {
128
129 this.netReadBuf = this.grow(this.netReadBuf, Math.max(growSize, size));
130
131 }
132
133 this.netReadBuf.put(mData, 0, size);
134 }
135 synchronized (this)
136 {
137 this.needsUnwrap = true;
138 this.executor.execute(this);
139 }
140 }
141
142 public void write(byte[] data) throws IOException
143 {
144 this.write(data, data.length);
145 }
146
147 @Override
148 public void write(byte[] data, int size) throws IOException
149 {
150 if (super.getState() == State.NOT_CONNECTED)
151 return;
152 if (data.length < size)
153 throw new IOException("Data size may not exceed array length; array length was " + data.length + " and copy length requested was " + size);
154
155 synchronized (this.writeLock)
156 {
157 if (this.clientWriteBuf.position() + size >= this.clientWriteBuf.limit())
158 {
159
160 this.clientWriteBuf = this.grow(this.clientWriteBuf, Math.max(growSize, size));
161 }
162
163 this.clientWriteBuf.put(data, 0, size);
164 }
165
166 synchronized (this)
167 {
168 this.needsWrap = true;
169 this.executor.execute(this);
170 }
171 }
172
173 public synchronized void run()
174 {
175 try
176 {
177
178 if (this.lastResult == null)
179 {
180 this.engine.beginHandshake();
181 HandshakeStatus handshakeStatus = this.engine.getHandshakeStatus();
182
183 if (handshakeStatus == HandshakeStatus.NEED_UNWRAP)
184 this.needsUnwrap = true;
185 else
186 this.needsWrap = true;
187 }
188
189 if (this.needsUnwrap)
190 this.unwrap();
191
192 if (this.needsWrap)
193 this.wrap();
194
195 if (this.needsDisconnect)
196 {
197 if (!this.needsWrap)
198 super.disconnect();
199 }
200
201 this.errors = 0;
202
203 }
204 catch (Exception e)
205 {
206 e.printStackTrace();
207 this.errors++;
208
209
210 if (this.errors > MAX_ERRORS)
211 {
212 this.needsWrap = false;
213 this.needsUnwrap = false;
214 super.disconnect();
215 }
216 }
217
218 }
219
220 void unwrap() throws SSLException
221 {
222 if (this.sslNetReadBuf.position() <= 0 && this.sslNetReadBuf.remaining() == this.sslNetReadBuf.capacity())
223 {
224 synchronized (this.readLock)
225 {
226 ByteBuffer b = this.netReadBuf;
227 this.netReadBuf = this.sslNetReadBuf;
228 this.sslNetReadBuf = b;
229 if (this.sslNetReadBuf.position() <= 0 && (this.lastResult == null || this.lastResult.getHandshakeStatus() == HandshakeStatus.NEED_UNWRAP))
230 return;
231 }
232 }
233 this.sslNetReadBuf.flip();
234 this.lastResult = this.engine.unwrap(this.sslNetReadBuf, this.clientReadBuf);
235 this.sslNetReadBuf.position(this.lastResult.bytesConsumed());
236 this.sslNetReadBuf.compact();
237 Status status = this.lastResult.getStatus();
238
239 if (status == Status.CLOSED)
240 return;
241
242 if (status == Status.BUFFER_UNDERFLOW)
243 {
244 if (this.sslNetReadBuf.position() > 0 && this.netReadBuf.position() > 0)
245 {
246
247 synchronized (this.netReadBuf)
248 {
249
250 this.netReadBuf.flip();
251 if (this.sslNetReadBuf.remaining() < this.netReadBuf.limit())
252 {
253
254 this.sslNetReadBuf = this.grow(this.sslNetReadBuf, Math.max(growSize, this.netReadBuf.limit()));
255
256 }
257
258 this.sslNetReadBuf.put(this.netReadBuf);
259 this.netReadBuf.clear();
260 }
261
262
263 this.executor.execute(this);
264
265 }
266 return;
267 }
268
269 if (status == Status.BUFFER_OVERFLOW)
270 {
271
272 this.clientReadBuf = this.grow(this.clientReadBuf, growSize);
273 this.unwrap();
274 return;
275 }
276
277 HandshakeStatus handshakeStatus = this.lastResult.getHandshakeStatus();
278
279 if (handshakeStatus != HandshakeStatus.NEED_TASK && this.netReadBuf.position() == 0)
280 this.needsUnwrap = false;
281 else if (this.netReadBuf.position() > 0)
282 this.executor.execute(this);
283
284 if (this.clientReadBuf.position() > 0)
285 {
286 super.receive(this.clientReadBuf.array(), this.clientReadBuf.position());
287 this.clientReadBuf.clear();
288 }
289
290 this.handleHandshakeStatus();
291
292 }
293
294 void wrap() throws IOException
295 {
296
297 if (this.sslClientWriteBuf.position() <= 0)
298 {
299 synchronized (this.writeLock)
300 {
301 ByteBuffer b = this.clientWriteBuf;
302 this.clientWriteBuf = this.sslClientWriteBuf;
303 this.sslClientWriteBuf = b;
304 if (this.sslClientWriteBuf.position() <= 0 && (this.lastResult == null || this.lastResult.getHandshakeStatus() == HandshakeStatus.NEED_UNWRAP))
305 return;
306 }
307 }
308
309 this.sslClientWriteBuf.flip();
310 this.lastResult = this.engine.wrap(this.sslClientWriteBuf, this.netWriteBuf);
311 this.sslClientWriteBuf.compact();
312 Status status = this.lastResult.getStatus();
313
314 if (status == Status.BUFFER_OVERFLOW)
315 {
316
317 this.netWriteBuf = this.grow(this.netWriteBuf, growSize);
318 this.wrap();
319 return;
320 }
321
322 if (status == Status.BUFFER_UNDERFLOW)
323
324 {
325 return;
326 }
327
328 if (this.lastResult.getHandshakeStatus() == HandshakeStatus.NOT_HANDSHAKING && (this.sslClientWriteBuf.remaining() != this.sslClientWriteBuf.capacity() || this.clientWriteBuf.remaining() != this.clientWriteBuf.capacity()))
329 this.executor.execute(this);
330 else if (this.lastResult.getHandshakeStatus() != HandshakeStatus.NEED_TASK)
331 this.needsWrap = false;
332
333 if (this.netWriteBuf.position() > 0)
334 {
335 byte[] mData = new byte[this.netWriteBuf.position()];
336 System.arraycopy(this.netWriteBuf.array(), 0, mData, 0, this.netWriteBuf.position());
337 super.write(mData, mData.length);
338 this.netWriteBuf.clear();
339
340 }
341
342 this.handleHandshakeStatus();
343
344 }
345
346 private void handleHandshakeStatus()
347 {
348
349 switch (this.lastResult.getHandshakeStatus())
350 {
351
352 case NEED_TASK:
353 Runnable r;
354 while ((r = this.engine.getDelegatedTask()) != null)
355 this.executor.execute(r);
356 this.executor.execute(this);
357 return;
358
359 case NEED_WRAP:
360 this.needsWrap = true;
361 this.executor.execute(this);
362 return;
363
364 case NEED_UNWRAP:
365 this.needsUnwrap = true;
366 this.executor.execute(this);
367 return;
368
369 case FINISHED:
370 if (this.sslClientWriteBuf.position() > 0 || this.clientWriteBuf.position() > 0)
371 {
372 this.needsWrap = true;
373 this.executor.execute(this);
374 }
375
376 if (this.sslNetReadBuf.position() > 0 || this.netReadBuf.position() > 0)
377 {
378 this.needsUnwrap = true;
379 this.executor.execute(this);
380 }
381 break;
382
383 case NOT_HANDSHAKING:
384 if (this.lastResult.getStatus() != Status.OK)
385 break;
386
387 if (this.sslClientWriteBuf.position() > 0 || this.clientWriteBuf.position() > 0)
388 {
389 this.needsWrap = true;
390 this.executor.execute(this);
391 }
392
393 if (this.sslNetReadBuf.position() > 0 || this.netReadBuf.position() > 0)
394 {
395 this.needsUnwrap = true;
396 this.executor.execute(this);
397 }
398
399 }
400
401 }
402
403 }