/*
 * Decompiled with CFR 0.152.
 */
package com.ibm.gpu;

import com.ibm.cuda.CudaBuffer;
import com.ibm.cuda.CudaDevice;
import com.ibm.cuda.CudaException;
import com.ibm.cuda.CudaGrid;
import com.ibm.cuda.CudaKernel;
import com.ibm.cuda.CudaModule;
import com.ibm.cuda.CudaStream;
import com.ibm.cuda.Dim3;
import com.ibm.gpu.CUDAManager;
import com.ibm.gpu.GPUConfigurationException;
import com.ibm.gpu.GPUSortException;
import com.ibm.gpu.PtxKernelGenerator;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.Queue;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;

final class SortNetwork {
    private static final Integer[] powersOf2;
    private static final ConcurrentHashMap<LoadKey, LoadResult> resultsMap;
    private final CudaDevice device;
    private final int maxGridDimX;
    private final CudaKernel sortFirst4;
    private final CudaKernel sortOther1;
    private final CudaKernel sortOther2;
    private final CudaKernel sortOther3;
    private final CudaKernel sortOther4;
    private final CudaKernel sortPhase9;

    private static void checkIndices(int length, int fromIndex, int toIndex) {
        if (fromIndex > toIndex) {
            throw new IllegalArgumentException();
        }
        if (fromIndex < 0) {
            throw new ArrayIndexOutOfBoundsException(fromIndex);
        }
        if (toIndex > length) {
            throw new ArrayIndexOutOfBoundsException(toIndex);
        }
    }

    private static SortNetwork load(int deviceId, char type) throws GPUConfigurationException {
        LoadKey key = new LoadKey(deviceId, type);
        return resultsMap.computeIfAbsent(key, LoadResult::create).get();
    }

    private static int roundUp(int value, int unit) {
        assert (value > 0);
        assert (unit > 0);
        int remainder = value % unit;
        return remainder == 0 ? value : value + (unit - remainder);
    }

    private static int significantBits(int value) {
        return 32 - Integer.numberOfLeadingZeros(Math.max(1, value));
    }

    static void sortArray(int deviceId, double[] array, int fromIndex, int toIndex) throws GPUConfigurationException, GPUSortException {
        CUDAManager manager = SortNetwork.traceStart(deviceId, "double", fromIndex, toIndex);
        try {
            SortNetwork network = SortNetwork.load(deviceId, 'D');
            network.sort(array, fromIndex, toIndex);
        }
        catch (GPUConfigurationException | GPUSortException e) {
            SortNetwork.traceFailure(manager, e);
            throw e;
        }
        SortNetwork.traceSuccess(manager, deviceId, "double");
    }

    static void sortArray(int deviceId, float[] array, int fromIndex, int toIndex) throws GPUConfigurationException, GPUSortException {
        CUDAManager manager = SortNetwork.traceStart(deviceId, "float", fromIndex, toIndex);
        try {
            SortNetwork network = SortNetwork.load(deviceId, 'F');
            network.sort(array, fromIndex, toIndex);
        }
        catch (GPUConfigurationException | GPUSortException e) {
            SortNetwork.traceFailure(manager, e);
            throw e;
        }
        SortNetwork.traceSuccess(manager, deviceId, "float");
    }

    static void sortArray(int deviceId, int[] array, int fromIndex, int toIndex) throws GPUConfigurationException, GPUSortException {
        CUDAManager manager = SortNetwork.traceStart(deviceId, "int", fromIndex, toIndex);
        try {
            SortNetwork network = SortNetwork.load(deviceId, 'I');
            network.sort(array, fromIndex, toIndex);
        }
        catch (GPUConfigurationException | GPUSortException e) {
            SortNetwork.traceFailure(manager, e);
            throw e;
        }
        SortNetwork.traceSuccess(manager, deviceId, "int");
    }

    static void sortArray(int deviceId, long[] array, int fromIndex, int toIndex) throws GPUConfigurationException, GPUSortException {
        CUDAManager manager = SortNetwork.traceStart(deviceId, "long", fromIndex, toIndex);
        try {
            SortNetwork network = SortNetwork.load(deviceId, 'J');
            network.sort(array, fromIndex, toIndex);
        }
        catch (GPUConfigurationException | GPUSortException e) {
            SortNetwork.traceFailure(manager, e);
            throw e;
        }
        SortNetwork.traceSuccess(manager, deviceId, "long");
    }

    private static void traceFailure(CUDAManager manager, Exception exception) {
        manager.outputIfVerbose(exception.getLocalizedMessage());
    }

    private static CUDAManager traceStart(int deviceId, String type, int fromIndex, int toIndex) {
        CUDAManager manager = CUDAManager.instanceInternal();
        if (manager.getVerboseGPUOutput()) {
            manager.outputIfVerbose("Using device: " + deviceId + " to sort " + type + " array; elements " + fromIndex + " to " + toIndex);
        }
        return manager;
    }

    private static void traceSuccess(CUDAManager manager, int deviceId, String type) {
        if (manager.getVerboseGPUOutput()) {
            manager.outputIfVerbose("Sorted " + type + "s on device " + deviceId + " successfully");
        }
    }

    SortNetwork(CudaDevice device, CudaModule module) throws CudaException {
        this.device = device;
        this.maxGridDimX = device.getAttribute(5);
        this.sortFirst4 = new CudaKernel(module, "first4");
        this.sortOther1 = new CudaKernel(module, "other1");
        this.sortOther2 = new CudaKernel(module, "other2");
        this.sortOther3 = new CudaKernel(module, "other3");
        this.sortOther4 = new CudaKernel(module, "other4");
        this.sortPhase9 = new CudaKernel(module, "phase9");
    }

    private CudaGrid makeGrid(int threadCount, int blockSize, CudaStream stream) {
        int blockCount = Math.max(1, (threadCount + blockSize - 1) / blockSize);
        return new CudaGrid(this.makeGridDim(blockCount), new Dim3(blockSize), stream);
    }

    private Dim3 makeGridDim(int blockCount) {
        int blockDimX = Math.max(1, blockCount);
        int blockDimY = 1;
        while (blockDimX > this.maxGridDimX) {
            if ((blockDimX & 1) != 0) {
                ++blockDimX;
            }
            blockDimX >>= 1;
            blockDimY <<= 1;
        }
        return new Dim3(blockDimX, blockDimY);
    }

    private void sort(double[] array, int fromIndex, int toIndex) throws GPUSortException {
        int length = toIndex - fromIndex;
        if (length < 2) {
            SortNetwork.checkIndices(array.length, fromIndex, toIndex);
            return;
        }
        try (CudaBuffer gpuBuffer = new CudaBuffer(this.device, (long)length * 8L);){
            gpuBuffer.copyFrom(array, fromIndex, toIndex);
            this.sortBuffer(gpuBuffer, length);
            gpuBuffer.copyTo(array, fromIndex, toIndex);
        }
        catch (CudaException e) {
            throw new GPUSortException(e.getLocalizedMessage(), e);
        }
    }

    private void sort(float[] array, int fromIndex, int toIndex) throws GPUSortException {
        int length = toIndex - fromIndex;
        if (length < 2) {
            SortNetwork.checkIndices(array.length, fromIndex, toIndex);
            return;
        }
        try (CudaBuffer gpuBuffer = new CudaBuffer(this.device, (long)length * 4L);){
            gpuBuffer.copyFrom(array, fromIndex, toIndex);
            this.sortBuffer(gpuBuffer, length);
            gpuBuffer.copyTo(array, fromIndex, toIndex);
        }
        catch (CudaException e) {
            throw new GPUSortException(e.getLocalizedMessage(), e);
        }
    }

    private void sort(int[] array, int fromIndex, int toIndex) throws GPUSortException {
        int length = toIndex - fromIndex;
        if (length < 2) {
            SortNetwork.checkIndices(array.length, fromIndex, toIndex);
            return;
        }
        try (CudaBuffer gpuBuffer = new CudaBuffer(this.device, (long)length * 4L);){
            gpuBuffer.copyFrom(array, fromIndex, toIndex);
            this.sortBuffer(gpuBuffer, length);
            gpuBuffer.copyTo(array, fromIndex, toIndex);
        }
        catch (CudaException e) {
            throw new GPUSortException(e.getLocalizedMessage(), e);
        }
    }

    private void sort(long[] array, int fromIndex, int toIndex) throws GPUSortException {
        int length = toIndex - fromIndex;
        if (length < 2) {
            SortNetwork.checkIndices(array.length, fromIndex, toIndex);
            return;
        }
        try (CudaBuffer gpuBuffer = new CudaBuffer(this.device, (long)length * 8L);){
            gpuBuffer.copyFrom(array, fromIndex, toIndex);
            this.sortBuffer(gpuBuffer, length);
            gpuBuffer.copyTo(array, fromIndex, toIndex);
        }
        catch (CudaException e) {
            throw new GPUSortException(e.getLocalizedMessage(), e);
        }
    }

    private void sortBuffer(CudaBuffer buffer, int length) throws CudaException {
        try (CudaStream stream = new CudaStream(this.device);){
            Integer boxLength = length;
            int phaseCount = 9;
            int inputSize = 512;
            int blockSize = 256;
            CudaGrid grid = this.makeGrid(length >> 1, 256, stream);
            this.sortPhase9.launch(grid, buffer, boxLength);
            phaseCount = SortNetwork.significantBits(length - 1);
            if (phaseCount <= 9) {
                return;
            }
            int blockSize2 = 256;
            CudaGrid gridOther = this.makeGrid(length >> 1, 256, stream);
            block17: for (int phase = 9; phase < phaseCount; ++phase) {
                int granule = 1 << phase;
                int grains = SortNetwork.roundUp(length, granule);
                CudaGrid grid2 = this.makeGrid(grains >> 1, 256, stream);
                this.sortFirst4.launch(grid2, buffer, boxLength, powersOf2[phase]);
                int step = phase;
                while ((step -= 4) >= 3) {
                    this.sortOther4.launch(gridOther, buffer, boxLength, powersOf2[step]);
                }
                switch (phase & 3) {
                    case 2: {
                        this.sortOther3.launch(gridOther, buffer, boxLength, powersOf2[2]);
                        continue block17;
                    }
                    case 1: {
                        this.sortOther2.launch(gridOther, buffer, boxLength, powersOf2[1]);
                        continue block17;
                    }
                    case 0: {
                        this.sortOther1.launch(gridOther, buffer, boxLength, powersOf2[0]);
                        continue block17;
                    }
                }
            }
        }
    }

    static {
        int phaseCount = 31;
        Integer[] powers = new Integer[31];
        for (int i = 0; i < 31; ++i) {
            powers[i] = 1 << i;
        }
        powersOf2 = powers;
        resultsMap = new ConcurrentHashMap();
    }

    private static final class ShutdownHook
    extends Thread {
        private static final Queue<CudaModule> modules = new ConcurrentLinkedQueue<CudaModule>();

        public static void unloadOnShutdown(CudaModule module) {
            modules.add(module);
        }

        private ShutdownHook() {
            super("GPU sort shutdown helper");
        }

        @Override
        public void run() {
            CudaModule module;
            while ((module = modules.poll()) != null) {
                try {
                    module.unload();
                }
                catch (CudaException cudaException) {}
            }
        }

        static {
            AccessController.doPrivileged(() -> {
                Runtime.getRuntime().addShutdownHook(new ShutdownHook());
                return null;
            });
        }
    }

    private static final class LoadResult {
        private final SortNetwork network;
        private final String problem;

        static LoadResult create(LoadKey key) {
            try {
                CudaDevice device = new CudaDevice(key.deviceId);
                int capability = device.getAttribute(75);
                if (capability < 2) {
                    return LoadResult.failure("Unsupported device");
                }
                ByteArrayOutputStream ptxBuffer = new ByteArrayOutputStream(49152);
                PtxKernelGenerator.writeTo(capability, key.type, ptxBuffer);
                ptxBuffer.write(0);
                byte[] ptxCode = ptxBuffer.toByteArray();
                PrivilegedAction<LoadResult> loader = () -> LoadResult.load(device, ptxCode);
                return AccessController.doPrivileged(loader);
            }
            catch (CudaException | IOException e) {
                return LoadResult.failure(e);
            }
        }

        private static LoadResult failure(Exception exception) {
            return new LoadResult(exception);
        }

        private static LoadResult failure(String problem) {
            return new LoadResult(problem);
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        private static LoadResult load(CudaDevice device, byte[] ptxCode) {
            LoadResult result;
            try {
                CudaModule module = null;
                try {
                    module = new CudaModule(device, ptxCode);
                    result = LoadResult.success(new SortNetwork(device, module));
                    ShutdownHook.unloadOnShutdown(module);
                    module = null;
                }
                finally {
                    if (module != null) {
                        module.unload();
                    }
                }
            }
            catch (CudaException e) {
                result = LoadResult.failure(e);
            }
            return result;
        }

        private static LoadResult success(SortNetwork network) {
            return new LoadResult(network);
        }

        private LoadResult(Exception exception) {
            this(exception.getLocalizedMessage());
        }

        private LoadResult(SortNetwork network) {
            this.network = network;
            this.problem = null;
        }

        private LoadResult(String problem) {
            this.network = null;
            this.problem = problem;
        }

        SortNetwork get() throws GPUConfigurationException {
            if (this.problem != null) {
                throw new GPUConfigurationException(this.problem);
            }
            return this.network;
        }
    }

    private static final class LoadKey {
        final int deviceId;
        final char type;

        LoadKey(int deviceId, char type) {
            this.deviceId = deviceId;
            this.type = type;
        }

        public boolean equals(Object object) {
            if (object instanceof LoadKey) {
                LoadKey that = (LoadKey)object;
                if (this.deviceId == that.deviceId && this.type == that.type) {
                    return true;
                }
            }
            return false;
        }

        public int hashCode() {
            return this.deviceId << 4 ^ this.type;
        }
    }
}

