diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 6c57a9ce5beaa..df2fbcbe32c8f 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -908,7 +908,8 @@ def to_msgpack(self, path_or_buf=None, **kwargs): from pandas.io import packers return packers.to_msgpack(path_or_buf, self, **kwargs) - def to_sql(self, name, con, flavor='sqlite', if_exists='fail', index=True): + def to_sql(self, name, con, flavor='sqlite', if_exists='fail', index=True, + index_label=None): """ Write records stored in a DataFrame to a SQL database. @@ -928,12 +929,17 @@ def to_sql(self, name, con, flavor='sqlite', if_exists='fail', index=True): - replace: If table exists, drop it, recreate it, and insert data. - append: If table exists, insert data. Create if does not exist. index : boolean, default True - Write DataFrame index as a column + Write DataFrame index as a column. + index_label : string or sequence, default None + Column label for index column(s). If None is given (default) and + `index` is True, then the index names are used. + A sequence should be given if the DataFrame uses MultiIndex. """ from pandas.io import sql sql.to_sql( - self, name, con, flavor=flavor, if_exists=if_exists, index=index) + self, name, con, flavor=flavor, if_exists=if_exists, index=index, + index_label=index_label) def to_pickle(self, path): """ diff --git a/pandas/io/sql.py b/pandas/io/sql.py index f17820b06ce5e..fa89cf488125a 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -229,7 +229,8 @@ def read_sql(sql, con, index_col=None, flavor='sqlite', coerce_float=True, parse_dates=parse_dates) -def to_sql(frame, name, con, flavor='sqlite', if_exists='fail', index=True): +def to_sql(frame, name, con, flavor='sqlite', if_exists='fail', index=True, + index_label=None): """ Write records stored in a DataFrame to a SQL database. @@ -251,6 +252,11 @@ def to_sql(frame, name, con, flavor='sqlite', if_exists='fail', index=True): - append: If table exists, insert data. Create if does not exist. index : boolean, default True Write DataFrame index as a column + index_label : string or sequence, default None + Column label for index column(s). If None is given (default) and + `index` is True, then the index names are used. + A sequence should be given if the DataFrame uses MultiIndex. + """ pandas_sql = pandasSQL_builder(con, flavor=flavor) @@ -259,7 +265,8 @@ def to_sql(frame, name, con, flavor='sqlite', if_exists='fail', index=True): elif not isinstance(frame, DataFrame): raise NotImplementedError - pandas_sql.to_sql(frame, name, if_exists=if_exists, index=index) + pandas_sql.to_sql(frame, name, if_exists=if_exists, index=index, + index_label=index_label) def has_table(table_name, con, meta=None, flavor='sqlite'): @@ -377,12 +384,12 @@ class PandasSQLTable(PandasObject): """ # TODO: support for multiIndex def __init__(self, name, pandas_sql_engine, frame=None, index=True, - if_exists='fail', prefix='pandas'): + if_exists='fail', prefix='pandas', index_label=None): self.name = name self.pd_sql = pandas_sql_engine self.prefix = prefix self.frame = frame - self.index = self._index_name(index) + self.index = self._index_name(index, index_label) if frame is not None: # We want to write a frame @@ -473,9 +480,11 @@ def read(self, coerce_float=True, parse_dates=None, columns=None): return self.frame - def _index_name(self, index): + def _index_name(self, index, index_label): if index is True: - if self.frame.index.name is not None: + if index_label is not None: + return _safe_col_name(index_label) + elif self.frame.index.name is not None: return _safe_col_name(self.frame.index.name) else: return self.prefix + '_index' @@ -652,9 +661,11 @@ def read_sql(self, sql, index_col=None, coerce_float=True, return data_frame - def to_sql(self, frame, name, if_exists='fail', index=True): + def to_sql(self, frame, name, if_exists='fail', index=True, + index_label=None): table = PandasSQLTable( - name, self, frame=frame, index=index, if_exists=if_exists) + name, self, frame=frame, index=index, if_exists=if_exists, + index_label=index_label) table.insert() @property @@ -882,7 +893,8 @@ def _fetchall_as_list(self, cur): result = list(result) return result - def to_sql(self, frame, name, if_exists='fail', index=True): + def to_sql(self, frame, name, if_exists='fail', index=True, + index_label=None): """ Write records stored in a DataFrame to a SQL database. @@ -895,6 +907,7 @@ def to_sql(self, frame, name, if_exists='fail', index=True): fail: If table exists, do nothing. replace: If table exists, drop it, recreate it, and insert data. append: If table exists, insert data. Create if does not exist. + index_label : ignored (only used in sqlalchemy mode) """ table = PandasSQLTableLegacy( name, self, frame=frame, index=index, if_exists=if_exists) diff --git a/pandas/io/tests/test_sql.py b/pandas/io/tests/test_sql.py index 8e045db0315cb..2f9323e50c9e2 100644 --- a/pandas/io/tests/test_sql.py +++ b/pandas/io/tests/test_sql.py @@ -255,7 +255,7 @@ def _tquery(self): tm.equalContents(row, [5.1, 3.5, 1.4, 0.2, 'Iris-setosa']) -class TestSQLApi(PandasSQLTest): +class _TestSQLApi(PandasSQLTest): """Test the public API as it would be used directly, including legacy names @@ -269,12 +269,6 @@ class TestSQLApi(PandasSQLTest): """ flavor = 'sqlite' - def connect(self): - if SQLALCHEMY_INSTALLED: - return sqlalchemy.create_engine('sqlite:///:memory:') - else: - return sqlite3.connect(':memory:') - def setUp(self): self.conn = self.connect() self._load_iris_data() @@ -436,6 +430,56 @@ def test_date_and_index(self): issubclass(df.IntDateCol.dtype.type, np.datetime64), "IntDateCol loaded with incorrect type") +class TestSQLApi(_TestSQLApi): + """Test the public API as it would be used directly + """ + flavor = 'sqlite' + + def connect(self): + if SQLALCHEMY_INSTALLED: + return sqlalchemy.create_engine('sqlite:///:memory:') + else: + raise nose.SkipTest('SQLAlchemy not installed') + + def test_to_sql_index_label(self): + temp_frame = DataFrame({'col1': range(4)}) + + # no index name, defaults to 'pandas_index' + sql.to_sql(temp_frame, 'test_index_label', self.conn) + frame = sql.read_table('test_index_label', self.conn) + self.assertEqual(frame.columns[0], 'pandas_index') + + # specifying index_label + sql.to_sql(temp_frame, 'test_index_label', self.conn, + if_exists='replace', index_label='other_label') + frame = sql.read_table('test_index_label', self.conn) + self.assertEqual(frame.columns[0], 'other_label', + "Specified index_label not written to database") + + # using the index name + temp_frame.index.name = 'index' + sql.to_sql(temp_frame, 'test_index_label', self.conn, + if_exists='replace') + frame = sql.read_table('test_index_label', self.conn) + self.assertEqual(frame.columns[0], 'index', + "Index name not written to database") + + # has index name, but specifying index_label + sql.to_sql(temp_frame, 'test_index_label', self.conn, + if_exists='replace', index_label='other_label') + frame = sql.read_table('test_index_label', self.conn) + self.assertEqual(frame.columns[0], 'other_label', + "Specified index_label not written to database") + + +class TestSQLLegacyApi(_TestSQLApi): + """Test the public legacy API + """ + flavor = 'sqlite' + + def connect(self): + return sqlite3.connect(':memory:') + class _TestSQLAlchemy(PandasSQLTest): """