Skip to content

Commit

Permalink
pygmt.x2sys_cross: Add 'output_type' parameter for output in pandas/n…
Browse files Browse the repository at this point in the history
…umpy/file formats
  • Loading branch information
seisman committed Apr 19, 2024
1 parent 62872d3 commit 295afc0
Showing 1 changed file with 18 additions and 29 deletions.
47 changes: 18 additions & 29 deletions pygmt/src/x2sys_cross.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,19 @@
import contextlib
import os
from pathlib import Path
from typing import Literal

import pandas as pd
from packaging.version import Version
from pygmt.clib import Session
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import (
GMTTempFile,
build_arg_list,
data_kind,
fmt_docstring,
kwargs_to_strings,
unique_name,
use_alias,
validate_output_table_type,
)


Expand Down Expand Up @@ -71,7 +71,12 @@ def tempfile_from_dftrack(track, suffix):
Z="trackvalues",
)
@kwargs_to_strings(R="sequence")
def x2sys_cross(tracks=None, outfile=None, **kwargs):
def x2sys_cross(
tracks=None,
output_type: Literal["pandas", "numpy", "file"] = "pandas",
outfile: str | None = None,
**kwargs,
):
r"""
Calculate crossovers between track data files.
Expand Down Expand Up @@ -192,6 +197,8 @@ def x2sys_cross(tracks=None, outfile=None, **kwargs):
- None if ``outfile`` is set (track output will be stored in the set in
``outfile``)
"""
output_type = validate_output_table_type(output_type, outfile=outfile)

with Session() as lib:
file_contexts = []
for track in tracks:
Expand All @@ -216,35 +223,17 @@ def x2sys_cross(tracks=None, outfile=None, **kwargs):
else:
raise GMTInvalidInput(f"Unrecognized data type: {type(track)}")

with GMTTempFile(suffix=".txt") as tmpfile:
with lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl:
with contextlib.ExitStack() as stack:
fnames = [stack.enter_context(c) for c in file_contexts]
if outfile is None:
outfile = tmpfile.name
lib.call_module(
module="x2sys_cross",
args=build_arg_list(kwargs, infile=fnames, outfile=outfile),
)

# Read temporary csv output to a pandas table
if outfile == tmpfile.name: # if outfile isn't set, return pd.DataFrame
# Read the tab-separated ASCII table
date_format_kwarg = (
{"date_format": "ISO8601"}
if Version(pd.__version__) >= Version("2.0.0")
else {}
args=build_arg_list(kwargs, infile=fnames, outfile=vouttbl),
)
table = pd.read_csv(
tmpfile.name,
sep="\t",
header=2, # Column names are on 2nd row
comment=">", # Skip the 3rd row with a ">"
parse_dates=[2, 3], # Datetimes on 3rd and 4th column
**date_format_kwarg, # Parse dates in ISO8601 format on pandas>=2
result = lib.virtualfile_to_dataset(
vfname=vouttbl, output_type=output_type, header=2
)
# Remove the "# " from "# x" in the first column
table = table.rename(columns={table.columns[0]: table.columns[0][2:]})
elif outfile != tmpfile.name: # if outfile is set, output in outfile only
table = None

return table
if output_type != "file":
# Datetimes on 3rd and 4th columns
result.iloc[:, 2:4] = result.iloc[:, 2:4].apply(pd.to_datetime)
return result

0 comments on commit 295afc0

Please sign in to comment.