Skip to content

Commit

Permalink
Merge pull request #6735 from jorisvandenbossche/sql-multiindex
Browse files Browse the repository at this point in the history
ENH: SQL multiindex support
  • Loading branch information
jorisvandenbossche committed Apr 14, 2014
2 parents 8e36ff4 + 18bd0d6 commit ad1f47d
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 36 deletions.
87 changes: 58 additions & 29 deletions pandas/io/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,8 @@ def read_table(table_name, con, meta=None, index_col=None, coerce_float=True,
Legacy mode not supported
meta : SQLAlchemy meta, optional
If omitted MetaData is reflected from engine
index_col : string, optional
Column to set as index
index_col : string or sequence of strings, optional
Column(s) to set as index.
coerce_float : boolean, default True
Attempt to convert values to non-string, non-numeric objects (like
decimal.Decimal) to floating point. Can result in loss of Precision.
Expand All @@ -324,7 +324,7 @@ def read_table(table_name, con, meta=None, index_col=None, coerce_float=True,
to the keyword arguments of :func:`pandas.to_datetime`
Especially useful with databases without native Datetime support,
such as SQLite
columns : list
columns : list, optional
List of column names to select from sql table
Returns
Expand All @@ -340,7 +340,8 @@ def read_table(table_name, con, meta=None, index_col=None, coerce_float=True,
table = pandas_sql.read_table(table_name,
index_col=index_col,
coerce_float=coerce_float,
parse_dates=parse_dates)
parse_dates=parse_dates,
columns=columns)

if table is not None:
return table
Expand Down Expand Up @@ -438,19 +439,25 @@ def maybe_asscalar(self, i):
def insert(self):
ins = self.insert_statement()
data_list = []
# to avoid if check for every row
keys = self.frame.columns

if self.index is not None:
for t in self.frame.itertuples():
data = dict((k, self.maybe_asscalar(v))
for k, v in zip(keys, t[1:]))
data[self.index] = self.maybe_asscalar(t[0])
data_list.append(data)
temp = self.frame.copy()
temp.index.names = self.index
try:
temp.reset_index(inplace=True)
except ValueError as err:
raise ValueError(
"duplicate name in index/columns: {0}".format(err))
else:
for t in self.frame.itertuples():
data = dict((k, self.maybe_asscalar(v))
for k, v in zip(keys, t[1:]))
data_list.append(data)
temp = self.frame

keys = temp.columns

for t in temp.itertuples():
data = dict((k, self.maybe_asscalar(v))
for k, v in zip(keys, t[1:]))
data_list.append(data)

self.pd_sql.execute(ins, data_list)

def read(self, coerce_float=True, parse_dates=None, columns=None):
Expand All @@ -459,7 +466,7 @@ def read(self, coerce_float=True, parse_dates=None, columns=None):
from sqlalchemy import select
cols = [self.table.c[n] for n in columns]
if self.index is not None:
cols.insert(0, self.table.c[self.index])
[cols.insert(0, self.table.c[idx]) for idx in self.index[::-1]]
sql_select = select(cols)
else:
sql_select = self.table.select()
Expand All @@ -476,22 +483,33 @@ def read(self, coerce_float=True, parse_dates=None, columns=None):
if self.index is not None:
self.frame.set_index(self.index, inplace=True)

# Assume if the index in prefix_index format, we gave it a name
# and should return it nameless
if self.index == self.prefix + '_index':
self.frame.index.name = None

return self.frame

def _index_name(self, index, index_label):
# for writing: index=True to include index in sql table
if index is True:
nlevels = self.frame.index.nlevels
# if index_label is specified, set this as index name(s)
if index_label is not None:
return _safe_col_name(index_label)
elif self.frame.index.name is not None:
return _safe_col_name(self.frame.index.name)
if not isinstance(index_label, list):
index_label = [index_label]
if len(index_label) != nlevels:
raise ValueError(
"Length of 'index_label' should match number of "
"levels, which is {0}".format(nlevels))
else:
return index_label
# return the used column labels for the index columns
if nlevels == 1 and 'index' not in self.frame.columns and self.frame.index.name is None:
return ['index']
else:
return self.prefix + '_index'
return [l if l is not None else "level_{0}".format(i)
for i, l in enumerate(self.frame.index.names)]

# for reading: index=(list of) string to specify column to set as index
elif isinstance(index, string_types):
return [index]
elif isinstance(index, list):
return index
else:
return None
Expand All @@ -506,10 +524,10 @@ def _create_table_statement(self):
for name, typ in zip(safe_columns, column_types)]

if self.index is not None:
columns.insert(0, Column(self.index,
self._sqlalchemy_type(
self.frame.index),
index=True))
for i, idx_label in enumerate(self.index[::-1]):
idx_type = self._sqlalchemy_type(
self.frame.index.get_level_values(i))
columns.insert(0, Column(idx_label, idx_type, index=True))

return Table(self.name, self.pd_sql.meta, *columns)

Expand Down Expand Up @@ -787,6 +805,17 @@ def insert(self):
cur.close()
self.pd_sql.con.commit()

def _index_name(self, index, index_label):
if index is True:
if self.frame.index.name is not None:
return _safe_col_name(self.frame.index.name)
else:
return 'pandas_index'
elif isinstance(index, string_types):
return index
else:
return None

def _create_table_statement(self):
"Return a CREATE TABLE statement to suit the contents of a DataFrame."

Expand Down
109 changes: 102 additions & 7 deletions pandas/io/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import nose
import numpy as np

from pandas import DataFrame, Series
from pandas import DataFrame, Series, MultiIndex
from pandas.compat import range, lrange, iteritems
#from pandas.core.datetools import format as date_format

Expand Down Expand Up @@ -266,7 +266,7 @@ def _roundtrip(self):
self.pandasSQL.to_sql(self.test_frame1, 'test_frame_roundtrip')
result = self.pandasSQL.read_sql('SELECT * FROM test_frame_roundtrip')

result.set_index('pandas_index', inplace=True)
result.set_index('level_0', inplace=True)
# result.index.astype(int)

result.index.name = None
Expand Down Expand Up @@ -391,7 +391,7 @@ def test_roundtrip(self):

# HACK!
result.index = self.test_frame1.index
result.set_index('pandas_index', inplace=True)
result.set_index('level_0', inplace=True)
result.index.astype(int)
result.index.name = None
tm.assert_frame_equal(result, self.test_frame1)
Expand Down Expand Up @@ -460,7 +460,9 @@ def test_date_and_index(self):
issubclass(df.IntDateCol.dtype.type, np.datetime64),
"IntDateCol loaded with incorrect type")


class TestSQLApi(_TestSQLApi):

"""Test the public API as it would be used directly
"""
flavor = 'sqlite'
Expand All @@ -474,10 +476,10 @@ def connect(self):
def test_to_sql_index_label(self):
temp_frame = DataFrame({'col1': range(4)})

# no index name, defaults to 'pandas_index'
# no index name, defaults to 'index'
sql.to_sql(temp_frame, 'test_index_label', self.conn)
frame = sql.read_table('test_index_label', self.conn)
self.assertEqual(frame.columns[0], 'pandas_index')
self.assertEqual(frame.columns[0], 'index')

# specifying index_label
sql.to_sql(temp_frame, 'test_index_label', self.conn,
Expand All @@ -487,11 +489,11 @@ def test_to_sql_index_label(self):
"Specified index_label not written to database")

# using the index name
temp_frame.index.name = 'index'
temp_frame.index.name = 'index_name'
sql.to_sql(temp_frame, 'test_index_label', self.conn,
if_exists='replace')
frame = sql.read_table('test_index_label', self.conn)
self.assertEqual(frame.columns[0], 'index',
self.assertEqual(frame.columns[0], 'index_name',
"Index name not written to database")

# has index name, but specifying index_label
Expand All @@ -501,8 +503,74 @@ def test_to_sql_index_label(self):
self.assertEqual(frame.columns[0], 'other_label',
"Specified index_label not written to database")

def test_to_sql_index_label_multiindex(self):
temp_frame = DataFrame({'col1': range(4)},
index=MultiIndex.from_product([('A0', 'A1'), ('B0', 'B1')]))

# no index name, defaults to 'level_0' and 'level_1'
sql.to_sql(temp_frame, 'test_index_label', self.conn)
frame = sql.read_table('test_index_label', self.conn)
self.assertEqual(frame.columns[0], 'level_0')
self.assertEqual(frame.columns[1], 'level_1')

# specifying index_label
sql.to_sql(temp_frame, 'test_index_label', self.conn,
if_exists='replace', index_label=['A', 'B'])
frame = sql.read_table('test_index_label', self.conn)
self.assertEqual(frame.columns[:2].tolist(), ['A', 'B'],
"Specified index_labels not written to database")

# using the index name
temp_frame.index.names = ['A', 'B']
sql.to_sql(temp_frame, 'test_index_label', self.conn,
if_exists='replace')
frame = sql.read_table('test_index_label', self.conn)
self.assertEqual(frame.columns[:2].tolist(), ['A', 'B'],
"Index names not written to database")

# has index name, but specifying index_label
sql.to_sql(temp_frame, 'test_index_label', self.conn,
if_exists='replace', index_label=['C', 'D'])
frame = sql.read_table('test_index_label', self.conn)
self.assertEqual(frame.columns[:2].tolist(), ['C', 'D'],
"Specified index_labels not written to database")

# wrong length of index_label
self.assertRaises(ValueError, sql.to_sql, temp_frame,
'test_index_label', self.conn, if_exists='replace',
index_label='C')

def test_read_table_columns(self):
# test columns argument in read_table
sql.to_sql(self.test_frame1, 'test_frame', self.conn)

cols = ['A', 'B']
result = sql.read_table('test_frame', self.conn, columns=cols)
self.assertEqual(result.columns.tolist(), cols,
"Columns not correctly selected")

def test_read_table_index_col(self):
# test columns argument in read_table
sql.to_sql(self.test_frame1, 'test_frame', self.conn)

result = sql.read_table('test_frame', self.conn, index_col="index")
self.assertEqual(result.index.names, ["index"],
"index_col not correctly set")

result = sql.read_table('test_frame', self.conn, index_col=["A", "B"])
self.assertEqual(result.index.names, ["A", "B"],
"index_col not correctly set")

result = sql.read_table('test_frame', self.conn, index_col=["A", "B"],
columns=["C", "D"])
self.assertEqual(result.index.names, ["A", "B"],
"index_col not correctly set")
self.assertEqual(result.columns.tolist(), ["C", "D"],
"columns not set correctly whith index_col")


class TestSQLLegacyApi(_TestSQLApi):

"""Test the public legacy API
"""
flavor = 'sqlite'
Expand Down Expand Up @@ -554,6 +622,23 @@ def test_sql_open_close(self):

tm.assert_frame_equal(self.test_frame2, result)

def test_roundtrip(self):
# this test otherwise fails, Legacy mode still uses 'pandas_index'
# as default index column label
sql.to_sql(self.test_frame1, 'test_frame_roundtrip',
con=self.conn, flavor='sqlite')
result = sql.read_sql(
'SELECT * FROM test_frame_roundtrip',
con=self.conn,
flavor='sqlite')

# HACK!
result.index = self.test_frame1.index
result.set_index('pandas_index', inplace=True)
result.index.astype(int)
result.index.name = None
tm.assert_frame_equal(result, self.test_frame1)


class _TestSQLAlchemy(PandasSQLTest):
"""
Expand Down Expand Up @@ -776,6 +861,16 @@ def setUp(self):

self._load_test1_data()

def _roundtrip(self):
# overwrite parent function (level_0 -> pandas_index in legacy mode)
self.drop_table('test_frame_roundtrip')
self.pandasSQL.to_sql(self.test_frame1, 'test_frame_roundtrip')
result = self.pandasSQL.read_sql('SELECT * FROM test_frame_roundtrip')
result.set_index('pandas_index', inplace=True)
result.index.name = None

tm.assert_frame_equal(result, self.test_frame1)

def test_invalid_flavor(self):
self.assertRaises(
NotImplementedError, sql.PandasSQLLegacy, self.conn, 'oracle')
Expand Down

0 comments on commit ad1f47d

Please sign in to comment.