Clear the receive buffer when message parser throws exceptions. Don't terminate the...
[controller.git] / opendaylight / protocol_plugins / openflow / src / main / java / org / opendaylight / controller / protocol_plugin / openflow / core / internal / SecureMessageReadWriteService.java
1 /*
2  * Copyright (c) 2013 Cisco Systems, Inc. and others.  All rights reserved.
3  *
4  * This program and the accompanying materials are made available under the
5  * terms of the Eclipse Public License v1.0 which accompanies this distribution,
6  * and is available at http://www.eclipse.org/legal/epl-v10.html
7  */
8
9 package org.opendaylight.controller.protocol_plugin.openflow.core.internal;
10
11 import java.io.FileInputStream;
12 import java.io.FileNotFoundException;
13 import java.io.IOException;
14 import java.nio.ByteBuffer;
15 import java.nio.channels.AsynchronousCloseException;
16 import java.nio.channels.SelectionKey;
17 import java.nio.channels.Selector;
18 import java.nio.channels.SocketChannel;
19 import java.security.KeyStore;
20 import java.security.SecureRandom;
21 import java.util.List;
22
23 import javax.net.ssl.KeyManagerFactory;
24 import javax.net.ssl.SSLContext;
25 import javax.net.ssl.SSLEngine;
26 import javax.net.ssl.SSLEngineResult;
27 import javax.net.ssl.SSLEngineResult.HandshakeStatus;
28 import javax.net.ssl.SSLSession;
29 import javax.net.ssl.TrustManagerFactory;
30
31 import org.opendaylight.controller.protocol_plugin.openflow.core.IMessageReadWrite;
32 import org.openflow.protocol.OFMessage;
33 import org.openflow.protocol.factory.BasicFactory;
34 import org.slf4j.Logger;
35 import org.slf4j.LoggerFactory;
36
37 /**
38  * This class implements methods to read/write messages over an established
39  * socket channel. The data exchange is encrypted/decrypted by SSLEngine.
40  */
41 public class SecureMessageReadWriteService implements IMessageReadWrite {
42     private static final Logger logger = LoggerFactory
43             .getLogger(SecureMessageReadWriteService.class);
44
45     private Selector selector;
46     private SocketChannel socket;
47     private BasicFactory factory;
48
49     private SSLEngine sslEngine;
50     private SSLEngineResult sslEngineResult; // results from sslEngine last operation
51     private ByteBuffer myAppData; // clear text message to be sent
52     private ByteBuffer myNetData; // encrypted message to be sent
53     private ByteBuffer peerAppData; // clear text message received from the
54                                     // switch
55     private ByteBuffer peerNetData; // encrypted message from the switch
56     private FileInputStream kfd = null, tfd = null;
57
58     public SecureMessageReadWriteService(SocketChannel socket, Selector selector)
59             throws Exception {
60         this.socket = socket;
61         this.selector = selector;
62         this.factory = new BasicFactory();
63
64         try {
65             createSecureChannel(socket);
66             createBuffers(sslEngine);
67         } catch (Exception e) {
68             logger.warn("Failed to setup TLS connection {} {}", socket, e);
69             stop();
70             throw e;
71         }
72     }
73
74     /**
75      * Bring up secure channel using SSL Engine
76      *
77      * @param socket
78      *            TCP socket channel
79      * @throws Exception
80      */
81     private void createSecureChannel(SocketChannel socket) throws Exception {
82         String keyStoreFile = System.getProperty("controllerKeyStore");
83         String keyStorePassword = System
84                 .getProperty("controllerKeyStorePassword");
85         String trustStoreFile = System.getProperty("controllerTrustStore");
86         String trustStorePassword = System
87                 .getProperty("controllerTrustStorePassword");
88
89         if (keyStoreFile != null) {
90             keyStoreFile = keyStoreFile.trim();
91         }
92         if ((keyStoreFile == null) || keyStoreFile.isEmpty()) {
93             throw new FileNotFoundException("TLS KeyStore file not found.");
94         }
95         if (keyStorePassword != null) {
96             keyStorePassword = keyStorePassword.trim();
97         }
98         if ((keyStorePassword == null) || keyStorePassword.isEmpty()) {
99             throw new FileNotFoundException("TLS KeyStore Password not provided.");
100         }
101         if (trustStoreFile != null) {
102             trustStoreFile = trustStoreFile.trim();
103         }
104         if ((trustStoreFile == null) || trustStoreFile.isEmpty()) {
105             throw new FileNotFoundException("TLS TrustStore file not found");
106         }
107         if (trustStorePassword != null) {
108             trustStorePassword = trustStorePassword.trim();
109         }
110         if ((trustStorePassword == null) || trustStorePassword.isEmpty()) {
111             throw new FileNotFoundException("TLS TrustStore Password not provided.");
112         }
113
114         KeyStore ks = KeyStore.getInstance("JKS");
115         KeyStore ts = KeyStore.getInstance("JKS");
116         KeyManagerFactory kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
117         TrustManagerFactory tmf = TrustManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
118         kfd = new FileInputStream(keyStoreFile);
119         tfd = new FileInputStream(trustStoreFile);
120         ks.load(kfd, keyStorePassword.toCharArray());
121         ts.load(tfd, trustStorePassword.toCharArray());
122         kmf.init(ks, keyStorePassword.toCharArray());
123         tmf.init(ts);
124
125         SecureRandom random = new SecureRandom();
126         random.nextInt();
127
128         SSLContext sslContext = SSLContext.getInstance("TLS");
129         sslContext.init(kmf.getKeyManagers(), tmf.getTrustManagers(), random);
130         sslEngine = sslContext.createSSLEngine();
131         sslEngine.setUseClientMode(false);
132         sslEngine.setNeedClientAuth(true);
133         sslEngine.setEnabledCipherSuites(new String[] {
134                 "SSL_RSA_WITH_RC4_128_MD5",
135                 "SSL_RSA_WITH_RC4_128_SHA",
136                 "TLS_RSA_WITH_AES_128_CBC_SHA",
137                 "TLS_DHE_RSA_WITH_AES_128_CBC_SHA",
138                 "TLS_DHE_DSS_WITH_AES_128_CBC_SHA",
139                 "SSL_RSA_WITH_3DES_EDE_CBC_SHA",
140                 "SSL_DHE_RSA_WITH_3DES_EDE_CBC_SHA",
141                 "SSL_DHE_DSS_WITH_3DES_EDE_CBC_SHA",
142                 "SSL_RSA_WITH_DES_CBC_SHA",
143                 "SSL_DHE_RSA_WITH_DES_CBC_SHA",
144                 "SSL_DHE_DSS_WITH_DES_CBC_SHA",
145                 "SSL_RSA_EXPORT_WITH_RC4_40_MD5",
146                 "SSL_RSA_EXPORT_WITH_DES40_CBC_SHA",
147                 "SSL_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA",
148                 "SSL_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA",
149                 "TLS_EMPTY_RENEGOTIATION_INFO_SCSV"});
150
151         // Do initial handshake
152         doHandshake(socket, sslEngine);
153
154         this.socket.register(this.selector, SelectionKey.OP_READ);
155     }
156
157     /**
158      * Sends the OF message out over the socket channel. The message is
159      * encrypted by SSL Engine.
160      *
161      * @param msg
162      *            OF message to be sent
163      * @throws Exception
164      */
165     @Override
166     public void asyncSend(OFMessage msg) throws Exception {
167         synchronized (myAppData) {
168             int msgLen = msg.getLengthU();
169             if (myAppData.remaining() < msgLen) {
170                 // increase the buffer size so that it can contain this message
171                 ByteBuffer newBuffer = ByteBuffer.allocateDirect(myAppData
172                         .capacity() + msgLen);
173                 myAppData.flip();
174                 newBuffer.put(myAppData);
175                 myAppData = newBuffer;
176             }
177         }
178         synchronized (myAppData) {
179             msg.writeTo(myAppData);
180             myAppData.flip();
181             sslEngineResult = sslEngine.wrap(myAppData, myNetData);
182             logger.trace("asyncSend sslEngine wrap: {}", sslEngineResult);
183             runDelegatedTasks(sslEngineResult, sslEngine);
184
185             if (!socket.isOpen()) {
186                 return;
187             }
188
189             myNetData.flip();
190             socket.write(myNetData);
191             if (myNetData.hasRemaining()) {
192                 myNetData.compact();
193             } else {
194                 myNetData.clear();
195             }
196
197             if (myAppData.hasRemaining()) {
198                 myAppData.compact();
199                 this.socket.register(this.selector, SelectionKey.OP_WRITE, this);
200             } else {
201                 myAppData.clear();
202                 this.socket.register(this.selector, SelectionKey.OP_READ, this);
203             }
204
205             logger.trace("Message sent: {}", msg);
206         }
207     }
208
209     /**
210      * Resumes sending the remaining messages in the outgoing buffer
211      *
212      * @throws Exception
213      */
214     @Override
215     public void resumeSend() throws Exception {
216         synchronized (myAppData) {
217             myAppData.flip();
218             sslEngineResult = sslEngine.wrap(myAppData, myNetData);
219             logger.trace("resumeSend sslEngine wrap: {}", sslEngineResult);
220             runDelegatedTasks(sslEngineResult, sslEngine);
221
222             if (!socket.isOpen()) {
223                 return;
224             }
225
226             myNetData.flip();
227             socket.write(myNetData);
228             if (myNetData.hasRemaining()) {
229                 myNetData.compact();
230             } else {
231                 myNetData.clear();
232             }
233
234             if (myAppData.hasRemaining()) {
235                 myAppData.compact();
236                 this.socket.register(this.selector, SelectionKey.OP_WRITE, this);
237             } else {
238                 myAppData.clear();
239                 this.socket.register(this.selector, SelectionKey.OP_READ, this);
240             }
241         }
242     }
243
244     /**
245      * Reads the incoming network data from the socket, decryptes them and then
246      * retrieves the OF messages.
247      *
248      * @return list of OF messages
249      * @throws Exception
250      */
251     @Override
252     public List<OFMessage> readMessages() throws Exception {
253         if (!socket.isOpen()) {
254             return null;
255         }
256
257         List<OFMessage> msgs = null;
258         int bytesRead = -1;
259         int countDown = 50;
260
261         bytesRead = socket.read(peerNetData);
262         if (bytesRead < 0) {
263             logger.debug("Message read operation failed");
264             throw new AsynchronousCloseException();
265         }
266
267         do {
268             peerNetData.flip();
269             sslEngineResult = sslEngine.unwrap(peerNetData, peerAppData);
270             if (peerNetData.hasRemaining()) {
271                 peerNetData.compact();
272             } else {
273                 peerNetData.clear();
274             }
275             logger.trace("sslEngine unwrap result: {}", sslEngineResult);
276             runDelegatedTasks(sslEngineResult, sslEngine);
277         } while ((sslEngineResult.getStatus() == SSLEngineResult.Status.OK)
278                 && peerNetData.hasRemaining() && (--countDown > 0));
279
280         if (countDown == 0) {
281             logger.trace("countDown reaches 0. peerNetData pos {} lim {}",
282                     peerNetData.position(), peerNetData.limit());
283         }
284
285         try {
286             peerAppData.flip();
287             msgs = factory.parseMessages(peerAppData);
288             if (peerAppData.hasRemaining()) {
289                 peerAppData.compact();
290             } else {
291                 peerAppData.clear();
292             }
293         } catch (Exception e) {
294             peerAppData.clear();
295             logger.debug("Caught exception: ", e);
296         }
297
298         this.socket.register(this.selector, SelectionKey.OP_READ, this);
299
300         return msgs;
301     }
302
303     /**
304      * If the result indicates that we have outstanding tasks to do, go ahead
305      * and run them in this thread.
306      */
307     private void runDelegatedTasks(SSLEngineResult result, SSLEngine engine)
308             throws Exception {
309
310         if (result.getHandshakeStatus() == HandshakeStatus.NEED_TASK) {
311             Runnable runnable;
312             while ((runnable = engine.getDelegatedTask()) != null) {
313                 logger.debug("\trunning delegated task...");
314                 runnable.run();
315             }
316             HandshakeStatus hsStatus = engine.getHandshakeStatus();
317             if (hsStatus == HandshakeStatus.NEED_TASK) {
318                 throw new Exception("handshake shouldn't need additional tasks");
319             }
320             logger.debug("\tnew HandshakeStatus: {}", hsStatus);
321         }
322     }
323
324     private void doHandshake(SocketChannel socket, SSLEngine engine)
325             throws Exception {
326         SSLSession session = engine.getSession();
327         ByteBuffer myAppData = ByteBuffer.allocate(session
328                 .getApplicationBufferSize());
329         ByteBuffer peerAppData = ByteBuffer.allocate(session
330                 .getApplicationBufferSize());
331         ByteBuffer myNetData = ByteBuffer.allocate(session
332                 .getPacketBufferSize());
333         ByteBuffer peerNetData = ByteBuffer.allocate(session
334                 .getPacketBufferSize());
335
336         // Begin handshake
337         engine.beginHandshake();
338         SSLEngineResult.HandshakeStatus hs = engine.getHandshakeStatus();
339
340         // Process handshaking message
341         while (hs != SSLEngineResult.HandshakeStatus.FINISHED
342                 && hs != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) {
343             switch (hs) {
344             case NEED_UNWRAP:
345                 // Receive handshaking data from peer
346                 if (socket.read(peerNetData) < 0) {
347                     throw new AsynchronousCloseException();
348                 }
349
350                 // Process incoming handshaking data
351                 peerNetData.flip();
352                 SSLEngineResult res = engine.unwrap(peerNetData, peerAppData);
353                 peerNetData.compact();
354                 hs = res.getHandshakeStatus();
355
356                 // Check status
357                 switch (res.getStatus()) {
358                 case OK:
359                     // Handle OK status
360                     break;
361                 }
362                 break;
363
364             case NEED_WRAP:
365                 // Empty the local network packet buffer.
366                 myNetData.clear();
367
368                 // Generate handshaking data
369                 res = engine.wrap(myAppData, myNetData);
370                 hs = res.getHandshakeStatus();
371
372                 // Check status
373                 switch (res.getStatus()) {
374                 case OK:
375                     myNetData.flip();
376
377                     // Send the handshaking data to peer
378                     while (myNetData.hasRemaining()) {
379                         if (socket.write(myNetData) < 0) {
380                             throw new AsynchronousCloseException();
381                         }
382                     }
383                     break;
384                 }
385                 break;
386
387             case NEED_TASK:
388                 // Handle blocking tasks
389                 Runnable runnable;
390                 while ((runnable = engine.getDelegatedTask()) != null) {
391                     logger.debug("\trunning delegated task...");
392                     runnable.run();
393                 }
394                 hs = engine.getHandshakeStatus();
395                 if (hs == HandshakeStatus.NEED_TASK) {
396                     throw new Exception(
397                             "handshake shouldn't need additional tasks");
398                 }
399                 logger.debug("\tnew HandshakeStatus: {}", hs);
400                 break;
401             }
402         }
403     }
404
405     private void createBuffers(SSLEngine engine) {
406         SSLSession session = engine.getSession();
407         this.myAppData = ByteBuffer
408                 .allocate(session.getApplicationBufferSize());
409         this.peerAppData = ByteBuffer.allocate(session
410                 .getApplicationBufferSize());
411         this.myNetData = ByteBuffer.allocate(session.getPacketBufferSize());
412         this.peerNetData = ByteBuffer.allocate(session.getPacketBufferSize());
413     }
414
415     @Override
416     public void stop() throws IOException {
417         this.sslEngine = null;
418         this.sslEngineResult = null;
419         this.myAppData = null;
420         this.myNetData = null;
421         this.peerAppData = null;
422         this.peerNetData = null;
423
424         if (this.kfd != null) {
425             this.kfd.close();
426             this.kfd = null;
427         }
428         if (this.tfd != null) {
429             this.tfd.close();
430             this.tfd = null;
431         }
432     }
433 }