/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.scheduler;

import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.configuration.MemorySize;
import org.apache.flink.runtime.clusterframework.types.ResourceProfile;
import org.apache.flink.runtime.executiongraph.EdgeManagerBuildUtil;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.ExecutionVertex;
import org.apache.flink.runtime.executiongraph.IndexRange;
import org.apache.flink.runtime.executiongraph.IntermediateResult;
import org.apache.flink.runtime.executiongraph.IntermediateResultPartition;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.jobgraph.IntermediateDataSet;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.JobEdge;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobmanager.scheduler.SlotSharingGroup;
import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
import org.apache.flink.runtime.shuffle.ShuffleMaster;
import org.apache.flink.runtime.shuffle.TaskInputsOutputsDescriptor;
import org.apache.flink.util.Preconditions;

public class SsgNetworkMemoryCalculationUtils {
    public static void enrichNetworkMemory(SlotSharingGroup ssg, Function<JobVertexID, ExecutionJobVertex> ejvs, ShuffleMaster<?> shuffleMaster) {
        ResourceProfile original = ssg.getResourceProfile();
        if (original.equals(ResourceProfile.UNKNOWN) || !original.getNetworkMemory().equals((Object)MemorySize.ZERO)) {
            return;
        }
        MemorySize networkMemory = MemorySize.ZERO;
        for (JobVertexID jvId : ssg.getJobVertexIds()) {
            ExecutionJobVertex ejv = ejvs.apply(jvId);
            TaskInputsOutputsDescriptor desc = SsgNetworkMemoryCalculationUtils.buildTaskInputsOutputsDescriptor(ejv, ejvs);
            MemorySize requiredNetworkMemory = shuffleMaster.computeShuffleMemorySizeForTask(desc);
            networkMemory = networkMemory.add(requiredNetworkMemory);
        }
        ResourceProfile enriched = ResourceProfile.newBuilder().setCpuCores(original.getCpuCores()).setTaskHeapMemory(original.getTaskHeapMemory()).setTaskOffHeapMemory(original.getTaskOffHeapMemory()).setManagedMemory(original.getManagedMemory()).setNetworkMemory(networkMemory).setExtendedResources(original.getExtendedResources().values()).build();
        ssg.setResourceProfile(enriched);
    }

    private static TaskInputsOutputsDescriptor buildTaskInputsOutputsDescriptor(ExecutionJobVertex ejv, Function<JobVertexID, ExecutionJobVertex> ejvs) {
        Map<IntermediateDataSetID, Integer> partitionReuseCount = SsgNetworkMemoryCalculationUtils.getPartitionReuseCount(ejv);
        HashMap<IntermediateDataSetID, Integer> maxInputChannelNums = new HashMap<IntermediateDataSetID, Integer>();
        HashMap<IntermediateDataSetID, Integer> maxSubpartitionNums = new HashMap<IntermediateDataSetID, Integer>();
        HashMap<IntermediateDataSetID, ResultPartitionType> inputPartitionTypes = new HashMap<IntermediateDataSetID, ResultPartitionType>();
        HashMap<IntermediateDataSetID, ResultPartitionType> partitionTypes = new HashMap<IntermediateDataSetID, ResultPartitionType>();
        if (ejv.getGraph().isDynamic()) {
            SsgNetworkMemoryCalculationUtils.getMaxInputChannelInfoForDynamicGraph(ejv, maxInputChannelNums, inputPartitionTypes);
            SsgNetworkMemoryCalculationUtils.getMaxSubpartitionInfoForDynamicGraph(ejv, maxSubpartitionNums, partitionTypes);
        } else {
            SsgNetworkMemoryCalculationUtils.getMaxInputChannelInfo(ejv, maxInputChannelNums, inputPartitionTypes);
            SsgNetworkMemoryCalculationUtils.getMaxSubpartitionInfo(ejv, maxSubpartitionNums, partitionTypes, ejvs);
        }
        JobVertex jv = ejv.getJobVertex();
        return TaskInputsOutputsDescriptor.from(jv.getNumberOfInputs(), maxInputChannelNums, partitionReuseCount, maxSubpartitionNums, inputPartitionTypes, partitionTypes);
    }

    private static Map<IntermediateDataSetID, Integer> getPartitionReuseCount(ExecutionJobVertex ejv) {
        HashMap<IntermediateDataSetID, Integer> partitionReuseCount = new HashMap<IntermediateDataSetID, Integer>();
        for (IntermediateResult intermediateResult : ejv.getInputs()) {
            partitionReuseCount.merge(intermediateResult.getId(), 1, Integer::sum);
        }
        return partitionReuseCount;
    }

    private static void getMaxInputChannelInfo(ExecutionJobVertex ejv, Map<IntermediateDataSetID, Integer> maxInputChannelNums, Map<IntermediateDataSetID, ResultPartitionType> inputPartitionTypes) {
        List<JobEdge> inputEdges = ejv.getJobVertex().getInputs();
        for (int i = 0; i < inputEdges.size(); ++i) {
            JobEdge inputEdge = inputEdges.get(i);
            IntermediateResult consumedResult = ejv.getInputs().get(i);
            Preconditions.checkState((boolean)consumedResult.getId().equals(inputEdge.getSourceId()));
            int maxNum = EdgeManagerBuildUtil.computeMaxEdgesToTargetExecutionVertex(ejv.getParallelism(), consumedResult.getNumberOfAssignedPartitions(), inputEdge.getDistributionPattern());
            maxInputChannelNums.put(consumedResult.getId(), maxNum);
            inputPartitionTypes.putIfAbsent(consumedResult.getId(), consumedResult.getResultType());
        }
    }

    private static void getMaxSubpartitionInfo(ExecutionJobVertex ejv, Map<IntermediateDataSetID, Integer> maxSubpartitionNums, Map<IntermediateDataSetID, ResultPartitionType> partitionTypes, Function<JobVertexID, ExecutionJobVertex> ejvs) {
        List<IntermediateDataSet> producedDataSets = ejv.getJobVertex().getProducedDataSets();
        Preconditions.checkState((!ejv.getGraph().isDynamic() ? 1 : 0) != 0, (Object)"Only support non-dynamic graph.");
        for (IntermediateDataSet producedDataSet : producedDataSets) {
            int maxNum = 0;
            List<JobEdge> outputEdges = producedDataSet.getConsumers();
            if (!outputEdges.isEmpty()) {
                JobEdge outputEdge = outputEdges.get(0);
                ExecutionJobVertex consumerJobVertex = ejvs.apply(outputEdge.getTarget().getID());
                maxNum = EdgeManagerBuildUtil.computeMaxEdgesToTargetExecutionVertex(ejv.getParallelism(), consumerJobVertex.getParallelism(), outputEdge.getDistributionPattern());
            }
            maxSubpartitionNums.put(producedDataSet.getId(), maxNum);
            partitionTypes.putIfAbsent(producedDataSet.getId(), producedDataSet.getResultType());
        }
    }

    @VisibleForTesting
    static void getMaxInputChannelInfoForDynamicGraph(ExecutionJobVertex ejv, Map<IntermediateDataSetID, Integer> maxInputChannelNums, Map<IntermediateDataSetID, ResultPartitionType> inputPartitionTypes) {
        for (ExecutionVertex vertex : ejv.getTaskVertices()) {
            for (ConsumedPartitionGroup partitionGroup : vertex.getAllConsumedPartitionGroups()) {
                IntermediateResultPartition resultPartition = ejv.getGraph().getResultPartitionOrThrow(partitionGroup.getFirst());
                IndexRange subpartitionIndexRange = vertex.getExecutionVertexInputInfo(resultPartition.getIntermediateResult().getId()).getSubpartitionIndexRange();
                maxInputChannelNums.merge(partitionGroup.getIntermediateDataSetID(), subpartitionIndexRange.size() * partitionGroup.size(), Integer::max);
                inputPartitionTypes.putIfAbsent(partitionGroup.getIntermediateDataSetID(), partitionGroup.getResultPartitionType());
            }
        }
    }

    private static void getMaxSubpartitionInfoForDynamicGraph(ExecutionJobVertex ejv, Map<IntermediateDataSetID, Integer> maxSubpartitionNums, Map<IntermediateDataSetID, ResultPartitionType> partitionTypes) {
        for (IntermediateResult intermediateResult : ejv.getProducedDataSets()) {
            int maxNum = Arrays.stream(intermediateResult.getPartitions()).map(IntermediateResultPartition::getNumberOfSubpartitions).reduce(0, Integer::max);
            maxSubpartitionNums.put(intermediateResult.getId(), maxNum);
            partitionTypes.putIfAbsent(intermediateResult.getId(), intermediateResult.getResultType());
        }
    }

    private SsgNetworkMemoryCalculationUtils() {
    }
}

