/*
 * Copyright 2005-2007 WSO2, Inc. (http://wso2.com)
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.wso2.wsas.admin.service;

import org.apache.axis2.AxisFault;
import org.apache.axis2.context.ConfigurationContext;
import org.apache.axis2.context.MessageContext;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.wso2.utils.AbstractAdmin;
import org.wso2.utils.ServerConfiguration;
import org.wso2.utils.ServerException;
import org.wso2.utils.security.CryptoUtil;
import org.wso2.wsas.ServerConstants;
import org.wso2.wsas.admin.service.util.CertData;
import org.wso2.wsas.admin.service.util.KeyStoreData;
import org.wso2.wsas.admin.service.util.KeyStoreSummary;
import org.wso2.wsas.admin.service.util.ServiceKeyStore;
import org.wso2.wsas.persistence.PersistenceManager;
import org.wso2.wsas.persistence.dataobject.KeyStoreDO;
import org.wso2.wsas.persistence.dataobject.ServiceDO;
import org.wso2.wsas.persistence.dataobject.ServiceIdentifierDO;
import org.wso2.wsas.persistence.exception.KeyStoreAlreadyExistsException;
import org.wso2.wsas.util.KeyStoreUtil;
import sun.misc.BASE64Encoder;

import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.security.KeyStore;
import java.security.PrivateKey;
import java.security.UnrecoverableKeyException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.text.Format;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.Enumeration;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
 * Handles all cryptographic administration functions
 */
public class CryptoAdmin extends AbstractAdmin {
    private static Log log = LogFactory.getLog(TransportAdmin.class);
    private PersistenceManager pm = new PersistenceManager();

    /**
     * Get all the private keys in KeyStore corresponding to
     * <code>keyStoreFileId</code>
     *
     * @param keyStoreFileId
     * @param storePass
     * @param keyStoreType   Type of the Key Store (JKS/PKCS12 etc)
     * @return The private keys in KeyStore corresponding to
     *         <code>keyStoreFileId</code>
     */
    public String[] getPrivateKeys(String keyStoreFileId,
                                   String storePass,
                                   String keyStoreType) throws AxisFault {
        Collection pvtKeys = new ArrayList();
        String filePath = getFilePathFromFileId(keyStoreFileId);
        FileInputStream in = null;
        BufferedInputStream ksbufin = null;
        try {
            KeyStore keyStore = KeyStore.getInstance(keyStoreType);
            in = new FileInputStream(filePath);
            ksbufin = new BufferedInputStream(in);
            keyStore.load(ksbufin, storePass.toCharArray()); //Populate the keystore
            Enumeration enumeration = keyStore.aliases();
            while (enumeration.hasMoreElements()) {
                String alias = (String) enumeration.nextElement();
                if (keyStore.isKeyEntry(alias)) {
                    pvtKeys.add(alias);
                }
            }
        } catch (Exception e) {
            String msg = "Could not read private keys from keystore file. ";
            log.error(msg, e);
            throw new AxisFault(msg + e.getMessage());
        } finally {
            try {
                if (in != null) {
                    in.close();
                }
                if (ksbufin != null) {
                    ksbufin.close();
                }
            } catch (IOException e) {
                log.error("Error occurred while closing keystore file " + filePath, e);
            }
        }
        return (String[]) pvtKeys.toArray(new String[pvtKeys.size()]);
    }

    /**
     * Get the certificate aliases in a particular key store
     *
     * @param keyStoreName The name of the keystore
     * @return Certificate aliases
     * @throws AxisFault
     */
    public String[] getCertificates(String keyStoreName) throws AxisFault {
        List certificates = new ArrayList();
        try {
            KeyStore keyStore = KeyStoreUtil.getKeyStore(keyStoreName);
            Enumeration enumeration = keyStore.aliases();
            while (enumeration.hasMoreElements()) {
                String alias = (String) enumeration.nextElement();
                certificates.add(alias);
            }
        } catch (Exception e) {
            String msg = "Could not read certificates from keystore file. ";
            log.error(msg, e);
            throw new AxisFault(msg + e.getMessage());
        }
        Collections.sort(certificates, new Comparator() {
            public int compare(Object arg0, Object arg1) {
                String a = (String) arg0;
                String b = (String) arg1;
                return a.compareToIgnoreCase(b);
            }
        });
        return (String[]) certificates.toArray(new String[certificates.size()]);
    }

    /**
     * This method will list
     * 1. Certificate aliases
     * 2. Private key alise
     * 3. Private key value
     * to a given keystore.
     *
     * @param keyStoreName The name of the keystore
     * @return Instance of KeyStoreData
     * @throws AxisFault will be thrown
     */
    public KeyStoreData getKeystoreInfo(String keyStoreName) throws AxisFault {
        try {
            KeyStoreData keyStoreData = new KeyStoreData();
            keyStoreData.setKeyStoreName(keyStoreName);
            Format formatter = new SimpleDateFormat("dd/MM/yyyy");
            KeyStore keyStore = KeyStoreUtil.getKeyStore(keyStoreName);
            Enumeration aliases = keyStore.aliases();
            List certDataList = new ArrayList();
            while (aliases.hasMoreElements()) {
                String alias = (String) aliases.nextElement();
                if (keyStore.isCertificateEntry(alias)) {
                    X509Certificate cert = (X509Certificate) keyStore.getCertificate(alias);
                    certDataList.add(fillCertData(cert, alias, formatter));
                }
            }
            CertData[] certs = (CertData[]) certDataList.toArray(new CertData[certDataList.size()]);
            keyStoreData.setCerts(certs);
            aliases = keyStore.aliases();
            while (aliases.hasMoreElements()) {
                String alias = (String) aliases.nextElement();
                // There be only one entry in WSAS related keystores
                if (keyStore.isKeyEntry(alias)) {
                    X509Certificate cert = (X509Certificate) keyStore.getCertificate(alias);
                    keyStoreData.setKey(fillCertData(cert, alias, formatter));
                    KeyStoreDO keyStoreDO = pm.getKeyStore(keyStoreName);
                    String privateKeyPassowrd = keyStoreDO.getPrivateKeyPassword();
                    ServerConfiguration config = ServerConfiguration.getInstance();
                    CryptoUtil cryptoUtil =
                            new CryptoUtil(new File(config.getFirstProperty(
                                    "Security.KeyStore.Location")).getAbsolutePath(),
                                           config.getFirstProperty("Security.KeyStore.Password"),
                                           config.getFirstProperty("Security.KeyStore.KeyAlias"),
                                           config.getFirstProperty("Security.KeyStore.KeyPassword"),
                                           config.getFirstProperty("Security.KeyStore.Type"));
                    byte[] bytes = cryptoUtil.base64DecodeAndDecrypt(privateKeyPassowrd);
                    PrivateKey key =
                            (PrivateKey) keyStore.getKey(alias, new String(bytes).toCharArray());
                    String pemKey = "";
                    BASE64Encoder encoder = new BASE64Encoder();
                    pemKey = "-----BEGIN PRIVATE KEY-----\n";
                    pemKey += encoder.encode(key.getEncoded());
                    pemKey += "\n-----END PRIVATE KEY-----";
                    keyStoreData.setKeyValue(pemKey);
                    keyStoreData.setType(keyStoreDO.getKeyStoreType());
                    break;
                }
            }
            return keyStoreData;
        } catch (Exception e) {
            String msg =
                    "Error has encounted while loading the keystore to the given keystore name " +
                    keyStoreName;
            log.error(msg, e);
            throw new AxisFault(msg);
        }

    }

    private CertData fillCertData(X509Certificate cert, String alise, Format formatter) {
        CertData certData = new CertData();
        certData.setAlias(alise);
        certData.setSubjectDN(cert.getSubjectDN().getName());
        certData.setIssuerDN(cert.getIssuerDN().getName());
        certData.setSerialNumber(cert.getSerialNumber());
        certData.setVersion(cert.getVersion());
        certData.setNotAfter(formatter.format(cert.getNotAfter()));
        certData.setNotBefore(formatter.format(cert.getNotBefore()));
        return certData;
    }

    public String addNewKeyStore(String ksFilePathId,
                                 String ksPassword,
                                 String pvtKeyAlias,
                                 String pvtKeyPassword,
                                 String keyStoreType,
                                 String provider) throws AxisFault {
        String ksFilePath = getFilePathFromFileId(ksFilePathId);
        if (ksFilePath == null) {
            return "File path corresponding to " + ksFilePathId + " cannot be found.";
        }
        ServerConfiguration serverConfig = ServerConfiguration.getInstance();
        String keyStoreName =
                ksFilePath.substring(ksFilePath.lastIndexOf(File.separator) + 1);
        File ksFile;

        // Move the KS file from work to WSO2WSAS_HOME/conf/keystores, if there is more than
        // 1 Kpr, delete all Kpr which are not equal to pvtKeyAlias
        FileInputStream in = null;
        BufferedInputStream ksbufin = null;
        OutputStream os = null;
        try {

            // mkdir keystore
            File ksDir = new File(serverConfig.getFirstProperty("Security.KeyStoresDir"));
            if (!ksDir.exists()) {
                ksDir.mkdirs();
            }

            // Check whether KS file already exists
            ksFile = new File(ksDir.getAbsolutePath(), keyStoreName);
            if (ksFile.exists()) {
                throw new AxisFault("Keystore file " + ksFile.getName() + " already exists!");
            }

            KeyStore keyStore = KeyStore.getInstance(keyStoreType);
            in = new FileInputStream(ksFilePath);
            ksbufin = new BufferedInputStream(in);
            keyStore.load(ksbufin, ksPassword.toCharArray()); //Populate the keystore

            if (!keyStore.isKeyEntry(pvtKeyAlias)) {
                return pvtKeyAlias + " is not a key entry";
            }

            keyStore.getKey(pvtKeyAlias, pvtKeyPassword.toCharArray());

            // Remove all other private keys
            Enumeration enumeration = keyStore.aliases();
            while (enumeration.hasMoreElements()) {
                String alias = (String) enumeration.nextElement();
                if (keyStore.isKeyEntry(alias) && !alias.equals(pvtKeyAlias)) {
                    keyStore.deleteEntry(alias);
                }
            }

            // Move the KS file to WSO2WSAS_HOME/conf/keystores
            os = new FileOutputStream(ksFile);
            keyStore.store(os, ksPassword.toCharArray());
            in.close();
            os.flush();
            os.close();
        } catch (UnrecoverableKeyException e) {
            String msg = "Cannot retrieve private key." +
                         " Please verify that the password is correct.";
            log.error(msg, e);
            throw new AxisFault(msg, e);
        } catch (Exception e) {
            String msg = "Could not add new keystore. ";
            log.error(msg, e);
            throw new AxisFault(msg + e.getMessage());
        } finally {
            try {
                if (in != null) {
                    in.close();
                }
                if (ksbufin != null) {
                    ksbufin.close();
                }
                if (os != null) {
                    os.close();
                }
            } catch (IOException e) {
                log.error("Error occurred while closing keystore file " + ksFilePath, e);
            }
        }

        // Store the KS entry in the database
        try {
            KeyStoreUtil.persistKeyStore(ksFile.getAbsolutePath(),
                                         ksPassword,
                                         keyStoreType,
                                         pvtKeyAlias,
                                         pvtKeyPassword,
                                         provider,
                                         false);
        } catch (KeyStoreAlreadyExistsException e) {
            String msg = "Cannot add new keystore. ";
            log.error(msg, e);
            throw new AxisFault(msg, e);
        } catch (ServerException e) {
            String msg = "Cannot add new keystore. ";
            log.error(msg, e);
            throw new AxisFault(msg + e.getMessage());
        }
        return "Keystore " + keyStoreName + " successfully added.";
    }

    public String importCertificate(String keyStoreName,
                                    String certificatePathId) throws AxisFault {
        String msg;
        FileInputStream ksIn = null;
        BufferedInputStream ksbufin = null;
        OutputStream os = null;
        try {
            KeyStoreDO keyStoreDO = pm.getKeyStore(keyStoreName);
            KeyStore keyStore = KeyStore.getInstance(keyStoreDO.getKeyStoreType());
            ksIn = new FileInputStream(keyStoreDO.getFilePath());
            ksbufin = new BufferedInputStream(ksIn);
            String storePassword = keyStoreDO.getStorePassword();

            ServerConfiguration config = ServerConfiguration.getInstance();
            String ksFile = config.getFirstProperty("Security.KeyStore.Location");
            if (!new File(ksFile).isAbsolute()) {
                ksFile = System.getProperty(ServerConstants.WSO2WSAS_HOME) +
                         File.separator + ksFile;
            }
            CryptoUtil cryptoUtil =
                    new CryptoUtil(ksFile,
                                   config.getFirstProperty("Security.KeyStore.Password"),
                                   config.getFirstProperty("Security.KeyStore.KeyAlias"),
                                   config.getFirstProperty("Security.KeyStore.KeyPassword"),
                                   config.getFirstProperty("Security.KeyStore.Type"));
            char[] decryptedStorePass =
                    new String(cryptoUtil.base64DecodeAndDecrypt(storePassword)).toCharArray();
            keyStore.load(ksbufin, decryptedStorePass);
            String certFile = getFilePathFromFileId(certificatePathId);
            FileInputStream certIn = new FileInputStream(certFile);
            String fileSeparator = File.separator;
            if (File.separator.equals("\\")) {
                fileSeparator = "\\";
            }
            String origCertAlias =
                    certFile.substring(certFile.lastIndexOf(fileSeparator) + 1);
            String certAlias = origCertAlias;
            int seq = 0;
            while (keyStore.containsAlias(certAlias)) {
                seq++;
                certAlias = origCertAlias + "." + seq;
            }
            keyStore.setCertificateEntry(certAlias,
                                         CertificateFactory.getInstance("X.509").
                                                 generateCertificate(certIn));
            os = new FileOutputStream(new File(keyStoreDO.getFilePath()));
            keyStore.store(os, decryptedStorePass);
            os.flush();
            os.close();
        } catch (Exception e) {
            msg = "Could not import certificate. Certificate may be invalid. ";
            log.error(msg, e);
            throw new AxisFault(msg, e);
        } finally {
            try {
                if (ksIn != null) {
                    ksIn.close();
                }
                if (ksbufin != null) {
                    ksbufin.close();
                }
                if (os != null) {
                    os.close();
                }
            } catch (IOException e) {
                log.error("Error occurred while closing keystore file " + keyStoreName, e);
            }
        }
        msg = "Certificate imported successfully";
        return msg;
    }

    public String[] getAllKeyStoreNames() {
        KeyStoreDO[] keyStores = pm.getKeyStores();
        String[] ksNames = new String[keyStores.length];

        for (int i = 0; i < keyStores.length; i++) {
            ksNames[i] = keyStores[i].getKeyStoreName();
        }

        return ksNames;
    }

    public ServiceKeyStore[] getServiceKeyStores(String serviceName) {
        String serviceVersion = ServiceIdentifierDO.EMPTY_SERVICE_VERSION;
        ServiceDO service = pm.getService(serviceName, serviceVersion);
        List list = new ArrayList();
        KeyStoreDO[] keyStores = pm.getKeyStores();
        KeyStoreDO privateKeyStore = service.getPrivateKeyStore();
        for (int i = 0; i < keyStores.length; i++) {
            ServiceKeyStore serviceKS = new ServiceKeyStore();
            KeyStoreDO keyStore = keyStores[i];
            serviceKS.setKeyStoreName(keyStore.getKeyStoreName());
            serviceKS.setKeyStoreType(keyStore.getKeyStoreType());
            for (Iterator iterator =
                    service.getTrustedCertStores().iterator(); iterator.hasNext();) {
                KeyStoreDO trustedKS = (KeyStoreDO) iterator.next();
                if (trustedKS.getKeyStoreName().equals(keyStore.getKeyStoreName())) {
                    serviceKS.setSelected(true);
                }
            }
            if (privateKeyStore != null &&
                keyStore.getKeyStoreName().equals(privateKeyStore.getKeyStoreName())) {
                serviceKS.setPrivateKeyStore(true);
            }
            list.add(serviceKS);
        }
        return (ServiceKeyStore[]) list.toArray(new ServiceKeyStore[list.size()]);
    }

    public KeyStoreSummary[] getKeyStores() {
        KeyStoreDO[] keyStores = pm.getKeyStores();
        KeyStoreSummary[] ksSummaries = new KeyStoreSummary[keyStores.length];

        for (int i = 0; i < keyStores.length; i++) {
            KeyStoreDO keyStore = keyStores[i];
            KeyStoreSummary ksSummary = new KeyStoreSummary();
            ksSummary.setKeyStoreFilename(keyStore.getFilePath());
            ksSummary.setKeyStoreName(keyStore.getKeyStoreName());
            ksSummary.setKeyStoreType(keyStore.getKeyStoreType());
            ksSummaries[i] = ksSummary;
        }

        return ksSummaries;
    }

    public String deleteKeyStore(String keyStoreName) throws AxisFault {
        KeyStoreDO keyStore = pm.getKeyStore(keyStoreName);
        String msg = "Key store " + keyStoreName + " successfully deleted.";

        if (keyStore == null) {
            msg = "Key store " + keyStoreName + " not found!";
            return msg;
        }

        String filePath = keyStore.getFilePath();
        if (keyStore.getIsPrimaryKeyStore()) {
            msg = "Primary keystore " + keyStoreName + " cannot be deleted!";
            throw new AxisFault(msg);
        }

        // Check whether we can delete this KS
        Set trustStoreServices = keyStore.getTrustStoreServices();
        if (!trustStoreServices.isEmpty()) {
            int i = 1;
            StringBuffer trustedServices = new StringBuffer();
            for (Iterator iterator = trustStoreServices.iterator(); iterator.hasNext();) {
                ServiceDO s = (ServiceDO) iterator.next();
                trustedServices.append(i).append(". ").
                        append(s.getServiceIdentifierDO().getServiceId()).append("<br/>");
                i++;
            }
            throw new AxisFault("Cannot delete keystore since the following services" +
                                " have been associated with a security scenario using this keystore" +
                                " as a trusted certificate store:<br/><br/>" + trustedServices + "<br/>" +
                                "You may disable security for these services and retry.");
        }
        Set pkStoreServices = keyStore.getPkStoreServices();
        if (!pkStoreServices.isEmpty()) {
            StringBuffer pkServices = new StringBuffer();
            int i = 1;
            for (Iterator iterator = pkStoreServices.iterator(); iterator.hasNext();) {
                ServiceDO s = (ServiceDO) iterator.next();
                pkServices.append(i).append(". ").
                        append(s.getServiceIdentifierDO().getServiceId()).append("<br/");
                i++;
            }
            throw new AxisFault("Cannot delete keystore since the following services" +
                                " have been associated with a security scenario using this keystore" +
                                " as a private key store:<br/><br/>" + pkServices + "<br/>" +
                                "You may disable security for these services and retry.");
        }

        File ksFile = new File(filePath);
        if (!ksFile.delete()) {
            msg = "Could not delete keystore file " + filePath +
                  ". Due to a JVM issue on MS-Windows, files may not be deleted sometimes.";
            throw new AxisFault(msg);
        }
        pm.deleteKeyStore(keyStoreName);
        return msg;
    }

    private String getFilePathFromFileId(String fileId) {
        ConfigurationContext configCtx =
                MessageContext.getCurrentMessageContext().getConfigurationContext();
        Map fileResMap = (Map) configCtx.getProperty(ServerConstants.FILE_RESOURCE_MAP);
        return (String) fileResMap.get(fileId);
    }
}
