Skip to content

Commit

Permalink
pygmt.x2sys_cross: Add 'output_type' parameter for output in pandas/n…
Browse files Browse the repository at this point in the history
…umpy/file formats
  • Loading branch information
seisman committed Apr 19, 2024
1 parent 62872d3 commit 5280524
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 33 deletions.
49 changes: 21 additions & 28 deletions pygmt/src/x2sys_cross.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,19 @@
import contextlib
import os
from pathlib import Path
from typing import Literal

import pandas as pd
from packaging.version import Version
from pygmt.clib import Session
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import (
GMTTempFile,
build_arg_list,
data_kind,
fmt_docstring,
kwargs_to_strings,
unique_name,
use_alias,
validate_output_table_type,
)


Expand Down Expand Up @@ -71,7 +71,12 @@ def tempfile_from_dftrack(track, suffix):
Z="trackvalues",
)
@kwargs_to_strings(R="sequence")
def x2sys_cross(tracks=None, outfile=None, **kwargs):
def x2sys_cross(
tracks=None,
output_type: Literal["pandas", "numpy", "file"] = "pandas",
outfile: str | None = None,
**kwargs,
):
r"""
Calculate crossovers between track data files.
Expand Down Expand Up @@ -192,6 +197,8 @@ def x2sys_cross(tracks=None, outfile=None, **kwargs):
- None if ``outfile`` is set (track output will be stored in the set in
``outfile``)
"""
output_type = validate_output_table_type(output_type, outfile=outfile)

with Session() as lib:
file_contexts = []
for track in tracks:
Expand All @@ -216,35 +223,21 @@ def x2sys_cross(tracks=None, outfile=None, **kwargs):
else:
raise GMTInvalidInput(f"Unrecognized data type: {type(track)}")

with GMTTempFile(suffix=".txt") as tmpfile:
with lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl:
with contextlib.ExitStack() as stack:
fnames = [stack.enter_context(c) for c in file_contexts]
if outfile is None:
outfile = tmpfile.name
lib.call_module(
module="x2sys_cross",
args=build_arg_list(kwargs, infile=fnames, outfile=outfile),
)

# Read temporary csv output to a pandas table
if outfile == tmpfile.name: # if outfile isn't set, return pd.DataFrame
# Read the tab-separated ASCII table
date_format_kwarg = (
{"date_format": "ISO8601"}
if Version(pd.__version__) >= Version("2.0.0")
else {}
args=build_arg_list(kwargs, infile=fnames, outfile=vouttbl),
)
table = pd.read_csv(
tmpfile.name,
sep="\t",
header=2, # Column names are on 2nd row
comment=">", # Skip the 3rd row with a ">"
parse_dates=[2, 3], # Datetimes on 3rd and 4th column
**date_format_kwarg, # Parse dates in ISO8601 format on pandas>=2
result = lib.virtualfile_to_dataset(
vfname=vouttbl, output_type=output_type, header=2
)
# Remove the "# " from "# x" in the first column
table = table.rename(columns={table.columns[0]: table.columns[0][2:]})
elif outfile != tmpfile.name: # if outfile is set, output in outfile only
table = None

return table
# Convert 3rd and 4th columns to datetimes.
# These two columns have names "t_1"/"t_2" or "i_1"/"i_2".
# "t_1"/"t_2" means they are datetimes and should be converted.
# "i_1"/"i_2" means they are dummy times (i.e., floating-point values).
if output_type != "file" and result.columns[2] == "t_1":
result.iloc[:, 2:4] = result.iloc[:, 2:4].apply(pd.to_datetime)
return result
14 changes: 9 additions & 5 deletions pygmt/tests/test_x2sys_cross.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,11 @@ def test_x2sys_cross_input_file_output_file():
x2sys_init(tag=tag, fmtfile="xyz", force=True)
outfile = tmpdir_p / "tmp_coe.txt"
output = x2sys_cross(
tracks=["@tut_ship.xyz"], tag=tag, coe="i", outfile=outfile
tracks=["@tut_ship.xyz"],
tag=tag,
coe="i",
outfile=outfile,
output_type="file",
)

assert output is None # check that output is None since outfile is set
Expand Down Expand Up @@ -97,8 +101,8 @@ def test_x2sys_cross_input_dataframe_output_dataframe(tracks):
columns = list(output.columns)
assert columns[:6] == ["x", "y", "i_1", "i_2", "dist_1", "dist_2"]
assert columns[6:] == ["head_1", "head_2", "vel_1", "vel_2", "z_X", "z_M"]
assert output.dtypes["i_1"].type == np.object_
assert output.dtypes["i_2"].type == np.object_
assert output.dtypes["i_1"].type == np.float64
assert output.dtypes["i_2"].type == np.float64


@pytest.mark.usefixtures("mock_x2sys_home")
Expand Down Expand Up @@ -158,8 +162,8 @@ def test_x2sys_cross_input_dataframe_with_nan(tracks):
columns = list(output.columns)
assert columns[:6] == ["x", "y", "i_1", "i_2", "dist_1", "dist_2"]
assert columns[6:] == ["head_1", "head_2", "vel_1", "vel_2", "z_X", "z_M"]
assert output.dtypes["i_1"].type == np.object_
assert output.dtypes["i_2"].type == np.object_
assert output.dtypes["i_1"].type == np.float64
assert output.dtypes["i_2"].type == np.float64


@pytest.mark.usefixtures("mock_x2sys_home")
Expand Down

0 comments on commit 5280524

Please sign in to comment.