Tuesday, June 30, 2015

How to write custom throttle handler to throttle requests based on IP address - WSO2 API Manager

Please find the sample source code for custom throttle handler to throttle requests based on IP address. Based on your requirements you can change the logic here.

package org.wso2.carbon.apimgt.gateway.handlers.throttling; import org.apache.axiom.om.OMAbstractFactory;
import org.apache.axiom.om.OMElement;
import org.apache.axiom.om.OMFactory;
import org.apache.axiom.om.OMNamespace;
import org.apache.axis2.context.ConfigurationContext;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.http.HttpStatus;
import org.apache.neethi.PolicyEngine;
import org.apache.synapse.Mediator;
import org.apache.synapse.MessageContext;
import org.apache.synapse.SynapseConstants;
import org.apache.synapse.SynapseException;
import org.apache.synapse.config.Entry;
import org.apache.synapse.core.axis2.Axis2MessageContext;
import org.apache.synapse.rest.AbstractHandler;
import org.wso2.carbon.apimgt.gateway.handlers.Utils;
import org.wso2.carbon.apimgt.gateway.handlers.security.APISecurityUtils;
import org.wso2.carbon.apimgt.gateway.handlers.security.AuthenticationContext;
import org.wso2.carbon.apimgt.impl.APIConstants;
import org.wso2.carbon.throttle.core.AccessInformation;
import org.wso2.carbon.throttle.core.RoleBasedAccessRateController;
import org.wso2.carbon.throttle.core.Throttle;
import org.wso2.carbon.throttle.core.ThrottleContext;
import org.wso2.carbon.throttle.core.ThrottleException;
import org.wso2.carbon.throttle.core.ThrottleFactory;

import java.util.Map;
import java.util.TreeMap;


public class IPBasedThrottleHandler extends AbstractHandler {

    private static final Log log = LogFactory.getLog(IPBasedThrottleHandler.class);

    /** The Throttle object - holds all runtime and configuration data */
    private volatile Throttle throttle;

    private RoleBasedAccessRateController applicationRoleBasedAccessController;

    /** The key for getting the throttling policy - key refers to a/an [registry] entry    */
    private String policyKey = null;
    /** The concurrent access control group id */
    private String id;
    /** Version number of the throttle policy */
    private long version;

    public IPBasedThrottleHandler() {
        this.applicationRoleBasedAccessController = new RoleBasedAccessRateController();
    }

    public boolean handleRequest(MessageContext messageContext) {
        return doThrottle(messageContext);
    }

    public boolean handleResponse(MessageContext messageContext) {
        return doThrottle(messageContext);
    }

    private boolean doThrottle(MessageContext messageContext) {
        boolean canAccess = true;
        boolean isResponse = messageContext.isResponse();
        org.apache.axis2.context.MessageContext axis2MC = ((Axis2MessageContext) messageContext).
                getAxis2MessageContext();
        ConfigurationContext cc = axis2MC.getConfigurationContext();
        synchronized (this) {

            if (!isResponse) {
                initThrottle(messageContext, cc);
            }
        }         // if the access is success through concurrency throttle and if this is a request message
        // then do access rate based throttling
        if (!isResponse && throttle != null) {
            AuthenticationContext authContext = APISecurityUtils.getAuthenticationContext(messageContext);
            String tier;             if (authContext != null) {
                AccessInformation info = null;
                try {

                    String ipBasedKey = (String) ((TreeMap) axis2MC.
                            getProperty("TRANSPORT_HEADERS")).get("X-Forwarded-For");
                    if (ipBasedKey == null) {
                        ipBasedKey = (String) axis2MC.getProperty("REMOTE_ADDR");
                    }
                    tier = authContext.getApplicationTier();
                    ThrottleContext apiThrottleContext =
                            ApplicationThrottleController.
                                    getApplicationThrottleContext(messageContext, cc, tier);
                    //    if (isClusteringEnable) {
                    //      applicationThrottleContext.setConfigurationContext(cc);
                    apiThrottleContext.setThrottleId(id);
                    info = applicationRoleBasedAccessController.canAccess(apiThrottleContext,
                                                                          ipBasedKey, tier);
                    canAccess = info.isAccessAllowed();
                } catch (ThrottleException e) {
                    handleException("Error while trying evaluate IPBased throttling policy", e);
                }
            }
        }         if (!canAccess) {
            handleThrottleOut(messageContext);
            return false;
        }

        return canAccess;
    }     private void initThrottle(MessageContext synCtx, ConfigurationContext cc) {
        if (policyKey == null) {
            throw new SynapseException("Throttle policy unspecified for the API");
        }         Entry entry = synCtx.getConfiguration().getEntryDefinition(policyKey);
        if (entry == null) {
            handleException("Cannot find throttling policy using key: " + policyKey);
            return;
        }
        Object entryValue = null;
        boolean reCreate = false;         if (entry.isDynamic()) {
            if ((!entry.isCached()) || (entry.isExpired()) || throttle == null) {
                entryValue = synCtx.getEntry(this.policyKey);
                if (this.version != entry.getVersion()) {
                    reCreate = true;
                }
            }
        } else if (this.throttle == null) {
            entryValue = synCtx.getEntry(this.policyKey);
        }         if (reCreate || throttle == null) {
            if (entryValue == null || !(entryValue instanceof OMElement)) {
                handleException("Unable to load throttling policy using key: " + policyKey);
                return;
            }
            version = entry.getVersion();             try {
                // Creates the throttle from the policy
                throttle = ThrottleFactory.createMediatorThrottle(
                        PolicyEngine.getPolicy((OMElement) entryValue));

            } catch (ThrottleException e) {
                handleException("Error processing the throttling policy", e);
            }
        }
    }     public void setId(String id) {
        this.id = id;
    }     public String getId(){
        return id;
    }     public void setPolicyKey(String policyKey){
        this.policyKey = policyKey;
    }     public String gePolicyKey(){
        return policyKey;
    }     private void handleException(String msg, Exception e) {
        log.error(msg, e);
        throw new SynapseException(msg, e);
    }     private void handleException(String msg) {
        log.error(msg);
        throw new SynapseException(msg);
    }     private OMElement getFaultPayload() {
        OMFactory fac = OMAbstractFactory.getOMFactory();
        OMNamespace ns = fac.createOMNamespace(APIThrottleConstants.API_THROTTLE_NS,
                                               APIThrottleConstants.API_THROTTLE_NS_PREFIX);
        OMElement payload = fac.createOMElement("fault", ns);         OMElement errorCode = fac.createOMElement("code", ns);
     errorCode.setText(String.valueOf(APIThrottleConstants.THROTTLE_OUT_ERROR_CODE));
        OMElement errorMessage = fac.createOMElement("message", ns);
        errorMessage.setText("Message Throttled Out");
        OMElement errorDetail = fac.createOMElement("description", ns);
        errorDetail.setText("You have exceeded your quota");

        payload.addChild(errorCode);
        payload.addChild(errorMessage);
        payload.addChild(errorDetail);
        return payload;
    }     private void handleThrottleOut(MessageContext messageContext) {
        messageContext.setProperty(SynapseConstants.ERROR_CODE, 900800);
        messageContext.setProperty(SynapseConstants.ERROR_MESSAGE, "Message throttled out");

        Mediator sequence = messageContext.getSequence(APIThrottleConstants.API_THROTTLE_OUT_HANDLER);
        // Invoke the custom error handler specified by the user
        if (sequence != null && !sequence.mediate(messageContext)) {
            // If needed user should be able to prevent the rest of the fault handling
            // logic from getting executed
            return;
        }         // By default we send a 503 response back
        if (messageContext.isDoingPOX() || messageContext.isDoingGET()) {
            Utils.setFaultPayload(messageContext, getFaultPayload());
        } else {
            Utils.setSOAPFault(messageContext, "Server", "Message Throttled Out",
                               "You have exceeded your quota");
        }
        org.apache.axis2.context.MessageContext axis2MC = ((Axis2MessageContext) messageContext).
                getAxis2MessageContext();

        if (Utils.isCORSEnabled()) {
            /* For CORS support adding required headers to the fault response */
            Map headers = (Map) axis2MC.getProperty(org.apache.axis2.context.MessageContext.TRANSPORT_HEADERS);
            headers.put(APIConstants.CORSHeaders.ACCESS_CONTROL_ALLOW_ORIGIN, Utils.getAllowedOrigin((String)headers.get("Origin")));
            headers.put(APIConstants.CORSHeaders.ACCESS_CONTROL_ALLOW_METHODS, Utils.getAllowedMethods());
            headers.put(APIConstants.CORSHeaders.ACCESS_CONTROL_ALLOW_HEADERS, Utils.getAllowedHeaders());
            axis2MC.setProperty(org.apache.axis2.context.MessageContext.TRANSPORT_HEADERS, headers);
        }
        Utils.sendFault(messageContext, HttpStatus.SC_SERVICE_UNAVAILABLE);
    }
}

As listed above your custom handler class is : "org.wso2.carbon.apimgt.gateway.handlers.throttling.IPBasedThrottleHandler", the following will be the handler definition for your API.


<handler class="org.wso2.carbon.apimgt.gateway.handlers.throttling.IPBasedThrottleHandler">
<property name="id" value="A"/>
<property name="policyKey" value="gov:/apimgt/applicationdata/tiers.xml"/>
</handler>

Then try to invoke API and see how throttling works.

No comments:

Post a Comment