Package org.mortbay.servlet

Source Code of org.mortbay.servlet.DosFilterTest$DoSFilter2

// ========================================================================
// Copyright 2009 Mort Bay Consulting Pty. Ltd.
// ------------------------------------------------------------------------
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//========================================================================

package org.mortbay.servlet;

import java.io.IOException;
import java.net.Socket;

import javax.servlet.Servlet;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import junit.framework.TestCase;

import org.mortbay.jetty.HttpURI;
import org.mortbay.jetty.servlet.FilterHolder;
import org.mortbay.jetty.testing.ServletTester;
import org.mortbay.log.Log;
import org.mortbay.util.IO;

public class DosFilterTest extends TestCase
{
    private ServletTester _tester;
    private String _host;
    private int _port;
   
    private int _maxRequestMs = 200;
    protected void setUp() throws Exception
    {
        _tester = new ServletTester();
        HttpURI uri=new HttpURI(_tester.createSocketConnector(true));
        _host=uri.getHost();
        _port=uri.getPort();
       
        _tester.setContextPath("/ctx");
        _tester.addServlet(TestServlet.class, "/*");
       
        FilterHolder dos=_tester.addFilter(DoSFilter2.class,"/dos/*",0);
        dos.setInitParameter("maxRequestsPerSec","4");
        dos.setInitParameter("delayMs","200");
        dos.setInitParameter("throttledRequests","1");
        dos.setInitParameter("waitMs","10");
        dos.setInitParameter("throttleMs","4000");
        dos.setInitParameter("remotePort", "false");
        dos.setInitParameter("insertHeaders", "true");
       
        FilterHolder quickTimeout = _tester.addFilter(DoSFilter2.class,"/timeout/*",0);
        quickTimeout.setInitParameter("maxRequestsPerSec","4");
        quickTimeout.setInitParameter("delayMs","200");
        quickTimeout.setInitParameter("throttledRequests","1");
        quickTimeout.setInitParameter("waitMs","10");
        quickTimeout.setInitParameter("throttleMs","4000");
        quickTimeout.setInitParameter("remotePort", "false");
        quickTimeout.setInitParameter("insertHeaders", "true");
        quickTimeout.setInitParameter("maxRequestMs", _maxRequestMs + "");

        _tester.start();

    }
       
    protected void tearDown() throws Exception
    {
        _tester.stop();
    }
   
    private String doRequests(String requests, int loops, long pause0,long pause1,String request)
        throws Exception
    {
        Socket socket = new Socket(_host,_port);
        socket.setSoTimeout(30000);
       
        for (int i=loops;i-->0;)
        {
            socket.getOutputStream().write(requests.getBytes("UTF-8"));
            socket.getOutputStream().flush();
            if (i>0 && pause0>0)
                Thread.sleep(pause0);
        }
        if (pause1>0)
            Thread.sleep(pause1);
        socket.getOutputStream().write(request.getBytes("UTF-8"));
        socket.getOutputStream().flush();
       
       
        String response = "";

        if (requests.indexOf("/unresponsive") != -1)
        {
            // don't read in anything, forcing the request to time out
            Thread.sleep(_maxRequestMs * 2);
            response = IO.toString(socket.getInputStream(),"UTF-8");
        }
        else
        {
            response = IO.toString(socket.getInputStream(),"UTF-8");
        }
        socket.close();
        return response;
    }
   
    private int count(String responses,String substring)
    {
        int count=0;
        int i=responses.indexOf(substring);
        while (i>=0)
        {
            count++;
            i=responses.indexOf(substring,i+substring.length());
        }
       
        return count;
    }
   
    public void testEvenLowRateIP()
        throws Exception
    {
        String request="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\n\r\n";
        String last="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
        String responses = doRequests(request,11,300,300,last);  
        assertEquals(12,count(responses,"HTTP/1.1 200 OK"));
        assertEquals(0,count(responses,"DoSFilter:"));
    }
   
    public void testBurstLowRateIP()
        throws Exception
    {
        String request="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\n\r\n";
        String last="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
        String responses = doRequests(request+request+request+request,2,1100,1100,last);  
       
        assertEquals(9,count(responses,"HTTP/1.1 200 OK"));
        assertEquals(0,count(responses,"DoSFilter:"));
    }
   
    public void testDelayedIP()
        throws Exception
    {
        String request="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\n\r\n";
        String last="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
        String responses = doRequests(request+request+request+request+request,2,1100,1100,last);
       
        assertEquals(11,count(responses,"HTTP/1.1 200 OK"));
        assertEquals(2,count(responses,"DoSFilter: delayed"));
    }
   
    public void testThrottledIP()
        throws Exception
    {
        Thread other = new Thread()
        {
            public void run()
            {
                try
                {
                    // Cause a delay, then sleep while holding pass
                    String request="GET /ctx/dos/sleeper HTTP/1.1\r\nHost: localhost\r\n\r\n";
                    String last="GET /ctx/dos/sleeper?sleep=2000 HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
                    String responses = doRequests(request+request+request+request,1,0,0,last);
                }
                catch(Exception e)
                {
                    e.printStackTrace();
                }
            }
        };
        other.start();
        Thread.sleep(1500);
       
        String request="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\n\r\n";
        String last="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
        String responses = doRequests(request+request+request+request,1,0,0,last);
       
        assertEquals(5,count(responses,"HTTP/1.1 200 OK"));
        assertEquals(1,count(responses,"DoSFilter: delayed"));
        assertEquals(1,count(responses,"DoSFilter: throttled"));
        assertEquals(0,count(responses,"DoSFilter: unavailable"));
       
        other.join();
    }
   
    public void testUnavailableIP()
        throws Exception
    {
        Thread other = new Thread()
        {
            public void run()
            {
                try
                {
                    // Cause a delay, then sleep while holding pass
                    String request="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\n\r\n";
                    String last="GET /ctx/dos/test?sleep=5000 HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
                    String responses = doRequests(request+request+request+request,1,0,0,last);
                }
                catch(Exception e)
                {
                    e.printStackTrace();
                }
            }
        };
        other.start();
        Thread.sleep(500);
       
        String request="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\n\r\n";
        String last="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
        String responses = doRequests(request+request+request+request,1,0,0,last);
       
        assertEquals(4,count(responses,"HTTP/1.1 200 OK"));
        assertEquals(1,count(responses,"HTTP/1.1 503"));
        assertEquals(1,count(responses,"DoSFilter: delayed"));
        assertEquals(1,count(responses,"DoSFilter: throttled"));
        assertEquals(1,count(responses,"DoSFilter: unavailable"));
       
        other.join();
    }
   
    public void testSessionTracking()
        throws Exception
    {
        // get a session, first
        String requestSession="GET /ctx/dos/test?session=true HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
        String response=doRequests("",1,0,0,requestSession);
        String sessionId=response.substring(response.indexOf("Set-Cookie: ")+12, response.indexOf(";"));

        // all other requests use this session
        String request="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nCookie: " + sessionId + "\r\n\r\n";
        String last="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\nCookie: " + sessionId + "\r\n\r\n";
        String responses = doRequests(request+request+request+request+request,2,1100,1100,last);

        assertEquals(11,count(responses,"HTTP/1.1 200 OK"));
        assertEquals(2,count(responses,"DoSFilter: delayed"));
    }

    public void testMultipleSessionTracking()
        throws Exception
    {
        // get some session ids, first
        String requestSession="GET /ctx/dos/test?session=true HTTP/1.1\r\nHost: localhost\r\n\r\n";
        String closeRequest="GET /ctx/dos/test?session=true HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
        String response=doRequests(requestSession+requestSession,1,0,0,closeRequest);

        String[] sessions = response.split("\r\n\r\n");

        String sessionId1=sessions[0].substring(sessions[0].indexOf("Set-Cookie: ")+12, sessions[0].indexOf(";"));
        String sessionId2=sessions[1].substring(sessions[1].indexOf("Set-Cookie: ")+12, sessions[1].indexOf(";"));

        String request1="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nCookie: " + sessionId1 + "\r\n\r\n";
        String request2="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nCookie: " + sessionId2 + "\r\n\r\n";
        String last="GET /ctx/dos/test HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\nCookie: " + sessionId2 + "\r\n\r\n";
       
        // Make sure the sessions are not new.
        String responses = doRequests(request1+request2+request1+request2+request1,1,1100,1100,last);
       
        Thread.sleep(1000);
       
        // alternate between sessions
        responses = doRequests(request1+request2+request1+request2+request1,2,1100,1100,last);

        assertEquals(11,count(responses,"HTTP/1.1 200 OK"));
        assertEquals(0,count(responses,"DoSFilter: delayed"));

        // alternate between sessions
        responses = doRequests(request1+request2+request1+request2+request1,2,550,550,last);

        assertEquals(11,count(responses,"HTTP/1.1 200 OK"));
        int delayedRequests = count(responses,"DoSFilter: delayed");
        assertTrue(delayedRequests >= 2 && delayedRequests <= 3);
    }   

    public void testUnresponsiveClient()
        throws Exception
    {
        int numRequests = 1000;

        String last="GET /ctx/timeout/unresponsive?lines="+numRequests+" HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n";
        String responses = doRequests("",0,0,0,last);

        // was expired, and stopped before reaching the end of the requests
        int responseLines = count(responses, "Line:");
        assertTrue(responses.indexOf("DoSFilter: timeout") != -1);
        assertTrue(responseLines > 0 && responseLines < numRequests);
    }
  
    public static class TestServlet extends HttpServlet implements Servlet
    {
        protected void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException
        {
            if (request.getParameter("session")!=null)
                request.getSession(true);
            if (request.getParameter("sleep")!=null)
            {
                try
                {
                    Thread.sleep(Long.parseLong(request.getParameter("sleep")));
                }
                catch(InterruptedException e)
                {   
                }
            }
           
            if (request.getParameter("lines")!=null)
            {
                int count = Integer.parseInt(request.getParameter("lines"));
                for(int i = 0; i < count; ++i)
                {                       
                    response.getWriter().write("Line: " + i+"\n");
                    response.flushBuffer();
                   
                    try
                    {
                        Thread.sleep(10);
                    }
                    catch(InterruptedException e)
                    {
                    }

                }
            }
           
            response.setContentType("text/plain");
           
        }
    }
   
    public static class DoSFilter2 extends DoSFilter
    {
        public void closeConnection(HttpServletRequest request, HttpServletResponse response, Thread thread)
        {
            try {
                response.getWriter().write("DoSFilter: timeout");
                super.closeConnection(request,response,thread);
            }
            catch (Exception e)
            {
                Log.warn(e);
            }
        }
    }   
}
TOP

Related Classes of org.mortbay.servlet.DosFilterTest$DoSFilter2

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.