diff --git a/mysql.go b/mysql.go index 914889e..1c1b30c 100644 --- a/mysql.go +++ b/mysql.go @@ -2,6 +2,7 @@ package main import ( "database/sql" + "fmt" _ "github.com/go-sql-driver/mysql" @@ -44,15 +45,15 @@ CREATE TABLE IF NOT EXISTS users ( passbcrypt VARCHAR(64) DEFAULT '', otpsecret VARCHAR(64) DEFAULT '', yubikey VARCHAR(128) DEFAULT '', - sshkeys TEXT DEFAULT '', - custattr TEXT DEFAULT '{}') + sshkeys TEXT DEFAULT (''), + custattr TEXT DEFAULT ('{}')) `) statement.Exec() statement, _ = db.Prepare("CREATE UNIQUE INDEX idx_user_name on users(name)") statement.Exec() - statement, _ = db.Prepare("CREATE TABLE IF NOT EXISTS groups (id INTEGER AUTO_INCREMENT PRIMARY KEY, name VARCHAR(64) NOT NULL, gidnumber INTEGER NOT NULL)") + statement, _ = db.Prepare("CREATE TABLE IF NOT EXISTS ldapgroups (id INTEGER AUTO_INCREMENT PRIMARY KEY, name VARCHAR(64) NOT NULL, gidnumber INTEGER NOT NULL)") statement.Exec() - statement, _ = db.Prepare("CREATE UNIQUE INDEX idx_group_name on groups(name)") + statement, _ = db.Prepare("CREATE UNIQUE INDEX idx_group_name on ldapgroups(name)") statement.Exec() statement, _ = db.Prepare("CREATE TABLE IF NOT EXISTS includegroups (id INTEGER AUTO_INCREMENT PRIMARY KEY, parentgroupid INTEGER NOT NULL, includegroupid INTEGER NOT NULL)") statement.Exec() @@ -63,7 +64,27 @@ CREATE TABLE IF NOT EXISTS users ( // Migrate schema if necessary func (b MysqlBackend) MigrateSchema(db *sql.DB, checker func(*sql.DB, string) bool) { if !checker(db, "sshkeys") { - statement, _ := db.Prepare("ALTER TABLE users ADD COLUMN sshkeys TEXT DEFAULT ''") + statement, _ := db.Prepare("ALTER TABLE users ADD COLUMN sshkeys TEXT DEFAULT ('')") statement.Exec() } + + if TableExists(db, "`groups`") { + // Drop the table created during schema creation + statement, _ := db.Prepare("DROP TABLE ldapgroups") + statement.Exec() + + statement, _ = db.Prepare("ALTER TABLE `groups` RENAME ldapgroups") + statement.Exec() + } +} + +// Indicates whether the table exists or not +func TableExists(db *sql.DB, tableName string) bool { + var found string + err := db.QueryRow(fmt.Sprintf("SELECT COUNT(id) FROM %s", tableName)).Scan( + &found) + if err != nil { + return false + } + return true }