package no.uib.cipr.matrix.distributed;

import java.util.Arrays;
import java.util.Iterator;
import no.uib.cipr.matrix.DenseLU;
import no.uib.cipr.matrix.DenseMatrix;
import no.uib.cipr.matrix.DenseVector;
import no.uib.cipr.matrix.Matrix;
import no.uib.cipr.matrix.MatrixEntry;
import no.uib.cipr.matrix.Vector;
import no.uib.cipr.matrix.VectorEntry;
import no.uib.cipr.matrix.sparse.Preconditioner;

/* loaded from: input_file:lib/mtj-0.9.12.jar:no/uib/cipr/matrix/distributed/TwoLevelPreconditioner.class */
public class TwoLevelPreconditioner extends BlockDiagonalPreconditioner {
    private static final int root = 0;
    private final DistMatrix A;
    private final Communicator comm;
    private final int rank;
    private final int size;
    private final int[] indexToRank;
    private final DistVector z;
    private final DistVector r;
    private final DenseMatrix A0;
    private final DenseVector b0;
    private final DenseLU lu;
    private final double[] Ai;
    private final double[][] Ai0;
    private final boolean row;
    private final double[][] zi0;

    public TwoLevelPreconditioner(Preconditioner preconditioner, DistRowMatrix distRowMatrix, DistVector distVector) {
        super(preconditioner);
        this.A = distRowMatrix;
        this.z = distVector;
        this.r = distVector.copy();
        this.row = true;
        this.indexToRank = createIndexToRank(distRowMatrix.numColumns(), distRowMatrix.getColumnOwnerships());
        this.comm = distRowMatrix.getCommunicator();
        this.rank = this.comm.rank();
        this.size = this.comm.size();
        this.A0 = new DenseMatrix(this.size, this.size);
        this.b0 = new DenseVector(this.size);
        this.lu = new DenseLU(this.size, this.size);
        this.Ai = new double[this.size];
        if (this.rank == 0) {
            this.Ai0 = new double[this.size][this.size];
            this.zi0 = new double[this.size][1];
        } else {
            this.Ai0 = (double[][]) null;
            this.zi0 = (double[][]) null;
        }
    }

    public TwoLevelPreconditioner(Preconditioner preconditioner, DistColMatrix distColMatrix, DistVector distVector) {
        super(preconditioner);
        this.A = distColMatrix;
        this.z = distVector;
        this.r = distVector.copy();
        this.row = false;
        this.indexToRank = createIndexToRank(distColMatrix.numColumns(), distColMatrix.getColumnOwnerships());
        this.comm = distColMatrix.getCommunicator();
        this.rank = this.comm.rank();
        this.size = this.comm.size();
        this.A0 = new DenseMatrix(this.size, this.size);
        this.b0 = new DenseVector(this.size);
        this.lu = new DenseLU(this.size, this.size);
        this.Ai = new double[this.size];
        if (this.rank == 0) {
            this.Ai0 = new double[this.size][this.size];
            this.zi0 = new double[this.size][1];
        } else {
            this.Ai0 = (double[][]) null;
            this.zi0 = (double[][]) null;
        }
    }

    private int[] createIndexToRank(int i, int[] iArr) {
        int[] iArr2 = new int[i];
        for (int i2 = 0; i2 < iArr.length - 1; i2++) {
            for (int i3 = iArr[i2]; i3 < iArr[i2 + 1]; i3++) {
                iArr2[i3] = i2;
            }
        }
        return iArr2;
    }

    @Override // no.uib.cipr.matrix.distributed.BlockDiagonalPreconditioner, no.uib.cipr.matrix.sparse.Preconditioner
    public Vector apply(Vector vector, Vector vector2) {
        if ((vector instanceof DistVector) && (vector2 instanceof DistVector)) {
            return apply(vector, vector2, false);
        }
        throw new IllegalArgumentException("Vectors must be DistVectors");
    }

    @Override // no.uib.cipr.matrix.distributed.BlockDiagonalPreconditioner, no.uib.cipr.matrix.sparse.Preconditioner
    public Vector transApply(Vector vector, Vector vector2) {
        if ((vector instanceof DistVector) && (vector2 instanceof DistVector)) {
            return apply(vector, vector2, true);
        }
        throw new IllegalArgumentException("Vectors must be DistVectors");
    }

    private Vector apply(Vector vector, Vector vector2, boolean z) {
        calculateCoarseResidual(vector, vector2, z);
        if (this.rank == 0) {
            solveCoarseSystem(z);
        }
        updateWithCoarseCorrection(vector2);
        return applyBlockPreconditioner(vector, vector2, z);
    }

    private void calculateCoarseResidual(Vector vector, Vector vector2, boolean z) {
        if (z) {
            this.A.transMultAdd(-1.0d, vector2, this.z.set(vector));
        } else {
            this.A.multAdd(-1.0d, vector2, this.z.set(vector));
        }
        double d = 0.0d;
        Iterator<VectorEntry> it = this.z.getLocal().iterator();
        while (it.hasNext()) {
            d += it.next().get();
        }
        this.comm.gather(new double[]{d}, this.zi0, 0);
    }

    private void solveCoarseSystem(boolean z) {
        for (int i = 0; i < this.size; i++) {
            this.b0.set(i, this.zi0[i][0]);
        }
        if (z) {
            this.lu.transSolve(new DenseMatrix((Vector) this.b0, false));
        } else {
            this.lu.solve(new DenseMatrix((Vector) this.b0, false));
        }
        double[] data = this.b0.getData();
        for (int i2 = 0; i2 < this.size; i2++) {
            this.zi0[i2][0] = data[i2];
        }
    }

    private void updateWithCoarseCorrection(Vector vector) {
        double[] dArr = new double[1];
        this.comm.scatter(this.zi0, dArr, 0);
        for (VectorEntry vectorEntry : ((DistVector) vector).getLocal()) {
            vectorEntry.set(vectorEntry.get() + dArr[0]);
        }
    }

    private Vector applyBlockPreconditioner(Vector vector, Vector vector2, boolean z) {
        if (z) {
            this.A.transMultAdd(-1.0d, vector2, this.z.set(vector));
        } else {
            this.A.multAdd(-1.0d, vector2, this.z.set(vector));
        }
        this.r.set(vector);
        if (z) {
            super.transApply(this.z, this.r);
        } else {
            super.apply(this.z, this.r);
        }
        return vector2.add(this.r);
    }

    @Override // no.uib.cipr.matrix.distributed.BlockDiagonalPreconditioner, no.uib.cipr.matrix.sparse.Preconditioner
    public void setMatrix(Matrix matrix) {
        if (!(matrix instanceof DistMatrix)) {
            throw new IllegalArgumentException("A is not a DistRowMatrix or a DistColMatrix");
        }
        Matrix block = this.A.getBlock();
        Matrix off = this.A.getOff();
        super.setMatrix(matrix);
        Arrays.fill(this.Ai, 0.0d);
        for (MatrixEntry matrixEntry : block) {
            double[] dArr = this.Ai;
            int i = this.rank;
            dArr[i] = dArr[i] + matrixEntry.get();
        }
        if (this.row) {
            for (MatrixEntry matrixEntry2 : off) {
                double[] dArr2 = this.Ai;
                int i2 = this.indexToRank[matrixEntry2.column()];
                dArr2[i2] = dArr2[i2] + matrixEntry2.get();
            }
            this.comm.gather(this.Ai, this.Ai0, 0);
            if (this.rank == 0) {
                for (int i3 = 0; i3 < this.size; i3++) {
                    for (int i4 = 0; i4 < this.size; i4++) {
                        this.A0.set(i3, i4, this.Ai0[i3][i4]);
                    }
                }
            }
        } else {
            for (MatrixEntry matrixEntry3 : off) {
                double[] dArr3 = this.Ai;
                int i5 = this.indexToRank[matrixEntry3.row()];
                dArr3[i5] = dArr3[i5] + matrixEntry3.get();
            }
            this.comm.gather(this.Ai, this.Ai0, 0);
            if (this.rank == 0) {
                for (int i6 = 0; i6 < this.size; i6++) {
                    for (int i7 = 0; i7 < this.size; i7++) {
                        this.A0.set(i6, i7, this.Ai0[i7][i6]);
                    }
                }
            }
        }
        if (this.rank == 0) {
            this.lu.factor(this.A0);
        }
    }
}
