From 8d08912453812a26c390b49ac5c2fbd43be3e986 Mon Sep 17 00:00:00 2001 From: Matthijs Douze Date: Thu, 29 Aug 2019 13:44:08 +0200 Subject: [PATCH] Ondisk distributed index implementation (#930) Adds the code for the distributed on-disk index --- benchs/distributed_ondisk/README.md | 194 +++++++++ benchs/distributed_ondisk/combined_index.py | 187 ++++++++ .../distributed_ondisk/distributed_kmeans.py | 404 ++++++++++++++++++ .../distributed_query_demo.py | 65 +++ .../distributed_ondisk/make_index_vslice.py | 112 +++++ .../distributed_ondisk/make_trained_index.py | 48 +++ benchs/distributed_ondisk/merge_to_ondisk.py | 91 ++++ benchs/distributed_ondisk/rpc.py | 245 +++++++++++ benchs/distributed_ondisk/run_on_cluster.bash | 259 +++++++++++ benchs/distributed_ondisk/search_server.py | 216 ++++++++++ 10 files changed, 1821 insertions(+) create mode 100644 benchs/distributed_ondisk/README.md create mode 100644 benchs/distributed_ondisk/combined_index.py create mode 100644 benchs/distributed_ondisk/distributed_kmeans.py create mode 100644 benchs/distributed_ondisk/distributed_query_demo.py create mode 100644 benchs/distributed_ondisk/make_index_vslice.py create mode 100644 benchs/distributed_ondisk/make_trained_index.py create mode 100644 benchs/distributed_ondisk/merge_to_ondisk.py create mode 100644 benchs/distributed_ondisk/rpc.py create mode 100644 benchs/distributed_ondisk/run_on_cluster.bash create mode 100644 benchs/distributed_ondisk/search_server.py diff --git a/benchs/distributed_ondisk/README.md b/benchs/distributed_ondisk/README.md new file mode 100644 index 0000000000..e1d91aed5b --- /dev/null +++ b/benchs/distributed_ondisk/README.md @@ -0,0 +1,194 @@ +# Distributed on-disk index for 1T-scale datasets + +This is code corresponding to the description in [Indexing 1T vectors](https://github.com/facebookresearch/faiss/wiki/Indexing-1T-vectors). +All the code is in python 3 (and not compatible with Python 2). +The current code uses the Deep1B dataset for demonstration purposes, but can scale to 1000x larger. +To run it, download the Deep1B dataset as explained [here](../#getting-deep1b), and edit paths to the dataset in the scripts. + +## Distributed k-means + +To cluster 500M vectors to 10M centroids, it is useful to have a distriubuted k-means implementation. +The distribution simply consists in splitting the training vectors across machines (servers) and have them do the assignment. +The master/client then synthesizes the results and updates the centroids. + +The distributed k-means implementation here is based on 3 files: + +- [`rpc.py`](rpc.py) is a very simple remote procedure call implementation based on sockets and pickle. +It exposes the methods of an object on the server side so that they can be called from the client as if the object was local. + +- [`distributed_kmeans.py`](distributed_kmeans.py) contains the k-means implementation. +The main loop of k-means is re-implemented in python but follows closely the Faiss C++ implementation, and should not be significantly less efficient. +It relies on a `DatasetAssign` object that does the assignement to centrtoids, which is the bulk of the computation. +The object can be a Faiss CPU index, a GPU index or a set of remote GPU or CPU indexes. + +- [`run_on_cluster.bash`](run_on_cluster.bash) contains the shell code to run the distributed k-means on a cluster. + +The distributed k-means works with a Python install that contains faiss and scipy (for sparse matrices). +It clusters the training data of Deep1B, this can be changed easily to any file in fvecs, bvecs or npy format that contains the training set. +The training vectors may be too large to fit in RAM, but they are memory-mapped so that should not be a problem. +The file is also assumed to be accessible from all server machines with eg. a distributed file system. + +### Local tests + +Edit `distibuted_kmeans.py` to point `testdata` to your local copy of the dataset. + +Then, 4 levels of sanity check can be run: +```bash +# reference Faiss C++ run +python distributed_kmeans.py --test 0 +# using the Python implementation +python distributed_kmeans.py --test 1 +# use the dispatch object (on local datasets) +python distributed_kmeans.py --test 2 +# same, with GPUs +python distributed_kmeans.py --test 3 +``` +The output should look like [This gist](https://gist.github.com/mdouze/ffa01fe666a9325761266fe55ead72ad). + +### Distributed sanity check + +To run the distributed k-means, `distibuted_kmeans.py` has to be run both on the servers (`--server` option) and client sides (`--client` option). +Edit the top of `run_on_cluster.bash` to set the path of the data to cluster. + +Sanity checks can be run with +```bash +# non distributed baseline +bash run_on_cluster.bash test_kmeans_0 +# using all the machine's GPUs +bash run_on_cluster.bash test_kmeans_1 +# distrbuted run, with one local server per GPU +bash run_on_cluster.bash test_kmeans_2 +``` +The test `test_kmeans_2` simulates a distributed run on a single machine by starting one server process per GPU and connecting to the servers via the rpc protocol. +The output should look like [this gist](https://gist.github.com/mdouze/5b2dc69b74579ecff04e1686a277d32e). + + + +### Distributed run + +The way the script can be distributed depends on the cluster's scheduling system. +Here we use Slurm, but it should be relatively easy to adapt to any scheduler that can allocate a set of matchines and start the same exectuable on all of them. + +The command +``` +bash run_on_cluster.bash slurm_distributed_kmeans +``` +asks SLURM for 5 machines with 4 GPUs each with the `srun` command. +All 5 machines run the script with the `slurm_within_kmeans_server` option. +They determine the number of servers and their own server id via the `SLURM_NPROCS` and `SLURM_PROCID` environment variables. + +All machines start `distributed_kmeans.py` in server mode for the slice of the dataset they are responsible for. + +In addition, the machine #0 also starts the client. +The client knows who are the other servers via the variable `SLURM_JOB_NODELIST`. +It connects to all clients and performs the clustering. + +The output should look like [this gist](https://gist.github.com/mdouze/8d25e89fb4af5093057cae0f917da6cd). + +### Run used for deep1B + +For the real run, we run the clustering on 50M vectors to 1M centroids. +This is just a matter of using as many machines / GPUs as possible in setting the output centroids with the `--out filename` option. +Then run +``` +bash run_on_cluster.bash deep1b_clustering +``` + +The last lines of output read like: +``` + Iteration 19 (898.92 s, search 875.71 s): objective=1.33601e+07 imbalance=1.303 nsplit=0 + 0: writing centroids to /checkpoint/matthijs/ondisk_distributed/1M_centroids.npy +``` + +This means that the total training time was 899s, of which 876s were used for computation. +However, the computation includes the I/O overhead to the assignment servers. +In this implementation, the overhead of transmitting the data is non-negligible and so is the centroid computation stage. +This is due to the inefficient Python implementation and the RPC protocol that is not optimized for broadcast / gather (like MPI). +However, it is a simple implementation that should run on most clusters. + +## Making the trained index + +After the centroids are obtained, an empty trained index must be constructed. +This is done by: + +- applying a pre-processing stage (a random rotation) to balance the dimensions of the vectors. This can be done after clustering, the clusters are just rotated as well. + +- wrapping the centroids into a HNSW index to speed up the CPU-based assignment of vectors + +- training the 6-bit scalar quantizer used to encode the vectors + +This is performed by the script [`make_trained_index.py`](make_trained_index.py). + +## Building the index by slices + +We call the slices "vslisces" as they are vertical slices of the big matrix (see explanation in the wiki section [Split across datanbase partitions](https://github.com/facebookresearch/faiss/wiki/Indexing-1T-vectors#split-across-database-partitions) + +The script [make_index_vslice.py](make_index_vslice.py) makes an index for a subset of the vectors of the input data and stores it as an independent index. +There are 200 slices of 5M vectors each for Deep1B. +It can be run in a brute-force parallel fashion, there is no constraint on ordering. +To run the script in parallel on a slurm cluster, use: +``` +bash run_on_cluster.bash make_index_vslices +``` +For a real dataset, the data would be read from a DBMS. +In that case, reading the data and indexing it in parallel is worthwhile because reading is very slow. + +## Splitting accross inverted lists + +The 200 slices need to be merged together. +This is done with the script [merge_to_ondisk.py](merge_to_ondisk.py), that memory maps the 200 vertical slice indexes, extracts a subset of the inverted lists and writes them to a contiguous horizontal slice. +We slice the inverted lists into 50 horizontal slices. +This is run with +``` +bash run_on_cluster.bash make_index_hslices +``` + +## Querying the index + +At this point the index is ready. +The horizontal slices need to be loaded in the right order and combined into an index to be usable. +This is done in the [combined_index.py](combined_index.py) script. +It provides a `CombinedIndexDeep1B` object that contains an index object that can be searched. +To test, run: +``` +python combined_index.py +``` +The output should look like: +``` +(faiss_1.5.2) matthijs@devfair0144:~/faiss_versions/faiss_1Tcode/faiss/benchs/distributed_ondisk$ python combined_index.py +reading /checkpoint/matthijs/ondisk_distributed//hslices/slice49.faissindex +loading empty index /checkpoint/matthijs/ondisk_distributed/trained.faissindex +replace invlists +loaded index of size 1000000000 +nprobe=1 1-recall@1=0.2904 t=12.35s +nnprobe=10 1-recall@1=0.6499 t=17.67s +nprobe=100 1-recall@1=0.8673 t=29.23s +nprobe=1000 1-recall@1=0.9132 t=129.58s +``` +ie. searching is a lot slower than from RAM. + +## Distributed query + +To reduce the bandwidth required from the machine that does the queries, it is possible to split the search accross several search servers. +This way, only the effective results are returned to the main machine. + +The search client and server are implemented in [`search_server.py`](search_server.py). +It can be used as a script to start a search server for `CombinedIndexDeep1B` or as a module to load the clients. + +The search servers can be started with +``` +bash run_on_cluster.bash run_search_servers +``` +(adjust to the number of servers that can be used). + +Then an example of search client is [`distributed_query_demo.py`](distributed_query_demo.py). +It connects to the servers and assigns subsets of inverted lists to visit to each of them. + +A typical output is [this gist](https://gist.github.com/mdouze/1585b9854a9a2437d71f2b2c3c05c7c5). +The number in MiB indicates the amount of data that is read from disk to perform the search. +In this case, the scale of the dataset is too small for the distributed search to have much impact, but on datasets > 10x larger, the difference becomes more significant. + +## Conclusion + +This code contains the core components to make an index that scales up to 1T vectors. +There are a few simplifications wrt. the index that was effectively used in [Indexing 1T vectors](https://github.com/facebookresearch/faiss/wiki/Indexing-1T-vectors). diff --git a/benchs/distributed_ondisk/combined_index.py b/benchs/distributed_ondisk/combined_index.py new file mode 100644 index 0000000000..50281c1668 --- /dev/null +++ b/benchs/distributed_ondisk/combined_index.py @@ -0,0 +1,187 @@ +import os +import faiss +import numpy as np + + +class CombinedIndex: + """ + combines a set of inverted lists into a hstack + masks part of those lists + adds these inverted lists to an empty index that contains + the info on how to perform searches + """ + + def __init__(self, invlist_fnames, empty_index_fname, + masked_index_fname=None): + + self.indexes = indexes = [] + ilv = faiss.InvertedListsPtrVector() + + for fname in invlist_fnames: + if os.path.exists(fname): + print('reading', fname, end='\r', flush=True) + index = faiss.read_index(fname) + indexes.append(index) + il = faiss.extract_index_ivf(index).invlists + else: + assert False + ilv.push_back(il) + print() + + self.big_il = faiss.VStackInvertedLists(ilv.size(), ilv.data()) + if masked_index_fname: + self.big_il_base = self.big_il + print('loading', masked_index_fname) + self.masked_index = faiss.read_index( + masked_index_fname, + faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY) + self.big_il = faiss.MaskedInvertedLists( + faiss.extract_index_ivf(self.masked_index).invlists, + self.big_il_base) + + print('loading empty index', empty_index_fname) + self.index = faiss.read_index(empty_index_fname) + ntotal = self.big_il.compute_ntotal() + + print('replace invlists') + index_ivf = faiss.extract_index_ivf(self.index) + index_ivf.replace_invlists(self.big_il, False) + index_ivf.ntotal = self.index.ntotal = ntotal + index_ivf.parallel_mode = 1 # seems reasonable to do this all the time + + quantizer = faiss.downcast_index(index_ivf.quantizer) + quantizer.hnsw.efSearch = 1024 + + ############################################################ + # Expose fields and functions of the index as methods so that they + # can be called by RPC + + def search(self, x, k): + return self.index.search(x, k) + + def range_search(self, x, radius): + return self.index.range_search(x, radius) + + def transform_and_assign(self, xq): + index = self.index + + if isinstance(index, faiss.IndexPreTransform): + assert index.chain.size() == 1 + vt = index.chain.at(0) + xq = vt.apply_py(xq) + + # perform quantization + index_ivf = faiss.extract_index_ivf(index) + quantizer = index_ivf.quantizer + coarse_dis, list_nos = quantizer.search(xq, index_ivf.nprobe) + return xq, list_nos, coarse_dis + + + def ivf_search_preassigned(self, xq, list_nos, coarse_dis, k): + index_ivf = faiss.extract_index_ivf(self.index) + n, d = xq.shape + assert d == index_ivf.d + n2, d2 = list_nos.shape + assert list_nos.shape == coarse_dis.shape + assert n2 == n + assert d2 == index_ivf.nprobe + D = np.empty((n, k), dtype='float32') + I = np.empty((n, k), dtype='int64') + index_ivf.search_preassigned( + n, faiss.swig_ptr(xq), k, + faiss.swig_ptr(list_nos), faiss.swig_ptr(coarse_dis), + faiss.swig_ptr(D), faiss.swig_ptr(I), False) + return D, I + + + def ivf_range_search_preassigned(self, xq, list_nos, coarse_dis, radius): + index_ivf = faiss.extract_index_ivf(self.index) + n, d = xq.shape + assert d == index_ivf.d + n2, d2 = list_nos.shape + assert list_nos.shape == coarse_dis.shape + assert n2 == n + assert d2 == index_ivf.nprobe + res = faiss.RangeSearchResult(n) + + index_ivf.range_search_preassigned( + n, faiss.swig_ptr(xq), radius, + faiss.swig_ptr(list_nos), faiss.swig_ptr(coarse_dis), + res) + + lims = faiss.rev_swig_ptr(res.lims, n + 1).copy() + nd = int(lims[-1]) + D = faiss.rev_swig_ptr(res.distances, nd).copy() + I = faiss.rev_swig_ptr(res.labels, nd).copy() + return lims, D, I + + def set_nprobe(self, nprobe): + index_ivf = faiss.extract_index_ivf(self.index) + index_ivf.nprobe = nprobe + + def set_parallel_mode(self, pm): + index_ivf = faiss.extract_index_ivf(self.index) + index_ivf.parallel_mode = pm + + def get_ntotal(self): + return self.index.ntotal + + def set_prefetch_nthread(self, nt): + for idx in self.indexes: + il = faiss.downcast_InvertedLists( + faiss.extract_index_ivf(idx).invlists) + il.prefetch_nthread + il.prefetch_nthread = nt + + def set_omp_num_threads(self, nt): + faiss.omp_set_num_threads(nt) + +class CombinedIndexDeep1B(CombinedIndex): + """ loads a CombinedIndex with the data from the big photodna index """ + + def __init__(self): + # set some paths + workdir = "/checkpoint/matthijs/ondisk_distributed/" + + # empty index with the proper quantizer + indexfname = workdir + 'trained.faissindex' + + # index that has some invlists that override the big one + masked_index_fname = None + invlist_fnames = [ + '%s/hslices/slice%d.faissindex' % (workdir, i) + for i in range(50) + ] + CombinedIndex.__init__(self, invlist_fnames, indexfname, masked_index_fname) + + +def ivecs_read(fname): + a = np.fromfile(fname, dtype='int32') + d = a[0] + return a.reshape(-1, d + 1)[:, 1:].copy() + + +def fvecs_read(fname): + return ivecs_read(fname).view('float32') + + +if __name__ == '__main__': + import time + ci = CombinedIndexDeep1B() + print('loaded index of size ', ci.index.ntotal) + + deep1bdir = "/datasets01_101/simsearch/041218/deep1b/" + + xq = fvecs_read(deep1bdir + "deep1B_queries.fvecs") + gt_fname = deep1bdir + "deep1B_groundtruth.ivecs" + gt = ivecs_read(gt_fname) + + for nprobe in 1, 10, 100, 1000: + ci.set_nprobe(nprobe) + t0 = time.time() + D, I = ci.search(xq, 100) + t1 = time.time() + print('nprobe=%d 1-recall@1=%.4f t=%.2fs' % ( + nprobe, (I[:, 0] == gt[:, 0]).sum() / len(xq), + t1 - t0 + )) diff --git a/benchs/distributed_ondisk/distributed_kmeans.py b/benchs/distributed_ondisk/distributed_kmeans.py new file mode 100644 index 0000000000..7c6c2156b3 --- /dev/null +++ b/benchs/distributed_ondisk/distributed_kmeans.py @@ -0,0 +1,404 @@ +#! /usr/bin/env python3 +""" +Simple distributed kmeans implementation Relies on an abstraction +for the training matrix, that can be sharded over several machines. +""" + +import faiss +import time +import numpy as np +import sys +import pdb +import argparse + +from scipy.sparse import csc_matrix + +from multiprocessing.dummy import Pool as ThreadPool + +import rpc + + + + +class DatasetAssign: + """Wrapper for a matrix that offers a function to assign the vectors + to centroids. All other implementations offer the same interface""" + + def __init__(self, x): + self.x = np.ascontiguousarray(x, dtype='float32') + + def count(self): + return self.x.shape[0] + + def dim(self): + return self.x.shape[1] + + def get_subset(self, indices): + return self.x[indices] + + def perform_search(self, centroids): + index = faiss.IndexFlatL2(self.x.shape[1]) + index.add(centroids) + return index.search(self.x, 1) + + def assign_to(self, centroids, weights=None): + D, I = self.perform_search(centroids) + + I = I.ravel() + D = D.ravel() + n = len(self.x) + if weights is None: + weights = np.ones(n, dtype='float32') + nc = len(centroids) + m = csc_matrix((weights, I, np.arange(n + 1)), + shape=(nc, n)) + sum_per_centroid = m * self.x + + return I, D, sum_per_centroid + + +class DatasetAssignGPU(DatasetAssign): + """ GPU version of the previous """ + + def __init__(self, x, gpu_id, verbose=False): + DatasetAssign.__init__(self, x) + index = faiss.IndexFlatL2(x.shape[1]) + if gpu_id >= 0: + self.index = faiss.index_cpu_to_gpu( + faiss.StandardGpuResources(), + gpu_id, index) + else: + # -1 -> assign to all GPUs + self.index = faiss.index_cpu_to_all_gpus(index) + + + def perform_search(self, centroids): + self.index.reset() + self.index.add(centroids) + return self.index.search(self.x, 1) + + +class DatasetAssignDispatch: + """dispatches to several other DatasetAssigns and combines the + results""" + + def __init__(self, xes, in_parallel): + self.xes = xes + self.d = xes[0].dim() + if not in_parallel: + self.imap = map + else: + self.pool = ThreadPool(len(self.xes)) + self.imap = self.pool.imap + self.sizes = list(map(lambda x: x.count(), self.xes)) + self.cs = np.cumsum([0] + self.sizes) + + def count(self): + return self.cs[-1] + + def dim(self): + return self.d + + def get_subset(self, indices): + res = np.zeros((len(indices), self.d), dtype='float32') + nos = np.searchsorted(self.cs[1:], indices, side='right') + + def handle(i): + mask = nos == i + sub_indices = indices[mask] - self.cs[i] + subset = self.xes[i].get_subset(sub_indices) + res[mask] = subset + + list(self.imap(handle, range(len(self.xes)))) + return res + + def assign_to(self, centroids, weights=None): + src = self.imap( + lambda x: x.assign_to(centroids, weights), + self.xes + ) + I = [] + D = [] + sum_per_centroid = None + for Ii, Di, sum_per_centroid_i in src: + I.append(Ii) + D.append(Di) + if sum_per_centroid is None: + sum_per_centroid = sum_per_centroid_i + else: + sum_per_centroid += sum_per_centroid_i + return np.hstack(I), np.hstack(D), sum_per_centroid + + +def imbalance_factor(k , assign): + return faiss.imbalance_factor(len(assign), k, faiss.swig_ptr(assign)) + + +def reassign_centroids(hassign, centroids, rs=None): + """ reassign centroids when some of them collapse """ + if rs is None: + rs = np.random + k, d = centroids.shape + nsplit = 0 + empty_cents = np.where(hassign == 0)[0] + + if empty_cents.size == 0: + return 0 + + fac = np.ones(d) + fac[::2] += 1 / 1024. + fac[1::2] -= 1 / 1024. + + # this is a single pass unless there are more than k/2 + # empty centroids + while empty_cents.size > 0: + # choose which centroids to split + probas = hassign.astype('float') - 1 + probas[probas < 0] = 0 + probas /= probas.sum() + nnz = (probas > 0).sum() + + nreplace = min(nnz, empty_cents.size) + cjs = rs.choice(k, size=nreplace, p=probas) + + for ci, cj in zip(empty_cents[:nreplace], cjs): + + c = centroids[cj] + centroids[ci] = c * fac + centroids[cj] = c / fac + + hassign[ci] = hassign[cj] // 2 + hassign[cj] -= hassign[ci] + nsplit += 1 + + empty_cents = empty_cents[nreplace:] + + return nsplit + + +def kmeans(k, data, niter=25, seed=1234, checkpoint=None): + """Pure python kmeans implementation. Follows the Faiss C++ version + quite closely, but takes a DatasetAssign instead of a training data + matrix. Also redo is not implemented. """ + n, d = data.count(), data.dim() + + print(("Clustering %d points in %dD to %d clusters, " + + "%d iterations seed %d") % (n, d, k, niter, seed)) + + rs = np.random.RandomState(seed) + print("preproc...") + t0 = time.time() + # initialization + perm = rs.choice(n, size=k, replace=False) + centroids = data.get_subset(perm) + + print(" done") + t_search_tot = 0 + obj = [] + for i in range(niter): + t0s = time.time() + + print('assigning', end='\r', flush=True) + assign, D, sums = data.assign_to(centroids) + + print('compute centroids', end='\r', flush=True) + + # pdb.set_trace() + + t_search_tot += time.time() - t0s; + + err = D.sum() + obj.append(err) + + hassign = np.bincount(assign, minlength=k) + + fac = hassign.reshape(-1, 1).astype('float32') + fac[fac == 0] = 1 # quiet warning + + centroids = sums / fac + + nsplit = reassign_centroids(hassign, centroids, rs) + + print((" Iteration %d (%.2f s, search %.2f s): " + "objective=%g imbalance=%.3f nsplit=%d") % ( + i, (time.time() - t0), t_search_tot, + err, imbalance_factor (k, assign), + nsplit) + ) + + if checkpoint is not None: + print('storing centroids in', checkpoint) + np.save(checkpoint, centroids) + + return centroids + + +class AssignServer(rpc.Server): + """ Assign version that can be exposed via RPC """ + + def __init__(self, s, assign, log_prefix=''): + rpc.Server.__init__(self, s, log_prefix=log_prefix) + self.assign = assign + + def __getattr__(self, f): + return getattr(self.assign, f) + + + +def bvecs_mmap(fname): + x = np.memmap(fname, dtype='uint8', mode='r') + d = x[:4].view('int32')[0] + return x.reshape(-1, d + 4)[:, 4:] + + +def ivecs_mmap(fname): + a = np.memmap(fname, dtype='int32', mode='r') + d = a[0] + return a.reshape(-1, d + 1)[:, 1:] + +def fvecs_mmap(fname): + return ivecs_mmap(fname).view('float32') + + +def do_test(todo): + testdata = '/datasets01_101/simsearch/041218/bigann/bigann_learn.bvecs' + + x = bvecs_mmap(testdata) + + # bad distribution to stress-test split code + xx = x[:100000].copy() + xx[:50000] = x[0] + + todo = sys.argv[1:] + + if "0" in todo: + # reference C++ run + km = faiss.Kmeans(x.shape[1], 1000, niter=20, verbose=True) + km.train(xx.astype('float32')) + + if "1" in todo: + # using the Faiss c++ implementation + data = DatasetAssign(xx) + kmeans(1000, data, 20) + + if "2" in todo: + # use the dispatch object (on local datasets) + data = DatasetAssignDispatch([ + DatasetAssign(xx[20000 * i : 20000 * (i + 1)]) + for i in range(5) + ], False + ) + kmeans(1000, data, 20) + + if "3" in todo: + # same, with GPU + ngpu = faiss.get_num_gpus() + print('using %d GPUs' % ngpu) + data = DatasetAssignDispatch([ + DatasetAssignGPU(xx[100000 * i // ngpu: 100000 * (i + 1) // ngpu], i) + for i in range(ngpu) + ], True + ) + kmeans(1000, data, 20) + + +def main(): + parser = argparse.ArgumentParser() + + def aa(*args, **kwargs): + group.add_argument(*args, **kwargs) + + group = parser.add_argument_group('general options') + aa('--test', default='', help='perform tests (comma-separated numbers)') + + aa('--k', default=0, type=int, help='nb centroids') + aa('--seed', default=1234, type=int, help='random seed') + aa('--niter', default=20, type=int, help='nb iterations') + aa('--gpu', default=-2, type=int, help='GPU to use (-2:none, -1: all)') + + group = parser.add_argument_group('I/O options') + aa('--indata', default='', + help='data file to load (supported formats fvecs, bvecs, npy') + aa('--i0', default=0, type=int, help='first vector to keep') + aa('--i1', default=-1, type=int, help='last vec to keep + 1') + aa('--out', default='', help='file to store centroids') + aa('--store_each_iteration', default=False, action='store_true', + help='store centroid checkpoints') + + group = parser.add_argument_group('server options') + aa('--server', action='store_true', default=False, help='run server') + aa('--port', default=12345, type=int, help='server port') + aa('--when_ready', default=None, help='store host:port to this file when ready') + aa('--ipv4', default=False, action='store_true', help='force ipv4') + + group = parser.add_argument_group('client options') + aa('--client', action='store_true', default=False, help='run client') + aa('--servers', default='', help='list of server:port separated by spaces') + + args = parser.parse_args() + + if args.test: + do_test(args.test.split(',')) + return + + # prepare data matrix (either local or remote) + if args.indata: + print('loading ', args.indata) + if args.indata.endswith('.bvecs'): + x = bvecs_mmap(args.indata) + elif args.indata.endswith('.fvecs'): + x = fvecs_mmap(args.indata) + elif args.indata.endswith('.npy'): + x = np.load(args.indata, mmap_mode='r') + else: + assert False + + if args.i1 == -1: + args.i1 = len(x) + x = x[args.i0:args.i1] + if args.gpu == -2: + data = DatasetAssign(x) + else: + print('moving to GPU') + data = DatasetAssignGPU(x, args.gpu) + + elif args.client: + print('connecting to servers') + + def connect_client(hostport): + host, port = hostport.split(':') + port = int(port) + print('connecting %s:%d' % (host, port)) + client = rpc.Client(host, port, v6=not args.ipv4) + print('client %s:%d ready' % (host, port)) + return client + + hostports = args.servers.strip().split(' ') + # pool = ThreadPool(len(hostports)) + + data = DatasetAssignDispatch( + list(map(connect_client, hostports)), + True + ) + else: + assert False + + if args.server: + print('starting server') + log_prefix = f"{rpc.socket.gethostname()}:{args.port}" + rpc.run_server( + lambda s: AssignServer(s, data, log_prefix=log_prefix), + args.port, report_to_file=args.when_ready, + v6=not args.ipv4) + + else: + print('running kmeans') + centroids = kmeans(args.k, data, niter=args.niter, seed=args.seed, + checkpoint=args.out if args.store_each_iteration else None) + if args.out != '': + print('writing centroids to', args.out) + np.save(args.out, centroids) + + +if __name__ == '__main__': + main() diff --git a/benchs/distributed_ondisk/distributed_query_demo.py b/benchs/distributed_ondisk/distributed_query_demo.py new file mode 100644 index 0000000000..2e4c4e911d --- /dev/null +++ b/benchs/distributed_ondisk/distributed_query_demo.py @@ -0,0 +1,65 @@ +import os +import faiss +import numpy as np +import time +import rpc +import sys + +import combined_index +import search_server + +hostnames = sys.argv[1:] + +print("Load local index") +ci = combined_index.CombinedIndexDeep1B() + +print("connect to clients") +clients = [] +for host in hostnames: + client = rpc.Client(host, 12012, v6=False) + clients.append(client) + +# check if all servers respond +print("sizes seen by servers:", [cl.get_ntotal() for cl in clients]) + + +# aggregate all clients into a one that uses them all for speed +# note that it also requires a local index ci +sindex = search_server.SplitPerListIndex(ci, clients) +sindex.verbose = True + +# set reasonable parameters +ci.set_parallel_mode(1) +ci.set_prefetch_nthread(0) +ci.set_omp_num_threads(64) + +# initialize params +sindex.set_parallel_mode(1) +sindex.set_prefetch_nthread(0) +sindex.set_omp_num_threads(64) + +def ivecs_read(fname): + a = np.fromfile(fname, dtype='int32') + d = a[0] + return a.reshape(-1, d + 1)[:, 1:].copy() + +def fvecs_read(fname): + return ivecs_read(fname).view('float32') + + +deep1bdir = "/datasets01_101/simsearch/041218/deep1b/" + +xq = fvecs_read(deep1bdir + "deep1B_queries.fvecs") +gt_fname = deep1bdir + "deep1B_groundtruth.ivecs" +gt = ivecs_read(gt_fname) + + +for nprobe in 1, 10, 100, 1000: + sindex.set_nprobe(nprobe) + t0 = time.time() + D, I = sindex.search(xq, 100) + t1 = time.time() + print('nprobe=%d 1-recall@1=%.4f t=%.2fs' % ( + nprobe, (I[:, 0] == gt[:, 0]).sum() / len(xq), + t1 - t0 + )) diff --git a/benchs/distributed_ondisk/make_index_vslice.py b/benchs/distributed_ondisk/make_index_vslice.py new file mode 100644 index 0000000000..dfd4e92b8d --- /dev/null +++ b/benchs/distributed_ondisk/make_index_vslice.py @@ -0,0 +1,112 @@ +import os +import time +import numpy as np +import faiss +import argparse +from multiprocessing.dummy import Pool as ThreadPool + +def ivecs_mmap(fname): + a = np.memmap(fname, dtype='int32', mode='r') + d = a[0] + return a.reshape(-1, d + 1)[:, 1:] + +def fvecs_mmap(fname): + return ivecs_mmap(fname).view('float32') + + +def produce_batches(args): + + x = fvecs_mmap(args.input) + + if args.i1 == -1: + args.i1 = len(x) + + print("Iterating on vectors %d:%d from %s by batches of size %d" % ( + args.i0, args.i1, args.input, args.bs)) + + for j0 in range(args.i0, args.i1, args.bs): + j1 = min(j0 + args.bs, args.i1) + yield np.arange(j0, j1), x[j0:j1] + + +def rate_limited_iter(l): + 'a thread pre-processes the next element' + pool = ThreadPool(1) + res = None + + def next_or_None(): + try: + return next(l) + except StopIteration: + return None + + while True: + res_next = pool.apply_async(next_or_None) + if res is not None: + res = res.get() + if res is None: + return + yield res + res = res_next + +deep1bdir = "/datasets01_101/simsearch/041218/deep1b/" +workdir = "/checkpoint/matthijs/ondisk_distributed/" + +def main(): + parser = argparse.ArgumentParser( + description='make index for a subset of the data') + + def aa(*args, **kwargs): + group.add_argument(*args, **kwargs) + + group = parser.add_argument_group('index type') + aa('--inputindex', + default=workdir + 'trained.faissindex', + help='empty input index to fill in') + aa('--nt', default=-1, type=int, help='nb of openmp threads to use') + + group = parser.add_argument_group('db options') + aa('--input', default=deep1bdir + "base.fvecs") + aa('--bs', default=2**18, type=int, + help='batch size for db access') + aa('--i0', default=0, type=int, help='lower bound to index') + aa('--i1', default=-1, type=int, help='upper bound of vectors to index') + + group = parser.add_argument_group('output') + aa('-o', default='/tmp/x', help='output index') + aa('--keepquantizer', default=False, action='store_true', + help='by default we remove the data from the quantizer to save space') + + args = parser.parse_args() + print('args=', args) + + print('start accessing data') + src = produce_batches(args) + + print('loading index', args.inputindex) + index = faiss.read_index(args.inputindex) + + if args.nt != -1: + faiss.omp_set_num_threads(args.nt) + + t0 = time.time() + ntot = 0 + for ids, x in rate_limited_iter(src): + print('add %d:%d (%.3f s)' % (ntot, ntot + ids.size, time.time() - t0)) + index.add_with_ids(np.ascontiguousarray(x, dtype='float32'), ids) + ntot += ids.size + + index_ivf = faiss.extract_index_ivf(index) + print('invlists stats: imbalance %.3f' % index_ivf.invlists.imbalance_factor()) + index_ivf.invlists.print_stats() + + if not args.keepquantizer: + print('resetting quantizer content') + index_ivf = faiss.extract_index_ivf(index) + index_ivf.quantizer.reset() + + print('store output', args.o) + faiss.write_index(index, args.o) + +if __name__ == '__main__': + main() diff --git a/benchs/distributed_ondisk/make_trained_index.py b/benchs/distributed_ondisk/make_trained_index.py new file mode 100644 index 0000000000..345a40d879 --- /dev/null +++ b/benchs/distributed_ondisk/make_trained_index.py @@ -0,0 +1,48 @@ + +import numpy as np +import faiss + +deep1bdir = "/datasets01_101/simsearch/041218/deep1b/" +workdir = "/checkpoint/matthijs/ondisk_distributed/" + + +print('Load centroids') +centroids = np.load(workdir + '1M_centroids.npy') +ncent, d = centroids.shape + + +print('apply random rotation') +rrot = faiss.RandomRotationMatrix(d, d) +rrot.init(1234) +centroids = rrot.apply_py(centroids) + +print('make HNSW index as quantizer') +quantizer = faiss.IndexHNSWFlat(d, 32) +quantizer.hnsw.efSearch = 1024 +quantizer.hnsw.efConstruction = 200 +quantizer.add(centroids) + +print('build index') +index = faiss.IndexPreTransform( + rrot, + faiss.IndexIVFScalarQuantizer( + quantizer, d, ncent, faiss.ScalarQuantizer.QT_6bit + ) + ) + +def ivecs_mmap(fname): + a = np.memmap(fname, dtype='int32', mode='r') + d = a[0] + return a.reshape(-1, d + 1)[:, 1:] + +def fvecs_mmap(fname): + return ivecs_mmap(fname).view('float32') + + +print('finish training index') +xt = fvecs_mmap(deep1bdir + 'learn.fvecs') +xt = np.ascontiguousarray(xt[:256 * 1000], dtype='float32') +index.train(xt) + +print('write output') +faiss.write_index(index, workdir + 'trained.faissindex') diff --git a/benchs/distributed_ondisk/merge_to_ondisk.py b/benchs/distributed_ondisk/merge_to_ondisk.py new file mode 100644 index 0000000000..70119b293a --- /dev/null +++ b/benchs/distributed_ondisk/merge_to_ondisk.py @@ -0,0 +1,91 @@ +import os +import faiss +import argparse +from multiprocessing.dummy import Pool as ThreadPool + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + + parser.add_argument('--inputs', nargs='*', required=True, + help='input indexes to merge') + parser.add_argument('--l0', type=int, default=0) + parser.add_argument('--l1', type=int, default=-1) + + parser.add_argument('--nt', default=-1, + help='nb threads') + + parser.add_argument('--output', required=True, + help='output index filename') + parser.add_argument('--outputIL', + help='output invfile filename') + + args = parser.parse_args() + + if args.nt != -1: + print('set nb of threads to', args.nt) + + + ils = faiss.InvertedListsPtrVector() + ils_dont_dealloc = [] + + pool = ThreadPool(20) + + def load_index(fname): + print("loading", fname) + try: + index = faiss.read_index(fname, faiss.IO_FLAG_MMAP | faiss.IO_FLAG_READ_ONLY) + except RuntimeError as e: + print('could not load %s: %s' % (fname, e)) + return fname, None + + print(" %d entries" % index.ntotal) + return fname, index + + index0 = None + + for fname, index in pool.imap(load_index, args.inputs): + if index is None: + continue + index_ivf = faiss.extract_index_ivf(index) + il = faiss.downcast_InvertedLists(index_ivf.invlists) + index_ivf.invlists = None + il.this.own() + ils_dont_dealloc.append(il) + if (args.l0, args.l1) != (0, -1): + print('restricting to lists %d:%d' % (args.l0, args.l1)) + # il = faiss.SliceInvertedLists(il, args.l0, args.l1) + + il.crop_invlists(args.l0, args.l1) + ils_dont_dealloc.append(il) + ils.push_back(il) + + if index0 is None: + index0 = index + + print("loaded %d invlists" % ils.size()) + + if not args.outputIL: + args.outputIL = args.output + '_invlists' + + il0 = ils.at(0) + + il = faiss.OnDiskInvertedLists( + il0.nlist, il0.code_size, + args.outputIL) + + print("perform merge") + + ntotal = il.merge_from(ils.data(), ils.size(), True) + + print("swap into index0") + + index0_ivf = faiss.extract_index_ivf(index0) + index0_ivf.nlist = il0.nlist + index0_ivf.ntotal = index0.ntotal = ntotal + index0_ivf.invlists = il + index0_ivf.own_invlists = False + + print("write", args.output) + + faiss.write_index(index0, args.output) diff --git a/benchs/distributed_ondisk/rpc.py b/benchs/distributed_ondisk/rpc.py new file mode 100644 index 0000000000..c3ccfaf5e9 --- /dev/null +++ b/benchs/distributed_ondisk/rpc.py @@ -0,0 +1,245 @@ +""" +Simplistic RPC implementation. +Exposes all functions of a Server object. + +Uses pickle for serialization and the socket interface. +""" + +import os,pdb,pickle,time,errno,sys,_thread,traceback,socket,threading,gc + + +# default +PORT=12032 + + +######################################################################### +# simple I/O functions + + + +def inline_send_handle(f, conn): + st = os.fstat(f.fileno()) + size = st.st_size + pickle.dump(size, conn) + conn.write(f.read(size)) + +def inline_send_string(s, conn): + size = len(s) + pickle.dump(size, conn) + conn.write(s) + + +class FileSock: + " wraps a socket so that it is usable by pickle/cPickle " + + def __init__(self,sock): + self.sock = sock + self.nr=0 + + def write(self, buf): + # print("sending %d bytes"%len(buf)) + #self.sock.sendall(buf) + # print("...done") + bs = 512 * 1024 + ns = 0 + while ns < len(buf): + sent = self.sock.send(buf[ns:ns + bs]) + ns += sent + + + def read(self,bs=512*1024): + #if self.nr==10000: pdb.set_trace() + self.nr+=1 + # print("read bs=%d"%bs) + b = [] + nb = 0 + while len(b) $workdir/vslices/slice$i.bash < $workdir/hslices/slice$i.bash <