package no.uib.cipr.matrix.distributed;

import java.lang.reflect.Array;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.Exchanger;
import no.uib.cipr.matrix.distributed.Communicator;

/* loaded from: input_file:lib/mtj-0.9.12.jar:no/uib/cipr/matrix/distributed/CollectiveCommunications.class */
public class CollectiveCommunications {
    final int size;
    private final List<List<Exchanger<Communicator.SendRecv>>> ex;
    private final CyclicBarrier barrier;
    private final Broadcast broadcast;
    private final Gather gather;
    private final Scatter scatter;
    private final AllGather allGather;
    private final AllToAll allToAll;
    private final Reduce reduce;

    /* loaded from: input_file:lib/mtj-0.9.12.jar:no/uib/cipr/matrix/distributed/CollectiveCommunications$AllGather.class */
    private class AllGather implements Runnable {
        final CyclicBarrier barrier;
        private final Object[] sendbuf;
        private final Object[][] recvbuf;
        private final int[] length;

        private AllGather() {
            this.barrier = new CyclicBarrier(CollectiveCommunications.this.size, this);
            this.sendbuf = new Object[CollectiveCommunications.this.size];
            this.recvbuf = new Object[CollectiveCommunications.this.size][CollectiveCommunications.this.size];
            this.length = new int[CollectiveCommunications.this.size];
        }

        public void setSendRecv(Object obj, Object[] objArr, int i) {
            this.sendbuf[i] = obj;
            this.recvbuf[i] = objArr;
            this.length[i] = Array.getLength(obj);
        }

        @Override // java.lang.Runnable
        public void run() {
            for (int i = 0; i < CollectiveCommunications.this.size; i++) {
                for (int i2 = 0; i2 < CollectiveCommunications.this.size; i2++) {
                    System.arraycopy(this.sendbuf[i], 0, this.recvbuf[i2][i], 0, this.length[i]);
                }
            }
        }
    }

    /* loaded from: input_file:lib/mtj-0.9.12.jar:no/uib/cipr/matrix/distributed/CollectiveCommunications$AllToAll.class */
    private class AllToAll implements Runnable {
        final CyclicBarrier barrier;
        private final Object[][] sendbuf;
        private final Object[][] recvbuf;
        private final int[][] length;

        private AllToAll() {
            this.barrier = new CyclicBarrier(CollectiveCommunications.this.size, this);
            this.sendbuf = new Object[CollectiveCommunications.this.size][CollectiveCommunications.this.size];
            this.recvbuf = new Object[CollectiveCommunications.this.size][CollectiveCommunications.this.size];
            this.length = new int[CollectiveCommunications.this.size][CollectiveCommunications.this.size];
        }

        public void setSendRecv(Object[] objArr, Object[] objArr2, int i) {
            this.sendbuf[i] = objArr;
            this.recvbuf[i] = objArr2;
            for (int i2 = 0; i2 < CollectiveCommunications.this.size; i2++) {
                this.length[i][i2] = Array.getLength(objArr[i2]);
            }
        }

        @Override // java.lang.Runnable
        public void run() {
            for (int i = 0; i < CollectiveCommunications.this.size; i++) {
                for (int i2 = 0; i2 < CollectiveCommunications.this.size; i2++) {
                    System.arraycopy(this.sendbuf[i][i2], 0, this.recvbuf[i2][i], 0, this.length[i][i2]);
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:lib/mtj-0.9.12.jar:no/uib/cipr/matrix/distributed/CollectiveCommunications$Broadcast.class */
    public class Broadcast implements Runnable {
        final CyclicBarrier barrier;
        int root;
        final Object[] buffer;

        private Broadcast() {
            this.barrier = new CyclicBarrier(CollectiveCommunications.this.size, this);
            this.buffer = new Object[CollectiveCommunications.this.size];
        }

        @Override // java.lang.Runnable
        public void run() {
            int length = Array.getLength(this.buffer[this.root]);
            for (int i = 0; i < CollectiveCommunications.this.size; i++) {
                System.arraycopy(this.buffer[this.root], 0, this.buffer[i], 0, length);
            }
        }
    }

    /* loaded from: input_file:lib/mtj-0.9.12.jar:no/uib/cipr/matrix/distributed/CollectiveCommunications$Gather.class */
    private class Gather implements Runnable {
        CyclicBarrier barrier;
        Object[] recvbuf;
        Object[] sendbuf;
        private final int[] length;

        private Gather() {
            this.barrier = new CyclicBarrier(CollectiveCommunications.this.size, this);
            this.recvbuf = new Object[CollectiveCommunications.this.size];
            this.sendbuf = new Object[CollectiveCommunications.this.size];
            this.length = new int[CollectiveCommunications.this.size];
        }

        public void setSend(Object obj, int i) {
            this.sendbuf[i] = obj;
            this.length[i] = Array.getLength(obj);
        }

        @Override // java.lang.Runnable
        public void run() {
            for (int i = 0; i < CollectiveCommunications.this.size; i++) {
                System.arraycopy(this.sendbuf[i], 0, this.recvbuf[i], 0, this.length[i]);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:lib/mtj-0.9.12.jar:no/uib/cipr/matrix/distributed/CollectiveCommunications$Reduce.class */
    public class Reduce implements Runnable {
        CyclicBarrier barrier;
        Reduction op;
        Object[] sendbuf;
        Object recvbuf;

        private Reduce() {
            this.barrier = new CyclicBarrier(CollectiveCommunications.this.size, this);
            this.sendbuf = new Object[CollectiveCommunications.this.size];
        }

        @Override // java.lang.Runnable
        public void run() {
            this.op.init(this.recvbuf);
            for (int i = 0; i < CollectiveCommunications.this.size; i++) {
                this.op.op(this.recvbuf, this.sendbuf[i]);
            }
        }
    }

    /* loaded from: input_file:lib/mtj-0.9.12.jar:no/uib/cipr/matrix/distributed/CollectiveCommunications$Scatter.class */
    private class Scatter implements Runnable {
        final CyclicBarrier barrier;
        Object[] sendbuf;
        Object[] recvbuf;
        private final int[] length;

        private Scatter() {
            this.barrier = new CyclicBarrier(CollectiveCommunications.this.size, this);
            this.sendbuf = new Object[CollectiveCommunications.this.size];
            this.recvbuf = new Object[CollectiveCommunications.this.size];
            this.length = new int[CollectiveCommunications.this.size];
        }

        public void setRecv(Object obj, int i) {
            this.recvbuf[i] = obj;
            this.length[i] = Array.getLength(obj);
        }

        @Override // java.lang.Runnable
        public void run() {
            for (int i = 0; i < CollectiveCommunications.this.size; i++) {
                System.arraycopy(this.sendbuf[i], 0, this.recvbuf[i], 0, this.length[i]);
            }
        }
    }

    public CollectiveCommunications(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("size < 1");
        }
        this.size = i;
        this.barrier = new CyclicBarrier(i);
        this.broadcast = new Broadcast();
        this.gather = new Gather();
        this.scatter = new Scatter();
        this.allGather = new AllGather();
        this.allToAll = new AllToAll();
        this.reduce = new Reduce();
        this.ex = new ArrayList();
        for (int i2 = 0; i2 < i; i2++) {
            ArrayList arrayList = new ArrayList();
            for (int i3 = 0; i3 < i2; i3++) {
                arrayList.add(this.ex.get(i3).get(i2));
            }
            arrayList.add(null);
            for (int i4 = i2 + 1; i4 < i; i4++) {
                arrayList.add(new Exchanger());
            }
            this.ex.add(arrayList);
        }
    }

    public int size() {
        return this.size;
    }

    public Communicator createCommunicator(int i) {
        if (i < 0 || i >= this.size) {
            throw new IllegalArgumentException("rank < 0 || rank >= size");
        }
        return new Communicator(i, this.ex.get(i), this);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void await(CyclicBarrier cyclicBarrier) {
        try {
            cyclicBarrier.await();
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        } catch (BrokenBarrierException e2) {
            throw new RuntimeException(e2);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void barrier() {
        await(this.barrier);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void broadcast(Object obj, int i, int i2) {
        this.broadcast.buffer[i2] = obj;
        if (i2 == i) {
            this.broadcast.root = i;
        }
        await(this.broadcast.barrier);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void gather(Object obj, Object[] objArr, int i, int i2) {
        this.gather.setSend(obj, i2);
        if (i2 == i) {
            this.gather.recvbuf = objArr;
        }
        await(this.gather.barrier);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void scatter(Object[] objArr, Object obj, int i, int i2) {
        this.scatter.setRecv(obj, i2);
        if (i2 == i) {
            this.scatter.sendbuf = objArr;
        }
        await(this.scatter.barrier);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void allGather(Object obj, Object[] objArr, int i) {
        this.allGather.setSendRecv(obj, objArr, i);
        await(this.allGather.barrier);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void allToAll(Object[] objArr, Object[] objArr2, int i) {
        this.allToAll.setSendRecv(objArr, objArr2, i);
        await(this.allToAll.barrier);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void reduce(Object obj, Object obj2, Reduction reduction, int i, int i2) {
        this.reduce.sendbuf[i2] = obj;
        if (i2 == i) {
            this.reduce.op = reduction;
            this.reduce.recvbuf = obj2;
        }
        await(this.reduce.barrier);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void allReduce(Object obj, Object obj2, Reduction reduction, int i) {
        reduce(obj, obj2, reduction, 0, i);
        broadcast(obj2, 0, i);
    }
}
