Skip to content

Commit 6f681b4

Browse files
committed
refactor(simd): unify cross-ISA implementations via traits + kernel templates
Introduce a type-traits abstraction layer and shared kernel templates to eliminate copy-paste SIMD implementations across ISAs (SSE/AVX/AVX2/AVX512/NEON). Changes: - Add per-ISA traits headers (traits/simd_traits_*.h) exposing a uniform static interface (load/store/fmadd/reduce_add/etc.) over each ISA's native vector types. - Add 17 kernel templates (kernels/*.h) that implement the algorithm once in terms of traits, replacing 5-way duplicated hand-written intrinsics. - Covered function families: FP32 compute/batch/binary-ops/reduce, SQ8/SQ4 quantized compute, SQ4/SQ8 uniform code IP, INT8 compute, BF16/FP16 half-precision compute, bit operations, normalize, RaBitQ binary/batch/split-code IP, butterfly/rotate, and scalar ops. - SVE and AMX are explicitly out of scope (different vector model). Net effect: ISA implementation files shrink from ~11,049 to ~4,628 lines (-58%). New SIMD function cost drops from ~7 files x 30 lines to 1 kernel template + 0-2 trait extensions. All 51 SIMD test cases pass (30,591,891 assertions). Signed-off-by: LHT129 <tianlan.lht@antgroup.com>
1 parent a1b501f commit 6f681b4

30 files changed

Lines changed: 4024 additions & 5521 deletions

src/simd/avx.cpp

Lines changed: 95 additions & 841 deletions
Large diffs are not rendered by default.

src/simd/avx2.cpp

Lines changed: 108 additions & 1037 deletions
Large diffs are not rendered by default.

src/simd/avx512.cpp

Lines changed: 107 additions & 975 deletions
Large diffs are not rendered by default.

src/simd/generic.cpp

Lines changed: 96 additions & 236 deletions
Large diffs are not rendered by default.

src/simd/kernels/binary_op.h

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
2+
// Copyright 2024-present the vsag project
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
#pragma once
17+
18+
// Element-wise binary op kernel: z[i] = op(x[i], y[i]).
19+
// Used by FP32Add / FP32Sub / FP32Mul / FP32Div across all ISAs.
20+
//
21+
// Op selects the per-element operation via a tag dispatched at compile time
22+
// to the corresponding traits method (add/sub/mul/div). The Generic backend
23+
// (Width == 1) compiles out the fallback branch via `if constexpr`.
24+
25+
#include <cstdint>
26+
27+
namespace vsag::simd {
28+
29+
using BinaryFallback = void (*)(const float*, const float*, float*, uint64_t);
30+
31+
enum class BinaryOp { Add, Sub, Mul, Div };
32+
33+
template <typename T, BinaryOp Op>
34+
inline __attribute__((always_inline)) typename T::FloatVec
35+
binary_apply(typename T::FloatVec a, typename T::FloatVec b) {
36+
if constexpr (Op == BinaryOp::Add) {
37+
return T::add(a, b);
38+
} else if constexpr (Op == BinaryOp::Sub) {
39+
return T::sub(a, b);
40+
} else if constexpr (Op == BinaryOp::Mul) {
41+
return T::mul(a, b);
42+
} else {
43+
return T::div(a, b);
44+
}
45+
}
46+
47+
template <typename T, BinaryOp Op>
48+
inline void
49+
BinaryOpImpl(
50+
const float* x, const float* y, float* z, uint64_t dim, BinaryFallback fallback = nullptr) {
51+
using V = typename T::FloatVec;
52+
constexpr int W = T::Width;
53+
54+
if constexpr (W > 1) {
55+
if (dim < static_cast<uint64_t>(W)) {
56+
fallback(x, y, z, dim);
57+
return;
58+
}
59+
}
60+
61+
uint64_t i = 0;
62+
for (; i + W <= dim; i += W) {
63+
V a = T::load(x + i);
64+
V b = T::load(y + i);
65+
T::store(z + i, binary_apply<T, Op>(a, b));
66+
}
67+
68+
if constexpr (W > 1) {
69+
if (dim > i) {
70+
fallback(x + i, y + i, z + i, dim - i);
71+
}
72+
}
73+
}
74+
75+
} // namespace vsag::simd

src/simd/kernels/bit_op.h

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
// Copyright 2024-present the vsag project
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
// Bitwise operation kernels: AND, OR, XOR, NOT.
18+
//
19+
// Parameterized on BitTraits<ISA>, which must expose:
20+
// IntVec - integer vector type (e.g. __m128i, __m256i, __m512i)
21+
// ByteWidth - number of bytes per vector (16, 32, 64)
22+
// load(const uint8_t* p) -> IntVec
23+
// store(uint8_t* p, IntVec v)
24+
// bit_and(IntVec a, IntVec b) -> IntVec
25+
// bit_or(IntVec a, IntVec b) -> IntVec
26+
// bit_xor(IntVec a, IntVec b) -> IntVec
27+
// bit_not(IntVec a) -> IntVec (optional, only needed for BitNotImpl)
28+
29+
#include <cstdint>
30+
31+
namespace vsag::simd {
32+
33+
using BitOpFallback = void (*)(const uint8_t*, const uint8_t*, uint64_t, uint8_t*);
34+
using BitNotFallback = void (*)(const uint8_t*, uint64_t, uint8_t*);
35+
36+
template <typename T>
37+
inline void
38+
BitAndImpl(const uint8_t* x,
39+
const uint8_t* y,
40+
uint64_t num_byte,
41+
uint8_t* result,
42+
BitOpFallback fallback = nullptr) {
43+
constexpr int W = T::ByteWidth;
44+
if (num_byte == 0)
45+
return;
46+
if (num_byte < static_cast<uint64_t>(W)) {
47+
return fallback(x, y, num_byte, result);
48+
}
49+
int64_t i = 0;
50+
for (; i + W <= static_cast<int64_t>(num_byte); i += W) {
51+
T::store(result + i, T::bit_and(T::load(x + i), T::load(y + i)));
52+
}
53+
if (i < static_cast<int64_t>(num_byte)) {
54+
fallback(x + i, y + i, num_byte - i, result + i);
55+
}
56+
}
57+
58+
template <typename T>
59+
inline void
60+
BitOrImpl(const uint8_t* x,
61+
const uint8_t* y,
62+
uint64_t num_byte,
63+
uint8_t* result,
64+
BitOpFallback fallback = nullptr) {
65+
constexpr int W = T::ByteWidth;
66+
if (num_byte == 0)
67+
return;
68+
if (num_byte < static_cast<uint64_t>(W)) {
69+
return fallback(x, y, num_byte, result);
70+
}
71+
int64_t i = 0;
72+
for (; i + W <= static_cast<int64_t>(num_byte); i += W) {
73+
T::store(result + i, T::bit_or(T::load(x + i), T::load(y + i)));
74+
}
75+
if (i < static_cast<int64_t>(num_byte)) {
76+
fallback(x + i, y + i, num_byte - i, result + i);
77+
}
78+
}
79+
80+
template <typename T>
81+
inline void
82+
BitXorImpl(const uint8_t* x,
83+
const uint8_t* y,
84+
uint64_t num_byte,
85+
uint8_t* result,
86+
BitOpFallback fallback = nullptr) {
87+
constexpr int W = T::ByteWidth;
88+
if (num_byte == 0)
89+
return;
90+
if (num_byte < static_cast<uint64_t>(W)) {
91+
return fallback(x, y, num_byte, result);
92+
}
93+
int64_t i = 0;
94+
for (; i + W <= static_cast<int64_t>(num_byte); i += W) {
95+
T::store(result + i, T::bit_xor(T::load(x + i), T::load(y + i)));
96+
}
97+
if (i < static_cast<int64_t>(num_byte)) {
98+
fallback(x + i, y + i, num_byte - i, result + i);
99+
}
100+
}
101+
102+
template <typename T>
103+
inline void
104+
BitNotImpl(const uint8_t* x,
105+
uint64_t num_byte,
106+
uint8_t* result,
107+
BitNotFallback fallback = nullptr) {
108+
constexpr int W = T::ByteWidth;
109+
if (num_byte == 0)
110+
return;
111+
if (num_byte < static_cast<uint64_t>(W)) {
112+
return fallback(x, num_byte, result);
113+
}
114+
int64_t i = 0;
115+
for (; i + W <= static_cast<int64_t>(num_byte); i += W) {
116+
T::store(result + i, T::bit_not(T::load(x + i)));
117+
}
118+
if (i < static_cast<int64_t>(num_byte)) {
119+
fallback(x + i, num_byte - i, result + i);
120+
}
121+
}
122+
123+
} // namespace vsag::simd

src/simd/kernels/butterfly.h

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
// Copyright 2024-present the vsag project
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
// Butterfly-pattern kernels used in Fast Hadamard Transform (FHT):
18+
// RotateOp: data[i+j] = data[i+j] + data[i+j+step]
19+
// data[i+j+step] = data[i+j] - data[i+j+step]
20+
// KacsWalk: in-place butterfly on two halves of the array.
21+
//
22+
// These use only load/store/add/sub from the traits interface.
23+
24+
#include <cmath>
25+
#include <cstdint>
26+
27+
namespace vsag::simd {
28+
29+
using RotateOpFallback = void (*)(float*, int, int, int);
30+
using KacsWalkFallback = void (*)(float*, uint64_t);
31+
32+
// RotateOp: butterfly on stride `step` within [idx, dim_).
33+
// Requires step >= T::Width for the vectorized path.
34+
template <typename T>
35+
inline void
36+
RotateOpImpl(float* data, int idx, int dim_, int step) {
37+
using V = typename T::FloatVec;
38+
constexpr int W = T::Width;
39+
40+
for (int i = idx; i < dim_; i += step * 2) {
41+
int j = 0;
42+
for (; j + W <= step; j += W) {
43+
V g1 = T::load(&data[i + j]);
44+
V g2 = T::load(&data[i + j + step]);
45+
T::store(&data[i + j], T::add(g1, g2));
46+
T::store(&data[i + j + step], T::sub(g1, g2));
47+
}
48+
for (; j < step; ++j) {
49+
float g1 = data[i + j];
50+
float g2 = data[i + j + step];
51+
data[i + j] = g1 + g2;
52+
data[i + j + step] = g1 - g2;
53+
}
54+
}
55+
}
56+
57+
// KacsWalk: in-place butterfly on data[0..len/2-1] vs data[offset..offset+len/2-1].
58+
// For odd-length arrays, the middle element is scaled by sqrt(2).
59+
template <typename T>
60+
inline void
61+
KacsWalkImpl(float* data, uint64_t len, KacsWalkFallback fallback = nullptr) {
62+
using V = typename T::FloatVec;
63+
constexpr int W = T::Width;
64+
65+
if constexpr (W > 1) {
66+
if (len / 2 < static_cast<uint64_t>(W)) {
67+
fallback(data, len);
68+
return;
69+
}
70+
}
71+
72+
uint64_t base = len % 2;
73+
uint64_t offset = base + (len / 2);
74+
uint64_t i = 0;
75+
76+
for (; i + W <= len / 2; i += W) {
77+
V x = T::load(&data[i]);
78+
V y = T::load(&data[i + offset]);
79+
T::store(&data[i], T::add(x, y));
80+
T::store(&data[i + offset], T::sub(x, y));
81+
}
82+
83+
// Scalar tail
84+
for (; i < len / 2; i++) {
85+
float x = data[i];
86+
float y = data[i + offset];
87+
data[i] = x + y;
88+
data[i + offset] = x - y;
89+
}
90+
91+
if (base != 0) {
92+
data[len / 2] *= std::sqrt(2.0f);
93+
}
94+
}
95+
96+
} // namespace vsag::simd

src/simd/kernels/compute_batch4.h

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
2+
// Copyright 2024-present the vsag project
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
#pragma once
17+
18+
// Batch-of-4 IP / L2 kernel: one query vector against four code vectors.
19+
// Results are accumulated into result1..result4 (the caller must initialise
20+
// them before invocation, e.g. to 0). Matches the existing semantics of
21+
// FP32ComputeIPBatch4 / FP32ComputeL2SqrBatch4: the four accumulators
22+
// share the same query load, so we get 4x reuse of every q-cacheline.
23+
24+
#include <cstdint>
25+
26+
#include "simd/simd_marco.h"
27+
28+
namespace vsag::simd {
29+
30+
using Batch4Fallback = void (*)(const float* RESTRICT query,
31+
uint64_t dim,
32+
const float* RESTRICT c1,
33+
const float* RESTRICT c2,
34+
const float* RESTRICT c3,
35+
const float* RESTRICT c4,
36+
float& r1,
37+
float& r2,
38+
float& r3,
39+
float& r4);
40+
41+
enum class Batch4Kind { IP, L2 };
42+
43+
template <typename T, Batch4Kind Kind>
44+
inline __attribute__((always_inline)) typename T::FloatVec
45+
batch4_accumulate(typename T::FloatVec q, typename T::FloatVec c, typename T::FloatVec acc) {
46+
if constexpr (Kind == Batch4Kind::IP) {
47+
return T::fmadd(q, c, acc);
48+
} else {
49+
typename T::FloatVec d = T::sub(q, c);
50+
return T::fmadd(d, d, acc);
51+
}
52+
}
53+
54+
template <typename T, Batch4Kind Kind>
55+
inline void
56+
ComputeBatch4Impl(const float* RESTRICT query,
57+
uint64_t dim,
58+
const float* RESTRICT c1,
59+
const float* RESTRICT c2,
60+
const float* RESTRICT c3,
61+
const float* RESTRICT c4,
62+
float& r1,
63+
float& r2,
64+
float& r3,
65+
float& r4,
66+
Batch4Fallback fallback = nullptr) {
67+
using V = typename T::FloatVec;
68+
constexpr int W = T::Width;
69+
70+
if constexpr (W > 1) {
71+
if (dim < static_cast<uint64_t>(W)) {
72+
fallback(query, dim, c1, c2, c3, c4, r1, r2, r3, r4);
73+
return;
74+
}
75+
}
76+
77+
V s1 = T::zero();
78+
V s2 = T::zero();
79+
V s3 = T::zero();
80+
V s4 = T::zero();
81+
82+
uint64_t i = 0;
83+
for (; i + W <= dim; i += W) {
84+
V q = T::load(query + i);
85+
s1 = batch4_accumulate<T, Kind>(q, T::load(c1 + i), s1);
86+
s2 = batch4_accumulate<T, Kind>(q, T::load(c2 + i), s2);
87+
s3 = batch4_accumulate<T, Kind>(q, T::load(c3 + i), s3);
88+
s4 = batch4_accumulate<T, Kind>(q, T::load(c4 + i), s4);
89+
}
90+
r1 += T::reduce_add(s1);
91+
r2 += T::reduce_add(s2);
92+
r3 += T::reduce_add(s3);
93+
r4 += T::reduce_add(s4);
94+
95+
if constexpr (W > 1) {
96+
if (dim > i) {
97+
fallback(query + i, dim - i, c1 + i, c2 + i, c3 + i, c4 + i, r1, r2, r3, r4);
98+
}
99+
}
100+
}
101+
102+
} // namespace vsag::simd

0 commit comments

Comments
 (0)