Skip to content

Commit

Permalink
Making GeoDatasets use seperate values for x and y resolution
Browse files Browse the repository at this point in the history
  • Loading branch information
calebrob6 committed Feb 21, 2025
1 parent 1eb4301 commit d6cb939
Show file tree
Hide file tree
Showing 36 changed files with 71 additions and 71 deletions.
2 changes: 1 addition & 1 deletion tests/datamodules/test_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(
super().__init__()
for i in range(length):
self.index.insert(i, (0, 1, 2, 3, 4, 5))
self.res = 1
self.res = (1, 1)

def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
image = torch.arange(3 * 2 * 2, dtype=torch.float).view(3, 2, 2)
Expand Down
2 changes: 1 addition & 1 deletion tests/datasets/test_cbf.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def dataset(
root = tmp_path
transforms = nn.Identity()
return CanadianBuildingFootprints(
root, res=0.1, transforms=transforms, download=True, checksum=True
root, res=(0.1, 0.1), transforms=transforms, download=True, checksum=True
)

def test_getitem(self, dataset: CanadianBuildingFootprints) -> None:
Expand Down
35 changes: 18 additions & 17 deletions tests/datasets/test_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
self,
bounds: BoundingBox = BoundingBox(0, 1, 2, 3, 4, 5),
crs: CRS = CRS.from_epsg(4087),
res: float = 1,
res: tuple[float, float] = (1, 1),
paths: str | os.PathLike[str] | Iterable[str | os.PathLike[str]] | None = None,
) -> None:
super().__init__()
Expand Down Expand Up @@ -324,7 +324,8 @@ def test_getitem_separate_files(self, sentinel: Sentinel2) -> None:
def test_reprojection(self, naip: NAIP) -> None:
naip2 = NAIP(naip.paths, crs='EPSG:4326')
assert naip.crs != naip2.crs
assert not math.isclose(naip.res, naip2.res)
assert not math.isclose(naip.res[0], naip2.res[0])
assert not math.isclose(naip.res[1], naip2.res[1])

@pytest.mark.parametrize('dtype', ['uint16', 'uint32'])
def test_getitem_uint_dtype(self, dtype: str) -> None:
Expand Down Expand Up @@ -381,14 +382,14 @@ class TestVectorDataset:
def dataset(self) -> CustomVectorDataset:
root = os.path.join('tests', 'data', 'vector')
transforms = nn.Identity()
return CustomVectorDataset(root, res=0.1, transforms=transforms)
return CustomVectorDataset(root, res=(0.1, 0.1), transforms=transforms)

@pytest.fixture(scope='class')
def multilabel(self) -> CustomVectorDataset:
root = os.path.join('tests', 'data', 'vector')
transforms = nn.Identity()
return CustomVectorDataset(
root, res=0.1, transforms=transforms, label_name='label_id'
root, res=(0.1, 0.1), transforms=transforms, label_name='label_id'
)

def test_getitem(self, dataset: CustomVectorDataset) -> None:
Expand Down Expand Up @@ -562,7 +563,7 @@ def test_different_crs_12(self) -> None:
ds = IntersectionDataset(ds1, ds2)
sample = ds[ds.bounds]
assert ds1.crs == ds2.crs == ds.crs == CRS.from_epsg(4087)
assert ds1.res == ds2.res == ds.res == 2
assert ds1.res == ds2.res == ds.res == (2, 2)
assert len(ds1) == len(ds2) == len(ds) == 1
assert isinstance(sample['image'], torch.Tensor)

Expand All @@ -573,7 +574,7 @@ def test_different_crs_12_3(self) -> None:
ds = (ds1 & ds2) & ds3
sample = ds[ds.bounds]
assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087)
assert ds1.res == ds2.res == ds3.res == ds.res == 2
assert ds1.res == ds2.res == ds3.res == ds.res == (2, 2)
assert len(ds1) == len(ds2) == len(ds3) == len(ds) == 1
assert isinstance(sample['image'], torch.Tensor)

Expand All @@ -584,7 +585,7 @@ def test_different_crs_1_23(self) -> None:
ds = ds1 & (ds2 & ds3)
sample = ds[ds.bounds]
assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087)
assert ds1.res == ds2.res == ds3.res == ds.res == 2
assert ds1.res == ds2.res == ds3.res == ds.res == (2, 2)
assert len(ds1) == len(ds2) == len(ds3) == len(ds) == 1
assert isinstance(sample['image'], torch.Tensor)

Expand All @@ -594,7 +595,7 @@ def test_different_res_12(self) -> None:
ds = IntersectionDataset(ds1, ds2)
sample = ds[ds.bounds]
assert ds1.crs == ds2.crs == ds.crs == CRS.from_epsg(4087)
assert ds1.res == ds2.res == ds.res == 2
assert ds1.res == ds2.res == ds.res == (2, 2)
assert len(ds1) == len(ds2) == len(ds) == 1
assert isinstance(sample['image'], torch.Tensor)

Expand All @@ -605,7 +606,7 @@ def test_different_res_12_3(self) -> None:
ds = (ds1 & ds2) & ds3
sample = ds[ds.bounds]
assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087)
assert ds1.res == ds2.res == ds3.res == ds.res == 2
assert ds1.res == ds2.res == ds3.res == ds.res == (2, 2)
assert len(ds1) == len(ds2) == len(ds3) == len(ds) == 1
assert isinstance(sample['image'], torch.Tensor)

Expand All @@ -616,7 +617,7 @@ def test_different_res_1_23(self) -> None:
ds = ds1 & (ds2 & ds3)
sample = ds[ds.bounds]
assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087)
assert ds1.res == ds2.res == ds3.res == ds.res == 2
assert ds1.res == ds2.res == ds3.res == ds.res == (2, 2)
assert len(ds1) == len(ds2) == len(ds3) == len(ds) == 1
assert isinstance(sample['image'], torch.Tensor)

Expand All @@ -625,7 +626,7 @@ def test_point_dataset(self) -> None:
ds2 = CustomGeoDataset(BoundingBox(1, 1, 3, 3, 5, 5))
ds = IntersectionDataset(ds1, ds2)
assert ds1.crs == ds2.crs == ds.crs == CRS.from_epsg(4087)
assert ds1.res == ds2.res == ds.res == 1
assert ds1.res == ds2.res == ds.res == (1, 1)
assert len(ds1) == len(ds2) == len(ds) == 1

def test_no_overlap(self) -> None:
Expand Down Expand Up @@ -678,7 +679,7 @@ def test_different_crs_12(self) -> None:
ds = UnionDataset(ds1, ds2)
sample = ds[ds.bounds]
assert ds1.crs == ds2.crs == ds.crs == CRS.from_epsg(4087)
assert ds1.res == ds2.res == ds.res == 2
assert ds1.res == ds2.res == ds.res == (2, 2)
assert len(ds1) == len(ds2) == 1
assert len(ds) == 2
assert isinstance(sample['image'], torch.Tensor)
Expand All @@ -690,7 +691,7 @@ def test_different_crs_12_3(self) -> None:
ds = (ds1 | ds2) | ds3
sample = ds[ds.bounds]
assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087)
assert ds1.res == ds2.res == ds3.res == ds.res == 2
assert ds1.res == ds2.res == ds3.res == ds.res == (2, 2)
assert len(ds1) == len(ds2) == len(ds3) == 1
assert len(ds) == 3
assert isinstance(sample['image'], torch.Tensor)
Expand All @@ -702,7 +703,7 @@ def test_different_crs_1_23(self) -> None:
ds = ds1 | (ds2 | ds3)
sample = ds[ds.bounds]
assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087)
assert ds1.res == ds2.res == ds3.res == ds.res == 2
assert ds1.res == ds2.res == ds3.res == ds.res == (2, 2)
assert len(ds1) == len(ds2) == len(ds3) == 1
assert len(ds) == 3
assert isinstance(sample['image'], torch.Tensor)
Expand All @@ -713,7 +714,7 @@ def test_different_res_12(self) -> None:
ds = UnionDataset(ds1, ds2)
sample = ds[ds.bounds]
assert ds1.crs == ds2.crs == ds.crs == CRS.from_epsg(4087)
assert ds1.res == ds2.res == ds.res == 2
assert ds1.res == ds2.res == ds.res == (2, 2)
assert len(ds1) == len(ds2) == 1
assert len(ds) == 2
assert isinstance(sample['image'], torch.Tensor)
Expand All @@ -725,7 +726,7 @@ def test_different_res_12_3(self) -> None:
ds = (ds1 | ds2) | ds3
sample = ds[ds.bounds]
assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087)
assert ds1.res == ds2.res == ds3.res == ds.res == 2
assert ds1.res == ds2.res == ds3.res == ds.res == (2, 2)
assert len(ds1) == len(ds2) == len(ds3) == 1
assert len(ds) == 3
assert isinstance(sample['image'], torch.Tensor)
Expand All @@ -737,7 +738,7 @@ def test_different_res_1_23(self) -> None:
ds = ds1 | (ds2 | ds3)
sample = ds[ds.bounds]
assert ds1.crs == ds2.crs == ds3.crs == ds.crs == CRS.from_epsg(4087)
assert ds1.res == ds2.res == ds3.res == ds.res == 2
assert ds1.res == ds2.res == ds3.res == ds.res == (2, 2)
assert len(ds1) == len(ds2) == len(ds3) == 1
assert len(ds) == 3
assert isinstance(sample['image'], torch.Tensor)
Expand Down
Empty file.
2 changes: 1 addition & 1 deletion tests/datasets/test_sentinel.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class TestSentinel2:
@pytest.fixture
def dataset(self) -> Sentinel2:
root = os.path.join('tests', 'data', 'sentinel2')
res = 10
res = (10.0, 10.0)
bands = ['B02', 'B03', 'B04', 'B08']
transforms = nn.Identity()
return Sentinel2(root, res=res, bands=bands, transforms=transforms)
Expand Down
2 changes: 1 addition & 1 deletion tests/datasets/test_splits.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(
self,
items: list[tuple[BoundingBox, str]] = [(BoundingBox(0, 1, 0, 1, 0, 40), '')],
crs: CRS = CRS.from_epsg(3005),
res: float = 1,
res: tuple[float, float] = (1, 1),
) -> None:
super().__init__()
for box, content in items:
Expand Down
2 changes: 1 addition & 1 deletion tests/samplers/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __len__(self) -> int:


class CustomGeoDataset(GeoDataset):
def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 10) -> None:
def __init__(self, crs: CRS = CRS.from_epsg(3005), res: tuple[float, float] = (10, 10)) -> None:
super().__init__()
self._crs = crs
self.res = res
Expand Down
2 changes: 1 addition & 1 deletion tests/samplers/test_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __len__(self) -> int:


class CustomGeoDataset(GeoDataset):
def __init__(self, crs: CRS = CRS.from_epsg(3005), res: float = 10) -> None:
def __init__(self, crs: CRS = CRS.from_epsg(3005), res: tuple[float, float] = (10, 10)) -> None:
super().__init__()
self._crs = crs
self.res = res
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/agb_live_woody_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(
self,
paths: Path | Iterable[Path] = 'data',
crs: CRS | None = None,
res: float | None = None,
res: tuple[float, float] | None = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
download: bool = False,
cache: bool = True,
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/astergdem.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
self,
paths: Path | list[Path] = 'data',
crs: CRS | None = None,
res: float | None = None,
res: tuple[float, float] | None = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
cache: bool = True,
) -> None:
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/cbf.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(
self,
paths: Path | Iterable[Path] = 'data',
crs: CRS | None = None,
res: float = 0.00001,
res: tuple[float, float] = (0.00001, 0.00001),
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
download: bool = False,
checksum: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/cdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def __init__(
self,
paths: Path | Iterable[Path] = 'data',
crs: CRS | None = None,
res: float | None = None,
res: tuple[float, float] | None = None,
years: list[int] = [2023],
classes: list[int] = list(cmap.keys()),
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datasets/chesapeake.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __init__(
self,
paths: Path | Iterable[Path] = 'data',
crs: CRS | None = None,
res: float | None = None,
res: tuple[float, float] | None = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
cache: bool = True,
download: bool = False,
Expand Down Expand Up @@ -351,7 +351,7 @@ class ChesapeakeCVPR(GeoDataset):
}

crs = CRS.from_epsg(3857)
res = 1
res = (1, 1)

lc_cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {
0: (0, 0, 0, 0),
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/cms_mangrove_canopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def __init__(
self,
paths: Path | list[Path] = 'data',
crs: CRS | None = None,
res: float | None = None,
res: tuple[float, float] | None = None,
measurement: str = 'agb',
country: str = all_countries[0],
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/eddmaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class EDDMapS(GeoDataset):
.. versionadded:: 0.3
"""

res = 0
res = (0, 0)
_crs = CRS.from_epsg(4326) # Lat/Lon

def __init__(self, root: Path = 'data') -> None:
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/enmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def __init__(
self,
paths: Path | Iterable[Path] = 'data',
crs: CRS | None = None,
res: float | None = None,
res: tuple[float, float] | None = None,
bands: Sequence[str] | None = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
cache: bool = True,
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/enviroatlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class EnviroAtlas(GeoDataset):
md5 = 'bfe601be21c7c001315fc6154be8ef14'

crs = CRS.from_epsg(3857)
res = 1
res = (1, 1)

valid_prior_layers = ('prior', 'prior_no_osm_no_buildings')

Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/esri2020.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(
self,
paths: Path | Iterable[Path] = 'data',
crs: CRS | None = None,
res: float | None = None,
res: tuple[float, float] | None = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
cache: bool = True,
download: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/eudem.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
self,
paths: Path | Iterable[Path] = 'data',
crs: CRS | None = None,
res: float | None = None,
res: tuple[float, float] | None = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
cache: bool = True,
checksum: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/eurocrops.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(
self,
paths: Path | Iterable[Path] = 'data',
crs: CRS = CRS.from_epsg(4326),
res: float = 0.00001,
res: tuple[float, float] = (0.00001, 0.00001),
classes: list[str] | None = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
download: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/gbif.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class GBIF(GeoDataset):
.. versionadded:: 0.3
"""

res = 0
res = (0, 0)
_crs = CRS.from_epsg(4326) # Lat/Lon

def __init__(self, root: Path = 'data') -> None:
Expand Down
Loading

0 comments on commit d6cb939

Please sign in to comment.