/*
 * Decompiled with CFR 0.152.
 */
package net.shibboleth.oidc.security.impl;

import com.nimbusds.jose.EncryptionMethod;
import com.nimbusds.jose.JWEAlgorithm;
import com.nimbusds.jose.JWEEncrypter;
import com.nimbusds.jose.JWEHeader;
import com.nimbusds.jose.JWEObject;
import com.nimbusds.jose.Payload;
import com.nimbusds.jose.crypto.AESEncrypter;
import com.nimbusds.jose.crypto.DirectEncrypter;
import com.nimbusds.jose.crypto.ECDHEncrypter;
import com.nimbusds.jose.crypto.RSAEncrypter;
import com.nimbusds.jwt.EncryptedJWT;
import com.nimbusds.jwt.JWT;
import java.security.interfaces.ECPublicKey;
import java.security.interfaces.RSAPublicKey;
import java.util.function.BiConsumer;
import java.util.function.Function;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import net.shibboleth.oidc.security.CredentialConversionUtil;
import net.shibboleth.oidc.security.jose.EncryptionParameters;
import net.shibboleth.oidc.security.jose.context.SecurityParametersContext;
import net.shibboleth.shared.annotation.constraint.NonnullAfterInit;
import net.shibboleth.shared.annotation.constraint.NonnullBeforeExec;
import net.shibboleth.shared.annotation.constraint.NotEmpty;
import net.shibboleth.shared.component.ComponentInitializationException;
import net.shibboleth.shared.logic.Constraint;
import net.shibboleth.shared.primitive.LoggerFactory;
import net.shibboleth.shared.primitive.StringSupport;
import org.opensaml.messaging.context.MessageContext;
import org.opensaml.messaging.context.navigate.ChildContextLookup;
import org.opensaml.messaging.handler.AbstractMessageHandler;
import org.opensaml.messaging.handler.MessageHandlerException;
import org.opensaml.security.credential.Credential;
import org.slf4j.Logger;

public class EncryptJWTHandler
extends AbstractMessageHandler {
    @Nonnull
    private final Logger log = LoggerFactory.getLogger(EncryptJWTHandler.class);
    @Nonnull
    private Function<MessageContext, SecurityParametersContext> securityParametersLookupStrategy = new ChildContextLookup(SecurityParametersContext.class);
    @NonnullAfterInit
    private Function<MessageContext, Payload> payloadToEncryptLookupStrategy;
    @NonnullAfterInit
    private BiConsumer<JWT, MessageContext> jwtUpdateConsumer;
    @NonnullBeforeExec
    private EncryptionParameters encryptionParameters;
    @Nonnull
    private String logName = "not-specified";

    public void setLogName(@Nonnull @NotEmpty String name) {
        this.ifInitializedThrowUnmodifiabledComponentException();
        this.ifDestroyedThrowDestroyedComponentException();
        this.logName = Constraint.isNotEmpty((String)name, (String)"ForFriendlyName can not be null or empty");
    }

    public void setJwtUpdateConsumer(BiConsumer<JWT, MessageContext> consumer) {
        this.ifInitializedThrowUnmodifiabledComponentException();
        this.ifDestroyedThrowDestroyedComponentException();
        this.jwtUpdateConsumer = (BiConsumer)Constraint.isNotNull(consumer, (String)"JWT Update Consumer can not be null");
    }

    public void setPayloadToEncryptLookupStrategy(@Nonnull Function<MessageContext, Payload> strategy) {
        this.ifInitializedThrowUnmodifiabledComponentException();
        this.ifDestroyedThrowDestroyedComponentException();
        this.payloadToEncryptLookupStrategy = (Function)Constraint.isNotNull(strategy, (String)"Payload To Encrypt Lookup Strategy can not be null");
    }

    public void setSecurityParametersLookupStrategy(@Nonnull Function<MessageContext, SecurityParametersContext> strategy) {
        this.ifInitializedThrowUnmodifiabledComponentException();
        this.securityParametersLookupStrategy = (Function)Constraint.isNotNull(strategy, (String)"SecurityParameterContext lookup strategy cannot be null");
    }

    protected void doInitialize() throws ComponentInitializationException {
        if (this.payloadToEncryptLookupStrategy == null) {
            throw new ComponentInitializationException("Payload To Encrypt Lookup Strategy can not be null");
        }
        if (this.jwtUpdateConsumer == null) {
            throw new ComponentInitializationException("JWT Update Consumer can not be null");
        }
        super.doInitialize();
    }

    protected boolean doPreInvoke(@Nonnull MessageContext messageContext) throws MessageHandlerException {
        if (!super.doPreInvoke(messageContext)) {
            return false;
        }
        SecurityParametersContext secParamCtx = this.securityParametersLookupStrategy.apply(messageContext);
        if (secParamCtx == null) {
            this.log.trace("{} Message context did not contain an encryption parameters context, encryption skipped", (Object)this.getLogPrefix());
            return false;
        }
        this.encryptionParameters = secParamCtx.getEncryptionParameters();
        if (this.encryptionParameters == null) {
            this.log.debug("{} Message context did not contain encryption parameters, '{}' will not be encrypted", (Object)this.getLogPrefix(), (Object)this.logName);
            return false;
        }
        if (StringSupport.trimOrNull((String)this.encryptionParameters.getKeyTransportEncryptionAlgorithm()) == null || StringSupport.trimOrNull((String)this.encryptionParameters.getDataEncryptionAlgorithm()) == null || this.encryptionParameters.getKeyTransportEncryptionCredential() == null && this.encryptionParameters.getDataEncryptionCredential() == null) {
            throw new MessageHandlerException("Message context did not contain all required encryption parameters");
        }
        if (this.encryptionParameters.getKeyTransportEncryptionCredential() != null && this.encryptionParameters.getDataEncryptionCredential() != null) {
            throw new MessageHandlerException("Message context contained both a content encryption and key transport credential. Only one required.");
        }
        return true;
    }

    protected void doInvoke(@Nonnull MessageContext messageContext) throws MessageHandlerException {
        Payload payload = this.payloadToEncryptLookupStrategy.apply(messageContext);
        if (payload == null) {
            this.log.trace("{} No plain text source payload provided to encrypt, encryption skipped", (Object)this.getLogPrefix());
            return;
        }
        JWEAlgorithm encAlg = JWEAlgorithm.parse((String)this.encryptionParameters.getKeyTransportEncryptionAlgorithm());
        EncryptionMethod encEnc = EncryptionMethod.parse((String)this.encryptionParameters.getDataEncryptionAlgorithm());
        Credential keyTransportCredential = this.encryptionParameters.getKeyTransportEncryptionCredential();
        Credential dataEncryptionCredential = this.encryptionParameters.getDataEncryptionCredential();
        String keyTransportKid = keyTransportCredential == null ? null : CredentialConversionUtil.resolveKid((Credential)keyTransportCredential);
        String dataEncryptionKid = dataEncryptionCredential == null ? null : CredentialConversionUtil.resolveKid((Credential)dataEncryptionCredential);
        JWEObject jweObject = null;
        try {
            if (JWEAlgorithm.Family.RSA.contains((Object)encAlg) && keyTransportCredential != null && keyTransportCredential.getPublicKey() != null) {
                jweObject = new JWEObject(new JWEHeader.Builder(encAlg, encEnc).contentType("JWT").keyID(keyTransportKid).build(), payload);
                this.logEncryption(keyTransportKid, encAlg.getName(), encEnc.getName());
                jweObject.encrypt((JWEEncrypter)new RSAEncrypter((RSAPublicKey)keyTransportCredential.getPublicKey()));
            } else if (JWEAlgorithm.Family.ECDH_ES.contains((Object)encAlg) && keyTransportCredential != null && keyTransportCredential.getPublicKey() != null) {
                jweObject = new JWEObject(new JWEHeader.Builder(encAlg, encEnc).contentType("JWT").keyID(keyTransportKid).build(), payload);
                this.logEncryption(keyTransportKid, encAlg.getName(), encEnc.getName());
                jweObject.encrypt((JWEEncrypter)new ECDHEncrypter((ECPublicKey)keyTransportCredential.getPublicKey()));
            } else if ((JWEAlgorithm.Family.AES_KW.contains((Object)encAlg) || JWEAlgorithm.Family.AES_GCM_KW.contains((Object)encAlg)) && keyTransportCredential != null && keyTransportCredential.getSecretKey() != null) {
                jweObject = new JWEObject(new JWEHeader.Builder(encAlg, encEnc).contentType("JWT").keyID(keyTransportKid).build(), payload);
                this.logEncryption(keyTransportKid, encAlg.getName(), encEnc.getName());
                jweObject.encrypt((JWEEncrypter)new AESEncrypter(keyTransportCredential.getSecretKey()));
            } else if (JWEAlgorithm.DIR.equals((Object)encAlg) && dataEncryptionCredential != null && dataEncryptionCredential.getSecretKey() != null) {
                jweObject = new JWEObject(new JWEHeader.Builder(encAlg, encEnc).contentType("JWT").keyID(dataEncryptionKid).build(), payload);
                this.logEncryption(dataEncryptionKid, encAlg.getName(), encEnc.getName());
                jweObject.encrypt((JWEEncrypter)new DirectEncrypter(dataEncryptionCredential.getSecretKey()));
            } else {
                this.log.error("{} Unsupported algorithm '{}' or key '{}'", new Object[]{this.getLogPrefix(), encAlg.getName(), keyTransportKid});
                throw new MessageHandlerException("Unsupported algorithm " + encAlg.getName());
            }
            EncryptedJWT encryptedJWT = EncryptedJWT.parse((String)jweObject.serialize());
            this.jwtUpdateConsumer.accept((JWT)encryptedJWT, messageContext);
            if (this.log.isDebugEnabled() && !this.log.isTraceEnabled()) {
                this.log.debug("{} Encrypted '{}' JWT", (Object)this.getLogPrefix(), (Object)this.logName);
            } else if (this.log.isTraceEnabled()) {
                this.log.trace("{} Encrypted '{}' JWT: {}", new Object[]{this.getLogPrefix(), this.logName, encryptedJWT.serialize()});
            }
        }
        catch (Exception e) {
            this.log.error("{} Encryption failed", (Object)this.getLogPrefix(), (Object)e);
            throw new MessageHandlerException("Encryption failed", e);
        }
    }

    private void logEncryption(@Nullable String keyID, @Nullable String alg, @Nullable String enc) {
        this.log.debug("{} Encrypting '{}' with kid '{}' and params alg: {} enc: {}", new Object[]{this.getLogPrefix(), this.logName, keyID, alg, enc});
    }
}

