ttomcat-1778514358873.zip-extract/apache-tomcat-11.0.18-src/java/org/apache/tomcat/websocket/WsWebSocketContainer.java

Path
ttomcat-1778514358873.zip-extract/apache-tomcat-11.0.18-src/java/org/apache/tomcat/websocket/WsWebSocketContainer.java
Status
scanned
Type
file
Name
WsWebSocketContainer.java
Extension
.java
Programming language
Java
Mime type
text/plain
File type
ASCII text, with CRLF line terminators
Tag

      
    
Rootfs path

      
    
Size
44204 (43.2 KB)
MD5
783bcd9cf8dd5db21599ec1499fa1b17
SHA1
d437f534726cf613ade6b4868868838af29ec8a9
SHA256
a01dc9ff9d4218b3e7f6679a5adc516401145f5a4234a497df49f415a4e3ed65
SHA512

      
    
SHA1_git
8997c01c344f3bc1e2e2afd74d59a05f9831bb0e
Is binary

      
    
Is text
True
Is archive

      
    
Is media

      
    
Is legal

      
    
Is manifest

      
    
Is readme

      
    
Is top level

      
    
Is key file

      
    
WsWebSocketContainer.java | 43.2 KB |

/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.tomcat.websocket; import java.io.EOFException; import java.io.IOException; import java.net.InetSocketAddress; import java.net.Proxy; import java.net.ProxySelector; import java.net.SocketAddress; import java.net.URI; import java.net.URISyntaxException; import java.nio.ByteBuffer; import java.nio.channels.AsynchronousChannelGroup; import java.nio.channels.AsynchronousSocketChannel; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; import java.util.Base64; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Map.Entry; import java.util.Random; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLException; import javax.net.ssl.SSLParameters; import jakarta.websocket.ClientEndpoint; import jakarta.websocket.ClientEndpointConfig; import jakarta.websocket.CloseReason; import jakarta.websocket.CloseReason.CloseCodes; import jakarta.websocket.DeploymentException; import jakarta.websocket.Endpoint; import jakarta.websocket.Extension; import jakarta.websocket.HandshakeResponse; import jakarta.websocket.Session; import jakarta.websocket.WebSocketContainer; import org.apache.juli.logging.Log; import org.apache.juli.logging.LogFactory; import org.apache.tomcat.InstanceManager; import org.apache.tomcat.InstanceManagerBindings; import org.apache.tomcat.util.buf.StringUtils; import org.apache.tomcat.util.collections.CaseInsensitiveKeyMap; import org.apache.tomcat.util.res.StringManager; public class WsWebSocketContainer implements WebSocketContainer, BackgroundProcess { private static final StringManager sm = StringManager.getManager(WsWebSocketContainer.class); private static final Random RANDOM = new Random(); private static final byte[] CRLF = new byte[] { 13, 10 }; private static final byte[] GET_BYTES = "GET ".getBytes(StandardCharsets.ISO_8859_1); private static final byte[] ROOT_URI_BYTES = "/".getBytes(StandardCharsets.ISO_8859_1); private static final byte[] HTTP_VERSION_BYTES = " HTTP/1.1 ".getBytes(StandardCharsets.ISO_8859_1); private volatile AsynchronousChannelGroup asynchronousChannelGroup = null; private final Object asynchronousChannelGroupLock = new Object(); private final Log log = LogFactory.getLog(WsWebSocketContainer.class); // must not be static // Server side uses the endpoint path as the key // Client side uses the client endpoint instance private final Map<Object,Set<WsSession>> endpointSessionMap = new HashMap<>(); private final Map<WsSession,WsSession> sessions = new ConcurrentHashMap<>(); private final Object endPointSessionMapLock = new Object(); private long defaultAsyncTimeout = -1; private int maxBinaryMessageBufferSize = Constants.DEFAULT_BUFFER_SIZE; private int maxTextMessageBufferSize = Constants.DEFAULT_BUFFER_SIZE; private volatile long defaultMaxSessionIdleTimeout = 0; private int backgroundProcessCount = 0; private int processPeriod = Constants.DEFAULT_PROCESS_PERIOD; private InstanceManager instanceManager; protected InstanceManager getInstanceManager(ClassLoader classLoader) { if (instanceManager != null) { return instanceManager; } return InstanceManagerBindings.get(classLoader); } protected void setInstanceManager(InstanceManager instanceManager) { this.instanceManager = instanceManager; } @Override public Session connectToServer(Object pojo, URI path) throws DeploymentException { ClientEndpointConfig config = createClientEndpointConfig(pojo.getClass()); ClientEndpointHolder holder = new PojoHolder(pojo, config); return connectToServerRecursive(holder, config, path, new HashSet<>()); } @Override public Session connectToServer(Class<?> annotatedEndpointClass, URI path) throws DeploymentException { ClientEndpointConfig config = createClientEndpointConfig(annotatedEndpointClass); ClientEndpointHolder holder = new PojoClassHolder(annotatedEndpointClass, config); return connectToServerRecursive(holder, config, path, new HashSet<>()); } private ClientEndpointConfig createClientEndpointConfig(Class<?> annotatedEndpointClass) throws DeploymentException { ClientEndpoint annotation = annotatedEndpointClass.getAnnotation(ClientEndpoint.class); if (annotation == null) { throw new DeploymentException( sm.getString("wsWebSocketContainer.missingAnnotation", annotatedEndpointClass.getName())); } Class<? extends ClientEndpointConfig.Configurator> configuratorClazz = annotation.configurator(); ClientEndpointConfig.Configurator configurator = null; if (!ClientEndpointConfig.Configurator.class.equals(configuratorClazz)) { try { configurator = configuratorClazz.getConstructor().newInstance(); } catch (ReflectiveOperationException e) { throw new DeploymentException(sm.getString("wsWebSocketContainer.defaultConfiguratorFail"), e); } } ClientEndpointConfig.Builder builder = ClientEndpointConfig.Builder.create(); // Avoid NPE when using RI API JAR - see BZ 56343 if (configurator != null) { builder.configurator(configurator); } return builder.decoders(Arrays.asList(annotation.decoders())).encoders(Arrays.asList(annotation.encoders())) .preferredSubprotocols(Arrays.asList(annotation.subprotocols())).build(); } @Override public Session connectToServer(Class<? extends Endpoint> clazz, ClientEndpointConfig clientEndpointConfiguration, URI path) throws DeploymentException { ClientEndpointHolder holder = new EndpointClassHolder(clazz); return connectToServerRecursive(holder, clientEndpointConfiguration, path, new HashSet<>()); } @Override public Session connectToServer(Endpoint endpoint, ClientEndpointConfig clientEndpointConfiguration, URI path) throws DeploymentException { ClientEndpointHolder holder = new EndpointHolder(endpoint); return connectToServerRecursive(holder, clientEndpointConfiguration, path, new HashSet<>()); } private Session connectToServerRecursive(ClientEndpointHolder clientEndpointHolder, ClientEndpointConfig clientEndpointConfiguration, URI path, Set<URI> redirectSet) throws DeploymentException { if (log.isTraceEnabled()) { log.trace(sm.getString("wsWebSocketContainer.connect.entry", clientEndpointHolder.getClassName(), path)); } boolean secure = false; ByteBuffer proxyConnect = null; URI proxyPath; // Validate scheme (and build proxyPath) String scheme = path.getScheme(); if ("ws".equalsIgnoreCase(scheme)) { proxyPath = URI.create("http" + path.toString().substring(2)); } else if ("wss".equalsIgnoreCase(scheme)) { proxyPath = URI.create("https" + path.toString().substring(3)); secure = true; } else { throw new DeploymentException(sm.getString("wsWebSocketContainer.pathWrongScheme", scheme)); } // Validate host String host = path.getHost(); if (host == null) { throw new DeploymentException(sm.getString("wsWebSocketContainer.pathNoHost")); } int port = path.getPort(); SocketAddress sa = null; // Check to see if a proxy is configured. Javadoc indicates return value // will never be null List<Proxy> proxies = ProxySelector.getDefault().select(proxyPath); Proxy selectedProxy = null; for (Proxy proxy : proxies) { if (proxy.type().equals(Proxy.Type.HTTP)) { sa = proxy.address(); if (sa instanceof InetSocketAddress inet) { if (inet.isUnresolved()) { sa = new InetSocketAddress(inet.getHostName(), inet.getPort()); } } selectedProxy = proxy; break; } } // If the port is not explicitly specified, compute it based on the // scheme if (port == -1) { if ("ws".equalsIgnoreCase(scheme)) { port = 80; } else { // Must be wss due to scheme validation above port = 443; } } Map<String,Object> userProperties = clientEndpointConfiguration.getUserProperties(); // If sa is null, no proxy is configured so need to create sa if (sa == null) { sa = new InetSocketAddress(host, port); } else { proxyConnect = createProxyRequest(host, port, (String) userProperties.get(Constants.PROXY_AUTHORIZATION_HEADER_NAME)); } // Create the initial HTTP request to open the WebSocket connection Map<String,List<String>> reqHeaders = createRequestHeaders(host, port, secure, clientEndpointConfiguration); clientEndpointConfiguration.getConfigurator().beforeRequest(reqHeaders); if (Constants.DEFAULT_ORIGIN_HEADER_VALUE != null && !reqHeaders.containsKey(Constants.ORIGIN_HEADER_NAME)) { List<String> originValues = new ArrayList<>(1); originValues.add(Constants.DEFAULT_ORIGIN_HEADER_VALUE); reqHeaders.put(Constants.ORIGIN_HEADER_NAME, originValues); } ByteBuffer request = createRequest(path, reqHeaders); // Get the connection timeout long timeout = Constants.IO_TIMEOUT_MS_DEFAULT; String timeoutValue = (String) userProperties.get(Constants.IO_TIMEOUT_MS_PROPERTY); if (timeoutValue != null) { timeout = Long.valueOf(timeoutValue).intValue(); } AsynchronousSocketChannel socketChannel; try { socketChannel = AsynchronousSocketChannel.open(getAsynchronousChannelGroup()); } catch (IOException ioe) { throw new DeploymentException(sm.getString("wsWebSocketContainer.asynchronousSocketChannelFail"), ioe); } // Set-up // Same size as the WsFrame input buffer ByteBuffer response = ByteBuffer.allocate(getDefaultMaxBinaryMessageBufferSize()); String subProtocol; boolean success = false; List<Extension> extensionsAgreed = new ArrayList<>(); Transformation transformation = null; AsyncChannelWrapper channel = null; try { // Open the connection Future<Void> fConnect = socketChannel.connect(sa); if (proxyConnect != null) { fConnect.get(timeout, TimeUnit.MILLISECONDS); // Proxy CONNECT is clear text channel = new AsyncChannelWrapperNonSecure(socketChannel); writeRequest(channel, proxyConnect, timeout); HttpResponse httpResponse = processResponse(response, channel, timeout); if (httpResponse.status == Constants.PROXY_AUTHENTICATION_REQUIRED) { return processAuthenticationChallenge(clientEndpointHolder, clientEndpointConfiguration, path, redirectSet, userProperties, request, httpResponse, AuthenticationType.PROXY); } else if (httpResponse.status() != 200) { throw new DeploymentException(sm.getString("wsWebSocketContainer.proxyConnectFail", selectedProxy, Integer.toString(httpResponse.status()))); } } if (secure) { // Regardless of whether a non-secure wrapper was created for a // proxy CONNECT, need to use TLS from this point on so wrap the // original AsynchronousSocketChannel SSLEngine sslEngine = createSSLEngine(clientEndpointConfiguration, host, port); channel = new AsyncChannelWrapperSecure(socketChannel, sslEngine); } else if (channel == null) { // Only need to wrap as this point if it wasn't wrapped to process a // proxy CONNECT channel = new AsyncChannelWrapperNonSecure(socketChannel); } fConnect.get(timeout, TimeUnit.MILLISECONDS); Future<Void> fHandshake = channel.handshake(); fHandshake.get(timeout, TimeUnit.MILLISECONDS); if (log.isTraceEnabled()) { SocketAddress localAddress = null; try { localAddress = channel.getLocalAddress(); } catch (IOException ioe) { // Ignore } log.trace(sm.getString("wsWebSocketContainer.connect.write", Integer.valueOf(request.position()), Integer.valueOf(request.limit()), localAddress)); } writeRequest(channel, request, timeout); HttpResponse httpResponse = processResponse(response, channel, timeout); // Check maximum permitted redirects int maxRedirects = Constants.MAX_REDIRECTIONS_DEFAULT; String maxRedirectsValue = (String) userProperties.get(Constants.MAX_REDIRECTIONS_PROPERTY); if (maxRedirectsValue != null) { maxRedirects = Integer.parseInt(maxRedirectsValue); } if (httpResponse.status != 101) { if (isRedirectStatus(httpResponse.status)) { List<String> locationHeader = httpResponse.handshakeResponse().getHeaders().get(Constants.LOCATION_HEADER_NAME); if (locationHeader == null || locationHeader.isEmpty() || locationHeader.get(0) == null || locationHeader.get(0).isEmpty()) { throw new DeploymentException(sm.getString("wsWebSocketContainer.missingLocationHeader", Integer.toString(httpResponse.status))); } URI redirectLocation = URI.create(locationHeader.get(0)).normalize(); if (!redirectLocation.isAbsolute()) { redirectLocation = path.resolve(redirectLocation); } String redirectScheme = redirectLocation.getScheme().toLowerCase(Locale.ENGLISH); if (redirectScheme.startsWith("http")) { redirectLocation = new URI(redirectScheme.replace("http", "ws"), redirectLocation.getUserInfo(), redirectLocation.getHost(), redirectLocation.getPort(), redirectLocation.getPath(), redirectLocation.getQuery(), redirectLocation.getFragment()); } if (!redirectSet.add(redirectLocation) || redirectSet.size() > maxRedirects) { throw new DeploymentException( sm.getString("wsWebSocketContainer.redirectThreshold", redirectLocation, Integer.toString(redirectSet.size()), Integer.toString(maxRedirects))); } return connectToServerRecursive(clientEndpointHolder, clientEndpointConfiguration, redirectLocation, redirectSet); } else if (httpResponse.status == Constants.UNAUTHORIZED) { return processAuthenticationChallenge(clientEndpointHolder, clientEndpointConfiguration, path, redirectSet, userProperties, request, httpResponse, AuthenticationType.WWW); } else { throw new DeploymentException( sm.getString("wsWebSocketContainer.invalidStatus", Integer.toString(httpResponse.status))); } } HandshakeResponse handshakeResponse = httpResponse.handshakeResponse(); clientEndpointConfiguration.getConfigurator().afterResponse(handshakeResponse); // Sub-protocol List<String> protocolHeaders = handshakeResponse.getHeaders().get(Constants.WS_PROTOCOL_HEADER_NAME); if (protocolHeaders == null || protocolHeaders.isEmpty()) { subProtocol = null; } else if (protocolHeaders.size() == 1) { subProtocol = protocolHeaders.get(0); } else { throw new DeploymentException(sm.getString("wsWebSocketContainer.invalidSubProtocol")); } // Extensions // Should normally only be one header but handle the case of // multiple headers List<String> extHeaders = handshakeResponse.getHeaders().get(Constants.WS_EXTENSIONS_HEADER_NAME); if (extHeaders != null) { for (String extHeader : extHeaders) { Util.parseExtensionHeader(extensionsAgreed, extHeader); } } // Build the transformations TransformationFactory factory = TransformationFactory.getInstance(); for (Extension extension : extensionsAgreed) { List<List<Extension.Parameter>> wrapper = new ArrayList<>(1); wrapper.add(extension.getParameters()); Transformation t = factory.create(extension.getName(), wrapper, false); if (t == null) { throw new DeploymentException(sm.getString("wsWebSocketContainer.invalidExtensionParameters")); } if (transformation == null) { transformation = t; } else { transformation.setNext(t); } } success = true; } catch (ExecutionException | InterruptedException | SSLException | EOFException | TimeoutException | URISyntaxException | AuthenticationException e) { throw new DeploymentException(sm.getString("wsWebSocketContainer.httpRequestFailed", path), e); } finally { if (!success) { if (channel != null) { channel.close(); } else { try { socketChannel.close(); } catch (IOException ioe) { // Ignore } } } } // Switch to WebSocket WsRemoteEndpointImplClient wsRemoteEndpointClient = new WsRemoteEndpointImplClient(channel); WsSession wsSession = new WsSession(clientEndpointHolder, wsRemoteEndpointClient, this, extensionsAgreed, subProtocol, Collections.emptyMap(), secure, clientEndpointConfiguration); WsFrameClient wsFrameClient = new WsFrameClient(response, channel, wsSession, transformation); // WsFrame adds the necessary final transformations. Copy the // completed transformation chain to the remote end point. wsRemoteEndpointClient.setTransformation(wsFrameClient.getTransformation()); wsSession.getLocal().onOpen(wsSession, clientEndpointConfiguration); registerSession(wsSession.getLocal(), wsSession); /* * It is possible that the server sent one or more messages as soon as the WebSocket connection was established. * Depending on the exact timing of when those messages were sent they could be sat in the input buffer waiting * to be read and will not trigger a "data available to read" event. Therefore, it is necessary to process the * input buffer here. Note that this happens on the current thread which means that this thread will be used for * any onMessage notifications. This is a special case. Subsequent "data available to read" events will be * handled by threads from the AsyncChannelGroup's executor. */ wsFrameClient.startInputProcessing(); return wsSession; } private Session processAuthenticationChallenge(ClientEndpointHolder clientEndpointHolder, ClientEndpointConfig clientEndpointConfiguration, URI path, Set<URI> redirectSet, Map<String,Object> userProperties, ByteBuffer request, HttpResponse httpResponse, AuthenticationType authenticationType) throws DeploymentException, AuthenticationException { if (userProperties.get(authenticationType.getAuthorizationHeaderName()) != null) { throw new DeploymentException(sm.getString("wsWebSocketContainer.failedAuthentication", Integer.valueOf(httpResponse.status), authenticationType.getAuthorizationHeaderName())); } List<String> authenticateHeaders = httpResponse.handshakeResponse().getHeaders().get(authenticationType.getAuthenticateHeaderName()); if (authenticateHeaders == null || authenticateHeaders.isEmpty() || authenticateHeaders.get(0) == null || authenticateHeaders.get(0).isEmpty()) { throw new DeploymentException(sm.getString("wsWebSocketContainer.missingAuthenticateHeader", Integer.toString(httpResponse.status), authenticationType.getAuthenticateHeaderName())); } String authScheme = authenticateHeaders.get(0).split("\\s+", 2)[0]; Authenticator auth = AuthenticatorFactory.getAuthenticator(authScheme); if (auth == null) { throw new DeploymentException(sm.getString("wsWebSocketContainer.unsupportedAuthScheme", Integer.valueOf(httpResponse.status), authScheme)); } String requestUri = new String(request.array(), StandardCharsets.ISO_8859_1).split("\\s", 3)[1]; userProperties.put(authenticationType.getAuthorizationHeaderName(), auth.getAuthorization(requestUri, authenticateHeaders.get(0), (String) userProperties.get(authenticationType.getUserNameProperty()), (String) userProperties.get(authenticationType.getUserPasswordProperty()), (String) userProperties.get(authenticationType.getUserRealmProperty()))); return connectToServerRecursive(clientEndpointHolder, clientEndpointConfiguration, path, redirectSet); } private static void writeRequest(AsyncChannelWrapper channel, ByteBuffer request, long timeout) throws TimeoutException, InterruptedException, ExecutionException { int toWrite = request.limit(); Future<Integer> fWrite = channel.write(request); Integer thisWrite = fWrite.get(timeout, TimeUnit.MILLISECONDS); toWrite -= thisWrite.intValue(); while (toWrite > 0) { fWrite = channel.write(request); thisWrite = fWrite.get(timeout, TimeUnit.MILLISECONDS); toWrite -= thisWrite.intValue(); } } private static boolean isRedirectStatus(int httpResponseCode) { boolean isRedirect = false; switch (httpResponseCode) { case Constants.MULTIPLE_CHOICES: case Constants.MOVED_PERMANENTLY: case Constants.FOUND: case Constants.SEE_OTHER: case Constants.USE_PROXY: case Constants.TEMPORARY_REDIRECT: isRedirect = true; break; default: break; } return isRedirect; } private static ByteBuffer createProxyRequest(String host, int port, String authorizationHeader) { StringBuilder request = new StringBuilder(); request.append("CONNECT "); request.append(host); request.append(':'); request.append(port); request.append(" HTTP/1.1 Proxy-Connection: keep-alive Connection: keepalive Host: "); request.append(host); request.append(':'); request.append(port); if (authorizationHeader != null) { request.append(" "); request.append(Constants.PROXY_AUTHORIZATION_HEADER_NAME); request.append(':'); request.append(authorizationHeader); } request.append(" "); byte[] bytes = request.toString().getBytes(StandardCharsets.ISO_8859_1); return ByteBuffer.wrap(bytes); } protected void registerSession(Object key, WsSession wsSession) { if (!wsSession.isOpen()) { // The session was closed during onOpen. No need to register it. return; } synchronized (endPointSessionMapLock) { if (endpointSessionMap.isEmpty()) { BackgroundProcessManager.getInstance().register(this); } endpointSessionMap.computeIfAbsent(key, k -> new HashSet<>()).add(wsSession); } sessions.put(wsSession, wsSession); } protected void unregisterSession(Object key, WsSession wsSession) { synchronized (endPointSessionMapLock) { Set<WsSession> wsSessions = endpointSessionMap.get(key); if (wsSessions != null) { wsSessions.remove(wsSession); if (wsSessions.isEmpty()) { endpointSessionMap.remove(key); } } if (endpointSessionMap.isEmpty()) { BackgroundProcessManager.getInstance().unregister(this); } } sessions.remove(wsSession); } Set<Session> getOpenSessions(Object key) { HashSet<Session> result = new HashSet<>(); synchronized (endPointSessionMapLock) { Set<WsSession> sessions = endpointSessionMap.get(key); if (sessions != null) { // Some sessions may be in the process of closing for (WsSession session : sessions) { if (session.isOpen()) { result.add(session); } } } } return result; } private static Map<String,List<String>> createRequestHeaders(String host, int port, boolean secure, ClientEndpointConfig clientEndpointConfiguration) { Map<String,List<String>> headers = new HashMap<>(); List<Extension> extensions = clientEndpointConfiguration.getExtensions(); List<String> subProtocols = clientEndpointConfiguration.getPreferredSubprotocols(); Map<String,Object> userProperties = clientEndpointConfiguration.getUserProperties(); if (userProperties.get(Constants.AUTHORIZATION_HEADER_NAME) != null) { List<String> authValues = new ArrayList<>(1); authValues.add((String) userProperties.get(Constants.AUTHORIZATION_HEADER_NAME)); headers.put(Constants.AUTHORIZATION_HEADER_NAME, authValues); } // Host header List<String> hostValues = new ArrayList<>(1); if (port == 80 && !secure || port == 443 && secure) { // Default ports. Do not include port in host header hostValues.add(host); } else { hostValues.add(host + ':' + port); } headers.put(Constants.HOST_HEADER_NAME, hostValues); // Upgrade header List<String> upgradeValues = new ArrayList<>(1); upgradeValues.add(Constants.UPGRADE_HEADER_VALUE); headers.put(Constants.UPGRADE_HEADER_NAME, upgradeValues); // Connection header List<String> connectionValues = new ArrayList<>(1); connectionValues.add(Constants.CONNECTION_HEADER_VALUE); headers.put(Constants.CONNECTION_HEADER_NAME, connectionValues); // WebSocket version header List<String> wsVersionValues = new ArrayList<>(1); wsVersionValues.add(Constants.WS_VERSION_HEADER_VALUE); headers.put(Constants.WS_VERSION_HEADER_NAME, wsVersionValues); // WebSocket key List<String> wsKeyValues = new ArrayList<>(1); wsKeyValues.add(generateWsKeyValue()); headers.put(Constants.WS_KEY_HEADER_NAME, wsKeyValues); // WebSocket sub-protocols if (subProtocols != null && !subProtocols.isEmpty()) { headers.put(Constants.WS_PROTOCOL_HEADER_NAME, subProtocols); } // WebSocket extensions if (extensions != null) { // Filter the requested extensions to remove any that are not supported by the client container. Set<String> installed = TransformationFactory.getInstance().getInstalledExtensionNames(); List<Extension> availableExtensions = new ArrayList<>(extensions); Iterator<Extension> availableExtensionsIter = availableExtensions.iterator(); while (availableExtensionsIter.hasNext()) { Extension e = availableExtensionsIter.next(); if (!installed.contains(e.getName())) { availableExtensionsIter.remove(); } } if (!availableExtensions.isEmpty()) { headers.put(Constants.WS_EXTENSIONS_HEADER_NAME, generateExtensionHeaders(availableExtensions)); } } return headers; } private static List<String> generateExtensionHeaders(List<Extension> extensions) { List<String> result = new ArrayList<>(extensions.size()); for (Extension extension : extensions) { StringBuilder header = new StringBuilder(); header.append(extension.getName()); for (Extension.Parameter param : extension.getParameters()) { header.append(';'); header.append(param.getName()); String value = param.getValue(); if (value != null && !value.isEmpty()) { header.append('='); header.append(value); } } result.add(header.toString()); } return result; } private static String generateWsKeyValue() { byte[] keyBytes = new byte[16]; RANDOM.nextBytes(keyBytes); return Base64.getEncoder().encodeToString(keyBytes); } private static ByteBuffer createRequest(URI uri, Map<String,List<String>> reqHeaders) { ByteBuffer result = ByteBuffer.allocate(4 * 1024); // Request line result.put(GET_BYTES); final String path = uri.getPath(); if (null == path || path.isEmpty()) { result.put(ROOT_URI_BYTES); } else { result.put(uri.getRawPath().getBytes(StandardCharsets.ISO_8859_1)); } String query = uri.getRawQuery(); if (query != null) { result.put((byte) '?'); result.put(query.getBytes(StandardCharsets.ISO_8859_1)); } result.put(HTTP_VERSION_BYTES); // Headers for (Entry<String,List<String>> entry : reqHeaders.entrySet()) { result = addHeader(result, entry.getKey(), entry.getValue()); } // Terminating CRLF result.put(CRLF); result.flip(); return result; } private static ByteBuffer addHeader(ByteBuffer result, String key, List<String> values) { if (values.isEmpty()) { return result; } result = putWithExpand(result, key.getBytes(StandardCharsets.ISO_8859_1)); result = putWithExpand(result, ": ".getBytes(StandardCharsets.ISO_8859_1)); result = putWithExpand(result, StringUtils.join(values).getBytes(StandardCharsets.ISO_8859_1)); result = putWithExpand(result, CRLF); return result; } private static ByteBuffer putWithExpand(ByteBuffer input, byte[] bytes) { if (bytes.length > input.remaining()) { int newSize; if (bytes.length > input.capacity()) { newSize = 2 * bytes.length; } else { newSize = input.capacity() * 2; } ByteBuffer expanded = ByteBuffer.allocate(newSize); input.flip(); expanded.put(input); input = expanded; } return input.put(bytes); } /** * Process response, blocking until HTTP response has been fully received. * * @throws ExecutionException if there is an exception reading the response * @throws InterruptedException if the thread is interrupted while reading the response * @throws DeploymentException if the response status line is not correctly formatted * @throws TimeoutException if the response was not read within the expected timeout */ private HttpResponse processResponse(ByteBuffer response, AsyncChannelWrapper channel, long timeout) throws InterruptedException, ExecutionException, DeploymentException, EOFException, TimeoutException { Map<String,List<String>> headers = new CaseInsensitiveKeyMap<>(); int status = 0; boolean readStatus = false; boolean readHeaders = false; String line = null; while (!readHeaders) { // On entering loop buffer will be empty and at the start of a new // loop the buffer will have been fully read. response.clear(); // Blocking read Future<Integer> read = channel.read(response); Integer bytesRead; try { bytesRead = read.get(timeout, TimeUnit.MILLISECONDS); } catch (TimeoutException e) { TimeoutException te = new TimeoutException( sm.getString("wsWebSocketContainer.responseFail", Integer.toString(status), headers)); te.initCause(e); throw te; } if (bytesRead.intValue() == -1) { throw new EOFException( sm.getString("wsWebSocketContainer.responseFail", Integer.toString(status), headers)); } response.flip(); while (response.hasRemaining() && !readHeaders) { if (line == null) { line = readLine(response); } else { line += readLine(response); } if (" ".equals(line)) { readHeaders = true; } else if (line.endsWith(" ")) { if (readStatus) { parseHeaders(line, headers); } else { status = parseStatus(line); readStatus = true; } line = null; } } } return new HttpResponse(status, new WsHandshakeResponse(headers)); } private int parseStatus(String line) throws DeploymentException { // This client only understands HTTP 1. // RFC2616 is case specific String[] parts = line.trim().split(" "); // CONNECT for proxy may return a 1.0 response if (parts.length < 2 || !("HTTP/1.0".equals(parts[0]) || "HTTP/1.1".equals(parts[0]))) { throw new DeploymentException(sm.getString("wsWebSocketContainer.invalidStatus", line)); } try { return Integer.parseInt(parts[1]); } catch (NumberFormatException nfe) { throw new DeploymentException(sm.getString("wsWebSocketContainer.invalidStatus", line)); } } private void parseHeaders(String line, Map<String,List<String>> headers) { // Treat headers as single values by default. int index = line.indexOf(':'); if (index == -1) { log.warn(sm.getString("wsWebSocketContainer.invalidHeader", line)); return; } // Header names are case-insensitive so always use lower case String headerName = line.substring(0, index).trim().toLowerCase(Locale.ENGLISH); // Multi-value headers are stored as a single header and the client is // expected to handle splitting into individual values String headerValue = line.substring(index + 1).trim(); List<String> values = headers.computeIfAbsent(headerName, k -> new ArrayList<>(1)); values.add(headerValue); } private String readLine(ByteBuffer response) { // All ISO-8859-1 StringBuilder sb = new StringBuilder(); char c; while (response.hasRemaining()) { c = (char) response.get(); sb.append(c); if (c == 10) { break; } } return sb.toString(); } private SSLEngine createSSLEngine(ClientEndpointConfig clientEndpointConfig, String host, int port) throws DeploymentException { try { // See if a custom SSLContext has been provided SSLContext sslContext = clientEndpointConfig.getSSLContext(); if (sslContext == null) { // Create the SSL Context sslContext = SSLContext.getInstance("TLS"); sslContext.init(null, null, null); } SSLEngine engine = sslContext.createSSLEngine(host, port); engine.setUseClientMode(true); // Enable host verification // Start with current settings (returns a copy) SSLParameters sslParams = engine.getSSLParameters(); // Use HTTPS since WebSocket starts over HTTP(S) sslParams.setEndpointIdentificationAlgorithm("HTTPS"); // Write the parameters back engine.setSSLParameters(sslParams); return engine; } catch (Exception e) { throw new DeploymentException(sm.getString("wsWebSocketContainer.sslEngineFail"), e); } } @Override public long getDefaultMaxSessionIdleTimeout() { return defaultMaxSessionIdleTimeout; } @Override public void setDefaultMaxSessionIdleTimeout(long timeout) { this.defaultMaxSessionIdleTimeout = timeout; } @Override public int getDefaultMaxBinaryMessageBufferSize() { return maxBinaryMessageBufferSize; } @Override public void setDefaultMaxBinaryMessageBufferSize(int max) { maxBinaryMessageBufferSize = max; } @Override public int getDefaultMaxTextMessageBufferSize() { return maxTextMessageBufferSize; } @Override public void setDefaultMaxTextMessageBufferSize(int max) { maxTextMessageBufferSize = max; } /** * {@inheritDoc} Currently, this implementation does not support any extensions. */ @Override public Set<Extension> getInstalledExtensions() { return TransformationFactory.getInstance().getInstalledExtensions(); } /** * {@inheritDoc} The default value for this implementation is -1. */ @Override public long getDefaultAsyncSendTimeout() { return defaultAsyncTimeout; } /** * {@inheritDoc} The default value for this implementation is -1. */ @Override public void setAsyncSendTimeout(long timeout) { this.defaultAsyncTimeout = timeout; } /** * Cleans up the resources still in use by WebSocket sessions created from this container. This includes closing * sessions and cancelling {@link Future}s associated with blocking read/writes. */ public void destroy() { CloseReason cr = new CloseReason(CloseCodes.GOING_AWAY, sm.getString("wsWebSocketContainer.shutdown")); for (WsSession session : sessions.keySet()) { try { session.close(cr); } catch (IOException ioe) { if (log.isDebugEnabled()) { log.debug(sm.getString("wsWebSocketContainer.sessionCloseFail", session.getId()), ioe); } } } // Only unregister with AsyncChannelGroupUtil if this instance // registered with it if (asynchronousChannelGroup != null) { synchronized (asynchronousChannelGroupLock) { if (asynchronousChannelGroup != null) { AsyncChannelGroupUtil.unregister(); asynchronousChannelGroup = null; } } } } private AsynchronousChannelGroup getAsynchronousChannelGroup() { // Use AsyncChannelGroupUtil to share a common group amongst all // WebSocket clients AsynchronousChannelGroup result = asynchronousChannelGroup; if (result == null) { synchronized (asynchronousChannelGroupLock) { if (asynchronousChannelGroup == null) { asynchronousChannelGroup = AsyncChannelGroupUtil.register(); } result = asynchronousChannelGroup; } } return result; } // ----------------------------------------------- BackgroundProcess methods @Override public void backgroundProcess() { // This method gets called once a second. backgroundProcessCount++; if (backgroundProcessCount >= processPeriod) { backgroundProcessCount = 0; // Check all registered sessions. for (WsSession wsSession : sessions.keySet()) { wsSession.checkExpiration(); wsSession.checkCloseTimeout(); } } } @Override public void setProcessPeriod(int period) { this.processPeriod = period; } /** * {@inheritDoc} The default value is 10 which means session expirations are processed every 10 seconds. */ @Override public int getProcessPeriod() { return processPeriod; } private record HttpResponse(int status, HandshakeResponse handshakeResponse) { } }
Detected license expression
apache-2.0
Detected license expression (SPDX)
Apache-2.0
Percentage of license text
3.34
Copyrights

      
    
Holders

      
    
Authors

      
    
License detections License expression License expression SPDX
apache_2_0-4bde3f57-78aa-4201-96bf-531cba09e7de apache-2.0 Apache-2.0
URL Start line End line
http://www.apache.org/licenses/LICENSE-2.0 9 9