defender: add provider driver

Fixes #616
This commit is contained in:
Nicola Murino
2021-12-25 12:08:07 +01:00
parent 8174349032
commit 7d8823307f
30 changed files with 2177 additions and 609 deletions

View File

@@ -27,6 +27,8 @@ const (
"DROP TABLE IF EXISTS `{{folders}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{shares}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{users}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{defender_events}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{defender_hosts}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{schema_version}}` CASCADE;"
mysqlInitialSQL = "CREATE TABLE `{{schema_version}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `version` integer NOT NULL);" +
"CREATE TABLE `{{admins}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `username` varchar(255) NOT NULL UNIQUE, " +
@@ -82,6 +84,17 @@ const (
"ALTER TABLE `{{shares}}` ADD CONSTRAINT `{{prefix}}shares_user_id_fk_users_id` " +
"FOREIGN KEY (`user_id`) REFERENCES `{{users}}` (`id`) ON DELETE CASCADE;"
mysqlV14DownSQL = "DROP TABLE `{{shares}}` CASCADE;"
mysqlV15SQL = "CREATE TABLE `{{defender_hosts}}` (`id` bigint AUTO_INCREMENT NOT NULL PRIMARY KEY, " +
"`ip` varchar(50) NOT NULL UNIQUE, `ban_time` bigint NOT NULL, `updated_at` bigint NOT NULL);" +
"CREATE TABLE `{{defender_events}}` (`id` bigint AUTO_INCREMENT NOT NULL PRIMARY KEY, " +
"`date_time` bigint NOT NULL, `score` integer NOT NULL, `host_id` bigint NOT NULL);" +
"ALTER TABLE `{{defender_events}}` ADD CONSTRAINT `{{prefix}}defender_events_host_id_fk_defender_hosts_id` " +
"FOREIGN KEY (`host_id`) REFERENCES `{{defender_hosts}}` (`id`) ON DELETE CASCADE;" +
"CREATE INDEX `{{prefix}}defender_hosts_updated_at_idx` ON `{{defender_hosts}}` (`updated_at`);" +
"CREATE INDEX `{{prefix}}defender_hosts_ban_time_idx` ON `{{defender_hosts}}` (`ban_time`);" +
"CREATE INDEX `{{prefix}}defender_events_date_time_idx` ON `{{defender_events}}` (`date_time`);"
mysqlV15DownSQL = "DROP TABLE `{{defender_events}}` CASCADE;" +
"DROP TABLE `{{defender_hosts}}` CASCADE;"
)
// MySQLProvider auth provider for MySQL/MariaDB database
@@ -311,6 +324,38 @@ func (p *MySQLProvider) updateShareLastUse(shareID string, numTokens int) error
return sqlCommonUpdateShareLastUse(shareID, numTokens, p.dbHandle)
}
func (p *MySQLProvider) getDefenderHosts(from int64, limit int) ([]*DefenderEntry, error) {
return sqlCommonGetDefenderHosts(from, limit, p.dbHandle)
}
func (p *MySQLProvider) getDefenderHostByIP(ip string, from int64) (*DefenderEntry, error) {
return sqlCommonGetDefenderHostByIP(ip, from, p.dbHandle)
}
func (p *MySQLProvider) isDefenderHostBanned(ip string) (*DefenderEntry, error) {
return sqlCommonIsDefenderHostBanned(ip, p.dbHandle)
}
func (p *MySQLProvider) updateDefenderBanTime(ip string, minutes int) error {
return sqlCommonDefenderIncrementBanTime(ip, minutes, p.dbHandle)
}
func (p *MySQLProvider) deleteDefenderHost(ip string) error {
return sqlCommonDeleteDefenderHost(ip, p.dbHandle)
}
func (p *MySQLProvider) addDefenderEvent(ip string, score int) error {
return sqlCommonAddDefenderHostAndEvent(ip, score, p.dbHandle)
}
func (p *MySQLProvider) setDefenderBanTime(ip string, banTime int64) error {
return sqlCommonSetDefenderBanTime(ip, banTime, p.dbHandle)
}
func (p *MySQLProvider) cleanupDefender(from int64) error {
return sqlCommonDefenderCleanup(from, p.dbHandle)
}
func (p *MySQLProvider) close() error {
return p.dbHandle.Close()
}
@@ -362,6 +407,8 @@ func (p *MySQLProvider) migrateDatabase() error {
return updateMySQLDatabaseFromV12(p.dbHandle)
case version == 13:
return updateMySQLDatabaseFromV13(p.dbHandle)
case version == 14:
return updateMySQLDatabaseFromV14(p.dbHandle)
default:
if version > sqlDatabaseVersion {
providerLog(logger.LevelError, "database version %v is newer than the supported one: %v", version,
@@ -384,6 +431,8 @@ func (p *MySQLProvider) revertDatabase(targetVersion int) error {
}
switch dbVersion.Version {
case 15:
return downgradeMySQLDatabaseFromV15(p.dbHandle)
case 14:
return downgradeMySQLDatabaseFromV14(p.dbHandle)
case 13:
@@ -405,6 +454,8 @@ func (p *MySQLProvider) resetDatabase() error {
sql = strings.ReplaceAll(sql, "{{folders_mapping}}", sqlTableFoldersMapping)
sql = strings.ReplaceAll(sql, "{{api_keys}}", sqlTableAPIKeys)
sql = strings.ReplaceAll(sql, "{{shares}}", sqlTableShares)
sql = strings.ReplaceAll(sql, "{{defender_events}}", sqlTableDefenderEvents)
sql = strings.ReplaceAll(sql, "{{defender_hosts}}", sqlTableDefenderHosts)
return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, strings.Split(sql, ";"), 0)
}
@@ -430,7 +481,21 @@ func updateMySQLDatabaseFromV12(dbHandle *sql.DB) error {
}
func updateMySQLDatabaseFromV13(dbHandle *sql.DB) error {
return updateMySQLDatabaseFrom13To14(dbHandle)
if err := updateMySQLDatabaseFrom13To14(dbHandle); err != nil {
return err
}
return updateMySQLDatabaseFromV14(dbHandle)
}
func updateMySQLDatabaseFromV14(dbHandle *sql.DB) error {
return updateMySQLDatabaseFrom14To15(dbHandle)
}
func downgradeMySQLDatabaseFromV15(dbHandle *sql.DB) error {
if err := downgradeMySQLDatabaseFrom15To14(dbHandle); err != nil {
return err
}
return downgradeMySQLDatabaseFromV14(dbHandle)
}
func downgradeMySQLDatabaseFromV14(dbHandle *sql.DB) error {
@@ -467,6 +532,23 @@ func updateMySQLDatabaseFrom13To14(dbHandle *sql.DB) error {
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 14)
}
func updateMySQLDatabaseFrom14To15(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database version: 14 -> 15")
providerLog(logger.LevelInfo, "updating database version: 14 -> 15")
sql := strings.ReplaceAll(mysqlV15SQL, "{{defender_events}}", sqlTableDefenderEvents)
sql = strings.ReplaceAll(sql, "{{defender_hosts}}", sqlTableDefenderHosts)
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 15)
}
func downgradeMySQLDatabaseFrom15To14(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database version: 15 -> 14")
providerLog(logger.LevelInfo, "downgrading database version: 15 -> 14")
sql := strings.ReplaceAll(mysqlV15DownSQL, "{{defender_events}}", sqlTableDefenderEvents)
sql = strings.ReplaceAll(sql, "{{defender_hosts}}", sqlTableDefenderHosts)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 14)
}
func downgradeMySQLDatabaseFrom14To13(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database version: 14 -> 13")
providerLog(logger.LevelInfo, "downgrading database version: 14 -> 13")