Skip to content

SQL: add index_label keyword to to_sql #6642

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,8 @@ def to_msgpack(self, path_or_buf=None, **kwargs):
from pandas.io import packers
return packers.to_msgpack(path_or_buf, self, **kwargs)

def to_sql(self, name, con, flavor='sqlite', if_exists='fail', index=True):
def to_sql(self, name, con, flavor='sqlite', if_exists='fail', index=True,
index_label=None):
"""
Write records stored in a DataFrame to a SQL database.

Expand All @@ -928,12 +929,17 @@ def to_sql(self, name, con, flavor='sqlite', if_exists='fail', index=True):
- replace: If table exists, drop it, recreate it, and insert data.
- append: If table exists, insert data. Create if does not exist.
index : boolean, default True
Write DataFrame index as a column
Write DataFrame index as a column.
index_label : string or sequence, default None
Column label for index column(s). If None is given (default) and
`index` is True, then the index names are used.
A sequence should be given if the DataFrame uses MultiIndex.

"""
from pandas.io import sql
sql.to_sql(
self, name, con, flavor=flavor, if_exists=if_exists, index=index)
self, name, con, flavor=flavor, if_exists=if_exists, index=index,
index_label=index_label)

def to_pickle(self, path):
"""
Expand Down
31 changes: 22 additions & 9 deletions pandas/io/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,8 @@ def read_sql(sql, con, index_col=None, flavor='sqlite', coerce_float=True,
parse_dates=parse_dates)


def to_sql(frame, name, con, flavor='sqlite', if_exists='fail', index=True):
def to_sql(frame, name, con, flavor='sqlite', if_exists='fail', index=True,
index_label=None):
"""
Write records stored in a DataFrame to a SQL database.

Expand All @@ -251,6 +252,11 @@ def to_sql(frame, name, con, flavor='sqlite', if_exists='fail', index=True):
- append: If table exists, insert data. Create if does not exist.
index : boolean, default True
Write DataFrame index as a column
index_label : string or sequence, default None
Column label for index column(s). If None is given (default) and
`index` is True, then the index names are used.
A sequence should be given if the DataFrame uses MultiIndex.

"""
pandas_sql = pandasSQL_builder(con, flavor=flavor)

Expand All @@ -259,7 +265,8 @@ def to_sql(frame, name, con, flavor='sqlite', if_exists='fail', index=True):
elif not isinstance(frame, DataFrame):
raise NotImplementedError

pandas_sql.to_sql(frame, name, if_exists=if_exists, index=index)
pandas_sql.to_sql(frame, name, if_exists=if_exists, index=index,
index_label=index_label)


def has_table(table_name, con, meta=None, flavor='sqlite'):
Expand Down Expand Up @@ -377,12 +384,12 @@ class PandasSQLTable(PandasObject):
"""
# TODO: support for multiIndex
def __init__(self, name, pandas_sql_engine, frame=None, index=True,
if_exists='fail', prefix='pandas'):
if_exists='fail', prefix='pandas', index_label=None):
self.name = name
self.pd_sql = pandas_sql_engine
self.prefix = prefix
self.frame = frame
self.index = self._index_name(index)
self.index = self._index_name(index, index_label)

if frame is not None:
# We want to write a frame
Expand Down Expand Up @@ -473,9 +480,11 @@ def read(self, coerce_float=True, parse_dates=None, columns=None):

return self.frame

def _index_name(self, index):
def _index_name(self, index, index_label):
if index is True:
if self.frame.index.name is not None:
if index_label is not None:
return _safe_col_name(index_label)
elif self.frame.index.name is not None:
return _safe_col_name(self.frame.index.name)
else:
return self.prefix + '_index'
Expand Down Expand Up @@ -652,9 +661,11 @@ def read_sql(self, sql, index_col=None, coerce_float=True,

return data_frame

def to_sql(self, frame, name, if_exists='fail', index=True):
def to_sql(self, frame, name, if_exists='fail', index=True,
index_label=None):
table = PandasSQLTable(
name, self, frame=frame, index=index, if_exists=if_exists)
name, self, frame=frame, index=index, if_exists=if_exists,
index_label=index_label)
table.insert()

@property
Expand Down Expand Up @@ -882,7 +893,8 @@ def _fetchall_as_list(self, cur):
result = list(result)
return result

def to_sql(self, frame, name, if_exists='fail', index=True):
def to_sql(self, frame, name, if_exists='fail', index=True,
index_label=None):
"""
Write records stored in a DataFrame to a SQL database.

Expand All @@ -895,6 +907,7 @@ def to_sql(self, frame, name, if_exists='fail', index=True):
fail: If table exists, do nothing.
replace: If table exists, drop it, recreate it, and insert data.
append: If table exists, insert data. Create if does not exist.
index_label : ignored (only used in sqlalchemy mode)
"""
table = PandasSQLTableLegacy(
name, self, frame=frame, index=index, if_exists=if_exists)
Expand Down
58 changes: 51 additions & 7 deletions pandas/io/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def _tquery(self):
tm.equalContents(row, [5.1, 3.5, 1.4, 0.2, 'Iris-setosa'])


class TestSQLApi(PandasSQLTest):
class _TestSQLApi(PandasSQLTest):

"""Test the public API as it would be used
directly, including legacy names
Expand All @@ -269,12 +269,6 @@ class TestSQLApi(PandasSQLTest):
"""
flavor = 'sqlite'

def connect(self):
if SQLALCHEMY_INSTALLED:
return sqlalchemy.create_engine('sqlite:///:memory:')
else:
return sqlite3.connect(':memory:')

def setUp(self):
self.conn = self.connect()
self._load_iris_data()
Expand Down Expand Up @@ -436,6 +430,56 @@ def test_date_and_index(self):
issubclass(df.IntDateCol.dtype.type, np.datetime64),
"IntDateCol loaded with incorrect type")

class TestSQLApi(_TestSQLApi):
"""Test the public API as it would be used directly
"""
flavor = 'sqlite'

def connect(self):
if SQLALCHEMY_INSTALLED:
return sqlalchemy.create_engine('sqlite:///:memory:')
else:
raise nose.SkipTest('SQLAlchemy not installed')

def test_to_sql_index_label(self):
temp_frame = DataFrame({'col1': range(4)})

# no index name, defaults to 'pandas_index'
sql.to_sql(temp_frame, 'test_index_label', self.conn)
frame = sql.read_table('test_index_label', self.conn)
self.assertEqual(frame.columns[0], 'pandas_index')

# specifying index_label
sql.to_sql(temp_frame, 'test_index_label', self.conn,
if_exists='replace', index_label='other_label')
frame = sql.read_table('test_index_label', self.conn)
self.assertEqual(frame.columns[0], 'other_label',
"Specified index_label not written to database")

# using the index name
temp_frame.index.name = 'index'
sql.to_sql(temp_frame, 'test_index_label', self.conn,
if_exists='replace')
frame = sql.read_table('test_index_label', self.conn)
self.assertEqual(frame.columns[0], 'index',
"Index name not written to database")

# has index name, but specifying index_label
sql.to_sql(temp_frame, 'test_index_label', self.conn,
if_exists='replace', index_label='other_label')
frame = sql.read_table('test_index_label', self.conn)
self.assertEqual(frame.columns[0], 'other_label',
"Specified index_label not written to database")


class TestSQLLegacyApi(_TestSQLApi):
"""Test the public legacy API
"""
flavor = 'sqlite'

def connect(self):
return sqlite3.connect(':memory:')


class _TestSQLAlchemy(PandasSQLTest):
"""
Expand Down