WSO2 API Manager - Decode JWT in ESB using JWT Decode mediator.

Here in this post i will list code for JWTDecorder mediator which we can use in WSO2 ESB or any other synapse based WSO2 product to decode JWT header.

After message gone through this mediator, all claims in JWT will present in message context as properties. Property name will be claim name. So you can use them in rest of the mediation flow.

import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.security.InvalidKeyException;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.Signature;
import java.security.SignatureException;
import java.security.cert.Certificate;
import java.security.cert.CertificateEncodingException;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.Enumeration;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.apache.axiom.util.base64.Base64Utils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.oltu.oauth2.jwt.JWTException;
import org.apache.oltu.oauth2.jwt.JWTProcessor;
import org.apache.synapse.ManagedLifecycle;
import org.apache.synapse.MessageContext;
import org.apache.synapse.SynapseException;
import org.apache.synapse.SynapseLog;
import org.apache.synapse.core.SynapseEnvironment;
import org.apache.synapse.core.axis2.Axis2MessageContext;
import org.apache.synapse.mediators.AbstractMediator;
import org.wso2.carbon.context.CarbonContext;
import org.wso2.carbon.core.util.KeyStoreManager;



public class JWTDecoder extends AbstractMediator implements ManagedLifecycle {
   
    private static Log log = LogFactory.getLog(JWTDecoder.class);

    private final String CLAIM_URI = "http://wso2.org/claims/";
    private final String SCIM_CLAIM_URI = "urn:scim:schemas:core:1.0:";

    private KeyStore keyStore;

    public void init(SynapseEnvironment synapseEnvironment) {
        if (log.isInfoEnabled()) {
            log.info("Initializing JWTDecoder Mediator");
        }
        String keyStoreFile = "";
        String password = "";

        try {
            keyStore = KeyStore.getInstance("JKS");
        } catch (KeyStoreException e) {
            //throw new Exception("Unable to get JKS KeyStore instance");
        }
        char[] storePass = password.toCharArray();

        // load the key store from file system
        FileInputStream fileInputStream = null;
        try {
            fileInputStream = new FileInputStream(keyStoreFile);
            keyStore.load(fileInputStream, storePass);
            fileInputStream.close();
        } catch (FileNotFoundException e) {
            if (log.isErrorEnabled()) {
                log.error("Error loading keystore", e);
            }
        } catch (NoSuchAlgorithmException e) {
            if (log.isErrorEnabled()) {
                log.error("Error loading keystore", e);
            }
        } catch (CertificateException e) {
            if (log.isErrorEnabled()) {
                log.error("Error loading keystore", e);
            }
        } catch (IOException e) {
            if (log.isErrorEnabled()) {
                log.error("Error loading keystore", e);
            }
        }
    }

    public boolean mediate(MessageContext synapseContext) {
        SynapseLog synLog = getLog(synapseContext);

        if (synLog.isTraceOrDebugEnabled()) {
            synLog.traceOrDebug("Start : JWTDecoder mediator");
            if (synLog.isTraceTraceEnabled()) {
                synLog.traceTrace("Message : " + synapseContext.getEnvelope());
            }
        }

        // Extract the HTTP headers and then extract the JWT from the HTTP Header map
        org.apache.axis2.context.MessageContext axis2MessageContext = ((Axis2MessageContext) synapseContext).getAxis2MessageContext();
        Object headerObj = axis2MessageContext.getProperty(org.apache.axis2.context.MessageContext.TRANSPORT_HEADERS);
        @SuppressWarnings("unchecked")
        Map headers = (Map) headerObj;
        String jwt_assertion = (String) headers.get("x-jwt-assertion");

        // Incoming request does not contain the JWT assertion
        if (jwt_assertion == null || jwt_assertion == "") {
            // Since this is an unauthorized request, send the response back to client with 401 - Unauthorized error
            synapseContext.setTo(null);
            synapseContext.setResponse(true);
            axis2MessageContext.setProperty("HTTP_SC", "401");
            // Log the authentication failure
            String err = "JWT assertion not found in the message header";
            handleException(err, synapseContext);
            return false;
        }

        boolean isSignatureVerified = verifySignature(jwt_assertion, synapseContext);

        try {
            if (isSignatureVerified) {
                // Process the JWT, extract the values and set them to the Synapse environment
                if (log.isDebugEnabled()){
                    log.debug("JWT assertion is : "+jwt_assertion);
                }
                JWTProcessor processor = new JWTProcessor().process(jwt_assertion);
                Map claims = processor.getPayloadClaims();
                for (Map.Entry claimEntry : claims.entrySet()) {
                    // Extract the claims and set it in Synapse context
                    if (claimEntry.getKey().startsWith(CLAIM_URI)) {
                        String tempPropName = claimEntry.getKey().split(CLAIM_URI)[1];
                        synapseContext.setProperty(tempPropName, claimEntry.getValue());
                        if(log.isDebugEnabled()){
                            log.debug("Getting claim :"+tempPropName+" , " +claimEntry.getValue() );
                        }
                    } else if (claimEntry.getKey().startsWith(SCIM_CLAIM_URI)) {
                        String tempPropName = claimEntry.getKey().split(SCIM_CLAIM_URI)[1];
                        if (tempPropName.contains(".")) {
                            tempPropName = tempPropName.split("\\.")[1];
                        }

                        synapseContext.setProperty(tempPropName, claimEntry.getValue());
                        if(log.isDebugEnabled()){
                            log.debug("Getting claim :"+tempPropName+" , " +claimEntry.getValue() );
                        }
                    }
                }
            } else {
                return false;
            }
        } catch (JWTException e) {
            log.error(e.getMessage(), e);
            throw new SynapseException(e.getMessage(), e);
        }

        if (synLog.isTraceOrDebugEnabled()) {
            synLog.traceOrDebug("End : JWTDecoder mediator");
        }
       
        return true;
    }

    private boolean verifySignature(String jwt_assertion, MessageContext synapseContext) {
        boolean isVerified = false;
        String[] split_string = jwt_assertion.split("\\.");
        String base64EncodedHeader = split_string[0];
        String base64EncodedBody = split_string[1];
        String base64EncodedSignature = split_string[2];

        String decodedHeader = new String(Base64Utils.decode(base64EncodedHeader));
        byte[] decodedSignature = Base64Utils.decode(base64EncodedSignature);
        Pattern pattern = Pattern.compile("^[^:]*:[^:]*:[^:]*:\"(.+)\"}$");
        Matcher matcher = pattern.matcher(decodedHeader);
        String base64EncodedCertThumb = null;
        if (matcher.find()) {
            base64EncodedCertThumb = matcher.group(1);
        }
        byte[] decodedCertThumb = Base64Utils.decode(base64EncodedCertThumb);

        Certificate publicCert = null;


        publicCert = getSuperTenantPublicKey(decodedCertThumb, synapseContext);
        try {
            if (publicCert != null) {
                isVerified = verifySignature(publicCert, decodedSignature, base64EncodedHeader, base64EncodedBody,
                        base64EncodedSignature);
            } else if (!isVerified) {
                publicCert = getTenantPublicKey(decodedCertThumb, synapseContext);
                if (publicCert != null) {
                    isVerified = verifySignature(publicCert, decodedSignature, base64EncodedHeader, base64EncodedBody,
                            base64EncodedSignature);
                } else {
                    throw new Exception("Couldn't find a public certificate to verify signature");
                }

            }

        } catch (Exception e) {
            handleSigVerificationException(e, synapseContext);
        }
        return isVerified;
    }

    private Certificate getSuperTenantPublicKey(byte[] decodedCertThumb, MessageContext synapseContext){
        String alias = getAliasForX509CertThumb(keyStore, decodedCertThumb, synapseContext);
        if (alias != null) {
            // get the certificate associated with the given alias from
            // default keystore
            try {
                return keyStore.getCertificate(alias);
            } catch (KeyStoreException e) {
                if (log.isErrorEnabled()) {
                    log.error("Error when getting server public certificate: " , e);
                }
            }
        }
        return null;
    }

    private Certificate getTenantPublicKey(byte[] decodedCertThumb, MessageContext synapseContext){
        SynapseLog synLog = getLog(synapseContext);

        int tenantId = CarbonContext.getThreadLocalCarbonContext().getTenantId();
        String tenantDomain = CarbonContext.getThreadLocalCarbonContext().getTenantDomain();
       
        if (synLog.isTraceOrDebugEnabled()) {
            synLog.traceOrDebug("Tenant Domain: " + tenantDomain);
        }

        KeyStore tenantKeyStore = null;
        KeyStoreManager tenantKSM = KeyStoreManager.getInstance(tenantId);
        String ksName = tenantDomain.trim().replace(".", "-");
        String jksName = ksName + ".jks";
        try {
            tenantKeyStore = tenantKSM.getKeyStore(jksName);
        } catch (Exception e) {
            if (log.isErrorEnabled()) {
                log.error("Error getting keystore for " + tenantDomain, e);
            }
        }
        if (tenantKeyStore != null) {
            String alias = getAliasForX509CertThumb(tenantKeyStore, decodedCertThumb, synapseContext);
            if (alias != null) {
                // get the certificate associated with the given alias
                // from
                // tenant's keystore
                try {
                    return tenantKeyStore.getCertificate(alias);
                } catch (KeyStoreException e) {
                    if (log.isErrorEnabled()) {
                        log.error("Error when getting tenants public certificate: " + tenantDomain, e);
                    }
                }
            }
        }

        return null;
    }
   
    private boolean verifySignature(Certificate publicCert, byte[] decodedSignature, String base64EncodedHeader,
            String base64EncodedBody, String base64EncodedSignature) throws NoSuchAlgorithmException,
            InvalidKeyException, SignatureException {
        // create signature instance with signature algorithm and public cert,
        // to verify the signature.
        Signature verifySig = Signature.getInstance("SHA256withRSA");
        // init
        verifySig.initVerify(publicCert);
        // update signature with signature data.
        verifySig.update((base64EncodedHeader + "." + base64EncodedBody).getBytes());
        // do the verification
        return verifySig.verify(decodedSignature);
    }

    private String getAliasForX509CertThumb(KeyStore keyStore, byte[] thumb, MessageContext synapseContext) {
        SynapseLog synLog = getLog(synapseContext);
        Certificate cert = null;
        MessageDigest sha = null;

        try {
            sha = MessageDigest.getInstance("SHA-1");
        } catch (NoSuchAlgorithmException e) {
            handleSigVerificationException(e, synapseContext);
        }
        try {
            for (Enumeration e = keyStore.aliases(); e.hasMoreElements();) {
                String alias = e.nextElement();
                Certificate[] certs = keyStore.getCertificateChain(alias);
                if (certs == null || certs.length == 0) {
                    // no cert chain, so lets check if getCertificate gives us a result.
                    cert = keyStore.getCertificate(alias);
                    if (cert == null) {
                        return null;
                    }
                } else {
                    cert = certs[0];
                }
                if (!(cert instanceof X509Certificate)) {
                    continue;
                }
                sha.reset();
                try {
                    sha.update(cert.getEncoded());
                } catch (CertificateEncodingException e1) {
                    //throw new Exception("Error encoding certificate");
                }
                byte[] data = sha.digest();
                if (new String(thumb).equals(hexify(data))) {
                    if (synLog.isTraceOrDebugEnabled()) {
                        synLog.traceOrDebug("Found matching alias: " + alias);
                    }
                    return alias;
                }
            }
        } catch (KeyStoreException e) {
            if (log.isErrorEnabled()) {
                log.error("Error getting alias from keystore", e);
            }
        }
        return null;
    }

    private String hexify(byte bytes[]) {
        char[] hexDigits = {'0', '1', '2', '3', '4', '5', '6', '7',
                            '8', '9', 'a', 'b', 'c', 'd', 'e', 'f'};

        StringBuffer buf = new StringBuffer(bytes.length * 2);

        for (int i = 0; i < bytes.length; ++i) {
            buf.append(hexDigits[(bytes[i] & 0xf0) >> 4]);
            buf.append(hexDigits[bytes[i] & 0x0f]);
        }

        return buf.toString();
    }

    private void handleSigVerificationException(Exception e, MessageContext synapseContext) {
        synapseContext.setTo(null);
        synapseContext.setResponse(true);
        org.apache.axis2.context.MessageContext axis2MessageContext = ((Axis2MessageContext) synapseContext).getAxis2MessageContext();
        axis2MessageContext.setProperty("HTTP_SC", "401");
        String err = e.getMessage();
        handleException(err, synapseContext);
    }

    public void destroy() {
        if (log.isInfoEnabled()) {
            log.info("Destroying JWTDecoder Mediator");
        }
    }
}





3 comments:

  1. Hi Sanjeewa,

    Nice post, thanks. Is this going to be one of the out of the box mediators in future released of ESB?

    ReplyDelete
  2. Hi Sanjeewa,
    I've managed to compile this, but not sure how to make it available in WSO2 API manager 2.0 so it would convert claims into properties.
    Thanks,
    Brian.

    ReplyDelete
    Replies
    1. You can drop this to pack as we do for any mediator and use it in mediation flow.

      Delete

Empowering the Future of API Management: Unveiling the Journey of WSO2 API Platform for Kubernetes (APK) Project and the Anticipated Alpha Release

  Introduction In the ever-evolving realm of API management, our journey embarked on the APK project eight months ago, and now, with great a...