Skip to content

Commit

Permalink
wip tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kylebarron committed Feb 16, 2024
1 parent 040b54f commit dbf8f9d
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 7 deletions.
5 changes: 5 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,8 @@ Reading as Arrow data

.. autofunction:: pyogrio.raw.read_arrow
.. autofunction:: pyogrio.raw.open_arrow

Writing from Arrow data
---------------------

.. autofunction:: pyogrio.raw.write_arrow
1 change: 1 addition & 0 deletions pyogrio/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@


HAS_ARROW_API = __gdal_version__ >= (3, 6, 0) and pyarrow is not None
HAS_ARROW_WRITE_API = __gdal_version__ >= (3, 8, 0) and pyarrow is not None

HAS_GEOPANDAS = geopandas is not None

Expand Down
11 changes: 7 additions & 4 deletions pyogrio/raw.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
import warnings

from pyogrio._compat import HAS_ARROW_API, HAS_ARROW_WRITE_API
from pyogrio._env import GDALEnv
from pyogrio._compat import HAS_ARROW_API
from pyogrio.core import detect_write_driver
from pyogrio.errors import DataSourceError
from pyogrio.util import (
_mask_to_wkb,
_preprocess_options_key_value,
get_vsi_path,
vsi_path,
_preprocess_options_key_value,
_mask_to_wkb,
)

with GDALEnv():
from pyogrio._io import ogr_open_arrow, ogr_read, ogr_write, ogr_write_arrow
from pyogrio._ogr import (
_get_driver_metadata_item,
get_gdal_version,
get_gdal_version_string,
ogr_driver_supports_write,
remove_virtual_file,
_get_driver_metadata_item,
)


Expand Down Expand Up @@ -581,6 +581,9 @@ def write_arrow(
gdal_tz_offsets=None,
**kwargs,
):
if not HAS_ARROW_WRITE_API:
raise RuntimeError("pyarrow and GDAL>=3.8 required to read using arrow")

# if dtypes is given, remove it from kwargs (dtypes is included in meta returned by
# read, and it is convenient to pass meta directly into write for round trip tests)
kwargs.pop("dtypes", None)
Expand Down
14 changes: 11 additions & 3 deletions pyogrio/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from zipfile import ZipFile, ZIP_DEFLATED
from zipfile import ZIP_DEFLATED, ZipFile

import pytest

Expand All @@ -8,10 +8,14 @@
__version__,
list_drivers,
)
from pyogrio._compat import HAS_ARROW_API, HAS_GDAL_GEOS, HAS_SHAPELY
from pyogrio._compat import (
HAS_ARROW_API,
HAS_ARROW_WRITE_API,
HAS_GDAL_GEOS,
HAS_SHAPELY,
)
from pyogrio.raw import read, write


_data_dir = Path(__file__).parent.resolve() / "fixtures"

# mapping of driver extension to driver name for well-supported drivers
Expand Down Expand Up @@ -47,6 +51,10 @@ def pytest_report_header(config):
not HAS_ARROW_API, reason="GDAL>=3.6 and pyarrow required"
)

requires_arrow_write_api = pytest.mark.skipif(
not HAS_ARROW_WRITE_API, reason="GDAL>=3.8 and pyarrow required"
)

requires_gdal_geos = pytest.mark.skipif(
not HAS_GDAL_GEOS, reason="GDAL compiled with GEOS required"
)
Expand Down
118 changes: 118 additions & 0 deletions pyogrio/tests/test_write_arrow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import json
import os

import numpy as np
import pytest

from pyogrio.core import list_layers
from pyogrio.raw import read_arrow, write_arrow
from pyogrio.tests.conftest import requires_arrow_write_api

try:
import pandas as pd
import pyarrow
from geopandas.testing import assert_geodataframe_equal
from pandas.testing import assert_frame_equal, assert_index_equal
except ImportError:
pass

# skip all tests in this file if Arrow Write API or GeoPandas are unavailable
pytestmark = requires_arrow_write_api
pytest.importorskip("geopandas")


def test_write(tmpdir, naturalearth_lowres):
meta, table = read_arrow(naturalearth_lowres)

filename = os.path.join(str(tmpdir), "test.shp")
write_arrow(
filename,
table,
crs=meta["crs"],
encoding=meta["encoding"],
geometry_type=meta["geometry_type"],
)

assert os.path.exists(filename)
for ext in (".dbf", ".prj"):
assert os.path.exists(filename.replace(".shp", ext))


def test_write_gpkg(tmpdir, naturalearth_lowres):
meta, table = read_arrow(naturalearth_lowres)

filename = os.path.join(str(tmpdir), "test.gpkg")
write_arrow(
filename,
table,
driver="GPKG",
crs=meta["crs"],
encoding=meta["encoding"],
geometry_type="MultiPolygon",
)

assert os.path.exists(filename)


def test_write_gpkg_multiple_layers(tmpdir, naturalearth_lowres):
meta, table = read_arrow(naturalearth_lowres)
meta["geometry_type"] = "MultiPolygon"

filename = os.path.join(str(tmpdir), "test.gpkg")
write_arrow(
filename,
table,
driver="GPKG",
layer="first",
crs=meta["crs"],
encoding=meta["encoding"],
geometry_type="MultiPolygon",
)

assert os.path.exists(filename)

assert np.array_equal(list_layers(filename), [["first", "MultiPolygon"]])

write_arrow(
filename,
table,
driver="GPKG",
layer="second",
crs=meta["crs"],
encoding=meta["encoding"],
geometry_type="MultiPolygon",
)

assert np.array_equal(
list_layers(filename), [["first", "MultiPolygon"], ["second", "MultiPolygon"]]
)


def test_write_geojson(tmpdir, naturalearth_lowres):
# I was thinking we might need to rename the wkb_geometry column to geometry
meta, table = read_arrow(naturalearth_lowres)
names = table.column_names
names[-1] = "geometry"
table = table.rename_columns(names)

filename = os.path.join(str(tmpdir), "test.json")
write_arrow(
filename,
table,
driver="GeoJSON",
crs=meta["crs"],
encoding=meta["encoding"],
geometry_type=meta["geometry_type"],
)

assert os.path.exists(filename)

data = json.loads(open(filename).read())

assert data["type"] == "FeatureCollection"
assert data["name"] == "test"
assert "crs" in data
assert len(data["features"]) == len(table)
assert not len(
set(meta["fields"]).difference(data["features"][0]["properties"].keys())
)

0 comments on commit dbf8f9d

Please sign in to comment.