Skip to content

Commit

Permalink
Add SQ8bit signed quantization (#3501)
Browse files Browse the repository at this point in the history
Summary:
### Description
Add new signed 8 bit scalar quantizer, `QT_8bit_direct_signed` to ingest signed 8 bit vectors ([-128 to 127]).

### Issues Resolved
#3488

Pull Request resolved: #3501

Reviewed By: mengdilin

Differential Revision: D58639363

Pulled By: mdouze

fbshipit-source-id: cf7f244fdbb7a34051d2b20c6f8086cd5628b4e0
  • Loading branch information
naveentatikonda authored and facebook-github-bot committed Jun 24, 2024
1 parent da75d03 commit 33c0ba5
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 17 deletions.
1 change: 1 addition & 0 deletions benchs/bench_fw/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def optimize_codec(
(None, "SQfp16"),
(None, "SQbf16"),
(None, "SQ8"),
(None, "SQ8_direct_signed"),
] + [
(f"OPQ{M}_{M * dim}", f"PQ{M}x{b}")
for M in [8, 12, 16, 32, 48, 64, 96, 128, 192, 256]
Expand Down
2 changes: 2 additions & 0 deletions c_api/IndexScalarQuantizer_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ typedef enum FaissQuantizerType {
QT_8bit_direct, ///< fast indexing of uint8s
QT_6bit, ///< 6 bits per component
QT_bf16,
QT_8bit_direct_signed, ///< fast indexing of signed int8s ranging from [-128
///< to 127]
} FaissQuantizerType;

// forward declaration
Expand Down
3 changes: 2 additions & 1 deletion faiss/IndexScalarQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ IndexScalarQuantizer::IndexScalarQuantizer(
: IndexFlatCodes(0, d, metric), sq(d, qtype) {
is_trained = qtype == ScalarQuantizer::QT_fp16 ||
qtype == ScalarQuantizer::QT_8bit_direct ||
qtype == ScalarQuantizer::QT_bf16;
qtype == ScalarQuantizer::QT_bf16 ||
qtype == ScalarQuantizer::QT_8bit_direct_signed;
code_size = sq.code_size;
}

Expand Down
103 changes: 97 additions & 6 deletions faiss/impl/ScalarQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -621,13 +621,90 @@ struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {

FAISS_ALWAYS_INLINE float32x4x2_t
reconstruct_8_components(const uint8_t* code, int i) const {
float32_t result[8] = {};
for (size_t j = 0; j < 8; j++) {
result[j] = code[i + j];
uint8x8_t x8 = vld1_u8((const uint8_t*)(code + i));
uint16x8_t y8 = vmovl_u8(x8);
uint16x4_t y8_0 = vget_low_u16(y8);
uint16x4_t y8_1 = vget_high_u16(y8);

// convert uint16 -> uint32 -> fp32
return {vcvtq_f32_u32(vmovl_u16(y8_0)), vcvtq_f32_u32(vmovl_u16(y8_1))};
}
};

#endif

/*******************************************************************
* 8bit_direct_signed quantizer
*******************************************************************/

template <int SIMDWIDTH>
struct Quantizer8bitDirectSigned {};

template <>
struct Quantizer8bitDirectSigned<1> : ScalarQuantizer::SQuantizer {
const size_t d;

Quantizer8bitDirectSigned(size_t d, const std::vector<float>& /* unused */)
: d(d) {}

void encode_vector(const float* x, uint8_t* code) const final {
for (size_t i = 0; i < d; i++) {
code[i] = (uint8_t)(x[i] + 128);
}
float32x4_t res1 = vld1q_f32(result);
float32x4_t res2 = vld1q_f32(result + 4);
return {res1, res2};
}

void decode_vector(const uint8_t* code, float* x) const final {
for (size_t i = 0; i < d; i++) {
x[i] = code[i] - 128;
}
}

FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i)
const {
return code[i] - 128;
}
};

#ifdef __AVX2__

template <>
struct Quantizer8bitDirectSigned<8> : Quantizer8bitDirectSigned<1> {
Quantizer8bitDirectSigned(size_t d, const std::vector<float>& trained)
: Quantizer8bitDirectSigned<1>(d, trained) {}

FAISS_ALWAYS_INLINE __m256
reconstruct_8_components(const uint8_t* code, int i) const {
__m128i x8 = _mm_loadl_epi64((__m128i*)(code + i)); // 8 * int8
__m256i y8 = _mm256_cvtepu8_epi32(x8); // 8 * int32
__m256i c8 = _mm256_set1_epi32(128);
__m256i z8 = _mm256_sub_epi32(y8, c8); // subtract 128 from all lanes
return _mm256_cvtepi32_ps(z8); // 8 * float32
}
};

#endif

#ifdef __aarch64__

template <>
struct Quantizer8bitDirectSigned<8> : Quantizer8bitDirectSigned<1> {
Quantizer8bitDirectSigned(size_t d, const std::vector<float>& trained)
: Quantizer8bitDirectSigned<1>(d, trained) {}

FAISS_ALWAYS_INLINE float32x4x2_t
reconstruct_8_components(const uint8_t* code, int i) const {
uint8x8_t x8 = vld1_u8((const uint8_t*)(code + i));
uint16x8_t y8 = vmovl_u8(x8); // convert uint8 -> uint16
uint16x4_t y8_0 = vget_low_u16(y8);
uint16x4_t y8_1 = vget_high_u16(y8);

float32x4_t z8_0 = vcvtq_f32_u32(
vmovl_u16(y8_0)); // convert uint16 -> uint32 -> fp32
float32x4_t z8_1 = vcvtq_f32_u32(vmovl_u16(y8_1));

// subtract 128 to convert into signed numbers
return {vsubq_f32(z8_0, vmovq_n_f32(128.0)),
vsubq_f32(z8_1, vmovq_n_f32(128.0))};
}
};

Expand Down Expand Up @@ -660,6 +737,8 @@ ScalarQuantizer::SQuantizer* select_quantizer_1(
return new QuantizerBF16<SIMDWIDTH>(d, trained);
case ScalarQuantizer::QT_8bit_direct:
return new Quantizer8bitDirect<SIMDWIDTH>(d, trained);
case ScalarQuantizer::QT_8bit_direct_signed:
return new Quantizer8bitDirectSigned<SIMDWIDTH>(d, trained);
}
FAISS_THROW_MSG("unknown qtype");
}
Expand Down Expand Up @@ -1460,6 +1539,11 @@ SQDistanceComputer* select_distance_computer(
Sim,
SIMDWIDTH>(d, trained);
}
case ScalarQuantizer::QT_8bit_direct_signed:
return new DCTemplate<
Quantizer8bitDirectSigned<SIMDWIDTH>,
Sim,
SIMDWIDTH>(d, trained);
}
FAISS_THROW_MSG("unknown qtype");
return nullptr;
Expand All @@ -1483,6 +1567,7 @@ void ScalarQuantizer::set_derived_sizes() {
case QT_8bit:
case QT_8bit_uniform:
case QT_8bit_direct:
case QT_8bit_direct_signed:
code_size = d;
bits = 8;
break;
Expand Down Expand Up @@ -1540,6 +1625,7 @@ void ScalarQuantizer::train(size_t n, const float* x) {
case QT_fp16:
case QT_8bit_direct:
case QT_bf16:
case QT_8bit_direct_signed:
// no training necessary
break;
}
Expand Down Expand Up @@ -1885,6 +1971,11 @@ InvertedListScanner* sel1_InvertedListScanner(
Similarity,
SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
}
case ScalarQuantizer::QT_8bit_direct_signed:
return sel2_InvertedListScanner<DCTemplate<
Quantizer8bitDirectSigned<SIMDWIDTH>,
Similarity,
SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
}

FAISS_THROW_MSG("unknown qtype");
Expand Down
2 changes: 2 additions & 0 deletions faiss/impl/ScalarQuantizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ struct ScalarQuantizer : Quantizer {
QT_8bit_direct, ///< fast indexing of uint8s
QT_6bit, ///< 6 bits per component
QT_bf16,
QT_8bit_direct_signed, ///< fast indexing of signed int8s ranging from
///< [-128 to 127]
};

QuantizerType qtype = QT_8bit;
Expand Down
5 changes: 4 additions & 1 deletion faiss/index_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,11 @@ std::map<std::string, ScalarQuantizer::QuantizerType> sq_types = {
{"SQ6", ScalarQuantizer::QT_6bit},
{"SQfp16", ScalarQuantizer::QT_fp16},
{"SQbf16", ScalarQuantizer::QT_bf16},
{"SQ8_direct_signed", ScalarQuantizer::QT_8bit_direct_signed},
{"SQ8_direct", ScalarQuantizer::QT_8bit_direct},
};
const std::string sq_pattern = "(SQ4|SQ8|SQ6|SQfp16|SQbf16)";
const std::string sq_pattern =
"(SQ4|SQ8|SQ6|SQfp16|SQbf16|SQ8_direct_signed|SQ8_direct)";

std::map<std::string, AdditiveQuantizer::Search_type_t> aq_search_type = {
{"_Nfloat", AdditiveQuantizer::ST_norm_float},
Expand Down
31 changes: 22 additions & 9 deletions tests/test_index_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def test_parallel_mode(self):


class TestSQByte(unittest.TestCase):
def subtest_8bit_direct(self, metric_type, d):
def subtest_8bit_direct(self, metric_type, d, quantizer_type):
xt, xb, xq = get_dataset_2(d, 500, 1000, 30)

# rescale everything to get integer
Expand All @@ -324,16 +324,28 @@ def rescale(x):
x[x > 255] = 255
return x

xt = rescale(xt)
xb = rescale(xb)
xq = rescale(xq)
def rescale_signed(x):
x = np.floor((x - tmin) * 256 / (tmax - tmin))
x[x < 0] = 0
x[x > 255] = 255
x -= 128
return x

if quantizer_type == faiss.ScalarQuantizer.QT_8bit_direct_signed:
xt = rescale_signed(xt)
xb = rescale_signed(xb)
xq = rescale_signed(xq)
else:
xt = rescale(xt)
xb = rescale(xb)
xq = rescale(xq)

gt_index = faiss.IndexFlat(d, metric_type)
gt_index.add(xb)
Dref, Iref = gt_index.search(xq, 10)

index = faiss.IndexScalarQuantizer(
d, faiss.ScalarQuantizer.QT_8bit_direct, metric_type
d, quantizer_type, metric_type
)
index.add(xb)
D, I = index.search(xq, 10)
Expand All @@ -353,7 +365,7 @@ def rescale(x):
Dref, Iref = gt_index.search(xq, 10)

index = faiss.IndexIVFScalarQuantizer(
quantizer, d, nlist, faiss.ScalarQuantizer.QT_8bit_direct,
quantizer, d, nlist, quantizer_type,
metric_type
)
index.nprobe = 4
Expand All @@ -366,9 +378,10 @@ def rescale(x):
assert np.all(D == Dref)

def test_8bit_direct(self):
for d in 13, 16, 24:
for metric_type in faiss.METRIC_L2, faiss.METRIC_INNER_PRODUCT:
self.subtest_8bit_direct(metric_type, d)
for quantizer in faiss.ScalarQuantizer.QT_8bit_direct, faiss.ScalarQuantizer.QT_8bit_direct_signed:
for d in 13, 16, 24:
for metric_type in faiss.METRIC_L2, faiss.METRIC_INNER_PRODUCT:
self.subtest_8bit_direct(metric_type, d, quantizer)


class TestNNDescent(unittest.TestCase):
Expand Down

0 comments on commit 33c0ba5

Please sign in to comment.