/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.translate;

import ai.djl.Model;
import ai.djl.translate.BasicTranslator;
import ai.djl.translate.PostProcessor;
import ai.djl.translate.PreProcessor;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import ai.djl.translate.TranslatorOptions;
import ai.djl.util.Pair;
import java.lang.reflect.Type;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;

public abstract class ExpansionTranslatorFactory<IbaseT, ObaseT>
implements TranslatorFactory {
    @Override
    public Set<Pair<Type, Type>> getSupportedTypes() {
        HashSet<Pair<Type, Type>> results = new HashSet<Pair<Type, Type>>();
        results.addAll(this.getExpansions().keySet());
        HashSet<Type> preProcessorTypes = new HashSet<Type>();
        preProcessorTypes.addAll(this.getPreprocessorExpansions().keySet());
        preProcessorTypes.add(this.getBaseInputType());
        HashSet<Type> postProcessorTypes = new HashSet<Type>();
        postProcessorTypes.addAll(this.getPostprocessorExpansions().keySet());
        postProcessorTypes.add(this.getBaseOutputType());
        for (Type i : preProcessorTypes) {
            for (Type o : postProcessorTypes) {
                results.add(new Pair<Type, Type>(i, o));
            }
        }
        return results;
    }

    @Override
    public <I, O> Translator<I, O> newInstance(Class<I> input, Class<O> output, Model model, Map<String, ?> arguments) {
        Translator<IbaseT, ObaseT> baseTranslator = this.buildBaseTranslator(model, arguments);
        return this.newInstance(input, output, baseTranslator);
    }

    <I, O> Translator<I, O> newInstance(Class<I> input, Class<O> output, Translator<IbaseT, ObaseT> translator) {
        if (input.equals(this.getBaseInputType()) && output.equals(this.getBaseOutputType())) {
            return translator;
        }
        TranslatorExpansion<IbaseT, ObaseT> expansion = this.getExpansions().get(new Pair<Class<I>, Class<O>>(input, output));
        if (expansion != null) {
            return (Translator)expansion.apply(translator);
        }
        PreProcessor<Object> preProcessor = null;
        if (input.equals(this.getBaseInputType())) {
            preProcessor = translator;
        } else {
            Function<PreProcessor<IbaseT>, PreProcessor<?>> expander = this.getPreprocessorExpansions().get(input);
            if (expander != null) {
                preProcessor = expander.apply(translator);
            }
        }
        PostProcessor<Object> postProcessor = null;
        if (output.equals(this.getBaseOutputType())) {
            postProcessor = translator;
        } else {
            Function<PostProcessor<ObaseT>, PostProcessor<?>> expander = this.getPostprocessorExpansions().get(output);
            if (expander != null) {
                postProcessor = expander.apply(translator);
            }
        }
        if (preProcessor != null && postProcessor != null) {
            return new BasicTranslator(preProcessor, postProcessor, translator.getBatchifier());
        }
        throw new IllegalArgumentException("Unsupported expansion input/output types.");
    }

    public ExpandedTranslatorOptions withTranslator(Translator<IbaseT, ObaseT> translator) {
        return new ExpandedTranslatorOptions(translator);
    }

    protected abstract Translator<IbaseT, ObaseT> buildBaseTranslator(Model var1, Map<String, ?> var2);

    public abstract Class<IbaseT> getBaseInputType();

    public abstract Class<ObaseT> getBaseOutputType();

    protected Map<Pair<Type, Type>, TranslatorExpansion<IbaseT, ObaseT>> getExpansions() {
        return Collections.emptyMap();
    }

    protected Map<Type, Function<PreProcessor<IbaseT>, PreProcessor<?>>> getPreprocessorExpansions() {
        return Collections.singletonMap(this.getBaseInputType(), p -> p);
    }

    protected Map<Type, Function<PostProcessor<ObaseT>, PostProcessor<?>>> getPostprocessorExpansions() {
        return Collections.singletonMap(this.getBaseOutputType(), p -> p);
    }

    @FunctionalInterface
    public static interface TranslatorExpansion<IbaseT, ObaseT>
    extends Function<Translator<IbaseT, ObaseT>, Translator<?, ?>> {
    }

    final class ExpandedTranslatorOptions
    implements TranslatorOptions {
        private Translator<IbaseT, ObaseT> translator;

        private ExpandedTranslatorOptions(Translator<IbaseT, ObaseT> translator) {
            this.translator = translator;
        }

        @Override
        public Set<Pair<Type, Type>> getOptions() {
            return ExpansionTranslatorFactory.this.getSupportedTypes();
        }

        @Override
        public <I, O> Translator<I, O> option(Class<I> input, Class<O> output) {
            return ExpansionTranslatorFactory.this.newInstance(input, output, this.translator);
        }
    }
}

