/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tomcat.util.net;
import java.io.EOFException;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.AsynchronousSocketChannel;
import java.nio.channels.CompletionHandler;
import java.nio.channels.WritePendingException;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
import javax.net.ssl.SSLEngineResult.Status;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSession;
import org.apache.juli.logging.Log;
import org.apache.juli.logging.LogFactory;
import org.apache.tomcat.util.buf.ByteBufferUtils;
import org.apache.tomcat.util.net.TLSClientHelloExtractor.ExtractorResult;
import org.apache.tomcat.util.net.openssl.ciphers.Cipher;
import org.apache.tomcat.util.net.openssl.ciphers.Group;
import org.apache.tomcat.util.net.openssl.ciphers.SignatureScheme;
import org.apache.tomcat.util.res.StringManager;
/**
* Implementation of a secure socket channel for NIO2.
*/
public class SecureNio2Channel extends Nio2Channel {
private static final Log log = LogFactory.getLog(SecureNio2Channel.class);
private static final StringManager sm = StringManager.getManager(SecureNio2Channel.class);
// Value determined by observation of what the SSL Engine requested in various scenarios
private static final int DEFAULT_NET_BUFFER_SIZE = 16921;
// Much longer than it should ever need to be but short enough to trigger connection closure if something goes wrong
private static final int HANDSHAKE_WRAP_QUEUE_LENGTH_LIMIT = 100;
protected final Nio2Endpoint endpoint;
protected ByteBuffer netInBuffer;
protected ByteBuffer netOutBuffer;
protected SSLEngine sslEngine;
protected volatile boolean sniComplete = false;
private volatile boolean handshakeComplete = false;
private final AtomicInteger handshakeWrapQueueLength = new AtomicInteger();
private volatile HandshakeStatus handshakeStatus; // gets set by handshake
protected boolean closed;
protected boolean closing;
private final Map<String,List<String>> additionalTlsAttributes = new HashMap<>();
private volatile boolean unwrapBeforeRead;
private final CompletionHandler<Integer,SocketWrapperBase<Nio2Channel>> handshakeReadCompletionHandler;
private final CompletionHandler<Integer,SocketWrapperBase<Nio2Channel>> handshakeWriteCompletionHandler;
public SecureNio2Channel(SocketBufferHandler bufHandler, Nio2Endpoint endpoint) {
super(bufHandler);
this.endpoint = endpoint;
if (endpoint.getSocketProperties().getDirectSslBuffer()) {
netInBuffer = ByteBuffer.allocateDirect(DEFAULT_NET_BUFFER_SIZE);
netOutBuffer = ByteBuffer.allocateDirect(DEFAULT_NET_BUFFER_SIZE);
} else {
netInBuffer = ByteBuffer.allocate(DEFAULT_NET_BUFFER_SIZE);
netOutBuffer = ByteBuffer.allocate(DEFAULT_NET_BUFFER_SIZE);
}
handshakeReadCompletionHandler = new HandshakeReadCompletionHandler();
handshakeWriteCompletionHandler = new HandshakeWriteCompletionHandler();
}
protected void createSSLEngine(String hostName, List<Cipher> clientRequestedCiphers,
List<String> clientRequestedApplicationProtocols, List<String> clientRequestedProtocols,
List<Group> clientSupportedGroups, List<SignatureScheme> clientSignatureSchemes) {
sslEngine = endpoint.createSSLEngine(hostName, clientRequestedCiphers, clientRequestedApplicationProtocols,
clientRequestedProtocols, clientSupportedGroups, clientSignatureSchemes);
}
private class HandshakeReadCompletionHandler implements CompletionHandler<Integer,SocketWrapperBase<Nio2Channel>> {
@Override
public void completed(Integer result, SocketWrapperBase<Nio2Channel> attachment) {
if (result.intValue() < 0) {
failed(new EOFException(), attachment);
} else {
endpoint.processSocket(attachment, SocketEvent.OPEN_READ, false);
}
}
@Override
public void failed(Throwable exc, SocketWrapperBase<Nio2Channel> attachment) {
endpoint.processSocket(attachment, SocketEvent.ERROR, false);
}
}
private class HandshakeWriteCompletionHandler implements CompletionHandler<Integer,SocketWrapperBase<Nio2Channel>> {
@Override
public void completed(Integer result, SocketWrapperBase<Nio2Channel> attachment) {
if (result.intValue() < 0) {
failed(new EOFException(), attachment);
} else {
endpoint.processSocket(attachment, SocketEvent.OPEN_WRITE, false);
}
}
@Override
public void failed(Throwable exc, SocketWrapperBase<Nio2Channel> attachment) {
endpoint.processSocket(attachment, SocketEvent.ERROR, false);
}
}
@Override
public void reset(AsynchronousSocketChannel channel, SocketWrapperBase<Nio2Channel> socket) throws IOException {
super.reset(channel, socket);
sslEngine = null;
sniComplete = false;
handshakeComplete = false;
handshakeWrapQueueLength.set(0);
unwrapBeforeRead = true;
closed = false;
closing = false;
netInBuffer.clear();
}
@Override
public void free() {
super.free();
if (endpoint.getSocketProperties().getDirectSslBuffer()) {
ByteBufferUtils.cleanDirectBuffer(netInBuffer);
ByteBufferUtils.cleanDirectBuffer(netOutBuffer);
}
}
private class FutureFlush implements Future<Boolean> {
private Future<Integer> integer;
private Exception e = null;
protected FutureFlush() {
try {
integer = sc.write(netOutBuffer);
} catch (IllegalStateException e) {
this.e = e;
}
}
@Override
public boolean cancel(boolean mayInterruptIfRunning) {
return e != null || integer.cancel(mayInterruptIfRunning);
}
@Override
public boolean isCancelled() {
return e != null || integer.isCancelled();
}
@Override
public boolean isDone() {
return e != null || integer.isDone();
}
@Override
public Boolean get() throws InterruptedException, ExecutionException {
if (e != null) {
throw new ExecutionException(e);
}
return Boolean.valueOf(integer.get().intValue() >= 0);
}
@Override
public Boolean get(long timeout, TimeUnit unit)
throws InterruptedException, ExecutionException, TimeoutException {
if (e != null) {
throw new ExecutionException(e);
}
return Boolean.valueOf(integer.get(timeout, unit).intValue() >= 0);
}
}
/**
* Flush the channel.
*
* @return <code>true</code> if the network buffer has been flushed out and is empty else <code>false</code> (as a
* future)
*/
@Override
public Future<Boolean> flush() {
return new FutureFlush();
}
/**
* Performs SSL handshake, non-blocking, but performs NEED_TASK on the same thread. Hence, you should never call
* this method using your Acceptor thread, as you would slow down your system significantly.
* <p>
* The return for this operation is 0 if the handshake is complete and a positive value if it is not complete. In
* the event of a positive value coming back, the appropriate read/write will already have been called with an
* appropriate CompletionHandler.
*
* @return 0 if hand shake is complete, negative if the socket needs to close and positive if the handshake is
* incomplete
*
* @throws IOException if an error occurs during the handshake
*/
@Override
public int handshake() throws IOException {
return handshakeInternal(true);
}
protected int handshakeInternal(boolean async) throws IOException {
if (handshakeComplete) {
return 0; // we have done our initial handshake
}
if (!sniComplete) {
int sniResult = processSNI();
if (sniResult == 0) {
sniComplete = true;
} else {
return sniResult;
}
}
SSLEngineResult handshake;
long timeout = endpoint.getConnectionTimeout();
while (!handshakeComplete) {
switch (handshakeStatus) {
case NOT_HANDSHAKING: {
// should never happen
throw new IOException(sm.getString("channel.nio.ssl.notHandshaking"));
}
case FINISHED: {
if (endpoint.hasNegotiableProtocols()) {
if (sslEngine instanceof SSLUtil.ProtocolInfo) {
socketWrapper
.setNegotiatedProtocol(((SSLUtil.ProtocolInfo) sslEngine).getNegotiatedProtocol());
} else {
socketWrapper.setNegotiatedProtocol(sslEngine.getApplicationProtocol());
}
}
// we are complete if we have delivered the last package
handshakeComplete = !netOutBuffer.hasRemaining();
// return 0 if we are complete, otherwise we still have data to write
if (handshakeComplete) {
return 0;
} else {
if (async) {
sc.write(netOutBuffer, AbstractEndpoint.toTimeout(timeout), TimeUnit.MILLISECONDS,
socketWrapper, handshakeWriteCompletionHandler);
} else {
try {
if (timeout > 0) {
sc.write(netOutBuffer).get(timeout, TimeUnit.MILLISECONDS);
} else {
sc.write(netOutBuffer).get();
}
} catch (InterruptedException | ExecutionException | TimeoutException e) {
throw new IOException(sm.getString("channel.nio.ssl.handshakeError"));
}
}
return 1;
}
}
case NEED_WRAP: {
// perform the wrap function
try {
handshake = handshakeWrap();
} catch (SSLException e) {
handshake = handshakeWrap();
throw e;
}
if (handshake.getStatus() == Status.OK) {
if (handshakeStatus == HandshakeStatus.NEED_TASK) {
handshakeStatus = tasks();
}
} else if (handshake.getStatus() == Status.CLOSED) {
return -1;
} else {
// wrap should always work with our buffers
throw new IOException(
sm.getString("channel.nio.ssl.unexpectedStatusDuringWrap", handshake.getStatus()));
}
if (handshakeStatus != HandshakeStatus.NEED_UNWRAP || netOutBuffer.remaining() > 0) {
// should actually return OP_READ if we have NEED_UNWRAP
if (async) {
sc.write(netOutBuffer, AbstractEndpoint.toTimeout(timeout), TimeUnit.MILLISECONDS,
socketWrapper, handshakeWriteCompletionHandler);
} else {
try {
if (timeout > 0) {
sc.write(netOutBuffer).get(timeout, TimeUnit.MILLISECONDS);
} else {
sc.write(netOutBuffer).get();
}
} catch (InterruptedException | ExecutionException | TimeoutException e) {
throw new IOException(sm.getString("channel.nio.ssl.handshakeError"));
}
}
return 1;
}
// fall down to NEED_UNWRAP on the same call, will result in a
// BUFFER_UNDERFLOW if it needs data
}
//$FALL-THROUGH$
case NEED_UNWRAP: {
// perform the unwrap function
handshake = handshakeUnwrap();
if (handshake.getStatus() == Status.OK) {
if (handshakeStatus == HandshakeStatus.NEED_TASK) {
handshakeStatus = tasks();
}
} else if (handshake.getStatus() == Status.BUFFER_UNDERFLOW) {
// read more data
if (async) {
sc.read(netInBuffer, AbstractEndpoint.toTimeout(timeout), TimeUnit.MILLISECONDS,
socketWrapper, handshakeReadCompletionHandler);
} else {
try {
int read;
if (timeout > 0) {
read = sc.read(netInBuffer).get(timeout, TimeUnit.MILLISECONDS).intValue();
} else {
read = sc.read(netInBuffer).get().intValue();
}
if (read == -1) {
throw new EOFException();
}
} catch (InterruptedException | ExecutionException | TimeoutException e) {
throw new IOException(sm.getString("channel.nio.ssl.handshakeError"));
}
}
return 1;
} else {
throw new IOException(
sm.getString("channel.nio.ssl.unexpectedStatusDuringUnwrap", handshake.getStatus()));
}
break;
}
case NEED_TASK: {
handshakeStatus = tasks();
break;
}
default:
throw new IllegalStateException(sm.getString("channel.nio.ssl.invalidStatus", handshakeStatus));
}
}
// return 0 if we are complete, otherwise recurse to process the task
return handshakeComplete ? 0 : handshakeInternal(async);
}
/*
* Peeks at the initial network bytes to determine if the SNI extension is present and, if it is, what host name has
* been requested. Based on the provided host name, configure the SSLEngine for this connection.
*/
protected int processSNI() throws IOException {
// If there is no data to process, trigger a read immediately. This is
// an optimisation for the typical case so we don't create an
// SNIExtractor only to discover there is no data to process
if (netInBuffer.position() == 0) {
sc.read(netInBuffer, AbstractEndpoint.toTimeout(endpoint.getConnectionTimeout()), TimeUnit.MILLISECONDS,
socketWrapper, handshakeReadCompletionHandler);
return 1;
}
TLSClientHelloExtractor extractor = new TLSClientHelloExtractor(netInBuffer);
if (extractor.getResult() == ExtractorResult.UNDERFLOW &&
netInBuffer.capacity() < endpoint.getSniParseLimit()) {
// extractor needed more data to process but netInBuffer was full so
// expand the buffer and read some more data.
int newLimit = Math.min(netInBuffer.capacity() * 2, endpoint.getSniParseLimit());
log.info(sm.getString("channel.nio.ssl.expandNetInBuffer", Integer.toString(newLimit)));
netInBuffer = ByteBufferUtils.expand(netInBuffer, newLimit);
sc.read(netInBuffer, AbstractEndpoint.toTimeout(endpoint.getConnectionTimeout()), TimeUnit.MILLISECONDS,
socketWrapper, handshakeReadCompletionHandler);
return 1;
}
String hostName = null;
List<Cipher> clientRequestedCiphers = null;
List<String> clientRequestedApplicationProtocols = null;
List<Group> clientSupportedGroups = null;
List<SignatureScheme> clientSignatureSchemes = null;
switch (extractor.getResult()) {
case COMPLETE:
hostName = extractor.getSNIValue();
socketWrapper.setSniHostName(hostName);
clientRequestedApplicationProtocols = extractor.getClientRequestedApplicationProtocols();
//$FALL-THROUGH$ to set the client requested ciphers
case NOT_PRESENT:
clientRequestedCiphers = extractor.getClientRequestedCiphers();
clientSupportedGroups = extractor.getClientSupportedGroups();
clientSignatureSchemes = extractor.getClientSignatureSchemes();
break;
case NEED_READ:
sc.read(netInBuffer, AbstractEndpoint.toTimeout(endpoint.getConnectionTimeout()), TimeUnit.MILLISECONDS,
socketWrapper, handshakeReadCompletionHandler);
return 1;
case UNDERFLOW:
// Unable to buffer enough data to read SNI extension data
if (log.isDebugEnabled()) {
log.debug(sm.getString("channel.nio.ssl.sniDefault"));
}
hostName = endpoint.getDefaultSSLHostConfigName();
clientRequestedCiphers = Collections.emptyList();
break;
case NON_SECURE:
netOutBuffer.clear();
netOutBuffer.put(TLSClientHelloExtractor.USE_TLS_RESPONSE);
netOutBuffer.flip();
flush();
throw new IOException(sm.getString("channel.nio.ssl.foundHttp"));
}
if (log.isTraceEnabled()) {
log.trace(sm.getString("channel.nio.ssl.sniHostName", sc, hostName));
}
createSSLEngine(hostName, clientRequestedCiphers, clientRequestedApplicationProtocols,
extractor.getClientRequestedProtocols(), clientSupportedGroups, clientSignatureSchemes);
// Populate additional TLS attributes obtained from the handshake that
// aren't available from the session
additionalTlsAttributes.put(SSLSupport.REQUESTED_PROTOCOL_VERSIONS_KEY,
extractor.getClientRequestedProtocols());
additionalTlsAttributes.put(SSLSupport.REQUESTED_CIPHERS_KEY, extractor.getClientRequestedCipherNames());
// Ensure the application buffers (which have to be created earlier) are
// big enough.
getBufHandler().expand(sslEngine.getSession().getApplicationBufferSize());
if (netOutBuffer.capacity() < sslEngine.getSession().getApplicationBufferSize()) {
// Info for now as we may need to increase DEFAULT_NET_BUFFER_SIZE
log.info(sm.getString("channel.nio.ssl.expandNetOutBuffer",
Integer.toString(sslEngine.getSession().getApplicationBufferSize())));
}
netInBuffer = ByteBufferUtils.expand(netInBuffer, sslEngine.getSession().getPacketBufferSize());
netOutBuffer = ByteBufferUtils.expand(netOutBuffer, sslEngine.getSession().getPacketBufferSize());
// Set limit and position to expected values
netOutBuffer.position(0);
netOutBuffer.limit(0);
// Initiate handshake
sslEngine.beginHandshake();
handshakeStatus = sslEngine.getHandshakeStatus();
return 0;
}
/**
* Force a blocking handshake to take place for this key. This requires that both network and application buffers
* have been emptied out prior to this call taking place, or a IOException will be thrown.
*
* @throws IOException - if an IO exception occurs or if application or network buffers contain
* data
* @throws java.net.SocketTimeoutException - if a socket operation timed out
*/
public void rehandshake() throws IOException {
// validate the network buffers are empty
if (netInBuffer.position() > 0 && netInBuffer.position() < netInBuffer.limit()) {
throw new IOException(sm.getString("channel.nio.ssl.netInputNotEmpty"));
}
if (netOutBuffer.position() > 0 && netOutBuffer.position() < netOutBuffer.limit()) {
throw new IOException(sm.getString("channel.nio.ssl.netOutputNotEmpty"));
}
if (!getBufHandler().isReadBufferEmpty()) {
throw new IOException(sm.getString("channel.nio.ssl.appInputNotEmpty"));
}
if (!getBufHandler().isWriteBufferEmpty()) {
throw new IOException(sm.getString("channel.nio.ssl.appOutputNotEmpty"));
}
netOutBuffer.position(0);
netOutBuffer.limit(0);
netInBuffer.position(0);
netInBuffer.limit(0);
getBufHandler().reset();
handshakeComplete = false;
// initiate handshake
sslEngine.beginHandshake();
handshakeStatus = sslEngine.getHandshakeStatus();
boolean handshaking = true;
try {
while (handshaking) {
int hsStatus = handshakeInternal(false);
switch (hsStatus) {
case -1:
throw new EOFException(sm.getString("channel.nio.ssl.eofDuringHandshake"));
case 0:
handshaking = false;
break;
default: // Some blocking IO occurred, so iterate
}
}
} catch (IOException x) {
closeSilently();
throw x;
} catch (Exception cx) {
closeSilently();
throw new IOException(cx);
}
}
/**
* Executes all the tasks needed on the same thread.
*
* @return the status
*/
protected SSLEngineResult.HandshakeStatus tasks() {
Runnable r;
while ((r = sslEngine.getDelegatedTask()) != null) {
r.run();
}
return sslEngine.getHandshakeStatus();
}
/**
* Performs the WRAP function
*
* @return the result
*
* @throws IOException An IO error occurred
*/
protected SSLEngineResult handshakeWrap() throws IOException {
// this should never be called with a network buffer that contains data
// so we can clear it here.
netOutBuffer.clear();
// perform the wrap
getBufHandler().configureWriteBufferForRead();
SSLEngineResult result = sslEngine.wrap(getBufHandler().getWriteBuffer(), netOutBuffer);
// prepare the results to be written
netOutBuffer.flip();
// set the status
handshakeStatus = result.getHandshakeStatus();
return result;
}
/**
* Perform handshake unwrap
*
* @return the result
*
* @throws IOException An IO error occurred
*/
protected SSLEngineResult handshakeUnwrap() throws IOException {
SSLEngineResult result;
boolean cont;
// loop while we can perform pure SSLEngine data
do {
// prepare the buffer with the incoming data
netInBuffer.flip();
// call unwrap
getBufHandler().configureReadBufferForWrite();
result = sslEngine.unwrap(netInBuffer, getBufHandler().getReadBuffer());
/*
* ByteBuffer.compact() is an optional method but netInBuffer is created from either ByteBuffer.allocate()
* or ByteBuffer.allocateDirect() and the ByteBuffers returned by those methods do implement compact(). The
* ByteBuffer must be in 'read from' mode when compact() is called and will be in 'write to' mode
* afterwards.
*/
netInBuffer.compact();
// read in the status
handshakeStatus = result.getHandshakeStatus();
if (result.getStatus() == SSLEngineResult.Status.OK &&
result.getHandshakeStatus() == HandshakeStatus.NEED_TASK) {
// execute tasks if we need to
handshakeStatus = tasks();
}
// perform another unwrap?
cont = result.getStatus() == SSLEngineResult.Status.OK && handshakeStatus == HandshakeStatus.NEED_UNWRAP;
} while (cont);
return result;
}
public SSLSupport getSSLSupport() {
if (sslEngine != null) {
SSLSession session = sslEngine.getSession();
return endpoint.getSslImplementation().getSSLSupport(session, additionalTlsAttributes);
}
return null;
}
/**
* Sends an SSL close message, will not physically close the connection here.<br>
* To close the connection, you could do something like
*
* <pre>
* <code>
* close();
* while (isOpen() && !myTimeoutFunction()) Thread.sleep(25);
* if ( isOpen() ) close(true); //forces a close if you timed out
* </code>
* </pre>
*
* @throws IOException if an I/O error occurs
* @throws IOException if there is data on the outgoing network buffer, and we are unable to flush it
*/
@Override
public void close() throws IOException {
if (closing) {
return;
}
closing = true;
if (sslEngine == null) {
netOutBuffer.clear();
closed = true;
return;
}
sslEngine.closeOutbound();
long timeout = endpoint.getConnectionTimeout();
try {
if (timeout > 0) {
if (!flush().get(timeout, TimeUnit.MILLISECONDS).booleanValue()) {
closeSilently();
throw new IOException(sm.getString("channel.nio.ssl.remainingDataDuringClose"));
}
} else {
if (!flush().get().booleanValue()) {
closeSilently();
throw new IOException(sm.getString("channel.nio.ssl.remainingDataDuringClose"));
}
}
} catch (InterruptedException | ExecutionException | TimeoutException e) {
closeSilently();
throw new IOException(sm.getString("channel.nio.ssl.remainingDataDuringClose"), e);
} catch (WritePendingException e) {
closeSilently();
throw new IOException(sm.getString("channel.nio.ssl.pendingWriteDuringClose"), e);
}
// prep the buffer for the close message
netOutBuffer.clear();
// perform the close, since we called sslEngine.closeOutbound
SSLEngineResult handshake = sslEngine.wrap(getEmptyBuf(), netOutBuffer);
// we should be in a close state
if (handshake.getStatus() != SSLEngineResult.Status.CLOSED) {
throw new IOException(sm.getString("channel.nio.ssl.invalidCloseState"));
}
// prepare the buffer for writing
netOutBuffer.flip();
// if there is data to be written
try {
if (timeout > 0) {
if (!flush().get(timeout, TimeUnit.MILLISECONDS).booleanValue()) {
closeSilently();
throw new IOException(sm.getString("channel.nio.ssl.remainingDataDuringClose"));
}
} else {
if (!flush().get().booleanValue()) {
closeSilently();
throw new IOException(sm.getString("channel.nio.ssl.remainingDataDuringClose"));
}
}
} catch (InterruptedException | ExecutionException | TimeoutException e) {
closeSilently();
throw new IOException(sm.getString("channel.nio.ssl.remainingDataDuringClose"), e);
} catch (WritePendingException e) {
closeSilently();
throw new IOException(sm.getString("channel.nio.ssl.pendingWriteDuringClose"), e);
}
// is the channel closed?
closed = (!netOutBuffer.hasRemaining() && (handshake.getHandshakeStatus() != HandshakeStatus.NEED_WRAP));
}
@Override
public void close(boolean force) throws IOException {
try {
close();
} finally {
if (force || closed) {
closed = true;
sc.close();
}
}
}
private void closeSilently() {
try {
close(true);
} catch (IOException ioe) {
// This is expected - swallowing the exception is the reason this
// method exists. Log at debug in case someone is interested.
log.debug(sm.getString("channel.nio.ssl.closeSilentError"), ioe);
}
}
private class FutureRead implements Future<Integer> {
private ByteBuffer dst;
private Future<Integer> integer;
private FutureRead(ByteBuffer dst) {
this.dst = dst;
if (unwrapBeforeRead || netInBuffer.position() > 0) {
this.integer = null;
} else {
this.integer = sc.read(netInBuffer);
}
}
@Override
public boolean cancel(boolean mayInterruptIfRunning) {
return integer != null && integer.cancel(mayInterruptIfRunning);
}
@Override
public boolean isCancelled() {
return integer != null && integer.isCancelled();
}
@Override
public boolean isDone() {
return integer == null || integer.isDone();
}
@Override
public Integer get() throws InterruptedException, ExecutionException {
try {
return (integer == null) ? unwrap(netInBuffer.position(), -1, TimeUnit.MILLISECONDS) :
unwrap(integer.get().intValue(), -1, TimeUnit.MILLISECONDS);
} catch (TimeoutException e) {
// Cannot happen: no timeout
throw new ExecutionException(e);
}
}
@Override
public Integer get(long timeout, TimeUnit unit)
throws InterruptedException, ExecutionException, TimeoutException {
return (integer == null) ? unwrap(netInBuffer.position(), timeout, unit) :
unwrap(integer.get(timeout, unit).intValue(), timeout, unit);
}
private Integer unwrap(int nRead, long timeout, TimeUnit unit)
throws ExecutionException, TimeoutException, InterruptedException {
// are we in the middle of closing or closed?
if (closing || closed) {
return Integer.valueOf(-1);
}
// did we reach EOF? if so send EOF up one layer.
if (nRead < 0) {
return Integer.valueOf(-1);
}
// the data read
int read = 0;
// the SSL engine result
SSLEngineResult unwrap;
do {
// prepare the buffer
netInBuffer.flip();
// unwrap the data
try {
unwrap = sslEngine.unwrap(netInBuffer, dst);
} catch (SSLException e) {
throw new ExecutionException(e);
}
// compact the buffer
netInBuffer.compact();
if (unwrap.getStatus() == Status.OK || unwrap.getStatus() == Status.BUFFER_UNDERFLOW) {
// we did receive some data, add it to our total
read += unwrap.bytesProduced();
// perform any tasks if needed
if (unwrap.getHandshakeStatus() == HandshakeStatus.NEED_TASK) {
tasks();
} else if (unwrap.getHandshakeStatus() == HandshakeStatus.NEED_WRAP) {
if (handshakeWrapQueueLength.incrementAndGet() > HANDSHAKE_WRAP_QUEUE_LENGTH_LIMIT) {
throw new ExecutionException(
new IOException(sm.getString("channel.nio.ssl.handshakeWrapQueueTooLong")));
}
}
// if we need more network data, then bail out for now.
if (unwrap.getStatus() == Status.BUFFER_UNDERFLOW) {
if (read == 0) {
integer = sc.read(netInBuffer);
if (timeout > 0) {
return unwrap(integer.get(timeout, unit).intValue(), timeout, unit);
} else {
return unwrap(integer.get().intValue(), -1, TimeUnit.MILLISECONDS);
}
} else {
break;
}
}
} else if (unwrap.getStatus() == Status.BUFFER_OVERFLOW) {
if (read > 0) {
// Buffer overflow can happen if we have read data. Return
// so the destination buffer can be emptied before another
// read is attempted
break;
} else {
// The SSL session has increased the required buffer size
// since the buffer was created.
if (dst == getBufHandler().getReadBuffer()) {
// This is the normal case for this code
getBufHandler().expand(sslEngine.getSession().getApplicationBufferSize());
dst = getBufHandler().getReadBuffer();
} else if (dst == getAppReadBufHandler().getByteBuffer()) {
getAppReadBufHandler().expand(sslEngine.getSession().getApplicationBufferSize());
dst = getAppReadBufHandler().getByteBuffer();
} else {
// Can't expand the buffer as there is no way to signal
// to the caller that the buffer has been replaced.
throw new ExecutionException(new IOException(
sm.getString("channel.nio.ssl.unwrapFailResize", unwrap.getStatus())));
}
}
} else {
// Something else went wrong
throw new ExecutionException(
new IOException(sm.getString("channel.nio.ssl.unwrapFail", unwrap.getStatus())));
}
} while (netInBuffer.position() != 0); // continue to unwrapping as long as the input buffer has stuff
unwrapBeforeRead = !dst.hasRemaining();
return Integer.valueOf(read);
}
}
/**
* Reads a sequence of bytes from this channel into the given buffer.
*
* @param dst The buffer into which bytes are to be transferred
*
* @return The number of bytes read, possibly zero, or <code>-1</code> if the channel has reached end-of-stream
*
* @throws IllegalStateException if the handshake was not completed
*/
@Override
public Future<Integer> read(ByteBuffer dst) {
if (!handshakeComplete) {
throw new IllegalStateException(sm.getString("channel.nio.ssl.incompleteHandshake"));
}
return new FutureRead(dst);
}
private class FutureWrite implements Future<Integer> {
private final ByteBuffer src;
private Future<Integer> integer = null;
private int written = 0;
private Throwable t = null;
private FutureWrite(ByteBuffer src) {
this.src = src;
// are we closing or closed?
if (closing || closed) {
t = new IOException(sm.getString("channel.nio.ssl.closing"));
} else {
wrap();
}
}
@Override
public boolean cancel(boolean mayInterruptIfRunning) {
return integer.cancel(mayInterruptIfRunning);
}
@Override
public boolean isCancelled() {
return integer.isCancelled();
}
@Override
public boolean isDone() {
return integer.isDone();
}
@Override
public Integer get() throws InterruptedException, ExecutionException {
if (t != null) {
throw new ExecutionException(t);
}
if (integer.get().intValue() > 0 && written == 0) {
wrap();
return get();
} else if (netOutBuffer.hasRemaining()) {
integer = sc.write(netOutBuffer);
return get();
} else {
return Integer.valueOf(written);
}
}
@Override
public Integer get(long timeout, TimeUnit unit)
throws InterruptedException, ExecutionException, TimeoutException {
if (t != null) {
throw new ExecutionException(t);
}
if (integer.get(timeout, unit).intValue() > 0 && written == 0) {
wrap();
return get(timeout, unit);
} else if (netOutBuffer.hasRemaining()) {
integer = sc.write(netOutBuffer);
return get(timeout, unit);
} else {
return Integer.valueOf(written);
}
}
protected void wrap() {
try {
if (!netOutBuffer.hasRemaining()) {
netOutBuffer.clear();
SSLEngineResult result = sslEngine.wrap(src, netOutBuffer);
// Call to wrap() will have included any required handshake data
handshakeWrapQueueLength.set(0);
written = result.bytesConsumed();
netOutBuffer.flip();
if (result.getStatus() == Status.OK) {
if (result.getHandshakeStatus() == HandshakeStatus.NEED_TASK) {
tasks();
}
} else {
t = new IOException(sm.getString("channel.nio.ssl.wrapFail", result.getStatus()));
}
}
integer = sc.write(netOutBuffer);
} catch (SSLException e) {
t = e;
}
}
}
/**
* Writes a sequence of bytes to this channel from the given buffer.
*
* @param src The buffer from which bytes are to be retrieved
*
* @return The number of bytes written, possibly zero
*/
@Override
public Future<Integer> write(ByteBuffer src) {
return new FutureWrite(src);
}
@Override
public <A> void read(final ByteBuffer dst, final long timeout, final TimeUnit unit, final A attachment,
final CompletionHandler<Integer,? super A> handler) {
// Check state
if (closing || closed) {
handler.completed(Integer.valueOf(-1), attachment);
return;
}
if (!handshakeComplete) {
throw new IllegalStateException(sm.getString("channel.nio.ssl.incompleteHandshake"));
}
CompletionHandler<Integer,A> readCompletionHandler = new CompletionHandler<>() {
@Override
public void completed(Integer nBytes, A attach) {
if (nBytes.intValue() < 0) {
failed(new EOFException(), attach);
} else {
try {
ByteBuffer dst2 = dst;
// the data read
int read = 0;
// the SSL engine result
SSLEngineResult unwrap;
do {
// prepare the buffer
netInBuffer.flip();
// unwrap the data
unwrap = sslEngine.unwrap(netInBuffer, dst2);
// compact the buffer
netInBuffer.compact();
if (unwrap.getStatus() == Status.OK || unwrap.getStatus() == Status.BUFFER_UNDERFLOW) {
// we did receive some data, add it to our total
read += unwrap.bytesProduced();
// perform any tasks if needed
if (unwrap.getHandshakeStatus() == HandshakeStatus.NEED_TASK) {
tasks();
} else if (unwrap.getHandshakeStatus() == HandshakeStatus.NEED_WRAP) {
if (handshakeWrapQueueLength
.incrementAndGet() > HANDSHAKE_WRAP_QUEUE_LENGTH_LIMIT) {
throw new ExecutionException(new IOException(
sm.getString("channel.nio.ssl.handshakeWrapQueueTooLong")));
}
}
// if we need more network data, then bail out for now.
if (unwrap.getStatus() == Status.BUFFER_UNDERFLOW) {
if (read == 0) {
sc.read(netInBuffer, timeout, unit, attachment, this);
return;
} else {
break;
}
}
} else if (unwrap.getStatus() == Status.BUFFER_OVERFLOW) {
if (read > 0) {
// Buffer overflow can happen if we have read data. Return
// so the destination buffer can be emptied before another
// read is attempted
break;
} else {
// The SSL session has increased the required buffer size
// since the buffer was created.
if (dst2 == getBufHandler().getReadBuffer()) {
// This is the normal case for this code
getBufHandler().expand(sslEngine.getSession().getApplicationBufferSize());
dst2 = getBufHandler().getReadBuffer();
} else if (getAppReadBufHandler() != null &&
dst2 == getAppReadBufHandler().getByteBuffer()) {
getAppReadBufHandler()
.expand(sslEngine.getSession().getApplicationBufferSize());
dst2 = getAppReadBufHandler().getByteBuffer();
} else {
// Can't expand the buffer as there is no way to signal
// to the caller that the buffer has been replaced.
throw new IOException(
sm.getString("channel.nio.ssl.unwrapFailResize", unwrap.getStatus()));
}
}
} else {
// Something else went wrong
throw new IOException(sm.getString("channel.nio.ssl.unwrapFail", unwrap.getStatus()));
}
// continue to unwrap as long as the input buffer has stuff
} while (netInBuffer.position() != 0);
unwrapBeforeRead = !dst2.hasRemaining();
// If everything is OK, so complete
handler.completed(Integer.valueOf(read), attach);
} catch (Exception e) {
failed(e, attach);
}
}
}
@Override
public void failed(Throwable exc, A attach) {
handler.failed(exc, attach);
}
};
if (unwrapBeforeRead || netInBuffer.position() > 0) {
readCompletionHandler.completed(Integer.valueOf(netInBuffer.position()), attachment);
} else {
sc.read(netInBuffer, timeout, unit, attachment, readCompletionHandler);
}
}
@Override
public <A> void read(final ByteBuffer[] dsts, final int offset, final int length, final long timeout,
final TimeUnit unit, final A attachment, final CompletionHandler<Long,? super A> handler) {
if (offset < 0 || dsts == null || (offset + length) > dsts.length) {
throw new IllegalArgumentException();
}
if (closing || closed) {
handler.completed(Long.valueOf(-1), attachment);
return;
}
if (!handshakeComplete) {
throw new IllegalStateException(sm.getString("channel.nio.ssl.incompleteHandshake"));
}
CompletionHandler<Integer,A> readCompletionHandler = new CompletionHandler<>() {
@Override
public void completed(Integer nBytes, A attach) {
if (nBytes.intValue() < 0) {
failed(new EOFException(), attach);
} else {
try {
// the data read
long read = 0;
// the SSL engine result
SSLEngineResult unwrap;
ByteBuffer[] dsts2 = dsts;
int length2 = length;
OverflowState overflowState = OverflowState.NONE;
do {
if (overflowState == OverflowState.PROCESSING) {
overflowState = OverflowState.DONE;
}
// prepare the buffer
netInBuffer.flip();
// unwrap the data
unwrap = sslEngine.unwrap(netInBuffer, dsts2, offset, length2);
// compact the buffer
netInBuffer.compact();
if (unwrap.getStatus() == Status.OK || unwrap.getStatus() == Status.BUFFER_UNDERFLOW) {
// we did receive some data, add it to our total
read += unwrap.bytesProduced();
if (overflowState == OverflowState.DONE) {
// Remove the data read into the overflow buffer
read -= getBufHandler().getReadBuffer().position();
}
// perform any tasks if needed
if (unwrap.getHandshakeStatus() == HandshakeStatus.NEED_TASK) {
tasks();
} else if (unwrap.getHandshakeStatus() == HandshakeStatus.NEED_WRAP) {
if (handshakeWrapQueueLength
.incrementAndGet() > HANDSHAKE_WRAP_QUEUE_LENGTH_LIMIT) {
throw new ExecutionException(new IOException(
sm.getString("channel.nio.ssl.handshakeWrapQueueTooLong")));
}
}
// if we need more network data, then bail out for now.
if (unwrap.getStatus() == Status.BUFFER_UNDERFLOW) {
if (read == 0) {
sc.read(netInBuffer, timeout, unit, attachment, this);
return;
} else {
break;
}
}
} else if (unwrap.getStatus() == Status.BUFFER_OVERFLOW && read > 0) {
// buffer overflow can happen, if we have read data, then
// empty out the dst buffer before we do another read
break;
} else if (unwrap.getStatus() == Status.BUFFER_OVERFLOW) {
// here we should trap BUFFER_OVERFLOW and call expand on the buffer
// for now, throw an exception, as we initialized the buffers
// in the constructor
ByteBuffer readBuffer = getBufHandler().getReadBuffer();
boolean found = false;
boolean resized = true;
for (int i = 0; i < length2; i++) {
// The SSL session has increased the required buffer size
// since the buffer was created.
if (dsts[offset + i] == getBufHandler().getReadBuffer()) {
getBufHandler().expand(sslEngine.getSession().getApplicationBufferSize());
if (dsts[offset + i] == getBufHandler().getReadBuffer()) {
resized = false;
}
dsts[offset + i] = getBufHandler().getReadBuffer();
found = true;
} else if (getAppReadBufHandler() != null &&
dsts[offset + i] == getAppReadBufHandler().getByteBuffer()) {
getAppReadBufHandler()
.expand(sslEngine.getSession().getApplicationBufferSize());
if (dsts[offset + i] == getAppReadBufHandler().getByteBuffer()) {
resized = false;
}
dsts[offset + i] = getAppReadBufHandler().getByteBuffer();
found = true;
}
}
if (found) {
if (!resized) {
throw new IOException(
sm.getString("channel.nio.ssl.unwrapFail", unwrap.getStatus()));
}
} else {
// Add the main read buffer in the destinations and try again
dsts2 = new ByteBuffer[dsts.length + 1];
int dstOffset = 0;
for (int i = 0; i < dsts.length + 1; i++) {
if (i == offset + length) {
dsts2[i] = readBuffer;
dstOffset = -1;
} else {
dsts2[i] = dsts[i + dstOffset];
}
}
length2 = length + 1;
getBufHandler().configureReadBufferForWrite();
overflowState = OverflowState.PROCESSING;
}
} else if (unwrap.getStatus() == Status.CLOSED) {
break;
} else {
throw new IOException(sm.getString("channel.nio.ssl.unwrapFail", unwrap.getStatus()));
}
} while ((netInBuffer.position() != 0 || overflowState == OverflowState.PROCESSING) &&
overflowState != OverflowState.DONE);
int capacity = 0;
final int endOffset = offset + length;
for (int i = offset; i < endOffset; i++) {
capacity += dsts[i].remaining();
}
unwrapBeforeRead = capacity == 0;
// If everything is OK, so complete
handler.completed(Long.valueOf(read), attach);
} catch (Exception e) {
failed(e, attach);
}
}
}
@Override
public void failed(Throwable exc, A attach) {
handler.failed(exc, attach);
}
};
if (unwrapBeforeRead || netInBuffer.position() > 0) {
readCompletionHandler.completed(Integer.valueOf(netInBuffer.position()), attachment);
} else {
sc.read(netInBuffer, timeout, unit, attachment, readCompletionHandler);
}
}
@Override
public <A> void write(final ByteBuffer src, final long timeout, final TimeUnit unit, final A attachment,
final CompletionHandler<Integer,? super A> handler) {
// Check state
if (closing || closed) {
handler.failed(new IOException(sm.getString("channel.nio.ssl.closing")), attachment);
return;
}
try {
// Prepare the output buffer
netOutBuffer.clear();
// Wrap the source data into the internal buffer
SSLEngineResult result = sslEngine.wrap(src, netOutBuffer);
// Call to wrap() will have included any required handshake data
handshakeWrapQueueLength.set(0);
final int written = result.bytesConsumed();
netOutBuffer.flip();
if (result.getStatus() == Status.OK) {
if (result.getHandshakeStatus() == HandshakeStatus.NEED_TASK) {
tasks();
}
// Write data to the channel
sc.write(netOutBuffer, timeout, unit, attachment, new CompletionHandler<>() {
@Override
public void completed(Integer nBytes, A attach) {
if (nBytes.intValue() < 0) {
failed(new EOFException(), attach);
} else if (netOutBuffer.hasRemaining()) {
sc.write(netOutBuffer, timeout, unit, attachment, this);
} else if (written == 0) {
// Special case, start over to avoid code duplication
write(src, timeout, unit, attachment, handler);
} else {
// Call the handler completed method with the
// consumed bytes number
handler.completed(Integer.valueOf(written), attach);
}
}
@Override
public void failed(Throwable exc, A attach) {
handler.failed(exc, attach);
}
});
} else {
throw new IOException(sm.getString("channel.nio.ssl.wrapFail", result.getStatus()));
}
} catch (Exception e) {
handler.failed(e, attachment);
}
}
@Override
public <A> void write(final ByteBuffer[] srcs, final int offset, final int length, final long timeout,
final TimeUnit unit, final A attachment, final CompletionHandler<Long,? super A> handler) {
if ((offset < 0) || (length < 0) || (offset > srcs.length - length)) {
throw new IndexOutOfBoundsException();
}
// Check state
if (closing || closed) {
handler.failed(new IOException(sm.getString("channel.nio.ssl.closing")), attachment);
return;
}
try {
// Prepare the output buffer
netOutBuffer.clear();
// Wrap the source data into the internal buffer
SSLEngineResult result = sslEngine.wrap(srcs, offset, length, netOutBuffer);
final int written = result.bytesConsumed();
netOutBuffer.flip();
if (result.getStatus() == Status.OK) {
if (result.getHandshakeStatus() == HandshakeStatus.NEED_TASK) {
tasks();
}
// Write data to the channel
sc.write(netOutBuffer, timeout, unit, attachment, new CompletionHandler<>() {
@Override
public void completed(Integer nBytes, A attach) {
if (nBytes.intValue() < 0) {
failed(new EOFException(), attach);
} else if (netOutBuffer.hasRemaining()) {
sc.write(netOutBuffer, timeout, unit, attachment, this);
} else if (written == 0) {
// Special case, start over to avoid code duplication
write(srcs, offset, length, timeout, unit, attachment, handler);
} else {
// Call the handler completed method with the
// consumed bytes number
handler.completed(Long.valueOf(written), attach);
}
}
@Override
public void failed(Throwable exc, A attach) {
handler.failed(exc, attach);
}
});
} else {
throw new IOException(sm.getString("channel.nio.ssl.wrapFail", result.getStatus()));
}
} catch (Exception e) {
handler.failed(e, attachment);
}
}
@Override
public boolean isHandshakeComplete() {
return handshakeComplete;
}
@Override
public boolean isClosing() {
return closing;
}
public SSLEngine getSslEngine() {
return sslEngine;
}
public ByteBuffer getEmptyBuf() {
return emptyBuf;
}
private enum OverflowState {
NONE,
PROCESSING,
DONE
}
}