webdav: performance improvements and bug fixes

we need my custom golang/x/net/webdav fork for now

https://github.com/drakkan/net/tree/sftpgo
This commit is contained in:
Nicola Murino
2020-11-04 19:11:40 +01:00
parent 442efa0607
commit 0a14297b48
22 changed files with 448 additions and 202 deletions

View File

@@ -18,6 +18,7 @@ import (
"github.com/eikenb/pipeat"
"github.com/stretchr/testify/assert"
"golang.org/x/net/webdav"
"github.com/drakkan/sftpgo/common"
"github.com/drakkan/sftpgo/dataprovider"
@@ -89,9 +90,9 @@ func (fs MockOsFs) Walk(root string, walkFn filepath.WalkFunc) error {
return fs.err
}
// GetMimeType implements vfs.MimeTyper
// GetMimeType returns the content type
func (fs MockOsFs) GetMimeType(name string) (string, error) {
return "application/octet-stream", nil
return "application/custom-mime", nil
}
func newMockOsFs(err error, atomicUpload bool, connectionID, rootDir string) vfs.Fs {
@@ -319,13 +320,11 @@ func TestFileAccessErrors(t *testing.T) {
if assert.Error(t, err) {
assert.EqualError(t, err, os.ErrNotExist.Error())
}
info := vfs.NewFileInfo(missingPath, true, 0, time.Now(), false)
_, err = connection.getFile(fsMissingPath, missingPath, info)
_, err = connection.getFile(fsMissingPath, missingPath)
if assert.Error(t, err) {
assert.EqualError(t, err, os.ErrNotExist.Error())
}
info = vfs.NewFileInfo(missingPath, false, 123, time.Now(), false)
_, err = connection.getFile(fsMissingPath, missingPath, info)
_, err = connection.getFile(fsMissingPath, missingPath)
if assert.Error(t, err) {
assert.EqualError(t, err, os.ErrNotExist.Error())
}
@@ -434,20 +433,34 @@ func TestContentType(t *testing.T) {
fs = newMockOsFs(nil, false, fs.ConnectionID(), user.GetHomeDir())
err := ioutil.WriteFile(testFilePath, []byte(""), os.ModePerm)
assert.NoError(t, err)
fi, err := os.Stat(testFilePath)
assert.NoError(t, err)
davFile := newWebDavFile(baseTransfer, nil, nil, fi)
davFile := newWebDavFile(baseTransfer, nil, nil)
davFile.Fs = fs
fi, err = davFile.Stat()
fi, err := davFile.Stat()
if assert.NoError(t, err) {
ctype, err := fi.(webDavFileInfo).ContentType(ctx)
ctype, err := fi.(*webDavFileInfo).ContentType(ctx)
assert.NoError(t, err)
assert.Equal(t, "application/octet-stream", ctype)
assert.Equal(t, "application/custom-mime", ctype)
}
_, err = davFile.Readdir(-1)
assert.Error(t, err)
err = davFile.Close()
assert.NoError(t, err)
davFile = newWebDavFile(baseTransfer, nil, nil)
davFile.Fs = vfs.NewOsFs("id", user.HomeDir, nil)
fi, err = davFile.Stat()
if assert.NoError(t, err) {
ctype, err := fi.(*webDavFileInfo).ContentType(ctx)
assert.NoError(t, err)
assert.Equal(t, "text/plain; charset=utf-8", ctype)
}
err = davFile.Close()
assert.NoError(t, err)
fi.(*webDavFileInfo).fsPath = "missing"
_, err = fi.(*webDavFileInfo).ContentType(ctx)
assert.EqualError(t, err, webdav.ErrNotImplemented.Error())
err = os.Remove(testFilePath)
assert.NoError(t, err)
}
@@ -465,17 +478,16 @@ func TestTransferReadWriteErrors(t *testing.T) {
testFilePath := filepath.Join(user.HomeDir, testFile)
baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile,
common.TransferUpload, 0, 0, 0, false, fs)
davFile := newWebDavFile(baseTransfer, nil, nil, nil)
assert.False(t, davFile.isDir())
davFile := newWebDavFile(baseTransfer, nil, nil)
p := make([]byte, 1)
_, err := davFile.Read(p)
assert.EqualError(t, err, common.ErrOpUnsupported.Error())
r, w, err := pipeat.Pipe()
assert.NoError(t, err)
davFile = newWebDavFile(baseTransfer, nil, r, nil)
davFile = newWebDavFile(baseTransfer, nil, r)
davFile.Connection.RemoveTransfer(davFile.BaseTransfer)
davFile = newWebDavFile(baseTransfer, vfs.NewPipeWriter(w), nil, nil)
davFile = newWebDavFile(baseTransfer, vfs.NewPipeWriter(w), nil)
davFile.Connection.RemoveTransfer(davFile.BaseTransfer)
err = r.Close()
assert.NoError(t, err)
@@ -484,7 +496,7 @@ func TestTransferReadWriteErrors(t *testing.T) {
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile,
common.TransferDownload, 0, 0, 0, false, fs)
davFile = newWebDavFile(baseTransfer, nil, nil, nil)
davFile = newWebDavFile(baseTransfer, nil, nil)
_, err = davFile.Read(p)
assert.True(t, os.IsNotExist(err))
_, err = davFile.Stat()
@@ -499,7 +511,7 @@ func TestTransferReadWriteErrors(t *testing.T) {
err = f.Close()
assert.NoError(t, err)
}
davFile = newWebDavFile(baseTransfer, nil, nil, nil)
davFile = newWebDavFile(baseTransfer, nil, nil)
davFile.reader = f
err = davFile.Close()
assert.EqualError(t, err, common.ErrGenericFailure.Error())
@@ -514,7 +526,7 @@ func TestTransferReadWriteErrors(t *testing.T) {
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile,
common.TransferDownload, 0, 0, 0, false, fs)
davFile = newWebDavFile(baseTransfer, nil, nil, nil)
davFile = newWebDavFile(baseTransfer, nil, nil)
davFile.writer = f
err = davFile.Close()
assert.EqualError(t, err, common.ErrGenericFailure.Error())
@@ -534,9 +546,10 @@ func TestTransferSeek(t *testing.T) {
BaseConnection: common.NewBaseConnection(fs.ConnectionID(), common.ProtocolWebDAV, user, fs),
}
testFilePath := filepath.Join(user.HomeDir, testFile)
testFileContents := []byte("content")
baseTransfer := common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile,
common.TransferUpload, 0, 0, 0, false, fs)
davFile := newWebDavFile(baseTransfer, nil, nil, nil)
davFile := newWebDavFile(baseTransfer, nil, nil)
_, err := davFile.Seek(0, io.SeekStart)
assert.EqualError(t, err, common.ErrOpUnsupported.Error())
err = davFile.Close()
@@ -544,12 +557,12 @@ func TestTransferSeek(t *testing.T) {
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile,
common.TransferDownload, 0, 0, 0, false, fs)
davFile = newWebDavFile(baseTransfer, nil, nil, nil)
davFile = newWebDavFile(baseTransfer, nil, nil)
_, err = davFile.Seek(0, io.SeekCurrent)
assert.True(t, os.IsNotExist(err))
davFile.Connection.RemoveTransfer(davFile.BaseTransfer)
err = ioutil.WriteFile(testFilePath, []byte("content"), os.ModePerm)
err = ioutil.WriteFile(testFilePath, testFileContents, os.ModePerm)
assert.NoError(t, err)
f, err := os.Open(testFilePath)
if assert.NoError(t, err) {
@@ -558,44 +571,55 @@ func TestTransferSeek(t *testing.T) {
}
baseTransfer = common.NewBaseTransfer(f, connection.BaseConnection, nil, testFilePath, testFile,
common.TransferDownload, 0, 0, 0, false, fs)
davFile = newWebDavFile(baseTransfer, nil, nil, nil)
davFile = newWebDavFile(baseTransfer, nil, nil)
_, err = davFile.Seek(0, io.SeekStart)
assert.Error(t, err)
davFile.Connection.RemoveTransfer(davFile.BaseTransfer)
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile,
common.TransferDownload, 0, 0, 0, false, fs)
davFile = newWebDavFile(baseTransfer, nil, nil, nil)
davFile.reader = f
davFile = newWebDavFile(baseTransfer, nil, nil)
res, err := davFile.Seek(0, io.SeekStart)
assert.NoError(t, err)
assert.Equal(t, int64(0), res)
davFile.Connection.RemoveTransfer(davFile.BaseTransfer)
info, err := os.Stat(testFilePath)
assert.NoError(t, err)
davFile = newWebDavFile(baseTransfer, nil, nil, info)
davFile.reader = f
davFile = newWebDavFile(baseTransfer, nil, nil)
res, err = davFile.Seek(0, io.SeekEnd)
assert.NoError(t, err)
assert.Equal(t, int64(7), res)
assert.Equal(t, int64(len(testFileContents)), res)
err = davFile.updateStatInfo()
assert.Nil(t, err)
davFile = newWebDavFile(baseTransfer, nil, nil, info)
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath+"1", testFile,
common.TransferDownload, 0, 0, 0, false, fs)
davFile = newWebDavFile(baseTransfer, nil, nil)
_, err = davFile.Seek(0, io.SeekEnd)
assert.True(t, os.IsNotExist(err))
davFile.Connection.RemoveTransfer(davFile.BaseTransfer)
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath, testFile,
common.TransferDownload, 0, 0, 0, false, fs)
davFile = newWebDavFile(baseTransfer, nil, nil)
davFile.reader = f
davFile.Fs = newMockOsFs(nil, true, fs.ConnectionID(), user.GetHomeDir())
res, err = davFile.Seek(2, io.SeekStart)
assert.NoError(t, err)
assert.Equal(t, int64(2), res)
davFile = newWebDavFile(baseTransfer, nil, nil, info)
davFile = newWebDavFile(baseTransfer, nil, nil)
davFile.Fs = newMockOsFs(nil, true, fs.ConnectionID(), user.GetHomeDir())
res, err = davFile.Seek(2, io.SeekEnd)
assert.NoError(t, err)
assert.Equal(t, int64(5), res)
davFile = newWebDavFile(baseTransfer, nil, nil, nil)
baseTransfer = common.NewBaseTransfer(nil, connection.BaseConnection, nil, testFilePath+"1", testFile,
common.TransferDownload, 0, 0, 0, false, fs)
davFile = newWebDavFile(baseTransfer, nil, nil)
davFile.Fs = newMockOsFs(nil, true, fs.ConnectionID(), user.GetHomeDir())
res, err = davFile.Seek(2, io.SeekEnd)
assert.EqualError(t, err, "unable to get file size, seek from end not possible")
assert.True(t, os.IsNotExist(err))
assert.Equal(t, int64(0), res)
assert.Len(t, common.Connections.GetStats(), 0)
@@ -622,9 +646,11 @@ func TestBasicUsersCache(t *testing.T) {
c := &Configuration{
BindPort: 9000,
Cache: Cache{
Enabled: true,
MaxSize: 50,
ExpirationTime: 1,
Users: UsersCacheConfig{
Enabled: true,
MaxSize: 50,
ExpirationTime: 1,
},
},
}
server, err := newServer(c, configDir)
@@ -633,48 +659,48 @@ func TestBasicUsersCache(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user.Username), nil)
assert.NoError(t, err)
_, _, err = server.authenticate(req)
_, _, _, err = server.authenticate(req) //nolint:dogsled
assert.Error(t, err)
now := time.Now()
req.SetBasicAuth(username, password)
_, isCached, err := server.authenticate(req)
_, isCached, _, err := server.authenticate(req)
assert.NoError(t, err)
assert.False(t, isCached)
// now the user should be cached
var cachedUser dataprovider.CachedUser
var cachedUser *dataprovider.CachedUser
result, ok := dataprovider.GetCachedWebDAVUser(username)
if assert.True(t, ok) {
cachedUser = result.(dataprovider.CachedUser)
cachedUser = result.(*dataprovider.CachedUser)
assert.False(t, cachedUser.IsExpired())
assert.True(t, cachedUser.Expiration.After(now.Add(time.Duration(c.Cache.ExpirationTime)*time.Minute)))
assert.True(t, cachedUser.Expiration.After(now.Add(time.Duration(c.Cache.Users.ExpirationTime)*time.Minute)))
// authenticate must return the cached user now
authUser, isCached, err := server.authenticate(req)
authUser, isCached, _, err := server.authenticate(req)
assert.NoError(t, err)
assert.True(t, isCached)
assert.Equal(t, cachedUser.User, authUser)
}
// a wrong password must fail
req.SetBasicAuth(username, "wrong")
_, _, err = server.authenticate(req)
_, _, _, err = server.authenticate(req) //nolint:dogsled
assert.EqualError(t, err, dataprovider.ErrInvalidCredentials.Error())
req.SetBasicAuth(username, password)
// force cached user expiration
cachedUser.Expiration = now
dataprovider.CacheWebDAVUser(cachedUser, c.Cache.MaxSize)
dataprovider.CacheWebDAVUser(cachedUser, c.Cache.Users.MaxSize)
result, ok = dataprovider.GetCachedWebDAVUser(username)
if assert.True(t, ok) {
cachedUser = result.(dataprovider.CachedUser)
cachedUser = result.(*dataprovider.CachedUser)
assert.True(t, cachedUser.IsExpired())
}
// now authenticate should get the user from the data provider and update the cache
_, isCached, err = server.authenticate(req)
_, isCached, _, err = server.authenticate(req)
assert.NoError(t, err)
assert.False(t, isCached)
result, ok = dataprovider.GetCachedWebDAVUser(username)
if assert.True(t, ok) {
cachedUser = result.(dataprovider.CachedUser)
cachedUser = result.(*dataprovider.CachedUser)
assert.False(t, cachedUser.IsExpired())
}
// cache is invalidated after a user modification
@@ -683,7 +709,7 @@ func TestBasicUsersCache(t *testing.T) {
_, ok = dataprovider.GetCachedWebDAVUser(username)
assert.False(t, ok)
_, isCached, err = server.authenticate(req)
_, isCached, _, err = server.authenticate(req)
assert.NoError(t, err)
assert.False(t, isCached)
_, ok = dataprovider.GetCachedWebDAVUser(username)
@@ -725,9 +751,11 @@ func TestUsersCacheSizeAndExpiration(t *testing.T) {
c := &Configuration{
BindPort: 9000,
Cache: Cache{
Enabled: true,
MaxSize: 3,
ExpirationTime: 1,
Users: UsersCacheConfig{
Enabled: true,
MaxSize: 3,
ExpirationTime: 1,
},
},
}
server, err := newServer(c, configDir)
@@ -736,21 +764,21 @@ func TestUsersCacheSizeAndExpiration(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user1.Username), nil)
assert.NoError(t, err)
req.SetBasicAuth(user1.Username, password+"1")
_, isCached, err := server.authenticate(req)
_, isCached, _, err := server.authenticate(req)
assert.NoError(t, err)
assert.False(t, isCached)
req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user2.Username), nil)
assert.NoError(t, err)
req.SetBasicAuth(user2.Username, password+"2")
_, isCached, err = server.authenticate(req)
_, isCached, _, err = server.authenticate(req)
assert.NoError(t, err)
assert.False(t, isCached)
req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user3.Username), nil)
assert.NoError(t, err)
req.SetBasicAuth(user3.Username, password+"3")
_, isCached, err = server.authenticate(req)
_, isCached, _, err = server.authenticate(req)
assert.NoError(t, err)
assert.False(t, isCached)
@@ -765,7 +793,7 @@ func TestUsersCacheSizeAndExpiration(t *testing.T) {
req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user4.Username), nil)
assert.NoError(t, err)
req.SetBasicAuth(user4.Username, password+"4")
_, isCached, err = server.authenticate(req)
_, isCached, _, err = server.authenticate(req)
assert.NoError(t, err)
assert.False(t, isCached)
// user1, the first cached, should be removed now
@@ -782,7 +810,7 @@ func TestUsersCacheSizeAndExpiration(t *testing.T) {
req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user1.Username), nil)
assert.NoError(t, err)
req.SetBasicAuth(user1.Username, password+"1")
_, isCached, err = server.authenticate(req)
_, isCached, _, err = server.authenticate(req)
assert.NoError(t, err)
assert.False(t, isCached)
_, ok = dataprovider.GetCachedWebDAVUser(user2.Username)
@@ -798,7 +826,7 @@ func TestUsersCacheSizeAndExpiration(t *testing.T) {
req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user2.Username), nil)
assert.NoError(t, err)
req.SetBasicAuth(user2.Username, password+"2")
_, isCached, err = server.authenticate(req)
_, isCached, _, err = server.authenticate(req)
assert.NoError(t, err)
assert.False(t, isCached)
_, ok = dataprovider.GetCachedWebDAVUser(user3.Username)
@@ -814,7 +842,7 @@ func TestUsersCacheSizeAndExpiration(t *testing.T) {
req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user3.Username), nil)
assert.NoError(t, err)
req.SetBasicAuth(user3.Username, password+"3")
_, isCached, err = server.authenticate(req)
_, isCached, _, err = server.authenticate(req)
assert.NoError(t, err)
assert.False(t, isCached)
_, ok = dataprovider.GetCachedWebDAVUser(user4.Username)
@@ -835,14 +863,14 @@ func TestUsersCacheSizeAndExpiration(t *testing.T) {
req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user4.Username), nil)
assert.NoError(t, err)
req.SetBasicAuth(user4.Username, password+"4")
_, isCached, err = server.authenticate(req)
_, isCached, _, err = server.authenticate(req)
assert.NoError(t, err)
assert.False(t, isCached)
req, err = http.NewRequest(http.MethodGet, fmt.Sprintf("/%v", user1.Username), nil)
assert.NoError(t, err)
req.SetBasicAuth(user1.Username, password+"1")
_, isCached, err = server.authenticate(req)
_, isCached, _, err = server.authenticate(req)
assert.NoError(t, err)
assert.False(t, isCached)
_, ok = dataprovider.GetCachedWebDAVUser(user2.Username)
@@ -874,3 +902,20 @@ func TestRecoverer(t *testing.T) {
server.ServeHTTP(rr, nil)
assert.Equal(t, http.StatusInternalServerError, rr.Code)
}
func TestMimeCache(t *testing.T) {
cache := mimeCache{
maxSize: 0,
mimeTypes: make(map[string]string),
}
cache.addMimeToCache(".zip", "application/zip")
mtype := cache.getMimeFromCache(".zip")
assert.Equal(t, "", mtype)
cache.maxSize = 1
cache.addMimeToCache(".zip", "application/zip")
mtype = cache.getMimeFromCache(".zip")
assert.Equal(t, "application/zip", mtype)
cache.addMimeToCache(".jpg", "image/jpeg")
mtype = cache.getMimeFromCache(".jpg")
assert.Equal(t, "", mtype)
}