/*
 * Decompiled with CFR 0.152.
 */
package net.shibboleth.idp.plugin.authn.oidc.rp.impl;

import com.google.common.collect.HashMultimap;
import com.google.common.collect.Multimap;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.openid.connect.sdk.AuthenticationSuccessResponse;
import com.nimbusds.openid.connect.sdk.claims.ClaimsSet;
import java.security.Principal;
import java.text.ParseException;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.security.auth.Subject;
import net.minidev.json.JSONObject;
import net.shibboleth.idp.attribute.AttributeDecodingException;
import net.shibboleth.idp.attribute.IdPAttribute;
import net.shibboleth.idp.attribute.context.AttributeContext;
import net.shibboleth.idp.attribute.filter.AttributeFilter;
import net.shibboleth.idp.attribute.filter.AttributeFilterException;
import net.shibboleth.idp.attribute.filter.context.AttributeFilterContext;
import net.shibboleth.idp.attribute.transcoding.AttributeTranscoder;
import net.shibboleth.idp.attribute.transcoding.AttributeTranscoderRegistry;
import net.shibboleth.idp.attribute.transcoding.TranscoderSupport;
import net.shibboleth.idp.attribute.transcoding.TranscodingRule;
import net.shibboleth.idp.authn.AbstractValidationAction;
import net.shibboleth.idp.authn.AuthenticationResult;
import net.shibboleth.idp.authn.context.AuthenticationContext;
import net.shibboleth.idp.authn.principal.IdPAttributePrincipal;
import net.shibboleth.idp.authn.principal.ProxyAuthenticationPrincipal;
import net.shibboleth.idp.plugin.authn.oidc.rp.context.EndUserClaimsContext;
import net.shibboleth.idp.plugin.authn.oidc.rp.principal.OIDCSubjectIdentifierPrincipal;
import net.shibboleth.oidc.profile.config.OIDCAuthenticationRelyingPartyProfileConfiguration;
import net.shibboleth.profile.context.RelyingPartyContext;
import net.shibboleth.saml.profile.context.navigate.SAMLMetadataContextLookupFunction;
import net.shibboleth.shared.annotation.constraint.Live;
import net.shibboleth.shared.annotation.constraint.NonnullAfterInit;
import net.shibboleth.shared.annotation.constraint.NonnullBeforeExec;
import net.shibboleth.shared.annotation.constraint.NonnullElements;
import net.shibboleth.shared.annotation.constraint.NotEmpty;
import net.shibboleth.shared.collection.CollectionSupport;
import net.shibboleth.shared.logic.Constraint;
import net.shibboleth.shared.primitive.LoggerFactory;
import net.shibboleth.shared.service.ReloadableService;
import net.shibboleth.shared.service.ServiceException;
import net.shibboleth.shared.service.ServiceableComponent;
import org.opensaml.messaging.context.MessageContext;
import org.opensaml.messaging.context.navigate.ChildContextLookup;
import org.opensaml.messaging.context.navigate.RecursiveTypedParentContextLookup;
import org.opensaml.profile.action.ActionSupport;
import org.opensaml.profile.context.ProfileRequestContext;
import org.opensaml.profile.context.navigate.InboundMessageContextLookup;
import org.opensaml.saml.metadata.resolver.MetadataResolver;
import org.slf4j.Logger;

public class ValidateOIDCAuthentication
extends AbstractValidationAction {
    @Nonnull
    @NotEmpty
    private static final String DEFAULT_METRIC_NAME = "net.shibboleth.idp.authn.oidc.rp";
    @Nonnull
    private final Logger log = LoggerFactory.getLogger(ValidateOIDCAuthentication.class);
    @NonnullAfterInit
    private ReloadableService<AttributeTranscoderRegistry> transcoderRegistry;
    @Nullable
    private ReloadableService<AttributeFilter> attributeFilterService;
    @Nullable
    private MetadataResolver metadataResolver;
    @Nonnull
    private Function<ProfileRequestContext, RelyingPartyContext> relyingPartyContextLookupStrategy;
    @Nullable
    private Function<ProfileRequestContext, Collection<Principal>> contextToPrivateCredentialsMappingStrategy;
    @NonnullBeforeExec
    private OIDCAuthenticationRelyingPartyProfileConfiguration profileConfiguration;
    @NonnullBeforeExec
    private EndUserClaimsContext endUserContext;
    @NonnullBeforeExec
    private JWTClaimsSet unprocessedIdTokenClaims;
    @Nonnull
    private final Function<ProfileRequestContext, EndUserClaimsContext> endUserClaimsContextLookupStrategy;
    @Nullable
    private AttributeContext attributeContext;
    @Nullable
    private Function<Collection<String>, Collection<Principal>> acrTranslator;
    @Nullable
    private Function<Collection<String>, Collection<Principal>> amrTranslator;
    @Nullable
    private Function<ProfileRequestContext, Collection<IdPAttribute>> attributeExtractionStrategy;
    @Nullable
    private ProfileRequestContext prc;

    public ValidateOIDCAuthentication() {
        this.setMetricName(DEFAULT_METRIC_NAME);
        this.relyingPartyContextLookupStrategy = new ChildContextLookup(RelyingPartyContext.class);
        this.endUserClaimsContextLookupStrategy = new ChildContextLookup(EndUserClaimsContext.class, true).compose((Function)new InboundMessageContextLookup());
    }

    public void setContextToPrivateCredentialsMappingStrategy(@Nullable Function<ProfileRequestContext, Collection<Principal>> strategy) {
        this.checkSetterPreconditions();
        this.contextToPrivateCredentialsMappingStrategy = strategy;
    }

    public void setAttributeFilter(@Nullable ReloadableService<AttributeFilter> filterService) {
        this.checkSetterPreconditions();
        this.attributeFilterService = filterService;
    }

    public void setTranscoderRegistry(@Nonnull ReloadableService<AttributeTranscoderRegistry> registry) {
        this.checkSetterPreconditions();
        this.transcoderRegistry = (ReloadableService)Constraint.isNotNull(registry, (String)"AttributeTranscoderRegistry cannot be null");
    }

    public void setMetadataResolver(@Nullable MetadataResolver resolver) {
        this.checkSetterPreconditions();
        this.metadataResolver = resolver;
    }

    public void setRelyingPartyContextLookupStrategy(@Nonnull Function<ProfileRequestContext, RelyingPartyContext> strategy) {
        this.checkSetterPreconditions();
        this.relyingPartyContextLookupStrategy = (Function)Constraint.isNotNull(strategy, (String)"RelyingPartyContext lookup strategy cannot be null");
    }

    public void setAttributeExtractionStrategy(@Nullable Function<ProfileRequestContext, Collection<IdPAttribute>> strategy) {
        this.checkSetterPreconditions();
        this.attributeExtractionStrategy = strategy;
    }

    protected boolean doPreExecute(@Nonnull ProfileRequestContext profileRequestContext, @Nonnull AuthenticationContext authenticationContext) {
        if (!super.doPreExecute(profileRequestContext, authenticationContext)) {
            return false;
        }
        this.prc = profileRequestContext;
        if (authenticationContext.getAttemptedFlow() == null) {
            this.log.debug("{} No attempted flow within authentication context", (Object)this.getLogPrefix());
            ActionSupport.buildEvent((ProfileRequestContext)profileRequestContext, (String)"InvalidProfileContext");
            return false;
        }
        MessageContext inboundMessageCtx = profileRequestContext.getInboundMessageContext();
        if (inboundMessageCtx == null) {
            this.log.error("{} No inbound message context", (Object)this.getLogPrefix());
            ActionSupport.buildEvent((ProfileRequestContext)profileRequestContext, (String)"InvalidMessageContext");
            return false;
        }
        if (inboundMessageCtx.getMessage() == null) {
            this.log.error("{} No inbound message", (Object)this.getLogPrefix());
            ActionSupport.buildEvent((ProfileRequestContext)profileRequestContext, (String)"InvalidMessageContext");
            return false;
        }
        if (!(inboundMessageCtx.getMessage() instanceof AuthenticationSuccessResponse)) {
            this.log.error("{} No inbound authentication success response", (Object)this.getLogPrefix());
            ActionSupport.buildEvent((ProfileRequestContext)profileRequestContext, (String)"InvalidMessageContext");
            return false;
        }
        RelyingPartyContext rpContext = this.relyingPartyContextLookupStrategy.apply(profileRequestContext);
        if (rpContext == null) {
            this.log.error("{} Unable to locate RelyingPartyContext", (Object)this.getLogPrefix());
            ActionSupport.buildEvent((ProfileRequestContext)profileRequestContext, (String)"InvalidRelyingPartyContext");
            return false;
        }
        if (rpContext.getProfileConfig() == null) {
            this.log.error("{} Unable to locate profile configuration", (Object)this.getLogPrefix());
            ActionSupport.buildEvent((ProfileRequestContext)profileRequestContext, (String)"InvalidProfileConfiguration");
            return false;
        }
        if (!(rpContext.getProfileConfig() instanceof OIDCAuthenticationRelyingPartyProfileConfiguration)) {
            this.log.error("{} No OIDC RP SSO profile configuration", (Object)this.getLogPrefix());
            ActionSupport.buildEvent((ProfileRequestContext)profileRequestContext, (String)"InvalidProfileConfiguration");
            return false;
        }
        this.profileConfiguration = (OIDCAuthenticationRelyingPartyProfileConfiguration)rpContext.getProfileConfig();
        this.endUserContext = this.endUserClaimsContextLookupStrategy.apply(profileRequestContext);
        if (this.endUserContext == null) {
            this.log.error("{} Unable to locate end-user claims context", (Object)this.getLogPrefix());
            ActionSupport.buildEvent((ProfileRequestContext)profileRequestContext, (String)"InvalidProfileContext");
            return false;
        }
        if (this.endUserContext.getEndUserClaims() == null) {
            this.log.error("{} End-user claims are null", (Object)this.getLogPrefix());
            ActionSupport.buildEvent((ProfileRequestContext)profileRequestContext, (String)"InvalidProfileContext");
            return false;
        }
        this.unprocessedIdTokenClaims = this.endUserContext.getUnprocessedIdTokenClaims();
        if (this.unprocessedIdTokenClaims == null) {
            this.log.error("{} id_token not found in response", (Object)this.getLogPrefix());
            ActionSupport.buildEvent((ProfileRequestContext)profileRequestContext, (String)"InvalidProfileContext");
            return false;
        }
        if (this.unprocessedIdTokenClaims.getSubject() == null) {
            this.log.error("{} id_token did not contain a subject (sub)", (Object)this.getLogPrefix());
            ActionSupport.buildEvent((ProfileRequestContext)profileRequestContext, (String)"InvalidProfileContext");
            return false;
        }
        return true;
    }

    protected void doExecute(@Nonnull ProfileRequestContext profileRequestContext, @Nonnull AuthenticationContext authenticationContext) {
        Function<ProfileRequestContext, Collection<IdPAttribute>> localAttributeExtractionStrategy;
        this.recordSuccess(profileRequestContext);
        this.log.debug("{} Validating OIDC proxy authentication", (Object)this.getLogPrefix());
        if (this.transcoderRegistry != null) {
            this.processAttributes(profileRequestContext);
        }
        if ((localAttributeExtractionStrategy = this.attributeExtractionStrategy) != null) {
            this.log.debug("{} Applying custom extraction strategy function", (Object)this.getLogPrefix());
            Collection<IdPAttribute> newAttributes = localAttributeExtractionStrategy.apply(profileRequestContext);
            if (newAttributes != null) {
                if (this.log.isDebugEnabled()) {
                    this.log.debug("{} Extracted attributes with custom strategy: {}", (Object)this.getLogPrefix(), newAttributes.stream().map(IdPAttribute::getId).collect(Collectors.toUnmodifiableList()));
                }
                if (this.attributeContext != null) {
                    Map<String, IdPAttribute> newMap = ValidateOIDCAuthentication.toMapMergeDuplicates(newAttributes);
                    assert (this.attributeContext != null);
                    this.attributeContext.setIdPAttributes(ValidateOIDCAuthentication.withMapMergeDuplicates(newMap, this.attributeContext.getIdPAttributes().values()).values());
                } else {
                    this.attributeContext = (AttributeContext)((RelyingPartyContext)profileRequestContext.ensureSubcontext(RelyingPartyContext.class)).ensureSubcontext(AttributeContext.class);
                    this.attributeContext.setIdPAttributes(ValidateOIDCAuthentication.toMapMergeDuplicates(newAttributes).values());
                }
            }
        }
        this.log.info("{} OIDC authentication succeeded for '{}'", (Object)this.getLogPrefix(), (Object)this.unprocessedIdTokenClaims.getSubject());
        this.acrTranslator = this.profileConfiguration.getAuthenticationContextClassReferenceTranslationStrategy(profileRequestContext);
        this.amrTranslator = this.profileConfiguration.getAuthenticationMethodsReferencesTranslationStrategy(profileRequestContext);
        this.buildAuthenticationResult(profileRequestContext, authenticationContext);
        AuthenticationResult authnResult = authenticationContext.getAuthenticationResult();
        if (authnResult != null && this.profileConfiguration.isProxiedAuthnInstant(profileRequestContext)) {
            try {
                Date authnTimeDate = this.unprocessedIdTokenClaims.getDateClaim("auth_time");
                if (authnTimeDate != null) {
                    this.log.debug("{} Resetting authentication time to proxied value: {}", (Object)this.getLogPrefix(), (Object)authnTimeDate);
                    Instant authnTimeInstant = authnTimeDate.toInstant();
                    assert (authnTimeInstant != null);
                    authnResult.setAuthenticationInstant(authnTimeInstant);
                } else {
                    this.log.debug("{} Unable to reset authentication time, auth_time not present in id_token", (Object)this.getLogPrefix());
                }
            }
            catch (ParseException e) {
                this.log.debug("{} Unable to reset authentication time, auth_time could not be parsed from id_token: {}", (Object)this.getLogPrefix(), (Object)e.getMessage());
            }
        }
    }

    protected Subject populateSubject(@Nonnull Subject subject) {
        Collection<Principal> privateCredentials;
        Map idpAttributes;
        Function<Collection<String>, Collection<Principal>> localAmrTranslator;
        Function<Collection<String>, Collection<Principal>> localAcrTranslator = this.acrTranslator;
        if (localAcrTranslator != null) {
            List<String> list;
            Object acrClaim = this.unprocessedIdTokenClaims.getClaim("acr");
            if (acrClaim instanceof String) {
                String acr = (String)acrClaim;
                list = List.of(acr);
            } else {
                list = CollectionSupport.emptyList();
            }
            List<String> acrList = list;
            Collection<Principal> translated = localAcrTranslator.apply(acrList);
            if (translated != null && !translated.isEmpty()) {
                subject.getPrincipals().addAll(translated);
                if (this.log.isDebugEnabled()) {
                    this.log.debug("{} Added translated ACR Principals: {}", (Object)this.getLogPrefix(), translated.stream().map(Principal::getName).toList());
                }
            }
        }
        if ((localAmrTranslator = this.amrTranslator) != null) {
            Object amrClaim = this.unprocessedIdTokenClaims.getClaim("amr");
            List amrs = Collections.emptyList();
            try {
                if (amrClaim instanceof Collection) {
                    amrs = this.unprocessedIdTokenClaims.getStringListClaim("amr");
                }
            }
            catch (ParseException e) {
                this.log.debug("Unable to parse AMR claims", (Throwable)e);
            }
            Collection<Principal> translated = localAmrTranslator.apply(amrs);
            if (translated != null && !translated.isEmpty()) {
                subject.getPrincipals().addAll(translated);
                if (this.log.isDebugEnabled()) {
                    this.log.debug("{} Added translated AMR Principals: {}", (Object)this.getLogPrefix(), translated.stream().map(Principal::getName).toList());
                }
            }
        }
        String localSubject = this.unprocessedIdTokenClaims.getSubject();
        assert (localSubject != null);
        subject.getPrincipals().add((Principal)new OIDCSubjectIdentifierPrincipal(localSubject));
        subject.getPrincipals().add((Principal)this.buildProxyPrincipal());
        Map map = idpAttributes = this.attributeContext != null ? this.attributeContext.getIdPAttributes() : null;
        if (this.attributeContext != null && idpAttributes != null && !idpAttributes.isEmpty()) {
            assert (this.attributeContext != null);
            this.log.debug("{} Adding filtered inbound attributes to Subject", (Object)this.getLogPrefix());
            subject.getPrincipals().addAll(idpAttributes.values().stream().map(IdPAttributePrincipal::new).toList());
        }
        if (this.contextToPrivateCredentialsMappingStrategy != null && (privateCredentials = this.contextToPrivateCredentialsMappingStrategy.apply(this.prc)) != null) {
            subject.getPrivateCredentials().addAll(privateCredentials);
            this.log.trace("{} Added '{}' private credential(s) from mapping strategy", (Object)this.getLogPrefix(), (Object)privateCredentials.size());
        }
        return subject;
    }

    @Nonnull
    private ProxyAuthenticationPrincipal buildProxyPrincipal() {
        ProxyAuthenticationPrincipal proxied = new ProxyAuthenticationPrincipal();
        proxied.getAuthorities().add(this.unprocessedIdTokenClaims.getIssuer());
        return proxied;
    }

    private void processAttributes(@Nonnull ProfileRequestContext profileRequestContext) {
        this.log.debug("{} Decoding incoming OIDC claims", (Object)this.getLogPrefix());
        HashMultimap mapped = HashMultimap.create();
        assert (mapped != null);
        assert (this.transcoderRegistry != null);
        try (ServiceableComponent component = this.transcoderRegistry.getServiceableComponent();){
            ClaimsSet endUserClaims = this.endUserContext.getEndUserClaims();
            assert (endUserClaims != null);
            for (Map.Entry claim : endUserClaims.toJSONObject().entrySet()) {
                try {
                    JSONObject jsonClaim = new JSONObject();
                    jsonClaim.put((Object)((String)claim.getKey()), claim.getValue());
                    this.decodeAttribute((AttributeTranscoderRegistry)component.getComponent(), profileRequestContext, jsonClaim, (Multimap<String, IdPAttribute>)mapped);
                }
                catch (AttributeDecodingException e) {
                    this.log.error("{} Error decoding inbound claim", (Object)this.getLogPrefix(), (Object)e);
                }
            }
        }
        catch (ServiceException e) {
            this.log.error("Attribute transcoder service unavailable", (Throwable)e);
            return;
        }
        this.log.debug("{} Incoming OIDC Attributes mapped to attribute IDs: {}", (Object)this.getLogPrefix(), (Object)mapped.keySet());
        if (!mapped.isEmpty()) {
            RelyingPartyContext rpc = (RelyingPartyContext)profileRequestContext.getSubcontext(RelyingPartyContext.class);
            assert (rpc != null);
            AttributeContext ac = this.attributeContext = (AttributeContext)rpc.ensureSubcontext(AttributeContext.class);
            assert (ac != null);
            ac.setUnfilteredIdPAttributes(ValidateOIDCAuthentication.toMapMergeDuplicates(mapped.values()).values()).setIdPAttributes(null);
            this.filterAttributes(profileRequestContext);
        }
    }

    private void filterAttributes(@Nonnull ProfileRequestContext profileRequestContext) {
        ReloadableService<AttributeFilter> service = this.attributeFilterService;
        if (service == null) {
            this.log.warn("{} No AttributeFilter service provided", (Object)this.getLogPrefix());
            return;
        }
        AttributeFilterContext filterContext = (AttributeFilterContext)profileRequestContext.ensureSubcontext(AttributeFilterContext.class);
        this.populateFilterContext(profileRequestContext, filterContext);
        try (ServiceableComponent component = service.getServiceableComponent();){
            AttributeFilter filter = (AttributeFilter)component.getComponent();
            filter.filterAttributes(filterContext);
            filterContext.removeFromParent();
            assert (this.attributeContext != null);
            this.attributeContext.setIdPAttributes(filterContext.getFilteredIdPAttributes().values());
        }
        catch (AttributeFilterException e) {
            this.log.error("{} Error while filtering inbound attributes", (Object)this.getLogPrefix(), (Object)e);
        }
        catch (ServiceException e) {
            this.log.error("{} Invalid AttributeFilter configuration", (Object)this.getLogPrefix(), (Object)e);
        }
    }

    private void populateFilterContext(@Nonnull ProfileRequestContext profileRequestContext, @Nonnull AttributeFilterContext filterContext) {
        AttributeContext ac = this.attributeContext;
        assert (ac != null);
        filterContext.setDirection(AttributeFilterContext.Direction.INBOUND).setPrefilteredIdPAttributes(ac.getUnfilteredIdPAttributes().values()).setMetadataResolver(this.metadataResolver).setRequesterMetadataContextLookupStrategy(null).setIssuerMetadataContextLookupStrategy(new SAMLMetadataContextLookupFunction().compose((Function)new RecursiveTypedParentContextLookup(ProfileRequestContext.class))).setProxiedRequesterContextLookupStrategy(null).setAttributeIssuerID((String)((Function)Constraint.isNotNull((Object)this.getResponderLookupStrategy(), (String)"No responder Strategy")).apply(profileRequestContext)).setAttributeRecipientID((String)((Function)Constraint.isNotNull((Object)this.getRequesterLookupStrategy(), (String)"No requester strategy")).apply(profileRequestContext));
    }

    private void decodeAttribute(@Nonnull AttributeTranscoderRegistry registry, @Nonnull ProfileRequestContext profileRequestContext, @Nonnull JSONObject input, @Nonnull @NonnullElements @Live Multimap<String, IdPAttribute> results) throws AttributeDecodingException {
        Collection transcodingRules = registry.getTranscodingRules((Object)input);
        if (transcodingRules.isEmpty()) {
            this.log.debug("{} No transcoding rule for Attribute '{}'", (Object)this.getLogPrefix(), (Object)input);
            return;
        }
        for (TranscodingRule rules : transcodingRules) {
            assert (rules != null);
            AttributeTranscoder transcoder = TranscoderSupport.getTranscoder((TranscodingRule)rules);
            IdPAttribute decodedAttribute = transcoder.decode(profileRequestContext, (Object)input, rules);
            if (decodedAttribute == null) continue;
            results.put((Object)decodedAttribute.getId(), (Object)decodedAttribute);
        }
    }

    @Nonnull
    @Live
    private static Map<String, IdPAttribute> toMapMergeDuplicates(@Nullable Collection<IdPAttribute> attributes) {
        HashMap<String, IdPAttribute> accumulator = new HashMap<String, IdPAttribute>();
        return ValidateOIDCAuthentication.withMapMergeDuplicates(accumulator, attributes);
    }

    @Nonnull
    @Live
    private static Map<String, IdPAttribute> withMapMergeDuplicates(@Nonnull @Live Map<String, IdPAttribute> existingAttributes, @Nullable Collection<IdPAttribute> newAttributes) {
        if (newAttributes == null) {
            return existingAttributes;
        }
        for (IdPAttribute attribute : newAttributes) {
            IdPAttribute newAttribute;
            IdPAttribute oldAttr = existingAttributes.get(attribute.getId());
            if (oldAttr == null) {
                existingAttributes.put(attribute.getId(), attribute);
                continue;
            }
            try {
                newAttribute = oldAttr.clone();
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
            ArrayList values = new ArrayList(newAttribute.getValues());
            values.addAll(attribute.getValues());
            newAttribute.setValues(values);
            existingAttributes.remove(attribute.getId());
            existingAttributes.put(newAttribute.getId(), newAttribute);
        }
        return existingAttributes;
    }
}

