/*
* Copyright 2005,2006 WSO2, Inc. http://wso2.com
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*      http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*
*/

package org.wso2.throttle.module.handler;


import org.apache.axis2.AxisFault;
import org.apache.axis2.clustering.ClusterManager;
import org.apache.axis2.clustering.ClusteringFault;
import org.apache.axis2.clustering.context.Replicator;
import org.apache.axis2.context.ConfigurationContext;
import org.apache.axis2.context.MessageContext;
import org.apache.axis2.description.AxisOperation;
import org.apache.axis2.description.AxisService;
import org.apache.axis2.handlers.AbstractHandler;
import org.apache.axis2.transport.http.HTTPConstants;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.wso2.throttle.*;

import javax.servlet.http.HttpServletRequest;
import javax.xml.namespace.QName;
import java.util.Map;


public abstract class ThrottleHandler extends AbstractHandler {

    private static Log log = LogFactory.getLog(ThrottleHandler.class.getName());
    /* The AccessRateController - control(limit) access for a remote caller */
    private AccessRateController accessRateController;

    private boolean debugOn;

    public ThrottleHandler() {
        this.debugOn = log.isDebugEnabled();
        this.accessRateController = new AccessRateController();
    }

    /**
     * @return int - indicates the type of the throttle according to the scope
     */
    protected abstract int getThrottleType();

    /**
     * Loads a throttle metadata for a particular throttle type
     *
     * @param messageContext - The messageContext
     * @param throttleType   - The type of throttle
     * @return IPBaseThrottleConfiguration     - The IPBaseThrottleConfiguration - load from AxisConfiguration
     * @throws ThrottleException Throws if the throttle type is unsupported
     */

    public Throttle loadThrottle(MessageContext messageContext,
                                 int throttleType) throws ThrottleException {

        Throttle throttle = null;
        ConfigurationContext configContext = messageContext.getConfigurationContext();
        //the Parameter which hold throttle ipbase object
        // to get thottles map from the configuration context

        Map throttles = (Map) configContext.getPropertyNonReplicable(ThrottleConstants.THROTTLES_MAP);
        if (throttles == null) {
            if (debugOn) {
                log.debug("Couldn't find thottles object map .. thottlling will not be occurred ");
            }
            return null;
        }
        switch (throttleType) {
            case ThrottleConstants.GLOBAL_THROTTLE: {
                throttle =
                    (Throttle) throttles.get(ThrottleConstants.GLOBAL_THROTTLE_KEY);
                break;
            }
            case ThrottleConstants.OPERATION_BASED_THROTTLE: {
                AxisOperation axisOperation = messageContext.getAxisOperation();
                if (axisOperation != null) {
                    QName opName = axisOperation.getName();
                    if (opName != null) {
                        AxisService service = (AxisService) axisOperation.getParent();
                        if (service != null) {
                            String currentServiceName = service.getName();
                            if (currentServiceName != null) {
                                throttle =
                                    (Throttle) throttles.get(currentServiceName + opName.getLocalPart());
                            }
                        }
                    }
                } else {
                    if (debugOn) {
                        log.debug("Couldn't find axis operation ");
                    }
                    return null;
                }
                break;
            }
            case ThrottleConstants.SERVICE_BASED_THROTTLE: {
                AxisService axisService = messageContext.getAxisService();
                if (axisService != null) {
                    throttle =
                        (Throttle) throttles.get(axisService.getName());
                } else {
                    if (debugOn) {
                        log.debug("Couldn't find axis service ");
                    }
                    return null;
                }
                break;
            }
            default: {
                throw new ThrottleException("Unsupported Throttle type");
            }
        }
        return throttle;
    }

    /**
     * processing through the throttle
     * 1) concurrent throttling
     * 2) access rate based throttling - domain or ip
     *
     * @param throttle       The Throttle object - holds all configuration and state data of the throttle
     * @param messageContext The MessageContext , that holds all data per message basis
     * @throws AxisFault         Throws when access must deny for caller
     * @throws ThrottleException
     */
    public void process(Throttle throttle,
                        MessageContext messageContext) throws ThrottleException, AxisFault {

        String throttleId = throttle.getId();
        ConfigurationContext cc = messageContext.getConfigurationContext();

        //check the env - whether clustered  or not
        boolean isClusteringEnable = false;
        ClusterManager clusterManager = cc.getAxisConfiguration().getClusterManager();
        if (clusterManager != null &&
            clusterManager.getContextManager() != null) {
            isClusteringEnable = true;
        }

        // Get the concurrent access controller
        ConcurrentAccessController cac;
        String key = null;
        if (isClusteringEnable) {
            // for clustered  env.,gets it from axis configuration context
            key = ThrottleConstants.THROTTLE_PROPERTY_PREFIX + throttleId
                + ThrottleConstants.CAC_SUFFIX;
            cac = (ConcurrentAccessController) cc.getProperty(key);
        } else {
            // for non-clustered  env.,gets it from axis configuration context
            cac = throttle.getConcurrentAccessController();
        }

        // check for concurrent access
        boolean canAccess = doConcurrentThrottling(cac, messageContext);

        if (canAccess) {
            // if the concurrent access is success then
            // do the access rate based throttling

            if (messageContext.getFLOW() == MessageContext.IN_FLOW) {
                //gets the remote caller domain name
                String domain = null;
                HttpServletRequest request =
                    (HttpServletRequest) messageContext.getPropertyNonReplicable(
                        HTTPConstants.MC_HTTP_SERVLETREQUEST);
                if (request != null) {
                    domain = request.getRemoteHost();
                }

                // Domain name based throttling
                //check whether a configuration has been defined for this domain name or not
                String callerId = null;
                if (domain != null) {
                    //loads the ThrottleContext
                    ThrottleContext context =
                        throttle.getThrottleContext(ThrottleConstants.DOMAIN_BASED_THROTTLE_KEY);
                    if (context != null) {
                        //Loads the ThrottleConfiguration
                        ThrottleConfiguration config = context.getThrottleConfiguration();
                        if (config != null) {
                            //check for configuration for this caller
                            callerId = config.getConfigurationKeyOfCaller(domain);
                            if (callerId != null) {
                                // If this is a clusterred env.
                                if (isClusteringEnable) {
                                    context.setConfigurationContext(cc);
                                    context.setThrottleId(throttleId);
                                }
                                //check for the permission for access
                                if (!accessRateController.canAccess(context, callerId,
                                    ThrottleConstants.DOMAIN_BASE)) {

                                    //In the case of both of concurrency throttling and
                                    //rate based throttling have enabled ,
                                    //if the access rate less than maximum concurrent access ,
                                    //then it is possible to occur death situation.To avoid that reset,
                                    //if the access has denied by rate based throttling
                                    if (cac != null) {
                                        cac.incrementAndGet();
                                        // set back if this is a clustered env
                                        if (isClusteringEnable) {
                                            cc.setProperty(key, cac);
                                            //replicate the current state of ConcurrentAccessController
                                            try {
                                                if (debugOn) {
                                                    log.debug("Going to replicates the " +
                                                        "states of the ConcurrentAccessController" +
                                                        " with key : " + key);
                                                }
//                                                Replicator.replicate(cc, new String[]{key});
                                                Replicator.replicate(cc);
                                            } catch (ClusteringFault clusteringFault) {
                                                log.error("Error during replicating states ",
                                                    clusteringFault);
                                            }
                                        }
                                    }
                                    throw new AxisFault("A caller with domain " + domain + " cannot access "
                                        + "this service since the allocated quota  have been exceeded.");
                                }
                            } else {
                                if (debugOn) {
                                    log.debug("Could not find the Throttle Context for domain-Based " +
                                        "Thottling for domain name " + domain + " Throttling for this " +
                                        "domain name may not be configured from policy");
                                }
                            }
                        }
                    }
                } else {
                    if (debugOn) {
                        log.debug("Could not find the domain of the caller - IP-based throttling may occur");
                    }
                }

                //IP based throtteling - Only if there is no configuration for caller domain name

                if (callerId == null) {
                    String ip = (String) messageContext.getProperty(MessageContext.REMOTE_ADDR);
                    if (ip != null) {
                        // loads IP based thottle context
                        ThrottleContext context =
                            throttle.getThrottleContext(ThrottleConstants.IP_BASED_THROTTLE_KEY);
                        if (context != null) {
                            //Loads the ThrottleConfiguration
                            ThrottleConfiguration config = context.getThrottleConfiguration();
                            if (config != null) {
                                // check for configuration for this ip
                                callerId = config.getConfigurationKeyOfCaller(ip);
                                if (callerId != null) {
                                    // for clustered env.
                                    if (isClusteringEnable) {
                                        context.setConfigurationContext(cc);
                                        context.setThrottleId(throttleId);
                                    }
                                    // check for the permission for access
                                    if (!accessRateController.canAccess(context
                                        , callerId
                                        , ThrottleConstants.IP_BASE)) {

                                        //In the case of both of concurrency throttling and
                                        //rate based throttling have enabled ,
                                        //if the access rate less than maximum concurrent access ,
                                        //then it is possible to occur death situation.To avoid that reset,
                                        //if the access has denied by rate based throttling
                                        if (cac != null) {
                                            cac.incrementAndGet();
                                            // set back if this is a clustered env
                                            if (isClusteringEnable) {
                                                cc.setProperty(key, cac);
                                                //replicate the current state of ConcurrentAccessController
                                                try {
                                                    if (debugOn) {
                                                        log.debug("Going to replicates the " +
                                                            "states of the ConcurrentAccessController" +
                                                            " with key : " + key);
                                                    }
//                                                    Replicator.replicate(cc, new String[]{key});
                                                    Replicator.replicate(cc);
                                                } catch (ClusteringFault clusteringFault) {
                                                    log.error("Error during replicating states ",
                                                        clusteringFault);
                                                }
                                            }
                                        }
                                        throw new AxisFault("A caller with IP " + ip + " " +
                                            "cannot access this service since the allocated quota" +
                                            "  have been exceeded.");
                                    }
                                }
                            }
                        } else {
                            if (debugOn) {
                                log.debug("Could not find the Throttle Context for IP-Based Thottling");
                            }
                        }
                    } else {
                        if (debugOn) {
                            log.debug("Could not find the IP address of the caller " +
                                "- throttling will not occur");
                        }
                    }
                }
            }
            // all the replication functionality of the access rate based throttling handles by itself
            // just replicate the current state of ConcurrentAccessController
            if (isClusteringEnable && cac != null) {
                try {
                    if (debugOn) {
                        log.debug("Going to replicates the states of the ConcurrentAccessController" +
                            " with key : " + key);
                    }
//                    Replicator.replicate(cc, new String[]{key});
                    Replicator.replicate(cc);
                } catch (ClusteringFault clusteringFault) {
                    log.error("Error during replicating states ", clusteringFault);
                }
            }

        } else {
            //replicate the current state of ConcurrentAccessController
            if (isClusteringEnable) {
                try {
                    if (debugOn) {
                        log.debug("Going to replicates the states of the ConcurrentAccessController" +
                            " with key : " + key);
                    }
//                    Replicator.replicate(cc, new String[]{key});
                    Replicator.replicate(cc);
                } catch (ClusteringFault clusteringFault) {
                    log.error("Error during replicating states ", clusteringFault);
                }
            }
            throw new AxisFault("Access has currently been denied since " +
                " maximum concurrent access have exceeded");
        }

    }

    /**
     * Helper method for handling concurrent throttling
     *
     * @param concurrentAccessController ConcurrentAccessController
     * @param messageContext             MessageContext - message lavel states
     * @return true if access is allowed through concurrent throttling ,o.w false
     */
    private boolean doConcurrentThrottling(ConcurrentAccessController concurrentAccessController, MessageContext messageContext) {

        boolean canAccess = true;
        int available;

        if (concurrentAccessController != null) {
            if (messageContext.getFLOW() == MessageContext.IN_FLOW) {
                available = concurrentAccessController.getAndDecrement();
                canAccess = available > 0;
                if (debugOn) {
                    log.debug("Concurrency Throttle : Access " + (canAccess ? "allowed" : "denied") +
                        " :: " + available + " of available of " +
                        concurrentAccessController.getLimit() + " connections");
                }
                if (debugOn) {
                    if (!canAccess) {
                        log.debug("Concurrency Throttle : Access has currently been denied since allowed" +
                            " maximum concurrent access have exceeded");
                    }
                }
            } else if (messageContext.getFLOW() == MessageContext.OUT_FLOW) {
                available = concurrentAccessController.incrementAndGet();
                if (debugOn) {
                    log.debug("Concurrency Throttle : Connection returned" +
                        " :: " + available + " of available of "
                        + concurrentAccessController.getLimit() + " connections");
                }
            }
        }
        return canAccess;
    }

    public InvocationResponse invoke(MessageContext msgContext) throws AxisFault {
        //Load throttle
        try {
            Throttle throttle = loadThrottle(msgContext, getThrottleType());
            if (throttle != null) {
                process(throttle, msgContext);
            }
        } catch (ThrottleException e) {
            log.error(e.getMessage());
            throw new AxisFault(e.getMessage());
        }
        return InvocationResponse.CONTINUE;
    }

}
