From 3e75afa6a94567df88b2029ca562c2c9e0064a95 Mon Sep 17 00:00:00 2001 From: Sergey Plis Date: Mon, 19 Feb 2024 18:15:40 -0500 Subject: [PATCH 1/8] redis version of the code. Word of the week brpoplpush --- mindfultensors/redisloader.py | 184 ++++++++++++++++++++++++++++++++++ 1 file changed, 184 insertions(+) create mode 100644 mindfultensors/redisloader.py diff --git a/mindfultensors/redisloader.py b/mindfultensors/redisloader.py new file mode 100644 index 0000000..18fc567 --- /dev/null +++ b/mindfultensors/redisloader.py @@ -0,0 +1,184 @@ +from typing import Sized +import pickle as pkl + +import numpy as np +import torch +import io +from torch.utils.data import Dataset +from torch.utils.data.sampler import Sampler +from mindfultensors.gencoords import CoordsGenerator +from redis import Redis + +def unit_interval_normalize(img): + """Unit interval preprocessing""" + img = (img - img.min()) / (img.max() - img.min()) + return img + + +def qnormalize(img, qmin=0.01, qmax=0.99): + """Unit interval preprocessing""" + img = (img - img.quantile(qmin)) / (img.quantile(qmax) - img.quantile(qmin)) + return img + + +class RedisDataset(Dataset): + """ + A dataset for fetching batches of records from a MongoDB + """ + + def __init__( + self, + indices, + transform, + dbkey, + normalize=unit_interval_normalize, + id="id", + ): + """Constructor + + :param indices: a set of indices to be extracted from the collection + :param transform: a function to be applied to each extracted record + :param collection: pymongo collection to be used + :param sample: a pair of fields to be fetched as `input` and `label`, e.g. (`T1`, `label104`) + :param id: the field to be used as an index. The `indices` are values of this field + :returns: an object of MongoDataset class + + """ + + self.indices = indices + self.transform = transform + self.Redis = None + self.normalize = normalize + self.id = id + + def __len__(self): + return len(self.indices) + + + def __getitem__(self, batch): + # Fetch all samples for ids in the batch and where 'kind' is either + # data or labela s specified by the sample parameter + + results = {} + for id in batch: + # Separate samples for this id + + # Separate processing for each 'kind' + payload = pkl.loads(self.Redis.brpoplpush(self.dbkey, self.dbkey)) + data = payload[0] + label = payload[1] + + # Add to results + results[id] = { + "input": self.normalize(self.transform(data).float()), + "label": self.transform(label), + } + + return results + + +class RBatchSampler(Sampler): + """ + A batch sampler from a random permutation. Used for generating indices for MongoDataset + """ + + data_source: Sized + + def __init__(self, data_source, batch_size=1, seed=None): + """TODO describe function + + :param data_source: a dataset of Dataset class + :param batch_size: number of samples in the batch (sample is an MRI split to 8 records) + :returns: an object of mBatchSampler class + + """ + self.batch_size = batch_size + self.data_source = data_source + self.data_size = len(self.data_source) + self.seed = seed + + def __chunks__(self, l, n): + for i in range(0, len(l), n): + yield l[i : i + n] + + def __iter__(self): + if self.seed is not None: + np.random.seed(self.seed) + return self.__chunks__( + np.random.permutation(self.data_size), self.batch_size + ) + + def __len__(self): + return ( + self.data_size + self.batch_size - 1 + ) // self.batch_size # Number of batches + + + +def create_client(worker_id, redishost): + worker_info = torch.utils.data.get_worker_info() + dataset = worker_info.dataset + client = Redis(host=redishost) + dataset.Redis = client + + +def mtransform(tensor_binary): + buffer = io.BytesIO(tensor_binary) + tensor = torch.load(buffer) + return tensor + + +def mcollate(results, field=("input", "label")): + results = results[0] + # Assuming 'results' is your dictionary containing all the data + input_tensors = [results[id_][field[0]] for id_ in results.keys()] + label_tensors = [results[id_][field[1]] for id_ in results.keys()] + # Stack all input tensors into a single tensor + stacked_inputs = torch.stack(input_tensors) + # Stack all label tensors into a single tensor + stacked_labels = torch.stack(label_tensors) + return stacked_inputs.unsqueeze(1), stacked_labels.long() + + +def collate_subcubes(results, coord_generator, samples=4): + data, labels = mcollate(results) + num_subjs = labels.shape[0] + data = data.squeeze(1) + + batch_data = [] + batch_labels = [] + + for i in range(num_subjs): + subcubes, sublabels = subcube_list( + data[i, :, :, :], labels[i, :, :, :], samples, coord_generator + ) + batch_data.extend(subcubes) + batch_labels.extend(sublabels) + + # Converting the list of tensors to a single tensor + batch_data = torch.stack(batch_data).unsqueeze(1) + batch_labels = torch.stack(batch_labels) + + return batch_data, batch_labels + + +def subcube_list(cube, labels, num, coords_generator): + subcubes = [] + sublabels = [] + + for i in range(num): + coords = coords_generator.get_coordinates() + subcube = cube[ + coords[0][0] : coords[0][1], + coords[1][0] : coords[1][1], + coords[2][0] : coords[2][1], + ] + sublabel = labels[ + coords[0][0] : coords[0][1], + coords[1][0] : coords[1][1], + coords[2][0] : coords[2][1], + ] + subcubes.append(subcube) + sublabels.append(sublabel) + + return subcubes, sublabels From caf2c7c158551200b28451287635cdf159cf8620 Mon Sep 17 00:00:00 2001 From: Sergey Plis Date: Mon, 19 Feb 2024 18:15:59 -0500 Subject: [PATCH 2/8] made redis a dependency --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index f638c5f..664da61 100644 --- a/setup.py +++ b/setup.py @@ -13,6 +13,7 @@ "numpy", "scipy >= 1.7", "pymongo >= 4.0", + "redis", "torch", ""], ) From 5a1cf03b53ad0dbd04c83d384053bcff05d149d1 Mon Sep 17 00:00:00 2001 From: Sergey Plis Date: Tue, 20 Feb 2024 14:57:26 -0500 Subject: [PATCH 3/8] refactoring to avoid code duplication. Renamed the sampler --- mindfultensors/mongoloader.py | 142 ++++++--------------------------- mindfultensors/redisloader.py | 145 ++++++---------------------------- mindfultensors/utils.py | 115 +++++++++++++++++++++++++++ 3 files changed, 163 insertions(+), 239 deletions(-) create mode 100644 mindfultensors/utils.py diff --git a/mindfultensors/mongoloader.py b/mindfultensors/mongoloader.py index 916663d..fdbe0ee 100644 --- a/mindfultensors/mongoloader.py +++ b/mindfultensors/mongoloader.py @@ -1,25 +1,28 @@ -from typing import Sized -import ipdb - -import numpy as np from pymongo import MongoClient -import torch -import io -from torch.utils.data import Dataset +from torch.utils.data import Dataset, get_worker_info from torch.utils.data.sampler import Sampler -from mindfultensors.gencoords import CoordsGenerator - - -def unit_interval_normalize(img): - """Unit interval preprocessing""" - img = (img - img.min()) / (img.max() - img.min()) - return img - -def qnormalize(img, qmin=0.01, qmax=0.99): - """Unit interval preprocessing""" - img = (img - img.quantile(qmin)) / (img.quantile(qmax) - img.quantile(qmin)) - return img +from .gencoords import CoordsGenerator +from .utils import ( + unit_interval_normalize, + qnormalize, + mtransform, + mcollate, + collate_subcubes, + subcube_list, + DBBatchSampler, +) + +__all__ = [ + "unit_interval_normalize", + "qnormalize", + "mtransform", + "mcollate", + "collate_subcubes", + "subcube_list", + "MongoDataset", + "DBBatchSampler", +] class MongoDataset(Dataset): @@ -109,43 +112,6 @@ def __getitem__(self, batch): return results -class MBatchSampler(Sampler): - """ - A batch sampler from a random permutation. Used for generating indices for MongoDataset - """ - - data_source: Sized - - def __init__(self, data_source, batch_size=1, seed=None): - """TODO describe function - - :param data_source: a dataset of Dataset class - :param batch_size: number of samples in the batch (sample is an MRI split to 8 records) - :returns: an object of mBatchSampler class - - """ - self.batch_size = batch_size - self.data_source = data_source - self.data_size = len(self.data_source) - self.seed = seed - - def __chunks__(self, l, n): - for i in range(0, len(l), n): - yield l[i : i + n] - - def __iter__(self): - if self.seed is not None: - np.random.seed(self.seed) - return self.__chunks__( - np.random.permutation(self.data_size), self.batch_size - ) - - def __len__(self): - return ( - self.data_size + self.batch_size - 1 - ) // self.batch_size # Number of batches - - def name2collections(name: str, database): collection_bin = database[f"{name}.bin"] collection_meta = database[f"{name}.meta"] @@ -153,70 +119,8 @@ def name2collections(name: str, database): def create_client(worker_id, dbname, colname, mongohost): - worker_info = torch.utils.data.get_worker_info() + worker_info = get_worker_info() dataset = worker_info.dataset client = MongoClient("mongodb://" + mongohost + ":27017") colbin, colmeta = name2collections(colname, client[dbname]) dataset.collection = {"bin": colbin, "meta": colmeta} - - -def mtransform(tensor_binary): - buffer = io.BytesIO(tensor_binary) - tensor = torch.load(buffer) - return tensor - - -def mcollate(results, field=("input", "label")): - results = results[0] - # Assuming 'results' is your dictionary containing all the data - input_tensors = [results[id_][field[0]] for id_ in results.keys()] - label_tensors = [results[id_][field[1]] for id_ in results.keys()] - # Stack all input tensors into a single tensor - stacked_inputs = torch.stack(input_tensors) - # Stack all label tensors into a single tensor - stacked_labels = torch.stack(label_tensors) - return stacked_inputs.unsqueeze(1), stacked_labels.long() - - -def collate_subcubes(results, coord_generator, samples=4): - data, labels = mcollate(results) - num_subjs = labels.shape[0] - data = data.squeeze(1) - - batch_data = [] - batch_labels = [] - - for i in range(num_subjs): - subcubes, sublabels = subcube_list( - data[i, :, :, :], labels[i, :, :, :], samples, coord_generator - ) - batch_data.extend(subcubes) - batch_labels.extend(sublabels) - - # Converting the list of tensors to a single tensor - batch_data = torch.stack(batch_data).unsqueeze(1) - batch_labels = torch.stack(batch_labels) - - return batch_data, batch_labels - - -def subcube_list(cube, labels, num, coords_generator): - subcubes = [] - sublabels = [] - - for i in range(num): - coords = coords_generator.get_coordinates() - subcube = cube[ - coords[0][0] : coords[0][1], - coords[1][0] : coords[1][1], - coords[2][0] : coords[2][1], - ] - sublabel = labels[ - coords[0][0] : coords[0][1], - coords[1][0] : coords[1][1], - coords[2][0] : coords[2][1], - ] - subcubes.append(subcube) - sublabels.append(sublabel) - - return subcubes, sublabels diff --git a/mindfultensors/redisloader.py b/mindfultensors/redisloader.py index 18fc567..b73b8c5 100644 --- a/mindfultensors/redisloader.py +++ b/mindfultensors/redisloader.py @@ -1,24 +1,30 @@ -from typing import Sized import pickle as pkl - -import numpy as np -import torch -import io -from torch.utils.data import Dataset -from torch.utils.data.sampler import Sampler -from mindfultensors.gencoords import CoordsGenerator from redis import Redis -def unit_interval_normalize(img): - """Unit interval preprocessing""" - img = (img - img.min()) / (img.max() - img.min()) - return img - +from torch.utils.data import Dataset, get_worker_info +from torch.utils.data.sampler import Sampler -def qnormalize(img, qmin=0.01, qmax=0.99): - """Unit interval preprocessing""" - img = (img - img.quantile(qmin)) / (img.quantile(qmax) - img.quantile(qmin)) - return img +from .gencoords import CoordsGenerator +from .utils import ( + unit_interval_normalize, + qnormalize, + mtransform, + mcollate, + collate_subcubes, + subcube_list, + DBBatchSampler, +) + +__all__ = [ + "unit_interval_normalize", + "qnormalize", + "mtransform", + "mcollate", + "collate_subcubes", + "subcube_list", + "RedisDataset", + "DBBatchSampler", +] class RedisDataset(Dataset): @@ -42,7 +48,7 @@ def __init__( :param sample: a pair of fields to be fetched as `input` and `label`, e.g. (`T1`, `label104`) :param id: the field to be used as an index. The `indices` are values of this field :returns: an object of MongoDataset class - + """ self.indices = indices @@ -54,7 +60,6 @@ def __init__( def __len__(self): return len(self.indices) - def __getitem__(self, batch): # Fetch all samples for ids in the batch and where 'kind' is either # data or labela s specified by the sample parameter @@ -77,108 +82,8 @@ def __getitem__(self, batch): return results -class RBatchSampler(Sampler): - """ - A batch sampler from a random permutation. Used for generating indices for MongoDataset - """ - - data_source: Sized - - def __init__(self, data_source, batch_size=1, seed=None): - """TODO describe function - - :param data_source: a dataset of Dataset class - :param batch_size: number of samples in the batch (sample is an MRI split to 8 records) - :returns: an object of mBatchSampler class - - """ - self.batch_size = batch_size - self.data_source = data_source - self.data_size = len(self.data_source) - self.seed = seed - - def __chunks__(self, l, n): - for i in range(0, len(l), n): - yield l[i : i + n] - - def __iter__(self): - if self.seed is not None: - np.random.seed(self.seed) - return self.__chunks__( - np.random.permutation(self.data_size), self.batch_size - ) - - def __len__(self): - return ( - self.data_size + self.batch_size - 1 - ) // self.batch_size # Number of batches - - - def create_client(worker_id, redishost): - worker_info = torch.utils.data.get_worker_info() + worker_info = get_worker_info() dataset = worker_info.dataset client = Redis(host=redishost) dataset.Redis = client - - -def mtransform(tensor_binary): - buffer = io.BytesIO(tensor_binary) - tensor = torch.load(buffer) - return tensor - - -def mcollate(results, field=("input", "label")): - results = results[0] - # Assuming 'results' is your dictionary containing all the data - input_tensors = [results[id_][field[0]] for id_ in results.keys()] - label_tensors = [results[id_][field[1]] for id_ in results.keys()] - # Stack all input tensors into a single tensor - stacked_inputs = torch.stack(input_tensors) - # Stack all label tensors into a single tensor - stacked_labels = torch.stack(label_tensors) - return stacked_inputs.unsqueeze(1), stacked_labels.long() - - -def collate_subcubes(results, coord_generator, samples=4): - data, labels = mcollate(results) - num_subjs = labels.shape[0] - data = data.squeeze(1) - - batch_data = [] - batch_labels = [] - - for i in range(num_subjs): - subcubes, sublabels = subcube_list( - data[i, :, :, :], labels[i, :, :, :], samples, coord_generator - ) - batch_data.extend(subcubes) - batch_labels.extend(sublabels) - - # Converting the list of tensors to a single tensor - batch_data = torch.stack(batch_data).unsqueeze(1) - batch_labels = torch.stack(batch_labels) - - return batch_data, batch_labels - - -def subcube_list(cube, labels, num, coords_generator): - subcubes = [] - sublabels = [] - - for i in range(num): - coords = coords_generator.get_coordinates() - subcube = cube[ - coords[0][0] : coords[0][1], - coords[1][0] : coords[1][1], - coords[2][0] : coords[2][1], - ] - sublabel = labels[ - coords[0][0] : coords[0][1], - coords[1][0] : coords[1][1], - coords[2][0] : coords[2][1], - ] - subcubes.append(subcube) - sublabels.append(sublabel) - - return subcubes, sublabels diff --git a/mindfultensors/utils.py b/mindfultensors/utils.py new file mode 100644 index 0000000..754bd23 --- /dev/null +++ b/mindfultensors/utils.py @@ -0,0 +1,115 @@ +import torch +import io +import numpy as np +from typing import Sized + + +def unit_interval_normalize(img): + """Unit interval preprocessing""" + img = (img - img.min()) / (img.max() - img.min()) + return img + + +def qnormalize(img, qmin=0.01, qmax=0.99): + """Unit interval preprocessing""" + img = (img - img.quantile(qmin)) / (img.quantile(qmax) - img.quantile(qmin)) + return img + + +def mtransform(tensor_binary): + buffer = io.BytesIO(tensor_binary) + tensor = torch.load(buffer) + return tensor + + +def mcollate(results, field=("input", "label")): + results = results[0] + # Assuming 'results' is your dictionary containing all the data + input_tensors = [results[id_][field[0]] for id_ in results.keys()] + label_tensors = [results[id_][field[1]] for id_ in results.keys()] + # Stack all input tensors into a single tensor + stacked_inputs = torch.stack(input_tensors) + # Stack all label tensors into a single tensor + stacked_labels = torch.stack(label_tensors) + return stacked_inputs.unsqueeze(1), stacked_labels.long() + + +def collate_subcubes(results, coord_generator, samples=4): + data, labels = mcollate(results) + num_subjs = labels.shape[0] + data = data.squeeze(1) + + batch_data = [] + batch_labels = [] + + for i in range(num_subjs): + subcubes, sublabels = subcube_list( + data[i, :, :, :], labels[i, :, :, :], samples, coord_generator + ) + batch_data.extend(subcubes) + batch_labels.extend(sublabels) + + # Converting the list of tensors to a single tensor + batch_data = torch.stack(batch_data).unsqueeze(1) + batch_labels = torch.stack(batch_labels) + + return batch_data, batch_labels + + +def subcube_list(cube, labels, num, coords_generator): + subcubes = [] + sublabels = [] + + for i in range(num): + coords = coords_generator.get_coordinates() + subcube = cube[ + coords[0][0] : coords[0][1], + coords[1][0] : coords[1][1], + coords[2][0] : coords[2][1], + ] + sublabel = labels[ + coords[0][0] : coords[0][1], + coords[1][0] : coords[1][1], + coords[2][0] : coords[2][1], + ] + subcubes.append(subcube) + sublabels.append(sublabel) + + return subcubes, sublabels + + +class DBBatchSampler(Sampler): + """ + A batch sampler from a random permutation. Used for generating indices for MongoDataset + """ + + data_source: Sized + + def __init__(self, data_source, batch_size=1, seed=None): + """TODO describe function + + :param data_source: a dataset of Dataset class + :param batch_size: number of samples in the batch (sample is an MRI split to 8 records) + :returns: an object of mBatchSampler class + + """ + self.batch_size = batch_size + self.data_source = data_source + self.data_size = len(self.data_source) + self.seed = seed + + def __chunks__(self, l, n): + for i in range(0, len(l), n): + yield l[i : i + n] + + def __iter__(self): + if self.seed is not None: + np.random.seed(self.seed) + return self.__chunks__( + np.random.permutation(self.data_size), self.batch_size + ) + + def __len__(self): + return ( + self.data_size + self.batch_size - 1 + ) // self.batch_size # Number of batches From 8c05ea44113a7529aa776b5f9b5a747ce09db145 Mon Sep 17 00:00:00 2001 From: Sergey Plis Date: Tue, 20 Feb 2024 15:18:14 -0500 Subject: [PATCH 4/8] id in redis was not used --- mindfultensors/redisloader.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/mindfultensors/redisloader.py b/mindfultensors/redisloader.py index b73b8c5..ab1ebbe 100644 --- a/mindfultensors/redisloader.py +++ b/mindfultensors/redisloader.py @@ -38,7 +38,6 @@ def __init__( transform, dbkey, normalize=unit_interval_normalize, - id="id", ): """Constructor @@ -55,7 +54,6 @@ def __init__( self.transform = transform self.Redis = None self.normalize = normalize - self.id = id def __len__(self): return len(self.indices) From 8c4dbe049054fe2cb6106cab67b878e08ccb6b9c Mon Sep 17 00:00:00 2001 From: Sergey Plis Date: Tue, 20 Feb 2024 17:44:20 -0500 Subject: [PATCH 5/8] adding Sampler where it belongs --- mindfultensors/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mindfultensors/utils.py b/mindfultensors/utils.py index 754bd23..f46a6fa 100644 --- a/mindfultensors/utils.py +++ b/mindfultensors/utils.py @@ -2,6 +2,7 @@ import io import numpy as np from typing import Sized +from torch.utils.data.sampler import Sampler def unit_interval_normalize(img): From 8d06f66fdf341697bea3c56a69f7a9d25fba1a87 Mon Sep 17 00:00:00 2001 From: Sergey Plis Date: Tue, 20 Feb 2024 18:14:12 -0500 Subject: [PATCH 6/8] removed Sampler from redis where it does not belong --- mindfultensors/redisloader.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mindfultensors/redisloader.py b/mindfultensors/redisloader.py index ab1ebbe..8589206 100644 --- a/mindfultensors/redisloader.py +++ b/mindfultensors/redisloader.py @@ -2,7 +2,6 @@ from redis import Redis from torch.utils.data import Dataset, get_worker_info -from torch.utils.data.sampler import Sampler from .gencoords import CoordsGenerator from .utils import ( From 6c3dc7f1705482a303ba00ee44b17ac79e8eaf8e Mon Sep 17 00:00:00 2001 From: Sergey Plis Date: Tue, 20 Feb 2024 19:03:04 -0500 Subject: [PATCH 7/8] missing dbkey is added --- mindfultensors/redisloader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mindfultensors/redisloader.py b/mindfultensors/redisloader.py index 8589206..669a369 100644 --- a/mindfultensors/redisloader.py +++ b/mindfultensors/redisloader.py @@ -52,6 +52,7 @@ def __init__( self.indices = indices self.transform = transform self.Redis = None + self.dbkey = dbkey self.normalize = normalize def __len__(self): From 912cef17fa97b424ab04d6549584cdcaeb075f98 Mon Sep 17 00:00:00 2001 From: Sergey Plis Date: Tue, 20 Feb 2024 19:19:36 -0500 Subject: [PATCH 8/8] minor correction to a comment --- mindfultensors/redisloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindfultensors/redisloader.py b/mindfultensors/redisloader.py index 669a369..a9c615a 100644 --- a/mindfultensors/redisloader.py +++ b/mindfultensors/redisloader.py @@ -60,7 +60,7 @@ def __len__(self): def __getitem__(self, batch): # Fetch all samples for ids in the batch and where 'kind' is either - # data or labela s specified by the sample parameter + # data or label as specified by the sample parameter results = {} for id in batch: