View Javadoc
1   package com.ericsson.research.transport.ssl;
2   
3   /*
4    * ##_BEGIN_LICENSE_##
5    * Transport Abstraction Package (trap)
6    * ----------
7    * Copyright (C) 2014 Ericsson AB
8    * ----------
9    * Redistribution and use in source and binary forms, with or without modification,
10   * are permitted provided that the following conditions are met:
11   * 
12   * 1. Redistributions of source code must retain the above copyright notice, this
13   *    list of conditions and the following disclaimer.
14   * 
15   * 2. Redistributions in binary form must reproduce the above copyright notice,
16   *    this list of conditions and the following disclaimer in the documentation
17   *    and/or other materials provided with the distribution.
18   * 
19   * 3. Neither the name of the Ericsson AB nor the names of its contributors
20   *    may be used to endorse or promote products derived from this software without
21   *    specific prior written permission.
22   * 
23   * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
24   * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
25   * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
26   * IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
27   * INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
28   * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
29   * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
30   * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE
31   * OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
32   * OF THE POSSIBILITY OF SUCH DAMAGE.
33   * ##_END_LICENSE_##
34   */
35  
36  import java.io.IOException;
37  import java.nio.ByteBuffer;
38  import java.util.concurrent.Executor;
39  import java.util.concurrent.Executors;
40  
41  import javax.net.ssl.SSLContext;
42  import javax.net.ssl.SSLEngine;
43  import javax.net.ssl.SSLEngineResult;
44  import javax.net.ssl.SSLEngineResult.HandshakeStatus;
45  import javax.net.ssl.SSLEngineResult.Status;
46  import javax.net.ssl.SSLException;
47  
48  import com.ericsson.research.transport.ManagedSocket;
49  import com.ericsson.research.transport.NioEndpoint;
50  
51  public class SSLSocket extends ManagedSocket implements Runnable, NioEndpoint
52  {
53      
54      private final SSLEngine  engine;
55      private final Executor   executor          = Executors.newSingleThreadExecutor();
56      
57      private static final int growSize          = 32 * 1024;
58      private static final int bufSize           = 32 * 1024;
59      
60      private Object           readLock          = new Object();
61      private Object           writeLock         = new Object();
62      
63      private ByteBuffer       netReadBuf        = ByteBuffer.allocate(bufSize);
64      private ByteBuffer       sslNetReadBuf     = ByteBuffer.allocate(bufSize);
65      private ByteBuffer       netWriteBuf       = ByteBuffer.allocate(bufSize);
66      
67      private ByteBuffer       clientReadBuf     = ByteBuffer.allocate(bufSize);
68      private ByteBuffer       clientWriteBuf    = ByteBuffer.allocate(bufSize);
69      private ByteBuffer       sslClientWriteBuf = ByteBuffer.allocate(bufSize);
70      
71      private SSLEngineResult  lastResult        = null;
72      
73      private boolean          needsWrap         = false;
74      private boolean          needsUnwrap       = false;
75      private boolean          needsDisconnect   = false;
76      
77      private int              errors     = 0;
78      private static final int MAX_ERRORS = 10;
79      
80      public synchronized void disconnect()
81      {
82          this.needsDisconnect = true;
83          
84          try
85          {
86              this.executor.execute(this);
87          }
88          catch (Exception e)
89          {
90              e.printStackTrace();
91          }
92      }
93      
94      public SSLSocket(SSLContext sslc)
95      {
96          this(sslc, true);
97      }
98      
99      public SSLSocket(SSLContext sslc, boolean clientMode)
100     {
101         super(!clientMode);
102         this.engine = sslc.createSSLEngine();
103         this.engine.setUseClientMode(clientMode);
104         String[] procols = { "TLSv1" };
105         this.engine.setEnabledProtocols(procols);
106     }
107     
108     ByteBuffer grow(ByteBuffer src, int size)
109     {
110         src.flip();
111         ByteBuffer b1 = ByteBuffer.allocate(src.capacity() + size);
112         b1.put(src);
113         return b1;
114     }
115     
116     public void receive(byte[] data, int size)
117     {
118         
119         if (super.getState() == State.NOT_CONNECTED)
120             return;
121         byte[] mData = new byte[size];
122         System.arraycopy(data, 0, mData, 0, size);
123         
124         synchronized (this.readLock)
125         {
126             if (this.netReadBuf.remaining() < size)
127             {
128                 // Grow the buffer
129                 this.netReadBuf = this.grow(this.netReadBuf, Math.max(growSize, size));
130                 
131             }
132             
133             this.netReadBuf.put(mData, 0, size);
134         }
135         synchronized (this)
136         {
137             this.needsUnwrap = true;
138             this.executor.execute(this);
139         }
140     }
141     
142     public void write(byte[] data) throws IOException
143     {
144         this.write(data, data.length);
145     }
146     
147     @Override
148     public void write(byte[] data, int size) throws IOException
149     {
150         if (super.getState() == State.NOT_CONNECTED)
151             return;
152         if (data.length < size)
153             throw new IOException("Data size may not exceed array length; array length was " + data.length + " and copy length requested was " + size);
154         
155         synchronized (this.writeLock)
156         {
157             if (this.clientWriteBuf.position() + size >= this.clientWriteBuf.limit())
158             {
159                 // Grow the buffer
160                 this.clientWriteBuf = this.grow(this.clientWriteBuf, Math.max(growSize, size));
161             }
162             
163             this.clientWriteBuf.put(data, 0, size);
164         }
165         
166         synchronized (this)
167         {
168             this.needsWrap = true;
169             this.executor.execute(this);
170         }
171     }
172     
173     public synchronized void run()
174     {
175         try
176         {
177             // Run loop that checks last status and runs through the buffers, wrapping and unwrapping as necessary.
178             if (this.lastResult == null)
179             {
180                 this.engine.beginHandshake();
181                 HandshakeStatus handshakeStatus = this.engine.getHandshakeStatus();
182                 
183                 if (handshakeStatus == HandshakeStatus.NEED_UNWRAP)
184                     this.needsUnwrap = true;
185                 else
186                     this.needsWrap = true;
187             }
188             
189             if (this.needsUnwrap)
190                 this.unwrap();
191             
192             if (this.needsWrap)
193                 this.wrap();
194             
195             if (this.needsDisconnect)
196             {
197                 if (!this.needsWrap) // If we still need to wrap (send) something, don't lose data...
198                     super.disconnect();
199             }
200             
201             this.errors = 0;
202             
203         }
204         catch (Exception e)
205         {
206             e.printStackTrace();
207             this.errors++;
208             
209             // Be a little forgiving. This might fix an issue with WSS and Safari/Chrome
210             if (this.errors > MAX_ERRORS)
211             {
212                 this.needsWrap = false;
213                 this.needsUnwrap = false;
214                 super.disconnect();
215             }
216         }
217         
218     }
219     
220     void unwrap() throws SSLException
221     {
222         if (this.sslNetReadBuf.position() <= 0 && this.sslNetReadBuf.remaining() == this.sslNetReadBuf.capacity()) // Empty buffer
223         {
224             synchronized (this.readLock)
225             {
226                 ByteBuffer b = this.netReadBuf;
227                 this.netReadBuf = this.sslNetReadBuf;
228                 this.sslNetReadBuf = b;
229                 if (this.sslNetReadBuf.position() <= 0 && (this.lastResult == null || this.lastResult.getHandshakeStatus() == HandshakeStatus.NEED_UNWRAP))
230                     return;
231             }
232         }
233         this.sslNetReadBuf.flip();
234         this.lastResult = this.engine.unwrap(this.sslNetReadBuf, this.clientReadBuf);
235         this.sslNetReadBuf.position(this.lastResult.bytesConsumed());
236         this.sslNetReadBuf.compact();
237         Status status = this.lastResult.getStatus();
238         
239         if (status == Status.CLOSED)
240             return;
241         
242         if (status == Status.BUFFER_UNDERFLOW)
243         {
244             if (this.sslNetReadBuf.position() > 0 && this.netReadBuf.position() > 0)
245             {
246                 // Consolidate the two buffers (in case a message is split up between them)		
247                 synchronized (this.netReadBuf)
248                 {
249                     
250                     this.netReadBuf.flip();
251                     if (this.sslNetReadBuf.remaining() < this.netReadBuf.limit())
252                     {
253                         // Grow the buffer
254                         this.sslNetReadBuf = this.grow(this.sslNetReadBuf, Math.max(growSize, this.netReadBuf.limit()));
255                         
256                     }
257                     
258                     this.sslNetReadBuf.put(this.netReadBuf);
259                     this.netReadBuf.clear();
260                 }
261                 
262                 // Attempt to unwrap using the consolidated buffer.
263                 this.executor.execute(this);
264                 
265             }
266             return; // We'll read more data later
267         }
268         
269         if (status == Status.BUFFER_OVERFLOW)
270         {
271             // Grow the target buffer
272             this.clientReadBuf = this.grow(this.clientReadBuf, growSize);
273             this.unwrap(); // Execute again
274             return;
275         }
276         
277         HandshakeStatus handshakeStatus = this.lastResult.getHandshakeStatus();
278         
279         if (handshakeStatus != HandshakeStatus.NEED_TASK && this.netReadBuf.position() == 0)
280             this.needsUnwrap = false;
281         else if (this.netReadBuf.position() > 0)
282             this.executor.execute(this);
283         
284         if (this.clientReadBuf.position() > 0)
285         {
286             super.receive(this.clientReadBuf.array(), this.clientReadBuf.position());
287             this.clientReadBuf.clear();
288         }
289         
290         this.handleHandshakeStatus();
291         
292     }
293     
294     void wrap() throws IOException
295     {
296         
297         if (this.sslClientWriteBuf.position() <= 0)
298         {
299             synchronized (this.writeLock)
300             {
301                 ByteBuffer b = this.clientWriteBuf;
302                 this.clientWriteBuf = this.sslClientWriteBuf;
303                 this.sslClientWriteBuf = b;
304                 if (this.sslClientWriteBuf.position() <= 0 && (this.lastResult == null || this.lastResult.getHandshakeStatus() == HandshakeStatus.NEED_UNWRAP))
305                     return;
306             }
307         }
308         
309         this.sslClientWriteBuf.flip();
310         this.lastResult = this.engine.wrap(this.sslClientWriteBuf, this.netWriteBuf);
311         this.sslClientWriteBuf.compact();
312         Status status = this.lastResult.getStatus();
313         
314         if (status == Status.BUFFER_OVERFLOW)
315         {
316             // Grow target buffer
317             this.netWriteBuf = this.grow(this.netWriteBuf, growSize);
318             this.wrap();
319             return;
320         }
321         
322         if (status == Status.BUFFER_UNDERFLOW)
323         // TODO: Does this ever happen?
324         {
325             return; // We'll read more data later
326         }
327         
328         if (this.lastResult.getHandshakeStatus() == HandshakeStatus.NOT_HANDSHAKING && (this.sslClientWriteBuf.remaining() != this.sslClientWriteBuf.capacity() || this.clientWriteBuf.remaining() != this.clientWriteBuf.capacity()))
329             this.executor.execute(this);
330         else if (this.lastResult.getHandshakeStatus() != HandshakeStatus.NEED_TASK)
331             this.needsWrap = false;
332         
333         if (this.netWriteBuf.position() > 0)
334         {
335             byte[] mData = new byte[this.netWriteBuf.position()];
336             System.arraycopy(this.netWriteBuf.array(), 0, mData, 0, this.netWriteBuf.position());
337             super.write(mData, mData.length);
338             this.netWriteBuf.clear();
339             
340         }
341         
342         this.handleHandshakeStatus();
343         
344     }
345     
346     private void handleHandshakeStatus()
347     {
348         
349         switch (this.lastResult.getHandshakeStatus())
350         {
351         
352             case NEED_TASK:
353                 Runnable r;
354                 while ((r = this.engine.getDelegatedTask()) != null)
355                     this.executor.execute(r);
356                 this.executor.execute(this);
357                 return;
358                 
359             case NEED_WRAP:
360                 this.needsWrap = true;
361                 this.executor.execute(this);
362                 return;
363                 
364             case NEED_UNWRAP:
365                 this.needsUnwrap = true;
366                 this.executor.execute(this);
367                 return;
368                 
369             case FINISHED:
370                 if (this.sslClientWriteBuf.position() > 0 || this.clientWriteBuf.position() > 0)
371                 {
372                     this.needsWrap = true;
373                     this.executor.execute(this);
374                 }
375                 
376                 if (this.sslNetReadBuf.position() > 0 || this.netReadBuf.position() > 0)
377                 {
378                     this.needsUnwrap = true;
379                     this.executor.execute(this);
380                 }
381                 break;
382             
383             case NOT_HANDSHAKING:
384                 if (this.lastResult.getStatus() != Status.OK)
385                     break;
386                 
387                 if (this.sslClientWriteBuf.position() > 0 || this.clientWriteBuf.position() > 0)
388                 {
389                     this.needsWrap = true;
390                     this.executor.execute(this);
391                 }
392                 
393                 if (this.sslNetReadBuf.position() > 0 || this.netReadBuf.position() > 0)
394                 {
395                     this.needsUnwrap = true;
396                     this.executor.execute(this);
397                 }
398                 
399         }
400         
401     }
402     
403 }