/*
 * Decompiled with CFR 0.152.
 */
package org.keycloak.sdjwt;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import java.time.Instant;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.function.BiFunction;
import org.keycloak.common.VerificationException;
import org.keycloak.common.util.Time;

public class ClaimVerifier {
    private final List<Predicate<ObjectNode>> headerVerifiers;
    private final List<Predicate<ObjectNode>> contentVerifiers;

    public ClaimVerifier(List<Predicate<ObjectNode>> headerVerifiers, List<Predicate<ObjectNode>> contentVerifiers) {
        this.headerVerifiers = headerVerifiers;
        this.contentVerifiers = contentVerifiers;
    }

    public void verifyClaims(ObjectNode header, ObjectNode body) throws VerificationException {
        this.verifyHeaderClaims(header);
        this.verifyBodyClaims(body);
    }

    public void verifyHeaderClaims(ObjectNode header) throws VerificationException {
        for (Predicate<ObjectNode> verifier : this.headerVerifiers) {
            verifier.test(header);
        }
    }

    public void verifyBodyClaims(ObjectNode body) throws VerificationException {
        for (Predicate<ObjectNode> verifier : this.contentVerifiers) {
            verifier.test(body);
        }
    }

    public List<Predicate<ObjectNode>> getContentVerifiers() {
        return this.contentVerifiers;
    }

    public static Builder builder() {
        return new Builder();
    }

    public static interface Predicate<T> {
        public boolean test(T var1) throws VerificationException;

        default public Instant getCurrentTimestamp() {
            return Instant.ofEpochSecond(Time.currentTime());
        }
    }

    public static class Builder {
        protected Integer clockSkew = 10;
        protected Integer allowedMaxAge = 300;
        protected List<Predicate<ObjectNode>> headerVerifiers = new ArrayList<Predicate<ObjectNode>>();
        protected List<Predicate<ObjectNode>> contentVerifiers = new ArrayList<Predicate<ObjectNode>>();

        public Builder() {
            this(10);
        }

        public Builder(Integer clockSkew) {
            this.withClockSkew(Optional.ofNullable(clockSkew).orElse(10));
            this.withIatCheck(this.allowedMaxAge, false);
            this.withExpCheck(false);
            this.withNbfCheck(false);
            boolean isOptional = false;
            this.headerVerifiers.add(new NegatedClaimCheck("alg", "none", (s1, s2) -> s1 != null && s1.equalsIgnoreCase((String)s2), isOptional));
        }

        public Builder withClockSkew(int clockSkew) {
            this.clockSkew = Math.max(0, clockSkew);
            this.contentVerifiers.stream().filter(verifier -> verifier instanceof TimeCheck).forEach(timeCheckVerifier -> ((TimeCheck)((Object)timeCheckVerifier)).setClockSkewSeconds(clockSkew));
            return this;
        }

        public Builder withIatCheck(Integer allowedMaxAge) {
            return this.withIatCheck(allowedMaxAge, false);
        }

        public Builder withIatCheck(boolean isCheckOptional) {
            return this.withIatCheck(this.allowedMaxAge, isCheckOptional);
        }

        public Builder withIatCheck(Integer allowedMaxAge, boolean isCheckOptional) {
            this.allowedMaxAge = Optional.ofNullable(allowedMaxAge).orElse(0);
            this.contentVerifiers.removeIf(verifier -> verifier instanceof IatLifetimeCheck || verifier instanceof ClaimCheck && ((ClaimCheck)verifier).getClaimName().equalsIgnoreCase("iat"));
            if (allowedMaxAge != null) {
                this.contentVerifiers.add(new IatLifetimeCheck(Optional.ofNullable(this.clockSkew).orElse(0), allowedMaxAge.intValue(), isCheckOptional));
            }
            return this;
        }

        public Builder withNbfCheck() {
            this.withNbfCheck(false);
            return this;
        }

        public Builder withNbfCheck(boolean isCheckOptional) {
            this.contentVerifiers.removeIf(verifier -> verifier instanceof NbfCheck || verifier instanceof ClaimCheck && ((ClaimCheck)verifier).getClaimName().equalsIgnoreCase("nbf"));
            if (this.clockSkew != null) {
                this.contentVerifiers.add(new NbfCheck(this.clockSkew, isCheckOptional));
            }
            return this;
        }

        public Builder withExpCheck() {
            this.withExpCheck(false);
            return this;
        }

        public Builder withExpCheck(boolean isCheckOptional) {
            this.contentVerifiers.removeIf(verifier -> verifier instanceof ExpCheck || verifier instanceof ClaimCheck && ((ClaimCheck)verifier).getClaimName().equalsIgnoreCase("exp"));
            if (this.clockSkew != null) {
                this.contentVerifiers.add(new ExpCheck(this.clockSkew, isCheckOptional));
            }
            return this;
        }

        public Builder withAudCheck(String expectedAud) {
            this.contentVerifiers.removeIf(verifier -> verifier instanceof AudienceCheck || verifier instanceof ClaimCheck && ((ClaimCheck)verifier).getClaimName().equalsIgnoreCase("aud"));
            if (expectedAud != null) {
                this.contentVerifiers.add(new AudienceCheck(expectedAud));
            }
            return this;
        }

        public Builder withClaimCheck(String claimName, String expectedValue) {
            return this.withClaimCheck(claimName, expectedValue, false);
        }

        public Builder withClaimCheck(String claimName, String expectedValue, boolean isOptionalCheck) {
            this.contentVerifiers.removeIf(verifier -> verifier instanceof ClaimCheck && ((ClaimCheck)verifier).getClaimName().equals(claimName));
            if (expectedValue != null) {
                this.contentVerifiers.add(new ClaimCheck(claimName, expectedValue, isOptionalCheck));
            }
            return this;
        }

        public Builder withContentVerifiers(List<Predicate<ObjectNode>> contentVerifiers) {
            this.contentVerifiers = contentVerifiers;
            return this;
        }

        public Builder addContentVerifiers(List<Predicate<ObjectNode>> contentVerifiers) {
            this.contentVerifiers = Optional.ofNullable(this.contentVerifiers).orElseGet(ArrayList::new);
            this.contentVerifiers.addAll(contentVerifiers);
            return this;
        }

        public ClaimVerifier build() {
            return new ClaimVerifier(this.headerVerifiers, this.contentVerifiers);
        }
    }

    public static class AudienceCheck
    implements Predicate<ObjectNode> {
        private final String expectedAudience;

        public AudienceCheck(String expectedAudience) {
            this.expectedAudience = expectedAudience;
        }

        @Override
        public boolean test(ObjectNode t) throws VerificationException {
            if (this.expectedAudience == null) {
                throw new VerificationException("Missing expected audience");
            }
            JsonNode audienceArray = t.get("aud");
            if (audienceArray == null) {
                throw new VerificationException("No audience in the token");
            }
            HashSet<String> audiences = new HashSet<String>();
            if (audienceArray.isArray()) {
                for (JsonNode audienceNode : audienceArray) {
                    audiences.add(audienceNode.textValue());
                }
            } else {
                audiences.add(audienceArray.textValue());
            }
            if (audiences.contains(this.expectedAudience)) {
                return true;
            }
            throw new VerificationException(String.format("Expected audience '%s' not available in the token. Present values are '%s'", this.expectedAudience, audiences));
        }
    }

    public static class ExpCheck
    extends TimeCheck
    implements Predicate<ObjectNode> {
        private boolean isOptional;

        public ExpCheck(int clockSkewSeconds) {
            this(clockSkewSeconds, false);
        }

        public ExpCheck(int clockSkewSeconds, boolean isOptional) {
            super(Math.max(0, clockSkewSeconds));
            this.isOptional = isOptional;
        }

        @Override
        public boolean test(ObjectNode jsonWebToken) throws VerificationException {
            Long expiration = Optional.ofNullable(jsonWebToken.get("exp")).filter(node -> !node.isNull()).map(JsonNode::asLong).orElse(null);
            if (expiration == null) {
                if (this.isOptional) {
                    return true;
                }
                throw new VerificationException("Missing required claim 'exp'");
            }
            long now = this.getCurrentTimestamp().getEpochSecond();
            if (expiration < now - (long)this.clockSkewSeconds) {
                throw new VerificationException(String.format("Token has expired by exp: now: '%s', exp: '%s'", now, expiration));
            }
            return true;
        }
    }

    public static class NbfCheck
    extends TimeCheck
    implements Predicate<ObjectNode> {
        private boolean isOptional;

        public NbfCheck(int clockSkewSeconds) {
            this(clockSkewSeconds, false);
        }

        public NbfCheck(int clockSkewSeconds, boolean isOptional) {
            super(Math.max(0, clockSkewSeconds));
            this.isOptional = isOptional;
        }

        @Override
        public boolean test(ObjectNode jsonWebToken) throws VerificationException {
            Long notBefore = Optional.ofNullable(jsonWebToken.get("nbf")).filter(node -> !node.isNull()).map(JsonNode::asLong).orElse(null);
            if (notBefore == null) {
                if (this.isOptional) {
                    return true;
                }
                throw new VerificationException("Missing required claim 'nbf'");
            }
            long now = this.getCurrentTimestamp().getEpochSecond();
            if (notBefore > now + (long)this.clockSkewSeconds) {
                throw new VerificationException(String.format("Token is not yet valid: now: '%s', nbf: '%s'", now, notBefore));
            }
            return true;
        }
    }

    public static class IatLifetimeCheck
    extends TimeCheck
    implements Predicate<ObjectNode> {
        private final long maxLifetime;
        private boolean isOptional;

        public IatLifetimeCheck(int clockSkewSeconds, long maxLifetime) {
            this(clockSkewSeconds, maxLifetime, false);
        }

        public IatLifetimeCheck(int clockSkewSeconds, long maxLifetime, boolean isOptional) {
            super(Math.max(0, clockSkewSeconds));
            this.maxLifetime = Math.max(0L, maxLifetime);
            this.isOptional = isOptional;
        }

        @Override
        public boolean test(ObjectNode jsonWebToken) throws VerificationException {
            Long iat = Optional.ofNullable(jsonWebToken.get("iat")).filter(node -> !node.isNull()).map(JsonNode::asLong).orElse(null);
            if (iat == null) {
                if (this.isOptional) {
                    return true;
                }
                throw new VerificationException("Missing required claim 'iat'");
            }
            long now = this.getCurrentTimestamp().getEpochSecond();
            if (now + (long)this.clockSkewSeconds < iat) {
                throw new VerificationException(String.format("Token was issued in the future: now: '%s', iat: '%s'", now, iat));
            }
            long expiration = iat + this.maxLifetime;
            if (expiration < now - (long)this.clockSkewSeconds) {
                throw new VerificationException(String.format("Token has expired by iat: now: '%s', expired at: '%s', iat: '%s', maxLifetime: '%s'", now, expiration, iat, this.maxLifetime));
            }
            return true;
        }
    }

    public static class NegatedClaimCheck
    extends ClaimCheck {
        public NegatedClaimCheck(String claimName, String expectedClaimValue) {
            super(claimName, expectedClaimValue);
        }

        public NegatedClaimCheck(String claimName, String expectedClaimValue, boolean isOptional) {
            super(claimName, expectedClaimValue, isOptional);
        }

        public NegatedClaimCheck(String claimName, String expectedClaimValue, BiFunction<String, String, Boolean> stringComparator) {
            super(claimName, expectedClaimValue, stringComparator);
        }

        public NegatedClaimCheck(String claimName, String expectedClaimValue, BiFunction<String, String, Boolean> stringComparator, boolean isOptional) {
            super(claimName, expectedClaimValue, stringComparator, isOptional);
        }

        @Override
        public boolean test(ObjectNode t) throws VerificationException {
            boolean isParentCheckSuccessful;
            String claimValue = Optional.ofNullable(t.get(this.getClaimName())).map(JsonNode::asText).map(String::valueOf).orElse(null);
            if (claimValue == null && !this.isOptional()) {
                throw new VerificationException(String.format("Missing claim '%s' in token", this.getClaimName()));
            }
            if (claimValue == null && this.isOptional()) {
                return true;
            }
            try {
                isParentCheckSuccessful = super.test(t);
            }
            catch (VerificationException ve) {
                return true;
            }
            if (isParentCheckSuccessful) {
                throw new VerificationException(String.format("Value '%s' is not allowed for claim '%s'!", claimValue, this.getClaimName()));
            }
            return true;
        }
    }

    public static class ClaimCheck
    implements Predicate<ObjectNode> {
        private final String claimName;
        private final String expectedClaimValue;
        private final BiFunction<String, String, Boolean> stringComparator;
        private final boolean isOptional;

        public ClaimCheck(String claimName, String expectedClaimValue) {
            this(claimName, expectedClaimValue, false);
        }

        public ClaimCheck(String claimName, String expectedClaimValue, boolean isOptional) {
            this(claimName, expectedClaimValue, ClaimCheck.getDefaultComparator(), isOptional);
        }

        public ClaimCheck(String claimName, String expectedClaimValue, BiFunction<String, String, Boolean> stringComparator) {
            this(claimName, expectedClaimValue, stringComparator, false);
        }

        public ClaimCheck(String claimName, String expectedClaimValue, BiFunction<String, String, Boolean> stringComparator, boolean isOptional) {
            this.claimName = claimName;
            this.expectedClaimValue = expectedClaimValue;
            this.stringComparator = Optional.ofNullable(stringComparator).orElseGet(ClaimCheck::getDefaultComparator);
            this.isOptional = isOptional;
        }

        protected static BiFunction<String, String, Boolean> getDefaultComparator() {
            return Objects::equals;
        }

        @Override
        public boolean test(ObjectNode t) throws VerificationException {
            if (this.expectedClaimValue == null) {
                throw new VerificationException(String.format("Missing expected value for claim '%s'", this.claimName));
            }
            String claimValue = Optional.ofNullable(t.get(this.claimName)).map(JsonNode::asText).map(String::valueOf).orElse(null);
            if (claimValue == null && !this.isOptional) {
                throw new VerificationException(String.format("Missing claim '%s' in token", this.claimName));
            }
            boolean checkSuccessful = this.stringComparator.apply(this.expectedClaimValue, claimValue);
            if (!checkSuccessful) {
                String errorMessage = String.format("Expected value '%s' in token for claim '%s' does not match actual value '%s'", this.expectedClaimValue, this.claimName, claimValue);
                throw new VerificationException(errorMessage);
            }
            return true;
        }

        public String getClaimName() {
            return this.claimName;
        }

        public String getExpectedClaimValue() {
            return this.expectedClaimValue;
        }

        public boolean isOptional() {
            return this.isOptional;
        }
    }

    public static abstract class TimeCheck {
        protected int clockSkewSeconds;

        public TimeCheck(int clockSkewSeconds) {
            this.clockSkewSeconds = Math.max(0, clockSkewSeconds);
        }

        public int getClockSkewSeconds() {
            return this.clockSkewSeconds;
        }

        public void setClockSkewSeconds(int clockSkewSeconds) {
            this.clockSkewSeconds = clockSkewSeconds;
        }
    }
}

