/*
 * Decompiled with CFR 0.152.
 */
package de.uni_freiburg.informatik.ultimate.lib.tracecheckerutils.assertorders;

import de.uni_freiburg.informatik.ultimate.automata.nestedword.NestedWord;
import de.uni_freiburg.informatik.ultimate.core.model.services.ILogger;
import de.uni_freiburg.informatik.ultimate.lib.modelcheckerutils.cfg.structure.IAction;
import de.uni_freiburg.informatik.ultimate.lib.modelcheckerutils.smt.tracecheck.ITraceCheckPreferences;
import de.uni_freiburg.informatik.ultimate.lib.smtlibutils.solverbuilder.SMTFeatureExtractionTermClassifier;
import de.uni_freiburg.informatik.ultimate.lib.tracecheckerutils.Counterexample;
import de.uni_freiburg.informatik.ultimate.lib.tracecheckerutils.assertorders.IAssertOrder;
import de.uni_freiburg.informatik.ultimate.logic.Term;
import de.uni_freiburg.informatik.ultimate.util.datastructures.relation.Triple;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

public class AssertOrderSmtFeatureHeuristic<L extends IAction>
implements IAssertOrder<L> {
    private final SMTFeatureExtractionTermClassifier.ScoringMethod mScoringMethod;
    private final int mNumberOfPartitions;
    private final double mHeuristicThreshold;
    private final ITraceCheckPreferences.SmtFeatureHeuristicPartitioningType mPartitioningType;
    private final ILogger mLogger;

    public AssertOrderSmtFeatureHeuristic(SMTFeatureExtractionTermClassifier.ScoringMethod scoringMethod, int n, double d, ITraceCheckPreferences.SmtFeatureHeuristicPartitioningType smtFeatureHeuristicPartitioningType, ILogger iLogger) {
        this.mScoringMethod = scoringMethod;
        this.mNumberOfPartitions = n;
        this.mHeuristicThreshold = d;
        this.mPartitioningType = smtFeatureHeuristicPartitioningType;
        this.mLogger = iLogger;
    }

    private List<Triple<Term, Double, Integer>> scoreTrace(NestedWord<? extends IAction> nestedWord) {
        ArrayList<Triple<Term, Double, Integer>> arrayList = new ArrayList<Triple<Term, Double, Integer>>();
        int n = 0;
        while (n < nestedWord.length()) {
            SMTFeatureExtractionTermClassifier sMTFeatureExtractionTermClassifier = new SMTFeatureExtractionTermClassifier();
            Term term = ((IAction)nestedWord.getSymbol(n)).getTransformula().getFormula();
            sMTFeatureExtractionTermClassifier.checkTerm(term);
            Double d = sMTFeatureExtractionTermClassifier.getScore(this.mScoringMethod);
            arrayList.add((Triple<Term, Double, Integer>)new Triple((Object)term, (Object)d, (Object)n));
            ++n;
        }
        Collections.sort(arrayList, Comparator.comparing(triple -> -((Double)triple.getSecond()).doubleValue()));
        return arrayList;
    }

    private List<Set<Integer>> partitionFixedNumberOfPartitions(List<Triple<Term, Double, Integer>> list) {
        LinkedHashSet linkedHashSet = list.stream().map(Triple::getThird).collect(Collectors.toCollection(LinkedHashSet::new));
        int n = (int)Math.ceil((double)list.size() * (1.0 / (double)this.mNumberOfPartitions));
        LinkedHashSet<Integer> linkedHashSet2 = new LinkedHashSet<Integer>();
        int n2 = 0;
        ArrayList<Set<Integer>> arrayList = new ArrayList<Set<Integer>>();
        Iterator iterator = linkedHashSet.iterator();
        while (iterator.hasNext()) {
            int n3 = (Integer)iterator.next();
            linkedHashSet2.add(n3);
            if (linkedHashSet2.size() != n && ++n2 != linkedHashSet.size()) continue;
            arrayList.add(new LinkedHashSet(linkedHashSet2));
            linkedHashSet2 = new LinkedHashSet();
        }
        return arrayList;
    }

    private List<Set<Integer>> partitionUsingThreshold(List<Triple<Term, Double, Integer>> list) {
        LinkedHashSet<Integer> linkedHashSet = new LinkedHashSet<Integer>();
        LinkedHashSet<Integer> linkedHashSet2 = new LinkedHashSet<Integer>();
        for (Triple<Term, Double, Integer> arrayList2 : list) {
            Double d = (Double)arrayList2.getSecond();
            Integer n = (Integer)arrayList2.getThird();
            if (d >= this.mHeuristicThreshold) {
                linkedHashSet.add(n);
                continue;
            }
            linkedHashSet2.add(n);
        }
        ArrayList<Set<Integer>> arrayList = new ArrayList<Set<Integer>>();
        if (!linkedHashSet.isEmpty()) {
            arrayList.add(linkedHashSet);
        }
        if (!linkedHashSet2.isEmpty()) {
            arrayList.add(linkedHashSet2);
        }
        return arrayList;
    }

    private List<Set<Integer>> partitionStmtsAccordingToTermScores(List<Triple<Term, Double, Integer>> list) {
        return switch (this.mPartitioningType) {
            case ITraceCheckPreferences.SmtFeatureHeuristicPartitioningType.FIXED_NUM_PARTITIONS -> this.partitionFixedNumberOfPartitions(list);
            case ITraceCheckPreferences.SmtFeatureHeuristicPartitioningType.THRESHOLD -> this.partitionUsingThreshold(list);
            default -> throw new MatchException(null, null);
        };
    }

    @Override
    public List<Set<Integer>> partition(Counterexample<L> counterexample) {
        List<Triple<Term, Double, Integer>> list = this.scoreTrace(counterexample.getWord());
        List<Set<Integer>> list2 = this.partitionStmtsAccordingToTermScores(list);
        assert (!list2.isEmpty());
        if (this.mLogger.isDebugEnabled()) {
            this.mLogger.debug((Object)("Trace: " + counterexample.getWord().toString()));
            this.mLogger.debug((Object)("TermScoreTriples: " + list.toString()));
            this.mLogger.debug((Object)("Partitions: " + list2.toString()));
        }
        return list2;
    }
}

