Compare commits

..

27 Commits

Author SHA1 Message Date
Nicola Murino
665016ed1e set version to 2.3.3
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-08-05 09:47:00 +02:00
Nicola Murino
97d5680d1e azblob: fix SAS URL with embedded container name
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-08-01 21:52:32 +02:00
Nicola Murino
e7866047aa allow to edit profile to users logged in via OIDC
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-08-01 21:49:53 +02:00
Nicola Murino
5f313cc6be macOS: add config file search path
this way the default config file is used in brew package if no config file is
set

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-07-29 17:55:01 +02:00
Nicola Murino
3c2c703408 user templates: apply placeholders also for start directory
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-07-27 19:09:54 +02:00
Nicola Murino
78a399eed4 download as zip: improve filename
include username and also filename/directory name if the user downloads
a single file/directory

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-07-26 17:53:04 +02:00
Nicola Murino
e6d434654d backport from main
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-07-24 08:56:31 +02:00
Nicola Murino
d34446e6e9 web client: add HTML5 player
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-07-23 16:30:27 +02:00
Nicola Murino
2da19ef233 backport OIDC related changes from main
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-07-23 15:31:57 +02:00
Nicola Murino
b34bc2b818 add license header to source files
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-07-18 13:43:25 +02:00
Nicola Murino
378995147b try to better highlight donations and sponsorships options ...
... and to better explain why they are required.

Please don't say "someone else will help the project, I'll just use it"

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-07-16 20:29:10 +02:00
Nicola Murino
6b995db864 oidc: allow to configure oauth2 scopes
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-07-16 19:25:04 +02:00
Nicola Murino
371012a46e backport some fixes from main
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-07-15 20:09:06 +02:00
Nicola Murino
d3d788c8d0 s3: improve rename performance
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-06-30 18:25:40 +02:00
maximethebault
756b122ab8 S3: Fix timeout error when renaming large files (#899)
Remove AWS SDK Transport ResponseHeaderTimeout (finer-grained timeout are already handled by the callers)
Lower the threshold for MultipartCopy (5GB -> 500MB) to improve copy performance and reduce chance of hitting Single part copy timeout

Fixes #898

Signed-off-by: Maxime Thébault <contact@maximethebault.me>
2022-06-30 10:25:04 +02:00
Nicola Murino
e244ba37b2 config: fix replace from env vars for some sub list
ensure to merge configuration from files with configuration from env

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-06-28 19:17:16 +02:00
Nicola Murino
5610b98d19 fix get branding from env
Fixes #895

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-06-28 10:46:25 +02:00
Nicola Murino
b3ca20b5e6 dataprovider: fix sql tables prefix handling
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-06-24 12:26:43 +02:00
Nicola Murino
d0b6ca8d2f backup: include folders set on groups
Fixes #885

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-06-21 14:13:25 +02:00
Nicola Murino
550158ff4b fix database reset
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-06-13 19:40:24 +02:00
Nicola Murino
14a3803c8f OpenAPI schema: improve compatibility with some generators
Fixes #875

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-06-11 19:00:04 +02:00
Nicola Murino
ca4da2f64e set version to 2.3.1
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-06-10 18:42:13 +02:00
Nicola Murino
049c2b7430 mysql: groups is a reserved keyfrom since MySQL 8.0.2
add mysql to CI

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-06-10 17:36:26 +02:00
Nicola Murino
7fd5558400 parse IP proxy header also if listening on UNIX domain socket
Fixes #867

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-06-09 09:48:39 +02:00
Nicola Murino
b60255752f web UIs: fix date formatting on Safari
Fixes #869

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-06-09 09:47:02 +02:00
Nicola Murino
37f79650c8 APT and YUM repo are now available
This is possible thanks to the Oregon State University's free
mirroring service

Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-06-09 09:46:57 +02:00
Nicola Murino
8988d6542b create branch 2.3.x
Signed-off-by: Nicola Murino <nicola.murino@gmail.com>
2022-06-06 19:01:24 +02:00
340 changed files with 8138 additions and 33027 deletions

View File

@@ -2,7 +2,7 @@ name: CI
on: on:
push: push:
branches: [main] branches: [2.3.x]
pull_request: pull_request:
jobs: jobs:
@@ -11,11 +11,11 @@ jobs:
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
strategy: strategy:
matrix: matrix:
go: [1.19] go: [1.18]
os: [ubuntu-latest, macos-latest] os: [ubuntu-latest, macos-latest]
upload-coverage: [true] upload-coverage: [true]
include: include:
- go: 1.19 - go: 1.18
os: windows-latest os: windows-latest
upload-coverage: false upload-coverage: false
@@ -32,24 +32,22 @@ jobs:
- name: Build for Linux/macOS x86_64 - name: Build for Linux/macOS x86_64
if: startsWith(matrix.os, 'windows-') != true if: startsWith(matrix.os, 'windows-') != true
run: | run: |
go build -trimpath -tags nopgxregisterdefaulttypes -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=`git describe --always --abbrev=8 --dirty` -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo go build -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/version.commit=`git describe --always --dirty` -X github.com/drakkan/sftpgo/v2/version.date=`date -u +%FT%TZ`" -o sftpgo
cd tests/eventsearcher cd tests/eventsearcher
go build -trimpath -ldflags "-s -w" -o eventsearcher go build -trimpath -ldflags "-s -w" -o eventsearcher
cd - cd -
cd tests/ipfilter cd tests/ipfilter
go build -trimpath -ldflags "-s -w" -o ipfilter go build -trimpath -ldflags "-s -w" -o ipfilter
cd - cd -
./sftpgo initprovider
./sftpgo resetprovider --force
- name: Build for macOS arm64 - name: Build for macOS arm64
if: startsWith(matrix.os, 'macos-') == true if: startsWith(matrix.os, 'macos-') == true
run: CGO_ENABLED=1 GOOS=darwin GOARCH=arm64 SDKROOT=$(xcrun --sdk macosx --show-sdk-path) go build -trimpath -tags nopgxregisterdefaulttypes -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=`git describe --always --abbrev=8 --dirty` -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo_arm64 run: CGO_ENABLED=1 GOOS=darwin GOARCH=arm64 SDKROOT=$(xcrun --sdk macosx --show-sdk-path) go build -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/version.commit=`git describe --always --dirty` -X github.com/drakkan/sftpgo/v2/version.date=`date -u +%FT%TZ`" -o sftpgo_arm64
- name: Build for Windows - name: Build for Windows
if: startsWith(matrix.os, 'windows-') if: startsWith(matrix.os, 'windows-')
run: | run: |
$GIT_COMMIT = (git describe --always --abbrev=8 --dirty) | Out-String $GIT_COMMIT = (git describe --always --dirty) | Out-String
$DATE_TIME = ([datetime]::Now.ToUniversalTime().toString("yyyy-MM-ddTHH:mm:ssZ")) | Out-String $DATE_TIME = ([datetime]::Now.ToUniversalTime().toString("yyyy-MM-ddTHH:mm:ssZ")) | Out-String
$LATEST_TAG = ((git describe --tags $(git rev-list --tags --max-count=1)) | Out-String).Trim() $LATEST_TAG = ((git describe --tags $(git rev-list --tags --max-count=1)) | Out-String).Trim()
$REV_LIST=$LATEST_TAG+"..HEAD" $REV_LIST=$LATEST_TAG+"..HEAD"
@@ -57,7 +55,7 @@ jobs:
$FILE_VERSION = $LATEST_TAG.substring(1) + "." + $COMMITS_FROM_TAG $FILE_VERSION = $LATEST_TAG.substring(1) + "." + $COMMITS_FROM_TAG
go install github.com/tc-hib/go-winres@latest go install github.com/tc-hib/go-winres@latest
go-winres simply --arch amd64 --product-version $LATEST_TAG-dev-$GIT_COMMIT --file-version $FILE_VERSION --file-description "SFTPGo server" --product-name SFTPGo --copyright "AGPL-3.0" --original-filename sftpgo.exe --icon .\windows-installer\icon.ico go-winres simply --arch amd64 --product-version $LATEST_TAG-dev-$GIT_COMMIT --file-version $FILE_VERSION --file-description "SFTPGo server" --product-name SFTPGo --copyright "AGPL-3.0" --original-filename sftpgo.exe --icon .\windows-installer\icon.ico
go build -trimpath -tags nopgxregisterdefaulttypes -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=$GIT_COMMIT -X github.com/drakkan/sftpgo/v2/internal/version.date=$DATE_TIME" -o sftpgo.exe go build -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/version.commit=$GIT_COMMIT -X github.com/drakkan/sftpgo/v2/version.date=$DATE_TIME" -o sftpgo.exe
cd tests/eventsearcher cd tests/eventsearcher
go build -trimpath -ldflags "-s -w" -o eventsearcher.exe go build -trimpath -ldflags "-s -w" -o eventsearcher.exe
cd ../.. cd ../..
@@ -69,17 +67,17 @@ jobs:
$Env:GOOS='windows' $Env:GOOS='windows'
$Env:GOARCH='arm64' $Env:GOARCH='arm64'
go-winres simply --arch arm64 --product-version $LATEST_TAG-dev-$GIT_COMMIT --file-version $FILE_VERSION --file-description "SFTPGo server" --product-name SFTPGo --copyright "AGPL-3.0" --original-filename sftpgo.exe --icon .\windows-installer\icon.ico go-winres simply --arch arm64 --product-version $LATEST_TAG-dev-$GIT_COMMIT --file-version $FILE_VERSION --file-description "SFTPGo server" --product-name SFTPGo --copyright "AGPL-3.0" --original-filename sftpgo.exe --icon .\windows-installer\icon.ico
go build -trimpath -tags nopgxregisterdefaulttypes -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=$GIT_COMMIT -X github.com/drakkan/sftpgo/v2/internal/version.date=$DATE_TIME" -o .\arm64\sftpgo.exe go build -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/version.commit=$GIT_COMMIT -X github.com/drakkan/sftpgo/v2/version.date=$DATE_TIME" -o .\arm64\sftpgo.exe
mkdir x86 mkdir x86
$Env:GOARCH='386' $Env:GOARCH='386'
go-winres simply --arch 386 --product-version $LATEST_TAG-dev-$GIT_COMMIT --file-version $FILE_VERSION --file-description "SFTPGo server" --product-name SFTPGo --copyright "AGPL-3.0" --original-filename sftpgo.exe --icon .\windows-installer\icon.ico go-winres simply --arch 386 --product-version $LATEST_TAG-dev-$GIT_COMMIT --file-version $FILE_VERSION --file-description "SFTPGo server" --product-name SFTPGo --copyright "AGPL-3.0" --original-filename sftpgo.exe --icon .\windows-installer\icon.ico
go build -trimpath -tags nopgxregisterdefaulttypes -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=$GIT_COMMIT -X github.com/drakkan/sftpgo/v2/internal/version.date=$DATE_TIME" -o .\x86\sftpgo.exe go build -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/version.commit=$GIT_COMMIT -X github.com/drakkan/sftpgo/v2/version.date=$DATE_TIME" -o .\x86\sftpgo.exe
Remove-Item Env:\CGO_ENABLED Remove-Item Env:\CGO_ENABLED
Remove-Item Env:\GOOS Remove-Item Env:\GOOS
Remove-Item Env:\GOARCH Remove-Item Env:\GOARCH
- name: Run test cases using SQLite provider - name: Run test cases using SQLite provider
run: go test -v -tags nopgxregisterdefaulttypes -p 1 -timeout 15m ./... -coverprofile=coverage.txt -covermode=atomic run: go test -v -p 1 -timeout 15m ./... -coverprofile=coverage.txt -covermode=atomic
- name: Upload coverage to Codecov - name: Upload coverage to Codecov
if: ${{ matrix.upload-coverage }} if: ${{ matrix.upload-coverage }}
@@ -90,21 +88,21 @@ jobs:
- name: Run test cases using bolt provider - name: Run test cases using bolt provider
run: | run: |
go test -v -tags nopgxregisterdefaulttypes -p 1 -timeout 2m ./internal/config -covermode=atomic go test -v -p 1 -timeout 2m ./config -covermode=atomic
go test -v -tags nopgxregisterdefaulttypes -p 1 -timeout 5m ./internal/common -covermode=atomic go test -v -p 1 -timeout 5m ./common -covermode=atomic
go test -v -tags nopgxregisterdefaulttypes -p 1 -timeout 5m ./internal/httpd -covermode=atomic go test -v -p 1 -timeout 5m ./httpd -covermode=atomic
go test -v -tags nopgxregisterdefaulttypes -p 1 -timeout 8m ./internal/sftpd -covermode=atomic go test -v -p 1 -timeout 8m ./sftpd -covermode=atomic
go test -v -tags nopgxregisterdefaulttypes -p 1 -timeout 5m ./internal/ftpd -covermode=atomic go test -v -p 1 -timeout 5m ./ftpd -covermode=atomic
go test -v -tags nopgxregisterdefaulttypes -p 1 -timeout 5m ./internal/webdavd -covermode=atomic go test -v -p 1 -timeout 5m ./webdavd -covermode=atomic
go test -v -tags nopgxregisterdefaulttypes -p 1 -timeout 2m ./internal/telemetry -covermode=atomic go test -v -p 1 -timeout 2m ./telemetry -covermode=atomic
go test -v -tags nopgxregisterdefaulttypes -p 1 -timeout 2m ./internal/mfa -covermode=atomic go test -v -p 1 -timeout 2m ./mfa -covermode=atomic
go test -v -tags nopgxregisterdefaulttypes -p 1 -timeout 2m ./internal/command -covermode=atomic go test -v -p 1 -timeout 2m ./command -covermode=atomic
env: env:
SFTPGO_DATA_PROVIDER__DRIVER: bolt SFTPGO_DATA_PROVIDER__DRIVER: bolt
SFTPGO_DATA_PROVIDER__NAME: 'sftpgo_bolt.db' SFTPGO_DATA_PROVIDER__NAME: 'sftpgo_bolt.db'
- name: Run test cases using memory provider - name: Run test cases using memory provider
run: go test -v -tags nopgxregisterdefaulttypes -p 1 -timeout 15m ./... -covermode=atomic run: go test -v -p 1 -timeout 15m ./... -covermode=atomic
env: env:
SFTPGO_DATA_PROVIDER__DRIVER: memory SFTPGO_DATA_PROVIDER__DRIVER: memory
SFTPGO_DATA_PROVIDER__NAME: '' SFTPGO_DATA_PROVIDER__NAME: ''
@@ -222,24 +220,6 @@ jobs:
name: sftpgo-${{ matrix.os }}-go-${{ matrix.go }} name: sftpgo-${{ matrix.os }}-go-${{ matrix.go }}
path: output path: output
test-bundle:
name: Build in bundle mode
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: 1.19
- name: Build
run: |
cp -r openapi static templates internal/bundle/
go build -trimpath -tags nopgxregisterdefaulttypes,bundle -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=`git describe --always --abbrev=8 --dirty` -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo
./sftpgo -v
test-goarch-386: test-goarch-386:
name: Run test cases on 32-bit arch name: Run test cases on 32-bit arch
runs-on: ubuntu-latest runs-on: ubuntu-latest
@@ -250,7 +230,7 @@ jobs:
- name: Set up Go - name: Set up Go
uses: actions/setup-go@v3 uses: actions/setup-go@v3
with: with:
go-version: 1.19 go-version: 1.18
- name: Build - name: Build
run: | run: |
@@ -264,7 +244,7 @@ jobs:
GOARCH: 386 GOARCH: 386
- name: Run test cases - name: Run test cases
run: go test -v -tags nopgxregisterdefaulttypes -p 1 -timeout 15m ./... -covermode=atomic run: go test -v -p 1 -timeout 15m ./... -covermode=atomic
env: env:
SFTPGO_DATA_PROVIDER__DRIVER: memory SFTPGO_DATA_PROVIDER__DRIVER: memory
SFTPGO_DATA_PROVIDER__NAME: '' SFTPGO_DATA_PROVIDER__NAME: ''
@@ -324,11 +304,10 @@ jobs:
- name: Set up Go - name: Set up Go
uses: actions/setup-go@v3 uses: actions/setup-go@v3
with: with:
go-version: 1.19 go-version: 1.18
- name: Build - name: Build
run: | run: |
go build -trimpath -tags nopgxregisterdefaulttypes -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=`git describe --always --abbrev=8 --dirty` -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo
cd tests/eventsearcher cd tests/eventsearcher
go build -trimpath -ldflags "-s -w" -o eventsearcher go build -trimpath -ldflags "-s -w" -o eventsearcher
cd - cd -
@@ -338,9 +317,7 @@ jobs:
- name: Run tests using PostgreSQL provider - name: Run tests using PostgreSQL provider
run: | run: |
./sftpgo initprovider go test -v -p 1 -timeout 15m ./... -covermode=atomic
./sftpgo resetprovider --force
go test -v -tags nopgxregisterdefaulttypes -p 1 -timeout 15m ./... -covermode=atomic
env: env:
SFTPGO_DATA_PROVIDER__DRIVER: postgresql SFTPGO_DATA_PROVIDER__DRIVER: postgresql
SFTPGO_DATA_PROVIDER__NAME: sftpgo SFTPGO_DATA_PROVIDER__NAME: sftpgo
@@ -351,9 +328,7 @@ jobs:
- name: Run tests using MySQL provider - name: Run tests using MySQL provider
run: | run: |
./sftpgo initprovider go test -v -p 1 -timeout 15m ./... -covermode=atomic
./sftpgo resetprovider --force
go test -v -tags nopgxregisterdefaulttypes -p 1 -timeout 15m ./... -covermode=atomic
env: env:
SFTPGO_DATA_PROVIDER__DRIVER: mysql SFTPGO_DATA_PROVIDER__DRIVER: mysql
SFTPGO_DATA_PROVIDER__NAME: sftpgo SFTPGO_DATA_PROVIDER__NAME: sftpgo
@@ -364,9 +339,7 @@ jobs:
- name: Run tests using MariaDB provider - name: Run tests using MariaDB provider
run: | run: |
./sftpgo initprovider go test -v -p 1 -timeout 15m ./... -covermode=atomic
./sftpgo resetprovider --force
go test -v -tags nopgxregisterdefaulttypes -p 1 -timeout 15m ./... -covermode=atomic
env: env:
SFTPGO_DATA_PROVIDER__DRIVER: mysql SFTPGO_DATA_PROVIDER__DRIVER: mysql
SFTPGO_DATA_PROVIDER__NAME: sftpgo SFTPGO_DATA_PROVIDER__NAME: sftpgo
@@ -381,9 +354,7 @@ jobs:
docker run --rm --name crdb --health-cmd "curl -I http://127.0.0.1:8080" --health-interval 10s --health-timeout 5s --health-retries 6 -p 26257:26257 -d cockroachdb/cockroach:latest start-single-node --insecure --listen-addr :26257 docker run --rm --name crdb --health-cmd "curl -I http://127.0.0.1:8080" --health-interval 10s --health-timeout 5s --health-retries 6 -p 26257:26257 -d cockroachdb/cockroach:latest start-single-node --insecure --listen-addr :26257
sleep 10 sleep 10
docker exec crdb cockroach sql --insecure -e 'create database "sftpgo"' docker exec crdb cockroach sql --insecure -e 'create database "sftpgo"'
./sftpgo initprovider go test -v -p 1 -timeout 15m ./... -covermode=atomic
./sftpgo resetprovider --force
go test -v -tags nopgxregisterdefaulttypes -p 1 -timeout 15m ./... -covermode=atomic
docker stop crdb docker stop crdb
env: env:
SFTPGO_DATA_PROVIDER__DRIVER: cockroachdb SFTPGO_DATA_PROVIDER__DRIVER: cockroachdb
@@ -396,13 +367,12 @@ jobs:
build-linux-packages: build-linux-packages:
name: Build Linux packages name: Build Linux packages
runs-on: ubuntu-latest runs-on: ubuntu-18.04
strategy: strategy:
matrix: matrix:
include: include:
- arch: amd64 - arch: amd64
distro: ubuntu:18.04 go: 1.18
go: latest
go-arch: amd64 go-arch: amd64
- arch: aarch64 - arch: aarch64
distro: ubuntu18.04 distro: ubuntu18.04
@@ -420,36 +390,16 @@ jobs:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
with: with:
fetch-depth: 0 fetch-depth: 0
- name: Set up Go
- name: Get commit SHA if: ${{ matrix.arch == 'amd64' }}
id: get_commit uses: actions/setup-go@v3
run: echo "COMMIT=${GITHUB_SHA::8}" >> $GITHUB_OUTPUT with:
shell: bash go-version: ${{ matrix.go }}
- name: Build on amd64 - name: Build on amd64
if: ${{ matrix.arch == 'amd64' }} if: ${{ matrix.arch == 'amd64' }}
run: | run: |
echo '#!/bin/bash' > build.sh go build -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/version.commit=`git describe --always --dirty` -X github.com/drakkan/sftpgo/v2/version.date=`date -u +%FT%TZ`" -o sftpgo
echo '' >> build.sh
echo 'set -e' >> build.sh
echo 'apt-get update -q -y' >> build.sh
echo 'apt-get install -q -y curl gcc' >> build.sh
if [ ${{ matrix.go }} == 'latest' ]
then
echo 'GO_VERSION=$(curl -L https://go.dev/VERSION?m=text)' >> build.sh
else
echo 'GO_VERSION=${{ matrix.go }}' >> build.sh
fi
echo 'GO_DOWNLOAD_ARCH=${{ matrix.go-arch }}' >> build.sh
echo 'curl --retry 5 --retry-delay 2 --connect-timeout 10 -o go.tar.gz -L https://go.dev/dl/${GO_VERSION}.linux-${GO_DOWNLOAD_ARCH}.tar.gz' >> build.sh
echo 'tar -C /usr/local -xzf go.tar.gz' >> build.sh
echo 'export PATH=$PATH:/usr/local/go/bin' >> build.sh
echo 'go version' >> build.sh
echo 'cd /usr/local/src' >> build.sh
echo 'go build -buildvcs=false -trimpath -tags nopgxregisterdefaulttypes -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=${{ steps.get_commit.outputs.COMMIT }} -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo' >> build.sh
chmod 755 build.sh
docker run --rm --name ubuntu-build --mount type=bind,source=`pwd`,target=/usr/local/src ${{ matrix.distro }} /usr/local/src/build.sh
mkdir -p output/{init,bash_completion,zsh_completion} mkdir -p output/{init,bash_completion,zsh_completion}
cp sftpgo.json output/ cp sftpgo.json output/
cp -r templates output/ cp -r templates output/
@@ -476,7 +426,7 @@ jobs:
shell: /bin/bash shell: /bin/bash
install: | install: |
apt-get update -q -y apt-get update -q -y
apt-get install -q -y curl gcc apt-get install -q -y curl gcc git
if [ ${{ matrix.go }} == 'latest' ] if [ ${{ matrix.go }} == 'latest' ]
then then
GO_VERSION=$(curl -L https://go.dev/VERSION?m=text) GO_VERSION=$(curl -L https://go.dev/VERSION?m=text)
@@ -492,12 +442,11 @@ jobs:
tar -C /usr/local -xzf go.tar.gz tar -C /usr/local -xzf go.tar.gz
run: | run: |
export PATH=$PATH:/usr/local/go/bin export PATH=$PATH:/usr/local/go/bin
go version
if [ ${{ matrix.arch}} == 'armv7' ] if [ ${{ matrix.arch}} == 'armv7' ]
then then
export GOARM=7 export GOARM=7
fi fi
go build -buildvcs=false -trimpath -tags nopgxregisterdefaulttypes -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=${{ steps.get_commit.outputs.COMMIT }} -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo go build -buildvcs=false -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/version.commit=`git describe --always --dirty` -X github.com/drakkan/sftpgo/v2/version.date=`date -u +%FT%TZ`" -o sftpgo
mkdir -p output/{init,bash_completion,zsh_completion} mkdir -p output/{init,bash_completion,zsh_completion}
cp sftpgo.json output/ cp sftpgo.json output/
cp -r templates output/ cp -r templates output/
@@ -523,7 +472,7 @@ jobs:
cd pkgs cd pkgs
./build.sh ./build.sh
PKG_VERSION=$(cat dist/version) PKG_VERSION=$(cat dist/version)
echo "pkg-version=${PKG_VERSION}" >> $GITHUB_OUTPUT echo "::set-output name=pkg-version::${PKG_VERSION}"
- name: Upload Debian Package - name: Upload Debian Package
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v3
@@ -544,7 +493,7 @@ jobs:
- name: Set up Go - name: Set up Go
uses: actions/setup-go@v3 uses: actions/setup-go@v3
with: with:
go-version: 1.19 go-version: 1.18
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- name: Run golangci-lint - name: Run golangci-lint
uses: golangci/golangci-lint-action@v3 uses: golangci/golangci-lint-action@v3

View File

@@ -5,7 +5,7 @@ on:
# - cron: '0 4 * * *' # everyday at 4:00 AM UTC # - cron: '0 4 * * *' # everyday at 4:00 AM UTC
push: push:
branches: branches:
- main - 2.3.x
tags: tags:
- v* - v*
pull_request: pull_request:
@@ -28,9 +28,6 @@ jobs:
- os: ubuntu-latest - os: ubuntu-latest
docker_pkg: distroless docker_pkg: distroless
optional_deps: false optional_deps: false
- os: ubuntu-latest
docker_pkg: debian-plugins
optional_deps: true
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v3 uses: actions/checkout@v3
@@ -67,9 +64,6 @@ jobs:
VERSION="${VERSION}-distroless" VERSION="${VERSION}-distroless"
VERSION_SLIM="${VERSION}-slim" VERSION_SLIM="${VERSION}-slim"
DOCKERFILE=Dockerfile.distroless DOCKERFILE=Dockerfile.distroless
elif [[ $DOCKER_PKG == debian-plugins ]]; then
VERSION="${VERSION}-plugins"
VERSION_SLIM="${VERSION}-slim"
fi fi
DOCKER_IMAGES=("drakkan/sftpgo" "ghcr.io/drakkan/sftpgo") DOCKER_IMAGES=("drakkan/sftpgo" "ghcr.io/drakkan/sftpgo")
TAGS="${DOCKER_IMAGES[0]}:${VERSION}" TAGS="${DOCKER_IMAGES[0]}:${VERSION}"
@@ -95,13 +89,6 @@ jobs:
fi fi
TAGS="${TAGS},${DOCKER_IMAGE}:distroless" TAGS="${TAGS},${DOCKER_IMAGE}:distroless"
TAGS_SLIM="${TAGS_SLIM},${DOCKER_IMAGE}:distroless-slim" TAGS_SLIM="${TAGS_SLIM},${DOCKER_IMAGE}:distroless-slim"
elif [[ $DOCKER_PKG == debian-plugins ]]; then
if [[ -n $MAJOR && -n $MINOR ]]; then
TAGS="${TAGS},${DOCKER_IMAGE}:${MINOR}-plugins,${DOCKER_IMAGE}:${MAJOR}-plugins"
TAGS_SLIM="${TAGS_SLIM},${DOCKER_IMAGE}:${MINOR}-plugins-slim,${DOCKER_IMAGE}:${MAJOR}-plugins-slim"
fi
TAGS="${TAGS},${DOCKER_IMAGE}:plugins"
TAGS_SLIM="${TAGS_SLIM},${DOCKER_IMAGE}:plugins-slim"
else else
if [[ -n $MAJOR && -n $MINOR ]]; then if [[ -n $MAJOR && -n $MINOR ]]; then
TAGS="${TAGS},${DOCKER_IMAGE}:${MINOR}-alpine,${DOCKER_IMAGE}:${MAJOR}-alpine" TAGS="${TAGS},${DOCKER_IMAGE}:${MINOR}-alpine,${DOCKER_IMAGE}:${MAJOR}-alpine"
@@ -114,22 +101,17 @@ jobs:
done done
if [[ $OPTIONAL_DEPS == true ]]; then if [[ $OPTIONAL_DEPS == true ]]; then
echo "version=${VERSION}" >> $GITHUB_OUTPUT echo ::set-output name=version::${VERSION}
echo "tags=${TAGS}" >> $GITHUB_OUTPUT echo ::set-output name=tags::${TAGS}
echo "full=true" >> $GITHUB_OUTPUT echo ::set-output name=full::true
else else
echo "version=${VERSION_SLIM}" >> $GITHUB_OUTPUT echo ::set-output name=version::${VERSION_SLIM}
echo "tags=${TAGS_SLIM}" >> $GITHUB_OUTPUT echo ::set-output name=tags::${TAGS_SLIM}
echo "full=false" >> $GITHUB_OUTPUT echo ::set-output name=full::false
fi fi
if [[ $DOCKER_PKG == debian-plugins ]]; then echo ::set-output name=dockerfile::${DOCKERFILE}
echo "plugins=true" >> $GITHUB_OUTPUT echo ::set-output name=created::$(date -u +'%Y-%m-%dT%H:%M:%SZ')
else echo ::set-output name=sha::${GITHUB_SHA::8}
echo "plugins=false" >> $GITHUB_OUTPUT
fi
echo "dockerfile=${DOCKERFILE}" >> $GITHUB_OUTPUT
echo "created=$(date -u +'%Y-%m-%dT%H:%M:%SZ')" >> $GITHUB_OUTPUT
echo "sha=${GITHUB_SHA::8}" >> $GITHUB_OUTPUT
env: env:
DOCKER_PKG: ${{ matrix.docker_pkg }} DOCKER_PKG: ${{ matrix.docker_pkg }}
OPTIONAL_DEPS: ${{ matrix.optional_deps }} OPTIONAL_DEPS: ${{ matrix.optional_deps }}
@@ -168,8 +150,6 @@ jobs:
build-args: | build-args: |
COMMIT_SHA=${{ steps.info.outputs.sha }} COMMIT_SHA=${{ steps.info.outputs.sha }}
INSTALL_OPTIONAL_PACKAGES=${{ steps.info.outputs.full }} INSTALL_OPTIONAL_PACKAGES=${{ steps.info.outputs.full }}
DOWNLOAD_PLUGINS=${{ steps.info.outputs.plugins }}
FEATURES=nopgxregisterdefaulttypes
labels: | labels: |
org.opencontainers.image.title=SFTPGo org.opencontainers.image.title=SFTPGo
org.opencontainers.image.description=Fully featured and highly configurable SFTP server with optional HTTP, FTP/S and WebDAV support org.opencontainers.image.description=Fully featured and highly configurable SFTP server with optional HTTP, FTP/S and WebDAV support
@@ -179,4 +159,4 @@ jobs:
org.opencontainers.image.version=${{ steps.info.outputs.version }} org.opencontainers.image.version=${{ steps.info.outputs.version }}
org.opencontainers.image.created=${{ steps.info.outputs.created }} org.opencontainers.image.created=${{ steps.info.outputs.created }}
org.opencontainers.image.revision=${{ github.sha }} org.opencontainers.image.revision=${{ github.sha }}
org.opencontainers.image.licenses=AGPL-3.0-only org.opencontainers.image.licenses=AGPL-3.0

View File

@@ -5,7 +5,7 @@ on:
tags: 'v*' tags: 'v*'
env: env:
GO_VERSION: 1.19.2 GO_VERSION: 1.18.5
jobs: jobs:
prepare-sources-with-deps: prepare-sources-with-deps:
@@ -20,13 +20,12 @@ jobs:
- name: Get SFTPGo version - name: Get SFTPGo version
id: get_version id: get_version
run: echo "VERSION=${GITHUB_REF/refs\/tags\//}" >> $GITHUB_OUTPUT run: echo ::set-output name=VERSION::${GITHUB_REF/refs\/tags\//}
- name: Prepare release - name: Prepare release
run: | run: |
go mod vendor go mod vendor
echo "${SFTPGO_VERSION}" > VERSION.txt echo "${SFTPGO_VERSION}" > VERSION.txt
echo "${GITHUB_SHA::8}" >> VERSION.txt
tar cJvf sftpgo_${SFTPGO_VERSION}_src_with_deps.tar.xz * tar cJvf sftpgo_${SFTPGO_VERSION}_src_with_deps.tar.xz *
env: env:
SFTPGO_VERSION: ${{ steps.get_version.outputs.VERSION }} SFTPGO_VERSION: ${{ steps.get_version.outputs.VERSION }}
@@ -54,7 +53,7 @@ jobs:
- name: Get SFTPGo version - name: Get SFTPGo version
id: get_version id: get_version
run: echo "VERSION=${GITHUB_REF/refs\/tags\//}" >> $GITHUB_OUTPUT run: echo ::set-output name=VERSION::${GITHUB_REF/refs\/tags\//}
shell: bash shell: bash
- name: Get OS name - name: Get OS name
@@ -62,9 +61,9 @@ jobs:
run: | run: |
if [[ $MATRIX_OS =~ ^macos.* ]] if [[ $MATRIX_OS =~ ^macos.* ]]
then then
echo "OS=macOS" >> $GITHUB_OUTPUT echo ::set-output name=OS::macOS
else else
echo "OS=windows" >> $GITHUB_OUTPUT echo ::set-output name=OS::windows
fi fi
shell: bash shell: bash
env: env:
@@ -72,31 +71,31 @@ jobs:
- name: Build for macOS x86_64 - name: Build for macOS x86_64
if: startsWith(matrix.os, 'windows-') != true if: startsWith(matrix.os, 'windows-') != true
run: go build -trimpath -tags nopgxregisterdefaulttypes -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=`git describe --always --abbrev=8 --dirty` -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo run: go build -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/version.commit=`git describe --always --dirty` -X github.com/drakkan/sftpgo/v2/version.date=`date -u +%FT%TZ`" -o sftpgo
- name: Build for macOS arm64 - name: Build for macOS arm64
if: startsWith(matrix.os, 'macos-') == true if: startsWith(matrix.os, 'macos-') == true
run: CGO_ENABLED=1 GOOS=darwin GOARCH=arm64 SDKROOT=$(xcrun --sdk macosx --show-sdk-path) go build -trimpath -tags nopgxregisterdefaulttypes -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=`git describe --always --abbrev=8 --dirty` -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo_arm64 run: CGO_ENABLED=1 GOOS=darwin GOARCH=arm64 SDKROOT=$(xcrun --sdk macosx --show-sdk-path) go build -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/version.commit=`git describe --always --dirty` -X github.com/drakkan/sftpgo/v2/version.date=`date -u +%FT%TZ`" -o sftpgo_arm64
- name: Build for Windows - name: Build for Windows
if: startsWith(matrix.os, 'windows-') if: startsWith(matrix.os, 'windows-')
run: | run: |
$GIT_COMMIT = (git describe --always --abbrev=8 --dirty) | Out-String $GIT_COMMIT = (git describe --always --dirty) | Out-String
$DATE_TIME = ([datetime]::Now.ToUniversalTime().toString("yyyy-MM-ddTHH:mm:ssZ")) | Out-String $DATE_TIME = ([datetime]::Now.ToUniversalTime().toString("yyyy-MM-ddTHH:mm:ssZ")) | Out-String
$FILE_VERSION = $Env:SFTPGO_VERSION.substring(1) + ".0" $FILE_VERSION = $Env:SFTPGO_VERSION.substring(1) + ".0"
go install github.com/tc-hib/go-winres@latest go install github.com/tc-hib/go-winres@latest
go-winres simply --arch amd64 --product-version $Env:SFTPGO_VERSION-$GIT_COMMIT --file-version $FILE_VERSION --file-description "SFTPGo server" --product-name SFTPGo --copyright "AGPL-3.0" --original-filename sftpgo.exe --icon .\windows-installer\icon.ico go-winres simply --arch amd64 --product-version $Env:SFTPGO_VERSION-$GIT_COMMIT --file-version $FILE_VERSION --file-description "SFTPGo server" --product-name SFTPGo --copyright "AGPL-3.0" --original-filename sftpgo.exe --icon .\windows-installer\icon.ico
go build -trimpath -tags nopgxregisterdefaulttypes -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=$GIT_COMMIT -X github.com/drakkan/sftpgo/v2/internal/version.date=$DATE_TIME" -o sftpgo.exe go build -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/version.commit=$GIT_COMMIT -X github.com/drakkan/sftpgo/v2/version.date=$DATE_TIME" -o sftpgo.exe
mkdir arm64 mkdir arm64
$Env:CGO_ENABLED='0' $Env:CGO_ENABLED='0'
$Env:GOOS='windows' $Env:GOOS='windows'
$Env:GOARCH='arm64' $Env:GOARCH='arm64'
go-winres simply --arch arm64 --product-version $Env:SFTPGO_VERSION-$GIT_COMMIT --file-version $FILE_VERSION --file-description "SFTPGo server" --product-name SFTPGo --copyright "AGPL-3.0" --original-filename sftpgo.exe --icon .\windows-installer\icon.ico go-winres simply --arch arm64 --product-version $Env:SFTPGO_VERSION-$GIT_COMMIT --file-version $FILE_VERSION --file-description "SFTPGo server" --product-name SFTPGo --copyright "AGPL-3.0" --original-filename sftpgo.exe --icon .\windows-installer\icon.ico
go build -trimpath -tags nopgxregisterdefaulttypes -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=$GIT_COMMIT -X github.com/drakkan/sftpgo/v2/internal/version.date=$DATE_TIME" -o .\arm64\sftpgo.exe go build -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/version.commit=$GIT_COMMIT -X github.com/drakkan/sftpgo/v2/version.date=$DATE_TIME" -o .\arm64\sftpgo.exe
mkdir x86 mkdir x86
$Env:GOARCH='386' $Env:GOARCH='386'
go-winres simply --arch 386 --product-version $Env:SFTPGO_VERSION-$GIT_COMMIT --file-version $FILE_VERSION --file-description "SFTPGo server" --product-name SFTPGo --copyright "AGPL-3.0" --original-filename sftpgo.exe --icon .\windows-installer\icon.ico go-winres simply --arch 386 --product-version $Env:SFTPGO_VERSION-$GIT_COMMIT --file-version $FILE_VERSION --file-description "SFTPGo server" --product-name SFTPGo --copyright "AGPL-3.0" --original-filename sftpgo.exe --icon .\windows-installer\icon.ico
go build -trimpath -tags nopgxregisterdefaulttypes -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=$GIT_COMMIT -X github.com/drakkan/sftpgo/v2/internal/version.date=$DATE_TIME" -o .\x86\sftpgo.exe go build -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/version.commit=$GIT_COMMIT -X github.com/drakkan/sftpgo/v2/version.date=$DATE_TIME" -o .\x86\sftpgo.exe
Remove-Item Env:\CGO_ENABLED Remove-Item Env:\CGO_ENABLED
Remove-Item Env:\GOOS Remove-Item Env:\GOOS
Remove-Item Env:\GOARCH Remove-Item Env:\GOARCH
@@ -255,12 +254,11 @@ jobs:
prepare-linux: prepare-linux:
name: Prepare Linux binaries name: Prepare Linux binaries
runs-on: ubuntu-latest runs-on: ubuntu-18.04
strategy: strategy:
matrix: matrix:
include: include:
- arch: amd64 - arch: amd64
distro: ubuntu:18.04
go-arch: amd64 go-arch: amd64
deb-arch: amd64 deb-arch: amd64
rpm-arch: x86_64 rpm-arch: x86_64
@@ -286,13 +284,17 @@ jobs:
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- name: Set up Go
if: ${{ matrix.arch == 'amd64' }}
uses: actions/setup-go@v3
with:
go-version: ${{ env.GO_VERSION }}
- name: Get versions - name: Get versions
id: get_version id: get_version
run: | run: |
echo "SFTPGO_VERSION=${GITHUB_REF/refs\/tags\//}" >> $GITHUB_OUTPUT echo ::set-output name=SFTPGO_VERSION::${GITHUB_REF/refs\/tags\//}
echo "GO_VERSION=${GO_VERSION}" >> $GITHUB_OUTPUT echo ::set-output name=GO_VERSION::${GO_VERSION}
echo "COMMIT=${GITHUB_SHA::8}" >> $GITHUB_OUTPUT
shell: bash shell: bash
env: env:
GO_VERSION: ${{ env.GO_VERSION }} GO_VERSION: ${{ env.GO_VERSION }}
@@ -300,20 +302,7 @@ jobs:
- name: Build on amd64 - name: Build on amd64
if: ${{ matrix.arch == 'amd64' }} if: ${{ matrix.arch == 'amd64' }}
run: | run: |
echo '#!/bin/bash' > build.sh go build -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/version.commit=`git describe --always --dirty` -X github.com/drakkan/sftpgo/v2/version.date=`date -u +%FT%TZ`" -o sftpgo
echo '' >> build.sh
echo 'set -e' >> build.sh
echo 'apt-get update -q -y' >> build.sh
echo 'apt-get install -q -y curl gcc' >> build.sh
echo 'curl --retry 5 --retry-delay 2 --connect-timeout 10 -o go.tar.gz -L https://go.dev/dl/go${{ steps.get_version.outputs.GO_VERSION }}.linux-${{ matrix.go-arch }}.tar.gz' >> build.sh
echo 'tar -C /usr/local -xzf go.tar.gz' >> build.sh
echo 'export PATH=$PATH:/usr/local/go/bin' >> build.sh
echo 'go version' >> build.sh
echo 'cd /usr/local/src' >> build.sh
echo 'go build -buildvcs=false -trimpath -tags nopgxregisterdefaulttypes -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=${{ steps.get_version.outputs.COMMIT }} -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo' >> build.sh
chmod 755 build.sh
docker run --rm --name ubuntu-build --mount type=bind,source=`pwd`,target=/usr/local/src ${{ matrix.distro }} /usr/local/src/build.sh
mkdir -p output/{init,sqlite,bash_completion,zsh_completion} mkdir -p output/{init,sqlite,bash_completion,zsh_completion}
echo "For documentation please take a look here:" > output/README.txt echo "For documentation please take a look here:" > output/README.txt
echo "" >> output/README.txt echo "" >> output/README.txt
@@ -351,7 +340,7 @@ jobs:
shell: /bin/bash shell: /bin/bash
install: | install: |
apt-get update -q -y apt-get update -q -y
apt-get install -q -y curl gcc xz-utils apt-get install -q -y curl gcc git xz-utils
GO_DOWNLOAD_ARCH=${{ matrix.go-arch }} GO_DOWNLOAD_ARCH=${{ matrix.go-arch }}
if [ ${{ matrix.arch}} == 'armv7' ] if [ ${{ matrix.arch}} == 'armv7' ]
then then
@@ -361,8 +350,7 @@ jobs:
tar -C /usr/local -xzf go.tar.gz tar -C /usr/local -xzf go.tar.gz
run: | run: |
export PATH=$PATH:/usr/local/go/bin export PATH=$PATH:/usr/local/go/bin
go version go build -buildvcs=false -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/version.commit=`git describe --always --dirty` -X github.com/drakkan/sftpgo/v2/version.date=`date -u +%FT%TZ`" -o sftpgo
go build -buildvcs=false -trimpath -tags nopgxregisterdefaulttypes -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=${{ steps.get_version.outputs.COMMIT }} -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -o sftpgo
mkdir -p output/{init,sqlite,bash_completion,zsh_completion} mkdir -p output/{init,sqlite,bash_completion,zsh_completion}
echo "For documentation please take a look here:" > output/README.txt echo "For documentation please take a look here:" > output/README.txt
echo "" >> output/README.txt echo "" >> output/README.txt
@@ -398,7 +386,7 @@ jobs:
cd pkgs cd pkgs
./build.sh ./build.sh
PKG_VERSION=${SFTPGO_VERSION:1} PKG_VERSION=${SFTPGO_VERSION:1}
echo "pkg-version=${PKG_VERSION}" >> $GITHUB_OUTPUT echo "::set-output name=pkg-version::${PKG_VERSION}"
env: env:
SFTPGO_VERSION: ${{ steps.get_version.outputs.SFTPGO_VERSION }} SFTPGO_VERSION: ${{ steps.get_version.outputs.SFTPGO_VERSION }}
@@ -425,7 +413,7 @@ jobs:
- name: Get versions - name: Get versions
id: get_version id: get_version
run: | run: |
echo "SFTPGO_VERSION=${GITHUB_REF/refs\/tags\//}" >> $GITHUB_OUTPUT echo ::set-output name=SFTPGO_VERSION::${GITHUB_REF/refs\/tags\//}
shell: bash shell: bash
- name: Download amd64 artifact - name: Download amd64 artifact
@@ -485,8 +473,8 @@ jobs:
run: | run: |
SFTPGO_VERSION=${GITHUB_REF/refs\/tags\//} SFTPGO_VERSION=${GITHUB_REF/refs\/tags\//}
PKG_VERSION=${SFTPGO_VERSION:1} PKG_VERSION=${SFTPGO_VERSION:1}
echo "SFTPGO_VERSION=${SFTPGO_VERSION}" >> $GITHUB_OUTPUT echo ::set-output name=SFTPGO_VERSION::${SFTPGO_VERSION}
echo "PKG_VERSION=${PKG_VERSION}" >> $GITHUB_OUTPUT echo "::set-output name=PKG_VERSION::${PKG_VERSION}"
shell: bash shell: bash
- name: Download amd64 artifact - name: Download amd64 artifact

View File

@@ -1,5 +1,5 @@
run: run:
timeout: 10m timeout: 5m
issues-exit-code: 1 issues-exit-code: 1
tests: true tests: true

View File

@@ -1,4 +1,4 @@
FROM golang:1.19-bullseye as builder FROM golang:1.18-bullseye as builder
ENV GOFLAGS="-mod=readonly" ENV GOFLAGS="-mod=readonly"
@@ -20,13 +20,8 @@ ARG FEATURES
COPY . . COPY . .
RUN set -xe && \ RUN set -xe && \
export COMMIT_SHA=${COMMIT_SHA:-$(git describe --always --abbrev=8 --dirty)} && \ export COMMIT_SHA=${COMMIT_SHA:-$(git describe --always --dirty)} && \
go build $(if [ -n "${FEATURES}" ]; then echo "-tags ${FEATURES}"; fi) -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=${COMMIT_SHA} -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -v -o sftpgo go build $(if [ -n "${FEATURES}" ]; then echo "-tags ${FEATURES}"; fi) -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/version.commit=${COMMIT_SHA} -X github.com/drakkan/sftpgo/v2/version.date=`date -u +%FT%TZ`" -v -o sftpgo
# Set to "true" to download the "official" plugins in /usr/local/bin
ARG DOWNLOAD_PLUGINS=false
RUN if [ "${DOWNLOAD_PLUGINS}" = "true" ]; then apt-get update && apt-get install --no-install-recommends -y curl && ./docker/scripts/download-plugins.sh; fi
FROM debian:bullseye-slim FROM debian:bullseye-slim
@@ -48,7 +43,7 @@ COPY --from=builder /workspace/sftpgo.json /etc/sftpgo/sftpgo.json
COPY --from=builder /workspace/templates /usr/share/sftpgo/templates COPY --from=builder /workspace/templates /usr/share/sftpgo/templates
COPY --from=builder /workspace/static /usr/share/sftpgo/static COPY --from=builder /workspace/static /usr/share/sftpgo/static
COPY --from=builder /workspace/openapi /usr/share/sftpgo/openapi COPY --from=builder /workspace/openapi /usr/share/sftpgo/openapi
COPY --from=builder /workspace/sftpgo /usr/local/bin/sftpgo-plugin-* /usr/local/bin/ COPY --from=builder /workspace/sftpgo /usr/local/bin/
# Log to the stdout so the logs will be available using docker logs # Log to the stdout so the logs will be available using docker logs
ENV SFTPGO_LOG_FILE_PATH="" ENV SFTPGO_LOG_FILE_PATH=""

View File

@@ -1,4 +1,4 @@
FROM golang:1.19-alpine3.16 AS builder FROM golang:1.18-alpine3.16 AS builder
ENV GOFLAGS="-mod=readonly" ENV GOFLAGS="-mod=readonly"
@@ -22,8 +22,8 @@ ARG FEATURES
COPY . . COPY . .
RUN set -xe && \ RUN set -xe && \
export COMMIT_SHA=${COMMIT_SHA:-$(git describe --always --abbrev=8 --dirty)} && \ export COMMIT_SHA=${COMMIT_SHA:-$(git describe --always --dirty)} && \
go build $(if [ -n "${FEATURES}" ]; then echo "-tags ${FEATURES}"; fi) -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=${COMMIT_SHA} -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -v -o sftpgo go build $(if [ -n "${FEATURES}" ]; then echo "-tags ${FEATURES}"; fi) -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/version.commit=${COMMIT_SHA} -X github.com/drakkan/sftpgo/v2/version.date=`date -u +%FT%TZ`" -v -o sftpgo
FROM alpine:3.16 FROM alpine:3.16

View File

@@ -1,4 +1,4 @@
FROM golang:1.19-bullseye as builder FROM golang:1.18-bullseye as builder
ENV CGO_ENABLED=0 GOFLAGS="-mod=readonly" ENV CGO_ENABLED=0 GOFLAGS="-mod=readonly"
@@ -20,8 +20,8 @@ ARG FEATURES=nosqlite
COPY . . COPY . .
RUN set -xe && \ RUN set -xe && \
export COMMIT_SHA=${COMMIT_SHA:-$(git describe --always --abbrev=8 --dirty)} && \ export COMMIT_SHA=${COMMIT_SHA:-$(git describe --always --dirty)} && \
go build $(if [ -n "${FEATURES}" ]; then echo "-tags ${FEATURES}"; fi) -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/internal/version.commit=${COMMIT_SHA} -X github.com/drakkan/sftpgo/v2/internal/version.date=`date -u +%FT%TZ`" -v -o sftpgo go build $(if [ -n "${FEATURES}" ]; then echo "-tags ${FEATURES}"; fi) -trimpath -ldflags "-s -w -X github.com/drakkan/sftpgo/v2/version.commit=${COMMIT_SHA} -X github.com/drakkan/sftpgo/v2/version.date=`date -u +%FT%TZ`" -v -o sftpgo
# Modify the default configuration file # Modify the default configuration file
RUN sed -i 's|"users_base_dir": "",|"users_base_dir": "/srv/sftpgo/data",|' sftpgo.json && \ RUN sed -i 's|"users_base_dir": "",|"users_base_dir": "/srv/sftpgo/data",|' sftpgo.json && \

View File

@@ -1,8 +1,8 @@
# SFTPGo # SFTPGo
[![CI Status](https://github.com/drakkan/sftpgo/workflows/CI/badge.svg?branch=main&event=push)](https://github.com/drakkan/sftpgo/workflows/CI/badge.svg?branch=main&event=push) ![CI Status](https://github.com/drakkan/sftpgo/workflows/CI/badge.svg?branch=main&event=push)
[![Code Coverage](https://codecov.io/gh/drakkan/sftpgo/branch/main/graph/badge.svg)](https://codecov.io/gh/drakkan/sftpgo/branch/main) [![Code Coverage](https://codecov.io/gh/drakkan/sftpgo/branch/main/graph/badge.svg)](https://codecov.io/gh/drakkan/sftpgo/branch/main)
[![License: AGPL-3.0-only](https://img.shields.io/badge/License-AGPLv3-blue.svg)](https://www.gnu.org/licenses/agpl-3.0) [![License: AGPL v3](https://img.shields.io/badge/License-AGPLv3-blue.svg)](https://www.gnu.org/licenses/agpl-3.0)
[![Docker Pulls](https://img.shields.io/docker/pulls/drakkan/sftpgo)](https://hub.docker.com/r/drakkan/sftpgo) [![Docker Pulls](https://img.shields.io/docker/pulls/drakkan/sftpgo)](https://hub.docker.com/r/drakkan/sftpgo)
[![Mentioned in Awesome Go](https://awesome.re/mentioned-badge.svg)](https://github.com/avelino/awesome-go) [![Mentioned in Awesome Go](https://awesome.re/mentioned-badge.svg)](https://github.com/avelino/awesome-go)
@@ -26,24 +26,10 @@ If you just take and don't return anything back, the project will die in the lon
More [info](https://github.com/drakkan/sftpgo/issues/452). More [info](https://github.com/drakkan/sftpgo/issues/452).
### Thank you to our sponsors Thank you to our sponsors!
#### Platinum sponsors
[<img src="./img/Aledade_logo.png" alt="Aledade logo" width="202" height="70">](https://www.aledade.com/)
#### Bronze sponsors
[<img src="https://www.7digital.com/wp-content/themes/sevendigital/images/top_logo.png" alt="7digital logo">](https://www.7digital.com/) [<img src="https://www.7digital.com/wp-content/themes/sevendigital/images/top_logo.png" alt="7digital logo">](https://www.7digital.com/)
## Support policy
SFTPGo is an Open Source project and you can of course use it for free but please don't ask for free support as well.
We will check the reported issues to see if you are experiencing a bug and if so we'll will fix it, but will only provide support to project [sponsors/donors](#sponsors).
If you report an invalid issue or ask for step-by-step support, your issue will remain open with no answer or will be closed as invalid without further explanation. Thanks for understanding.
## Features ## Features
- Support for serving local filesystem, encrypted local filesystem, S3 Compatible Object Storage, Google Cloud Storage, Azure Blob Storage or other SFTP accounts over SFTP/SCP/FTP/WebDAV. - Support for serving local filesystem, encrypted local filesystem, S3 Compatible Object Storage, Google Cloud Storage, Azure Blob Storage or other SFTP accounts over SFTP/SCP/FTP/WebDAV.
@@ -54,7 +40,6 @@ If you report an invalid issue or ask for step-by-step support, your issue will
- Chroot isolation for local accounts. Cloud-based accounts can be restricted to a certain base path. - Chroot isolation for local accounts. Cloud-based accounts can be restricted to a certain base path.
- Per-user and per-directory virtual permissions, for each exposed path you can allow or deny: directory listing, upload, overwrite, download, delete, rename, create directories, create symlinks, change owner/group/file mode and modification time. - Per-user and per-directory virtual permissions, for each exposed path you can allow or deny: directory listing, upload, overwrite, download, delete, rename, create directories, create symlinks, change owner/group/file mode and modification time.
- [REST API](./docs/rest-api.md) for users and folders management, data retention, backup, restore and real time reports of the active connections with possibility of forcibly closing a connection. - [REST API](./docs/rest-api.md) for users and folders management, data retention, backup, restore and real time reports of the active connections with possibility of forcibly closing a connection.
- The [Event Manager](./docs/eventmanager.md) allows to define custom workflows based on server events or schedules.
- [Web based administration interface](./docs/web-admin.md) to easily manage users, folders and connections. - [Web based administration interface](./docs/web-admin.md) to easily manage users, folders and connections.
- [Web client interface](./docs/web-client.md) so that end users can change their credentials, manage and share their files in the browser. - [Web client interface](./docs/web-client.md) so that end users can change their credentials, manage and share their files in the browser.
- Public key and password authentication. Multiple public keys per-user are supported. - Public key and password authentication. Multiple public keys per-user are supported.
@@ -64,10 +49,10 @@ If you report an invalid issue or ask for step-by-step support, your issue will
- Per-user authentication methods. - Per-user authentication methods.
- [Two-factor authentication](./docs/howto/two-factor-authentication.md) based on time-based one time passwords (RFC 6238) which works with Authy, Google Authenticator and other compatible apps. - [Two-factor authentication](./docs/howto/two-factor-authentication.md) based on time-based one time passwords (RFC 6238) which works with Authy, Google Authenticator and other compatible apps.
- Simplified user administrations using [groups](./docs/groups.md). - Simplified user administrations using [groups](./docs/groups.md).
- Custom authentication via [external programs/HTTP API](./docs/external-auth.md). - Custom authentication via external programs/HTTP API.
- Web Client and Web Admin user interfaces support [OpenID Connect](https://openid.net/connect/) authentication and so they can be integrated with identity providers such as [Keycloak](https://www.keycloak.org/). You can find more details [here](./docs/oidc.md). - Web Client and Web Admin user interfaces support [OpenID Connect](https://openid.net/connect/) authentication and so they can be integrated with identity providers such as [Keycloak](https://www.keycloak.org/). You can find more details [here](./docs/oidc.md).
- [Data At Rest Encryption](./docs/dare.md). - [Data At Rest Encryption](./docs/dare.md).
- Dynamic user modification before login via [external programs/HTTP API](./docs/dynamic-user-mod.md). - Dynamic user modification before login via external programs/HTTP API.
- Quota support: accounts can have individual disk quota expressed as max total size and/or max number of files. - Quota support: accounts can have individual disk quota expressed as max total size and/or max number of files.
- Bandwidth throttling, with separate settings for upload and download and overrides based on the client's IP address. - Bandwidth throttling, with separate settings for upload and download and overrides based on the client's IP address.
- Data transfer bandwidth limits, with total limit or separate settings for uploads and downloads and overrides based on the client's IP address. Limits can be reset using the REST API. - Data transfer bandwidth limits, with total limit or separate settings for uploads and downloads and overrides based on the client's IP address. Limits can be reset using the REST API.
@@ -104,10 +89,8 @@ SFTPGo is developed and tested on Linux. After each commit, the code is automati
## Requirements ## Requirements
- Go as build only dependency. We support the Go version(s) used in [continuous integration workflows](./.github/workflows). - Go as build only dependency. We support the Go version(s) used in [continuous integration workflows](./.github/workflows).
- A suitable SQL server to use as data provider: - A suitable SQL server to use as data provider: PostgreSQL 9.4+, MySQL 5.6+, SQLite 3.x, CockroachDB stable.
- upstream supported versions of PostgreSQL, MySQL and MariaDB. - The SQL server is optional: you can choose to use an embedded bolt database as key/value store or an in memory data provider.
- CockroachDB stable.
- The SQL server is optional: you can choose to use an embedded SQLite, bolt or in memory data provider.
## Installation ## Installation
@@ -131,13 +114,7 @@ An official Docker image is available. Documentation is [here](./docker/README.m
APT and YUM repositories are [available](./docs/repo.md). APT and YUM repositories are [available](./docs/repo.md).
SFTPGo is also available on some marketplaces: SFTPGo is also available on [AWS Marketplace](https://aws.amazon.com/marketplace/seller-profile?id=6e849ab8-70a6-47de-9a43-13c3fa849335) and [Azure Marketplace](https://azuremarketplace.microsoft.com/en-us/marketplace/apps/prasselsrl1645470739547.sftpgo_linux), purchasing from there will help keep SFTPGo a long-term sustainable project.
- [AWS Marketplace](https://aws.amazon.com/marketplace/seller-profile?id=6e849ab8-70a6-47de-9a43-13c3fa849335)
- [Azure Marketplace](https://azuremarketplace.microsoft.com/en-us/marketplace/apps/prasselsrl1645470739547.sftpgo_linux)
- [Elest.io](https://elest.io/open-source/sftpgo)
Purchasing from there will help keep SFTPGo a long-term sustainable project.
<details><summary>Windows packages</summary> <details><summary>Windows packages</summary>
@@ -248,18 +225,16 @@ The `revertprovider` command is not supported for the memory provider.
Please note that we only support the current release branch and the current main branch, if you find a bug it is better to report it rather than downgrading to an older unsupported version. Please note that we only support the current release branch and the current main branch, if you find a bug it is better to report it rather than downgrading to an older unsupported version.
## Users, groups and folders management ## Users and folders management
After starting SFTPGo you can manage users, groups, folders and other resources using: After starting SFTPGo you can manage users and folders using:
- the [web based administration interface](./docs/web-admin.md) - the [web based administration interface](./docs/web-admin.md)
- the [REST API](./docs/rest-api.md) - the [REST API](./docs/rest-api.md)
To support embedded data providers like `bolt` and `SQLite`, which do not support concurrent connections, we can't have a CLI that directly write users and other resources to the data provider, we always have to use the REST API. To support embedded data providers like `bolt` and `SQLite` we can't have a CLI that directly write users and folders to the data provider, we always have to use the REST API.
Full details for users, groups, folders, admins and other resources are documented in the [OpenAPI](./openapi/openapi.yaml) schema. If you want to render the schema without importing it manually, you can explore it on [Stoplight](https://sftpgo.stoplight.io/docs/sftpgo/openapi.yaml). Full details for users, folders, admins and other resources are documented in the [OpenAPI](./openapi/openapi.yaml) schema. If you want to render the schema without importing it manually, you can explore it on [Stoplight](https://sftpgo.stoplight.io/docs/sftpgo/openapi.yaml).
:warning: SFTPGo users, groups and folders are virtual and therefore unrelated to the system ones. There is no need to create system-wide users and groups.
## Tutorials ## Tutorials
@@ -315,10 +290,6 @@ Each user can be mapped to another SFTP server account or a subfolder of it. Mor
Data at-rest encryption is supported via the [cryptfs backend](./docs/dare.md). Data at-rest encryption is supported via the [cryptfs backend](./docs/dare.md).
### HTTP/S backend
HTTP/S backend allows you to write your own custom storage backend by implementing a REST API. More information can be found [here](./docs/httpfs.md).
### Other Storage backends ### Other Storage backends
Adding new storage backends is quite easy: Adding new storage backends is quite easy:
@@ -360,4 +331,4 @@ Thank you [ysura](https://www.ysura.com/) for granting me stable access to a tes
## License ## License
GNU AGPL-3.0-only GNU AGPLv3

View File

@@ -1,8 +1,8 @@
# SFTPGo # SFTPGo
[![CI Status](https://github.com/drakkan/sftpgo/workflows/CI/badge.svg?branch=main&event=push)](https://github.com/drakkan/sftpgo/workflows/CI/badge.svg?branch=main&event=push) ![CI Status](https://github.com/drakkan/sftpgo/workflows/CI/badge.svg?branch=main&event=push)
[![Code Coverage](https://codecov.io/gh/drakkan/sftpgo/branch/main/graph/badge.svg)](https://codecov.io/gh/drakkan/sftpgo/branch/main) [![Code Coverage](https://codecov.io/gh/drakkan/sftpgo/branch/main/graph/badge.svg)](https://codecov.io/gh/drakkan/sftpgo/branch/main)
[![License: AGPL-3.0-only](https://img.shields.io/badge/License-AGPLv3-blue.svg)](https://www.gnu.org/licenses/agpl-3.0) [![License: AGPL v3](https://img.shields.io/badge/License-AGPLv3-blue.svg)](https://www.gnu.org/licenses/agpl-3.0)
[![Docker Pulls](https://img.shields.io/docker/pulls/drakkan/sftpgo)](https://hub.docker.com/r/drakkan/sftpgo) [![Docker Pulls](https://img.shields.io/docker/pulls/drakkan/sftpgo)](https://hub.docker.com/r/drakkan/sftpgo)
[![Mentioned in Awesome Go](https://awesome.re/mentioned-badge.svg)](https://github.com/avelino/awesome-go) [![Mentioned in Awesome Go](https://awesome.re/mentioned-badge.svg)](https://github.com/avelino/awesome-go)
@@ -11,39 +11,6 @@
功能齐全、高度可配置化、支持自定义 HTTP/SFTP/S 和 WebDAV 的 SFTP 服务。 功能齐全、高度可配置化、支持自定义 HTTP/SFTP/S 和 WebDAV 的 SFTP 服务。
一些存储后端支持本地文件系统、加密本地文件系统、S3兼容对象存储Google Cloud 存储Azure Blob 存储SFTP。 一些存储后端支持本地文件系统、加密本地文件系统、S3兼容对象存储Google Cloud 存储Azure Blob 存储SFTP。
## 赞助商
如果你觉得 SFTPGo 有用,请考虑支持这个开源项目。
维护和发展 SFTPGo 对我来说是很多工作——很容易相当于一份全职工作。
我想让 SFTPGo 成为一个可持续的长期项目,并且不想引入双重许可选项并将某些功能仅限于专有版本。
如果您使用 SFTPGo确保您所依赖的项目保持健康和维护良好符合您的最大利益。
这只能通过您的捐款和[赞助](https://github.com/sponsors/drakkan) 发生heart
如果您只是拿走任何东西而不返回任何东西,从长远来看,该项目将失败,您将被迫为类似的专有解决方案付费。
[更多信息](https://github.com/drakkan/sftpgo/issues/452)。
### 感谢我们的赞助商
#### 白金赞助商
[<img src="./img/Aledade_logo.png" alt="Aledade logo" width="202" height="70">](https://www.aledade.com/)
#### 铜牌赞助商
[<img src="https://www.7digital.com/wp-content/themes/sevendigital/images/top_logo.png" alt="7digital logo">](https://www.7digital.com/)
## 支持政策
SFTPGo 是一个开源项目,您当然可以免费使用它,但也请不要要求免费支持。
我们将检查报告的问题以查看您是否遇到错误,如果是,我们将修复它,但只会为项目赞助商/捐助者提供支持。
如果您报告无效问题或要求逐步支持,您的问题将保持打开状态而没有答案,或者将被关闭为无效而无需进一步解释。 感谢您的理解。
## 特性 ## 特性
- 支持服务本地文件系统、加密本地文件系统、S3 兼容对象存储、Google Cloud 存储、Azure Blob 存储或其它基于 SFTP/SCP/FTP/WebDAV 协议的 SFTP 账户。 - 支持服务本地文件系统、加密本地文件系统、S3 兼容对象存储、Google Cloud 存储、Azure Blob 存储或其它基于 SFTP/SCP/FTP/WebDAV 协议的 SFTP 账户。
@@ -148,7 +115,7 @@ SFTPGo 在 [AWS Marketplace](https://aws.amazon.com/marketplace/seller-profile?i
可以完整的配置项方法说明可以参考 [配置项](./docs/full-configuration.md)。 可以完整的配置项方法说明可以参考 [配置项](./docs/full-configuration.md)。
请确保按需运行之前,[初始化数据提供程序](#数据提供程序初始化和管理)。 请确保按需运行之前,[初始化数据提供程序](#data-provider-initialization-and-management)。
默认配置启动 STFPGo运行 默认配置启动 STFPGo运行
@@ -348,4 +315,4 @@ SFTPGo 使用了 [go.mod](./go.mod) 中列出的第三方库。
## 许可证 ## 许可证
GNU AGPL-3.0-only GNU AGPLv3

View File

@@ -45,14 +45,13 @@ import (
"github.com/go-acme/lego/v4/registration" "github.com/go-acme/lego/v4/registration"
"github.com/robfig/cron/v3" "github.com/robfig/cron/v3"
"github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/ftpd"
"github.com/drakkan/sftpgo/v2/internal/ftpd" "github.com/drakkan/sftpgo/v2/httpd"
"github.com/drakkan/sftpgo/v2/internal/httpd" "github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/telemetry"
"github.com/drakkan/sftpgo/v2/internal/telemetry" "github.com/drakkan/sftpgo/v2/util"
"github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/version"
"github.com/drakkan/sftpgo/v2/internal/version" "github.com/drakkan/sftpgo/v2/webdavd"
"github.com/drakkan/sftpgo/v2/internal/webdavd"
) )
const ( const (
@@ -562,24 +561,6 @@ func (c *Configuration) getCertificates() error {
return nil return nil
} }
func (c *Configuration) notifyCertificateRenewal(domain string, err error) {
if domain == "" {
domain = strings.Join(c.Domains, ",")
}
params := common.EventParams{
Name: domain,
Event: "Certificate renewal",
Timestamp: time.Now().UnixNano(),
}
if err != nil {
params.Status = 2
params.AddError(err)
} else {
params.Status = 1
}
common.HandleCertificateEvent(params)
}
func (c *Configuration) renewCertificates() error { func (c *Configuration) renewCertificates() error {
lockTime, err := c.getLockTime() lockTime, err := c.getLockTime()
if err != nil { if err != nil {
@@ -592,28 +573,22 @@ func (c *Configuration) renewCertificates() error {
} }
err = c.setLockTime() err = c.setLockTime()
if err != nil { if err != nil {
c.notifyCertificateRenewal("", err)
return err return err
} }
account, client, err := c.setup() account, client, err := c.setup()
if err != nil { if err != nil {
c.notifyCertificateRenewal("", err)
return err return err
} }
if account.Registration == nil { if account.Registration == nil {
acmeLog(logger.LevelError, "cannot renew certificates, your account is not registered") acmeLog(logger.LevelError, "cannot renew certificates, your account is not registered")
err = errors.New("cannot renew certificates, your account is not registered") return fmt.Errorf("cannot renew certificates, your account is not registered")
c.notifyCertificateRenewal("", err)
return err
} }
var errRenew error var errRenew error
needReload := false needReload := false
for _, domain := range c.Domains { for _, domain := range c.Domains {
certificates, err := c.loadCertificatesForDomain(domain) certificates, err := c.loadCertificatesForDomain(domain)
if err != nil { if err != nil {
c.notifyCertificateRenewal(domain, err) return err
errRenew = err
continue
} }
cert := certificates[0] cert := certificates[0]
if !c.needRenewal(cert, domain) { if !c.needRenewal(cert, domain) {
@@ -621,10 +596,8 @@ func (c *Configuration) renewCertificates() error {
} }
err = c.obtainAndSaveCertificate(client, domain) err = c.obtainAndSaveCertificate(client, domain)
if err != nil { if err != nil {
c.notifyCertificateRenewal(domain, err)
errRenew = err errRenew = err
} else { } else {
c.notifyCertificateRenewal(domain, nil)
needReload = true needReload = true
} }
} }

View File

@@ -20,10 +20,10 @@ import (
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/drakkan/sftpgo/v2/internal/acme" "github.com/drakkan/sftpgo/v2/acme"
"github.com/drakkan/sftpgo/v2/internal/config" "github.com/drakkan/sftpgo/v2/config"
"github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/util"
) )
var ( var (
@@ -40,13 +40,13 @@ Certificates are saved in the configured "certs_path".
After this initial step, the certificates are automatically checked and After this initial step, the certificates are automatically checked and
renewed by the SFTPGo service renewed by the SFTPGo service
`, `,
Run: func(_ *cobra.Command, _ []string) { Run: func(cmd *cobra.Command, args []string) {
logger.DisableLogger() logger.DisableLogger()
logger.EnableConsoleLogger(zerolog.DebugLevel) logger.EnableConsoleLogger(zerolog.DebugLevel)
configDir = util.CleanDirInput(configDir) configDir = util.CleanDirInput(configDir)
err := config.LoadConfig(configDir, configFile) err := config.LoadConfig(configDir, configFile)
if err != nil { if err != nil {
logger.ErrorToConsole("Unable to initialize ACME, config load error: %v", err) logger.ErrorToConsole("Unable to initialize data provider, config load error: %v", err)
return return
} }
acmeConfig := config.GetACMEConfig() acmeConfig := config.GetACMEConfig()

View File

@@ -21,4 +21,4 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
) )
func addAWSContainerFlags(_ *cobra.Command) {} func addAWSContainerFlags(cmd *cobra.Command) {}

View File

@@ -24,8 +24,8 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/cobra/doc" "github.com/spf13/cobra/doc"
"github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/internal/version" "github.com/drakkan/sftpgo/v2/version"
) )
var ( var (
@@ -38,7 +38,7 @@ command-line interface.
By default, it creates the man page files in the "man" directory under the By default, it creates the man page files in the "man" directory under the
current directory. current directory.
`, `,
Run: func(cmd *cobra.Command, _ []string) { Run: func(cmd *cobra.Command, args []string) {
logger.DisableLogger() logger.DisableLogger()
logger.EnableConsoleLogger(zerolog.DebugLevel) logger.EnableConsoleLogger(zerolog.DebugLevel)
if _, err := os.Stat(manDir); errors.Is(err, fs.ErrNotExist) { if _, err := os.Stat(manDir); errors.Is(err, fs.ErrNotExist) {

View File

@@ -21,11 +21,11 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/viper" "github.com/spf13/viper"
"github.com/drakkan/sftpgo/v2/internal/config" "github.com/drakkan/sftpgo/v2/config"
"github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/internal/service" "github.com/drakkan/sftpgo/v2/service"
"github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/util"
) )
var ( var (
@@ -50,7 +50,7 @@ $ sftpgo initprovider
Any defined action is ignored. Any defined action is ignored.
Please take a look at the usage below to customize the options.`, Please take a look at the usage below to customize the options.`,
Run: func(_ *cobra.Command, _ []string) { Run: func(cmd *cobra.Command, args []string) {
logger.DisableLogger() logger.DisableLogger()
logger.EnableConsoleLogger(zerolog.DebugLevel) logger.EnableConsoleLogger(zerolog.DebugLevel)
configDir = util.CleanDirInput(configDir) configDir = util.CleanDirInput(configDir)

View File

@@ -21,8 +21,8 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/drakkan/sftpgo/v2/internal/service" "github.com/drakkan/sftpgo/v2/service"
"github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/util"
) )
var ( var (
@@ -35,7 +35,7 @@ line flags simply use:
sftpgo service install sftpgo service install
Please take a look at the usage below to customize the startup options`, Please take a look at the usage below to customize the startup options`,
Run: func(_ *cobra.Command, _ []string) { Run: func(cmd *cobra.Command, args []string) {
s := service.Service{ s := service.Service{
ConfigDir: util.CleanDirInput(configDir), ConfigDir: util.CleanDirInput(configDir),
ConfigFile: configFile, ConfigFile: configFile,
@@ -44,7 +44,7 @@ Please take a look at the usage below to customize the startup options`,
LogMaxBackups: logMaxBackups, LogMaxBackups: logMaxBackups,
LogMaxAge: logMaxAge, LogMaxAge: logMaxAge,
LogCompress: logCompress, LogCompress: logCompress,
LogLevel: logLevel, LogVerbose: logVerbose,
LogUTCTime: logUTCTime, LogUTCTime: logUTCTime,
Shutdown: make(chan bool), Shutdown: make(chan bool),
} }
@@ -99,9 +99,8 @@ func getCustomServeFlags() []string {
result = append(result, "--"+logMaxAgeFlag) result = append(result, "--"+logMaxAgeFlag)
result = append(result, strconv.Itoa(logMaxAge)) result = append(result, strconv.Itoa(logMaxAge))
} }
if logLevel != defaultLogLevel { if logVerbose != defaultLogVerbose {
result = append(result, "--"+logLevelFlag) result = append(result, "--"+logVerboseFlag+"=false")
result = append(result, logLevel)
} }
if logUTCTime != defaultLogUTCTime { if logUTCTime != defaultLogUTCTime {
result = append(result, "--"+logUTCTimeFlag+"=true") result = append(result, "--"+logUTCTimeFlag+"=true")
@@ -109,9 +108,5 @@ func getCustomServeFlags() []string {
if logCompress != defaultLogCompress { if logCompress != defaultLogCompress {
result = append(result, "--"+logCompressFlag+"=true") result = append(result, "--"+logCompressFlag+"=true")
} }
if graceTime != defaultGraceTime {
result = append(result, "--"+graceTimeFlag)
result = append(result, strconv.Itoa(graceTime))
}
return result return result
} }

View File

@@ -27,13 +27,13 @@ import (
"github.com/sftpgo/sdk" "github.com/sftpgo/sdk"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/common"
"github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/kms"
"github.com/drakkan/sftpgo/v2/internal/service" "github.com/drakkan/sftpgo/v2/service"
"github.com/drakkan/sftpgo/v2/internal/sftpd" "github.com/drakkan/sftpgo/v2/sftpd"
"github.com/drakkan/sftpgo/v2/internal/version" "github.com/drakkan/sftpgo/v2/version"
"github.com/drakkan/sftpgo/v2/internal/vfs" "github.com/drakkan/sftpgo/v2/vfs"
) )
var ( var (
@@ -45,7 +45,7 @@ var (
portablePassword string portablePassword string
portableStartDir string portableStartDir string
portableLogFile string portableLogFile string
portableLogLevel string portableLogVerbose bool
portableLogUTCTime bool portableLogUTCTime bool
portablePublicKeys []string portablePublicKeys []string
portablePermissions []string portablePermissions []string
@@ -106,7 +106,7 @@ use:
$ sftpgo portable $ sftpgo portable
Please take a look at the usage below to customize the serving parameters`, Please take a look at the usage below to customize the serving parameters`,
Run: func(_ *cobra.Command, _ []string) { Run: func(cmd *cobra.Command, args []string) {
portableDir := directoryToServe portableDir := directoryToServe
fsProvider := sdk.GetProviderByName(portableFsProvider) fsProvider := sdk.GetProviderByName(portableFsProvider)
if !filepath.IsAbs(portableDir) { if !filepath.IsAbs(portableDir) {
@@ -169,7 +169,6 @@ Please take a look at the usage below to customize the serving parameters`,
os.Exit(1) os.Exit(1)
} }
} }
service.SetGraceTime(graceTime)
service := service.Service{ service := service.Service{
ConfigDir: filepath.Clean(defaultConfigDir), ConfigDir: filepath.Clean(defaultConfigDir),
ConfigFile: defaultConfigFile, ConfigFile: defaultConfigFile,
@@ -178,7 +177,7 @@ Please take a look at the usage below to customize the serving parameters`,
LogMaxBackups: defaultLogMaxBackup, LogMaxBackups: defaultLogMaxBackup,
LogMaxAge: defaultLogMaxAge, LogMaxAge: defaultLogMaxAge,
LogCompress: defaultLogCompress, LogCompress: defaultLogCompress,
LogLevel: portableLogLevel, LogVerbose: portableLogVerbose,
LogUTCTime: portableLogUTCTime, LogUTCTime: portableLogUTCTime,
Shutdown: make(chan bool), Shutdown: make(chan bool),
PortableMode: 1, PortableMode: 1,
@@ -258,10 +257,8 @@ Please take a look at the usage below to customize the serving parameters`,
}, },
}, },
} }
err := service.StartPortableMode(portableSFTPDPort, portableFTPDPort, portableWebDAVPort, portableSSHCommands, if err := service.StartPortableMode(portableSFTPDPort, portableFTPDPort, portableWebDAVPort, portableSSHCommands, portableAdvertiseService,
portableAdvertiseService, portableAdvertiseCredentials, portableFTPSCert, portableFTPSKey, portableWebDAVCert, portableAdvertiseCredentials, portableFTPSCert, portableFTPSKey, portableWebDAVCert, portableWebDAVKey); err == nil {
portableWebDAVKey)
if err == nil {
service.Wait() service.Wait()
if service.Error == nil { if service.Error == nil {
os.Exit(0) os.Exit(0)
@@ -298,11 +295,7 @@ value`)
portableCmd.Flags().StringVarP(&portablePassword, "password", "p", "", `Leave empty to use an auto generated portableCmd.Flags().StringVarP(&portablePassword, "password", "p", "", `Leave empty to use an auto generated
value`) value`)
portableCmd.Flags().StringVarP(&portableLogFile, logFilePathFlag, "l", "", "Leave empty to disable logging") portableCmd.Flags().StringVarP(&portableLogFile, logFilePathFlag, "l", "", "Leave empty to disable logging")
portableCmd.Flags().StringVar(&portableLogLevel, logLevelFlag, defaultLogLevel, `Set the log level. portableCmd.Flags().BoolVarP(&portableLogVerbose, logVerboseFlag, "v", false, "Enable verbose logs")
Supported values:
debug, info, warn, error.
`)
portableCmd.Flags().BoolVar(&portableLogUTCTime, logUTCTimeFlag, false, "Use UTC time for logging") portableCmd.Flags().BoolVar(&portableLogUTCTime, logUTCTimeFlag, false, "Use UTC time for logging")
portableCmd.Flags().StringSliceVarP(&portablePublicKeys, "public-key", "k", []string{}, "") portableCmd.Flags().StringSliceVarP(&portablePublicKeys, "public-key", "k", []string{}, "")
portableCmd.Flags().StringSliceVarP(&portablePermissions, "permissions", "g", []string{"list", "download"}, portableCmd.Flags().StringSliceVarP(&portablePermissions, "permissions", "g", []string{"list", "download"},
@@ -406,13 +399,6 @@ multiple concurrent requests and this
allows data to be transferred at a allows data to be transferred at a
faster rate, over high latency networks, faster rate, over high latency networks,
by overlapping round-trip times`) by overlapping round-trip times`)
portableCmd.Flags().IntVar(&graceTime, graceTimeFlag, 0,
`This grace time defines the number of
seconds allowed for existing transfers
to get completed before shutting down.
A graceful shutdown is triggered by an
interrupt signal.
`)
rootCmd.AddCommand(portableCmd) rootCmd.AddCommand(portableCmd)
} }

View File

@@ -17,7 +17,7 @@
package cmd package cmd
import "github.com/drakkan/sftpgo/v2/internal/version" import "github.com/drakkan/sftpgo/v2/version"
func init() { func init() {
version.AddFeature("-portable") version.AddFeature("-portable")

View File

@@ -20,14 +20,14 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/drakkan/sftpgo/v2/internal/service" "github.com/drakkan/sftpgo/v2/service"
) )
var ( var (
reloadCmd = &cobra.Command{ reloadCmd = &cobra.Command{
Use: "reload", Use: "reload",
Short: "Reload the SFTPGo Windows Service sending a \"paramchange\" request", Short: "Reload the SFTPGo Windows Service sending a \"paramchange\" request",
Run: func(_ *cobra.Command, _ []string) { Run: func(cmd *cobra.Command, args []string) {
s := service.WindowsService{ s := service.WindowsService{
Service: service.Service{ Service: service.Service{
Shutdown: make(chan bool), Shutdown: make(chan bool),

View File

@@ -23,10 +23,10 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/viper" "github.com/spf13/viper"
"github.com/drakkan/sftpgo/v2/internal/config" "github.com/drakkan/sftpgo/v2/config"
"github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/util"
) )
var ( var (
@@ -39,7 +39,7 @@ configuration file and resets the provider by deleting all data and schemas.
This command is not supported for the memory provider. This command is not supported for the memory provider.
Please take a look at the usage below to customize the options.`, Please take a look at the usage below to customize the options.`,
Run: func(_ *cobra.Command, _ []string) { Run: func(cmd *cobra.Command, args []string) {
logger.DisableLogger() logger.DisableLogger()
logger.EnableConsoleLogger(zerolog.DebugLevel) logger.EnableConsoleLogger(zerolog.DebugLevel)
configDir = util.CleanDirInput(configDir) configDir = util.CleanDirInput(configDir)

View File

@@ -21,10 +21,10 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/viper" "github.com/spf13/viper"
"github.com/drakkan/sftpgo/v2/internal/config" "github.com/drakkan/sftpgo/v2/config"
"github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/util"
) )
var ( var (
@@ -37,11 +37,11 @@ configuration file and restore the provider schema and/or data to a previous ver
This command is not supported for the memory provider. This command is not supported for the memory provider.
Please take a look at the usage below to customize the options.`, Please take a look at the usage below to customize the options.`,
Run: func(_ *cobra.Command, _ []string) { Run: func(cmd *cobra.Command, args []string) {
logger.DisableLogger() logger.DisableLogger()
logger.EnableConsoleLogger(zerolog.DebugLevel) logger.EnableConsoleLogger(zerolog.DebugLevel)
if revertProviderTargetVersion != 19 { if revertProviderTargetVersion != 15 {
logger.WarnToConsole("Unsupported target version, 19 is the only supported one") logger.WarnToConsole("Unsupported target version, 15 is the only supported one")
os.Exit(1) os.Exit(1)
} }
configDir = util.CleanDirInput(configDir) configDir = util.CleanDirInput(configDir)
@@ -71,7 +71,7 @@ Please take a look at the usage below to customize the options.`,
func init() { func init() {
addConfigFlags(revertProviderCmd) addConfigFlags(revertProviderCmd)
revertProviderCmd.Flags().IntVar(&revertProviderTargetVersion, "to-version", 19, `19 means the version supported in v2.3.x`) revertProviderCmd.Flags().IntVar(&revertProviderTargetVersion, "to-version", 15, `15 means the version supported in v2.2.x`)
rootCmd.AddCommand(revertProviderCmd) rootCmd.AddCommand(revertProviderCmd)
} }

View File

@@ -22,7 +22,7 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/viper" "github.com/spf13/viper"
"github.com/drakkan/sftpgo/v2/internal/version" "github.com/drakkan/sftpgo/v2/version"
) )
const ( const (
@@ -40,8 +40,8 @@ const (
logMaxAgeKey = "log_max_age" logMaxAgeKey = "log_max_age"
logCompressFlag = "log-compress" logCompressFlag = "log-compress"
logCompressKey = "log_compress" logCompressKey = "log_compress"
logLevelFlag = "log-level" logVerboseFlag = "log-verbose"
logLevelKey = "log_level" logVerboseKey = "log_verbose"
logUTCTimeFlag = "log-utc-time" logUTCTimeFlag = "log-utc-time"
logUTCTimeKey = "log_utc_time" logUTCTimeKey = "log_utc_time"
loadDataFromFlag = "loaddata-from" loadDataFromFlag = "loaddata-from"
@@ -52,8 +52,6 @@ const (
loadDataQuotaScanKey = "loaddata_scan" loadDataQuotaScanKey = "loaddata_scan"
loadDataCleanFlag = "loaddata-clean" loadDataCleanFlag = "loaddata-clean"
loadDataCleanKey = "loaddata_clean" loadDataCleanKey = "loaddata_clean"
graceTimeFlag = "grace-time"
graceTimeKey = "grace_time"
defaultConfigDir = "." defaultConfigDir = "."
defaultConfigFile = "" defaultConfigFile = ""
defaultLogFile = "sftpgo.log" defaultLogFile = "sftpgo.log"
@@ -61,13 +59,12 @@ const (
defaultLogMaxBackup = 5 defaultLogMaxBackup = 5
defaultLogMaxAge = 28 defaultLogMaxAge = 28
defaultLogCompress = false defaultLogCompress = false
defaultLogLevel = "debug" defaultLogVerbose = true
defaultLogUTCTime = false defaultLogUTCTime = false
defaultLoadDataFrom = "" defaultLoadDataFrom = ""
defaultLoadDataMode = 1 defaultLoadDataMode = 1
defaultLoadDataQuotaScan = 0 defaultLoadDataQuotaScan = 0
defaultLoadDataClean = false defaultLoadDataClean = false
defaultGraceTime = 0
) )
var ( var (
@@ -78,13 +75,12 @@ var (
logMaxBackups int logMaxBackups int
logMaxAge int logMaxAge int
logCompress bool logCompress bool
logLevel string logVerbose bool
logUTCTime bool logUTCTime bool
loadDataFrom string loadDataFrom string
loadDataMode int loadDataMode int
loadDataQuotaScan int loadDataQuotaScan int
loadDataClean bool loadDataClean bool
graceTime int
// used if awscontainer build tag is enabled // used if awscontainer build tag is enabled
disableAWSInstallationCode bool disableAWSInstallationCode bool
@@ -233,17 +229,13 @@ It is unused if log-file-path is empty.
`) `)
viper.BindPFlag(logCompressKey, cmd.Flags().Lookup(logCompressFlag)) //nolint:errcheck viper.BindPFlag(logCompressKey, cmd.Flags().Lookup(logCompressFlag)) //nolint:errcheck
viper.SetDefault(logLevelKey, defaultLogLevel) viper.SetDefault(logVerboseKey, defaultLogVerbose)
viper.BindEnv(logLevelKey, "SFTPGO_LOG_LEVEL") //nolint:errcheck viper.BindEnv(logVerboseKey, "SFTPGO_LOG_VERBOSE") //nolint:errcheck
cmd.Flags().StringVar(&logLevel, logLevelFlag, viper.GetString(logLevelKey), cmd.Flags().BoolVarP(&logVerbose, logVerboseFlag, "v", viper.GetBool(logVerboseKey),
`Set the log level. Supported values: `Enable verbose logs. This flag can be set
using SFTPGO_LOG_VERBOSE env var too.
debug, info, warn, error.
This flag can be set
using SFTPGO_LOG_LEVEL env var too.
`) `)
viper.BindPFlag(logLevelKey, cmd.Flags().Lookup(logLevelFlag)) //nolint:errcheck viper.BindPFlag(logVerboseKey, cmd.Flags().Lookup(logVerboseFlag)) //nolint:errcheck
viper.SetDefault(logUTCTimeKey, defaultLogUTCTime) viper.SetDefault(logUTCTimeKey, defaultLogUTCTime)
viper.BindEnv(logUTCTimeKey, "SFTPGO_LOG_UTC_TIME") //nolint:errcheck viper.BindEnv(logUTCTimeKey, "SFTPGO_LOG_UTC_TIME") //nolint:errcheck
@@ -266,20 +258,4 @@ This flag can be set using SFTPGO_LOADDATA_QUOTA_SCAN
env var too. env var too.
(default 0)`) (default 0)`)
viper.BindPFlag(loadDataQuotaScanKey, cmd.Flags().Lookup(loadDataQuotaScanFlag)) //nolint:errcheck viper.BindPFlag(loadDataQuotaScanKey, cmd.Flags().Lookup(loadDataQuotaScanFlag)) //nolint:errcheck
viper.SetDefault(graceTimeKey, defaultGraceTime)
viper.BindEnv(graceTimeKey, "SFTPGO_GRACE_TIME") //nolint:errcheck
cmd.Flags().IntVar(&graceTime, graceTimeFlag, viper.GetInt(graceTimeKey),
`Graceful shutdown is an option to initiate a
shutdown without abrupt cancellation of the
currently ongoing client-initiated transfer
sessions.
This grace time defines the number of seconds
allowed for existing transfers to get
completed before shutting down.
A graceful shutdown is triggered by an
interrupt signal.
This flag can be set using SFTPGO_GRACE_TIME env
var too. 0 means disabled. (default 0)`)
viper.BindPFlag(graceTimeKey, cmd.Flags().Lookup(graceTimeFlag)) //nolint:errcheck
} }

View File

@@ -20,14 +20,14 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/drakkan/sftpgo/v2/internal/service" "github.com/drakkan/sftpgo/v2/service"
) )
var ( var (
rotateLogCmd = &cobra.Command{ rotateLogCmd = &cobra.Command{
Use: "rotatelogs", Use: "rotatelogs",
Short: "Signal to the running service to rotate the logs", Short: "Signal to the running service to rotate the logs",
Run: func(_ *cobra.Command, _ []string) { Run: func(cmd *cobra.Command, args []string) {
s := service.WindowsService{ s := service.WindowsService{
Service: service.Service{ Service: service.Service{
Shutdown: make(chan bool), Shutdown: make(chan bool),

View File

@@ -19,8 +19,8 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/drakkan/sftpgo/v2/internal/service" "github.com/drakkan/sftpgo/v2/service"
"github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/util"
) )
var ( var (
@@ -33,8 +33,7 @@ use:
$ sftpgo serve $ sftpgo serve
Please take a look at the usage below to customize the startup options`, Please take a look at the usage below to customize the startup options`,
Run: func(_ *cobra.Command, _ []string) { Run: func(cmd *cobra.Command, args []string) {
service.SetGraceTime(graceTime)
service := service.Service{ service := service.Service{
ConfigDir: util.CleanDirInput(configDir), ConfigDir: util.CleanDirInput(configDir),
ConfigFile: configFile, ConfigFile: configFile,
@@ -43,7 +42,7 @@ Please take a look at the usage below to customize the startup options`,
LogMaxBackups: logMaxBackups, LogMaxBackups: logMaxBackups,
LogMaxAge: logMaxAge, LogMaxAge: logMaxAge,
LogCompress: logCompress, LogCompress: logCompress,
LogLevel: logLevel, LogVerbose: logVerbose,
LogUTCTime: logUTCTime, LogUTCTime: logUTCTime,
LoadDataFrom: loadDataFrom, LoadDataFrom: loadDataFrom,
LoadDataMode: loadDataMode, LoadDataMode: loadDataMode,

View File

@@ -20,10 +20,10 @@ import (
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/drakkan/sftpgo/v2/internal/config" "github.com/drakkan/sftpgo/v2/config"
"github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/internal/smtp" "github.com/drakkan/sftpgo/v2/smtp"
"github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/util"
) )
var ( var (
@@ -33,7 +33,7 @@ var (
Short: "Test the SMTP configuration", Short: "Test the SMTP configuration",
Long: `SFTPGo will try to send a test email to the specified recipient. Long: `SFTPGo will try to send a test email to the specified recipient.
If the SMTP configuration is correct you should receive this email.`, If the SMTP configuration is correct you should receive this email.`,
Run: func(_ *cobra.Command, _ []string) { Run: func(cmd *cobra.Command, args []string) {
logger.DisableLogger() logger.DisableLogger()
logger.EnableConsoleLogger(zerolog.DebugLevel) logger.EnableConsoleLogger(zerolog.DebugLevel)
configDir = util.CleanDirInput(configDir) configDir = util.CleanDirInput(configDir)
@@ -48,7 +48,7 @@ If the SMTP configuration is correct you should receive this email.`,
logger.ErrorToConsole("unable to initialize SMTP configuration: %v", err) logger.ErrorToConsole("unable to initialize SMTP configuration: %v", err)
os.Exit(1) os.Exit(1)
} }
err = smtp.SendEmail([]string{smtpTestRecipient}, "SFTPGo - Testing Email Settings", "It appears your SFTPGo email is setup correctly!", err = smtp.SendEmail(smtpTestRecipient, "SFTPGo - Testing Email Settings", "It appears your SFTPGo email is setup correctly!",
smtp.EmailContentTypeTextPlain) smtp.EmailContentTypeTextPlain)
if err != nil { if err != nil {
logger.WarnToConsole("Error sending email: %v", err) logger.WarnToConsole("Error sending email: %v", err)

View File

@@ -21,20 +21,19 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/drakkan/sftpgo/v2/internal/service" "github.com/drakkan/sftpgo/v2/service"
"github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/util"
) )
var ( var (
startCmd = &cobra.Command{ startCmd = &cobra.Command{
Use: "start", Use: "start",
Short: "Start the SFTPGo Windows Service", Short: "Start the SFTPGo Windows Service",
Run: func(_ *cobra.Command, _ []string) { Run: func(cmd *cobra.Command, args []string) {
configDir = util.CleanDirInput(configDir) configDir = util.CleanDirInput(configDir)
if !filepath.IsAbs(logFilePath) && util.IsFileInputValid(logFilePath) { if !filepath.IsAbs(logFilePath) && util.IsFileInputValid(logFilePath) {
logFilePath = filepath.Join(configDir, logFilePath) logFilePath = filepath.Join(configDir, logFilePath)
} }
service.SetGraceTime(graceTime)
s := service.Service{ s := service.Service{
ConfigDir: configDir, ConfigDir: configDir,
ConfigFile: configFile, ConfigFile: configFile,
@@ -43,7 +42,7 @@ var (
LogMaxBackups: logMaxBackups, LogMaxBackups: logMaxBackups,
LogMaxAge: logMaxAge, LogMaxAge: logMaxAge,
LogCompress: logCompress, LogCompress: logCompress,
LogLevel: logLevel, LogVerbose: logVerbose,
LogUTCTime: logUTCTime, LogUTCTime: logUTCTime,
Shutdown: make(chan bool), Shutdown: make(chan bool),
} }

View File

@@ -25,13 +25,13 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/viper" "github.com/spf13/viper"
"github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/common"
"github.com/drakkan/sftpgo/v2/internal/config" "github.com/drakkan/sftpgo/v2/config"
"github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/internal/plugin" "github.com/drakkan/sftpgo/v2/plugin"
"github.com/drakkan/sftpgo/v2/internal/sftpd" "github.com/drakkan/sftpgo/v2/sftpd"
"github.com/drakkan/sftpgo/v2/internal/version" "github.com/drakkan/sftpgo/v2/version"
) )
var ( var (
@@ -51,25 +51,18 @@ Subsystem sftp sftpgo startsubsys
Command-line flags should be specified in the Subsystem declaration. Command-line flags should be specified in the Subsystem declaration.
`, `,
Run: func(_ *cobra.Command, _ []string) { Run: func(cmd *cobra.Command, args []string) {
logSender := "startsubsys" logSender := "startsubsys"
connectionID := xid.New().String() connectionID := xid.New().String()
var zeroLogLevel zerolog.Level logLevel := zerolog.DebugLevel
switch logLevel { if !logVerbose {
case "info": logLevel = zerolog.InfoLevel
zeroLogLevel = zerolog.InfoLevel
case "warn":
zeroLogLevel = zerolog.WarnLevel
case "error":
zeroLogLevel = zerolog.ErrorLevel
default:
zeroLogLevel = zerolog.DebugLevel
} }
logger.SetLogTime(logUTCTime) logger.SetLogTime(logUTCTime)
if logJournalD { if logJournalD {
logger.InitJournalDLogger(zeroLogLevel) logger.InitJournalDLogger(logLevel)
} else { } else {
logger.InitStdErrLogger(zeroLogLevel) logger.InitStdErrLogger(logLevel)
} }
osUser, err := user.Current() osUser, err := user.Current()
if err != nil { if err != nil {
@@ -105,7 +98,7 @@ Command-line flags should be specified in the Subsystem declaration.
logger.Error(logSender, "", "unable to initialize MFA: %v", err) logger.Error(logSender, "", "unable to initialize MFA: %v", err)
os.Exit(1) os.Exit(1)
} }
if err := plugin.Initialize(config.GetPluginsConfig(), logLevel); err != nil { if err := plugin.Initialize(config.GetPluginsConfig(), logVerbose); err != nil {
logger.Error(logSender, connectionID, "unable to initialize plugin system: %v", err) logger.Error(logSender, connectionID, "unable to initialize plugin system: %v", err)
os.Exit(1) os.Exit(1)
} }
@@ -203,17 +196,13 @@ error`)
addConfigFlags(subsystemCmd) addConfigFlags(subsystemCmd)
viper.SetDefault(logLevelKey, defaultLogLevel) viper.SetDefault(logVerboseKey, defaultLogVerbose)
viper.BindEnv(logLevelKey, "SFTPGO_LOG_LEVEL") //nolint:errcheck viper.BindEnv(logVerboseKey, "SFTPGO_LOG_VERBOSE") //nolint:errcheck
subsystemCmd.Flags().StringVar(&logLevel, logLevelFlag, viper.GetString(logLevelKey), subsystemCmd.Flags().BoolVarP(&logVerbose, logVerboseFlag, "v", viper.GetBool(logVerboseKey),
`Set the log level. Supported values: `Enable verbose logs. This flag can be set
using SFTPGO_LOG_VERBOSE env var too.
debug, info, warn, error.
This flag can be set
using SFTPGO_LOG_LEVEL env var too.
`) `)
viper.BindPFlag(logLevelKey, subsystemCmd.Flags().Lookup(logLevelFlag)) //nolint:errcheck viper.BindPFlag(logVerboseKey, subsystemCmd.Flags().Lookup(logVerboseFlag)) //nolint:errcheck
viper.SetDefault(logUTCTimeKey, defaultLogUTCTime) viper.SetDefault(logUTCTimeKey, defaultLogUTCTime)
viper.BindEnv(logUTCTimeKey, "SFTPGO_LOG_UTC_TIME") //nolint:errcheck viper.BindEnv(logUTCTimeKey, "SFTPGO_LOG_UTC_TIME") //nolint:errcheck

View File

@@ -20,14 +20,14 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/drakkan/sftpgo/v2/internal/service" "github.com/drakkan/sftpgo/v2/service"
) )
var ( var (
statusCmd = &cobra.Command{ statusCmd = &cobra.Command{
Use: "status", Use: "status",
Short: "Retrieve the status for the SFTPGo Windows Service", Short: "Retrieve the status for the SFTPGo Windows Service",
Run: func(_ *cobra.Command, _ []string) { Run: func(cmd *cobra.Command, args []string) {
s := service.WindowsService{ s := service.WindowsService{
Service: service.Service{ Service: service.Service{
Shutdown: make(chan bool), Shutdown: make(chan bool),

View File

@@ -20,14 +20,14 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/drakkan/sftpgo/v2/internal/service" "github.com/drakkan/sftpgo/v2/service"
) )
var ( var (
stopCmd = &cobra.Command{ stopCmd = &cobra.Command{
Use: "stop", Use: "stop",
Short: "Stop the SFTPGo Windows Service", Short: "Stop the SFTPGo Windows Service",
Run: func(_ *cobra.Command, _ []string) { Run: func(cmd *cobra.Command, args []string) {
s := service.WindowsService{ s := service.WindowsService{
Service: service.Service{ Service: service.Service{
Shutdown: make(chan bool), Shutdown: make(chan bool),

View File

@@ -20,14 +20,14 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/drakkan/sftpgo/v2/internal/service" "github.com/drakkan/sftpgo/v2/service"
) )
var ( var (
uninstallCmd = &cobra.Command{ uninstallCmd = &cobra.Command{
Use: "uninstall", Use: "uninstall",
Short: "Uninstall the SFTPGo Windows Service", Short: "Uninstall the SFTPGo Windows Service",
Run: func(_ *cobra.Command, _ []string) { Run: func(cmd *cobra.Command, args []string) {
s := service.WindowsService{ s := service.WindowsService{
Service: service.Service{ Service: service.Service{
Shutdown: make(chan bool), Shutdown: make(chan bool),

View File

@@ -12,15 +12,13 @@
// You should have received a copy of the GNU Affero General Public License // You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>. // along with this program. If not, see <https://www.gnu.org/licenses/>.
// Package command provides command configuration for SFTPGo hooks
package command package command
import ( import (
"fmt" "fmt"
"os"
"strings" "strings"
"time" "time"
"github.com/drakkan/sftpgo/v2/internal/util"
) )
const ( const (
@@ -29,25 +27,8 @@ const (
defaultTimeout = 30 defaultTimeout = 30
) )
// Supported hook names
const (
HookFsActions = "fs_actions"
HookProviderActions = "provider_actions"
HookStartup = "startup"
HookPostConnect = "post_connect"
HookPostDisconnect = "post_disconnect"
HookDataRetention = "data_retention"
HookCheckPassword = "check_password"
HookPreLogin = "pre_login"
HookPostLogin = "post_login"
HookExternalAuth = "external_auth"
HookKeyboardInteractive = "keyboard_interactive"
)
var ( var (
config Config config Config
supportedHooks = []string{HookFsActions, HookProviderActions, HookStartup, HookPostConnect, HookPostDisconnect,
HookDataRetention, HookCheckPassword, HookPreLogin, HookPostLogin, HookExternalAuth, HookKeyboardInteractive}
) )
// Command define the configuration for a specific commands // Command define the configuration for a specific commands
@@ -59,14 +40,10 @@ type Command struct {
// Do not use variables with the SFTPGO_ prefix to avoid conflicts with env // Do not use variables with the SFTPGO_ prefix to avoid conflicts with env
// vars that SFTPGo sets // vars that SFTPGo sets
Timeout int `json:"timeout" mapstructure:"timeout"` Timeout int `json:"timeout" mapstructure:"timeout"`
// Env defines environment variable for the command. // Env defines additional environment variable for the commands.
// Each entry is of the form "key=value". // Each entry is of the form "key=value".
// These values are added to the global environment variables if any // These values are added to the global environment variables if any
Env []string `json:"env" mapstructure:"env"` Env []string `json:"env" mapstructure:"env"`
// Args defines arguments to pass to the specified command
Args []string `json:"args" mapstructure:"args"`
// if not empty both command path and hook name must match
Hook string `json:"hook" mapstructure:"hook"`
} }
// Config defines the configuration for external commands such as // Config defines the configuration for external commands such as
@@ -74,7 +51,7 @@ type Command struct {
type Config struct { type Config struct {
// Timeout specifies a global time limit, in seconds, for the external commands execution // Timeout specifies a global time limit, in seconds, for the external commands execution
Timeout int `json:"timeout" mapstructure:"timeout"` Timeout int `json:"timeout" mapstructure:"timeout"`
// Env defines environment variable for the commands. // Env defines additional environment variable for the commands.
// Each entry is of the form "key=value". // Each entry is of the form "key=value".
// Do not use variables with the SFTPGO_ prefix to avoid conflicts with env // Do not use variables with the SFTPGO_ prefix to avoid conflicts with env
// vars that SFTPGo sets // vars that SFTPGo sets
@@ -95,7 +72,7 @@ func (c Config) Initialize() error {
return fmt.Errorf("invalid timeout %v", c.Timeout) return fmt.Errorf("invalid timeout %v", c.Timeout)
} }
for _, env := range c.Env { for _, env := range c.Env {
if len(strings.SplitN(env, "=", 2)) != 2 { if len(strings.Split(env, "=")) != 2 {
return fmt.Errorf("invalid env var %#v", env) return fmt.Errorf("invalid env var %#v", env)
} }
} }
@@ -111,37 +88,27 @@ func (c Config) Initialize() error {
} }
} }
for _, env := range cmd.Env { for _, env := range cmd.Env {
if len(strings.SplitN(env, "=", 2)) != 2 { if len(strings.Split(env, "=")) != 2 {
return fmt.Errorf("invalid env var %#v for command %#v", env, cmd.Path) return fmt.Errorf("invalid env var %#v for command %#v", env, cmd.Path)
} }
} }
// don't validate args, we allow to pass empty arguments
if cmd.Hook != "" {
if !util.Contains(supportedHooks, cmd.Hook) {
return fmt.Errorf("invalid hook name %q, supported values: %+v", cmd.Hook, supportedHooks)
}
}
} }
config = c config = c
return nil return nil
} }
// GetConfig returns the configuration for the specified command // GetConfig returns the configuration for the specified command
func GetConfig(command, hook string) (time.Duration, []string, []string) { func GetConfig(command string) (time.Duration, []string) {
env := []string{} env := os.Environ()
var args []string
timeout := time.Duration(config.Timeout) * time.Second timeout := time.Duration(config.Timeout) * time.Second
env = append(env, config.Env...) env = append(env, config.Env...)
for _, cmd := range config.Commands { for _, cmd := range config.Commands {
if cmd.Path == command { if cmd.Path == command {
if cmd.Hook == "" || cmd.Hook == hook { timeout = time.Duration(cmd.Timeout) * time.Second
timeout = time.Duration(cmd.Timeout) * time.Second env = append(env, cmd.Env...)
env = append(env, cmd.Env...) break
args = cmd.Args
break
}
} }
} }
return timeout, env, args return timeout, env
} }

View File

@@ -33,17 +33,15 @@ func TestCommandConfig(t *testing.T) {
assert.Equal(t, cfg.Timeout, config.Timeout) assert.Equal(t, cfg.Timeout, config.Timeout)
assert.Equal(t, cfg.Env, config.Env) assert.Equal(t, cfg.Env, config.Env)
assert.Len(t, cfg.Commands, 0) assert.Len(t, cfg.Commands, 0)
timeout, env, args := GetConfig("cmd", "") timeout, env := GetConfig("cmd")
assert.Equal(t, time.Duration(config.Timeout)*time.Second, timeout) assert.Equal(t, time.Duration(config.Timeout)*time.Second, timeout)
assert.Contains(t, env, "a=b") assert.Contains(t, env, "a=b")
assert.Len(t, args, 0)
cfg.Commands = []Command{ cfg.Commands = []Command{
{ {
Path: "cmd1", Path: "cmd1",
Timeout: 30, Timeout: 30,
Env: []string{"c=d"}, Env: []string{"c=d"},
Args: []string{"1", "", "2"},
}, },
{ {
Path: "cmd2", Path: "cmd2",
@@ -59,68 +57,20 @@ func TestCommandConfig(t *testing.T) {
assert.Equal(t, cfg.Commands[0].Path, config.Commands[0].Path) assert.Equal(t, cfg.Commands[0].Path, config.Commands[0].Path)
assert.Equal(t, cfg.Commands[0].Timeout, config.Commands[0].Timeout) assert.Equal(t, cfg.Commands[0].Timeout, config.Commands[0].Timeout)
assert.Equal(t, cfg.Commands[0].Env, config.Commands[0].Env) assert.Equal(t, cfg.Commands[0].Env, config.Commands[0].Env)
assert.Equal(t, cfg.Commands[0].Args, config.Commands[0].Args)
assert.Equal(t, cfg.Commands[1].Path, config.Commands[1].Path) assert.Equal(t, cfg.Commands[1].Path, config.Commands[1].Path)
assert.Equal(t, cfg.Timeout, config.Commands[1].Timeout) assert.Equal(t, cfg.Timeout, config.Commands[1].Timeout)
assert.Equal(t, cfg.Commands[1].Env, config.Commands[1].Env) assert.Equal(t, cfg.Commands[1].Env, config.Commands[1].Env)
assert.Equal(t, cfg.Commands[1].Args, config.Commands[1].Args)
} }
timeout, env, args = GetConfig("cmd1", "") timeout, env = GetConfig("cmd1")
assert.Equal(t, time.Duration(config.Commands[0].Timeout)*time.Second, timeout) assert.Equal(t, time.Duration(config.Commands[0].Timeout)*time.Second, timeout)
assert.Contains(t, env, "a=b") assert.Contains(t, env, "a=b")
assert.Contains(t, env, "c=d") assert.Contains(t, env, "c=d")
assert.NotContains(t, env, "e=f") assert.NotContains(t, env, "e=f")
if assert.Len(t, args, 3) { timeout, env = GetConfig("cmd2")
assert.Equal(t, "1", args[0])
assert.Empty(t, args[1])
assert.Equal(t, "2", args[2])
}
timeout, env, args = GetConfig("cmd2", "")
assert.Equal(t, time.Duration(config.Timeout)*time.Second, timeout) assert.Equal(t, time.Duration(config.Timeout)*time.Second, timeout)
assert.Contains(t, env, "a=b") assert.Contains(t, env, "a=b")
assert.NotContains(t, env, "c=d") assert.NotContains(t, env, "c=d")
assert.Contains(t, env, "e=f") assert.Contains(t, env, "e=f")
assert.Len(t, args, 0)
cfg.Commands = []Command{
{
Path: "cmd1",
Timeout: 30,
Env: []string{"c=d"},
Args: []string{"1", "", "2"},
Hook: HookCheckPassword,
},
{
Path: "cmd1",
Timeout: 0,
Env: []string{"e=f"},
Hook: HookExternalAuth,
},
}
err = cfg.Initialize()
require.NoError(t, err)
timeout, env, args = GetConfig("cmd1", "")
assert.Equal(t, time.Duration(config.Timeout)*time.Second, timeout)
assert.Contains(t, env, "a=b")
assert.NotContains(t, env, "c=d")
assert.NotContains(t, env, "e=f")
assert.Len(t, args, 0)
timeout, env, args = GetConfig("cmd1", HookCheckPassword)
assert.Equal(t, time.Duration(config.Commands[0].Timeout)*time.Second, timeout)
assert.Contains(t, env, "a=b")
assert.Contains(t, env, "c=d")
assert.NotContains(t, env, "e=f")
if assert.Len(t, args, 3) {
assert.Equal(t, "1", args[0])
assert.Empty(t, args[1])
assert.Equal(t, "2", args[2])
}
timeout, env, args = GetConfig("cmd1", HookExternalAuth)
assert.Equal(t, time.Duration(cfg.Timeout)*time.Second, timeout)
assert.Contains(t, env, "a=b")
assert.NotContains(t, env, "c=d")
assert.Contains(t, env, "e=f")
assert.Len(t, args, 0)
} }
func TestConfigErrors(t *testing.T) { func TestConfigErrors(t *testing.T) {
@@ -166,16 +116,4 @@ func TestConfigErrors(t *testing.T) {
if assert.Error(t, err) { if assert.Error(t, err) {
assert.Contains(t, err.Error(), "invalid env var") assert.Contains(t, err.Error(), "invalid env var")
} }
c.Commands = []Command{
{
Path: "path",
Timeout: 30,
Env: []string{"a=b"},
Hook: "invali",
},
}
err = c.Initialize()
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "invalid hook name")
}
} }

View File

@@ -23,41 +23,27 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"os/exec" "os/exec"
"path"
"path/filepath" "path/filepath"
"strings" "strings"
"sync/atomic"
"time" "time"
"github.com/sftpgo/sdk" "github.com/sftpgo/sdk"
"github.com/sftpgo/sdk/plugin/notifier" "github.com/sftpgo/sdk/plugin/notifier"
"github.com/drakkan/sftpgo/v2/internal/command" "github.com/drakkan/sftpgo/v2/command"
"github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/httpclient" "github.com/drakkan/sftpgo/v2/httpclient"
"github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/internal/plugin" "github.com/drakkan/sftpgo/v2/plugin"
"github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/util"
) )
var ( var (
errUnconfiguredAction = errors.New("no hook is configured for this action") errUnconfiguredAction = errors.New("no hook is configured for this action")
errNoHook = errors.New("unable to execute action, no hook defined") errNoHook = errors.New("unable to execute action, no hook defined")
errUnexpectedHTTResponse = errors.New("unexpected HTTP hook response code") errUnexpectedHTTResponse = errors.New("unexpected HTTP response code")
hooksConcurrencyGuard = make(chan struct{}, 150)
activeHooks atomic.Int32
) )
func startNewHook() {
activeHooks.Add(1)
hooksConcurrencyGuard <- struct{}{}
}
func hookEnded() {
activeHooks.Add(-1)
<-hooksConcurrencyGuard
}
// ProtocolActions defines the action to execute on file operations and SSH commands // ProtocolActions defines the action to execute on file operations and SSH commands
type ProtocolActions struct { type ProtocolActions struct {
// Valid values are download, upload, pre-delete, delete, rename, ssh_cmd. Empty slice to disable // Valid values are download, upload, pre-delete, delete, rename, ssh_cmd. Empty slice to disable
@@ -112,56 +98,26 @@ func ExecutePreAction(conn *BaseConnection, operation, filePath, virtualPath str
// ExecuteActionNotification executes the defined hook, if any, for the specified action // ExecuteActionNotification executes the defined hook, if any, for the specified action
func ExecuteActionNotification(conn *BaseConnection, operation, filePath, virtualPath, target, virtualTarget, sshCmd string, func ExecuteActionNotification(conn *BaseConnection, operation, filePath, virtualPath, target, virtualTarget, sshCmd string,
fileSize int64, err error, fileSize int64, err error,
) error { ) {
hasNotifiersPlugin := plugin.Handler.HasNotifiers() hasNotifiersPlugin := plugin.Handler.HasNotifiers()
hasHook := util.Contains(Config.Actions.ExecuteOn, operation) hasHook := util.Contains(Config.Actions.ExecuteOn, operation)
hasRules := eventManager.hasFsRules() if !hasHook && !hasNotifiersPlugin {
if !hasHook && !hasNotifiersPlugin && !hasRules { return
return nil
} }
notification := newActionNotification(&conn.User, operation, filePath, virtualPath, target, virtualTarget, sshCmd, notification := newActionNotification(&conn.User, operation, filePath, virtualPath, target, virtualTarget, sshCmd,
conn.protocol, conn.GetRemoteIP(), conn.ID, fileSize, 0, err) conn.protocol, conn.GetRemoteIP(), conn.ID, fileSize, 0, err)
if hasNotifiersPlugin { if hasNotifiersPlugin {
plugin.Handler.NotifyFsEvent(notification) plugin.Handler.NotifyFsEvent(notification)
} }
var errRes error
if hasRules {
params := EventParams{
Name: notification.Username,
Groups: conn.User.Groups,
Event: notification.Action,
Status: notification.Status,
VirtualPath: notification.VirtualPath,
FsPath: notification.Path,
VirtualTargetPath: notification.VirtualTargetPath,
FsTargetPath: notification.TargetPath,
ObjectName: path.Base(notification.VirtualPath),
FileSize: notification.FileSize,
Protocol: notification.Protocol,
IP: notification.IP,
Timestamp: notification.Timestamp,
Object: nil,
}
if err != nil {
params.AddError(fmt.Errorf("%q failed: %w", params.Event, err))
}
errRes = eventManager.handleFsEvent(params)
}
if hasHook { if hasHook {
if util.Contains(Config.Actions.ExecuteSync, operation) { if util.Contains(Config.Actions.ExecuteSync, operation) {
if errHook := actionHandler.Handle(notification); errHook != nil { actionHandler.Handle(notification) //nolint:errcheck
errRes = errHook return
}
} else {
go func() {
startNewHook()
defer hookEnded()
actionHandler.Handle(notification) //nolint:errcheck
}()
} }
go actionHandler.Handle(notification) //nolint:errcheck
} }
return errRes
} }
// ActionHandler handles a notification for a Protocol Action. // ActionHandler handles a notification for a Protocol Action.
@@ -177,6 +133,7 @@ func newActionNotification(
err error, err error,
) *notifier.FsEvent { ) *notifier.FsEvent {
var bucket, endpoint string var bucket, endpoint string
status := 1
fsConfig := user.GetFsConfigForPath(virtualPath) fsConfig := user.GetFsConfigForPath(virtualPath)
@@ -193,8 +150,12 @@ func newActionNotification(
} }
case sdk.SFTPFilesystemProvider: case sdk.SFTPFilesystemProvider:
endpoint = fsConfig.SFTPConfig.Endpoint endpoint = fsConfig.SFTPConfig.Endpoint
case sdk.HTTPFilesystemProvider: }
endpoint = fsConfig.HTTPConfig.Endpoint
if err == ErrQuotaExceeded {
status = 3
} else if err != nil {
status = 2
} }
return &notifier.FsEvent{ return &notifier.FsEvent{
@@ -209,7 +170,7 @@ func newActionNotification(
FsProvider: int(fsConfig.Provider), FsProvider: int(fsConfig.Provider),
Bucket: bucket, Bucket: bucket,
Endpoint: endpoint, Endpoint: endpoint,
Status: getNotificationStatus(err), Status: status,
Protocol: protocol, Protocol: protocol,
IP: ip, IP: ip,
SessionID: sessionID, SessionID: sessionID,
@@ -262,7 +223,7 @@ func (h *defaultActionHandler) handleHTTP(event *notifier.FsEvent) error {
} }
} }
logger.Debug(event.Protocol, "", "notified operation %q to URL: %s status code: %d, elapsed: %s err: %v", logger.Debug(event.Protocol, "", "notified operation %#v to URL: %v status code: %v, elapsed: %v err: %v",
event.Action, u.Redacted(), respCode, time.Since(startTime), err) event.Action, u.Redacted(), respCode, time.Since(startTime), err)
return err return err
@@ -276,11 +237,11 @@ func (h *defaultActionHandler) handleCommand(event *notifier.FsEvent) error {
return err return err
} }
timeout, env, args := command.GetConfig(Config.Actions.Hook, command.HookFsActions) timeout, env := command.GetConfig(Config.Actions.Hook)
ctx, cancel := context.WithTimeout(context.Background(), timeout) ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel() defer cancel()
cmd := exec.CommandContext(ctx, Config.Actions.Hook, args...) cmd := exec.CommandContext(ctx, Config.Actions.Hook)
cmd.Env = append(env, notificationAsEnvVars(event)...) cmd.Env = append(env, notificationAsEnvVars(event)...)
startTime := time.Now() startTime := time.Now()
@@ -294,32 +255,22 @@ func (h *defaultActionHandler) handleCommand(event *notifier.FsEvent) error {
func notificationAsEnvVars(event *notifier.FsEvent) []string { func notificationAsEnvVars(event *notifier.FsEvent) []string {
return []string{ return []string{
fmt.Sprintf("SFTPGO_ACTION=%s", event.Action), fmt.Sprintf("SFTPGO_ACTION=%v", event.Action),
fmt.Sprintf("SFTPGO_ACTION_USERNAME=%s", event.Username), fmt.Sprintf("SFTPGO_ACTION_USERNAME=%v", event.Username),
fmt.Sprintf("SFTPGO_ACTION_PATH=%s", event.Path), fmt.Sprintf("SFTPGO_ACTION_PATH=%v", event.Path),
fmt.Sprintf("SFTPGO_ACTION_TARGET=%s", event.TargetPath), fmt.Sprintf("SFTPGO_ACTION_TARGET=%v", event.TargetPath),
fmt.Sprintf("SFTPGO_ACTION_VIRTUAL_PATH=%s", event.VirtualPath), fmt.Sprintf("SFTPGO_ACTION_VIRTUAL_PATH=%v", event.VirtualPath),
fmt.Sprintf("SFTPGO_ACTION_VIRTUAL_TARGET=%s", event.VirtualTargetPath), fmt.Sprintf("SFTPGO_ACTION_VIRTUAL_TARGET=%v", event.VirtualTargetPath),
fmt.Sprintf("SFTPGO_ACTION_SSH_CMD=%s", event.SSHCmd), fmt.Sprintf("SFTPGO_ACTION_SSH_CMD=%v", event.SSHCmd),
fmt.Sprintf("SFTPGO_ACTION_FILE_SIZE=%d", event.FileSize), fmt.Sprintf("SFTPGO_ACTION_FILE_SIZE=%v", event.FileSize),
fmt.Sprintf("SFTPGO_ACTION_FS_PROVIDER=%d", event.FsProvider), fmt.Sprintf("SFTPGO_ACTION_FS_PROVIDER=%v", event.FsProvider),
fmt.Sprintf("SFTPGO_ACTION_BUCKET=%s", event.Bucket), fmt.Sprintf("SFTPGO_ACTION_BUCKET=%v", event.Bucket),
fmt.Sprintf("SFTPGO_ACTION_ENDPOINT=%s", event.Endpoint), fmt.Sprintf("SFTPGO_ACTION_ENDPOINT=%v", event.Endpoint),
fmt.Sprintf("SFTPGO_ACTION_STATUS=%d", event.Status), fmt.Sprintf("SFTPGO_ACTION_STATUS=%v", event.Status),
fmt.Sprintf("SFTPGO_ACTION_PROTOCOL=%s", event.Protocol), fmt.Sprintf("SFTPGO_ACTION_PROTOCOL=%v", event.Protocol),
fmt.Sprintf("SFTPGO_ACTION_IP=%s", event.IP), fmt.Sprintf("SFTPGO_ACTION_IP=%v", event.IP),
fmt.Sprintf("SFTPGO_ACTION_SESSION_ID=%s", event.SessionID), fmt.Sprintf("SFTPGO_ACTION_SESSION_ID=%v", event.SessionID),
fmt.Sprintf("SFTPGO_ACTION_OPEN_FLAGS=%d", event.OpenFlags), fmt.Sprintf("SFTPGO_ACTION_OPEN_FLAGS=%v", event.OpenFlags),
fmt.Sprintf("SFTPGO_ACTION_TIMESTAMP=%d", event.Timestamp), fmt.Sprintf("SFTPGO_ACTION_TIMESTAMP=%v", event.Timestamp),
} }
} }
func getNotificationStatus(err error) int {
status := 1
if err == ErrQuotaExceeded {
status = 3
} else if err != nil {
status = 2
}
return status
}

View File

@@ -29,9 +29,9 @@ import (
"github.com/sftpgo/sdk/plugin/notifier" "github.com/sftpgo/sdk/plugin/notifier"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/plugin" "github.com/drakkan/sftpgo/v2/plugin"
"github.com/drakkan/sftpgo/v2/internal/vfs" "github.com/drakkan/sftpgo/v2/vfs"
) )
func TestNewActionNotification(t *testing.T) { func TestNewActionNotification(t *testing.T) {
@@ -63,11 +63,6 @@ func TestNewActionNotification(t *testing.T) {
Endpoint: "sftpendpoint", Endpoint: "sftpendpoint",
}, },
} }
user.FsConfig.HTTPConfig = vfs.HTTPFsConfig{
BaseHTTPFsConfig: sdk.BaseHTTPFsConfig{
Endpoint: "httpendpoint",
},
}
sessionID := xid.New().String() sessionID := xid.New().String()
a := newActionNotification(user, operationDownload, "path", "vpath", "target", "", "", ProtocolSFTP, "", sessionID, a := newActionNotification(user, operationDownload, "path", "vpath", "target", "", "", ProtocolSFTP, "", sessionID,
123, 0, errors.New("fake error")) 123, 0, errors.New("fake error"))
@@ -90,12 +85,6 @@ func TestNewActionNotification(t *testing.T) {
assert.Equal(t, 0, len(a.Endpoint)) assert.Equal(t, 0, len(a.Endpoint))
assert.Equal(t, 3, a.Status) assert.Equal(t, 3, a.Status)
user.FsConfig.Provider = sdk.HTTPFilesystemProvider
a = newActionNotification(user, operationDownload, "path", "vpath", "target", "", "", ProtocolSSH, "", sessionID,
123, 0, nil)
assert.Equal(t, "httpendpoint", a.Endpoint)
assert.Equal(t, 1, a.Status)
user.FsConfig.Provider = sdk.AzureBlobFilesystemProvider user.FsConfig.Provider = sdk.AzureBlobFilesystemProvider
a = newActionNotification(user, operationDownload, "path", "vpath", "target", "", "", ProtocolSCP, "", sessionID, a = newActionNotification(user, operationDownload, "path", "vpath", "target", "", "", ProtocolSCP, "", sessionID,
123, 0, nil) 123, 0, nil)
@@ -171,11 +160,9 @@ func TestActionCMD(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
c := NewBaseConnection("id", ProtocolSFTP, "", "", *user) c := NewBaseConnection("id", ProtocolSFTP, "", "", *user)
err = ExecuteActionNotification(c, OperationSSHCmd, "path", "vpath", "target", "vtarget", "sha1sum", 0, nil) ExecuteActionNotification(c, OperationSSHCmd, "path", "vpath", "target", "vtarget", "sha1sum", 0, nil)
assert.NoError(t, err)
err = ExecuteActionNotification(c, operationDownload, "path", "vpath", "", "", "", 0, nil) ExecuteActionNotification(c, operationDownload, "path", "vpath", "", "", "", 0, nil)
assert.NoError(t, err)
Config.Actions = actionsCopy Config.Actions = actionsCopy
} }
@@ -278,7 +265,7 @@ func TestUnconfiguredHook(t *testing.T) {
Type: "notifier", Type: "notifier",
}, },
} }
err := plugin.Initialize(pluginsConfig, "debug") err := plugin.Initialize(pluginsConfig, true)
assert.Error(t, err) assert.Error(t, err)
assert.True(t, plugin.Handler.HasNotifiers()) assert.True(t, plugin.Handler.HasNotifiers())
@@ -288,10 +275,9 @@ func TestUnconfiguredHook(t *testing.T) {
err = ExecutePreAction(c, operationPreDelete, "", "", 0, 0) err = ExecutePreAction(c, operationPreDelete, "", "", 0, 0)
assert.ErrorIs(t, err, errUnconfiguredAction) assert.ErrorIs(t, err, errUnconfiguredAction)
err = ExecuteActionNotification(c, operationDownload, "", "", "", "", "", 0, nil) ExecuteActionNotification(c, operationDownload, "", "", "", "", "", 0, nil)
assert.NoError(t, err)
err = plugin.Initialize(nil, "debug") err = plugin.Initialize(nil, true)
assert.NoError(t, err) assert.NoError(t, err)
assert.False(t, plugin.Handler.HasNotifiers()) assert.False(t, plugin.Handler.HasNotifiers())

View File

@@ -18,18 +18,18 @@ import (
"sync" "sync"
"sync/atomic" "sync/atomic"
"github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/logger"
) )
// clienstMap is a struct containing the map of the connected clients // clienstMap is a struct containing the map of the connected clients
type clientsMap struct { type clientsMap struct {
totalConnections atomic.Int32 totalConnections int32
mu sync.RWMutex mu sync.RWMutex
clients map[string]int clients map[string]int
} }
func (c *clientsMap) add(source string) { func (c *clientsMap) add(source string) {
c.totalConnections.Add(1) atomic.AddInt32(&c.totalConnections, 1)
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
@@ -42,7 +42,7 @@ func (c *clientsMap) remove(source string) {
defer c.mu.Unlock() defer c.mu.Unlock()
if val, ok := c.clients[source]; ok { if val, ok := c.clients[source]; ok {
c.totalConnections.Add(-1) atomic.AddInt32(&c.totalConnections, -1)
c.clients[source]-- c.clients[source]--
if val > 1 { if val > 1 {
return return
@@ -54,7 +54,7 @@ func (c *clientsMap) remove(source string) {
} }
func (c *clientsMap) getTotal() int32 { func (c *clientsMap) getTotal() int32 {
return c.totalConnections.Load() return atomic.LoadInt32(&c.totalConnections)
} }
func (c *clientsMap) getTotalFrom(source string) int { func (c *clientsMap) getTotalFrom(source string) int {

View File

@@ -33,35 +33,33 @@ import (
"github.com/pires/go-proxyproto" "github.com/pires/go-proxyproto"
"github.com/drakkan/sftpgo/v2/internal/command" "github.com/drakkan/sftpgo/v2/command"
"github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/httpclient" "github.com/drakkan/sftpgo/v2/httpclient"
"github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/internal/metric" "github.com/drakkan/sftpgo/v2/metric"
"github.com/drakkan/sftpgo/v2/internal/plugin" "github.com/drakkan/sftpgo/v2/plugin"
"github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/util"
"github.com/drakkan/sftpgo/v2/internal/vfs" "github.com/drakkan/sftpgo/v2/vfs"
) )
// constants // constants
const ( const (
logSender = "common" logSender = "common"
uploadLogSender = "Upload" uploadLogSender = "Upload"
downloadLogSender = "Download" downloadLogSender = "Download"
renameLogSender = "Rename" renameLogSender = "Rename"
rmdirLogSender = "Rmdir" rmdirLogSender = "Rmdir"
mkdirLogSender = "Mkdir" mkdirLogSender = "Mkdir"
symlinkLogSender = "Symlink" symlinkLogSender = "Symlink"
removeLogSender = "Remove" removeLogSender = "Remove"
chownLogSender = "Chown" chownLogSender = "Chown"
chmodLogSender = "Chmod" chmodLogSender = "Chmod"
chtimesLogSender = "Chtimes" chtimesLogSender = "Chtimes"
truncateLogSender = "Truncate" truncateLogSender = "Truncate"
operationDownload = "download" operationDownload = "download"
operationUpload = "upload" operationUpload = "upload"
operationFirstDownload = "first-download" operationDelete = "delete"
operationFirstUpload = "first-upload"
operationDelete = "delete"
// Pre-download action name // Pre-download action name
OperationPreDownload = "pre-download" OperationPreDownload = "pre-download"
// Pre-upload action name // Pre-upload action name
@@ -102,7 +100,6 @@ const (
ProtocolHTTPShare = "HTTPShare" ProtocolHTTPShare = "HTTPShare"
ProtocolDataRetention = "DataRetention" ProtocolDataRetention = "DataRetention"
ProtocolOIDC = "OIDC" ProtocolOIDC = "OIDC"
protocolEventAction = "EventAction"
) )
// Upload modes // Upload modes
@@ -117,27 +114,25 @@ func init() {
clients: make(map[string]int), clients: make(map[string]int),
} }
Connections.perUserConns = make(map[string]int) Connections.perUserConns = make(map[string]int)
Connections.mapping = make(map[string]int)
Connections.sshMapping = make(map[string]int)
} }
// errors definitions // errors definitions
var ( var (
ErrPermissionDenied = errors.New("permission denied") ErrPermissionDenied = errors.New("permission denied")
ErrNotExist = errors.New("no such file or directory") ErrNotExist = errors.New("no such file or directory")
ErrOpUnsupported = errors.New("operation unsupported") ErrOpUnsupported = errors.New("operation unsupported")
ErrGenericFailure = errors.New("failure") ErrGenericFailure = errors.New("failure")
ErrQuotaExceeded = errors.New("denying write due to space limit") ErrQuotaExceeded = errors.New("denying write due to space limit")
ErrReadQuotaExceeded = errors.New("denying read due to quota limit") ErrReadQuotaExceeded = errors.New("denying read due to quota limit")
ErrConnectionDenied = errors.New("you are not allowed to connect") ErrSkipPermissionsCheck = errors.New("permission check skipped")
ErrNoBinding = errors.New("no binding configured") ErrConnectionDenied = errors.New("you are not allowed to connect")
ErrCrtRevoked = errors.New("your certificate has been revoked") ErrNoBinding = errors.New("no binding configured")
ErrNoCredentials = errors.New("no credential provided") ErrCrtRevoked = errors.New("your certificate has been revoked")
ErrInternalFailure = errors.New("internal failure") ErrNoCredentials = errors.New("no credential provided")
ErrTransferAborted = errors.New("transfer aborted") ErrInternalFailure = errors.New("internal failure")
ErrShuttingDown = errors.New("the service is shutting down") ErrTransferAborted = errors.New("transfer aborted")
errNoTransfer = errors.New("requested transfer not found") errNoTransfer = errors.New("requested transfer not found")
errTransferMismatch = errors.New("transfer mismatch") errTransferMismatch = errors.New("transfer mismatch")
) )
var ( var (
@@ -146,28 +141,26 @@ var (
// Connections is the list of active connections // Connections is the list of active connections
Connections ActiveConnections Connections ActiveConnections
// QuotaScans is the list of active quota scans // QuotaScans is the list of active quota scans
QuotaScans ActiveScans QuotaScans ActiveScans
// ActiveMetadataChecks holds the active metadata checks transfersChecker TransfersChecker
ActiveMetadataChecks MetadataChecks periodicTimeoutTicker *time.Ticker
transfersChecker TransfersChecker periodicTimeoutTickerDone chan bool
supportedProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP, ProtocolWebDAV, supportedProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP, ProtocolWebDAV,
ProtocolHTTP, ProtocolHTTPShare, ProtocolOIDC} ProtocolHTTP, ProtocolHTTPShare, ProtocolOIDC}
disconnHookProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP} disconnHookProtocols = []string{ProtocolSFTP, ProtocolSCP, ProtocolSSH, ProtocolFTP}
// the map key is the protocol, for each protocol we can have multiple rate limiters // the map key is the protocol, for each protocol we can have multiple rate limiters
rateLimiters map[string][]*rateLimiter rateLimiters map[string][]*rateLimiter
isShuttingDown atomic.Bool
) )
// Initialize sets the common configuration // Initialize sets the common configuration
func Initialize(c Configuration, isShared int) error { func Initialize(c Configuration, isShared int) error {
isShuttingDown.Store(false)
Config = c Config = c
Config.Actions.ExecuteOn = util.RemoveDuplicates(Config.Actions.ExecuteOn, true) Config.Actions.ExecuteOn = util.RemoveDuplicates(Config.Actions.ExecuteOn, true)
Config.Actions.ExecuteSync = util.RemoveDuplicates(Config.Actions.ExecuteSync, true) Config.Actions.ExecuteSync = util.RemoveDuplicates(Config.Actions.ExecuteSync, true)
Config.ProxyAllowed = util.RemoveDuplicates(Config.ProxyAllowed, true) Config.ProxyAllowed = util.RemoveDuplicates(Config.ProxyAllowed, true)
Config.idleLoginTimeout = 2 * time.Minute Config.idleLoginTimeout = 2 * time.Minute
Config.idleTimeoutAsDuration = time.Duration(Config.IdleTimeout) * time.Minute Config.idleTimeoutAsDuration = time.Duration(Config.IdleTimeout) * time.Minute
startPeriodicChecks(periodicTimeoutCheckInterval) startPeriodicTimeoutTicker(periodicTimeoutCheckInterval)
Config.defender = nil Config.defender = nil
Config.whitelist = nil Config.whitelist = nil
rateLimiters = make(map[string][]*rateLimiter) rateLimiters = make(map[string][]*rateLimiter)
@@ -217,73 +210,10 @@ func Initialize(c Configuration, isShared int) error {
} }
vfs.SetTempPath(c.TempPath) vfs.SetTempPath(c.TempPath)
dataprovider.SetTempPath(c.TempPath) dataprovider.SetTempPath(c.TempPath)
vfs.SetAllowSelfConnections(c.AllowSelfConnections)
dataprovider.SetAllowSelfConnections(c.AllowSelfConnections)
transfersChecker = getTransfersChecker(isShared) transfersChecker = getTransfersChecker(isShared)
return nil return nil
} }
// CheckClosing returns an error if the service is closing
func CheckClosing() error {
if isShuttingDown.Load() {
return ErrShuttingDown
}
return nil
}
// WaitForTransfers waits, for the specified grace time, for currently ongoing
// client-initiated transfer sessions to completes.
// A zero graceTime means no wait
func WaitForTransfers(graceTime int) {
if graceTime == 0 {
return
}
if isShuttingDown.Swap(true) {
return
}
if activeHooks.Load() == 0 && getActiveConnections() == 0 {
return
}
graceTimer := time.NewTimer(time.Duration(graceTime) * time.Second)
ticker := time.NewTicker(3 * time.Second)
for {
select {
case <-ticker.C:
hooks := activeHooks.Load()
logger.Info(logSender, "", "active hooks: %d", hooks)
if hooks == 0 && getActiveConnections() == 0 {
logger.Info(logSender, "", "no more active connections, graceful shutdown")
ticker.Stop()
graceTimer.Stop()
return
}
case <-graceTimer.C:
logger.Info(logSender, "", "grace time expired, hard shutdown")
ticker.Stop()
return
}
}
}
// getActiveConnections returns the number of connections with active transfers
func getActiveConnections() int {
var activeConns int
Connections.RLock()
for _, c := range Connections.connections {
if len(c.GetTransfers()) > 0 {
activeConns++
}
}
Connections.RUnlock()
logger.Info(logSender, "", "number of connections with active transfers: %d", activeConns)
return activeConns
}
// LimitRate blocks until all the configured rate limiters // LimitRate blocks until all the configured rate limiters
// allow one event to happen. // allow one event to happen.
// It returns an error if the time to wait exceeds the max // It returns an error if the time to wait exceeds the max
@@ -381,18 +311,35 @@ func AddDefenderEvent(ip string, event HostEvent) {
Config.defender.AddEvent(ip, event) Config.defender.AddEvent(ip, event)
} }
func startPeriodicChecks(duration time.Duration) { // the ticker cannot be started/stopped from multiple goroutines
startEventScheduler() func startPeriodicTimeoutTicker(duration time.Duration) {
spec := fmt.Sprintf("@every %s", duration) stopPeriodicTimeoutTicker()
_, err := eventScheduler.AddFunc(spec, Connections.checkTransfers) periodicTimeoutTicker = time.NewTicker(duration)
util.PanicOnError(err) periodicTimeoutTickerDone = make(chan bool)
logger.Info(logSender, "", "scheduled overquota transfers check, schedule %q", spec) go func() {
if Config.IdleTimeout > 0 { counter := int64(0)
ratio := idleTimeoutCheckInterval / periodicTimeoutCheckInterval ratio := idleTimeoutCheckInterval / periodicTimeoutCheckInterval
spec = fmt.Sprintf("@every %s", duration*ratio) for {
_, err = eventScheduler.AddFunc(spec, Connections.checkIdles) select {
util.PanicOnError(err) case <-periodicTimeoutTickerDone:
logger.Info(logSender, "", "scheduled idle connections check, schedule %q", spec) return
case <-periodicTimeoutTicker.C:
counter++
if Config.IdleTimeout > 0 && counter >= int64(ratio) {
counter = 0
Connections.checkIdles()
}
go Connections.checkTransfers()
}
}
}()
}
func stopPeriodicTimeoutTicker() {
if periodicTimeoutTicker != nil {
periodicTimeoutTicker.Stop()
periodicTimeoutTickerDone <- true
periodicTimeoutTicker = nil
} }
} }
@@ -464,11 +411,11 @@ func (t *ConnectionTransfer) getConnectionTransferAsString() string {
case operationDownload: case operationDownload:
result += "DL " result += "DL "
} }
result += fmt.Sprintf("%q ", t.VirtualPath) result += fmt.Sprintf("%#v ", t.VirtualPath)
if t.Size > 0 { if t.Size > 0 {
elapsed := time.Since(util.GetTimeFromMsecSinceEpoch(t.StartTime)) elapsed := time.Since(util.GetTimeFromMsecSinceEpoch(t.StartTime))
speed := float64(t.Size) / float64(util.GetTimeAsMsSinceEpoch(time.Now())-t.StartTime) speed := float64(t.Size) / float64(util.GetTimeAsMsSinceEpoch(time.Now())-t.StartTime)
result += fmt.Sprintf("Size: %s Elapsed: %s Speed: \"%.1f KB/s\"", util.ByteCountIEC(t.Size), result += fmt.Sprintf("Size: %#v Elapsed: %#v Speed: \"%.1f KB/s\"", util.ByteCountIEC(t.Size),
util.GetDurationAsString(elapsed), speed) util.GetDurationAsString(elapsed), speed)
} }
return result return result
@@ -572,9 +519,6 @@ type Configuration struct {
// Only the listed IPs/networks can access the configured services, all other client connections // Only the listed IPs/networks can access the configured services, all other client connections
// will be dropped before they even try to authenticate. // will be dropped before they even try to authenticate.
WhiteListFile string `json:"whitelist_file" mapstructure:"whitelist_file"` WhiteListFile string `json:"whitelist_file" mapstructure:"whitelist_file"`
// Allow users on this instance to use other users/virtual folders on this instance as storage backend.
// Enable this setting if you know what you are doing.
AllowSelfConnections int `json:"allow_self_connections" mapstructure:"allow_self_connections"`
// Defender configuration // Defender configuration
DefenderConfig DefenderConfig `json:"defender" mapstructure:"defender"` DefenderConfig DefenderConfig `json:"defender" mapstructure:"defender"`
// Rate limiter configurations // Rate limiter configurations
@@ -651,11 +595,11 @@ func (c *Configuration) ExecuteStartupHook() error {
return err return err
} }
startTime := time.Now() startTime := time.Now()
timeout, env, args := command.GetConfig(c.StartupHook, command.HookStartup) timeout, env := command.GetConfig(c.StartupHook)
ctx, cancel := context.WithTimeout(context.Background(), timeout) ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel() defer cancel()
cmd := exec.CommandContext(ctx, c.StartupHook, args...) cmd := exec.CommandContext(ctx, c.StartupHook)
cmd.Env = env cmd.Env = env
err := cmd.Run() err := cmd.Run()
logger.Debug(logSender, "", "Startup hook executed, elapsed: %v, error: %v", time.Since(startTime), err) logger.Debug(logSender, "", "Startup hook executed, elapsed: %v, error: %v", time.Since(startTime), err)
@@ -663,9 +607,6 @@ func (c *Configuration) ExecuteStartupHook() error {
} }
func (c *Configuration) executePostDisconnectHook(remoteAddr, protocol, username, connID string, connectionTime time.Time) { func (c *Configuration) executePostDisconnectHook(remoteAddr, protocol, username, connID string, connectionTime time.Time) {
startNewHook()
defer hookEnded()
ipAddr := util.GetIPFromRemoteAddress(remoteAddr) ipAddr := util.GetIPFromRemoteAddress(remoteAddr)
connDuration := int64(time.Since(connectionTime) / time.Millisecond) connDuration := int64(time.Since(connectionTime) / time.Millisecond)
@@ -697,12 +638,12 @@ func (c *Configuration) executePostDisconnectHook(remoteAddr, protocol, username
logger.Debug(protocol, connID, "invalid post disconnect hook %#v", c.PostDisconnectHook) logger.Debug(protocol, connID, "invalid post disconnect hook %#v", c.PostDisconnectHook)
return return
} }
timeout, env, args := command.GetConfig(c.PostDisconnectHook, command.HookPostDisconnect) timeout, env := command.GetConfig(c.PostDisconnectHook)
ctx, cancel := context.WithTimeout(context.Background(), timeout) ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel() defer cancel()
startTime := time.Now() startTime := time.Now()
cmd := exec.CommandContext(ctx, c.PostDisconnectHook, args...) cmd := exec.CommandContext(ctx, c.PostDisconnectHook)
cmd.Env = append(env, cmd.Env = append(env,
fmt.Sprintf("SFTPGO_CONNECTION_IP=%v", ipAddr), fmt.Sprintf("SFTPGO_CONNECTION_IP=%v", ipAddr),
fmt.Sprintf("SFTPGO_CONNECTION_USERNAME=%v", username), fmt.Sprintf("SFTPGO_CONNECTION_USERNAME=%v", username),
@@ -757,11 +698,11 @@ func (c *Configuration) ExecutePostConnectHook(ipAddr, protocol string) error {
logger.Warn(protocol, "", "Login from ip %#v denied: %v", ipAddr, err) logger.Warn(protocol, "", "Login from ip %#v denied: %v", ipAddr, err)
return err return err
} }
timeout, env, args := command.GetConfig(c.PostConnectHook, command.HookPostConnect) timeout, env := command.GetConfig(c.PostConnectHook)
ctx, cancel := context.WithTimeout(context.Background(), timeout) ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel() defer cancel()
cmd := exec.CommandContext(ctx, c.PostConnectHook, args...) cmd := exec.CommandContext(ctx, c.PostConnectHook)
cmd.Env = append(env, cmd.Env = append(env,
fmt.Sprintf("SFTPGO_CONNECTION_IP=%v", ipAddr), fmt.Sprintf("SFTPGO_CONNECTION_IP=%v", ipAddr),
fmt.Sprintf("SFTPGO_CONNECTION_PROTOCOL=%v", protocol)) fmt.Sprintf("SFTPGO_CONNECTION_PROTOCOL=%v", protocol))
@@ -777,17 +718,16 @@ func (c *Configuration) ExecutePostConnectHook(ipAddr, protocol string) error {
type SSHConnection struct { type SSHConnection struct {
id string id string
conn net.Conn conn net.Conn
lastActivity atomic.Int64 lastActivity int64
} }
// NewSSHConnection returns a new SSHConnection // NewSSHConnection returns a new SSHConnection
func NewSSHConnection(id string, conn net.Conn) *SSHConnection { func NewSSHConnection(id string, conn net.Conn) *SSHConnection {
c := &SSHConnection{ return &SSHConnection{
id: id, id: id,
conn: conn, conn: conn,
lastActivity: time.Now().UnixNano(),
} }
c.lastActivity.Store(time.Now().UnixNano())
return c
} }
// GetID returns the ID for this SSHConnection // GetID returns the ID for this SSHConnection
@@ -797,12 +737,12 @@ func (c *SSHConnection) GetID() string {
// UpdateLastActivity updates last activity for this connection // UpdateLastActivity updates last activity for this connection
func (c *SSHConnection) UpdateLastActivity() { func (c *SSHConnection) UpdateLastActivity() {
c.lastActivity.Store(time.Now().UnixNano()) atomic.StoreInt64(&c.lastActivity, time.Now().UnixNano())
} }
// GetLastActivity returns the last connection activity // GetLastActivity returns the last connection activity
func (c *SSHConnection) GetLastActivity() time.Time { func (c *SSHConnection) GetLastActivity() time.Time {
return time.Unix(0, c.lastActivity.Load()) return time.Unix(0, atomic.LoadInt64(&c.lastActivity))
} }
// Close closes the underlying network connection // Close closes the underlying network connection
@@ -815,12 +755,10 @@ type ActiveConnections struct {
// clients contains both authenticated and estabilished connections and the ones waiting // clients contains both authenticated and estabilished connections and the ones waiting
// for authentication // for authentication
clients clientsMap clients clientsMap
transfersCheckStatus atomic.Bool transfersCheckStatus int32
sync.RWMutex sync.RWMutex
connections []ActiveConnection connections []ActiveConnection
mapping map[string]int
sshConnections []*SSHConnection sshConnections []*SSHConnection
sshMapping map[string]int
perUserConns map[string]int perUserConns map[string]int
} }
@@ -868,10 +806,9 @@ func (conns *ActiveConnections) Add(c ActiveConnection) error {
} }
conns.addUserConnection(username) conns.addUserConnection(username)
} }
conns.mapping[c.GetID()] = len(conns.connections)
conns.connections = append(conns.connections, c) conns.connections = append(conns.connections, c)
metric.UpdateActiveConnectionsSize(len(conns.connections)) metric.UpdateActiveConnectionsSize(len(conns.connections))
logger.Debug(c.GetProtocol(), c.GetID(), "connection added, local address %q, remote address %q, num open connections: %d", logger.Debug(c.GetProtocol(), c.GetID(), "connection added, local address %#v, remote address %#v, num open connections: %v",
c.GetLocalAddress(), c.GetRemoteAddress(), len(conns.connections)) c.GetLocalAddress(), c.GetRemoteAddress(), len(conns.connections))
return nil return nil
} }
@@ -884,25 +821,25 @@ func (conns *ActiveConnections) Swap(c ActiveConnection) error {
conns.Lock() conns.Lock()
defer conns.Unlock() defer conns.Unlock()
if idx, ok := conns.mapping[c.GetID()]; ok { for idx, conn := range conns.connections {
conn := conns.connections[idx] if conn.GetID() == c.GetID() {
conns.removeUserConnection(conn.GetUsername()) conns.removeUserConnection(conn.GetUsername())
if username := c.GetUsername(); username != "" { if username := c.GetUsername(); username != "" {
if maxSessions := c.GetMaxSessions(); maxSessions > 0 { if maxSessions := c.GetMaxSessions(); maxSessions > 0 {
if val, ok := conns.perUserConns[username]; ok && val >= maxSessions { if val := conns.perUserConns[username]; val >= maxSessions {
conns.addUserConnection(conn.GetUsername()) conns.addUserConnection(conn.GetUsername())
return fmt.Errorf("too many open sessions: %d/%d", val, maxSessions) return fmt.Errorf("too many open sessions: %d/%d", val, maxSessions)
}
} }
conns.addUserConnection(username)
} }
conns.addUserConnection(username) err := conn.CloseFS()
conns.connections[idx] = c
logger.Debug(logSender, c.GetID(), "connection swapped, close fs error: %v", err)
conn = nil
return nil
} }
err := conn.CloseFS()
conns.connections[idx] = c
logger.Debug(logSender, c.GetID(), "connection swapped, close fs error: %v", err)
conn = nil
return nil
} }
return errors.New("connection to swap not found") return errors.New("connection to swap not found")
} }
@@ -911,53 +848,40 @@ func (conns *ActiveConnections) Remove(connectionID string) {
conns.Lock() conns.Lock()
defer conns.Unlock() defer conns.Unlock()
if idx, ok := conns.mapping[connectionID]; ok { for idx, conn := range conns.connections {
conn := conns.connections[idx] if conn.GetID() == connectionID {
err := conn.CloseFS() err := conn.CloseFS()
lastIdx := len(conns.connections) - 1 lastIdx := len(conns.connections) - 1
conns.connections[idx] = conns.connections[lastIdx] conns.connections[idx] = conns.connections[lastIdx]
conns.connections[lastIdx] = nil conns.connections[lastIdx] = nil
conns.connections = conns.connections[:lastIdx] conns.connections = conns.connections[:lastIdx]
delete(conns.mapping, connectionID) conns.removeUserConnection(conn.GetUsername())
if idx != lastIdx { metric.UpdateActiveConnectionsSize(lastIdx)
conns.mapping[conns.connections[idx].GetID()] = idx logger.Debug(conn.GetProtocol(), conn.GetID(), "connection removed, local address %#v, remote address %#v close fs error: %v, num open connections: %v",
conn.GetLocalAddress(), conn.GetRemoteAddress(), err, lastIdx)
Config.checkPostDisconnectHook(conn.GetRemoteAddress(), conn.GetProtocol(), conn.GetUsername(),
conn.GetID(), conn.GetConnectionTime())
return
} }
conns.removeUserConnection(conn.GetUsername())
metric.UpdateActiveConnectionsSize(lastIdx)
logger.Debug(conn.GetProtocol(), conn.GetID(), "connection removed, local address %#v, remote address %#v close fs error: %v, num open connections: %v",
conn.GetLocalAddress(), conn.GetRemoteAddress(), err, lastIdx)
if conn.GetProtocol() == ProtocolFTP && conn.GetUsername() == "" {
ip := util.GetIPFromRemoteAddress(conn.GetRemoteAddress())
logger.ConnectionFailedLog("", ip, dataprovider.LoginMethodNoAuthTryed, conn.GetProtocol(),
dataprovider.ErrNoAuthTryed.Error())
metric.AddNoAuthTryed()
AddDefenderEvent(ip, HostEventNoLoginTried)
dataprovider.ExecutePostLoginHook(&dataprovider.User{}, dataprovider.LoginMethodNoAuthTryed, ip,
conn.GetProtocol(), dataprovider.ErrNoAuthTryed)
}
Config.checkPostDisconnectHook(conn.GetRemoteAddress(), conn.GetProtocol(), conn.GetUsername(),
conn.GetID(), conn.GetConnectionTime())
return
} }
logger.Warn(logSender, "", "connection id %#v to remove not found!", connectionID)
logger.Warn(logSender, "", "connection id %q to remove not found!", connectionID)
} }
// Close closes an active connection. // Close closes an active connection.
// It returns true on success // It returns true on success
func (conns *ActiveConnections) Close(connectionID string) bool { func (conns *ActiveConnections) Close(connectionID string) bool {
conns.RLock() conns.RLock()
result := false
var result bool for _, c := range conns.connections {
if c.GetID() == connectionID {
if idx, ok := conns.mapping[connectionID]; ok { defer func(conn ActiveConnection) {
c := conns.connections[idx] err := conn.Disconnect()
logger.Debug(conn.GetProtocol(), conn.GetID(), "close connection requested, close err: %v", err)
defer func(conn ActiveConnection) { }(c)
err := conn.Disconnect() result = true
logger.Debug(conn.GetProtocol(), conn.GetID(), "close connection requested, close err: %v", err) break
}(c) }
result = true
} }
conns.RUnlock() conns.RUnlock()
@@ -969,9 +893,8 @@ func (conns *ActiveConnections) AddSSHConnection(c *SSHConnection) {
conns.Lock() conns.Lock()
defer conns.Unlock() defer conns.Unlock()
conns.sshMapping[c.GetID()] = len(conns.sshConnections)
conns.sshConnections = append(conns.sshConnections, c) conns.sshConnections = append(conns.sshConnections, c)
logger.Debug(logSender, c.GetID(), "ssh connection added, num open connections: %d", len(conns.sshConnections)) logger.Debug(logSender, c.GetID(), "ssh connection added, num open connections: %v", len(conns.sshConnections))
} }
// RemoveSSHConnection removes a connection from the active ones // RemoveSSHConnection removes a connection from the active ones
@@ -979,19 +902,17 @@ func (conns *ActiveConnections) RemoveSSHConnection(connectionID string) {
conns.Lock() conns.Lock()
defer conns.Unlock() defer conns.Unlock()
if idx, ok := conns.sshMapping[connectionID]; ok { for idx, conn := range conns.sshConnections {
lastIdx := len(conns.sshConnections) - 1 if conn.GetID() == connectionID {
conns.sshConnections[idx] = conns.sshConnections[lastIdx] lastIdx := len(conns.sshConnections) - 1
conns.sshConnections[lastIdx] = nil conns.sshConnections[idx] = conns.sshConnections[lastIdx]
conns.sshConnections = conns.sshConnections[:lastIdx] conns.sshConnections[lastIdx] = nil
delete(conns.sshMapping, connectionID) conns.sshConnections = conns.sshConnections[:lastIdx]
if idx != lastIdx { logger.Debug(logSender, conn.GetID(), "ssh connection removed, num open ssh connections: %v", lastIdx)
conns.sshMapping[conns.sshConnections[idx].GetID()] = idx return
} }
logger.Debug(logSender, connectionID, "ssh connection removed, num open ssh connections: %d", lastIdx)
return
} }
logger.Warn(logSender, "", "ssh connection to remove with id %q not found!", connectionID) logger.Warn(logSender, "", "ssh connection to remove with id %#v not found!", connectionID)
} }
func (conns *ActiveConnections) checkIdles() { func (conns *ActiveConnections) checkIdles() {
@@ -1026,11 +947,19 @@ func (conns *ActiveConnections) checkIdles() {
isUnauthenticatedFTPUser := (c.GetProtocol() == ProtocolFTP && c.GetUsername() == "") isUnauthenticatedFTPUser := (c.GetProtocol() == ProtocolFTP && c.GetUsername() == "")
if idleTime > Config.idleTimeoutAsDuration || (isUnauthenticatedFTPUser && idleTime > Config.idleLoginTimeout) { if idleTime > Config.idleTimeoutAsDuration || (isUnauthenticatedFTPUser && idleTime > Config.idleLoginTimeout) {
defer func(conn ActiveConnection) { defer func(conn ActiveConnection, isFTPNoAuth bool) {
err := conn.Disconnect() err := conn.Disconnect()
logger.Debug(conn.GetProtocol(), conn.GetID(), "close idle connection, idle time: %v, username: %#v close err: %v", logger.Debug(conn.GetProtocol(), conn.GetID(), "close idle connection, idle time: %v, username: %#v close err: %v",
time.Since(conn.GetLastActivity()), conn.GetUsername(), err) time.Since(conn.GetLastActivity()), conn.GetUsername(), err)
}(c) if isFTPNoAuth {
ip := util.GetIPFromRemoteAddress(c.GetRemoteAddress())
logger.ConnectionFailedLog("", ip, dataprovider.LoginMethodNoAuthTryed, c.GetProtocol(), "client idle")
metric.AddNoAuthTryed()
AddDefenderEvent(ip, HostEventNoLoginTried)
dataprovider.ExecutePostLoginHook(&dataprovider.User{}, dataprovider.LoginMethodNoAuthTryed, ip, c.GetProtocol(),
dataprovider.ErrNoAuthTryed)
}
}(c, isUnauthenticatedFTPUser)
} }
} }
@@ -1038,12 +967,12 @@ func (conns *ActiveConnections) checkIdles() {
} }
func (conns *ActiveConnections) checkTransfers() { func (conns *ActiveConnections) checkTransfers() {
if conns.transfersCheckStatus.Load() { if atomic.LoadInt32(&conns.transfersCheckStatus) == 1 {
logger.Warn(logSender, "", "the previous transfer check is still running, skipping execution") logger.Warn(logSender, "", "the previous transfer check is still running, skipping execution")
return return
} }
conns.transfersCheckStatus.Store(true) atomic.StoreInt32(&conns.transfersCheckStatus, 1)
defer conns.transfersCheckStatus.Store(false) defer atomic.StoreInt32(&conns.transfersCheckStatus, 0)
conns.RLock() conns.RLock()
@@ -1115,34 +1044,30 @@ func (conns *ActiveConnections) GetClientConnections() int32 {
return conns.clients.getTotal() return conns.clients.getTotal()
} }
// IsNewConnectionAllowed returns an error if the maximum number of concurrent allowed // IsNewConnectionAllowed returns false if the maximum number of concurrent allowed connections is exceeded
// connections is exceeded or a whitelist is defined and the specified ipAddr is not listed // or a whitelist is defined and the specified ipAddr is not listed
// or the service is shutting down func (conns *ActiveConnections) IsNewConnectionAllowed(ipAddr string) bool {
func (conns *ActiveConnections) IsNewConnectionAllowed(ipAddr string) error {
if isShuttingDown.Load() {
return ErrShuttingDown
}
if Config.whitelist != nil { if Config.whitelist != nil {
if !Config.whitelist.isAllowed(ipAddr) { if !Config.whitelist.isAllowed(ipAddr) {
return ErrConnectionDenied return false
} }
} }
if Config.MaxTotalConnections == 0 && Config.MaxPerHostConnections == 0 { if Config.MaxTotalConnections == 0 && Config.MaxPerHostConnections == 0 {
return nil return true
} }
if Config.MaxPerHostConnections > 0 { if Config.MaxPerHostConnections > 0 {
if total := conns.clients.getTotalFrom(ipAddr); total > Config.MaxPerHostConnections { if total := conns.clients.getTotalFrom(ipAddr); total > Config.MaxPerHostConnections {
logger.Info(logSender, "", "active connections from %s %d/%d", ipAddr, total, Config.MaxPerHostConnections) logger.Debug(logSender, "", "active connections from %v %v/%v", ipAddr, total, Config.MaxPerHostConnections)
AddDefenderEvent(ipAddr, HostEventLimitExceeded) AddDefenderEvent(ipAddr, HostEventLimitExceeded)
return ErrConnectionDenied return false
} }
} }
if Config.MaxTotalConnections > 0 { if Config.MaxTotalConnections > 0 {
if total := conns.clients.getTotal(); total > int32(Config.MaxTotalConnections) { if total := conns.clients.getTotal(); total > int32(Config.MaxTotalConnections) {
logger.Info(logSender, "", "active client connections %d/%d", total, Config.MaxTotalConnections) logger.Debug(logSender, "", "active client connections %v/%v", total, Config.MaxTotalConnections)
return ErrConnectionDenied return false
} }
// on a single SFTP connection we could have multiple SFTP channels or commands // on a single SFTP connection we could have multiple SFTP channels or commands
@@ -1151,13 +1076,10 @@ func (conns *ActiveConnections) IsNewConnectionAllowed(ipAddr string) error {
conns.RLock() conns.RLock()
defer conns.RUnlock() defer conns.RUnlock()
if sess := len(conns.connections); sess >= Config.MaxTotalConnections { return len(conns.connections) < Config.MaxTotalConnections
logger.Info(logSender, "", "active client sessions %d/%d", sess, Config.MaxTotalConnections)
return ErrConnectionDenied
}
} }
return nil return true
} }
// GetStats returns stats for active connections // GetStats returns stats for active connections
@@ -1166,7 +1088,6 @@ func (conns *ActiveConnections) GetStats() []ConnectionStatus {
defer conns.RUnlock() defer conns.RUnlock()
stats := make([]ConnectionStatus, 0, len(conns.connections)) stats := make([]ConnectionStatus, 0, len(conns.connections))
node := dataprovider.GetNodeName()
for _, c := range conns.connections { for _, c := range conns.connections {
stat := ConnectionStatus{ stat := ConnectionStatus{
Username: c.GetUsername(), Username: c.GetUsername(),
@@ -1178,7 +1099,6 @@ func (conns *ActiveConnections) GetStats() []ConnectionStatus {
Protocol: c.GetProtocol(), Protocol: c.GetProtocol(),
Command: c.GetCommand(), Command: c.GetCommand(),
Transfers: c.GetTransfers(), Transfers: c.GetTransfers(),
Node: node,
} }
stats = append(stats, stat) stats = append(stats, stat)
} }
@@ -1205,8 +1125,6 @@ type ConnectionStatus struct {
Transfers []ConnectionTransfer `json:"active_transfers,omitempty"` Transfers []ConnectionTransfer `json:"active_transfers,omitempty"`
// SSH command or WebDAV method // SSH command or WebDAV method
Command string `json:"command,omitempty"` Command string `json:"command,omitempty"`
// Node identifier, omitted for single node installations
Node string `json:"node,omitempty"`
} }
// GetConnectionDuration returns the connection duration as string // GetConnectionDuration returns the connection duration as string
@@ -1248,7 +1166,7 @@ func (c *ConnectionStatus) GetTransfersAsString() string {
return result return result
} }
// ActiveQuotaScan defines an active quota scan for a user // ActiveQuotaScan defines an active quota scan for a user home dir
type ActiveQuotaScan struct { type ActiveQuotaScan struct {
// Username to which the quota scan refers // Username to which the quota scan refers
Username string `json:"username"` Username string `json:"username"`
@@ -1271,7 +1189,7 @@ type ActiveScans struct {
FolderScans []ActiveVirtualFolderQuotaScan FolderScans []ActiveVirtualFolderQuotaScan
} }
// GetUsersQuotaScans returns the active users quota scans // GetUsersQuotaScans returns the active quota scans for users home directories
func (s *ActiveScans) GetUsersQuotaScans() []ActiveQuotaScan { func (s *ActiveScans) GetUsersQuotaScans() []ActiveQuotaScan {
s.RLock() s.RLock()
defer s.RUnlock() defer s.RUnlock()
@@ -1361,65 +1279,3 @@ func (s *ActiveScans) RemoveVFolderQuotaScan(folderName string) bool {
return false return false
} }
// MetadataCheck defines an active metadata check
type MetadataCheck struct {
// Username to which the metadata check refers
Username string `json:"username"`
// check start time as unix timestamp in milliseconds
StartTime int64 `json:"start_time"`
}
// MetadataChecks holds the active metadata checks
type MetadataChecks struct {
sync.RWMutex
checks []MetadataCheck
}
// Get returns the active metadata checks
func (c *MetadataChecks) Get() []MetadataCheck {
c.RLock()
defer c.RUnlock()
checks := make([]MetadataCheck, len(c.checks))
copy(checks, c.checks)
return checks
}
// Add adds a user to the ones with active metadata checks.
// Return false if a metadata check is already active for the specified user
func (c *MetadataChecks) Add(username string) bool {
c.Lock()
defer c.Unlock()
for idx := range c.checks {
if c.checks[idx].Username == username {
return false
}
}
c.checks = append(c.checks, MetadataCheck{
Username: username,
StartTime: util.GetTimeAsMsSinceEpoch(time.Now()),
})
return true
}
// Remove removes a user from the ones with active metadata checks
func (c *MetadataChecks) Remove(username string) bool {
c.Lock()
defer c.Unlock()
for idx := range c.checks {
if c.checks[idx].Username == username {
lastIdx := len(c.checks) - 1
c.checks[idx] = c.checks[lastIdx]
c.checks = c.checks[:lastIdx]
return true
}
}
return false
}

View File

@@ -24,7 +24,7 @@ import (
"path/filepath" "path/filepath"
"runtime" "runtime"
"strings" "strings"
"sync" "sync/atomic"
"testing" "testing"
"time" "time"
@@ -34,24 +34,21 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/kms"
"github.com/drakkan/sftpgo/v2/internal/plugin" "github.com/drakkan/sftpgo/v2/plugin"
"github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/util"
"github.com/drakkan/sftpgo/v2/internal/vfs" "github.com/drakkan/sftpgo/v2/vfs"
) )
const ( const (
logSenderTest = "common_test" logSenderTest = "common_test"
httpAddr = "127.0.0.1:9999" httpAddr = "127.0.0.1:9999"
configDir = ".."
osWindows = "windows" osWindows = "windows"
userTestUsername = "common_test_username" userTestUsername = "common_test_username"
) )
var (
configDir = filepath.Join(".", "..", "..")
)
type fakeConnection struct { type fakeConnection struct {
*BaseConnection *BaseConnection
command string command string
@@ -99,122 +96,6 @@ func (c *customNetConn) Close() error {
return c.Conn.Close() return c.Conn.Close()
} }
func TestConnections(t *testing.T) {
c1 := &fakeConnection{
BaseConnection: NewBaseConnection("id1", ProtocolSFTP, "", "", dataprovider.User{
BaseUser: sdk.BaseUser{
Username: userTestUsername,
},
}),
}
c2 := &fakeConnection{
BaseConnection: NewBaseConnection("id2", ProtocolSFTP, "", "", dataprovider.User{
BaseUser: sdk.BaseUser{
Username: userTestUsername,
},
}),
}
c3 := &fakeConnection{
BaseConnection: NewBaseConnection("id3", ProtocolSFTP, "", "", dataprovider.User{
BaseUser: sdk.BaseUser{
Username: userTestUsername,
},
}),
}
c4 := &fakeConnection{
BaseConnection: NewBaseConnection("id4", ProtocolSFTP, "", "", dataprovider.User{
BaseUser: sdk.BaseUser{
Username: userTestUsername,
},
}),
}
assert.Equal(t, "SFTP_id1", c1.GetID())
assert.Equal(t, "SFTP_id2", c2.GetID())
assert.Equal(t, "SFTP_id3", c3.GetID())
assert.Equal(t, "SFTP_id4", c4.GetID())
err := Connections.Add(c1)
assert.NoError(t, err)
err = Connections.Add(c2)
assert.NoError(t, err)
err = Connections.Add(c3)
assert.NoError(t, err)
err = Connections.Add(c4)
assert.NoError(t, err)
Connections.RLock()
assert.Len(t, Connections.connections, 4)
assert.Len(t, Connections.mapping, 4)
_, ok := Connections.mapping[c1.GetID()]
assert.True(t, ok)
assert.Equal(t, 0, Connections.mapping[c1.GetID()])
assert.Equal(t, 1, Connections.mapping[c2.GetID()])
assert.Equal(t, 2, Connections.mapping[c3.GetID()])
assert.Equal(t, 3, Connections.mapping[c4.GetID()])
Connections.RUnlock()
c2 = &fakeConnection{
BaseConnection: NewBaseConnection("id2", ProtocolSFTP, "", "", dataprovider.User{
BaseUser: sdk.BaseUser{
Username: userTestUsername + "_mod",
},
}),
}
err = Connections.Swap(c2)
assert.NoError(t, err)
Connections.RLock()
assert.Len(t, Connections.connections, 4)
assert.Len(t, Connections.mapping, 4)
_, ok = Connections.mapping[c1.GetID()]
assert.True(t, ok)
assert.Equal(t, 0, Connections.mapping[c1.GetID()])
assert.Equal(t, 1, Connections.mapping[c2.GetID()])
assert.Equal(t, 2, Connections.mapping[c3.GetID()])
assert.Equal(t, 3, Connections.mapping[c4.GetID()])
assert.Equal(t, userTestUsername+"_mod", Connections.connections[1].GetUsername())
Connections.RUnlock()
Connections.Remove(c2.GetID())
Connections.RLock()
assert.Len(t, Connections.connections, 3)
assert.Len(t, Connections.mapping, 3)
_, ok = Connections.mapping[c1.GetID()]
assert.True(t, ok)
assert.Equal(t, 0, Connections.mapping[c1.GetID()])
assert.Equal(t, 1, Connections.mapping[c4.GetID()])
assert.Equal(t, 2, Connections.mapping[c3.GetID()])
Connections.RUnlock()
Connections.Remove(c3.GetID())
Connections.RLock()
assert.Len(t, Connections.connections, 2)
assert.Len(t, Connections.mapping, 2)
_, ok = Connections.mapping[c1.GetID()]
assert.True(t, ok)
assert.Equal(t, 0, Connections.mapping[c1.GetID()])
assert.Equal(t, 1, Connections.mapping[c4.GetID()])
Connections.RUnlock()
Connections.Remove(c1.GetID())
Connections.RLock()
assert.Len(t, Connections.connections, 1)
assert.Len(t, Connections.mapping, 1)
_, ok = Connections.mapping[c4.GetID()]
assert.True(t, ok)
assert.Equal(t, 0, Connections.mapping[c4.GetID()])
Connections.RUnlock()
Connections.Remove(c4.GetID())
Connections.RLock()
assert.Len(t, Connections.connections, 0)
assert.Len(t, Connections.mapping, 0)
Connections.RUnlock()
}
func TestSSHConnections(t *testing.T) { func TestSSHConnections(t *testing.T) {
conn1, conn2 := net.Pipe() conn1, conn2 := net.Pipe()
now := time.Now() now := time.Now()
@@ -231,44 +112,27 @@ func TestSSHConnections(t *testing.T) {
Connections.AddSSHConnection(sshConn3) Connections.AddSSHConnection(sshConn3)
Connections.RLock() Connections.RLock()
assert.Len(t, Connections.sshConnections, 3) assert.Len(t, Connections.sshConnections, 3)
_, ok := Connections.sshMapping[sshConn1.GetID()]
assert.True(t, ok)
assert.Equal(t, 0, Connections.sshMapping[sshConn1.GetID()])
assert.Equal(t, 1, Connections.sshMapping[sshConn2.GetID()])
assert.Equal(t, 2, Connections.sshMapping[sshConn3.GetID()])
Connections.RUnlock() Connections.RUnlock()
Connections.RemoveSSHConnection(sshConn1.id) Connections.RemoveSSHConnection(sshConn1.id)
Connections.RLock() Connections.RLock()
assert.Len(t, Connections.sshConnections, 2) assert.Len(t, Connections.sshConnections, 2)
assert.Equal(t, sshConn3.id, Connections.sshConnections[0].id) assert.Equal(t, sshConn3.id, Connections.sshConnections[0].id)
assert.Equal(t, sshConn2.id, Connections.sshConnections[1].id) assert.Equal(t, sshConn2.id, Connections.sshConnections[1].id)
_, ok = Connections.sshMapping[sshConn3.GetID()]
assert.True(t, ok)
assert.Equal(t, 0, Connections.sshMapping[sshConn3.GetID()])
assert.Equal(t, 1, Connections.sshMapping[sshConn2.GetID()])
Connections.RUnlock() Connections.RUnlock()
Connections.RemoveSSHConnection(sshConn1.id) Connections.RemoveSSHConnection(sshConn1.id)
Connections.RLock() Connections.RLock()
assert.Len(t, Connections.sshConnections, 2) assert.Len(t, Connections.sshConnections, 2)
assert.Equal(t, sshConn3.id, Connections.sshConnections[0].id) assert.Equal(t, sshConn3.id, Connections.sshConnections[0].id)
assert.Equal(t, sshConn2.id, Connections.sshConnections[1].id) assert.Equal(t, sshConn2.id, Connections.sshConnections[1].id)
_, ok = Connections.sshMapping[sshConn3.GetID()]
assert.True(t, ok)
assert.Equal(t, 0, Connections.sshMapping[sshConn3.GetID()])
assert.Equal(t, 1, Connections.sshMapping[sshConn2.GetID()])
Connections.RUnlock() Connections.RUnlock()
Connections.RemoveSSHConnection(sshConn2.id) Connections.RemoveSSHConnection(sshConn2.id)
Connections.RLock() Connections.RLock()
assert.Len(t, Connections.sshConnections, 1) assert.Len(t, Connections.sshConnections, 1)
assert.Equal(t, sshConn3.id, Connections.sshConnections[0].id) assert.Equal(t, sshConn3.id, Connections.sshConnections[0].id)
_, ok = Connections.sshMapping[sshConn3.GetID()]
assert.True(t, ok)
assert.Equal(t, 0, Connections.sshMapping[sshConn3.GetID()])
Connections.RUnlock() Connections.RUnlock()
Connections.RemoveSSHConnection(sshConn3.id) Connections.RemoveSSHConnection(sshConn3.id)
Connections.RLock() Connections.RLock()
assert.Len(t, Connections.sshConnections, 0) assert.Len(t, Connections.sshConnections, 0)
assert.Len(t, Connections.sshMapping, 0)
Connections.RUnlock() Connections.RUnlock()
assert.NoError(t, sshConn1.Close()) assert.NoError(t, sshConn1.Close())
assert.NoError(t, sshConn2.Close()) assert.NoError(t, sshConn2.Close())
@@ -284,14 +148,14 @@ func TestDefenderIntegration(t *testing.T) {
pluginsConfig := []plugin.Config{ pluginsConfig := []plugin.Config{
{ {
Type: "ipfilter", Type: "ipfilter",
Cmd: filepath.Join(wdPath, "..", "..", "tests", "ipfilter", "ipfilter"), Cmd: filepath.Join(wdPath, "..", "tests", "ipfilter", "ipfilter"),
AutoMTLS: true, AutoMTLS: true,
}, },
} }
if runtime.GOOS == osWindows { if runtime.GOOS == osWindows {
pluginsConfig[0].Cmd += ".exe" pluginsConfig[0].Cmd += ".exe"
} }
err = plugin.Initialize(pluginsConfig, "debug") err = plugin.Initialize(pluginsConfig, true)
require.NoError(t, err) require.NoError(t, err)
ip := "127.1.1.1" ip := "127.1.1.1"
@@ -497,10 +361,10 @@ func TestWhitelist(t *testing.T) {
err = Initialize(Config, 0) err = Initialize(Config, 0)
assert.NoError(t, err) assert.NoError(t, err)
assert.NoError(t, Connections.IsNewConnectionAllowed("172.18.1.1")) assert.True(t, Connections.IsNewConnectionAllowed("172.18.1.1"))
assert.Error(t, Connections.IsNewConnectionAllowed("172.18.1.3")) assert.False(t, Connections.IsNewConnectionAllowed("172.18.1.3"))
assert.NoError(t, Connections.IsNewConnectionAllowed("10.8.7.3")) assert.True(t, Connections.IsNewConnectionAllowed("10.8.7.3"))
assert.Error(t, Connections.IsNewConnectionAllowed("10.8.8.2")) assert.False(t, Connections.IsNewConnectionAllowed("10.8.8.2"))
wl.IPAddresses = append(wl.IPAddresses, "172.18.1.3") wl.IPAddresses = append(wl.IPAddresses, "172.18.1.3")
wl.CIDRNetworks = append(wl.CIDRNetworks, "10.8.8.0/24") wl.CIDRNetworks = append(wl.CIDRNetworks, "10.8.8.0/24")
@@ -508,14 +372,14 @@ func TestWhitelist(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
err = os.WriteFile(wlFile, data, 0664) err = os.WriteFile(wlFile, data, 0664)
assert.NoError(t, err) assert.NoError(t, err)
assert.Error(t, Connections.IsNewConnectionAllowed("10.8.8.3")) assert.False(t, Connections.IsNewConnectionAllowed("10.8.8.3"))
err = Reload() err = Reload()
assert.NoError(t, err) assert.NoError(t, err)
assert.NoError(t, Connections.IsNewConnectionAllowed("10.8.8.3")) assert.True(t, Connections.IsNewConnectionAllowed("10.8.8.3"))
assert.NoError(t, Connections.IsNewConnectionAllowed("172.18.1.3")) assert.True(t, Connections.IsNewConnectionAllowed("172.18.1.3"))
assert.NoError(t, Connections.IsNewConnectionAllowed("172.18.1.2")) assert.True(t, Connections.IsNewConnectionAllowed("172.18.1.2"))
assert.Error(t, Connections.IsNewConnectionAllowed("172.18.1.12")) assert.False(t, Connections.IsNewConnectionAllowed("172.18.1.12"))
Config = configCopy Config = configCopy
} }
@@ -550,12 +414,12 @@ func TestMaxConnections(t *testing.T) {
Config.MaxPerHostConnections = 0 Config.MaxPerHostConnections = 0
ipAddr := "192.168.7.8" ipAddr := "192.168.7.8"
assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr)) assert.True(t, Connections.IsNewConnectionAllowed(ipAddr))
Config.MaxTotalConnections = 1 Config.MaxTotalConnections = 1
Config.MaxPerHostConnections = perHost Config.MaxPerHostConnections = perHost
assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr)) assert.True(t, Connections.IsNewConnectionAllowed(ipAddr))
c := NewBaseConnection("id", ProtocolSFTP, "", "", dataprovider.User{}) c := NewBaseConnection("id", ProtocolSFTP, "", "", dataprovider.User{})
fakeConn := &fakeConnection{ fakeConn := &fakeConnection{
BaseConnection: c, BaseConnection: c,
@@ -563,18 +427,18 @@ func TestMaxConnections(t *testing.T) {
err := Connections.Add(fakeConn) err := Connections.Add(fakeConn)
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, Connections.GetStats(), 1) assert.Len(t, Connections.GetStats(), 1)
assert.Error(t, Connections.IsNewConnectionAllowed(ipAddr)) assert.False(t, Connections.IsNewConnectionAllowed(ipAddr))
res := Connections.Close(fakeConn.GetID()) res := Connections.Close(fakeConn.GetID())
assert.True(t, res) assert.True(t, res)
assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 300*time.Millisecond, 50*time.Millisecond) assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 300*time.Millisecond, 50*time.Millisecond)
assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr)) assert.True(t, Connections.IsNewConnectionAllowed(ipAddr))
Connections.AddClientConnection(ipAddr) Connections.AddClientConnection(ipAddr)
Connections.AddClientConnection(ipAddr) Connections.AddClientConnection(ipAddr)
assert.Error(t, Connections.IsNewConnectionAllowed(ipAddr)) assert.False(t, Connections.IsNewConnectionAllowed(ipAddr))
Connections.RemoveClientConnection(ipAddr) Connections.RemoveClientConnection(ipAddr)
assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr)) assert.True(t, Connections.IsNewConnectionAllowed(ipAddr))
Connections.RemoveClientConnection(ipAddr) Connections.RemoveClientConnection(ipAddr)
Config.MaxTotalConnections = oldValue Config.MaxTotalConnections = oldValue
@@ -587,13 +451,13 @@ func TestMaxConnectionPerHost(t *testing.T) {
ipAddr := "192.168.9.9" ipAddr := "192.168.9.9"
Connections.AddClientConnection(ipAddr) Connections.AddClientConnection(ipAddr)
assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr)) assert.True(t, Connections.IsNewConnectionAllowed(ipAddr))
Connections.AddClientConnection(ipAddr) Connections.AddClientConnection(ipAddr)
assert.NoError(t, Connections.IsNewConnectionAllowed(ipAddr)) assert.True(t, Connections.IsNewConnectionAllowed(ipAddr))
Connections.AddClientConnection(ipAddr) Connections.AddClientConnection(ipAddr)
assert.Error(t, Connections.IsNewConnectionAllowed(ipAddr)) assert.False(t, Connections.IsNewConnectionAllowed(ipAddr))
assert.Equal(t, int32(3), Connections.GetClientConnections()) assert.Equal(t, int32(3), Connections.GetClientConnections())
Connections.RemoveClientConnection(ipAddr) Connections.RemoveClientConnection(ipAddr)
@@ -631,14 +495,14 @@ func TestIdleConnections(t *testing.T) {
}, },
} }
c := NewBaseConnection(sshConn1.id+"_1", ProtocolSFTP, "", "", user) c := NewBaseConnection(sshConn1.id+"_1", ProtocolSFTP, "", "", user)
c.lastActivity.Store(time.Now().Add(-24 * time.Hour).UnixNano()) c.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
fakeConn := &fakeConnection{ fakeConn := &fakeConnection{
BaseConnection: c, BaseConnection: c,
} }
// both ssh connections are expired but they should get removed only // both ssh connections are expired but they should get removed only
// if there is no associated connection // if there is no associated connection
sshConn1.lastActivity.Store(c.lastActivity.Load()) sshConn1.lastActivity = c.lastActivity
sshConn2.lastActivity.Store(c.lastActivity.Load()) sshConn2.lastActivity = c.lastActivity
Connections.AddSSHConnection(sshConn1) Connections.AddSSHConnection(sshConn1)
err = Connections.Add(fakeConn) err = Connections.Add(fakeConn)
assert.NoError(t, err) assert.NoError(t, err)
@@ -653,7 +517,7 @@ func TestIdleConnections(t *testing.T) {
assert.Equal(t, Connections.GetActiveSessions(username), 2) assert.Equal(t, Connections.GetActiveSessions(username), 2)
cFTP := NewBaseConnection("id2", ProtocolFTP, "", "", dataprovider.User{}) cFTP := NewBaseConnection("id2", ProtocolFTP, "", "", dataprovider.User{})
cFTP.lastActivity.Store(time.Now().UnixNano()) cFTP.lastActivity = time.Now().UnixNano()
fakeConn = &fakeConnection{ fakeConn = &fakeConnection{
BaseConnection: cFTP, BaseConnection: cFTP,
} }
@@ -665,27 +529,27 @@ func TestIdleConnections(t *testing.T) {
assert.Len(t, Connections.sshConnections, 2) assert.Len(t, Connections.sshConnections, 2)
Connections.RUnlock() Connections.RUnlock()
startPeriodicChecks(100 * time.Millisecond) startPeriodicTimeoutTicker(100 * time.Millisecond)
assert.Eventually(t, func() bool { return Connections.GetActiveSessions(username) == 1 }, 2*time.Second, 200*time.Millisecond) assert.Eventually(t, func() bool { return Connections.GetActiveSessions(username) == 1 }, 1*time.Second, 200*time.Millisecond)
assert.Eventually(t, func() bool { assert.Eventually(t, func() bool {
Connections.RLock() Connections.RLock()
defer Connections.RUnlock() defer Connections.RUnlock()
return len(Connections.sshConnections) == 1 return len(Connections.sshConnections) == 1
}, 1*time.Second, 200*time.Millisecond) }, 1*time.Second, 200*time.Millisecond)
stopEventScheduler() stopPeriodicTimeoutTicker()
assert.Len(t, Connections.GetStats(), 2) assert.Len(t, Connections.GetStats(), 2)
c.lastActivity.Store(time.Now().Add(-24 * time.Hour).UnixNano()) c.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
cFTP.lastActivity.Store(time.Now().Add(-24 * time.Hour).UnixNano()) cFTP.lastActivity = time.Now().Add(-24 * time.Hour).UnixNano()
sshConn2.lastActivity.Store(c.lastActivity.Load()) sshConn2.lastActivity = c.lastActivity
startPeriodicChecks(100 * time.Millisecond) startPeriodicTimeoutTicker(100 * time.Millisecond)
assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 2*time.Second, 200*time.Millisecond) assert.Eventually(t, func() bool { return len(Connections.GetStats()) == 0 }, 1*time.Second, 200*time.Millisecond)
assert.Eventually(t, func() bool { assert.Eventually(t, func() bool {
Connections.RLock() Connections.RLock()
defer Connections.RUnlock() defer Connections.RUnlock()
return len(Connections.sshConnections) == 0 return len(Connections.sshConnections) == 0
}, 1*time.Second, 200*time.Millisecond) }, 1*time.Second, 200*time.Millisecond)
assert.Equal(t, int32(0), Connections.GetClientConnections()) assert.Equal(t, int32(0), Connections.GetClientConnections())
stopEventScheduler() stopPeriodicTimeoutTicker()
assert.True(t, customConn1.isClosed) assert.True(t, customConn1.isClosed)
assert.True(t, customConn2.isClosed) assert.True(t, customConn2.isClosed)
@@ -697,7 +561,7 @@ func TestCloseConnection(t *testing.T) {
fakeConn := &fakeConnection{ fakeConn := &fakeConnection{
BaseConnection: c, BaseConnection: c,
} }
assert.NoError(t, Connections.IsNewConnectionAllowed("127.0.0.1")) assert.True(t, Connections.IsNewConnectionAllowed("127.0.0.1"))
err := Connections.Add(fakeConn) err := Connections.Add(fakeConn)
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, Connections.GetStats(), 1) assert.Len(t, Connections.GetStats(), 1)
@@ -779,9 +643,9 @@ func TestConnectionStatus(t *testing.T) {
BaseConnection: c1, BaseConnection: c1,
} }
t1 := NewBaseTransfer(nil, c1, nil, "/p1", "/p1", "/r1", TransferUpload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) t1 := NewBaseTransfer(nil, c1, nil, "/p1", "/p1", "/r1", TransferUpload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{})
t1.BytesReceived.Store(123) t1.BytesReceived = 123
t2 := NewBaseTransfer(nil, c1, nil, "/p2", "/p2", "/r2", TransferDownload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) t2 := NewBaseTransfer(nil, c1, nil, "/p2", "/p2", "/r2", TransferDownload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{})
t2.BytesSent.Store(456) t2.BytesSent = 456
c2 := NewBaseConnection("id2", ProtocolSSH, "", "", user) c2 := NewBaseConnection("id2", ProtocolSSH, "", "", user)
fakeConn2 := &fakeConnection{ fakeConn2 := &fakeConnection{
BaseConnection: c2, BaseConnection: c2,
@@ -831,7 +695,7 @@ func TestConnectionStatus(t *testing.T) {
err = fakeConn3.SignalTransfersAbort() err = fakeConn3.SignalTransfersAbort()
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, t3.AbortTransfer.Load()) assert.Equal(t, int32(1), atomic.LoadInt32(&t3.AbortTransfer))
err = t3.Close() err = t3.Close()
assert.NoError(t, err) assert.NoError(t, err)
err = fakeConn3.SignalTransfersAbort() err = fakeConn3.SignalTransfersAbort()
@@ -1194,189 +1058,6 @@ func TestUserRecentActivity(t *testing.T) {
assert.True(t, res) assert.True(t, res)
} }
func TestVfsSameResource(t *testing.T) {
fs := vfs.Filesystem{}
other := vfs.Filesystem{}
res := fs.IsSameResource(other)
assert.True(t, res)
fs = vfs.Filesystem{
Provider: sdk.S3FilesystemProvider,
S3Config: vfs.S3FsConfig{
BaseS3FsConfig: sdk.BaseS3FsConfig{
Bucket: "a",
Region: "b",
},
},
}
other = vfs.Filesystem{
Provider: sdk.S3FilesystemProvider,
S3Config: vfs.S3FsConfig{
BaseS3FsConfig: sdk.BaseS3FsConfig{
Bucket: "a",
Region: "c",
},
},
}
res = fs.IsSameResource(other)
assert.False(t, res)
other = vfs.Filesystem{
Provider: sdk.S3FilesystemProvider,
S3Config: vfs.S3FsConfig{
BaseS3FsConfig: sdk.BaseS3FsConfig{
Bucket: "a",
Region: "b",
},
},
}
res = fs.IsSameResource(other)
assert.True(t, res)
fs = vfs.Filesystem{
Provider: sdk.GCSFilesystemProvider,
GCSConfig: vfs.GCSFsConfig{
BaseGCSFsConfig: sdk.BaseGCSFsConfig{
Bucket: "b",
},
},
}
other = vfs.Filesystem{
Provider: sdk.GCSFilesystemProvider,
GCSConfig: vfs.GCSFsConfig{
BaseGCSFsConfig: sdk.BaseGCSFsConfig{
Bucket: "c",
},
},
}
res = fs.IsSameResource(other)
assert.False(t, res)
other = vfs.Filesystem{
Provider: sdk.GCSFilesystemProvider,
GCSConfig: vfs.GCSFsConfig{
BaseGCSFsConfig: sdk.BaseGCSFsConfig{
Bucket: "b",
},
},
}
res = fs.IsSameResource(other)
assert.True(t, res)
sasURL := kms.NewPlainSecret("http://127.0.0.1/sasurl")
fs = vfs.Filesystem{
Provider: sdk.AzureBlobFilesystemProvider,
AzBlobConfig: vfs.AzBlobFsConfig{
BaseAzBlobFsConfig: sdk.BaseAzBlobFsConfig{
AccountName: "a",
},
SASURL: sasURL,
},
}
err := fs.Validate("data1")
assert.NoError(t, err)
other = vfs.Filesystem{
Provider: sdk.AzureBlobFilesystemProvider,
AzBlobConfig: vfs.AzBlobFsConfig{
BaseAzBlobFsConfig: sdk.BaseAzBlobFsConfig{
AccountName: "a",
},
SASURL: sasURL,
},
}
err = other.Validate("data2")
assert.NoError(t, err)
err = fs.AzBlobConfig.SASURL.TryDecrypt()
assert.NoError(t, err)
err = other.AzBlobConfig.SASURL.TryDecrypt()
assert.NoError(t, err)
res = fs.IsSameResource(other)
assert.True(t, res)
fs.AzBlobConfig.AccountName = "b"
res = fs.IsSameResource(other)
assert.False(t, res)
fs.AzBlobConfig.AccountName = "a"
other.AzBlobConfig.SASURL = kms.NewPlainSecret("http://127.1.1.1/sasurl")
err = other.Validate("data2")
assert.NoError(t, err)
err = other.AzBlobConfig.SASURL.TryDecrypt()
assert.NoError(t, err)
res = fs.IsSameResource(other)
assert.False(t, res)
fs = vfs.Filesystem{
Provider: sdk.HTTPFilesystemProvider,
HTTPConfig: vfs.HTTPFsConfig{
BaseHTTPFsConfig: sdk.BaseHTTPFsConfig{
Endpoint: "http://127.0.0.1/httpfs",
Username: "a",
},
},
}
other = vfs.Filesystem{
Provider: sdk.HTTPFilesystemProvider,
HTTPConfig: vfs.HTTPFsConfig{
BaseHTTPFsConfig: sdk.BaseHTTPFsConfig{
Endpoint: "http://127.0.0.1/httpfs",
Username: "b",
},
},
}
res = fs.IsSameResource(other)
assert.True(t, res)
fs.HTTPConfig.EqualityCheckMode = 1
res = fs.IsSameResource(other)
assert.False(t, res)
}
func TestUpdateTransferTimestamps(t *testing.T) {
username := "user_test_timestamps"
user := &dataprovider.User{
BaseUser: sdk.BaseUser{
Username: username,
HomeDir: filepath.Join(os.TempDir(), username),
Status: 1,
Permissions: map[string][]string{
"/": {dataprovider.PermAny},
},
},
}
err := dataprovider.AddUser(user, "", "")
assert.NoError(t, err)
assert.Equal(t, int64(0), user.FirstUpload)
assert.Equal(t, int64(0), user.FirstDownload)
err = dataprovider.UpdateUserTransferTimestamps(username, true)
assert.NoError(t, err)
userGet, err := dataprovider.UserExists(username)
assert.NoError(t, err)
assert.Greater(t, userGet.FirstUpload, int64(0))
assert.Equal(t, int64(0), user.FirstDownload)
err = dataprovider.UpdateUserTransferTimestamps(username, false)
assert.NoError(t, err)
userGet, err = dataprovider.UserExists(username)
assert.NoError(t, err)
assert.Greater(t, userGet.FirstUpload, int64(0))
assert.Greater(t, userGet.FirstDownload, int64(0))
// updating again must fail
err = dataprovider.UpdateUserTransferTimestamps(username, true)
assert.Error(t, err)
err = dataprovider.UpdateUserTransferTimestamps(username, false)
assert.Error(t, err)
// cleanup
err = dataprovider.DeleteUser(username, "", "")
assert.NoError(t, err)
}
func TestMetadataAPI(t *testing.T) {
username := "metadatauser"
require.False(t, ActiveMetadataChecks.Remove(username))
require.True(t, ActiveMetadataChecks.Add(username))
require.False(t, ActiveMetadataChecks.Add(username))
checks := ActiveMetadataChecks.Get()
require.Len(t, checks, 1)
checks[0].Username = username + "a"
checks = ActiveMetadataChecks.Get()
require.Len(t, checks, 1)
require.Equal(t, username, checks[0].Username)
require.True(t, ActiveMetadataChecks.Remove(username))
require.Len(t, ActiveMetadataChecks.Get(), 0)
}
func BenchmarkBcryptHashing(b *testing.B) { func BenchmarkBcryptHashing(b *testing.B) {
bcryptPassword := "bcryptpassword" bcryptPassword := "bcryptpassword"
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
@@ -1416,52 +1097,3 @@ func BenchmarkCompareArgon2Password(b *testing.B) {
} }
} }
} }
func BenchmarkAddRemoveConnections(b *testing.B) {
var conns []ActiveConnection
for i := 0; i < 100; i++ {
conns = append(conns, &fakeConnection{
BaseConnection: NewBaseConnection(fmt.Sprintf("id%d", i), ProtocolSFTP, "", "", dataprovider.User{
BaseUser: sdk.BaseUser{
Username: userTestUsername,
},
}),
})
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
for _, c := range conns {
if err := Connections.Add(c); err != nil {
panic(err)
}
}
var wg sync.WaitGroup
for idx := len(conns) - 1; idx >= 0; idx-- {
wg.Add(1)
go func(index int) {
defer wg.Done()
Connections.Remove(conns[index].GetID())
}(idx)
}
wg.Wait()
}
}
func BenchmarkAddRemoveSSHConnections(b *testing.B) {
conn1, conn2 := net.Pipe()
var conns []*SSHConnection
for i := 0; i < 2000; i++ {
conns = append(conns, NewSSHConnection(fmt.Sprintf("id%d", i), conn1))
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
for _, c := range conns {
Connections.AddSSHConnection(c)
}
for idx := len(conns) - 1; idx >= 0; idx-- {
Connections.RemoveSSHConnection(conns[idx].GetID())
}
}
conn1.Close()
conn2.Close()
}

View File

@@ -28,22 +28,20 @@ import (
"github.com/pkg/sftp" "github.com/pkg/sftp"
"github.com/sftpgo/sdk" "github.com/sftpgo/sdk"
"github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/util"
"github.com/drakkan/sftpgo/v2/internal/vfs" "github.com/drakkan/sftpgo/v2/vfs"
) )
// BaseConnection defines common fields for a connection using any supported protocol // BaseConnection defines common fields for a connection using any supported protocol
type BaseConnection struct { type BaseConnection struct {
// last activity for this connection. // last activity for this connection.
// Since this field is accessed atomically we put it as first element of the struct to achieve 64 bit alignment // Since this field is accessed atomically we put it as first element of the struct to achieve 64 bit alignment
lastActivity atomic.Int64 lastActivity int64
uploadDone atomic.Bool
downloadDone atomic.Bool
// unique ID for a transfer. // unique ID for a transfer.
// This field is accessed atomically so we put it at the beginning of the struct to achieve 64 bit alignment // This field is accessed atomically so we put it at the beginning of the struct to achieve 64 bit alignment
transferID atomic.Int64 transferID int64
// Unique identifier for the connection // Unique identifier for the connection
ID string ID string
// user associated with this connection if any // user associated with this connection if any
@@ -64,18 +62,16 @@ func NewBaseConnection(id, protocol, localAddr, remoteAddr string, user dataprov
connID = fmt.Sprintf("%s_%s", protocol, id) connID = fmt.Sprintf("%s_%s", protocol, id)
} }
user.UploadBandwidth, user.DownloadBandwidth = user.GetBandwidthForIP(util.GetIPFromRemoteAddress(remoteAddr), connID) user.UploadBandwidth, user.DownloadBandwidth = user.GetBandwidthForIP(util.GetIPFromRemoteAddress(remoteAddr), connID)
c := &BaseConnection{ return &BaseConnection{
ID: connID, ID: connID,
User: user, User: user,
startTime: time.Now(), startTime: time.Now(),
protocol: protocol, protocol: protocol,
localAddr: localAddr, localAddr: localAddr,
remoteAddr: remoteAddr, remoteAddr: remoteAddr,
lastActivity: time.Now().UnixNano(),
transferID: 0,
} }
c.transferID.Store(0)
c.lastActivity.Store(time.Now().UnixNano())
return c
} }
// Log outputs a log entry to the configured logger // Log outputs a log entry to the configured logger
@@ -85,7 +81,7 @@ func (c *BaseConnection) Log(level logger.LogLevel, format string, v ...any) {
// GetTransferID returns an unique transfer ID for this connection // GetTransferID returns an unique transfer ID for this connection
func (c *BaseConnection) GetTransferID() int64 { func (c *BaseConnection) GetTransferID() int64 {
return c.transferID.Add(1) return atomic.AddInt64(&c.transferID, 1)
} }
// GetID returns the connection ID // GetID returns the connection ID
@@ -128,12 +124,12 @@ func (c *BaseConnection) GetConnectionTime() time.Time {
// UpdateLastActivity updates last activity for this connection // UpdateLastActivity updates last activity for this connection
func (c *BaseConnection) UpdateLastActivity() { func (c *BaseConnection) UpdateLastActivity() {
c.lastActivity.Store(time.Now().UnixNano()) atomic.StoreInt64(&c.lastActivity, time.Now().UnixNano())
} }
// GetLastActivity returns the last connection activity // GetLastActivity returns the last connection activity
func (c *BaseConnection) GetLastActivity() time.Time { func (c *BaseConnection) GetLastActivity() time.Time {
return time.Unix(0, c.lastActivity.Load()) return time.Unix(0, atomic.LoadInt64(&c.lastActivity))
} }
// CloseFS closes the underlying fs // CloseFS closes the underlying fs
@@ -257,7 +253,7 @@ func (c *BaseConnection) getRealFsPath(fsPath string) string {
defer c.RUnlock() defer c.RUnlock()
for _, t := range c.activeTransfers { for _, t := range c.activeTransfers {
if p := t.GetRealFsPath(fsPath); p != "" { if p := t.GetRealFsPath(fsPath); len(p) > 0 {
return p return p
} }
} }
@@ -362,7 +358,7 @@ func (c *BaseConnection) CreateDir(virtualPath string, checkFilePatterns bool) e
logger.CommandLog(mkdirLogSender, fsPath, "", c.User.Username, "", c.ID, c.protocol, -1, -1, "", "", "", -1, logger.CommandLog(mkdirLogSender, fsPath, "", c.User.Username, "", c.ID, c.protocol, -1, -1, "", "", "", -1,
c.localAddr, c.remoteAddr) c.localAddr, c.remoteAddr)
ExecuteActionNotification(c, operationMkdir, fsPath, virtualPath, "", "", "", 0, nil) //nolint:errcheck ExecuteActionNotification(c, operationMkdir, fsPath, virtualPath, "", "", "", 0, nil)
return nil return nil
} }
@@ -409,7 +405,7 @@ func (c *BaseConnection) RemoveFile(fs vfs.Fs, fsPath, virtualPath string, info
} }
} }
if actionErr != nil { if actionErr != nil {
ExecuteActionNotification(c, operationDelete, fsPath, virtualPath, "", "", "", size, nil) //nolint:errcheck ExecuteActionNotification(c, operationDelete, fsPath, virtualPath, "", "", "", size, nil)
} }
return nil return nil
} }
@@ -473,7 +469,7 @@ func (c *BaseConnection) RemoveDir(virtualPath string) error {
logger.CommandLog(rmdirLogSender, fsPath, "", c.User.Username, "", c.ID, c.protocol, -1, -1, "", "", "", -1, logger.CommandLog(rmdirLogSender, fsPath, "", c.User.Username, "", c.ID, c.protocol, -1, -1, "", "", "", -1,
c.localAddr, c.remoteAddr) c.localAddr, c.remoteAddr)
ExecuteActionNotification(c, operationRmdir, fsPath, virtualPath, "", "", "", 0, nil) //nolint:errcheck ExecuteActionNotification(c, operationRmdir, fsPath, virtualPath, "", "", "", 0, nil)
return nil return nil
} }
@@ -500,7 +496,7 @@ func (c *BaseConnection) Rename(virtualSourcePath, virtualTargetPath string) err
initialSize := int64(-1) initialSize := int64(-1)
if dstInfo, err := fsDst.Lstat(fsTargetPath); err == nil { if dstInfo, err := fsDst.Lstat(fsTargetPath); err == nil {
if dstInfo.IsDir() { if dstInfo.IsDir() {
c.Log(logger.LevelWarn, "attempted to rename %q overwriting an existing directory %q", c.Log(logger.LevelWarn, "attempted to rename %#v overwriting an existing directory %#v",
fsSourcePath, fsTargetPath) fsSourcePath, fsTargetPath)
return c.GetOpUnsupportedError() return c.GetOpUnsupportedError()
} }
@@ -520,8 +516,7 @@ func (c *BaseConnection) Rename(virtualSourcePath, virtualTargetPath string) err
virtualSourcePath) virtualSourcePath)
return c.GetOpUnsupportedError() return c.GetOpUnsupportedError()
} }
if err = c.checkRecursiveRenameDirPermissions(fsSrc, fsDst, fsSourcePath, fsTargetPath, if err = c.checkRecursiveRenameDirPermissions(fsSrc, fsDst, fsSourcePath, fsTargetPath); err != nil {
virtualSourcePath, virtualTargetPath, srcInfo); err != nil {
c.Log(logger.LevelDebug, "error checking recursive permissions before renaming %#v: %+v", fsSourcePath, err) c.Log(logger.LevelDebug, "error checking recursive permissions before renaming %#v: %+v", fsSourcePath, err)
return err return err
} }
@@ -530,7 +525,7 @@ func (c *BaseConnection) Rename(virtualSourcePath, virtualTargetPath string) err
c.Log(logger.LevelInfo, "denying cross rename due to space limit") c.Log(logger.LevelInfo, "denying cross rename due to space limit")
return c.GetGenericError(ErrQuotaExceeded) return c.GetGenericError(ErrQuotaExceeded)
} }
if err := fsDst.Rename(fsSourcePath, fsTargetPath); err != nil { if err := fsSrc.Rename(fsSourcePath, fsTargetPath); err != nil {
c.Log(logger.LevelError, "failed to rename %#v -> %#v: %+v", fsSourcePath, fsTargetPath, err) c.Log(logger.LevelError, "failed to rename %#v -> %#v: %+v", fsSourcePath, fsTargetPath, err)
return c.GetFsError(fsSrc, err) return c.GetFsError(fsSrc, err)
} }
@@ -538,21 +533,14 @@ func (c *BaseConnection) Rename(virtualSourcePath, virtualTargetPath string) err
c.updateQuotaAfterRename(fsDst, virtualSourcePath, virtualTargetPath, fsTargetPath, initialSize) //nolint:errcheck c.updateQuotaAfterRename(fsDst, virtualSourcePath, virtualTargetPath, fsTargetPath, initialSize) //nolint:errcheck
logger.CommandLog(renameLogSender, fsSourcePath, fsTargetPath, c.User.Username, "", c.ID, c.protocol, -1, -1, logger.CommandLog(renameLogSender, fsSourcePath, fsTargetPath, c.User.Username, "", c.ID, c.protocol, -1, -1,
"", "", "", -1, c.localAddr, c.remoteAddr) "", "", "", -1, c.localAddr, c.remoteAddr)
ExecuteActionNotification(c, operationRename, fsSourcePath, virtualSourcePath, fsTargetPath, //nolint:errcheck ExecuteActionNotification(c, operationRename, fsSourcePath, virtualSourcePath, fsTargetPath, virtualTargetPath,
virtualTargetPath, "", 0, nil) "", 0, nil)
return nil return nil
} }
// CreateSymlink creates fsTargetPath as a symbolic link to fsSourcePath // CreateSymlink creates fsTargetPath as a symbolic link to fsSourcePath
func (c *BaseConnection) CreateSymlink(virtualSourcePath, virtualTargetPath string) error { func (c *BaseConnection) CreateSymlink(virtualSourcePath, virtualTargetPath string) error {
var relativePath string
if !path.IsAbs(virtualSourcePath) {
relativePath = virtualSourcePath
virtualSourcePath = path.Join(path.Dir(virtualTargetPath), relativePath)
c.Log(logger.LevelDebug, "link relative path %q resolved as %q, target path %q",
relativePath, virtualSourcePath, virtualTargetPath)
}
if c.isCrossFoldersRequest(virtualSourcePath, virtualTargetPath) { if c.isCrossFoldersRequest(virtualSourcePath, virtualTargetPath) {
c.Log(logger.LevelWarn, "cross folder symlink is not supported, src: %v dst: %v", virtualSourcePath, virtualTargetPath) c.Log(logger.LevelWarn, "cross folder symlink is not supported, src: %v dst: %v", virtualSourcePath, virtualTargetPath)
return c.GetOpUnsupportedError() return c.GetOpUnsupportedError()
@@ -586,9 +574,6 @@ func (c *BaseConnection) CreateSymlink(virtualSourcePath, virtualTargetPath stri
c.Log(logger.LevelError, "symlink target path %#v is not allowed", virtualTargetPath) c.Log(logger.LevelError, "symlink target path %#v is not allowed", virtualTargetPath)
return c.GetPermissionDeniedError() return c.GetPermissionDeniedError()
} }
if relativePath != "" {
fsSourcePath = relativePath
}
if err := fs.Symlink(fsSourcePath, fsTargetPath); err != nil { if err := fs.Symlink(fsSourcePath, fsTargetPath); err != nil {
c.Log(logger.LevelError, "failed to create symlink %#v -> %#v: %+v", fsSourcePath, fsTargetPath, err) c.Log(logger.LevelError, "failed to create symlink %#v -> %#v: %+v", fsSourcePath, fsTargetPath, err)
return c.GetFsError(fs, err) return c.GetFsError(fs, err)
@@ -608,14 +593,13 @@ func (c *BaseConnection) getPathForSetStatPerms(fs vfs.Fs, fsPath, virtualPath s
return pathForPerms return pathForPerms
} }
func (c *BaseConnection) doStatInternal(virtualPath string, mode int, checkFilePatterns, // DoStat execute a Stat if mode = 0, Lstat if mode = 1
convertResult bool, func (c *BaseConnection) DoStat(virtualPath string, mode int, checkFilePatterns bool) (os.FileInfo, error) {
) (os.FileInfo, error) {
// for some vfs we don't create intermediary folders so we cannot simply check // for some vfs we don't create intermediary folders so we cannot simply check
// if virtualPath is a virtual folder // if virtualPath is a virtual folder
vfolders := c.User.GetVirtualFoldersInPath(path.Dir(virtualPath)) vfolders := c.User.GetVirtualFoldersInPath(path.Dir(virtualPath))
if _, ok := vfolders[virtualPath]; ok { if _, ok := vfolders[virtualPath]; ok {
return vfs.NewFileInfo(virtualPath, true, 0, time.Unix(0, 0), false), nil return vfs.NewFileInfo(virtualPath, true, 0, time.Now(), false), nil
} }
if checkFilePatterns { if checkFilePatterns {
ok, policy := c.User.IsFileAllowed(virtualPath) ok, policy := c.User.IsFileAllowed(virtualPath)
@@ -637,20 +621,15 @@ func (c *BaseConnection) doStatInternal(virtualPath string, mode int, checkFileP
info, err = fs.Stat(c.getRealFsPath(fsPath)) info, err = fs.Stat(c.getRealFsPath(fsPath))
} }
if err != nil { if err != nil {
c.Log(logger.LevelWarn, "stat error for path %#v: %+v", virtualPath, err) c.Log(logger.LevelError, "stat error for path %#v: %+v", virtualPath, err)
return info, c.GetFsError(fs, err) return info, c.GetFsError(fs, err)
} }
if convertResult && vfs.IsCryptOsFs(fs) { if vfs.IsCryptOsFs(fs) {
info = fs.(*vfs.CryptFs).ConvertFileInfo(info) info = fs.(*vfs.CryptFs).ConvertFileInfo(info)
} }
return info, nil return info, nil
} }
// DoStat execute a Stat if mode = 0, Lstat if mode = 1
func (c *BaseConnection) DoStat(virtualPath string, mode int, checkFilePatterns bool) (os.FileInfo, error) {
return c.doStatInternal(virtualPath, mode, checkFilePatterns, true)
}
func (c *BaseConnection) createDirIfMissing(name string) error { func (c *BaseConnection) createDirIfMissing(name string) error {
_, err := c.DoStat(name, 0, false) _, err := c.DoStat(name, 0, false)
if c.IsNotExistError(err) { if c.IsNotExistError(err) {
@@ -787,7 +766,7 @@ func (c *BaseConnection) truncateFile(fs vfs.Fs, fsPath, virtualPath string, siz
initialSize = info.Size() initialSize = info.Size()
err = fs.Truncate(fsPath, size) err = fs.Truncate(fsPath, size)
} }
if err == nil && vfs.HasTruncateSupport(fs) { if err == nil && vfs.IsLocalOrSFTPFs(fs) {
sizeDiff := initialSize - size sizeDiff := initialSize - size
vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(virtualPath)) vfolder, err := c.User.GetVirtualFolderForPath(path.Dir(virtualPath))
if err == nil { if err == nil {
@@ -802,31 +781,23 @@ func (c *BaseConnection) truncateFile(fs vfs.Fs, fsPath, virtualPath string, siz
return err return err
} }
func (c *BaseConnection) checkRecursiveRenameDirPermissions(fsSrc, fsDst vfs.Fs, sourcePath, targetPath, func (c *BaseConnection) checkRecursiveRenameDirPermissions(fsSrc, fsDst vfs.Fs, sourcePath, targetPath string) error {
virtualSourcePath, virtualTargetPath string, fi os.FileInfo, err := fsSrc.Walk(sourcePath, func(walkedPath string, info os.FileInfo, err error) error {
) error {
if !c.User.HasPermissionsInside(virtualSourcePath) &&
!c.User.HasPermissionsInside(virtualTargetPath) {
if !c.isRenamePermitted(fsSrc, fsDst, sourcePath, targetPath, virtualSourcePath, virtualTargetPath, fi) {
c.Log(logger.LevelInfo, "rename %#v -> %#v is not allowed, virtual destination path: %#v",
sourcePath, targetPath, virtualTargetPath)
return c.GetPermissionDeniedError()
}
// if all rename permissions are granted we have finished, otherwise we have to walk
// because we could have the rename dir permission but not the rename file and the dir to
// rename could contain files
if c.User.HasPermsRenameAll(path.Dir(virtualSourcePath)) && c.User.HasPermsRenameAll(path.Dir(virtualTargetPath)) {
return nil
}
}
return fsSrc.Walk(sourcePath, func(walkedPath string, info os.FileInfo, err error) error {
if err != nil { if err != nil {
return c.GetFsError(fsSrc, err) return c.GetFsError(fsSrc, err)
} }
dstPath := strings.Replace(walkedPath, sourcePath, targetPath, 1) dstPath := strings.Replace(walkedPath, sourcePath, targetPath, 1)
virtualSrcPath := fsSrc.GetRelativePath(walkedPath) virtualSrcPath := fsSrc.GetRelativePath(walkedPath)
virtualDstPath := fsDst.GetRelativePath(dstPath) virtualDstPath := fsDst.GetRelativePath(dstPath)
// walk scans the directory tree in order, checking the parent directory permissions we are sure that all contents
// inside the parent path was checked. If the current dir has no subdirs with defined permissions inside it
// and it has all the possible permissions we can stop scanning
if !c.User.HasPermissionsInside(path.Dir(virtualSrcPath)) && !c.User.HasPermissionsInside(path.Dir(virtualDstPath)) {
if c.User.HasPermsRenameAll(path.Dir(virtualSrcPath)) &&
c.User.HasPermsRenameAll(path.Dir(virtualDstPath)) {
return ErrSkipPermissionsCheck
}
}
if !c.isRenamePermitted(fsSrc, fsDst, walkedPath, dstPath, virtualSrcPath, virtualDstPath, info) { if !c.isRenamePermitted(fsSrc, fsDst, walkedPath, dstPath, virtualSrcPath, virtualDstPath, info) {
c.Log(logger.LevelInfo, "rename %#v -> %#v is not allowed, virtual destination path: %#v", c.Log(logger.LevelInfo, "rename %#v -> %#v is not allowed, virtual destination path: %#v",
walkedPath, dstPath, virtualDstPath) walkedPath, dstPath, virtualDstPath)
@@ -834,6 +805,10 @@ func (c *BaseConnection) checkRecursiveRenameDirPermissions(fsSrc, fsDst vfs.Fs,
} }
return nil return nil
}) })
if err == ErrSkipPermissionsCheck {
err = nil
}
return err
} }
func (c *BaseConnection) hasRenamePerms(virtualSourcePath, virtualTargetPath string, fi os.FileInfo) bool { func (c *BaseConnection) hasRenamePerms(virtualSourcePath, virtualTargetPath string, fi os.FileInfo) bool {
@@ -862,11 +837,9 @@ func (c *BaseConnection) hasRenamePerms(virtualSourcePath, virtualTargetPath str
c.User.HasAnyPerm(perms, path.Dir(virtualTargetPath)) c.User.HasAnyPerm(perms, path.Dir(virtualTargetPath))
} }
func (c *BaseConnection) isRenamePermitted(fsSrc, fsDst vfs.Fs, fsSourcePath, fsTargetPath, virtualSourcePath, func (c *BaseConnection) isRenamePermitted(fsSrc, fsDst vfs.Fs, fsSourcePath, fsTargetPath, virtualSourcePath, virtualTargetPath string, fi os.FileInfo) bool {
virtualTargetPath string, fi os.FileInfo, if !c.isLocalOrSameFolderRename(virtualSourcePath, virtualTargetPath) {
) bool { c.Log(logger.LevelInfo, "rename %#v->%#v is not allowed: the paths must be local or on the same virtual folder",
if !c.isSameResourceRename(virtualSourcePath, virtualTargetPath) {
c.Log(logger.LevelInfo, "rename %#v->%#v is not allowed: the paths must be on the same resource",
virtualSourcePath, virtualTargetPath) virtualSourcePath, virtualTargetPath)
return false return false
} }
@@ -878,7 +851,7 @@ func (c *BaseConnection) isRenamePermitted(fsSrc, fsDst vfs.Fs, fsSourcePath, fs
c.Log(logger.LevelWarn, "renaming to a directory mapped as virtual folder is not allowed: %#v", fsTargetPath) c.Log(logger.LevelWarn, "renaming to a directory mapped as virtual folder is not allowed: %#v", fsTargetPath)
return false return false
} }
if virtualSourcePath == "/" || virtualTargetPath == "/" || fsSrc.GetRelativePath(fsSourcePath) == "/" { if fsSrc.GetRelativePath(fsSourcePath) == "/" {
c.Log(logger.LevelWarn, "renaming root dir is not allowed") c.Log(logger.LevelWarn, "renaming root dir is not allowed")
return false return false
} }
@@ -1108,7 +1081,8 @@ func (c *BaseConnection) HasSpace(checkFiles, getUsage bool, requestPath string)
return result, transferQuota return result, transferQuota
} }
func (c *BaseConnection) isSameResourceRename(virtualSourcePath, virtualTargetPath string) bool { // returns true if this is a rename on the same fs or local virtual folders
func (c *BaseConnection) isLocalOrSameFolderRename(virtualSourcePath, virtualTargetPath string) bool {
sourceFolder, errSrc := c.User.GetVirtualFolderForPath(virtualSourcePath) sourceFolder, errSrc := c.User.GetVirtualFolderForPath(virtualSourcePath)
dstFolder, errDst := c.User.GetVirtualFolderForPath(virtualTargetPath) dstFolder, errDst := c.User.GetVirtualFolderForPath(virtualTargetPath)
if errSrc != nil && errDst != nil { if errSrc != nil && errDst != nil {
@@ -1118,13 +1092,27 @@ func (c *BaseConnection) isSameResourceRename(virtualSourcePath, virtualTargetPa
if sourceFolder.Name == dstFolder.Name { if sourceFolder.Name == dstFolder.Name {
return true return true
} }
// we have different folders, check if they point to the same resource // we have different folders, only local fs is supported
return sourceFolder.FsConfig.IsSameResource(dstFolder.FsConfig) if sourceFolder.FsConfig.Provider == sdk.LocalFilesystemProvider &&
dstFolder.FsConfig.Provider == sdk.LocalFilesystemProvider {
return true
}
return false
}
if c.User.FsConfig.Provider != sdk.LocalFilesystemProvider {
return false
} }
if errSrc == nil { if errSrc == nil {
return sourceFolder.FsConfig.IsSameResource(c.User.FsConfig) if sourceFolder.FsConfig.Provider == sdk.LocalFilesystemProvider {
return true
}
} }
return dstFolder.FsConfig.IsSameResource(c.User.FsConfig) if errDst == nil {
if dstFolder.FsConfig.Provider == sdk.LocalFilesystemProvider {
return true
}
}
return false
} }
func (c *BaseConnection) isCrossFoldersRequest(virtualSourcePath, virtualTargetPath string) bool { func (c *BaseConnection) isCrossFoldersRequest(virtualSourcePath, virtualTargetPath string) bool {
@@ -1362,21 +1350,16 @@ func (c *BaseConnection) GetGenericError(err error) error {
if err == vfs.ErrStorageSizeUnavailable { if err == vfs.ErrStorageSizeUnavailable {
return fmt.Errorf("%w: %v", sftp.ErrSSHFxOpUnsupported, err.Error()) return fmt.Errorf("%w: %v", sftp.ErrSSHFxOpUnsupported, err.Error())
} }
if err == ErrShuttingDown {
return fmt.Errorf("%w: %v", sftp.ErrSSHFxFailure, err.Error())
}
if err != nil { if err != nil {
if e, ok := err.(*os.PathError); ok { if e, ok := err.(*os.PathError); ok {
c.Log(logger.LevelError, "generic path error: %+v", e)
return fmt.Errorf("%w: %v %v", sftp.ErrSSHFxFailure, e.Op, e.Err.Error()) return fmt.Errorf("%w: %v %v", sftp.ErrSSHFxFailure, e.Op, e.Err.Error())
} }
c.Log(logger.LevelError, "generic error: %+v", err) return fmt.Errorf("%w: %v", sftp.ErrSSHFxFailure, err.Error())
return fmt.Errorf("%w: %v", sftp.ErrSSHFxFailure, ErrGenericFailure.Error())
} }
return sftp.ErrSSHFxFailure return sftp.ErrSSHFxFailure
default: default:
if err == ErrPermissionDenied || err == ErrNotExist || err == ErrOpUnsupported || if err == ErrPermissionDenied || err == ErrNotExist || err == ErrOpUnsupported ||
err == ErrQuotaExceeded || err == vfs.ErrStorageSizeUnavailable || err == ErrShuttingDown { err == ErrQuotaExceeded || err == vfs.ErrStorageSizeUnavailable {
return err return err
} }
return ErrGenericFailure return ErrGenericFailure
@@ -1409,10 +1392,6 @@ func (c *BaseConnection) GetFsAndResolvedPath(virtualPath string) (vfs.Fs, strin
return nil, "", err return nil, "", err
} }
if isShuttingDown.Load() {
return nil, "", c.GetFsError(fs, ErrShuttingDown)
}
fsPath, err := fs.ResolvePath(virtualPath) fsPath, err := fs.ResolvePath(virtualPath)
if err != nil { if err != nil {
return nil, "", c.GetFsError(fs, err) return nil, "", c.GetFsError(fs, err)

View File

@@ -27,24 +27,20 @@ import (
"github.com/sftpgo/sdk" "github.com/sftpgo/sdk"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/kms"
"github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/util"
"github.com/drakkan/sftpgo/v2/internal/vfs" "github.com/drakkan/sftpgo/v2/vfs"
) )
// MockOsFs mockable OsFs // MockOsFs mockable OsFs
type MockOsFs struct { type MockOsFs struct {
vfs.Fs vfs.Fs
hasVirtualFolders bool hasVirtualFolders bool
name string
} }
// Name returns the name for the Fs implementation // Name returns the name for the Fs implementation
func (fs *MockOsFs) Name() string { func (fs *MockOsFs) Name() string {
if fs.name != "" {
return fs.name
}
return "mockOsFs" return "mockOsFs"
} }
@@ -61,10 +57,9 @@ func (fs *MockOsFs) Chtimes(name string, atime, mtime time.Time, isUploading boo
return vfs.ErrVfsUnsupported return vfs.ErrVfsUnsupported
} }
func newMockOsFs(hasVirtualFolders bool, connectionID, rootDir, name string) vfs.Fs { func newMockOsFs(hasVirtualFolders bool, connectionID, rootDir string) vfs.Fs {
return &MockOsFs{ return &MockOsFs{
Fs: vfs.NewOsFs(connectionID, rootDir, ""), Fs: vfs.NewOsFs(connectionID, rootDir, ""),
name: name,
hasVirtualFolders: hasVirtualFolders, hasVirtualFolders: hasVirtualFolders,
} }
} }
@@ -113,7 +108,7 @@ func TestSetStatMode(t *testing.T) {
} }
user.Permissions = make(map[string][]string) user.Permissions = make(map[string][]string)
user.Permissions["/"] = []string{dataprovider.PermAny} user.Permissions["/"] = []string{dataprovider.PermAny}
fs := newMockOsFs(true, "", user.GetHomeDir(), "") fs := newMockOsFs(true, "", user.GetHomeDir())
conn := NewBaseConnection("", ProtocolWebDAV, "", "", user) conn := NewBaseConnection("", ProtocolWebDAV, "", "", user)
err := conn.handleChmod(fs, fakePath, fakePath, nil) err := conn.handleChmod(fs, fakePath, fakePath, nil)
assert.NoError(t, err) assert.NoError(t, err)
@@ -135,28 +130,10 @@ func TestSetStatMode(t *testing.T) {
} }
func TestRecursiveRenameWalkError(t *testing.T) { func TestRecursiveRenameWalkError(t *testing.T) {
fs := vfs.NewOsFs("", filepath.Clean(os.TempDir()), "") fs := vfs.NewOsFs("", os.TempDir(), "")
conn := NewBaseConnection("", ProtocolWebDAV, "", "", dataprovider.User{ conn := NewBaseConnection("", ProtocolWebDAV, "", "", dataprovider.User{})
BaseUser: sdk.BaseUser{ err := conn.checkRecursiveRenameDirPermissions(fs, fs, "/source", "/target")
Permissions: map[string][]string{
"/": {dataprovider.PermListItems, dataprovider.PermUpload,
dataprovider.PermDownload, dataprovider.PermRenameDirs},
},
},
})
err := conn.checkRecursiveRenameDirPermissions(fs, fs, filepath.Join(os.TempDir(), "/source"),
filepath.Join(os.TempDir(), "/target"), "/source", "/target",
vfs.NewFileInfo("source", true, 0, time.Now(), false))
assert.ErrorIs(t, err, os.ErrNotExist) assert.ErrorIs(t, err, os.ErrNotExist)
conn.User.Permissions["/"] = []string{dataprovider.PermListItems, dataprovider.PermUpload,
dataprovider.PermDownload, dataprovider.PermRenameFiles}
// no dir rename permission, the quick check path returns permission error without walking
err = conn.checkRecursiveRenameDirPermissions(fs, fs, filepath.Join(os.TempDir(), "/source"),
filepath.Join(os.TempDir(), "/target"), "/source", "/target",
vfs.NewFileInfo("source", true, 0, time.Now(), false))
if assert.Error(t, err) {
assert.EqualError(t, err, conn.GetPermissionDeniedError().Error())
}
} }
func TestCrossRenameFsErrors(t *testing.T) { func TestCrossRenameFsErrors(t *testing.T) {
@@ -324,7 +301,7 @@ func TestErrorsMapping(t *testing.T) {
fs := vfs.NewOsFs("", os.TempDir(), "") fs := vfs.NewOsFs("", os.TempDir(), "")
conn := NewBaseConnection("", ProtocolSFTP, "", "", dataprovider.User{BaseUser: sdk.BaseUser{HomeDir: os.TempDir()}}) conn := NewBaseConnection("", ProtocolSFTP, "", "", dataprovider.User{BaseUser: sdk.BaseUser{HomeDir: os.TempDir()}})
osErrorsProtocols := []string{ProtocolWebDAV, ProtocolFTP, ProtocolHTTP, ProtocolHTTPShare, osErrorsProtocols := []string{ProtocolWebDAV, ProtocolFTP, ProtocolHTTP, ProtocolHTTPShare,
ProtocolDataRetention, ProtocolOIDC, protocolEventAction} ProtocolDataRetention, ProtocolOIDC}
for _, protocol := range supportedProtocols { for _, protocol := range supportedProtocols {
conn.SetProtocol(protocol) conn.SetProtocol(protocol)
err := conn.GetFsError(fs, os.ErrNotExist) err := conn.GetFsError(fs, os.ErrNotExist)
@@ -344,12 +321,14 @@ func TestErrorsMapping(t *testing.T) {
err = conn.GetFsError(fs, os.ErrClosed) err = conn.GetFsError(fs, os.ErrClosed)
if protocol == ProtocolSFTP { if protocol == ProtocolSFTP {
assert.ErrorIs(t, err, sftp.ErrSSHFxFailure) assert.ErrorIs(t, err, sftp.ErrSSHFxFailure)
assert.Contains(t, err.Error(), os.ErrClosed.Error())
} else { } else {
assert.EqualError(t, err, ErrGenericFailure.Error()) assert.EqualError(t, err, ErrGenericFailure.Error())
} }
err = conn.GetFsError(fs, ErrPermissionDenied) err = conn.GetFsError(fs, ErrPermissionDenied)
if protocol == ProtocolSFTP { if protocol == ProtocolSFTP {
assert.ErrorIs(t, err, sftp.ErrSSHFxFailure) assert.ErrorIs(t, err, sftp.ErrSSHFxFailure)
assert.Contains(t, err.Error(), ErrPermissionDenied.Error())
} else { } else {
assert.EqualError(t, err, ErrPermissionDenied.Error()) assert.EqualError(t, err, ErrPermissionDenied.Error())
} }
@@ -385,13 +364,6 @@ func TestErrorsMapping(t *testing.T) {
} else { } else {
assert.EqualError(t, err, ErrOpUnsupported.Error()) assert.EqualError(t, err, ErrOpUnsupported.Error())
} }
err = conn.GetFsError(fs, ErrShuttingDown)
if protocol == ProtocolSFTP {
assert.ErrorIs(t, err, sftp.ErrSSHFxFailure)
assert.Contains(t, err.Error(), ErrShuttingDown.Error())
} else {
assert.EqualError(t, err, ErrShuttingDown.Error())
}
} }
} }
@@ -441,7 +413,7 @@ func TestMaxWriteSize(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, int64(90), size) assert.Equal(t, int64(90), size)
fs = newMockOsFs(true, fs.ConnectionID(), user.GetHomeDir(), "") fs = newMockOsFs(true, fs.ConnectionID(), user.GetHomeDir())
size, err = conn.GetMaxWriteSize(quotaResult, true, 100, fs.IsUploadResumeSupported()) size, err = conn.GetMaxWriteSize(quotaResult, true, 100, fs.IsUploadResumeSupported())
assert.EqualError(t, err, ErrOpUnsupported.Error()) assert.EqualError(t, err, ErrOpUnsupported.Error())
assert.Equal(t, int64(0), size) assert.Equal(t, int64(0), size)

View File

@@ -29,14 +29,12 @@ import (
"sync" "sync"
"time" "time"
mail "github.com/xhit/go-simple-mail/v2" "github.com/drakkan/sftpgo/v2/command"
"github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/command" "github.com/drakkan/sftpgo/v2/httpclient"
"github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/internal/httpclient" "github.com/drakkan/sftpgo/v2/smtp"
"github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/util"
"github.com/drakkan/sftpgo/v2/internal/smtp"
"github.com/drakkan/sftpgo/v2/internal/util"
) )
// RetentionCheckNotification defines the supported notification methods for a retention check result // RetentionCheckNotification defines the supported notification methods for a retention check result
@@ -51,11 +49,11 @@ const (
) )
var ( var (
// RetentionChecks is the list of active retention checks // RetentionChecks is the list of active quota scans
RetentionChecks ActiveRetentionChecks RetentionChecks ActiveRetentionChecks
) )
// ActiveRetentionChecks holds the active retention checks // ActiveRetentionChecks holds the active quota scans
type ActiveRetentionChecks struct { type ActiveRetentionChecks struct {
sync.RWMutex sync.RWMutex
Checks []RetentionCheck Checks []RetentionCheck
@@ -68,7 +66,7 @@ func (c *ActiveRetentionChecks) Get() []RetentionCheck {
checks := make([]RetentionCheck, 0, len(c.Checks)) checks := make([]RetentionCheck, 0, len(c.Checks))
for _, check := range c.Checks { for _, check := range c.Checks {
foldersCopy := make([]dataprovider.FolderRetention, len(check.Folders)) foldersCopy := make([]FolderRetention, len(check.Folders))
copy(foldersCopy, check.Folders) copy(foldersCopy, check.Folders)
notificationsCopy := make([]string, len(check.Notifications)) notificationsCopy := make([]string, len(check.Notifications))
copy(notificationsCopy, check.Notifications) copy(notificationsCopy, check.Notifications)
@@ -126,6 +124,37 @@ func (c *ActiveRetentionChecks) remove(username string) bool {
return false return false
} }
// FolderRetention defines the retention policy for the specified directory path
type FolderRetention struct {
// Path is the exposed virtual directory path, if no other specific retention is defined,
// the retention applies for sub directories too. For example if retention is defined
// for the paths "/" and "/sub" then the retention for "/" is applied for any file outside
// the "/sub" directory
Path string `json:"path"`
// Retention time in hours. 0 means exclude this path
Retention int `json:"retention"`
// DeleteEmptyDirs defines if empty directories will be deleted.
// The user need the delete permission
DeleteEmptyDirs bool `json:"delete_empty_dirs,omitempty"`
// IgnoreUserPermissions defines if delete files even if the user does not have the delete permission.
// The default is "false" which means that files will be skipped if the user does not have the permission
// to delete them. This applies to sub directories too.
IgnoreUserPermissions bool `json:"ignore_user_permissions,omitempty"`
}
func (f *FolderRetention) isValid() error {
f.Path = path.Clean(f.Path)
if !path.IsAbs(f.Path) {
return util.NewValidationError(fmt.Sprintf("folder retention: invalid path %#v, please specify an absolute POSIX path",
f.Path))
}
if f.Retention < 0 {
return util.NewValidationError(fmt.Sprintf("invalid folder retention %v, it must be greater or equal to zero",
f.Retention))
}
return nil
}
type folderRetentionCheckResult struct { type folderRetentionCheckResult struct {
Path string `json:"path"` Path string `json:"path"`
Retention int `json:"retention"` Retention int `json:"retention"`
@@ -143,13 +172,13 @@ type RetentionCheck struct {
// retention check start time as unix timestamp in milliseconds // retention check start time as unix timestamp in milliseconds
StartTime int64 `json:"start_time"` StartTime int64 `json:"start_time"`
// affected folders // affected folders
Folders []dataprovider.FolderRetention `json:"folders"` Folders []FolderRetention `json:"folders"`
// how cleanup results will be notified // how cleanup results will be notified
Notifications []RetentionCheckNotification `json:"notifications,omitempty"` Notifications []RetentionCheckNotification `json:"notifications,omitempty"`
// email to use if the notification method is set to email // email to use if the notification method is set to email
Email string `json:"email,omitempty"` Email string `json:"email,omitempty"`
// Cleanup results // Cleanup results
results []folderRetentionCheckResult `json:"-"` results []*folderRetentionCheckResult `json:"-"`
conn *BaseConnection conn *BaseConnection
} }
@@ -159,7 +188,7 @@ func (c *RetentionCheck) Validate() error {
nothingToDo := true nothingToDo := true
for idx := range c.Folders { for idx := range c.Folders {
f := &c.Folders[idx] f := &c.Folders[idx]
if err := f.Validate(); err != nil { if err := f.isValid(); err != nil {
return err return err
} }
if f.Retention > 0 { if f.Retention > 0 {
@@ -201,7 +230,7 @@ func (c *RetentionCheck) updateUserPermissions() {
} }
} }
func (c *RetentionCheck) getFolderRetention(folderPath string) (dataprovider.FolderRetention, error) { func (c *RetentionCheck) getFolderRetention(folderPath string) (FolderRetention, error) {
dirsForPath := util.GetDirsForVirtualPath(folderPath) dirsForPath := util.GetDirsForVirtualPath(folderPath)
for _, dirPath := range dirsForPath { for _, dirPath := range dirsForPath {
for _, folder := range c.Folders { for _, folder := range c.Folders {
@@ -211,7 +240,7 @@ func (c *RetentionCheck) getFolderRetention(folderPath string) (dataprovider.Fol
} }
} }
return dataprovider.FolderRetention{}, fmt.Errorf("unable to find folder retention for %#v", folderPath) return FolderRetention{}, fmt.Errorf("unable to find folder retention for %#v", folderPath)
} }
func (c *RetentionCheck) removeFile(virtualPath string, info os.FileInfo) error { func (c *RetentionCheck) removeFile(virtualPath string, info os.FileInfo) error {
@@ -225,12 +254,10 @@ func (c *RetentionCheck) removeFile(virtualPath string, info os.FileInfo) error
func (c *RetentionCheck) cleanupFolder(folderPath string) error { func (c *RetentionCheck) cleanupFolder(folderPath string) error {
deleteFilesPerms := []string{dataprovider.PermDelete, dataprovider.PermDeleteFiles} deleteFilesPerms := []string{dataprovider.PermDelete, dataprovider.PermDeleteFiles}
startTime := time.Now() startTime := time.Now()
result := folderRetentionCheckResult{ result := &folderRetentionCheckResult{
Path: folderPath, Path: folderPath,
} }
defer func() { c.results = append(c.results, result)
c.results = append(c.results, result)
}()
if !c.conn.User.HasPerm(dataprovider.PermListItems, folderPath) || !c.conn.User.HasAnyPerm(deleteFilesPerms, folderPath) { if !c.conn.User.HasPerm(dataprovider.PermListItems, folderPath) || !c.conn.User.HasAnyPerm(deleteFilesPerms, folderPath) {
result.Elapsed = time.Since(startTime) result.Elapsed = time.Since(startTime)
result.Info = "data retention check skipped: no permissions" result.Info = "data retention check skipped: no permissions"
@@ -305,15 +332,7 @@ func (c *RetentionCheck) cleanupFolder(folderPath string) error {
} }
func (c *RetentionCheck) checkEmptyDirRemoval(folderPath string) { func (c *RetentionCheck) checkEmptyDirRemoval(folderPath string) {
if folderPath == "/" { if folderPath != "/" && c.conn.User.HasAnyPerm([]string{
return
}
for _, folder := range c.Folders {
if folderPath == folder.Path {
return
}
}
if c.conn.User.HasAnyPerm([]string{
dataprovider.PermDelete, dataprovider.PermDelete,
dataprovider.PermDeleteDirs, dataprovider.PermDeleteDirs,
}, path.Dir(folderPath), }, path.Dir(folderPath),
@@ -327,7 +346,7 @@ func (c *RetentionCheck) checkEmptyDirRemoval(folderPath string) {
} }
// Start starts the retention check // Start starts the retention check
func (c *RetentionCheck) Start() error { func (c *RetentionCheck) Start() {
c.conn.Log(logger.LevelInfo, "retention check started") c.conn.Log(logger.LevelInfo, "retention check started")
defer RetentionChecks.remove(c.conn.User.Username) defer RetentionChecks.remove(c.conn.User.Username)
defer c.conn.CloseFS() //nolint:errcheck defer c.conn.CloseFS() //nolint:errcheck
@@ -338,55 +357,54 @@ func (c *RetentionCheck) Start() error {
if err := c.cleanupFolder(folder.Path); err != nil { if err := c.cleanupFolder(folder.Path); err != nil {
c.conn.Log(logger.LevelError, "retention check failed, unable to cleanup folder %#v", folder.Path) c.conn.Log(logger.LevelError, "retention check failed, unable to cleanup folder %#v", folder.Path)
c.sendNotifications(time.Since(startTime), err) c.sendNotifications(time.Since(startTime), err)
return err return
} }
} }
} }
c.conn.Log(logger.LevelInfo, "retention check completed") c.conn.Log(logger.LevelInfo, "retention check completed")
c.sendNotifications(time.Since(startTime), nil) c.sendNotifications(time.Since(startTime), nil)
return nil
} }
func (c *RetentionCheck) sendNotifications(elapsed time.Duration, err error) { func (c *RetentionCheck) sendNotifications(elapsed time.Duration, err error) {
for _, notification := range c.Notifications { for _, notification := range c.Notifications {
switch notification { switch notification {
case RetentionCheckNotificationEmail: case RetentionCheckNotificationEmail:
c.sendEmailNotification(err) //nolint:errcheck c.sendEmailNotification(elapsed, err) //nolint:errcheck
case RetentionCheckNotificationHook: case RetentionCheckNotificationHook:
c.sendHookNotification(elapsed, err) //nolint:errcheck c.sendHookNotification(elapsed, err) //nolint:errcheck
} }
} }
} }
func (c *RetentionCheck) sendEmailNotification(errCheck error) error { func (c *RetentionCheck) sendEmailNotification(elapsed time.Duration, errCheck error) error {
params := EventParams{} body := new(bytes.Buffer)
if len(c.results) > 0 || errCheck != nil { data := make(map[string]any)
params.retentionChecks = append(params.retentionChecks, executedRetentionCheck{ data["Results"] = c.results
Username: c.conn.User.Username, totalDeletedFiles := 0
ActionName: "Retention check", totalDeletedSize := int64(0)
Results: c.results, for _, result := range c.results {
}) totalDeletedFiles += result.DeletedFiles
totalDeletedSize += result.DeletedSize
} }
var files []mail.File data["HumanizeSize"] = util.ByteCountIEC
f, err := params.getRetentionReportsAsMailAttachment() data["TotalFiles"] = totalDeletedFiles
if err != nil { data["TotalSize"] = totalDeletedSize
c.conn.Log(logger.LevelError, "unable to get retention report as mail attachment: %v", err) data["Elapsed"] = elapsed
data["Username"] = c.conn.User.Username
data["StartTime"] = util.GetTimeFromMsecSinceEpoch(c.StartTime)
if errCheck == nil {
data["Status"] = "Succeeded"
} else {
data["Status"] = "Failed"
}
if err := smtp.RenderRetentionReportTemplate(body, data); err != nil {
c.conn.Log(logger.LevelError, "unable to render retention check template: %v", err)
return err return err
} }
f.Name = "retention-report.zip"
files = append(files, f)
startTime := time.Now() startTime := time.Now()
var subject string subject := fmt.Sprintf("Retention check completed for user %#v", c.conn.User.Username)
if errCheck == nil { if err := smtp.SendEmail(c.Email, subject, body.String(), smtp.EmailContentTypeTextHTML); err != nil {
subject = fmt.Sprintf("Successful retention check for user %q", c.conn.User.Username)
} else {
subject = fmt.Sprintf("Retention check failed for user %q", c.conn.User.Username)
}
body := "Further details attached."
err = smtp.SendEmail([]string{c.Email}, subject, body, smtp.EmailContentTypeTextPlain, files...)
if err != nil {
c.conn.Log(logger.LevelError, "unable to notify retention check result via email: %v, elapsed: %v", err, c.conn.Log(logger.LevelError, "unable to notify retention check result via email: %v, elapsed: %v", err,
time.Since(startTime)) time.Since(startTime))
return err return err
@@ -396,9 +414,6 @@ func (c *RetentionCheck) sendEmailNotification(errCheck error) error {
} }
func (c *RetentionCheck) sendHookNotification(elapsed time.Duration, errCheck error) error { func (c *RetentionCheck) sendHookNotification(elapsed time.Duration, errCheck error) error {
startNewHook()
defer hookEnded()
data := make(map[string]any) data := make(map[string]any)
totalDeletedFiles := 0 totalDeletedFiles := 0
totalDeletedSize := int64(0) totalDeletedSize := int64(0)
@@ -450,11 +465,11 @@ func (c *RetentionCheck) sendHookNotification(elapsed time.Duration, errCheck er
c.conn.Log(logger.LevelError, "%v", err) c.conn.Log(logger.LevelError, "%v", err)
return err return err
} }
timeout, env, args := command.GetConfig(Config.DataRetentionHook, command.HookDataRetention) timeout, env := command.GetConfig(Config.DataRetentionHook)
ctx, cancel := context.WithTimeout(context.Background(), timeout) ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel() defer cancel()
cmd := exec.CommandContext(ctx, Config.DataRetentionHook, args...) cmd := exec.CommandContext(ctx, Config.DataRetentionHook)
cmd.Env = append(env, cmd.Env = append(env,
fmt.Sprintf("SFTPGO_DATA_RETENTION_RESULT=%v", string(jsonData))) fmt.Sprintf("SFTPGO_DATA_RETENTION_RESULT=%v", string(jsonData)))
err := cmd.Run() err := cmd.Run()

View File

@@ -26,23 +26,31 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/smtp" "github.com/drakkan/sftpgo/v2/smtp"
) )
func TestRetentionValidation(t *testing.T) { func TestRetentionValidation(t *testing.T) {
check := RetentionCheck{} check := RetentionCheck{}
check.Folders = []dataprovider.FolderRetention{ check.Folders = append(check.Folders, FolderRetention{
Path: "relative",
Retention: 10,
})
err := check.Validate()
require.Error(t, err)
assert.Contains(t, err.Error(), "please specify an absolute POSIX path")
check.Folders = []FolderRetention{
{ {
Path: "/", Path: "/",
Retention: -1, Retention: -1,
}, },
} }
err := check.Validate() err = check.Validate()
require.Error(t, err) require.Error(t, err)
assert.Contains(t, err.Error(), "invalid folder retention") assert.Contains(t, err.Error(), "invalid folder retention")
check.Folders = []dataprovider.FolderRetention{ check.Folders = []FolderRetention{
{ {
Path: "/ab/..", Path: "/ab/..",
Retention: 0, Retention: 0,
@@ -53,7 +61,7 @@ func TestRetentionValidation(t *testing.T) {
assert.Contains(t, err.Error(), "nothing to delete") assert.Contains(t, err.Error(), "nothing to delete")
assert.Equal(t, "/", check.Folders[0].Path) assert.Equal(t, "/", check.Folders[0].Path)
check.Folders = append(check.Folders, dataprovider.FolderRetention{ check.Folders = append(check.Folders, FolderRetention{
Path: "/../..", Path: "/../..",
Retention: 24, Retention: 24,
}) })
@@ -61,7 +69,7 @@ func TestRetentionValidation(t *testing.T) {
require.Error(t, err) require.Error(t, err)
assert.Contains(t, err.Error(), `duplicated folder path "/"`) assert.Contains(t, err.Error(), `duplicated folder path "/"`)
check.Folders = []dataprovider.FolderRetention{ check.Folders = []FolderRetention{
{ {
Path: "/dir1", Path: "/dir1",
Retention: 48, Retention: 48,
@@ -86,7 +94,7 @@ func TestRetentionValidation(t *testing.T) {
Port: 25, Port: 25,
TemplatesPath: "templates", TemplatesPath: "templates",
} }
err = smtpCfg.Initialize(configDir) err = smtpCfg.Initialize("..")
require.NoError(t, err) require.NoError(t, err)
err = check.Validate() err = check.Validate()
@@ -98,7 +106,7 @@ func TestRetentionValidation(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
smtpCfg = smtp.Config{} smtpCfg = smtp.Config{}
err = smtpCfg.Initialize(configDir) err = smtpCfg.Initialize("..")
require.NoError(t, err) require.NoError(t, err)
check.Notifications = []RetentionCheckNotification{RetentionCheckNotificationHook} check.Notifications = []RetentionCheckNotification{RetentionCheckNotificationHook}
@@ -118,7 +126,7 @@ func TestRetentionEmailNotifications(t *testing.T) {
Port: 2525, Port: 2525,
TemplatesPath: "templates", TemplatesPath: "templates",
} }
err := smtpCfg.Initialize(configDir) err := smtpCfg.Initialize("..")
require.NoError(t, err) require.NoError(t, err)
user := dataprovider.User{ user := dataprovider.User{
@@ -131,7 +139,7 @@ func TestRetentionEmailNotifications(t *testing.T) {
check := RetentionCheck{ check := RetentionCheck{
Notifications: []RetentionCheckNotification{RetentionCheckNotificationEmail}, Notifications: []RetentionCheckNotification{RetentionCheckNotificationEmail},
Email: "notification@example.com", Email: "notification@example.com",
results: []folderRetentionCheckResult{ results: []*folderRetentionCheckResult{
{ {
Path: "/", Path: "/",
Retention: 24, Retention: 24,
@@ -146,36 +154,21 @@ func TestRetentionEmailNotifications(t *testing.T) {
conn.ID = fmt.Sprintf("data_retention_%v", user.Username) conn.ID = fmt.Sprintf("data_retention_%v", user.Username)
check.conn = conn check.conn = conn
check.sendNotifications(1*time.Second, nil) check.sendNotifications(1*time.Second, nil)
err = check.sendEmailNotification(nil) err = check.sendEmailNotification(1*time.Second, nil)
assert.NoError(t, err) assert.NoError(t, err)
err = check.sendEmailNotification(errors.New("test error")) err = check.sendEmailNotification(1*time.Second, errors.New("test error"))
assert.NoError(t, err) assert.NoError(t, err)
check.results = nil
err = check.sendEmailNotification(nil)
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "no data retention report available")
}
smtpCfg.Port = 2626 smtpCfg.Port = 2626
err = smtpCfg.Initialize(configDir) err = smtpCfg.Initialize("..")
require.NoError(t, err) require.NoError(t, err)
err = check.sendEmailNotification(nil) err = check.sendEmailNotification(1*time.Second, nil)
assert.Error(t, err) assert.Error(t, err)
check.results = []folderRetentionCheckResult{
{
Path: "/",
Retention: 24,
DeletedFiles: 20,
DeletedSize: 456789,
Elapsed: 12 * time.Second,
},
}
smtpCfg = smtp.Config{} smtpCfg = smtp.Config{}
err = smtpCfg.Initialize(configDir) err = smtpCfg.Initialize("..")
require.NoError(t, err) require.NoError(t, err)
err = check.sendEmailNotification(nil) err = check.sendEmailNotification(1*time.Second, nil)
assert.Error(t, err) assert.Error(t, err)
} }
@@ -192,7 +185,7 @@ func TestRetentionHookNotifications(t *testing.T) {
user.Permissions["/"] = []string{dataprovider.PermAny} user.Permissions["/"] = []string{dataprovider.PermAny}
check := RetentionCheck{ check := RetentionCheck{
Notifications: []RetentionCheckNotification{RetentionCheckNotificationHook}, Notifications: []RetentionCheckNotification{RetentionCheckNotificationHook},
results: []folderRetentionCheckResult{ results: []*folderRetentionCheckResult{
{ {
Path: "/", Path: "/",
Retention: 24, Retention: 24,
@@ -247,7 +240,7 @@ func TestRetentionPermissionsAndGetFolder(t *testing.T) {
user.Permissions["/dir2/sub2"] = []string{dataprovider.PermDelete} user.Permissions["/dir2/sub2"] = []string{dataprovider.PermDelete}
check := RetentionCheck{ check := RetentionCheck{
Folders: []dataprovider.FolderRetention{ Folders: []FolderRetention{
{ {
Path: "/dir2", Path: "/dir2",
Retention: 24 * 7, Retention: 24 * 7,
@@ -307,7 +300,7 @@ func TestRetentionCheckAddRemove(t *testing.T) {
user.Permissions = make(map[string][]string) user.Permissions = make(map[string][]string)
user.Permissions["/"] = []string{dataprovider.PermAny} user.Permissions["/"] = []string{dataprovider.PermAny}
check := RetentionCheck{ check := RetentionCheck{
Folders: []dataprovider.FolderRetention{ Folders: []FolderRetention{
{ {
Path: "/", Path: "/",
Retention: 48, Retention: 48,
@@ -341,7 +334,7 @@ func TestCleanupErrors(t *testing.T) {
user.Permissions = make(map[string][]string) user.Permissions = make(map[string][]string)
user.Permissions["/"] = []string{dataprovider.PermAny} user.Permissions["/"] = []string{dataprovider.PermAny}
check := &RetentionCheck{ check := &RetentionCheck{
Folders: []dataprovider.FolderRetention{ Folders: []FolderRetention{
{ {
Path: "/path", Path: "/path",
Retention: 48, Retention: 48,

View File

@@ -25,9 +25,9 @@ import (
"github.com/yl2chen/cidranger" "github.com/yl2chen/cidranger"
"github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/util"
) )
// HostEvent is the enumerable for the supported host events // HostEvent is the enumerable for the supported host events

View File

@@ -366,7 +366,7 @@ func TestLoadHostListFromFile(t *testing.T) {
assert.Len(t, hostList.IPAddresses, 0) assert.Len(t, hostList.IPAddresses, 0)
assert.Equal(t, 0, hostList.Ranges.Len()) assert.Equal(t, 0, hostList.Ranges.Len())
if runtime.GOOS != osWindows { if runtime.GOOS != "windows" {
err = os.Chmod(hostsFilePath, 0111) err = os.Chmod(hostsFilePath, 0111)
assert.NoError(t, err) assert.NoError(t, err)

View File

@@ -17,9 +17,9 @@ package common
import ( import (
"time" "time"
"github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/util"
) )
type dbDefender struct { type dbDefender struct {
@@ -107,14 +107,6 @@ func (d *dbDefender) AddEvent(ip string, event HostEvent) {
if host.Score > d.config.Threshold { if host.Score > d.config.Threshold {
banTime := time.Now().Add(time.Duration(d.config.BanTime) * time.Minute) banTime := time.Now().Add(time.Duration(d.config.BanTime) * time.Minute)
err = dataprovider.SetDefenderBanTime(ip, util.GetTimeAsMsSinceEpoch(banTime)) err = dataprovider.SetDefenderBanTime(ip, util.GetTimeAsMsSinceEpoch(banTime))
if err == nil {
eventManager.handleIPBlockedEvent(EventParams{
Event: ipBlockedEventName,
IP: ip,
Timestamp: time.Now().UnixNano(),
Status: 1,
})
}
} }
if err == nil { if err == nil {

View File

@@ -24,8 +24,8 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/util"
) )
func TestBasicDbDefender(t *testing.T) { func TestBasicDbDefender(t *testing.T) {

View File

@@ -18,8 +18,8 @@ import (
"sort" "sort"
"time" "time"
"github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/util"
) )
type memoryDefender struct { type memoryDefender struct {
@@ -209,12 +209,6 @@ func (d *memoryDefender) AddEvent(ip string, event HostEvent) {
d.banned[ip] = time.Now().Add(time.Duration(d.config.BanTime) * time.Minute) d.banned[ip] = time.Now().Add(time.Duration(d.config.BanTime) * time.Minute)
delete(d.hosts, ip) delete(d.hosts, ip)
d.cleanupBanned() d.cleanupBanned()
eventManager.handleIPBlockedEvent(EventParams{
Event: ipBlockedEventName,
IP: ip,
Timestamp: time.Now().UnixNano(),
Status: 1,
})
} else { } else {
d.hosts[ip] = hs d.hosts[ip] = hs
} }

View File

@@ -24,8 +24,8 @@ import (
"github.com/GehirnInc/crypt/md5_crypt" "github.com/GehirnInc/crypt/md5_crypt"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/util"
) )
const ( const (

File diff suppressed because it is too large Load Diff

View File

@@ -25,7 +25,7 @@ import (
"golang.org/x/time/rate" "golang.org/x/time/rate"
"github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/util"
) )
var ( var (
@@ -186,16 +186,16 @@ func (rl *rateLimiter) Wait(source string) (time.Duration, error) {
} }
type sourceRateLimiter struct { type sourceRateLimiter struct {
lastActivity *atomic.Int64 lastActivity int64
bucket *rate.Limiter bucket *rate.Limiter
} }
func (s *sourceRateLimiter) updateLastActivity() { func (s *sourceRateLimiter) updateLastActivity() {
s.lastActivity.Store(time.Now().UnixNano()) atomic.StoreInt64(&s.lastActivity, time.Now().UnixNano())
} }
func (s *sourceRateLimiter) getLastActivity() int64 { func (s *sourceRateLimiter) getLastActivity() int64 {
return s.lastActivity.Load() return atomic.LoadInt64(&s.lastActivity)
} }
type sourceBuckets struct { type sourceBuckets struct {
@@ -224,8 +224,7 @@ func (b *sourceBuckets) addAndReserve(r *rate.Limiter, source string) *rate.Rese
b.cleanup() b.cleanup()
src := sourceRateLimiter{ src := sourceRateLimiter{
lastActivity: new(atomic.Int64), bucket: r,
bucket: r,
} }
src.updateLastActivity() src.updateLastActivity()
b.buckets[source] = src b.buckets[source] = src

View File

@@ -21,7 +21,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/util"
) )
func TestRateLimiterConfig(t *testing.T) { func TestRateLimiterConfig(t *testing.T) {

View File

@@ -25,8 +25,8 @@ import (
"sync" "sync"
"time" "time"
"github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/util"
) )
const ( const (
@@ -97,7 +97,7 @@ func (m *CertManager) loadCertificates() error {
// GetCertificateFunc returns the loaded certificate // GetCertificateFunc returns the loaded certificate
func (m *CertManager) GetCertificateFunc(certID string) func(*tls.ClientHelloInfo) (*tls.Certificate, error) { func (m *CertManager) GetCertificateFunc(certID string) func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
m.RLock() m.RLock()
defer m.RUnlock() defer m.RUnlock()

View File

@@ -21,10 +21,10 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/internal/metric" "github.com/drakkan/sftpgo/v2/metric"
"github.com/drakkan/sftpgo/v2/internal/vfs" "github.com/drakkan/sftpgo/v2/vfs"
) )
var ( var (
@@ -35,8 +35,8 @@ var (
// BaseTransfer contains protocols common transfer details for an upload or a download. // BaseTransfer contains protocols common transfer details for an upload or a download.
type BaseTransfer struct { //nolint:maligned type BaseTransfer struct { //nolint:maligned
ID int64 ID int64
BytesSent atomic.Int64 BytesSent int64
BytesReceived atomic.Int64 BytesReceived int64
Fs vfs.Fs Fs vfs.Fs
File vfs.File File vfs.File
Connection *BaseConnection Connection *BaseConnection
@@ -52,7 +52,7 @@ type BaseTransfer struct { //nolint:maligned
truncatedSize int64 truncatedSize int64
isNewFile bool isNewFile bool
transferType int transferType int
AbortTransfer atomic.Bool AbortTransfer int32
aTime time.Time aTime time.Time
mTime time.Time mTime time.Time
transferQuota dataprovider.TransferQuota transferQuota dataprovider.TransferQuota
@@ -79,14 +79,14 @@ func NewBaseTransfer(file vfs.File, conn *BaseConnection, cancelFn func(), fsPat
InitialSize: initialSize, InitialSize: initialSize,
isNewFile: isNewFile, isNewFile: isNewFile,
requestPath: requestPath, requestPath: requestPath,
BytesSent: 0,
BytesReceived: 0,
MaxWriteSize: maxWriteSize, MaxWriteSize: maxWriteSize,
AbortTransfer: 0,
truncatedSize: truncatedSize, truncatedSize: truncatedSize,
transferQuota: transferQuota, transferQuota: transferQuota,
Fs: fs, Fs: fs,
} }
t.AbortTransfer.Store(false)
t.BytesSent.Store(0)
t.BytesReceived.Store(0)
conn.AddTransfer(t) conn.AddTransfer(t)
return t return t
@@ -115,19 +115,19 @@ func (t *BaseTransfer) GetType() int {
// GetSize returns the transferred size // GetSize returns the transferred size
func (t *BaseTransfer) GetSize() int64 { func (t *BaseTransfer) GetSize() int64 {
if t.transferType == TransferDownload { if t.transferType == TransferDownload {
return t.BytesSent.Load() return atomic.LoadInt64(&t.BytesSent)
} }
return t.BytesReceived.Load() return atomic.LoadInt64(&t.BytesReceived)
} }
// GetDownloadedSize returns the transferred size // GetDownloadedSize returns the transferred size
func (t *BaseTransfer) GetDownloadedSize() int64 { func (t *BaseTransfer) GetDownloadedSize() int64 {
return t.BytesSent.Load() return atomic.LoadInt64(&t.BytesSent)
} }
// GetUploadedSize returns the transferred size // GetUploadedSize returns the transferred size
func (t *BaseTransfer) GetUploadedSize() int64 { func (t *BaseTransfer) GetUploadedSize() int64 {
return t.BytesReceived.Load() return atomic.LoadInt64(&t.BytesReceived)
} }
// GetStartTime returns the start time // GetStartTime returns the start time
@@ -153,7 +153,7 @@ func (t *BaseTransfer) SignalClose(err error) {
t.Lock() t.Lock()
t.errAbort = err t.errAbort = err
t.Unlock() t.Unlock()
t.AbortTransfer.Store(true) atomic.StoreInt32(&(t.AbortTransfer), 1)
} }
// GetTruncatedSize returns the truncated sized if this is an upload overwriting // GetTruncatedSize returns the truncated sized if this is an upload overwriting
@@ -217,11 +217,11 @@ func (t *BaseTransfer) CheckRead() error {
return nil return nil
} }
if t.transferQuota.AllowedTotalSize > 0 { if t.transferQuota.AllowedTotalSize > 0 {
if t.BytesSent.Load()+t.BytesReceived.Load() > t.transferQuota.AllowedTotalSize { if atomic.LoadInt64(&t.BytesSent)+atomic.LoadInt64(&t.BytesReceived) > t.transferQuota.AllowedTotalSize {
return t.Connection.GetReadQuotaExceededError() return t.Connection.GetReadQuotaExceededError()
} }
} else if t.transferQuota.AllowedDLSize > 0 { } else if t.transferQuota.AllowedDLSize > 0 {
if t.BytesSent.Load() > t.transferQuota.AllowedDLSize { if atomic.LoadInt64(&t.BytesSent) > t.transferQuota.AllowedDLSize {
return t.Connection.GetReadQuotaExceededError() return t.Connection.GetReadQuotaExceededError()
} }
} }
@@ -230,18 +230,18 @@ func (t *BaseTransfer) CheckRead() error {
// CheckWrite returns an error if write if not allowed // CheckWrite returns an error if write if not allowed
func (t *BaseTransfer) CheckWrite() error { func (t *BaseTransfer) CheckWrite() error {
if t.MaxWriteSize > 0 && t.BytesReceived.Load() > t.MaxWriteSize { if t.MaxWriteSize > 0 && atomic.LoadInt64(&t.BytesReceived) > t.MaxWriteSize {
return t.Connection.GetQuotaExceededError() return t.Connection.GetQuotaExceededError()
} }
if t.transferQuota.AllowedULSize == 0 && t.transferQuota.AllowedTotalSize == 0 { if t.transferQuota.AllowedULSize == 0 && t.transferQuota.AllowedTotalSize == 0 {
return nil return nil
} }
if t.transferQuota.AllowedTotalSize > 0 { if t.transferQuota.AllowedTotalSize > 0 {
if t.BytesSent.Load()+t.BytesReceived.Load() > t.transferQuota.AllowedTotalSize { if atomic.LoadInt64(&t.BytesSent)+atomic.LoadInt64(&t.BytesReceived) > t.transferQuota.AllowedTotalSize {
return t.Connection.GetQuotaExceededError() return t.Connection.GetQuotaExceededError()
} }
} else if t.transferQuota.AllowedULSize > 0 { } else if t.transferQuota.AllowedULSize > 0 {
if t.BytesReceived.Load() > t.transferQuota.AllowedULSize { if atomic.LoadInt64(&t.BytesReceived) > t.transferQuota.AllowedULSize {
return t.Connection.GetQuotaExceededError() return t.Connection.GetQuotaExceededError()
} }
} }
@@ -261,14 +261,13 @@ func (t *BaseTransfer) Truncate(fsPath string, size int64) (int64, error) {
if t.MaxWriteSize > 0 { if t.MaxWriteSize > 0 {
sizeDiff := initialSize - size sizeDiff := initialSize - size
t.MaxWriteSize += sizeDiff t.MaxWriteSize += sizeDiff
metric.TransferCompleted(t.BytesSent.Load(), t.BytesReceived.Load(), metric.TransferCompleted(atomic.LoadInt64(&t.BytesSent), atomic.LoadInt64(&t.BytesReceived), t.transferType, t.ErrTransfer)
t.transferType, t.ErrTransfer, vfs.IsSFTPFs(t.Fs))
if t.transferQuota.HasSizeLimits() { if t.transferQuota.HasSizeLimits() {
go func(ulSize, dlSize int64, user dataprovider.User) { go func(ulSize, dlSize int64, user dataprovider.User) {
dataprovider.UpdateUserTransferQuota(&user, ulSize, dlSize, false) //nolint:errcheck dataprovider.UpdateUserTransferQuota(&user, ulSize, dlSize, false) //nolint:errcheck
}(t.BytesReceived.Load(), t.BytesSent.Load(), t.Connection.User) }(atomic.LoadInt64(&t.BytesReceived), atomic.LoadInt64(&t.BytesSent), t.Connection.User)
} }
t.BytesReceived.Store(0) atomic.StoreInt64(&t.BytesReceived, 0)
} }
t.Unlock() t.Unlock()
} }
@@ -276,7 +275,7 @@ func (t *BaseTransfer) Truncate(fsPath string, size int64) (int64, error) {
fsPath, size, t.MaxWriteSize, t.InitialSize, err) fsPath, size, t.MaxWriteSize, t.InitialSize, err)
return initialSize, err return initialSize, err
} }
if size == 0 && t.BytesSent.Load() == 0 { if size == 0 && atomic.LoadInt64(&t.BytesSent) == 0 {
// for cloud providers the file is always truncated to zero, we don't support append/resume for uploads // for cloud providers the file is always truncated to zero, we don't support append/resume for uploads
// for buffered SFTP we can have buffered bytes so we returns an error // for buffered SFTP we can have buffered bytes so we returns an error
if !vfs.IsBufferedSFTPFs(t.Fs) { if !vfs.IsBufferedSFTPFs(t.Fs) {
@@ -302,30 +301,23 @@ func (t *BaseTransfer) TransferError(err error) {
} }
elapsed := time.Since(t.start).Nanoseconds() / 1000000 elapsed := time.Since(t.start).Nanoseconds() / 1000000
t.Connection.Log(logger.LevelError, "Unexpected error for transfer, path: %#v, error: \"%v\" bytes sent: %v, "+ t.Connection.Log(logger.LevelError, "Unexpected error for transfer, path: %#v, error: \"%v\" bytes sent: %v, "+
"bytes received: %v transfer running since %v ms", t.fsPath, t.ErrTransfer, t.BytesSent.Load(), "bytes received: %v transfer running since %v ms", t.fsPath, t.ErrTransfer, atomic.LoadInt64(&t.BytesSent),
t.BytesReceived.Load(), elapsed) atomic.LoadInt64(&t.BytesReceived), elapsed)
} }
func (t *BaseTransfer) getUploadFileSize() (int64, int, error) { func (t *BaseTransfer) getUploadFileSize() (int64, error) {
var fileSize int64 var fileSize int64
var deletedFiles int
info, err := t.Fs.Stat(t.fsPath) info, err := t.Fs.Stat(t.fsPath)
if err == nil { if err == nil {
fileSize = info.Size() fileSize = info.Size()
} }
if t.ErrTransfer != nil && vfs.IsCryptOsFs(t.Fs) { if vfs.IsCryptOsFs(t.Fs) && t.ErrTransfer != nil {
errDelete := t.Fs.Remove(t.fsPath, false) errDelete := t.Fs.Remove(t.fsPath, false)
if errDelete != nil { if errDelete != nil {
t.Connection.Log(logger.LevelWarn, "error removing partial crypto file %#v: %v", t.fsPath, errDelete) t.Connection.Log(logger.LevelWarn, "error removing partial crypto file %#v: %v", t.fsPath, errDelete)
} else {
fileSize = 0
deletedFiles = 1
t.BytesReceived.Store(0)
t.MinWriteOffset = 0
} }
} }
return fileSize, deletedFiles, err return fileSize, err
} }
// return 1 if the file is outside the user home dir // return 1 if the file is outside the user home dir
@@ -340,7 +332,7 @@ func (t *BaseTransfer) checkUploadOutsideHomeDir(err error) int {
t.Connection.Log(logger.LevelWarn, "upload in temp path cannot be renamed, delete temporary file: %#v, deletion error: %v", t.Connection.Log(logger.LevelWarn, "upload in temp path cannot be renamed, delete temporary file: %#v, deletion error: %v",
t.effectiveFsPath, err) t.effectiveFsPath, err)
// the file is outside the home dir so don't update the quota // the file is outside the home dir so don't update the quota
t.BytesReceived.Store(0) atomic.StoreInt64(&t.BytesReceived, 0)
t.MinWriteOffset = 0 t.MinWriteOffset = 0
return 1 return 1
} }
@@ -354,18 +346,22 @@ func (t *BaseTransfer) Close() error {
defer t.Connection.RemoveTransfer(t) defer t.Connection.RemoveTransfer(t)
var err error var err error
numFiles := t.getUploadedFiles() numFiles := 0
metric.TransferCompleted(t.BytesSent.Load(), t.BytesReceived.Load(), if t.isNewFile {
t.transferType, t.ErrTransfer, vfs.IsSFTPFs(t.Fs)) numFiles = 1
}
metric.TransferCompleted(atomic.LoadInt64(&t.BytesSent), atomic.LoadInt64(&t.BytesReceived),
t.transferType, t.ErrTransfer)
if t.transferQuota.HasSizeLimits() { if t.transferQuota.HasSizeLimits() {
dataprovider.UpdateUserTransferQuota(&t.Connection.User, t.BytesReceived.Load(), //nolint:errcheck dataprovider.UpdateUserTransferQuota(&t.Connection.User, atomic.LoadInt64(&t.BytesReceived), //nolint:errcheck
t.BytesSent.Load(), false) atomic.LoadInt64(&t.BytesSent), false)
} }
if t.File != nil && t.Connection.IsQuotaExceededError(t.ErrTransfer) { if t.File != nil && t.Connection.IsQuotaExceededError(t.ErrTransfer) {
// if quota is exceeded we try to remove the partial file for uploads to local filesystem // if quota is exceeded we try to remove the partial file for uploads to local filesystem
err = t.Fs.Remove(t.File.Name(), false) err = t.Fs.Remove(t.File.Name(), false)
if err == nil { if err == nil {
t.BytesReceived.Store(0) numFiles--
atomic.StoreInt64(&t.BytesReceived, 0)
t.MinWriteOffset = 0 t.MinWriteOffset = 0
} }
t.Connection.Log(logger.LevelWarn, "upload denied due to space limit, delete temporary file: %#v, deletion error: %v", t.Connection.Log(logger.LevelWarn, "upload denied due to space limit, delete temporary file: %#v, deletion error: %v",
@@ -376,43 +372,35 @@ func (t *BaseTransfer) Close() error {
t.Connection.Log(logger.LevelDebug, "atomic upload completed, rename: %#v -> %#v, error: %v", t.Connection.Log(logger.LevelDebug, "atomic upload completed, rename: %#v -> %#v, error: %v",
t.effectiveFsPath, t.fsPath, err) t.effectiveFsPath, t.fsPath, err)
// the file must be removed if it is uploaded to a path outside the home dir and cannot be renamed // the file must be removed if it is uploaded to a path outside the home dir and cannot be renamed
t.checkUploadOutsideHomeDir(err) numFiles -= t.checkUploadOutsideHomeDir(err)
} else { } else {
err = t.Fs.Remove(t.effectiveFsPath, false) err = t.Fs.Remove(t.effectiveFsPath, false)
t.Connection.Log(logger.LevelWarn, "atomic upload completed with error: \"%v\", delete temporary file: %#v, deletion error: %v", t.Connection.Log(logger.LevelWarn, "atomic upload completed with error: \"%v\", delete temporary file: %#v, deletion error: %v",
t.ErrTransfer, t.effectiveFsPath, err) t.ErrTransfer, t.effectiveFsPath, err)
if err == nil { if err == nil {
t.BytesReceived.Store(0) numFiles--
atomic.StoreInt64(&t.BytesReceived, 0)
t.MinWriteOffset = 0 t.MinWriteOffset = 0
} }
} }
} }
elapsed := time.Since(t.start).Nanoseconds() / 1000000 elapsed := time.Since(t.start).Nanoseconds() / 1000000
var uploadFileSize int64
if t.transferType == TransferDownload { if t.transferType == TransferDownload {
logger.TransferLog(downloadLogSender, t.fsPath, elapsed, t.BytesSent.Load(), t.Connection.User.Username, logger.TransferLog(downloadLogSender, t.fsPath, elapsed, atomic.LoadInt64(&t.BytesSent), t.Connection.User.Username,
t.Connection.ID, t.Connection.protocol, t.Connection.localAddr, t.Connection.remoteAddr, t.ftpMode) t.Connection.ID, t.Connection.protocol, t.Connection.localAddr, t.Connection.remoteAddr, t.ftpMode)
ExecuteActionNotification(t.Connection, operationDownload, t.fsPath, t.requestPath, "", "", "", //nolint:errcheck ExecuteActionNotification(t.Connection, operationDownload, t.fsPath, t.requestPath, "", "", "",
t.BytesSent.Load(), t.ErrTransfer) atomic.LoadInt64(&t.BytesSent), t.ErrTransfer)
} else { } else {
statSize, deletedFiles, errStat := t.getUploadFileSize() fileSize := atomic.LoadInt64(&t.BytesReceived) + t.MinWriteOffset
if errStat == nil { if statSize, errStat := t.getUploadFileSize(); errStat == nil {
uploadFileSize = statSize fileSize = statSize
} else {
uploadFileSize = t.BytesReceived.Load() + t.MinWriteOffset
if t.Fs.IsNotExist(errStat) {
uploadFileSize = 0
numFiles--
}
} }
numFiles -= deletedFiles t.Connection.Log(logger.LevelDebug, "uploaded file size %v", fileSize)
t.Connection.Log(logger.LevelDebug, "upload file size %d, num files %d, deleted files %d, fs path %q", t.updateQuota(numFiles, fileSize)
uploadFileSize, numFiles, deletedFiles, t.fsPath)
numFiles, uploadFileSize = t.executeUploadHook(numFiles, uploadFileSize)
t.updateQuota(numFiles, uploadFileSize)
t.updateTimes() t.updateTimes()
logger.TransferLog(uploadLogSender, t.fsPath, elapsed, t.BytesReceived.Load(), t.Connection.User.Username, logger.TransferLog(uploadLogSender, t.fsPath, elapsed, atomic.LoadInt64(&t.BytesReceived), t.Connection.User.Username,
t.Connection.ID, t.Connection.protocol, t.Connection.localAddr, t.Connection.remoteAddr, t.ftpMode) t.Connection.ID, t.Connection.protocol, t.Connection.localAddr, t.Connection.remoteAddr, t.ftpMode)
ExecuteActionNotification(t.Connection, operationUpload, t.fsPath, t.requestPath, "", "", "", fileSize, t.ErrTransfer)
} }
if t.ErrTransfer != nil { if t.ErrTransfer != nil {
t.Connection.Log(logger.LevelError, "transfer error: %v, path: %#v", t.ErrTransfer, t.fsPath) t.Connection.Log(logger.LevelError, "transfer error: %v, path: %#v", t.ErrTransfer, t.fsPath)
@@ -420,62 +408,9 @@ func (t *BaseTransfer) Close() error {
err = t.ErrTransfer err = t.ErrTransfer
} }
} }
t.updateTransferTimestamps(uploadFileSize)
return err return err
} }
func (t *BaseTransfer) updateTransferTimestamps(uploadFileSize int64) {
if t.ErrTransfer != nil {
return
}
if t.transferType == TransferUpload {
if t.Connection.User.FirstUpload == 0 && !t.Connection.uploadDone.Load() {
if err := dataprovider.UpdateUserTransferTimestamps(t.Connection.User.Username, true); err == nil {
t.Connection.uploadDone.Store(true)
ExecuteActionNotification(t.Connection, operationFirstUpload, t.fsPath, t.requestPath, "", //nolint:errcheck
"", "", uploadFileSize, t.ErrTransfer)
}
}
return
}
if t.Connection.User.FirstDownload == 0 && !t.Connection.downloadDone.Load() && t.BytesSent.Load() > 0 {
if err := dataprovider.UpdateUserTransferTimestamps(t.Connection.User.Username, false); err == nil {
t.Connection.downloadDone.Store(true)
ExecuteActionNotification(t.Connection, operationFirstDownload, t.fsPath, t.requestPath, "", //nolint:errcheck
"", "", t.BytesSent.Load(), t.ErrTransfer)
}
}
}
func (t *BaseTransfer) executeUploadHook(numFiles int, fileSize int64) (int, int64) {
err := ExecuteActionNotification(t.Connection, operationUpload, t.fsPath, t.requestPath, "", "", "",
fileSize, t.ErrTransfer)
if err != nil {
if t.ErrTransfer == nil {
t.ErrTransfer = err
}
// try to remove the uploaded file
err = t.Fs.Remove(t.fsPath, false)
if err == nil {
numFiles--
fileSize = 0
t.BytesReceived.Store(0)
t.MinWriteOffset = 0
} else {
t.Connection.Log(logger.LevelWarn, "unable to remove path %q after upload hook failure: %v", t.fsPath, err)
}
}
return numFiles, fileSize
}
func (t *BaseTransfer) getUploadedFiles() int {
numFiles := 0
if t.isNewFile {
numFiles = 1
}
return numFiles
}
func (t *BaseTransfer) updateTimes() { func (t *BaseTransfer) updateTimes() {
if !t.aTime.IsZero() && !t.mTime.IsZero() { if !t.aTime.IsZero() && !t.mTime.IsZero() {
err := t.Fs.Chtimes(t.fsPath, t.aTime, t.mTime, true) err := t.Fs.Chtimes(t.fsPath, t.aTime, t.mTime, true)
@@ -485,12 +420,12 @@ func (t *BaseTransfer) updateTimes() {
} }
func (t *BaseTransfer) updateQuota(numFiles int, fileSize int64) bool { func (t *BaseTransfer) updateQuota(numFiles int, fileSize int64) bool {
// Uploads on some filesystem (S3 and similar) are atomic, if there is an error nothing is uploaded // S3 uploads are atomic, if there is an error nothing is uploaded
if t.File == nil && t.ErrTransfer != nil && vfs.HasImplicitAtomicUploads(t.Fs) { if t.File == nil && t.ErrTransfer != nil && !t.Connection.User.HasBufferedSFTP(t.GetVirtualPath()) {
return false return false
} }
sizeDiff := fileSize - t.InitialSize sizeDiff := fileSize - t.InitialSize
if t.transferType == TransferUpload && (numFiles != 0 || sizeDiff != 0) { if t.transferType == TransferUpload && (numFiles != 0 || sizeDiff > 0) {
vfolder, err := t.Connection.User.GetVirtualFolderForPath(path.Dir(t.requestPath)) vfolder, err := t.Connection.User.GetVirtualFolderForPath(path.Dir(t.requestPath))
if err == nil { if err == nil {
dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, numFiles, //nolint:errcheck dataprovider.UpdateVirtualFolderQuota(&vfolder.BaseVirtualFolder, numFiles, //nolint:errcheck
@@ -512,10 +447,10 @@ func (t *BaseTransfer) HandleThrottle() {
var trasferredBytes int64 var trasferredBytes int64
if t.transferType == TransferDownload { if t.transferType == TransferDownload {
wantedBandwidth = t.Connection.User.DownloadBandwidth wantedBandwidth = t.Connection.User.DownloadBandwidth
trasferredBytes = t.BytesSent.Load() trasferredBytes = atomic.LoadInt64(&t.BytesSent)
} else { } else {
wantedBandwidth = t.Connection.User.UploadBandwidth wantedBandwidth = t.Connection.User.UploadBandwidth
trasferredBytes = t.BytesReceived.Load() trasferredBytes = atomic.LoadInt64(&t.BytesReceived)
} }
if wantedBandwidth > 0 { if wantedBandwidth > 0 {
// real and wanted elapsed as milliseconds, bytes as kilobytes // real and wanted elapsed as milliseconds, bytes as kilobytes

View File

@@ -25,21 +25,22 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/kms"
"github.com/drakkan/sftpgo/v2/internal/vfs" "github.com/drakkan/sftpgo/v2/vfs"
) )
func TestTransferUpdateQuota(t *testing.T) { func TestTransferUpdateQuota(t *testing.T) {
conn := NewBaseConnection("", ProtocolSFTP, "", "", dataprovider.User{}) conn := NewBaseConnection("", ProtocolSFTP, "", "", dataprovider.User{})
transfer := BaseTransfer{ transfer := BaseTransfer{
Connection: conn, Connection: conn,
transferType: TransferUpload, transferType: TransferUpload,
Fs: vfs.NewOsFs("", os.TempDir(), ""), BytesReceived: 123,
Fs: vfs.NewOsFs("", os.TempDir(), ""),
} }
transfer.BytesReceived.Store(123)
errFake := errors.New("fake error") errFake := errors.New("fake error")
transfer.TransferError(errFake) transfer.TransferError(errFake)
assert.False(t, transfer.updateQuota(1, 0))
err := transfer.Close() err := transfer.Close()
if assert.Error(t, err) { if assert.Error(t, err) {
assert.EqualError(t, err, errFake.Error()) assert.EqualError(t, err, errFake.Error())
@@ -55,15 +56,11 @@ func TestTransferUpdateQuota(t *testing.T) {
QuotaSize: -1, QuotaSize: -1,
}) })
transfer.ErrTransfer = nil transfer.ErrTransfer = nil
transfer.BytesReceived.Store(1) transfer.BytesReceived = 1
transfer.requestPath = "/vdir/file" transfer.requestPath = "/vdir/file"
assert.True(t, transfer.updateQuota(1, 0)) assert.True(t, transfer.updateQuota(1, 0))
err = transfer.Close() err = transfer.Close()
assert.NoError(t, err) assert.NoError(t, err)
transfer.ErrTransfer = errFake
transfer.Fs = newMockOsFs(true, "", "", "S3Fs fake")
assert.False(t, transfer.updateQuota(1, 0))
} }
func TestTransferThrottling(t *testing.T) { func TestTransferThrottling(t *testing.T) {
@@ -83,7 +80,7 @@ func TestTransferThrottling(t *testing.T) {
wantedDownloadElapsed -= wantedDownloadElapsed / 10 wantedDownloadElapsed -= wantedDownloadElapsed / 10
conn := NewBaseConnection("id", ProtocolSCP, "", "", u) conn := NewBaseConnection("id", ProtocolSCP, "", "", u)
transfer := NewBaseTransfer(nil, conn, nil, "", "", "", TransferUpload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) transfer := NewBaseTransfer(nil, conn, nil, "", "", "", TransferUpload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{})
transfer.BytesReceived.Store(testFileSize) transfer.BytesReceived = testFileSize
transfer.Connection.UpdateLastActivity() transfer.Connection.UpdateLastActivity()
startTime := transfer.Connection.GetLastActivity() startTime := transfer.Connection.GetLastActivity()
transfer.HandleThrottle() transfer.HandleThrottle()
@@ -93,7 +90,7 @@ func TestTransferThrottling(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
transfer = NewBaseTransfer(nil, conn, nil, "", "", "", TransferDownload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) transfer = NewBaseTransfer(nil, conn, nil, "", "", "", TransferDownload, 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{})
transfer.BytesSent.Store(testFileSize) transfer.BytesSent = testFileSize
transfer.Connection.UpdateLastActivity() transfer.Connection.UpdateLastActivity()
startTime = transfer.Connection.GetLastActivity() startTime = transfer.Connection.GetLastActivity()
@@ -229,7 +226,7 @@ func TestTransferErrors(t *testing.T) {
assert.Equal(t, testFile, transfer.GetFsPath()) assert.Equal(t, testFile, transfer.GetFsPath())
transfer.SetCancelFn(cancelFn) transfer.SetCancelFn(cancelFn)
errFake := errors.New("err fake") errFake := errors.New("err fake")
transfer.BytesReceived.Store(9) transfer.BytesReceived = 9
transfer.TransferError(ErrQuotaExceeded) transfer.TransferError(ErrQuotaExceeded)
assert.True(t, isCancelled) assert.True(t, isCancelled)
transfer.TransferError(errFake) transfer.TransferError(errFake)
@@ -252,7 +249,7 @@ func TestTransferErrors(t *testing.T) {
fsPath := filepath.Join(os.TempDir(), "test_file") fsPath := filepath.Join(os.TempDir(), "test_file")
transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, 0, true, transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, 0, true,
fs, dataprovider.TransferQuota{}) fs, dataprovider.TransferQuota{})
transfer.BytesReceived.Store(9) transfer.BytesReceived = 9
transfer.TransferError(errFake) transfer.TransferError(errFake)
assert.Error(t, transfer.ErrTransfer, errFake.Error()) assert.Error(t, transfer.ErrTransfer, errFake.Error())
// the file is closed from the embedding struct before to call close // the file is closed from the embedding struct before to call close
@@ -272,7 +269,7 @@ func TestTransferErrors(t *testing.T) {
} }
transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, 0, true, transfer = NewBaseTransfer(file, conn, nil, fsPath, file.Name(), "/test_file", TransferUpload, 0, 0, 0, 0, true,
fs, dataprovider.TransferQuota{}) fs, dataprovider.TransferQuota{})
transfer.BytesReceived.Store(9) transfer.BytesReceived = 9
// the file is closed from the embedding struct before to call close // the file is closed from the embedding struct before to call close
err = file.Close() err = file.Close()
assert.NoError(t, err) assert.NoError(t, err)
@@ -300,25 +297,24 @@ func TestRemovePartialCryptoFile(t *testing.T) {
transfer := NewBaseTransfer(nil, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload, transfer := NewBaseTransfer(nil, conn, nil, testFile, testFile, "/transfer_test_file", TransferUpload,
0, 0, 0, 0, true, fs, dataprovider.TransferQuota{}) 0, 0, 0, 0, true, fs, dataprovider.TransferQuota{})
transfer.ErrTransfer = errors.New("test error") transfer.ErrTransfer = errors.New("test error")
_, _, err = transfer.getUploadFileSize() _, err = transfer.getUploadFileSize()
assert.Error(t, err) assert.Error(t, err)
err = os.WriteFile(testFile, []byte("test data"), os.ModePerm) err = os.WriteFile(testFile, []byte("test data"), os.ModePerm)
assert.NoError(t, err) assert.NoError(t, err)
size, deletedFiles, err := transfer.getUploadFileSize() size, err := transfer.getUploadFileSize()
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, int64(0), size) assert.Equal(t, int64(9), size)
assert.Equal(t, 1, deletedFiles)
assert.NoFileExists(t, testFile) assert.NoFileExists(t, testFile)
} }
func TestFTPMode(t *testing.T) { func TestFTPMode(t *testing.T) {
conn := NewBaseConnection("", ProtocolFTP, "", "", dataprovider.User{}) conn := NewBaseConnection("", ProtocolFTP, "", "", dataprovider.User{})
transfer := BaseTransfer{ transfer := BaseTransfer{
Connection: conn, Connection: conn,
transferType: TransferUpload, transferType: TransferUpload,
Fs: vfs.NewOsFs("", os.TempDir(), ""), BytesReceived: 123,
Fs: vfs.NewOsFs("", os.TempDir(), ""),
} }
transfer.BytesReceived.Store(123)
assert.Empty(t, transfer.ftpMode) assert.Empty(t, transfer.ftpMode)
transfer.SetFtpMode("active") transfer.SetFtpMode("active")
assert.Equal(t, "active", transfer.ftpMode) assert.Equal(t, "active", transfer.ftpMode)
@@ -403,14 +399,14 @@ func TestTransferQuota(t *testing.T) {
transfer.transferQuota = dataprovider.TransferQuota{ transfer.transferQuota = dataprovider.TransferQuota{
AllowedTotalSize: 10, AllowedTotalSize: 10,
} }
transfer.BytesReceived.Store(5) transfer.BytesReceived = 5
transfer.BytesSent.Store(4) transfer.BytesSent = 4
err = transfer.CheckRead() err = transfer.CheckRead()
assert.NoError(t, err) assert.NoError(t, err)
err = transfer.CheckWrite() err = transfer.CheckWrite()
assert.NoError(t, err) assert.NoError(t, err)
transfer.BytesSent.Store(6) transfer.BytesSent = 6
err = transfer.CheckRead() err = transfer.CheckRead()
if assert.Error(t, err) { if assert.Error(t, err) {
assert.Contains(t, err.Error(), ErrReadQuotaExceeded.Error()) assert.Contains(t, err.Error(), ErrReadQuotaExceeded.Error())
@@ -432,7 +428,7 @@ func TestTransferQuota(t *testing.T) {
err = transfer.CheckWrite() err = transfer.CheckWrite()
assert.NoError(t, err) assert.NoError(t, err)
transfer.BytesReceived.Store(11) transfer.BytesReceived = 11
err = transfer.CheckRead() err = transfer.CheckRead()
if assert.Error(t, err) { if assert.Error(t, err) {
assert.Contains(t, err.Error(), ErrReadQuotaExceeded.Error()) assert.Contains(t, err.Error(), ErrReadQuotaExceeded.Error())
@@ -446,11 +442,11 @@ func TestUploadOutsideHomeRenameError(t *testing.T) {
conn := NewBaseConnection("", ProtocolSFTP, "", "", dataprovider.User{}) conn := NewBaseConnection("", ProtocolSFTP, "", "", dataprovider.User{})
transfer := BaseTransfer{ transfer := BaseTransfer{
Connection: conn, Connection: conn,
transferType: TransferUpload, transferType: TransferUpload,
Fs: vfs.NewOsFs("", filepath.Join(os.TempDir(), "home"), ""), BytesReceived: 123,
Fs: vfs.NewOsFs("", filepath.Join(os.TempDir(), "home"), ""),
} }
transfer.BytesReceived.Store(123)
fileName := filepath.Join(os.TempDir(), "_temp") fileName := filepath.Join(os.TempDir(), "_temp")
err := os.WriteFile(fileName, []byte(`data`), 0644) err := os.WriteFile(fileName, []byte(`data`), 0644)
@@ -463,10 +459,10 @@ func TestUploadOutsideHomeRenameError(t *testing.T) {
Config.TempPath = filepath.Clean(os.TempDir()) Config.TempPath = filepath.Clean(os.TempDir())
res = transfer.checkUploadOutsideHomeDir(nil) res = transfer.checkUploadOutsideHomeDir(nil)
assert.Equal(t, 0, res) assert.Equal(t, 0, res)
assert.Greater(t, transfer.BytesReceived.Load(), int64(0)) assert.Greater(t, transfer.BytesReceived, int64(0))
res = transfer.checkUploadOutsideHomeDir(os.ErrPermission) res = transfer.checkUploadOutsideHomeDir(os.ErrPermission)
assert.Equal(t, 1, res) assert.Equal(t, 1, res)
assert.Equal(t, int64(0), transfer.BytesReceived.Load()) assert.Equal(t, int64(0), transfer.BytesReceived)
assert.NoFileExists(t, fileName) assert.NoFileExists(t, fileName)
Config.TempPath = oldTempPath Config.TempPath = oldTempPath

View File

@@ -19,9 +19,9 @@ import (
"sync" "sync"
"time" "time"
"github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/util"
) )
type overquotaTransfer struct { type overquotaTransfer struct {

View File

@@ -21,6 +21,7 @@ import (
"path/filepath" "path/filepath"
"strconv" "strconv"
"strings" "strings"
"sync/atomic"
"testing" "testing"
"time" "time"
@@ -28,9 +29,9 @@ import (
"github.com/sftpgo/sdk" "github.com/sftpgo/sdk"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/util"
"github.com/drakkan/sftpgo/v2/internal/vfs" "github.com/drakkan/sftpgo/v2/vfs"
) )
func TestTransfersCheckerDiskQuota(t *testing.T) { func TestTransfersCheckerDiskQuota(t *testing.T) {
@@ -95,7 +96,7 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
} }
transfer1 := NewBaseTransfer(nil, conn1, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"), transfer1 := NewBaseTransfer(nil, conn1, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"),
"/file1", TransferUpload, 0, 0, 120, 0, true, fsUser, dataprovider.TransferQuota{}) "/file1", TransferUpload, 0, 0, 120, 0, true, fsUser, dataprovider.TransferQuota{})
transfer1.BytesReceived.Store(150) transfer1.BytesReceived = 150
err = Connections.Add(fakeConn1) err = Connections.Add(fakeConn1)
assert.NoError(t, err) assert.NoError(t, err)
// the transferschecker will do nothing if there is only one ongoing transfer // the transferschecker will do nothing if there is only one ongoing transfer
@@ -109,8 +110,8 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
} }
transfer2 := NewBaseTransfer(nil, conn2, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"), transfer2 := NewBaseTransfer(nil, conn2, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"),
"/file2", TransferUpload, 0, 0, 120, 40, true, fsUser, dataprovider.TransferQuota{}) "/file2", TransferUpload, 0, 0, 120, 40, true, fsUser, dataprovider.TransferQuota{})
transfer1.BytesReceived.Store(50) transfer1.BytesReceived = 50
transfer2.BytesReceived.Store(60) transfer2.BytesReceived = 60
err = Connections.Add(fakeConn2) err = Connections.Add(fakeConn2)
assert.NoError(t, err) assert.NoError(t, err)
@@ -121,7 +122,7 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
} }
transfer3 := NewBaseTransfer(nil, conn3, nil, filepath.Join(user.HomeDir, "file3"), filepath.Join(user.HomeDir, "file3"), transfer3 := NewBaseTransfer(nil, conn3, nil, filepath.Join(user.HomeDir, "file3"), filepath.Join(user.HomeDir, "file3"),
"/file3", TransferDownload, 0, 0, 120, 0, true, fsUser, dataprovider.TransferQuota{}) "/file3", TransferDownload, 0, 0, 120, 0, true, fsUser, dataprovider.TransferQuota{})
transfer3.BytesReceived.Store(60) // this value will be ignored, this is a download transfer3.BytesReceived = 60 // this value will be ignored, this is a download
err = Connections.Add(fakeConn3) err = Connections.Add(fakeConn3)
assert.NoError(t, err) assert.NoError(t, err)
@@ -131,20 +132,20 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
assert.Nil(t, transfer2.errAbort) assert.Nil(t, transfer2.errAbort)
assert.Nil(t, transfer3.errAbort) assert.Nil(t, transfer3.errAbort)
transfer1.BytesReceived.Store(80) // truncated size will be subtracted, we are not overquota transfer1.BytesReceived = 80 // truncated size will be subtracted, we are not overquota
Connections.checkTransfers() Connections.checkTransfers()
assert.Nil(t, transfer1.errAbort) assert.Nil(t, transfer1.errAbort)
assert.Nil(t, transfer2.errAbort) assert.Nil(t, transfer2.errAbort)
assert.Nil(t, transfer3.errAbort) assert.Nil(t, transfer3.errAbort)
transfer1.BytesReceived.Store(120) transfer1.BytesReceived = 120
// we are now overquota // we are now overquota
// if another check is in progress nothing is done // if another check is in progress nothing is done
Connections.transfersCheckStatus.Store(true) atomic.StoreInt32(&Connections.transfersCheckStatus, 1)
Connections.checkTransfers() Connections.checkTransfers()
assert.Nil(t, transfer1.errAbort) assert.Nil(t, transfer1.errAbort)
assert.Nil(t, transfer2.errAbort) assert.Nil(t, transfer2.errAbort)
assert.Nil(t, transfer3.errAbort) assert.Nil(t, transfer3.errAbort)
Connections.transfersCheckStatus.Store(false) atomic.StoreInt32(&Connections.transfersCheckStatus, 0)
Connections.checkTransfers() Connections.checkTransfers()
assert.True(t, conn1.IsQuotaExceededError(transfer1.errAbort), transfer1.errAbort) assert.True(t, conn1.IsQuotaExceededError(transfer1.errAbort), transfer1.errAbort)
@@ -171,8 +172,8 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
assert.Nil(t, transfer2.errAbort) assert.Nil(t, transfer2.errAbort)
assert.Nil(t, transfer3.errAbort) assert.Nil(t, transfer3.errAbort)
// now check a public folder // now check a public folder
transfer1.BytesReceived.Store(0) transfer1.BytesReceived = 0
transfer2.BytesReceived.Store(0) transfer2.BytesReceived = 0
connID4 := xid.New().String() connID4 := xid.New().String()
fsFolder, err := user.GetFilesystemForPath(path.Join(vdirPath, "/file1"), connID4) fsFolder, err := user.GetFilesystemForPath(path.Join(vdirPath, "/file1"), connID4)
assert.NoError(t, err) assert.NoError(t, err)
@@ -196,12 +197,12 @@ func TestTransfersCheckerDiskQuota(t *testing.T) {
err = Connections.Add(fakeConn5) err = Connections.Add(fakeConn5)
assert.NoError(t, err) assert.NoError(t, err)
transfer4.BytesReceived.Store(50) transfer4.BytesReceived = 50
transfer5.BytesReceived.Store(40) transfer5.BytesReceived = 40
Connections.checkTransfers() Connections.checkTransfers()
assert.Nil(t, transfer4.errAbort) assert.Nil(t, transfer4.errAbort)
assert.Nil(t, transfer5.errAbort) assert.Nil(t, transfer5.errAbort)
transfer5.BytesReceived.Store(60) transfer5.BytesReceived = 60
Connections.checkTransfers() Connections.checkTransfers()
assert.Nil(t, transfer1.errAbort) assert.Nil(t, transfer1.errAbort)
assert.Nil(t, transfer2.errAbort) assert.Nil(t, transfer2.errAbort)
@@ -285,7 +286,7 @@ func TestTransferCheckerTransferQuota(t *testing.T) {
} }
transfer1 := NewBaseTransfer(nil, conn1, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"), transfer1 := NewBaseTransfer(nil, conn1, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"),
"/file1", TransferUpload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedTotalSize: 100}) "/file1", TransferUpload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedTotalSize: 100})
transfer1.BytesReceived.Store(150) transfer1.BytesReceived = 150
err = Connections.Add(fakeConn1) err = Connections.Add(fakeConn1)
assert.NoError(t, err) assert.NoError(t, err)
// the transferschecker will do nothing if there is only one ongoing transfer // the transferschecker will do nothing if there is only one ongoing transfer
@@ -299,26 +300,26 @@ func TestTransferCheckerTransferQuota(t *testing.T) {
} }
transfer2 := NewBaseTransfer(nil, conn2, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"), transfer2 := NewBaseTransfer(nil, conn2, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"),
"/file2", TransferUpload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedTotalSize: 100}) "/file2", TransferUpload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedTotalSize: 100})
transfer2.BytesReceived.Store(150) transfer2.BytesReceived = 150
err = Connections.Add(fakeConn2) err = Connections.Add(fakeConn2)
assert.NoError(t, err) assert.NoError(t, err)
Connections.checkTransfers() Connections.checkTransfers()
assert.Nil(t, transfer1.errAbort) assert.Nil(t, transfer1.errAbort)
assert.Nil(t, transfer2.errAbort) assert.Nil(t, transfer2.errAbort)
// now test overquota // now test overquota
transfer1.BytesReceived.Store(1024*1024 + 1) transfer1.BytesReceived = 1024*1024 + 1
transfer2.BytesReceived.Store(0) transfer2.BytesReceived = 0
Connections.checkTransfers() Connections.checkTransfers()
assert.True(t, conn1.IsQuotaExceededError(transfer1.errAbort)) assert.True(t, conn1.IsQuotaExceededError(transfer1.errAbort))
assert.Nil(t, transfer2.errAbort) assert.Nil(t, transfer2.errAbort)
transfer1.errAbort = nil transfer1.errAbort = nil
transfer1.BytesReceived.Store(1024*1024 + 1) transfer1.BytesReceived = 1024*1024 + 1
transfer2.BytesReceived.Store(1024) transfer2.BytesReceived = 1024
Connections.checkTransfers() Connections.checkTransfers()
assert.True(t, conn1.IsQuotaExceededError(transfer1.errAbort)) assert.True(t, conn1.IsQuotaExceededError(transfer1.errAbort))
assert.True(t, conn2.IsQuotaExceededError(transfer2.errAbort)) assert.True(t, conn2.IsQuotaExceededError(transfer2.errAbort))
transfer1.BytesReceived.Store(0) transfer1.BytesReceived = 0
transfer2.BytesReceived.Store(0) transfer2.BytesReceived = 0
transfer1.errAbort = nil transfer1.errAbort = nil
transfer2.errAbort = nil transfer2.errAbort = nil
@@ -336,7 +337,7 @@ func TestTransferCheckerTransferQuota(t *testing.T) {
} }
transfer3 := NewBaseTransfer(nil, conn3, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"), transfer3 := NewBaseTransfer(nil, conn3, nil, filepath.Join(user.HomeDir, "file1"), filepath.Join(user.HomeDir, "file1"),
"/file1", TransferDownload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedDLSize: 100}) "/file1", TransferDownload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedDLSize: 100})
transfer3.BytesSent.Store(150) transfer3.BytesSent = 150
err = Connections.Add(fakeConn3) err = Connections.Add(fakeConn3)
assert.NoError(t, err) assert.NoError(t, err)
@@ -347,15 +348,15 @@ func TestTransferCheckerTransferQuota(t *testing.T) {
} }
transfer4 := NewBaseTransfer(nil, conn4, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"), transfer4 := NewBaseTransfer(nil, conn4, nil, filepath.Join(user.HomeDir, "file2"), filepath.Join(user.HomeDir, "file2"),
"/file2", TransferDownload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedDLSize: 100}) "/file2", TransferDownload, 0, 0, 0, 0, true, fsUser, dataprovider.TransferQuota{AllowedDLSize: 100})
transfer4.BytesSent.Store(150) transfer4.BytesSent = 150
err = Connections.Add(fakeConn4) err = Connections.Add(fakeConn4)
assert.NoError(t, err) assert.NoError(t, err)
Connections.checkTransfers() Connections.checkTransfers()
assert.Nil(t, transfer3.errAbort) assert.Nil(t, transfer3.errAbort)
assert.Nil(t, transfer4.errAbort) assert.Nil(t, transfer4.errAbort)
transfer3.BytesSent.Store(512 * 1024) transfer3.BytesSent = 512 * 1024
transfer4.BytesSent.Store(512*1024 + 1) transfer4.BytesSent = 512*1024 + 1
Connections.checkTransfers() Connections.checkTransfers()
if assert.Error(t, transfer3.errAbort) { if assert.Error(t, transfer3.errAbort) {
assert.Contains(t, transfer3.errAbort.Error(), ErrReadQuotaExceeded.Error()) assert.Contains(t, transfer3.errAbort.Error(), ErrReadQuotaExceeded.Error())

View File

@@ -24,25 +24,24 @@ import (
"strings" "strings"
"github.com/spf13/viper" "github.com/spf13/viper"
"github.com/subosito/gotenv"
"github.com/drakkan/sftpgo/v2/internal/acme" "github.com/drakkan/sftpgo/v2/acme"
"github.com/drakkan/sftpgo/v2/internal/command" "github.com/drakkan/sftpgo/v2/command"
"github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/common"
"github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/ftpd" "github.com/drakkan/sftpgo/v2/ftpd"
"github.com/drakkan/sftpgo/v2/internal/httpclient" "github.com/drakkan/sftpgo/v2/httpclient"
"github.com/drakkan/sftpgo/v2/internal/httpd" "github.com/drakkan/sftpgo/v2/httpd"
"github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/kms"
"github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/internal/mfa" "github.com/drakkan/sftpgo/v2/mfa"
"github.com/drakkan/sftpgo/v2/internal/plugin" "github.com/drakkan/sftpgo/v2/plugin"
"github.com/drakkan/sftpgo/v2/internal/sftpd" "github.com/drakkan/sftpgo/v2/sftpd"
"github.com/drakkan/sftpgo/v2/internal/smtp" "github.com/drakkan/sftpgo/v2/smtp"
"github.com/drakkan/sftpgo/v2/internal/telemetry" "github.com/drakkan/sftpgo/v2/telemetry"
"github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/util"
"github.com/drakkan/sftpgo/v2/internal/version" "github.com/drakkan/sftpgo/v2/version"
"github.com/drakkan/sftpgo/v2/internal/webdavd" "github.com/drakkan/sftpgo/v2/webdavd"
) )
const ( const (
@@ -82,26 +81,24 @@ var (
Debug: false, Debug: false,
} }
defaultWebDAVDBinding = webdavd.Binding{ defaultWebDAVDBinding = webdavd.Binding{
Address: "", Address: "",
Port: 0, Port: 0,
EnableHTTPS: false, EnableHTTPS: false,
CertificateFile: "", CertificateFile: "",
CertificateKeyFile: "", CertificateKeyFile: "",
MinTLSVersion: 12, MinTLSVersion: 12,
ClientAuthType: 0, ClientAuthType: 0,
TLSCipherSuites: nil, TLSCipherSuites: nil,
Prefix: "", Prefix: "",
ProxyAllowed: nil, ProxyAllowed: nil,
ClientIPProxyHeader: "", ClientIPProxyHeader: "",
ClientIPHeaderDepth: 0, ClientIPHeaderDepth: 0,
DisableWWWAuthHeader: false,
} }
defaultHTTPDBinding = httpd.Binding{ defaultHTTPDBinding = httpd.Binding{
Address: "", Address: "",
Port: 8080, Port: 8080,
EnableWebAdmin: true, EnableWebAdmin: true,
EnableWebClient: true, EnableWebClient: true,
EnableRESTAPI: true,
EnabledLoginMethods: 0, EnabledLoginMethods: 0,
EnableHTTPS: false, EnableHTTPS: false,
CertificateFile: "", CertificateFile: "",
@@ -116,17 +113,16 @@ var (
RenderOpenAPI: true, RenderOpenAPI: true,
WebClientIntegrations: nil, WebClientIntegrations: nil,
OIDC: httpd.OIDC{ OIDC: httpd.OIDC{
ClientID: "", ClientID: "",
ClientSecret: "", ClientSecret: "",
ConfigURL: "", ConfigURL: "",
RedirectBaseURL: "", RedirectBaseURL: "",
UsernameField: "", UsernameField: "",
RoleField: "", RoleField: "",
ImplicitRoles: false, ImplicitRoles: false,
Scopes: []string{"openid", "profile", "email"}, Scopes: []string{"openid", "profile", "email"},
CustomFields: []string{}, CustomFields: []string{},
InsecureSkipSignatureCheck: false, Debug: false,
Debug: false,
}, },
Security: httpd.SecurityConf{ Security: httpd.SecurityConf{
Enabled: false, Enabled: false,
@@ -210,7 +206,6 @@ func Init() {
MaxTotalConnections: 0, MaxTotalConnections: 0,
MaxPerHostConnections: 20, MaxPerHostConnections: 20,
WhiteListFile: "", WhiteListFile: "",
AllowSelfConnections: 0,
DefenderConfig: common.DefenderConfig{ DefenderConfig: common.DefenderConfig{
Enabled: false, Enabled: false,
Driver: common.DefenderDriverMemory, Driver: common.DefenderDriverMemory,
@@ -290,16 +285,13 @@ func Init() {
CACertificates: []string{}, CACertificates: []string{},
CARevocationLists: []string{}, CARevocationLists: []string{},
Cors: webdavd.CorsConfig{ Cors: webdavd.CorsConfig{
Enabled: false, Enabled: false,
AllowedOrigins: []string{}, AllowedOrigins: []string{},
AllowedMethods: []string{}, AllowedMethods: []string{},
AllowedHeaders: []string{}, AllowedHeaders: []string{},
ExposedHeaders: []string{}, ExposedHeaders: []string{},
AllowCredentials: false, AllowCredentials: false,
MaxAge: 0, MaxAge: 0,
OptionsPassthrough: false,
OptionsSuccessStatus: 0,
AllowPrivateNetwork: false,
}, },
Cache: webdavd.Cache{ Cache: webdavd.Cache{
Users: webdavd.UsersCacheConfig{ Users: webdavd.UsersCacheConfig{
@@ -313,23 +305,21 @@ func Init() {
}, },
}, },
ProviderConf: dataprovider.Config{ ProviderConf: dataprovider.Config{
Driver: "sqlite", Driver: "sqlite",
Name: "sftpgo.db", Name: "sftpgo.db",
Host: "", Host: "",
Port: 0, Port: 0,
Username: "", Username: "",
Password: "", Password: "",
ConnectionString: "", ConnectionString: "",
SQLTablesPrefix: "", SQLTablesPrefix: "",
SSLMode: 0, SSLMode: 0,
DisableSNI: false, RootCert: "",
TargetSessionAttrs: "", ClientCert: "",
RootCert: "", ClientKey: "",
ClientCert: "", TrackQuota: 2,
ClientKey: "", PoolSize: 0,
TrackQuota: 2, UsersBaseDir: "",
PoolSize: 0,
UsersBaseDir: "",
Actions: dataprovider.ObjectsActions{ Actions: dataprovider.ObjectsActions{
ExecuteOn: []string{}, ExecuteOn: []string{},
ExecuteFor: []string{}, ExecuteFor: []string{},
@@ -337,6 +327,7 @@ func Init() {
}, },
ExternalAuthHook: "", ExternalAuthHook: "",
ExternalAuthScope: 0, ExternalAuthScope: 0,
CredentialsPath: "credentials",
PreLoginHook: "", PreLoginHook: "",
PostLoginHook: "", PostLoginHook: "",
PostLoginScope: 0, PostLoginScope: 0,
@@ -367,12 +358,12 @@ func Init() {
CreateDefaultAdmin: false, CreateDefaultAdmin: false,
NamingRules: 1, NamingRules: 1,
IsShared: 0, IsShared: 0,
Node: dataprovider.NodeConfig{ BackupsPath: "backups",
Host: "", AutoBackup: dataprovider.AutoBackup{
Port: 0, Enabled: true,
Proto: "http", Hour: "0",
DayOfWeek: "*",
}, },
BackupsPath: "backups",
}, },
HTTPDConfig: httpd.Conf{ HTTPDConfig: httpd.Conf{
Bindings: []httpd.Binding{defaultHTTPDBinding}, Bindings: []httpd.Binding{defaultHTTPDBinding},
@@ -388,16 +379,13 @@ func Init() {
TokenValidation: 0, TokenValidation: 0,
MaxUploadFileSize: 1048576000, MaxUploadFileSize: 1048576000,
Cors: httpd.CorsConfig{ Cors: httpd.CorsConfig{
Enabled: false, Enabled: false,
AllowedOrigins: []string{}, AllowedOrigins: []string{},
AllowedMethods: []string{}, AllowedMethods: []string{},
AllowedHeaders: []string{}, AllowedHeaders: []string{},
ExposedHeaders: []string{}, ExposedHeaders: []string{},
AllowCredentials: false, AllowCredentials: false,
MaxAge: 0, MaxAge: 0,
OptionsPassthrough: false,
OptionsSuccessStatus: 0,
AllowPrivateNetwork: false,
}, },
Setup: httpd.SetupConfig{ Setup: httpd.SetupConfig{
InstallationCode: "", InstallationCode: "",
@@ -428,7 +416,7 @@ func Init() {
}, },
}, },
MFAConfig: mfa.Config{ MFAConfig: mfa.Config{
TOTP: []mfa.TOTPConfig{defaultTOTP}, TOTP: nil,
}, },
TelemetryConfig: telemetry.Conf{ TelemetryConfig: telemetry.Conf{
BindPort: 0, BindPort: 0,
@@ -636,64 +624,6 @@ func setConfigFile(configDir, configFile string) {
viper.SetConfigFile(configFile) viper.SetConfigFile(configFile)
} }
// readEnvFiles reads files inside the "env.d" directory relative to configDir
// and then export the valid variables into environment variables if they do
// not exist
func readEnvFiles(configDir string) {
envd := filepath.Join(configDir, "env.d")
entries, err := os.ReadDir(envd)
if err != nil {
logger.Info(logSender, "", "unable to read env files from %q: %v", envd, err)
return
}
for _, entry := range entries {
info, err := entry.Info()
if err == nil && info.Mode().IsRegular() {
envFile := filepath.Join(envd, entry.Name())
err = gotenv.Load(envFile)
if err != nil {
logger.Error(logSender, "", "unable to load env vars from file %q, err: %v", envFile, err)
} else {
logger.Info(logSender, "", "set env vars from file %q", envFile)
}
}
}
}
func checkOverrideDefaultSettings() {
// for slices we need to set the defaults to nil if the key is set in the config file,
// otherwise the values are merged and not replaced as expected
rateLimiters := viper.Get("common.rate_limiters")
if val, ok := rateLimiters.([]any); ok {
if len(val) > 0 {
if rl, ok := val[0].(map[string]any); ok {
if _, ok := rl["protocols"]; ok {
globalConf.Common.RateLimitersConfig[0].Protocols = nil
}
}
}
}
httpdBindings := viper.Get("httpd.bindings")
if val, ok := httpdBindings.([]any); ok {
if len(val) > 0 {
if binding, ok := val[0].(map[string]any); ok {
if val, ok := binding["oidc"]; ok {
if oidc, ok := val.(map[string]any); ok {
if _, ok := oidc["scopes"]; ok {
globalConf.HTTPDConfig.Bindings[0].OIDC.Scopes = nil
}
}
}
}
}
}
if util.Contains(viper.AllKeys(), "mfa.totp") {
globalConf.MFAConfig.TOTP = nil
}
}
// LoadConfig loads the configuration // LoadConfig loads the configuration
// configDir will be added to the configuration search paths. // configDir will be added to the configuration search paths.
// The search path contains by default the current directory and on linux it contains // The search path contains by default the current directory and on linux it contains
@@ -701,7 +631,6 @@ func checkOverrideDefaultSettings() {
// configFile is an absolute or relative path (to the config dir) to the configuration file. // configFile is an absolute or relative path (to the config dir) to the configuration file.
func LoadConfig(configDir, configFile string) error { func LoadConfig(configDir, configFile string) error {
var err error var err error
readEnvFiles(configDir)
viper.AddConfigPath(configDir) viper.AddConfigPath(configDir)
setViperAdditionalConfigPaths() setViperAdditionalConfigPaths()
viper.AddConfigPath(".") viper.AddConfigPath(".")
@@ -717,8 +646,8 @@ func LoadConfig(configDir, configFile string) error {
logger.Warn(logSender, "", "error loading configuration file: %v", err) logger.Warn(logSender, "", "error loading configuration file: %v", err)
logger.WarnToConsole("error loading configuration file: %v", err) logger.WarnToConsole("error loading configuration file: %v", err)
} }
globalConf.MFAConfig.TOTP = []mfa.TOTPConfig{defaultTOTP}
} }
checkOverrideDefaultSettings()
err = viper.Unmarshal(&globalConf) err = viper.Unmarshal(&globalConf)
if err != nil { if err != nil {
logger.Warn(logSender, "", "error parsing configuration file: %v", err) logger.Warn(logSender, "", "error parsing configuration file: %v", err)
@@ -780,6 +709,12 @@ func resetInvalidConfigs() {
logger.Warn(logSender, "", "Non-fatal configuration error: %v", warn) logger.Warn(logSender, "", "Non-fatal configuration error: %v", warn)
logger.WarnToConsole("Non-fatal configuration error: %v", warn) logger.WarnToConsole("Non-fatal configuration error: %v", warn)
} }
if globalConf.ProviderConf.CredentialsPath == "" {
warn := "invalid credentials path, reset to \"credentials\""
globalConf.ProviderConf.CredentialsPath = "credentials"
logger.Warn(logSender, "", "Non-fatal configuration error: %v", warn)
logger.WarnToConsole("Non-fatal configuration error: %v", warn)
}
if globalConf.Common.DefenderConfig.Enabled && globalConf.Common.DefenderConfig.Driver == common.DefenderDriverProvider { if globalConf.Common.DefenderConfig.Enabled && globalConf.Common.DefenderConfig.Driver == common.DefenderDriverProvider {
if !globalConf.ProviderConf.IsDefenderSupported() { if !globalConf.ProviderConf.IsDefenderSupported() {
warn := fmt.Sprintf("provider based defender is not supported with data provider %#v, "+ warn := fmt.Sprintf("provider based defender is not supported with data provider %#v, "+
@@ -1264,12 +1199,6 @@ func getWebDAVDBindingFromEnv(idx int) {
isSet = true isSet = true
} }
enableHTTPS, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__ENABLE_HTTPS", idx))
if ok {
binding.EnableHTTPS = enableHTTPS
isSet = true
}
certificateFile, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__CERTIFICATE_FILE", idx)) certificateFile, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__CERTIFICATE_FILE", idx))
if ok { if ok {
binding.CertificateFile = certificateFile binding.CertificateFile = certificateFile
@@ -1282,6 +1211,12 @@ func getWebDAVDBindingFromEnv(idx int) {
isSet = true isSet = true
} }
enableHTTPS, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__ENABLE_HTTPS", idx))
if ok {
binding.EnableHTTPS = enableHTTPS
isSet = true
}
tlsVer, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__MIN_TLS_VERSION", idx)) tlsVer, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__MIN_TLS_VERSION", idx))
if ok { if ok {
binding.MinTLSVersion = int(tlsVer) binding.MinTLSVersion = int(tlsVer)
@@ -1300,19 +1235,13 @@ func getWebDAVDBindingFromEnv(idx int) {
isSet = true isSet = true
} }
prefix, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__PREFIX", idx))
if ok {
binding.Prefix = prefix
isSet = true
}
if getWebDAVDBindingProxyConfigsFromEnv(idx, &binding) { if getWebDAVDBindingProxyConfigsFromEnv(idx, &binding) {
isSet = true isSet = true
} }
disableWWWAuth, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__DISABLE_WWW_AUTH_HEADER", idx)) prefix, ok := os.LookupEnv(fmt.Sprintf("SFTPGO_WEBDAVD__BINDINGS__%v__PREFIX", idx))
if ok { if ok {
binding.DisableWWWAuthHeader = disableWWWAuth binding.Prefix = prefix
isSet = true isSet = true
} }
@@ -1521,12 +1450,6 @@ func getHTTPDOIDCFromEnv(idx int) (httpd.OIDC, bool) {
isSet = true isSet = true
} }
skipSignatureCheck, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__OIDC__INSECURE_SKIP_SIGNATURE_CHECK", idx))
if ok {
result.InsecureSkipSignatureCheck = skipSignatureCheck
isSet = true
}
debug, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__OIDC__DEBUG", idx)) debug, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__OIDC__DEBUG", idx))
if ok { if ok {
result.Debug = debug result.Debug = debug
@@ -1592,7 +1515,6 @@ func getHTTPDUIBrandingFromEnv(prefix string, branding httpd.UIBranding) (httpd.
branding.ExtraCSS = extraCSS branding.ExtraCSS = extraCSS
isSet = true isSet = true
} }
return branding, isSet return branding, isSet
} }
@@ -1760,12 +1682,6 @@ func getHTTPDBindingFromEnv(idx int) { //nolint:gocyclo
isSet = true isSet = true
} }
enableRESTAPI, ok := lookupBoolFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__ENABLE_REST_API", idx))
if ok {
binding.EnableRESTAPI = enableRESTAPI
isSet = true
}
enabledLoginMethods, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__ENABLED_LOGIN_METHODS", idx)) enabledLoginMethods, ok := lookupIntFromEnv(fmt.Sprintf("SFTPGO_HTTPD__BINDINGS__%v__ENABLED_LOGIN_METHODS", idx))
if ok { if ok {
binding.EnabledLoginMethods = int(enabledLoginMethods) binding.EnabledLoginMethods = int(enabledLoginMethods)
@@ -1930,7 +1846,6 @@ func setViperDefaults() {
viper.SetDefault("common.max_total_connections", globalConf.Common.MaxTotalConnections) viper.SetDefault("common.max_total_connections", globalConf.Common.MaxTotalConnections)
viper.SetDefault("common.max_per_host_connections", globalConf.Common.MaxPerHostConnections) viper.SetDefault("common.max_per_host_connections", globalConf.Common.MaxPerHostConnections)
viper.SetDefault("common.whitelist_file", globalConf.Common.WhiteListFile) viper.SetDefault("common.whitelist_file", globalConf.Common.WhiteListFile)
viper.SetDefault("common.allow_self_connections", globalConf.Common.AllowSelfConnections)
viper.SetDefault("common.defender.enabled", globalConf.Common.DefenderConfig.Enabled) viper.SetDefault("common.defender.enabled", globalConf.Common.DefenderConfig.Enabled)
viper.SetDefault("common.defender.driver", globalConf.Common.DefenderConfig.Driver) viper.SetDefault("common.defender.driver", globalConf.Common.DefenderConfig.Driver)
viper.SetDefault("common.defender.ban_time", globalConf.Common.DefenderConfig.BanTime) viper.SetDefault("common.defender.ban_time", globalConf.Common.DefenderConfig.BanTime)
@@ -1995,9 +1910,6 @@ func setViperDefaults() {
viper.SetDefault("webdavd.cors.allowed_headers", globalConf.WebDAVD.Cors.AllowedHeaders) viper.SetDefault("webdavd.cors.allowed_headers", globalConf.WebDAVD.Cors.AllowedHeaders)
viper.SetDefault("webdavd.cors.exposed_headers", globalConf.WebDAVD.Cors.ExposedHeaders) viper.SetDefault("webdavd.cors.exposed_headers", globalConf.WebDAVD.Cors.ExposedHeaders)
viper.SetDefault("webdavd.cors.allow_credentials", globalConf.WebDAVD.Cors.AllowCredentials) viper.SetDefault("webdavd.cors.allow_credentials", globalConf.WebDAVD.Cors.AllowCredentials)
viper.SetDefault("webdavd.cors.options_passthrough", globalConf.WebDAVD.Cors.OptionsPassthrough)
viper.SetDefault("webdavd.cors.options_success_status", globalConf.WebDAVD.Cors.OptionsSuccessStatus)
viper.SetDefault("webdavd.cors.allow_private_network", globalConf.WebDAVD.Cors.AllowPrivateNetwork)
viper.SetDefault("webdavd.cors.max_age", globalConf.WebDAVD.Cors.MaxAge) viper.SetDefault("webdavd.cors.max_age", globalConf.WebDAVD.Cors.MaxAge)
viper.SetDefault("webdavd.cache.users.expiration_time", globalConf.WebDAVD.Cache.Users.ExpirationTime) viper.SetDefault("webdavd.cache.users.expiration_time", globalConf.WebDAVD.Cache.Users.ExpirationTime)
viper.SetDefault("webdavd.cache.users.max_size", globalConf.WebDAVD.Cache.Users.MaxSize) viper.SetDefault("webdavd.cache.users.max_size", globalConf.WebDAVD.Cache.Users.MaxSize)
@@ -2010,8 +1922,6 @@ func setViperDefaults() {
viper.SetDefault("data_provider.username", globalConf.ProviderConf.Username) viper.SetDefault("data_provider.username", globalConf.ProviderConf.Username)
viper.SetDefault("data_provider.password", globalConf.ProviderConf.Password) viper.SetDefault("data_provider.password", globalConf.ProviderConf.Password)
viper.SetDefault("data_provider.sslmode", globalConf.ProviderConf.SSLMode) viper.SetDefault("data_provider.sslmode", globalConf.ProviderConf.SSLMode)
viper.SetDefault("data_provider.disable_sni", globalConf.ProviderConf.DisableSNI)
viper.SetDefault("data_provider.target_session_attrs", globalConf.ProviderConf.TargetSessionAttrs)
viper.SetDefault("data_provider.root_cert", globalConf.ProviderConf.RootCert) viper.SetDefault("data_provider.root_cert", globalConf.ProviderConf.RootCert)
viper.SetDefault("data_provider.client_cert", globalConf.ProviderConf.ClientCert) viper.SetDefault("data_provider.client_cert", globalConf.ProviderConf.ClientCert)
viper.SetDefault("data_provider.client_key", globalConf.ProviderConf.ClientKey) viper.SetDefault("data_provider.client_key", globalConf.ProviderConf.ClientKey)
@@ -2025,6 +1935,7 @@ func setViperDefaults() {
viper.SetDefault("data_provider.actions.hook", globalConf.ProviderConf.Actions.Hook) viper.SetDefault("data_provider.actions.hook", globalConf.ProviderConf.Actions.Hook)
viper.SetDefault("data_provider.external_auth_hook", globalConf.ProviderConf.ExternalAuthHook) viper.SetDefault("data_provider.external_auth_hook", globalConf.ProviderConf.ExternalAuthHook)
viper.SetDefault("data_provider.external_auth_scope", globalConf.ProviderConf.ExternalAuthScope) viper.SetDefault("data_provider.external_auth_scope", globalConf.ProviderConf.ExternalAuthScope)
viper.SetDefault("data_provider.credentials_path", globalConf.ProviderConf.CredentialsPath)
viper.SetDefault("data_provider.pre_login_hook", globalConf.ProviderConf.PreLoginHook) viper.SetDefault("data_provider.pre_login_hook", globalConf.ProviderConf.PreLoginHook)
viper.SetDefault("data_provider.post_login_hook", globalConf.ProviderConf.PostLoginHook) viper.SetDefault("data_provider.post_login_hook", globalConf.ProviderConf.PostLoginHook)
viper.SetDefault("data_provider.post_login_scope", globalConf.ProviderConf.PostLoginScope) viper.SetDefault("data_provider.post_login_scope", globalConf.ProviderConf.PostLoginScope)
@@ -2043,10 +1954,10 @@ func setViperDefaults() {
viper.SetDefault("data_provider.create_default_admin", globalConf.ProviderConf.CreateDefaultAdmin) viper.SetDefault("data_provider.create_default_admin", globalConf.ProviderConf.CreateDefaultAdmin)
viper.SetDefault("data_provider.naming_rules", globalConf.ProviderConf.NamingRules) viper.SetDefault("data_provider.naming_rules", globalConf.ProviderConf.NamingRules)
viper.SetDefault("data_provider.is_shared", globalConf.ProviderConf.IsShared) viper.SetDefault("data_provider.is_shared", globalConf.ProviderConf.IsShared)
viper.SetDefault("data_provider.node.host", globalConf.ProviderConf.Node.Host)
viper.SetDefault("data_provider.node.port", globalConf.ProviderConf.Node.Port)
viper.SetDefault("data_provider.node.proto", globalConf.ProviderConf.Node.Proto)
viper.SetDefault("data_provider.backups_path", globalConf.ProviderConf.BackupsPath) viper.SetDefault("data_provider.backups_path", globalConf.ProviderConf.BackupsPath)
viper.SetDefault("data_provider.auto_backup.enabled", globalConf.ProviderConf.AutoBackup.Enabled)
viper.SetDefault("data_provider.auto_backup.hour", globalConf.ProviderConf.AutoBackup.Hour)
viper.SetDefault("data_provider.auto_backup.day_of_week", globalConf.ProviderConf.AutoBackup.DayOfWeek)
viper.SetDefault("httpd.templates_path", globalConf.HTTPDConfig.TemplatesPath) viper.SetDefault("httpd.templates_path", globalConf.HTTPDConfig.TemplatesPath)
viper.SetDefault("httpd.static_files_path", globalConf.HTTPDConfig.StaticFilesPath) viper.SetDefault("httpd.static_files_path", globalConf.HTTPDConfig.StaticFilesPath)
viper.SetDefault("httpd.openapi_path", globalConf.HTTPDConfig.OpenAPIPath) viper.SetDefault("httpd.openapi_path", globalConf.HTTPDConfig.OpenAPIPath)
@@ -2065,9 +1976,6 @@ func setViperDefaults() {
viper.SetDefault("httpd.cors.exposed_headers", globalConf.HTTPDConfig.Cors.ExposedHeaders) viper.SetDefault("httpd.cors.exposed_headers", globalConf.HTTPDConfig.Cors.ExposedHeaders)
viper.SetDefault("httpd.cors.allow_credentials", globalConf.HTTPDConfig.Cors.AllowCredentials) viper.SetDefault("httpd.cors.allow_credentials", globalConf.HTTPDConfig.Cors.AllowCredentials)
viper.SetDefault("httpd.cors.max_age", globalConf.HTTPDConfig.Cors.MaxAge) viper.SetDefault("httpd.cors.max_age", globalConf.HTTPDConfig.Cors.MaxAge)
viper.SetDefault("httpd.cors.options_passthrough", globalConf.HTTPDConfig.Cors.OptionsPassthrough)
viper.SetDefault("httpd.cors.options_success_status", globalConf.HTTPDConfig.Cors.OptionsSuccessStatus)
viper.SetDefault("httpd.cors.allow_private_network", globalConf.HTTPDConfig.Cors.AllowPrivateNetwork)
viper.SetDefault("httpd.setup.installation_code", globalConf.HTTPDConfig.Setup.InstallationCode) viper.SetDefault("httpd.setup.installation_code", globalConf.HTTPDConfig.Setup.InstallationCode)
viper.SetDefault("httpd.setup.installation_code_hint", globalConf.HTTPDConfig.Setup.InstallationCodeHint) viper.SetDefault("httpd.setup.installation_code_hint", globalConf.HTTPDConfig.Setup.InstallationCodeHint)
viper.SetDefault("httpd.hide_support_link", globalConf.HTTPDConfig.HideSupportLink) viper.SetDefault("httpd.hide_support_link", globalConf.HTTPDConfig.HideSupportLink)

View File

@@ -26,28 +26,24 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/drakkan/sftpgo/v2/internal/command" "github.com/drakkan/sftpgo/v2/command"
"github.com/drakkan/sftpgo/v2/internal/common" "github.com/drakkan/sftpgo/v2/common"
"github.com/drakkan/sftpgo/v2/internal/config" "github.com/drakkan/sftpgo/v2/config"
"github.com/drakkan/sftpgo/v2/internal/dataprovider" "github.com/drakkan/sftpgo/v2/dataprovider"
"github.com/drakkan/sftpgo/v2/internal/ftpd" "github.com/drakkan/sftpgo/v2/ftpd"
"github.com/drakkan/sftpgo/v2/internal/httpclient" "github.com/drakkan/sftpgo/v2/httpclient"
"github.com/drakkan/sftpgo/v2/internal/httpd" "github.com/drakkan/sftpgo/v2/httpd"
"github.com/drakkan/sftpgo/v2/internal/mfa" "github.com/drakkan/sftpgo/v2/mfa"
"github.com/drakkan/sftpgo/v2/internal/plugin" "github.com/drakkan/sftpgo/v2/plugin"
"github.com/drakkan/sftpgo/v2/internal/sftpd" "github.com/drakkan/sftpgo/v2/sftpd"
"github.com/drakkan/sftpgo/v2/internal/smtp" "github.com/drakkan/sftpgo/v2/smtp"
"github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/util"
) )
const ( const (
tempConfigName = "temp" tempConfigName = "temp"
) )
var (
configDir = filepath.Join(".", "..", "..")
)
func reset() { func reset() {
viper.Reset() viper.Reset()
config.Init() config.Init()
@@ -56,6 +52,7 @@ func reset() {
func TestLoadConfigTest(t *testing.T) { func TestLoadConfigTest(t *testing.T) {
reset() reset()
configDir := ".."
err := config.LoadConfig(configDir, "") err := config.LoadConfig(configDir, "")
assert.NoError(t, err) assert.NoError(t, err)
assert.NotEqual(t, httpd.Conf{}, config.GetHTTPConfig()) assert.NotEqual(t, httpd.Conf{}, config.GetHTTPConfig())
@@ -84,41 +81,15 @@ func TestLoadConfigFileNotFound(t *testing.T) {
viper.SetConfigName("configfile") viper.SetConfigName("configfile")
err := config.LoadConfig(os.TempDir(), "") err := config.LoadConfig(os.TempDir(), "")
require.NoError(t, err) assert.NoError(t, err)
mfaConf := config.GetMFAConfig() mfaConf := config.GetMFAConfig()
require.Len(t, mfaConf.TOTP, 1) assert.Len(t, mfaConf.TOTP, 1)
require.Len(t, config.GetCommonConfig().RateLimitersConfig, 1)
require.Len(t, config.GetCommonConfig().RateLimitersConfig[0].Protocols, 4)
require.Len(t, config.GetHTTPDConfig().Bindings, 1)
require.Len(t, config.GetHTTPDConfig().Bindings[0].OIDC.Scopes, 3)
}
func TestReadEnvFiles(t *testing.T) {
reset()
envd := filepath.Join(configDir, "env.d")
err := os.Mkdir(envd, os.ModePerm)
assert.NoError(t, err)
err = os.WriteFile(filepath.Join(envd, "env1"), []byte("SFTPGO_SFTPD__MAX_AUTH_TRIES = 10"), 0666)
assert.NoError(t, err)
err = os.WriteFile(filepath.Join(envd, "env2"), []byte(`{"invalid env": "value"}`), 0666)
assert.NoError(t, err)
err = config.LoadConfig(configDir, "")
assert.NoError(t, err)
assert.Equal(t, 10, config.GetSFTPDConfig().MaxAuthTries)
_, ok := os.LookupEnv("SFTPGO_SFTPD__MAX_AUTH_TRIES")
assert.True(t, ok)
err = os.Unsetenv("SFTPGO_SFTPD__MAX_AUTH_TRIES")
assert.NoError(t, err)
os.RemoveAll(envd)
} }
func TestEmptyBanner(t *testing.T) { func TestEmptyBanner(t *testing.T) {
reset() reset()
configDir := ".."
confName := tempConfigName + ".json" confName := tempConfigName + ".json"
configFilePath := filepath.Join(configDir, confName) configFilePath := filepath.Join(configDir, confName)
err := config.LoadConfig(configDir, "") err := config.LoadConfig(configDir, "")
@@ -155,6 +126,7 @@ func TestEmptyBanner(t *testing.T) {
func TestEnabledSSHCommands(t *testing.T) { func TestEnabledSSHCommands(t *testing.T) {
reset() reset()
configDir := ".."
confName := tempConfigName + ".json" confName := tempConfigName + ".json"
configFilePath := filepath.Join(configDir, confName) configFilePath := filepath.Join(configDir, confName)
err := config.LoadConfig(configDir, "") err := config.LoadConfig(configDir, "")
@@ -183,6 +155,7 @@ func TestEnabledSSHCommands(t *testing.T) {
func TestInvalidUploadMode(t *testing.T) { func TestInvalidUploadMode(t *testing.T) {
reset() reset()
configDir := ".."
confName := tempConfigName + ".json" confName := tempConfigName + ".json"
configFilePath := filepath.Join(configDir, confName) configFilePath := filepath.Join(configDir, confName)
err := config.LoadConfig(configDir, "") err := config.LoadConfig(configDir, "")
@@ -205,6 +178,7 @@ func TestInvalidUploadMode(t *testing.T) {
func TestInvalidExternalAuthScope(t *testing.T) { func TestInvalidExternalAuthScope(t *testing.T) {
reset() reset()
configDir := ".."
confName := tempConfigName + ".json" confName := tempConfigName + ".json"
configFilePath := filepath.Join(configDir, confName) configFilePath := filepath.Join(configDir, confName)
err := config.LoadConfig(configDir, "") err := config.LoadConfig(configDir, "")
@@ -224,9 +198,33 @@ func TestInvalidExternalAuthScope(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
} }
func TestInvalidCredentialsPath(t *testing.T) {
reset()
configDir := ".."
confName := tempConfigName + ".json"
configFilePath := filepath.Join(configDir, confName)
err := config.LoadConfig(configDir, "")
assert.NoError(t, err)
providerConf := config.GetProviderConf()
providerConf.CredentialsPath = ""
c := make(map[string]dataprovider.Config)
c["data_provider"] = providerConf
jsonConf, err := json.Marshal(c)
assert.NoError(t, err)
err = os.WriteFile(configFilePath, jsonConf, os.ModePerm)
assert.NoError(t, err)
err = config.LoadConfig(configDir, confName)
assert.NoError(t, err)
assert.Equal(t, "credentials", config.GetProviderConf().CredentialsPath)
err = os.Remove(configFilePath)
assert.NoError(t, err)
}
func TestInvalidProxyProtocol(t *testing.T) { func TestInvalidProxyProtocol(t *testing.T) {
reset() reset()
configDir := ".."
confName := tempConfigName + ".json" confName := tempConfigName + ".json"
configFilePath := filepath.Join(configDir, confName) configFilePath := filepath.Join(configDir, confName)
err := config.LoadConfig(configDir, "") err := config.LoadConfig(configDir, "")
@@ -249,6 +247,7 @@ func TestInvalidProxyProtocol(t *testing.T) {
func TestInvalidUsersBaseDir(t *testing.T) { func TestInvalidUsersBaseDir(t *testing.T) {
reset() reset()
configDir := ".."
confName := tempConfigName + ".json" confName := tempConfigName + ".json"
configFilePath := filepath.Join(configDir, confName) configFilePath := filepath.Join(configDir, confName)
err := config.LoadConfig(configDir, "") err := config.LoadConfig(configDir, "")
@@ -271,6 +270,7 @@ func TestInvalidUsersBaseDir(t *testing.T) {
func TestInvalidInstallationHint(t *testing.T) { func TestInvalidInstallationHint(t *testing.T) {
reset() reset()
configDir := ".."
confName := tempConfigName + ".json" confName := tempConfigName + ".json"
configFilePath := filepath.Join(configDir, confName) configFilePath := filepath.Join(configDir, confName)
err := config.LoadConfig(configDir, "") err := config.LoadConfig(configDir, "")
@@ -301,6 +301,7 @@ func TestDefenderProviderDriver(t *testing.T) {
} }
reset() reset()
configDir := ".."
confName := tempConfigName + ".json" confName := tempConfigName + ".json"
configFilePath := filepath.Join(configDir, confName) configFilePath := filepath.Join(configDir, confName)
providerConf := config.GetProviderConf() providerConf := config.GetProviderConf()
@@ -380,6 +381,7 @@ func TestSetGetConfig(t *testing.T) {
func TestServiceToStart(t *testing.T) { func TestServiceToStart(t *testing.T) {
reset() reset()
configDir := ".."
err := config.LoadConfig(configDir, "") err := config.LoadConfig(configDir, "")
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, config.HasServicesToStart()) assert.True(t, config.HasServicesToStart())
@@ -413,6 +415,7 @@ func TestSSHCommandsFromEnv(t *testing.T) {
os.Unsetenv("SFTPGO_SFTPD__ENABLED_SSH_COMMANDS") os.Unsetenv("SFTPGO_SFTPD__ENABLED_SSH_COMMANDS")
}) })
configDir := ".."
err := config.LoadConfig(configDir, "") err := config.LoadConfig(configDir, "")
assert.NoError(t, err) assert.NoError(t, err)
@@ -433,6 +436,7 @@ func TestSMTPFromEnv(t *testing.T) {
os.Unsetenv("SFTPGO_SMTP__PORT") os.Unsetenv("SFTPGO_SMTP__PORT")
}) })
configDir := ".."
err := config.LoadConfig(configDir, "") err := config.LoadConfig(configDir, "")
assert.NoError(t, err) assert.NoError(t, err)
smtpConfig := config.GetSMTPConfig() smtpConfig := config.GetSMTPConfig()
@@ -454,6 +458,7 @@ func TestMFAFromEnv(t *testing.T) {
os.Unsetenv("SFTPGO_MFA__TOTP__1__ALGO") os.Unsetenv("SFTPGO_MFA__TOTP__1__ALGO")
}) })
configDir := ".."
err := config.LoadConfig(configDir, "") err := config.LoadConfig(configDir, "")
assert.NoError(t, err) assert.NoError(t, err)
mfaConf := config.GetMFAConfig() mfaConf := config.GetMFAConfig()
@@ -469,6 +474,7 @@ func TestMFAFromEnv(t *testing.T) {
func TestDisabledMFAConfig(t *testing.T) { func TestDisabledMFAConfig(t *testing.T) {
reset() reset()
configDir := ".."
confName := tempConfigName + ".json" confName := tempConfigName + ".json"
configFilePath := filepath.Join(configDir, confName) configFilePath := filepath.Join(configDir, confName)
@@ -493,81 +499,6 @@ func TestDisabledMFAConfig(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
} }
func TestOverrideSliceValues(t *testing.T) {
reset()
confName := tempConfigName + ".json"
configFilePath := filepath.Join(configDir, confName)
c := make(map[string]any)
c["common"] = common.Configuration{
RateLimitersConfig: []common.RateLimiterConfig{
{
Type: 1,
Protocols: []string{"HTTP"},
},
},
}
jsonConf, err := json.Marshal(c)
assert.NoError(t, err)
err = os.WriteFile(configFilePath, jsonConf, os.ModePerm)
assert.NoError(t, err)
err = config.LoadConfig(configDir, confName)
assert.NoError(t, err)
require.Len(t, config.GetCommonConfig().RateLimitersConfig, 1)
require.Equal(t, []string{"HTTP"}, config.GetCommonConfig().RateLimitersConfig[0].Protocols)
reset()
// empty ratelimiters, default value should be used
c["common"] = common.Configuration{}
jsonConf, err = json.Marshal(c)
assert.NoError(t, err)
err = os.WriteFile(configFilePath, jsonConf, os.ModePerm)
assert.NoError(t, err)
err = config.LoadConfig(configDir, confName)
assert.NoError(t, err)
require.Len(t, config.GetCommonConfig().RateLimitersConfig, 1)
rl := config.GetCommonConfig().RateLimitersConfig[0]
require.Equal(t, []string{"SSH", "FTP", "DAV", "HTTP"}, rl.Protocols)
require.Equal(t, int64(1000), rl.Period)
reset()
c = make(map[string]any)
c["httpd"] = httpd.Conf{
Bindings: []httpd.Binding{
{
OIDC: httpd.OIDC{
Scopes: []string{"scope1"},
},
},
},
}
jsonConf, err = json.Marshal(c)
assert.NoError(t, err)
err = os.WriteFile(configFilePath, jsonConf, os.ModePerm)
assert.NoError(t, err)
err = config.LoadConfig(configDir, confName)
assert.NoError(t, err)
require.Len(t, config.GetHTTPDConfig().Bindings, 1)
require.Equal(t, []string{"scope1"}, config.GetHTTPDConfig().Bindings[0].OIDC.Scopes)
reset()
c = make(map[string]any)
c["httpd"] = httpd.Conf{
Bindings: []httpd.Binding{},
}
jsonConf, err = json.Marshal(c)
assert.NoError(t, err)
err = os.WriteFile(configFilePath, jsonConf, os.ModePerm)
assert.NoError(t, err)
err = config.LoadConfig(configDir, confName)
assert.NoError(t, err)
require.Len(t, config.GetHTTPDConfig().Bindings, 1)
require.Equal(t, []string{"openid", "profile", "email"}, config.GetHTTPDConfig().Bindings[0].OIDC.Scopes)
}
func TestFTPDOverridesFromEnv(t *testing.T) { func TestFTPDOverridesFromEnv(t *testing.T) {
reset() reset()
@@ -583,6 +514,7 @@ func TestFTPDOverridesFromEnv(t *testing.T) {
} }
t.Cleanup(cleanup) t.Cleanup(cleanup)
configDir := ".."
err := config.LoadConfig(configDir, "") err := config.LoadConfig(configDir, "")
assert.NoError(t, err) assert.NoError(t, err)
ftpdConf := config.GetFTPDConfig() ftpdConf := config.GetFTPDConfig()
@@ -643,6 +575,7 @@ func TestHTTPDSubObjectsFromEnv(t *testing.T) {
} }
t.Cleanup(cleanup) t.Cleanup(cleanup)
configDir := ".."
err := config.LoadConfig(configDir, "") err := config.LoadConfig(configDir, "")
assert.NoError(t, err) assert.NoError(t, err)
httpdConf := config.GetHTTPDConfig() httpdConf := config.GetHTTPDConfig()
@@ -718,6 +651,7 @@ func TestPluginsFromEnv(t *testing.T) {
os.Unsetenv("SFTPGO_PLUGINS__0__AUTH_OPTIONS__SCOPE") os.Unsetenv("SFTPGO_PLUGINS__0__AUTH_OPTIONS__SCOPE")
}) })
configDir := ".."
err := config.LoadConfig(configDir, "") err := config.LoadConfig(configDir, "")
assert.NoError(t, err) assert.NoError(t, err)
pluginsConf := config.GetPluginsConfig() pluginsConf := config.GetPluginsConfig()
@@ -816,6 +750,7 @@ func TestRateLimitersFromEnv(t *testing.T) {
os.Unsetenv("SFTPGO_COMMON__RATE_LIMITERS__8__ALLOW_LIST") os.Unsetenv("SFTPGO_COMMON__RATE_LIMITERS__8__ALLOW_LIST")
}) })
configDir := ".."
err := config.LoadConfig(configDir, "") err := config.LoadConfig(configDir, "")
assert.NoError(t, err) assert.NoError(t, err)
limiters := config.GetCommonConfig().RateLimitersConfig limiters := config.GetCommonConfig().RateLimitersConfig
@@ -868,6 +803,7 @@ func TestSFTPDBindingsFromEnv(t *testing.T) {
os.Unsetenv("SFTPGO_SFTPD__BINDINGS__3__PORT") os.Unsetenv("SFTPGO_SFTPD__BINDINGS__3__PORT")
}) })
configDir := ".."
err := config.LoadConfig(configDir, "") err := config.LoadConfig(configDir, "")
assert.NoError(t, err) assert.NoError(t, err)
bindings := config.GetSFTPDConfig().Bindings bindings := config.GetSFTPDConfig().Bindings
@@ -883,6 +819,7 @@ func TestSFTPDBindingsFromEnv(t *testing.T) {
func TestCommandsFromEnv(t *testing.T) { func TestCommandsFromEnv(t *testing.T) {
reset() reset()
configDir := ".."
confName := tempConfigName + ".json" confName := tempConfigName + ".json"
configFilePath := filepath.Join(configDir, confName) configFilePath := filepath.Join(configDir, confName)
err := config.LoadConfig(configDir, "") err := config.LoadConfig(configDir, "")
@@ -989,6 +926,7 @@ func TestFTPDBindingsFromEnv(t *testing.T) {
os.Unsetenv("SFTPGO_FTPD__BINDINGS__9__CERTIFICATE_KEY_FILE") os.Unsetenv("SFTPGO_FTPD__BINDINGS__9__CERTIFICATE_KEY_FILE")
}) })
configDir := ".."
err := config.LoadConfig(configDir, "") err := config.LoadConfig(configDir, "")
assert.NoError(t, err) assert.NoError(t, err)
bindings := config.GetFTPDConfig().Bindings bindings := config.GetFTPDConfig().Bindings
@@ -1045,7 +983,6 @@ func TestWebDAVBindingsFromEnv(t *testing.T) {
os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__PREFIX", "/dav2") os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__PREFIX", "/dav2")
os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__CERTIFICATE_FILE", "webdav.crt") os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__CERTIFICATE_FILE", "webdav.crt")
os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__CERTIFICATE_KEY_FILE", "webdav.key") os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__CERTIFICATE_KEY_FILE", "webdav.key")
os.Setenv("SFTPGO_WEBDAVD__BINDINGS__2__DISABLE_WWW_AUTH_HEADER", "1")
t.Cleanup(func() { t.Cleanup(func() {
os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__ADDRESS") os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__1__ADDRESS")
@@ -1063,9 +1000,9 @@ func TestWebDAVBindingsFromEnv(t *testing.T) {
os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__PREFIX") os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__PREFIX")
os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__CERTIFICATE_FILE") os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__CERTIFICATE_FILE")
os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__CERTIFICATE_KEY_FILE") os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__CERTIFICATE_KEY_FILE")
os.Unsetenv("SFTPGO_WEBDAVD__BINDINGS__2__DISABLE_WWW_AUTH_HEADER")
}) })
configDir := ".."
err := config.LoadConfig(configDir, "") err := config.LoadConfig(configDir, "")
assert.NoError(t, err) assert.NoError(t, err)
bindings := config.GetWebDAVDConfig().Bindings bindings := config.GetWebDAVDConfig().Bindings
@@ -1077,7 +1014,6 @@ func TestWebDAVBindingsFromEnv(t *testing.T) {
require.Len(t, bindings[0].TLSCipherSuites, 0) require.Len(t, bindings[0].TLSCipherSuites, 0)
require.Empty(t, bindings[0].Prefix) require.Empty(t, bindings[0].Prefix)
require.Equal(t, 0, bindings[0].ClientIPHeaderDepth) require.Equal(t, 0, bindings[0].ClientIPHeaderDepth)
require.False(t, bindings[0].DisableWWWAuthHeader)
require.Equal(t, 8000, bindings[1].Port) require.Equal(t, 8000, bindings[1].Port)
require.Equal(t, "127.0.0.1", bindings[1].Address) require.Equal(t, "127.0.0.1", bindings[1].Address)
require.False(t, bindings[1].EnableHTTPS) require.False(t, bindings[1].EnableHTTPS)
@@ -1089,7 +1025,6 @@ func TestWebDAVBindingsFromEnv(t *testing.T) {
require.Equal(t, "X-Forwarded-For", bindings[1].ClientIPProxyHeader) require.Equal(t, "X-Forwarded-For", bindings[1].ClientIPProxyHeader)
require.Equal(t, 2, bindings[1].ClientIPHeaderDepth) require.Equal(t, 2, bindings[1].ClientIPHeaderDepth)
require.Empty(t, bindings[1].Prefix) require.Empty(t, bindings[1].Prefix)
require.False(t, bindings[1].DisableWWWAuthHeader)
require.Equal(t, 9000, bindings[2].Port) require.Equal(t, 9000, bindings[2].Port)
require.Equal(t, "127.0.1.1", bindings[2].Address) require.Equal(t, "127.0.1.1", bindings[2].Address)
require.True(t, bindings[2].EnableHTTPS) require.True(t, bindings[2].EnableHTTPS)
@@ -1100,7 +1035,6 @@ func TestWebDAVBindingsFromEnv(t *testing.T) {
require.Equal(t, "webdav.crt", bindings[2].CertificateFile) require.Equal(t, "webdav.crt", bindings[2].CertificateFile)
require.Equal(t, "webdav.key", bindings[2].CertificateKeyFile) require.Equal(t, "webdav.key", bindings[2].CertificateKeyFile)
require.Equal(t, 0, bindings[2].ClientIPHeaderDepth) require.Equal(t, 0, bindings[2].ClientIPHeaderDepth)
require.True(t, bindings[2].DisableWWWAuthHeader)
} }
func TestHTTPDBindingsFromEnv(t *testing.T) { func TestHTTPDBindingsFromEnv(t *testing.T) {
@@ -1121,7 +1055,6 @@ func TestHTTPDBindingsFromEnv(t *testing.T) {
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__PORT", "9000") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__PORT", "9000")
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_WEB_ADMIN", "0") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_WEB_ADMIN", "0")
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_WEB_CLIENT", "0") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_WEB_CLIENT", "0")
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_REST_API", "0")
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__ENABLED_LOGIN_METHODS", "3") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__ENABLED_LOGIN_METHODS", "3")
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__RENDER_OPENAPI", "0") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__RENDER_OPENAPI", "0")
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_HTTPS", "1 ") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_HTTPS", "1 ")
@@ -1145,7 +1078,6 @@ func TestHTTPDBindingsFromEnv(t *testing.T) {
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__SCOPES", "openid") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__SCOPES", "openid")
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__IMPLICIT_ROLES", "1") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__IMPLICIT_ROLES", "1")
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__CUSTOM_FIELDS", "field1,field2") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__CUSTOM_FIELDS", "field1,field2")
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__INSECURE_SKIP_SIGNATURE_CHECK", "1")
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__DEBUG", "1") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__DEBUG", "1")
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__ENABLED", "true") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__ENABLED", "true")
os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__ALLOWED_HOSTS", "*.example.com,*.example.net") os.Setenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__ALLOWED_HOSTS", "*.example.com,*.example.net")
@@ -1192,7 +1124,6 @@ func TestHTTPDBindingsFromEnv(t *testing.T) {
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__MIN_TLS_VERSION") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__MIN_TLS_VERSION")
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_WEB_ADMIN") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_WEB_ADMIN")
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_WEB_CLIENT") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_WEB_CLIENT")
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__ENABLE_REST_API")
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__ENABLED_LOGIN_METHODS") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__ENABLED_LOGIN_METHODS")
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__RENDER_OPENAPI") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__RENDER_OPENAPI")
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__CLIENT_AUTH_TYPE") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__CLIENT_AUTH_TYPE")
@@ -1214,7 +1145,6 @@ func TestHTTPDBindingsFromEnv(t *testing.T) {
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__SCOPES") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__SCOPES")
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__IMPLICIT_ROLES") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__IMPLICIT_ROLES")
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__CUSTOM_FIELDS") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__CUSTOM_FIELDS")
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__INSECURE_SKIP_SIGNATURE_CHECK")
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__DEBUG") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__OIDC__DEBUG")
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__ENABLED") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__ENABLED")
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__ALLOWED_HOSTS") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__SECURITY__ALLOWED_HOSTS")
@@ -1245,6 +1175,7 @@ func TestHTTPDBindingsFromEnv(t *testing.T) {
os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__CERTIFICATE_KEY_FILE") os.Unsetenv("SFTPGO_HTTPD__BINDINGS__2__CERTIFICATE_KEY_FILE")
}) })
configDir := ".."
err := config.LoadConfig(configDir, "") err := config.LoadConfig(configDir, "")
assert.NoError(t, err) assert.NoError(t, err)
bindings := config.GetHTTPDConfig().Bindings bindings := config.GetHTTPDConfig().Bindings
@@ -1255,7 +1186,6 @@ func TestHTTPDBindingsFromEnv(t *testing.T) {
require.Equal(t, 12, bindings[0].MinTLSVersion) require.Equal(t, 12, bindings[0].MinTLSVersion)
require.True(t, bindings[0].EnableWebAdmin) require.True(t, bindings[0].EnableWebAdmin)
require.True(t, bindings[0].EnableWebClient) require.True(t, bindings[0].EnableWebClient)
require.True(t, bindings[0].EnableRESTAPI)
require.Equal(t, 0, bindings[0].EnabledLoginMethods) require.Equal(t, 0, bindings[0].EnabledLoginMethods)
require.True(t, bindings[0].RenderOpenAPI) require.True(t, bindings[0].RenderOpenAPI)
require.Len(t, bindings[0].TLSCipherSuites, 1) require.Len(t, bindings[0].TLSCipherSuites, 1)
@@ -1265,7 +1195,6 @@ func TestHTTPDBindingsFromEnv(t *testing.T) {
require.False(t, bindings[0].Security.Enabled) require.False(t, bindings[0].Security.Enabled)
require.Equal(t, 0, bindings[0].ClientIPHeaderDepth) require.Equal(t, 0, bindings[0].ClientIPHeaderDepth)
require.Len(t, bindings[0].OIDC.Scopes, 3) require.Len(t, bindings[0].OIDC.Scopes, 3)
require.False(t, bindings[0].OIDC.InsecureSkipSignatureCheck)
require.False(t, bindings[0].OIDC.Debug) require.False(t, bindings[0].OIDC.Debug)
require.Equal(t, 8000, bindings[1].Port) require.Equal(t, 8000, bindings[1].Port)
require.Equal(t, "127.0.0.1", bindings[1].Address) require.Equal(t, "127.0.0.1", bindings[1].Address)
@@ -1273,14 +1202,12 @@ func TestHTTPDBindingsFromEnv(t *testing.T) {
require.Equal(t, 12, bindings[0].MinTLSVersion) require.Equal(t, 12, bindings[0].MinTLSVersion)
require.True(t, bindings[1].EnableWebAdmin) require.True(t, bindings[1].EnableWebAdmin)
require.True(t, bindings[1].EnableWebClient) require.True(t, bindings[1].EnableWebClient)
require.True(t, bindings[1].EnableRESTAPI)
require.Equal(t, 0, bindings[1].EnabledLoginMethods) require.Equal(t, 0, bindings[1].EnabledLoginMethods)
require.True(t, bindings[1].RenderOpenAPI) require.True(t, bindings[1].RenderOpenAPI)
require.Nil(t, bindings[1].TLSCipherSuites) require.Nil(t, bindings[1].TLSCipherSuites)
require.Equal(t, 1, bindings[1].HideLoginURL) require.Equal(t, 1, bindings[1].HideLoginURL)
require.Empty(t, bindings[1].OIDC.ClientID) require.Empty(t, bindings[1].OIDC.ClientID)
require.Len(t, bindings[1].OIDC.Scopes, 3) require.Len(t, bindings[1].OIDC.Scopes, 3)
require.False(t, bindings[1].OIDC.InsecureSkipSignatureCheck)
require.False(t, bindings[1].OIDC.Debug) require.False(t, bindings[1].OIDC.Debug)
require.False(t, bindings[1].Security.Enabled) require.False(t, bindings[1].Security.Enabled)
require.Equal(t, "Web Admin", bindings[1].Branding.WebAdmin.Name) require.Equal(t, "Web Admin", bindings[1].Branding.WebAdmin.Name)
@@ -1292,7 +1219,6 @@ func TestHTTPDBindingsFromEnv(t *testing.T) {
require.Equal(t, 13, bindings[2].MinTLSVersion) require.Equal(t, 13, bindings[2].MinTLSVersion)
require.False(t, bindings[2].EnableWebAdmin) require.False(t, bindings[2].EnableWebAdmin)
require.False(t, bindings[2].EnableWebClient) require.False(t, bindings[2].EnableWebClient)
require.False(t, bindings[2].EnableRESTAPI)
require.Equal(t, 3, bindings[2].EnabledLoginMethods) require.Equal(t, 3, bindings[2].EnabledLoginMethods)
require.False(t, bindings[2].RenderOpenAPI) require.False(t, bindings[2].RenderOpenAPI)
require.Equal(t, 1, bindings[2].ClientAuthType) require.Equal(t, 1, bindings[2].ClientAuthType)
@@ -1320,7 +1246,6 @@ func TestHTTPDBindingsFromEnv(t *testing.T) {
require.Len(t, bindings[2].OIDC.CustomFields, 2) require.Len(t, bindings[2].OIDC.CustomFields, 2)
require.Equal(t, "field1", bindings[2].OIDC.CustomFields[0]) require.Equal(t, "field1", bindings[2].OIDC.CustomFields[0])
require.Equal(t, "field2", bindings[2].OIDC.CustomFields[1]) require.Equal(t, "field2", bindings[2].OIDC.CustomFields[1])
require.True(t, bindings[2].OIDC.InsecureSkipSignatureCheck)
require.True(t, bindings[2].OIDC.Debug) require.True(t, bindings[2].OIDC.Debug)
require.True(t, bindings[2].Security.Enabled) require.True(t, bindings[2].Security.Enabled)
require.Len(t, bindings[2].Security.AllowedHosts, 2) require.Len(t, bindings[2].Security.AllowedHosts, 2)
@@ -1358,6 +1283,7 @@ func TestHTTPDBindingsFromEnv(t *testing.T) {
func TestHTTPClientCertificatesFromEnv(t *testing.T) { func TestHTTPClientCertificatesFromEnv(t *testing.T) {
reset() reset()
configDir := ".."
confName := tempConfigName + ".json" confName := tempConfigName + ".json"
configFilePath := filepath.Join(configDir, confName) configFilePath := filepath.Join(configDir, confName)
err := config.LoadConfig(configDir, "") err := config.LoadConfig(configDir, "")
@@ -1418,6 +1344,7 @@ func TestHTTPClientCertificatesFromEnv(t *testing.T) {
func TestHTTPClientHeadersFromEnv(t *testing.T) { func TestHTTPClientHeadersFromEnv(t *testing.T) {
reset() reset()
configDir := ".."
confName := tempConfigName + ".json" confName := tempConfigName + ".json"
configFilePath := filepath.Join(configDir, confName) configFilePath := filepath.Join(configDir, confName)
err := config.LoadConfig(configDir, "") err := config.LoadConfig(configDir, "")

View File

@@ -26,11 +26,11 @@ import (
"github.com/sftpgo/sdk/plugin/notifier" "github.com/sftpgo/sdk/plugin/notifier"
"github.com/drakkan/sftpgo/v2/internal/command" "github.com/drakkan/sftpgo/v2/command"
"github.com/drakkan/sftpgo/v2/internal/httpclient" "github.com/drakkan/sftpgo/v2/httpclient"
"github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/internal/plugin" "github.com/drakkan/sftpgo/v2/plugin"
"github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/util"
) )
const ( const (
@@ -42,19 +42,12 @@ const (
) )
const ( const (
actionObjectUser = "user" actionObjectUser = "user"
actionObjectFolder = "folder" actionObjectFolder = "folder"
actionObjectGroup = "group" actionObjectGroup = "group"
actionObjectAdmin = "admin" actionObjectAdmin = "admin"
actionObjectAPIKey = "api_key" actionObjectAPIKey = "api_key"
actionObjectShare = "share" actionObjectShare = "share"
actionObjectEventAction = "event_action"
actionObjectEventRule = "event_rule"
)
var (
actionsConcurrencyGuard = make(chan struct{}, 100)
reservedUsers = []string{ActionExecutorSelf, ActionExecutorSystem}
) )
func executeAction(operation, executor, ip, objectType, objectName string, object plugin.Renderer) { func executeAction(operation, executor, ip, objectType, objectName string, object plugin.Renderer) {
@@ -68,9 +61,6 @@ func executeAction(operation, executor, ip, objectType, objectName string, objec
Timestamp: time.Now().UnixNano(), Timestamp: time.Now().UnixNano(),
}, object) }, object)
} }
if fnHandleRuleForProviderEvent != nil {
fnHandleRuleForProviderEvent(operation, executor, ip, objectType, objectName, object)
}
if config.Actions.Hook == "" { if config.Actions.Hook == "" {
return return
} }
@@ -80,11 +70,6 @@ func executeAction(operation, executor, ip, objectType, objectName string, objec
} }
go func() { go func() {
actionsConcurrencyGuard <- struct{}{}
defer func() {
<-actionsConcurrencyGuard
}()
dataAsJSON, err := object.RenderAsJSON(operation != operationDelete) dataAsJSON, err := object.RenderAsJSON(operation != operationDelete)
if err != nil { if err != nil {
providerLog(logger.LevelError, "unable to serialize user as JSON for operation %#v: %v", operation, err) providerLog(logger.LevelError, "unable to serialize user as JSON for operation %#v: %v", operation, err)
@@ -104,7 +89,7 @@ func executeAction(operation, executor, ip, objectType, objectName string, objec
q.Add("ip", ip) q.Add("ip", ip)
q.Add("object_type", objectType) q.Add("object_type", objectType)
q.Add("object_name", objectName) q.Add("object_name", objectName)
q.Add("timestamp", fmt.Sprintf("%d", time.Now().UnixNano())) q.Add("timestamp", fmt.Sprintf("%v", time.Now().UnixNano()))
url.RawQuery = q.Encode() url.RawQuery = q.Encode()
startTime := time.Now() startTime := time.Now()
resp, err := httpclient.RetryablePost(url.String(), "application/json", bytes.NewBuffer(dataAsJSON)) resp, err := httpclient.RetryablePost(url.String(), "application/json", bytes.NewBuffer(dataAsJSON))
@@ -128,19 +113,19 @@ func executeNotificationCommand(operation, executor, ip, objectType, objectName
return err return err
} }
timeout, env, args := command.GetConfig(config.Actions.Hook, command.HookProviderActions) timeout, env := command.GetConfig(config.Actions.Hook)
ctx, cancel := context.WithTimeout(context.Background(), timeout) ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel() defer cancel()
cmd := exec.CommandContext(ctx, config.Actions.Hook, args...) cmd := exec.CommandContext(ctx, config.Actions.Hook)
cmd.Env = append(env, cmd.Env = append(env,
fmt.Sprintf("SFTPGO_PROVIDER_ACTION=%vs", operation), fmt.Sprintf("SFTPGO_PROVIDER_ACTION=%v", operation),
fmt.Sprintf("SFTPGO_PROVIDER_OBJECT_TYPE=%s", objectType), fmt.Sprintf("SFTPGO_PROVIDER_OBJECT_TYPE=%v", objectType),
fmt.Sprintf("SFTPGO_PROVIDER_OBJECT_NAME=%s", objectName), fmt.Sprintf("SFTPGO_PROVIDER_OBJECT_NAME=%v", objectName),
fmt.Sprintf("SFTPGO_PROVIDER_USERNAME=%s", executor), fmt.Sprintf("SFTPGO_PROVIDER_USERNAME=%v", executor),
fmt.Sprintf("SFTPGO_PROVIDER_IP=%s", ip), fmt.Sprintf("SFTPGO_PROVIDER_IP=%v", ip),
fmt.Sprintf("SFTPGO_PROVIDER_TIMESTAMP=%d", util.GetTimeAsMsSinceEpoch(time.Now())), fmt.Sprintf("SFTPGO_PROVIDER_TIMESTAMP=%v", util.GetTimeAsMsSinceEpoch(time.Now())),
fmt.Sprintf("SFTPGO_PROVIDER_OBJECT=%s", string(objectAsJSON))) fmt.Sprintf("SFTPGO_PROVIDER_OBJECT=%v", string(objectAsJSON)))
startTime := time.Now() startTime := time.Now()
err := cmd.Run() err := cmd.Run()

View File

@@ -22,18 +22,16 @@ import (
"fmt" "fmt"
"net" "net"
"os" "os"
"sort"
"strings" "strings"
"github.com/alexedwards/argon2id" "github.com/alexedwards/argon2id"
"github.com/sftpgo/sdk"
passwordvalidator "github.com/wagslane/go-password-validator" passwordvalidator "github.com/wagslane/go-password-validator"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/kms"
"github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/internal/mfa" "github.com/drakkan/sftpgo/v2/mfa"
"github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/util"
) )
// Available permissions for SFTPGo admins // Available permissions for SFTPGo admins
@@ -56,16 +54,6 @@ const (
PermAdminRetentionChecks = "retention_checks" PermAdminRetentionChecks = "retention_checks"
PermAdminMetadataChecks = "metadata_checks" PermAdminMetadataChecks = "metadata_checks"
PermAdminViewEvents = "view_events" PermAdminViewEvents = "view_events"
PermAdminManageEventRules = "manage_event_rules"
)
const (
// GroupAddToUsersAsMembership defines that the admin's group will be added as membership group for new users
GroupAddToUsersAsMembership = iota
// GroupAddToUsersAsPrimary defines that the admin's group will be added as primary group for new users
GroupAddToUsersAsPrimary
// GroupAddToUsersAsSecondary defines that the admin's group will be added as secondary group for new users
GroupAddToUsersAsSecondary
) )
var ( var (
@@ -107,82 +95,6 @@ func (c *AdminTOTPConfig) validate(username string) error {
return nil return nil
} }
// AdminPreferences defines the admin preferences
type AdminPreferences struct {
// Allow to hide some sections from the user page.
// These are not security settings and are not enforced server side
// in any way. They are only intended to simplify the user page in
// the WebAdmin UI.
//
// 1 means hide groups section
// 2 means hide filesystem section, "users_base_dir" must be set in the config file otherwise this setting is ignored
// 4 means hide virtual folders section
// 8 means hide profile section
// 16 means hide ACLs section
// 32 means hide disk and bandwidth quota limits section
// 64 means hide advanced settings section
//
// The settings can be combined
HideUserPageSections int `json:"hide_user_page_sections,omitempty"`
}
// HideGroups returns true if the groups section should be hidden
func (p *AdminPreferences) HideGroups() bool {
return p.HideUserPageSections&1 != 0
}
// HideFilesystem returns true if the filesystem section should be hidden
func (p *AdminPreferences) HideFilesystem() bool {
return config.UsersBaseDir != "" && p.HideUserPageSections&2 != 0
}
// HideVirtualFolders returns true if the virtual folder section should be hidden
func (p *AdminPreferences) HideVirtualFolders() bool {
return p.HideUserPageSections&4 != 0
}
// HideProfile returns true if the profile section should be hidden
func (p *AdminPreferences) HideProfile() bool {
return p.HideUserPageSections&8 != 0
}
// HideACLs returns true if the ACLs section should be hidden
func (p *AdminPreferences) HideACLs() bool {
return p.HideUserPageSections&16 != 0
}
// HideDiskQuotaAndBandwidthLimits returns true if the disk quota and bandwidth limits
// section should be hidden
func (p *AdminPreferences) HideDiskQuotaAndBandwidthLimits() bool {
return p.HideUserPageSections&32 != 0
}
// HideAdvancedSettings returns true if the advanced settings section should be hidden
func (p *AdminPreferences) HideAdvancedSettings() bool {
return p.HideUserPageSections&64 != 0
}
// VisibleUserPageSections returns the number of visible sections
// in the user page
func (p *AdminPreferences) VisibleUserPageSections() int {
var result int
if !p.HideProfile() {
result++
}
if !p.HideACLs() {
result++
}
if !p.HideDiskQuotaAndBandwidthLimits() {
result++
}
if !p.HideAdvancedSettings() {
result++
}
return result
}
// AdminFilters defines additional restrictions for SFTPGo admins // AdminFilters defines additional restrictions for SFTPGo admins
// TODO: rename to AdminOptions in v3 // TODO: rename to AdminOptions in v3
type AdminFilters struct { type AdminFilters struct {
@@ -197,38 +109,7 @@ type AdminFilters struct {
// Recovery codes to use if the user loses access to their second factor auth device. // Recovery codes to use if the user loses access to their second factor auth device.
// Each code can only be used once, you should use these codes to login and disable or // Each code can only be used once, you should use these codes to login and disable or
// reset 2FA for your account // reset 2FA for your account
RecoveryCodes []RecoveryCode `json:"recovery_codes,omitempty"` RecoveryCodes []RecoveryCode `json:"recovery_codes,omitempty"`
Preferences AdminPreferences `json:"preferences"`
}
// AdminGroupMappingOptions defines the options for admin/group mapping
type AdminGroupMappingOptions struct {
AddToUsersAs int `json:"add_to_users_as,omitempty"`
}
func (o *AdminGroupMappingOptions) validate() error {
if o.AddToUsersAs < GroupAddToUsersAsMembership || o.AddToUsersAs > GroupAddToUsersAsSecondary {
return util.NewValidationError(fmt.Sprintf("Invalid mode to add groups to new users: %d", o.AddToUsersAs))
}
return nil
}
// GetUserGroupType returns the type for the matching user group
func (o *AdminGroupMappingOptions) GetUserGroupType() int {
switch o.AddToUsersAs {
case GroupAddToUsersAsPrimary:
return sdk.GroupTypePrimary
case GroupAddToUsersAsSecondary:
return sdk.GroupTypeSecondary
default:
return sdk.GroupTypeMembership
}
}
// AdminGroupMapping defines the mapping between an SFTPGo admin and a group
type AdminGroupMapping struct {
Name string `json:"name"`
Options AdminGroupMappingOptions `json:"options"`
} }
// Admin defines a SFTPGo admin // Admin defines a SFTPGo admin
@@ -245,8 +126,6 @@ type Admin struct {
Filters AdminFilters `json:"filters,omitempty"` Filters AdminFilters `json:"filters,omitempty"`
Description string `json:"description,omitempty"` Description string `json:"description,omitempty"`
AdditionalInfo string `json:"additional_info,omitempty"` AdditionalInfo string `json:"additional_info,omitempty"`
// Groups membership
Groups []AdminGroupMapping `json:"groups,omitempty"`
// Creation time as unix timestamp in milliseconds. It will be 0 for admins created before v2.2.0 // Creation time as unix timestamp in milliseconds. It will be 0 for admins created before v2.2.0
CreatedAt int64 `json:"created_at"` CreatedAt int64 `json:"created_at"`
// last update time as unix timestamp in milliseconds // last update time as unix timestamp in milliseconds
@@ -326,33 +205,11 @@ func (a *Admin) validatePermissions() error {
return nil return nil
} }
func (a *Admin) validateGroups() error {
hasPrimary := false
for _, g := range a.Groups {
if g.Name == "" {
return util.NewValidationError("group name is mandatory")
}
if err := g.Options.validate(); err != nil {
return err
}
if g.Options.AddToUsersAs == GroupAddToUsersAsPrimary {
if hasPrimary {
return util.NewValidationError("only one primary group is allowed")
}
hasPrimary = true
}
}
return nil
}
func (a *Admin) validate() error { func (a *Admin) validate() error {
a.SetEmptySecretsIfNil() a.SetEmptySecretsIfNil()
if a.Username == "" { if a.Username == "" {
return util.NewValidationError("username is mandatory") return util.NewValidationError("username is mandatory")
} }
if err := checkReservedUsernames(a.Username); err != nil {
return err
}
if a.Password == "" { if a.Password == "" {
return util.NewValidationError("please set a password") return util.NewValidationError("please set a password")
} }
@@ -385,20 +242,7 @@ func (a *Admin) validate() error {
} }
} }
return a.validateGroups() return nil
}
// GetGroupsAsString returns the user's groups as a string
func (a *Admin) GetGroupsAsString() string {
if len(a.Groups) == 0 {
return ""
}
var groups []string
for _, g := range a.Groups {
groups = append(groups, g.Name)
}
sort.Strings(groups)
return strings.Join(groups, ",")
} }
// CheckPassword verifies the admin password // CheckPassword verifies the admin password
@@ -577,18 +421,6 @@ func (a *Admin) getACopy() Admin {
Used: code.Used, Used: code.Used,
}) })
} }
filters.Preferences = AdminPreferences{
HideUserPageSections: a.Filters.Preferences.HideUserPageSections,
}
groups := make([]AdminGroupMapping, 0, len(a.Groups))
for _, g := range a.Groups {
groups = append(groups, AdminGroupMapping{
Name: g.Name,
Options: AdminGroupMappingOptions{
AddToUsersAs: g.Options.AddToUsersAs,
},
})
}
return Admin{ return Admin{
ID: a.ID, ID: a.ID,
@@ -597,7 +429,6 @@ func (a *Admin) getACopy() Admin {
Password: a.Password, Password: a.Password,
Email: a.Email, Email: a.Email,
Permissions: permissions, Permissions: permissions,
Groups: groups,
Filters: filters, Filters: filters,
AdditionalInfo: a.AdditionalInfo, AdditionalInfo: a.AdditionalInfo,
Description: a.Description, Description: a.Description,

View File

@@ -23,8 +23,8 @@ import (
"github.com/alexedwards/argon2id" "github.com/alexedwards/argon2id"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/util"
) )
// APIKeyScope defines the supported API key scopes // APIKeyScope defines the supported API key scopes

File diff suppressed because it is too large Load Diff

View File

@@ -20,13 +20,13 @@ package dataprovider
import ( import (
"errors" "errors"
"github.com/drakkan/sftpgo/v2/internal/version" "github.com/drakkan/sftpgo/v2/version"
) )
func init() { func init() {
version.AddFeature("-bolt") version.AddFeature("-bolt")
} }
func initializeBoltProvider(_ string) error { func initializeBoltProvider(basePath string) error {
return errors.New("bolt disabled at build time") return errors.New("bolt disabled at build time")
} }

View File

@@ -20,8 +20,8 @@ import (
"golang.org/x/net/webdav" "golang.org/x/net/webdav"
"github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/util"
) )
var ( var (

View File

@@ -22,10 +22,10 @@ import (
"github.com/sftpgo/sdk" "github.com/sftpgo/sdk"
"github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/internal/plugin" "github.com/drakkan/sftpgo/v2/plugin"
"github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/util"
"github.com/drakkan/sftpgo/v2/internal/vfs" "github.com/drakkan/sftpgo/v2/vfs"
) )
// GroupUserSettings defines the settings to apply to users // GroupUserSettings defines the settings to apply to users
@@ -187,8 +187,6 @@ func (g *Group) validateUserSettings() error {
func (g *Group) getACopy() Group { func (g *Group) getACopy() Group {
users := make([]string, len(g.Users)) users := make([]string, len(g.Users))
copy(users, g.Users) copy(users, g.Users)
admins := make([]string, len(g.Admins))
copy(admins, g.Admins)
virtualFolders := make([]vfs.VirtualFolder, 0, len(g.VirtualFolders)) virtualFolders := make([]vfs.VirtualFolder, 0, len(g.VirtualFolders))
for idx := range g.VirtualFolders { for idx := range g.VirtualFolders {
vfolder := g.VirtualFolders[idx].GetACopy() vfolder := g.VirtualFolders[idx].GetACopy()
@@ -209,7 +207,6 @@ func (g *Group) getACopy() Group {
CreatedAt: g.CreatedAt, CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt, UpdatedAt: g.UpdatedAt,
Users: users, Users: users,
Admins: admins,
}, },
UserSettings: GroupUserSettings{ UserSettings: GroupUserSettings{
BaseGroupUserSettings: sdk.BaseGroupUserSettings{ BaseGroupUserSettings: sdk.BaseGroupUserSettings{
@@ -231,14 +228,7 @@ func (g *Group) getACopy() Group {
} }
} }
// GetMembersAsString returns a string representation for the group members // GetUsersAsString returns the list of users as comma separated string
func (g *Group) GetMembersAsString() string { func (g *Group) GetUsersAsString() string {
var sb strings.Builder return strings.Join(g.Users, ",")
if len(g.Users) > 0 {
sb.WriteString(fmt.Sprintf("Users: %d. ", len(g.Users)))
}
if len(g.Admins) > 0 {
sb.WriteString(fmt.Sprintf("Admins: %d. ", len(g.Admins)))
}
return sb.String()
} }

View File

@@ -24,9 +24,9 @@ import (
"sync" "sync"
"time" "time"
"github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/util"
"github.com/drakkan/sftpgo/v2/internal/vfs" "github.com/drakkan/sftpgo/v2/vfs"
) )
var ( var (
@@ -62,14 +62,6 @@ type memoryProviderHandle struct {
shares map[string]Share shares map[string]Share
// slice with ordered shares shareID // slice with ordered shares shareID
sharesIDs []string sharesIDs []string
// map for event actions, name is the key
actions map[string]BaseEventAction
// slice with ordered actions
actionsNames []string
// map for event actions, name is the key
rules map[string]EventRule
// slice with ordered rules
rulesNames []string
} }
// MemoryProvider defines the auth provider for a memory store // MemoryProvider defines the auth provider for a memory store
@@ -100,10 +92,6 @@ func initializeMemoryProvider(basePath string) {
apiKeysIDs: []string{}, apiKeysIDs: []string{},
shares: make(map[string]Share), shares: make(map[string]Share),
sharesIDs: []string{}, sharesIDs: []string{},
actions: make(map[string]BaseEventAction),
actionsNames: []string{},
rules: make(map[string]EventRule),
rulesNames: []string{},
configFile: configFile, configFile: configFile,
}, },
} }
@@ -146,6 +134,10 @@ func (p *MemoryProvider) validateUserAndTLSCert(username, protocol string, tlsCe
} }
func (p *MemoryProvider) validateUserAndPass(username, password, ip, protocol string) (User, error) { func (p *MemoryProvider) validateUserAndPass(username, password, ip, protocol string) (User, error) {
var user User
if password == "" {
return user, errors.New("credentials cannot be null or empty")
}
user, err := p.userExists(username) user, err := p.userExists(username)
if err != nil { if err != nil {
providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err) providerLog(logger.LevelWarn, "error authenticating user %#v: %v", username, err)
@@ -326,22 +318,14 @@ func (p *MemoryProvider) addUser(user *User) error {
user.UsedUploadDataTransfer = 0 user.UsedUploadDataTransfer = 0
user.UsedDownloadDataTransfer = 0 user.UsedDownloadDataTransfer = 0
user.LastLogin = 0 user.LastLogin = 0
user.FirstUpload = 0
user.FirstDownload = 0
user.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) user.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
user.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) user.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
var mappedGroups []string user.VirtualFolders = p.joinUserVirtualFoldersFields(user)
for idx := range user.Groups { for idx := range user.Groups {
if err = p.addUserToGroupMapping(user.Username, user.Groups[idx].Name); err != nil { if err = p.addUserFromGroupMapping(user.Username, user.Groups[idx].Name); err != nil {
// try to remove group mapping
for _, g := range mappedGroups {
p.removeUserFromGroupMapping(user.Username, g)
}
return err return err
} }
mappedGroups = append(mappedGroups, user.Groups[idx].Name)
} }
user.VirtualFolders = p.joinUserVirtualFoldersFields(user)
p.dbHandle.users[user.Username] = user.getACopy() p.dbHandle.users[user.Username] = user.getACopy()
p.dbHandle.usernames = append(p.dbHandle.usernames, user.Username) p.dbHandle.usernames = append(p.dbHandle.usernames, user.Username)
sort.Strings(p.dbHandle.usernames) sort.Strings(p.dbHandle.usernames)
@@ -366,33 +350,26 @@ func (p *MemoryProvider) updateUser(user *User) error {
if err != nil { if err != nil {
return err return err
} }
for idx := range u.Groups {
p.removeUserFromGroupMapping(u.Username, u.Groups[idx].Name)
}
for idx := range user.Groups {
if err = p.addUserToGroupMapping(user.Username, user.Groups[idx].Name); err != nil {
// try to add old mapping
for _, g := range u.Groups {
if errRollback := p.addUserToGroupMapping(user.Username, g.Name); errRollback != nil {
providerLog(logger.LevelError, "unable to rollback old group mapping %q for user %q, error: %v",
g.Name, user.Username, errRollback)
}
}
return err
}
}
for _, oldFolder := range u.VirtualFolders { for _, oldFolder := range u.VirtualFolders {
p.removeRelationFromFolderMapping(oldFolder.Name, u.Username, "") p.removeRelationFromFolderMapping(oldFolder.Name, u.Username, "")
} }
for idx := range u.Groups {
if err = p.removeUserFromGroupMapping(u.Username, u.Groups[idx].Name); err != nil {
return err
}
}
user.VirtualFolders = p.joinUserVirtualFoldersFields(user) user.VirtualFolders = p.joinUserVirtualFoldersFields(user)
for idx := range user.Groups {
if err = p.addUserFromGroupMapping(user.Username, user.Groups[idx].Name); err != nil {
return err
}
}
user.LastQuotaUpdate = u.LastQuotaUpdate user.LastQuotaUpdate = u.LastQuotaUpdate
user.UsedQuotaSize = u.UsedQuotaSize user.UsedQuotaSize = u.UsedQuotaSize
user.UsedQuotaFiles = u.UsedQuotaFiles user.UsedQuotaFiles = u.UsedQuotaFiles
user.UsedUploadDataTransfer = u.UsedUploadDataTransfer user.UsedUploadDataTransfer = u.UsedUploadDataTransfer
user.UsedDownloadDataTransfer = u.UsedDownloadDataTransfer user.UsedDownloadDataTransfer = u.UsedDownloadDataTransfer
user.LastLogin = u.LastLogin user.LastLogin = u.LastLogin
user.FirstDownload = u.FirstDownload
user.FirstUpload = u.FirstUpload
user.CreatedAt = u.CreatedAt user.CreatedAt = u.CreatedAt
user.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) user.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
user.ID = u.ID user.ID = u.ID
@@ -402,7 +379,7 @@ func (p *MemoryProvider) updateUser(user *User) error {
return nil return nil
} }
func (p *MemoryProvider) deleteUser(user User, softDelete bool) error { func (p *MemoryProvider) deleteUser(user User) error {
p.dbHandle.Lock() p.dbHandle.Lock()
defer p.dbHandle.Unlock() defer p.dbHandle.Unlock()
if p.dbHandle.isClosed { if p.dbHandle.isClosed {
@@ -416,7 +393,9 @@ func (p *MemoryProvider) deleteUser(user User, softDelete bool) error {
p.removeRelationFromFolderMapping(oldFolder.Name, u.Username, "") p.removeRelationFromFolderMapping(oldFolder.Name, u.Username, "")
} }
for idx := range u.Groups { for idx := range u.Groups {
p.removeUserFromGroupMapping(u.Username, u.Groups[idx].Name) if err = p.removeUserFromGroupMapping(u.Username, u.Groups[idx].Name); err != nil {
return err
}
} }
delete(p.dbHandle.users, user.Username) delete(p.dbHandle.users, user.Username)
// this could be more efficient // this could be more efficient
@@ -611,28 +590,14 @@ func (p *MemoryProvider) userExistsInternal(username string) (User, error) {
if val, ok := p.dbHandle.users[username]; ok { if val, ok := p.dbHandle.users[username]; ok {
return val.getACopy(), nil return val.getACopy(), nil
} }
return User{}, util.NewRecordNotFoundError(fmt.Sprintf("username %q does not exist", username)) return User{}, util.NewRecordNotFoundError(fmt.Sprintf("username %#v does not exist", username))
} }
func (p *MemoryProvider) groupExistsInternal(name string) (Group, error) { func (p *MemoryProvider) groupExistsInternal(name string) (Group, error) {
if val, ok := p.dbHandle.groups[name]; ok { if val, ok := p.dbHandle.groups[name]; ok {
return val.getACopy(), nil return val.getACopy(), nil
} }
return Group{}, util.NewRecordNotFoundError(fmt.Sprintf("group %q does not exist", name)) return Group{}, util.NewRecordNotFoundError(fmt.Sprintf("group %#v does not exist", name))
}
func (p *MemoryProvider) actionExistsInternal(name string) (BaseEventAction, error) {
if val, ok := p.dbHandle.actions[name]; ok {
return val.getACopy(), nil
}
return BaseEventAction{}, util.NewRecordNotFoundError(fmt.Sprintf("event action %q does not exist", name))
}
func (p *MemoryProvider) ruleExistsInternal(name string) (EventRule, error) {
if val, ok := p.dbHandle.rules[name]; ok {
return val.getACopy(), nil
}
return EventRule{}, util.NewRecordNotFoundError(fmt.Sprintf("event rule %q does not exist", name))
} }
func (p *MemoryProvider) addAdmin(admin *Admin) error { func (p *MemoryProvider) addAdmin(admin *Admin) error {
@@ -653,17 +618,6 @@ func (p *MemoryProvider) addAdmin(admin *Admin) error {
admin.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) admin.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
admin.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) admin.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
admin.LastLogin = 0 admin.LastLogin = 0
var mappedAdmins []string
for idx := range admin.Groups {
if err = p.addAdminToGroupMapping(admin.Username, admin.Groups[idx].Name); err != nil {
// try to remove group mapping
for _, g := range mappedAdmins {
p.removeAdminFromGroupMapping(admin.Username, g)
}
return err
}
mappedAdmins = append(mappedAdmins, admin.Groups[idx].Name)
}
p.dbHandle.admins[admin.Username] = admin.getACopy() p.dbHandle.admins[admin.Username] = admin.getACopy()
p.dbHandle.adminsUsernames = append(p.dbHandle.adminsUsernames, admin.Username) p.dbHandle.adminsUsernames = append(p.dbHandle.adminsUsernames, admin.Username)
sort.Strings(p.dbHandle.adminsUsernames) sort.Strings(p.dbHandle.adminsUsernames)
@@ -684,21 +638,6 @@ func (p *MemoryProvider) updateAdmin(admin *Admin) error {
if err != nil { if err != nil {
return err return err
} }
for idx := range a.Groups {
p.removeAdminFromGroupMapping(a.Username, a.Groups[idx].Name)
}
for idx := range admin.Groups {
if err = p.addAdminToGroupMapping(admin.Username, admin.Groups[idx].Name); err != nil {
// try to add old mapping
for _, oldGroup := range a.Groups {
if errRollback := p.addAdminToGroupMapping(a.Username, oldGroup.Name); errRollback != nil {
providerLog(logger.LevelError, "unable to rollback old group mapping %q for admin %q, error: %v",
oldGroup.Name, a.Username, errRollback)
}
}
return err
}
}
admin.ID = a.ID admin.ID = a.ID
admin.CreatedAt = a.CreatedAt admin.CreatedAt = a.CreatedAt
admin.LastLogin = a.LastLogin admin.LastLogin = a.LastLogin
@@ -713,13 +652,10 @@ func (p *MemoryProvider) deleteAdmin(admin Admin) error {
if p.dbHandle.isClosed { if p.dbHandle.isClosed {
return errMemoryProviderClosed return errMemoryProviderClosed
} }
a, err := p.adminExistsInternal(admin.Username) _, err := p.adminExistsInternal(admin.Username)
if err != nil { if err != nil {
return err return err
} }
for idx := range a.Groups {
p.removeAdminFromGroupMapping(a.Username, a.Groups[idx].Name)
}
delete(p.dbHandle.admins, admin.Username) delete(p.dbHandle.admins, admin.Username)
// this could be more efficient // this could be more efficient
@@ -944,8 +880,6 @@ func (p *MemoryProvider) addGroup(group *Group) error {
group.ID = p.getNextGroupID() group.ID = p.getNextGroupID()
group.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) group.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
group.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) group.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
group.Users = nil
group.Admins = nil
group.VirtualFolders = p.joinGroupVirtualFoldersFields(group) group.VirtualFolders = p.joinGroupVirtualFoldersFields(group)
p.dbHandle.groups[group.Name] = group.getACopy() p.dbHandle.groups[group.Name] = group.getACopy()
p.dbHandle.groupnames = append(p.dbHandle.groupnames, group.Name) p.dbHandle.groupnames = append(p.dbHandle.groupnames, group.Name)
@@ -973,8 +907,6 @@ func (p *MemoryProvider) updateGroup(group *Group) error {
group.CreatedAt = g.CreatedAt group.CreatedAt = g.CreatedAt
group.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now()) group.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
group.ID = g.ID group.ID = g.ID
group.Users = g.Users
group.Admins = g.Admins
p.dbHandle.groups[group.Name] = group.getACopy() p.dbHandle.groups[group.Name] = group.getACopy()
return nil return nil
} }
@@ -995,9 +927,6 @@ func (p *MemoryProvider) deleteGroup(group Group) error {
for _, oldFolder := range g.VirtualFolders { for _, oldFolder := range g.VirtualFolders {
p.removeRelationFromFolderMapping(oldFolder.Name, "", g.Name) p.removeRelationFromFolderMapping(oldFolder.Name, "", g.Name)
} }
for _, a := range g.Admins {
p.removeGroupFromAdminMapping(g.Name, a)
}
delete(p.dbHandle.groups, group.Name) delete(p.dbHandle.groups, group.Name)
// this could be more efficient // this could be more efficient
p.dbHandle.groupnames = make([]string, 0, len(p.dbHandle.groups)) p.dbHandle.groupnames = make([]string, 0, len(p.dbHandle.groups))
@@ -1068,95 +997,7 @@ func (p *MemoryProvider) addVirtualFoldersToGroup(group *Group) {
} }
} }
func (p *MemoryProvider) addActionsToRule(rule *EventRule) { func (p *MemoryProvider) addUserFromGroupMapping(username, groupname string) error {
var actions []EventAction
for idx := range rule.Actions {
action := &rule.Actions[idx]
baseAction, err := p.actionExistsInternal(action.Name)
if err != nil {
continue
}
baseAction.Options.SetEmptySecretsIfNil()
action.BaseEventAction = baseAction
actions = append(actions, *action)
}
rule.Actions = actions
}
func (p *MemoryProvider) addRuleToActionMapping(ruleName, actionName string) error {
a, err := p.actionExistsInternal(actionName)
if err != nil {
return util.NewGenericError(fmt.Sprintf("action %q does not exist", actionName))
}
if !util.Contains(a.Rules, ruleName) {
a.Rules = append(a.Rules, ruleName)
p.dbHandle.actions[actionName] = a
}
return nil
}
func (p *MemoryProvider) removeRuleFromActionMapping(ruleName, actionName string) {
a, err := p.actionExistsInternal(actionName)
if err != nil {
providerLog(logger.LevelWarn, "action %q does not exist, cannot remove from mapping", actionName)
return
}
if util.Contains(a.Rules, ruleName) {
var rules []string
for _, r := range a.Rules {
if r != ruleName {
rules = append(rules, r)
}
}
a.Rules = rules
p.dbHandle.actions[actionName] = a
}
}
func (p *MemoryProvider) addAdminToGroupMapping(username, groupname string) error {
g, err := p.groupExistsInternal(groupname)
if err != nil {
return err
}
if !util.Contains(g.Admins, username) {
g.Admins = append(g.Admins, username)
p.dbHandle.groups[groupname] = g
}
return nil
}
func (p *MemoryProvider) removeAdminFromGroupMapping(username, groupname string) {
g, err := p.groupExistsInternal(groupname)
if err != nil {
return
}
var admins []string
for _, a := range g.Admins {
if a != username {
admins = append(admins, a)
}
}
g.Admins = admins
p.dbHandle.groups[groupname] = g
}
func (p *MemoryProvider) removeGroupFromAdminMapping(groupname, username string) {
admin, err := p.adminExistsInternal(username)
if err != nil {
// the admin does not exist so there is no associated group
return
}
var newGroups []AdminGroupMapping
for _, g := range admin.Groups {
if g.Name != groupname {
newGroups = append(newGroups, g)
}
}
admin.Groups = newGroups
p.dbHandle.admins[admin.Username] = admin
}
func (p *MemoryProvider) addUserToGroupMapping(username, groupname string) error {
g, err := p.groupExistsInternal(groupname) g, err := p.groupExistsInternal(groupname)
if err != nil { if err != nil {
return err return err
@@ -1168,19 +1009,22 @@ func (p *MemoryProvider) addUserToGroupMapping(username, groupname string) error
return nil return nil
} }
func (p *MemoryProvider) removeUserFromGroupMapping(username, groupname string) { func (p *MemoryProvider) removeUserFromGroupMapping(username, groupname string) error {
g, err := p.groupExistsInternal(groupname) g, err := p.groupExistsInternal(groupname)
if err != nil { if err != nil {
return return err
} }
var users []string if util.Contains(g.Users, username) {
for _, u := range g.Users { var users []string
if u != username { for _, u := range g.Users {
users = append(users, u) if u != username {
users = append(users, u)
}
} }
g.Users = users
p.dbHandle.groups[groupname] = g
} }
g.Users = users return nil
p.dbHandle.groups[groupname] = g
} }
func (p *MemoryProvider) joinUserVirtualFoldersFields(user *User) []vfs.VirtualFolder { func (p *MemoryProvider) joinUserVirtualFoldersFields(user *User) []vfs.VirtualFolder {
@@ -1364,7 +1208,6 @@ func (p *MemoryProvider) addFolder(folder *vfs.BaseVirtualFolder) error {
} }
folder.ID = p.getNextFolderID() folder.ID = p.getNextFolderID()
folder.Users = nil folder.Users = nil
folder.Groups = nil
p.dbHandle.vfolders[folder.Name] = folder.GetACopy() p.dbHandle.vfolders[folder.Name] = folder.GetACopy()
p.dbHandle.vfoldersNames = append(p.dbHandle.vfoldersNames, folder.Name) p.dbHandle.vfoldersNames = append(p.dbHandle.vfoldersNames, folder.Name)
sort.Strings(p.dbHandle.vfoldersNames) sort.Strings(p.dbHandle.vfoldersNames)
@@ -1391,7 +1234,6 @@ func (p *MemoryProvider) updateFolder(folder *vfs.BaseVirtualFolder) error {
folder.UsedQuotaFiles = f.UsedQuotaFiles folder.UsedQuotaFiles = f.UsedQuotaFiles
folder.UsedQuotaSize = f.UsedQuotaSize folder.UsedQuotaSize = f.UsedQuotaSize
folder.Users = f.Users folder.Users = f.Users
folder.Groups = f.Groups
p.dbHandle.vfolders[folder.Name] = folder.GetACopy() p.dbHandle.vfolders[folder.Name] = folder.GetACopy()
// now update the related users // now update the related users
for _, username := range folder.Users { for _, username := range folder.Users {
@@ -1412,14 +1254,14 @@ func (p *MemoryProvider) updateFolder(folder *vfs.BaseVirtualFolder) error {
return nil return nil
} }
func (p *MemoryProvider) deleteFolder(f vfs.BaseVirtualFolder) error { func (p *MemoryProvider) deleteFolder(folder vfs.BaseVirtualFolder) error {
p.dbHandle.Lock() p.dbHandle.Lock()
defer p.dbHandle.Unlock() defer p.dbHandle.Unlock()
if p.dbHandle.isClosed { if p.dbHandle.isClosed {
return errMemoryProviderClosed return errMemoryProviderClosed
} }
folder, err := p.folderExistsInternal(f.Name) _, err := p.folderExistsInternal(folder.Name)
if err != nil { if err != nil {
return err return err
} }
@@ -1940,426 +1782,6 @@ func (p *MemoryProvider) cleanupSharedSessions(sessionType SessionType, before i
return ErrNotImplemented return ErrNotImplemented
} }
func (p *MemoryProvider) getEventActions(limit, offset int, order string, minimal bool) ([]BaseEventAction, error) {
p.dbHandle.Lock()
defer p.dbHandle.Unlock()
if p.dbHandle.isClosed {
return nil, errMemoryProviderClosed
}
if limit <= 0 {
return nil, nil
}
actions := make([]BaseEventAction, 0, limit)
itNum := 0
if order == OrderASC {
for _, name := range p.dbHandle.actionsNames {
itNum++
if itNum <= offset {
continue
}
a := p.dbHandle.actions[name]
action := a.getACopy()
action.PrepareForRendering()
actions = append(actions, action)
if len(actions) >= limit {
break
}
}
} else {
for i := len(p.dbHandle.actionsNames) - 1; i >= 0; i-- {
itNum++
if itNum <= offset {
continue
}
name := p.dbHandle.actionsNames[i]
a := p.dbHandle.actions[name]
action := a.getACopy()
action.PrepareForRendering()
actions = append(actions, action)
if len(actions) >= limit {
break
}
}
}
return actions, nil
}
func (p *MemoryProvider) dumpEventActions() ([]BaseEventAction, error) {
p.dbHandle.Lock()
defer p.dbHandle.Unlock()
if p.dbHandle.isClosed {
return nil, errMemoryProviderClosed
}
actions := make([]BaseEventAction, 0, len(p.dbHandle.actions))
for _, name := range p.dbHandle.actionsNames {
a := p.dbHandle.actions[name]
action := a.getACopy()
actions = append(actions, action)
}
return actions, nil
}
func (p *MemoryProvider) eventActionExists(name string) (BaseEventAction, error) {
p.dbHandle.Lock()
defer p.dbHandle.Unlock()
if p.dbHandle.isClosed {
return BaseEventAction{}, errMemoryProviderClosed
}
return p.actionExistsInternal(name)
}
func (p *MemoryProvider) addEventAction(action *BaseEventAction) error {
err := action.validate()
if err != nil {
return err
}
p.dbHandle.Lock()
defer p.dbHandle.Unlock()
if p.dbHandle.isClosed {
return errMemoryProviderClosed
}
_, err = p.actionExistsInternal(action.Name)
if err == nil {
return fmt.Errorf("event action %q already exists", action.Name)
}
action.ID = p.getNextActionID()
action.Rules = nil
p.dbHandle.actions[action.Name] = action.getACopy()
p.dbHandle.actionsNames = append(p.dbHandle.actionsNames, action.Name)
sort.Strings(p.dbHandle.actionsNames)
return nil
}
func (p *MemoryProvider) updateEventAction(action *BaseEventAction) error {
err := action.validate()
if err != nil {
return err
}
p.dbHandle.Lock()
defer p.dbHandle.Unlock()
if p.dbHandle.isClosed {
return errMemoryProviderClosed
}
oldAction, err := p.actionExistsInternal(action.Name)
if err != nil {
return fmt.Errorf("event action %s does not exist", action.Name)
}
action.ID = oldAction.ID
action.Name = oldAction.Name
action.Rules = nil
if len(oldAction.Rules) > 0 {
var relatedRules []string
for _, ruleName := range oldAction.Rules {
rule, err := p.ruleExistsInternal(ruleName)
if err == nil {
relatedRules = append(relatedRules, ruleName)
rule.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
p.dbHandle.rules[ruleName] = rule
setLastRuleUpdate()
}
}
action.Rules = relatedRules
}
p.dbHandle.actions[action.Name] = action.getACopy()
return nil
}
func (p *MemoryProvider) deleteEventAction(action BaseEventAction) error {
p.dbHandle.Lock()
defer p.dbHandle.Unlock()
if p.dbHandle.isClosed {
return errMemoryProviderClosed
}
oldAction, err := p.actionExistsInternal(action.Name)
if err != nil {
return fmt.Errorf("event action %s does not exist", action.Name)
}
if len(oldAction.Rules) > 0 {
return util.NewValidationError(fmt.Sprintf("action %s is referenced, it cannot be removed", oldAction.Name))
}
delete(p.dbHandle.actions, action.Name)
// this could be more efficient
p.dbHandle.actionsNames = make([]string, 0, len(p.dbHandle.actions))
for name := range p.dbHandle.actions {
p.dbHandle.actionsNames = append(p.dbHandle.actionsNames, name)
}
sort.Strings(p.dbHandle.actionsNames)
return nil
}
func (p *MemoryProvider) getEventRules(limit, offset int, order string) ([]EventRule, error) {
p.dbHandle.Lock()
defer p.dbHandle.Unlock()
if p.dbHandle.isClosed {
return nil, errMemoryProviderClosed
}
if limit <= 0 {
return nil, nil
}
itNum := 0
rules := make([]EventRule, 0, limit)
if order == OrderASC {
for _, name := range p.dbHandle.rulesNames {
itNum++
if itNum <= offset {
continue
}
r := p.dbHandle.rules[name]
rule := r.getACopy()
p.addActionsToRule(&rule)
rule.PrepareForRendering()
rules = append(rules, rule)
if len(rules) >= limit {
break
}
}
} else {
for i := len(p.dbHandle.rulesNames) - 1; i >= 0; i-- {
itNum++
if itNum <= offset {
continue
}
name := p.dbHandle.rulesNames[i]
r := p.dbHandle.rules[name]
rule := r.getACopy()
p.addActionsToRule(&rule)
rule.PrepareForRendering()
rules = append(rules, rule)
if len(rules) >= limit {
break
}
}
}
return rules, nil
}
func (p *MemoryProvider) dumpEventRules() ([]EventRule, error) {
p.dbHandle.Lock()
defer p.dbHandle.Unlock()
if p.dbHandle.isClosed {
return nil, errMemoryProviderClosed
}
rules := make([]EventRule, 0, len(p.dbHandle.rules))
for _, name := range p.dbHandle.rulesNames {
r := p.dbHandle.rules[name]
rule := r.getACopy()
p.addActionsToRule(&rule)
rules = append(rules, rule)
}
return rules, nil
}
func (p *MemoryProvider) getRecentlyUpdatedRules(after int64) ([]EventRule, error) {
if getLastRuleUpdate() < after {
return nil, nil
}
p.dbHandle.Lock()
defer p.dbHandle.Unlock()
if p.dbHandle.isClosed {
return nil, errMemoryProviderClosed
}
rules := make([]EventRule, 0, 10)
for _, name := range p.dbHandle.rulesNames {
r := p.dbHandle.rules[name]
if r.UpdatedAt < after {
continue
}
rule := r.getACopy()
p.addActionsToRule(&rule)
rules = append(rules, rule)
}
return rules, nil
}
func (p *MemoryProvider) eventRuleExists(name string) (EventRule, error) {
p.dbHandle.Lock()
defer p.dbHandle.Unlock()
if p.dbHandle.isClosed {
return EventRule{}, errMemoryProviderClosed
}
rule, err := p.ruleExistsInternal(name)
if err != nil {
return rule, err
}
p.addActionsToRule(&rule)
return rule, nil
}
func (p *MemoryProvider) addEventRule(rule *EventRule) error {
if err := rule.validate(); err != nil {
return err
}
p.dbHandle.Lock()
defer p.dbHandle.Unlock()
if p.dbHandle.isClosed {
return errMemoryProviderClosed
}
_, err := p.ruleExistsInternal(rule.Name)
if err == nil {
return fmt.Errorf("event rule %q already exists", rule.Name)
}
rule.ID = p.getNextRuleID()
rule.CreatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
rule.UpdatedAt = rule.CreatedAt
var mappedActions []string
for idx := range rule.Actions {
if err := p.addRuleToActionMapping(rule.Name, rule.Actions[idx].Name); err != nil {
// try to remove action mapping
for _, a := range mappedActions {
p.removeRuleFromActionMapping(rule.Name, a)
}
return err
}
mappedActions = append(mappedActions, rule.Actions[idx].Name)
}
sort.Slice(rule.Actions, func(i, j int) bool {
return rule.Actions[i].Order < rule.Actions[j].Order
})
p.dbHandle.rules[rule.Name] = rule.getACopy()
p.dbHandle.rulesNames = append(p.dbHandle.rulesNames, rule.Name)
sort.Strings(p.dbHandle.rulesNames)
setLastRuleUpdate()
return nil
}
func (p *MemoryProvider) updateEventRule(rule *EventRule) error {
if err := rule.validate(); err != nil {
return err
}
p.dbHandle.Lock()
defer p.dbHandle.Unlock()
if p.dbHandle.isClosed {
return errMemoryProviderClosed
}
oldRule, err := p.ruleExistsInternal(rule.Name)
if err != nil {
return err
}
for idx := range oldRule.Actions {
p.removeRuleFromActionMapping(rule.Name, oldRule.Actions[idx].Name)
}
for idx := range rule.Actions {
if err = p.addRuleToActionMapping(rule.Name, rule.Actions[idx].Name); err != nil {
// try to add old mapping
for _, oldAction := range oldRule.Actions {
if errRollback := p.addRuleToActionMapping(oldRule.Name, oldAction.Name); errRollback != nil {
providerLog(logger.LevelError, "unable to rollback old action mapping %q for rule %q, error: %v",
oldAction.Name, oldRule.Name, errRollback)
}
}
return err
}
}
rule.ID = oldRule.ID
rule.CreatedAt = oldRule.CreatedAt
rule.UpdatedAt = util.GetTimeAsMsSinceEpoch(time.Now())
sort.Slice(rule.Actions, func(i, j int) bool {
return rule.Actions[i].Order < rule.Actions[j].Order
})
p.dbHandle.rules[rule.Name] = rule.getACopy()
setLastRuleUpdate()
return nil
}
func (p *MemoryProvider) deleteEventRule(rule EventRule, softDelete bool) error {
p.dbHandle.Lock()
defer p.dbHandle.Unlock()
if p.dbHandle.isClosed {
return errMemoryProviderClosed
}
oldRule, err := p.ruleExistsInternal(rule.Name)
if err != nil {
return err
}
if len(oldRule.Actions) > 0 {
for idx := range oldRule.Actions {
p.removeRuleFromActionMapping(rule.Name, oldRule.Actions[idx].Name)
}
}
delete(p.dbHandle.rules, rule.Name)
p.dbHandle.rulesNames = make([]string, 0, len(p.dbHandle.rules))
for name := range p.dbHandle.rules {
p.dbHandle.rulesNames = append(p.dbHandle.rulesNames, name)
}
sort.Strings(p.dbHandle.rulesNames)
setLastRuleUpdate()
return nil
}
func (*MemoryProvider) getTaskByName(name string) (Task, error) {
return Task{}, ErrNotImplemented
}
func (*MemoryProvider) addTask(name string) error {
return ErrNotImplemented
}
func (*MemoryProvider) updateTask(name string, version int64) error {
return ErrNotImplemented
}
func (*MemoryProvider) updateTaskTimestamp(name string) error {
return ErrNotImplemented
}
func (*MemoryProvider) addNode() error {
return ErrNotImplemented
}
func (*MemoryProvider) getNodeByName(name string) (Node, error) {
return Node{}, ErrNotImplemented
}
func (*MemoryProvider) getNodes() ([]Node, error) {
return nil, ErrNotImplemented
}
func (*MemoryProvider) updateNodeTimestamp() error {
return ErrNotImplemented
}
func (*MemoryProvider) cleanupNodes() error {
return ErrNotImplemented
}
func (p *MemoryProvider) setFirstDownloadTimestamp(username string) error {
p.dbHandle.Lock()
defer p.dbHandle.Unlock()
if p.dbHandle.isClosed {
return errMemoryProviderClosed
}
user, err := p.userExistsInternal(username)
if err != nil {
return err
}
if user.FirstDownload > 0 {
return util.NewGenericError(fmt.Sprintf("first download already set to %v",
util.GetTimeFromMsecSinceEpoch(user.FirstDownload)))
}
user.FirstDownload = util.GetTimeAsMsSinceEpoch(time.Now())
p.dbHandle.users[user.Username] = user
return nil
}
func (p *MemoryProvider) setFirstUploadTimestamp(username string) error {
p.dbHandle.Lock()
defer p.dbHandle.Unlock()
if p.dbHandle.isClosed {
return errMemoryProviderClosed
}
user, err := p.userExistsInternal(username)
if err != nil {
return err
}
if user.FirstUpload > 0 {
return util.NewGenericError(fmt.Sprintf("first upload already set to %v",
util.GetTimeFromMsecSinceEpoch(user.FirstUpload)))
}
user.FirstUpload = util.GetTimeAsMsSinceEpoch(time.Now())
p.dbHandle.users[user.Username] = user
return nil
}
func (p *MemoryProvider) getNextID() int64 { func (p *MemoryProvider) getNextID() int64 {
nextID := int64(1) nextID := int64(1)
for _, v := range p.dbHandle.users { for _, v := range p.dbHandle.users {
@@ -2400,26 +1822,6 @@ func (p *MemoryProvider) getNextGroupID() int64 {
return nextID return nextID
} }
func (p *MemoryProvider) getNextActionID() int64 {
nextID := int64(1)
for _, a := range p.dbHandle.actions {
if a.ID >= nextID {
nextID = a.ID + 1
}
}
return nextID
}
func (p *MemoryProvider) getNextRuleID() int64 {
nextID := int64(1)
for _, r := range p.dbHandle.rules {
if r.ID >= nextID {
nextID = r.ID + 1
}
}
return nextID
}
func (p *MemoryProvider) clear() { func (p *MemoryProvider) clear() {
p.dbHandle.Lock() p.dbHandle.Lock()
defer p.dbHandle.Unlock() defer p.dbHandle.Unlock()
@@ -2468,35 +1870,27 @@ func (p *MemoryProvider) reloadConfig() error {
} }
p.clear() p.clear()
if err := p.restoreFolders(dump); err != nil { if err := p.restoreFolders(&dump); err != nil {
return err return err
} }
if err := p.restoreGroups(dump); err != nil { if err := p.restoreGroups(&dump); err != nil {
return err return err
} }
if err := p.restoreUsers(dump); err != nil { if err := p.restoreUsers(&dump); err != nil {
return err return err
} }
if err := p.restoreAdmins(dump); err != nil { if err := p.restoreAdmins(&dump); err != nil {
return err return err
} }
if err := p.restoreAPIKeys(dump); err != nil { if err := p.restoreAPIKeys(&dump); err != nil {
return err return err
} }
if err := p.restoreShares(dump); err != nil { if err := p.restoreShares(&dump); err != nil {
return err
}
if err := p.restoreEventActions(dump); err != nil {
return err
}
if err := p.restoreEventRules(dump); err != nil {
return err return err
} }
@@ -2504,51 +1898,7 @@ func (p *MemoryProvider) reloadConfig() error {
return nil return nil
} }
func (p *MemoryProvider) restoreEventActions(dump BackupData) error { func (p *MemoryProvider) restoreShares(dump *BackupData) error {
for _, action := range dump.EventActions {
a, err := p.eventActionExists(action.Name)
action := action // pin
if err == nil {
action.ID = a.ID
err = UpdateEventAction(&action, ActionExecutorSystem, "")
if err != nil {
providerLog(logger.LevelError, "error updating event action %q: %v", action.Name, err)
return err
}
} else {
err = AddEventAction(&action, ActionExecutorSystem, "")
if err != nil {
providerLog(logger.LevelError, "error adding event action %q: %v", action.Name, err)
return err
}
}
}
return nil
}
func (p *MemoryProvider) restoreEventRules(dump BackupData) error {
for _, rule := range dump.EventRules {
r, err := p.eventRuleExists(rule.Name)
rule := rule // pin
if err == nil {
rule.ID = r.ID
err = UpdateEventRule(&rule, ActionExecutorSystem, "")
if err != nil {
providerLog(logger.LevelError, "error updating event rule %q: %v", rule.Name, err)
return err
}
} else {
err = AddEventRule(&rule, ActionExecutorSystem, "")
if err != nil {
providerLog(logger.LevelError, "error adding event rule %q: %v", rule.Name, err)
return err
}
}
}
return nil
}
func (p *MemoryProvider) restoreShares(dump BackupData) error {
for _, share := range dump.Shares { for _, share := range dump.Shares {
s, err := p.shareExists(share.ShareID, "") s, err := p.shareExists(share.ShareID, "")
share := share // pin share := share // pin
@@ -2571,7 +1921,7 @@ func (p *MemoryProvider) restoreShares(dump BackupData) error {
return nil return nil
} }
func (p *MemoryProvider) restoreAPIKeys(dump BackupData) error { func (p *MemoryProvider) restoreAPIKeys(dump *BackupData) error {
for _, apiKey := range dump.APIKeys { for _, apiKey := range dump.APIKeys {
if apiKey.Key == "" { if apiKey.Key == "" {
return fmt.Errorf("cannot restore an empty API key: %+v", apiKey) return fmt.Errorf("cannot restore an empty API key: %+v", apiKey)
@@ -2596,7 +1946,7 @@ func (p *MemoryProvider) restoreAPIKeys(dump BackupData) error {
return nil return nil
} }
func (p *MemoryProvider) restoreAdmins(dump BackupData) error { func (p *MemoryProvider) restoreAdmins(dump *BackupData) error {
for _, admin := range dump.Admins { for _, admin := range dump.Admins {
admin := admin // pin admin := admin // pin
admin.Username = config.convertName(admin.Username) admin.Username = config.convertName(admin.Username)
@@ -2619,7 +1969,7 @@ func (p *MemoryProvider) restoreAdmins(dump BackupData) error {
return nil return nil
} }
func (p *MemoryProvider) restoreGroups(dump BackupData) error { func (p *MemoryProvider) restoreGroups(dump *BackupData) error {
for _, group := range dump.Groups { for _, group := range dump.Groups {
group := group // pin group := group // pin
group.Name = config.convertName(group.Name) group.Name = config.convertName(group.Name)
@@ -2643,7 +1993,7 @@ func (p *MemoryProvider) restoreGroups(dump BackupData) error {
return nil return nil
} }
func (p *MemoryProvider) restoreFolders(dump BackupData) error { func (p *MemoryProvider) restoreFolders(dump *BackupData) error {
for _, folder := range dump.Folders { for _, folder := range dump.Folders {
folder := folder // pin folder := folder // pin
folder.Name = config.convertName(folder.Name) folder.Name = config.convertName(folder.Name)
@@ -2667,7 +2017,7 @@ func (p *MemoryProvider) restoreFolders(dump BackupData) error {
return nil return nil
} }
func (p *MemoryProvider) restoreUsers(dump BackupData) error { func (p *MemoryProvider) restoreUsers(dump *BackupData) error {
for _, user := range dump.Users { for _, user := range dump.Users {
user := user // pin user := user // pin
user.Username = config.convertName(user.Username) user.Username = config.convertName(user.Username)

View File

@@ -25,15 +25,14 @@ import (
"errors" "errors"
"fmt" "fmt"
"os" "os"
"path/filepath"
"strings" "strings"
"time" "time"
"github.com/go-sql-driver/mysql" "github.com/go-sql-driver/mysql"
"github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/internal/version" "github.com/drakkan/sftpgo/v2/version"
"github.com/drakkan/sftpgo/v2/internal/vfs" "github.com/drakkan/sftpgo/v2/vfs"
) )
const ( const (
@@ -41,7 +40,6 @@ const (
"DROP TABLE IF EXISTS `{{folders_mapping}}` CASCADE;" + "DROP TABLE IF EXISTS `{{folders_mapping}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{users_folders_mapping}}` CASCADE;" + "DROP TABLE IF EXISTS `{{users_folders_mapping}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{users_groups_mapping}}` CASCADE;" + "DROP TABLE IF EXISTS `{{users_groups_mapping}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{admins_groups_mapping}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{groups_folders_mapping}}` CASCADE;" + "DROP TABLE IF EXISTS `{{groups_folders_mapping}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{admins}}` CASCADE;" + "DROP TABLE IF EXISTS `{{admins}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{folders}}` CASCADE;" + "DROP TABLE IF EXISTS `{{folders}}` CASCADE;" +
@@ -52,22 +50,12 @@ const (
"DROP TABLE IF EXISTS `{{defender_hosts}}` CASCADE;" + "DROP TABLE IF EXISTS `{{defender_hosts}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{active_transfers}}` CASCADE;" + "DROP TABLE IF EXISTS `{{active_transfers}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{shared_sessions}}` CASCADE;" + "DROP TABLE IF EXISTS `{{shared_sessions}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{rules_actions_mapping}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{events_actions}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{events_rules}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{tasks}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{nodes}}` CASCADE;" +
"DROP TABLE IF EXISTS `{{schema_version}}` CASCADE;" "DROP TABLE IF EXISTS `{{schema_version}}` CASCADE;"
mysqlInitialSQL = "CREATE TABLE `{{schema_version}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `version` integer NOT NULL);" + mysqlInitialSQL = "CREATE TABLE `{{schema_version}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `version` integer NOT NULL);" +
"CREATE TABLE `{{admins}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `username` varchar(255) NOT NULL UNIQUE, " + "CREATE TABLE `{{admins}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `username` varchar(255) NOT NULL UNIQUE, " +
"`description` varchar(512) NULL, `password` varchar(255) NOT NULL, `email` varchar(255) NULL, `status` integer NOT NULL, " + "`description` varchar(512) NULL, `password` varchar(255) NOT NULL, `email` varchar(255) NULL, `status` integer NOT NULL, " +
"`permissions` longtext NOT NULL, `filters` longtext NULL, `additional_info` longtext NULL, `last_login` bigint NOT NULL, " + "`permissions` longtext NOT NULL, `filters` longtext NULL, `additional_info` longtext NULL, `last_login` bigint NOT NULL, " +
"`created_at` bigint NOT NULL, `updated_at` bigint NOT NULL);" + "`created_at` bigint NOT NULL, `updated_at` bigint NOT NULL);" +
"CREATE TABLE `{{active_transfers}}` (`id` bigint AUTO_INCREMENT NOT NULL PRIMARY KEY, " +
"`connection_id` varchar(100) NOT NULL, `transfer_id` bigint NOT NULL, `transfer_type` integer NOT NULL, " +
"`username` varchar(255) NOT NULL, `folder_name` varchar(255) NULL, `ip` varchar(50) NOT NULL, " +
"`truncated_size` bigint NOT NULL, `current_ul_size` bigint NOT NULL, `current_dl_size` bigint NOT NULL, " +
"`created_at` bigint NOT NULL, `updated_at` bigint NOT NULL);" +
"CREATE TABLE `{{defender_hosts}}` (`id` bigint AUTO_INCREMENT NOT NULL PRIMARY KEY, " + "CREATE TABLE `{{defender_hosts}}` (`id` bigint AUTO_INCREMENT NOT NULL PRIMARY KEY, " +
"`ip` varchar(50) NOT NULL UNIQUE, `ban_time` bigint NOT NULL, `updated_at` bigint NOT NULL);" + "`ip` varchar(50) NOT NULL UNIQUE, `ban_time` bigint NOT NULL, `updated_at` bigint NOT NULL);" +
"CREATE TABLE `{{defender_events}}` (`id` bigint AUTO_INCREMENT NOT NULL PRIMARY KEY, " + "CREATE TABLE `{{defender_events}}` (`id` bigint AUTO_INCREMENT NOT NULL PRIMARY KEY, " +
@@ -77,11 +65,6 @@ const (
"CREATE TABLE `{{folders}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `name` varchar(255) NOT NULL UNIQUE, " + "CREATE TABLE `{{folders}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `name` varchar(255) NOT NULL UNIQUE, " +
"`description` varchar(512) NULL, `path` longtext NULL, `used_quota_size` bigint NOT NULL, " + "`description` varchar(512) NULL, `path` longtext NULL, `used_quota_size` bigint NOT NULL, " +
"`used_quota_files` integer NOT NULL, `last_quota_update` bigint NOT NULL, `filesystem` longtext NULL);" + "`used_quota_files` integer NOT NULL, `last_quota_update` bigint NOT NULL, `filesystem` longtext NULL);" +
"CREATE TABLE `{{groups}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " +
"`name` varchar(255) NOT NULL UNIQUE, `description` varchar(512) NULL, `created_at` bigint NOT NULL, " +
"`updated_at` bigint NOT NULL, `user_settings` longtext NULL);" +
"CREATE TABLE `{{shared_sessions}}` (`key` varchar(128) NOT NULL PRIMARY KEY, " +
"`data` longtext NOT NULL, `type` integer NOT NULL, `timestamp` bigint NOT NULL);" +
"CREATE TABLE `{{users}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `username` varchar(255) NOT NULL UNIQUE, " + "CREATE TABLE `{{users}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `username` varchar(255) NOT NULL UNIQUE, " +
"`status` integer NOT NULL, `expiration_date` bigint NOT NULL, `description` varchar(512) NULL, `password` longtext NULL, " + "`status` integer NOT NULL, `expiration_date` bigint NOT NULL, `description` varchar(512) NULL, `password` longtext NULL, " +
"`public_keys` longtext NULL, `home_dir` longtext NOT NULL, `uid` bigint NOT NULL, `gid` bigint NOT NULL, " + "`public_keys` longtext NULL, `home_dir` longtext NOT NULL, `uid` bigint NOT NULL, `gid` bigint NOT NULL, " +
@@ -89,33 +72,12 @@ const (
"`permissions` longtext NOT NULL, `used_quota_size` bigint NOT NULL, `used_quota_files` integer NOT NULL, " + "`permissions` longtext NOT NULL, `used_quota_size` bigint NOT NULL, `used_quota_files` integer NOT NULL, " +
"`last_quota_update` bigint NOT NULL, `upload_bandwidth` integer NOT NULL, `download_bandwidth` integer NOT NULL, " + "`last_quota_update` bigint NOT NULL, `upload_bandwidth` integer NOT NULL, `download_bandwidth` integer NOT NULL, " +
"`last_login` bigint NOT NULL, `filters` longtext NULL, `filesystem` longtext NULL, `additional_info` longtext NULL, " + "`last_login` bigint NOT NULL, `filters` longtext NULL, `filesystem` longtext NULL, `additional_info` longtext NULL, " +
"`created_at` bigint NOT NULL, `updated_at` bigint NOT NULL, `email` varchar(255) NULL, " + "`created_at` bigint NOT NULL, `updated_at` bigint NOT NULL, `email` varchar(255) NULL);" +
"`upload_data_transfer` integer NOT NULL, `download_data_transfer` integer NOT NULL, " + "CREATE TABLE `{{folders_mapping}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `virtual_path` longtext NOT NULL, " +
"`total_data_transfer` integer NOT NULL, `used_upload_data_transfer` integer NOT NULL, " +
"`used_download_data_transfer` integer NOT NULL);" +
"CREATE TABLE `{{groups_folders_mapping}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " +
"`group_id` integer NOT NULL, `folder_id` integer NOT NULL, " +
"`virtual_path` longtext NOT NULL, `quota_size` bigint NOT NULL, `quota_files` integer NOT NULL);" +
"CREATE TABLE `{{users_groups_mapping}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " +
"`user_id` integer NOT NULL, `group_id` integer NOT NULL, `group_type` integer NOT NULL);" +
"CREATE TABLE `{{users_folders_mapping}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `virtual_path` longtext NOT NULL, " +
"`quota_size` bigint NOT NULL, `quota_files` integer NOT NULL, `folder_id` integer NOT NULL, `user_id` integer NOT NULL);" + "`quota_size` bigint NOT NULL, `quota_files` integer NOT NULL, `folder_id` integer NOT NULL, `user_id` integer NOT NULL);" +
"ALTER TABLE `{{users_folders_mapping}}` ADD CONSTRAINT `{{prefix}}unique_user_folder_mapping` " + "ALTER TABLE `{{folders_mapping}}` ADD CONSTRAINT `{{prefix}}unique_mapping` UNIQUE (`user_id`, `folder_id`);" +
"UNIQUE (`user_id`, `folder_id`);" + "ALTER TABLE `{{folders_mapping}}` ADD CONSTRAINT `{{prefix}}folders_mapping_folder_id_fk_folders_id` FOREIGN KEY (`folder_id`) REFERENCES `{{folders}}` (`id`) ON DELETE CASCADE;" +
"ALTER TABLE `{{users_folders_mapping}}` ADD CONSTRAINT `{{prefix}}users_folders_mapping_user_id_fk_users_id` " + "ALTER TABLE `{{folders_mapping}}` ADD CONSTRAINT `{{prefix}}folders_mapping_user_id_fk_users_id` FOREIGN KEY (`user_id`) REFERENCES `{{users}}` (`id`) ON DELETE CASCADE;" +
"FOREIGN KEY (`user_id`) REFERENCES `{{users}}` (`id`) ON DELETE CASCADE;" +
"ALTER TABLE `{{users_folders_mapping}}` ADD CONSTRAINT `{{prefix}}users_folders_mapping_folder_id_fk_folders_id` " +
"FOREIGN KEY (`folder_id`) REFERENCES `{{folders}}` (`id`) ON DELETE CASCADE;" +
"ALTER TABLE `{{users_groups_mapping}}` ADD CONSTRAINT `{{prefix}}unique_user_group_mapping` UNIQUE (`user_id`, `group_id`);" +
"ALTER TABLE `{{groups_folders_mapping}}` ADD CONSTRAINT `{{prefix}}unique_group_folder_mapping` UNIQUE (`group_id`, `folder_id`);" +
"ALTER TABLE `{{users_groups_mapping}}` ADD CONSTRAINT `{{prefix}}users_groups_mapping_group_id_fk_groups_id` " +
"FOREIGN KEY (`group_id`) REFERENCES `{{groups}}` (`id`) ON DELETE NO ACTION;" +
"ALTER TABLE `{{users_groups_mapping}}` ADD CONSTRAINT `{{prefix}}users_groups_mapping_user_id_fk_users_id` " +
"FOREIGN KEY (`user_id`) REFERENCES `{{users}}` (`id`) ON DELETE CASCADE;" +
"ALTER TABLE `{{groups_folders_mapping}}` ADD CONSTRAINT `{{prefix}}groups_folders_mapping_folder_id_fk_folders_id` " +
"FOREIGN KEY (`folder_id`) REFERENCES `{{folders}}` (`id`) ON DELETE CASCADE;" +
"ALTER TABLE `{{groups_folders_mapping}}` ADD CONSTRAINT `{{prefix}}groups_folders_mapping_group_id_fk_groups_id` " +
"FOREIGN KEY (`group_id`) REFERENCES `{{groups}}` (`id`) ON DELETE CASCADE;" +
"CREATE TABLE `{{shares}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " + "CREATE TABLE `{{shares}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " +
"`share_id` varchar(60) NOT NULL UNIQUE, `name` varchar(255) NOT NULL, `description` varchar(512) NULL, " + "`share_id` varchar(60) NOT NULL UNIQUE, `name` varchar(255) NOT NULL, `description` varchar(512) NULL, " +
"`scope` integer NOT NULL, `paths` longtext NOT NULL, `created_at` bigint NOT NULL, " + "`scope` integer NOT NULL, `paths` longtext NOT NULL, `created_at` bigint NOT NULL, " +
@@ -133,60 +95,83 @@ const (
"CREATE INDEX `{{prefix}}defender_hosts_updated_at_idx` ON `{{defender_hosts}}` (`updated_at`);" + "CREATE INDEX `{{prefix}}defender_hosts_updated_at_idx` ON `{{defender_hosts}}` (`updated_at`);" +
"CREATE INDEX `{{prefix}}defender_hosts_ban_time_idx` ON `{{defender_hosts}}` (`ban_time`);" + "CREATE INDEX `{{prefix}}defender_hosts_ban_time_idx` ON `{{defender_hosts}}` (`ban_time`);" +
"CREATE INDEX `{{prefix}}defender_events_date_time_idx` ON `{{defender_events}}` (`date_time`);" + "CREATE INDEX `{{prefix}}defender_events_date_time_idx` ON `{{defender_events}}` (`date_time`);" +
"INSERT INTO {{schema_version}} (version) VALUES (15);"
mysqlV16SQL = "ALTER TABLE `{{users}}` ADD COLUMN `download_data_transfer` integer DEFAULT 0 NOT NULL;" +
"ALTER TABLE `{{users}}` ALTER COLUMN `download_data_transfer` DROP DEFAULT;" +
"ALTER TABLE `{{users}}` ADD COLUMN `total_data_transfer` integer DEFAULT 0 NOT NULL;" +
"ALTER TABLE `{{users}}` ALTER COLUMN `total_data_transfer` DROP DEFAULT;" +
"ALTER TABLE `{{users}}` ADD COLUMN `upload_data_transfer` integer DEFAULT 0 NOT NULL;" +
"ALTER TABLE `{{users}}` ALTER COLUMN `upload_data_transfer` DROP DEFAULT;" +
"ALTER TABLE `{{users}}` ADD COLUMN `used_download_data_transfer` integer DEFAULT 0 NOT NULL;" +
"ALTER TABLE `{{users}}` ALTER COLUMN `used_download_data_transfer` DROP DEFAULT;" +
"ALTER TABLE `{{users}}` ADD COLUMN `used_upload_data_transfer` integer DEFAULT 0 NOT NULL;" +
"ALTER TABLE `{{users}}` ALTER COLUMN `used_upload_data_transfer` DROP DEFAULT;" +
"CREATE TABLE `{{active_transfers}}` (`id` bigint AUTO_INCREMENT NOT NULL PRIMARY KEY, " +
"`connection_id` varchar(100) NOT NULL, `transfer_id` bigint NOT NULL, `transfer_type` integer NOT NULL, " +
"`username` varchar(255) NOT NULL, `folder_name` varchar(255) NULL, `ip` varchar(50) NOT NULL, " +
"`truncated_size` bigint NOT NULL, `current_ul_size` bigint NOT NULL, `current_dl_size` bigint NOT NULL, " +
"`created_at` bigint NOT NULL, `updated_at` bigint NOT NULL);" +
"CREATE INDEX `{{prefix}}active_transfers_connection_id_idx` ON `{{active_transfers}}` (`connection_id`);" + "CREATE INDEX `{{prefix}}active_transfers_connection_id_idx` ON `{{active_transfers}}` (`connection_id`);" +
"CREATE INDEX `{{prefix}}active_transfers_transfer_id_idx` ON `{{active_transfers}}` (`transfer_id`);" + "CREATE INDEX `{{prefix}}active_transfers_transfer_id_idx` ON `{{active_transfers}}` (`transfer_id`);" +
"CREATE INDEX `{{prefix}}active_transfers_updated_at_idx` ON `{{active_transfers}}` (`updated_at`);" + "CREATE INDEX `{{prefix}}active_transfers_updated_at_idx` ON `{{active_transfers}}` (`updated_at`);"
"CREATE INDEX `{{prefix}}groups_updated_at_idx` ON `{{groups}}` (`updated_at`);" + mysqlV16DownSQL = "ALTER TABLE `{{users}}` DROP COLUMN `used_upload_data_transfer`;" +
"CREATE INDEX `{{prefix}}shared_sessions_type_idx` ON `{{shared_sessions}}` (`type`);" + "ALTER TABLE `{{users}}` DROP COLUMN `used_download_data_transfer`;" +
"CREATE INDEX `{{prefix}}shared_sessions_timestamp_idx` ON `{{shared_sessions}}` (`timestamp`);" + "ALTER TABLE `{{users}}` DROP COLUMN `upload_data_transfer`;" +
"INSERT INTO {{schema_version}} (version) VALUES (19);" "ALTER TABLE `{{users}}` DROP COLUMN `total_data_transfer`;" +
mysqlV20SQL = "CREATE TABLE `{{events_rules}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " + "ALTER TABLE `{{users}}` DROP COLUMN `download_data_transfer`;" +
"DROP TABLE `{{active_transfers}}` CASCADE;"
mysqlV17SQL = "CREATE TABLE `{{groups}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " +
"`name` varchar(255) NOT NULL UNIQUE, `description` varchar(512) NULL, `created_at` bigint NOT NULL, " + "`name` varchar(255) NOT NULL UNIQUE, `description` varchar(512) NULL, `created_at` bigint NOT NULL, " +
"`updated_at` bigint NOT NULL, `trigger` integer NOT NULL, `conditions` longtext NOT NULL, `deleted_at` bigint NOT NULL);" + "`updated_at` bigint NOT NULL, `user_settings` longtext NULL);" +
"CREATE TABLE `{{events_actions}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " + "CREATE TABLE `{{groups_folders_mapping}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " +
"`name` varchar(255) NOT NULL UNIQUE, `description` varchar(512) NULL, `type` integer NOT NULL, " + "`group_id` integer NOT NULL, `folder_id` integer NOT NULL, " +
"`options` longtext NOT NULL);" + "`virtual_path` longtext NOT NULL, `quota_size` bigint NOT NULL, `quota_files` integer NOT NULL);" +
"CREATE TABLE `{{rules_actions_mapping}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " + "CREATE TABLE `{{users_groups_mapping}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " +
"`rule_id` integer NOT NULL, `action_id` integer NOT NULL, `order` integer NOT NULL, `options` longtext NOT NULL);" + "`user_id` integer NOT NULL, `group_id` integer NOT NULL, `group_type` integer NOT NULL);" +
"CREATE TABLE `{{tasks}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, `name` varchar(255) NOT NULL UNIQUE, " + "ALTER TABLE `{{folders_mapping}}` DROP FOREIGN KEY `{{prefix}}folders_mapping_folder_id_fk_folders_id`;" +
"`updated_at` bigint NOT NULL, `version` bigint NOT NULL);" + "ALTER TABLE `{{folders_mapping}}` DROP FOREIGN KEY `{{prefix}}folders_mapping_user_id_fk_users_id`;" +
"ALTER TABLE `{{rules_actions_mapping}}` ADD CONSTRAINT `{{prefix}}unique_rule_action_mapping` UNIQUE (`rule_id`, `action_id`);" + "ALTER TABLE `{{folders_mapping}}` DROP INDEX `{{prefix}}unique_mapping`;" +
"ALTER TABLE `{{rules_actions_mapping}}` ADD CONSTRAINT `{{prefix}}rules_actions_mapping_rule_id_fk_events_rules_id` " + "RENAME TABLE `{{folders_mapping}}` TO `{{users_folders_mapping}}`;" +
"FOREIGN KEY (`rule_id`) REFERENCES `{{events_rules}}` (`id`) ON DELETE CASCADE;" + "ALTER TABLE `{{users_folders_mapping}}` ADD CONSTRAINT `{{prefix}}unique_user_folder_mapping` " +
"ALTER TABLE `{{rules_actions_mapping}}` ADD CONSTRAINT `{{prefix}}rules_actions_mapping_action_id_fk_events_targets_id` " + "UNIQUE (`user_id`, `folder_id`);" +
"FOREIGN KEY (`action_id`) REFERENCES `{{events_actions}}` (`id`) ON DELETE NO ACTION;" + "ALTER TABLE `{{users_folders_mapping}}` ADD CONSTRAINT `{{prefix}}users_folders_mapping_user_id_fk_users_id` " +
"ALTER TABLE `{{users}}` ADD COLUMN `deleted_at` bigint DEFAULT 0 NOT NULL;" + "FOREIGN KEY (`user_id`) REFERENCES `{{users}}` (`id`) ON DELETE CASCADE;" +
"ALTER TABLE `{{users}}` ALTER COLUMN `deleted_at` DROP DEFAULT;" + "ALTER TABLE `{{users_folders_mapping}}` ADD CONSTRAINT `{{prefix}}users_folders_mapping_folder_id_fk_folders_id` " +
"CREATE INDEX `{{prefix}}events_rules_updated_at_idx` ON `{{events_rules}}` (`updated_at`);" + "FOREIGN KEY (`folder_id`) REFERENCES `{{folders}}` (`id`) ON DELETE CASCADE;" +
"CREATE INDEX `{{prefix}}events_rules_deleted_at_idx` ON `{{events_rules}}` (`deleted_at`);" + "ALTER TABLE `{{users_groups_mapping}}` ADD CONSTRAINT `{{prefix}}unique_user_group_mapping` UNIQUE (`user_id`, `group_id`);" +
"CREATE INDEX `{{prefix}}events_rules_trigger_idx` ON `{{events_rules}}` (`trigger`);" + "ALTER TABLE `{{groups_folders_mapping}}` ADD CONSTRAINT `{{prefix}}unique_group_folder_mapping` UNIQUE (`group_id`, `folder_id`);" +
"CREATE INDEX `{{prefix}}rules_actions_mapping_order_idx` ON `{{rules_actions_mapping}}` (`order`);" + "ALTER TABLE `{{users_groups_mapping}}` ADD CONSTRAINT `{{prefix}}users_groups_mapping_group_id_fk_groups_id` " +
"CREATE INDEX `{{prefix}}users_deleted_at_idx` ON `{{users}}` (`deleted_at`);" "FOREIGN KEY (`group_id`) REFERENCES `{{groups}}` (`id`) ON DELETE NO ACTION;" +
mysqlV20DownSQL = "DROP TABLE `{{rules_actions_mapping}}` CASCADE;" + "ALTER TABLE `{{users_groups_mapping}}` ADD CONSTRAINT `{{prefix}}users_groups_mapping_user_id_fk_users_id` " +
"DROP TABLE `{{events_rules}}` CASCADE;" + "FOREIGN KEY (`user_id`) REFERENCES `{{users}}` (`id`) ON DELETE CASCADE;" +
"DROP TABLE `{{events_actions}}` CASCADE;" + "ALTER TABLE `{{groups_folders_mapping}}` ADD CONSTRAINT `{{prefix}}groups_folders_mapping_folder_id_fk_folders_id` " +
"DROP TABLE `{{tasks}}` CASCADE;" + "FOREIGN KEY (`folder_id`) REFERENCES `{{folders}}` (`id`) ON DELETE CASCADE;" +
"ALTER TABLE `{{users}}` DROP COLUMN `deleted_at`;" "ALTER TABLE `{{groups_folders_mapping}}` ADD CONSTRAINT `{{prefix}}groups_folders_mapping_group_id_fk_groups_id` " +
mysqlV21SQL = "ALTER TABLE `{{users}}` ADD COLUMN `first_download` bigint DEFAULT 0 NOT NULL; " + "FOREIGN KEY (`group_id`) REFERENCES `{{groups}}` (`id`) ON DELETE CASCADE;" +
"ALTER TABLE `{{users}}` ALTER COLUMN `first_download` DROP DEFAULT; " + "CREATE INDEX `{{prefix}}groups_updated_at_idx` ON `{{groups}}` (`updated_at`);"
"ALTER TABLE `{{users}}` ADD COLUMN `first_upload` bigint DEFAULT 0 NOT NULL; " + mysqlV17DownSQL = "ALTER TABLE `{{groups_folders_mapping}}` DROP FOREIGN KEY `{{prefix}}groups_folders_mapping_group_id_fk_groups_id`;" +
"ALTER TABLE `{{users}}` ALTER COLUMN `first_upload` DROP DEFAULT;" "ALTER TABLE `{{groups_folders_mapping}}` DROP FOREIGN KEY `{{prefix}}groups_folders_mapping_folder_id_fk_folders_id`;" +
mysqlV21DownSQL = "ALTER TABLE `{{users}}` DROP COLUMN `first_upload`; " + "ALTER TABLE `{{users_groups_mapping}}` DROP FOREIGN KEY `{{prefix}}users_groups_mapping_user_id_fk_users_id`;" +
"ALTER TABLE `{{users}}` DROP COLUMN `first_download`;" "ALTER TABLE `{{users_groups_mapping}}` DROP FOREIGN KEY `{{prefix}}users_groups_mapping_group_id_fk_groups_id`;" +
mysqlV22SQL = "CREATE TABLE `{{admins_groups_mapping}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " + "ALTER TABLE `{{groups_folders_mapping}}` DROP INDEX `{{prefix}}unique_group_folder_mapping`;" +
" `admin_id` integer NOT NULL, `group_id` integer NOT NULL, `options` longtext NOT NULL);" + "ALTER TABLE `{{users_groups_mapping}}` DROP INDEX `{{prefix}}unique_user_group_mapping`;" +
"ALTER TABLE `{{admins_groups_mapping}}` ADD CONSTRAINT `{{prefix}}unique_admin_group_mapping` " + "DROP TABLE `{{users_groups_mapping}}` CASCADE;" +
"UNIQUE (`admin_id`, `group_id`);" + "DROP TABLE `{{groups_folders_mapping}}` CASCADE;" +
"ALTER TABLE `{{admins_groups_mapping}}` ADD CONSTRAINT `{{prefix}}admins_groups_mapping_admin_id_fk_admins_id` " + "DROP TABLE `{{groups}}` CASCADE;" +
"FOREIGN KEY (`admin_id`) REFERENCES `{{admins}}` (`id`) ON DELETE CASCADE;" + "ALTER TABLE `{{users_folders_mapping}}` DROP FOREIGN KEY `{{prefix}}users_folders_mapping_folder_id_fk_folders_id`;" +
"ALTER TABLE `{{admins_groups_mapping}}` ADD CONSTRAINT `{{prefix}}admins_groups_mapping_group_id_fk_groups_id` " + "ALTER TABLE `{{users_folders_mapping}}` DROP FOREIGN KEY `{{prefix}}users_folders_mapping_user_id_fk_users_id`;" +
"FOREIGN KEY (`group_id`) REFERENCES `{{groups}}` (`id`) ON DELETE CASCADE;" "ALTER TABLE `{{users_folders_mapping}}` DROP INDEX `{{prefix}}unique_user_folder_mapping`;" +
mysqlV22DownSQL = "ALTER TABLE `{{admins_groups_mapping}}` DROP INDEX `{{prefix}}unique_admin_group_mapping`;" + "RENAME TABLE `{{users_folders_mapping}}` TO `{{folders_mapping}}`;" +
"DROP TABLE `{{admins_groups_mapping}}` CASCADE;" "ALTER TABLE `{{folders_mapping}}` ADD CONSTRAINT `{{prefix}}unique_mapping` UNIQUE (`user_id`, `folder_id`);" +
mysqlV23SQL = "CREATE TABLE `{{nodes}}` (`id` integer AUTO_INCREMENT NOT NULL PRIMARY KEY, " + "ALTER TABLE `{{folders_mapping}}` ADD CONSTRAINT `{{prefix}}folders_mapping_user_id_fk_users_id` " +
"`name` varchar(255) NOT NULL UNIQUE, `data` longtext NOT NULL, `created_at` bigint NOT NULL, " + "FOREIGN KEY (`user_id`) REFERENCES `{{users}}` (`id`) ON DELETE CASCADE;" +
"`updated_at` bigint NOT NULL);" "ALTER TABLE `{{folders_mapping}}` ADD CONSTRAINT `{{prefix}}folders_mapping_folder_id_fk_folders_id` " +
mysqlV23DownSQL = "DROP TABLE `{{nodes}}` CASCADE;" "FOREIGN KEY (`folder_id`) REFERENCES `{{folders}}` (`id`) ON DELETE CASCADE;"
mysqlV19SQL = "CREATE TABLE `{{shared_sessions}}` (`key` varchar(128) NOT NULL PRIMARY KEY, " +
"`data` longtext NOT NULL, `type` integer NOT NULL, `timestamp` bigint NOT NULL);" +
"CREATE INDEX `{{prefix}}shared_sessions_type_idx` ON `{{shared_sessions}}` (`type`);" +
"CREATE INDEX `{{prefix}}shared_sessions_timestamp_idx` ON `{{shared_sessions}}` (`timestamp`);"
mysqlV19DownSQL = "DROP TABLE `{{shared_sessions}}` CASCADE;"
) )
// MySQLProvider defines the auth provider for MySQL/MariaDB database // MySQLProvider defines the auth provider for MySQL/MariaDB database
@@ -220,7 +205,6 @@ func initializeMySQLProvider() error {
dbHandle.SetMaxIdleConns(2) dbHandle.SetMaxIdleConns(2)
} }
dbHandle.SetConnMaxLifetime(240 * time.Second) dbHandle.SetConnMaxLifetime(240 * time.Second)
dbHandle.SetConnMaxIdleTime(120 * time.Second)
provider = &MySQLProvider{dbHandle: dbHandle} provider = &MySQLProvider{dbHandle: dbHandle}
} else { } else {
providerLog(logger.LevelError, "error creating mysql database handler, connection string: %#v, error: %v", providerLog(logger.LevelError, "error creating mysql database handler, connection string: %#v, error: %v",
@@ -237,8 +221,37 @@ func getMySQLConnectionString(redactedPwd bool) (string, error) {
} }
sslMode := getSSLMode() sslMode := getSSLMode()
if sslMode == "custom" && !redactedPwd { if sslMode == "custom" && !redactedPwd {
if err := registerMySQLCustomTLSConfig(); err != nil { tlsConfig := &tls.Config{}
return "", err if config.RootCert != "" {
rootCAs, err := x509.SystemCertPool()
if err != nil {
rootCAs = x509.NewCertPool()
}
rootCrt, err := os.ReadFile(config.RootCert)
if err != nil {
return "", fmt.Errorf("unable to load root certificate %#v: %v", config.RootCert, err)
}
if !rootCAs.AppendCertsFromPEM(rootCrt) {
return "", fmt.Errorf("unable to parse root certificate %#v", config.RootCert)
}
tlsConfig.RootCAs = rootCAs
}
if config.ClientCert != "" && config.ClientKey != "" {
clientCert := make([]tls.Certificate, 0, 1)
tlsCert, err := tls.LoadX509KeyPair(config.ClientCert, config.ClientKey)
if err != nil {
return "", fmt.Errorf("unable to load key pair %#v, %#v: %v", config.ClientCert, config.ClientKey, err)
}
clientCert = append(clientCert, tlsCert)
tlsConfig.Certificates = clientCert
}
if config.SSLMode == 2 {
tlsConfig.InsecureSkipVerify = true
}
providerLog(logger.LevelInfo, "registering custom TLS config, root cert %#v, client cert %#v, client key %#v",
config.RootCert, config.ClientCert, config.ClientKey)
if err := mysql.RegisterTLSConfig("custom", tlsConfig); err != nil {
return "", fmt.Errorf("unable to register tls config: %v", err)
} }
} }
connectionString = fmt.Sprintf("%v:%v@tcp([%v]:%v)/%v?charset=utf8mb4&interpolateParams=true&timeout=10s&parseTime=true&tls=%v&writeTimeout=60s&readTimeout=60s", connectionString = fmt.Sprintf("%v:%v@tcp([%v]:%v)/%v?charset=utf8mb4&interpolateParams=true&timeout=10s&parseTime=true&tls=%v&writeTimeout=60s&readTimeout=60s",
@@ -249,45 +262,6 @@ func getMySQLConnectionString(redactedPwd bool) (string, error) {
return connectionString, nil return connectionString, nil
} }
func registerMySQLCustomTLSConfig() error {
tlsConfig := &tls.Config{}
if config.RootCert != "" {
rootCAs, err := x509.SystemCertPool()
if err != nil {
rootCAs = x509.NewCertPool()
}
rootCrt, err := os.ReadFile(config.RootCert)
if err != nil {
return fmt.Errorf("unable to load root certificate %#v: %v", config.RootCert, err)
}
if !rootCAs.AppendCertsFromPEM(rootCrt) {
return fmt.Errorf("unable to parse root certificate %#v", config.RootCert)
}
tlsConfig.RootCAs = rootCAs
}
if config.ClientCert != "" && config.ClientKey != "" {
clientCert := make([]tls.Certificate, 0, 1)
tlsCert, err := tls.LoadX509KeyPair(config.ClientCert, config.ClientKey)
if err != nil {
return fmt.Errorf("unable to load key pair %#v, %#v: %v", config.ClientCert, config.ClientKey, err)
}
clientCert = append(clientCert, tlsCert)
tlsConfig.Certificates = clientCert
}
if config.SSLMode == 2 || config.SSLMode == 3 {
tlsConfig.InsecureSkipVerify = true
}
if !filepath.IsAbs(config.Host) && !config.DisableSNI {
tlsConfig.ServerName = config.Host
}
providerLog(logger.LevelInfo, "registering custom TLS config, root cert %#v, client cert %#v, client key %#v, disable SNI? %v",
config.RootCert, config.ClientCert, config.ClientKey, config.DisableSNI)
if err := mysql.RegisterTLSConfig("custom", tlsConfig); err != nil {
return fmt.Errorf("unable to register tls config: %v", err)
}
return nil
}
func (p *MySQLProvider) checkAvailability() error { func (p *MySQLProvider) checkAvailability() error {
return sqlCommonCheckAvailability(p.dbHandle) return sqlCommonCheckAvailability(p.dbHandle)
} }
@@ -340,8 +314,8 @@ func (p *MySQLProvider) updateUser(user *User) error {
return sqlCommonUpdateUser(user, p.dbHandle) return sqlCommonUpdateUser(user, p.dbHandle)
} }
func (p *MySQLProvider) deleteUser(user User, softDelete bool) error { func (p *MySQLProvider) deleteUser(user User) error {
return sqlCommonDeleteUser(user, softDelete, p.dbHandle) return sqlCommonDeleteUser(user, p.dbHandle)
} }
func (p *MySQLProvider) updateUserPassword(username, password string) error { func (p *MySQLProvider) updateUserPassword(username, password string) error {
@@ -582,102 +556,6 @@ func (p *MySQLProvider) cleanupSharedSessions(sessionType SessionType, before in
return sqlCommonCleanupSessions(sessionType, before, p.dbHandle) return sqlCommonCleanupSessions(sessionType, before, p.dbHandle)
} }
func (p *MySQLProvider) getEventActions(limit, offset int, order string, minimal bool) ([]BaseEventAction, error) {
return sqlCommonGetEventActions(limit, offset, order, minimal, p.dbHandle)
}
func (p *MySQLProvider) dumpEventActions() ([]BaseEventAction, error) {
return sqlCommonDumpEventActions(p.dbHandle)
}
func (p *MySQLProvider) eventActionExists(name string) (BaseEventAction, error) {
return sqlCommonGetEventActionByName(name, p.dbHandle)
}
func (p *MySQLProvider) addEventAction(action *BaseEventAction) error {
return sqlCommonAddEventAction(action, p.dbHandle)
}
func (p *MySQLProvider) updateEventAction(action *BaseEventAction) error {
return sqlCommonUpdateEventAction(action, p.dbHandle)
}
func (p *MySQLProvider) deleteEventAction(action BaseEventAction) error {
return sqlCommonDeleteEventAction(action, p.dbHandle)
}
func (p *MySQLProvider) getEventRules(limit, offset int, order string) ([]EventRule, error) {
return sqlCommonGetEventRules(limit, offset, order, p.dbHandle)
}
func (p *MySQLProvider) dumpEventRules() ([]EventRule, error) {
return sqlCommonDumpEventRules(p.dbHandle)
}
func (p *MySQLProvider) getRecentlyUpdatedRules(after int64) ([]EventRule, error) {
return sqlCommonGetRecentlyUpdatedRules(after, p.dbHandle)
}
func (p *MySQLProvider) eventRuleExists(name string) (EventRule, error) {
return sqlCommonGetEventRuleByName(name, p.dbHandle)
}
func (p *MySQLProvider) addEventRule(rule *EventRule) error {
return sqlCommonAddEventRule(rule, p.dbHandle)
}
func (p *MySQLProvider) updateEventRule(rule *EventRule) error {
return sqlCommonUpdateEventRule(rule, p.dbHandle)
}
func (p *MySQLProvider) deleteEventRule(rule EventRule, softDelete bool) error {
return sqlCommonDeleteEventRule(rule, softDelete, p.dbHandle)
}
func (p *MySQLProvider) getTaskByName(name string) (Task, error) {
return sqlCommonGetTaskByName(name, p.dbHandle)
}
func (p *MySQLProvider) addTask(name string) error {
return sqlCommonAddTask(name, p.dbHandle)
}
func (p *MySQLProvider) updateTask(name string, version int64) error {
return sqlCommonUpdateTask(name, version, p.dbHandle)
}
func (p *MySQLProvider) updateTaskTimestamp(name string) error {
return sqlCommonUpdateTaskTimestamp(name, p.dbHandle)
}
func (p *MySQLProvider) addNode() error {
return sqlCommonAddNode(p.dbHandle)
}
func (p *MySQLProvider) getNodeByName(name string) (Node, error) {
return sqlCommonGetNodeByName(name, p.dbHandle)
}
func (p *MySQLProvider) getNodes() ([]Node, error) {
return sqlCommonGetNodes(p.dbHandle)
}
func (p *MySQLProvider) updateNodeTimestamp() error {
return sqlCommonUpdateNodeTimestamp(p.dbHandle)
}
func (p *MySQLProvider) cleanupNodes() error {
return sqlCommonCleanupNodes(p.dbHandle)
}
func (p *MySQLProvider) setFirstDownloadTimestamp(username string) error {
return sqlCommonSetFirstDownloadTimestamp(username, p.dbHandle)
}
func (p *MySQLProvider) setFirstUploadTimestamp(username string) error {
return sqlCommonSetFirstUploadTimestamp(username, p.dbHandle)
}
func (p *MySQLProvider) close() error { func (p *MySQLProvider) close() error {
return p.dbHandle.Close() return p.dbHandle.Close()
} }
@@ -695,11 +573,20 @@ func (p *MySQLProvider) initializeDatabase() error {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return errSchemaVersionEmpty return errSchemaVersionEmpty
} }
logger.InfoToConsole("creating initial database schema, version 19") logger.InfoToConsole("creating initial database schema, version 15")
providerLog(logger.LevelInfo, "creating initial database schema, version 19") providerLog(logger.LevelInfo, "creating initial database schema, version 15")
initialSQL := sqlReplaceAll(mysqlInitialSQL) initialSQL := strings.ReplaceAll(mysqlInitialSQL, "{{schema_version}}", sqlTableSchemaVersion)
initialSQL = strings.ReplaceAll(initialSQL, "{{admins}}", sqlTableAdmins)
initialSQL = strings.ReplaceAll(initialSQL, "{{folders}}", sqlTableFolders)
initialSQL = strings.ReplaceAll(initialSQL, "{{users}}", sqlTableUsers)
initialSQL = strings.ReplaceAll(initialSQL, "{{folders_mapping}}", sqlTableFoldersMapping)
initialSQL = strings.ReplaceAll(initialSQL, "{{api_keys}}", sqlTableAPIKeys)
initialSQL = strings.ReplaceAll(initialSQL, "{{shares}}", sqlTableShares)
initialSQL = strings.ReplaceAll(initialSQL, "{{defender_events}}", sqlTableDefenderEvents)
initialSQL = strings.ReplaceAll(initialSQL, "{{defender_hosts}}", sqlTableDefenderHosts)
initialSQL = strings.ReplaceAll(initialSQL, "{{prefix}}", config.SQLTablesPrefix)
return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, strings.Split(initialSQL, ";"), 19, true) return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, strings.Split(initialSQL, ";"), 15, true)
} }
func (p *MySQLProvider) migrateDatabase() error { //nolint:dupl func (p *MySQLProvider) migrateDatabase() error { //nolint:dupl
@@ -712,28 +599,28 @@ func (p *MySQLProvider) migrateDatabase() error { //nolint:dupl
case version == sqlDatabaseVersion: case version == sqlDatabaseVersion:
providerLog(logger.LevelDebug, "sql database is up to date, current version: %v", version) providerLog(logger.LevelDebug, "sql database is up to date, current version: %v", version)
return ErrNoInitRequired return ErrNoInitRequired
case version < 19: case version < 15:
err = fmt.Errorf("database schema version %v is too old, please see the upgrading docs", version) err = fmt.Errorf("database version %v is too old, please see the upgrading docs", version)
providerLog(logger.LevelError, "%v", err) providerLog(logger.LevelError, "%v", err)
logger.ErrorToConsole("%v", err) logger.ErrorToConsole("%v", err)
return err return err
case version == 19: case version == 15:
return updateMySQLDatabaseFromV19(p.dbHandle) return updateMySQLDatabaseFromV15(p.dbHandle)
case version == 20: case version == 16:
return updateMySQLDatabaseFromV20(p.dbHandle) return updateMySQLDatabaseFromV16(p.dbHandle)
case version == 21: case version == 17:
return updateMySQLDatabaseFromV21(p.dbHandle) return updateMySQLDatabaseFromV17(p.dbHandle)
case version == 22: case version == 18:
return updateMySQLDatabaseFromV22(p.dbHandle) return updateMySQLDatabaseFromV18(p.dbHandle)
default: default:
if version > sqlDatabaseVersion { if version > sqlDatabaseVersion {
providerLog(logger.LevelError, "database schema version %v is newer than the supported one: %v", version, providerLog(logger.LevelError, "database version %v is newer than the supported one: %v", version,
sqlDatabaseVersion) sqlDatabaseVersion)
logger.WarnToConsole("database schema version %v is newer than the supported one: %v", version, logger.WarnToConsole("database version %v is newer than the supported one: %v", version,
sqlDatabaseVersion) sqlDatabaseVersion)
return nil return nil
} }
return fmt.Errorf("database schema version not handled: %v", version) return fmt.Errorf("database version not handled: %v", version)
} }
} }
@@ -747,16 +634,16 @@ func (p *MySQLProvider) revertDatabase(targetVersion int) error {
} }
switch dbVersion.Version { switch dbVersion.Version {
case 20: case 16:
return downgradeMySQLDatabaseFromV20(p.dbHandle) return downgradeMySQLDatabaseFromV16(p.dbHandle)
case 21: case 17:
return downgradeMySQLDatabaseFromV21(p.dbHandle) return downgradeMySQLDatabaseFromV17(p.dbHandle)
case 22: case 18:
return downgradeMySQLDatabaseFromV22(p.dbHandle) return downgradeMySQLDatabaseFromV18(p.dbHandle)
case 23: case 19:
return downgradeMySQLDatabaseFromV23(p.dbHandle) return downgradeMySQLDatabaseFromV19(p.dbHandle)
default: default:
return fmt.Errorf("database schema version not handled: %v", dbVersion.Version) return fmt.Errorf("database version not handled: %v", dbVersion.Version)
} }
} }
@@ -765,121 +652,127 @@ func (p *MySQLProvider) resetDatabase() error {
return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, strings.Split(sql, ";"), 0, false) return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, strings.Split(sql, ";"), 0, false)
} }
func updateMySQLDatabaseFromV19(dbHandle *sql.DB) error { func updateMySQLDatabaseFromV15(dbHandle *sql.DB) error {
if err := updateMySQLDatabaseFrom19To20(dbHandle); err != nil { if err := updateMySQLDatabaseFrom15To16(dbHandle); err != nil {
return err return err
} }
return updateMySQLDatabaseFromV20(dbHandle) return updateMySQLDatabaseFromV16(dbHandle)
} }
func updateMySQLDatabaseFromV20(dbHandle *sql.DB) error { func updateMySQLDatabaseFromV16(dbHandle *sql.DB) error {
if err := updateMySQLDatabaseFrom20To21(dbHandle); err != nil { if err := updateMySQLDatabaseFrom16To17(dbHandle); err != nil {
return err return err
} }
return updateMySQLDatabaseFromV21(dbHandle) return updateMySQLDatabaseFromV17(dbHandle)
} }
func updateMySQLDatabaseFromV21(dbHandle *sql.DB) error { func updateMySQLDatabaseFromV17(dbHandle *sql.DB) error {
if err := updateMySQLDatabaseFrom21To22(dbHandle); err != nil { if err := updateMySQLDatabaseFrom17To18(dbHandle); err != nil {
return err return err
} }
return updateMySQLDatabaseFromV22(dbHandle) return updateMySQLDatabaseFromV18(dbHandle)
} }
func updateMySQLDatabaseFromV22(dbHandle *sql.DB) error { func updateMySQLDatabaseFromV18(dbHandle *sql.DB) error {
return updateMySQLDatabaseFrom22To23(dbHandle) return updateMySQLDatabaseFrom18To19(dbHandle)
} }
func downgradeMySQLDatabaseFromV20(dbHandle *sql.DB) error { func downgradeMySQLDatabaseFromV16(dbHandle *sql.DB) error {
return downgradeMySQLDatabaseFrom20To19(dbHandle) return downgradeMySQLDatabaseFrom16To15(dbHandle)
} }
func downgradeMySQLDatabaseFromV21(dbHandle *sql.DB) error { func downgradeMySQLDatabaseFromV17(dbHandle *sql.DB) error {
if err := downgradeMySQLDatabaseFrom21To20(dbHandle); err != nil { if err := downgradeMySQLDatabaseFrom17To16(dbHandle); err != nil {
return err return err
} }
return downgradeMySQLDatabaseFromV20(dbHandle) return downgradeMySQLDatabaseFromV16(dbHandle)
} }
func downgradeMySQLDatabaseFromV22(dbHandle *sql.DB) error { func downgradeMySQLDatabaseFromV18(dbHandle *sql.DB) error {
if err := downgradeMySQLDatabaseFrom22To21(dbHandle); err != nil { if err := downgradeMySQLDatabaseFrom18To17(dbHandle); err != nil {
return err return err
} }
return downgradeMySQLDatabaseFromV21(dbHandle) return downgradeMySQLDatabaseFromV17(dbHandle)
} }
func downgradeMySQLDatabaseFromV23(dbHandle *sql.DB) error { func downgradeMySQLDatabaseFromV19(dbHandle *sql.DB) error {
if err := downgradeMySQLDatabaseFrom23To22(dbHandle); err != nil { if err := downgradeMySQLDatabaseFrom19To18(dbHandle); err != nil {
return err return err
} }
return downgradeMySQLDatabaseFromV22(dbHandle) return downgradeMySQLDatabaseFromV18(dbHandle)
} }
func updateMySQLDatabaseFrom19To20(dbHandle *sql.DB) error { func updateMySQLDatabaseFrom15To16(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database schema version: 19 -> 20") logger.InfoToConsole("updating database version: 15 -> 16")
providerLog(logger.LevelInfo, "updating database schema version: 19 -> 20") providerLog(logger.LevelInfo, "updating database version: 15 -> 16")
sql := strings.ReplaceAll(mysqlV20SQL, "{{events_actions}}", sqlTableEventsActions) sql := strings.ReplaceAll(mysqlV16SQL, "{{users}}", sqlTableUsers)
sql = strings.ReplaceAll(sql, "{{events_rules}}", sqlTableEventsRules) sql = strings.ReplaceAll(sql, "{{active_transfers}}", sqlTableActiveTransfers)
sql = strings.ReplaceAll(sql, "{{rules_actions_mapping}}", sqlTableRulesActionsMapping)
sql = strings.ReplaceAll(sql, "{{users}}", sqlTableUsers)
sql = strings.ReplaceAll(sql, "{{tasks}}", sqlTableTasks)
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix) sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 20, true) return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 16, true)
} }
func updateMySQLDatabaseFrom20To21(dbHandle *sql.DB) error { func updateMySQLDatabaseFrom16To17(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database schema version: 20 -> 21") logger.InfoToConsole("updating database version: 16 -> 17")
providerLog(logger.LevelInfo, "updating database schema version: 20 -> 21") providerLog(logger.LevelInfo, "updating database version: 16 -> 17")
sql := strings.ReplaceAll(mysqlV21SQL, "{{users}}", sqlTableUsers) sql := strings.ReplaceAll(mysqlV17SQL, "{{users}}", sqlTableUsers)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 21, true)
}
func updateMySQLDatabaseFrom21To22(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database schema version: 21 -> 22")
providerLog(logger.LevelInfo, "updating database schema version: 21 -> 22")
sql := strings.ReplaceAll(mysqlV22SQL, "{{admins_groups_mapping}}", sqlTableAdminsGroupsMapping)
sql = strings.ReplaceAll(sql, "{{admins}}", sqlTableAdmins)
sql = strings.ReplaceAll(sql, "{{groups}}", sqlTableGroups) sql = strings.ReplaceAll(sql, "{{groups}}", sqlTableGroups)
sql = strings.ReplaceAll(sql, "{{folders}}", sqlTableFolders)
sql = strings.ReplaceAll(sql, "{{folders_mapping}}", sqlTableFoldersMapping)
sql = strings.ReplaceAll(sql, "{{users_folders_mapping}}", sqlTableUsersFoldersMapping)
sql = strings.ReplaceAll(sql, "{{users_groups_mapping}}", sqlTableUsersGroupsMapping)
sql = strings.ReplaceAll(sql, "{{groups_folders_mapping}}", sqlTableGroupsFoldersMapping)
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix) sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 22, true) return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 17, true)
} }
func updateMySQLDatabaseFrom22To23(dbHandle *sql.DB) error { func updateMySQLDatabaseFrom17To18(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database schema version: 22 -> 23") logger.InfoToConsole("updating database version: 17 -> 18")
providerLog(logger.LevelInfo, "updating database schema version: 22 -> 23") providerLog(logger.LevelInfo, "updating database version: 17 -> 18")
sql := strings.ReplaceAll(mysqlV23SQL, "{{nodes}}", sqlTableNodes) if err := importGCSCredentials(); err != nil {
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 23, true) return err
}
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, nil, 18, true)
} }
func downgradeMySQLDatabaseFrom20To19(dbHandle *sql.DB) error { func updateMySQLDatabaseFrom18To19(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database schema version: 20 -> 19") logger.InfoToConsole("updating database version: 18 -> 19")
providerLog(logger.LevelInfo, "downgrading database schema version: 20 -> 19") providerLog(logger.LevelInfo, "updating database version: 18 -> 19")
sql := strings.ReplaceAll(mysqlV20DownSQL, "{{events_actions}}", sqlTableEventsActions) sql := strings.ReplaceAll(mysqlV19SQL, "{{shared_sessions}}", sqlTableSharedSessions)
sql = strings.ReplaceAll(sql, "{{events_rules}}", sqlTableEventsRules)
sql = strings.ReplaceAll(sql, "{{rules_actions_mapping}}", sqlTableRulesActionsMapping)
sql = strings.ReplaceAll(sql, "{{users}}", sqlTableUsers)
sql = strings.ReplaceAll(sql, "{{tasks}}", sqlTableTasks)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 19, false)
}
func downgradeMySQLDatabaseFrom21To20(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database schema version: 21 -> 20")
providerLog(logger.LevelInfo, "downgrading database schema version: 21 -> 20")
sql := strings.ReplaceAll(mysqlV21DownSQL, "{{users}}", sqlTableUsers)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 20, false)
}
func downgradeMySQLDatabaseFrom22To21(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database schema version: 22 -> 21")
providerLog(logger.LevelInfo, "downgrading database schema version: 22 -> 21")
sql := strings.ReplaceAll(mysqlV22DownSQL, "{{admins_groups_mapping}}", sqlTableAdminsGroupsMapping)
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix) sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 21, false) return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 19, true)
} }
func downgradeMySQLDatabaseFrom23To22(dbHandle *sql.DB) error { func downgradeMySQLDatabaseFrom16To15(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database schema version: 23 -> 22") logger.InfoToConsole("downgrading database version: 16 -> 15")
providerLog(logger.LevelInfo, "downgrading database schema version: 23 -> 22") providerLog(logger.LevelInfo, "downgrading database version: 16 -> 15")
sql := strings.ReplaceAll(mysqlV23DownSQL, "{{nodes}}", sqlTableNodes) sql := strings.ReplaceAll(mysqlV16DownSQL, "{{users}}", sqlTableUsers)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 22, false) sql = strings.ReplaceAll(sql, "{{active_transfers}}", sqlTableActiveTransfers)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 15, false)
}
func downgradeMySQLDatabaseFrom17To16(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database version: 17 -> 16")
providerLog(logger.LevelInfo, "downgrading database version: 17 -> 16")
sql := strings.ReplaceAll(mysqlV17DownSQL, "{{users}}", sqlTableUsers)
sql = strings.ReplaceAll(sql, "{{groups}}", sqlTableGroups)
sql = strings.ReplaceAll(sql, "{{folders}}", sqlTableFolders)
sql = strings.ReplaceAll(sql, "{{folders_mapping}}", sqlTableFoldersMapping)
sql = strings.ReplaceAll(sql, "{{users_folders_mapping}}", sqlTableUsersFoldersMapping)
sql = strings.ReplaceAll(sql, "{{users_groups_mapping}}", sqlTableUsersGroupsMapping)
sql = strings.ReplaceAll(sql, "{{groups_folders_mapping}}", sqlTableGroupsFoldersMapping)
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, strings.Split(sql, ";"), 16, false)
}
func downgradeMySQLDatabaseFrom18To17(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database version: 18 -> 17")
providerLog(logger.LevelInfo, "downgrading database version: 18 -> 17")
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, nil, 17, false)
}
func downgradeMySQLDatabaseFrom19To18(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database version: 19 -> 18")
providerLog(logger.LevelInfo, "downgrading database version: 19 -> 18")
sql := strings.ReplaceAll(mysqlV19DownSQL, "{{shared_sessions}}", sqlTableSharedSessions)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 18, false)
} }

View File

@@ -20,7 +20,7 @@ package dataprovider
import ( import (
"errors" "errors"
"github.com/drakkan/sftpgo/v2/internal/version" "github.com/drakkan/sftpgo/v2/version"
) )
func init() { func init() {

View File

@@ -26,12 +26,12 @@ import (
"strings" "strings"
"time" "time"
// we import pgx here to be able to disable PostgreSQL support using a build tag // we import lib/pq here to be able to disable PostgreSQL support using a build tag
_ "github.com/jackc/pgx/v5/stdlib" _ "github.com/lib/pq"
"github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/internal/version" "github.com/drakkan/sftpgo/v2/version"
"github.com/drakkan/sftpgo/v2/internal/vfs" "github.com/drakkan/sftpgo/v2/vfs"
) )
const ( const (
@@ -39,7 +39,6 @@ const (
DROP TABLE IF EXISTS "{{folders_mapping}}" CASCADE; DROP TABLE IF EXISTS "{{folders_mapping}}" CASCADE;
DROP TABLE IF EXISTS "{{users_folders_mapping}}" CASCADE; DROP TABLE IF EXISTS "{{users_folders_mapping}}" CASCADE;
DROP TABLE IF EXISTS "{{users_groups_mapping}}" CASCADE; DROP TABLE IF EXISTS "{{users_groups_mapping}}" CASCADE;
DROP TABLE IF EXISTS "{{admins_groups_mapping}}" CASCADE;
DROP TABLE IF EXISTS "{{groups_folders_mapping}}" CASCADE; DROP TABLE IF EXISTS "{{groups_folders_mapping}}" CASCADE;
DROP TABLE IF EXISTS "{{admins}}" CASCADE; DROP TABLE IF EXISTS "{{admins}}" CASCADE;
DROP TABLE IF EXISTS "{{folders}}" CASCADE; DROP TABLE IF EXISTS "{{folders}}" CASCADE;
@@ -50,11 +49,6 @@ DROP TABLE IF EXISTS "{{defender_events}}" CASCADE;
DROP TABLE IF EXISTS "{{defender_hosts}}" CASCADE; DROP TABLE IF EXISTS "{{defender_hosts}}" CASCADE;
DROP TABLE IF EXISTS "{{active_transfers}}" CASCADE; DROP TABLE IF EXISTS "{{active_transfers}}" CASCADE;
DROP TABLE IF EXISTS "{{shared_sessions}}" CASCADE; DROP TABLE IF EXISTS "{{shared_sessions}}" CASCADE;
DROP TABLE IF EXISTS "{{rules_actions_mapping}}" CASCADE;
DROP TABLE IF EXISTS "{{events_actions}}" CASCADE;
DROP TABLE IF EXISTS "{{events_rules}}" CASCADE;
DROP TABLE IF EXISTS "{{tasks}}" CASCADE;
DROP TABLE IF EXISTS "{{nodes}}" CASCADE;
DROP TABLE IF EXISTS "{{schema_version}}" CASCADE; DROP TABLE IF EXISTS "{{schema_version}}" CASCADE;
` `
pgsqlInitial = `CREATE TABLE "{{schema_version}}" ("id" serial NOT NULL PRIMARY KEY, "version" integer NOT NULL); pgsqlInitial = `CREATE TABLE "{{schema_version}}" ("id" serial NOT NULL PRIMARY KEY, "version" integer NOT NULL);
@@ -62,11 +56,6 @@ CREATE TABLE "{{admins}}" ("id" serial NOT NULL PRIMARY KEY, "username" varchar(
"description" varchar(512) NULL, "password" varchar(255) NOT NULL, "email" varchar(255) NULL, "status" integer NOT NULL, "description" varchar(512) NULL, "password" varchar(255) NOT NULL, "email" varchar(255) NULL, "status" integer NOT NULL,
"permissions" text NOT NULL, "filters" text NULL, "additional_info" text NULL, "last_login" bigint NOT NULL, "permissions" text NOT NULL, "filters" text NULL, "additional_info" text NULL, "last_login" bigint NOT NULL,
"created_at" bigint NOT NULL, "updated_at" bigint NOT NULL); "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL);
CREATE TABLE "{{active_transfers}}" ("id" bigserial NOT NULL PRIMARY KEY, "connection_id" varchar(100) NOT NULL,
"transfer_id" bigint NOT NULL, "transfer_type" integer NOT NULL, "username" varchar(255) NOT NULL,
"folder_name" varchar(255) NULL, "ip" varchar(50) NOT NULL, "truncated_size" bigint NOT NULL,
"current_ul_size" bigint NOT NULL, "current_dl_size" bigint NOT NULL, "created_at" bigint NOT NULL,
"updated_at" bigint NOT NULL);
CREATE TABLE "{{defender_hosts}}" ("id" bigserial NOT NULL PRIMARY KEY, "ip" varchar(50) NOT NULL UNIQUE, CREATE TABLE "{{defender_hosts}}" ("id" bigserial NOT NULL PRIMARY KEY, "ip" varchar(50) NOT NULL UNIQUE,
"ban_time" bigint NOT NULL, "updated_at" bigint NOT NULL); "ban_time" bigint NOT NULL, "updated_at" bigint NOT NULL);
CREATE TABLE "{{defender_events}}" ("id" bigserial NOT NULL PRIMARY KEY, "date_time" bigint NOT NULL, "score" integer NOT NULL, CREATE TABLE "{{defender_events}}" ("id" bigserial NOT NULL PRIMARY KEY, "date_time" bigint NOT NULL, "score" integer NOT NULL,
@@ -76,29 +65,19 @@ ALTER TABLE "{{defender_events}}" ADD CONSTRAINT "{{prefix}}defender_events_host
CREATE TABLE "{{folders}}" ("id" serial NOT NULL PRIMARY KEY, "name" varchar(255) NOT NULL UNIQUE, "description" varchar(512) NULL, CREATE TABLE "{{folders}}" ("id" serial NOT NULL PRIMARY KEY, "name" varchar(255) NOT NULL UNIQUE, "description" varchar(512) NULL,
"path" text NULL, "used_quota_size" bigint NOT NULL, "used_quota_files" integer NOT NULL, "last_quota_update" bigint NOT NULL, "path" text NULL, "used_quota_size" bigint NOT NULL, "used_quota_files" integer NOT NULL, "last_quota_update" bigint NOT NULL,
"filesystem" text NULL); "filesystem" text NULL);
CREATE TABLE "{{groups}}" ("id" serial NOT NULL PRIMARY KEY, "name" varchar(255) NOT NULL UNIQUE,
"description" varchar(512) NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "user_settings" text NULL);
CREATE TABLE "{{shared_sessions}}" ("key" varchar(128) NOT NULL PRIMARY KEY,
"data" text NOT NULL, "type" integer NOT NULL, "timestamp" bigint NOT NULL);
CREATE TABLE "{{users}}" ("id" serial NOT NULL PRIMARY KEY, "username" varchar(255) NOT NULL UNIQUE, "status" integer NOT NULL, CREATE TABLE "{{users}}" ("id" serial NOT NULL PRIMARY KEY, "username" varchar(255) NOT NULL UNIQUE, "status" integer NOT NULL,
"expiration_date" bigint NOT NULL, "description" varchar(512) NULL, "password" text NULL, "public_keys" text NULL, "expiration_date" bigint NOT NULL, "description" varchar(512) NULL, "password" text NULL, "public_keys" text NULL,
"home_dir" text NOT NULL, "uid" bigint NOT NULL, "gid" bigint NOT NULL, "max_sessions" integer NOT NULL, "home_dir" text NOT NULL, "uid" bigint NOT NULL, "gid" bigint NOT NULL, "max_sessions" integer NOT NULL,
"quota_size" bigint NOT NULL, "quota_files" integer NOT NULL, "permissions" text NOT NULL, "used_quota_size" bigint NOT NULL, "quota_size" bigint NOT NULL, "quota_files" integer NOT NULL, "permissions" text NOT NULL, "used_quota_size" bigint NOT NULL,
"used_quota_files" integer NOT NULL, "last_quota_update" bigint NOT NULL, "upload_bandwidth" integer NOT NULL, "used_quota_files" integer NOT NULL, "last_quota_update" bigint NOT NULL, "upload_bandwidth" integer NOT NULL,
"download_bandwidth" integer NOT NULL, "last_login" bigint NOT NULL, "filters" text NULL, "filesystem" text NULL, "download_bandwidth" integer NOT NULL, "last_login" bigint NOT NULL, "filters" text NULL, "filesystem" text NULL,
"additional_info" text NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "email" varchar(255) NULL, "additional_info" text NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "email" varchar(255) NULL);
"upload_data_transfer" integer NOT NULL, "download_data_transfer" integer NOT NULL, "total_data_transfer" integer NOT NULL, CREATE TABLE "{{folders_mapping}}" ("id" serial NOT NULL PRIMARY KEY, "virtual_path" text NOT NULL,
"used_upload_data_transfer" integer NOT NULL, "used_download_data_transfer" integer NOT NULL);
CREATE TABLE "{{groups_folders_mapping}}" ("id" serial NOT NULL PRIMARY KEY, "group_id" integer NOT NULL,
"folder_id" integer NOT NULL, "virtual_path" text NOT NULL, "quota_size" bigint NOT NULL, "quota_files" integer NOT NULL);
CREATE TABLE "{{users_groups_mapping}}" ("id" serial NOT NULL PRIMARY KEY, "user_id" integer NOT NULL,
"group_id" integer NOT NULL, "group_type" integer NOT NULL);
CREATE TABLE "{{users_folders_mapping}}" ("id" serial NOT NULL PRIMARY KEY, "virtual_path" text NOT NULL,
"quota_size" bigint NOT NULL, "quota_files" integer NOT NULL, "folder_id" integer NOT NULL, "user_id" integer NOT NULL); "quota_size" bigint NOT NULL, "quota_files" integer NOT NULL, "folder_id" integer NOT NULL, "user_id" integer NOT NULL);
ALTER TABLE "{{users_folders_mapping}}" ADD CONSTRAINT "{{prefix}}unique_user_folder_mapping" UNIQUE ("user_id", "folder_id"); ALTER TABLE "{{folders_mapping}}" ADD CONSTRAINT "{{prefix}}unique_mapping" UNIQUE ("user_id", "folder_id");
ALTER TABLE "{{users_folders_mapping}}" ADD CONSTRAINT "{{prefix}}users_folders_mapping_folder_id_fk_folders_id" ALTER TABLE "{{folders_mapping}}" ADD CONSTRAINT "{{prefix}}folders_mapping_folder_id_fk_folders_id"
FOREIGN KEY ("folder_id") REFERENCES "{{folders}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE; FOREIGN KEY ("folder_id") REFERENCES "{{folders}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE;
ALTER TABLE "{{users_folders_mapping}}" ADD CONSTRAINT "{{prefix}}users_folders_mapping_user_id_fk_users_id" ALTER TABLE "{{folders_mapping}}" ADD CONSTRAINT "{{prefix}}folders_mapping_user_id_fk_users_id"
FOREIGN KEY ("user_id") REFERENCES "{{users}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE; FOREIGN KEY ("user_id") REFERENCES "{{users}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE;
CREATE TABLE "{{shares}}" ("id" serial NOT NULL PRIMARY KEY, CREATE TABLE "{{shares}}" ("id" serial NOT NULL PRIMARY KEY,
"share_id" varchar(60) NOT NULL UNIQUE, "name" varchar(255) NOT NULL, "description" varchar(512) NULL, "share_id" varchar(60) NOT NULL UNIQUE, "name" varchar(255) NOT NULL, "description" varchar(512) NULL,
@@ -116,6 +95,57 @@ ALTER TABLE "{{api_keys}}" ADD CONSTRAINT "{{prefix}}api_keys_admin_id_fk_admins
REFERENCES "{{admins}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE; REFERENCES "{{admins}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE;
ALTER TABLE "{{api_keys}}" ADD CONSTRAINT "{{prefix}}api_keys_user_id_fk_users_id" FOREIGN KEY ("user_id") ALTER TABLE "{{api_keys}}" ADD CONSTRAINT "{{prefix}}api_keys_user_id_fk_users_id" FOREIGN KEY ("user_id")
REFERENCES "{{users}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE; REFERENCES "{{users}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE;
CREATE INDEX "{{prefix}}folders_mapping_folder_id_idx" ON "{{folders_mapping}}" ("folder_id");
CREATE INDEX "{{prefix}}folders_mapping_user_id_idx" ON "{{folders_mapping}}" ("user_id");
CREATE INDEX "{{prefix}}api_keys_admin_id_idx" ON "{{api_keys}}" ("admin_id");
CREATE INDEX "{{prefix}}api_keys_user_id_idx" ON "{{api_keys}}" ("user_id");
CREATE INDEX "{{prefix}}users_updated_at_idx" ON "{{users}}" ("updated_at");
CREATE INDEX "{{prefix}}shares_user_id_idx" ON "{{shares}}" ("user_id");
CREATE INDEX "{{prefix}}defender_hosts_updated_at_idx" ON "{{defender_hosts}}" ("updated_at");
CREATE INDEX "{{prefix}}defender_hosts_ban_time_idx" ON "{{defender_hosts}}" ("ban_time");
CREATE INDEX "{{prefix}}defender_events_date_time_idx" ON "{{defender_events}}" ("date_time");
CREATE INDEX "{{prefix}}defender_events_host_id_idx" ON "{{defender_events}}" ("host_id");
INSERT INTO {{schema_version}} (version) VALUES (15);
`
pgsqlV16SQL = `ALTER TABLE "{{users}}" ADD COLUMN "download_data_transfer" integer DEFAULT 0 NOT NULL;
ALTER TABLE "{{users}}" ALTER COLUMN "download_data_transfer" DROP DEFAULT;
ALTER TABLE "{{users}}" ADD COLUMN "total_data_transfer" integer DEFAULT 0 NOT NULL;
ALTER TABLE "{{users}}" ALTER COLUMN "total_data_transfer" DROP DEFAULT;
ALTER TABLE "{{users}}" ADD COLUMN "upload_data_transfer" integer DEFAULT 0 NOT NULL;
ALTER TABLE "{{users}}" ALTER COLUMN "upload_data_transfer" DROP DEFAULT;
ALTER TABLE "{{users}}" ADD COLUMN "used_download_data_transfer" integer DEFAULT 0 NOT NULL;
ALTER TABLE "{{users}}" ALTER COLUMN "used_download_data_transfer" DROP DEFAULT;
ALTER TABLE "{{users}}" ADD COLUMN "used_upload_data_transfer" integer DEFAULT 0 NOT NULL;
ALTER TABLE "{{users}}" ALTER COLUMN "used_upload_data_transfer" DROP DEFAULT;
CREATE TABLE "{{active_transfers}}" ("id" bigserial NOT NULL PRIMARY KEY, "connection_id" varchar(100) NOT NULL,
"transfer_id" bigint NOT NULL, "transfer_type" integer NOT NULL, "username" varchar(255) NOT NULL,
"folder_name" varchar(255) NULL, "ip" varchar(50) NOT NULL, "truncated_size" bigint NOT NULL,
"current_ul_size" bigint NOT NULL, "current_dl_size" bigint NOT NULL, "created_at" bigint NOT NULL,
"updated_at" bigint NOT NULL);
CREATE INDEX "{{prefix}}active_transfers_connection_id_idx" ON "{{active_transfers}}" ("connection_id");
CREATE INDEX "{{prefix}}active_transfers_transfer_id_idx" ON "{{active_transfers}}" ("transfer_id");
CREATE INDEX "{{prefix}}active_transfers_updated_at_idx" ON "{{active_transfers}}" ("updated_at");
`
pgsqlV16DownSQL = `ALTER TABLE "{{users}}" DROP COLUMN "used_upload_data_transfer" CASCADE;
ALTER TABLE "{{users}}" DROP COLUMN "used_download_data_transfer" CASCADE;
ALTER TABLE "{{users}}" DROP COLUMN "upload_data_transfer" CASCADE;
ALTER TABLE "{{users}}" DROP COLUMN "total_data_transfer" CASCADE;
ALTER TABLE "{{users}}" DROP COLUMN "download_data_transfer" CASCADE;
DROP TABLE "{{active_transfers}}" CASCADE;
`
pgsqlV17SQL = `CREATE TABLE "{{groups}}" ("id" serial NOT NULL PRIMARY KEY, "name" varchar(255) NOT NULL UNIQUE,
"description" varchar(512) NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "user_settings" text NULL);
CREATE TABLE "{{groups_folders_mapping}}" ("id" serial NOT NULL PRIMARY KEY, "group_id" integer NOT NULL,
"folder_id" integer NOT NULL, "virtual_path" text NOT NULL, "quota_size" bigint NOT NULL, "quota_files" integer NOT NULL);
CREATE TABLE "{{users_groups_mapping}}" ("id" serial NOT NULL PRIMARY KEY, "user_id" integer NOT NULL,
"group_id" integer NOT NULL, "group_type" integer NOT NULL);
DROP INDEX "{{prefix}}folders_mapping_folder_id_idx";
DROP INDEX "{{prefix}}folders_mapping_user_id_idx";
ALTER TABLE "{{folders_mapping}}" DROP CONSTRAINT "{{prefix}}unique_mapping";
ALTER TABLE "{{folders_mapping}}" RENAME TO "{{users_folders_mapping}}";
ALTER TABLE "{{users_folders_mapping}}" ADD CONSTRAINT "{{prefix}}unique_user_folder_mapping" UNIQUE ("user_id", "folder_id");
CREATE INDEX "{{prefix}}users_folders_mapping_folder_id_idx" ON "{{users_folders_mapping}}" ("folder_id");
CREATE INDEX "{{prefix}}users_folders_mapping_user_id_idx" ON "{{users_folders_mapping}}" ("user_id");
ALTER TABLE "{{users_groups_mapping}}" ADD CONSTRAINT "{{prefix}}unique_user_group_mapping" UNIQUE ("user_id", "group_id"); ALTER TABLE "{{users_groups_mapping}}" ADD CONSTRAINT "{{prefix}}unique_user_group_mapping" UNIQUE ("user_id", "group_id");
ALTER TABLE "{{groups_folders_mapping}}" ADD CONSTRAINT "{{prefix}}unique_group_folder_mapping" UNIQUE ("group_id", "folder_id"); ALTER TABLE "{{groups_folders_mapping}}" ADD CONSTRAINT "{{prefix}}unique_group_folder_mapping" UNIQUE ("group_id", "folder_id");
CREATE INDEX "{{prefix}}users_groups_mapping_group_id_idx" ON "{{users_groups_mapping}}" ("group_id"); CREATE INDEX "{{prefix}}users_groups_mapping_group_id_idx" ON "{{users_groups_mapping}}" ("group_id");
@@ -131,77 +161,23 @@ CREATE INDEX "{{prefix}}groups_folders_mapping_group_id_idx" ON "{{groups_folder
ALTER TABLE "{{groups_folders_mapping}}" ADD CONSTRAINT "{{prefix}}groups_folders_mapping_group_id_fk_groups_id" ALTER TABLE "{{groups_folders_mapping}}" ADD CONSTRAINT "{{prefix}}groups_folders_mapping_group_id_fk_groups_id"
FOREIGN KEY ("group_id") REFERENCES "{{groups}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE; FOREIGN KEY ("group_id") REFERENCES "{{groups}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE;
CREATE INDEX "{{prefix}}groups_updated_at_idx" ON "{{groups}}" ("updated_at"); CREATE INDEX "{{prefix}}groups_updated_at_idx" ON "{{groups}}" ("updated_at");
CREATE INDEX "{{prefix}}users_folders_mapping_folder_id_idx" ON "{{users_folders_mapping}}" ("folder_id"); `
CREATE INDEX "{{prefix}}users_folders_mapping_user_id_idx" ON "{{users_folders_mapping}}" ("user_id"); pgsqlV17DownSQL = `DROP TABLE "{{users_groups_mapping}}" CASCADE;
CREATE INDEX "{{prefix}}api_keys_admin_id_idx" ON "{{api_keys}}" ("admin_id"); DROP TABLE "{{groups_folders_mapping}}" CASCADE;
CREATE INDEX "{{prefix}}api_keys_user_id_idx" ON "{{api_keys}}" ("user_id"); DROP TABLE "{{groups}}" CASCADE;
CREATE INDEX "{{prefix}}users_updated_at_idx" ON "{{users}}" ("updated_at"); DROP INDEX "{{prefix}}users_folders_mapping_folder_id_idx";
CREATE INDEX "{{prefix}}shares_user_id_idx" ON "{{shares}}" ("user_id"); DROP INDEX "{{prefix}}users_folders_mapping_user_id_idx";
CREATE INDEX "{{prefix}}defender_hosts_updated_at_idx" ON "{{defender_hosts}}" ("updated_at"); ALTER TABLE "{{users_folders_mapping}}" DROP CONSTRAINT "{{prefix}}unique_user_folder_mapping";
CREATE INDEX "{{prefix}}defender_hosts_ban_time_idx" ON "{{defender_hosts}}" ("ban_time"); ALTER TABLE "{{users_folders_mapping}}" RENAME TO "{{folders_mapping}}";
CREATE INDEX "{{prefix}}defender_events_date_time_idx" ON "{{defender_events}}" ("date_time"); ALTER TABLE "{{folders_mapping}}" ADD CONSTRAINT "{{prefix}}unique_mapping" UNIQUE ("user_id", "folder_id");
CREATE INDEX "{{prefix}}defender_events_host_id_idx" ON "{{defender_events}}" ("host_id"); CREATE INDEX "{{prefix}}folders_mapping_folder_id_idx" ON "{{folders_mapping}}" ("folder_id");
CREATE INDEX "{{prefix}}active_transfers_connection_id_idx" ON "{{active_transfers}}" ("connection_id"); CREATE INDEX "{{prefix}}folders_mapping_user_id_idx" ON "{{folders_mapping}}" ("user_id");
CREATE INDEX "{{prefix}}active_transfers_transfer_id_idx" ON "{{active_transfers}}" ("transfer_id"); `
CREATE INDEX "{{prefix}}active_transfers_updated_at_idx" ON "{{active_transfers}}" ("updated_at"); pgsqlV19SQL = `CREATE TABLE "{{shared_sessions}}" ("key" varchar(128) NOT NULL PRIMARY KEY,
"data" text NOT NULL, "type" integer NOT NULL, "timestamp" bigint NOT NULL);
CREATE INDEX "{{prefix}}shared_sessions_type_idx" ON "{{shared_sessions}}" ("type"); CREATE INDEX "{{prefix}}shared_sessions_type_idx" ON "{{shared_sessions}}" ("type");
CREATE INDEX "{{prefix}}shared_sessions_timestamp_idx" ON "{{shared_sessions}}" ("timestamp"); CREATE INDEX "{{prefix}}shared_sessions_timestamp_idx" ON "{{shared_sessions}}" ("timestamp");`
INSERT INTO {{schema_version}} (version) VALUES (19); pgsqlV19DownSQL = `DROP TABLE "{{shared_sessions}}" CASCADE;`
`
pgsqlV20SQL = `CREATE TABLE "{{events_rules}}" ("id" serial NOT NULL PRIMARY KEY, "name" varchar(255) NOT NULL UNIQUE,
"description" varchar(512) NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "trigger" integer NOT NULL,
"conditions" text NOT NULL, "deleted_at" bigint NOT NULL);
CREATE TABLE "{{events_actions}}" ("id" serial NOT NULL PRIMARY KEY, "name" varchar(255) NOT NULL UNIQUE,
"description" varchar(512) NULL, "type" integer NOT NULL, "options" text NOT NULL);
CREATE TABLE "{{rules_actions_mapping}}" ("id" serial NOT NULL PRIMARY KEY, "rule_id" integer NOT NULL,
"action_id" integer NOT NULL, "order" integer NOT NULL, "options" text NOT NULL);
CREATE TABLE "{{tasks}}" ("id" serial NOT NULL PRIMARY KEY, "name" varchar(255) NOT NULL UNIQUE, "updated_at" bigint NOT NULL,
"version" bigint NOT NULL);
ALTER TABLE "{{rules_actions_mapping}}" ADD CONSTRAINT "{{prefix}}unique_rule_action_mapping" UNIQUE ("rule_id", "action_id");
ALTER TABLE "{{rules_actions_mapping}}" ADD CONSTRAINT "{{prefix}}rules_actions_mapping_rule_id_fk_events_rules_id"
FOREIGN KEY ("rule_id") REFERENCES "{{events_rules}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE;
ALTER TABLE "{{rules_actions_mapping}}" ADD CONSTRAINT "{{prefix}}rules_actions_mapping_action_id_fk_events_targets_id"
FOREIGN KEY ("action_id") REFERENCES "{{events_actions}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE NO ACTION;
ALTER TABLE "{{users}}" ADD COLUMN "deleted_at" bigint DEFAULT 0 NOT NULL;
ALTER TABLE "{{users}}" ALTER COLUMN "deleted_at" DROP DEFAULT;
CREATE INDEX "{{prefix}}events_rules_updated_at_idx" ON "{{events_rules}}" ("updated_at");
CREATE INDEX "{{prefix}}events_rules_deleted_at_idx" ON "{{events_rules}}" ("deleted_at");
CREATE INDEX "{{prefix}}events_rules_trigger_idx" ON "{{events_rules}}" ("trigger");
CREATE INDEX "{{prefix}}rules_actions_mapping_rule_id_idx" ON "{{rules_actions_mapping}}" ("rule_id");
CREATE INDEX "{{prefix}}rules_actions_mapping_action_id_idx" ON "{{rules_actions_mapping}}" ("action_id");
CREATE INDEX "{{prefix}}rules_actions_mapping_order_idx" ON "{{rules_actions_mapping}}" ("order");
CREATE INDEX "{{prefix}}users_deleted_at_idx" ON "{{users}}" ("deleted_at");
`
pgsqlV20DownSQL = `DROP TABLE "{{rules_actions_mapping}}" CASCADE;
DROP TABLE "{{events_rules}}" CASCADE;
DROP TABLE "{{events_actions}}" CASCADE;
DROP TABLE "{{tasks}}" CASCADE;
ALTER TABLE "{{users}}" DROP COLUMN "deleted_at" CASCADE;
`
pgsqlV21SQL = `ALTER TABLE "{{users}}" ADD COLUMN "first_download" bigint DEFAULT 0 NOT NULL;
ALTER TABLE "{{users}}" ALTER COLUMN "first_download" DROP DEFAULT;
ALTER TABLE "{{users}}" ADD COLUMN "first_upload" bigint DEFAULT 0 NOT NULL;
ALTER TABLE "{{users}}" ALTER COLUMN "first_upload" DROP DEFAULT;
`
pgsqlV21DownSQL = `ALTER TABLE "{{users}}" DROP COLUMN "first_upload" CASCADE;
ALTER TABLE "{{users}}" DROP COLUMN "first_download" CASCADE;
`
pgsqlV22SQL = `CREATE TABLE "{{admins_groups_mapping}}" ("id" serial NOT NULL PRIMARY KEY,
"admin_id" integer NOT NULL, "group_id" integer NOT NULL, "options" text NOT NULL);
ALTER TABLE "{{admins_groups_mapping}}" ADD CONSTRAINT "{{prefix}}unique_admin_group_mapping" UNIQUE ("admin_id", "group_id");
ALTER TABLE "{{admins_groups_mapping}}" ADD CONSTRAINT "{{prefix}}admins_groups_mapping_admin_id_fk_admins_id"
FOREIGN KEY ("admin_id") REFERENCES "{{admins}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE;
ALTER TABLE "{{admins_groups_mapping}}" ADD CONSTRAINT "{{prefix}}admins_groups_mapping_group_id_fk_groups_id"
FOREIGN KEY ("group_id") REFERENCES "{{groups}}" ("id") MATCH SIMPLE ON UPDATE NO ACTION ON DELETE CASCADE;
CREATE INDEX "{{prefix}}admins_groups_mapping_admin_id_idx" ON "{{admins_groups_mapping}}" ("admin_id");
CREATE INDEX "{{prefix}}admins_groups_mapping_group_id_idx" ON "{{admins_groups_mapping}}" ("group_id");
`
pgsqlV22DownSQL = `ALTER TABLE "{{admins_groups_mapping}}" DROP CONSTRAINT "{{prefix}}unique_admin_group_mapping";
DROP TABLE "{{admins_groups_mapping}}" CASCADE;
`
pgsqlV23SQL = `CREATE TABLE "{{nodes}}" ("id" serial NOT NULL PRIMARY KEY, "name" varchar(255) NOT NULL UNIQUE,
"data" text NOT NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL);`
pgsqlV23DownSQL = `DROP TABLE "{{nodes}}" CASCADE;`
) )
// PGSQLProvider defines the auth provider for PostgreSQL database // PGSQLProvider defines the auth provider for PostgreSQL database
@@ -215,7 +191,7 @@ func init() {
func initializePGSQLProvider() error { func initializePGSQLProvider() error {
var err error var err error
dbHandle, err := sql.Open("pgx", getPGSQLConnectionString(false)) dbHandle, err := sql.Open("postgres", getPGSQLConnectionString(false))
if err == nil { if err == nil {
providerLog(logger.LevelDebug, "postgres database handle created, connection string: %#v, pool size: %v", providerLog(logger.LevelDebug, "postgres database handle created, connection string: %#v, pool size: %v",
getPGSQLConnectionString(true), config.PoolSize) getPGSQLConnectionString(true), config.PoolSize)
@@ -226,7 +202,6 @@ func initializePGSQLProvider() error {
dbHandle.SetMaxIdleConns(2) dbHandle.SetMaxIdleConns(2)
} }
dbHandle.SetConnMaxLifetime(240 * time.Second) dbHandle.SetConnMaxLifetime(240 * time.Second)
dbHandle.SetConnMaxIdleTime(120 * time.Second)
provider = &PGSQLProvider{dbHandle: dbHandle} provider = &PGSQLProvider{dbHandle: dbHandle}
} else { } else {
providerLog(logger.LevelError, "error creating postgres database handler, connection string: %#v, error: %v", providerLog(logger.LevelError, "error creating postgres database handler, connection string: %#v, error: %v",
@@ -250,12 +225,6 @@ func getPGSQLConnectionString(redactedPwd bool) string {
if config.ClientCert != "" && config.ClientKey != "" { if config.ClientCert != "" && config.ClientKey != "" {
connectionString += fmt.Sprintf(" sslcert='%v' sslkey='%v'", config.ClientCert, config.ClientKey) connectionString += fmt.Sprintf(" sslcert='%v' sslkey='%v'", config.ClientCert, config.ClientKey)
} }
if config.DisableSNI {
connectionString += " sslsni=0"
}
if config.TargetSessionAttrs != "" {
connectionString += fmt.Sprintf(" target_session_attrs='%s'", config.TargetSessionAttrs)
}
} else { } else {
connectionString = config.ConnectionString connectionString = config.ConnectionString
} }
@@ -314,8 +283,8 @@ func (p *PGSQLProvider) updateUser(user *User) error {
return sqlCommonUpdateUser(user, p.dbHandle) return sqlCommonUpdateUser(user, p.dbHandle)
} }
func (p *PGSQLProvider) deleteUser(user User, softDelete bool) error { func (p *PGSQLProvider) deleteUser(user User) error {
return sqlCommonDeleteUser(user, softDelete, p.dbHandle) return sqlCommonDeleteUser(user, p.dbHandle)
} }
func (p *PGSQLProvider) updateUserPassword(username, password string) error { func (p *PGSQLProvider) updateUserPassword(username, password string) error {
@@ -556,102 +525,6 @@ func (p *PGSQLProvider) cleanupSharedSessions(sessionType SessionType, before in
return sqlCommonCleanupSessions(sessionType, before, p.dbHandle) return sqlCommonCleanupSessions(sessionType, before, p.dbHandle)
} }
func (p *PGSQLProvider) getEventActions(limit, offset int, order string, minimal bool) ([]BaseEventAction, error) {
return sqlCommonGetEventActions(limit, offset, order, minimal, p.dbHandle)
}
func (p *PGSQLProvider) dumpEventActions() ([]BaseEventAction, error) {
return sqlCommonDumpEventActions(p.dbHandle)
}
func (p *PGSQLProvider) eventActionExists(name string) (BaseEventAction, error) {
return sqlCommonGetEventActionByName(name, p.dbHandle)
}
func (p *PGSQLProvider) addEventAction(action *BaseEventAction) error {
return sqlCommonAddEventAction(action, p.dbHandle)
}
func (p *PGSQLProvider) updateEventAction(action *BaseEventAction) error {
return sqlCommonUpdateEventAction(action, p.dbHandle)
}
func (p *PGSQLProvider) deleteEventAction(action BaseEventAction) error {
return sqlCommonDeleteEventAction(action, p.dbHandle)
}
func (p *PGSQLProvider) getEventRules(limit, offset int, order string) ([]EventRule, error) {
return sqlCommonGetEventRules(limit, offset, order, p.dbHandle)
}
func (p *PGSQLProvider) dumpEventRules() ([]EventRule, error) {
return sqlCommonDumpEventRules(p.dbHandle)
}
func (p *PGSQLProvider) getRecentlyUpdatedRules(after int64) ([]EventRule, error) {
return sqlCommonGetRecentlyUpdatedRules(after, p.dbHandle)
}
func (p *PGSQLProvider) eventRuleExists(name string) (EventRule, error) {
return sqlCommonGetEventRuleByName(name, p.dbHandle)
}
func (p *PGSQLProvider) addEventRule(rule *EventRule) error {
return sqlCommonAddEventRule(rule, p.dbHandle)
}
func (p *PGSQLProvider) updateEventRule(rule *EventRule) error {
return sqlCommonUpdateEventRule(rule, p.dbHandle)
}
func (p *PGSQLProvider) deleteEventRule(rule EventRule, softDelete bool) error {
return sqlCommonDeleteEventRule(rule, softDelete, p.dbHandle)
}
func (p *PGSQLProvider) getTaskByName(name string) (Task, error) {
return sqlCommonGetTaskByName(name, p.dbHandle)
}
func (p *PGSQLProvider) addTask(name string) error {
return sqlCommonAddTask(name, p.dbHandle)
}
func (p *PGSQLProvider) updateTask(name string, version int64) error {
return sqlCommonUpdateTask(name, version, p.dbHandle)
}
func (p *PGSQLProvider) updateTaskTimestamp(name string) error {
return sqlCommonUpdateTaskTimestamp(name, p.dbHandle)
}
func (p *PGSQLProvider) addNode() error {
return sqlCommonAddNode(p.dbHandle)
}
func (p *PGSQLProvider) getNodeByName(name string) (Node, error) {
return sqlCommonGetNodeByName(name, p.dbHandle)
}
func (p *PGSQLProvider) getNodes() ([]Node, error) {
return sqlCommonGetNodes(p.dbHandle)
}
func (p *PGSQLProvider) updateNodeTimestamp() error {
return sqlCommonUpdateNodeTimestamp(p.dbHandle)
}
func (p *PGSQLProvider) cleanupNodes() error {
return sqlCommonCleanupNodes(p.dbHandle)
}
func (p *PGSQLProvider) setFirstDownloadTimestamp(username string) error {
return sqlCommonSetFirstDownloadTimestamp(username, p.dbHandle)
}
func (p *PGSQLProvider) setFirstUploadTimestamp(username string) error {
return sqlCommonSetFirstUploadTimestamp(username, p.dbHandle)
}
func (p *PGSQLProvider) close() error { func (p *PGSQLProvider) close() error {
return p.dbHandle.Close() return p.dbHandle.Close()
} }
@@ -669,11 +542,26 @@ func (p *PGSQLProvider) initializeDatabase() error {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return errSchemaVersionEmpty return errSchemaVersionEmpty
} }
logger.InfoToConsole("creating initial database schema, version 19") logger.InfoToConsole("creating initial database schema, version 15")
providerLog(logger.LevelInfo, "creating initial database schema, version 19") providerLog(logger.LevelInfo, "creating initial database schema, version 15")
initialSQL := sqlReplaceAll(pgsqlInitial) initialSQL := strings.ReplaceAll(pgsqlInitial, "{{schema_version}}", sqlTableSchemaVersion)
initialSQL = strings.ReplaceAll(initialSQL, "{{admins}}", sqlTableAdmins)
initialSQL = strings.ReplaceAll(initialSQL, "{{folders}}", sqlTableFolders)
initialSQL = strings.ReplaceAll(initialSQL, "{{users}}", sqlTableUsers)
initialSQL = strings.ReplaceAll(initialSQL, "{{folders_mapping}}", sqlTableFoldersMapping)
initialSQL = strings.ReplaceAll(initialSQL, "{{api_keys}}", sqlTableAPIKeys)
initialSQL = strings.ReplaceAll(initialSQL, "{{shares}}", sqlTableShares)
initialSQL = strings.ReplaceAll(initialSQL, "{{defender_events}}", sqlTableDefenderEvents)
initialSQL = strings.ReplaceAll(initialSQL, "{{defender_hosts}}", sqlTableDefenderHosts)
initialSQL = strings.ReplaceAll(initialSQL, "{{prefix}}", config.SQLTablesPrefix)
if config.Driver == CockroachDataProviderName {
// Cockroach does not support deferrable constraint validation, we don't need them,
// we keep these definitions for the PostgreSQL driver to avoid changes for users
// upgrading from old SFTPGo versions
initialSQL = strings.ReplaceAll(initialSQL, "DEFERRABLE INITIALLY DEFERRED", "")
}
return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{initialSQL}, 19, true) return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{initialSQL}, 15, true)
} }
func (p *PGSQLProvider) migrateDatabase() error { //nolint:dupl func (p *PGSQLProvider) migrateDatabase() error { //nolint:dupl
@@ -686,28 +574,28 @@ func (p *PGSQLProvider) migrateDatabase() error { //nolint:dupl
case version == sqlDatabaseVersion: case version == sqlDatabaseVersion:
providerLog(logger.LevelDebug, "sql database is up to date, current version: %v", version) providerLog(logger.LevelDebug, "sql database is up to date, current version: %v", version)
return ErrNoInitRequired return ErrNoInitRequired
case version < 19: case version < 15:
err = fmt.Errorf("database schema version %v is too old, please see the upgrading docs", version) err = fmt.Errorf("database version %v is too old, please see the upgrading docs", version)
providerLog(logger.LevelError, "%v", err) providerLog(logger.LevelError, "%v", err)
logger.ErrorToConsole("%v", err) logger.ErrorToConsole("%v", err)
return err return err
case version == 19: case version == 15:
return updatePgSQLDatabaseFromV19(p.dbHandle) return updatePGSQLDatabaseFromV15(p.dbHandle)
case version == 20: case version == 16:
return updatePgSQLDatabaseFromV20(p.dbHandle) return updatePGSQLDatabaseFromV16(p.dbHandle)
case version == 21: case version == 17:
return updatePgSQLDatabaseFromV21(p.dbHandle) return updatePGSQLDatabaseFromV17(p.dbHandle)
case version == 22: case version == 18:
return updatePgSQLDatabaseFromV21(p.dbHandle) return updatePGSQLDatabaseFromV18(p.dbHandle)
default: default:
if version > sqlDatabaseVersion { if version > sqlDatabaseVersion {
providerLog(logger.LevelError, "database schema version %v is newer than the supported one: %v", version, providerLog(logger.LevelError, "database version %v is newer than the supported one: %v", version,
sqlDatabaseVersion) sqlDatabaseVersion)
logger.WarnToConsole("database schema version %v is newer than the supported one: %v", version, logger.WarnToConsole("database version %v is newer than the supported one: %v", version,
sqlDatabaseVersion) sqlDatabaseVersion)
return nil return nil
} }
return fmt.Errorf("database schema version not handled: %v", version) return fmt.Errorf("database version not handled: %v", version)
} }
} }
@@ -721,16 +609,16 @@ func (p *PGSQLProvider) revertDatabase(targetVersion int) error {
} }
switch dbVersion.Version { switch dbVersion.Version {
case 20: case 16:
return downgradePgSQLDatabaseFromV20(p.dbHandle) return downgradePGSQLDatabaseFromV16(p.dbHandle)
case 21: case 17:
return downgradePgSQLDatabaseFromV21(p.dbHandle) return downgradePGSQLDatabaseFromV17(p.dbHandle)
case 22: case 18:
return downgradePgSQLDatabaseFromV22(p.dbHandle) return downgradePGSQLDatabaseFromV18(p.dbHandle)
case 23: case 19:
return downgradePgSQLDatabaseFromV23(p.dbHandle) return downgradePGSQLDatabaseFromV19(p.dbHandle)
default: default:
return fmt.Errorf("database schema version not handled: %v", dbVersion.Version) return fmt.Errorf("database version not handled: %v", dbVersion.Version)
} }
} }
@@ -739,135 +627,153 @@ func (p *PGSQLProvider) resetDatabase() error {
return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{sql}, 0, false) return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{sql}, 0, false)
} }
func updatePgSQLDatabaseFromV19(dbHandle *sql.DB) error { func updatePGSQLDatabaseFromV15(dbHandle *sql.DB) error {
if err := updatePgSQLDatabaseFrom19To20(dbHandle); err != nil { if err := updatePGSQLDatabaseFrom15To16(dbHandle); err != nil {
return err return err
} }
return updatePgSQLDatabaseFromV20(dbHandle) return updatePGSQLDatabaseFromV16(dbHandle)
} }
func updatePgSQLDatabaseFromV20(dbHandle *sql.DB) error { func updatePGSQLDatabaseFromV16(dbHandle *sql.DB) error {
if err := updatePgSQLDatabaseFrom20To21(dbHandle); err != nil { if err := updatePGSQLDatabaseFrom16To17(dbHandle); err != nil {
return err return err
} }
return updatePgSQLDatabaseFromV21(dbHandle) return updatePGSQLDatabaseFromV17(dbHandle)
} }
func updatePgSQLDatabaseFromV21(dbHandle *sql.DB) error { func updatePGSQLDatabaseFromV17(dbHandle *sql.DB) error {
if err := updatePgSQLDatabaseFrom21To22(dbHandle); err != nil { if err := updatePGSQLDatabaseFrom17To18(dbHandle); err != nil {
return err return err
} }
return updatePgSQLDatabaseFromV22(dbHandle) return updatePGSQLDatabaseFromV18(dbHandle)
} }
func updatePgSQLDatabaseFromV22(dbHandle *sql.DB) error { func updatePGSQLDatabaseFromV18(dbHandle *sql.DB) error {
return updatePgSQLDatabaseFrom22To23(dbHandle) return updatePGSQLDatabaseFrom18To19(dbHandle)
} }
func downgradePgSQLDatabaseFromV20(dbHandle *sql.DB) error { func downgradePGSQLDatabaseFromV16(dbHandle *sql.DB) error {
return downgradePgSQLDatabaseFrom20To19(dbHandle) return downgradePGSQLDatabaseFrom16To15(dbHandle)
} }
func downgradePgSQLDatabaseFromV21(dbHandle *sql.DB) error { func downgradePGSQLDatabaseFromV17(dbHandle *sql.DB) error {
if err := downgradePgSQLDatabaseFrom21To20(dbHandle); err != nil { if err := downgradePGSQLDatabaseFrom17To16(dbHandle); err != nil {
return err return err
} }
return downgradePgSQLDatabaseFromV20(dbHandle) return downgradePGSQLDatabaseFromV16(dbHandle)
} }
func downgradePgSQLDatabaseFromV22(dbHandle *sql.DB) error { func downgradePGSQLDatabaseFromV18(dbHandle *sql.DB) error {
if err := downgradePgSQLDatabaseFrom22To21(dbHandle); err != nil { if err := downgradePGSQLDatabaseFrom18To17(dbHandle); err != nil {
return err return err
} }
return downgradePgSQLDatabaseFromV21(dbHandle) return downgradePGSQLDatabaseFromV17(dbHandle)
} }
func downgradePgSQLDatabaseFromV23(dbHandle *sql.DB) error { func downgradePGSQLDatabaseFromV19(dbHandle *sql.DB) error {
if err := downgradePgSQLDatabaseFrom23To22(dbHandle); err != nil { if err := downgradePGSQLDatabaseFrom19To18(dbHandle); err != nil {
return err return err
} }
return downgradePgSQLDatabaseFromV22(dbHandle) return downgradePGSQLDatabaseFromV18(dbHandle)
} }
func updatePgSQLDatabaseFrom19To20(dbHandle *sql.DB) error { func updatePGSQLDatabaseFrom15To16(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database schema version: 19 -> 20") logger.InfoToConsole("updating database version: 15 -> 16")
providerLog(logger.LevelInfo, "updating database schema version: 19 -> 20") providerLog(logger.LevelInfo, "updating database version: 15 -> 16")
sql := pgsqlV20SQL sql := strings.ReplaceAll(pgsqlV16SQL, "{{users}}", sqlTableUsers)
if config.Driver == CockroachDataProviderName { sql = strings.ReplaceAll(sql, "{{active_transfers}}", sqlTableActiveTransfers)
sql = strings.ReplaceAll(sql, `ALTER TABLE "{{users}}" ALTER COLUMN "deleted_at" DROP DEFAULT;`, "")
}
sql = strings.ReplaceAll(sql, "{{events_actions}}", sqlTableEventsActions)
sql = strings.ReplaceAll(sql, "{{events_rules}}", sqlTableEventsRules)
sql = strings.ReplaceAll(sql, "{{rules_actions_mapping}}", sqlTableRulesActionsMapping)
sql = strings.ReplaceAll(sql, "{{tasks}}", sqlTableTasks)
sql = strings.ReplaceAll(sql, "{{users}}", sqlTableUsers)
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix) sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 20, true)
}
func updatePgSQLDatabaseFrom20To21(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database schema version: 20 -> 21")
providerLog(logger.LevelInfo, "updating database schema version: 20 -> 21")
sql := pgsqlV21SQL
if config.Driver == CockroachDataProviderName { if config.Driver == CockroachDataProviderName {
sql = strings.ReplaceAll(sql, `ALTER TABLE "{{users}}" ALTER COLUMN "first_download" DROP DEFAULT;`, "") // Cockroach does not allow to run this schema migration within a transaction
sql = strings.ReplaceAll(sql, `ALTER TABLE "{{users}}" ALTER COLUMN "first_upload" DROP DEFAULT;`, "") ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
defer cancel()
for _, q := range strings.Split(sql, ";") {
if strings.TrimSpace(q) == "" {
continue
}
_, err := dbHandle.ExecContext(ctx, q)
if err != nil {
return err
}
}
return sqlCommonUpdateDatabaseVersion(ctx, dbHandle, 16)
} }
sql = strings.ReplaceAll(sql, "{{users}}", sqlTableUsers) return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 16, true)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 21, true)
} }
func updatePgSQLDatabaseFrom21To22(dbHandle *sql.DB) error { func updatePGSQLDatabaseFrom16To17(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database schema version: 21 -> 22") logger.InfoToConsole("updating database version: 16 -> 17")
providerLog(logger.LevelInfo, "updating database schema version: 21 -> 22") providerLog(logger.LevelInfo, "updating database version: 16 -> 17")
sql := strings.ReplaceAll(pgsqlV22SQL, "{{admins_groups_mapping}}", sqlTableAdminsGroupsMapping) sql := pgsqlV17SQL
sql = strings.ReplaceAll(sql, "{{admins}}", sqlTableAdmins) if config.Driver == CockroachDataProviderName {
sql = strings.ReplaceAll(sql, `ALTER TABLE "{{folders_mapping}}" DROP CONSTRAINT "{{prefix}}unique_mapping";`,
`DROP INDEX "{{prefix}}unique_mapping" CASCADE;`)
}
sql = strings.ReplaceAll(sql, "{{groups}}", sqlTableGroups) sql = strings.ReplaceAll(sql, "{{groups}}", sqlTableGroups)
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 22, true)
}
func updatePgSQLDatabaseFrom22To23(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database schema version: 22 -> 23")
providerLog(logger.LevelInfo, "updating database schema version: 22 -> 23")
sql := strings.ReplaceAll(pgsqlV23SQL, "{{nodes}}", sqlTableNodes)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 23, true)
}
func downgradePgSQLDatabaseFrom20To19(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database schema version: 20 -> 19")
providerLog(logger.LevelInfo, "downgrading database schema version: 20 -> 19")
sql := strings.ReplaceAll(pgsqlV20DownSQL, "{{events_actions}}", sqlTableEventsActions)
sql = strings.ReplaceAll(sql, "{{events_rules}}", sqlTableEventsRules)
sql = strings.ReplaceAll(sql, "{{rules_actions_mapping}}", sqlTableRulesActionsMapping)
sql = strings.ReplaceAll(sql, "{{users}}", sqlTableUsers) sql = strings.ReplaceAll(sql, "{{users}}", sqlTableUsers)
sql = strings.ReplaceAll(sql, "{{tasks}}", sqlTableTasks) sql = strings.ReplaceAll(sql, "{{folders}}", sqlTableFolders)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 19, false) sql = strings.ReplaceAll(sql, "{{folders_mapping}}", sqlTableFoldersMapping)
} sql = strings.ReplaceAll(sql, "{{users_folders_mapping}}", sqlTableUsersFoldersMapping)
sql = strings.ReplaceAll(sql, "{{users_groups_mapping}}", sqlTableUsersGroupsMapping)
func downgradePgSQLDatabaseFrom21To20(dbHandle *sql.DB) error { sql = strings.ReplaceAll(sql, "{{groups_folders_mapping}}", sqlTableGroupsFoldersMapping)
logger.InfoToConsole("downgrading database schema version: 21 -> 20")
providerLog(logger.LevelInfo, "downgrading database schema version: 21 -> 20")
sql := strings.ReplaceAll(pgsqlV21DownSQL, "{{users}}", sqlTableUsers)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 20, false)
}
func downgradePgSQLDatabaseFrom22To21(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database schema version: 22 -> 21")
providerLog(logger.LevelInfo, "downgrading database schema version: 22 -> 21")
sql := pgsqlV22DownSQL
if config.Driver == CockroachDataProviderName {
sql = strings.ReplaceAll(sql, `ALTER TABLE "{{admins_groups_mapping}}" DROP CONSTRAINT "{{prefix}}unique_admin_group_mapping";`,
`DROP INDEX "{{prefix}}unique_admin_group_mapping" CASCADE;`)
}
sql = strings.ReplaceAll(sql, "{{admins_groups_mapping}}", sqlTableAdminsGroupsMapping)
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix) sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 21, false) return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 17, true)
} }
func downgradePgSQLDatabaseFrom23To22(dbHandle *sql.DB) error { func updatePGSQLDatabaseFrom17To18(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database schema version: 23 -> 22") logger.InfoToConsole("updating database version: 17 -> 18")
providerLog(logger.LevelInfo, "downgrading database schema version: 23 -> 22") providerLog(logger.LevelInfo, "updating database version: 17 -> 18")
sql := strings.ReplaceAll(pgsqlV23DownSQL, "{{nodes}}", sqlTableNodes) if err := importGCSCredentials(); err != nil {
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 22, false) return err
}
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, nil, 18, true)
}
func updatePGSQLDatabaseFrom18To19(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database version: 18 -> 19")
providerLog(logger.LevelInfo, "updating database version: 18 -> 19")
sql := strings.ReplaceAll(pgsqlV19SQL, "{{shared_sessions}}", sqlTableSharedSessions)
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 19, true)
}
func downgradePGSQLDatabaseFrom16To15(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database version: 16 -> 15")
providerLog(logger.LevelInfo, "downgrading database version: 16 -> 15")
sql := strings.ReplaceAll(pgsqlV16DownSQL, "{{users}}", sqlTableUsers)
sql = strings.ReplaceAll(sql, "{{active_transfers}}", sqlTableActiveTransfers)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 15, false)
}
func downgradePGSQLDatabaseFrom17To16(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database version: 17 -> 16")
providerLog(logger.LevelInfo, "downgrading database version: 17 -> 16")
sql := pgsqlV17DownSQL
if config.Driver == CockroachDataProviderName {
sql = strings.ReplaceAll(sql, `ALTER TABLE "{{users_folders_mapping}}" DROP CONSTRAINT "{{prefix}}unique_user_folder_mapping";`,
`DROP INDEX "{{prefix}}unique_user_folder_mapping" CASCADE;`)
}
sql = strings.ReplaceAll(sql, "{{groups}}", sqlTableGroups)
sql = strings.ReplaceAll(sql, "{{users}}", sqlTableUsers)
sql = strings.ReplaceAll(sql, "{{folders}}", sqlTableFolders)
sql = strings.ReplaceAll(sql, "{{folders_mapping}}", sqlTableFoldersMapping)
sql = strings.ReplaceAll(sql, "{{users_folders_mapping}}", sqlTableUsersFoldersMapping)
sql = strings.ReplaceAll(sql, "{{users_groups_mapping}}", sqlTableUsersGroupsMapping)
sql = strings.ReplaceAll(sql, "{{groups_folders_mapping}}", sqlTableGroupsFoldersMapping)
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 16, false)
}
func downgradePGSQLDatabaseFrom18To17(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database version: 18 -> 17")
providerLog(logger.LevelInfo, "downgrading database version: 18 -> 17")
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, nil, 17, false)
}
func downgradePGSQLDatabaseFrom19To18(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database version: 19 -> 18")
providerLog(logger.LevelInfo, "downgrading database version: 19 -> 18")
sql := strings.ReplaceAll(pgsqlV19DownSQL, "{{shared_sessions}}", sqlTableSharedSessions)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 18, false)
} }

View File

@@ -20,7 +20,7 @@ package dataprovider
import ( import (
"errors" "errors"
"github.com/drakkan/sftpgo/v2/internal/version" "github.com/drakkan/sftpgo/v2/version"
) )
func init() { func init() {

View File

@@ -18,7 +18,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/logger"
) )
var delayedQuotaUpdater quotaUpdater var delayedQuotaUpdater quotaUpdater

110
dataprovider/scheduler.go Normal file
View File

@@ -0,0 +1,110 @@
// Copyright (C) 2019-2022 Nicola Murino
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package dataprovider
import (
"fmt"
"sync/atomic"
"time"
"github.com/robfig/cron/v3"
"github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/metric"
"github.com/drakkan/sftpgo/v2/util"
)
var (
scheduler *cron.Cron
lastCachesUpdate int64
// used for bolt and memory providers, so we avoid iterating all users
// to find recently modified ones
lastUserUpdate int64
)
func stopScheduler() {
if scheduler != nil {
scheduler.Stop()
scheduler = nil
}
}
func startScheduler() error {
stopScheduler()
scheduler = cron.New()
_, err := scheduler.AddFunc("@every 30s", checkDataprovider)
if err != nil {
return fmt.Errorf("unable to schedule dataprovider availability check: %w", err)
}
if config.AutoBackup.Enabled {
spec := fmt.Sprintf("0 %v * * %v", config.AutoBackup.Hour, config.AutoBackup.DayOfWeek)
_, err = scheduler.AddFunc(spec, config.doBackup)
if err != nil {
return fmt.Errorf("unable to schedule auto backup: %w", err)
}
}
err = addScheduledCacheUpdates()
if err != nil {
return err
}
scheduler.Start()
return nil
}
func addScheduledCacheUpdates() error {
lastCachesUpdate = util.GetTimeAsMsSinceEpoch(time.Now())
_, err := scheduler.AddFunc("@every 10m", checkCacheUpdates)
if err != nil {
return fmt.Errorf("unable to schedule cache updates: %w", err)
}
return nil
}
func checkDataprovider() {
err := provider.checkAvailability()
if err != nil {
providerLog(logger.LevelError, "check availability error: %v", err)
}
metric.UpdateDataProviderAvailability(err)
}
func checkCacheUpdates() {
providerLog(logger.LevelDebug, "start caches check, update time %v", util.GetTimeFromMsecSinceEpoch(lastCachesUpdate))
checkTime := util.GetTimeAsMsSinceEpoch(time.Now())
users, err := provider.getRecentlyUpdatedUsers(lastCachesUpdate)
if err != nil {
providerLog(logger.LevelError, "unable to get recently updated users: %v", err)
return
}
for _, user := range users {
providerLog(logger.LevelDebug, "invalidate caches for user %#v", user.Username)
webDAVUsersCache.swap(&user)
cachedPasswords.Remove(user.Username)
}
lastCachesUpdate = checkTime
providerLog(logger.LevelDebug, "end caches check, new update time %v", util.GetTimeFromMsecSinceEpoch(lastCachesUpdate))
}
func setLastUserUpdate() {
atomic.StoreInt64(&lastUserUpdate, util.GetTimeAsMsSinceEpoch(time.Now()))
}
func getLastUserUpdate() int64 {
return atomic.LoadInt64(&lastUserUpdate)
}

View File

@@ -24,8 +24,8 @@ import (
"github.com/alexedwards/argon2id" "github.com/alexedwards/argon2id"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/util"
) )
// ShareScope defines the supported share scopes // ShareScope defines the supported share scopes

View File

@@ -30,10 +30,10 @@ import (
// we import go-sqlite3 here to be able to disable SQLite support using a build tag // we import go-sqlite3 here to be able to disable SQLite support using a build tag
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/util"
"github.com/drakkan/sftpgo/v2/internal/version" "github.com/drakkan/sftpgo/v2/version"
"github.com/drakkan/sftpgo/v2/internal/vfs" "github.com/drakkan/sftpgo/v2/vfs"
) )
const ( const (
@@ -41,7 +41,6 @@ const (
DROP TABLE IF EXISTS "{{folders_mapping}}"; DROP TABLE IF EXISTS "{{folders_mapping}}";
DROP TABLE IF EXISTS "{{users_folders_mapping}}"; DROP TABLE IF EXISTS "{{users_folders_mapping}}";
DROP TABLE IF EXISTS "{{users_groups_mapping}}"; DROP TABLE IF EXISTS "{{users_groups_mapping}}";
DROP TABLE IF EXISTS "{{admins_groups_mapping}}";
DROP TABLE IF EXISTS "{{groups_folders_mapping}}"; DROP TABLE IF EXISTS "{{groups_folders_mapping}}";
DROP TABLE IF EXISTS "{{admins}}"; DROP TABLE IF EXISTS "{{admins}}";
DROP TABLE IF EXISTS "{{folders}}"; DROP TABLE IF EXISTS "{{folders}}";
@@ -52,10 +51,6 @@ DROP TABLE IF EXISTS "{{defender_events}}";
DROP TABLE IF EXISTS "{{defender_hosts}}"; DROP TABLE IF EXISTS "{{defender_hosts}}";
DROP TABLE IF EXISTS "{{active_transfers}}"; DROP TABLE IF EXISTS "{{active_transfers}}";
DROP TABLE IF EXISTS "{{shared_sessions}}"; DROP TABLE IF EXISTS "{{shared_sessions}}";
DROP TABLE IF EXISTS "{{rules_actions_mapping}}";
DROP TABLE IF EXISTS "{{events_rules}}";
DROP TABLE IF EXISTS "{{events_actions}}";
DROP TABLE IF EXISTS "{{tasks}}";
DROP TABLE IF EXISTS "{{schema_version}}"; DROP TABLE IF EXISTS "{{schema_version}}";
` `
sqliteInitialSQL = `CREATE TABLE "{{schema_version}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "version" integer NOT NULL); sqliteInitialSQL = `CREATE TABLE "{{schema_version}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "version" integer NOT NULL);
@@ -63,11 +58,6 @@ CREATE TABLE "{{admins}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "use
"description" varchar(512) NULL, "password" varchar(255) NOT NULL, "email" varchar(255) NULL, "status" integer NOT NULL, "description" varchar(512) NULL, "password" varchar(255) NOT NULL, "email" varchar(255) NULL, "status" integer NOT NULL,
"permissions" text NOT NULL, "filters" text NULL, "additional_info" text NULL, "last_login" bigint NOT NULL, "permissions" text NOT NULL, "filters" text NULL, "additional_info" text NULL, "last_login" bigint NOT NULL,
"created_at" bigint NOT NULL, "updated_at" bigint NOT NULL); "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL);
CREATE TABLE "{{active_transfers}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "connection_id" varchar(100) NOT NULL,
"transfer_id" bigint NOT NULL, "transfer_type" integer NOT NULL, "username" varchar(255) NOT NULL,
"folder_name" varchar(255) NULL, "ip" varchar(50) NOT NULL, "truncated_size" bigint NOT NULL,
"current_ul_size" bigint NOT NULL, "current_dl_size" bigint NOT NULL, "created_at" bigint NOT NULL,
"updated_at" bigint NOT NULL);
CREATE TABLE "{{defender_hosts}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "ip" varchar(50) NOT NULL UNIQUE, CREATE TABLE "{{defender_hosts}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "ip" varchar(50) NOT NULL UNIQUE,
"ban_time" bigint NOT NULL, "updated_at" bigint NOT NULL); "ban_time" bigint NOT NULL, "updated_at" bigint NOT NULL);
CREATE TABLE "{{defender_events}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "date_time" bigint NOT NULL, CREATE TABLE "{{defender_events}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "date_time" bigint NOT NULL,
@@ -76,34 +66,18 @@ DEFERRABLE INITIALLY DEFERRED);
CREATE TABLE "{{folders}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "name" varchar(255) NOT NULL UNIQUE, CREATE TABLE "{{folders}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "name" varchar(255) NOT NULL UNIQUE,
"description" varchar(512) NULL, "path" text NULL, "used_quota_size" bigint NOT NULL, "used_quota_files" integer NOT NULL, "description" varchar(512) NULL, "path" text NULL, "used_quota_size" bigint NOT NULL, "used_quota_files" integer NOT NULL,
"last_quota_update" bigint NOT NULL, "filesystem" text NULL); "last_quota_update" bigint NOT NULL, "filesystem" text NULL);
CREATE TABLE "{{groups}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "name" varchar(255) NOT NULL UNIQUE,
"description" varchar(512) NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "user_settings" text NULL);
CREATE TABLE "{{shared_sessions}}" ("key" varchar(128) NOT NULL PRIMARY KEY, "data" text NOT NULL,
"type" integer NOT NULL, "timestamp" bigint NOT NULL);
CREATE TABLE "{{users}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "username" varchar(255) NOT NULL UNIQUE, CREATE TABLE "{{users}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "username" varchar(255) NOT NULL UNIQUE,
"status" integer NOT NULL, "expiration_date" bigint NOT NULL, "description" varchar(512) NULL, "password" text NULL, "status" integer NOT NULL, "expiration_date" bigint NOT NULL, "description" varchar(512) NULL, "password" text NULL,
"public_keys" text NULL, "home_dir" text NOT NULL, "uid" bigint NOT NULL, "gid" bigint NOT NULL, "public_keys" text NULL, "home_dir" text NOT NULL, "uid" bigint NOT NULL, "gid" bigint NOT NULL,
"max_sessions" integer NOT NULL, "quota_size" bigint NOT NULL, "quota_files" integer NOT NULL, "permissions" text NOT NULL, "max_sessions" integer NOT NULL, "quota_size" bigint NOT NULL, "quota_files" integer NOT NULL, "permissions" text NOT NULL,
"used_quota_size" bigint NOT NULL, "used_quota_files" integer NOT NULL, "last_quota_update" bigint NOT NULL, "used_quota_size" bigint NOT NULL, "used_quota_files" integer NOT NULL, "last_quota_update" bigint NOT NULL,
"upload_bandwidth" integer NOT NULL, "download_bandwidth" integer NOT NULL, "last_login" bigint NOT NULL, "upload_bandwidth" integer NOT NULL, "download_bandwidth" integer NOT NULL, "last_login" bigint NOT NULL, "filters" text NULL,
"filters" text NULL, "filesystem" text NULL, "additional_info" text NULL, "created_at" bigint NOT NULL, "filesystem" text NULL, "additional_info" text NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL,
"updated_at" bigint NOT NULL, "email" varchar(255) NULL, "upload_data_transfer" integer NOT NULL, "email" varchar(255) NULL);
"download_data_transfer" integer NOT NULL, "total_data_transfer" integer NOT NULL, "used_upload_data_transfer" integer NOT NULL, CREATE TABLE "{{folders_mapping}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "virtual_path" text NOT NULL,
"used_download_data_transfer" integer NOT NULL); "quota_size" bigint NOT NULL, "quota_files" integer NOT NULL, "folder_id" integer NOT NULL REFERENCES "{{folders}}" ("id")
CREATE TABLE "{{groups_folders_mapping}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED, "user_id" integer NOT NULL REFERENCES "{{users}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED,
"folder_id" integer NOT NULL REFERENCES "{{folders}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED, CONSTRAINT "{{prefix}}unique_mapping" UNIQUE ("user_id", "folder_id"));
"group_id" integer NOT NULL REFERENCES "{{groups}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED,
"virtual_path" text NOT NULL, "quota_size" bigint NOT NULL, "quota_files" integer NOT NULL,
CONSTRAINT "{{prefix}}unique_group_folder_mapping" UNIQUE ("group_id", "folder_id"));
CREATE TABLE "{{users_groups_mapping}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"user_id" integer NOT NULL REFERENCES "{{users}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED,
"group_id" integer NOT NULL REFERENCES "{{groups}}" ("id") ON DELETE NO ACTION,
"group_type" integer NOT NULL, CONSTRAINT "{{prefix}}unique_user_group_mapping" UNIQUE ("user_id", "group_id"));
CREATE TABLE "{{users_folders_mapping}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"user_id" integer NOT NULL REFERENCES "{{users}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED,
"folder_id" integer NOT NULL REFERENCES "{{folders}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED,
"virtual_path" text NOT NULL, "quota_size" bigint NOT NULL, "quota_files" integer NOT NULL,
CONSTRAINT "{{prefix}}unique_user_folder_mapping" UNIQUE ("user_id", "folder_id"));
CREATE TABLE "{{shares}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "share_id" varchar(60) NOT NULL UNIQUE, CREATE TABLE "{{shares}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "share_id" varchar(60) NOT NULL UNIQUE,
"name" varchar(255) NOT NULL, "description" varchar(512) NULL, "scope" integer NOT NULL, "paths" text NOT NULL, "name" varchar(255) NOT NULL, "description" varchar(512) NULL, "scope" integer NOT NULL, "paths" text NOT NULL,
"created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "last_use_at" bigint NOT NULL, "expires_at" bigint NOT NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "last_use_at" bigint NOT NULL, "expires_at" bigint NOT NULL,
@@ -114,13 +88,8 @@ CREATE TABLE "{{api_keys}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "n
"created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "last_use_at" bigint NOT NULL, "expires_at" bigint NOT NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "last_use_at" bigint NOT NULL, "expires_at" bigint NOT NULL,
"description" text NULL, "admin_id" integer NULL REFERENCES "{{admins}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED, "description" text NULL, "admin_id" integer NULL REFERENCES "{{admins}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED,
"user_id" integer NULL REFERENCES "{{users}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED); "user_id" integer NULL REFERENCES "{{users}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED);
CREATE INDEX "{{prefix}}groups_updated_at_idx" ON "{{groups}}" ("updated_at"); CREATE INDEX "{{prefix}}folders_mapping_folder_id_idx" ON "{{folders_mapping}}" ("folder_id");
CREATE INDEX "{{prefix}}users_folders_mapping_folder_id_idx" ON "{{users_folders_mapping}}" ("folder_id"); CREATE INDEX "{{prefix}}folders_mapping_user_id_idx" ON "{{folders_mapping}}" ("user_id");
CREATE INDEX "{{prefix}}users_folders_mapping_user_id_idx" ON "{{users_folders_mapping}}" ("user_id");
CREATE INDEX "{{prefix}}users_groups_mapping_group_id_idx" ON "{{users_groups_mapping}}" ("group_id");
CREATE INDEX "{{prefix}}users_groups_mapping_user_id_idx" ON "{{users_groups_mapping}}" ("user_id");
CREATE INDEX "{{prefix}}groups_folders_mapping_folder_id_idx" ON "{{groups_folders_mapping}}" ("folder_id");
CREATE INDEX "{{prefix}}groups_folders_mapping_group_id_idx" ON "{{groups_folders_mapping}}" ("group_id");
CREATE INDEX "{{prefix}}api_keys_admin_id_idx" ON "{{api_keys}}" ("admin_id"); CREATE INDEX "{{prefix}}api_keys_admin_id_idx" ON "{{api_keys}}" ("admin_id");
CREATE INDEX "{{prefix}}api_keys_user_id_idx" ON "{{api_keys}}" ("user_id"); CREATE INDEX "{{prefix}}api_keys_user_id_idx" ON "{{api_keys}}" ("user_id");
CREATE INDEX "{{prefix}}users_updated_at_idx" ON "{{users}}" ("updated_at"); CREATE INDEX "{{prefix}}users_updated_at_idx" ON "{{users}}" ("updated_at");
@@ -129,54 +98,78 @@ CREATE INDEX "{{prefix}}defender_hosts_updated_at_idx" ON "{{defender_hosts}}" (
CREATE INDEX "{{prefix}}defender_hosts_ban_time_idx" ON "{{defender_hosts}}" ("ban_time"); CREATE INDEX "{{prefix}}defender_hosts_ban_time_idx" ON "{{defender_hosts}}" ("ban_time");
CREATE INDEX "{{prefix}}defender_events_date_time_idx" ON "{{defender_events}}" ("date_time"); CREATE INDEX "{{prefix}}defender_events_date_time_idx" ON "{{defender_events}}" ("date_time");
CREATE INDEX "{{prefix}}defender_events_host_id_idx" ON "{{defender_events}}" ("host_id"); CREATE INDEX "{{prefix}}defender_events_host_id_idx" ON "{{defender_events}}" ("host_id");
INSERT INTO {{schema_version}} (version) VALUES (15);
`
sqliteV16SQL = `ALTER TABLE "{{users}}" ADD COLUMN "download_data_transfer" integer DEFAULT 0 NOT NULL;
ALTER TABLE "{{users}}" ADD COLUMN "total_data_transfer" integer DEFAULT 0 NOT NULL;
ALTER TABLE "{{users}}" ADD COLUMN "upload_data_transfer" integer DEFAULT 0 NOT NULL;
ALTER TABLE "{{users}}" ADD COLUMN "used_download_data_transfer" integer DEFAULT 0 NOT NULL;
ALTER TABLE "{{users}}" ADD COLUMN "used_upload_data_transfer" integer DEFAULT 0 NOT NULL;
CREATE TABLE "{{active_transfers}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "connection_id" varchar(100) NOT NULL,
"transfer_id" bigint NOT NULL, "transfer_type" integer NOT NULL, "username" varchar(255) NOT NULL,
"folder_name" varchar(255) NULL, "ip" varchar(50) NOT NULL, "truncated_size" bigint NOT NULL,
"current_ul_size" bigint NOT NULL, "current_dl_size" bigint NOT NULL, "created_at" bigint NOT NULL,
"updated_at" bigint NOT NULL);
CREATE INDEX "{{prefix}}active_transfers_connection_id_idx" ON "{{active_transfers}}" ("connection_id"); CREATE INDEX "{{prefix}}active_transfers_connection_id_idx" ON "{{active_transfers}}" ("connection_id");
CREATE INDEX "{{prefix}}active_transfers_transfer_id_idx" ON "{{active_transfers}}" ("transfer_id"); CREATE INDEX "{{prefix}}active_transfers_transfer_id_idx" ON "{{active_transfers}}" ("transfer_id");
CREATE INDEX "{{prefix}}active_transfers_updated_at_idx" ON "{{active_transfers}}" ("updated_at"); CREATE INDEX "{{prefix}}active_transfers_updated_at_idx" ON "{{active_transfers}}" ("updated_at");
`
sqliteV16DownSQL = `ALTER TABLE "{{users}}" DROP COLUMN "used_upload_data_transfer";
ALTER TABLE "{{users}}" DROP COLUMN "used_download_data_transfer";
ALTER TABLE "{{users}}" DROP COLUMN "upload_data_transfer";
ALTER TABLE "{{users}}" DROP COLUMN "total_data_transfer";
ALTER TABLE "{{users}}" DROP COLUMN "download_data_transfer";
DROP TABLE "{{active_transfers}}";
`
sqliteV17SQL = `CREATE TABLE "{{groups}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "name" varchar(255) NOT NULL UNIQUE,
"description" varchar(512) NULL, "created_at" bigint NOT NULL, "updated_at" bigint NOT NULL, "user_settings" text NULL);
CREATE TABLE "{{groups_folders_mapping}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"folder_id" integer NOT NULL REFERENCES "{{folders}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED,
"group_id" integer NOT NULL REFERENCES "{{groups}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED,
"virtual_path" text NOT NULL, "quota_size" bigint NOT NULL, "quota_files" integer NOT NULL,
CONSTRAINT "{{prefix}}unique_group_folder_mapping" UNIQUE ("group_id", "folder_id"));
CREATE TABLE "{{users_groups_mapping}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"user_id" integer NOT NULL REFERENCES "{{users}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED,
"group_id" integer NOT NULL REFERENCES "{{groups}}" ("id") ON DELETE NO ACTION,
"group_type" integer NOT NULL, CONSTRAINT "{{prefix}}unique_user_group_mapping" UNIQUE ("user_id", "group_id"));
CREATE TABLE "new__folders_mapping" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"user_id" integer NOT NULL REFERENCES "{{users}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED,
"folder_id" integer NOT NULL REFERENCES "{{folders}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED,
"virtual_path" text NOT NULL, "quota_size" bigint NOT NULL, "quota_files" integer NOT NULL,
CONSTRAINT "{{prefix}}unique_user_folder_mapping" UNIQUE ("user_id", "folder_id"));
INSERT INTO "new__folders_mapping" ("id", "virtual_path", "quota_size", "quota_files", "folder_id", "user_id") SELECT "id",
"virtual_path", "quota_size", "quota_files", "folder_id", "user_id" FROM "{{folders_mapping}}";
DROP TABLE "{{folders_mapping}}";
ALTER TABLE "new__folders_mapping" RENAME TO "{{users_folders_mapping}}";
CREATE INDEX "{{prefix}}groups_updated_at_idx" ON "{{groups}}" ("updated_at");
CREATE INDEX "{{prefix}}users_folders_mapping_folder_id_idx" ON "{{users_folders_mapping}}" ("folder_id");
CREATE INDEX "{{prefix}}users_folders_mapping_user_id_idx" ON "{{users_folders_mapping}}" ("user_id");
CREATE INDEX "{{prefix}}users_groups_mapping_group_id_idx" ON "{{users_groups_mapping}}" ("group_id");
CREATE INDEX "{{prefix}}users_groups_mapping_user_id_idx" ON "{{users_groups_mapping}}" ("user_id");
CREATE INDEX "{{prefix}}groups_folders_mapping_folder_id_idx" ON "{{groups_folders_mapping}}" ("folder_id");
CREATE INDEX "{{prefix}}groups_folders_mapping_group_id_idx" ON "{{groups_folders_mapping}}" ("group_id");
`
sqliteV17DownSQL = `DROP TABLE "{{users_groups_mapping}}";
DROP TABLE "{{groups_folders_mapping}}";
DROP TABLE "{{groups}}";
CREATE TABLE "new__folders_mapping" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"user_id" integer NOT NULL REFERENCES "{{users}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED,
"folder_id" integer NOT NULL REFERENCES "{{folders}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED,
"virtual_path" text NOT NULL, "quota_size" bigint NOT NULL, "quota_files" integer NOT NULL,
CONSTRAINT "{{prefix}}unique_folder_mapping" UNIQUE ("user_id", "folder_id"));
INSERT INTO "new__folders_mapping" ("id", "virtual_path", "quota_size", "quota_files", "folder_id", "user_id") SELECT "id",
"virtual_path", "quota_size", "quota_files", "folder_id", "user_id" FROM "{{users_folders_mapping}}";
DROP TABLE "{{users_folders_mapping}}";
ALTER TABLE "new__folders_mapping" RENAME TO "{{folders_mapping}}";
CREATE INDEX "{{prefix}}folders_mapping_folder_id_idx" ON "{{folders_mapping}}" ("folder_id");
CREATE INDEX "{{prefix}}folders_mapping_user_id_idx" ON "{{folders_mapping}}" ("user_id");
`
sqliteV19SQL = `CREATE TABLE "{{shared_sessions}}" ("key" varchar(128) NOT NULL PRIMARY KEY, "data" text NOT NULL,
"type" integer NOT NULL, "timestamp" bigint NOT NULL);
CREATE INDEX "{{prefix}}shared_sessions_type_idx" ON "{{shared_sessions}}" ("type"); CREATE INDEX "{{prefix}}shared_sessions_type_idx" ON "{{shared_sessions}}" ("type");
CREATE INDEX "{{prefix}}shared_sessions_timestamp_idx" ON "{{shared_sessions}}" ("timestamp"); CREATE INDEX "{{prefix}}shared_sessions_timestamp_idx" ON "{{shared_sessions}}" ("timestamp");
INSERT INTO {{schema_version}} (version) VALUES (19); `
` sqliteV19DownSQL = `DROP TABLE "{{shared_sessions}}";`
sqliteV20SQL = `CREATE TABLE "{{events_rules}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"name" varchar(255) NOT NULL UNIQUE, "description" varchar(512) NULL, "created_at" bigint NOT NULL,
"updated_at" bigint NOT NULL, "trigger" integer NOT NULL, "conditions" text NOT NULL, "deleted_at" bigint NOT NULL);
CREATE TABLE "{{events_actions}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "name" varchar(255) NOT NULL UNIQUE,
"description" varchar(512) NULL, "type" integer NOT NULL, "options" text NOT NULL);
CREATE TABLE "{{rules_actions_mapping}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"rule_id" integer NOT NULL REFERENCES "{{events_rules}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED,
"action_id" integer NOT NULL REFERENCES "{{events_actions}}" ("id") ON DELETE NO ACTION DEFERRABLE INITIALLY DEFERRED,
"order" integer NOT NULL, "options" text NOT NULL,
CONSTRAINT "{{prefix}}unique_rule_action_mapping" UNIQUE ("rule_id", "action_id"));
CREATE TABLE "{{tasks}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, "name" varchar(255) NOT NULL UNIQUE,
"updated_at" bigint NOT NULL, "version" bigint NOT NULL);
ALTER TABLE "{{users}}" ADD COLUMN "deleted_at" bigint DEFAULT 0 NOT NULL;
CREATE INDEX "{{prefix}}events_rules_updated_at_idx" ON "{{events_rules}}" ("updated_at");
CREATE INDEX "{{prefix}}events_rules_deleted_at_idx" ON "{{events_rules}}" ("deleted_at");
CREATE INDEX "{{prefix}}events_rules_trigger_idx" ON "{{events_rules}}" ("trigger");
CREATE INDEX "{{prefix}}rules_actions_mapping_rule_id_idx" ON "{{rules_actions_mapping}}" ("rule_id");
CREATE INDEX "{{prefix}}rules_actions_mapping_action_id_idx" ON "{{rules_actions_mapping}}" ("action_id");
CREATE INDEX "{{prefix}}rules_actions_mapping_order_idx" ON "{{rules_actions_mapping}}" ("order");
CREATE INDEX "{{prefix}}users_deleted_at_idx" ON "{{users}}" ("deleted_at");
`
sqliteV20DownSQL = `DROP TABLE "{{rules_actions_mapping}}";
DROP TABLE "{{events_rules}}";
DROP TABLE "{{events_actions}}";
DROP TABLE "{{tasks}}";
DROP INDEX IF EXISTS "{{prefix}}users_deleted_at_idx";
ALTER TABLE "{{users}}" DROP COLUMN "deleted_at";
`
sqliteV21SQL = `ALTER TABLE "{{users}}" ADD COLUMN "first_download" bigint DEFAULT 0 NOT NULL;
ALTER TABLE "{{users}}" ADD COLUMN "first_upload" bigint DEFAULT 0 NOT NULL;`
sqliteV21DownSQL = `ALTER TABLE "{{users}}" DROP COLUMN "first_upload";
ALTER TABLE "{{users}}" DROP COLUMN "first_download";
`
sqliteV22SQL = `CREATE TABLE "{{admins_groups_mapping}}" ("id" integer NOT NULL PRIMARY KEY AUTOINCREMENT,
"admin_id" integer NOT NULL REFERENCES "{{admins}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED,
"group_id" integer NOT NULL REFERENCES "{{groups}}" ("id") ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED,
"options" text NOT NULL, CONSTRAINT "{{prefix}}unique_admin_group_mapping" UNIQUE ("admin_id", "group_id"));
CREATE INDEX "{{prefix}}admins_groups_mapping_admin_id_idx" ON "{{admins_groups_mapping}}" ("admin_id");
CREATE INDEX "{{prefix}}admins_groups_mapping_group_id_idx" ON "{{admins_groups_mapping}}" ("group_id");
`
sqliteV22DownSQL = `DROP TABLE "{{admins_groups_mapping}}";`
) )
// SQLiteProvider defines the auth provider for SQLite database // SQLiteProvider defines the auth provider for SQLite database
@@ -268,8 +261,8 @@ func (p *SQLiteProvider) updateUser(user *User) error {
return sqlCommonUpdateUser(user, p.dbHandle) return sqlCommonUpdateUser(user, p.dbHandle)
} }
func (p *SQLiteProvider) deleteUser(user User, softDelete bool) error { func (p *SQLiteProvider) deleteUser(user User) error {
return sqlCommonDeleteUser(user, softDelete, p.dbHandle) return sqlCommonDeleteUser(user, p.dbHandle)
} }
func (p *SQLiteProvider) updateUserPassword(username, password string) error { func (p *SQLiteProvider) updateUserPassword(username, password string) error {
@@ -510,102 +503,6 @@ func (p *SQLiteProvider) cleanupSharedSessions(sessionType SessionType, before i
return sqlCommonCleanupSessions(sessionType, before, p.dbHandle) return sqlCommonCleanupSessions(sessionType, before, p.dbHandle)
} }
func (p *SQLiteProvider) getEventActions(limit, offset int, order string, minimal bool) ([]BaseEventAction, error) {
return sqlCommonGetEventActions(limit, offset, order, minimal, p.dbHandle)
}
func (p *SQLiteProvider) dumpEventActions() ([]BaseEventAction, error) {
return sqlCommonDumpEventActions(p.dbHandle)
}
func (p *SQLiteProvider) eventActionExists(name string) (BaseEventAction, error) {
return sqlCommonGetEventActionByName(name, p.dbHandle)
}
func (p *SQLiteProvider) addEventAction(action *BaseEventAction) error {
return sqlCommonAddEventAction(action, p.dbHandle)
}
func (p *SQLiteProvider) updateEventAction(action *BaseEventAction) error {
return sqlCommonUpdateEventAction(action, p.dbHandle)
}
func (p *SQLiteProvider) deleteEventAction(action BaseEventAction) error {
return sqlCommonDeleteEventAction(action, p.dbHandle)
}
func (p *SQLiteProvider) getEventRules(limit, offset int, order string) ([]EventRule, error) {
return sqlCommonGetEventRules(limit, offset, order, p.dbHandle)
}
func (p *SQLiteProvider) dumpEventRules() ([]EventRule, error) {
return sqlCommonDumpEventRules(p.dbHandle)
}
func (p *SQLiteProvider) getRecentlyUpdatedRules(after int64) ([]EventRule, error) {
return sqlCommonGetRecentlyUpdatedRules(after, p.dbHandle)
}
func (p *SQLiteProvider) eventRuleExists(name string) (EventRule, error) {
return sqlCommonGetEventRuleByName(name, p.dbHandle)
}
func (p *SQLiteProvider) addEventRule(rule *EventRule) error {
return sqlCommonAddEventRule(rule, p.dbHandle)
}
func (p *SQLiteProvider) updateEventRule(rule *EventRule) error {
return sqlCommonUpdateEventRule(rule, p.dbHandle)
}
func (p *SQLiteProvider) deleteEventRule(rule EventRule, softDelete bool) error {
return sqlCommonDeleteEventRule(rule, softDelete, p.dbHandle)
}
func (p *SQLiteProvider) getTaskByName(name string) (Task, error) {
return sqlCommonGetTaskByName(name, p.dbHandle)
}
func (p *SQLiteProvider) addTask(name string) error {
return sqlCommonAddTask(name, p.dbHandle)
}
func (p *SQLiteProvider) updateTask(name string, version int64) error {
return sqlCommonUpdateTask(name, version, p.dbHandle)
}
func (p *SQLiteProvider) updateTaskTimestamp(name string) error {
return sqlCommonUpdateTaskTimestamp(name, p.dbHandle)
}
func (*SQLiteProvider) addNode() error {
return ErrNotImplemented
}
func (*SQLiteProvider) getNodeByName(name string) (Node, error) {
return Node{}, ErrNotImplemented
}
func (*SQLiteProvider) getNodes() ([]Node, error) {
return nil, ErrNotImplemented
}
func (*SQLiteProvider) updateNodeTimestamp() error {
return ErrNotImplemented
}
func (*SQLiteProvider) cleanupNodes() error {
return ErrNotImplemented
}
func (p *SQLiteProvider) setFirstDownloadTimestamp(username string) error {
return sqlCommonSetFirstDownloadTimestamp(username, p.dbHandle)
}
func (p *SQLiteProvider) setFirstUploadTimestamp(username string) error {
return sqlCommonSetFirstUploadTimestamp(username, p.dbHandle)
}
func (p *SQLiteProvider) close() error { func (p *SQLiteProvider) close() error {
return p.dbHandle.Close() return p.dbHandle.Close()
} }
@@ -623,11 +520,20 @@ func (p *SQLiteProvider) initializeDatabase() error {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return errSchemaVersionEmpty return errSchemaVersionEmpty
} }
logger.InfoToConsole("creating initial database schema, version 19") logger.InfoToConsole("creating initial database schema, version 15")
providerLog(logger.LevelInfo, "creating initial database schema, version 19") providerLog(logger.LevelInfo, "creating initial database schema, version 15")
sql := sqlReplaceAll(sqliteInitialSQL) initialSQL := strings.ReplaceAll(sqliteInitialSQL, "{{schema_version}}", sqlTableSchemaVersion)
initialSQL = strings.ReplaceAll(initialSQL, "{{admins}}", sqlTableAdmins)
initialSQL = strings.ReplaceAll(initialSQL, "{{folders}}", sqlTableFolders)
initialSQL = strings.ReplaceAll(initialSQL, "{{users}}", sqlTableUsers)
initialSQL = strings.ReplaceAll(initialSQL, "{{folders_mapping}}", sqlTableFoldersMapping)
initialSQL = strings.ReplaceAll(initialSQL, "{{api_keys}}", sqlTableAPIKeys)
initialSQL = strings.ReplaceAll(initialSQL, "{{shares}}", sqlTableShares)
initialSQL = strings.ReplaceAll(initialSQL, "{{defender_events}}", sqlTableDefenderEvents)
initialSQL = strings.ReplaceAll(initialSQL, "{{defender_hosts}}", sqlTableDefenderHosts)
initialSQL = strings.ReplaceAll(initialSQL, "{{prefix}}", config.SQLTablesPrefix)
return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{sql}, 19, true) return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{initialSQL}, 15, true)
} }
func (p *SQLiteProvider) migrateDatabase() error { //nolint:dupl func (p *SQLiteProvider) migrateDatabase() error { //nolint:dupl
@@ -640,28 +546,28 @@ func (p *SQLiteProvider) migrateDatabase() error { //nolint:dupl
case version == sqlDatabaseVersion: case version == sqlDatabaseVersion:
providerLog(logger.LevelDebug, "sql database is up to date, current version: %v", version) providerLog(logger.LevelDebug, "sql database is up to date, current version: %v", version)
return ErrNoInitRequired return ErrNoInitRequired
case version < 19: case version < 15:
err = fmt.Errorf("database schema version %v is too old, please see the upgrading docs", version) err = fmt.Errorf("database version %v is too old, please see the upgrading docs", version)
providerLog(logger.LevelError, "%v", err) providerLog(logger.LevelError, "%v", err)
logger.ErrorToConsole("%v", err) logger.ErrorToConsole("%v", err)
return err return err
case version == 19: case version == 15:
return updateSQLiteDatabaseFromV19(p.dbHandle) return updateSQLiteDatabaseFromV15(p.dbHandle)
case version == 20: case version == 16:
return updateSQLiteDatabaseFromV20(p.dbHandle) return updateSQLiteDatabaseFromV16(p.dbHandle)
case version == 21: case version == 17:
return updateSQLiteDatabaseFromV21(p.dbHandle) return updateSQLiteDatabaseFromV17(p.dbHandle)
case version == 22: case version == 18:
return updateSQLiteDatabaseFromV22(p.dbHandle) return updateSQLiteDatabaseFromV18(p.dbHandle)
default: default:
if version > sqlDatabaseVersion { if version > sqlDatabaseVersion {
providerLog(logger.LevelError, "database schema version %v is newer than the supported one: %v", version, providerLog(logger.LevelError, "database version %v is newer than the supported one: %v", version,
sqlDatabaseVersion) sqlDatabaseVersion)
logger.WarnToConsole("database schema version %v is newer than the supported one: %v", version, logger.WarnToConsole("database version %v is newer than the supported one: %v", version,
sqlDatabaseVersion) sqlDatabaseVersion)
return nil return nil
} }
return fmt.Errorf("database schema version not handled: %v", version) return fmt.Errorf("database version not handled: %v", version)
} }
} }
@@ -675,16 +581,16 @@ func (p *SQLiteProvider) revertDatabase(targetVersion int) error {
} }
switch dbVersion.Version { switch dbVersion.Version {
case 20: case 16:
return downgradeSQLiteDatabaseFromV20(p.dbHandle) return downgradeSQLiteDatabaseFromV16(p.dbHandle)
case 21: case 17:
return downgradeSQLiteDatabaseFromV21(p.dbHandle) return downgradeSQLiteDatabaseFromV17(p.dbHandle)
case 22: case 18:
return downgradeSQLiteDatabaseFromV22(p.dbHandle) return downgradeSQLiteDatabaseFromV18(p.dbHandle)
case 23: case 19:
return downgradeSQLiteDatabaseFromV23(p.dbHandle) return downgradeSQLiteDatabaseFromV19(p.dbHandle)
default: default:
return fmt.Errorf("database schema version not handled: %v", dbVersion.Version) return fmt.Errorf("database version not handled: %v", dbVersion.Version)
} }
} }
@@ -693,125 +599,145 @@ func (p *SQLiteProvider) resetDatabase() error {
return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{sql}, 0, false) return sqlCommonExecSQLAndUpdateDBVersion(p.dbHandle, []string{sql}, 0, false)
} }
func updateSQLiteDatabaseFromV19(dbHandle *sql.DB) error { func updateSQLiteDatabaseFromV15(dbHandle *sql.DB) error {
if err := updateSQLiteDatabaseFrom19To20(dbHandle); err != nil { if err := updateSQLiteDatabaseFrom15To16(dbHandle); err != nil {
return err return err
} }
return updateSQLiteDatabaseFromV20(dbHandle) return updateSQLiteDatabaseFromV16(dbHandle)
} }
func updateSQLiteDatabaseFromV20(dbHandle *sql.DB) error { func updateSQLiteDatabaseFromV16(dbHandle *sql.DB) error {
if err := updateSQLiteDatabaseFrom20To21(dbHandle); err != nil { if err := updateSQLiteDatabaseFrom16To17(dbHandle); err != nil {
return err return err
} }
return updateSQLiteDatabaseFromV21(dbHandle) return updateSQLiteDatabaseFromV17(dbHandle)
} }
func updateSQLiteDatabaseFromV21(dbHandle *sql.DB) error { func updateSQLiteDatabaseFromV17(dbHandle *sql.DB) error {
if err := updateSQLiteDatabaseFrom21To22(dbHandle); err != nil { if err := updateSQLiteDatabaseFrom17To18(dbHandle); err != nil {
return err return err
} }
return updateSQLiteDatabaseFromV22(dbHandle) return updateSQLiteDatabaseFromV18(dbHandle)
} }
func updateSQLiteDatabaseFromV22(dbHandle *sql.DB) error { func updateSQLiteDatabaseFromV18(dbHandle *sql.DB) error {
return updateSQLiteDatabaseFrom22To23(dbHandle) return updateSQLiteDatabaseFrom18To19(dbHandle)
} }
func downgradeSQLiteDatabaseFromV20(dbHandle *sql.DB) error { func downgradeSQLiteDatabaseFromV16(dbHandle *sql.DB) error {
return downgradeSQLiteDatabaseFrom20To19(dbHandle) return downgradeSQLiteDatabaseFrom16To15(dbHandle)
} }
func downgradeSQLiteDatabaseFromV21(dbHandle *sql.DB) error { func downgradeSQLiteDatabaseFromV17(dbHandle *sql.DB) error {
if err := downgradeSQLiteDatabaseFrom21To20(dbHandle); err != nil { if err := downgradeSQLiteDatabaseFrom17To16(dbHandle); err != nil {
return err return err
} }
return downgradeSQLiteDatabaseFromV20(dbHandle) return downgradeSQLiteDatabaseFromV16(dbHandle)
} }
func downgradeSQLiteDatabaseFromV22(dbHandle *sql.DB) error { func downgradeSQLiteDatabaseFromV18(dbHandle *sql.DB) error {
if err := downgradeSQLiteDatabaseFrom22To21(dbHandle); err != nil { if err := downgradeSQLiteDatabaseFrom18To17(dbHandle); err != nil {
return err return err
} }
return downgradeSQLiteDatabaseFromV21(dbHandle) return downgradeSQLiteDatabaseFromV17(dbHandle)
} }
func downgradeSQLiteDatabaseFromV23(dbHandle *sql.DB) error { func downgradeSQLiteDatabaseFromV19(dbHandle *sql.DB) error {
if err := downgradeSQLiteDatabaseFrom23To22(dbHandle); err != nil { if err := downgradeSQLiteDatabaseFrom19To18(dbHandle); err != nil {
return err return err
} }
return downgradeSQLiteDatabaseFromV22(dbHandle) return downgradeSQLiteDatabaseFromV18(dbHandle)
} }
func updateSQLiteDatabaseFrom19To20(dbHandle *sql.DB) error { func updateSQLiteDatabaseFrom15To16(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database schema version: 19 -> 20") logger.InfoToConsole("updating database version: 15 -> 16")
providerLog(logger.LevelInfo, "updating database schema version: 19 -> 20") providerLog(logger.LevelInfo, "updating database version: 15 -> 16")
sql := strings.ReplaceAll(sqliteV20SQL, "{{events_actions}}", sqlTableEventsActions) sql := strings.ReplaceAll(sqliteV16SQL, "{{users}}", sqlTableUsers)
sql = strings.ReplaceAll(sql, "{{events_rules}}", sqlTableEventsRules) sql = strings.ReplaceAll(sql, "{{active_transfers}}", sqlTableActiveTransfers)
sql = strings.ReplaceAll(sql, "{{rules_actions_mapping}}", sqlTableRulesActionsMapping)
sql = strings.ReplaceAll(sql, "{{tasks}}", sqlTableTasks)
sql = strings.ReplaceAll(sql, "{{users}}", sqlTableUsers)
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix) sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 20, true) return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 16, true)
} }
func updateSQLiteDatabaseFrom20To21(dbHandle *sql.DB) error { func updateSQLiteDatabaseFrom16To17(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database schema version: 20 -> 21") logger.InfoToConsole("updating database version: 16 -> 17")
providerLog(logger.LevelInfo, "updating database schema version: 20 -> 21") providerLog(logger.LevelInfo, "updating database version: 16 -> 17")
sql := strings.ReplaceAll(sqliteV21SQL, "{{users}}", sqlTableUsers) if err := setPragmaFK(dbHandle, "OFF"); err != nil {
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 21, true) return err
} }
sql := strings.ReplaceAll(sqliteV17SQL, "{{users}}", sqlTableUsers)
func updateSQLiteDatabaseFrom21To22(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database schema version: 21 -> 22")
providerLog(logger.LevelInfo, "updating database schema version: 21 -> 22")
sql := strings.ReplaceAll(sqliteV22SQL, "{{admins_groups_mapping}}", sqlTableAdminsGroupsMapping)
sql = strings.ReplaceAll(sql, "{{admins}}", sqlTableAdmins)
sql = strings.ReplaceAll(sql, "{{groups}}", sqlTableGroups) sql = strings.ReplaceAll(sql, "{{groups}}", sqlTableGroups)
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 22, true)
}
func updateSQLiteDatabaseFrom22To23(dbHandle *sql.DB) error {
logger.InfoToConsole("updating database schema version: 22 -> 23")
providerLog(logger.LevelInfo, "updating database schema version: 22 -> 23")
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{`SELECT 1`}, 23, true)
}
func downgradeSQLiteDatabaseFrom20To19(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database schema version: 20 -> 19")
providerLog(logger.LevelInfo, "downgrading database schema version: 20 -> 19")
sql := strings.ReplaceAll(sqliteV20DownSQL, "{{events_actions}}", sqlTableEventsActions)
sql = strings.ReplaceAll(sql, "{{events_rules}}", sqlTableEventsRules)
sql = strings.ReplaceAll(sql, "{{rules_actions_mapping}}", sqlTableRulesActionsMapping)
sql = strings.ReplaceAll(sql, "{{users}}", sqlTableUsers) sql = strings.ReplaceAll(sql, "{{users}}", sqlTableUsers)
sql = strings.ReplaceAll(sql, "{{tasks}}", sqlTableTasks) sql = strings.ReplaceAll(sql, "{{folders}}", sqlTableFolders)
sql = strings.ReplaceAll(sql, "{{folders_mapping}}", sqlTableFoldersMapping)
sql = strings.ReplaceAll(sql, "{{users_folders_mapping}}", sqlTableUsersFoldersMapping)
sql = strings.ReplaceAll(sql, "{{users_groups_mapping}}", sqlTableUsersGroupsMapping)
sql = strings.ReplaceAll(sql, "{{groups_folders_mapping}}", sqlTableGroupsFoldersMapping)
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix) sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 19, false) if err := sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 17, true); err != nil {
return err
}
return setPragmaFK(dbHandle, "ON")
} }
func downgradeSQLiteDatabaseFrom21To20(dbHandle *sql.DB) error { func updateSQLiteDatabaseFrom17To18(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database schema version: 21 -> 20") logger.InfoToConsole("updating database version: 17 -> 18")
providerLog(logger.LevelInfo, "downgrading database schema version: 21 -> 20") providerLog(logger.LevelInfo, "updating database version: 17 -> 18")
sql := strings.ReplaceAll(sqliteV21DownSQL, "{{users}}", sqlTableUsers) if err := importGCSCredentials(); err != nil {
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 20, false) return err
}
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, nil, 18, true)
} }
func downgradeSQLiteDatabaseFrom22To21(dbHandle *sql.DB) error { func updateSQLiteDatabaseFrom18To19(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database schema version: 22 -> 21") logger.InfoToConsole("updating database version: 18 -> 19")
providerLog(logger.LevelInfo, "downgrading database schema version: 22 -> 21") providerLog(logger.LevelInfo, "updating database version: 18 -> 19")
sql := strings.ReplaceAll(sqliteV22DownSQL, "{{admins_groups_mapping}}", sqlTableAdminsGroupsMapping) sql := strings.ReplaceAll(sqliteV19SQL, "{{shared_sessions}}", sqlTableSharedSessions)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 21, false) sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 19, true)
} }
func downgradeSQLiteDatabaseFrom23To22(dbHandle *sql.DB) error { func downgradeSQLiteDatabaseFrom16To15(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database schema version: 23 -> 22") logger.InfoToConsole("downgrading database version: 16 -> 15")
providerLog(logger.LevelInfo, "downgrading database schema version: 23 -> 22") providerLog(logger.LevelInfo, "downgrading database version: 16 -> 15")
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{`SELECT 1`}, 22, false) sql := strings.ReplaceAll(sqliteV16DownSQL, "{{users}}", sqlTableUsers)
sql = strings.ReplaceAll(sql, "{{active_transfers}}", sqlTableActiveTransfers)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 15, false)
} }
/*func setPragmaFK(dbHandle *sql.DB, value string) error { func downgradeSQLiteDatabaseFrom17To16(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database version: 17 -> 16")
providerLog(logger.LevelInfo, "downgrading database version: 17 -> 16")
if err := setPragmaFK(dbHandle, "OFF"); err != nil {
return err
}
sql := strings.ReplaceAll(sqliteV17DownSQL, "{{groups}}", sqlTableGroups)
sql = strings.ReplaceAll(sql, "{{users}}", sqlTableUsers)
sql = strings.ReplaceAll(sql, "{{folders}}", sqlTableFolders)
sql = strings.ReplaceAll(sql, "{{folders_mapping}}", sqlTableFoldersMapping)
sql = strings.ReplaceAll(sql, "{{users_folders_mapping}}", sqlTableUsersFoldersMapping)
sql = strings.ReplaceAll(sql, "{{users_groups_mapping}}", sqlTableUsersGroupsMapping)
sql = strings.ReplaceAll(sql, "{{groups_folders_mapping}}", sqlTableGroupsFoldersMapping)
sql = strings.ReplaceAll(sql, "{{prefix}}", config.SQLTablesPrefix)
if err := sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 16, false); err != nil {
return err
}
return setPragmaFK(dbHandle, "ON")
}
func downgradeSQLiteDatabaseFrom18To17(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database version: 18 -> 17")
providerLog(logger.LevelInfo, "downgrading database version: 18 -> 17")
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, nil, 17, false)
}
func downgradeSQLiteDatabaseFrom19To18(dbHandle *sql.DB) error {
logger.InfoToConsole("downgrading database version: 19 -> 18")
providerLog(logger.LevelInfo, "downgrading database version: 19 -> 18")
sql := strings.ReplaceAll(sqliteV19DownSQL, "{{shared_sessions}}", sqlTableSharedSessions)
return sqlCommonExecSQLAndUpdateDBVersion(dbHandle, []string{sql}, 18, false)
}
func setPragmaFK(dbHandle *sql.DB, value string) error {
ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout) ctx, cancel := context.WithTimeout(context.Background(), longSQLQueryTimeout)
defer cancel() defer cancel()
@@ -819,4 +745,4 @@ func downgradeSQLiteDatabaseFrom23To22(dbHandle *sql.DB) error {
_, err := dbHandle.ExecContext(ctx, sql) _, err := dbHandle.ExecContext(ctx, sql)
return err return err
}*/ }

View File

@@ -20,13 +20,13 @@ package dataprovider
import ( import (
"errors" "errors"
"github.com/drakkan/sftpgo/v2/internal/version" "github.com/drakkan/sftpgo/v2/version"
) )
func init() { func init() {
version.AddFeature("-sqlite") version.AddFeature("-sqlite")
} }
func initializeSQLiteProvider(_ string) error { func initializeSQLiteProvider(basePath string) error {
return errors.New("SQLite disabled at build time") return errors.New("SQLite disabled at build time")
} }

734
dataprovider/sqlqueries.go Normal file
View File

@@ -0,0 +1,734 @@
// Copyright (C) 2019-2022 Nicola Murino
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package dataprovider
import (
"fmt"
"strconv"
"strings"
"github.com/drakkan/sftpgo/v2/vfs"
)
const (
selectUserFields = "id,username,password,public_keys,home_dir,uid,gid,max_sessions,quota_size,quota_files,permissions,used_quota_size," +
"used_quota_files,last_quota_update,upload_bandwidth,download_bandwidth,expiration_date,last_login,status,filters,filesystem," +
"additional_info,description,email,created_at,updated_at,upload_data_transfer,download_data_transfer,total_data_transfer," +
"used_upload_data_transfer,used_download_data_transfer"
selectFolderFields = "id,path,used_quota_size,used_quota_files,last_quota_update,name,description,filesystem"
selectAdminFields = "id,username,password,status,email,permissions,filters,additional_info,description,created_at,updated_at,last_login"
selectAPIKeyFields = "key_id,name,api_key,scope,created_at,updated_at,last_use_at,expires_at,description,user_id,admin_id"
selectShareFields = "s.share_id,s.name,s.description,s.scope,s.paths,u.username,s.created_at,s.updated_at,s.last_use_at," +
"s.expires_at,s.password,s.max_tokens,s.used_tokens,s.allow_from"
selectGroupFields = "id,name,description,created_at,updated_at,user_settings"
)
func getSQLPlaceholders() []string {
var placeholders []string
for i := 1; i <= 50; i++ {
if config.Driver == PGSQLDataProviderName || config.Driver == CockroachDataProviderName {
placeholders = append(placeholders, fmt.Sprintf("$%v", i))
} else {
placeholders = append(placeholders, "?")
}
}
return placeholders
}
func getSQLTableGroups() string {
if config.Driver == MySQLDataProviderName {
return fmt.Sprintf("`%s`", sqlTableGroups)
}
return sqlTableGroups
}
func getAddSessionQuery() string {
if config.Driver == MySQLDataProviderName {
return fmt.Sprintf("INSERT INTO %s (`key`,`data`,`type`,`timestamp`) VALUES (%s,%s,%s,%s) "+
"ON DUPLICATE KEY UPDATE `data`=VALUES(`data`), `timestamp`=VALUES(`timestamp`)",
sqlTableSharedSessions, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3])
}
return fmt.Sprintf(`INSERT INTO %s (key,data,type,timestamp) VALUES (%s,%s,%s,%s) ON CONFLICT(key) DO UPDATE SET data=
EXCLUDED.data, timestamp=EXCLUDED.timestamp`,
sqlTableSharedSessions, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3])
}
func getDeleteSessionQuery() string {
if config.Driver == MySQLDataProviderName {
return fmt.Sprintf("DELETE FROM %s WHERE `key` = %s", sqlTableSharedSessions, sqlPlaceholders[0])
}
return fmt.Sprintf(`DELETE FROM %s WHERE key = %s`, sqlTableSharedSessions, sqlPlaceholders[0])
}
func getSessionQuery() string {
if config.Driver == MySQLDataProviderName {
return fmt.Sprintf("SELECT `key`,`data`,`type`,`timestamp` FROM %s WHERE `key` = %s", sqlTableSharedSessions,
sqlPlaceholders[0])
}
return fmt.Sprintf(`SELECT key,data,type,timestamp FROM %s WHERE key = %s`, sqlTableSharedSessions,
sqlPlaceholders[0])
}
func getCleanupSessionsQuery() string {
return fmt.Sprintf(`DELETE from %s WHERE type = %s AND timestamp < %s`,
sqlTableSharedSessions, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getAddDefenderHostQuery() string {
if config.Driver == MySQLDataProviderName {
return fmt.Sprintf("INSERT INTO %v (`ip`,`updated_at`,`ban_time`) VALUES (%v,%v,0) ON DUPLICATE KEY UPDATE `updated_at`=VALUES(`updated_at`)",
sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1])
}
return fmt.Sprintf(`INSERT INTO %v (ip,updated_at,ban_time) VALUES (%v,%v,0) ON CONFLICT (ip) DO UPDATE SET updated_at = EXCLUDED.updated_at RETURNING id`,
sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getAddDefenderEventQuery() string {
return fmt.Sprintf(`INSERT INTO %v (date_time,score,host_id) VALUES (%v,%v,(SELECT id from %v WHERE ip = %v))`,
sqlTableDefenderEvents, sqlPlaceholders[0], sqlPlaceholders[1], sqlTableDefenderHosts, sqlPlaceholders[2])
}
func getDefenderHostsQuery() string {
return fmt.Sprintf(`SELECT id,ip,ban_time FROM %v WHERE updated_at >= %v OR ban_time > 0 ORDER BY updated_at DESC LIMIT %v`,
sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getDefenderHostQuery() string {
return fmt.Sprintf(`SELECT id,ip,ban_time FROM %v WHERE ip = %v AND (updated_at >= %v OR ban_time > 0)`,
sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getDefenderEventsQuery(hostIDS []int64) string {
var sb strings.Builder
for _, hID := range hostIDS {
if sb.Len() == 0 {
sb.WriteString("(")
} else {
sb.WriteString(",")
}
sb.WriteString(strconv.FormatInt(hID, 10))
}
if sb.Len() > 0 {
sb.WriteString(")")
} else {
sb.WriteString("(0)")
}
return fmt.Sprintf(`SELECT host_id,SUM(score) FROM %v WHERE date_time >= %v AND host_id IN %v GROUP BY host_id`,
sqlTableDefenderEvents, sqlPlaceholders[0], sb.String())
}
func getDefenderIsHostBannedQuery() string {
return fmt.Sprintf(`SELECT id FROM %v WHERE ip = %v AND ban_time >= %v`,
sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getDefenderIncrementBanTimeQuery() string {
return fmt.Sprintf(`UPDATE %v SET ban_time = ban_time + %v WHERE ip = %v`,
sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getDefenderSetBanTimeQuery() string {
return fmt.Sprintf(`UPDATE %v SET ban_time = %v WHERE ip = %v`,
sqlTableDefenderHosts, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getDeleteDefenderHostQuery() string {
return fmt.Sprintf(`DELETE FROM %v WHERE ip = %v`, sqlTableDefenderHosts, sqlPlaceholders[0])
}
func getDefenderHostsCleanupQuery() string {
return fmt.Sprintf(`DELETE FROM %v WHERE ban_time < %v AND NOT EXISTS (
SELECT id FROM %v WHERE %v.host_id = %v.id AND %v.date_time > %v)`,
sqlTableDefenderHosts, sqlPlaceholders[0], sqlTableDefenderEvents, sqlTableDefenderEvents, sqlTableDefenderHosts,
sqlTableDefenderEvents, sqlPlaceholders[1])
}
func getDefenderEventsCleanupQuery() string {
return fmt.Sprintf(`DELETE FROM %v WHERE date_time < %v`, sqlTableDefenderEvents, sqlPlaceholders[0])
}
func getGroupByNameQuery() string {
return fmt.Sprintf(`SELECT %s FROM %s WHERE name = %s`, selectGroupFields, getSQLTableGroups(), sqlPlaceholders[0])
}
func getGroupsQuery(order string, minimal bool) string {
var fieldSelection string
if minimal {
fieldSelection = "id,name"
} else {
fieldSelection = selectGroupFields
}
return fmt.Sprintf(`SELECT %s FROM %s ORDER BY name %s LIMIT %v OFFSET %v`, fieldSelection, getSQLTableGroups(),
order, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getGroupsWithNamesQuery(numArgs int) string {
var sb strings.Builder
for idx := 0; idx < numArgs; idx++ {
if sb.Len() == 0 {
sb.WriteString("(")
} else {
sb.WriteString(",")
}
sb.WriteString(sqlPlaceholders[idx])
}
if sb.Len() > 0 {
sb.WriteString(")")
} else {
sb.WriteString("('')")
}
return fmt.Sprintf(`SELECT %s FROM %s WHERE name in %s`, selectGroupFields, getSQLTableGroups(), sb.String())
}
func getUsersInGroupsQuery(numArgs int) string {
var sb strings.Builder
for idx := 0; idx < numArgs; idx++ {
if sb.Len() == 0 {
sb.WriteString("(")
} else {
sb.WriteString(",")
}
sb.WriteString(sqlPlaceholders[idx])
}
if sb.Len() > 0 {
sb.WriteString(")")
} else {
sb.WriteString("('')")
}
return fmt.Sprintf(`SELECT username FROM %s WHERE id IN (SELECT user_id from %s WHERE group_id IN (SELECT id FROM %s WHERE name IN (%s)))`,
sqlTableUsers, sqlTableUsersGroupsMapping, getSQLTableGroups(), sb.String())
}
func getDumpGroupsQuery() string {
return fmt.Sprintf(`SELECT %s FROM %s`, selectGroupFields, getSQLTableGroups())
}
func getAddGroupQuery() string {
return fmt.Sprintf(`INSERT INTO %s (name,description,created_at,updated_at,user_settings)
VALUES (%v,%v,%v,%v,%v)`, getSQLTableGroups(), sqlPlaceholders[0], sqlPlaceholders[1],
sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4])
}
func getUpdateGroupQuery() string {
return fmt.Sprintf(`UPDATE %s SET description=%v,user_settings=%v,updated_at=%v
WHERE name = %s`, getSQLTableGroups(), sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2],
sqlPlaceholders[3])
}
func getDeleteGroupQuery() string {
return fmt.Sprintf(`DELETE FROM %s WHERE name = %s`, getSQLTableGroups(), sqlPlaceholders[0])
}
func getAdminByUsernameQuery() string {
return fmt.Sprintf(`SELECT %v FROM %v WHERE username = %v`, selectAdminFields, sqlTableAdmins, sqlPlaceholders[0])
}
func getAdminsQuery(order string) string {
return fmt.Sprintf(`SELECT %v FROM %v ORDER BY username %v LIMIT %v OFFSET %v`, selectAdminFields, sqlTableAdmins,
order, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getDumpAdminsQuery() string {
return fmt.Sprintf(`SELECT %v FROM %v`, selectAdminFields, sqlTableAdmins)
}
func getAddAdminQuery() string {
return fmt.Sprintf(`INSERT INTO %v (username,password,status,email,permissions,filters,additional_info,description,created_at,updated_at,last_login)
VALUES (%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,0)`, sqlTableAdmins, sqlPlaceholders[0], sqlPlaceholders[1],
sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7],
sqlPlaceholders[8], sqlPlaceholders[9])
}
func getUpdateAdminQuery() string {
return fmt.Sprintf(`UPDATE %v SET password=%v,status=%v,email=%v,permissions=%v,filters=%v,additional_info=%v,description=%v,updated_at=%v
WHERE username = %v`, sqlTableAdmins, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2],
sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8])
}
func getDeleteAdminQuery() string {
return fmt.Sprintf(`DELETE FROM %v WHERE username = %v`, sqlTableAdmins, sqlPlaceholders[0])
}
func getShareByIDQuery(filterUser bool) string {
if filterUser {
return fmt.Sprintf(`SELECT %v FROM %v s INNER JOIN %v u ON s.user_id = u.id WHERE s.share_id = %v AND u.username = %v`,
selectShareFields, sqlTableShares, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1])
}
return fmt.Sprintf(`SELECT %v FROM %v s INNER JOIN %v u ON s.user_id = u.id WHERE s.share_id = %v`,
selectShareFields, sqlTableShares, sqlTableUsers, sqlPlaceholders[0])
}
func getSharesQuery(order string) string {
return fmt.Sprintf(`SELECT %v FROM %v s INNER JOIN %v u ON s.user_id = u.id WHERE u.username = %v ORDER BY s.share_id %v LIMIT %v OFFSET %v`,
selectShareFields, sqlTableShares, sqlTableUsers, sqlPlaceholders[0], order, sqlPlaceholders[1], sqlPlaceholders[2])
}
func getDumpSharesQuery() string {
return fmt.Sprintf(`SELECT %v FROM %v s INNER JOIN %v u ON s.user_id = u.id`,
selectShareFields, sqlTableShares, sqlTableUsers)
}
func getAddShareQuery() string {
return fmt.Sprintf(`INSERT INTO %v (share_id,name,description,scope,paths,created_at,updated_at,last_use_at,
expires_at,password,max_tokens,used_tokens,allow_from,user_id) VALUES (%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,%v)`,
sqlTableShares, sqlPlaceholders[0], sqlPlaceholders[1],
sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6],
sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9], sqlPlaceholders[10], sqlPlaceholders[11],
sqlPlaceholders[12], sqlPlaceholders[13])
}
func getUpdateShareRestoreQuery() string {
return fmt.Sprintf(`UPDATE %v SET name=%v,description=%v,scope=%v,paths=%v,created_at=%v,updated_at=%v,
last_use_at=%v,expires_at=%v,password=%v,max_tokens=%v,used_tokens=%v,allow_from=%v,user_id=%v WHERE share_id = %v`, sqlTableShares,
sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4],
sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9],
sqlPlaceholders[10], sqlPlaceholders[11], sqlPlaceholders[12], sqlPlaceholders[13])
}
func getUpdateShareQuery() string {
return fmt.Sprintf(`UPDATE %v SET name=%v,description=%v,scope=%v,paths=%v,updated_at=%v,expires_at=%v,
password=%v,max_tokens=%v,allow_from=%v,user_id=%v WHERE share_id = %v`, sqlTableShares,
sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4],
sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9],
sqlPlaceholders[10])
}
func getDeleteShareQuery() string {
return fmt.Sprintf(`DELETE FROM %v WHERE share_id = %v`, sqlTableShares, sqlPlaceholders[0])
}
func getAPIKeyByIDQuery() string {
return fmt.Sprintf(`SELECT %v FROM %v WHERE key_id = %v`, selectAPIKeyFields, sqlTableAPIKeys, sqlPlaceholders[0])
}
func getAPIKeysQuery(order string) string {
return fmt.Sprintf(`SELECT %v FROM %v ORDER BY key_id %v LIMIT %v OFFSET %v`, selectAPIKeyFields, sqlTableAPIKeys,
order, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getDumpAPIKeysQuery() string {
return fmt.Sprintf(`SELECT %v FROM %v`, selectAPIKeyFields, sqlTableAPIKeys)
}
func getAddAPIKeyQuery() string {
return fmt.Sprintf(`INSERT INTO %v (key_id,name,api_key,scope,created_at,updated_at,last_use_at,expires_at,description,user_id,admin_id)
VALUES (%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,%v)`, sqlTableAPIKeys, sqlPlaceholders[0], sqlPlaceholders[1],
sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6],
sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9], sqlPlaceholders[10])
}
func getUpdateAPIKeyQuery() string {
return fmt.Sprintf(`UPDATE %v SET name=%v,scope=%v,expires_at=%v,user_id=%v,admin_id=%v,description=%v,updated_at=%v
WHERE key_id = %v`, sqlTableAPIKeys, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2],
sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7])
}
func getDeleteAPIKeyQuery() string {
return fmt.Sprintf(`DELETE FROM %v WHERE key_id = %v`, sqlTableAPIKeys, sqlPlaceholders[0])
}
func getRelatedUsersForAPIKeysQuery(apiKeys []APIKey) string {
var sb strings.Builder
for _, k := range apiKeys {
if k.userID == 0 {
continue
}
if sb.Len() == 0 {
sb.WriteString("(")
} else {
sb.WriteString(",")
}
sb.WriteString(strconv.FormatInt(k.userID, 10))
}
if sb.Len() > 0 {
sb.WriteString(")")
} else {
sb.WriteString("(0)")
}
return fmt.Sprintf(`SELECT id,username FROM %v WHERE id IN %v`, sqlTableUsers, sb.String())
}
func getRelatedAdminsForAPIKeysQuery(apiKeys []APIKey) string {
var sb strings.Builder
for _, k := range apiKeys {
if k.adminID == 0 {
continue
}
if sb.Len() == 0 {
sb.WriteString("(")
} else {
sb.WriteString(",")
}
sb.WriteString(strconv.FormatInt(k.adminID, 10))
}
if sb.Len() > 0 {
sb.WriteString(")")
} else {
sb.WriteString("(0)")
}
return fmt.Sprintf(`SELECT id,username FROM %v WHERE id IN %v`, sqlTableAdmins, sb.String())
}
func getUserByUsernameQuery() string {
return fmt.Sprintf(`SELECT %v FROM %v WHERE username = %v`, selectUserFields, sqlTableUsers, sqlPlaceholders[0])
}
func getUsersQuery(order string) string {
return fmt.Sprintf(`SELECT %v FROM %v ORDER BY username %v LIMIT %v OFFSET %v`, selectUserFields, sqlTableUsers,
order, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getUsersForQuotaCheckQuery(numArgs int) string {
var sb strings.Builder
for idx := 0; idx < numArgs; idx++ {
if sb.Len() == 0 {
sb.WriteString("(")
} else {
sb.WriteString(",")
}
sb.WriteString(sqlPlaceholders[idx])
}
if sb.Len() > 0 {
sb.WriteString(")")
}
return fmt.Sprintf(`SELECT id,username,quota_size,used_quota_size,total_data_transfer,upload_data_transfer,
download_data_transfer,used_upload_data_transfer,used_download_data_transfer,filters FROM %v WHERE username IN %v`,
sqlTableUsers, sb.String())
}
func getRecentlyUpdatedUsersQuery() string {
return fmt.Sprintf(`SELECT %v FROM %v WHERE updated_at >= %v`, selectUserFields, sqlTableUsers, sqlPlaceholders[0])
}
func getDumpUsersQuery() string {
return fmt.Sprintf(`SELECT %v FROM %v`, selectUserFields, sqlTableUsers)
}
func getDumpFoldersQuery() string {
return fmt.Sprintf(`SELECT %v FROM %v`, selectFolderFields, sqlTableFolders)
}
func getUpdateTransferQuotaQuery(reset bool) string {
if reset {
return fmt.Sprintf(`UPDATE %v SET used_upload_data_transfer = %v,used_download_data_transfer = %v,last_quota_update = %v
WHERE username = %v`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3])
}
return fmt.Sprintf(`UPDATE %v SET used_upload_data_transfer = used_upload_data_transfer + %v,
used_download_data_transfer = used_download_data_transfer + %v,last_quota_update = %v
WHERE username = %v`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3])
}
func getUpdateQuotaQuery(reset bool) string {
if reset {
return fmt.Sprintf(`UPDATE %v SET used_quota_size = %v,used_quota_files = %v,last_quota_update = %v
WHERE username = %v`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3])
}
return fmt.Sprintf(`UPDATE %v SET used_quota_size = used_quota_size + %v,used_quota_files = used_quota_files + %v,last_quota_update = %v
WHERE username = %v`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3])
}
func getSetUpdateAtQuery() string {
return fmt.Sprintf(`UPDATE %v SET updated_at = %v WHERE username = %v`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getUpdateLastLoginQuery() string {
return fmt.Sprintf(`UPDATE %v SET last_login = %v WHERE username = %v`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getUpdateAdminLastLoginQuery() string {
return fmt.Sprintf(`UPDATE %v SET last_login = %v WHERE username = %v`, sqlTableAdmins, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getUpdateAPIKeyLastUseQuery() string {
return fmt.Sprintf(`UPDATE %v SET last_use_at = %v WHERE key_id = %v`, sqlTableAPIKeys, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getUpdateShareLastUseQuery() string {
return fmt.Sprintf(`UPDATE %v SET last_use_at = %v, used_tokens = used_tokens +%v WHERE share_id = %v`,
sqlTableShares, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2])
}
func getQuotaQuery() string {
return fmt.Sprintf(`SELECT used_quota_size,used_quota_files,used_upload_data_transfer,
used_download_data_transfer FROM %v WHERE username = %v`,
sqlTableUsers, sqlPlaceholders[0])
}
func getAddUserQuery() string {
return fmt.Sprintf(`INSERT INTO %v (username,password,public_keys,home_dir,uid,gid,max_sessions,quota_size,quota_files,permissions,
used_quota_size,used_quota_files,last_quota_update,upload_bandwidth,download_bandwidth,status,last_login,expiration_date,filters,
filesystem,additional_info,description,email,created_at,updated_at,upload_data_transfer,download_data_transfer,total_data_transfer,
used_upload_data_transfer,used_download_data_transfer)
VALUES (%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,0,0,0,%v,%v,%v,0,%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,0,0)`,
sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4],
sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9],
sqlPlaceholders[10], sqlPlaceholders[11], sqlPlaceholders[12], sqlPlaceholders[13], sqlPlaceholders[14],
sqlPlaceholders[15], sqlPlaceholders[16], sqlPlaceholders[17], sqlPlaceholders[18], sqlPlaceholders[19],
sqlPlaceholders[20], sqlPlaceholders[21], sqlPlaceholders[22], sqlPlaceholders[23])
}
func getUpdateUserQuery() string {
return fmt.Sprintf(`UPDATE %v SET password=%v,public_keys=%v,home_dir=%v,uid=%v,gid=%v,max_sessions=%v,quota_size=%v,
quota_files=%v,permissions=%v,upload_bandwidth=%v,download_bandwidth=%v,status=%v,expiration_date=%v,filters=%v,filesystem=%v,
additional_info=%v,description=%v,email=%v,updated_at=%v,upload_data_transfer=%v,download_data_transfer=%v,
total_data_transfer=%v WHERE id = %v`,
sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4],
sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8], sqlPlaceholders[9],
sqlPlaceholders[10], sqlPlaceholders[11], sqlPlaceholders[12], sqlPlaceholders[13], sqlPlaceholders[14],
sqlPlaceholders[15], sqlPlaceholders[16], sqlPlaceholders[17], sqlPlaceholders[18], sqlPlaceholders[19],
sqlPlaceholders[20], sqlPlaceholders[21], sqlPlaceholders[22])
}
func getUpdateUserPasswordQuery() string {
return fmt.Sprintf(`UPDATE %v SET password=%v WHERE username = %v`, sqlTableUsers, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getDeleteUserQuery() string {
return fmt.Sprintf(`DELETE FROM %v WHERE id = %v`, sqlTableUsers, sqlPlaceholders[0])
}
func getFolderByNameQuery() string {
return fmt.Sprintf(`SELECT %v FROM %v WHERE name = %v`, selectFolderFields, sqlTableFolders, sqlPlaceholders[0])
}
func getAddFolderQuery() string {
return fmt.Sprintf(`INSERT INTO %v (path,used_quota_size,used_quota_files,last_quota_update,name,description,filesystem)
VALUES (%v,%v,%v,%v,%v,%v,%v)`, sqlTableFolders, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2],
sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6])
}
func getUpdateFolderQuery() string {
return fmt.Sprintf(`UPDATE %v SET path=%v,description=%v,filesystem=%v WHERE name = %v`, sqlTableFolders, sqlPlaceholders[0],
sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3])
}
func getDeleteFolderQuery() string {
return fmt.Sprintf(`DELETE FROM %v WHERE id = %v`, sqlTableFolders, sqlPlaceholders[0])
}
func getUpsertFolderQuery() string {
if config.Driver == MySQLDataProviderName {
return fmt.Sprintf("INSERT INTO %v (`path`,`used_quota_size`,`used_quota_files`,`last_quota_update`,`name`,"+
"`description`,`filesystem`) VALUES (%v,%v,%v,%v,%v,%v,%v) ON DUPLICATE KEY UPDATE "+
"`path`=VALUES(`path`),`description`=VALUES(`description`),`filesystem`=VALUES(`filesystem`)",
sqlTableFolders, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4],
sqlPlaceholders[5], sqlPlaceholders[6])
}
return fmt.Sprintf(`INSERT INTO %v (path,used_quota_size,used_quota_files,last_quota_update,name,description,filesystem)
VALUES (%v,%v,%v,%v,%v,%v,%v) ON CONFLICT (name) DO UPDATE SET path = EXCLUDED.path,description=EXCLUDED.description,
filesystem=EXCLUDED.filesystem`, sqlTableFolders, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2],
sqlPlaceholders[3], sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6])
}
func getClearUserGroupMappingQuery() string {
return fmt.Sprintf(`DELETE FROM %v WHERE user_id = (SELECT id FROM %v WHERE username = %v)`, sqlTableUsersGroupsMapping,
sqlTableUsers, sqlPlaceholders[0])
}
func getAddUserGroupMappingQuery() string {
return fmt.Sprintf(`INSERT INTO %v (user_id,group_id,group_type) VALUES ((SELECT id FROM %v WHERE username = %v),
(SELECT id FROM %v WHERE name = %v),%v)`,
sqlTableUsersGroupsMapping, sqlTableUsers, sqlPlaceholders[0], getSQLTableGroups(), sqlPlaceholders[1], sqlPlaceholders[2])
}
func getClearGroupFolderMappingQuery() string {
return fmt.Sprintf(`DELETE FROM %v WHERE group_id = (SELECT id FROM %v WHERE name = %v)`, sqlTableGroupsFoldersMapping,
getSQLTableGroups(), sqlPlaceholders[0])
}
func getAddGroupFolderMappingQuery() string {
return fmt.Sprintf(`INSERT INTO %v (virtual_path,quota_size,quota_files,folder_id,group_id)
VALUES (%v,%v,%v,(SELECT id FROM %v WHERE name = %v),(SELECT id FROM %v WHERE name = %v))`,
sqlTableGroupsFoldersMapping, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlTableFolders,
sqlPlaceholders[3], getSQLTableGroups(), sqlPlaceholders[4])
}
func getClearUserFolderMappingQuery() string {
return fmt.Sprintf(`DELETE FROM %v WHERE user_id = (SELECT id FROM %v WHERE username = %v)`, sqlTableUsersFoldersMapping,
sqlTableUsers, sqlPlaceholders[0])
}
func getAddUserFolderMappingQuery() string {
return fmt.Sprintf(`INSERT INTO %v (virtual_path,quota_size,quota_files,folder_id,user_id)
VALUES (%v,%v,%v,(SELECT id FROM %v WHERE name = %v),(SELECT id FROM %v WHERE username = %v))`,
sqlTableUsersFoldersMapping, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlTableFolders,
sqlPlaceholders[3], sqlTableUsers, sqlPlaceholders[4])
}
func getFoldersQuery(order string, minimal bool) string {
var fieldSelection string
if minimal {
fieldSelection = "id,name"
} else {
fieldSelection = selectFolderFields
}
return fmt.Sprintf(`SELECT %v FROM %v ORDER BY name %v LIMIT %v OFFSET %v`, fieldSelection, sqlTableFolders,
order, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getUpdateFolderQuotaQuery(reset bool) string {
if reset {
return fmt.Sprintf(`UPDATE %v SET used_quota_size = %v,used_quota_files = %v,last_quota_update = %v
WHERE name = %v`, sqlTableFolders, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3])
}
return fmt.Sprintf(`UPDATE %v SET used_quota_size = used_quota_size + %v,used_quota_files = used_quota_files + %v,last_quota_update = %v
WHERE name = %v`, sqlTableFolders, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3])
}
func getQuotaFolderQuery() string {
return fmt.Sprintf(`SELECT used_quota_size,used_quota_files FROM %v WHERE name = %v`, sqlTableFolders,
sqlPlaceholders[0])
}
func getRelatedGroupsForUsersQuery(users []User) string {
var sb strings.Builder
for _, u := range users {
if sb.Len() == 0 {
sb.WriteString("(")
} else {
sb.WriteString(",")
}
sb.WriteString(strconv.FormatInt(u.ID, 10))
}
if sb.Len() > 0 {
sb.WriteString(")")
}
return fmt.Sprintf(`SELECT g.name,ug.group_type,ug.user_id FROM %v g INNER JOIN %v ug ON g.id = ug.group_id WHERE
ug.user_id IN %v ORDER BY ug.user_id`, getSQLTableGroups(), sqlTableUsersGroupsMapping, sb.String())
}
func getRelatedFoldersForUsersQuery(users []User) string {
var sb strings.Builder
for _, u := range users {
if sb.Len() == 0 {
sb.WriteString("(")
} else {
sb.WriteString(",")
}
sb.WriteString(strconv.FormatInt(u.ID, 10))
}
if sb.Len() > 0 {
sb.WriteString(")")
}
return fmt.Sprintf(`SELECT f.id,f.name,f.path,f.used_quota_size,f.used_quota_files,f.last_quota_update,fm.virtual_path,
fm.quota_size,fm.quota_files,fm.user_id,f.filesystem,f.description FROM %v f INNER JOIN %v fm ON f.id = fm.folder_id WHERE
fm.user_id IN %v ORDER BY fm.user_id`, sqlTableFolders, sqlTableUsersFoldersMapping, sb.String())
}
func getRelatedUsersForFoldersQuery(folders []vfs.BaseVirtualFolder) string {
var sb strings.Builder
for _, f := range folders {
if sb.Len() == 0 {
sb.WriteString("(")
} else {
sb.WriteString(",")
}
sb.WriteString(strconv.FormatInt(f.ID, 10))
}
if sb.Len() > 0 {
sb.WriteString(")")
}
return fmt.Sprintf(`SELECT fm.folder_id,u.username FROM %v fm INNER JOIN %v u ON fm.user_id = u.id
WHERE fm.folder_id IN %v ORDER BY fm.folder_id`, sqlTableUsersFoldersMapping, sqlTableUsers, sb.String())
}
func getRelatedGroupsForFoldersQuery(folders []vfs.BaseVirtualFolder) string {
var sb strings.Builder
for _, f := range folders {
if sb.Len() == 0 {
sb.WriteString("(")
} else {
sb.WriteString(",")
}
sb.WriteString(strconv.FormatInt(f.ID, 10))
}
if sb.Len() > 0 {
sb.WriteString(")")
}
return fmt.Sprintf(`SELECT fm.folder_id,g.name FROM %v fm INNER JOIN %v g ON fm.group_id = g.id
WHERE fm.folder_id IN %v ORDER BY fm.folder_id`, sqlTableGroupsFoldersMapping, getSQLTableGroups(), sb.String())
}
func getRelatedUsersForGroupsQuery(groups []Group) string {
var sb strings.Builder
for _, g := range groups {
if sb.Len() == 0 {
sb.WriteString("(")
} else {
sb.WriteString(",")
}
sb.WriteString(strconv.FormatInt(g.ID, 10))
}
if sb.Len() > 0 {
sb.WriteString(")")
}
return fmt.Sprintf(`SELECT um.group_id,u.username FROM %v um INNER JOIN %v u ON um.user_id = u.id
WHERE um.group_id IN %v ORDER BY um.group_id`, sqlTableUsersGroupsMapping, sqlTableUsers, sb.String())
}
func getRelatedFoldersForGroupsQuery(groups []Group) string {
var sb strings.Builder
for _, g := range groups {
if sb.Len() == 0 {
sb.WriteString("(")
} else {
sb.WriteString(",")
}
sb.WriteString(strconv.FormatInt(g.ID, 10))
}
if sb.Len() > 0 {
sb.WriteString(")")
}
return fmt.Sprintf(`SELECT f.id,f.name,f.path,f.used_quota_size,f.used_quota_files,f.last_quota_update,fm.virtual_path,
fm.quota_size,fm.quota_files,fm.group_id,f.filesystem,f.description FROM %s f INNER JOIN %s fm ON f.id = fm.folder_id WHERE
fm.group_id IN %v ORDER BY fm.group_id`, sqlTableFolders, sqlTableGroupsFoldersMapping, sb.String())
}
func getActiveTransfersQuery() string {
return fmt.Sprintf(`SELECT transfer_id,connection_id,transfer_type,username,folder_name,ip,truncated_size,
current_ul_size,current_dl_size,created_at,updated_at FROM %v WHERE updated_at > %v`,
sqlTableActiveTransfers, sqlPlaceholders[0])
}
func getAddActiveTransferQuery() string {
return fmt.Sprintf(`INSERT INTO %v (transfer_id,connection_id,transfer_type,username,folder_name,ip,truncated_size,
current_ul_size,current_dl_size,created_at,updated_at) VALUES (%v,%v,%v,%v,%v,%v,%v,%v,%v,%v,%v)`,
sqlTableActiveTransfers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3],
sqlPlaceholders[4], sqlPlaceholders[5], sqlPlaceholders[6], sqlPlaceholders[7], sqlPlaceholders[8],
sqlPlaceholders[9], sqlPlaceholders[10])
}
func getUpdateActiveTransferSizesQuery() string {
return fmt.Sprintf(`UPDATE %v SET current_ul_size=%v,current_dl_size=%v,updated_at=%v WHERE connection_id = %v AND transfer_id = %v`,
sqlTableActiveTransfers, sqlPlaceholders[0], sqlPlaceholders[1], sqlPlaceholders[2], sqlPlaceholders[3], sqlPlaceholders[4])
}
func getRemoveActiveTransferQuery() string {
return fmt.Sprintf(`DELETE FROM %v WHERE connection_id = %v AND transfer_id = %v`,
sqlTableActiveTransfers, sqlPlaceholders[0], sqlPlaceholders[1])
}
func getCleanupActiveTransfersQuery() string {
return fmt.Sprintf(`DELETE FROM %v WHERE updated_at < %v`, sqlTableActiveTransfers, sqlPlaceholders[0])
}
func getDatabaseVersionQuery() string {
return fmt.Sprintf("SELECT version from %v LIMIT 1", sqlTableSchemaVersion)
}
func getUpdateDBVersionQuery() string {
return fmt.Sprintf(`UPDATE %v SET version=%v`, sqlTableSchemaVersion, sqlPlaceholders[0])
}

View File

@@ -24,17 +24,18 @@ import (
"net" "net"
"os" "os"
"path" "path"
"path/filepath"
"strings" "strings"
"time" "time"
"github.com/sftpgo/sdk" "github.com/sftpgo/sdk"
"github.com/drakkan/sftpgo/v2/internal/kms" "github.com/drakkan/sftpgo/v2/kms"
"github.com/drakkan/sftpgo/v2/internal/logger" "github.com/drakkan/sftpgo/v2/logger"
"github.com/drakkan/sftpgo/v2/internal/mfa" "github.com/drakkan/sftpgo/v2/mfa"
"github.com/drakkan/sftpgo/v2/internal/plugin" "github.com/drakkan/sftpgo/v2/plugin"
"github.com/drakkan/sftpgo/v2/internal/util" "github.com/drakkan/sftpgo/v2/util"
"github.com/drakkan/sftpgo/v2/internal/vfs" "github.com/drakkan/sftpgo/v2/vfs"
) )
// Available permissions for SFTPGo users // Available permissions for SFTPGo users
@@ -142,8 +143,6 @@ type User struct {
fsCache map[string]vfs.Fs `json:"-"` fsCache map[string]vfs.Fs `json:"-"`
// true if group settings are already applied for this user // true if group settings are already applied for this user
groupSettingsApplied bool `json:"-"` groupSettingsApplied bool `json:"-"`
// in multi node setups we mark the user as deleted to be able to update the webdav cache
DeletedAt int64 `json:"-"`
} }
// GetFilesystem returns the base filesystem for this user // GetFilesystem returns the base filesystem for this user
@@ -168,8 +167,6 @@ func (u *User) getRootFs(connectionID string) (fs vfs.Fs, err error) {
} }
forbiddenSelfUsers = append(forbiddenSelfUsers, u.Username) forbiddenSelfUsers = append(forbiddenSelfUsers, u.Username)
return vfs.NewSFTPFs(connectionID, "", u.GetHomeDir(), forbiddenSelfUsers, u.FsConfig.SFTPConfig) return vfs.NewSFTPFs(connectionID, "", u.GetHomeDir(), forbiddenSelfUsers, u.FsConfig.SFTPConfig)
case sdk.HTTPFilesystemProvider:
return vfs.NewHTTPFs(connectionID, u.GetHomeDir(), "", u.FsConfig.HTTPConfig)
default: default:
return vfs.NewOsFs(connectionID, u.GetHomeDir(), ""), nil return vfs.NewOsFs(connectionID, u.GetHomeDir(), ""), nil
} }
@@ -300,7 +297,7 @@ func (u *User) isFsEqual(other *User) bool {
if u.FsConfig.Provider == sdk.LocalFilesystemProvider && u.GetHomeDir() != other.GetHomeDir() { if u.FsConfig.Provider == sdk.LocalFilesystemProvider && u.GetHomeDir() != other.GetHomeDir() {
return false return false
} }
if !u.FsConfig.IsEqual(other.FsConfig) { if !u.FsConfig.IsEqual(&other.FsConfig) {
return false return false
} }
if u.Filters.StartDirectory != other.Filters.StartDirectory { if u.Filters.StartDirectory != other.Filters.StartDirectory {
@@ -319,7 +316,7 @@ func (u *User) isFsEqual(other *User) bool {
if f.FsConfig.Provider == sdk.LocalFilesystemProvider && f.MappedPath != f1.MappedPath { if f.FsConfig.Provider == sdk.LocalFilesystemProvider && f.MappedPath != f1.MappedPath {
return false return false
} }
if !f.FsConfig.IsEqual(f1.FsConfig) { if !f.FsConfig.IsEqual(&f1.FsConfig) {
return false return false
} }
} }
@@ -373,20 +370,6 @@ func (u *User) GetSubDirPermissions() []sdk.DirectoryPermissions {
return result return result
} }
func (u *User) setAnonymousSettings() {
for k := range u.Permissions {
u.Permissions[k] = []string{PermListItems, PermDownload}
}
u.Filters.DeniedProtocols = append(u.Filters.DeniedProtocols, protocolSSH, protocolHTTP)
u.Filters.DeniedProtocols = util.RemoveDuplicates(u.Filters.DeniedProtocols, false)
for _, method := range ValidLoginMethods {
if method != LoginMethodPassword {
u.Filters.DeniedLoginMethods = append(u.Filters.DeniedLoginMethods, method)
}
}
u.Filters.DeniedLoginMethods = util.RemoveDuplicates(u.Filters.DeniedLoginMethods, false)
}
// RenderAsJSON implements the renderer interface used within plugins // RenderAsJSON implements the renderer interface used within plugins
func (u *User) RenderAsJSON(reload bool) ([]byte, error) { func (u *User) RenderAsJSON(reload bool) ([]byte, error) {
if reload { if reload {
@@ -496,10 +479,16 @@ func (u *User) GetPermissionsForPath(p string) []string {
return permissions return permissions
} }
func (u *User) getForbiddenSFTPSelfUsers(username string) ([]string, error) { // HasBufferedSFTP returns true if the user has a SFTP filesystem with buffering enabled
if allowSelfConnections == 0 { func (u *User) HasBufferedSFTP(name string) bool {
return nil, nil fs := u.GetFsConfigForPath(name)
if fs.Provider == sdk.SFTPFilesystemProvider {
return fs.SFTPConfig.BufferSize > 0
} }
return false
}
func (u *User) getForbiddenSFTPSelfUsers(username string) ([]string, error) {
sftpUser, err := UserExists(username) sftpUser, err := UserExists(username)
if err == nil { if err == nil {
err = sftpUser.LoadAndApplyGroupSettings() err = sftpUser.LoadAndApplyGroupSettings()
@@ -716,7 +705,7 @@ func (u *User) FilterListDir(dirContents []os.FileInfo, virtualPath string) []os
for dir := range vdirs { for dir := range vdirs {
if fi.Name() == dir { if fi.Name() == dir {
if !fi.IsDir() { if !fi.IsDir() {
fi = vfs.NewFileInfo(dir, true, 0, time.Unix(0, 0), false) fi = vfs.NewFileInfo(dir, true, 0, time.Now(), false)
dirContents[index] = fi dirContents[index] = fi
} }
delete(vdirs, dir) delete(vdirs, dir)
@@ -738,7 +727,7 @@ func (u *User) FilterListDir(dirContents []os.FileInfo, virtualPath string) []os
} }
for dir := range vdirs { for dir := range vdirs {
fi := vfs.NewFileInfo(dir, true, 0, time.Unix(0, 0), false) fi := vfs.NewFileInfo(dir, true, 0, time.Now(), false)
dirContents = append(dirContents, fi) dirContents = append(dirContents, fi)
} }
return dirContents return dirContents
@@ -787,10 +776,7 @@ func (u *User) HasVirtualFoldersInside(virtualPath string) bool {
// HasPermissionsInside returns true if the specified virtualPath has no permissions itself and // HasPermissionsInside returns true if the specified virtualPath has no permissions itself and
// no subdirs with defined permissions // no subdirs with defined permissions
func (u *User) HasPermissionsInside(virtualPath string) bool { func (u *User) HasPermissionsInside(virtualPath string) bool {
for dir, perms := range u.Permissions { for dir := range u.Permissions {
if len(perms) == 1 && perms[0] == PermAny {
continue
}
if dir == virtualPath { if dir == virtualPath {
return true return true
} else if len(dir) > len(virtualPath) { } else if len(dir) > len(virtualPath) {
@@ -1164,7 +1150,7 @@ func (u *User) GetBandwidthForIP(clientIP, connectionID string) (int64, int64) {
// IsLoginFromAddrAllowed returns true if the login is allowed from the specified remoteAddr. // IsLoginFromAddrAllowed returns true if the login is allowed from the specified remoteAddr.
// If AllowedIP is defined only the specified IP/Mask can login. // If AllowedIP is defined only the specified IP/Mask can login.
// If DeniedIP is defined the specified IP/Mask cannot login. // If DeniedIP is defined the specified IP/Mask cannot login.
// If an IP is both allowed and denied then login will be allowed // If an IP is both allowed and denied then login will be denied
func (u *User) IsLoginFromAddrAllowed(remoteAddr string) bool { func (u *User) IsLoginFromAddrAllowed(remoteAddr string) bool {
if len(u.Filters.AllowedIP) == 0 && len(u.Filters.DeniedIP) == 0 { if len(u.Filters.AllowedIP) == 0 && len(u.Filters.DeniedIP) == 0 {
return true return true
@@ -1175,15 +1161,6 @@ func (u *User) IsLoginFromAddrAllowed(remoteAddr string) bool {
logger.Warn(logSender, "", "login allowed for invalid IP. remote address: %#v", remoteAddr) logger.Warn(logSender, "", "login allowed for invalid IP. remote address: %#v", remoteAddr)
return true return true
} }
for _, IPMask := range u.Filters.AllowedIP {
_, IPNet, err := net.ParseCIDR(IPMask)
if err != nil {
return false
}
if IPNet.Contains(remoteIP) {
return true
}
}
for _, IPMask := range u.Filters.DeniedIP { for _, IPMask := range u.Filters.DeniedIP {
_, IPNet, err := net.ParseCIDR(IPMask) _, IPNet, err := net.ParseCIDR(IPMask)
if err != nil { if err != nil {
@@ -1193,6 +1170,15 @@ func (u *User) IsLoginFromAddrAllowed(remoteAddr string) bool {
return false return false
} }
} }
for _, IPMask := range u.Filters.AllowedIP {
_, IPNet, err := net.ParseCIDR(IPMask)
if err != nil {
return false
}
if IPNet.Contains(remoteIP) {
return true
}
}
return len(u.Filters.AllowedIP) == 0 return len(u.Filters.AllowedIP) == 0
} }
@@ -1415,8 +1401,6 @@ func (u *User) GetStorageDescrition() string {
return fmt.Sprintf("Encrypted: %v", u.GetHomeDir()) return fmt.Sprintf("Encrypted: %v", u.GetHomeDir())
case sdk.SFTPFilesystemProvider: case sdk.SFTPFilesystemProvider:
return fmt.Sprintf("SFTP: %v", u.FsConfig.SFTPConfig.Endpoint) return fmt.Sprintf("SFTP: %v", u.FsConfig.SFTPConfig.Endpoint)
case sdk.HTTPFilesystemProvider:
return fmt.Sprintf("HTTP: %v", u.FsConfig.HTTPConfig.Endpoint)
default: default:
return "" return ""
} }
@@ -1562,37 +1546,17 @@ func (u *User) HasSecondaryGroup(name string) bool {
return false return false
} }
// HasMembershipGroup returns true if the user has the specified membership group
func (u *User) HasMembershipGroup(name string) bool {
for _, g := range u.Groups {
if g.Name == name {
return g.Type == sdk.GroupTypeMembership
}
}
return false
}
func (u *User) hasSettingsFromGroups() bool {
for _, g := range u.Groups {
if g.Type != sdk.GroupTypeMembership {
return true
}
}
return false
}
func (u *User) applyGroupSettings(groupsMapping map[string]Group) { func (u *User) applyGroupSettings(groupsMapping map[string]Group) {
if !u.hasSettingsFromGroups() { if len(u.Groups) == 0 {
return return
} }
if u.groupSettingsApplied { if u.groupSettingsApplied {
return return
} }
replacer := u.getGroupPlacehodersReplacer()
for _, g := range u.Groups { for _, g := range u.Groups {
if g.Type == sdk.GroupTypePrimary { if g.Type == sdk.GroupTypePrimary {
if group, ok := groupsMapping[g.Name]; ok { if group, ok := groupsMapping[g.Name]; ok {
u.mergeWithPrimaryGroup(group, replacer) u.mergeWithPrimaryGroup(group)
} else { } else {
providerLog(logger.LevelError, "mapping not found for user %s, group %s", u.Username, g.Name) providerLog(logger.LevelError, "mapping not found for user %s, group %s", u.Username, g.Name)
} }
@@ -1602,7 +1566,7 @@ func (u *User) applyGroupSettings(groupsMapping map[string]Group) {
for _, g := range u.Groups { for _, g := range u.Groups {
if g.Type == sdk.GroupTypeSecondary { if g.Type == sdk.GroupTypeSecondary {
if group, ok := groupsMapping[g.Name]; ok { if group, ok := groupsMapping[g.Name]; ok {
u.mergeAdditiveProperties(group, sdk.GroupTypeSecondary, replacer) u.mergeAdditiveProperties(group, sdk.GroupTypeSecondary)
} else { } else {
providerLog(logger.LevelError, "mapping not found for user %s, group %s", u.Username, g.Name) providerLog(logger.LevelError, "mapping not found for user %s, group %s", u.Username, g.Name)
} }
@@ -1613,7 +1577,7 @@ func (u *User) applyGroupSettings(groupsMapping map[string]Group) {
// LoadAndApplyGroupSettings update the user by loading and applying the group settings // LoadAndApplyGroupSettings update the user by loading and applying the group settings
func (u *User) LoadAndApplyGroupSettings() error { func (u *User) LoadAndApplyGroupSettings() error {
if !u.hasSettingsFromGroups() { if len(u.Groups) == 0 {
return nil return nil
} }
if u.groupSettingsApplied { if u.groupSettingsApplied {
@@ -1625,19 +1589,16 @@ func (u *User) LoadAndApplyGroupSettings() error {
if g.Type == sdk.GroupTypePrimary { if g.Type == sdk.GroupTypePrimary {
primaryGroupName = g.Name primaryGroupName = g.Name
} }
if g.Type != sdk.GroupTypeMembership { names = append(names, g.Name)
names = append(names, g.Name)
}
} }
groups, err := provider.getGroupsWithNames(names) groups, err := provider.getGroupsWithNames(names)
if err != nil { if err != nil {
return fmt.Errorf("unable to get groups: %w", err) return fmt.Errorf("unable to get groups: %w", err)
} }
replacer := u.getGroupPlacehodersReplacer()
// make sure to always merge with the primary group first // make sure to always merge with the primary group first
for idx, g := range groups { for idx, g := range groups {
if g.Name == primaryGroupName { if g.Name == primaryGroupName {
u.mergeWithPrimaryGroup(g, replacer) u.mergeWithPrimaryGroup(g)
lastIdx := len(groups) - 1 lastIdx := len(groups) - 1
groups[idx] = groups[lastIdx] groups[idx] = groups[lastIdx]
groups = groups[:lastIdx] groups = groups[:lastIdx]
@@ -1645,46 +1606,40 @@ func (u *User) LoadAndApplyGroupSettings() error {
} }
} }
for _, g := range groups { for _, g := range groups {
u.mergeAdditiveProperties(g, sdk.GroupTypeSecondary, replacer) u.mergeAdditiveProperties(g, sdk.GroupTypeSecondary)
} }
u.removeDuplicatesAfterGroupMerge() u.removeDuplicatesAfterGroupMerge()
return nil return nil
} }
func (u *User) getGroupPlacehodersReplacer() *strings.Replacer { func (u *User) replacePlaceholder(value string) string {
return strings.NewReplacer("%username%", u.Username)
}
func (u *User) replacePlaceholder(value string, replacer *strings.Replacer) string {
if value == "" { if value == "" {
return value return value
} }
return replacer.Replace(value) return strings.ReplaceAll(value, "%username%", u.Username)
} }
func (u *User) replaceFsConfigPlaceholders(fsConfig vfs.Filesystem, replacer *strings.Replacer) vfs.Filesystem { func (u *User) replaceFsConfigPlaceholders(fsConfig vfs.Filesystem) vfs.Filesystem {
switch fsConfig.Provider { switch fsConfig.Provider {
case sdk.S3FilesystemProvider: case sdk.S3FilesystemProvider:
fsConfig.S3Config.KeyPrefix = u.replacePlaceholder(fsConfig.S3Config.KeyPrefix, replacer) fsConfig.S3Config.KeyPrefix = u.replacePlaceholder(fsConfig.S3Config.KeyPrefix)
case sdk.GCSFilesystemProvider: case sdk.GCSFilesystemProvider:
fsConfig.GCSConfig.KeyPrefix = u.replacePlaceholder(fsConfig.GCSConfig.KeyPrefix, replacer) fsConfig.GCSConfig.KeyPrefix = u.replacePlaceholder(fsConfig.GCSConfig.KeyPrefix)
case sdk.AzureBlobFilesystemProvider: case sdk.AzureBlobFilesystemProvider:
fsConfig.AzBlobConfig.KeyPrefix = u.replacePlaceholder(fsConfig.AzBlobConfig.KeyPrefix, replacer) fsConfig.AzBlobConfig.KeyPrefix = u.replacePlaceholder(fsConfig.AzBlobConfig.KeyPrefix)
case sdk.SFTPFilesystemProvider: case sdk.SFTPFilesystemProvider:
fsConfig.SFTPConfig.Username = u.replacePlaceholder(fsConfig.SFTPConfig.Username, replacer) fsConfig.SFTPConfig.Username = u.replacePlaceholder(fsConfig.SFTPConfig.Username)
fsConfig.SFTPConfig.Prefix = u.replacePlaceholder(fsConfig.SFTPConfig.Prefix, replacer) fsConfig.SFTPConfig.Prefix = u.replacePlaceholder(fsConfig.SFTPConfig.Prefix)
case sdk.HTTPFilesystemProvider:
fsConfig.HTTPConfig.Username = u.replacePlaceholder(fsConfig.HTTPConfig.Username, replacer)
} }
return fsConfig return fsConfig
} }
func (u *User) mergeWithPrimaryGroup(group Group, replacer *strings.Replacer) { func (u *User) mergeWithPrimaryGroup(group Group) {
if group.UserSettings.HomeDir != "" { if group.UserSettings.HomeDir != "" {
u.HomeDir = u.replacePlaceholder(group.UserSettings.HomeDir, replacer) u.HomeDir = u.replacePlaceholder(group.UserSettings.HomeDir)
} }
if group.UserSettings.FsConfig.Provider != 0 { if group.UserSettings.FsConfig.Provider != 0 {
u.FsConfig = u.replaceFsConfigPlaceholders(group.UserSettings.FsConfig, replacer) u.FsConfig = u.replaceFsConfigPlaceholders(group.UserSettings.FsConfig)
} }
if u.MaxSessions == 0 { if u.MaxSessions == 0 {
u.MaxSessions = group.UserSettings.MaxSessions u.MaxSessions = group.UserSettings.MaxSessions
@@ -1706,11 +1661,11 @@ func (u *User) mergeWithPrimaryGroup(group Group, replacer *strings.Replacer) {
u.DownloadDataTransfer = group.UserSettings.DownloadDataTransfer u.DownloadDataTransfer = group.UserSettings.DownloadDataTransfer
u.TotalDataTransfer = group.UserSettings.TotalDataTransfer u.TotalDataTransfer = group.UserSettings.TotalDataTransfer
} }
u.mergePrimaryGroupFilters(group.UserSettings.Filters, replacer) u.mergePrimaryGroupFilters(group.UserSettings.Filters)
u.mergeAdditiveProperties(group, sdk.GroupTypePrimary, replacer) u.mergeAdditiveProperties(group, sdk.GroupTypePrimary)
} }
func (u *User) mergePrimaryGroupFilters(filters sdk.BaseUserFilters, replacer *strings.Replacer) { func (u *User) mergePrimaryGroupFilters(filters sdk.BaseUserFilters) {
if u.Filters.MaxUploadFileSize == 0 { if u.Filters.MaxUploadFileSize == 0 {
u.Filters.MaxUploadFileSize = filters.MaxUploadFileSize u.Filters.MaxUploadFileSize = filters.MaxUploadFileSize
} }
@@ -1732,27 +1687,18 @@ func (u *User) mergePrimaryGroupFilters(filters sdk.BaseUserFilters, replacer *s
if !u.Filters.AllowAPIKeyAuth { if !u.Filters.AllowAPIKeyAuth {
u.Filters.AllowAPIKeyAuth = filters.AllowAPIKeyAuth u.Filters.AllowAPIKeyAuth = filters.AllowAPIKeyAuth
} }
if !u.Filters.IsAnonymous {
u.Filters.IsAnonymous = filters.IsAnonymous
}
if u.Filters.ExternalAuthCacheTime == 0 { if u.Filters.ExternalAuthCacheTime == 0 {
u.Filters.ExternalAuthCacheTime = filters.ExternalAuthCacheTime u.Filters.ExternalAuthCacheTime = filters.ExternalAuthCacheTime
} }
if u.Filters.FTPSecurity == 0 {
u.Filters.FTPSecurity = filters.FTPSecurity
}
if u.Filters.StartDirectory == "" { if u.Filters.StartDirectory == "" {
u.Filters.StartDirectory = u.replacePlaceholder(filters.StartDirectory, replacer) u.Filters.StartDirectory = u.replacePlaceholder(filters.StartDirectory)
}
if u.Filters.DefaultSharesExpiration == 0 {
u.Filters.DefaultSharesExpiration = filters.DefaultSharesExpiration
} }
} }
func (u *User) mergeAdditiveProperties(group Group, groupType int, replacer *strings.Replacer) { func (u *User) mergeAdditiveProperties(group Group, groupType int) {
u.mergeVirtualFolders(group, groupType, replacer) u.mergeVirtualFolders(group, groupType)
u.mergePermissions(group, groupType, replacer) u.mergePermissions(group, groupType)
u.mergeFilePatterns(group, groupType, replacer) u.mergeFilePatterns(group, groupType)
u.Filters.BandwidthLimits = append(u.Filters.BandwidthLimits, group.UserSettings.Filters.BandwidthLimits...) u.Filters.BandwidthLimits = append(u.Filters.BandwidthLimits, group.UserSettings.Filters.BandwidthLimits...)
u.Filters.DataTransferLimits = append(u.Filters.DataTransferLimits, group.UserSettings.Filters.DataTransferLimits...) u.Filters.DataTransferLimits = append(u.Filters.DataTransferLimits, group.UserSettings.Filters.DataTransferLimits...)
u.Filters.AllowedIP = append(u.Filters.AllowedIP, group.UserSettings.Filters.AllowedIP...) u.Filters.AllowedIP = append(u.Filters.AllowedIP, group.UserSettings.Filters.AllowedIP...)
@@ -1763,7 +1709,7 @@ func (u *User) mergeAdditiveProperties(group Group, groupType int, replacer *str
u.Filters.TwoFactorAuthProtocols = append(u.Filters.TwoFactorAuthProtocols, group.UserSettings.Filters.TwoFactorAuthProtocols...) u.Filters.TwoFactorAuthProtocols = append(u.Filters.TwoFactorAuthProtocols, group.UserSettings.Filters.TwoFactorAuthProtocols...)
} }
func (u *User) mergeVirtualFolders(group Group, groupType int, replacer *strings.Replacer) { func (u *User) mergeVirtualFolders(group Group, groupType int) {
if len(group.VirtualFolders) > 0 { if len(group.VirtualFolders) > 0 {
folderPaths := make(map[string]bool) folderPaths := make(map[string]bool)
for _, folder := range u.VirtualFolders { for _, folder := range u.VirtualFolders {
@@ -1773,17 +1719,17 @@ func (u *User) mergeVirtualFolders(group Group, groupType int, replacer *strings
if folder.VirtualPath == "/" && groupType != sdk.GroupTypePrimary { if folder.VirtualPath == "/" && groupType != sdk.GroupTypePrimary {
continue continue
} }
folder.VirtualPath = u.replacePlaceholder(folder.VirtualPath, replacer) folder.VirtualPath = u.replacePlaceholder(folder.VirtualPath)
if _, ok := folderPaths[folder.VirtualPath]; !ok { if _, ok := folderPaths[folder.VirtualPath]; !ok {
folder.MappedPath = u.replacePlaceholder(folder.MappedPath, replacer) folder.MappedPath = u.replacePlaceholder(folder.MappedPath)
folder.FsConfig = u.replaceFsConfigPlaceholders(folder.FsConfig, replacer) folder.FsConfig = u.replaceFsConfigPlaceholders(folder.FsConfig)
u.VirtualFolders = append(u.VirtualFolders, folder) u.VirtualFolders = append(u.VirtualFolders, folder)
} }
} }
} }
} }
func (u *User) mergePermissions(group Group, groupType int, replacer *strings.Replacer) { func (u *User) mergePermissions(group Group, groupType int) {
for k, v := range group.UserSettings.Permissions { for k, v := range group.UserSettings.Permissions {
if k == "/" { if k == "/" {
if groupType == sdk.GroupTypePrimary { if groupType == sdk.GroupTypePrimary {
@@ -1792,14 +1738,14 @@ func (u *User) mergePermissions(group Group, groupType int, replacer *strings.Re
continue continue
} }
} }
k = u.replacePlaceholder(k, replacer) k = u.replacePlaceholder(k)
if _, ok := u.Permissions[k]; !ok { if _, ok := u.Permissions[k]; !ok {
u.Permissions[k] = v u.Permissions[k] = v
} }
} }
} }
func (u *User) mergeFilePatterns(group Group, groupType int, replacer *strings.Replacer) { func (u *User) mergeFilePatterns(group Group, groupType int) {
if len(group.UserSettings.Filters.FilePatterns) > 0 { if len(group.UserSettings.Filters.FilePatterns) > 0 {
patternPaths := make(map[string]bool) patternPaths := make(map[string]bool)
for _, pattern := range u.Filters.FilePatterns { for _, pattern := range u.Filters.FilePatterns {
@@ -1809,7 +1755,7 @@ func (u *User) mergeFilePatterns(group Group, groupType int, replacer *strings.R
if pattern.Path == "/" && groupType != sdk.GroupTypePrimary { if pattern.Path == "/" && groupType != sdk.GroupTypePrimary {
continue continue
} }
pattern.Path = u.replacePlaceholder(pattern.Path, replacer) pattern.Path = u.replacePlaceholder(pattern.Path)
if _, ok := patternPaths[pattern.Path]; !ok { if _, ok := patternPaths[pattern.Path]; !ok {
u.Filters.FilePatterns = append(u.Filters.FilePatterns, pattern) u.Filters.FilePatterns = append(u.Filters.FilePatterns, pattern)
} }
@@ -1896,8 +1842,6 @@ func (u *User) getACopy() User {
Status: u.Status, Status: u.Status,
ExpirationDate: u.ExpirationDate, ExpirationDate: u.ExpirationDate,
LastLogin: u.LastLogin, LastLogin: u.LastLogin,
FirstDownload: u.FirstDownload,
FirstUpload: u.FirstUpload,
AdditionalInfo: u.AdditionalInfo, AdditionalInfo: u.AdditionalInfo,
Description: u.Description, Description: u.Description,
CreatedAt: u.CreatedAt, CreatedAt: u.CreatedAt,
@@ -1915,3 +1859,8 @@ func (u *User) getACopy() User {
func (u *User) GetEncryptionAdditionalData() string { func (u *User) GetEncryptionAdditionalData() string {
return u.Username return u.Username
} }
// GetGCSCredentialsFilePath returns the path for GCS credentials
func (u *User) GetGCSCredentialsFilePath() string {
return filepath.Join(credentialsDirPath, fmt.Sprintf("%v_gcs_credentials.json", u.Username))
}

View File

@@ -4,14 +4,12 @@ SFTPGo provides an official Docker image, it is available on both [Docker Hub](h
## Supported tags and respective Dockerfile links ## Supported tags and respective Dockerfile links
- [v2.4.0, v2.4, v2, latest](https://github.com/drakkan/sftpgo/blob/v2.4.0/Dockerfile) - [v2.3.3, v2.3, v2, latest](https://github.com/drakkan/sftpgo/blob/v2.3.3/Dockerfile)
- [v2.4.0-plugins, v2.4-plugins, v2-plugins, plugins](https://github.com/drakkan/sftpgo/blob/v2.4.0/Dockerfile) - [v2.3.3-alpine, v2.3-alpine, v2-alpine, alpine](https://github.com/drakkan/sftpgo/blob/v2.3.3/Dockerfile.alpine)
- [v2.4.0-alpine, v2.4-alpine, v2-alpine, alpine](https://github.com/drakkan/sftpgo/blob/v2.4.0/Dockerfile.alpine) - [v2.3.3-slim, v2.3-slim, v2-slim, slim](https://github.com/drakkan/sftpgo/blob/v2.3.3/Dockerfile)
- [v2.4.0-slim, v2.4-slim, v2-slim, slim](https://github.com/drakkan/sftpgo/blob/v2.4.0/Dockerfile) - [v2.3.3-alpine-slim, v2.3-alpine-slim, v2-alpine-slim, alpine-slim](https://github.com/drakkan/sftpgo/blob/v2.3.3/Dockerfile.alpine)
- [v2.4.0-alpine-slim, v2.4-alpine-slim, v2-alpine-slim, alpine-slim](https://github.com/drakkan/sftpgo/blob/v2.4.0/Dockerfile.alpine) - [v2.3.3-distroless-slim, v2.3-distroless-slim, v2-distroless-slim, distroless-slim](https://github.com/drakkan/sftpgo/blob/v2.3.3/Dockerfile.distroless)
- [v2.4.0-distroless-slim, v2.4-distroless-slim, v2-distroless-slim, distroless-slim](https://github.com/drakkan/sftpgo/blob/v2.4.0/Dockerfile.distroless)
- [edge](../Dockerfile) - [edge](../Dockerfile)
- [edge-plugins](../Dockerfile)
- [edge-alpine](../Dockerfile.alpine) - [edge-alpine](../Dockerfile.alpine)
- [edge-slim](../Dockerfile) - [edge-slim](../Dockerfile)
- [edge-alpine-slim](../Dockerfile.alpine) - [edge-alpine-slim](../Dockerfile.alpine)
@@ -199,17 +197,8 @@ We only provide the slim variant and so the optional `git` dependency is not ava
### `sftpgo:<suite>-slim` ### `sftpgo:<suite>-slim`
These tags provide a slimmer image that does not include `jq` and the optional `git` and `rsync` dependencies. These tags provide a slimmer image that does not include the optional `git` dependency.
### `sftpgo:<suite>-plugins`
These tags provide the standard image with the addition of all "official" plugins installed in `/usr/local/bin`.
## Helm Chart ## Helm Chart
Some helm charts are available: An helm chart is [available](https://artifacthub.io/packages/helm/sagikazarmark/sftpgo). You can find the source code [here](https://github.com/sagikazarmark/helm-charts/tree/master/charts/sftpgo).
- [sagikazarmark/sftpgo](https://artifacthub.io/packages/helm/sagikazarmark/sftpgo)
- [truecharts/sftpgo](https://artifacthub.io/packages/helm/truecharts/sftpgo)
These charts are not maintained by the SFTPGo project and any issues with the charts should be raised to the upstream repo.

View File

@@ -1,25 +0,0 @@
#!/usr/bin/env bash
set -e
ARCH=`uname -m`
case ${ARCH} in
"x86_64")
SUFFIX=amd64
;;
"aarch64")
SUFFIX=arm64
;;
*)
SUFFIX=ppc64le
;;
esac
echo "download plugins for arch ${SUFFIX}"
for PLUGIN in geoipfilter kms pubsub eventstore eventsearch metadata
do
echo "download plugin from https://github.com/sftpgo/sftpgo-plugin-${PLUGIN}/releases/latest/download/sftpgo-plugin-${PLUGIN}-linux-${SUFFIX}"
curl -L "https://github.com/sftpgo/sftpgo-plugin-${PLUGIN}/releases/latest/download/sftpgo-plugin-${PLUGIN}-linux-${SUFFIX}" --output "/usr/local/bin/sftpgo-plugin-${PLUGIN}"
chmod 755 "/usr/local/bin/sftpgo-plugin-${PLUGIN}"
done

View File

@@ -0,0 +1,28 @@
#!/usr/bin/env bash
SFTPGO_PUID=${SFTPGO_PUID:-1000}
SFTPGO_PGID=${SFTPGO_PGID:-1000}
if [ "$1" = 'sftpgo' ]; then
if [ "$(id -u)" = '0' ]; then
for DIR in "/etc/sftpgo" "/var/lib/sftpgo" "/srv/sftpgo"
do
DIR_UID=$(stat -c %u ${DIR})
DIR_GID=$(stat -c %g ${DIR})
if [ ${DIR_UID} != ${SFTPGO_PUID} ] || [ ${DIR_GID} != ${SFTPGO_PGID} ]; then
echo '{"level":"info","time":"'`date +%Y-%m-%dT%H:%M:%S.000`'","sender":"entrypoint","message":"change owner for \"'${DIR}'\" UID: '${SFTPGO_PUID}' GID: '${SFTPGO_PGID}'"}'
if [ ${DIR} = "/etc/sftpgo" ]; then
chown -R ${SFTPGO_PUID}:${SFTPGO_PGID} ${DIR}
else
chown ${SFTPGO_PUID}:${SFTPGO_PGID} ${DIR}
fi
fi
done
echo '{"level":"info","time":"'`date +%Y-%m-%dT%H:%M:%S.000`'","sender":"entrypoint","message":"run as UID: '${SFTPGO_PUID}' GID: '${SFTPGO_PGID}'"}'
exec su-exec ${SFTPGO_PUID}:${SFTPGO_PGID} "$@"
fi
exec "$@"
fi
exec "$@"

32
docker/scripts/entrypoint.sh Executable file
View File

@@ -0,0 +1,32 @@
#!/usr/bin/env bash
SFTPGO_PUID=${SFTPGO_PUID:-1000}
SFTPGO_PGID=${SFTPGO_PGID:-1000}
if [ "$1" = 'sftpgo' ]; then
if [ "$(id -u)" = '0' ]; then
getent passwd ${SFTPGO_PUID} > /dev/null
HAS_PUID=$?
getent group ${SFTPGO_PGID} > /dev/null
HAS_PGID=$?
if [ ${HAS_PUID} -ne 0 ] || [ ${HAS_PGID} -ne 0 ]; then
echo '{"level":"info","time":"'`date +%Y-%m-%dT%H:%M:%S.%3N`'","sender":"entrypoint","message":"prepare to run as UID: '${SFTPGO_PUID}' GID: '${SFTPGO_PGID}'"}'
if [ ${HAS_PGID} -ne 0 ]; then
echo '{"level":"info","time":"'`date +%Y-%m-%dT%H:%M:%S.%3N`'","sender":"entrypoint","message":"set GID to: '${SFTPGO_PGID}'"}'
groupmod -g ${SFTPGO_PGID} sftpgo
fi
if [ ${HAS_PUID} -ne 0 ]; then
echo '{"level":"info","time":"'`date +%Y-%m-%dT%H:%M:%S.%3N`'","sender":"entrypoint","message":"set UID to: '${SFTPGO_PUID}'"}'
usermod -u ${SFTPGO_PUID} sftpgo
fi
chown -R ${SFTPGO_PUID}:${SFTPGO_PGID} /etc/sftpgo
chown ${SFTPGO_PUID}:${SFTPGO_PGID} /var/lib/sftpgo /srv/sftpgo
fi
echo '{"level":"info","time":"'`date +%Y-%m-%dT%H:%M:%S.%3N`'","sender":"entrypoint","message":"run as UID: '${SFTPGO_PUID}' GID: '${SFTPGO_PGID}'"}'
exec gosu ${SFTPGO_PUID}:${SFTPGO_PGID} "$@"
fi
exec "$@"
fi
exec "$@"

View File

@@ -0,0 +1,50 @@
FROM golang:alpine as builder
RUN apk add --no-cache git gcc g++ ca-certificates \
&& go get -v -d github.com/drakkan/sftpgo
WORKDIR /go/src/github.com/drakkan/sftpgo
ARG TAG
ARG FEATURES
# Use --build-arg TAG=LATEST for latest tag. Use e.g. --build-arg TAG=v1.0.0 for a specific tag/commit. Otherwise HEAD (master) is built.
RUN git checkout $(if [ "${TAG}" = LATEST ]; then echo `git rev-list --tags --max-count=1`; elif [ -n "${TAG}" ]; then echo "${TAG}"; else echo HEAD; fi)
RUN go build $(if [ -n "${FEATURES}" ]; then echo "-tags ${FEATURES}"; fi) -ldflags "-s -w -X github.com/drakkan/sftpgo/version.commit=`git describe --always --dirty` -X github.com/drakkan/sftpgo/version.date=`date -u +%FT%TZ`" -v -o /go/bin/sftpgo
FROM alpine:latest
RUN apk add --no-cache ca-certificates su-exec \
&& mkdir -p /data /etc/sftpgo /srv/sftpgo/config /srv/sftpgo/web /srv/sftpgo/backups
# git and rsync are optional, uncomment the next line to add support for them if needed.
#RUN apk add --no-cache git rsync
COPY --from=builder /go/bin/sftpgo /bin/
COPY --from=builder /go/src/github.com/drakkan/sftpgo/sftpgo.json /etc/sftpgo/sftpgo.json
COPY --from=builder /go/src/github.com/drakkan/sftpgo/templates /srv/sftpgo/web/templates
COPY --from=builder /go/src/github.com/drakkan/sftpgo/static /srv/sftpgo/web/static
COPY docker-entrypoint.sh /bin/entrypoint.sh
RUN chmod +x /bin/entrypoint.sh
VOLUME [ "/data", "/srv/sftpgo/config", "/srv/sftpgo/backups" ]
EXPOSE 2022 8080
# uncomment the following settings to enable FTP support
#ENV SFTPGO_FTPD__BIND_PORT=2121
#ENV SFTPGO_FTPD__FORCE_PASSIVE_IP=<your FTP visibile IP here>
#EXPOSE 2121
# we need to expose the passive ports range too
#EXPOSE 50000-50100
# it is a good idea to provide certificates to enable FTPS too
#ENV SFTPGO_FTPD__CERTIFICATE_FILE=/srv/sftpgo/config/mycert.crt
#ENV SFTPGO_FTPD__CERTIFICATE_KEY_FILE=/srv/sftpgo/config/mycert.key
# uncomment the following setting to enable WebDAV support
#ENV SFTPGO_WEBDAVD__BIND_PORT=8090
# it is a good idea to provide certificates to enable WebDAV over HTTPS
#ENV SFTPGO_WEBDAVD__CERTIFICATE_FILE=${CONFIG_DIR}/mycert.crt
#ENV SFTPGO_WEBDAVD__CERTIFICATE_KEY_FILE=${CONFIG_DIR}/mycert.key
ENTRYPOINT ["/bin/entrypoint.sh"]
CMD ["serve"]

View File

@@ -0,0 +1,61 @@
# SFTPGo with Docker and Alpine
:warning: The recommended way to run SFTPGo on Docker is to use the official [images](https://hub.docker.com/r/drakkan/sftpgo). The documentation here is now obsolete.
This DockerFile is made to build image to host multiple instances of SFTPGo started with different users.
## Example
> 1003 is a custom uid:gid for this instance of SFTPGo
```bash
# Prereq on docker host
sudo groupadd -g 1003 sftpgrp && \
sudo useradd -u 1003 -g 1003 sftpuser -d /home/sftpuser/ && \
sudo -u sftpuser mkdir /home/sftpuser/{conf,data} && \
curl https://raw.githubusercontent.com/drakkan/sftpgo/master/sftpgo.json -o /home/sftpuser/conf/sftpgo.json
# Edit sftpgo.json as you need
# Get and build SFTPGo image.
# Add --build-arg TAG=LATEST to build the latest tag or e.g. TAG=v1.0.0 for a specific tag/commit.
# Add --build-arg FEATURES=<build features comma separated> to specify the features to build.
git clone https://github.com/drakkan/sftpgo.git && \
cd sftpgo && \
sudo docker build -t sftpgo docker/sftpgo/alpine/
# Initialize the configured provider. For PostgreSQL and MySQL providers you need to create the configured database and the "initprovider" command will create the required tables.
sudo docker run --name sftpgo \
-e PUID=1003 \
-e GUID=1003 \
-v /home/sftpuser/conf/:/srv/sftpgo/config \
sftpgo initprovider -c /srv/sftpgo/config
# Start the image
sudo docker rm sftpgo && sudo docker run --name sftpgo \
-e SFTPGO_LOG_FILE_PATH= \
-e SFTPGO_CONFIG_DIR=/srv/sftpgo/config \
-e SFTPGO_HTTPD__TEMPLATES_PATH=/srv/sftpgo/web/templates \
-e SFTPGO_HTTPD__STATIC_FILES_PATH=/srv/sftpgo/web/static \
-e SFTPGO_HTTPD__BACKUPS_PATH=/srv/sftpgo/backups \
-p 8080:8080 \
-p 2022:2022 \
-e PUID=1003 \
-e GUID=1003 \
-v /home/sftpuser/conf/:/srv/sftpgo/config \
-v /home/sftpuser/data:/data \
-v /home/sftpuser/backups:/srv/sftpgo/backups \
sftpgo
```
If you want to enable FTP/S you also need the publish the FTP port and the FTP passive port range, defined in your `Dockerfile`, by adding, for example, the following options to the `docker run` command `-p 2121:2121 -p 50000-50100:50000-50100`. The same goes for WebDAV, you need to publish the configured port.
The script `entrypoint.sh` makes sure to correct the permissions of directories and start the process with the right user.
Several images can be run with different parameters.
## Custom systemd script
An example of systemd script is present [here](sftpgo.service), with `Environment` parameter to set `PUID` and `GUID`
`WorkingDirectory` parameter must be exist with one file in this directory like `sftpgo-${PUID}.env` corresponding to the variable file for SFTPGo instance.

View File

@@ -0,0 +1,7 @@
#!/bin/sh
set -eu
chown -R "${PUID}:${GUID}" /data /etc/sftpgo /srv/sftpgo/config /srv/sftpgo/backups \
&& exec su-exec "${PUID}:${GUID}" \
/bin/sftpgo "$@"

View File

@@ -0,0 +1,35 @@
[Unit]
Description=SFTPGo server
After=docker.service
[Service]
User=root
Group=root
WorkingDirectory=/etc/sftpgo
Environment=PUID=1003
Environment=GUID=1003
EnvironmentFile=-/etc/sysconfig/sftpgo.env
ExecStartPre=-docker kill sftpgo
ExecStartPre=-docker rm sftpgo
ExecStart=docker run --name sftpgo \
--env-file sftpgo-${PUID}.env \
-e PUID=${PUID} \
-e GUID=${GUID} \
-e SFTPGO_LOG_FILE_PATH= \
-e SFTPGO_CONFIG_DIR=/srv/sftpgo/config \
-e SFTPGO_HTTPD__TEMPLATES_PATH=/srv/sftpgo/web/templates \
-e SFTPGO_HTTPD__STATIC_FILES_PATH=/srv/sftpgo/web/static \
-e SFTPGO_HTTPD__BACKUPS_PATH=/srv/sftpgo/backups \
-p 8080:8080 \
-p 2022:2022 \
-v /home/sftpuser/conf/:/srv/sftpgo/config \
-v /home/sftpuser/data:/data \
-v /home/sftpuser/backups:/srv/sftpgo/backups \
sftpgo
ExecStop=docker stop sftpgo
SyslogIdentifier=sftpgo
Restart=always
RestartSec=10s
[Install]
WantedBy=multi-user.target

View File

@@ -0,0 +1,93 @@
# we use a multi stage build to have a separate build and run env
FROM golang:latest as buildenv
LABEL maintainer="nicola.murino@gmail.com"
RUN go get -v -d github.com/drakkan/sftpgo
WORKDIR /go/src/github.com/drakkan/sftpgo
ARG TAG
ARG FEATURES
# Use --build-arg TAG=LATEST for latest tag. Use e.g. --build-arg TAG=v1.0.0 for a specific tag/commit. Otherwise HEAD (master) is built.
RUN git checkout $(if [ "${TAG}" = LATEST ]; then echo `git rev-list --tags --max-count=1`; elif [ -n "${TAG}" ]; then echo "${TAG}"; else echo HEAD; fi)
RUN go build $(if [ -n "${FEATURES}" ]; then echo "-tags ${FEATURES}"; fi) -ldflags "-s -w -X github.com/drakkan/sftpgo/version.commit=`git describe --always --dirty` -X github.com/drakkan/sftpgo/version.date=`date -u +%FT%TZ`" -v -o sftpgo
# now define the run environment
FROM debian:latest
# ca-certificates is needed for Cloud Storage Support and for HTTPS/FTPS.
RUN apt-get update && apt-get install -y ca-certificates && apt-get clean
# git and rsync are optional, uncomment the next line to add support for them if needed.
#RUN apt-get update && apt-get install -y git rsync && apt-get clean
ARG BASE_DIR=/app
ARG DATA_REL_DIR=data
ARG CONFIG_REL_DIR=config
ARG BACKUP_REL_DIR=backups
ARG USERNAME=sftpgo
ARG GROUPNAME=sftpgo
ARG UID=515
ARG GID=515
ARG WEB_REL_PATH=web
# HOME_DIR for sftpgo itself
ENV HOME_DIR=${BASE_DIR}/${USERNAME}
# DATA_DIR, this is a volume that you can use hold user's home dirs
ENV DATA_DIR=${BASE_DIR}/${DATA_REL_DIR}
# CONFIG_DIR, this is a volume to persist the daemon private keys, configuration file ecc..
ENV CONFIG_DIR=${BASE_DIR}/${CONFIG_REL_DIR}
# BACKUPS_DIR, this is a volume to store backups done using "dumpdata" REST API
ENV BACKUPS_DIR=${BASE_DIR}/${BACKUP_REL_DIR}
ENV WEB_DIR=${BASE_DIR}/${WEB_REL_PATH}
RUN mkdir -p ${DATA_DIR} ${CONFIG_DIR} ${WEB_DIR} ${BACKUPS_DIR}
RUN groupadd --system -g ${GID} ${GROUPNAME}
RUN useradd --system --create-home --no-log-init --home-dir ${HOME_DIR} --comment "SFTPGo user" --shell /usr/sbin/nologin --gid ${GID} --uid ${UID} ${USERNAME}
WORKDIR ${HOME_DIR}
RUN mkdir -p bin .config/sftpgo
ENV PATH ${HOME_DIR}/bin:$PATH
COPY --from=buildenv /go/src/github.com/drakkan/sftpgo/sftpgo bin/sftpgo
# default config file to use if no config file is found inside the CONFIG_DIR volume.
# You can override each configuration options via env vars too
COPY --from=buildenv /go/src/github.com/drakkan/sftpgo/sftpgo.json .config/sftpgo/
COPY --from=buildenv /go/src/github.com/drakkan/sftpgo/templates ${WEB_DIR}/templates
COPY --from=buildenv /go/src/github.com/drakkan/sftpgo/static ${WEB_DIR}/static
RUN chown -R ${UID}:${GID} ${DATA_DIR} ${BACKUPS_DIR}
# run as non root user
USER ${USERNAME}
EXPOSE 2022 8080
# the defined volumes must have write access for the UID and GID defined above
VOLUME [ "$DATA_DIR", "$CONFIG_DIR", "$BACKUPS_DIR" ]
# override some default configuration options using env vars
ENV SFTPGO_CONFIG_DIR=${CONFIG_DIR}
# setting SFTPGO_LOG_FILE_PATH to an empty string will log to stdout
ENV SFTPGO_LOG_FILE_PATH=""
ENV SFTPGO_HTTPD__BIND_ADDRESS=""
ENV SFTPGO_HTTPD__TEMPLATES_PATH=${WEB_DIR}/templates
ENV SFTPGO_HTTPD__STATIC_FILES_PATH=${WEB_DIR}/static
ENV SFTPGO_DATA_PROVIDER__USERS_BASE_DIR=${DATA_DIR}
ENV SFTPGO_HTTPD__BACKUPS_PATH=${BACKUPS_DIR}
# uncomment the following settings to enable FTP support
#ENV SFTPGO_FTPD__BIND_PORT=2121
#ENV SFTPGO_FTPD__FORCE_PASSIVE_IP=<your FTP visibile IP here>
#EXPOSE 2121
# we need to expose the passive ports range too
#EXPOSE 50000-50100
# it is a good idea to provide certificates to enable FTPS too
#ENV SFTPGO_FTPD__CERTIFICATE_FILE=${CONFIG_DIR}/mycert.crt
#ENV SFTPGO_FTPD__CERTIFICATE_KEY_FILE=${CONFIG_DIR}/mycert.key
# uncomment the following setting to enable WebDAV support
#ENV SFTPGO_WEBDAVD__BIND_PORT=8090
# it is a good idea to provide certificates to enable WebDAV over HTTPS
#ENV SFTPGO_WEBDAVD__CERTIFICATE_FILE=${CONFIG_DIR}/mycert.crt
#ENV SFTPGO_WEBDAVD__CERTIFICATE_KEY_FILE=${CONFIG_DIR}/mycert.key
ENTRYPOINT ["sftpgo"]
CMD ["serve"]

View File

@@ -0,0 +1,59 @@
# Dockerfile based on Debian stable
:warning: The recommended way to run SFTPGo on Docker is to use the official [images](https://hub.docker.com/r/drakkan/sftpgo). The documentation here is now obsolete.
Please read the comments inside the `Dockerfile` to learn how to customize things for your setup.
You can build the container image using `docker build`, for example:
```bash
docker build -t="drakkan/sftpgo" .
```
This will build master of github.com/drakkan/sftpgo.
To build the latest tag you can add `--build-arg TAG=LATEST` and to build a specific tag/commit you can use for example `TAG=v1.0.0`, like this:
```bash
docker build -t="drakkan/sftpgo" --build-arg TAG=v1.0.0 .
```
To specify the features to build you can add `--build-arg FEATURES=<build features comma separated>`. For example you can disable SQLite and S3 support like this:
```bash
docker build -t="drakkan/sftpgo" --build-arg FEATURES=nosqlite,nos3 .
```
Please take a look at the [build from source](./../../../docs/build-from-source.md) documentation for the complete list of the features that can be disabled.
Now create the required folders on the host system, for example:
```bash
sudo mkdir -p /srv/sftpgo/data /srv/sftpgo/config /srv/sftpgo/backups
```
and give write access to them to the UID/GID defined inside the `Dockerfile`. You can choose to create a new user, on the host system, with a matching UID/GID pair, or simply do something like this:
```bash
sudo chown -R <UID>:<GID> /srv/sftpgo/data /srv/sftpgo/config /srv/sftpgo/backups
```
Download the default configuration file and edit it as you need:
```bash
sudo curl https://raw.githubusercontent.com/drakkan/sftpgo/master/sftpgo.json -o /srv/sftpgo/config/sftpgo.json
```
Initialize the configured provider. For PostgreSQL and MySQL providers you need to create the configured database and the `initprovider` command will create the required tables:
```bash
docker run --name sftpgo --mount type=bind,source=/srv/sftpgo/config,target=/app/config drakkan/sftpgo initprovider -c /app/config
```
and finally you can run the image using something like this:
```bash
docker rm sftpgo && docker run --name sftpgo -p 8080:8080 -p 2022:2022 --mount type=bind,source=/srv/sftpgo/data,target=/app/data --mount type=bind,source=/srv/sftpgo/config,target=/app/config --mount type=bind,source=/srv/sftpgo/backups,target=/app/backups drakkan/sftpgo
```
If you want to enable FTP/S you also need the publish the FTP port and the FTP passive port range, defined in your `Dockerfile`, by adding, for example, the following options to the `docker run` command `-p 2121:2121 -p 50000-50100:50000-50100`. The same goes for WebDAV, you need to publish the configured port.

Some files were not shown because too many files have changed in this diff Show More