package com.whisperonnx.voice_translation.neural_networks.voice;

import ai.onnxruntime.OnnxJavaType;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import com.google.common.primitives.Floats;
import com.google.common.primitives.Ints;
import java.lang.reflect.Array;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.util.Arrays;
import java.util.function.IntToLongFunction;

/* loaded from: classes2.dex */
public final class TensorUtils {
    public static float[][][] batchTensor(float[][] fArr, int i) {
        float[][][] fArr2 = new float[i][];
        for (int i2 = 0; i2 < i; i2++) {
            fArr2[i2] = fArr;
        }
        return fArr2;
    }

    public static float[][][][] batchTensor(float[][][] fArr, int i) {
        float[][][][] fArr2 = new float[i][][];
        for (int i2 = 0; i2 < i; i2++) {
            fArr2[i2] = fArr;
        }
        return fArr2;
    }

    public static OnnxTensor convertBooleanToTensor(OrtEnvironment ortEnvironment, boolean z) throws OrtException {
        long[] jArr = {1};
        byte[] bArr = {0};
        if (z) {
            bArr[0] = 1;
        }
        return OnnxTensor.createTensor(ortEnvironment, ByteBuffer.wrap(bArr), jArr, OnnxJavaType.BOOL);
    }

    public static OnnxTensor convertIntArrayToTensor(OrtEnvironment ortEnvironment, int[] iArr) throws OrtException {
        return OnnxTensor.createTensor(ortEnvironment, LongBuffer.wrap(Arrays.stream(iArr).mapToLong(new IntToLongFunction() { // from class: com.whisperonnx.voice_translation.neural_networks.voice.TensorUtils$$ExternalSyntheticLambda2
            @Override // java.util.function.IntToLongFunction
            public final long applyAsLong(int i) {
                return TensorUtils.lambda$convertIntArrayToTensor$1(i);
            }
        }).toArray()), new long[]{1, iArr.length});
    }

    public static OnnxTensor convertIntArrayToTensor(OrtEnvironment ortEnvironment, int[] iArr, long[] jArr) throws OrtException {
        return OnnxTensor.createTensor(ortEnvironment, LongBuffer.wrap(Arrays.stream(iArr).mapToLong(new IntToLongFunction() { // from class: com.whisperonnx.voice_translation.neural_networks.voice.TensorUtils$$ExternalSyntheticLambda0
            @Override // java.util.function.IntToLongFunction
            public final long applyAsLong(int i) {
                return TensorUtils.lambda$convertIntArrayToTensor$2(i);
            }
        }).toArray()), jArr);
    }

    public static OnnxTensor createFloatTensor(OrtEnvironment ortEnvironment, float[] fArr, long[] jArr) throws OrtException {
        return OnnxTensor.createTensor(ortEnvironment, FloatBuffer.wrap(fArr), jArr);
    }

    public static OnnxTensor createFloatTensor(OrtEnvironment ortEnvironment, float[][][] fArr, long[] jArr, long[] jArr2) throws OrtException {
        float[] flattenFloatArray = flattenFloatArray(fArr);
        long currentTimeMillis = System.currentTimeMillis();
        ByteBuffer allocateDirect = ByteBuffer.allocateDirect(flattenFloatArray.length * 4);
        allocateDirect.order(ByteOrder.LITTLE_ENDIAN);
        allocateDirect.asFloatBuffer().put(flattenFloatArray);
        allocateDirect.position(0);
        OnnxTensor createTensor = OnnxTensor.createTensor(ortEnvironment, allocateDirect, jArr, OnnxJavaType.FLOAT);
        jArr2[0] = System.currentTimeMillis() - currentTimeMillis;
        return createTensor;
    }

    public static OnnxTensor createFloatTensor(OrtEnvironment ortEnvironment, float[][][][] fArr, long[] jArr) throws OrtException {
        return OnnxTensor.createTensor(ortEnvironment, FloatBuffer.wrap(flattenFloatArray(fArr)), jArr);
    }

    public static OnnxTensor createFloatTensor(OrtEnvironment ortEnvironment, float[][][][] fArr, long[] jArr, long[] jArr2) throws OrtException {
        float[] flattenFloatArray = flattenFloatArray(fArr);
        long currentTimeMillis = System.currentTimeMillis();
        ByteBuffer allocateDirect = ByteBuffer.allocateDirect(flattenFloatArray.length * 4);
        allocateDirect.order(ByteOrder.LITTLE_ENDIAN);
        allocateDirect.asFloatBuffer().put(flattenFloatArray);
        allocateDirect.position(0);
        OnnxTensor createTensor = OnnxTensor.createTensor(ortEnvironment, allocateDirect, jArr, OnnxJavaType.FLOAT);
        jArr2[0] = System.currentTimeMillis() - currentTimeMillis;
        return createTensor;
    }

    public static OnnxTensor createFloatTensorOptimized(OrtEnvironment ortEnvironment, float[][][][] fArr, long[] jArr) throws OrtException {
        return OnnxTensor.createTensor(ortEnvironment, fArr);
    }

    public static OnnxTensor createFloatTensorWithSingleValue(OrtEnvironment ortEnvironment, float f, long[] jArr) throws OrtException {
        FloatBuffer asFloatBuffer;
        long j = jArr[0];
        for (int i = 1; i < jArr.length; i++) {
            j *= jArr[i];
        }
        if (f != 0.0f) {
            float[] fArr = new float[(int) j];
            Arrays.fill(fArr, f);
            asFloatBuffer = FloatBuffer.wrap(fArr);
        } else {
            asFloatBuffer = ByteBuffer.allocateDirect((int) (j * 4)).asFloatBuffer();
        }
        return OnnxTensor.createTensor(ortEnvironment, asFloatBuffer, jArr);
    }

    public static OnnxTensor createInt32Tensor(OrtEnvironment ortEnvironment, int[] iArr, long[] jArr) throws OrtException {
        return OnnxTensor.createTensor(ortEnvironment, IntBuffer.wrap(iArr), jArr);
    }

    public static OnnxTensor createInt64TensorWithSingleValue(OrtEnvironment ortEnvironment, long j, long[] jArr) {
        LongBuffer asLongBuffer;
        long j2 = jArr[0];
        for (int i = 1; i < jArr.length; i++) {
            j2 *= jArr[i];
        }
        if (j != 0) {
            long[] jArr2 = new long[(int) j2];
            Arrays.fill(jArr2, j);
            asLongBuffer = LongBuffer.wrap(jArr2);
        } else {
            asLongBuffer = ByteBuffer.allocateDirect((int) (j2 * 8)).asLongBuffer();
        }
        try {
            return OnnxTensor.createTensor(ortEnvironment, asLongBuffer, jArr);
        } catch (OrtException e) {
            throw new RuntimeException(e);
        }
    }

    public static OnnxTensor createIntTensor(OrtEnvironment ortEnvironment, int[] iArr, long[] jArr) throws OrtException {
        return OnnxTensor.createTensor(ortEnvironment, LongBuffer.wrap(Arrays.stream(iArr).mapToLong(new IntToLongFunction() { // from class: com.whisperonnx.voice_translation.neural_networks.voice.TensorUtils$$ExternalSyntheticLambda1
            @Override // java.util.function.IntToLongFunction
            public final long applyAsLong(int i) {
                return TensorUtils.lambda$createIntTensor$0(i);
            }
        }).toArray()), jArr);
    }

    public static float[][][][] extractFloatMatrix(OnnxTensor onnxTensor, int i, int i2, int i3, int i4, long[] jArr) throws NoSuchMethodException, InvocationTargetException, IllegalAccessException {
        long currentTimeMillis = System.currentTimeMillis();
        int i5 = i * i2 * i3 * i4;
        Method declaredMethod = onnxTensor.getClass().getDeclaredMethod("getBuffer", new Class[0]);
        declaredMethod.setAccessible(true);
        float[] fArr = new float[i5];
        ((ByteBuffer) declaredMethod.invoke(onnxTensor, new Object[0])).asFloatBuffer().get(fArr, 0, i5);
        jArr[0] = System.currentTimeMillis() - currentTimeMillis;
        long currentTimeMillis2 = System.currentTimeMillis();
        float[][][][] fArr2 = (float[][][][]) Array.newInstance((Class<?>) Float.TYPE, i, i2, i3, i4);
        for (int i6 = 0; i6 < i; i6++) {
            for (int i7 = 0; i7 < i2; i7++) {
                for (int i8 = 0; i8 < i3; i8++) {
                    for (int i9 = 0; i9 < i4; i9++) {
                        fArr2[i6][i7][i8][i9] = fArr[(i6 * i2 * i3 * i4) + (i7 * i3 * i4) + (i8 * i4) + i9];
                    }
                }
            }
        }
        jArr[1] = System.currentTimeMillis() - currentTimeMillis2;
        return fArr2;
    }

    public static float[][][][] extractFloatMatrix(OrtSession.Result result, String str, int i, int i2, int i3, int i4) throws NoSuchMethodException, InvocationTargetException, IllegalAccessException {
        int i5 = i * i2 * i3 * i4;
        OnnxTensor onnxTensor = (OnnxTensor) result.get(str).get();
        Method declaredMethod = onnxTensor.getClass().getDeclaredMethod("getBuffer", new Class[0]);
        declaredMethod.setAccessible(true);
        FloatBuffer asFloatBuffer = ((ByteBuffer) declaredMethod.invoke(onnxTensor, new Object[0])).asFloatBuffer();
        float[] fArr = new float[i5];
        asFloatBuffer.get(fArr, 0, i5);
        float[][][][] fArr2 = (float[][][][]) Array.newInstance((Class<?>) Float.TYPE, i, i2, i3, i4);
        for (int i6 = 0; i6 < i; i6++) {
            for (int i7 = 0; i7 < i2; i7++) {
                for (int i8 = 0; i8 < i3; i8++) {
                    for (int i9 = 0; i9 < i4; i9++) {
                        fArr2[i6][i7][i8][i9] = fArr[(i6 * i2 * i3 * i4) + (i7 * i3 * i4) + (i8 * i4) + i9];
                    }
                }
            }
        }
        return fArr2;
    }

    public static float[][][][] extractFloatMatrixAlternative(OnnxTensor onnxTensor, int i, int i2, int i3, int i4) throws OrtException {
        return (float[][][][]) onnxTensor.getValue();
    }

    public static Object extractValue(OrtSession.Result result, String str) throws OrtException {
        return ((OnnxTensor) result.get(str).get()).getValue();
    }

    public static float[] flattenFloatArray(float[][][] fArr) {
        int length = fArr.length;
        float[][] fArr2 = fArr[0];
        float[][] fArr3 = (float[][]) Array.newInstance((Class<?>) Float.TYPE, length, fArr2.length * fArr2[0].length);
        for (int i = 0; i < fArr.length; i++) {
            fArr3[i] = Floats.concat(fArr[i]);
        }
        return Floats.concat(fArr3);
    }

    public static float[] flattenFloatArray(float[][][][] fArr) {
        int length = fArr.length;
        float[][][] fArr2 = fArr[0];
        int length2 = fArr2.length;
        float[][] fArr3 = fArr2[0];
        float[][] fArr4 = (float[][]) Array.newInstance((Class<?>) Float.TYPE, length, length2 * fArr3.length * fArr3[0].length);
        for (int i = 0; i < fArr.length; i++) {
            fArr4[i] = flattenFloatArray(fArr[i]);
        }
        return Floats.concat(fArr4);
    }

    public static float[] flattenFloatArrayBatched(float[][] fArr, int i) {
        float[] concat = Floats.concat(fArr);
        float[][] fArr2 = (float[][]) Array.newInstance((Class<?>) Float.TYPE, i, concat.length);
        for (int i2 = 0; i2 < i; i2++) {
            fArr2[i2] = concat;
        }
        return Floats.concat(fArr2);
    }

    public static float[] flattenFloatArrayBatched(float[][][] fArr, int i) {
        float[] flattenFloatArray = flattenFloatArray(fArr);
        float[][] fArr2 = (float[][]) Array.newInstance((Class<?>) Float.TYPE, i, flattenFloatArray.length);
        for (int i2 = 0; i2 < i; i2++) {
            fArr2[i2] = flattenFloatArray;
        }
        return Floats.concat(fArr2);
    }

    public static int[] flattenIntArrayBatched(int[] iArr, int i) {
        int[][] iArr2 = (int[][]) Array.newInstance((Class<?>) Integer.TYPE, i, iArr.length);
        for (int i2 = 0; i2 < i; i2++) {
            iArr2[i2] = iArr;
        }
        return Ints.concat(iArr2);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ long lambda$convertIntArrayToTensor$1(int i) {
        return i;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ long lambda$convertIntArrayToTensor$2(int i) {
        return i;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static /* synthetic */ long lambda$createIntTensor$0(int i) {
        return i;
    }

    public static long[] tensorShape(long... jArr) {
        return Arrays.copyOf(jArr, jArr.length);
    }
}
