Compare commits

..

27 Commits

Author SHA1 Message Date
Nicola Murino
665016ed1e set version to 2.3.3
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-08-05 09:47:00 +02:00
Nicola Murino
97d5680d1e azblob: fix SAS URL with embedded container name
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-08-01 21:52:32 +02:00
Nicola Murino
e7866047aa allow to edit profile to users logged in via OIDC
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-08-01 21:49:53 +02:00
Nicola Murino
5f313cc6be macOS: add config file search path
this way the default config file is used in brew package if no config file is
set

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-07-29 17:55:01 +02:00
Nicola Murino
3c2c703408 user templates: apply placeholders also for start directory
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-07-27 19:09:54 +02:00
Nicola Murino
78a399eed4 download as zip: improve filename
include username and also filename/directory name if the user downloads
a single file/directory

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-07-26 17:53:04 +02:00
Nicola Murino
e6d434654d backport from main
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-07-24 08:56:31 +02:00
Nicola Murino
d34446e6e9 web client: add HTML5 player
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-07-23 16:30:27 +02:00
Nicola Murino
2da19ef233 backport OIDC related changes from main
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-07-23 15:31:57 +02:00
Nicola Murino
b34bc2b818 add license header to source files
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-07-18 13:43:25 +02:00
Nicola Murino
378995147b try to better highlight donations and sponsorships options ...
... and to better explain why they are required.

Please don't say "someone else will help the project, I'll just use it"

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-07-16 20:29:10 +02:00
Nicola Murino
6b995db864 oidc: allow to configure oauth2 scopes
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-07-16 19:25:04 +02:00
Nicola Murino
371012a46e backport some fixes from main
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-07-15 20:09:06 +02:00
Nicola Murino
d3d788c8d0 s3: improve rename performance
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-06-30 18:25:40 +02:00
maximethebault
756b122ab8 S3: Fix timeout error when renaming large files (#899)
Remove AWS SDK Transport ResponseHeaderTimeout (finer-grained timeout are already handled by the callers)
Lower the threshold for MultipartCopy (5GB -> 500MB) to improve copy performance and reduce chance of hitting Single part copy timeout

Fixes #898

Signed-off-by: Maxime Thébault <contact@maximethebault.me>
2022-06-30 10:25:04 +02:00
Nicola Murino
e244ba37b2 config: fix replace from env vars for some sub list
ensure to merge configuration from files with configuration from env

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-06-28 19:17:16 +02:00
Nicola Murino
5610b98d19 fix get branding from env
Fixes #895

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-06-28 10:46:25 +02:00
Nicola Murino
b3ca20b5e6 dataprovider: fix sql tables prefix handling
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-06-24 12:26:43 +02:00
Nicola Murino
d0b6ca8d2f backup: include folders set on groups
Fixes #885

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-06-21 14:13:25 +02:00
Nicola Murino
550158ff4b fix database reset
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-06-13 19:40:24 +02:00
Nicola Murino
14a3803c8f OpenAPI schema: improve compatibility with some generators
Fixes #875

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-06-11 19:00:04 +02:00
Nicola Murino
ca4da2f64e set version to 2.3.1
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-06-10 18:42:13 +02:00
Nicola Murino
049c2b7430 mysql: groups is a reserved keyfrom since MySQL 8.0.2
add mysql to CI

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-06-10 17:36:26 +02:00
Nicola Murino
7fd5558400 parse IP proxy header also if listening on UNIX domain socket
Fixes #867

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-06-09 09:48:39 +02:00
Nicola Murino
b60255752f web UIs: fix date formatting on Safari
Fixes #869

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-06-09 09:47:02 +02:00
Nicola Murino
37f79650c8 APT and YUM repo are now available
This is possible thanks to the Oregon State University's free
mirroring service

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-06-09 09:46:57 +02:00
Nicola Murino
8988d6542b create branch 2.3.x
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-06-06 19:01:24 +02:00
340 changed files with 8138 additions and 33027 deletions

View File

@@ -2,7 +2,7 @@ name: CI
on:
push:
branches: [main]
branches: [2.3.x]
pull_request:
jobs:
@@ -11,11 +11,11 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
go: [1.19]
go: [1.18]
os: [ubuntu-latest, macos-latest]
upload-coverage: [true]
include:
- go: 1.19
- go: 1.18
os: windows-latest
upload-coverage: false
@@ -32,24 +32,22 @@ jobs:
- name: Build for Linux/macOS x86_64
if: startsWith(matrix.os, 'windows-') != true
run: |
go build -trimpath -tags nopgxregisterdefaulttypes -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=`git describe --always --abbrev=8 --dirty` -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo
go build -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/version.commit=`git describe --always --dirty` -X github.com/drakkan/sftpgo/v2/version.date=`date -u +%FT%TZ`" -o sftpgo
cd tests/eventsearcher
go build -trimpath -ldflags "-s -w" -o eventsearcher
cd -
cd tests/ipfilter
go build -trimpath -ldflags "-s -w" -o ipfilter
cd -
./sftpgo initprovider
./sftpgo resetprovider --force
- name: Build for macOS arm64
if: startsWith(matrix.os, 'macos-') == true
run: CGO_ENABLED=1 GOOS=darwin GOARCH=arm64 SDKROOT=$(xcrun --sdk macosx --show-sdk-path) go build -trimpath -tags nopgxregisterdefaulttypes -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=`git describe --always --abbrev=8 --dirty` -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo_arm64
run: CGO_ENABLED=1 GOOS=darwin GOARCH=arm64 SDKROOT=$(xcrun --sdk macosx --show-sdk-path) go build -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/version.commit=`git describe --always --dirty` -X github.com/drakkan/sftpgo/v2/version.date=`date -u +%FT%TZ`" -o sftpgo_arm64
- name: Build for Windows
if: startsWith(matrix.os, 'windows-')
run: |
$GIT_COMMIT = (git describe --always --abbrev=8 --dirty) | Out-String
$GIT_COMMIT = (git describe --always --dirty) | Out-String
$DATE_TIME = ([datetime]::Now.ToUniversalTime().toString("yyyy-MM-ddTHH:mm:ssZ")) | Out-String
$LATEST_TAG = ((git describe --tags $(git rev-list --tags --max-count=1)) | Out-String).Trim()
$REV_LIST=$LATEST_TAG+"..HEAD"
@@ -57,7 +55,7 @@ jobs:
$FILE_VERSION = $LATEST_TAG.substring(1) + "." + $COMMITS_FROM_TAG
go install github.com/tc-hib/go-winres@latest
go-winres simply --arch amd64 --product-version $LATEST_TAG-dev-$GIT_COMMIT --file-version $FILE_VERSION --file-description "SFTPGo server" --product-name SFTPGo --copyright "AGPL-3.0" --original-filename sftpgo.exe --icon .\windows-installer\icon.ico
go build -trimpath -tags nopgxregisterdefaulttypes -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=$GIT_COMMIT -X github.com/drakkan/sftpgo/v2/internal/version.date=$DATE_TIME" -o sftpgo.exe
go build -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/version.commit=$GIT_COMMIT -X github.com/drakkan/sftpgo/v2/version.date=$DATE_TIME" -o sftpgo.exe
cd tests/eventsearcher
go build -trimpath -ldflags "-s -w" -o eventsearcher.exe
cd ../..
@@ -69,17 +67,17 @@ jobs:
$Env:GOOS='windows'
$Env:GOARCH='arm64'
go-winres simply --arch arm64 --product-version $LATEST_TAG-dev-$GIT_COMMIT --file-version $FILE_VERSION --file-description "SFTPGo server" --product-name SFTPGo --copyright "AGPL-3.0" --original-filename sftpgo.exe --icon .\windows-installer\icon.ico
go build -trimpath -tags nopgxregisterdefaulttypes -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=$GIT_COMMIT -X github.com/drakkan/sftpgo/v2/internal/version.date=$DATE_TIME" -o .\arm64\sftpgo.exe
go build -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/version.commit=$GIT_COMMIT -X github.com/drakkan/sftpgo/v2/version.date=$DATE_TIME" -o .\arm64\sftpgo.exe
mkdir x86
$Env:GOARCH='386'
go-winres simply --arch 386 --product-version $LATEST_TAG-dev-$GIT_COMMIT --file-version $FILE_VERSION --file-description "SFTPGo server" --product-name SFTPGo --copyright "AGPL-3.0" --original-filename sftpgo.exe --icon .\windows-installer\icon.ico
go build -trimpath -tags nopgxregisterdefaulttypes -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=$GIT_COMMIT -X github.com/drakkan/sftpgo/v2/internal/version.date=$DATE_TIME" -o .\x86\sftpgo.exe
go build -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/version.commit=$GIT_COMMIT -X github.com/drakkan/sftpgo/v2/version.date=$DATE_TIME" -o .\x86\sftpgo.exe
Remove-Item Env:\CGO_ENABLED
Remove-Item Env:\GOOS
Remove-Item Env:\GOARCH
- name: Run test cases using SQLite provider
run: go test -v -tags nopgxregisterdefaulttypes -p 1 -timeout 15m ./... -coverprofile=coverage.txt -covermode=atomic
run: go test -v -p 1 -timeout 15m ./... -coverprofile=coverage.txt -covermode=atomic
- name: Upload coverage to Codecov
if: ${{ matrix.upload-coverage }}
@@ -90,21 +88,21 @@ jobs:
- name: Run test cases using bolt provider
run: |
go test -v -tags nopgxregisterdefaulttypes -p 1 -timeout 2m ./internal/config -covermode=atomic
go test -v -tags nopgxregisterdefaulttypes -p 1 -timeout 5m ./internal/common -covermode=atomic
go test -v -tags nopgxregisterdefaulttypes -p 1 -timeout 5m ./internal/httpd -covermode=atomic
go test -v -tags nopgxregisterdefaulttypes -p 1 -timeout 8m ./internal/sftpd -covermode=atomic
go test -v -tags nopgxregisterdefaulttypes -p 1 -timeout 5m ./internal/ftpd -covermode=atomic
go test -v -tags nopgxregisterdefaulttypes -p 1 -timeout 5m ./internal/webdavd -covermode=atomic
go test -v -tags nopgxregisterdefaulttypes -p 1 -timeout 2m ./internal/telemetry -covermode=atomic
go test -v -tags nopgxregisterdefaulttypes -p 1 -timeout 2m ./internal/mfa -covermode=atomic
go test -v -tags nopgxregisterdefaulttypes -p 1 -timeout 2m ./internal/command -covermode=atomic
go test -v -p 1 -timeout 2m ./config -covermode=atomic
go test -v -p 1 -timeout 5m ./common -covermode=atomic
go test -v -p 1 -timeout 5m ./httpd -covermode=atomic
go test -v -p 1 -timeout 8m ./sftpd -covermode=atomic
go test -v -p 1 -timeout 5m ./ftpd -covermode=atomic
go test -v -p 1 -timeout 5m ./webdavd -covermode=atomic
go test -v -p 1 -timeout 2m ./telemetry -covermode=atomic
go test -v -p 1 -timeout 2m ./mfa -covermode=atomic
go test -v -p 1 -timeout 2m ./command -covermode=atomic
env:
SFTPGO_DATA_PROVIDER__DRIVER: bolt
SFTPGO_DATA_PROVIDER__NAME: 'sftpgo_bolt.db'
- name: Run test cases using memory provider
run: go test -v -tags nopgxregisterdefaulttypes -p 1 -timeout 15m ./... -covermode=atomic
run: go test -v -p 1 -timeout 15m ./... -covermode=atomic
env:
SFTPGO_DATA_PROVIDER__DRIVER: memory
SFTPGO_DATA_PROVIDER__NAME: ''
@@ -222,24 +220,6 @@ jobs:
name: sftpgo-${{ matrix.os }}-go-${{ matrix.go }}
path: output
test-bundle:
name: Build in bundle mode
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: 1.19
- name: Build
run: |
cp -r openapi static templates internal/bundle/
go build -trimpath -tags nopgxregisterdefaulttypes,bundle -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=`git describe --always --abbrev=8 --dirty` -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo
./sftpgo -v
test-goarch-386:
name: Run test cases on 32-bit arch
runs-on: ubuntu-latest
@@ -250,7 +230,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: 1.19
go-version: 1.18
- name: Build
run: |
@@ -264,7 +244,7 @@ jobs:
GOARCH: 386
- name: Run test cases
run: go test -v -tags nopgxregisterdefaulttypes -p 1 -timeout 15m ./... -covermode=atomic
run: go test -v -p 1 -timeout 15m ./... -covermode=atomic
env:
SFTPGO_DATA_PROVIDER__DRIVER: memory
SFTPGO_DATA_PROVIDER__NAME: ''
@@ -324,11 +304,10 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: 1.19
go-version: 1.18
- name: Build
run: |
go build -trimpath -tags nopgxregisterdefaulttypes -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=`git describe --always --abbrev=8 --dirty` -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo
cd tests/eventsearcher
go build -trimpath -ldflags "-s -w" -o eventsearcher
cd -
@@ -338,9 +317,7 @@ jobs:
- name: Run tests using PostgreSQL provider
run: |
./sftpgo initprovider
./sftpgo resetprovider --force
go test -v -tags nopgxregisterdefaulttypes -p 1 -timeout 15m ./... -covermode=atomic
go test -v -p 1 -timeout 15m ./... -covermode=atomic
env:
SFTPGO_DATA_PROVIDER__DRIVER: postgresql
SFTPGO_DATA_PROVIDER__NAME: sftpgo
@@ -351,9 +328,7 @@ jobs:
- name: Run tests using MySQL provider
run: |
./sftpgo initprovider
./sftpgo resetprovider --force
go test -v -tags nopgxregisterdefaulttypes -p 1 -timeout 15m ./... -covermode=atomic
go test -v -p 1 -timeout 15m ./... -covermode=atomic
env:
SFTPGO_DATA_PROVIDER__DRIVER: mysql
SFTPGO_DATA_PROVIDER__NAME: sftpgo
@@ -364,9 +339,7 @@ jobs:
- name: Run tests using MariaDB provider
run: |
./sftpgo initprovider
./sftpgo resetprovider --force
go test -v -tags nopgxregisterdefaulttypes -p 1 -timeout 15m ./... -covermode=atomic
go test -v -p 1 -timeout 15m ./... -covermode=atomic
env:
SFTPGO_DATA_PROVIDER__DRIVER: mysql
SFTPGO_DATA_PROVIDER__NAME: sftpgo
@@ -381,9 +354,7 @@ jobs:
docker run --rm --name crdb --health-cmd "curl -I http://127.0.0.1:8080" --health-interval 10s --health-timeout 5s --health-retries 6 -p 26257:26257 -d cockroachdb/cockroach:latest start-single-node --insecure --listen-addr :26257
sleep 10
docker exec crdb cockroach sql --insecure -e 'create database "sftpgo"'
./sftpgo initprovider
./sftpgo resetprovider --force
go test -v -tags nopgxregisterdefaulttypes -p 1 -timeout 15m ./... -covermode=atomic
go test -v -p 1 -timeout 15m ./... -covermode=atomic
docker stop crdb
env:
SFTPGO_DATA_PROVIDER__DRIVER: cockroachdb
@@ -396,13 +367,12 @@ jobs:
build-linux-packages:
name: Build Linux packages
runs-on: ubuntu-latest
runs-on: ubuntu-18.04
strategy:
matrix:
include:
- arch: amd64
distro: ubuntu:18.04
go: latest
go: 1.18
go-arch: amd64
- arch: aarch64
distro: ubuntu18.04
@@ -420,36 +390,16 @@ jobs:
- uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Get commit SHA
id: get_commit
run: echo "COMMIT=${GITHUB_SHA::8}" >> $GITHUB_OUTPUT
shell: bash
- name: Set up Go
if: ${{ matrix.arch == 'amd64' }}
uses: actions/setup-go@v3
with:
go-version: ${{ matrix.go }}
- name: Build on amd64
if: ${{ matrix.arch == 'amd64' }}
run: |
echo '#!/bin/bash' > build.sh
echo '' >> build.sh
echo 'set -e' >> build.sh
echo 'apt-get update -q -y' >> build.sh
echo 'apt-get install -q -y curl gcc' >> build.sh
if [ ${{ matrix.go }} == 'latest' ]
then
echo 'GO_VERSION=$(curl -L https://go.dev/VERSION?m=text)' >> build.sh
else
echo 'GO_VERSION=${{ matrix.go }}' >> build.sh
fi
echo 'GO_DOWNLOAD_ARCH=${{ matrix.go-arch }}' >> build.sh
echo 'curl --retry 5 --retry-delay 2 --connect-timeout 10 -o go.tar.gz -L https://go.dev/dl/${GO_VERSION}.linux-${GO_DOWNLOAD_ARCH}.tar.gz' >> build.sh
echo 'tar -C /usr/local -xzf go.tar.gz' >> build.sh
echo 'export PATH=$PATH:/usr/local/go/bin' >> build.sh
echo 'go version' >> build.sh
echo 'cd /usr/local/src' >> build.sh
echo 'go build -buildvcs=false -trimpath -tags nopgxregisterdefaulttypes -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=${{ steps.get_commit.outputs.COMMIT }} -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo' >> build.sh
chmod 755 build.sh
docker run --rm --name ubuntu-build --mount type=bind,source=`pwd`,target=/usr/local/src ${{ matrix.distro }} /usr/local/src/build.sh
go build -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/version.commit=`git describe --always --dirty` -X github.com/drakkan/sftpgo/v2/version.date=`date -u +%FT%TZ`" -o sftpgo
mkdir -p output/{init,bash_completion,zsh_completion}
cp sftpgo.json output/
cp -r templates output/
@@ -476,7 +426,7 @@ jobs:
shell: /bin/bash
install: |
apt-get update -q -y
apt-get install -q -y curl gcc
apt-get install -q -y curl gcc git
if [ ${{ matrix.go }} == 'latest' ]
then
GO_VERSION=$(curl -L https://go.dev/VERSION?m=text)
@@ -492,12 +442,11 @@ jobs:
tar -C /usr/local -xzf go.tar.gz
run: |
export PATH=$PATH:/usr/local/go/bin
go version
if [ ${{ matrix.arch}} == 'armv7' ]
then
export GOARM=7
fi
go build -buildvcs=false -trimpath -tags nopgxregisterdefaulttypes -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=${{ steps.get_commit.outputs.COMMIT }} -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo
go build -buildvcs=false -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/version.commit=`git describe --always --dirty` -X github.com/drakkan/sftpgo/v2/version.date=`date -u +%FT%TZ`" -o sftpgo
mkdir -p output/{init,bash_completion,zsh_completion}
cp sftpgo.json output/
cp -r templates output/
@@ -523,7 +472,7 @@ jobs:
cd pkgs
./build.sh
PKG_VERSION=$(cat dist/version)
echo "pkg-version=${PKG_VERSION}" >> $GITHUB_OUTPUT
echo "::set-output name=pkg-version::${PKG_VERSION}"
- name: Upload Debian Package
uses: actions/upload-artifact@v3
@@ -544,7 +493,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: 1.19
go-version: 1.18
- uses: actions/checkout@v3
- name: Run golangci-lint
uses: golangci/golangci-lint-action@v3

View File

@@ -5,7 +5,7 @@ on:
# - cron: '0 4 * * *' # everyday at 4:00 AM UTC
push:
branches:
- main
- 2.3.x
tags:
- v*
pull_request:
@@ -28,9 +28,6 @@ jobs:
- os: ubuntu-latest
docker_pkg: distroless
optional_deps: false
- os: ubuntu-latest
docker_pkg: debian-plugins
optional_deps: true
steps:
- name: Checkout
uses: actions/checkout@v3
@@ -67,9 +64,6 @@ jobs:
VERSION="${VERSION}-distroless"
VERSION_SLIM="${VERSION}-slim"
DOCKERFILE=Dockerfile.distroless
elif [[ $DOCKER_PKG == debian-plugins ]]; then
VERSION="${VERSION}-plugins"
VERSION_SLIM="${VERSION}-slim"
fi
DOCKER_IMAGES=("drakkan/sftpgo" "ghcr.io/drakkan/sftpgo")
TAGS="${DOCKER_IMAGES[0]}:${VERSION}"
@@ -95,13 +89,6 @@ jobs:
fi
TAGS="${TAGS},${DOCKER_IMAGE}:distroless"
TAGS_SLIM="${TAGS_SLIM},${DOCKER_IMAGE}:distroless-slim"
elif [[ $DOCKER_PKG == debian-plugins ]]; then
if [[ -n $MAJOR && -n $MINOR ]]; then
TAGS="${TAGS},${DOCKER_IMAGE}:${MINOR}-plugins,${DOCKER_IMAGE}:${MAJOR}-plugins"
TAGS_SLIM="${TAGS_SLIM},${DOCKER_IMAGE}:${MINOR}-plugins-slim,${DOCKER_IMAGE}:${MAJOR}-plugins-slim"
fi
TAGS="${TAGS},${DOCKER_IMAGE}:plugins"
TAGS_SLIM="${TAGS_SLIM},${DOCKER_IMAGE}:plugins-slim"
else
if [[ -n $MAJOR && -n $MINOR ]]; then
TAGS="${TAGS},${DOCKER_IMAGE}:${MINOR}-alpine,${DOCKER_IMAGE}:${MAJOR}-alpine"
@@ -114,22 +101,17 @@ jobs:
done
if [[ $OPTIONAL_DEPS == true ]]; then
echo "version=${VERSION}" >> $GITHUB_OUTPUT
echo "tags=${TAGS}" >> $GITHUB_OUTPUT
echo "full=true" >> $GITHUB_OUTPUT
echo ::set-output name=version::${VERSION}
echo ::set-output name=tags::${TAGS}
echo ::set-output name=full::true
else
echo "version=${VERSION_SLIM}" >> $GITHUB_OUTPUT
echo "tags=${TAGS_SLIM}" >> $GITHUB_OUTPUT
echo "full=false" >> $GITHUB_OUTPUT
echo ::set-output name=version::${VERSION_SLIM}
echo ::set-output name=tags::${TAGS_SLIM}
echo ::set-output name=full::false
fi
if [[ $DOCKER_PKG == debian-plugins ]]; then
echo "plugins=true" >> $GITHUB_OUTPUT
else
echo "plugins=false" >> $GITHUB_OUTPUT
fi
echo "dockerfile=${DOCKERFILE}" >> $GITHUB_OUTPUT
echo "created=$(date -u +'%Y-%m-%dT%H:%M:%SZ')" >> $GITHUB_OUTPUT
echo "sha=${GITHUB_SHA::8}" >> $GITHUB_OUTPUT
echo ::set-output name=dockerfile::${DOCKERFILE}
echo ::set-output name=created::$(date -u +'%Y-%m-%dT%H:%M:%SZ')
echo ::set-output name=sha::${GITHUB_SHA::8}
env:
DOCKER_PKG: ${{ matrix.docker_pkg }}
OPTIONAL_DEPS: ${{ matrix.optional_deps }}
@@ -168,8 +150,6 @@ jobs:
build-args: |
COMMIT_SHA=${{ steps.info.outputs.sha }}
INSTALL_OPTIONAL_PACKAGES=${{ steps.info.outputs.full }}
DOWNLOAD_PLUGINS=${{ steps.info.outputs.plugins }}
FEATURES=nopgxregisterdefaulttypes
labels: |
org.opencontainers.image.title=SFTPGo
org.opencontainers.image.description=Fully featured and highly configurable SFTP server with optional HTTP, FTP/S and WebDAV support
@@ -179,4 +159,4 @@ jobs:
org.opencontainers.image.version=${{ steps.info.outputs.version }}
org.opencontainers.image.created=${{ steps.info.outputs.created }}
org.opencontainers.image.revision=${{ github.sha }}
org.opencontainers.image.licenses=AGPL-3.0-only
org.opencontainers.image.licenses=AGPL-3.0

View File

@@ -5,7 +5,7 @@ on:
tags: 'v*'
env:
GO_VERSION: 1.19.2
GO_VERSION: 1.18.5
jobs:
prepare-sources-with-deps:
@@ -20,13 +20,12 @@ jobs:
- name: Get SFTPGo version
id: get_version
run: echo "VERSION=${GITHUB_REF/refs\/tags\//}" >> $GITHUB_OUTPUT
run: echo ::set-output name=VERSION::${GITHUB_REF/refs\/tags\//}
- name: Prepare release
run: |
go mod vendor
echo "${SFTPGO_VERSION}" > VERSION.txt
echo "${GITHUB_SHA::8}" >> VERSION.txt
tar cJvf sftpgo_${SFTPGO_VERSION}_src_with_deps.tar.xz *
env:
SFTPGO_VERSION: ${{ steps.get_version.outputs.VERSION }}
@@ -54,7 +53,7 @@ jobs:
- name: Get SFTPGo version
id: get_version
run: echo "VERSION=${GITHUB_REF/refs\/tags\//}" >> $GITHUB_OUTPUT
run: echo ::set-output name=VERSION::${GITHUB_REF/refs\/tags\//}
shell: bash
- name: Get OS name
@@ -62,9 +61,9 @@ jobs:
run: |
if [[ $MATRIX_OS =~ ^macos.* ]]
then
echo "OS=macOS" >> $GITHUB_OUTPUT
echo ::set-output name=OS::macOS
else
echo "OS=windows" >> $GITHUB_OUTPUT
echo ::set-output name=OS::windows
fi
shell: bash
env:
@@ -72,31 +71,31 @@ jobs:
- name: Build for macOS x86_64
if: startsWith(matrix.os, 'windows-') != true
run: go build -trimpath -tags nopgxregisterdefaulttypes -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=`git describe --always --abbrev=8 --dirty` -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo
run: go build -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/version.commit=`git describe --always --dirty` -X github.com/drakkan/sftpgo/v2/version.date=`date -u +%FT%TZ`" -o sftpgo
- name: Build for macOS arm64
if: startsWith(matrix.os, 'macos-') == true
run: CGO_ENABLED=1 GOOS=darwin GOARCH=arm64 SDKROOT=$(xcrun --sdk macosx --show-sdk-path) go build -trimpath -tags nopgxregisterdefaulttypes -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=`git describe --always --abbrev=8 --dirty` -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo_arm64
run: CGO_ENABLED=1 GOOS=darwin GOARCH=arm64 SDKROOT=$(xcrun --sdk macosx --show-sdk-path) go build -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/version.commit=`git describe --always --dirty` -X github.com/drakkan/sftpgo/v2/version.date=`date -u +%FT%TZ`" -o sftpgo_arm64
- name: Build for Windows
if: startsWith(matrix.os, 'windows-')
run: |
$GIT_COMMIT = (git describe --always --abbrev=8 --dirty) | Out-String
$GIT_COMMIT = (git describe --always --dirty) | Out-String
$DATE_TIME = ([datetime]::Now.ToUniversalTime().toString("yyyy-MM-ddTHH:mm:ssZ")) | Out-String
$FILE_VERSION = $Env:SFTPGO_VERSION.substring(1) + ".0"
go install github.com/tc-hib/go-winres@latest
go-winres simply --arch amd64 --product-version $Env:SFTPGO_VERSION-$GIT_COMMIT --file-version $FILE_VERSION --file-description "SFTPGo server" --product-name SFTPGo --copyright "AGPL-3.0" --original-filename sftpgo.exe --icon .\windows-installer\icon.ico
go build -trimpath -tags nopgxregisterdefaulttypes -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=$GIT_COMMIT -X github.com/drakkan/sftpgo/v2/internal/version.date=$DATE_TIME" -o sftpgo.exe
go build -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/version.commit=$GIT_COMMIT -X github.com/drakkan/sftpgo/v2/version.date=$DATE_TIME" -o sftpgo.exe
mkdir arm64
$Env:CGO_ENABLED='0'
$Env:GOOS='windows'
$Env:GOARCH='arm64'
go-winres simply --arch arm64 --product-version $Env:SFTPGO_VERSION-$GIT_COMMIT --file-version $FILE_VERSION --file-description "SFTPGo server" --product-name SFTPGo --copyright "AGPL-3.0" --original-filename sftpgo.exe --icon .\windows-installer\icon.ico
go build -trimpath -tags nopgxregisterdefaulttypes -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=$GIT_COMMIT -X github.com/drakkan/sftpgo/v2/internal/version.date=$DATE_TIME" -o .\arm64\sftpgo.exe
go build -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/version.commit=$GIT_COMMIT -X github.com/drakkan/sftpgo/v2/version.date=$DATE_TIME" -o .\arm64\sftpgo.exe
mkdir x86
$Env:GOARCH='386'
go-winres simply --arch 386 --product-version $Env:SFTPGO_VERSION-$GIT_COMMIT --file-version $FILE_VERSION --file-description "SFTPGo server" --product-name SFTPGo --copyright "AGPL-3.0" --original-filename sftpgo.exe --icon .\windows-installer\icon.ico
go build -trimpath -tags nopgxregisterdefaulttypes -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=$GIT_COMMIT -X github.com/drakkan/sftpgo/v2/internal/version.date=$DATE_TIME" -o .\x86\sftpgo.exe
go build -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/version.commit=$GIT_COMMIT -X github.com/drakkan/sftpgo/v2/version.date=$DATE_TIME" -o .\x86\sftpgo.exe
Remove-Item Env:\CGO_ENABLED
Remove-Item Env:\GOOS
Remove-Item Env:\GOARCH
@@ -255,12 +254,11 @@ jobs:
prepare-linux:
name: Prepare Linux binaries
runs-on: ubuntu-latest
runs-on: ubuntu-18.04
strategy:
matrix:
include:
- arch: amd64
distro: ubuntu:18.04
go-arch: amd64
deb-arch: amd64
rpm-arch: x86_64
@@ -286,13 +284,17 @@ jobs:
steps:
- uses: actions/checkout@v3
- name: Set up Go
if: ${{ matrix.arch == 'amd64' }}
uses: actions/setup-go@v3
with:
go-version: ${{ env.GO_VERSION }}
- name: Get versions
id: get_version
run: |
echo "SFTPGO_VERSION=${GITHUB_REF/refs\/tags\//}" >> $GITHUB_OUTPUT
echo "GO_VERSION=${GO_VERSION}" >> $GITHUB_OUTPUT
echo "COMMIT=${GITHUB_SHA::8}" >> $GITHUB_OUTPUT
echo ::set-output name=SFTPGO_VERSION::${GITHUB_REF/refs\/tags\//}
echo ::set-output name=GO_VERSION::${GO_VERSION}
shell: bash
env:
GO_VERSION: ${{ env.GO_VERSION }}
@@ -300,20 +302,7 @@ jobs:
- name: Build on amd64
if: ${{ matrix.arch == 'amd64' }}
run: |
echo '#!/bin/bash' > build.sh
echo '' >> build.sh
echo 'set -e' >> build.sh
echo 'apt-get update -q -y' >> build.sh
echo 'apt-get install -q -y curl gcc' >> build.sh
echo 'curl --retry 5 --retry-delay 2 --connect-timeout 10 -o go.tar.gz -L https://go.dev/dl/go${{ steps.get_version.outputs.GO_VERSION }}.linux-${{ matrix.go-arch }}.tar.gz' >> build.sh
echo 'tar -C /usr/local -xzf go.tar.gz' >> build.sh
echo 'export PATH=$PATH:/usr/local/go/bin' >> build.sh
echo 'go version' >> build.sh
echo 'cd /usr/local/src' >> build.sh
echo 'go build -buildvcs=false -trimpath -tags nopgxregisterdefaulttypes -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=${{ steps.get_version.outputs.COMMIT }} -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo' >> build.sh
chmod 755 build.sh
docker run --rm --name ubuntu-build --mount type=bind,source=`pwd`,target=/usr/local/src ${{ matrix.distro }} /usr/local/src/build.sh
go build -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/version.commit=`git describe --always --dirty` -X github.com/drakkan/sftpgo/v2/version.date=`date -u +%FT%TZ`" -o sftpgo
mkdir -p output/{init,sqlite,bash_completion,zsh_completion}
echo "For documentation please take a look here:" > output/README.txt
echo "" >> output/README.txt
@@ -351,7 +340,7 @@ jobs:
shell: /bin/bash
install: |
apt-get update -q -y
apt-get install -q -y curl gcc xz-utils
apt-get install -q -y curl gcc git xz-utils
GO_DOWNLOAD_ARCH=${{ matrix.go-arch }}
if [ ${{ matrix.arch}} == 'armv7' ]
then
@@ -361,8 +350,7 @@ jobs:
tar -C /usr/local -xzf go.tar.gz
run: |
export PATH=$PATH:/usr/local/go/bin
go version
go build -buildvcs=false -trimpath -tags nopgxregisterdefaulttypes -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=${{ steps.get_version.outputs.COMMIT }} -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo
go build -buildvcs=false -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/version.commit=`git describe --always --dirty` -X github.com/drakkan/sftpgo/v2/version.date=`date -u +%FT%TZ`" -o sftpgo
mkdir -p output/{init,sqlite,bash_completion,zsh_completion}
echo "For documentation please take a look here:" > output/README.txt
echo "" >> output/README.txt
@@ -398,7 +386,7 @@ jobs:
cd pkgs
./build.sh
PKG_VERSION=${SFTPGO_VERSION:1}
echo "pkg-version=${PKG_VERSION}" >> $GITHUB_OUTPUT
echo "::set-output name=pkg-version::${PKG_VERSION}"
env:
SFTPGO_VERSION: ${{ steps.get_version.outputs.SFTPGO_VERSION }}
@@ -425,7 +413,7 @@ jobs:
- name: Get versions
id: get_version
run: |
echo "SFTPGO_VERSION=${GITHUB_REF/refs\/tags\//}" >> $GITHUB_OUTPUT
echo ::set-output name=SFTPGO_VERSION::${GITHUB_REF/refs\/tags\//}
shell: bash
- name: Download amd64 artifact
@@ -485,8 +473,8 @@ jobs:
run: |
SFTPGO_VERSION=${GITHUB_REF/refs\/tags\//}
PKG_VERSION=${SFTPGO_VERSION:1}
echo "SFTPGO_VERSION=${SFTPGO_VERSION}" >> $GITHUB_OUTPUT
echo "PKG_VERSION=${PKG_VERSION}" >> $GITHUB_OUTPUT
echo ::set-output name=SFTPGO_VERSION::${SFTPGO_VERSION}
echo "::set-output name=PKG_VERSION::${PKG_VERSION}"
shell: bash
- name: Download amd64 artifact

View File

@@ -1,5 +1,5 @@
run:
timeout: 10m
timeout: 5m
issues-exit-code: 1
tests: true

View File

@@ -1,4 +1,4 @@
FROM golang:1.19-bullseye as builder
FROM golang:1.18-bullseye as builder
ENV GOFLAGS="-mod=readonly"
@@ -20,13 +20,8 @@ ARG FEATURES
COPY . .
RUN set -xe && \
export COMMIT_SHA=${COMMIT_SHA:-$(git describe --always --abbrev=8 --dirty)} && \
go build $(if [ -n "${FEATURES}" ]; then echo "-tags ${FEATURES}"; fi) -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=${COMMIT_SHA} -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -v -o sftpgo
# Set to "true" to download the "official" plugins in /usr/local/bin
ARG DOWNLOAD_PLUGINS=false
RUN if [ "${DOWNLOAD_PLUGINS}" = "true" ]; then apt-get update && apt-get install --no-install-recommends -y curl && ./docker/scripts/download-plugins.sh; fi
export COMMIT_SHA=${COMMIT_SHA:-$(git describe --always --dirty)} && \
go build $(if [ -n "${FEATURES}" ]; then echo "-tags ${FEATURES}"; fi) -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/version.commit=${COMMIT_SHA} -X github.com/drakkan/sftpgo/v2/version.date=`date -u +%FT%TZ`" -v -o sftpgo
FROM debian:bullseye-slim
@@ -48,7 +43,7 @@ COPY --from=builder /workspace/sftpgo.json /etc/sftpgo/sftpgo.json
COPY --from=builder /workspace/templates /usr/share/sftpgo/templates
COPY --from=builder /workspace/static /usr/share/sftpgo/static
COPY --from=builder /workspace/openapi /usr/share/sftpgo/openapi
COPY --from=builder /workspace/sftpgo /usr/local/bin/sftpgo-plugin-* /usr/local/bin/
COPY --from=builder /workspace/sftpgo /usr/local/bin/
# Log to the stdout so the logs will be available using docker logs
ENV SFTPGO_LOG_FILE_PATH=""

View File

@@ -1,4 +1,4 @@
FROM golang:1.19-alpine3.16 AS builder
FROM golang:1.18-alpine3.16 AS builder
ENV GOFLAGS="-mod=readonly"
@@ -22,8 +22,8 @@ ARG FEATURES
COPY . .
RUN set -xe && \
export COMMIT_SHA=${COMMIT_SHA:-$(git describe --always --abbrev=8 --dirty)} && \
go build $(if [ -n "${FEATURES}" ]; then echo "-tags ${FEATURES}"; fi) -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=${COMMIT_SHA} -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -v -o sftpgo
export COMMIT_SHA=${COMMIT_SHA:-$(git describe --always --dirty)} && \
go build $(if [ -n "${FEATURES}" ]; then echo "-tags ${FEATURES}"; fi) -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/version.commit=${COMMIT_SHA} -X github.com/drakkan/sftpgo/v2/version.date=`date -u +%FT%TZ`" -v -o sftpgo
FROM alpine:3.16

View File

@@ -1,4 +1,4 @@
FROM golang:1.19-bullseye as builder
FROM golang:1.18-bullseye as builder
ENV CGO_ENABLED=0 GOFLAGS="-mod=readonly"
@@ -20,8 +20,8 @@ ARG FEATURES=nosqlite
COPY . .
RUN set -xe && \
export COMMIT_SHA=${COMMIT_SHA:-$(git describe --always --abbrev=8 --dirty)} && \
go build $(if [ -n "${FEATURES}" ]; then echo "-tags ${FEATURES}"; fi) -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=${COMMIT_SHA} -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -v -o sftpgo
export COMMIT_SHA=${COMMIT_SHA:-$(git describe --always --dirty)} && \
go build $(if [ -n "${FEATURES}" ]; then echo "-tags ${FEATURES}"; fi) -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/version.commit=${COMMIT_SHA} -X github.com/drakkan/sftpgo/v2/version.date=`date -u +%FT%TZ`" -v -o sftpgo
# Modify the default configuration file
RUN sed -i 's|"users_base_dir": "",|"users_base_dir": "/srv/sftpgo/data",|' sftpgo.json && \

View File

@@ -1,8 +1,8 @@
# SFTPGo
[![CI Status](https://github.com/drakkan/sftpgo/workflows/CI/badge.svg?branch=main&event=push)](https://github.com/drakkan/sftpgo/workflows/CI/badge.svg?branch=main&event=push)
![CI Status](https://github.com/drakkan/sftpgo/workflows/CI/badge.svg?branch=main&event=push)
[![Code Coverage](https://codecov.io/gh/drakkan/sftpgo/branch/main/graph/badge.svg)](https://codecov.io/gh/drakkan/sftpgo/branch/main)
[![License: AGPL-3.0-only](https://img.shields.io/badge/License-AGPLv3-blue.svg)](https://www.gnu.org/licenses/agpl-3.0)
[![License: AGPL v3](https://img.shields.io/badge/License-AGPLv3-blue.svg)](https://www.gnu.org/licenses/agpl-3.0)
[![Docker Pulls](https://img.shields.io/docker/pulls/drakkan/sftpgo)](https://hub.docker.com/r/drakkan/sftpgo)
[![Mentioned in Awesome Go](https://awesome.re/mentioned-badge.svg)](https://github.com/avelino/awesome-go)
@@ -26,24 +26,10 @@ If you just take and don't return anything back, the project will die in the lon
More [info](https://github.com/drakkan/sftpgo/issues/452).
### Thank you to our sponsors
#### Platinum sponsors
[<img src="./img/Aledade_logo.png" alt="Aledade logo" width="202" height="70">](https://www.aledade.com/)
#### Bronze sponsors
Thank you to our sponsors!
[<img src="https://www.7digital.com/wp-content/themes/sevendigital/images/top_logo.png" alt="7digital logo">](https://www.7digital.com/)
## Support policy
SFTPGo is an Open Source project and you can of course use it for free but please don't ask for free support as well.
We will check the reported issues to see if you are experiencing a bug and if so we'll will fix it, but will only provide support to project [sponsors/donors](#sponsors).
If you report an invalid issue or ask for step-by-step support, your issue will remain open with no answer or will be closed as invalid without further explanation. Thanks for understanding.
## Features
- Support for serving local filesystem, encrypted local filesystem, S3 Compatible Object Storage, Google Cloud Storage, Azure Blob Storage or other SFTP accounts over SFTP/SCP/FTP/WebDAV.
@@ -54,7 +40,6 @@ If you report an invalid issue or ask for step-by-step support, your issue will
- Chroot isolation for local accounts. Cloud-based accounts can be restricted to a certain base path.
- Per-user and per-directory virtual permissions, for each exposed path you can allow or deny: directory listing, upload, overwrite, download, delete, rename, create directories, create symlinks, change owner/group/file mode and modification time.
- [REST API](./docs/rest-api.md) for users and folders management, data retention, backup, restore and real time reports of the active connections with possibility of forcibly closing a connection.
- The [Event Manager](./docs/eventmanager.md) allows to define custom workflows based on server events or schedules.
- [Web based administration interface](./docs/web-admin.md) to easily manage users, folders and connections.
- [Web client interface](./docs/web-client.md) so that end users can change their credentials, manage and share their files in the browser.
- Public key and password authentication. Multiple public keys per-user are supported.
@@ -64,10 +49,10 @@ If you report an invalid issue or ask for step-by-step support, your issue will
- Per-user authentication methods.
- [Two-factor authentication](./docs/howto/two-factor-authentication.md) based on time-based one time passwords (RFC 6238) which works with Authy, Google Authenticator and other compatible apps.
- Simplified user administrations using [groups](./docs/groups.md).
- Custom authentication via [external programs/HTTP API](./docs/external-auth.md).
- Custom authentication via external programs/HTTP API.
- Web Client and Web Admin user interfaces support [OpenID Connect](https://openid.net/connect/) authentication and so they can be integrated with identity providers such as [Keycloak](https://www.keycloak.org/). You can find more details [here](./docs/oidc.md).
- [Data At Rest Encryption](./docs/dare.md).
- Dynamic user modification before login via [external programs/HTTP API](./docs/dynamic-user-mod.md).
- Dynamic user modification before login via external programs/HTTP API.
- Quota support: accounts can have individual disk quota expressed as max total size and/or max number of files.
- Bandwidth throttling, with separate settings for upload and download and overrides based on the client's IP address.
- Data transfer bandwidth limits, with total limit or separate settings for uploads and downloads and overrides based on the client's IP address. Limits can be reset using the REST API.
@@ -104,10 +89,8 @@ SFTPGo is developed and tested on Linux. After each commit, the code is automati
## Requirements
- Go as build only dependency. We support the Go version(s) used in [continuous integration workflows](./.github/workflows).
- A suitable SQL server to use as data provider:
- upstream supported versions of PostgreSQL, MySQL and MariaDB.
- CockroachDB stable.
- The SQL server is optional: you can choose to use an embedded SQLite, bolt or in memory data provider.
- A suitable SQL server to use as data provider: PostgreSQL 9.4+, MySQL 5.6+, SQLite 3.x, CockroachDB stable.
- The SQL server is optional: you can choose to use an embedded bolt database as key/value store or an in memory data provider.
## Installation
@@ -131,13 +114,7 @@ An official Docker image is available. Documentation is [here](./docker/README.m
APT and YUM repositories are [available](./docs/repo.md).
SFTPGo is also available on some marketplaces:
- [AWS Marketplace](https://aws.amazon.com/marketplace/seller-profile?id=6e849ab8-70a6-47de-9a43-13c3fa849335)
- [Azure Marketplace](https://azuremarketplace.microsoft.com/en-us/marketplace/apps/prasselsrl1645470739547.sftpgo_linux)
- [Elest.io](https://elest.io/open-source/sftpgo)
Purchasing from there will help keep SFTPGo a long-term sustainable project.
SFTPGo is also available on [AWS Marketplace](https://aws.amazon.com/marketplace/seller-profile?id=6e849ab8-70a6-47de-9a43-13c3fa849335) and [Azure Marketplace](https://azuremarketplace.microsoft.com/en-us/marketplace/apps/prasselsrl1645470739547.sftpgo_linux), purchasing from there will help keep SFTPGo a long-term sustainable project.
<details><summary>Windows packages</summary>
@@ -248,18 +225,16 @@ The `revertprovider` command is not supported for the memory provider.
Please note that we only support the current release branch and the current main branch, if you find a bug it is better to report it rather than downgrading to an older unsupported version.
## Users, groups and folders management
## Users and folders management
After starting SFTPGo you can manage users, groups, folders and other resources using:
After starting SFTPGo you can manage users and folders using:
- the [web based administration interface](./docs/web-admin.md)
- the [REST API](./docs/rest-api.md)
To support embedded data providers like `bolt` and `SQLite`, which do not support concurrent connections, we can't have a CLI that directly write users and other resources to the data provider, we always have to use the REST API.
To support embedded data providers like `bolt` and `SQLite` we can't have a CLI that directly write users and folders to the data provider, we always have to use the REST API.
Full details for users, groups, folders, admins and other resources are documented in the [OpenAPI](./openapi/openapi.yaml) schema. If you want to render the schema without importing it manually, you can explore it on [Stoplight](https://sftpgo.stoplight.io/docs/sftpgo/openapi.yaml).
:warning: SFTPGo users, groups and folders are virtual and therefore unrelated to the system ones. There is no need to create system-wide users and groups.
Full details for users, folders, admins and other resources are documented in the [OpenAPI](./openapi/openapi.yaml) schema. If you want to render the schema without importing it manually, you can explore it on [Stoplight](https://sftpgo.stoplight.io/docs/sftpgo/openapi.yaml).
## Tutorials
@@ -315,10 +290,6 @@ Each user can be mapped to another SFTP server account or a subfolder of it. Mor
Data at-rest encryption is supported via the [cryptfs backend](./docs/dare.md).
### HTTP/S backend
HTTP/S backend allows you to write your own custom storage backend by implementing a REST API. More information can be found [here](./docs/httpfs.md).
### Other Storage backends
Adding new storage backends is quite easy:
@@ -360,4 +331,4 @@ Thank you [ysura](https://www.ysura.com/) for granting me stable access to a tes
## License
GNU AGPL-3.0-only
GNU AGPLv3

View File

@@ -1,8 +1,8 @@
# SFTPGo
[![CI Status](https://github.com/drakkan/sftpgo/workflows/CI/badge.svg?branch=main&event=push)](https://github.com/drakkan/sftpgo/workflows/CI/badge.svg?branch=main&event=push)
![CI Status](https://github.com/drakkan/sftpgo/workflows/CI/badge.svg?branch=main&event=push)
[![Code Coverage](https://codecov.io/gh/drakkan/sftpgo/branch/main/graph/badge.svg)](https://codecov.io/gh/drakkan/sftpgo/branch/main)
[![License: AGPL-3.0-only](https://img.shields.io/badge/License-AGPLv3-blue.svg)](https://www.gnu.org/licenses/agpl-3.0)
[![License: AGPL v3](https://img.shields.io/badge/License-AGPLv3-blue.svg)](https://www.gnu.org/licenses/agpl-3.0)
[![Docker Pulls](https://img.shields.io/docker/pulls/drakkan/sftpgo)](https://hub.docker.com/r/drakkan/sftpgo)
[![Mentioned in Awesome Go](https://awesome.re/mentioned-badge.svg)](https://github.com/avelino/awesome-go)
@@ -11,39 +11,6 @@
功能齐全、高度可配置化、支持自定义 HTTP/SFTP/S 和 WebDAV 的 SFTP 服务。
一些存储后端支持本地文件系统、加密本地文件系统、S3兼容对象存储Google Cloud 存储Azure Blob 存储SFTP。
## 赞助商
如果你觉得 SFTPGo 有用,请考虑支持这个开源项目。
维护和发展 SFTPGo 对我来说是很多工作——很容易相当于一份全职工作。
我想让 SFTPGo 成为一个可持续的长期项目,并且不想引入双重许可选项并将某些功能仅限于专有版本。
如果您使用 SFTPGo确保您所依赖的项目保持健康和维护良好符合您的最大利益。
这只能通过您的捐款和[赞助](https://github.com/sponsors/drakkan) 发生heart
如果您只是拿走任何东西而不返回任何东西,从长远来看,该项目将失败,您将被迫为类似的专有解决方案付费。
[更多信息](https://github.com/drakkan/sftpgo/issues/452)。
### 感谢我们的赞助商
#### 白金赞助商
[<img src="./img/Aledade_logo.png" alt="Aledade logo" width="202" height="70">](https://www.aledade.com/)
#### 铜牌赞助商
[<img src="https://www.7digital.com/wp-content/themes/sevendigital/images/top_logo.png" alt="7digital logo">](https://www.7digital.com/)
## 支持政策
SFTPGo 是一个开源项目,您当然可以免费使用它,但也请不要要求免费支持。
我们将检查报告的问题以查看您是否遇到错误,如果是,我们将修复它,但只会为项目赞助商/捐助者提供支持。
如果您报告无效问题或要求逐步支持,您的问题将保持打开状态而没有答案,或者将被关闭为无效而无需进一步解释。 感谢您的理解。
## 特性
- 支持服务本地文件系统、加密本地文件系统、S3 兼容对象存储、Google Cloud 存储、Azure Blob 存储或其它基于 SFTP/SCP/FTP/WebDAV 协议的 SFTP 账户。
@@ -148,7 +115,7 @@ SFTPGo 在 [AWS Marketplace](https://aws.amazon.com/marketplace/seller-profile?i
可以完整的配置项方法说明可以参考 [配置项](./docs/full-configuration.md)。
请确保按需运行之前,[初始化数据提供程序](#数据提供程序初始化和管理)。
请确保按需运行之前,[初始化数据提供程序](#data-provider-initialization-and-management)。
默认配置启动 STFPGo运行
@@ -348,4 +315,4 @@ SFTPGo 使用了 [go.mod](./go.mod) 中列出的第三方库。
## 许可证
GNU AGPL-3.0-only
GNU AGPLv3

View File

@@ -45,14 +45,13 @@ import (
"github.com/go-acme/lego/v4/registration"
"github.com/robfig/cron/v3"
"github.com/drakkan/sftpgo/v2/internal/common"
"github.com/drakkan/sftpgo/v2/internal/ftpd"
"github.com/drakkan/sftpgo/v2/internal/httpd"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/telemetry"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/internal/version"
"github.com/drakkan/sftpgo/v2/internal/webdavd"
"github.com/drakkan/sftpgo/v2/ftpd"
"github.com/drakkan/sftpgo/v2/httpd"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/telemetry"
"github.com/drakkan/sftpgo/v2/util"
"github.com/drakkan/sftpgo/v2/version"
"github.com/drakkan/sftpgo/v2/webdavd"
)
const (
@@ -562,24 +561,6 @@ func (c *Configuration) getCertificates() error {
return nil
}
func (c *Configuration) notifyCertificateRenewal(domain string, err error) {
if domain == "" {
domain = strings.Join(c.Domains, ",")
}
params := common.EventParams{
Name: domain,
Event: "Certificate renewal",
Timestamp: time.Now().UnixNano(),
}
if err != nil {
params.Status = 2
params.AddError(err)
} else {
params.Status = 1
}
common.HandleCertificateEvent(params)
}
func (c *Configuration) renewCertificates() error {
lockTime, err := c.getLockTime()
if err != nil {
@@ -592,28 +573,22 @@ func (c *Configuration) renewCertificates() error {
}
err = c.setLockTime()
if err != nil {
c.notifyCertificateRenewal("", err)
return err
}
account, client, err := c.setup()
if err != nil {
c.notifyCertificateRenewal("", err)
return err
}
if account.Registration == nil {
acmeLog(logger.LevelError, "cannot renew certificates, your account is not registered")
err = errors.New("cannot renew certificates, your account is not registered")
c.notifyCertificateRenewal("", err)
return err
return fmt.Errorf("cannot renew certificates, your account is not registered")
}
var errRenew error
needReload := false
for _, domain := range c.Domains {
certificates, err := c.loadCertificatesForDomain(domain)
if err != nil {
c.notifyCertificateRenewal(domain, err)
errRenew = err
continue
return err
}
cert := certificates[0]
if !c.needRenewal(cert, domain) {
@@ -621,10 +596,8 @@ func (c *Configuration) renewCertificates() error {
}
err = c.obtainAndSaveCertificate(client, domain)
if err != nil {
c.notifyCertificateRenewal(domain, err)
errRenew = err
} else {
c.notifyCertificateRenewal(domain, nil)
needReload = true
}
}

View File

@@ -20,10 +20,10 @@ import (
"github.com/rs/zerolog"
"github.com/spf13/cobra"
"github.com/drakkan/sftpgo/v2/internal/acme"
"github.com/drakkan/sftpgo/v2/internal/config"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/acme"
"github.com/drakkan/sftpgo/v2/config"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/util"
)
var (
@@ -40,13 +40,13 @@ Certificates are saved in the configured "certs_path".
After this initial step, the certificates are automatically checked and
renewed by the SFTPGo service
`,
Run: func(_ *cobra.Command, _ []string) {
Run: func(cmd *cobra.Command, args []string) {
logger.DisableLogger()
logger.EnableConsoleLogger(zerolog.DebugLevel)
configDir = util.CleanDirInput(configDir)
err := config.LoadConfig(configDir, configFile)
if err != nil {
logger.ErrorToConsole("Unable to initialize ACME, config load error: %v", err)
logger.ErrorToConsole("Unable to initialize data provider, config load error: %v", err)
return
}
acmeConfig := config.GetACMEConfig()

View File

@@ -21,4 +21,4 @@ import (
"github.com/spf13/cobra"
)
func addAWSContainerFlags(_ *cobra.Command) {}
func addAWSContainerFlags(cmd *cobra.Command) {}

View File

@@ -24,8 +24,8 @@ import (
"github.com/spf13/cobra"
"github.com/spf13/cobra/doc"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/version"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/version"
)
var (
@@ -38,7 +38,7 @@ command-line interface.
By default, it creates the man page files in the "man" directory under the
current directory.
`,
Run: func(cmd *cobra.Command, _ []string) {
Run: func(cmd *cobra.Command, args []string) {
logger.DisableLogger()
logger.EnableConsoleLogger(zerolog.DebugLevel)
if _, err := os.Stat(manDir); errors.Is(err, fs.ErrNotExist) {

View File

@@ -21,11 +21,11 @@ import (
"github.com/spf13/cobra"
"github.com/spf13/viper"
"github.com/drakkan/sftpgo/v2/internal/config"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/service"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/config"
"github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/service"
"github.com/drakkan/sftpgo/v2/util"
)
var (
@@ -50,7 +50,7 @@ $ sftpgo initprovider
Any defined action is ignored.
Please take a look at the usage below to customize the options.`,
Run: func(_ *cobra.Command, _ []string) {
Run: func(cmd *cobra.Command, args []string) {
logger.DisableLogger()
logger.EnableConsoleLogger(zerolog.DebugLevel)
configDir = util.CleanDirInput(configDir)

View File

@@ -21,8 +21,8 @@ import (
"github.com/spf13/cobra"
"github.com/drakkan/sftpgo/v2/internal/service"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/service"
"github.com/drakkan/sftpgo/v2/util"
)
var (
@@ -35,7 +35,7 @@ line flags simply use:
sftpgo service install
Please take a look at the usage below to customize the startup options`,
Run: func(_ *cobra.Command, _ []string) {
Run: func(cmd *cobra.Command, args []string) {
s := service.Service{
ConfigDir: util.CleanDirInput(configDir),
ConfigFile: configFile,
@@ -44,7 +44,7 @@ Please take a look at the usage below to customize the startup options`,
LogMaxBackups: logMaxBackups,
LogMaxAge: logMaxAge,
LogCompress: logCompress,
LogLevel: logLevel,
LogVerbose: logVerbose,
LogUTCTime: logUTCTime,
Shutdown: make(chan bool),
}
@@ -99,9 +99,8 @@ func getCustomServeFlags() []string {
result = append(result, "--"+logMaxAgeFlag)
result = append(result, strconv.Itoa(logMaxAge))
}
if logLevel != defaultLogLevel {
result = append(result, "--"+logLevelFlag)
result = append(result, logLevel)
if logVerbose != defaultLogVerbose {
result = append(result, "--"+logVerboseFlag+"=false")
}
if logUTCTime != defaultLogUTCTime {
result = append(result, "--"+logUTCTimeFlag+"=true")
@@ -109,9 +108,5 @@ func getCustomServeFlags() []string {
if logCompress != defaultLogCompress {
result = append(result, "--"+logCompressFlag+"=true")
}
if graceTime != defaultGraceTime {
result = append(result, "--"+graceTimeFlag)
result = append(result, strconv.Itoa(graceTime))
}
return result
}

View File

@@ -27,13 +27,13 @@ import (
"github.com/sftpgo/sdk"
"github.com/spf13/cobra"
"github.com/drakkan/sftpgo/v2/internal/common"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/kms"
"github.com/drakkan/sftpgo/v2/internal/service"
"github.com/drakkan/sftpgo/v2/internal/sftpd"
"github.com/drakkan/sftpgo/v2/internal/version"
"github.com/drakkan/sftpgo/v2/internal/vfs"
"github.com/drakkan/sftpgo/v2/common"
"github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/kms"
"github.com/drakkan/sftpgo/v2/service"
"github.com/drakkan/sftpgo/v2/sftpd"
"github.com/drakkan/sftpgo/v2/version"
"github.com/drakkan/sftpgo/v2/vfs"
)
var (
@@ -45,7 +45,7 @@ var (
portablePassword string
portableStartDir string
portableLogFile string
portableLogLevel string
portableLogVerbose bool
portableLogUTCTime bool
portablePublicKeys []string
portablePermissions []string
@@ -106,7 +106,7 @@ use:
$ sftpgo portable
Please take a look at the usage below to customize the serving parameters`,
Run: func(_ *cobra.Command, _ []string) {
Run: func(cmd *cobra.Command, args []string) {
portableDir := directoryToServe
fsProvider := sdk.GetProviderByName(portableFsProvider)
if !filepath.IsAbs(portableDir) {
@@ -169,7 +169,6 @@ Please take a look at the usage below to customize the serving parameters`,
os.Exit(1)
}
}
service.SetGraceTime(graceTime)
service := service.Service{
ConfigDir: filepath.Clean(defaultConfigDir),
ConfigFile: defaultConfigFile,
@@ -178,7 +177,7 @@ Please take a look at the usage below to customize the serving parameters`,
LogMaxBackups: defaultLogMaxBackup,
LogMaxAge: defaultLogMaxAge,
LogCompress: defaultLogCompress,
LogLevel: portableLogLevel,
LogVerbose: portableLogVerbose,
LogUTCTime: portableLogUTCTime,
Shutdown: make(chan bool),
PortableMode: 1,
@@ -258,10 +257,8 @@ Please take a look at the usage below to customize the serving parameters`,
},
},
}
err := service.StartPortableMode(portableSFTPDPort, portableFTPDPort, portableWebDAVPort, portableSSHCommands,
portableAdvertiseService, portableAdvertiseCredentials, portableFTPSCert, portableFTPSKey, portableWebDAVCert,
portableWebDAVKey)
if err == nil {
if err := service.StartPortableMode(portableSFTPDPort, portableFTPDPort, portableWebDAVPort, portableSSHCommands, portableAdvertiseService,
portableAdvertiseCredentials, portableFTPSCert, portableFTPSKey, portableWebDAVCert, portableWebDAVKey); err == nil {
service.Wait()
if service.Error == nil {
os.Exit(0)
@@ -298,11 +295,7 @@ value`)
portableCmd.Flags().StringVarP(&portablePassword, "password", "p", "", `Leave empty to use an auto generated
value`)
portableCmd.Flags().StringVarP(&portableLogFile, logFilePathFlag, "l", "", "Leave empty to disable logging")
portableCmd.Flags().StringVar(&portableLogLevel, logLevelFlag, defaultLogLevel, `Set the log level.
Supported values:
debug, info, warn, error.
`)
portableCmd.Flags().BoolVarP(&portableLogVerbose, logVerboseFlag, "v", false, "Enable verbose logs")
portableCmd.Flags().BoolVar(&portableLogUTCTime, logUTCTimeFlag, false, "Use UTC time for logging")
portableCmd.Flags().StringSliceVarP(&portablePublicKeys, "public-key", "k", []string{}, "")
portableCmd.Flags().StringSliceVarP(&portablePermissions, "permissions", "g", []string{"list", "download"},
@@ -406,13 +399,6 @@ multiple concurrent requests and this
allows data to be transferred at a
faster rate, over high latency networks,
by overlapping round-trip times`)
portableCmd.Flags().IntVar(&graceTime, graceTimeFlag, 0,
`This grace time defines the number of
seconds allowed for existing transfers
to get completed before shutting down.
A graceful shutdown is triggered by an
interrupt signal.
`)
rootCmd.AddCommand(portableCmd)
}

View File

@@ -17,7 +17,7 @@
package cmd
import "github.com/drakkan/sftpgo/v2/internal/version"
import "github.com/drakkan/sftpgo/v2/version"
func init() {
version.AddFeature("-portable")

View File

@@ -20,14 +20,14 @@ import (
"github.com/spf13/cobra"
"github.com/drakkan/sftpgo/v2/internal/service"
"github.com/drakkan/sftpgo/v2/service"
)
var (
reloadCmd = &cobra.Command{
Use: "reload",
Short: "Reload the SFTPGo Windows Service sending a \"paramchange\" request",
Run: func(_ *cobra.Command, _ []string) {
Run: func(cmd *cobra.Command, args []string) {
s := service.WindowsService{
Service: service.Service{
Shutdown: make(chan bool),

View File

@@ -23,10 +23,10 @@ import (
"github.com/spf13/cobra"
"github.com/spf13/viper"
"github.com/drakkan/sftpgo/v2/internal/config"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/config"
"github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/util"
)
var (
@@ -39,7 +39,7 @@ configuration file and resets the provider by deleting all data and schemas.
This command is not supported for the memory provider.
Please take a look at the usage below to customize the options.`,
Run: func(_ *cobra.Command, _ []string) {
Run: func(cmd *cobra.Command, args []string) {
logger.DisableLogger()
logger.EnableConsoleLogger(zerolog.DebugLevel)
configDir = util.CleanDirInput(configDir)

View File

@@ -21,10 +21,10 @@ import (
"github.com/spf13/cobra"
"github.com/spf13/viper"
"github.com/drakkan/sftpgo/v2/internal/config"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/config"
"github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/util"
)
var (
@@ -37,11 +37,11 @@ configuration file and restore the provider schema and/or data to a previous ver
This command is not supported for the memory provider.
Please take a look at the usage below to customize the options.`,
Run: func(_ *cobra.Command, _ []string) {
Run: func(cmd *cobra.Command, args []string) {
logger.DisableLogger()
logger.EnableConsoleLogger(zerolog.DebugLevel)
if revertProviderTargetVersion != 19 {
logger.WarnToConsole("Unsupported target version, 19 is the only supported one")
if revertProviderTargetVersion != 15 {
logger.WarnToConsole("Unsupported target version, 15 is the only supported one")
os.Exit(1)
}
configDir = util.CleanDirInput(configDir)
@@ -71,7 +71,7 @@ Please take a look at the usage below to customize the options.`,
func init() {
addConfigFlags(revertProviderCmd)
revertProviderCmd.Flags().IntVar(&revertProviderTargetVersion, "to-version", 19, `19 means the version supported in v2.3.x`)
revertProviderCmd.Flags().IntVar(&revertProviderTargetVersion, "to-version", 15, `15 means the version supported in v2.2.x`)
rootCmd.AddCommand(revertProviderCmd)
}

View File

@@ -22,7 +22,7 @@ import (
"github.com/spf13/cobra"
"github.com/spf13/viper"
"github.com/drakkan/sftpgo/v2/internal/version"
"github.com/drakkan/sftpgo/v2/version"
)
const (
@@ -40,8 +40,8 @@ const (
logMaxAgeKey = "log_max_age"
logCompressFlag = "log-compress"
logCompressKey = "log_compress"
logLevelFlag = "log-level"
logLevelKey = "log_level"
logVerboseFlag = "log-verbose"
logVerboseKey = "log_verbose"
logUTCTimeFlag = "log-utc-time"
logUTCTimeKey = "log_utc_time"
loadDataFromFlag = "loaddata-from"
@@ -52,8 +52,6 @@ const (
loadDataQuotaScanKey = "loaddata_scan"
loadDataCleanFlag = "loaddata-clean"
loadDataCleanKey = "loaddata_clean"
graceTimeFlag = "grace-time"
graceTimeKey = "grace_time"
defaultConfigDir = "."
defaultConfigFile = ""
defaultLogFile = "sftpgo.log"
@@ -61,13 +59,12 @@ const (
defaultLogMaxBackup = 5
defaultLogMaxAge = 28
defaultLogCompress = false
defaultLogLevel = "debug"
defaultLogVerbose = true
defaultLogUTCTime = false
defaultLoadDataFrom = ""
defaultLoadDataMode = 1
defaultLoadDataQuotaScan = 0
defaultLoadDataClean = false
defaultGraceTime = 0
)
var (
@@ -78,13 +75,12 @@ var (
logMaxBackups int
logMaxAge int
logCompress bool
logLevel string
logVerbose bool
logUTCTime bool
loadDataFrom string
loadDataMode int
loadDataQuotaScan int
loadDataClean bool
graceTime int
// used if awscontainer build tag is enabled
disableAWSInstallationCode bool
@@ -233,17 +229,13 @@ It is unused if log-file-path is empty.
`)
viper.BindPFlag(logCompressKey, cmd.Flags().Lookup(logCompressFlag)) //nolint:errcheck
viper.SetDefault(logLevelKey, defaultLogLevel)
viper.BindEnv(logLevelKey, "SFTPGO_LOG_LEVEL") //nolint:errcheck
cmd.Flags().StringVar(&logLevel, logLevelFlag, viper.GetString(logLevelKey),
`Set the log level. Supported values:
debug, info, warn, error.
This flag can be set
using SFTPGO_LOG_LEVEL env var too.
viper.SetDefault(logVerboseKey, defaultLogVerbose)
viper.BindEnv(logVerboseKey, "SFTPGO_LOG_VERBOSE") //nolint:errcheck
cmd.Flags().BoolVarP(&logVerbose, logVerboseFlag, "v", viper.GetBool(logVerboseKey),
`Enable verbose logs. This flag can be set
using SFTPGO_LOG_VERBOSE env var too.
`)
viper.BindPFlag(logLevelKey, cmd.Flags().Lookup(logLevelFlag)) //nolint:errcheck
viper.BindPFlag(logVerboseKey, cmd.Flags().Lookup(logVerboseFlag)) //nolint:errcheck
viper.SetDefault(logUTCTimeKey, defaultLogUTCTime)
viper.BindEnv(logUTCTimeKey, "SFTPGO_LOG_UTC_TIME") //nolint:errcheck
@@ -266,20 +258,4 @@ This flag can be set using SFTPGO_LOADDATA_QUOTA_SCAN
env var too.
(default 0)`)
viper.BindPFlag(loadDataQuotaScanKey, cmd.Flags().Lookup(loadDataQuotaScanFlag)) //nolint:errcheck
viper.SetDefault(graceTimeKey, defaultGraceTime)
viper.BindEnv(graceTimeKey, "SFTPGO_GRACE_TIME") //nolint:errcheck
cmd.Flags().IntVar(&graceTime, graceTimeFlag, viper.GetInt(graceTimeKey),
`Graceful shutdown is an option to initiate a
shutdown without abrupt cancellation of the
currently ongoing client-initiated transfer
sessions.
This grace time defines the number of seconds
allowed for existing transfers to get
completed before shutting down.
A graceful shutdown is triggered by an
interrupt signal.
This flag can be set using SFTPGO_GRACE_TIME env
var too. 0 means disabled. (default 0)`)
viper.BindPFlag(graceTimeKey, cmd.Flags().Lookup(graceTimeFlag)) //nolint:errcheck
}

View File

@@ -20,14 +20,14 @@ import (
"github.com/spf13/cobra"
"github.com/drakkan/sftpgo/v2/internal/service"
"github.com/drakkan/sftpgo/v2/service"
)
var (
rotateLogCmd = &cobra.Command{
Use: "rotatelogs",
Short: "Signal to the running service to rotate the logs",
Run: func(_ *cobra.Command, _ []string) {
Run: func(cmd *cobra.Command, args []string) {
s := service.WindowsService{
Service: service.Service{
Shutdown: make(chan bool),

View File

@@ -19,8 +19,8 @@ import (
"github.com/spf13/cobra"
"github.com/drakkan/sftpgo/v2/internal/service"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/service"
"github.com/drakkan/sftpgo/v2/util"
)
var (
@@ -33,8 +33,7 @@ use:
$ sftpgo serve
Please take a look at the usage below to customize the startup options`,
Run: func(_ *cobra.Command, _ []string) {
service.SetGraceTime(graceTime)
Run: func(cmd *cobra.Command, args []string) {
service := service.Service{
ConfigDir: util.CleanDirInput(configDir),
ConfigFile: configFile,
@@ -43,7 +42,7 @@ Please take a look at the usage below to customize the startup options`,
LogMaxBackups: logMaxBackups,
LogMaxAge: logMaxAge,
LogCompress: logCompress,
LogLevel: logLevel,
LogVerbose: logVerbose,
LogUTCTime: logUTCTime,
LoadDataFrom: loadDataFrom,
LoadDataMode: loadDataMode,

View File

@@ -20,10 +20,10 @@ import (
"github.com/rs/zerolog"
"github.com/spf13/cobra"
"github.com/drakkan/sftpgo/v2/internal/config"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/smtp"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/config"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/smtp"
"github.com/drakkan/sftpgo/v2/util"
)
var (
@@ -33,7 +33,7 @@ var (
Short: "Test the SMTP configuration",
Long: `SFTPGo will try to send a test email to the specified recipient.
If the SMTP configuration is correct you should receive this email.`,
Run: func(_ *cobra.Command, _ []string) {
Run: func(cmd *cobra.Command, args []string) {
logger.DisableLogger()
logger.EnableConsoleLogger(zerolog.DebugLevel)
configDir = util.CleanDirInput(configDir)
@@ -48,7 +48,7 @@ If the SMTP configuration is correct you should receive this email.`,
logger.ErrorToConsole("unable to initialize SMTP configuration: %v", err)
os.Exit(1)
}
err = smtp.SendEmail([]string{smtpTestRecipient}, "SFTPGo - Testing Email Settings", "It appears your SFTPGo email is setup correctly!",
err = smtp.SendEmail(smtpTestRecipient, "SFTPGo - Testing Email Settings", "It appears your SFTPGo email is setup correctly!",
smtp.EmailContentTypeTextPlain)
if err != nil {
logger.WarnToConsole("Error sending email: %v", err)

View File

@@ -21,20 +21,19 @@ import (
"github.com/spf13/cobra"
"github.com/drakkan/sftpgo/v2/internal/service"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/service"
"github.com/drakkan/sftpgo/v2/util"
)
var (
startCmd = &cobra.Command{
Use: "start",
Short: "Start the SFTPGo Windows Service",
Run: func(_ *cobra.Command, _ []string) {
Run: func(cmd *cobra.Command, args []string) {
configDir = util.CleanDirInput(configDir)
if !filepath.IsAbs(logFilePath) && util.IsFileInputValid(logFilePath) {
logFilePath = filepath.Join(configDir, logFilePath)
}
service.SetGraceTime(graceTime)
s := service.Service{
ConfigDir: configDir,
ConfigFile: configFile,
@@ -43,7 +42,7 @@ var (
LogMaxBackups: logMaxBackups,
LogMaxAge: logMaxAge,
LogCompress: logCompress,
LogLevel: logLevel,
LogVerbose: logVerbose,
LogUTCTime: logUTCTime,
Shutdown: make(chan bool),
}

View File

@@ -25,13 +25,13 @@ import (
"github.com/spf13/cobra"
"github.com/spf13/viper"
"github.com/drakkan/sftpgo/v2/internal/common"
"github.com/drakkan/sftpgo/v2/internal/config"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/plugin"
"github.com/drakkan/sftpgo/v2/internal/sftpd"
"github.com/drakkan/sftpgo/v2/internal/version"
"github.com/drakkan/sftpgo/v2/common"
"github.com/drakkan/sftpgo/v2/config"
"github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/plugin"
"github.com/drakkan/sftpgo/v2/sftpd"
"github.com/drakkan/sftpgo/v2/version"
)
var (
@@ -51,25 +51,18 @@ Subsystem sftp sftpgo startsubsys
Command-line flags should be specified in the Subsystem declaration.
`,
Run: func(_ *cobra.Command, _ []string) {
Run: func(cmd *cobra.Command, args []string) {
logSender := "startsubsys"
connectionID := xid.New().String()
var zeroLogLevel zerolog.Level
switch logLevel {
case "info":
zeroLogLevel = zerolog.InfoLevel
case "warn":
zeroLogLevel = zerolog.WarnLevel
case "error":
zeroLogLevel = zerolog.ErrorLevel
default:
zeroLogLevel = zerolog.DebugLevel
logLevel := zerolog.DebugLevel
if !logVerbose {
logLevel = zerolog.InfoLevel
}
logger.SetLogTime(logUTCTime)
if logJournalD {
logger.InitJournalDLogger(zeroLogLevel)
logger.InitJournalDLogger(logLevel)
} else {
logger.InitStdErrLogger(zeroLogLevel)
logger.InitStdErrLogger(logLevel)
}
osUser, err := user.Current()
if err != nil {
@@ -105,7 +98,7 @@ Command-line flags should be specified in the Subsystem declaration.
logger.Error(logSender, "", "unable to initialize MFA: %v", err)
os.Exit(1)
}
if err := plugin.Initialize(config.GetPluginsConfig(), logLevel); err != nil {
if err := plugin.Initialize(config.GetPluginsConfig(), logVerbose); err != nil {
logger.Error(logSender, connectionID, "unable to initialize plugin system: %v", err)
os.Exit(1)
}
@@ -203,17 +196,13 @@ error`)
addConfigFlags(subsystemCmd)
viper.SetDefault(logLevelKey, defaultLogLevel)
viper.BindEnv(logLevelKey, "SFTPGO_LOG_LEVEL") //nolint:errcheck
subsystemCmd.Flags().StringVar(&logLevel, logLevelFlag, viper.GetString(logLevelKey),
`Set the log level. Supported values:
debug, info, warn, error.
This flag can be set
using SFTPGO_LOG_LEVEL env var too.
viper.SetDefault(logVerboseKey, defaultLogVerbose)
viper.BindEnv(logVerboseKey, "SFTPGO_LOG_VERBOSE") //nolint:errcheck
subsystemCmd.Flags().BoolVarP(&logVerbose, logVerboseFlag, "v", viper.GetBool(logVerboseKey),
`Enable verbose logs. This flag can be set
using SFTPGO_LOG_VERBOSE env var too.
`)
viper.BindPFlag(logLevelKey, subsystemCmd.Flags().Lookup(logLevelFlag)) //nolint:errcheck
viper.BindPFlag(logVerboseKey, subsystemCmd.Flags().Lookup(logVerboseFlag)) //nolint:errcheck
viper.SetDefault(logUTCTimeKey, defaultLogUTCTime)
viper.BindEnv(logUTCTimeKey, "SFTPGO_LOG_UTC_TIME") //nolint:errcheck

View File

@@ -20,14 +20,14 @@ import (
"github.com/spf13/cobra"
"github.com/drakkan/sftpgo/v2/internal/service"
"github.com/drakkan/sftpgo/v2/service"
)
var (
statusCmd = &cobra.Command{
Use: "status",
Short: "Retrieve the status for the SFTPGo Windows Service",
Run: func(_ *cobra.Command, _ []string) {
Run: func(cmd *cobra.Command, args []string) {
s := service.WindowsService{
Service: service.Service{
Shutdown: make(chan bool),

View File

@@ -20,14 +20,14 @@ import (
"github.com/spf13/cobra"
"github.com/drakkan/sftpgo/v2/internal/service"
"github.com/drakkan/sftpgo/v2/service"
)
var (
stopCmd = &cobra.Command{
Use: "stop",
Short: "Stop the SFTPGo Windows Service",
Run: func(_ *cobra.Command, _ []string) {
Run: func(cmd *cobra.Command, args []string) {
s := service.WindowsService{
Service: service.Service{
Shutdown: make(chan bool),

View File

@@ -20,14 +20,14 @@ import (
"github.com/spf13/cobra"
"github.com/drakkan/sftpgo/v2/internal/service"
"github.com/drakkan/sftpgo/v2/service"
)
var (
uninstallCmd = &cobra.Command{
Use: "uninstall",
Short: "Uninstall the SFTPGo Windows Service",
Run: func(_ *cobra.Command, _ []string) {
Run: func(cmd *cobra.Command, args []string) {
s := service.WindowsService{
Service: service.Service{
Shutdown: make(chan bool),

View File

@@ -12,15 +12,13 @@
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
// Package command provides command configuration for SFTPGo hooks
package command
import (
"fmt"
"os"
"strings"
"time"
"github.com/drakkan/sftpgo/v2/internal/util"
)
const (
@@ -29,25 +27,8 @@ const (
defaultTimeout = 30
)
// Supported hook names
const (
HookFsActions = "fs_actions"
HookProviderActions = "provider_actions"
HookStartup = "startup"
HookPostConnect = "post_connect"
HookPostDisconnect = "post_disconnect"
HookDataRetention = "data_retention"
HookCheckPassword = "check_password"
HookPreLogin = "pre_login"
HookPostLogin = "post_login"
HookExternalAuth = "external_auth"
HookKeyboardInteractive = "keyboard_interactive"
)
var (
config Config
supportedHooks = []string{HookFsActions, HookProviderActions, HookStartup, HookPostConnect, HookPostDisconnect,
HookDataRetention, HookCheckPassword, HookPreLogin, HookPostLogin, HookExternalAuth, HookKeyboardInteractive}
config Config
)
// Command define the configuration for a specific commands
@@ -59,14 +40,10 @@ type Command struct {
// Do not use variables with the SFTPGO_ prefix to avoid conflicts with env
// vars that SFTPGo sets
Timeout int `json:"timeout" mapstructure:"timeout"`
// Env defines environment variable for the command.
// Env defines additional environment variable for the commands.
// Each entry is of the form "key=value".
// These values are added to the global environment variables if any
Env []string `json:"env" mapstructure:"env"`
// Args defines arguments to pass to the specified command
Args []string `json:"args" mapstructure:"args"`
// if not empty both command path and hook name must match
Hook string `json:"hook" mapstructure:"hook"`
}
// Config defines the configuration for external commands such as
@@ -74,7 +51,7 @@ type Command struct {
type Config struct {
// Timeout specifies a global time limit, in seconds, for the external commands execution
Timeout int `json:"timeout" mapstructure:"timeout"`
// Env defines environment variable for the commands.
// Env defines additional environment variable for the commands.
// Each entry is of the form "key=value".
// Do not use variables with the SFTPGO_ prefix to avoid conflicts with env
// vars that SFTPGo sets
@@ -95,7 +72,7 @@ func (c Config) Initialize() error {
return fmt.Errorf("invalid timeout %v", c.Timeout)
}
for _, env := range c.Env {
if len(strings.SplitN(env, "=", 2)) != 2 {
if len(strings.Split(env, "=")) != 2 {
return fmt.Errorf("invalid env var %#v", env)
}
}
@@ -111,37 +88,27 @@ func (c Config) Initialize() error {
}
}
for _, env := range cmd.Env {
if len(strings.SplitN(env, "=", 2)) != 2 {
if len(strings.Split(env, "=")) != 2 {
return fmt.Errorf("invalid env var %#v for command %#v", env, cmd.Path)
}
}
// don't validate args, we allow to pass empty arguments
if cmd.Hook != "" {
if !util.Contains(supportedHooks, cmd.Hook) {
return fmt.Errorf("invalid hook name %q, supported values: %+v", cmd.Hook, supportedHooks)
}
}
}
config = c
return nil
}
// GetConfig returns the configuration for the specified command
func GetConfig(command, hook string) (time.Duration, []string, []string) {
env := []string{}
var args []string
func GetConfig(command string) (time.Duration, []string) {
env := os.Environ()
timeout := time.Duration(config.Timeout) * time.Second
env = append(env, config.Env...)
for _, cmd := range config.Commands {
if cmd.Path == command {
if cmd.Hook == "" || cmd.Hook == hook {
timeout = time.Duration(cmd.Timeout) * time.Second
env = append(env, cmd.Env...)
args = cmd.Args
break
}
timeout = time.Duration(cmd.Timeout) * time.Second
env = append(env, cmd.Env...)
break
}
}
return timeout, env, args
return timeout, env
}

View File

@@ -33,17 +33,15 @@ func TestCommandConfig(t *testing.T) {
assert.Equal(t, cfg.Timeout, config.Timeout)
assert.Equal(t, cfg.Env, config.Env)
assert.Len(t, cfg.Commands, 0)
timeout, env, args := GetConfig("cmd", "")
timeout, env := GetConfig("cmd")
assert.Equal(t, time.Duration(config.Timeout)*time.Second, timeout)
assert.Contains(t, env, "a=b")
assert.Len(t, args, 0)
cfg.Commands = []Command{
{
Path: "cmd1",
Timeout: 30,
Env: []string{"c=d"},
Args: []string{"1", "", "2"},
},
{
Path: "cmd2",
@@ -59,68 +57,20 @@ func TestCommandConfig(t *testing.T) {
assert.Equal(t, cfg.Commands[0].Path, config.Commands[0].Path)
assert.Equal(t, cfg.Commands[0].Timeout, config.Commands[0].Timeout)
assert.Equal(t, cfg.Commands[0].Env, config.Commands[0].Env)
assert.Equal(t, cfg.Commands[0].Args, config.Commands[0].Args)
assert.Equal(t, cfg.Commands[1].Path, config.Commands[1].Path)
assert.Equal(t, cfg.Timeout, config.Commands[1].Timeout)
assert.Equal(t, cfg.Commands[1].Env, config.Commands[1].Env)
assert.Equal(t, cfg.Commands[1].Args, config.Commands[1].Args)
}
timeout, env, args = GetConfig("cmd1", "")
timeout, env = GetConfig("cmd1")
assert.Equal(t, time.Duration(config.Commands[0].Timeout)*time.Second, timeout)
assert.Contains(t, env, "a=b")
assert.Contains(t, env, "c=d")
assert.NotContains(t, env, "e=f")
if assert.Len(t, args, 3) {
assert.Equal(t, "1", args[0])
assert.Empty(t, args[1])
assert.Equal(t, "2", args[2])
}
timeout, env, args = GetConfig("cmd2", "")
timeout, env = GetConfig("cmd2")
assert.Equal(t, time.Duration(config.Timeout)*time.Second, timeout)
assert.Contains(t, env, "a=b")
assert.NotContains(t, env, "c=d")
assert.Contains(t, env, "e=f")
assert.Len(t, args, 0)
cfg.Commands = []Command{
{
Path: "cmd1",
Timeout: 30,
Env: []string{"c=d"},
Args: []string{"1", "", "2"},
Hook: HookCheckPassword,
},
{
Path: "cmd1",
Timeout: 0,
Env: []string{"e=f"},
Hook: HookExternalAuth,
},
}
err = cfg.Initialize()
require.NoError(t, err)
timeout, env, args = GetConfig("cmd1", "")
assert.Equal(t, time.Duration(config.Timeout)*time.Second, timeout)
assert.Contains(t, env, "a=b")
assert.NotContains(t, env, "c=d")
assert.NotContains(t, env, "e=f")
assert.Len(t, args, 0)
timeout, env, args = GetConfig("cmd1", HookCheckPassword)
assert.Equal(t, time.Duration(config.Commands[0].Timeout)*time.Second, timeout)
assert.Contains(t, env, "a=b")
assert.Contains(t, env, "c=d")
assert.NotContains(t, env, "e=f")
if assert.Len(t, args, 3) {
assert.Equal(t, "1", args[0])
assert.Empty(t, args[1])
assert.Equal(t, "2", args[2])
}
timeout, env, args = GetConfig("cmd1", HookExternalAuth)
assert.Equal(t, time.Duration(cfg.Timeout)*time.Second, timeout)
assert.Contains(t, env, "a=b")
assert.NotContains(t, env, "c=d")
assert.Contains(t, env, "e=f")
assert.Len(t, args, 0)
}
func TestConfigErrors(t *testing.T) {
@@ -166,16 +116,4 @@ func TestConfigErrors(t *testing.T) {
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "invalid env var")
}
c.Commands = []Command{
{
Path: "path",
Timeout: 30,
Env: []string{"a=b"},
Hook: "invali",
},
}
err = c.Initialize()
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "invalid hook name")
}
}

View File

@@ -23,41 +23,27 @@ import (
"net/http"
"net/url"
"os/exec"
"path"
"path/filepath"
"strings"
"sync/atomic"
"time"
"github.com/sftpgo/sdk"
"github.com/sftpgo/sdk/plugin/notifier"
"github.com/drakkan/sftpgo/v2/internal/command"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/httpclient"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/plugin"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/command"
"github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/httpclient"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/plugin"
"github.com/drakkan/sftpgo/v2/util"
)
var (
errUnconfiguredAction = errors.New("no hook is configured for this action")
errNoHook = errors.New("unable to execute action, no hook defined")
errUnexpectedHTTResponse = errors.New("unexpected HTTP hook response code")
hooksConcurrencyGuard = make(chan struct{}, 150)
activeHooks atomic.Int32
errUnexpectedHTTResponse = errors.New("unexpected HTTP response code")
)
func startNewHook() {
activeHooks.Add(1)
hooksConcurrencyGuard <- struct{}{}
}
func hookEnded() {
activeHooks.Add(-1)
<-hooksConcurrencyGuard
}
// ProtocolActions defines the action to execute on file operations and SSH commands
type ProtocolActions struct {
// Valid values are download, upload, pre-delete, delete, rename, ssh_cmd. Empty slice to disable
@@ -112,56 +98,26 @@ func ExecutePreAction(conn *BaseConnection, operation, filePath, virtualPath str
// ExecuteActionNotification executes the defined hook, if any, for the specified action
func ExecuteActionNotification(conn *BaseConnection, operation, filePath, virtualPath, target, virtualTarget, sshCmd string,
fileSize int64, err error,
) error {
) {
hasNotifiersPlugin := plugin.Handler.HasNotifiers()
hasHook := util.Contains(Config.Actions.ExecuteOn, operation)
hasRules := eventManager.hasFsRules()
if !hasHook && !hasNotifiersPlugin && !hasRules {
return nil
if !hasHook && !hasNotifiersPlugin {
return
}
notification := newActionNotification(&conn.User, operation, filePath, virtualPath, target, virtualTarget, sshCmd,
conn.protocol, conn.GetRemoteIP(), conn.ID, fileSize, 0, err)
if hasNotifiersPlugin {
plugin.Handler.NotifyFsEvent(notification)
}
var errRes error
if hasRules {
params := EventParams{
Name: notification.Username,
Groups: conn.User.Groups,
Event: notification.Action,
Status: notification.Status,
VirtualPath: notification.VirtualPath,
FsPath: notification.Path,
VirtualTargetPath: notification.VirtualTargetPath,
FsTargetPath: notification.TargetPath,
ObjectName: path.Base(notification.VirtualPath),
FileSize: notification.FileSize,
Protocol: notification.Protocol,
IP: notification.IP,
Timestamp: notification.Timestamp,
Object: nil,
}
if err != nil {
params.AddError(fmt.Errorf("%q failed: %w", params.Event, err))
}
errRes = eventManager.handleFsEvent(params)
}
if hasHook {
if util.Contains(Config.Actions.ExecuteSync, operation) {
if errHook := actionHandler.Handle(notification); errHook != nil {
errRes = errHook
}
} else {
go func() {
startNewHook()
defer hookEnded()
actionHandler.Handle(notification) //nolint:errcheck
}()
actionHandler.Handle(notification) //nolint:errcheck
return
}
go actionHandler.Handle(notification) //nolint:errcheck
}
return errRes
}
// ActionHandler handles a notification for a Protocol Action.
@@ -177,6 +133,7 @@ func newActionNotification(
err error,
) *notifier.FsEvent {
var bucket, endpoint string
status := 1
fsConfig := user.GetFsConfigForPath(virtualPath)
@@ -193,8 +150,12 @@ func newActionNotification(
}
case sdk.SFTPFilesystemProvider:
endpoint = fsConfig.SFTPConfig.Endpoint
case sdk.HTTPFilesystemProvider:
endpoint = fsConfig.HTTPConfig.Endpoint
}
if err == ErrQuotaExceeded {
status = 3
} else if err != nil {
status = 2
}
return &notifier.FsEvent{
@@ -209,7 +170,7 @@ func newActionNotification(
FsProvider: int(fsConfig.Provider),
Bucket: bucket,
Endpoint: endpoint,
Status: getNotificationStatus(err),
Status: status,
Protocol: protocol,
IP: ip,
SessionID: sessionID,
@@ -262,7 +223,7 @@ func (h *defaultActionHandler) handleHTTP(event *notifier.FsEvent) error {
}
}
logger.Debug(event.Protocol, "", "notified operation %q to URL: %s status code: %d, elapsed: %s err: %v",
logger.Debug(event.Protocol, "", "notified operation %#v to URL: %v status code: %v, elapsed: %v err: %v",
event.Action, u.Redacted(), respCode, time.Since(startTime), err)
return err
@@ -276,11 +237,11 @@ func (h *defaultActionHandler) handleCommand(event *notifier.FsEvent) error {
return err
}
timeout, env, args := command.GetConfig(Config.Actions.Hook, command.HookFsActions)
timeout, env := command.GetConfig(Config.Actions.Hook)
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
cmd := exec.CommandContext(ctx, Config.Actions.Hook, args...)
cmd := exec.CommandContext(ctx, Config.Actions.Hook)
cmd.Env = append(env, notificationAsEnvVars(event)...)
startTime := time.Now()
@@ -294,32 +255,22 @@ func (h *defaultActionHandler) handleCommand(event *notifier.FsEvent) error {
func notificationAsEnvVars(event *notifier.FsEvent) []string {
return []string{
fmt.Sprintf("SFTPGO_ACTION=%s", event.Action),
fmt.Sprintf("SFTPGO_ACTION_USERNAME=%s", event.Username),
fmt.Sprintf("SFTPGO_ACTION_PATH=%s", event.Path),
fmt.Sprintf("SFTPGO_ACTION_TARGET=%s", event.TargetPath),
fmt.Sprintf("SFTPGO_ACTION_VIRTUAL_PATH=%s", event.VirtualPath),
fmt.Sprintf("SFTPGO_ACTION_VIRTUAL_TARGET=%s", event.VirtualTargetPath),
fmt.Sprintf("SFTPGO_ACTION_SSH_CMD=%s", event.SSHCmd),
fmt.Sprintf("SFTPGO_ACTION_FILE_SIZE=%d", event.FileSize),
fmt.Sprintf("SFTPGO_ACTION_FS_PROVIDER=%d", event.FsProvider),
fmt.Sprintf("SFTPGO_ACTION_BUCKET=%s", event.Bucket),
fmt.Sprintf("SFTPGO_ACTION_ENDPOINT=%s", event.Endpoint),
fmt.Sprintf("SFTPGO_ACTION_STATUS=%d", event.Status),
fmt.Sprintf("SFTPGO_ACTION_PROTOCOL=%s", event.Protocol),
fmt.Sprintf("SFTPGO_ACTION_IP=%s", event.IP),
fmt.Sprintf("SFTPGO_ACTION_SESSION_ID=%s", event.SessionID),
fmt.Sprintf("SFTPGO_ACTION_OPEN_FLAGS=%d", event.OpenFlags),
fmt.Sprintf("SFTPGO_ACTION_TIMESTAMP=%d", event.Timestamp),
fmt.Sprintf("SFTPGO_ACTION=%v", event.Action),
fmt.Sprintf("SFTPGO_ACTION_USERNAME=%v", event.Username),
fmt.Sprintf("SFTPGO_ACTION_PATH=%v", event.Path),
fmt.Sprintf("SFTPGO_ACTION_TARGET=%v", event.TargetPath),
fmt.Sprintf("SFTPGO_ACTION_VIRTUAL_PATH=%v", event.VirtualPath),
fmt.Sprintf("SFTPGO_ACTION_VIRTUAL_TARGET=%v", event.VirtualTargetPath),
fmt.Sprintf("SFTPGO_ACTION_SSH_CMD=%v", event.SSHCmd),
fmt.Sprintf("SFTPGO_ACTION_FILE_SIZE=%v", event.FileSize),
fmt.Sprintf("SFTPGO_ACTION_FS_PROVIDER=%v", event.FsProvider),
fmt.Sprintf("SFTPGO_ACTION_BUCKET=%v", event.Bucket),
fmt.Sprintf("SFTPGO_ACTION_ENDPOINT=%v", event.Endpoint),
fmt.Sprintf("SFTPGO_ACTION_STATUS=%v", event.Status),
fmt.Sprintf("SFTPGO_ACTION_PROTOCOL=%v", event.Protocol),
fmt.Sprintf("SFTPGO_ACTION_IP=%v", event.IP),
fmt.Sprintf("SFTPGO_ACTION_SESSION_ID=%v", event.SessionID),
fmt.Sprintf("SFTPGO_ACTION_OPEN_FLAGS=%v", event.OpenFlags),
fmt.Sprintf("SFTPGO_ACTION_TIMESTAMP=%v", event.Timestamp),
}
}
func getNotificationStatus(err error) int {
status := 1
if err == ErrQuotaExceeded {
status = 3
} else if err != nil {
status = 2
}
return status
}

View File

@@ -29,9 +29,9 @@ import (
"github.com/sftpgo/sdk/plugin/notifier"
"github.com/stretchr/testify/assert"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/plugin"
"github.com/drakkan/sftpgo/v2/internal/vfs"
"github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/plugin"
"github.com/drakkan/sftpgo/v2/vfs"
)
func TestNewActionNotification(t *testing.T) {
@@ -63,11 +63,6 @@ func TestNewActionNotification(t *testing.T) {
Endpoint: "sftpendpoint",
},
}
user.FsConfig.HTTPConfig = vfs.HTTPFsConfig{
BaseHTTPFsConfig: sdk.BaseHTTPFsConfig{
Endpoint: "httpendpoint",
},
}
sessionID := xid.New().String()
a := newActionNotification(user, operationDownload, "path", "vpath", "target", "", "", ProtocolSFTP, "", sessionID,
123, 0, errors.New("fake error"))
@@ -90,12 +85,6 @@ func TestNewActionNotification(t *testing.T) {
assert.Equal(t, 0, len(a.Endpoint))
assert.Equal(t, 3, a.Status)
user.FsConfig.Provider = sdk.HTTPFilesystemProvider
a = newActionNotification(user, operationDownload, "path", "vpath", "target", "", "", ProtocolSSH, "", sessionID,
123, 0, nil)
assert.Equal(t, "httpendpoint", a.Endpoint)
assert.Equal(t, 1, a.Status)
user.FsConfig.Provider = sdk.AzureBlobFilesystemProvider
a = newActionNotification(user, operationDownload, "path", "vpath", "target", "", "", ProtocolSCP, "", sessionID,
123, 0, nil)
@@ -171,11 +160,9 @@ func TestActionCMD(t *testing.T) {
assert.NoError(t, err)
c := NewBaseConnection("id", ProtocolSFTP, "", "", *user)
err = ExecuteActionNotification(c, OperationSSHCmd, "path", "vpath", "target", "vtarget", "sha1sum", 0, nil)
assert.NoError(t, err)
ExecuteActionNotification(c, OperationSSHCmd, "path", "vpath", "target", "vtarget", "sha1sum", 0, nil)
err = ExecuteActionNotification(c, operationDownload, "path", "vpath", "", "", "", 0, nil)
assert.NoError(t, err)
ExecuteActionNotification(c, operationDownload, "path", "vpath", "", "", "", 0, nil)
Config.Actions = actionsCopy
}
@@ -278,7 +265,7 @@ func TestUnconfiguredHook(t *testing.T) {
Type: "notifier",
},
}
err := plugin.Initialize(pluginsConfig, "debug")
err := plugin.Initialize(pluginsConfig, true)
assert.Error(t, err)
assert.True(t, plugin.Handler.HasNotifiers())
@@ -288,10 +275,9 @@ func TestUnconfiguredHook(t *testing.T) {
err = ExecutePreAction(c, operationPreDelete, "", "", 0, 0)
assert.ErrorIs(t, err, errUnconfiguredAction)
err = ExecuteActionNotification(c, operationDownload, "", "", "", "", "", 0, nil)
assert.NoError(t, err)
ExecuteActionNotification(c, operationDownload, "", "", "", "", "", 0, nil)
err = plugin.Initialize(nil, "debug")
err = plugin.Initialize(nil, true)
assert.NoError(t, err)
assert.False(t, plugin.Handler.HasNotifiers())

View File

@@ -18,18 +18,18 @@ import (
"sync"
"sync/atomic"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/logger"
)
// clienstMap is a struct containing the map of the connected clients
type clientsMap struct {
totalConnections atomic.Int32
totalConnections int32
mu sync.RWMutex
clients map[string]int
}
func (c *clientsMap) add(source string) {
c.totalConnections.Add(1)
atomic.AddInt32(&c.totalConnections, 1)
c.mu.Lock()
defer c.mu.Unlock()
@@ -42,7 +42,7 @@ func (c *clientsMap) remove(source string) {
defer c.mu.Unlock()
if val, ok := c.clients[source]; ok {
c.totalConnections.Add(-1)
atomic.AddInt32(&c.totalConnections, -1)
c.clients[source]--
if val > 1 {
return
@@ -54,7 +54,7 @@ func (c *clientsMap) remove(source string) {
}
func (c *clientsMap) getTotal() int32 {
return c.totalConnections.Load()
return atomic.LoadInt32(&c.totalConnections)
}
func (c *clientsMap) getTotalFrom(source string) int {

View File

@@ -33,35 +33,33 @@ import (
"github.com/pires/go-proxyproto"
"github.com/drakkan/sftpgo/v2/internal/command"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/httpclient"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/metric"
"github.com/drakkan/sftpgo/v2/internal/plugin"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/internal/vfs"
"github.com/drakkan/sftpgo/v2/command"
"github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/httpclient"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/metric"
"github.com/drakkan/sftpgo/v2/plugin"
"github.com/drakkan/sftpgo/v2/util"
"github.com/drakkan/sftpgo/v2/vfs"
)
// constants
const (
logSender = "common"
uploadLogSender = "Upload"
downloadLogSender = "Download"
renameLogSender = "Rename"
rmdirLogSender = "Rmdir"
mkdirLogSender = "Mkdir"
symlinkLogSender = "Symlink"
removeLogSender = "Remove"
chownLogSender = "Chown"
chmodLogSender = "Chmod"
chtimesLogSender = "Chtimes"
truncateLogSender = "Truncate"
operationDownload = "download"
operationUpload = "upload"
operationFirstDownload = "first-download"
operationFirstUpload = "first-upload"
operationDelete = "delete"
logSender = "common"
uploadLogSender = "Upload"
downloadLogSender = "Download"
renameLogSender = "Rename"
rmdirLogSender = "Rmdir"
mkdirLogSender = "Mkdir"
symlinkLogSender = "Symlink"
removeLogSender = "Remove"
chownLogSender = "Chown"
chmodLogSender = "Chmod"
chtimesLogSender = "Chtimes"
truncateLogSender = "Truncate"
operationDownload = "download"
operationUpload = "upload"
operationDelete = "delete"
// Pre-download action name
OperationPreDownload = "pre-download"
// Pre-upload action name
@@ -102,7 +100,6 @@ const (
ProtocolHTTPShare = "HTTPShare"
ProtocolDataRetention = "DataRetention"
ProtocolOIDC = "OIDC"
protocolEventAction = "EventAction"
)
// Upload modes
@@ -117,27 +114,25 @@ func init() {
clients: make(map[string]int),
}
Connections.perUserConns = make(map[string]int)
Connections.mapping = make(map[string]int)
Connections.sshMapping = make(map[string]int)
}
// errors definitions
var (
ErrPermissionDenied = errors.New("permission denied")
ErrNotExist = errors.New("no such file or directory")
ErrOpUnsupported = errors.New("operation unsupported")
ErrGenericFailure = errors.New("failure")
ErrQuotaExceeded = errors.New("denying write due to space limit")
ErrReadQuotaExceeded = errors.New("denying read due to quota limit")
ErrConnectionDenied = errors.New("you are not allowed to connect")
ErrNoBinding = errors.New("no binding configured")
ErrCrtRevoked = errors.New("your certificate has been revoked")
ErrNoCredentials = errors.New("no credential provided")
ErrInternalFailure = errors.New("internal failure")
ErrTransferAborted = errors.New("transfer aborted")
ErrShuttingDown = errors.New("the service is shutting down")
errNoTransfer = errors.New("requested transfer not found")
errTransferMismatch = errors.New("transfer mismatch")
ErrPermissionDenied = errors.New("permission denied")
ErrNotExist = errors.New("no such file or directory")
ErrOpUnsupported = errors.New("operation unsupported")
ErrGenericFailure = errors.New("failure")
ErrQuotaExceeded = errors.New("denying write due to space limit")
ErrReadQuotaExceeded = errors.New("denying read due to quota limit")
ErrSkipPermissionsCheck = errors.New("permission check skipped")
ErrConnectionDenied = errors.New("you are not allowed to connect")
ErrNoBinding = errors.New("no binding configured")
ErrCrtRevoked = errors.New("your certificate has been revoked")
ErrNoCredentials = errors.New("no credential provided")
ErrInternalFailure = errors.New("internal failure")
ErrTransferAborted = errors.New("transfer aborted")
errNoTransfer = errors.New("requested transfer not found")
errTransferMismatch = errors.New("transfer mismatch")
)
var (
@@ -146,28 +141,26 @@ var (
// Connections is the list of active connections
Connections ActiveConnections
// QuotaScans is the list of active quota scans
QuotaScans ActiveScans
// ActiveMetadataChecks holds the active metadata checks
ActiveMetadataChecks MetadataChecks
transfersChecker TransfersChecker
supportedProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP, ProtocolWebDAV,
QuotaScans ActiveScans
transfersChecker TransfersChecker
periodicTimeoutTicker *time.Ticker
periodicTimeoutTickerDone chan bool
supportedProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP, ProtocolWebDAV,
ProtocolHTTP, ProtocolHTTPShare, ProtocolOIDC}
disconnHookProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP}
// the map key is the protocol, for each protocol we can have multiple rate limiters
rateLimiters map[string][]*rateLimiter
isShuttingDown atomic.Bool
rateLimiters map[string][]*rateLimiter
)
// Initialize sets the common configuration
func Initialize(c Configuration, isShared int) error {
isShuttingDown.Store(false)
Config = c
Config.Actions.ExecuteOn = util.RemoveDuplicates(Config.Actions.ExecuteOn, true)
Config.Actions.ExecuteSync = util.RemoveDuplicates(Config.Actions.ExecuteSync, true)
Config.ProxyAllowed = util.RemoveDuplicates(Config.ProxyAllowed, true)
Config.idleLoginTimeout = 2 * time.Minute
Config.idleTimeoutAsDuration = time.Duration(Config.IdleTimeout) * time.Minute
startPeriodicChecks(periodicTimeoutCheckInterval)
startPeriodicTimeoutTicker(periodicTimeoutCheckInterval)
Config.defender = nil
Config.whitelist = nil
rateLimiters = make(map[string][]*rateLimiter)
@@ -217,73 +210,10 @@ func Initialize(c Configuration, isShared int) error {
}
vfs.SetTempPath(c.TempPath)
dataprovider.SetTempPath(c.TempPath)
vfs.SetAllowSelfConnections(c.AllowSelfConnections)
dataprovider.SetAllowSelfConnections(c.AllowSelfConnections)
transfersChecker = getTransfersChecker(isShared)
return nil
}
// CheckClosing returns an error if the service is closing
func CheckClosing() error {
if isShuttingDown.Load() {
return ErrShuttingDown
}
return nil
}
// WaitForTransfers waits, for the specified grace time, for currently ongoing
// client-initiated transfer sessions to completes.
// A zero graceTime means no wait
func WaitForTransfers(graceTime int) {
if graceTime == 0 {
return
}
if isShuttingDown.Swap(true) {
return
}
if activeHooks.Load() == 0 && getActiveConnections() == 0 {
return
}
graceTimer := time.NewTimer(time.Duration(graceTime) * time.Second)
ticker := time.NewTicker(3 * time.Second)
for {
select {
case <-ticker.C:
hooks := activeHooks.Load()
logger.Info(logSender, "", "active hooks: %d", hooks)
if hooks == 0 && getActiveConnections() == 0 {
logger.Info(logSender, "", "no more active connections, graceful shutdown")
ticker.Stop()
graceTimer.Stop()
return
}
case <-graceTimer.C:
logger.Info(logSender, "", "grace time expired, hard shutdown")
ticker.Stop()
return
}
}
}
// getActiveConnections returns the number of connections with active transfers
func getActiveConnections() int {
var activeConns int
Connections.RLock()
for _, c := range Connections.connections {
if len(c.GetTransfers()) > 0 {
activeConns++
}
}
Connections.RUnlock()
logger.Info(logSender, "", "number of connections with active transfers: %d", activeConns)
return activeConns
}
// LimitRate blocks until all the configured rate limiters
// allow one event to happen.
// It returns an error if the time to wait exceeds the max
@@ -381,18 +311,35 @@ func AddDefenderEvent(ip string, event HostEvent) {
Config.defender.AddEvent(ip, event)
}
func startPeriodicChecks(duration time.Duration) {
startEventScheduler()
spec := fmt.Sprintf("@every %s", duration)
_, err := eventScheduler.AddFunc(spec, Connections.checkTransfers)
util.PanicOnError(err)
logger.Info(logSender, "", "scheduled overquota transfers check, schedule %q", spec)
if Config.IdleTimeout > 0 {
// the ticker cannot be started/stopped from multiple goroutines
func startPeriodicTimeoutTicker(duration time.Duration) {
stopPeriodicTimeoutTicker()
periodicTimeoutTicker = time.NewTicker(duration)
periodicTimeoutTickerDone = make(chan bool)
go func() {
counter := int64(0)
ratio := idleTimeoutCheckInterval / periodicTimeoutCheckInterval
spec = fmt.Sprintf("@every %s", duration*ratio)
_, err = eventScheduler.AddFunc(spec, Connections.checkIdles)
util.PanicOnError(err)
logger.Info(logSender, "", "scheduled idle connections check, schedule %q", spec)
for {
select {
case <-periodicTimeoutTickerDone:
return
case <-periodicTimeoutTicker.C:
counter++
if Config.IdleTimeout > 0 && counter >= int64(ratio) {
counter = 0
Connections.checkIdles()
}
go Connections.checkTransfers()
}
}
}()
}
func stopPeriodicTimeoutTicker() {
if periodicTimeoutTicker != nil {
periodicTimeoutTicker.Stop()
periodicTimeoutTickerDone <- true
periodicTimeoutTicker = nil
}
}
@@ -464,11 +411,11 @@ func (t *ConnectionTransfer) getConnectionTransferAsString() string {
case operationDownload:
result += "DL "
}
result += fmt.Sprintf("%q ", t.VirtualPath)
result += fmt.Sprintf("%#v ", t.VirtualPath)
if t.Size > 0 {
elapsed := time.Since(util.GetTimeFromMsecSinceEpoch(t.StartTime))
speed := float64(t.Size) / float64(util.GetTimeAsMsSinceEpoch(time.Now())-t.StartTime)
result += fmt.Sprintf("Size: %s Elapsed: %s Speed: \"%.1f KB/s\"", util.ByteCountIEC(t.Size),
result += fmt.Sprintf("Size: %#v Elapsed: %#v Speed: \"%.1f KB/s\"", util.ByteCountIEC(t.Size),
util.GetDurationAsString(elapsed), speed)
}
return result
@@ -572,9 +519,6 @@ type Configuration struct {
// Only the listed IPs/networks can access the configured services, all other client connections
// will be dropped before they even try to authenticate.
WhiteListFile string `json:"whitelist_file" mapstructure:"whitelist_file"`
// Allow users on this instance to use other users/virtual folders on this instance as storage backend.
// Enable this setting if you know what you are doing.
AllowSelfConnections int `json:"allow_self_connections" mapstructure:"allow_self_connections"`
// Defender configuration
DefenderConfig DefenderConfig `json:"defender" mapstructure:"defender"`
// Rate limiter configurations
@@ -651,11 +595,11 @@ func (c *Configuration) ExecuteStartupHook() error {
return err
}
startTime := time.Now()
timeout, env, args := command.GetConfig(c.StartupHook, command.HookStartup)
timeout, env := command.GetConfig(c.StartupHook)
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
cmd := exec.CommandContext(ctx, c.StartupHook, args...)
cmd := exec.CommandContext(ctx, c.StartupHook)
cmd.Env = env
err := cmd.Run()
logger.Debug(logSender, "", "Startup hook executed, elapsed: %v, error: %v", time.Since(startTime), err)
@@ -663,9 +607,6 @@ func (c *Configuration) ExecuteStartupHook() error {
}
func (c *Configuration) executePostDisconnectHook(remoteAddr, protocol, username, connID string, connectionTime time.Time) {
startNewHook()
defer hookEnded()
ipAddr := util.GetIPFromRemoteAddress(remoteAddr)
connDuration := int64(time.Since(connectionTime) / time.Millisecond)
@@ -697,12 +638,12 @@ func (c *Configuration) executePostDisconnectHook(remoteAddr, protocol, username
logger.Debug(protocol, connID, "invalid post disconnect hook %#v", c.PostDisconnectHook)
return
}
timeout, env, args := command.GetConfig(c.PostDisconnectHook, command.HookPostDisconnect)
timeout, env := command.GetConfig(c.PostDisconnectHook)
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
startTime := time.Now()
cmd := exec.CommandContext(ctx, c.PostDisconnectHook, args...)
cmd := exec.CommandContext(ctx, c.PostDisconnectHook)
cmd.Env = append(env,
fmt.Sprintf("SFTPGO_CONNECTION_IP=%v", ipAddr),
fmt.Sprintf("SFTPGO_CONNECTION_USERNAME=%v", username),
@@ -757,11 +698,11 @@ func (c *Configuration) ExecutePostConnectHook(ipAddr, protocol string) error {
logger.Warn(protocol, "", "Login from ip %#v denied: %v", ipAddr, err)
return err
}
timeout, env, args := command.GetConfig(c.PostConnectHook, command.HookPostConnect)
timeout, env := command.GetConfig(c.PostConnectHook)
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
cmd := exec.CommandContext(ctx, c.PostConnectHook, args...)
cmd := exec.CommandContext(ctx, c.PostConnectHook)
cmd.Env = append(env,
fmt.Sprintf("SFTPGO_CONNECTION_IP=%v", ipAddr),
fmt.Sprintf("SFTPGO_CONNECTION_PROTOCOL=%v", protocol))
@@ -777,17 +718,16 @@ func (c *Configuration) ExecutePostConnectHook(ipAddr, protocol string) error {
type SSHConnection struct {
id string
conn net.Conn
lastActivity atomic.Int64
lastActivity int64
}
// NewSSHConnection returns a new SSHConnection
func NewSSHConnection(id string, conn net.Conn) *SSHConnection {
c := &SSHConnection{
id: id,
conn: conn,
return &SSHConnection{
id: id,
conn: conn,
lastActivity: time.Now().UnixNano(),
}
c.lastActivity.Store(time.Now().UnixNano())
return c
}
// GetID returns the ID for this SSHConnection
@@ -797,12 +737,12 @@ func (c *SSHConnection) GetID() string {
// UpdateLastActivity updates last activity for this connection
func (c *SSHConnection) UpdateLastActivity() {
c.lastActivity.Store(time.Now().UnixNano())
atomic.StoreInt64(&c.lastActivity, time.Now().UnixNano())
}
// GetLastActivity returns the last connection activity
func (c *SSHConnection) GetLastActivity() time.Time {
return time.Unix(0, c.lastActivity.Load())
return time.Unix(0, atomic.LoadInt64(&c.lastActivity))
}
// Close closes the underlying network connection
@@ -815,12 +755,10 @@ type ActiveConnections struct {
// clients contains both authenticated and estabilished connections and the ones waiting
// for authentication
clients clientsMap
transfersCheckStatus atomic.Bool
transfersCheckStatus int32
sync.RWMutex
connections []ActiveConnection
mapping map[string]int
sshConnections []*SSHConnection
sshMapping map[string]int
perUserConns map[string]int
}
@@ -868,10 +806,9 @@ func (conns *ActiveConnections) Add(c ActiveConnection) error {
}
conns.addUserConnection(username)
}
conns.mapping[c.GetID()] = len(conns.connections)
conns.connections = append(conns.connections, c)
metric.UpdateActiveConnectionsSize(len(conns.connections))
logger.Debug(c.GetProtocol(), c.GetID(), "connection added, local address %q, remote address %q, num open connections: %d",
logger.Debug(c.GetProtocol(), c.GetID(), "connection added, local address %#v, remote address %#v, num open connections: %v",
c.GetLocalAddress(), c.GetRemoteAddress(), len(conns.connections))
return nil
}
@@ -884,25 +821,25 @@ func (conns *ActiveConnections) Swap(c ActiveConnection) error {
conns.Lock()
defer conns.Unlock()
if idx, ok := conns.mapping[c.GetID()]; ok {
conn := conns.connections[idx]
conns.removeUserConnection(conn.GetUsername())
if username := c.GetUsername(); username != "" {
if maxSessions := c.GetMaxSessions(); maxSessions > 0 {
if val, ok := conns.perUserConns[username]; ok && val >= maxSessions {
conns.addUserConnection(conn.GetUsername())
return fmt.Errorf("too many open sessions: %d/%d", val, maxSessions)
for idx, conn := range conns.connections {
if conn.GetID() == c.GetID() {
conns.removeUserConnection(conn.GetUsername())
if username := c.GetUsername(); username != "" {
if maxSessions := c.GetMaxSessions(); maxSessions > 0 {
if val := conns.perUserConns[username]; val >= maxSessions {
conns.addUserConnection(conn.GetUsername())
return fmt.Errorf("too many open sessions: %d/%d", val, maxSessions)
}
}
conns.addUserConnection(username)
}
conns.addUserConnection(username)
err := conn.CloseFS()
conns.connections[idx] = c
logger.Debug(logSender, c.GetID(), "connection swapped, close fs error: %v", err)
conn = nil
return nil
}
err := conn.CloseFS()
conns.connections[idx] = c
logger.Debug(logSender, c.GetID(), "connection swapped, close fs error: %v", err)
conn = nil
return nil
}
return errors.New("connection to swap not found")
}
@@ -911,53 +848,40 @@ func (conns *ActiveConnections) Remove(connectionID string) {
conns.Lock()
defer conns.Unlock()
if idx, ok := conns.mapping[connectionID]; ok {
conn := conns.connections[idx]
err := conn.CloseFS()
lastIdx := len(conns.connections) - 1
conns.connections[idx] = conns.connections[lastIdx]
conns.connections[lastIdx] = nil
conns.connections = conns.connections[:lastIdx]
delete(conns.mapping, connectionID)
if idx != lastIdx {
conns.mapping[conns.connections[idx].GetID()] = idx
for idx, conn := range conns.connections {
if conn.GetID() == connectionID {
err := conn.CloseFS()
lastIdx := len(conns.connections) - 1
conns.connections[idx] = conns.connections[lastIdx]
conns.connections[lastIdx] = nil
conns.connections = conns.connections[:lastIdx]
conns.removeUserConnection(conn.GetUsername())
metric.UpdateActiveConnectionsSize(lastIdx)
logger.Debug(conn.GetProtocol(), conn.GetID(), "connection removed, local address %#v, remote address %#v close fs error: %v, num open connections: %v",
conn.GetLocalAddress(), conn.GetRemoteAddress(), err, lastIdx)
Config.checkPostDisconnectHook(conn.GetRemoteAddress(), conn.GetProtocol(), conn.GetUsername(),
conn.GetID(), conn.GetConnectionTime())
return
}
conns.removeUserConnection(conn.GetUsername())
metric.UpdateActiveConnectionsSize(lastIdx)
logger.Debug(conn.GetProtocol(), conn.GetID(), "connection removed, local address %#v, remote address %#v close fs error: %v, num open connections: %v",
conn.GetLocalAddress(), conn.GetRemoteAddress(), err, lastIdx)
if conn.GetProtocol() == ProtocolFTP && conn.GetUsername() == "" {
ip := util.GetIPFromRemoteAddress(conn.GetRemoteAddress())
logger.ConnectionFailedLog("", ip, dataprovider.LoginMethodNoAuthTryed, conn.GetProtocol(),
dataprovider.ErrNoAuthTryed.Error())
metric.AddNoAuthTryed()
AddDefenderEvent(ip, HostEventNoLoginTried)
dataprovider.ExecutePostLoginHook(&dataprovider.User{}, dataprovider.LoginMethodNoAuthTryed, ip,
conn.GetProtocol(), dataprovider.ErrNoAuthTryed)
}
Config.checkPostDisconnectHook(conn.GetRemoteAddress(), conn.GetProtocol(), conn.GetUsername(),
conn.GetID(), conn.GetConnectionTime())
return
}
logger.Warn(logSender, "", "connection id %q to remove not found!", connectionID)
logger.Warn(logSender, "", "connection id %#v to remove not found!", connectionID)
}
// Close closes an active connection.
// It returns true on success
func (conns *ActiveConnections) Close(connectionID string) bool {
conns.RLock()
result := false
var result bool
if idx, ok := conns.mapping[connectionID]; ok {
c := conns.connections[idx]
defer func(conn ActiveConnection) {
err := conn.Disconnect()
logger.Debug(conn.GetProtocol(), conn.GetID(), "close connection requested, close err: %v", err)
}(c)
result = true
for _, c := range conns.connections {
if c.GetID() == connectionID {
defer func(conn ActiveConnection) {
err := conn.Disconnect()
logger.Debug(conn.GetProtocol(), conn.GetID(), "close connection requested, close err: %v", err)
}(c)
result = true
break
}
}
conns.RUnlock()
@@ -969,9 +893,8 @@ func (conns *ActiveConnections) AddSSHConnection(c *SSHConnection) {
conns.Lock()
defer conns.Unlock()
conns.sshMapping[c.GetID()] = len(conns.sshConnections)
conns.sshConnections = append(conns.sshConnections, c)
logger.Debug(logSender, c.GetID(), "ssh connection added, num open connections: %d", len(conns.sshConnections))
logger.Debug(logSender, c.GetID(), "ssh connection added, num open connections: %v", len(conns.sshConnections))
}
// RemoveSSHConnection removes a connection from the active ones
@@ -979,19 +902,17 @@ func (conns *ActiveConnections) RemoveSSHConnection(connectionID string) {
conns.Lock()
defer conns.Unlock()
if idx, ok := conns.sshMapping[connectionID]; ok {
lastIdx := len(conns.sshConnections) - 1
conns.sshConnections[idx] = conns.sshConnections[lastIdx]
conns.sshConnections[lastIdx] = nil
conns.sshConnections = conns.sshConnections[:lastIdx]
delete(conns.sshMapping, connectionID)
if idx != lastIdx {
conns.sshMapping[conns.sshConnections[idx].GetID()] = idx
for idx, conn := range conns.sshConnections {
if conn.GetID() == connectionID {
lastIdx := len(conns.sshConnections) - 1
conns.sshConnections[idx] = conns.sshConnections[lastIdx]
conns.sshConnections[lastIdx] = nil
conns.sshConnections = conns.sshConnections[:lastIdx]
logger.Debug(logSender, conn.GetID(), "ssh connection removed, num open ssh connections: %v", lastIdx)
return
}
logger.Debug(logSender, connectionID, "ssh connection removed, num open ssh connections: %d", lastIdx)
return
}
logger.Warn(logSender, "", "ssh connection to remove with id %q not found!", connectionID)
logger.Warn(logSender, "", "ssh connection to remove with id %#v not found!", connectionID)
}
func (conns *ActiveConnections) checkIdles() {
@@ -1026,11 +947,19 @@ func (conns *ActiveConnections) checkIdles() {
isUnauthenticatedFTPUser := (c.GetProtocol() == ProtocolFTP && c.GetUsername() == "")
if idleTime > Config.idleTimeoutAsDuration || (isUnauthenticatedFTPUser && idleTime > Config.idleLoginTimeout) {
defer func(conn ActiveConnection) {
defer func(conn ActiveConnection, isFTPNoAuth bool) {
err := conn.Disconnect()
logger.Debug(conn.GetProtocol(), conn.GetID(), "close idle connection, idle time: %v, username: %#v close err: %v",
time.Since(conn.GetLastActivity()), conn.GetUsername(), err)
}(c)
if isFTPNoAuth {
ip := util.GetIPFromRemoteAddress(c.GetRemoteAddress())
logger.ConnectionFailedLog("", ip, dataprovider.LoginMethodNoAuthTryed, c.GetProtocol(), "client idle")
metric.AddNoAuthTryed()
AddDefenderEvent(ip, HostEventNoLoginTried)
dataprovider.ExecutePostLoginHook(&dataprovider.User{}, dataprovider.LoginMethodNoAuthTryed, ip, c.GetProtocol(),
dataprovider.ErrNoAuthTryed)
}
}(c, isUnauthenticatedFTPUser)
}
}
@@ -1038,12 +967,12 @@ func (conns *ActiveConnections) checkIdles() {
}
func (conns *ActiveConnections) checkTransfers() {
if conns.transfersCheckStatus.Load() {
if atomic.LoadInt32(&conns.transfersCheckStatus) == 1 {
logger.Warn(logSender, "", "the previous transfer check is still running, skipping execution")
return
}
conns.transfersCheckStatus.Store(true)
defer conns.transfersCheckStatus.Store(false)
atomic.StoreInt32(&conns.transfersCheckStatus, 1)
defer atomic.StoreInt32(&conns.transfersCheckStatus, 0)
conns.RLock()
@@ -1115,34 +1044,30 @@ func (conns *ActiveConnections) GetClientConnections() int32 {
return conns.clients.getTotal()
}
// IsNewConnectionAllowed returns an error if the maximum number of concurrent allowed
// connections is exceeded or a whitelist is defined and the specified ipAddr is not listed
// or the service is shutting down
func (conns *ActiveConnections) IsNewConnectionAllowed(ipAddr string) error {
if isShuttingDown.Load() {
return ErrShuttingDown
}
// IsNewConnectionAllowed returns false if the maximum number of concurrent allowed connections is exceeded
// or a whitelist is defined and the specified ipAddr is not listed
func (conns *ActiveConnections) IsNewConnectionAllowed(ipAddr string) bool {
if Config.whitelist != nil {
if !Config.whitelist.isAllowed(ipAddr) {
return ErrConnectionDenied
return false
}
}
if Config.MaxTotalConnections == 0 && Config.MaxPerHostConnections == 0 {
return nil
return true
}
if Config.MaxPerHostConnections > 0 {
if total := conns.clients.getTotalFrom(ipAddr); total > Config.MaxPerHostConnections {
logger.Info(logSender, "", "active connections from %s %d/%d", ipAddr, total, Config.MaxPerHostConnections)
logger.Debug(logSender, "", "active connections from %v %v/%v", ipAddr, total, Config.MaxPerHostConnections)
AddDefenderEvent(ipAddr, HostEventLimitExceeded)
return ErrConnectionDenied
return false
}
}
if Config.MaxTotalConnections > 0 {
if total := conns.clients.getTotal(); total > int32(Config.MaxTotalConnections) {
logger.Info(logSender, "", "active client connections %d/%d", total, Config.MaxTotalConnections)
return ErrConnectionDenied
logger.Debug(logSender, "", "active client connections %v/%v", total, Config.MaxTotalConnections)
return false
}
// on a single SFTP connection we could have multiple SFTP channels or commands
@@ -1151,13 +1076,10 @@ func (conns *ActiveConnections) IsNewConnectionAllowed(ipAddr string) error {
conns.RLock()
defer conns.RUnlock()
if sess := len(conns.connections); sess >= Config.MaxTotalConnections {
logger.Info(logSender, "", "active client sessions %d/%d", sess, Config.MaxTotalConnections)
return ErrConnectionDenied
}
return len(conns.connections) < Config.MaxTotalConnections
}
return nil
return true
}
// GetStats returns stats for active connections
@@ -1166,7 +1088,6 @@ func (conns *ActiveConnections) GetStats() []ConnectionStatus {
defer conns.RUnlock()
stats := make([]ConnectionStatus, 0, len(conns.connections))
node := dataprovider.GetNodeName()
for _, c := range conns.connections {
stat := ConnectionStatus{
Username: c.GetUsername(),
@@ -1178,7 +1099,6 @@ func (conns *ActiveConnections) GetStats() []ConnectionStatus {
Protocol: c.GetProtocol(),
Command: c.GetCommand(),
Transfers: c.GetTransfers(),
Node: node,
}
stats = append(stats, stat)
}
@@ -1205,8 +1125,6 @@ type ConnectionStatus struct {
Transfers []ConnectionTransfer `json:"active_transfers,omitempty"`
// SSH command or WebDAV method
Command string `json:"command,omitempty"`
// Node identifier, omitted for single node installations
Node string `json:"node,omitempty"`
}
// GetConnectionDuration returns the connection duration as string
@@ -1248,7 +1166,7 @@ func (c *ConnectionStatus) GetTransfersAsString() string {
return result
}
// ActiveQuotaScan defines an active quota scan for a user
// ActiveQuotaScan defines an active quota scan for a user home dir
type ActiveQuotaScan struct {
// Username to which the quota scan refers
Username string `json:"username"`
@@ -1271,7 +1189,7 @@ type ActiveScans struct {
FolderScans []ActiveVirtualFolderQuotaScan
}
// GetUsersQuotaScans returns the active users quota scans
// GetUsersQuotaScans returns the active quota scans for users home directories
func (s *ActiveScans) GetUsersQuotaScans() []ActiveQuotaScan {
s.RLock()
defer s.RUnlock()
@@ -1361,65 +1279,3 @@ func (s *ActiveScans) RemoveVFolderQuotaScan(folderName string) bool {
return false
}
// MetadataCheck defines an active metadata check
type MetadataCheck struct {
// Username to which the metadata check refers
Username string `json:"username"`
// check start time as unix timestamp in milliseconds
StartTime int64 `json:"start_time"`
}
// MetadataChecks holds the active metadata checks
type MetadataChecks struct {
sync.RWMutex
checks []MetadataCheck
}
// Get returns the active metadata checks
func (c *MetadataChecks) Get() []MetadataCheck {
c.RLock()
defer c.RUnlock()
checks := make([]MetadataCheck, len(c.checks))
copy(checks, c.checks)
return checks
}
// Add adds a user to the ones with active metadata checks.
// Return false if a metadata check is already active for the specified user
func (c *MetadataChecks) Add(username string) bool {
c.Lock()
defer c.Unlock()
for idx := range c.checks {
if c.checks[idx].Username == username {
return false
}
}
c.checks = append(c.checks, MetadataCheck{
Username: username,
StartTime: util.GetTimeAsMsSinceEpoch(time.Now()),
})
return true
}
// Remove removes a user from the ones with active metadata checks
func (c *MetadataChecks) Remove(username string) bool {
c.Lock()
defer c.Unlock()
for idx := range c.checks {
if c.checks[idx].Username == username {
lastIdx := len(c.checks) - 1
c.checks[idx] = c.checks[lastIdx]
c.checks = c.checks[:lastIdx]
return true
}
}
return false
}

View File

@@ -24,7 +24,7 @@ import (
"path/filepath"
"runtime"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
@@ -34,24 +34,21 @@ import (
"github.com/stretchr/testify/require"
"golang.org/x/crypto/bcrypt"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/kms"
"github.com/drakkan/sftpgo/v2/internal/plugin"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/internal/vfs"
"github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/kms"
"github.com/drakkan/sftpgo/v2/plugin"
"github.com/drakkan/sftpgo/v2/util"
"github.com/drakkan/sftpgo/v2/vfs"
)
const (
logSenderTest = "common_test"
httpAddr = "127.0.0.1:9999"
configDir = ".."
osWindows = "windows"
userTestUsername = "common_test_username"
)
var (
configDir = filepath.Join(".", "..", "..")
)
type fakeConnection struct {
*BaseConnection
command string
@@ -99,122 +96,6 @@ func (c *customNetConn) Close() error {
return c.Conn.Close()
}
func TestConnections(t *testing.T) {
c1 := &fakeConnection{
BaseConnection: NewBaseConnection("id1", ProtocolSFTP, "", "", dataprovider.User{
BaseUser: sdk.BaseUser{
Username: userTestUsername,
},
}),
}
c2 := &fakeConnection{
BaseConnection: NewBaseConnection("id2", ProtocolSFTP, "", "", dataprovider.User{
BaseUser: sdk.BaseUser{
Username: userTestUsername,
},
}),
}
c3 := &fakeConnection{
BaseConnection: NewBaseConnection("id3", ProtocolSFTP, "", "", dataprovider.User{
BaseUser: sdk.BaseUser{
Username: userTestUsername,
},
}),
}
c4 := &fakeConnection{
BaseConnection: NewBaseConnection("id4", ProtocolSFTP, "", "", dataprovider.User{
BaseUser: sdk.BaseUser{
Username: userTestUsername,
},
}),
}
assert.Equal(t, "SFTP_id1", c1.GetID())
assert.Equal(t, "SFTP_id2", c2.GetID())
assert.Equal(t, "SFTP_id3", c3.GetID())
assert.Equal(t, "SFTP_id4", c4.GetID())
err := Connections.Add(c1)
assert.NoError(t, err)
err = Connections.Add(c2)
assert.NoError(t, err)
err = Connections.Add(c3)
assert.NoError(t, err)
err = Connections.Add(c4)
assert.NoError(t, err)
Connections.RLock()
assert.Len(t, Connections.connections, 4)
assert.Len(t, Connections.mapping, 4)
_, ok := Connections.mapping[c1.GetID()]
assert.True(t, ok)
assert.Equal(t, 0, Connections.mapping[c1.GetID()])
assert.Equal(t, 1, Connections.mapping[c2.GetID()])
assert.Equal(t, 2, Connections.mapping[c3.GetID()])
assert.Equal(t, 3, Connections.mapping[c4.GetID()])
Connections.RUnlock()
c2 = &fakeConnection{
BaseConnection: NewBaseConnection("id2", ProtocolSFTP, "", "", dataprovider.User{
BaseUser: sdk.BaseUser{
Username: userTestUsername + "_mod",
},
}),
}
err = Connections.Swap(c2)
assert.NoError(t, err)
Connections.RLock()
assert.Len(t, Connections.connections, 4)
assert.Len(t, Connections.mapping, 4)
_, ok = Connections.mapping[c1.GetID()]
assert.True(t, ok)
assert.Equal(t, 0, Connections.mapping[c1.GetID()])
assert.Equal(t, 1, Connections.mapping[c2.GetID()])
assert.Equal(t, 2, Connections.mapping[c3.GetID()])
assert.Equal(t, 3, Connections.mapping[c4.GetID()])
assert.Equal(t, userTestUsername+"_mod", Connections.connections[1].GetUsername())
Connections.RUnlock()
Connections.Remove(c2.GetID())
Connections.RLock()
assert.Len(t, Connections.connections, 3)
assert.Len(t, Connections.mapping, 3)
_, ok = Connections.mapping[c1.GetID()]
assert.True(t, ok)
assert.Equal(t, 0, Connections.mapping[c1.GetID()])
assert.Equal(t, 1, Connections.mapping[c4.GetID()])
assert.Equal(t, 2, Connections.mapping[c3.GetID()])
Connections.RUnlock()
Connections.Remove(c3.GetID())
Connections.RLock()
assert.Len(t, Connections.connections, 2)
assert.Len(t, Connections.mapping, 2)
_, ok = Connections.mapping[c1.GetID()]
assert.True(t, ok)
assert.Equal(t, 0, Connections.mapping[c1.GetID()])
assert.Equal(t, 1, Connections.mapping[c4.GetID()])
Connections.RUnlock()
Connections.Remove(c1.GetID())
Connections.RLock()
assert.Len(t, Connections.connections, 1)
assert.Len(t, Connections.mapping, 1)
_, ok = Connections.mapping[c4.GetID()]
assert.True(t, ok)
assert.Equal(t, 0, Connections.mapping[c4.GetID()])
Connections.RUnlock()
Connections.Remove(c4.GetID())
Connections.RLock()
assert.Len(t, Connections.connections, 0)
assert.Len(t, Connections.mapping, 0)
Connections.RUnlock()
}
func TestSSHConnections(t *testing.T) {
conn1, conn2 := net.Pipe()
now := time.Now()
@@ -231,44 +112,27 @@ func TestSSHConnections(t *testing.T) {
Connections.AddSSHConnection(sshConn3)
Connections.RLock()
assert.Len(t, Connections.sshConnections, 3)
_, ok := Connections.sshMapping[sshConn1.GetID()]
assert.True(t, ok)
assert.Equal(t, 0, Connections.sshMapping[sshConn1.GetID()])
assert.Equal(t, 1, Connections.sshMapping[sshConn2.GetID()])
assert.Equal(t, 2, Connections.sshMapping[sshConn3.GetID()])
Connections.RUnlock()
Connections.RemoveSSHConnection(sshConn1.id)
Connections.RLock()
assert.Len(t, Connections.sshConnections, 2)
assert.Equal(t, sshConn3.id, Connections.sshConnections[0].id)
assert.Equal(t, sshConn2.id, Connections.sshConnections[1].id)
_, ok = Connections.sshMapping[sshConn3.GetID()]
assert.True(t, ok)
assert.Equal(t, 0, Connections.sshMapping[sshConn3.GetID()])
assert.Equal(t, 1, Connections.sshMapping[sshConn2.GetID()])
Connections.RUnlock()
Connections.RemoveSSHConnection(sshConn1.id)
Connections.RLock()
assert.Len(t, Connections.sshConnections, 2)
assert.Equal(t, sshConn3.id, Connections.sshConnections[0].id)
assert.Equal(t, sshConn2.id, Connections.sshConnections[1].id)
_, ok = Connections.sshMapping[sshConn3.GetID()]
assert.True(t, ok)
assert.Equal(t, 0, Connections.sshMapping[sshConn3.GetID()])
assert.Equal(t, 1, Connections.sshMapping[sshConn2.GetID()])
Connections.RUnlock()
Connections.RemoveSSHConnection(sshConn2.id)
Connections.RLock()
assert.Len(t, Connections.sshConnections, 1)
assert.Equal(t, sshConn3.id, Connections.sshConnections[0].id)
_, ok = Connections.sshMapping[sshConn3.GetID()]
assert.True(t, ok)
assert.Equal(t, 0, Connections.sshMapping[sshConn3.GetID()])
Connections.RUnlock()
Connections.RemoveSSHConnection(sshConn3.id)
Connections.RLock()
assert.Len(t, Connections.sshConnections, 0)
assert.Len(t, Connections.sshMapping, 0)
Connections.RUnlock()
assert.NoError(t, sshConn1.Close())
assert.NoError(t, sshConn2.Close())
@@ -284,14 +148,14 @@ func TestDefenderIntegration(t *testing.T) {
pluginsConfig := []plugin.Config{
{
Type: "ipfilter",
Cmd: filepath.Join(wdPath, "..", "..", "tests", "ipfilter", "ipfilter"),
Cmd: filepath.Join(wdPath, "..", "tests", "ipfilter", "ipfilter"),
AutoMTLS: true,
},
}
if runtime.GOOS == osWindows {
pluginsConfig[0].Cmd += ".exe"
}
err = plugin.Initialize(pluginsConfig, "debug")
err = plugin.Initialize(pluginsConfig, true)
require.NoError(t, err)
ip := "127.1.1.1"
@@ -497,10 +361,10 @@ func TestWhitelist(t *testing.T) {
err = Initialize(Config, 0)
assert.NoError(t, err)
assert.NoError(t, Connections.IsNewConnectionAllowed("172.18.1.1"))
assert.Error(t, Connections.IsNewConnectionAllowed("172.18.1.3"))
assert.NoError(t, Connections.IsNewConnectionAllowed("10.8.7.3"))
assert.Error(t, Connections.IsNewConnectionAllowed("10.8.8.2"))
assert.True(t, Connections.IsNewConnectionAllowed("172.18.1.1"))
assert.False(t, Connections.IsNewConnectionAllowed("172.18.1.3"))
assert.True(t, Connections.IsNewConnectionAllowed("10.8.7.3"))
assert.False(t, Connections.IsNewConnectionAllowed("10.8.8.2"))
wl.IPAddresses = append(wl.IPAddresses, "172.18.1.3")
wl.CIDRNetworks = append(wl.CIDRNetworks, "10.8.8.0/24")
@@ -508,14 +372,14 @@ func TestWhitelist(t *testing.T) {
assert.NoError(t, err)
err = os.WriteFile(wlFile, data, 0664)
assert.NoError(t, err)
assert.Error(t, Connections.IsNewConnectionAllowed("10.8.8.3"))
assert.False(t, Connections.IsNewConnectionAllowed("10.8.8.3"))
err = Reload()
assert.NoError(t, err)
assert.NoError(t, Connections.IsNewConnectionAllowed("10.8.8.3"))
assert.NoError(t, Connections.IsNewConnectionAllowed("172.18.1.3"))
assert.NoError(t, Connections.IsNewConnectionAllowed("172.18.1.2"))
assert.Error(t, Connections.IsNewConnectionAllowed("172.18.1.12"))
assert.True(t, Connections.IsNewConnectionAllowed("10.8.8.3"))
assert.True(t, Connections.IsNewConnectionAllowed("172.18.1.3"))
assert.True(t, Connections.IsNewConnectionAllowed("172.18.1.2"))
assert.False(t, Connections.IsNewConnectionAllowed("172.18.1.12"))
Config = configCopy
}
@@ -550,12 +414,12 @@ func TestMaxConnections(t *testing.T) {
Config.MaxPerHostConnections = 0
ipAddr := "192.168.7.8"
assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr))
assert.True(t, Connections.IsNewConnectionAllowed(ipAddr))
Config.MaxTotalConnections = 1
Config.MaxPerHostConnections = perHost
assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr))
assert.True(t, Connections.IsNewConnectionAllowed(ipAddr))
c := NewBaseConnection("id", ProtocolSFTP, "", "", dataprovider.User{})
fakeConn := &fakeConnection{
BaseConnection: c,
@@ -563,18 +427,18 @@ func TestMaxConnections(t *testing.T) {
err := Connections.Add(fakeConn)
assert.NoError(t, err)
assert.Len(t, Connections.GetStats(), 1)
assert.Error(t, Connections.IsNewConnectionAllowed(ipAddr))
assert.False(t, Connections.IsNewConnectionAllowed(ipAddr))
res := Connections.Close(fakeConn.GetID())
assert.True(t, res)
assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 300*time.Millisecond, 50*time.Millisecond)
assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr))
assert.True(t, Connections.IsNewConnectionAllowed(ipAddr))
Connections.AddClientConnection(ipAddr)
Connections.AddClientConnection(ipAddr)
assert.Error(t, Connections.IsNewConnectionAllowed(ipAddr))
assert.False(t, Connections.IsNewConnectionAllowed(ipAddr))
Connections.RemoveClientConnection(ipAddr)
assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr))
assert.True(t, Connections.IsNewConnectionAllowed(ipAddr))
Connections.RemoveClientConnection(ipAddr)
Config.MaxTotalConnections = oldValue
@@ -587,13 +451,13 @@ func TestMaxConnectionPerHost(t *testing.T) {
ipAddr := "192.168.9.9"
Connections.AddClientConnection(ipAddr)
assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr))
assert.True(t, Connections.IsNewConnectionAllowed(ipAddr))
Connections.AddClientConnection(ipAddr)
assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr))
assert.True(t, Connections.IsNewConnectionAllowed(ipAddr))
Connections.AddClientConnection(ipAddr)
assert.Error(t, Connections.IsNewConnectionAllowed(ipAddr))
assert.False(t, Connections.IsNewConnectionAllowed(ipAddr))
assert.Equal(t, int32(3), Connections.GetClientConnections())
Connections.RemoveClientConnection(ipAddr)
@@ -631,14 +495,14 @@ func TestIdleConnections(t *testing.T) {
},
}
c := NewBaseConnection(sshConn1.id+"_1", ProtocolSFTP, "", "", user)
c.lastActivity.Store(time.Now().Add(-24 * time.Hour).UnixNano())
c.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
fakeConn := &fakeConnection{
BaseConnection: c,
}
// both ssh connections are expired but they should get removed only
// if there is no associated connection
sshConn1.lastActivity.Store(c.lastActivity.Load())
sshConn2.lastActivity.Store(c.lastActivity.Load())
sshConn1.lastActivity = c.lastActivity
sshConn2.lastActivity = c.lastActivity
Connections.AddSSHConnection(sshConn1)
err = Connections.Add(fakeConn)
assert.NoError(t, err)
@@ -653,7 +517,7 @@ func TestIdleConnections(t *testing.T) {
assert.Equal(t, Connections.GetActiveSessions(username), 2)
cFTP := NewBaseConnection("id2", ProtocolFTP, "", "", dataprovider.User{})
cFTP.lastActivity.Store(time.Now().UnixNano())
cFTP.lastActivity = time.Now().UnixNano()
fakeConn = &fakeConnection{
BaseConnection: cFTP,
}
@@ -665,27 +529,27 @@ func TestIdleConnections(t *testing.T) {
assert.Len(t, Connections.sshConnections, 2)
Connections.RUnlock()
startPeriodicChecks(100 * time.Millisecond)
assert.Eventually(t, func() bool { return Connections.GetActiveSessions(username) == 1 }, 2*time.Second, 200*time.Millisecond)
startPeriodicTimeoutTicker(100 * time.Millisecond)
assert.Eventually(t, func() bool { return Connections.GetActiveSessions(username) == 1 }, 1*time.Second, 200*time.Millisecond)
assert.Eventually(t, func() bool {
Connections.RLock()
defer Connections.RUnlock()
return len(Connections.sshConnections) == 1
}, 1*time.Second, 200*time.Millisecond)
stopEventScheduler()
stopPeriodicTimeoutTicker()
assert.Len(t, Connections.GetStats(), 2)
c.lastActivity.Store(time.Now().Add(-24 * time.Hour).UnixNano())
cFTP.lastActivity.Store(time.Now().Add(-24 * time.Hour).UnixNano())
sshConn2.lastActivity.Store(c.lastActivity.Load())
startPeriodicChecks(100 * time.Millisecond)
assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 2*time.Second, 200*time.Millisecond)
c.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
cFTP.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
sshConn2.lastActivity = c.lastActivity
startPeriodicTimeoutTicker(100 * time.Millisecond)
assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 1*time.Second, 200*time.Millisecond)
assert.Eventually(t, func() bool {
Connections.RLock()
defer Connections.RUnlock()
return len(Connections.sshConnections) == 0
}, 1*time.Second, 200*time.Millisecond)
assert.Equal(t, int32(0), Connections.GetClientConnections())
stopEventScheduler()
stopPeriodicTimeoutTicker()
assert.True(t, customConn1.isClosed)
assert.True(t, customConn2.isClosed)
@@ -697,7 +561,7 @@ func TestCloseConnection(t *testing.T) {
fakeConn := &fakeConnection{
BaseConnection: c,
}
assert.NoError(t, Connections.IsNewConnectionAllowed("127.0.0.1"))
assert.True(t, Connections.IsNewConnectionAllowed("127.0.0.1"))
err := Connections.Add(fakeConn)
assert.NoError(t, err)
assert.Len(t, Connections.GetStats(), 1)
@@ -779,9 +643,9 @@ func TestConnectionStatus(t *testing.T) {
BaseConnection: c1,
}
t1 := NewBaseTransfer(nil, c1, nil, "/p1", "/p1", "/r1", TransferUpload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{})
t1.BytesReceived.Store(123)
t1.BytesReceived = 123
t2 := NewBaseTransfer(nil, c1, nil, "/p2", "/p2", "/r2", TransferDownload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{})
t2.BytesSent.Store(456)
t2.BytesSent = 456
c2 := NewBaseConnection("id2", ProtocolSSH, "", "", user)
fakeConn2 := &fakeConnection{
BaseConnection: c2,
@@ -831,7 +695,7 @@ func TestConnectionStatus(t *testing.T) {
err = fakeConn3.SignalTransfersAbort()
assert.NoError(t, err)
assert.True(t, t3.AbortTransfer.Load())
assert.Equal(t, int32(1), atomic.LoadInt32(&t3.AbortTransfer))
err = t3.Close()
assert.NoError(t, err)
err = fakeConn3.SignalTransfersAbort()
@@ -1194,189 +1058,6 @@ func TestUserRecentActivity(t *testing.T) {
assert.True(t, res)
}
func TestVfsSameResource(t *testing.T) {
fs := vfs.Filesystem{}
other := vfs.Filesystem{}
res := fs.IsSameResource(other)
assert.True(t, res)
fs = vfs.Filesystem{
Provider: sdk.S3FilesystemProvider,
S3Config: vfs.S3FsConfig{
BaseS3FsConfig: sdk.BaseS3FsConfig{
Bucket: "a",
Region: "b",
},
},
}
other = vfs.Filesystem{
Provider: sdk.S3FilesystemProvider,
S3Config: vfs.S3FsConfig{
BaseS3FsConfig: sdk.BaseS3FsConfig{
Bucket: "a",
Region: "c",
},
},
}
res = fs.IsSameResource(other)
assert.False(t, res)
other = vfs.Filesystem{
Provider: sdk.S3FilesystemProvider,
S3Config: vfs.S3FsConfig{
BaseS3FsConfig: sdk.BaseS3FsConfig{
Bucket: "a",
Region: "b",
},
},
}
res = fs.IsSameResource(other)
assert.True(t, res)
fs = vfs.Filesystem{
Provider: sdk.GCSFilesystemProvider,
GCSConfig: vfs.GCSFsConfig{
BaseGCSFsConfig: sdk.BaseGCSFsConfig{
Bucket: "b",
},
},
}
other = vfs.Filesystem{
Provider: sdk.GCSFilesystemProvider,
GCSConfig: vfs.GCSFsConfig{
BaseGCSFsConfig: sdk.BaseGCSFsConfig{
Bucket: "c",
},
},
}
res = fs.IsSameResource(other)
assert.False(t, res)
other = vfs.Filesystem{
Provider: sdk.GCSFilesystemProvider,
GCSConfig: vfs.GCSFsConfig{
BaseGCSFsConfig: sdk.BaseGCSFsConfig{
Bucket: "b",
},
},
}
res = fs.IsSameResource(other)
assert.True(t, res)
sasURL := kms.NewPlainSecret("http://127.0.0.1/sasurl")
fs = vfs.Filesystem{
Provider: sdk.AzureBlobFilesystemProvider,
AzBlobConfig: vfs.AzBlobFsConfig{
BaseAzBlobFsConfig: sdk.BaseAzBlobFsConfig{
AccountName: "a",
},
SASURL: sasURL,
},
}
err := fs.Validate("data1")
assert.NoError(t, err)
other = vfs.Filesystem{
Provider: sdk.AzureBlobFilesystemProvider,
AzBlobConfig: vfs.AzBlobFsConfig{
BaseAzBlobFsConfig: sdk.BaseAzBlobFsConfig{
AccountName: "a",
},
SASURL: sasURL,
},
}
err = other.Validate("data2")
assert.NoError(t, err)
err = fs.AzBlobConfig.SASURL.TryDecrypt()
assert.NoError(t, err)
err = other.AzBlobConfig.SASURL.TryDecrypt()
assert.NoError(t, err)
res = fs.IsSameResource(other)
assert.True(t, res)
fs.AzBlobConfig.AccountName = "b"
res = fs.IsSameResource(other)
assert.False(t, res)
fs.AzBlobConfig.AccountName = "a"
other.AzBlobConfig.SASURL = kms.NewPlainSecret("http://127.1.1.1/sasurl")
err = other.Validate("data2")
assert.NoError(t, err)
err = other.AzBlobConfig.SASURL.TryDecrypt()
assert.NoError(t, err)
res = fs.IsSameResource(other)
assert.False(t, res)
fs = vfs.Filesystem{
Provider: sdk.HTTPFilesystemProvider,
HTTPConfig: vfs.HTTPFsConfig{
BaseHTTPFsConfig: sdk.BaseHTTPFsConfig{
Endpoint: "http://127.0.0.1/httpfs",
Username: "a",
},
},
}
other = vfs.Filesystem{
Provider: sdk.HTTPFilesystemProvider,
HTTPConfig: vfs.HTTPFsConfig{
BaseHTTPFsConfig: sdk.BaseHTTPFsConfig{
Endpoint: "http://127.0.0.1/httpfs",
Username: "b",
},
},
}
res = fs.IsSameResource(other)
assert.True(t, res)
fs.HTTPConfig.EqualityCheckMode = 1
res = fs.IsSameResource(other)
assert.False(t, res)
}
func TestUpdateTransferTimestamps(t *testing.T) {
username := "user_test_timestamps"
user := &dataprovider.User{
BaseUser: sdk.BaseUser{
Username: username,
HomeDir: filepath.Join(os.TempDir(), username),
Status: 1,
Permissions: map[string][]string{
"/": {dataprovider.PermAny},
},
},
}
err := dataprovider.AddUser(user, "", "")
assert.NoError(t, err)
assert.Equal(t, int64(0), user.FirstUpload)
assert.Equal(t, int64(0), user.FirstDownload)
err = dataprovider.UpdateUserTransferTimestamps(username, true)
assert.NoError(t, err)
userGet, err := dataprovider.UserExists(username)
assert.NoError(t, err)
assert.Greater(t, userGet.FirstUpload, int64(0))
assert.Equal(t, int64(0), user.FirstDownload)
err = dataprovider.UpdateUserTransferTimestamps(username, false)
assert.NoError(t, err)
userGet, err = dataprovider.UserExists(username)
assert.NoError(t, err)
assert.Greater(t, userGet.FirstUpload, int64(0))
assert.Greater(t, userGet.FirstDownload, int64(0))
// updating again must fail
err = dataprovider.UpdateUserTransferTimestamps(username, true)
assert.Error(t, err)
err = dataprovider.UpdateUserTransferTimestamps(username, false)
assert.Error(t, err)
// cleanup
err = dataprovider.DeleteUser(username, "", "")
assert.NoError(t, err)
}
func TestMetadataAPI(t *testing.T) {
username := "metadatauser"
require.False(t, ActiveMetadataChecks.Remove(username))
require.True(t, ActiveMetadataChecks.Add(username))
require.False(t, ActiveMetadataChecks.Add(username))
checks := ActiveMetadataChecks.Get()
require.Len(t, checks, 1)
checks[0].Username = username + "a"
checks = ActiveMetadataChecks.Get()
require.Len(t, checks, 1)
require.Equal(t, username, checks[0].Username)
require.True(t, ActiveMetadataChecks.Remove(username))
require.Len(t, ActiveMetadataChecks.Get(), 0)
}
func BenchmarkBcryptHashing(b *testing.B) {
bcryptPassword := "bcryptpassword"
for i := 0; i < b.N; i++ {
@@ -1416,52 +1097,3 @@ func BenchmarkCompareArgon2Password(b *testing.B) {
}
}
}
func BenchmarkAddRemoveConnections(b *testing.B) {
var conns []ActiveConnection
for i := 0; i < 100; i++ {
conns = append(conns, &fakeConnection{
BaseConnection: NewBaseConnection(fmt.Sprintf("id%d", i), ProtocolSFTP, "", "", dataprovider.User{
BaseUser: sdk.BaseUser{
Username: userTestUsername,
},
}),
})
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
for _, c := range conns {
if err := Connections.Add(c); err != nil {
panic(err)
}
}
var wg sync.WaitGroup
for idx := len(conns) - 1; idx >= 0; idx-- {
wg.Add(1)
go func(index int) {
defer wg.Done()
Connections.Remove(conns[index].GetID())
}(idx)
}
wg.Wait()
}
}
func BenchmarkAddRemoveSSHConnections(b *testing.B) {
conn1, conn2 := net.Pipe()
var conns []*SSHConnection
for i := 0; i < 2000; i++ {
conns = append(conns, NewSSHConnection(fmt.Sprintf("id%d", i), conn1))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
for _, c := range conns {
Connections.AddSSHConnection(c)
}
for idx := len(conns) - 1; idx >= 0; idx-- {
Connections.RemoveSSHConnection(conns[idx].GetID())
}
}
conn1.Close()
conn2.Close()
}

View File

@@ -28,22 +28,20 @@ import (
"github.com/pkg/sftp"
"github.com/sftpgo/sdk"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/internal/vfs"
"github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/util"
"github.com/drakkan/sftpgo/v2/vfs"
)
// BaseConnection defines common fields for a connection using any supported protocol
type BaseConnection struct {
// last activity for this connection.
// Since this field is accessed atomically we put it as first element of the struct to achieve 64 bit alignment
lastActivity atomic.Int64
uploadDone atomic.Bool
downloadDone atomic.Bool
lastActivity int64
// unique ID for a transfer.
// This field is accessed atomically so we put it at the beginning of the struct to achieve 64 bit alignment
transferID atomic.Int64
transferID int64
// Unique identifier for the connection
ID string
// user associated with this connection if any
@@ -64,18 +62,16 @@ func NewBaseConnection(id, protocol, localAddr, remoteAddr string, user dataprov
connID = fmt.Sprintf("%s_%s", protocol, id)
}
user.UploadBandwidth, user.DownloadBandwidth = user.GetBandwidthForIP(util.GetIPFromRemoteAddress(remoteAddr), connID)
c := &BaseConnection{
ID: connID,
User: user,
startTime: time.Now(),
protocol: protocol,
localAddr: localAddr,
remoteAddr: remoteAddr,
return &BaseConnection{
ID: connID,
User: user,
startTime: time.Now(),
protocol: protocol,
localAddr: localAddr,
remoteAddr: remoteAddr,
lastActivity: time.Now().UnixNano(),
transferID: 0,
}
c.transferID.Store(0)
c.lastActivity.Store(time.Now().UnixNano())
return c
}
// Log outputs a log entry to the configured logger
@@ -85,7 +81,7 @@ func (c *BaseConnection) Log(level logger.LogLevel, format string, v ...any) {
// GetTransferID returns an unique transfer ID for this connection
func (c *BaseConnection) GetTransferID() int64 {
return c.transferID.Add(1)
return atomic.AddInt64(&c.transferID, 1)
}
// GetID returns the connection ID
@@ -128,12 +124,12 @@ func (c *BaseConnection) GetConnectionTime() time.Time {
// UpdateLastActivity updates last activity for this connection
func (c *BaseConnection) UpdateLastActivity() {
c.lastActivity.Store(time.Now().UnixNano())
atomic.StoreInt64(&c.lastActivity, time.Now().UnixNano())
}
// GetLastActivity returns the last connection activity
func (c *BaseConnection) GetLastActivity() time.Time {
return time.Unix(0, c.lastActivity.Load())
return time.Unix(0, atomic.LoadInt64(&c.lastActivity))
}
// CloseFS closes the underlying fs
@@ -257,7 +253,7 @@ func (c *BaseConnection) getRealFsPath(fsPath string) string {
defer c.RUnlock()
for _, t := range c.activeTransfers {
if p := t.GetRealFsPath(fsPath); p != "" {
if p := t.GetRealFsPath(fsPath); len(p) > 0 {
return p
}
}
@@ -362,7 +358,7 @@ func (c *BaseConnection) CreateDir(virtualPath string, checkFilePatterns bool) e
logger.CommandLog(mkdirLogSender, fsPath, "", c.User.Username, "", c.ID, c.protocol, -1, -1, "", "", "", -1,
c.localAddr, c.remoteAddr)
ExecuteActionNotification(c, operationMkdir, fsPath, virtualPath, "", "", "", 0, nil) //nolint:errcheck
ExecuteActionNotification(c, operationMkdir, fsPath, virtualPath, "", "", "", 0, nil)
return nil
}
@@ -409,7 +405,7 @@ func (c *BaseConnection) RemoveFile(fs vfs.Fs, fsPath, virtualPath string, info
}
}
if actionErr != nil {
ExecuteActionNotification(c, operationDelete, fsPath, virtualPath, "", "", "", size, nil) //nolint:errcheck
ExecuteActionNotification(c, operationDelete, fsPath, virtualPath, "", "", "", size, nil)
}
return nil
}
@@ -473,7 +469,7 @@ func (c *BaseConnection) RemoveDir(virtualPath string) error {
logger.CommandLog(rmdirLogSender, fsPath, "", c.User.Username, "", c.ID, c.protocol, -1, -1, "", "", "", -1,
c.localAddr, c.remoteAddr)
ExecuteActionNotification(c, operationRmdir, fsPath, virtualPath, "", "", "", 0, nil) //nolint:errcheck
ExecuteActionNotification(c, operationRmdir, fsPath, virtualPath, "", "", "", 0, nil)
return nil
}
@@ -500,7 +496,7 @@ func (c *BaseConnection) Rename(virtualSourcePath, virtualTargetPath string) err
initialSize := int64(-1)
if dstInfo, err := fsDst.Lstat(fsTargetPath); err == nil {
if dstInfo.IsDir() {
c.Log(logger.LevelWarn, "attempted to rename %q overwriting an existing directory %q",
c.Log(logger.LevelWarn, "attempted to rename %#v overwriting an existing directory %#v",
fsSourcePath, fsTargetPath)
return c.GetOpUnsupportedError()
}
@@ -520,8 +516,7 @@ func (c *BaseConnection) Rename(virtualSourcePath, virtualTargetPath string) err
virtualSourcePath)
return c.GetOpUnsupportedError()
}
if err = c.checkRecursiveRenameDirPermissions(fsSrc, fsDst, fsSourcePath, fsTargetPath,
virtualSourcePath, virtualTargetPath, srcInfo); err != nil {
if err = c.checkRecursiveRenameDirPermissions(fsSrc, fsDst, fsSourcePath, fsTargetPath); err != nil {
c.Log(logger.LevelDebug, "error checking recursive permissions before renaming %#v: %+v", fsSourcePath, err)
return err
}
@@ -530,7 +525,7 @@ func (c *BaseConnection) Rename(virtualSourcePath, virtualTargetPath string) err
c.Log(logger.LevelInfo, "denying cross rename due to space limit")
return c.GetGenericError(ErrQuotaExceeded)
}
if err := fsDst.Rename(fsSourcePath, fsTargetPath); err != nil {
if err := fsSrc.Rename(fsSourcePath, fsTargetPath); err != nil {
c.Log(logger.LevelError, "failed to rename %#v -> %#v: %+v", fsSourcePath, fsTargetPath, err)
return c.GetFsError(fsSrc, err)
}
@@ -538,21 +533,14 @@ func (c *BaseConnection) Rename(virtualSourcePath, virtualTargetPath string) err
c.updateQuotaAfterRename(fsDst, virtualSourcePath, virtualTargetPath, fsTargetPath, initialSize) //nolint:errcheck
logger.CommandLog(renameLogSender, fsSourcePath, fsTargetPath, c.User.Username, "", c.ID, c.protocol, -1, -1,
"", "", "", -1, c.localAddr, c.remoteAddr)
ExecuteActionNotification(c, operationRename, fsSourcePath, virtualSourcePath, fsTargetPath, //nolint:errcheck
virtualTargetPath, "", 0, nil)
ExecuteActionNotification(c, operationRename, fsSourcePath, virtualSourcePath, fsTargetPath, virtualTargetPath,
"", 0, nil)
return nil
}
// CreateSymlink creates fsTargetPath as a symbolic link to fsSourcePath
func (c *BaseConnection) CreateSymlink(virtualSourcePath, virtualTargetPath string) error {
var relativePath string
if !path.IsAbs(virtualSourcePath) {
relativePath = virtualSourcePath
virtualSourcePath = path.Join(path.Dir(virtualTargetPath), relativePath)
c.Log(logger.LevelDebug, "link relative path %q resolved as %q, target path %q",
relativePath, virtualSourcePath, virtualTargetPath)
}
if c.isCrossFoldersRequest(virtualSourcePath, virtualTargetPath) {
c.Log(logger.LevelWarn, "cross folder symlink is not supported, src: %v dst: %v", virtualSourcePath, virtualTargetPath)
return c.GetOpUnsupportedError()
@@ -586,9 +574,6 @@ func (c *BaseConnection) CreateSymlink(virtualSourcePath, virtualTargetPath stri
c.Log(logger.LevelError, "symlink target path %#v is not allowed", virtualTargetPath)
return c.GetPermissionDeniedError()
}
if relativePath != "" {
fsSourcePath = relativePath
}
if err := fs.Symlink(fsSourcePath, fsTargetPath); err != nil {
c.Log(logger.LevelError, "failed to create symlink %#v -> %#v: %+v", fsSourcePath, fsTargetPath, err)
return c.GetFsError(fs, err)
@@ -608,14 +593,13 @@ func (c *BaseConnection) getPathForSetStatPerms(fs vfs.Fs, fsPath, virtualPath s
return pathForPerms
}
func (c *BaseConnection) doStatInternal(virtualPath string, mode int, checkFilePatterns,
convertResult bool,
) (os.FileInfo, error) {
// DoStat execute a Stat if mode = 0, Lstat if mode = 1
func (c *BaseConnection) DoStat(virtualPath string, mode int, checkFilePatterns bool) (os.FileInfo, error) {
// for some vfs we don't create intermediary folders so we cannot simply check
// if virtualPath is a virtual folder
vfolders := c.User.GetVirtualFoldersInPath(path.Dir(virtualPath))
if _, ok := vfolders[virtualPath]; ok {
return vfs.NewFileInfo(virtualPath, true, 0, time.Unix(0, 0), false), nil
return vfs.NewFileInfo(virtualPath, true, 0, time.Now(), false), nil
}
if checkFilePatterns {
ok, policy := c.User.IsFileAllowed(virtualPath)
@@ -637,20 +621,15 @@ func (c *BaseConnection) doStatInternal(virtualPath string, mode int, checkFileP
info, err = fs.Stat(c.getRealFsPath(fsPath))
}
if err != nil {
c.Log(logger.LevelWarn, "stat error for path %#v: %+v", virtualPath, err)
c.Log(logger.LevelError, "stat error for path %#v: %+v", virtualPath, err)
return info, c.GetFsError(fs, err)
}
if convertResult && vfs.IsCryptOsFs(fs) {
if vfs.IsCryptOsFs(fs) {
info = fs.(*vfs.CryptFs).ConvertFileInfo(info)
}
return info, nil
}
// DoStat execute a Stat if mode = 0, Lstat if mode = 1
func (c *BaseConnection) DoStat(virtualPath string, mode int, checkFilePatterns bool) (os.FileInfo, error) {
return c.doStatInternal(virtualPath, mode, checkFilePatterns, true)
}
func (c *BaseConnection) createDirIfMissing(name string) error {
_, err := c.DoStat(name, 0, false)
if c.IsNotExistError(err) {
@@ -787,7 +766,7 @@ func (c *BaseConnection) truncateFile(fs vfs.Fs, fsPath, virtualPath string, siz
initialSize = info.Size()
err = fs.Truncate(fsPath, size)
}
if err == nil && vfs.HasTruncateSupport(fs) {
if err == nil && vfs.IsLocalOrSFTPFs(fs) {
sizeDiff := initialSize - size
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(virtualPath))
if err == nil {
@@ -802,31 +781,23 @@ func (c *BaseConnection) truncateFile(fs vfs.Fs, fsPath, virtualPath string, siz
return err
}
func (c *BaseConnection) checkRecursiveRenameDirPermissions(fsSrc, fsDst vfs.Fs, sourcePath, targetPath,
virtualSourcePath, virtualTargetPath string, fi os.FileInfo,
) error {
if !c.User.HasPermissionsInside(virtualSourcePath) &&
!c.User.HasPermissionsInside(virtualTargetPath) {
if !c.isRenamePermitted(fsSrc, fsDst, sourcePath, targetPath, virtualSourcePath, virtualTargetPath, fi) {
c.Log(logger.LevelInfo, "rename %#v -> %#v is not allowed, virtual destination path: %#v",
sourcePath, targetPath, virtualTargetPath)
return c.GetPermissionDeniedError()
}
// if all rename permissions are granted we have finished, otherwise we have to walk
// because we could have the rename dir permission but not the rename file and the dir to
// rename could contain files
if c.User.HasPermsRenameAll(path.Dir(virtualSourcePath)) && c.User.HasPermsRenameAll(path.Dir(virtualTargetPath)) {
return nil
}
}
return fsSrc.Walk(sourcePath, func(walkedPath string, info os.FileInfo, err error) error {
func (c *BaseConnection) checkRecursiveRenameDirPermissions(fsSrc, fsDst vfs.Fs, sourcePath, targetPath string) error {
err := fsSrc.Walk(sourcePath, func(walkedPath string, info os.FileInfo, err error) error {
if err != nil {
return c.GetFsError(fsSrc, err)
}
dstPath := strings.Replace(walkedPath, sourcePath, targetPath, 1)
virtualSrcPath := fsSrc.GetRelativePath(walkedPath)
virtualDstPath := fsDst.GetRelativePath(dstPath)
// walk scans the directory tree in order, checking the parent directory permissions we are sure that all contents
// inside the parent path was checked. If the current dir has no subdirs with defined permissions inside it
// and it has all the possible permissions we can stop scanning
if !c.User.HasPermissionsInside(path.Dir(virtualSrcPath)) && !c.User.HasPermissionsInside(path.Dir(virtualDstPath)) {
if c.User.HasPermsRenameAll(path.Dir(virtualSrcPath)) &&
c.User.HasPermsRenameAll(path.Dir(virtualDstPath)) {
return ErrSkipPermissionsCheck
}
}
if !c.isRenamePermitted(fsSrc, fsDst, walkedPath, dstPath, virtualSrcPath, virtualDstPath, info) {
c.Log(logger.LevelInfo, "rename %#v -> %#v is not allowed, virtual destination path: %#v",
walkedPath, dstPath, virtualDstPath)
@@ -834,6 +805,10 @@ func (c *BaseConnection) checkRecursiveRenameDirPermissions(fsSrc, fsDst vfs.Fs,
}
return nil
})
if err == ErrSkipPermissionsCheck {
err = nil
}
return err
}
func (c *BaseConnection) hasRenamePerms(virtualSourcePath, virtualTargetPath string, fi os.FileInfo) bool {
@@ -862,11 +837,9 @@ func (c *BaseConnection) hasRenamePerms(virtualSourcePath, virtualTargetPath str
c.User.HasAnyPerm(perms, path.Dir(virtualTargetPath))
}
func (c *BaseConnection) isRenamePermitted(fsSrc, fsDst vfs.Fs, fsSourcePath, fsTargetPath, virtualSourcePath,
virtualTargetPath string, fi os.FileInfo,
) bool {
if !c.isSameResourceRename(virtualSourcePath, virtualTargetPath) {
c.Log(logger.LevelInfo, "rename %#v->%#v is not allowed: the paths must be on the same resource",
func (c *BaseConnection) isRenamePermitted(fsSrc, fsDst vfs.Fs, fsSourcePath, fsTargetPath, virtualSourcePath, virtualTargetPath string, fi os.FileInfo) bool {
if !c.isLocalOrSameFolderRename(virtualSourcePath, virtualTargetPath) {
c.Log(logger.LevelInfo, "rename %#v->%#v is not allowed: the paths must be local or on the same virtual folder",
virtualSourcePath, virtualTargetPath)
return false
}
@@ -878,7 +851,7 @@ func (c *BaseConnection) isRenamePermitted(fsSrc, fsDst vfs.Fs, fsSourcePath, fs
c.Log(logger.LevelWarn, "renaming to a directory mapped as virtual folder is not allowed: %#v", fsTargetPath)
return false
}
if virtualSourcePath == "/" || virtualTargetPath == "/" || fsSrc.GetRelativePath(fsSourcePath) == "/" {
if fsSrc.GetRelativePath(fsSourcePath) == "/" {
c.Log(logger.LevelWarn, "renaming root dir is not allowed")
return false
}
@@ -1108,7 +1081,8 @@ func (c *BaseConnection) HasSpace(checkFiles, getUsage bool, requestPath string)
return result, transferQuota
}
func (c *BaseConnection) isSameResourceRename(virtualSourcePath, virtualTargetPath string) bool {
// returns true if this is a rename on the same fs or local virtual folders
func (c *BaseConnection) isLocalOrSameFolderRename(virtualSourcePath, virtualTargetPath string) bool {
sourceFolder, errSrc := c.User.GetVirtualFolderForPath(virtualSourcePath)
dstFolder, errDst := c.User.GetVirtualFolderForPath(virtualTargetPath)
if errSrc != nil && errDst != nil {
@@ -1118,13 +1092,27 @@ func (c *BaseConnection) isSameResourceRename(virtualSourcePath, virtualTargetPa
if sourceFolder.Name == dstFolder.Name {
return true
}
// we have different folders, check if they point to the same resource
return sourceFolder.FsConfig.IsSameResource(dstFolder.FsConfig)
// we have different folders, only local fs is supported
if sourceFolder.FsConfig.Provider == sdk.LocalFilesystemProvider &&
dstFolder.FsConfig.Provider == sdk.LocalFilesystemProvider {
return true
}
return false
}
if c.User.FsConfig.Provider != sdk.LocalFilesystemProvider {
return false
}
if errSrc == nil {
return sourceFolder.FsConfig.IsSameResource(c.User.FsConfig)
if sourceFolder.FsConfig.Provider == sdk.LocalFilesystemProvider {
return true
}
}
return dstFolder.FsConfig.IsSameResource(c.User.FsConfig)
if errDst == nil {
if dstFolder.FsConfig.Provider == sdk.LocalFilesystemProvider {
return true
}
}
return false
}
func (c *BaseConnection) isCrossFoldersRequest(virtualSourcePath, virtualTargetPath string) bool {
@@ -1362,21 +1350,16 @@ func (c *BaseConnection) GetGenericError(err error) error {
if err == vfs.ErrStorageSizeUnavailable {
return fmt.Errorf("%w: %v", sftp.ErrSSHFxOpUnsupported, err.Error())
}
if err == ErrShuttingDown {
return fmt.Errorf("%w: %v", sftp.ErrSSHFxFailure, err.Error())
}
if err != nil {
if e, ok := err.(*os.PathError); ok {
c.Log(logger.LevelError, "generic path error: %+v", e)
return fmt.Errorf("%w: %v %v", sftp.ErrSSHFxFailure, e.Op, e.Err.Error())
}
c.Log(logger.LevelError, "generic error: %+v", err)
return fmt.Errorf("%w: %v", sftp.ErrSSHFxFailure, ErrGenericFailure.Error())
return fmt.Errorf("%w: %v", sftp.ErrSSHFxFailure, err.Error())
}
return sftp.ErrSSHFxFailure
default:
if err == ErrPermissionDenied || err == ErrNotExist || err == ErrOpUnsupported ||
err == ErrQuotaExceeded || err == vfs.ErrStorageSizeUnavailable || err == ErrShuttingDown {
err == ErrQuotaExceeded || err == vfs.ErrStorageSizeUnavailable {
return err
}
return ErrGenericFailure
@@ -1409,10 +1392,6 @@ func (c *BaseConnection) GetFsAndResolvedPath(virtualPath string) (vfs.Fs, strin
return nil, "", err
}
if isShuttingDown.Load() {
return nil, "", c.GetFsError(fs, ErrShuttingDown)
}
fsPath, err := fs.ResolvePath(virtualPath)
if err != nil {
return nil, "", c.GetFsError(fs, err)

View File

@@ -27,24 +27,20 @@ import (
"github.com/sftpgo/sdk"
"github.com/stretchr/testify/assert"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/kms"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/internal/vfs"
"github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/kms"
"github.com/drakkan/sftpgo/v2/util"
"github.com/drakkan/sftpgo/v2/vfs"
)
// MockOsFs mockable OsFs
type MockOsFs struct {
vfs.Fs
hasVirtualFolders bool
name string
}
// Name returns the name for the Fs implementation
func (fs *MockOsFs) Name() string {
if fs.name != "" {
return fs.name
}
return "mockOsFs"
}
@@ -61,10 +57,9 @@ func (fs *MockOsFs) Chtimes(name string, atime, mtime time.Time, isUploading boo
return vfs.ErrVfsUnsupported
}
func newMockOsFs(hasVirtualFolders bool, connectionID, rootDir, name string) vfs.Fs {
func newMockOsFs(hasVirtualFolders bool, connectionID, rootDir string) vfs.Fs {
return &MockOsFs{
Fs: vfs.NewOsFs(connectionID, rootDir, ""),
name: name,
hasVirtualFolders: hasVirtualFolders,
}
}
@@ -113,7 +108,7 @@ func TestSetStatMode(t *testing.T) {
}
user.Permissions = make(map[string][]string)
user.Permissions["/"] = []string{dataprovider.PermAny}
fs := newMockOsFs(true, "", user.GetHomeDir(), "")
fs := newMockOsFs(true, "", user.GetHomeDir())
conn := NewBaseConnection("", ProtocolWebDAV, "", "", user)
err := conn.handleChmod(fs, fakePath, fakePath, nil)
assert.NoError(t, err)
@@ -135,28 +130,10 @@ func TestSetStatMode(t *testing.T) {
}
func TestRecursiveRenameWalkError(t *testing.T) {
fs := vfs.NewOsFs("", filepath.Clean(os.TempDir()), "")
conn := NewBaseConnection("", ProtocolWebDAV, "", "", dataprovider.User{
BaseUser: sdk.BaseUser{
Permissions: map[string][]string{
"/": {dataprovider.PermListItems, dataprovider.PermUpload,
dataprovider.PermDownload, dataprovider.PermRenameDirs},
},
},
})
err := conn.checkRecursiveRenameDirPermissions(fs, fs, filepath.Join(os.TempDir(), "/source"),
filepath.Join(os.TempDir(), "/target"), "/source", "/target",
vfs.NewFileInfo("source", true, 0, time.Now(), false))
fs := vfs.NewOsFs("", os.TempDir(), "")
conn := NewBaseConnection("", ProtocolWebDAV, "", "", dataprovider.User{})
err := conn.checkRecursiveRenameDirPermissions(fs, fs, "/source", "/target")
assert.ErrorIs(t, err, os.ErrNotExist)
conn.User.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermUpload,
dataprovider.PermDownload, dataprovider.PermRenameFiles}
// no dir rename permission, the quick check path returns permission error without walking
err = conn.checkRecursiveRenameDirPermissions(fs, fs, filepath.Join(os.TempDir(), "/source"),
filepath.Join(os.TempDir(), "/target"), "/source", "/target",
vfs.NewFileInfo("source", true, 0, time.Now(), false))
if assert.Error(t, err) {
assert.EqualError(t, err, conn.GetPermissionDeniedError().Error())
}
}
func TestCrossRenameFsErrors(t *testing.T) {
@@ -324,7 +301,7 @@ func TestErrorsMapping(t *testing.T) {
fs := vfs.NewOsFs("", os.TempDir(), "")
conn := NewBaseConnection("", ProtocolSFTP, "", "", dataprovider.User{BaseUser: sdk.BaseUser{HomeDir: os.TempDir()}})
osErrorsProtocols := []string{ProtocolWebDAV, ProtocolFTP, ProtocolHTTP, ProtocolHTTPShare,
ProtocolDataRetention, ProtocolOIDC, protocolEventAction}
ProtocolDataRetention, ProtocolOIDC}
for _, protocol := range supportedProtocols {
conn.SetProtocol(protocol)
err := conn.GetFsError(fs, os.ErrNotExist)
@@ -344,12 +321,14 @@ func TestErrorsMapping(t *testing.T) {
err = conn.GetFsError(fs, os.ErrClosed)
if protocol == ProtocolSFTP {
assert.ErrorIs(t, err, sftp.ErrSSHFxFailure)
assert.Contains(t, err.Error(), os.ErrClosed.Error())
} else {
assert.EqualError(t, err, ErrGenericFailure.Error())
}
err = conn.GetFsError(fs, ErrPermissionDenied)
if protocol == ProtocolSFTP {
assert.ErrorIs(t, err, sftp.ErrSSHFxFailure)
assert.Contains(t, err.Error(), ErrPermissionDenied.Error())
} else {
assert.EqualError(t, err, ErrPermissionDenied.Error())
}
@@ -385,13 +364,6 @@ func TestErrorsMapping(t *testing.T) {
} else {
assert.EqualError(t, err, ErrOpUnsupported.Error())
}
err = conn.GetFsError(fs, ErrShuttingDown)
if protocol == ProtocolSFTP {
assert.ErrorIs(t, err, sftp.ErrSSHFxFailure)
assert.Contains(t, err.Error(), ErrShuttingDown.Error())
} else {
assert.EqualError(t, err, ErrShuttingDown.Error())
}
}
}
@@ -441,7 +413,7 @@ func TestMaxWriteSize(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, int64(90), size)
fs = newMockOsFs(true, fs.ConnectionID(), user.GetHomeDir(), "")
fs = newMockOsFs(true, fs.ConnectionID(), user.GetHomeDir())
size, err = conn.GetMaxWriteSize(quotaResult, true, 100, fs.IsUploadResumeSupported())
assert.EqualError(t, err, ErrOpUnsupported.Error())
assert.Equal(t, int64(0), size)

View File

@@ -29,14 +29,12 @@ import (
"sync"
"time"
mail "github.com/xhit/go-simple-mail/v2"
"github.com/drakkan/sftpgo/v2/internal/command"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/httpclient"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/smtp"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/command"
"github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/httpclient"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/smtp"
"github.com/drakkan/sftpgo/v2/util"
)
// RetentionCheckNotification defines the supported notification methods for a retention check result
@@ -51,11 +49,11 @@ const (
)
var (
// RetentionChecks is the list of active retention checks
// RetentionChecks is the list of active quota scans
RetentionChecks ActiveRetentionChecks
)
// ActiveRetentionChecks holds the active retention checks
// ActiveRetentionChecks holds the active quota scans
type ActiveRetentionChecks struct {
sync.RWMutex
Checks []RetentionCheck
@@ -68,7 +66,7 @@ func (c *ActiveRetentionChecks) Get() []RetentionCheck {
checks := make([]RetentionCheck, 0, len(c.Checks))
for _, check := range c.Checks {
foldersCopy := make([]dataprovider.FolderRetention, len(check.Folders))
foldersCopy := make([]FolderRetention, len(check.Folders))
copy(foldersCopy, check.Folders)
notificationsCopy := make([]string, len(check.Notifications))
copy(notificationsCopy, check.Notifications)
@@ -126,6 +124,37 @@ func (c *ActiveRetentionChecks) remove(username string) bool {
return false
}
// FolderRetention defines the retention policy for the specified directory path
type FolderRetention struct {
// Path is the exposed virtual directory path, if no other specific retention is defined,
// the retention applies for sub directories too. For example if retention is defined
// for the paths "/" and "/sub" then the retention for "/" is applied for any file outside
// the "/sub" directory
Path string `json:"path"`
// Retention time in hours. 0 means exclude this path
Retention int `json:"retention"`
// DeleteEmptyDirs defines if empty directories will be deleted.
// The user need the delete permission
DeleteEmptyDirs bool `json:"delete_empty_dirs,omitempty"`
// IgnoreUserPermissions defines if delete files even if the user does not have the delete permission.
// The default is "false" which means that files will be skipped if the user does not have the permission
// to delete them. This applies to sub directories too.
IgnoreUserPermissions bool `json:"ignore_user_permissions,omitempty"`
}
func (f *FolderRetention) isValid() error {
f.Path = path.Clean(f.Path)
if !path.IsAbs(f.Path) {
return util.NewValidationError(fmt.Sprintf("folder retention: invalid path %#v, please specify an absolute POSIX path",
f.Path))
}
if f.Retention < 0 {
return util.NewValidationError(fmt.Sprintf("invalid folder retention %v, it must be greater or equal to zero",
f.Retention))
}
return nil
}
type folderRetentionCheckResult struct {
Path string `json:"path"`
Retention int `json:"retention"`
@@ -143,13 +172,13 @@ type RetentionCheck struct {
// retention check start time as unix timestamp in milliseconds
StartTime int64 `json:"start_time"`
// affected folders
Folders []dataprovider.FolderRetention `json:"folders"`
Folders []FolderRetention `json:"folders"`
// how cleanup results will be notified
Notifications []RetentionCheckNotification `json:"notifications,omitempty"`
// email to use if the notification method is set to email
Email string `json:"email,omitempty"`
// Cleanup results
results []folderRetentionCheckResult `json:"-"`
results []*folderRetentionCheckResult `json:"-"`
conn *BaseConnection
}
@@ -159,7 +188,7 @@ func (c *RetentionCheck) Validate() error {
nothingToDo := true
for idx := range c.Folders {
f := &c.Folders[idx]
if err := f.Validate(); err != nil {
if err := f.isValid(); err != nil {
return err
}
if f.Retention > 0 {
@@ -201,7 +230,7 @@ func (c *RetentionCheck) updateUserPermissions() {
}
}
func (c *RetentionCheck) getFolderRetention(folderPath string) (dataprovider.FolderRetention, error) {
func (c *RetentionCheck) getFolderRetention(folderPath string) (FolderRetention, error) {
dirsForPath := util.GetDirsForVirtualPath(folderPath)
for _, dirPath := range dirsForPath {
for _, folder := range c.Folders {
@@ -211,7 +240,7 @@ func (c *RetentionCheck) getFolderRetention(folderPath string) (dataprovider.Fol
}
}
return dataprovider.FolderRetention{}, fmt.Errorf("unable to find folder retention for %#v", folderPath)
return FolderRetention{}, fmt.Errorf("unable to find folder retention for %#v", folderPath)
}
func (c *RetentionCheck) removeFile(virtualPath string, info os.FileInfo) error {
@@ -225,12 +254,10 @@ func (c *RetentionCheck) removeFile(virtualPath string, info os.FileInfo) error
func (c *RetentionCheck) cleanupFolder(folderPath string) error {
deleteFilesPerms := []string{dataprovider.PermDelete, dataprovider.PermDeleteFiles}
startTime := time.Now()
result := folderRetentionCheckResult{
result := &folderRetentionCheckResult{
Path: folderPath,
}
defer func() {
c.results = append(c.results, result)
}()
c.results = append(c.results, result)
if !c.conn.User.HasPerm(dataprovider.PermListItems, folderPath) || !c.conn.User.HasAnyPerm(deleteFilesPerms, folderPath) {
result.Elapsed = time.Since(startTime)
result.Info = "data retention check skipped: no permissions"
@@ -305,15 +332,7 @@ func (c *RetentionCheck) cleanupFolder(folderPath string) error {
}
func (c *RetentionCheck) checkEmptyDirRemoval(folderPath string) {
if folderPath == "/" {
return
}
for _, folder := range c.Folders {
if folderPath == folder.Path {
return
}
}
if c.conn.User.HasAnyPerm([]string{
if folderPath != "/" && c.conn.User.HasAnyPerm([]string{
dataprovider.PermDelete,
dataprovider.PermDeleteDirs,
}, path.Dir(folderPath),
@@ -327,7 +346,7 @@ func (c *RetentionCheck) checkEmptyDirRemoval(folderPath string) {
}
// Start starts the retention check
func (c *RetentionCheck) Start() error {
func (c *RetentionCheck) Start() {
c.conn.Log(logger.LevelInfo, "retention check started")
defer RetentionChecks.remove(c.conn.User.Username)
defer c.conn.CloseFS() //nolint:errcheck
@@ -338,55 +357,54 @@ func (c *RetentionCheck) Start() error {
if err := c.cleanupFolder(folder.Path); err != nil {
c.conn.Log(logger.LevelError, "retention check failed, unable to cleanup folder %#v", folder.Path)
c.sendNotifications(time.Since(startTime), err)
return err
return
}
}
}
c.conn.Log(logger.LevelInfo, "retention check completed")
c.sendNotifications(time.Since(startTime), nil)
return nil
}
func (c *RetentionCheck) sendNotifications(elapsed time.Duration, err error) {
for _, notification := range c.Notifications {
switch notification {
case RetentionCheckNotificationEmail:
c.sendEmailNotification(err) //nolint:errcheck
c.sendEmailNotification(elapsed, err) //nolint:errcheck
case RetentionCheckNotificationHook:
c.sendHookNotification(elapsed, err) //nolint:errcheck
}
}
}
func (c *RetentionCheck) sendEmailNotification(errCheck error) error {
params := EventParams{}
if len(c.results) > 0 || errCheck != nil {
params.retentionChecks = append(params.retentionChecks, executedRetentionCheck{
Username: c.conn.User.Username,
ActionName: "Retention check",
Results: c.results,
})
func (c *RetentionCheck) sendEmailNotification(elapsed time.Duration, errCheck error) error {
body := new(bytes.Buffer)
data := make(map[string]any)
data["Results"] = c.results
totalDeletedFiles := 0
totalDeletedSize := int64(0)
for _, result := range c.results {
totalDeletedFiles += result.DeletedFiles
totalDeletedSize += result.DeletedSize
}
var files []mail.File
f, err := params.getRetentionReportsAsMailAttachment()
if err != nil {
c.conn.Log(logger.LevelError, "unable to get retention report as mail attachment: %v", err)
data["HumanizeSize"] = util.ByteCountIEC
data["TotalFiles"] = totalDeletedFiles
data["TotalSize"] = totalDeletedSize
data["Elapsed"] = elapsed
data["Username"] = c.conn.User.Username
data["StartTime"] = util.GetTimeFromMsecSinceEpoch(c.StartTime)
if errCheck == nil {
data["Status"] = "Succeeded"
} else {
data["Status"] = "Failed"
}
if err := smtp.RenderRetentionReportTemplate(body, data); err != nil {
c.conn.Log(logger.LevelError, "unable to render retention check template: %v", err)
return err
}
f.Name = "retention-report.zip"
files = append(files, f)
startTime := time.Now()
var subject string
if errCheck == nil {
subject = fmt.Sprintf("Successful retention check for user %q", c.conn.User.Username)
} else {
subject = fmt.Sprintf("Retention check failed for user %q", c.conn.User.Username)
}
body := "Further details attached."
err = smtp.SendEmail([]string{c.Email}, subject, body, smtp.EmailContentTypeTextPlain, files...)
if err != nil {
subject := fmt.Sprintf("Retention check completed for user %#v", c.conn.User.Username)
if err := smtp.SendEmail(c.Email, subject, body.String(), smtp.EmailContentTypeTextHTML); err != nil {
c.conn.Log(logger.LevelError, "unable to notify retention check result via email: %v, elapsed: %v", err,
time.Since(startTime))
return err
@@ -396,9 +414,6 @@ func (c *RetentionCheck) sendEmailNotification(errCheck error) error {
}
func (c *RetentionCheck) sendHookNotification(elapsed time.Duration, errCheck error) error {
startNewHook()
defer hookEnded()
data := make(map[string]any)
totalDeletedFiles := 0
totalDeletedSize := int64(0)
@@ -450,11 +465,11 @@ func (c *RetentionCheck) sendHookNotification(elapsed time.Duration, errCheck er
c.conn.Log(logger.LevelError, "%v", err)
return err
}
timeout, env, args := command.GetConfig(Config.DataRetentionHook, command.HookDataRetention)
timeout, env := command.GetConfig(Config.DataRetentionHook)
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
cmd := exec.CommandContext(ctx, Config.DataRetentionHook, args...)
cmd := exec.CommandContext(ctx, Config.DataRetentionHook)
cmd.Env = append(env,
fmt.Sprintf("SFTPGO_DATA_RETENTION_RESULT=%v", string(jsonData)))
err := cmd.Run()

View File

@@ -26,23 +26,31 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/smtp"
"github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/smtp"
)
func TestRetentionValidation(t *testing.T) {
check := RetentionCheck{}
check.Folders = []dataprovider.FolderRetention{
check.Folders = append(check.Folders, FolderRetention{
Path: "relative",
Retention: 10,
})
err := check.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "please specify an absolute POSIX path")
check.Folders = []FolderRetention{
{
Path: "/",
Retention: -1,
},
}
err := check.Validate()
err = check.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "invalid folder retention")
check.Folders = []dataprovider.FolderRetention{
check.Folders = []FolderRetention{
{
Path: "/ab/..",
Retention: 0,
@@ -53,7 +61,7 @@ func TestRetentionValidation(t *testing.T) {
assert.Contains(t, err.Error(), "nothing to delete")
assert.Equal(t, "/", check.Folders[0].Path)
check.Folders = append(check.Folders, dataprovider.FolderRetention{
check.Folders = append(check.Folders, FolderRetention{
Path: "/../..",
Retention: 24,
})
@@ -61,7 +69,7 @@ func TestRetentionValidation(t *testing.T) {
require.Error(t, err)
assert.Contains(t, err.Error(), `duplicated folder path "/"`)
check.Folders = []dataprovider.FolderRetention{
check.Folders = []FolderRetention{
{
Path: "/dir1",
Retention: 48,
@@ -86,7 +94,7 @@ func TestRetentionValidation(t *testing.T) {
Port: 25,
TemplatesPath: "templates",
}
err = smtpCfg.Initialize(configDir)
err = smtpCfg.Initialize("..")
require.NoError(t, err)
err = check.Validate()
@@ -98,7 +106,7 @@ func TestRetentionValidation(t *testing.T) {
assert.NoError(t, err)
smtpCfg = smtp.Config{}
err = smtpCfg.Initialize(configDir)
err = smtpCfg.Initialize("..")
require.NoError(t, err)
check.Notifications = []RetentionCheckNotification{RetentionCheckNotificationHook}
@@ -118,7 +126,7 @@ func TestRetentionEmailNotifications(t *testing.T) {
Port: 2525,
TemplatesPath: "templates",
}
err := smtpCfg.Initialize(configDir)
err := smtpCfg.Initialize("..")
require.NoError(t, err)
user := dataprovider.User{
@@ -131,7 +139,7 @@ func TestRetentionEmailNotifications(t *testing.T) {
check := RetentionCheck{
Notifications: []RetentionCheckNotification{RetentionCheckNotificationEmail},
Email: "notification@example.com",
results: []folderRetentionCheckResult{
results: []*folderRetentionCheckResult{
{
Path: "/",
Retention: 24,
@@ -146,36 +154,21 @@ func TestRetentionEmailNotifications(t *testing.T) {
conn.ID = fmt.Sprintf("data_retention_%v", user.Username)
check.conn = conn
check.sendNotifications(1*time.Second, nil)
err = check.sendEmailNotification(nil)
err = check.sendEmailNotification(1*time.Second, nil)
assert.NoError(t, err)
err = check.sendEmailNotification(errors.New("test error"))
err = check.sendEmailNotification(1*time.Second, errors.New("test error"))
assert.NoError(t, err)
check.results = nil
err = check.sendEmailNotification(nil)
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "no data retention report available")
}
smtpCfg.Port = 2626
err = smtpCfg.Initialize(configDir)
err = smtpCfg.Initialize("..")
require.NoError(t, err)
err = check.sendEmailNotification(nil)
err = check.sendEmailNotification(1*time.Second, nil)
assert.Error(t, err)
check.results = []folderRetentionCheckResult{
{
Path: "/",
Retention: 24,
DeletedFiles: 20,
DeletedSize: 456789,
Elapsed: 12 * time.Second,
},
}
smtpCfg = smtp.Config{}
err = smtpCfg.Initialize(configDir)
err = smtpCfg.Initialize("..")
require.NoError(t, err)
err = check.sendEmailNotification(nil)
err = check.sendEmailNotification(1*time.Second, nil)
assert.Error(t, err)
}
@@ -192,7 +185,7 @@ func TestRetentionHookNotifications(t *testing.T) {
user.Permissions["/"] = []string{dataprovider.PermAny}
check := RetentionCheck{
Notifications: []RetentionCheckNotification{RetentionCheckNotificationHook},
results: []folderRetentionCheckResult{
results: []*folderRetentionCheckResult{
{
Path: "/",
Retention: 24,
@@ -247,7 +240,7 @@ func TestRetentionPermissionsAndGetFolder(t *testing.T) {
user.Permissions["/dir2/sub2"] = []string{dataprovider.PermDelete}
check := RetentionCheck{
Folders: []dataprovider.FolderRetention{
Folders: []FolderRetention{
{
Path: "/dir2",
Retention: 24 * 7,
@@ -307,7 +300,7 @@ func TestRetentionCheckAddRemove(t *testing.T) {
user.Permissions = make(map[string][]string)
user.Permissions["/"] = []string{dataprovider.PermAny}
check := RetentionCheck{
Folders: []dataprovider.FolderRetention{
Folders: []FolderRetention{
{
Path: "/",
Retention: 48,
@@ -341,7 +334,7 @@ func TestCleanupErrors(t *testing.T) {
user.Permissions = make(map[string][]string)
user.Permissions["/"] = []string{dataprovider.PermAny}
check := &RetentionCheck{
Folders: []dataprovider.FolderRetention{
Folders: []FolderRetention{
{
Path: "/path",
Retention: 48,

View File

@@ -25,9 +25,9 @@ import (
"github.com/yl2chen/cidranger"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/util"
)
// HostEvent is the enumerable for the supported host events

View File

@@ -366,7 +366,7 @@ func TestLoadHostListFromFile(t *testing.T) {
assert.Len(t, hostList.IPAddresses, 0)
assert.Equal(t, 0, hostList.Ranges.Len())
if runtime.GOOS != osWindows {
if runtime.GOOS != "windows" {
err = os.Chmod(hostsFilePath, 0111)
assert.NoError(t, err)

View File

@@ -17,9 +17,9 @@ package common
import (
"time"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/util"
)
type dbDefender struct {
@@ -107,14 +107,6 @@ func (d *dbDefender) AddEvent(ip string, event HostEvent) {
if host.Score > d.config.Threshold {
banTime := time.Now().Add(time.Duration(d.config.BanTime) * time.Minute)
err = dataprovider.SetDefenderBanTime(ip, util.GetTimeAsMsSinceEpoch(banTime))
if err == nil {
eventManager.handleIPBlockedEvent(EventParams{
Event: ipBlockedEventName,
IP: ip,
Timestamp: time.Now().UnixNano(),
Status: 1,
})
}
}
if err == nil {

View File

@@ -24,8 +24,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/util"
)
func TestBasicDbDefender(t *testing.T) {

View File

@@ -18,8 +18,8 @@ import (
"sort"
"time"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/util"
)
type memoryDefender struct {
@@ -209,12 +209,6 @@ func (d *memoryDefender) AddEvent(ip string, event HostEvent) {
d.banned[ip] = time.Now().Add(time.Duration(d.config.BanTime) * time.Minute)
delete(d.hosts, ip)
d.cleanupBanned()
eventManager.handleIPBlockedEvent(EventParams{
Event: ipBlockedEventName,
IP: ip,
Timestamp: time.Now().UnixNano(),
Status: 1,
})
} else {
d.hosts[ip] = hs
}

View File

@@ -24,8 +24,8 @@ import (
"github.com/GehirnInc/crypt/md5_crypt"
"golang.org/x/crypto/bcrypt"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/util"
)
const (

File diff suppressed because it is too large Load Diff

View File

@@ -25,7 +25,7 @@ import (
"golang.org/x/time/rate"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/util"
)
var (
@@ -186,16 +186,16 @@ func (rl *rateLimiter) Wait(source string) (time.Duration, error) {
}
type sourceRateLimiter struct {
lastActivity *atomic.Int64
lastActivity int64
bucket *rate.Limiter
}
func (s *sourceRateLimiter) updateLastActivity() {
s.lastActivity.Store(time.Now().UnixNano())
atomic.StoreInt64(&s.lastActivity, time.Now().UnixNano())
}
func (s *sourceRateLimiter) getLastActivity() int64 {
return s.lastActivity.Load()
return atomic.LoadInt64(&s.lastActivity)
}
type sourceBuckets struct {
@@ -224,8 +224,7 @@ func (b *sourceBuckets) addAndReserve(r *rate.Limiter, source string) *rate.Rese
b.cleanup()
src := sourceRateLimiter{
lastActivity: new(atomic.Int64),
bucket: r,
bucket: r,
}
src.updateLastActivity()
b.buckets[source] = src

View File

@@ -21,7 +21,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/util"
)
func TestRateLimiterConfig(t *testing.T) {

View File

@@ -25,8 +25,8 @@ import (
"sync"
"time"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/util"
)
const (
@@ -97,7 +97,7 @@ func (m *CertManager) loadCertificates() error {
// GetCertificateFunc returns the loaded certificate
func (m *CertManager) GetCertificateFunc(certID string) func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
m.RLock()
defer m.RUnlock()

View File

@@ -21,10 +21,10 @@ import (
"sync/atomic"
"time"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/metric"
"github.com/drakkan/sftpgo/v2/internal/vfs"
"github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/metric"
"github.com/drakkan/sftpgo/v2/vfs"
)
var (
@@ -35,8 +35,8 @@ var (
// BaseTransfer contains protocols common transfer details for an upload or a download.
type BaseTransfer struct { //nolint:maligned
ID int64
BytesSent atomic.Int64
BytesReceived atomic.Int64
BytesSent int64
BytesReceived int64
Fs vfs.Fs
File vfs.File
Connection *BaseConnection
@@ -52,7 +52,7 @@ type BaseTransfer struct { //nolint:maligned
truncatedSize int64
isNewFile bool
transferType int
AbortTransfer atomic.Bool
AbortTransfer int32
aTime time.Time
mTime time.Time
transferQuota dataprovider.TransferQuota
@@ -79,14 +79,14 @@ func NewBaseTransfer(file vfs.File, conn *BaseConnection, cancelFn func(), fsPat
InitialSize: initialSize,
isNewFile: isNewFile,
requestPath: requestPath,
BytesSent: 0,
BytesReceived: 0,
MaxWriteSize: maxWriteSize,
AbortTransfer: 0,
truncatedSize: truncatedSize,
transferQuota: transferQuota,
Fs: fs,
}
t.AbortTransfer.Store(false)
t.BytesSent.Store(0)
t.BytesReceived.Store(0)
conn.AddTransfer(t)
return t
@@ -115,19 +115,19 @@ func (t *BaseTransfer) GetType() int {
// GetSize returns the transferred size
func (t *BaseTransfer) GetSize() int64 {
if t.transferType == TransferDownload {
return t.BytesSent.Load()
return atomic.LoadInt64(&t.BytesSent)
}
return t.BytesReceived.Load()
return atomic.LoadInt64(&t.BytesReceived)
}
// GetDownloadedSize returns the transferred size
func (t *BaseTransfer) GetDownloadedSize() int64 {
return t.BytesSent.Load()
return atomic.LoadInt64(&t.BytesSent)
}
// GetUploadedSize returns the transferred size
func (t *BaseTransfer) GetUploadedSize() int64 {
return t.BytesReceived.Load()
return atomic.LoadInt64(&t.BytesReceived)
}
// GetStartTime returns the start time
@@ -153,7 +153,7 @@ func (t *BaseTransfer) SignalClose(err error) {
t.Lock()
t.errAbort = err
t.Unlock()
t.AbortTransfer.Store(true)
atomic.StoreInt32(&(t.AbortTransfer), 1)
}
// GetTruncatedSize returns the truncated sized if this is an upload overwriting
@@ -217,11 +217,11 @@ func (t *BaseTransfer) CheckRead() error {
return nil
}
if t.transferQuota.AllowedTotalSize > 0 {
if t.BytesSent.Load()+t.BytesReceived.Load() > t.transferQuota.AllowedTotalSize {
if atomic.LoadInt64(&t.BytesSent)+atomic.LoadInt64(&t.BytesReceived) > t.transferQuota.AllowedTotalSize {
return t.Connection.GetReadQuotaExceededError()
}
} else if t.transferQuota.AllowedDLSize > 0 {
if t.BytesSent.Load() > t.transferQuota.AllowedDLSize {
if atomic.LoadInt64(&t.BytesSent) > t.transferQuota.AllowedDLSize {
return t.Connection.GetReadQuotaExceededError()
}
}
@@ -230,18 +230,18 @@ func (t *BaseTransfer) CheckRead() error {
// CheckWrite returns an error if write if not allowed
func (t *BaseTransfer) CheckWrite() error {
if t.MaxWriteSize > 0 && t.BytesReceived.Load() > t.MaxWriteSize {
if t.MaxWriteSize > 0 && atomic.LoadInt64(&t.BytesReceived) > t.MaxWriteSize {
return t.Connection.GetQuotaExceededError()
}
if t.transferQuota.AllowedULSize == 0 && t.transferQuota.AllowedTotalSize == 0 {
return nil
}
if t.transferQuota.AllowedTotalSize > 0 {
if t.BytesSent.Load()+t.BytesReceived.Load() > t.transferQuota.AllowedTotalSize {
if atomic.LoadInt64(&t.BytesSent)+atomic.LoadInt64(&t.BytesReceived) > t.transferQuota.AllowedTotalSize {
return t.Connection.GetQuotaExceededError()
}
} else if t.transferQuota.AllowedULSize > 0 {
if t.BytesReceived.Load() > t.transferQuota.AllowedULSize {
if atomic.LoadInt64(&t.BytesReceived) > t.transferQuota.AllowedULSize {
return t.Connection.GetQuotaExceededError()
}
}
@@ -261,14 +261,13 @@ func (t *BaseTransfer) Truncate(fsPath string, size int64) (int64, error) {
if t.MaxWriteSize > 0 {
sizeDiff := initialSize - size
t.MaxWriteSize += sizeDiff
metric.TransferCompleted(t.BytesSent.Load(), t.BytesReceived.Load(),
t.transferType, t.ErrTransfer, vfs.IsSFTPFs(t.Fs))
metric.TransferCompleted(atomic.LoadInt64(&t.BytesSent), atomic.LoadInt64(&t.BytesReceived), t.transferType, t.ErrTransfer)
if t.transferQuota.HasSizeLimits() {
go func(ulSize, dlSize int64, user dataprovider.User) {
dataprovider.UpdateUserTransferQuota(&user, ulSize, dlSize, false) //nolint:errcheck
}(t.BytesReceived.Load(), t.BytesSent.Load(), t.Connection.User)
}(atomic.LoadInt64(&t.BytesReceived), atomic.LoadInt64(&t.BytesSent), t.Connection.User)
}
t.BytesReceived.Store(0)
atomic.StoreInt64(&t.BytesReceived, 0)
}
t.Unlock()
}
@@ -276,7 +275,7 @@ func (t *BaseTransfer) Truncate(fsPath string, size int64) (int64, error) {
fsPath, size, t.MaxWriteSize, t.InitialSize, err)
return initialSize, err
}
if size == 0 && t.BytesSent.Load() == 0 {
if size == 0 && atomic.LoadInt64(&t.BytesSent) == 0 {
// for cloud providers the file is always truncated to zero, we don't support append/resume for uploads
// for buffered SFTP we can have buffered bytes so we returns an error
if !vfs.IsBufferedSFTPFs(t.Fs) {
@@ -302,30 +301,23 @@ func (t *BaseTransfer) TransferError(err error) {
}
elapsed := time.Since(t.start).Nanoseconds() / 1000000
t.Connection.Log(logger.LevelError, "Unexpected error for transfer, path: %#v, error: \"%v\" bytes sent: %v, "+
"bytes received: %v transfer running since %v ms", t.fsPath, t.ErrTransfer, t.BytesSent.Load(),
t.BytesReceived.Load(), elapsed)
"bytes received: %v transfer running since %v ms", t.fsPath, t.ErrTransfer, atomic.LoadInt64(&t.BytesSent),
atomic.LoadInt64(&t.BytesReceived), elapsed)
}
func (t *BaseTransfer) getUploadFileSize() (int64, int, error) {
func (t *BaseTransfer) getUploadFileSize() (int64, error) {
var fileSize int64
var deletedFiles int
info, err := t.Fs.Stat(t.fsPath)
if err == nil {
fileSize = info.Size()
}
if t.ErrTransfer != nil && vfs.IsCryptOsFs(t.Fs) {
if vfs.IsCryptOsFs(t.Fs) && t.ErrTransfer != nil {
errDelete := t.Fs.Remove(t.fsPath, false)
if errDelete != nil {
t.Connection.Log(logger.LevelWarn, "error removing partial crypto file %#v: %v", t.fsPath, errDelete)
} else {
fileSize = 0
deletedFiles = 1
t.BytesReceived.Store(0)
t.MinWriteOffset = 0
}
}
return fileSize, deletedFiles, err
return fileSize, err
}
// return 1 if the file is outside the user home dir
@@ -340,7 +332,7 @@ func (t *BaseTransfer) checkUploadOutsideHomeDir(err error) int {
t.Connection.Log(logger.LevelWarn, "upload in temp path cannot be renamed, delete temporary file: %#v, deletion error: %v",
t.effectiveFsPath, err)
// the file is outside the home dir so don't update the quota
t.BytesReceived.Store(0)
atomic.StoreInt64(&t.BytesReceived, 0)
t.MinWriteOffset = 0
return 1
}
@@ -354,18 +346,22 @@ func (t *BaseTransfer) Close() error {
defer t.Connection.RemoveTransfer(t)
var err error
numFiles := t.getUploadedFiles()
metric.TransferCompleted(t.BytesSent.Load(), t.BytesReceived.Load(),
t.transferType, t.ErrTransfer, vfs.IsSFTPFs(t.Fs))
numFiles := 0
if t.isNewFile {
numFiles = 1
}
metric.TransferCompleted(atomic.LoadInt64(&t.BytesSent), atomic.LoadInt64(&t.BytesReceived),
t.transferType, t.ErrTransfer)
if t.transferQuota.HasSizeLimits() {
dataprovider.UpdateUserTransferQuota(&t.Connection.User, t.BytesReceived.Load(), //nolint:errcheck
t.BytesSent.Load(), false)
dataprovider.UpdateUserTransferQuota(&t.Connection.User, atomic.LoadInt64(&t.BytesReceived), //nolint:errcheck
atomic.LoadInt64(&t.BytesSent), false)
}
if t.File != nil && t.Connection.IsQuotaExceededError(t.ErrTransfer) {
// if quota is exceeded we try to remove the partial file for uploads to local filesystem
err = t.Fs.Remove(t.File.Name(), false)
if err == nil {
t.BytesReceived.Store(0)
numFiles--
atomic.StoreInt64(&t.BytesReceived, 0)
t.MinWriteOffset = 0
}
t.Connection.Log(logger.LevelWarn, "upload denied due to space limit, delete temporary file: %#v, deletion error: %v",
@@ -376,43 +372,35 @@ func (t *BaseTransfer) Close() error {
t.Connection.Log(logger.LevelDebug, "atomic upload completed, rename: %#v -> %#v, error: %v",
t.effectiveFsPath, t.fsPath, err)
// the file must be removed if it is uploaded to a path outside the home dir and cannot be renamed
t.checkUploadOutsideHomeDir(err)
numFiles -= t.checkUploadOutsideHomeDir(err)
} else {
err = t.Fs.Remove(t.effectiveFsPath, false)
t.Connection.Log(logger.LevelWarn, "atomic upload completed with error: \"%v\", delete temporary file: %#v, deletion error: %v",
t.ErrTransfer, t.effectiveFsPath, err)
if err == nil {
t.BytesReceived.Store(0)
numFiles--
atomic.StoreInt64(&t.BytesReceived, 0)
t.MinWriteOffset = 0
}
}
}
elapsed := time.Since(t.start).Nanoseconds() / 1000000
var uploadFileSize int64
if t.transferType == TransferDownload {
logger.TransferLog(downloadLogSender, t.fsPath, elapsed, t.BytesSent.Load(), t.Connection.User.Username,
logger.TransferLog(downloadLogSender, t.fsPath, elapsed, atomic.LoadInt64(&t.BytesSent), t.Connection.User.Username,
t.Connection.ID, t.Connection.protocol, t.Connection.localAddr, t.Connection.remoteAddr, t.ftpMode)
ExecuteActionNotification(t.Connection, operationDownload, t.fsPath, t.requestPath, "", "", "", //nolint:errcheck
t.BytesSent.Load(), t.ErrTransfer)
ExecuteActionNotification(t.Connection, operationDownload, t.fsPath, t.requestPath, "", "", "",
atomic.LoadInt64(&t.BytesSent), t.ErrTransfer)
} else {
statSize, deletedFiles, errStat := t.getUploadFileSize()
if errStat == nil {
uploadFileSize = statSize
} else {
uploadFileSize = t.BytesReceived.Load() + t.MinWriteOffset
if t.Fs.IsNotExist(errStat) {
uploadFileSize = 0
numFiles--
}
fileSize := atomic.LoadInt64(&t.BytesReceived) + t.MinWriteOffset
if statSize, errStat := t.getUploadFileSize(); errStat == nil {
fileSize = statSize
}
numFiles -= deletedFiles
t.Connection.Log(logger.LevelDebug, "upload file size %d, num files %d, deleted files %d, fs path %q",
uploadFileSize, numFiles, deletedFiles, t.fsPath)
numFiles, uploadFileSize = t.executeUploadHook(numFiles, uploadFileSize)
t.updateQuota(numFiles, uploadFileSize)
t.Connection.Log(logger.LevelDebug, "uploaded file size %v", fileSize)
t.updateQuota(numFiles, fileSize)
t.updateTimes()
logger.TransferLog(uploadLogSender, t.fsPath, elapsed, t.BytesReceived.Load(), t.Connection.User.Username,
logger.TransferLog(uploadLogSender, t.fsPath, elapsed, atomic.LoadInt64(&t.BytesReceived), t.Connection.User.Username,
t.Connection.ID, t.Connection.protocol, t.Connection.localAddr, t.Connection.remoteAddr, t.ftpMode)
ExecuteActionNotification(t.Connection, operationUpload, t.fsPath, t.requestPath, "", "", "", fileSize, t.ErrTransfer)
}
if t.ErrTransfer != nil {
t.Connection.Log(logger.LevelError, "transfer error: %v, path: %#v", t.ErrTransfer, t.fsPath)
@@ -420,62 +408,9 @@ func (t *BaseTransfer) Close() error {
err = t.ErrTransfer
}
}
t.updateTransferTimestamps(uploadFileSize)
return err
}
func (t *BaseTransfer) updateTransferTimestamps(uploadFileSize int64) {
if t.ErrTransfer != nil {
return
}
if t.transferType == TransferUpload {
if t.Connection.User.FirstUpload == 0 && !t.Connection.uploadDone.Load() {
if err := dataprovider.UpdateUserTransferTimestamps(t.Connection.User.Username, true); err == nil {
t.Connection.uploadDone.Store(true)
ExecuteActionNotification(t.Connection, operationFirstUpload, t.fsPath, t.requestPath, "", //nolint:errcheck
"", "", uploadFileSize, t.ErrTransfer)
}
}
return
}
if t.Connection.User.FirstDownload == 0 && !t.Connection.downloadDone.Load() && t.BytesSent.Load() > 0 {
if err := dataprovider.UpdateUserTransferTimestamps(t.Connection.User.Username, false); err == nil {
t.Connection.downloadDone.Store(true)
ExecuteActionNotification(t.Connection, operationFirstDownload, t.fsPath, t.requestPath, "", //nolint:errcheck
"", "", t.BytesSent.Load(), t.ErrTransfer)
}
}
}
func (t *BaseTransfer) executeUploadHook(numFiles int, fileSize int64) (int, int64) {
err := ExecuteActionNotification(t.Connection, operationUpload, t.fsPath, t.requestPath, "", "", "",
fileSize, t.ErrTransfer)
if err != nil {
if t.ErrTransfer == nil {
t.ErrTransfer = err
}
// try to remove the uploaded file
err = t.Fs.Remove(t.fsPath, false)
if err == nil {
numFiles--
fileSize = 0
t.BytesReceived.Store(0)
t.MinWriteOffset = 0
} else {
t.Connection.Log(logger.LevelWarn, "unable to remove path %q after upload hook failure: %v", t.fsPath, err)
}
}
return numFiles, fileSize
}
func (t *BaseTransfer) getUploadedFiles() int {
numFiles := 0
if t.isNewFile {
numFiles = 1
}
return numFiles
}
func (t *BaseTransfer) updateTimes() {
if !t.aTime.IsZero() && !t.mTime.IsZero() {
err := t.Fs.Chtimes(t.fsPath, t.aTime, t.mTime, true)
@@ -485,12 +420,12 @@ func (t *BaseTransfer) updateTimes() {
}
func (t *BaseTransfer) updateQuota(numFiles int, fileSize int64) bool {
// Uploads on some filesystem (S3 and similar) are atomic, if there is an error nothing is uploaded
if t.File == nil && t.ErrTransfer != nil && vfs.HasImplicitAtomicUploads(t.Fs) {
// S3 uploads are atomic, if there is an error nothing is uploaded
if t.File == nil && t.ErrTransfer != nil && !t.Connection.User.HasBufferedSFTP(t.GetVirtualPath()) {
return false
}
sizeDiff := fileSize - t.InitialSize
if t.transferType == TransferUpload && (numFiles != 0 || sizeDiff != 0) {
if t.transferType == TransferUpload && (numFiles != 0 || sizeDiff > 0) {
vfolder, err := t.Connection.User.GetVirtualFolderForPath(path.Dir(t.requestPath))
if err == nil {
dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, numFiles, //nolint:errcheck
@@ -512,10 +447,10 @@ func (t *BaseTransfer) HandleThrottle() {
var trasferredBytes int64
if t.transferType == TransferDownload {
wantedBandwidth = t.Connection.User.DownloadBandwidth
trasferredBytes = t.BytesSent.Load()
trasferredBytes = atomic.LoadInt64(&t.BytesSent)
} else {
wantedBandwidth = t.Connection.User.UploadBandwidth
trasferredBytes = t.BytesReceived.Load()
trasferredBytes = atomic.LoadInt64(&t.BytesReceived)
}
if wantedBandwidth > 0 {
// real and wanted elapsed as milliseconds, bytes as kilobytes

View File

@@ -25,21 +25,22 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/kms"
"github.com/drakkan/sftpgo/v2/internal/vfs"
"github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/kms"
"github.com/drakkan/sftpgo/v2/vfs"
)
func TestTransferUpdateQuota(t *testing.T) {
conn := NewBaseConnection("", ProtocolSFTP, "", "", dataprovider.User{})
transfer := BaseTransfer{
Connection: conn,
transferType: TransferUpload,
Fs: vfs.NewOsFs("", os.TempDir(), ""),
Connection: conn,
transferType: TransferUpload,
BytesReceived: 123,
Fs: vfs.NewOsFs("", os.TempDir(), ""),
}
transfer.BytesReceived.Store(123)
errFake := errors.New("fake error")
transfer.TransferError(errFake)
assert.False(t, transfer.updateQuota(1, 0))
err := transfer.Close()
if assert.Error(t, err) {
assert.EqualError(t, err, errFake.Error())
@@ -55,15 +56,11 @@ func TestTransferUpdateQuota(t *testing.T) {
QuotaSize: -1,
})
transfer.ErrTransfer = nil
transfer.BytesReceived.Store(1)
transfer.BytesReceived = 1
transfer.requestPath = "/vdir/file"
assert.True(t, transfer.updateQuota(1, 0))
err = transfer.Close()
assert.NoError(t, err)
transfer.ErrTransfer = errFake
transfer.Fs = newMockOsFs(true, "", "", "S3Fs fake")
assert.False(t, transfer.updateQuota(1, 0))
}
func TestTransferThrottling(t *testing.T) {
@@ -83,7 +80,7 @@ func TestTransferThrottling(t *testing.T) {
wantedDownloadElapsed -= wantedDownloadElapsed / 10
conn := NewBaseConnection("id", ProtocolSCP, "", "", u)
transfer := NewBaseTransfer(nil, conn, nil, "", "", "", TransferUpload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{})
transfer.BytesReceived.Store(testFileSize)
transfer.BytesReceived = testFileSize
transfer.Connection.UpdateLastActivity()
startTime := transfer.Connection.GetLastActivity()
transfer.HandleThrottle()
@@ -93,7 +90,7 @@ func TestTransferThrottling(t *testing.T) {
assert.NoError(t, err)
transfer = NewBaseTransfer(nil, conn, nil, "", "", "", TransferDownload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{})
transfer.BytesSent.Store(testFileSize)
transfer.BytesSent = testFileSize
transfer.Connection.UpdateLastActivity()
startTime = transfer.Connection.GetLastActivity()
@@ -229,7 +226,7 @@ func TestTransferErrors(t *testing.T) {
assert.Equal(t, testFile, transfer.GetFsPath())
transfer.SetCancelFn(cancelFn)
errFake := errors.New("err fake")
transfer.BytesReceived.Store(9)
transfer.BytesReceived = 9
transfer.TransferError(ErrQuotaExceeded)
assert.True(t, isCancelled)
transfer.TransferError(errFake)
@@ -252,7 +249,7 @@ func TestTransferErrors(t *testing.T) {
fsPath := filepath.Join(os.TempDir(), "test_file")
transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, 0, true,
fs, dataprovider.TransferQuota{})
transfer.BytesReceived.Store(9)
transfer.BytesReceived = 9
transfer.TransferError(errFake)
assert.Error(t, transfer.ErrTransfer, errFake.Error())
// the file is closed from the embedding struct before to call close
@@ -272,7 +269,7 @@ func TestTransferErrors(t *testing.T) {
}
transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, 0, true,
fs, dataprovider.TransferQuota{})
transfer.BytesReceived.Store(9)
transfer.BytesReceived = 9
// the file is closed from the embedding struct before to call close
err = file.Close()
assert.NoError(t, err)
@@ -300,25 +297,24 @@ func TestRemovePartialCryptoFile(t *testing.T) {
transfer := NewBaseTransfer(nil, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload,
0, 0, 0, 0, true, fs, dataprovider.TransferQuota{})
transfer.ErrTransfer = errors.New("test error")
_, _, err = transfer.getUploadFileSize()
_, err = transfer.getUploadFileSize()
assert.Error(t, err)
err = os.WriteFile(testFile, []byte("test data"), os.ModePerm)
assert.NoError(t, err)
size, deletedFiles, err := transfer.getUploadFileSize()
size, err := transfer.getUploadFileSize()
assert.NoError(t, err)
assert.Equal(t, int64(0), size)
assert.Equal(t, 1, deletedFiles)
assert.Equal(t, int64(9), size)
assert.NoFileExists(t, testFile)
}
func TestFTPMode(t *testing.T) {
conn := NewBaseConnection("", ProtocolFTP, "", "", dataprovider.User{})
transfer := BaseTransfer{
Connection: conn,
transferType: TransferUpload,
Fs: vfs.NewOsFs("", os.TempDir(), ""),
Connection: conn,
transferType: TransferUpload,
BytesReceived: 123,
Fs: vfs.NewOsFs("", os.TempDir(), ""),
}
transfer.BytesReceived.Store(123)
assert.Empty(t, transfer.ftpMode)
transfer.SetFtpMode("active")
assert.Equal(t, "active", transfer.ftpMode)
@@ -403,14 +399,14 @@ func TestTransferQuota(t *testing.T) {
transfer.transferQuota = dataprovider.TransferQuota{
AllowedTotalSize: 10,
}
transfer.BytesReceived.Store(5)
transfer.BytesSent.Store(4)
transfer.BytesReceived = 5
transfer.BytesSent = 4
err = transfer.CheckRead()
assert.NoError(t, err)
err = transfer.CheckWrite()
assert.NoError(t, err)
transfer.BytesSent.Store(6)
transfer.BytesSent = 6
err = transfer.CheckRead()
if assert.Error(t, err) {
assert.Contains(t, err.Error(), ErrReadQuotaExceeded.Error())
@@ -432,7 +428,7 @@ func TestTransferQuota(t *testing.T) {
err = transfer.CheckWrite()
assert.NoError(t, err)
transfer.BytesReceived.Store(11)
transfer.BytesReceived = 11
err = transfer.CheckRead()
if assert.Error(t, err) {
assert.Contains(t, err.Error(), ErrReadQuotaExceeded.Error())
@@ -446,11 +442,11 @@ func TestUploadOutsideHomeRenameError(t *testing.T) {
conn := NewBaseConnection("", ProtocolSFTP, "", "", dataprovider.User{})
transfer := BaseTransfer{
Connection: conn,
transferType: TransferUpload,
Fs: vfs.NewOsFs("", filepath.Join(os.TempDir(), "home"), ""),
Connection: conn,
transferType: TransferUpload,
BytesReceived: 123,
Fs: vfs.NewOsFs("", filepath.Join(os.TempDir(), "home"), ""),
}
transfer.BytesReceived.Store(123)
fileName := filepath.Join(os.TempDir(), "_temp")
err := os.WriteFile(fileName, []byte(`data`), 0644)
@@ -463,10 +459,10 @@ func TestUploadOutsideHomeRenameError(t *testing.T) {
Config.TempPath = filepath.Clean(os.TempDir())
res = transfer.checkUploadOutsideHomeDir(nil)
assert.Equal(t, 0, res)
assert.Greater(t, transfer.BytesReceived.Load(), int64(0))
assert.Greater(t, transfer.BytesReceived, int64(0))
res = transfer.checkUploadOutsideHomeDir(os.ErrPermission)
assert.Equal(t, 1, res)
assert.Equal(t, int64(0), transfer.BytesReceived.Load())
assert.Equal(t, int64(0), transfer.BytesReceived)
assert.NoFileExists(t, fileName)
Config.TempPath = oldTempPath

View File

@@ -19,9 +19,9 @@ import (
"sync"
"time"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/util"
)
type overquotaTransfer struct {

View File

@@ -21,6 +21,7 @@ import (
"path/filepath"
"strconv"
"strings"
"sync/atomic"
"testing"
"time"
@@ -28,9 +29,9 @@ import (
"github.com/sftpgo/sdk"
"github.com/stretchr/testify/assert"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/internal/vfs"
"github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/util"
"github.com/drakkan/sftpgo/v2/vfs"
)
func TestTransfersCheckerDiskQuota(t *testing.T) {
@@ -95,7 +96,7 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
}
transfer1 := NewBaseTransfer(nil, conn1, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"),
"/file1", TransferUpload, 0, 0, 120, 0, true, fsUser, dataprovider.TransferQuota{})
transfer1.BytesReceived.Store(150)
transfer1.BytesReceived = 150
err = Connections.Add(fakeConn1)
assert.NoError(t, err)
// the transferschecker will do nothing if there is only one ongoing transfer
@@ -109,8 +110,8 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
}
transfer2 := NewBaseTransfer(nil, conn2, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"),
"/file2", TransferUpload, 0, 0, 120, 40, true, fsUser, dataprovider.TransferQuota{})
transfer1.BytesReceived.Store(50)
transfer2.BytesReceived.Store(60)
transfer1.BytesReceived = 50
transfer2.BytesReceived = 60
err = Connections.Add(fakeConn2)
assert.NoError(t, err)
@@ -121,7 +122,7 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
}
transfer3 := NewBaseTransfer(nil, conn3, nil, filepath.Join(user.HomeDir, "file3"), filepath.Join(user.HomeDir, "file3"),
"/file3", TransferDownload, 0, 0, 120, 0, true, fsUser, dataprovider.TransferQuota{})
transfer3.BytesReceived.Store(60) // this value will be ignored, this is a download
transfer3.BytesReceived = 60 // this value will be ignored, this is a download
err = Connections.Add(fakeConn3)
assert.NoError(t, err)
@@ -131,20 +132,20 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
assert.Nil(t, transfer2.errAbort)
assert.Nil(t, transfer3.errAbort)
transfer1.BytesReceived.Store(80) // truncated size will be subtracted, we are not overquota
transfer1.BytesReceived = 80 // truncated size will be subtracted, we are not overquota
Connections.checkTransfers()
assert.Nil(t, transfer1.errAbort)
assert.Nil(t, transfer2.errAbort)
assert.Nil(t, transfer3.errAbort)
transfer1.BytesReceived.Store(120)
transfer1.BytesReceived = 120
// we are now overquota
// if another check is in progress nothing is done
Connections.transfersCheckStatus.Store(true)
atomic.StoreInt32(&Connections.transfersCheckStatus, 1)
Connections.checkTransfers()
assert.Nil(t, transfer1.errAbort)
assert.Nil(t, transfer2.errAbort)
assert.Nil(t, transfer3.errAbort)
Connections.transfersCheckStatus.Store(false)
atomic.StoreInt32(&Connections.transfersCheckStatus, 0)
Connections.checkTransfers()
assert.True(t, conn1.IsQuotaExceededError(transfer1.errAbort), transfer1.errAbort)
@@ -171,8 +172,8 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
assert.Nil(t, transfer2.errAbort)
assert.Nil(t, transfer3.errAbort)
// now check a public folder
transfer1.BytesReceived.Store(0)
transfer2.BytesReceived.Store(0)
transfer1.BytesReceived = 0
transfer2.BytesReceived = 0
connID4 := xid.New().String()
fsFolder, err := user.GetFilesystemForPath(path.Join(vdirPath, "/file1"), connID4)
assert.NoError(t, err)
@@ -196,12 +197,12 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
err = Connections.Add(fakeConn5)
assert.NoError(t, err)
transfer4.BytesReceived.Store(50)
transfer5.BytesReceived.Store(40)
transfer4.BytesReceived = 50
transfer5.BytesReceived = 40
Connections.checkTransfers()
assert.Nil(t, transfer4.errAbort)
assert.Nil(t, transfer5.errAbort)
transfer5.BytesReceived.Store(60)
transfer5.BytesReceived = 60
Connections.checkTransfers()
assert.Nil(t, transfer1.errAbort)
assert.Nil(t, transfer2.errAbort)
@@ -285,7 +286,7 @@ func TestTransferCheckerTransferQuota(t *testing.T) {
}
transfer1 := NewBaseTransfer(nil, conn1, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"),
"/file1", TransferUpload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedTotalSize: 100})
transfer1.BytesReceived.Store(150)
transfer1.BytesReceived = 150
err = Connections.Add(fakeConn1)
assert.NoError(t, err)
// the transferschecker will do nothing if there is only one ongoing transfer
@@ -299,26 +300,26 @@ func TestTransferCheckerTransferQuota(t *testing.T) {
}
transfer2 := NewBaseTransfer(nil, conn2, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"),
"/file2", TransferUpload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedTotalSize: 100})
transfer2.BytesReceived.Store(150)
transfer2.BytesReceived = 150
err = Connections.Add(fakeConn2)
assert.NoError(t, err)
Connections.checkTransfers()
assert.Nil(t, transfer1.errAbort)
assert.Nil(t, transfer2.errAbort)
// now test overquota
transfer1.BytesReceived.Store(1024*1024 + 1)
transfer2.BytesReceived.Store(0)
transfer1.BytesReceived = 1024*1024 + 1
transfer2.BytesReceived = 0
Connections.checkTransfers()
assert.True(t, conn1.IsQuotaExceededError(transfer1.errAbort))
assert.Nil(t, transfer2.errAbort)
transfer1.errAbort = nil
transfer1.BytesReceived.Store(1024*1024 + 1)
transfer2.BytesReceived.Store(1024)
transfer1.BytesReceived = 1024*1024 + 1
transfer2.BytesReceived = 1024
Connections.checkTransfers()
assert.True(t, conn1.IsQuotaExceededError(transfer1.errAbort))
assert.True(t, conn2.IsQuotaExceededError(transfer2.errAbort))
transfer1.BytesReceived.Store(0)
transfer2.BytesReceived.Store(0)
transfer1.BytesReceived = 0
transfer2.BytesReceived = 0
transfer1.errAbort = nil
transfer2.errAbort = nil
@@ -336,7 +337,7 @@ func TestTransferCheckerTransferQuota(t *testing.T) {
}
transfer3 := NewBaseTransfer(nil, conn3, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"),
"/file1", TransferDownload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedDLSize: 100})
transfer3.BytesSent.Store(150)
transfer3.BytesSent = 150
err = Connections.Add(fakeConn3)
assert.NoError(t, err)
@@ -347,15 +348,15 @@ func TestTransferCheckerTransferQuota(t *testing.T) {
}
transfer4 := NewBaseTransfer(nil, conn4, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"),
"/file2", TransferDownload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedDLSize: 100})
transfer4.BytesSent.Store(150)
transfer4.BytesSent = 150
err = Connections.Add(fakeConn4)
assert.NoError(t, err)
Connections.checkTransfers()
assert.Nil(t, transfer3.errAbort)
assert.Nil(t, transfer4.errAbort)
transfer3.BytesSent.Store(512 * 1024)
transfer4.BytesSent.Store(512*1024 + 1)
transfer3.BytesSent = 512 * 1024
transfer4.BytesSent = 512*1024 + 1
Connections.checkTransfers()
if assert.Error(t, transfer3.errAbort) {
assert.Contains(t, transfer3.errAbort.Error(), ErrReadQuotaExceeded.Error())

View File

@@ -24,25 +24,24 @@ import (
"strings"
"github.com/spf13/viper"
"github.com/subosito/gotenv"
"github.com/drakkan/sftpgo/v2/internal/acme"
"github.com/drakkan/sftpgo/v2/internal/command"
"github.com/drakkan/sftpgo/v2/internal/common"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/ftpd"
"github.com/drakkan/sftpgo/v2/internal/httpclient"
"github.com/drakkan/sftpgo/v2/internal/httpd"
"github.com/drakkan/sftpgo/v2/internal/kms"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/mfa"
"github.com/drakkan/sftpgo/v2/internal/plugin"
"github.com/drakkan/sftpgo/v2/internal/sftpd"
"github.com/drakkan/sftpgo/v2/internal/smtp"
"github.com/drakkan/sftpgo/v2/internal/telemetry"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/internal/version"
"github.com/drakkan/sftpgo/v2/internal/webdavd"
"github.com/drakkan/sftpgo/v2/acme"
"github.com/drakkan/sftpgo/v2/command"
"github.com/drakkan/sftpgo/v2/common"
"github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/ftpd"
"github.com/drakkan/sftpgo/v2/httpclient"
"github.com/drakkan/sftpgo/v2/httpd"
"github.com/drakkan/sftpgo/v2/kms"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/mfa"
"github.com/drakkan/sftpgo/v2/plugin"
"github.com/drakkan/sftpgo/v2/sftpd"
"github.com/drakkan/sftpgo/v2/smtp"
"github.com/drakkan/sftpgo/v2/telemetry"
"github.com/drakkan/sftpgo/v2/util"
"github.com/drakkan/sftpgo/v2/version"
"github.com/drakkan/sftpgo/v2/webdavd"
)
const (
@@ -82,26 +81,24 @@ var (
Debug: false,
}
defaultWebDAVDBinding = webdavd.Binding{
Address: "",
Port: 0,
EnableHTTPS: false,
CertificateFile: "",
CertificateKeyFile: "",
MinTLSVersion: 12,
ClientAuthType: 0,
TLSCipherSuites: nil,
Prefix: "",
ProxyAllowed: nil,
ClientIPProxyHeader: "",
ClientIPHeaderDepth: 0,
DisableWWWAuthHeader: false,
Address: "",
Port: 0,
EnableHTTPS: false,
CertificateFile: "",
CertificateKeyFile: "",
MinTLSVersion: 12,
ClientAuthType: 0,
TLSCipherSuites: nil,
Prefix: "",
ProxyAllowed: nil,
ClientIPProxyHeader: "",
ClientIPHeaderDepth: 0,
}
defaultHTTPDBinding = httpd.Binding{
Address: "",
Port: 8080,
EnableWebAdmin: true,
EnableWebClient: true,
EnableRESTAPI: true,
EnabledLoginMethods: 0,
EnableHTTPS: false,
CertificateFile: "",
@@ -116,17 +113,16 @@ var (
RenderOpenAPI: true,
WebClientIntegrations: nil,
OIDC: httpd.OIDC{
ClientID: "",
ClientSecret: "",
ConfigURL: "",
RedirectBaseURL: "",
UsernameField: "",
RoleField: "",
ImplicitRoles: false,
Scopes: []string{"openid", "profile", "email"},
CustomFields: []string{},
InsecureSkipSignatureCheck: false,
Debug: false,
ClientID: "",
ClientSecret: "",
ConfigURL: "",
RedirectBaseURL: "",
UsernameField: "",
RoleField: "",
ImplicitRoles: false,
Scopes: []string{"openid", "profile", "email"},
CustomFields: []string{},
Debug: false,
},
Security: httpd.SecurityConf{
Enabled: false,
@@ -210,7 +206,6 @@ func Init() {
MaxTotalConnections: 0,
MaxPerHostConnections: 20,
WhiteListFile: "",
AllowSelfConnections: 0,
DefenderConfig: common.DefenderConfig{
Enabled: false,
Driver: common.DefenderDriverMemory,
@@ -290,16 +285,13 @@ func Init() {
CACertificates: []string{},
CARevocationLists: []string{},
Cors: webdavd.CorsConfig{
Enabled: false,
AllowedOrigins: []string{},
AllowedMethods: []string{},
AllowedHeaders: []string{},
ExposedHeaders: []string{},
AllowCredentials: false,
MaxAge: 0,
OptionsPassthrough: false,
OptionsSuccessStatus: 0,
AllowPrivateNetwork: false,
Enabled: false,
AllowedOrigins: []string{},
AllowedMethods: []string{},
AllowedHeaders: []string{},
ExposedHeaders: []string{},
AllowCredentials: false,
MaxAge: 0,
},
Cache: webdavd.Cache{
Users: webdavd.UsersCacheConfig{
@@ -313,23 +305,21 @@ func Init() {
},
},
ProviderConf: dataprovider.Config{
Driver: "sqlite",
Name: "sftpgo.db",
Host: "",
Port: 0,
Username: "",
Password: "",
ConnectionString: "",
SQLTablesPrefix: "",
SSLMode: 0,
DisableSNI: false,
TargetSessionAttrs: "",
RootCert: "",
ClientCert: "",
ClientKey: "",
TrackQuota: 2,
PoolSize: 0,
UsersBaseDir: "",
Driver: "sqlite",
Name: "sftpgo.db",
Host: "",
Port: 0,
Username: "",
Password: "",
ConnectionString: "",
SQLTablesPrefix: "",
SSLMode: 0,
RootCert: "",
ClientCert: "",
ClientKey: "",
TrackQuota: 2,
PoolSize: 0,
UsersBaseDir: "",
Actions: dataprovider.ObjectsActions{
ExecuteOn: []string{},
ExecuteFor: []string{},
@@ -337,6 +327,7 @@ func Init() {
},
ExternalAuthHook: "",
ExternalAuthScope: 0,
CredentialsPath: "credentials",
PreLoginHook: "",
PostLoginHook: "",
PostLoginScope: 0,
@@ -367,12 +358,12 @@ func Init() {
CreateDefaultAdmin: false,
NamingRules: 1,
IsShared: 0,
Node: dataprovider.NodeConfig{
Host: "",
Port: 0,
Proto: "http",
BackupsPath: "backups",
AutoBackup: dataprovider.AutoBackup{
Enabled: true,
Hour: "0",
DayOfWeek: "*",
},
BackupsPath: "backups",
},
HTTPDConfig: httpd.Conf{
Bindings: []httpd.Binding{defaultHTTPDBinding},
@@ -388,16 +379,13 @@ func Init() {
TokenValidation: 0,
MaxUploadFileSize: 1048576000,
Cors: httpd.CorsConfig{
Enabled: false,
AllowedOrigins: []string{},
AllowedMethods: []string{},
AllowedHeaders: []string{},
ExposedHeaders: []string{},
AllowCredentials: false,
MaxAge: 0,
OptionsPassthrough: false,
OptionsSuccessStatus: 0,
AllowPrivateNetwork: false,
Enabled: false,
AllowedOrigins: []string{},
AllowedMethods: []string{},
AllowedHeaders: []string{},
ExposedHeaders: []string{},
AllowCredentials: false,
MaxAge: 0,
},
Setup: httpd.SetupConfig{
InstallationCode: "",
@@ -428,7 +416,7 @@ func Init() {
},
},
MFAConfig: mfa.Config{
TOTP: []mfa.TOTPConfig{defaultTOTP},
TOTP: nil,
},
TelemetryConfig: telemetry.Conf{
BindPort: 0,
@@ -636,64 +624,6 @@ func setConfigFile(configDir, configFile string) {
viper.SetConfigFile(configFile)
}
// readEnvFiles reads files inside the "env.d" directory relative to configDir
// and then export the valid variables into environment variables if they do
// not exist
func readEnvFiles(configDir string) {
envd := filepath.Join(configDir, "env.d")
entries, err := os.ReadDir(envd)
if err != nil {
logger.Info(logSender, "", "unable to read env files from %q: %v", envd, err)
return
}
for _, entry := range entries {
info, err := entry.Info()
if err == nil && info.Mode().IsRegular() {
envFile := filepath.Join(envd, entry.Name())
err = gotenv.Load(envFile)
if err != nil {
logger.Error(logSender, "", "unable to load env vars from file %q, err: %v", envFile, err)
} else {
logger.Info(logSender, "", "set env vars from file %q", envFile)
}
}
}
}
func checkOverrideDefaultSettings() {
// for slices we need to set the defaults to nil if the key is set in the config file,
// otherwise the values are merged and not replaced as expected
rateLimiters := viper.Get("common.rate_limiters")
if val, ok := rateLimiters.([]any); ok {
if len(val) > 0 {
if rl, ok := val[0].(map[string]any); ok {
if _, ok := rl["protocols"]; ok {
globalConf.Common.RateLimitersConfig[0].Protocols = nil
}
}
}
}
httpdBindings := viper.Get("httpd.bindings")
if val, ok := httpdBindings.([]any); ok {
if len(val) > 0 {
if binding, ok := val[0].(map[string]any); ok {
if val, ok := binding["oidc"]; ok {
if oidc, ok := val.(map[string]any); ok {
if _, ok := oidc["scopes"]; ok {
globalConf.HTTPDConfig.Bindings[0].OIDC.Scopes = nil
}
}
}
}
}
}
if util.Contains(viper.AllKeys(), "mfa.totp") {
globalConf.MFAConfig.TOTP = nil
}
}
// LoadConfig loads the configuration
// configDir will be added to the configuration search paths.
// The search path contains by default the current directory and on linux it contains
@@ -701,7 +631,6 @@ func checkOverrideDefaultSettings() {
// configFile is an absolute or relative path (to the config dir) to the configuration file.
func LoadConfig(configDir, configFile string) error {
var err error
readEnvFiles(configDir)
viper.AddConfigPath(configDir)
setViperAdditionalConfigPaths()
viper.AddConfigPath(".")
@@ -717,8 +646,8 @@ func LoadConfig(configDir, configFile string) error {
logger.Warn(logSender, "", "error loading configuration file: %v", err)
logger.WarnToConsole("error loading configuration file: %v", err)
}
globalConf.MFAConfig.TOTP = []mfa.TOTPConfig{defaultTOTP}
}
checkOverrideDefaultSettings()
err = viper.Unmarshal(&globalConf)
if err != nil {
logger.Warn(logSender, "", "error parsing configuration file: %v", err)
@@ -780,6 +709,12 @@ func resetInvalidConfigs() {
logger.Warn(logSender, "", "Non-fatal configuration error: %v", warn)
logger.WarnToConsole("Non-fatal configuration error: %v", warn)
}
if globalConf.ProviderConf.CredentialsPath == "" {
warn := "invalid credentials path, reset to \"credentials\""
globalConf.ProviderConf.CredentialsPath = "credentials"
logger.Warn(logSender, "", "Non-fatal configuration error: %v", warn)
logger.WarnToConsole("Non-fatal configuration error: %v", warn)
}
if globalConf.Common.DefenderConfig.Enabled && globalConf.Common.DefenderConfig.Driver == common.DefenderDriverProvider {
if !globalConf.ProviderConf.IsDefenderSupported() {
warn := fmt.Sprintf("provider based defender is not supported with data provider %#v, "+
@@ -1264,12 +1199,6 @@ func getWebDAVDBindingFromEnv(idx int) {
isSet = true
}
enableHTTPS, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__ENABLE_HTTPS", idx))
if ok {
binding.EnableHTTPS = enableHTTPS
isSet = true
}
certificateFile, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__CERTIFICATE_FILE", idx))
if ok {
binding.CertificateFile = certificateFile
@@ -1282,6 +1211,12 @@ func getWebDAVDBindingFromEnv(idx int) {
isSet = true
}
enableHTTPS, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__ENABLE_HTTPS", idx))
if ok {
binding.EnableHTTPS = enableHTTPS
isSet = true
}
tlsVer, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__MIN_TLS_VERSION", idx))
if ok {
binding.MinTLSVersion = int(tlsVer)
@@ -1300,19 +1235,13 @@ func getWebDAVDBindingFromEnv(idx int) {
isSet = true
}
prefix, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__PREFIX", idx))
if ok {
binding.Prefix = prefix
isSet = true
}
if getWebDAVDBindingProxyConfigsFromEnv(idx, &binding) {
isSet = true
}
disableWWWAuth, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__DISABLE_WWW_AUTH_HEADER", idx))
prefix, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__PREFIX", idx))
if ok {
binding.DisableWWWAuthHeader = disableWWWAuth
binding.Prefix = prefix
isSet = true
}
@@ -1521,12 +1450,6 @@ func getHTTPDOIDCFromEnv(idx int) (httpd.OIDC, bool) {
isSet = true
}
skipSignatureCheck, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__OIDC__INSECURE_SKIP_SIGNATURE_CHECK", idx))
if ok {
result.InsecureSkipSignatureCheck = skipSignatureCheck
isSet = true
}
debug, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__OIDC__DEBUG", idx))
if ok {
result.Debug = debug
@@ -1592,7 +1515,6 @@ func getHTTPDUIBrandingFromEnv(prefix string, branding httpd.UIBranding) (httpd.
branding.ExtraCSS = extraCSS
isSet = true
}
return branding, isSet
}
@@ -1760,12 +1682,6 @@ func getHTTPDBindingFromEnv(idx int) { //nolint:gocyclo
isSet = true
}
enableRESTAPI, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__ENABLE_REST_API", idx))
if ok {
binding.EnableRESTAPI = enableRESTAPI
isSet = true
}
enabledLoginMethods, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__ENABLED_LOGIN_METHODS", idx))
if ok {
binding.EnabledLoginMethods = int(enabledLoginMethods)
@@ -1930,7 +1846,6 @@ func setViperDefaults() {
viper.SetDefault("common.max_total_connections", globalConf.Common.MaxTotalConnections)
viper.SetDefault("common.max_per_host_connections", globalConf.Common.MaxPerHostConnections)
viper.SetDefault("common.whitelist_file", globalConf.Common.WhiteListFile)
viper.SetDefault("common.allow_self_connections", globalConf.Common.AllowSelfConnections)
viper.SetDefault("common.defender.enabled", globalConf.Common.DefenderConfig.Enabled)
viper.SetDefault("common.defender.driver", globalConf.Common.DefenderConfig.Driver)
viper.SetDefault("common.defender.ban_time", globalConf.Common.DefenderConfig.BanTime)
@@ -1995,9 +1910,6 @@ func setViperDefaults() {
viper.SetDefault("webdavd.cors.allowed_headers", globalConf.WebDAVD.Cors.AllowedHeaders)
viper.SetDefault("webdavd.cors.exposed_headers", globalConf.WebDAVD.Cors.ExposedHeaders)
viper.SetDefault("webdavd.cors.allow_credentials", globalConf.WebDAVD.Cors.AllowCredentials)
viper.SetDefault("webdavd.cors.options_passthrough", globalConf.WebDAVD.Cors.OptionsPassthrough)
viper.SetDefault("webdavd.cors.options_success_status", globalConf.WebDAVD.Cors.OptionsSuccessStatus)
viper.SetDefault("webdavd.cors.allow_private_network", globalConf.WebDAVD.Cors.AllowPrivateNetwork)
viper.SetDefault("webdavd.cors.max_age", globalConf.WebDAVD.Cors.MaxAge)
viper.SetDefault("webdavd.cache.users.expiration_time", globalConf.WebDAVD.Cache.Users.ExpirationTime)
viper.SetDefault("webdavd.cache.users.max_size", globalConf.WebDAVD.Cache.Users.MaxSize)
@@ -2010,8 +1922,6 @@ func setViperDefaults() {
viper.SetDefault("data_provider.username", globalConf.ProviderConf.Username)
viper.SetDefault("data_provider.password", globalConf.ProviderConf.Password)
viper.SetDefault("data_provider.sslmode", globalConf.ProviderConf.SSLMode)
viper.SetDefault("data_provider.disable_sni", globalConf.ProviderConf.DisableSNI)
viper.SetDefault("data_provider.target_session_attrs", globalConf.ProviderConf.TargetSessionAttrs)
viper.SetDefault("data_provider.root_cert", globalConf.ProviderConf.RootCert)
viper.SetDefault("data_provider.client_cert", globalConf.ProviderConf.ClientCert)
viper.SetDefault("data_provider.client_key", globalConf.ProviderConf.ClientKey)
@@ -2025,6 +1935,7 @@ func setViperDefaults() {
viper.SetDefault("data_provider.actions.hook", globalConf.ProviderConf.Actions.Hook)
viper.SetDefault("data_provider.external_auth_hook", globalConf.ProviderConf.ExternalAuthHook)
viper.SetDefault("data_provider.external_auth_scope", globalConf.ProviderConf.ExternalAuthScope)
viper.SetDefault("data_provider.credentials_path", globalConf.ProviderConf.CredentialsPath)
viper.SetDefault("data_provider.pre_login_hook", globalConf.ProviderConf.PreLoginHook)
viper.SetDefault("data_provider.post_login_hook", globalConf.ProviderConf.PostLoginHook)
viper.SetDefault("data_provider.post_login_scope", globalConf.ProviderConf.PostLoginScope)
@@ -2043,10 +1954,10 @@ func setViperDefaults() {
viper.SetDefault("data_provider.create_default_admin", globalConf.ProviderConf.CreateDefaultAdmin)
viper.SetDefault("data_provider.naming_rules", globalConf.ProviderConf.NamingRules)
viper.SetDefault("data_provider.is_shared", globalConf.ProviderConf.IsShared)
viper.SetDefault("data_provider.node.host", globalConf.ProviderConf.Node.Host)
viper.SetDefault("data_provider.node.port", globalConf.ProviderConf.Node.Port)
viper.SetDefault("data_provider.node.proto", globalConf.ProviderConf.Node.Proto)
viper.SetDefault("data_provider.backups_path", globalConf.ProviderConf.BackupsPath)
viper.SetDefault("data_provider.auto_backup.enabled", globalConf.ProviderConf.AutoBackup.Enabled)
viper.SetDefault("data_provider.auto_backup.hour", globalConf.ProviderConf.AutoBackup.Hour)
viper.SetDefault("data_provider.auto_backup.day_of_week", globalConf.ProviderConf.AutoBackup.DayOfWeek)
viper.SetDefault("httpd.templates_path", globalConf.HTTPDConfig.TemplatesPath)
viper.SetDefault("httpd.static_files_path", globalConf.HTTPDConfig.StaticFilesPath)
viper.SetDefault("httpd.openapi_path", globalConf.HTTPDConfig.OpenAPIPath)
@@ -2065,9 +1976,6 @@ func setViperDefaults() {
viper.SetDefault("httpd.cors.exposed_headers", globalConf.HTTPDConfig.Cors.ExposedHeaders)
viper.SetDefault("httpd.cors.allow_credentials", globalConf.HTTPDConfig.Cors.AllowCredentials)
viper.SetDefault("httpd.cors.max_age", globalConf.HTTPDConfig.Cors.MaxAge)
viper.SetDefault("httpd.cors.options_passthrough", globalConf.HTTPDConfig.Cors.OptionsPassthrough)
viper.SetDefault("httpd.cors.options_success_status", globalConf.HTTPDConfig.Cors.OptionsSuccessStatus)
viper.SetDefault("httpd.cors.allow_private_network", globalConf.HTTPDConfig.Cors.AllowPrivateNetwork)
viper.SetDefault("httpd.setup.installation_code", globalConf.HTTPDConfig.Setup.InstallationCode)
viper.SetDefault("httpd.setup.installation_code_hint", globalConf.HTTPDConfig.Setup.InstallationCodeHint)
viper.SetDefault("httpd.hide_support_link", globalConf.HTTPDConfig.HideSupportLink)

View File

@@ -26,28 +26,24 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/drakkan/sftpgo/v2/internal/command"
"github.com/drakkan/sftpgo/v2/internal/common"
"github.com/drakkan/sftpgo/v2/internal/config"
"github.com/drakkan/sftpgo/v2/internal/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/ftpd"
"github.com/drakkan/sftpgo/v2/internal/httpclient"
"github.com/drakkan/sftpgo/v2/internal/httpd"
"github.com/drakkan/sftpgo/v2/internal/mfa"
"github.com/drakkan/sftpgo/v2/internal/plugin"
"github.com/drakkan/sftpgo/v2/internal/sftpd"
"github.com/drakkan/sftpgo/v2/internal/smtp"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/command"
"github.com/drakkan/sftpgo/v2/common"
"github.com/drakkan/sftpgo/v2/config"
"github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/ftpd"
"github.com/drakkan/sftpgo/v2/httpclient"
"github.com/drakkan/sftpgo/v2/httpd"
"github.com/drakkan/sftpgo/v2/mfa"
"github.com/drakkan/sftpgo/v2/plugin"
"github.com/drakkan/sftpgo/v2/sftpd"
"github.com/drakkan/sftpgo/v2/smtp"
"github.com/drakkan/sftpgo/v2/util"
)
const (
tempConfigName = "temp"
)
var (
configDir = filepath.Join(".", "..", "..")
)
func reset() {
viper.Reset()
config.Init()
@@ -56,6 +52,7 @@ func reset() {
func TestLoadConfigTest(t *testing.T) {
reset()
configDir := ".."
err := config.LoadConfig(configDir, "")
assert.NoError(t, err)
assert.NotEqual(t, httpd.Conf{}, config.GetHTTPConfig())
@@ -84,41 +81,15 @@ func TestLoadConfigFileNotFound(t *testing.T) {
viper.SetConfigName("configfile")
err := config.LoadConfig(os.TempDir(), "")
require.NoError(t, err)
assert.NoError(t, err)
mfaConf := config.GetMFAConfig()
require.Len(t, mfaConf.TOTP, 1)
require.Len(t, config.GetCommonConfig().RateLimitersConfig, 1)
require.Len(t, config.GetCommonConfig().RateLimitersConfig[0].Protocols, 4)
require.Len(t, config.GetHTTPDConfig().Bindings, 1)
require.Len(t, config.GetHTTPDConfig().Bindings[0].OIDC.Scopes, 3)
}
func TestReadEnvFiles(t *testing.T) {
reset()
envd := filepath.Join(configDir, "env.d")
err := os.Mkdir(envd, os.ModePerm)
assert.NoError(t, err)
err = os.WriteFile(filepath.Join(envd, "env1"), []byte("SFTPGO_SFTPD__MAX_AUTH_TRIES = 10"), 0666)
assert.NoError(t, err)
err = os.WriteFile(filepath.Join(envd, "env2"), []byte(`{"invalid env": "value"}`), 0666)
assert.NoError(t, err)
err = config.LoadConfig(configDir, "")
assert.NoError(t, err)
assert.Equal(t, 10, config.GetSFTPDConfig().MaxAuthTries)
_, ok := os.LookupEnv("SFTPGO_SFTPD__MAX_AUTH_TRIES")
assert.True(t, ok)
err = os.Unsetenv("SFTPGO_SFTPD__MAX_AUTH_TRIES")
assert.NoError(t, err)
os.RemoveAll(envd)
assert.Len(t, mfaConf.TOTP, 1)
}
func TestEmptyBanner(t *testing.T) {
reset()
configDir := ".."
confName := tempConfigName + ".json"
configFilePath := filepath.Join(configDir, confName)
err := config.LoadConfig(configDir, "")
@@ -155,6 +126,7 @@ func TestEmptyBanner(t *testing.T) {
func TestEnabledSSHCommands(t *testing.T) {
reset()
configDir := ".."
confName := tempConfigName + ".json"
configFilePath := filepath.Join(configDir, confName)
err := config.LoadConfig(configDir, "")
@@ -183,6 +155,7 @@ func TestEnabledSSHCommands(t *testing.T) {
func TestInvalidUploadMode(t *testing.T) {
reset()
configDir := ".."
confName := tempConfigName + ".json"
configFilePath := filepath.Join(configDir, confName)
err := config.LoadConfig(configDir, "")
@@ -205,6 +178,7 @@ func TestInvalidUploadMode(t *testing.T) {
func TestInvalidExternalAuthScope(t *testing.T) {
reset()
configDir := ".."
confName := tempConfigName + ".json"
configFilePath := filepath.Join(configDir, confName)
err := config.LoadConfig(configDir, "")
@@ -224,9 +198,33 @@ func TestInvalidExternalAuthScope(t *testing.T) {
assert.NoError(t, err)
}
func TestInvalidCredentialsPath(t *testing.T) {
reset()
configDir := ".."
confName := tempConfigName + ".json"
configFilePath := filepath.Join(configDir, confName)
err := config.LoadConfig(configDir, "")
assert.NoError(t, err)
providerConf := config.GetProviderConf()
providerConf.CredentialsPath = ""
c := make(map[string]dataprovider.Config)
c["data_provider"] = providerConf
jsonConf, err := json.Marshal(c)
assert.NoError(t, err)
err = os.WriteFile(configFilePath, jsonConf, os.ModePerm)
assert.NoError(t, err)
err = config.LoadConfig(configDir, confName)
assert.NoError(t, err)
assert.Equal(t, "credentials", config.GetProviderConf().CredentialsPath)
err = os.Remove(configFilePath)
assert.NoError(t, err)
}
func TestInvalidProxyProtocol(t *testing.T) {
reset()
configDir := ".."
confName := tempConfigName + ".json"
configFilePath := filepath.Join(configDir, confName)
err := config.LoadConfig(configDir, "")
@@ -249,6 +247,7 @@ func TestInvalidProxyProtocol(t *testing.T) {
func TestInvalidUsersBaseDir(t *testing.T) {
reset()
configDir := ".."
confName := tempConfigName + ".json"
configFilePath := filepath.Join(configDir, confName)
err := config.LoadConfig(configDir, "")
@@ -271,6 +270,7 @@ func TestInvalidUsersBaseDir(t *testing.T) {
func TestInvalidInstallationHint(t *testing.T) {
reset()
configDir := ".."
confName := tempConfigName + ".json"
configFilePath := filepath.Join(configDir, confName)
err := config.LoadConfig(configDir, "")
@@ -301,6 +301,7 @@ func TestDefenderProviderDriver(t *testing.T) {
}
reset()
configDir := ".."
confName := tempConfigName + ".json"
configFilePath := filepath.Join(configDir, confName)
providerConf := config.GetProviderConf()
@@ -380,6 +381,7 @@ func TestSetGetConfig(t *testing.T) {
func TestServiceToStart(t *testing.T) {
reset()
configDir := ".."
err := config.LoadConfig(configDir, "")
assert.NoError(t, err)
assert.True(t, config.HasServicesToStart())
@@ -413,6 +415,7 @@ func TestSSHCommandsFromEnv(t *testing.T) {
os.Unsetenv("SFTPGO_SFTPD__ENABLED_SSH_COMMANDS")
})
configDir := ".."
err := config.LoadConfig(configDir, "")
assert.NoError(t, err)
@@ -433,6 +436,7 @@ func TestSMTPFromEnv(t *testing.T) {
os.Unsetenv("SFTPGO_SMTP__PORT")
})
configDir := ".."
err := config.LoadConfig(configDir, "")
assert.NoError(t, err)
smtpConfig := config.GetSMTPConfig()
@@ -454,6 +458,7 @@ func TestMFAFromEnv(t *testing.T) {
os.Unsetenv("SFTPGO_MFA__TOTP__1__ALGO")
})
configDir := ".."
err := config.LoadConfig(configDir, "")
assert.NoError(t, err)
mfaConf := config.GetMFAConfig()
@@ -469,6 +474,7 @@ func TestMFAFromEnv(t *testing.T) {
func TestDisabledMFAConfig(t *testing.T) {
reset()
configDir := ".."
confName := tempConfigName + ".json"
configFilePath := filepath.Join(configDir, confName)
@@ -493,81 +499,6 @@ func TestDisabledMFAConfig(t *testing.T) {
assert.NoError(t, err)
}
func TestOverrideSliceValues(t *testing.T) {
reset()
confName := tempConfigName + ".json"
configFilePath := filepath.Join(configDir, confName)
c := make(map[string]any)
c["common"] = common.Configuration{
RateLimitersConfig: []common.RateLimiterConfig{
{
Type: 1,
Protocols: []string{"HTTP"},
},
},
}
jsonConf, err := json.Marshal(c)
assert.NoError(t, err)
err = os.WriteFile(configFilePath, jsonConf, os.ModePerm)
assert.NoError(t, err)
err = config.LoadConfig(configDir, confName)
assert.NoError(t, err)
require.Len(t, config.GetCommonConfig().RateLimitersConfig, 1)
require.Equal(t, []string{"HTTP"}, config.GetCommonConfig().RateLimitersConfig[0].Protocols)
reset()
// empty ratelimiters, default value should be used
c["common"] = common.Configuration{}
jsonConf, err = json.Marshal(c)
assert.NoError(t, err)
err = os.WriteFile(configFilePath, jsonConf, os.ModePerm)
assert.NoError(t, err)
err = config.LoadConfig(configDir, confName)
assert.NoError(t, err)
require.Len(t, config.GetCommonConfig().RateLimitersConfig, 1)
rl := config.GetCommonConfig().RateLimitersConfig[0]
require.Equal(t, []string{"SSH", "FTP", "DAV", "HTTP"}, rl.Protocols)
require.Equal(t, int64(1000), rl.Period)
reset()
c = make(map[string]any)
c["httpd"] = httpd.Conf{
Bindings: []httpd.Binding{
{
OIDC: httpd.OIDC{
Scopes: []string{"scope1"},
},
},
},
}
jsonConf, err = json.Marshal(c)
assert.NoError(t, err)
err = os.WriteFile(configFilePath, jsonConf, os.ModePerm)
assert.NoError(t, err)
err = config.LoadConfig(configDir, confName)
assert.NoError(t, err)
require.Len(t, config.GetHTTPDConfig().Bindings, 1)
require.Equal(t, []string{"scope1"}, config.GetHTTPDConfig().Bindings[0].OIDC.Scopes)
reset()
c = make(map[string]any)
c["httpd"] = httpd.Conf{
Bindings: []httpd.Binding{},
}
jsonConf, err = json.Marshal(c)
assert.NoError(t, err)
err = os.WriteFile(configFilePath, jsonConf, os.ModePerm)
assert.NoError(t, err)
err = config.LoadConfig(configDir, confName)
assert.NoError(t, err)
require.Len(t, config.GetHTTPDConfig().Bindings, 1)
require.Equal(t, []string{"openid", "profile", "email"}, config.GetHTTPDConfig().Bindings[0].OIDC.Scopes)
}
func TestFTPDOverridesFromEnv(t *testing.T) {
reset()
@@ -583,6 +514,7 @@ func TestFTPDOverridesFromEnv(t *testing.T) {
}
t.Cleanup(cleanup)
configDir := ".."
err := config.LoadConfig(configDir, "")
assert.NoError(t, err)
ftpdConf := config.GetFTPDConfig()
@@ -643,6 +575,7 @@ func TestHTTPDSubObjectsFromEnv(t *testing.T) {
}
t.Cleanup(cleanup)
configDir := ".."
err := config.LoadConfig(configDir, "")
assert.NoError(t, err)
httpdConf := config.GetHTTPDConfig()
@@ -718,6 +651,7 @@ func TestPluginsFromEnv(t *testing.T) {
os.Unsetenv("SFTPGO_PLUGINS__0__AUTH_OPTIONS__SCOPE")
})
configDir := ".."
err := config.LoadConfig(configDir, "")
assert.NoError(t, err)
pluginsConf := config.GetPluginsConfig()
@@ -816,6 +750,7 @@ func TestRateLimitersFromEnv(t *testing.T) {
os.Unsetenv("SFTPGO_COMMON__RATE_LIMITERS__8__ALLOW_LIST")
})
configDir := ".."
err := config.LoadConfig(configDir, "")
assert.NoError(t, err)
limiters := config.GetCommonConfig().RateLimitersConfig
@@ -868,6 +803,7 @@ func TestSFTPDBindingsFromEnv(t *testing.T) {
os.Unsetenv("SFTPGO_SFTPD__BINDINGS__3__PORT")
})
configDir := ".."
err := config.LoadConfig(configDir, "")
assert.NoError(t, err)
bindings := config.GetSFTPDConfig().Bindings
@@ -883,6 +819,7 @@ func TestSFTPDBindingsFromEnv(t *testing.T) {
func TestCommandsFromEnv(t *testing.T) {
reset()
configDir := ".."
confName := tempConfigName + ".json"
configFilePath := filepath.Join(configDir, confName)
err := config.LoadConfig(configDir, "")
@@ -989,6 +926,7 @@ func TestFTPDBindingsFromEnv(t *testing.T) {
os.Unsetenv("SFTPGO_FTPD__BINDINGS__9__CERTIFICATE_KEY_FILE")
})
configDir := ".."
err := config.LoadConfig(configDir, "")
assert.NoError(t, err)
bindings := config.GetFTPDConfig().Bindings
@@ -1045,7 +983,6 @@ func TestWebDAVBindingsFromEnv(t *testing.T) {
os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__PREFIX", "/dav2")
os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__CERTIFICATE_FILE", "webdav.crt")
os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__CERTIFICATE_KEY_FILE", "webdav.key")
os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__DISABLE_WWW_AUTH_HEADER", "1")
t.Cleanup(func() {
os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__ADDRESS")
@@ -1063,9 +1000,9 @@ func TestWebDAVBindingsFromEnv(t *testing.T) {
os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__PREFIX")
os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__CERTIFICATE_FILE")
os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__CERTIFICATE_KEY_FILE")
os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__DISABLE_WWW_AUTH_HEADER")
})
configDir := ".."
err := config.LoadConfig(configDir, "")
assert.NoError(t, err)
bindings := config.GetWebDAVDConfig().Bindings
@@ -1077,7 +1014,6 @@ func TestWebDAVBindingsFromEnv(t *testing.T) {
require.Len(t, bindings[0].TLSCipherSuites, 0)
require.Empty(t, bindings[0].Prefix)
require.Equal(t, 0, bindings[0].ClientIPHeaderDepth)
require.False(t, bindings[0].DisableWWWAuthHeader)
require.Equal(t, 8000, bindings[1].Port)
require.Equal(t, "127.0.0.1", bindings[1].Address)
require.False(t, bindings[1].EnableHTTPS)
@@ -1089,7 +1025,6 @@ func TestWebDAVBindingsFromEnv(t *testing.T) {
require.Equal(t, "X-Forwarded-For", bindings[1].ClientIPProxyHeader)
require.Equal(t, 2, bindings[1].ClientIPHeaderDepth)
require.Empty(t, bindings[1].Prefix)
require.False(t, bindings[1].DisableWWWAuthHeader)
require.Equal(t, 9000, bindings[2].Port)
require.Equal(t, "127.0.1.1", bindings[2].Address)
require.True(t, bindings[2].EnableHTTPS)
@@ -1100,7 +1035,6 @@ func TestWebDAVBindingsFromEnv(t *testing.T) {
require.Equal(t, "webdav.crt", bindings[2].CertificateFile)
require.Equal(t, "webdav.key", bindings[2].CertificateKeyFile)
require.Equal(t, 0, bindings[2].ClientIPHeaderDepth)
require.True(t, bindings[2].DisableWWWAuthHeader)
}
func TestHTTPDBindingsFromEnv(t *testing.T) {
@@ -1121,7 +1055,6 @@ func TestHTTPDBindingsFromEnv(t *testing.T) {
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__PORT", "9000")
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_WEB_ADMIN", "0")
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_WEB_CLIENT", "0")
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_REST_API", "0")
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__ENABLED_LOGIN_METHODS", "3")
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__RENDER_OPENAPI", "0")
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_HTTPS", "1 ")
@@ -1145,7 +1078,6 @@ func TestHTTPDBindingsFromEnv(t *testing.T) {
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__SCOPES", "openid")
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__IMPLICIT_ROLES", "1")
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__CUSTOM_FIELDS", "field1,field2")
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__INSECURE_SKIP_SIGNATURE_CHECK", "1")
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__DEBUG", "1")
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__ENABLED", "true")
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__ALLOWED_HOSTS", "*.example.com,*.example.net")
@@ -1192,7 +1124,6 @@ func TestHTTPDBindingsFromEnv(t *testing.T) {
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__MIN_TLS_VERSION")
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_WEB_ADMIN")
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_WEB_CLIENT")
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_REST_API")
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__ENABLED_LOGIN_METHODS")
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__RENDER_OPENAPI")
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__CLIENT_AUTH_TYPE")
@@ -1214,7 +1145,6 @@ func TestHTTPDBindingsFromEnv(t *testing.T) {
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__SCOPES")
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__IMPLICIT_ROLES")
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__CUSTOM_FIELDS")
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__INSECURE_SKIP_SIGNATURE_CHECK")
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__DEBUG")
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__ENABLED")
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__ALLOWED_HOSTS")
@@ -1245,6 +1175,7 @@ func TestHTTPDBindingsFromEnv(t *testing.T) {
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__CERTIFICATE_KEY_FILE")
})
configDir := ".."
err := config.LoadConfig(configDir, "")
assert.NoError(t, err)
bindings := config.GetHTTPDConfig().Bindings
@@ -1255,7 +1186,6 @@ func TestHTTPDBindingsFromEnv(t *testing.T) {
require.Equal(t, 12, bindings[0].MinTLSVersion)
require.True(t, bindings[0].EnableWebAdmin)
require.True(t, bindings[0].EnableWebClient)
require.True(t, bindings[0].EnableRESTAPI)
require.Equal(t, 0, bindings[0].EnabledLoginMethods)
require.True(t, bindings[0].RenderOpenAPI)
require.Len(t, bindings[0].TLSCipherSuites, 1)
@@ -1265,7 +1195,6 @@ func TestHTTPDBindingsFromEnv(t *testing.T) {
require.False(t, bindings[0].Security.Enabled)
require.Equal(t, 0, bindings[0].ClientIPHeaderDepth)
require.Len(t, bindings[0].OIDC.Scopes, 3)
require.False(t, bindings[0].OIDC.InsecureSkipSignatureCheck)
require.False(t, bindings[0].OIDC.Debug)
require.Equal(t, 8000, bindings[1].Port)
require.Equal(t, "127.0.0.1", bindings[1].Address)
@@ -1273,14 +1202,12 @@ func TestHTTPDBindingsFromEnv(t *testing.T) {
require.Equal(t, 12, bindings[0].MinTLSVersion)
require.True(t, bindings[1].EnableWebAdmin)
require.True(t, bindings[1].EnableWebClient)
require.True(t, bindings[1].EnableRESTAPI)
require.Equal(t, 0, bindings[1].EnabledLoginMethods)
require.True(t, bindings[1].RenderOpenAPI)
require.Nil(t, bindings[1].TLSCipherSuites)
require.Equal(t, 1, bindings[1].HideLoginURL)
require.Empty(t, bindings[1].OIDC.ClientID)
require.Len(t, bindings[1].OIDC.Scopes, 3)
require.False(t, bindings[1].OIDC.InsecureSkipSignatureCheck)
require.False(t, bindings[1].OIDC.Debug)
require.False(t, bindings[1].Security.Enabled)
require.Equal(t, "Web Admin", bindings[1].Branding.WebAdmin.Name)
@@ -1292,7 +1219,6 @@ func TestHTTPDBindingsFromEnv(t *testing.T) {
require.Equal(t, 13, bindings[2].MinTLSVersion)
require.False(t, bindings[2].EnableWebAdmin)
require.False(t, bindings[2].EnableWebClient)
require.False(t, bindings[2].EnableRESTAPI)
require.Equal(t, 3, bindings[2].EnabledLoginMethods)
require.False(t, bindings[2].RenderOpenAPI)
require.Equal(t, 1, bindings[2].ClientAuthType)
@@ -1320,7 +1246,6 @@ func TestHTTPDBindingsFromEnv(t *testing.T) {
require.Len(t, bindings[2].OIDC.CustomFields, 2)
require.Equal(t, "field1", bindings[2].OIDC.CustomFields[0])
require.Equal(t, "field2", bindings[2].OIDC.CustomFields[1])
require.True(t, bindings[2].OIDC.InsecureSkipSignatureCheck)
require.True(t, bindings[2].OIDC.Debug)
require.True(t, bindings[2].Security.Enabled)
require.Len(t, bindings[2].Security.AllowedHosts, 2)
@@ -1358,6 +1283,7 @@ func TestHTTPDBindingsFromEnv(t *testing.T) {
func TestHTTPClientCertificatesFromEnv(t *testing.T) {
reset()
configDir := ".."
confName := tempConfigName + ".json"
configFilePath := filepath.Join(configDir, confName)
err := config.LoadConfig(configDir, "")
@@ -1418,6 +1344,7 @@ func TestHTTPClientCertificatesFromEnv(t *testing.T) {
func TestHTTPClientHeadersFromEnv(t *testing.T) {
reset()
configDir := ".."
confName := tempConfigName + ".json"
configFilePath := filepath.Join(configDir, confName)
err := config.LoadConfig(configDir, "")

View File

@@ -26,11 +26,11 @@ import (
"github.com/sftpgo/sdk/plugin/notifier"
"github.com/drakkan/sftpgo/v2/internal/command"
"github.com/drakkan/sftpgo/v2/internal/httpclient"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/plugin"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/command"
"github.com/drakkan/sftpgo/v2/httpclient"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/plugin"
"github.com/drakkan/sftpgo/v2/util"
)
const (
@@ -42,19 +42,12 @@ const (
)
const (
actionObjectUser = "user"
actionObjectFolder = "folder"
actionObjectGroup = "group"
actionObjectAdmin = "admin"
actionObjectAPIKey = "api_key"
actionObjectShare = "share"
actionObjectEventAction = "event_action"
actionObjectEventRule = "event_rule"
)
var (
actionsConcurrencyGuard = make(chan struct{}, 100)
reservedUsers = []string{ActionExecutorSelf, ActionExecutorSystem}
actionObjectUser = "user"
actionObjectFolder = "folder"
actionObjectGroup = "group"
actionObjectAdmin = "admin"
actionObjectAPIKey = "api_key"
actionObjectShare = "share"
)
func executeAction(operation, executor, ip, objectType, objectName string, object plugin.Renderer) {
@@ -68,9 +61,6 @@ func executeAction(operation, executor, ip, objectType, objectName string, objec
Timestamp: time.Now().UnixNano(),
}, object)
}
if fnHandleRuleForProviderEvent != nil {
fnHandleRuleForProviderEvent(operation, executor, ip, objectType, objectName, object)
}
if config.Actions.Hook == "" {
return
}
@@ -80,11 +70,6 @@ func executeAction(operation, executor, ip, objectType, objectName string, objec
}
go func() {
actionsConcurrencyGuard <- struct{}{}
defer func() {
<-actionsConcurrencyGuard
}()
dataAsJSON, err := object.RenderAsJSON(operation != operationDelete)
if err != nil {
providerLog(logger.LevelError, "unable to serialize user as JSON for operation %#v: %v", operation, err)
@@ -104,7 +89,7 @@ func executeAction(operation, executor, ip, objectType, objectName string, objec
q.Add("ip", ip)
q.Add("object_type", objectType)
q.Add("object_name", objectName)
q.Add("timestamp", fmt.Sprintf("%d", time.Now().UnixNano()))
q.Add("timestamp", fmt.Sprintf("%v", time.Now().UnixNano()))
url.RawQuery = q.Encode()
startTime := time.Now()
resp, err := httpclient.RetryablePost(url.String(), "application/json", bytes.NewBuffer(dataAsJSON))
@@ -128,19 +113,19 @@ func executeNotificationCommand(operation, executor, ip, objectType, objectName
return err
}
timeout, env, args := command.GetConfig(config.Actions.Hook, command.HookProviderActions)
timeout, env := command.GetConfig(config.Actions.Hook)
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
cmd := exec.CommandContext(ctx, config.Actions.Hook, args...)
cmd := exec.CommandContext(ctx, config.Actions.Hook)
cmd.Env = append(env,
fmt.Sprintf("SFTPGO_PROVIDER_ACTION=%vs", operation),
fmt.Sprintf("SFTPGO_PROVIDER_OBJECT_TYPE=%s", objectType),
fmt.Sprintf("SFTPGO_PROVIDER_OBJECT_NAME=%s", objectName),
fmt.Sprintf("SFTPGO_PROVIDER_USERNAME=%s", executor),
fmt.Sprintf("SFTPGO_PROVIDER_IP=%s", ip),
fmt.Sprintf("SFTPGO_PROVIDER_TIMESTAMP=%d", util.GetTimeAsMsSinceEpoch(time.Now())),
fmt.Sprintf("SFTPGO_PROVIDER_OBJECT=%s", string(objectAsJSON)))
fmt.Sprintf("SFTPGO_PROVIDER_ACTION=%v", operation),
fmt.Sprintf("SFTPGO_PROVIDER_OBJECT_TYPE=%v", objectType),
fmt.Sprintf("SFTPGO_PROVIDER_OBJECT_NAME=%v", objectName),
fmt.Sprintf("SFTPGO_PROVIDER_USERNAME=%v", executor),
fmt.Sprintf("SFTPGO_PROVIDER_IP=%v", ip),
fmt.Sprintf("SFTPGO_PROVIDER_TIMESTAMP=%v", util.GetTimeAsMsSinceEpoch(time.Now())),
fmt.Sprintf("SFTPGO_PROVIDER_OBJECT=%v", string(objectAsJSON)))
startTime := time.Now()
err := cmd.Run()

View File

@@ -22,18 +22,16 @@ import (
"fmt"
"net"
"os"
"sort"
"strings"
"github.com/alexedwards/argon2id"
"github.com/sftpgo/sdk"
passwordvalidator "github.com/wagslane/go-password-validator"
"golang.org/x/crypto/bcrypt"
"github.com/drakkan/sftpgo/v2/internal/kms"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/mfa"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/kms"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/mfa"
"github.com/drakkan/sftpgo/v2/util"
)
// Available permissions for SFTPGo admins
@@ -56,16 +54,6 @@ const (
PermAdminRetentionChecks = "retention_checks"
PermAdminMetadataChecks = "metadata_checks"
PermAdminViewEvents = "view_events"
PermAdminManageEventRules = "manage_event_rules"
)
const (
// GroupAddToUsersAsMembership defines that the admin's group will be added as membership group for new users
GroupAddToUsersAsMembership = iota
// GroupAddToUsersAsPrimary defines that the admin's group will be added as primary group for new users
GroupAddToUsersAsPrimary
// GroupAddToUsersAsSecondary defines that the admin's group will be added as secondary group for new users
GroupAddToUsersAsSecondary
)
var (
@@ -107,82 +95,6 @@ func (c *AdminTOTPConfig) validate(username string) error {
return nil
}
// AdminPreferences defines the admin preferences
type AdminPreferences struct {
// Allow to hide some sections from the user page.
// These are not security settings and are not enforced server side
// in any way. They are only intended to simplify the user page in
// the WebAdmin UI.
//
// 1 means hide groups section
// 2 means hide filesystem section, "users_base_dir" must be set in the config file otherwise this setting is ignored
// 4 means hide virtual folders section
// 8 means hide profile section
// 16 means hide ACLs section
// 32 means hide disk and bandwidth quota limits section
// 64 means hide advanced settings section
//
// The settings can be combined
HideUserPageSections int `json:"hide_user_page_sections,omitempty"`
}
// HideGroups returns true if the groups section should be hidden
func (p *AdminPreferences) HideGroups() bool {
return p.HideUserPageSections&1 != 0
}
// HideFilesystem returns true if the filesystem section should be hidden
func (p *AdminPreferences) HideFilesystem() bool {
return config.UsersBaseDir != "" && p.HideUserPageSections&2 != 0
}
// HideVirtualFolders returns true if the virtual folder section should be hidden
func (p *AdminPreferences) HideVirtualFolders() bool {
return p.HideUserPageSections&4 != 0
}
// HideProfile returns true if the profile section should be hidden
func (p *AdminPreferences) HideProfile() bool {
return p.HideUserPageSections&8 != 0
}
// HideACLs returns true if the ACLs section should be hidden
func (p *AdminPreferences) HideACLs() bool {
return p.HideUserPageSections&16 != 0
}
// HideDiskQuotaAndBandwidthLimits returns true if the disk quota and bandwidth limits
// section should be hidden
func (p *AdminPreferences) HideDiskQuotaAndBandwidthLimits() bool {
return p.HideUserPageSections&32 != 0
}
// HideAdvancedSettings returns true if the advanced settings section should be hidden
func (p *AdminPreferences) HideAdvancedSettings() bool {
return p.HideUserPageSections&64 != 0
}
// VisibleUserPageSections returns the number of visible sections
// in the user page
func (p *AdminPreferences) VisibleUserPageSections() int {
var result int
if !p.HideProfile() {
result++
}
if !p.HideACLs() {
result++
}
if !p.HideDiskQuotaAndBandwidthLimits() {
result++
}
if !p.HideAdvancedSettings() {
result++
}
return result
}
// AdminFilters defines additional restrictions for SFTPGo admins
// TODO: rename to AdminOptions in v3
type AdminFilters struct {
@@ -197,38 +109,7 @@ type AdminFilters struct {
// Recovery codes to use if the user loses access to their second factor auth device.
// Each code can only be used once, you should use these codes to login and disable or
// reset 2FA for your account
RecoveryCodes []RecoveryCode `json:"recovery_codes,omitempty"`
Preferences AdminPreferences `json:"preferences"`
}
// AdminGroupMappingOptions defines the options for admin/group mapping
type AdminGroupMappingOptions struct {
AddToUsersAs int `json:"add_to_users_as,omitempty"`
}
func (o *AdminGroupMappingOptions) validate() error {
if o.AddToUsersAs < GroupAddToUsersAsMembership || o.AddToUsersAs > GroupAddToUsersAsSecondary {
return util.NewValidationError(fmt.Sprintf("Invalid mode to add groups to new users: %d", o.AddToUsersAs))
}
return nil
}
// GetUserGroupType returns the type for the matching user group
func (o *AdminGroupMappingOptions) GetUserGroupType() int {
switch o.AddToUsersAs {
case GroupAddToUsersAsPrimary:
return sdk.GroupTypePrimary
case GroupAddToUsersAsSecondary:
return sdk.GroupTypeSecondary
default:
return sdk.GroupTypeMembership
}
}
// AdminGroupMapping defines the mapping between an SFTPGo admin and a group
type AdminGroupMapping struct {
Name string `json:"name"`
Options AdminGroupMappingOptions `json:"options"`
RecoveryCodes []RecoveryCode `json:"recovery_codes,omitempty"`
}
// Admin defines a SFTPGo admin
@@ -245,8 +126,6 @@ type Admin struct {
Filters AdminFilters `json:"filters,omitempty"`
Description string `json:"description,omitempty"`
AdditionalInfo string `json:"additional_info,omitempty"`
// Groups membership
Groups []AdminGroupMapping `json:"groups,omitempty"`
// Creation time as unix timestamp in milliseconds. It will be 0 for admins created before v2.2.0
CreatedAt int64 `json:"created_at"`
// last update time as unix timestamp in milliseconds
@@ -326,33 +205,11 @@ func (a *Admin) validatePermissions() error {
return nil
}
func (a *Admin) validateGroups() error {
hasPrimary := false
for _, g := range a.Groups {
if g.Name == "" {
return util.NewValidationError("group name is mandatory")
}
if err := g.Options.validate(); err != nil {
return err
}
if g.Options.AddToUsersAs == GroupAddToUsersAsPrimary {
if hasPrimary {
return util.NewValidationError("only one primary group is allowed")
}
hasPrimary = true
}
}
return nil
}
func (a *Admin) validate() error {
a.SetEmptySecretsIfNil()
if a.Username == "" {
return util.NewValidationError("username is mandatory")
}
if err := checkReservedUsernames(a.Username); err != nil {
return err
}
if a.Password == "" {
return util.NewValidationError("please set a password")
}
@@ -385,20 +242,7 @@ func (a *Admin) validate() error {
}
}
return a.validateGroups()
}
// GetGroupsAsString returns the user's groups as a string
func (a *Admin) GetGroupsAsString() string {
if len(a.Groups) == 0 {
return ""
}
var groups []string
for _, g := range a.Groups {
groups = append(groups, g.Name)
}
sort.Strings(groups)
return strings.Join(groups, ",")
return nil
}
// CheckPassword verifies the admin password
@@ -577,18 +421,6 @@ func (a *Admin) getACopy() Admin {
Used: code.Used,
})
}
filters.Preferences = AdminPreferences{
HideUserPageSections: a.Filters.Preferences.HideUserPageSections,
}
groups := make([]AdminGroupMapping, 0, len(a.Groups))
for _, g := range a.Groups {
groups = append(groups, AdminGroupMapping{
Name: g.Name,
Options: AdminGroupMappingOptions{
AddToUsersAs: g.Options.AddToUsersAs,
},
})
}
return Admin{
ID: a.ID,
@@ -597,7 +429,6 @@ func (a *Admin) getACopy() Admin {
Password: a.Password,
Email: a.Email,
Permissions: permissions,
Groups: groups,
Filters: filters,
AdditionalInfo: a.AdditionalInfo,
Description: a.Description,

View File

@@ -23,8 +23,8 @@ import (
"github.com/alexedwards/argon2id"
"golang.org/x/crypto/bcrypt"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/util"
)
// APIKeyScope defines the supported API key scopes

File diff suppressed because it is too large Load Diff

View File

@@ -20,13 +20,13 @@ package dataprovider
import (
"errors"
"github.com/drakkan/sftpgo/v2/internal/version"
"github.com/drakkan/sftpgo/v2/version"
)
func init() {
version.AddFeature("-bolt")
}
func initializeBoltProvider(_ string) error {
func initializeBoltProvider(basePath string) error {
return errors.New("bolt disabled at build time")
}

View File

@@ -20,8 +20,8 @@ import (
"golang.org/x/net/webdav"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/util"
)
var (

View File

@@ -22,10 +22,10 @@ import (
"github.com/sftpgo/sdk"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/plugin"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/internal/vfs"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/plugin"
"github.com/drakkan/sftpgo/v2/util"
"github.com/drakkan/sftpgo/v2/vfs"
)
// GroupUserSettings defines the settings to apply to users
@@ -187,8 +187,6 @@ func (g *Group) validateUserSettings() error {
func (g *Group) getACopy() Group {
users := make([]string, len(g.Users))
copy(users, g.Users)
admins := make([]string, len(g.Admins))
copy(admins, g.Admins)
virtualFolders := make([]vfs.VirtualFolder, 0, len(g.VirtualFolders))
for idx := range g.VirtualFolders {
vfolder := g.VirtualFolders[idx].GetACopy()
@@ -209,7 +207,6 @@ func (g *Group) getACopy() Group {
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
Users: users,
Admins: admins,
},
UserSettings: GroupUserSettings{
BaseGroupUserSettings: sdk.BaseGroupUserSettings{
@@ -231,14 +228,7 @@ func (g *Group) getACopy() Group {
}
}
// GetMembersAsString returns a string representation for the group members
func (g *Group) GetMembersAsString() string {
var sb strings.Builder
if len(g.Users) > 0 {
sb.WriteString(fmt.Sprintf("Users: %d. ", len(g.Users)))
}
if len(g.Admins) > 0 {
sb.WriteString(fmt.Sprintf("Admins: %d. ", len(g.Admins)))
}
return sb.String()
// GetUsersAsString returns the list of users as comma separated string
func (g *Group) GetUsersAsString() string {
return strings.Join(g.Users, ",")
}

View File

@@ -24,9 +24,9 @@ import (
"sync"
"time"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/internal/vfs"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/util"
"github.com/drakkan/sftpgo/v2/vfs"
)
var (
@@ -62,14 +62,6 @@ type memoryProviderHandle struct {
shares map[string]Share
// slice with ordered shares shareID
sharesIDs []string
// map for event actions, name is the key
actions map[string]BaseEventAction
// slice with ordered actions
actionsNames []string
// map for event actions, name is the key
rules map[string]EventRule
// slice with ordered rules
rulesNames []string
}
// MemoryProvider defines the auth provider for a memory store
@@ -100,10 +92,6 @@ func initializeMemoryProvider(basePath string) {
apiKeysIDs: []string{},
shares: make(map[string]Share),
sharesIDs: []string{},
actions: make(map[string]BaseEventAction),
actionsNames: []string{},
rules: make(map[string]EventRule),
rulesNames: []string{},
configFile: configFile,
},
}
@@ -146,6 +134,10 @@ func (p *MemoryProvider) validateUserAndTLSCert(username, protocol string, tlsCe
}
func (p *MemoryProvider) validateUserAndPass(username, password, ip, protocol string) (User, error) {
var user User
if password == "" {
return user, errors.New("credentials cannot be null or empty")
}
user, err := p.userExists(username)
if err != nil {
providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err)
@@ -326,22 +318,14 @@ func (p *MemoryProvider) addUser(user *User) error {
user.UsedUploadDataTransfer = 0
user.UsedDownloadDataTransfer = 0
user.LastLogin = 0
user.FirstUpload = 0
user.FirstDownload = 0
user.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
user.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
var mappedGroups []string
user.VirtualFolders = p.joinUserVirtualFoldersFields(user)
for idx := range user.Groups {
if err = p.addUserToGroupMapping(user.Username, user.Groups[idx].Name); err != nil {
// try to remove group mapping
for _, g := range mappedGroups {
p.removeUserFromGroupMapping(user.Username, g)
}
if err = p.addUserFromGroupMapping(user.Username, user.Groups[idx].Name); err != nil {
return err
}
mappedGroups = append(mappedGroups, user.Groups[idx].Name)
}
user.VirtualFolders = p.joinUserVirtualFoldersFields(user)
p.dbHandle.users[user.Username] = user.getACopy()
p.dbHandle.usernames = append(p.dbHandle.usernames, user.Username)
sort.Strings(p.dbHandle.usernames)
@@ -366,33 +350,26 @@ func (p *MemoryProvider) updateUser(user *User) error {
if err != nil {
return err
}
for idx := range u.Groups {
p.removeUserFromGroupMapping(u.Username, u.Groups[idx].Name)
}
for idx := range user.Groups {
if err = p.addUserToGroupMapping(user.Username, user.Groups[idx].Name); err != nil {
// try to add old mapping
for _, g := range u.Groups {
if errRollback := p.addUserToGroupMapping(user.Username, g.Name); errRollback != nil {
providerLog(logger.LevelError, "unable to rollback old group mapping %q for user %q, error: %v",
g.Name, user.Username, errRollback)
}
}
return err
}
}
for _, oldFolder := range u.VirtualFolders {
p.removeRelationFromFolderMapping(oldFolder.Name, u.Username, "")
}
for idx := range u.Groups {
if err = p.removeUserFromGroupMapping(u.Username, u.Groups[idx].Name); err != nil {
return err
}
}
user.VirtualFolders = p.joinUserVirtualFoldersFields(user)
for idx := range user.Groups {
if err = p.addUserFromGroupMapping(user.Username, user.Groups[idx].Name); err != nil {
return err
}
}
user.LastQuotaUpdate = u.LastQuotaUpdate
user.UsedQuotaSize = u.UsedQuotaSize
user.UsedQuotaFiles = u.UsedQuotaFiles
user.UsedUploadDataTransfer = u.UsedUploadDataTransfer
user.UsedDownloadDataTransfer = u.UsedDownloadDataTransfer
user.LastLogin = u.LastLogin
user.FirstDownload = u.FirstDownload
user.FirstUpload = u.FirstUpload
user.CreatedAt = u.CreatedAt
user.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
user.ID = u.ID
@@ -402,7 +379,7 @@ func (p *MemoryProvider) updateUser(user *User) error {
return nil
}
func (p *MemoryProvider) deleteUser(user User, softDelete bool) error {
func (p *MemoryProvider) deleteUser(user User) error {
p.dbHandle.Lock()
defer p.dbHandle.Unlock()
if p.dbHandle.isClosed {
@@ -416,7 +393,9 @@ func (p *MemoryProvider) deleteUser(user User, softDelete bool) error {
p.removeRelationFromFolderMapping(oldFolder.Name, u.Username, "")
}
for idx := range u.Groups {
p.removeUserFromGroupMapping(u.Username, u.Groups[idx].Name)
if err = p.removeUserFromGroupMapping(u.Username, u.Groups[idx].Name); err != nil {
return err
}
}
delete(p.dbHandle.users, user.Username)
// this could be more efficient
@@ -611,28 +590,14 @@ func (p *MemoryProvider) userExistsInternal(username string) (User, error) {
if val, ok := p.dbHandle.users[username]; ok {
return val.getACopy(), nil
}
return User{}, util.NewRecordNotFoundError(fmt.Sprintf("username %q does not exist", username))
return User{}, util.NewRecordNotFoundError(fmt.Sprintf("username %#v does not exist", username))
}
func (p *MemoryProvider) groupExistsInternal(name string) (Group, error) {
if val, ok := p.dbHandle.groups[name]; ok {
return val.getACopy(), nil
}
return Group{}, util.NewRecordNotFoundError(fmt.Sprintf("group %q does not exist", name))
}
func (p *MemoryProvider) actionExistsInternal(name string) (BaseEventAction, error) {
if val, ok := p.dbHandle.actions[name]; ok {
return val.getACopy(), nil
}
return BaseEventAction{}, util.NewRecordNotFoundError(fmt.Sprintf("event action %q does not exist", name))
}
func (p *MemoryProvider) ruleExistsInternal(name string) (EventRule, error) {
if val, ok := p.dbHandle.rules[name]; ok {
return val.getACopy(), nil
}
return EventRule{}, util.NewRecordNotFoundError(fmt.Sprintf("event rule %q does not exist", name))
return Group{}, util.NewRecordNotFoundError(fmt.Sprintf("group %#v does not exist", name))
}
func (p *MemoryProvider) addAdmin(admin *Admin) error {
@@ -653,17 +618,6 @@ func (p *MemoryProvider) addAdmin(admin *Admin) error {
admin.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
admin.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
admin.LastLogin = 0
var mappedAdmins []string
for idx := range admin.Groups {
if err = p.addAdminToGroupMapping(admin.Username, admin.Groups[idx].Name); err != nil {
// try to remove group mapping
for _, g := range mappedAdmins {
p.removeAdminFromGroupMapping(admin.Username, g)
}
return err
}
mappedAdmins = append(mappedAdmins, admin.Groups[idx].Name)
}
p.dbHandle.admins[admin.Username] = admin.getACopy()
p.dbHandle.adminsUsernames = append(p.dbHandle.adminsUsernames, admin.Username)
sort.Strings(p.dbHandle.adminsUsernames)
@@ -684,21 +638,6 @@ func (p *MemoryProvider) updateAdmin(admin *Admin) error {
if err != nil {
return err
}
for idx := range a.Groups {
p.removeAdminFromGroupMapping(a.Username, a.Groups[idx].Name)
}
for idx := range admin.Groups {
if err = p.addAdminToGroupMapping(admin.Username, admin.Groups[idx].Name); err != nil {
// try to add old mapping
for _, oldGroup := range a.Groups {
if errRollback := p.addAdminToGroupMapping(a.Username, oldGroup.Name); errRollback != nil {
providerLog(logger.LevelError, "unable to rollback old group mapping %q for admin %q, error: %v",
oldGroup.Name, a.Username, errRollback)
}
}
return err
}
}
admin.ID = a.ID
admin.CreatedAt = a.CreatedAt
admin.LastLogin = a.LastLogin
@@ -713,13 +652,10 @@ func (p *MemoryProvider) deleteAdmin(admin Admin) error {
if p.dbHandle.isClosed {
return errMemoryProviderClosed
}
a, err := p.adminExistsInternal(admin.Username)
_, err := p.adminExistsInternal(admin.Username)
if err != nil {
return err
}
for idx := range a.Groups {
p.removeAdminFromGroupMapping(a.Username, a.Groups[idx].Name)
}
delete(p.dbHandle.admins, admin.Username)
// this could be more efficient
@@ -944,8 +880,6 @@ func (p *MemoryProvider) addGroup(group *Group) error {
group.ID = p.getNextGroupID()
group.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
group.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
group.Users = nil
group.Admins = nil
group.VirtualFolders = p.joinGroupVirtualFoldersFields(group)
p.dbHandle.groups[group.Name] = group.getACopy()
p.dbHandle.groupnames = append(p.dbHandle.groupnames, group.Name)
@@ -973,8 +907,6 @@ func (p *MemoryProvider) updateGroup(group *Group) error {
group.CreatedAt = g.CreatedAt
group.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
group.ID = g.ID
group.Users = g.Users
group.Admins = g.Admins
p.dbHandle.groups[group.Name] = group.getACopy()
return nil
}
@@ -995,9 +927,6 @@ func (p *MemoryProvider) deleteGroup(group Group) error {
for _, oldFolder := range g.VirtualFolders {
p.removeRelationFromFolderMapping(oldFolder.Name, "", g.Name)
}
for _, a := range g.Admins {
p.removeGroupFromAdminMapping(g.Name, a)
}
delete(p.dbHandle.groups, group.Name)
// this could be more efficient
p.dbHandle.groupnames = make([]string, 0, len(p.dbHandle.groups))
@@ -1068,95 +997,7 @@ func (p *MemoryProvider) addVirtualFoldersToGroup(group *Group) {
}
}
func (p *MemoryProvider) addActionsToRule(rule *EventRule) {
var actions []EventAction
for idx := range rule.Actions {
action := &rule.Actions[idx]
baseAction, err := p.actionExistsInternal(action.Name)
if err != nil {
continue
}
baseAction.Options.SetEmptySecretsIfNil()
action.BaseEventAction = baseAction
actions = append(actions, *action)
}
rule.Actions = actions
}
func (p *MemoryProvider) addRuleToActionMapping(ruleName, actionName string) error {
a, err := p.actionExistsInternal(actionName)
if err != nil {
return util.NewGenericError(fmt.Sprintf("action %q does not exist", actionName))
}
if !util.Contains(a.Rules, ruleName) {
a.Rules = append(a.Rules, ruleName)
p.dbHandle.actions[actionName] = a
}
return nil
}
func (p *MemoryProvider) removeRuleFromActionMapping(ruleName, actionName string) {
a, err := p.actionExistsInternal(actionName)
if err != nil {
providerLog(logger.LevelWarn, "action %q does not exist, cannot remove from mapping", actionName)
return
}
if util.Contains(a.Rules, ruleName) {
var rules []string
for _, r := range a.Rules {
if r != ruleName {
rules = append(rules, r)
}
}
a.Rules = rules
p.dbHandle.actions[actionName] = a
}
}
func (p *MemoryProvider) addAdminToGroupMapping(username, groupname string) error {
g, err := p.groupExistsInternal(groupname)
if err != nil {
return err
}
if !util.Contains(g.Admins, username) {
g.Admins = append(g.Admins, username)
p.dbHandle.groups[groupname] = g
}
return nil
}
func (p *MemoryProvider) removeAdminFromGroupMapping(username, groupname string) {
g, err := p.groupExistsInternal(groupname)
if err != nil {
return
}
var admins []string
for _, a := range g.Admins {
if a != username {
admins = append(admins, a)
}
}
g.Admins = admins
p.dbHandle.groups[groupname] = g
}
func (p *MemoryProvider) removeGroupFromAdminMapping(groupname, username string) {
admin, err := p.adminExistsInternal(username)
if err != nil {
// the admin does not exist so there is no associated group
return
}
var newGroups []AdminGroupMapping
for _, g := range admin.Groups {
if g.Name != groupname {
newGroups = append(newGroups, g)
}
}
admin.Groups = newGroups
p.dbHandle.admins[admin.Username] = admin
}
func (p *MemoryProvider) addUserToGroupMapping(username, groupname string) error {
func (p *MemoryProvider) addUserFromGroupMapping(username, groupname string) error {
g, err := p.groupExistsInternal(groupname)
if err != nil {
return err
@@ -1168,19 +1009,22 @@ func (p *MemoryProvider) addUserToGroupMapping(username, groupname string) error
return nil
}
func (p *MemoryProvider) removeUserFromGroupMapping(username, groupname string) {
func (p *MemoryProvider) removeUserFromGroupMapping(username, groupname string) error {
g, err := p.groupExistsInternal(groupname)
if err != nil {
return
return err
}
var users []string
for _, u := range g.Users {
if u != username {
users = append(users, u)
if util.Contains(g.Users, username) {
var users []string
for _, u := range g.Users {
if u != username {
users = append(users, u)
}
}
g.Users = users
p.dbHandle.groups[groupname] = g
}
g.Users = users
p.dbHandle.groups[groupname] = g
return nil
}
func (p *MemoryProvider) joinUserVirtualFoldersFields(user *User) []vfs.VirtualFolder {
@@ -1364,7 +1208,6 @@ func (p *MemoryProvider) addFolder(folder *vfs.BaseVirtualFolder) error {
}
folder.ID = p.getNextFolderID()
folder.Users = nil
folder.Groups = nil
p.dbHandle.vfolders[folder.Name] = folder.GetACopy()
p.dbHandle.vfoldersNames = append(p.dbHandle.vfoldersNames, folder.Name)
sort.Strings(p.dbHandle.vfoldersNames)
@@ -1391,7 +1234,6 @@ func (p *MemoryProvider) updateFolder(folder *vfs.BaseVirtualFolder) error {
folder.UsedQuotaFiles = f.UsedQuotaFiles
folder.UsedQuotaSize = f.UsedQuotaSize
folder.Users = f.Users
folder.Groups = f.Groups
p.dbHandle.vfolders[folder.Name] = folder.GetACopy()
// now update the related users
for _, username := range folder.Users {
@@ -1412,14 +1254,14 @@ func (p *MemoryProvider) updateFolder(folder *vfs.BaseVirtualFolder) error {
return nil
}
func (p *MemoryProvider) deleteFolder(f vfs.BaseVirtualFolder) error {
func (p *MemoryProvider) deleteFolder(folder vfs.BaseVirtualFolder) error {
p.dbHandle.Lock()
defer p.dbHandle.Unlock()
if p.dbHandle.isClosed {
return errMemoryProviderClosed
}
folder, err := p.folderExistsInternal(f.Name)
_, err := p.folderExistsInternal(folder.Name)
if err != nil {
return err
}
@@ -1940,426 +1782,6 @@ func (p *MemoryProvider) cleanupSharedSessions(sessionType SessionType, before i
return ErrNotImplemented
}
func (p *MemoryProvider) getEventActions(limit, offset int, order string, minimal bool) ([]BaseEventAction, error) {
p.dbHandle.Lock()
defer p.dbHandle.Unlock()
if p.dbHandle.isClosed {
return nil, errMemoryProviderClosed
}
if limit <= 0 {
return nil, nil
}
actions := make([]BaseEventAction, 0, limit)
itNum := 0
if order == OrderASC {
for _, name := range p.dbHandle.actionsNames {
itNum++
if itNum <= offset {
continue
}
a := p.dbHandle.actions[name]
action := a.getACopy()
action.PrepareForRendering()
actions = append(actions, action)
if len(actions) >= limit {
break
}
}
} else {
for i := len(p.dbHandle.actionsNames) - 1; i >= 0; i-- {
itNum++
if itNum <= offset {
continue
}
name := p.dbHandle.actionsNames[i]
a := p.dbHandle.actions[name]
action := a.getACopy()
action.PrepareForRendering()
actions = append(actions, action)
if len(actions) >= limit {
break
}
}
}
return actions, nil
}
func (p *MemoryProvider) dumpEventActions() ([]BaseEventAction, error) {
p.dbHandle.Lock()
defer p.dbHandle.Unlock()
if p.dbHandle.isClosed {
return nil, errMemoryProviderClosed
}
actions := make([]BaseEventAction, 0, len(p.dbHandle.actions))
for _, name := range p.dbHandle.actionsNames {
a := p.dbHandle.actions[name]
action := a.getACopy()
actions = append(actions, action)
}
return actions, nil
}
func (p *MemoryProvider) eventActionExists(name string) (BaseEventAction, error) {
p.dbHandle.Lock()
defer p.dbHandle.Unlock()
if p.dbHandle.isClosed {
return BaseEventAction{}, errMemoryProviderClosed
}
return p.actionExistsInternal(name)
}
func (p *MemoryProvider) addEventAction(action *BaseEventAction) error {
err := action.validate()
if err != nil {
return err
}
p.dbHandle.Lock()
defer p.dbHandle.Unlock()
if p.dbHandle.isClosed {
return errMemoryProviderClosed
}
_, err = p.actionExistsInternal(action.Name)
if err == nil {
return fmt.Errorf("event action %q already exists", action.Name)
}
action.ID = p.getNextActionID()
action.Rules = nil
p.dbHandle.actions[action.Name] = action.getACopy()
p.dbHandle.actionsNames = append(p.dbHandle.actionsNames, action.Name)
sort.Strings(p.dbHandle.actionsNames)
return nil
}
func (p *MemoryProvider) updateEventAction(action *BaseEventAction) error {
err := action.validate()
if err != nil {
return err
}
p.dbHandle.Lock()
defer p.dbHandle.Unlock()
if p.dbHandle.isClosed {
return errMemoryProviderClosed
}
oldAction, err := p.actionExistsInternal(action.Name)
if err != nil {
return fmt.Errorf("event action %s does not exist", action.Name)
}
action.ID = oldAction.ID
action.Name = oldAction.Name
action.Rules = nil
if len(oldAction.Rules) > 0 {
var relatedRules []string
for _, ruleName := range oldAction.Rules {
rule, err := p.ruleExistsInternal(ruleName)
if err == nil {
relatedRules = append(relatedRules, ruleName)
rule.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
p.dbHandle.rules[ruleName] = rule
setLastRuleUpdate()
}
}
action.Rules = relatedRules
}
p.dbHandle.actions[action.Name] = action.getACopy()
return nil
}
func (p *MemoryProvider) deleteEventAction(action BaseEventAction) error {
p.dbHandle.Lock()
defer p.dbHandle.Unlock()
if p.dbHandle.isClosed {
return errMemoryProviderClosed
}
oldAction, err := p.actionExistsInternal(action.Name)
if err != nil {
return fmt.Errorf("event action %s does not exist", action.Name)
}
if len(oldAction.Rules) > 0 {
return util.NewValidationError(fmt.Sprintf("action %s is referenced, it cannot be removed", oldAction.Name))
}
delete(p.dbHandle.actions, action.Name)
// this could be more efficient
p.dbHandle.actionsNames = make([]string, 0, len(p.dbHandle.actions))
for name := range p.dbHandle.actions {
p.dbHandle.actionsNames = append(p.dbHandle.actionsNames, name)
}
sort.Strings(p.dbHandle.actionsNames)
return nil
}
func (p *MemoryProvider) getEventRules(limit, offset int, order string) ([]EventRule, error) {
p.dbHandle.Lock()
defer p.dbHandle.Unlock()
if p.dbHandle.isClosed {
return nil, errMemoryProviderClosed
}
if limit <= 0 {
return nil, nil
}
itNum := 0
rules := make([]EventRule, 0, limit)
if order == OrderASC {
for _, name := range p.dbHandle.rulesNames {
itNum++
if itNum <= offset {
continue
}
r := p.dbHandle.rules[name]
rule := r.getACopy()
p.addActionsToRule(&rule)
rule.PrepareForRendering()
rules = append(rules, rule)
if len(rules) >= limit {
break
}
}
} else {
for i := len(p.dbHandle.rulesNames) - 1; i >= 0; i-- {
itNum++
if itNum <= offset {
continue
}
name := p.dbHandle.rulesNames[i]
r := p.dbHandle.rules[name]
rule := r.getACopy()
p.addActionsToRule(&rule)
rule.PrepareForRendering()
rules = append(rules, rule)
if len(rules) >= limit {
break
}
}
}
return rules, nil
}
func (p *MemoryProvider) dumpEventRules() ([]EventRule, error) {
p.dbHandle.Lock()
defer p.dbHandle.Unlock()
if p.dbHandle.isClosed {
return nil, errMemoryProviderClosed
}
rules := make([]EventRule, 0, len(p.dbHandle.rules))
for _, name := range p.dbHandle.rulesNames {
r := p.dbHandle.rules[name]
rule := r.getACopy()
p.addActionsToRule(&rule)
rules = append(rules, rule)
}
return rules, nil
}
func (p *MemoryProvider) getRecentlyUpdatedRules(after int64) ([]EventRule, error) {
if getLastRuleUpdate() < after {
return nil, nil
}
p.dbHandle.Lock()
defer p.dbHandle.Unlock()
if p.dbHandle.isClosed {
return nil, errMemoryProviderClosed
}
rules := make([]EventRule, 0, 10)
for _, name := range p.dbHandle.rulesNames {
r := p.dbHandle.rules[name]
if r.UpdatedAt < after {
continue
}
rule := r.getACopy()
p.addActionsToRule(&rule)
rules = append(rules, rule)
}
return rules, nil
}
func (p *MemoryProvider) eventRuleExists(name string) (EventRule, error) {
p.dbHandle.Lock()
defer p.dbHandle.Unlock()
if p.dbHandle.isClosed {
return EventRule{}, errMemoryProviderClosed
}
rule, err := p.ruleExistsInternal(name)
if err != nil {
return rule, err
}
p.addActionsToRule(&rule)
return rule, nil
}
func (p *MemoryProvider) addEventRule(rule *EventRule) error {
if err := rule.validate(); err != nil {
return err
}
p.dbHandle.Lock()
defer p.dbHandle.Unlock()
if p.dbHandle.isClosed {
return errMemoryProviderClosed
}
_, err := p.ruleExistsInternal(rule.Name)
if err == nil {
return fmt.Errorf("event rule %q already exists", rule.Name)
}
rule.ID = p.getNextRuleID()
rule.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
rule.UpdatedAt = rule.CreatedAt
var mappedActions []string
for idx := range rule.Actions {
if err := p.addRuleToActionMapping(rule.Name, rule.Actions[idx].Name); err != nil {
// try to remove action mapping
for _, a := range mappedActions {
p.removeRuleFromActionMapping(rule.Name, a)
}
return err
}
mappedActions = append(mappedActions, rule.Actions[idx].Name)
}
sort.Slice(rule.Actions, func(i, j int) bool {
return rule.Actions[i].Order < rule.Actions[j].Order
})
p.dbHandle.rules[rule.Name] = rule.getACopy()
p.dbHandle.rulesNames = append(p.dbHandle.rulesNames, rule.Name)
sort.Strings(p.dbHandle.rulesNames)
setLastRuleUpdate()
return nil
}
func (p *MemoryProvider) updateEventRule(rule *EventRule) error {
if err := rule.validate(); err != nil {
return err
}
p.dbHandle.Lock()
defer p.dbHandle.Unlock()
if p.dbHandle.isClosed {
return errMemoryProviderClosed
}
oldRule, err := p.ruleExistsInternal(rule.Name)
if err != nil {
return err
}
for idx := range oldRule.Actions {
p.removeRuleFromActionMapping(rule.Name, oldRule.Actions[idx].Name)
}
for idx := range rule.Actions {
if err = p.addRuleToActionMapping(rule.Name, rule.Actions[idx].Name); err != nil {
// try to add old mapping
for _, oldAction := range oldRule.Actions {
if errRollback := p.addRuleToActionMapping(oldRule.Name, oldAction.Name); errRollback != nil {
providerLog(logger.LevelError, "unable to rollback old action mapping %q for rule %q, error: %v",
oldAction.Name, oldRule.Name, errRollback)
}
}
return err
}
}
rule.ID = oldRule.ID
rule.CreatedAt = oldRule.CreatedAt
rule.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
sort.Slice(rule.Actions, func(i, j int) bool {
return rule.Actions[i].Order < rule.Actions[j].Order
})
p.dbHandle.rules[rule.Name] = rule.getACopy()
setLastRuleUpdate()
return nil
}
func (p *MemoryProvider) deleteEventRule(rule EventRule, softDelete bool) error {
p.dbHandle.Lock()
defer p.dbHandle.Unlock()
if p.dbHandle.isClosed {
return errMemoryProviderClosed
}
oldRule, err := p.ruleExistsInternal(rule.Name)
if err != nil {
return err
}
if len(oldRule.Actions) > 0 {
for idx := range oldRule.Actions {
p.removeRuleFromActionMapping(rule.Name, oldRule.Actions[idx].Name)
}
}
delete(p.dbHandle.rules, rule.Name)
p.dbHandle.rulesNames = make([]string, 0, len(p.dbHandle.rules))
for name := range p.dbHandle.rules {
p.dbHandle.rulesNames = append(p.dbHandle.rulesNames, name)
}
sort.Strings(p.dbHandle.rulesNames)
setLastRuleUpdate()
return nil
}
func (*MemoryProvider) getTaskByName(name string) (Task, error) {
return Task{}, ErrNotImplemented
}
func (*MemoryProvider) addTask(name string) error {
return ErrNotImplemented
}
func (*MemoryProvider) updateTask(name string, version int64) error {
return ErrNotImplemented
}
func (*MemoryProvider) updateTaskTimestamp(name string) error {
return ErrNotImplemented
}
func (*MemoryProvider) addNode() error {
return ErrNotImplemented
}
func (*MemoryProvider) getNodeByName(name string) (Node, error) {
return Node{}, ErrNotImplemented
}
func (*MemoryProvider) getNodes() ([]Node, error) {
return nil, ErrNotImplemented
}
func (*MemoryProvider) updateNodeTimestamp() error {
return ErrNotImplemented
}
func (*MemoryProvider) cleanupNodes() error {
return ErrNotImplemented
}
func (p *MemoryProvider) setFirstDownloadTimestamp(username string) error {
p.dbHandle.Lock()
defer p.dbHandle.Unlock()
if p.dbHandle.isClosed {
return errMemoryProviderClosed
}
user, err := p.userExistsInternal(username)
if err != nil {
return err
}
if user.FirstDownload > 0 {
return util.NewGenericError(fmt.Sprintf("first download already set to %v",
util.GetTimeFromMsecSinceEpoch(user.FirstDownload)))
}
user.FirstDownload = util.GetTimeAsMsSinceEpoch(time.Now())
p.dbHandle.users[user.Username] = user
return nil
}
func (p *MemoryProvider) setFirstUploadTimestamp(username string) error {
p.dbHandle.Lock()
defer p.dbHandle.Unlock()
if p.dbHandle.isClosed {
return errMemoryProviderClosed
}
user, err := p.userExistsInternal(username)
if err != nil {
return err
}
if user.FirstUpload > 0 {
return util.NewGenericError(fmt.Sprintf("first upload already set to %v",
util.GetTimeFromMsecSinceEpoch(user.FirstUpload)))
}
user.FirstUpload = util.GetTimeAsMsSinceEpoch(time.Now())
p.dbHandle.users[user.Username] = user
return nil
}
func (p *MemoryProvider) getNextID() int64 {
nextID := int64(1)
for _, v := range p.dbHandle.users {
@@ -2400,26 +1822,6 @@ func (p *MemoryProvider) getNextGroupID() int64 {
return nextID
}
func (p *MemoryProvider) getNextActionID() int64 {
nextID := int64(1)
for _, a := range p.dbHandle.actions {
if a.ID >= nextID {
nextID = a.ID + 1
}
}
return nextID
}
func (p *MemoryProvider) getNextRuleID() int64 {
nextID := int64(1)
for _, r := range p.dbHandle.rules {
if r.ID >= nextID {
nextID = r.ID + 1
}
}
return nextID
}
func (p *MemoryProvider) clear() {
p.dbHandle.Lock()
defer p.dbHandle.Unlock()
@@ -2468,35 +1870,27 @@ func (p *MemoryProvider) reloadConfig() error {
}
p.clear()
if err := p.restoreFolders(dump); err != nil {
if err := p.restoreFolders(&dump); err != nil {
return err
}
if err := p.restoreGroups(dump); err != nil {
if err := p.restoreGroups(&dump); err != nil {
return err
}
if err := p.restoreUsers(dump); err != nil {
if err := p.restoreUsers(&dump); err != nil {
return err
}
if err := p.restoreAdmins(dump); err != nil {
if err := p.restoreAdmins(&dump); err != nil {
return err
}
if err := p.restoreAPIKeys(dump); err != nil {
if err := p.restoreAPIKeys(&dump); err != nil {
return err
}
if err := p.restoreShares(dump); err != nil {
return err
}
if err := p.restoreEventActions(dump); err != nil {
return err
}
if err := p.restoreEventRules(dump); err != nil {
if err := p.restoreShares(&dump); err != nil {
return err
}
@@ -2504,51 +1898,7 @@ func (p *MemoryProvider) reloadConfig() error {
return nil
}
func (p *MemoryProvider) restoreEventActions(dump BackupData) error {
for _, action := range dump.EventActions {
a, err := p.eventActionExists(action.Name)
action := action // pin
if err == nil {
action.ID = a.ID
err = UpdateEventAction(&action, ActionExecutorSystem, "")
if err != nil {
providerLog(logger.LevelError, "error updating event action %q: %v", action.Name, err)
return err
}
} else {
err = AddEventAction(&action, ActionExecutorSystem, "")
if err != nil {
providerLog(logger.LevelError, "error adding event action %q: %v", action.Name, err)
return err
}
}
}
return nil
}
func (p *MemoryProvider) restoreEventRules(dump BackupData) error {
for _, rule := range dump.EventRules {
r, err := p.eventRuleExists(rule.Name)
rule := rule // pin
if err == nil {
rule.ID = r.ID
err = UpdateEventRule(&rule, ActionExecutorSystem, "")
if err != nil {
providerLog(logger.LevelError, "error updating event rule %q: %v", rule.Name, err)
return err
}
} else {
err = AddEventRule(&rule, ActionExecutorSystem, "")
if err != nil {
providerLog(logger.LevelError, "error adding event rule %q: %v", rule.Name, err)
return err
}
}
}
return nil
}
func (p *MemoryProvider) restoreShares(dump BackupData) error {
func (p *MemoryProvider) restoreShares(dump *BackupData) error {
for _, share := range dump.Shares {
s, err := p.shareExists(share.ShareID, "")
share := share // pin
@@ -2571,7 +1921,7 @@ func (p *MemoryProvider) restoreShares(dump BackupData) error {
return nil
}
func (p *MemoryProvider) restoreAPIKeys(dump BackupData) error {
func (p *MemoryProvider) restoreAPIKeys(dump *BackupData) error {
for _, apiKey := range dump.APIKeys {
if apiKey.Key == "" {
return fmt.Errorf("cannot restore an empty API key: %+v", apiKey)
@@ -2596,7 +1946,7 @@ func (p *MemoryProvider) restoreAPIKeys(dump BackupData) error {
return nil
}
func (p *MemoryProvider) restoreAdmins(dump BackupData) error {
func (p *MemoryProvider) restoreAdmins(dump *BackupData) error {
for _, admin := range dump.Admins {
admin := admin // pin
admin.Username = config.convertName(admin.Username)
@@ -2619,7 +1969,7 @@ func (p *MemoryProvider) restoreAdmins(dump BackupData) error {
return nil
}
func (p *MemoryProvider) restoreGroups(dump BackupData) error {
func (p *MemoryProvider) restoreGroups(dump *BackupData) error {
for _, group := range dump.Groups {
group := group // pin
group.Name = config.convertName(group.Name)
@@ -2643,7 +1993,7 @@ func (p *MemoryProvider) restoreGroups(dump BackupData) error {
return nil
}
func (p *MemoryProvider) restoreFolders(dump BackupData) error {
func (p *MemoryProvider) restoreFolders(dump *BackupData) error {
for _, folder := range dump.Folders {
folder := folder // pin
folder.Name = config.convertName(folder.Name)
@@ -2667,7 +2017,7 @@ func (p *MemoryProvider) restoreFolders(dump BackupData) error {
return nil
}
func (p *MemoryProvider) restoreUsers(dump BackupData) error {
func (p *MemoryProvider) restoreUsers(dump *BackupData) error {
for _, user := range dump.Users {
user := user // pin
user.Username = config.convertName(user.Username)

View File

@@ -25,15 +25,14 @@ import (
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/go-sql-driver/mysql"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/version"
"github.com/drakkan/sftpgo/v2/internal/vfs"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/version"
"github.com/drakkan/sftpgo/v2/vfs"
)
const (
@@ -41,7 +40,6 @@ const (
"DROP TABLE IF EXISTS `{{folders_mapping}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{users_folders_mapping}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{users_groups_mapping}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{admins_groups_mapping}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{groups_folders_mapping}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{admins}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{folders}}` CASCADE;" +
@@ -52,22 +50,12 @@ const (
"DROP TABLE IF EXISTS `{{defender_hosts}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{active_transfers}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{shared_sessions}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{rules_actions_mapping}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{events_actions}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{events_rules}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{tasks}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{nodes}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{schema_version}}` CASCADE;"
mysqlInitialSQL = "CREATE TABLE `{{schema_version}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `version` integer NOT NULL);" +
"CREATE TABLE `{{admins}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `username` varchar(255) NOT NULL UNIQUE, " +
"`description` varchar(512) NULL, `password` varchar(255) NOT NULL, `email` varchar(255) NULL, `status` integer NOT NULL, " +
"`permissions` longtext NOT NULL, `filters` longtext NULL, `additional_info` longtext NULL, `last_login` bigint NOT NULL, " +
"`created_at` bigint NOT NULL, `updated_at` bigint NOT NULL);" +
"CREATE TABLE `{{active_transfers}}` (`id` bigint AUTO_INCREMENT NOT NULL PRIMARY KEY, " +
"`connection_id` varchar(100) NOT NULL, `transfer_id` bigint NOT NULL, `transfer_type` integer NOT NULL, " +
"`username` varchar(255) NOT NULL, `folder_name` varchar(255) NULL, `ip` varchar(50) NOT NULL, " +
"`truncated_size` bigint NOT NULL, `current_ul_size` bigint NOT NULL, `current_dl_size` bigint NOT NULL, " +
"`created_at` bigint NOT NULL, `updated_at` bigint NOT NULL);" +
"CREATE TABLE `{{defender_hosts}}` (`id` bigint AUTO_INCREMENT NOT NULL PRIMARY KEY, " +
"`ip` varchar(50) NOT NULL UNIQUE, `ban_time` bigint NOT NULL, `updated_at` bigint NOT NULL);" +
"CREATE TABLE `{{defender_events}}` (`id` bigint AUTO_INCREMENT NOT NULL PRIMARY KEY, " +
@@ -77,11 +65,6 @@ const (
"CREATE TABLE `{{folders}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `name` varchar(255) NOT NULL UNIQUE, " +
"`description` varchar(512) NULL, `path` longtext NULL, `used_quota_size` bigint NOT NULL, " +
"`used_quota_files` integer NOT NULL, `last_quota_update` bigint NOT NULL, `filesystem` longtext NULL);" +
"CREATE TABLE `{{groups}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " +
"`name` varchar(255) NOT NULL UNIQUE, `description` varchar(512) NULL, `created_at` bigint NOT NULL, " +
"`updated_at` bigint NOT NULL, `user_settings` longtext NULL);" +
"CREATE TABLE `{{shared_sessions}}` (`key` varchar(128) NOT NULL PRIMARY KEY, " +
"`data` longtext NOT NULL, `type` integer NOT NULL, `timestamp` bigint NOT NULL);" +
"CREATE TABLE `{{users}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `username` varchar(255) NOT NULL UNIQUE, " +
"`status` integer NOT NULL, `expiration_date` bigint NOT NULL, `description` varchar(512) NULL, `password` longtext NULL, " +
"`public_keys` longtext NULL, `home_dir` longtext NOT NULL, `uid` bigint NOT NULL, `gid` bigint NOT NULL, " +
@@ -89,33 +72,12 @@ const (
"`permissions` longtext NOT NULL, `used_quota_size` bigint NOT NULL, `used_quota_files` integer NOT NULL, " +
"`last_quota_update` bigint NOT NULL, `upload_bandwidth` integer NOT NULL, `download_bandwidth` integer NOT NULL, " +
"`last_login` bigint NOT NULL, `filters` longtext NULL, `filesystem` longtext NULL, `additional_info` longtext NULL, " +
"`created_at` bigint NOT NULL, `updated_at` bigint NOT NULL, `email` varchar(255) NULL, " +
"`upload_data_transfer` integer NOT NULL, `download_data_transfer` integer NOT NULL, " +
"`total_data_transfer` integer NOT NULL, `used_upload_data_transfer` integer NOT NULL, " +
"`used_download_data_transfer` integer NOT NULL);" +
"CREATE TABLE `{{groups_folders_mapping}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " +
"`group_id` integer NOT NULL, `folder_id` integer NOT NULL, " +
"`virtual_path` longtext NOT NULL, `quota_size` bigint NOT NULL, `quota_files` integer NOT NULL);" +
"CREATE TABLE `{{users_groups_mapping}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " +
"`user_id` integer NOT NULL, `group_id` integer NOT NULL, `group_type` integer NOT NULL);" +
"CREATE TABLE `{{users_folders_mapping}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `virtual_path` longtext NOT NULL, " +
"`created_at` bigint NOT NULL, `updated_at` bigint NOT NULL, `email` varchar(255) NULL);" +
"CREATE TABLE `{{folders_mapping}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `virtual_path` longtext NOT NULL, " +
"`quota_size` bigint NOT NULL, `quota_files` integer NOT NULL, `folder_id` integer NOT NULL, `user_id` integer NOT NULL);" +
"ALTER TABLE `{{users_folders_mapping}}` ADD CONSTRAINT `{{prefix}}unique_user_folder_mapping` " +
"UNIQUE (`user_id`, `folder_id`);" +
"ALTER TABLE `{{users_folders_mapping}}` ADD CONSTRAINT `{{prefix}}users_folders_mapping_user_id_fk_users_id` " +
"FOREIGN KEY (`user_id`) REFERENCES `{{users}}` (`id`) ON DELETE CASCADE;" +
"ALTER TABLE `{{users_folders_mapping}}` ADD CONSTRAINT `{{prefix}}users_folders_mapping_folder_id_fk_folders_id` " +
"FOREIGN KEY (`folder_id`) REFERENCES `{{folders}}` (`id`) ON DELETE CASCADE;" +
"ALTER TABLE `{{users_groups_mapping}}` ADD CONSTRAINT `{{prefix}}unique_user_group_mapping` UNIQUE (`user_id`, `group_id`);" +
"ALTER TABLE `{{groups_folders_mapping}}` ADD CONSTRAINT `{{prefix}}unique_group_folder_mapping` UNIQUE (`group_id`, `folder_id`);" +
"ALTER TABLE `{{users_groups_mapping}}` ADD CONSTRAINT `{{prefix}}users_groups_mapping_group_id_fk_groups_id` " +
"FOREIGN KEY (`group_id`) REFERENCES `{{groups}}` (`id`) ON DELETE NO ACTION;" +
"ALTER TABLE `{{users_groups_mapping}}` ADD CONSTRAINT `{{prefix}}users_groups_mapping_user_id_fk_users_id` " +
"FOREIGN KEY (`user_id`) REFERENCES `{{users}}` (`id`) ON DELETE CASCADE;" +
"ALTER TABLE `{{groups_folders_mapping}}` ADD CONSTRAINT `{{prefix}}groups_folders_mapping_folder_id_fk_folders_id` " +
"FOREIGN KEY (`folder_id`) REFERENCES `{{folders}}` (`id`) ON DELETE CASCADE;" +
"ALTER TABLE `{{groups_folders_mapping}}` ADD CONSTRAINT `{{prefix}}groups_folders_mapping_group_id_fk_groups_id` " +
"FOREIGN KEY (`group_id`) REFERENCES `{{groups}}` (`id`) ON DELETE CASCADE;" +
"ALTER TABLE `{{folders_mapping}}` ADD CONSTRAINT `{{prefix}}unique_mapping` UNIQUE (`user_id`, `folder_id`);" +
"ALTER TABLE `{{folders_mapping}}` ADD CONSTRAINT `{{prefix}}folders_mapping_folder_id_fk_folders_id` FOREIGN KEY (`folder_id`) REFERENCES `{{folders}}` (`id`) ON DELETE CASCADE;" +
"ALTER TABLE `{{folders_mapping}}` ADD CONSTRAINT `{{prefix}}folders_mapping_user_id_fk_users_id` FOREIGN KEY (`user_id`) REFERENCES `{{users}}` (`id`) ON DELETE CASCADE;" +
"CREATE TABLE `{{shares}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " +
"`share_id` varchar(60) NOT NULL UNIQUE, `name` varchar(255) NOT NULL, `description` varchar(512) NULL, " +
"`scope` integer NOT NULL, `paths` longtext NOT NULL, `created_at` bigint NOT NULL, " +
@@ -133,60 +95,83 @@ const (
"CREATE INDEX `{{prefix}}defender_hosts_updated_at_idx` ON `{{defender_hosts}}` (`updated_at`);" +
"CREATE INDEX `{{prefix}}defender_hosts_ban_time_idx` ON `{{defender_hosts}}` (`ban_time`);" +
"CREATE INDEX `{{prefix}}defender_events_date_time_idx` ON `{{defender_events}}` (`date_time`);" +
"INSERT INTO {{schema_version}} (version) VALUES (15);"
mysqlV16SQL = "ALTER TABLE `{{users}}` ADD COLUMN `download_data_transfer` integer DEFAULT 0 NOT NULL;" +
"ALTER TABLE `{{users}}` ALTER COLUMN `download_data_transfer` DROP DEFAULT;" +
"ALTER TABLE `{{users}}` ADD COLUMN `total_data_transfer` integer DEFAULT 0 NOT NULL;" +
"ALTER TABLE `{{users}}` ALTER COLUMN `total_data_transfer` DROP DEFAULT;" +
"ALTER TABLE `{{users}}` ADD COLUMN `upload_data_transfer` integer DEFAULT 0 NOT NULL;" +
"ALTER TABLE `{{users}}` ALTER COLUMN `upload_data_transfer` DROP DEFAULT;" +
"ALTER TABLE `{{users}}` ADD COLUMN `used_download_data_transfer` integer DEFAULT 0 NOT NULL;" +
"ALTER TABLE `{{users}}` ALTER COLUMN `used_download_data_transfer` DROP DEFAULT;" +
"ALTER TABLE `{{users}}` ADD COLUMN `used_upload_data_transfer` integer DEFAULT 0 NOT NULL;" +
"ALTER TABLE `{{users}}` ALTER COLUMN `used_upload_data_transfer` DROP DEFAULT;" +
"CREATE TABLE `{{active_transfers}}` (`id` bigint AUTO_INCREMENT NOT NULL PRIMARY KEY, " +
"`connection_id` varchar(100) NOT NULL, `transfer_id` bigint NOT NULL, `transfer_type` integer NOT NULL, " +
"`username` varchar(255) NOT NULL, `folder_name` varchar(255) NULL, `ip` varchar(50) NOT NULL, " +
"`truncated_size` bigint NOT NULL, `current_ul_size` bigint NOT NULL, `current_dl_size` bigint NOT NULL, " +
"`created_at` bigint NOT NULL, `updated_at` bigint NOT NULL);" +
"CREATE INDEX `{{prefix}}active_transfers_connection_id_idx` ON `{{active_transfers}}` (`connection_id`);" +
"CREATE INDEX `{{prefix}}active_transfers_transfer_id_idx` ON `{{active_transfers}}` (`transfer_id`);" +
"CREATE INDEX `{{prefix}}active_transfers_updated_at_idx` ON `{{active_transfers}}` (`updated_at`);" +
"CREATE INDEX `{{prefix}}groups_updated_at_idx` ON `{{groups}}` (`updated_at`);" +
"CREATE INDEX `{{prefix}}shared_sessions_type_idx` ON `{{shared_sessions}}` (`type`);" +
"CREATE INDEX `{{prefix}}shared_sessions_timestamp_idx` ON `{{shared_sessions}}` (`timestamp`);" +
"INSERT INTO {{schema_version}} (version) VALUES (19);"
mysqlV20SQL = "CREATE TABLE `{{events_rules}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " +
"CREATE INDEX `{{prefix}}active_transfers_updated_at_idx` ON `{{active_transfers}}` (`updated_at`);"
mysqlV16DownSQL = "ALTER TABLE `{{users}}` DROP COLUMN `used_upload_data_transfer`;" +
"ALTER TABLE `{{users}}` DROP COLUMN `used_download_data_transfer`;" +
"ALTER TABLE `{{users}}` DROP COLUMN `upload_data_transfer`;" +
"ALTER TABLE `{{users}}` DROP COLUMN `total_data_transfer`;" +
"ALTER TABLE `{{users}}` DROP COLUMN `download_data_transfer`;" +
"DROP TABLE `{{active_transfers}}` CASCADE;"
mysqlV17SQL = "CREATE TABLE `{{groups}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " +
"`name` varchar(255) NOT NULL UNIQUE, `description` varchar(512) NULL, `created_at` bigint NOT NULL, " +
"`updated_at` bigint NOT NULL, `trigger` integer NOT NULL, `conditions` longtext NOT NULL, `deleted_at` bigint NOT NULL);" +
"CREATE TABLE `{{events_actions}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " +
"`name` varchar(255) NOT NULL UNIQUE, `description` varchar(512) NULL, `type` integer NOT NULL, " +
"`options` longtext NOT NULL);" +
"CREATE TABLE `{{rules_actions_mapping}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " +
"`rule_id` integer NOT NULL, `action_id` integer NOT NULL, `order` integer NOT NULL, `options` longtext NOT NULL);" +
"CREATE TABLE `{{tasks}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `name` varchar(255) NOT NULL UNIQUE, " +
"`updated_at` bigint NOT NULL, `version` bigint NOT NULL);" +
"ALTER TABLE `{{rules_actions_mapping}}` ADD CONSTRAINT `{{prefix}}unique_rule_action_mapping` UNIQUE (`rule_id`, `action_id`);" +
"ALTER TABLE `{{rules_actions_mapping}}` ADD CONSTRAINT `{{prefix}}rules_actions_mapping_rule_id_fk_events_rules_id` " +
"FOREIGN KEY (`rule_id`) REFERENCES `{{events_rules}}` (`id`) ON DELETE CASCADE;" +
"ALTER TABLE `{{rules_actions_mapping}}` ADD CONSTRAINT `{{prefix}}rules_actions_mapping_action_id_fk_events_targets_id` " +
"FOREIGN KEY (`action_id`) REFERENCES `{{events_actions}}` (`id`) ON DELETE NO ACTION;" +
"ALTER TABLE `{{users}}` ADD COLUMN `deleted_at` bigint DEFAULT 0 NOT NULL;" +
"ALTER TABLE `{{users}}` ALTER COLUMN `deleted_at` DROP DEFAULT;" +
"CREATE INDEX `{{prefix}}events_rules_updated_at_idx` ON `{{events_rules}}` (`updated_at`);" +
"CREATE INDEX `{{prefix}}events_rules_deleted_at_idx` ON `{{events_rules}}` (`deleted_at`);" +
"CREATE INDEX `{{prefix}}events_rules_trigger_idx` ON `{{events_rules}}` (`trigger`);" +
"CREATE INDEX `{{prefix}}rules_actions_mapping_order_idx` ON `{{rules_actions_mapping}}` (`order`);" +
"CREATE INDEX `{{prefix}}users_deleted_at_idx` ON `{{users}}` (`deleted_at`);"
mysqlV20DownSQL = "DROP TABLE `{{rules_actions_mapping}}` CASCADE;" +
"DROP TABLE `{{events_rules}}` CASCADE;" +
"DROP TABLE `{{events_actions}}` CASCADE;" +
"DROP TABLE `{{tasks}}` CASCADE;" +
"ALTER TABLE `{{users}}` DROP COLUMN `deleted_at`;"
mysqlV21SQL = "ALTER TABLE `{{users}}` ADD COLUMN `first_download` bigint DEFAULT 0 NOT NULL; " +
"ALTER TABLE `{{users}}` ALTER COLUMN `first_download` DROP DEFAULT; " +
"ALTER TABLE `{{users}}` ADD COLUMN `first_upload` bigint DEFAULT 0 NOT NULL; " +
"ALTER TABLE `{{users}}` ALTER COLUMN `first_upload` DROP DEFAULT;"
mysqlV21DownSQL = "ALTER TABLE `{{users}}` DROP COLUMN `first_upload`; " +
"ALTER TABLE `{{users}}` DROP COLUMN `first_download`;"
mysqlV22SQL = "CREATE TABLE `{{admins_groups_mapping}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " +
" `admin_id` integer NOT NULL, `group_id` integer NOT NULL, `options` longtext NOT NULL);" +
"ALTER TABLE `{{admins_groups_mapping}}` ADD CONSTRAINT `{{prefix}}unique_admin_group_mapping` " +
"UNIQUE (`admin_id`, `group_id`);" +
"ALTER TABLE `{{admins_groups_mapping}}` ADD CONSTRAINT `{{prefix}}admins_groups_mapping_admin_id_fk_admins_id` " +
"FOREIGN KEY (`admin_id`) REFERENCES `{{admins}}` (`id`) ON DELETE CASCADE;" +
"ALTER TABLE `{{admins_groups_mapping}}` ADD CONSTRAINT `{{prefix}}admins_groups_mapping_group_id_fk_groups_id` " +
"FOREIGN KEY (`group_id`) REFERENCES `{{groups}}` (`id`) ON DELETE CASCADE;"
mysqlV22DownSQL = "ALTER TABLE `{{admins_groups_mapping}}` DROP INDEX `{{prefix}}unique_admin_group_mapping`;" +
"DROP TABLE `{{admins_groups_mapping}}` CASCADE;"
mysqlV23SQL = "CREATE TABLE `{{nodes}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " +
"`name` varchar(255) NOT NULL UNIQUE, `data` longtext NOT NULL, `created_at` bigint NOT NULL, " +
"`updated_at` bigint NOT NULL);"
mysqlV23DownSQL = "DROP TABLE `{{nodes}}` CASCADE;"
"`updated_at` bigint NOT NULL, `user_settings` longtext NULL);" +
"CREATE TABLE `{{groups_folders_mapping}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " +
"`group_id` integer NOT NULL, `folder_id` integer NOT NULL, " +
"`virtual_path` longtext NOT NULL, `quota_size` bigint NOT NULL, `quota_files` integer NOT NULL);" +
"CREATE TABLE `{{users_groups_mapping}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " +
"`user_id` integer NOT NULL, `group_id` integer NOT NULL, `group_type` integer NOT NULL);" +
"ALTER TABLE `{{folders_mapping}}` DROP FOREIGN KEY `{{prefix}}folders_mapping_folder_id_fk_folders_id`;" +
"ALTER TABLE `{{folders_mapping}}` DROP FOREIGN KEY `{{prefix}}folders_mapping_user_id_fk_users_id`;" +
"ALTER TABLE `{{folders_mapping}}` DROP INDEX `{{prefix}}unique_mapping`;" +
"RENAME TABLE `{{folders_mapping}}` TO `{{users_folders_mapping}}`;" +
"ALTER TABLE `{{users_folders_mapping}}` ADD CONSTRAINT `{{prefix}}unique_user_folder_mapping` " +
"UNIQUE (`user_id`, `folder_id`);" +
"ALTER TABLE `{{users_folders_mapping}}` ADD CONSTRAINT `{{prefix}}users_folders_mapping_user_id_fk_users_id` " +
"FOREIGN KEY (`user_id`) REFERENCES `{{users}}` (`id`) ON DELETE CASCADE;" +
"ALTER TABLE `{{users_folders_mapping}}` ADD CONSTRAINT `{{prefix}}users_folders_mapping_folder_id_fk_folders_id` " +
"FOREIGN KEY (`folder_id`) REFERENCES `{{folders}}` (`id`) ON DELETE CASCADE;" +
"ALTER TABLE `{{users_groups_mapping}}` ADD CONSTRAINT `{{prefix}}unique_user_group_mapping` UNIQUE (`user_id`, `group_id`);" +
"ALTER TABLE `{{groups_folders_mapping}}` ADD CONSTRAINT `{{prefix}}unique_group_folder_mapping` UNIQUE (`group_id`, `folder_id`);" +
"ALTER TABLE `{{users_groups_mapping}}` ADD CONSTRAINT `{{prefix}}users_groups_mapping_group_id_fk_groups_id` " +
"FOREIGN KEY (`group_id`) REFERENCES `{{groups}}` (`id`) ON DELETE NO ACTION;" +
"ALTER TABLE `{{users_groups_mapping}}` ADD CONSTRAINT `{{prefix}}users_groups_mapping_user_id_fk_users_id` " +
"FOREIGN KEY (`user_id`) REFERENCES `{{users}}` (`id`) ON DELETE CASCADE;" +
"ALTER TABLE `{{groups_folders_mapping}}` ADD CONSTRAINT `{{prefix}}groups_folders_mapping_folder_id_fk_folders_id` " +
"FOREIGN KEY (`folder_id`) REFERENCES `{{folders}}` (`id`) ON DELETE CASCADE;" +
"ALTER TABLE `{{groups_folders_mapping}}` ADD CONSTRAINT `{{prefix}}groups_folders_mapping_group_id_fk_groups_id` " +
"FOREIGN KEY (`group_id`) REFERENCES `{{groups}}` (`id`) ON DELETE CASCADE;" +
"CREATE INDEX `{{prefix}}groups_updated_at_idx` ON `{{groups}}` (`updated_at`);"
mysqlV17DownSQL = "ALTER TABLE `{{groups_folders_mapping}}` DROP FOREIGN KEY `{{prefix}}groups_folders_mapping_group_id_fk_groups_id`;" +
"ALTER TABLE `{{groups_folders_mapping}}` DROP FOREIGN KEY `{{prefix}}groups_folders_mapping_folder_id_fk_folders_id`;" +
"ALTER TABLE `{{users_groups_mapping}}` DROP FOREIGN KEY `{{prefix}}users_groups_mapping_user_id_fk_users_id`;" +
"ALTER TABLE `{{users_groups_mapping}}` DROP FOREIGN KEY `{{prefix}}users_groups_mapping_group_id_fk_groups_id`;" +
"ALTER TABLE `{{groups_folders_mapping}}` DROP INDEX `{{prefix}}unique_group_folder_mapping`;" +
"ALTER TABLE `{{users_groups_mapping}}` DROP INDEX `{{prefix}}unique_user_group_mapping`;" +
"DROP TABLE `{{users_groups_mapping}}` CASCADE;" +
"DROP TABLE `{{groups_folders_mapping}}` CASCADE;" +
"DROP TABLE `{{groups}}` CASCADE;" +
"ALTER TABLE `{{users_folders_mapping}}` DROP FOREIGN KEY `{{prefix}}users_folders_mapping_folder_id_fk_folders_id`;" +
"ALTER TABLE `{{users_folders_mapping}}` DROP FOREIGN KEY `{{prefix}}users_folders_mapping_user_id_fk_users_id`;" +
"ALTER TABLE `{{users_folders_mapping}}` DROP INDEX `{{prefix}}unique_user_folder_mapping`;" +
"RENAME TABLE `{{users_folders_mapping}}` TO `{{folders_mapping}}`;" +
"ALTER TABLE `{{folders_mapping}}` ADD CONSTRAINT `{{prefix}}unique_mapping` UNIQUE (`user_id`, `folder_id`);" +
"ALTER TABLE `{{folders_mapping}}` ADD CONSTRAINT `{{prefix}}folders_mapping_user_id_fk_users_id` " +
"FOREIGN KEY (`user_id`) REFERENCES `{{users}}` (`id`) ON DELETE CASCADE;" +
"ALTER TABLE `{{folders_mapping}}` ADD CONSTRAINT `{{prefix}}folders_mapping_folder_id_fk_folders_id` " +
"FOREIGN KEY (`folder_id`) REFERENCES `{{folders}}` (`id`) ON DELETE CASCADE;"
mysqlV19SQL = "CREATE TABLE `{{shared_sessions}}` (`key` varchar(128) NOT NULL PRIMARY KEY, " +
"`data` longtext NOT NULL, `type` integer NOT NULL, `timestamp` bigint NOT NULL);" +
"CREATE INDEX `{{prefix}}shared_sessions_type_idx` ON `{{shared_sessions}}` (`type`);" +
"CREATE INDEX `{{prefix}}shared_sessions_timestamp_idx` ON `{{shared_sessions}}` (`timestamp`);"
mysqlV19DownSQL = "DROP TABLE `{{shared_sessions}}` CASCADE;"
)
// MySQLProvider defines the auth provider for MySQL/MariaDB database
@@ -220,7 +205,6 @@ func initializeMySQLProvider() error {
dbHandle.SetMaxIdleConns(2)
}
dbHandle.SetConnMaxLifetime(240 * time.Second)
dbHandle.SetConnMaxIdleTime(120 * time.Second)
provider = &MySQLProvider{dbHandle: dbHandle}
} else {
providerLog(logger.LevelError, "error creating mysql database handler, connection string: %#v, error: %v",
@@ -237,8 +221,37 @@ func getMySQLConnectionString(redactedPwd bool) (string, error) {
}
sslMode := getSSLMode()
if sslMode == "custom" && !redactedPwd {
if err := registerMySQLCustomTLSConfig(); err != nil {
return "", err
tlsConfig := &tls.Config{}
if config.RootCert != "" {
rootCAs, err := x509.SystemCertPool()
if err != nil {
rootCAs = x509.NewCertPool()
}
rootCrt, err := os.ReadFile(config.RootCert)
if err != nil {
return "", fmt.Errorf("unable to load root certificate %#v: %v", config.RootCert, err)
}
if !rootCAs.AppendCertsFromPEM(rootCrt) {
return "", fmt.Errorf("unable to parse root certificate %#v", config.RootCert)
}
tlsConfig.RootCAs = rootCAs
}
if config.ClientCert != "" && config.ClientKey != "" {
clientCert := make([]tls.Certificate, 0, 1)
tlsCert, err := tls.LoadX509KeyPair(config.ClientCert, config.ClientKey)
if err != nil {
return "", fmt.Errorf("unable to load key pair %#v, %#v: %v", config.ClientCert, config.ClientKey, err)
}
clientCert = append(clientCert, tlsCert)
tlsConfig.Certificates = clientCert
}
if config.SSLMode == 2 {
tlsConfig.InsecureSkipVerify = true
}
providerLog(logger.LevelInfo, "registering custom TLS config, root cert %#v, client cert %#v, client key %#v",
config.RootCert, config.ClientCert, config.ClientKey)
if err := mysql.RegisterTLSConfig("custom", tlsConfig); err != nil {
return "", fmt.Errorf("unable to register tls config: %v", err)
}
}
connectionString = fmt.Sprintf("%v:%v@tcp([%v]:%v)/%v?charset=utf8mb4&interpolateParams=true&timeout=10s&parseTime=true&tls=%v&writeTimeout=60s&readTimeout=60s",
@@ -249,45 +262,6 @@ func getMySQLConnectionString(redactedPwd bool) (string, error) {
return connectionString, nil
}
func registerMySQLCustomTLSConfig() error {
tlsConfig := &tls.Config{}
if config.RootCert != "" {
rootCAs, err := x509.SystemCertPool()
if err != nil {
rootCAs = x509.NewCertPool()
}
rootCrt, err := os.ReadFile(config.RootCert)
if err != nil {
return fmt.Errorf("unable to load root certificate %#v: %v", config.RootCert, err)
}
if !rootCAs.AppendCertsFromPEM(rootCrt) {
return fmt.Errorf("unable to parse root certificate %#v", config.RootCert)
}
tlsConfig.RootCAs = rootCAs
}
if config.ClientCert != "" && config.ClientKey != "" {
clientCert := make([]tls.Certificate, 0, 1)
tlsCert, err := tls.LoadX509KeyPair(config.ClientCert, config.ClientKey)
if err != nil {
return fmt.Errorf("unable to load key pair %#v, %#v: %v", config.ClientCert, config.ClientKey, err)
}
clientCert = append(clientCert, tlsCert)
tlsConfig.Certificates = clientCert
}
if config.SSLMode == 2 || config.SSLMode == 3 {
tlsConfig.InsecureSkipVerify = true
}
if !filepath.IsAbs(config.Host) && !config.DisableSNI {
tlsConfig.ServerName = config.Host
}
providerLog(logger.LevelInfo, "registering custom TLS config, root cert %#v, client cert %#v, client key %#v, disable SNI? %v",
config.RootCert, config.ClientCert, config.ClientKey, config.DisableSNI)
if err := mysql.RegisterTLSConfig("custom", tlsConfig); err != nil {
return fmt.Errorf("unable to register tls config: %v", err)
}
return nil
}
func (p *MySQLProvider) checkAvailability() error {
return sqlCommonCheckAvailability(p.dbHandle)
}
@@ -340,8 +314,8 @@ func (p *MySQLProvider) updateUser(user *User) error {
return sqlCommonUpdateUser(user, p.dbHandle)
}
func (p *MySQLProvider) deleteUser(user User, softDelete bool) error {
return sqlCommonDeleteUser(user, softDelete, p.dbHandle)
func (p *MySQLProvider) deleteUser(user User) error {
return sqlCommonDeleteUser(user, p.dbHandle)
}
func (p *MySQLProvider) updateUserPassword(username, password string) error {
@@ -582,102 +556,6 @@ func (p *MySQLProvider) cleanupSharedSessions(sessionType SessionType, before in
return sqlCommonCleanupSessions(sessionType, before, p.dbHandle)
}
func (p *MySQLProvider) getEventActions(limit, offset int, order string, minimal bool) ([]BaseEventAction, error) {
return sqlCommonGetEventActions(limit, offset, order, minimal, p.dbHandle)
}
func (p *MySQLProvider) dumpEventActions() ([]BaseEventAction, error) {
return sqlCommonDumpEventActions(p.dbHandle)
}
func (p *MySQLProvider) eventActionExists(name string) (BaseEventAction, error) {
return sqlCommonGetEventActionByName(name, p.dbHandle)
}
func (p *MySQLProvider) addEventAction(action *BaseEventAction) error {
return sqlCommonAddEventAction(action, p.dbHandle)
}
func (p *MySQLProvider) updateEventAction(action *BaseEventAction) error {
return sqlCommonUpdateEventAction(action, p.dbHandle)
}
func (p *MySQLProvider) deleteEventAction(action BaseEventAction) error {
return sqlCommonDeleteEventAction(action, p.dbHandle)
}
func (p *MySQLProvider) getEventRules(limit, offset int, order string) ([]EventRule, error) {
return sqlCommonGetEventRules(limit, offset, order, p.dbHandle)
}
func (p *MySQLProvider) dumpEventRules() ([]EventRule, error) {
return sqlCommonDumpEventRules(p.dbHandle)
}
func (p *MySQLProvider) getRecentlyUpdatedRules(after int64) ([]EventRule, error) {
return sqlCommonGetRecentlyUpdatedRules(after, p.dbHandle)
}
func (p *MySQLProvider) eventRuleExists(name string) (EventRule, error) {
return sqlCommonGetEventRuleByName(name, p.dbHandle)
}
func (p *MySQLProvider) addEventRule(rule *EventRule) error {
return sqlCommonAddEventRule(rule, p.dbHandle)
}
func (p *MySQLProvider) updateEventRule(rule *EventRule) error {
return sqlCommonUpdateEventRule(rule, p.dbHandle)
}
func (p *MySQLProvider) deleteEventRule(rule EventRule, softDelete bool) error {
return sqlCommonDeleteEventRule(rule, softDelete, p.dbHandle)
}
func (p *MySQLProvider) getTaskByName(name string) (Task, error) {
return sqlCommonGetTaskByName(name, p.dbHandle)
}
func (p *MySQLProvider) addTask(name string) error {
return sqlCommonAddTask(name, p.dbHandle)
}
func (p *MySQLProvider) updateTask(name string, version int64) error {
return sqlCommonUpdateTask(name, version, p.dbHandle)
}
func (p *MySQLProvider) updateTaskTimestamp(name string) error {
return sqlCommonUpdateTaskTimestamp(name, p.dbHandle)
}
func (p *MySQLProvider) addNode() error {
return sqlCommonAddNode(p.dbHandle)
}
func (p *MySQLProvider) getNodeByName(name string) (Node, error) {
return sqlCommonGetNodeByName(name, p.dbHandle)
}
func (p *MySQLProvider) getNodes() ([]Node, error) {
return sqlCommonGetNodes(p.dbHandle)
}
func (p *MySQLProvider) updateNodeTimestamp() error {
return sqlCommonUpdateNodeTimestamp(p.dbHandle)
}
func (p *MySQLProvider) cleanupNodes() error {
return sqlCommonCleanupNodes(p.dbHandle)
}
func (p *MySQLProvider) setFirstDownloadTimestamp(username string) error {
return sqlCommonSetFirstDownloadTimestamp(username, p.dbHandle)
}
func (p *MySQLProvider) setFirstUploadTimestamp(username string) error {
return sqlCommonSetFirstUploadTimestamp(username, p.dbHandle)
}
func (p *MySQLProvider) close() error {
return p.dbHandle.Close()
}
@@ -695,11 +573,20 @@ func (p *MySQLProvider) initializeDatabase() error {
if errors.Is(err, sql.ErrNoRows) {
return errSchemaVersionEmpty
}
logger.InfoToConsole("creating initial database schema, version 19")
providerLog(logger.LevelInfo, "creating initial database schema, version 19")
initialSQL := sqlReplaceAll(mysqlInitialSQL)
logger.InfoToConsole("creating initial database schema, version 15")
providerLog(logger.LevelInfo, "creating initial database schema, version 15")
initialSQL := strings.ReplaceAll(mysqlInitialSQL, "{{schema_version}}", sqlTableSchemaVersion)
initialSQL = strings.ReplaceAll(initialSQL, "{{admins}}", sqlTableAdmins)
initialSQL = strings.ReplaceAll(initialSQL, "{{folders}}", sqlTableFolders)
initialSQL = strings.ReplaceAll(initialSQL, "{{users}}", sqlTableUsers)
initialSQL = strings.ReplaceAll(initialSQL, "{{folders_mapping}}", sqlTableFoldersMapping)
initialSQL = strings.ReplaceAll(initialSQL, "{{api_keys}}", sqlTableAPIKeys)
initialSQL = strings.ReplaceAll(initialSQL, "{{shares}}", sqlTableShares)
initialSQL = strings.ReplaceAll(initialSQL, "{{defender_events}}", sqlTableDefenderEvents)
initialSQL = strings.ReplaceAll(initialSQL, "{{defender_hosts}}", sqlTableDefenderHosts)
initialSQL = strings.ReplaceAll(initialSQL, "{{prefix}}", config.SQLTablesPrefix)
return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, strings.Split(initialSQL, ";"), 19, true)
return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, strings.Split(initialSQL, ";"), 15, true)
}
func (p *MySQLProvider) migrateDatabase() error { //nolint:dupl
@@ -712,28 +599,28 @@ func (p *MySQLProvider) migrateDatabase() error { //nolint:dupl
case version == sqlDatabaseVersion:
providerLog(logger.LevelDebug, "sql database is up to date, current version: %v", version)
return ErrNoInitRequired
case version < 19:
err = fmt.Errorf("database schema version %v is too old, please see the upgrading docs", version)
case version < 15:
err = fmt.Errorf("database version %v is too old, please see the upgrading docs", version)
providerLog(logger.LevelError, "%v", err)
logger.ErrorToConsole("%v", err)
return err
case version == 19:
return updateMySQLDatabaseFromV19(p.dbHandle)
case version == 20:
return updateMySQLDatabaseFromV20(p.dbHandle)
case version == 21:
return updateMySQLDatabaseFromV21(p.dbHandle)
case version == 22:
return updateMySQLDatabaseFromV22(p.dbHandle)
case version == 15:
return updateMySQLDatabaseFromV15(p.dbHandle)
case version == 16:
return updateMySQLDatabaseFromV16(p.dbHandle)
case version == 17:
return updateMySQLDatabaseFromV17(p.dbHandle)
case version == 18:
return updateMySQLDatabaseFromV18(p.dbHandle)
default:
if version > sqlDatabaseVersion {
providerLog(logger.LevelError, "database schema version %v is newer than the supported one: %v", version,
providerLog(logger.LevelError, "database version %v is newer than the supported one: %v", version,
sqlDatabaseVersion)
logger.WarnToConsole("database schema version %v is newer than the supported one: %v", version,
logger.WarnToConsole("database version %v is newer than the supported one: %v", version,
sqlDatabaseVersion)
return nil
}
return fmt.Errorf("database schema version not handled: %v", version)
return fmt.Errorf("database version not handled: %v", version)
}
}
@@ -747,16 +634,16 @@ func (p *MySQLProvider) revertDatabase(targetVersion int) error {
}
switch dbVersion.Version {
case 20:
return downgradeMySQLDatabaseFromV20(p.dbHandle)
case 21:
return downgradeMySQLDatabaseFromV21(p.dbHandle)
case 22:
return downgradeMySQLDatabaseFromV22(p.dbHandle)
case 23:
return downgradeMySQLDatabaseFromV23(p.dbHandle)
case 16:
return downgradeMySQLDatabaseFromV16(p.dbHandle)
case 17:
return downgradeMySQLDatabaseFromV17(p.dbHandle)
case 18:
return downgradeMySQLDatabaseFromV18(p.dbHandle)
case 19:
return downgradeMySQLDatabaseFromV19(p.dbHandle)
default:
return fmt.Errorf("database schema version not handled: %v", dbVersion.Version)
return fmt.Errorf("database version not handled: %v", dbVersion.Version)
}
}
@@ -765,121 +652,127 @@ func (p *MySQLProvider) resetDatabase() error {
return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, strings.Split(sql, ";"), 0, false)
}
func updateMySQLDatabaseFromV19(dbHandle *sql.DB) error {
if err := updateMySQLDatabaseFrom19To20(dbHandle); err != nil {
func updateMySQLDatabaseFromV15(dbHandle *sql.DB) error {
if err := updateMySQLDatabaseFrom15To16(dbHandle); err != nil {
return err
}
return updateMySQLDatabaseFromV20(dbHandle)
return updateMySQLDatabaseFromV16(dbHandle)
}
func updateMySQLDatabaseFromV20(dbHandle *sql.DB) error {
if err := updateMySQLDatabaseFrom20To21(dbHandle); err != nil {
func updateMySQLDatabaseFromV16(dbHandle *sql.DB) error {
if err := updateMySQLDatabaseFrom16To17(dbHandle); err != nil {
return err
}
return updateMySQLDatabaseFromV21(dbHandle)
return updateMySQLDatabaseFromV17(dbHandle)
}
func updateMySQLDatabaseFromV21(dbHandle *sql.DB) error {
if err := updateMySQLDatabaseFrom21To22(dbHandle); err != nil {
func updateMySQLDatabaseFromV17(dbHandle *sql.DB) error {
if err := updateMySQLDatabaseFrom17To18(dbHandle); err != nil {
return err
}
return updateMySQLDatabaseFromV22(dbHandle)
return updateMySQLDatabaseFromV18(dbHandle)
}
func updateMySQLDatabaseFromV22(dbHandle *sql.DB) error {
return updateMySQLDatabaseFrom22To23(dbHandle)
func updateMySQLDatabaseFromV18(dbHandle *sql.DB) error {
return updateMySQLDatabaseFrom18To19(dbHandle)
}
func downgradeMySQLDatabaseFromV20(dbHandle *sql.DB) error {
return downgradeMySQLDatabaseFrom20To19(dbHandle)
func downgradeMySQLDatabaseFromV16(dbHandle *sql.DB) error {
return downgradeMySQLDatabaseFrom16To15(dbHandle)
}
func downgradeMySQLDatabaseFromV21(dbHandle *sql.DB) error {
if err := downgradeMySQLDatabaseFrom21To20(dbHandle); err != nil {
func downgradeMySQLDatabaseFromV17(dbHandle *sql.DB) error {
if err := downgradeMySQLDatabaseFrom17To16(dbHandle); err != nil {
return err
}
return downgradeMySQLDatabaseFromV20(dbHandle)
return downgradeMySQLDatabaseFromV16(dbHandle)
}
func downgradeMySQLDatabaseFromV22(dbHandle *sql.DB) error {
if err := downgradeMySQLDatabaseFrom22To21(dbHandle); err != nil {
func downgradeMySQLDatabaseFromV18(dbHandle *sql.DB) error {
if err := downgradeMySQLDatabaseFrom18To17(dbHandle); err != nil {
return err
}
return downgradeMySQLDatabaseFromV21(dbHandle)
return downgradeMySQLDatabaseFromV17(dbHandle)
}
func downgradeMySQLDatabaseFromV23(dbHandle *sql.DB) error {
if err := downgradeMySQLDatabaseFrom23To22(dbHandle); err != nil {
func downgradeMySQLDatabaseFromV19(dbHandle *sql.DB) error {
if err := downgradeMySQLDatabaseFrom19To18(dbHandle); err != nil {
return err
}
return downgradeMySQLDatabaseFromV22(dbHandle)
return downgradeMySQLDatabaseFromV18(dbHandle)
}
func updateMySQLDatabaseFrom19To20(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database schema version: 19 -> 20")
providerLog(logger.LevelInfo, "updating database schema version: 19 -> 20")
sql := strings.ReplaceAll(mysqlV20SQL, "{{events_actions}}", sqlTableEventsActions)
sql = strings.ReplaceAll(sql, "{{events_rules}}", sqlTableEventsRules)
sql = strings.ReplaceAll(sql, "{{rules_actions_mapping}}", sqlTableRulesActionsMapping)
sql = strings.ReplaceAll(sql, "{{users}}", sqlTableUsers)
sql = strings.ReplaceAll(sql, "{{tasks}}", sqlTableTasks)
func updateMySQLDatabaseFrom15To16(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database version: 15 -> 16")
providerLog(logger.LevelInfo, "updating database version: 15 -> 16")
sql := strings.ReplaceAll(mysqlV16SQL, "{{users}}", sqlTableUsers)
sql = strings.ReplaceAll(sql, "{{active_transfers}}", sqlTableActiveTransfers)
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 20, true)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 16, true)
}
func updateMySQLDatabaseFrom20To21(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database schema version: 20 -> 21")
providerLog(logger.LevelInfo, "updating database schema version: 20 -> 21")
sql := strings.ReplaceAll(mysqlV21SQL, "{{users}}", sqlTableUsers)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 21, true)
}
func updateMySQLDatabaseFrom21To22(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database schema version: 21 -> 22")
providerLog(logger.LevelInfo, "updating database schema version: 21 -> 22")
sql := strings.ReplaceAll(mysqlV22SQL, "{{admins_groups_mapping}}", sqlTableAdminsGroupsMapping)
sql = strings.ReplaceAll(sql, "{{admins}}", sqlTableAdmins)
func updateMySQLDatabaseFrom16To17(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database version: 16 -> 17")
providerLog(logger.LevelInfo, "updating database version: 16 -> 17")
sql := strings.ReplaceAll(mysqlV17SQL, "{{users}}", sqlTableUsers)
sql = strings.ReplaceAll(sql, "{{groups}}", sqlTableGroups)
sql = strings.ReplaceAll(sql, "{{folders}}", sqlTableFolders)
sql = strings.ReplaceAll(sql, "{{folders_mapping}}", sqlTableFoldersMapping)
sql = strings.ReplaceAll(sql, "{{users_folders_mapping}}", sqlTableUsersFoldersMapping)
sql = strings.ReplaceAll(sql, "{{users_groups_mapping}}", sqlTableUsersGroupsMapping)
sql = strings.ReplaceAll(sql, "{{groups_folders_mapping}}", sqlTableGroupsFoldersMapping)
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 22, true)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 17, true)
}
func updateMySQLDatabaseFrom22To23(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database schema version: 22 -> 23")
providerLog(logger.LevelInfo, "updating database schema version: 22 -> 23")
sql := strings.ReplaceAll(mysqlV23SQL, "{{nodes}}", sqlTableNodes)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 23, true)
func updateMySQLDatabaseFrom17To18(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database version: 17 -> 18")
providerLog(logger.LevelInfo, "updating database version: 17 -> 18")
if err := importGCSCredentials(); err != nil {
return err
}
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, nil, 18, true)
}
func downgradeMySQLDatabaseFrom20To19(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database schema version: 20 -> 19")
providerLog(logger.LevelInfo, "downgrading database schema version: 20 -> 19")
sql := strings.ReplaceAll(mysqlV20DownSQL, "{{events_actions}}", sqlTableEventsActions)
sql = strings.ReplaceAll(sql, "{{events_rules}}", sqlTableEventsRules)
sql = strings.ReplaceAll(sql, "{{rules_actions_mapping}}", sqlTableRulesActionsMapping)
sql = strings.ReplaceAll(sql, "{{users}}", sqlTableUsers)
sql = strings.ReplaceAll(sql, "{{tasks}}", sqlTableTasks)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 19, false)
}
func downgradeMySQLDatabaseFrom21To20(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database schema version: 21 -> 20")
providerLog(logger.LevelInfo, "downgrading database schema version: 21 -> 20")
sql := strings.ReplaceAll(mysqlV21DownSQL, "{{users}}", sqlTableUsers)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 20, false)
}
func downgradeMySQLDatabaseFrom22To21(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database schema version: 22 -> 21")
providerLog(logger.LevelInfo, "downgrading database schema version: 22 -> 21")
sql := strings.ReplaceAll(mysqlV22DownSQL, "{{admins_groups_mapping}}", sqlTableAdminsGroupsMapping)
func updateMySQLDatabaseFrom18To19(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database version: 18 -> 19")
providerLog(logger.LevelInfo, "updating database version: 18 -> 19")
sql := strings.ReplaceAll(mysqlV19SQL, "{{shared_sessions}}", sqlTableSharedSessions)
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 21, false)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 19, true)
}
func downgradeMySQLDatabaseFrom23To22(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database schema version: 23 -> 22")
providerLog(logger.LevelInfo, "downgrading database schema version: 23 -> 22")
sql := strings.ReplaceAll(mysqlV23DownSQL, "{{nodes}}", sqlTableNodes)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 22, false)
func downgradeMySQLDatabaseFrom16To15(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database version: 16 -> 15")
providerLog(logger.LevelInfo, "downgrading database version: 16 -> 15")
sql := strings.ReplaceAll(mysqlV16DownSQL, "{{users}}", sqlTableUsers)
sql = strings.ReplaceAll(sql, "{{active_transfers}}", sqlTableActiveTransfers)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 15, false)
}
func downgradeMySQLDatabaseFrom17To16(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database version: 17 -> 16")
providerLog(logger.LevelInfo, "downgrading database version: 17 -> 16")
sql := strings.ReplaceAll(mysqlV17DownSQL, "{{users}}", sqlTableUsers)
sql = strings.ReplaceAll(sql, "{{groups}}", sqlTableGroups)
sql = strings.ReplaceAll(sql, "{{folders}}", sqlTableFolders)
sql = strings.ReplaceAll(sql, "{{folders_mapping}}", sqlTableFoldersMapping)
sql = strings.ReplaceAll(sql, "{{users_folders_mapping}}", sqlTableUsersFoldersMapping)
sql = strings.ReplaceAll(sql, "{{users_groups_mapping}}", sqlTableUsersGroupsMapping)
sql = strings.ReplaceAll(sql, "{{groups_folders_mapping}}", sqlTableGroupsFoldersMapping)
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 16, false)
}
func downgradeMySQLDatabaseFrom18To17(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database version: 18 -> 17")
providerLog(logger.LevelInfo, "downgrading database version: 18 -> 17")
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, nil, 17, false)
}
func downgradeMySQLDatabaseFrom19To18(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database version: 19 -> 18")
providerLog(logger.LevelInfo, "downgrading database version: 19 -> 18")
sql := strings.ReplaceAll(mysqlV19DownSQL, "{{shared_sessions}}", sqlTableSharedSessions)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 18, false)
}

View File

@@ -20,7 +20,7 @@ package dataprovider
import (
"errors"
"github.com/drakkan/sftpgo/v2/internal/version"
"github.com/drakkan/sftpgo/v2/version"
)
func init() {

View File

@@ -26,12 +26,12 @@ import (
"strings"
"time"
// we import pgx here to be able to disable PostgreSQL support using a build tag
_ "github.com/jackc/pgx/v5/stdlib"
// we import lib/pq here to be able to disable PostgreSQL support using a build tag
_ "github.com/lib/pq"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/version"
"github.com/drakkan/sftpgo/v2/internal/vfs"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/version"
"github.com/drakkan/sftpgo/v2/vfs"
)
const (
@@ -39,7 +39,6 @@ const (
DROP TABLE IF EXISTS "{{folders_mapping}}" CASCADE;
DROP TABLE IF EXISTS "{{users_folders_mapping}}" CASCADE;
DROP TABLE IF EXISTS "{{users_groups_mapping}}" CASCADE;
DROP TABLE IF EXISTS "{{admins_groups_mapping}}" CASCADE;
DROP TABLE IF EXISTS "{{groups_folders_mapping}}" CASCADE;
DROP TABLE IF EXISTS "{{admins}}" CASCADE;
DROP TABLE IF EXISTS "{{folders}}" CASCADE;
@@ -50,11 +49,6 @@ DROP TABLE IF EXISTS "{{defender_events}}" CASCADE;
DROP TABLE IF EXISTS "{{defender_hosts}}" CASCADE;
DROP TABLE IF EXISTS "{{active_transfers}}" CASCADE;
DROP TABLE IF EXISTS "{{shared_sessions}}" CASCADE;
DROP TABLE IF EXISTS "{{rules_actions_mapping}}" CASCADE;
DROP TABLE IF EXISTS "{{events_actions}}" CASCADE;
DROP TABLE IF EXISTS "{{events_rules}}" CASCADE;
DROP TABLE IF EXISTS "{{tasks}}" CASCADE;
DROP TABLE IF EXISTS "{{nodes}}" CASCADE;
DROP TABLE IF EXISTS "{{schema_version}}" CASCADE;
`
pgsqlInitial = `CREATE TABLE "{{schema_version}}" ("id" serial NOT NULL PRIMARY KEY, "version" integer NOT NULL);
@@ -62,11 +56,6 @@ CREATE TABLE "{{admins}}" ("id" serial NOT NULL PRIMARY KEY, "username" varchar(
"description" varchar(512) NULL, "password" varchar(255) NOT NULL, "email" varchar(255) NULL, "status" integer NOT NULL,
"permissions" text NOT NULL, "filters" text NULL, "additional_info" text NULL, "last_login" bigint NOT NULL,
"created_at" bigint NOT NULL, "updated_at" bigint NOT NULL);
CREATE TABLE "{{active_transfers}}" ("id" bigserial NOT NULL PRIMARY KEY, "connection_id" varchar(100) NOT NULL,
"transfer_id" bigint NOT NULL, "transfer_type" integer NOT NULL, "username" varchar(255) NOT NULL,
"folder_name" varchar(255) NULL, "ip" varchar(50) NOT NULL, "truncated_size" bigint NOT NULL,
"current_ul_size" bigint NOT NULL, "current_dl_size" bigint NOT NULL, "created_at" bigint NOT NULL,
"updated_at" bigint NOT NULL);
CREATE TABLE "{{defender_hosts}}" ("id" bigserial NOT NULL PRIMARY KEY, "ip" varchar(50) NOT NULL UNIQUE,
"ban_time" bigint NOT NULL, "updated_at" bigint NOT NULL);
CREATE TABLE "{{defender_events}}" ("id" bigserial NOT NULL PRIMARY KEY, "date_time" bigint NOT NULL, "score" integer NOT NULL,
@@ -76,29 +65,19 @@ ALTER TABLE "{{defender_events}}" ADD CONSTRAINT "{{prefix}}defender_events_host
CREATE TABLE "{{folders}}" ("id" serial NOT NULL PRIMARY KEY, "name" varchar(255) NOT NULL UNIQUE, "description" varchar(512) NULL,
"path" text NULL, "used_quota_size" bigint NOT NULL, "used_quota_files" integer NOT NULL, "last_quota_update" bigint NOT NULL,
"filesystem" text NULL);
CREATE TABLE "{{groups}}" ("id" serial NOT NULL PRIMARY KEY, "name" varchar(255) NOT NULL UNIQUE,
"description" varchar(512) NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "user_settings" text NULL);
CREATE TABLE "{{shared_sessions}}" ("key" varchar(128) NOT NULL PRIMARY KEY,
"data" text NOT NULL, "type" integer NOT NULL, "timestamp" bigint NOT NULL);
CREATE TABLE "{{users}}" ("id" serial NOT NULL PRIMARY KEY, "username" varchar(255) NOT NULL UNIQUE, "status" integer NOT NULL,
"expiration_date" bigint NOT NULL, "description" varchar(512) NULL, "password" text NULL, "public_keys" text NULL,
"home_dir" text NOT NULL, "uid" bigint NOT NULL, "gid" bigint NOT NULL, "max_sessions" integer NOT NULL,
"quota_size" bigint NOT NULL, "quota_files" integer NOT NULL, "permissions" text NOT NULL, "used_quota_size" bigint NOT NULL,
"used_quota_files" integer NOT NULL, "last_quota_update" bigint NOT NULL, "upload_bandwidth" integer NOT NULL,
"download_bandwidth" integer NOT NULL, "last_login" bigint NOT NULL, "filters" text NULL, "filesystem" text NULL,
"additional_info" text NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "email" varchar(255) NULL,
"upload_data_transfer" integer NOT NULL, "download_data_transfer" integer NOT NULL, "total_data_transfer" integer NOT NULL,
"used_upload_data_transfer" integer NOT NULL, "used_download_data_transfer" integer NOT NULL);
CREATE TABLE "{{groups_folders_mapping}}" ("id" serial NOT NULL PRIMARY KEY, "group_id" integer NOT NULL,
"folder_id" integer NOT NULL, "virtual_path" text NOT NULL, "quota_size" bigint NOT NULL, "quota_files" integer NOT NULL);
CREATE TABLE "{{users_groups_mapping}}" ("id" serial NOT NULL PRIMARY KEY, "user_id" integer NOT NULL,
"group_id" integer NOT NULL, "group_type" integer NOT NULL);
CREATE TABLE "{{users_folders_mapping}}" ("id" serial NOT NULL PRIMARY KEY, "virtual_path" text NOT NULL,
"additional_info" text NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "email" varchar(255) NULL);
CREATE TABLE "{{folders_mapping}}" ("id" serial NOT NULL PRIMARY KEY, "virtual_path" text NOT NULL,
"quota_size" bigint NOT NULL, "quota_files" integer NOT NULL, "folder_id" integer NOT NULL, "user_id" integer NOT NULL);
ALTER TABLE "{{users_folders_mapping}}" ADD CONSTRAINT "{{prefix}}unique_user_folder_mapping" UNIQUE ("user_id", "folder_id");
ALTER TABLE "{{users_folders_mapping}}" ADD CONSTRAINT "{{prefix}}users_folders_mapping_folder_id_fk_folders_id"
ALTER TABLE "{{folders_mapping}}" ADD CONSTRAINT "{{prefix}}unique_mapping" UNIQUE ("user_id", "folder_id");
ALTER TABLE "{{folders_mapping}}" ADD CONSTRAINT "{{prefix}}folders_mapping_folder_id_fk_folders_id"
FOREIGN KEY ("folder_id") REFERENCES "{{folders}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE;
ALTER TABLE "{{users_folders_mapping}}" ADD CONSTRAINT "{{prefix}}users_folders_mapping_user_id_fk_users_id"
ALTER TABLE "{{folders_mapping}}" ADD CONSTRAINT "{{prefix}}folders_mapping_user_id_fk_users_id"
FOREIGN KEY ("user_id") REFERENCES "{{users}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE;
CREATE TABLE "{{shares}}" ("id" serial NOT NULL PRIMARY KEY,
"share_id" varchar(60) NOT NULL UNIQUE, "name" varchar(255) NOT NULL, "description" varchar(512) NULL,
@@ -116,6 +95,57 @@ ALTER TABLE "{{api_keys}}" ADD CONSTRAINT "{{prefix}}api_keys_admin_id_fk_admins
REFERENCES "{{admins}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE;
ALTER TABLE "{{api_keys}}" ADD CONSTRAINT "{{prefix}}api_keys_user_id_fk_users_id" FOREIGN KEY ("user_id")
REFERENCES "{{users}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE;
CREATE INDEX "{{prefix}}folders_mapping_folder_id_idx" ON "{{folders_mapping}}" ("folder_id");
CREATE INDEX "{{prefix}}folders_mapping_user_id_idx" ON "{{folders_mapping}}" ("user_id");
CREATE INDEX "{{prefix}}api_keys_admin_id_idx" ON "{{api_keys}}" ("admin_id");
CREATE INDEX "{{prefix}}api_keys_user_id_idx" ON "{{api_keys}}" ("user_id");
CREATE INDEX "{{prefix}}users_updated_at_idx" ON "{{users}}" ("updated_at");
CREATE INDEX "{{prefix}}shares_user_id_idx" ON "{{shares}}" ("user_id");
CREATE INDEX "{{prefix}}defender_hosts_updated_at_idx" ON "{{defender_hosts}}" ("updated_at");
CREATE INDEX "{{prefix}}defender_hosts_ban_time_idx" ON "{{defender_hosts}}" ("ban_time");
CREATE INDEX "{{prefix}}defender_events_date_time_idx" ON "{{defender_events}}" ("date_time");
CREATE INDEX "{{prefix}}defender_events_host_id_idx" ON "{{defender_events}}" ("host_id");
INSERT INTO {{schema_version}} (version) VALUES (15);
`
pgsqlV16SQL = `ALTER TABLE "{{users}}" ADD COLUMN "download_data_transfer" integer DEFAULT 0 NOT NULL;
ALTER TABLE "{{users}}" ALTER COLUMN "download_data_transfer" DROP DEFAULT;
ALTER TABLE "{{users}}" ADD COLUMN "total_data_transfer" integer DEFAULT 0 NOT NULL;
ALTER TABLE "{{users}}" ALTER COLUMN "total_data_transfer" DROP DEFAULT;
ALTER TABLE "{{users}}" ADD COLUMN "upload_data_transfer" integer DEFAULT 0 NOT NULL;
ALTER TABLE "{{users}}" ALTER COLUMN "upload_data_transfer" DROP DEFAULT;
ALTER TABLE "{{users}}" ADD COLUMN "used_download_data_transfer" integer DEFAULT 0 NOT NULL;
ALTER TABLE "{{users}}" ALTER COLUMN "used_download_data_transfer" DROP DEFAULT;
ALTER TABLE "{{users}}" ADD COLUMN "used_upload_data_transfer" integer DEFAULT 0 NOT NULL;
ALTER TABLE "{{users}}" ALTER COLUMN "used_upload_data_transfer" DROP DEFAULT;
CREATE TABLE "{{active_transfers}}" ("id" bigserial NOT NULL PRIMARY KEY, "connection_id" varchar(100) NOT NULL,
"transfer_id" bigint NOT NULL, "transfer_type" integer NOT NULL, "username" varchar(255) NOT NULL,
"folder_name" varchar(255) NULL, "ip" varchar(50) NOT NULL, "truncated_size" bigint NOT NULL,
"current_ul_size" bigint NOT NULL, "current_dl_size" bigint NOT NULL, "created_at" bigint NOT NULL,
"updated_at" bigint NOT NULL);
CREATE INDEX "{{prefix}}active_transfers_connection_id_idx" ON "{{active_transfers}}" ("connection_id");
CREATE INDEX "{{prefix}}active_transfers_transfer_id_idx" ON "{{active_transfers}}" ("transfer_id");
CREATE INDEX "{{prefix}}active_transfers_updated_at_idx" ON "{{active_transfers}}" ("updated_at");
`
pgsqlV16DownSQL = `ALTER TABLE "{{users}}" DROP COLUMN "used_upload_data_transfer" CASCADE;
ALTER TABLE "{{users}}" DROP COLUMN "used_download_data_transfer" CASCADE;
ALTER TABLE "{{users}}" DROP COLUMN "upload_data_transfer" CASCADE;
ALTER TABLE "{{users}}" DROP COLUMN "total_data_transfer" CASCADE;
ALTER TABLE "{{users}}" DROP COLUMN "download_data_transfer" CASCADE;
DROP TABLE "{{active_transfers}}" CASCADE;
`
pgsqlV17SQL = `CREATE TABLE "{{groups}}" ("id" serial NOT NULL PRIMARY KEY, "name" varchar(255) NOT NULL UNIQUE,
"description" varchar(512) NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "user_settings" text NULL);
CREATE TABLE "{{groups_folders_mapping}}" ("id" serial NOT NULL PRIMARY KEY, "group_id" integer NOT NULL,
"folder_id" integer NOT NULL, "virtual_path" text NOT NULL, "quota_size" bigint NOT NULL, "quota_files" integer NOT NULL);
CREATE TABLE "{{users_groups_mapping}}" ("id" serial NOT NULL PRIMARY KEY, "user_id" integer NOT NULL,
"group_id" integer NOT NULL, "group_type" integer NOT NULL);
DROP INDEX "{{prefix}}folders_mapping_folder_id_idx";
DROP INDEX "{{prefix}}folders_mapping_user_id_idx";
ALTER TABLE "{{folders_mapping}}" DROP CONSTRAINT "{{prefix}}unique_mapping";
ALTER TABLE "{{folders_mapping}}" RENAME TO "{{users_folders_mapping}}";
ALTER TABLE "{{users_folders_mapping}}" ADD CONSTRAINT "{{prefix}}unique_user_folder_mapping" UNIQUE ("user_id", "folder_id");
CREATE INDEX "{{prefix}}users_folders_mapping_folder_id_idx" ON "{{users_folders_mapping}}" ("folder_id");
CREATE INDEX "{{prefix}}users_folders_mapping_user_id_idx" ON "{{users_folders_mapping}}" ("user_id");
ALTER TABLE "{{users_groups_mapping}}" ADD CONSTRAINT "{{prefix}}unique_user_group_mapping" UNIQUE ("user_id", "group_id");
ALTER TABLE "{{groups_folders_mapping}}" ADD CONSTRAINT "{{prefix}}unique_group_folder_mapping" UNIQUE ("group_id", "folder_id");
CREATE INDEX "{{prefix}}users_groups_mapping_group_id_idx" ON "{{users_groups_mapping}}" ("group_id");
@@ -131,77 +161,23 @@ CREATE INDEX "{{prefix}}groups_folders_mapping_group_id_idx" ON "{{groups_folder
ALTER TABLE "{{groups_folders_mapping}}" ADD CONSTRAINT "{{prefix}}groups_folders_mapping_group_id_fk_groups_id"
FOREIGN KEY ("group_id") REFERENCES "{{groups}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE;
CREATE INDEX "{{prefix}}groups_updated_at_idx" ON "{{groups}}" ("updated_at");
CREATE INDEX "{{prefix}}users_folders_mapping_folder_id_idx" ON "{{users_folders_mapping}}" ("folder_id");
CREATE INDEX "{{prefix}}users_folders_mapping_user_id_idx" ON "{{users_folders_mapping}}" ("user_id");
CREATE INDEX "{{prefix}}api_keys_admin_id_idx" ON "{{api_keys}}" ("admin_id");
CREATE INDEX "{{prefix}}api_keys_user_id_idx" ON "{{api_keys}}" ("user_id");
CREATE INDEX "{{prefix}}users_updated_at_idx" ON "{{users}}" ("updated_at");
CREATE INDEX "{{prefix}}shares_user_id_idx" ON "{{shares}}" ("user_id");
CREATE INDEX "{{prefix}}defender_hosts_updated_at_idx" ON "{{defender_hosts}}" ("updated_at");
CREATE INDEX "{{prefix}}defender_hosts_ban_time_idx" ON "{{defender_hosts}}" ("ban_time");
CREATE INDEX "{{prefix}}defender_events_date_time_idx" ON "{{defender_events}}" ("date_time");
CREATE INDEX "{{prefix}}defender_events_host_id_idx" ON "{{defender_events}}" ("host_id");
CREATE INDEX "{{prefix}}active_transfers_connection_id_idx" ON "{{active_transfers}}" ("connection_id");
CREATE INDEX "{{prefix}}active_transfers_transfer_id_idx" ON "{{active_transfers}}" ("transfer_id");
CREATE INDEX "{{prefix}}active_transfers_updated_at_idx" ON "{{active_transfers}}" ("updated_at");
`
pgsqlV17DownSQL = `DROP TABLE "{{users_groups_mapping}}" CASCADE;
DROP TABLE "{{groups_folders_mapping}}" CASCADE;
DROP TABLE "{{groups}}" CASCADE;
DROP INDEX "{{prefix}}users_folders_mapping_folder_id_idx";
DROP INDEX "{{prefix}}users_folders_mapping_user_id_idx";
ALTER TABLE "{{users_folders_mapping}}" DROP CONSTRAINT "{{prefix}}unique_user_folder_mapping";
ALTER TABLE "{{users_folders_mapping}}" RENAME TO "{{folders_mapping}}";
ALTER TABLE "{{folders_mapping}}" ADD CONSTRAINT "{{prefix}}unique_mapping" UNIQUE ("user_id", "folder_id");
CREATE INDEX "{{prefix}}folders_mapping_folder_id_idx" ON "{{folders_mapping}}" ("folder_id");
CREATE INDEX "{{prefix}}folders_mapping_user_id_idx" ON "{{folders_mapping}}" ("user_id");
`
pgsqlV19SQL = `CREATE TABLE "{{shared_sessions}}" ("key" varchar(128) NOT NULL PRIMARY KEY,
"data" text NOT NULL, "type" integer NOT NULL, "timestamp" bigint NOT NULL);
CREATE INDEX "{{prefix}}shared_sessions_type_idx" ON "{{shared_sessions}}" ("type");
CREATE INDEX "{{prefix}}shared_sessions_timestamp_idx" ON "{{shared_sessions}}" ("timestamp");
INSERT INTO {{schema_version}} (version) VALUES (19);
`
pgsqlV20SQL = `CREATE TABLE "{{events_rules}}" ("id" serial NOT NULL PRIMARY KEY, "name" varchar(255) NOT NULL UNIQUE,
"description" varchar(512) NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "trigger" integer NOT NULL,
"conditions" text NOT NULL, "deleted_at" bigint NOT NULL);
CREATE TABLE "{{events_actions}}" ("id" serial NOT NULL PRIMARY KEY, "name" varchar(255) NOT NULL UNIQUE,
"description" varchar(512) NULL, "type" integer NOT NULL, "options" text NOT NULL);
CREATE TABLE "{{rules_actions_mapping}}" ("id" serial NOT NULL PRIMARY KEY, "rule_id" integer NOT NULL,
"action_id" integer NOT NULL, "order" integer NOT NULL, "options" text NOT NULL);
CREATE TABLE "{{tasks}}" ("id" serial NOT NULL PRIMARY KEY, "name" varchar(255) NOT NULL UNIQUE, "updated_at" bigint NOT NULL,
"version" bigint NOT NULL);
ALTER TABLE "{{rules_actions_mapping}}" ADD CONSTRAINT "{{prefix}}unique_rule_action_mapping" UNIQUE ("rule_id", "action_id");
ALTER TABLE "{{rules_actions_mapping}}" ADD CONSTRAINT "{{prefix}}rules_actions_mapping_rule_id_fk_events_rules_id"
FOREIGN KEY ("rule_id") REFERENCES "{{events_rules}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE;
ALTER TABLE "{{rules_actions_mapping}}" ADD CONSTRAINT "{{prefix}}rules_actions_mapping_action_id_fk_events_targets_id"
FOREIGN KEY ("action_id") REFERENCES "{{events_actions}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE NO ACTION;
ALTER TABLE "{{users}}" ADD COLUMN "deleted_at" bigint DEFAULT 0 NOT NULL;
ALTER TABLE "{{users}}" ALTER COLUMN "deleted_at" DROP DEFAULT;
CREATE INDEX "{{prefix}}events_rules_updated_at_idx" ON "{{events_rules}}" ("updated_at");
CREATE INDEX "{{prefix}}events_rules_deleted_at_idx" ON "{{events_rules}}" ("deleted_at");
CREATE INDEX "{{prefix}}events_rules_trigger_idx" ON "{{events_rules}}" ("trigger");
CREATE INDEX "{{prefix}}rules_actions_mapping_rule_id_idx" ON "{{rules_actions_mapping}}" ("rule_id");
CREATE INDEX "{{prefix}}rules_actions_mapping_action_id_idx" ON "{{rules_actions_mapping}}" ("action_id");
CREATE INDEX "{{prefix}}rules_actions_mapping_order_idx" ON "{{rules_actions_mapping}}" ("order");
CREATE INDEX "{{prefix}}users_deleted_at_idx" ON "{{users}}" ("deleted_at");
`
pgsqlV20DownSQL = `DROP TABLE "{{rules_actions_mapping}}" CASCADE;
DROP TABLE "{{events_rules}}" CASCADE;
DROP TABLE "{{events_actions}}" CASCADE;
DROP TABLE "{{tasks}}" CASCADE;
ALTER TABLE "{{users}}" DROP COLUMN "deleted_at" CASCADE;
`
pgsqlV21SQL = `ALTER TABLE "{{users}}" ADD COLUMN "first_download" bigint DEFAULT 0 NOT NULL;
ALTER TABLE "{{users}}" ALTER COLUMN "first_download" DROP DEFAULT;
ALTER TABLE "{{users}}" ADD COLUMN "first_upload" bigint DEFAULT 0 NOT NULL;
ALTER TABLE "{{users}}" ALTER COLUMN "first_upload" DROP DEFAULT;
`
pgsqlV21DownSQL = `ALTER TABLE "{{users}}" DROP COLUMN "first_upload" CASCADE;
ALTER TABLE "{{users}}" DROP COLUMN "first_download" CASCADE;
`
pgsqlV22SQL = `CREATE TABLE "{{admins_groups_mapping}}" ("id" serial NOT NULL PRIMARY KEY,
"admin_id" integer NOT NULL, "group_id" integer NOT NULL, "options" text NOT NULL);
ALTER TABLE "{{admins_groups_mapping}}" ADD CONSTRAINT "{{prefix}}unique_admin_group_mapping" UNIQUE ("admin_id", "group_id");
ALTER TABLE "{{admins_groups_mapping}}" ADD CONSTRAINT "{{prefix}}admins_groups_mapping_admin_id_fk_admins_id"
FOREIGN KEY ("admin_id") REFERENCES "{{admins}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE;
ALTER TABLE "{{admins_groups_mapping}}" ADD CONSTRAINT "{{prefix}}admins_groups_mapping_group_id_fk_groups_id"
FOREIGN KEY ("group_id") REFERENCES "{{groups}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE;
CREATE INDEX "{{prefix}}admins_groups_mapping_admin_id_idx" ON "{{admins_groups_mapping}}" ("admin_id");
CREATE INDEX "{{prefix}}admins_groups_mapping_group_id_idx" ON "{{admins_groups_mapping}}" ("group_id");
`
pgsqlV22DownSQL = `ALTER TABLE "{{admins_groups_mapping}}" DROP CONSTRAINT "{{prefix}}unique_admin_group_mapping";
DROP TABLE "{{admins_groups_mapping}}" CASCADE;
`
pgsqlV23SQL = `CREATE TABLE "{{nodes}}" ("id" serial NOT NULL PRIMARY KEY, "name" varchar(255) NOT NULL UNIQUE,
"data" text NOT NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL);`
pgsqlV23DownSQL = `DROP TABLE "{{nodes}}" CASCADE;`
CREATE INDEX "{{prefix}}shared_sessions_timestamp_idx" ON "{{shared_sessions}}" ("timestamp");`
pgsqlV19DownSQL = `DROP TABLE "{{shared_sessions}}" CASCADE;`
)
// PGSQLProvider defines the auth provider for PostgreSQL database
@@ -215,7 +191,7 @@ func init() {
func initializePGSQLProvider() error {
var err error
dbHandle, err := sql.Open("pgx", getPGSQLConnectionString(false))
dbHandle, err := sql.Open("postgres", getPGSQLConnectionString(false))
if err == nil {
providerLog(logger.LevelDebug, "postgres database handle created, connection string: %#v, pool size: %v",
getPGSQLConnectionString(true), config.PoolSize)
@@ -226,7 +202,6 @@ func initializePGSQLProvider() error {
dbHandle.SetMaxIdleConns(2)
}
dbHandle.SetConnMaxLifetime(240 * time.Second)
dbHandle.SetConnMaxIdleTime(120 * time.Second)
provider = &PGSQLProvider{dbHandle: dbHandle}
} else {
providerLog(logger.LevelError, "error creating postgres database handler, connection string: %#v, error: %v",
@@ -250,12 +225,6 @@ func getPGSQLConnectionString(redactedPwd bool) string {
if config.ClientCert != "" && config.ClientKey != "" {
connectionString += fmt.Sprintf(" sslcert='%v' sslkey='%v'", config.ClientCert, config.ClientKey)
}
if config.DisableSNI {
connectionString += " sslsni=0"
}
if config.TargetSessionAttrs != "" {
connectionString += fmt.Sprintf(" target_session_attrs='%s'", config.TargetSessionAttrs)
}
} else {
connectionString = config.ConnectionString
}
@@ -314,8 +283,8 @@ func (p *PGSQLProvider) updateUser(user *User) error {
return sqlCommonUpdateUser(user, p.dbHandle)
}
func (p *PGSQLProvider) deleteUser(user User, softDelete bool) error {
return sqlCommonDeleteUser(user, softDelete, p.dbHandle)
func (p *PGSQLProvider) deleteUser(user User) error {
return sqlCommonDeleteUser(user, p.dbHandle)
}
func (p *PGSQLProvider) updateUserPassword(username, password string) error {
@@ -556,102 +525,6 @@ func (p *PGSQLProvider) cleanupSharedSessions(sessionType SessionType, before in
return sqlCommonCleanupSessions(sessionType, before, p.dbHandle)
}
func (p *PGSQLProvider) getEventActions(limit, offset int, order string, minimal bool) ([]BaseEventAction, error) {
return sqlCommonGetEventActions(limit, offset, order, minimal, p.dbHandle)
}
func (p *PGSQLProvider) dumpEventActions() ([]BaseEventAction, error) {
return sqlCommonDumpEventActions(p.dbHandle)
}
func (p *PGSQLProvider) eventActionExists(name string) (BaseEventAction, error) {
return sqlCommonGetEventActionByName(name, p.dbHandle)
}
func (p *PGSQLProvider) addEventAction(action *BaseEventAction) error {
return sqlCommonAddEventAction(action, p.dbHandle)
}
func (p *PGSQLProvider) updateEventAction(action *BaseEventAction) error {
return sqlCommonUpdateEventAction(action, p.dbHandle)
}
func (p *PGSQLProvider) deleteEventAction(action BaseEventAction) error {
return sqlCommonDeleteEventAction(action, p.dbHandle)
}
func (p *PGSQLProvider) getEventRules(limit, offset int, order string) ([]EventRule, error) {
return sqlCommonGetEventRules(limit, offset, order, p.dbHandle)
}
func (p *PGSQLProvider) dumpEventRules() ([]EventRule, error) {
return sqlCommonDumpEventRules(p.dbHandle)
}
func (p *PGSQLProvider) getRecentlyUpdatedRules(after int64) ([]EventRule, error) {
return sqlCommonGetRecentlyUpdatedRules(after, p.dbHandle)
}
func (p *PGSQLProvider) eventRuleExists(name string) (EventRule, error) {
return sqlCommonGetEventRuleByName(name, p.dbHandle)
}
func (p *PGSQLProvider) addEventRule(rule *EventRule) error {
return sqlCommonAddEventRule(rule, p.dbHandle)
}
func (p *PGSQLProvider) updateEventRule(rule *EventRule) error {
return sqlCommonUpdateEventRule(rule, p.dbHandle)
}
func (p *PGSQLProvider) deleteEventRule(rule EventRule, softDelete bool) error {
return sqlCommonDeleteEventRule(rule, softDelete, p.dbHandle)
}
func (p *PGSQLProvider) getTaskByName(name string) (Task, error) {
return sqlCommonGetTaskByName(name, p.dbHandle)
}
func (p *PGSQLProvider) addTask(name string) error {
return sqlCommonAddTask(name, p.dbHandle)
}
func (p *PGSQLProvider) updateTask(name string, version int64) error {
return sqlCommonUpdateTask(name, version, p.dbHandle)
}
func (p *PGSQLProvider) updateTaskTimestamp(name string) error {
return sqlCommonUpdateTaskTimestamp(name, p.dbHandle)
}
func (p *PGSQLProvider) addNode() error {
return sqlCommonAddNode(p.dbHandle)
}
func (p *PGSQLProvider) getNodeByName(name string) (Node, error) {
return sqlCommonGetNodeByName(name, p.dbHandle)
}
func (p *PGSQLProvider) getNodes() ([]Node, error) {
return sqlCommonGetNodes(p.dbHandle)
}
func (p *PGSQLProvider) updateNodeTimestamp() error {
return sqlCommonUpdateNodeTimestamp(p.dbHandle)
}
func (p *PGSQLProvider) cleanupNodes() error {
return sqlCommonCleanupNodes(p.dbHandle)
}
func (p *PGSQLProvider) setFirstDownloadTimestamp(username string) error {
return sqlCommonSetFirstDownloadTimestamp(username, p.dbHandle)
}
func (p *PGSQLProvider) setFirstUploadTimestamp(username string) error {
return sqlCommonSetFirstUploadTimestamp(username, p.dbHandle)
}
func (p *PGSQLProvider) close() error {
return p.dbHandle.Close()
}
@@ -669,11 +542,26 @@ func (p *PGSQLProvider) initializeDatabase() error {
if errors.Is(err, sql.ErrNoRows) {
return errSchemaVersionEmpty
}
logger.InfoToConsole("creating initial database schema, version 19")
providerLog(logger.LevelInfo, "creating initial database schema, version 19")
initialSQL := sqlReplaceAll(pgsqlInitial)
logger.InfoToConsole("creating initial database schema, version 15")
providerLog(logger.LevelInfo, "creating initial database schema, version 15")
initialSQL := strings.ReplaceAll(pgsqlInitial, "{{schema_version}}", sqlTableSchemaVersion)
initialSQL = strings.ReplaceAll(initialSQL, "{{admins}}", sqlTableAdmins)
initialSQL = strings.ReplaceAll(initialSQL, "{{folders}}", sqlTableFolders)
initialSQL = strings.ReplaceAll(initialSQL, "{{users}}", sqlTableUsers)
initialSQL = strings.ReplaceAll(initialSQL, "{{folders_mapping}}", sqlTableFoldersMapping)
initialSQL = strings.ReplaceAll(initialSQL, "{{api_keys}}", sqlTableAPIKeys)
initialSQL = strings.ReplaceAll(initialSQL, "{{shares}}", sqlTableShares)
initialSQL = strings.ReplaceAll(initialSQL, "{{defender_events}}", sqlTableDefenderEvents)
initialSQL = strings.ReplaceAll(initialSQL, "{{defender_hosts}}", sqlTableDefenderHosts)
initialSQL = strings.ReplaceAll(initialSQL, "{{prefix}}", config.SQLTablesPrefix)
if config.Driver == CockroachDataProviderName {
// Cockroach does not support deferrable constraint validation, we don't need them,
// we keep these definitions for the PostgreSQL driver to avoid changes for users
// upgrading from old SFTPGo versions
initialSQL = strings.ReplaceAll(initialSQL, "DEFERRABLE INITIALLY DEFERRED", "")
}
return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{initialSQL}, 19, true)
return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{initialSQL}, 15, true)
}
func (p *PGSQLProvider) migrateDatabase() error { //nolint:dupl
@@ -686,28 +574,28 @@ func (p *PGSQLProvider) migrateDatabase() error { //nolint:dupl
case version == sqlDatabaseVersion:
providerLog(logger.LevelDebug, "sql database is up to date, current version: %v", version)
return ErrNoInitRequired
case version < 19:
err = fmt.Errorf("database schema version %v is too old, please see the upgrading docs", version)
case version < 15:
err = fmt.Errorf("database version %v is too old, please see the upgrading docs", version)
providerLog(logger.LevelError, "%v", err)
logger.ErrorToConsole("%v", err)
return err
case version == 19:
return updatePgSQLDatabaseFromV19(p.dbHandle)
case version == 20:
return updatePgSQLDatabaseFromV20(p.dbHandle)
case version == 21:
return updatePgSQLDatabaseFromV21(p.dbHandle)
case version == 22:
return updatePgSQLDatabaseFromV21(p.dbHandle)
case version == 15:
return updatePGSQLDatabaseFromV15(p.dbHandle)
case version == 16:
return updatePGSQLDatabaseFromV16(p.dbHandle)
case version == 17:
return updatePGSQLDatabaseFromV17(p.dbHandle)
case version == 18:
return updatePGSQLDatabaseFromV18(p.dbHandle)
default:
if version > sqlDatabaseVersion {
providerLog(logger.LevelError, "database schema version %v is newer than the supported one: %v", version,
providerLog(logger.LevelError, "database version %v is newer than the supported one: %v", version,
sqlDatabaseVersion)
logger.WarnToConsole("database schema version %v is newer than the supported one: %v", version,
logger.WarnToConsole("database version %v is newer than the supported one: %v", version,
sqlDatabaseVersion)
return nil
}
return fmt.Errorf("database schema version not handled: %v", version)
return fmt.Errorf("database version not handled: %v", version)
}
}
@@ -721,16 +609,16 @@ func (p *PGSQLProvider) revertDatabase(targetVersion int) error {
}
switch dbVersion.Version {
case 20:
return downgradePgSQLDatabaseFromV20(p.dbHandle)
case 21:
return downgradePgSQLDatabaseFromV21(p.dbHandle)
case 22:
return downgradePgSQLDatabaseFromV22(p.dbHandle)
case 23:
return downgradePgSQLDatabaseFromV23(p.dbHandle)
case 16:
return downgradePGSQLDatabaseFromV16(p.dbHandle)
case 17:
return downgradePGSQLDatabaseFromV17(p.dbHandle)
case 18:
return downgradePGSQLDatabaseFromV18(p.dbHandle)
case 19:
return downgradePGSQLDatabaseFromV19(p.dbHandle)
default:
return fmt.Errorf("database schema version not handled: %v", dbVersion.Version)
return fmt.Errorf("database version not handled: %v", dbVersion.Version)
}
}
@@ -739,135 +627,153 @@ func (p *PGSQLProvider) resetDatabase() error {
return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{sql}, 0, false)
}
func updatePgSQLDatabaseFromV19(dbHandle *sql.DB) error {
if err := updatePgSQLDatabaseFrom19To20(dbHandle); err != nil {
func updatePGSQLDatabaseFromV15(dbHandle *sql.DB) error {
if err := updatePGSQLDatabaseFrom15To16(dbHandle); err != nil {
return err
}
return updatePgSQLDatabaseFromV20(dbHandle)
return updatePGSQLDatabaseFromV16(dbHandle)
}
func updatePgSQLDatabaseFromV20(dbHandle *sql.DB) error {
if err := updatePgSQLDatabaseFrom20To21(dbHandle); err != nil {
func updatePGSQLDatabaseFromV16(dbHandle *sql.DB) error {
if err := updatePGSQLDatabaseFrom16To17(dbHandle); err != nil {
return err
}
return updatePgSQLDatabaseFromV21(dbHandle)
return updatePGSQLDatabaseFromV17(dbHandle)
}
func updatePgSQLDatabaseFromV21(dbHandle *sql.DB) error {
if err := updatePgSQLDatabaseFrom21To22(dbHandle); err != nil {
func updatePGSQLDatabaseFromV17(dbHandle *sql.DB) error {
if err := updatePGSQLDatabaseFrom17To18(dbHandle); err != nil {
return err
}
return updatePgSQLDatabaseFromV22(dbHandle)
return updatePGSQLDatabaseFromV18(dbHandle)
}
func updatePgSQLDatabaseFromV22(dbHandle *sql.DB) error {
return updatePgSQLDatabaseFrom22To23(dbHandle)
func updatePGSQLDatabaseFromV18(dbHandle *sql.DB) error {
return updatePGSQLDatabaseFrom18To19(dbHandle)
}
func downgradePgSQLDatabaseFromV20(dbHandle *sql.DB) error {
return downgradePgSQLDatabaseFrom20To19(dbHandle)
func downgradePGSQLDatabaseFromV16(dbHandle *sql.DB) error {
return downgradePGSQLDatabaseFrom16To15(dbHandle)
}
func downgradePgSQLDatabaseFromV21(dbHandle *sql.DB) error {
if err := downgradePgSQLDatabaseFrom21To20(dbHandle); err != nil {
func downgradePGSQLDatabaseFromV17(dbHandle *sql.DB) error {
if err := downgradePGSQLDatabaseFrom17To16(dbHandle); err != nil {
return err
}
return downgradePgSQLDatabaseFromV20(dbHandle)
return downgradePGSQLDatabaseFromV16(dbHandle)
}
func downgradePgSQLDatabaseFromV22(dbHandle *sql.DB) error {
if err := downgradePgSQLDatabaseFrom22To21(dbHandle); err != nil {
func downgradePGSQLDatabaseFromV18(dbHandle *sql.DB) error {
if err := downgradePGSQLDatabaseFrom18To17(dbHandle); err != nil {
return err
}
return downgradePgSQLDatabaseFromV21(dbHandle)
return downgradePGSQLDatabaseFromV17(dbHandle)
}
func downgradePgSQLDatabaseFromV23(dbHandle *sql.DB) error {
if err := downgradePgSQLDatabaseFrom23To22(dbHandle); err != nil {
func downgradePGSQLDatabaseFromV19(dbHandle *sql.DB) error {
if err := downgradePGSQLDatabaseFrom19To18(dbHandle); err != nil {
return err
}
return downgradePgSQLDatabaseFromV22(dbHandle)
return downgradePGSQLDatabaseFromV18(dbHandle)
}
func updatePgSQLDatabaseFrom19To20(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database schema version: 19 -> 20")
providerLog(logger.LevelInfo, "updating database schema version: 19 -> 20")
sql := pgsqlV20SQL
if config.Driver == CockroachDataProviderName {
sql = strings.ReplaceAll(sql, `ALTER TABLE "{{users}}" ALTER COLUMN "deleted_at" DROP DEFAULT;`, "")
}
sql = strings.ReplaceAll(sql, "{{events_actions}}", sqlTableEventsActions)
sql = strings.ReplaceAll(sql, "{{events_rules}}", sqlTableEventsRules)
sql = strings.ReplaceAll(sql, "{{rules_actions_mapping}}", sqlTableRulesActionsMapping)
sql = strings.ReplaceAll(sql, "{{tasks}}", sqlTableTasks)
sql = strings.ReplaceAll(sql, "{{users}}", sqlTableUsers)
func updatePGSQLDatabaseFrom15To16(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database version: 15 -> 16")
providerLog(logger.LevelInfo, "updating database version: 15 -> 16")
sql := strings.ReplaceAll(pgsqlV16SQL, "{{users}}", sqlTableUsers)
sql = strings.ReplaceAll(sql, "{{active_transfers}}", sqlTableActiveTransfers)
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 20, true)
}
func updatePgSQLDatabaseFrom20To21(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database schema version: 20 -> 21")
providerLog(logger.LevelInfo, "updating database schema version: 20 -> 21")
sql := pgsqlV21SQL
if config.Driver == CockroachDataProviderName {
sql = strings.ReplaceAll(sql, `ALTER TABLE "{{users}}" ALTER COLUMN "first_download" DROP DEFAULT;`, "")
sql = strings.ReplaceAll(sql, `ALTER TABLE "{{users}}" ALTER COLUMN "first_upload" DROP DEFAULT;`, "")
// Cockroach does not allow to run this schema migration within a transaction
ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
defer cancel()
for _, q := range strings.Split(sql, ";") {
if strings.TrimSpace(q) == "" {
continue
}
_, err := dbHandle.ExecContext(ctx, q)
if err != nil {
return err
}
}
return sqlCommonUpdateDatabaseVersion(ctx, dbHandle, 16)
}
sql = strings.ReplaceAll(sql, "{{users}}", sqlTableUsers)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 21, true)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 16, true)
}
func updatePgSQLDatabaseFrom21To22(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database schema version: 21 -> 22")
providerLog(logger.LevelInfo, "updating database schema version: 21 -> 22")
sql := strings.ReplaceAll(pgsqlV22SQL, "{{admins_groups_mapping}}", sqlTableAdminsGroupsMapping)
sql = strings.ReplaceAll(sql, "{{admins}}", sqlTableAdmins)
func updatePGSQLDatabaseFrom16To17(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database version: 16 -> 17")
providerLog(logger.LevelInfo, "updating database version: 16 -> 17")
sql := pgsqlV17SQL
if config.Driver == CockroachDataProviderName {
sql = strings.ReplaceAll(sql, `ALTER TABLE "{{folders_mapping}}" DROP CONSTRAINT "{{prefix}}unique_mapping";`,
`DROP INDEX "{{prefix}}unique_mapping" CASCADE;`)
}
sql = strings.ReplaceAll(sql, "{{groups}}", sqlTableGroups)
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 22, true)
}
func updatePgSQLDatabaseFrom22To23(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database schema version: 22 -> 23")
providerLog(logger.LevelInfo, "updating database schema version: 22 -> 23")
sql := strings.ReplaceAll(pgsqlV23SQL, "{{nodes}}", sqlTableNodes)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 23, true)
}
func downgradePgSQLDatabaseFrom20To19(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database schema version: 20 -> 19")
providerLog(logger.LevelInfo, "downgrading database schema version: 20 -> 19")
sql := strings.ReplaceAll(pgsqlV20DownSQL, "{{events_actions}}", sqlTableEventsActions)
sql = strings.ReplaceAll(sql, "{{events_rules}}", sqlTableEventsRules)
sql = strings.ReplaceAll(sql, "{{rules_actions_mapping}}", sqlTableRulesActionsMapping)
sql = strings.ReplaceAll(sql, "{{users}}", sqlTableUsers)
sql = strings.ReplaceAll(sql, "{{tasks}}", sqlTableTasks)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 19, false)
}
func downgradePgSQLDatabaseFrom21To20(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database schema version: 21 -> 20")
providerLog(logger.LevelInfo, "downgrading database schema version: 21 -> 20")
sql := strings.ReplaceAll(pgsqlV21DownSQL, "{{users}}", sqlTableUsers)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 20, false)
}
func downgradePgSQLDatabaseFrom22To21(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database schema version: 22 -> 21")
providerLog(logger.LevelInfo, "downgrading database schema version: 22 -> 21")
sql := pgsqlV22DownSQL
if config.Driver == CockroachDataProviderName {
sql = strings.ReplaceAll(sql, `ALTER TABLE "{{admins_groups_mapping}}" DROP CONSTRAINT "{{prefix}}unique_admin_group_mapping";`,
`DROP INDEX "{{prefix}}unique_admin_group_mapping" CASCADE;`)
}
sql = strings.ReplaceAll(sql, "{{admins_groups_mapping}}", sqlTableAdminsGroupsMapping)
sql = strings.ReplaceAll(sql, "{{folders}}", sqlTableFolders)
sql = strings.ReplaceAll(sql, "{{folders_mapping}}", sqlTableFoldersMapping)
sql = strings.ReplaceAll(sql, "{{users_folders_mapping}}", sqlTableUsersFoldersMapping)
sql = strings.ReplaceAll(sql, "{{users_groups_mapping}}", sqlTableUsersGroupsMapping)
sql = strings.ReplaceAll(sql, "{{groups_folders_mapping}}", sqlTableGroupsFoldersMapping)
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 21, false)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 17, true)
}
func downgradePgSQLDatabaseFrom23To22(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database schema version: 23 -> 22")
providerLog(logger.LevelInfo, "downgrading database schema version: 23 -> 22")
sql := strings.ReplaceAll(pgsqlV23DownSQL, "{{nodes}}", sqlTableNodes)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 22, false)
func updatePGSQLDatabaseFrom17To18(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database version: 17 -> 18")
providerLog(logger.LevelInfo, "updating database version: 17 -> 18")
if err := importGCSCredentials(); err != nil {
return err
}
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, nil, 18, true)
}
func updatePGSQLDatabaseFrom18To19(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database version: 18 -> 19")
providerLog(logger.LevelInfo, "updating database version: 18 -> 19")
sql := strings.ReplaceAll(pgsqlV19SQL, "{{shared_sessions}}", sqlTableSharedSessions)
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 19, true)
}
func downgradePGSQLDatabaseFrom16To15(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database version: 16 -> 15")
providerLog(logger.LevelInfo, "downgrading database version: 16 -> 15")
sql := strings.ReplaceAll(pgsqlV16DownSQL, "{{users}}", sqlTableUsers)
sql = strings.ReplaceAll(sql, "{{active_transfers}}", sqlTableActiveTransfers)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 15, false)
}
func downgradePGSQLDatabaseFrom17To16(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database version: 17 -> 16")
providerLog(logger.LevelInfo, "downgrading database version: 17 -> 16")
sql := pgsqlV17DownSQL
if config.Driver == CockroachDataProviderName {
sql = strings.ReplaceAll(sql, `ALTER TABLE "{{users_folders_mapping}}" DROP CONSTRAINT "{{prefix}}unique_user_folder_mapping";`,
`DROP INDEX "{{prefix}}unique_user_folder_mapping" CASCADE;`)
}
sql = strings.ReplaceAll(sql, "{{groups}}", sqlTableGroups)
sql = strings.ReplaceAll(sql, "{{users}}", sqlTableUsers)
sql = strings.ReplaceAll(sql, "{{folders}}", sqlTableFolders)
sql = strings.ReplaceAll(sql, "{{folders_mapping}}", sqlTableFoldersMapping)
sql = strings.ReplaceAll(sql, "{{users_folders_mapping}}", sqlTableUsersFoldersMapping)
sql = strings.ReplaceAll(sql, "{{users_groups_mapping}}", sqlTableUsersGroupsMapping)
sql = strings.ReplaceAll(sql, "{{groups_folders_mapping}}", sqlTableGroupsFoldersMapping)
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 16, false)
}
func downgradePGSQLDatabaseFrom18To17(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database version: 18 -> 17")
providerLog(logger.LevelInfo, "downgrading database version: 18 -> 17")
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, nil, 17, false)
}
func downgradePGSQLDatabaseFrom19To18(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database version: 19 -> 18")
providerLog(logger.LevelInfo, "downgrading database version: 19 -> 18")
sql := strings.ReplaceAll(pgsqlV19DownSQL, "{{shared_sessions}}", sqlTableSharedSessions)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 18, false)
}

View File

@@ -20,7 +20,7 @@ package dataprovider
import (
"errors"
"github.com/drakkan/sftpgo/v2/internal/version"
"github.com/drakkan/sftpgo/v2/version"
)
func init() {

View File

@@ -18,7 +18,7 @@ import (
"sync"
"time"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/logger"
)
var delayedQuotaUpdater quotaUpdater

110
dataprovider/scheduler.go Normal file
View File

@@ -0,0 +1,110 @@
// Copyright (C) 2019-2022 Nicola Murino
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package dataprovider
import (
"fmt"
"sync/atomic"
"time"
"github.com/robfig/cron/v3"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/metric"
"github.com/drakkan/sftpgo/v2/util"
)
var (
scheduler *cron.Cron
lastCachesUpdate int64
// used for bolt and memory providers, so we avoid iterating all users
// to find recently modified ones
lastUserUpdate int64
)
func stopScheduler() {
if scheduler != nil {
scheduler.Stop()
scheduler = nil
}
}
func startScheduler() error {
stopScheduler()
scheduler = cron.New()
_, err := scheduler.AddFunc("@every 30s", checkDataprovider)
if err != nil {
return fmt.Errorf("unable to schedule dataprovider availability check: %w", err)
}
if config.AutoBackup.Enabled {
spec := fmt.Sprintf("0 %v * * %v", config.AutoBackup.Hour, config.AutoBackup.DayOfWeek)
_, err = scheduler.AddFunc(spec, config.doBackup)
if err != nil {
return fmt.Errorf("unable to schedule auto backup: %w", err)
}
}
err = addScheduledCacheUpdates()
if err != nil {
return err
}
scheduler.Start()
return nil
}
func addScheduledCacheUpdates() error {
lastCachesUpdate = util.GetTimeAsMsSinceEpoch(time.Now())
_, err := scheduler.AddFunc("@every 10m", checkCacheUpdates)
if err != nil {
return fmt.Errorf("unable to schedule cache updates: %w", err)
}
return nil
}
func checkDataprovider() {
err := provider.checkAvailability()
if err != nil {
providerLog(logger.LevelError, "check availability error: %v", err)
}
metric.UpdateDataProviderAvailability(err)
}
func checkCacheUpdates() {
providerLog(logger.LevelDebug, "start caches check, update time %v", util.GetTimeFromMsecSinceEpoch(lastCachesUpdate))
checkTime := util.GetTimeAsMsSinceEpoch(time.Now())
users, err := provider.getRecentlyUpdatedUsers(lastCachesUpdate)
if err != nil {
providerLog(logger.LevelError, "unable to get recently updated users: %v", err)
return
}
for _, user := range users {
providerLog(logger.LevelDebug, "invalidate caches for user %#v", user.Username)
webDAVUsersCache.swap(&user)
cachedPasswords.Remove(user.Username)
}
lastCachesUpdate = checkTime
providerLog(logger.LevelDebug, "end caches check, new update time %v", util.GetTimeFromMsecSinceEpoch(lastCachesUpdate))
}
func setLastUserUpdate() {
atomic.StoreInt64(&lastUserUpdate, util.GetTimeAsMsSinceEpoch(time.Now()))
}
func getLastUserUpdate() int64 {
return atomic.LoadInt64(&lastUserUpdate)
}

View File

@@ -24,8 +24,8 @@ import (
"github.com/alexedwards/argon2id"
"golang.org/x/crypto/bcrypt"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/util"
)
// ShareScope defines the supported share scopes

View File

@@ -30,10 +30,10 @@ import (
// we import go-sqlite3 here to be able to disable SQLite support using a build tag
_ "github.com/mattn/go-sqlite3"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/internal/version"
"github.com/drakkan/sftpgo/v2/internal/vfs"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/util"
"github.com/drakkan/sftpgo/v2/version"
"github.com/drakkan/sftpgo/v2/vfs"
)
const (
@@ -41,7 +41,6 @@ const (
DROP TABLE IF EXISTS "{{folders_mapping}}";
DROP TABLE IF EXISTS "{{users_folders_mapping}}";
DROP TABLE IF EXISTS "{{users_groups_mapping}}";
DROP TABLE IF EXISTS "{{admins_groups_mapping}}";
DROP TABLE IF EXISTS "{{groups_folders_mapping}}";
DROP TABLE IF EXISTS "{{admins}}";
DROP TABLE IF EXISTS "{{folders}}";
@@ -52,10 +51,6 @@ DROP TABLE IF EXISTS "{{defender_events}}";
DROP TABLE IF EXISTS "{{defender_hosts}}";
DROP TABLE IF EXISTS "{{active_transfers}}";
DROP TABLE IF EXISTS "{{shared_sessions}}";
DROP TABLE IF EXISTS "{{rules_actions_mapping}}";
DROP TABLE IF EXISTS "{{events_rules}}";
DROP TABLE IF EXISTS "{{events_actions}}";
DROP TABLE IF EXISTS "{{tasks}}";
DROP TABLE IF EXISTS "{{schema_version}}";
`
sqliteInitialSQL = `CREATE TABLE "{{schema_version}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "version" integer NOT NULL);
@@ -63,11 +58,6 @@ CREATE TABLE "{{admins}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "use
"description" varchar(512) NULL, "password" varchar(255) NOT NULL, "email" varchar(255) NULL, "status" integer NOT NULL,
"permissions" text NOT NULL, "filters" text NULL, "additional_info" text NULL, "last_login" bigint NOT NULL,
"created_at" bigint NOT NULL, "updated_at" bigint NOT NULL);
CREATE TABLE "{{active_transfers}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "connection_id" varchar(100) NOT NULL,
"transfer_id" bigint NOT NULL, "transfer_type" integer NOT NULL, "username" varchar(255) NOT NULL,
"folder_name" varchar(255) NULL, "ip" varchar(50) NOT NULL, "truncated_size" bigint NOT NULL,
"current_ul_size" bigint NOT NULL, "current_dl_size" bigint NOT NULL, "created_at" bigint NOT NULL,
"updated_at" bigint NOT NULL);
CREATE TABLE "{{defender_hosts}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "ip" varchar(50) NOT NULL UNIQUE,
"ban_time" bigint NOT NULL, "updated_at" bigint NOT NULL);
CREATE TABLE "{{defender_events}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "date_time" bigint NOT NULL,
@@ -76,34 +66,18 @@ DEFERRABLE INITIALLY DEFERRED);
CREATE TABLE "{{folders}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "name" varchar(255) NOT NULL UNIQUE,
"description" varchar(512) NULL, "path" text NULL, "used_quota_size" bigint NOT NULL, "used_quota_files" integer NOT NULL,
"last_quota_update" bigint NOT NULL, "filesystem" text NULL);
CREATE TABLE "{{groups}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "name" varchar(255) NOT NULL UNIQUE,
"description" varchar(512) NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "user_settings" text NULL);
CREATE TABLE "{{shared_sessions}}" ("key" varchar(128) NOT NULL PRIMARY KEY, "data" text NOT NULL,
"type" integer NOT NULL, "timestamp" bigint NOT NULL);
CREATE TABLE "{{users}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "username" varchar(255) NOT NULL UNIQUE,
"status" integer NOT NULL, "expiration_date" bigint NOT NULL, "description" varchar(512) NULL, "password" text NULL,
"public_keys" text NULL, "home_dir" text NOT NULL, "uid" bigint NOT NULL, "gid" bigint NOT NULL,
"max_sessions" integer NOT NULL, "quota_size" bigint NOT NULL, "quota_files" integer NOT NULL, "permissions" text NOT NULL,
"used_quota_size" bigint NOT NULL, "used_quota_files" integer NOT NULL, "last_quota_update" bigint NOT NULL,
"upload_bandwidth" integer NOT NULL, "download_bandwidth" integer NOT NULL, "last_login" bigint NOT NULL,
"filters" text NULL, "filesystem" text NULL, "additional_info" text NULL, "created_at" bigint NOT NULL,
"updated_at" bigint NOT NULL, "email" varchar(255) NULL, "upload_data_transfer" integer NOT NULL,
"download_data_transfer" integer NOT NULL, "total_data_transfer" integer NOT NULL, "used_upload_data_transfer" integer NOT NULL,
"used_download_data_transfer" integer NOT NULL);
CREATE TABLE "{{groups_folders_mapping}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"folder_id" integer NOT NULL REFERENCES "{{folders}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED,
"group_id" integer NOT NULL REFERENCES "{{groups}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED,
"virtual_path" text NOT NULL, "quota_size" bigint NOT NULL, "quota_files" integer NOT NULL,
CONSTRAINT "{{prefix}}unique_group_folder_mapping" UNIQUE ("group_id", "folder_id"));
CREATE TABLE "{{users_groups_mapping}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"user_id" integer NOT NULL REFERENCES "{{users}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED,
"group_id" integer NOT NULL REFERENCES "{{groups}}" ("id") ON DELETE NO ACTION,
"group_type" integer NOT NULL, CONSTRAINT "{{prefix}}unique_user_group_mapping" UNIQUE ("user_id", "group_id"));
CREATE TABLE "{{users_folders_mapping}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"user_id" integer NOT NULL REFERENCES "{{users}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED,
"folder_id" integer NOT NULL REFERENCES "{{folders}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED,
"virtual_path" text NOT NULL, "quota_size" bigint NOT NULL, "quota_files" integer NOT NULL,
CONSTRAINT "{{prefix}}unique_user_folder_mapping" UNIQUE ("user_id", "folder_id"));
"upload_bandwidth" integer NOT NULL, "download_bandwidth" integer NOT NULL, "last_login" bigint NOT NULL, "filters" text NULL,
"filesystem" text NULL, "additional_info" text NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL,
"email" varchar(255) NULL);
CREATE TABLE "{{folders_mapping}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "virtual_path" text NOT NULL,
"quota_size" bigint NOT NULL, "quota_files" integer NOT NULL, "folder_id" integer NOT NULL REFERENCES "{{folders}}" ("id")
ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED, "user_id" integer NOT NULL REFERENCES "{{users}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED,
CONSTRAINT "{{prefix}}unique_mapping" UNIQUE ("user_id", "folder_id"));
CREATE TABLE "{{shares}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "share_id" varchar(60) NOT NULL UNIQUE,
"name" varchar(255) NOT NULL, "description" varchar(512) NULL, "scope" integer NOT NULL, "paths" text NOT NULL,
"created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "last_use_at" bigint NOT NULL, "expires_at" bigint NOT NULL,
@@ -114,13 +88,8 @@ CREATE TABLE "{{api_keys}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "n
"created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "last_use_at" bigint NOT NULL, "expires_at" bigint NOT NULL,
"description" text NULL, "admin_id" integer NULL REFERENCES "{{admins}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED,
"user_id" integer NULL REFERENCES "{{users}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED);
CREATE INDEX "{{prefix}}groups_updated_at_idx" ON "{{groups}}" ("updated_at");
CREATE INDEX "{{prefix}}users_folders_mapping_folder_id_idx" ON "{{users_folders_mapping}}" ("folder_id");
CREATE INDEX "{{prefix}}users_folders_mapping_user_id_idx" ON "{{users_folders_mapping}}" ("user_id");
CREATE INDEX "{{prefix}}users_groups_mapping_group_id_idx" ON "{{users_groups_mapping}}" ("group_id");
CREATE INDEX "{{prefix}}users_groups_mapping_user_id_idx" ON "{{users_groups_mapping}}" ("user_id");
CREATE INDEX "{{prefix}}groups_folders_mapping_folder_id_idx" ON "{{groups_folders_mapping}}" ("folder_id");
CREATE INDEX "{{prefix}}groups_folders_mapping_group_id_idx" ON "{{groups_folders_mapping}}" ("group_id");
CREATE INDEX "{{prefix}}folders_mapping_folder_id_idx" ON "{{folders_mapping}}" ("folder_id");
CREATE INDEX "{{prefix}}folders_mapping_user_id_idx" ON "{{folders_mapping}}" ("user_id");
CREATE INDEX "{{prefix}}api_keys_admin_id_idx" ON "{{api_keys}}" ("admin_id");
CREATE INDEX "{{prefix}}api_keys_user_id_idx" ON "{{api_keys}}" ("user_id");
CREATE INDEX "{{prefix}}users_updated_at_idx" ON "{{users}}" ("updated_at");
@@ -129,54 +98,78 @@ CREATE INDEX "{{prefix}}defender_hosts_updated_at_idx" ON "{{defender_hosts}}" (
CREATE INDEX "{{prefix}}defender_hosts_ban_time_idx" ON "{{defender_hosts}}" ("ban_time");
CREATE INDEX "{{prefix}}defender_events_date_time_idx" ON "{{defender_events}}" ("date_time");
CREATE INDEX "{{prefix}}defender_events_host_id_idx" ON "{{defender_events}}" ("host_id");
INSERT INTO {{schema_version}} (version) VALUES (15);
`
sqliteV16SQL = `ALTER TABLE "{{users}}" ADD COLUMN "download_data_transfer" integer DEFAULT 0 NOT NULL;
ALTER TABLE "{{users}}" ADD COLUMN "total_data_transfer" integer DEFAULT 0 NOT NULL;
ALTER TABLE "{{users}}" ADD COLUMN "upload_data_transfer" integer DEFAULT 0 NOT NULL;
ALTER TABLE "{{users}}" ADD COLUMN "used_download_data_transfer" integer DEFAULT 0 NOT NULL;
ALTER TABLE "{{users}}" ADD COLUMN "used_upload_data_transfer" integer DEFAULT 0 NOT NULL;
CREATE TABLE "{{active_transfers}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "connection_id" varchar(100) NOT NULL,
"transfer_id" bigint NOT NULL, "transfer_type" integer NOT NULL, "username" varchar(255) NOT NULL,
"folder_name" varchar(255) NULL, "ip" varchar(50) NOT NULL, "truncated_size" bigint NOT NULL,
"current_ul_size" bigint NOT NULL, "current_dl_size" bigint NOT NULL, "created_at" bigint NOT NULL,
"updated_at" bigint NOT NULL);
CREATE INDEX "{{prefix}}active_transfers_connection_id_idx" ON "{{active_transfers}}" ("connection_id");
CREATE INDEX "{{prefix}}active_transfers_transfer_id_idx" ON "{{active_transfers}}" ("transfer_id");
CREATE INDEX "{{prefix}}active_transfers_updated_at_idx" ON "{{active_transfers}}" ("updated_at");
`
sqliteV16DownSQL = `ALTER TABLE "{{users}}" DROP COLUMN "used_upload_data_transfer";
ALTER TABLE "{{users}}" DROP COLUMN "used_download_data_transfer";
ALTER TABLE "{{users}}" DROP COLUMN "upload_data_transfer";
ALTER TABLE "{{users}}" DROP COLUMN "total_data_transfer";
ALTER TABLE "{{users}}" DROP COLUMN "download_data_transfer";
DROP TABLE "{{active_transfers}}";
`
sqliteV17SQL = `CREATE TABLE "{{groups}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "name" varchar(255) NOT NULL UNIQUE,
"description" varchar(512) NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "user_settings" text NULL);
CREATE TABLE "{{groups_folders_mapping}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"folder_id" integer NOT NULL REFERENCES "{{folders}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED,
"group_id" integer NOT NULL REFERENCES "{{groups}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED,
"virtual_path" text NOT NULL, "quota_size" bigint NOT NULL, "quota_files" integer NOT NULL,
CONSTRAINT "{{prefix}}unique_group_folder_mapping" UNIQUE ("group_id", "folder_id"));
CREATE TABLE "{{users_groups_mapping}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"user_id" integer NOT NULL REFERENCES "{{users}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED,
"group_id" integer NOT NULL REFERENCES "{{groups}}" ("id") ON DELETE NO ACTION,
"group_type" integer NOT NULL, CONSTRAINT "{{prefix}}unique_user_group_mapping" UNIQUE ("user_id", "group_id"));
CREATE TABLE "new__folders_mapping" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"user_id" integer NOT NULL REFERENCES "{{users}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED,
"folder_id" integer NOT NULL REFERENCES "{{folders}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED,
"virtual_path" text NOT NULL, "quota_size" bigint NOT NULL, "quota_files" integer NOT NULL,
CONSTRAINT "{{prefix}}unique_user_folder_mapping" UNIQUE ("user_id", "folder_id"));
INSERT INTO "new__folders_mapping" ("id", "virtual_path", "quota_size", "quota_files", "folder_id", "user_id") SELECT "id",
"virtual_path", "quota_size", "quota_files", "folder_id", "user_id" FROM "{{folders_mapping}}";
DROP TABLE "{{folders_mapping}}";
ALTER TABLE "new__folders_mapping" RENAME TO "{{users_folders_mapping}}";
CREATE INDEX "{{prefix}}groups_updated_at_idx" ON "{{groups}}" ("updated_at");
CREATE INDEX "{{prefix}}users_folders_mapping_folder_id_idx" ON "{{users_folders_mapping}}" ("folder_id");
CREATE INDEX "{{prefix}}users_folders_mapping_user_id_idx" ON "{{users_folders_mapping}}" ("user_id");
CREATE INDEX "{{prefix}}users_groups_mapping_group_id_idx" ON "{{users_groups_mapping}}" ("group_id");
CREATE INDEX "{{prefix}}users_groups_mapping_user_id_idx" ON "{{users_groups_mapping}}" ("user_id");
CREATE INDEX "{{prefix}}groups_folders_mapping_folder_id_idx" ON "{{groups_folders_mapping}}" ("folder_id");
CREATE INDEX "{{prefix}}groups_folders_mapping_group_id_idx" ON "{{groups_folders_mapping}}" ("group_id");
`
sqliteV17DownSQL = `DROP TABLE "{{users_groups_mapping}}";
DROP TABLE "{{groups_folders_mapping}}";
DROP TABLE "{{groups}}";
CREATE TABLE "new__folders_mapping" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"user_id" integer NOT NULL REFERENCES "{{users}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED,
"folder_id" integer NOT NULL REFERENCES "{{folders}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED,
"virtual_path" text NOT NULL, "quota_size" bigint NOT NULL, "quota_files" integer NOT NULL,
CONSTRAINT "{{prefix}}unique_folder_mapping" UNIQUE ("user_id", "folder_id"));
INSERT INTO "new__folders_mapping" ("id", "virtual_path", "quota_size", "quota_files", "folder_id", "user_id") SELECT "id",
"virtual_path", "quota_size", "quota_files", "folder_id", "user_id" FROM "{{users_folders_mapping}}";
DROP TABLE "{{users_folders_mapping}}";
ALTER TABLE "new__folders_mapping" RENAME TO "{{folders_mapping}}";
CREATE INDEX "{{prefix}}folders_mapping_folder_id_idx" ON "{{folders_mapping}}" ("folder_id");
CREATE INDEX "{{prefix}}folders_mapping_user_id_idx" ON "{{folders_mapping}}" ("user_id");
`
sqliteV19SQL = `CREATE TABLE "{{shared_sessions}}" ("key" varchar(128) NOT NULL PRIMARY KEY, "data" text NOT NULL,
"type" integer NOT NULL, "timestamp" bigint NOT NULL);
CREATE INDEX "{{prefix}}shared_sessions_type_idx" ON "{{shared_sessions}}" ("type");
CREATE INDEX "{{prefix}}shared_sessions_timestamp_idx" ON "{{shared_sessions}}" ("timestamp");
INSERT INTO {{schema_version}} (version) VALUES (19);
`
sqliteV20SQL = `CREATE TABLE "{{events_rules}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"name" varchar(255) NOT NULL UNIQUE, "description" varchar(512) NULL, "created_at" bigint NOT NULL,
"updated_at" bigint NOT NULL, "trigger" integer NOT NULL, "conditions" text NOT NULL, "deleted_at" bigint NOT NULL);
CREATE TABLE "{{events_actions}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "name" varchar(255) NOT NULL UNIQUE,
"description" varchar(512) NULL, "type" integer NOT NULL, "options" text NOT NULL);
CREATE TABLE "{{rules_actions_mapping}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"rule_id" integer NOT NULL REFERENCES "{{events_rules}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED,
"action_id" integer NOT NULL REFERENCES "{{events_actions}}" ("id") ON DELETE NO ACTION DEFERRABLE INITIALLY DEFERRED,
"order" integer NOT NULL, "options" text NOT NULL,
CONSTRAINT "{{prefix}}unique_rule_action_mapping" UNIQUE ("rule_id", "action_id"));
CREATE TABLE "{{tasks}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "name" varchar(255) NOT NULL UNIQUE,
"updated_at" bigint NOT NULL, "version" bigint NOT NULL);
ALTER TABLE "{{users}}" ADD COLUMN "deleted_at" bigint DEFAULT 0 NOT NULL;
CREATE INDEX "{{prefix}}events_rules_updated_at_idx" ON "{{events_rules}}" ("updated_at");
CREATE INDEX "{{prefix}}events_rules_deleted_at_idx" ON "{{events_rules}}" ("deleted_at");
CREATE INDEX "{{prefix}}events_rules_trigger_idx" ON "{{events_rules}}" ("trigger");
CREATE INDEX "{{prefix}}rules_actions_mapping_rule_id_idx" ON "{{rules_actions_mapping}}" ("rule_id");
CREATE INDEX "{{prefix}}rules_actions_mapping_action_id_idx" ON "{{rules_actions_mapping}}" ("action_id");
CREATE INDEX "{{prefix}}rules_actions_mapping_order_idx" ON "{{rules_actions_mapping}}" ("order");
CREATE INDEX "{{prefix}}users_deleted_at_idx" ON "{{users}}" ("deleted_at");
`
sqliteV20DownSQL = `DROP TABLE "{{rules_actions_mapping}}";
DROP TABLE "{{events_rules}}";
DROP TABLE "{{events_actions}}";
DROP TABLE "{{tasks}}";
DROP INDEX IF EXISTS "{{prefix}}users_deleted_at_idx";
ALTER TABLE "{{users}}" DROP COLUMN "deleted_at";
`
sqliteV21SQL = `ALTER TABLE "{{users}}" ADD COLUMN "first_download" bigint DEFAULT 0 NOT NULL;
ALTER TABLE "{{users}}" ADD COLUMN "first_upload" bigint DEFAULT 0 NOT NULL;`
sqliteV21DownSQL = `ALTER TABLE "{{users}}" DROP COLUMN "first_upload";
ALTER TABLE "{{users}}" DROP COLUMN "first_download";
`
sqliteV22SQL = `CREATE TABLE "{{admins_groups_mapping}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"admin_id" integer NOT NULL REFERENCES "{{admins}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED,
"group_id" integer NOT NULL REFERENCES "{{groups}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED,
"options" text NOT NULL, CONSTRAINT "{{prefix}}unique_admin_group_mapping" UNIQUE ("admin_id", "group_id"));
CREATE INDEX "{{prefix}}admins_groups_mapping_admin_id_idx" ON "{{admins_groups_mapping}}" ("admin_id");
CREATE INDEX "{{prefix}}admins_groups_mapping_group_id_idx" ON "{{admins_groups_mapping}}" ("group_id");
`
sqliteV22DownSQL = `DROP TABLE "{{admins_groups_mapping}}";`
`
sqliteV19DownSQL = `DROP TABLE "{{shared_sessions}}";`
)
// SQLiteProvider defines the auth provider for SQLite database
@@ -268,8 +261,8 @@ func (p *SQLiteProvider) updateUser(user *User) error {
return sqlCommonUpdateUser(user, p.dbHandle)
}
func (p *SQLiteProvider) deleteUser(user User, softDelete bool) error {
return sqlCommonDeleteUser(user, softDelete, p.dbHandle)
func (p *SQLiteProvider) deleteUser(user User) error {
return sqlCommonDeleteUser(user, p.dbHandle)
}
func (p *SQLiteProvider) updateUserPassword(username, password string) error {
@@ -510,102 +503,6 @@ func (p *SQLiteProvider) cleanupSharedSessions(sessionType SessionType, before i
return sqlCommonCleanupSessions(sessionType, before, p.dbHandle)
}
func (p *SQLiteProvider) getEventActions(limit, offset int, order string, minimal bool) ([]BaseEventAction, error) {
return sqlCommonGetEventActions(limit, offset, order, minimal, p.dbHandle)
}
func (p *SQLiteProvider) dumpEventActions() ([]BaseEventAction, error) {
return sqlCommonDumpEventActions(p.dbHandle)
}
func (p *SQLiteProvider) eventActionExists(name string) (BaseEventAction, error) {
return sqlCommonGetEventActionByName(name, p.dbHandle)
}
func (p *SQLiteProvider) addEventAction(action *BaseEventAction) error {
return sqlCommonAddEventAction(action, p.dbHandle)
}
func (p *SQLiteProvider) updateEventAction(action *BaseEventAction) error {
return sqlCommonUpdateEventAction(action, p.dbHandle)
}
func (p *SQLiteProvider) deleteEventAction(action BaseEventAction) error {
return sqlCommonDeleteEventAction(action, p.dbHandle)
}
func (p *SQLiteProvider) getEventRules(limit, offset int, order string) ([]EventRule, error) {
return sqlCommonGetEventRules(limit, offset, order, p.dbHandle)
}
func (p *SQLiteProvider) dumpEventRules() ([]EventRule, error) {
return sqlCommonDumpEventRules(p.dbHandle)
}
func (p *SQLiteProvider) getRecentlyUpdatedRules(after int64) ([]EventRule, error) {
return sqlCommonGetRecentlyUpdatedRules(after, p.dbHandle)
}
func (p *SQLiteProvider) eventRuleExists(name string) (EventRule, error) {
return sqlCommonGetEventRuleByName(name, p.dbHandle)
}
func (p *SQLiteProvider) addEventRule(rule *EventRule) error {
return sqlCommonAddEventRule(rule, p.dbHandle)
}
func (p *SQLiteProvider) updateEventRule(rule *EventRule) error {
return sqlCommonUpdateEventRule(rule, p.dbHandle)
}
func (p *SQLiteProvider) deleteEventRule(rule EventRule, softDelete bool) error {
return sqlCommonDeleteEventRule(rule, softDelete, p.dbHandle)
}
func (p *SQLiteProvider) getTaskByName(name string) (Task, error) {
return sqlCommonGetTaskByName(name, p.dbHandle)
}
func (p *SQLiteProvider) addTask(name string) error {
return sqlCommonAddTask(name, p.dbHandle)
}
func (p *SQLiteProvider) updateTask(name string, version int64) error {
return sqlCommonUpdateTask(name, version, p.dbHandle)
}
func (p *SQLiteProvider) updateTaskTimestamp(name string) error {
return sqlCommonUpdateTaskTimestamp(name, p.dbHandle)
}
func (*SQLiteProvider) addNode() error {
return ErrNotImplemented
}
func (*SQLiteProvider) getNodeByName(name string) (Node, error) {
return Node{}, ErrNotImplemented
}
func (*SQLiteProvider) getNodes() ([]Node, error) {
return nil, ErrNotImplemented
}
func (*SQLiteProvider) updateNodeTimestamp() error {
return ErrNotImplemented
}
func (*SQLiteProvider) cleanupNodes() error {
return ErrNotImplemented
}
func (p *SQLiteProvider) setFirstDownloadTimestamp(username string) error {
return sqlCommonSetFirstDownloadTimestamp(username, p.dbHandle)
}
func (p *SQLiteProvider) setFirstUploadTimestamp(username string) error {
return sqlCommonSetFirstUploadTimestamp(username, p.dbHandle)
}
func (p *SQLiteProvider) close() error {
return p.dbHandle.Close()
}
@@ -623,11 +520,20 @@ func (p *SQLiteProvider) initializeDatabase() error {
if errors.Is(err, sql.ErrNoRows) {
return errSchemaVersionEmpty
}
logger.InfoToConsole("creating initial database schema, version 19")
providerLog(logger.LevelInfo, "creating initial database schema, version 19")
sql := sqlReplaceAll(sqliteInitialSQL)
logger.InfoToConsole("creating initial database schema, version 15")
providerLog(logger.LevelInfo, "creating initial database schema, version 15")
initialSQL := strings.ReplaceAll(sqliteInitialSQL, "{{schema_version}}", sqlTableSchemaVersion)
initialSQL = strings.ReplaceAll(initialSQL, "{{admins}}", sqlTableAdmins)
initialSQL = strings.ReplaceAll(initialSQL, "{{folders}}", sqlTableFolders)
initialSQL = strings.ReplaceAll(initialSQL, "{{users}}", sqlTableUsers)
initialSQL = strings.ReplaceAll(initialSQL, "{{folders_mapping}}", sqlTableFoldersMapping)
initialSQL = strings.ReplaceAll(initialSQL, "{{api_keys}}", sqlTableAPIKeys)
initialSQL = strings.ReplaceAll(initialSQL, "{{shares}}", sqlTableShares)
initialSQL = strings.ReplaceAll(initialSQL, "{{defender_events}}", sqlTableDefenderEvents)
initialSQL = strings.ReplaceAll(initialSQL, "{{defender_hosts}}", sqlTableDefenderHosts)
initialSQL = strings.ReplaceAll(initialSQL, "{{prefix}}", config.SQLTablesPrefix)
return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{sql}, 19, true)
return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{initialSQL}, 15, true)
}
func (p *SQLiteProvider) migrateDatabase() error { //nolint:dupl
@@ -640,28 +546,28 @@ func (p *SQLiteProvider) migrateDatabase() error { //nolint:dupl
case version == sqlDatabaseVersion:
providerLog(logger.LevelDebug, "sql database is up to date, current version: %v", version)
return ErrNoInitRequired
case version < 19:
err = fmt.Errorf("database schema version %v is too old, please see the upgrading docs", version)
case version < 15:
err = fmt.Errorf("database version %v is too old, please see the upgrading docs", version)
providerLog(logger.LevelError, "%v", err)
logger.ErrorToConsole("%v", err)
return err
case version == 19:
return updateSQLiteDatabaseFromV19(p.dbHandle)
case version == 20:
return updateSQLiteDatabaseFromV20(p.dbHandle)
case version == 21:
return updateSQLiteDatabaseFromV21(p.dbHandle)
case version == 22:
return updateSQLiteDatabaseFromV22(p.dbHandle)
case version == 15:
return updateSQLiteDatabaseFromV15(p.dbHandle)
case version == 16:
return updateSQLiteDatabaseFromV16(p.dbHandle)
case version == 17:
return updateSQLiteDatabaseFromV17(p.dbHandle)
case version == 18:
return updateSQLiteDatabaseFromV18(p.dbHandle)
default:
if version > sqlDatabaseVersion {
providerLog(logger.LevelError, "database schema version %v is newer than the supported one: %v", version,
providerLog(logger.LevelError, "database version %v is newer than the supported one: %v", version,
sqlDatabaseVersion)
logger.WarnToConsole("database schema version %v is newer than the supported one: %v", version,
logger.WarnToConsole("database version %v is newer than the supported one: %v", version,
sqlDatabaseVersion)
return nil
}
return fmt.Errorf("database schema version not handled: %v", version)
return fmt.Errorf("database version not handled: %v", version)
}
}
@@ -675,16 +581,16 @@ func (p *SQLiteProvider) revertDatabase(targetVersion int) error {
}
switch dbVersion.Version {
case 20:
return downgradeSQLiteDatabaseFromV20(p.dbHandle)
case 21:
return downgradeSQLiteDatabaseFromV21(p.dbHandle)
case 22:
return downgradeSQLiteDatabaseFromV22(p.dbHandle)
case 23:
return downgradeSQLiteDatabaseFromV23(p.dbHandle)
case 16:
return downgradeSQLiteDatabaseFromV16(p.dbHandle)
case 17:
return downgradeSQLiteDatabaseFromV17(p.dbHandle)
case 18:
return downgradeSQLiteDatabaseFromV18(p.dbHandle)
case 19:
return downgradeSQLiteDatabaseFromV19(p.dbHandle)
default:
return fmt.Errorf("database schema version not handled: %v", dbVersion.Version)
return fmt.Errorf("database version not handled: %v", dbVersion.Version)
}
}
@@ -693,125 +599,145 @@ func (p *SQLiteProvider) resetDatabase() error {
return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{sql}, 0, false)
}
func updateSQLiteDatabaseFromV19(dbHandle *sql.DB) error {
if err := updateSQLiteDatabaseFrom19To20(dbHandle); err != nil {
func updateSQLiteDatabaseFromV15(dbHandle *sql.DB) error {
if err := updateSQLiteDatabaseFrom15To16(dbHandle); err != nil {
return err
}
return updateSQLiteDatabaseFromV20(dbHandle)
return updateSQLiteDatabaseFromV16(dbHandle)
}
func updateSQLiteDatabaseFromV20(dbHandle *sql.DB) error {
if err := updateSQLiteDatabaseFrom20To21(dbHandle); err != nil {
func updateSQLiteDatabaseFromV16(dbHandle *sql.DB) error {
if err := updateSQLiteDatabaseFrom16To17(dbHandle); err != nil {
return err
}
return updateSQLiteDatabaseFromV21(dbHandle)
return updateSQLiteDatabaseFromV17(dbHandle)
}
func updateSQLiteDatabaseFromV21(dbHandle *sql.DB) error {
if err := updateSQLiteDatabaseFrom21To22(dbHandle); err != nil {
func updateSQLiteDatabaseFromV17(dbHandle *sql.DB) error {
if err := updateSQLiteDatabaseFrom17To18(dbHandle); err != nil {
return err
}
return updateSQLiteDatabaseFromV22(dbHandle)
return updateSQLiteDatabaseFromV18(dbHandle)
}
func updateSQLiteDatabaseFromV22(dbHandle *sql.DB) error {
return updateSQLiteDatabaseFrom22To23(dbHandle)
func updateSQLiteDatabaseFromV18(dbHandle *sql.DB) error {
return updateSQLiteDatabaseFrom18To19(dbHandle)
}
func downgradeSQLiteDatabaseFromV20(dbHandle *sql.DB) error {
return downgradeSQLiteDatabaseFrom20To19(dbHandle)
func downgradeSQLiteDatabaseFromV16(dbHandle *sql.DB) error {
return downgradeSQLiteDatabaseFrom16To15(dbHandle)
}
func downgradeSQLiteDatabaseFromV21(dbHandle *sql.DB) error {
if err := downgradeSQLiteDatabaseFrom21To20(dbHandle); err != nil {
func downgradeSQLiteDatabaseFromV17(dbHandle *sql.DB) error {
if err := downgradeSQLiteDatabaseFrom17To16(dbHandle); err != nil {
return err
}
return downgradeSQLiteDatabaseFromV20(dbHandle)
return downgradeSQLiteDatabaseFromV16(dbHandle)
}
func downgradeSQLiteDatabaseFromV22(dbHandle *sql.DB) error {
if err := downgradeSQLiteDatabaseFrom22To21(dbHandle); err != nil {
func downgradeSQLiteDatabaseFromV18(dbHandle *sql.DB) error {
if err := downgradeSQLiteDatabaseFrom18To17(dbHandle); err != nil {
return err
}
return downgradeSQLiteDatabaseFromV21(dbHandle)
return downgradeSQLiteDatabaseFromV17(dbHandle)
}
func downgradeSQLiteDatabaseFromV23(dbHandle *sql.DB) error {
if err := downgradeSQLiteDatabaseFrom23To22(dbHandle); err != nil {
func downgradeSQLiteDatabaseFromV19(dbHandle *sql.DB) error {
if err := downgradeSQLiteDatabaseFrom19To18(dbHandle); err != nil {
return err
}
return downgradeSQLiteDatabaseFromV22(dbHandle)
return downgradeSQLiteDatabaseFromV18(dbHandle)
}
func updateSQLiteDatabaseFrom19To20(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database schema version: 19 -> 20")
providerLog(logger.LevelInfo, "updating database schema version: 19 -> 20")
sql := strings.ReplaceAll(sqliteV20SQL, "{{events_actions}}", sqlTableEventsActions)
sql = strings.ReplaceAll(sql, "{{events_rules}}", sqlTableEventsRules)
sql = strings.ReplaceAll(sql, "{{rules_actions_mapping}}", sqlTableRulesActionsMapping)
sql = strings.ReplaceAll(sql, "{{tasks}}", sqlTableTasks)
sql = strings.ReplaceAll(sql, "{{users}}", sqlTableUsers)
func updateSQLiteDatabaseFrom15To16(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database version: 15 -> 16")
providerLog(logger.LevelInfo, "updating database version: 15 -> 16")
sql := strings.ReplaceAll(sqliteV16SQL, "{{users}}", sqlTableUsers)
sql = strings.ReplaceAll(sql, "{{active_transfers}}", sqlTableActiveTransfers)
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 20, true)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 16, true)
}
func updateSQLiteDatabaseFrom20To21(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database schema version: 20 -> 21")
providerLog(logger.LevelInfo, "updating database schema version: 20 -> 21")
sql := strings.ReplaceAll(sqliteV21SQL, "{{users}}", sqlTableUsers)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 21, true)
}
func updateSQLiteDatabaseFrom21To22(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database schema version: 21 -> 22")
providerLog(logger.LevelInfo, "updating database schema version: 21 -> 22")
sql := strings.ReplaceAll(sqliteV22SQL, "{{admins_groups_mapping}}", sqlTableAdminsGroupsMapping)
sql = strings.ReplaceAll(sql, "{{admins}}", sqlTableAdmins)
func updateSQLiteDatabaseFrom16To17(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database version: 16 -> 17")
providerLog(logger.LevelInfo, "updating database version: 16 -> 17")
if err := setPragmaFK(dbHandle, "OFF"); err != nil {
return err
}
sql := strings.ReplaceAll(sqliteV17SQL, "{{users}}", sqlTableUsers)
sql = strings.ReplaceAll(sql, "{{groups}}", sqlTableGroups)
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 22, true)
}
func updateSQLiteDatabaseFrom22To23(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database schema version: 22 -> 23")
providerLog(logger.LevelInfo, "updating database schema version: 22 -> 23")
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{`SELECT 1`}, 23, true)
}
func downgradeSQLiteDatabaseFrom20To19(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database schema version: 20 -> 19")
providerLog(logger.LevelInfo, "downgrading database schema version: 20 -> 19")
sql := strings.ReplaceAll(sqliteV20DownSQL, "{{events_actions}}", sqlTableEventsActions)
sql = strings.ReplaceAll(sql, "{{events_rules}}", sqlTableEventsRules)
sql = strings.ReplaceAll(sql, "{{rules_actions_mapping}}", sqlTableRulesActionsMapping)
sql = strings.ReplaceAll(sql, "{{users}}", sqlTableUsers)
sql = strings.ReplaceAll(sql, "{{tasks}}", sqlTableTasks)
sql = strings.ReplaceAll(sql, "{{folders}}", sqlTableFolders)
sql = strings.ReplaceAll(sql, "{{folders_mapping}}", sqlTableFoldersMapping)
sql = strings.ReplaceAll(sql, "{{users_folders_mapping}}", sqlTableUsersFoldersMapping)
sql = strings.ReplaceAll(sql, "{{users_groups_mapping}}", sqlTableUsersGroupsMapping)
sql = strings.ReplaceAll(sql, "{{groups_folders_mapping}}", sqlTableGroupsFoldersMapping)
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 19, false)
if err := sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 17, true); err != nil {
return err
}
return setPragmaFK(dbHandle, "ON")
}
func downgradeSQLiteDatabaseFrom21To20(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database schema version: 21 -> 20")
providerLog(logger.LevelInfo, "downgrading database schema version: 21 -> 20")
sql := strings.ReplaceAll(sqliteV21DownSQL, "{{users}}", sqlTableUsers)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 20, false)
func updateSQLiteDatabaseFrom17To18(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database version: 17 -> 18")
providerLog(logger.LevelInfo, "updating database version: 17 -> 18")
if err := importGCSCredentials(); err != nil {
return err
}
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, nil, 18, true)
}
func downgradeSQLiteDatabaseFrom22To21(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database schema version: 22 -> 21")
providerLog(logger.LevelInfo, "downgrading database schema version: 22 -> 21")
sql := strings.ReplaceAll(sqliteV22DownSQL, "{{admins_groups_mapping}}", sqlTableAdminsGroupsMapping)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 21, false)
func updateSQLiteDatabaseFrom18To19(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database version: 18 -> 19")
providerLog(logger.LevelInfo, "updating database version: 18 -> 19")
sql := strings.ReplaceAll(sqliteV19SQL, "{{shared_sessions}}", sqlTableSharedSessions)
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 19, true)
}
func downgradeSQLiteDatabaseFrom23To22(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database schema version: 23 -> 22")
providerLog(logger.LevelInfo, "downgrading database schema version: 23 -> 22")
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{`SELECT 1`}, 22, false)
func downgradeSQLiteDatabaseFrom16To15(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database version: 16 -> 15")
providerLog(logger.LevelInfo, "downgrading database version: 16 -> 15")
sql := strings.ReplaceAll(sqliteV16DownSQL, "{{users}}", sqlTableUsers)
sql = strings.ReplaceAll(sql, "{{active_transfers}}", sqlTableActiveTransfers)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 15, false)
}
/*func setPragmaFK(dbHandle *sql.DB, value string) error {
func downgradeSQLiteDatabaseFrom17To16(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database version: 17 -> 16")
providerLog(logger.LevelInfo, "downgrading database version: 17 -> 16")
if err := setPragmaFK(dbHandle, "OFF"); err != nil {
return err
}
sql := strings.ReplaceAll(sqliteV17DownSQL, "{{groups}}", sqlTableGroups)
sql = strings.ReplaceAll(sql, "{{users}}", sqlTableUsers)
sql = strings.ReplaceAll(sql, "{{folders}}", sqlTableFolders)
sql = strings.ReplaceAll(sql, "{{folders_mapping}}", sqlTableFoldersMapping)
sql = strings.ReplaceAll(sql, "{{users_folders_mapping}}", sqlTableUsersFoldersMapping)
sql = strings.ReplaceAll(sql, "{{users_groups_mapping}}", sqlTableUsersGroupsMapping)
sql = strings.ReplaceAll(sql, "{{groups_folders_mapping}}", sqlTableGroupsFoldersMapping)
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
if err := sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 16, false); err != nil {
return err
}
return setPragmaFK(dbHandle, "ON")
}
func downgradeSQLiteDatabaseFrom18To17(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database version: 18 -> 17")
providerLog(logger.LevelInfo, "downgrading database version: 18 -> 17")
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, nil, 17, false)
}
func downgradeSQLiteDatabaseFrom19To18(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database version: 19 -> 18")
providerLog(logger.LevelInfo, "downgrading database version: 19 -> 18")
sql := strings.ReplaceAll(sqliteV19DownSQL, "{{shared_sessions}}", sqlTableSharedSessions)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 18, false)
}
func setPragmaFK(dbHandle *sql.DB, value string) error {
ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
defer cancel()
@@ -819,4 +745,4 @@ func downgradeSQLiteDatabaseFrom23To22(dbHandle *sql.DB) error {
_, err := dbHandle.ExecContext(ctx, sql)
return err
}*/
}

View File

@@ -20,13 +20,13 @@ package dataprovider
import (
"errors"
"github.com/drakkan/sftpgo/v2/internal/version"
"github.com/drakkan/sftpgo/v2/version"
)
func init() {
version.AddFeature("-sqlite")
}
func initializeSQLiteProvider(_ string) error {
func initializeSQLiteProvider(basePath string) error {
return errors.New("SQLite disabled at build time")
}

734
dataprovider/sqlqueries.go Normal file
View File

@@ -0,0 +1,734 @@
// Copyright (C) 2019-2022 Nicola Murino
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package dataprovider
import (
"fmt"
"strconv"
"strings"
"github.com/drakkan/sftpgo/v2/vfs"
)
const (
selectUserFields = "id,username,password,public_keys,home_dir,uid,gid,max_sessions,quota_size,quota_files,permissions,used_quota_size," +
"used_quota_files,last_quota_update,upload_bandwidth,download_bandwidth,expiration_date,last_login,status,filters,filesystem," +
"additional_info,description,email,created_at,updated_at,upload_data_transfer,download_data_transfer,total_data_transfer," +
"used_upload_data_transfer,used_download_data_transfer"
selectFolderFields = "id,path,used_quota_size,used_quota_files,last_quota_update,name,description,filesystem"
selectAdminFields = "id,username,password,status,email,permissions,filters,additional_info,description,created_at,updated_at,last_login"
selectAPIKeyFields = "key_id,name,api_key,scope,created_at,updated_at,last_use_at,expires_at,description,user_id,admin_id"
selectShareFields = "s.share_id,s.name,s.description,s.scope,s.paths,u.username,s.created_at,s.updated_at,s.last_use_at," +
"s.expires_at,s.password,s.max_tokens,s.used_tokens,s.allow_from"
selectGroupFields = "id,name,description,created_at,updated_at,user_settings"
)
func getSQLPlaceholders() []string {
var placeholders []string
for i := 1; i <= 50; i++ {
if config.Driver == PGSQLDataProviderName || config.Driver == CockroachDataProviderName {
placeholders = append(placeholders, fmt.Sprintf("$%v", i))
} else {
placeholders = append(placeholders, "?")
}
}
return placeholders
}
func getSQLTableGroups() string {
if config.Driver == MySQLDataProviderName {
return fmt.Sprintf("`%s`", sqlTableGroups)
}
return sqlTableGroups
}
func getAddSessionQuery() string {
if config.Driver == MySQLDataProviderName {
return fmt.Sprintf("INSERT INTO %s (`key`,`data`,`type`,`timestamp`) VALUES (%s,%s,%s,%s) "+
"ON DUPLICATE KEY UPDATE `data`=VALUES(`data`), `timestamp`=VALUES(`timestamp`)",
sqlTableSharedSessions, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3])
}
return fmt.Sprintf(`INSERT INTO %s (key,data,type,timestamp) VALUES (%s,%s,%s,%s) ON CONFLICT(key) DO UPDATE SET data=
EXCLUDED.data, timestamp=EXCLUDED.timestamp`,
sqlTableSharedSessions, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3])
}
func getDeleteSessionQuery() string {
if config.Driver == MySQLDataProviderName {
return fmt.Sprintf("DELETE FROM %s WHERE `key` = %s", sqlTableSharedSessions, sqlPlaceholders[0])
}
return fmt.Sprintf(`DELETE FROM %s WHERE key = %s`, sqlTableSharedSessions, sqlPlaceholders[0])
}
func getSessionQuery() string {
if config.Driver == MySQLDataProviderName {
return fmt.Sprintf("SELECT `key`,`data`,`type`,`timestamp` FROM %s WHERE `key` = %s", sqlTableSharedSessions,
sqlPlaceholders[0])
}
return fmt.Sprintf(`SELECT key,data,type,timestamp FROM %s WHERE key = %s`, sqlTableSharedSessions,
sqlPlaceholders[0])
}
func getCleanupSessionsQuery() string {
return fmt.Sprintf(`DELETE from %s WHERE type = %s AND timestamp < %s`,
sqlTableSharedSessions, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getAddDefenderHostQuery() string {
if config.Driver == MySQLDataProviderName {
return fmt.Sprintf("INSERT INTO %v (`ip`,`updated_at`,`ban_time`) VALUES (%v,%v,0) ON DUPLICATE KEY UPDATE `updated_at`=VALUES(`updated_at`)",
sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1])
}
return fmt.Sprintf(`INSERT INTO %v (ip,updated_at,ban_time) VALUES (%v,%v,0) ON CONFLICT (ip) DO UPDATE SET updated_at = EXCLUDED.updated_at RETURNING id`,
sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getAddDefenderEventQuery() string {
return fmt.Sprintf(`INSERT INTO %v (date_time,score,host_id) VALUES (%v,%v,(SELECT id from %v WHERE ip = %v))`,
sqlTableDefenderEvents, sqlPlaceholders[0], sqlPlaceholders[1], sqlTableDefenderHosts, sqlPlaceholders[2])
}
func getDefenderHostsQuery() string {
return fmt.Sprintf(`SELECT id,ip,ban_time FROM %v WHERE updated_at >= %v OR ban_time > 0 ORDER BY updated_at DESC LIMIT %v`,
sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getDefenderHostQuery() string {
return fmt.Sprintf(`SELECT id,ip,ban_time FROM %v WHERE ip = %v AND (updated_at >= %v OR ban_time > 0)`,
sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getDefenderEventsQuery(hostIDS []int64) string {
var sb strings.Builder
for _, hID := range hostIDS {
if sb.Len() == 0 {
sb.WriteString("(")
} else {
sb.WriteString(",")
}
sb.WriteString(strconv.FormatInt(hID, 10))
}
if sb.Len() > 0 {
sb.WriteString(")")
} else {
sb.WriteString("(0)")
}
return fmt.Sprintf(`SELECT host_id,SUM(score) FROM %v WHERE date_time >= %v AND host_id IN %v GROUP BY host_id`,
sqlTableDefenderEvents, sqlPlaceholders[0], sb.String())
}
func getDefenderIsHostBannedQuery() string {
return fmt.Sprintf(`SELECT id FROM %v WHERE ip = %v AND ban_time >= %v`,
sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getDefenderIncrementBanTimeQuery() string {
return fmt.Sprintf(`UPDATE %v SET ban_time = ban_time + %v WHERE ip = %v`,
sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getDefenderSetBanTimeQuery() string {
return fmt.Sprintf(`UPDATE %v SET ban_time = %v WHERE ip = %v`,
sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getDeleteDefenderHostQuery() string {
return fmt.Sprintf(`DELETE FROM %v WHERE ip = %v`, sqlTableDefenderHosts, sqlPlaceholders[0])
}
func getDefenderHostsCleanupQuery() string {
return fmt.Sprintf(`DELETE FROM %v WHERE ban_time < %v AND NOT EXISTS (
SELECT id FROM %v WHERE %v.host_id = %v.id AND %v.date_time > %v)`,
sqlTableDefenderHosts, sqlPlaceholders[0], sqlTableDefenderEvents, sqlTableDefenderEvents, sqlTableDefenderHosts,
sqlTableDefenderEvents, sqlPlaceholders[1])
}
func getDefenderEventsCleanupQuery() string {
return fmt.Sprintf(`DELETE FROM %v WHERE date_time < %v`, sqlTableDefenderEvents, sqlPlaceholders[0])
}
func getGroupByNameQuery() string {
return fmt.Sprintf(`SELECT %s FROM %s WHERE name = %s`, selectGroupFields, getSQLTableGroups(), sqlPlaceholders[0])
}
func getGroupsQuery(order string, minimal bool) string {
var fieldSelection string
if minimal {
fieldSelection = "id,name"
} else {
fieldSelection = selectGroupFields
}
return fmt.Sprintf(`SELECT %s FROM %s ORDER BY name %s LIMIT %v OFFSET %v`, fieldSelection, getSQLTableGroups(),
order, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getGroupsWithNamesQuery(numArgs int) string {
var sb strings.Builder
for idx := 0; idx < numArgs; idx++ {
if sb.Len() == 0 {
sb.WriteString("(")
} else {
sb.WriteString(",")
}
sb.WriteString(sqlPlaceholders[idx])
}
if sb.Len() > 0 {
sb.WriteString(")")
} else {
sb.WriteString("('')")
}
return fmt.Sprintf(`SELECT %s FROM %s WHERE name in %s`, selectGroupFields, getSQLTableGroups(), sb.String())
}
func getUsersInGroupsQuery(numArgs int) string {
var sb strings.Builder
for idx := 0; idx < numArgs; idx++ {
if sb.Len() == 0 {
sb.WriteString("(")
} else {
sb.WriteString(",")
}
sb.WriteString(sqlPlaceholders[idx])
}
if sb.Len() > 0 {
sb.WriteString(")")
} else {
sb.WriteString("('')")
}
return fmt.Sprintf(`SELECT username FROM %s WHERE id IN (SELECT user_id from %s WHERE group_id IN (SELECT id FROM %s WHERE name IN (%s)))`,
sqlTableUsers, sqlTableUsersGroupsMapping, getSQLTableGroups(), sb.String())
}
func getDumpGroupsQuery() string {
return fmt.Sprintf(`SELECT %s FROM %s`, selectGroupFields, getSQLTableGroups())
}
func getAddGroupQuery() string {
return fmt.Sprintf(`INSERT INTO %s (name,description,created_at,updated_at,user_settings)
VALUES (%v,%v,%v,%v,%v)`, getSQLTableGroups(), sqlPlaceholders[0], sqlPlaceholders[1],
sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4])
}
func getUpdateGroupQuery() string {
return fmt.Sprintf(`UPDATE %s SET description=%v,user_settings=%v,updated_at=%v
WHERE name = %s`, getSQLTableGroups(), sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2],
sqlPlaceholders[3])
}
func getDeleteGroupQuery() string {
return fmt.Sprintf(`DELETE FROM %s WHERE name = %s`, getSQLTableGroups(), sqlPlaceholders[0])
}
func getAdminByUsernameQuery() string {
return fmt.Sprintf(`SELECT %v FROM %v WHERE username = %v`, selectAdminFields, sqlTableAdmins, sqlPlaceholders[0])
}
func getAdminsQuery(order string) string {
return fmt.Sprintf(`SELECT %v FROM %v ORDER BY username %v LIMIT %v OFFSET %v`, selectAdminFields, sqlTableAdmins,
order, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getDumpAdminsQuery() string {
return fmt.Sprintf(`SELECT %v FROM %v`, selectAdminFields, sqlTableAdmins)
}
func getAddAdminQuery() string {
return fmt.Sprintf(`INSERT INTO %v (username,password,status,email,permissions,filters,additional_info,description,created_at,updated_at,last_login)
VALUES (%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,0)`, sqlTableAdmins, sqlPlaceholders[0], sqlPlaceholders[1],
sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7],
sqlPlaceholders[8], sqlPlaceholders[9])
}
func getUpdateAdminQuery() string {
return fmt.Sprintf(`UPDATE %v SET password=%v,status=%v,email=%v,permissions=%v,filters=%v,additional_info=%v,description=%v,updated_at=%v
WHERE username = %v`, sqlTableAdmins, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2],
sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8])
}
func getDeleteAdminQuery() string {
return fmt.Sprintf(`DELETE FROM %v WHERE username = %v`, sqlTableAdmins, sqlPlaceholders[0])
}
func getShareByIDQuery(filterUser bool) string {
if filterUser {
return fmt.Sprintf(`SELECT %v FROM %v s INNER JOIN %v u ON s.user_id = u.id WHERE s.share_id = %v AND u.username = %v`,
selectShareFields, sqlTableShares, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1])
}
return fmt.Sprintf(`SELECT %v FROM %v s INNER JOIN %v u ON s.user_id = u.id WHERE s.share_id = %v`,
selectShareFields, sqlTableShares, sqlTableUsers, sqlPlaceholders[0])
}
func getSharesQuery(order string) string {
return fmt.Sprintf(`SELECT %v FROM %v s INNER JOIN %v u ON s.user_id = u.id WHERE u.username = %v ORDER BY s.share_id %v LIMIT %v OFFSET %v`,
selectShareFields, sqlTableShares, sqlTableUsers, sqlPlaceholders[0], order, sqlPlaceholders[1], sqlPlaceholders[2])
}
func getDumpSharesQuery() string {
return fmt.Sprintf(`SELECT %v FROM %v s INNER JOIN %v u ON s.user_id = u.id`,
selectShareFields, sqlTableShares, sqlTableUsers)
}
func getAddShareQuery() string {
return fmt.Sprintf(`INSERT INTO %v (share_id,name,description,scope,paths,created_at,updated_at,last_use_at,
expires_at,password,max_tokens,used_tokens,allow_from,user_id) VALUES (%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,%v)`,
sqlTableShares, sqlPlaceholders[0], sqlPlaceholders[1],
sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6],
sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9], sqlPlaceholders[10], sqlPlaceholders[11],
sqlPlaceholders[12], sqlPlaceholders[13])
}
func getUpdateShareRestoreQuery() string {
return fmt.Sprintf(`UPDATE %v SET name=%v,description=%v,scope=%v,paths=%v,created_at=%v,updated_at=%v,
last_use_at=%v,expires_at=%v,password=%v,max_tokens=%v,used_tokens=%v,allow_from=%v,user_id=%v WHERE share_id = %v`, sqlTableShares,
sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4],
sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9],
sqlPlaceholders[10], sqlPlaceholders[11], sqlPlaceholders[12], sqlPlaceholders[13])
}
func getUpdateShareQuery() string {
return fmt.Sprintf(`UPDATE %v SET name=%v,description=%v,scope=%v,paths=%v,updated_at=%v,expires_at=%v,
password=%v,max_tokens=%v,allow_from=%v,user_id=%v WHERE share_id = %v`, sqlTableShares,
sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4],
sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9],
sqlPlaceholders[10])
}
func getDeleteShareQuery() string {
return fmt.Sprintf(`DELETE FROM %v WHERE share_id = %v`, sqlTableShares, sqlPlaceholders[0])
}
func getAPIKeyByIDQuery() string {
return fmt.Sprintf(`SELECT %v FROM %v WHERE key_id = %v`, selectAPIKeyFields, sqlTableAPIKeys, sqlPlaceholders[0])
}
func getAPIKeysQuery(order string) string {
return fmt.Sprintf(`SELECT %v FROM %v ORDER BY key_id %v LIMIT %v OFFSET %v`, selectAPIKeyFields, sqlTableAPIKeys,
order, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getDumpAPIKeysQuery() string {
return fmt.Sprintf(`SELECT %v FROM %v`, selectAPIKeyFields, sqlTableAPIKeys)
}
func getAddAPIKeyQuery() string {
return fmt.Sprintf(`INSERT INTO %v (key_id,name,api_key,scope,created_at,updated_at,last_use_at,expires_at,description,user_id,admin_id)
VALUES (%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,%v)`, sqlTableAPIKeys, sqlPlaceholders[0], sqlPlaceholders[1],
sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6],
sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9], sqlPlaceholders[10])
}
func getUpdateAPIKeyQuery() string {
return fmt.Sprintf(`UPDATE %v SET name=%v,scope=%v,expires_at=%v,user_id=%v,admin_id=%v,description=%v,updated_at=%v
WHERE key_id = %v`, sqlTableAPIKeys, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2],
sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7])
}
func getDeleteAPIKeyQuery() string {
return fmt.Sprintf(`DELETE FROM %v WHERE key_id = %v`, sqlTableAPIKeys, sqlPlaceholders[0])
}
func getRelatedUsersForAPIKeysQuery(apiKeys []APIKey) string {
var sb strings.Builder
for _, k := range apiKeys {
if k.userID == 0 {
continue
}
if sb.Len() == 0 {
sb.WriteString("(")
} else {
sb.WriteString(",")
}
sb.WriteString(strconv.FormatInt(k.userID, 10))
}
if sb.Len() > 0 {
sb.WriteString(")")
} else {
sb.WriteString("(0)")
}
return fmt.Sprintf(`SELECT id,username FROM %v WHERE id IN %v`, sqlTableUsers, sb.String())
}
func getRelatedAdminsForAPIKeysQuery(apiKeys []APIKey) string {
var sb strings.Builder
for _, k := range apiKeys {
if k.adminID == 0 {
continue
}
if sb.Len() == 0 {
sb.WriteString("(")
} else {
sb.WriteString(",")
}
sb.WriteString(strconv.FormatInt(k.adminID, 10))
}
if sb.Len() > 0 {
sb.WriteString(")")
} else {
sb.WriteString("(0)")
}
return fmt.Sprintf(`SELECT id,username FROM %v WHERE id IN %v`, sqlTableAdmins, sb.String())
}
func getUserByUsernameQuery() string {
return fmt.Sprintf(`SELECT %v FROM %v WHERE username = %v`, selectUserFields, sqlTableUsers, sqlPlaceholders[0])
}
func getUsersQuery(order string) string {
return fmt.Sprintf(`SELECT %v FROM %v ORDER BY username %v LIMIT %v OFFSET %v`, selectUserFields, sqlTableUsers,
order, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getUsersForQuotaCheckQuery(numArgs int) string {
var sb strings.Builder
for idx := 0; idx < numArgs; idx++ {
if sb.Len() == 0 {
sb.WriteString("(")
} else {
sb.WriteString(",")
}
sb.WriteString(sqlPlaceholders[idx])
}
if sb.Len() > 0 {
sb.WriteString(")")
}
return fmt.Sprintf(`SELECT id,username,quota_size,used_quota_size,total_data_transfer,upload_data_transfer,
download_data_transfer,used_upload_data_transfer,used_download_data_transfer,filters FROM %v WHERE username IN %v`,
sqlTableUsers, sb.String())
}
func getRecentlyUpdatedUsersQuery() string {
return fmt.Sprintf(`SELECT %v FROM %v WHERE updated_at >= %v`, selectUserFields, sqlTableUsers, sqlPlaceholders[0])
}
func getDumpUsersQuery() string {
return fmt.Sprintf(`SELECT %v FROM %v`, selectUserFields, sqlTableUsers)
}
func getDumpFoldersQuery() string {
return fmt.Sprintf(`SELECT %v FROM %v`, selectFolderFields, sqlTableFolders)
}
func getUpdateTransferQuotaQuery(reset bool) string {
if reset {
return fmt.Sprintf(`UPDATE %v SET used_upload_data_transfer = %v,used_download_data_transfer = %v,last_quota_update = %v
WHERE username = %v`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3])
}
return fmt.Sprintf(`UPDATE %v SET used_upload_data_transfer = used_upload_data_transfer + %v,
used_download_data_transfer = used_download_data_transfer + %v,last_quota_update = %v
WHERE username = %v`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3])
}
func getUpdateQuotaQuery(reset bool) string {
if reset {
return fmt.Sprintf(`UPDATE %v SET used_quota_size = %v,used_quota_files = %v,last_quota_update = %v
WHERE username = %v`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3])
}
return fmt.Sprintf(`UPDATE %v SET used_quota_size = used_quota_size + %v,used_quota_files = used_quota_files + %v,last_quota_update = %v
WHERE username = %v`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3])
}
func getSetUpdateAtQuery() string {
return fmt.Sprintf(`UPDATE %v SET updated_at = %v WHERE username = %v`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getUpdateLastLoginQuery() string {
return fmt.Sprintf(`UPDATE %v SET last_login = %v WHERE username = %v`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getUpdateAdminLastLoginQuery() string {
return fmt.Sprintf(`UPDATE %v SET last_login = %v WHERE username = %v`, sqlTableAdmins, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getUpdateAPIKeyLastUseQuery() string {
return fmt.Sprintf(`UPDATE %v SET last_use_at = %v WHERE key_id = %v`, sqlTableAPIKeys, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getUpdateShareLastUseQuery() string {
return fmt.Sprintf(`UPDATE %v SET last_use_at = %v, used_tokens = used_tokens +%v WHERE share_id = %v`,
sqlTableShares, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2])
}
func getQuotaQuery() string {
return fmt.Sprintf(`SELECT used_quota_size,used_quota_files,used_upload_data_transfer,
used_download_data_transfer FROM %v WHERE username = %v`,
sqlTableUsers, sqlPlaceholders[0])
}
func getAddUserQuery() string {
return fmt.Sprintf(`INSERT INTO %v (username,password,public_keys,home_dir,uid,gid,max_sessions,quota_size,quota_files,permissions,
used_quota_size,used_quota_files,last_quota_update,upload_bandwidth,download_bandwidth,status,last_login,expiration_date,filters,
filesystem,additional_info,description,email,created_at,updated_at,upload_data_transfer,download_data_transfer,total_data_transfer,
used_upload_data_transfer,used_download_data_transfer)
VALUES (%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,0,0,0,%v,%v,%v,0,%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,0,0)`,
sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4],
sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9],
sqlPlaceholders[10], sqlPlaceholders[11], sqlPlaceholders[12], sqlPlaceholders[13], sqlPlaceholders[14],
sqlPlaceholders[15], sqlPlaceholders[16], sqlPlaceholders[17], sqlPlaceholders[18], sqlPlaceholders[19],
sqlPlaceholders[20], sqlPlaceholders[21], sqlPlaceholders[22], sqlPlaceholders[23])
}
func getUpdateUserQuery() string {
return fmt.Sprintf(`UPDATE %v SET password=%v,public_keys=%v,home_dir=%v,uid=%v,gid=%v,max_sessions=%v,quota_size=%v,
quota_files=%v,permissions=%v,upload_bandwidth=%v,download_bandwidth=%v,status=%v,expiration_date=%v,filters=%v,filesystem=%v,
additional_info=%v,description=%v,email=%v,updated_at=%v,upload_data_transfer=%v,download_data_transfer=%v,
total_data_transfer=%v WHERE id = %v`,
sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4],
sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9],
sqlPlaceholders[10], sqlPlaceholders[11], sqlPlaceholders[12], sqlPlaceholders[13], sqlPlaceholders[14],
sqlPlaceholders[15], sqlPlaceholders[16], sqlPlaceholders[17], sqlPlaceholders[18], sqlPlaceholders[19],
sqlPlaceholders[20], sqlPlaceholders[21], sqlPlaceholders[22])
}
func getUpdateUserPasswordQuery() string {
return fmt.Sprintf(`UPDATE %v SET password=%v WHERE username = %v`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getDeleteUserQuery() string {
return fmt.Sprintf(`DELETE FROM %v WHERE id = %v`, sqlTableUsers, sqlPlaceholders[0])
}
func getFolderByNameQuery() string {
return fmt.Sprintf(`SELECT %v FROM %v WHERE name = %v`, selectFolderFields, sqlTableFolders, sqlPlaceholders[0])
}
func getAddFolderQuery() string {
return fmt.Sprintf(`INSERT INTO %v (path,used_quota_size,used_quota_files,last_quota_update,name,description,filesystem)
VALUES (%v,%v,%v,%v,%v,%v,%v)`, sqlTableFolders, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2],
sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6])
}
func getUpdateFolderQuery() string {
return fmt.Sprintf(`UPDATE %v SET path=%v,description=%v,filesystem=%v WHERE name = %v`, sqlTableFolders, sqlPlaceholders[0],
sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3])
}
func getDeleteFolderQuery() string {
return fmt.Sprintf(`DELETE FROM %v WHERE id = %v`, sqlTableFolders, sqlPlaceholders[0])
}
func getUpsertFolderQuery() string {
if config.Driver == MySQLDataProviderName {
return fmt.Sprintf("INSERT INTO %v (`path`,`used_quota_size`,`used_quota_files`,`last_quota_update`,`name`,"+
"`description`,`filesystem`) VALUES (%v,%v,%v,%v,%v,%v,%v) ON DUPLICATE KEY UPDATE "+
"`path`=VALUES(`path`),`description`=VALUES(`description`),`filesystem`=VALUES(`filesystem`)",
sqlTableFolders, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4],
sqlPlaceholders[5], sqlPlaceholders[6])
}
return fmt.Sprintf(`INSERT INTO %v (path,used_quota_size,used_quota_files,last_quota_update,name,description,filesystem)
VALUES (%v,%v,%v,%v,%v,%v,%v) ON CONFLICT (name) DO UPDATE SET path = EXCLUDED.path,description=EXCLUDED.description,
filesystem=EXCLUDED.filesystem`, sqlTableFolders, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2],
sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6])
}
func getClearUserGroupMappingQuery() string {
return fmt.Sprintf(`DELETE FROM %v WHERE user_id = (SELECT id FROM %v WHERE username = %v)`, sqlTableUsersGroupsMapping,
sqlTableUsers, sqlPlaceholders[0])
}
func getAddUserGroupMappingQuery() string {
return fmt.Sprintf(`INSERT INTO %v (user_id,group_id,group_type) VALUES ((SELECT id FROM %v WHERE username = %v),
(SELECT id FROM %v WHERE name = %v),%v)`,
sqlTableUsersGroupsMapping, sqlTableUsers, sqlPlaceholders[0], getSQLTableGroups(), sqlPlaceholders[1], sqlPlaceholders[2])
}
func getClearGroupFolderMappingQuery() string {
return fmt.Sprintf(`DELETE FROM %v WHERE group_id = (SELECT id FROM %v WHERE name = %v)`, sqlTableGroupsFoldersMapping,
getSQLTableGroups(), sqlPlaceholders[0])
}
func getAddGroupFolderMappingQuery() string {
return fmt.Sprintf(`INSERT INTO %v (virtual_path,quota_size,quota_files,folder_id,group_id)
VALUES (%v,%v,%v,(SELECT id FROM %v WHERE name = %v),(SELECT id FROM %v WHERE name = %v))`,
sqlTableGroupsFoldersMapping, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlTableFolders,
sqlPlaceholders[3], getSQLTableGroups(), sqlPlaceholders[4])
}
func getClearUserFolderMappingQuery() string {
return fmt.Sprintf(`DELETE FROM %v WHERE user_id = (SELECT id FROM %v WHERE username = %v)`, sqlTableUsersFoldersMapping,
sqlTableUsers, sqlPlaceholders[0])
}
func getAddUserFolderMappingQuery() string {
return fmt.Sprintf(`INSERT INTO %v (virtual_path,quota_size,quota_files,folder_id,user_id)
VALUES (%v,%v,%v,(SELECT id FROM %v WHERE name = %v),(SELECT id FROM %v WHERE username = %v))`,
sqlTableUsersFoldersMapping, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlTableFolders,
sqlPlaceholders[3], sqlTableUsers, sqlPlaceholders[4])
}
func getFoldersQuery(order string, minimal bool) string {
var fieldSelection string
if minimal {
fieldSelection = "id,name"
} else {
fieldSelection = selectFolderFields
}
return fmt.Sprintf(`SELECT %v FROM %v ORDER BY name %v LIMIT %v OFFSET %v`, fieldSelection, sqlTableFolders,
order, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getUpdateFolderQuotaQuery(reset bool) string {
if reset {
return fmt.Sprintf(`UPDATE %v SET used_quota_size = %v,used_quota_files = %v,last_quota_update = %v
WHERE name = %v`, sqlTableFolders, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3])
}
return fmt.Sprintf(`UPDATE %v SET used_quota_size = used_quota_size + %v,used_quota_files = used_quota_files + %v,last_quota_update = %v
WHERE name = %v`, sqlTableFolders, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3])
}
func getQuotaFolderQuery() string {
return fmt.Sprintf(`SELECT used_quota_size,used_quota_files FROM %v WHERE name = %v`, sqlTableFolders,
sqlPlaceholders[0])
}
func getRelatedGroupsForUsersQuery(users []User) string {
var sb strings.Builder
for _, u := range users {
if sb.Len() == 0 {
sb.WriteString("(")
} else {
sb.WriteString(",")
}
sb.WriteString(strconv.FormatInt(u.ID, 10))
}
if sb.Len() > 0 {
sb.WriteString(")")
}
return fmt.Sprintf(`SELECT g.name,ug.group_type,ug.user_id FROM %v g INNER JOIN %v ug ON g.id = ug.group_id WHERE
ug.user_id IN %v ORDER BY ug.user_id`, getSQLTableGroups(), sqlTableUsersGroupsMapping, sb.String())
}
func getRelatedFoldersForUsersQuery(users []User) string {
var sb strings.Builder
for _, u := range users {
if sb.Len() == 0 {
sb.WriteString("(")
} else {
sb.WriteString(",")
}
sb.WriteString(strconv.FormatInt(u.ID, 10))
}
if sb.Len() > 0 {
sb.WriteString(")")
}
return fmt.Sprintf(`SELECT f.id,f.name,f.path,f.used_quota_size,f.used_quota_files,f.last_quota_update,fm.virtual_path,
fm.quota_size,fm.quota_files,fm.user_id,f.filesystem,f.description FROM %v f INNER JOIN %v fm ON f.id = fm.folder_id WHERE
fm.user_id IN %v ORDER BY fm.user_id`, sqlTableFolders, sqlTableUsersFoldersMapping, sb.String())
}
func getRelatedUsersForFoldersQuery(folders []vfs.BaseVirtualFolder) string {
var sb strings.Builder
for _, f := range folders {
if sb.Len() == 0 {
sb.WriteString("(")
} else {
sb.WriteString(",")
}
sb.WriteString(strconv.FormatInt(f.ID, 10))
}
if sb.Len() > 0 {
sb.WriteString(")")
}
return fmt.Sprintf(`SELECT fm.folder_id,u.username FROM %v fm INNER JOIN %v u ON fm.user_id = u.id
WHERE fm.folder_id IN %v ORDER BY fm.folder_id`, sqlTableUsersFoldersMapping, sqlTableUsers, sb.String())
}
func getRelatedGroupsForFoldersQuery(folders []vfs.BaseVirtualFolder) string {
var sb strings.Builder
for _, f := range folders {
if sb.Len() == 0 {
sb.WriteString("(")
} else {
sb.WriteString(",")
}
sb.WriteString(strconv.FormatInt(f.ID, 10))
}
if sb.Len() > 0 {
sb.WriteString(")")
}
return fmt.Sprintf(`SELECT fm.folder_id,g.name FROM %v fm INNER JOIN %v g ON fm.group_id = g.id
WHERE fm.folder_id IN %v ORDER BY fm.folder_id`, sqlTableGroupsFoldersMapping, getSQLTableGroups(), sb.String())
}
func getRelatedUsersForGroupsQuery(groups []Group) string {
var sb strings.Builder
for _, g := range groups {
if sb.Len() == 0 {
sb.WriteString("(")
} else {
sb.WriteString(",")
}
sb.WriteString(strconv.FormatInt(g.ID, 10))
}
if sb.Len() > 0 {
sb.WriteString(")")
}
return fmt.Sprintf(`SELECT um.group_id,u.username FROM %v um INNER JOIN %v u ON um.user_id = u.id
WHERE um.group_id IN %v ORDER BY um.group_id`, sqlTableUsersGroupsMapping, sqlTableUsers, sb.String())
}
func getRelatedFoldersForGroupsQuery(groups []Group) string {
var sb strings.Builder
for _, g := range groups {
if sb.Len() == 0 {
sb.WriteString("(")
} else {
sb.WriteString(",")
}
sb.WriteString(strconv.FormatInt(g.ID, 10))
}
if sb.Len() > 0 {
sb.WriteString(")")
}
return fmt.Sprintf(`SELECT f.id,f.name,f.path,f.used_quota_size,f.used_quota_files,f.last_quota_update,fm.virtual_path,
fm.quota_size,fm.quota_files,fm.group_id,f.filesystem,f.description FROM %s f INNER JOIN %s fm ON f.id = fm.folder_id WHERE
fm.group_id IN %v ORDER BY fm.group_id`, sqlTableFolders, sqlTableGroupsFoldersMapping, sb.String())
}
func getActiveTransfersQuery() string {
return fmt.Sprintf(`SELECT transfer_id,connection_id,transfer_type,username,folder_name,ip,truncated_size,
current_ul_size,current_dl_size,created_at,updated_at FROM %v WHERE updated_at > %v`,
sqlTableActiveTransfers, sqlPlaceholders[0])
}
func getAddActiveTransferQuery() string {
return fmt.Sprintf(`INSERT INTO %v (transfer_id,connection_id,transfer_type,username,folder_name,ip,truncated_size,
current_ul_size,current_dl_size,created_at,updated_at) VALUES (%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,%v)`,
sqlTableActiveTransfers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3],
sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8],
sqlPlaceholders[9], sqlPlaceholders[10])
}
func getUpdateActiveTransferSizesQuery() string {
return fmt.Sprintf(`UPDATE %v SET current_ul_size=%v,current_dl_size=%v,updated_at=%v WHERE connection_id = %v AND transfer_id = %v`,
sqlTableActiveTransfers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4])
}
func getRemoveActiveTransferQuery() string {
return fmt.Sprintf(`DELETE FROM %v WHERE connection_id = %v AND transfer_id = %v`,
sqlTableActiveTransfers, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getCleanupActiveTransfersQuery() string {
return fmt.Sprintf(`DELETE FROM %v WHERE updated_at < %v`, sqlTableActiveTransfers, sqlPlaceholders[0])
}
func getDatabaseVersionQuery() string {
return fmt.Sprintf("SELECT version from %v LIMIT 1", sqlTableSchemaVersion)
}
func getUpdateDBVersionQuery() string {
return fmt.Sprintf(`UPDATE %v SET version=%v`, sqlTableSchemaVersion, sqlPlaceholders[0])
}

View File

@@ -24,17 +24,18 @@ import (
"net"
"os"
"path"
"path/filepath"
"strings"
"time"
"github.com/sftpgo/sdk"
"github.com/drakkan/sftpgo/v2/internal/kms"
"github.com/drakkan/sftpgo/v2/internal/logger"
"github.com/drakkan/sftpgo/v2/internal/mfa"
"github.com/drakkan/sftpgo/v2/internal/plugin"
"github.com/drakkan/sftpgo/v2/internal/util"
"github.com/drakkan/sftpgo/v2/internal/vfs"
"github.com/drakkan/sftpgo/v2/kms"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/mfa"
"github.com/drakkan/sftpgo/v2/plugin"
"github.com/drakkan/sftpgo/v2/util"
"github.com/drakkan/sftpgo/v2/vfs"
)
// Available permissions for SFTPGo users
@@ -142,8 +143,6 @@ type User struct {
fsCache map[string]vfs.Fs `json:"-"`
// true if group settings are already applied for this user
groupSettingsApplied bool `json:"-"`
// in multi node setups we mark the user as deleted to be able to update the webdav cache
DeletedAt int64 `json:"-"`
}
// GetFilesystem returns the base filesystem for this user
@@ -168,8 +167,6 @@ func (u *User) getRootFs(connectionID string) (fs vfs.Fs, err error) {
}
forbiddenSelfUsers = append(forbiddenSelfUsers, u.Username)
return vfs.NewSFTPFs(connectionID, "", u.GetHomeDir(), forbiddenSelfUsers, u.FsConfig.SFTPConfig)
case sdk.HTTPFilesystemProvider:
return vfs.NewHTTPFs(connectionID, u.GetHomeDir(), "", u.FsConfig.HTTPConfig)
default:
return vfs.NewOsFs(connectionID, u.GetHomeDir(), ""), nil
}
@@ -300,7 +297,7 @@ func (u *User) isFsEqual(other *User) bool {
if u.FsConfig.Provider == sdk.LocalFilesystemProvider && u.GetHomeDir() != other.GetHomeDir() {
return false
}
if !u.FsConfig.IsEqual(other.FsConfig) {
if !u.FsConfig.IsEqual(&other.FsConfig) {
return false
}
if u.Filters.StartDirectory != other.Filters.StartDirectory {
@@ -319,7 +316,7 @@ func (u *User) isFsEqual(other *User) bool {
if f.FsConfig.Provider == sdk.LocalFilesystemProvider && f.MappedPath != f1.MappedPath {
return false
}
if !f.FsConfig.IsEqual(f1.FsConfig) {
if !f.FsConfig.IsEqual(&f1.FsConfig) {
return false
}
}
@@ -373,20 +370,6 @@ func (u *User) GetSubDirPermissions() []sdk.DirectoryPermissions {
return result
}
func (u *User) setAnonymousSettings() {
for k := range u.Permissions {
u.Permissions[k] = []string{PermListItems, PermDownload}
}
u.Filters.DeniedProtocols = append(u.Filters.DeniedProtocols, protocolSSH, protocolHTTP)
u.Filters.DeniedProtocols = util.RemoveDuplicates(u.Filters.DeniedProtocols, false)
for _, method := range ValidLoginMethods {
if method != LoginMethodPassword {
u.Filters.DeniedLoginMethods = append(u.Filters.DeniedLoginMethods, method)
}
}
u.Filters.DeniedLoginMethods = util.RemoveDuplicates(u.Filters.DeniedLoginMethods, false)
}
// RenderAsJSON implements the renderer interface used within plugins
func (u *User) RenderAsJSON(reload bool) ([]byte, error) {
if reload {
@@ -496,10 +479,16 @@ func (u *User) GetPermissionsForPath(p string) []string {
return permissions
}
func (u *User) getForbiddenSFTPSelfUsers(username string) ([]string, error) {
if allowSelfConnections == 0 {
return nil, nil
// HasBufferedSFTP returns true if the user has a SFTP filesystem with buffering enabled
func (u *User) HasBufferedSFTP(name string) bool {
fs := u.GetFsConfigForPath(name)
if fs.Provider == sdk.SFTPFilesystemProvider {
return fs.SFTPConfig.BufferSize > 0
}
return false
}
func (u *User) getForbiddenSFTPSelfUsers(username string) ([]string, error) {
sftpUser, err := UserExists(username)
if err == nil {
err = sftpUser.LoadAndApplyGroupSettings()
@@ -716,7 +705,7 @@ func (u *User) FilterListDir(dirContents []os.FileInfo, virtualPath string) []os
for dir := range vdirs {
if fi.Name() == dir {
if !fi.IsDir() {
fi = vfs.NewFileInfo(dir, true, 0, time.Unix(0, 0), false)
fi = vfs.NewFileInfo(dir, true, 0, time.Now(), false)
dirContents[index] = fi
}
delete(vdirs, dir)
@@ -738,7 +727,7 @@ func (u *User) FilterListDir(dirContents []os.FileInfo, virtualPath string) []os
}
for dir := range vdirs {
fi := vfs.NewFileInfo(dir, true, 0, time.Unix(0, 0), false)
fi := vfs.NewFileInfo(dir, true, 0, time.Now(), false)
dirContents = append(dirContents, fi)
}
return dirContents
@@ -787,10 +776,7 @@ func (u *User) HasVirtualFoldersInside(virtualPath string) bool {
// HasPermissionsInside returns true if the specified virtualPath has no permissions itself and
// no subdirs with defined permissions
func (u *User) HasPermissionsInside(virtualPath string) bool {
for dir, perms := range u.Permissions {
if len(perms) == 1 && perms[0] == PermAny {
continue
}
for dir := range u.Permissions {
if dir == virtualPath {
return true
} else if len(dir) > len(virtualPath) {
@@ -1164,7 +1150,7 @@ func (u *User) GetBandwidthForIP(clientIP, connectionID string) (int64, int64) {
// IsLoginFromAddrAllowed returns true if the login is allowed from the specified remoteAddr.
// If AllowedIP is defined only the specified IP/Mask can login.
// If DeniedIP is defined the specified IP/Mask cannot login.
// If an IP is both allowed and denied then login will be allowed
// If an IP is both allowed and denied then login will be denied
func (u *User) IsLoginFromAddrAllowed(remoteAddr string) bool {
if len(u.Filters.AllowedIP) == 0 && len(u.Filters.DeniedIP) == 0 {
return true
@@ -1175,15 +1161,6 @@ func (u *User) IsLoginFromAddrAllowed(remoteAddr string) bool {
logger.Warn(logSender, "", "login allowed for invalid IP. remote address: %#v", remoteAddr)
return true
}
for _, IPMask := range u.Filters.AllowedIP {
_, IPNet, err := net.ParseCIDR(IPMask)
if err != nil {
return false
}
if IPNet.Contains(remoteIP) {
return true
}
}
for _, IPMask := range u.Filters.DeniedIP {
_, IPNet, err := net.ParseCIDR(IPMask)
if err != nil {
@@ -1193,6 +1170,15 @@ func (u *User) IsLoginFromAddrAllowed(remoteAddr string) bool {
return false
}
}
for _, IPMask := range u.Filters.AllowedIP {
_, IPNet, err := net.ParseCIDR(IPMask)
if err != nil {
return false
}
if IPNet.Contains(remoteIP) {
return true
}
}
return len(u.Filters.AllowedIP) == 0
}
@@ -1415,8 +1401,6 @@ func (u *User) GetStorageDescrition() string {
return fmt.Sprintf("Encrypted: %v", u.GetHomeDir())
case sdk.SFTPFilesystemProvider:
return fmt.Sprintf("SFTP: %v", u.FsConfig.SFTPConfig.Endpoint)
case sdk.HTTPFilesystemProvider:
return fmt.Sprintf("HTTP: %v", u.FsConfig.HTTPConfig.Endpoint)
default:
return ""
}
@@ -1562,37 +1546,17 @@ func (u *User) HasSecondaryGroup(name string) bool {
return false
}
// HasMembershipGroup returns true if the user has the specified membership group
func (u *User) HasMembershipGroup(name string) bool {
for _, g := range u.Groups {
if g.Name == name {
return g.Type == sdk.GroupTypeMembership
}
}
return false
}
func (u *User) hasSettingsFromGroups() bool {
for _, g := range u.Groups {
if g.Type != sdk.GroupTypeMembership {
return true
}
}
return false
}
func (u *User) applyGroupSettings(groupsMapping map[string]Group) {
if !u.hasSettingsFromGroups() {
if len(u.Groups) == 0 {
return
}
if u.groupSettingsApplied {
return
}
replacer := u.getGroupPlacehodersReplacer()
for _, g := range u.Groups {
if g.Type == sdk.GroupTypePrimary {
if group, ok := groupsMapping[g.Name]; ok {
u.mergeWithPrimaryGroup(group, replacer)
u.mergeWithPrimaryGroup(group)
} else {
providerLog(logger.LevelError, "mapping not found for user %s, group %s", u.Username, g.Name)
}
@@ -1602,7 +1566,7 @@ func (u *User) applyGroupSettings(groupsMapping map[string]Group) {
for _, g := range u.Groups {
if g.Type == sdk.GroupTypeSecondary {
if group, ok := groupsMapping[g.Name]; ok {
u.mergeAdditiveProperties(group, sdk.GroupTypeSecondary, replacer)
u.mergeAdditiveProperties(group, sdk.GroupTypeSecondary)
} else {
providerLog(logger.LevelError, "mapping not found for user %s, group %s", u.Username, g.Name)
}
@@ -1613,7 +1577,7 @@ func (u *User) applyGroupSettings(groupsMapping map[string]Group) {
// LoadAndApplyGroupSettings update the user by loading and applying the group settings
func (u *User) LoadAndApplyGroupSettings() error {
if !u.hasSettingsFromGroups() {
if len(u.Groups) == 0 {
return nil
}
if u.groupSettingsApplied {
@@ -1625,19 +1589,16 @@ func (u *User) LoadAndApplyGroupSettings() error {
if g.Type == sdk.GroupTypePrimary {
primaryGroupName = g.Name
}
if g.Type != sdk.GroupTypeMembership {
names = append(names, g.Name)
}
names = append(names, g.Name)
}
groups, err := provider.getGroupsWithNames(names)
if err != nil {
return fmt.Errorf("unable to get groups: %w", err)
}
replacer := u.getGroupPlacehodersReplacer()
// make sure to always merge with the primary group first
for idx, g := range groups {
if g.Name == primaryGroupName {
u.mergeWithPrimaryGroup(g, replacer)
u.mergeWithPrimaryGroup(g)
lastIdx := len(groups) - 1
groups[idx] = groups[lastIdx]
groups = groups[:lastIdx]
@@ -1645,46 +1606,40 @@ func (u *User) LoadAndApplyGroupSettings() error {
}
}
for _, g := range groups {
u.mergeAdditiveProperties(g, sdk.GroupTypeSecondary, replacer)
u.mergeAdditiveProperties(g, sdk.GroupTypeSecondary)
}
u.removeDuplicatesAfterGroupMerge()
return nil
}
func (u *User) getGroupPlacehodersReplacer() *strings.Replacer {
return strings.NewReplacer("%username%", u.Username)
}
func (u *User) replacePlaceholder(value string, replacer *strings.Replacer) string {
func (u *User) replacePlaceholder(value string) string {
if value == "" {
return value
}
return replacer.Replace(value)
return strings.ReplaceAll(value, "%username%", u.Username)
}
func (u *User) replaceFsConfigPlaceholders(fsConfig vfs.Filesystem, replacer *strings.Replacer) vfs.Filesystem {
func (u *User) replaceFsConfigPlaceholders(fsConfig vfs.Filesystem) vfs.Filesystem {
switch fsConfig.Provider {
case sdk.S3FilesystemProvider:
fsConfig.S3Config.KeyPrefix = u.replacePlaceholder(fsConfig.S3Config.KeyPrefix, replacer)
fsConfig.S3Config.KeyPrefix = u.replacePlaceholder(fsConfig.S3Config.KeyPrefix)
case sdk.GCSFilesystemProvider:
fsConfig.GCSConfig.KeyPrefix = u.replacePlaceholder(fsConfig.GCSConfig.KeyPrefix, replacer)
fsConfig.GCSConfig.KeyPrefix = u.replacePlaceholder(fsConfig.GCSConfig.KeyPrefix)
case sdk.AzureBlobFilesystemProvider:
fsConfig.AzBlobConfig.KeyPrefix = u.replacePlaceholder(fsConfig.AzBlobConfig.KeyPrefix, replacer)
fsConfig.AzBlobConfig.KeyPrefix = u.replacePlaceholder(fsConfig.AzBlobConfig.KeyPrefix)
case sdk.SFTPFilesystemProvider:
fsConfig.SFTPConfig.Username = u.replacePlaceholder(fsConfig.SFTPConfig.Username, replacer)
fsConfig.SFTPConfig.Prefix = u.replacePlaceholder(fsConfig.SFTPConfig.Prefix, replacer)
case sdk.HTTPFilesystemProvider:
fsConfig.HTTPConfig.Username = u.replacePlaceholder(fsConfig.HTTPConfig.Username, replacer)
fsConfig.SFTPConfig.Username = u.replacePlaceholder(fsConfig.SFTPConfig.Username)
fsConfig.SFTPConfig.Prefix = u.replacePlaceholder(fsConfig.SFTPConfig.Prefix)
}
return fsConfig
}
func (u *User) mergeWithPrimaryGroup(group Group, replacer *strings.Replacer) {
func (u *User) mergeWithPrimaryGroup(group Group) {
if group.UserSettings.HomeDir != "" {
u.HomeDir = u.replacePlaceholder(group.UserSettings.HomeDir, replacer)
u.HomeDir = u.replacePlaceholder(group.UserSettings.HomeDir)
}
if group.UserSettings.FsConfig.Provider != 0 {
u.FsConfig = u.replaceFsConfigPlaceholders(group.UserSettings.FsConfig, replacer)
u.FsConfig = u.replaceFsConfigPlaceholders(group.UserSettings.FsConfig)
}
if u.MaxSessions == 0 {
u.MaxSessions = group.UserSettings.MaxSessions
@@ -1706,11 +1661,11 @@ func (u *User) mergeWithPrimaryGroup(group Group, replacer *strings.Replacer) {
u.DownloadDataTransfer = group.UserSettings.DownloadDataTransfer
u.TotalDataTransfer = group.UserSettings.TotalDataTransfer
}
u.mergePrimaryGroupFilters(group.UserSettings.Filters, replacer)
u.mergeAdditiveProperties(group, sdk.GroupTypePrimary, replacer)
u.mergePrimaryGroupFilters(group.UserSettings.Filters)
u.mergeAdditiveProperties(group, sdk.GroupTypePrimary)
}
func (u *User) mergePrimaryGroupFilters(filters sdk.BaseUserFilters, replacer *strings.Replacer) {
func (u *User) mergePrimaryGroupFilters(filters sdk.BaseUserFilters) {
if u.Filters.MaxUploadFileSize == 0 {
u.Filters.MaxUploadFileSize = filters.MaxUploadFileSize
}
@@ -1732,27 +1687,18 @@ func (u *User) mergePrimaryGroupFilters(filters sdk.BaseUserFilters, replacer *s
if !u.Filters.AllowAPIKeyAuth {
u.Filters.AllowAPIKeyAuth = filters.AllowAPIKeyAuth
}
if !u.Filters.IsAnonymous {
u.Filters.IsAnonymous = filters.IsAnonymous
}
if u.Filters.ExternalAuthCacheTime == 0 {
u.Filters.ExternalAuthCacheTime = filters.ExternalAuthCacheTime
}
if u.Filters.FTPSecurity == 0 {
u.Filters.FTPSecurity = filters.FTPSecurity
}
if u.Filters.StartDirectory == "" {
u.Filters.StartDirectory = u.replacePlaceholder(filters.StartDirectory, replacer)
}
if u.Filters.DefaultSharesExpiration == 0 {
u.Filters.DefaultSharesExpiration = filters.DefaultSharesExpiration
u.Filters.StartDirectory = u.replacePlaceholder(filters.StartDirectory)
}
}
func (u *User) mergeAdditiveProperties(group Group, groupType int, replacer *strings.Replacer) {
u.mergeVirtualFolders(group, groupType, replacer)
u.mergePermissions(group, groupType, replacer)
u.mergeFilePatterns(group, groupType, replacer)
func (u *User) mergeAdditiveProperties(group Group, groupType int) {
u.mergeVirtualFolders(group, groupType)
u.mergePermissions(group, groupType)
u.mergeFilePatterns(group, groupType)
u.Filters.BandwidthLimits = append(u.Filters.BandwidthLimits, group.UserSettings.Filters.BandwidthLimits...)
u.Filters.DataTransferLimits = append(u.Filters.DataTransferLimits, group.UserSettings.Filters.DataTransferLimits...)
u.Filters.AllowedIP = append(u.Filters.AllowedIP, group.UserSettings.Filters.AllowedIP...)
@@ -1763,7 +1709,7 @@ func (u *User) mergeAdditiveProperties(group Group, groupType int, replacer *str
u.Filters.TwoFactorAuthProtocols = append(u.Filters.TwoFactorAuthProtocols, group.UserSettings.Filters.TwoFactorAuthProtocols...)
}
func (u *User) mergeVirtualFolders(group Group, groupType int, replacer *strings.Replacer) {
func (u *User) mergeVirtualFolders(group Group, groupType int) {
if len(group.VirtualFolders) > 0 {
folderPaths := make(map[string]bool)
for _, folder := range u.VirtualFolders {
@@ -1773,17 +1719,17 @@ func (u *User) mergeVirtualFolders(group Group, groupType int, replacer *strings
if folder.VirtualPath == "/" && groupType != sdk.GroupTypePrimary {
continue
}
folder.VirtualPath = u.replacePlaceholder(folder.VirtualPath, replacer)
folder.VirtualPath = u.replacePlaceholder(folder.VirtualPath)
if _, ok := folderPaths[folder.VirtualPath]; !ok {
folder.MappedPath = u.replacePlaceholder(folder.MappedPath, replacer)
folder.FsConfig = u.replaceFsConfigPlaceholders(folder.FsConfig, replacer)
folder.MappedPath = u.replacePlaceholder(folder.MappedPath)
folder.FsConfig = u.replaceFsConfigPlaceholders(folder.FsConfig)
u.VirtualFolders = append(u.VirtualFolders, folder)
}
}
}
}
func (u *User) mergePermissions(group Group, groupType int, replacer *strings.Replacer) {
func (u *User) mergePermissions(group Group, groupType int) {
for k, v := range group.UserSettings.Permissions {
if k == "/" {
if groupType == sdk.GroupTypePrimary {
@@ -1792,14 +1738,14 @@ func (u *User) mergePermissions(group Group, groupType int, replacer *strings.Re
continue
}
}
k = u.replacePlaceholder(k, replacer)
k = u.replacePlaceholder(k)
if _, ok := u.Permissions[k]; !ok {
u.Permissions[k] = v
}
}
}
func (u *User) mergeFilePatterns(group Group, groupType int, replacer *strings.Replacer) {
func (u *User) mergeFilePatterns(group Group, groupType int) {
if len(group.UserSettings.Filters.FilePatterns) > 0 {
patternPaths := make(map[string]bool)
for _, pattern := range u.Filters.FilePatterns {
@@ -1809,7 +1755,7 @@ func (u *User) mergeFilePatterns(group Group, groupType int, replacer *strings.R
if pattern.Path == "/" && groupType != sdk.GroupTypePrimary {
continue
}
pattern.Path = u.replacePlaceholder(pattern.Path, replacer)
pattern.Path = u.replacePlaceholder(pattern.Path)
if _, ok := patternPaths[pattern.Path]; !ok {
u.Filters.FilePatterns = append(u.Filters.FilePatterns, pattern)
}
@@ -1896,8 +1842,6 @@ func (u *User) getACopy() User {
Status: u.Status,
ExpirationDate: u.ExpirationDate,
LastLogin: u.LastLogin,
FirstDownload: u.FirstDownload,
FirstUpload: u.FirstUpload,
AdditionalInfo: u.AdditionalInfo,
Description: u.Description,
CreatedAt: u.CreatedAt,
@@ -1915,3 +1859,8 @@ func (u *User) getACopy() User {
func (u *User) GetEncryptionAdditionalData() string {
return u.Username
}
// GetGCSCredentialsFilePath returns the path for GCS credentials
func (u *User) GetGCSCredentialsFilePath() string {
return filepath.Join(credentialsDirPath, fmt.Sprintf("%v_gcs_credentials.json", u.Username))
}

View File

@@ -4,14 +4,12 @@ SFTPGo provides an official Docker image, it is available on both [Docker Hub](h
## Supported tags and respective Dockerfile links
- [v2.4.0, v2.4, v2, latest](https://github.com/drakkan/sftpgo/blob/v2.4.0/Dockerfile)
- [v2.4.0-plugins, v2.4-plugins, v2-plugins, plugins](https://github.com/drakkan/sftpgo/blob/v2.4.0/Dockerfile)
- [v2.4.0-alpine, v2.4-alpine, v2-alpine, alpine](https://github.com/drakkan/sftpgo/blob/v2.4.0/Dockerfile.alpine)
- [v2.4.0-slim, v2.4-slim, v2-slim, slim](https://github.com/drakkan/sftpgo/blob/v2.4.0/Dockerfile)
- [v2.4.0-alpine-slim, v2.4-alpine-slim, v2-alpine-slim, alpine-slim](https://github.com/drakkan/sftpgo/blob/v2.4.0/Dockerfile.alpine)
- [v2.4.0-distroless-slim, v2.4-distroless-slim, v2-distroless-slim, distroless-slim](https://github.com/drakkan/sftpgo/blob/v2.4.0/Dockerfile.distroless)
- [v2.3.3, v2.3, v2, latest](https://github.com/drakkan/sftpgo/blob/v2.3.3/Dockerfile)
- [v2.3.3-alpine, v2.3-alpine, v2-alpine, alpine](https://github.com/drakkan/sftpgo/blob/v2.3.3/Dockerfile.alpine)
- [v2.3.3-slim, v2.3-slim, v2-slim, slim](https://github.com/drakkan/sftpgo/blob/v2.3.3/Dockerfile)
- [v2.3.3-alpine-slim, v2.3-alpine-slim, v2-alpine-slim, alpine-slim](https://github.com/drakkan/sftpgo/blob/v2.3.3/Dockerfile.alpine)
- [v2.3.3-distroless-slim, v2.3-distroless-slim, v2-distroless-slim, distroless-slim](https://github.com/drakkan/sftpgo/blob/v2.3.3/Dockerfile.distroless)
- [edge](../Dockerfile)
- [edge-plugins](../Dockerfile)
- [edge-alpine](../Dockerfile.alpine)
- [edge-slim](../Dockerfile)
- [edge-alpine-slim](../Dockerfile.alpine)
@@ -199,17 +197,8 @@ We only provide the slim variant and so the optional `git` dependency is not ava
### `sftpgo:<suite>-slim`
These tags provide a slimmer image that does not include `jq` and the optional `git` and `rsync` dependencies.
### `sftpgo:<suite>-plugins`
These tags provide the standard image with the addition of all "official" plugins installed in `/usr/local/bin`.
These tags provide a slimmer image that does not include the optional `git` dependency.
## Helm Chart
Some helm charts are available:
- [sagikazarmark/sftpgo](https://artifacthub.io/packages/helm/sagikazarmark/sftpgo)
- [truecharts/sftpgo](https://artifacthub.io/packages/helm/truecharts/sftpgo)
These charts are not maintained by the SFTPGo project and any issues with the charts should be raised to the upstream repo.
An helm chart is [available](https://artifacthub.io/packages/helm/sagikazarmark/sftpgo). You can find the source code [here](https://github.com/sagikazarmark/helm-charts/tree/master/charts/sftpgo).

View File

@@ -1,25 +0,0 @@
#!/usr/bin/env bash
set -e
ARCH=`uname -m`
case ${ARCH} in
"x86_64")
SUFFIX=amd64
;;
"aarch64")
SUFFIX=arm64
;;
*)
SUFFIX=ppc64le
;;
esac
echo "download plugins for arch ${SUFFIX}"
for PLUGIN in geoipfilter kms pubsub eventstore eventsearch metadata
do
echo "download plugin from https://github.com/sftpgo/sftpgo-plugin-${PLUGIN}/releases/latest/download/sftpgo-plugin-${PLUGIN}-linux-${SUFFIX}"
curl -L "https://github.com/sftpgo/sftpgo-plugin-${PLUGIN}/releases/latest/download/sftpgo-plugin-${PLUGIN}-linux-${SUFFIX}" --output "/usr/local/bin/sftpgo-plugin-${PLUGIN}"
chmod 755 "/usr/local/bin/sftpgo-plugin-${PLUGIN}"
done

View File

@@ -0,0 +1,28 @@
#!/usr/bin/env bash
SFTPGO_PUID=${SFTPGO_PUID:-1000}
SFTPGO_PGID=${SFTPGO_PGID:-1000}
if [ "$1" = 'sftpgo' ]; then
if [ "$(id -u)" = '0' ]; then
for DIR in "/etc/sftpgo" "/var/lib/sftpgo" "/srv/sftpgo"
do
DIR_UID=$(stat -c %u ${DIR})
DIR_GID=$(stat -c %g ${DIR})
if [ ${DIR_UID} != ${SFTPGO_PUID} ] || [ ${DIR_GID} != ${SFTPGO_PGID} ]; then
echo '{"level":"info","time":"'`date +%Y-%m-%dT%H:%M:%S.000`'","sender":"entrypoint","message":"change owner for \"'${DIR}'\" UID: '${SFTPGO_PUID}' GID: '${SFTPGO_PGID}'"}'
if [ ${DIR} = "/etc/sftpgo" ]; then
chown -R ${SFTPGO_PUID}:${SFTPGO_PGID} ${DIR}
else
chown ${SFTPGO_PUID}:${SFTPGO_PGID} ${DIR}
fi
fi
done
echo '{"level":"info","time":"'`date +%Y-%m-%dT%H:%M:%S.000`'","sender":"entrypoint","message":"run as UID: '${SFTPGO_PUID}' GID: '${SFTPGO_PGID}'"}'
exec su-exec ${SFTPGO_PUID}:${SFTPGO_PGID} "$@"
fi
exec "$@"
fi
exec "$@"

32
docker/scripts/entrypoint.sh Executable file
View File

@@ -0,0 +1,32 @@
#!/usr/bin/env bash
SFTPGO_PUID=${SFTPGO_PUID:-1000}
SFTPGO_PGID=${SFTPGO_PGID:-1000}
if [ "$1" = 'sftpgo' ]; then
if [ "$(id -u)" = '0' ]; then
getent passwd ${SFTPGO_PUID} > /dev/null
HAS_PUID=$?
getent group ${SFTPGO_PGID} > /dev/null
HAS_PGID=$?
if [ ${HAS_PUID} -ne 0 ] || [ ${HAS_PGID} -ne 0 ]; then
echo '{"level":"info","time":"'`date +%Y-%m-%dT%H:%M:%S.%3N`'","sender":"entrypoint","message":"prepare to run as UID: '${SFTPGO_PUID}' GID: '${SFTPGO_PGID}'"}'
if [ ${HAS_PGID} -ne 0 ]; then
echo '{"level":"info","time":"'`date +%Y-%m-%dT%H:%M:%S.%3N`'","sender":"entrypoint","message":"set GID to: '${SFTPGO_PGID}'"}'
groupmod -g ${SFTPGO_PGID} sftpgo
fi
if [ ${HAS_PUID} -ne 0 ]; then
echo '{"level":"info","time":"'`date +%Y-%m-%dT%H:%M:%S.%3N`'","sender":"entrypoint","message":"set UID to: '${SFTPGO_PUID}'"}'
usermod -u ${SFTPGO_PUID} sftpgo
fi
chown -R ${SFTPGO_PUID}:${SFTPGO_PGID} /etc/sftpgo
chown ${SFTPGO_PUID}:${SFTPGO_PGID} /var/lib/sftpgo /srv/sftpgo
fi
echo '{"level":"info","time":"'`date +%Y-%m-%dT%H:%M:%S.%3N`'","sender":"entrypoint","message":"run as UID: '${SFTPGO_PUID}' GID: '${SFTPGO_PGID}'"}'
exec gosu ${SFTPGO_PUID}:${SFTPGO_PGID} "$@"
fi
exec "$@"
fi
exec "$@"

View File

@@ -0,0 +1,50 @@
FROM golang:alpine as builder
RUN apk add --no-cache git gcc g++ ca-certificates \
&& go get -v -d github.com/drakkan/sftpgo
WORKDIR /go/src/github.com/drakkan/sftpgo
ARG TAG
ARG FEATURES
# Use --build-arg TAG=LATEST for latest tag. Use e.g. --build-arg TAG=v1.0.0 for a specific tag/commit. Otherwise HEAD (master) is built.
RUN git checkout $(if [ "${TAG}" = LATEST ]; then echo `git rev-list --tags --max-count=1`; elif [ -n "${TAG}" ]; then echo "${TAG}"; else echo HEAD; fi)
RUN go build $(if [ -n "${FEATURES}" ]; then echo "-tags ${FEATURES}"; fi) -ldflags "-s -w -X github.com/drakkan/sftpgo/version.commit=`git describe --always --dirty` -X github.com/drakkan/sftpgo/version.date=`date -u +%FT%TZ`" -v -o /go/bin/sftpgo
FROM alpine:latest
RUN apk add --no-cache ca-certificates su-exec \
&& mkdir -p /data /etc/sftpgo /srv/sftpgo/config /srv/sftpgo/web /srv/sftpgo/backups
# git and rsync are optional, uncomment the next line to add support for them if needed.
#RUN apk add --no-cache git rsync
COPY --from=builder /go/bin/sftpgo /bin/
COPY --from=builder /go/src/github.com/drakkan/sftpgo/sftpgo.json /etc/sftpgo/sftpgo.json
COPY --from=builder /go/src/github.com/drakkan/sftpgo/templates /srv/sftpgo/web/templates
COPY --from=builder /go/src/github.com/drakkan/sftpgo/static /srv/sftpgo/web/static
COPY docker-entrypoint.sh /bin/entrypoint.sh
RUN chmod +x /bin/entrypoint.sh
VOLUME [ "/data", "/srv/sftpgo/config", "/srv/sftpgo/backups" ]
EXPOSE 2022 8080
# uncomment the following settings to enable FTP support
#ENV SFTPGO_FTPD__BIND_PORT=2121
#ENV SFTPGO_FTPD__FORCE_PASSIVE_IP=<your FTP visibile IP here>
#EXPOSE 2121
# we need to expose the passive ports range too
#EXPOSE 50000-50100
# it is a good idea to provide certificates to enable FTPS too
#ENV SFTPGO_FTPD__CERTIFICATE_FILE=/srv/sftpgo/config/mycert.crt
#ENV SFTPGO_FTPD__CERTIFICATE_KEY_FILE=/srv/sftpgo/config/mycert.key
# uncomment the following setting to enable WebDAV support
#ENV SFTPGO_WEBDAVD__BIND_PORT=8090
# it is a good idea to provide certificates to enable WebDAV over HTTPS
#ENV SFTPGO_WEBDAVD__CERTIFICATE_FILE=${CONFIG_DIR}/mycert.crt
#ENV SFTPGO_WEBDAVD__CERTIFICATE_KEY_FILE=${CONFIG_DIR}/mycert.key
ENTRYPOINT ["/bin/entrypoint.sh"]
CMD ["serve"]

View File

@@ -0,0 +1,61 @@
# SFTPGo with Docker and Alpine
:warning: The recommended way to run SFTPGo on Docker is to use the official [images](https://hub.docker.com/r/drakkan/sftpgo). The documentation here is now obsolete.
This DockerFile is made to build image to host multiple instances of SFTPGo started with different users.
## Example
> 1003 is a custom uid:gid for this instance of SFTPGo
```bash
# Prereq on docker host
sudo groupadd -g 1003 sftpgrp && \
sudo useradd -u 1003 -g 1003 sftpuser -d /home/sftpuser/ && \
sudo -u sftpuser mkdir /home/sftpuser/{conf,data} && \
curl https://raw.githubusercontent.com/drakkan/sftpgo/master/sftpgo.json -o /home/sftpuser/conf/sftpgo.json
# Edit sftpgo.json as you need
# Get and build SFTPGo image.
# Add --build-arg TAG=LATEST to build the latest tag or e.g. TAG=v1.0.0 for a specific tag/commit.
# Add --build-arg FEATURES=<build features comma separated> to specify the features to build.
git clone https://github.com/drakkan/sftpgo.git && \
cd sftpgo && \
sudo docker build -t sftpgo docker/sftpgo/alpine/
# Initialize the configured provider. For PostgreSQL and MySQL providers you need to create the configured database and the "initprovider" command will create the required tables.
sudo docker run --name sftpgo \
-e PUID=1003 \
-e GUID=1003 \
-v /home/sftpuser/conf/:/srv/sftpgo/config \
sftpgo initprovider -c /srv/sftpgo/config
# Start the image
sudo docker rm sftpgo && sudo docker run --name sftpgo \
-e SFTPGO_LOG_FILE_PATH= \
-e SFTPGO_CONFIG_DIR=/srv/sftpgo/config \
-e SFTPGO_HTTPD__TEMPLATES_PATH=/srv/sftpgo/web/templates \
-e SFTPGO_HTTPD__STATIC_FILES_PATH=/srv/sftpgo/web/static \
-e SFTPGO_HTTPD__BACKUPS_PATH=/srv/sftpgo/backups \
-p 8080:8080 \
-p 2022:2022 \
-e PUID=1003 \
-e GUID=1003 \
-v /home/sftpuser/conf/:/srv/sftpgo/config \
-v /home/sftpuser/data:/data \
-v /home/sftpuser/backups:/srv/sftpgo/backups \
sftpgo
```
If you want to enable FTP/S you also need the publish the FTP port and the FTP passive port range, defined in your `Dockerfile`, by adding, for example, the following options to the `docker run` command `-p 2121:2121 -p 50000-50100:50000-50100`. The same goes for WebDAV, you need to publish the configured port.
The script `entrypoint.sh` makes sure to correct the permissions of directories and start the process with the right user.
Several images can be run with different parameters.
## Custom systemd script
An example of systemd script is present [here](sftpgo.service), with `Environment` parameter to set `PUID` and `GUID`
`WorkingDirectory` parameter must be exist with one file in this directory like `sftpgo-${PUID}.env` corresponding to the variable file for SFTPGo instance.

View File

@@ -0,0 +1,7 @@
#!/bin/sh
set -eu
chown -R "${PUID}:${GUID}" /data /etc/sftpgo /srv/sftpgo/config /srv/sftpgo/backups \
&& exec su-exec "${PUID}:${GUID}" \
/bin/sftpgo "$@"

View File

@@ -0,0 +1,35 @@
[Unit]
Description=SFTPGo server
After=docker.service
[Service]
User=root
Group=root
WorkingDirectory=/etc/sftpgo
Environment=PUID=1003
Environment=GUID=1003
EnvironmentFile=-/etc/sysconfig/sftpgo.env
ExecStartPre=-docker kill sftpgo
ExecStartPre=-docker rm sftpgo
ExecStart=docker run --name sftpgo \
--env-file sftpgo-${PUID}.env \
-e PUID=${PUID} \
-e GUID=${GUID} \
-e SFTPGO_LOG_FILE_PATH= \
-e SFTPGO_CONFIG_DIR=/srv/sftpgo/config \
-e SFTPGO_HTTPD__TEMPLATES_PATH=/srv/sftpgo/web/templates \
-e SFTPGO_HTTPD__STATIC_FILES_PATH=/srv/sftpgo/web/static \
-e SFTPGO_HTTPD__BACKUPS_PATH=/srv/sftpgo/backups \
-p 8080:8080 \
-p 2022:2022 \
-v /home/sftpuser/conf/:/srv/sftpgo/config \
-v /home/sftpuser/data:/data \
-v /home/sftpuser/backups:/srv/sftpgo/backups \
sftpgo
ExecStop=docker stop sftpgo
SyslogIdentifier=sftpgo
Restart=always
RestartSec=10s
[Install]
WantedBy=multi-user.target

View File

@@ -0,0 +1,93 @@
# we use a multi stage build to have a separate build and run env
FROM golang:latest as buildenv
LABEL maintainer="nicola.murino@gmail.com"
RUN go get -v -d github.com/drakkan/sftpgo
WORKDIR /go/src/github.com/drakkan/sftpgo
ARG TAG
ARG FEATURES
# Use --build-arg TAG=LATEST for latest tag. Use e.g. --build-arg TAG=v1.0.0 for a specific tag/commit. Otherwise HEAD (master) is built.
RUN git checkout $(if [ "${TAG}" = LATEST ]; then echo `git rev-list --tags --max-count=1`; elif [ -n "${TAG}" ]; then echo "${TAG}"; else echo HEAD; fi)
RUN go build $(if [ -n "${FEATURES}" ]; then echo "-tags ${FEATURES}"; fi) -ldflags "-s -w -X github.com/drakkan/sftpgo/version.commit=`git describe --always --dirty` -X github.com/drakkan/sftpgo/version.date=`date -u +%FT%TZ`" -v -o sftpgo
# now define the run environment
FROM debian:latest
# ca-certificates is needed for Cloud Storage Support and for HTTPS/FTPS.
RUN apt-get update && apt-get install -y ca-certificates && apt-get clean
# git and rsync are optional, uncomment the next line to add support for them if needed.
#RUN apt-get update && apt-get install -y git rsync && apt-get clean
ARG BASE_DIR=/app
ARG DATA_REL_DIR=data
ARG CONFIG_REL_DIR=config
ARG BACKUP_REL_DIR=backups
ARG USERNAME=sftpgo
ARG GROUPNAME=sftpgo
ARG UID=515
ARG GID=515
ARG WEB_REL_PATH=web
# HOME_DIR for sftpgo itself
ENV HOME_DIR=${BASE_DIR}/${USERNAME}
# DATA_DIR, this is a volume that you can use hold user's home dirs
ENV DATA_DIR=${BASE_DIR}/${DATA_REL_DIR}
# CONFIG_DIR, this is a volume to persist the daemon private keys, configuration file ecc..
ENV CONFIG_DIR=${BASE_DIR}/${CONFIG_REL_DIR}
# BACKUPS_DIR, this is a volume to store backups done using "dumpdata" REST API
ENV BACKUPS_DIR=${BASE_DIR}/${BACKUP_REL_DIR}
ENV WEB_DIR=${BASE_DIR}/${WEB_REL_PATH}
RUN mkdir -p ${DATA_DIR} ${CONFIG_DIR} ${WEB_DIR} ${BACKUPS_DIR}
RUN groupadd --system -g ${GID} ${GROUPNAME}
RUN useradd --system --create-home --no-log-init --home-dir ${HOME_DIR} --comment "SFTPGo user" --shell /usr/sbin/nologin --gid ${GID} --uid ${UID} ${USERNAME}
WORKDIR ${HOME_DIR}
RUN mkdir -p bin .config/sftpgo
ENV PATH ${HOME_DIR}/bin:$PATH
COPY --from=buildenv /go/src/github.com/drakkan/sftpgo/sftpgo bin/sftpgo
# default config file to use if no config file is found inside the CONFIG_DIR volume.
# You can override each configuration options via env vars too
COPY --from=buildenv /go/src/github.com/drakkan/sftpgo/sftpgo.json .config/sftpgo/
COPY --from=buildenv /go/src/github.com/drakkan/sftpgo/templates ${WEB_DIR}/templates
COPY --from=buildenv /go/src/github.com/drakkan/sftpgo/static ${WEB_DIR}/static
RUN chown -R ${UID}:${GID} ${DATA_DIR} ${BACKUPS_DIR}
# run as non root user
USER ${USERNAME}
EXPOSE 2022 8080
# the defined volumes must have write access for the UID and GID defined above
VOLUME [ "$DATA_DIR", "$CONFIG_DIR", "$BACKUPS_DIR" ]
# override some default configuration options using env vars
ENV SFTPGO_CONFIG_DIR=${CONFIG_DIR}
# setting SFTPGO_LOG_FILE_PATH to an empty string will log to stdout
ENV SFTPGO_LOG_FILE_PATH=""
ENV SFTPGO_HTTPD__BIND_ADDRESS=""
ENV SFTPGO_HTTPD__TEMPLATES_PATH=${WEB_DIR}/templates
ENV SFTPGO_HTTPD__STATIC_FILES_PATH=${WEB_DIR}/static
ENV SFTPGO_DATA_PROVIDER__USERS_BASE_DIR=${DATA_DIR}
ENV SFTPGO_HTTPD__BACKUPS_PATH=${BACKUPS_DIR}
# uncomment the following settings to enable FTP support
#ENV SFTPGO_FTPD__BIND_PORT=2121
#ENV SFTPGO_FTPD__FORCE_PASSIVE_IP=<your FTP visibile IP here>
#EXPOSE 2121
# we need to expose the passive ports range too
#EXPOSE 50000-50100
# it is a good idea to provide certificates to enable FTPS too
#ENV SFTPGO_FTPD__CERTIFICATE_FILE=${CONFIG_DIR}/mycert.crt
#ENV SFTPGO_FTPD__CERTIFICATE_KEY_FILE=${CONFIG_DIR}/mycert.key
# uncomment the following setting to enable WebDAV support
#ENV SFTPGO_WEBDAVD__BIND_PORT=8090
# it is a good idea to provide certificates to enable WebDAV over HTTPS
#ENV SFTPGO_WEBDAVD__CERTIFICATE_FILE=${CONFIG_DIR}/mycert.crt
#ENV SFTPGO_WEBDAVD__CERTIFICATE_KEY_FILE=${CONFIG_DIR}/mycert.key
ENTRYPOINT ["sftpgo"]
CMD ["serve"]

View File

@@ -0,0 +1,59 @@
# Dockerfile based on Debian stable
:warning: The recommended way to run SFTPGo on Docker is to use the official [images](https://hub.docker.com/r/drakkan/sftpgo). The documentation here is now obsolete.
Please read the comments inside the `Dockerfile` to learn how to customize things for your setup.
You can build the container image using `docker build`, for example:
```bash
docker build -t="drakkan/sftpgo" .
```
This will build master of github.com/drakkan/sftpgo.
To build the latest tag you can add `--build-arg TAG=LATEST` and to build a specific tag/commit you can use for example `TAG=v1.0.0`, like this:
```bash
docker build -t="drakkan/sftpgo" --build-arg TAG=v1.0.0 .
```
To specify the features to build you can add `--build-arg FEATURES=<build features comma separated>`. For example you can disable SQLite and S3 support like this:
```bash
docker build -t="drakkan/sftpgo" --build-arg FEATURES=nosqlite,nos3 .
```
Please take a look at the [build from source](./../../../docs/build-from-source.md) documentation for the complete list of the features that can be disabled.
Now create the required folders on the host system, for example:
```bash
sudo mkdir -p /srv/sftpgo/data /srv/sftpgo/config /srv/sftpgo/backups
```
and give write access to them to the UID/GID defined inside the `Dockerfile`. You can choose to create a new user, on the host system, with a matching UID/GID pair, or simply do something like this:
```bash
sudo chown -R <UID>:<GID> /srv/sftpgo/data /srv/sftpgo/config /srv/sftpgo/backups
```
Download the default configuration file and edit it as you need:
```bash
sudo curl https://raw.githubusercontent.com/drakkan/sftpgo/master/sftpgo.json -o /srv/sftpgo/config/sftpgo.json
```
Initialize the configured provider. For PostgreSQL and MySQL providers you need to create the configured database and the `initprovider` command will create the required tables:
```bash
docker run --name sftpgo --mount type=bind,source=/srv/sftpgo/config,target=/app/config drakkan/sftpgo initprovider -c /app/config
```
and finally you can run the image using something like this:
```bash
docker rm sftpgo && docker run --name sftpgo -p 8080:8080 -p 2022:2022 --mount type=bind,source=/srv/sftpgo/data,target=/app/data --mount type=bind,source=/srv/sftpgo/config,target=/app/config --mount type=bind,source=/srv/sftpgo/backups,target=/app/backups drakkan/sftpgo
```
If you want to enable FTP/S you also need the publish the FTP port and the FTP passive port range, defined in your `Dockerfile`, by adding, for example, the following options to the `docker run` command `-p 2121:2121 -p 50000-50100:50000-50100`. The same goes for WebDAV, you need to publish the configured port.

Some files were not shown because too many files have changed in this diff Show More