Add TLS support in the Opendaylight Controller:
[controller.git] / opendaylight / protocol_plugins / openflow / src / main / java / org / opendaylight / controller / protocol_plugin / openflow / core / internal / SecureMessageReadWriteService.java
1
2 /*
3  * Copyright (c) 2013 Cisco Systems, Inc. and others.  All rights reserved.
4  *
5  * This program and the accompanying materials are made available under the
6  * terms of the Eclipse Public License v1.0 which accompanies this distribution,
7  * and is available at http://www.eclipse.org/legal/epl-v10.html
8  */
9
10 package org.opendaylight.controller.protocol_plugin.openflow.core.internal;
11
12 import java.io.FileInputStream;
13 import java.nio.ByteBuffer;
14 import java.nio.channels.AsynchronousCloseException;
15 import java.nio.channels.SelectionKey;
16 import java.nio.channels.Selector;
17 import java.nio.channels.SocketChannel;
18 import java.security.KeyStore;
19 import java.security.SecureRandom;
20 import java.util.List;
21 import javax.net.ssl.KeyManagerFactory;
22 import javax.net.ssl.SSLContext;
23 import javax.net.ssl.SSLEngine;
24 import javax.net.ssl.SSLEngineResult;
25 import javax.net.ssl.SSLSession;
26 import javax.net.ssl.TrustManagerFactory;
27 import javax.net.ssl.SSLEngineResult.HandshakeStatus;
28 import org.opendaylight.controller.protocol_plugin.openflow.core.IMessageReadWrite;
29 import org.openflow.protocol.OFMessage;
30 import org.openflow.protocol.factory.BasicFactory;
31 import org.slf4j.Logger;
32 import org.slf4j.LoggerFactory;
33
34 /**
35  * This class implements methods to read/write messages over an established
36  * socket channel. The data exchange is encrypted/decrypted by SSLEngine.
37  */
38 public class SecureMessageReadWriteService implements IMessageReadWrite {
39     private static final Logger logger = LoggerFactory
40             .getLogger(SecureMessageReadWriteService.class);
41
42     private Selector selector;
43     private SelectionKey clientSelectionKey;
44     private SocketChannel socket;
45     private BasicFactory factory;
46
47     private SSLEngine sslEngine;
48         private SSLEngineResult sslEngineResult;        // results from sslEngine last operation
49     private ByteBuffer myAppData;                               // clear text message to be sent
50     private ByteBuffer myNetData;                       // encrypted message to be sent
51     private ByteBuffer peerAppData;                             // clear text message received from the switch
52     private ByteBuffer peerNetData;                     // encrypted message from the switch
53
54     public SecureMessageReadWriteService(SocketChannel socket, Selector selector) throws Exception {
55         this.socket = socket;
56         this.selector = selector;
57         this.factory = new BasicFactory();
58
59         createSecureChannel(socket);
60         createBuffers(sslEngine);
61     }
62
63         /**
64          * Bring up secure channel using SSL Engine
65          * 
66          * @param socket TCP socket channel
67          * @throws Exception
68          */
69     private void createSecureChannel(SocketChannel socket) throws Exception {
70         String keyStoreFile = System.getProperty("controllerKeyStore");
71         String keyStorePassword = System.getProperty("controllerKeyStorePassword");
72         String trustStoreFile = System.getProperty("controllerTrustStore");;
73         String trustStorePassword = System.getProperty("controllerTrustStorePassword");;
74
75         KeyStore ks = KeyStore.getInstance("JKS");
76         KeyStore ts = KeyStore.getInstance("JKS");
77         KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
78         TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509");
79         ks.load(new FileInputStream(keyStoreFile), keyStorePassword.toCharArray());
80         ts.load(new FileInputStream(trustStoreFile), trustStorePassword.toCharArray());
81         kmf.init(ks, keyStorePassword.toCharArray());
82         tmf.init(ts);
83
84         SecureRandom random = new SecureRandom();
85         random.nextInt();
86
87         SSLContext sslContext = SSLContext.getInstance("TLS");
88         sslContext.init(kmf.getKeyManagers(), tmf.getTrustManagers(), random);
89         sslEngine = sslContext.createSSLEngine();
90         sslEngine.setUseClientMode(false);
91         sslEngine.setNeedClientAuth(true);
92         
93         // Do initial handshake
94         doHandshake(socket, sslEngine);
95         
96         this.clientSelectionKey = this.socket.register(this.selector,
97                 SelectionKey.OP_READ);
98     }
99
100         /**
101          * Sends the OF message out over the socket channel. The message is
102          * encrypted by SSL Engine.
103          * 
104          * @param msg OF message to be sent
105          * @throws Exception
106          */
107     @Override
108     public void asyncSend(OFMessage msg) throws Exception {
109         synchronized (myAppData) {
110                 int msgLen = msg.getLengthU();
111                 if (myAppData.remaining() < msgLen) {
112                         // increase the buffer size so that it can contain this message
113                         ByteBuffer newBuffer = ByteBuffer.allocateDirect(myAppData
114                                         .capacity()
115                                         + msgLen);
116                         myAppData.flip();
117                         newBuffer.put(myAppData);
118                         myAppData = newBuffer;
119                 }
120                 msg.writeTo(myAppData);
121                 myAppData.flip();
122                 sslEngineResult = sslEngine.wrap(myAppData, myNetData);
123                 logger.trace("asyncSend sslEngine wrap: {}", sslEngineResult);
124                 runDelegatedTasks(sslEngineResult, sslEngine);
125
126                 if (!socket.isOpen()) {
127                         return;
128                 }
129
130                 myNetData.flip();
131                 socket.write(myNetData);
132                 if (myNetData.hasRemaining()) {
133                         myNetData.compact();
134                 } else {
135                         myNetData.clear();
136                 }
137
138                 if (myAppData.hasRemaining()) {
139                         myAppData.compact();
140                         this.clientSelectionKey = this.socket.register(
141                                         this.selector, SelectionKey.OP_WRITE, this);
142                 } else {
143                         myAppData.clear();
144                         this.clientSelectionKey = this.socket.register(
145                                         this.selector, SelectionKey.OP_READ, this);
146                 }
147
148                 logger.trace("Message sent: {}", msg.toString());
149         }
150     }
151
152         /**
153          * Resumes sending the remaining messages in the outgoing buffer
154          * @throws Exception
155          */
156     @Override
157     public void resumeSend() throws Exception {
158                 synchronized (myAppData) {
159                         myAppData.flip();
160                         sslEngineResult = sslEngine.wrap(myAppData, myNetData);
161                         logger.trace("resumeSend sslEngine wrap: {}", sslEngineResult);
162                         runDelegatedTasks(sslEngineResult, sslEngine);
163
164                         if (!socket.isOpen()) {
165                                 return;
166                         }
167
168                         myNetData.flip();
169                         socket.write(myNetData);
170                         if (myNetData.hasRemaining()) {
171                                 myNetData.compact();
172                         } else {
173                                 myNetData.clear();
174                         }
175
176                         if (myAppData.hasRemaining()) {
177                                 myAppData.compact();
178                                 this.clientSelectionKey = this.socket.register(this.selector,
179                                                 SelectionKey.OP_WRITE, this);
180                         } else {
181                                 myAppData.clear();
182                                 this.clientSelectionKey = this.socket.register(this.selector,
183                                                 SelectionKey.OP_READ, this);
184                         }
185                 }
186     }
187
188         /**
189          * Reads the incoming network data from the socket, decryptes them and then
190          * retrieves the OF messages.
191          * 
192          * @return list of OF messages
193          * @throws Exception
194          */
195     @Override
196     public List<OFMessage> readMessages() throws Exception {
197                 if (!socket.isOpen()) {
198                         return null;
199                 }
200
201                 List<OFMessage> msgs = null;
202         int bytesRead = -1;
203         int countDown = 50;             
204
205         bytesRead = socket.read(peerNetData);
206         if (bytesRead < 0) {
207                         throw new AsynchronousCloseException();
208         }
209
210         do {                    
211                 peerNetData.flip();
212                 sslEngineResult = sslEngine.unwrap(peerNetData, peerAppData);
213                 if (peerNetData.hasRemaining()) {
214                         peerNetData.compact();
215                 } else {
216                         peerNetData.clear();
217                 }
218                 logger.trace("sslEngine unwrap result: {}", sslEngineResult);
219                 runDelegatedTasks(sslEngineResult, sslEngine);
220         } while ((sslEngineResult.getStatus() == SSLEngineResult.Status.OK) &&
221                           peerNetData.hasRemaining() && (--countDown > 0));
222         
223         if (countDown == 0) {
224                 logger.trace("countDown reaches 0. peerNetData pos {} lim {}", peerNetData.position(), peerNetData.limit());
225         }
226
227         peerAppData.flip();
228         msgs = factory.parseMessages(peerAppData);
229         if (peerAppData.hasRemaining()) {
230                 peerAppData.compact();
231         } else {
232                 peerAppData.clear();
233         }
234
235         this.clientSelectionKey = this.socket.register(
236                         this.selector, SelectionKey.OP_READ, this);
237         
238         return msgs;
239     }
240
241     /**
242      *  If the result indicates that we have outstanding tasks to do,
243      *  go ahead and run them in this thread.
244      */
245     private void runDelegatedTasks(SSLEngineResult result,
246                 SSLEngine engine) throws Exception {
247
248         if (result.getHandshakeStatus() == HandshakeStatus.NEED_TASK) {
249                 Runnable runnable;
250                 while ((runnable = engine.getDelegatedTask()) != null) {
251                         logger.debug("\trunning delegated task...");
252                         runnable.run();
253                 }
254                 HandshakeStatus hsStatus = engine.getHandshakeStatus();
255                 if (hsStatus == HandshakeStatus.NEED_TASK) {
256                         throw new Exception(
257                                         "handshake shouldn't need additional tasks");
258                 }
259                 logger.debug("\tnew HandshakeStatus: {}", hsStatus);
260         }
261     }
262
263     private void doHandshake(SocketChannel socket, SSLEngine engine) throws Exception {
264         SSLSession session = engine.getSession();
265         ByteBuffer myAppData = ByteBuffer.allocate(session.getApplicationBufferSize());
266         ByteBuffer peerAppData = ByteBuffer.allocate(session.getApplicationBufferSize());
267         ByteBuffer myNetData = ByteBuffer.allocate(session.getPacketBufferSize());
268         ByteBuffer peerNetData = ByteBuffer.allocate(session.getPacketBufferSize());
269
270         // Begin handshake
271         engine.beginHandshake();
272         SSLEngineResult.HandshakeStatus hs = engine.getHandshakeStatus();
273
274         // Process handshaking message
275         while (hs != SSLEngineResult.HandshakeStatus.FINISHED &&
276                    hs != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) {
277                 switch (hs) {
278                 case NEED_UNWRAP:
279                         // Receive handshaking data from peer
280                         if (socket.read(peerNetData) < 0) {
281                                 throw new AsynchronousCloseException();
282                         }
283
284                         // Process incoming handshaking data
285                         peerNetData.flip();
286                         SSLEngineResult res = engine.unwrap(peerNetData, peerAppData);
287                         peerNetData.compact();
288                         hs = res.getHandshakeStatus();
289
290                         // Check status
291                         switch (res.getStatus()) {
292                         case OK :
293                                 // Handle OK status
294                                 break;
295                         }
296                         break;
297
298                 case NEED_WRAP :
299                         // Empty the local network packet buffer.
300                         myNetData.clear();
301
302                         // Generate handshaking data
303                         res = engine.wrap(myAppData, myNetData);
304                         hs = res.getHandshakeStatus();
305
306                         // Check status
307                         switch (res.getStatus()) {
308                         case OK :
309                                 myNetData.flip();
310
311                                 // Send the handshaking data to peer
312                                 while (myNetData.hasRemaining()) {
313                                         if (socket.write(myNetData) < 0) {
314                                         throw new AsynchronousCloseException();
315                                         }
316                                 }
317                                 break;
318                         }
319                         break;
320
321                 case NEED_TASK :
322                         // Handle blocking tasks
323                         Runnable runnable;
324                         while ((runnable = engine.getDelegatedTask()) != null) {
325                                 logger.debug("\trunning delegated task...");
326                                 runnable.run();
327                         }
328                         hs = engine.getHandshakeStatus();
329                         if (hs == HandshakeStatus.NEED_TASK) {
330                                 throw new Exception(
331                                                 "handshake shouldn't need additional tasks");
332                         }
333                         logger.debug("\tnew HandshakeStatus: {}", hs);
334                         break;
335                 }
336         }
337     }
338     
339     private void createBuffers(SSLEngine engine) {
340         SSLSession session = engine.getSession();
341         this.myAppData = ByteBuffer.allocate(session.getApplicationBufferSize());
342         this.peerAppData = ByteBuffer.allocate(session.getApplicationBufferSize());
343         this.myNetData = ByteBuffer.allocate(session.getPacketBufferSize());
344         this.peerNetData = ByteBuffer.allocate(session.getPacketBufferSize());
345     }
346 }