Skip to content

Commit

Permalink
Merge branch 'main' into amx
Browse files Browse the repository at this point in the history
  • Loading branch information
mengdilin authored Jul 18, 2024
2 parents a3e2ccb + 749163e commit e59deb8
Show file tree
Hide file tree
Showing 24 changed files with 1,746 additions and 286 deletions.
24 changes: 24 additions & 0 deletions .github/workflows/autoclose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
name: Close inactive issues
on:
schedule:
- cron: "30 1 * * *"

jobs:
close-issues:
runs-on: ubuntu-latest
permissions:
issues: write
pull-requests: write
steps:
- uses: actions/stale@v5
with:
only-labels: autoclose
labels-to-remove-when-unstale: autoclose
days-before-issue-stale: 7
days-before-issue-close: 7
stale-issue-label: "stale"
stale-issue-message: "This issue is stale because it has been open for 7 days with no activity."
close-issue-message: "This issue was closed because it has been inactive for 7 days since being marked as stale."
days-before-pr-stale: -1
days-before-pr-close: -1
repo-token: ${{ secrets.GITHUB_TOKEN }}
1 change: 1 addition & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ jobs:
- uses: ./.github/actions/build_cmake
with:
opt_level: avx512

6 changes: 6 additions & 0 deletions contrib/inspect_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ def get_flat_data(index):
return xb.reshape(index.ntotal, index.d)


def get_flat_codes(index_flat):
""" get the codes from an indexFlatCodes as an array """
return faiss.vector_to_array(index_flat.codes).reshape(
index_flat.ntotal, index_flat.code_size)


def get_NSG_neighbors(nsg):
""" get the neighbor list for the vectors stored in the NSG structure, as
a N-by-K matrix of indices """
Expand Down
77 changes: 77 additions & 0 deletions demos/demo_qinco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
This demonstrates how to reproduce the QINCo paper results using the Faiss
QINCo implementation. The code loads the reference model because training
is not implemented in Faiss.
Prepare the data with
cd /tmp
# get the reference qinco code
git clone https://github.com/facebookresearch/Qinco.git
# get the data
wget https://dl.fbaipublicfiles.com/QINCo/datasets/bigann/bigann1M.bvecs
# get the model
wget https://dl.fbaipublicfiles.com/QINCo/models/bigann_8x8_L2.pt
"""

import numpy as np
from faiss.contrib.vecs_io import bvecs_mmap
import sys
import time
import torch
import faiss

# make sure pickle deserialization will work
sys.path.append("/tmp/Qinco")
import model_qinco

with torch.no_grad():

qinco = torch.load("/tmp/bigann_8x8_L2.pt")
qinco.eval()
# print(qinco)
if True:
torch.set_num_threads(1)
faiss.omp_set_num_threads(1)

x_base = bvecs_mmap("/tmp/bigann1M.bvecs")[:1000].astype('float32')
x_scaled = torch.from_numpy(x_base) / qinco.db_scale

t0 = time.time()
codes, _ = qinco.encode(x_scaled)
x_decoded_scaled = qinco.decode(codes)
print(f"Pytorch encode {time.time() - t0:.3f} s")
# multi-thread: 1.13s, single-thread: 7.744

x_decoded = x_decoded_scaled.numpy() * qinco.db_scale

err = ((x_decoded - x_base) ** 2).sum(1).mean()
print("MSE=", err) # = 14211.956, near the L=2 result in Fig 4 of the paper

qinco2 = faiss.QINCo(qinco)
t0 = time.time()
codes2 = qinco2.encode(faiss.Tensor2D(x_scaled))
x_decoded2 = qinco2.decode(codes2).numpy() * qinco.db_scale
print(f"Faiss encode {time.time() - t0:.3f} s")
# multi-thread: 3.2s, single thread: 7.019

# these tests don't work because there are outlier encodings
# np.testing.assert_array_equal(codes.numpy(), codes2.numpy())
# np.testing.assert_allclose(x_decoded, x_decoded2)

ndiff = (codes.numpy() != codes2.numpy()).sum() / codes.numel()
assert ndiff < 0.01
ndiff = (((x_decoded - x_decoded2) ** 2).sum(1) > 1e-5).sum()
assert ndiff / len(x_base) < 0.01

err = ((x_decoded2 - x_base) ** 2).sum(1).mean()
print("MSE=", err) # = 14213.551
2 changes: 2 additions & 0 deletions faiss/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ set(FAISS_SRC
IndexScalarQuantizer.cpp
IndexShards.cpp
IndexShardsIVF.cpp
IndexNeuralNetCodec.cpp
MatrixStats.cpp
MetaIndexes.cpp
VectorTransform.cpp
Expand Down Expand Up @@ -81,6 +82,7 @@ set(FAISS_SRC
invlists/InvertedLists.cpp
invlists/InvertedListsIOHook.cpp
utils/Heap.cpp
utils/NeuralNet.cpp
utils/WorkerThread.cpp
utils/distances.cpp
utils/distances_simd.cpp
Expand Down
164 changes: 159 additions & 5 deletions faiss/IndexFlatCodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#include <faiss/impl/DistanceComputer.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/impl/IDSelector.h>
#include <faiss/impl/ResultHandler.h>
#include <faiss/utils/extra_distances.h>

namespace faiss {

Expand Down Expand Up @@ -70,11 +72,6 @@ void IndexFlatCodes::reconstruct(idx_t key, float* recons) const {
reconstruct_n(key, 1, recons);
}

FlatCodesDistanceComputer* IndexFlatCodes::get_FlatCodesDistanceComputer()
const {
FAISS_THROW_MSG("not implemented");
}

void IndexFlatCodes::check_compatible_for_merge(const Index& otherIndex) const {
// minimal sanity checks
const IndexFlatCodes* other =
Expand Down Expand Up @@ -114,4 +111,161 @@ void IndexFlatCodes::permute_entries(const idx_t* perm) {
std::swap(codes, new_codes);
}

namespace {

template <class VD>
struct GenericFlatCodesDistanceComputer : FlatCodesDistanceComputer {
const IndexFlatCodes& codec;
const VD vd;
// temp buffers
std::vector<uint8_t> code_buffer;
std::vector<float> vec_buffer;
const float* query = nullptr;

GenericFlatCodesDistanceComputer(const IndexFlatCodes* codec, const VD& vd)
: FlatCodesDistanceComputer(codec->codes.data(), codec->code_size),
codec(*codec),
vd(vd),
code_buffer(codec->code_size * 4),
vec_buffer(codec->d * 4) {}

void set_query(const float* x) override {
query = x;
}

float operator()(idx_t i) override {
codec.sa_decode(1, codes + i * code_size, vec_buffer.data());
return vd(query, vec_buffer.data());
}

float distance_to_code(const uint8_t* code) override {
codec.sa_decode(1, code, vec_buffer.data());
return vd(query, vec_buffer.data());
}

float symmetric_dis(idx_t i, idx_t j) override {
codec.sa_decode(1, codes + i * code_size, vec_buffer.data());
codec.sa_decode(1, codes + j * code_size, vec_buffer.data() + vd.d);
return vd(vec_buffer.data(), vec_buffer.data() + vd.d);
}

void distances_batch_4(
const idx_t idx0,
const idx_t idx1,
const idx_t idx2,
const idx_t idx3,
float& dis0,
float& dis1,
float& dis2,
float& dis3) override {
uint8_t* cp = code_buffer.data();
for (idx_t i : {idx0, idx1, idx2, idx3}) {
memcpy(cp, codes + i * code_size, code_size);
cp += code_size;
}
// potential benefit is if batch decoding is more efficient than 1 by 1
// decoding
codec.sa_decode(4, code_buffer.data(), vec_buffer.data());
dis0 = vd(query, vec_buffer.data());
dis1 = vd(query, vec_buffer.data() + vd.d);
dis2 = vd(query, vec_buffer.data() + 2 * vd.d);
dis3 = vd(query, vec_buffer.data() + 3 * vd.d);
}
};

struct Run_get_distance_computer {
using T = FlatCodesDistanceComputer*;

template <class VD>
FlatCodesDistanceComputer* f(const VD& vd, const IndexFlatCodes* codec) {
return new GenericFlatCodesDistanceComputer<VD>(codec, vd);
}
};

template <class BlockResultHandler>
struct Run_search_with_decompress {
using T = void;

template <class VectorDistance>
void f(VectorDistance& vd,
const IndexFlatCodes* index_ptr,
const float* xq,
BlockResultHandler& res) {
// Note that there seems to be a clang (?) bug that "sometimes" passes
// the const Index & parameters by value, so to be on the safe side,
// it's better to use pointers.
const IndexFlatCodes& index = *index_ptr;
size_t ntotal = index.ntotal;
using SingleResultHandler =
typename BlockResultHandler::SingleResultHandler;
using DC = GenericFlatCodesDistanceComputer<VectorDistance>;
#pragma omp parallel // if (res.nq > 100)
{
std::unique_ptr<DC> dc(new DC(&index, vd));
SingleResultHandler resi(res);
#pragma omp for
for (int64_t q = 0; q < res.nq; q++) {
resi.begin(q);
dc->set_query(xq + vd.d * q);
for (size_t i = 0; i < ntotal; i++) {
if (res.is_in_selection(i)) {
float dis = (*dc)(i);
resi.add_result(dis, i);
}
}
resi.end();
}
}
}
};

struct Run_search_with_decompress_res {
using T = void;

template <class ResultHandler>
void f(ResultHandler& res, const IndexFlatCodes* index, const float* xq) {
Run_search_with_decompress<ResultHandler> r;
dispatch_VectorDistance(
index->d,
index->metric_type,
index->metric_arg,
r,
index,
xq,
res);
}
};

} // anonymous namespace

FlatCodesDistanceComputer* IndexFlatCodes::get_FlatCodesDistanceComputer()
const {
Run_get_distance_computer r;
return dispatch_VectorDistance(d, metric_type, metric_arg, r, this);
}

void IndexFlatCodes::search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels,
const SearchParameters* params) const {
Run_search_with_decompress_res r;
const IDSelector* sel = params ? params->sel : nullptr;
dispatch_knn_ResultHandler(
n, distances, labels, k, metric_type, sel, r, this, x);
}

void IndexFlatCodes::range_search(
idx_t n,
const float* x,
float radius,
RangeSearchResult* result,
const SearchParameters* params) const {
const IDSelector* sel = params ? params->sel : nullptr;
Run_search_with_decompress_res r;
dispatch_range_ResultHandler(result, radius, metric_type, sel, r, this, x);
}

} // namespace faiss
23 changes: 20 additions & 3 deletions faiss/IndexFlatCodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
* LICENSE file in the root directory of this source tree.
*/

// -*- c++ -*-

#pragma once

#include <faiss/Index.h>
Expand Down Expand Up @@ -45,13 +43,32 @@ struct IndexFlatCodes : Index {
* different from the usual ones: the new ids are shifted */
size_t remove_ids(const IDSelector& sel) override;

/** a FlatCodesDistanceComputer offers a distance_to_code method */
/** a FlatCodesDistanceComputer offers a distance_to_code method
*
* The default implementation explicitly decodes the vector with sa_decode.
*/
virtual FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const;

DistanceComputer* get_distance_computer() const override {
return get_FlatCodesDistanceComputer();
}

/** Search implemented by decoding */
void search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels,
const SearchParameters* params = nullptr) const override;

void range_search(
idx_t n,
const float* x,
float radius,
RangeSearchResult* result,
const SearchParameters* params = nullptr) const override;

// returns a new instance of a CodePacker
CodePacker* get_CodePacker() const;

Expand Down
2 changes: 1 addition & 1 deletion faiss/IndexIVF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ void Level1Quantizer::train_q1(
} else if (quantizer_trains_alone == 1) {
if (verbose)
printf("IVF quantizer trains alone...\n");
quantizer->train(n, x);
quantizer->verbose = verbose;
quantizer->train(n, x);
FAISS_THROW_IF_NOT_MSG(
quantizer->ntotal == nlist,
"nlist not consistent with quantizer size");
Expand Down
Loading

0 comments on commit e59deb8

Please sign in to comment.