diff --git a/modules/storage/helper.go b/modules/storage/helper.go index f8dff9e64d..95f1c7b9a8 100644 --- a/modules/storage/helper.go +++ b/modules/storage/helper.go @@ -10,30 +10,30 @@ import ( "os" ) -var uninitializedStorage = discardStorage("uninitialized storage") +var UninitializedStorage = DiscardStorage("uninitialized storage") -type discardStorage string +type DiscardStorage string -func (s discardStorage) Open(_ string) (Object, error) { +func (s DiscardStorage) Open(_ string) (Object, error) { return nil, fmt.Errorf("%s", s) } -func (s discardStorage) Save(_ string, _ io.Reader, _ int64) (int64, error) { +func (s DiscardStorage) Save(_ string, _ io.Reader, _ int64) (int64, error) { return 0, fmt.Errorf("%s", s) } -func (s discardStorage) Stat(_ string) (os.FileInfo, error) { +func (s DiscardStorage) Stat(_ string) (os.FileInfo, error) { return nil, fmt.Errorf("%s", s) } -func (s discardStorage) Delete(_ string) error { +func (s DiscardStorage) Delete(_ string) error { return fmt.Errorf("%s", s) } -func (s discardStorage) URL(_, _ string) (*url.URL, error) { +func (s DiscardStorage) URL(_, _ string) (*url.URL, error) { return nil, fmt.Errorf("%s", s) } -func (s discardStorage) IterateObjects(_ string, _ func(string, Object) error) error { +func (s DiscardStorage) IterateObjects(_ string, _ func(string, Object) error) error { return fmt.Errorf("%s", s) } diff --git a/modules/storage/helper_test.go b/modules/storage/helper_test.go index f4c2d0467f..f1f9791044 100644 --- a/modules/storage/helper_test.go +++ b/modules/storage/helper_test.go @@ -11,9 +11,9 @@ import ( ) func Test_discardStorage(t *testing.T) { - tests := []discardStorage{ - uninitializedStorage, - discardStorage("empty"), + tests := []DiscardStorage{ + UninitializedStorage, + DiscardStorage("empty"), } for _, tt := range tests { t.Run(string(tt), func(t *testing.T) { diff --git a/modules/storage/storage.go b/modules/storage/storage.go index 8f970b5dfc..b83b1c7929 100644 --- a/modules/storage/storage.go +++ b/modules/storage/storage.go @@ -109,26 +109,26 @@ func SaveFrom(objStorage ObjectStorage, p string, callback func(w io.Writer) err var ( // Attachments represents attachments storage - Attachments ObjectStorage = uninitializedStorage + Attachments ObjectStorage = UninitializedStorage // LFS represents lfs storage - LFS ObjectStorage = uninitializedStorage + LFS ObjectStorage = UninitializedStorage // Avatars represents user avatars storage - Avatars ObjectStorage = uninitializedStorage + Avatars ObjectStorage = UninitializedStorage // RepoAvatars represents repository avatars storage - RepoAvatars ObjectStorage = uninitializedStorage + RepoAvatars ObjectStorage = UninitializedStorage // RepoArchives represents repository archives storage - RepoArchives ObjectStorage = uninitializedStorage + RepoArchives ObjectStorage = UninitializedStorage // Packages represents packages storage - Packages ObjectStorage = uninitializedStorage + Packages ObjectStorage = UninitializedStorage // Actions represents actions storage - Actions ObjectStorage = uninitializedStorage + Actions ObjectStorage = UninitializedStorage // Actions Artifacts represents actions artifacts storage - ActionsArtifacts ObjectStorage = uninitializedStorage + ActionsArtifacts ObjectStorage = UninitializedStorage ) // Init init the stoarge @@ -170,7 +170,7 @@ func initAvatars() (err error) { func initAttachments() (err error) { if !setting.Attachment.Enabled { - Attachments = discardStorage("Attachment isn't enabled") + Attachments = DiscardStorage("Attachment isn't enabled") return nil } log.Info("Initialising Attachment storage with type: %s", setting.Attachment.Storage.Type) @@ -180,7 +180,7 @@ func initAttachments() (err error) { func initLFS() (err error) { if !setting.LFS.StartServer { - LFS = discardStorage("LFS isn't enabled") + LFS = DiscardStorage("LFS isn't enabled") return nil } log.Info("Initialising LFS storage with type: %s", setting.LFS.Storage.Type) @@ -202,7 +202,7 @@ func initRepoArchives() (err error) { func initPackages() (err error) { if !setting.Packages.Enabled { - Packages = discardStorage("Packages isn't enabled") + Packages = DiscardStorage("Packages isn't enabled") return nil } log.Info("Initialising Packages storage with type: %s", setting.Packages.Storage.Type) @@ -212,8 +212,8 @@ func initPackages() (err error) { func initActions() (err error) { if !setting.Actions.Enabled { - Actions = discardStorage("Actions isn't enabled") - ActionsArtifacts = discardStorage("ActionsArtifacts isn't enabled") + Actions = DiscardStorage("Actions isn't enabled") + ActionsArtifacts = DiscardStorage("ActionsArtifacts isn't enabled") return nil } log.Info("Initialising Actions storage with type: %s", setting.Actions.LogStorage.Type) diff --git a/services/user/avatar.go b/services/user/avatar.go index 2d6c3faf9a..3f87466eaa 100644 --- a/services/user/avatar.go +++ b/services/user/avatar.go @@ -5,8 +5,10 @@ package user import ( "context" + "errors" "fmt" "io" + "os" "code.gitea.io/gitea/models/db" user_model "code.gitea.io/gitea/models/user" @@ -48,16 +50,24 @@ func UploadAvatar(ctx context.Context, u *user_model.User, data []byte) error { func DeleteAvatar(ctx context.Context, u *user_model.User) error { aPath := u.CustomAvatarRelativePath() log.Trace("DeleteAvatar[%d]: %s", u.ID, aPath) - if len(u.Avatar) > 0 { - if err := storage.Avatars.Delete(aPath); err != nil { - return fmt.Errorf("Failed to remove %s: %w", aPath, err) - } - } - u.UseCustomAvatar = false - u.Avatar = "" - if _, err := db.GetEngine(ctx).ID(u.ID).Cols("avatar, use_custom_avatar").Update(u); err != nil { - return fmt.Errorf("DeleteAvatar: %w", err) - } - return nil + return db.WithTx(ctx, func(ctx context.Context) error { + hasAvatar := len(u.Avatar) > 0 + u.UseCustomAvatar = false + u.Avatar = "" + if _, err := db.GetEngine(ctx).ID(u.ID).Cols("avatar, use_custom_avatar").Update(u); err != nil { + return fmt.Errorf("DeleteAvatar: %w", err) + } + + if hasAvatar { + if err := storage.Avatars.Delete(aPath); err != nil { + if !errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("failed to remove %s: %w", aPath, err) + } + log.Warn("Deleting avatar %s but it doesn't exist", aPath) + } + } + + return nil + }) } diff --git a/services/user/avatar_test.go b/services/user/avatar_test.go new file mode 100644 index 0000000000..557ddccec0 --- /dev/null +++ b/services/user/avatar_test.go @@ -0,0 +1,80 @@ +// Copyright The Forgejo Authors. +// SPDX-License-Identifier: MIT + +package user + +import ( + "bytes" + "image" + "image/png" + "os" + "testing" + + "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/models/unittest" + user_model "code.gitea.io/gitea/models/user" + "code.gitea.io/gitea/modules/storage" + "code.gitea.io/gitea/modules/test" + + "github.com/stretchr/testify/assert" +) + +type alreadyDeletedStorage struct { + storage.DiscardStorage +} + +func (s alreadyDeletedStorage) Delete(_ string) error { + return os.ErrNotExist +} + +func TestUserDeleteAvatar(t *testing.T) { + myImage := image.NewRGBA(image.Rect(0, 0, 1, 1)) + var buff bytes.Buffer + png.Encode(&buff, myImage) + + t.Run("AtomicStorageFailure", func(t *testing.T) { + defer test.MockProtect[storage.ObjectStorage](&storage.Avatars)() + + assert.NoError(t, unittest.PrepareTestDatabase()) + user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 1}) + + err := UploadAvatar(db.DefaultContext, user, buff.Bytes()) + assert.NoError(t, err) + verification := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 1}) + assert.NotEqual(t, "", verification.Avatar) + + // fail to delete ... + storage.Avatars = storage.UninitializedStorage + err = DeleteAvatar(db.DefaultContext, user) + assert.Error(t, err) + + // ... the avatar is not removed from the database + verification = unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 1}) + assert.True(t, verification.UseCustomAvatar) + + // already deleted ... + storage.Avatars = alreadyDeletedStorage{} + err = DeleteAvatar(db.DefaultContext, user) + assert.NoError(t, err) + + // ... the avatar is removed from the database + verification = unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 1}) + assert.Equal(t, "", verification.Avatar) + }) + + t.Run("Success", func(t *testing.T) { + assert.NoError(t, unittest.PrepareTestDatabase()) + user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 1}) + + err := UploadAvatar(db.DefaultContext, user, buff.Bytes()) + assert.NoError(t, err) + verification := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 1}) + assert.NotEqual(t, "", verification.Avatar) + + err = DeleteAvatar(db.DefaultContext, user) + assert.NoError(t, err) + + verification = unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 1}) + assert.Equal(t, "", verification.Avatar) + }) +}