diff --git a/CMakeLists.txt b/CMakeLists.txt index 0f08b5dd..7083502a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -109,7 +109,7 @@ set_target_properties(${_trgt} PROPERTIES CMAKE_POSITION_INDEPENDENT_CODE ON C_STANDARD 99 ) -target_include_directories(${_trgt} PUBLIC mkl_umath/src/ ${Python_NumPy_INCLUDE_DIRS} ${Python_INCLUDE_DIRS}) +target_include_directories(${_trgt} PUBLIC mkl_umath/src/ mkl_umath/src/npyv/ ${Python_NumPy_INCLUDE_DIRS} ${Python_INCLUDE_DIRS}) target_link_libraries(${_trgt} PUBLIC MKL::MKL ${Python_LIBRARIES}) target_link_options(${_trgt} PUBLIC ${_linker_options}) target_compile_options(${_trgt} PUBLIC -fveclib=SVML) @@ -117,6 +117,10 @@ target_compile_options(${_trgt} PUBLIC -fvectorize) if(OPTIMIZATION_REPORT) target_compile_options(${_trgt} PRIVATE -qopt-report=3) endif() +# Enable AVX2 for npyv SIMD acceleration +# -mvzeroupper: Prefer VEX encoding even for scalar ops +# -mtune=native: Optimize for build machine (can change to -mtune=haswell for generic AVX2) +target_compile_options(${_trgt} PUBLIC -mavx2 -mfma -mvzeroupper) install(TARGETS ${_trgt} LIBRARY DESTINATION mkl_umath ARCHIVE DESTINATION mkl_umath diff --git a/mkl_umath/src/mkl_umath_loops.c.src b/mkl_umath/src/mkl_umath_loops.c.src index b588ca6f..c68e183a 100644 --- a/mkl_umath/src/mkl_umath_loops.c.src +++ b/mkl_umath/src/mkl_umath_loops.c.src @@ -68,8 +68,10 @@ */ #define PW_BLOCKSIZE 128 #define VML_TRANSCEDENTAL_THRESHOLD 8192 +#define VML_SIMPLE_OPS_THRESHOLD 8192 #define VML_ASM_THRESHOLD 100000 #define VML_D_THRESHOLD 8000 +#define SMALL_ARRAY_THRESHOLD 2048 #define MKL_INT_MAX ((npy_intp) ((~((MKL_UINT) 0)) >> 1)) @@ -130,7 +132,14 @@ /* for pointers p1, and p2 pointing at contiguous arrays n-elements of size s, are arrays disjoint or same * when these conditions are not met VML functions may produce incorrect output */ -#define DISJOINT_OR_SAME(p1, p2, n, s) (((p1) == (p2)) || ((p2) + (n)*(s) < (p1)) || ((p1) + (n)*(s) < (p2))) +static inline int disjoint_or_same_check(const void *p1, const void *p2, npy_intp n, npy_intp s) { + const char *_p1 = (const char*)p1; + const char *_p2 = (const char*)p2; + if (_p1 == _p2) return 1; + const npy_intp size = n * s; + return (_p2 + size <= _p1) || (_p1 + size <= _p2); +} +#define DISJOINT_OR_SAME(p1, p2, n, s) disjoint_or_same_check((p1), (p2), (n), (s)) #define DISJOINT_OR_SAME_TWO_DTYPES(p1, p2, n, s1, s2) (((p1) == (p2)) || ((p2) + (n)*(s2) < (p1)) || ((p1) + (n)*(s1) < (p2))) /* @@ -142,7 +151,6 @@ /** Provides the various *_LOOP macros */ #include "fast_loop_macros.h" -#include static inline npy_double spacing(npy_double x) { if (isinf(x)) @@ -230,6 +238,7 @@ divmod@c@(@type@ a, @type@ b, @type@ *modulus) * #TYPE = FLOAT, DOUBLE# * #c = f, # * #s = s, d# + * #sfx = f32, f64# */ /* @@ -244,7 +253,7 @@ divmod@c@(@type@ a, @type@ b, @type@ *modulus) #pragma intel optimization_level 2 #endif #endif -static @type@ +static inline @type@ pairwise_sum_@TYPE@(char *a, npy_intp n, npy_intp stride) { if (n < 8) { @@ -311,35 +320,224 @@ pairwise_sum_@TYPE@(char *a, npy_intp n, npy_intp stride) } } -/**begin repeat1 - * # kind = add, subtract# - * # OP = +, -# - * # PW = 1, 0# - * # VML = Add, Sub# - */ +/* Special npyv-accelerated add implementation */ +#include "npyv/npyv.h" + void -mkl_umath_@TYPE@_@kind@(char **args, const npy_intp *dimensions, const npy_intp *steps, void *NPY_UNUSED(func)) +mkl_umath_@TYPE@_add(char **args, const npy_intp *dimensions, const npy_intp *steps, void *NPY_UNUSED(func)) { + npy_intp len = dimensions[0]; + char *src0 = args[0], *src1 = args[1], *dst = args[2]; + npy_intp ssrc0 = steps[0], ssrc1 = steps[1], sdst = steps[2]; + + /* + * REDUCTION: For np.sum() operations + * Check FIRST, before any size-based optimizations + * Uses pairwise summation for numerical accuracy (O(log n) error vs O(n)) + */ + if (IS_BINARY_REDUCE) { + *((@type@*)src0) += pairwise_sum_@TYPE@(src1, len, ssrc1); + return; + } + + /* + * TIER 1: Tiny arrays - simple scalar loop + * Overhead of SIMD setup not worth it + */ + if (len < 32) { + for (; len > 0; --len, src0 += ssrc0, src1 += ssrc1, dst += sdst) { + const @type@ a = *(@type@*)src0; + const @type@ b = *(@type@*)src1; + *(@type@*)dst = a + b; + } + return; + } + + /* + * Calculate memory overlap checks once - reused by multiple tiers + */ + const int disjoint_or_same1 = DISJOINT_OR_SAME(src0, dst, len, sizeof(@type@)); + const int disjoint_or_same2 = DISJOINT_OR_SAME(src1, dst, len, sizeof(@type@)); + + /* + * TIER 3a: Large arrays with scalar broadcast - use MKL VML LinearFrac + * For operations like array + scalar or scalar + array + * LinearFrac computes: out = scaleA * in + shiftA + * For add: out = 1.0 * array + scalar + */ + // Pattern: scalar + array[i] (scalar is first argument with stride=0) + if (ssrc0 == 0 && ssrc1 == sizeof(@type@) && sdst == sizeof(@type@)) { + if (len >= VML_SIMPLE_OPS_THRESHOLD && disjoint_or_same2) { + // scalar + array[i] → 1.0 * array[i] + scalar + CHUNKED_VML_LINEARFRAC_CALL(v@s@LinearFrac, len, @type@, + src1, dst, 1.0, *(@type@*)src0, 0.0, 1.0); + return; + } + } + // Pattern: array[i] + scalar (scalar is second argument with stride=0) + else if (ssrc1 == 0 && ssrc0 == sizeof(@type@) && sdst == sizeof(@type@)) { + if (len >= VML_SIMPLE_OPS_THRESHOLD && disjoint_or_same1) { + // array[i] + scalar → 1.0 * array[i] + scalar + CHUNKED_VML_LINEARFRAC_CALL(v@s@LinearFrac, len, @type@, + src0, dst, 1.0, *(@type@*)src1, 0.0, 1.0); + return; + } + } + + /* + * TIER 3: Large arrays - use MKL VML + * MKL's overhead is amortized, can use multi-threading + */ + const int is_contiguous = (ssrc0 == sizeof(@type@) && + ssrc1 == sizeof(@type@) && + sdst == sizeof(@type@)); + + if (len >= VML_SIMPLE_OPS_THRESHOLD && is_contiguous) { + if (disjoint_or_same1 && disjoint_or_same2) { + CHUNKED_VML_CALL3(v@s@Add, len, @type@, @type@, src0, src1, dst); + return; + } + } + + /* + * TIER 2: Medium arrays - use npyv SIMD + * Better than scalar, no MKL overhead + */ +#if NPYV_CAN_VECTORIZE_@TYPE@ + if (!is_mem_overlap(src0, ssrc0, dst, sdst, len) && + !is_mem_overlap(src1, ssrc1, dst, sdst, len)) { + + const int vstep = npyv_nlanes_u8; // Byte step (32 for AVX2) + const int wstep = vstep * 2; // Double vector step + const int hstep = npyv_nlanes_@sfx@; // Element step (8 for f32/AVX2) + const int lstep = hstep * 2; // Double element step + + // PATTERN 1: Fully contiguous + if (ssrc0 == sizeof(@type@) && ssrc1 == sizeof(@type@) && sdst == sizeof(@type@)) { + // Main loop: process 2 vectors per iteration + for (; len >= lstep; len -= lstep, src0 += wstep, src1 += wstep, dst += wstep) { + npyv_@sfx@ a0 = npyv_load_@sfx@((const @type@*)src0); + npyv_@sfx@ a1 = npyv_load_@sfx@((const @type@*)(src0 + vstep)); + npyv_@sfx@ b0 = npyv_load_@sfx@((const @type@*)src1); + npyv_@sfx@ b1 = npyv_load_@sfx@((const @type@*)(src1 + vstep)); + npyv_@sfx@ r0 = npyv_add_@sfx@(a0, b0); + npyv_@sfx@ r1 = npyv_add_@sfx@(a1, b1); + npyv_store_@sfx@((@type@*)dst, r0); + npyv_store_@sfx@((@type@*)(dst + vstep), r1); + } + + // Remainder loop: handle trailing elements + for (; len > 0; len -= hstep, src0 += vstep, src1 += vstep, dst += vstep) { + npyv_@sfx@ a = npyv_load_tillz_@sfx@((const @type@*)src0, len); + npyv_@sfx@ b = npyv_load_tillz_@sfx@((const @type@*)src1, len); + npyv_@sfx@ r = npyv_add_@sfx@(a, b); + npyv_store_till_@sfx@((@type@*)dst, len, r); + } + + npyv_cleanup(); + return; + } + + // PATTERN 2: Broadcast first argument (e.g., scalar + array) + if (ssrc0 == 0 && ssrc1 == sizeof(@type@) && sdst == sizeof(@type@)) { + npyv_@sfx@ a = npyv_setall_@sfx@(*(@type@*)src0); + + for (; len >= lstep; len -= lstep, src1 += wstep, dst += wstep) { + npyv_@sfx@ b0 = npyv_load_@sfx@((const @type@*)src1); + npyv_@sfx@ b1 = npyv_load_@sfx@((const @type@*)(src1 + vstep)); + npyv_@sfx@ r0 = npyv_add_@sfx@(a, b0); + npyv_@sfx@ r1 = npyv_add_@sfx@(a, b1); + npyv_store_@sfx@((@type@*)dst, r0); + npyv_store_@sfx@((@type@*)(dst + vstep), r1); + } + + for (; len > 0; len -= hstep, src1 += vstep, dst += vstep) { + npyv_@sfx@ b = npyv_load_tillz_@sfx@((const @type@*)src1, len); + npyv_@sfx@ r = npyv_add_@sfx@(a, b); + npyv_store_till_@sfx@((@type@*)dst, len, r); + } + + npyv_cleanup(); + return; + } + + // PATTERN 3: Broadcast second argument (e.g., array + scalar) + if (ssrc1 == 0 && ssrc0 == sizeof(@type@) && sdst == sizeof(@type@)) { + npyv_@sfx@ b = npyv_setall_@sfx@(*(@type@*)src1); + + for (; len >= lstep; len -= lstep, src0 += wstep, dst += wstep) { + npyv_@sfx@ a0 = npyv_load_@sfx@((const @type@*)src0); + npyv_@sfx@ a1 = npyv_load_@sfx@((const @type@*)(src0 + vstep)); + npyv_@sfx@ r0 = npyv_add_@sfx@(a0, b); + npyv_@sfx@ r1 = npyv_add_@sfx@(a1, b); + npyv_store_@sfx@((@type@*)dst, r0); + npyv_store_@sfx@((@type@*)(dst + vstep), r1); + } + + for (; len > 0; len -= hstep, src0 += vstep, dst += vstep) { + npyv_@sfx@ a = npyv_load_tillz_@sfx@((const @type@*)src0, len); + npyv_@sfx@ r = npyv_add_@sfx@(a, b); + npyv_store_till_@sfx@((@type@*)dst, len, r); + } + + npyv_cleanup(); + return; + } + + // Future: Add strided memory patterns here + } +#endif // NPYV_CAN_VECTORIZE + + /* + * FALLBACK: Scalar loop for all other cases + * - Non-contiguous with complex strides + * - Memory overlap detected + * - SIMD not available + */ + for (; len > 0; --len, src0 += ssrc0, src1 += ssrc1, dst += sdst) { + const @type@ a = *(@type@*)src0; + const @type@ b = *(@type@*)src1; + *(@type@*)dst = a + b; + } +} + +/* Keep subtract with original implementation for now */ +void +mkl_umath_@TYPE@_subtract(char **args, const npy_intp *dimensions, const npy_intp *steps, void *NPY_UNUSED(func)) +{ + const npy_intp n = dimensions[0]; + + /* FAST PATH: Small contiguous arrays - avoid all overhead */ + if (IS_BINARY_CONT(@type@, @type@) && n <= SMALL_ARRAY_THRESHOLD) { + @type@ *ip1 = (@type@*)args[0]; + @type@ *ip2 = (@type@*)args[1]; + @type@ *op1 = (@type@*)args[2]; + + NPY_PRAGMA_VECTOR + for(npy_intp i = 0; i < n; i++) { + op1[i] = ip1[i] - ip2[i]; + } + return; + } + const int disjoint_or_same1 = DISJOINT_OR_SAME(args[0], args[2], dimensions[0], sizeof(@type@)); const int disjoint_or_same2 = DISJOINT_OR_SAME(args[1], args[2], dimensions[0], sizeof(@type@)); if (IS_BINARY_CONT(@type@, @type@)) { - if (dimensions[0] > VML_ASM_THRESHOLD && disjoint_or_same1 && disjoint_or_same2) { - CHUNKED_VML_CALL3(v@s@@VML@, dimensions[0], @type@, @type@, args[0], args[1], args[2]); - /* v@s@@VML@(dimensions[0], (@type@*) args[0], (@type@*) args[1], (@type@*) args[2]); */ + if (dimensions[0] > VML_SIMPLE_OPS_THRESHOLD && disjoint_or_same1 && disjoint_or_same2) { + CHUNKED_VML_CALL3(v@s@Sub, dimensions[0], @type@, @type@, args[0], args[1], args[2]); } else { @type@ *ip1 = (@type@*)args[0]; @type@ *ip2 = (@type@*)args[1]; @type@ *op1 = (@type@*)args[2]; const npy_intp vsize = 64; - const npy_intp n = dimensions[0]; const npy_intp peel = npy_aligned_block_offset(ip1, sizeof(@type@), vsize, n); const npy_intp blocked_end = npy_blocked_end(peel, sizeof(@type@), vsize, n); npy_intp i; for(i = 0; i < peel; i++) { - op1[i] = ip1[i] @OP@ ip2[i]; + op1[i] = ip1[i] - ip2[i]; } { @@ -349,19 +547,10 @@ mkl_umath_@TYPE@_@kind@(char **args, const npy_intp *dimensions, const npy_intp @type@ *op1_shifted = op1 + peel; @type@ *ip2_shifted = ip2 + peel; - if (DISJOINT_OR_SAME(op1_shifted, ip1_aligned, j_max, 1) && - DISJOINT_OR_SAME(op1_shifted, ip2_shifted, j_max, 1)) { - NPY_ASSUME_ALIGNED(ip1_aligned, 64) - NPY_PRAGMA_VECTOR - for(j = 0; j < j_max; j++) { - op1_shifted[j] = ip1_aligned[j] @OP@ ip2_shifted[j]; - } - } - else { - NPY_ASSUME_ALIGNED(ip1_aligned, 64) - for(j = 0; j < j_max; j++) { - op1_shifted[j] = ip1_aligned[j] @OP@ ip2_shifted[j]; - } + NPY_ASSUME_ALIGNED(ip1_aligned, 64) + NPY_PRAGMA_VECTOR + for(j = 0; j < j_max; j++) { + op1_shifted[j] = ip1_aligned[j] - ip2_shifted[j]; } i = blocked_end; @@ -369,14 +558,13 @@ mkl_umath_@TYPE@_@kind@(char **args, const npy_intp *dimensions, const npy_intp } for(; i < n; i++) { - op1[i] = ip1[i] @OP@ ip2[i]; + op1[i] = ip1[i] - ip2[i]; } } } else if (IS_BINARY_CONT_S1(@type@, @type@)) { - if (dimensions[0] > VML_ASM_THRESHOLD && disjoint_or_same2) { - CHUNKED_VML_LINEARFRAC_CALL(v@s@LinearFrac, dimensions[0], @type@, args[1], args[2], @OP@1.0, *(@type@*)args[0], 0.0, 1.0); - /* v@s@LinearFrac(dimensions[0], (@type@*) args[1], (@type@*) args[1], @OP@1.0, *(@type@*)args[0], 0.0, 1.0, (@type@*) args[2]); */ + if (dimensions[0] > VML_SIMPLE_OPS_THRESHOLD && disjoint_or_same2) { + CHUNKED_VML_LINEARFRAC_CALL(v@s@LinearFrac, dimensions[0], @type@, args[1], args[2], -1.0, *(@type@*)args[0], 0.0, 1.0); } else { @type@ *ip1 = (@type@*)args[0]; @@ -390,7 +578,7 @@ mkl_umath_@TYPE@_@kind@(char **args, const npy_intp *dimensions, const npy_intp const @type@ ip1c = ip1[0]; for(i = 0; i < peel; i++) { - op1[i] = ip1c @OP@ ip2[i]; + op1[i] = ip1c - ip2[i]; } { @@ -401,7 +589,7 @@ mkl_umath_@TYPE@_@kind@(char **args, const npy_intp *dimensions, const npy_intp NPY_ASSUME_ALIGNED(ip2_aligned, 64) for(j = 0; j < j_max; j++) { - op1_shifted[j] = ip1c @OP@ ip2_aligned[j]; + op1_shifted[j] = ip1c - ip2_aligned[j]; } i = blocked_end; @@ -409,14 +597,13 @@ mkl_umath_@TYPE@_@kind@(char **args, const npy_intp *dimensions, const npy_intp } for(; i < n; i++) { - op1[i] = ip1c @OP@ ip2[i]; + op1[i] = ip1c - ip2[i]; } } } else if (IS_BINARY_CONT_S2(@type@, @type@)) { - if (dimensions[0] > VML_ASM_THRESHOLD && disjoint_or_same1) { - CHUNKED_VML_LINEARFRAC_CALL(v@s@LinearFrac, dimensions[0], @type@, args[0], args[2], 1.0, @OP@(*(@type@*)args[1]), 0.0, 1.0); - /* v@s@LinearFrac(dimensions[0], (@type@*) args[0], (@type@*) args[0], 1.0, @OP@(*(@type@*)args[1]), 0.0, 1.0, (@type@*) args[2]); */ + if (dimensions[0] > VML_SIMPLE_OPS_THRESHOLD && disjoint_or_same1) { + CHUNKED_VML_LINEARFRAC_CALL(v@s@LinearFrac, dimensions[0], @type@, args[0], args[2], 1.0, -(*(@type@*)args[1]), 0.0, 1.0); } else { @type@ *ip1 = (@type@*)args[0]; @@ -430,7 +617,7 @@ mkl_umath_@TYPE@_@kind@(char **args, const npy_intp *dimensions, const npy_intp const @type@ ip2c = ip2[0]; for(i = 0; i < peel; i++) { - op1[i] = ip1[i] @OP@ ip2c; + op1[i] = ip1[i] - ip2c; } { @@ -441,7 +628,7 @@ mkl_umath_@TYPE@_@kind@(char **args, const npy_intp *dimensions, const npy_intp NPY_ASSUME_ALIGNED(ip1_aligned, 64) for(j = 0; j < j_max; j++) { - op1_shifted[j] = ip1_aligned[j] @OP@ ip2c; + op1_shifted[j] = ip1_aligned[j] - ip2c; } i = blocked_end; @@ -449,41 +636,48 @@ mkl_umath_@TYPE@_@kind@(char **args, const npy_intp *dimensions, const npy_intp } for(; i < n; i++) { - op1[i] = ip1[i] @OP@ ip2c; + op1[i] = ip1[i] - ip2c; } } } else if (IS_BINARY_REDUCE) { -#if @PW@ - @type@ * iop1 = (@type@ *)args[0]; - npy_intp n = dimensions[0]; - - *iop1 @OP@= pairwise_sum_@TYPE@(args[1], n, steps[1]); -#else BINARY_REDUCE_LOOP(@type@) { - io1 @OP@= *(@type@ *)ip2; + io1 -= *(@type@ *)ip2; } *((@type@ *)iop1) = io1; -#endif } else { BINARY_LOOP { const @type@ in1 = *(@type@ *)ip1; const @type@ in2 = *(@type@ *)ip2; - *((@type@ *)op1) = in1 @OP@ in2; + *((@type@ *)op1) = in1 - in2; } } } -/**end repeat1**/ void mkl_umath_@TYPE@_multiply(char **args, const npy_intp *dimensions, const npy_intp *steps, void *NPY_UNUSED(func)) { + const npy_intp n = dimensions[0]; + + /* FAST PATH: Small contiguous arrays - avoid all overhead */ + if (IS_BINARY_CONT(@type@, @type@) && n <= SMALL_ARRAY_THRESHOLD) { + @type@ *ip1 = (@type@*)args[0]; + @type@ *ip2 = (@type@*)args[1]; + @type@ *op1 = (@type@*)args[2]; + + NPY_PRAGMA_VECTOR + for(npy_intp i = 0; i < n; i++) { + op1[i] = ip1[i] * ip2[i]; + } + return; + } + const int disjoint_or_same1 = DISJOINT_OR_SAME(args[0], args[2], dimensions[0], sizeof(@type@)); const int disjoint_or_same2 = DISJOINT_OR_SAME(args[1], args[2], dimensions[0], sizeof(@type@)); if (IS_BINARY_CONT(@type@, @type@)) { - if (dimensions[0] > VML_ASM_THRESHOLD && disjoint_or_same1 && disjoint_or_same2) { + if (dimensions[0] > VML_SIMPLE_OPS_THRESHOLD && disjoint_or_same1 && disjoint_or_same2) { CHUNKED_VML_CALL3(v@s@Mul, dimensions[0], @type@, @type@, args[0], args[1], args[2]); /* v@s@Mul(dimensions[0], (@type@*) args[0], (@type@*) args[1], (@type@*) args[2]); */ } @@ -492,7 +686,6 @@ mkl_umath_@TYPE@_multiply(char **args, const npy_intp *dimensions, const npy_int @type@ *ip2 = (@type@*)args[1]; @type@ *op1 = (@type@*)args[2]; const npy_intp vsize = 64; - const npy_intp n = dimensions[0]; const npy_intp peel = npy_aligned_block_offset(ip1, sizeof(@type@), vsize, n); const npy_intp blocked_end = npy_blocked_end(peel, sizeof(@type@), vsize, n); npy_intp i; @@ -508,19 +701,11 @@ mkl_umath_@TYPE@_multiply(char **args, const npy_intp *dimensions, const npy_int @type@ *op1_shifted = op1 + peel; @type@ *ip2_shifted = ip2 + peel; - if ( DISJOINT_OR_SAME(op1_shifted, ip1_aligned, j_max, 1) && - DISJOINT_OR_SAME(op1_shifted, ip2_shifted, j_max, 1)) { - NPY_ASSUME_ALIGNED(ip1_aligned, 64) - NPY_PRAGMA_VECTOR - for(j = 0; j < j_max; j++) { - op1_shifted[j] = ip1_aligned[j] * ip2_shifted[j]; - } - } - else { - NPY_ASSUME_ALIGNED(ip1_aligned, 64) - for(j = 0; j < j_max; j++) { - op1_shifted[j] = ip1_aligned[j] * ip2_shifted[j]; - } + /* Redundant check removed - already verified disjoint_or_same above */ + NPY_ASSUME_ALIGNED(ip1_aligned, 64) + NPY_PRAGMA_VECTOR + for(j = 0; j < j_max; j++) { + op1_shifted[j] = ip1_aligned[j] * ip2_shifted[j]; } i = blocked_end; @@ -533,7 +718,7 @@ mkl_umath_@TYPE@_multiply(char **args, const npy_intp *dimensions, const npy_int } } else if (IS_BINARY_CONT_S1(@type@, @type@)) { - if (dimensions[0] > VML_ASM_THRESHOLD && disjoint_or_same2) { + if (dimensions[0] > VML_SIMPLE_OPS_THRESHOLD && disjoint_or_same2) { CHUNKED_VML_LINEARFRAC_CALL(v@s@LinearFrac, dimensions[0], @type@, args[1], args[2], *(@type@*)args[0], 0.0, 0.0, 1.0); /* v@s@LinearFrac(dimensions[0], (@type@*) args[1], (@type@*) args[1], *(@type@*)args[0], 0.0, 0.0, 1.0, (@type@*) args[2]); */ } diff --git a/mkl_umath/src/npyv/array_assign.h b/mkl_umath/src/npyv/array_assign.h new file mode 100644 index 00000000..8a28ed1d --- /dev/null +++ b/mkl_umath/src/npyv/array_assign.h @@ -0,0 +1,118 @@ +#ifndef NUMPY_CORE_SRC_COMMON_ARRAY_ASSIGN_H_ +#define NUMPY_CORE_SRC_COMMON_ARRAY_ASSIGN_H_ + +/* + * An array assignment function for copying arrays, treating the + * arrays as flat according to their respective ordering rules. + * This function makes a temporary copy of 'src' if 'src' and + * 'dst' overlap, to be able to handle views of the same data with + * different strides. + * + * dst: The destination array. + * dst_order: The rule for how 'dst' is to be made flat. + * src: The source array. + * src_order: The rule for how 'src' is to be made flat. + * casting: An exception is raised if the copy violates this + * casting rule. + * + * Returns 0 on success, -1 on failure. + */ +/* Not yet implemented +NPY_NO_EXPORT int +PyArray_AssignArrayAsFlat(PyArrayObject *dst, NPY_ORDER dst_order, + PyArrayObject *src, NPY_ORDER src_order, + NPY_CASTING casting, + npy_bool preservena, npy_bool *preservewhichna); +*/ + +NPY_NO_EXPORT int +PyArray_AssignArray(PyArrayObject *dst, PyArrayObject *src, + PyArrayObject *wheremask, + NPY_CASTING casting); + +NPY_NO_EXPORT int +PyArray_AssignRawScalar(PyArrayObject *dst, + PyArray_Descr *src_dtype, char *src_data, + PyArrayObject *wheremask, + NPY_CASTING casting); + +/******** LOW-LEVEL SCALAR TO ARRAY ASSIGNMENT ********/ + +/* + * Assigns the scalar value to every element of the destination raw array. + * + * Returns 0 on success, -1 on failure. + */ +NPY_NO_EXPORT int +raw_array_assign_scalar(int ndim, npy_intp const *shape, + PyArray_Descr *dst_dtype, char *dst_data, npy_intp const *dst_strides, + PyArray_Descr *src_dtype, char *src_data); + +/* + * Assigns the scalar value to every element of the destination raw array + * where the 'wheremask' value is True. + * + * Returns 0 on success, -1 on failure. + */ +NPY_NO_EXPORT int +raw_array_wheremasked_assign_scalar(int ndim, npy_intp const *shape, + PyArray_Descr *dst_dtype, char *dst_data, npy_intp const *dst_strides, + PyArray_Descr *src_dtype, char *src_data, + PyArray_Descr *wheremask_dtype, char *wheremask_data, + npy_intp const *wheremask_strides); + +/******** LOW-LEVEL ARRAY MANIPULATION HELPERS ********/ + +/* + * Internal detail of how much to buffer during array assignments which + * need it. This is for more complex NA masking operations where masks + * need to be inverted or combined together. + */ +#define NPY_ARRAY_ASSIGN_BUFFERSIZE 8192 + +/* + * Broadcasts strides to match the given dimensions. Can be used, + * for instance, to set up a raw iteration. + * + * 'strides_name' is used to produce an error message if the strides + * cannot be broadcast. + * + * Returns 0 on success, -1 on failure. + */ +NPY_NO_EXPORT int +broadcast_strides(int ndim, npy_intp const *shape, + int strides_ndim, npy_intp const *strides_shape, npy_intp const *strides, + char const *strides_name, + npy_intp *out_strides); + +/* + * Checks whether a data pointer + set of strides refers to a raw + * array whose elements are all aligned to a given alignment. Returns + * 1 if data is aligned to alignment or 0 if not. + * alignment should be a power of two, or may be the sentinel value 0 to mean + * cannot-be-aligned, in which case 0 (false) is always returned. + */ +NPY_NO_EXPORT int +raw_array_is_aligned(int ndim, npy_intp const *shape, + char *data, npy_intp const *strides, int alignment); + +/* + * Checks if an array is aligned to its "true alignment" + * given by dtype->alignment. + */ +NPY_NO_EXPORT int +IsAligned(PyArrayObject *ap); + +/* + * Checks if an array is aligned to its "uint alignment" + * given by npy_uint_alignment(dtype->elsize). + */ +NPY_NO_EXPORT int +IsUintAligned(PyArrayObject *ap); + +/* Returns 1 if the arrays have overlapping data, 0 otherwise */ +NPY_NO_EXPORT int +arrays_overlap(PyArrayObject *arr1, PyArrayObject *arr2); + + +#endif /* NUMPY_CORE_SRC_COMMON_ARRAY_ASSIGN_H_ */ diff --git a/mkl_umath/src/npyv/avx2/arithmetic.h b/mkl_umath/src/npyv/avx2/arithmetic.h new file mode 100644 index 00000000..58d842a6 --- /dev/null +++ b/mkl_umath/src/npyv/avx2/arithmetic.h @@ -0,0 +1,347 @@ +#ifndef NPY_SIMD + #error "Not a standalone header" +#endif + +#ifndef _NPY_SIMD_AVX2_ARITHMETIC_H +#define _NPY_SIMD_AVX2_ARITHMETIC_H + +#include "../sse/utils.h" +/*************************** + * Addition + ***************************/ +// non-saturated +#define npyv_add_u8 _mm256_add_epi8 +#define npyv_add_s8 _mm256_add_epi8 +#define npyv_add_u16 _mm256_add_epi16 +#define npyv_add_s16 _mm256_add_epi16 +#define npyv_add_u32 _mm256_add_epi32 +#define npyv_add_s32 _mm256_add_epi32 +#define npyv_add_u64 _mm256_add_epi64 +#define npyv_add_s64 _mm256_add_epi64 +#define npyv_add_f32 _mm256_add_ps +#define npyv_add_f64 _mm256_add_pd + +// saturated +#define npyv_adds_u8 _mm256_adds_epu8 +#define npyv_adds_s8 _mm256_adds_epi8 +#define npyv_adds_u16 _mm256_adds_epu16 +#define npyv_adds_s16 _mm256_adds_epi16 +// TODO: rest, after implement Packs intrins + +/*************************** + * Subtraction + ***************************/ +// non-saturated +#define npyv_sub_u8 _mm256_sub_epi8 +#define npyv_sub_s8 _mm256_sub_epi8 +#define npyv_sub_u16 _mm256_sub_epi16 +#define npyv_sub_s16 _mm256_sub_epi16 +#define npyv_sub_u32 _mm256_sub_epi32 +#define npyv_sub_s32 _mm256_sub_epi32 +#define npyv_sub_u64 _mm256_sub_epi64 +#define npyv_sub_s64 _mm256_sub_epi64 +#define npyv_sub_f32 _mm256_sub_ps +#define npyv_sub_f64 _mm256_sub_pd + +// saturated +#define npyv_subs_u8 _mm256_subs_epu8 +#define npyv_subs_s8 _mm256_subs_epi8 +#define npyv_subs_u16 _mm256_subs_epu16 +#define npyv_subs_s16 _mm256_subs_epi16 +// TODO: rest, after implement Packs intrins + +/*************************** + * Multiplication + ***************************/ +// non-saturated +#define npyv_mul_u8 npyv256_mul_u8 +#define npyv_mul_s8 npyv_mul_u8 +#define npyv_mul_u16 _mm256_mullo_epi16 +#define npyv_mul_s16 _mm256_mullo_epi16 +#define npyv_mul_u32 _mm256_mullo_epi32 +#define npyv_mul_s32 _mm256_mullo_epi32 +#define npyv_mul_f32 _mm256_mul_ps +#define npyv_mul_f64 _mm256_mul_pd + +// saturated +// TODO: after implement Packs intrins + +/*************************** + * Integer Division + ***************************/ +// See simd/intdiv.h for more clarification +// divide each unsigned 8-bit element by a precomputed divisor +NPY_FINLINE npyv_u8 npyv_divc_u8(npyv_u8 a, const npyv_u8x3 divisor) +{ + const __m256i bmask = _mm256_set1_epi32(0x00FF00FF); + const __m128i shf1 = _mm256_castsi256_si128(divisor.val[1]); + const __m128i shf2 = _mm256_castsi256_si128(divisor.val[2]); + const __m256i shf1b = _mm256_set1_epi8(0xFFU >> _mm_cvtsi128_si32(shf1)); + const __m256i shf2b = _mm256_set1_epi8(0xFFU >> _mm_cvtsi128_si32(shf2)); + // high part of unsigned multiplication + __m256i mulhi_even = _mm256_mullo_epi16(_mm256_and_si256(a, bmask), divisor.val[0]); + mulhi_even = _mm256_srli_epi16(mulhi_even, 8); + __m256i mulhi_odd = _mm256_mullo_epi16(_mm256_srli_epi16(a, 8), divisor.val[0]); + __m256i mulhi = _mm256_blendv_epi8(mulhi_odd, mulhi_even, bmask); + // floor(a/d) = (mulhi + ((a-mulhi) >> sh1)) >> sh2 + __m256i q = _mm256_sub_epi8(a, mulhi); + q = _mm256_and_si256(_mm256_srl_epi16(q, shf1), shf1b); + q = _mm256_add_epi8(mulhi, q); + q = _mm256_and_si256(_mm256_srl_epi16(q, shf2), shf2b); + return q; +} +// divide each signed 8-bit element by a precomputed divisor (round towards zero) +NPY_FINLINE npyv_s16 npyv_divc_s16(npyv_s16 a, const npyv_s16x3 divisor); +NPY_FINLINE npyv_s8 npyv_divc_s8(npyv_s8 a, const npyv_s8x3 divisor) +{ + const __m256i bmask = _mm256_set1_epi32(0x00FF00FF); + // instead of _mm256_cvtepi8_epi16/_mm256_packs_epi16 to wrap around overflow + __m256i divc_even = npyv_divc_s16(_mm256_srai_epi16(_mm256_slli_epi16(a, 8), 8), divisor); + __m256i divc_odd = npyv_divc_s16(_mm256_srai_epi16(a, 8), divisor); + divc_odd = _mm256_slli_epi16(divc_odd, 8); + return _mm256_blendv_epi8(divc_odd, divc_even, bmask); +} +// divide each unsigned 16-bit element by a precomputed divisor +NPY_FINLINE npyv_u16 npyv_divc_u16(npyv_u16 a, const npyv_u16x3 divisor) +{ + const __m128i shf1 = _mm256_castsi256_si128(divisor.val[1]); + const __m128i shf2 = _mm256_castsi256_si128(divisor.val[2]); + // high part of unsigned multiplication + __m256i mulhi = _mm256_mulhi_epu16(a, divisor.val[0]); + // floor(a/d) = (mulhi + ((a-mulhi) >> sh1)) >> sh2 + __m256i q = _mm256_sub_epi16(a, mulhi); + q = _mm256_srl_epi16(q, shf1); + q = _mm256_add_epi16(mulhi, q); + q = _mm256_srl_epi16(q, shf2); + return q; +} +// divide each signed 16-bit element by a precomputed divisor (round towards zero) +NPY_FINLINE npyv_s16 npyv_divc_s16(npyv_s16 a, const npyv_s16x3 divisor) +{ + const __m128i shf1 = _mm256_castsi256_si128(divisor.val[1]); + // high part of signed multiplication + __m256i mulhi = _mm256_mulhi_epi16(a, divisor.val[0]); + // q = ((a + mulhi) >> sh1) - XSIGN(a) + // trunc(a/d) = (q ^ dsign) - dsign + __m256i q = _mm256_sra_epi16(_mm256_add_epi16(a, mulhi), shf1); + q = _mm256_sub_epi16(q, _mm256_srai_epi16(a, 15)); + q = _mm256_sub_epi16(_mm256_xor_si256(q, divisor.val[2]), divisor.val[2]); + return q; +} +// divide each unsigned 32-bit element by a precomputed divisor +NPY_FINLINE npyv_u32 npyv_divc_u32(npyv_u32 a, const npyv_u32x3 divisor) +{ + const __m128i shf1 = _mm256_castsi256_si128(divisor.val[1]); + const __m128i shf2 = _mm256_castsi256_si128(divisor.val[2]); + // high part of unsigned multiplication + __m256i mulhi_even = _mm256_srli_epi64(_mm256_mul_epu32(a, divisor.val[0]), 32); + __m256i mulhi_odd = _mm256_mul_epu32(_mm256_srli_epi64(a, 32), divisor.val[0]); + __m256i mulhi = _mm256_blend_epi32(mulhi_even, mulhi_odd, 0xAA); + // floor(a/d) = (mulhi + ((a-mulhi) >> sh1)) >> sh2 + __m256i q = _mm256_sub_epi32(a, mulhi); + q = _mm256_srl_epi32(q, shf1); + q = _mm256_add_epi32(mulhi, q); + q = _mm256_srl_epi32(q, shf2); + return q; +} +// divide each signed 32-bit element by a precomputed divisor (round towards zero) +NPY_FINLINE npyv_s32 npyv_divc_s32(npyv_s32 a, const npyv_s32x3 divisor) +{ + const __m128i shf1 = _mm256_castsi256_si128(divisor.val[1]); + // high part of signed multiplication + __m256i mulhi_even = _mm256_srli_epi64(_mm256_mul_epi32(a, divisor.val[0]), 32); + __m256i mulhi_odd = _mm256_mul_epi32(_mm256_srli_epi64(a, 32), divisor.val[0]); + __m256i mulhi = _mm256_blend_epi32(mulhi_even, mulhi_odd, 0xAA); + // q = ((a + mulhi) >> sh1) - XSIGN(a) + // trunc(a/d) = (q ^ dsign) - dsign + __m256i q = _mm256_sra_epi32(_mm256_add_epi32(a, mulhi), shf1); + q = _mm256_sub_epi32(q, _mm256_srai_epi32(a, 31)); + q = _mm256_sub_epi32(_mm256_xor_si256(q, divisor.val[2]), divisor.val[2]); + return q; +} +// returns the high 64 bits of unsigned 64-bit multiplication +// xref https://stackoverflow.com/a/28827013 +NPY_FINLINE npyv_u64 npyv__mullhi_u64(npyv_u64 a, npyv_u64 b) +{ + __m256i lomask = npyv_setall_s64(0xffffffff); + __m256i a_hi = _mm256_srli_epi64(a, 32); // a0l, a0h, a1l, a1h + __m256i b_hi = _mm256_srli_epi64(b, 32); // b0l, b0h, b1l, b1h + // compute partial products + __m256i w0 = _mm256_mul_epu32(a, b); // a0l*b0l, a1l*b1l + __m256i w1 = _mm256_mul_epu32(a, b_hi); // a0l*b0h, a1l*b1h + __m256i w2 = _mm256_mul_epu32(a_hi, b); // a0h*b0l, a1h*b0l + __m256i w3 = _mm256_mul_epu32(a_hi, b_hi); // a0h*b0h, a1h*b1h + // sum partial products + __m256i w0h = _mm256_srli_epi64(w0, 32); + __m256i s1 = _mm256_add_epi64(w1, w0h); + __m256i s1l = _mm256_and_si256(s1, lomask); + __m256i s1h = _mm256_srli_epi64(s1, 32); + + __m256i s2 = _mm256_add_epi64(w2, s1l); + __m256i s2h = _mm256_srli_epi64(s2, 32); + + __m256i hi = _mm256_add_epi64(w3, s1h); + hi = _mm256_add_epi64(hi, s2h); + return hi; +} +// divide each unsigned 64-bit element by a divisor +NPY_FINLINE npyv_u64 npyv_divc_u64(npyv_u64 a, const npyv_u64x3 divisor) +{ + const __m128i shf1 = _mm256_castsi256_si128(divisor.val[1]); + const __m128i shf2 = _mm256_castsi256_si128(divisor.val[2]); + // high part of unsigned multiplication + __m256i mulhi = npyv__mullhi_u64(a, divisor.val[0]); + // floor(a/d) = (mulhi + ((a-mulhi) >> sh1)) >> sh2 + __m256i q = _mm256_sub_epi64(a, mulhi); + q = _mm256_srl_epi64(q, shf1); + q = _mm256_add_epi64(mulhi, q); + q = _mm256_srl_epi64(q, shf2); + return q; +} +// divide each unsigned 64-bit element by a divisor (round towards zero) +NPY_FINLINE npyv_s64 npyv_divc_s64(npyv_s64 a, const npyv_s64x3 divisor) +{ + const __m128i shf1 = _mm256_castsi256_si128(divisor.val[1]); + // high part of unsigned multiplication + __m256i mulhi = npyv__mullhi_u64(a, divisor.val[0]); + // convert unsigned to signed high multiplication + // mulhi - ((a < 0) ? m : 0) - ((m < 0) ? a : 0); + __m256i asign = _mm256_cmpgt_epi64(_mm256_setzero_si256(), a); + __m256i msign = _mm256_cmpgt_epi64(_mm256_setzero_si256(), divisor.val[0]); + __m256i m_asign = _mm256_and_si256(divisor.val[0], asign); + __m256i a_msign = _mm256_and_si256(a, msign); + mulhi = _mm256_sub_epi64(mulhi, m_asign); + mulhi = _mm256_sub_epi64(mulhi, a_msign); + // q = (a + mulhi) >> sh + __m256i q = _mm256_add_epi64(a, mulhi); + // emulate arithmetic right shift + const __m256i sigb = npyv_setall_s64(1LL << 63); + q = _mm256_srl_epi64(_mm256_add_epi64(q, sigb), shf1); + q = _mm256_sub_epi64(q, _mm256_srl_epi64(sigb, shf1)); + // q = q - XSIGN(a) + // trunc(a/d) = (q ^ dsign) - dsign + q = _mm256_sub_epi64(q, asign); + q = _mm256_sub_epi64(_mm256_xor_si256(q, divisor.val[2]), divisor.val[2]); + return q; +} +/*************************** + * Division + ***************************/ +// TODO: emulate integer division +#define npyv_div_f32 _mm256_div_ps +#define npyv_div_f64 _mm256_div_pd + +/*************************** + * FUSED + ***************************/ +#ifdef NPY_HAVE_FMA3 + // multiply and add, a*b + c + #define npyv_muladd_f32 _mm256_fmadd_ps + #define npyv_muladd_f64 _mm256_fmadd_pd + // multiply and subtract, a*b - c + #define npyv_mulsub_f32 _mm256_fmsub_ps + #define npyv_mulsub_f64 _mm256_fmsub_pd + // negate multiply and add, -(a*b) + c + #define npyv_nmuladd_f32 _mm256_fnmadd_ps + #define npyv_nmuladd_f64 _mm256_fnmadd_pd + // negate multiply and subtract, -(a*b) - c + #define npyv_nmulsub_f32 _mm256_fnmsub_ps + #define npyv_nmulsub_f64 _mm256_fnmsub_pd + // multiply, add for odd elements and subtract even elements. + // (a * b) -+ c + #define npyv_muladdsub_f32 _mm256_fmaddsub_ps + #define npyv_muladdsub_f64 _mm256_fmaddsub_pd +#else + // multiply and add, a*b + c + NPY_FINLINE npyv_f32 npyv_muladd_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c) + { return npyv_add_f32(npyv_mul_f32(a, b), c); } + NPY_FINLINE npyv_f64 npyv_muladd_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c) + { return npyv_add_f64(npyv_mul_f64(a, b), c); } + // multiply and subtract, a*b - c + NPY_FINLINE npyv_f32 npyv_mulsub_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c) + { return npyv_sub_f32(npyv_mul_f32(a, b), c); } + NPY_FINLINE npyv_f64 npyv_mulsub_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c) + { return npyv_sub_f64(npyv_mul_f64(a, b), c); } + // negate multiply and add, -(a*b) + c + NPY_FINLINE npyv_f32 npyv_nmuladd_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c) + { return npyv_sub_f32(c, npyv_mul_f32(a, b)); } + NPY_FINLINE npyv_f64 npyv_nmuladd_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c) + { return npyv_sub_f64(c, npyv_mul_f64(a, b)); } + // negate multiply and subtract, -(a*b) - c + NPY_FINLINE npyv_f32 npyv_nmulsub_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c) + { + npyv_f32 neg_a = npyv_xor_f32(a, npyv_setall_f32(-0.0f)); + return npyv_sub_f32(npyv_mul_f32(neg_a, b), c); + } + NPY_FINLINE npyv_f64 npyv_nmulsub_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c) + { + npyv_f64 neg_a = npyv_xor_f64(a, npyv_setall_f64(-0.0)); + return npyv_sub_f64(npyv_mul_f64(neg_a, b), c); + } + // multiply, add for odd elements and subtract even elements. + // (a * b) -+ c + NPY_FINLINE npyv_f32 npyv_muladdsub_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c) + { return _mm256_addsub_ps(npyv_mul_f32(a, b), c); } + NPY_FINLINE npyv_f64 npyv_muladdsub_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c) + { return _mm256_addsub_pd(npyv_mul_f64(a, b), c); } + +#endif // !NPY_HAVE_FMA3 + +/*************************** + * Summation + ***************************/ +// reduce sum across vector +NPY_FINLINE npy_uint32 npyv_sum_u32(npyv_u32 a) +{ + __m256i s0 = _mm256_hadd_epi32(a, a); + s0 = _mm256_hadd_epi32(s0, s0); + __m128i s1 = _mm256_extracti128_si256(s0, 1); + s1 = _mm_add_epi32(_mm256_castsi256_si128(s0), s1); + return _mm_cvtsi128_si32(s1); +} + +NPY_FINLINE npy_uint64 npyv_sum_u64(npyv_u64 a) +{ + __m256i two = _mm256_add_epi64(a, _mm256_shuffle_epi32(a, _MM_SHUFFLE(1, 0, 3, 2))); + __m128i one = _mm_add_epi64(_mm256_castsi256_si128(two), _mm256_extracti128_si256(two, 1)); + return (npy_uint64)npyv128_cvtsi128_si64(one); +} + +NPY_FINLINE float npyv_sum_f32(npyv_f32 a) +{ + __m256 sum_halves = _mm256_hadd_ps(a, a); + sum_halves = _mm256_hadd_ps(sum_halves, sum_halves); + __m128 lo = _mm256_castps256_ps128(sum_halves); + __m128 hi = _mm256_extractf128_ps(sum_halves, 1); + __m128 sum = _mm_add_ps(lo, hi); + return _mm_cvtss_f32(sum); +} + +NPY_FINLINE double npyv_sum_f64(npyv_f64 a) +{ + __m256d sum_halves = _mm256_hadd_pd(a, a); + __m128d lo = _mm256_castpd256_pd128(sum_halves); + __m128d hi = _mm256_extractf128_pd(sum_halves, 1); + __m128d sum = _mm_add_pd(lo, hi); + return _mm_cvtsd_f64(sum); +} + +// expand the source vector and performs sum reduce +NPY_FINLINE npy_uint16 npyv_sumup_u8(npyv_u8 a) +{ + __m256i four = _mm256_sad_epu8(a, _mm256_setzero_si256()); + __m128i two = _mm_add_epi16(_mm256_castsi256_si128(four), _mm256_extracti128_si256(four, 1)); + __m128i one = _mm_add_epi16(two, _mm_unpackhi_epi64(two, two)); + return (npy_uint16)_mm_cvtsi128_si32(one); +} + +NPY_FINLINE npy_uint32 npyv_sumup_u16(npyv_u16 a) +{ + const npyv_u16 even_mask = _mm256_set1_epi32(0x0000FFFF); + __m256i even = _mm256_and_si256(a, even_mask); + __m256i odd = _mm256_srli_epi32(a, 16); + __m256i eight = _mm256_add_epi32(even, odd); + return npyv_sum_u32(eight); +} + +#endif // _NPY_SIMD_AVX2_ARITHMETIC_H diff --git a/mkl_umath/src/npyv/avx2/avx2.h b/mkl_umath/src/npyv/avx2/avx2.h new file mode 100644 index 00000000..d64f3c6d --- /dev/null +++ b/mkl_umath/src/npyv/avx2/avx2.h @@ -0,0 +1,77 @@ +#ifndef _NPY_SIMD_H_ + #error "Not a standalone header" +#endif +#define NPY_SIMD 256 +#define NPY_SIMD_WIDTH 32 +#define NPY_SIMD_F32 1 +#define NPY_SIMD_F64 1 +#ifdef NPY_HAVE_FMA3 + #define NPY_SIMD_FMA3 1 // native support +#else + #define NPY_SIMD_FMA3 0 // fast emulated +#endif +#define NPY_SIMD_BIGENDIAN 0 +#define NPY_SIMD_CMPSIGNAL 0 +// Enough limit to allow us to use _mm256_i32gather_* +#define NPY_SIMD_MAXLOAD_STRIDE32 (0x7fffffff / 8) + +typedef __m256i npyv_u8; +typedef __m256i npyv_s8; +typedef __m256i npyv_u16; +typedef __m256i npyv_s16; +typedef __m256i npyv_u32; +typedef __m256i npyv_s32; +typedef __m256i npyv_u64; +typedef __m256i npyv_s64; +typedef __m256 npyv_f32; +typedef __m256d npyv_f64; + +typedef __m256i npyv_b8; +typedef __m256i npyv_b16; +typedef __m256i npyv_b32; +typedef __m256i npyv_b64; + +typedef struct { __m256i val[2]; } npyv_m256ix2; +typedef npyv_m256ix2 npyv_u8x2; +typedef npyv_m256ix2 npyv_s8x2; +typedef npyv_m256ix2 npyv_u16x2; +typedef npyv_m256ix2 npyv_s16x2; +typedef npyv_m256ix2 npyv_u32x2; +typedef npyv_m256ix2 npyv_s32x2; +typedef npyv_m256ix2 npyv_u64x2; +typedef npyv_m256ix2 npyv_s64x2; + +typedef struct { __m256i val[3]; } npyv_m256ix3; +typedef npyv_m256ix3 npyv_u8x3; +typedef npyv_m256ix3 npyv_s8x3; +typedef npyv_m256ix3 npyv_u16x3; +typedef npyv_m256ix3 npyv_s16x3; +typedef npyv_m256ix3 npyv_u32x3; +typedef npyv_m256ix3 npyv_s32x3; +typedef npyv_m256ix3 npyv_u64x3; +typedef npyv_m256ix3 npyv_s64x3; + +typedef struct { __m256 val[2]; } npyv_f32x2; +typedef struct { __m256d val[2]; } npyv_f64x2; +typedef struct { __m256 val[3]; } npyv_f32x3; +typedef struct { __m256d val[3]; } npyv_f64x3; + +#define npyv_nlanes_u8 32 +#define npyv_nlanes_s8 32 +#define npyv_nlanes_u16 16 +#define npyv_nlanes_s16 16 +#define npyv_nlanes_u32 8 +#define npyv_nlanes_s32 8 +#define npyv_nlanes_u64 4 +#define npyv_nlanes_s64 4 +#define npyv_nlanes_f32 8 +#define npyv_nlanes_f64 4 + +#include "utils.h" +#include "memory.h" +#include "misc.h" +#include "reorder.h" +#include "operators.h" +#include "conversion.h" +#include "arithmetic.h" +#include "math.h" diff --git a/mkl_umath/src/npyv/avx2/conversion.h b/mkl_umath/src/npyv/avx2/conversion.h new file mode 100644 index 00000000..00ac0d38 --- /dev/null +++ b/mkl_umath/src/npyv/avx2/conversion.h @@ -0,0 +1,99 @@ +#ifndef NPY_SIMD + #error "Not a standalone header" +#endif + +#ifndef _NPY_SIMD_AVX2_CVT_H +#define _NPY_SIMD_AVX2_CVT_H + +// convert mask types to integer types +#define npyv_cvt_u8_b8(A) A +#define npyv_cvt_s8_b8(A) A +#define npyv_cvt_u16_b16(A) A +#define npyv_cvt_s16_b16(A) A +#define npyv_cvt_u32_b32(A) A +#define npyv_cvt_s32_b32(A) A +#define npyv_cvt_u64_b64(A) A +#define npyv_cvt_s64_b64(A) A +#define npyv_cvt_f32_b32 _mm256_castsi256_ps +#define npyv_cvt_f64_b64 _mm256_castsi256_pd + +// convert integer types to mask types +#define npyv_cvt_b8_u8(BL) BL +#define npyv_cvt_b8_s8(BL) BL +#define npyv_cvt_b16_u16(BL) BL +#define npyv_cvt_b16_s16(BL) BL +#define npyv_cvt_b32_u32(BL) BL +#define npyv_cvt_b32_s32(BL) BL +#define npyv_cvt_b64_u64(BL) BL +#define npyv_cvt_b64_s64(BL) BL +#define npyv_cvt_b32_f32 _mm256_castps_si256 +#define npyv_cvt_b64_f64 _mm256_castpd_si256 + +// convert boolean vector to integer bitfield +NPY_FINLINE npy_uint64 npyv_tobits_b8(npyv_b8 a) +{ return (npy_uint32)_mm256_movemask_epi8(a); } + +NPY_FINLINE npy_uint64 npyv_tobits_b16(npyv_b16 a) +{ + __m128i pack = _mm_packs_epi16(_mm256_castsi256_si128(a), _mm256_extracti128_si256(a, 1)); + return (npy_uint16)_mm_movemask_epi8(pack); +} +NPY_FINLINE npy_uint64 npyv_tobits_b32(npyv_b32 a) +{ return (npy_uint8)_mm256_movemask_ps(_mm256_castsi256_ps(a)); } +NPY_FINLINE npy_uint64 npyv_tobits_b64(npyv_b64 a) +{ return (npy_uint8)_mm256_movemask_pd(_mm256_castsi256_pd(a)); } + +// expand +NPY_FINLINE npyv_u16x2 npyv_expand_u16_u8(npyv_u8 data) { + npyv_u16x2 r; + r.val[0] = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(data)); + r.val[1] = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(data, 1)); + return r; +} + +NPY_FINLINE npyv_u32x2 npyv_expand_u32_u16(npyv_u16 data) { + npyv_u32x2 r; + r.val[0] = _mm256_cvtepu16_epi32(_mm256_castsi256_si128(data)); + r.val[1] = _mm256_cvtepu16_epi32(_mm256_extracti128_si256(data, 1)); + return r; +} + +// pack two 16-bit boolean into one 8-bit boolean vector +NPY_FINLINE npyv_b8 npyv_pack_b8_b16(npyv_b16 a, npyv_b16 b) { + __m256i ab = _mm256_packs_epi16(a, b); + return npyv256_shuffle_odd(ab); +} + +// pack four 32-bit boolean vectors into one 8-bit boolean vector +NPY_FINLINE npyv_b8 +npyv_pack_b8_b32(npyv_b32 a, npyv_b32 b, npyv_b32 c, npyv_b32 d) { + __m256i ab = _mm256_packs_epi32(a, b); + __m256i cd = _mm256_packs_epi32(c, d); + __m256i abcd = npyv_pack_b8_b16(ab, cd); + return _mm256_shuffle_epi32(abcd, _MM_SHUFFLE(3, 1, 2, 0)); +} + +// pack eight 64-bit boolean vectors into one 8-bit boolean vector +NPY_FINLINE npyv_b8 +npyv_pack_b8_b64(npyv_b64 a, npyv_b64 b, npyv_b64 c, npyv_b64 d, + npyv_b64 e, npyv_b64 f, npyv_b64 g, npyv_b64 h) { + __m256i ab = _mm256_packs_epi32(a, b); + __m256i cd = _mm256_packs_epi32(c, d); + __m256i ef = _mm256_packs_epi32(e, f); + __m256i gh = _mm256_packs_epi32(g, h); + __m256i abcd = _mm256_packs_epi32(ab, cd); + __m256i efgh = _mm256_packs_epi32(ef, gh); + __m256i all = npyv256_shuffle_odd(_mm256_packs_epi16(abcd, efgh)); + __m256i rev128 = _mm256_alignr_epi8(all, all, 8); + return _mm256_unpacklo_epi16(all, rev128); +} + +// round to nearest integer (assuming even) +#define npyv_round_s32_f32 _mm256_cvtps_epi32 +NPY_FINLINE npyv_s32 npyv_round_s32_f64(npyv_f64 a, npyv_f64 b) +{ + __m128i lo = _mm256_cvtpd_epi32(a), hi = _mm256_cvtpd_epi32(b); + return _mm256_inserti128_si256(_mm256_castsi128_si256(lo), hi, 1); +} + +#endif // _NPY_SIMD_AVX2_CVT_H diff --git a/mkl_umath/src/npyv/avx2/math.h b/mkl_umath/src/npyv/avx2/math.h new file mode 100644 index 00000000..1ef9cd36 --- /dev/null +++ b/mkl_umath/src/npyv/avx2/math.h @@ -0,0 +1,253 @@ +#ifndef NPY_SIMD + #error "Not a standalone header" +#endif + +#ifndef _NPY_SIMD_AVX2_MATH_H +#define _NPY_SIMD_AVX2_MATH_H +/*************************** + * Elementary + ***************************/ +// Square root +#define npyv_sqrt_f32 _mm256_sqrt_ps +#define npyv_sqrt_f64 _mm256_sqrt_pd + +// Reciprocal +NPY_FINLINE npyv_f32 npyv_recip_f32(npyv_f32 a) +{ return _mm256_div_ps(_mm256_set1_ps(1.0f), a); } +NPY_FINLINE npyv_f64 npyv_recip_f64(npyv_f64 a) +{ return _mm256_div_pd(_mm256_set1_pd(1.0), a); } + +// Absolute +NPY_FINLINE npyv_f32 npyv_abs_f32(npyv_f32 a) +{ + return _mm256_and_ps( + a, _mm256_castsi256_ps(_mm256_set1_epi32(0x7fffffff)) + ); +} +NPY_FINLINE npyv_f64 npyv_abs_f64(npyv_f64 a) +{ + return _mm256_and_pd( + a, _mm256_castsi256_pd(npyv_setall_s64(0x7fffffffffffffffLL)) + ); +} + +// Square +NPY_FINLINE npyv_f32 npyv_square_f32(npyv_f32 a) +{ return _mm256_mul_ps(a, a); } +NPY_FINLINE npyv_f64 npyv_square_f64(npyv_f64 a) +{ return _mm256_mul_pd(a, a); } + +// Maximum, natively mapping with no guarantees to handle NaN. +#define npyv_max_f32 _mm256_max_ps +#define npyv_max_f64 _mm256_max_pd +// Maximum, supports IEEE floating-point arithmetic (IEC 60559), +// - If one of the two vectors contains NaN, the equivalent element of the other vector is set +// - Only if both corresponded elements are NaN, NaN is set. +NPY_FINLINE npyv_f32 npyv_maxp_f32(npyv_f32 a, npyv_f32 b) +{ + __m256 nn = _mm256_cmp_ps(b, b, _CMP_ORD_Q); + __m256 max = _mm256_max_ps(a, b); + return _mm256_blendv_ps(a, max, nn); +} +NPY_FINLINE npyv_f64 npyv_maxp_f64(npyv_f64 a, npyv_f64 b) +{ + __m256d nn = _mm256_cmp_pd(b, b, _CMP_ORD_Q); + __m256d max = _mm256_max_pd(a, b); + return _mm256_blendv_pd(a, max, nn); +} +// Maximum, propagates NaNs +// If any of corresponded elements is NaN, NaN is set. +NPY_FINLINE npyv_f32 npyv_maxn_f32(npyv_f32 a, npyv_f32 b) +{ + __m256 nn = _mm256_cmp_ps(a, a, _CMP_ORD_Q); + __m256 max = _mm256_max_ps(a, b); + return _mm256_blendv_ps(a, max, nn); +} +NPY_FINLINE npyv_f64 npyv_maxn_f64(npyv_f64 a, npyv_f64 b) +{ + __m256d nn = _mm256_cmp_pd(a, a, _CMP_ORD_Q); + __m256d max = _mm256_max_pd(a, b); + return _mm256_blendv_pd(a, max, nn); +} + +// Maximum, integer operations +#define npyv_max_u8 _mm256_max_epu8 +#define npyv_max_s8 _mm256_max_epi8 +#define npyv_max_u16 _mm256_max_epu16 +#define npyv_max_s16 _mm256_max_epi16 +#define npyv_max_u32 _mm256_max_epu32 +#define npyv_max_s32 _mm256_max_epi32 +NPY_FINLINE npyv_u64 npyv_max_u64(npyv_u64 a, npyv_u64 b) +{ + return _mm256_blendv_epi8(b, a, npyv_cmpgt_u64(a, b)); +} +NPY_FINLINE npyv_s64 npyv_max_s64(npyv_s64 a, npyv_s64 b) +{ + return _mm256_blendv_epi8(b, a, _mm256_cmpgt_epi64(a, b)); +} + +// Minimum, natively mapping with no guarantees to handle NaN. +#define npyv_min_f32 _mm256_min_ps +#define npyv_min_f64 _mm256_min_pd +// Minimum, supports IEEE floating-point arithmetic (IEC 60559), +// - If one of the two vectors contains NaN, the equivalent element of the other vector is set +// - Only if both corresponded elements are NaN, NaN is set. +NPY_FINLINE npyv_f32 npyv_minp_f32(npyv_f32 a, npyv_f32 b) +{ + __m256 nn = _mm256_cmp_ps(b, b, _CMP_ORD_Q); + __m256 min = _mm256_min_ps(a, b); + return _mm256_blendv_ps(a, min, nn); +} +NPY_FINLINE npyv_f64 npyv_minp_f64(npyv_f64 a, npyv_f64 b) +{ + __m256d nn = _mm256_cmp_pd(b, b, _CMP_ORD_Q); + __m256d min = _mm256_min_pd(a, b); + return _mm256_blendv_pd(a, min, nn); +} +// Minimum, propagates NaNs +// If any of corresponded element is NaN, NaN is set. +NPY_FINLINE npyv_f32 npyv_minn_f32(npyv_f32 a, npyv_f32 b) +{ + __m256 nn = _mm256_cmp_ps(a, a, _CMP_ORD_Q); + __m256 min = _mm256_min_ps(a, b); + return _mm256_blendv_ps(a, min, nn); +} +NPY_FINLINE npyv_f64 npyv_minn_f64(npyv_f64 a, npyv_f64 b) +{ + __m256d nn = _mm256_cmp_pd(a, a, _CMP_ORD_Q); + __m256d min = _mm256_min_pd(a, b); + return _mm256_blendv_pd(a, min, nn); +} +// Minimum, integer operations +#define npyv_min_u8 _mm256_min_epu8 +#define npyv_min_s8 _mm256_min_epi8 +#define npyv_min_u16 _mm256_min_epu16 +#define npyv_min_s16 _mm256_min_epi16 +#define npyv_min_u32 _mm256_min_epu32 +#define npyv_min_s32 _mm256_min_epi32 +NPY_FINLINE npyv_u64 npyv_min_u64(npyv_u64 a, npyv_u64 b) +{ + return _mm256_blendv_epi8(b, a, npyv_cmplt_u64(a, b)); +} +NPY_FINLINE npyv_s64 npyv_min_s64(npyv_s64 a, npyv_s64 b) +{ + return _mm256_blendv_epi8(a, b, _mm256_cmpgt_epi64(a, b)); +} +// reduce min&max for 32&64-bits +#define NPY_IMPL_AVX2_REDUCE_MINMAX(STYPE, INTRIN, VINTRIN) \ + NPY_FINLINE STYPE##32 npyv_reduce_##INTRIN##32(__m256i a) \ + { \ + __m128i v128 = _mm_##VINTRIN##32(_mm256_castsi256_si128(a), _mm256_extracti128_si256(a, 1)); \ + __m128i v64 = _mm_##VINTRIN##32(v128, _mm_shuffle_epi32(v128, _MM_SHUFFLE(0, 0, 3, 2))); \ + __m128i v32 = _mm_##VINTRIN##32(v64, _mm_shuffle_epi32(v64, _MM_SHUFFLE(0, 0, 0, 1))); \ + return (STYPE##32)_mm_cvtsi128_si32(v32); \ + } \ + NPY_FINLINE STYPE##64 npyv_reduce_##INTRIN##64(__m256i a) \ + { \ + __m256i v128 = npyv_##INTRIN##64(a, _mm256_permute2f128_si256(a, a, _MM_SHUFFLE(0, 0, 0, 1))); \ + __m256i v64 = npyv_##INTRIN##64(v128, _mm256_shuffle_epi32(v128, _MM_SHUFFLE(0, 0, 3, 2))); \ + return (STYPE##64)npyv_extract0_u64(v64); \ + } +NPY_IMPL_AVX2_REDUCE_MINMAX(npy_uint, min_u, min_epu) +NPY_IMPL_AVX2_REDUCE_MINMAX(npy_int, min_s, min_epi) +NPY_IMPL_AVX2_REDUCE_MINMAX(npy_uint, max_u, max_epu) +NPY_IMPL_AVX2_REDUCE_MINMAX(npy_int, max_s, max_epi) +#undef NPY_IMPL_AVX2_REDUCE_MINMAX + +// reduce min&max for ps & pd +#define NPY_IMPL_AVX2_REDUCE_MINMAX(INTRIN, INF, INF64) \ + NPY_FINLINE float npyv_reduce_##INTRIN##_f32(npyv_f32 a) \ + { \ + __m128 v128 = _mm_##INTRIN##_ps(_mm256_castps256_ps128(a), _mm256_extractf128_ps(a, 1)); \ + __m128 v64 = _mm_##INTRIN##_ps(v128, _mm_shuffle_ps(v128, v128, _MM_SHUFFLE(0, 0, 3, 2))); \ + __m128 v32 = _mm_##INTRIN##_ps(v64, _mm_shuffle_ps(v64, v64, _MM_SHUFFLE(0, 0, 0, 1))); \ + return _mm_cvtss_f32(v32); \ + } \ + NPY_FINLINE double npyv_reduce_##INTRIN##_f64(npyv_f64 a) \ + { \ + __m128d v128 = _mm_##INTRIN##_pd(_mm256_castpd256_pd128(a), _mm256_extractf128_pd(a, 1)); \ + __m128d v64 = _mm_##INTRIN##_pd(v128, _mm_shuffle_pd(v128, v128, _MM_SHUFFLE(0, 0, 0, 1))); \ + return _mm_cvtsd_f64(v64); \ + } \ + NPY_FINLINE float npyv_reduce_##INTRIN##p_f32(npyv_f32 a) \ + { \ + npyv_b32 notnan = npyv_notnan_f32(a); \ + if (NPY_UNLIKELY(!npyv_any_b32(notnan))) { \ + return _mm_cvtss_f32(_mm256_castps256_ps128(a)); \ + } \ + a = npyv_select_f32(notnan, a, npyv_reinterpret_f32_u32(npyv_setall_u32(INF))); \ + return npyv_reduce_##INTRIN##_f32(a); \ + } \ + NPY_FINLINE double npyv_reduce_##INTRIN##p_f64(npyv_f64 a) \ + { \ + npyv_b64 notnan = npyv_notnan_f64(a); \ + if (NPY_UNLIKELY(!npyv_any_b64(notnan))) { \ + return _mm_cvtsd_f64(_mm256_castpd256_pd128(a)); \ + } \ + a = npyv_select_f64(notnan, a, npyv_reinterpret_f64_u64(npyv_setall_u64(INF64))); \ + return npyv_reduce_##INTRIN##_f64(a); \ + } \ + NPY_FINLINE float npyv_reduce_##INTRIN##n_f32(npyv_f32 a) \ + { \ + npyv_b32 notnan = npyv_notnan_f32(a); \ + if (NPY_UNLIKELY(!npyv_all_b32(notnan))) { \ + const union { npy_uint32 i; float f;} pnan = {0x7fc00000UL}; \ + return pnan.f; \ + } \ + return npyv_reduce_##INTRIN##_f32(a); \ + } \ + NPY_FINLINE double npyv_reduce_##INTRIN##n_f64(npyv_f64 a) \ + { \ + npyv_b64 notnan = npyv_notnan_f64(a); \ + if (NPY_UNLIKELY(!npyv_all_b64(notnan))) { \ + const union { npy_uint64 i; double d;} pnan = {0x7ff8000000000000ull}; \ + return pnan.d; \ + } \ + return npyv_reduce_##INTRIN##_f64(a); \ + } +NPY_IMPL_AVX2_REDUCE_MINMAX(min, 0x7f800000, 0x7ff0000000000000) +NPY_IMPL_AVX2_REDUCE_MINMAX(max, 0xff800000, 0xfff0000000000000) +#undef NPY_IMPL_AVX2_REDUCE_MINMAX + +// reduce min&max for 8&16-bits +#define NPY_IMPL_AVX256_REDUCE_MINMAX(STYPE, INTRIN, VINTRIN) \ + NPY_FINLINE STYPE##16 npyv_reduce_##INTRIN##16(__m256i a) \ + { \ + __m128i v128 = _mm_##VINTRIN##16(_mm256_castsi256_si128(a), _mm256_extracti128_si256(a, 1)); \ + __m128i v64 = _mm_##VINTRIN##16(v128, _mm_shuffle_epi32(v128, _MM_SHUFFLE(0, 0, 3, 2))); \ + __m128i v32 = _mm_##VINTRIN##16(v64, _mm_shuffle_epi32(v64, _MM_SHUFFLE(0, 0, 0, 1))); \ + __m128i v16 = _mm_##VINTRIN##16(v32, _mm_shufflelo_epi16(v32, _MM_SHUFFLE(0, 0, 0, 1))); \ + return (STYPE##16)_mm_cvtsi128_si32(v16); \ + } \ + NPY_FINLINE STYPE##8 npyv_reduce_##INTRIN##8(__m256i a) \ + { \ + __m128i v128 = _mm_##VINTRIN##8(_mm256_castsi256_si128(a), _mm256_extracti128_si256(a, 1)); \ + __m128i v64 = _mm_##VINTRIN##8(v128, _mm_shuffle_epi32(v128, _MM_SHUFFLE(0, 0, 3, 2))); \ + __m128i v32 = _mm_##VINTRIN##8(v64, _mm_shuffle_epi32(v64, _MM_SHUFFLE(0, 0, 0, 1))); \ + __m128i v16 = _mm_##VINTRIN##8(v32, _mm_shufflelo_epi16(v32, _MM_SHUFFLE(0, 0, 0, 1))); \ + __m128i v8 = _mm_##VINTRIN##8(v16, _mm_srli_epi16(v16, 8)); \ + return (STYPE##16)_mm_cvtsi128_si32(v8); \ + } +NPY_IMPL_AVX256_REDUCE_MINMAX(npy_uint, min_u, min_epu) +NPY_IMPL_AVX256_REDUCE_MINMAX(npy_int, min_s, min_epi) +NPY_IMPL_AVX256_REDUCE_MINMAX(npy_uint, max_u, max_epu) +NPY_IMPL_AVX256_REDUCE_MINMAX(npy_int, max_s, max_epi) +#undef NPY_IMPL_AVX256_REDUCE_MINMAX + +// round to nearest integer even +#define npyv_rint_f32(A) _mm256_round_ps(A, _MM_FROUND_TO_NEAREST_INT) +#define npyv_rint_f64(A) _mm256_round_pd(A, _MM_FROUND_TO_NEAREST_INT) + +// ceil +#define npyv_ceil_f32 _mm256_ceil_ps +#define npyv_ceil_f64 _mm256_ceil_pd + +// trunc +#define npyv_trunc_f32(A) _mm256_round_ps(A, _MM_FROUND_TO_ZERO) +#define npyv_trunc_f64(A) _mm256_round_pd(A, _MM_FROUND_TO_ZERO) + +// floor +#define npyv_floor_f32 _mm256_floor_ps +#define npyv_floor_f64 _mm256_floor_pd + +#endif // _NPY_SIMD_AVX2_MATH_H diff --git a/mkl_umath/src/npyv/avx2/memory.h b/mkl_umath/src/npyv/avx2/memory.h new file mode 100644 index 00000000..f1863653 --- /dev/null +++ b/mkl_umath/src/npyv/avx2/memory.h @@ -0,0 +1,761 @@ +#ifndef NPY_SIMD + #error "Not a standalone header" +#endif + +#include "misc.h" + +#ifndef _NPY_SIMD_AVX2_MEMORY_H +#define _NPY_SIMD_AVX2_MEMORY_H + +/*************************** + * load/store + ***************************/ +#define NPYV_IMPL_AVX2_MEM_INT(CTYPE, SFX) \ + NPY_FINLINE npyv_##SFX npyv_load_##SFX(const CTYPE *ptr) \ + { return _mm256_loadu_si256((const __m256i*)ptr); } \ + NPY_FINLINE npyv_##SFX npyv_loada_##SFX(const CTYPE *ptr) \ + { return _mm256_load_si256((const __m256i*)ptr); } \ + NPY_FINLINE npyv_##SFX npyv_loads_##SFX(const CTYPE *ptr) \ + { return _mm256_stream_load_si256((const __m256i*)ptr); } \ + NPY_FINLINE npyv_##SFX npyv_loadl_##SFX(const CTYPE *ptr) \ + { return _mm256_castsi128_si256(_mm_loadu_si128((const __m128i*)ptr)); } \ + NPY_FINLINE void npyv_store_##SFX(CTYPE *ptr, npyv_##SFX vec) \ + { _mm256_storeu_si256((__m256i*)ptr, vec); } \ + NPY_FINLINE void npyv_storea_##SFX(CTYPE *ptr, npyv_##SFX vec) \ + { _mm256_store_si256((__m256i*)ptr, vec); } \ + NPY_FINLINE void npyv_stores_##SFX(CTYPE *ptr, npyv_##SFX vec) \ + { _mm256_stream_si256((__m256i*)ptr, vec); } \ + NPY_FINLINE void npyv_storel_##SFX(CTYPE *ptr, npyv_##SFX vec) \ + { _mm_storeu_si128((__m128i*)(ptr), _mm256_castsi256_si128(vec)); } \ + NPY_FINLINE void npyv_storeh_##SFX(CTYPE *ptr, npyv_##SFX vec) \ + { _mm_storeu_si128((__m128i*)(ptr), _mm256_extracti128_si256(vec, 1)); } + +NPYV_IMPL_AVX2_MEM_INT(npy_uint8, u8) +NPYV_IMPL_AVX2_MEM_INT(npy_int8, s8) +NPYV_IMPL_AVX2_MEM_INT(npy_uint16, u16) +NPYV_IMPL_AVX2_MEM_INT(npy_int16, s16) +NPYV_IMPL_AVX2_MEM_INT(npy_uint32, u32) +NPYV_IMPL_AVX2_MEM_INT(npy_int32, s32) +NPYV_IMPL_AVX2_MEM_INT(npy_uint64, u64) +NPYV_IMPL_AVX2_MEM_INT(npy_int64, s64) + +// unaligned load +#define npyv_load_f32 _mm256_loadu_ps +#define npyv_load_f64 _mm256_loadu_pd +// aligned load +#define npyv_loada_f32 _mm256_load_ps +#define npyv_loada_f64 _mm256_load_pd +// stream load +#define npyv_loads_f32(PTR) \ + _mm256_castsi256_ps(_mm256_stream_load_si256((const __m256i*)(PTR))) +#define npyv_loads_f64(PTR) \ + _mm256_castsi256_pd(_mm256_stream_load_si256((const __m256i*)(PTR))) +// load lower part +#define npyv_loadl_f32(PTR) _mm256_castps128_ps256(_mm_loadu_ps(PTR)) +#define npyv_loadl_f64(PTR) _mm256_castpd128_pd256(_mm_loadu_pd(PTR)) +// unaligned store +#define npyv_store_f32 _mm256_storeu_ps +#define npyv_store_f64 _mm256_storeu_pd +// aligned store +#define npyv_storea_f32 _mm256_store_ps +#define npyv_storea_f64 _mm256_store_pd +// stream store +#define npyv_stores_f32 _mm256_stream_ps +#define npyv_stores_f64 _mm256_stream_pd +// store lower part +#define npyv_storel_f32(PTR, VEC) _mm_storeu_ps(PTR, _mm256_castps256_ps128(VEC)) +#define npyv_storel_f64(PTR, VEC) _mm_storeu_pd(PTR, _mm256_castpd256_pd128(VEC)) +// store higher part +#define npyv_storeh_f32(PTR, VEC) _mm_storeu_ps(PTR, _mm256_extractf128_ps(VEC, 1)) +#define npyv_storeh_f64(PTR, VEC) _mm_storeu_pd(PTR, _mm256_extractf128_pd(VEC, 1)) +/*************************** + * Non-contiguous Load + ***************************/ +//// 32 +NPY_FINLINE npyv_u32 npyv_loadn_u32(const npy_uint32 *ptr, npy_intp stride) +{ + assert(llabs(stride) <= NPY_SIMD_MAXLOAD_STRIDE32); + const __m256i steps = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + const __m256i idx = _mm256_mullo_epi32(_mm256_set1_epi32((int)stride), steps); + return _mm256_i32gather_epi32((const int*)ptr, idx, 4); +} +NPY_FINLINE npyv_s32 npyv_loadn_s32(const npy_int32 *ptr, npy_intp stride) +{ return npyv_loadn_u32((const npy_uint32*)ptr, stride); } +NPY_FINLINE npyv_f32 npyv_loadn_f32(const float *ptr, npy_intp stride) +{ return _mm256_castsi256_ps(npyv_loadn_u32((const npy_uint32*)ptr, stride)); } +//// 64 +NPY_FINLINE npyv_f64 npyv_loadn_f64(const double *ptr, npy_intp stride) +{ + __m128d a0 = _mm_castsi128_pd(_mm_loadl_epi64((const __m128i*)ptr)); + __m128d a2 = _mm_castsi128_pd(_mm_loadl_epi64((const __m128i*)(ptr + stride*2))); + __m128d a01 = _mm_loadh_pd(a0, ptr + stride); + __m128d a23 = _mm_loadh_pd(a2, ptr + stride*3); + return _mm256_insertf128_pd(_mm256_castpd128_pd256(a01), a23, 1); +} +NPY_FINLINE npyv_u64 npyv_loadn_u64(const npy_uint64 *ptr, npy_intp stride) +{ return _mm256_castpd_si256(npyv_loadn_f64((const double*)ptr, stride)); } +NPY_FINLINE npyv_s64 npyv_loadn_s64(const npy_int64 *ptr, npy_intp stride) +{ return _mm256_castpd_si256(npyv_loadn_f64((const double*)ptr, stride)); } + +//// 64-bit load over 32-bit stride +NPY_FINLINE npyv_f32 npyv_loadn2_f32(const float *ptr, npy_intp stride) +{ + __m128d a0 = _mm_castsi128_pd(_mm_loadl_epi64((const __m128i*)ptr)); + __m128d a2 = _mm_castsi128_pd(_mm_loadl_epi64((const __m128i*)(ptr + stride*2))); + __m128d a01 = _mm_loadh_pd(a0, (const double*)(ptr + stride)); + __m128d a23 = _mm_loadh_pd(a2, (const double*)(ptr + stride*3)); + return _mm256_castpd_ps(_mm256_insertf128_pd(_mm256_castpd128_pd256(a01), a23, 1)); +} +NPY_FINLINE npyv_u32 npyv_loadn2_u32(const npy_uint32 *ptr, npy_intp stride) +{ return _mm256_castps_si256(npyv_loadn2_f32((const float*)ptr, stride)); } +NPY_FINLINE npyv_s32 npyv_loadn2_s32(const npy_int32 *ptr, npy_intp stride) +{ return _mm256_castps_si256(npyv_loadn2_f32((const float*)ptr, stride)); } + +//// 128-bit load over 64-bit stride +NPY_FINLINE npyv_f64 npyv_loadn2_f64(const double *ptr, npy_intp stride) +{ + __m256d a = npyv_loadl_f64(ptr); + return _mm256_insertf128_pd(a, _mm_loadu_pd(ptr + stride), 1); +} +NPY_FINLINE npyv_u64 npyv_loadn2_u64(const npy_uint64 *ptr, npy_intp stride) +{ return _mm256_castpd_si256(npyv_loadn2_f64((const double*)ptr, stride)); } +NPY_FINLINE npyv_s64 npyv_loadn2_s64(const npy_int64 *ptr, npy_intp stride) +{ return _mm256_castpd_si256(npyv_loadn2_f64((const double*)ptr, stride)); } + +/*************************** + * Non-contiguous Store + ***************************/ +//// 32 +NPY_FINLINE void npyv_storen_s32(npy_int32 *ptr, npy_intp stride, npyv_s32 a) +{ + __m128i a0 = _mm256_castsi256_si128(a); + __m128i a1 = _mm256_extracti128_si256(a, 1); + ptr[stride * 0] = _mm_cvtsi128_si32(a0); + ptr[stride * 1] = _mm_extract_epi32(a0, 1); + ptr[stride * 2] = _mm_extract_epi32(a0, 2); + ptr[stride * 3] = _mm_extract_epi32(a0, 3); + ptr[stride * 4] = _mm_cvtsi128_si32(a1); + ptr[stride * 5] = _mm_extract_epi32(a1, 1); + ptr[stride * 6] = _mm_extract_epi32(a1, 2); + ptr[stride * 7] = _mm_extract_epi32(a1, 3); +} +NPY_FINLINE void npyv_storen_u32(npy_uint32 *ptr, npy_intp stride, npyv_u32 a) +{ npyv_storen_s32((npy_int32*)ptr, stride, a); } +NPY_FINLINE void npyv_storen_f32(float *ptr, npy_intp stride, npyv_f32 a) +{ npyv_storen_s32((npy_int32*)ptr, stride, _mm256_castps_si256(a)); } +//// 64 +NPY_FINLINE void npyv_storen_f64(double *ptr, npy_intp stride, npyv_f64 a) +{ + __m128d a0 = _mm256_castpd256_pd128(a); + __m128d a1 = _mm256_extractf128_pd(a, 1); + _mm_storel_pd(ptr + stride * 0, a0); + _mm_storeh_pd(ptr + stride * 1, a0); + _mm_storel_pd(ptr + stride * 2, a1); + _mm_storeh_pd(ptr + stride * 3, a1); +} +NPY_FINLINE void npyv_storen_u64(npy_uint64 *ptr, npy_intp stride, npyv_u64 a) +{ npyv_storen_f64((double*)ptr, stride, _mm256_castsi256_pd(a)); } +NPY_FINLINE void npyv_storen_s64(npy_int64 *ptr, npy_intp stride, npyv_s64 a) +{ npyv_storen_f64((double*)ptr, stride, _mm256_castsi256_pd(a)); } + +//// 64-bit store over 32-bit stride +NPY_FINLINE void npyv_storen2_u32(npy_uint32 *ptr, npy_intp stride, npyv_u32 a) +{ + __m128d a0 = _mm256_castpd256_pd128(_mm256_castsi256_pd(a)); + __m128d a1 = _mm256_extractf128_pd(_mm256_castsi256_pd(a), 1); + _mm_storel_pd((double*)ptr, a0); + _mm_storeh_pd((double*)(ptr + stride), a0); + _mm_storel_pd((double*)(ptr + stride*2), a1); + _mm_storeh_pd((double*)(ptr + stride*3), a1); +} +NPY_FINLINE void npyv_storen2_s32(npy_int32 *ptr, npy_intp stride, npyv_s32 a) +{ npyv_storen2_u32((npy_uint32*)ptr, stride, a); } +NPY_FINLINE void npyv_storen2_f32(float *ptr, npy_intp stride, npyv_f32 a) +{ npyv_storen2_u32((npy_uint32*)ptr, stride, _mm256_castps_si256(a)); } + +//// 128-bit store over 64-bit stride +NPY_FINLINE void npyv_storen2_u64(npy_uint64 *ptr, npy_intp stride, npyv_u64 a) +{ + npyv_storel_u64(ptr, a); + npyv_storeh_u64(ptr + stride, a); +} +NPY_FINLINE void npyv_storen2_s64(npy_int64 *ptr, npy_intp stride, npyv_s64 a) +{ npyv_storen2_u64((npy_uint64*)ptr, stride, a); } +NPY_FINLINE void npyv_storen2_f64(double *ptr, npy_intp stride, npyv_f64 a) +{ npyv_storen2_u64((npy_uint64*)ptr, stride, _mm256_castpd_si256(a)); } + +/********************************* + * Partial Load + *********************************/ +//// 32 +NPY_FINLINE npyv_s32 npyv_load_till_s32(const npy_int32 *ptr, npy_uintp nlane, npy_int32 fill) +{ + assert(nlane > 0); + const __m256i vfill = _mm256_set1_epi32(fill); + const __m256i steps = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + __m256i vnlane = _mm256_set1_epi32(nlane > 8 ? 8 : (int)nlane); + __m256i mask = _mm256_cmpgt_epi32(vnlane, steps); + __m256i payload = _mm256_maskload_epi32((const int*)ptr, mask); + __m256i ret = _mm256_blendv_epi8(vfill, payload, mask); +#if NPY_SIMD_GUARD_PARTIAL_LOAD + volatile __m256i workaround = ret; + ret = _mm256_or_si256(workaround, ret); +#endif + return ret; +} +// fill zero to rest lanes +NPY_FINLINE npyv_s32 npyv_load_tillz_s32(const npy_int32 *ptr, npy_uintp nlane) +{ + assert(nlane > 0); + const __m256i steps = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + __m256i vnlane = _mm256_set1_epi32(nlane > 8 ? 8 : (int)nlane); + __m256i mask = _mm256_cmpgt_epi32(vnlane, steps); + __m256i ret = _mm256_maskload_epi32((const int*)ptr, mask); +#if NPY_SIMD_GUARD_PARTIAL_LOAD + volatile __m256i workaround = ret; + ret = _mm256_or_si256(workaround, ret); +#endif + return ret; +} +//// 64 +NPY_FINLINE npyv_s64 npyv_load_till_s64(const npy_int64 *ptr, npy_uintp nlane, npy_int64 fill) +{ + assert(nlane > 0); + const __m256i vfill = npyv_setall_s64(fill); + const __m256i steps = npyv_set_s64(0, 1, 2, 3); + __m256i vnlane = npyv_setall_s64(nlane > 4 ? 4 : (int)nlane); + __m256i mask = _mm256_cmpgt_epi64(vnlane, steps); + __m256i payload = _mm256_maskload_epi64((const long long*)ptr, mask); + __m256i ret = _mm256_blendv_epi8(vfill, payload, mask); +#if NPY_SIMD_GUARD_PARTIAL_LOAD + volatile __m256i workaround = ret; + ret = _mm256_or_si256(workaround, ret); +#endif + return ret; +} +// fill zero to rest lanes +NPY_FINLINE npyv_s64 npyv_load_tillz_s64(const npy_int64 *ptr, npy_uintp nlane) +{ + assert(nlane > 0); + const __m256i steps = npyv_set_s64(0, 1, 2, 3); + __m256i vnlane = npyv_setall_s64(nlane > 4 ? 4 : (int)nlane); + __m256i mask = _mm256_cmpgt_epi64(vnlane, steps); + __m256i ret = _mm256_maskload_epi64((const long long*)ptr, mask); +#if NPY_SIMD_GUARD_PARTIAL_LOAD + volatile __m256i workaround = ret; + ret = _mm256_or_si256(workaround, ret); +#endif + return ret; +} + +//// 64-bit nlane +NPY_FINLINE npyv_s32 npyv_load2_till_s32(const npy_int32 *ptr, npy_uintp nlane, + npy_int32 fill_lo, npy_int32 fill_hi) +{ + assert(nlane > 0); + const __m256i vfill = npyv_set_s32( + fill_lo, fill_hi, fill_lo, fill_hi, + fill_lo, fill_hi, fill_lo, fill_hi + ); + const __m256i steps = npyv_set_s64(0, 1, 2, 3); + __m256i vnlane = npyv_setall_s64(nlane > 4 ? 4 : (int)nlane); + __m256i mask = _mm256_cmpgt_epi64(vnlane, steps); + __m256i payload = _mm256_maskload_epi64((const long long*)ptr, mask); + __m256i ret = _mm256_blendv_epi8(vfill, payload, mask); +#if NPY_SIMD_GUARD_PARTIAL_LOAD + volatile __m256i workaround = ret; + ret = _mm256_or_si256(workaround, ret); +#endif + return ret; +} +// fill zero to rest lanes +NPY_FINLINE npyv_s32 npyv_load2_tillz_s32(const npy_int32 *ptr, npy_uintp nlane) +{ return npyv_load_tillz_s64((const npy_int64*)ptr, nlane); } + +/// 128-bit nlane +NPY_FINLINE npyv_u64 npyv_load2_tillz_s64(const npy_int64 *ptr, npy_uintp nlane) +{ + assert(nlane > 0); + npy_int64 m = -((npy_int64)(nlane > 1)); + __m256i mask = npyv_set_s64(-1, -1, m, m); + __m256i ret = _mm256_maskload_epi64((const long long*)ptr, mask); +#if NPY_SIMD_GUARD_PARTIAL_LOAD + volatile __m256i workaround = ret; + ret = _mm256_or_si256(workaround, ret); +#endif + return ret; +} +// fill zero to rest lanes +NPY_FINLINE npyv_u64 npyv_load2_till_s64(const npy_int64 *ptr, npy_uintp nlane, + npy_int64 fill_lo, npy_int64 fill_hi) +{ + const __m256i vfill = npyv_set_s64(0, 0, fill_lo, fill_hi); + npy_int64 m = -((npy_int64)(nlane > 1)); + __m256i mask = npyv_set_s64(-1, -1, m, m); + __m256i payload = _mm256_maskload_epi64((const long long*)ptr, mask); + __m256i ret =_mm256_blendv_epi8(vfill, payload, mask); +#if NPY_SIMD_GUARD_PARTIAL_LOAD + volatile __m256i workaround = ret; + ret = _mm256_or_si256(workaround, ret); +#endif + return ret; +} +/********************************* + * Non-contiguous partial load + *********************************/ +//// 32 +NPY_FINLINE npyv_s32 +npyv_loadn_till_s32(const npy_int32 *ptr, npy_intp stride, npy_uintp nlane, npy_int32 fill) +{ + assert(nlane > 0); + assert(llabs(stride) <= NPY_SIMD_MAXLOAD_STRIDE32); + const __m256i vfill = _mm256_set1_epi32(fill); + const __m256i steps = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + const __m256i idx = _mm256_mullo_epi32(_mm256_set1_epi32((int)stride), steps); + __m256i vnlane = _mm256_set1_epi32(nlane > 8 ? 8 : (int)nlane); + __m256i mask = _mm256_cmpgt_epi32(vnlane, steps); + __m256i ret = _mm256_mask_i32gather_epi32(vfill, (const int*)ptr, idx, mask, 4); +#if NPY_SIMD_GUARD_PARTIAL_LOAD + volatile __m256i workaround = ret; + ret = _mm256_or_si256(workaround, ret); +#endif + return ret; +} +// fill zero to rest lanes +NPY_FINLINE npyv_s32 +npyv_loadn_tillz_s32(const npy_int32 *ptr, npy_intp stride, npy_uintp nlane) +{ return npyv_loadn_till_s32(ptr, stride, nlane, 0); } +//// 64 +NPY_FINLINE npyv_s64 +npyv_loadn_till_s64(const npy_int64 *ptr, npy_intp stride, npy_uintp nlane, npy_int64 fill) +{ + assert(nlane > 0); + const __m256i vfill = npyv_setall_s64(fill); + const __m256i idx = npyv_set_s64(0, 1*stride, 2*stride, 3*stride); + const __m256i steps = npyv_set_s64(0, 1, 2, 3); + __m256i vnlane = npyv_setall_s64(nlane > 4 ? 4 : (int)nlane); + __m256i mask = _mm256_cmpgt_epi64(vnlane, steps); + __m256i ret = _mm256_mask_i64gather_epi64(vfill, (const long long*)ptr, idx, mask, 8); +#if NPY_SIMD_GUARD_PARTIAL_LOAD + volatile __m256i workaround = ret; + ret = _mm256_or_si256(workaround, ret); +#endif + return ret; +} +// fill zero to rest lanes +NPY_FINLINE npyv_s64 +npyv_loadn_tillz_s64(const npy_int64 *ptr, npy_intp stride, npy_uintp nlane) +{ return npyv_loadn_till_s64(ptr, stride, nlane, 0); } + +//// 64-bit load over 32-bit stride +NPY_FINLINE npyv_s64 npyv_loadn2_till_s32(const npy_int32 *ptr, npy_intp stride, npy_uintp nlane, + npy_int32 fill_lo, npy_int32 fill_hi) +{ + assert(nlane > 0); + const __m256i vfill = npyv_set_s32( + fill_lo, fill_hi, fill_lo, fill_hi, + fill_lo, fill_hi, fill_lo, fill_hi + ); + const __m256i idx = npyv_set_s64(0, 1*stride, 2*stride, 3*stride); + const __m256i steps = npyv_set_s64(0, 1, 2, 3); + __m256i vnlane = npyv_setall_s64(nlane > 4 ? 4 : (int)nlane); + __m256i mask = _mm256_cmpgt_epi64(vnlane, steps); + __m256i ret = _mm256_mask_i64gather_epi64(vfill, (const long long*)ptr, idx, mask, 4); +#if NPY_SIMD_GUARD_PARTIAL_LOAD + volatile __m256i workaround = ret; + ret = _mm256_or_si256(workaround, ret); +#endif + return ret; +} +// fill zero to rest lanes +NPY_FINLINE npyv_s32 npyv_loadn2_tillz_s32(const npy_int32 *ptr, npy_intp stride, npy_uintp nlane) +{ return npyv_loadn2_till_s32(ptr, stride, nlane, 0, 0); } + +//// 128-bit load over 64-bit stride +NPY_FINLINE npyv_s64 npyv_loadn2_till_s64(const npy_int64 *ptr, npy_intp stride, npy_uintp nlane, + npy_int64 fill_lo, npy_int64 fill_hi) +{ + assert(nlane > 0); + __m256i a = npyv_loadl_s64(ptr); +#if defined(_MSC_VER) && defined(_M_IX86) + __m128i fill =_mm_setr_epi32( + (int)fill_lo, (int)(fill_lo >> 32), + (int)fill_hi, (int)(fill_hi >> 32) + ); +#else + __m128i fill = _mm_set_epi64x(fill_hi, fill_lo); +#endif + __m128i b = nlane > 1 ? _mm_loadu_si128((const __m128i*)(ptr + stride)) : fill; + __m256i ret = _mm256_inserti128_si256(a, b, 1); +#if NPY_SIMD_GUARD_PARTIAL_LOAD + volatile __m256i workaround = ret; + ret = _mm256_or_si256(workaround, ret); +#endif + return ret; +} +// fill zero to rest lanes +NPY_FINLINE npyv_s64 npyv_loadn2_tillz_s64(const npy_int64 *ptr, npy_intp stride, npy_uintp nlane) +{ return npyv_loadn2_till_s64(ptr, stride, nlane, 0, 0); } + +/********************************* + * Partial store + *********************************/ +//// 32 +NPY_FINLINE void npyv_store_till_s32(npy_int32 *ptr, npy_uintp nlane, npyv_s32 a) +{ + assert(nlane > 0); + const __m256i steps = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + __m256i vnlane = _mm256_set1_epi32(nlane > 8 ? 8 : (int)nlane); + __m256i mask = _mm256_cmpgt_epi32(vnlane, steps); + _mm256_maskstore_epi32((int*)ptr, mask, a); +} +//// 64 +NPY_FINLINE void npyv_store_till_s64(npy_int64 *ptr, npy_uintp nlane, npyv_s64 a) +{ + assert(nlane > 0); + const __m256i steps = npyv_set_s64(0, 1, 2, 3); + __m256i vnlane = npyv_setall_s64(nlane > 8 ? 8 : (int)nlane); + __m256i mask = _mm256_cmpgt_epi64(vnlane, steps); + _mm256_maskstore_epi64((long long*)ptr, mask, a); +} + +//// 64-bit nlane +NPY_FINLINE void npyv_store2_till_s32(npy_int32 *ptr, npy_uintp nlane, npyv_s32 a) +{ npyv_store_till_s64((npy_int64*)ptr, nlane, a); } + +//// 128-bit nlane +NPY_FINLINE void npyv_store2_till_s64(npy_int64 *ptr, npy_uintp nlane, npyv_s64 a) +{ + assert(nlane > 0); +#ifdef _MSC_VER + /* + * Although this version is compatible with all other compilers, + * there is no performance benefit in retaining the other branch. + * However, it serves as evidence of a newly emerging bug in MSVC + * that started to appear since v19.30. + * For some reason, the MSVC optimizer chooses to ignore the lower store (128-bit mov) + * and replace with full mov counting on ymmword pointer. + * + * For more details, please refer to the discussion on https://github.com/numpy/numpy/issues/23896. + */ + if (nlane > 1) { + npyv_store_s64(ptr, a); + } + else { + npyv_storel_s64(ptr, a); + } +#else + npyv_storel_s64(ptr, a); + if (nlane > 1) { + npyv_storeh_s64(ptr + 2, a); + } +#endif +} +/********************************* + * Non-contiguous partial store + *********************************/ +//// 32 +NPY_FINLINE void npyv_storen_till_s32(npy_int32 *ptr, npy_intp stride, npy_uintp nlane, npyv_s32 a) +{ + assert(nlane > 0); + __m128i a0 = _mm256_castsi256_si128(a); + __m128i a1 = _mm256_extracti128_si256(a, 1); + + ptr[stride*0] = _mm_extract_epi32(a0, 0); + switch(nlane) { + case 1: + return; + case 2: + ptr[stride*1] = _mm_extract_epi32(a0, 1); + return; + case 3: + ptr[stride*1] = _mm_extract_epi32(a0, 1); + ptr[stride*2] = _mm_extract_epi32(a0, 2); + return; + case 4: + ptr[stride*1] = _mm_extract_epi32(a0, 1); + ptr[stride*2] = _mm_extract_epi32(a0, 2); + ptr[stride*3] = _mm_extract_epi32(a0, 3); + return; + case 5: + ptr[stride*1] = _mm_extract_epi32(a0, 1); + ptr[stride*2] = _mm_extract_epi32(a0, 2); + ptr[stride*3] = _mm_extract_epi32(a0, 3); + ptr[stride*4] = _mm_extract_epi32(a1, 0); + return; + case 6: + ptr[stride*1] = _mm_extract_epi32(a0, 1); + ptr[stride*2] = _mm_extract_epi32(a0, 2); + ptr[stride*3] = _mm_extract_epi32(a0, 3); + ptr[stride*4] = _mm_extract_epi32(a1, 0); + ptr[stride*5] = _mm_extract_epi32(a1, 1); + return; + case 7: + ptr[stride*1] = _mm_extract_epi32(a0, 1); + ptr[stride*2] = _mm_extract_epi32(a0, 2); + ptr[stride*3] = _mm_extract_epi32(a0, 3); + ptr[stride*4] = _mm_extract_epi32(a1, 0); + ptr[stride*5] = _mm_extract_epi32(a1, 1); + ptr[stride*6] = _mm_extract_epi32(a1, 2); + return; + default: + ptr[stride*1] = _mm_extract_epi32(a0, 1); + ptr[stride*2] = _mm_extract_epi32(a0, 2); + ptr[stride*3] = _mm_extract_epi32(a0, 3); + ptr[stride*4] = _mm_extract_epi32(a1, 0); + ptr[stride*5] = _mm_extract_epi32(a1, 1); + ptr[stride*6] = _mm_extract_epi32(a1, 2); + ptr[stride*7] = _mm_extract_epi32(a1, 3); + } +} +//// 64 +NPY_FINLINE void npyv_storen_till_s64(npy_int64 *ptr, npy_intp stride, npy_uintp nlane, npyv_s64 a) +{ + assert(nlane > 0); + __m128d a0 = _mm256_castpd256_pd128(_mm256_castsi256_pd(a)); + __m128d a1 = _mm256_extractf128_pd(_mm256_castsi256_pd(a), 1); + + double *dptr = (double*)ptr; + _mm_storel_pd(dptr, a0); + switch(nlane) { + case 1: + return; + case 2: + _mm_storeh_pd(dptr + stride * 1, a0); + return; + case 3: + _mm_storeh_pd(dptr + stride * 1, a0); + _mm_storel_pd(dptr + stride * 2, a1); + return; + default: + _mm_storeh_pd(dptr + stride * 1, a0); + _mm_storel_pd(dptr + stride * 2, a1); + _mm_storeh_pd(dptr + stride * 3, a1); + } +} + +//// 64-bit store over 32-bit stride +NPY_FINLINE void npyv_storen2_till_s32(npy_int32 *ptr, npy_intp stride, npy_uintp nlane, npyv_s32 a) +{ + assert(nlane > 0); + __m128d a0 = _mm256_castpd256_pd128(_mm256_castsi256_pd(a)); + __m128d a1 = _mm256_extractf128_pd(_mm256_castsi256_pd(a), 1); + + _mm_storel_pd((double*)ptr, a0); + switch(nlane) { + case 1: + return; + case 2: + _mm_storeh_pd((double*)(ptr + stride * 1), a0); + return; + case 3: + _mm_storeh_pd((double*)(ptr + stride * 1), a0); + _mm_storel_pd((double*)(ptr + stride * 2), a1); + return; + default: + _mm_storeh_pd((double*)(ptr + stride * 1), a0); + _mm_storel_pd((double*)(ptr + stride * 2), a1); + _mm_storeh_pd((double*)(ptr + stride * 3), a1); + } +} + +//// 128-bit store over 64-bit stride +NPY_FINLINE void npyv_storen2_till_s64(npy_int64 *ptr, npy_intp stride, npy_uintp nlane, npyv_s64 a) +{ + assert(nlane > 0); + npyv_storel_s64(ptr, a); + if (nlane > 1) { + npyv_storeh_s64(ptr + stride, a); + } +} +/***************************************************************************** + * Implement partial load/store for u32/f32/u64/f64... via reinterpret cast + *****************************************************************************/ +#define NPYV_IMPL_AVX2_REST_PARTIAL_TYPES(F_SFX, T_SFX) \ + NPY_FINLINE npyv_##F_SFX npyv_load_till_##F_SFX \ + (const npyv_lanetype_##F_SFX *ptr, npy_uintp nlane, npyv_lanetype_##F_SFX fill) \ + { \ + union { \ + npyv_lanetype_##F_SFX from_##F_SFX; \ + npyv_lanetype_##T_SFX to_##T_SFX; \ + } pun; \ + pun.from_##F_SFX = fill; \ + return npyv_reinterpret_##F_SFX##_##T_SFX(npyv_load_till_##T_SFX( \ + (const npyv_lanetype_##T_SFX *)ptr, nlane, pun.to_##T_SFX \ + )); \ + } \ + NPY_FINLINE npyv_##F_SFX npyv_loadn_till_##F_SFX \ + (const npyv_lanetype_##F_SFX *ptr, npy_intp stride, npy_uintp nlane, \ + npyv_lanetype_##F_SFX fill) \ + { \ + union { \ + npyv_lanetype_##F_SFX from_##F_SFX; \ + npyv_lanetype_##T_SFX to_##T_SFX; \ + } pun; \ + pun.from_##F_SFX = fill; \ + return npyv_reinterpret_##F_SFX##_##T_SFX(npyv_loadn_till_##T_SFX( \ + (const npyv_lanetype_##T_SFX *)ptr, stride, nlane, pun.to_##T_SFX \ + )); \ + } \ + NPY_FINLINE npyv_##F_SFX npyv_load_tillz_##F_SFX \ + (const npyv_lanetype_##F_SFX *ptr, npy_uintp nlane) \ + { \ + return npyv_reinterpret_##F_SFX##_##T_SFX(npyv_load_tillz_##T_SFX( \ + (const npyv_lanetype_##T_SFX *)ptr, nlane \ + )); \ + } \ + NPY_FINLINE npyv_##F_SFX npyv_loadn_tillz_##F_SFX \ + (const npyv_lanetype_##F_SFX *ptr, npy_intp stride, npy_uintp nlane) \ + { \ + return npyv_reinterpret_##F_SFX##_##T_SFX(npyv_loadn_tillz_##T_SFX( \ + (const npyv_lanetype_##T_SFX *)ptr, stride, nlane \ + )); \ + } \ + NPY_FINLINE void npyv_store_till_##F_SFX \ + (npyv_lanetype_##F_SFX *ptr, npy_uintp nlane, npyv_##F_SFX a) \ + { \ + npyv_store_till_##T_SFX( \ + (npyv_lanetype_##T_SFX *)ptr, nlane, \ + npyv_reinterpret_##T_SFX##_##F_SFX(a) \ + ); \ + } \ + NPY_FINLINE void npyv_storen_till_##F_SFX \ + (npyv_lanetype_##F_SFX *ptr, npy_intp stride, npy_uintp nlane, npyv_##F_SFX a) \ + { \ + npyv_storen_till_##T_SFX( \ + (npyv_lanetype_##T_SFX *)ptr, stride, nlane, \ + npyv_reinterpret_##T_SFX##_##F_SFX(a) \ + ); \ + } + +NPYV_IMPL_AVX2_REST_PARTIAL_TYPES(u32, s32) +NPYV_IMPL_AVX2_REST_PARTIAL_TYPES(f32, s32) +NPYV_IMPL_AVX2_REST_PARTIAL_TYPES(u64, s64) +NPYV_IMPL_AVX2_REST_PARTIAL_TYPES(f64, s64) + +// 128-bit/64-bit stride (load/store pair) +#define NPYV_IMPL_AVX2_REST_PARTIAL_TYPES_PAIR(F_SFX, T_SFX) \ + NPY_FINLINE npyv_##F_SFX npyv_load2_till_##F_SFX \ + (const npyv_lanetype_##F_SFX *ptr, npy_uintp nlane, \ + npyv_lanetype_##F_SFX fill_lo, npyv_lanetype_##F_SFX fill_hi) \ + { \ + union pun { \ + npyv_lanetype_##F_SFX from_##F_SFX; \ + npyv_lanetype_##T_SFX to_##T_SFX; \ + }; \ + union pun pun_lo; \ + union pun pun_hi; \ + pun_lo.from_##F_SFX = fill_lo; \ + pun_hi.from_##F_SFX = fill_hi; \ + return npyv_reinterpret_##F_SFX##_##T_SFX(npyv_load2_till_##T_SFX( \ + (const npyv_lanetype_##T_SFX *)ptr, nlane, pun_lo.to_##T_SFX, pun_hi.to_##T_SFX \ + )); \ + } \ + NPY_FINLINE npyv_##F_SFX npyv_loadn2_till_##F_SFX \ + (const npyv_lanetype_##F_SFX *ptr, npy_intp stride, npy_uintp nlane, \ + npyv_lanetype_##F_SFX fill_lo, npyv_lanetype_##F_SFX fill_hi) \ + { \ + union pun { \ + npyv_lanetype_##F_SFX from_##F_SFX; \ + npyv_lanetype_##T_SFX to_##T_SFX; \ + }; \ + union pun pun_lo; \ + union pun pun_hi; \ + pun_lo.from_##F_SFX = fill_lo; \ + pun_hi.from_##F_SFX = fill_hi; \ + return npyv_reinterpret_##F_SFX##_##T_SFX(npyv_loadn2_till_##T_SFX( \ + (const npyv_lanetype_##T_SFX *)ptr, stride, nlane, pun_lo.to_##T_SFX, \ + pun_hi.to_##T_SFX \ + )); \ + } \ + NPY_FINLINE npyv_##F_SFX npyv_load2_tillz_##F_SFX \ + (const npyv_lanetype_##F_SFX *ptr, npy_uintp nlane) \ + { \ + return npyv_reinterpret_##F_SFX##_##T_SFX(npyv_load2_tillz_##T_SFX( \ + (const npyv_lanetype_##T_SFX *)ptr, nlane \ + )); \ + } \ + NPY_FINLINE npyv_##F_SFX npyv_loadn2_tillz_##F_SFX \ + (const npyv_lanetype_##F_SFX *ptr, npy_intp stride, npy_uintp nlane) \ + { \ + return npyv_reinterpret_##F_SFX##_##T_SFX(npyv_loadn2_tillz_##T_SFX( \ + (const npyv_lanetype_##T_SFX *)ptr, stride, nlane \ + )); \ + } \ + NPY_FINLINE void npyv_store2_till_##F_SFX \ + (npyv_lanetype_##F_SFX *ptr, npy_uintp nlane, npyv_##F_SFX a) \ + { \ + npyv_store2_till_##T_SFX( \ + (npyv_lanetype_##T_SFX *)ptr, nlane, \ + npyv_reinterpret_##T_SFX##_##F_SFX(a) \ + ); \ + } \ + NPY_FINLINE void npyv_storen2_till_##F_SFX \ + (npyv_lanetype_##F_SFX *ptr, npy_intp stride, npy_uintp nlane, npyv_##F_SFX a) \ + { \ + npyv_storen2_till_##T_SFX( \ + (npyv_lanetype_##T_SFX *)ptr, stride, nlane, \ + npyv_reinterpret_##T_SFX##_##F_SFX(a) \ + ); \ + } + +NPYV_IMPL_AVX2_REST_PARTIAL_TYPES_PAIR(u32, s32) +NPYV_IMPL_AVX2_REST_PARTIAL_TYPES_PAIR(f32, s32) +NPYV_IMPL_AVX2_REST_PARTIAL_TYPES_PAIR(u64, s64) +NPYV_IMPL_AVX2_REST_PARTIAL_TYPES_PAIR(f64, s64) + +/************************************************************ + * de-interlave load / interleave contiguous store + ************************************************************/ +// two channels +#define NPYV_IMPL_AVX2_MEM_INTERLEAVE(SFX, ZSFX) \ + NPY_FINLINE npyv_##ZSFX##x2 npyv_zip_##ZSFX(npyv_##ZSFX, npyv_##ZSFX); \ + NPY_FINLINE npyv_##ZSFX##x2 npyv_unzip_##ZSFX(npyv_##ZSFX, npyv_##ZSFX); \ + NPY_FINLINE npyv_##SFX##x2 npyv_load_##SFX##x2( \ + const npyv_lanetype_##SFX *ptr \ + ) { \ + return npyv_unzip_##ZSFX( \ + npyv_load_##SFX(ptr), npyv_load_##SFX(ptr+npyv_nlanes_##SFX) \ + ); \ + } \ + NPY_FINLINE void npyv_store_##SFX##x2( \ + npyv_lanetype_##SFX *ptr, npyv_##SFX##x2 v \ + ) { \ + npyv_##SFX##x2 zip = npyv_zip_##ZSFX(v.val[0], v.val[1]); \ + npyv_store_##SFX(ptr, zip.val[0]); \ + npyv_store_##SFX(ptr + npyv_nlanes_##SFX, zip.val[1]); \ + } + +NPYV_IMPL_AVX2_MEM_INTERLEAVE(u8, u8) +NPYV_IMPL_AVX2_MEM_INTERLEAVE(s8, u8) +NPYV_IMPL_AVX2_MEM_INTERLEAVE(u16, u16) +NPYV_IMPL_AVX2_MEM_INTERLEAVE(s16, u16) +NPYV_IMPL_AVX2_MEM_INTERLEAVE(u32, u32) +NPYV_IMPL_AVX2_MEM_INTERLEAVE(s32, u32) +NPYV_IMPL_AVX2_MEM_INTERLEAVE(u64, u64) +NPYV_IMPL_AVX2_MEM_INTERLEAVE(s64, u64) +NPYV_IMPL_AVX2_MEM_INTERLEAVE(f32, f32) +NPYV_IMPL_AVX2_MEM_INTERLEAVE(f64, f64) + +/********************************* + * Lookup tables + *********************************/ +// uses vector as indexes into a table +// that contains 32 elements of float32. +NPY_FINLINE npyv_f32 npyv_lut32_f32(const float *table, npyv_u32 idx) +{ return _mm256_i32gather_ps(table, idx, 4); } +NPY_FINLINE npyv_u32 npyv_lut32_u32(const npy_uint32 *table, npyv_u32 idx) +{ return npyv_reinterpret_u32_f32(npyv_lut32_f32((const float*)table, idx)); } +NPY_FINLINE npyv_s32 npyv_lut32_s32(const npy_int32 *table, npyv_u32 idx) +{ return npyv_reinterpret_s32_f32(npyv_lut32_f32((const float*)table, idx)); } + +// uses vector as indexes into a table +// that contains 16 elements of float64. +NPY_FINLINE npyv_f64 npyv_lut16_f64(const double *table, npyv_u64 idx) +{ return _mm256_i64gather_pd(table, idx, 8); } +NPY_FINLINE npyv_u64 npyv_lut16_u64(const npy_uint64 *table, npyv_u64 idx) +{ return npyv_reinterpret_u64_f64(npyv_lut16_f64((const double*)table, idx)); } +NPY_FINLINE npyv_s64 npyv_lut16_s64(const npy_int64 *table, npyv_u64 idx) +{ return npyv_reinterpret_s64_f64(npyv_lut16_f64((const double*)table, idx)); } + +#endif // _NPY_SIMD_AVX2_MEMORY_H diff --git a/mkl_umath/src/npyv/avx2/misc.h b/mkl_umath/src/npyv/avx2/misc.h new file mode 100644 index 00000000..41e788c7 --- /dev/null +++ b/mkl_umath/src/npyv/avx2/misc.h @@ -0,0 +1,258 @@ +#ifndef NPY_SIMD + #error "Not a standalone header" +#endif + +#ifndef _NPY_SIMD_AVX2_MISC_H +#define _NPY_SIMD_AVX2_MISC_H + +// vector with zero lanes +#define npyv_zero_u8 _mm256_setzero_si256 +#define npyv_zero_s8 _mm256_setzero_si256 +#define npyv_zero_u16 _mm256_setzero_si256 +#define npyv_zero_s16 _mm256_setzero_si256 +#define npyv_zero_u32 _mm256_setzero_si256 +#define npyv_zero_s32 _mm256_setzero_si256 +#define npyv_zero_u64 _mm256_setzero_si256 +#define npyv_zero_s64 _mm256_setzero_si256 +#define npyv_zero_f32 _mm256_setzero_ps +#define npyv_zero_f64 _mm256_setzero_pd + +// vector with a specific value set to all lanes +#define npyv_setall_u8(VAL) _mm256_set1_epi8((char)VAL) +#define npyv_setall_s8(VAL) _mm256_set1_epi8((char)VAL) +#define npyv_setall_u16(VAL) _mm256_set1_epi16((short)VAL) +#define npyv_setall_s16(VAL) _mm256_set1_epi16((short)VAL) +#define npyv_setall_u32(VAL) _mm256_set1_epi32((int)VAL) +#define npyv_setall_s32(VAL) _mm256_set1_epi32(VAL) +#define npyv_setall_f32(VAL) _mm256_set1_ps(VAL) +#define npyv_setall_f64(VAL) _mm256_set1_pd(VAL) + +NPY_FINLINE __m256i npyv__setr_epi64(npy_int64, npy_int64, npy_int64, npy_int64); +NPY_FINLINE npyv_u64 npyv_setall_u64(npy_uint64 a) +{ + npy_int64 ai = (npy_int64)a; +#if defined(_MSC_VER) && defined(_M_IX86) + return npyv__setr_epi64(ai, ai, ai, ai); +#else + return _mm256_set1_epi64x(ai); +#endif +} +NPY_FINLINE npyv_s64 npyv_setall_s64(npy_int64 a) +{ +#if defined(_MSC_VER) && defined(_M_IX86) + return npyv__setr_epi64(a, a, a, a); +#else + return _mm256_set1_epi64x(a); +#endif +} +/* + * vector with specific values set to each lane and + * set a specific value to all remained lanes + * + * Args that generated by NPYV__SET_FILL_* not going to expand if + * _mm256_setr_* are defined as macros. +*/ +NPY_FINLINE __m256i npyv__setr_epi8( + char i0, char i1, char i2, char i3, char i4, char i5, char i6, char i7, + char i8, char i9, char i10, char i11, char i12, char i13, char i14, char i15, + char i16, char i17, char i18, char i19, char i20, char i21, char i22, char i23, + char i24, char i25, char i26, char i27, char i28, char i29, char i30, char i31) +{ + return _mm256_setr_epi8( + i0, i1, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12, i13, i14, i15, + i16, i17, i18, i19, i20, i21, i22, i23, i24, i25, i26, i27, i28, i29, i30, i31 + ); +} +NPY_FINLINE __m256i npyv__setr_epi16( + short i0, short i1, short i2, short i3, short i4, short i5, short i6, short i7, + short i8, short i9, short i10, short i11, short i12, short i13, short i14, short i15) +{ + return _mm256_setr_epi16(i0, i1, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12, i13, i14, i15); +} +NPY_FINLINE __m256i npyv__setr_epi32(int i0, int i1, int i2, int i3, int i4, int i5, int i6, int i7) +{ + return _mm256_setr_epi32(i0, i1, i2, i3, i4, i5, i6, i7); +} +NPY_FINLINE __m256i npyv__setr_epi64(npy_int64 i0, npy_int64 i1, npy_int64 i2, npy_int64 i3) +{ +#if defined(_MSC_VER) && defined(_M_IX86) + return _mm256_setr_epi32( + (int)i0, (int)(i0 >> 32), (int)i1, (int)(i1 >> 32), + (int)i2, (int)(i2 >> 32), (int)i3, (int)(i3 >> 32) + ); +#else + return _mm256_setr_epi64x(i0, i1, i2, i3); +#endif +} + +NPY_FINLINE __m256 npyv__setr_ps(float i0, float i1, float i2, float i3, float i4, float i5, + float i6, float i7) +{ + return _mm256_setr_ps(i0, i1, i2, i3, i4, i5, i6, i7); +} +NPY_FINLINE __m256d npyv__setr_pd(double i0, double i1, double i2, double i3) +{ + return _mm256_setr_pd(i0, i1, i2, i3); +} +#define npyv_setf_u8(FILL, ...) npyv__setr_epi8(NPYV__SET_FILL_32(char, FILL, __VA_ARGS__)) +#define npyv_setf_s8(FILL, ...) npyv__setr_epi8(NPYV__SET_FILL_32(char, FILL, __VA_ARGS__)) +#define npyv_setf_u16(FILL, ...) npyv__setr_epi16(NPYV__SET_FILL_16(short, FILL, __VA_ARGS__)) +#define npyv_setf_s16(FILL, ...) npyv__setr_epi16(NPYV__SET_FILL_16(short, FILL, __VA_ARGS__)) +#define npyv_setf_u32(FILL, ...) npyv__setr_epi32(NPYV__SET_FILL_8(int, FILL, __VA_ARGS__)) +#define npyv_setf_s32(FILL, ...) npyv__setr_epi32(NPYV__SET_FILL_8(int, FILL, __VA_ARGS__)) +#define npyv_setf_u64(FILL, ...) npyv__setr_epi64(NPYV__SET_FILL_4(npy_uint64, FILL, __VA_ARGS__)) +#define npyv_setf_s64(FILL, ...) npyv__setr_epi64(NPYV__SET_FILL_4(npy_int64, FILL, __VA_ARGS__)) +#define npyv_setf_f32(FILL, ...) npyv__setr_ps(NPYV__SET_FILL_8(float, FILL, __VA_ARGS__)) +#define npyv_setf_f64(FILL, ...) npyv__setr_pd(NPYV__SET_FILL_4(double, FILL, __VA_ARGS__)) + +// vector with specific values set to each lane and +// set zero to all remained lanes +#define npyv_set_u8(...) npyv_setf_u8(0, __VA_ARGS__) +#define npyv_set_s8(...) npyv_setf_s8(0, __VA_ARGS__) +#define npyv_set_u16(...) npyv_setf_u16(0, __VA_ARGS__) +#define npyv_set_s16(...) npyv_setf_s16(0, __VA_ARGS__) +#define npyv_set_u32(...) npyv_setf_u32(0, __VA_ARGS__) +#define npyv_set_s32(...) npyv_setf_s32(0, __VA_ARGS__) +#define npyv_set_u64(...) npyv_setf_u64(0, __VA_ARGS__) +#define npyv_set_s64(...) npyv_setf_s64(0, __VA_ARGS__) +#define npyv_set_f32(...) npyv_setf_f32(0, __VA_ARGS__) +#define npyv_set_f64(...) npyv_setf_f64(0, __VA_ARGS__) + +// Per lane select +#define npyv_select_u8(MASK, A, B) _mm256_blendv_epi8(B, A, MASK) +#define npyv_select_s8 npyv_select_u8 +#define npyv_select_u16 npyv_select_u8 +#define npyv_select_s16 npyv_select_u8 +#define npyv_select_u32 npyv_select_u8 +#define npyv_select_s32 npyv_select_u8 +#define npyv_select_u64 npyv_select_u8 +#define npyv_select_s64 npyv_select_u8 +#define npyv_select_f32(MASK, A, B) _mm256_blendv_ps(B, A, _mm256_castsi256_ps(MASK)) +#define npyv_select_f64(MASK, A, B) _mm256_blendv_pd(B, A, _mm256_castsi256_pd(MASK)) + +// extract the first vector's lane +#define npyv_extract0_u8(A) ((npy_uint8)_mm_cvtsi128_si32(_mm256_castsi256_si128(A))) +#define npyv_extract0_s8(A) ((npy_int8)_mm_cvtsi128_si32(_mm256_castsi256_si128(A))) +#define npyv_extract0_u16(A) ((npy_uint16)_mm_cvtsi128_si32(_mm256_castsi256_si128(A))) +#define npyv_extract0_s16(A) ((npy_int16)_mm_cvtsi128_si32(_mm256_castsi256_si128(A))) +#define npyv_extract0_u32(A) ((npy_uint32)_mm_cvtsi128_si32(_mm256_castsi256_si128(A))) +#define npyv_extract0_s32(A) ((npy_int32)_mm_cvtsi128_si32(_mm256_castsi256_si128(A))) +#define npyv_extract0_u64(A) ((npy_uint64)npyv128_cvtsi128_si64(_mm256_castsi256_si128(A))) +#define npyv_extract0_s64(A) ((npy_int64)npyv128_cvtsi128_si64(_mm256_castsi256_si128(A))) +#define npyv_extract0_f32(A) _mm_cvtss_f32(_mm256_castps256_ps128(A)) +#define npyv_extract0_f64(A) _mm_cvtsd_f64(_mm256_castpd256_pd128(A)) + +// Reinterpret +#define npyv_reinterpret_u8_u8(X) X +#define npyv_reinterpret_u8_s8(X) X +#define npyv_reinterpret_u8_u16(X) X +#define npyv_reinterpret_u8_s16(X) X +#define npyv_reinterpret_u8_u32(X) X +#define npyv_reinterpret_u8_s32(X) X +#define npyv_reinterpret_u8_u64(X) X +#define npyv_reinterpret_u8_s64(X) X +#define npyv_reinterpret_u8_f32 _mm256_castps_si256 +#define npyv_reinterpret_u8_f64 _mm256_castpd_si256 + +#define npyv_reinterpret_s8_s8(X) X +#define npyv_reinterpret_s8_u8(X) X +#define npyv_reinterpret_s8_u16(X) X +#define npyv_reinterpret_s8_s16(X) X +#define npyv_reinterpret_s8_u32(X) X +#define npyv_reinterpret_s8_s32(X) X +#define npyv_reinterpret_s8_u64(X) X +#define npyv_reinterpret_s8_s64(X) X +#define npyv_reinterpret_s8_f32 _mm256_castps_si256 +#define npyv_reinterpret_s8_f64 _mm256_castpd_si256 + +#define npyv_reinterpret_u16_u16(X) X +#define npyv_reinterpret_u16_u8(X) X +#define npyv_reinterpret_u16_s8(X) X +#define npyv_reinterpret_u16_s16(X) X +#define npyv_reinterpret_u16_u32(X) X +#define npyv_reinterpret_u16_s32(X) X +#define npyv_reinterpret_u16_u64(X) X +#define npyv_reinterpret_u16_s64(X) X +#define npyv_reinterpret_u16_f32 _mm256_castps_si256 +#define npyv_reinterpret_u16_f64 _mm256_castpd_si256 + +#define npyv_reinterpret_s16_s16(X) X +#define npyv_reinterpret_s16_u8(X) X +#define npyv_reinterpret_s16_s8(X) X +#define npyv_reinterpret_s16_u16(X) X +#define npyv_reinterpret_s16_u32(X) X +#define npyv_reinterpret_s16_s32(X) X +#define npyv_reinterpret_s16_u64(X) X +#define npyv_reinterpret_s16_s64(X) X +#define npyv_reinterpret_s16_f32 _mm256_castps_si256 +#define npyv_reinterpret_s16_f64 _mm256_castpd_si256 + +#define npyv_reinterpret_u32_u32(X) X +#define npyv_reinterpret_u32_u8(X) X +#define npyv_reinterpret_u32_s8(X) X +#define npyv_reinterpret_u32_u16(X) X +#define npyv_reinterpret_u32_s16(X) X +#define npyv_reinterpret_u32_s32(X) X +#define npyv_reinterpret_u32_u64(X) X +#define npyv_reinterpret_u32_s64(X) X +#define npyv_reinterpret_u32_f32 _mm256_castps_si256 +#define npyv_reinterpret_u32_f64 _mm256_castpd_si256 + +#define npyv_reinterpret_s32_s32(X) X +#define npyv_reinterpret_s32_u8(X) X +#define npyv_reinterpret_s32_s8(X) X +#define npyv_reinterpret_s32_u16(X) X +#define npyv_reinterpret_s32_s16(X) X +#define npyv_reinterpret_s32_u32(X) X +#define npyv_reinterpret_s32_u64(X) X +#define npyv_reinterpret_s32_s64(X) X +#define npyv_reinterpret_s32_f32 _mm256_castps_si256 +#define npyv_reinterpret_s32_f64 _mm256_castpd_si256 + +#define npyv_reinterpret_u64_u64(X) X +#define npyv_reinterpret_u64_u8(X) X +#define npyv_reinterpret_u64_s8(X) X +#define npyv_reinterpret_u64_u16(X) X +#define npyv_reinterpret_u64_s16(X) X +#define npyv_reinterpret_u64_u32(X) X +#define npyv_reinterpret_u64_s32(X) X +#define npyv_reinterpret_u64_s64(X) X +#define npyv_reinterpret_u64_f32 _mm256_castps_si256 +#define npyv_reinterpret_u64_f64 _mm256_castpd_si256 + +#define npyv_reinterpret_s64_s64(X) X +#define npyv_reinterpret_s64_u8(X) X +#define npyv_reinterpret_s64_s8(X) X +#define npyv_reinterpret_s64_u16(X) X +#define npyv_reinterpret_s64_s16(X) X +#define npyv_reinterpret_s64_u32(X) X +#define npyv_reinterpret_s64_s32(X) X +#define npyv_reinterpret_s64_u64(X) X +#define npyv_reinterpret_s64_f32 _mm256_castps_si256 +#define npyv_reinterpret_s64_f64 _mm256_castpd_si256 + +#define npyv_reinterpret_f32_f32(X) X +#define npyv_reinterpret_f32_u8 _mm256_castsi256_ps +#define npyv_reinterpret_f32_s8 _mm256_castsi256_ps +#define npyv_reinterpret_f32_u16 _mm256_castsi256_ps +#define npyv_reinterpret_f32_s16 _mm256_castsi256_ps +#define npyv_reinterpret_f32_u32 _mm256_castsi256_ps +#define npyv_reinterpret_f32_s32 _mm256_castsi256_ps +#define npyv_reinterpret_f32_u64 _mm256_castsi256_ps +#define npyv_reinterpret_f32_s64 _mm256_castsi256_ps +#define npyv_reinterpret_f32_f64 _mm256_castpd_ps + +#define npyv_reinterpret_f64_f64(X) X +#define npyv_reinterpret_f64_u8 _mm256_castsi256_pd +#define npyv_reinterpret_f64_s8 _mm256_castsi256_pd +#define npyv_reinterpret_f64_u16 _mm256_castsi256_pd +#define npyv_reinterpret_f64_s16 _mm256_castsi256_pd +#define npyv_reinterpret_f64_u32 _mm256_castsi256_pd +#define npyv_reinterpret_f64_s32 _mm256_castsi256_pd +#define npyv_reinterpret_f64_u64 _mm256_castsi256_pd +#define npyv_reinterpret_f64_s64 _mm256_castsi256_pd +#define npyv_reinterpret_f64_f32 _mm256_castps_pd + +#define npyv_cleanup _mm256_zeroall + +#endif // _NPY_SIMD_SSE_MISC_H diff --git a/mkl_umath/src/npyv/avx2/operators.h b/mkl_umath/src/npyv/avx2/operators.h new file mode 100644 index 00000000..7b9b6a34 --- /dev/null +++ b/mkl_umath/src/npyv/avx2/operators.h @@ -0,0 +1,282 @@ +#ifndef NPY_SIMD + #error "Not a standalone header" +#endif + +#ifndef _NPY_SIMD_AVX2_OPERATORS_H +#define _NPY_SIMD_AVX2_OPERATORS_H + +/*************************** + * Shifting + ***************************/ + +// left +#define npyv_shl_u16(A, C) _mm256_sll_epi16(A, _mm_cvtsi32_si128(C)) +#define npyv_shl_s16(A, C) _mm256_sll_epi16(A, _mm_cvtsi32_si128(C)) +#define npyv_shl_u32(A, C) _mm256_sll_epi32(A, _mm_cvtsi32_si128(C)) +#define npyv_shl_s32(A, C) _mm256_sll_epi32(A, _mm_cvtsi32_si128(C)) +#define npyv_shl_u64(A, C) _mm256_sll_epi64(A, _mm_cvtsi32_si128(C)) +#define npyv_shl_s64(A, C) _mm256_sll_epi64(A, _mm_cvtsi32_si128(C)) + +// left by an immediate constant +#define npyv_shli_u16 _mm256_slli_epi16 +#define npyv_shli_s16 _mm256_slli_epi16 +#define npyv_shli_u32 _mm256_slli_epi32 +#define npyv_shli_s32 _mm256_slli_epi32 +#define npyv_shli_u64 _mm256_slli_epi64 +#define npyv_shli_s64 _mm256_slli_epi64 + +// right +#define npyv_shr_u16(A, C) _mm256_srl_epi16(A, _mm_cvtsi32_si128(C)) +#define npyv_shr_s16(A, C) _mm256_sra_epi16(A, _mm_cvtsi32_si128(C)) +#define npyv_shr_u32(A, C) _mm256_srl_epi32(A, _mm_cvtsi32_si128(C)) +#define npyv_shr_s32(A, C) _mm256_sra_epi32(A, _mm_cvtsi32_si128(C)) +#define npyv_shr_u64(A, C) _mm256_srl_epi64(A, _mm_cvtsi32_si128(C)) +NPY_FINLINE __m256i npyv_shr_s64(__m256i a, int c) +{ + const __m256i sbit = _mm256_set1_epi64x(0x8000000000000000); + const __m128i c64 = _mm_cvtsi32_si128(c); + __m256i r = _mm256_srl_epi64(_mm256_add_epi64(a, sbit), c64); + return _mm256_sub_epi64(r, _mm256_srl_epi64(sbit, c64)); +} + +// right by an immediate constant +#define npyv_shri_u16 _mm256_srli_epi16 +#define npyv_shri_s16 _mm256_srai_epi16 +#define npyv_shri_u32 _mm256_srli_epi32 +#define npyv_shri_s32 _mm256_srai_epi32 +#define npyv_shri_u64 _mm256_srli_epi64 +#define npyv_shri_s64 npyv_shr_s64 + +/*************************** + * Logical + ***************************/ +// AND +#define npyv_and_u8 _mm256_and_si256 +#define npyv_and_s8 _mm256_and_si256 +#define npyv_and_u16 _mm256_and_si256 +#define npyv_and_s16 _mm256_and_si256 +#define npyv_and_u32 _mm256_and_si256 +#define npyv_and_s32 _mm256_and_si256 +#define npyv_and_u64 _mm256_and_si256 +#define npyv_and_s64 _mm256_and_si256 +#define npyv_and_f32 _mm256_and_ps +#define npyv_and_f64 _mm256_and_pd +#define npyv_and_b8 _mm256_and_si256 +#define npyv_and_b16 _mm256_and_si256 +#define npyv_and_b32 _mm256_and_si256 +#define npyv_and_b64 _mm256_and_si256 + +// OR +#define npyv_or_u8 _mm256_or_si256 +#define npyv_or_s8 _mm256_or_si256 +#define npyv_or_u16 _mm256_or_si256 +#define npyv_or_s16 _mm256_or_si256 +#define npyv_or_u32 _mm256_or_si256 +#define npyv_or_s32 _mm256_or_si256 +#define npyv_or_u64 _mm256_or_si256 +#define npyv_or_s64 _mm256_or_si256 +#define npyv_or_f32 _mm256_or_ps +#define npyv_or_f64 _mm256_or_pd +#define npyv_or_b8 _mm256_or_si256 +#define npyv_or_b16 _mm256_or_si256 +#define npyv_or_b32 _mm256_or_si256 +#define npyv_or_b64 _mm256_or_si256 + +// XOR +#define npyv_xor_u8 _mm256_xor_si256 +#define npyv_xor_s8 _mm256_xor_si256 +#define npyv_xor_u16 _mm256_xor_si256 +#define npyv_xor_s16 _mm256_xor_si256 +#define npyv_xor_u32 _mm256_xor_si256 +#define npyv_xor_s32 _mm256_xor_si256 +#define npyv_xor_u64 _mm256_xor_si256 +#define npyv_xor_s64 _mm256_xor_si256 +#define npyv_xor_f32 _mm256_xor_ps +#define npyv_xor_f64 _mm256_xor_pd +#define npyv_xor_b8 _mm256_xor_si256 +#define npyv_xor_b16 _mm256_xor_si256 +#define npyv_xor_b32 _mm256_xor_si256 +#define npyv_xor_b64 _mm256_xor_si256 + +// NOT +#define npyv_not_u8(A) _mm256_xor_si256(A, _mm256_set1_epi32(-1)) +#define npyv_not_s8 npyv_not_u8 +#define npyv_not_u16 npyv_not_u8 +#define npyv_not_s16 npyv_not_u8 +#define npyv_not_u32 npyv_not_u8 +#define npyv_not_s32 npyv_not_u8 +#define npyv_not_u64 npyv_not_u8 +#define npyv_not_s64 npyv_not_u8 +#define npyv_not_f32(A) _mm256_xor_ps(A, _mm256_castsi256_ps(_mm256_set1_epi32(-1))) +#define npyv_not_f64(A) _mm256_xor_pd(A, _mm256_castsi256_pd(_mm256_set1_epi32(-1))) +#define npyv_not_b8 npyv_not_u8 +#define npyv_not_b16 npyv_not_u8 +#define npyv_not_b32 npyv_not_u8 +#define npyv_not_b64 npyv_not_u8 + +// ANDC, ORC and XNOR +#define npyv_andc_u8(A, B) _mm256_andnot_si256(B, A) +#define npyv_andc_b8(A, B) _mm256_andnot_si256(B, A) +#define npyv_orc_b8(A, B) npyv_or_b8(npyv_not_b8(B), A) +#define npyv_xnor_b8 _mm256_cmpeq_epi8 + +/*************************** + * Comparison + ***************************/ + +// int Equal +#define npyv_cmpeq_u8 _mm256_cmpeq_epi8 +#define npyv_cmpeq_s8 _mm256_cmpeq_epi8 +#define npyv_cmpeq_u16 _mm256_cmpeq_epi16 +#define npyv_cmpeq_s16 _mm256_cmpeq_epi16 +#define npyv_cmpeq_u32 _mm256_cmpeq_epi32 +#define npyv_cmpeq_s32 _mm256_cmpeq_epi32 +#define npyv_cmpeq_u64 _mm256_cmpeq_epi64 +#define npyv_cmpeq_s64 _mm256_cmpeq_epi64 + +// int Not Equal +#define npyv_cmpneq_u8(A, B) npyv_not_u8(_mm256_cmpeq_epi8(A, B)) +#define npyv_cmpneq_s8 npyv_cmpneq_u8 +#define npyv_cmpneq_u16(A, B) npyv_not_u16(_mm256_cmpeq_epi16(A, B)) +#define npyv_cmpneq_s16 npyv_cmpneq_u16 +#define npyv_cmpneq_u32(A, B) npyv_not_u32(_mm256_cmpeq_epi32(A, B)) +#define npyv_cmpneq_s32 npyv_cmpneq_u32 +#define npyv_cmpneq_u64(A, B) npyv_not_u64(_mm256_cmpeq_epi64(A, B)) +#define npyv_cmpneq_s64 npyv_cmpneq_u64 + +// signed greater than +#define npyv_cmpgt_s8 _mm256_cmpgt_epi8 +#define npyv_cmpgt_s16 _mm256_cmpgt_epi16 +#define npyv_cmpgt_s32 _mm256_cmpgt_epi32 +#define npyv_cmpgt_s64 _mm256_cmpgt_epi64 + +// signed greater than or equal +#define npyv_cmpge_s8(A, B) npyv_not_s8(_mm256_cmpgt_epi8(B, A)) +#define npyv_cmpge_s16(A, B) npyv_not_s16(_mm256_cmpgt_epi16(B, A)) +#define npyv_cmpge_s32(A, B) npyv_not_s32(_mm256_cmpgt_epi32(B, A)) +#define npyv_cmpge_s64(A, B) npyv_not_s64(_mm256_cmpgt_epi64(B, A)) + +// unsigned greater than +#define NPYV_IMPL_AVX2_UNSIGNED_GT(LEN, SIGN) \ + NPY_FINLINE __m256i npyv_cmpgt_u##LEN(__m256i a, __m256i b) \ + { \ + const __m256i sbit = _mm256_set1_epi32(SIGN); \ + return _mm256_cmpgt_epi##LEN( \ + _mm256_xor_si256(a, sbit), _mm256_xor_si256(b, sbit) \ + ); \ + } + +NPYV_IMPL_AVX2_UNSIGNED_GT(8, 0x80808080) +NPYV_IMPL_AVX2_UNSIGNED_GT(16, 0x80008000) +NPYV_IMPL_AVX2_UNSIGNED_GT(32, 0x80000000) + +NPY_FINLINE __m256i npyv_cmpgt_u64(__m256i a, __m256i b) +{ + const __m256i sbit = _mm256_set1_epi64x(0x8000000000000000); + return _mm256_cmpgt_epi64(_mm256_xor_si256(a, sbit), _mm256_xor_si256(b, sbit)); +} + +// unsigned greater than or equal +NPY_FINLINE __m256i npyv_cmpge_u8(__m256i a, __m256i b) +{ return _mm256_cmpeq_epi8(a, _mm256_max_epu8(a, b)); } +NPY_FINLINE __m256i npyv_cmpge_u16(__m256i a, __m256i b) +{ return _mm256_cmpeq_epi16(a, _mm256_max_epu16(a, b)); } +NPY_FINLINE __m256i npyv_cmpge_u32(__m256i a, __m256i b) +{ return _mm256_cmpeq_epi32(a, _mm256_max_epu32(a, b)); } +#define npyv_cmpge_u64(A, B) npyv_not_u64(npyv_cmpgt_u64(B, A)) + +// less than +#define npyv_cmplt_u8(A, B) npyv_cmpgt_u8(B, A) +#define npyv_cmplt_s8(A, B) npyv_cmpgt_s8(B, A) +#define npyv_cmplt_u16(A, B) npyv_cmpgt_u16(B, A) +#define npyv_cmplt_s16(A, B) npyv_cmpgt_s16(B, A) +#define npyv_cmplt_u32(A, B) npyv_cmpgt_u32(B, A) +#define npyv_cmplt_s32(A, B) npyv_cmpgt_s32(B, A) +#define npyv_cmplt_u64(A, B) npyv_cmpgt_u64(B, A) +#define npyv_cmplt_s64(A, B) npyv_cmpgt_s64(B, A) + +// less than or equal +#define npyv_cmple_u8(A, B) npyv_cmpge_u8(B, A) +#define npyv_cmple_s8(A, B) npyv_cmpge_s8(B, A) +#define npyv_cmple_u16(A, B) npyv_cmpge_u16(B, A) +#define npyv_cmple_s16(A, B) npyv_cmpge_s16(B, A) +#define npyv_cmple_u32(A, B) npyv_cmpge_u32(B, A) +#define npyv_cmple_s32(A, B) npyv_cmpge_s32(B, A) +#define npyv_cmple_u64(A, B) npyv_cmpge_u64(B, A) +#define npyv_cmple_s64(A, B) npyv_cmpge_s64(B, A) + +// precision comparison (ordered) +#define npyv_cmpeq_f32(A, B) _mm256_castps_si256(_mm256_cmp_ps(A, B, _CMP_EQ_OQ)) +#define npyv_cmpeq_f64(A, B) _mm256_castpd_si256(_mm256_cmp_pd(A, B, _CMP_EQ_OQ)) +#define npyv_cmpneq_f32(A, B) _mm256_castps_si256(_mm256_cmp_ps(A, B, _CMP_NEQ_UQ)) +#define npyv_cmpneq_f64(A, B) _mm256_castpd_si256(_mm256_cmp_pd(A, B, _CMP_NEQ_UQ)) +#define npyv_cmplt_f32(A, B) _mm256_castps_si256(_mm256_cmp_ps(A, B, _CMP_LT_OQ)) +#define npyv_cmplt_f64(A, B) _mm256_castpd_si256(_mm256_cmp_pd(A, B, _CMP_LT_OQ)) +#define npyv_cmple_f32(A, B) _mm256_castps_si256(_mm256_cmp_ps(A, B, _CMP_LE_OQ)) +#define npyv_cmple_f64(A, B) _mm256_castpd_si256(_mm256_cmp_pd(A, B, _CMP_LE_OQ)) +#define npyv_cmpgt_f32(A, B) _mm256_castps_si256(_mm256_cmp_ps(A, B, _CMP_GT_OQ)) +#define npyv_cmpgt_f64(A, B) _mm256_castpd_si256(_mm256_cmp_pd(A, B, _CMP_GT_OQ)) +#define npyv_cmpge_f32(A, B) _mm256_castps_si256(_mm256_cmp_ps(A, B, _CMP_GE_OQ)) +#define npyv_cmpge_f64(A, B) _mm256_castpd_si256(_mm256_cmp_pd(A, B, _CMP_GE_OQ)) + +// check special cases +NPY_FINLINE npyv_b32 npyv_notnan_f32(npyv_f32 a) +{ return _mm256_castps_si256(_mm256_cmp_ps(a, a, _CMP_ORD_Q)); } +NPY_FINLINE npyv_b64 npyv_notnan_f64(npyv_f64 a) +{ return _mm256_castpd_si256(_mm256_cmp_pd(a, a, _CMP_ORD_Q)); } + +// Test cross all vector lanes +// any: returns true if any of the elements is not equal to zero +// all: returns true if all elements are not equal to zero +#define NPYV_IMPL_AVX2_ANYALL(SFX) \ + NPY_FINLINE bool npyv_any_##SFX(npyv_##SFX a) \ + { return _mm256_movemask_epi8(a) != 0; } \ + NPY_FINLINE bool npyv_all_##SFX(npyv_##SFX a) \ + { return _mm256_movemask_epi8(a) == -1; } +NPYV_IMPL_AVX2_ANYALL(b8) +NPYV_IMPL_AVX2_ANYALL(b16) +NPYV_IMPL_AVX2_ANYALL(b32) +NPYV_IMPL_AVX2_ANYALL(b64) +#undef NPYV_IMPL_AVX2_ANYALL + +#define NPYV_IMPL_AVX2_ANYALL(SFX) \ + NPY_FINLINE bool npyv_any_##SFX(npyv_##SFX a) \ + { \ + return _mm256_movemask_epi8( \ + npyv_cmpeq_##SFX(a, npyv_zero_##SFX()) \ + ) != -1; \ + } \ + NPY_FINLINE bool npyv_all_##SFX(npyv_##SFX a) \ + { \ + return _mm256_movemask_epi8( \ + npyv_cmpeq_##SFX(a, npyv_zero_##SFX()) \ + ) == 0; \ + } +NPYV_IMPL_AVX2_ANYALL(u8) +NPYV_IMPL_AVX2_ANYALL(s8) +NPYV_IMPL_AVX2_ANYALL(u16) +NPYV_IMPL_AVX2_ANYALL(s16) +NPYV_IMPL_AVX2_ANYALL(u32) +NPYV_IMPL_AVX2_ANYALL(s32) +NPYV_IMPL_AVX2_ANYALL(u64) +NPYV_IMPL_AVX2_ANYALL(s64) +#undef NPYV_IMPL_AVX2_ANYALL + +#define NPYV_IMPL_AVX2_ANYALL(SFX, XSFX, MASK) \ + NPY_FINLINE bool npyv_any_##SFX(npyv_##SFX a) \ + { \ + return _mm256_movemask_##XSFX( \ + _mm256_cmp_##XSFX(a, npyv_zero_##SFX(), _CMP_EQ_OQ) \ + ) != MASK; \ + } \ + NPY_FINLINE bool npyv_all_##SFX(npyv_##SFX a) \ + { \ + return _mm256_movemask_##XSFX( \ + _mm256_cmp_##XSFX(a, npyv_zero_##SFX(), _CMP_EQ_OQ) \ + ) == 0; \ + } +NPYV_IMPL_AVX2_ANYALL(f32, ps, 0xff) +NPYV_IMPL_AVX2_ANYALL(f64, pd, 0xf) +#undef NPYV_IMPL_AVX2_ANYALL + +#endif // _NPY_SIMD_AVX2_OPERATORS_H diff --git a/mkl_umath/src/npyv/avx2/reorder.h b/mkl_umath/src/npyv/avx2/reorder.h new file mode 100644 index 00000000..9ebe0e7f --- /dev/null +++ b/mkl_umath/src/npyv/avx2/reorder.h @@ -0,0 +1,216 @@ +#ifndef NPY_SIMD + #error "Not a standalone header" +#endif + +#ifndef _NPY_SIMD_AVX2_REORDER_H +#define _NPY_SIMD_AVX2_REORDER_H + +// combine lower part of two vectors +#define npyv_combinel_u8(A, B) _mm256_permute2x128_si256(A, B, 0x20) +#define npyv_combinel_s8 npyv_combinel_u8 +#define npyv_combinel_u16 npyv_combinel_u8 +#define npyv_combinel_s16 npyv_combinel_u8 +#define npyv_combinel_u32 npyv_combinel_u8 +#define npyv_combinel_s32 npyv_combinel_u8 +#define npyv_combinel_u64 npyv_combinel_u8 +#define npyv_combinel_s64 npyv_combinel_u8 +#define npyv_combinel_f32(A, B) _mm256_permute2f128_ps(A, B, 0x20) +#define npyv_combinel_f64(A, B) _mm256_permute2f128_pd(A, B, 0x20) + +// combine higher part of two vectors +#define npyv_combineh_u8(A, B) _mm256_permute2x128_si256(A, B, 0x31) +#define npyv_combineh_s8 npyv_combineh_u8 +#define npyv_combineh_u16 npyv_combineh_u8 +#define npyv_combineh_s16 npyv_combineh_u8 +#define npyv_combineh_u32 npyv_combineh_u8 +#define npyv_combineh_s32 npyv_combineh_u8 +#define npyv_combineh_u64 npyv_combineh_u8 +#define npyv_combineh_s64 npyv_combineh_u8 +#define npyv_combineh_f32(A, B) _mm256_permute2f128_ps(A, B, 0x31) +#define npyv_combineh_f64(A, B) _mm256_permute2f128_pd(A, B, 0x31) + +// combine two vectors from lower and higher parts of two other vectors +NPY_FINLINE npyv_m256ix2 npyv__combine(__m256i a, __m256i b) +{ + npyv_m256ix2 r; + __m256i a1b0 = _mm256_permute2x128_si256(a, b, 0x21); + r.val[0] = _mm256_blend_epi32(a, a1b0, 0xF0); + r.val[1] = _mm256_blend_epi32(b, a1b0, 0xF); + return r; +} +NPY_FINLINE npyv_f32x2 npyv_combine_f32(__m256 a, __m256 b) +{ + npyv_f32x2 r; + __m256 a1b0 = _mm256_permute2f128_ps(a, b, 0x21); + r.val[0] = _mm256_blend_ps(a, a1b0, 0xF0); + r.val[1] = _mm256_blend_ps(b, a1b0, 0xF); + return r; +} +NPY_FINLINE npyv_f64x2 npyv_combine_f64(__m256d a, __m256d b) +{ + npyv_f64x2 r; + __m256d a1b0 = _mm256_permute2f128_pd(a, b, 0x21); + r.val[0] = _mm256_blend_pd(a, a1b0, 0xC); + r.val[1] = _mm256_blend_pd(b, a1b0, 0x3); + return r; +} +#define npyv_combine_u8 npyv__combine +#define npyv_combine_s8 npyv__combine +#define npyv_combine_u16 npyv__combine +#define npyv_combine_s16 npyv__combine +#define npyv_combine_u32 npyv__combine +#define npyv_combine_s32 npyv__combine +#define npyv_combine_u64 npyv__combine +#define npyv_combine_s64 npyv__combine + +// interleave two vectors +#define NPYV_IMPL_AVX2_ZIP_U(T_VEC, LEN) \ + NPY_FINLINE T_VEC##x2 npyv_zip_u##LEN(T_VEC a, T_VEC b) \ + { \ + __m256i ab0 = _mm256_unpacklo_epi##LEN(a, b); \ + __m256i ab1 = _mm256_unpackhi_epi##LEN(a, b); \ + return npyv__combine(ab0, ab1); \ + } + +NPYV_IMPL_AVX2_ZIP_U(npyv_u8, 8) +NPYV_IMPL_AVX2_ZIP_U(npyv_u16, 16) +NPYV_IMPL_AVX2_ZIP_U(npyv_u32, 32) +NPYV_IMPL_AVX2_ZIP_U(npyv_u64, 64) +#define npyv_zip_s8 npyv_zip_u8 +#define npyv_zip_s16 npyv_zip_u16 +#define npyv_zip_s32 npyv_zip_u32 +#define npyv_zip_s64 npyv_zip_u64 + +NPY_FINLINE npyv_f32x2 npyv_zip_f32(__m256 a, __m256 b) +{ + __m256 ab0 = _mm256_unpacklo_ps(a, b); + __m256 ab1 = _mm256_unpackhi_ps(a, b); + return npyv_combine_f32(ab0, ab1); +} +NPY_FINLINE npyv_f64x2 npyv_zip_f64(__m256d a, __m256d b) +{ + __m256d ab0 = _mm256_unpacklo_pd(a, b); + __m256d ab1 = _mm256_unpackhi_pd(a, b); + return npyv_combine_f64(ab0, ab1); +} + +// deinterleave two vectors +NPY_FINLINE npyv_u8x2 npyv_unzip_u8(npyv_u8 ab0, npyv_u8 ab1) +{ + const __m256i idx = _mm256_setr_epi8( + 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15, + 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15 + ); + __m256i ab_03 = _mm256_shuffle_epi8(ab0, idx); + __m256i ab_12 = _mm256_shuffle_epi8(ab1, idx); + npyv_u8x2 ab_lh = npyv_combine_u8(ab_03, ab_12); + npyv_u8x2 r; + r.val[0] = _mm256_unpacklo_epi64(ab_lh.val[0], ab_lh.val[1]); + r.val[1] = _mm256_unpackhi_epi64(ab_lh.val[0], ab_lh.val[1]); + return r; +} +#define npyv_unzip_s8 npyv_unzip_u8 + +NPY_FINLINE npyv_u16x2 npyv_unzip_u16(npyv_u16 ab0, npyv_u16 ab1) +{ + const __m256i idx = _mm256_setr_epi8( + 0,1, 4,5, 8,9, 12,13, 2,3, 6,7, 10,11, 14,15, + 0,1, 4,5, 8,9, 12,13, 2,3, 6,7, 10,11, 14,15 + ); + __m256i ab_03 = _mm256_shuffle_epi8(ab0, idx); + __m256i ab_12 = _mm256_shuffle_epi8(ab1, idx); + npyv_u16x2 ab_lh = npyv_combine_u16(ab_03, ab_12); + npyv_u16x2 r; + r.val[0] = _mm256_unpacklo_epi64(ab_lh.val[0], ab_lh.val[1]); + r.val[1] = _mm256_unpackhi_epi64(ab_lh.val[0], ab_lh.val[1]); + return r; +} +#define npyv_unzip_s16 npyv_unzip_u16 + +NPY_FINLINE npyv_u32x2 npyv_unzip_u32(npyv_u32 ab0, npyv_u32 ab1) +{ + const __m256i idx = npyv_set_u32(0, 2, 4, 6, 1, 3, 5, 7); + __m256i abl = _mm256_permutevar8x32_epi32(ab0, idx); + __m256i abh = _mm256_permutevar8x32_epi32(ab1, idx); + return npyv_combine_u32(abl, abh); +} +#define npyv_unzip_s32 npyv_unzip_u32 + +NPY_FINLINE npyv_u64x2 npyv_unzip_u64(npyv_u64 ab0, npyv_u64 ab1) +{ + npyv_u64x2 ab_lh = npyv_combine_u64(ab0, ab1); + npyv_u64x2 r; + r.val[0] = _mm256_unpacklo_epi64(ab_lh.val[0], ab_lh.val[1]); + r.val[1] = _mm256_unpackhi_epi64(ab_lh.val[0], ab_lh.val[1]); + return r; +} +#define npyv_unzip_s64 npyv_unzip_u64 + +NPY_FINLINE npyv_f32x2 npyv_unzip_f32(npyv_f32 ab0, npyv_f32 ab1) +{ + const __m256i idx = npyv_set_u32(0, 2, 4, 6, 1, 3, 5, 7); + __m256 abl = _mm256_permutevar8x32_ps(ab0, idx); + __m256 abh = _mm256_permutevar8x32_ps(ab1, idx); + return npyv_combine_f32(abl, abh); +} + +NPY_FINLINE npyv_f64x2 npyv_unzip_f64(npyv_f64 ab0, npyv_f64 ab1) +{ + npyv_f64x2 ab_lh = npyv_combine_f64(ab0, ab1); + npyv_f64x2 r; + r.val[0] = _mm256_unpacklo_pd(ab_lh.val[0], ab_lh.val[1]); + r.val[1] = _mm256_unpackhi_pd(ab_lh.val[0], ab_lh.val[1]); + return r; +} + +// Reverse elements of each 64-bit lane +NPY_FINLINE npyv_u8 npyv_rev64_u8(npyv_u8 a) +{ + const __m256i idx = _mm256_setr_epi8( + 7, 6, 5, 4, 3, 2, 1, 0,/*64*/15, 14, 13, 12, 11, 10, 9, 8, + 7, 6, 5, 4, 3, 2, 1, 0,/*64*/15, 14, 13, 12, 11, 10, 9, 8 + ); + return _mm256_shuffle_epi8(a, idx); +} +#define npyv_rev64_s8 npyv_rev64_u8 + +NPY_FINLINE npyv_u16 npyv_rev64_u16(npyv_u16 a) +{ + const __m256i idx = _mm256_setr_epi8( + 6, 7, 4, 5, 2, 3, 0, 1,/*64*/14, 15, 12, 13, 10, 11, 8, 9, + 6, 7, 4, 5, 2, 3, 0, 1,/*64*/14, 15, 12, 13, 10, 11, 8, 9 + ); + return _mm256_shuffle_epi8(a, idx); +} +#define npyv_rev64_s16 npyv_rev64_u16 + +NPY_FINLINE npyv_u32 npyv_rev64_u32(npyv_u32 a) +{ + return _mm256_shuffle_epi32(a, _MM_SHUFFLE(2, 3, 0, 1)); +} +#define npyv_rev64_s32 npyv_rev64_u32 + +NPY_FINLINE npyv_f32 npyv_rev64_f32(npyv_f32 a) +{ + return _mm256_shuffle_ps(a, a, _MM_SHUFFLE(2, 3, 0, 1)); +} + +// Permuting the elements of each 128-bit lane by immediate index for +// each element. +#define npyv_permi128_u32(A, E0, E1, E2, E3) \ + _mm256_shuffle_epi32(A, _MM_SHUFFLE(E3, E2, E1, E0)) + +#define npyv_permi128_s32 npyv_permi128_u32 + +#define npyv_permi128_u64(A, E0, E1) \ + _mm256_shuffle_epi32(A, _MM_SHUFFLE(((E1)<<1)+1, ((E1)<<1), ((E0)<<1)+1, ((E0)<<1))) + +#define npyv_permi128_s64 npyv_permi128_u64 + +#define npyv_permi128_f32(A, E0, E1, E2, E3) \ + _mm256_permute_ps(A, _MM_SHUFFLE(E3, E2, E1, E0)) + +#define npyv_permi128_f64(A, E0, E1) \ + _mm256_permute_pd(A, ((E1)<<3) | ((E0)<<2) | ((E1)<<1) | (E0)) + +#endif // _NPY_SIMD_AVX2_REORDER_H diff --git a/mkl_umath/src/npyv/avx2/utils.h b/mkl_umath/src/npyv/avx2/utils.h new file mode 100644 index 00000000..24f1af5d --- /dev/null +++ b/mkl_umath/src/npyv/avx2/utils.h @@ -0,0 +1,21 @@ +#ifndef NPY_SIMD + #error "Not a standalone header" +#endif + +#ifndef _NPY_SIMD_AVX2_UTILS_H +#define _NPY_SIMD_AVX2_UTILS_H + +#define npyv256_shuffle_odd(A) _mm256_permute4x64_epi64(A, _MM_SHUFFLE(3, 1, 2, 0)) +#define npyv256_shuffle_odd_ps(A) _mm256_castsi256_ps(npyv256_shuffle_odd(_mm256_castps_si256(A))) +#define npyv256_shuffle_odd_pd(A) _mm256_permute4x64_pd(A, _MM_SHUFFLE(3, 1, 2, 0)) + +NPY_FINLINE __m256i npyv256_mul_u8(__m256i a, __m256i b) +{ + const __m256i mask = _mm256_set1_epi32(0xFF00FF00); + __m256i even = _mm256_mullo_epi16(a, b); + __m256i odd = _mm256_mullo_epi16(_mm256_srai_epi16(a, 8), _mm256_srai_epi16(b, 8)); + odd = _mm256_slli_epi16(odd, 8); + return _mm256_blendv_epi8(even, odd, mask); +} + +#endif // _NPY_SIMD_AVX2_UTILS_H diff --git a/mkl_umath/src/npyv/avx512/arithmetic.h b/mkl_umath/src/npyv/avx512/arithmetic.h new file mode 100644 index 00000000..a63da87d --- /dev/null +++ b/mkl_umath/src/npyv/avx512/arithmetic.h @@ -0,0 +1,446 @@ +#ifndef NPY_SIMD + #error "Not a standalone header" +#endif + +#ifndef _NPY_SIMD_AVX512_ARITHMETIC_H +#define _NPY_SIMD_AVX512_ARITHMETIC_H + +#include "../avx2/utils.h" +#include "../sse/utils.h" +/*************************** + * Addition + ***************************/ +// non-saturated +#ifdef NPY_HAVE_AVX512BW + #define npyv_add_u8 _mm512_add_epi8 + #define npyv_add_u16 _mm512_add_epi16 +#else + NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv_add_u8, _mm256_add_epi8) + NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv_add_u16, _mm256_add_epi16) +#endif +#define npyv_add_s8 npyv_add_u8 +#define npyv_add_s16 npyv_add_u16 +#define npyv_add_u32 _mm512_add_epi32 +#define npyv_add_s32 _mm512_add_epi32 +#define npyv_add_u64 _mm512_add_epi64 +#define npyv_add_s64 _mm512_add_epi64 +#define npyv_add_f32 _mm512_add_ps +#define npyv_add_f64 _mm512_add_pd + +// saturated +#ifdef NPY_HAVE_AVX512BW + #define npyv_adds_u8 _mm512_adds_epu8 + #define npyv_adds_s8 _mm512_adds_epi8 + #define npyv_adds_u16 _mm512_adds_epu16 + #define npyv_adds_s16 _mm512_adds_epi16 +#else + NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv_adds_u8, _mm256_adds_epu8) + NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv_adds_s8, _mm256_adds_epi8) + NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv_adds_u16, _mm256_adds_epu16) + NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv_adds_s16, _mm256_adds_epi16) +#endif +// TODO: rest, after implement Packs intrins + +/*************************** + * Subtraction + ***************************/ +// non-saturated +#ifdef NPY_HAVE_AVX512BW + #define npyv_sub_u8 _mm512_sub_epi8 + #define npyv_sub_u16 _mm512_sub_epi16 +#else + NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv_sub_u8, _mm256_sub_epi8) + NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv_sub_u16, _mm256_sub_epi16) +#endif +#define npyv_sub_s8 npyv_sub_u8 +#define npyv_sub_s16 npyv_sub_u16 +#define npyv_sub_u32 _mm512_sub_epi32 +#define npyv_sub_s32 _mm512_sub_epi32 +#define npyv_sub_u64 _mm512_sub_epi64 +#define npyv_sub_s64 _mm512_sub_epi64 +#define npyv_sub_f32 _mm512_sub_ps +#define npyv_sub_f64 _mm512_sub_pd + +// saturated +#ifdef NPY_HAVE_AVX512BW + #define npyv_subs_u8 _mm512_subs_epu8 + #define npyv_subs_s8 _mm512_subs_epi8 + #define npyv_subs_u16 _mm512_subs_epu16 + #define npyv_subs_s16 _mm512_subs_epi16 +#else + NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv_subs_u8, _mm256_subs_epu8) + NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv_subs_s8, _mm256_subs_epi8) + NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv_subs_u16, _mm256_subs_epu16) + NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv_subs_s16, _mm256_subs_epi16) +#endif +// TODO: rest, after implement Packs intrins + +/*************************** + * Multiplication + ***************************/ +// non-saturated +#ifdef NPY_HAVE_AVX512BW +NPY_FINLINE __m512i npyv_mul_u8(__m512i a, __m512i b) +{ + __m512i even = _mm512_mullo_epi16(a, b); + __m512i odd = _mm512_mullo_epi16(_mm512_srai_epi16(a, 8), _mm512_srai_epi16(b, 8)); + odd = _mm512_slli_epi16(odd, 8); + return _mm512_mask_blend_epi8(0xAAAAAAAAAAAAAAAA, even, odd); +} +#else + NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv_mul_u8, npyv256_mul_u8) +#endif + +#ifdef NPY_HAVE_AVX512BW + #define npyv_mul_u16 _mm512_mullo_epi16 +#else + NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv_mul_u16, _mm256_mullo_epi16) +#endif +#define npyv_mul_s8 npyv_mul_u8 +#define npyv_mul_s16 npyv_mul_u16 +#define npyv_mul_u32 _mm512_mullo_epi32 +#define npyv_mul_s32 _mm512_mullo_epi32 +#define npyv_mul_f32 _mm512_mul_ps +#define npyv_mul_f64 _mm512_mul_pd + +// saturated +// TODO: after implement Packs intrins + +/*************************** + * Integer Division + ***************************/ +// See simd/intdiv.h for more clarification +// divide each unsigned 8-bit element by divisor +NPY_FINLINE npyv_u8 npyv_divc_u8(npyv_u8 a, const npyv_u8x3 divisor) +{ + const __m128i shf1 = _mm512_castsi512_si128(divisor.val[1]); + const __m128i shf2 = _mm512_castsi512_si128(divisor.val[2]); +#ifdef NPY_HAVE_AVX512BW + const __m512i bmask = _mm512_set1_epi32(0x00FF00FF); + const __m512i shf1b = _mm512_set1_epi8(0xFFU >> _mm_cvtsi128_si32(shf1)); + const __m512i shf2b = _mm512_set1_epi8(0xFFU >> _mm_cvtsi128_si32(shf2)); + // high part of unsigned multiplication + __m512i mulhi_even = _mm512_mullo_epi16(_mm512_and_si512(a, bmask), divisor.val[0]); + mulhi_even = _mm512_srli_epi16(mulhi_even, 8); + __m512i mulhi_odd = _mm512_mullo_epi16(_mm512_srli_epi16(a, 8), divisor.val[0]); + __m512i mulhi = _mm512_mask_mov_epi8(mulhi_even, 0xAAAAAAAAAAAAAAAA, mulhi_odd); + // floor(a/d) = (mulhi + ((a-mulhi) >> sh1)) >> sh2 + __m512i q = _mm512_sub_epi8(a, mulhi); + q = _mm512_and_si512(_mm512_srl_epi16(q, shf1), shf1b); + q = _mm512_add_epi8(mulhi, q); + q = _mm512_and_si512(_mm512_srl_epi16(q, shf2), shf2b); + return q; +#else + const __m256i bmask = _mm256_set1_epi32(0x00FF00FF); + const __m256i shf1b = _mm256_set1_epi8(0xFFU >> _mm_cvtsi128_si32(shf1)); + const __m256i shf2b = _mm256_set1_epi8(0xFFU >> _mm_cvtsi128_si32(shf2)); + const __m512i shf2bw= npyv512_combine_si256(shf2b, shf2b); + const __m256i mulc = npyv512_lower_si256(divisor.val[0]); + //// lower 256-bit + __m256i lo_a = npyv512_lower_si256(a); + // high part of unsigned multiplication + __m256i mulhi_even = _mm256_mullo_epi16(_mm256_and_si256(lo_a, bmask), mulc); + mulhi_even = _mm256_srli_epi16(mulhi_even, 8); + __m256i mulhi_odd = _mm256_mullo_epi16(_mm256_srli_epi16(lo_a, 8), mulc); + __m256i mulhi = _mm256_blendv_epi8(mulhi_odd, mulhi_even, bmask); + // floor(a/d) = (mulhi + ((a-mulhi) >> sh1)) >> sh2 + __m256i lo_q = _mm256_sub_epi8(lo_a, mulhi); + lo_q = _mm256_and_si256(_mm256_srl_epi16(lo_q, shf1), shf1b); + lo_q = _mm256_add_epi8(mulhi, lo_q); + lo_q = _mm256_srl_epi16(lo_q, shf2); // no sign extend + + //// higher 256-bit + __m256i hi_a = npyv512_higher_si256(a); + // high part of unsigned multiplication + mulhi_even = _mm256_mullo_epi16(_mm256_and_si256(hi_a, bmask), mulc); + mulhi_even = _mm256_srli_epi16(mulhi_even, 8); + mulhi_odd = _mm256_mullo_epi16(_mm256_srli_epi16(hi_a, 8), mulc); + mulhi = _mm256_blendv_epi8(mulhi_odd, mulhi_even, bmask); + // floor(a/d) = (mulhi + ((a-mulhi) >> sh1)) >> sh2 + __m256i hi_q = _mm256_sub_epi8(hi_a, mulhi); + hi_q = _mm256_and_si256(_mm256_srl_epi16(hi_q, shf1), shf1b); + hi_q = _mm256_add_epi8(mulhi, hi_q); + hi_q = _mm256_srl_epi16(hi_q, shf2); // no sign extend + return _mm512_and_si512(npyv512_combine_si256(lo_q, hi_q), shf2bw); // extend sign +#endif +} +// divide each signed 8-bit element by divisor (round towards zero) +NPY_FINLINE npyv_s16 npyv_divc_s16(npyv_s16 a, const npyv_s16x3 divisor); +NPY_FINLINE npyv_s8 npyv_divc_s8(npyv_s8 a, const npyv_s8x3 divisor) +{ + __m512i divc_even = npyv_divc_s16(npyv_shri_s16(npyv_shli_s16(a, 8), 8), divisor); + __m512i divc_odd = npyv_divc_s16(npyv_shri_s16(a, 8), divisor); + divc_odd = npyv_shli_s16(divc_odd, 8); +#ifdef NPY_HAVE_AVX512BW + return _mm512_mask_mov_epi8(divc_even, 0xAAAAAAAAAAAAAAAA, divc_odd); +#else + const __m512i bmask = _mm512_set1_epi32(0x00FF00FF); + return npyv_select_u8(bmask, divc_even, divc_odd); +#endif +} +// divide each unsigned 16-bit element by divisor +NPY_FINLINE npyv_u16 npyv_divc_u16(npyv_u16 a, const npyv_u16x3 divisor) +{ + const __m128i shf1 = _mm512_castsi512_si128(divisor.val[1]); + const __m128i shf2 = _mm512_castsi512_si128(divisor.val[2]); + // floor(a/d) = (mulhi + ((a-mulhi) >> sh1)) >> sh2 + #define NPYV__DIVC_U16(RLEN, A, MULC, R) \ + mulhi = _mm##RLEN##_mulhi_epu16(A, MULC); \ + R = _mm##RLEN##_sub_epi16(A, mulhi); \ + R = _mm##RLEN##_srl_epi16(R, shf1); \ + R = _mm##RLEN##_add_epi16(mulhi, R); \ + R = _mm##RLEN##_srl_epi16(R, shf2); + +#ifdef NPY_HAVE_AVX512BW + __m512i mulhi, q; + NPYV__DIVC_U16(512, a, divisor.val[0], q) + return q; +#else + const __m256i m = npyv512_lower_si256(divisor.val[0]); + __m256i lo_a = npyv512_lower_si256(a); + __m256i hi_a = npyv512_higher_si256(a); + + __m256i mulhi, lo_q, hi_q; + NPYV__DIVC_U16(256, lo_a, m, lo_q) + NPYV__DIVC_U16(256, hi_a, m, hi_q) + return npyv512_combine_si256(lo_q, hi_q); +#endif + #undef NPYV__DIVC_U16 +} +// divide each signed 16-bit element by divisor (round towards zero) +NPY_FINLINE npyv_s16 npyv_divc_s16(npyv_s16 a, const npyv_s16x3 divisor) +{ + const __m128i shf1 = _mm512_castsi512_si128(divisor.val[1]); + // q = ((a + mulhi) >> sh1) - XSIGN(a) + // trunc(a/d) = (q ^ dsign) - dsign + #define NPYV__DIVC_S16(RLEN, A, MULC, DSIGN, R) \ + mulhi = _mm##RLEN##_mulhi_epi16(A, MULC); \ + R = _mm##RLEN##_sra_epi16(_mm##RLEN##_add_epi16(A, mulhi), shf1); \ + R = _mm##RLEN##_sub_epi16(R, _mm##RLEN##_srai_epi16(A, 15)); \ + R = _mm##RLEN##_sub_epi16(_mm##RLEN##_xor_si##RLEN(R, DSIGN), DSIGN); + +#ifdef NPY_HAVE_AVX512BW + __m512i mulhi, q; + NPYV__DIVC_S16(512, a, divisor.val[0], divisor.val[2], q) + return q; +#else + const __m256i m = npyv512_lower_si256(divisor.val[0]); + const __m256i dsign = npyv512_lower_si256(divisor.val[2]); + __m256i lo_a = npyv512_lower_si256(a); + __m256i hi_a = npyv512_higher_si256(a); + + __m256i mulhi, lo_q, hi_q; + NPYV__DIVC_S16(256, lo_a, m, dsign, lo_q) + NPYV__DIVC_S16(256, hi_a, m, dsign, hi_q) + return npyv512_combine_si256(lo_q, hi_q); +#endif + #undef NPYV__DIVC_S16 +} +// divide each unsigned 32-bit element by divisor +NPY_FINLINE npyv_u32 npyv_divc_u32(npyv_u32 a, const npyv_u32x3 divisor) +{ + const __m128i shf1 = _mm512_castsi512_si128(divisor.val[1]); + const __m128i shf2 = _mm512_castsi512_si128(divisor.val[2]); + // high part of unsigned multiplication + __m512i mulhi_even = _mm512_srli_epi64(_mm512_mul_epu32(a, divisor.val[0]), 32); + __m512i mulhi_odd = _mm512_mul_epu32(_mm512_srli_epi64(a, 32), divisor.val[0]); + __m512i mulhi = _mm512_mask_mov_epi32(mulhi_even, 0xAAAA, mulhi_odd); + // floor(a/d) = (mulhi + ((a-mulhi) >> sh1)) >> sh2 + __m512i q = _mm512_sub_epi32(a, mulhi); + q = _mm512_srl_epi32(q, shf1); + q = _mm512_add_epi32(mulhi, q); + q = _mm512_srl_epi32(q, shf2); + return q; +} +// divide each signed 32-bit element by divisor (round towards zero) +NPY_FINLINE npyv_s32 npyv_divc_s32(npyv_s32 a, const npyv_s32x3 divisor) +{ + const __m128i shf1 = _mm512_castsi512_si128(divisor.val[1]); + // high part of signed multiplication + __m512i mulhi_even = _mm512_srli_epi64(_mm512_mul_epi32(a, divisor.val[0]), 32); + __m512i mulhi_odd = _mm512_mul_epi32(_mm512_srli_epi64(a, 32), divisor.val[0]); + __m512i mulhi = _mm512_mask_mov_epi32(mulhi_even, 0xAAAA, mulhi_odd); + // q = ((a + mulhi) >> sh1) - XSIGN(a) + // trunc(a/d) = (q ^ dsign) - dsign + __m512i q = _mm512_sra_epi32(_mm512_add_epi32(a, mulhi), shf1); + q = _mm512_sub_epi32(q, _mm512_srai_epi32(a, 31)); + q = _mm512_sub_epi32(_mm512_xor_si512(q, divisor.val[2]), divisor.val[2]); + return q; +} +// returns the high 64 bits of unsigned 64-bit multiplication +// xref https://stackoverflow.com/a/28827013 +NPY_FINLINE npyv_u64 npyv__mullhi_u64(npyv_u64 a, npyv_u64 b) +{ + __m512i lomask = npyv_setall_s64(0xffffffff); + __m512i a_hi = _mm512_srli_epi64(a, 32); // a0l, a0h, a1l, a1h + __m512i b_hi = _mm512_srli_epi64(b, 32); // b0l, b0h, b1l, b1h + // compute partial products + __m512i w0 = _mm512_mul_epu32(a, b); // a0l*b0l, a1l*b1l + __m512i w1 = _mm512_mul_epu32(a, b_hi); // a0l*b0h, a1l*b1h + __m512i w2 = _mm512_mul_epu32(a_hi, b); // a0h*b0l, a1h*b0l + __m512i w3 = _mm512_mul_epu32(a_hi, b_hi); // a0h*b0h, a1h*b1h + // sum partial products + __m512i w0h = _mm512_srli_epi64(w0, 32); + __m512i s1 = _mm512_add_epi64(w1, w0h); + __m512i s1l = _mm512_and_si512(s1, lomask); + __m512i s1h = _mm512_srli_epi64(s1, 32); + + __m512i s2 = _mm512_add_epi64(w2, s1l); + __m512i s2h = _mm512_srli_epi64(s2, 32); + + __m512i hi = _mm512_add_epi64(w3, s1h); + hi = _mm512_add_epi64(hi, s2h); + return hi; +} +// divide each unsigned 64-bit element by a divisor +NPY_FINLINE npyv_u64 npyv_divc_u64(npyv_u64 a, const npyv_u64x3 divisor) +{ + const __m128i shf1 = _mm512_castsi512_si128(divisor.val[1]); + const __m128i shf2 = _mm512_castsi512_si128(divisor.val[2]); + // high part of unsigned multiplication + __m512i mulhi = npyv__mullhi_u64(a, divisor.val[0]); + // floor(a/d) = (mulhi + ((a-mulhi) >> sh1)) >> sh2 + __m512i q = _mm512_sub_epi64(a, mulhi); + q = _mm512_srl_epi64(q, shf1); + q = _mm512_add_epi64(mulhi, q); + q = _mm512_srl_epi64(q, shf2); + return q; +} +// divide each unsigned 64-bit element by a divisor (round towards zero) +NPY_FINLINE npyv_s64 npyv_divc_s64(npyv_s64 a, const npyv_s64x3 divisor) +{ + const __m128i shf1 = _mm512_castsi512_si128(divisor.val[1]); + // high part of unsigned multiplication + __m512i mulhi = npyv__mullhi_u64(a, divisor.val[0]); + // convert unsigned to signed high multiplication + // mulhi - ((a < 0) ? m : 0) - ((m < 0) ? a : 0); + __m512i asign = _mm512_srai_epi64(a, 63); + __m512i msign = _mm512_srai_epi64(divisor.val[0], 63); + __m512i m_asign = _mm512_and_si512(divisor.val[0], asign); + __m512i a_msign = _mm512_and_si512(a, msign); + mulhi = _mm512_sub_epi64(mulhi, m_asign); + mulhi = _mm512_sub_epi64(mulhi, a_msign); + // q = ((a + mulhi) >> sh1) - XSIGN(a) + // trunc(a/d) = (q ^ dsign) - dsign + __m512i q = _mm512_sra_epi64(_mm512_add_epi64(a, mulhi), shf1); + q = _mm512_sub_epi64(q, asign); + q = _mm512_sub_epi64(_mm512_xor_si512(q, divisor.val[2]), divisor.val[2]); + return q; +} +/*************************** + * Division + ***************************/ +// TODO: emulate integer division +#define npyv_div_f32 _mm512_div_ps +#define npyv_div_f64 _mm512_div_pd + +/*************************** + * FUSED + ***************************/ +// multiply and add, a*b + c +#define npyv_muladd_f32 _mm512_fmadd_ps +#define npyv_muladd_f64 _mm512_fmadd_pd +// multiply and subtract, a*b - c +#define npyv_mulsub_f32 _mm512_fmsub_ps +#define npyv_mulsub_f64 _mm512_fmsub_pd +// negate multiply and add, -(a*b) + c +#define npyv_nmuladd_f32 _mm512_fnmadd_ps +#define npyv_nmuladd_f64 _mm512_fnmadd_pd +// negate multiply and subtract, -(a*b) - c +#define npyv_nmulsub_f32 _mm512_fnmsub_ps +#define npyv_nmulsub_f64 _mm512_fnmsub_pd +// multiply, add for odd elements and subtract even elements. +// (a * b) -+ c +#define npyv_muladdsub_f32 _mm512_fmaddsub_ps +#define npyv_muladdsub_f64 _mm512_fmaddsub_pd + +/*************************** + * Summation: Calculates the sum of all vector elements. + * there are three ways to implement reduce sum for AVX512: + * 1- split(256) /add /split(128) /add /hadd /hadd /extract + * 2- shuff(cross) /add /shuff(cross) /add /shuff /add /shuff /add /extract + * 3- _mm512_reduce_add_ps/pd + * The first one is been widely used by many projects + * + * the second one is used by Intel Compiler, maybe because the + * latency of hadd increased by (2-3) starting from Skylake-X which makes two + * extra shuffles(non-cross) cheaper. check https://godbolt.org/z/s3G9Er for more info. + * + * The third one is almost the same as the second one but only works for + * intel compiler/GCC 7.1/Clang 4, we still need to support older GCC. + ***************************/ +// reduce sum across vector +#ifdef NPY_HAVE_AVX512F_REDUCE + #define npyv_sum_u32 _mm512_reduce_add_epi32 + #define npyv_sum_u64 _mm512_reduce_add_epi64 + #define npyv_sum_f32 _mm512_reduce_add_ps + #define npyv_sum_f64 _mm512_reduce_add_pd +#else + NPY_FINLINE npy_uint32 npyv_sum_u32(npyv_u32 a) + { + __m256i half = _mm256_add_epi32(npyv512_lower_si256(a), npyv512_higher_si256(a)); + __m128i quarter = _mm_add_epi32(_mm256_castsi256_si128(half), _mm256_extracti128_si256(half, 1)); + quarter = _mm_hadd_epi32(quarter, quarter); + return _mm_cvtsi128_si32(_mm_hadd_epi32(quarter, quarter)); + } + + NPY_FINLINE npy_uint64 npyv_sum_u64(npyv_u64 a) + { + __m256i four = _mm256_add_epi64(npyv512_lower_si256(a), npyv512_higher_si256(a)); + __m256i two = _mm256_add_epi64(four, _mm256_shuffle_epi32(four, _MM_SHUFFLE(1, 0, 3, 2))); + __m128i one = _mm_add_epi64(_mm256_castsi256_si128(two), _mm256_extracti128_si256(two, 1)); + return (npy_uint64)npyv128_cvtsi128_si64(one); + } + + NPY_FINLINE float npyv_sum_f32(npyv_f32 a) + { + __m512 h64 = _mm512_shuffle_f32x4(a, a, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 sum32 = _mm512_add_ps(a, h64); + __m512 h32 = _mm512_shuffle_f32x4(sum32, sum32, _MM_SHUFFLE(1, 0, 3, 2)); + __m512 sum16 = _mm512_add_ps(sum32, h32); + __m512 h16 = _mm512_permute_ps(sum16, _MM_SHUFFLE(1, 0, 3, 2)); + __m512 sum8 = _mm512_add_ps(sum16, h16); + __m512 h4 = _mm512_permute_ps(sum8, _MM_SHUFFLE(2, 3, 0, 1)); + __m512 sum4 = _mm512_add_ps(sum8, h4); + return _mm_cvtss_f32(_mm512_castps512_ps128(sum4)); + } + + NPY_FINLINE double npyv_sum_f64(npyv_f64 a) + { + __m512d h64 = _mm512_shuffle_f64x2(a, a, _MM_SHUFFLE(3, 2, 3, 2)); + __m512d sum32 = _mm512_add_pd(a, h64); + __m512d h32 = _mm512_permutex_pd(sum32, _MM_SHUFFLE(1, 0, 3, 2)); + __m512d sum16 = _mm512_add_pd(sum32, h32); + __m512d h16 = _mm512_permute_pd(sum16, _MM_SHUFFLE(2, 3, 0, 1)); + __m512d sum8 = _mm512_add_pd(sum16, h16); + return _mm_cvtsd_f64(_mm512_castpd512_pd128(sum8)); + } + +#endif + +// expand the source vector and performs sum reduce +NPY_FINLINE npy_uint16 npyv_sumup_u8(npyv_u8 a) +{ +#ifdef NPY_HAVE_AVX512BW + __m512i eight = _mm512_sad_epu8(a, _mm512_setzero_si512()); + __m256i four = _mm256_add_epi16(npyv512_lower_si256(eight), npyv512_higher_si256(eight)); +#else + __m256i lo_four = _mm256_sad_epu8(npyv512_lower_si256(a), _mm256_setzero_si256()); + __m256i hi_four = _mm256_sad_epu8(npyv512_higher_si256(a), _mm256_setzero_si256()); + __m256i four = _mm256_add_epi16(lo_four, hi_four); +#endif + __m128i two = _mm_add_epi16(_mm256_castsi256_si128(four), _mm256_extracti128_si256(four, 1)); + __m128i one = _mm_add_epi16(two, _mm_unpackhi_epi64(two, two)); + return (npy_uint16)_mm_cvtsi128_si32(one); +} + +NPY_FINLINE npy_uint32 npyv_sumup_u16(npyv_u16 a) +{ + const npyv_u16 even_mask = _mm512_set1_epi32(0x0000FFFF); + __m512i even = _mm512_and_si512(a, even_mask); + __m512i odd = _mm512_srli_epi32(a, 16); + __m512i ff = _mm512_add_epi32(even, odd); + return npyv_sum_u32(ff); +} + +#endif // _NPY_SIMD_AVX512_ARITHMETIC_H diff --git a/mkl_umath/src/npyv/avx512/avx512.h b/mkl_umath/src/npyv/avx512/avx512.h new file mode 100644 index 00000000..2a4a20b2 --- /dev/null +++ b/mkl_umath/src/npyv/avx512/avx512.h @@ -0,0 +1,82 @@ +#ifndef _NPY_SIMD_H_ + #error "Not a standalone header" +#endif +#define NPY_SIMD 512 +#define NPY_SIMD_WIDTH 64 +#define NPY_SIMD_F32 1 +#define NPY_SIMD_F64 1 +#define NPY_SIMD_FMA3 1 // native support +#define NPY_SIMD_BIGENDIAN 0 +#define NPY_SIMD_CMPSIGNAL 0 +// Enough limit to allow us to use _mm512_i32gather_* and _mm512_i32scatter_* +#define NPY_SIMD_MAXLOAD_STRIDE32 (0x7fffffff / 16) +#define NPY_SIMD_MAXSTORE_STRIDE32 (0x7fffffff / 16) +#define NPY_SIMD_MAXLOAD_STRIDE64 (0x7fffffff / 16) +#define NPY_SIMD_MAXSTORE_STRIDE64 (0x7fffffff / 16) + +typedef __m512i npyv_u8; +typedef __m512i npyv_s8; +typedef __m512i npyv_u16; +typedef __m512i npyv_s16; +typedef __m512i npyv_u32; +typedef __m512i npyv_s32; +typedef __m512i npyv_u64; +typedef __m512i npyv_s64; +typedef __m512 npyv_f32; +typedef __m512d npyv_f64; + +#ifdef NPY_HAVE_AVX512BW +typedef __mmask64 npyv_b8; +typedef __mmask32 npyv_b16; +#else +typedef __m512i npyv_b8; +typedef __m512i npyv_b16; +#endif +typedef __mmask16 npyv_b32; +typedef __mmask8 npyv_b64; + +typedef struct { __m512i val[2]; } npyv_m512ix2; +typedef npyv_m512ix2 npyv_u8x2; +typedef npyv_m512ix2 npyv_s8x2; +typedef npyv_m512ix2 npyv_u16x2; +typedef npyv_m512ix2 npyv_s16x2; +typedef npyv_m512ix2 npyv_u32x2; +typedef npyv_m512ix2 npyv_s32x2; +typedef npyv_m512ix2 npyv_u64x2; +typedef npyv_m512ix2 npyv_s64x2; + +typedef struct { __m512i val[3]; } npyv_m512ix3; +typedef npyv_m512ix3 npyv_u8x3; +typedef npyv_m512ix3 npyv_s8x3; +typedef npyv_m512ix3 npyv_u16x3; +typedef npyv_m512ix3 npyv_s16x3; +typedef npyv_m512ix3 npyv_u32x3; +typedef npyv_m512ix3 npyv_s32x3; +typedef npyv_m512ix3 npyv_u64x3; +typedef npyv_m512ix3 npyv_s64x3; + +typedef struct { __m512 val[2]; } npyv_f32x2; +typedef struct { __m512d val[2]; } npyv_f64x2; +typedef struct { __m512 val[3]; } npyv_f32x3; +typedef struct { __m512d val[3]; } npyv_f64x3; + +#define npyv_nlanes_u8 64 +#define npyv_nlanes_s8 64 +#define npyv_nlanes_u16 32 +#define npyv_nlanes_s16 32 +#define npyv_nlanes_u32 16 +#define npyv_nlanes_s32 16 +#define npyv_nlanes_u64 8 +#define npyv_nlanes_s64 8 +#define npyv_nlanes_f32 16 +#define npyv_nlanes_f64 8 + +#include "utils.h" +#include "memory.h" +#include "misc.h" +#include "reorder.h" +#include "operators.h" +#include "conversion.h" +#include "arithmetic.h" +#include "math.h" +#include "maskop.h" diff --git a/mkl_umath/src/npyv/avx512/conversion.h b/mkl_umath/src/npyv/avx512/conversion.h new file mode 100644 index 00000000..3b29b672 --- /dev/null +++ b/mkl_umath/src/npyv/avx512/conversion.h @@ -0,0 +1,204 @@ +#ifndef NPY_SIMD + #error "Not a standalone header" +#endif + +#ifndef _NPY_SIMD_AVX512_CVT_H +#define _NPY_SIMD_AVX512_CVT_H + +// convert mask to integer vectors +#ifdef NPY_HAVE_AVX512BW + #define npyv_cvt_u8_b8 _mm512_movm_epi8 + #define npyv_cvt_u16_b16 _mm512_movm_epi16 +#else + #define npyv_cvt_u8_b8(BL) BL + #define npyv_cvt_u16_b16(BL) BL +#endif +#define npyv_cvt_s8_b8 npyv_cvt_u8_b8 +#define npyv_cvt_s16_b16 npyv_cvt_u16_b16 + +#ifdef NPY_HAVE_AVX512DQ + #define npyv_cvt_u32_b32 _mm512_movm_epi32 + #define npyv_cvt_u64_b64 _mm512_movm_epi64 +#else + #define npyv_cvt_u32_b32(BL) _mm512_maskz_set1_epi32(BL, (int)-1) + #define npyv_cvt_u64_b64(BL) _mm512_maskz_set1_epi64(BL, (npy_int64)-1) +#endif +#define npyv_cvt_s32_b32 npyv_cvt_u32_b32 +#define npyv_cvt_s64_b64 npyv_cvt_u64_b64 +#define npyv_cvt_f32_b32(BL) _mm512_castsi512_ps(npyv_cvt_u32_b32(BL)) +#define npyv_cvt_f64_b64(BL) _mm512_castsi512_pd(npyv_cvt_u64_b64(BL)) + +// convert integer vectors to mask +#ifdef NPY_HAVE_AVX512BW + #define npyv_cvt_b8_u8 _mm512_movepi8_mask + #define npyv_cvt_b16_u16 _mm512_movepi16_mask +#else + #define npyv_cvt_b8_u8(A) A + #define npyv_cvt_b16_u16(A) A +#endif +#define npyv_cvt_b8_s8 npyv_cvt_b8_u8 +#define npyv_cvt_b16_s16 npyv_cvt_b16_u16 + +#ifdef NPY_HAVE_AVX512DQ + #define npyv_cvt_b32_u32 _mm512_movepi32_mask + #define npyv_cvt_b64_u64 _mm512_movepi64_mask +#else + #define npyv_cvt_b32_u32(A) _mm512_cmpneq_epu32_mask(A, _mm512_setzero_si512()) + #define npyv_cvt_b64_u64(A) _mm512_cmpneq_epu64_mask(A, _mm512_setzero_si512()) +#endif +#define npyv_cvt_b32_s32 npyv_cvt_b32_u32 +#define npyv_cvt_b64_s64 npyv_cvt_b64_u64 +#define npyv_cvt_b32_f32(A) npyv_cvt_b32_u32(_mm512_castps_si512(A)) +#define npyv_cvt_b64_f64(A) npyv_cvt_b64_u64(_mm512_castpd_si512(A)) + +// expand +NPY_FINLINE npyv_u16x2 npyv_expand_u16_u8(npyv_u8 data) +{ + npyv_u16x2 r; + __m256i lo = npyv512_lower_si256(data); + __m256i hi = npyv512_higher_si256(data); +#ifdef NPY_HAVE_AVX512BW + r.val[0] = _mm512_cvtepu8_epi16(lo); + r.val[1] = _mm512_cvtepu8_epi16(hi); +#else + __m256i loelo = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(lo)); + __m256i loehi = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(lo, 1)); + __m256i hielo = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(hi)); + __m256i hiehi = _mm256_cvtepu8_epi16(_mm256_extracti128_si256(hi, 1)); + r.val[0] = npyv512_combine_si256(loelo, loehi); + r.val[1] = npyv512_combine_si256(hielo, hiehi); +#endif + return r; +} + +NPY_FINLINE npyv_u32x2 npyv_expand_u32_u16(npyv_u16 data) +{ + npyv_u32x2 r; + __m256i lo = npyv512_lower_si256(data); + __m256i hi = npyv512_higher_si256(data); +#ifdef NPY_HAVE_AVX512BW + r.val[0] = _mm512_cvtepu16_epi32(lo); + r.val[1] = _mm512_cvtepu16_epi32(hi); +#else + __m256i loelo = _mm256_cvtepu16_epi32(_mm256_castsi256_si128(lo)); + __m256i loehi = _mm256_cvtepu16_epi32(_mm256_extracti128_si256(lo, 1)); + __m256i hielo = _mm256_cvtepu16_epi32(_mm256_castsi256_si128(hi)); + __m256i hiehi = _mm256_cvtepu16_epi32(_mm256_extracti128_si256(hi, 1)); + r.val[0] = npyv512_combine_si256(loelo, loehi); + r.val[1] = npyv512_combine_si256(hielo, hiehi); +#endif + return r; +} + +// pack two 16-bit boolean into one 8-bit boolean vector +NPY_FINLINE npyv_b8 npyv_pack_b8_b16(npyv_b16 a, npyv_b16 b) { +#ifdef NPY_HAVE_AVX512BW + return _mm512_kunpackd((__mmask64)b, (__mmask64)a); +#else + const __m512i idx = _mm512_setr_epi64(0, 2, 4, 6, 1, 3, 5, 7); + return _mm512_permutexvar_epi64(idx, npyv512_packs_epi16(a, b)); +#endif +} + +// pack four 32-bit boolean vectors into one 8-bit boolean vector +NPY_FINLINE npyv_b8 +npyv_pack_b8_b32(npyv_b32 a, npyv_b32 b, npyv_b32 c, npyv_b32 d) { +#ifdef NPY_HAVE_AVX512BW + __mmask32 ab = _mm512_kunpackw((__mmask32)b, (__mmask32)a); + __mmask32 cd = _mm512_kunpackw((__mmask32)d, (__mmask32)c); + return npyv_pack_b8_b16(ab, cd); +#else + const __m512i idx = _mm512_setr_epi32( + 0, 4, 1, 5, 2, 6, 3, 7, 8, 12, 9, 13, 10, 14, 11, 15); + __m256i ta = npyv512_pack_lo_hi(npyv_cvt_u32_b32(a)); + __m256i tb = npyv512_pack_lo_hi(npyv_cvt_u32_b32(b)); + __m256i tc = npyv512_pack_lo_hi(npyv_cvt_u32_b32(c)); + __m256i td = npyv512_pack_lo_hi(npyv_cvt_u32_b32(d)); + __m256i ab = _mm256_packs_epi16(ta, tb); + __m256i cd = _mm256_packs_epi16(tc, td); + __m512i abcd = npyv512_combine_si256(ab, cd); + return _mm512_permutexvar_epi32(idx, abcd); +#endif +} + +// pack eight 64-bit boolean vectors into one 8-bit boolean vector +NPY_FINLINE npyv_b8 +npyv_pack_b8_b64(npyv_b64 a, npyv_b64 b, npyv_b64 c, npyv_b64 d, + npyv_b64 e, npyv_b64 f, npyv_b64 g, npyv_b64 h) { + __mmask16 ab = _mm512_kunpackb((__mmask16)b, (__mmask16)a); + __mmask16 cd = _mm512_kunpackb((__mmask16)d, (__mmask16)c); + __mmask16 ef = _mm512_kunpackb((__mmask16)f, (__mmask16)e); + __mmask16 gh = _mm512_kunpackb((__mmask16)h, (__mmask16)g); + return npyv_pack_b8_b32(ab, cd, ef, gh); +} +/* + * A compiler bug workaround on Intel Compiler Classic. + * The bug manifests specifically when the + * scalar result of _cvtmask64_u64 is compared against the constant -1. This + * comparison uniquely triggers a bug under conditions of equality (==) or + * inequality (!=) checks, which are typically used in reduction operations like + * np.logical_or. + * + * The underlying issue arises from the compiler's optimizer. When the last + * vector comparison instruction operates on zmm, the optimizer erroneously + * emits a duplicate of this instruction but on the lower half register ymm. It + * then performs a bitwise XOR operation between the mask produced by this + * duplicated instruction and the mask from the original comparison instruction. + * This erroneous behavior leads to incorrect results. + * + * See https://github.com/numpy/numpy/issues/26197#issuecomment-2056750975 + */ +#ifdef __INTEL_COMPILER +#define NPYV__VOLATILE_CVTMASK64 volatile +#else +#define NPYV__VOLATILE_CVTMASK64 +#endif +// convert boolean vectors to integer bitfield +NPY_FINLINE npy_uint64 npyv_tobits_b8(npyv_b8 a) { +#ifdef NPY_HAVE_AVX512BW_MASK + npy_uint64 NPYV__VOLATILE_CVTMASK64 t = (npy_uint64)_cvtmask64_u64(a); + return t; +#elif defined(NPY_HAVE_AVX512BW) + npy_uint64 NPYV__VOLATILE_CVTMASK64 t = (npy_uint64)a; + return t; +#else + int mask_lo = _mm256_movemask_epi8(npyv512_lower_si256(a)); + int mask_hi = _mm256_movemask_epi8(npyv512_higher_si256(a)); + return (unsigned)mask_lo | ((npy_uint64)(unsigned)mask_hi << 32); +#endif +} +#undef NPYV__VOLATILE_CVTMASK64 + +NPY_FINLINE npy_uint64 npyv_tobits_b16(npyv_b16 a) +{ +#ifdef NPY_HAVE_AVX512BW_MASK + return (npy_uint32)_cvtmask32_u32(a); +#elif defined(NPY_HAVE_AVX512BW) + return (npy_uint32)a; +#else + __m256i pack = _mm256_packs_epi16( + npyv512_lower_si256(a), npyv512_higher_si256(a) + ); + return (npy_uint32)_mm256_movemask_epi8(_mm256_permute4x64_epi64(pack, _MM_SHUFFLE(3, 1, 2, 0))); +#endif +} +NPY_FINLINE npy_uint64 npyv_tobits_b32(npyv_b32 a) +{ return (npy_uint16)a; } +NPY_FINLINE npy_uint64 npyv_tobits_b64(npyv_b64 a) +{ +#ifdef NPY_HAVE_AVX512DQ_MASK + return _cvtmask8_u32(a); +#else + return (npy_uint8)a; +#endif +} + +// round to nearest integer (assuming even) +#define npyv_round_s32_f32 _mm512_cvtps_epi32 +NPY_FINLINE npyv_s32 npyv_round_s32_f64(npyv_f64 a, npyv_f64 b) +{ + __m256i lo = _mm512_cvtpd_epi32(a), hi = _mm512_cvtpd_epi32(b); + return npyv512_combine_si256(lo, hi); +} + +#endif // _NPY_SIMD_AVX512_CVT_H diff --git a/mkl_umath/src/npyv/avx512/maskop.h b/mkl_umath/src/npyv/avx512/maskop.h new file mode 100644 index 00000000..88fa4a68 --- /dev/null +++ b/mkl_umath/src/npyv/avx512/maskop.h @@ -0,0 +1,67 @@ +#ifndef NPY_SIMD + #error "Not a standalone header, use simd/simd.h instead" +#endif + +#ifndef _NPY_SIMD_AVX512_MASKOP_H +#define _NPY_SIMD_AVX512_MASKOP_H + +/** + * Implements conditional addition and subtraction. + * e.g. npyv_ifadd_f32(m, a, b, c) -> m ? a + b : c + * e.g. npyv_ifsub_f32(m, a, b, c) -> m ? a - b : c + */ +#define NPYV_IMPL_AVX512_EMULATE_MASK_ADDSUB(SFX, BSFX) \ + NPY_FINLINE npyv_##SFX npyv_ifadd_##SFX \ + (npyv_##BSFX m, npyv_##SFX a, npyv_##SFX b, npyv_##SFX c) \ + { \ + npyv_##SFX add = npyv_add_##SFX(a, b); \ + return npyv_select_##SFX(m, add, c); \ + } \ + NPY_FINLINE npyv_##SFX npyv_ifsub_##SFX \ + (npyv_##BSFX m, npyv_##SFX a, npyv_##SFX b, npyv_##SFX c) \ + { \ + npyv_##SFX sub = npyv_sub_##SFX(a, b); \ + return npyv_select_##SFX(m, sub, c); \ + } + +#define NPYV_IMPL_AVX512_MASK_ADDSUB(SFX, BSFX, ZSFX) \ + NPY_FINLINE npyv_##SFX npyv_ifadd_##SFX \ + (npyv_##BSFX m, npyv_##SFX a, npyv_##SFX b, npyv_##SFX c) \ + { return _mm512_mask_add_##ZSFX(c, m, a, b); } \ + NPY_FINLINE npyv_##SFX npyv_ifsub_##SFX \ + (npyv_##BSFX m, npyv_##SFX a, npyv_##SFX b, npyv_##SFX c) \ + { return _mm512_mask_sub_##ZSFX(c, m, a, b); } + +#ifdef NPY_HAVE_AVX512BW + NPYV_IMPL_AVX512_MASK_ADDSUB(u8, b8, epi8) + NPYV_IMPL_AVX512_MASK_ADDSUB(s8, b8, epi8) + NPYV_IMPL_AVX512_MASK_ADDSUB(u16, b16, epi16) + NPYV_IMPL_AVX512_MASK_ADDSUB(s16, b16, epi16) +#else + NPYV_IMPL_AVX512_EMULATE_MASK_ADDSUB(u8, b8) + NPYV_IMPL_AVX512_EMULATE_MASK_ADDSUB(s8, b8) + NPYV_IMPL_AVX512_EMULATE_MASK_ADDSUB(u16, b16) + NPYV_IMPL_AVX512_EMULATE_MASK_ADDSUB(s16, b16) +#endif + +NPYV_IMPL_AVX512_MASK_ADDSUB(u32, b32, epi32) +NPYV_IMPL_AVX512_MASK_ADDSUB(s32, b32, epi32) +NPYV_IMPL_AVX512_MASK_ADDSUB(u64, b64, epi64) +NPYV_IMPL_AVX512_MASK_ADDSUB(s64, b64, epi64) +NPYV_IMPL_AVX512_MASK_ADDSUB(f32, b32, ps) +NPYV_IMPL_AVX512_MASK_ADDSUB(f64, b64, pd) + +// division, m ? a / b : c +NPY_FINLINE npyv_f32 npyv_ifdiv_f32(npyv_b32 m, npyv_f32 a, npyv_f32 b, npyv_f32 c) +{ return _mm512_mask_div_ps(c, m, a, b); } +// conditional division, m ? a / b : 0 +NPY_FINLINE npyv_f32 npyv_ifdivz_f32(npyv_b32 m, npyv_f32 a, npyv_f32 b) +{ return _mm512_maskz_div_ps(m, a, b); } +// division, m ? a / b : c +NPY_FINLINE npyv_f64 npyv_ifdiv_f64(npyv_b32 m, npyv_f64 a, npyv_f64 b, npyv_f64 c) +{ return _mm512_mask_div_pd(c, m, a, b); } +// conditional division, m ? a / b : 0 +NPY_FINLINE npyv_f64 npyv_ifdivz_f64(npyv_b32 m, npyv_f64 a, npyv_f64 b) +{ return _mm512_maskz_div_pd(m, a, b); } + +#endif // _NPY_SIMD_AVX512_MASKOP_H diff --git a/mkl_umath/src/npyv/avx512/math.h b/mkl_umath/src/npyv/avx512/math.h new file mode 100644 index 00000000..97fd2d64 --- /dev/null +++ b/mkl_umath/src/npyv/avx512/math.h @@ -0,0 +1,309 @@ +#ifndef NPY_SIMD + #error "Not a standalone header" +#endif + +#ifndef _NPY_SIMD_AVX512_MATH_H +#define _NPY_SIMD_AVX512_MATH_H + +/*************************** + * Elementary + ***************************/ +// Square root +#define npyv_sqrt_f32 _mm512_sqrt_ps +#define npyv_sqrt_f64 _mm512_sqrt_pd + +// Reciprocal +NPY_FINLINE npyv_f32 npyv_recip_f32(npyv_f32 a) +{ return _mm512_div_ps(_mm512_set1_ps(1.0f), a); } +NPY_FINLINE npyv_f64 npyv_recip_f64(npyv_f64 a) +{ return _mm512_div_pd(_mm512_set1_pd(1.0), a); } + +// Absolute +NPY_FINLINE npyv_f32 npyv_abs_f32(npyv_f32 a) +{ +#if 0 // def NPY_HAVE_AVX512DQ + return _mm512_range_ps(a, a, 8); +#else + return npyv_and_f32( + a, _mm512_castsi512_ps(_mm512_set1_epi32(0x7fffffff)) + ); +#endif +} +NPY_FINLINE npyv_f64 npyv_abs_f64(npyv_f64 a) +{ +#if 0 // def NPY_HAVE_AVX512DQ + return _mm512_range_pd(a, a, 8); +#else + return npyv_and_f64( + a, _mm512_castsi512_pd(npyv_setall_s64(0x7fffffffffffffffLL)) + ); +#endif +} + +// Square +NPY_FINLINE npyv_f32 npyv_square_f32(npyv_f32 a) +{ return _mm512_mul_ps(a, a); } +NPY_FINLINE npyv_f64 npyv_square_f64(npyv_f64 a) +{ return _mm512_mul_pd(a, a); } + +// Maximum, natively mapping with no guarantees to handle NaN. +#define npyv_max_f32 _mm512_max_ps +#define npyv_max_f64 _mm512_max_pd +// Maximum, supports IEEE floating-point arithmetic (IEC 60559), +// - If one of the two vectors contains NaN, the equivalent element of the other vector is set +// - Only if both corresponded elements are NaN, NaN is set. +NPY_FINLINE npyv_f32 npyv_maxp_f32(npyv_f32 a, npyv_f32 b) +{ + __mmask16 nn = _mm512_cmp_ps_mask(b, b, _CMP_ORD_Q); + return _mm512_mask_max_ps(a, nn, a, b); +} +NPY_FINLINE npyv_f64 npyv_maxp_f64(npyv_f64 a, npyv_f64 b) +{ + __mmask8 nn = _mm512_cmp_pd_mask(b, b, _CMP_ORD_Q); + return _mm512_mask_max_pd(a, nn, a, b); +} +// Maximum, propagates NaNs +// If any of corresponded element is NaN, NaN is set. +NPY_FINLINE npyv_f32 npyv_maxn_f32(npyv_f32 a, npyv_f32 b) +{ + __mmask16 nn = _mm512_cmp_ps_mask(a, a, _CMP_ORD_Q); + return _mm512_mask_max_ps(a, nn, a, b); +} +NPY_FINLINE npyv_f64 npyv_maxn_f64(npyv_f64 a, npyv_f64 b) +{ + __mmask8 nn = _mm512_cmp_pd_mask(a, a, _CMP_ORD_Q); + return _mm512_mask_max_pd(a, nn, a, b); +} +// Maximum, integer operations +#ifdef NPY_HAVE_AVX512BW + #define npyv_max_u8 _mm512_max_epu8 + #define npyv_max_s8 _mm512_max_epi8 + #define npyv_max_u16 _mm512_max_epu16 + #define npyv_max_s16 _mm512_max_epi16 +#else + NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv_max_u8, _mm256_max_epu8) + NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv_max_s8, _mm256_max_epi8) + NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv_max_u16, _mm256_max_epu16) + NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv_max_s16, _mm256_max_epi16) +#endif +#define npyv_max_u32 _mm512_max_epu32 +#define npyv_max_s32 _mm512_max_epi32 +#define npyv_max_u64 _mm512_max_epu64 +#define npyv_max_s64 _mm512_max_epi64 + +// Minimum, natively mapping with no guarantees to handle NaN. +#define npyv_min_f32 _mm512_min_ps +#define npyv_min_f64 _mm512_min_pd +// Minimum, supports IEEE floating-point arithmetic (IEC 60559), +// - If one of the two vectors contains NaN, the equivalent element of the other vector is set +// - Only if both corresponded elements are NaN, NaN is set. +NPY_FINLINE npyv_f32 npyv_minp_f32(npyv_f32 a, npyv_f32 b) +{ + __mmask16 nn = _mm512_cmp_ps_mask(b, b, _CMP_ORD_Q); + return _mm512_mask_min_ps(a, nn, a, b); +} +NPY_FINLINE npyv_f64 npyv_minp_f64(npyv_f64 a, npyv_f64 b) +{ + __mmask8 nn = _mm512_cmp_pd_mask(b, b, _CMP_ORD_Q); + return _mm512_mask_min_pd(a, nn, a, b); +} +// Minimum, propagates NaNs +// If any of corresponded element is NaN, NaN is set. +NPY_FINLINE npyv_f32 npyv_minn_f32(npyv_f32 a, npyv_f32 b) +{ + __mmask16 nn = _mm512_cmp_ps_mask(a, a, _CMP_ORD_Q); + return _mm512_mask_min_ps(a, nn, a, b); +} +NPY_FINLINE npyv_f64 npyv_minn_f64(npyv_f64 a, npyv_f64 b) +{ + __mmask8 nn = _mm512_cmp_pd_mask(a, a, _CMP_ORD_Q); + return _mm512_mask_min_pd(a, nn, a, b); +} +// Minimum, integer operations +#ifdef NPY_HAVE_AVX512BW + #define npyv_min_u8 _mm512_min_epu8 + #define npyv_min_s8 _mm512_min_epi8 + #define npyv_min_u16 _mm512_min_epu16 + #define npyv_min_s16 _mm512_min_epi16 +#else + NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv_min_u8, _mm256_min_epu8) + NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv_min_s8, _mm256_min_epi8) + NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv_min_u16, _mm256_min_epu16) + NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv_min_s16, _mm256_min_epi16) +#endif +#define npyv_min_u32 _mm512_min_epu32 +#define npyv_min_s32 _mm512_min_epi32 +#define npyv_min_u64 _mm512_min_epu64 +#define npyv_min_s64 _mm512_min_epi64 + +#ifdef NPY_HAVE_AVX512F_REDUCE + #define npyv_reduce_min_u32 _mm512_reduce_min_epu32 + #define npyv_reduce_min_s32 _mm512_reduce_min_epi32 + #define npyv_reduce_min_u64 _mm512_reduce_min_epu64 + #define npyv_reduce_min_s64 _mm512_reduce_min_epi64 + #define npyv_reduce_min_f32 _mm512_reduce_min_ps + #define npyv_reduce_min_f64 _mm512_reduce_min_pd + #define npyv_reduce_max_u32 _mm512_reduce_max_epu32 + #define npyv_reduce_max_s32 _mm512_reduce_max_epi32 + #define npyv_reduce_max_u64 _mm512_reduce_max_epu64 + #define npyv_reduce_max_s64 _mm512_reduce_max_epi64 + #define npyv_reduce_max_f32 _mm512_reduce_max_ps + #define npyv_reduce_max_f64 _mm512_reduce_max_pd +#else + // reduce min&max for 32&64-bits + #define NPY_IMPL_AVX512_REDUCE_MINMAX(STYPE, INTRIN, VINTRIN) \ + NPY_FINLINE STYPE##32 npyv_reduce_##INTRIN##32(__m512i a) \ + { \ + __m256i v256 = _mm256_##VINTRIN##32(npyv512_lower_si256(a), \ + npyv512_higher_si256(a)); \ + __m128i v128 = _mm_##VINTRIN##32(_mm256_castsi256_si128(v256), \ + _mm256_extracti128_si256(v256, 1)); \ + __m128i v64 = _mm_##VINTRIN##32(v128, _mm_shuffle_epi32(v128, \ + (_MM_PERM_ENUM)_MM_SHUFFLE(0, 0, 3, 2))); \ + __m128i v32 = _mm_##VINTRIN##32(v64, _mm_shuffle_epi32(v64, \ + (_MM_PERM_ENUM)_MM_SHUFFLE(0, 0, 0, 1))); \ + return (STYPE##32)_mm_cvtsi128_si32(v32); \ + } \ + NPY_FINLINE STYPE##64 npyv_reduce_##INTRIN##64(__m512i a) \ + { \ + __m512i v256 = _mm512_##VINTRIN##64(a, \ + _mm512_shuffle_i64x2(a, a, (_MM_PERM_ENUM)_MM_SHUFFLE(0, 0, 3, 2))); \ + __m512i v128 = _mm512_##VINTRIN##64(v256, \ + _mm512_shuffle_i64x2(v256, v256, (_MM_PERM_ENUM)_MM_SHUFFLE(0, 0, 0, 1))); \ + __m512i v64 = _mm512_##VINTRIN##64(v128, \ + _mm512_shuffle_epi32(v128, (_MM_PERM_ENUM)_MM_SHUFFLE(0, 0, 3, 2))); \ + return (STYPE##64)npyv_extract0_u64(v64); \ + } + + NPY_IMPL_AVX512_REDUCE_MINMAX(npy_uint, min_u, min_epu) + NPY_IMPL_AVX512_REDUCE_MINMAX(npy_int, min_s, min_epi) + NPY_IMPL_AVX512_REDUCE_MINMAX(npy_uint, max_u, max_epu) + NPY_IMPL_AVX512_REDUCE_MINMAX(npy_int, max_s, max_epi) + #undef NPY_IMPL_AVX512_REDUCE_MINMAX + // reduce min&max for ps & pd + #define NPY_IMPL_AVX512_REDUCE_MINMAX(INTRIN) \ + NPY_FINLINE float npyv_reduce_##INTRIN##_f32(npyv_f32 a) \ + { \ + __m256 v256 = _mm256_##INTRIN##_ps( \ + npyv512_lower_ps256(a), npyv512_higher_ps256(a)); \ + __m128 v128 = _mm_##INTRIN##_ps( \ + _mm256_castps256_ps128(v256), _mm256_extractf128_ps(v256, 1)); \ + __m128 v64 = _mm_##INTRIN##_ps(v128, \ + _mm_shuffle_ps(v128, v128, (_MM_PERM_ENUM)_MM_SHUFFLE(0, 0, 3, 2))); \ + __m128 v32 = _mm_##INTRIN##_ps(v64, \ + _mm_shuffle_ps(v64, v64, (_MM_PERM_ENUM)_MM_SHUFFLE(0, 0, 0, 1))); \ + return _mm_cvtss_f32(v32); \ + } \ + NPY_FINLINE double npyv_reduce_##INTRIN##_f64(npyv_f64 a) \ + { \ + __m256d v256 = _mm256_##INTRIN##_pd( \ + npyv512_lower_pd256(a), npyv512_higher_pd256(a)); \ + __m128d v128 = _mm_##INTRIN##_pd( \ + _mm256_castpd256_pd128(v256), _mm256_extractf128_pd(v256, 1)); \ + __m128d v64 = _mm_##INTRIN##_pd(v128, \ + _mm_shuffle_pd(v128, v128, (_MM_PERM_ENUM)_MM_SHUFFLE(0, 0, 0, 1))); \ + return _mm_cvtsd_f64(v64); \ + } + + NPY_IMPL_AVX512_REDUCE_MINMAX(min) + NPY_IMPL_AVX512_REDUCE_MINMAX(max) + #undef NPY_IMPL_AVX512_REDUCE_MINMAX +#endif +#define NPY_IMPL_AVX512_REDUCE_MINMAX(INTRIN, INF, INF64) \ + NPY_FINLINE float npyv_reduce_##INTRIN##p_f32(npyv_f32 a) \ + { \ + npyv_b32 notnan = npyv_notnan_f32(a); \ + if (NPY_UNLIKELY(!npyv_any_b32(notnan))) { \ + return _mm_cvtss_f32(_mm512_castps512_ps128(a)); \ + } \ + a = npyv_select_f32(notnan, a, \ + npyv_reinterpret_f32_u32(npyv_setall_u32(INF))); \ + return npyv_reduce_##INTRIN##_f32(a); \ + } \ + NPY_FINLINE double npyv_reduce_##INTRIN##p_f64(npyv_f64 a) \ + { \ + npyv_b64 notnan = npyv_notnan_f64(a); \ + if (NPY_UNLIKELY(!npyv_any_b64(notnan))) { \ + return _mm_cvtsd_f64(_mm512_castpd512_pd128(a)); \ + } \ + a = npyv_select_f64(notnan, a, \ + npyv_reinterpret_f64_u64(npyv_setall_u64(INF64))); \ + return npyv_reduce_##INTRIN##_f64(a); \ + } \ + NPY_FINLINE float npyv_reduce_##INTRIN##n_f32(npyv_f32 a) \ + { \ + npyv_b32 notnan = npyv_notnan_f32(a); \ + if (NPY_UNLIKELY(!npyv_all_b32(notnan))) { \ + const union { npy_uint32 i; float f;} pnan = { \ + 0x7fc00000ul \ + }; \ + return pnan.f; \ + } \ + return npyv_reduce_##INTRIN##_f32(a); \ + } \ + NPY_FINLINE double npyv_reduce_##INTRIN##n_f64(npyv_f64 a) \ + { \ + npyv_b64 notnan = npyv_notnan_f64(a); \ + if (NPY_UNLIKELY(!npyv_all_b64(notnan))) { \ + const union { npy_uint64 i; double d;} pnan = { \ + 0x7ff8000000000000ull \ + }; \ + return pnan.d; \ + } \ + return npyv_reduce_##INTRIN##_f64(a); \ + } + +NPY_IMPL_AVX512_REDUCE_MINMAX(min, 0x7f800000, 0x7ff0000000000000) +NPY_IMPL_AVX512_REDUCE_MINMAX(max, 0xff800000, 0xfff0000000000000) +#undef NPY_IMPL_AVX512_REDUCE_MINMAX + +// reduce min&max for 8&16-bits +#define NPY_IMPL_AVX512_REDUCE_MINMAX(STYPE, INTRIN, VINTRIN) \ + NPY_FINLINE STYPE##16 npyv_reduce_##INTRIN##16(__m512i a) \ + { \ + __m256i v256 = _mm256_##VINTRIN##16(npyv512_lower_si256(a), npyv512_higher_si256(a)); \ + __m128i v128 = _mm_##VINTRIN##16(_mm256_castsi256_si128(v256), _mm256_extracti128_si256(v256, 1)); \ + __m128i v64 = _mm_##VINTRIN##16(v128, _mm_shuffle_epi32(v128, \ + (_MM_PERM_ENUM)_MM_SHUFFLE(0, 0, 3, 2))); \ + __m128i v32 = _mm_##VINTRIN##16(v64, _mm_shuffle_epi32(v64, \ + (_MM_PERM_ENUM)_MM_SHUFFLE(0, 0, 0, 1))); \ + __m128i v16 = _mm_##VINTRIN##16(v32, _mm_shufflelo_epi16(v32, \ + (_MM_PERM_ENUM)_MM_SHUFFLE(0, 0, 0, 1))); \ + return (STYPE##16)_mm_cvtsi128_si32(v16); \ + } \ + NPY_FINLINE STYPE##8 npyv_reduce_##INTRIN##8(__m512i a) \ + { \ + __m256i v256 = _mm256_##VINTRIN##8(npyv512_lower_si256(a), npyv512_higher_si256(a)); \ + __m128i v128 = _mm_##VINTRIN##8(_mm256_castsi256_si128(v256), _mm256_extracti128_si256(v256, 1)); \ + __m128i v64 = _mm_##VINTRIN##8(v128, _mm_shuffle_epi32(v128, \ + (_MM_PERM_ENUM)_MM_SHUFFLE(0, 0, 3, 2))); \ + __m128i v32 = _mm_##VINTRIN##8(v64, _mm_shuffle_epi32(v64, \ + (_MM_PERM_ENUM)_MM_SHUFFLE(0, 0, 0, 1))); \ + __m128i v16 = _mm_##VINTRIN##8(v32, _mm_shufflelo_epi16(v32, \ + (_MM_PERM_ENUM)_MM_SHUFFLE(0, 0, 0, 1))); \ + __m128i v8 = _mm_##VINTRIN##8(v16, _mm_srli_epi16(v16, 8)); \ + return (STYPE##16)_mm_cvtsi128_si32(v8); \ + } +NPY_IMPL_AVX512_REDUCE_MINMAX(npy_uint, min_u, min_epu) +NPY_IMPL_AVX512_REDUCE_MINMAX(npy_int, min_s, min_epi) +NPY_IMPL_AVX512_REDUCE_MINMAX(npy_uint, max_u, max_epu) +NPY_IMPL_AVX512_REDUCE_MINMAX(npy_int, max_s, max_epi) +#undef NPY_IMPL_AVX512_REDUCE_MINMAX + +// round to nearest integer even +#define npyv_rint_f32(A) _mm512_roundscale_ps(A, _MM_FROUND_TO_NEAREST_INT) +#define npyv_rint_f64(A) _mm512_roundscale_pd(A, _MM_FROUND_TO_NEAREST_INT) + +// ceil +#define npyv_ceil_f32(A) _mm512_roundscale_ps(A, _MM_FROUND_TO_POS_INF) +#define npyv_ceil_f64(A) _mm512_roundscale_pd(A, _MM_FROUND_TO_POS_INF) + +// trunc +#define npyv_trunc_f32(A) _mm512_roundscale_ps(A, _MM_FROUND_TO_ZERO) +#define npyv_trunc_f64(A) _mm512_roundscale_pd(A, _MM_FROUND_TO_ZERO) + +// floor +#define npyv_floor_f32(A) _mm512_roundscale_ps(A, _MM_FROUND_TO_NEG_INF) +#define npyv_floor_f64(A) _mm512_roundscale_pd(A, _MM_FROUND_TO_NEG_INF) + +#endif // _NPY_SIMD_AVX512_MATH_H diff --git a/mkl_umath/src/npyv/avx512/memory.h b/mkl_umath/src/npyv/avx512/memory.h new file mode 100644 index 00000000..e981ef8f --- /dev/null +++ b/mkl_umath/src/npyv/avx512/memory.h @@ -0,0 +1,715 @@ +#ifndef NPY_SIMD + #error "Not a standalone header" +#endif + +#ifndef _NPY_SIMD_AVX512_MEMORY_H +#define _NPY_SIMD_AVX512_MEMORY_H + +#include "misc.h" + +/*************************** + * load/store + ***************************/ +#if defined(__GNUC__) + // GCC expect pointer argument type to be `void*` instead of `const void *`, + // which cause a massive warning. + #define npyv__loads(PTR) _mm512_stream_load_si512((__m512i*)(PTR)) +#else + #define npyv__loads(PTR) _mm512_stream_load_si512((const __m512i*)(PTR)) +#endif +#if defined(_MSC_VER) && defined(_M_IX86) + // workaround msvc(32bit) overflow bug, reported at + // https://developercommunity.visualstudio.com/content/problem/911872/u.html + NPY_FINLINE __m512i npyv__loadl(const __m256i *ptr) + { + __m256i a = _mm256_loadu_si256(ptr); + return _mm512_inserti64x4(_mm512_castsi256_si512(a), a, 0); + } +#else + #define npyv__loadl(PTR) \ + _mm512_castsi256_si512(_mm256_loadu_si256(PTR)) +#endif +#define NPYV_IMPL_AVX512_MEM_INT(CTYPE, SFX) \ + NPY_FINLINE npyv_##SFX npyv_load_##SFX(const CTYPE *ptr) \ + { return _mm512_loadu_si512((const __m512i*)ptr); } \ + NPY_FINLINE npyv_##SFX npyv_loada_##SFX(const CTYPE *ptr) \ + { return _mm512_load_si512((const __m512i*)ptr); } \ + NPY_FINLINE npyv_##SFX npyv_loads_##SFX(const CTYPE *ptr) \ + { return npyv__loads(ptr); } \ + NPY_FINLINE npyv_##SFX npyv_loadl_##SFX(const CTYPE *ptr) \ + { return npyv__loadl((const __m256i *)ptr); } \ + NPY_FINLINE void npyv_store_##SFX(CTYPE *ptr, npyv_##SFX vec) \ + { _mm512_storeu_si512((__m512i*)ptr, vec); } \ + NPY_FINLINE void npyv_storea_##SFX(CTYPE *ptr, npyv_##SFX vec) \ + { _mm512_store_si512((__m512i*)ptr, vec); } \ + NPY_FINLINE void npyv_stores_##SFX(CTYPE *ptr, npyv_##SFX vec) \ + { _mm512_stream_si512((__m512i*)ptr, vec); } \ + NPY_FINLINE void npyv_storel_##SFX(CTYPE *ptr, npyv_##SFX vec) \ + { _mm256_storeu_si256((__m256i*)ptr, npyv512_lower_si256(vec)); } \ + NPY_FINLINE void npyv_storeh_##SFX(CTYPE *ptr, npyv_##SFX vec) \ + { _mm256_storeu_si256((__m256i*)(ptr), npyv512_higher_si256(vec)); } + +NPYV_IMPL_AVX512_MEM_INT(npy_uint8, u8) +NPYV_IMPL_AVX512_MEM_INT(npy_int8, s8) +NPYV_IMPL_AVX512_MEM_INT(npy_uint16, u16) +NPYV_IMPL_AVX512_MEM_INT(npy_int16, s16) +NPYV_IMPL_AVX512_MEM_INT(npy_uint32, u32) +NPYV_IMPL_AVX512_MEM_INT(npy_int32, s32) +NPYV_IMPL_AVX512_MEM_INT(npy_uint64, u64) +NPYV_IMPL_AVX512_MEM_INT(npy_int64, s64) + +// unaligned load +#define npyv_load_f32(PTR) _mm512_loadu_ps((const __m512*)(PTR)) +#define npyv_load_f64(PTR) _mm512_loadu_pd((const __m512d*)(PTR)) +// aligned load +#define npyv_loada_f32(PTR) _mm512_load_ps((const __m512*)(PTR)) +#define npyv_loada_f64(PTR) _mm512_load_pd((const __m512d*)(PTR)) +// load lower part +#if defined(_MSC_VER) && defined(_M_IX86) + #define npyv_loadl_f32(PTR) _mm512_castsi512_ps(npyv__loadl((const __m256i *)(PTR))) + #define npyv_loadl_f64(PTR) _mm512_castsi512_pd(npyv__loadl((const __m256i *)(PTR))) +#else + #define npyv_loadl_f32(PTR) _mm512_castps256_ps512(_mm256_loadu_ps(PTR)) + #define npyv_loadl_f64(PTR) _mm512_castpd256_pd512(_mm256_loadu_pd(PTR)) +#endif +// stream load +#define npyv_loads_f32(PTR) _mm512_castsi512_ps(npyv__loads(PTR)) +#define npyv_loads_f64(PTR) _mm512_castsi512_pd(npyv__loads(PTR)) +// unaligned store +#define npyv_store_f32 _mm512_storeu_ps +#define npyv_store_f64 _mm512_storeu_pd +// aligned store +#define npyv_storea_f32 _mm512_store_ps +#define npyv_storea_f64 _mm512_store_pd +// stream store +#define npyv_stores_f32 _mm512_stream_ps +#define npyv_stores_f64 _mm512_stream_pd +// store lower part +#define npyv_storel_f32(PTR, VEC) _mm256_storeu_ps(PTR, npyv512_lower_ps256(VEC)) +#define npyv_storel_f64(PTR, VEC) _mm256_storeu_pd(PTR, npyv512_lower_pd256(VEC)) +// store higher part +#define npyv_storeh_f32(PTR, VEC) _mm256_storeu_ps(PTR, npyv512_higher_ps256(VEC)) +#define npyv_storeh_f64(PTR, VEC) _mm256_storeu_pd(PTR, npyv512_higher_pd256(VEC)) +/*************************** + * Non-contiguous Load + ***************************/ +//// 32 +NPY_FINLINE npyv_u32 npyv_loadn_u32(const npy_uint32 *ptr, npy_intp stride) +{ + assert(llabs(stride) <= NPY_SIMD_MAXLOAD_STRIDE32); + const __m512i steps = npyv_set_s32( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + ); + const __m512i idx = _mm512_mullo_epi32(steps, _mm512_set1_epi32((int)stride)); + return _mm512_i32gather_epi32(idx, (const __m512i*)ptr, 4); +} +NPY_FINLINE npyv_s32 npyv_loadn_s32(const npy_int32 *ptr, npy_intp stride) +{ return npyv_loadn_u32((const npy_uint32*)ptr, stride); } +NPY_FINLINE npyv_f32 npyv_loadn_f32(const float *ptr, npy_intp stride) +{ return _mm512_castsi512_ps(npyv_loadn_u32((const npy_uint32*)ptr, stride)); } +//// 64 +NPY_FINLINE npyv_u64 npyv_loadn_u64(const npy_uint64 *ptr, npy_intp stride) +{ + const __m512i idx = npyv_set_s64( + 0*stride, 1*stride, 2*stride, 3*stride, + 4*stride, 5*stride, 6*stride, 7*stride + ); + return _mm512_i64gather_epi64(idx, (const __m512i*)ptr, 8); +} +NPY_FINLINE npyv_s64 npyv_loadn_s64(const npy_int64 *ptr, npy_intp stride) +{ return npyv_loadn_u64((const npy_uint64*)ptr, stride); } +NPY_FINLINE npyv_f64 npyv_loadn_f64(const double *ptr, npy_intp stride) +{ return _mm512_castsi512_pd(npyv_loadn_u64((const npy_uint64*)ptr, stride)); } + +//// 64-bit load over 32-bit stride +NPY_FINLINE npyv_u32 npyv_loadn2_u32(const npy_uint32 *ptr, npy_intp stride) +{ + __m128d a = _mm_loadh_pd( + _mm_castsi128_pd(_mm_loadl_epi64((const __m128i*)ptr)), + (const double*)(ptr + stride) + ); + __m128d b = _mm_loadh_pd( + _mm_castsi128_pd(_mm_loadl_epi64((const __m128i*)(ptr + stride*2))), + (const double*)(ptr + stride*3) + ); + __m128d c = _mm_loadh_pd( + _mm_castsi128_pd(_mm_loadl_epi64((const __m128i*)(ptr + stride*4))), + (const double*)(ptr + stride*5) + ); + __m128d d = _mm_loadh_pd( + _mm_castsi128_pd(_mm_loadl_epi64((const __m128i*)(ptr + stride*6))), + (const double*)(ptr + stride*7) + ); + return _mm512_castpd_si512(npyv512_combine_pd256( + _mm256_insertf128_pd(_mm256_castpd128_pd256(a), b, 1), + _mm256_insertf128_pd(_mm256_castpd128_pd256(c), d, 1) + )); +} +NPY_FINLINE npyv_s32 npyv_loadn2_s32(const npy_int32 *ptr, npy_intp stride) +{ return npyv_loadn2_u32((const npy_uint32*)ptr, stride); } +NPY_FINLINE npyv_f32 npyv_loadn2_f32(const float *ptr, npy_intp stride) +{ return _mm512_castsi512_ps(npyv_loadn2_u32((const npy_uint32*)ptr, stride)); } + +//// 128-bit load over 64-bit stride +NPY_FINLINE npyv_f64 npyv_loadn2_f64(const double *ptr, npy_intp stride) +{ + __m128d a = _mm_loadu_pd(ptr); + __m128d b = _mm_loadu_pd(ptr + stride); + __m128d c = _mm_loadu_pd(ptr + stride * 2); + __m128d d = _mm_loadu_pd(ptr + stride * 3); + return npyv512_combine_pd256( + _mm256_insertf128_pd(_mm256_castpd128_pd256(a), b, 1), + _mm256_insertf128_pd(_mm256_castpd128_pd256(c), d, 1) + ); +} +NPY_FINLINE npyv_u64 npyv_loadn2_u64(const npy_uint64 *ptr, npy_intp stride) +{ return npyv_reinterpret_u64_f64(npyv_loadn2_f64((const double*)ptr, stride)); } +NPY_FINLINE npyv_s64 npyv_loadn2_s64(const npy_int64 *ptr, npy_intp stride) +{ return npyv_loadn2_u64((const npy_uint64*)ptr, stride); } +/*************************** + * Non-contiguous Store + ***************************/ +//// 32 +NPY_FINLINE void npyv_storen_u32(npy_uint32 *ptr, npy_intp stride, npyv_u32 a) +{ + assert(llabs(stride) <= NPY_SIMD_MAXSTORE_STRIDE32); + const __m512i steps = _mm512_setr_epi32( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + ); + const __m512i idx = _mm512_mullo_epi32(steps, _mm512_set1_epi32((int)stride)); + _mm512_i32scatter_epi32((__m512i*)ptr, idx, a, 4); +} +NPY_FINLINE void npyv_storen_s32(npy_int32 *ptr, npy_intp stride, npyv_s32 a) +{ npyv_storen_u32((npy_uint32*)ptr, stride, a); } +NPY_FINLINE void npyv_storen_f32(float *ptr, npy_intp stride, npyv_f32 a) +{ npyv_storen_u32((npy_uint32*)ptr, stride, _mm512_castps_si512(a)); } +//// 64 +NPY_FINLINE void npyv_storen_u64(npy_uint64 *ptr, npy_intp stride, npyv_u64 a) +{ + const __m512i idx = npyv_set_s64( + 0*stride, 1*stride, 2*stride, 3*stride, + 4*stride, 5*stride, 6*stride, 7*stride + ); + _mm512_i64scatter_epi64((__m512i*)ptr, idx, a, 8); +} +NPY_FINLINE void npyv_storen_s64(npy_int64 *ptr, npy_intp stride, npyv_s64 a) +{ npyv_storen_u64((npy_uint64*)ptr, stride, a); } +NPY_FINLINE void npyv_storen_f64(double *ptr, npy_intp stride, npyv_f64 a) +{ npyv_storen_u64((npy_uint64*)ptr, stride, _mm512_castpd_si512(a)); } + +//// 64-bit store over 32-bit stride +NPY_FINLINE void npyv_storen2_u32(npy_uint32 *ptr, npy_intp stride, npyv_u32 a) +{ + __m256d lo = _mm512_castpd512_pd256(_mm512_castsi512_pd(a)); + __m256d hi = _mm512_extractf64x4_pd(_mm512_castsi512_pd(a), 1); + __m128d e0 = _mm256_castpd256_pd128(lo); + __m128d e1 = _mm256_extractf128_pd(lo, 1); + __m128d e2 = _mm256_castpd256_pd128(hi); + __m128d e3 = _mm256_extractf128_pd(hi, 1); + _mm_storel_pd((double*)(ptr + stride * 0), e0); + _mm_storeh_pd((double*)(ptr + stride * 1), e0); + _mm_storel_pd((double*)(ptr + stride * 2), e1); + _mm_storeh_pd((double*)(ptr + stride * 3), e1); + _mm_storel_pd((double*)(ptr + stride * 4), e2); + _mm_storeh_pd((double*)(ptr + stride * 5), e2); + _mm_storel_pd((double*)(ptr + stride * 6), e3); + _mm_storeh_pd((double*)(ptr + stride * 7), e3); +} +NPY_FINLINE void npyv_storen2_s32(npy_int32 *ptr, npy_intp stride, npyv_s32 a) +{ npyv_storen2_u32((npy_uint32*)ptr, stride, a); } +NPY_FINLINE void npyv_storen2_f32(float *ptr, npy_intp stride, npyv_f32 a) +{ npyv_storen2_u32((npy_uint32*)ptr, stride, _mm512_castps_si512(a)); } + +//// 128-bit store over 64-bit stride +NPY_FINLINE void npyv_storen2_u64(npy_uint64 *ptr, npy_intp stride, npyv_u64 a) +{ + __m256i lo = npyv512_lower_si256(a); + __m256i hi = npyv512_higher_si256(a); + __m128i e0 = _mm256_castsi256_si128(lo); + __m128i e1 = _mm256_extracti128_si256(lo, 1); + __m128i e2 = _mm256_castsi256_si128(hi); + __m128i e3 = _mm256_extracti128_si256(hi, 1); + _mm_storeu_si128((__m128i*)(ptr + stride * 0), e0); + _mm_storeu_si128((__m128i*)(ptr + stride * 1), e1); + _mm_storeu_si128((__m128i*)(ptr + stride * 2), e2); + _mm_storeu_si128((__m128i*)(ptr + stride * 3), e3); +} +NPY_FINLINE void npyv_storen2_s64(npy_int64 *ptr, npy_intp stride, npyv_s64 a) +{ npyv_storen2_u64((npy_uint64*)ptr, stride, a); } +NPY_FINLINE void npyv_storen2_f64(double *ptr, npy_intp stride, npyv_f64 a) +{ npyv_storen2_u64((npy_uint64*)ptr, stride, _mm512_castpd_si512(a)); } + +/********************************* + * Partial Load + *********************************/ +//// 32 +NPY_FINLINE npyv_s32 npyv_load_till_s32(const npy_int32 *ptr, npy_uintp nlane, npy_int32 fill) +{ + assert(nlane > 0); + const __m512i vfill = _mm512_set1_epi32(fill); + const __mmask16 mask = nlane > 15 ? -1 : (1 << nlane) - 1; + __m512i ret = _mm512_mask_loadu_epi32(vfill, mask, (const __m512i*)ptr); +#if NPY_SIMD_GUARD_PARTIAL_LOAD + volatile __m512i workaround = ret; + ret = _mm512_or_si512(workaround, ret); +#endif + return ret; +} +// fill zero to rest lanes +NPY_FINLINE npyv_s32 npyv_load_tillz_s32(const npy_int32 *ptr, npy_uintp nlane) +{ + assert(nlane > 0); + const __mmask16 mask = nlane > 15 ? -1 : (1 << nlane) - 1; + __m512i ret = _mm512_maskz_loadu_epi32(mask, (const __m512i*)ptr); +#if NPY_SIMD_GUARD_PARTIAL_LOAD + volatile __m512i workaround = ret; + ret = _mm512_or_si512(workaround, ret); +#endif + return ret; +} +//// 64 +NPY_FINLINE npyv_s64 npyv_load_till_s64(const npy_int64 *ptr, npy_uintp nlane, npy_int64 fill) +{ + assert(nlane > 0); + const __m512i vfill = npyv_setall_s64(fill); + const __mmask8 mask = nlane > 7 ? -1 : (1 << nlane) - 1; + __m512i ret = _mm512_mask_loadu_epi64(vfill, mask, (const __m512i*)ptr); +#if NPY_SIMD_GUARD_PARTIAL_LOAD + volatile __m512i workaround = ret; + ret = _mm512_or_si512(workaround, ret); +#endif + return ret; +} +// fill zero to rest lanes +NPY_FINLINE npyv_s64 npyv_load_tillz_s64(const npy_int64 *ptr, npy_uintp nlane) +{ + assert(nlane > 0); + const __mmask8 mask = nlane > 7 ? -1 : (1 << nlane) - 1; + __m512i ret = _mm512_maskz_loadu_epi64(mask, (const __m512i*)ptr); +#if NPY_SIMD_GUARD_PARTIAL_LOAD + volatile __m512i workaround = ret; + ret = _mm512_or_si512(workaround, ret); +#endif + return ret; +} + +//// 64-bit nlane +NPY_FINLINE npyv_s32 npyv_load2_till_s32(const npy_int32 *ptr, npy_uintp nlane, + npy_int32 fill_lo, npy_int32 fill_hi) +{ + assert(nlane > 0); + const __m512i vfill = _mm512_set4_epi32(fill_hi, fill_lo, fill_hi, fill_lo); + const __mmask8 mask = nlane > 7 ? -1 : (1 << nlane) - 1; + __m512i ret = _mm512_mask_loadu_epi64(vfill, mask, (const __m512i*)ptr); +#if NPY_SIMD_GUARD_PARTIAL_LOAD + volatile __m512i workaround = ret; + ret = _mm512_or_si512(workaround, ret); +#endif + return ret; +} +// fill zero to rest lanes +NPY_FINLINE npyv_s32 npyv_load2_tillz_s32(const npy_int32 *ptr, npy_uintp nlane) +{ return npyv_load_tillz_s64((const npy_int64*)ptr, nlane); } + +//// 128-bit nlane +NPY_FINLINE npyv_u64 npyv_load2_till_s64(const npy_int64 *ptr, npy_uintp nlane, + npy_int64 fill_lo, npy_int64 fill_hi) +{ + assert(nlane > 0); + const __m512i vfill = _mm512_set4_epi64(fill_hi, fill_lo, fill_hi, fill_lo); + const __mmask8 mask = nlane > 3 ? -1 : (1 << (nlane*2)) - 1; + __m512i ret = _mm512_mask_loadu_epi64(vfill, mask, (const __m512i*)ptr); +#if NPY_SIMD_GUARD_PARTIAL_LOAD + volatile __m512i workaround = ret; + ret = _mm512_or_si512(workaround, ret); +#endif + return ret; +} +// fill zero to rest lanes +NPY_FINLINE npyv_s64 npyv_load2_tillz_s64(const npy_int64 *ptr, npy_uintp nlane) +{ + assert(nlane > 0); + const __mmask8 mask = nlane > 3 ? -1 : (1 << (nlane*2)) - 1; + __m512i ret = _mm512_maskz_loadu_epi64(mask, (const __m512i*)ptr); +#if NPY_SIMD_GUARD_PARTIAL_LOAD + volatile __m512i workaround = ret; + ret = _mm512_or_si512(workaround, ret); +#endif + return ret; +} +/********************************* + * Non-contiguous partial load + *********************************/ +//// 32 +NPY_FINLINE npyv_s32 +npyv_loadn_till_s32(const npy_int32 *ptr, npy_intp stride, npy_uintp nlane, npy_int32 fill) +{ + assert(nlane > 0); + assert(llabs(stride) <= NPY_SIMD_MAXLOAD_STRIDE32); + const __m512i steps = npyv_set_s32( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + ); + const __m512i idx = _mm512_mullo_epi32(steps, _mm512_set1_epi32((int)stride)); + const __m512i vfill = _mm512_set1_epi32(fill); + const __mmask16 mask = nlane > 15 ? -1 : (1 << nlane) - 1; + __m512i ret = _mm512_mask_i32gather_epi32(vfill, mask, idx, (const __m512i*)ptr, 4); +#if NPY_SIMD_GUARD_PARTIAL_LOAD + volatile __m512i workaround = ret; + ret = _mm512_or_si512(workaround, ret); +#endif + return ret; +} +// fill zero to rest lanes +NPY_FINLINE npyv_s32 +npyv_loadn_tillz_s32(const npy_int32 *ptr, npy_intp stride, npy_uintp nlane) +{ return npyv_loadn_till_s32(ptr, stride, nlane, 0); } +//// 64 +NPY_FINLINE npyv_s64 +npyv_loadn_till_s64(const npy_int64 *ptr, npy_intp stride, npy_uintp nlane, npy_int64 fill) +{ + assert(nlane > 0); + const __m512i idx = npyv_set_s64( + 0*stride, 1*stride, 2*stride, 3*stride, + 4*stride, 5*stride, 6*stride, 7*stride + ); + const __m512i vfill = npyv_setall_s64(fill); + const __mmask8 mask = nlane > 15 ? -1 : (1 << nlane) - 1; + __m512i ret = _mm512_mask_i64gather_epi64(vfill, mask, idx, (const __m512i*)ptr, 8); +#if NPY_SIMD_GUARD_PARTIAL_LOAD + volatile __m512i workaround = ret; + ret = _mm512_or_si512(workaround, ret); +#endif + return ret; +} +// fill zero to rest lanes +NPY_FINLINE npyv_s64 +npyv_loadn_tillz_s64(const npy_int64 *ptr, npy_intp stride, npy_uintp nlane) +{ return npyv_loadn_till_s64(ptr, stride, nlane, 0); } + +//// 64-bit load over 32-bit stride +NPY_FINLINE npyv_s64 npyv_loadn2_till_s32(const npy_int32 *ptr, npy_intp stride, npy_uintp nlane, + npy_int32 fill_lo, npy_int32 fill_hi) +{ + assert(nlane > 0); + const __m512i idx = npyv_set_s64( + 0*stride, 1*stride, 2*stride, 3*stride, + 4*stride, 5*stride, 6*stride, 7*stride + ); + const __m512i vfill = _mm512_set4_epi32(fill_hi, fill_lo, fill_hi, fill_lo); + const __mmask8 mask = nlane > 7 ? -1 : (1 << nlane) - 1; + __m512i ret = _mm512_mask_i64gather_epi64(vfill, mask, idx, (const __m512i*)ptr, 4); +#if NPY_SIMD_GUARD_PARTIAL_LOAD + volatile __m512i workaround = ret; + ret = _mm512_or_si512(workaround, ret); +#endif + return ret; +} +// fill zero to rest lanes +NPY_FINLINE npyv_s32 npyv_loadn2_tillz_s32(const npy_int32 *ptr, npy_intp stride, npy_uintp nlane) +{ return npyv_loadn2_till_s32(ptr, stride, nlane, 0, 0); } + +//// 128-bit load over 64-bit stride +NPY_FINLINE npyv_s64 npyv_loadn2_till_s64(const npy_int64 *ptr, npy_intp stride, npy_uintp nlane, + npy_int64 fill_lo, npy_int64 fill_hi) +{ + assert(nlane > 0); + const __m512i idx = npyv_set_s64( + 0, 1, stride, stride+1, + stride*2, stride*2+1, stride*3, stride*3+1 + ); + const __mmask8 mask = nlane > 3 ? -1 : (1 << (nlane*2)) - 1; + const __m512i vfill = _mm512_set4_epi64(fill_hi, fill_lo, fill_hi, fill_lo); + __m512i ret = _mm512_mask_i64gather_epi64(vfill, mask, idx, (const __m512i*)ptr, 8); +#if NPY_SIMD_GUARD_PARTIAL_LOAD + volatile __m512i workaround = ret; + ret = _mm512_or_si512(workaround, ret); +#endif + return ret; +} +// fill zero to rest lanes +NPY_FINLINE npyv_s64 npyv_loadn2_tillz_s64(const npy_int64 *ptr, npy_intp stride, npy_uintp nlane) +{ return npyv_loadn2_till_s64(ptr, stride, nlane, 0, 0); } + +/********************************* + * Partial store + *********************************/ +//// 32 +NPY_FINLINE void npyv_store_till_s32(npy_int32 *ptr, npy_uintp nlane, npyv_s32 a) +{ + assert(nlane > 0); + const __mmask16 mask = nlane > 15 ? -1 : (1 << nlane) - 1; + _mm512_mask_storeu_epi32((__m512i*)ptr, mask, a); +} +//// 64 +NPY_FINLINE void npyv_store_till_s64(npy_int64 *ptr, npy_uintp nlane, npyv_s64 a) +{ + assert(nlane > 0); + const __mmask8 mask = nlane > 7 ? -1 : (1 << nlane) - 1; + _mm512_mask_storeu_epi64((__m512i*)ptr, mask, a); +} + +//// 64-bit nlane +NPY_FINLINE void npyv_store2_till_s32(npy_int32 *ptr, npy_uintp nlane, npyv_s32 a) +{ + assert(nlane > 0); + const __mmask8 mask = nlane > 7 ? -1 : (1 << nlane) - 1; + _mm512_mask_storeu_epi64((__m512i*)ptr, mask, a); +} + +//// 128-bit nlane +NPY_FINLINE void npyv_store2_till_s64(npy_int64 *ptr, npy_uintp nlane, npyv_s64 a) +{ + assert(nlane > 0); + const __mmask8 mask = nlane > 3 ? -1 : (1 << (nlane*2)) - 1; + _mm512_mask_storeu_epi64((__m512i*)ptr, mask, a); +} +/********************************* + * Non-contiguous partial store + *********************************/ +//// 32 +NPY_FINLINE void npyv_storen_till_s32(npy_int32 *ptr, npy_intp stride, npy_uintp nlane, npyv_s32 a) +{ + assert(nlane > 0); + assert(llabs(stride) <= NPY_SIMD_MAXSTORE_STRIDE32); + const __m512i steps = _mm512_setr_epi32( + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + ); + const __m512i idx = _mm512_mullo_epi32(steps, _mm512_set1_epi32((int)stride)); + const __mmask16 mask = nlane > 15 ? -1 : (1 << nlane) - 1; + _mm512_mask_i32scatter_epi32((__m512i*)ptr, mask, idx, a, 4); +} +//// 64 +NPY_FINLINE void npyv_storen_till_s64(npy_int64 *ptr, npy_intp stride, npy_uintp nlane, npyv_s64 a) +{ + assert(nlane > 0); + const __m512i idx = npyv_set_s64( + 0*stride, 1*stride, 2*stride, 3*stride, + 4*stride, 5*stride, 6*stride, 7*stride + ); + const __mmask8 mask = nlane > 7 ? -1 : (1 << nlane) - 1; + _mm512_mask_i64scatter_epi64((__m512i*)ptr, mask, idx, a, 8); +} + +//// 64-bit store over 32-bit stride +NPY_FINLINE void npyv_storen2_till_s32(npy_int32 *ptr, npy_intp stride, npy_uintp nlane, npyv_s32 a) +{ + assert(nlane > 0); + const __m512i idx = npyv_set_s64( + 0*stride, 1*stride, 2*stride, 3*stride, + 4*stride, 5*stride, 6*stride, 7*stride + ); + const __mmask8 mask = nlane > 7 ? -1 : (1 << nlane) - 1; + _mm512_mask_i64scatter_epi64((__m512i*)ptr, mask, idx, a, 4); +} + +//// 128-bit store over 64-bit stride +NPY_FINLINE void npyv_storen2_till_s64(npy_int64 *ptr, npy_intp stride, npy_uintp nlane, npyv_s64 a) +{ + assert(nlane > 0); + const __m512i idx = npyv_set_s64( + 0, 1, stride, stride+1, + 2*stride, 2*stride+1, 3*stride, 3*stride+1 + ); + const __mmask8 mask = nlane > 3 ? -1 : (1 << (nlane*2)) - 1; + _mm512_mask_i64scatter_epi64((__m512i*)ptr, mask, idx, a, 8); +} + +/***************************************************************************** + * Implement partial load/store for u32/f32/u64/f64... via reinterpret cast + *****************************************************************************/ +#define NPYV_IMPL_AVX512_REST_PARTIAL_TYPES(F_SFX, T_SFX) \ + NPY_FINLINE npyv_##F_SFX npyv_load_till_##F_SFX \ + (const npyv_lanetype_##F_SFX *ptr, npy_uintp nlane, npyv_lanetype_##F_SFX fill) \ + { \ + union { \ + npyv_lanetype_##F_SFX from_##F_SFX; \ + npyv_lanetype_##T_SFX to_##T_SFX; \ + } pun; \ + pun.from_##F_SFX = fill; \ + return npyv_reinterpret_##F_SFX##_##T_SFX(npyv_load_till_##T_SFX( \ + (const npyv_lanetype_##T_SFX *)ptr, nlane, pun.to_##T_SFX \ + )); \ + } \ + NPY_FINLINE npyv_##F_SFX npyv_loadn_till_##F_SFX \ + (const npyv_lanetype_##F_SFX *ptr, npy_intp stride, npy_uintp nlane, \ + npyv_lanetype_##F_SFX fill) \ + { \ + union { \ + npyv_lanetype_##F_SFX from_##F_SFX; \ + npyv_lanetype_##T_SFX to_##T_SFX; \ + } pun; \ + pun.from_##F_SFX = fill; \ + return npyv_reinterpret_##F_SFX##_##T_SFX(npyv_loadn_till_##T_SFX( \ + (const npyv_lanetype_##T_SFX *)ptr, stride, nlane, pun.to_##T_SFX \ + )); \ + } \ + NPY_FINLINE npyv_##F_SFX npyv_load_tillz_##F_SFX \ + (const npyv_lanetype_##F_SFX *ptr, npy_uintp nlane) \ + { \ + return npyv_reinterpret_##F_SFX##_##T_SFX(npyv_load_tillz_##T_SFX( \ + (const npyv_lanetype_##T_SFX *)ptr, nlane \ + )); \ + } \ + NPY_FINLINE npyv_##F_SFX npyv_loadn_tillz_##F_SFX \ + (const npyv_lanetype_##F_SFX *ptr, npy_intp stride, npy_uintp nlane) \ + { \ + return npyv_reinterpret_##F_SFX##_##T_SFX(npyv_loadn_tillz_##T_SFX( \ + (const npyv_lanetype_##T_SFX *)ptr, stride, nlane \ + )); \ + } \ + NPY_FINLINE void npyv_store_till_##F_SFX \ + (npyv_lanetype_##F_SFX *ptr, npy_uintp nlane, npyv_##F_SFX a) \ + { \ + npyv_store_till_##T_SFX( \ + (npyv_lanetype_##T_SFX *)ptr, nlane, \ + npyv_reinterpret_##T_SFX##_##F_SFX(a) \ + ); \ + } \ + NPY_FINLINE void npyv_storen_till_##F_SFX \ + (npyv_lanetype_##F_SFX *ptr, npy_intp stride, npy_uintp nlane, npyv_##F_SFX a) \ + { \ + npyv_storen_till_##T_SFX( \ + (npyv_lanetype_##T_SFX *)ptr, stride, nlane, \ + npyv_reinterpret_##T_SFX##_##F_SFX(a) \ + ); \ + } + +NPYV_IMPL_AVX512_REST_PARTIAL_TYPES(u32, s32) +NPYV_IMPL_AVX512_REST_PARTIAL_TYPES(f32, s32) +NPYV_IMPL_AVX512_REST_PARTIAL_TYPES(u64, s64) +NPYV_IMPL_AVX512_REST_PARTIAL_TYPES(f64, s64) + +// 128-bit/64-bit stride (pair load/store) +#define NPYV_IMPL_AVX512_REST_PARTIAL_TYPES_PAIR(F_SFX, T_SFX) \ + NPY_FINLINE npyv_##F_SFX npyv_load2_till_##F_SFX \ + (const npyv_lanetype_##F_SFX *ptr, npy_uintp nlane, \ + npyv_lanetype_##F_SFX fill_lo, npyv_lanetype_##F_SFX fill_hi) \ + { \ + union pun { \ + npyv_lanetype_##F_SFX from_##F_SFX; \ + npyv_lanetype_##T_SFX to_##T_SFX; \ + }; \ + union pun pun_lo; \ + union pun pun_hi; \ + pun_lo.from_##F_SFX = fill_lo; \ + pun_hi.from_##F_SFX = fill_hi; \ + return npyv_reinterpret_##F_SFX##_##T_SFX(npyv_load2_till_##T_SFX( \ + (const npyv_lanetype_##T_SFX *)ptr, nlane, pun_lo.to_##T_SFX, pun_hi.to_##T_SFX \ + )); \ + } \ + NPY_FINLINE npyv_##F_SFX npyv_loadn2_till_##F_SFX \ + (const npyv_lanetype_##F_SFX *ptr, npy_intp stride, npy_uintp nlane, \ + npyv_lanetype_##F_SFX fill_lo, npyv_lanetype_##F_SFX fill_hi) \ + { \ + union pun { \ + npyv_lanetype_##F_SFX from_##F_SFX; \ + npyv_lanetype_##T_SFX to_##T_SFX; \ + }; \ + union pun pun_lo; \ + union pun pun_hi; \ + pun_lo.from_##F_SFX = fill_lo; \ + pun_hi.from_##F_SFX = fill_hi; \ + return npyv_reinterpret_##F_SFX##_##T_SFX(npyv_loadn2_till_##T_SFX( \ + (const npyv_lanetype_##T_SFX *)ptr, stride, nlane, pun_lo.to_##T_SFX, \ + pun_hi.to_##T_SFX \ + )); \ + } \ + NPY_FINLINE npyv_##F_SFX npyv_load2_tillz_##F_SFX \ + (const npyv_lanetype_##F_SFX *ptr, npy_uintp nlane) \ + { \ + return npyv_reinterpret_##F_SFX##_##T_SFX(npyv_load2_tillz_##T_SFX( \ + (const npyv_lanetype_##T_SFX *)ptr, nlane \ + )); \ + } \ + NPY_FINLINE npyv_##F_SFX npyv_loadn2_tillz_##F_SFX \ + (const npyv_lanetype_##F_SFX *ptr, npy_intp stride, npy_uintp nlane) \ + { \ + return npyv_reinterpret_##F_SFX##_##T_SFX(npyv_loadn2_tillz_##T_SFX( \ + (const npyv_lanetype_##T_SFX *)ptr, stride, nlane \ + )); \ + } \ + NPY_FINLINE void npyv_store2_till_##F_SFX \ + (npyv_lanetype_##F_SFX *ptr, npy_uintp nlane, npyv_##F_SFX a) \ + { \ + npyv_store2_till_##T_SFX( \ + (npyv_lanetype_##T_SFX *)ptr, nlane, \ + npyv_reinterpret_##T_SFX##_##F_SFX(a) \ + ); \ + } \ + NPY_FINLINE void npyv_storen2_till_##F_SFX \ + (npyv_lanetype_##F_SFX *ptr, npy_intp stride, npy_uintp nlane, npyv_##F_SFX a) \ + { \ + npyv_storen2_till_##T_SFX( \ + (npyv_lanetype_##T_SFX *)ptr, stride, nlane, \ + npyv_reinterpret_##T_SFX##_##F_SFX(a) \ + ); \ + } + +NPYV_IMPL_AVX512_REST_PARTIAL_TYPES_PAIR(u32, s32) +NPYV_IMPL_AVX512_REST_PARTIAL_TYPES_PAIR(f32, s32) +NPYV_IMPL_AVX512_REST_PARTIAL_TYPES_PAIR(u64, s64) +NPYV_IMPL_AVX512_REST_PARTIAL_TYPES_PAIR(f64, s64) + +/************************************************************ + * de-interlave load / interleave contiguous store + ************************************************************/ +// two channels +#define NPYV_IMPL_AVX512_MEM_INTERLEAVE(SFX, ZSFX) \ + NPY_FINLINE npyv_##ZSFX##x2 npyv_zip_##ZSFX(npyv_##ZSFX, npyv_##ZSFX); \ + NPY_FINLINE npyv_##ZSFX##x2 npyv_unzip_##ZSFX(npyv_##ZSFX, npyv_##ZSFX); \ + NPY_FINLINE npyv_##SFX##x2 npyv_load_##SFX##x2( \ + const npyv_lanetype_##SFX *ptr \ + ) { \ + return npyv_unzip_##ZSFX( \ + npyv_load_##SFX(ptr), npyv_load_##SFX(ptr+npyv_nlanes_##SFX) \ + ); \ + } \ + NPY_FINLINE void npyv_store_##SFX##x2( \ + npyv_lanetype_##SFX *ptr, npyv_##SFX##x2 v \ + ) { \ + npyv_##SFX##x2 zip = npyv_zip_##ZSFX(v.val[0], v.val[1]); \ + npyv_store_##SFX(ptr, zip.val[0]); \ + npyv_store_##SFX(ptr + npyv_nlanes_##SFX, zip.val[1]); \ + } + +NPYV_IMPL_AVX512_MEM_INTERLEAVE(u8, u8) +NPYV_IMPL_AVX512_MEM_INTERLEAVE(s8, u8) +NPYV_IMPL_AVX512_MEM_INTERLEAVE(u16, u16) +NPYV_IMPL_AVX512_MEM_INTERLEAVE(s16, u16) +NPYV_IMPL_AVX512_MEM_INTERLEAVE(u32, u32) +NPYV_IMPL_AVX512_MEM_INTERLEAVE(s32, u32) +NPYV_IMPL_AVX512_MEM_INTERLEAVE(u64, u64) +NPYV_IMPL_AVX512_MEM_INTERLEAVE(s64, u64) +NPYV_IMPL_AVX512_MEM_INTERLEAVE(f32, f32) +NPYV_IMPL_AVX512_MEM_INTERLEAVE(f64, f64) + +/************************************************** + * Lookup table + *************************************************/ +// uses vector as indexes into a table +// that contains 32 elements of float32. +NPY_FINLINE npyv_f32 npyv_lut32_f32(const float *table, npyv_u32 idx) +{ + const npyv_f32 table0 = npyv_load_f32(table); + const npyv_f32 table1 = npyv_load_f32(table + 16); + return _mm512_permutex2var_ps(table0, idx, table1); +} +NPY_FINLINE npyv_u32 npyv_lut32_u32(const npy_uint32 *table, npyv_u32 idx) +{ return npyv_reinterpret_u32_f32(npyv_lut32_f32((const float*)table, idx)); } +NPY_FINLINE npyv_s32 npyv_lut32_s32(const npy_int32 *table, npyv_u32 idx) +{ return npyv_reinterpret_s32_f32(npyv_lut32_f32((const float*)table, idx)); } + +// uses vector as indexes into a table +// that contains 16 elements of float64. +NPY_FINLINE npyv_f64 npyv_lut16_f64(const double *table, npyv_u64 idx) +{ + const npyv_f64 table0 = npyv_load_f64(table); + const npyv_f64 table1 = npyv_load_f64(table + 8); + return _mm512_permutex2var_pd(table0, idx, table1); +} +NPY_FINLINE npyv_u64 npyv_lut16_u64(const npy_uint64 *table, npyv_u64 idx) +{ return npyv_reinterpret_u64_f64(npyv_lut16_f64((const double*)table, idx)); } +NPY_FINLINE npyv_s64 npyv_lut16_s64(const npy_int64 *table, npyv_u64 idx) +{ return npyv_reinterpret_s64_f64(npyv_lut16_f64((const double*)table, idx)); } + +#endif // _NPY_SIMD_AVX512_MEMORY_H diff --git a/mkl_umath/src/npyv/avx512/misc.h b/mkl_umath/src/npyv/avx512/misc.h new file mode 100644 index 00000000..d9190870 --- /dev/null +++ b/mkl_umath/src/npyv/avx512/misc.h @@ -0,0 +1,292 @@ +#ifndef NPY_SIMD + #error "Not a standalone header" +#endif + +#ifndef _NPY_SIMD_AVX512_MISC_H +#define _NPY_SIMD_AVX512_MISC_H + +// set all lanes to zero +#define npyv_zero_u8 _mm512_setzero_si512 +#define npyv_zero_s8 _mm512_setzero_si512 +#define npyv_zero_u16 _mm512_setzero_si512 +#define npyv_zero_s16 _mm512_setzero_si512 +#define npyv_zero_u32 _mm512_setzero_si512 +#define npyv_zero_s32 _mm512_setzero_si512 +#define npyv_zero_u64 _mm512_setzero_si512 +#define npyv_zero_s64 _mm512_setzero_si512 +#define npyv_zero_f32 _mm512_setzero_ps +#define npyv_zero_f64 _mm512_setzero_pd + +// set all lanes to same value +#define npyv_setall_u8(VAL) _mm512_set1_epi8((char)VAL) +#define npyv_setall_s8(VAL) _mm512_set1_epi8((char)VAL) +#define npyv_setall_u16(VAL) _mm512_set1_epi16((short)VAL) +#define npyv_setall_s16(VAL) _mm512_set1_epi16((short)VAL) +#define npyv_setall_u32(VAL) _mm512_set1_epi32((int)VAL) +#define npyv_setall_s32(VAL) _mm512_set1_epi32(VAL) +#define npyv_setall_f32(VAL) _mm512_set1_ps(VAL) +#define npyv_setall_f64(VAL) _mm512_set1_pd(VAL) + +NPY_FINLINE __m512i npyv__setr_epi64( + npy_int64, npy_int64, npy_int64, npy_int64, + npy_int64, npy_int64, npy_int64, npy_int64 +); +NPY_FINLINE npyv_u64 npyv_setall_u64(npy_uint64 a) +{ + npy_int64 ai = (npy_int64)a; +#if defined(_MSC_VER) && defined(_M_IX86) + return npyv__setr_epi64(ai, ai, ai, ai, ai, ai, ai, ai); +#else + return _mm512_set1_epi64(ai); +#endif +} +NPY_FINLINE npyv_s64 npyv_setall_s64(npy_int64 a) +{ +#if defined(_MSC_VER) && defined(_M_IX86) + return npyv__setr_epi64(a, a, a, a, a, a, a, a); +#else + return _mm512_set1_epi64(a); +#endif +} +/** + * vector with specific values set to each lane and + * set a specific value to all remained lanes + * + * _mm512_set_epi8 and _mm512_set_epi16 are missing in many compilers + */ +NPY_FINLINE __m512i npyv__setr_epi8( + char i0, char i1, char i2, char i3, char i4, char i5, char i6, char i7, + char i8, char i9, char i10, char i11, char i12, char i13, char i14, char i15, + char i16, char i17, char i18, char i19, char i20, char i21, char i22, char i23, + char i24, char i25, char i26, char i27, char i28, char i29, char i30, char i31, + char i32, char i33, char i34, char i35, char i36, char i37, char i38, char i39, + char i40, char i41, char i42, char i43, char i44, char i45, char i46, char i47, + char i48, char i49, char i50, char i51, char i52, char i53, char i54, char i55, + char i56, char i57, char i58, char i59, char i60, char i61, char i62, char i63) +{ + const char NPY_DECL_ALIGNED(64) data[64] = { + i0, i1, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12, i13, i14, i15, + i16, i17, i18, i19, i20, i21, i22, i23, i24, i25, i26, i27, i28, i29, i30, i31, + i32, i33, i34, i35, i36, i37, i38, i39, i40, i41, i42, i43, i44, i45, i46, i47, + i48, i49, i50, i51, i52, i53, i54, i55, i56, i57, i58, i59, i60, i61, i62, i63 + }; + return _mm512_load_si512((const void*)data); +} +NPY_FINLINE __m512i npyv__setr_epi16( + short i0, short i1, short i2, short i3, short i4, short i5, short i6, short i7, + short i8, short i9, short i10, short i11, short i12, short i13, short i14, short i15, + short i16, short i17, short i18, short i19, short i20, short i21, short i22, short i23, + short i24, short i25, short i26, short i27, short i28, short i29, short i30, short i31) +{ + const short NPY_DECL_ALIGNED(64) data[32] = { + i0, i1, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12, i13, i14, i15, + i16, i17, i18, i19, i20, i21, i22, i23, i24, i25, i26, i27, i28, i29, i30, i31 + }; + return _mm512_load_si512((const void*)data); +} +// args that generated by NPYV__SET_FILL_* not going to expand if +// _mm512_setr_* are defined as macros. +NPY_FINLINE __m512i npyv__setr_epi32( + int i0, int i1, int i2, int i3, int i4, int i5, int i6, int i7, + int i8, int i9, int i10, int i11, int i12, int i13, int i14, int i15) +{ + return _mm512_setr_epi32(i0, i1, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12, i13, i14, i15); +} +NPY_FINLINE __m512i npyv__setr_epi64(npy_int64 i0, npy_int64 i1, npy_int64 i2, npy_int64 i3, + npy_int64 i4, npy_int64 i5, npy_int64 i6, npy_int64 i7) +{ +#if defined(_MSC_VER) && defined(_M_IX86) + return _mm512_setr_epi32( + (int)i0, (int)(i0 >> 32), (int)i1, (int)(i1 >> 32), + (int)i2, (int)(i2 >> 32), (int)i3, (int)(i3 >> 32), + (int)i4, (int)(i4 >> 32), (int)i5, (int)(i5 >> 32), + (int)i6, (int)(i6 >> 32), (int)i7, (int)(i7 >> 32) + ); +#else + return _mm512_setr_epi64(i0, i1, i2, i3, i4, i5, i6, i7); +#endif +} + +NPY_FINLINE __m512 npyv__setr_ps( + float i0, float i1, float i2, float i3, float i4, float i5, float i6, float i7, + float i8, float i9, float i10, float i11, float i12, float i13, float i14, float i15) +{ + return _mm512_setr_ps(i0, i1, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12, i13, i14, i15); +} +NPY_FINLINE __m512d npyv__setr_pd(double i0, double i1, double i2, double i3, + double i4, double i5, double i6, double i7) +{ + return _mm512_setr_pd(i0, i1, i2, i3, i4, i5, i6, i7); +} +#define npyv_setf_u8(FILL, ...) npyv__setr_epi8(NPYV__SET_FILL_64(char, FILL, __VA_ARGS__)) +#define npyv_setf_s8(FILL, ...) npyv__setr_epi8(NPYV__SET_FILL_64(char, FILL, __VA_ARGS__)) +#define npyv_setf_u16(FILL, ...) npyv__setr_epi16(NPYV__SET_FILL_32(short, FILL, __VA_ARGS__)) +#define npyv_setf_s16(FILL, ...) npyv__setr_epi16(NPYV__SET_FILL_32(short, FILL, __VA_ARGS__)) +#define npyv_setf_u32(FILL, ...) npyv__setr_epi32(NPYV__SET_FILL_16(int, FILL, __VA_ARGS__)) +#define npyv_setf_s32(FILL, ...) npyv__setr_epi32(NPYV__SET_FILL_16(int, FILL, __VA_ARGS__)) +#define npyv_setf_u64(FILL, ...) npyv__setr_epi64(NPYV__SET_FILL_8(npy_int64, FILL, __VA_ARGS__)) +#define npyv_setf_s64(FILL, ...) npyv__setr_epi64(NPYV__SET_FILL_8(npy_int64, FILL, __VA_ARGS__)) +#define npyv_setf_f32(FILL, ...) npyv__setr_ps(NPYV__SET_FILL_16(float, FILL, __VA_ARGS__)) +#define npyv_setf_f64(FILL, ...) npyv__setr_pd(NPYV__SET_FILL_8(double, FILL, __VA_ARGS__)) + +// vector with specific values set to each lane and +// set zero to all remained lanes +#define npyv_set_u8(...) npyv_setf_u8(0, __VA_ARGS__) +#define npyv_set_s8(...) npyv_setf_s8(0, __VA_ARGS__) +#define npyv_set_u16(...) npyv_setf_u16(0, __VA_ARGS__) +#define npyv_set_s16(...) npyv_setf_s16(0, __VA_ARGS__) +#define npyv_set_u32(...) npyv_setf_u32(0, __VA_ARGS__) +#define npyv_set_s32(...) npyv_setf_s32(0, __VA_ARGS__) +#define npyv_set_u64(...) npyv_setf_u64(0, __VA_ARGS__) +#define npyv_set_s64(...) npyv_setf_s64(0, __VA_ARGS__) +#define npyv_set_f32(...) npyv_setf_f32(0, __VA_ARGS__) +#define npyv_set_f64(...) npyv_setf_f64(0, __VA_ARGS__) + +// per lane select +#ifdef NPY_HAVE_AVX512BW + #define npyv_select_u8(MASK, A, B) _mm512_mask_blend_epi8(MASK, B, A) + #define npyv_select_u16(MASK, A, B) _mm512_mask_blend_epi16(MASK, B, A) +#else + NPY_FINLINE __m512i npyv_select_u8(__m512i mask, __m512i a, __m512i b) + { return _mm512_xor_si512(b, _mm512_and_si512(_mm512_xor_si512(b, a), mask)); } + #define npyv_select_u16 npyv_select_u8 +#endif +#define npyv_select_s8 npyv_select_u8 +#define npyv_select_s16 npyv_select_u16 +#define npyv_select_u32(MASK, A, B) _mm512_mask_blend_epi32(MASK, B, A) +#define npyv_select_s32 npyv_select_u32 +#define npyv_select_u64(MASK, A, B) _mm512_mask_blend_epi64(MASK, B, A) +#define npyv_select_s64 npyv_select_u64 +#define npyv_select_f32(MASK, A, B) _mm512_mask_blend_ps(MASK, B, A) +#define npyv_select_f64(MASK, A, B) _mm512_mask_blend_pd(MASK, B, A) + +// extract the first vector's lane +#define npyv_extract0_u8(A) ((npy_uint8)_mm_cvtsi128_si32(_mm512_castsi512_si128(A))) +#define npyv_extract0_s8(A) ((npy_int8)_mm_cvtsi128_si32(_mm512_castsi512_si128(A))) +#define npyv_extract0_u16(A) ((npy_uint16)_mm_cvtsi128_si32(_mm512_castsi512_si128(A))) +#define npyv_extract0_s16(A) ((npy_int16)_mm_cvtsi128_si32(_mm512_castsi512_si128(A))) +#define npyv_extract0_u32(A) ((npy_uint32)_mm_cvtsi128_si32(_mm512_castsi512_si128(A))) +#define npyv_extract0_s32(A) ((npy_int32)_mm_cvtsi128_si32(_mm512_castsi512_si128(A))) +#define npyv_extract0_u64(A) ((npy_uint64)npyv128_cvtsi128_si64(_mm512_castsi512_si128(A))) +#define npyv_extract0_s64(A) ((npy_int64)npyv128_cvtsi128_si64(_mm512_castsi512_si128(A))) +#define npyv_extract0_f32(A) _mm_cvtss_f32(_mm512_castps512_ps128(A)) +#define npyv_extract0_f64(A) _mm_cvtsd_f64(_mm512_castpd512_pd128(A)) + +// reinterpret +#define npyv_reinterpret_u8_u8(X) X +#define npyv_reinterpret_u8_s8(X) X +#define npyv_reinterpret_u8_u16(X) X +#define npyv_reinterpret_u8_s16(X) X +#define npyv_reinterpret_u8_u32(X) X +#define npyv_reinterpret_u8_s32(X) X +#define npyv_reinterpret_u8_u64(X) X +#define npyv_reinterpret_u8_s64(X) X +#define npyv_reinterpret_u8_f32 _mm512_castps_si512 +#define npyv_reinterpret_u8_f64 _mm512_castpd_si512 + +#define npyv_reinterpret_s8_s8(X) X +#define npyv_reinterpret_s8_u8(X) X +#define npyv_reinterpret_s8_u16(X) X +#define npyv_reinterpret_s8_s16(X) X +#define npyv_reinterpret_s8_u32(X) X +#define npyv_reinterpret_s8_s32(X) X +#define npyv_reinterpret_s8_u64(X) X +#define npyv_reinterpret_s8_s64(X) X +#define npyv_reinterpret_s8_f32 _mm512_castps_si512 +#define npyv_reinterpret_s8_f64 _mm512_castpd_si512 + +#define npyv_reinterpret_u16_u16(X) X +#define npyv_reinterpret_u16_u8(X) X +#define npyv_reinterpret_u16_s8(X) X +#define npyv_reinterpret_u16_s16(X) X +#define npyv_reinterpret_u16_u32(X) X +#define npyv_reinterpret_u16_s32(X) X +#define npyv_reinterpret_u16_u64(X) X +#define npyv_reinterpret_u16_s64(X) X +#define npyv_reinterpret_u16_f32 _mm512_castps_si512 +#define npyv_reinterpret_u16_f64 _mm512_castpd_si512 + +#define npyv_reinterpret_s16_s16(X) X +#define npyv_reinterpret_s16_u8(X) X +#define npyv_reinterpret_s16_s8(X) X +#define npyv_reinterpret_s16_u16(X) X +#define npyv_reinterpret_s16_u32(X) X +#define npyv_reinterpret_s16_s32(X) X +#define npyv_reinterpret_s16_u64(X) X +#define npyv_reinterpret_s16_s64(X) X +#define npyv_reinterpret_s16_f32 _mm512_castps_si512 +#define npyv_reinterpret_s16_f64 _mm512_castpd_si512 + +#define npyv_reinterpret_u32_u32(X) X +#define npyv_reinterpret_u32_u8(X) X +#define npyv_reinterpret_u32_s8(X) X +#define npyv_reinterpret_u32_u16(X) X +#define npyv_reinterpret_u32_s16(X) X +#define npyv_reinterpret_u32_s32(X) X +#define npyv_reinterpret_u32_u64(X) X +#define npyv_reinterpret_u32_s64(X) X +#define npyv_reinterpret_u32_f32 _mm512_castps_si512 +#define npyv_reinterpret_u32_f64 _mm512_castpd_si512 + +#define npyv_reinterpret_s32_s32(X) X +#define npyv_reinterpret_s32_u8(X) X +#define npyv_reinterpret_s32_s8(X) X +#define npyv_reinterpret_s32_u16(X) X +#define npyv_reinterpret_s32_s16(X) X +#define npyv_reinterpret_s32_u32(X) X +#define npyv_reinterpret_s32_u64(X) X +#define npyv_reinterpret_s32_s64(X) X +#define npyv_reinterpret_s32_f32 _mm512_castps_si512 +#define npyv_reinterpret_s32_f64 _mm512_castpd_si512 + +#define npyv_reinterpret_u64_u64(X) X +#define npyv_reinterpret_u64_u8(X) X +#define npyv_reinterpret_u64_s8(X) X +#define npyv_reinterpret_u64_u16(X) X +#define npyv_reinterpret_u64_s16(X) X +#define npyv_reinterpret_u64_u32(X) X +#define npyv_reinterpret_u64_s32(X) X +#define npyv_reinterpret_u64_s64(X) X +#define npyv_reinterpret_u64_f32 _mm512_castps_si512 +#define npyv_reinterpret_u64_f64 _mm512_castpd_si512 + +#define npyv_reinterpret_s64_s64(X) X +#define npyv_reinterpret_s64_u8(X) X +#define npyv_reinterpret_s64_s8(X) X +#define npyv_reinterpret_s64_u16(X) X +#define npyv_reinterpret_s64_s16(X) X +#define npyv_reinterpret_s64_u32(X) X +#define npyv_reinterpret_s64_s32(X) X +#define npyv_reinterpret_s64_u64(X) X +#define npyv_reinterpret_s64_f32 _mm512_castps_si512 +#define npyv_reinterpret_s64_f64 _mm512_castpd_si512 + +#define npyv_reinterpret_f32_f32(X) X +#define npyv_reinterpret_f32_u8 _mm512_castsi512_ps +#define npyv_reinterpret_f32_s8 _mm512_castsi512_ps +#define npyv_reinterpret_f32_u16 _mm512_castsi512_ps +#define npyv_reinterpret_f32_s16 _mm512_castsi512_ps +#define npyv_reinterpret_f32_u32 _mm512_castsi512_ps +#define npyv_reinterpret_f32_s32 _mm512_castsi512_ps +#define npyv_reinterpret_f32_u64 _mm512_castsi512_ps +#define npyv_reinterpret_f32_s64 _mm512_castsi512_ps +#define npyv_reinterpret_f32_f64 _mm512_castpd_ps + +#define npyv_reinterpret_f64_f64(X) X +#define npyv_reinterpret_f64_u8 _mm512_castsi512_pd +#define npyv_reinterpret_f64_s8 _mm512_castsi512_pd +#define npyv_reinterpret_f64_u16 _mm512_castsi512_pd +#define npyv_reinterpret_f64_s16 _mm512_castsi512_pd +#define npyv_reinterpret_f64_u32 _mm512_castsi512_pd +#define npyv_reinterpret_f64_s32 _mm512_castsi512_pd +#define npyv_reinterpret_f64_u64 _mm512_castsi512_pd +#define npyv_reinterpret_f64_s64 _mm512_castsi512_pd +#define npyv_reinterpret_f64_f32 _mm512_castps_pd + +#ifdef NPY_HAVE_AVX512_KNL + #define npyv_cleanup() ((void)0) +#else + #define npyv_cleanup _mm256_zeroall +#endif + +#endif // _NPY_SIMD_AVX512_MISC_H diff --git a/mkl_umath/src/npyv/avx512/operators.h b/mkl_umath/src/npyv/avx512/operators.h new file mode 100644 index 00000000..c70932d5 --- /dev/null +++ b/mkl_umath/src/npyv/avx512/operators.h @@ -0,0 +1,380 @@ +#ifndef NPY_SIMD + #error "Not a standalone header" +#endif + +#ifndef _NPY_SIMD_AVX512_OPERATORS_H +#define _NPY_SIMD_AVX512_OPERATORS_H + +#include "conversion.h" // tobits + +/*************************** + * Shifting + ***************************/ + +// left +#ifdef NPY_HAVE_AVX512BW + #define npyv_shl_u16(A, C) _mm512_sll_epi16(A, _mm_cvtsi32_si128(C)) +#else + #define NPYV_IMPL_AVX512_SHIFT(FN, INTRIN) \ + NPY_FINLINE __m512i npyv_##FN(__m512i a, int c) \ + { \ + __m256i l = npyv512_lower_si256(a); \ + __m256i h = npyv512_higher_si256(a); \ + __m128i cv = _mm_cvtsi32_si128(c); \ + l = _mm256_##INTRIN(l, cv); \ + h = _mm256_##INTRIN(h, cv); \ + return npyv512_combine_si256(l, h); \ + } + + NPYV_IMPL_AVX512_SHIFT(shl_u16, sll_epi16) +#endif +#define npyv_shl_s16 npyv_shl_u16 +#define npyv_shl_u32(A, C) _mm512_sll_epi32(A, _mm_cvtsi32_si128(C)) +#define npyv_shl_s32(A, C) _mm512_sll_epi32(A, _mm_cvtsi32_si128(C)) +#define npyv_shl_u64(A, C) _mm512_sll_epi64(A, _mm_cvtsi32_si128(C)) +#define npyv_shl_s64(A, C) _mm512_sll_epi64(A, _mm_cvtsi32_si128(C)) + +// left by an immediate constant +#ifdef NPY_HAVE_AVX512BW + #define npyv_shli_u16 _mm512_slli_epi16 +#else + #define npyv_shli_u16 npyv_shl_u16 +#endif +#define npyv_shli_s16 npyv_shl_u16 +#define npyv_shli_u32 _mm512_slli_epi32 +#define npyv_shli_s32 _mm512_slli_epi32 +#define npyv_shli_u64 _mm512_slli_epi64 +#define npyv_shli_s64 _mm512_slli_epi64 + +// right +#ifdef NPY_HAVE_AVX512BW + #define npyv_shr_u16(A, C) _mm512_srl_epi16(A, _mm_cvtsi32_si128(C)) + #define npyv_shr_s16(A, C) _mm512_sra_epi16(A, _mm_cvtsi32_si128(C)) +#else + NPYV_IMPL_AVX512_SHIFT(shr_u16, srl_epi16) + NPYV_IMPL_AVX512_SHIFT(shr_s16, sra_epi16) +#endif +#define npyv_shr_u32(A, C) _mm512_srl_epi32(A, _mm_cvtsi32_si128(C)) +#define npyv_shr_s32(A, C) _mm512_sra_epi32(A, _mm_cvtsi32_si128(C)) +#define npyv_shr_u64(A, C) _mm512_srl_epi64(A, _mm_cvtsi32_si128(C)) +#define npyv_shr_s64(A, C) _mm512_sra_epi64(A, _mm_cvtsi32_si128(C)) + +// right by an immediate constant +#ifdef NPY_HAVE_AVX512BW + #define npyv_shri_u16 _mm512_srli_epi16 + #define npyv_shri_s16 _mm512_srai_epi16 +#else + #define npyv_shri_u16 npyv_shr_u16 + #define npyv_shri_s16 npyv_shr_s16 +#endif +#define npyv_shri_u32 _mm512_srli_epi32 +#define npyv_shri_s32 _mm512_srai_epi32 +#define npyv_shri_u64 _mm512_srli_epi64 +#define npyv_shri_s64 _mm512_srai_epi64 + +/*************************** + * Logical + ***************************/ + +// AND +#define npyv_and_u8 _mm512_and_si512 +#define npyv_and_s8 _mm512_and_si512 +#define npyv_and_u16 _mm512_and_si512 +#define npyv_and_s16 _mm512_and_si512 +#define npyv_and_u32 _mm512_and_si512 +#define npyv_and_s32 _mm512_and_si512 +#define npyv_and_u64 _mm512_and_si512 +#define npyv_and_s64 _mm512_and_si512 +#ifdef NPY_HAVE_AVX512DQ + #define npyv_and_f32 _mm512_and_ps + #define npyv_and_f64 _mm512_and_pd +#else + NPYV_IMPL_AVX512_FROM_SI512_PS_2ARG(npyv_and_f32, _mm512_and_si512) + NPYV_IMPL_AVX512_FROM_SI512_PD_2ARG(npyv_and_f64, _mm512_and_si512) +#endif +// OR +#define npyv_or_u8 _mm512_or_si512 +#define npyv_or_s8 _mm512_or_si512 +#define npyv_or_u16 _mm512_or_si512 +#define npyv_or_s16 _mm512_or_si512 +#define npyv_or_u32 _mm512_or_si512 +#define npyv_or_s32 _mm512_or_si512 +#define npyv_or_u64 _mm512_or_si512 +#define npyv_or_s64 _mm512_or_si512 +#ifdef NPY_HAVE_AVX512DQ + #define npyv_or_f32 _mm512_or_ps + #define npyv_or_f64 _mm512_or_pd +#else + NPYV_IMPL_AVX512_FROM_SI512_PS_2ARG(npyv_or_f32, _mm512_or_si512) + NPYV_IMPL_AVX512_FROM_SI512_PD_2ARG(npyv_or_f64, _mm512_or_si512) +#endif + +// XOR +#define npyv_xor_u8 _mm512_xor_si512 +#define npyv_xor_s8 _mm512_xor_si512 +#define npyv_xor_u16 _mm512_xor_si512 +#define npyv_xor_s16 _mm512_xor_si512 +#define npyv_xor_u32 _mm512_xor_si512 +#define npyv_xor_s32 _mm512_xor_si512 +#define npyv_xor_u64 _mm512_xor_si512 +#define npyv_xor_s64 _mm512_xor_si512 +#ifdef NPY_HAVE_AVX512DQ + #define npyv_xor_f32 _mm512_xor_ps + #define npyv_xor_f64 _mm512_xor_pd +#else + NPYV_IMPL_AVX512_FROM_SI512_PS_2ARG(npyv_xor_f32, _mm512_xor_si512) + NPYV_IMPL_AVX512_FROM_SI512_PD_2ARG(npyv_xor_f64, _mm512_xor_si512) +#endif +// NOT +#define npyv_not_u8(A) _mm512_xor_si512(A, _mm512_set1_epi32(-1)) +#define npyv_not_s8 npyv_not_u8 +#define npyv_not_u16 npyv_not_u8 +#define npyv_not_s16 npyv_not_u8 +#define npyv_not_u32 npyv_not_u8 +#define npyv_not_s32 npyv_not_u8 +#define npyv_not_u64 npyv_not_u8 +#define npyv_not_s64 npyv_not_u8 +#ifdef NPY_HAVE_AVX512DQ + #define npyv_not_f32(A) _mm512_xor_ps(A, _mm512_castsi512_ps(_mm512_set1_epi32(-1))) + #define npyv_not_f64(A) _mm512_xor_pd(A, _mm512_castsi512_pd(_mm512_set1_epi32(-1))) +#else + #define npyv_not_f32(A) _mm512_castsi512_ps(npyv_not_u32(_mm512_castps_si512(A))) + #define npyv_not_f64(A) _mm512_castsi512_pd(npyv_not_u64(_mm512_castpd_si512(A))) +#endif + +// ANDC +#define npyv_andc_u8(A, B) _mm512_andnot_si512(B, A) + +/*************************** + * Logical (boolean) + ***************************/ +#ifdef NPY_HAVE_AVX512BW_MASK + #define npyv_and_b8 _kand_mask64 + #define npyv_and_b16 _kand_mask32 + #define npyv_or_b8 _kor_mask64 + #define npyv_or_b16 _kor_mask32 + #define npyv_xor_b8 _kxor_mask64 + #define npyv_xor_b16 _kxor_mask32 + #define npyv_not_b8 _knot_mask64 + #define npyv_not_b16 _knot_mask32 + #define npyv_andc_b8(A, B) _kandn_mask64(B, A) + #define npyv_orc_b8(A, B) npyv_or_b8(npyv_not_b8(B), A) + #define npyv_xnor_b8 _kxnor_mask64 +#elif defined(NPY_HAVE_AVX512BW) + NPY_FINLINE npyv_b8 npyv_and_b8(npyv_b8 a, npyv_b8 b) + { return a & b; } + NPY_FINLINE npyv_b16 npyv_and_b16(npyv_b16 a, npyv_b16 b) + { return a & b; } + NPY_FINLINE npyv_b8 npyv_or_b8(npyv_b8 a, npyv_b8 b) + { return a | b; } + NPY_FINLINE npyv_b16 npyv_or_b16(npyv_b16 a, npyv_b16 b) + { return a | b; } + NPY_FINLINE npyv_b8 npyv_xor_b8(npyv_b8 a, npyv_b8 b) + { return a ^ b; } + NPY_FINLINE npyv_b16 npyv_xor_b16(npyv_b16 a, npyv_b16 b) + { return a ^ b; } + NPY_FINLINE npyv_b8 npyv_not_b8(npyv_b8 a) + { return ~a; } + NPY_FINLINE npyv_b16 npyv_not_b16(npyv_b16 a) + { return ~a; } + NPY_FINLINE npyv_b8 npyv_andc_b8(npyv_b8 a, npyv_b8 b) + { return a & (~b); } + NPY_FINLINE npyv_b8 npyv_orc_b8(npyv_b8 a, npyv_b8 b) + { return a | (~b); } + NPY_FINLINE npyv_b8 npyv_xnor_b8(npyv_b8 a, npyv_b8 b) + { return ~(a ^ b); } +#else + #define npyv_and_b8 _mm512_and_si512 + #define npyv_and_b16 _mm512_and_si512 + #define npyv_or_b8 _mm512_or_si512 + #define npyv_or_b16 _mm512_or_si512 + #define npyv_xor_b8 _mm512_xor_si512 + #define npyv_xor_b16 _mm512_xor_si512 + #define npyv_not_b8 npyv_not_u8 + #define npyv_not_b16 npyv_not_u8 + #define npyv_andc_b8(A, B) _mm512_andnot_si512(B, A) + #define npyv_orc_b8(A, B) npyv_or_b8(npyv_not_b8(B), A) + #define npyv_xnor_b8(A, B) npyv_not_b8(npyv_xor_b8(A, B)) +#endif + +#define npyv_and_b32 _mm512_kand +#define npyv_or_b32 _mm512_kor +#define npyv_xor_b32 _mm512_kxor +#define npyv_not_b32 _mm512_knot + +#ifdef NPY_HAVE_AVX512DQ_MASK + #define npyv_and_b64 _kand_mask8 + #define npyv_or_b64 _kor_mask8 + #define npyv_xor_b64 _kxor_mask8 + #define npyv_not_b64 _knot_mask8 +#else + NPY_FINLINE npyv_b64 npyv_and_b64(npyv_b64 a, npyv_b64 b) + { return (npyv_b64)_mm512_kand((npyv_b32)a, (npyv_b32)b); } + NPY_FINLINE npyv_b64 npyv_or_b64(npyv_b64 a, npyv_b64 b) + { return (npyv_b64)_mm512_kor((npyv_b32)a, (npyv_b32)b); } + NPY_FINLINE npyv_b64 npyv_xor_b64(npyv_b64 a, npyv_b64 b) + { return (npyv_b64)_mm512_kxor((npyv_b32)a, (npyv_b32)b); } + NPY_FINLINE npyv_b64 npyv_not_b64(npyv_b64 a) + { return (npyv_b64)_mm512_knot((npyv_b32)a); } +#endif + +/*************************** + * Comparison + ***************************/ + +// int Equal +#ifdef NPY_HAVE_AVX512BW + #define npyv_cmpeq_u8 _mm512_cmpeq_epu8_mask + #define npyv_cmpeq_s8 _mm512_cmpeq_epi8_mask + #define npyv_cmpeq_u16 _mm512_cmpeq_epu16_mask + #define npyv_cmpeq_s16 _mm512_cmpeq_epi16_mask +#else + NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv_cmpeq_u8, _mm256_cmpeq_epi8) + NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv_cmpeq_u16, _mm256_cmpeq_epi16) + #define npyv_cmpeq_s8 npyv_cmpeq_u8 + #define npyv_cmpeq_s16 npyv_cmpeq_u16 +#endif +#define npyv_cmpeq_u32 _mm512_cmpeq_epu32_mask +#define npyv_cmpeq_s32 _mm512_cmpeq_epi32_mask +#define npyv_cmpeq_u64 _mm512_cmpeq_epu64_mask +#define npyv_cmpeq_s64 _mm512_cmpeq_epi64_mask + +// int not equal +#ifdef NPY_HAVE_AVX512BW + #define npyv_cmpneq_u8 _mm512_cmpneq_epu8_mask + #define npyv_cmpneq_s8 _mm512_cmpneq_epi8_mask + #define npyv_cmpneq_u16 _mm512_cmpneq_epu16_mask + #define npyv_cmpneq_s16 _mm512_cmpneq_epi16_mask +#else + #define npyv_cmpneq_u8(A, B) npyv_not_u8(npyv_cmpeq_u8(A, B)) + #define npyv_cmpneq_u16(A, B) npyv_not_u16(npyv_cmpeq_u16(A, B)) + #define npyv_cmpneq_s8 npyv_cmpneq_u8 + #define npyv_cmpneq_s16 npyv_cmpneq_u16 +#endif +#define npyv_cmpneq_u32 _mm512_cmpneq_epu32_mask +#define npyv_cmpneq_s32 _mm512_cmpneq_epi32_mask +#define npyv_cmpneq_u64 _mm512_cmpneq_epu64_mask +#define npyv_cmpneq_s64 _mm512_cmpneq_epi64_mask + +// greater than +#ifdef NPY_HAVE_AVX512BW + #define npyv_cmpgt_u8 _mm512_cmpgt_epu8_mask + #define npyv_cmpgt_s8 _mm512_cmpgt_epi8_mask + #define npyv_cmpgt_u16 _mm512_cmpgt_epu16_mask + #define npyv_cmpgt_s16 _mm512_cmpgt_epi16_mask +#else + NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv_cmpgt_s8, _mm256_cmpgt_epi8) + NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv_cmpgt_s16, _mm256_cmpgt_epi16) + NPY_FINLINE __m512i npyv_cmpgt_u8(__m512i a, __m512i b) + { + const __m512i sbit = _mm512_set1_epi32(0x80808080); + return npyv_cmpgt_s8(_mm512_xor_si512(a, sbit), _mm512_xor_si512(b, sbit)); + } + NPY_FINLINE __m512i npyv_cmpgt_u16(__m512i a, __m512i b) + { + const __m512i sbit = _mm512_set1_epi32(0x80008000); + return npyv_cmpgt_s16(_mm512_xor_si512(a, sbit), _mm512_xor_si512(b, sbit)); + } +#endif +#define npyv_cmpgt_u32 _mm512_cmpgt_epu32_mask +#define npyv_cmpgt_s32 _mm512_cmpgt_epi32_mask +#define npyv_cmpgt_u64 _mm512_cmpgt_epu64_mask +#define npyv_cmpgt_s64 _mm512_cmpgt_epi64_mask + +// greater than or equal +#ifdef NPY_HAVE_AVX512BW + #define npyv_cmpge_u8 _mm512_cmpge_epu8_mask + #define npyv_cmpge_s8 _mm512_cmpge_epi8_mask + #define npyv_cmpge_u16 _mm512_cmpge_epu16_mask + #define npyv_cmpge_s16 _mm512_cmpge_epi16_mask +#else + #define npyv_cmpge_u8(A, B) npyv_not_u8(npyv_cmpgt_u8(B, A)) + #define npyv_cmpge_s8(A, B) npyv_not_s8(npyv_cmpgt_s8(B, A)) + #define npyv_cmpge_u16(A, B) npyv_not_u16(npyv_cmpgt_u16(B, A)) + #define npyv_cmpge_s16(A, B) npyv_not_s16(npyv_cmpgt_s16(B, A)) +#endif +#define npyv_cmpge_u32 _mm512_cmpge_epu32_mask +#define npyv_cmpge_s32 _mm512_cmpge_epi32_mask +#define npyv_cmpge_u64 _mm512_cmpge_epu64_mask +#define npyv_cmpge_s64 _mm512_cmpge_epi64_mask + +// less than +#define npyv_cmplt_u8(A, B) npyv_cmpgt_u8(B, A) +#define npyv_cmplt_s8(A, B) npyv_cmpgt_s8(B, A) +#define npyv_cmplt_u16(A, B) npyv_cmpgt_u16(B, A) +#define npyv_cmplt_s16(A, B) npyv_cmpgt_s16(B, A) +#define npyv_cmplt_u32(A, B) npyv_cmpgt_u32(B, A) +#define npyv_cmplt_s32(A, B) npyv_cmpgt_s32(B, A) +#define npyv_cmplt_u64(A, B) npyv_cmpgt_u64(B, A) +#define npyv_cmplt_s64(A, B) npyv_cmpgt_s64(B, A) + +// less than or equal +#define npyv_cmple_u8(A, B) npyv_cmpge_u8(B, A) +#define npyv_cmple_s8(A, B) npyv_cmpge_s8(B, A) +#define npyv_cmple_u16(A, B) npyv_cmpge_u16(B, A) +#define npyv_cmple_s16(A, B) npyv_cmpge_s16(B, A) +#define npyv_cmple_u32(A, B) npyv_cmpge_u32(B, A) +#define npyv_cmple_s32(A, B) npyv_cmpge_s32(B, A) +#define npyv_cmple_u64(A, B) npyv_cmpge_u64(B, A) +#define npyv_cmple_s64(A, B) npyv_cmpge_s64(B, A) + +// precision comparison +#define npyv_cmpeq_f32(A, B) _mm512_cmp_ps_mask(A, B, _CMP_EQ_OQ) +#define npyv_cmpeq_f64(A, B) _mm512_cmp_pd_mask(A, B, _CMP_EQ_OQ) +#define npyv_cmpneq_f32(A, B) _mm512_cmp_ps_mask(A, B, _CMP_NEQ_UQ) +#define npyv_cmpneq_f64(A, B) _mm512_cmp_pd_mask(A, B, _CMP_NEQ_UQ) +#define npyv_cmplt_f32(A, B) _mm512_cmp_ps_mask(A, B, _CMP_LT_OQ) +#define npyv_cmplt_f64(A, B) _mm512_cmp_pd_mask(A, B, _CMP_LT_OQ) +#define npyv_cmple_f32(A, B) _mm512_cmp_ps_mask(A, B, _CMP_LE_OQ) +#define npyv_cmple_f64(A, B) _mm512_cmp_pd_mask(A, B, _CMP_LE_OQ) +#define npyv_cmpgt_f32(A, B) _mm512_cmp_ps_mask(A, B, _CMP_GT_OQ) +#define npyv_cmpgt_f64(A, B) _mm512_cmp_pd_mask(A, B, _CMP_GT_OQ) +#define npyv_cmpge_f32(A, B) _mm512_cmp_ps_mask(A, B, _CMP_GE_OQ) +#define npyv_cmpge_f64(A, B) _mm512_cmp_pd_mask(A, B, _CMP_GE_OQ) + +// check special cases +NPY_FINLINE npyv_b32 npyv_notnan_f32(npyv_f32 a) +{ return _mm512_cmp_ps_mask(a, a, _CMP_ORD_Q); } +NPY_FINLINE npyv_b64 npyv_notnan_f64(npyv_f64 a) +{ return _mm512_cmp_pd_mask(a, a, _CMP_ORD_Q); } + +// Test cross all vector lanes +// any: returns true if any of the elements is not equal to zero +// all: returns true if all elements are not equal to zero +#define NPYV_IMPL_AVX512_ANYALL(SFX, MASK) \ + NPY_FINLINE bool npyv_any_##SFX(npyv_##SFX a) \ + { return npyv_tobits_##SFX(a) != 0; } \ + NPY_FINLINE bool npyv_all_##SFX(npyv_##SFX a) \ + { return npyv_tobits_##SFX(a) == MASK; } +NPYV_IMPL_AVX512_ANYALL(b8, 0xffffffffffffffffull) +NPYV_IMPL_AVX512_ANYALL(b16, 0xfffffffful) +NPYV_IMPL_AVX512_ANYALL(b32, 0xffff) +NPYV_IMPL_AVX512_ANYALL(b64, 0xff) +#undef NPYV_IMPL_AVX512_ANYALL + +#define NPYV_IMPL_AVX512_ANYALL(SFX, BSFX, MASK) \ + NPY_FINLINE bool npyv_any_##SFX(npyv_##SFX a) \ + { \ + return npyv_tobits_##BSFX( \ + npyv_cmpeq_##SFX(a, npyv_zero_##SFX()) \ + ) != MASK; \ + } \ + NPY_FINLINE bool npyv_all_##SFX(npyv_##SFX a) \ + { \ + return npyv_tobits_##BSFX( \ + npyv_cmpeq_##SFX(a, npyv_zero_##SFX()) \ + ) == 0; \ + } +NPYV_IMPL_AVX512_ANYALL(u8, b8, 0xffffffffffffffffull) +NPYV_IMPL_AVX512_ANYALL(s8, b8, 0xffffffffffffffffull) +NPYV_IMPL_AVX512_ANYALL(u16, b16, 0xfffffffful) +NPYV_IMPL_AVX512_ANYALL(s16, b16, 0xfffffffful) +NPYV_IMPL_AVX512_ANYALL(u32, b32, 0xffff) +NPYV_IMPL_AVX512_ANYALL(s32, b32, 0xffff) +NPYV_IMPL_AVX512_ANYALL(u64, b64, 0xff) +NPYV_IMPL_AVX512_ANYALL(s64, b64, 0xff) +NPYV_IMPL_AVX512_ANYALL(f32, b32, 0xffff) +NPYV_IMPL_AVX512_ANYALL(f64, b64, 0xff) +#undef NPYV_IMPL_AVX512_ANYALL + +#endif // _NPY_SIMD_AVX512_OPERATORS_H diff --git a/mkl_umath/src/npyv/avx512/reorder.h b/mkl_umath/src/npyv/avx512/reorder.h new file mode 100644 index 00000000..27e66b5e --- /dev/null +++ b/mkl_umath/src/npyv/avx512/reorder.h @@ -0,0 +1,378 @@ +#ifndef NPY_SIMD + #error "Not a standalone header" +#endif + +#ifndef _NPY_SIMD_AVX512_REORDER_H +#define _NPY_SIMD_AVX512_REORDER_H + +// combine lower part of two vectors +#define npyv_combinel_u8(A, B) _mm512_inserti64x4(A, _mm512_castsi512_si256(B), 1) +#define npyv_combinel_s8 npyv_combinel_u8 +#define npyv_combinel_u16 npyv_combinel_u8 +#define npyv_combinel_s16 npyv_combinel_u8 +#define npyv_combinel_u32 npyv_combinel_u8 +#define npyv_combinel_s32 npyv_combinel_u8 +#define npyv_combinel_u64 npyv_combinel_u8 +#define npyv_combinel_s64 npyv_combinel_u8 +#define npyv_combinel_f64(A, B) _mm512_insertf64x4(A, _mm512_castpd512_pd256(B), 1) +#ifdef NPY_HAVE_AVX512DQ + #define npyv_combinel_f32(A, B) \ + _mm512_insertf32x8(A, _mm512_castps512_ps256(B), 1) +#else + #define npyv_combinel_f32(A, B) \ + _mm512_castsi512_ps(npyv_combinel_u8(_mm512_castps_si512(A), _mm512_castps_si512(B))) +#endif + +// combine higher part of two vectors +#define npyv_combineh_u8(A, B) _mm512_inserti64x4(B, _mm512_extracti64x4_epi64(A, 1), 0) +#define npyv_combineh_s8 npyv_combineh_u8 +#define npyv_combineh_u16 npyv_combineh_u8 +#define npyv_combineh_s16 npyv_combineh_u8 +#define npyv_combineh_u32 npyv_combineh_u8 +#define npyv_combineh_s32 npyv_combineh_u8 +#define npyv_combineh_u64 npyv_combineh_u8 +#define npyv_combineh_s64 npyv_combineh_u8 +#define npyv_combineh_f64(A, B) _mm512_insertf64x4(B, _mm512_extractf64x4_pd(A, 1), 0) +#ifdef NPY_HAVE_AVX512DQ + #define npyv_combineh_f32(A, B) \ + _mm512_insertf32x8(B, _mm512_extractf32x8_ps(A, 1), 0) +#else + #define npyv_combineh_f32(A, B) \ + _mm512_castsi512_ps(npyv_combineh_u8(_mm512_castps_si512(A), _mm512_castps_si512(B))) +#endif + +// combine two vectors from lower and higher parts of two other vectors +NPY_FINLINE npyv_m512ix2 npyv__combine(__m512i a, __m512i b) +{ + npyv_m512ix2 r; + r.val[0] = npyv_combinel_u8(a, b); + r.val[1] = npyv_combineh_u8(a, b); + return r; +} +NPY_FINLINE npyv_f32x2 npyv_combine_f32(__m512 a, __m512 b) +{ + npyv_f32x2 r; + r.val[0] = npyv_combinel_f32(a, b); + r.val[1] = npyv_combineh_f32(a, b); + return r; +} +NPY_FINLINE npyv_f64x2 npyv_combine_f64(__m512d a, __m512d b) +{ + npyv_f64x2 r; + r.val[0] = npyv_combinel_f64(a, b); + r.val[1] = npyv_combineh_f64(a, b); + return r; +} +#define npyv_combine_u8 npyv__combine +#define npyv_combine_s8 npyv__combine +#define npyv_combine_u16 npyv__combine +#define npyv_combine_s16 npyv__combine +#define npyv_combine_u32 npyv__combine +#define npyv_combine_s32 npyv__combine +#define npyv_combine_u64 npyv__combine +#define npyv_combine_s64 npyv__combine + +// interleave two vectors +#ifndef NPY_HAVE_AVX512BW + NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv__unpacklo_epi8, _mm256_unpacklo_epi8) + NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv__unpackhi_epi8, _mm256_unpackhi_epi8) + NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv__unpacklo_epi16, _mm256_unpacklo_epi16) + NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv__unpackhi_epi16, _mm256_unpackhi_epi16) +#endif + +NPY_FINLINE npyv_u64x2 npyv_zip_u64(__m512i a, __m512i b) +{ + npyv_u64x2 r; + r.val[0] = _mm512_permutex2var_epi64(a, npyv_set_u64(0, 8, 1, 9, 2, 10, 3, 11), b); + r.val[1] = _mm512_permutex2var_epi64(a, npyv_set_u64(4, 12, 5, 13, 6, 14, 7, 15), b); + return r; +} +#define npyv_zip_s64 npyv_zip_u64 + +NPY_FINLINE npyv_u8x2 npyv_zip_u8(__m512i a, __m512i b) +{ + npyv_u8x2 r; +#ifdef NPY_HAVE_AVX512VBMI + r.val[0] = _mm512_permutex2var_epi8(a, + npyv_set_u8(0, 64, 1, 65, 2, 66, 3, 67, 4, 68, 5, 69, 6, 70, 7, 71, + 8, 72, 9, 73, 10, 74, 11, 75, 12, 76, 13, 77, 14, 78, 15, 79, + 16, 80, 17, 81, 18, 82, 19, 83, 20, 84, 21, 85, 22, 86, 23, 87, + 24, 88, 25, 89, 26, 90, 27, 91, 28, 92, 29, 93, 30, 94, 31, 95), b); + r.val[1] = _mm512_permutex2var_epi8(a, + npyv_set_u8(32, 96, 33, 97, 34, 98, 35, 99, 36, 100, 37, 101, 38, 102, 39, 103, + 40, 104, 41, 105, 42, 106, 43, 107, 44, 108, 45, 109, 46, 110, 47, 111, + 48, 112, 49, 113, 50, 114, 51, 115, 52, 116, 53, 117, 54, 118, 55, 119, + 56, 120, 57, 121, 58, 122, 59, 123, 60, 124, 61, 125, 62, 126, 63, 127), b); +#else + #ifdef NPY_HAVE_AVX512BW + __m512i ab0 = _mm512_unpacklo_epi8(a, b); + __m512i ab1 = _mm512_unpackhi_epi8(a, b); + #else + __m512i ab0 = npyv__unpacklo_epi8(a, b); + __m512i ab1 = npyv__unpackhi_epi8(a, b); + #endif + r.val[0] = _mm512_permutex2var_epi64(ab0, npyv_set_u64(0, 1, 8, 9, 2, 3, 10, 11), ab1); + r.val[1] = _mm512_permutex2var_epi64(ab0, npyv_set_u64(4, 5, 12, 13, 6, 7, 14, 15), ab1); +#endif + return r; +} +#define npyv_zip_s8 npyv_zip_u8 + +NPY_FINLINE npyv_u16x2 npyv_zip_u16(__m512i a, __m512i b) +{ + npyv_u16x2 r; +#ifdef NPY_HAVE_AVX512BW + r.val[0] = _mm512_permutex2var_epi16(a, + npyv_set_u16(0, 32, 1, 33, 2, 34, 3, 35, 4, 36, 5, 37, 6, 38, 7, 39, + 8, 40, 9, 41, 10, 42, 11, 43, 12, 44, 13, 45, 14, 46, 15, 47), b); + r.val[1] = _mm512_permutex2var_epi16(a, + npyv_set_u16(16, 48, 17, 49, 18, 50, 19, 51, 20, 52, 21, 53, 22, 54, 23, 55, + 24, 56, 25, 57, 26, 58, 27, 59, 28, 60, 29, 61, 30, 62, 31, 63), b); +#else + __m512i ab0 = npyv__unpacklo_epi16(a, b); + __m512i ab1 = npyv__unpackhi_epi16(a, b); + r.val[0] = _mm512_permutex2var_epi64(ab0, npyv_set_u64(0, 1, 8, 9, 2, 3, 10, 11), ab1); + r.val[1] = _mm512_permutex2var_epi64(ab0, npyv_set_u64(4, 5, 12, 13, 6, 7, 14, 15), ab1); +#endif + return r; +} +#define npyv_zip_s16 npyv_zip_u16 + +NPY_FINLINE npyv_u32x2 npyv_zip_u32(__m512i a, __m512i b) +{ + npyv_u32x2 r; + r.val[0] = _mm512_permutex2var_epi32(a, + npyv_set_u32(0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23), b); + r.val[1] = _mm512_permutex2var_epi32(a, + npyv_set_u32(8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31), b); + return r; +} +#define npyv_zip_s32 npyv_zip_u32 + +NPY_FINLINE npyv_f32x2 npyv_zip_f32(__m512 a, __m512 b) +{ + npyv_f32x2 r; + r.val[0] = _mm512_permutex2var_ps(a, + npyv_set_u32(0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23), b); + r.val[1] = _mm512_permutex2var_ps(a, + npyv_set_u32(8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31), b); + return r; +} + +NPY_FINLINE npyv_f64x2 npyv_zip_f64(__m512d a, __m512d b) +{ + npyv_f64x2 r; + r.val[0] = _mm512_permutex2var_pd(a, npyv_set_u64(0, 8, 1, 9, 2, 10, 3, 11), b); + r.val[1] = _mm512_permutex2var_pd(a, npyv_set_u64(4, 12, 5, 13, 6, 14, 7, 15), b); + return r; +} + +// deinterleave two vectors +NPY_FINLINE npyv_u8x2 npyv_unzip_u8(npyv_u8 ab0, npyv_u8 ab1) +{ + npyv_u8x2 r; +#ifdef NPY_HAVE_AVX512VBMI + const __m512i idx_a = npyv_set_u8( + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, + 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62, + 64, 66, 68, 70, 72, 74, 76, 78, 80, 82, 84, 86, 88, 90, 92, 94, + 96, 98, 100, 102, 104, 106, 108, 110, 112, 114, 116, 118, 120, 122, 124, 126 + ); + const __m512i idx_b = npyv_set_u8( + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, + 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63, + 65, 67, 69, 71, 73, 75, 77, 79, 81, 83, 85, 87, 89, 91, 93, 95, + 97, 99, 101, 103, 105, 107, 109, 111, 113, 115, 117, 119, 121, 123, 125, 127 + ); + r.val[0] = _mm512_permutex2var_epi8(ab0, idx_a, ab1); + r.val[1] = _mm512_permutex2var_epi8(ab0, idx_b, ab1); +#else + #ifdef NPY_HAVE_AVX512BW + const __m512i idx = npyv_set_u8( + 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15, + 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15, + 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15, + 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15 + ); + __m512i abl = _mm512_shuffle_epi8(ab0, idx); + __m512i abh = _mm512_shuffle_epi8(ab1, idx); + #else + const __m256i idx = _mm256_setr_epi8( + 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15, + 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15 + ); + __m256i abl_lo = _mm256_shuffle_epi8(npyv512_lower_si256(ab0), idx); + __m256i abl_hi = _mm256_shuffle_epi8(npyv512_higher_si256(ab0), idx); + __m256i abh_lo = _mm256_shuffle_epi8(npyv512_lower_si256(ab1), idx); + __m256i abh_hi = _mm256_shuffle_epi8(npyv512_higher_si256(ab1), idx); + __m512i abl = npyv512_combine_si256(abl_lo, abl_hi); + __m512i abh = npyv512_combine_si256(abh_lo, abh_hi); + #endif + const __m512i idx_a = npyv_set_u64(0, 2, 4, 6, 8, 10, 12, 14); + const __m512i idx_b = npyv_set_u64(1, 3, 5, 7, 9, 11, 13, 15); + r.val[0] = _mm512_permutex2var_epi64(abl, idx_a, abh); + r.val[1] = _mm512_permutex2var_epi64(abl, idx_b, abh); +#endif + return r; +} +#define npyv_unzip_s8 npyv_unzip_u8 + +NPY_FINLINE npyv_u16x2 npyv_unzip_u16(npyv_u16 ab0, npyv_u16 ab1) +{ + npyv_u16x2 r; +#ifdef NPY_HAVE_AVX512BW + const __m512i idx_a = npyv_set_u16( + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, + 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62 + ); + const __m512i idx_b = npyv_set_u16( + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31, + 33, 35, 37, 39, 41, 43, 45, 47, 49, 51, 53, 55, 57, 59, 61, 63 + ); + r.val[0] = _mm512_permutex2var_epi16(ab0, idx_a, ab1); + r.val[1] = _mm512_permutex2var_epi16(ab0, idx_b, ab1); +#else + const __m256i idx = _mm256_setr_epi8( + 0,1, 4,5, 8,9, 12,13, 2,3, 6,7, 10,11, 14,15, + 0,1, 4,5, 8,9, 12,13, 2,3, 6,7, 10,11, 14,15 + ); + __m256i abl_lo = _mm256_shuffle_epi8(npyv512_lower_si256(ab0), idx); + __m256i abl_hi = _mm256_shuffle_epi8(npyv512_higher_si256(ab0), idx); + __m256i abh_lo = _mm256_shuffle_epi8(npyv512_lower_si256(ab1), idx); + __m256i abh_hi = _mm256_shuffle_epi8(npyv512_higher_si256(ab1), idx); + __m512i abl = npyv512_combine_si256(abl_lo, abl_hi); + __m512i abh = npyv512_combine_si256(abh_lo, abh_hi); + + const __m512i idx_a = npyv_set_u64(0, 2, 4, 6, 8, 10, 12, 14); + const __m512i idx_b = npyv_set_u64(1, 3, 5, 7, 9, 11, 13, 15); + r.val[0] = _mm512_permutex2var_epi64(abl, idx_a, abh); + r.val[1] = _mm512_permutex2var_epi64(abl, idx_b, abh); +#endif + return r; +} +#define npyv_unzip_s16 npyv_unzip_u16 + +NPY_FINLINE npyv_u32x2 npyv_unzip_u32(npyv_u32 ab0, npyv_u32 ab1) +{ + const __m512i idx_a = npyv_set_u32( + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30 + ); + const __m512i idx_b = npyv_set_u32( + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31 + ); + npyv_u32x2 r; + r.val[0] = _mm512_permutex2var_epi32(ab0, idx_a, ab1); + r.val[1] = _mm512_permutex2var_epi32(ab0, idx_b, ab1); + return r; +} +#define npyv_unzip_s32 npyv_unzip_u32 + +NPY_FINLINE npyv_u64x2 npyv_unzip_u64(npyv_u64 ab0, npyv_u64 ab1) +{ + const __m512i idx_a = npyv_set_u64(0, 2, 4, 6, 8, 10, 12, 14); + const __m512i idx_b = npyv_set_u64(1, 3, 5, 7, 9, 11, 13, 15); + npyv_u64x2 r; + r.val[0] = _mm512_permutex2var_epi64(ab0, idx_a, ab1); + r.val[1] = _mm512_permutex2var_epi64(ab0, idx_b, ab1); + return r; +} +#define npyv_unzip_s64 npyv_unzip_u64 + +NPY_FINLINE npyv_f32x2 npyv_unzip_f32(npyv_f32 ab0, npyv_f32 ab1) +{ + const __m512i idx_a = npyv_set_u32( + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30 + ); + const __m512i idx_b = npyv_set_u32( + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31 + ); + npyv_f32x2 r; + r.val[0] = _mm512_permutex2var_ps(ab0, idx_a, ab1); + r.val[1] = _mm512_permutex2var_ps(ab0, idx_b, ab1); + return r; +} +NPY_FINLINE npyv_f64x2 npyv_unzip_f64(npyv_f64 ab0, npyv_f64 ab1) +{ + const __m512i idx_a = npyv_set_u64(0, 2, 4, 6, 8, 10, 12, 14); + const __m512i idx_b = npyv_set_u64(1, 3, 5, 7, 9, 11, 13, 15); + npyv_f64x2 r; + r.val[0] = _mm512_permutex2var_pd(ab0, idx_a, ab1); + r.val[1] = _mm512_permutex2var_pd(ab0, idx_b, ab1); + return r; +} + +// Reverse elements of each 64-bit lane +NPY_FINLINE npyv_u8 npyv_rev64_u8(npyv_u8 a) +{ +#ifdef NPY_HAVE_AVX512BW + const __m512i idx = npyv_set_u8( + 7, 6, 5, 4, 3, 2, 1, 0,/*64*/15, 14, 13, 12, 11, 10, 9, 8, + 7, 6, 5, 4, 3, 2, 1, 0,/*64*/15, 14, 13, 12, 11, 10, 9, 8, + 7, 6, 5, 4, 3, 2, 1, 0,/*64*/15, 14, 13, 12, 11, 10, 9, 8, + 7, 6, 5, 4, 3, 2, 1, 0,/*64*/15, 14, 13, 12, 11, 10, 9, 8 + ); + return _mm512_shuffle_epi8(a, idx); +#else + const __m256i idx = _mm256_setr_epi8( + 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8, + 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8 + ); + __m256i lo = _mm256_shuffle_epi8(npyv512_lower_si256(a), idx); + __m256i hi = _mm256_shuffle_epi8(npyv512_higher_si256(a), idx); + return npyv512_combine_si256(lo, hi); +#endif +} +#define npyv_rev64_s8 npyv_rev64_u8 + +NPY_FINLINE npyv_u16 npyv_rev64_u16(npyv_u16 a) +{ +#ifdef NPY_HAVE_AVX512BW + const __m512i idx = npyv_set_u8( + 6, 7, 4, 5, 2, 3, 0, 1,/*64*/14, 15, 12, 13, 10, 11, 8, 9, + 6, 7, 4, 5, 2, 3, 0, 1,/*64*/14, 15, 12, 13, 10, 11, 8, 9, + 6, 7, 4, 5, 2, 3, 0, 1,/*64*/14, 15, 12, 13, 10, 11, 8, 9, + 6, 7, 4, 5, 2, 3, 0, 1,/*64*/14, 15, 12, 13, 10, 11, 8, 9 + ); + return _mm512_shuffle_epi8(a, idx); +#else + const __m256i idx = _mm256_setr_epi8( + 6, 7, 4, 5, 2, 3, 0, 1,/*64*/14, 15, 12, 13, 10, 11, 8, 9, + 6, 7, 4, 5, 2, 3, 0, 1,/*64*/14, 15, 12, 13, 10, 11, 8, 9 + ); + __m256i lo = _mm256_shuffle_epi8(npyv512_lower_si256(a), idx); + __m256i hi = _mm256_shuffle_epi8(npyv512_higher_si256(a), idx); + return npyv512_combine_si256(lo, hi); +#endif +} +#define npyv_rev64_s16 npyv_rev64_u16 + +NPY_FINLINE npyv_u32 npyv_rev64_u32(npyv_u32 a) +{ + return _mm512_shuffle_epi32(a, (_MM_PERM_ENUM)_MM_SHUFFLE(2, 3, 0, 1)); +} +#define npyv_rev64_s32 npyv_rev64_u32 + +NPY_FINLINE npyv_f32 npyv_rev64_f32(npyv_f32 a) +{ + return _mm512_shuffle_ps(a, a, (_MM_PERM_ENUM)_MM_SHUFFLE(2, 3, 0, 1)); +} + +// Permuting the elements of each 128-bit lane by immediate index for +// each element. +#define npyv_permi128_u32(A, E0, E1, E2, E3) \ + _mm512_shuffle_epi32(A, _MM_SHUFFLE(E3, E2, E1, E0)) + +#define npyv_permi128_s32 npyv_permi128_u32 + +#define npyv_permi128_u64(A, E0, E1) \ + _mm512_shuffle_epi32(A, _MM_SHUFFLE(((E1)<<1)+1, ((E1)<<1), ((E0)<<1)+1, ((E0)<<1))) + +#define npyv_permi128_s64 npyv_permi128_u64 + +#define npyv_permi128_f32(A, E0, E1, E2, E3) \ + _mm512_permute_ps(A, _MM_SHUFFLE(E3, E2, E1, E0)) + +#define npyv_permi128_f64(A, E0, E1) \ + _mm512_permute_pd(A, (((E1)<<7) | ((E0)<<6) | ((E1)<<5) | ((E0)<<4) | ((E1)<<3) | ((E0)<<2) | ((E1)<<1) | (E0))) + +#endif // _NPY_SIMD_AVX512_REORDER_H diff --git a/mkl_umath/src/npyv/avx512/utils.h b/mkl_umath/src/npyv/avx512/utils.h new file mode 100644 index 00000000..ced3bfef --- /dev/null +++ b/mkl_umath/src/npyv/avx512/utils.h @@ -0,0 +1,102 @@ +#ifndef NPY_SIMD + #error "Not a standalone header" +#endif + +#ifndef _NPY_SIMD_AVX512_UTILS_H +#define _NPY_SIMD_AVX512_UTILS_H + +#define npyv512_lower_si256 _mm512_castsi512_si256 +#define npyv512_lower_ps256 _mm512_castps512_ps256 +#define npyv512_lower_pd256 _mm512_castpd512_pd256 + +#define npyv512_higher_si256(A) _mm512_extracti64x4_epi64(A, 1) +#define npyv512_higher_pd256(A) _mm512_extractf64x4_pd(A, 1) + +#ifdef NPY_HAVE_AVX512DQ + #define npyv512_higher_ps256(A) _mm512_extractf32x8_ps(A, 1) +#else + #define npyv512_higher_ps256(A) \ + _mm256_castsi256_ps(_mm512_extracti64x4_epi64(_mm512_castps_si512(A), 1)) +#endif + +#define npyv512_combine_si256(A, B) _mm512_inserti64x4(_mm512_castsi256_si512(A), B, 1) +#define npyv512_combine_pd256(A, B) _mm512_insertf64x4(_mm512_castpd256_pd512(A), B, 1) + +#ifdef NPY_HAVE_AVX512DQ + #define npyv512_combine_ps256(A, B) _mm512_insertf32x8(_mm512_castps256_ps512(A), B, 1) +#else + #define npyv512_combine_ps256(A, B) \ + _mm512_castsi512_ps(npyv512_combine_si256(_mm256_castps_si256(A), _mm256_castps_si256(B))) +#endif + +#define NPYV_IMPL_AVX512_FROM_AVX2_1ARG(FN_NAME, INTRIN) \ + NPY_FINLINE __m512i FN_NAME(__m512i a) \ + { \ + __m256i l_a = npyv512_lower_si256(a); \ + __m256i h_a = npyv512_higher_si256(a); \ + l_a = INTRIN(l_a); \ + h_a = INTRIN(h_a); \ + return npyv512_combine_si256(l_a, h_a); \ + } + +#define NPYV_IMPL_AVX512_FROM_AVX2_PS_1ARG(FN_NAME, INTRIN) \ + NPY_FINLINE __m512 FN_NAME(__m512 a) \ + { \ + __m256 l_a = npyv512_lower_ps256(a); \ + __m256 h_a = npyv512_higher_ps256(a); \ + l_a = INTRIN(l_a); \ + h_a = INTRIN(h_a); \ + return npyv512_combine_ps256(l_a, h_a); \ + } + +#define NPYV_IMPL_AVX512_FROM_AVX2_PD_1ARG(FN_NAME, INTRIN) \ + NPY_FINLINE __m512d FN_NAME(__m512d a) \ + { \ + __m256d l_a = npyv512_lower_pd256(a); \ + __m256d h_a = npyv512_higher_pd256(a); \ + l_a = INTRIN(l_a); \ + h_a = INTRIN(h_a); \ + return npyv512_combine_pd256(l_a, h_a); \ + } + +#define NPYV_IMPL_AVX512_FROM_AVX2_2ARG(FN_NAME, INTRIN) \ + NPY_FINLINE __m512i FN_NAME(__m512i a, __m512i b) \ + { \ + __m256i l_a = npyv512_lower_si256(a); \ + __m256i h_a = npyv512_higher_si256(a); \ + __m256i l_b = npyv512_lower_si256(b); \ + __m256i h_b = npyv512_higher_si256(b); \ + l_a = INTRIN(l_a, l_b); \ + h_a = INTRIN(h_a, h_b); \ + return npyv512_combine_si256(l_a, h_a); \ + } + +#define NPYV_IMPL_AVX512_FROM_SI512_PS_2ARG(FN_NAME, INTRIN) \ + NPY_FINLINE __m512 FN_NAME(__m512 a, __m512 b) \ + { \ + return _mm512_castsi512_ps(INTRIN( \ + _mm512_castps_si512(a), _mm512_castps_si512(b) \ + )); \ + } + +#define NPYV_IMPL_AVX512_FROM_SI512_PD_2ARG(FN_NAME, INTRIN) \ + NPY_FINLINE __m512d FN_NAME(__m512d a, __m512d b) \ + { \ + return _mm512_castsi512_pd(INTRIN( \ + _mm512_castpd_si512(a), _mm512_castpd_si512(b) \ + )); \ + } + +#ifndef NPY_HAVE_AVX512BW + NPYV_IMPL_AVX512_FROM_AVX2_2ARG(npyv512_packs_epi16, _mm256_packs_epi16) +#else + #define npyv512_packs_epi16 _mm512_packs_epi16 +#endif + +NPY_FINLINE __m256i npyv512_pack_lo_hi(__m512i a) { + __m256i lo = npyv512_lower_si256(a); + __m256i hi = npyv512_higher_si256(a); + return _mm256_packs_epi32(lo, hi); +} + +#endif // _NPY_SIMD_AVX512_UTILS_H diff --git a/mkl_umath/src/npyv/binop_override.h b/mkl_umath/src/npyv/binop_override.h new file mode 100644 index 00000000..ec3d0467 --- /dev/null +++ b/mkl_umath/src/npyv/binop_override.h @@ -0,0 +1,215 @@ +#ifndef NUMPY_CORE_SRC_COMMON_BINOP_OVERRIDE_H_ +#define NUMPY_CORE_SRC_COMMON_BINOP_OVERRIDE_H_ + +#include +#include +#include "numpy/arrayobject.h" + +#include "get_attr_string.h" + +/* + * Logic for deciding when binops should return NotImplemented versus when + * they should go ahead and call a ufunc (or similar). + * + * The interaction between binop methods (ndarray.__add__ and friends) and + * ufuncs (which dispatch to __array_ufunc__) is both complicated in its own + * right, and also has complicated historical constraints. + * + * In the very old days, the rules were: + * - If the other argument has a higher __array_priority__, then return + * NotImplemented + * - Otherwise, call the corresponding ufunc. + * - And the ufunc might return NotImplemented based on some complex + * criteria that I won't reproduce here. + * + * Ufuncs no longer return NotImplemented (except in a few marginal situations + * which are being phased out -- see https://github.com/numpy/numpy/pull/5864) + * + * So as of 1.9, the effective rules were: + * - If the other argument has a higher __array_priority__, and is *not* a + * subclass of ndarray, then return NotImplemented. (If it is a subclass, + * the regular Python rules have already given it a chance to run; so if we + * are running, then it means the other argument has already returned + * NotImplemented and is basically asking us to take care of things.) + * - Otherwise call the corresponding ufunc. + * + * We would like to get rid of __array_priority__, and __array_ufunc__ + * provides a large part of a replacement for it. Once __array_ufunc__ is + * widely available, the simplest dispatch rules that might possibly work + * would be: + * - Always call the corresponding ufunc. + * + * But: + * - Doing this immediately would break backwards compatibility -- there's a + * lot of code using __array_priority__ out there. + * - It's not at all clear whether __array_ufunc__ actually is sufficient for + * all use cases. (See https://github.com/numpy/numpy/issues/5844 for lots + * of discussion of this, and in particular + * https://github.com/numpy/numpy/issues/5844#issuecomment-112014014 + * for a summary of some conclusions.) Also, python 3.6 defines a standard + * where setting a special-method name to None is a signal that that method + * cannot be used. + * + * So for 1.13, we are going to try the following rules. + * + * For binops like a.__add__(b): + * - If b does not define __array_ufunc__, apply the legacy rule: + * - If not isinstance(b, a.__class__), and b.__array_priority__ is higher + * than a.__array_priority__, return NotImplemented + * - If b does define __array_ufunc__ but it is None, return NotImplemented + * - Otherwise, call the corresponding ufunc. + * + * For in-place operations like a.__iadd__(b) + * - If b does not define __array_ufunc__, apply the legacy rule: + * - If not isinstance(b, a.__class__), and b.__array_priority__ is higher + * than a.__array_priority__, return NotImplemented + * - Otherwise, call the corresponding ufunc. + * + * For reversed operations like b.__radd__(a) we call the corresponding ufunc. + * + * Rationale for __radd__: This is because by the time the reversed operation + * is called, there are only two possibilities: The first possibility is that + * the current class is a strict subclass of the other class. In practice, the + * only way this will happen is if b is a strict subclass of a, and a is + * ndarray or a subclass of ndarray, and neither a nor b has actually + * overridden this method. In this case, Python will never call a.__add__ + * (because it's identical to b.__radd__), so we have no-one to defer to; + * there's no reason to return NotImplemented. The second possibility is that + * b.__add__ has already been called and returned NotImplemented. Again, in + * this case there is no point in returning NotImplemented. + * + * Rationale for __iadd__: In-place operations do not take all the trouble + * above, because if __iadd__ returns NotImplemented then Python will silently + * convert the operation into an out-of-place operation, i.e. 'a += b' will + * silently become 'a = a + b'. We don't want to allow this for arrays, + * because it will create unexpected memory allocations, break views, etc. + * However, backwards compatibility requires that we follow the rules of + * __array_priority__ for arrays that define it. For classes that use the new + * __array_ufunc__ mechanism we simply defer to the ufunc. That has the effect + * that when the other array has__array_ufunc = None a TypeError will be raised. + * + * In the future we might change these rules further. For example, we plan to + * eventually deprecate __array_priority__ in cases where __array_ufunc__ is + * not present. + */ + +static int +binop_should_defer(PyObject *self, PyObject *other, int inplace) +{ + /* + * This function assumes that self.__binop__(other) is underway and + * implements the rules described above. Python's C API is funny, and + * makes it tricky to tell whether a given slot is called for __binop__ + * ("forward") or __rbinop__ ("reversed"). You are responsible for + * determining this before calling this function; it only provides the + * logic for forward binop implementations. + */ + + /* + * NB: there's another copy of this code in + * numpy.ma.core.MaskedArray._delegate_binop + * which should possibly be updated when this is. + */ + + PyObject *attr; + double self_prio, other_prio; + int defer; + /* + * attribute check is expensive for scalar operations, avoid if possible + */ + if (other == NULL || + self == NULL || + Py_TYPE(self) == Py_TYPE(other) || + PyArray_CheckExact(other) || + PyArray_CheckAnyScalarExact(other)) { + return 0; + } + /* + * Classes with __array_ufunc__ are living in the future, and only need to + * check whether __array_ufunc__ equals None. + */ + attr = PyArray_LookupSpecial(other, npy_um_str_array_ufunc); + if (attr != NULL) { + defer = !inplace && (attr == Py_None); + Py_DECREF(attr); + return defer; + } + else if (PyErr_Occurred()) { + PyErr_Clear(); /* TODO[gh-14801]: propagate crashes during attribute access? */ + } + /* + * Otherwise, we need to check for the legacy __array_priority__. But if + * other.__class__ is a subtype of self.__class__, then it's already had + * a chance to run, so no need to defer to it. + */ + if(PyType_IsSubtype(Py_TYPE(other), Py_TYPE(self))) { + return 0; + } + self_prio = PyArray_GetPriority((PyObject *)self, NPY_SCALAR_PRIORITY); + other_prio = PyArray_GetPriority((PyObject *)other, NPY_SCALAR_PRIORITY); + return self_prio < other_prio; +} + +/* + * A CPython slot like ->tp_as_number->nb_add gets called for *both* forward + * and reversed operations. E.g. + * a + b + * may call + * a->tp_as_number->nb_add(a, b) + * and + * b + a + * may call + * a->tp_as_number->nb_add(b, a) + * and the only way to tell which is which is for a slot implementation 'f' to + * check + * arg1->tp_as_number->nb_add == f + * arg2->tp_as_number->nb_add == f + * If both are true, then CPython will as a special case only call the + * operation once (i.e., it performs both the forward and reversed binops + * simultaneously). This function is mostly intended for figuring out + * whether we are a forward binop that might want to return NotImplemented, + * and in the both-at-once case we never want to return NotImplemented, so in + * that case BINOP_IS_FORWARD returns false. + * + * This is modeled on the checks in CPython's typeobject.c SLOT1BINFULL + * macro. + */ +#define BINOP_IS_FORWARD(m1, m2, SLOT_NAME, test_func) \ + (Py_TYPE(m2)->tp_as_number != NULL && \ + (void*)(Py_TYPE(m2)->tp_as_number->SLOT_NAME) != (void*)(test_func)) + +#define BINOP_GIVE_UP_IF_NEEDED(m1, m2, slot_expr, test_func) \ + do { \ + if (BINOP_IS_FORWARD(m1, m2, slot_expr, test_func) && \ + binop_should_defer((PyObject*)m1, (PyObject*)m2, 0)) { \ + Py_INCREF(Py_NotImplemented); \ + return Py_NotImplemented; \ + } \ + } while (0) + +#define INPLACE_GIVE_UP_IF_NEEDED(m1, m2, slot_expr, test_func) \ + do { \ + if (BINOP_IS_FORWARD(m1, m2, slot_expr, test_func) && \ + binop_should_defer((PyObject*)m1, (PyObject*)m2, 1)) { \ + Py_INCREF(Py_NotImplemented); \ + return Py_NotImplemented; \ + } \ + } while (0) + +/* + * For rich comparison operations, it's impossible to distinguish + * between a forward comparison and a reversed/reflected + * comparison. So we assume they are all forward. This only works because the + * logic in binop_override_forward_binop_should_defer is essentially + * asymmetric -- you can never have two duck-array types that each decide to + * defer to the other. + */ +#define RICHCMP_GIVE_UP_IF_NEEDED(m1, m2) \ + do { \ + if (binop_should_defer((PyObject*)m1, (PyObject*)m2, 0)) { \ + Py_INCREF(Py_NotImplemented); \ + return Py_NotImplemented; \ + } \ + } while (0) + +#endif /* NUMPY_CORE_SRC_COMMON_BINOP_OVERRIDE_H_ */ diff --git a/mkl_umath/src/npyv/cblasfuncs.h b/mkl_umath/src/npyv/cblasfuncs.h new file mode 100644 index 00000000..71c533f3 --- /dev/null +++ b/mkl_umath/src/npyv/cblasfuncs.h @@ -0,0 +1,7 @@ +#ifndef NUMPY_CORE_SRC_COMMON_CBLASFUNCS_H_ +#define NUMPY_CORE_SRC_COMMON_CBLASFUNCS_H_ + +NPY_NO_EXPORT PyObject * +cblas_matrixproduct(int, PyArrayObject *, PyArrayObject *, PyArrayObject *); + +#endif /* NUMPY_CORE_SRC_COMMON_CBLASFUNCS_H_ */ diff --git a/mkl_umath/src/npyv/emulate_maskop.h b/mkl_umath/src/npyv/emulate_maskop.h new file mode 100644 index 00000000..0a3a164f --- /dev/null +++ b/mkl_umath/src/npyv/emulate_maskop.h @@ -0,0 +1,80 @@ +/** + * This header is used internally by all current supported SIMD extensions, + * except for AVX512. + */ +#ifndef NPY_SIMD + #error "Not a standalone header, use simd/simd.h instead" +#endif + +#ifndef _NPY_SIMD_EMULATE_MASKOP_H +#define _NPY_SIMD_EMULATE_MASKOP_H + +/** + * Implements conditional addition and subtraction. + * e.g. npyv_ifadd_f32(mask, a, b, c) -> mask ? a + b : c + * e.g. npyv_ifsub_f32(mask, a, b, c) -> mask ? a - b : c + */ +#define NPYV_IMPL_EMULATE_MASK_ADDSUB(SFX, BSFX) \ + NPY_FINLINE npyv_##SFX npyv_ifadd_##SFX \ + (npyv_##BSFX m, npyv_##SFX a, npyv_##SFX b, npyv_##SFX c) \ + { \ + npyv_##SFX add = npyv_add_##SFX(a, b); \ + return npyv_select_##SFX(m, add, c); \ + } \ + NPY_FINLINE npyv_##SFX npyv_ifsub_##SFX \ + (npyv_##BSFX m, npyv_##SFX a, npyv_##SFX b, npyv_##SFX c) \ + { \ + npyv_##SFX sub = npyv_sub_##SFX(a, b); \ + return npyv_select_##SFX(m, sub, c); \ + } + +NPYV_IMPL_EMULATE_MASK_ADDSUB(u8, b8) +NPYV_IMPL_EMULATE_MASK_ADDSUB(s8, b8) +NPYV_IMPL_EMULATE_MASK_ADDSUB(u16, b16) +NPYV_IMPL_EMULATE_MASK_ADDSUB(s16, b16) +NPYV_IMPL_EMULATE_MASK_ADDSUB(u32, b32) +NPYV_IMPL_EMULATE_MASK_ADDSUB(s32, b32) +NPYV_IMPL_EMULATE_MASK_ADDSUB(u64, b64) +NPYV_IMPL_EMULATE_MASK_ADDSUB(s64, b64) +#if NPY_SIMD_F32 + NPYV_IMPL_EMULATE_MASK_ADDSUB(f32, b32) +#endif +#if NPY_SIMD_F64 + NPYV_IMPL_EMULATE_MASK_ADDSUB(f64, b64) +#endif +#if NPY_SIMD_F32 + // conditional division, m ? a / b : c + NPY_FINLINE npyv_f32 + npyv_ifdiv_f32(npyv_b32 m, npyv_f32 a, npyv_f32 b, npyv_f32 c) + { + const npyv_f32 one = npyv_setall_f32(1.0f); + npyv_f32 div = npyv_div_f32(a, npyv_select_f32(m, b, one)); + return npyv_select_f32(m, div, c); + } + // conditional division, m ? a / b : 0 + NPY_FINLINE npyv_f32 + npyv_ifdivz_f32(npyv_b32 m, npyv_f32 a, npyv_f32 b) + { + const npyv_f32 zero = npyv_zero_f32(); + return npyv_ifdiv_f32(m, a, b, zero); + } +#endif +#if NPY_SIMD_F64 + // conditional division, m ? a / b : c + NPY_FINLINE npyv_f64 + npyv_ifdiv_f64(npyv_b64 m, npyv_f64 a, npyv_f64 b, npyv_f64 c) + { + const npyv_f64 one = npyv_setall_f64(1.0); + npyv_f64 div = npyv_div_f64(a, npyv_select_f64(m, b, one)); + return npyv_select_f64(m, div, c); + } + // conditional division, m ? a / b : 0 + NPY_FINLINE npyv_f64 + npyv_ifdivz_f64(npyv_b64 m, npyv_f64 a, npyv_f64 b) + { + const npyv_f64 zero = npyv_zero_f64(); + return npyv_ifdiv_f64(m, a, b, zero); + } +#endif + +#endif // _NPY_SIMD_EMULATE_MASKOP_H diff --git a/mkl_umath/src/npyv/get_attr_string.h b/mkl_umath/src/npyv/get_attr_string.h new file mode 100644 index 00000000..36d39189 --- /dev/null +++ b/mkl_umath/src/npyv/get_attr_string.h @@ -0,0 +1,95 @@ +#ifndef NUMPY_CORE_SRC_COMMON_GET_ATTR_STRING_H_ +#define NUMPY_CORE_SRC_COMMON_GET_ATTR_STRING_H_ + +#include +#include "ufunc_object.h" + +static inline npy_bool +_is_basic_python_type(PyTypeObject *tp) +{ + return ( + /* Basic number types */ + tp == &PyBool_Type || + tp == &PyLong_Type || + tp == &PyFloat_Type || + tp == &PyComplex_Type || + + /* Basic sequence types */ + tp == &PyList_Type || + tp == &PyTuple_Type || + tp == &PyDict_Type || + tp == &PySet_Type || + tp == &PyFrozenSet_Type || + tp == &PyUnicode_Type || + tp == &PyBytes_Type || + + /* other builtins */ + tp == &PySlice_Type || + tp == Py_TYPE(Py_None) || + tp == Py_TYPE(Py_Ellipsis) || + tp == Py_TYPE(Py_NotImplemented) || + + /* TODO: ndarray, but we can't see PyArray_Type here */ + + /* sentinel to swallow trailing || */ + NPY_FALSE + ); +} + + +/* + * Lookup a special method, following the python approach of looking up + * on the type object, rather than on the instance itself. + * + * Assumes that the special method is a numpy-specific one, so does not look + * at builtin types. It does check base ndarray and numpy scalar types. + * + * In future, could be made more like _Py_LookupSpecial + */ +static inline PyObject * +PyArray_LookupSpecial(PyObject *obj, PyObject *name_unicode) +{ + PyTypeObject *tp = Py_TYPE(obj); + + /* We do not need to check for special attributes on trivial types */ + if (_is_basic_python_type(tp)) { + return NULL; + } + PyObject *res = PyObject_GetAttr((PyObject *)tp, name_unicode); + + if (res == NULL && PyErr_ExceptionMatches(PyExc_AttributeError)) { + PyErr_Clear(); + } + + return res; +} + + +/* + * PyArray_LookupSpecial_OnInstance: + * + * Implements incorrect special method lookup rules, that break the python + * convention, and looks on the instance, not the type. + * + * Kept for backwards compatibility. In future, we should deprecate this. + */ +static inline PyObject * +PyArray_LookupSpecial_OnInstance(PyObject *obj, PyObject *name_unicode) +{ + PyTypeObject *tp = Py_TYPE(obj); + + /* We do not need to check for special attributes on trivial types */ + if (_is_basic_python_type(tp)) { + return NULL; + } + + PyObject *res = PyObject_GetAttr(obj, name_unicode); + + if (res == NULL && PyErr_ExceptionMatches(PyExc_AttributeError)) { + PyErr_Clear(); + } + + return res; +} + +#endif /* NUMPY_CORE_SRC_COMMON_GET_ATTR_STRING_H_ */ diff --git a/mkl_umath/src/npyv/gil_utils.h b/mkl_umath/src/npyv/gil_utils.h new file mode 100644 index 00000000..fd77fa60 --- /dev/null +++ b/mkl_umath/src/npyv/gil_utils.h @@ -0,0 +1,15 @@ +#ifndef NUMPY_CORE_SRC_COMMON_GIL_UTILS_H_ +#define NUMPY_CORE_SRC_COMMON_GIL_UTILS_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +NPY_NO_EXPORT void +npy_gil_error(PyObject *type, const char *format, ...); + +#ifdef __cplusplus +} +#endif + +#endif /* NUMPY_CORE_SRC_COMMON_GIL_UTILS_H_ */ diff --git a/mkl_umath/src/npyv/intdiv.h b/mkl_umath/src/npyv/intdiv.h new file mode 100644 index 00000000..d843eaf4 --- /dev/null +++ b/mkl_umath/src/npyv/intdiv.h @@ -0,0 +1,475 @@ +/** + * This header implements `npyv_divisor_*` intrinsics used for computing the parameters + * of fast integer division, while division intrinsics `npyv_divc_*` are defined in + * {extension}/arithmetic.h. + */ +#ifndef NPY_SIMD + #error "Not a standalone header, use simd/simd.h instead" +#endif +#ifndef _NPY_SIMD_INTDIV_H +#define _NPY_SIMD_INTDIV_H +/********************************************************************************** + ** Integer division + ********************************************************************************** + * Almost all architecture (except Power10) doesn't support integer vector division, + * also the cost of scalar division in architectures like x86 is too high it can take + * 30 to 40 cycles on modern chips and up to 100 on old ones. + * + * Therefore we are using division by multiplying with precomputed reciprocal technique, + * the method that been used in this implementation is based on T. Granlund and P. L. Montgomery + * “Division by invariant integers using multiplication(see [Figure 4.1, 5.1] + * https://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.1.2556) + * + * It shows a good impact for all architectures especially on X86, + * however computing divisor parameters is kind of expensive so this implementation + * should only works when divisor is a scalar and used multiple of times. + * + * The division process is separated into two intrinsics for each data type + * + * 1- npyv_{dtype}x3 npyv_divisor_{dtype} ({dtype} divisor); + * For computing the divisor parameters (multiplier + shifters + sign of divisor(signed only)) + * + * 2- npyv_{dtype} npyv_divisor_{dtype} (npyv_{dtype} dividend, npyv_{dtype}x3 divisor_parms); + * For performing the final division. + * + ** For example: + * int vstep = npyv_nlanes_s32; // number of lanes + * int x = 0x6e70; + * npyv_s32x3 divisor = npyv_divisor_s32(x); // init divisor params + * for (; len >= vstep; src += vstep, dst += vstep, len -= vstep) { + * npyv_s32 a = npyv_load_s32(*src); // load s32 vector from memory + * a = npyv_divc_s32(a, divisor); // divide all elements by x + * npyv_store_s32(dst, a); // store s32 vector into memory + * } + * + ** NOTES: + * - For 64-bit division on Aarch64 and IBM/Power, we fall-back to the scalar division + * since emulating multiply-high is expensive and both architectures have very fast dividers. + * + *************************************************************** + ** Figure 4.1: Unsigned division by run–time invariant divisor + *************************************************************** + * Initialization (given uword d with 1 ≤ d < 2^N): + * int l = ceil(log2(d)); + * uword m = 2^N * (2^l− d) / d + 1; + * int sh1 = min(l, 1); + * int sh2 = max(l − 1, 0); + * + * For q = FLOOR(a/d), all uword: + * uword t1 = MULUH(m, a); + * q = SRL(t1 + SRL(a − t1, sh1), sh2); + * + ************************************************************************************ + ** Figure 5.1: Signed division by run–time invariant divisor, rounded towards zero + ************************************************************************************ + * Initialization (given constant sword d with d !=0): + * int l = max(ceil(log2(abs(d))), 1); + * udword m0 = 1 + (2^(N+l-1)) / abs(d); + * sword m = m0 − 2^N; + * sword dsign = XSIGN(d); + * int sh = l − 1; + * + * For q = TRUNC(a/d), all sword: + * sword q0 = a + MULSH(m, a); + * q0 = SRA(q0, sh) − XSIGN(a); + * q = EOR(q0, dsign) − dsign; + */ +/** + * bit-scan reverse for non-zeros. returns the index of the highest set bit. + * equivalent to floor(log2(a)) + */ +#ifdef _MSC_VER + #include // _BitScanReverse +#endif +NPY_FINLINE unsigned npyv__bitscan_revnz_u32(npy_uint32 a) +{ + assert(a > 0); // due to use __builtin_clz + unsigned r; +#if defined(NPY_HAVE_SSE2) && defined(_MSC_VER) + unsigned long rl; + (void)_BitScanReverse(&rl, (unsigned long)a); + r = (unsigned)rl; + +#elif defined(NPY_HAVE_SSE2) && (defined(__GNUC__) || defined(__clang__) || defined(__INTEL_COMPILER)) \ + && (defined(NPY_CPU_X86) || defined(NPY_CPU_AMD64)) + __asm__("bsr %1, %0" : "=r" (r) : "r"(a)); +#elif defined(__GNUC__) || defined(__clang__) + r = 31 - __builtin_clz(a); // performs on arm -> clz, ppc -> cntlzw +#else + r = 0; + while (a >>= 1) { + r++; + } +#endif + return r; +} +NPY_FINLINE unsigned npyv__bitscan_revnz_u64(npy_uint64 a) +{ + assert(a > 0); // due to use __builtin_clzll +#if defined(_M_AMD64) && defined(_MSC_VER) + unsigned long rl; + (void)_BitScanReverse64(&rl, a); + return (unsigned)rl; +#elif defined(__x86_64__) && (defined(__GNUC__) || defined(__clang__) || defined(__INTEL_COMPILER)) + npy_uint64 r; + __asm__("bsrq %1, %0" : "=r"(r) : "r"(a)); + return (unsigned)r; +#elif defined(__GNUC__) || defined(__clang__) + return 63 - __builtin_clzll(a); +#else + npy_uint64 a_hi = a >> 32; + if (a_hi == 0) { + return npyv__bitscan_revnz_u32((npy_uint32)a); + } + return 32 + npyv__bitscan_revnz_u32((npy_uint32)a_hi); +#endif +} +/** + * Divides 128-bit unsigned integer by a 64-bit when the lower + * 64-bit of the dividend is zero. + * + * This function is needed to calculate the multiplier of 64-bit integer division + * see npyv_divisor_u64/npyv_divisor_s64. + */ +NPY_FINLINE npy_uint64 npyv__divh128_u64(npy_uint64 high, npy_uint64 divisor) +{ + assert(divisor > 1); + npy_uint64 quotient; +#if defined(_M_X64) && defined(_MSC_VER) && _MSC_VER >= 1920 && !defined(__clang__) + npy_uint64 remainder; + quotient = _udiv128(high, 0, divisor, &remainder); + (void)remainder; +#elif defined(__x86_64__) && (defined(__GNUC__) || defined(__clang__) || defined(__INTEL_COMPILER)) + __asm__("divq %[d]" : "=a"(quotient) : [d] "r"(divisor), "a"(0), "d"(high)); +#elif defined(__SIZEOF_INT128__) + quotient = (npy_uint64)((((__uint128_t)high) << 64) / divisor); +#else + /** + * Minified version based on Donald Knuth’s Algorithm D (Division of nonnegative integers), + * and Generic implementation in Hacker’s Delight. + * + * See https://skanthak.homepage.t-online.de/division.html + * with respect to the license of the Hacker's Delight book + * (https://web.archive.org/web/20190408122508/http://www.hackersdelight.org/permissions.htm) + */ + // shift amount for normalize + unsigned ldz = 63 - npyv__bitscan_revnz_u64(divisor); + // normalize divisor + divisor <<= ldz; + high <<= ldz; + // break divisor up into two 32-bit digits + npy_uint32 divisor_hi = divisor >> 32; + npy_uint32 divisor_lo = divisor & 0xFFFFFFFF; + // compute high quotient digit + npy_uint64 quotient_hi = high / divisor_hi; + npy_uint64 remainder = high - divisor_hi * quotient_hi; + npy_uint64 base32 = 1ULL << 32; + while (quotient_hi >= base32 || quotient_hi*divisor_lo > base32*remainder) { + --quotient_hi; + remainder += divisor_hi; + if (remainder >= base32) { + break; + } + } + // compute dividend digit pairs + npy_uint64 dividend_pairs = base32*high - divisor*quotient_hi; + // compute second quotient digit for lower zeros + npy_uint32 quotient_lo = (npy_uint32)(dividend_pairs / divisor_hi); + quotient = base32*quotient_hi + quotient_lo; +#endif + return quotient; +} +// Initializing divisor parameters for unsigned 8-bit division +NPY_FINLINE npyv_u8x3 npyv_divisor_u8(npy_uint8 d) +{ + unsigned l, l2, sh1, sh2, m; + switch (d) { + case 0: // LCOV_EXCL_LINE + // for potential divide by zero, On x86 GCC inserts `ud2` instruction + // instead of letting the HW/CPU trap it which leads to illegal instruction exception. + // 'volatile' should suppress this behavior and allow us to raise HW/CPU + // arithmetic exception. + m = sh1 = sh2 = 1 / ((npy_uint8 volatile *)&d)[0]; + break; + case 1: + m = 1; sh1 = sh2 = 0; + break; + case 2: + m = 1; sh1 = 1; sh2 = 0; + break; + default: + l = npyv__bitscan_revnz_u32(d - 1) + 1; // ceil(log2(d)) + l2 = (npy_uint8)(1 << l); // 2^l, overflow to 0 if l = 8 + m = ((npy_uint16)((l2 - d) << 8)) / d + 1; // multiplier + sh1 = 1; sh2 = l - 1; // shift counts + } + npyv_u8x3 divisor; +#ifdef NPY_HAVE_SSE2 // SSE/AVX2/AVX512 + divisor.val[0] = npyv_setall_u16(m); + divisor.val[1] = npyv_set_u8(sh1); + divisor.val[2] = npyv_set_u8(sh2); +#elif defined(NPY_HAVE_VSX2) || defined(NPY_HAVE_VX) + divisor.val[0] = npyv_setall_u8(m); + divisor.val[1] = npyv_setall_u8(sh1); + divisor.val[2] = npyv_setall_u8(sh2); +#elif defined(NPY_HAVE_NEON) + divisor.val[0] = npyv_setall_u8(m); + divisor.val[1] = npyv_reinterpret_u8_s8(npyv_setall_s8(-sh1)); + divisor.val[2] = npyv_reinterpret_u8_s8(npyv_setall_s8(-sh2)); +#else + #error "please initialize the shifting operand for the new architecture" +#endif + return divisor; +} +// Initializing divisor parameters for signed 8-bit division +NPY_FINLINE npyv_s16x3 npyv_divisor_s16(npy_int16 d); +NPY_FINLINE npyv_s8x3 npyv_divisor_s8(npy_int8 d) +{ +#ifdef NPY_HAVE_SSE2 // SSE/AVX2/AVX512 + npyv_s16x3 p = npyv_divisor_s16(d); + npyv_s8x3 r; + r.val[0] = npyv_reinterpret_s8_s16(p.val[0]); + r.val[1] = npyv_reinterpret_s8_s16(p.val[1]); + r.val[2] = npyv_reinterpret_s8_s16(p.val[2]); + return r; +#else + int d1 = abs(d); + int sh, m; + if (d1 > 1) { + sh = (int)npyv__bitscan_revnz_u32(d1-1); // ceil(log2(abs(d))) - 1 + m = (1 << (8 + sh)) / d1 + 1; // multiplier + } + else if (d1 == 1) { + sh = 0; m = 1; + } + else { + // raise arithmetic exception for d == 0 + sh = m = 1 / ((npy_int8 volatile *)&d)[0]; // LCOV_EXCL_LINE + } + npyv_s8x3 divisor; + divisor.val[0] = npyv_setall_s8(m); + divisor.val[2] = npyv_setall_s8(d < 0 ? -1 : 0); + #if defined(NPY_HAVE_VSX2) || defined(NPY_HAVE_VX) + divisor.val[1] = npyv_setall_s8(sh); + #elif defined(NPY_HAVE_NEON) + divisor.val[1] = npyv_setall_s8(-sh); + #else + #error "please initialize the shifting operand for the new architecture" + #endif + return divisor; +#endif +} +// Initializing divisor parameters for unsigned 16-bit division +NPY_FINLINE npyv_u16x3 npyv_divisor_u16(npy_uint16 d) +{ + unsigned l, l2, sh1, sh2, m; + switch (d) { + case 0: // LCOV_EXCL_LINE + // raise arithmetic exception for d == 0 + m = sh1 = sh2 = 1 / ((npy_uint16 volatile *)&d)[0]; + break; + case 1: + m = 1; sh1 = sh2 = 0; + break; + case 2: + m = 1; sh1 = 1; sh2 = 0; + break; + default: + l = npyv__bitscan_revnz_u32(d - 1) + 1; // ceil(log2(d)) + l2 = (npy_uint16)(1 << l); // 2^l, overflow to 0 if l = 16 + m = ((l2 - d) << 16) / d + 1; // multiplier + sh1 = 1; sh2 = l - 1; // shift counts + } + npyv_u16x3 divisor; + divisor.val[0] = npyv_setall_u16(m); +#ifdef NPY_HAVE_SSE2 // SSE/AVX2/AVX512 + divisor.val[1] = npyv_set_u16(sh1); + divisor.val[2] = npyv_set_u16(sh2); +#elif defined(NPY_HAVE_VSX2) || defined(NPY_HAVE_VX) + divisor.val[1] = npyv_setall_u16(sh1); + divisor.val[2] = npyv_setall_u16(sh2); +#elif defined(NPY_HAVE_NEON) + divisor.val[1] = npyv_reinterpret_u16_s16(npyv_setall_s16(-sh1)); + divisor.val[2] = npyv_reinterpret_u16_s16(npyv_setall_s16(-sh2)); +#else + #error "please initialize the shifting operand for the new architecture" +#endif + return divisor; +} +// Initializing divisor parameters for signed 16-bit division +NPY_FINLINE npyv_s16x3 npyv_divisor_s16(npy_int16 d) +{ + int d1 = abs(d); + int sh, m; + if (d1 > 1) { + sh = (int)npyv__bitscan_revnz_u32(d1 - 1); // ceil(log2(abs(d))) - 1 + m = (1 << (16 + sh)) / d1 + 1; // multiplier + } + else if (d1 == 1) { + sh = 0; m = 1; + } + else { + // raise arithmetic exception for d == 0 + sh = m = 1 / ((npy_int16 volatile *)&d)[0]; // LCOV_EXCL_LINE + } + npyv_s16x3 divisor; + divisor.val[0] = npyv_setall_s16(m); + divisor.val[2] = npyv_setall_s16(d < 0 ? -1 : 0); // sign of divisor +#ifdef NPY_HAVE_SSE2 // SSE/AVX2/AVX512 + divisor.val[1] = npyv_set_s16(sh); +#elif defined(NPY_HAVE_VSX2) || defined(NPY_HAVE_VX) + divisor.val[1] = npyv_setall_s16(sh); +#elif defined(NPY_HAVE_NEON) + divisor.val[1] = npyv_setall_s16(-sh); +#else + #error "please initialize the shifting operand for the new architecture" +#endif + return divisor; +} +// Initializing divisor parameters for unsigned 32-bit division +NPY_FINLINE npyv_u32x3 npyv_divisor_u32(npy_uint32 d) +{ + npy_uint32 l, l2, sh1, sh2, m; + switch (d) { + case 0: // LCOV_EXCL_LINE + // raise arithmetic exception for d == 0 + m = sh1 = sh2 = 1 / ((npy_uint32 volatile *)&d)[0]; // LCOV_EXCL_LINE + break; + case 1: + m = 1; sh1 = sh2 = 0; + break; + case 2: + m = 1; sh1 = 1; sh2 = 0; + break; + default: + l = npyv__bitscan_revnz_u32(d - 1) + 1; // ceil(log2(d)) + l2 = (npy_uint32)(1ULL << l); // 2^l, overflow to 0 if l = 32 + m = ((npy_uint64)(l2 - d) << 32) / d + 1; // multiplier + sh1 = 1; sh2 = l - 1; // shift counts + } + npyv_u32x3 divisor; + divisor.val[0] = npyv_setall_u32(m); +#ifdef NPY_HAVE_SSE2 // SSE/AVX2/AVX512 + divisor.val[1] = npyv_set_u32(sh1); + divisor.val[2] = npyv_set_u32(sh2); +#elif defined(NPY_HAVE_VSX2) || defined(NPY_HAVE_VX) + divisor.val[1] = npyv_setall_u32(sh1); + divisor.val[2] = npyv_setall_u32(sh2); +#elif defined(NPY_HAVE_NEON) + divisor.val[1] = npyv_reinterpret_u32_s32(npyv_setall_s32(-sh1)); + divisor.val[2] = npyv_reinterpret_u32_s32(npyv_setall_s32(-sh2)); +#else + #error "please initialize the shifting operand for the new architecture" +#endif + return divisor; +} +// Initializing divisor parameters for signed 32-bit division +NPY_FINLINE npyv_s32x3 npyv_divisor_s32(npy_int32 d) +{ + npy_int32 d1 = abs(d); + npy_int32 sh, m; + // Handle abs overflow + if ((npy_uint32)d == 0x80000000U) { + m = 0x80000001; + sh = 30; + } + else if (d1 > 1) { + sh = npyv__bitscan_revnz_u32(d1 - 1); // ceil(log2(abs(d))) - 1 + m = (1ULL << (32 + sh)) / d1 + 1; // multiplier + } + else if (d1 == 1) { + sh = 0; m = 1; + } + else { + // raise arithmetic exception for d == 0 + sh = m = 1 / ((npy_int32 volatile *)&d)[0]; // LCOV_EXCL_LINE + } + npyv_s32x3 divisor; + divisor.val[0] = npyv_setall_s32(m); + divisor.val[2] = npyv_setall_s32(d < 0 ? -1 : 0); // sign of divisor +#ifdef NPY_HAVE_SSE2 // SSE/AVX2/AVX512 + divisor.val[1] = npyv_set_s32(sh); +#elif defined(NPY_HAVE_VSX2) || defined(NPY_HAVE_VX) + divisor.val[1] = npyv_setall_s32(sh); +#elif defined(NPY_HAVE_NEON) + divisor.val[1] = npyv_setall_s32(-sh); +#else + #error "please initialize the shifting operand for the new architecture" +#endif + return divisor; +} +// Initializing divisor parameters for unsigned 64-bit division +NPY_FINLINE npyv_u64x3 npyv_divisor_u64(npy_uint64 d) +{ + npyv_u64x3 divisor; +#if defined(NPY_HAVE_VSX2) || defined(NPY_HAVE_VX) || defined(NPY_HAVE_NEON) + divisor.val[0] = npyv_setall_u64(d); +#else + npy_uint64 l, l2, sh1, sh2, m; + switch (d) { + case 0: // LCOV_EXCL_LINE + // raise arithmetic exception for d == 0 + m = sh1 = sh2 = 1 / ((npy_uint64 volatile *)&d)[0]; // LCOV_EXCL_LINE + break; + case 1: + m = 1; sh1 = sh2 = 0; + break; + case 2: + m = 1; sh1 = 1; sh2 = 0; + break; + default: + l = npyv__bitscan_revnz_u64(d - 1) + 1; // ceil(log2(d)) + l2 = l < 64 ? 1ULL << l : 0; // 2^l + m = npyv__divh128_u64(l2 - d, d) + 1; // multiplier + sh1 = 1; sh2 = l - 1; // shift counts + } + divisor.val[0] = npyv_setall_u64(m); + #ifdef NPY_HAVE_SSE2 // SSE/AVX2/AVX512 + divisor.val[1] = npyv_set_u64(sh1); + divisor.val[2] = npyv_set_u64(sh2); + #else + #error "please initialize the shifting operand for the new architecture" + #endif +#endif + return divisor; +} +// Initializing divisor parameters for signed 64-bit division +NPY_FINLINE npyv_s64x3 npyv_divisor_s64(npy_int64 d) +{ + npyv_s64x3 divisor; +#if defined(NPY_HAVE_VSX2) || defined(NPY_HAVE_VX) || defined(NPY_HAVE_NEON) + divisor.val[0] = npyv_setall_s64(d); + divisor.val[1] = npyv_cvt_s64_b64( + npyv_cmpeq_s64(npyv_setall_s64(-1), divisor.val[0]) + ); +#else + npy_int64 d1 = llabs(d); + npy_int64 sh, m; + // Handle abs overflow + if ((npy_uint64)d == 0x8000000000000000ULL) { + m = 0x8000000000000001LL; + sh = 62; + } + else if (d1 > 1) { + sh = npyv__bitscan_revnz_u64(d1 - 1); // ceil(log2(abs(d))) - 1 + m = npyv__divh128_u64(1ULL << sh, d1) + 1; // multiplier + } + else if (d1 == 1) { + sh = 0; m = 1; + } + else { + // raise arithmetic exception for d == 0 + sh = m = 1 / ((npy_int64 volatile *)&d)[0]; // LCOV_EXCL_LINE + } + divisor.val[0] = npyv_setall_s64(m); + divisor.val[2] = npyv_setall_s64(d < 0 ? -1 : 0); // sign of divisor + #ifdef NPY_HAVE_SSE2 // SSE/AVX2/AVX512 + divisor.val[1] = npyv_set_s64(sh); + #else + #error "please initialize the shifting operand for the new architecture" + #endif +#endif + return divisor; +} + +#endif // _NPY_SIMD_INTDIV_H diff --git a/mkl_umath/src/npyv/lowlevel_strided_loops.h b/mkl_umath/src/npyv/lowlevel_strided_loops.h new file mode 100644 index 00000000..9bcfcf2d --- /dev/null +++ b/mkl_umath/src/npyv/lowlevel_strided_loops.h @@ -0,0 +1,790 @@ +#ifndef NUMPY_CORE_SRC_COMMON_LOWLEVEL_STRIDED_LOOPS_H_ +#define NUMPY_CORE_SRC_COMMON_LOWLEVEL_STRIDED_LOOPS_H_ +#include "common.h" +#include "npy_config.h" +#include "array_method.h" +#include "dtype_transfer.h" +#include "mem_overlap.h" +#include "mapping.h" + +/* For PyArray_ macros used below */ +#include "numpy/ndarrayobject.h" + +/* + * NOTE: This API should remain private for the time being, to allow + * for further refinement. I think the 'aligned' mechanism + * needs changing, for example. + * + * Note: Updated in 2018 to distinguish "true" from "uint" alignment. + */ + +/* + * This function pointer is for unary operations that input an + * arbitrarily strided one-dimensional array segment and output + * an arbitrarily strided array segment of the same size. + * It may be a fully general function, or a specialized function + * when the strides or item size have particular known values. + * + * Examples of unary operations are a straight copy, a byte-swap, + * and a casting operation, + * + * The 'transferdata' parameter is slightly special, following a + * generic auxiliary data pattern defined in ndarraytypes.h + * Use NPY_AUXDATA_CLONE and NPY_AUXDATA_FREE to deal with this data. + * + */ +// TODO: FIX! That comment belongs to something now in array-method + +/* + * This is for pointers to functions which behave exactly as + * for PyArrayMethod_StridedLoop, but with an additional mask controlling + * which values are transformed. + * + * TODO: We should move this mask "capability" to the ArrayMethod itself + * probably. Although for NumPy internal things this works decently, + * and exposing it there should be well thought out to be useful beyond + * NumPy if possible. + * + * In particular, the 'i'-th element is operated on if and only if + * mask[i*mask_stride] is true. + */ +typedef int (PyArray_MaskedStridedUnaryOp)( + PyArrayMethod_Context *context, char *const *args, + const npy_intp *dimensions, const npy_intp *strides, + npy_bool *mask, npy_intp mask_stride, + NpyAuxData *auxdata); + +/* + * Gives back a function pointer to a specialized function for copying + * strided memory. Returns NULL if there is a problem with the inputs. + * + * aligned: + * Should be 1 if the src and dst pointers always point to + * locations at which a uint of equal size to dtype->elsize + * would be aligned, 0 otherwise. + * src_stride: + * Should be the src stride if it will always be the same, + * NPY_MAX_INTP otherwise. + * dst_stride: + * Should be the dst stride if it will always be the same, + * NPY_MAX_INTP otherwise. + * itemsize: + * Should be the item size if it will always be the same, 0 otherwise. + * + */ +NPY_NO_EXPORT PyArrayMethod_StridedLoop * +PyArray_GetStridedCopyFn(int aligned, + npy_intp src_stride, npy_intp dst_stride, + npy_intp itemsize); + +/* + * Gives back a function pointer to a specialized function for copying + * and swapping strided memory. This assumes each element is a single + * value to be swapped. + * + * For information on the 'aligned', 'src_stride' and 'dst_stride' parameters + * see above. + * + * Parameters are as for PyArray_GetStridedCopyFn. + */ +NPY_NO_EXPORT PyArrayMethod_StridedLoop * +PyArray_GetStridedCopySwapFn(int aligned, + npy_intp src_stride, npy_intp dst_stride, + npy_intp itemsize); + +/* + * Gives back a function pointer to a specialized function for copying + * and swapping strided memory. This assumes each element is a pair + * of values, each of which needs to be swapped. + * + * For information on the 'aligned', 'src_stride' and 'dst_stride' parameters + * see above. + * + * Parameters are as for PyArray_GetStridedCopyFn. + */ +NPY_NO_EXPORT PyArrayMethod_StridedLoop * +PyArray_GetStridedCopySwapPairFn(int aligned, + npy_intp src_stride, npy_intp dst_stride, + npy_intp itemsize); + +/* + * Gives back a transfer function and transfer data pair which copies + * the data from source to dest, truncating it if the data doesn't + * fit, and padding with zero bytes if there's too much space. + * + * For information on the 'aligned', 'src_stride' and 'dst_stride' parameters + * see above. + * + * Returns NPY_SUCCEED or NPY_FAIL + */ +NPY_NO_EXPORT int +PyArray_GetStridedZeroPadCopyFn(int aligned, int unicode_swap, + npy_intp src_stride, npy_intp dst_stride, + npy_intp src_itemsize, npy_intp dst_itemsize, + PyArrayMethod_StridedLoop **outstransfer, + NpyAuxData **outtransferdata); + +/* + * For casts between built-in numeric types, + * this produces a function pointer for casting from src_type_num + * to dst_type_num. If a conversion is unsupported, returns NULL + * without setting a Python exception. + */ +NPY_NO_EXPORT PyArrayMethod_StridedLoop * +PyArray_GetStridedNumericCastFn(int aligned, + npy_intp src_stride, npy_intp dst_stride, + int src_type_num, int dst_type_num); + +/* + * Gets an operation which copies elements of the given dtype, + * swapping if the dtype isn't in NBO. + * + * Returns NPY_SUCCEED or NPY_FAIL + */ +NPY_NO_EXPORT int +PyArray_GetDTypeCopySwapFn(int aligned, + npy_intp src_stride, npy_intp dst_stride, + PyArray_Descr *dtype, + PyArrayMethod_StridedLoop **outstransfer, + NpyAuxData **outtransferdata); + +/* + * If it's possible, gives back a transfer function which casts and/or + * byte swaps data with the dtype 'src_dtype' into data with the dtype + * 'dst_dtype'. If the outtransferdata is populated with a non-NULL value, + * it must be deallocated with the NPY_AUXDATA_FREE + * function when the transfer function is no longer required. + * + * aligned: + * Should be 1 if the src and dst pointers always point to + * locations at which a uint of equal size to dtype->elsize + * would be aligned, 0 otherwise. + * src_stride: + * Should be the src stride if it will always be the same, + * NPY_MAX_INTP otherwise. + * dst_stride: + * Should be the dst stride if it will always be the same, + * NPY_MAX_INTP otherwise. + * src_dtype: + * The data type of source data. Must not be NULL. + * dst_dtype: + * The data type of destination data. If this is NULL and + * move_references is 1, a transfer function which decrements + * source data references is produced. + * move_references: + * If 0, the destination data gets new reference ownership. + * If 1, the references from the source data are moved to + * the destination data. + * cast_info: + * A pointer to an (uninitialized) `NPY_cast_info` struct, the caller + * must call `NPY_cast_info_xfree` on it (except on error) and handle + * its memory livespan. + * out_needs_api: + * If this is non-NULL, and the transfer function produced needs + * to call into the (Python) API, this gets set to 1. This + * remains untouched if no API access is required. + * + * WARNING: If you set move_references to 1, it is best that src_stride is + * never zero when calling the transfer function. Otherwise, the + * first destination reference will get the value and all the rest + * will get NULL. + * + * Returns NPY_SUCCEED or NPY_FAIL. + */ +NPY_NO_EXPORT int +PyArray_GetDTypeTransferFunction(int aligned, + npy_intp src_stride, npy_intp dst_stride, + PyArray_Descr *src_dtype, PyArray_Descr *dst_dtype, + int move_references, + NPY_cast_info *cast_info, + NPY_ARRAYMETHOD_FLAGS *out_flags); + +NPY_NO_EXPORT int +get_fields_transfer_function(int aligned, + npy_intp src_stride, npy_intp dst_stride, + PyArray_Descr *src_dtype, PyArray_Descr *dst_dtype, + int move_references, + PyArrayMethod_StridedLoop **out_stransfer, + NpyAuxData **out_transferdata, + NPY_ARRAYMETHOD_FLAGS *out_flags); + +NPY_NO_EXPORT int +get_subarray_transfer_function(int aligned, + npy_intp src_stride, npy_intp dst_stride, + PyArray_Descr *src_dtype, PyArray_Descr *dst_dtype, + int move_references, + PyArrayMethod_StridedLoop **out_stransfer, + NpyAuxData **out_transferdata, + NPY_ARRAYMETHOD_FLAGS *out_flags); + +/* + * This is identical to PyArray_GetDTypeTransferFunction, but returns a + * transfer function which also takes a mask as a parameter. The mask is used + * to determine which values to copy, and data is transferred exactly when + * mask[i*mask_stride] is true. + * + * If move_references is true, values which are not copied to the + * destination will still have their source reference decremented. + * + * If mask_dtype is NPY_BOOL or NPY_UINT8, each full element is either + * transferred or not according to the mask as described above. If + * dst_dtype and mask_dtype are both struct dtypes, their names must + * match exactly, and the dtype of each leaf field in mask_dtype must + * be either NPY_BOOL or NPY_UINT8. + */ +NPY_NO_EXPORT int +PyArray_GetMaskedDTypeTransferFunction(int aligned, + npy_intp src_stride, + npy_intp dst_stride, + npy_intp mask_stride, + PyArray_Descr *src_dtype, + PyArray_Descr *dst_dtype, + PyArray_Descr *mask_dtype, + int move_references, + NPY_cast_info *cast_info, + NPY_ARRAYMETHOD_FLAGS *out_flags); + +/* + * Casts the specified number of elements from 'src' with data type + * 'src_dtype' to 'dst' with 'dst_dtype'. See + * PyArray_GetDTypeTransferFunction for more details. + * + * Returns NPY_SUCCEED or NPY_FAIL. + */ +NPY_NO_EXPORT int +PyArray_CastRawArrays(npy_intp count, + char *src, char *dst, + npy_intp src_stride, npy_intp dst_stride, + PyArray_Descr *src_dtype, PyArray_Descr *dst_dtype, + int move_references); + +/* + * These two functions copy or convert the data of an n-dimensional array + * to/from a 1-dimensional strided buffer. These functions will only call + * 'stransfer' with the provided dst_stride/src_stride and + * dst_strides[0]/src_strides[0], so the caller can use those values to + * specialize the function. + * Note that even if ndim == 0, everything needs to be set as if ndim == 1. + * + * The return value is the number of elements it couldn't copy. A return value + * of 0 means all elements were copied, a larger value means the end of + * the n-dimensional array was reached before 'count' elements were copied. + * A negative return value indicates an error occurred. + * + * ndim: + * The number of dimensions of the n-dimensional array. + * dst/src/mask: + * The destination, source or mask starting pointer. + * dst_stride/src_stride/mask_stride: + * The stride of the 1-dimensional strided buffer + * dst_strides/src_strides: + * The strides of the n-dimensional array. + * dst_strides_inc/src_strides_inc: + * How much to add to the ..._strides pointer to get to the next stride. + * coords: + * The starting coordinates in the n-dimensional array. + * coords_inc: + * How much to add to the coords pointer to get to the next coordinate. + * shape: + * The shape of the n-dimensional array. + * shape_inc: + * How much to add to the shape pointer to get to the next shape entry. + * count: + * How many elements to transfer + * src_itemsize: + * How big each element is. If transferring between elements of different + * sizes, for example a casting operation, the 'stransfer' function + * should be specialized for that, in which case 'stransfer' will use + * this parameter as the source item size. + * cast_info: + * Pointer to the NPY_cast_info struct which summarizes all information + * necessary to perform a cast. + */ +NPY_NO_EXPORT npy_intp +PyArray_TransferNDimToStrided(npy_intp ndim, + char *dst, npy_intp dst_stride, + char *src, npy_intp const *src_strides, npy_intp src_strides_inc, + npy_intp const *coords, npy_intp coords_inc, + npy_intp const *shape, npy_intp shape_inc, + npy_intp count, npy_intp src_itemsize, + NPY_cast_info *cast_info); + +NPY_NO_EXPORT npy_intp +PyArray_TransferStridedToNDim(npy_intp ndim, + char *dst, npy_intp const *dst_strides, npy_intp dst_strides_inc, + char *src, npy_intp src_stride, + npy_intp const *coords, npy_intp coords_inc, + npy_intp const *shape, npy_intp shape_inc, + npy_intp count, npy_intp src_itemsize, + NPY_cast_info *cast_info); + +NPY_NO_EXPORT npy_intp +PyArray_TransferMaskedStridedToNDim(npy_intp ndim, + char *dst, npy_intp const *dst_strides, npy_intp dst_strides_inc, + char *src, npy_intp src_stride, + npy_bool *mask, npy_intp mask_stride, + npy_intp const *coords, npy_intp coords_inc, + npy_intp const *shape, npy_intp shape_inc, + npy_intp count, npy_intp src_itemsize, + NPY_cast_info *cast_info); + +NPY_NO_EXPORT int +mapiter_trivial_get( + PyArrayObject *self, PyArrayObject *ind, PyArrayObject *result, + int is_aligned, NPY_cast_info *cast_info); + +NPY_NO_EXPORT int +mapiter_trivial_set( + PyArrayObject *self, PyArrayObject *ind, PyArrayObject *result, + int is_aligned, NPY_cast_info *cast_info); + +NPY_NO_EXPORT int +mapiter_get( + PyArrayMapIterObject *mit, NPY_cast_info *cast_info, + NPY_ARRAYMETHOD_FLAGS flags, int is_aligned); + +NPY_NO_EXPORT int +mapiter_set( + PyArrayMapIterObject *mit, NPY_cast_info *cast_info, + NPY_ARRAYMETHOD_FLAGS flags, int is_aligned); + +/* + * Prepares shape and strides for a simple raw array iteration. + * This sorts the strides into FORTRAN order, reverses any negative + * strides, then coalesces axes where possible. The results are + * filled in the output parameters. + * + * This is intended for simple, lightweight iteration over arrays + * where no buffering of any kind is needed, and the array may + * not be stored as a PyArrayObject. + * + * You can use this together with NPY_RAW_ITER_START and + * NPY_RAW_ITER_ONE_NEXT to handle the looping boilerplate of everything + * but the innermost loop (which is for idim == 0). + * + * Returns 0 on success, -1 on failure. + */ +NPY_NO_EXPORT int +PyArray_PrepareOneRawArrayIter(int ndim, npy_intp const *shape, + char *data, npy_intp const *strides, + int *out_ndim, npy_intp *out_shape, + char **out_data, npy_intp *out_strides); + +/* + * The same as PyArray_PrepareOneRawArrayIter, but for two + * operands instead of one. Any broadcasting of the two operands + * should have already been done before calling this function, + * as the ndim and shape is only specified once for both operands. + * + * Only the strides of the first operand are used to reorder + * the dimensions, no attempt to consider all the strides together + * is made, as is done in the NpyIter object. + * + * You can use this together with NPY_RAW_ITER_START and + * NPY_RAW_ITER_TWO_NEXT to handle the looping boilerplate of everything + * but the innermost loop (which is for idim == 0). + * + * Returns 0 on success, -1 on failure. + */ +NPY_NO_EXPORT int +PyArray_PrepareTwoRawArrayIter(int ndim, npy_intp const *shape, + char *dataA, npy_intp const *stridesA, + char *dataB, npy_intp const *stridesB, + int *out_ndim, npy_intp *out_shape, + char **out_dataA, npy_intp *out_stridesA, + char **out_dataB, npy_intp *out_stridesB); + +/* + * The same as PyArray_PrepareOneRawArrayIter, but for three + * operands instead of one. Any broadcasting of the three operands + * should have already been done before calling this function, + * as the ndim and shape is only specified once for all operands. + * + * Only the strides of the first operand are used to reorder + * the dimensions, no attempt to consider all the strides together + * is made, as is done in the NpyIter object. + * + * You can use this together with NPY_RAW_ITER_START and + * NPY_RAW_ITER_THREE_NEXT to handle the looping boilerplate of everything + * but the innermost loop (which is for idim == 0). + * + * Returns 0 on success, -1 on failure. + */ +NPY_NO_EXPORT int +PyArray_PrepareThreeRawArrayIter(int ndim, npy_intp const *shape, + char *dataA, npy_intp const *stridesA, + char *dataB, npy_intp const *stridesB, + char *dataC, npy_intp const *stridesC, + int *out_ndim, npy_intp *out_shape, + char **out_dataA, npy_intp *out_stridesA, + char **out_dataB, npy_intp *out_stridesB, + char **out_dataC, npy_intp *out_stridesC); + +/* + * Return number of elements that must be peeled from the start of 'addr' with + * 'nvals' elements of size 'esize' in order to reach blockable alignment. + * The required alignment in bytes is passed as the 'alignment' argument and + * must be a power of two. This function is used to prepare an array for + * blocking. See the 'npy_blocked_end' function documentation below for an + * example of how this function is used. + */ +static inline npy_intp +npy_aligned_block_offset(const void * addr, const npy_uintp esize, + const npy_uintp alignment, const npy_uintp nvals) +{ + npy_uintp offset, peel; + + offset = (npy_uintp)addr & (alignment - 1); + peel = offset ? (alignment - offset) / esize : 0; + peel = (peel <= nvals) ? peel : nvals; + assert(peel <= NPY_MAX_INTP); + return (npy_intp)peel; +} + +/* + * Return upper loop bound for an array of 'nvals' elements + * of size 'esize' peeled by 'offset' elements and blocking to + * a vector size of 'vsz' in bytes + * + * example usage: + * npy_intp i; + * double v[101]; + * npy_intp esize = sizeof(v[0]); + * npy_intp peel = npy_aligned_block_offset(v, esize, 16, n); + * // peel to alignment 16 + * for (i = 0; i < peel; i++) + * + * // simd vectorized operation + * for (; i < npy_blocked_end(peel, esize, 16, n); i += 16 / esize) + * + * // handle scalar rest + * for(; i < n; i++) + * + */ +static inline npy_intp +npy_blocked_end(const npy_uintp peel, const npy_uintp esize, + const npy_uintp vsz, const npy_uintp nvals) +{ + npy_uintp ndiff = nvals - peel; + npy_uintp res = (ndiff - ndiff % (vsz / esize)); + + assert(nvals >= peel); + assert(res <= NPY_MAX_INTP); + return (npy_intp)(res); +} + + +/* byte swapping functions */ +static inline npy_uint16 +npy_bswap2(npy_uint16 x) +{ + return ((x & 0xffu) << 8) | (x >> 8); +} + +/* + * treat as int16 and byteswap unaligned memory, + * some cpus don't support unaligned access + */ +static inline void +npy_bswap2_unaligned(char * x) +{ + char a = x[0]; + x[0] = x[1]; + x[1] = a; +} + +static inline npy_uint32 +npy_bswap4(npy_uint32 x) +{ +#ifdef HAVE___BUILTIN_BSWAP32 + return __builtin_bswap32(x); +#else + return ((x & 0xffu) << 24) | ((x & 0xff00u) << 8) | + ((x & 0xff0000u) >> 8) | (x >> 24); +#endif +} + +static inline void +npy_bswap4_unaligned(char * x) +{ + char a = x[0]; + x[0] = x[3]; + x[3] = a; + a = x[1]; + x[1] = x[2]; + x[2] = a; +} + +static inline npy_uint64 +npy_bswap8(npy_uint64 x) +{ +#ifdef HAVE___BUILTIN_BSWAP64 + return __builtin_bswap64(x); +#else + return ((x & 0xffULL) << 56) | + ((x & 0xff00ULL) << 40) | + ((x & 0xff0000ULL) << 24) | + ((x & 0xff000000ULL) << 8) | + ((x & 0xff00000000ULL) >> 8) | + ((x & 0xff0000000000ULL) >> 24) | + ((x & 0xff000000000000ULL) >> 40) | + ( x >> 56); +#endif +} + +static inline void +npy_bswap8_unaligned(char * x) +{ + char a = x[0]; x[0] = x[7]; x[7] = a; + a = x[1]; x[1] = x[6]; x[6] = a; + a = x[2]; x[2] = x[5]; x[5] = a; + a = x[3]; x[3] = x[4]; x[4] = a; +} + + +/* Start raw iteration */ +#define NPY_RAW_ITER_START(idim, ndim, coord, shape) \ + memset((coord), 0, (ndim) * sizeof(coord[0])); \ + do { + +/* Increment to the next n-dimensional coordinate for one raw array */ +#define NPY_RAW_ITER_ONE_NEXT(idim, ndim, coord, shape, data, strides) \ + for ((idim) = 1; (idim) < (ndim); ++(idim)) { \ + if (++(coord)[idim] == (shape)[idim]) { \ + (coord)[idim] = 0; \ + (data) -= ((shape)[idim] - 1) * (strides)[idim]; \ + } \ + else { \ + (data) += (strides)[idim]; \ + break; \ + } \ + } \ + } while ((idim) < (ndim)) + +/* Increment to the next n-dimensional coordinate for two raw arrays */ +#define NPY_RAW_ITER_TWO_NEXT(idim, ndim, coord, shape, \ + dataA, stridesA, dataB, stridesB) \ + for ((idim) = 1; (idim) < (ndim); ++(idim)) { \ + if (++(coord)[idim] == (shape)[idim]) { \ + (coord)[idim] = 0; \ + (dataA) -= ((shape)[idim] - 1) * (stridesA)[idim]; \ + (dataB) -= ((shape)[idim] - 1) * (stridesB)[idim]; \ + } \ + else { \ + (dataA) += (stridesA)[idim]; \ + (dataB) += (stridesB)[idim]; \ + break; \ + } \ + } \ + } while ((idim) < (ndim)) + +/* Increment to the next n-dimensional coordinate for three raw arrays */ +#define NPY_RAW_ITER_THREE_NEXT(idim, ndim, coord, shape, \ + dataA, stridesA, \ + dataB, stridesB, \ + dataC, stridesC) \ + for ((idim) = 1; (idim) < (ndim); ++(idim)) { \ + if (++(coord)[idim] == (shape)[idim]) { \ + (coord)[idim] = 0; \ + (dataA) -= ((shape)[idim] - 1) * (stridesA)[idim]; \ + (dataB) -= ((shape)[idim] - 1) * (stridesB)[idim]; \ + (dataC) -= ((shape)[idim] - 1) * (stridesC)[idim]; \ + } \ + else { \ + (dataA) += (stridesA)[idim]; \ + (dataB) += (stridesB)[idim]; \ + (dataC) += (stridesC)[idim]; \ + break; \ + } \ + } \ + } while ((idim) < (ndim)) + +/* Increment to the next n-dimensional coordinate for four raw arrays */ +#define NPY_RAW_ITER_FOUR_NEXT(idim, ndim, coord, shape, \ + dataA, stridesA, \ + dataB, stridesB, \ + dataC, stridesC, \ + dataD, stridesD) \ + for ((idim) = 1; (idim) < (ndim); ++(idim)) { \ + if (++(coord)[idim] == (shape)[idim]) { \ + (coord)[idim] = 0; \ + (dataA) -= ((shape)[idim] - 1) * (stridesA)[idim]; \ + (dataB) -= ((shape)[idim] - 1) * (stridesB)[idim]; \ + (dataC) -= ((shape)[idim] - 1) * (stridesC)[idim]; \ + (dataD) -= ((shape)[idim] - 1) * (stridesD)[idim]; \ + } \ + else { \ + (dataA) += (stridesA)[idim]; \ + (dataB) += (stridesB)[idim]; \ + (dataC) += (stridesC)[idim]; \ + (dataD) += (stridesD)[idim]; \ + break; \ + } \ + } \ + } while ((idim) < (ndim)) + + +/* + * TRIVIAL ITERATION + * + * In some cases when the iteration order isn't important, iteration over + * arrays is trivial. This is the case when: + * * The array has 0 or 1 dimensions. + * * The array is C or Fortran contiguous. + * Use of an iterator can be skipped when this occurs. These macros assist + * in detecting and taking advantage of the situation. Note that it may + * be worthwhile to further check if the stride is a contiguous stride + * and take advantage of that. + * + * Here is example code for a single array: + * + * if (PyArray_TRIVIALLY_ITERABLE(self)) { + * char *data; + * npy_intp count, stride; + * + * PyArray_PREPARE_TRIVIAL_ITERATION(self, count, data, stride); + * + * while (count--) { + * // Use the data pointer + * + * data += stride; + * } + * } + * else { + * // Create iterator, etc... + * } + * + */ + +/* + * Note: Equivalently iterable macro requires one of arr1 or arr2 be + * trivially iterable to be valid. + */ + +/** + * Determine whether two arrays are safe for trivial iteration in cases where + * some of the arrays may be modified. + * + * In-place iteration is safe if one of the following is true: + * + * - Both arrays are read-only + * - The arrays do not have overlapping memory (based on a check that may be too + * strict) + * - The strides match, and the non-read-only array base addresses are equal or + * before the read-only one, ensuring correct data dependency. + */ + +#define PyArray_TRIVIALLY_ITERABLE_OP_NOREAD 0 +#define PyArray_TRIVIALLY_ITERABLE_OP_READ 1 + +#define PyArray_TRIVIALLY_ITERABLE(arr) ( \ + PyArray_NDIM(arr) <= 1 || \ + PyArray_CHKFLAGS(arr, NPY_ARRAY_C_CONTIGUOUS) || \ + PyArray_CHKFLAGS(arr, NPY_ARRAY_F_CONTIGUOUS) \ + ) + +#define PyArray_TRIVIAL_PAIR_ITERATION_STRIDE(size, arr) ( \ + assert(PyArray_TRIVIALLY_ITERABLE(arr)), \ + size == 1 ? 0 : ((PyArray_NDIM(arr) == 1) ? \ + PyArray_STRIDE(arr, 0) : PyArray_ITEMSIZE(arr))) + +static inline int +PyArray_EQUIVALENTLY_ITERABLE_OVERLAP_OK(PyArrayObject *arr1, PyArrayObject *arr2, + int arr1_read, int arr2_read) +{ + npy_intp size1, size2, stride1, stride2; + int arr1_ahead = 0, arr2_ahead = 0; + + if (arr1_read && arr2_read) { + return 1; + } + + size1 = PyArray_SIZE(arr1); + stride1 = PyArray_TRIVIAL_PAIR_ITERATION_STRIDE(size1, arr1); + + /* + * arr1 == arr2 is common for in-place operations, so we fast-path it here. + * TODO: The stride1 != 0 check rejects broadcast arrays. This may affect + * self-overlapping arrays, but seems only necessary due to + * `try_trivial_single_output_loop` not rejecting broadcast outputs. + */ + if (arr1 == arr2 && stride1 != 0) { + return 1; + } + + if (solve_may_share_memory(arr1, arr2, 1) == 0) { + return 1; + } + + /* + * Arrays overlapping in memory may be equivalently iterable if input + * arrays stride ahead faster than output arrays. + */ + + size2 = PyArray_SIZE(arr2); + stride2 = PyArray_TRIVIAL_PAIR_ITERATION_STRIDE(size2, arr2); + + /* + * Arrays with zero stride are never "ahead" since the element is reused + * (at this point we know the array extents overlap). + */ + + if (stride1 > 0) { + arr1_ahead = (stride1 >= stride2 && + PyArray_BYTES(arr1) >= PyArray_BYTES(arr2)); + } + else if (stride1 < 0) { + arr1_ahead = (stride1 <= stride2 && + PyArray_BYTES(arr1) <= PyArray_BYTES(arr2)); + } + + if (stride2 > 0) { + arr2_ahead = (stride2 >= stride1 && + PyArray_BYTES(arr2) >= PyArray_BYTES(arr1)); + } + else if (stride2 < 0) { + arr2_ahead = (stride2 <= stride1 && + PyArray_BYTES(arr2) <= PyArray_BYTES(arr1)); + } + + return (!arr1_read || arr1_ahead) && (!arr2_read || arr2_ahead); +} + +#define PyArray_EQUIVALENTLY_ITERABLE_BASE(arr1, arr2) ( \ + PyArray_NDIM(arr1) == PyArray_NDIM(arr2) && \ + PyArray_CompareLists(PyArray_DIMS(arr1), \ + PyArray_DIMS(arr2), \ + PyArray_NDIM(arr1)) && \ + (PyArray_FLAGS(arr1)&(NPY_ARRAY_C_CONTIGUOUS| \ + NPY_ARRAY_F_CONTIGUOUS)) & \ + (PyArray_FLAGS(arr2)&(NPY_ARRAY_C_CONTIGUOUS| \ + NPY_ARRAY_F_CONTIGUOUS)) \ + ) + +#define PyArray_EQUIVALENTLY_ITERABLE(arr1, arr2, arr1_read, arr2_read) ( \ + PyArray_EQUIVALENTLY_ITERABLE_BASE(arr1, arr2) && \ + PyArray_EQUIVALENTLY_ITERABLE_OVERLAP_OK( \ + arr1, arr2, arr1_read, arr2_read)) + +#define PyArray_PREPARE_TRIVIAL_ITERATION(arr, count, data, stride) \ + count = PyArray_SIZE(arr); \ + data = PyArray_BYTES(arr); \ + stride = ((PyArray_NDIM(arr) == 0) ? 0 : \ + ((PyArray_NDIM(arr) == 1) ? \ + PyArray_STRIDE(arr, 0) : \ + PyArray_ITEMSIZE(arr))); + +#define PyArray_PREPARE_TRIVIAL_PAIR_ITERATION(arr1, arr2, \ + count, \ + data1, data2, \ + stride1, stride2) { \ + npy_intp size1 = PyArray_SIZE(arr1); \ + npy_intp size2 = PyArray_SIZE(arr2); \ + count = ((size1 > size2) || size1 == 0) ? size1 : size2; \ + data1 = PyArray_BYTES(arr1); \ + data2 = PyArray_BYTES(arr2); \ + stride1 = PyArray_TRIVIAL_PAIR_ITERATION_STRIDE(size1, arr1); \ + stride2 = PyArray_TRIVIAL_PAIR_ITERATION_STRIDE(size2, arr2); \ + } + +#endif /* NUMPY_CORE_SRC_COMMON_LOWLEVEL_STRIDED_LOOPS_H_ */ diff --git a/mkl_umath/src/npyv/mem_overlap.h b/mkl_umath/src/npyv/mem_overlap.h new file mode 100644 index 00000000..3aa4f798 --- /dev/null +++ b/mkl_umath/src/npyv/mem_overlap.h @@ -0,0 +1,49 @@ +#ifndef NUMPY_CORE_SRC_COMMON_MEM_OVERLAP_H_ +#define NUMPY_CORE_SRC_COMMON_MEM_OVERLAP_H_ + +#include "npy_config.h" +#include "numpy/ndarraytypes.h" + + +/* Bounds check only */ +#define NPY_MAY_SHARE_BOUNDS 0 + +/* Exact solution */ +#define NPY_MAY_SHARE_EXACT -1 + + +typedef enum { + MEM_OVERLAP_NO = 0, /* no solution exists */ + MEM_OVERLAP_YES = 1, /* solution found */ + MEM_OVERLAP_TOO_HARD = -1, /* max_work exceeded */ + MEM_OVERLAP_OVERFLOW = -2, /* algorithm failed due to integer overflow */ + MEM_OVERLAP_ERROR = -3 /* invalid input */ +} mem_overlap_t; + + +typedef struct { + npy_int64 a; + npy_int64 ub; +} diophantine_term_t; + +NPY_VISIBILITY_HIDDEN mem_overlap_t +solve_diophantine(unsigned int n, diophantine_term_t *E, + npy_int64 b, Py_ssize_t max_work, int require_nontrivial, + npy_int64 *x); + +NPY_VISIBILITY_HIDDEN int +diophantine_simplify(unsigned int *n, diophantine_term_t *E, npy_int64 b); + +NPY_VISIBILITY_HIDDEN mem_overlap_t +solve_may_share_memory(PyArrayObject *a, PyArrayObject *b, + Py_ssize_t max_work); + +NPY_VISIBILITY_HIDDEN mem_overlap_t +solve_may_have_internal_overlap(PyArrayObject *a, Py_ssize_t max_work); + +NPY_VISIBILITY_HIDDEN void +offset_bounds_from_strides(const int itemsize, const int nd, + const npy_intp *dims, const npy_intp *strides, + npy_intp *lower_offset, npy_intp *upper_offset); + +#endif /* NUMPY_CORE_SRC_COMMON_MEM_OVERLAP_H_ */ diff --git a/mkl_umath/src/npyv/neon/arithmetic.h b/mkl_umath/src/npyv/neon/arithmetic.h new file mode 100644 index 00000000..68362011 --- /dev/null +++ b/mkl_umath/src/npyv/neon/arithmetic.h @@ -0,0 +1,343 @@ +#ifndef NPY_SIMD + #error "Not a standalone header" +#endif + +#ifndef _NPY_SIMD_NEON_ARITHMETIC_H +#define _NPY_SIMD_NEON_ARITHMETIC_H + +/*************************** + * Addition + ***************************/ +// non-saturated +#define npyv_add_u8 vaddq_u8 +#define npyv_add_s8 vaddq_s8 +#define npyv_add_u16 vaddq_u16 +#define npyv_add_s16 vaddq_s16 +#define npyv_add_u32 vaddq_u32 +#define npyv_add_s32 vaddq_s32 +#define npyv_add_u64 vaddq_u64 +#define npyv_add_s64 vaddq_s64 +#define npyv_add_f32 vaddq_f32 +#define npyv_add_f64 vaddq_f64 + +// saturated +#define npyv_adds_u8 vqaddq_u8 +#define npyv_adds_s8 vqaddq_s8 +#define npyv_adds_u16 vqaddq_u16 +#define npyv_adds_s16 vqaddq_s16 + +/*************************** + * Subtraction + ***************************/ +// non-saturated +#define npyv_sub_u8 vsubq_u8 +#define npyv_sub_s8 vsubq_s8 +#define npyv_sub_u16 vsubq_u16 +#define npyv_sub_s16 vsubq_s16 +#define npyv_sub_u32 vsubq_u32 +#define npyv_sub_s32 vsubq_s32 +#define npyv_sub_u64 vsubq_u64 +#define npyv_sub_s64 vsubq_s64 +#define npyv_sub_f32 vsubq_f32 +#define npyv_sub_f64 vsubq_f64 + +// saturated +#define npyv_subs_u8 vqsubq_u8 +#define npyv_subs_s8 vqsubq_s8 +#define npyv_subs_u16 vqsubq_u16 +#define npyv_subs_s16 vqsubq_s16 + +/*************************** + * Multiplication + ***************************/ +// non-saturated +#define npyv_mul_u8 vmulq_u8 +#define npyv_mul_s8 vmulq_s8 +#define npyv_mul_u16 vmulq_u16 +#define npyv_mul_s16 vmulq_s16 +#define npyv_mul_u32 vmulq_u32 +#define npyv_mul_s32 vmulq_s32 +#define npyv_mul_f32 vmulq_f32 +#define npyv_mul_f64 vmulq_f64 + +/*************************** + * Integer Division + ***************************/ +// See simd/intdiv.h for more clarification +// divide each unsigned 8-bit element by a precomputed divisor +NPY_FINLINE npyv_u8 npyv_divc_u8(npyv_u8 a, const npyv_u8x3 divisor) +{ + const uint8x8_t mulc_lo = vget_low_u8(divisor.val[0]); + // high part of unsigned multiplication + uint16x8_t mull_lo = vmull_u8(vget_low_u8(a), mulc_lo); +#if NPY_SIMD_F64 + uint16x8_t mull_hi = vmull_high_u8(a, divisor.val[0]); + // get the high unsigned bytes + uint8x16_t mulhi = vuzp2q_u8(vreinterpretq_u8_u16(mull_lo), vreinterpretq_u8_u16(mull_hi)); +#else + const uint8x8_t mulc_hi = vget_high_u8(divisor.val[0]); + uint16x8_t mull_hi = vmull_u8(vget_high_u8(a), mulc_hi); + uint8x16_t mulhi = vuzpq_u8(vreinterpretq_u8_u16(mull_lo), vreinterpretq_u8_u16(mull_hi)).val[1]; +#endif + // floor(a/d) = (mulhi + ((a-mulhi) >> sh1)) >> sh2 + uint8x16_t q = vsubq_u8(a, mulhi); + q = vshlq_u8(q, vreinterpretq_s8_u8(divisor.val[1])); + q = vaddq_u8(mulhi, q); + q = vshlq_u8(q, vreinterpretq_s8_u8(divisor.val[2])); + return q; +} +// divide each signed 8-bit element by a precomputed divisor (round towards zero) +NPY_FINLINE npyv_s8 npyv_divc_s8(npyv_s8 a, const npyv_s8x3 divisor) +{ + const int8x8_t mulc_lo = vget_low_s8(divisor.val[0]); + // high part of signed multiplication + int16x8_t mull_lo = vmull_s8(vget_low_s8(a), mulc_lo); +#if NPY_SIMD_F64 + int16x8_t mull_hi = vmull_high_s8(a, divisor.val[0]); + // get the high unsigned bytes + int8x16_t mulhi = vuzp2q_s8(vreinterpretq_s8_s16(mull_lo), vreinterpretq_s8_s16(mull_hi)); +#else + const int8x8_t mulc_hi = vget_high_s8(divisor.val[0]); + int16x8_t mull_hi = vmull_s8(vget_high_s8(a), mulc_hi); + int8x16_t mulhi = vuzpq_s8(vreinterpretq_s8_s16(mull_lo), vreinterpretq_s8_s16(mull_hi)).val[1]; +#endif + // q = ((a + mulhi) >> sh1) - XSIGN(a) + // trunc(a/d) = (q ^ dsign) - dsign + int8x16_t q = vshlq_s8(vaddq_s8(a, mulhi), divisor.val[1]); + q = vsubq_s8(q, vshrq_n_s8(a, 7)); + q = vsubq_s8(veorq_s8(q, divisor.val[2]), divisor.val[2]); + return q; +} +// divide each unsigned 16-bit element by a precomputed divisor +NPY_FINLINE npyv_u16 npyv_divc_u16(npyv_u16 a, const npyv_u16x3 divisor) +{ + const uint16x4_t mulc_lo = vget_low_u16(divisor.val[0]); + // high part of unsigned multiplication + uint32x4_t mull_lo = vmull_u16(vget_low_u16(a), mulc_lo); +#if NPY_SIMD_F64 + uint32x4_t mull_hi = vmull_high_u16(a, divisor.val[0]); + // get the high unsigned bytes + uint16x8_t mulhi = vuzp2q_u16(vreinterpretq_u16_u32(mull_lo), vreinterpretq_u16_u32(mull_hi)); +#else + const uint16x4_t mulc_hi = vget_high_u16(divisor.val[0]); + uint32x4_t mull_hi = vmull_u16(vget_high_u16(a), mulc_hi); + uint16x8_t mulhi = vuzpq_u16(vreinterpretq_u16_u32(mull_lo), vreinterpretq_u16_u32(mull_hi)).val[1]; +#endif + // floor(a/d) = (mulhi + ((a-mulhi) >> sh1)) >> sh2 + uint16x8_t q = vsubq_u16(a, mulhi); + q = vshlq_u16(q, vreinterpretq_s16_u16(divisor.val[1])); + q = vaddq_u16(mulhi, q); + q = vshlq_u16(q, vreinterpretq_s16_u16(divisor.val[2])); + return q; +} +// divide each signed 16-bit element by a precomputed divisor (round towards zero) +NPY_FINLINE npyv_s16 npyv_divc_s16(npyv_s16 a, const npyv_s16x3 divisor) +{ + const int16x4_t mulc_lo = vget_low_s16(divisor.val[0]); + // high part of signed multiplication + int32x4_t mull_lo = vmull_s16(vget_low_s16(a), mulc_lo); +#if NPY_SIMD_F64 + int32x4_t mull_hi = vmull_high_s16(a, divisor.val[0]); + // get the high unsigned bytes + int16x8_t mulhi = vuzp2q_s16(vreinterpretq_s16_s32(mull_lo), vreinterpretq_s16_s32(mull_hi)); +#else + const int16x4_t mulc_hi = vget_high_s16(divisor.val[0]); + int32x4_t mull_hi = vmull_s16(vget_high_s16(a), mulc_hi); + int16x8_t mulhi = vuzpq_s16(vreinterpretq_s16_s32(mull_lo), vreinterpretq_s16_s32(mull_hi)).val[1]; +#endif + // q = ((a + mulhi) >> sh1) - XSIGN(a) + // trunc(a/d) = (q ^ dsign) - dsign + int16x8_t q = vshlq_s16(vaddq_s16(a, mulhi), divisor.val[1]); + q = vsubq_s16(q, vshrq_n_s16(a, 15)); + q = vsubq_s16(veorq_s16(q, divisor.val[2]), divisor.val[2]); + return q; +} +// divide each unsigned 32-bit element by a precomputed divisor +NPY_FINLINE npyv_u32 npyv_divc_u32(npyv_u32 a, const npyv_u32x3 divisor) +{ + const uint32x2_t mulc_lo = vget_low_u32(divisor.val[0]); + // high part of unsigned multiplication + uint64x2_t mull_lo = vmull_u32(vget_low_u32(a), mulc_lo); +#if NPY_SIMD_F64 + uint64x2_t mull_hi = vmull_high_u32(a, divisor.val[0]); + // get the high unsigned bytes + uint32x4_t mulhi = vuzp2q_u32(vreinterpretq_u32_u64(mull_lo), vreinterpretq_u32_u64(mull_hi)); +#else + const uint32x2_t mulc_hi = vget_high_u32(divisor.val[0]); + uint64x2_t mull_hi = vmull_u32(vget_high_u32(a), mulc_hi); + uint32x4_t mulhi = vuzpq_u32(vreinterpretq_u32_u64(mull_lo), vreinterpretq_u32_u64(mull_hi)).val[1]; +#endif + // floor(a/d) = (mulhi + ((a-mulhi) >> sh1)) >> sh2 + uint32x4_t q = vsubq_u32(a, mulhi); + q = vshlq_u32(q, vreinterpretq_s32_u32(divisor.val[1])); + q = vaddq_u32(mulhi, q); + q = vshlq_u32(q, vreinterpretq_s32_u32(divisor.val[2])); + return q; +} +// divide each signed 32-bit element by a precomputed divisor (round towards zero) +NPY_FINLINE npyv_s32 npyv_divc_s32(npyv_s32 a, const npyv_s32x3 divisor) +{ + const int32x2_t mulc_lo = vget_low_s32(divisor.val[0]); + // high part of signed multiplication + int64x2_t mull_lo = vmull_s32(vget_low_s32(a), mulc_lo); +#if NPY_SIMD_F64 + int64x2_t mull_hi = vmull_high_s32(a, divisor.val[0]); + // get the high unsigned bytes + int32x4_t mulhi = vuzp2q_s32(vreinterpretq_s32_s64(mull_lo), vreinterpretq_s32_s64(mull_hi)); +#else + const int32x2_t mulc_hi = vget_high_s32(divisor.val[0]); + int64x2_t mull_hi = vmull_s32(vget_high_s32(a), mulc_hi); + int32x4_t mulhi = vuzpq_s32(vreinterpretq_s32_s64(mull_lo), vreinterpretq_s32_s64(mull_hi)).val[1]; +#endif + // q = ((a + mulhi) >> sh1) - XSIGN(a) + // trunc(a/d) = (q ^ dsign) - dsign + int32x4_t q = vshlq_s32(vaddq_s32(a, mulhi), divisor.val[1]); + q = vsubq_s32(q, vshrq_n_s32(a, 31)); + q = vsubq_s32(veorq_s32(q, divisor.val[2]), divisor.val[2]); + return q; +} +// divide each unsigned 64-bit element by a divisor +NPY_FINLINE npyv_u64 npyv_divc_u64(npyv_u64 a, const npyv_u64x3 divisor) +{ + const uint64_t d = vgetq_lane_u64(divisor.val[0], 0); + return npyv_set_u64(vgetq_lane_u64(a, 0) / d, vgetq_lane_u64(a, 1) / d); +} +// returns the high 64 bits of signed 64-bit multiplication +NPY_FINLINE npyv_s64 npyv_divc_s64(npyv_s64 a, const npyv_s64x3 divisor) +{ + const int64_t d = vgetq_lane_s64(divisor.val[0], 0); + return npyv_set_s64(vgetq_lane_s64(a, 0) / d, vgetq_lane_s64(a, 1) / d); +} +/*************************** + * Division + ***************************/ +#if NPY_SIMD_F64 + #define npyv_div_f32 vdivq_f32 +#else + NPY_FINLINE npyv_f32 npyv_div_f32(npyv_f32 a, npyv_f32 b) + { + // Based on ARM doc, see https://developer.arm.com/documentation/dui0204/j/CIHDIACI + // estimate to 1/b + npyv_f32 recipe = vrecpeq_f32(b); + /** + * Newton-Raphson iteration: + * x[n+1] = x[n] * (2-d * x[n]) + * converges to (1/d) if x0 is the result of VRECPE applied to d. + * + * NOTE: at least 3 iterations is needed to improve precision + */ + recipe = vmulq_f32(vrecpsq_f32(b, recipe), recipe); + recipe = vmulq_f32(vrecpsq_f32(b, recipe), recipe); + recipe = vmulq_f32(vrecpsq_f32(b, recipe), recipe); + // a/b = a*recip(b) + return vmulq_f32(a, recipe); + } +#endif +#define npyv_div_f64 vdivq_f64 + +/*************************** + * FUSED F32 + ***************************/ +#ifdef NPY_HAVE_NEON_VFPV4 // FMA + // multiply and add, a*b + c + NPY_FINLINE npyv_f32 npyv_muladd_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c) + { return vfmaq_f32(c, a, b); } + // multiply and subtract, a*b - c + NPY_FINLINE npyv_f32 npyv_mulsub_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c) + { return vfmaq_f32(vnegq_f32(c), a, b); } + // negate multiply and add, -(a*b) + c + NPY_FINLINE npyv_f32 npyv_nmuladd_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c) + { return vfmsq_f32(c, a, b); } + // negate multiply and subtract, -(a*b) - c + NPY_FINLINE npyv_f32 npyv_nmulsub_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c) + { return vfmsq_f32(vnegq_f32(c), a, b); } +#else + // multiply and add, a*b + c + NPY_FINLINE npyv_f32 npyv_muladd_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c) + { return vmlaq_f32(c, a, b); } + // multiply and subtract, a*b - c + NPY_FINLINE npyv_f32 npyv_mulsub_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c) + { return vmlaq_f32(vnegq_f32(c), a, b); } + // negate multiply and add, -(a*b) + c + NPY_FINLINE npyv_f32 npyv_nmuladd_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c) + { return vmlsq_f32(c, a, b); } + // negate multiply and subtract, -(a*b) - c + NPY_FINLINE npyv_f32 npyv_nmulsub_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c) + { return vmlsq_f32(vnegq_f32(c), a, b); } +#endif +// multiply, add for odd elements and subtract even elements. +// (a * b) -+ c +NPY_FINLINE npyv_f32 npyv_muladdsub_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c) +{ + const npyv_f32 msign = npyv_set_f32(-0.0f, 0.0f, -0.0f, 0.0f); + return npyv_muladd_f32(a, b, npyv_xor_f32(msign, c)); +} + +/*************************** + * FUSED F64 + ***************************/ +#if NPY_SIMD_F64 + NPY_FINLINE npyv_f64 npyv_muladd_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c) + { return vfmaq_f64(c, a, b); } + NPY_FINLINE npyv_f64 npyv_mulsub_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c) + { return vfmaq_f64(vnegq_f64(c), a, b); } + NPY_FINLINE npyv_f64 npyv_nmuladd_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c) + { return vfmsq_f64(c, a, b); } + NPY_FINLINE npyv_f64 npyv_nmulsub_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c) + { return vfmsq_f64(vnegq_f64(c), a, b); } + NPY_FINLINE npyv_f64 npyv_muladdsub_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c) + { + const npyv_f64 msign = npyv_set_f64(-0.0, 0.0); + return npyv_muladd_f64(a, b, npyv_xor_f64(msign, c)); + } +#endif // NPY_SIMD_F64 + +/*************************** + * Summation + ***************************/ +// reduce sum across vector +#if NPY_SIMD_F64 + #define npyv_sum_u32 vaddvq_u32 + #define npyv_sum_u64 vaddvq_u64 + #define npyv_sum_f32 vaddvq_f32 + #define npyv_sum_f64 vaddvq_f64 +#else + NPY_FINLINE npy_uint64 npyv_sum_u64(npyv_u64 a) + { + return vget_lane_u64(vadd_u64(vget_low_u64(a), vget_high_u64(a)),0); + } + + NPY_FINLINE npy_uint32 npyv_sum_u32(npyv_u32 a) + { + uint32x2_t a0 = vpadd_u32(vget_low_u32(a), vget_high_u32(a)); + return (unsigned)vget_lane_u32(vpadd_u32(a0, vget_high_u32(a)),0); + } + + NPY_FINLINE float npyv_sum_f32(npyv_f32 a) + { + float32x2_t r = vadd_f32(vget_high_f32(a), vget_low_f32(a)); + return vget_lane_f32(vpadd_f32(r, r), 0); + } +#endif + +// expand the source vector and performs sum reduce +#if NPY_SIMD_F64 + #define npyv_sumup_u8 vaddlvq_u8 + #define npyv_sumup_u16 vaddlvq_u16 +#else + NPY_FINLINE npy_uint16 npyv_sumup_u8(npyv_u8 a) + { + uint32x4_t t0 = vpaddlq_u16(vpaddlq_u8(a)); + uint32x2_t t1 = vpadd_u32(vget_low_u32(t0), vget_high_u32(t0)); + return vget_lane_u32(vpadd_u32(t1, t1), 0); + } + + NPY_FINLINE npy_uint32 npyv_sumup_u16(npyv_u16 a) + { + uint32x4_t t0 = vpaddlq_u16(a); + uint32x2_t t1 = vpadd_u32(vget_low_u32(t0), vget_high_u32(t0)); + return vget_lane_u32(vpadd_u32(t1, t1), 0); + } +#endif + +#endif // _NPY_SIMD_NEON_ARITHMETIC_H diff --git a/mkl_umath/src/npyv/neon/conversion.h b/mkl_umath/src/npyv/neon/conversion.h new file mode 100644 index 00000000..03a5259f --- /dev/null +++ b/mkl_umath/src/npyv/neon/conversion.h @@ -0,0 +1,150 @@ +#ifndef NPY_SIMD + #error "Not a standalone header" +#endif + +#ifndef _NPY_SIMD_NEON_CVT_H +#define _NPY_SIMD_NEON_CVT_H + +// convert boolean vectors to integer vectors +#define npyv_cvt_u8_b8(A) A +#define npyv_cvt_s8_b8 vreinterpretq_s8_u8 +#define npyv_cvt_u16_b16(A) A +#define npyv_cvt_s16_b16 vreinterpretq_s16_u16 +#define npyv_cvt_u32_b32(A) A +#define npyv_cvt_s32_b32 vreinterpretq_s32_u32 +#define npyv_cvt_u64_b64(A) A +#define npyv_cvt_s64_b64 vreinterpretq_s64_u64 +#define npyv_cvt_f32_b32 vreinterpretq_f32_u32 +#define npyv_cvt_f64_b64 vreinterpretq_f64_u64 + +// convert integer vectors to boolean vectors +#define npyv_cvt_b8_u8(BL) BL +#define npyv_cvt_b8_s8 vreinterpretq_u8_s8 +#define npyv_cvt_b16_u16(BL) BL +#define npyv_cvt_b16_s16 vreinterpretq_u16_s16 +#define npyv_cvt_b32_u32(BL) BL +#define npyv_cvt_b32_s32 vreinterpretq_u32_s32 +#define npyv_cvt_b64_u64(BL) BL +#define npyv_cvt_b64_s64 vreinterpretq_u64_s64 +#define npyv_cvt_b32_f32 vreinterpretq_u32_f32 +#define npyv_cvt_b64_f64 vreinterpretq_u64_f64 + +// convert boolean vector to integer bitfield +NPY_FINLINE npy_uint64 npyv_tobits_b8(npyv_b8 a) +{ + const npyv_u8 scale = npyv_set_u8(1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128); + npyv_u8 seq_scale = vandq_u8(a, scale); +#if defined(__aarch64__) + const npyv_u8 byteOrder = {0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15}; + npyv_u8 v0 = vqtbl1q_u8(seq_scale, byteOrder); + return vaddlvq_u16(vreinterpretq_u16_u8(v0)); +#else + npyv_u64 sumh = vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(seq_scale))); + return vgetq_lane_u64(sumh, 0) + ((int)vgetq_lane_u64(sumh, 1) << 8); +#endif +} +NPY_FINLINE npy_uint64 npyv_tobits_b16(npyv_b16 a) +{ + const npyv_u16 scale = npyv_set_u16(1, 2, 4, 8, 16, 32, 64, 128); + npyv_u16 seq_scale = vandq_u16(a, scale); +#if NPY_SIMD_F64 + return vaddvq_u16(seq_scale); +#else + npyv_u64 sumh = vpaddlq_u32(vpaddlq_u16(seq_scale)); + return vgetq_lane_u64(sumh, 0) + vgetq_lane_u64(sumh, 1); +#endif +} +NPY_FINLINE npy_uint64 npyv_tobits_b32(npyv_b32 a) +{ + const npyv_u32 scale = npyv_set_u32(1, 2, 4, 8); + npyv_u32 seq_scale = vandq_u32(a, scale); +#if NPY_SIMD_F64 + return vaddvq_u32(seq_scale); +#else + npyv_u64 sumh = vpaddlq_u32(seq_scale); + return vgetq_lane_u64(sumh, 0) + vgetq_lane_u64(sumh, 1); +#endif +} +NPY_FINLINE npy_uint64 npyv_tobits_b64(npyv_b64 a) +{ + uint64_t lo = vgetq_lane_u64(a, 0); + uint64_t hi = vgetq_lane_u64(a, 1); + return ((hi & 0x2) | (lo & 0x1)); +} + +//expand +NPY_FINLINE npyv_u16x2 npyv_expand_u16_u8(npyv_u8 data) { + npyv_u16x2 r; + r.val[0] = vmovl_u8(vget_low_u8(data)); + r.val[1] = vmovl_u8(vget_high_u8(data)); + return r; +} + +NPY_FINLINE npyv_u32x2 npyv_expand_u32_u16(npyv_u16 data) { + npyv_u32x2 r; + r.val[0] = vmovl_u16(vget_low_u16(data)); + r.val[1] = vmovl_u16(vget_high_u16(data)); + return r; +} + +// pack two 16-bit boolean into one 8-bit boolean vector +NPY_FINLINE npyv_b8 npyv_pack_b8_b16(npyv_b16 a, npyv_b16 b) { +#if defined(__aarch64__) + return vuzp1q_u8((uint8x16_t)a, (uint8x16_t)b); +#else + return vcombine_u8(vmovn_u16(a), vmovn_u16(b)); +#endif +} + +// pack four 32-bit boolean vectors into one 8-bit boolean vector +NPY_FINLINE npyv_b8 +npyv_pack_b8_b32(npyv_b32 a, npyv_b32 b, npyv_b32 c, npyv_b32 d) { +#if defined(__aarch64__) + npyv_b16 ab = vuzp1q_u16((uint16x8_t)a, (uint16x8_t)b); + npyv_b16 cd = vuzp1q_u16((uint16x8_t)c, (uint16x8_t)d); +#else + npyv_b16 ab = vcombine_u16(vmovn_u32(a), vmovn_u32(b)); + npyv_b16 cd = vcombine_u16(vmovn_u32(c), vmovn_u32(d)); +#endif + return npyv_pack_b8_b16(ab, cd); +} + +// pack eight 64-bit boolean vectors into one 8-bit boolean vector +NPY_FINLINE npyv_b8 +npyv_pack_b8_b64(npyv_b64 a, npyv_b64 b, npyv_b64 c, npyv_b64 d, + npyv_b64 e, npyv_b64 f, npyv_b64 g, npyv_b64 h) { +#if defined(__aarch64__) + npyv_b32 ab = vuzp1q_u32((uint32x4_t)a, (uint32x4_t)b); + npyv_b32 cd = vuzp1q_u32((uint32x4_t)c, (uint32x4_t)d); + npyv_b32 ef = vuzp1q_u32((uint32x4_t)e, (uint32x4_t)f); + npyv_u32 gh = vuzp1q_u32((uint32x4_t)g, (uint32x4_t)h); +#else + npyv_b32 ab = vcombine_u32(vmovn_u64(a), vmovn_u64(b)); + npyv_b32 cd = vcombine_u32(vmovn_u64(c), vmovn_u64(d)); + npyv_b32 ef = vcombine_u32(vmovn_u64(e), vmovn_u64(f)); + npyv_b32 gh = vcombine_u32(vmovn_u64(g), vmovn_u64(h)); +#endif + return npyv_pack_b8_b32(ab, cd, ef, gh); +} + +// round to nearest integer +#if NPY_SIMD_F64 + #define npyv_round_s32_f32 vcvtnq_s32_f32 + NPY_FINLINE npyv_s32 npyv_round_s32_f64(npyv_f64 a, npyv_f64 b) + { + npyv_s64 lo = vcvtnq_s64_f64(a), hi = vcvtnq_s64_f64(b); + return vcombine_s32(vmovn_s64(lo), vmovn_s64(hi)); + } +#else + NPY_FINLINE npyv_s32 npyv_round_s32_f32(npyv_f32 a) + { + // halves will be rounded up. it's very costly + // to obey IEEE standard on arm7. tests should pass +-1 difference + const npyv_u32 sign = vdupq_n_u32(0x80000000); + const npyv_f32 half = vdupq_n_f32(0.5f); + npyv_f32 sign_half = vbslq_f32(sign, a, half); + return vcvtq_s32_f32(vaddq_f32(a, sign_half)); + } +#endif + +#endif // _NPY_SIMD_NEON_CVT_H diff --git a/mkl_umath/src/npyv/neon/math.h b/mkl_umath/src/npyv/neon/math.h new file mode 100644 index 00000000..58d14809 --- /dev/null +++ b/mkl_umath/src/npyv/neon/math.h @@ -0,0 +1,416 @@ +#ifndef NPY_SIMD + #error "Not a standalone header" +#endif + +#ifndef _NPY_SIMD_NEON_MATH_H +#define _NPY_SIMD_NEON_MATH_H + +/*************************** + * Elementary + ***************************/ +// Absolute +#define npyv_abs_f32 vabsq_f32 +#define npyv_abs_f64 vabsq_f64 + +// Square +NPY_FINLINE npyv_f32 npyv_square_f32(npyv_f32 a) +{ return vmulq_f32(a, a); } +#if NPY_SIMD_F64 + NPY_FINLINE npyv_f64 npyv_square_f64(npyv_f64 a) + { return vmulq_f64(a, a); } +#endif + +// Square root +#if NPY_SIMD_F64 + #define npyv_sqrt_f32 vsqrtq_f32 + #define npyv_sqrt_f64 vsqrtq_f64 +#else + // Based on ARM doc, see https://developer.arm.com/documentation/dui0204/j/CIHDIACI + NPY_FINLINE npyv_f32 npyv_sqrt_f32(npyv_f32 a) + { + const npyv_f32 zero = vdupq_n_f32(0.0f); + const npyv_u32 pinf = vdupq_n_u32(0x7f800000); + npyv_u32 is_zero = vceqq_f32(a, zero), is_inf = vceqq_u32(vreinterpretq_u32_f32(a), pinf); + // guard against floating-point division-by-zero error + npyv_f32 guard_byz = vbslq_f32(is_zero, vreinterpretq_f32_u32(pinf), a); + // estimate to (1/√a) + npyv_f32 rsqrte = vrsqrteq_f32(guard_byz); + /** + * Newton-Raphson iteration: + * x[n+1] = x[n] * (3-d * (x[n]*x[n]) )/2) + * converges to (1/√d)if x0 is the result of VRSQRTE applied to d. + * + * NOTE: at least 3 iterations is needed to improve precision + */ + rsqrte = vmulq_f32(vrsqrtsq_f32(vmulq_f32(a, rsqrte), rsqrte), rsqrte); + rsqrte = vmulq_f32(vrsqrtsq_f32(vmulq_f32(a, rsqrte), rsqrte), rsqrte); + rsqrte = vmulq_f32(vrsqrtsq_f32(vmulq_f32(a, rsqrte), rsqrte), rsqrte); + // a * (1/√a) + npyv_f32 sqrt = vmulq_f32(a, rsqrte); + // return zero if the a is zero + // - return zero if a is zero. + // - return positive infinity if a is positive infinity + return vbslq_f32(vorrq_u32(is_zero, is_inf), a, sqrt); + } +#endif // NPY_SIMD_F64 + +// Reciprocal +NPY_FINLINE npyv_f32 npyv_recip_f32(npyv_f32 a) +{ +#if NPY_SIMD_F64 + const npyv_f32 one = vdupq_n_f32(1.0f); + return npyv_div_f32(one, a); +#else + npyv_f32 recipe = vrecpeq_f32(a); + /** + * Newton-Raphson iteration: + * x[n+1] = x[n] * (2-d * x[n]) + * converges to (1/d) if x0 is the result of VRECPE applied to d. + * + * NOTE: at least 3 iterations is needed to improve precision + */ + recipe = vmulq_f32(vrecpsq_f32(a, recipe), recipe); + recipe = vmulq_f32(vrecpsq_f32(a, recipe), recipe); + recipe = vmulq_f32(vrecpsq_f32(a, recipe), recipe); + return recipe; +#endif +} +#if NPY_SIMD_F64 + NPY_FINLINE npyv_f64 npyv_recip_f64(npyv_f64 a) + { + const npyv_f64 one = vdupq_n_f64(1.0); + return npyv_div_f64(one, a); + } +#endif // NPY_SIMD_F64 + +// Maximum, natively mapping with no guarantees to handle NaN. +#define npyv_max_f32 vmaxq_f32 +#define npyv_max_f64 vmaxq_f64 +// Maximum, supports IEEE floating-point arithmetic (IEC 60559), +// - If one of the two vectors contains NaN, the equivalent element of the other vector is set +// - Only if both corresponded elements are NaN, NaN is set. +#ifdef NPY_HAVE_ASIMD + #define npyv_maxp_f32 vmaxnmq_f32 +#else + NPY_FINLINE npyv_f32 npyv_maxp_f32(npyv_f32 a, npyv_f32 b) + { + npyv_u32 nn_a = vceqq_f32(a, a); + npyv_u32 nn_b = vceqq_f32(b, b); + return vmaxq_f32(vbslq_f32(nn_a, a, b), vbslq_f32(nn_b, b, a)); + } +#endif +// Max, propagates NaNs +// If any of corresponded element is NaN, NaN is set. +#define npyv_maxn_f32 vmaxq_f32 +#if NPY_SIMD_F64 + #define npyv_maxp_f64 vmaxnmq_f64 + #define npyv_maxn_f64 vmaxq_f64 +#endif // NPY_SIMD_F64 +// Maximum, integer operations +#define npyv_max_u8 vmaxq_u8 +#define npyv_max_s8 vmaxq_s8 +#define npyv_max_u16 vmaxq_u16 +#define npyv_max_s16 vmaxq_s16 +#define npyv_max_u32 vmaxq_u32 +#define npyv_max_s32 vmaxq_s32 +NPY_FINLINE npyv_u64 npyv_max_u64(npyv_u64 a, npyv_u64 b) +{ + return vbslq_u64(npyv_cmpgt_u64(a, b), a, b); +} +NPY_FINLINE npyv_s64 npyv_max_s64(npyv_s64 a, npyv_s64 b) +{ + return vbslq_s64(npyv_cmpgt_s64(a, b), a, b); +} + +// Minimum, natively mapping with no guarantees to handle NaN. +#define npyv_min_f32 vminq_f32 +#define npyv_min_f64 vminq_f64 +// Minimum, supports IEEE floating-point arithmetic (IEC 60559), +// - If one of the two vectors contains NaN, the equivalent element of the other vector is set +// - Only if both corresponded elements are NaN, NaN is set. +#ifdef NPY_HAVE_ASIMD + #define npyv_minp_f32 vminnmq_f32 +#else + NPY_FINLINE npyv_f32 npyv_minp_f32(npyv_f32 a, npyv_f32 b) + { + npyv_u32 nn_a = vceqq_f32(a, a); + npyv_u32 nn_b = vceqq_f32(b, b); + return vminq_f32(vbslq_f32(nn_a, a, b), vbslq_f32(nn_b, b, a)); + } +#endif +// Min, propagates NaNs +// If any of corresponded element is NaN, NaN is set. +#define npyv_minn_f32 vminq_f32 +#if NPY_SIMD_F64 + #define npyv_minp_f64 vminnmq_f64 + #define npyv_minn_f64 vminq_f64 +#endif // NPY_SIMD_F64 + +// Minimum, integer operations +#define npyv_min_u8 vminq_u8 +#define npyv_min_s8 vminq_s8 +#define npyv_min_u16 vminq_u16 +#define npyv_min_s16 vminq_s16 +#define npyv_min_u32 vminq_u32 +#define npyv_min_s32 vminq_s32 +NPY_FINLINE npyv_u64 npyv_min_u64(npyv_u64 a, npyv_u64 b) +{ + return vbslq_u64(npyv_cmplt_u64(a, b), a, b); +} +NPY_FINLINE npyv_s64 npyv_min_s64(npyv_s64 a, npyv_s64 b) +{ + return vbslq_s64(npyv_cmplt_s64(a, b), a, b); +} +// reduce min/max for all data types +#if NPY_SIMD_F64 + #define npyv_reduce_max_u8 vmaxvq_u8 + #define npyv_reduce_max_s8 vmaxvq_s8 + #define npyv_reduce_max_u16 vmaxvq_u16 + #define npyv_reduce_max_s16 vmaxvq_s16 + #define npyv_reduce_max_u32 vmaxvq_u32 + #define npyv_reduce_max_s32 vmaxvq_s32 + + #define npyv_reduce_max_f32 vmaxvq_f32 + #define npyv_reduce_max_f64 vmaxvq_f64 + #define npyv_reduce_maxn_f32 vmaxvq_f32 + #define npyv_reduce_maxn_f64 vmaxvq_f64 + #define npyv_reduce_maxp_f32 vmaxnmvq_f32 + #define npyv_reduce_maxp_f64 vmaxnmvq_f64 + + #define npyv_reduce_min_u8 vminvq_u8 + #define npyv_reduce_min_s8 vminvq_s8 + #define npyv_reduce_min_u16 vminvq_u16 + #define npyv_reduce_min_s16 vminvq_s16 + #define npyv_reduce_min_u32 vminvq_u32 + #define npyv_reduce_min_s32 vminvq_s32 + + #define npyv_reduce_min_f32 vminvq_f32 + #define npyv_reduce_min_f64 vminvq_f64 + #define npyv_reduce_minn_f32 vminvq_f32 + #define npyv_reduce_minn_f64 vminvq_f64 + #define npyv_reduce_minp_f32 vminnmvq_f32 + #define npyv_reduce_minp_f64 vminnmvq_f64 +#else + #define NPY_IMPL_NEON_REDUCE_MINMAX(INTRIN, STYPE, SFX) \ + NPY_FINLINE npy_##STYPE npyv_reduce_##INTRIN##_##SFX(npyv_##SFX a) \ + { \ + STYPE##x8_t r = vp##INTRIN##_##SFX(vget_low_##SFX(a), vget_high_##SFX(a)); \ + r = vp##INTRIN##_##SFX(r, r); \ + r = vp##INTRIN##_##SFX(r, r); \ + r = vp##INTRIN##_##SFX(r, r); \ + return (npy_##STYPE)vget_lane_##SFX(r, 0); \ + } + NPY_IMPL_NEON_REDUCE_MINMAX(min, uint8, u8) + NPY_IMPL_NEON_REDUCE_MINMAX(max, uint8, u8) + NPY_IMPL_NEON_REDUCE_MINMAX(min, int8, s8) + NPY_IMPL_NEON_REDUCE_MINMAX(max, int8, s8) + #undef NPY_IMPL_NEON_REDUCE_MINMAX + + #define NPY_IMPL_NEON_REDUCE_MINMAX(INTRIN, STYPE, SFX) \ + NPY_FINLINE npy_##STYPE npyv_reduce_##INTRIN##_##SFX(npyv_##SFX a) \ + { \ + STYPE##x4_t r = vp##INTRIN##_##SFX(vget_low_##SFX(a), vget_high_##SFX(a)); \ + r = vp##INTRIN##_##SFX(r, r); \ + r = vp##INTRIN##_##SFX(r, r); \ + return (npy_##STYPE)vget_lane_##SFX(r, 0); \ + } + NPY_IMPL_NEON_REDUCE_MINMAX(min, uint16, u16) + NPY_IMPL_NEON_REDUCE_MINMAX(max, uint16, u16) + NPY_IMPL_NEON_REDUCE_MINMAX(min, int16, s16) + NPY_IMPL_NEON_REDUCE_MINMAX(max, int16, s16) + #undef NPY_IMPL_NEON_REDUCE_MINMAX + + #define NPY_IMPL_NEON_REDUCE_MINMAX(INTRIN, STYPE, SFX) \ + NPY_FINLINE npy_##STYPE npyv_reduce_##INTRIN##_##SFX(npyv_##SFX a) \ + { \ + STYPE##x2_t r = vp##INTRIN##_##SFX(vget_low_##SFX(a), vget_high_##SFX(a)); \ + r = vp##INTRIN##_##SFX(r, r); \ + return (npy_##STYPE)vget_lane_##SFX(r, 0); \ + } + NPY_IMPL_NEON_REDUCE_MINMAX(min, uint32, u32) + NPY_IMPL_NEON_REDUCE_MINMAX(max, uint32, u32) + NPY_IMPL_NEON_REDUCE_MINMAX(min, int32, s32) + NPY_IMPL_NEON_REDUCE_MINMAX(max, int32, s32) + #undef NPY_IMPL_NEON_REDUCE_MINMAX + + #define NPY_IMPL_NEON_REDUCE_MINMAX(INTRIN, INF) \ + NPY_FINLINE float npyv_reduce_##INTRIN##_f32(npyv_f32 a) \ + { \ + float32x2_t r = vp##INTRIN##_f32(vget_low_f32(a), vget_high_f32(a));\ + r = vp##INTRIN##_f32(r, r); \ + return vget_lane_f32(r, 0); \ + } \ + NPY_FINLINE float npyv_reduce_##INTRIN##p_f32(npyv_f32 a) \ + { \ + npyv_b32 notnan = npyv_notnan_f32(a); \ + if (NPY_UNLIKELY(!npyv_any_b32(notnan))) { \ + return vgetq_lane_f32(a, 0); \ + } \ + a = npyv_select_f32(notnan, a, \ + npyv_reinterpret_f32_u32(npyv_setall_u32(INF))); \ + return npyv_reduce_##INTRIN##_f32(a); \ + } \ + NPY_FINLINE float npyv_reduce_##INTRIN##n_f32(npyv_f32 a) \ + { \ + return npyv_reduce_##INTRIN##_f32(a); \ + } + NPY_IMPL_NEON_REDUCE_MINMAX(min, 0x7f800000) + NPY_IMPL_NEON_REDUCE_MINMAX(max, 0xff800000) + #undef NPY_IMPL_NEON_REDUCE_MINMAX +#endif // NPY_SIMD_F64 +#define NPY_IMPL_NEON_REDUCE_MINMAX(INTRIN, STYPE, SFX, OP) \ + NPY_FINLINE STYPE npyv_reduce_##INTRIN##_##SFX(npyv_##SFX a) \ + { \ + STYPE al = (STYPE)vget_low_##SFX(a); \ + STYPE ah = (STYPE)vget_high_##SFX(a); \ + return al OP ah ? al : ah; \ + } +NPY_IMPL_NEON_REDUCE_MINMAX(max, npy_uint64, u64, >) +NPY_IMPL_NEON_REDUCE_MINMAX(max, npy_int64, s64, >) +NPY_IMPL_NEON_REDUCE_MINMAX(min, npy_uint64, u64, <) +NPY_IMPL_NEON_REDUCE_MINMAX(min, npy_int64, s64, <) +#undef NPY_IMPL_NEON_REDUCE_MINMAX + +// round to nearest integer even +NPY_FINLINE npyv_f32 npyv_rint_f32(npyv_f32 a) +{ +#ifdef NPY_HAVE_ASIMD + return vrndnq_f32(a); +#else + // ARMv7 NEON only supports fp to int truncate conversion. + // a magic trick of adding 1.5 * 2^23 is used for rounding + // to nearest even and then subtract this magic number to get + // the integer. + // + const npyv_u32 szero = vreinterpretq_u32_f32(vdupq_n_f32(-0.0f)); + const npyv_u32 sign_mask = vandq_u32(vreinterpretq_u32_f32(a), szero); + const npyv_f32 two_power_23 = vdupq_n_f32(8388608.0); // 2^23 + const npyv_f32 two_power_23h = vdupq_n_f32(12582912.0f); // 1.5 * 2^23 + npyv_u32 nnan_mask = vceqq_f32(a, a); + // eliminate nans to avoid invalid fp errors + npyv_f32 abs_x = vabsq_f32(vreinterpretq_f32_u32(vandq_u32(nnan_mask, vreinterpretq_u32_f32(a)))); + // round by add magic number 1.5 * 2^23 + npyv_f32 round = vsubq_f32(vaddq_f32(two_power_23h, abs_x), two_power_23h); + // copysign + round = vreinterpretq_f32_u32(vorrq_u32(vreinterpretq_u32_f32(round), sign_mask )); + // a if |a| >= 2^23 or a == NaN + npyv_u32 mask = vcleq_f32(abs_x, two_power_23); + mask = vandq_u32(mask, nnan_mask); + return vbslq_f32(mask, round, a); +#endif +} +#if NPY_SIMD_F64 + #define npyv_rint_f64 vrndnq_f64 +#endif // NPY_SIMD_F64 + +// ceil +#ifdef NPY_HAVE_ASIMD + #define npyv_ceil_f32 vrndpq_f32 +#else + NPY_FINLINE npyv_f32 npyv_ceil_f32(npyv_f32 a) + { + const npyv_u32 one = vreinterpretq_u32_f32(vdupq_n_f32(1.0f)); + const npyv_u32 szero = vreinterpretq_u32_f32(vdupq_n_f32(-0.0f)); + const npyv_u32 sign_mask = vandq_u32(vreinterpretq_u32_f32(a), szero); + const npyv_f32 two_power_23 = vdupq_n_f32(8388608.0); // 2^23 + const npyv_f32 two_power_23h = vdupq_n_f32(12582912.0f); // 1.5 * 2^23 + npyv_u32 nnan_mask = vceqq_f32(a, a); + npyv_f32 x = vreinterpretq_f32_u32(vandq_u32(nnan_mask, vreinterpretq_u32_f32(a))); + // eliminate nans to avoid invalid fp errors + npyv_f32 abs_x = vabsq_f32(x); + // round by add magic number 1.5 * 2^23 + npyv_f32 round = vsubq_f32(vaddq_f32(two_power_23h, abs_x), two_power_23h); + // copysign + round = vreinterpretq_f32_u32(vorrq_u32(vreinterpretq_u32_f32(round), sign_mask)); + npyv_f32 ceil = vaddq_f32(round, vreinterpretq_f32_u32( + vandq_u32(vcltq_f32(round, x), one)) + ); + // respects signed zero + ceil = vreinterpretq_f32_u32(vorrq_u32(vreinterpretq_u32_f32(ceil), sign_mask)); + // a if |a| >= 2^23 or a == NaN + npyv_u32 mask = vcleq_f32(abs_x, two_power_23); + mask = vandq_u32(mask, nnan_mask); + return vbslq_f32(mask, ceil, a); + } +#endif +#if NPY_SIMD_F64 + #define npyv_ceil_f64 vrndpq_f64 +#endif // NPY_SIMD_F64 + +// trunc +#ifdef NPY_HAVE_ASIMD + #define npyv_trunc_f32 vrndq_f32 +#else + NPY_FINLINE npyv_f32 npyv_trunc_f32(npyv_f32 a) + { + const npyv_s32 max_int = vdupq_n_s32(0x7fffffff); + const npyv_u32 exp_mask = vdupq_n_u32(0xff000000); + const npyv_s32 szero = vreinterpretq_s32_f32(vdupq_n_f32(-0.0f)); + const npyv_u32 sign_mask = vandq_u32( + vreinterpretq_u32_f32(a), vreinterpretq_u32_s32(szero)); + + npyv_u32 nfinite_mask = vshlq_n_u32(vreinterpretq_u32_f32(a), 1); + nfinite_mask = vandq_u32(nfinite_mask, exp_mask); + nfinite_mask = vceqq_u32(nfinite_mask, exp_mask); + // eliminate nans/inf to avoid invalid fp errors + npyv_f32 x = vreinterpretq_f32_u32( + veorq_u32(nfinite_mask, vreinterpretq_u32_f32(a))); + /** + * On armv7, vcvtq.f32 handles special cases as follows: + * NaN return 0 + * +inf or +outrange return 0x80000000(-0.0f) + * -inf or -outrange return 0x7fffffff(nan) + */ + npyv_s32 trunci = vcvtq_s32_f32(x); + npyv_f32 trunc = vcvtq_f32_s32(trunci); + // respect signed zero, e.g. -0.5 -> -0.0 + trunc = vreinterpretq_f32_u32( + vorrq_u32(vreinterpretq_u32_f32(trunc), sign_mask)); + // if overflow return a + npyv_u32 overflow_mask = vorrq_u32( + vceqq_s32(trunci, szero), vceqq_s32(trunci, max_int) + ); + // a if a overflow or nonfinite + return vbslq_f32(vorrq_u32(nfinite_mask, overflow_mask), a, trunc); + } +#endif +#if NPY_SIMD_F64 + #define npyv_trunc_f64 vrndq_f64 +#endif // NPY_SIMD_F64 + +// floor +#ifdef NPY_HAVE_ASIMD + #define npyv_floor_f32 vrndmq_f32 +#else + NPY_FINLINE npyv_f32 npyv_floor_f32(npyv_f32 a) + { + const npyv_u32 one = vreinterpretq_u32_f32(vdupq_n_f32(1.0f)); + const npyv_u32 szero = vreinterpretq_u32_f32(vdupq_n_f32(-0.0f)); + const npyv_u32 sign_mask = vandq_u32(vreinterpretq_u32_f32(a), szero); + const npyv_f32 two_power_23 = vdupq_n_f32(8388608.0); // 2^23 + const npyv_f32 two_power_23h = vdupq_n_f32(12582912.0f); // 1.5 * 2^23 + + npyv_u32 nnan_mask = vceqq_f32(a, a); + npyv_f32 x = vreinterpretq_f32_u32(vandq_u32(nnan_mask, vreinterpretq_u32_f32(a))); + // eliminate nans to avoid invalid fp errors + npyv_f32 abs_x = vabsq_f32(x); + // round by add magic number 1.5 * 2^23 + npyv_f32 round = vsubq_f32(vaddq_f32(two_power_23h, abs_x), two_power_23h); + // copysign + round = vreinterpretq_f32_u32(vorrq_u32(vreinterpretq_u32_f32(round), sign_mask)); + npyv_f32 floor = vsubq_f32(round, vreinterpretq_f32_u32( + vandq_u32(vcgtq_f32(round, x), one) + )); + // respects signed zero + floor = vreinterpretq_f32_u32(vorrq_u32(vreinterpretq_u32_f32(floor), sign_mask)); + // a if |a| >= 2^23 or a == NaN + npyv_u32 mask = vcleq_f32(abs_x, two_power_23); + mask = vandq_u32(mask, nnan_mask); + return vbslq_f32(mask, floor, a); + } +#endif // NPY_HAVE_ASIMD +#if NPY_SIMD_F64 + #define npyv_floor_f64 vrndmq_f64 +#endif // NPY_SIMD_F64 + +#endif // _NPY_SIMD_NEON_MATH_H diff --git a/mkl_umath/src/npyv/neon/memory.h b/mkl_umath/src/npyv/neon/memory.h new file mode 100644 index 00000000..e7503b82 --- /dev/null +++ b/mkl_umath/src/npyv/neon/memory.h @@ -0,0 +1,678 @@ +#ifndef NPY_SIMD + #error "Not a standalone header" +#endif + +#ifndef _NPY_SIMD_NEON_MEMORY_H +#define _NPY_SIMD_NEON_MEMORY_H + +#include "misc.h" + +/*************************** + * load/store + ***************************/ +// GCC requires literal type definitions for pointers types otherwise it causes ambiguous errors +#define NPYV_IMPL_NEON_MEM(SFX, CTYPE) \ + NPY_FINLINE npyv_##SFX npyv_load_##SFX(const npyv_lanetype_##SFX *ptr) \ + { return vld1q_##SFX((const CTYPE*)ptr); } \ + NPY_FINLINE npyv_##SFX npyv_loada_##SFX(const npyv_lanetype_##SFX *ptr) \ + { return vld1q_##SFX((const CTYPE*)ptr); } \ + NPY_FINLINE npyv_##SFX npyv_loads_##SFX(const npyv_lanetype_##SFX *ptr) \ + { return vld1q_##SFX((const CTYPE*)ptr); } \ + NPY_FINLINE npyv_##SFX npyv_loadl_##SFX(const npyv_lanetype_##SFX *ptr) \ + { \ + return vcombine_##SFX( \ + vld1_##SFX((const CTYPE*)ptr), vdup_n_##SFX(0) \ + ); \ + } \ + NPY_FINLINE void npyv_store_##SFX(npyv_lanetype_##SFX *ptr, npyv_##SFX vec) \ + { vst1q_##SFX((CTYPE*)ptr, vec); } \ + NPY_FINLINE void npyv_storea_##SFX(npyv_lanetype_##SFX *ptr, npyv_##SFX vec) \ + { vst1q_##SFX((CTYPE*)ptr, vec); } \ + NPY_FINLINE void npyv_stores_##SFX(npyv_lanetype_##SFX *ptr, npyv_##SFX vec) \ + { vst1q_##SFX((CTYPE*)ptr, vec); } \ + NPY_FINLINE void npyv_storel_##SFX(npyv_lanetype_##SFX *ptr, npyv_##SFX vec) \ + { vst1_##SFX((CTYPE*)ptr, vget_low_##SFX(vec)); } \ + NPY_FINLINE void npyv_storeh_##SFX(npyv_lanetype_##SFX *ptr, npyv_##SFX vec) \ + { vst1_##SFX((CTYPE*)ptr, vget_high_##SFX(vec)); } + +NPYV_IMPL_NEON_MEM(u8, uint8_t) +NPYV_IMPL_NEON_MEM(s8, int8_t) +NPYV_IMPL_NEON_MEM(u16, uint16_t) +NPYV_IMPL_NEON_MEM(s16, int16_t) +NPYV_IMPL_NEON_MEM(u32, uint32_t) +NPYV_IMPL_NEON_MEM(s32, int32_t) +NPYV_IMPL_NEON_MEM(u64, uint64_t) +NPYV_IMPL_NEON_MEM(s64, int64_t) +NPYV_IMPL_NEON_MEM(f32, float) +#if NPY_SIMD_F64 +NPYV_IMPL_NEON_MEM(f64, double) +#endif +/*************************** + * Non-contiguous Load + ***************************/ +NPY_FINLINE npyv_s32 npyv_loadn_s32(const npy_int32 *ptr, npy_intp stride) +{ + int32x4_t a = vdupq_n_s32(0); + a = vld1q_lane_s32((const int32_t*)ptr, a, 0); + a = vld1q_lane_s32((const int32_t*)ptr + stride, a, 1); + a = vld1q_lane_s32((const int32_t*)ptr + stride*2, a, 2); + a = vld1q_lane_s32((const int32_t*)ptr + stride*3, a, 3); + return a; +} + +NPY_FINLINE npyv_u32 npyv_loadn_u32(const npy_uint32 *ptr, npy_intp stride) +{ + return npyv_reinterpret_u32_s32( + npyv_loadn_s32((const npy_int32*)ptr, stride) + ); +} +NPY_FINLINE npyv_f32 npyv_loadn_f32(const float *ptr, npy_intp stride) +{ + return npyv_reinterpret_f32_s32( + npyv_loadn_s32((const npy_int32*)ptr, stride) + ); +} +//// 64 +NPY_FINLINE npyv_s64 npyv_loadn_s64(const npy_int64 *ptr, npy_intp stride) +{ + return vcombine_s64( + vld1_s64((const int64_t*)ptr), vld1_s64((const int64_t*)ptr + stride) + ); +} +NPY_FINLINE npyv_u64 npyv_loadn_u64(const npy_uint64 *ptr, npy_intp stride) +{ + return npyv_reinterpret_u64_s64( + npyv_loadn_s64((const npy_int64*)ptr, stride) + ); +} +#if NPY_SIMD_F64 +NPY_FINLINE npyv_f64 npyv_loadn_f64(const double *ptr, npy_intp stride) +{ + return npyv_reinterpret_f64_s64( + npyv_loadn_s64((const npy_int64*)ptr, stride) + ); +} +#endif + +//// 64-bit load over 32-bit stride +NPY_FINLINE npyv_u32 npyv_loadn2_u32(const npy_uint32 *ptr, npy_intp stride) +{ + return vcombine_u32( + vld1_u32((const uint32_t*)ptr), vld1_u32((const uint32_t*)ptr + stride) + ); +} +NPY_FINLINE npyv_s32 npyv_loadn2_s32(const npy_int32 *ptr, npy_intp stride) +{ return npyv_reinterpret_s32_u32(npyv_loadn2_u32((const npy_uint32*)ptr, stride)); } +NPY_FINLINE npyv_f32 npyv_loadn2_f32(const float *ptr, npy_intp stride) +{ return npyv_reinterpret_f32_u32(npyv_loadn2_u32((const npy_uint32*)ptr, stride)); } + +//// 128-bit load over 64-bit stride +NPY_FINLINE npyv_u64 npyv_loadn2_u64(const npy_uint64 *ptr, npy_intp stride) +{ (void)stride; return npyv_load_u64(ptr); } +NPY_FINLINE npyv_s64 npyv_loadn2_s64(const npy_int64 *ptr, npy_intp stride) +{ (void)stride; return npyv_load_s64(ptr); } +#if NPY_SIMD_F64 +NPY_FINLINE npyv_f64 npyv_loadn2_f64(const double *ptr, npy_intp stride) +{ (void)stride; return npyv_load_f64(ptr); } +#endif + +/*************************** + * Non-contiguous Store + ***************************/ +//// 32 +NPY_FINLINE void npyv_storen_s32(npy_int32 *ptr, npy_intp stride, npyv_s32 a) +{ + vst1q_lane_s32((int32_t*)ptr, a, 0); + vst1q_lane_s32((int32_t*)ptr + stride, a, 1); + vst1q_lane_s32((int32_t*)ptr + stride*2, a, 2); + vst1q_lane_s32((int32_t*)ptr + stride*3, a, 3); +} +NPY_FINLINE void npyv_storen_u32(npy_uint32 *ptr, npy_intp stride, npyv_u32 a) +{ npyv_storen_s32((npy_int32*)ptr, stride, npyv_reinterpret_s32_u32(a)); } +NPY_FINLINE void npyv_storen_f32(float *ptr, npy_intp stride, npyv_f32 a) +{ npyv_storen_s32((npy_int32*)ptr, stride, npyv_reinterpret_s32_f32(a)); } +//// 64 +NPY_FINLINE void npyv_storen_s64(npy_int64 *ptr, npy_intp stride, npyv_s64 a) +{ + vst1q_lane_s64((int64_t*)ptr, a, 0); + vst1q_lane_s64((int64_t*)ptr + stride, a, 1); +} +NPY_FINLINE void npyv_storen_u64(npy_uint64 *ptr, npy_intp stride, npyv_u64 a) +{ npyv_storen_s64((npy_int64*)ptr, stride, npyv_reinterpret_s64_u64(a)); } + +#if NPY_SIMD_F64 +NPY_FINLINE void npyv_storen_f64(double *ptr, npy_intp stride, npyv_f64 a) +{ npyv_storen_s64((npy_int64*)ptr, stride, npyv_reinterpret_s64_f64(a)); } +#endif + +//// 64-bit store over 32-bit stride +NPY_FINLINE void npyv_storen2_u32(npy_uint32 *ptr, npy_intp stride, npyv_u32 a) +{ +#if NPY_SIMD_F64 + vst1q_lane_u64((uint64_t*)ptr, npyv_reinterpret_u64_u32(a), 0); + vst1q_lane_u64((uint64_t*)(ptr + stride), npyv_reinterpret_u64_u32(a), 1); +#else + // armhf strict to alignment + vst1_u32((uint32_t*)ptr, vget_low_u32(a)); + vst1_u32((uint32_t*)ptr + stride, vget_high_u32(a)); +#endif +} +NPY_FINLINE void npyv_storen2_s32(npy_int32 *ptr, npy_intp stride, npyv_s32 a) +{ npyv_storen2_u32((npy_uint32*)ptr, stride, npyv_reinterpret_u32_s32(a)); } +NPY_FINLINE void npyv_storen2_f32(float *ptr, npy_intp stride, npyv_f32 a) +{ npyv_storen2_u32((npy_uint32*)ptr, stride, npyv_reinterpret_u32_f32(a)); } + +//// 128-bit store over 64-bit stride +NPY_FINLINE void npyv_storen2_u64(npy_uint64 *ptr, npy_intp stride, npyv_u64 a) +{ (void)stride; npyv_store_u64(ptr, a); } +NPY_FINLINE void npyv_storen2_s64(npy_int64 *ptr, npy_intp stride, npyv_s64 a) +{ (void)stride; npyv_store_s64(ptr, a); } +#if NPY_SIMD_F64 +NPY_FINLINE void npyv_storen2_f64(double *ptr, npy_intp stride, npyv_f64 a) +{ (void)stride; npyv_store_f64(ptr, a); } +#endif +/********************************* + * Partial Load + *********************************/ +//// 32 +NPY_FINLINE npyv_s32 npyv_load_till_s32(const npy_int32 *ptr, npy_uintp nlane, npy_int32 fill) +{ + assert(nlane > 0); + npyv_s32 a; + switch(nlane) { + case 1: + a = vld1q_lane_s32((const int32_t*)ptr, vdupq_n_s32(fill), 0); + break; + case 2: + a = vcombine_s32(vld1_s32((const int32_t*)ptr), vdup_n_s32(fill)); + break; + case 3: + a = vcombine_s32( + vld1_s32((const int32_t*)ptr), + vld1_lane_s32((const int32_t*)ptr + 2, vdup_n_s32(fill), 0) + ); + break; + default: + return npyv_load_s32(ptr); + } +#if NPY_SIMD_GUARD_PARTIAL_LOAD + volatile npyv_s32 workaround = a; + a = vorrq_s32(workaround, a); +#endif + return a; +} +// fill zero to rest lanes +NPY_FINLINE npyv_s32 npyv_load_tillz_s32(const npy_int32 *ptr, npy_uintp nlane) +{ return npyv_load_till_s32(ptr, nlane, 0); } +//// 64 +NPY_FINLINE npyv_s64 npyv_load_till_s64(const npy_int64 *ptr, npy_uintp nlane, npy_int64 fill) +{ + assert(nlane > 0); + if (nlane == 1) { + npyv_s64 a = vcombine_s64(vld1_s64((const int64_t*)ptr), vdup_n_s64(fill)); + #if NPY_SIMD_GUARD_PARTIAL_LOAD + volatile npyv_s64 workaround = a; + a = vorrq_s64(workaround, a); + #endif + return a; + } + return npyv_load_s64(ptr); +} +// fill zero to rest lanes +NPY_FINLINE npyv_s64 npyv_load_tillz_s64(const npy_int64 *ptr, npy_uintp nlane) +{ return npyv_load_till_s64(ptr, nlane, 0); } + +//// 64-bit nlane +NPY_FINLINE npyv_s32 npyv_load2_till_s32(const npy_int32 *ptr, npy_uintp nlane, + npy_int32 fill_lo, npy_int32 fill_hi) +{ + assert(nlane > 0); + if (nlane == 1) { + const int32_t NPY_DECL_ALIGNED(16) fill[2] = {fill_lo, fill_hi}; + npyv_s32 a = vcombine_s32(vld1_s32((const int32_t*)ptr), vld1_s32(fill)); + #if NPY_SIMD_GUARD_PARTIAL_LOAD + volatile npyv_s32 workaround = a; + a = vorrq_s32(workaround, a); + #endif + return a; + } + return npyv_load_s32(ptr); +} +// fill zero to rest lanes +NPY_FINLINE npyv_s32 npyv_load2_tillz_s32(const npy_int32 *ptr, npy_uintp nlane) +{ return vreinterpretq_s32_s64(npyv_load_tillz_s64((const npy_int64*)ptr, nlane)); } + +//// 128-bit nlane +NPY_FINLINE npyv_s64 npyv_load2_till_s64(const npy_int64 *ptr, npy_uintp nlane, + npy_int64 fill_lo, npy_int64 fill_hi) +{ (void)nlane; (void)fill_lo; (void)fill_hi; return npyv_load_s64(ptr); } + +NPY_FINLINE npyv_s64 npyv_load2_tillz_s64(const npy_int64 *ptr, npy_uintp nlane) +{ (void)nlane; return npyv_load_s64(ptr); } + +/********************************* + * Non-contiguous partial load + *********************************/ +//// 32 +NPY_FINLINE npyv_s32 +npyv_loadn_till_s32(const npy_int32 *ptr, npy_intp stride, npy_uintp nlane, npy_int32 fill) +{ + assert(nlane > 0); + int32x4_t vfill = vdupq_n_s32(fill); + switch(nlane) { + case 3: + vfill = vld1q_lane_s32((const int32_t*)ptr + stride*2, vfill, 2); + case 2: + vfill = vld1q_lane_s32((const int32_t*)ptr + stride, vfill, 1); + case 1: + vfill = vld1q_lane_s32((const int32_t*)ptr, vfill, 0); + break; + default: + return npyv_loadn_s32(ptr, stride); + } +#if NPY_SIMD_GUARD_PARTIAL_LOAD + volatile npyv_s32 workaround = vfill; + vfill = vorrq_s32(workaround, vfill); +#endif + return vfill; +} +NPY_FINLINE npyv_s32 +npyv_loadn_tillz_s32(const npy_int32 *ptr, npy_intp stride, npy_uintp nlane) +{ return npyv_loadn_till_s32(ptr, stride, nlane, 0); } + +NPY_FINLINE npyv_s64 +npyv_loadn_till_s64(const npy_int64 *ptr, npy_intp stride, npy_uintp nlane, npy_int64 fill) +{ + assert(nlane > 0); + if (nlane == 1) { + return npyv_load_till_s64(ptr, 1, fill); + } + return npyv_loadn_s64(ptr, stride); +} +// fill zero to rest lanes +NPY_FINLINE npyv_s64 npyv_loadn_tillz_s64(const npy_int64 *ptr, npy_intp stride, npy_uintp nlane) +{ return npyv_loadn_till_s64(ptr, stride, nlane, 0); } + +//// 64-bit load over 32-bit stride +NPY_FINLINE npyv_s32 npyv_loadn2_till_s32(const npy_int32 *ptr, npy_intp stride, npy_uintp nlane, + npy_int32 fill_lo, npy_int32 fill_hi) +{ + assert(nlane > 0); + if (nlane == 1) { + const int32_t NPY_DECL_ALIGNED(16) fill[2] = {fill_lo, fill_hi}; + npyv_s32 a = vcombine_s32(vld1_s32((const int32_t*)ptr), vld1_s32(fill)); + #if NPY_SIMD_GUARD_PARTIAL_LOAD + volatile npyv_s32 workaround = a; + a = vorrq_s32(workaround, a); + #endif + return a; + } + return npyv_loadn2_s32(ptr, stride); +} +NPY_FINLINE npyv_s32 npyv_loadn2_tillz_s32(const npy_int32 *ptr, npy_intp stride, npy_uintp nlane) +{ + assert(nlane > 0); + if (nlane == 1) { + npyv_s32 a = vcombine_s32(vld1_s32((const int32_t*)ptr), vdup_n_s32(0)); + #if NPY_SIMD_GUARD_PARTIAL_LOAD + volatile npyv_s32 workaround = a; + a = vorrq_s32(workaround, a); + #endif + return a; + } + return npyv_loadn2_s32(ptr, stride); +} + +//// 128-bit load over 64-bit stride +NPY_FINLINE npyv_s64 npyv_loadn2_till_s64(const npy_int64 *ptr, npy_intp stride, npy_uintp nlane, + npy_int64 fill_lo, npy_int64 fill_hi) +{ assert(nlane > 0); (void)stride; (void)nlane; (void)fill_lo; (void)fill_hi; return npyv_load_s64(ptr); } + +NPY_FINLINE npyv_s64 npyv_loadn2_tillz_s64(const npy_int64 *ptr, npy_intp stride, npy_uintp nlane) +{ assert(nlane > 0); (void)stride; (void)nlane; return npyv_load_s64(ptr); } + +/********************************* + * Partial store + *********************************/ +//// 32 +NPY_FINLINE void npyv_store_till_s32(npy_int32 *ptr, npy_uintp nlane, npyv_s32 a) +{ + assert(nlane > 0); + switch(nlane) { + case 1: + vst1q_lane_s32((int32_t*)ptr, a, 0); + break; + case 2: + vst1_s32((int32_t*)ptr, vget_low_s32(a)); + break; + case 3: + vst1_s32((int32_t*)ptr, vget_low_s32(a)); + vst1q_lane_s32((int32_t*)ptr + 2, a, 2); + break; + default: + npyv_store_s32(ptr, a); + } +} +//// 64 +NPY_FINLINE void npyv_store_till_s64(npy_int64 *ptr, npy_uintp nlane, npyv_s64 a) +{ + assert(nlane > 0); + if (nlane == 1) { + vst1q_lane_s64((int64_t*)ptr, a, 0); + return; + } + npyv_store_s64(ptr, a); +} + +//// 64-bit nlane +NPY_FINLINE void npyv_store2_till_s32(npy_int32 *ptr, npy_uintp nlane, npyv_s32 a) +{ + assert(nlane > 0); + if (nlane == 1) { + // armhf strict to alignment, may cause bus error + #if NPY_SIMD_F64 + vst1q_lane_s64((int64_t*)ptr, npyv_reinterpret_s64_s32(a), 0); + #else + npyv_storel_s32(ptr, a); + #endif + return; + } + npyv_store_s32(ptr, a); +} + +//// 128-bit nlane +NPY_FINLINE void npyv_store2_till_s64(npy_int64 *ptr, npy_uintp nlane, npyv_s64 a) +{ + assert(nlane > 0); (void)nlane; + npyv_store_s64(ptr, a); +} + +/********************************* + * Non-contiguous partial store + *********************************/ +//// 32 +NPY_FINLINE void npyv_storen_till_s32(npy_int32 *ptr, npy_intp stride, npy_uintp nlane, npyv_s32 a) +{ + assert(nlane > 0); + vst1q_lane_s32((int32_t*)ptr, a, 0); + switch(nlane) { + case 1: + return; + case 2: + vst1q_lane_s32((int32_t*)ptr + stride, a, 1); + return; + case 3: + vst1q_lane_s32((int32_t*)ptr + stride, a, 1); + vst1q_lane_s32((int32_t*)ptr + stride*2, a, 2); + return; + default: + vst1q_lane_s32((int32_t*)ptr + stride, a, 1); + vst1q_lane_s32((int32_t*)ptr + stride*2, a, 2); + vst1q_lane_s32((int32_t*)ptr + stride*3, a, 3); + } +} +//// 64 +NPY_FINLINE void npyv_storen_till_s64(npy_int64 *ptr, npy_intp stride, npy_uintp nlane, npyv_s64 a) +{ + assert(nlane > 0); + if (nlane == 1) { + vst1q_lane_s64((int64_t*)ptr, a, 0); + return; + } + npyv_storen_s64(ptr, stride, a); +} + +//// 64-bit store over 32-bit stride +NPY_FINLINE void npyv_storen2_till_s32(npy_int32 *ptr, npy_intp stride, npy_uintp nlane, npyv_s32 a) +{ + assert(nlane > 0); +#if NPY_SIMD_F64 + vst1q_lane_s64((int64_t*)ptr, npyv_reinterpret_s64_s32(a), 0); + if (nlane > 1) { + vst1q_lane_s64((int64_t*)(ptr + stride), npyv_reinterpret_s64_s32(a), 1); + } +#else + npyv_storel_s32(ptr, a); + if (nlane > 1) { + npyv_storeh_s32(ptr + stride, a); + } +#endif +} + +//// 128-bit store over 64-bit stride +NPY_FINLINE void npyv_storen2_till_s64(npy_int64 *ptr, npy_intp stride, npy_uintp nlane, npyv_s64 a) +{ assert(nlane > 0); (void)stride; (void)nlane; npyv_store_s64(ptr, a); } + +/***************************************************************** + * Implement partial load/store for u32/f32/u64/f64... via casting + *****************************************************************/ +#define NPYV_IMPL_NEON_REST_PARTIAL_TYPES(F_SFX, T_SFX) \ + NPY_FINLINE npyv_##F_SFX npyv_load_till_##F_SFX \ + (const npyv_lanetype_##F_SFX *ptr, npy_uintp nlane, npyv_lanetype_##F_SFX fill) \ + { \ + union { \ + npyv_lanetype_##F_SFX from_##F_SFX; \ + npyv_lanetype_##T_SFX to_##T_SFX; \ + } pun; \ + pun.from_##F_SFX = fill; \ + return npyv_reinterpret_##F_SFX##_##T_SFX(npyv_load_till_##T_SFX( \ + (const npyv_lanetype_##T_SFX *)ptr, nlane, pun.to_##T_SFX \ + )); \ + } \ + NPY_FINLINE npyv_##F_SFX npyv_loadn_till_##F_SFX \ + (const npyv_lanetype_##F_SFX *ptr, npy_intp stride, npy_uintp nlane, \ + npyv_lanetype_##F_SFX fill) \ + { \ + union { \ + npyv_lanetype_##F_SFX from_##F_SFX; \ + npyv_lanetype_##T_SFX to_##T_SFX; \ + } pun; \ + pun.from_##F_SFX = fill; \ + return npyv_reinterpret_##F_SFX##_##T_SFX(npyv_loadn_till_##T_SFX( \ + (const npyv_lanetype_##T_SFX *)ptr, stride, nlane, pun.to_##T_SFX \ + )); \ + } \ + NPY_FINLINE npyv_##F_SFX npyv_load_tillz_##F_SFX \ + (const npyv_lanetype_##F_SFX *ptr, npy_uintp nlane) \ + { \ + return npyv_reinterpret_##F_SFX##_##T_SFX(npyv_load_tillz_##T_SFX( \ + (const npyv_lanetype_##T_SFX *)ptr, nlane \ + )); \ + } \ + NPY_FINLINE npyv_##F_SFX npyv_loadn_tillz_##F_SFX \ + (const npyv_lanetype_##F_SFX *ptr, npy_intp stride, npy_uintp nlane) \ + { \ + return npyv_reinterpret_##F_SFX##_##T_SFX(npyv_loadn_tillz_##T_SFX( \ + (const npyv_lanetype_##T_SFX *)ptr, stride, nlane \ + )); \ + } \ + NPY_FINLINE void npyv_store_till_##F_SFX \ + (npyv_lanetype_##F_SFX *ptr, npy_uintp nlane, npyv_##F_SFX a) \ + { \ + npyv_store_till_##T_SFX( \ + (npyv_lanetype_##T_SFX *)ptr, nlane, \ + npyv_reinterpret_##T_SFX##_##F_SFX(a) \ + ); \ + } \ + NPY_FINLINE void npyv_storen_till_##F_SFX \ + (npyv_lanetype_##F_SFX *ptr, npy_intp stride, npy_uintp nlane, npyv_##F_SFX a) \ + { \ + npyv_storen_till_##T_SFX( \ + (npyv_lanetype_##T_SFX *)ptr, stride, nlane, \ + npyv_reinterpret_##T_SFX##_##F_SFX(a) \ + ); \ + } + +NPYV_IMPL_NEON_REST_PARTIAL_TYPES(u32, s32) +NPYV_IMPL_NEON_REST_PARTIAL_TYPES(f32, s32) +NPYV_IMPL_NEON_REST_PARTIAL_TYPES(u64, s64) +#if NPY_SIMD_F64 +NPYV_IMPL_NEON_REST_PARTIAL_TYPES(f64, s64) +#endif + +// 128-bit/64-bit stride +#define NPYV_IMPL_NEON_REST_PARTIAL_TYPES_PAIR(F_SFX, T_SFX) \ + NPY_FINLINE npyv_##F_SFX npyv_load2_till_##F_SFX \ + (const npyv_lanetype_##F_SFX *ptr, npy_uintp nlane, \ + npyv_lanetype_##F_SFX fill_lo, npyv_lanetype_##F_SFX fill_hi) \ + { \ + union pun { \ + npyv_lanetype_##F_SFX from_##F_SFX; \ + npyv_lanetype_##T_SFX to_##T_SFX; \ + }; \ + union pun pun_lo; \ + union pun pun_hi; \ + pun_lo.from_##F_SFX = fill_lo; \ + pun_hi.from_##F_SFX = fill_hi; \ + return npyv_reinterpret_##F_SFX##_##T_SFX(npyv_load2_till_##T_SFX( \ + (const npyv_lanetype_##T_SFX *)ptr, nlane, pun_lo.to_##T_SFX, pun_hi.to_##T_SFX \ + )); \ + } \ + NPY_FINLINE npyv_##F_SFX npyv_loadn2_till_##F_SFX \ + (const npyv_lanetype_##F_SFX *ptr, npy_intp stride, npy_uintp nlane, \ + npyv_lanetype_##F_SFX fill_lo, npyv_lanetype_##F_SFX fill_hi) \ + { \ + union pun { \ + npyv_lanetype_##F_SFX from_##F_SFX; \ + npyv_lanetype_##T_SFX to_##T_SFX; \ + }; \ + union pun pun_lo; \ + union pun pun_hi; \ + pun_lo.from_##F_SFX = fill_lo; \ + pun_hi.from_##F_SFX = fill_hi; \ + return npyv_reinterpret_##F_SFX##_##T_SFX(npyv_loadn2_till_##T_SFX( \ + (const npyv_lanetype_##T_SFX *)ptr, stride, nlane, pun_lo.to_##T_SFX, \ + pun_hi.to_##T_SFX \ + )); \ + } \ + NPY_FINLINE npyv_##F_SFX npyv_load2_tillz_##F_SFX \ + (const npyv_lanetype_##F_SFX *ptr, npy_uintp nlane) \ + { \ + return npyv_reinterpret_##F_SFX##_##T_SFX(npyv_load2_tillz_##T_SFX( \ + (const npyv_lanetype_##T_SFX *)ptr, nlane \ + )); \ + } \ + NPY_FINLINE npyv_##F_SFX npyv_loadn2_tillz_##F_SFX \ + (const npyv_lanetype_##F_SFX *ptr, npy_intp stride, npy_uintp nlane) \ + { \ + return npyv_reinterpret_##F_SFX##_##T_SFX(npyv_loadn2_tillz_##T_SFX( \ + (const npyv_lanetype_##T_SFX *)ptr, stride, nlane \ + )); \ + } \ + NPY_FINLINE void npyv_store2_till_##F_SFX \ + (npyv_lanetype_##F_SFX *ptr, npy_uintp nlane, npyv_##F_SFX a) \ + { \ + npyv_store2_till_##T_SFX( \ + (npyv_lanetype_##T_SFX *)ptr, nlane, \ + npyv_reinterpret_##T_SFX##_##F_SFX(a) \ + ); \ + } \ + NPY_FINLINE void npyv_storen2_till_##F_SFX \ + (npyv_lanetype_##F_SFX *ptr, npy_intp stride, npy_uintp nlane, npyv_##F_SFX a) \ + { \ + npyv_storen2_till_##T_SFX( \ + (npyv_lanetype_##T_SFX *)ptr, stride, nlane, \ + npyv_reinterpret_##T_SFX##_##F_SFX(a) \ + ); \ + } + +NPYV_IMPL_NEON_REST_PARTIAL_TYPES_PAIR(u32, s32) +NPYV_IMPL_NEON_REST_PARTIAL_TYPES_PAIR(f32, s32) +NPYV_IMPL_NEON_REST_PARTIAL_TYPES_PAIR(u64, s64) +#if NPY_SIMD_F64 +NPYV_IMPL_NEON_REST_PARTIAL_TYPES_PAIR(f64, s64) +#endif + +/************************************************************ + * de-interlave load / interleave contiguous store + ************************************************************/ +// two channels +#define NPYV_IMPL_NEON_MEM_INTERLEAVE(SFX, T_PTR) \ + NPY_FINLINE npyv_##SFX##x2 npyv_load_##SFX##x2( \ + const npyv_lanetype_##SFX *ptr \ + ) { \ + return vld2q_##SFX((const T_PTR*)ptr); \ + } \ + NPY_FINLINE void npyv_store_##SFX##x2( \ + npyv_lanetype_##SFX *ptr, npyv_##SFX##x2 v \ + ) { \ + vst2q_##SFX((T_PTR*)ptr, v); \ + } + +NPYV_IMPL_NEON_MEM_INTERLEAVE(u8, uint8_t) +NPYV_IMPL_NEON_MEM_INTERLEAVE(s8, int8_t) +NPYV_IMPL_NEON_MEM_INTERLEAVE(u16, uint16_t) +NPYV_IMPL_NEON_MEM_INTERLEAVE(s16, int16_t) +NPYV_IMPL_NEON_MEM_INTERLEAVE(u32, uint32_t) +NPYV_IMPL_NEON_MEM_INTERLEAVE(s32, int32_t) +NPYV_IMPL_NEON_MEM_INTERLEAVE(f32, float) + +#if NPY_SIMD_F64 + NPYV_IMPL_NEON_MEM_INTERLEAVE(f64, double) + NPYV_IMPL_NEON_MEM_INTERLEAVE(u64, uint64_t) + NPYV_IMPL_NEON_MEM_INTERLEAVE(s64, int64_t) +#else + #define NPYV_IMPL_NEON_MEM_INTERLEAVE_64(SFX) \ + NPY_FINLINE npyv_##SFX##x2 npyv_load_##SFX##x2( \ + const npyv_lanetype_##SFX *ptr) \ + { \ + npyv_##SFX a = npyv_load_##SFX(ptr); \ + npyv_##SFX b = npyv_load_##SFX(ptr + 2); \ + npyv_##SFX##x2 r; \ + r.val[0] = vcombine_##SFX(vget_low_##SFX(a), vget_low_##SFX(b)); \ + r.val[1] = vcombine_##SFX(vget_high_##SFX(a), vget_high_##SFX(b)); \ + return r; \ + } \ + NPY_FINLINE void npyv_store_##SFX##x2( \ + npyv_lanetype_##SFX *ptr, npyv_##SFX##x2 v) \ + { \ + npyv_store_##SFX(ptr, vcombine_##SFX( \ + vget_low_##SFX(v.val[0]), vget_low_##SFX(v.val[1]))); \ + npyv_store_##SFX(ptr + 2, vcombine_##SFX( \ + vget_high_##SFX(v.val[0]), vget_high_##SFX(v.val[1]))); \ + } + NPYV_IMPL_NEON_MEM_INTERLEAVE_64(u64) + NPYV_IMPL_NEON_MEM_INTERLEAVE_64(s64) +#endif +/********************************* + * Lookup table + *********************************/ +// uses vector as indexes into a table +// that contains 32 elements of uint32. +NPY_FINLINE npyv_u32 npyv_lut32_u32(const npy_uint32 *table, npyv_u32 idx) +{ + const unsigned i0 = vgetq_lane_u32(idx, 0); + const unsigned i1 = vgetq_lane_u32(idx, 1); + const unsigned i2 = vgetq_lane_u32(idx, 2); + const unsigned i3 = vgetq_lane_u32(idx, 3); + + uint32x2_t low = vcreate_u32(table[i0]); + low = vld1_lane_u32((const uint32_t*)table + i1, low, 1); + uint32x2_t high = vcreate_u32(table[i2]); + high = vld1_lane_u32((const uint32_t*)table + i3, high, 1); + return vcombine_u32(low, high); +} +NPY_FINLINE npyv_s32 npyv_lut32_s32(const npy_int32 *table, npyv_u32 idx) +{ return npyv_reinterpret_s32_u32(npyv_lut32_u32((const npy_uint32*)table, idx)); } +NPY_FINLINE npyv_f32 npyv_lut32_f32(const float *table, npyv_u32 idx) +{ return npyv_reinterpret_f32_u32(npyv_lut32_u32((const npy_uint32*)table, idx)); } + +// uses vector as indexes into a table +// that contains 16 elements of uint64. +NPY_FINLINE npyv_u64 npyv_lut16_u64(const npy_uint64 *table, npyv_u64 idx) +{ + const unsigned i0 = vgetq_lane_u32(vreinterpretq_u32_u64(idx), 0); + const unsigned i1 = vgetq_lane_u32(vreinterpretq_u32_u64(idx), 2); + return vcombine_u64( + vld1_u64((const uint64_t*)table + i0), + vld1_u64((const uint64_t*)table + i1) + ); +} +NPY_FINLINE npyv_s64 npyv_lut16_s64(const npy_int64 *table, npyv_u64 idx) +{ return npyv_reinterpret_s64_u64(npyv_lut16_u64((const npy_uint64*)table, idx)); } +#if NPY_SIMD_F64 +NPY_FINLINE npyv_f64 npyv_lut16_f64(const double *table, npyv_u64 idx) +{ return npyv_reinterpret_f64_u64(npyv_lut16_u64((const npy_uint64*)table, idx)); } +#endif + +#endif // _NPY_SIMD_NEON_MEMORY_H diff --git a/mkl_umath/src/npyv/neon/misc.h b/mkl_umath/src/npyv/neon/misc.h new file mode 100644 index 00000000..9dac0cfa --- /dev/null +++ b/mkl_umath/src/npyv/neon/misc.h @@ -0,0 +1,275 @@ +#ifndef NPY_SIMD + #error "Not a standalone header" +#endif + +#ifndef _NPY_SIMD_NEON_MISC_H +#define _NPY_SIMD_NEON_MISC_H + +// vector with zero lanes +#define npyv_zero_u8() vreinterpretq_u8_s32(npyv_zero_s32()) +#define npyv_zero_s8() vreinterpretq_s8_s32(npyv_zero_s32()) +#define npyv_zero_u16() vreinterpretq_u16_s32(npyv_zero_s32()) +#define npyv_zero_s16() vreinterpretq_s16_s32(npyv_zero_s32()) +#define npyv_zero_u32() vdupq_n_u32((unsigned)0) +#define npyv_zero_s32() vdupq_n_s32((int)0) +#define npyv_zero_u64() vreinterpretq_u64_s32(npyv_zero_s32()) +#define npyv_zero_s64() vreinterpretq_s64_s32(npyv_zero_s32()) +#define npyv_zero_f32() vdupq_n_f32(0.0f) +#define npyv_zero_f64() vdupq_n_f64(0.0) + +// vector with a specific value set to all lanes +#define npyv_setall_u8 vdupq_n_u8 +#define npyv_setall_s8 vdupq_n_s8 +#define npyv_setall_u16 vdupq_n_u16 +#define npyv_setall_s16 vdupq_n_s16 +#define npyv_setall_u32 vdupq_n_u32 +#define npyv_setall_s32 vdupq_n_s32 +#define npyv_setall_u64 vdupq_n_u64 +#define npyv_setall_s64 vdupq_n_s64 +#define npyv_setall_f32 vdupq_n_f32 +#define npyv_setall_f64 vdupq_n_f64 + +// vector with specific values set to each lane and +// set a specific value to all remained lanes +#if defined(__clang__) || defined(__GNUC__) + #define npyv_setf_u8(FILL, ...) ((uint8x16_t){NPYV__SET_FILL_16(uint8_t, FILL, __VA_ARGS__)}) + #define npyv_setf_s8(FILL, ...) ((int8x16_t){NPYV__SET_FILL_16(int8_t, FILL, __VA_ARGS__)}) + #define npyv_setf_u16(FILL, ...) ((uint16x8_t){NPYV__SET_FILL_8(uint16_t, FILL, __VA_ARGS__)}) + #define npyv_setf_s16(FILL, ...) ((int16x8_t){NPYV__SET_FILL_8(int16_t, FILL, __VA_ARGS__)}) + #define npyv_setf_u32(FILL, ...) ((uint32x4_t){NPYV__SET_FILL_4(uint32_t, FILL, __VA_ARGS__)}) + #define npyv_setf_s32(FILL, ...) ((int32x4_t){NPYV__SET_FILL_4(int32_t, FILL, __VA_ARGS__)}) + #define npyv_setf_u64(FILL, ...) ((uint64x2_t){NPYV__SET_FILL_2(uint64_t, FILL, __VA_ARGS__)}) + #define npyv_setf_s64(FILL, ...) ((int64x2_t){NPYV__SET_FILL_2(int64_t, FILL, __VA_ARGS__)}) + #define npyv_setf_f32(FILL, ...) ((float32x4_t){NPYV__SET_FILL_4(float, FILL, __VA_ARGS__)}) + #if NPY_SIMD_F64 + #define npyv_setf_f64(FILL, ...) ((float64x2_t){NPYV__SET_FILL_2(double, FILL, __VA_ARGS__)}) + #endif +#else + NPY_FINLINE uint8x16_t npyv__set_u8(npy_uint8 i0, npy_uint8 i1, npy_uint8 i2, npy_uint8 i3, + npy_uint8 i4, npy_uint8 i5, npy_uint8 i6, npy_uint8 i7, npy_uint8 i8, npy_uint8 i9, + npy_uint8 i10, npy_uint8 i11, npy_uint8 i12, npy_uint8 i13, npy_uint8 i14, npy_uint8 i15) + { + const uint8_t NPY_DECL_ALIGNED(16) data[16] = { + i0, i1, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12, i13, i14, i15 + }; + return vld1q_u8(data); + } + NPY_FINLINE int8x16_t npyv__set_s8(npy_int8 i0, npy_int8 i1, npy_int8 i2, npy_int8 i3, + npy_int8 i4, npy_int8 i5, npy_int8 i6, npy_int8 i7, npy_int8 i8, npy_int8 i9, + npy_int8 i10, npy_int8 i11, npy_int8 i12, npy_int8 i13, npy_int8 i14, npy_int8 i15) + { + const int8_t NPY_DECL_ALIGNED(16) data[16] = { + i0, i1, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12, i13, i14, i15 + }; + return vld1q_s8(data); + } + NPY_FINLINE uint16x8_t npyv__set_u16(npy_uint16 i0, npy_uint16 i1, npy_uint16 i2, npy_uint16 i3, + npy_uint16 i4, npy_uint16 i5, npy_uint16 i6, npy_uint16 i7) + { + const uint16_t NPY_DECL_ALIGNED(16) data[8] = {i0, i1, i2, i3, i4, i5, i6, i7}; + return vld1q_u16(data); + } + NPY_FINLINE int16x8_t npyv__set_s16(npy_int16 i0, npy_int16 i1, npy_int16 i2, npy_int16 i3, + npy_int16 i4, npy_int16 i5, npy_int16 i6, npy_int16 i7) + { + const int16_t NPY_DECL_ALIGNED(16) data[8] = {i0, i1, i2, i3, i4, i5, i6, i7}; + return vld1q_s16(data); + } + NPY_FINLINE uint32x4_t npyv__set_u32(npy_uint32 i0, npy_uint32 i1, npy_uint32 i2, npy_uint32 i3) + { + const uint32_t NPY_DECL_ALIGNED(16) data[4] = {i0, i1, i2, i3}; + return vld1q_u32(data); + } + NPY_FINLINE int32x4_t npyv__set_s32(npy_int32 i0, npy_int32 i1, npy_int32 i2, npy_int32 i3) + { + const int32_t NPY_DECL_ALIGNED(16) data[4] = {i0, i1, i2, i3}; + return vld1q_s32(data); + } + NPY_FINLINE uint64x2_t npyv__set_u64(npy_uint64 i0, npy_uint64 i1) + { + const uint64_t NPY_DECL_ALIGNED(16) data[2] = {i0, i1}; + return vld1q_u64(data); + } + NPY_FINLINE int64x2_t npyv__set_s64(npy_int64 i0, npy_int64 i1) + { + const int64_t NPY_DECL_ALIGNED(16) data[2] = {i0, i1}; + return vld1q_s64(data); + } + NPY_FINLINE float32x4_t npyv__set_f32(float i0, float i1, float i2, float i3) + { + const float NPY_DECL_ALIGNED(16) data[4] = {i0, i1, i2, i3}; + return vld1q_f32(data); + } + #if NPY_SIMD_F64 + NPY_FINLINE float64x2_t npyv__set_f64(double i0, double i1) + { + const double NPY_DECL_ALIGNED(16) data[2] = {i0, i1}; + return vld1q_f64(data); + } + #endif + #define npyv_setf_u8(FILL, ...) npyv__set_u8(NPYV__SET_FILL_16(npy_uint8, FILL, __VA_ARGS__)) + #define npyv_setf_s8(FILL, ...) npyv__set_s8(NPYV__SET_FILL_16(npy_int8, FILL, __VA_ARGS__)) + #define npyv_setf_u16(FILL, ...) npyv__set_u16(NPYV__SET_FILL_8(npy_uint16, FILL, __VA_ARGS__)) + #define npyv_setf_s16(FILL, ...) npyv__set_s16(NPYV__SET_FILL_8(npy_int16, FILL, __VA_ARGS__)) + #define npyv_setf_u32(FILL, ...) npyv__set_u32(NPYV__SET_FILL_4(npy_uint32, FILL, __VA_ARGS__)) + #define npyv_setf_s32(FILL, ...) npyv__set_s32(NPYV__SET_FILL_4(npy_int32, FILL, __VA_ARGS__)) + #define npyv_setf_u64(FILL, ...) npyv__set_u64(NPYV__SET_FILL_2(npy_uint64, FILL, __VA_ARGS__)) + #define npyv_setf_s64(FILL, ...) npyv__set_s64(NPYV__SET_FILL_2(npy_int64, FILL, __VA_ARGS__)) + #define npyv_setf_f32(FILL, ...) npyv__set_f32(NPYV__SET_FILL_4(float, FILL, __VA_ARGS__)) + #if NPY_SIMD_F64 + #define npyv_setf_f64(FILL, ...) npyv__set_f64(NPYV__SET_FILL_2(double, FILL, __VA_ARGS__)) + #endif +#endif + +// vector with specific values set to each lane and +// set zero to all remained lanes +#define npyv_set_u8(...) npyv_setf_u8(0, __VA_ARGS__) +#define npyv_set_s8(...) npyv_setf_s8(0, __VA_ARGS__) +#define npyv_set_u16(...) npyv_setf_u16(0, __VA_ARGS__) +#define npyv_set_s16(...) npyv_setf_s16(0, __VA_ARGS__) +#define npyv_set_u32(...) npyv_setf_u32(0, __VA_ARGS__) +#define npyv_set_s32(...) npyv_setf_s32(0, __VA_ARGS__) +#define npyv_set_u64(...) npyv_setf_u64(0, __VA_ARGS__) +#define npyv_set_s64(...) npyv_setf_s64(0, __VA_ARGS__) +#define npyv_set_f32(...) npyv_setf_f32(0, __VA_ARGS__) +#define npyv_set_f64(...) npyv_setf_f64(0, __VA_ARGS__) + +// Per lane select +#define npyv_select_u8 vbslq_u8 +#define npyv_select_s8 vbslq_s8 +#define npyv_select_u16 vbslq_u16 +#define npyv_select_s16 vbslq_s16 +#define npyv_select_u32 vbslq_u32 +#define npyv_select_s32 vbslq_s32 +#define npyv_select_u64 vbslq_u64 +#define npyv_select_s64 vbslq_s64 +#define npyv_select_f32 vbslq_f32 +#define npyv_select_f64 vbslq_f64 + +// extract the first vector's lane +#define npyv_extract0_u8(A) ((npy_uint8)vgetq_lane_u8(A, 0)) +#define npyv_extract0_s8(A) ((npy_int8)vgetq_lane_s8(A, 0)) +#define npyv_extract0_u16(A) ((npy_uint16)vgetq_lane_u16(A, 0)) +#define npyv_extract0_s16(A) ((npy_int16)vgetq_lane_s16(A, 0)) +#define npyv_extract0_u32(A) ((npy_uint32)vgetq_lane_u32(A, 0)) +#define npyv_extract0_s32(A) ((npy_int32)vgetq_lane_s32(A, 0)) +#define npyv_extract0_u64(A) ((npy_uint64)vgetq_lane_u64(A, 0)) +#define npyv_extract0_s64(A) ((npy_int64)vgetq_lane_s64(A, 0)) +#define npyv_extract0_f32(A) vgetq_lane_f32(A, 0) +#define npyv_extract0_f64(A) vgetq_lane_f64(A, 0) + +// Reinterpret +#define npyv_reinterpret_u8_u8(X) X +#define npyv_reinterpret_u8_s8 vreinterpretq_u8_s8 +#define npyv_reinterpret_u8_u16 vreinterpretq_u8_u16 +#define npyv_reinterpret_u8_s16 vreinterpretq_u8_s16 +#define npyv_reinterpret_u8_u32 vreinterpretq_u8_u32 +#define npyv_reinterpret_u8_s32 vreinterpretq_u8_s32 +#define npyv_reinterpret_u8_u64 vreinterpretq_u8_u64 +#define npyv_reinterpret_u8_s64 vreinterpretq_u8_s64 +#define npyv_reinterpret_u8_f32 vreinterpretq_u8_f32 +#define npyv_reinterpret_u8_f64 vreinterpretq_u8_f64 + +#define npyv_reinterpret_s8_s8(X) X +#define npyv_reinterpret_s8_u8 vreinterpretq_s8_u8 +#define npyv_reinterpret_s8_u16 vreinterpretq_s8_u16 +#define npyv_reinterpret_s8_s16 vreinterpretq_s8_s16 +#define npyv_reinterpret_s8_u32 vreinterpretq_s8_u32 +#define npyv_reinterpret_s8_s32 vreinterpretq_s8_s32 +#define npyv_reinterpret_s8_u64 vreinterpretq_s8_u64 +#define npyv_reinterpret_s8_s64 vreinterpretq_s8_s64 +#define npyv_reinterpret_s8_f32 vreinterpretq_s8_f32 +#define npyv_reinterpret_s8_f64 vreinterpretq_s8_f64 + +#define npyv_reinterpret_u16_u16(X) X +#define npyv_reinterpret_u16_u8 vreinterpretq_u16_u8 +#define npyv_reinterpret_u16_s8 vreinterpretq_u16_s8 +#define npyv_reinterpret_u16_s16 vreinterpretq_u16_s16 +#define npyv_reinterpret_u16_u32 vreinterpretq_u16_u32 +#define npyv_reinterpret_u16_s32 vreinterpretq_u16_s32 +#define npyv_reinterpret_u16_u64 vreinterpretq_u16_u64 +#define npyv_reinterpret_u16_s64 vreinterpretq_u16_s64 +#define npyv_reinterpret_u16_f32 vreinterpretq_u16_f32 +#define npyv_reinterpret_u16_f64 vreinterpretq_u16_f64 + +#define npyv_reinterpret_s16_s16(X) X +#define npyv_reinterpret_s16_u8 vreinterpretq_s16_u8 +#define npyv_reinterpret_s16_s8 vreinterpretq_s16_s8 +#define npyv_reinterpret_s16_u16 vreinterpretq_s16_u16 +#define npyv_reinterpret_s16_u32 vreinterpretq_s16_u32 +#define npyv_reinterpret_s16_s32 vreinterpretq_s16_s32 +#define npyv_reinterpret_s16_u64 vreinterpretq_s16_u64 +#define npyv_reinterpret_s16_s64 vreinterpretq_s16_s64 +#define npyv_reinterpret_s16_f32 vreinterpretq_s16_f32 +#define npyv_reinterpret_s16_f64 vreinterpretq_s16_f64 + +#define npyv_reinterpret_u32_u32(X) X +#define npyv_reinterpret_u32_u8 vreinterpretq_u32_u8 +#define npyv_reinterpret_u32_s8 vreinterpretq_u32_s8 +#define npyv_reinterpret_u32_u16 vreinterpretq_u32_u16 +#define npyv_reinterpret_u32_s16 vreinterpretq_u32_s16 +#define npyv_reinterpret_u32_s32 vreinterpretq_u32_s32 +#define npyv_reinterpret_u32_u64 vreinterpretq_u32_u64 +#define npyv_reinterpret_u32_s64 vreinterpretq_u32_s64 +#define npyv_reinterpret_u32_f32 vreinterpretq_u32_f32 +#define npyv_reinterpret_u32_f64 vreinterpretq_u32_f64 + +#define npyv_reinterpret_s32_s32(X) X +#define npyv_reinterpret_s32_u8 vreinterpretq_s32_u8 +#define npyv_reinterpret_s32_s8 vreinterpretq_s32_s8 +#define npyv_reinterpret_s32_u16 vreinterpretq_s32_u16 +#define npyv_reinterpret_s32_s16 vreinterpretq_s32_s16 +#define npyv_reinterpret_s32_u32 vreinterpretq_s32_u32 +#define npyv_reinterpret_s32_u64 vreinterpretq_s32_u64 +#define npyv_reinterpret_s32_s64 vreinterpretq_s32_s64 +#define npyv_reinterpret_s32_f32 vreinterpretq_s32_f32 +#define npyv_reinterpret_s32_f64 vreinterpretq_s32_f64 + +#define npyv_reinterpret_u64_u64(X) X +#define npyv_reinterpret_u64_u8 vreinterpretq_u64_u8 +#define npyv_reinterpret_u64_s8 vreinterpretq_u64_s8 +#define npyv_reinterpret_u64_u16 vreinterpretq_u64_u16 +#define npyv_reinterpret_u64_s16 vreinterpretq_u64_s16 +#define npyv_reinterpret_u64_u32 vreinterpretq_u64_u32 +#define npyv_reinterpret_u64_s32 vreinterpretq_u64_s32 +#define npyv_reinterpret_u64_s64 vreinterpretq_u64_s64 +#define npyv_reinterpret_u64_f32 vreinterpretq_u64_f32 +#define npyv_reinterpret_u64_f64 vreinterpretq_u64_f64 + +#define npyv_reinterpret_s64_s64(X) X +#define npyv_reinterpret_s64_u8 vreinterpretq_s64_u8 +#define npyv_reinterpret_s64_s8 vreinterpretq_s64_s8 +#define npyv_reinterpret_s64_u16 vreinterpretq_s64_u16 +#define npyv_reinterpret_s64_s16 vreinterpretq_s64_s16 +#define npyv_reinterpret_s64_u32 vreinterpretq_s64_u32 +#define npyv_reinterpret_s64_s32 vreinterpretq_s64_s32 +#define npyv_reinterpret_s64_u64 vreinterpretq_s64_u64 +#define npyv_reinterpret_s64_f32 vreinterpretq_s64_f32 +#define npyv_reinterpret_s64_f64 vreinterpretq_s64_f64 + +#define npyv_reinterpret_f32_f32(X) X +#define npyv_reinterpret_f32_u8 vreinterpretq_f32_u8 +#define npyv_reinterpret_f32_s8 vreinterpretq_f32_s8 +#define npyv_reinterpret_f32_u16 vreinterpretq_f32_u16 +#define npyv_reinterpret_f32_s16 vreinterpretq_f32_s16 +#define npyv_reinterpret_f32_u32 vreinterpretq_f32_u32 +#define npyv_reinterpret_f32_s32 vreinterpretq_f32_s32 +#define npyv_reinterpret_f32_u64 vreinterpretq_f32_u64 +#define npyv_reinterpret_f32_s64 vreinterpretq_f32_s64 +#define npyv_reinterpret_f32_f64 vreinterpretq_f32_f64 + +#define npyv_reinterpret_f64_f64(X) X +#define npyv_reinterpret_f64_u8 vreinterpretq_f64_u8 +#define npyv_reinterpret_f64_s8 vreinterpretq_f64_s8 +#define npyv_reinterpret_f64_u16 vreinterpretq_f64_u16 +#define npyv_reinterpret_f64_s16 vreinterpretq_f64_s16 +#define npyv_reinterpret_f64_u32 vreinterpretq_f64_u32 +#define npyv_reinterpret_f64_s32 vreinterpretq_f64_s32 +#define npyv_reinterpret_f64_u64 vreinterpretq_f64_u64 +#define npyv_reinterpret_f64_s64 vreinterpretq_f64_s64 +#define npyv_reinterpret_f64_f32 vreinterpretq_f64_f32 + +// Only required by AVX2/AVX512 +#define npyv_cleanup() ((void)0) + +#endif // _NPY_SIMD_NEON_MISC_H diff --git a/mkl_umath/src/npyv/neon/neon.h b/mkl_umath/src/npyv/neon/neon.h new file mode 100644 index 00000000..49c35c41 --- /dev/null +++ b/mkl_umath/src/npyv/neon/neon.h @@ -0,0 +1,82 @@ +#ifndef _NPY_SIMD_H_ + #error "Not a standalone header" +#endif + +#define NPY_SIMD 128 +#define NPY_SIMD_WIDTH 16 +#define NPY_SIMD_F32 1 +#ifdef __aarch64__ + #define NPY_SIMD_F64 1 +#else + #define NPY_SIMD_F64 0 +#endif +#ifdef NPY_HAVE_NEON_VFPV4 + #define NPY_SIMD_FMA3 1 // native support +#else + #define NPY_SIMD_FMA3 0 // HW emulated +#endif +#define NPY_SIMD_BIGENDIAN 0 +#define NPY_SIMD_CMPSIGNAL 1 + +typedef uint8x16_t npyv_u8; +typedef int8x16_t npyv_s8; +typedef uint16x8_t npyv_u16; +typedef int16x8_t npyv_s16; +typedef uint32x4_t npyv_u32; +typedef int32x4_t npyv_s32; +typedef uint64x2_t npyv_u64; +typedef int64x2_t npyv_s64; +typedef float32x4_t npyv_f32; +#if NPY_SIMD_F64 +typedef float64x2_t npyv_f64; +#endif + +typedef uint8x16_t npyv_b8; +typedef uint16x8_t npyv_b16; +typedef uint32x4_t npyv_b32; +typedef uint64x2_t npyv_b64; + +typedef uint8x16x2_t npyv_u8x2; +typedef int8x16x2_t npyv_s8x2; +typedef uint16x8x2_t npyv_u16x2; +typedef int16x8x2_t npyv_s16x2; +typedef uint32x4x2_t npyv_u32x2; +typedef int32x4x2_t npyv_s32x2; +typedef uint64x2x2_t npyv_u64x2; +typedef int64x2x2_t npyv_s64x2; +typedef float32x4x2_t npyv_f32x2; +#if NPY_SIMD_F64 +typedef float64x2x2_t npyv_f64x2; +#endif + +typedef uint8x16x3_t npyv_u8x3; +typedef int8x16x3_t npyv_s8x3; +typedef uint16x8x3_t npyv_u16x3; +typedef int16x8x3_t npyv_s16x3; +typedef uint32x4x3_t npyv_u32x3; +typedef int32x4x3_t npyv_s32x3; +typedef uint64x2x3_t npyv_u64x3; +typedef int64x2x3_t npyv_s64x3; +typedef float32x4x3_t npyv_f32x3; +#if NPY_SIMD_F64 +typedef float64x2x3_t npyv_f64x3; +#endif + +#define npyv_nlanes_u8 16 +#define npyv_nlanes_s8 16 +#define npyv_nlanes_u16 8 +#define npyv_nlanes_s16 8 +#define npyv_nlanes_u32 4 +#define npyv_nlanes_s32 4 +#define npyv_nlanes_u64 2 +#define npyv_nlanes_s64 2 +#define npyv_nlanes_f32 4 +#define npyv_nlanes_f64 2 + +#include "memory.h" +#include "misc.h" +#include "reorder.h" +#include "operators.h" +#include "conversion.h" +#include "arithmetic.h" +#include "math.h" diff --git a/mkl_umath/src/npyv/neon/operators.h b/mkl_umath/src/npyv/neon/operators.h new file mode 100644 index 00000000..e18ea94b --- /dev/null +++ b/mkl_umath/src/npyv/neon/operators.h @@ -0,0 +1,399 @@ +#ifndef NPY_SIMD + #error "Not a standalone header" +#endif + +#ifndef _NPY_SIMD_NEON_OPERATORS_H +#define _NPY_SIMD_NEON_OPERATORS_H + +/*************************** + * Shifting + ***************************/ + +// left +#define npyv_shl_u16(A, C) vshlq_u16(A, npyv_setall_s16(C)) +#define npyv_shl_s16(A, C) vshlq_s16(A, npyv_setall_s16(C)) +#define npyv_shl_u32(A, C) vshlq_u32(A, npyv_setall_s32(C)) +#define npyv_shl_s32(A, C) vshlq_s32(A, npyv_setall_s32(C)) +#define npyv_shl_u64(A, C) vshlq_u64(A, npyv_setall_s64(C)) +#define npyv_shl_s64(A, C) vshlq_s64(A, npyv_setall_s64(C)) + +// left by an immediate constant +#define npyv_shli_u16 vshlq_n_u16 +#define npyv_shli_s16 vshlq_n_s16 +#define npyv_shli_u32 vshlq_n_u32 +#define npyv_shli_s32 vshlq_n_s32 +#define npyv_shli_u64 vshlq_n_u64 +#define npyv_shli_s64 vshlq_n_s64 + +// right +#define npyv_shr_u16(A, C) vshlq_u16(A, npyv_setall_s16(-(C))) +#define npyv_shr_s16(A, C) vshlq_s16(A, npyv_setall_s16(-(C))) +#define npyv_shr_u32(A, C) vshlq_u32(A, npyv_setall_s32(-(C))) +#define npyv_shr_s32(A, C) vshlq_s32(A, npyv_setall_s32(-(C))) +#define npyv_shr_u64(A, C) vshlq_u64(A, npyv_setall_s64(-(C))) +#define npyv_shr_s64(A, C) vshlq_s64(A, npyv_setall_s64(-(C))) + +// right by an immediate constant +#define npyv_shri_u16 vshrq_n_u16 +#define npyv_shri_s16 vshrq_n_s16 +#define npyv_shri_u32 vshrq_n_u32 +#define npyv_shri_s32 vshrq_n_s32 +#define npyv_shri_u64 vshrq_n_u64 +#define npyv_shri_s64 vshrq_n_s64 + +/*************************** + * Logical + ***************************/ + +// AND +#define npyv_and_u8 vandq_u8 +#define npyv_and_s8 vandq_s8 +#define npyv_and_u16 vandq_u16 +#define npyv_and_s16 vandq_s16 +#define npyv_and_u32 vandq_u32 +#define npyv_and_s32 vandq_s32 +#define npyv_and_u64 vandq_u64 +#define npyv_and_s64 vandq_s64 +#define npyv_and_f32(A, B) \ + vreinterpretq_f32_u8(vandq_u8(vreinterpretq_u8_f32(A), vreinterpretq_u8_f32(B))) +#define npyv_and_f64(A, B) \ + vreinterpretq_f64_u8(vandq_u8(vreinterpretq_u8_f64(A), vreinterpretq_u8_f64(B))) +#define npyv_and_b8 vandq_u8 +#define npyv_and_b16 vandq_u16 +#define npyv_and_b32 vandq_u32 +#define npyv_and_b64 vandq_u64 + +// OR +#define npyv_or_u8 vorrq_u8 +#define npyv_or_s8 vorrq_s8 +#define npyv_or_u16 vorrq_u16 +#define npyv_or_s16 vorrq_s16 +#define npyv_or_u32 vorrq_u32 +#define npyv_or_s32 vorrq_s32 +#define npyv_or_u64 vorrq_u64 +#define npyv_or_s64 vorrq_s64 +#define npyv_or_f32(A, B) \ + vreinterpretq_f32_u8(vorrq_u8(vreinterpretq_u8_f32(A), vreinterpretq_u8_f32(B))) +#define npyv_or_f64(A, B) \ + vreinterpretq_f64_u8(vorrq_u8(vreinterpretq_u8_f64(A), vreinterpretq_u8_f64(B))) +#define npyv_or_b8 vorrq_u8 +#define npyv_or_b16 vorrq_u16 +#define npyv_or_b32 vorrq_u32 +#define npyv_or_b64 vorrq_u64 + + +// XOR +#define npyv_xor_u8 veorq_u8 +#define npyv_xor_s8 veorq_s8 +#define npyv_xor_u16 veorq_u16 +#define npyv_xor_s16 veorq_s16 +#define npyv_xor_u32 veorq_u32 +#define npyv_xor_s32 veorq_s32 +#define npyv_xor_u64 veorq_u64 +#define npyv_xor_s64 veorq_s64 +#define npyv_xor_f32(A, B) \ + vreinterpretq_f32_u8(veorq_u8(vreinterpretq_u8_f32(A), vreinterpretq_u8_f32(B))) +#define npyv_xor_f64(A, B) \ + vreinterpretq_f64_u8(veorq_u8(vreinterpretq_u8_f64(A), vreinterpretq_u8_f64(B))) +#define npyv_xor_b8 veorq_u8 +#define npyv_xor_b16 veorq_u16 +#define npyv_xor_b32 veorq_u32 +#define npyv_xor_b64 veorq_u64 + +// NOT +#define npyv_not_u8 vmvnq_u8 +#define npyv_not_s8 vmvnq_s8 +#define npyv_not_u16 vmvnq_u16 +#define npyv_not_s16 vmvnq_s16 +#define npyv_not_u32 vmvnq_u32 +#define npyv_not_s32 vmvnq_s32 +#define npyv_not_u64(A) vreinterpretq_u64_u8(vmvnq_u8(vreinterpretq_u8_u64(A))) +#define npyv_not_s64(A) vreinterpretq_s64_u8(vmvnq_u8(vreinterpretq_u8_s64(A))) +#define npyv_not_f32(A) vreinterpretq_f32_u8(vmvnq_u8(vreinterpretq_u8_f32(A))) +#define npyv_not_f64(A) vreinterpretq_f64_u8(vmvnq_u8(vreinterpretq_u8_f64(A))) +#define npyv_not_b8 vmvnq_u8 +#define npyv_not_b16 vmvnq_u16 +#define npyv_not_b32 vmvnq_u32 +#define npyv_not_b64 npyv_not_u64 + +// ANDC, ORC and XNOR +#define npyv_andc_u8 vbicq_u8 +#define npyv_andc_b8 vbicq_u8 +#define npyv_orc_b8 vornq_u8 +#define npyv_xnor_b8 vceqq_u8 + +/*************************** + * Comparison + ***************************/ + +// equal +#define npyv_cmpeq_u8 vceqq_u8 +#define npyv_cmpeq_s8 vceqq_s8 +#define npyv_cmpeq_u16 vceqq_u16 +#define npyv_cmpeq_s16 vceqq_s16 +#define npyv_cmpeq_u32 vceqq_u32 +#define npyv_cmpeq_s32 vceqq_s32 +#define npyv_cmpeq_f32 vceqq_f32 +#define npyv_cmpeq_f64 vceqq_f64 + +#ifdef __aarch64__ + #define npyv_cmpeq_u64 vceqq_u64 + #define npyv_cmpeq_s64 vceqq_s64 +#else + NPY_FINLINE uint64x2_t npyv_cmpeq_u64(uint64x2_t a, uint64x2_t b) + { + uint64x2_t cmpeq = vreinterpretq_u64_u32(vceqq_u32( + vreinterpretq_u32_u64(a), vreinterpretq_u32_u64(b) + )); + uint64x2_t cmpeq_h = vshlq_n_u64(cmpeq, 32); + uint64x2_t test = vandq_u64(cmpeq, cmpeq_h); + return vreinterpretq_u64_s64(vshrq_n_s64(vreinterpretq_s64_u64(test), 32)); + } + #define npyv_cmpeq_s64(A, B) \ + npyv_cmpeq_u64(vreinterpretq_u64_s64(A), vreinterpretq_u64_s64(B)) +#endif + +// not Equal +#define npyv_cmpneq_u8(A, B) vmvnq_u8(vceqq_u8(A, B)) +#define npyv_cmpneq_s8(A, B) vmvnq_u8(vceqq_s8(A, B)) +#define npyv_cmpneq_u16(A, B) vmvnq_u16(vceqq_u16(A, B)) +#define npyv_cmpneq_s16(A, B) vmvnq_u16(vceqq_s16(A, B)) +#define npyv_cmpneq_u32(A, B) vmvnq_u32(vceqq_u32(A, B)) +#define npyv_cmpneq_s32(A, B) vmvnq_u32(vceqq_s32(A, B)) +#define npyv_cmpneq_u64(A, B) npyv_not_u64(npyv_cmpeq_u64(A, B)) +#define npyv_cmpneq_s64(A, B) npyv_not_u64(npyv_cmpeq_s64(A, B)) +#define npyv_cmpneq_f32(A, B) vmvnq_u32(vceqq_f32(A, B)) +#define npyv_cmpneq_f64(A, B) npyv_not_u64(vceqq_f64(A, B)) + +// greater than +#define npyv_cmpgt_u8 vcgtq_u8 +#define npyv_cmpgt_s8 vcgtq_s8 +#define npyv_cmpgt_u16 vcgtq_u16 +#define npyv_cmpgt_s16 vcgtq_s16 +#define npyv_cmpgt_u32 vcgtq_u32 +#define npyv_cmpgt_s32 vcgtq_s32 +#define npyv_cmpgt_f32 vcgtq_f32 +#define npyv_cmpgt_f64 vcgtq_f64 + +#ifdef __aarch64__ + #define npyv_cmpgt_u64 vcgtq_u64 + #define npyv_cmpgt_s64 vcgtq_s64 +#else + NPY_FINLINE uint64x2_t npyv_cmpgt_s64(int64x2_t a, int64x2_t b) + { + int64x2_t sub = vsubq_s64(b, a); + uint64x2_t nsame_sbit = vreinterpretq_u64_s64(veorq_s64(a, b)); + int64x2_t test = vbslq_s64(nsame_sbit, b, sub); + int64x2_t extend_sbit = vshrq_n_s64(test, 63); + return vreinterpretq_u64_s64(extend_sbit); + } + NPY_FINLINE uint64x2_t npyv_cmpgt_u64(uint64x2_t a, uint64x2_t b) + { + const uint64x2_t sbit = npyv_setall_u64(0x8000000000000000); + a = npyv_xor_u64(a, sbit); + b = npyv_xor_u64(b, sbit); + return npyv_cmpgt_s64(vreinterpretq_s64_u64(a), vreinterpretq_s64_u64(b)); + } +#endif + +// greater than or equal +#define npyv_cmpge_u8 vcgeq_u8 +#define npyv_cmpge_s8 vcgeq_s8 +#define npyv_cmpge_u16 vcgeq_u16 +#define npyv_cmpge_s16 vcgeq_s16 +#define npyv_cmpge_u32 vcgeq_u32 +#define npyv_cmpge_s32 vcgeq_s32 +#define npyv_cmpge_f32 vcgeq_f32 +#define npyv_cmpge_f64 vcgeq_f64 + +#ifdef __aarch64__ + #define npyv_cmpge_u64 vcgeq_u64 + #define npyv_cmpge_s64 vcgeq_s64 +#else + #define npyv_cmpge_u64(A, B) npyv_not_u64(npyv_cmpgt_u64(B, A)) + #define npyv_cmpge_s64(A, B) npyv_not_u64(npyv_cmpgt_s64(B, A)) +#endif + +// less than +#define npyv_cmplt_u8(A, B) npyv_cmpgt_u8(B, A) +#define npyv_cmplt_s8(A, B) npyv_cmpgt_s8(B, A) +#define npyv_cmplt_u16(A, B) npyv_cmpgt_u16(B, A) +#define npyv_cmplt_s16(A, B) npyv_cmpgt_s16(B, A) +#define npyv_cmplt_u32(A, B) npyv_cmpgt_u32(B, A) +#define npyv_cmplt_s32(A, B) npyv_cmpgt_s32(B, A) +#define npyv_cmplt_u64(A, B) npyv_cmpgt_u64(B, A) +#define npyv_cmplt_s64(A, B) npyv_cmpgt_s64(B, A) +#define npyv_cmplt_f32(A, B) npyv_cmpgt_f32(B, A) +#define npyv_cmplt_f64(A, B) npyv_cmpgt_f64(B, A) + +// less than or equal +#define npyv_cmple_u8(A, B) npyv_cmpge_u8(B, A) +#define npyv_cmple_s8(A, B) npyv_cmpge_s8(B, A) +#define npyv_cmple_u16(A, B) npyv_cmpge_u16(B, A) +#define npyv_cmple_s16(A, B) npyv_cmpge_s16(B, A) +#define npyv_cmple_u32(A, B) npyv_cmpge_u32(B, A) +#define npyv_cmple_s32(A, B) npyv_cmpge_s32(B, A) +#define npyv_cmple_u64(A, B) npyv_cmpge_u64(B, A) +#define npyv_cmple_s64(A, B) npyv_cmpge_s64(B, A) +#define npyv_cmple_f32(A, B) npyv_cmpge_f32(B, A) +#define npyv_cmple_f64(A, B) npyv_cmpge_f64(B, A) + +// check special cases +NPY_FINLINE npyv_b32 npyv_notnan_f32(npyv_f32 a) +{ +#if defined(__clang__) +/** + * To avoid signaling qNaN, workaround for clang symmetric inputs bug + * check https://github.com/numpy/numpy/issues/22933, + * for more clarification. + */ + npyv_b32 ret; + #if NPY_SIMD_F64 + __asm("fcmeq %0.4s, %1.4s, %1.4s" : "=w" (ret) : "w" (a)); + #else + __asm("vceq.f32 %q0, %q1, %q1" : "=w" (ret) : "w" (a)); + #endif + return ret; +#else + return vceqq_f32(a, a); +#endif +} +#if NPY_SIMD_F64 + NPY_FINLINE npyv_b64 npyv_notnan_f64(npyv_f64 a) + { + #if defined(__clang__) + npyv_b64 ret; + __asm("fcmeq %0.2d, %1.2d, %1.2d" : "=w" (ret) : "w" (a)); + return ret; + #else + return vceqq_f64(a, a); + #endif + } +#endif + +// Test cross all vector lanes +// any: returns true if any of the elements is not equal to zero +// all: returns true if all elements are not equal to zero +#if NPY_SIMD_F64 + #define NPYV_IMPL_NEON_ANYALL(LEN) \ + NPY_FINLINE bool npyv_any_b##LEN(npyv_b##LEN a) \ + { return vmaxvq_u##LEN(a) != 0; } \ + NPY_FINLINE bool npyv_all_b##LEN(npyv_b##LEN a) \ + { return vminvq_u##LEN(a) != 0; } + NPYV_IMPL_NEON_ANYALL(8) + NPYV_IMPL_NEON_ANYALL(16) + NPYV_IMPL_NEON_ANYALL(32) + #undef NPYV_IMPL_NEON_ANYALL + + #define NPYV_IMPL_NEON_ANYALL(SFX, USFX, BSFX) \ + NPY_FINLINE bool npyv_any_##SFX(npyv_##SFX a) \ + { return npyv_any_##BSFX(npyv_reinterpret_##USFX##_##SFX(a)); } \ + NPY_FINLINE bool npyv_all_##SFX(npyv_##SFX a) \ + { return npyv_all_##BSFX(npyv_reinterpret_##USFX##_##SFX(a)); } + NPYV_IMPL_NEON_ANYALL(u8, u8, b8) + NPYV_IMPL_NEON_ANYALL(s8, u8, b8) + NPYV_IMPL_NEON_ANYALL(u16, u16, b16) + NPYV_IMPL_NEON_ANYALL(s16, u16, b16) + NPYV_IMPL_NEON_ANYALL(u32, u32, b32) + NPYV_IMPL_NEON_ANYALL(s32, u32, b32) + #undef NPYV_IMPL_NEON_ANYALL + + NPY_FINLINE bool npyv_any_b64(npyv_b64 a) + { return vmaxvq_u32(vreinterpretq_u32_u64(a)) != 0; } + NPY_FINLINE bool npyv_all_b64(npyv_b64 a) + { return vminvq_u32(vreinterpretq_u32_u64(a)) != 0; } + #define npyv_any_u64 npyv_any_b64 + NPY_FINLINE bool npyv_all_u64(npyv_u64 a) + { + uint32x4_t a32 = vreinterpretq_u32_u64(a); + a32 = vorrq_u32(a32, vrev64q_u32(a32)); + return vminvq_u32(a32) != 0; + } + NPY_FINLINE bool npyv_any_s64(npyv_s64 a) + { return npyv_any_u64(vreinterpretq_u64_s64(a)); } + NPY_FINLINE bool npyv_all_s64(npyv_s64 a) + { return npyv_all_u64(vreinterpretq_u64_s64(a)); } + + #define NPYV_IMPL_NEON_ANYALL(SFX, BSFX) \ + NPY_FINLINE bool npyv_any_##SFX(npyv_##SFX a) \ + { return !npyv_all_##BSFX(npyv_cmpeq_##SFX(a, npyv_zero_##SFX())); } \ + NPY_FINLINE bool npyv_all_##SFX(npyv_##SFX a) \ + { return !npyv_any_##BSFX(npyv_cmpeq_##SFX(a, npyv_zero_##SFX())); } + NPYV_IMPL_NEON_ANYALL(f32, b32) + NPYV_IMPL_NEON_ANYALL(f64, b64) + #undef NPYV_IMPL_NEON_ANYALL +#else + #define NPYV_IMPL_NEON_ANYALL(LEN) \ + NPY_FINLINE bool npyv_any_b##LEN(npyv_b##LEN a) \ + { \ + int64x2_t a64 = vreinterpretq_s64_u##LEN(a); \ + return ( \ + vgetq_lane_s64(a64, 0) | \ + vgetq_lane_s64(a64, 1) \ + ) != 0; \ + } \ + NPY_FINLINE bool npyv_all_b##LEN(npyv_b##LEN a) \ + { \ + int64x2_t a64 = vreinterpretq_s64_u##LEN(a); \ + return ( \ + vgetq_lane_s64(a64, 0) & \ + vgetq_lane_s64(a64, 1) \ + ) == -1; \ + } + NPYV_IMPL_NEON_ANYALL(8) + NPYV_IMPL_NEON_ANYALL(16) + NPYV_IMPL_NEON_ANYALL(32) + NPYV_IMPL_NEON_ANYALL(64) + #undef NPYV_IMPL_NEON_ANYALL + + #define NPYV_IMPL_NEON_ANYALL(SFX, USFX) \ + NPY_FINLINE bool npyv_any_##SFX(npyv_##SFX a) \ + { \ + int64x2_t a64 = vreinterpretq_s64_##SFX(a); \ + return ( \ + vgetq_lane_s64(a64, 0) | \ + vgetq_lane_s64(a64, 1) \ + ) != 0; \ + } \ + NPY_FINLINE bool npyv_all_##SFX(npyv_##SFX a) \ + { \ + npyv_##USFX tz = npyv_cmpeq_##SFX( \ + a, npyv_zero_##SFX() \ + ); \ + int64x2_t a64 = vreinterpretq_s64_##USFX(tz); \ + return ( \ + vgetq_lane_s64(a64, 0) | \ + vgetq_lane_s64(a64, 1) \ + ) == 0; \ + } + NPYV_IMPL_NEON_ANYALL(u8, u8) + NPYV_IMPL_NEON_ANYALL(s8, u8) + NPYV_IMPL_NEON_ANYALL(u16, u16) + NPYV_IMPL_NEON_ANYALL(s16, u16) + NPYV_IMPL_NEON_ANYALL(u32, u32) + NPYV_IMPL_NEON_ANYALL(s32, u32) + #undef NPYV_IMPL_NEON_ANYALL + + NPY_FINLINE bool npyv_any_f32(npyv_f32 a) + { + uint32x4_t tz = npyv_cmpeq_f32(a, npyv_zero_f32()); + int64x2_t a64 = vreinterpretq_s64_u32(tz); + return (vgetq_lane_s64(a64, 0) & vgetq_lane_s64(a64, 1)) != -1ll; + } + NPY_FINLINE bool npyv_all_f32(npyv_f32 a) + { + uint32x4_t tz = npyv_cmpeq_f32(a, npyv_zero_f32()); + int64x2_t a64 = vreinterpretq_s64_u32(tz); + return (vgetq_lane_s64(a64, 0) | vgetq_lane_s64(a64, 1)) == 0; + } + NPY_FINLINE bool npyv_any_s64(npyv_s64 a) + { return (vgetq_lane_s64(a, 0) | vgetq_lane_s64(a, 1)) != 0; } + NPY_FINLINE bool npyv_all_s64(npyv_s64 a) + { return vgetq_lane_s64(a, 0) && vgetq_lane_s64(a, 1); } + NPY_FINLINE bool npyv_any_u64(npyv_u64 a) + { return (vgetq_lane_u64(a, 0) | vgetq_lane_u64(a, 1)) != 0; } + NPY_FINLINE bool npyv_all_u64(npyv_u64 a) + { return vgetq_lane_u64(a, 0) && vgetq_lane_u64(a, 1); } +#endif // NPY_SIMD_F64 + +#endif // _NPY_SIMD_NEON_OPERATORS_H diff --git a/mkl_umath/src/npyv/neon/reorder.h b/mkl_umath/src/npyv/neon/reorder.h new file mode 100644 index 00000000..8bf68f5b --- /dev/null +++ b/mkl_umath/src/npyv/neon/reorder.h @@ -0,0 +1,189 @@ +#ifndef NPY_SIMD + #error "Not a standalone header" +#endif + +#ifndef _NPY_SIMD_NEON_REORDER_H +#define _NPY_SIMD_NEON_REORDER_H + +// combine lower part of two vectors +#ifdef __aarch64__ + #define npyv_combinel_u8(A, B) vreinterpretq_u8_u64(vzip1q_u64(vreinterpretq_u64_u8(A), vreinterpretq_u64_u8(B))) + #define npyv_combinel_s8(A, B) vreinterpretq_s8_u64(vzip1q_u64(vreinterpretq_u64_s8(A), vreinterpretq_u64_s8(B))) + #define npyv_combinel_u16(A, B) vreinterpretq_u16_u64(vzip1q_u64(vreinterpretq_u64_u16(A), vreinterpretq_u64_u16(B))) + #define npyv_combinel_s16(A, B) vreinterpretq_s16_u64(vzip1q_u64(vreinterpretq_u64_s16(A), vreinterpretq_u64_s16(B))) + #define npyv_combinel_u32(A, B) vreinterpretq_u32_u64(vzip1q_u64(vreinterpretq_u64_u32(A), vreinterpretq_u64_u32(B))) + #define npyv_combinel_s32(A, B) vreinterpretq_s32_u64(vzip1q_u64(vreinterpretq_u64_s32(A), vreinterpretq_u64_s32(B))) + #define npyv_combinel_u64 vzip1q_u64 + #define npyv_combinel_s64 vzip1q_s64 + #define npyv_combinel_f32(A, B) vreinterpretq_f32_u64(vzip1q_u64(vreinterpretq_u64_f32(A), vreinterpretq_u64_f32(B))) + #define npyv_combinel_f64 vzip1q_f64 +#else + #define npyv_combinel_u8(A, B) vcombine_u8(vget_low_u8(A), vget_low_u8(B)) + #define npyv_combinel_s8(A, B) vcombine_s8(vget_low_s8(A), vget_low_s8(B)) + #define npyv_combinel_u16(A, B) vcombine_u16(vget_low_u16(A), vget_low_u16(B)) + #define npyv_combinel_s16(A, B) vcombine_s16(vget_low_s16(A), vget_low_s16(B)) + #define npyv_combinel_u32(A, B) vcombine_u32(vget_low_u32(A), vget_low_u32(B)) + #define npyv_combinel_s32(A, B) vcombine_s32(vget_low_s32(A), vget_low_s32(B)) + #define npyv_combinel_u64(A, B) vcombine_u64(vget_low_u64(A), vget_low_u64(B)) + #define npyv_combinel_s64(A, B) vcombine_s64(vget_low_s64(A), vget_low_s64(B)) + #define npyv_combinel_f32(A, B) vcombine_f32(vget_low_f32(A), vget_low_f32(B)) +#endif + +// combine higher part of two vectors +#ifdef __aarch64__ + #define npyv_combineh_u8(A, B) vreinterpretq_u8_u64(vzip2q_u64(vreinterpretq_u64_u8(A), vreinterpretq_u64_u8(B))) + #define npyv_combineh_s8(A, B) vreinterpretq_s8_u64(vzip2q_u64(vreinterpretq_u64_s8(A), vreinterpretq_u64_s8(B))) + #define npyv_combineh_u16(A, B) vreinterpretq_u16_u64(vzip2q_u64(vreinterpretq_u64_u16(A), vreinterpretq_u64_u16(B))) + #define npyv_combineh_s16(A, B) vreinterpretq_s16_u64(vzip2q_u64(vreinterpretq_u64_s16(A), vreinterpretq_u64_s16(B))) + #define npyv_combineh_u32(A, B) vreinterpretq_u32_u64(vzip2q_u64(vreinterpretq_u64_u32(A), vreinterpretq_u64_u32(B))) + #define npyv_combineh_s32(A, B) vreinterpretq_s32_u64(vzip2q_u64(vreinterpretq_u64_s32(A), vreinterpretq_u64_s32(B))) + #define npyv_combineh_u64 vzip2q_u64 + #define npyv_combineh_s64 vzip2q_s64 + #define npyv_combineh_f32(A, B) vreinterpretq_f32_u64(vzip2q_u64(vreinterpretq_u64_f32(A), vreinterpretq_u64_f32(B))) + #define npyv_combineh_f64 vzip2q_f64 +#else + #define npyv_combineh_u8(A, B) vcombine_u8(vget_high_u8(A), vget_high_u8(B)) + #define npyv_combineh_s8(A, B) vcombine_s8(vget_high_s8(A), vget_high_s8(B)) + #define npyv_combineh_u16(A, B) vcombine_u16(vget_high_u16(A), vget_high_u16(B)) + #define npyv_combineh_s16(A, B) vcombine_s16(vget_high_s16(A), vget_high_s16(B)) + #define npyv_combineh_u32(A, B) vcombine_u32(vget_high_u32(A), vget_high_u32(B)) + #define npyv_combineh_s32(A, B) vcombine_s32(vget_high_s32(A), vget_high_s32(B)) + #define npyv_combineh_u64(A, B) vcombine_u64(vget_high_u64(A), vget_high_u64(B)) + #define npyv_combineh_s64(A, B) vcombine_s64(vget_high_s64(A), vget_high_s64(B)) + #define npyv_combineh_f32(A, B) vcombine_f32(vget_high_f32(A), vget_high_f32(B)) +#endif + +// combine two vectors from lower and higher parts of two other vectors +#define NPYV_IMPL_NEON_COMBINE(T_VEC, SFX) \ + NPY_FINLINE T_VEC##x2 npyv_combine_##SFX(T_VEC a, T_VEC b) \ + { \ + T_VEC##x2 r; \ + r.val[0] = NPY_CAT(npyv_combinel_, SFX)(a, b); \ + r.val[1] = NPY_CAT(npyv_combineh_, SFX)(a, b); \ + return r; \ + } + +NPYV_IMPL_NEON_COMBINE(npyv_u8, u8) +NPYV_IMPL_NEON_COMBINE(npyv_s8, s8) +NPYV_IMPL_NEON_COMBINE(npyv_u16, u16) +NPYV_IMPL_NEON_COMBINE(npyv_s16, s16) +NPYV_IMPL_NEON_COMBINE(npyv_u32, u32) +NPYV_IMPL_NEON_COMBINE(npyv_s32, s32) +NPYV_IMPL_NEON_COMBINE(npyv_u64, u64) +NPYV_IMPL_NEON_COMBINE(npyv_s64, s64) +NPYV_IMPL_NEON_COMBINE(npyv_f32, f32) +#ifdef __aarch64__ +NPYV_IMPL_NEON_COMBINE(npyv_f64, f64) +#endif + +// interleave & deinterleave two vectors +#ifdef __aarch64__ + #define NPYV_IMPL_NEON_ZIP(T_VEC, SFX) \ + NPY_FINLINE T_VEC##x2 npyv_zip_##SFX(T_VEC a, T_VEC b) \ + { \ + T_VEC##x2 r; \ + r.val[0] = vzip1q_##SFX(a, b); \ + r.val[1] = vzip2q_##SFX(a, b); \ + return r; \ + } \ + NPY_FINLINE T_VEC##x2 npyv_unzip_##SFX(T_VEC a, T_VEC b) \ + { \ + T_VEC##x2 r; \ + r.val[0] = vuzp1q_##SFX(a, b); \ + r.val[1] = vuzp2q_##SFX(a, b); \ + return r; \ + } +#else + #define NPYV_IMPL_NEON_ZIP(T_VEC, SFX) \ + NPY_FINLINE T_VEC##x2 npyv_zip_##SFX(T_VEC a, T_VEC b) \ + { return vzipq_##SFX(a, b); } \ + NPY_FINLINE T_VEC##x2 npyv_unzip_##SFX(T_VEC a, T_VEC b) \ + { return vuzpq_##SFX(a, b); } +#endif + +NPYV_IMPL_NEON_ZIP(npyv_u8, u8) +NPYV_IMPL_NEON_ZIP(npyv_s8, s8) +NPYV_IMPL_NEON_ZIP(npyv_u16, u16) +NPYV_IMPL_NEON_ZIP(npyv_s16, s16) +NPYV_IMPL_NEON_ZIP(npyv_u32, u32) +NPYV_IMPL_NEON_ZIP(npyv_s32, s32) +NPYV_IMPL_NEON_ZIP(npyv_f32, f32) + +#define npyv_zip_u64 npyv_combine_u64 +#define npyv_zip_s64 npyv_combine_s64 +#define npyv_zip_f64 npyv_combine_f64 +#define npyv_unzip_u64 npyv_combine_u64 +#define npyv_unzip_s64 npyv_combine_s64 +#define npyv_unzip_f64 npyv_combine_f64 + +// Reverse elements of each 64-bit lane +#define npyv_rev64_u8 vrev64q_u8 +#define npyv_rev64_s8 vrev64q_s8 +#define npyv_rev64_u16 vrev64q_u16 +#define npyv_rev64_s16 vrev64q_s16 +#define npyv_rev64_u32 vrev64q_u32 +#define npyv_rev64_s32 vrev64q_s32 +#define npyv_rev64_f32 vrev64q_f32 + +// Permuting the elements of each 128-bit lane by immediate index for +// each element. +#ifdef __clang__ + #define npyv_permi128_u32(A, E0, E1, E2, E3) \ + __builtin_shufflevector(A, A, E0, E1, E2, E3) +#elif defined(__GNUC__) + #define npyv_permi128_u32(A, E0, E1, E2, E3) \ + __builtin_shuffle(A, npyv_set_u32(E0, E1, E2, E3)) +#else + #define npyv_permi128_u32(A, E0, E1, E2, E3) \ + npyv_set_u32( \ + vgetq_lane_u32(A, E0), vgetq_lane_u32(A, E1), \ + vgetq_lane_u32(A, E2), vgetq_lane_u32(A, E3) \ + ) + #define npyv_permi128_s32(A, E0, E1, E2, E3) \ + npyv_set_s32( \ + vgetq_lane_s32(A, E0), vgetq_lane_s32(A, E1), \ + vgetq_lane_s32(A, E2), vgetq_lane_s32(A, E3) \ + ) + #define npyv_permi128_f32(A, E0, E1, E2, E3) \ + npyv_set_f32( \ + vgetq_lane_f32(A, E0), vgetq_lane_f32(A, E1), \ + vgetq_lane_f32(A, E2), vgetq_lane_f32(A, E3) \ + ) +#endif + +#if defined(__clang__) || defined(__GNUC__) + #define npyv_permi128_s32 npyv_permi128_u32 + #define npyv_permi128_f32 npyv_permi128_u32 +#endif + +#ifdef __clang__ + #define npyv_permi128_u64(A, E0, E1) \ + __builtin_shufflevector(A, A, E0, E1) +#elif defined(__GNUC__) + #define npyv_permi128_u64(A, E0, E1) \ + __builtin_shuffle(A, npyv_set_u64(E0, E1)) +#else + #define npyv_permi128_u64(A, E0, E1) \ + npyv_set_u64( \ + vgetq_lane_u64(A, E0), vgetq_lane_u64(A, E1) \ + ) + #define npyv_permi128_s64(A, E0, E1) \ + npyv_set_s64( \ + vgetq_lane_s64(A, E0), vgetq_lane_s64(A, E1) \ + ) + #define npyv_permi128_f64(A, E0, E1) \ + npyv_set_f64( \ + vgetq_lane_f64(A, E0), vgetq_lane_f64(A, E1) \ + ) +#endif + +#if defined(__clang__) || defined(__GNUC__) + #define npyv_permi128_s64 npyv_permi128_u64 + #define npyv_permi128_f64 npyv_permi128_u64 +#endif + +#if !NPY_SIMD_F64 + #undef npyv_permi128_f64 +#endif + +#endif // _NPY_SIMD_NEON_REORDER_H diff --git a/mkl_umath/src/npyv/npy_argparse.h b/mkl_umath/src/npyv/npy_argparse.h new file mode 100644 index 00000000..f4122103 --- /dev/null +++ b/mkl_umath/src/npyv/npy_argparse.h @@ -0,0 +1,96 @@ +#ifndef NUMPY_CORE_SRC_COMMON_NPY_ARGPARSE_H +#define NUMPY_CORE_SRC_COMMON_NPY_ARGPARSE_H + +#include +#include "numpy/ndarraytypes.h" + +/* + * This file defines macros to help with keyword argument parsing. + * This solves two issues as of now: + * 1. Pythons C-API PyArg_* keyword argument parsers are slow, due to + * not caching the strings they use. + * 2. It allows the use of METH_ARGPARSE (and `tp_vectorcall`) + * when available in Python, which removes a large chunk of overhead. + * + * Internally CPython achieves similar things by using a code generator + * argument clinic. NumPy may well decide to use argument clinic or a different + * solution in the future. + */ + +NPY_NO_EXPORT int +PyArray_PythonPyIntFromInt(PyObject *obj, int *value); + + +#define _NPY_MAX_KWARGS 15 + +typedef struct { + int npositional; + int nargs; + int npositional_only; + int nrequired; + /* Null terminated list of keyword argument name strings */ + PyObject *kw_strings[_NPY_MAX_KWARGS+1]; +} _NpyArgParserCache; + + +/* + * The sole purpose of this macro is to hide the argument parsing cache. + * Since this cache must be static, this also removes a source of error. + */ +#define NPY_PREPARE_ARGPARSER static _NpyArgParserCache __argparse_cache = {-1} + +/** + * Macro to help with argument parsing. + * + * The pattern for using this macro is by defining the method as: + * + * @code + * static PyObject * + * my_method(PyObject *self, + * PyObject *const *args, Py_ssize_t len_args, PyObject *kwnames) + * { + * NPY_PREPARE_ARGPARSER; + * + * PyObject *argument1, *argument3; + * int argument2 = -1; + * if (npy_parse_arguments("method", args, len_args, kwnames), + * "argument1", NULL, &argument1, + * "|argument2", &PyArray_PythonPyIntFromInt, &argument2, + * "$argument3", NULL, &argument3, + * NULL, NULL, NULL) < 0) { + * return NULL; + * } + * } + * @endcode + * + * The `NPY_PREPARE_ARGPARSER` macro sets up a static cache variable necessary + * to hold data for speeding up the parsing. `npy_parse_arguments` must be + * used in cunjunction with the macro defined in the same scope. + * (No two `npy_parse_arguments` may share a single `NPY_PREPARE_ARGPARSER`.) + * + * @param funcname + * @param args Python passed args (METH_FASTCALL) + * @param len_args Number of arguments (not flagged) + * @param kwnames Tuple as passed by METH_FASTCALL or NULL. + * @param ... List of arguments must be param1_name, param1_converter, + * *param1_outvalue, param2_name, ..., NULL, NULL, NULL. + * Where name is ``char *``, ``converter`` a python converter + * function or NULL and ``outvalue`` is the ``void *`` passed to + * the converter (holding the converted data or a borrowed + * reference if converter is NULL). + * + * @return Returns 0 on success and -1 on failure. + */ +NPY_NO_EXPORT int +_npy_parse_arguments(const char *funcname, + /* cache_ptr is a NULL initialized persistent storage for data */ + _NpyArgParserCache *cache_ptr, + PyObject *const *args, Py_ssize_t len_args, PyObject *kwnames, + /* va_list is NULL, NULL, NULL terminated: name, converter, value */ + ...) NPY_GCC_NONNULL(1); + +#define npy_parse_arguments(funcname, args, len_args, kwnames, ...) \ + _npy_parse_arguments(funcname, &__argparse_cache, \ + args, len_args, kwnames, __VA_ARGS__) + +#endif /* NUMPY_CORE_SRC_COMMON_NPY_ARGPARSE_H */ diff --git a/mkl_umath/src/npyv/npy_binsearch.h b/mkl_umath/src/npyv/npy_binsearch.h new file mode 100644 index 00000000..8d2f0714 --- /dev/null +++ b/mkl_umath/src/npyv/npy_binsearch.h @@ -0,0 +1,31 @@ +#ifndef __NPY_BINSEARCH_H__ +#define __NPY_BINSEARCH_H__ + +#include "npy_sort.h" +#include +#include + + +#ifdef __cplusplus +extern "C" { +#endif + +typedef void (PyArray_BinSearchFunc)(const char*, const char*, char*, + npy_intp, npy_intp, + npy_intp, npy_intp, npy_intp, + PyArrayObject*); + +typedef int (PyArray_ArgBinSearchFunc)(const char*, const char*, + const char*, char*, + npy_intp, npy_intp, npy_intp, + npy_intp, npy_intp, npy_intp, + PyArrayObject*); + +NPY_NO_EXPORT PyArray_BinSearchFunc* get_binsearch_func(PyArray_Descr *dtype, NPY_SEARCHSIDE side); +NPY_NO_EXPORT PyArray_ArgBinSearchFunc* get_argbinsearch_func(PyArray_Descr *dtype, NPY_SEARCHSIDE side); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/mkl_umath/src/npyv/npy_cblas.h b/mkl_umath/src/npyv/npy_cblas.h new file mode 100644 index 00000000..596a7c68 --- /dev/null +++ b/mkl_umath/src/npyv/npy_cblas.h @@ -0,0 +1,129 @@ +/* + * This header provides numpy a consistent interface to CBLAS code. It is needed + * because not all providers of cblas provide cblas.h. For instance, MKL provides + * mkl_cblas.h and also typedefs the CBLAS_XXX enums. + */ +#ifndef NUMPY_CORE_SRC_COMMON_NPY_CBLAS_H_ +#define NUMPY_CORE_SRC_COMMON_NPY_CBLAS_H_ + +#include + +/* Allow the use in C++ code. */ +#ifdef __cplusplus +extern "C" +{ +#endif + +/* + * Enumerated and derived types + */ +enum CBLAS_ORDER {CblasRowMajor=101, CblasColMajor=102}; +enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113}; +enum CBLAS_UPLO {CblasUpper=121, CblasLower=122}; +enum CBLAS_DIAG {CblasNonUnit=131, CblasUnit=132}; +enum CBLAS_SIDE {CblasLeft=141, CblasRight=142}; + +#define CBLAS_INDEX size_t /* this may vary between platforms */ + +#ifdef ACCELERATE_NEW_LAPACK + #if __MAC_OS_X_VERSION_MAX_ALLOWED < 130300 + #ifdef HAVE_BLAS_ILP64 + #error "Accelerate ILP64 support is only available with macOS 13.3 SDK or later" + #endif + #else + #define NO_APPEND_FORTRAN + #ifdef HAVE_BLAS_ILP64 + #define BLAS_SYMBOL_SUFFIX $NEWLAPACK$ILP64 + #else + #define BLAS_SYMBOL_SUFFIX $NEWLAPACK + #endif + #endif +#endif + +#ifdef NO_APPEND_FORTRAN +#define BLAS_FORTRAN_SUFFIX +#else +#define BLAS_FORTRAN_SUFFIX _ +#endif + +#ifndef BLAS_SYMBOL_PREFIX +#define BLAS_SYMBOL_PREFIX +#endif + +#ifndef BLAS_SYMBOL_SUFFIX +#define BLAS_SYMBOL_SUFFIX +#endif + +#define BLAS_FUNC_CONCAT(name,prefix,suffix,suffix2) prefix ## name ## suffix ## suffix2 +#define BLAS_FUNC_EXPAND(name,prefix,suffix,suffix2) BLAS_FUNC_CONCAT(name,prefix,suffix,suffix2) + +/* + * Use either the OpenBLAS scheme with the `64_` suffix behind the Fortran + * compiler symbol mangling, or the MKL scheme (and upcoming + * reference-lapack#666) which does it the other way around and uses `_64`. + */ +#ifdef OPENBLAS_ILP64_NAMING_SCHEME +#define BLAS_FUNC(name) BLAS_FUNC_EXPAND(name,BLAS_SYMBOL_PREFIX,BLAS_FORTRAN_SUFFIX,BLAS_SYMBOL_SUFFIX) +#else +#define BLAS_FUNC(name) BLAS_FUNC_EXPAND(name,BLAS_SYMBOL_PREFIX,BLAS_SYMBOL_SUFFIX,BLAS_FORTRAN_SUFFIX) +#endif +/* + * Note that CBLAS doesn't include Fortran compiler symbol mangling, so ends up + * being the same in both schemes + */ +#define CBLAS_FUNC(name) BLAS_FUNC_EXPAND(name,BLAS_SYMBOL_PREFIX,,BLAS_SYMBOL_SUFFIX) + +#ifdef HAVE_BLAS_ILP64 +#define CBLAS_INT npy_int64 +#define CBLAS_INT_MAX NPY_MAX_INT64 +#else +#define CBLAS_INT int +#define CBLAS_INT_MAX INT_MAX +#endif + +#define BLASNAME(name) CBLAS_FUNC(name) +#define BLASINT CBLAS_INT + +#include "npy_cblas_base.h" + +#undef BLASINT +#undef BLASNAME + + +/* + * Convert NumPy stride to BLAS stride. Returns 0 if conversion cannot be done + * (BLAS won't handle negative or zero strides the way we want). + */ +static inline CBLAS_INT +blas_stride(npy_intp stride, unsigned itemsize) +{ + /* + * Should probably check pointer alignment also, but this may cause + * problems if we require complex to be 16 byte aligned. + */ + if (stride > 0 && (stride % itemsize) == 0) { + stride /= itemsize; + if (stride <= CBLAS_INT_MAX) { + return stride; + } + } + return 0; +} + +/* + * Define a chunksize for CBLAS. + * + * The chunksize is the greatest power of two less than CBLAS_INT_MAX. + */ +#if NPY_MAX_INTP > CBLAS_INT_MAX +# define NPY_CBLAS_CHUNK (CBLAS_INT_MAX / 2 + 1) +#else +# define NPY_CBLAS_CHUNK NPY_MAX_INTP +#endif + + +#ifdef __cplusplus +} +#endif + +#endif /* NUMPY_CORE_SRC_COMMON_NPY_CBLAS_H_ */ diff --git a/mkl_umath/src/npyv/npy_cblas_base.h b/mkl_umath/src/npyv/npy_cblas_base.h new file mode 100644 index 00000000..12dfb2e7 --- /dev/null +++ b/mkl_umath/src/npyv/npy_cblas_base.h @@ -0,0 +1,562 @@ +/* + * This header provides numpy a consistent interface to CBLAS code. It is needed + * because not all providers of cblas provide cblas.h. For instance, MKL provides + * mkl_cblas.h and also typedefs the CBLAS_XXX enums. + */ + +/* + * =========================================================================== + * Prototypes for level 1 BLAS functions (complex are recast as routines) + * =========================================================================== + */ +#ifndef NUMPY_CORE_SRC_COMMON_NPY_CBLAS_BASE_H_ +#define NUMPY_CORE_SRC_COMMON_NPY_CBLAS_BASE_H_ + +float BLASNAME(cblas_sdsdot)(const BLASINT N, const float alpha, const float *X, + const BLASINT incX, const float *Y, const BLASINT incY); +double BLASNAME(cblas_dsdot)(const BLASINT N, const float *X, const BLASINT incX, const float *Y, + const BLASINT incY); +float BLASNAME(cblas_sdot)(const BLASINT N, const float *X, const BLASINT incX, + const float *Y, const BLASINT incY); +double BLASNAME(cblas_ddot)(const BLASINT N, const double *X, const BLASINT incX, + const double *Y, const BLASINT incY); + +/* + * Functions having prefixes Z and C only + */ +void BLASNAME(cblas_cdotu_sub)(const BLASINT N, const void *X, const BLASINT incX, + const void *Y, const BLASINT incY, void *dotu); +void BLASNAME(cblas_cdotc_sub)(const BLASINT N, const void *X, const BLASINT incX, + const void *Y, const BLASINT incY, void *dotc); + +void BLASNAME(cblas_zdotu_sub)(const BLASINT N, const void *X, const BLASINT incX, + const void *Y, const BLASINT incY, void *dotu); +void BLASNAME(cblas_zdotc_sub)(const BLASINT N, const void *X, const BLASINT incX, + const void *Y, const BLASINT incY, void *dotc); + + +/* + * Functions having prefixes S D SC DZ + */ +float BLASNAME(cblas_snrm2)(const BLASINT N, const float *X, const BLASINT incX); +float BLASNAME(cblas_sasum)(const BLASINT N, const float *X, const BLASINT incX); + +double BLASNAME(cblas_dnrm2)(const BLASINT N, const double *X, const BLASINT incX); +double BLASNAME(cblas_dasum)(const BLASINT N, const double *X, const BLASINT incX); + +float BLASNAME(cblas_scnrm2)(const BLASINT N, const void *X, const BLASINT incX); +float BLASNAME(cblas_scasum)(const BLASINT N, const void *X, const BLASINT incX); + +double BLASNAME(cblas_dznrm2)(const BLASINT N, const void *X, const BLASINT incX); +double BLASNAME(cblas_dzasum)(const BLASINT N, const void *X, const BLASINT incX); + + +/* + * Functions having standard 4 prefixes (S D C Z) + */ +CBLAS_INDEX BLASNAME(cblas_isamax)(const BLASINT N, const float *X, const BLASINT incX); +CBLAS_INDEX BLASNAME(cblas_idamax)(const BLASINT N, const double *X, const BLASINT incX); +CBLAS_INDEX BLASNAME(cblas_icamax)(const BLASINT N, const void *X, const BLASINT incX); +CBLAS_INDEX BLASNAME(cblas_izamax)(const BLASINT N, const void *X, const BLASINT incX); + +/* + * =========================================================================== + * Prototypes for level 1 BLAS routines + * =========================================================================== + */ + +/* + * Routines with standard 4 prefixes (s, d, c, z) + */ +void BLASNAME(cblas_sswap)(const BLASINT N, float *X, const BLASINT incX, + float *Y, const BLASINT incY); +void BLASNAME(cblas_scopy)(const BLASINT N, const float *X, const BLASINT incX, + float *Y, const BLASINT incY); +void BLASNAME(cblas_saxpy)(const BLASINT N, const float alpha, const float *X, + const BLASINT incX, float *Y, const BLASINT incY); + +void BLASNAME(cblas_dswap)(const BLASINT N, double *X, const BLASINT incX, + double *Y, const BLASINT incY); +void BLASNAME(cblas_dcopy)(const BLASINT N, const double *X, const BLASINT incX, + double *Y, const BLASINT incY); +void BLASNAME(cblas_daxpy)(const BLASINT N, const double alpha, const double *X, + const BLASINT incX, double *Y, const BLASINT incY); + +void BLASNAME(cblas_cswap)(const BLASINT N, void *X, const BLASINT incX, + void *Y, const BLASINT incY); +void BLASNAME(cblas_ccopy)(const BLASINT N, const void *X, const BLASINT incX, + void *Y, const BLASINT incY); +void BLASNAME(cblas_caxpy)(const BLASINT N, const void *alpha, const void *X, + const BLASINT incX, void *Y, const BLASINT incY); + +void BLASNAME(cblas_zswap)(const BLASINT N, void *X, const BLASINT incX, + void *Y, const BLASINT incY); +void BLASNAME(cblas_zcopy)(const BLASINT N, const void *X, const BLASINT incX, + void *Y, const BLASINT incY); +void BLASNAME(cblas_zaxpy)(const BLASINT N, const void *alpha, const void *X, + const BLASINT incX, void *Y, const BLASINT incY); + + +/* + * Routines with S and D prefix only + */ +void BLASNAME(cblas_srotg)(float *a, float *b, float *c, float *s); +void BLASNAME(cblas_srotmg)(float *d1, float *d2, float *b1, const float b2, float *P); +void BLASNAME(cblas_srot)(const BLASINT N, float *X, const BLASINT incX, + float *Y, const BLASINT incY, const float c, const float s); +void BLASNAME(cblas_srotm)(const BLASINT N, float *X, const BLASINT incX, + float *Y, const BLASINT incY, const float *P); + +void BLASNAME(cblas_drotg)(double *a, double *b, double *c, double *s); +void BLASNAME(cblas_drotmg)(double *d1, double *d2, double *b1, const double b2, double *P); +void BLASNAME(cblas_drot)(const BLASINT N, double *X, const BLASINT incX, + double *Y, const BLASINT incY, const double c, const double s); +void BLASNAME(cblas_drotm)(const BLASINT N, double *X, const BLASINT incX, + double *Y, const BLASINT incY, const double *P); + + +/* + * Routines with S D C Z CS and ZD prefixes + */ +void BLASNAME(cblas_sscal)(const BLASINT N, const float alpha, float *X, const BLASINT incX); +void BLASNAME(cblas_dscal)(const BLASINT N, const double alpha, double *X, const BLASINT incX); +void BLASNAME(cblas_cscal)(const BLASINT N, const void *alpha, void *X, const BLASINT incX); +void BLASNAME(cblas_zscal)(const BLASINT N, const void *alpha, void *X, const BLASINT incX); +void BLASNAME(cblas_csscal)(const BLASINT N, const float alpha, void *X, const BLASINT incX); +void BLASNAME(cblas_zdscal)(const BLASINT N, const double alpha, void *X, const BLASINT incX); + +/* + * =========================================================================== + * Prototypes for level 2 BLAS + * =========================================================================== + */ + +/* + * Routines with standard 4 prefixes (S, D, C, Z) + */ +void BLASNAME(cblas_sgemv)(const enum CBLAS_ORDER order, + const enum CBLAS_TRANSPOSE TransA, const BLASINT M, const BLASINT N, + const float alpha, const float *A, const BLASINT lda, + const float *X, const BLASINT incX, const float beta, + float *Y, const BLASINT incY); +void BLASNAME(cblas_sgbmv)(const enum CBLAS_ORDER order, + const enum CBLAS_TRANSPOSE TransA, const BLASINT M, const BLASINT N, + const BLASINT KL, const BLASINT KU, const float alpha, + const float *A, const BLASINT lda, const float *X, + const BLASINT incX, const float beta, float *Y, const BLASINT incY); +void BLASNAME(cblas_strmv)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const BLASINT N, const float *A, const BLASINT lda, + float *X, const BLASINT incX); +void BLASNAME(cblas_stbmv)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const BLASINT N, const BLASINT K, const float *A, const BLASINT lda, + float *X, const BLASINT incX); +void BLASNAME(cblas_stpmv)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const BLASINT N, const float *Ap, float *X, const BLASINT incX); +void BLASNAME(cblas_strsv)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const BLASINT N, const float *A, const BLASINT lda, float *X, + const BLASINT incX); +void BLASNAME(cblas_stbsv)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const BLASINT N, const BLASINT K, const float *A, const BLASINT lda, + float *X, const BLASINT incX); +void BLASNAME(cblas_stpsv)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const BLASINT N, const float *Ap, float *X, const BLASINT incX); + +void BLASNAME(cblas_dgemv)(const enum CBLAS_ORDER order, + const enum CBLAS_TRANSPOSE TransA, const BLASINT M, const BLASINT N, + const double alpha, const double *A, const BLASINT lda, + const double *X, const BLASINT incX, const double beta, + double *Y, const BLASINT incY); +void BLASNAME(cblas_dgbmv)(const enum CBLAS_ORDER order, + const enum CBLAS_TRANSPOSE TransA, const BLASINT M, const BLASINT N, + const BLASINT KL, const BLASINT KU, const double alpha, + const double *A, const BLASINT lda, const double *X, + const BLASINT incX, const double beta, double *Y, const BLASINT incY); +void BLASNAME(cblas_dtrmv)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const BLASINT N, const double *A, const BLASINT lda, + double *X, const BLASINT incX); +void BLASNAME(cblas_dtbmv)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const BLASINT N, const BLASINT K, const double *A, const BLASINT lda, + double *X, const BLASINT incX); +void BLASNAME(cblas_dtpmv)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const BLASINT N, const double *Ap, double *X, const BLASINT incX); +void BLASNAME(cblas_dtrsv)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const BLASINT N, const double *A, const BLASINT lda, double *X, + const BLASINT incX); +void BLASNAME(cblas_dtbsv)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const BLASINT N, const BLASINT K, const double *A, const BLASINT lda, + double *X, const BLASINT incX); +void BLASNAME(cblas_dtpsv)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const BLASINT N, const double *Ap, double *X, const BLASINT incX); + +void BLASNAME(cblas_cgemv)(const enum CBLAS_ORDER order, + const enum CBLAS_TRANSPOSE TransA, const BLASINT M, const BLASINT N, + const void *alpha, const void *A, const BLASINT lda, + const void *X, const BLASINT incX, const void *beta, + void *Y, const BLASINT incY); +void BLASNAME(cblas_cgbmv)(const enum CBLAS_ORDER order, + const enum CBLAS_TRANSPOSE TransA, const BLASINT M, const BLASINT N, + const BLASINT KL, const BLASINT KU, const void *alpha, + const void *A, const BLASINT lda, const void *X, + const BLASINT incX, const void *beta, void *Y, const BLASINT incY); +void BLASNAME(cblas_ctrmv)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const BLASINT N, const void *A, const BLASINT lda, + void *X, const BLASINT incX); +void BLASNAME(cblas_ctbmv)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const BLASINT N, const BLASINT K, const void *A, const BLASINT lda, + void *X, const BLASINT incX); +void BLASNAME(cblas_ctpmv)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const BLASINT N, const void *Ap, void *X, const BLASINT incX); +void BLASNAME(cblas_ctrsv)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const BLASINT N, const void *A, const BLASINT lda, void *X, + const BLASINT incX); +void BLASNAME(cblas_ctbsv)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const BLASINT N, const BLASINT K, const void *A, const BLASINT lda, + void *X, const BLASINT incX); +void BLASNAME(cblas_ctpsv)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const BLASINT N, const void *Ap, void *X, const BLASINT incX); + +void BLASNAME(cblas_zgemv)(const enum CBLAS_ORDER order, + const enum CBLAS_TRANSPOSE TransA, const BLASINT M, const BLASINT N, + const void *alpha, const void *A, const BLASINT lda, + const void *X, const BLASINT incX, const void *beta, + void *Y, const BLASINT incY); +void BLASNAME(cblas_zgbmv)(const enum CBLAS_ORDER order, + const enum CBLAS_TRANSPOSE TransA, const BLASINT M, const BLASINT N, + const BLASINT KL, const BLASINT KU, const void *alpha, + const void *A, const BLASINT lda, const void *X, + const BLASINT incX, const void *beta, void *Y, const BLASINT incY); +void BLASNAME(cblas_ztrmv)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const BLASINT N, const void *A, const BLASINT lda, + void *X, const BLASINT incX); +void BLASNAME(cblas_ztbmv)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const BLASINT N, const BLASINT K, const void *A, const BLASINT lda, + void *X, const BLASINT incX); +void BLASNAME(cblas_ztpmv)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const BLASINT N, const void *Ap, void *X, const BLASINT incX); +void BLASNAME(cblas_ztrsv)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const BLASINT N, const void *A, const BLASINT lda, void *X, + const BLASINT incX); +void BLASNAME(cblas_ztbsv)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const BLASINT N, const BLASINT K, const void *A, const BLASINT lda, + void *X, const BLASINT incX); +void BLASNAME(cblas_ztpsv)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const BLASINT N, const void *Ap, void *X, const BLASINT incX); + + +/* + * Routines with S and D prefixes only + */ +void BLASNAME(cblas_ssymv)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const BLASINT N, const float alpha, const float *A, + const BLASINT lda, const float *X, const BLASINT incX, + const float beta, float *Y, const BLASINT incY); +void BLASNAME(cblas_ssbmv)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const BLASINT N, const BLASINT K, const float alpha, const float *A, + const BLASINT lda, const float *X, const BLASINT incX, + const float beta, float *Y, const BLASINT incY); +void BLASNAME(cblas_sspmv)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const BLASINT N, const float alpha, const float *Ap, + const float *X, const BLASINT incX, + const float beta, float *Y, const BLASINT incY); +void BLASNAME(cblas_sger)(const enum CBLAS_ORDER order, const BLASINT M, const BLASINT N, + const float alpha, const float *X, const BLASINT incX, + const float *Y, const BLASINT incY, float *A, const BLASINT lda); +void BLASNAME(cblas_ssyr)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const BLASINT N, const float alpha, const float *X, + const BLASINT incX, float *A, const BLASINT lda); +void BLASNAME(cblas_sspr)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const BLASINT N, const float alpha, const float *X, + const BLASINT incX, float *Ap); +void BLASNAME(cblas_ssyr2)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const BLASINT N, const float alpha, const float *X, + const BLASINT incX, const float *Y, const BLASINT incY, float *A, + const BLASINT lda); +void BLASNAME(cblas_sspr2)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const BLASINT N, const float alpha, const float *X, + const BLASINT incX, const float *Y, const BLASINT incY, float *A); + +void BLASNAME(cblas_dsymv)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const BLASINT N, const double alpha, const double *A, + const BLASINT lda, const double *X, const BLASINT incX, + const double beta, double *Y, const BLASINT incY); +void BLASNAME(cblas_dsbmv)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const BLASINT N, const BLASINT K, const double alpha, const double *A, + const BLASINT lda, const double *X, const BLASINT incX, + const double beta, double *Y, const BLASINT incY); +void BLASNAME(cblas_dspmv)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const BLASINT N, const double alpha, const double *Ap, + const double *X, const BLASINT incX, + const double beta, double *Y, const BLASINT incY); +void BLASNAME(cblas_dger)(const enum CBLAS_ORDER order, const BLASINT M, const BLASINT N, + const double alpha, const double *X, const BLASINT incX, + const double *Y, const BLASINT incY, double *A, const BLASINT lda); +void BLASNAME(cblas_dsyr)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const BLASINT N, const double alpha, const double *X, + const BLASINT incX, double *A, const BLASINT lda); +void BLASNAME(cblas_dspr)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const BLASINT N, const double alpha, const double *X, + const BLASINT incX, double *Ap); +void BLASNAME(cblas_dsyr2)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const BLASINT N, const double alpha, const double *X, + const BLASINT incX, const double *Y, const BLASINT incY, double *A, + const BLASINT lda); +void BLASNAME(cblas_dspr2)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const BLASINT N, const double alpha, const double *X, + const BLASINT incX, const double *Y, const BLASINT incY, double *A); + + +/* + * Routines with C and Z prefixes only + */ +void BLASNAME(cblas_chemv)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const BLASINT N, const void *alpha, const void *A, + const BLASINT lda, const void *X, const BLASINT incX, + const void *beta, void *Y, const BLASINT incY); +void BLASNAME(cblas_chbmv)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const BLASINT N, const BLASINT K, const void *alpha, const void *A, + const BLASINT lda, const void *X, const BLASINT incX, + const void *beta, void *Y, const BLASINT incY); +void BLASNAME(cblas_chpmv)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const BLASINT N, const void *alpha, const void *Ap, + const void *X, const BLASINT incX, + const void *beta, void *Y, const BLASINT incY); +void BLASNAME(cblas_cgeru)(const enum CBLAS_ORDER order, const BLASINT M, const BLASINT N, + const void *alpha, const void *X, const BLASINT incX, + const void *Y, const BLASINT incY, void *A, const BLASINT lda); +void BLASNAME(cblas_cgerc)(const enum CBLAS_ORDER order, const BLASINT M, const BLASINT N, + const void *alpha, const void *X, const BLASINT incX, + const void *Y, const BLASINT incY, void *A, const BLASINT lda); +void BLASNAME(cblas_cher)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const BLASINT N, const float alpha, const void *X, const BLASINT incX, + void *A, const BLASINT lda); +void BLASNAME(cblas_chpr)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const BLASINT N, const float alpha, const void *X, + const BLASINT incX, void *A); +void BLASNAME(cblas_cher2)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const BLASINT N, + const void *alpha, const void *X, const BLASINT incX, + const void *Y, const BLASINT incY, void *A, const BLASINT lda); +void BLASNAME(cblas_chpr2)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const BLASINT N, + const void *alpha, const void *X, const BLASINT incX, + const void *Y, const BLASINT incY, void *Ap); + +void BLASNAME(cblas_zhemv)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const BLASINT N, const void *alpha, const void *A, + const BLASINT lda, const void *X, const BLASINT incX, + const void *beta, void *Y, const BLASINT incY); +void BLASNAME(cblas_zhbmv)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const BLASINT N, const BLASINT K, const void *alpha, const void *A, + const BLASINT lda, const void *X, const BLASINT incX, + const void *beta, void *Y, const BLASINT incY); +void BLASNAME(cblas_zhpmv)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const BLASINT N, const void *alpha, const void *Ap, + const void *X, const BLASINT incX, + const void *beta, void *Y, const BLASINT incY); +void BLASNAME(cblas_zgeru)(const enum CBLAS_ORDER order, const BLASINT M, const BLASINT N, + const void *alpha, const void *X, const BLASINT incX, + const void *Y, const BLASINT incY, void *A, const BLASINT lda); +void BLASNAME(cblas_zgerc)(const enum CBLAS_ORDER order, const BLASINT M, const BLASINT N, + const void *alpha, const void *X, const BLASINT incX, + const void *Y, const BLASINT incY, void *A, const BLASINT lda); +void BLASNAME(cblas_zher)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const BLASINT N, const double alpha, const void *X, const BLASINT incX, + void *A, const BLASINT lda); +void BLASNAME(cblas_zhpr)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, + const BLASINT N, const double alpha, const void *X, + const BLASINT incX, void *A); +void BLASNAME(cblas_zher2)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const BLASINT N, + const void *alpha, const void *X, const BLASINT incX, + const void *Y, const BLASINT incY, void *A, const BLASINT lda); +void BLASNAME(cblas_zhpr2)(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const BLASINT N, + const void *alpha, const void *X, const BLASINT incX, + const void *Y, const BLASINT incY, void *Ap); + +/* + * =========================================================================== + * Prototypes for level 3 BLAS + * =========================================================================== + */ + +/* + * Routines with standard 4 prefixes (S, D, C, Z) + */ +void BLASNAME(cblas_sgemm)(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_TRANSPOSE TransB, const BLASINT M, const BLASINT N, + const BLASINT K, const float alpha, const float *A, + const BLASINT lda, const float *B, const BLASINT ldb, + const float beta, float *C, const BLASINT ldc); +void BLASNAME(cblas_ssymm)(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const BLASINT M, const BLASINT N, + const float alpha, const float *A, const BLASINT lda, + const float *B, const BLASINT ldb, const float beta, + float *C, const BLASINT ldc); +void BLASNAME(cblas_ssyrk)(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const BLASINT N, const BLASINT K, + const float alpha, const float *A, const BLASINT lda, + const float beta, float *C, const BLASINT ldc); +void BLASNAME(cblas_ssyr2k)(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const BLASINT N, const BLASINT K, + const float alpha, const float *A, const BLASINT lda, + const float *B, const BLASINT ldb, const float beta, + float *C, const BLASINT ldc); +void BLASNAME(cblas_strmm)(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const BLASINT M, const BLASINT N, + const float alpha, const float *A, const BLASINT lda, + float *B, const BLASINT ldb); +void BLASNAME(cblas_strsm)(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const BLASINT M, const BLASINT N, + const float alpha, const float *A, const BLASINT lda, + float *B, const BLASINT ldb); + +void BLASNAME(cblas_dgemm)(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_TRANSPOSE TransB, const BLASINT M, const BLASINT N, + const BLASINT K, const double alpha, const double *A, + const BLASINT lda, const double *B, const BLASINT ldb, + const double beta, double *C, const BLASINT ldc); +void BLASNAME(cblas_dsymm)(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const BLASINT M, const BLASINT N, + const double alpha, const double *A, const BLASINT lda, + const double *B, const BLASINT ldb, const double beta, + double *C, const BLASINT ldc); +void BLASNAME(cblas_dsyrk)(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const BLASINT N, const BLASINT K, + const double alpha, const double *A, const BLASINT lda, + const double beta, double *C, const BLASINT ldc); +void BLASNAME(cblas_dsyr2k)(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const BLASINT N, const BLASINT K, + const double alpha, const double *A, const BLASINT lda, + const double *B, const BLASINT ldb, const double beta, + double *C, const BLASINT ldc); +void BLASNAME(cblas_dtrmm)(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const BLASINT M, const BLASINT N, + const double alpha, const double *A, const BLASINT lda, + double *B, const BLASINT ldb); +void BLASNAME(cblas_dtrsm)(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const BLASINT M, const BLASINT N, + const double alpha, const double *A, const BLASINT lda, + double *B, const BLASINT ldb); + +void BLASNAME(cblas_cgemm)(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_TRANSPOSE TransB, const BLASINT M, const BLASINT N, + const BLASINT K, const void *alpha, const void *A, + const BLASINT lda, const void *B, const BLASINT ldb, + const void *beta, void *C, const BLASINT ldc); +void BLASNAME(cblas_csymm)(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const BLASINT M, const BLASINT N, + const void *alpha, const void *A, const BLASINT lda, + const void *B, const BLASINT ldb, const void *beta, + void *C, const BLASINT ldc); +void BLASNAME(cblas_csyrk)(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const BLASINT N, const BLASINT K, + const void *alpha, const void *A, const BLASINT lda, + const void *beta, void *C, const BLASINT ldc); +void BLASNAME(cblas_csyr2k)(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const BLASINT N, const BLASINT K, + const void *alpha, const void *A, const BLASINT lda, + const void *B, const BLASINT ldb, const void *beta, + void *C, const BLASINT ldc); +void BLASNAME(cblas_ctrmm)(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const BLASINT M, const BLASINT N, + const void *alpha, const void *A, const BLASINT lda, + void *B, const BLASINT ldb); +void BLASNAME(cblas_ctrsm)(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const BLASINT M, const BLASINT N, + const void *alpha, const void *A, const BLASINT lda, + void *B, const BLASINT ldb); + +void BLASNAME(cblas_zgemm)(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_TRANSPOSE TransB, const BLASINT M, const BLASINT N, + const BLASINT K, const void *alpha, const void *A, + const BLASINT lda, const void *B, const BLASINT ldb, + const void *beta, void *C, const BLASINT ldc); +void BLASNAME(cblas_zsymm)(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const BLASINT M, const BLASINT N, + const void *alpha, const void *A, const BLASINT lda, + const void *B, const BLASINT ldb, const void *beta, + void *C, const BLASINT ldc); +void BLASNAME(cblas_zsyrk)(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const BLASINT N, const BLASINT K, + const void *alpha, const void *A, const BLASINT lda, + const void *beta, void *C, const BLASINT ldc); +void BLASNAME(cblas_zsyr2k)(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const BLASINT N, const BLASINT K, + const void *alpha, const void *A, const BLASINT lda, + const void *B, const BLASINT ldb, const void *beta, + void *C, const BLASINT ldc); +void BLASNAME(cblas_ztrmm)(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const BLASINT M, const BLASINT N, + const void *alpha, const void *A, const BLASINT lda, + void *B, const BLASINT ldb); +void BLASNAME(cblas_ztrsm)(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const BLASINT M, const BLASINT N, + const void *alpha, const void *A, const BLASINT lda, + void *B, const BLASINT ldb); + + +/* + * Routines with prefixes C and Z only + */ +void BLASNAME(cblas_chemm)(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const BLASINT M, const BLASINT N, + const void *alpha, const void *A, const BLASINT lda, + const void *B, const BLASINT ldb, const void *beta, + void *C, const BLASINT ldc); +void BLASNAME(cblas_cherk)(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const BLASINT N, const BLASINT K, + const float alpha, const void *A, const BLASINT lda, + const float beta, void *C, const BLASINT ldc); +void BLASNAME(cblas_cher2k)(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const BLASINT N, const BLASINT K, + const void *alpha, const void *A, const BLASINT lda, + const void *B, const BLASINT ldb, const float beta, + void *C, const BLASINT ldc); + +void BLASNAME(cblas_zhemm)(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const BLASINT M, const BLASINT N, + const void *alpha, const void *A, const BLASINT lda, + const void *B, const BLASINT ldb, const void *beta, + void *C, const BLASINT ldc); +void BLASNAME(cblas_zherk)(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const BLASINT N, const BLASINT K, + const double alpha, const void *A, const BLASINT lda, + const double beta, void *C, const BLASINT ldc); +void BLASNAME(cblas_zher2k)(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const BLASINT N, const BLASINT K, + const void *alpha, const void *A, const BLASINT lda, + const void *B, const BLASINT ldb, const double beta, + void *C, const BLASINT ldc); + +void BLASNAME(cblas_xerbla)(BLASINT p, const char *rout, const char *form, ...); + +#endif /* NUMPY_CORE_SRC_COMMON_NPY_CBLAS_BASE_H_ */ diff --git a/mkl_umath/src/npyv/npy_config.h b/mkl_umath/src/npyv/npy_config.h new file mode 100644 index 00000000..e5903668 --- /dev/null +++ b/mkl_umath/src/npyv/npy_config.h @@ -0,0 +1,188 @@ +#ifndef NUMPY_CORE_SRC_COMMON_NPY_CONFIG_H_ +#define NUMPY_CORE_SRC_COMMON_NPY_CONFIG_H_ + +#include "config.h" +#include "npy_cpu_dispatch.h" // brings NPY_HAVE_[CPU features] +#include "numpy/numpyconfig.h" +#include "numpy/utils.h" +#include "numpy/npy_os.h" + +/* blocklist */ + +/* Disable broken functions on z/OS */ +#if defined (__MVS__) + +#define NPY_BLOCK_POWF +#define NPY_BLOCK_EXPF +#undef HAVE___THREAD + +#endif + +/* Disable broken MS math functions */ +#if defined(__MINGW32_VERSION) + +#define NPY_BLOCK_ATAN2 +#define NPY_BLOCK_ATAN2F +#define NPY_BLOCK_ATAN2L + +#define NPY_BLOCK_HYPOT +#define NPY_BLOCK_HYPOTF +#define NPY_BLOCK_HYPOTL + +#endif + +#if defined(_MSC_VER) + +#undef HAVE_CASIN +#undef HAVE_CASINF +#undef HAVE_CASINL +#undef HAVE_CASINH +#undef HAVE_CASINHF +#undef HAVE_CASINHL +#undef HAVE_CATAN +#undef HAVE_CATANF +#undef HAVE_CATANL +#undef HAVE_CATANH +#undef HAVE_CATANHF +#undef HAVE_CATANHL +#undef HAVE_CSQRT +#undef HAVE_CSQRTF +#undef HAVE_CSQRTL +#undef HAVE_CLOG +#undef HAVE_CLOGF +#undef HAVE_CLOGL +#undef HAVE_CACOS +#undef HAVE_CACOSF +#undef HAVE_CACOSL +#undef HAVE_CACOSH +#undef HAVE_CACOSHF +#undef HAVE_CACOSHL + +#endif + +/* MSVC _hypot messes with fp precision mode on 32-bit, see gh-9567 */ +#if defined(_MSC_VER) && !defined(_WIN64) + +#undef HAVE_CABS +#undef HAVE_CABSF +#undef HAVE_CABSL + +#define NPY_BLOCK_HYPOT +#define NPY_BLOCK_HYPOTF +#define NPY_BLOCK_HYPOTL + +#endif + + +/* Intel C for Windows uses POW for 64 bits longdouble*/ +#if defined(_MSC_VER) && defined(__INTEL_COMPILER) +#if NPY_SIZEOF_LONGDOUBLE == 8 +#define NPY_BLOCK_POWL +#endif +#endif /* defined(_MSC_VER) && defined(__INTEL_COMPILER) */ + +/* powl gives zero division warning on OS X, see gh-8307 */ +#if defined(NPY_OS_DARWIN) +#define NPY_BLOCK_POWL +#endif + +#ifdef __CYGWIN__ +/* Loss of precision */ +#undef HAVE_CASINHL +#undef HAVE_CASINH +#undef HAVE_CASINHF + +/* Loss of precision */ +#undef HAVE_CATANHL +#undef HAVE_CATANH +#undef HAVE_CATANHF + +/* Loss of precision and branch cuts */ +#undef HAVE_CATANL +#undef HAVE_CATAN +#undef HAVE_CATANF + +/* Branch cuts */ +#undef HAVE_CACOSHF +#undef HAVE_CACOSH + +/* Branch cuts */ +#undef HAVE_CSQRTF +#undef HAVE_CSQRT + +/* Branch cuts and loss of precision */ +#undef HAVE_CASINF +#undef HAVE_CASIN +#undef HAVE_CASINL + +/* Branch cuts */ +#undef HAVE_CACOSF +#undef HAVE_CACOS + +/* log2(exp2(i)) off by a few eps */ +#define NPY_BLOCK_LOG2 + +/* np.power(..., dtype=np.complex256) doesn't report overflow */ +#undef HAVE_CPOWL +#undef HAVE_CEXPL + +#include +#if CYGWIN_VERSION_DLL_MAJOR < 3003 +// rather than blocklist cabsl, hypotl, modfl, sqrtl, error out +#error cygwin < 3.3 not supported, please update +#endif +#endif + +/* Disable broken gnu trig functions */ +#if defined(HAVE_FEATURES_H) +#include + +#if defined(__GLIBC__) +#if !__GLIBC_PREREQ(2, 18) + +#undef HAVE_CASIN +#undef HAVE_CASINF +#undef HAVE_CASINL +#undef HAVE_CASINH +#undef HAVE_CASINHF +#undef HAVE_CASINHL +#undef HAVE_CATAN +#undef HAVE_CATANF +#undef HAVE_CATANL +#undef HAVE_CATANH +#undef HAVE_CATANHF +#undef HAVE_CATANHL +#undef HAVE_CACOS +#undef HAVE_CACOSF +#undef HAVE_CACOSL +#undef HAVE_CACOSH +#undef HAVE_CACOSHF +#undef HAVE_CACOSHL + +#endif /* __GLIBC_PREREQ(2, 18) */ +#else /* defined(__GLIBC) */ +/* musl linux?, see issue #25092 */ + +#undef HAVE_CASIN +#undef HAVE_CASINF +#undef HAVE_CASINL +#undef HAVE_CASINH +#undef HAVE_CASINHF +#undef HAVE_CASINHL +#undef HAVE_CATAN +#undef HAVE_CATANF +#undef HAVE_CATANL +#undef HAVE_CATANH +#undef HAVE_CATANHF +#undef HAVE_CATANHL +#undef HAVE_CACOS +#undef HAVE_CACOSF +#undef HAVE_CACOSL +#undef HAVE_CACOSH +#undef HAVE_CACOSHF +#undef HAVE_CACOSHL + +#endif /* defined(__GLIBC) */ +#endif /* defined(HAVE_FEATURES_H) */ + +#endif /* NUMPY_CORE_SRC_COMMON_NPY_CONFIG_H_ */ diff --git a/mkl_umath/src/npyv/npy_cpu_dispatch.h b/mkl_umath/src/npyv/npy_cpu_dispatch.h new file mode 100644 index 00000000..ddf6bd55 --- /dev/null +++ b/mkl_umath/src/npyv/npy_cpu_dispatch.h @@ -0,0 +1,132 @@ +#ifndef NUMPY_CORE_SRC_COMMON_NPY_CPU_DISPATCH_H_ +#define NUMPY_CORE_SRC_COMMON_NPY_CPU_DISPATCH_H_ +/** + * This file is part of the NumPy CPU dispatcher. + * + * Please have a look at doc/reference/simd-optimizations.html + * To get a better understanding of the mechanism behind it. + */ +#include "npy_cpu_features.h" // NPY_CPU_HAVE +#if (defined(__s390x__) || defined(__powerpc64__)) && !defined(__cplusplus) && defined(bool) + /* + * "altivec.h" header contains the definitions(bool, vector, pixel), + * usually in c++ we undefine them after including the header. + * It's better anyway to take them off and use built-in types(__vector, __pixel, __bool) instead, + * since c99 supports bool variables which may lead to ambiguous errors. + */ + // backup 'bool' before including 'npy_cpu_dispatch_config.h', since it may not defined as a compiler token. + #define NPY__CPU_DISPATCH_GUARD_BOOL + typedef bool npy__cpu_dispatch_guard_bool; +#endif +/** + * Including the main configuration header 'npy_cpu_dispatch_config.h'. + * This header is generated by the 'ccompiler_opt' distutils module and the Meson build system. + * + * For the distutils-generated version, it contains: + * - Headers for platform-specific instruction sets. + * - Feature #definitions, e.g. NPY_HAVE_AVX2. + * - Helper macros that encapsulate enabled features through user-defined build options + * '--cpu-baseline' and '--cpu-dispatch'. These options are essential for implementing + * attributes like `__cpu_baseline__` and `__cpu_dispatch__` in the NumPy module. + * + * For the Meson-generated version, it contains: + * - Headers for platform-specific instruction sets. + * - Helper macros that encapsulate enabled features through user-defined build options + * '--cpu-baseline' and '--cpu-dispatch'. These options remain crucial for implementing + * attributes like `__cpu_baseline__` and `__cpu_dispatch__` in the NumPy module. + * - Additional helper macros necessary for runtime dispatching. + * + * Note: In the Meson build, features #definitions are conveyed via compiler arguments. + */ +#include "npy_cpu_dispatch_config.h" +#ifndef NPY__CPU_MESON_BUILD + // Define helper macros necessary for runtime dispatching for distutils. + #include "npy_cpu_dispatch_distutils.h" +#endif +#if defined(NPY_HAVE_VSX) || defined(NPY_HAVE_VX) + #undef bool + #undef vector + #undef pixel + #ifdef NPY__CPU_DISPATCH_GUARD_BOOL + #define bool npy__cpu_dispatch_guard_bool + #undef NPY__CPU_DISPATCH_GUARD_BOOL + #endif +#endif +/** + * Initialize the CPU dispatch tracer. + * + * This function simply adds an empty dictionary with the attribute + * '__cpu_targets_info__' to the provided module. + * + * It should be called only once during the loading of the NumPy module. + * Note: This function is not thread-safe. + * + * @param mod The module to which the '__cpu_targets_info__' dictionary will be added. + * @return 0 on success. + */ +NPY_VISIBILITY_HIDDEN int +npy_cpu_dispatch_tracer_init(PyObject *mod); +/** + * Insert data into the initialized '__cpu_targets_info__' dictionary. + * + * This function adds the function name as a key and another dictionary as a value. + * The inner dictionary holds the 'signature' as a key and splits 'dispatch_info' into another dictionary. + * The innermost dictionary contains the current enabled target as 'current' and available targets as 'available'. + * + * Note: This function should not be used directly; it should be used through the macro NPY_CPU_DISPATCH_TRACE(), + * which is responsible for filling in the enabled CPU targets. + * + * Example: + * + * const char *dispatch_info[] = {"AVX2", "AVX512_SKX AVX2 baseline"}; + * npy_cpu_dispatch_trace("add", "bbb", dispatch_info); + * + * const char *dispatch_info[] = {"AVX2", "AVX2 SSE41 baseline"}; + * npy_cpu_dispatch_trace("add", "BBB", dispatch_info); + * + * This will insert the following structure into the '__cpu_targets_info__' dictionary: + * + * numpy._core._multiarray_umath.__cpu_targets_info__ + * { + * "add": { + * "bbb": { + * "current": "AVX2", + * "available": "AVX512_SKX AVX2 baseline" + * }, + * "BBB": { + * "current": "AVX2", + * "available": "AVX2 SSE41 baseline" + * }, + * }, + * } + * + * @param func_name The name of the function. + * @param signature The signature of the function. + * @param dispatch_info The information about CPU dispatching. + */ +NPY_VISIBILITY_HIDDEN void +npy_cpu_dispatch_trace(const char *func_name, const char *signature, + const char **dispatch_info); +/** + * Extract the enabled CPU targets from the generated configuration file. + * + * This macro is used to extract the enabled CPU targets from the generated configuration file, + * which is derived from 'meson.multi_targets()' or from 'disutils.CCompilerOpt' in the case of using distutils. + * It then calls 'npy_cpu_dispatch_trace()' to insert a new item into the '__cpu_targets_info__' dictionary, + * based on the provided FUNC_NAME and SIGNATURE. + * + * For more clarification, please refer to the macro 'NPY_CPU_DISPATCH_INFO()' defined in 'meson_cpu/main_config.h.in' + * and check 'np.lib.utils.opt_func_info()' for the final usage of this trace. + * + * Example: + * #include "arithmetic.dispatch.h" + * NPY_CPU_DISPATCH_CALL(BYTE_add_ptr = BYTE_add); + * NPY_CPU_DISPATCH_TRACE("add", "bbb"); + */ +#define NPY_CPU_DISPATCH_TRACE(FNAME, SIGNATURE) \ +{ \ + const char *dinfo[] = NPY_CPU_DISPATCH_INFO(); \ + npy_cpu_dispatch_trace(FNAME, SIGNATURE, dinfo); \ +} while(0) + +#endif // NUMPY_CORE_SRC_COMMON_NPY_CPU_DISPATCH_H_ diff --git a/mkl_umath/src/npyv/npy_cpu_dispatch_distutils.h b/mkl_umath/src/npyv/npy_cpu_dispatch_distutils.h new file mode 100644 index 00000000..8db99541 --- /dev/null +++ b/mkl_umath/src/npyv/npy_cpu_dispatch_distutils.h @@ -0,0 +1,116 @@ +#ifndef NUMPY_CORE_SRC_COMMON_NPY_CPU_DISPATCH_DISTUTILS_H_ +#define NUMPY_CORE_SRC_COMMON_NPY_CPU_DISPATCH_DISTUTILS_H_ +#ifndef NUMPY_CORE_SRC_COMMON_NPY_CPU_DISPATCH_H_ + #error "Not standalone header please use 'npy_cpu_dispatch.h'" +#endif +/** + * This header should be removed after support for distutils is removed. + * It provides helper macros required for CPU runtime dispatching, + * which are already defined within `meson_cpu/main_config.h.in`. + * + * The following macros are explained within `meson_cpu/main_config.h.in`, + * although there are some differences in their usage: + * + * - Dispatched targets must be defined at the top of each dispatch-able + * source file within an inline or multi-line comment block. + * For example: //@targets baseline SSE2 AVX2 AVX512_SKX + * + * - The generated configuration derived from each dispatch-able source + * file must be guarded with `#ifndef NPY_DISABLE_OPTIMIZATION`. + * For example: + * #ifndef NPY_DISABLE_OPTIMIZATION + * #include "arithmetic.dispatch.h" + * #endif + */ +#include "npy_cpu_features.h" // NPY_CPU_HAVE +#include "numpy/utils.h" // NPY_EXPAND, NPY_CAT + +#ifdef NPY__CPU_TARGET_CURRENT + // 'NPY__CPU_TARGET_CURRENT': only defined by the dispatch-able sources + #define NPY_CPU_DISPATCH_CURFX(NAME) NPY_CAT(NPY_CAT(NAME, _), NPY__CPU_TARGET_CURRENT) +#else + #define NPY_CPU_DISPATCH_CURFX(NAME) NPY_EXPAND(NAME) +#endif +/** + * Defining the default behavior for the configurable macros of dispatch-able sources, + * 'NPY__CPU_DISPATCH_CALL(...)' and 'NPY__CPU_DISPATCH_BASELINE_CALL(...)' + * + * These macros are defined inside the generated config files that been derived from + * the configuration statements of the dispatch-able sources. + * + * The generated config file takes the same name of the dispatch-able source with replacing + * the extension to '.h' instead of '.c', and it should be treated as a header template. + */ +#ifndef NPY_DISABLE_OPTIMIZATION + #define NPY__CPU_DISPATCH_BASELINE_CALL(CB, ...) \ + &&"Expected config header of the dispatch-able source"; + #define NPY__CPU_DISPATCH_CALL(CHK, CB, ...) \ + &&"Expected config header of the dispatch-able source"; +#else + /** + * We assume by default that all configuration statements contains 'baseline' option however, + * if the dispatch-able source doesn't require it, then the dispatch-able source and following macros + * need to be guard it with '#ifndef NPY_DISABLE_OPTIMIZATION' + */ + #define NPY__CPU_DISPATCH_BASELINE_CALL(CB, ...) \ + NPY_EXPAND(CB(__VA_ARGS__)) + #define NPY__CPU_DISPATCH_CALL(CHK, CB, ...) +#endif // !NPY_DISABLE_OPTIMIZATION + +#define NPY_CPU_DISPATCH_DECLARE(...) \ + NPY__CPU_DISPATCH_CALL(NPY_CPU_DISPATCH_DECLARE_CHK_, NPY_CPU_DISPATCH_DECLARE_CB_, __VA_ARGS__) \ + NPY__CPU_DISPATCH_BASELINE_CALL(NPY_CPU_DISPATCH_DECLARE_BASE_CB_, __VA_ARGS__) +// Preprocessor callbacks +#define NPY_CPU_DISPATCH_DECLARE_CB_(DUMMY, TARGET_NAME, LEFT, ...) \ + NPY_CAT(NPY_CAT(LEFT, _), TARGET_NAME) __VA_ARGS__; +#define NPY_CPU_DISPATCH_DECLARE_BASE_CB_(LEFT, ...) \ + LEFT __VA_ARGS__; +// Dummy CPU runtime checking +#define NPY_CPU_DISPATCH_DECLARE_CHK_(FEATURE) + +#define NPY_CPU_DISPATCH_DECLARE_XB(...) \ + NPY__CPU_DISPATCH_CALL(NPY_CPU_DISPATCH_DECLARE_CHK_, NPY_CPU_DISPATCH_DECLARE_CB_, __VA_ARGS__) +#define NPY_CPU_DISPATCH_CALL(...) \ + NPY__CPU_DISPATCH_CALL(NPY_CPU_HAVE, NPY_CPU_DISPATCH_CALL_CB_, __VA_ARGS__) \ + NPY__CPU_DISPATCH_BASELINE_CALL(NPY_CPU_DISPATCH_CALL_BASE_CB_, __VA_ARGS__) +// Preprocessor callbacks +#define NPY_CPU_DISPATCH_CALL_CB_(TESTED_FEATURES, TARGET_NAME, LEFT, ...) \ + (TESTED_FEATURES) ? (NPY_CAT(NPY_CAT(LEFT, _), TARGET_NAME) __VA_ARGS__) : +#define NPY_CPU_DISPATCH_CALL_BASE_CB_(LEFT, ...) \ + (LEFT __VA_ARGS__) + +#define NPY_CPU_DISPATCH_CALL_XB(...) \ + NPY__CPU_DISPATCH_CALL(NPY_CPU_HAVE, NPY_CPU_DISPATCH_CALL_XB_CB_, __VA_ARGS__) \ + ((void) 0 /* discarded expression value */) +#define NPY_CPU_DISPATCH_CALL_XB_CB_(TESTED_FEATURES, TARGET_NAME, LEFT, ...) \ + (TESTED_FEATURES) ? (void) (NPY_CAT(NPY_CAT(LEFT, _), TARGET_NAME) __VA_ARGS__) : + +#define NPY_CPU_DISPATCH_CALL_ALL(...) \ + (NPY__CPU_DISPATCH_CALL(NPY_CPU_HAVE, NPY_CPU_DISPATCH_CALL_ALL_CB_, __VA_ARGS__) \ + NPY__CPU_DISPATCH_BASELINE_CALL(NPY_CPU_DISPATCH_CALL_ALL_BASE_CB_, __VA_ARGS__)) +// Preprocessor callbacks +#define NPY_CPU_DISPATCH_CALL_ALL_CB_(TESTED_FEATURES, TARGET_NAME, LEFT, ...) \ + ((TESTED_FEATURES) ? (NPY_CAT(NPY_CAT(LEFT, _), TARGET_NAME) __VA_ARGS__) : (void) 0), +#define NPY_CPU_DISPATCH_CALL_ALL_BASE_CB_(LEFT, ...) \ + ( LEFT __VA_ARGS__ ) + +#define NPY_CPU_DISPATCH_INFO() \ + { \ + NPY__CPU_DISPATCH_CALL(NPY_CPU_HAVE, NPY_CPU_DISPATCH_INFO_HIGH_CB_, DUMMY) \ + NPY__CPU_DISPATCH_BASELINE_CALL(NPY_CPU_DISPATCH_INFO_BASE_HIGH_CB_, DUMMY) \ + "", \ + NPY__CPU_DISPATCH_CALL(NPY_CPU_HAVE, NPY_CPU_DISPATCH_INFO_CB_, DUMMY) \ + NPY__CPU_DISPATCH_BASELINE_CALL(NPY_CPU_DISPATCH_INFO_BASE_CB_, DUMMY) \ + ""\ + } +#define NPY_CPU_DISPATCH_INFO_HIGH_CB_(TESTED_FEATURES, TARGET_NAME, ...) \ + (TESTED_FEATURES) ? NPY_TOSTRING(TARGET_NAME) : +#define NPY_CPU_DISPATCH_INFO_BASE_HIGH_CB_(...) \ + (1) ? "baseline(" NPY_WITH_CPU_BASELINE ")" : +// Preprocessor callbacks +#define NPY_CPU_DISPATCH_INFO_CB_(TESTED_FEATURES, TARGET_NAME, ...) \ + NPY_TOSTRING(TARGET_NAME) " " +#define NPY_CPU_DISPATCH_INFO_BASE_CB_(...) \ + "baseline(" NPY_WITH_CPU_BASELINE ")" + +#endif // NUMPY_CORE_SRC_COMMON_NPY_CPU_DISPATCH_DISTUTILS_H_ diff --git a/mkl_umath/src/npyv/npy_cpu_features.h b/mkl_umath/src/npyv/npy_cpu_features.h new file mode 100644 index 00000000..83522b93 --- /dev/null +++ b/mkl_umath/src/npyv/npy_cpu_features.h @@ -0,0 +1,201 @@ +#ifndef NUMPY_CORE_SRC_COMMON_NPY_CPU_FEATURES_H_ +#define NUMPY_CORE_SRC_COMMON_NPY_CPU_FEATURES_H_ + +#include // for PyObject +#include "numpy/numpyconfig.h" // for NPY_VISIBILITY_HIDDEN + +#ifdef __cplusplus +extern "C" { +#endif + +enum npy_cpu_features +{ + NPY_CPU_FEATURE_NONE = 0, + // X86 + NPY_CPU_FEATURE_MMX = 1, + NPY_CPU_FEATURE_SSE = 2, + NPY_CPU_FEATURE_SSE2 = 3, + NPY_CPU_FEATURE_SSE3 = 4, + NPY_CPU_FEATURE_SSSE3 = 5, + NPY_CPU_FEATURE_SSE41 = 6, + NPY_CPU_FEATURE_POPCNT = 7, + NPY_CPU_FEATURE_SSE42 = 8, + NPY_CPU_FEATURE_AVX = 9, + NPY_CPU_FEATURE_F16C = 10, + NPY_CPU_FEATURE_XOP = 11, + NPY_CPU_FEATURE_FMA4 = 12, + NPY_CPU_FEATURE_FMA3 = 13, + NPY_CPU_FEATURE_AVX2 = 14, + NPY_CPU_FEATURE_FMA = 15, // AVX2 & FMA3, provides backward compatibility + + NPY_CPU_FEATURE_AVX512F = 30, + NPY_CPU_FEATURE_AVX512CD = 31, + NPY_CPU_FEATURE_AVX512ER = 32, + NPY_CPU_FEATURE_AVX512PF = 33, + NPY_CPU_FEATURE_AVX5124FMAPS = 34, + NPY_CPU_FEATURE_AVX5124VNNIW = 35, + NPY_CPU_FEATURE_AVX512VPOPCNTDQ = 36, + NPY_CPU_FEATURE_AVX512BW = 37, + NPY_CPU_FEATURE_AVX512DQ = 38, + NPY_CPU_FEATURE_AVX512VL = 39, + NPY_CPU_FEATURE_AVX512IFMA = 40, + NPY_CPU_FEATURE_AVX512VBMI = 41, + NPY_CPU_FEATURE_AVX512VNNI = 42, + NPY_CPU_FEATURE_AVX512VBMI2 = 43, + NPY_CPU_FEATURE_AVX512BITALG = 44, + NPY_CPU_FEATURE_AVX512FP16 = 45, + + // X86 CPU Groups + // Knights Landing (F,CD,ER,PF) + NPY_CPU_FEATURE_AVX512_KNL = 101, + // Knights Mill (F,CD,ER,PF,4FMAPS,4VNNIW,VPOPCNTDQ) + NPY_CPU_FEATURE_AVX512_KNM = 102, + // Skylake-X (F,CD,BW,DQ,VL) + NPY_CPU_FEATURE_AVX512_SKX = 103, + // Cascade Lake (F,CD,BW,DQ,VL,VNNI) + NPY_CPU_FEATURE_AVX512_CLX = 104, + // Cannon Lake (F,CD,BW,DQ,VL,IFMA,VBMI) + NPY_CPU_FEATURE_AVX512_CNL = 105, + // Ice Lake (F,CD,BW,DQ,VL,IFMA,VBMI,VNNI,VBMI2,BITALG,VPOPCNTDQ) + NPY_CPU_FEATURE_AVX512_ICL = 106, + // Sapphire Rapids (Ice Lake, AVX512FP16) + NPY_CPU_FEATURE_AVX512_SPR = 107, + + // IBM/POWER VSX + // POWER7 + NPY_CPU_FEATURE_VSX = 200, + // POWER8 + NPY_CPU_FEATURE_VSX2 = 201, + // POWER9 + NPY_CPU_FEATURE_VSX3 = 202, + // POWER10 + NPY_CPU_FEATURE_VSX4 = 203, + + // ARM + NPY_CPU_FEATURE_NEON = 300, + NPY_CPU_FEATURE_NEON_FP16 = 301, + // FMA + NPY_CPU_FEATURE_NEON_VFPV4 = 302, + // Advanced SIMD + NPY_CPU_FEATURE_ASIMD = 303, + // ARMv8.2 half-precision + NPY_CPU_FEATURE_FPHP = 304, + // ARMv8.2 half-precision vector arithm + NPY_CPU_FEATURE_ASIMDHP = 305, + // ARMv8.2 dot product + NPY_CPU_FEATURE_ASIMDDP = 306, + // ARMv8.2 single&half-precision multiply + NPY_CPU_FEATURE_ASIMDFHM = 307, + // Scalable Vector Extensions (SVE) + NPY_CPU_FEATURE_SVE = 308, + + // IBM/ZARCH + NPY_CPU_FEATURE_VX = 350, + + // Vector-Enhancements Facility 1 + NPY_CPU_FEATURE_VXE = 351, + + // Vector-Enhancements Facility 2 + NPY_CPU_FEATURE_VXE2 = 352, + + NPY_CPU_FEATURE_MAX +}; + +/* + * Initialize CPU features + * + * This function + * - detects runtime CPU features + * - check that baseline CPU features are present + * - uses 'NPY_DISABLE_CPU_FEATURES' to disable dispatchable features + * - uses 'NPY_ENABLE_CPU_FEATURES' to enable dispatchable features + * + * It will set a RuntimeError when + * - CPU baseline features from the build are not supported at runtime + * - 'NPY_DISABLE_CPU_FEATURES' tries to disable a baseline feature + * - 'NPY_DISABLE_CPU_FEATURES' and 'NPY_ENABLE_CPU_FEATURES' are + * simultaneously set + * - 'NPY_ENABLE_CPU_FEATURES' tries to enable a feature that is not supported + * by the machine or build + * - 'NPY_ENABLE_CPU_FEATURES' tries to enable a feature when the project was + * not built with any feature optimization support + * + * It will set an ImportWarning when: + * - 'NPY_DISABLE_CPU_FEATURES' tries to disable a feature that is not supported + * by the machine or build + * - 'NPY_DISABLE_CPU_FEATURES' or 'NPY_ENABLE_CPU_FEATURES' tries to + * disable/enable a feature when the project was not built with any feature + * optimization support + * + * return 0 on success otherwise return -1 + */ +NPY_VISIBILITY_HIDDEN int +npy_cpu_init(void); + +/* + * return 0 if CPU feature isn't available + * note: `npy_cpu_init` must be called first otherwise it will always return 0 +*/ +NPY_VISIBILITY_HIDDEN int +npy_cpu_have(int feature_id); + +#define NPY_CPU_HAVE(FEATURE_NAME) \ +npy_cpu_have(NPY_CPU_FEATURE_##FEATURE_NAME) + +/* + * return a new dictionary contains CPU feature names + * with runtime availability. + * same as npy_cpu_have, `npy_cpu_init` must be called first. + */ +NPY_VISIBILITY_HIDDEN PyObject * +npy_cpu_features_dict(void); +/* + * Return a new a Python list contains the minimal set of required optimizations + * that supported by the compiler and platform according to the specified + * values to command argument '--cpu-baseline'. + * + * This function is mainly used to implement umath's attribute '__cpu_baseline__', + * and the items are sorted from the lowest to highest interest. + * + * For example, according to the default build configuration and by assuming the compiler + * support all the involved optimizations then the returned list should equivalent to: + * + * On x86: ['SSE', 'SSE2'] + * On x64: ['SSE', 'SSE2', 'SSE3'] + * On armhf: [] + * On aarch64: ['NEON', 'NEON_FP16', 'NEON_VPFV4', 'ASIMD'] + * On ppc64: [] + * On ppc64le: ['VSX', 'VSX2'] + * On s390x: [] + * On any other arch or if the optimization is disabled: [] + */ +NPY_VISIBILITY_HIDDEN PyObject * +npy_cpu_baseline_list(void); +/* + * Return a new a Python list contains the dispatched set of additional optimizations + * that supported by the compiler and platform according to the specified + * values to command argument '--cpu-dispatch'. + * + * This function is mainly used to implement umath's attribute '__cpu_dispatch__', + * and the items are sorted from the lowest to highest interest. + * + * For example, according to the default build configuration and by assuming the compiler + * support all the involved optimizations then the returned list should equivalent to: + * + * On x86: ['SSE3', 'SSSE3', 'SSE41', 'POPCNT', 'SSE42', 'AVX', 'F16C', 'FMA3', 'AVX2', 'AVX512F', ...] + * On x64: ['SSSE3', 'SSE41', 'POPCNT', 'SSE42', 'AVX', 'F16C', 'FMA3', 'AVX2', 'AVX512F', ...] + * On armhf: ['NEON', 'NEON_FP16', 'NEON_VPFV4', 'ASIMD', 'ASIMDHP', 'ASIMDDP', 'ASIMDFHM'] + * On aarch64: ['ASIMDHP', 'ASIMDDP', 'ASIMDFHM'] + * On ppc64: ['VSX', 'VSX2', 'VSX3', 'VSX4'] + * On ppc64le: ['VSX3', 'VSX4'] + * On s390x: ['VX', 'VXE', VXE2] + * On any other arch or if the optimization is disabled: [] + */ +NPY_VISIBILITY_HIDDEN PyObject * +npy_cpu_dispatch_list(void); + +#ifdef __cplusplus +} +#endif + +#endif // NUMPY_CORE_SRC_COMMON_NPY_CPU_FEATURES_H_ diff --git a/mkl_umath/src/npyv/npy_cpuinfo_parser.h b/mkl_umath/src/npyv/npy_cpuinfo_parser.h new file mode 100644 index 00000000..154c4245 --- /dev/null +++ b/mkl_umath/src/npyv/npy_cpuinfo_parser.h @@ -0,0 +1,263 @@ +/* + * Copyright (C) 2010 The Android Open Source Project + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the + * distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS + * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE + * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, + * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS + * OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED + * AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT + * OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF + * SUCH DAMAGE. + */ +#ifndef NUMPY_CORE_SRC_COMMON_NPY_CPUINFO_PARSER_H_ +#define NUMPY_CORE_SRC_COMMON_NPY_CPUINFO_PARSER_H_ +#include +#include +#include +#include +#include + +#define NPY__HWCAP 16 +#define NPY__HWCAP2 26 + +// arch/arm/include/uapi/asm/hwcap.h +#define NPY__HWCAP_HALF (1 << 1) +#define NPY__HWCAP_NEON (1 << 12) +#define NPY__HWCAP_VFPv3 (1 << 13) +#define NPY__HWCAP_VFPv4 (1 << 16) +#define NPY__HWCAP2_AES (1 << 0) +#define NPY__HWCAP2_PMULL (1 << 1) +#define NPY__HWCAP2_SHA1 (1 << 2) +#define NPY__HWCAP2_SHA2 (1 << 3) +#define NPY__HWCAP2_CRC32 (1 << 4) +// arch/arm64/include/uapi/asm/hwcap.h +#define NPY__HWCAP_FP (1 << 0) +#define NPY__HWCAP_ASIMD (1 << 1) +#define NPY__HWCAP_FPHP (1 << 9) +#define NPY__HWCAP_ASIMDHP (1 << 10) +#define NPY__HWCAP_ASIMDDP (1 << 20) +#define NPY__HWCAP_SVE (1 << 22) +#define NPY__HWCAP_ASIMDFHM (1 << 23) +/* + * Get the size of a file by reading it until the end. This is needed + * because files under /proc do not always return a valid size when + * using fseek(0, SEEK_END) + ftell(). Nor can they be mmap()-ed. + */ +static int +get_file_size(const char* pathname) +{ + int fd, result = 0; + char buffer[256]; + + fd = open(pathname, O_RDONLY); + if (fd < 0) { + return -1; + } + + for (;;) { + int ret = read(fd, buffer, sizeof buffer); + if (ret < 0) { + if (errno == EINTR) { + continue; + } + break; + } + if (ret == 0) { + break; + } + result += ret; + } + close(fd); + return result; +} + +/* + * Read the content of /proc/cpuinfo into a user-provided buffer. + * Return the length of the data, or -1 on error. Does *not* + * zero-terminate the content. Will not read more + * than 'buffsize' bytes. + */ +static int +read_file(const char* pathname, char* buffer, size_t buffsize) +{ + int fd, count; + + fd = open(pathname, O_RDONLY); + if (fd < 0) { + return -1; + } + count = 0; + while (count < (int)buffsize) { + int ret = read(fd, buffer + count, buffsize - count); + if (ret < 0) { + if (errno == EINTR) { + continue; + } + if (count == 0) { + count = -1; + } + break; + } + if (ret == 0) { + break; + } + count += ret; + } + close(fd); + return count; +} + +/* + * Extract the content of a the first occurrence of a given field in + * the content of /proc/cpuinfo and return it as a heap-allocated + * string that must be freed by the caller. + * + * Return NULL if not found + */ +static char* +extract_cpuinfo_field(const char* buffer, int buflen, const char* field) +{ + int fieldlen = strlen(field); + const char* bufend = buffer + buflen; + char* result = NULL; + int len; + const char *p, *q; + + /* Look for first field occurrence, and ensures it starts the line. */ + p = buffer; + for (;;) { + p = memmem(p, bufend-p, field, fieldlen); + if (p == NULL) { + goto EXIT; + } + + if (p == buffer || p[-1] == '\n') { + break; + } + + p += fieldlen; + } + + /* Skip to the first column followed by a space */ + p += fieldlen; + p = memchr(p, ':', bufend-p); + if (p == NULL || p[1] != ' ') { + goto EXIT; + } + + /* Find the end of the line */ + p += 2; + q = memchr(p, '\n', bufend-p); + if (q == NULL) { + q = bufend; + } + + /* Copy the line into a heap-allocated buffer */ + len = q - p; + result = malloc(len + 1); + if (result == NULL) { + goto EXIT; + } + + memcpy(result, p, len); + result[len] = '\0'; + +EXIT: + return result; +} + +/* + * Checks that a space-separated list of items contains one given 'item'. + * Returns 1 if found, 0 otherwise. + */ +static int +has_list_item(const char* list, const char* item) +{ + const char* p = list; + int itemlen = strlen(item); + + if (list == NULL) { + return 0; + } + + while (*p) { + const char* q; + + /* skip spaces */ + while (*p == ' ' || *p == '\t') { + p++; + } + + /* find end of current list item */ + q = p; + while (*q && *q != ' ' && *q != '\t') { + q++; + } + + if (itemlen == q-p && !memcmp(p, item, itemlen)) { + return 1; + } + + /* skip to next item */ + p = q; + } + return 0; +} + +static void setHwcap(char* cpuFeatures, unsigned long* hwcap) { + *hwcap |= has_list_item(cpuFeatures, "neon") ? NPY__HWCAP_NEON : 0; + *hwcap |= has_list_item(cpuFeatures, "half") ? NPY__HWCAP_HALF : 0; + *hwcap |= has_list_item(cpuFeatures, "vfpv3") ? NPY__HWCAP_VFPv3 : 0; + *hwcap |= has_list_item(cpuFeatures, "vfpv4") ? NPY__HWCAP_VFPv4 : 0; + + *hwcap |= has_list_item(cpuFeatures, "asimd") ? NPY__HWCAP_ASIMD : 0; + *hwcap |= has_list_item(cpuFeatures, "fp") ? NPY__HWCAP_FP : 0; + *hwcap |= has_list_item(cpuFeatures, "fphp") ? NPY__HWCAP_FPHP : 0; + *hwcap |= has_list_item(cpuFeatures, "asimdhp") ? NPY__HWCAP_ASIMDHP : 0; + *hwcap |= has_list_item(cpuFeatures, "asimddp") ? NPY__HWCAP_ASIMDDP : 0; + *hwcap |= has_list_item(cpuFeatures, "asimdfhm") ? NPY__HWCAP_ASIMDFHM : 0; +} + +static int +get_feature_from_proc_cpuinfo(unsigned long *hwcap, unsigned long *hwcap2) { + char* cpuinfo = NULL; + int cpuinfo_len; + cpuinfo_len = get_file_size("/proc/cpuinfo"); + if (cpuinfo_len < 0) { + return 0; + } + cpuinfo = malloc(cpuinfo_len); + if (cpuinfo == NULL) { + return 0; + } + cpuinfo_len = read_file("/proc/cpuinfo", cpuinfo, cpuinfo_len); + char* cpuFeatures = extract_cpuinfo_field(cpuinfo, cpuinfo_len, "Features"); + if(cpuFeatures == NULL) { + return 0; + } + setHwcap(cpuFeatures, hwcap); + *hwcap2 |= *hwcap; + *hwcap2 |= has_list_item(cpuFeatures, "aes") ? NPY__HWCAP2_AES : 0; + *hwcap2 |= has_list_item(cpuFeatures, "pmull") ? NPY__HWCAP2_PMULL : 0; + *hwcap2 |= has_list_item(cpuFeatures, "sha1") ? NPY__HWCAP2_SHA1 : 0; + *hwcap2 |= has_list_item(cpuFeatures, "sha2") ? NPY__HWCAP2_SHA2 : 0; + *hwcap2 |= has_list_item(cpuFeatures, "crc32") ? NPY__HWCAP2_CRC32 : 0; + return 1; +} +#endif /* NUMPY_CORE_SRC_COMMON_NPY_CPUINFO_PARSER_H_ */ diff --git a/mkl_umath/src/npyv/npy_ctypes.h b/mkl_umath/src/npyv/npy_ctypes.h new file mode 100644 index 00000000..578de063 --- /dev/null +++ b/mkl_umath/src/npyv/npy_ctypes.h @@ -0,0 +1,50 @@ +#ifndef NUMPY_CORE_SRC_COMMON_NPY_CTYPES_H_ +#define NUMPY_CORE_SRC_COMMON_NPY_CTYPES_H_ + +#include + +#include "npy_import.h" + +/* + * Check if a python type is a ctypes class. + * + * Works like the Py_Check functions, returning true if the argument + * looks like a ctypes object. + * + * This entire function is just a wrapper around the Python function of the + * same name. + */ +static inline int +npy_ctypes_check(PyTypeObject *obj) +{ + static PyObject *py_func = NULL; + PyObject *ret_obj; + int ret; + + npy_cache_import("numpy._core._internal", "npy_ctypes_check", &py_func); + if (py_func == NULL) { + goto fail; + } + + ret_obj = PyObject_CallFunctionObjArgs(py_func, (PyObject *)obj, NULL); + if (ret_obj == NULL) { + goto fail; + } + + ret = PyObject_IsTrue(ret_obj); + Py_DECREF(ret_obj); + if (ret == -1) { + goto fail; + } + + return ret; + +fail: + /* If the above fails, then we should just assume that the type is not from + * ctypes + */ + PyErr_Clear(); + return 0; +} + +#endif /* NUMPY_CORE_SRC_COMMON_NPY_CTYPES_H_ */ diff --git a/mkl_umath/src/npyv/npy_dlpack.h b/mkl_umath/src/npyv/npy_dlpack.h new file mode 100644 index 00000000..cb926a26 --- /dev/null +++ b/mkl_umath/src/npyv/npy_dlpack.h @@ -0,0 +1,28 @@ +#include "Python.h" +#include "dlpack/dlpack.h" + +#ifndef NPY_DLPACK_H +#define NPY_DLPACK_H + +// Part of the Array API specification. +#define NPY_DLPACK_CAPSULE_NAME "dltensor" +#define NPY_DLPACK_USED_CAPSULE_NAME "used_dltensor" + +// Used internally by NumPy to store a base object +// as it has to release a reference to the original +// capsule. +#define NPY_DLPACK_INTERNAL_CAPSULE_NAME "numpy_dltensor" + +PyObject * +array_dlpack(PyArrayObject *self, PyObject *const *args, Py_ssize_t len_args, + PyObject *kwnames); + + +PyObject * +array_dlpack_device(PyArrayObject *self, PyObject *NPY_UNUSED(args)); + + +NPY_NO_EXPORT PyObject * +from_dlpack(PyObject *NPY_UNUSED(self), PyObject *obj); + +#endif diff --git a/mkl_umath/src/npyv/npy_extint128.h b/mkl_umath/src/npyv/npy_extint128.h new file mode 100644 index 00000000..776d71c7 --- /dev/null +++ b/mkl_umath/src/npyv/npy_extint128.h @@ -0,0 +1,317 @@ +#ifndef NUMPY_CORE_SRC_COMMON_NPY_EXTINT128_H_ +#define NUMPY_CORE_SRC_COMMON_NPY_EXTINT128_H_ + + +typedef struct { + signed char sign; + npy_uint64 lo, hi; +} npy_extint128_t; + + +/* Integer addition with overflow checking */ +static inline npy_int64 +safe_add(npy_int64 a, npy_int64 b, char *overflow_flag) +{ + if (a > 0 && b > NPY_MAX_INT64 - a) { + *overflow_flag = 1; + } + else if (a < 0 && b < NPY_MIN_INT64 - a) { + *overflow_flag = 1; + } + return a + b; +} + + +/* Integer subtraction with overflow checking */ +static inline npy_int64 +safe_sub(npy_int64 a, npy_int64 b, char *overflow_flag) +{ + if (a >= 0 && b < a - NPY_MAX_INT64) { + *overflow_flag = 1; + } + else if (a < 0 && b > a - NPY_MIN_INT64) { + *overflow_flag = 1; + } + return a - b; +} + + +/* Integer multiplication with overflow checking */ +static inline npy_int64 +safe_mul(npy_int64 a, npy_int64 b, char *overflow_flag) +{ + if (a > 0) { + if (b > NPY_MAX_INT64 / a || b < NPY_MIN_INT64 / a) { + *overflow_flag = 1; + } + } + else if (a < 0) { + if (b > 0 && a < NPY_MIN_INT64 / b) { + *overflow_flag = 1; + } + else if (b < 0 && a < NPY_MAX_INT64 / b) { + *overflow_flag = 1; + } + } + return a * b; +} + + +/* Long integer init */ +static inline npy_extint128_t +to_128(npy_int64 x) +{ + npy_extint128_t result; + result.sign = (x >= 0 ? 1 : -1); + if (x >= 0) { + result.lo = x; + } + else { + result.lo = (npy_uint64)(-(x + 1)) + 1; + } + result.hi = 0; + return result; +} + + +static inline npy_int64 +to_64(npy_extint128_t x, char *overflow) +{ + if (x.hi != 0 || + (x.sign > 0 && x.lo > NPY_MAX_INT64) || + (x.sign < 0 && x.lo != 0 && x.lo - 1 > -(NPY_MIN_INT64 + 1))) { + *overflow = 1; + } + return x.lo * x.sign; +} + + +/* Long integer multiply */ +static inline npy_extint128_t +mul_64_64(npy_int64 a, npy_int64 b) +{ + npy_extint128_t x, y, z; + npy_uint64 x1, x2, y1, y2, r1, r2, prev; + + x = to_128(a); + y = to_128(b); + + x1 = x.lo & 0xffffffff; + x2 = x.lo >> 32; + + y1 = y.lo & 0xffffffff; + y2 = y.lo >> 32; + + r1 = x1*y2; + r2 = x2*y1; + + z.sign = x.sign * y.sign; + z.hi = x2*y2 + (r1 >> 32) + (r2 >> 32); + z.lo = x1*y1; + + /* Add with carry */ + prev = z.lo; + z.lo += (r1 << 32); + if (z.lo < prev) { + ++z.hi; + } + + prev = z.lo; + z.lo += (r2 << 32); + if (z.lo < prev) { + ++z.hi; + } + + return z; +} + + +/* Long integer add */ +static inline npy_extint128_t +add_128(npy_extint128_t x, npy_extint128_t y, char *overflow) +{ + npy_extint128_t z; + + if (x.sign == y.sign) { + z.sign = x.sign; + z.hi = x.hi + y.hi; + if (z.hi < x.hi) { + *overflow = 1; + } + z.lo = x.lo + y.lo; + if (z.lo < x.lo) { + if (z.hi == NPY_MAX_UINT64) { + *overflow = 1; + } + ++z.hi; + } + } + else if (x.hi > y.hi || (x.hi == y.hi && x.lo >= y.lo)) { + z.sign = x.sign; + z.hi = x.hi - y.hi; + z.lo = x.lo; + z.lo -= y.lo; + if (z.lo > x.lo) { + --z.hi; + } + } + else { + z.sign = y.sign; + z.hi = y.hi - x.hi; + z.lo = y.lo; + z.lo -= x.lo; + if (z.lo > y.lo) { + --z.hi; + } + } + + return z; +} + + +/* Long integer negation */ +static inline npy_extint128_t +neg_128(npy_extint128_t x) +{ + npy_extint128_t z = x; + z.sign *= -1; + return z; +} + + +static inline npy_extint128_t +sub_128(npy_extint128_t x, npy_extint128_t y, char *overflow) +{ + return add_128(x, neg_128(y), overflow); +} + + +static inline npy_extint128_t +shl_128(npy_extint128_t v) +{ + npy_extint128_t z; + z = v; + z.hi <<= 1; + z.hi |= (z.lo & (((npy_uint64)1) << 63)) >> 63; + z.lo <<= 1; + return z; +} + + +static inline npy_extint128_t +shr_128(npy_extint128_t v) +{ + npy_extint128_t z; + z = v; + z.lo >>= 1; + z.lo |= (z.hi & 0x1) << 63; + z.hi >>= 1; + return z; +} + +static inline int +gt_128(npy_extint128_t a, npy_extint128_t b) +{ + if (a.sign > 0 && b.sign > 0) { + return (a.hi > b.hi) || (a.hi == b.hi && a.lo > b.lo); + } + else if (a.sign < 0 && b.sign < 0) { + return (a.hi < b.hi) || (a.hi == b.hi && a.lo < b.lo); + } + else if (a.sign > 0 && b.sign < 0) { + return a.hi != 0 || a.lo != 0 || b.hi != 0 || b.lo != 0; + } + else { + return 0; + } +} + + +/* Long integer divide */ +static inline npy_extint128_t +divmod_128_64(npy_extint128_t x, npy_int64 b, npy_int64 *mod) +{ + npy_extint128_t remainder, pointer, result, divisor; + char overflow = 0; + + assert(b > 0); + + if (b <= 1 || x.hi == 0) { + result.sign = x.sign; + result.lo = x.lo / b; + result.hi = x.hi / b; + *mod = x.sign * (x.lo % b); + return result; + } + + /* Long division, not the most efficient choice */ + remainder = x; + remainder.sign = 1; + + divisor.sign = 1; + divisor.hi = 0; + divisor.lo = b; + + result.sign = 1; + result.lo = 0; + result.hi = 0; + + pointer.sign = 1; + pointer.lo = 1; + pointer.hi = 0; + + while ((divisor.hi & (((npy_uint64)1) << 63)) == 0 && + gt_128(remainder, divisor)) { + divisor = shl_128(divisor); + pointer = shl_128(pointer); + } + + while (pointer.lo || pointer.hi) { + if (!gt_128(divisor, remainder)) { + remainder = sub_128(remainder, divisor, &overflow); + result = add_128(result, pointer, &overflow); + } + divisor = shr_128(divisor); + pointer = shr_128(pointer); + } + + /* Fix signs and return; cannot overflow */ + result.sign = x.sign; + *mod = x.sign * remainder.lo; + + return result; +} + + +/* Divide and round down (positive divisor; no overflows) */ +static inline npy_extint128_t +floordiv_128_64(npy_extint128_t a, npy_int64 b) +{ + npy_extint128_t result; + npy_int64 remainder; + char overflow = 0; + assert(b > 0); + result = divmod_128_64(a, b, &remainder); + if (a.sign < 0 && remainder != 0) { + result = sub_128(result, to_128(1), &overflow); + } + return result; +} + + +/* Divide and round up (positive divisor; no overflows) */ +static inline npy_extint128_t +ceildiv_128_64(npy_extint128_t a, npy_int64 b) +{ + npy_extint128_t result; + npy_int64 remainder; + char overflow = 0; + assert(b > 0); + result = divmod_128_64(a, b, &remainder); + if (a.sign > 0 && remainder != 0) { + result = add_128(result, to_128(1), &overflow); + } + return result; +} + +#endif /* NUMPY_CORE_SRC_COMMON_NPY_EXTINT128_H_ */ diff --git a/mkl_umath/src/npyv/npy_fpmath.h b/mkl_umath/src/npyv/npy_fpmath.h new file mode 100644 index 00000000..27e9ea3f --- /dev/null +++ b/mkl_umath/src/npyv/npy_fpmath.h @@ -0,0 +1,30 @@ +#ifndef NUMPY_CORE_SRC_COMMON_NPY_NPY_FPMATH_H_ +#define NUMPY_CORE_SRC_COMMON_NPY_NPY_FPMATH_H_ + +#include "npy_config.h" + +#include "numpy/npy_os.h" +#include "numpy/npy_cpu.h" +#include "numpy/npy_common.h" + +#if !(defined(HAVE_LDOUBLE_IEEE_QUAD_BE) || \ + defined(HAVE_LDOUBLE_IEEE_QUAD_LE) || \ + defined(HAVE_LDOUBLE_IEEE_DOUBLE_LE) || \ + defined(HAVE_LDOUBLE_IEEE_DOUBLE_BE) || \ + defined(HAVE_LDOUBLE_INTEL_EXTENDED_16_BYTES_LE) || \ + defined(HAVE_LDOUBLE_INTEL_EXTENDED_12_BYTES_LE) || \ + defined(HAVE_LDOUBLE_MOTOROLA_EXTENDED_12_BYTES_BE) || \ + defined(HAVE_LDOUBLE_IBM_DOUBLE_DOUBLE_BE) || \ + defined(HAVE_LDOUBLE_IBM_DOUBLE_DOUBLE_LE)) + #error No long double representation defined +#endif + +/* for back-compat, also keep old name for double-double */ +#ifdef HAVE_LDOUBLE_IBM_DOUBLE_DOUBLE_LE + #define HAVE_LDOUBLE_DOUBLE_DOUBLE_LE +#endif +#ifdef HAVE_LDOUBLE_IBM_DOUBLE_DOUBLE_BE + #define HAVE_LDOUBLE_DOUBLE_DOUBLE_BE +#endif + +#endif /* NUMPY_CORE_SRC_COMMON_NPY_NPY_FPMATH_H_ */ diff --git a/mkl_umath/src/npyv/npy_hashtable.h b/mkl_umath/src/npyv/npy_hashtable.h new file mode 100644 index 00000000..a0bf8196 --- /dev/null +++ b/mkl_umath/src/npyv/npy_hashtable.h @@ -0,0 +1,32 @@ +#ifndef NUMPY_CORE_SRC_COMMON_NPY_NPY_HASHTABLE_H_ +#define NUMPY_CORE_SRC_COMMON_NPY_NPY_HASHTABLE_H_ + +#include + +#define NPY_NO_DEPRECATED_API NPY_API_VERSION +#include "numpy/ndarraytypes.h" + + +typedef struct { + int key_len; /* number of identities used */ + /* Buckets stores: val1, key1[0], key1[1], ..., val2, key2[0], ... */ + PyObject **buckets; + npy_intp size; /* current size */ + npy_intp nelem; /* number of elements */ +} PyArrayIdentityHash; + + +NPY_NO_EXPORT int +PyArrayIdentityHash_SetItem(PyArrayIdentityHash *tb, + PyObject *const *key, PyObject *value, int replace); + +NPY_NO_EXPORT PyObject * +PyArrayIdentityHash_GetItem(PyArrayIdentityHash const *tb, PyObject *const *key); + +NPY_NO_EXPORT PyArrayIdentityHash * +PyArrayIdentityHash_New(int key_len); + +NPY_NO_EXPORT void +PyArrayIdentityHash_Dealloc(PyArrayIdentityHash *tb); + +#endif /* NUMPY_CORE_SRC_COMMON_NPY_NPY_HASHTABLE_H_ */ diff --git a/mkl_umath/src/npyv/npy_import.h b/mkl_umath/src/npyv/npy_import.h new file mode 100644 index 00000000..58b4ba0b --- /dev/null +++ b/mkl_umath/src/npyv/npy_import.h @@ -0,0 +1,32 @@ +#ifndef NUMPY_CORE_SRC_COMMON_NPY_IMPORT_H_ +#define NUMPY_CORE_SRC_COMMON_NPY_IMPORT_H_ + +#include + +/*! \brief Fetch and cache Python function. + * + * Import a Python function and cache it for use. The function checks if + * cache is NULL, and if not NULL imports the Python function specified by + * \a module and \a function, increments its reference count, and stores + * the result in \a cache. Usually \a cache will be a static variable and + * should be initialized to NULL. On error \a cache will contain NULL on + * exit, + * + * @param module Absolute module name. + * @param attr module attribute to cache. + * @param cache Storage location for imported function. + */ +static inline void +npy_cache_import(const char *module, const char *attr, PyObject **cache) +{ + if (NPY_UNLIKELY(*cache == NULL)) { + PyObject *mod = PyImport_ImportModule(module); + + if (mod != NULL) { + *cache = PyObject_GetAttrString(mod, attr); + Py_DECREF(mod); + } + } +} + +#endif /* NUMPY_CORE_SRC_COMMON_NPY_IMPORT_H_ */ diff --git a/mkl_umath/src/npyv/npy_longdouble.h b/mkl_umath/src/npyv/npy_longdouble.h new file mode 100644 index 00000000..cf8b37bc --- /dev/null +++ b/mkl_umath/src/npyv/npy_longdouble.h @@ -0,0 +1,27 @@ +#ifndef NUMPY_CORE_SRC_COMMON_NPY_LONGDOUBLE_H_ +#define NUMPY_CORE_SRC_COMMON_NPY_LONGDOUBLE_H_ + +#include "npy_config.h" +#include "numpy/ndarraytypes.h" + +/* Convert a npy_longdouble to a python `long` integer. + * + * Results are rounded towards zero. + * + * This performs the same task as PyLong_FromDouble, but for long doubles + * which have a greater range. + */ +NPY_VISIBILITY_HIDDEN PyObject * +npy_longdouble_to_PyLong(npy_longdouble ldval); + +/* Convert a python `long` integer to a npy_longdouble + * + * This performs the same task as PyLong_AsDouble, but for long doubles + * which have a greater range. + * + * Returns -1 if an error occurs. + */ +NPY_VISIBILITY_HIDDEN npy_longdouble +npy_longdouble_from_PyLong(PyObject *long_obj); + +#endif /* NUMPY_CORE_SRC_COMMON_NPY_LONGDOUBLE_H_ */ diff --git a/mkl_umath/src/npyv/npy_partition.h b/mkl_umath/src/npyv/npy_partition.h new file mode 100644 index 00000000..57f834d9 --- /dev/null +++ b/mkl_umath/src/npyv/npy_partition.h @@ -0,0 +1,34 @@ +#ifndef NUMPY_CORE_SRC_COMMON_PARTITION_H_ +#define NUMPY_CORE_SRC_COMMON_PARTITION_H_ + +#include "npy_sort.h" + +/* Python include is for future object sorts */ +#include + +#include +#include + +#define NPY_MAX_PIVOT_STACK 50 + +typedef int (PyArray_PartitionFunc)(void *, npy_intp, npy_intp, + npy_intp *, npy_intp *, npy_intp, + void *); +typedef int (PyArray_ArgPartitionFunc)(void *, npy_intp *, npy_intp, npy_intp, + npy_intp *, npy_intp *, npy_intp, + void *); +#ifdef __cplusplus +extern "C" { +#endif + +NPY_NO_EXPORT PyArray_PartitionFunc * +get_partition_func(int type, NPY_SELECTKIND which); + +NPY_NO_EXPORT PyArray_ArgPartitionFunc * +get_argpartition_func(int type, NPY_SELECTKIND which); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/mkl_umath/src/npyv/npy_pycompat.h b/mkl_umath/src/npyv/npy_pycompat.h new file mode 100644 index 00000000..ce6c34fa --- /dev/null +++ b/mkl_umath/src/npyv/npy_pycompat.h @@ -0,0 +1,22 @@ +#ifndef NUMPY_CORE_SRC_COMMON_NPY_PYCOMPAT_H_ +#define NUMPY_CORE_SRC_COMMON_NPY_PYCOMPAT_H_ + +#include "numpy/npy_3kcompat.h" + + +/* + * In Python 3.10a7 (or b1), python started using the identity for the hash + * when a value is NaN. See https://bugs.python.org/issue43475 + */ +#if PY_VERSION_HEX > 0x030a00a6 +#define Npy_HashDouble _Py_HashDouble +#else +static inline Py_hash_t +Npy_HashDouble(PyObject *NPY_UNUSED(identity), double val) +{ + return _Py_HashDouble(val); +} +#endif + + +#endif /* NUMPY_CORE_SRC_COMMON_NPY_PYCOMPAT_H_ */ diff --git a/mkl_umath/src/npyv/npy_svml.h b/mkl_umath/src/npyv/npy_svml.h new file mode 100644 index 00000000..64e9ee53 --- /dev/null +++ b/mkl_umath/src/npyv/npy_svml.h @@ -0,0 +1,71 @@ +#if NPY_SIMD && defined(NPY_HAVE_AVX512_SPR) && defined(NPY_CAN_LINK_SVML) +extern void __svml_exps32(const npy_half*, npy_half*, npy_intp); +extern void __svml_exp2s32(const npy_half*, npy_half*, npy_intp); +extern void __svml_logs32(const npy_half*, npy_half*, npy_intp); +extern void __svml_log2s32(const npy_half*, npy_half*, npy_intp); +extern void __svml_log10s32(const npy_half*, npy_half*, npy_intp); +extern void __svml_expm1s32(const npy_half*, npy_half*, npy_intp); +extern void __svml_log1ps32(const npy_half*, npy_half*, npy_intp); +extern void __svml_cbrts32(const npy_half*, npy_half*, npy_intp); +extern void __svml_sins32(const npy_half*, npy_half*, npy_intp); +extern void __svml_coss32(const npy_half*, npy_half*, npy_intp); +extern void __svml_tans32(const npy_half*, npy_half*, npy_intp); +extern void __svml_asins32(const npy_half*, npy_half*, npy_intp); +extern void __svml_acoss32(const npy_half*, npy_half*, npy_intp); +extern void __svml_atans32(const npy_half*, npy_half*, npy_intp); +extern void __svml_atan2s32(const npy_half*, npy_half*, npy_intp); +extern void __svml_sinhs32(const npy_half*, npy_half*, npy_intp); +extern void __svml_coshs32(const npy_half*, npy_half*, npy_intp); +extern void __svml_tanhs32(const npy_half*, npy_half*, npy_intp); +extern void __svml_asinhs32(const npy_half*, npy_half*, npy_intp); +extern void __svml_acoshs32(const npy_half*, npy_half*, npy_intp); +extern void __svml_atanhs32(const npy_half*, npy_half*, npy_intp); +#endif + +#if NPY_SIMD && defined(NPY_HAVE_AVX512_SKX) && defined(NPY_CAN_LINK_SVML) +extern __m512 __svml_expf16(__m512 x); +extern __m512 __svml_exp2f16(__m512 x); +extern __m512 __svml_logf16(__m512 x); +extern __m512 __svml_log2f16(__m512 x); +extern __m512 __svml_log10f16(__m512 x); +extern __m512 __svml_expm1f16(__m512 x); +extern __m512 __svml_log1pf16(__m512 x); +extern __m512 __svml_cbrtf16(__m512 x); +extern __m512 __svml_sinf16(__m512 x); +extern __m512 __svml_cosf16(__m512 x); +extern __m512 __svml_tanf16(__m512 x); +extern __m512 __svml_asinf16(__m512 x); +extern __m512 __svml_acosf16(__m512 x); +extern __m512 __svml_atanf16(__m512 x); +extern __m512 __svml_atan2f16(__m512 x, __m512 y); +extern __m512 __svml_sinhf16(__m512 x); +extern __m512 __svml_coshf16(__m512 x); +extern __m512 __svml_tanhf16(__m512 x); +extern __m512 __svml_asinhf16(__m512 x); +extern __m512 __svml_acoshf16(__m512 x); +extern __m512 __svml_atanhf16(__m512 x); +extern __m512 __svml_powf16(__m512 x, __m512 y); + +extern __m512d __svml_exp8_ha(__m512d x); +extern __m512d __svml_exp28_ha(__m512d x); +extern __m512d __svml_log8_ha(__m512d x); +extern __m512d __svml_log28_ha(__m512d x); +extern __m512d __svml_log108_ha(__m512d x); +extern __m512d __svml_expm18_ha(__m512d x); +extern __m512d __svml_log1p8_ha(__m512d x); +extern __m512d __svml_cbrt8_ha(__m512d x); +extern __m512d __svml_sin8_ha(__m512d x); +extern __m512d __svml_cos8_ha(__m512d x); +extern __m512d __svml_tan8_ha(__m512d x); +extern __m512d __svml_asin8_ha(__m512d x); +extern __m512d __svml_acos8_ha(__m512d x); +extern __m512d __svml_atan8_ha(__m512d x); +extern __m512d __svml_atan28_ha(__m512d x, __m512d y); +extern __m512d __svml_sinh8_ha(__m512d x); +extern __m512d __svml_cosh8_ha(__m512d x); +extern __m512d __svml_tanh8_ha(__m512d x); +extern __m512d __svml_asinh8_ha(__m512d x); +extern __m512d __svml_acosh8_ha(__m512d x); +extern __m512d __svml_atanh8_ha(__m512d x); +extern __m512d __svml_pow8_ha(__m512d x, __m512d y); +#endif diff --git a/mkl_umath/src/npyv/npyv.h b/mkl_umath/src/npyv/npyv.h new file mode 100644 index 00000000..545ca48c --- /dev/null +++ b/mkl_umath/src/npyv/npyv.h @@ -0,0 +1,112 @@ +#ifndef MKL_UMATH_NPYV_H +#define MKL_UMATH_NPYV_H + +/* + * Simplified SIMD vectorization wrapper for mkl_umath + * Using direct AVX2 intrinsics instead of NumPy's complex npyv layer + * + * This is a proof-of-concept focusing on FLOAT add operation only. + */ + +#include "numpy/npy_common.h" +#include // AVX2 intrinsics + +/* + * Check if AVX2 is available at compile time + */ +#ifdef __AVX2__ + #define NPYV_CAN_VECTORIZE_FLOAT 1 + #define NPYV_CAN_VECTORIZE_DOUBLE 1 + + // AVX2 vector lanes + #define npyv_nlanes_f32 8 // 256 bits / 32 bits + #define npyv_nlanes_f64 4 // 256 bits / 64 bits + #define npyv_nlanes_u8 32 // 256 bits / 8 bits + + // Type definitions + typedef __m256 npyv_f32; + typedef __m256d npyv_f64; + + // Load operations + #define npyv_load_f32(ptr) _mm256_loadu_ps((const float*)(ptr)) + #define npyv_load_f64(ptr) _mm256_loadu_pd((const double*)(ptr)) + + // Store operations + #define npyv_store_f32(ptr, vec) _mm256_storeu_ps((float*)(ptr), (vec)) + #define npyv_store_f64(ptr, vec) _mm256_storeu_pd((double*)(ptr), (vec)) + + // Arithmetic operations + #define npyv_add_f32(a, b) _mm256_add_ps((a), (b)) + #define npyv_add_f64(a, b) _mm256_add_pd((a), (b)) + + // Set all lanes to same value + #define npyv_setall_f32(val) _mm256_set1_ps(val) + #define npyv_setall_f64(val) _mm256_set1_pd(val) + + // Conditional load/store (simplified - just use regular ops for prototype) + static inline npyv_f32 npyv_load_tillz_f32(const float *ptr, npy_intp len) { + if (len >= 8) return npyv_load_f32(ptr); + // For partial loads, create mask and zero out unused elements + float temp[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + for (npy_intp i = 0; i < len && i < 8; i++) temp[i] = ptr[i]; + return _mm256_loadu_ps(temp); + } + + static inline npyv_f64 npyv_load_tillz_f64(const double *ptr, npy_intp len) { + if (len >= 4) return npyv_load_f64(ptr); + double temp[4] = {0, 0, 0, 0}; + for (npy_intp i = 0; i < len && i < 4; i++) temp[i] = ptr[i]; + return _mm256_loadu_pd(temp); + } + + static inline void npyv_store_till_f32(float *ptr, npy_intp len, npyv_f32 vec) { + if (len >= 8) { + npyv_store_f32(ptr, vec); + } else { + float temp[8]; + npyv_store_f32(temp, vec); + for (npy_intp i = 0; i < len && i < 8; i++) ptr[i] = temp[i]; + } + } + + static inline void npyv_store_till_f64(double *ptr, npy_intp len, npyv_f64 vec) { + if (len >= 4) { + npyv_store_f64(ptr, vec); + } else { + double temp[4]; + npyv_store_f64(temp, vec); + for (npy_intp i = 0; i < len && i < 4; i++) ptr[i] = temp[i]; + } + } + + // Cleanup (no-op for AVX2) + #define npyv_cleanup() do {} while(0) + +#else + // SIMD not available + #define NPYV_CAN_VECTORIZE_FLOAT 0 + #define NPYV_CAN_VECTORIZE_DOUBLE 0 +#endif + +/* + * Memory overlap detection (copied from NumPy) + */ +static inline int +is_mem_overlap(const void *src, npy_intp src_step, + const void *dst, npy_intp dst_step, npy_intp len) +{ + const char *src_ptr = (const char *)src; + const char *dst_ptr = (const char *)dst; + + if (src_ptr == dst_ptr) { + return 0; // Same pointer, always safe + } + + npy_intp src_size = len * src_step; + npy_intp dst_size = len * dst_step; + + // Check if ranges overlap + return !((src_ptr + src_size <= dst_ptr) || (dst_ptr + dst_size <= src_ptr)); +} + +#endif /* MKL_UMATH_NPYV_H */ diff --git a/mkl_umath/src/npyv/numpy_tag.h b/mkl_umath/src/npyv/numpy_tag.h new file mode 100644 index 00000000..ee0c36ca --- /dev/null +++ b/mkl_umath/src/npyv/numpy_tag.h @@ -0,0 +1,259 @@ +#ifndef _NPY_COMMON_TAG_H_ +#define _NPY_COMMON_TAG_H_ + +#include "../npysort/npysort_common.h" + +namespace npy { + +template +struct taglist { + static constexpr unsigned size = sizeof...(tags); +}; + +struct integral_tag { +}; +struct floating_point_tag { +}; +struct complex_tag { +}; +struct date_tag { +}; + +struct bool_tag : integral_tag { + using type = npy_bool; + static constexpr NPY_TYPES type_value = NPY_BOOL; + static int less(type const& a, type const& b) { + return BOOL_LT(a, b); + } + static int less_equal(type const& a, type const& b) { + return !less(b, a); + } +}; +struct byte_tag : integral_tag { + using type = npy_byte; + static constexpr NPY_TYPES type_value = NPY_BYTE; + static int less(type const& a, type const& b) { + return BYTE_LT(a, b); + } + static int less_equal(type const& a, type const& b) { + return !less(b, a); + } +}; +struct ubyte_tag : integral_tag { + using type = npy_ubyte; + static constexpr NPY_TYPES type_value = NPY_UBYTE; + static int less(type const& a, type const& b) { + return UBYTE_LT(a, b); + } + static int less_equal(type const& a, type const& b) { + return !less(b, a); + } +}; +struct short_tag : integral_tag { + using type = npy_short; + static constexpr NPY_TYPES type_value = NPY_SHORT; + static int less(type const& a, type const& b) { + return SHORT_LT(a, b); + } + static int less_equal(type const& a, type const& b) { + return !less(b, a); + } +}; +struct ushort_tag : integral_tag { + using type = npy_ushort; + static constexpr NPY_TYPES type_value = NPY_USHORT; + static int less(type const& a, type const& b) { + return USHORT_LT(a, b); + } + static int less_equal(type const& a, type const& b) { + return !less(b, a); + } +}; +struct int_tag : integral_tag { + using type = npy_int; + static constexpr NPY_TYPES type_value = NPY_INT; + static int less(type const& a, type const& b) { + return INT_LT(a, b); + } + static int less_equal(type const& a, type const& b) { + return !less(b, a); + } +}; +struct uint_tag : integral_tag { + using type = npy_uint; + static constexpr NPY_TYPES type_value = NPY_UINT; + static int less(type const& a, type const& b) { + return UINT_LT(a, b); + } + static int less_equal(type const& a, type const& b) { + return !less(b, a); + } +}; +struct long_tag : integral_tag { + using type = npy_long; + static constexpr NPY_TYPES type_value = NPY_LONG; + static int less(type const& a, type const& b) { + return LONG_LT(a, b); + } + static int less_equal(type const& a, type const& b) { + return !less(b, a); + } +}; +struct ulong_tag : integral_tag { + using type = npy_ulong; + static constexpr NPY_TYPES type_value = NPY_ULONG; + static int less(type const& a, type const& b) { + return ULONG_LT(a, b); + } + static int less_equal(type const& a, type const& b) { + return !less(b, a); + } +}; +struct longlong_tag : integral_tag { + using type = npy_longlong; + static constexpr NPY_TYPES type_value = NPY_LONGLONG; + static int less(type const& a, type const& b) { + return LONGLONG_LT(a, b); + } + static int less_equal(type const& a, type const& b) { + return !less(b, a); + } +}; +struct ulonglong_tag : integral_tag { + using type = npy_ulonglong; + static constexpr NPY_TYPES type_value = NPY_ULONGLONG; + static int less(type const& a, type const& b) { + return ULONGLONG_LT(a, b); + } + static int less_equal(type const& a, type const& b) { + return !less(b, a); + } +}; +struct half_tag { + using type = npy_half; + static constexpr NPY_TYPES type_value = NPY_HALF; + static int less(type const& a, type const& b) { + return HALF_LT(a, b); + } + static int less_equal(type const& a, type const& b) { + return !less(b, a); + } +}; +struct float_tag : floating_point_tag { + using type = npy_float; + static constexpr NPY_TYPES type_value = NPY_FLOAT; + static int less(type const& a, type const& b) { + return FLOAT_LT(a, b); + } + static int less_equal(type const& a, type const& b) { + return !less(b, a); + } +}; +struct double_tag : floating_point_tag { + using type = npy_double; + static constexpr NPY_TYPES type_value = NPY_DOUBLE; + static int less(type const& a, type const& b) { + return DOUBLE_LT(a, b); + } + static int less_equal(type const& a, type const& b) { + return !less(b, a); + } +}; +struct longdouble_tag : floating_point_tag { + using type = npy_longdouble; + static constexpr NPY_TYPES type_value = NPY_LONGDOUBLE; + static int less(type const& a, type const& b) { + return LONGDOUBLE_LT(a, b); + } + static int less_equal(type const& a, type const& b) { + return !less(b, a); + } +}; +struct cfloat_tag : complex_tag { + using type = npy_cfloat; + static constexpr NPY_TYPES type_value = NPY_CFLOAT; + static int less(type const& a, type const& b) { + return CFLOAT_LT(a, b); + } + static int less_equal(type const& a, type const& b) { + return !less(b, a); + } +}; +struct cdouble_tag : complex_tag { + using type = npy_cdouble; + static constexpr NPY_TYPES type_value = NPY_CDOUBLE; + static int less(type const& a, type const& b) { + return CDOUBLE_LT(a, b); + } + static int less_equal(type const& a, type const& b) { + return !less(b, a); + } +}; +struct clongdouble_tag : complex_tag { + using type = npy_clongdouble; + static constexpr NPY_TYPES type_value = NPY_CLONGDOUBLE; + static int less(type const& a, type const& b) { + return CLONGDOUBLE_LT(a, b); + } + static int less_equal(type const& a, type const& b) { + return !less(b, a); + } +}; +struct datetime_tag : date_tag { + using type = npy_datetime; + static constexpr NPY_TYPES type_value = NPY_DATETIME; + static int less(type const& a, type const& b) { + return DATETIME_LT(a, b); + } + static int less_equal(type const& a, type const& b) { + return !less(b, a); + } +}; +struct timedelta_tag : date_tag { + using type = npy_timedelta; + static constexpr NPY_TYPES type_value = NPY_TIMEDELTA; + static int less(type const& a, type const& b) { + return TIMEDELTA_LT(a, b); + } + static int less_equal(type const& a, type const& b) { + return !less(b, a); + } +}; + +struct string_tag { + using type = npy_char; + static constexpr NPY_TYPES type_value = NPY_STRING; + static int less(type const* a, type const* b, size_t len) { + return STRING_LT(a, b, len); + } + static int less_equal(type const* a, type const* b, size_t len) { + return !less(b, a, len); + } + static void swap(type* a, type* b, size_t len) { + STRING_SWAP(a, b, len); + } + static void copy(type * a, type const* b, size_t len) { + STRING_COPY(a, b, len); + } +}; + +struct unicode_tag { + using type = npy_ucs4; + static constexpr NPY_TYPES type_value = NPY_UNICODE; + static int less(type const* a, type const* b, size_t len) { + return UNICODE_LT(a, b, len); + } + static int less_equal(type const* a, type const* b, size_t len) { + return !less(b, a, len); + } + static void swap(type* a, type* b, size_t len) { + UNICODE_SWAP(a, b, len); + } + static void copy(type * a, type const* b, size_t len) { + UNICODE_COPY(a, b, len); + } +}; + +} // namespace npy + +#endif diff --git a/mkl_umath/src/npyv/numpyos.h b/mkl_umath/src/npyv/numpyos.h new file mode 100644 index 00000000..fac82f7d --- /dev/null +++ b/mkl_umath/src/npyv/numpyos.h @@ -0,0 +1,68 @@ +#ifndef NUMPY_CORE_SRC_COMMON_NPY_NUMPYOS_H_ +#define NUMPY_CORE_SRC_COMMON_NPY_NUMPYOS_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +NPY_NO_EXPORT char* +NumPyOS_ascii_formatd(char *buffer, size_t buf_size, + const char *format, + double val, int decimal); + +NPY_NO_EXPORT char* +NumPyOS_ascii_formatf(char *buffer, size_t buf_size, + const char *format, + float val, int decimal); + +NPY_NO_EXPORT char* +NumPyOS_ascii_formatl(char *buffer, size_t buf_size, + const char *format, + long double val, int decimal); + +NPY_NO_EXPORT double +NumPyOS_ascii_strtod(const char *s, char** endptr); + +NPY_NO_EXPORT long double +NumPyOS_ascii_strtold(const char *s, char** endptr); + +NPY_NO_EXPORT int +NumPyOS_ascii_ftolf(FILE *fp, double *value); + +NPY_NO_EXPORT int +NumPyOS_ascii_ftoLf(FILE *fp, long double *value); + +NPY_NO_EXPORT int +NumPyOS_ascii_isspace(int c); + +NPY_NO_EXPORT int +NumPyOS_ascii_isalpha(char c); + +NPY_NO_EXPORT int +NumPyOS_ascii_isdigit(char c); + +NPY_NO_EXPORT int +NumPyOS_ascii_isalnum(char c); + +NPY_NO_EXPORT int +NumPyOS_ascii_islower(char c); + +NPY_NO_EXPORT int +NumPyOS_ascii_isupper(char c); + +NPY_NO_EXPORT int +NumPyOS_ascii_tolower(char c); + +/* Convert a string to an int in an arbitrary base */ +NPY_NO_EXPORT npy_longlong +NumPyOS_strtoll(const char *str, char **endptr, int base); + +/* Convert a string to an int in an arbitrary base */ +NPY_NO_EXPORT npy_ulonglong +NumPyOS_strtoull(const char *str, char **endptr, int base); + +#ifdef __cplusplus +} +#endif + +#endif /* NUMPY_CORE_SRC_COMMON_NPY_NUMPYOS_H_ */ diff --git a/mkl_umath/src/npyv/simd.h b/mkl_umath/src/npyv/simd.h new file mode 100644 index 00000000..2d9d48cf --- /dev/null +++ b/mkl_umath/src/npyv/simd.h @@ -0,0 +1,171 @@ +#ifndef _NPY_SIMD_H_ +#define _NPY_SIMD_H_ +/** + * the NumPy C SIMD vectorization interface "NPYV" are types and functions intended + * to simplify vectorization of code on different platforms, currently supports + * the following SIMD extensions SSE, AVX2, AVX512, VSX and NEON. + * + * TODO: Add an independent sphinx doc. +*/ +#include "numpy/npy_common.h" +#ifndef __cplusplus + #include +#endif + +#include "npy_cpu_dispatch.h" +#include "simd_utils.h" + +#ifdef __cplusplus +extern "C" { +#endif +/* + * clang commit an aggressive optimization behaviour when flag `-ftrapping-math` + * isn't fully supported that's present at -O1 or greater. When partially loading a + * vector register for a operations that requires to fill up the remaining lanes + * with certain value for example divide operation needs to fill the remaining value + * with non-zero integer to avoid fp exception divide-by-zero. + * clang optimizer notices that the entire register is not needed for the store + * and optimizes out the fill of non-zero integer to the remaining + * elements. As workaround we mark the returned register with `volatile` + * followed by symmetric operand operation e.g. `or` + * to convince the compiler that the entire vector is needed. + */ +#if defined(__clang__) && !defined(NPY_HAVE_CLANG_FPSTRICT) + #define NPY_SIMD_GUARD_PARTIAL_LOAD 1 +#else + #define NPY_SIMD_GUARD_PARTIAL_LOAD 0 +#endif + +#if defined(_MSC_VER) && defined(_M_IX86) +/* + * Avoid using any of the following intrinsics with MSVC 32-bit, + * even if they are apparently work on newer versions. + * They had bad impact on the generated instructions, + * sometimes the compiler deal with them without the respect + * of 32-bit mode which lead to crush due to execute 64-bit + * instructions and other times generate bad emulated instructions. + */ + #undef _mm512_set1_epi64 + #undef _mm256_set1_epi64x + #undef _mm_set1_epi64x + #undef _mm512_setr_epi64x + #undef _mm256_setr_epi64x + #undef _mm_setr_epi64x + #undef _mm512_set_epi64x + #undef _mm256_set_epi64x + #undef _mm_set_epi64x +#endif + +// lane type by intrin suffix +typedef npy_uint8 npyv_lanetype_u8; +typedef npy_int8 npyv_lanetype_s8; +typedef npy_uint16 npyv_lanetype_u16; +typedef npy_int16 npyv_lanetype_s16; +typedef npy_uint32 npyv_lanetype_u32; +typedef npy_int32 npyv_lanetype_s32; +typedef npy_uint64 npyv_lanetype_u64; +typedef npy_int64 npyv_lanetype_s64; +typedef float npyv_lanetype_f32; +typedef double npyv_lanetype_f64; + +#if defined(NPY_HAVE_AVX512F) && !defined(NPY_SIMD_FORCE_256) && !defined(NPY_SIMD_FORCE_128) + #include "avx512/avx512.h" +#elif defined(NPY_HAVE_AVX2) && !defined(NPY_SIMD_FORCE_128) + #include "avx2/avx2.h" +#elif defined(NPY_HAVE_SSE2) + #include "sse/sse.h" +#endif + +// TODO: Add support for VSX(2.06) and BE Mode for VSX +#if defined(NPY_HAVE_VX) || (defined(NPY_HAVE_VSX2) && defined(__LITTLE_ENDIAN__)) + #include "vec/vec.h" +#endif + +#ifdef NPY_HAVE_NEON + #include "neon/neon.h" +#endif + +#ifndef NPY_SIMD + /// SIMD width in bits or 0 if there's no SIMD extension available. + #define NPY_SIMD 0 + /// SIMD width in bytes or 0 if there's no SIMD extension available. + #define NPY_SIMD_WIDTH 0 + /// 1 if the enabled SIMD extension supports single-precision otherwise 0. + #define NPY_SIMD_F32 0 + /// 1 if the enabled SIMD extension supports double-precision otherwise 0. + #define NPY_SIMD_F64 0 + /// 1 if the enabled SIMD extension supports native FMA otherwise 0. + /// note: we still emulate(fast) FMA intrinsics even if they + /// aren't supported but they shouldn't be used if the precision is matters. + #define NPY_SIMD_FMA3 0 + /// 1 if the enabled SIMD extension is running on big-endian mode otherwise 0. + #define NPY_SIMD_BIGENDIAN 0 + /// 1 if the supported comparison intrinsics(lt, le, gt, ge) + /// raises FP invalid exception for quite NaNs. + #define NPY_SIMD_CMPSIGNAL 0 +#endif + +// enable emulated mask operations for all SIMD extension except for AVX512 +#if !defined(NPY_HAVE_AVX512F) && NPY_SIMD && NPY_SIMD < 512 + #include "emulate_maskop.h" +#endif + +// enable integer divisor generator for all SIMD extensions +#if NPY_SIMD + #include "intdiv.h" +#endif + +/** + * Some SIMD extensions currently(AVX2, AVX512F) require (de facto) + * a maximum number of strides sizes when dealing with non-contiguous memory access. + * + * Therefore the following functions must be used to check the maximum + * acceptable limit of strides before using any of non-contiguous load/store intrinsics. + * + * For instance: + * npy_intp ld_stride = step[0] / sizeof(float); + * npy_intp st_stride = step[1] / sizeof(float); + * + * if (npyv_loadable_stride_f32(ld_stride) && npyv_storable_stride_f32(st_stride)) { + * for (;;) + * npyv_f32 a = npyv_loadn_f32(ld_pointer, ld_stride); + * // ... + * npyv_storen_f32(st_pointer, st_stride, a); + * } + * else { + * for (;;) + * // C scalars + * } + */ +#ifndef NPY_SIMD_MAXLOAD_STRIDE32 + #define NPY_SIMD_MAXLOAD_STRIDE32 0 +#endif +#ifndef NPY_SIMD_MAXSTORE_STRIDE32 + #define NPY_SIMD_MAXSTORE_STRIDE32 0 +#endif +#ifndef NPY_SIMD_MAXLOAD_STRIDE64 + #define NPY_SIMD_MAXLOAD_STRIDE64 0 +#endif +#ifndef NPY_SIMD_MAXSTORE_STRIDE64 + #define NPY_SIMD_MAXSTORE_STRIDE64 0 +#endif +#define NPYV_IMPL_MAXSTRIDE(SFX, MAXLOAD, MAXSTORE) \ + NPY_FINLINE int npyv_loadable_stride_##SFX(npy_intp stride) \ + { return MAXLOAD > 0 ? llabs(stride) <= MAXLOAD : 1; } \ + NPY_FINLINE int npyv_storable_stride_##SFX(npy_intp stride) \ + { return MAXSTORE > 0 ? llabs(stride) <= MAXSTORE : 1; } +#if NPY_SIMD + NPYV_IMPL_MAXSTRIDE(u32, NPY_SIMD_MAXLOAD_STRIDE32, NPY_SIMD_MAXSTORE_STRIDE32) + NPYV_IMPL_MAXSTRIDE(s32, NPY_SIMD_MAXLOAD_STRIDE32, NPY_SIMD_MAXSTORE_STRIDE32) + NPYV_IMPL_MAXSTRIDE(f32, NPY_SIMD_MAXLOAD_STRIDE32, NPY_SIMD_MAXSTORE_STRIDE32) + NPYV_IMPL_MAXSTRIDE(u64, NPY_SIMD_MAXLOAD_STRIDE64, NPY_SIMD_MAXSTORE_STRIDE64) + NPYV_IMPL_MAXSTRIDE(s64, NPY_SIMD_MAXLOAD_STRIDE64, NPY_SIMD_MAXSTORE_STRIDE64) +#endif +#if NPY_SIMD_F64 + NPYV_IMPL_MAXSTRIDE(f64, NPY_SIMD_MAXLOAD_STRIDE64, NPY_SIMD_MAXSTORE_STRIDE64) +#endif + +#ifdef __cplusplus +} +#endif +#endif // _NPY_SIMD_H_ diff --git a/mkl_umath/src/npyv/simd_utils.h b/mkl_umath/src/npyv/simd_utils.h new file mode 100644 index 00000000..06c2f16f --- /dev/null +++ b/mkl_umath/src/npyv/simd_utils.h @@ -0,0 +1,48 @@ +#ifndef _NPY_SIMD_UTILS_H +#define _NPY_SIMD_UTILS_H + +#define NPYV__SET_2(CAST, I0, I1, ...) (CAST)(I0), (CAST)(I1) + +#define NPYV__SET_4(CAST, I0, I1, I2, I3, ...) \ + (CAST)(I0), (CAST)(I1), (CAST)(I2), (CAST)(I3) + +#define NPYV__SET_8(CAST, I0, I1, I2, I3, I4, I5, I6, I7, ...) \ + (CAST)(I0), (CAST)(I1), (CAST)(I2), (CAST)(I3), (CAST)(I4), (CAST)(I5), (CAST)(I6), (CAST)(I7) + +#define NPYV__SET_16(CAST, I0, I1, I2, I3, I4, I5, I6, I7, I8, I9, I10, I11, I12, I13, I14, I15, ...) \ + NPYV__SET_8(CAST, I0, I1, I2, I3, I4, I5, I6, I7), \ + NPYV__SET_8(CAST, I8, I9, I10, I11, I12, I13, I14, I15) + +#define NPYV__SET_32(CAST, I0, I1, I2, I3, I4, I5, I6, I7, I8, I9, I10, I11, I12, I13, I14, I15, \ +I16, I17, I18, I19, I20, I21, I22, I23, I24, I25, I26, I27, I28, I29, I30, I31, ...) \ + \ + NPYV__SET_16(CAST, I0, I1, I2, I3, I4, I5, I6, I7, I8, I9, I10, I11, I12, I13, I14, I15), \ + NPYV__SET_16(CAST, I16, I17, I18, I19, I20, I21, I22, I23, I24, I25, I26, I27, I28, I29, I30, I31) + +#define NPYV__SET_64(CAST, I0, I1, I2, I3, I4, I5, I6, I7, I8, I9, I10, I11, I12, I13, I14, I15, \ +I16, I17, I18, I19, I20, I21, I22, I23, I24, I25, I26, I27, I28, I29, I30, I31, \ +I32, I33, I34, I35, I36, I37, I38, I39, I40, I41, I42, I43, I44, I45, I46, I47, \ +I48, I49, I50, I51, I52, I53, I54, I55, I56, I57, I58, I59, I60, I61, I62, I63, ...) \ + \ + NPYV__SET_32(CAST, I0, I1, I2, I3, I4, I5, I6, I7, I8, I9, I10, I11, I12, I13, I14, I15, \ +I16, I17, I18, I19, I20, I21, I22, I23, I24, I25, I26, I27, I28, I29, I30, I31), \ + NPYV__SET_32(CAST, I32, I33, I34, I35, I36, I37, I38, I39, I40, I41, I42, I43, I44, I45, I46, I47, \ +I48, I49, I50, I51, I52, I53, I54, I55, I56, I57, I58, I59, I60, I61, I62, I63) + +#define NPYV__SET_FILL_2(CAST, F, ...) NPY_EXPAND(NPYV__SET_2(CAST, __VA_ARGS__, F, F)) + +#define NPYV__SET_FILL_4(CAST, F, ...) NPY_EXPAND(NPYV__SET_4(CAST, __VA_ARGS__, F, F, F, F)) + +#define NPYV__SET_FILL_8(CAST, F, ...) NPY_EXPAND(NPYV__SET_8(CAST, __VA_ARGS__, F, F, F, F, F, F, F, F)) + +#define NPYV__SET_FILL_16(CAST, F, ...) NPY_EXPAND(NPYV__SET_16(CAST, __VA_ARGS__, \ + F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F)) + +#define NPYV__SET_FILL_32(CAST, F, ...) NPY_EXPAND(NPYV__SET_32(CAST, __VA_ARGS__, \ + F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F)) + +#define NPYV__SET_FILL_64(CAST, F, ...) NPY_EXPAND(NPYV__SET_64(CAST, __VA_ARGS__, \ + F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, \ + F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F, F)) + +#endif // _NPY_SIMD_UTILS_H diff --git a/mkl_umath/src/npyv/sse/arithmetic.h b/mkl_umath/src/npyv/sse/arithmetic.h new file mode 100644 index 00000000..357b136d --- /dev/null +++ b/mkl_umath/src/npyv/sse/arithmetic.h @@ -0,0 +1,415 @@ +#ifndef NPY_SIMD + #error "Not a standalone header" +#endif + +#ifndef _NPY_SIMD_SSE_ARITHMETIC_H +#define _NPY_SIMD_SSE_ARITHMETIC_H + +/*************************** + * Addition + ***************************/ +// non-saturated +#define npyv_add_u8 _mm_add_epi8 +#define npyv_add_s8 _mm_add_epi8 +#define npyv_add_u16 _mm_add_epi16 +#define npyv_add_s16 _mm_add_epi16 +#define npyv_add_u32 _mm_add_epi32 +#define npyv_add_s32 _mm_add_epi32 +#define npyv_add_u64 _mm_add_epi64 +#define npyv_add_s64 _mm_add_epi64 +#define npyv_add_f32 _mm_add_ps +#define npyv_add_f64 _mm_add_pd + +// saturated +#define npyv_adds_u8 _mm_adds_epu8 +#define npyv_adds_s8 _mm_adds_epi8 +#define npyv_adds_u16 _mm_adds_epu16 +#define npyv_adds_s16 _mm_adds_epi16 +// TODO: rest, after implement Packs intrins + +/*************************** + * Subtraction + ***************************/ +// non-saturated +#define npyv_sub_u8 _mm_sub_epi8 +#define npyv_sub_s8 _mm_sub_epi8 +#define npyv_sub_u16 _mm_sub_epi16 +#define npyv_sub_s16 _mm_sub_epi16 +#define npyv_sub_u32 _mm_sub_epi32 +#define npyv_sub_s32 _mm_sub_epi32 +#define npyv_sub_u64 _mm_sub_epi64 +#define npyv_sub_s64 _mm_sub_epi64 +#define npyv_sub_f32 _mm_sub_ps +#define npyv_sub_f64 _mm_sub_pd + +// saturated +#define npyv_subs_u8 _mm_subs_epu8 +#define npyv_subs_s8 _mm_subs_epi8 +#define npyv_subs_u16 _mm_subs_epu16 +#define npyv_subs_s16 _mm_subs_epi16 +// TODO: rest, after implement Packs intrins + +/*************************** + * Multiplication + ***************************/ +// non-saturated +NPY_FINLINE __m128i npyv_mul_u8(__m128i a, __m128i b) +{ + const __m128i mask = _mm_set1_epi32(0xFF00FF00); + __m128i even = _mm_mullo_epi16(a, b); + __m128i odd = _mm_mullo_epi16(_mm_srai_epi16(a, 8), _mm_srai_epi16(b, 8)); + odd = _mm_slli_epi16(odd, 8); + return npyv_select_u8(mask, odd, even); +} +#define npyv_mul_s8 npyv_mul_u8 +#define npyv_mul_u16 _mm_mullo_epi16 +#define npyv_mul_s16 _mm_mullo_epi16 + +#ifdef NPY_HAVE_SSE41 + #define npyv_mul_u32 _mm_mullo_epi32 +#else + NPY_FINLINE __m128i npyv_mul_u32(__m128i a, __m128i b) + { + __m128i even = _mm_mul_epu32(a, b); + __m128i odd = _mm_mul_epu32(_mm_srli_epi64(a, 32), _mm_srli_epi64(b, 32)); + __m128i low = _mm_unpacklo_epi32(even, odd); + __m128i high = _mm_unpackhi_epi32(even, odd); + return _mm_unpacklo_epi64(low, high); + } +#endif // NPY_HAVE_SSE41 +#define npyv_mul_s32 npyv_mul_u32 +// TODO: emulate 64-bit*/ +#define npyv_mul_f32 _mm_mul_ps +#define npyv_mul_f64 _mm_mul_pd + +// saturated +// TODO: after implement Packs intrins + +/*************************** + * Integer Division + ***************************/ +// See simd/intdiv.h for more clarification +// divide each unsigned 8-bit element by a precomputed divisor +NPY_FINLINE npyv_u8 npyv_divc_u8(npyv_u8 a, const npyv_u8x3 divisor) +{ + const __m128i bmask = _mm_set1_epi32(0x00FF00FF); + const __m128i shf1b = _mm_set1_epi8(0xFFU >> _mm_cvtsi128_si32(divisor.val[1])); + const __m128i shf2b = _mm_set1_epi8(0xFFU >> _mm_cvtsi128_si32(divisor.val[2])); + // high part of unsigned multiplication + __m128i mulhi_even = _mm_mullo_epi16(_mm_and_si128(a, bmask), divisor.val[0]); + __m128i mulhi_odd = _mm_mullo_epi16(_mm_srli_epi16(a, 8), divisor.val[0]); + mulhi_even = _mm_srli_epi16(mulhi_even, 8); + __m128i mulhi = npyv_select_u8(bmask, mulhi_even, mulhi_odd); + // floor(a/d) = (mulhi + ((a-mulhi) >> sh1)) >> sh2 + __m128i q = _mm_sub_epi8(a, mulhi); + q = _mm_and_si128(_mm_srl_epi16(q, divisor.val[1]), shf1b); + q = _mm_add_epi8(mulhi, q); + q = _mm_and_si128(_mm_srl_epi16(q, divisor.val[2]), shf2b); + return q; +} +// divide each signed 8-bit element by a precomputed divisor (round towards zero) +NPY_FINLINE npyv_s16 npyv_divc_s16(npyv_s16 a, const npyv_s16x3 divisor); +NPY_FINLINE npyv_s8 npyv_divc_s8(npyv_s8 a, const npyv_s8x3 divisor) +{ + const __m128i bmask = _mm_set1_epi32(0x00FF00FF); + // instead of _mm_cvtepi8_epi16/_mm_packs_epi16 to wrap around overflow + __m128i divc_even = npyv_divc_s16(_mm_srai_epi16(_mm_slli_epi16(a, 8), 8), divisor); + __m128i divc_odd = npyv_divc_s16(_mm_srai_epi16(a, 8), divisor); + divc_odd = _mm_slli_epi16(divc_odd, 8); + return npyv_select_u8(bmask, divc_even, divc_odd); +} +// divide each unsigned 16-bit element by a precomputed divisor +NPY_FINLINE npyv_u16 npyv_divc_u16(npyv_u16 a, const npyv_u16x3 divisor) +{ + // high part of unsigned multiplication + __m128i mulhi = _mm_mulhi_epu16(a, divisor.val[0]); + // floor(a/d) = (mulhi + ((a-mulhi) >> sh1)) >> sh2 + __m128i q = _mm_sub_epi16(a, mulhi); + q = _mm_srl_epi16(q, divisor.val[1]); + q = _mm_add_epi16(mulhi, q); + q = _mm_srl_epi16(q, divisor.val[2]); + return q; +} +// divide each signed 16-bit element by a precomputed divisor (round towards zero) +NPY_FINLINE npyv_s16 npyv_divc_s16(npyv_s16 a, const npyv_s16x3 divisor) +{ + // high part of signed multiplication + __m128i mulhi = _mm_mulhi_epi16(a, divisor.val[0]); + // q = ((a + mulhi) >> sh1) - XSIGN(a) + // trunc(a/d) = (q ^ dsign) - dsign + __m128i q = _mm_sra_epi16(_mm_add_epi16(a, mulhi), divisor.val[1]); + q = _mm_sub_epi16(q, _mm_srai_epi16(a, 15)); + q = _mm_sub_epi16(_mm_xor_si128(q, divisor.val[2]), divisor.val[2]); + return q; +} +// divide each unsigned 32-bit element by a precomputed divisor +NPY_FINLINE npyv_u32 npyv_divc_u32(npyv_u32 a, const npyv_u32x3 divisor) +{ + // high part of unsigned multiplication + __m128i mulhi_even = _mm_srli_epi64(_mm_mul_epu32(a, divisor.val[0]), 32); + __m128i mulhi_odd = _mm_mul_epu32(_mm_srli_epi64(a, 32), divisor.val[0]); +#ifdef NPY_HAVE_SSE41 + __m128i mulhi = _mm_blend_epi16(mulhi_even, mulhi_odd, 0xCC); +#else + __m128i mask_13 = _mm_setr_epi32(0, -1, 0, -1); + mulhi_odd = _mm_and_si128(mulhi_odd, mask_13); + __m128i mulhi = _mm_or_si128(mulhi_even, mulhi_odd); +#endif + // floor(a/d) = (mulhi + ((a-mulhi) >> sh1)) >> sh2 + __m128i q = _mm_sub_epi32(a, mulhi); + q = _mm_srl_epi32(q, divisor.val[1]); + q = _mm_add_epi32(mulhi, q); + q = _mm_srl_epi32(q, divisor.val[2]); + return q; +} +// divide each signed 32-bit element by a precomputed divisor (round towards zero) +NPY_FINLINE npyv_s32 npyv_divc_s32(npyv_s32 a, const npyv_s32x3 divisor) +{ + __m128i asign = _mm_srai_epi32(a, 31); +#ifdef NPY_HAVE_SSE41 + // high part of signed multiplication + __m128i mulhi_even = _mm_srli_epi64(_mm_mul_epi32(a, divisor.val[0]), 32); + __m128i mulhi_odd = _mm_mul_epi32(_mm_srli_epi64(a, 32), divisor.val[0]); + __m128i mulhi = _mm_blend_epi16(mulhi_even, mulhi_odd, 0xCC); +#else // not SSE4.1 + // high part of "unsigned" multiplication + __m128i mulhi_even = _mm_srli_epi64(_mm_mul_epu32(a, divisor.val[0]), 32); + __m128i mulhi_odd = _mm_mul_epu32(_mm_srli_epi64(a, 32), divisor.val[0]); + __m128i mask_13 = _mm_setr_epi32(0, -1, 0, -1); + mulhi_odd = _mm_and_si128(mulhi_odd, mask_13); + __m128i mulhi = _mm_or_si128(mulhi_even, mulhi_odd); + // convert unsigned to signed high multiplication + // mulhi - ((a < 0) ? m : 0) - ((m < 0) ? a : 0); + const __m128i msign= _mm_srai_epi32(divisor.val[0], 31); + __m128i m_asign = _mm_and_si128(divisor.val[0], asign); + __m128i a_msign = _mm_and_si128(a, msign); + mulhi = _mm_sub_epi32(mulhi, m_asign); + mulhi = _mm_sub_epi32(mulhi, a_msign); +#endif + // q = ((a + mulhi) >> sh1) - XSIGN(a) + // trunc(a/d) = (q ^ dsign) - dsign + __m128i q = _mm_sra_epi32(_mm_add_epi32(a, mulhi), divisor.val[1]); + q = _mm_sub_epi32(q, asign); + q = _mm_sub_epi32(_mm_xor_si128(q, divisor.val[2]), divisor.val[2]); + return q; +} +// returns the high 64 bits of unsigned 64-bit multiplication +// xref https://stackoverflow.com/a/28827013 +NPY_FINLINE npyv_u64 npyv__mullhi_u64(npyv_u64 a, npyv_u64 b) +{ + __m128i lomask = npyv_setall_s64(0xffffffff); + __m128i a_hi = _mm_srli_epi64(a, 32); // a0l, a0h, a1l, a1h + __m128i b_hi = _mm_srli_epi64(b, 32); // b0l, b0h, b1l, b1h + // compute partial products + __m128i w0 = _mm_mul_epu32(a, b); // a0l*b0l, a1l*b1l + __m128i w1 = _mm_mul_epu32(a, b_hi); // a0l*b0h, a1l*b1h + __m128i w2 = _mm_mul_epu32(a_hi, b); // a0h*b0l, a1h*b0l + __m128i w3 = _mm_mul_epu32(a_hi, b_hi); // a0h*b0h, a1h*b1h + // sum partial products + __m128i w0h = _mm_srli_epi64(w0, 32); + __m128i s1 = _mm_add_epi64(w1, w0h); + __m128i s1l = _mm_and_si128(s1, lomask); + __m128i s1h = _mm_srli_epi64(s1, 32); + + __m128i s2 = _mm_add_epi64(w2, s1l); + __m128i s2h = _mm_srli_epi64(s2, 32); + + __m128i hi = _mm_add_epi64(w3, s1h); + hi = _mm_add_epi64(hi, s2h); + return hi; +} +// divide each unsigned 64-bit element by a precomputed divisor +NPY_FINLINE npyv_u64 npyv_divc_u64(npyv_u64 a, const npyv_u64x3 divisor) +{ + // high part of unsigned multiplication + __m128i mulhi = npyv__mullhi_u64(a, divisor.val[0]); + // floor(a/d) = (mulhi + ((a-mulhi) >> sh1)) >> sh2 + __m128i q = _mm_sub_epi64(a, mulhi); + q = _mm_srl_epi64(q, divisor.val[1]); + q = _mm_add_epi64(mulhi, q); + q = _mm_srl_epi64(q, divisor.val[2]); + return q; +} +// divide each signed 64-bit element by a precomputed divisor (round towards zero) +NPY_FINLINE npyv_s64 npyv_divc_s64(npyv_s64 a, const npyv_s64x3 divisor) +{ + // high part of unsigned multiplication + __m128i mulhi = npyv__mullhi_u64(a, divisor.val[0]); + // convert unsigned to signed high multiplication + // mulhi - ((a < 0) ? m : 0) - ((m < 0) ? a : 0); +#ifdef NPY_HAVE_SSE42 + const __m128i msign= _mm_cmpgt_epi64(_mm_setzero_si128(), divisor.val[0]); + __m128i asign = _mm_cmpgt_epi64(_mm_setzero_si128(), a); +#else + const __m128i msign= _mm_srai_epi32(_mm_shuffle_epi32(divisor.val[0], _MM_SHUFFLE(3, 3, 1, 1)), 31); + __m128i asign = _mm_srai_epi32(_mm_shuffle_epi32(a, _MM_SHUFFLE(3, 3, 1, 1)), 31); +#endif + __m128i m_asign = _mm_and_si128(divisor.val[0], asign); + __m128i a_msign = _mm_and_si128(a, msign); + mulhi = _mm_sub_epi64(mulhi, m_asign); + mulhi = _mm_sub_epi64(mulhi, a_msign); + // q = (a + mulhi) >> sh + __m128i q = _mm_add_epi64(a, mulhi); + // emulate arithmetic right shift + const __m128i sigb = npyv_setall_s64(1LL << 63); + q = _mm_srl_epi64(_mm_add_epi64(q, sigb), divisor.val[1]); + q = _mm_sub_epi64(q, _mm_srl_epi64(sigb, divisor.val[1])); + // q = q - XSIGN(a) + // trunc(a/d) = (q ^ dsign) - dsign + q = _mm_sub_epi64(q, asign); + q = _mm_sub_epi64(_mm_xor_si128(q, divisor.val[2]), divisor.val[2]); + return q; +} +/*************************** + * Division + ***************************/ +// TODO: emulate integer division +#define npyv_div_f32 _mm_div_ps +#define npyv_div_f64 _mm_div_pd +/*************************** + * FUSED + ***************************/ +#ifdef NPY_HAVE_FMA3 + // multiply and add, a*b + c + #define npyv_muladd_f32 _mm_fmadd_ps + #define npyv_muladd_f64 _mm_fmadd_pd + // multiply and subtract, a*b - c + #define npyv_mulsub_f32 _mm_fmsub_ps + #define npyv_mulsub_f64 _mm_fmsub_pd + // negate multiply and add, -(a*b) + c + #define npyv_nmuladd_f32 _mm_fnmadd_ps + #define npyv_nmuladd_f64 _mm_fnmadd_pd + // negate multiply and subtract, -(a*b) - c + #define npyv_nmulsub_f32 _mm_fnmsub_ps + #define npyv_nmulsub_f64 _mm_fnmsub_pd + // multiply, add for odd elements and subtract even elements. + // (a * b) -+ c + #define npyv_muladdsub_f32 _mm_fmaddsub_ps + #define npyv_muladdsub_f64 _mm_fmaddsub_pd +#elif defined(NPY_HAVE_FMA4) + // multiply and add, a*b + c + #define npyv_muladd_f32 _mm_macc_ps + #define npyv_muladd_f64 _mm_macc_pd + // multiply and subtract, a*b - c + #define npyv_mulsub_f32 _mm_msub_ps + #define npyv_mulsub_f64 _mm_msub_pd + // negate multiply and add, -(a*b) + c + #define npyv_nmuladd_f32 _mm_nmacc_ps + #define npyv_nmuladd_f64 _mm_nmacc_pd + // multiply, add for odd elements and subtract even elements. + // (a * b) -+ c + #define npyv_muladdsub_f32 _mm_maddsub_ps + #define npyv_muladdsub_f64 _mm_maddsub_pd +#else + // multiply and add, a*b + c + NPY_FINLINE npyv_f32 npyv_muladd_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c) + { return npyv_add_f32(npyv_mul_f32(a, b), c); } + NPY_FINLINE npyv_f64 npyv_muladd_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c) + { return npyv_add_f64(npyv_mul_f64(a, b), c); } + // multiply and subtract, a*b - c + NPY_FINLINE npyv_f32 npyv_mulsub_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c) + { return npyv_sub_f32(npyv_mul_f32(a, b), c); } + NPY_FINLINE npyv_f64 npyv_mulsub_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c) + { return npyv_sub_f64(npyv_mul_f64(a, b), c); } + // negate multiply and add, -(a*b) + c + NPY_FINLINE npyv_f32 npyv_nmuladd_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c) + { return npyv_sub_f32(c, npyv_mul_f32(a, b)); } + NPY_FINLINE npyv_f64 npyv_nmuladd_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c) + { return npyv_sub_f64(c, npyv_mul_f64(a, b)); } + // multiply, add for odd elements and subtract even elements. + // (a * b) -+ c + NPY_FINLINE npyv_f32 npyv_muladdsub_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c) + { + npyv_f32 m = npyv_mul_f32(a, b); + #ifdef NPY_HAVE_SSE3 + return _mm_addsub_ps(m, c); + #else + const npyv_f32 msign = npyv_set_f32(-0.0f, 0.0f, -0.0f, 0.0f); + return npyv_add_f32(m, npyv_xor_f32(msign, c)); + #endif + } + NPY_FINLINE npyv_f64 npyv_muladdsub_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c) + { + npyv_f64 m = npyv_mul_f64(a, b); + #ifdef NPY_HAVE_SSE3 + return _mm_addsub_pd(m, c); + #else + const npyv_f64 msign = npyv_set_f64(-0.0, 0.0); + return npyv_add_f64(m, npyv_xor_f64(msign, c)); + #endif + } +#endif // NPY_HAVE_FMA3 +#ifndef NPY_HAVE_FMA3 // for FMA4 and NON-FMA3 + // negate multiply and subtract, -(a*b) - c + NPY_FINLINE npyv_f32 npyv_nmulsub_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c) + { + npyv_f32 neg_a = npyv_xor_f32(a, npyv_setall_f32(-0.0f)); + return npyv_sub_f32(npyv_mul_f32(neg_a, b), c); + } + NPY_FINLINE npyv_f64 npyv_nmulsub_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c) + { + npyv_f64 neg_a = npyv_xor_f64(a, npyv_setall_f64(-0.0)); + return npyv_sub_f64(npyv_mul_f64(neg_a, b), c); + } +#endif // !NPY_HAVE_FMA3 + +/*************************** + * Summation + ***************************/ +// reduce sum across vector +NPY_FINLINE npy_uint32 npyv_sum_u32(npyv_u32 a) +{ + __m128i t = _mm_add_epi32(a, _mm_srli_si128(a, 8)); + t = _mm_add_epi32(t, _mm_srli_si128(t, 4)); + return (unsigned)_mm_cvtsi128_si32(t); +} + +NPY_FINLINE npy_uint64 npyv_sum_u64(npyv_u64 a) +{ + __m128i one = _mm_add_epi64(a, _mm_unpackhi_epi64(a, a)); + return (npy_uint64)npyv128_cvtsi128_si64(one); +} + +NPY_FINLINE float npyv_sum_f32(npyv_f32 a) +{ +#ifdef NPY_HAVE_SSE3 + __m128 sum_halves = _mm_hadd_ps(a, a); + return _mm_cvtss_f32(_mm_hadd_ps(sum_halves, sum_halves)); +#else + __m128 t1 = _mm_movehl_ps(a, a); + __m128 t2 = _mm_add_ps(a, t1); + __m128 t3 = _mm_shuffle_ps(t2, t2, 1); + __m128 t4 = _mm_add_ss(t2, t3); + return _mm_cvtss_f32(t4); +#endif +} + +NPY_FINLINE double npyv_sum_f64(npyv_f64 a) +{ +#ifdef NPY_HAVE_SSE3 + return _mm_cvtsd_f64(_mm_hadd_pd(a, a)); +#else + return _mm_cvtsd_f64(_mm_add_pd(a, _mm_unpackhi_pd(a, a))); +#endif +} + +// expand the source vector and performs sum reduce +NPY_FINLINE npy_uint16 npyv_sumup_u8(npyv_u8 a) +{ + __m128i two = _mm_sad_epu8(a, _mm_setzero_si128()); + __m128i one = _mm_add_epi16(two, _mm_unpackhi_epi64(two, two)); + return (npy_uint16)_mm_cvtsi128_si32(one); +} + +NPY_FINLINE npy_uint32 npyv_sumup_u16(npyv_u16 a) +{ + const __m128i even_mask = _mm_set1_epi32(0x0000FFFF); + __m128i even = _mm_and_si128(a, even_mask); + __m128i odd = _mm_srli_epi32(a, 16); + __m128i four = _mm_add_epi32(even, odd); + return npyv_sum_u32(four); +} + +#endif // _NPY_SIMD_SSE_ARITHMETIC_H + + diff --git a/mkl_umath/src/npyv/sse/conversion.h b/mkl_umath/src/npyv/sse/conversion.h new file mode 100644 index 00000000..0811bf06 --- /dev/null +++ b/mkl_umath/src/npyv/sse/conversion.h @@ -0,0 +1,94 @@ +#ifndef NPY_SIMD + #error "Not a standalone header" +#endif + +#ifndef _NPY_SIMD_SSE_CVT_H +#define _NPY_SIMD_SSE_CVT_H + +// convert mask types to integer types +#define npyv_cvt_u8_b8(BL) BL +#define npyv_cvt_s8_b8(BL) BL +#define npyv_cvt_u16_b16(BL) BL +#define npyv_cvt_s16_b16(BL) BL +#define npyv_cvt_u32_b32(BL) BL +#define npyv_cvt_s32_b32(BL) BL +#define npyv_cvt_u64_b64(BL) BL +#define npyv_cvt_s64_b64(BL) BL +#define npyv_cvt_f32_b32 _mm_castsi128_ps +#define npyv_cvt_f64_b64 _mm_castsi128_pd + +// convert integer types to mask types +#define npyv_cvt_b8_u8(A) A +#define npyv_cvt_b8_s8(A) A +#define npyv_cvt_b16_u16(A) A +#define npyv_cvt_b16_s16(A) A +#define npyv_cvt_b32_u32(A) A +#define npyv_cvt_b32_s32(A) A +#define npyv_cvt_b64_u64(A) A +#define npyv_cvt_b64_s64(A) A +#define npyv_cvt_b32_f32 _mm_castps_si128 +#define npyv_cvt_b64_f64 _mm_castpd_si128 + +// convert boolean vector to integer bitfield +NPY_FINLINE npy_uint64 npyv_tobits_b8(npyv_b8 a) +{ return (npy_uint16)_mm_movemask_epi8(a); } +NPY_FINLINE npy_uint64 npyv_tobits_b16(npyv_b16 a) +{ + __m128i pack = _mm_packs_epi16(a, a); + return (npy_uint8)_mm_movemask_epi8(pack); +} +NPY_FINLINE npy_uint64 npyv_tobits_b32(npyv_b32 a) +{ return (npy_uint8)_mm_movemask_ps(_mm_castsi128_ps(a)); } +NPY_FINLINE npy_uint64 npyv_tobits_b64(npyv_b64 a) +{ return (npy_uint8)_mm_movemask_pd(_mm_castsi128_pd(a)); } + +// expand +NPY_FINLINE npyv_u16x2 npyv_expand_u16_u8(npyv_u8 data) { + npyv_u16x2 r; + const __m128i z = _mm_setzero_si128(); + r.val[0] = _mm_unpacklo_epi8(data, z); + r.val[1] = _mm_unpackhi_epi8(data, z); + return r; +} + +NPY_FINLINE npyv_u32x2 npyv_expand_u32_u16(npyv_u16 data) { + npyv_u32x2 r; + const __m128i z = _mm_setzero_si128(); + r.val[0] = _mm_unpacklo_epi16(data, z); + r.val[1] = _mm_unpackhi_epi16(data, z); + return r; +} + +// pack two 16-bit boolean into one 8-bit boolean vector +NPY_FINLINE npyv_b8 npyv_pack_b8_b16(npyv_b16 a, npyv_b16 b) { + return _mm_packs_epi16(a, b); +} + +// pack four 32-bit boolean vectors into one 8-bit boolean vector +NPY_FINLINE npyv_b8 +npyv_pack_b8_b32(npyv_b32 a, npyv_b32 b, npyv_b32 c, npyv_b32 d) { + npyv_b16 ab = _mm_packs_epi32(a, b); + npyv_b16 cd = _mm_packs_epi32(c, d); + return npyv_pack_b8_b16(ab, cd); +} + +// pack eight 64-bit boolean vectors into one 8-bit boolean vector +NPY_FINLINE npyv_b8 +npyv_pack_b8_b64(npyv_b64 a, npyv_b64 b, npyv_b64 c, npyv_b64 d, + npyv_b64 e, npyv_b64 f, npyv_b64 g, npyv_b64 h) { + npyv_b32 ab = _mm_packs_epi32(a, b); + npyv_b32 cd = _mm_packs_epi32(c, d); + npyv_b32 ef = _mm_packs_epi32(e, f); + npyv_b32 gh = _mm_packs_epi32(g, h); + return npyv_pack_b8_b32(ab, cd, ef, gh); +} + +// round to nearest integer (assuming even) +#define npyv_round_s32_f32 _mm_cvtps_epi32 +NPY_FINLINE npyv_s32 npyv_round_s32_f64(npyv_f64 a, npyv_f64 b) +{ + __m128i lo = _mm_cvtpd_epi32(a), hi = _mm_cvtpd_epi32(b); + return _mm_unpacklo_epi64(lo, hi); +} + +#endif // _NPY_SIMD_SSE_CVT_H diff --git a/mkl_umath/src/npyv/sse/math.h b/mkl_umath/src/npyv/sse/math.h new file mode 100644 index 00000000..b51c935a --- /dev/null +++ b/mkl_umath/src/npyv/sse/math.h @@ -0,0 +1,463 @@ +#ifndef NPY_SIMD + #error "Not a standalone header" +#endif + +#ifndef _NPY_SIMD_SSE_MATH_H +#define _NPY_SIMD_SSE_MATH_H +/*************************** + * Elementary + ***************************/ +// Square root +#define npyv_sqrt_f32 _mm_sqrt_ps +#define npyv_sqrt_f64 _mm_sqrt_pd + +// Reciprocal +NPY_FINLINE npyv_f32 npyv_recip_f32(npyv_f32 a) +{ return _mm_div_ps(_mm_set1_ps(1.0f), a); } +NPY_FINLINE npyv_f64 npyv_recip_f64(npyv_f64 a) +{ return _mm_div_pd(_mm_set1_pd(1.0), a); } + +// Absolute +NPY_FINLINE npyv_f32 npyv_abs_f32(npyv_f32 a) +{ + return _mm_and_ps( + a, _mm_castsi128_ps(_mm_set1_epi32(0x7fffffff)) + ); +} +NPY_FINLINE npyv_f64 npyv_abs_f64(npyv_f64 a) +{ + return _mm_and_pd( + a, _mm_castsi128_pd(npyv_setall_s64(0x7fffffffffffffffLL)) + ); +} + +// Square +NPY_FINLINE npyv_f32 npyv_square_f32(npyv_f32 a) +{ return _mm_mul_ps(a, a); } +NPY_FINLINE npyv_f64 npyv_square_f64(npyv_f64 a) +{ return _mm_mul_pd(a, a); } + +// Maximum, natively mapping with no guarantees to handle NaN. +#define npyv_max_f32 _mm_max_ps +#define npyv_max_f64 _mm_max_pd +// Maximum, supports IEEE floating-point arithmetic (IEC 60559), +// - If one of the two vectors contains NaN, the equivalent element of the other vector is set +// - Only if both corresponded elements are NaN, NaN is set. +NPY_FINLINE npyv_f32 npyv_maxp_f32(npyv_f32 a, npyv_f32 b) +{ + __m128i nn = npyv_notnan_f32(b); + __m128 max = _mm_max_ps(a, b); + return npyv_select_f32(nn, max, a); +} +NPY_FINLINE npyv_f64 npyv_maxp_f64(npyv_f64 a, npyv_f64 b) +{ + __m128i nn = npyv_notnan_f64(b); + __m128d max = _mm_max_pd(a, b); + return npyv_select_f64(nn, max, a); +} +NPY_FINLINE npyv_f32 npyv_maxn_f32(npyv_f32 a, npyv_f32 b) +{ + __m128i nn = npyv_notnan_f32(a); + __m128 max = _mm_max_ps(a, b); + return npyv_select_f32(nn, max, a); +} +NPY_FINLINE npyv_f64 npyv_maxn_f64(npyv_f64 a, npyv_f64 b) +{ + __m128i nn = npyv_notnan_f64(a); + __m128d max = _mm_max_pd(a, b); + return npyv_select_f64(nn, max, a); +} +// Maximum, integer operations +#ifdef NPY_HAVE_SSE41 + #define npyv_max_s8 _mm_max_epi8 + #define npyv_max_u16 _mm_max_epu16 + #define npyv_max_u32 _mm_max_epu32 + #define npyv_max_s32 _mm_max_epi32 +#else + NPY_FINLINE npyv_s8 npyv_max_s8(npyv_s8 a, npyv_s8 b) + { + return npyv_select_s8(npyv_cmpgt_s8(a, b), a, b); + } + NPY_FINLINE npyv_u16 npyv_max_u16(npyv_u16 a, npyv_u16 b) + { + return npyv_select_u16(npyv_cmpgt_u16(a, b), a, b); + } + NPY_FINLINE npyv_u32 npyv_max_u32(npyv_u32 a, npyv_u32 b) + { + return npyv_select_u32(npyv_cmpgt_u32(a, b), a, b); + } + NPY_FINLINE npyv_s32 npyv_max_s32(npyv_s32 a, npyv_s32 b) + { + return npyv_select_s32(npyv_cmpgt_s32(a, b), a, b); + } +#endif +#define npyv_max_u8 _mm_max_epu8 +#define npyv_max_s16 _mm_max_epi16 +NPY_FINLINE npyv_u64 npyv_max_u64(npyv_u64 a, npyv_u64 b) +{ + return npyv_select_u64(npyv_cmpgt_u64(a, b), a, b); +} +NPY_FINLINE npyv_s64 npyv_max_s64(npyv_s64 a, npyv_s64 b) +{ + return npyv_select_s64(npyv_cmpgt_s64(a, b), a, b); +} + +// Minimum, natively mapping with no guarantees to handle NaN. +#define npyv_min_f32 _mm_min_ps +#define npyv_min_f64 _mm_min_pd +// Minimum, supports IEEE floating-point arithmetic (IEC 60559), +// - If one of the two vectors contains NaN, the equivalent element of the other vector is set +// - Only if both corresponded elements are NaN, NaN is set. +NPY_FINLINE npyv_f32 npyv_minp_f32(npyv_f32 a, npyv_f32 b) +{ + __m128i nn = npyv_notnan_f32(b); + __m128 min = _mm_min_ps(a, b); + return npyv_select_f32(nn, min, a); +} +NPY_FINLINE npyv_f64 npyv_minp_f64(npyv_f64 a, npyv_f64 b) +{ + __m128i nn = npyv_notnan_f64(b); + __m128d min = _mm_min_pd(a, b); + return npyv_select_f64(nn, min, a); +} +NPY_FINLINE npyv_f32 npyv_minn_f32(npyv_f32 a, npyv_f32 b) +{ + __m128i nn = npyv_notnan_f32(a); + __m128 min = _mm_min_ps(a, b); + return npyv_select_f32(nn, min, a); +} +NPY_FINLINE npyv_f64 npyv_minn_f64(npyv_f64 a, npyv_f64 b) +{ + __m128i nn = npyv_notnan_f64(a); + __m128d min = _mm_min_pd(a, b); + return npyv_select_f64(nn, min, a); +} +// Minimum, integer operations +#ifdef NPY_HAVE_SSE41 + #define npyv_min_s8 _mm_min_epi8 + #define npyv_min_u16 _mm_min_epu16 + #define npyv_min_u32 _mm_min_epu32 + #define npyv_min_s32 _mm_min_epi32 +#else + NPY_FINLINE npyv_s8 npyv_min_s8(npyv_s8 a, npyv_s8 b) + { + return npyv_select_s8(npyv_cmplt_s8(a, b), a, b); + } + NPY_FINLINE npyv_u16 npyv_min_u16(npyv_u16 a, npyv_u16 b) + { + return npyv_select_u16(npyv_cmplt_u16(a, b), a, b); + } + NPY_FINLINE npyv_u32 npyv_min_u32(npyv_u32 a, npyv_u32 b) + { + return npyv_select_u32(npyv_cmplt_u32(a, b), a, b); + } + NPY_FINLINE npyv_s32 npyv_min_s32(npyv_s32 a, npyv_s32 b) + { + return npyv_select_s32(npyv_cmplt_s32(a, b), a, b); + } +#endif +#define npyv_min_u8 _mm_min_epu8 +#define npyv_min_s16 _mm_min_epi16 +NPY_FINLINE npyv_u64 npyv_min_u64(npyv_u64 a, npyv_u64 b) +{ + return npyv_select_u64(npyv_cmplt_u64(a, b), a, b); +} +NPY_FINLINE npyv_s64 npyv_min_s64(npyv_s64 a, npyv_s64 b) +{ + return npyv_select_s64(npyv_cmplt_s64(a, b), a, b); +} + +// reduce min&max for 32&64-bits +#define NPY_IMPL_SSE_REDUCE_MINMAX(STYPE, INTRIN, VINTRIN) \ + NPY_FINLINE STYPE##32 npyv_reduce_##INTRIN##32(__m128i a) \ + { \ + __m128i v64 = npyv_##INTRIN##32(a, _mm_shuffle_epi32(a, _MM_SHUFFLE(0, 0, 3, 2))); \ + __m128i v32 = npyv_##INTRIN##32(v64, _mm_shuffle_epi32(v64, _MM_SHUFFLE(0, 0, 0, 1))); \ + return (STYPE##32)_mm_cvtsi128_si32(v32); \ + } \ + NPY_FINLINE STYPE##64 npyv_reduce_##INTRIN##64(__m128i a) \ + { \ + __m128i v64 = npyv_##INTRIN##64(a, _mm_shuffle_epi32(a, _MM_SHUFFLE(0, 0, 3, 2))); \ + return (STYPE##64)npyv_extract0_u64(v64); \ + } + +NPY_IMPL_SSE_REDUCE_MINMAX(npy_uint, min_u, min_epu) +NPY_IMPL_SSE_REDUCE_MINMAX(npy_int, min_s, min_epi) +NPY_IMPL_SSE_REDUCE_MINMAX(npy_uint, max_u, max_epu) +NPY_IMPL_SSE_REDUCE_MINMAX(npy_int, max_s, max_epi) +#undef NPY_IMPL_SSE_REDUCE_MINMAX +// reduce min&max for ps & pd +#define NPY_IMPL_SSE_REDUCE_MINMAX(INTRIN, INF, INF64) \ + NPY_FINLINE float npyv_reduce_##INTRIN##_f32(npyv_f32 a) \ + { \ + __m128 v64 = _mm_##INTRIN##_ps(a, _mm_shuffle_ps(a, a, _MM_SHUFFLE(0, 0, 3, 2))); \ + __m128 v32 = _mm_##INTRIN##_ps(v64, _mm_shuffle_ps(v64, v64, _MM_SHUFFLE(0, 0, 0, 1))); \ + return _mm_cvtss_f32(v32); \ + } \ + NPY_FINLINE double npyv_reduce_##INTRIN##_f64(npyv_f64 a) \ + { \ + __m128d v64 = _mm_##INTRIN##_pd(a, _mm_shuffle_pd(a, a, _MM_SHUFFLE(0, 0, 0, 1))); \ + return _mm_cvtsd_f64(v64); \ + } \ + NPY_FINLINE float npyv_reduce_##INTRIN##p_f32(npyv_f32 a) \ + { \ + npyv_b32 notnan = npyv_notnan_f32(a); \ + if (NPY_UNLIKELY(!npyv_any_b32(notnan))) { \ + return _mm_cvtss_f32(a); \ + } \ + a = npyv_select_f32(notnan, a, npyv_reinterpret_f32_u32(npyv_setall_u32(INF))); \ + return npyv_reduce_##INTRIN##_f32(a); \ + } \ + NPY_FINLINE double npyv_reduce_##INTRIN##p_f64(npyv_f64 a) \ + { \ + npyv_b64 notnan = npyv_notnan_f64(a); \ + if (NPY_UNLIKELY(!npyv_any_b64(notnan))) { \ + return _mm_cvtsd_f64(a); \ + } \ + a = npyv_select_f64(notnan, a, npyv_reinterpret_f64_u64(npyv_setall_u64(INF64))); \ + return npyv_reduce_##INTRIN##_f64(a); \ + } \ + NPY_FINLINE float npyv_reduce_##INTRIN##n_f32(npyv_f32 a) \ + { \ + npyv_b32 notnan = npyv_notnan_f32(a); \ + if (NPY_UNLIKELY(!npyv_all_b32(notnan))) { \ + const union { npy_uint32 i; float f;} pnan = {0x7fc00000UL}; \ + return pnan.f; \ + } \ + return npyv_reduce_##INTRIN##_f32(a); \ + } \ + NPY_FINLINE double npyv_reduce_##INTRIN##n_f64(npyv_f64 a) \ + { \ + npyv_b64 notnan = npyv_notnan_f64(a); \ + if (NPY_UNLIKELY(!npyv_all_b64(notnan))) { \ + const union { npy_uint64 i; double d;} pnan = {0x7ff8000000000000ull}; \ + return pnan.d; \ + } \ + return npyv_reduce_##INTRIN##_f64(a); \ + } + +NPY_IMPL_SSE_REDUCE_MINMAX(min, 0x7f800000, 0x7ff0000000000000) +NPY_IMPL_SSE_REDUCE_MINMAX(max, 0xff800000, 0xfff0000000000000) +#undef NPY_IMPL_SSE_REDUCE_MINMAX + +// reduce min&max for 8&16-bits +#define NPY_IMPL_SSE_REDUCE_MINMAX(STYPE, INTRIN) \ + NPY_FINLINE STYPE##16 npyv_reduce_##INTRIN##16(__m128i a) \ + { \ + __m128i v64 = npyv_##INTRIN##16(a, _mm_shuffle_epi32(a, _MM_SHUFFLE(0, 0, 3, 2))); \ + __m128i v32 = npyv_##INTRIN##16(v64, _mm_shuffle_epi32(v64, _MM_SHUFFLE(0, 0, 0, 1))); \ + __m128i v16 = npyv_##INTRIN##16(v32, _mm_shufflelo_epi16(v32, _MM_SHUFFLE(0, 0, 0, 1))); \ + return (STYPE##16)_mm_cvtsi128_si32(v16); \ + } \ + NPY_FINLINE STYPE##8 npyv_reduce_##INTRIN##8(__m128i a) \ + { \ + __m128i v64 = npyv_##INTRIN##8(a, _mm_shuffle_epi32(a, _MM_SHUFFLE(0, 0, 3, 2))); \ + __m128i v32 = npyv_##INTRIN##8(v64, _mm_shuffle_epi32(v64, _MM_SHUFFLE(0, 0, 0, 1))); \ + __m128i v16 = npyv_##INTRIN##8(v32, _mm_shufflelo_epi16(v32, _MM_SHUFFLE(0, 0, 0, 1))); \ + __m128i v8 = npyv_##INTRIN##8(v16, _mm_srli_epi16(v16, 8)); \ + return (STYPE##16)_mm_cvtsi128_si32(v8); \ + } +NPY_IMPL_SSE_REDUCE_MINMAX(npy_uint, min_u) +NPY_IMPL_SSE_REDUCE_MINMAX(npy_int, min_s) +NPY_IMPL_SSE_REDUCE_MINMAX(npy_uint, max_u) +NPY_IMPL_SSE_REDUCE_MINMAX(npy_int, max_s) +#undef NPY_IMPL_SSE_REDUCE_MINMAX + +// round to nearest integer even +NPY_FINLINE npyv_f32 npyv_rint_f32(npyv_f32 a) +{ +#ifdef NPY_HAVE_SSE41 + return _mm_round_ps(a, _MM_FROUND_TO_NEAREST_INT); +#else + const __m128 szero = _mm_set1_ps(-0.0f); + const __m128i exp_mask = _mm_set1_epi32(0xff000000); + + __m128i nfinite_mask = _mm_slli_epi32(_mm_castps_si128(a), 1); + nfinite_mask = _mm_and_si128(nfinite_mask, exp_mask); + nfinite_mask = _mm_cmpeq_epi32(nfinite_mask, exp_mask); + + // eliminate nans/inf to avoid invalid fp errors + __m128 x = _mm_xor_ps(a, _mm_castsi128_ps(nfinite_mask)); + __m128i roundi = _mm_cvtps_epi32(x); + __m128 round = _mm_cvtepi32_ps(roundi); + // respect signed zero + round = _mm_or_ps(round, _mm_and_ps(a, szero)); + // if overflow return a + __m128i overflow_mask = _mm_cmpeq_epi32(roundi, _mm_castps_si128(szero)); + // a if a overflow or nonfinite + return npyv_select_f32(_mm_or_si128(nfinite_mask, overflow_mask), a, round); +#endif +} + +// round to nearest integer even +NPY_FINLINE npyv_f64 npyv_rint_f64(npyv_f64 a) +{ +#ifdef NPY_HAVE_SSE41 + return _mm_round_pd(a, _MM_FROUND_TO_NEAREST_INT); +#else + const __m128d szero = _mm_set1_pd(-0.0); + const __m128d two_power_52 = _mm_set1_pd(0x10000000000000); + __m128d nan_mask = _mm_cmpunord_pd(a, a); + // eliminate nans to avoid invalid fp errors within cmpge + __m128d abs_x = npyv_abs_f64(_mm_xor_pd(nan_mask, a)); + // round by add magic number 2^52 + // assuming that MXCSR register is set to rounding + __m128d round = _mm_sub_pd(_mm_add_pd(two_power_52, abs_x), two_power_52); + // copysign + round = _mm_or_pd(round, _mm_and_pd(a, szero)); + // a if |a| >= 2^52 or a == NaN + __m128d mask = _mm_cmpge_pd(abs_x, two_power_52); + mask = _mm_or_pd(mask, nan_mask); + return npyv_select_f64(_mm_castpd_si128(mask), a, round); +#endif +} +// ceil +#ifdef NPY_HAVE_SSE41 + #define npyv_ceil_f32 _mm_ceil_ps + #define npyv_ceil_f64 _mm_ceil_pd +#else + NPY_FINLINE npyv_f32 npyv_ceil_f32(npyv_f32 a) + { + const __m128 one = _mm_set1_ps(1.0f); + const __m128 szero = _mm_set1_ps(-0.0f); + const __m128i exp_mask = _mm_set1_epi32(0xff000000); + + __m128i nfinite_mask = _mm_slli_epi32(_mm_castps_si128(a), 1); + nfinite_mask = _mm_and_si128(nfinite_mask, exp_mask); + nfinite_mask = _mm_cmpeq_epi32(nfinite_mask, exp_mask); + + // eliminate nans/inf to avoid invalid fp errors + __m128 x = _mm_xor_ps(a, _mm_castsi128_ps(nfinite_mask)); + __m128i roundi = _mm_cvtps_epi32(x); + __m128 round = _mm_cvtepi32_ps(roundi); + __m128 ceil = _mm_add_ps(round, _mm_and_ps(_mm_cmplt_ps(round, x), one)); + // respect signed zero + ceil = _mm_or_ps(ceil, _mm_and_ps(a, szero)); + // if overflow return a + __m128i overflow_mask = _mm_cmpeq_epi32(roundi, _mm_castps_si128(szero)); + // a if a overflow or nonfinite + return npyv_select_f32(_mm_or_si128(nfinite_mask, overflow_mask), a, ceil); + } + NPY_FINLINE npyv_f64 npyv_ceil_f64(npyv_f64 a) + { + const __m128d one = _mm_set1_pd(1.0); + const __m128d szero = _mm_set1_pd(-0.0); + const __m128d two_power_52 = _mm_set1_pd(0x10000000000000); + __m128d nan_mask = _mm_cmpunord_pd(a, a); + // eliminate nans to avoid invalid fp errors within cmpge + __m128d x = _mm_xor_pd(nan_mask, a); + __m128d abs_x = npyv_abs_f64(x); + __m128d sign_x = _mm_and_pd(x, szero); + // round by add magic number 2^52 + // assuming that MXCSR register is set to rounding + __m128d round = _mm_sub_pd(_mm_add_pd(two_power_52, abs_x), two_power_52); + // copysign + round = _mm_or_pd(round, sign_x); + __m128d ceil = _mm_add_pd(round, _mm_and_pd(_mm_cmplt_pd(round, x), one)); + // respects sign of 0.0 + ceil = _mm_or_pd(ceil, sign_x); + // a if |a| >= 2^52 or a == NaN + __m128d mask = _mm_cmpge_pd(abs_x, two_power_52); + mask = _mm_or_pd(mask, nan_mask); + return npyv_select_f64(_mm_castpd_si128(mask), a, ceil); + } +#endif + +// trunc +#ifdef NPY_HAVE_SSE41 + #define npyv_trunc_f32(A) _mm_round_ps(A, _MM_FROUND_TO_ZERO) + #define npyv_trunc_f64(A) _mm_round_pd(A, _MM_FROUND_TO_ZERO) +#else + NPY_FINLINE npyv_f32 npyv_trunc_f32(npyv_f32 a) + { + const __m128 szero = _mm_set1_ps(-0.0f); + const __m128i exp_mask = _mm_set1_epi32(0xff000000); + + __m128i nfinite_mask = _mm_slli_epi32(_mm_castps_si128(a), 1); + nfinite_mask = _mm_and_si128(nfinite_mask, exp_mask); + nfinite_mask = _mm_cmpeq_epi32(nfinite_mask, exp_mask); + + // eliminate nans/inf to avoid invalid fp errors + __m128 x = _mm_xor_ps(a, _mm_castsi128_ps(nfinite_mask)); + __m128i trunci = _mm_cvttps_epi32(x); + __m128 trunc = _mm_cvtepi32_ps(trunci); + // respect signed zero, e.g. -0.5 -> -0.0 + trunc = _mm_or_ps(trunc, _mm_and_ps(a, szero)); + // if overflow return a + __m128i overflow_mask = _mm_cmpeq_epi32(trunci, _mm_castps_si128(szero)); + // a if a overflow or nonfinite + return npyv_select_f32(_mm_or_si128(nfinite_mask, overflow_mask), a, trunc); + } + NPY_FINLINE npyv_f64 npyv_trunc_f64(npyv_f64 a) + { + const __m128d one = _mm_set1_pd(1.0); + const __m128d szero = _mm_set1_pd(-0.0); + const __m128d two_power_52 = _mm_set1_pd(0x10000000000000); + __m128d nan_mask = _mm_cmpunord_pd(a, a); + // eliminate nans to avoid invalid fp errors within cmpge + __m128d abs_x = npyv_abs_f64(_mm_xor_pd(nan_mask, a)); + // round by add magic number 2^52 + // assuming that MXCSR register is set to rounding + __m128d abs_round = _mm_sub_pd(_mm_add_pd(two_power_52, abs_x), two_power_52); + __m128d subtrahend = _mm_and_pd(_mm_cmpgt_pd(abs_round, abs_x), one); + __m128d trunc = _mm_sub_pd(abs_round, subtrahend); + // copysign + trunc = _mm_or_pd(trunc, _mm_and_pd(a, szero)); + // a if |a| >= 2^52 or a == NaN + __m128d mask = _mm_cmpge_pd(abs_x, two_power_52); + mask = _mm_or_pd(mask, nan_mask); + return npyv_select_f64(_mm_castpd_si128(mask), a, trunc); + } +#endif + +// floor +#ifdef NPY_HAVE_SSE41 + #define npyv_floor_f32 _mm_floor_ps + #define npyv_floor_f64 _mm_floor_pd +#else + NPY_FINLINE npyv_f32 npyv_floor_f32(npyv_f32 a) + { + const __m128 one = _mm_set1_ps(1.0f); + const __m128 szero = _mm_set1_ps(-0.0f); + const __m128i exp_mask = _mm_set1_epi32(0xff000000); + + __m128i nfinite_mask = _mm_slli_epi32(_mm_castps_si128(a), 1); + nfinite_mask = _mm_and_si128(nfinite_mask, exp_mask); + nfinite_mask = _mm_cmpeq_epi32(nfinite_mask, exp_mask); + + // eliminate nans/inf to avoid invalid fp errors + __m128 x = _mm_xor_ps(a, _mm_castsi128_ps(nfinite_mask)); + __m128i roundi = _mm_cvtps_epi32(x); + __m128 round = _mm_cvtepi32_ps(roundi); + __m128 floor = _mm_sub_ps(round, _mm_and_ps(_mm_cmpgt_ps(round, x), one)); + // respect signed zero + floor = _mm_or_ps(floor, _mm_and_ps(a, szero)); + // if overflow return a + __m128i overflow_mask = _mm_cmpeq_epi32(roundi, _mm_castps_si128(szero)); + // a if a overflow or nonfinite + return npyv_select_f32(_mm_or_si128(nfinite_mask, overflow_mask), a, floor); + } + NPY_FINLINE npyv_f64 npyv_floor_f64(npyv_f64 a) + { + const __m128d one = _mm_set1_pd(1.0f); + const __m128d szero = _mm_set1_pd(-0.0f); + const __m128d two_power_52 = _mm_set1_pd(0x10000000000000); + __m128d nan_mask = _mm_cmpunord_pd(a, a); + // eliminate nans to avoid invalid fp errors within cmpge + __m128d x = _mm_xor_pd(nan_mask, a); + __m128d abs_x = npyv_abs_f64(x); + __m128d sign_x = _mm_and_pd(x, szero); + // round by add magic number 2^52 + // assuming that MXCSR register is set to rounding + __m128d round = _mm_sub_pd(_mm_add_pd(two_power_52, abs_x), two_power_52); + // copysign + round = _mm_or_pd(round, sign_x); + __m128d floor = _mm_sub_pd(round, _mm_and_pd(_mm_cmpgt_pd(round, x), one)); + // a if |a| >= 2^52 or a == NaN + __m128d mask = _mm_cmpge_pd(abs_x, two_power_52); + mask = _mm_or_pd(mask, nan_mask); + return npyv_select_f64(_mm_castpd_si128(mask), a, floor); + } +#endif // NPY_HAVE_SSE41 + +#endif // _NPY_SIMD_SSE_MATH_H diff --git a/mkl_umath/src/npyv/sse/memory.h b/mkl_umath/src/npyv/sse/memory.h new file mode 100644 index 00000000..90c01ffe --- /dev/null +++ b/mkl_umath/src/npyv/sse/memory.h @@ -0,0 +1,759 @@ +#ifndef NPY_SIMD + #error "Not a standalone header" +#endif + +#ifndef _NPY_SIMD_SSE_MEMORY_H +#define _NPY_SIMD_SSE_MEMORY_H + +#include "misc.h" + +/*************************** + * load/store + ***************************/ +// stream load +#ifdef NPY_HAVE_SSE41 + #define npyv__loads(PTR) _mm_stream_load_si128((__m128i *)(PTR)) +#else + #define npyv__loads(PTR) _mm_load_si128((const __m128i *)(PTR)) +#endif +#define NPYV_IMPL_SSE_MEM_INT(CTYPE, SFX) \ + NPY_FINLINE npyv_##SFX npyv_load_##SFX(const CTYPE *ptr) \ + { return _mm_loadu_si128((const __m128i*)ptr); } \ + NPY_FINLINE npyv_##SFX npyv_loada_##SFX(const CTYPE *ptr) \ + { return _mm_load_si128((const __m128i*)ptr); } \ + NPY_FINLINE npyv_##SFX npyv_loads_##SFX(const CTYPE *ptr) \ + { return npyv__loads(ptr); } \ + NPY_FINLINE npyv_##SFX npyv_loadl_##SFX(const CTYPE *ptr) \ + { return _mm_loadl_epi64((const __m128i*)ptr); } \ + NPY_FINLINE void npyv_store_##SFX(CTYPE *ptr, npyv_##SFX vec) \ + { _mm_storeu_si128((__m128i*)ptr, vec); } \ + NPY_FINLINE void npyv_storea_##SFX(CTYPE *ptr, npyv_##SFX vec) \ + { _mm_store_si128((__m128i*)ptr, vec); } \ + NPY_FINLINE void npyv_stores_##SFX(CTYPE *ptr, npyv_##SFX vec) \ + { _mm_stream_si128((__m128i*)ptr, vec); } \ + NPY_FINLINE void npyv_storel_##SFX(CTYPE *ptr, npyv_##SFX vec) \ + { _mm_storel_epi64((__m128i *)ptr, vec); } \ + NPY_FINLINE void npyv_storeh_##SFX(CTYPE *ptr, npyv_##SFX vec) \ + { _mm_storel_epi64((__m128i *)ptr, _mm_unpackhi_epi64(vec, vec)); } + +NPYV_IMPL_SSE_MEM_INT(npy_uint8, u8) +NPYV_IMPL_SSE_MEM_INT(npy_int8, s8) +NPYV_IMPL_SSE_MEM_INT(npy_uint16, u16) +NPYV_IMPL_SSE_MEM_INT(npy_int16, s16) +NPYV_IMPL_SSE_MEM_INT(npy_uint32, u32) +NPYV_IMPL_SSE_MEM_INT(npy_int32, s32) +NPYV_IMPL_SSE_MEM_INT(npy_uint64, u64) +NPYV_IMPL_SSE_MEM_INT(npy_int64, s64) + +// unaligned load +#define npyv_load_f32 _mm_loadu_ps +#define npyv_load_f64 _mm_loadu_pd +// aligned load +#define npyv_loada_f32 _mm_load_ps +#define npyv_loada_f64 _mm_load_pd +// load lower part +#define npyv_loadl_f32(PTR) _mm_castsi128_ps(npyv_loadl_u32((const npy_uint32*)(PTR))) +#define npyv_loadl_f64(PTR) _mm_castsi128_pd(npyv_loadl_u32((const npy_uint32*)(PTR))) +// stream load +#define npyv_loads_f32(PTR) _mm_castsi128_ps(npyv__loads(PTR)) +#define npyv_loads_f64(PTR) _mm_castsi128_pd(npyv__loads(PTR)) +// unaligned store +#define npyv_store_f32 _mm_storeu_ps +#define npyv_store_f64 _mm_storeu_pd +// aligned store +#define npyv_storea_f32 _mm_store_ps +#define npyv_storea_f64 _mm_store_pd +// stream store +#define npyv_stores_f32 _mm_stream_ps +#define npyv_stores_f64 _mm_stream_pd +// store lower part +#define npyv_storel_f32(PTR, VEC) _mm_storel_epi64((__m128i*)(PTR), _mm_castps_si128(VEC)); +#define npyv_storel_f64(PTR, VEC) _mm_storel_epi64((__m128i*)(PTR), _mm_castpd_si128(VEC)); +// store higher part +#define npyv_storeh_f32(PTR, VEC) npyv_storeh_u32((npy_uint32*)(PTR), _mm_castps_si128(VEC)) +#define npyv_storeh_f64(PTR, VEC) npyv_storeh_u32((npy_uint32*)(PTR), _mm_castpd_si128(VEC)) +/*************************** + * Non-contiguous Load + ***************************/ +//// 32 +NPY_FINLINE npyv_s32 npyv_loadn_s32(const npy_int32 *ptr, npy_intp stride) +{ + __m128i a = _mm_cvtsi32_si128(*ptr); +#ifdef NPY_HAVE_SSE41 + a = _mm_insert_epi32(a, ptr[stride], 1); + a = _mm_insert_epi32(a, ptr[stride*2], 2); + a = _mm_insert_epi32(a, ptr[stride*3], 3); +#else + __m128i a1 = _mm_cvtsi32_si128(ptr[stride]); + __m128i a2 = _mm_cvtsi32_si128(ptr[stride*2]); + __m128i a3 = _mm_cvtsi32_si128(ptr[stride*3]); + a = _mm_unpacklo_epi32(a, a1); + a = _mm_unpacklo_epi64(a, _mm_unpacklo_epi32(a2, a3)); +#endif + return a; +} +NPY_FINLINE npyv_u32 npyv_loadn_u32(const npy_uint32 *ptr, npy_intp stride) +{ return npyv_loadn_s32((const npy_int32*)ptr, stride); } +NPY_FINLINE npyv_f32 npyv_loadn_f32(const float *ptr, npy_intp stride) +{ return _mm_castsi128_ps(npyv_loadn_s32((const npy_int32*)ptr, stride)); } +//// 64 +NPY_FINLINE npyv_f64 npyv_loadn_f64(const double *ptr, npy_intp stride) +{ return _mm_loadh_pd(npyv_loadl_f64(ptr), ptr + stride); } +NPY_FINLINE npyv_u64 npyv_loadn_u64(const npy_uint64 *ptr, npy_intp stride) +{ return _mm_castpd_si128(npyv_loadn_f64((const double*)ptr, stride)); } +NPY_FINLINE npyv_s64 npyv_loadn_s64(const npy_int64 *ptr, npy_intp stride) +{ return _mm_castpd_si128(npyv_loadn_f64((const double*)ptr, stride)); } + +//// 64-bit load over 32-bit stride +NPY_FINLINE npyv_f32 npyv_loadn2_f32(const float *ptr, npy_intp stride) +{ + __m128d r = _mm_loadh_pd( + npyv_loadl_f64((const double*)ptr), (const double*)(ptr + stride) + ); + return _mm_castpd_ps(r); +} +NPY_FINLINE npyv_u32 npyv_loadn2_u32(const npy_uint32 *ptr, npy_intp stride) +{ return _mm_castps_si128(npyv_loadn2_f32((const float*)ptr, stride)); } +NPY_FINLINE npyv_s32 npyv_loadn2_s32(const npy_int32 *ptr, npy_intp stride) +{ return _mm_castps_si128(npyv_loadn2_f32((const float*)ptr, stride)); } + +//// 128-bit load over 64-bit stride +NPY_FINLINE npyv_f64 npyv_loadn2_f64(const double *ptr, npy_intp stride) +{ (void)stride; return npyv_load_f64(ptr); } +NPY_FINLINE npyv_u64 npyv_loadn2_u64(const npy_uint64 *ptr, npy_intp stride) +{ (void)stride; return npyv_load_u64(ptr); } +NPY_FINLINE npyv_s64 npyv_loadn2_s64(const npy_int64 *ptr, npy_intp stride) +{ (void)stride; return npyv_load_s64(ptr); } + +/*************************** + * Non-contiguous Store + ***************************/ +//// 32 +NPY_FINLINE void npyv_storen_s32(npy_int32 *ptr, npy_intp stride, npyv_s32 a) +{ + ptr[stride * 0] = _mm_cvtsi128_si32(a); +#ifdef NPY_HAVE_SSE41 + ptr[stride * 1] = _mm_extract_epi32(a, 1); + ptr[stride * 2] = _mm_extract_epi32(a, 2); + ptr[stride * 3] = _mm_extract_epi32(a, 3); +#else + ptr[stride * 1] = _mm_cvtsi128_si32(_mm_shuffle_epi32(a, _MM_SHUFFLE(0, 0, 0, 1))); + ptr[stride * 2] = _mm_cvtsi128_si32(_mm_shuffle_epi32(a, _MM_SHUFFLE(0, 0, 0, 2))); + ptr[stride * 3] = _mm_cvtsi128_si32(_mm_shuffle_epi32(a, _MM_SHUFFLE(0, 0, 0, 3))); +#endif +} +NPY_FINLINE void npyv_storen_u32(npy_uint32 *ptr, npy_intp stride, npyv_u32 a) +{ npyv_storen_s32((npy_int32*)ptr, stride, a); } +NPY_FINLINE void npyv_storen_f32(float *ptr, npy_intp stride, npyv_f32 a) +{ npyv_storen_s32((npy_int32*)ptr, stride, _mm_castps_si128(a)); } +//// 64 +NPY_FINLINE void npyv_storen_f64(double *ptr, npy_intp stride, npyv_f64 a) +{ + _mm_storel_pd(ptr, a); + _mm_storeh_pd(ptr + stride, a); +} +NPY_FINLINE void npyv_storen_u64(npy_uint64 *ptr, npy_intp stride, npyv_u64 a) +{ npyv_storen_f64((double*)ptr, stride, _mm_castsi128_pd(a)); } +NPY_FINLINE void npyv_storen_s64(npy_int64 *ptr, npy_intp stride, npyv_s64 a) +{ npyv_storen_f64((double*)ptr, stride, _mm_castsi128_pd(a)); } + +//// 64-bit store over 32-bit stride +NPY_FINLINE void npyv_storen2_u32(npy_uint32 *ptr, npy_intp stride, npyv_u32 a) +{ + _mm_storel_pd((double*)ptr, _mm_castsi128_pd(a)); + _mm_storeh_pd((double*)(ptr + stride), _mm_castsi128_pd(a)); +} +NPY_FINLINE void npyv_storen2_s32(npy_int32 *ptr, npy_intp stride, npyv_s32 a) +{ npyv_storen2_u32((npy_uint32*)ptr, stride, a); } +NPY_FINLINE void npyv_storen2_f32(float *ptr, npy_intp stride, npyv_f32 a) +{ npyv_storen2_u32((npy_uint32*)ptr, stride, _mm_castps_si128(a)); } + +//// 128-bit store over 64-bit stride +NPY_FINLINE void npyv_storen2_u64(npy_uint64 *ptr, npy_intp stride, npyv_u64 a) +{ (void)stride; npyv_store_u64(ptr, a); } +NPY_FINLINE void npyv_storen2_s64(npy_int64 *ptr, npy_intp stride, npyv_s64 a) +{ (void)stride; npyv_store_s64(ptr, a); } +NPY_FINLINE void npyv_storen2_f64(double *ptr, npy_intp stride, npyv_f64 a) +{ (void)stride; npyv_store_f64(ptr, a); } +/********************************* + * Partial Load + *********************************/ +//// 32 +NPY_FINLINE npyv_s32 npyv_load_till_s32(const npy_int32 *ptr, npy_uintp nlane, npy_int32 fill) +{ + assert(nlane > 0); + #ifndef NPY_HAVE_SSE41 + const short *wptr = (const short*)ptr; + #endif + const __m128i vfill = npyv_setall_s32(fill); + __m128i a; + switch(nlane) { + case 2: + a = _mm_castpd_si128( + _mm_loadl_pd(_mm_castsi128_pd(vfill), (double*)ptr) + ); + break; + #ifdef NPY_HAVE_SSE41 + case 1: + a = _mm_insert_epi32(vfill, ptr[0], 0); + break; + case 3: + a = _mm_loadl_epi64((const __m128i*)ptr); + a = _mm_insert_epi32(a, ptr[2], 2); + a = _mm_insert_epi32(a, fill, 3); + break; + #else + case 1: + a = _mm_insert_epi16(vfill, wptr[0], 0); + a = _mm_insert_epi16(a, wptr[1], 1); + break; + case 3: + a = _mm_loadl_epi64((const __m128i*)ptr); + a = _mm_unpacklo_epi64(a, vfill); + a = _mm_insert_epi16(a, wptr[4], 4); + a = _mm_insert_epi16(a, wptr[5], 5); + break; + #endif // NPY_HAVE_SSE41 + default: + return npyv_load_s32(ptr); + } + #if NPY_SIMD_GUARD_PARTIAL_LOAD + // We use a variable marked 'volatile' to convince the compiler that + // the entire vector is needed. + volatile __m128i workaround = a; + // avoid optimizing it out + a = _mm_or_si128(workaround, a); + #endif + return a; +} +// fill zero to rest lanes +NPY_FINLINE npyv_s32 npyv_load_tillz_s32(const npy_int32 *ptr, npy_uintp nlane) +{ + assert(nlane > 0); + switch(nlane) { + case 1: + return _mm_cvtsi32_si128(*ptr); + case 2: + return _mm_loadl_epi64((const __m128i*)ptr); + case 3: { + npyv_s32 a = _mm_loadl_epi64((const __m128i*)ptr); + #ifdef NPY_HAVE_SSE41 + return _mm_insert_epi32(a, ptr[2], 2); + #else + return _mm_unpacklo_epi64(a, _mm_cvtsi32_si128(ptr[2])); + #endif + } + default: + return npyv_load_s32(ptr); + } +} +//// 64 +NPY_FINLINE npyv_s64 npyv_load_till_s64(const npy_int64 *ptr, npy_uintp nlane, npy_int64 fill) +{ + assert(nlane > 0); + if (nlane == 1) { + const __m128i vfill = npyv_setall_s64(fill); + npyv_s64 a = _mm_castpd_si128( + _mm_loadl_pd(_mm_castsi128_pd(vfill), (double*)ptr) + ); + #if NPY_SIMD_GUARD_PARTIAL_LOAD + volatile __m128i workaround = a; + a = _mm_or_si128(workaround, a); + #endif + return a; + } + return npyv_load_s64(ptr); +} +// fill zero to rest lanes +NPY_FINLINE npyv_s64 npyv_load_tillz_s64(const npy_int64 *ptr, npy_uintp nlane) +{ + assert(nlane > 0); + if (nlane == 1) { + return _mm_loadl_epi64((const __m128i*)ptr); + } + return npyv_load_s64(ptr); +} + +//// 64-bit nlane +NPY_FINLINE npyv_s32 npyv_load2_till_s32(const npy_int32 *ptr, npy_uintp nlane, + npy_int32 fill_lo, npy_int32 fill_hi) +{ + assert(nlane > 0); + if (nlane == 1) { + const __m128i vfill = npyv_set_s32(fill_lo, fill_hi, fill_lo, fill_hi); + __m128i a = _mm_castpd_si128( + _mm_loadl_pd(_mm_castsi128_pd(vfill), (double*)ptr) + ); + #if NPY_SIMD_GUARD_PARTIAL_LOAD + volatile __m128i workaround = a; + a = _mm_or_si128(workaround, a); + #endif + return a; + } + return npyv_load_s32(ptr); +} +// fill zero to rest lanes +NPY_FINLINE npyv_s32 npyv_load2_tillz_s32(const npy_int32 *ptr, npy_uintp nlane) +{ return npyv_load_tillz_s64((const npy_int64*)ptr, nlane); } + +//// 128-bit nlane +NPY_FINLINE npyv_s64 npyv_load2_till_s64(const npy_int64 *ptr, npy_uintp nlane, + npy_int64 fill_lo, npy_int64 fill_hi) +{ (void)nlane; (void)fill_lo; (void)fill_hi; return npyv_load_s64(ptr); } + +NPY_FINLINE npyv_s64 npyv_load2_tillz_s64(const npy_int64 *ptr, npy_uintp nlane) +{ (void)nlane; return npyv_load_s64(ptr); } + +/********************************* + * Non-contiguous partial load + *********************************/ +//// 32 +NPY_FINLINE npyv_s32 +npyv_loadn_till_s32(const npy_int32 *ptr, npy_intp stride, npy_uintp nlane, npy_int32 fill) +{ + assert(nlane > 0); + __m128i vfill = npyv_setall_s32(fill); + #ifndef NPY_HAVE_SSE41 + const short *wptr = (const short*)ptr; + #endif + switch(nlane) { + #ifdef NPY_HAVE_SSE41 + case 3: + vfill = _mm_insert_epi32(vfill, ptr[stride*2], 2); + case 2: + vfill = _mm_insert_epi32(vfill, ptr[stride], 1); + case 1: + vfill = _mm_insert_epi32(vfill, ptr[0], 0); + break; + #else + case 3: + vfill = _mm_unpacklo_epi32(_mm_cvtsi32_si128(ptr[stride*2]), vfill); + case 2: + vfill = _mm_unpacklo_epi64(_mm_unpacklo_epi32( + _mm_cvtsi32_si128(*ptr), _mm_cvtsi32_si128(ptr[stride]) + ), vfill); + break; + case 1: + vfill = _mm_insert_epi16(vfill, wptr[0], 0); + vfill = _mm_insert_epi16(vfill, wptr[1], 1); + break; + #endif // NPY_HAVE_SSE41 + default: + return npyv_loadn_s32(ptr, stride); + } // switch +#if NPY_SIMD_GUARD_PARTIAL_LOAD + volatile __m128i workaround = vfill; + vfill = _mm_or_si128(workaround, vfill); +#endif + return vfill; +} +// fill zero to rest lanes +NPY_FINLINE npyv_s32 +npyv_loadn_tillz_s32(const npy_int32 *ptr, npy_intp stride, npy_uintp nlane) +{ + assert(nlane > 0); + switch(nlane) { + case 1: + return _mm_cvtsi32_si128(ptr[0]); + case 2:; + { + npyv_s32 a = _mm_cvtsi32_si128(ptr[0]); + #ifdef NPY_HAVE_SSE41 + return _mm_insert_epi32(a, ptr[stride], 1); + #else + return _mm_unpacklo_epi32(a, _mm_cvtsi32_si128(ptr[stride])); + #endif // NPY_HAVE_SSE41 + } + case 3: + { + npyv_s32 a = _mm_cvtsi32_si128(ptr[0]); + #ifdef NPY_HAVE_SSE41 + a = _mm_insert_epi32(a, ptr[stride], 1); + a = _mm_insert_epi32(a, ptr[stride*2], 2); + return a; + #else + a = _mm_unpacklo_epi32(a, _mm_cvtsi32_si128(ptr[stride])); + a = _mm_unpacklo_epi64(a, _mm_cvtsi32_si128(ptr[stride*2])); + return a; + #endif // NPY_HAVE_SSE41 + } + default: + return npyv_loadn_s32(ptr, stride); + } +} +//// 64 +NPY_FINLINE npyv_s64 +npyv_loadn_till_s64(const npy_int64 *ptr, npy_intp stride, npy_uintp nlane, npy_int64 fill) +{ + assert(nlane > 0); + if (nlane == 1) { + return npyv_load_till_s64(ptr, 1, fill); + } + return npyv_loadn_s64(ptr, stride); +} +// fill zero to rest lanes +NPY_FINLINE npyv_s64 npyv_loadn_tillz_s64(const npy_int64 *ptr, npy_intp stride, npy_uintp nlane) +{ + assert(nlane > 0); + if (nlane == 1) { + return _mm_loadl_epi64((const __m128i*)ptr); + } + return npyv_loadn_s64(ptr, stride); +} + +//// 64-bit load over 32-bit stride +NPY_FINLINE npyv_s32 npyv_loadn2_till_s32(const npy_int32 *ptr, npy_intp stride, npy_uintp nlane, + npy_int32 fill_lo, npy_int32 fill_hi) +{ + assert(nlane > 0); + if (nlane == 1) { + const __m128i vfill = npyv_set_s32(0, 0, fill_lo, fill_hi); + __m128i a = _mm_castpd_si128( + _mm_loadl_pd(_mm_castsi128_pd(vfill), (double*)ptr) + ); + #if NPY_SIMD_GUARD_PARTIAL_LOAD + volatile __m128i workaround = a; + a = _mm_or_si128(workaround, a); + #endif + return a; + } + return npyv_loadn2_s32(ptr, stride); +} +NPY_FINLINE npyv_s32 npyv_loadn2_tillz_s32(const npy_int32 *ptr, npy_intp stride, npy_uintp nlane) +{ + assert(nlane > 0); + if (nlane == 1) { + return _mm_loadl_epi64((const __m128i*)ptr); + } + return npyv_loadn2_s32(ptr, stride); +} + +//// 128-bit load over 64-bit stride +NPY_FINLINE npyv_s64 npyv_loadn2_till_s64(const npy_int64 *ptr, npy_intp stride, npy_uintp nlane, + npy_int64 fill_lo, npy_int64 fill_hi) +{ assert(nlane > 0); (void)stride; (void)nlane; (void)fill_lo; (void)fill_hi; return npyv_load_s64(ptr); } + +NPY_FINLINE npyv_s64 npyv_loadn2_tillz_s64(const npy_int64 *ptr, npy_intp stride, npy_uintp nlane) +{ assert(nlane > 0); (void)stride; (void)nlane; return npyv_load_s64(ptr); } + +/********************************* + * Partial store + *********************************/ +//// 32 +NPY_FINLINE void npyv_store_till_s32(npy_int32 *ptr, npy_uintp nlane, npyv_s32 a) +{ + assert(nlane > 0); + switch(nlane) { + case 1: + *ptr = _mm_cvtsi128_si32(a); + break; + case 2: + _mm_storel_epi64((__m128i *)ptr, a); + break; + case 3: + _mm_storel_epi64((__m128i *)ptr, a); + #ifdef NPY_HAVE_SSE41 + ptr[2] = _mm_extract_epi32(a, 2); + #else + ptr[2] = _mm_cvtsi128_si32(_mm_shuffle_epi32(a, _MM_SHUFFLE(0, 0, 0, 2))); + #endif + break; + default: + npyv_store_s32(ptr, a); + } +} +//// 64 +NPY_FINLINE void npyv_store_till_s64(npy_int64 *ptr, npy_uintp nlane, npyv_s64 a) +{ + assert(nlane > 0); + if (nlane == 1) { + _mm_storel_epi64((__m128i *)ptr, a); + return; + } + npyv_store_s64(ptr, a); +} +//// 64-bit nlane +NPY_FINLINE void npyv_store2_till_s32(npy_int32 *ptr, npy_uintp nlane, npyv_s32 a) +{ npyv_store_till_s64((npy_int64*)ptr, nlane, a); } + +//// 128-bit nlane +NPY_FINLINE void npyv_store2_till_s64(npy_int64 *ptr, npy_uintp nlane, npyv_s64 a) +{ + assert(nlane > 0); (void)nlane; + npyv_store_s64(ptr, a); +} + +/********************************* + * Non-contiguous partial store + *********************************/ +//// 32 +NPY_FINLINE void npyv_storen_till_s32(npy_int32 *ptr, npy_intp stride, npy_uintp nlane, npyv_s32 a) +{ + assert(nlane > 0); + ptr[stride*0] = _mm_cvtsi128_si32(a); + switch(nlane) { + case 1: + return; +#ifdef NPY_HAVE_SSE41 + case 2: + ptr[stride*1] = _mm_extract_epi32(a, 1); + return; + case 3: + ptr[stride*1] = _mm_extract_epi32(a, 1); + ptr[stride*2] = _mm_extract_epi32(a, 2); + return; + default: + ptr[stride*1] = _mm_extract_epi32(a, 1); + ptr[stride*2] = _mm_extract_epi32(a, 2); + ptr[stride*3] = _mm_extract_epi32(a, 3); +#else + case 2: + ptr[stride*1] = _mm_cvtsi128_si32(_mm_shuffle_epi32(a, _MM_SHUFFLE(0, 0, 0, 1))); + return; + case 3: + ptr[stride*1] = _mm_cvtsi128_si32(_mm_shuffle_epi32(a, _MM_SHUFFLE(0, 0, 0, 1))); + ptr[stride*2] = _mm_cvtsi128_si32(_mm_shuffle_epi32(a, _MM_SHUFFLE(0, 0, 0, 2))); + return; + default: + ptr[stride*1] = _mm_cvtsi128_si32(_mm_shuffle_epi32(a, _MM_SHUFFLE(0, 0, 0, 1))); + ptr[stride*2] = _mm_cvtsi128_si32(_mm_shuffle_epi32(a, _MM_SHUFFLE(0, 0, 0, 2))); + ptr[stride*3] = _mm_cvtsi128_si32(_mm_shuffle_epi32(a, _MM_SHUFFLE(0, 0, 0, 3))); +#endif + } +} +//// 64 +NPY_FINLINE void npyv_storen_till_s64(npy_int64 *ptr, npy_intp stride, npy_uintp nlane, npyv_s64 a) +{ + assert(nlane > 0); + if (nlane == 1) { + _mm_storel_epi64((__m128i *)ptr, a); + return; + } + npyv_storen_s64(ptr, stride, a); +} + +//// 64-bit store over 32-bit stride +NPY_FINLINE void npyv_storen2_till_s32(npy_int32 *ptr, npy_intp stride, npy_uintp nlane, npyv_s32 a) +{ + assert(nlane > 0); + npyv_storel_s32(ptr, a); + if (nlane > 1) { + npyv_storeh_s32(ptr + stride, a); + } +} + +//// 128-bit store over 64-bit stride +NPY_FINLINE void npyv_storen2_till_s64(npy_int64 *ptr, npy_intp stride, npy_uintp nlane, npyv_s64 a) +{ assert(nlane > 0); (void)stride; (void)nlane; npyv_store_s64(ptr, a); } + +/***************************************************************** + * Implement partial load/store for u32/f32/u64/f64... via casting + *****************************************************************/ +#define NPYV_IMPL_SSE_REST_PARTIAL_TYPES(F_SFX, T_SFX) \ + NPY_FINLINE npyv_##F_SFX npyv_load_till_##F_SFX \ + (const npyv_lanetype_##F_SFX *ptr, npy_uintp nlane, npyv_lanetype_##F_SFX fill) \ + { \ + union { \ + npyv_lanetype_##F_SFX from_##F_SFX; \ + npyv_lanetype_##T_SFX to_##T_SFX; \ + } pun; \ + pun.from_##F_SFX = fill; \ + return npyv_reinterpret_##F_SFX##_##T_SFX(npyv_load_till_##T_SFX( \ + (const npyv_lanetype_##T_SFX *)ptr, nlane, pun.to_##T_SFX \ + )); \ + } \ + NPY_FINLINE npyv_##F_SFX npyv_loadn_till_##F_SFX \ + (const npyv_lanetype_##F_SFX *ptr, npy_intp stride, npy_uintp nlane, \ + npyv_lanetype_##F_SFX fill) \ + { \ + union { \ + npyv_lanetype_##F_SFX from_##F_SFX; \ + npyv_lanetype_##T_SFX to_##T_SFX; \ + } pun; \ + pun.from_##F_SFX = fill; \ + return npyv_reinterpret_##F_SFX##_##T_SFX(npyv_loadn_till_##T_SFX( \ + (const npyv_lanetype_##T_SFX *)ptr, stride, nlane, pun.to_##T_SFX \ + )); \ + } \ + NPY_FINLINE npyv_##F_SFX npyv_load_tillz_##F_SFX \ + (const npyv_lanetype_##F_SFX *ptr, npy_uintp nlane) \ + { \ + return npyv_reinterpret_##F_SFX##_##T_SFX(npyv_load_tillz_##T_SFX( \ + (const npyv_lanetype_##T_SFX *)ptr, nlane \ + )); \ + } \ + NPY_FINLINE npyv_##F_SFX npyv_loadn_tillz_##F_SFX \ + (const npyv_lanetype_##F_SFX *ptr, npy_intp stride, npy_uintp nlane) \ + { \ + return npyv_reinterpret_##F_SFX##_##T_SFX(npyv_loadn_tillz_##T_SFX( \ + (const npyv_lanetype_##T_SFX *)ptr, stride, nlane \ + )); \ + } \ + NPY_FINLINE void npyv_store_till_##F_SFX \ + (npyv_lanetype_##F_SFX *ptr, npy_uintp nlane, npyv_##F_SFX a) \ + { \ + npyv_store_till_##T_SFX( \ + (npyv_lanetype_##T_SFX *)ptr, nlane, \ + npyv_reinterpret_##T_SFX##_##F_SFX(a) \ + ); \ + } \ + NPY_FINLINE void npyv_storen_till_##F_SFX \ + (npyv_lanetype_##F_SFX *ptr, npy_intp stride, npy_uintp nlane, npyv_##F_SFX a) \ + { \ + npyv_storen_till_##T_SFX( \ + (npyv_lanetype_##T_SFX *)ptr, stride, nlane, \ + npyv_reinterpret_##T_SFX##_##F_SFX(a) \ + ); \ + } + +NPYV_IMPL_SSE_REST_PARTIAL_TYPES(u32, s32) +NPYV_IMPL_SSE_REST_PARTIAL_TYPES(f32, s32) +NPYV_IMPL_SSE_REST_PARTIAL_TYPES(u64, s64) +NPYV_IMPL_SSE_REST_PARTIAL_TYPES(f64, s64) + +// 128-bit/64-bit stride +#define NPYV_IMPL_SSE_REST_PARTIAL_TYPES_PAIR(F_SFX, T_SFX) \ + NPY_FINLINE npyv_##F_SFX npyv_load2_till_##F_SFX \ + (const npyv_lanetype_##F_SFX *ptr, npy_uintp nlane, \ + npyv_lanetype_##F_SFX fill_lo, npyv_lanetype_##F_SFX fill_hi) \ + { \ + union pun { \ + npyv_lanetype_##F_SFX from_##F_SFX; \ + npyv_lanetype_##T_SFX to_##T_SFX; \ + }; \ + union pun pun_lo; \ + union pun pun_hi; \ + pun_lo.from_##F_SFX = fill_lo; \ + pun_hi.from_##F_SFX = fill_hi; \ + return npyv_reinterpret_##F_SFX##_##T_SFX(npyv_load2_till_##T_SFX( \ + (const npyv_lanetype_##T_SFX *)ptr, nlane, pun_lo.to_##T_SFX, pun_hi.to_##T_SFX \ + )); \ + } \ + NPY_FINLINE npyv_##F_SFX npyv_loadn2_till_##F_SFX \ + (const npyv_lanetype_##F_SFX *ptr, npy_intp stride, npy_uintp nlane, \ + npyv_lanetype_##F_SFX fill_lo, npyv_lanetype_##F_SFX fill_hi) \ + { \ + union pun { \ + npyv_lanetype_##F_SFX from_##F_SFX; \ + npyv_lanetype_##T_SFX to_##T_SFX; \ + }; \ + union pun pun_lo; \ + union pun pun_hi; \ + pun_lo.from_##F_SFX = fill_lo; \ + pun_hi.from_##F_SFX = fill_hi; \ + return npyv_reinterpret_##F_SFX##_##T_SFX(npyv_loadn2_till_##T_SFX( \ + (const npyv_lanetype_##T_SFX *)ptr, stride, nlane, pun_lo.to_##T_SFX, \ + pun_hi.to_##T_SFX \ + )); \ + } \ + NPY_FINLINE npyv_##F_SFX npyv_load2_tillz_##F_SFX \ + (const npyv_lanetype_##F_SFX *ptr, npy_uintp nlane) \ + { \ + return npyv_reinterpret_##F_SFX##_##T_SFX(npyv_load2_tillz_##T_SFX( \ + (const npyv_lanetype_##T_SFX *)ptr, nlane \ + )); \ + } \ + NPY_FINLINE npyv_##F_SFX npyv_loadn2_tillz_##F_SFX \ + (const npyv_lanetype_##F_SFX *ptr, npy_intp stride, npy_uintp nlane) \ + { \ + return npyv_reinterpret_##F_SFX##_##T_SFX(npyv_loadn2_tillz_##T_SFX( \ + (const npyv_lanetype_##T_SFX *)ptr, stride, nlane \ + )); \ + } \ + NPY_FINLINE void npyv_store2_till_##F_SFX \ + (npyv_lanetype_##F_SFX *ptr, npy_uintp nlane, npyv_##F_SFX a) \ + { \ + npyv_store2_till_##T_SFX( \ + (npyv_lanetype_##T_SFX *)ptr, nlane, \ + npyv_reinterpret_##T_SFX##_##F_SFX(a) \ + ); \ + } \ + NPY_FINLINE void npyv_storen2_till_##F_SFX \ + (npyv_lanetype_##F_SFX *ptr, npy_intp stride, npy_uintp nlane, npyv_##F_SFX a) \ + { \ + npyv_storen2_till_##T_SFX( \ + (npyv_lanetype_##T_SFX *)ptr, stride, nlane, \ + npyv_reinterpret_##T_SFX##_##F_SFX(a) \ + ); \ + } + +NPYV_IMPL_SSE_REST_PARTIAL_TYPES_PAIR(u32, s32) +NPYV_IMPL_SSE_REST_PARTIAL_TYPES_PAIR(f32, s32) +NPYV_IMPL_SSE_REST_PARTIAL_TYPES_PAIR(u64, s64) +NPYV_IMPL_SSE_REST_PARTIAL_TYPES_PAIR(f64, s64) + +/************************************************************ + * de-interlave load / interleave contiguous store + ************************************************************/ +// two channels +#define NPYV_IMPL_SSE_MEM_INTERLEAVE(SFX, ZSFX) \ + NPY_FINLINE npyv_##ZSFX##x2 npyv_zip_##ZSFX(npyv_##ZSFX, npyv_##ZSFX); \ + NPY_FINLINE npyv_##ZSFX##x2 npyv_unzip_##ZSFX(npyv_##ZSFX, npyv_##ZSFX); \ + NPY_FINLINE npyv_##SFX##x2 npyv_load_##SFX##x2( \ + const npyv_lanetype_##SFX *ptr \ + ) { \ + return npyv_unzip_##ZSFX( \ + npyv_load_##SFX(ptr), npyv_load_##SFX(ptr+npyv_nlanes_##SFX) \ + ); \ + } \ + NPY_FINLINE void npyv_store_##SFX##x2( \ + npyv_lanetype_##SFX *ptr, npyv_##SFX##x2 v \ + ) { \ + npyv_##SFX##x2 zip = npyv_zip_##ZSFX(v.val[0], v.val[1]); \ + npyv_store_##SFX(ptr, zip.val[0]); \ + npyv_store_##SFX(ptr + npyv_nlanes_##SFX, zip.val[1]); \ + } + +NPYV_IMPL_SSE_MEM_INTERLEAVE(u8, u8) +NPYV_IMPL_SSE_MEM_INTERLEAVE(s8, u8) +NPYV_IMPL_SSE_MEM_INTERLEAVE(u16, u16) +NPYV_IMPL_SSE_MEM_INTERLEAVE(s16, u16) +NPYV_IMPL_SSE_MEM_INTERLEAVE(u32, u32) +NPYV_IMPL_SSE_MEM_INTERLEAVE(s32, u32) +NPYV_IMPL_SSE_MEM_INTERLEAVE(u64, u64) +NPYV_IMPL_SSE_MEM_INTERLEAVE(s64, u64) +NPYV_IMPL_SSE_MEM_INTERLEAVE(f32, f32) +NPYV_IMPL_SSE_MEM_INTERLEAVE(f64, f64) + +/********************************* + * Lookup table + *********************************/ +// uses vector as indexes into a table +// that contains 32 elements of float32. +NPY_FINLINE npyv_f32 npyv_lut32_f32(const float *table, npyv_u32 idx) +{ + const int i0 = _mm_cvtsi128_si32(idx); +#ifdef NPY_HAVE_SSE41 + const int i1 = _mm_extract_epi32(idx, 1); + const int i2 = _mm_extract_epi32(idx, 2); + const int i3 = _mm_extract_epi32(idx, 3); +#else + const int i1 = _mm_extract_epi16(idx, 2); + const int i2 = _mm_extract_epi16(idx, 4); + const int i3 = _mm_extract_epi16(idx, 6); +#endif + return npyv_set_f32(table[i0], table[i1], table[i2], table[i3]); +} +NPY_FINLINE npyv_u32 npyv_lut32_u32(const npy_uint32 *table, npyv_u32 idx) +{ return npyv_reinterpret_u32_f32(npyv_lut32_f32((const float*)table, idx)); } +NPY_FINLINE npyv_s32 npyv_lut32_s32(const npy_int32 *table, npyv_u32 idx) +{ return npyv_reinterpret_s32_f32(npyv_lut32_f32((const float*)table, idx)); } + +// uses vector as indexes into a table +// that contains 16 elements of float64. +NPY_FINLINE npyv_f64 npyv_lut16_f64(const double *table, npyv_u64 idx) +{ + const int i0 = _mm_cvtsi128_si32(idx); +#ifdef NPY_HAVE_SSE41 + const int i1 = _mm_extract_epi32(idx, 2); +#else + const int i1 = _mm_extract_epi16(idx, 4); +#endif + return npyv_set_f64(table[i0], table[i1]); +} +NPY_FINLINE npyv_u64 npyv_lut16_u64(const npy_uint64 *table, npyv_u64 idx) +{ return npyv_reinterpret_u64_f64(npyv_lut16_f64((const double*)table, idx)); } +NPY_FINLINE npyv_s64 npyv_lut16_s64(const npy_int64 *table, npyv_u64 idx) +{ return npyv_reinterpret_s64_f64(npyv_lut16_f64((const double*)table, idx)); } + +#endif // _NPY_SIMD_SSE_MEMORY_H diff --git a/mkl_umath/src/npyv/sse/misc.h b/mkl_umath/src/npyv/sse/misc.h new file mode 100644 index 00000000..b01ff172 --- /dev/null +++ b/mkl_umath/src/npyv/sse/misc.h @@ -0,0 +1,258 @@ +#ifndef NPY_SIMD + #error "Not a standalone header" +#endif + +#ifndef _NPY_SIMD_SSE_MISC_H +#define _NPY_SIMD_SSE_MISC_H + +// vector with zero lanes +#define npyv_zero_u8 _mm_setzero_si128 +#define npyv_zero_s8 _mm_setzero_si128 +#define npyv_zero_u16 _mm_setzero_si128 +#define npyv_zero_s16 _mm_setzero_si128 +#define npyv_zero_u32 _mm_setzero_si128 +#define npyv_zero_s32 _mm_setzero_si128 +#define npyv_zero_u64 _mm_setzero_si128 +#define npyv_zero_s64 _mm_setzero_si128 +#define npyv_zero_f32 _mm_setzero_ps +#define npyv_zero_f64 _mm_setzero_pd + +// vector with a specific value set to all lanes +#define npyv_setall_u8(VAL) _mm_set1_epi8((char)(VAL)) +#define npyv_setall_s8(VAL) _mm_set1_epi8((char)(VAL)) +#define npyv_setall_u16(VAL) _mm_set1_epi16((short)(VAL)) +#define npyv_setall_s16(VAL) _mm_set1_epi16((short)(VAL)) +#define npyv_setall_u32(VAL) _mm_set1_epi32((int)(VAL)) +#define npyv_setall_s32(VAL) _mm_set1_epi32((int)(VAL)) +#define npyv_setall_f32 _mm_set1_ps +#define npyv_setall_f64 _mm_set1_pd + +NPY_FINLINE __m128i npyv__setr_epi64(npy_int64 i0, npy_int64 i1); + +NPY_FINLINE npyv_u64 npyv_setall_u64(npy_uint64 a) +{ +#if defined(_MSC_VER) && defined(_M_IX86) + return npyv__setr_epi64((npy_int64)a, (npy_int64)a); +#else + return _mm_set1_epi64x((npy_int64)a); +#endif +} +NPY_FINLINE npyv_s64 npyv_setall_s64(npy_int64 a) +{ +#if defined(_MSC_VER) && defined(_M_IX86) + return npyv__setr_epi64(a, a); +#else + return _mm_set1_epi64x((npy_int64)a); +#endif +} + +/** + * vector with specific values set to each lane and + * set a specific value to all remained lanes + * + * Args that generated by NPYV__SET_FILL_* not going to expand if + * _mm_setr_* are defined as macros. + */ +NPY_FINLINE __m128i npyv__setr_epi8( + char i0, char i1, char i2, char i3, char i4, char i5, char i6, char i7, + char i8, char i9, char i10, char i11, char i12, char i13, char i14, char i15) +{ + return _mm_setr_epi8(i0, i1, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12, i13, i14, i15); +} +NPY_FINLINE __m128i npyv__setr_epi16(short i0, short i1, short i2, short i3, short i4, short i5, + short i6, short i7) +{ + return _mm_setr_epi16(i0, i1, i2, i3, i4, i5, i6, i7); +} +NPY_FINLINE __m128i npyv__setr_epi32(int i0, int i1, int i2, int i3) +{ + return _mm_setr_epi32(i0, i1, i2, i3); +} +NPY_FINLINE __m128i npyv__setr_epi64(npy_int64 i0, npy_int64 i1) +{ +#if defined(_MSC_VER) && defined(_M_IX86) + return _mm_setr_epi32((int)i0, (int)(i0 >> 32), (int)i1, (int)(i1 >> 32)); +#else + return _mm_set_epi64x(i1, i0); +#endif +} +NPY_FINLINE __m128 npyv__setr_ps(float i0, float i1, float i2, float i3) +{ + return _mm_setr_ps(i0, i1, i2, i3); +} +NPY_FINLINE __m128d npyv__setr_pd(double i0, double i1) +{ + return _mm_setr_pd(i0, i1); +} +#define npyv_setf_u8(FILL, ...) npyv__setr_epi8(NPYV__SET_FILL_16(char, FILL, __VA_ARGS__)) +#define npyv_setf_s8(FILL, ...) npyv__setr_epi8(NPYV__SET_FILL_16(char, FILL, __VA_ARGS__)) +#define npyv_setf_u16(FILL, ...) npyv__setr_epi16(NPYV__SET_FILL_8(short, FILL, __VA_ARGS__)) +#define npyv_setf_s16(FILL, ...) npyv__setr_epi16(NPYV__SET_FILL_8(short, FILL, __VA_ARGS__)) +#define npyv_setf_u32(FILL, ...) npyv__setr_epi32(NPYV__SET_FILL_4(int, FILL, __VA_ARGS__)) +#define npyv_setf_s32(FILL, ...) npyv__setr_epi32(NPYV__SET_FILL_4(int, FILL, __VA_ARGS__)) +#define npyv_setf_u64(FILL, ...) npyv__setr_epi64(NPYV__SET_FILL_2(npy_int64, FILL, __VA_ARGS__)) +#define npyv_setf_s64(FILL, ...) npyv__setr_epi64(NPYV__SET_FILL_2(npy_int64, FILL, __VA_ARGS__)) +#define npyv_setf_f32(FILL, ...) npyv__setr_ps(NPYV__SET_FILL_4(float, FILL, __VA_ARGS__)) +#define npyv_setf_f64(FILL, ...) npyv__setr_pd(NPYV__SET_FILL_2(double, FILL, __VA_ARGS__)) + +// vector with specific values set to each lane and +// set zero to all remained lanes +#define npyv_set_u8(...) npyv_setf_u8(0, __VA_ARGS__) +#define npyv_set_s8(...) npyv_setf_s8(0, __VA_ARGS__) +#define npyv_set_u16(...) npyv_setf_u16(0, __VA_ARGS__) +#define npyv_set_s16(...) npyv_setf_s16(0, __VA_ARGS__) +#define npyv_set_u32(...) npyv_setf_u32(0, __VA_ARGS__) +#define npyv_set_s32(...) npyv_setf_s32(0, __VA_ARGS__) +#define npyv_set_u64(...) npyv_setf_u64(0, __VA_ARGS__) +#define npyv_set_s64(...) npyv_setf_s64(0, __VA_ARGS__) +#define npyv_set_f32(...) npyv_setf_f32(0, __VA_ARGS__) +#define npyv_set_f64(...) npyv_setf_f64(0, __VA_ARGS__) + +// Per lane select +#ifdef NPY_HAVE_SSE41 + #define npyv_select_u8(MASK, A, B) _mm_blendv_epi8(B, A, MASK) + #define npyv_select_f32(MASK, A, B) _mm_blendv_ps(B, A, _mm_castsi128_ps(MASK)) + #define npyv_select_f64(MASK, A, B) _mm_blendv_pd(B, A, _mm_castsi128_pd(MASK)) +#else + NPY_FINLINE __m128i npyv_select_u8(__m128i mask, __m128i a, __m128i b) + { return _mm_xor_si128(b, _mm_and_si128(_mm_xor_si128(b, a), mask)); } + NPY_FINLINE __m128 npyv_select_f32(__m128i mask, __m128 a, __m128 b) + { return _mm_xor_ps(b, _mm_and_ps(_mm_xor_ps(b, a), _mm_castsi128_ps(mask))); } + NPY_FINLINE __m128d npyv_select_f64(__m128i mask, __m128d a, __m128d b) + { return _mm_xor_pd(b, _mm_and_pd(_mm_xor_pd(b, a), _mm_castsi128_pd(mask))); } +#endif +#define npyv_select_s8 npyv_select_u8 +#define npyv_select_u16 npyv_select_u8 +#define npyv_select_s16 npyv_select_u8 +#define npyv_select_u32 npyv_select_u8 +#define npyv_select_s32 npyv_select_u8 +#define npyv_select_u64 npyv_select_u8 +#define npyv_select_s64 npyv_select_u8 + +// extract the first vector's lane +#define npyv_extract0_u8(A) ((npy_uint8)_mm_cvtsi128_si32(A)) +#define npyv_extract0_s8(A) ((npy_int8)_mm_cvtsi128_si32(A)) +#define npyv_extract0_u16(A) ((npy_uint16)_mm_cvtsi128_si32(A)) +#define npyv_extract0_s16(A) ((npy_int16)_mm_cvtsi128_si32(A)) +#define npyv_extract0_u32(A) ((npy_uint32)_mm_cvtsi128_si32(A)) +#define npyv_extract0_s32(A) ((npy_int32)_mm_cvtsi128_si32(A)) +#define npyv_extract0_u64(A) ((npy_uint64)npyv128_cvtsi128_si64(A)) +#define npyv_extract0_s64(A) ((npy_int64)npyv128_cvtsi128_si64(A)) +#define npyv_extract0_f32 _mm_cvtss_f32 +#define npyv_extract0_f64 _mm_cvtsd_f64 + +// Reinterpret +#define npyv_reinterpret_u8_u8(X) X +#define npyv_reinterpret_u8_s8(X) X +#define npyv_reinterpret_u8_u16(X) X +#define npyv_reinterpret_u8_s16(X) X +#define npyv_reinterpret_u8_u32(X) X +#define npyv_reinterpret_u8_s32(X) X +#define npyv_reinterpret_u8_u64(X) X +#define npyv_reinterpret_u8_s64(X) X +#define npyv_reinterpret_u8_f32 _mm_castps_si128 +#define npyv_reinterpret_u8_f64 _mm_castpd_si128 + +#define npyv_reinterpret_s8_s8(X) X +#define npyv_reinterpret_s8_u8(X) X +#define npyv_reinterpret_s8_u16(X) X +#define npyv_reinterpret_s8_s16(X) X +#define npyv_reinterpret_s8_u32(X) X +#define npyv_reinterpret_s8_s32(X) X +#define npyv_reinterpret_s8_u64(X) X +#define npyv_reinterpret_s8_s64(X) X +#define npyv_reinterpret_s8_f32 _mm_castps_si128 +#define npyv_reinterpret_s8_f64 _mm_castpd_si128 + +#define npyv_reinterpret_u16_u16(X) X +#define npyv_reinterpret_u16_u8(X) X +#define npyv_reinterpret_u16_s8(X) X +#define npyv_reinterpret_u16_s16(X) X +#define npyv_reinterpret_u16_u32(X) X +#define npyv_reinterpret_u16_s32(X) X +#define npyv_reinterpret_u16_u64(X) X +#define npyv_reinterpret_u16_s64(X) X +#define npyv_reinterpret_u16_f32 _mm_castps_si128 +#define npyv_reinterpret_u16_f64 _mm_castpd_si128 + +#define npyv_reinterpret_s16_s16(X) X +#define npyv_reinterpret_s16_u8(X) X +#define npyv_reinterpret_s16_s8(X) X +#define npyv_reinterpret_s16_u16(X) X +#define npyv_reinterpret_s16_u32(X) X +#define npyv_reinterpret_s16_s32(X) X +#define npyv_reinterpret_s16_u64(X) X +#define npyv_reinterpret_s16_s64(X) X +#define npyv_reinterpret_s16_f32 _mm_castps_si128 +#define npyv_reinterpret_s16_f64 _mm_castpd_si128 + +#define npyv_reinterpret_u32_u32(X) X +#define npyv_reinterpret_u32_u8(X) X +#define npyv_reinterpret_u32_s8(X) X +#define npyv_reinterpret_u32_u16(X) X +#define npyv_reinterpret_u32_s16(X) X +#define npyv_reinterpret_u32_s32(X) X +#define npyv_reinterpret_u32_u64(X) X +#define npyv_reinterpret_u32_s64(X) X +#define npyv_reinterpret_u32_f32 _mm_castps_si128 +#define npyv_reinterpret_u32_f64 _mm_castpd_si128 + +#define npyv_reinterpret_s32_s32(X) X +#define npyv_reinterpret_s32_u8(X) X +#define npyv_reinterpret_s32_s8(X) X +#define npyv_reinterpret_s32_u16(X) X +#define npyv_reinterpret_s32_s16(X) X +#define npyv_reinterpret_s32_u32(X) X +#define npyv_reinterpret_s32_u64(X) X +#define npyv_reinterpret_s32_s64(X) X +#define npyv_reinterpret_s32_f32 _mm_castps_si128 +#define npyv_reinterpret_s32_f64 _mm_castpd_si128 + +#define npyv_reinterpret_u64_u64(X) X +#define npyv_reinterpret_u64_u8(X) X +#define npyv_reinterpret_u64_s8(X) X +#define npyv_reinterpret_u64_u16(X) X +#define npyv_reinterpret_u64_s16(X) X +#define npyv_reinterpret_u64_u32(X) X +#define npyv_reinterpret_u64_s32(X) X +#define npyv_reinterpret_u64_s64(X) X +#define npyv_reinterpret_u64_f32 _mm_castps_si128 +#define npyv_reinterpret_u64_f64 _mm_castpd_si128 + +#define npyv_reinterpret_s64_s64(X) X +#define npyv_reinterpret_s64_u8(X) X +#define npyv_reinterpret_s64_s8(X) X +#define npyv_reinterpret_s64_u16(X) X +#define npyv_reinterpret_s64_s16(X) X +#define npyv_reinterpret_s64_u32(X) X +#define npyv_reinterpret_s64_s32(X) X +#define npyv_reinterpret_s64_u64(X) X +#define npyv_reinterpret_s64_f32 _mm_castps_si128 +#define npyv_reinterpret_s64_f64 _mm_castpd_si128 + +#define npyv_reinterpret_f32_f32(X) X +#define npyv_reinterpret_f32_u8 _mm_castsi128_ps +#define npyv_reinterpret_f32_s8 _mm_castsi128_ps +#define npyv_reinterpret_f32_u16 _mm_castsi128_ps +#define npyv_reinterpret_f32_s16 _mm_castsi128_ps +#define npyv_reinterpret_f32_u32 _mm_castsi128_ps +#define npyv_reinterpret_f32_s32 _mm_castsi128_ps +#define npyv_reinterpret_f32_u64 _mm_castsi128_ps +#define npyv_reinterpret_f32_s64 _mm_castsi128_ps +#define npyv_reinterpret_f32_f64 _mm_castpd_ps + +#define npyv_reinterpret_f64_f64(X) X +#define npyv_reinterpret_f64_u8 _mm_castsi128_pd +#define npyv_reinterpret_f64_s8 _mm_castsi128_pd +#define npyv_reinterpret_f64_u16 _mm_castsi128_pd +#define npyv_reinterpret_f64_s16 _mm_castsi128_pd +#define npyv_reinterpret_f64_u32 _mm_castsi128_pd +#define npyv_reinterpret_f64_s32 _mm_castsi128_pd +#define npyv_reinterpret_f64_u64 _mm_castsi128_pd +#define npyv_reinterpret_f64_s64 _mm_castsi128_pd +#define npyv_reinterpret_f64_f32 _mm_castps_pd + +// Only required by AVX2/AVX512 +#define npyv_cleanup() ((void)0) + +#endif // _NPY_SIMD_SSE_MISC_H diff --git a/mkl_umath/src/npyv/sse/operators.h b/mkl_umath/src/npyv/sse/operators.h new file mode 100644 index 00000000..59182679 --- /dev/null +++ b/mkl_umath/src/npyv/sse/operators.h @@ -0,0 +1,342 @@ +#ifndef NPY_SIMD + #error "Not a standalone header" +#endif + +#ifndef _NPY_SIMD_SSE_OPERATORS_H +#define _NPY_SIMD_SSE_OPERATORS_H + +/*************************** + * Shifting + ***************************/ + +// left +#define npyv_shl_u16(A, C) _mm_sll_epi16(A, _mm_cvtsi32_si128(C)) +#define npyv_shl_s16(A, C) _mm_sll_epi16(A, _mm_cvtsi32_si128(C)) +#define npyv_shl_u32(A, C) _mm_sll_epi32(A, _mm_cvtsi32_si128(C)) +#define npyv_shl_s32(A, C) _mm_sll_epi32(A, _mm_cvtsi32_si128(C)) +#define npyv_shl_u64(A, C) _mm_sll_epi64(A, _mm_cvtsi32_si128(C)) +#define npyv_shl_s64(A, C) _mm_sll_epi64(A, _mm_cvtsi32_si128(C)) + +// left by an immediate constant +#define npyv_shli_u16 _mm_slli_epi16 +#define npyv_shli_s16 _mm_slli_epi16 +#define npyv_shli_u32 _mm_slli_epi32 +#define npyv_shli_s32 _mm_slli_epi32 +#define npyv_shli_u64 _mm_slli_epi64 +#define npyv_shli_s64 _mm_slli_epi64 + +// right +#define npyv_shr_u16(A, C) _mm_srl_epi16(A, _mm_cvtsi32_si128(C)) +#define npyv_shr_s16(A, C) _mm_sra_epi16(A, _mm_cvtsi32_si128(C)) +#define npyv_shr_u32(A, C) _mm_srl_epi32(A, _mm_cvtsi32_si128(C)) +#define npyv_shr_s32(A, C) _mm_sra_epi32(A, _mm_cvtsi32_si128(C)) +#define npyv_shr_u64(A, C) _mm_srl_epi64(A, _mm_cvtsi32_si128(C)) +NPY_FINLINE __m128i npyv_shr_s64(__m128i a, int c) +{ + const __m128i sbit = npyv_setall_s64(0x8000000000000000); + const __m128i cv = _mm_cvtsi32_si128(c); + __m128i r = _mm_srl_epi64(_mm_add_epi64(a, sbit), cv); + return _mm_sub_epi64(r, _mm_srl_epi64(sbit, cv)); +} + +// Right by an immediate constant +#define npyv_shri_u16 _mm_srli_epi16 +#define npyv_shri_s16 _mm_srai_epi16 +#define npyv_shri_u32 _mm_srli_epi32 +#define npyv_shri_s32 _mm_srai_epi32 +#define npyv_shri_u64 _mm_srli_epi64 +#define npyv_shri_s64 npyv_shr_s64 + +/*************************** + * Logical + ***************************/ + +// AND +#define npyv_and_u8 _mm_and_si128 +#define npyv_and_s8 _mm_and_si128 +#define npyv_and_u16 _mm_and_si128 +#define npyv_and_s16 _mm_and_si128 +#define npyv_and_u32 _mm_and_si128 +#define npyv_and_s32 _mm_and_si128 +#define npyv_and_u64 _mm_and_si128 +#define npyv_and_s64 _mm_and_si128 +#define npyv_and_f32 _mm_and_ps +#define npyv_and_f64 _mm_and_pd +#define npyv_and_b8 _mm_and_si128 +#define npyv_and_b16 _mm_and_si128 +#define npyv_and_b32 _mm_and_si128 +#define npyv_and_b64 _mm_and_si128 + +// OR +#define npyv_or_u8 _mm_or_si128 +#define npyv_or_s8 _mm_or_si128 +#define npyv_or_u16 _mm_or_si128 +#define npyv_or_s16 _mm_or_si128 +#define npyv_or_u32 _mm_or_si128 +#define npyv_or_s32 _mm_or_si128 +#define npyv_or_u64 _mm_or_si128 +#define npyv_or_s64 _mm_or_si128 +#define npyv_or_f32 _mm_or_ps +#define npyv_or_f64 _mm_or_pd +#define npyv_or_b8 _mm_or_si128 +#define npyv_or_b16 _mm_or_si128 +#define npyv_or_b32 _mm_or_si128 +#define npyv_or_b64 _mm_or_si128 + +// XOR +#define npyv_xor_u8 _mm_xor_si128 +#define npyv_xor_s8 _mm_xor_si128 +#define npyv_xor_u16 _mm_xor_si128 +#define npyv_xor_s16 _mm_xor_si128 +#define npyv_xor_u32 _mm_xor_si128 +#define npyv_xor_s32 _mm_xor_si128 +#define npyv_xor_u64 _mm_xor_si128 +#define npyv_xor_s64 _mm_xor_si128 +#define npyv_xor_f32 _mm_xor_ps +#define npyv_xor_f64 _mm_xor_pd +#define npyv_xor_b8 _mm_xor_si128 +#define npyv_xor_b16 _mm_xor_si128 +#define npyv_xor_b32 _mm_xor_si128 +#define npyv_xor_b64 _mm_xor_si128 + +// NOT +#define npyv_not_u8(A) _mm_xor_si128(A, _mm_set1_epi32(-1)) +#define npyv_not_s8 npyv_not_u8 +#define npyv_not_u16 npyv_not_u8 +#define npyv_not_s16 npyv_not_u8 +#define npyv_not_u32 npyv_not_u8 +#define npyv_not_s32 npyv_not_u8 +#define npyv_not_u64 npyv_not_u8 +#define npyv_not_s64 npyv_not_u8 +#define npyv_not_f32(A) _mm_xor_ps(A, _mm_castsi128_ps(_mm_set1_epi32(-1))) +#define npyv_not_f64(A) _mm_xor_pd(A, _mm_castsi128_pd(_mm_set1_epi32(-1))) +#define npyv_not_b8 npyv_not_u8 +#define npyv_not_b16 npyv_not_u8 +#define npyv_not_b32 npyv_not_u8 +#define npyv_not_b64 npyv_not_u8 + +// ANDC, ORC and XNOR +#define npyv_andc_u8(A, B) _mm_andnot_si128(B, A) +#define npyv_andc_b8(A, B) _mm_andnot_si128(B, A) +#define npyv_orc_b8(A, B) npyv_or_b8(npyv_not_b8(B), A) +#define npyv_xnor_b8 _mm_cmpeq_epi8 + +/*************************** + * Comparison + ***************************/ + +// Int Equal +#define npyv_cmpeq_u8 _mm_cmpeq_epi8 +#define npyv_cmpeq_s8 _mm_cmpeq_epi8 +#define npyv_cmpeq_u16 _mm_cmpeq_epi16 +#define npyv_cmpeq_s16 _mm_cmpeq_epi16 +#define npyv_cmpeq_u32 _mm_cmpeq_epi32 +#define npyv_cmpeq_s32 _mm_cmpeq_epi32 +#define npyv_cmpeq_s64 npyv_cmpeq_u64 + +#ifdef NPY_HAVE_SSE41 + #define npyv_cmpeq_u64 _mm_cmpeq_epi64 +#else + NPY_FINLINE __m128i npyv_cmpeq_u64(__m128i a, __m128i b) + { + __m128i cmpeq = _mm_cmpeq_epi32(a, b); + __m128i cmpeq_h = _mm_srli_epi64(cmpeq, 32); + __m128i test = _mm_and_si128(cmpeq, cmpeq_h); + return _mm_shuffle_epi32(test, _MM_SHUFFLE(2, 2, 0, 0)); + } +#endif + +// Int Not Equal +#ifdef NPY_HAVE_XOP + #define npyv_cmpneq_u8 _mm_comneq_epi8 + #define npyv_cmpneq_u16 _mm_comneq_epi16 + #define npyv_cmpneq_u32 _mm_comneq_epi32 + #define npyv_cmpneq_u64 _mm_comneq_epi64 +#else + #define npyv_cmpneq_u8(A, B) npyv_not_u8(npyv_cmpeq_u8(A, B)) + #define npyv_cmpneq_u16(A, B) npyv_not_u16(npyv_cmpeq_u16(A, B)) + #define npyv_cmpneq_u32(A, B) npyv_not_u32(npyv_cmpeq_u32(A, B)) + #define npyv_cmpneq_u64(A, B) npyv_not_u64(npyv_cmpeq_u64(A, B)) +#endif +#define npyv_cmpneq_s8 npyv_cmpneq_u8 +#define npyv_cmpneq_s16 npyv_cmpneq_u16 +#define npyv_cmpneq_s32 npyv_cmpneq_u32 +#define npyv_cmpneq_s64 npyv_cmpneq_u64 + +// signed greater than +#define npyv_cmpgt_s8 _mm_cmpgt_epi8 +#define npyv_cmpgt_s16 _mm_cmpgt_epi16 +#define npyv_cmpgt_s32 _mm_cmpgt_epi32 + +#ifdef NPY_HAVE_SSE42 + #define npyv_cmpgt_s64 _mm_cmpgt_epi64 +#else + NPY_FINLINE __m128i npyv_cmpgt_s64(__m128i a, __m128i b) + { + __m128i sub = _mm_sub_epi64(b, a); + __m128i nsame_sbit = _mm_xor_si128(a, b); + // nsame_sbit ? b : sub + __m128i test = _mm_xor_si128(sub, _mm_and_si128(_mm_xor_si128(sub, b), nsame_sbit)); + __m128i extend_sbit = _mm_shuffle_epi32(_mm_srai_epi32(test, 31), _MM_SHUFFLE(3, 3, 1, 1)); + return extend_sbit; + } +#endif + +// signed greater than or equal +#ifdef NPY_HAVE_XOP + #define npyv_cmpge_s8 _mm_comge_epi8 + #define npyv_cmpge_s16 _mm_comge_epi16 + #define npyv_cmpge_s32 _mm_comge_epi32 + #define npyv_cmpge_s64 _mm_comge_epi64 +#else + #define npyv_cmpge_s8(A, B) npyv_not_s8(_mm_cmpgt_epi8(B, A)) + #define npyv_cmpge_s16(A, B) npyv_not_s16(_mm_cmpgt_epi16(B, A)) + #define npyv_cmpge_s32(A, B) npyv_not_s32(_mm_cmpgt_epi32(B, A)) + #define npyv_cmpge_s64(A, B) npyv_not_s64(npyv_cmpgt_s64(B, A)) +#endif + +// unsigned greater than +#ifdef NPY_HAVE_XOP + #define npyv_cmpgt_u8 _mm_comgt_epu8 + #define npyv_cmpgt_u16 _mm_comgt_epu16 + #define npyv_cmpgt_u32 _mm_comgt_epu32 + #define npyv_cmpgt_u64 _mm_comgt_epu64 +#else + #define NPYV_IMPL_SSE_UNSIGNED_GT(LEN, SIGN) \ + NPY_FINLINE __m128i npyv_cmpgt_u##LEN(__m128i a, __m128i b) \ + { \ + const __m128i sbit = _mm_set1_epi32(SIGN); \ + return _mm_cmpgt_epi##LEN( \ + _mm_xor_si128(a, sbit), _mm_xor_si128(b, sbit) \ + ); \ + } + + NPYV_IMPL_SSE_UNSIGNED_GT(8, 0x80808080) + NPYV_IMPL_SSE_UNSIGNED_GT(16, 0x80008000) + NPYV_IMPL_SSE_UNSIGNED_GT(32, 0x80000000) + + NPY_FINLINE __m128i npyv_cmpgt_u64(__m128i a, __m128i b) + { + const __m128i sbit = npyv_setall_s64(0x8000000000000000); + return npyv_cmpgt_s64(_mm_xor_si128(a, sbit), _mm_xor_si128(b, sbit)); + } +#endif + +// unsigned greater than or equal +#ifdef NPY_HAVE_XOP + #define npyv_cmpge_u8 _mm_comge_epu8 + #define npyv_cmpge_u16 _mm_comge_epu16 + #define npyv_cmpge_u32 _mm_comge_epu32 + #define npyv_cmpge_u64 _mm_comge_epu64 +#else + NPY_FINLINE __m128i npyv_cmpge_u8(__m128i a, __m128i b) + { return _mm_cmpeq_epi8(a, _mm_max_epu8(a, b)); } + #ifdef NPY_HAVE_SSE41 + NPY_FINLINE __m128i npyv_cmpge_u16(__m128i a, __m128i b) + { return _mm_cmpeq_epi16(a, _mm_max_epu16(a, b)); } + NPY_FINLINE __m128i npyv_cmpge_u32(__m128i a, __m128i b) + { return _mm_cmpeq_epi32(a, _mm_max_epu32(a, b)); } + #else + #define npyv_cmpge_u16(A, B) _mm_cmpeq_epi16(_mm_subs_epu16(B, A), _mm_setzero_si128()) + #define npyv_cmpge_u32(A, B) npyv_not_u32(npyv_cmpgt_u32(B, A)) + #endif + #define npyv_cmpge_u64(A, B) npyv_not_u64(npyv_cmpgt_u64(B, A)) +#endif + +// less than +#define npyv_cmplt_u8(A, B) npyv_cmpgt_u8(B, A) +#define npyv_cmplt_s8(A, B) npyv_cmpgt_s8(B, A) +#define npyv_cmplt_u16(A, B) npyv_cmpgt_u16(B, A) +#define npyv_cmplt_s16(A, B) npyv_cmpgt_s16(B, A) +#define npyv_cmplt_u32(A, B) npyv_cmpgt_u32(B, A) +#define npyv_cmplt_s32(A, B) npyv_cmpgt_s32(B, A) +#define npyv_cmplt_u64(A, B) npyv_cmpgt_u64(B, A) +#define npyv_cmplt_s64(A, B) npyv_cmpgt_s64(B, A) + +// less than or equal +#define npyv_cmple_u8(A, B) npyv_cmpge_u8(B, A) +#define npyv_cmple_s8(A, B) npyv_cmpge_s8(B, A) +#define npyv_cmple_u16(A, B) npyv_cmpge_u16(B, A) +#define npyv_cmple_s16(A, B) npyv_cmpge_s16(B, A) +#define npyv_cmple_u32(A, B) npyv_cmpge_u32(B, A) +#define npyv_cmple_s32(A, B) npyv_cmpge_s32(B, A) +#define npyv_cmple_u64(A, B) npyv_cmpge_u64(B, A) +#define npyv_cmple_s64(A, B) npyv_cmpge_s64(B, A) + +// precision comparison +#define npyv_cmpeq_f32(a, b) _mm_castps_si128(_mm_cmpeq_ps(a, b)) +#define npyv_cmpeq_f64(a, b) _mm_castpd_si128(_mm_cmpeq_pd(a, b)) +#define npyv_cmpneq_f32(a, b) _mm_castps_si128(_mm_cmpneq_ps(a, b)) +#define npyv_cmpneq_f64(a, b) _mm_castpd_si128(_mm_cmpneq_pd(a, b)) +#define npyv_cmplt_f32(a, b) _mm_castps_si128(_mm_cmplt_ps(a, b)) +#define npyv_cmplt_f64(a, b) _mm_castpd_si128(_mm_cmplt_pd(a, b)) +#define npyv_cmple_f32(a, b) _mm_castps_si128(_mm_cmple_ps(a, b)) +#define npyv_cmple_f64(a, b) _mm_castpd_si128(_mm_cmple_pd(a, b)) +#define npyv_cmpgt_f32(a, b) _mm_castps_si128(_mm_cmpgt_ps(a, b)) +#define npyv_cmpgt_f64(a, b) _mm_castpd_si128(_mm_cmpgt_pd(a, b)) +#define npyv_cmpge_f32(a, b) _mm_castps_si128(_mm_cmpge_ps(a, b)) +#define npyv_cmpge_f64(a, b) _mm_castpd_si128(_mm_cmpge_pd(a, b)) + +// check special cases +NPY_FINLINE npyv_b32 npyv_notnan_f32(npyv_f32 a) +{ return _mm_castps_si128(_mm_cmpord_ps(a, a)); } +NPY_FINLINE npyv_b64 npyv_notnan_f64(npyv_f64 a) +{ return _mm_castpd_si128(_mm_cmpord_pd(a, a)); } + +// Test cross all vector lanes +// any: returns true if any of the elements is not equal to zero +// all: returns true if all elements are not equal to zero +#define NPYV_IMPL_SSE_ANYALL(SFX) \ + NPY_FINLINE bool npyv_any_##SFX(npyv_##SFX a) \ + { return _mm_movemask_epi8(a) != 0; } \ + NPY_FINLINE bool npyv_all_##SFX(npyv_##SFX a) \ + { return _mm_movemask_epi8(a) == 0xffff; } +NPYV_IMPL_SSE_ANYALL(b8) +NPYV_IMPL_SSE_ANYALL(b16) +NPYV_IMPL_SSE_ANYALL(b32) +NPYV_IMPL_SSE_ANYALL(b64) +#undef NPYV_IMPL_SSE_ANYALL + +#define NPYV_IMPL_SSE_ANYALL(SFX, MSFX, TSFX, MASK) \ + NPY_FINLINE bool npyv_any_##SFX(npyv_##SFX a) \ + { \ + return _mm_movemask_##MSFX( \ + _mm_cmpeq_##TSFX(a, npyv_zero_##SFX()) \ + ) != MASK; \ + } \ + NPY_FINLINE bool npyv_all_##SFX(npyv_##SFX a) \ + { \ + return _mm_movemask_##MSFX( \ + _mm_cmpeq_##TSFX(a, npyv_zero_##SFX()) \ + ) == 0; \ + } +NPYV_IMPL_SSE_ANYALL(u8, epi8, epi8, 0xffff) +NPYV_IMPL_SSE_ANYALL(s8, epi8, epi8, 0xffff) +NPYV_IMPL_SSE_ANYALL(u16, epi8, epi16, 0xffff) +NPYV_IMPL_SSE_ANYALL(s16, epi8, epi16, 0xffff) +NPYV_IMPL_SSE_ANYALL(u32, epi8, epi32, 0xffff) +NPYV_IMPL_SSE_ANYALL(s32, epi8, epi32, 0xffff) +#ifdef NPY_HAVE_SSE41 + NPYV_IMPL_SSE_ANYALL(u64, epi8, epi64, 0xffff) + NPYV_IMPL_SSE_ANYALL(s64, epi8, epi64, 0xffff) +#else + NPY_FINLINE bool npyv_any_u64(npyv_u64 a) + { + return _mm_movemask_epi8( + _mm_cmpeq_epi32(a, npyv_zero_u64()) + ) != 0xffff; + } + NPY_FINLINE bool npyv_all_u64(npyv_u64 a) + { + a = _mm_cmpeq_epi32(a, npyv_zero_u64()); + a = _mm_and_si128(a, _mm_shuffle_epi32(a, _MM_SHUFFLE(2, 3, 0, 1))); + return _mm_movemask_epi8(a) == 0; + } + #define npyv_any_s64 npyv_any_u64 + #define npyv_all_s64 npyv_all_u64 +#endif +NPYV_IMPL_SSE_ANYALL(f32, ps, ps, 0xf) +NPYV_IMPL_SSE_ANYALL(f64, pd, pd, 0x3) +#undef NPYV_IMPL_SSE_ANYALL + +#endif // _NPY_SIMD_SSE_OPERATORS_H diff --git a/mkl_umath/src/npyv/sse/reorder.h b/mkl_umath/src/npyv/sse/reorder.h new file mode 100644 index 00000000..9a57f648 --- /dev/null +++ b/mkl_umath/src/npyv/sse/reorder.h @@ -0,0 +1,212 @@ +#ifndef NPY_SIMD + #error "Not a standalone header" +#endif + +#ifndef _NPY_SIMD_SSE_REORDER_H +#define _NPY_SIMD_SSE_REORDER_H + +// combine lower part of two vectors +#define npyv_combinel_u8 _mm_unpacklo_epi64 +#define npyv_combinel_s8 _mm_unpacklo_epi64 +#define npyv_combinel_u16 _mm_unpacklo_epi64 +#define npyv_combinel_s16 _mm_unpacklo_epi64 +#define npyv_combinel_u32 _mm_unpacklo_epi64 +#define npyv_combinel_s32 _mm_unpacklo_epi64 +#define npyv_combinel_u64 _mm_unpacklo_epi64 +#define npyv_combinel_s64 _mm_unpacklo_epi64 +#define npyv_combinel_f32(A, B) _mm_castsi128_ps(_mm_unpacklo_epi64(_mm_castps_si128(A), _mm_castps_si128(B))) +#define npyv_combinel_f64 _mm_unpacklo_pd + +// combine higher part of two vectors +#define npyv_combineh_u8 _mm_unpackhi_epi64 +#define npyv_combineh_s8 _mm_unpackhi_epi64 +#define npyv_combineh_u16 _mm_unpackhi_epi64 +#define npyv_combineh_s16 _mm_unpackhi_epi64 +#define npyv_combineh_u32 _mm_unpackhi_epi64 +#define npyv_combineh_s32 _mm_unpackhi_epi64 +#define npyv_combineh_u64 _mm_unpackhi_epi64 +#define npyv_combineh_s64 _mm_unpackhi_epi64 +#define npyv_combineh_f32(A, B) _mm_castsi128_ps(_mm_unpackhi_epi64(_mm_castps_si128(A), _mm_castps_si128(B))) +#define npyv_combineh_f64 _mm_unpackhi_pd + +// combine two vectors from lower and higher parts of two other vectors +NPY_FINLINE npyv_m128ix2 npyv__combine(__m128i a, __m128i b) +{ + npyv_m128ix2 r; + r.val[0] = npyv_combinel_u8(a, b); + r.val[1] = npyv_combineh_u8(a, b); + return r; +} +NPY_FINLINE npyv_f32x2 npyv_combine_f32(__m128 a, __m128 b) +{ + npyv_f32x2 r; + r.val[0] = npyv_combinel_f32(a, b); + r.val[1] = npyv_combineh_f32(a, b); + return r; +} +NPY_FINLINE npyv_f64x2 npyv_combine_f64(__m128d a, __m128d b) +{ + npyv_f64x2 r; + r.val[0] = npyv_combinel_f64(a, b); + r.val[1] = npyv_combineh_f64(a, b); + return r; +} +#define npyv_combine_u8 npyv__combine +#define npyv_combine_s8 npyv__combine +#define npyv_combine_u16 npyv__combine +#define npyv_combine_s16 npyv__combine +#define npyv_combine_u32 npyv__combine +#define npyv_combine_s32 npyv__combine +#define npyv_combine_u64 npyv__combine +#define npyv_combine_s64 npyv__combine + +// interleave two vectors +#define NPYV_IMPL_SSE_ZIP(T_VEC, SFX, INTR_SFX) \ + NPY_FINLINE T_VEC##x2 npyv_zip_##SFX(T_VEC a, T_VEC b) \ + { \ + T_VEC##x2 r; \ + r.val[0] = _mm_unpacklo_##INTR_SFX(a, b); \ + r.val[1] = _mm_unpackhi_##INTR_SFX(a, b); \ + return r; \ + } + +NPYV_IMPL_SSE_ZIP(npyv_u8, u8, epi8) +NPYV_IMPL_SSE_ZIP(npyv_s8, s8, epi8) +NPYV_IMPL_SSE_ZIP(npyv_u16, u16, epi16) +NPYV_IMPL_SSE_ZIP(npyv_s16, s16, epi16) +NPYV_IMPL_SSE_ZIP(npyv_u32, u32, epi32) +NPYV_IMPL_SSE_ZIP(npyv_s32, s32, epi32) +NPYV_IMPL_SSE_ZIP(npyv_u64, u64, epi64) +NPYV_IMPL_SSE_ZIP(npyv_s64, s64, epi64) +NPYV_IMPL_SSE_ZIP(npyv_f32, f32, ps) +NPYV_IMPL_SSE_ZIP(npyv_f64, f64, pd) + +// deinterleave two vectors +NPY_FINLINE npyv_u8x2 npyv_unzip_u8(npyv_u8 ab0, npyv_u8 ab1) +{ +#ifdef NPY_HAVE_SSSE3 + const __m128i idx = _mm_setr_epi8( + 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15 + ); + __m128i abl = _mm_shuffle_epi8(ab0, idx); + __m128i abh = _mm_shuffle_epi8(ab1, idx); + return npyv_combine_u8(abl, abh); +#else + __m128i ab_083b = _mm_unpacklo_epi8(ab0, ab1); + __m128i ab_4c6e = _mm_unpackhi_epi8(ab0, ab1); + __m128i ab_048c = _mm_unpacklo_epi8(ab_083b, ab_4c6e); + __m128i ab_36be = _mm_unpackhi_epi8(ab_083b, ab_4c6e); + __m128i ab_0346 = _mm_unpacklo_epi8(ab_048c, ab_36be); + __m128i ab_8bc8 = _mm_unpackhi_epi8(ab_048c, ab_36be); + npyv_u8x2 r; + r.val[0] = _mm_unpacklo_epi8(ab_0346, ab_8bc8); + r.val[1] = _mm_unpackhi_epi8(ab_0346, ab_8bc8); + return r; +#endif +} +#define npyv_unzip_s8 npyv_unzip_u8 + +NPY_FINLINE npyv_u16x2 npyv_unzip_u16(npyv_u16 ab0, npyv_u16 ab1) +{ +#ifdef NPY_HAVE_SSSE3 + const __m128i idx = _mm_setr_epi8( + 0,1, 4,5, 8,9, 12,13, 2,3, 6,7, 10,11, 14,15 + ); + __m128i abl = _mm_shuffle_epi8(ab0, idx); + __m128i abh = _mm_shuffle_epi8(ab1, idx); + return npyv_combine_u16(abl, abh); +#else + __m128i ab_0415 = _mm_unpacklo_epi16(ab0, ab1); + __m128i ab_263f = _mm_unpackhi_epi16(ab0, ab1); + __m128i ab_0246 = _mm_unpacklo_epi16(ab_0415, ab_263f); + __m128i ab_135f = _mm_unpackhi_epi16(ab_0415, ab_263f); + npyv_u16x2 r; + r.val[0] = _mm_unpacklo_epi16(ab_0246, ab_135f); + r.val[1] = _mm_unpackhi_epi16(ab_0246, ab_135f); + return r; +#endif +} +#define npyv_unzip_s16 npyv_unzip_u16 + +NPY_FINLINE npyv_u32x2 npyv_unzip_u32(npyv_u32 ab0, npyv_u32 ab1) +{ + __m128i abl = _mm_shuffle_epi32(ab0, _MM_SHUFFLE(3, 1, 2, 0)); + __m128i abh = _mm_shuffle_epi32(ab1, _MM_SHUFFLE(3, 1, 2, 0)); + return npyv_combine_u32(abl, abh); +} +#define npyv_unzip_s32 npyv_unzip_u32 + +NPY_FINLINE npyv_u64x2 npyv_unzip_u64(npyv_u64 ab0, npyv_u64 ab1) +{ return npyv_combine_u64(ab0, ab1); } +#define npyv_unzip_s64 npyv_unzip_u64 + +NPY_FINLINE npyv_f32x2 npyv_unzip_f32(npyv_f32 ab0, npyv_f32 ab1) +{ + npyv_f32x2 r; + r.val[0] = _mm_shuffle_ps(ab0, ab1, _MM_SHUFFLE(2, 0, 2, 0)); + r.val[1] = _mm_shuffle_ps(ab0, ab1, _MM_SHUFFLE(3, 1, 3, 1)); + return r; +} +NPY_FINLINE npyv_f64x2 npyv_unzip_f64(npyv_f64 ab0, npyv_f64 ab1) +{ return npyv_combine_f64(ab0, ab1); } + +// Reverse elements of each 64-bit lane +NPY_FINLINE npyv_u16 npyv_rev64_u16(npyv_u16 a) +{ +#ifdef NPY_HAVE_SSSE3 + const __m128i idx = _mm_setr_epi8( + 6, 7, 4, 5, 2, 3, 0, 1,/*64*/14, 15, 12, 13, 10, 11, 8, 9 + ); + return _mm_shuffle_epi8(a, idx); +#else + __m128i lo = _mm_shufflelo_epi16(a, _MM_SHUFFLE(0, 1, 2, 3)); + return _mm_shufflehi_epi16(lo, _MM_SHUFFLE(0, 1, 2, 3)); +#endif +} +#define npyv_rev64_s16 npyv_rev64_u16 + +NPY_FINLINE npyv_u8 npyv_rev64_u8(npyv_u8 a) +{ +#ifdef NPY_HAVE_SSSE3 + const __m128i idx = _mm_setr_epi8( + 7, 6, 5, 4, 3, 2, 1, 0,/*64*/15, 14, 13, 12, 11, 10, 9, 8 + ); + return _mm_shuffle_epi8(a, idx); +#else + __m128i rev16 = npyv_rev64_u16(a); + // swap 8bit pairs + return _mm_or_si128(_mm_slli_epi16(rev16, 8), _mm_srli_epi16(rev16, 8)); +#endif +} +#define npyv_rev64_s8 npyv_rev64_u8 + +NPY_FINLINE npyv_u32 npyv_rev64_u32(npyv_u32 a) +{ + return _mm_shuffle_epi32(a, _MM_SHUFFLE(2, 3, 0, 1)); +} +#define npyv_rev64_s32 npyv_rev64_u32 + +NPY_FINLINE npyv_f32 npyv_rev64_f32(npyv_f32 a) +{ + return _mm_shuffle_ps(a, a, _MM_SHUFFLE(2, 3, 0, 1)); +} + +// Permuting the elements of each 128-bit lane by immediate index for +// each element. +#define npyv_permi128_u32(A, E0, E1, E2, E3) \ + _mm_shuffle_epi32(A, _MM_SHUFFLE(E3, E2, E1, E0)) + +#define npyv_permi128_s32 npyv_permi128_u32 + +#define npyv_permi128_u64(A, E0, E1) \ + _mm_shuffle_epi32(A, _MM_SHUFFLE(((E1)<<1)+1, ((E1)<<1), ((E0)<<1)+1, ((E0)<<1))) + +#define npyv_permi128_s64 npyv_permi128_u64 + +#define npyv_permi128_f32(A, E0, E1, E2, E3) \ + _mm_shuffle_ps(A, A, _MM_SHUFFLE(E3, E2, E1, E0)) + +#define npyv_permi128_f64(A, E0, E1) \ + _mm_shuffle_pd(A, A, _MM_SHUFFLE2(E1, E0)) + +#endif // _NPY_SIMD_SSE_REORDER_H diff --git a/mkl_umath/src/npyv/sse/sse.h b/mkl_umath/src/npyv/sse/sse.h new file mode 100644 index 00000000..0c6b8cdb --- /dev/null +++ b/mkl_umath/src/npyv/sse/sse.h @@ -0,0 +1,76 @@ +#ifndef _NPY_SIMD_H_ + #error "Not a standalone header" +#endif + +#define NPY_SIMD 128 +#define NPY_SIMD_WIDTH 16 +#define NPY_SIMD_F32 1 +#define NPY_SIMD_F64 1 +#if defined(NPY_HAVE_FMA3) || defined(NPY_HAVE_FMA4) + #define NPY_SIMD_FMA3 1 // native support +#else + #define NPY_SIMD_FMA3 0 // fast emulated +#endif +#define NPY_SIMD_BIGENDIAN 0 +#define NPY_SIMD_CMPSIGNAL 1 + +typedef __m128i npyv_u8; +typedef __m128i npyv_s8; +typedef __m128i npyv_u16; +typedef __m128i npyv_s16; +typedef __m128i npyv_u32; +typedef __m128i npyv_s32; +typedef __m128i npyv_u64; +typedef __m128i npyv_s64; +typedef __m128 npyv_f32; +typedef __m128d npyv_f64; + +typedef __m128i npyv_b8; +typedef __m128i npyv_b16; +typedef __m128i npyv_b32; +typedef __m128i npyv_b64; + +typedef struct { __m128i val[2]; } npyv_m128ix2; +typedef npyv_m128ix2 npyv_u8x2; +typedef npyv_m128ix2 npyv_s8x2; +typedef npyv_m128ix2 npyv_u16x2; +typedef npyv_m128ix2 npyv_s16x2; +typedef npyv_m128ix2 npyv_u32x2; +typedef npyv_m128ix2 npyv_s32x2; +typedef npyv_m128ix2 npyv_u64x2; +typedef npyv_m128ix2 npyv_s64x2; + +typedef struct { __m128i val[3]; } npyv_m128ix3; +typedef npyv_m128ix3 npyv_u8x3; +typedef npyv_m128ix3 npyv_s8x3; +typedef npyv_m128ix3 npyv_u16x3; +typedef npyv_m128ix3 npyv_s16x3; +typedef npyv_m128ix3 npyv_u32x3; +typedef npyv_m128ix3 npyv_s32x3; +typedef npyv_m128ix3 npyv_u64x3; +typedef npyv_m128ix3 npyv_s64x3; + +typedef struct { __m128 val[2]; } npyv_f32x2; +typedef struct { __m128d val[2]; } npyv_f64x2; +typedef struct { __m128 val[3]; } npyv_f32x3; +typedef struct { __m128d val[3]; } npyv_f64x3; + +#define npyv_nlanes_u8 16 +#define npyv_nlanes_s8 16 +#define npyv_nlanes_u16 8 +#define npyv_nlanes_s16 8 +#define npyv_nlanes_u32 4 +#define npyv_nlanes_s32 4 +#define npyv_nlanes_u64 2 +#define npyv_nlanes_s64 2 +#define npyv_nlanes_f32 4 +#define npyv_nlanes_f64 2 + +#include "utils.h" +#include "memory.h" +#include "misc.h" +#include "reorder.h" +#include "operators.h" +#include "conversion.h" +#include "arithmetic.h" +#include "math.h" diff --git a/mkl_umath/src/npyv/sse/utils.h b/mkl_umath/src/npyv/sse/utils.h new file mode 100644 index 00000000..c23def11 --- /dev/null +++ b/mkl_umath/src/npyv/sse/utils.h @@ -0,0 +1,19 @@ +#ifndef NPY_SIMD + #error "Not a standalone header" +#endif + +#ifndef _NPY_SIMD_SSE_UTILS_H +#define _NPY_SIMD_SSE_UTILS_H + +#if !defined(__x86_64__) && !defined(_M_X64) +NPY_FINLINE npy_int64 npyv128_cvtsi128_si64(__m128i a) +{ + npy_int64 NPY_DECL_ALIGNED(16) idx[2]; + _mm_store_si128((__m128i *)idx, a); + return idx[0]; +} +#else + #define npyv128_cvtsi128_si64 _mm_cvtsi128_si64 +#endif + +#endif // _NPY_SIMD_SSE_UTILS_H diff --git a/mkl_umath/src/npyv/ucsnarrow.h b/mkl_umath/src/npyv/ucsnarrow.h new file mode 100644 index 00000000..4b17a280 --- /dev/null +++ b/mkl_umath/src/npyv/ucsnarrow.h @@ -0,0 +1,7 @@ +#ifndef NUMPY_CORE_SRC_COMMON_NPY_UCSNARROW_H_ +#define NUMPY_CORE_SRC_COMMON_NPY_UCSNARROW_H_ + +NPY_NO_EXPORT PyUnicodeObject * +PyUnicode_FromUCS4(char const *src, Py_ssize_t size, int swap, int align); + +#endif /* NUMPY_CORE_SRC_COMMON_NPY_UCSNARROW_H_ */ diff --git a/mkl_umath/src/npyv/ufunc_override.h b/mkl_umath/src/npyv/ufunc_override.h new file mode 100644 index 00000000..5da95fb2 --- /dev/null +++ b/mkl_umath/src/npyv/ufunc_override.h @@ -0,0 +1,38 @@ +#ifndef NUMPY_CORE_SRC_COMMON_UFUNC_OVERRIDE_H_ +#define NUMPY_CORE_SRC_COMMON_UFUNC_OVERRIDE_H_ + +#include "npy_config.h" + +/* + * Check whether an object has __array_ufunc__ defined on its class and it + * is not the default, i.e., the object is not an ndarray, and its + * __array_ufunc__ is not the same as that of ndarray. + * + * Returns a new reference, the value of type(obj).__array_ufunc__ if it + * exists and is different from that of ndarray, and NULL otherwise. + */ +NPY_NO_EXPORT PyObject * +PyUFuncOverride_GetNonDefaultArrayUfunc(PyObject *obj); + +/* + * Check whether an object has __array_ufunc__ defined on its class and it + * is not the default, i.e., the object is not an ndarray, and its + * __array_ufunc__ is not the same as that of ndarray. + * + * Returns 1 if this is the case, 0 if not. + */ +NPY_NO_EXPORT int +PyUFunc_HasOverride(PyObject *obj); + +/* + * Get possible out argument from kwds, and returns the number of outputs + * contained within it: if a tuple, the number of elements in it, 1 otherwise. + * The out argument itself is returned in out_kwd_obj, and the outputs + * in the out_obj array (as borrowed references). + * + * Returns 0 if no outputs found, -1 if kwds is not a dict (with an error set). + */ +NPY_NO_EXPORT int +PyUFuncOverride_GetOutObjects(PyObject *kwds, PyObject **out_kwd_obj, PyObject ***out_objs); + +#endif /* NUMPY_CORE_SRC_COMMON_UFUNC_OVERRIDE_H_ */ diff --git a/mkl_umath/src/npyv/umathmodule.h b/mkl_umath/src/npyv/umathmodule.h new file mode 100644 index 00000000..73d85334 --- /dev/null +++ b/mkl_umath/src/npyv/umathmodule.h @@ -0,0 +1,18 @@ +#ifndef NUMPY_CORE_SRC_COMMON_UMATHMODULE_H_ +#define NUMPY_CORE_SRC_COMMON_UMATHMODULE_H_ + +#include "ufunc_object.h" +#include "ufunc_type_resolution.h" +#include "extobj.h" /* for the python side extobj set/get */ + + +NPY_NO_EXPORT PyObject * +get_sfloat_dtype(PyObject *NPY_UNUSED(mod), PyObject *NPY_UNUSED(args)); + +PyObject * add_newdoc_ufunc(PyObject *NPY_UNUSED(dummy), PyObject *args); +PyObject * ufunc_frompyfunc(PyObject *NPY_UNUSED(dummy), PyObject *args, PyObject *NPY_UNUSED(kwds)); + + +int initumath(PyObject *m); + +#endif /* NUMPY_CORE_SRC_COMMON_UMATHMODULE_H_ */ diff --git a/mkl_umath/src/npyv/vec/arithmetic.h b/mkl_umath/src/npyv/vec/arithmetic.h new file mode 100644 index 00000000..85f4d6b2 --- /dev/null +++ b/mkl_umath/src/npyv/vec/arithmetic.h @@ -0,0 +1,409 @@ +#ifndef NPY_SIMD + #error "Not a standalone header" +#endif + +#ifndef _NPY_SIMD_VEC_ARITHMETIC_H +#define _NPY_SIMD_VEC_ARITHMETIC_H + +/*************************** + * Addition + ***************************/ +// non-saturated +#define npyv_add_u8 vec_add +#define npyv_add_s8 vec_add +#define npyv_add_u16 vec_add +#define npyv_add_s16 vec_add +#define npyv_add_u32 vec_add +#define npyv_add_s32 vec_add +#define npyv_add_u64 vec_add +#define npyv_add_s64 vec_add +#if NPY_SIMD_F32 +#define npyv_add_f32 vec_add +#endif +#define npyv_add_f64 vec_add + +// saturated +#ifdef NPY_HAVE_VX + #define NPYV_IMPL_VX_ADDS(SFX, PSFX) \ + NPY_FINLINE npyv_##SFX npyv_adds_##SFX(npyv_##SFX a, npyv_##SFX b)\ + { \ + return vec_pack##PSFX( \ + vec_add(vec_unpackh(a), vec_unpackh(b)), \ + vec_add(vec_unpackl(a), vec_unpackl(b)) \ + ); \ + } + + NPYV_IMPL_VX_ADDS(u8, su) + NPYV_IMPL_VX_ADDS(s8, s) + NPYV_IMPL_VX_ADDS(u16, su) + NPYV_IMPL_VX_ADDS(s16, s) +#else // VSX + #define npyv_adds_u8 vec_adds + #define npyv_adds_s8 vec_adds + #define npyv_adds_u16 vec_adds + #define npyv_adds_s16 vec_adds +#endif +/*************************** + * Subtraction + ***************************/ +// non-saturated +#define npyv_sub_u8 vec_sub +#define npyv_sub_s8 vec_sub +#define npyv_sub_u16 vec_sub +#define npyv_sub_s16 vec_sub +#define npyv_sub_u32 vec_sub +#define npyv_sub_s32 vec_sub +#define npyv_sub_u64 vec_sub +#define npyv_sub_s64 vec_sub +#if NPY_SIMD_F32 +#define npyv_sub_f32 vec_sub +#endif +#define npyv_sub_f64 vec_sub + +// saturated +#ifdef NPY_HAVE_VX + #define NPYV_IMPL_VX_SUBS(SFX, PSFX) \ + NPY_FINLINE npyv_##SFX npyv_subs_##SFX(npyv_##SFX a, npyv_##SFX b)\ + { \ + return vec_pack##PSFX( \ + vec_sub(vec_unpackh(a), vec_unpackh(b)), \ + vec_sub(vec_unpackl(a), vec_unpackl(b)) \ + ); \ + } + + NPYV_IMPL_VX_SUBS(u8, su) + NPYV_IMPL_VX_SUBS(s8, s) + NPYV_IMPL_VX_SUBS(u16, su) + NPYV_IMPL_VX_SUBS(s16, s) +#else // VSX + #define npyv_subs_u8 vec_subs + #define npyv_subs_s8 vec_subs + #define npyv_subs_u16 vec_subs + #define npyv_subs_s16 vec_subs +#endif + +/*************************** + * Multiplication + ***************************/ +// non-saturated +// up to GCC 6 vec_mul only supports precisions and llong +#if defined(NPY_HAVE_VSX) && defined(__GNUC__) && __GNUC__ < 7 + #define NPYV_IMPL_VSX_MUL(T_VEC, SFX, ...) \ + NPY_FINLINE T_VEC npyv_mul_##SFX(T_VEC a, T_VEC b) \ + { \ + const npyv_u8 ev_od = {__VA_ARGS__}; \ + return vec_perm( \ + (T_VEC)vec_mule(a, b), \ + (T_VEC)vec_mulo(a, b), ev_od \ + ); \ + } + + NPYV_IMPL_VSX_MUL(npyv_u8, u8, 0, 16, 2, 18, 4, 20, 6, 22, 8, 24, 10, 26, 12, 28, 14, 30) + NPYV_IMPL_VSX_MUL(npyv_s8, s8, 0, 16, 2, 18, 4, 20, 6, 22, 8, 24, 10, 26, 12, 28, 14, 30) + NPYV_IMPL_VSX_MUL(npyv_u16, u16, 0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29) + NPYV_IMPL_VSX_MUL(npyv_s16, s16, 0, 1, 16, 17, 4, 5, 20, 21, 8, 9, 24, 25, 12, 13, 28, 29) + + // vmuluwm can be used for unsigned or signed 32-bit integers + #define NPYV_IMPL_VSX_MUL_32(T_VEC, SFX) \ + NPY_FINLINE T_VEC npyv_mul_##SFX(T_VEC a, T_VEC b) \ + { \ + T_VEC ret; \ + __asm__ __volatile__( \ + "vmuluwm %0,%1,%2" : \ + "=v" (ret) : "v" (a), "v" (b) \ + ); \ + return ret; \ + } + + NPYV_IMPL_VSX_MUL_32(npyv_u32, u32) + NPYV_IMPL_VSX_MUL_32(npyv_s32, s32) + +#else + #define npyv_mul_u8 vec_mul + #define npyv_mul_s8 vec_mul + #define npyv_mul_u16 vec_mul + #define npyv_mul_s16 vec_mul + #define npyv_mul_u32 vec_mul + #define npyv_mul_s32 vec_mul +#endif +#if NPY_SIMD_F32 +#define npyv_mul_f32 vec_mul +#endif +#define npyv_mul_f64 vec_mul + +/*************************** + * Integer Division + ***************************/ +// See simd/intdiv.h for more clarification +// divide each unsigned 8-bit element by a precomputed divisor +NPY_FINLINE npyv_u8 npyv_divc_u8(npyv_u8 a, const npyv_u8x3 divisor) +{ +#ifdef NPY_HAVE_VX + npyv_u8 mulhi = vec_mulh(a, divisor.val[0]); +#else // VSX + const npyv_u8 mergeo_perm = { + 1, 17, 3, 19, 5, 21, 7, 23, 9, 25, 11, 27, 13, 29, 15, 31 + }; + // high part of unsigned multiplication + npyv_u16 mul_even = vec_mule(a, divisor.val[0]); + npyv_u16 mul_odd = vec_mulo(a, divisor.val[0]); + npyv_u8 mulhi = (npyv_u8)vec_perm(mul_even, mul_odd, mergeo_perm); +#endif + // floor(a/d) = (mulhi + ((a-mulhi) >> sh1)) >> sh2 + npyv_u8 q = vec_sub(a, mulhi); + q = vec_sr(q, divisor.val[1]); + q = vec_add(mulhi, q); + q = vec_sr(q, divisor.val[2]); + return q; +} +// divide each signed 8-bit element by a precomputed divisor +NPY_FINLINE npyv_s8 npyv_divc_s8(npyv_s8 a, const npyv_s8x3 divisor) +{ +#ifdef NPY_HAVE_VX + npyv_s8 mulhi = vec_mulh(a, divisor.val[0]); +#else + const npyv_u8 mergeo_perm = { + 1, 17, 3, 19, 5, 21, 7, 23, 9, 25, 11, 27, 13, 29, 15, 31 + }; + // high part of signed multiplication + npyv_s16 mul_even = vec_mule(a, divisor.val[0]); + npyv_s16 mul_odd = vec_mulo(a, divisor.val[0]); + npyv_s8 mulhi = (npyv_s8)vec_perm(mul_even, mul_odd, mergeo_perm); +#endif + // q = ((a + mulhi) >> sh1) - XSIGN(a) + // trunc(a/d) = (q ^ dsign) - dsign + npyv_s8 q = vec_sra_s8(vec_add(a, mulhi), (npyv_u8)divisor.val[1]); + q = vec_sub(q, vec_sra_s8(a, npyv_setall_u8(7))); + q = vec_sub(vec_xor(q, divisor.val[2]), divisor.val[2]); + return q; +} +// divide each unsigned 16-bit element by a precomputed divisor +NPY_FINLINE npyv_u16 npyv_divc_u16(npyv_u16 a, const npyv_u16x3 divisor) +{ +#ifdef NPY_HAVE_VX + npyv_u16 mulhi = vec_mulh(a, divisor.val[0]); +#else // VSX + const npyv_u8 mergeo_perm = { + 2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31 + }; + // high part of unsigned multiplication + npyv_u32 mul_even = vec_mule(a, divisor.val[0]); + npyv_u32 mul_odd = vec_mulo(a, divisor.val[0]); + npyv_u16 mulhi = (npyv_u16)vec_perm(mul_even, mul_odd, mergeo_perm); +#endif + // floor(a/d) = (mulhi + ((a-mulhi) >> sh1)) >> sh2 + npyv_u16 q = vec_sub(a, mulhi); + q = vec_sr(q, divisor.val[1]); + q = vec_add(mulhi, q); + q = vec_sr(q, divisor.val[2]); + return q; +} +// divide each signed 16-bit element by a precomputed divisor (round towards zero) +NPY_FINLINE npyv_s16 npyv_divc_s16(npyv_s16 a, const npyv_s16x3 divisor) +{ +#ifdef NPY_HAVE_VX + npyv_s16 mulhi = vec_mulh(a, divisor.val[0]); +#else // VSX + const npyv_u8 mergeo_perm = { + 2, 3, 18, 19, 6, 7, 22, 23, 10, 11, 26, 27, 14, 15, 30, 31 + }; + // high part of signed multiplication + npyv_s32 mul_even = vec_mule(a, divisor.val[0]); + npyv_s32 mul_odd = vec_mulo(a, divisor.val[0]); + npyv_s16 mulhi = (npyv_s16)vec_perm(mul_even, mul_odd, mergeo_perm); +#endif + // q = ((a + mulhi) >> sh1) - XSIGN(a) + // trunc(a/d) = (q ^ dsign) - dsign + npyv_s16 q = vec_sra_s16(vec_add(a, mulhi), (npyv_u16)divisor.val[1]); + q = vec_sub(q, vec_sra_s16(a, npyv_setall_u16(15))); + q = vec_sub(vec_xor(q, divisor.val[2]), divisor.val[2]); + return q; +} +// divide each unsigned 32-bit element by a precomputed divisor +NPY_FINLINE npyv_u32 npyv_divc_u32(npyv_u32 a, const npyv_u32x3 divisor) +{ +#if defined(NPY_HAVE_VSX4) || defined(NPY_HAVE_VX) + // high part of unsigned multiplication + npyv_u32 mulhi = vec_mulh(a, divisor.val[0]); +#else // VSX + #if defined(__GNUC__) && __GNUC__ < 8 + // Doubleword integer wide multiplication supported by GCC 8+ + npyv_u64 mul_even, mul_odd; + __asm__ ("vmulouw %0,%1,%2" : "=v" (mul_even) : "v" (a), "v" (divisor.val[0])); + __asm__ ("vmuleuw %0,%1,%2" : "=v" (mul_odd) : "v" (a), "v" (divisor.val[0])); + #else + // Doubleword integer wide multiplication supported by GCC 8+ + npyv_u64 mul_even = vec_mule(a, divisor.val[0]); + npyv_u64 mul_odd = vec_mulo(a, divisor.val[0]); + #endif + // high part of unsigned multiplication + npyv_u32 mulhi = vec_mergeo((npyv_u32)mul_even, (npyv_u32)mul_odd); +#endif + // floor(x/d) = (((a-mulhi) >> sh1) + mulhi) >> sh2 + npyv_u32 q = vec_sub(a, mulhi); + q = vec_sr(q, divisor.val[1]); + q = vec_add(mulhi, q); + q = vec_sr(q, divisor.val[2]); + return q; +} +// divide each signed 32-bit element by a precomputed divisor (round towards zero) +NPY_FINLINE npyv_s32 npyv_divc_s32(npyv_s32 a, const npyv_s32x3 divisor) +{ +#if defined(NPY_HAVE_VSX4) || defined(NPY_HAVE_VX) + // high part of signed multiplication + npyv_s32 mulhi = vec_mulh(a, divisor.val[0]); +#else + #if defined(__GNUC__) && __GNUC__ < 8 + // Doubleword integer wide multiplication supported by GCC8+ + npyv_s64 mul_even, mul_odd; + __asm__ ("vmulosw %0,%1,%2" : "=v" (mul_even) : "v" (a), "v" (divisor.val[0])); + __asm__ ("vmulesw %0,%1,%2" : "=v" (mul_odd) : "v" (a), "v" (divisor.val[0])); + #else + // Doubleword integer wide multiplication supported by GCC8+ + npyv_s64 mul_even = vec_mule(a, divisor.val[0]); + npyv_s64 mul_odd = vec_mulo(a, divisor.val[0]); + #endif + // high part of signed multiplication + npyv_s32 mulhi = vec_mergeo((npyv_s32)mul_even, (npyv_s32)mul_odd); +#endif + // q = ((a + mulhi) >> sh1) - XSIGN(a) + // trunc(a/d) = (q ^ dsign) - dsign + npyv_s32 q = vec_sra_s32(vec_add(a, mulhi), (npyv_u32)divisor.val[1]); + q = vec_sub(q, vec_sra_s32(a, npyv_setall_u32(31))); + q = vec_sub(vec_xor(q, divisor.val[2]), divisor.val[2]); + return q; +} +// divide each unsigned 64-bit element by a precomputed divisor +NPY_FINLINE npyv_u64 npyv_divc_u64(npyv_u64 a, const npyv_u64x3 divisor) +{ +#if defined(NPY_HAVE_VSX4) + return vec_div(a, divisor.val[0]); +#else + const npy_uint64 d = vec_extract(divisor.val[0], 0); + return npyv_set_u64(vec_extract(a, 0) / d, vec_extract(a, 1) / d); +#endif +} +// divide each signed 64-bit element by a precomputed divisor (round towards zero) +NPY_FINLINE npyv_s64 npyv_divc_s64(npyv_s64 a, const npyv_s64x3 divisor) +{ + npyv_b64 overflow = npyv_and_b64(vec_cmpeq(a, npyv_setall_s64(-1LL << 63)), (npyv_b64)divisor.val[1]); + npyv_s64 d = vec_sel(divisor.val[0], npyv_setall_s64(1), overflow); + return vec_div(a, d); +} +/*************************** + * Division + ***************************/ +#if NPY_SIMD_F32 + #define npyv_div_f32 vec_div +#endif +#define npyv_div_f64 vec_div + +/*************************** + * FUSED + ***************************/ +// multiply and add, a*b + c +#define npyv_muladd_f64 vec_madd +// multiply and subtract, a*b - c +#define npyv_mulsub_f64 vec_msub +#if NPY_SIMD_F32 + #define npyv_muladd_f32 vec_madd + #define npyv_mulsub_f32 vec_msub +#endif +#if defined(NPY_HAVE_VXE) || defined(NPY_HAVE_VSX) + // negate multiply and add, -(a*b) + c + #define npyv_nmuladd_f32 vec_nmsub // equivalent to -(a*b - c) + #define npyv_nmuladd_f64 vec_nmsub + // negate multiply and subtract, -(a*b) - c + #define npyv_nmulsub_f64 vec_nmadd + #define npyv_nmulsub_f32 vec_nmadd // equivalent to -(a*b + c) +#else + NPY_FINLINE npyv_f64 npyv_nmuladd_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c) + { return vec_neg(vec_msub(a, b, c)); } + NPY_FINLINE npyv_f64 npyv_nmulsub_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c) + { return vec_neg(vec_madd(a, b, c)); } +#endif +// multiply, add for odd elements and subtract even elements. +// (a * b) -+ c +#if NPY_SIMD_F32 +NPY_FINLINE npyv_f32 npyv_muladdsub_f32(npyv_f32 a, npyv_f32 b, npyv_f32 c) +{ + const npyv_f32 msign = npyv_set_f32(-0.0f, 0.0f, -0.0f, 0.0f); + return npyv_muladd_f32(a, b, npyv_xor_f32(msign, c)); +} +#endif +NPY_FINLINE npyv_f64 npyv_muladdsub_f64(npyv_f64 a, npyv_f64 b, npyv_f64 c) +{ + const npyv_f64 msign = npyv_set_f64(-0.0, 0.0); + return npyv_muladd_f64(a, b, npyv_xor_f64(msign, c)); +} +/*************************** + * Summation + ***************************/ +// reduce sum across vector +NPY_FINLINE npy_uint64 npyv_sum_u64(npyv_u64 a) +{ +#ifdef NPY_HAVE_VX + const npyv_u64 zero = npyv_zero_u64(); + return vec_extract((npyv_u64)vec_sum_u128(a, zero), 1); +#else + return vec_extract(vec_add(a, vec_mergel(a, a)), 0); +#endif +} + +NPY_FINLINE npy_uint32 npyv_sum_u32(npyv_u32 a) +{ +#ifdef NPY_HAVE_VX + const npyv_u32 zero = npyv_zero_u32(); + return vec_extract((npyv_u32)vec_sum_u128(a, zero), 3); +#else + const npyv_u32 rs = vec_add(a, vec_sld(a, a, 8)); + return vec_extract(vec_add(rs, vec_sld(rs, rs, 4)), 0); +#endif +} + +#if NPY_SIMD_F32 +NPY_FINLINE float npyv_sum_f32(npyv_f32 a) +{ + npyv_f32 sum = vec_add(a, npyv_combineh_f32(a, a)); + return vec_extract(sum, 0) + vec_extract(sum, 1); + (void)sum; +} +#endif + +NPY_FINLINE double npyv_sum_f64(npyv_f64 a) +{ + return vec_extract(a, 0) + vec_extract(a, 1); +} + +// expand the source vector and performs sum reduce +NPY_FINLINE npy_uint16 npyv_sumup_u8(npyv_u8 a) +{ +#ifdef NPY_HAVE_VX + const npyv_u8 zero = npyv_zero_u8(); + npyv_u32 sum4 = vec_sum4(a, zero); + return (npy_uint16)npyv_sum_u32(sum4); +#else + const npyv_u32 zero = npyv_zero_u32(); + npyv_u32 four = vec_sum4s(a, zero); + npyv_s32 one = vec_sums((npyv_s32)four, (npyv_s32)zero); + return (npy_uint16)vec_extract(one, 3); + (void)one; +#endif +} + +NPY_FINLINE npy_uint32 npyv_sumup_u16(npyv_u16 a) +{ +#ifdef NPY_HAVE_VX + npyv_u64 sum = vec_sum2(a, npyv_zero_u16()); + return (npy_uint32)npyv_sum_u64(sum); +#else // VSX + const npyv_s32 zero = npyv_zero_s32(); + npyv_u32x2 eight = npyv_expand_u32_u16(a); + npyv_u32 four = vec_add(eight.val[0], eight.val[1]); + npyv_s32 one = vec_sums((npyv_s32)four, zero); + return (npy_uint32)vec_extract(one, 3); + (void)one; +#endif +} + +#endif // _NPY_SIMD_VEC_ARITHMETIC_H diff --git a/mkl_umath/src/npyv/vec/conversion.h b/mkl_umath/src/npyv/vec/conversion.h new file mode 100644 index 00000000..922109f7 --- /dev/null +++ b/mkl_umath/src/npyv/vec/conversion.h @@ -0,0 +1,237 @@ +#ifndef NPY_SIMD + #error "Not a standalone header" +#endif + +#ifndef _NPY_SIMD_VEC_CVT_H +#define _NPY_SIMD_VEC_CVT_H + +// convert boolean vectors to integer vectors +#define npyv_cvt_u8_b8(BL) ((npyv_u8) BL) +#define npyv_cvt_s8_b8(BL) ((npyv_s8) BL) +#define npyv_cvt_u16_b16(BL) ((npyv_u16) BL) +#define npyv_cvt_s16_b16(BL) ((npyv_s16) BL) +#define npyv_cvt_u32_b32(BL) ((npyv_u32) BL) +#define npyv_cvt_s32_b32(BL) ((npyv_s32) BL) +#define npyv_cvt_u64_b64(BL) ((npyv_u64) BL) +#define npyv_cvt_s64_b64(BL) ((npyv_s64) BL) +#if NPY_SIMD_F32 + #define npyv_cvt_f32_b32(BL) ((npyv_f32) BL) +#endif +#define npyv_cvt_f64_b64(BL) ((npyv_f64) BL) + +// convert integer vectors to boolean vectors +#define npyv_cvt_b8_u8(A) ((npyv_b8) A) +#define npyv_cvt_b8_s8(A) ((npyv_b8) A) +#define npyv_cvt_b16_u16(A) ((npyv_b16) A) +#define npyv_cvt_b16_s16(A) ((npyv_b16) A) +#define npyv_cvt_b32_u32(A) ((npyv_b32) A) +#define npyv_cvt_b32_s32(A) ((npyv_b32) A) +#define npyv_cvt_b64_u64(A) ((npyv_b64) A) +#define npyv_cvt_b64_s64(A) ((npyv_b64) A) +#if NPY_SIMD_F32 + #define npyv_cvt_b32_f32(A) ((npyv_b32) A) +#endif +#define npyv_cvt_b64_f64(A) ((npyv_b64) A) + +//expand +NPY_FINLINE npyv_u16x2 npyv_expand_u16_u8(npyv_u8 data) +{ + npyv_u16x2 r; +#ifdef NPY_HAVE_VX + r.val[0] = vec_unpackh(data); + r.val[1] = vec_unpackl(data); +#else + npyv_u8 zero = npyv_zero_u8(); + r.val[0] = (npyv_u16)vec_mergeh(data, zero); + r.val[1] = (npyv_u16)vec_mergel(data, zero); +#endif + return r; +} + +NPY_FINLINE npyv_u32x2 npyv_expand_u32_u16(npyv_u16 data) +{ + npyv_u32x2 r; +#ifdef NPY_HAVE_VX + r.val[0] = vec_unpackh(data); + r.val[1] = vec_unpackl(data); +#else + npyv_u16 zero = npyv_zero_u16(); + r.val[0] = (npyv_u32)vec_mergeh(data, zero); + r.val[1] = (npyv_u32)vec_mergel(data, zero); +#endif + return r; +} + +// pack two 16-bit boolean into one 8-bit boolean vector +NPY_FINLINE npyv_b8 npyv_pack_b8_b16(npyv_b16 a, npyv_b16 b) { + return vec_pack(a, b); +} + +// pack four 32-bit boolean vectors into one 8-bit boolean vector +NPY_FINLINE npyv_b8 npyv_pack_b8_b32(npyv_b32 a, npyv_b32 b, npyv_b32 c, npyv_b32 d) { + npyv_b16 ab = vec_pack(a, b); + npyv_b16 cd = vec_pack(c, d); + return npyv_pack_b8_b16(ab, cd); +} + +// pack eight 64-bit boolean vectors into one 8-bit boolean vector +NPY_FINLINE npyv_b8 +npyv_pack_b8_b64(npyv_b64 a, npyv_b64 b, npyv_b64 c, npyv_b64 d, + npyv_b64 e, npyv_b64 f, npyv_b64 g, npyv_b64 h) { + npyv_b32 ab = vec_pack(a, b); + npyv_b32 cd = vec_pack(c, d); + npyv_b32 ef = vec_pack(e, f); + npyv_b32 gh = vec_pack(g, h); + return npyv_pack_b8_b32(ab, cd, ef, gh); +} + +// convert boolean vector to integer bitfield +#if defined(NPY_HAVE_VXE) || defined(NPY_HAVE_VSX2) + NPY_FINLINE npy_uint64 npyv_tobits_b8(npyv_b8 a) + { + const npyv_u8 qperm = npyv_set_u8(120, 112, 104, 96, 88, 80, 72, 64, 56, 48, 40, 32, 24, 16, 8, 0); + npyv_u16 r = (npyv_u16)vec_vbpermq((npyv_u8)a, qperm); + #ifdef NPY_HAVE_VXE + return vec_extract(r, 3); + #else + return vec_extract(r, 4); + #endif + // to suppress ambiguous warning: variable `r` but not used [-Wunused-but-set-variable] + (void)r; + } + NPY_FINLINE npy_uint64 npyv_tobits_b16(npyv_b16 a) + { + const npyv_u8 qperm = npyv_setf_u8(128, 112, 96, 80, 64, 48, 32, 16, 0); + npyv_u8 r = (npyv_u8)vec_vbpermq((npyv_u8)a, qperm); + #ifdef NPY_HAVE_VXE + return vec_extract(r, 6); + #else + return vec_extract(r, 8); + #endif + // to suppress ambiguous warning: variable `r` but not used [-Wunused-but-set-variable] + (void)r; + } + NPY_FINLINE npy_uint64 npyv_tobits_b32(npyv_b32 a) + { + #ifdef NPY_HAVE_VXE + const npyv_u8 qperm = npyv_setf_u8(128, 128, 128, 128, 128, 96, 64, 32, 0); + #else + const npyv_u8 qperm = npyv_setf_u8(128, 96, 64, 32, 0); + #endif + npyv_u8 r = (npyv_u8)vec_vbpermq((npyv_u8)a, qperm); + #ifdef NPY_HAVE_VXE + return vec_extract(r, 6); + #else + return vec_extract(r, 8); + #endif + // to suppress ambiguous warning: variable `r` but not used [-Wunused-but-set-variable] + (void)r; + } + NPY_FINLINE npy_uint64 npyv_tobits_b64(npyv_b64 a) + { + #ifdef NPY_HAVE_VXE + const npyv_u8 qperm = npyv_setf_u8(128, 128, 128, 128, 128, 128, 128, 64, 0); + #else + const npyv_u8 qperm = npyv_setf_u8(128, 64, 0); + #endif + npyv_u8 r = (npyv_u8)vec_vbpermq((npyv_u8)a, qperm); + #ifdef NPY_HAVE_VXE + return vec_extract(r, 6); + #else + return vec_extract(r, 8); + #endif + // to suppress ambiguous warning: variable `r` but not used [-Wunused-but-set-variable] + (void)r; + } +#else + NPY_FINLINE npy_uint64 npyv_tobits_b8(npyv_b8 a) + { + const npyv_u8 scale = npyv_set_u8(1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128); + npyv_u8 seq_scale = vec_and((npyv_u8)a, scale); + npyv_u64 sum = vec_sum2(vec_sum4(seq_scale, npyv_zero_u8()), npyv_zero_u32()); + return vec_extract(sum, 0) + ((int)vec_extract(sum, 1) << 8); + } + NPY_FINLINE npy_uint64 npyv_tobits_b16(npyv_b16 a) + { + const npyv_u16 scale = npyv_set_u16(1, 2, 4, 8, 16, 32, 64, 128); + npyv_u16 seq_scale = vec_and((npyv_u16)a, scale); + npyv_u64 sum = vec_sum2(seq_scale, npyv_zero_u16()); + return vec_extract(vec_sum_u128(sum, npyv_zero_u64()), 15); + } + NPY_FINLINE npy_uint64 npyv_tobits_b32(npyv_b32 a) + { + const npyv_u32 scale = npyv_set_u32(1, 2, 4, 8); + npyv_u32 seq_scale = vec_and((npyv_u32)a, scale); + return vec_extract(vec_sum_u128(seq_scale, npyv_zero_u32()), 15); + } + NPY_FINLINE npy_uint64 npyv_tobits_b64(npyv_b64 a) + { + const npyv_u64 scale = npyv_set_u64(1, 2); + npyv_u64 seq_scale = vec_and((npyv_u64)a, scale); + return vec_extract(vec_sum_u128(seq_scale, npyv_zero_u64()), 15); + } +#endif +// truncate compatible with all compilers(internal use for now) +#if NPY_SIMD_F32 + NPY_FINLINE npyv_s32 npyv__trunc_s32_f32(npyv_f32 a) + { + #ifdef NPY_HAVE_VXE2 + return vec_signed(a); + #elif defined(NPY_HAVE_VXE) + return vec_packs(vec_signed(npyv_doublee(vec_mergeh(a,a))), + vec_signed(npyv_doublee(vec_mergel(a, a)))); + // VSX + #elif defined(__IBMC__) + return vec_cts(a, 0); + #elif defined(__clang__) + /** + * old versions of CLANG doesn't support %x in the inline asm template + * which fixes register number when using any of the register constraints wa, wd, wf. + * therefore, we count on built-in functions. + */ + return __builtin_convertvector(a, npyv_s32); + #else // gcc + npyv_s32 ret; + __asm__ ("xvcvspsxws %x0,%x1" : "=wa" (ret) : "wa" (a)); + return ret; + #endif + } +#endif + +NPY_FINLINE npyv_s32 npyv__trunc_s32_f64(npyv_f64 a, npyv_f64 b) +{ +#ifdef NPY_HAVE_VX + return vec_packs(vec_signed(a), vec_signed(b)); +// VSX +#elif defined(__IBMC__) + const npyv_u8 seq_even = npyv_set_u8(0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27); + // unfortunately, XLC missing asm register vsx fixer + // hopefully, xlc can optimize around big-endian compatibility + npyv_s32 lo_even = vec_cts(a, 0); + npyv_s32 hi_even = vec_cts(b, 0); + return vec_perm(lo_even, hi_even, seq_even); +#else + const npyv_u8 seq_odd = npyv_set_u8(4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31); + #ifdef __clang__ + // __builtin_convertvector doesn't support this conversion on wide range of versions + // fortunately, almost all versions have direct builtin of 'xvcvdpsxws' + npyv_s32 lo_odd = __builtin_vsx_xvcvdpsxws(a); + npyv_s32 hi_odd = __builtin_vsx_xvcvdpsxws(b); + #else // gcc + npyv_s32 lo_odd, hi_odd; + __asm__ ("xvcvdpsxws %x0,%x1" : "=wa" (lo_odd) : "wa" (a)); + __asm__ ("xvcvdpsxws %x0,%x1" : "=wa" (hi_odd) : "wa" (b)); + #endif + return vec_perm(lo_odd, hi_odd, seq_odd); +#endif +} + +// round to nearest integer (assuming even) +#if NPY_SIMD_F32 + NPY_FINLINE npyv_s32 npyv_round_s32_f32(npyv_f32 a) + { return npyv__trunc_s32_f32(vec_rint(a)); } +#endif +NPY_FINLINE npyv_s32 npyv_round_s32_f64(npyv_f64 a, npyv_f64 b) +{ return npyv__trunc_s32_f64(vec_rint(a), vec_rint(b)); } + +#endif // _NPY_SIMD_VEC_CVT_H diff --git a/mkl_umath/src/npyv/vec/math.h b/mkl_umath/src/npyv/vec/math.h new file mode 100644 index 00000000..85690f76 --- /dev/null +++ b/mkl_umath/src/npyv/vec/math.h @@ -0,0 +1,285 @@ +#ifndef NPY_SIMD + #error "Not a standalone header" +#endif + +#ifndef _NPY_SIMD_VEC_MATH_H +#define _NPY_SIMD_VEC_MATH_H +/*************************** + * Elementary + ***************************/ +// Square root +#if NPY_SIMD_F32 + #define npyv_sqrt_f32 vec_sqrt +#endif +#define npyv_sqrt_f64 vec_sqrt + +// Reciprocal +#if NPY_SIMD_F32 + NPY_FINLINE npyv_f32 npyv_recip_f32(npyv_f32 a) + { + const npyv_f32 one = npyv_setall_f32(1.0f); + return vec_div(one, a); + } +#endif +NPY_FINLINE npyv_f64 npyv_recip_f64(npyv_f64 a) +{ + const npyv_f64 one = npyv_setall_f64(1.0); + return vec_div(one, a); +} + +// Absolute +#if NPY_SIMD_F32 + #define npyv_abs_f32 vec_abs +#endif +#define npyv_abs_f64 vec_abs + +// Square +#if NPY_SIMD_F32 + NPY_FINLINE npyv_f32 npyv_square_f32(npyv_f32 a) + { return vec_mul(a, a); } +#endif +NPY_FINLINE npyv_f64 npyv_square_f64(npyv_f64 a) +{ return vec_mul(a, a); } + +// Maximum, natively mapping with no guarantees to handle NaN. +#if NPY_SIMD_F32 + #define npyv_max_f32 vec_max +#endif +#define npyv_max_f64 vec_max +// Maximum, supports IEEE floating-point arithmetic (IEC 60559), +// - If one of the two vectors contains NaN, the equivalent element of the other vector is set +// - Only if both corresponded elements are NaN, NaN is set. +#if NPY_SIMD_F32 + #define npyv_maxp_f32 vec_max +#endif +#if defined(NPY_HAVE_VXE) || defined(NPY_HAVE_VSX) + #define npyv_maxp_f64 vec_max +#else + // vfmindb & vfmaxdb appears in zarch12 + NPY_FINLINE npyv_f64 npyv_maxp_f64(npyv_f64 a, npyv_f64 b) + { + npyv_b64 nn_a = npyv_notnan_f64(a); + npyv_b64 nn_b = npyv_notnan_f64(b); + return vec_max(vec_sel(b, a, nn_a), vec_sel(a, b, nn_b)); + } +#endif +#if NPY_SIMD_F32 + NPY_FINLINE npyv_f32 npyv_maxn_f32(npyv_f32 a, npyv_f32 b) + { + npyv_b32 nn_a = npyv_notnan_f32(a); + npyv_b32 nn_b = npyv_notnan_f32(b); + npyv_f32 max = vec_max(a, b); + return vec_sel(b, vec_sel(a, max, nn_a), nn_b); + } +#endif +NPY_FINLINE npyv_f64 npyv_maxn_f64(npyv_f64 a, npyv_f64 b) +{ + npyv_b64 nn_a = npyv_notnan_f64(a); + npyv_b64 nn_b = npyv_notnan_f64(b); + npyv_f64 max = vec_max(a, b); + return vec_sel(b, vec_sel(a, max, nn_a), nn_b); +} + +// Maximum, integer operations +#define npyv_max_u8 vec_max +#define npyv_max_s8 vec_max +#define npyv_max_u16 vec_max +#define npyv_max_s16 vec_max +#define npyv_max_u32 vec_max +#define npyv_max_s32 vec_max +#define npyv_max_u64 vec_max +#define npyv_max_s64 vec_max + +// Minimum, natively mapping with no guarantees to handle NaN. +#if NPY_SIMD_F32 + #define npyv_min_f32 vec_min +#endif +#define npyv_min_f64 vec_min +// Minimum, supports IEEE floating-point arithmetic (IEC 60559), +// - If one of the two vectors contains NaN, the equivalent element of the other vector is set +// - Only if both corresponded elements are NaN, NaN is set. +#if NPY_SIMD_F32 + #define npyv_minp_f32 vec_min +#endif +#if defined(NPY_HAVE_VXE) || defined(NPY_HAVE_VSX) + #define npyv_minp_f64 vec_min +#else + // vfmindb & vfmaxdb appears in zarch12 + NPY_FINLINE npyv_f64 npyv_minp_f64(npyv_f64 a, npyv_f64 b) + { + npyv_b64 nn_a = npyv_notnan_f64(a); + npyv_b64 nn_b = npyv_notnan_f64(b); + return vec_min(vec_sel(b, a, nn_a), vec_sel(a, b, nn_b)); + } +#endif +#if NPY_SIMD_F32 + NPY_FINLINE npyv_f32 npyv_minn_f32(npyv_f32 a, npyv_f32 b) + { + npyv_b32 nn_a = npyv_notnan_f32(a); + npyv_b32 nn_b = npyv_notnan_f32(b); + npyv_f32 min = vec_min(a, b); + return vec_sel(b, vec_sel(a, min, nn_a), nn_b); + } +#endif +NPY_FINLINE npyv_f64 npyv_minn_f64(npyv_f64 a, npyv_f64 b) +{ + npyv_b64 nn_a = npyv_notnan_f64(a); + npyv_b64 nn_b = npyv_notnan_f64(b); + npyv_f64 min = vec_min(a, b); + return vec_sel(b, vec_sel(a, min, nn_a), nn_b); +} + +// Minimum, integer operations +#define npyv_min_u8 vec_min +#define npyv_min_s8 vec_min +#define npyv_min_u16 vec_min +#define npyv_min_s16 vec_min +#define npyv_min_u32 vec_min +#define npyv_min_s32 vec_min +#define npyv_min_u64 vec_min +#define npyv_min_s64 vec_min + +#define NPY_IMPL_VEC_REDUCE_MINMAX(INTRIN, STYPE, SFX) \ + NPY_FINLINE npy_##STYPE npyv_reduce_##INTRIN##_##SFX(npyv_##SFX a) \ + { \ + npyv_##SFX r = vec_##INTRIN(a, vec_sld(a, a, 8)); \ + r = vec_##INTRIN(r, vec_sld(r, r, 4)); \ + r = vec_##INTRIN(r, vec_sld(r, r, 2)); \ + r = vec_##INTRIN(r, vec_sld(r, r, 1)); \ + return (npy_##STYPE)vec_extract(r, 0); \ + } +NPY_IMPL_VEC_REDUCE_MINMAX(min, uint8, u8) +NPY_IMPL_VEC_REDUCE_MINMAX(max, uint8, u8) +NPY_IMPL_VEC_REDUCE_MINMAX(min, int8, s8) +NPY_IMPL_VEC_REDUCE_MINMAX(max, int8, s8) +#undef NPY_IMPL_VEC_REDUCE_MINMAX + +#define NPY_IMPL_VEC_REDUCE_MINMAX(INTRIN, STYPE, SFX) \ + NPY_FINLINE npy_##STYPE npyv_reduce_##INTRIN##_##SFX(npyv_##SFX a) \ + { \ + npyv_##SFX r = vec_##INTRIN(a, vec_sld(a, a, 8)); \ + r = vec_##INTRIN(r, vec_sld(r, r, 4)); \ + r = vec_##INTRIN(r, vec_sld(r, r, 2)); \ + return (npy_##STYPE)vec_extract(r, 0); \ + } +NPY_IMPL_VEC_REDUCE_MINMAX(min, uint16, u16) +NPY_IMPL_VEC_REDUCE_MINMAX(max, uint16, u16) +NPY_IMPL_VEC_REDUCE_MINMAX(min, int16, s16) +NPY_IMPL_VEC_REDUCE_MINMAX(max, int16, s16) +#undef NPY_IMPL_VEC_REDUCE_MINMAX + +#define NPY_IMPL_VEC_REDUCE_MINMAX(INTRIN, STYPE, SFX) \ + NPY_FINLINE npy_##STYPE npyv_reduce_##INTRIN##_##SFX(npyv_##SFX a) \ + { \ + npyv_##SFX r = vec_##INTRIN(a, vec_sld(a, a, 8)); \ + r = vec_##INTRIN(r, vec_sld(r, r, 4)); \ + return (npy_##STYPE)vec_extract(r, 0); \ + } +NPY_IMPL_VEC_REDUCE_MINMAX(min, uint32, u32) +NPY_IMPL_VEC_REDUCE_MINMAX(max, uint32, u32) +NPY_IMPL_VEC_REDUCE_MINMAX(min, int32, s32) +NPY_IMPL_VEC_REDUCE_MINMAX(max, int32, s32) +#undef NPY_IMPL_VEC_REDUCE_MINMAX + +#define NPY_IMPL_VEC_REDUCE_MINMAX(INTRIN, STYPE, SFX) \ + NPY_FINLINE npy_##STYPE npyv_reduce_##INTRIN##_##SFX(npyv_##SFX a) \ + { \ + npyv_##SFX r = vec_##INTRIN(a, vec_sld(a, a, 8)); \ + return (npy_##STYPE)vec_extract(r, 0); \ + (void)r; \ + } +NPY_IMPL_VEC_REDUCE_MINMAX(min, uint64, u64) +NPY_IMPL_VEC_REDUCE_MINMAX(max, uint64, u64) +NPY_IMPL_VEC_REDUCE_MINMAX(min, int64, s64) +NPY_IMPL_VEC_REDUCE_MINMAX(max, int64, s64) +#undef NPY_IMPL_VEC_REDUCE_MINMAX + +#if NPY_SIMD_F32 + #define NPY_IMPL_VEC_REDUCE_MINMAX(INTRIN, INF) \ + NPY_FINLINE float npyv_reduce_##INTRIN##_f32(npyv_f32 a) \ + { \ + npyv_f32 r = vec_##INTRIN(a, vec_sld(a, a, 8)); \ + r = vec_##INTRIN(r, vec_sld(r, r, 4)); \ + return vec_extract(r, 0); \ + } \ + NPY_FINLINE float npyv_reduce_##INTRIN##p_f32(npyv_f32 a) \ + { \ + return npyv_reduce_##INTRIN##_f32(a); \ + } \ + NPY_FINLINE float npyv_reduce_##INTRIN##n_f32(npyv_f32 a) \ + { \ + npyv_b32 notnan = npyv_notnan_f32(a); \ + if (NPY_UNLIKELY(!npyv_all_b32(notnan))) { \ + const union { npy_uint32 i; float f;} \ + pnan = {0x7fc00000UL}; \ + return pnan.f; \ + } \ + return npyv_reduce_##INTRIN##_f32(a); \ + } + NPY_IMPL_VEC_REDUCE_MINMAX(min, 0x7f800000) + NPY_IMPL_VEC_REDUCE_MINMAX(max, 0xff800000) + #undef NPY_IMPL_VEC_REDUCE_MINMAX +#endif // NPY_SIMD_F32 + +#define NPY_IMPL_VEC_REDUCE_MINMAX(INTRIN, INF) \ + NPY_FINLINE double npyv_reduce_##INTRIN##_f64(npyv_f64 a) \ + { \ + npyv_f64 r = vec_##INTRIN(a, vec_sld(a, a, 8)); \ + return vec_extract(r, 0); \ + (void)r; \ + } \ + NPY_FINLINE double npyv_reduce_##INTRIN##n_f64(npyv_f64 a) \ + { \ + npyv_b64 notnan = npyv_notnan_f64(a); \ + if (NPY_UNLIKELY(!npyv_all_b64(notnan))) { \ + const union { npy_uint64 i; double f;} \ + pnan = {0x7ff8000000000000ull}; \ + return pnan.f; \ + } \ + return npyv_reduce_##INTRIN##_f64(a); \ + } +NPY_IMPL_VEC_REDUCE_MINMAX(min, 0x7ff0000000000000) +NPY_IMPL_VEC_REDUCE_MINMAX(max, 0xfff0000000000000) +#undef NPY_IMPL_VEC_REDUCE_MINMAX + +#if defined(NPY_HAVE_VXE) || defined(NPY_HAVE_VSX) + #define npyv_reduce_minp_f64 npyv_reduce_min_f64 + #define npyv_reduce_maxp_f64 npyv_reduce_max_f64 +#else + NPY_FINLINE double npyv_reduce_minp_f64(npyv_f64 a) + { + npyv_b64 notnan = npyv_notnan_f64(a); + if (NPY_UNLIKELY(!npyv_any_b64(notnan))) { + return vec_extract(a, 0); + } + a = npyv_select_f64(notnan, a, npyv_reinterpret_f64_u64( + npyv_setall_u64(0x7ff0000000000000))); + return npyv_reduce_min_f64(a); + } + NPY_FINLINE double npyv_reduce_maxp_f64(npyv_f64 a) + { + npyv_b64 notnan = npyv_notnan_f64(a); + if (NPY_UNLIKELY(!npyv_any_b64(notnan))) { + return vec_extract(a, 0); + } + a = npyv_select_f64(notnan, a, npyv_reinterpret_f64_u64( + npyv_setall_u64(0xfff0000000000000))); + return npyv_reduce_max_f64(a); + } +#endif +// round to nearest int even +#define npyv_rint_f64 vec_rint +// ceil +#define npyv_ceil_f64 vec_ceil +// trunc +#define npyv_trunc_f64 vec_trunc +// floor +#define npyv_floor_f64 vec_floor +#if NPY_SIMD_F32 + #define npyv_rint_f32 vec_rint + #define npyv_ceil_f32 vec_ceil + #define npyv_trunc_f32 vec_trunc + #define npyv_floor_f32 vec_floor +#endif + +#endif // _NPY_SIMD_VEC_MATH_H diff --git a/mkl_umath/src/npyv/vec/memory.h b/mkl_umath/src/npyv/vec/memory.h new file mode 100644 index 00000000..dbcdc16d --- /dev/null +++ b/mkl_umath/src/npyv/vec/memory.h @@ -0,0 +1,703 @@ +#ifndef NPY_SIMD + #error "Not a standalone header" +#endif + +#ifndef _NPY_SIMD_VEC_MEMORY_H +#define _NPY_SIMD_VEC_MEMORY_H + +#include "misc.h" + +/**************************** + * Private utilities + ****************************/ +// TODO: test load by cast +#define VSX__CAST_lOAD 0 +#if VSX__CAST_lOAD + #define npyv__load(T_VEC, PTR) (*((T_VEC*)(PTR))) +#else + /** + * CLANG fails to load unaligned addresses via vec_xl, vec_xst + * so we failback to vec_vsx_ld, vec_vsx_st + */ + #if defined (NPY_HAVE_VSX2) && ( \ + (defined(__GNUC__) && !defined(vec_xl)) || (defined(__clang__) && !defined(__IBMC__)) \ + ) + #define npyv__load(T_VEC, PTR) vec_vsx_ld(0, PTR) + #else // VX + #define npyv__load(T_VEC, PTR) vec_xl(0, PTR) + #endif +#endif +// unaligned store +#if defined (NPY_HAVE_VSX2) && ( \ + (defined(__GNUC__) && !defined(vec_xl)) || (defined(__clang__) && !defined(__IBMC__)) \ +) + #define npyv__store(PTR, VEC) vec_vsx_st(VEC, 0, PTR) +#else // VX + #define npyv__store(PTR, VEC) vec_xst(VEC, 0, PTR) +#endif + +// aligned load/store +#if defined (NPY_HAVE_VSX) + #define npyv__loada(PTR) vec_ld(0, PTR) + #define npyv__storea(PTR, VEC) vec_st(VEC, 0, PTR) +#else // VX + #define npyv__loada(PTR) vec_xl(0, PTR) + #define npyv__storea(PTR, VEC) vec_xst(VEC, 0, PTR) +#endif + +// avoid aliasing rules +NPY_FINLINE npy_uint64 *npyv__ptr2u64(const void *ptr) +{ npy_uint64 *ptr64 = (npy_uint64*)ptr; return ptr64; } + +// load lower part +NPY_FINLINE npyv_u64 npyv__loadl(const void *ptr) +{ +#ifdef NPY_HAVE_VSX + #if defined(__clang__) && !defined(__IBMC__) + // vec_promote doesn't support doubleword on clang + return npyv_setall_u64(*npyv__ptr2u64(ptr)); + #else + return vec_promote(*npyv__ptr2u64(ptr), 0); + #endif +#else // VX + return vec_load_len((const unsigned long long*)ptr, 7); +#endif +} +// store lower part +#define npyv__storel(PTR, VEC) \ + *npyv__ptr2u64(PTR) = vec_extract(((npyv_u64)VEC), 0) + +#define npyv__storeh(PTR, VEC) \ + *npyv__ptr2u64(PTR) = vec_extract(((npyv_u64)VEC), 1) + +/**************************** + * load/store + ****************************/ +#define NPYV_IMPL_VEC_MEM(SFX, DW_CAST) \ + NPY_FINLINE npyv_##SFX npyv_load_##SFX(const npyv_lanetype_##SFX *ptr) \ + { return (npyv_##SFX)npyv__load(npyv_##SFX, (const npyv_lanetype_##DW_CAST*)ptr); } \ + NPY_FINLINE npyv_##SFX npyv_loada_##SFX(const npyv_lanetype_##SFX *ptr) \ + { return (npyv_##SFX)npyv__loada((const npyv_lanetype_u32*)ptr); } \ + NPY_FINLINE npyv_##SFX npyv_loads_##SFX(const npyv_lanetype_##SFX *ptr) \ + { return npyv_loada_##SFX(ptr); } \ + NPY_FINLINE npyv_##SFX npyv_loadl_##SFX(const npyv_lanetype_##SFX *ptr) \ + { return (npyv_##SFX)npyv__loadl(ptr); } \ + NPY_FINLINE void npyv_store_##SFX(npyv_lanetype_##SFX *ptr, npyv_##SFX vec) \ + { npyv__store((npyv_lanetype_##DW_CAST*)ptr, (npyv_##DW_CAST)vec); } \ + NPY_FINLINE void npyv_storea_##SFX(npyv_lanetype_##SFX *ptr, npyv_##SFX vec) \ + { npyv__storea((npyv_lanetype_u32*)ptr, (npyv_u32)vec); } \ + NPY_FINLINE void npyv_stores_##SFX(npyv_lanetype_##SFX *ptr, npyv_##SFX vec) \ + { npyv_storea_##SFX(ptr, vec); } \ + NPY_FINLINE void npyv_storel_##SFX(npyv_lanetype_##SFX *ptr, npyv_##SFX vec) \ + { npyv__storel(ptr, vec); } \ + NPY_FINLINE void npyv_storeh_##SFX(npyv_lanetype_##SFX *ptr, npyv_##SFX vec) \ + { npyv__storeh(ptr, vec); } + +NPYV_IMPL_VEC_MEM(u8, u8) +NPYV_IMPL_VEC_MEM(s8, s8) +NPYV_IMPL_VEC_MEM(u16, u16) +NPYV_IMPL_VEC_MEM(s16, s16) +NPYV_IMPL_VEC_MEM(u32, u32) +NPYV_IMPL_VEC_MEM(s32, s32) +NPYV_IMPL_VEC_MEM(u64, f64) +NPYV_IMPL_VEC_MEM(s64, f64) +#if NPY_SIMD_F32 +NPYV_IMPL_VEC_MEM(f32, f32) +#endif +NPYV_IMPL_VEC_MEM(f64, f64) + +/*************************** + * Non-contiguous Load + ***************************/ +//// 32 +NPY_FINLINE npyv_u32 npyv_loadn_u32(const npy_uint32 *ptr, npy_intp stride) +{ + return npyv_set_u32( + ptr[stride * 0], ptr[stride * 1], + ptr[stride * 2], ptr[stride * 3] + ); +} +NPY_FINLINE npyv_s32 npyv_loadn_s32(const npy_int32 *ptr, npy_intp stride) +{ return (npyv_s32)npyv_loadn_u32((const npy_uint32*)ptr, stride); } +#if NPY_SIMD_F32 +NPY_FINLINE npyv_f32 npyv_loadn_f32(const float *ptr, npy_intp stride) +{ return (npyv_f32)npyv_loadn_u32((const npy_uint32*)ptr, stride); } +#endif +//// 64 +NPY_FINLINE npyv_u64 npyv_loadn_u64(const npy_uint64 *ptr, npy_intp stride) +{ return npyv_set_u64(ptr[0], ptr[stride]); } +NPY_FINLINE npyv_s64 npyv_loadn_s64(const npy_int64 *ptr, npy_intp stride) +{ return npyv_set_s64(ptr[0], ptr[stride]); } +NPY_FINLINE npyv_f64 npyv_loadn_f64(const double *ptr, npy_intp stride) +{ return npyv_set_f64(ptr[0], ptr[stride]); } + +//// 64-bit load over 32-bit stride +NPY_FINLINE npyv_u32 npyv_loadn2_u32(const npy_uint32 *ptr, npy_intp stride) +{ return (npyv_u32)npyv_set_u64(*(npy_uint64*)ptr, *(npy_uint64*)(ptr + stride)); } +NPY_FINLINE npyv_s32 npyv_loadn2_s32(const npy_int32 *ptr, npy_intp stride) +{ return (npyv_s32)npyv_set_u64(*(npy_uint64*)ptr, *(npy_uint64*)(ptr + stride)); } +#if NPY_SIMD_F32 +NPY_FINLINE npyv_f32 npyv_loadn2_f32(const float *ptr, npy_intp stride) +{ return (npyv_f32)npyv_set_u64(*(npy_uint64*)ptr, *(npy_uint64*)(ptr + stride)); } +#endif +//// 128-bit load over 64-bit stride +NPY_FINLINE npyv_u64 npyv_loadn2_u64(const npy_uint64 *ptr, npy_intp stride) +{ (void)stride; return npyv_load_u64(ptr); } +NPY_FINLINE npyv_s64 npyv_loadn2_s64(const npy_int64 *ptr, npy_intp stride) +{ (void)stride; return npyv_load_s64(ptr); } +NPY_FINLINE npyv_f64 npyv_loadn2_f64(const double *ptr, npy_intp stride) +{ (void)stride; return npyv_load_f64(ptr); } + +/*************************** + * Non-contiguous Store + ***************************/ +//// 32 +NPY_FINLINE void npyv_storen_u32(npy_uint32 *ptr, npy_intp stride, npyv_u32 a) +{ + ptr[stride * 0] = vec_extract(a, 0); + ptr[stride * 1] = vec_extract(a, 1); + ptr[stride * 2] = vec_extract(a, 2); + ptr[stride * 3] = vec_extract(a, 3); +} +NPY_FINLINE void npyv_storen_s32(npy_int32 *ptr, npy_intp stride, npyv_s32 a) +{ npyv_storen_u32((npy_uint32*)ptr, stride, (npyv_u32)a); } +#if NPY_SIMD_F32 +NPY_FINLINE void npyv_storen_f32(float *ptr, npy_intp stride, npyv_f32 a) +{ npyv_storen_u32((npy_uint32*)ptr, stride, (npyv_u32)a); } +#endif +//// 64 +NPY_FINLINE void npyv_storen_u64(npy_uint64 *ptr, npy_intp stride, npyv_u64 a) +{ + ptr[stride * 0] = vec_extract(a, 0); + ptr[stride * 1] = vec_extract(a, 1); +} +NPY_FINLINE void npyv_storen_s64(npy_int64 *ptr, npy_intp stride, npyv_s64 a) +{ npyv_storen_u64((npy_uint64*)ptr, stride, (npyv_u64)a); } +NPY_FINLINE void npyv_storen_f64(double *ptr, npy_intp stride, npyv_f64 a) +{ npyv_storen_u64((npy_uint64*)ptr, stride, (npyv_u64)a); } + +//// 64-bit store over 32-bit stride +NPY_FINLINE void npyv_storen2_u32(npy_uint32 *ptr, npy_intp stride, npyv_u32 a) +{ + *(npy_uint64*)ptr = vec_extract((npyv_u64)a, 0); + *(npy_uint64*)(ptr + stride) = vec_extract((npyv_u64)a, 1); +} +NPY_FINLINE void npyv_storen2_s32(npy_int32 *ptr, npy_intp stride, npyv_s32 a) +{ npyv_storen2_u32((npy_uint32*)ptr, stride, (npyv_u32)a); } +#if NPY_SIMD_F32 +NPY_FINLINE void npyv_storen2_f32(float *ptr, npy_intp stride, npyv_f32 a) +{ npyv_storen2_u32((npy_uint32*)ptr, stride, (npyv_u32)a); } +#endif +//// 128-bit store over 64-bit stride +NPY_FINLINE void npyv_storen2_u64(npy_uint64 *ptr, npy_intp stride, npyv_u64 a) +{ (void)stride; npyv_store_u64(ptr, a); } +NPY_FINLINE void npyv_storen2_s64(npy_int64 *ptr, npy_intp stride, npyv_s64 a) +{ (void)stride; npyv_store_s64(ptr, a); } +NPY_FINLINE void npyv_storen2_f64(double *ptr, npy_intp stride, npyv_f64 a) +{ (void)stride; npyv_store_f64(ptr, a); } + +/********************************* + * Partial Load + *********************************/ +//// 32 +NPY_FINLINE npyv_s32 npyv_load_till_s32(const npy_int32 *ptr, npy_uintp nlane, npy_int32 fill) +{ + assert(nlane > 0); + npyv_s32 vfill = npyv_setall_s32(fill); +#ifdef NPY_HAVE_VX + const unsigned blane = (nlane > 4) ? 4 : nlane; + const npyv_u32 steps = npyv_set_u32(0, 1, 2, 3); + const npyv_u32 vlane = npyv_setall_u32(blane); + const npyv_b32 mask = vec_cmpgt(vlane, steps); + npyv_s32 a = vec_load_len(ptr, blane*4-1); + a = vec_sel(vfill, a, mask); +#else + npyv_s32 a; + switch(nlane) { + case 1: + a = vec_insert(ptr[0], vfill, 0); + break; + case 2: + a = (npyv_s32)vec_insert( + *npyv__ptr2u64(ptr), (npyv_u64)vfill, 0 + ); + break; + case 3: + vfill = vec_insert(ptr[2], vfill, 2); + a = (npyv_s32)vec_insert( + *npyv__ptr2u64(ptr), (npyv_u64)vfill, 0 + ); + break; + default: + return npyv_load_s32(ptr); + } +#endif +#if NPY_SIMD_GUARD_PARTIAL_LOAD + volatile npyv_s32 workaround = a; + a = vec_or(workaround, a); +#endif + return a; +} +// fill zero to rest lanes +NPY_FINLINE npyv_s32 npyv_load_tillz_s32(const npy_int32 *ptr, npy_uintp nlane) +{ +#ifdef NPY_HAVE_VX + unsigned blane = (nlane > 4) ? 4 : nlane; + return vec_load_len(ptr, blane*4-1); +#else + return npyv_load_till_s32(ptr, nlane, 0); +#endif +} +//// 64 +NPY_FINLINE npyv_s64 npyv_load_till_s64(const npy_int64 *ptr, npy_uintp nlane, npy_int64 fill) +{ + assert(nlane > 0); + if (nlane == 1) { + npyv_s64 r = npyv_set_s64(ptr[0], fill); + #if NPY_SIMD_GUARD_PARTIAL_LOAD + volatile npyv_s64 workaround = r; + r = vec_or(workaround, r); + #endif + return r; + } + return npyv_load_s64(ptr); +} +// fill zero to rest lanes +NPY_FINLINE npyv_s64 npyv_load_tillz_s64(const npy_int64 *ptr, npy_uintp nlane) +{ +#ifdef NPY_HAVE_VX + unsigned blane = (nlane > 2) ? 2 : nlane; + return vec_load_len((const signed long long*)ptr, blane*8-1); +#else + return npyv_load_till_s64(ptr, nlane, 0); +#endif +} +//// 64-bit nlane +NPY_FINLINE npyv_s32 npyv_load2_till_s32(const npy_int32 *ptr, npy_uintp nlane, + npy_int32 fill_lo, npy_int32 fill_hi) +{ + assert(nlane > 0); + if (nlane == 1) { + npyv_s32 r = npyv_set_s32(ptr[0], ptr[1], fill_lo, fill_hi); + #if NPY_SIMD_GUARD_PARTIAL_LOAD + volatile npyv_s32 workaround = r; + r = vec_or(workaround, r); + #endif + return r; + } + return npyv_load_s32(ptr); +} +// fill zero to rest lanes +NPY_FINLINE npyv_s32 npyv_load2_tillz_s32(const npy_int32 *ptr, npy_uintp nlane) +{ return (npyv_s32)npyv_load_tillz_s64((const npy_int64*)ptr, nlane); } + +//// 128-bit nlane +NPY_FINLINE npyv_s64 npyv_load2_till_s64(const npy_int64 *ptr, npy_uintp nlane, + npy_int64 fill_lo, npy_int64 fill_hi) +{ (void)nlane; (void)fill_lo; (void)fill_hi; return npyv_load_s64(ptr); } + +NPY_FINLINE npyv_s64 npyv_load2_tillz_s64(const npy_int64 *ptr, npy_uintp nlane) +{ (void)nlane; return npyv_load_s64(ptr); } +/********************************* + * Non-contiguous partial load + *********************************/ +//// 32 +NPY_FINLINE npyv_s32 +npyv_loadn_till_s32(const npy_int32 *ptr, npy_intp stride, npy_uintp nlane, npy_int32 fill) +{ + assert(nlane > 0); + npyv_s32 vfill = npyv_setall_s32(fill); + switch(nlane) { + case 3: + vfill = vec_insert(ptr[stride*2], vfill, 2); + case 2: + vfill = vec_insert(ptr[stride], vfill, 1); + case 1: + vfill = vec_insert(*ptr, vfill, 0); + break; + default: + return npyv_loadn_s32(ptr, stride); + } // switch +#if NPY_SIMD_GUARD_PARTIAL_LOAD + volatile npyv_s32 workaround = vfill; + vfill = vec_or(workaround, vfill); +#endif + return vfill; +} +// fill zero to rest lanes +NPY_FINLINE npyv_s32 +npyv_loadn_tillz_s32(const npy_int32 *ptr, npy_intp stride, npy_uintp nlane) +{ return npyv_loadn_till_s32(ptr, stride, nlane, 0); } +//// 64 +NPY_FINLINE npyv_s64 +npyv_loadn_till_s64(const npy_int64 *ptr, npy_intp stride, npy_uintp nlane, npy_int64 fill) +{ + assert(nlane > 0); + if (nlane == 1) { + return npyv_load_till_s64(ptr, nlane, fill); + } + return npyv_loadn_s64(ptr, stride); +} +// fill zero to rest lanes +NPY_FINLINE npyv_s64 npyv_loadn_tillz_s64(const npy_int64 *ptr, npy_intp stride, npy_uintp nlane) +{ return npyv_loadn_till_s64(ptr, stride, nlane, 0); } + +//// 64-bit load over 32-bit stride +NPY_FINLINE npyv_s32 npyv_loadn2_till_s32(const npy_int32 *ptr, npy_intp stride, npy_uintp nlane, + npy_int32 fill_lo, npy_int32 fill_hi) +{ + assert(nlane > 0); + if (nlane == 1) { + npyv_s32 r = npyv_set_s32(ptr[0], ptr[1], fill_lo, fill_hi); + #if NPY_SIMD_GUARD_PARTIAL_LOAD + volatile npyv_s32 workaround = r; + r = vec_or(workaround, r); + #endif + return r; + } + return npyv_loadn2_s32(ptr, stride); +} +NPY_FINLINE npyv_s32 npyv_loadn2_tillz_s32(const npy_int32 *ptr, npy_intp stride, npy_uintp nlane) +{ + assert(nlane > 0); + if (nlane == 1) { + npyv_s32 r = (npyv_s32)npyv_set_s64(*(npy_int64*)ptr, 0); + #if NPY_SIMD_GUARD_PARTIAL_LOAD + volatile npyv_s32 workaround = r; + r = vec_or(workaround, r); + #endif + return r; + } + return npyv_loadn2_s32(ptr, stride); +} + +//// 128-bit load over 64-bit stride +NPY_FINLINE npyv_s64 npyv_loadn2_till_s64(const npy_int64 *ptr, npy_intp stride, npy_uintp nlane, + npy_int64 fill_lo, npy_int64 fill_hi) +{ assert(nlane > 0); (void)stride; (void)nlane; (void)fill_lo; (void)fill_hi; return npyv_load_s64(ptr); } + +NPY_FINLINE npyv_s64 npyv_loadn2_tillz_s64(const npy_int64 *ptr, npy_intp stride, npy_uintp nlane) +{ assert(nlane > 0); (void)stride; (void)nlane; return npyv_load_s64(ptr); } + +/********************************* + * Partial store + *********************************/ +//// 32 +NPY_FINLINE void npyv_store_till_s32(npy_int32 *ptr, npy_uintp nlane, npyv_s32 a) +{ + assert(nlane > 0); +#ifdef NPY_HAVE_VX + unsigned blane = (nlane > 4) ? 4 : nlane; + vec_store_len(a, ptr, blane*4-1); +#else + switch(nlane) { + case 1: + *ptr = vec_extract(a, 0); + break; + case 2: + npyv_storel_s32(ptr, a); + break; + case 3: + npyv_storel_s32(ptr, a); + ptr[2] = vec_extract(a, 2); + break; + default: + npyv_store_s32(ptr, a); + } +#endif +} +//// 64 +NPY_FINLINE void npyv_store_till_s64(npy_int64 *ptr, npy_uintp nlane, npyv_s64 a) +{ + assert(nlane > 0); +#ifdef NPY_HAVE_VX + unsigned blane = (nlane > 2) ? 2 : nlane; + vec_store_len(a, (signed long long*)ptr, blane*8-1); +#else + if (nlane == 1) { + npyv_storel_s64(ptr, a); + return; + } + npyv_store_s64(ptr, a); +#endif +} + +//// 64-bit nlane +NPY_FINLINE void npyv_store2_till_s32(npy_int32 *ptr, npy_uintp nlane, npyv_s32 a) +{ npyv_store_till_s64((npy_int64*)ptr, nlane, (npyv_s64)a); } + +//// 128-bit nlane +NPY_FINLINE void npyv_store2_till_s64(npy_int64 *ptr, npy_uintp nlane, npyv_s64 a) +{ + assert(nlane > 0); (void)nlane; + npyv_store_s64(ptr, a); +} + +/********************************* + * Non-contiguous partial store + *********************************/ +//// 32 +NPY_FINLINE void npyv_storen_till_s32(npy_int32 *ptr, npy_intp stride, npy_uintp nlane, npyv_s32 a) +{ + assert(nlane > 0); + ptr[stride*0] = vec_extract(a, 0); + switch(nlane) { + case 1: + return; + case 2: + ptr[stride*1] = vec_extract(a, 1); + return; + case 3: + ptr[stride*1] = vec_extract(a, 1); + ptr[stride*2] = vec_extract(a, 2); + return; + default: + ptr[stride*1] = vec_extract(a, 1); + ptr[stride*2] = vec_extract(a, 2); + ptr[stride*3] = vec_extract(a, 3); + } +} +//// 64 +NPY_FINLINE void npyv_storen_till_s64(npy_int64 *ptr, npy_intp stride, npy_uintp nlane, npyv_s64 a) +{ + assert(nlane > 0); + if (nlane == 1) { + npyv_storel_s64(ptr, a); + return; + } + npyv_storen_s64(ptr, stride, a); +} + +//// 64-bit store over 32-bit stride +NPY_FINLINE void npyv_storen2_till_s32(npy_int32 *ptr, npy_intp stride, npy_uintp nlane, npyv_s32 a) +{ + assert(nlane > 0); + npyv_storel_s32(ptr, a); + if (nlane > 1) { + npyv_storeh_s32(ptr + stride, a); + } +} + +//// 128-bit store over 64-bit stride +NPY_FINLINE void npyv_storen2_till_s64(npy_int64 *ptr, npy_intp stride, npy_uintp nlane, npyv_s64 a) +{ assert(nlane > 0); (void)stride; (void)nlane; npyv_store_s64(ptr, a); } + +/***************************************************************** + * Implement partial load/store for u32/f32/u64/f64... via casting + *****************************************************************/ +#define NPYV_IMPL_VEC_REST_PARTIAL_TYPES(F_SFX, T_SFX) \ + NPY_FINLINE npyv_##F_SFX npyv_load_till_##F_SFX \ + (const npyv_lanetype_##F_SFX *ptr, npy_uintp nlane, npyv_lanetype_##F_SFX fill) \ + { \ + union { \ + npyv_lanetype_##F_SFX from_##F_SFX; \ + npyv_lanetype_##T_SFX to_##T_SFX; \ + } pun; \ + pun.from_##F_SFX = fill; \ + return npyv_reinterpret_##F_SFX##_##T_SFX(npyv_load_till_##T_SFX( \ + (const npyv_lanetype_##T_SFX *)ptr, nlane, pun.to_##T_SFX \ + )); \ + } \ + NPY_FINLINE npyv_##F_SFX npyv_loadn_till_##F_SFX \ + (const npyv_lanetype_##F_SFX *ptr, npy_intp stride, npy_uintp nlane, \ + npyv_lanetype_##F_SFX fill) \ + { \ + union { \ + npyv_lanetype_##F_SFX from_##F_SFX; \ + npyv_lanetype_##T_SFX to_##T_SFX; \ + } pun; \ + pun.from_##F_SFX = fill; \ + return npyv_reinterpret_##F_SFX##_##T_SFX(npyv_loadn_till_##T_SFX( \ + (const npyv_lanetype_##T_SFX *)ptr, stride, nlane, pun.to_##T_SFX \ + )); \ + } \ + NPY_FINLINE npyv_##F_SFX npyv_load_tillz_##F_SFX \ + (const npyv_lanetype_##F_SFX *ptr, npy_uintp nlane) \ + { \ + return npyv_reinterpret_##F_SFX##_##T_SFX(npyv_load_tillz_##T_SFX( \ + (const npyv_lanetype_##T_SFX *)ptr, nlane \ + )); \ + } \ + NPY_FINLINE npyv_##F_SFX npyv_loadn_tillz_##F_SFX \ + (const npyv_lanetype_##F_SFX *ptr, npy_intp stride, npy_uintp nlane) \ + { \ + return npyv_reinterpret_##F_SFX##_##T_SFX(npyv_loadn_tillz_##T_SFX( \ + (const npyv_lanetype_##T_SFX *)ptr, stride, nlane \ + )); \ + } \ + NPY_FINLINE void npyv_store_till_##F_SFX \ + (npyv_lanetype_##F_SFX *ptr, npy_uintp nlane, npyv_##F_SFX a) \ + { \ + npyv_store_till_##T_SFX( \ + (npyv_lanetype_##T_SFX *)ptr, nlane, \ + npyv_reinterpret_##T_SFX##_##F_SFX(a) \ + ); \ + } \ + NPY_FINLINE void npyv_storen_till_##F_SFX \ + (npyv_lanetype_##F_SFX *ptr, npy_intp stride, npy_uintp nlane, npyv_##F_SFX a) \ + { \ + npyv_storen_till_##T_SFX( \ + (npyv_lanetype_##T_SFX *)ptr, stride, nlane, \ + npyv_reinterpret_##T_SFX##_##F_SFX(a) \ + ); \ + } + +NPYV_IMPL_VEC_REST_PARTIAL_TYPES(u32, s32) +#if NPY_SIMD_F32 +NPYV_IMPL_VEC_REST_PARTIAL_TYPES(f32, s32) +#endif +NPYV_IMPL_VEC_REST_PARTIAL_TYPES(u64, s64) +NPYV_IMPL_VEC_REST_PARTIAL_TYPES(f64, s64) + +// 128-bit/64-bit stride +#define NPYV_IMPL_VEC_REST_PARTIAL_TYPES_PAIR(F_SFX, T_SFX) \ + NPY_FINLINE npyv_##F_SFX npyv_load2_till_##F_SFX \ + (const npyv_lanetype_##F_SFX *ptr, npy_uintp nlane, \ + npyv_lanetype_##F_SFX fill_lo, npyv_lanetype_##F_SFX fill_hi) \ + { \ + union pun { \ + npyv_lanetype_##F_SFX from_##F_SFX; \ + npyv_lanetype_##T_SFX to_##T_SFX; \ + }; \ + union pun pun_lo; \ + union pun pun_hi; \ + pun_lo.from_##F_SFX = fill_lo; \ + pun_hi.from_##F_SFX = fill_hi; \ + return npyv_reinterpret_##F_SFX##_##T_SFX(npyv_load2_till_##T_SFX( \ + (const npyv_lanetype_##T_SFX *)ptr, nlane, pun_lo.to_##T_SFX, pun_hi.to_##T_SFX \ + )); \ + } \ + NPY_FINLINE npyv_##F_SFX npyv_loadn2_till_##F_SFX \ + (const npyv_lanetype_##F_SFX *ptr, npy_intp stride, npy_uintp nlane, \ + npyv_lanetype_##F_SFX fill_lo, npyv_lanetype_##F_SFX fill_hi) \ + { \ + union pun { \ + npyv_lanetype_##F_SFX from_##F_SFX; \ + npyv_lanetype_##T_SFX to_##T_SFX; \ + }; \ + union pun pun_lo; \ + union pun pun_hi; \ + pun_lo.from_##F_SFX = fill_lo; \ + pun_hi.from_##F_SFX = fill_hi; \ + return npyv_reinterpret_##F_SFX##_##T_SFX(npyv_loadn2_till_##T_SFX( \ + (const npyv_lanetype_##T_SFX *)ptr, stride, nlane, pun_lo.to_##T_SFX, \ + pun_hi.to_##T_SFX \ + )); \ + } \ + NPY_FINLINE npyv_##F_SFX npyv_load2_tillz_##F_SFX \ + (const npyv_lanetype_##F_SFX *ptr, npy_uintp nlane) \ + { \ + return npyv_reinterpret_##F_SFX##_##T_SFX(npyv_load2_tillz_##T_SFX( \ + (const npyv_lanetype_##T_SFX *)ptr, nlane \ + )); \ + } \ + NPY_FINLINE npyv_##F_SFX npyv_loadn2_tillz_##F_SFX \ + (const npyv_lanetype_##F_SFX *ptr, npy_intp stride, npy_uintp nlane) \ + { \ + return npyv_reinterpret_##F_SFX##_##T_SFX(npyv_loadn2_tillz_##T_SFX( \ + (const npyv_lanetype_##T_SFX *)ptr, stride, nlane \ + )); \ + } \ + NPY_FINLINE void npyv_store2_till_##F_SFX \ + (npyv_lanetype_##F_SFX *ptr, npy_uintp nlane, npyv_##F_SFX a) \ + { \ + npyv_store2_till_##T_SFX( \ + (npyv_lanetype_##T_SFX *)ptr, nlane, \ + npyv_reinterpret_##T_SFX##_##F_SFX(a) \ + ); \ + } \ + NPY_FINLINE void npyv_storen2_till_##F_SFX \ + (npyv_lanetype_##F_SFX *ptr, npy_intp stride, npy_uintp nlane, npyv_##F_SFX a) \ + { \ + npyv_storen2_till_##T_SFX( \ + (npyv_lanetype_##T_SFX *)ptr, stride, nlane, \ + npyv_reinterpret_##T_SFX##_##F_SFX(a) \ + ); \ + } + +NPYV_IMPL_VEC_REST_PARTIAL_TYPES_PAIR(u32, s32) +#if NPY_SIMD_F32 +NPYV_IMPL_VEC_REST_PARTIAL_TYPES_PAIR(f32, s32) +#endif +NPYV_IMPL_VEC_REST_PARTIAL_TYPES_PAIR(u64, s64) +NPYV_IMPL_VEC_REST_PARTIAL_TYPES_PAIR(f64, s64) + +/************************************************************ + * de-interlave load / interleave contiguous store + ************************************************************/ +// two channels +#define NPYV_IMPL_VEC_MEM_INTERLEAVE(SFX) \ + NPY_FINLINE npyv_##SFX##x2 npyv_zip_##SFX(npyv_##SFX, npyv_##SFX); \ + NPY_FINLINE npyv_##SFX##x2 npyv_unzip_##SFX(npyv_##SFX, npyv_##SFX); \ + NPY_FINLINE npyv_##SFX##x2 npyv_load_##SFX##x2( \ + const npyv_lanetype_##SFX *ptr \ + ) { \ + return npyv_unzip_##SFX( \ + npyv_load_##SFX(ptr), npyv_load_##SFX(ptr+npyv_nlanes_##SFX) \ + ); \ + } \ + NPY_FINLINE void npyv_store_##SFX##x2( \ + npyv_lanetype_##SFX *ptr, npyv_##SFX##x2 v \ + ) { \ + npyv_##SFX##x2 zip = npyv_zip_##SFX(v.val[0], v.val[1]); \ + npyv_store_##SFX(ptr, zip.val[0]); \ + npyv_store_##SFX(ptr + npyv_nlanes_##SFX, zip.val[1]); \ + } + +NPYV_IMPL_VEC_MEM_INTERLEAVE(u8) +NPYV_IMPL_VEC_MEM_INTERLEAVE(s8) +NPYV_IMPL_VEC_MEM_INTERLEAVE(u16) +NPYV_IMPL_VEC_MEM_INTERLEAVE(s16) +NPYV_IMPL_VEC_MEM_INTERLEAVE(u32) +NPYV_IMPL_VEC_MEM_INTERLEAVE(s32) +NPYV_IMPL_VEC_MEM_INTERLEAVE(u64) +NPYV_IMPL_VEC_MEM_INTERLEAVE(s64) +#if NPY_SIMD_F32 +NPYV_IMPL_VEC_MEM_INTERLEAVE(f32) +#endif +NPYV_IMPL_VEC_MEM_INTERLEAVE(f64) + +/********************************* + * Lookup table + *********************************/ +// uses vector as indexes into a table +// that contains 32 elements of float32. +NPY_FINLINE npyv_u32 npyv_lut32_u32(const npy_uint32 *table, npyv_u32 idx) +{ + const unsigned i0 = vec_extract(idx, 0); + const unsigned i1 = vec_extract(idx, 1); + const unsigned i2 = vec_extract(idx, 2); + const unsigned i3 = vec_extract(idx, 3); + npyv_u32 r = vec_promote(table[i0], 0); + r = vec_insert(table[i1], r, 1); + r = vec_insert(table[i2], r, 2); + r = vec_insert(table[i3], r, 3); + return r; +} +NPY_FINLINE npyv_s32 npyv_lut32_s32(const npy_int32 *table, npyv_u32 idx) +{ return (npyv_s32)npyv_lut32_u32((const npy_uint32*)table, idx); } +#if NPY_SIMD_F32 + NPY_FINLINE npyv_f32 npyv_lut32_f32(const float *table, npyv_u32 idx) + { return (npyv_f32)npyv_lut32_u32((const npy_uint32*)table, idx); } +#endif +// uses vector as indexes into a table +// that contains 16 elements of float64. +NPY_FINLINE npyv_f64 npyv_lut16_f64(const double *table, npyv_u64 idx) +{ +#ifdef NPY_HAVE_VX + const unsigned i0 = vec_extract((npyv_u32)idx, 1); + const unsigned i1 = vec_extract((npyv_u32)idx, 3); +#else + const unsigned i0 = vec_extract((npyv_u32)idx, 0); + const unsigned i1 = vec_extract((npyv_u32)idx, 2); +#endif + npyv_f64 r = vec_promote(table[i0], 0); + r = vec_insert(table[i1], r, 1); + return r; +} +NPY_FINLINE npyv_u64 npyv_lut16_u64(const npy_uint64 *table, npyv_u64 idx) +{ return npyv_reinterpret_u64_f64(npyv_lut16_f64((const double*)table, idx)); } +NPY_FINLINE npyv_s64 npyv_lut16_s64(const npy_int64 *table, npyv_u64 idx) +{ return npyv_reinterpret_s64_f64(npyv_lut16_f64((const double*)table, idx)); } + +#endif // _NPY_SIMD_VEC_MEMORY_H diff --git a/mkl_umath/src/npyv/vec/misc.h b/mkl_umath/src/npyv/vec/misc.h new file mode 100644 index 00000000..79c194d9 --- /dev/null +++ b/mkl_umath/src/npyv/vec/misc.h @@ -0,0 +1,233 @@ +#ifndef NPY_SIMD + #error "Not a standalone header" +#endif + +#ifndef _NPY_SIMD_VEC_MISC_H +#define _NPY_SIMD_VEC_MISC_H + +// vector with zero lanes +#define npyv_zero_u8() ((npyv_u8) npyv_setall_s32(0)) +#define npyv_zero_s8() ((npyv_s8) npyv_setall_s32(0)) +#define npyv_zero_u16() ((npyv_u16) npyv_setall_s32(0)) +#define npyv_zero_s16() ((npyv_s16) npyv_setall_s32(0)) +#define npyv_zero_u32() npyv_setall_u32(0) +#define npyv_zero_s32() npyv_setall_s32(0) +#define npyv_zero_u64() ((npyv_u64) npyv_setall_s32(0)) +#define npyv_zero_s64() ((npyv_s64) npyv_setall_s32(0)) +#if NPY_SIMD_F32 + #define npyv_zero_f32() npyv_setall_f32(0.0f) +#endif +#define npyv_zero_f64() npyv_setall_f64(0.0) + +// vector with a specific value set to all lanes +// the safest way to generate vsplti* and vsplt* instructions +#define NPYV_IMPL_VEC_SPLTB(T_VEC, V) ((T_VEC){V, V, V, V, V, V, V, V, V, V, V, V, V, V, V, V}) +#define NPYV_IMPL_VEC_SPLTH(T_VEC, V) ((T_VEC){V, V, V, V, V, V, V, V}) +#define NPYV_IMPL_VEC_SPLTW(T_VEC, V) ((T_VEC){V, V, V, V}) +#define NPYV_IMPL_VEC_SPLTD(T_VEC, V) ((T_VEC){V, V}) + +#define npyv_setall_u8(VAL) NPYV_IMPL_VEC_SPLTB(npyv_u8, (unsigned char)(VAL)) +#define npyv_setall_s8(VAL) NPYV_IMPL_VEC_SPLTB(npyv_s8, (signed char)(VAL)) +#define npyv_setall_u16(VAL) NPYV_IMPL_VEC_SPLTH(npyv_u16, (unsigned short)(VAL)) +#define npyv_setall_s16(VAL) NPYV_IMPL_VEC_SPLTH(npyv_s16, (short)(VAL)) +#define npyv_setall_u32(VAL) NPYV_IMPL_VEC_SPLTW(npyv_u32, (unsigned int)(VAL)) +#define npyv_setall_s32(VAL) NPYV_IMPL_VEC_SPLTW(npyv_s32, (int)(VAL)) +#if NPY_SIMD_F32 + #define npyv_setall_f32(VAL) NPYV_IMPL_VEC_SPLTW(npyv_f32, (VAL)) +#endif +#define npyv_setall_u64(VAL) NPYV_IMPL_VEC_SPLTD(npyv_u64, (npy_uint64)(VAL)) +#define npyv_setall_s64(VAL) NPYV_IMPL_VEC_SPLTD(npyv_s64, (npy_int64)(VAL)) +#define npyv_setall_f64(VAL) NPYV_IMPL_VEC_SPLTD(npyv_f64, VAL) + +// vector with specific values set to each lane and +// set a specific value to all remained lanes +#define npyv_setf_u8(FILL, ...) ((npyv_u8){NPYV__SET_FILL_16(unsigned char, FILL, __VA_ARGS__)}) +#define npyv_setf_s8(FILL, ...) ((npyv_s8){NPYV__SET_FILL_16(signed char, FILL, __VA_ARGS__)}) +#define npyv_setf_u16(FILL, ...) ((npyv_u16){NPYV__SET_FILL_8(unsigned short, FILL, __VA_ARGS__)}) +#define npyv_setf_s16(FILL, ...) ((npyv_s16){NPYV__SET_FILL_8(short, FILL, __VA_ARGS__)}) +#define npyv_setf_u32(FILL, ...) ((npyv_u32){NPYV__SET_FILL_4(unsigned int, FILL, __VA_ARGS__)}) +#define npyv_setf_s32(FILL, ...) ((npyv_s32){NPYV__SET_FILL_4(int, FILL, __VA_ARGS__)}) +#define npyv_setf_u64(FILL, ...) ((npyv_u64){NPYV__SET_FILL_2(npy_uint64, FILL, __VA_ARGS__)}) +#define npyv_setf_s64(FILL, ...) ((npyv_s64){NPYV__SET_FILL_2(npy_int64, FILL, __VA_ARGS__)}) +#if NPY_SIMD_F32 + #define npyv_setf_f32(FILL, ...) ((npyv_f32){NPYV__SET_FILL_4(float, FILL, __VA_ARGS__)}) +#endif +#define npyv_setf_f64(FILL, ...) ((npyv_f64){NPYV__SET_FILL_2(double, FILL, __VA_ARGS__)}) + +// vector with specific values set to each lane and +// set zero to all remained lanes +#define npyv_set_u8(...) npyv_setf_u8(0, __VA_ARGS__) +#define npyv_set_s8(...) npyv_setf_s8(0, __VA_ARGS__) +#define npyv_set_u16(...) npyv_setf_u16(0, __VA_ARGS__) +#define npyv_set_s16(...) npyv_setf_s16(0, __VA_ARGS__) +#define npyv_set_u32(...) npyv_setf_u32(0, __VA_ARGS__) +#define npyv_set_s32(...) npyv_setf_s32(0, __VA_ARGS__) +#define npyv_set_u64(...) npyv_setf_u64(0, __VA_ARGS__) +#define npyv_set_s64(...) npyv_setf_s64(0, __VA_ARGS__) +#if NPY_SIMD_F32 + #define npyv_set_f32(...) npyv_setf_f32(0, __VA_ARGS__) +#endif +#define npyv_set_f64(...) npyv_setf_f64(0, __VA_ARGS__) + +// Per lane select +#define npyv_select_u8(MASK, A, B) vec_sel(B, A, MASK) +#define npyv_select_s8 npyv_select_u8 +#define npyv_select_u16 npyv_select_u8 +#define npyv_select_s16 npyv_select_u8 +#define npyv_select_u32 npyv_select_u8 +#define npyv_select_s32 npyv_select_u8 +#define npyv_select_u64 npyv_select_u8 +#define npyv_select_s64 npyv_select_u8 +#if NPY_SIMD_F32 + #define npyv_select_f32 npyv_select_u8 +#endif +#define npyv_select_f64 npyv_select_u8 + +// extract the first vector's lane +#define npyv_extract0_u8(A) ((npy_uint8)vec_extract(A, 0)) +#define npyv_extract0_s8(A) ((npy_int8)vec_extract(A, 0)) +#define npyv_extract0_u16(A) ((npy_uint16)vec_extract(A, 0)) +#define npyv_extract0_s16(A) ((npy_int16)vec_extract(A, 0)) +#define npyv_extract0_u32(A) ((npy_uint32)vec_extract(A, 0)) +#define npyv_extract0_s32(A) ((npy_int32)vec_extract(A, 0)) +#define npyv_extract0_u64(A) ((npy_uint64)vec_extract(A, 0)) +#define npyv_extract0_s64(A) ((npy_int64)vec_extract(A, 0)) +#if NPY_SIMD_F32 + #define npyv_extract0_f32(A) vec_extract(A, 0) +#endif +#define npyv_extract0_f64(A) vec_extract(A, 0) + +// Reinterpret +#define npyv_reinterpret_u8_u8(X) X +#define npyv_reinterpret_u8_s8(X) ((npyv_u8)X) +#define npyv_reinterpret_u8_u16 npyv_reinterpret_u8_s8 +#define npyv_reinterpret_u8_s16 npyv_reinterpret_u8_s8 +#define npyv_reinterpret_u8_u32 npyv_reinterpret_u8_s8 +#define npyv_reinterpret_u8_s32 npyv_reinterpret_u8_s8 +#define npyv_reinterpret_u8_u64 npyv_reinterpret_u8_s8 +#define npyv_reinterpret_u8_s64 npyv_reinterpret_u8_s8 +#if NPY_SIMD_F32 + #define npyv_reinterpret_u8_f32 npyv_reinterpret_u8_s8 +#endif +#define npyv_reinterpret_u8_f64 npyv_reinterpret_u8_s8 + +#define npyv_reinterpret_s8_s8(X) X +#define npyv_reinterpret_s8_u8(X) ((npyv_s8)X) +#define npyv_reinterpret_s8_u16 npyv_reinterpret_s8_u8 +#define npyv_reinterpret_s8_s16 npyv_reinterpret_s8_u8 +#define npyv_reinterpret_s8_u32 npyv_reinterpret_s8_u8 +#define npyv_reinterpret_s8_s32 npyv_reinterpret_s8_u8 +#define npyv_reinterpret_s8_u64 npyv_reinterpret_s8_u8 +#define npyv_reinterpret_s8_s64 npyv_reinterpret_s8_u8 +#if NPY_SIMD_F32 + #define npyv_reinterpret_s8_f32 npyv_reinterpret_s8_u8 +#endif +#define npyv_reinterpret_s8_f64 npyv_reinterpret_s8_u8 + +#define npyv_reinterpret_u16_u16(X) X +#define npyv_reinterpret_u16_u8(X) ((npyv_u16)X) +#define npyv_reinterpret_u16_s8 npyv_reinterpret_u16_u8 +#define npyv_reinterpret_u16_s16 npyv_reinterpret_u16_u8 +#define npyv_reinterpret_u16_u32 npyv_reinterpret_u16_u8 +#define npyv_reinterpret_u16_s32 npyv_reinterpret_u16_u8 +#define npyv_reinterpret_u16_u64 npyv_reinterpret_u16_u8 +#define npyv_reinterpret_u16_s64 npyv_reinterpret_u16_u8 +#if NPY_SIMD_F32 + #define npyv_reinterpret_u16_f32 npyv_reinterpret_u16_u8 +#endif +#define npyv_reinterpret_u16_f64 npyv_reinterpret_u16_u8 + +#define npyv_reinterpret_s16_s16(X) X +#define npyv_reinterpret_s16_u8(X) ((npyv_s16)X) +#define npyv_reinterpret_s16_s8 npyv_reinterpret_s16_u8 +#define npyv_reinterpret_s16_u16 npyv_reinterpret_s16_u8 +#define npyv_reinterpret_s16_u32 npyv_reinterpret_s16_u8 +#define npyv_reinterpret_s16_s32 npyv_reinterpret_s16_u8 +#define npyv_reinterpret_s16_u64 npyv_reinterpret_s16_u8 +#define npyv_reinterpret_s16_s64 npyv_reinterpret_s16_u8 +#if NPY_SIMD_F32 + #define npyv_reinterpret_s16_f32 npyv_reinterpret_s16_u8 +#endif +#define npyv_reinterpret_s16_f64 npyv_reinterpret_s16_u8 + +#define npyv_reinterpret_u32_u32(X) X +#define npyv_reinterpret_u32_u8(X) ((npyv_u32)X) +#define npyv_reinterpret_u32_s8 npyv_reinterpret_u32_u8 +#define npyv_reinterpret_u32_u16 npyv_reinterpret_u32_u8 +#define npyv_reinterpret_u32_s16 npyv_reinterpret_u32_u8 +#define npyv_reinterpret_u32_s32 npyv_reinterpret_u32_u8 +#define npyv_reinterpret_u32_u64 npyv_reinterpret_u32_u8 +#define npyv_reinterpret_u32_s64 npyv_reinterpret_u32_u8 +#if NPY_SIMD_F32 + #define npyv_reinterpret_u32_f32 npyv_reinterpret_u32_u8 +#endif +#define npyv_reinterpret_u32_f64 npyv_reinterpret_u32_u8 + +#define npyv_reinterpret_s32_s32(X) X +#define npyv_reinterpret_s32_u8(X) ((npyv_s32)X) +#define npyv_reinterpret_s32_s8 npyv_reinterpret_s32_u8 +#define npyv_reinterpret_s32_u16 npyv_reinterpret_s32_u8 +#define npyv_reinterpret_s32_s16 npyv_reinterpret_s32_u8 +#define npyv_reinterpret_s32_u32 npyv_reinterpret_s32_u8 +#define npyv_reinterpret_s32_u64 npyv_reinterpret_s32_u8 +#define npyv_reinterpret_s32_s64 npyv_reinterpret_s32_u8 +#if NPY_SIMD_F32 + #define npyv_reinterpret_s32_f32 npyv_reinterpret_s32_u8 +#endif +#define npyv_reinterpret_s32_f64 npyv_reinterpret_s32_u8 + +#define npyv_reinterpret_u64_u64(X) X +#define npyv_reinterpret_u64_u8(X) ((npyv_u64)X) +#define npyv_reinterpret_u64_s8 npyv_reinterpret_u64_u8 +#define npyv_reinterpret_u64_u16 npyv_reinterpret_u64_u8 +#define npyv_reinterpret_u64_s16 npyv_reinterpret_u64_u8 +#define npyv_reinterpret_u64_u32 npyv_reinterpret_u64_u8 +#define npyv_reinterpret_u64_s32 npyv_reinterpret_u64_u8 +#define npyv_reinterpret_u64_s64 npyv_reinterpret_u64_u8 +#if NPY_SIMD_F32 + #define npyv_reinterpret_u64_f32 npyv_reinterpret_u64_u8 +#endif +#define npyv_reinterpret_u64_f64 npyv_reinterpret_u64_u8 + +#define npyv_reinterpret_s64_s64(X) X +#define npyv_reinterpret_s64_u8(X) ((npyv_s64)X) +#define npyv_reinterpret_s64_s8 npyv_reinterpret_s64_u8 +#define npyv_reinterpret_s64_u16 npyv_reinterpret_s64_u8 +#define npyv_reinterpret_s64_s16 npyv_reinterpret_s64_u8 +#define npyv_reinterpret_s64_u32 npyv_reinterpret_s64_u8 +#define npyv_reinterpret_s64_s32 npyv_reinterpret_s64_u8 +#define npyv_reinterpret_s64_u64 npyv_reinterpret_s64_u8 +#if NPY_SIMD_F32 + #define npyv_reinterpret_s64_f32 npyv_reinterpret_s64_u8 +#endif +#define npyv_reinterpret_s64_f64 npyv_reinterpret_s64_u8 + +#if NPY_SIMD_F32 + #define npyv_reinterpret_f32_f32(X) X + #define npyv_reinterpret_f32_u8(X) ((npyv_f32)X) + #define npyv_reinterpret_f32_s8 npyv_reinterpret_f32_u8 + #define npyv_reinterpret_f32_u16 npyv_reinterpret_f32_u8 + #define npyv_reinterpret_f32_s16 npyv_reinterpret_f32_u8 + #define npyv_reinterpret_f32_u32 npyv_reinterpret_f32_u8 + #define npyv_reinterpret_f32_s32 npyv_reinterpret_f32_u8 + #define npyv_reinterpret_f32_u64 npyv_reinterpret_f32_u8 + #define npyv_reinterpret_f32_s64 npyv_reinterpret_f32_u8 + #define npyv_reinterpret_f32_f64 npyv_reinterpret_f32_u8 +#endif + +#define npyv_reinterpret_f64_f64(X) X +#define npyv_reinterpret_f64_u8(X) ((npyv_f64)X) +#define npyv_reinterpret_f64_s8 npyv_reinterpret_f64_u8 +#define npyv_reinterpret_f64_u16 npyv_reinterpret_f64_u8 +#define npyv_reinterpret_f64_s16 npyv_reinterpret_f64_u8 +#define npyv_reinterpret_f64_u32 npyv_reinterpret_f64_u8 +#define npyv_reinterpret_f64_s32 npyv_reinterpret_f64_u8 +#define npyv_reinterpret_f64_u64 npyv_reinterpret_f64_u8 +#define npyv_reinterpret_f64_s64 npyv_reinterpret_f64_u8 +#if NPY_SIMD_F32 + #define npyv_reinterpret_f64_f32 npyv_reinterpret_f64_u8 +#endif +// Only required by AVX2/AVX512 +#define npyv_cleanup() ((void)0) + +#endif // _NPY_SIMD_VEC_MISC_H diff --git a/mkl_umath/src/npyv/vec/operators.h b/mkl_umath/src/npyv/vec/operators.h new file mode 100644 index 00000000..50dac20f --- /dev/null +++ b/mkl_umath/src/npyv/vec/operators.h @@ -0,0 +1,303 @@ +#ifndef NPY_SIMD + #error "Not a standalone header" +#endif + +#ifndef _NPY_SIMD_VEC_OPERATORS_H +#define _NPY_SIMD_VEC_OPERATORS_H + +/*************************** + * Shifting + ***************************/ + +// Left +#define npyv_shl_u16(A, C) vec_sl(A, npyv_setall_u16(C)) +#define npyv_shl_s16(A, C) vec_sl_s16(A, npyv_setall_u16(C)) +#define npyv_shl_u32(A, C) vec_sl(A, npyv_setall_u32(C)) +#define npyv_shl_s32(A, C) vec_sl_s32(A, npyv_setall_u32(C)) +#define npyv_shl_u64(A, C) vec_sl(A, npyv_setall_u64(C)) +#define npyv_shl_s64(A, C) vec_sl_s64(A, npyv_setall_u64(C)) + +// Left by an immediate constant +#define npyv_shli_u16 npyv_shl_u16 +#define npyv_shli_s16 npyv_shl_s16 +#define npyv_shli_u32 npyv_shl_u32 +#define npyv_shli_s32 npyv_shl_s32 +#define npyv_shli_u64 npyv_shl_u64 +#define npyv_shli_s64 npyv_shl_s64 + +// Right +#define npyv_shr_u16(A, C) vec_sr(A, npyv_setall_u16(C)) +#define npyv_shr_s16(A, C) vec_sra_s16(A, npyv_setall_u16(C)) +#define npyv_shr_u32(A, C) vec_sr(A, npyv_setall_u32(C)) +#define npyv_shr_s32(A, C) vec_sra_s32(A, npyv_setall_u32(C)) +#define npyv_shr_u64(A, C) vec_sr(A, npyv_setall_u64(C)) +#define npyv_shr_s64(A, C) vec_sra_s64(A, npyv_setall_u64(C)) + +// Right by an immediate constant +#define npyv_shri_u16 npyv_shr_u16 +#define npyv_shri_s16 npyv_shr_s16 +#define npyv_shri_u32 npyv_shr_u32 +#define npyv_shri_s32 npyv_shr_s32 +#define npyv_shri_u64 npyv_shr_u64 +#define npyv_shri_s64 npyv_shr_s64 + +/*************************** + * Logical + ***************************/ +#define NPYV_IMPL_VEC_BIN_CAST(INTRIN, SFX, CAST) \ + NPY_FINLINE npyv_##SFX npyv_##INTRIN##_##SFX(npyv_##SFX a, npyv_##SFX b) \ + { return (npyv_##SFX)vec_##INTRIN((CAST)a, (CAST)b); } + +// Up to GCC 6 logical intrinsics don't support bool long long +#if defined(__GNUC__) && __GNUC__ <= 6 + #define NPYV_IMPL_VEC_BIN_B64(INTRIN) NPYV_IMPL_VEC_BIN_CAST(INTRIN, b64, npyv_u64) +#else + #define NPYV_IMPL_VEC_BIN_B64(INTRIN) NPYV_IMPL_VEC_BIN_CAST(INTRIN, b64, npyv_b64) +#endif +// AND +#define npyv_and_u8 vec_and +#define npyv_and_s8 vec_and +#define npyv_and_u16 vec_and +#define npyv_and_s16 vec_and +#define npyv_and_u32 vec_and +#define npyv_and_s32 vec_and +#define npyv_and_u64 vec_and +#define npyv_and_s64 vec_and +#if NPY_SIMD_F32 + #define npyv_and_f32 vec_and +#endif +#define npyv_and_f64 vec_and +#define npyv_and_b8 vec_and +#define npyv_and_b16 vec_and +#define npyv_and_b32 vec_and +NPYV_IMPL_VEC_BIN_B64(and) + +// OR +#define npyv_or_u8 vec_or +#define npyv_or_s8 vec_or +#define npyv_or_u16 vec_or +#define npyv_or_s16 vec_or +#define npyv_or_u32 vec_or +#define npyv_or_s32 vec_or +#define npyv_or_u64 vec_or +#define npyv_or_s64 vec_or +#if NPY_SIMD_F32 + #define npyv_or_f32 vec_or +#endif +#define npyv_or_f64 vec_or +#define npyv_or_b8 vec_or +#define npyv_or_b16 vec_or +#define npyv_or_b32 vec_or +NPYV_IMPL_VEC_BIN_B64(or) + +// XOR +#define npyv_xor_u8 vec_xor +#define npyv_xor_s8 vec_xor +#define npyv_xor_u16 vec_xor +#define npyv_xor_s16 vec_xor +#define npyv_xor_u32 vec_xor +#define npyv_xor_s32 vec_xor +#define npyv_xor_u64 vec_xor +#define npyv_xor_s64 vec_xor +#if NPY_SIMD_F32 + #define npyv_xor_f32 vec_xor +#endif +#define npyv_xor_f64 vec_xor +#define npyv_xor_b8 vec_xor +#define npyv_xor_b16 vec_xor +#define npyv_xor_b32 vec_xor +NPYV_IMPL_VEC_BIN_B64(xor) + +// NOT +// note: we implement npyv_not_b*(boolean types) for internal use*/ +#define NPYV_IMPL_VEC_NOT_INT(VEC_LEN) \ + NPY_FINLINE npyv_u##VEC_LEN npyv_not_u##VEC_LEN(npyv_u##VEC_LEN a) \ + { return vec_nor(a, a); } \ + NPY_FINLINE npyv_s##VEC_LEN npyv_not_s##VEC_LEN(npyv_s##VEC_LEN a) \ + { return vec_nor(a, a); } \ + NPY_FINLINE npyv_b##VEC_LEN npyv_not_b##VEC_LEN(npyv_b##VEC_LEN a) \ + { return vec_nor(a, a); } + +NPYV_IMPL_VEC_NOT_INT(8) +NPYV_IMPL_VEC_NOT_INT(16) +NPYV_IMPL_VEC_NOT_INT(32) + +// on ppc64, up to gcc5 vec_nor doesn't support bool long long +#if defined(NPY_HAVE_VSX) && defined(__GNUC__) && __GNUC__ > 5 + NPYV_IMPL_VEC_NOT_INT(64) +#else + NPY_FINLINE npyv_u64 npyv_not_u64(npyv_u64 a) + { return vec_nor(a, a); } + NPY_FINLINE npyv_s64 npyv_not_s64(npyv_s64 a) + { return vec_nor(a, a); } + NPY_FINLINE npyv_b64 npyv_not_b64(npyv_b64 a) + { return (npyv_b64)vec_nor((npyv_u64)a, (npyv_u64)a); } +#endif + +#if NPY_SIMD_F32 + NPY_FINLINE npyv_f32 npyv_not_f32(npyv_f32 a) + { return vec_nor(a, a); } +#endif +NPY_FINLINE npyv_f64 npyv_not_f64(npyv_f64 a) +{ return vec_nor(a, a); } + +// ANDC, ORC and XNOR +#define npyv_andc_u8 vec_andc +#define npyv_andc_b8 vec_andc +#if defined(NPY_HAVE_VXE) || defined(NPY_HAVE_VSX) + #define npyv_orc_b8 vec_orc + #define npyv_xnor_b8 vec_eqv +#else + #define npyv_orc_b8(A, B) npyv_or_b8(npyv_not_b8(B), A) + #define npyv_xnor_b8(A, B) npyv_not_b8(npyv_xor_b8(B, A)) +#endif + +/*************************** + * Comparison + ***************************/ + +// Int Equal +#define npyv_cmpeq_u8 vec_cmpeq +#define npyv_cmpeq_s8 vec_cmpeq +#define npyv_cmpeq_u16 vec_cmpeq +#define npyv_cmpeq_s16 vec_cmpeq +#define npyv_cmpeq_u32 vec_cmpeq +#define npyv_cmpeq_s32 vec_cmpeq +#define npyv_cmpeq_u64 vec_cmpeq +#define npyv_cmpeq_s64 vec_cmpeq +#if NPY_SIMD_F32 + #define npyv_cmpeq_f32 vec_cmpeq +#endif +#define npyv_cmpeq_f64 vec_cmpeq + +// Int Not Equal +#if defined(NPY_HAVE_VSX3) && (!defined(__GNUC__) || defined(vec_cmpne)) + // vec_cmpne supported by gcc since version 7 + #define npyv_cmpneq_u8 vec_cmpne + #define npyv_cmpneq_s8 vec_cmpne + #define npyv_cmpneq_u16 vec_cmpne + #define npyv_cmpneq_s16 vec_cmpne + #define npyv_cmpneq_u32 vec_cmpne + #define npyv_cmpneq_s32 vec_cmpne + #define npyv_cmpneq_u64 vec_cmpne + #define npyv_cmpneq_s64 vec_cmpne + #define npyv_cmpneq_f32 vec_cmpne + #define npyv_cmpneq_f64 vec_cmpne +#else + #define npyv_cmpneq_u8(A, B) npyv_not_b8(vec_cmpeq(A, B)) + #define npyv_cmpneq_s8(A, B) npyv_not_b8(vec_cmpeq(A, B)) + #define npyv_cmpneq_u16(A, B) npyv_not_b16(vec_cmpeq(A, B)) + #define npyv_cmpneq_s16(A, B) npyv_not_b16(vec_cmpeq(A, B)) + #define npyv_cmpneq_u32(A, B) npyv_not_b32(vec_cmpeq(A, B)) + #define npyv_cmpneq_s32(A, B) npyv_not_b32(vec_cmpeq(A, B)) + #define npyv_cmpneq_u64(A, B) npyv_not_b64(vec_cmpeq(A, B)) + #define npyv_cmpneq_s64(A, B) npyv_not_b64(vec_cmpeq(A, B)) + #if NPY_SIMD_F32 + #define npyv_cmpneq_f32(A, B) npyv_not_b32(vec_cmpeq(A, B)) + #endif + #define npyv_cmpneq_f64(A, B) npyv_not_b64(vec_cmpeq(A, B)) +#endif + +// Greater than +#define npyv_cmpgt_u8 vec_cmpgt +#define npyv_cmpgt_s8 vec_cmpgt +#define npyv_cmpgt_u16 vec_cmpgt +#define npyv_cmpgt_s16 vec_cmpgt +#define npyv_cmpgt_u32 vec_cmpgt +#define npyv_cmpgt_s32 vec_cmpgt +#define npyv_cmpgt_u64 vec_cmpgt +#define npyv_cmpgt_s64 vec_cmpgt +#if NPY_SIMD_F32 + #define npyv_cmpgt_f32 vec_cmpgt +#endif +#define npyv_cmpgt_f64 vec_cmpgt + +// Greater than or equal +// On ppc64le, up to gcc5 vec_cmpge only supports single and double precision +#if defined(NPY_HAVE_VX) || (defined(__GNUC__) && __GNUC__ > 5) + #define npyv_cmpge_u8 vec_cmpge + #define npyv_cmpge_s8 vec_cmpge + #define npyv_cmpge_u16 vec_cmpge + #define npyv_cmpge_s16 vec_cmpge + #define npyv_cmpge_u32 vec_cmpge + #define npyv_cmpge_s32 vec_cmpge + #define npyv_cmpge_u64 vec_cmpge + #define npyv_cmpge_s64 vec_cmpge +#else + #define npyv_cmpge_u8(A, B) npyv_not_b8(vec_cmpgt(B, A)) + #define npyv_cmpge_s8(A, B) npyv_not_b8(vec_cmpgt(B, A)) + #define npyv_cmpge_u16(A, B) npyv_not_b16(vec_cmpgt(B, A)) + #define npyv_cmpge_s16(A, B) npyv_not_b16(vec_cmpgt(B, A)) + #define npyv_cmpge_u32(A, B) npyv_not_b32(vec_cmpgt(B, A)) + #define npyv_cmpge_s32(A, B) npyv_not_b32(vec_cmpgt(B, A)) + #define npyv_cmpge_u64(A, B) npyv_not_b64(vec_cmpgt(B, A)) + #define npyv_cmpge_s64(A, B) npyv_not_b64(vec_cmpgt(B, A)) +#endif +#if NPY_SIMD_F32 + #define npyv_cmpge_f32 vec_cmpge +#endif +#define npyv_cmpge_f64 vec_cmpge + +// Less than +#define npyv_cmplt_u8(A, B) npyv_cmpgt_u8(B, A) +#define npyv_cmplt_s8(A, B) npyv_cmpgt_s8(B, A) +#define npyv_cmplt_u16(A, B) npyv_cmpgt_u16(B, A) +#define npyv_cmplt_s16(A, B) npyv_cmpgt_s16(B, A) +#define npyv_cmplt_u32(A, B) npyv_cmpgt_u32(B, A) +#define npyv_cmplt_s32(A, B) npyv_cmpgt_s32(B, A) +#define npyv_cmplt_u64(A, B) npyv_cmpgt_u64(B, A) +#define npyv_cmplt_s64(A, B) npyv_cmpgt_s64(B, A) +#if NPY_SIMD_F32 + #define npyv_cmplt_f32(A, B) npyv_cmpgt_f32(B, A) +#endif +#define npyv_cmplt_f64(A, B) npyv_cmpgt_f64(B, A) + +// Less than or equal +#define npyv_cmple_u8(A, B) npyv_cmpge_u8(B, A) +#define npyv_cmple_s8(A, B) npyv_cmpge_s8(B, A) +#define npyv_cmple_u16(A, B) npyv_cmpge_u16(B, A) +#define npyv_cmple_s16(A, B) npyv_cmpge_s16(B, A) +#define npyv_cmple_u32(A, B) npyv_cmpge_u32(B, A) +#define npyv_cmple_s32(A, B) npyv_cmpge_s32(B, A) +#define npyv_cmple_u64(A, B) npyv_cmpge_u64(B, A) +#define npyv_cmple_s64(A, B) npyv_cmpge_s64(B, A) +#if NPY_SIMD_F32 + #define npyv_cmple_f32(A, B) npyv_cmpge_f32(B, A) +#endif +#define npyv_cmple_f64(A, B) npyv_cmpge_f64(B, A) + +// check special cases +#if NPY_SIMD_F32 + NPY_FINLINE npyv_b32 npyv_notnan_f32(npyv_f32 a) + { return vec_cmpeq(a, a); } +#endif +NPY_FINLINE npyv_b64 npyv_notnan_f64(npyv_f64 a) +{ return vec_cmpeq(a, a); } + +// Test cross all vector lanes +// any: returns true if any of the elements is not equal to zero +// all: returns true if all elements are not equal to zero +#define NPYV_IMPL_VEC_ANYALL(SFX, SFX2) \ + NPY_FINLINE bool npyv_any_##SFX(npyv_##SFX a) \ + { return vec_any_ne(a, (npyv_##SFX)npyv_zero_##SFX2()); } \ + NPY_FINLINE bool npyv_all_##SFX(npyv_##SFX a) \ + { return vec_all_ne(a, (npyv_##SFX)npyv_zero_##SFX2()); } +NPYV_IMPL_VEC_ANYALL(b8, u8) +NPYV_IMPL_VEC_ANYALL(b16, u16) +NPYV_IMPL_VEC_ANYALL(b32, u32) +NPYV_IMPL_VEC_ANYALL(b64, u64) +NPYV_IMPL_VEC_ANYALL(u8, u8) +NPYV_IMPL_VEC_ANYALL(s8, s8) +NPYV_IMPL_VEC_ANYALL(u16, u16) +NPYV_IMPL_VEC_ANYALL(s16, s16) +NPYV_IMPL_VEC_ANYALL(u32, u32) +NPYV_IMPL_VEC_ANYALL(s32, s32) +NPYV_IMPL_VEC_ANYALL(u64, u64) +NPYV_IMPL_VEC_ANYALL(s64, s64) +#if NPY_SIMD_F32 + NPYV_IMPL_VEC_ANYALL(f32, f32) +#endif +NPYV_IMPL_VEC_ANYALL(f64, f64) +#undef NPYV_IMPL_VEC_ANYALL + +#endif // _NPY_SIMD_VEC_OPERATORS_H diff --git a/mkl_umath/src/npyv/vec/reorder.h b/mkl_umath/src/npyv/vec/reorder.h new file mode 100644 index 00000000..3910980a --- /dev/null +++ b/mkl_umath/src/npyv/vec/reorder.h @@ -0,0 +1,213 @@ +#ifndef NPY_SIMD + #error "Not a standalone header" +#endif + +#ifndef _NPY_SIMD_VEC_REORDER_H +#define _NPY_SIMD_VEC_REORDER_H + +// combine lower part of two vectors +#define npyv__combinel(A, B) vec_mergeh((npyv_u64)(A), (npyv_u64)(B)) +#define npyv_combinel_u8(A, B) ((npyv_u8) npyv__combinel(A, B)) +#define npyv_combinel_s8(A, B) ((npyv_s8) npyv__combinel(A, B)) +#define npyv_combinel_u16(A, B) ((npyv_u16)npyv__combinel(A, B)) +#define npyv_combinel_s16(A, B) ((npyv_s16)npyv__combinel(A, B)) +#define npyv_combinel_u32(A, B) ((npyv_u32)npyv__combinel(A, B)) +#define npyv_combinel_s32(A, B) ((npyv_s32)npyv__combinel(A, B)) +#define npyv_combinel_u64 vec_mergeh +#define npyv_combinel_s64 vec_mergeh +#if NPY_SIMD_F32 + #define npyv_combinel_f32(A, B) ((npyv_f32)npyv__combinel(A, B)) +#endif +#define npyv_combinel_f64 vec_mergeh + +// combine higher part of two vectors +#define npyv__combineh(A, B) vec_mergel((npyv_u64)(A), (npyv_u64)(B)) +#define npyv_combineh_u8(A, B) ((npyv_u8) npyv__combineh(A, B)) +#define npyv_combineh_s8(A, B) ((npyv_s8) npyv__combineh(A, B)) +#define npyv_combineh_u16(A, B) ((npyv_u16)npyv__combineh(A, B)) +#define npyv_combineh_s16(A, B) ((npyv_s16)npyv__combineh(A, B)) +#define npyv_combineh_u32(A, B) ((npyv_u32)npyv__combineh(A, B)) +#define npyv_combineh_s32(A, B) ((npyv_s32)npyv__combineh(A, B)) +#define npyv_combineh_u64 vec_mergel +#define npyv_combineh_s64 vec_mergel +#if NPY_SIMD_F32 + #define npyv_combineh_f32(A, B) ((npyv_f32)npyv__combineh(A, B)) +#endif +#define npyv_combineh_f64 vec_mergel + +/* + * combine: combine two vectors from lower and higher parts of two other vectors + * zip: interleave two vectors +*/ +#define NPYV_IMPL_VEC_COMBINE_ZIP(T_VEC, SFX) \ + NPY_FINLINE T_VEC##x2 npyv_combine_##SFX(T_VEC a, T_VEC b) \ + { \ + T_VEC##x2 r; \ + r.val[0] = NPY_CAT(npyv_combinel_, SFX)(a, b); \ + r.val[1] = NPY_CAT(npyv_combineh_, SFX)(a, b); \ + return r; \ + } \ + NPY_FINLINE T_VEC##x2 npyv_zip_##SFX(T_VEC a, T_VEC b) \ + { \ + T_VEC##x2 r; \ + r.val[0] = vec_mergeh(a, b); \ + r.val[1] = vec_mergel(a, b); \ + return r; \ + } + +NPYV_IMPL_VEC_COMBINE_ZIP(npyv_u8, u8) +NPYV_IMPL_VEC_COMBINE_ZIP(npyv_s8, s8) +NPYV_IMPL_VEC_COMBINE_ZIP(npyv_u16, u16) +NPYV_IMPL_VEC_COMBINE_ZIP(npyv_s16, s16) +NPYV_IMPL_VEC_COMBINE_ZIP(npyv_u32, u32) +NPYV_IMPL_VEC_COMBINE_ZIP(npyv_s32, s32) +NPYV_IMPL_VEC_COMBINE_ZIP(npyv_u64, u64) +NPYV_IMPL_VEC_COMBINE_ZIP(npyv_s64, s64) +#if NPY_SIMD_F32 + NPYV_IMPL_VEC_COMBINE_ZIP(npyv_f32, f32) +#endif +NPYV_IMPL_VEC_COMBINE_ZIP(npyv_f64, f64) + +// deinterleave two vectors +NPY_FINLINE npyv_u8x2 npyv_unzip_u8(npyv_u8 ab0, npyv_u8 ab1) +{ + const npyv_u8 idx_even = npyv_set_u8( + 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30 + ); + const npyv_u8 idx_odd = npyv_set_u8( + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31 + ); + npyv_u8x2 r; + r.val[0] = vec_perm(ab0, ab1, idx_even); + r.val[1] = vec_perm(ab0, ab1, idx_odd); + return r; +} +NPY_FINLINE npyv_s8x2 npyv_unzip_s8(npyv_s8 ab0, npyv_s8 ab1) +{ + npyv_u8x2 ru = npyv_unzip_u8((npyv_u8)ab0, (npyv_u8)ab1); + npyv_s8x2 r; + r.val[0] = (npyv_s8)ru.val[0]; + r.val[1] = (npyv_s8)ru.val[1]; + return r; +} +NPY_FINLINE npyv_u16x2 npyv_unzip_u16(npyv_u16 ab0, npyv_u16 ab1) +{ + const npyv_u8 idx_even = npyv_set_u8( + 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29 + ); + const npyv_u8 idx_odd = npyv_set_u8( + 2, 3, 6, 7, 10, 11, 14, 15, 18, 19, 22, 23, 26, 27, 30, 31 + ); + npyv_u16x2 r; + r.val[0] = vec_perm(ab0, ab1, idx_even); + r.val[1] = vec_perm(ab0, ab1, idx_odd); + return r; +} +NPY_FINLINE npyv_s16x2 npyv_unzip_s16(npyv_s16 ab0, npyv_s16 ab1) +{ + npyv_u16x2 ru = npyv_unzip_u16((npyv_u16)ab0, (npyv_u16)ab1); + npyv_s16x2 r; + r.val[0] = (npyv_s16)ru.val[0]; + r.val[1] = (npyv_s16)ru.val[1]; + return r; +} +NPY_FINLINE npyv_u32x2 npyv_unzip_u32(npyv_u32 ab0, npyv_u32 ab1) +{ + npyv_u32 m0 = vec_mergeh(ab0, ab1); + npyv_u32 m1 = vec_mergel(ab0, ab1); + npyv_u32 r0 = vec_mergeh(m0, m1); + npyv_u32 r1 = vec_mergel(m0, m1); + npyv_u32x2 r; + r.val[0] = r0; + r.val[1] = r1; + return r; +} +NPY_FINLINE npyv_s32x2 npyv_unzip_s32(npyv_s32 ab0, npyv_s32 ab1) +{ + npyv_u32x2 ru = npyv_unzip_u32((npyv_u32)ab0, (npyv_u32)ab1); + npyv_s32x2 r; + r.val[0] = (npyv_s32)ru.val[0]; + r.val[1] = (npyv_s32)ru.val[1]; + return r; +} +#if NPY_SIMD_F32 + NPY_FINLINE npyv_f32x2 npyv_unzip_f32(npyv_f32 ab0, npyv_f32 ab1) + { + npyv_u32x2 ru = npyv_unzip_u32((npyv_u32)ab0, (npyv_u32)ab1); + npyv_f32x2 r; + r.val[0] = (npyv_f32)ru.val[0]; + r.val[1] = (npyv_f32)ru.val[1]; + return r; + } +#endif +NPY_FINLINE npyv_u64x2 npyv_unzip_u64(npyv_u64 ab0, npyv_u64 ab1) +{ return npyv_combine_u64(ab0, ab1); } +NPY_FINLINE npyv_s64x2 npyv_unzip_s64(npyv_s64 ab0, npyv_s64 ab1) +{ return npyv_combine_s64(ab0, ab1); } +NPY_FINLINE npyv_f64x2 npyv_unzip_f64(npyv_f64 ab0, npyv_f64 ab1) +{ return npyv_combine_f64(ab0, ab1); } + +// Reverse elements of each 64-bit lane +NPY_FINLINE npyv_u8 npyv_rev64_u8(npyv_u8 a) +{ +#if defined(NPY_HAVE_VSX3) && ((defined(__GNUC__) && __GNUC__ > 7) || defined(__IBMC__)) + return (npyv_u8)vec_revb((npyv_u64)a); +#elif defined(NPY_HAVE_VSX3) && defined(NPY_HAVE_VSX_ASM) + npyv_u8 ret; + __asm__ ("xxbrd %x0,%x1" : "=wa" (ret) : "wa" (a)); + return ret; +#else + const npyv_u8 idx = npyv_set_u8( + 7, 6, 5, 4, 3, 2, 1, 0,/*64*/15, 14, 13, 12, 11, 10, 9, 8 + ); + return vec_perm(a, a, idx); +#endif +} +NPY_FINLINE npyv_s8 npyv_rev64_s8(npyv_s8 a) +{ return (npyv_s8)npyv_rev64_u8((npyv_u8)a); } + +NPY_FINLINE npyv_u16 npyv_rev64_u16(npyv_u16 a) +{ + const npyv_u8 idx = npyv_set_u8( + 6, 7, 4, 5, 2, 3, 0, 1,/*64*/14, 15, 12, 13, 10, 11, 8, 9 + ); + return vec_perm(a, a, idx); +} +NPY_FINLINE npyv_s16 npyv_rev64_s16(npyv_s16 a) +{ return (npyv_s16)npyv_rev64_u16((npyv_u16)a); } + +NPY_FINLINE npyv_u32 npyv_rev64_u32(npyv_u32 a) +{ + const npyv_u8 idx = npyv_set_u8( + 4, 5, 6, 7, 0, 1, 2, 3,/*64*/12, 13, 14, 15, 8, 9, 10, 11 + ); + return vec_perm(a, a, idx); +} +NPY_FINLINE npyv_s32 npyv_rev64_s32(npyv_s32 a) +{ return (npyv_s32)npyv_rev64_u32((npyv_u32)a); } +#if NPY_SIMD_F32 + NPY_FINLINE npyv_f32 npyv_rev64_f32(npyv_f32 a) + { return (npyv_f32)npyv_rev64_u32((npyv_u32)a); } +#endif + +// Permuting the elements of each 128-bit lane by immediate index for +// each element. +#define npyv_permi128_u32(A, E0, E1, E2, E3) \ + vec_perm(A, A, npyv_set_u8( \ + (E0<<2), (E0<<2)+1, (E0<<2)+2, (E0<<2)+3, \ + (E1<<2), (E1<<2)+1, (E1<<2)+2, (E1<<2)+3, \ + (E2<<2), (E2<<2)+1, (E2<<2)+2, (E2<<2)+3, \ + (E3<<2), (E3<<2)+1, (E3<<2)+2, (E3<<2)+3 \ + )) +#define npyv_permi128_s32 npyv_permi128_u32 +#define npyv_permi128_f32 npyv_permi128_u32 + +#if defined(__IBMC__) || defined(vec_permi) + #define npyv_permi128_u64(A, E0, E1) vec_permi(A, A, ((E0)<<1) | (E1)) +#else + #define npyv_permi128_u64(A, E0, E1) vec_xxpermdi(A, A, ((E0)<<1) | (E1)) +#endif +#define npyv_permi128_s64 npyv_permi128_u64 +#define npyv_permi128_f64 npyv_permi128_u64 + +#endif // _NPY_SIMD_VEC_REORDER_H diff --git a/mkl_umath/src/npyv/vec/utils.h b/mkl_umath/src/npyv/vec/utils.h new file mode 100644 index 00000000..f8b28cfe --- /dev/null +++ b/mkl_umath/src/npyv/vec/utils.h @@ -0,0 +1,84 @@ +#ifndef NPY_SIMD + #error "Not a standalone header" +#endif + +#ifndef _NPY_SIMD_VEC_UTILS_H +#define _NPY_SIMD_VEC_UTILS_H + +// the following intrinsics may not some|all by zvector API on gcc/clang +#ifdef NPY_HAVE_VX + #ifndef vec_neg + #define vec_neg(a) (-(a)) // Vector Negate + #endif + #ifndef vec_add + #define vec_add(a, b) ((a) + (b)) // Vector Add + #endif + #ifndef vec_sub + #define vec_sub(a, b) ((a) - (b)) // Vector Subtract + #endif + #ifndef vec_mul + #define vec_mul(a, b) ((a) * (b)) // Vector Multiply + #endif + #ifndef vec_div + #define vec_div(a, b) ((a) / (b)) // Vector Divide + #endif + #ifndef vec_neg + #define vec_neg(a) (-(a)) + #endif + #ifndef vec_and + #define vec_and(a, b) ((a) & (b)) // Vector AND + #endif + #ifndef vec_or + #define vec_or(a, b) ((a) | (b)) // Vector OR + #endif + #ifndef vec_xor + #define vec_xor(a, b) ((a) ^ (b)) // Vector XOR + #endif + #ifndef vec_sl + #define vec_sl(a, b) ((a) << (b)) // Vector Shift Left + #endif + #ifndef vec_sra + #define vec_sra(a, b) ((a) >> (b)) // Vector Shift Right + #endif + #ifndef vec_sr + #define vec_sr(a, b) ((a) >> (b)) // Vector Shift Right Algebraic + #endif + #ifndef vec_slo + #define vec_slo(a, b) vec_slb(a, (b) << 64) // Vector Shift Left by Octet + #endif + #ifndef vec_sro + #define vec_sro(a, b) vec_srb(a, (b) << 64) // Vector Shift Right by Octet + #endif + // vec_doublee maps to wrong intrin "vfll". + // see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=100871 + #if defined(__GNUC__) && !defined(__clang__) + #define npyv_doublee __builtin_s390_vflls + #else + #define npyv_doublee vec_doublee + #endif + // compatibility with vsx + #ifndef vec_vbpermq + #define vec_vbpermq vec_bperm_u128 + #endif + // zvector requires second operand to signed while vsx api expected to be + // unsigned, the following macros are set to remove this conflict + #define vec_sl_s8(a, b) vec_sl(a, (npyv_s8)(b)) + #define vec_sl_s16(a, b) vec_sl(a, (npyv_s16)(b)) + #define vec_sl_s32(a, b) vec_sl(a, (npyv_s32)(b)) + #define vec_sl_s64(a, b) vec_sl(a, (npyv_s64)(b)) + #define vec_sra_s8(a, b) vec_sra(a, (npyv_s8)(b)) + #define vec_sra_s16(a, b) vec_sra(a, (npyv_s16)(b)) + #define vec_sra_s32(a, b) vec_sra(a, (npyv_s32)(b)) + #define vec_sra_s64(a, b) vec_sra(a, (npyv_s64)(b)) +#else + #define vec_sl_s8 vec_sl + #define vec_sl_s16 vec_sl + #define vec_sl_s32 vec_sl + #define vec_sl_s64 vec_sl + #define vec_sra_s8 vec_sra + #define vec_sra_s16 vec_sra + #define vec_sra_s32 vec_sra + #define vec_sra_s64 vec_sra +#endif + +#endif // _NPY_SIMD_VEC_UTILS_H diff --git a/mkl_umath/src/npyv/vec/vec.h b/mkl_umath/src/npyv/vec/vec.h new file mode 100644 index 00000000..1d450866 --- /dev/null +++ b/mkl_umath/src/npyv/vec/vec.h @@ -0,0 +1,111 @@ +/** + * branch /vec(altivec-like) provides the SIMD operations for + * both IBM VSX(Power) and VX(ZArch). +*/ +#ifndef _NPY_SIMD_H_ + #error "Not a standalone header" +#endif + +#if !defined(NPY_HAVE_VX) && !defined(NPY_HAVE_VSX2) + #error "require minimum support VX(zarch11) or VSX2(Power8/ISA2.07)" +#endif + +#if defined(NPY_HAVE_VSX) && !defined(__LITTLE_ENDIAN__) + #error "VSX support doesn't cover big-endian mode yet, only zarch." +#endif +#if defined(NPY_HAVE_VX) && defined(__LITTLE_ENDIAN__) + #error "VX(zarch) support doesn't cover little-endian mode." +#endif + +#if defined(__GNUC__) && __GNUC__ <= 7 + /** + * GCC <= 7 produces ambiguous warning caused by -Werror=maybe-uninitialized, + * when certain intrinsics involved. `vec_ld` is one of them but it seemed to work fine, + * and suppressing the warning wouldn't affect its functionality. + */ + #pragma GCC diagnostic ignored "-Wuninitialized" + #pragma GCC diagnostic ignored "-Wmaybe-uninitialized" +#endif + +#define NPY_SIMD 128 +#define NPY_SIMD_WIDTH 16 +#define NPY_SIMD_F64 1 +#if defined(NPY_HAVE_VXE) || defined(NPY_HAVE_VSX) + #define NPY_SIMD_F32 1 +#else + #define NPY_SIMD_F32 0 +#endif +#define NPY_SIMD_FMA3 1 // native support + +#ifdef NPY_HAVE_VX + #define NPY_SIMD_BIGENDIAN 1 + #define NPY_SIMD_CMPSIGNAL 0 +#else + #define NPY_SIMD_BIGENDIAN 0 + #define NPY_SIMD_CMPSIGNAL 1 +#endif + +typedef __vector unsigned char npyv_u8; +typedef __vector signed char npyv_s8; +typedef __vector unsigned short npyv_u16; +typedef __vector signed short npyv_s16; +typedef __vector unsigned int npyv_u32; +typedef __vector signed int npyv_s32; +typedef __vector unsigned long long npyv_u64; +typedef __vector signed long long npyv_s64; +#if NPY_SIMD_F32 +typedef __vector float npyv_f32; +#endif +typedef __vector double npyv_f64; + +typedef struct { npyv_u8 val[2]; } npyv_u8x2; +typedef struct { npyv_s8 val[2]; } npyv_s8x2; +typedef struct { npyv_u16 val[2]; } npyv_u16x2; +typedef struct { npyv_s16 val[2]; } npyv_s16x2; +typedef struct { npyv_u32 val[2]; } npyv_u32x2; +typedef struct { npyv_s32 val[2]; } npyv_s32x2; +typedef struct { npyv_u64 val[2]; } npyv_u64x2; +typedef struct { npyv_s64 val[2]; } npyv_s64x2; +#if NPY_SIMD_F32 +typedef struct { npyv_f32 val[2]; } npyv_f32x2; +#endif +typedef struct { npyv_f64 val[2]; } npyv_f64x2; + +typedef struct { npyv_u8 val[3]; } npyv_u8x3; +typedef struct { npyv_s8 val[3]; } npyv_s8x3; +typedef struct { npyv_u16 val[3]; } npyv_u16x3; +typedef struct { npyv_s16 val[3]; } npyv_s16x3; +typedef struct { npyv_u32 val[3]; } npyv_u32x3; +typedef struct { npyv_s32 val[3]; } npyv_s32x3; +typedef struct { npyv_u64 val[3]; } npyv_u64x3; +typedef struct { npyv_s64 val[3]; } npyv_s64x3; +#if NPY_SIMD_F32 +typedef struct { npyv_f32 val[3]; } npyv_f32x3; +#endif +typedef struct { npyv_f64 val[3]; } npyv_f64x3; + +#define npyv_nlanes_u8 16 +#define npyv_nlanes_s8 16 +#define npyv_nlanes_u16 8 +#define npyv_nlanes_s16 8 +#define npyv_nlanes_u32 4 +#define npyv_nlanes_s32 4 +#define npyv_nlanes_u64 2 +#define npyv_nlanes_s64 2 +#define npyv_nlanes_f32 4 +#define npyv_nlanes_f64 2 + +// using __bool with typedef cause ambiguous errors +#define npyv_b8 __vector __bool char +#define npyv_b16 __vector __bool short +#define npyv_b32 __vector __bool int +#define npyv_b64 __vector __bool long long + +#include "utils.h" +#include "memory.h" +#include "misc.h" +#include "reorder.h" +#include "operators.h" +#include "conversion.h" +#include "arithmetic.h" +#include "math.h"