/*
 * Decompiled with CFR 0.152.
 */
package net.shibboleth.idp.plugin.authn.oidc.rp.impl;

import jakarta.servlet.http.HttpServletRequest;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Collections;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Function;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import net.shibboleth.idp.plugin.authn.oidc.rp.context.OAuth2ClientContext;
import net.shibboleth.oidc.profile.messaging.context.OIDCPeerEntityContext;
import net.shibboleth.shared.annotation.constraint.NonnullAfterInit;
import net.shibboleth.shared.annotation.constraint.NotEmpty;
import net.shibboleth.shared.annotation.constraint.ThreadSafeAfterInit;
import net.shibboleth.shared.component.AbstractIdentifiableInitializableComponent;
import net.shibboleth.shared.component.ComponentInitializationException;
import net.shibboleth.shared.logic.Constraint;
import net.shibboleth.shared.primitive.LoggerFactory;
import net.shibboleth.shared.primitive.StringSupport;
import org.apache.hc.core5.net.URIBuilder;
import org.opensaml.messaging.context.navigate.ChildContextLookup;
import org.opensaml.profile.context.ProfileRequestContext;
import org.opensaml.profile.context.navigate.OutboundMessageContextLookup;
import org.slf4j.Logger;

@ThreadSafeAfterInit
public class DefaultRedirectUriCreationFunction
extends AbstractIdentifiableInitializableComponent
implements BiFunction<HttpServletRequest, ProfileRequestContext, URI> {
    @Nonnull
    private final Logger log = LoggerFactory.getLogger(DefaultRedirectUriCreationFunction.class);
    @Nonnull
    private Function<ProfileRequestContext, OAuth2ClientContext> oauth2ClientContextLookupStrategy = new ChildContextLookup(OAuth2ClientContext.class).compose(new ChildContextLookup(OIDCPeerEntityContext.class).compose((Function)new OutboundMessageContextLookup()));
    @NonnullAfterInit
    @NotEmpty
    private String callbackServletPath;
    @NonnullAfterInit
    private Set<String> allowedOrigins;

    protected void doInitialize() throws ComponentInitializationException {
        super.doInitialize();
        if (StringSupport.trimOrNull((String)this.callbackServletPath) == null) {
            throw new ComponentInitializationException("Callback servlet path can not be null");
        }
        if (this.allowedOrigins == null) {
            this.allowedOrigins = Collections.emptySet();
        }
    }

    public void setCallbackServletPath(@Nonnull @NotEmpty String path) {
        this.checkSetterPreconditions();
        this.callbackServletPath = Constraint.isNotEmpty((String)path, (String)"callbackServletPath can not be null");
    }

    public void setAllowedOrigins(@Nullable Set<String> origins) {
        this.checkSetterPreconditions();
        if (origins == null) {
            this.allowedOrigins = Collections.emptySet();
        }
        this.allowedOrigins = Collections.unmodifiableSet(origins);
    }

    public void setOAuth2ClientContextLookupStrategy(@Nonnull Function<ProfileRequestContext, OAuth2ClientContext> strgy) {
        this.checkSetterPreconditions();
        this.oauth2ClientContextLookupStrategy = (Function)Constraint.isNotNull(strgy, (String)"OAuth2 client context lookup strategy cannot be null");
    }

    @Override
    @Nullable
    public URI apply(@Nullable HttpServletRequest request, @Nullable ProfileRequestContext prc) {
        OAuth2ClientContext context = this.oauth2ClientContextLookupStrategy.apply(prc);
        if (context == null) {
            this.log.warn("Could not locate the OAuth2 Client Context, can not compute redirect_uri");
            return null;
        }
        if (request == null) {
            this.log.warn("HttpServletRequest was unavailable, can not compute redirect_uri");
            return null;
        }
        if (context.getRedirectUriOverride() != null) {
            return context.getRedirectUriOverride();
        }
        if (this.allowedOrigins.isEmpty()) {
            this.log.warn("Can not compute redirect_uri if allowed origins is empty");
            return null;
        }
        try {
            String scheme = request.getScheme();
            assert (scheme != null);
            String serverName = request.getServerName();
            assert (serverName != null);
            URI redirectUri = this.buildURIIgnoreDefaultPorts(scheme, serverName, request.getServerPort(), request.getContextPath() + request.getServletPath() + this.callbackServletPath);
            String origin = this.buildOrigin(redirectUri);
            if (!this.allowedOrigins.contains(origin)) {
                this.log.warn("The 'origin' of the computed redirect_uri ('{}') is not allowed. If permissible, add it to the allowed origins property.", (Object)origin);
                return null;
            }
            return redirectUri;
        }
        catch (URISyntaxException e) {
            this.log.warn("Unable to create redirect_uri for OIDC authentication request", (Throwable)e);
            return null;
        }
    }

    @Nonnull
    private String buildOrigin(@Nonnull URI uri) throws URISyntaxException {
        if (uri.getPort() == -1) {
            String uriAsString = new URI(String.format("%s://%s", uri.getScheme(), uri.getHost())).toString();
            assert (uriAsString != null);
            return uriAsString;
        }
        String uriAsString = new URI(String.format("%s://%s:%s", uri.getScheme(), uri.getHost(), uri.getPort())).toString();
        assert (uriAsString != null);
        return uriAsString;
    }

    @Nonnull
    private final URI buildURIIgnoreDefaultPorts(@Nonnull String scheme, @Nonnull String host, int port, @Nonnull String path) throws URISyntaxException {
        int usedPort = port;
        if ("http".equalsIgnoreCase(scheme)) {
            if (port == 80) {
                usedPort = -1;
            }
        } else if ("https".equalsIgnoreCase(scheme) && port == 443) {
            usedPort = -1;
        }
        URI builtUri = new URIBuilder().setScheme(scheme).setHost(host).setPort(usedPort).setPath(path).build();
        assert (builtUri != null);
        return builtUri;
    }
}

