/*
 * Copyright 2005-2007 WSO2, Inc. (http://wso2.com)
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.wso2.wsas.admin.service;

import org.apache.axiom.om.OMElement;
import org.apache.axis2.AxisFault;
import org.apache.axis2.context.MessageContext;
import org.apache.axis2.description.AxisService;
import org.apache.axis2.description.Parameter;
import org.apache.axis2.engine.AxisConfiguration;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.rahas.impl.SAMLTokenIssuerConfig;
import org.wso2.utils.AbstractAdmin;
import org.wso2.wsas.ServerConstants;
import org.wso2.wsas.admin.service.util.TrustedServiceData;
import org.wso2.wsas.persistence.PersistenceManager;
import org.wso2.wsas.persistence.dataobject.KeyStoreDO;
import org.wso2.wsas.util.KeyStoreUtil;

import java.security.KeyStore;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Enumeration;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;

/**
 * Administration service of the WSO2WSAS-SecurityTokenService
 */
public class STSAdmin extends AbstractAdmin {

    private static Log log = LogFactory.getLog(STSAdmin.class);

    /**
     * Add a the given service endpoint as a trusted endpoint address
     *
     * @param serviceAddress Address of the service endpoint
     * @param certAlias      Alias of the service cert
     * @throws AxisFault
     */
    public void addTrustedService(String serviceAddress, String certAlias) throws AxisFault {
        try {
            MessageContext msgCtx = MessageContext.getCurrentMessageContext();
            AxisConfiguration config = msgCtx.getConfigurationContext().getAxisConfiguration();
            AxisService stsService = config.getService(ServerConstants.STS_NAME);
            Parameter origParam =
                    stsService.getParameter(SAMLTokenIssuerConfig.SAML_ISSUER_CONFIG.getLocalPart());
            if (origParam != null) {
                OMElement samlConfigElem =
                        origParam.getParameterElement().
                                getFirstChildWithName(SAMLTokenIssuerConfig.SAML_ISSUER_CONFIG);
                SAMLTokenIssuerConfig samlConfig = new SAMLTokenIssuerConfig(samlConfigElem);
                samlConfig.addTrustedServiceEndpointAddress(serviceAddress, certAlias);
                setSTSParameter(samlConfig);
            } else {
                throw new AxisFault("missing parameter : "
                                    + SAMLTokenIssuerConfig.SAML_ISSUER_CONFIG.getLocalPart());
            }

        } catch (Exception e) {
            throw new AxisFault(e.getMessage(), e);
        }
    }

    public void setProofKeyType(String keyType) throws AxisFault {
        try {
            MessageContext msgCtx = MessageContext.getCurrentMessageContext();
            AxisConfiguration config = msgCtx.getConfigurationContext().getAxisConfiguration();
            AxisService service = config.getService(ServerConstants.STS_NAME);
            Parameter origParam = service
                    .getParameter(SAMLTokenIssuerConfig.SAML_ISSUER_CONFIG
                            .getLocalPart());
            if (origParam != null) {
                OMElement samlConfigElem = origParam.getParameterElement().getFirstChildWithName(
                        SAMLTokenIssuerConfig.SAML_ISSUER_CONFIG);
                SAMLTokenIssuerConfig samlConfig = new SAMLTokenIssuerConfig(samlConfigElem);
                samlConfig.setProofKeyType(keyType);
                setSTSParameter(samlConfig);
            } else {
                throw new AxisFault("missing parameter : "
                                    + SAMLTokenIssuerConfig.SAML_ISSUER_CONFIG
                        .getLocalPart());
            }

        } catch (Exception e) {
            throw new AxisFault(e.getMessage(), e);
        }
    }

    public TrustedServiceData[] getTrustedServices() throws AxisFault {
        try {
            MessageContext msgCtx = MessageContext.getCurrentMessageContext();
            AxisConfiguration config = msgCtx.getConfigurationContext().getAxisConfiguration();
            AxisService service = config.getService(ServerConstants.STS_NAME);
            Parameter origParam = service
                    .getParameter(SAMLTokenIssuerConfig.SAML_ISSUER_CONFIG
                            .getLocalPart());
            if (origParam != null) {
                OMElement samlConfigElem = origParam.getParameterElement().getFirstChildWithName(
                        SAMLTokenIssuerConfig.SAML_ISSUER_CONFIG);
                SAMLTokenIssuerConfig samlConfig = new SAMLTokenIssuerConfig(samlConfigElem);
                Map trustedServicesMap = samlConfig.getTrustedServices();
                Set addresses = trustedServicesMap.keySet();

                ArrayList serviceBag = new ArrayList();
                for (Iterator iterator = addresses.iterator(); iterator.hasNext();) {
                    String address = (String) iterator.next();
                    String alias = (String) trustedServicesMap.get(address);
                    TrustedServiceData data = new TrustedServiceData(address, alias);
                    serviceBag.add(data);
                }
                return (TrustedServiceData[])
                        serviceBag.toArray(new TrustedServiceData[serviceBag.size()]);
            } else {
                throw new AxisFault("missing parameter : "
                                    + SAMLTokenIssuerConfig.SAML_ISSUER_CONFIG.getLocalPart());
            }

        } catch (Exception e) {
            throw new AxisFault(e.getMessage(), e);
        }
    }

    public String getProofKeyType() throws AxisFault {
        try {
            MessageContext msgCtx = MessageContext.getCurrentMessageContext();
            AxisConfiguration config = msgCtx.getConfigurationContext().getAxisConfiguration();
            AxisService service = config.getService(ServerConstants.STS_NAME);
            Parameter origParam = service
                    .getParameter(SAMLTokenIssuerConfig.SAML_ISSUER_CONFIG
                            .getLocalPart());
            if (origParam != null) {
                OMElement samlConfigElem =
                        origParam.getParameterElement().
                                getFirstChildWithName(SAMLTokenIssuerConfig.SAML_ISSUER_CONFIG);
                SAMLTokenIssuerConfig samlConfig = new SAMLTokenIssuerConfig(samlConfigElem);
                return samlConfig.getProofKeyType();
            } else {
                throw new AxisFault("missing parameter : "
                                    + SAMLTokenIssuerConfig.SAML_ISSUER_CONFIG.getLocalPart());
            }
        } catch (Exception e) {
            throw new AxisFault(e.getMessage(), e);
        }
    }


    /**
     * Returns certificate aliases of primary keystore
     *
     * @return String array
     * @throws AxisFault
     */
    public String[] getCertAliasOfPrimaryKeyStore() throws AxisFault {
        PersistenceManager pm = new PersistenceManager();
        KeyStoreDO[] keyStores = pm.getKeyStores();
        KeyStoreDO primaryKeystoe = null;
        for (int i = 0; i < keyStores.length; i++) {
            if (keyStores[i].getIsPrimaryKeyStore()) {
                primaryKeystoe = keyStores[i];
                break;
            }
        }
        if (primaryKeystoe != null) {
            String keysoteName = primaryKeystoe.getKeyStoreName();
            Collection aliases = new ArrayList();
            try {
                KeyStore keyStore = KeyStoreUtil.getKeyStore(keysoteName);
                Enumeration enumeration = keyStore.aliases();
                while (enumeration.hasMoreElements()) {
                    String alias = (String) enumeration.nextElement();
                    aliases.add(alias);
                }
            } catch (Exception e) {
                String msg = "Could not read certificates from keystore file. ";
                log.error(msg, e);
                throw new AxisFault(msg + e.getMessage());
            }
            return (String[]) aliases.toArray(new String[aliases.size()]);
        }
        throw new AxisFault("Primary Keystore cannot be found.");
    }

    private void setSTSParameter(SAMLTokenIssuerConfig samlConfig) throws AxisFault {
        new ServiceAdmin().setServiceParameter3(ServerConstants.STS_NAME,
                                                samlConfig.getParameter());
    }
}
