/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package jakarta.servlet.http;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import jakarta.servlet.AsyncContext;
import jakarta.servlet.Servlet;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletOutputStream;
import jakarta.servlet.WriteListener;
import org.junit.Assert;
import org.junit.Test;
import static org.apache.catalina.startup.SimpleHttpClient.CRLF;
import org.apache.catalina.Context;
import org.apache.catalina.Wrapper;
import org.apache.catalina.core.StandardContext;
import org.apache.catalina.startup.SimpleHttpClient;
import org.apache.catalina.startup.TesterServlet;
import org.apache.catalina.startup.Tomcat;
import org.apache.catalina.startup.TomcatBaseTest;
import org.apache.tomcat.util.buf.ByteChunk;
import org.apache.tomcat.util.collections.CaseInsensitiveKeyMap;
import org.apache.tomcat.util.http.Method;
import org.apache.tomcat.util.net.TesterSupport.SimpleServlet;
public class TestHttpServlet extends TomcatBaseTest {
@Test
public void testBug53454() throws Exception {
Tomcat tomcat = getTomcatInstance();
// No file system docBase required
StandardContext ctx = (StandardContext) getProgrammaticRootContext();
// Map the test Servlet
LargeBodyServlet largeBodyServlet = new LargeBodyServlet();
Tomcat.addServlet(ctx, "largeBodyServlet", largeBodyServlet);
ctx.addServletMappingDecoded("/", "largeBodyServlet");
tomcat.start();
Map<String,List<String>> resHeaders = new HashMap<>();
int rc = headUrl("http://localhost:" + getPort() + "/", new ByteChunk(), resHeaders);
Assert.assertEquals(HttpServletResponse.SC_OK, rc);
Assert.assertEquals(LargeBodyServlet.RESPONSE_LENGTH, resHeaders.get("Content-Length").get(0));
}
private static class LargeBodyServlet extends HttpServlet {
private static final long serialVersionUID = 1L;
private static final String RESPONSE_LENGTH = "12345678901";
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
resp.setHeader("content-length", RESPONSE_LENGTH);
}
}
/*
* Verifies that the same Content-Length is returned for both GET and HEAD operations when a Servlet includes
* content from another Servlet
*/
@Test
public void testBug57602() throws Exception {
Tomcat tomcat = getTomcatInstance();
// No file system docBase required
StandardContext ctx = (StandardContext) getProgrammaticRootContext();
Bug57602ServletOuter outer = new Bug57602ServletOuter();
Tomcat.addServlet(ctx, "Bug57602ServletOuter", outer);
ctx.addServletMappingDecoded("/outer", "Bug57602ServletOuter");
Bug57602ServletInner inner = new Bug57602ServletInner();
Tomcat.addServlet(ctx, "Bug57602ServletInner", inner);
ctx.addServletMappingDecoded("/inner", "Bug57602ServletInner");
tomcat.start();
Map<String,List<String>> resHeaders = new CaseInsensitiveKeyMap<>();
String path = "http://localhost:" + getPort() + "/outer";
ByteChunk out = new ByteChunk();
int rc = getUrl(path, out, resHeaders);
Assert.assertEquals(HttpServletResponse.SC_OK, rc);
String length = getSingleHeader("Content-Length", resHeaders);
Assert.assertEquals(Long.parseLong(length), out.getLength());
out.recycle();
rc = headUrl(path, out, resHeaders);
Assert.assertEquals(HttpServletResponse.SC_OK, rc);
Assert.assertEquals(0, out.getLength());
Assert.assertEquals(length, resHeaders.get("Content-Length").get(0));
tomcat.stop();
}
@Test
public void testHeadWithChunking() throws Exception {
doTestHead(new ChunkingServlet());
}
@Test
public void testHeadWithResetBufferWriter() throws Exception {
doTestHead(new ResetBufferServlet(true));
}
@Test
public void testHeadWithResetBufferStream() throws Exception {
doTestHead(new ResetBufferServlet(false));
}
@Test
public void testHeadWithResetWriter() throws Exception {
doTestHead(new ResetServlet(true));
}
@Test
public void testHeadWithResetStream() throws Exception {
doTestHead(new ResetServlet(false));
}
@Test
public void testHeadWithNonBlocking() throws Exception {
// Less than buffer size
doTestHead(new NonBlockingWriteServlet(4 * 1024));
}
private void doTestHead(Servlet servlet) throws Exception {
Tomcat tomcat = getTomcatInstance();
// No file system docBase required
StandardContext ctx = (StandardContext) getProgrammaticRootContext();
Wrapper w = Tomcat.addServlet(ctx, "TestServlet", servlet);
// Not all need/use this but it is simpler to set it for all
w.setAsyncSupported(true);
ctx.addServletMappingDecoded("/test", "TestServlet");
tomcat.start();
Map<String,List<String>> getHeaders = new CaseInsensitiveKeyMap<>();
String path = "http://localhost:" + getPort() + "/test";
ByteChunk out = new ByteChunk();
int rc = getUrl(path, out, getHeaders);
Assert.assertEquals(HttpServletResponse.SC_OK, rc);
out.recycle();
Map<String,List<String>> headHeaders = new CaseInsensitiveKeyMap<>();
rc = headUrl(path, out, headHeaders);
Assert.assertEquals(HttpServletResponse.SC_OK, rc);
// Date header is likely to be different so just remove it from both GET and HEAD.
getHeaders.remove("date");
headHeaders.remove("date");
/*
* There are some headers that are optional for HEAD. See RFC 9110, section 9.3.2. If present, they must be the
* same for both GET and HEAD. If not present in HEAD, remove them from GET.
*/
for (String header : TesterConstants.OPTIONAL_HEADERS_WITH_HEAD) {
if (!headHeaders.containsKey(header)) {
getHeaders.remove(header);
}
}
// Headers should be the same (apart from Date)
Assert.assertEquals(getHeaders.size(), headHeaders.size());
for (Map.Entry<String,List<String>> getHeader : getHeaders.entrySet()) {
String headerName = getHeader.getKey();
Assert.assertTrue(headerName, headHeaders.containsKey(headerName));
List<String> getValues = getHeader.getValue();
List<String> headValues = headHeaders.get(headerName);
Assert.assertEquals(getValues.size(), headValues.size());
for (String value : getValues) {
Assert.assertTrue(headValues.contains(value));
}
}
tomcat.stop();
}
@Test
public void testDoOptions() throws Exception {
doTestDoOptions(new OptionsServlet(), "GET, HEAD, OPTIONS");
}
@Test
public void testDoOptionsSub() throws Exception {
doTestDoOptions(new OptionsServletSub(), "GET, HEAD, POST, OPTIONS");
}
private void doTestDoOptions(Servlet servlet, String expectedAllow) throws Exception {
Tomcat tomcat = getTomcatInstance();
// No file system docBase required
StandardContext ctx = (StandardContext) getProgrammaticRootContext();
// Map the test Servlet
Tomcat.addServlet(ctx, "servlet", servlet);
ctx.addServletMappingDecoded("/", "servlet");
tomcat.start();
Map<String,List<String>> resHeaders = new HashMap<>();
int rc = methodUrl("http://localhost:" + getPort() + "/", new ByteChunk(), DEFAULT_CLIENT_TIMEOUT_MS, null,
resHeaders, Method.OPTIONS);
Assert.assertEquals(HttpServletResponse.SC_OK, rc);
Assert.assertEquals(expectedAllow, resHeaders.get("Allow").get(0));
}
@Test
public void testUnimplementedMethodHttp09() throws Exception {
doTestUnimplementedMethod("0.9");
}
@Test
public void testUnimplementedMethodHttp10() throws Exception {
doTestUnimplementedMethod("1.0");
}
@Test
public void testUnimplementedMethodHttp11() throws Exception {
doTestUnimplementedMethod("1.1");
}
/*
* See org.apache.coyote.http2.TestHttpServlet for the HTTP/2 version of this test. It was placed in that package
* because it needed access to package private classes.
*/
private void doTestUnimplementedMethod(String httpVersion) {
StringBuilder request = new StringBuilder("PUT /test");
boolean isHttp09 = "0.9".equals(httpVersion);
boolean isHttp10 = "1.0".equals(httpVersion);
if (!isHttp09) {
request.append(" HTTP/");
request.append(httpVersion);
}
request.append(CRLF);
request.append("Host: localhost:8080");
request.append(CRLF);
request.append("Connection: close");
request.append(CRLF);
request.append(CRLF);
Client client = new Client(request.toString(), "0.9".equals(httpVersion));
client.doRequest();
if (isHttp09) {
Assert.assertTrue(client.getResponseBody(), client.getResponseBody().contains(" 400 "));
} else if (isHttp10) {
Assert.assertTrue(client.getResponseLine(), client.isResponse400());
} else {
Assert.assertTrue(client.getResponseLine(), client.isResponse405());
}
}
@Test
public void testTrace() throws Exception {
Tomcat tomcat = getTomcatInstance();
tomcat.getConnector().setAllowTrace(true);
// No file system docBase required
StandardContext ctx = (StandardContext) getProgrammaticRootContext();
// Map the test Servlet
Tomcat.addServlet(ctx, "servlet", new SimpleServlet());
ctx.addServletMappingDecoded("/", "servlet");
tomcat.start();
TraceClient client = new TraceClient();
client.setPort(getPort());
// @formatter:off
client.setRequest(new String[] {
"TRACE / HTTP/1.1" + CRLF +
"Host: localhost:" + getPort() + CRLF +
"X-aaa: a1, a2" + CRLF +
"X-aaa: a3" + CRLF +
"Cookie: c1-v1" + CRLF +
"Authorization: not-a-real-credential" + CRLF +
CRLF
});
// @formatter:on
client.setUseContentLength(true);
client.connect();
client.sendRequest();
client.readResponse(true);
String body = client.getResponseBody();
System.out.println(body);
Assert.assertTrue(client.getResponseLine(), client.isResponse200());
// Far from perfect but good enough
body = body.toLowerCase(Locale.ENGLISH);
Assert.assertTrue(body.contains("a1"));
Assert.assertTrue(body.contains("a2"));
Assert.assertTrue(body.contains("a3"));
// Sensitive headers (cookies, WWW-Authenticate) must not be reflected
// (since RFC 7231)
Assert.assertFalse(body.contains("cookie"));
Assert.assertFalse(body.contains("authorization"));
client.disconnect();
}
private static final class TraceClient extends SimpleHttpClient {
@Override
public boolean isResponseBodyOK() {
return true;
}
}
private class Client extends SimpleHttpClient {
Client(String request, boolean isHttp09) {
setRequest(new String[] { request });
setUseHttp09(isHttp09);
}
private Exception doRequest() {
Tomcat tomcat = getTomcatInstance();
Context root = tomcat.addContext("", TEMP_DIR);
Tomcat.addServlet(root, "TesterServlet", new TesterServlet());
root.addServletMappingDecoded("/test", "TesterServlet");
try {
tomcat.start();
setPort(tomcat.getConnector().getLocalPort());
setRequestPause(20);
// Open connection
connect();
processRequest(); // blocks until response has been read
// Close the connection
disconnect();
} catch (Exception e) {
e.printStackTrace();
return e;
}
return null;
}
@Override
public boolean isResponseBodyOK() {
return false;
}
}
private static class Bug57602ServletOuter extends HttpServlet {
private static final long serialVersionUID = 1L;
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
resp.setContentType("text/plain");
resp.setCharacterEncoding("UTF-8");
PrintWriter pw = resp.getWriter();
pw.println("Header");
req.getRequestDispatcher("/inner").include(req, resp);
pw.println("Footer");
}
}
private static class Bug57602ServletInner extends HttpServlet {
private static final long serialVersionUID = 1L;
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
resp.setContentType("text/plain");
resp.setCharacterEncoding("UTF-8");
PrintWriter pw = resp.getWriter();
pw.println("Included");
}
}
private static class ChunkingServlet extends HttpServlet {
private static final long serialVersionUID = 1L;
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
resp.setContentType("text/plain");
resp.setCharacterEncoding("UTF-8");
PrintWriter pw = resp.getWriter();
// Trigger chunking
pw.write(new char[8192 * 16]);
pw.println("Data");
}
}
private static class ResetBufferServlet extends HttpServlet {
private static final long serialVersionUID = 1L;
private final boolean useWriter;
ResetBufferServlet(boolean useWriter) {
this.useWriter = useWriter;
}
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
resp.setContentType("text/plain");
resp.setCharacterEncoding("UTF-8");
if (useWriter) {
PrintWriter pw = resp.getWriter();
pw.write(new char[4 * 1024]);
resp.resetBuffer();
pw.write(new char[4 * 1024]);
} else {
ServletOutputStream sos = resp.getOutputStream();
sos.write(new byte[4 * 1024]);
resp.resetBuffer();
sos.write(new byte[4 * 1024]);
}
}
}
private static class ResetServlet extends HttpServlet {
private static final long serialVersionUID = 1L;
private final boolean useWriter;
ResetServlet(boolean useWriter) {
this.useWriter = useWriter;
}
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
resp.setContentType("text/plain");
resp.setCharacterEncoding("UTF-8");
if (useWriter) {
PrintWriter pw = resp.getWriter();
resp.addHeader("aaa", "bbb");
pw.write(new char[4 * 1024]);
resp.resetBuffer();
resp.addHeader("ccc", "ddd");
pw.write(new char[4 * 1024]);
} else {
ServletOutputStream sos = resp.getOutputStream();
resp.addHeader("aaa", "bbb");
sos.write(new byte[4 * 1024]);
resp.resetBuffer();
resp.addHeader("ccc", "ddd");
sos.write(new byte[4 * 1024]);
}
}
}
private static class NonBlockingWriteServlet extends HttpServlet {
private static final long serialVersionUID = 1L;
private final int bytesToWrite;
NonBlockingWriteServlet(int bytesToWrite) {
this.bytesToWrite = bytesToWrite;
}
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
AsyncContext ac = req.startAsync(req, resp);
ac.setTimeout(3000);
WriteListener wListener = new NonBlockingWriteListener(ac, bytesToWrite);
resp.getOutputStream().setWriteListener(wListener);
}
private static class NonBlockingWriteListener implements WriteListener {
private final AsyncContext ac;
private final ServletOutputStream sos;
private int bytesToWrite;
NonBlockingWriteListener(AsyncContext ac, int bytesToWrite) throws IOException {
this.ac = ac;
this.sos = ac.getResponse().getOutputStream();
this.bytesToWrite = bytesToWrite;
}
@Override
public void onWritePossible() throws IOException {
do {
// Write up to 1k a time
int bytesThisTime = Math.min(bytesToWrite, 1024);
sos.write(new byte[bytesThisTime]);
bytesToWrite -= bytesThisTime;
} while (sos.isReady() && bytesToWrite > 0);
if (sos.isReady() && bytesToWrite == 0) {
ac.complete();
}
}
@Override
public void onError(Throwable throwable) {
throwable.printStackTrace();
}
}
}
private static class OptionsServlet extends HttpServlet {
private static final long serialVersionUID = 1L;
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
resp.setContentType("text/plain");
resp.setCharacterEncoding("UTF-8");
PrintWriter pw = resp.getWriter();
pw.print("OK");
}
}
private static class OptionsServletSub extends OptionsServlet {
private static final long serialVersionUID = 1L;
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
doGet(req, resp);
}
}
}