- Add stop() method in IMessageReadWrite for proper clean up when connection is closed.
[controller.git] / opendaylight / protocol_plugins / openflow / src / main / java / org / opendaylight / controller / protocol_plugin / openflow / core / internal / SecureMessageReadWriteService.java
1
2 /*
3  * Copyright (c) 2013 Cisco Systems, Inc. and others.  All rights reserved.
4  *
5  * This program and the accompanying materials are made available under the
6  * terms of the Eclipse Public License v1.0 which accompanies this distribution,
7  * and is available at http://www.eclipse.org/legal/epl-v10.html
8  */
9
10 package org.opendaylight.controller.protocol_plugin.openflow.core.internal;
11
12 import java.io.FileInputStream;
13 import java.io.IOException;
14 import java.nio.ByteBuffer;
15 import java.nio.channels.AsynchronousCloseException;
16 import java.nio.channels.SelectionKey;
17 import java.nio.channels.Selector;
18 import java.nio.channels.SocketChannel;
19 import java.security.KeyStore;
20 import java.security.SecureRandom;
21 import java.util.List;
22 import javax.net.ssl.KeyManagerFactory;
23 import javax.net.ssl.SSLContext;
24 import javax.net.ssl.SSLEngine;
25 import javax.net.ssl.SSLEngineResult;
26 import javax.net.ssl.SSLSession;
27 import javax.net.ssl.TrustManagerFactory;
28 import javax.net.ssl.SSLEngineResult.HandshakeStatus;
29 import org.opendaylight.controller.protocol_plugin.openflow.core.IMessageReadWrite;
30 import org.openflow.protocol.OFMessage;
31 import org.openflow.protocol.factory.BasicFactory;
32 import org.slf4j.Logger;
33 import org.slf4j.LoggerFactory;
34
35 /**
36  * This class implements methods to read/write messages over an established
37  * socket channel. The data exchange is encrypted/decrypted by SSLEngine.
38  */
39 public class SecureMessageReadWriteService implements IMessageReadWrite {
40     private static final Logger logger = LoggerFactory
41             .getLogger(SecureMessageReadWriteService.class);
42
43     private Selector selector;
44     private SelectionKey clientSelectionKey;
45     private SocketChannel socket;
46     private BasicFactory factory;
47
48     private SSLEngine sslEngine;
49         private SSLEngineResult sslEngineResult;        // results from sslEngine last operation
50     private ByteBuffer myAppData;                               // clear text message to be sent
51     private ByteBuffer myNetData;                       // encrypted message to be sent
52     private ByteBuffer peerAppData;                             // clear text message received from the switch
53     private ByteBuffer peerNetData;                     // encrypted message from the switch
54     private FileInputStream kfd = null, tfd = null;
55
56     public SecureMessageReadWriteService(SocketChannel socket, Selector selector) throws Exception {
57         this.socket = socket;
58         this.selector = selector;
59         this.factory = new BasicFactory();
60
61         try {
62                 createSecureChannel(socket);
63                 createBuffers(sslEngine);
64         } catch (Exception e) {
65                 stop();
66                 throw e;
67         }
68     }
69
70         /**
71          * Bring up secure channel using SSL Engine
72          * 
73          * @param socket TCP socket channel
74          * @throws Exception
75          */
76     private void createSecureChannel(SocketChannel socket) throws Exception {
77         String keyStoreFile = System.getProperty("controllerKeyStore").trim();
78         String keyStorePassword = System.getProperty("controllerKeyStorePassword").trim();
79         String trustStoreFile = System.getProperty("controllerTrustStore").trim();
80         String trustStorePassword = System.getProperty("controllerTrustStorePassword").trim();
81
82         KeyStore ks = KeyStore.getInstance("JKS");
83         KeyStore ts = KeyStore.getInstance("JKS");
84         KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
85         TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509");
86         kfd = new FileInputStream(keyStoreFile);
87         tfd = new FileInputStream(trustStoreFile);
88         ks.load(kfd, keyStorePassword.toCharArray());
89         ts.load(tfd, trustStorePassword.toCharArray());
90         kmf.init(ks, keyStorePassword.toCharArray());
91         tmf.init(ts);
92
93         SecureRandom random = new SecureRandom();
94         random.nextInt();
95
96         SSLContext sslContext = SSLContext.getInstance("TLS");
97         sslContext.init(kmf.getKeyManagers(), tmf.getTrustManagers(), random);
98         sslEngine = sslContext.createSSLEngine();
99         sslEngine.setUseClientMode(false);
100         sslEngine.setNeedClientAuth(true);
101         
102         // Do initial handshake
103         doHandshake(socket, sslEngine);
104         
105         this.clientSelectionKey = this.socket.register(this.selector,
106                 SelectionKey.OP_READ);
107     }
108
109         /**
110          * Sends the OF message out over the socket channel. The message is
111          * encrypted by SSL Engine.
112          * 
113          * @param msg OF message to be sent
114          * @throws Exception
115          */
116     @Override
117     public void asyncSend(OFMessage msg) throws Exception {
118         synchronized (myAppData) {
119                 int msgLen = msg.getLengthU();
120                 if (myAppData.remaining() < msgLen) {
121                         // increase the buffer size so that it can contain this message
122                         ByteBuffer newBuffer = ByteBuffer.allocateDirect(myAppData
123                                         .capacity()
124                                         + msgLen);
125                         myAppData.flip();
126                         newBuffer.put(myAppData);
127                         myAppData = newBuffer;
128                 }
129                 msg.writeTo(myAppData);
130                 myAppData.flip();
131                 sslEngineResult = sslEngine.wrap(myAppData, myNetData);
132                 logger.trace("asyncSend sslEngine wrap: {}", sslEngineResult);
133                 runDelegatedTasks(sslEngineResult, sslEngine);
134
135                 if (!socket.isOpen()) {
136                         return;
137                 }
138
139                 myNetData.flip();
140                 socket.write(myNetData);
141                 if (myNetData.hasRemaining()) {
142                         myNetData.compact();
143                 } else {
144                         myNetData.clear();
145                 }
146
147                 if (myAppData.hasRemaining()) {
148                         myAppData.compact();
149                         this.clientSelectionKey = this.socket.register(
150                                         this.selector, SelectionKey.OP_WRITE, this);
151                 } else {
152                         myAppData.clear();
153                         this.clientSelectionKey = this.socket.register(
154                                         this.selector, SelectionKey.OP_READ, this);
155                 }
156
157                 logger.trace("Message sent: {}", msg.toString());
158         }
159     }
160
161         /**
162          * Resumes sending the remaining messages in the outgoing buffer
163          * @throws Exception
164          */
165     @Override
166     public void resumeSend() throws Exception {
167                 synchronized (myAppData) {
168                         myAppData.flip();
169                         sslEngineResult = sslEngine.wrap(myAppData, myNetData);
170                         logger.trace("resumeSend sslEngine wrap: {}", sslEngineResult);
171                         runDelegatedTasks(sslEngineResult, sslEngine);
172
173                         if (!socket.isOpen()) {
174                                 return;
175                         }
176
177                         myNetData.flip();
178                         socket.write(myNetData);
179                         if (myNetData.hasRemaining()) {
180                                 myNetData.compact();
181                         } else {
182                                 myNetData.clear();
183                         }
184
185                         if (myAppData.hasRemaining()) {
186                                 myAppData.compact();
187                                 this.clientSelectionKey = this.socket.register(this.selector,
188                                                 SelectionKey.OP_WRITE, this);
189                         } else {
190                                 myAppData.clear();
191                                 this.clientSelectionKey = this.socket.register(this.selector,
192                                                 SelectionKey.OP_READ, this);
193                         }
194                 }
195     }
196
197         /**
198          * Reads the incoming network data from the socket, decryptes them and then
199          * retrieves the OF messages.
200          * 
201          * @return list of OF messages
202          * @throws Exception
203          */
204     @Override
205     public List<OFMessage> readMessages() throws Exception {
206                 if (!socket.isOpen()) {
207                         return null;
208                 }
209
210                 List<OFMessage> msgs = null;
211         int bytesRead = -1;
212         int countDown = 50;             
213
214         bytesRead = socket.read(peerNetData);
215         if (bytesRead < 0) {
216                 logger.debug("Message read operation failed");
217                         throw new AsynchronousCloseException();
218         }
219
220         do {                    
221                 peerNetData.flip();
222                 sslEngineResult = sslEngine.unwrap(peerNetData, peerAppData);
223                 if (peerNetData.hasRemaining()) {
224                         peerNetData.compact();
225                 } else {
226                         peerNetData.clear();
227                 }
228                 logger.trace("sslEngine unwrap result: {}", sslEngineResult);
229                 runDelegatedTasks(sslEngineResult, sslEngine);
230         } while ((sslEngineResult.getStatus() == SSLEngineResult.Status.OK) &&
231                           peerNetData.hasRemaining() && (--countDown > 0));
232         
233         if (countDown == 0) {
234                 logger.trace("countDown reaches 0. peerNetData pos {} lim {}", peerNetData.position(), peerNetData.limit());
235         }
236
237         peerAppData.flip();
238         msgs = factory.parseMessages(peerAppData);
239         if (peerAppData.hasRemaining()) {
240                 peerAppData.compact();
241         } else {
242                 peerAppData.clear();
243         }
244
245         this.clientSelectionKey = this.socket.register(
246                         this.selector, SelectionKey.OP_READ, this);
247         
248         return msgs;
249     }
250
251     /**
252      *  If the result indicates that we have outstanding tasks to do,
253      *  go ahead and run them in this thread.
254      */
255     private void runDelegatedTasks(SSLEngineResult result,
256                 SSLEngine engine) throws Exception {
257
258         if (result.getHandshakeStatus() == HandshakeStatus.NEED_TASK) {
259                 Runnable runnable;
260                 while ((runnable = engine.getDelegatedTask()) != null) {
261                         logger.debug("\trunning delegated task...");
262                         runnable.run();
263                 }
264                 HandshakeStatus hsStatus = engine.getHandshakeStatus();
265                 if (hsStatus == HandshakeStatus.NEED_TASK) {
266                         throw new Exception(
267                                         "handshake shouldn't need additional tasks");
268                 }
269                 logger.debug("\tnew HandshakeStatus: {}", hsStatus);
270         }
271     }
272
273     private void doHandshake(SocketChannel socket, SSLEngine engine) throws Exception {
274         SSLSession session = engine.getSession();
275         ByteBuffer myAppData = ByteBuffer.allocate(session.getApplicationBufferSize());
276         ByteBuffer peerAppData = ByteBuffer.allocate(session.getApplicationBufferSize());
277         ByteBuffer myNetData = ByteBuffer.allocate(session.getPacketBufferSize());
278         ByteBuffer peerNetData = ByteBuffer.allocate(session.getPacketBufferSize());
279
280         // Begin handshake
281         engine.beginHandshake();
282         SSLEngineResult.HandshakeStatus hs = engine.getHandshakeStatus();
283
284         // Process handshaking message
285         while (hs != SSLEngineResult.HandshakeStatus.FINISHED &&
286                    hs != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) {
287                 switch (hs) {
288                 case NEED_UNWRAP:
289                         // Receive handshaking data from peer
290                         if (socket.read(peerNetData) < 0) {
291                                 throw new AsynchronousCloseException();
292                         }
293
294                         // Process incoming handshaking data
295                         peerNetData.flip();
296                         SSLEngineResult res = engine.unwrap(peerNetData, peerAppData);
297                         peerNetData.compact();
298                         hs = res.getHandshakeStatus();
299
300                         // Check status
301                         switch (res.getStatus()) {
302                         case OK :
303                                 // Handle OK status
304                                 break;
305                         }
306                         break;
307
308                 case NEED_WRAP :
309                         // Empty the local network packet buffer.
310                         myNetData.clear();
311
312                         // Generate handshaking data
313                         res = engine.wrap(myAppData, myNetData);
314                         hs = res.getHandshakeStatus();
315
316                         // Check status
317                         switch (res.getStatus()) {
318                         case OK :
319                                 myNetData.flip();
320
321                                 // Send the handshaking data to peer
322                                 while (myNetData.hasRemaining()) {
323                                         if (socket.write(myNetData) < 0) {
324                                         throw new AsynchronousCloseException();
325                                         }
326                                 }
327                                 break;
328                         }
329                         break;
330
331                 case NEED_TASK :
332                         // Handle blocking tasks
333                         Runnable runnable;
334                         while ((runnable = engine.getDelegatedTask()) != null) {
335                                 logger.debug("\trunning delegated task...");
336                                 runnable.run();
337                         }
338                         hs = engine.getHandshakeStatus();
339                         if (hs == HandshakeStatus.NEED_TASK) {
340                                 throw new Exception(
341                                                 "handshake shouldn't need additional tasks");
342                         }
343                         logger.debug("\tnew HandshakeStatus: {}", hs);
344                         break;
345                 }
346         }
347     }
348     
349     private void createBuffers(SSLEngine engine) {
350         SSLSession session = engine.getSession();
351         this.myAppData = ByteBuffer.allocate(session.getApplicationBufferSize());
352         this.peerAppData = ByteBuffer.allocate(session.getApplicationBufferSize());
353         this.myNetData = ByteBuffer.allocate(session.getPacketBufferSize());
354         this.peerNetData = ByteBuffer.allocate(session.getPacketBufferSize());
355     }
356
357         @Override
358         public void stop() throws IOException {
359                 this.sslEngine = null;
360                 this.sslEngineResult = null;
361                 this.myAppData = null;
362                 this.myNetData = null;
363                 this.peerAppData = null;
364                 this.peerNetData = null;
365             
366             if (this.kfd != null) {
367                 this.kfd.close();
368                 this.kfd = null;
369                 }
370                 if (this.tfd != null) {
371                         this.tfd.close();
372                         this.tfd = null;
373                 }
374         }
375 }