From fd7d83ace60258acf7139c4c787aa8af75b7ba8c Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Fri, 20 May 2022 22:08:52 +0800 Subject: [PATCH] Move almost all functions' parameter db.Engine to context.Context (#19748) * Move almost all functions' parameter db.Engine to context.Context * remove some unnecessary wrap functions --- cmd/admin.go | 6 +- integrations/api_issue_tracked_time_test.go | 6 +- integrations/api_repo_test.go | 2 +- integrations/auth_ldap_test.go | 4 +- integrations/git_test.go | 3 +- integrations/mirror_pull_test.go | 2 +- integrations/pull_merge_test.go | 3 +- integrations/pull_update_test.go | 3 +- models/action.go | 12 +- models/action_list.go | 21 +-- models/asymkey/gpg_key.go | 8 +- models/asymkey/gpg_key_add.go | 17 +- models/asymkey/ssh_key.go | 36 ++--- models/asymkey/ssh_key_authorized_keys.go | 11 +- .../asymkey/ssh_key_authorized_principals.go | 18 +-- models/asymkey/ssh_key_deploy.go | 29 ++-- models/asymkey/ssh_key_fingerprint.go | 5 +- models/asymkey/ssh_key_principals.go | 16 +- models/auth/oauth2.go | 134 +++++----------- models/auth/oauth2_test.go | 33 ++-- models/branches.go | 10 +- models/commit_status.go | 25 ++- models/consistency.go | 2 +- models/db/index.go | 5 +- models/issue.go | 145 ++++++++--------- models/issue_assignees.go | 70 +++----- models/issue_assignees_test.go | 11 +- models/issue_comment.go | 61 +++---- models/issue_comment_list.go | 55 ++++--- models/issue_dependency.go | 9 +- models/issue_label.go | 124 +++++---------- models/issue_label_test.go | 54 +++---- models/issue_list.go | 73 ++++----- models/issue_project.go | 26 ++- models/issue_stopwatch.go | 15 +- models/issue_stopwatch_test.go | 5 +- models/issue_test.go | 16 +- models/issue_tracked_time.go | 51 +++--- models/issue_tracked_time_test.go | 15 +- models/issue_watch.go | 43 ++--- models/issue_watch_test.go | 14 +- models/issue_xref.go | 25 ++- models/issue_xref_test.go | 2 +- models/issues/content_history.go | 13 +- models/issues/content_history_test.go | 21 ++- models/issues/milestone.go | 26 +-- models/issues/milestone_test.go | 4 +- models/notification.go | 108 ++++++------- models/notification_test.go | 7 +- models/org.go | 6 +- models/org_team.go | 12 +- models/organization/org.go | 65 ++------ models/organization/org_test.go | 39 +---- models/organization/org_user_test.go | 4 +- models/organization/team.go | 31 ++-- models/organization/team_test.go | 10 +- models/perm/access/access.go | 17 +- models/perm/access/repo_permission.go | 18 +-- models/project/board.go | 41 ++--- models/project/issue.go | 8 +- models/project/project.go | 45 ++---- models/project/project_test.go | 13 +- models/pull.go | 74 +++------ models/pull_list.go | 6 +- models/pull_test.go | 10 +- models/release.go | 24 +-- models/repo.go | 57 +++---- models/repo/attachment.go | 66 ++------ models/repo/attachment_test.go | 16 +- models/repo/avatar.go | 17 +- models/repo/collaboration.go | 11 +- models/repo/fork.go | 10 +- models/repo/language_stats.go | 26 ++- models/repo/mirror.go | 44 ++--- models/repo/pushmirror.go | 17 +- models/repo/repo.go | 82 ++++------ models/repo/repo_indexer.go | 28 ++-- models/repo/repo_list.go | 3 +- models/repo/repo_test.go | 19 ++- models/repo/repo_unit.go | 5 +- models/repo/star.go | 12 +- models/repo/star_test.go | 4 +- models/repo/topic.go | 34 ++-- models/repo/update.go | 13 +- models/repo/user_repo.go | 37 ----- models/repo/watch.go | 17 +- models/repo/watch_test.go | 14 +- models/repo_collaboration.go | 8 +- models/repo_generate.go | 4 +- models/repo_list.go | 20 +-- models/repo_test.go | 6 +- models/repo_transfer.go | 18 +-- models/review.go | 150 +++++++----------- models/review_test.go | 12 +- models/statistic.go | 4 +- models/task.go | 14 +- models/user.go | 8 +- models/user/avatar.go | 9 +- models/user/email_address.go | 27 ++-- models/user/email_address_test.go | 4 +- models/user/list.go | 13 +- models/user/openid.go | 18 +-- models/user/user.go | 100 ++++-------- models/user/user_test.go | 12 +- models/webhook/webhook.go | 28 +--- models/webhook/webhook_test.go | 8 +- modules/context/repo.go | 14 +- modules/convert/convert.go | 2 +- modules/convert/issue.go | 2 +- modules/convert/issue_comment.go | 3 +- modules/convert/repository.go | 2 +- modules/doctor/authorizedkeys.go | 2 +- modules/gitgraph/graph_models.go | 2 +- modules/indexer/code/git.go | 2 +- modules/indexer/code/indexer.go | 2 +- modules/indexer/issues/indexer.go | 2 +- modules/indexer/stats/db.go | 2 +- modules/indexer/stats/indexer_test.go | 3 +- modules/ssh/ssh.go | 4 +- routers/api/v1/admin/adopt.go | 8 +- routers/api/v1/admin/user.go | 2 +- routers/api/v1/api.go | 6 +- routers/api/v1/notify/notifications.go | 2 +- routers/api/v1/notify/repo.go | 4 +- routers/api/v1/notify/user.go | 4 +- routers/api/v1/org/hook.go | 2 +- routers/api/v1/org/label.go | 8 +- routers/api/v1/repo/branch.go | 18 +-- routers/api/v1/repo/collaborators.go | 8 +- routers/api/v1/repo/hook.go | 2 +- routers/api/v1/repo/issue.go | 6 +- routers/api/v1/repo/issue_comment.go | 12 +- routers/api/v1/repo/issue_label.go | 6 +- routers/api/v1/repo/issue_reaction.go | 4 +- routers/api/v1/repo/issue_subscription.go | 6 +- routers/api/v1/repo/issue_tracked_time.go | 16 +- routers/api/v1/repo/label.go | 8 +- routers/api/v1/repo/language.go | 2 +- routers/api/v1/repo/migrate.go | 2 +- routers/api/v1/repo/mirror.go | 2 +- routers/api/v1/repo/pull.go | 22 +-- routers/api/v1/repo/pull_review.go | 18 +-- routers/api/v1/repo/release_attachment.go | 8 +- routers/api/v1/repo/repo.go | 6 +- routers/api/v1/repo/status.go | 2 +- routers/api/v1/repo/teams.go | 4 +- routers/api/v1/repo/transfer.go | 4 +- routers/api/v1/user/app.go | 4 +- routers/api/v1/user/helper.go | 2 +- routers/api/v1/user/settings.go | 2 +- routers/api/v1/user/star.go | 2 +- routers/api/v1/user/user.go | 2 +- routers/api/v1/user/watch.go | 4 +- routers/install/install.go | 2 +- routers/private/hook_post_receive.go | 4 +- routers/private/hook_pre_receive.go | 2 +- routers/private/key.go | 4 +- routers/private/mail.go | 2 +- routers/private/serv.go | 6 +- routers/web/admin/hooks.go | 4 +- routers/web/admin/repos.go | 4 +- routers/web/admin/users.go | 2 +- routers/web/admin/users_test.go | 8 +- routers/web/auth/auth.go | 2 +- routers/web/auth/linkaccount.go | 2 +- routers/web/auth/oauth.go | 39 ++--- routers/web/auth/oauth_test.go | 7 +- routers/web/auth/openid.go | 6 +- routers/web/org/org_labels.go | 4 +- routers/web/org/setting.go | 6 +- routers/web/org/teams.go | 2 +- routers/web/repo/attachment.go | 4 +- routers/web/repo/commit.go | 2 +- routers/web/repo/compare.go | 2 +- routers/web/repo/issue.go | 64 ++++---- routers/web/repo/issue_content_history.go | 4 +- routers/web/repo/issue_label.go | 10 +- routers/web/repo/issue_stopwatch.go | 2 +- routers/web/repo/projects.go | 32 ++-- routers/web/repo/pull.go | 12 +- routers/web/repo/pull_review.go | 4 +- routers/web/repo/release.go | 4 +- routers/web/repo/repo.go | 6 +- routers/web/repo/setting.go | 17 +- routers/web/repo/setting_protected_branch.go | 4 +- routers/web/repo/view.go | 6 +- routers/web/repo/webhook.go | 2 +- routers/web/user/avatar.go | 2 +- routers/web/user/home.go | 2 +- routers/web/user/notification.go | 8 +- routers/web/user/profile.go | 4 +- routers/web/user/setting/account.go | 2 +- routers/web/user/setting/adopt.go | 2 +- routers/web/user/setting/applications.go | 4 +- routers/web/user/setting/oauth2.go | 8 +- routers/web/user/setting/profile.go | 2 +- routers/web/user/setting/security/openid.go | 2 +- routers/web/webfinger.go | 4 +- services/asymkey/sign.go | 2 +- services/asymkey/ssh_key.go | 2 +- services/attachment/attachment_test.go | 3 +- services/auth/oauth2.go | 2 +- services/auth/reverseproxy.go | 2 +- .../auth/source/ldap/source_authenticate.go | 4 +- services/auth/source/ldap/source_sync.go | 3 +- services/auth/sspi_windows.go | 2 +- services/automerge/automerge.go | 2 +- services/comments/comments.go | 4 +- services/context/user.go | 2 +- services/cron/tasks_extended.go | 3 +- services/issue/assignee.go | 11 +- services/issue/assignee_test.go | 11 +- services/issue/issue.go | 4 +- services/issue/label.go | 2 +- services/mailer/mail_issue.go | 4 +- services/migrations/gitea_uploader.go | 4 +- services/migrations/gitea_uploader_test.go | 5 +- services/mirror/mirror_pull.go | 10 +- services/mirror/mirror_push.go | 3 + services/org/org.go | 2 +- services/pull/check.go | 2 +- services/pull/commit_status.go | 2 +- services/pull/pull.go | 6 +- services/pull/review.go | 10 +- services/release/release_test.go | 7 +- services/repository/adopt.go | 4 +- services/repository/avatar.go | 6 +- services/repository/files/patch.go | 2 +- services/repository/files/update.go | 2 +- services/repository/push.go | 4 +- services/user/user.go | 4 +- services/webhook/webhook.go | 12 +- 232 files changed, 1448 insertions(+), 2093 deletions(-) diff --git a/cmd/admin.go b/cmd/admin.go index fcf331751..0629dfc2d 100644 --- a/cmd/admin.go +++ b/cmd/admin.go @@ -490,7 +490,7 @@ func runChangePassword(c *cli.Context) error { return errors.New("The password you chose is on a list of stolen passwords previously exposed in public data breaches. Please try again with a different password.\nFor more details, see https://haveibeenpwned.com/Passwords") } uname := c.String("username") - user, err := user_model.GetUserByName(uname) + user, err := user_model.GetUserByName(ctx, uname) if err != nil { return err } @@ -659,7 +659,7 @@ func runDeleteUser(c *cli.Context) error { if c.IsSet("email") { user, err = user_model.GetUserByEmail(c.String("email")) } else if c.IsSet("username") { - user, err = user_model.GetUserByName(c.String("username")) + user, err = user_model.GetUserByName(ctx, c.String("username")) } else { user, err = user_model.GetUserByID(c.Int64("id")) } @@ -689,7 +689,7 @@ func runGenerateAccessToken(c *cli.Context) error { return err } - user, err := user_model.GetUserByName(c.String("username")) + user, err := user_model.GetUserByName(ctx, c.String("username")) if err != nil { return err } diff --git a/integrations/api_issue_tracked_time_test.go b/integrations/api_issue_tracked_time_test.go index b6f709101..7c69d4eb9 100644 --- a/integrations/api_issue_tracked_time_test.go +++ b/integrations/api_issue_tracked_time_test.go @@ -33,7 +33,7 @@ func TestAPIGetTrackedTimes(t *testing.T) { resp := session.MakeRequest(t, req, http.StatusOK) var apiTimes api.TrackedTimeList DecodeJSON(t, resp, &apiTimes) - expect, err := models.GetTrackedTimes(&models.FindTrackedTimesOptions{IssueID: issue2.ID}) + expect, err := models.GetTrackedTimes(db.DefaultContext, &models.FindTrackedTimesOptions{IssueID: issue2.ID}) assert.NoError(t, err) assert.Len(t, apiTimes, 3) @@ -83,7 +83,7 @@ func TestAPIDeleteTrackedTime(t *testing.T) { session.MakeRequest(t, req, http.StatusNotFound) // Reset time of user 2 on issue 2 - trackedSeconds, err := models.GetTrackedSeconds(models.FindTrackedTimesOptions{IssueID: 2, UserID: 2}) + trackedSeconds, err := models.GetTrackedSeconds(db.DefaultContext, models.FindTrackedTimesOptions{IssueID: 2, UserID: 2}) assert.NoError(t, err) assert.Equal(t, int64(3661), trackedSeconds) @@ -91,7 +91,7 @@ func TestAPIDeleteTrackedTime(t *testing.T) { session.MakeRequest(t, req, http.StatusNoContent) session.MakeRequest(t, req, http.StatusNotFound) - trackedSeconds, err = models.GetTrackedSeconds(models.FindTrackedTimesOptions{IssueID: 2, UserID: 2}) + trackedSeconds, err = models.GetTrackedSeconds(db.DefaultContext, models.FindTrackedTimesOptions{IssueID: 2, UserID: 2}) assert.NoError(t, err) assert.Equal(t, int64(0), trackedSeconds) } diff --git a/integrations/api_repo_test.go b/integrations/api_repo_test.go index e635a4f37..8f08da16a 100644 --- a/integrations/api_repo_test.go +++ b/integrations/api_repo_test.go @@ -388,7 +388,7 @@ func testAPIRepoMigrateConflict(t *testing.T, u *url.URL) { defer util.RemoveAll(dstPath) t.Run("CreateRepo", doAPICreateRepository(httpContext, false)) - user, err := user_model.GetUserByName(httpContext.Username) + user, err := user_model.GetUserByName(db.DefaultContext, httpContext.Username) assert.NoError(t, err) userID := user.ID diff --git a/integrations/auth_ldap_test.go b/integrations/auth_ldap_test.go index 0eee5ae0c..296b647e6 100644 --- a/integrations/auth_ldap_test.go +++ b/integrations/auth_ldap_test.go @@ -321,7 +321,7 @@ func TestLDAPGroupTeamSyncAddMember(t *testing.T) { addAuthSourceLDAP(t, "", "on", `{"cn=ship_crew,ou=people,dc=planetexpress,dc=com":{"org26": ["team11"]},"cn=admin_staff,ou=people,dc=planetexpress,dc=com": {"non-existent": ["non-existent"]}}`) org, err := organization.GetOrgByName("org26") assert.NoError(t, err) - team, err := organization.GetTeam(org.ID, "team11") + team, err := organization.GetTeam(db.DefaultContext, org.ID, "team11") assert.NoError(t, err) auth.SyncExternalUsers(context.Background(), true) for _, gitLDAPUser := range gitLDAPUsers { @@ -366,7 +366,7 @@ func TestLDAPGroupTeamSyncRemoveMember(t *testing.T) { addAuthSourceLDAP(t, "", "on", `{"cn=dispatch,ou=people,dc=planetexpress,dc=com": {"org26": ["team11"]}}`) org, err := organization.GetOrgByName("org26") assert.NoError(t, err) - team, err := organization.GetTeam(org.ID, "team11") + team, err := organization.GetTeam(db.DefaultContext, org.ID, "team11") assert.NoError(t, err) loginUserWithPassword(t, gitLDAPUsers[0].UserName, gitLDAPUsers[0].Password) user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ diff --git a/integrations/git_test.go b/integrations/git_test.go index 04cdf633b..63afc7913 100644 --- a/integrations/git_test.go +++ b/integrations/git_test.go @@ -18,6 +18,7 @@ import ( "time" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/models/perm" repo_model "code.gitea.io/gitea/models/repo" "code.gitea.io/gitea/models/unittest" @@ -438,7 +439,7 @@ func doProtectBranch(ctx APITestContext, branch, userToWhitelist, unprotectedFil }) ctx.Session.MakeRequest(t, req, http.StatusSeeOther) } else { - user, err := user_model.GetUserByName(userToWhitelist) + user, err := user_model.GetUserByName(db.DefaultContext, userToWhitelist) assert.NoError(t, err) // Change branch to protected req := NewRequestWithValues(t, "POST", fmt.Sprintf("/%s/%s/settings/branches/%s", url.PathEscape(ctx.Username), url.PathEscape(ctx.Reponame), url.PathEscape(branch)), map[string]string{ diff --git a/integrations/mirror_pull_test.go b/integrations/mirror_pull_test.go index dd66974e0..8f74d5fe1 100644 --- a/integrations/mirror_pull_test.go +++ b/integrations/mirror_pull_test.go @@ -75,7 +75,7 @@ func TestMirrorPull(t *testing.T) { IsTag: true, }, nil, "")) - _, err = repo_model.GetMirrorByRepoID(mirror.ID) + _, err = repo_model.GetMirrorByRepoID(ctx, mirror.ID) assert.NoError(t, err) ok := mirror_service.SyncPullMirror(ctx, mirror.ID) diff --git a/integrations/pull_merge_test.go b/integrations/pull_merge_test.go index cfe8b4afe..6c5b67caa 100644 --- a/integrations/pull_merge_test.go +++ b/integrations/pull_merge_test.go @@ -18,6 +18,7 @@ import ( "time" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/db" repo_model "code.gitea.io/gitea/models/repo" "code.gitea.io/gitea/models/unittest" user_model "code.gitea.io/gitea/models/user" @@ -407,7 +408,7 @@ func TestConflictChecking(t *testing.T) { assert.NoError(t, err) issue := unittest.AssertExistsAndLoadBean(t, &models.Issue{Title: "PR with conflict!"}).(*models.Issue) - conflictingPR, err := models.GetPullRequestByIssueID(issue.ID) + conflictingPR, err := models.GetPullRequestByIssueID(db.DefaultContext, issue.ID) assert.NoError(t, err) // Ensure conflictedFiles is populated. diff --git a/integrations/pull_update_test.go b/integrations/pull_update_test.go index 20b4eaeb4..f11eacf14 100644 --- a/integrations/pull_update_test.go +++ b/integrations/pull_update_test.go @@ -11,6 +11,7 @@ import ( "time" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/models/unittest" user_model "code.gitea.io/gitea/models/user" "code.gitea.io/gitea/modules/git" @@ -165,7 +166,7 @@ func createOutdatedPR(t *testing.T, actor, forkOrg *user_model.User) *models.Pul assert.NoError(t, err) issue := unittest.AssertExistsAndLoadBean(t, &models.Issue{Title: "Test Pull -to-update-"}).(*models.Issue) - pr, err := models.GetPullRequestByIssueID(issue.ID) + pr, err := models.GetPullRequestByIssueID(db.DefaultContext, issue.ID) assert.NoError(t, err) return pr diff --git a/models/action.go b/models/action.go index 6b662b310..29e2ea47b 100644 --- a/models/action.go +++ b/models/action.go @@ -222,9 +222,8 @@ func (a *Action) getCommentLink(ctx context.Context) string { if a == nil { return "#" } - e := db.GetEngine(ctx) if a.Comment == nil && a.CommentID != 0 { - a.Comment, _ = getCommentByID(e, a.CommentID) + a.Comment, _ = GetCommentByID(ctx, a.CommentID) } if a.Comment != nil { return a.Comment.HTMLURL() @@ -239,7 +238,7 @@ func (a *Action) getCommentLink(ctx context.Context) string { return "#" } - issue, err := getIssueByID(e, issueID) + issue, err := getIssueByID(ctx, issueID) if err != nil { return "#" } @@ -340,8 +339,7 @@ func GetFeeds(ctx context.Context, opts GetFeedsOptions) (ActionList, error) { return nil, err } - e := db.GetEngine(ctx) - sess := e.Where(cond). + sess := db.GetEngine(ctx).Where(cond). Select("`action`.*"). // this line will avoid select other joined table's columns Join("INNER", "repository", "`repository`.id = `action`.repo_id") @@ -354,7 +352,7 @@ func GetFeeds(ctx context.Context, opts GetFeedsOptions) (ActionList, error) { return nil, fmt.Errorf("Find: %v", err) } - if err := ActionList(actions).loadAttributes(e); err != nil { + if err := ActionList(actions).loadAttributes(ctx); err != nil { return nil, fmt.Errorf("LoadAttributes: %v", err) } @@ -504,7 +502,7 @@ func notifyWatchers(ctx context.Context, actions ...*Action) error { permIssue = make([]bool, len(watchers)) permPR = make([]bool, len(watchers)) for i, watcher := range watchers { - user, err := user_model.GetUserByIDEngine(e, watcher.UserID) + user, err := user_model.GetUserByIDCtx(ctx, watcher.UserID) if err != nil { permCode[i] = false permIssue[i] = false diff --git a/models/action_list.go b/models/action_list.go index 5f7b17b9d..d585ef0fc 100644 --- a/models/action_list.go +++ b/models/action_list.go @@ -5,6 +5,7 @@ package models import ( + "context" "fmt" "code.gitea.io/gitea/models/db" @@ -26,14 +27,14 @@ func (actions ActionList) getUserIDs() []int64 { return container.KeysInt64(userIDs) } -func (actions ActionList) loadUsers(e db.Engine) (map[int64]*user_model.User, error) { +func (actions ActionList) loadUsers(ctx context.Context) (map[int64]*user_model.User, error) { if len(actions) == 0 { return nil, nil } userIDs := actions.getUserIDs() userMaps := make(map[int64]*user_model.User, len(userIDs)) - err := e. + err := db.GetEngine(ctx). In("id", userIDs). Find(&userMaps) if err != nil { @@ -56,14 +57,14 @@ func (actions ActionList) getRepoIDs() []int64 { return container.KeysInt64(repoIDs) } -func (actions ActionList) loadRepositories(e db.Engine) error { +func (actions ActionList) loadRepositories(ctx context.Context) error { if len(actions) == 0 { return nil } repoIDs := actions.getRepoIDs() repoMaps := make(map[int64]*repo_model.Repository, len(repoIDs)) - err := e.In("id", repoIDs).Find(&repoMaps) + err := db.GetEngine(ctx).In("id", repoIDs).Find(&repoMaps) if err != nil { return fmt.Errorf("find repository: %v", err) } @@ -74,7 +75,7 @@ func (actions ActionList) loadRepositories(e db.Engine) error { return nil } -func (actions ActionList) loadRepoOwner(e db.Engine, userMap map[int64]*user_model.User) (err error) { +func (actions ActionList) loadRepoOwner(ctx context.Context, userMap map[int64]*user_model.User) (err error) { if userMap == nil { userMap = make(map[int64]*user_model.User) } @@ -85,7 +86,7 @@ func (actions ActionList) loadRepoOwner(e db.Engine, userMap map[int64]*user_mod } repoOwner, ok := userMap[action.Repo.OwnerID] if !ok { - repoOwner, err = user_model.GetUserByID(action.Repo.OwnerID) + repoOwner, err = user_model.GetUserByIDCtx(ctx, action.Repo.OwnerID) if err != nil { if user_model.IsErrUserNotExist(err) { continue @@ -101,15 +102,15 @@ func (actions ActionList) loadRepoOwner(e db.Engine, userMap map[int64]*user_mod } // loadAttributes loads all attributes -func (actions ActionList) loadAttributes(e db.Engine) error { - userMap, err := actions.loadUsers(e) +func (actions ActionList) loadAttributes(ctx context.Context) error { + userMap, err := actions.loadUsers(ctx) if err != nil { return err } - if err := actions.loadRepositories(e); err != nil { + if err := actions.loadRepositories(ctx); err != nil { return err } - return actions.loadRepoOwner(e, userMap) + return actions.loadRepoOwner(ctx, userMap) } diff --git a/models/asymkey/gpg_key.go b/models/asymkey/gpg_key.go index ced6ca37a..2b9997237 100644 --- a/models/asymkey/gpg_key.go +++ b/models/asymkey/gpg_key.go @@ -198,16 +198,16 @@ func parseGPGKey(ownerID int64, e *openpgp.Entity, verified bool) (*GPGKey, erro } // deleteGPGKey does the actual key deletion -func deleteGPGKey(e db.Engine, keyID string) (int64, error) { +func deleteGPGKey(ctx context.Context, keyID string) (int64, error) { if keyID == "" { return 0, fmt.Errorf("empty KeyId forbidden") // Should never happen but just to be sure } // Delete imported key - n, err := e.Where("key_id=?", keyID).Delete(new(GPGKeyImport)) + n, err := db.GetEngine(ctx).Where("key_id=?", keyID).Delete(new(GPGKeyImport)) if err != nil { return n, err } - return e.Where("key_id=?", keyID).Or("primary_key_id=?", keyID).Delete(new(GPGKey)) + return db.GetEngine(ctx).Where("key_id=?", keyID).Or("primary_key_id=?", keyID).Delete(new(GPGKey)) } // DeleteGPGKey deletes GPG key information in database. @@ -231,7 +231,7 @@ func DeleteGPGKey(doer *user_model.User, id int64) (err error) { } defer committer.Close() - if _, err = deleteGPGKey(db.GetEngine(ctx), key.KeyID); err != nil { + if _, err = deleteGPGKey(ctx, key.KeyID); err != nil { return err } diff --git a/models/asymkey/gpg_key_add.go b/models/asymkey/gpg_key_add.go index 8f84bba1d..d01f2deb0 100644 --- a/models/asymkey/gpg_key_add.go +++ b/models/asymkey/gpg_key_add.go @@ -5,6 +5,7 @@ package asymkey import ( + "context" "strings" "code.gitea.io/gitea/models/db" @@ -29,21 +30,21 @@ import ( // This file contains functions relating to adding GPG Keys // addGPGKey add key, import and subkeys to database -func addGPGKey(e db.Engine, key *GPGKey, content string) (err error) { +func addGPGKey(ctx context.Context, key *GPGKey, content string) (err error) { // Add GPGKeyImport - if _, err = e.Insert(GPGKeyImport{ + if err = db.Insert(ctx, &GPGKeyImport{ KeyID: key.KeyID, Content: content, }); err != nil { return err } // Save GPG primary key. - if _, err = e.Insert(key); err != nil { + if err = db.Insert(ctx, key); err != nil { return err } // Save GPG subs key. for _, subkey := range key.SubsKey { - if err := addGPGSubKey(e, subkey); err != nil { + if err := addGPGSubKey(ctx, subkey); err != nil { return err } } @@ -51,14 +52,14 @@ func addGPGKey(e db.Engine, key *GPGKey, content string) (err error) { } // addGPGSubKey add subkeys to database -func addGPGSubKey(e db.Engine, key *GPGKey) (err error) { +func addGPGSubKey(ctx context.Context, key *GPGKey) (err error) { // Save GPG primary key. - if _, err = e.Insert(key); err != nil { + if err = db.Insert(ctx, key); err != nil { return err } // Save GPG subs key. for _, subkey := range key.SubsKey { - if err := addGPGSubKey(e, subkey); err != nil { + if err := addGPGSubKey(ctx, subkey); err != nil { return err } } @@ -158,7 +159,7 @@ func AddGPGKey(ownerID int64, content, token, signature string) ([]*GPGKey, erro return nil, err } - if err = addGPGKey(db.GetEngine(ctx), key, content); err != nil { + if err = addGPGKey(ctx, key, content); err != nil { return nil, err } keys = append(keys, key) diff --git a/models/asymkey/ssh_key.go b/models/asymkey/ssh_key.go index 74d2411dd..10220ea93 100644 --- a/models/asymkey/ssh_key.go +++ b/models/asymkey/ssh_key.go @@ -75,7 +75,7 @@ func (key *PublicKey) AuthorizedString() string { return AuthorizedStringForKey(key) } -func addKey(e db.Engine, key *PublicKey) (err error) { +func addKey(ctx context.Context, key *PublicKey) (err error) { if len(key.Fingerprint) == 0 { key.Fingerprint, err = calcFingerprint(key.Content) if err != nil { @@ -84,7 +84,7 @@ func addKey(e db.Engine, key *PublicKey) (err error) { } // Save SSH key. - if _, err = e.Insert(key); err != nil { + if err = db.Insert(ctx, key); err != nil { return err } @@ -105,14 +105,13 @@ func AddPublicKey(ownerID int64, name, content string, authSourceID int64) (*Pub return nil, err } defer committer.Close() - sess := db.GetEngine(ctx) - if err := checkKeyFingerprint(sess, fingerprint); err != nil { + if err := checkKeyFingerprint(ctx, fingerprint); err != nil { return nil, err } // Key name of same user cannot be duplicated. - has, err := sess. + has, err := db.GetEngine(ctx). Where("owner_id = ? AND name = ?", ownerID, name). Get(new(PublicKey)) if err != nil { @@ -130,7 +129,7 @@ func AddPublicKey(ownerID int64, name, content string, authSourceID int64) (*Pub Type: KeyTypeUser, LoginSourceID: authSourceID, } - if err = addKey(sess, key); err != nil { + if err = addKey(ctx, key); err != nil { return nil, fmt.Errorf("addKey: %v", err) } @@ -151,9 +150,11 @@ func GetPublicKeyByID(keyID int64) (*PublicKey, error) { return key, nil } -func searchPublicKeyByContentWithEngine(e db.Engine, content string) (*PublicKey, error) { +// SearchPublicKeyByContent searches content as prefix (leak e-mail part) +// and returns public key found. +func SearchPublicKeyByContent(ctx context.Context, content string) (*PublicKey, error) { key := new(PublicKey) - has, err := e. + has, err := db.GetEngine(ctx). Where("content like ?", content+"%"). Get(key) if err != nil { @@ -164,15 +165,11 @@ func searchPublicKeyByContentWithEngine(e db.Engine, content string) (*PublicKey return key, nil } -// SearchPublicKeyByContent searches content as prefix (leak e-mail part) +// SearchPublicKeyByContentExact searches content // and returns public key found. -func SearchPublicKeyByContent(content string) (*PublicKey, error) { - return searchPublicKeyByContentWithEngine(db.GetEngine(db.DefaultContext), content) -} - -func searchPublicKeyByContentExactWithEngine(e db.Engine, content string) (*PublicKey, error) { +func SearchPublicKeyByContentExact(ctx context.Context, content string) (*PublicKey, error) { key := new(PublicKey) - has, err := e. + has, err := db.GetEngine(ctx). Where("content = ?", content). Get(key) if err != nil { @@ -183,12 +180,6 @@ func searchPublicKeyByContentExactWithEngine(e db.Engine, content string) (*Publ return key, nil } -// SearchPublicKeyByContentExact searches content -// and returns public key found. -func SearchPublicKeyByContentExact(content string) (*PublicKey, error) { - return searchPublicKeyByContentExactWithEngine(db.GetEngine(db.DefaultContext), content) -} - // SearchPublicKey returns a list of public keys matching the provided arguments. func SearchPublicKey(uid int64, fingerprint string) ([]*PublicKey, error) { keys := make([]*PublicKey, 0, 5) @@ -335,12 +326,11 @@ func deleteKeysMarkedForDeletion(keys []string) (bool, error) { return false, err } defer committer.Close() - sess := db.GetEngine(ctx) // Delete keys marked for deletion var sshKeysNeedUpdate bool for _, KeyToDelete := range keys { - key, err := searchPublicKeyByContentWithEngine(sess, KeyToDelete) + key, err := SearchPublicKeyByContent(ctx, KeyToDelete) if err != nil { log.Error("SearchPublicKeyByContent: %v", err) continue diff --git a/models/asymkey/ssh_key_authorized_keys.go b/models/asymkey/ssh_key_authorized_keys.go index dd058f5d1..ce3b0248c 100644 --- a/models/asymkey/ssh_key_authorized_keys.go +++ b/models/asymkey/ssh_key_authorized_keys.go @@ -6,6 +6,7 @@ package asymkey import ( "bufio" + "context" "fmt" "io" "os" @@ -165,7 +166,7 @@ func RewriteAllPublicKeys() error { } } - if err := RegeneratePublicKeys(t); err != nil { + if err := RegeneratePublicKeys(db.DefaultContext, t); err != nil { return err } @@ -174,12 +175,8 @@ func RewriteAllPublicKeys() error { } // RegeneratePublicKeys regenerates the authorized_keys file -func RegeneratePublicKeys(t io.StringWriter) error { - return regeneratePublicKeys(db.GetEngine(db.DefaultContext), t) -} - -func regeneratePublicKeys(e db.Engine, t io.StringWriter) error { - if err := e.Where("type != ?", KeyTypePrincipal).Iterate(new(PublicKey), func(idx int, bean interface{}) (err error) { +func RegeneratePublicKeys(ctx context.Context, t io.StringWriter) error { + if err := db.GetEngine(ctx).Where("type != ?", KeyTypePrincipal).Iterate(new(PublicKey), func(idx int, bean interface{}) (err error) { _, err = t.WriteString((bean.(*PublicKey)).AuthorizedString()) return err }); err != nil { diff --git a/models/asymkey/ssh_key_authorized_principals.go b/models/asymkey/ssh_key_authorized_principals.go index a8c48c50a..4b08d0dfe 100644 --- a/models/asymkey/ssh_key_authorized_principals.go +++ b/models/asymkey/ssh_key_authorized_principals.go @@ -6,6 +6,7 @@ package asymkey import ( "bufio" + "context" "fmt" "io" "os" @@ -42,11 +43,7 @@ const authorizedPrincipalsFile = "authorized_principals" // RewriteAllPrincipalKeys removes any authorized principal and rewrite all keys from database again. // Note: db.GetEngine(db.DefaultContext).Iterate does not get latest data after insert/delete, so we have to call this function // outside any session scope independently. -func RewriteAllPrincipalKeys() error { - return rewriteAllPrincipalKeys(db.GetEngine(db.DefaultContext)) -} - -func rewriteAllPrincipalKeys(e db.Engine) error { +func RewriteAllPrincipalKeys(ctx context.Context) error { // Don't rewrite key if internal server if setting.SSH.StartBuiltinServer || !setting.SSH.CreateAuthorizedPrincipalsFile { return nil @@ -92,7 +89,7 @@ func rewriteAllPrincipalKeys(e db.Engine) error { } } - if err := regeneratePrincipalKeys(e, t); err != nil { + if err := regeneratePrincipalKeys(ctx, t); err != nil { return err } @@ -100,13 +97,8 @@ func rewriteAllPrincipalKeys(e db.Engine) error { return util.Rename(tmpPath, fPath) } -// RegeneratePrincipalKeys regenerates the authorized_principals file -func RegeneratePrincipalKeys(t io.StringWriter) error { - return regeneratePrincipalKeys(db.GetEngine(db.DefaultContext), t) -} - -func regeneratePrincipalKeys(e db.Engine, t io.StringWriter) error { - if err := e.Where("type = ?", KeyTypePrincipal).Iterate(new(PublicKey), func(idx int, bean interface{}) (err error) { +func regeneratePrincipalKeys(ctx context.Context, t io.StringWriter) error { + if err := db.GetEngine(ctx).Where("type = ?", KeyTypePrincipal).Iterate(new(PublicKey), func(idx int, bean interface{}) (err error) { _, err = t.WriteString((bean.(*PublicKey)).AuthorizedString()) return err }); err != nil { diff --git a/models/asymkey/ssh_key_deploy.go b/models/asymkey/ssh_key_deploy.go index fe2ade43a..9a97d37f9 100644 --- a/models/asymkey/ssh_key_deploy.go +++ b/models/asymkey/ssh_key_deploy.go @@ -67,9 +67,9 @@ func init() { db.RegisterModel(new(DeployKey)) } -func checkDeployKey(e db.Engine, keyID, repoID int64, name string) error { +func checkDeployKey(ctx context.Context, keyID, repoID int64, name string) error { // Note: We want error detail, not just true or false here. - has, err := e. + has, err := db.GetEngine(ctx). Where("key_id = ? AND repo_id = ?", keyID, repoID). Get(new(DeployKey)) if err != nil { @@ -78,7 +78,7 @@ func checkDeployKey(e db.Engine, keyID, repoID int64, name string) error { return ErrDeployKeyAlreadyExist{keyID, repoID} } - has, err = e. + has, err = db.GetEngine(ctx). Where("repo_id = ? AND name = ?", repoID, name). Get(new(DeployKey)) if err != nil { @@ -91,8 +91,8 @@ func checkDeployKey(e db.Engine, keyID, repoID int64, name string) error { } // addDeployKey adds new key-repo relation. -func addDeployKey(e db.Engine, keyID, repoID int64, name, fingerprint string, mode perm.AccessMode) (*DeployKey, error) { - if err := checkDeployKey(e, keyID, repoID, name); err != nil { +func addDeployKey(ctx context.Context, keyID, repoID int64, name, fingerprint string, mode perm.AccessMode) (*DeployKey, error) { + if err := checkDeployKey(ctx, keyID, repoID, name); err != nil { return nil, err } @@ -103,8 +103,7 @@ func addDeployKey(e db.Engine, keyID, repoID int64, name, fingerprint string, mo Fingerprint: fingerprint, Mode: mode, } - _, err := e.Insert(key) - return key, err + return key, db.Insert(ctx, key) } // HasDeployKey returns true if public key is a deploy key of given repository. @@ -133,12 +132,10 @@ func AddDeployKey(repoID int64, name, content string, readOnly bool) (*DeployKey } defer committer.Close() - sess := db.GetEngine(ctx) - pkey := &PublicKey{ Fingerprint: fingerprint, } - has, err := sess.Get(pkey) + has, err := db.GetByBean(ctx, pkey) if err != nil { return nil, err } @@ -153,12 +150,12 @@ func AddDeployKey(repoID int64, name, content string, readOnly bool) (*DeployKey pkey.Type = KeyTypeDeploy pkey.Content = content pkey.Name = name - if err = addKey(sess, pkey); err != nil { + if err = addKey(ctx, pkey); err != nil { return nil, fmt.Errorf("addKey: %v", err) } } - key, err := addDeployKey(sess, pkey.ID, repoID, name, pkey.Fingerprint, accessMode) + key, err := addDeployKey(ctx, pkey.ID, repoID, name, pkey.Fingerprint, accessMode) if err != nil { return nil, err } @@ -179,16 +176,12 @@ func GetDeployKeyByID(ctx context.Context, id int64) (*DeployKey, error) { } // GetDeployKeyByRepo returns deploy key by given public key ID and repository ID. -func GetDeployKeyByRepo(keyID, repoID int64) (*DeployKey, error) { - return getDeployKeyByRepo(db.GetEngine(db.DefaultContext), keyID, repoID) -} - -func getDeployKeyByRepo(e db.Engine, keyID, repoID int64) (*DeployKey, error) { +func GetDeployKeyByRepo(ctx context.Context, keyID, repoID int64) (*DeployKey, error) { key := &DeployKey{ KeyID: keyID, RepoID: repoID, } - has, err := e.Get(key) + has, err := db.GetByBean(ctx, key) if err != nil { return nil, err } else if !has { diff --git a/models/asymkey/ssh_key_fingerprint.go b/models/asymkey/ssh_key_fingerprint.go index 437f283bf..283b3d3b6 100644 --- a/models/asymkey/ssh_key_fingerprint.go +++ b/models/asymkey/ssh_key_fingerprint.go @@ -5,6 +5,7 @@ package asymkey import ( + "context" "errors" "fmt" "strings" @@ -31,8 +32,8 @@ import ( // checkKeyFingerprint only checks if key fingerprint has been used as public key, // it is OK to use same key as deploy key for multiple repositories/users. -func checkKeyFingerprint(e db.Engine, fingerprint string) error { - has, err := e.Get(&PublicKey{ +func checkKeyFingerprint(ctx context.Context, fingerprint string) error { + has, err := db.GetByBean(ctx, &PublicKey{ Fingerprint: fingerprint, }) if err != nil { diff --git a/models/asymkey/ssh_key_principals.go b/models/asymkey/ssh_key_principals.go index 5f18fd04d..7a5c234f6 100644 --- a/models/asymkey/ssh_key_principals.go +++ b/models/asymkey/ssh_key_principals.go @@ -31,10 +31,9 @@ func AddPrincipalKey(ownerID int64, content string, authSourceID int64) (*Public return nil, err } defer committer.Close() - sess := db.GetEngine(ctx) // Principals cannot be duplicated. - has, err := sess. + has, err := db.GetEngine(ctx). Where("content = ? AND type = ?", content, KeyTypePrincipal). Get(new(PublicKey)) if err != nil { @@ -51,7 +50,7 @@ func AddPrincipalKey(ownerID int64, content string, authSourceID int64) (*Public Type: KeyTypePrincipal, LoginSourceID: authSourceID, } - if err = addPrincipalKey(sess, key); err != nil { + if err = db.Insert(ctx, key); err != nil { return nil, fmt.Errorf("addKey: %v", err) } @@ -61,16 +60,7 @@ func AddPrincipalKey(ownerID int64, content string, authSourceID int64) (*Public committer.Close() - return key, RewriteAllPrincipalKeys() -} - -func addPrincipalKey(e db.Engine, key *PublicKey) (err error) { - // Save Key representing a principal. - if _, err = e.Insert(key); err != nil { - return err - } - - return nil + return key, RewriteAllPrincipalKeys(db.DefaultContext) } // CheckPrincipalKeyString strips spaces and returns an error if the given principal contains newlines diff --git a/models/auth/oauth2.go b/models/auth/oauth2.go index ca77fcdb7..c5c6e9112 100644 --- a/models/auth/oauth2.go +++ b/models/auth/oauth2.go @@ -92,13 +92,9 @@ func (app *OAuth2Application) ValidateClientSecret(secret []byte) bool { } // GetGrantByUserID returns a OAuth2Grant by its user and application ID -func (app *OAuth2Application) GetGrantByUserID(userID int64) (*OAuth2Grant, error) { - return app.getGrantByUserID(db.GetEngine(db.DefaultContext), userID) -} - -func (app *OAuth2Application) getGrantByUserID(e db.Engine, userID int64) (grant *OAuth2Grant, err error) { +func (app *OAuth2Application) GetGrantByUserID(ctx context.Context, userID int64) (grant *OAuth2Grant, err error) { grant = new(OAuth2Grant) - if has, err := e.Where("user_id = ? AND application_id = ?", userID, app.ID).Get(grant); err != nil { + if has, err := db.GetEngine(ctx).Where("user_id = ? AND application_id = ?", userID, app.ID).Get(grant); err != nil { return nil, err } else if !has { return nil, nil @@ -107,17 +103,13 @@ func (app *OAuth2Application) getGrantByUserID(e db.Engine, userID int64) (grant } // CreateGrant generates a grant for an user -func (app *OAuth2Application) CreateGrant(userID int64, scope string) (*OAuth2Grant, error) { - return app.createGrant(db.GetEngine(db.DefaultContext), userID, scope) -} - -func (app *OAuth2Application) createGrant(e db.Engine, userID int64, scope string) (*OAuth2Grant, error) { +func (app *OAuth2Application) CreateGrant(ctx context.Context, userID int64, scope string) (*OAuth2Grant, error) { grant := &OAuth2Grant{ ApplicationID: app.ID, UserID: userID, Scope: scope, } - _, err := e.Insert(grant) + err := db.Insert(ctx, grant) if err != nil { return nil, err } @@ -125,13 +117,9 @@ func (app *OAuth2Application) createGrant(e db.Engine, userID int64, scope strin } // GetOAuth2ApplicationByClientID returns the oauth2 application with the given client_id. Returns an error if not found. -func GetOAuth2ApplicationByClientID(clientID string) (app *OAuth2Application, err error) { - return getOAuth2ApplicationByClientID(db.GetEngine(db.DefaultContext), clientID) -} - -func getOAuth2ApplicationByClientID(e db.Engine, clientID string) (app *OAuth2Application, err error) { +func GetOAuth2ApplicationByClientID(ctx context.Context, clientID string) (app *OAuth2Application, err error) { app = new(OAuth2Application) - has, err := e.Where("client_id = ?", clientID).Get(app) + has, err := db.GetEngine(ctx).Where("client_id = ?", clientID).Get(app) if !has { return nil, ErrOAuthClientIDInvalid{ClientID: clientID} } @@ -139,13 +127,9 @@ func getOAuth2ApplicationByClientID(e db.Engine, clientID string) (app *OAuth2Ap } // GetOAuth2ApplicationByID returns the oauth2 application with the given id. Returns an error if not found. -func GetOAuth2ApplicationByID(id int64) (app *OAuth2Application, err error) { - return getOAuth2ApplicationByID(db.GetEngine(db.DefaultContext), id) -} - -func getOAuth2ApplicationByID(e db.Engine, id int64) (app *OAuth2Application, err error) { +func GetOAuth2ApplicationByID(ctx context.Context, id int64) (app *OAuth2Application, err error) { app = new(OAuth2Application) - has, err := e.ID(id).Get(app) + has, err := db.GetEngine(ctx).ID(id).Get(app) if err != nil { return nil, err } @@ -156,13 +140,9 @@ func getOAuth2ApplicationByID(e db.Engine, id int64) (app *OAuth2Application, er } // GetOAuth2ApplicationsByUserID returns all oauth2 applications owned by the user -func GetOAuth2ApplicationsByUserID(userID int64) (apps []*OAuth2Application, err error) { - return getOAuth2ApplicationsByUserID(db.GetEngine(db.DefaultContext), userID) -} - -func getOAuth2ApplicationsByUserID(e db.Engine, userID int64) (apps []*OAuth2Application, err error) { +func GetOAuth2ApplicationsByUserID(ctx context.Context, userID int64) (apps []*OAuth2Application, err error) { apps = make([]*OAuth2Application, 0) - err = e.Where("uid = ?", userID).Find(&apps) + err = db.GetEngine(ctx).Where("uid = ?", userID).Find(&apps) return } @@ -174,11 +154,7 @@ type CreateOAuth2ApplicationOptions struct { } // CreateOAuth2Application inserts a new oauth2 application -func CreateOAuth2Application(opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) { - return createOAuth2Application(db.GetEngine(db.DefaultContext), opts) -} - -func createOAuth2Application(e db.Engine, opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) { +func CreateOAuth2Application(ctx context.Context, opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) { clientID := uuid.New().String() app := &OAuth2Application{ UID: opts.UserID, @@ -186,7 +162,7 @@ func createOAuth2Application(e db.Engine, opts CreateOAuth2ApplicationOptions) ( ClientID: clientID, RedirectURIs: opts.RedirectURIs, } - if _, err := e.Insert(app); err != nil { + if err := db.Insert(ctx, app); err != nil { return nil, err } return app, nil @@ -207,9 +183,8 @@ func UpdateOAuth2Application(opts UpdateOAuth2ApplicationOptions) (*OAuth2Applic return nil, err } defer committer.Close() - sess := db.GetEngine(ctx) - app, err := getOAuth2ApplicationByID(sess, opts.ID) + app, err := GetOAuth2ApplicationByID(ctx, opts.ID) if err != nil { return nil, err } @@ -220,7 +195,7 @@ func UpdateOAuth2Application(opts UpdateOAuth2ApplicationOptions) (*OAuth2Applic app.Name = opts.Name app.RedirectURIs = opts.RedirectURIs - if err = updateOAuth2Application(sess, app); err != nil { + if err = updateOAuth2Application(ctx, app); err != nil { return nil, err } app.ClientSecret = "" @@ -228,14 +203,15 @@ func UpdateOAuth2Application(opts UpdateOAuth2ApplicationOptions) (*OAuth2Applic return app, committer.Commit() } -func updateOAuth2Application(e db.Engine, app *OAuth2Application) error { - if _, err := e.ID(app.ID).Update(app); err != nil { +func updateOAuth2Application(ctx context.Context, app *OAuth2Application) error { + if _, err := db.GetEngine(ctx).ID(app.ID).Update(app); err != nil { return err } return nil } -func deleteOAuth2Application(sess db.Engine, id, userid int64) error { +func deleteOAuth2Application(ctx context.Context, id, userid int64) error { + sess := db.GetEngine(ctx) if deleted, err := sess.Delete(&OAuth2Application{ID: id, UID: userid}); err != nil { return err } else if deleted == 0 { @@ -269,7 +245,7 @@ func DeleteOAuth2Application(id, userid int64) error { return err } defer committer.Close() - if err := deleteOAuth2Application(db.GetEngine(ctx), id, userid); err != nil { + if err := deleteOAuth2Application(ctx, id, userid); err != nil { return err } return committer.Commit() @@ -328,21 +304,13 @@ func (code *OAuth2AuthorizationCode) GenerateRedirectURI(state string) (redirect } // Invalidate deletes the auth code from the database to invalidate this code -func (code *OAuth2AuthorizationCode) Invalidate() error { - return code.invalidate(db.GetEngine(db.DefaultContext)) -} - -func (code *OAuth2AuthorizationCode) invalidate(e db.Engine) error { - _, err := e.Delete(code) +func (code *OAuth2AuthorizationCode) Invalidate(ctx context.Context) error { + _, err := db.GetEngine(ctx).ID(code.ID).NoAutoCondition().Delete(code) return err } // ValidateCodeChallenge validates the given verifier against the saved code challenge. This is part of the PKCE implementation. func (code *OAuth2AuthorizationCode) ValidateCodeChallenge(verifier string) bool { - return code.validateCodeChallenge(verifier) -} - -func (code *OAuth2AuthorizationCode) validateCodeChallenge(verifier string) bool { switch code.CodeChallengeMethod { case "S256": // base64url(SHA256(verifier)) see https://tools.ietf.org/html/rfc7636#section-4.6 @@ -360,19 +328,15 @@ func (code *OAuth2AuthorizationCode) validateCodeChallenge(verifier string) bool } // GetOAuth2AuthorizationByCode returns an authorization by its code -func GetOAuth2AuthorizationByCode(code string) (*OAuth2AuthorizationCode, error) { - return getOAuth2AuthorizationByCode(db.GetEngine(db.DefaultContext), code) -} - -func getOAuth2AuthorizationByCode(e db.Engine, code string) (auth *OAuth2AuthorizationCode, err error) { +func GetOAuth2AuthorizationByCode(ctx context.Context, code string) (auth *OAuth2AuthorizationCode, err error) { auth = new(OAuth2AuthorizationCode) - if has, err := e.Where("code = ?", code).Get(auth); err != nil { + if has, err := db.GetEngine(ctx).Where("code = ?", code).Get(auth); err != nil { return nil, err } else if !has { return nil, nil } auth.Grant = new(OAuth2Grant) - if has, err := e.ID(auth.GrantID).Get(auth.Grant); err != nil { + if has, err := db.GetEngine(ctx).ID(auth.GrantID).Get(auth.Grant); err != nil { return nil, err } else if !has { return nil, nil @@ -401,11 +365,7 @@ func (grant *OAuth2Grant) TableName() string { } // GenerateNewAuthorizationCode generates a new authorization code for a grant and saves it to the database -func (grant *OAuth2Grant) GenerateNewAuthorizationCode(redirectURI, codeChallenge, codeChallengeMethod string) (*OAuth2AuthorizationCode, error) { - return grant.generateNewAuthorizationCode(db.GetEngine(db.DefaultContext), redirectURI, codeChallenge, codeChallengeMethod) -} - -func (grant *OAuth2Grant) generateNewAuthorizationCode(e db.Engine, redirectURI, codeChallenge, codeChallengeMethod string) (code *OAuth2AuthorizationCode, err error) { +func (grant *OAuth2Grant) GenerateNewAuthorizationCode(ctx context.Context, redirectURI, codeChallenge, codeChallengeMethod string) (code *OAuth2AuthorizationCode, err error) { rBytes, err := util.CryptoRandomBytes(32) if err != nil { return &OAuth2AuthorizationCode{}, err @@ -422,23 +382,19 @@ func (grant *OAuth2Grant) generateNewAuthorizationCode(e db.Engine, redirectURI, CodeChallenge: codeChallenge, CodeChallengeMethod: codeChallengeMethod, } - if _, err := e.Insert(code); err != nil { + if err := db.Insert(ctx, code); err != nil { return nil, err } return code, nil } // IncreaseCounter increases the counter and updates the grant -func (grant *OAuth2Grant) IncreaseCounter() error { - return grant.increaseCount(db.GetEngine(db.DefaultContext)) -} - -func (grant *OAuth2Grant) increaseCount(e db.Engine) error { - _, err := e.ID(grant.ID).Incr("counter").Update(new(OAuth2Grant)) +func (grant *OAuth2Grant) IncreaseCounter(ctx context.Context) error { + _, err := db.GetEngine(ctx).ID(grant.ID).Incr("counter").Update(new(OAuth2Grant)) if err != nil { return err } - updatedGrant, err := getOAuth2GrantByID(e, grant.ID) + updatedGrant, err := GetOAuth2GrantByID(ctx, grant.ID) if err != nil { return err } @@ -457,13 +413,9 @@ func (grant *OAuth2Grant) ScopeContains(scope string) bool { } // SetNonce updates the current nonce value of a grant -func (grant *OAuth2Grant) SetNonce(nonce string) error { - return grant.setNonce(db.GetEngine(db.DefaultContext), nonce) -} - -func (grant *OAuth2Grant) setNonce(e db.Engine, nonce string) error { +func (grant *OAuth2Grant) SetNonce(ctx context.Context, nonce string) error { grant.Nonce = nonce - _, err := e.ID(grant.ID).Cols("nonce").Update(grant) + _, err := db.GetEngine(ctx).ID(grant.ID).Cols("nonce").Update(grant) if err != nil { return err } @@ -471,13 +423,9 @@ func (grant *OAuth2Grant) setNonce(e db.Engine, nonce string) error { } // GetOAuth2GrantByID returns the grant with the given ID -func GetOAuth2GrantByID(id int64) (*OAuth2Grant, error) { - return getOAuth2GrantByID(db.GetEngine(db.DefaultContext), id) -} - -func getOAuth2GrantByID(e db.Engine, id int64) (grant *OAuth2Grant, err error) { +func GetOAuth2GrantByID(ctx context.Context, id int64) (grant *OAuth2Grant, err error) { grant = new(OAuth2Grant) - if has, err := e.ID(id).Get(grant); err != nil { + if has, err := db.GetEngine(ctx).ID(id).Get(grant); err != nil { return nil, err } else if !has { return nil, nil @@ -486,18 +434,14 @@ func getOAuth2GrantByID(e db.Engine, id int64) (grant *OAuth2Grant, err error) { } // GetOAuth2GrantsByUserID lists all grants of a certain user -func GetOAuth2GrantsByUserID(uid int64) ([]*OAuth2Grant, error) { - return getOAuth2GrantsByUserID(db.GetEngine(db.DefaultContext), uid) -} - -func getOAuth2GrantsByUserID(e db.Engine, uid int64) ([]*OAuth2Grant, error) { +func GetOAuth2GrantsByUserID(ctx context.Context, uid int64) ([]*OAuth2Grant, error) { type joinedOAuth2Grant struct { Grant *OAuth2Grant `xorm:"extends"` Application *OAuth2Application `xorm:"extends"` } var results *xorm.Rows var err error - if results, err = e. + if results, err = db.GetEngine(ctx). Table("oauth2_grant"). Where("user_id = ?", uid). Join("INNER", "oauth2_application", "application_id = oauth2_application.id"). @@ -518,12 +462,8 @@ func getOAuth2GrantsByUserID(e db.Engine, uid int64) ([]*OAuth2Grant, error) { } // RevokeOAuth2Grant deletes the grant with grantID and userID -func RevokeOAuth2Grant(grantID, userID int64) error { - return revokeOAuth2Grant(db.GetEngine(db.DefaultContext), grantID, userID) -} - -func revokeOAuth2Grant(e db.Engine, grantID, userID int64) error { - _, err := e.Delete(&OAuth2Grant{ID: grantID, UserID: userID}) +func RevokeOAuth2Grant(ctx context.Context, grantID, userID int64) error { + _, err := db.DeleteByBean(ctx, &OAuth2Grant{ID: grantID, UserID: userID}) return err } diff --git a/models/auth/oauth2_test.go b/models/auth/oauth2_test.go index b712fc285..cb8c4aeb6 100644 --- a/models/auth/oauth2_test.go +++ b/models/auth/oauth2_test.go @@ -7,6 +7,7 @@ package auth import ( "testing" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/models/unittest" "github.com/stretchr/testify/assert" @@ -52,18 +53,18 @@ func TestOAuth2Application_ValidateClientSecret(t *testing.T) { func TestGetOAuth2ApplicationByClientID(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - app, err := GetOAuth2ApplicationByClientID("da7da3ba-9a13-4167-856f-3899de0b0138") + app, err := GetOAuth2ApplicationByClientID(db.DefaultContext, "da7da3ba-9a13-4167-856f-3899de0b0138") assert.NoError(t, err) assert.Equal(t, "da7da3ba-9a13-4167-856f-3899de0b0138", app.ClientID) - app, err = GetOAuth2ApplicationByClientID("invalid client id") + app, err = GetOAuth2ApplicationByClientID(db.DefaultContext, "invalid client id") assert.Error(t, err) assert.Nil(t, app) } func TestCreateOAuth2Application(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - app, err := CreateOAuth2Application(CreateOAuth2ApplicationOptions{Name: "newapp", UserID: 1}) + app, err := CreateOAuth2Application(db.DefaultContext, CreateOAuth2ApplicationOptions{Name: "newapp", UserID: 1}) assert.NoError(t, err) assert.Equal(t, "newapp", app.Name) assert.Len(t, app.ClientID, 36) @@ -77,11 +78,11 @@ func TestOAuth2Application_TableName(t *testing.T) { func TestOAuth2Application_GetGrantByUserID(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) app := unittest.AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1}).(*OAuth2Application) - grant, err := app.GetGrantByUserID(1) + grant, err := app.GetGrantByUserID(db.DefaultContext, 1) assert.NoError(t, err) assert.Equal(t, int64(1), grant.UserID) - grant, err = app.GetGrantByUserID(34923458) + grant, err = app.GetGrantByUserID(db.DefaultContext, 34923458) assert.NoError(t, err) assert.Nil(t, grant) } @@ -89,7 +90,7 @@ func TestOAuth2Application_GetGrantByUserID(t *testing.T) { func TestOAuth2Application_CreateGrant(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) app := unittest.AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1}).(*OAuth2Application) - grant, err := app.CreateGrant(2, "") + grant, err := app.CreateGrant(db.DefaultContext, 2, "") assert.NoError(t, err) assert.NotNil(t, grant) assert.Equal(t, int64(2), grant.UserID) @@ -101,11 +102,11 @@ func TestOAuth2Application_CreateGrant(t *testing.T) { func TestGetOAuth2GrantByID(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - grant, err := GetOAuth2GrantByID(1) + grant, err := GetOAuth2GrantByID(db.DefaultContext, 1) assert.NoError(t, err) assert.Equal(t, int64(1), grant.ID) - grant, err = GetOAuth2GrantByID(34923458) + grant, err = GetOAuth2GrantByID(db.DefaultContext, 34923458) assert.NoError(t, err) assert.Nil(t, grant) } @@ -113,7 +114,7 @@ func TestGetOAuth2GrantByID(t *testing.T) { func TestOAuth2Grant_IncreaseCounter(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) grant := unittest.AssertExistsAndLoadBean(t, &OAuth2Grant{ID: 1, Counter: 1}).(*OAuth2Grant) - assert.NoError(t, grant.IncreaseCounter()) + assert.NoError(t, grant.IncreaseCounter(db.DefaultContext)) assert.Equal(t, int64(2), grant.Counter) unittest.AssertExistsAndLoadBean(t, &OAuth2Grant{ID: 1, Counter: 2}) } @@ -130,7 +131,7 @@ func TestOAuth2Grant_ScopeContains(t *testing.T) { func TestOAuth2Grant_GenerateNewAuthorizationCode(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) grant := unittest.AssertExistsAndLoadBean(t, &OAuth2Grant{ID: 1}).(*OAuth2Grant) - code, err := grant.GenerateNewAuthorizationCode("https://example2.com/callback", "CjvyTLSdR47G5zYenDA-eDWW4lRrO8yvjcWwbD_deOg", "S256") + code, err := grant.GenerateNewAuthorizationCode(db.DefaultContext, "https://example2.com/callback", "CjvyTLSdR47G5zYenDA-eDWW4lRrO8yvjcWwbD_deOg", "S256") assert.NoError(t, err) assert.NotNil(t, code) assert.True(t, len(code.Code) > 32) // secret length > 32 @@ -142,20 +143,20 @@ func TestOAuth2Grant_TableName(t *testing.T) { func TestGetOAuth2GrantsByUserID(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - result, err := GetOAuth2GrantsByUserID(1) + result, err := GetOAuth2GrantsByUserID(db.DefaultContext, 1) assert.NoError(t, err) assert.Len(t, result, 1) assert.Equal(t, int64(1), result[0].ID) assert.Equal(t, result[0].ApplicationID, result[0].Application.ID) - result, err = GetOAuth2GrantsByUserID(34134) + result, err = GetOAuth2GrantsByUserID(db.DefaultContext, 34134) assert.NoError(t, err) assert.Empty(t, result) } func TestRevokeOAuth2Grant(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - assert.NoError(t, RevokeOAuth2Grant(1, 1)) + assert.NoError(t, RevokeOAuth2Grant(db.DefaultContext, 1, 1)) unittest.AssertNotExistsBean(t, &OAuth2Grant{ID: 1, UserID: 1}) } @@ -163,13 +164,13 @@ func TestRevokeOAuth2Grant(t *testing.T) { func TestGetOAuth2AuthorizationByCode(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - code, err := GetOAuth2AuthorizationByCode("authcode") + code, err := GetOAuth2AuthorizationByCode(db.DefaultContext, "authcode") assert.NoError(t, err) assert.NotNil(t, code) assert.Equal(t, "authcode", code.Code) assert.Equal(t, int64(1), code.ID) - code, err = GetOAuth2AuthorizationByCode("does not exist") + code, err = GetOAuth2AuthorizationByCode(db.DefaultContext, "does not exist") assert.NoError(t, err) assert.Nil(t, code) } @@ -224,7 +225,7 @@ func TestOAuth2AuthorizationCode_GenerateRedirectURI(t *testing.T) { func TestOAuth2AuthorizationCode_Invalidate(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) code := unittest.AssertExistsAndLoadBean(t, &OAuth2AuthorizationCode{Code: "authcode"}).(*OAuth2AuthorizationCode) - assert.NoError(t, code.Invalidate()) + assert.NoError(t, code.Invalidate(db.DefaultContext)) unittest.AssertNotExistsBean(t, &OAuth2AuthorizationCode{Code: "authcode"}) } diff --git a/models/branches.go b/models/branches.go index 008cb8653..98d2c3a99 100644 --- a/models/branches.go +++ b/models/branches.go @@ -306,13 +306,9 @@ func (protectBranch *ProtectedBranch) IsUnprotectedFile(patterns []glob.Glob, pa } // GetProtectedBranchBy getting protected branch by ID/Name -func GetProtectedBranchBy(repoID int64, branchName string) (*ProtectedBranch, error) { - return getProtectedBranchBy(db.GetEngine(db.DefaultContext), repoID, branchName) -} - -func getProtectedBranchBy(e db.Engine, repoID int64, branchName string) (*ProtectedBranch, error) { +func GetProtectedBranchBy(ctx context.Context, repoID int64, branchName string) (*ProtectedBranch, error) { rel := &ProtectedBranch{RepoID: repoID, BranchName: branchName} - has, err := e.Get(rel) + has, err := db.GetByBean(ctx, rel) if err != nil { return nil, err } @@ -632,7 +628,7 @@ func RenameBranch(repo *repo_model.Repository, from, to string, gitAction func(i } // 2. Update protected branch if needed - protectedBranch, err := getProtectedBranchBy(sess, repo.ID, from) + protectedBranch, err := GetProtectedBranchBy(ctx, repo.ID, from) if err != nil { return err } diff --git a/models/commit_status.go b/models/commit_status.go index cf2143d30..ef92c5847 100644 --- a/models/commit_status.go +++ b/models/commit_status.go @@ -49,21 +49,21 @@ func init() { } // upsertCommitStatusIndex the function will not return until it acquires the lock or receives an error. -func upsertCommitStatusIndex(e db.Engine, repoID int64, sha string) (err error) { +func upsertCommitStatusIndex(ctx context.Context, repoID int64, sha string) (err error) { // An atomic UPSERT operation (INSERT/UPDATE) is the only operation // that ensures that the key is actually locked. switch { case setting.Database.UseSQLite3 || setting.Database.UsePostgreSQL: - _, err = e.Exec("INSERT INTO `commit_status_index` (repo_id, sha, max_index) "+ + _, err = db.Exec(ctx, "INSERT INTO `commit_status_index` (repo_id, sha, max_index) "+ "VALUES (?,?,1) ON CONFLICT (repo_id,sha) DO UPDATE SET max_index = `commit_status_index`.max_index+1", repoID, sha) case setting.Database.UseMySQL: - _, err = e.Exec("INSERT INTO `commit_status_index` (repo_id, sha, max_index) "+ + _, err = db.Exec(ctx, "INSERT INTO `commit_status_index` (repo_id, sha, max_index) "+ "VALUES (?,?,1) ON DUPLICATE KEY UPDATE max_index = max_index+1", repoID, sha) case setting.Database.UseMSSQL: // https://weblogs.sqlteam.com/dang/2009/01/31/upsert-race-condition-with-merge/ - _, err = e.Exec("MERGE `commit_status_index` WITH (HOLDLOCK) as target "+ + _, err = db.Exec(ctx, "MERGE `commit_status_index` WITH (HOLDLOCK) as target "+ "USING (SELECT ? AS repo_id, ? AS sha) AS src "+ "ON src.repo_id = target.repo_id AND src.sha = target.sha "+ "WHEN MATCHED THEN UPDATE SET target.max_index = target.max_index+1 "+ @@ -100,17 +100,17 @@ func getNextCommitStatusIndex(repoID int64, sha string) (int64, error) { defer commiter.Close() var preIdx int64 - _, err = ctx.Engine().SQL("SELECT max_index FROM `commit_status_index` WHERE repo_id = ? AND sha = ?", repoID, sha).Get(&preIdx) + _, err = db.GetEngine(ctx).SQL("SELECT max_index FROM `commit_status_index` WHERE repo_id = ? AND sha = ?", repoID, sha).Get(&preIdx) if err != nil { return 0, err } - if err := upsertCommitStatusIndex(ctx.Engine(), repoID, sha); err != nil { + if err := upsertCommitStatusIndex(ctx, repoID, sha); err != nil { return 0, err } var curIdx int64 - has, err := ctx.Engine().SQL("SELECT max_index FROM `commit_status_index` WHERE repo_id = ? AND sha = ? AND max_index=?", repoID, sha, preIdx+1).Get(&curIdx) + has, err := db.GetEngine(ctx).SQL("SELECT max_index FROM `commit_status_index` WHERE repo_id = ? AND sha = ? AND max_index=?", repoID, sha, preIdx+1).Get(&curIdx) if err != nil { return 0, err } @@ -131,7 +131,7 @@ func (status *CommitStatus) loadAttributes(ctx context.Context) (err error) { } } if status.Creator == nil && status.CreatorID > 0 { - status.Creator, err = user_model.GetUserByIDEngine(db.GetEngine(ctx), status.CreatorID) + status.Creator, err = user_model.GetUserByIDCtx(ctx, status.CreatorID) if err != nil { return fmt.Errorf("getUserByID [%d]: %v", status.CreatorID, err) } @@ -231,12 +231,7 @@ type CommitStatusIndex struct { } // GetLatestCommitStatus returns all statuses with a unique context for a given commit. -func GetLatestCommitStatus(repoID int64, sha string, listOptions db.ListOptions) ([]*CommitStatus, int64, error) { - return GetLatestCommitStatusCtx(db.DefaultContext, repoID, sha, listOptions) -} - -// GetLatestCommitStatusCtx returns all statuses with a unique context for a given commit. -func GetLatestCommitStatusCtx(ctx context.Context, repoID int64, sha string, listOptions db.ListOptions) ([]*CommitStatus, int64, error) { +func GetLatestCommitStatus(ctx context.Context, repoID int64, sha string, listOptions db.ListOptions) ([]*CommitStatus, int64, error) { ids := make([]int64, 0, 10) sess := db.GetEngine(ctx).Table(&CommitStatus{}). Where("repo_id = ?", repoID).And("sha = ?", sha). @@ -341,7 +336,7 @@ func ParseCommitsWithStatus(oldCommits []*asymkey_model.SignCommit, repo *repo_m commit := &SignCommitWithStatuses{ SignCommit: c, } - statuses, _, err := GetLatestCommitStatus(repo.ID, commit.ID.String(), db.ListOptions{}) + statuses, _, err := GetLatestCommitStatus(db.DefaultContext, repo.ID, commit.ID.String(), db.ListOptions{}) if err != nil { log.Error("GetLatestCommitStatus: %v", err) } else { diff --git a/models/consistency.go b/models/consistency.go index 30cafb788..abef7243f 100644 --- a/models/consistency.go +++ b/models/consistency.go @@ -117,7 +117,7 @@ func DeleteOrphanedIssues() error { var attachmentPaths []string for i := range ids { - paths, err := deleteIssuesByRepoID(db.GetEngine(ctx), ids[i]) + paths, err := deleteIssuesByRepoID(ctx, ids[i]) if err != nil { return err } diff --git a/models/db/index.go b/models/db/index.go index 0086a8f54..8598de949 100644 --- a/models/db/index.go +++ b/models/db/index.go @@ -5,6 +5,7 @@ package db import ( + "context" "errors" "fmt" @@ -74,8 +75,8 @@ func GetNextResourceIndex(tableName string, groupID int64) (int64, error) { } // DeleteResouceIndex delete resource index -func DeleteResouceIndex(e Engine, tableName string, groupID int64) error { - _, err := e.Exec(fmt.Sprintf("DELETE FROM %s WHERE group_id=?", tableName), groupID) +func DeleteResouceIndex(ctx context.Context, tableName string, groupID int64) error { + _, err := Exec(ctx, fmt.Sprintf("DELETE FROM %s WHERE group_id=?", tableName), groupID) return err } diff --git a/models/issue.go b/models/issue.go index c344998b9..1a66e5e95 100644 --- a/models/issue.go +++ b/models/issue.go @@ -107,9 +107,9 @@ func init() { db.RegisterModel(new(IssueIndex)) } -func (issue *Issue) loadTotalTimes(e db.Engine) (err error) { +func (issue *Issue) loadTotalTimes(ctx context.Context) (err error) { opts := FindTrackedTimesOptions{IssueID: issue.ID} - issue.TotalTrackedTime, err = opts.toSession(e).SumInt(&TrackedTime{}, "time") + issue.TotalTrackedTime, err = opts.toSession(db.GetEngine(ctx)).SumInt(&TrackedTime{}, "time") if err != nil { return err } @@ -154,7 +154,7 @@ func (issue *Issue) GetPullRequest() (pr *PullRequest, err error) { return nil, fmt.Errorf("Issue is not a pull request") } - pr, err = getPullRequestByIssueID(db.GetEngine(db.DefaultContext), issue.ID) + pr, err = GetPullRequestByIssueID(db.DefaultContext, issue.ID) if err != nil { return nil, err } @@ -165,7 +165,7 @@ func (issue *Issue) GetPullRequest() (pr *PullRequest, err error) { // LoadLabels loads labels func (issue *Issue) LoadLabels(ctx context.Context) (err error) { if issue.Labels == nil { - issue.Labels, err = getLabelsByIssueID(db.GetEngine(ctx), issue.ID) + issue.Labels, err = GetLabelsByIssueID(ctx, issue.ID) if err != nil { return fmt.Errorf("getLabelsByIssueID [%d]: %v", issue.ID, err) } @@ -175,12 +175,12 @@ func (issue *Issue) LoadLabels(ctx context.Context) (err error) { // LoadPoster loads poster func (issue *Issue) LoadPoster() error { - return issue.loadPoster(db.GetEngine(db.DefaultContext)) + return issue.loadPoster(db.DefaultContext) } -func (issue *Issue) loadPoster(e db.Engine) (err error) { +func (issue *Issue) loadPoster(ctx context.Context) (err error) { if issue.Poster == nil { - issue.Poster, err = user_model.GetUserByIDEngine(e, issue.PosterID) + issue.Poster, err = user_model.GetUserByIDCtx(ctx, issue.PosterID) if err != nil { issue.PosterID = -1 issue.Poster = user_model.NewGhostUser() @@ -194,9 +194,9 @@ func (issue *Issue) loadPoster(e db.Engine) (err error) { return } -func (issue *Issue) loadPullRequest(e db.Engine) (err error) { +func (issue *Issue) loadPullRequest(ctx context.Context) (err error) { if issue.IsPull && issue.PullRequest == nil { - issue.PullRequest, err = getPullRequestByIssueID(e, issue.ID) + issue.PullRequest, err = GetPullRequestByIssueID(ctx, issue.ID) if err != nil { if IsErrPullRequestNotExist(err) { return err @@ -210,23 +210,23 @@ func (issue *Issue) loadPullRequest(e db.Engine) (err error) { // LoadPullRequest loads pull request info func (issue *Issue) LoadPullRequest() error { - return issue.loadPullRequest(db.GetEngine(db.DefaultContext)) + return issue.loadPullRequest(db.DefaultContext) } -func (issue *Issue) loadComments(e db.Engine) (err error) { - return issue.loadCommentsByType(e, CommentTypeUnknown) +func (issue *Issue) loadComments(ctx context.Context) (err error) { + return issue.loadCommentsByType(ctx, CommentTypeUnknown) } // LoadDiscussComments loads discuss comments func (issue *Issue) LoadDiscussComments() error { - return issue.loadCommentsByType(db.GetEngine(db.DefaultContext), CommentTypeComment) + return issue.loadCommentsByType(db.DefaultContext, CommentTypeComment) } -func (issue *Issue) loadCommentsByType(e db.Engine, tp CommentType) (err error) { +func (issue *Issue) loadCommentsByType(ctx context.Context, tp CommentType) (err error) { if issue.Comments != nil { return nil } - issue.Comments, err = findComments(e, &FindCommentsOptions{ + issue.Comments, err = FindComments(ctx, &FindCommentsOptions{ IssueID: issue.ID, Type: tp, }) @@ -301,12 +301,11 @@ func (issue *Issue) loadMilestone(ctx context.Context) (err error) { } func (issue *Issue) loadAttributes(ctx context.Context) (err error) { - e := db.GetEngine(ctx) if err = issue.LoadRepo(ctx); err != nil { return } - if err = issue.loadPoster(e); err != nil { + if err = issue.loadPoster(ctx); err != nil { return } @@ -318,27 +317,27 @@ func (issue *Issue) loadAttributes(ctx context.Context) (err error) { return } - if err = issue.loadProject(e); err != nil { + if err = issue.loadProject(ctx); err != nil { return } - if err = issue.loadAssignees(e); err != nil { + if err = issue.LoadAssignees(ctx); err != nil { return } - if err = issue.loadPullRequest(e); err != nil && !IsErrPullRequestNotExist(err) { + if err = issue.loadPullRequest(ctx); err != nil && !IsErrPullRequestNotExist(err) { // It is possible pull request is not yet created. return err } if issue.Attachments == nil { - issue.Attachments, err = repo_model.GetAttachmentsByIssueIDCtx(ctx, issue.ID) + issue.Attachments, err = repo_model.GetAttachmentsByIssueID(ctx, issue.ID) if err != nil { return fmt.Errorf("getAttachmentsByIssueID [%d]: %v", issue.ID, err) } } - if err = issue.loadComments(e); err != nil { + if err = issue.loadComments(ctx); err != nil { return err } @@ -346,7 +345,7 @@ func (issue *Issue) loadAttributes(ctx context.Context) (err error) { return err } if issue.isTimetrackerEnabled(ctx) { - if err = issue.loadTotalTimes(e); err != nil { + if err = issue.loadTotalTimes(ctx); err != nil { return err } } @@ -449,12 +448,12 @@ func (issue *Issue) IsPoster(uid int64) bool { return issue.OriginalAuthorID == 0 && issue.PosterID == uid } -func (issue *Issue) getLabels(e db.Engine) (err error) { +func (issue *Issue) getLabels(ctx context.Context) (err error) { if len(issue.Labels) > 0 { return nil } - issue.Labels, err = getLabelsByIssueID(e, issue.ID) + issue.Labels, err = GetLabelsByIssueID(ctx, issue.ID) if err != nil { return fmt.Errorf("getLabelsByIssueID: %v", err) } @@ -462,7 +461,7 @@ func (issue *Issue) getLabels(e db.Engine) (err error) { } func clearIssueLabels(ctx context.Context, issue *Issue, doer *user_model.User) (err error) { - if err = issue.getLabels(db.GetEngine(ctx)); err != nil { + if err = issue.getLabels(ctx); err != nil { return fmt.Errorf("getLabels: %v", err) } @@ -486,7 +485,7 @@ func ClearIssueLabels(issue *Issue, doer *user_model.User) (err error) { if err := issue.LoadRepo(ctx); err != nil { return err - } else if err = issue.loadPullRequest(db.GetEngine(ctx)); err != nil { + } else if err = issue.loadPullRequest(ctx); err != nil { return err } @@ -597,7 +596,7 @@ func (issue *Issue) ReadBy(ctx context.Context, userID int64) error { return err } - return setIssueNotificationStatusReadIfUnread(db.GetEngine(db.DefaultContext), userID, issue.ID) + return setIssueNotificationStatusReadIfUnread(ctx, userID, issue.ID) } // UpdateIssueCols updates cols of issue @@ -610,7 +609,7 @@ func UpdateIssueCols(ctx context.Context, issue *Issue, cols ...string) error { func changeIssueStatus(ctx context.Context, issue *Issue, doer *user_model.User, isClosed, isMergePull bool) (*Comment, error) { // Reload the issue - currentIssue, err := getIssueByID(db.GetEngine(ctx), issue.ID) + currentIssue, err := getIssueByID(ctx, issue.ID) if err != nil { return nil, err } @@ -632,7 +631,6 @@ func changeIssueStatus(ctx context.Context, issue *Issue, doer *user_model.User, } func doChangeIssueStatus(ctx context.Context, issue *Issue, doer *user_model.User, isMergePull bool) (*Comment, error) { - e := db.GetEngine(ctx) // Check for open dependencies if issue.IsClosed && issue.Repo.IsDependenciesEnabledCtx(ctx) { // only check if dependencies are enabled and we're about to close an issue, otherwise reopening an issue would fail when there are unsatisfied dependencies @@ -657,11 +655,11 @@ func doChangeIssueStatus(ctx context.Context, issue *Issue, doer *user_model.Use } // Update issue count of labels - if err := issue.getLabels(e); err != nil { + if err := issue.getLabels(ctx); err != nil { return nil, err } for idx := range issue.Labels { - if err := updateLabelCols(e, issue.Labels[idx], "num_issues", "num_closed_issue"); err != nil { + if err := updateLabelCols(ctx, issue.Labels[idx], "num_issues", "num_closed_issue"); err != nil { return nil, err } } @@ -698,7 +696,7 @@ func ChangeIssueStatus(ctx context.Context, issue *Issue, doer *user_model.User, if err := issue.LoadRepo(ctx); err != nil { return nil, err } - if err := issue.loadPoster(db.GetEngine(ctx)); err != nil { + if err := issue.loadPoster(ctx); err != nil { return nil, err } @@ -774,7 +772,7 @@ func ChangeIssueRef(issue *Issue, doer *user_model.User, oldRef string) (err err // AddDeletePRBranchComment adds delete branch comment for pull request issue func AddDeletePRBranchComment(ctx context.Context, doer *user_model.User, repo *repo_model.Repository, issueID int64, branchName string) error { - issue, err := getIssueByID(db.GetEngine(ctx), issueID) + issue, err := getIssueByID(ctx, issueID) if err != nil { return err } @@ -802,7 +800,7 @@ func UpdateIssueAttachments(issueID int64, uuids []string) (err error) { } for i := 0; i < len(attachments); i++ { attachments[i].IssueID = issueID - if err := repo_model.UpdateAttachmentCtx(ctx, attachments[i]); err != nil { + if err := repo_model.UpdateAttachment(ctx, attachments[i]); err != nil { return fmt.Errorf("update attachment [id: %d]: %v", attachments[i].ID, err) } } @@ -822,7 +820,7 @@ func ChangeIssueContent(issue *Issue, doer *user_model.User, content string) (er return fmt.Errorf("HasIssueContentHistory: %v", err) } if !hasContentHistory { - if err = issues_model.SaveIssueContentHistory(db.GetEngine(ctx), issue.PosterID, issue.ID, 0, + if err = issues_model.SaveIssueContentHistory(ctx, issue.PosterID, issue.ID, 0, issue.CreatedUnix, issue.Content, true); err != nil { return fmt.Errorf("SaveIssueContentHistory: %v", err) } @@ -834,7 +832,7 @@ func ChangeIssueContent(issue *Issue, doer *user_model.User, content string) (er return fmt.Errorf("UpdateIssueCols: %v", err) } - if err = issues_model.SaveIssueContentHistory(db.GetEngine(ctx), doer.ID, issue.ID, 0, + if err = issues_model.SaveIssueContentHistory(ctx, doer.ID, issue.ID, 0, timeutil.TimeStampNow(), issue.Content, false); err != nil { return fmt.Errorf("SaveIssueContentHistory: %v", err) } @@ -973,7 +971,7 @@ func newIssue(ctx context.Context, doer *user_model.User, opts NewIssueOptions) return fmt.Errorf("find all labels [label_ids: %v]: %v", opts.LabelIDs, err) } - if err = opts.Issue.loadPoster(e); err != nil { + if err = opts.Issue.loadPoster(ctx); err != nil { return err } @@ -1119,9 +1117,9 @@ func GetIssueWithAttrsByIndex(repoID, index int64) (*Issue, error) { return issue, issue.LoadAttributes() } -func getIssueByID(e db.Engine, id int64) (*Issue, error) { +func getIssueByID(ctx context.Context, id int64) (*Issue, error) { issue := new(Issue) - has, err := e.ID(id).Get(issue) + has, err := db.GetEngine(ctx).ID(id).Get(issue) if err != nil { return nil, err } else if !has { @@ -1132,7 +1130,7 @@ func getIssueByID(e db.Engine, id int64) (*Issue, error) { // GetIssueWithAttrsByID returns an issue with attributes by given ID. func GetIssueWithAttrsByID(id int64) (*Issue, error) { - issue, err := getIssueByID(db.GetEngine(db.DefaultContext), id) + issue, err := getIssueByID(db.DefaultContext, id) if err != nil { return nil, err } @@ -1141,30 +1139,22 @@ func GetIssueWithAttrsByID(id int64) (*Issue, error) { // GetIssueByID returns an issue by given ID. func GetIssueByID(id int64) (*Issue, error) { - return getIssueByID(db.GetEngine(db.DefaultContext), id) + return getIssueByID(db.DefaultContext, id) } -func getIssuesByIDs(e db.Engine, issueIDs []int64) ([]*Issue, error) { +// GetIssuesByIDs return issues with the given IDs. +func GetIssuesByIDs(ctx context.Context, issueIDs []int64) ([]*Issue, error) { issues := make([]*Issue, 0, 10) - return issues, e.In("id", issueIDs).Find(&issues) + return issues, db.GetEngine(ctx).In("id", issueIDs).Find(&issues) } -func getIssueIDsByRepoID(e db.Engine, repoID int64) ([]int64, error) { +// GetIssueIDsByRepoID returns all issue ids by repo id +func GetIssueIDsByRepoID(ctx context.Context, repoID int64) ([]int64, error) { ids := make([]int64, 0, 10) - err := e.Table("issue").Cols("id").Where("repo_id = ?", repoID).Find(&ids) + err := db.GetEngine(ctx).Table("issue").Cols("id").Where("repo_id = ?", repoID).Find(&ids) return ids, err } -// GetIssueIDsByRepoID returns all issue ids by repo id -func GetIssueIDsByRepoID(repoID int64) ([]int64, error) { - return getIssueIDsByRepoID(db.GetEngine(db.DefaultContext), repoID) -} - -// GetIssuesByIDs return issues with the given IDs. -func GetIssuesByIDs(issueIDs []int64) ([]*Issue, error) { - return getIssuesByIDs(db.GetEngine(db.DefaultContext), issueIDs) -} - // IssuesOptions represents options of an issue. type IssuesOptions struct { db.ListOptions @@ -1502,7 +1492,7 @@ func GetParticipantsIDsByIssueID(issueID int64) ([]int64, error) { // IsUserParticipantsOfIssue return true if user is participants of an issue func IsUserParticipantsOfIssue(user *user_model.User, issue *Issue) bool { - userIDs, err := issue.getParticipantIDsByIssue(db.GetEngine(db.DefaultContext)) + userIDs, err := issue.getParticipantIDsByIssue(db.DefaultContext) if err != nil { log.Error(err.Error()) return false @@ -1912,19 +1902,18 @@ func UpdateIssueByAPI(issue *Issue, doer *user_model.User) (statusChangeComment return nil, false, err } defer committer.Close() - sess := db.GetEngine(ctx) if err := issue.LoadRepo(ctx); err != nil { return nil, false, fmt.Errorf("loadRepo: %v", err) } // Reload the issue - currentIssue, err := getIssueByID(sess, issue.ID) + currentIssue, err := getIssueByID(ctx, issue.ID) if err != nil { return nil, false, err } - if _, err := sess.ID(issue.ID).Cols( + if _, err := db.GetEngine(ctx).ID(issue.ID).Cols( "name", "content", "milestone_id", "priority", "deadline_unix", "updated_unix", "is_locked"). Update(issue); err != nil { @@ -2000,7 +1989,8 @@ func DeleteIssue(issue *Issue) error { return committer.Commit() } -func deleteInIssue(e db.Engine, issueID int64, beans ...interface{}) error { +func deleteInIssue(ctx context.Context, issueID int64, beans ...interface{}) error { + e := db.GetEngine(ctx) for _, bean := range beans { if _, err := e.In("issue_id", issueID).Delete(bean); err != nil { return err @@ -2061,7 +2051,7 @@ func deleteIssue(ctx context.Context, issue *Issue) error { } // delete all database data still assigned to this issue - if err := deleteInIssue(e, issue.ID, + if err := deleteInIssue(ctx, issue.ID, &issues_model.ContentHistory{}, &Comment{}, &IssueLabel{}, @@ -2105,12 +2095,12 @@ type DependencyInfo struct { } // getParticipantIDsByIssue returns all userIDs who are participated in comments of an issue and issue author -func (issue *Issue) getParticipantIDsByIssue(e db.Engine) ([]int64, error) { +func (issue *Issue) getParticipantIDsByIssue(ctx context.Context) ([]int64, error) { if issue == nil { return nil, nil } userIDs := make([]int64, 0, 5) - if err := e.Table("comment").Cols("poster_id"). + if err := db.GetEngine(ctx).Table("comment").Cols("poster_id"). Where("`comment`.issue_id = ?", issue.ID). And("`comment`.type in (?,?,?)", CommentTypeComment, CommentTypeCode, CommentTypeReview). And("`user`.is_active = ?", true). @@ -2126,9 +2116,9 @@ func (issue *Issue) getParticipantIDsByIssue(e db.Engine) ([]int64, error) { return userIDs, nil } -// Get Blocked By Dependencies, aka all issues this issue is blocked by. -func (issue *Issue) getBlockedByDependencies(e db.Engine) (issueDeps []*DependencyInfo, err error) { - err = e. +// BlockedByDependencies finds all Dependencies an issue is blocked by +func (issue *Issue) BlockedByDependencies(ctx context.Context) (issueDeps []*DependencyInfo, err error) { + err = db.GetEngine(ctx). Table("issue"). Join("INNER", "repository", "repository.id = issue.repo_id"). Join("INNER", "issue_dependency", "issue_dependency.dependency_id = issue.id"). @@ -2144,9 +2134,9 @@ func (issue *Issue) getBlockedByDependencies(e db.Engine) (issueDeps []*Dependen return issueDeps, err } -// Get Blocking Dependencies, aka all issues this issue blocks. -func (issue *Issue) getBlockingDependencies(e db.Engine) (issueDeps []*DependencyInfo, err error) { - err = e. +// BlockingDependencies returns all blocking dependencies, aka all other issues a given issue blocks +func (issue *Issue) BlockingDependencies(ctx context.Context) (issueDeps []*DependencyInfo, err error) { + err = db.GetEngine(ctx). Table("issue"). Join("INNER", "repository", "repository.id = issue.repo_id"). Join("INNER", "issue_dependency", "issue_dependency.issue_id = issue.id"). @@ -2162,16 +2152,6 @@ func (issue *Issue) getBlockingDependencies(e db.Engine) (issueDeps []*Dependenc return issueDeps, err } -// BlockedByDependencies finds all Dependencies an issue is blocked by -func (issue *Issue) BlockedByDependencies() ([]*DependencyInfo, error) { - return issue.getBlockedByDependencies(db.GetEngine(db.DefaultContext)) -} - -// BlockingDependencies returns all blocking dependencies, aka all other issues a given issue blocks -func (issue *Issue) BlockingDependencies() ([]*DependencyInfo, error) { - return issue.getBlockingDependencies(db.GetEngine(db.DefaultContext)) -} - func updateIssueClosedNum(ctx context.Context, issue *Issue) (err error) { if issue.IsPull { err = repoStatsCorrectNumClosed(ctx, issue.RepoID, true, "num_closed_pulls") @@ -2354,9 +2334,10 @@ func UpdateReactionsMigrationsByType(gitServiceType api.GitServiceType, original return err } -func deleteIssuesByRepoID(sess db.Engine, repoID int64) (attachmentPaths []string, err error) { +func deleteIssuesByRepoID(ctx context.Context, repoID int64) (attachmentPaths []string, err error) { deleteCond := builder.Select("id").From("issue").Where(builder.Eq{"issue.repo_id": repoID}) + sess := db.GetEngine(ctx) // Delete content histories if _, err = sess.In("issue_id", deleteCond). Delete(&issues_model.ContentHistory{}); err != nil { @@ -2431,7 +2412,7 @@ func deleteIssuesByRepoID(sess db.Engine, repoID int64) (attachmentPaths []strin return } - if _, err = sess.Delete(&Issue{RepoID: repoID}); err != nil { + if _, err = db.DeleteByBean(ctx, &Issue{RepoID: repoID}); err != nil { return } diff --git a/models/issue_assignees.go b/models/issue_assignees.go index 0f1f7b657..c6ccb6e9d 100644 --- a/models/issue_assignees.go +++ b/models/issue_assignees.go @@ -25,20 +25,15 @@ func init() { } // LoadAssignees load assignees of this issue. -func (issue *Issue) LoadAssignees() error { - return issue.loadAssignees(db.GetEngine(db.DefaultContext)) -} - -// This loads all assignees of an issue -func (issue *Issue) loadAssignees(e db.Engine) (err error) { +func (issue *Issue) LoadAssignees(ctx context.Context) (err error) { // Reset maybe preexisting assignees issue.Assignees = []*user_model.User{} + issue.Assignee = nil - err = e.Table("`user`"). + err = db.GetEngine(ctx).Table("`user`"). Join("INNER", "issue_assignees", "assignee_id = `user`.id"). Where("issue_assignees.issue_id = ?", issue.ID). Find(&issue.Assignees) - if err != nil { return err } @@ -47,7 +42,6 @@ func (issue *Issue) loadAssignees(e db.Engine) (err error) { if len(issue.Assignees) > 0 { issue.Assignee = issue.Assignees[0] } - return } @@ -63,33 +57,9 @@ func GetAssigneeIDsByIssue(issueID int64) ([]int64, error) { Find(&userIDs) } -// GetAssigneesByIssue returns everyone assigned to that issue -func GetAssigneesByIssue(issue *Issue) (assignees []*user_model.User, err error) { - return getAssigneesByIssue(db.GetEngine(db.DefaultContext), issue) -} - -func getAssigneesByIssue(e db.Engine, issue *Issue) (assignees []*user_model.User, err error) { - err = issue.loadAssignees(e) - if err != nil { - return assignees, err - } - - return issue.Assignees, nil -} - // IsUserAssignedToIssue returns true when the user is assigned to the issue -func IsUserAssignedToIssue(issue *Issue, user *user_model.User) (isAssigned bool, err error) { - return isUserAssignedToIssue(db.GetEngine(db.DefaultContext), issue, user) -} - -func isUserAssignedToIssue(e db.Engine, issue *Issue, user *user_model.User) (isAssigned bool, err error) { - return e.Get(&IssueAssignees{IssueID: issue.ID, AssigneeID: user.ID}) -} - -// ClearAssigneeByUserID deletes all assignments of an user -func clearAssigneeByUserID(sess db.Engine, userID int64) (err error) { - _, err = sess.Delete(&IssueAssignees{AssigneeID: userID}) - return +func IsUserAssignedToIssue(ctx context.Context, issue *Issue, user *user_model.User) (isAssigned bool, err error) { + return db.GetByBean(ctx, &IssueAssignees{IssueID: issue.ID, AssigneeID: user.ID}) } // ToggleIssueAssignee changes a user between assigned and not assigned for this issue, and make issue comment for it. @@ -113,8 +83,7 @@ func ToggleIssueAssignee(issue *Issue, doer *user_model.User, assigneeID int64) } func toggleIssueAssignee(ctx context.Context, issue *Issue, doer *user_model.User, assigneeID int64, isCreate bool) (removed bool, comment *Comment, err error) { - sess := db.GetEngine(ctx) - removed, err = toggleUserAssignee(sess, issue, assigneeID) + removed, err = toggleUserAssignee(ctx, issue, assigneeID) if err != nil { return false, nil, fmt.Errorf("UpdateIssueUserByAssignee: %v", err) } @@ -147,39 +116,38 @@ func toggleIssueAssignee(ctx context.Context, issue *Issue, doer *user_model.Use } // toggles user assignee state in database -func toggleUserAssignee(e db.Engine, issue *Issue, assigneeID int64) (removed bool, err error) { +func toggleUserAssignee(ctx context.Context, issue *Issue, assigneeID int64) (removed bool, err error) { // Check if the user exists - assignee, err := user_model.GetUserByIDEngine(e, assigneeID) + assignee, err := user_model.GetUserByIDCtx(ctx, assigneeID) if err != nil { return false, err } // Check if the submitted user is already assigned, if yes delete him otherwise add him - var i int - for i = 0; i < len(issue.Assignees); i++ { + found := false + i := 0 + for ; i < len(issue.Assignees); i++ { if issue.Assignees[i].ID == assigneeID { + found = true break } } assigneeIn := IssueAssignees{AssigneeID: assigneeID, IssueID: issue.ID} - - toBeDeleted := i < len(issue.Assignees) - if toBeDeleted { - issue.Assignees = append(issue.Assignees[:i], issue.Assignees[i:]...) - _, err = e.Delete(assigneeIn) + if found { + issue.Assignees = append(issue.Assignees[:i], issue.Assignees[i+1:]...) + _, err = db.DeleteByBean(ctx, &assigneeIn) if err != nil { - return toBeDeleted, err + return found, err } } else { issue.Assignees = append(issue.Assignees, assignee) - _, err = e.Insert(assigneeIn) - if err != nil { - return toBeDeleted, err + if err = db.Insert(ctx, &assigneeIn); err != nil { + return found, err } } - return toBeDeleted, nil + return found, nil } // MakeIDsFromAPIAssigneesToAdd returns an array with all assignee IDs diff --git a/models/issue_assignees_test.go b/models/issue_assignees_test.go index 41a3ad86e..80317e160 100644 --- a/models/issue_assignees_test.go +++ b/models/issue_assignees_test.go @@ -7,6 +7,7 @@ package models import ( "testing" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/models/unittest" user_model "code.gitea.io/gitea/models/user" @@ -37,28 +38,28 @@ func TestUpdateAssignee(t *testing.T) { assert.NoError(t, err) // Check if he got removed - isAssigned, err := IsUserAssignedToIssue(issue, user1) + isAssigned, err := IsUserAssignedToIssue(db.DefaultContext, issue, user1) assert.NoError(t, err) assert.False(t, isAssigned) // Check if they're all there - assignees, err := GetAssigneesByIssue(issue) + err = issue.LoadAssignees(db.DefaultContext) assert.NoError(t, err) var expectedAssignees []*user_model.User expectedAssignees = append(expectedAssignees, user2, user3) - for in, assignee := range assignees { + for in, assignee := range issue.Assignees { assert.Equal(t, assignee.ID, expectedAssignees[in].ID) } // Check if the user is assigned - isAssigned, err = IsUserAssignedToIssue(issue, user2) + isAssigned, err = IsUserAssignedToIssue(db.DefaultContext, issue, user2) assert.NoError(t, err) assert.True(t, isAssigned) // This user should not be assigned - isAssigned, err = IsUserAssignedToIssue(issue, &user_model.User{ID: 4}) + isAssigned, err = IsUserAssignedToIssue(db.DefaultContext, issue, &user_model.User{ID: 4}) assert.NoError(t, err) assert.False(t, isAssigned) } diff --git a/models/issue_comment.go b/models/issue_comment.go index 2cf3d5a61..90c95afa4 100644 --- a/models/issue_comment.go +++ b/models/issue_comment.go @@ -298,7 +298,7 @@ func (c *Comment) LoadIssueCtx(ctx context.Context) (err error) { if c.Issue != nil { return nil } - c.Issue, err = getIssueByID(db.GetEngine(ctx), c.IssueID) + c.Issue, err = getIssueByID(ctx, c.IssueID) return } @@ -329,12 +329,12 @@ func (c *Comment) AfterLoad(session *xorm.Session) { } } -func (c *Comment) loadPoster(e db.Engine) (err error) { +func (c *Comment) loadPoster(ctx context.Context) (err error) { if c.PosterID <= 0 || c.Poster != nil { return nil } - c.Poster, err = user_model.GetUserByIDEngine(e, c.PosterID) + c.Poster, err = user_model.GetUserByIDCtx(ctx, c.PosterID) if err != nil { if user_model.IsErrUserNotExist(err) { c.PosterID = -1 @@ -525,7 +525,7 @@ func (c *Comment) LoadMilestone() error { // LoadPoster loads comment poster func (c *Comment) LoadPoster() error { - return c.loadPoster(db.GetEngine(db.DefaultContext)) + return c.loadPoster(db.DefaultContext) } // LoadAttachments loads attachments (it never returns error, the error during `GetAttachmentsByCommentIDCtx` is ignored) @@ -535,7 +535,7 @@ func (c *Comment) LoadAttachments() error { } var err error - c.Attachments, err = repo_model.GetAttachmentsByCommentIDCtx(db.DefaultContext, c.ID) + c.Attachments, err = repo_model.GetAttachmentsByCommentID(db.DefaultContext, c.ID) if err != nil { log.Error("getAttachmentsByCommentID[%d]: %v", c.ID, err) } @@ -557,7 +557,7 @@ func (c *Comment) UpdateAttachments(uuids []string) error { for i := 0; i < len(attachments); i++ { attachments[i].IssueID = c.IssueID attachments[i].CommentID = c.ID - if err := repo_model.UpdateAttachmentCtx(ctx, attachments[i]); err != nil { + if err := repo_model.UpdateAttachment(ctx, attachments[i]); err != nil { return fmt.Errorf("update attachment [id: %d]: %v", attachments[i].ID, err) } } @@ -590,7 +590,7 @@ func (c *Comment) LoadAssigneeUserAndTeam() error { } if c.Issue.Repo.Owner.IsOrganization() { - c.AssigneeTeam, err = organization.GetTeamByID(c.AssigneeTeamID) + c.AssigneeTeam, err = organization.GetTeamByID(db.DefaultContext, c.AssigneeTeamID) if err != nil && !organization.IsErrTeamNotExist(err) { return err } @@ -624,7 +624,7 @@ func (c *Comment) LoadDepIssueDetails() (err error) { if c.DependentIssueID <= 0 || c.DependentIssue != nil { return nil } - c.DependentIssue, err = getIssueByID(db.GetEngine(db.DefaultContext), c.DependentIssueID) + c.DependentIssue, err = getIssueByID(db.DefaultContext, c.DependentIssueID) return err } @@ -661,9 +661,9 @@ func (c *Comment) LoadReactions(repo *repo_model.Repository) error { return c.loadReactions(db.DefaultContext, repo) } -func (c *Comment) loadReview(e db.Engine) (err error) { +func (c *Comment) loadReview(ctx context.Context) (err error) { if c.Review == nil { - if c.Review, err = getReviewByID(e, c.ReviewID); err != nil { + if c.Review, err = GetReviewByID(ctx, c.ReviewID); err != nil { return err } } @@ -673,7 +673,7 @@ func (c *Comment) loadReview(e db.Engine) (err error) { // LoadReview loads the associated review func (c *Comment) LoadReview() error { - return c.loadReview(db.GetEngine(db.DefaultContext)) + return c.loadReview(db.DefaultContext) } var notEnoughLines = regexp.MustCompile(`fatal: file .* has only \d+ lines?`) @@ -830,13 +830,12 @@ func CreateCommentCtx(ctx context.Context, opts *CreateCommentOptions) (_ *Comme } func updateCommentInfos(ctx context.Context, opts *CreateCommentOptions, comment *Comment) (err error) { - e := db.GetEngine(ctx) // Check comment type. switch opts.Type { case CommentTypeCode: if comment.ReviewID != 0 { if comment.Review == nil { - if err := comment.loadReview(e); err != nil { + if err := comment.loadReview(ctx); err != nil { return err } } @@ -846,7 +845,7 @@ func updateCommentInfos(ctx context.Context, opts *CreateCommentOptions, comment } fallthrough case CommentTypeComment: - if _, err = e.Exec("UPDATE `issue` SET num_comments=num_comments+1 WHERE id=?", opts.Issue.ID); err != nil { + if _, err = db.Exec(ctx, "UPDATE `issue` SET num_comments=num_comments+1 WHERE id=?", opts.Issue.ID); err != nil { return err } fallthrough @@ -861,7 +860,7 @@ func updateCommentInfos(ctx context.Context, opts *CreateCommentOptions, comment attachments[i].IssueID = opts.Issue.ID attachments[i].CommentID = comment.ID // No assign value could be 0, so ignore AllCols(). - if _, err = e.ID(attachments[i].ID).Update(attachments[i]); err != nil { + if _, err = db.GetEngine(ctx).ID(attachments[i].ID).Update(attachments[i]); err != nil { return fmt.Errorf("update attachment [%d]: %v", attachments[i].ID, err) } } @@ -1031,13 +1030,9 @@ func CreateRefComment(doer *user_model.User, repo *repo_model.Repository, issue } // GetCommentByID returns the comment by given ID. -func GetCommentByID(id int64) (*Comment, error) { - return getCommentByID(db.GetEngine(db.DefaultContext), id) -} - -func getCommentByID(e db.Engine, id int64) (*Comment, error) { +func GetCommentByID(ctx context.Context, id int64) (*Comment, error) { c := new(Comment) - has, err := e.ID(id).Get(c) + has, err := db.GetEngine(ctx).ID(id).Get(c) if err != nil { return nil, err } else if !has { @@ -1088,9 +1083,10 @@ func (opts *FindCommentsOptions) toConds() builder.Cond { return cond } -func findComments(e db.Engine, opts *FindCommentsOptions) ([]*Comment, error) { +// FindComments returns all comments according options +func FindComments(ctx context.Context, opts *FindCommentsOptions) ([]*Comment, error) { comments := make([]*Comment, 0, 10) - sess := e.Where(opts.toConds()) + sess := db.GetEngine(ctx).Where(opts.toConds()) if opts.RepoID > 0 { sess.Join("INNER", "issue", "issue.id = comment.issue_id") } @@ -1107,11 +1103,6 @@ func findComments(e db.Engine, opts *FindCommentsOptions) ([]*Comment, error) { Find(&comments) } -// FindComments returns all comments according options -func FindComments(opts *FindCommentsOptions) ([]*Comment, error) { - return findComments(db.GetEngine(db.DefaultContext), opts) -} - // CountComments count all comments according options by ignoring pagination func CountComments(opts *FindCommentsOptions) (int64, error) { sess := db.GetEngine(db.DefaultContext).Where(opts.toConds()) @@ -1167,7 +1158,7 @@ func deleteComment(ctx context.Context, comment *Comment) error { return err } - if _, err := e.Delete(&issues_model.ContentHistory{ + if _, err := db.DeleteByBean(ctx, &issues_model.ContentHistory{ CommentID: comment.ID, }); err != nil { return err @@ -1182,7 +1173,7 @@ func deleteComment(ctx context.Context, comment *Comment) error { return err } - if err := comment.neuterCrossReferences(e); err != nil { + if err := comment.neuterCrossReferences(ctx); err != nil { return err } @@ -1192,7 +1183,8 @@ func deleteComment(ctx context.Context, comment *Comment) error { // CodeComments represents comments on code by using this structure: FILENAME -> LINE (+ == proposed; - == previous) -> COMMENTS type CodeComments map[string]map[int64][]*Comment -func fetchCodeComments(ctx context.Context, issue *Issue, currentUser *user_model.User) (CodeComments, error) { +// FetchCodeComments will return a 2d-map: ["Path"]["Line"] = Comments at line +func FetchCodeComments(ctx context.Context, issue *Issue, currentUser *user_model.User) (CodeComments, error) { return fetchCodeCommentsByReview(ctx, issue, currentUser, nil) } @@ -1242,7 +1234,7 @@ func findCodeComments(ctx context.Context, opts FindCommentsOptions, issue *Issu return nil, err } - if err := CommentList(comments).loadPosters(e); err != nil { + if err := CommentList(comments).loadPosters(ctx); err != nil { return nil, err } @@ -1302,11 +1294,6 @@ func FetchCodeCommentsByLine(ctx context.Context, issue *Issue, currentUser *use return findCodeComments(ctx, opts, issue, currentUser, nil) } -// FetchCodeComments will return a 2d-map: ["Path"]["Line"] = Comments at line -func FetchCodeComments(ctx context.Context, issue *Issue, currentUser *user_model.User) (CodeComments, error) { - return fetchCodeComments(ctx, issue, currentUser) -} - // UpdateCommentsMigrationsByType updates comments' migrations information via given git service type and original id and poster id func UpdateCommentsMigrationsByType(tp structs.GitServiceType, originalAuthorID string, posterID int64) error { _, err := db.GetEngine(db.DefaultContext).Table("comment"). diff --git a/models/issue_comment_list.go b/models/issue_comment_list.go index 4133fc876..d62984c1e 100644 --- a/models/issue_comment_list.go +++ b/models/issue_comment_list.go @@ -27,7 +27,7 @@ func (comments CommentList) getPosterIDs() []int64 { return container.KeysInt64(posterIDs) } -func (comments CommentList) loadPosters(e db.Engine) error { +func (comments CommentList) loadPosters(ctx context.Context) error { if len(comments) == 0 { return nil } @@ -40,7 +40,7 @@ func (comments CommentList) loadPosters(e db.Engine) error { if left < limit { limit = left } - err := e. + err := db.GetEngine(ctx). In("id", posterIDs[:limit]). Find(&posterMaps) if err != nil { @@ -80,7 +80,7 @@ func (comments CommentList) getLabelIDs() []int64 { return container.KeysInt64(ids) } -func (comments CommentList) loadLabels(e db.Engine) error { +func (comments CommentList) loadLabels(ctx context.Context) error { if len(comments) == 0 { return nil } @@ -93,7 +93,7 @@ func (comments CommentList) loadLabels(e db.Engine) error { if left < limit { limit = left } - rows, err := e. + rows, err := db.GetEngine(ctx). In("id", labelIDs[:limit]). Rows(new(Label)) if err != nil { @@ -130,7 +130,7 @@ func (comments CommentList) getMilestoneIDs() []int64 { return container.KeysInt64(ids) } -func (comments CommentList) loadMilestones(e db.Engine) error { +func (comments CommentList) loadMilestones(ctx context.Context) error { if len(comments) == 0 { return nil } @@ -147,7 +147,7 @@ func (comments CommentList) loadMilestones(e db.Engine) error { if left < limit { limit = left } - err := e. + err := db.GetEngine(ctx). In("id", milestoneIDs[:limit]). Find(&milestoneMaps) if err != nil { @@ -173,7 +173,7 @@ func (comments CommentList) getOldMilestoneIDs() []int64 { return container.KeysInt64(ids) } -func (comments CommentList) loadOldMilestones(e db.Engine) error { +func (comments CommentList) loadOldMilestones(ctx context.Context) error { if len(comments) == 0 { return nil } @@ -190,7 +190,7 @@ func (comments CommentList) loadOldMilestones(e db.Engine) error { if left < limit { limit = left } - err := e. + err := db.GetEngine(ctx). In("id", milestoneIDs[:limit]). Find(&milestoneMaps) if err != nil { @@ -216,7 +216,7 @@ func (comments CommentList) getAssigneeIDs() []int64 { return container.KeysInt64(ids) } -func (comments CommentList) loadAssignees(e db.Engine) error { +func (comments CommentList) loadAssignees(ctx context.Context) error { if len(comments) == 0 { return nil } @@ -229,7 +229,7 @@ func (comments CommentList) loadAssignees(e db.Engine) error { if left < limit { limit = left } - rows, err := e. + rows, err := db.GetEngine(ctx). In("id", assigneeIDs[:limit]). Rows(new(user_model.User)) if err != nil { @@ -290,7 +290,7 @@ func (comments CommentList) Issues() IssueList { return issueList } -func (comments CommentList) loadIssues(e db.Engine) error { +func (comments CommentList) loadIssues(ctx context.Context) error { if len(comments) == 0 { return nil } @@ -303,7 +303,7 @@ func (comments CommentList) loadIssues(e db.Engine) error { if left < limit { limit = left } - rows, err := e. + rows, err := db.GetEngine(ctx). In("id", issueIDs[:limit]). Rows(new(Issue)) if err != nil { @@ -397,7 +397,7 @@ func (comments CommentList) loadDependentIssues(ctx context.Context) error { return nil } -func (comments CommentList) loadAttachments(e db.Engine) (err error) { +func (comments CommentList) loadAttachments(ctx context.Context) (err error) { if len(comments) == 0 { return nil } @@ -410,7 +410,7 @@ func (comments CommentList) loadAttachments(e db.Engine) (err error) { if left < limit { limit = left } - rows, err := e.Table("attachment"). + rows, err := db.GetEngine(ctx).Table("attachment"). Join("INNER", "comment", "comment.id = attachment.comment_id"). In("comment.id", commentsIDs[:limit]). Rows(new(repo_model.Attachment)) @@ -449,7 +449,7 @@ func (comments CommentList) getReviewIDs() []int64 { return container.KeysInt64(ids) } -func (comments CommentList) loadReviews(e db.Engine) error { +func (comments CommentList) loadReviews(ctx context.Context) error { if len(comments) == 0 { return nil } @@ -462,7 +462,7 @@ func (comments CommentList) loadReviews(e db.Engine) error { if left < limit { limit = left } - rows, err := e. + rows, err := db.GetEngine(ctx). In("id", reviewIDs[:limit]). Rows(new(Review)) if err != nil { @@ -493,36 +493,35 @@ func (comments CommentList) loadReviews(e db.Engine) error { // loadAttributes loads all attributes func (comments CommentList) loadAttributes(ctx context.Context) (err error) { - e := db.GetEngine(ctx) - if err = comments.loadPosters(e); err != nil { + if err = comments.loadPosters(ctx); err != nil { return } - if err = comments.loadLabels(e); err != nil { + if err = comments.loadLabels(ctx); err != nil { return } - if err = comments.loadMilestones(e); err != nil { + if err = comments.loadMilestones(ctx); err != nil { return } - if err = comments.loadOldMilestones(e); err != nil { + if err = comments.loadOldMilestones(ctx); err != nil { return } - if err = comments.loadAssignees(e); err != nil { + if err = comments.loadAssignees(ctx); err != nil { return } - if err = comments.loadAttachments(e); err != nil { + if err = comments.loadAttachments(ctx); err != nil { return } - if err = comments.loadReviews(e); err != nil { + if err = comments.loadReviews(ctx); err != nil { return } - if err = comments.loadIssues(e); err != nil { + if err = comments.loadIssues(ctx); err != nil { return } @@ -541,15 +540,15 @@ func (comments CommentList) LoadAttributes() error { // LoadAttachments loads attachments func (comments CommentList) LoadAttachments() error { - return comments.loadAttachments(db.GetEngine(db.DefaultContext)) + return comments.loadAttachments(db.DefaultContext) } // LoadPosters loads posters func (comments CommentList) LoadPosters() error { - return comments.loadPosters(db.GetEngine(db.DefaultContext)) + return comments.loadPosters(db.DefaultContext) } // LoadIssues loads issues of comments func (comments CommentList) LoadIssues() error { - return comments.loadIssues(db.GetEngine(db.DefaultContext)) + return comments.loadIssues(db.DefaultContext) } diff --git a/models/issue_dependency.go b/models/issue_dependency.go index b292db57e..af40aa45d 100644 --- a/models/issue_dependency.go +++ b/models/issue_dependency.go @@ -42,10 +42,9 @@ func CreateIssueDependency(user *user_model.User, issue, dep *Issue) error { return err } defer committer.Close() - sess := db.GetEngine(ctx) // Check if it aleready exists - exists, err := issueDepExists(sess, issue.ID, dep.ID) + exists, err := issueDepExists(ctx, issue.ID, dep.ID) if err != nil { return err } @@ -53,7 +52,7 @@ func CreateIssueDependency(user *user_model.User, issue, dep *Issue) error { return ErrDependencyExists{issue.ID, dep.ID} } // And if it would be circular - circular, err := issueDepExists(sess, dep.ID, issue.ID) + circular, err := issueDepExists(ctx, dep.ID, issue.ID) if err != nil { return err } @@ -114,8 +113,8 @@ func RemoveIssueDependency(user *user_model.User, issue, dep *Issue, depType Dep } // Check if the dependency already exists -func issueDepExists(e db.Engine, issueID, depID int64) (bool, error) { - return e.Where("(issue_id = ? AND dependency_id = ?)", issueID, depID).Exist(&IssueDependency{}) +func issueDepExists(ctx context.Context, issueID, depID int64) (bool, error) { + return db.GetEngine(ctx).Where("(issue_id = ? AND dependency_id = ?)", issueID, depID).Exist(&IssueDependency{}) } // IssueNoDependenciesLeft checks if issue can be closed diff --git a/models/issue_label.go b/models/issue_label.go index d06915393..48a48dbb7 100644 --- a/models/issue_label.go +++ b/models/issue_label.go @@ -193,12 +193,12 @@ func UpdateLabel(l *Label) error { if !LabelColorPattern.MatchString(l.Color) { return fmt.Errorf("bad color code: %s", l.Color) } - return updateLabelCols(db.GetEngine(db.DefaultContext), l, "name", "description", "color") + return updateLabelCols(db.DefaultContext, l, "name", "description", "color") } // DeleteLabel delete a label func DeleteLabel(id, labelID int64) error { - label, err := GetLabelByID(labelID) + label, err := GetLabelByID(db.DefaultContext, labelID) if err != nil { if IsErrLabelNotExist(err) { return nil @@ -237,14 +237,14 @@ func DeleteLabel(id, labelID int64) error { return committer.Commit() } -// getLabelByID returns a label by label id -func getLabelByID(e db.Engine, labelID int64) (*Label, error) { +// GetLabelByID returns a label by given ID. +func GetLabelByID(ctx context.Context, labelID int64) (*Label, error) { if labelID <= 0 { return nil, ErrLabelNotExist{labelID} } l := &Label{} - has, err := e.ID(labelID).Get(l) + has, err := db.GetEngine(ctx).ID(labelID).Get(l) if err != nil { return nil, err } else if !has { @@ -253,11 +253,6 @@ func getLabelByID(e db.Engine, labelID int64) (*Label, error) { return l, nil } -// GetLabelByID returns a label by given ID. -func GetLabelByID(id int64) (*Label, error) { - return getLabelByID(db.GetEngine(db.DefaultContext), id) -} - // GetLabelsByIDs returns a list of labels by IDs func GetLabelsByIDs(labelIDs []int64) ([]*Label, error) { labels := make([]*Label, 0, len(labelIDs)) @@ -275,8 +270,8 @@ func GetLabelsByIDs(labelIDs []int64) ([]*Label, error) { // |____|_ /\___ > __/ \____/____ >__||__| \____/|__| / ____| // \/ \/|__| \/ \/ -// getLabelInRepoByName returns a label by Name in given repository. -func getLabelInRepoByName(e db.Engine, repoID int64, labelName string) (*Label, error) { +// GetLabelInRepoByName returns a label by name in given repository. +func GetLabelInRepoByName(ctx context.Context, repoID int64, labelName string) (*Label, error) { if len(labelName) == 0 || repoID <= 0 { return nil, ErrRepoLabelNotExist{0, repoID} } @@ -285,7 +280,7 @@ func getLabelInRepoByName(e db.Engine, repoID int64, labelName string) (*Label, Name: labelName, RepoID: repoID, } - has, err := e.Get(l) + has, err := db.GetByBean(ctx, l) if err != nil { return nil, err } else if !has { @@ -294,8 +289,8 @@ func getLabelInRepoByName(e db.Engine, repoID int64, labelName string) (*Label, return l, nil } -// getLabelInRepoByID returns a label by ID in given repository. -func getLabelInRepoByID(e db.Engine, repoID, labelID int64) (*Label, error) { +// GetLabelInRepoByID returns a label by ID in given repository. +func GetLabelInRepoByID(ctx context.Context, repoID, labelID int64) (*Label, error) { if labelID <= 0 || repoID <= 0 { return nil, ErrRepoLabelNotExist{labelID, repoID} } @@ -304,7 +299,7 @@ func getLabelInRepoByID(e db.Engine, repoID, labelID int64) (*Label, error) { ID: labelID, RepoID: repoID, } - has, err := e.Get(l) + has, err := db.GetByBean(ctx, l) if err != nil { return nil, err } else if !has { @@ -313,11 +308,6 @@ func getLabelInRepoByID(e db.Engine, repoID, labelID int64) (*Label, error) { return l, nil } -// GetLabelInRepoByName returns a label by name in given repository. -func GetLabelInRepoByName(repoID int64, labelName string) (*Label, error) { - return getLabelInRepoByName(db.GetEngine(db.DefaultContext), repoID, labelName) -} - // GetLabelIDsInRepoByNames returns a list of labelIDs by names in a given // repository. // it silently ignores label names that do not belong to the repository. @@ -342,11 +332,6 @@ func BuildLabelNamesIssueIDsCondition(labelNames []string) *builder.Builder { GroupBy("issue_label.issue_id") } -// GetLabelInRepoByID returns a label by ID in given repository. -func GetLabelInRepoByID(repoID, labelID int64) (*Label, error) { - return getLabelInRepoByID(db.GetEngine(db.DefaultContext), repoID, labelID) -} - // GetLabelsInRepoByIDs returns a list of labels by IDs in given repository, // it silently ignores label IDs that do not belong to the repository. func GetLabelsInRepoByIDs(repoID int64, labelIDs []int64) ([]*Label, error) { @@ -358,12 +343,13 @@ func GetLabelsInRepoByIDs(repoID int64, labelIDs []int64) ([]*Label, error) { Find(&labels) } -func getLabelsByRepoID(e db.Engine, repoID int64, sortType string, listOptions db.ListOptions) ([]*Label, error) { +// GetLabelsByRepoID returns all labels that belong to given repository by ID. +func GetLabelsByRepoID(ctx context.Context, repoID int64, sortType string, listOptions db.ListOptions) ([]*Label, error) { if repoID <= 0 { return nil, ErrRepoLabelNotExist{0, repoID} } labels := make([]*Label, 0, 10) - sess := e.Where("repo_id = ?", repoID) + sess := db.GetEngine(ctx).Where("repo_id = ?", repoID) switch sortType { case "reversealphabetically": @@ -383,11 +369,6 @@ func getLabelsByRepoID(e db.Engine, repoID int64, sortType string, listOptions d return labels, sess.Find(&labels) } -// GetLabelsByRepoID returns all labels that belong to given repository by ID. -func GetLabelsByRepoID(repoID int64, sortType string, listOptions db.ListOptions) ([]*Label, error) { - return getLabelsByRepoID(db.GetEngine(db.DefaultContext), repoID, sortType, listOptions) -} - // CountLabelsByRepoID count number of all labels that belong to given repository by ID. func CountLabelsByRepoID(repoID int64) (int64, error) { return db.GetEngine(db.DefaultContext).Where("repo_id = ?", repoID).Count(&Label{}) @@ -400,8 +381,8 @@ func CountLabelsByRepoID(repoID int64) (int64, error) { // \_______ /__| \___ / // \/ /_____/ -// getLabelInOrgByName returns a label by Name in given organization -func getLabelInOrgByName(e db.Engine, orgID int64, labelName string) (*Label, error) { +// GetLabelInOrgByName returns a label by name in given organization. +func GetLabelInOrgByName(ctx context.Context, orgID int64, labelName string) (*Label, error) { if len(labelName) == 0 || orgID <= 0 { return nil, ErrOrgLabelNotExist{0, orgID} } @@ -410,7 +391,7 @@ func getLabelInOrgByName(e db.Engine, orgID int64, labelName string) (*Label, er Name: labelName, OrgID: orgID, } - has, err := e.Get(l) + has, err := db.GetByBean(ctx, l) if err != nil { return nil, err } else if !has { @@ -419,8 +400,8 @@ func getLabelInOrgByName(e db.Engine, orgID int64, labelName string) (*Label, er return l, nil } -// getLabelInOrgByID returns a label by ID in given organization. -func getLabelInOrgByID(e db.Engine, orgID, labelID int64) (*Label, error) { +// GetLabelInOrgByID returns a label by ID in given organization. +func GetLabelInOrgByID(ctx context.Context, orgID, labelID int64) (*Label, error) { if labelID <= 0 || orgID <= 0 { return nil, ErrOrgLabelNotExist{labelID, orgID} } @@ -429,7 +410,7 @@ func getLabelInOrgByID(e db.Engine, orgID, labelID int64) (*Label, error) { ID: labelID, OrgID: orgID, } - has, err := e.Get(l) + has, err := db.GetByBean(ctx, l) if err != nil { return nil, err } else if !has { @@ -438,11 +419,6 @@ func getLabelInOrgByID(e db.Engine, orgID, labelID int64) (*Label, error) { return l, nil } -// GetLabelInOrgByName returns a label by name in given organization. -func GetLabelInOrgByName(orgID int64, labelName string) (*Label, error) { - return getLabelInOrgByName(db.GetEngine(db.DefaultContext), orgID, labelName) -} - // GetLabelIDsInOrgByNames returns a list of labelIDs by names in a given // organization. func GetLabelIDsInOrgByNames(orgID int64, labelNames []string) ([]int64, error) { @@ -459,11 +435,6 @@ func GetLabelIDsInOrgByNames(orgID int64, labelNames []string) ([]int64, error) Find(&labelIDs) } -// GetLabelInOrgByID returns a label by ID in given organization. -func GetLabelInOrgByID(orgID, labelID int64) (*Label, error) { - return getLabelInOrgByID(db.GetEngine(db.DefaultContext), orgID, labelID) -} - // GetLabelsInOrgByIDs returns a list of labels by IDs in given organization, // it silently ignores label IDs that do not belong to the organization. func GetLabelsInOrgByIDs(orgID int64, labelIDs []int64) ([]*Label, error) { @@ -475,12 +446,13 @@ func GetLabelsInOrgByIDs(orgID int64, labelIDs []int64) ([]*Label, error) { Find(&labels) } -func getLabelsByOrgID(e db.Engine, orgID int64, sortType string, listOptions db.ListOptions) ([]*Label, error) { +// GetLabelsByOrgID returns all labels that belong to given organization by ID. +func GetLabelsByOrgID(ctx context.Context, orgID int64, sortType string, listOptions db.ListOptions) ([]*Label, error) { if orgID <= 0 { return nil, ErrOrgLabelNotExist{0, orgID} } labels := make([]*Label, 0, 10) - sess := e.Where("org_id = ?", orgID) + sess := db.GetEngine(ctx).Where("org_id = ?", orgID) switch sortType { case "reversealphabetically": @@ -500,11 +472,6 @@ func getLabelsByOrgID(e db.Engine, orgID int64, sortType string, listOptions db. return labels, sess.Find(&labels) } -// GetLabelsByOrgID returns all labels that belong to given organization by ID. -func GetLabelsByOrgID(orgID int64, sortType string, listOptions db.ListOptions) ([]*Label, error) { - return getLabelsByOrgID(db.GetEngine(db.DefaultContext), orgID, sortType, listOptions) -} - // CountLabelsByOrgID count all labels that belong to given organization by ID. func CountLabelsByOrgID(orgID int64) (int64, error) { return db.GetEngine(db.DefaultContext).Where("org_id = ?", orgID).Count(&Label{}) @@ -517,21 +484,17 @@ func CountLabelsByOrgID(orgID int64) (int64, error) { // |___/____ >____ >____/ \___ | // \/ \/ \/ -func getLabelsByIssueID(e db.Engine, issueID int64) ([]*Label, error) { +// GetLabelsByIssueID returns all labels that belong to given issue by ID. +func GetLabelsByIssueID(ctx context.Context, issueID int64) ([]*Label, error) { var labels []*Label - return labels, e.Where("issue_label.issue_id = ?", issueID). + return labels, db.GetEngine(ctx).Where("issue_label.issue_id = ?", issueID). Join("LEFT", "issue_label", "issue_label.label_id = label.id"). Asc("label.name"). Find(&labels) } -// GetLabelsByIssueID returns all labels that belong to given issue by ID. -func GetLabelsByIssueID(issueID int64) ([]*Label, error) { - return getLabelsByIssueID(db.GetEngine(db.DefaultContext), issueID) -} - -func updateLabelCols(e db.Engine, l *Label, cols ...string) error { - _, err := e.ID(l.ID). +func updateLabelCols(ctx context.Context, l *Label, cols ...string) error { + _, err := db.GetEngine(ctx).ID(l.ID). SetExpr("num_issues", builder.Select("count(*)").From("issue_label"). Where(builder.Eq{"label_id": l.ID}), @@ -562,21 +525,16 @@ type IssueLabel struct { LabelID int64 `xorm:"UNIQUE(s)"` } -func hasIssueLabel(e db.Engine, issueID, labelID int64) bool { - has, _ := e.Where("issue_id = ? AND label_id = ?", issueID, labelID).Get(new(IssueLabel)) - return has -} - // HasIssueLabel returns true if issue has been labeled. -func HasIssueLabel(issueID, labelID int64) bool { - return hasIssueLabel(db.GetEngine(db.DefaultContext), issueID, labelID) +func HasIssueLabel(ctx context.Context, issueID, labelID int64) bool { + has, _ := db.GetEngine(ctx).Where("issue_id = ? AND label_id = ?", issueID, labelID).Get(new(IssueLabel)) + return has } // newIssueLabel this function creates a new label it does not check if the label is valid for the issue // YOU MUST CHECK THIS BEFORE THIS FUNCTION func newIssueLabel(ctx context.Context, issue *Issue, label *Label, doer *user_model.User) (err error) { - e := db.GetEngine(ctx) - if _, err = e.Insert(&IssueLabel{ + if err = db.Insert(ctx, &IssueLabel{ IssueID: issue.ID, LabelID: label.ID, }); err != nil { @@ -599,12 +557,12 @@ func newIssueLabel(ctx context.Context, issue *Issue, label *Label, doer *user_m return err } - return updateLabelCols(e, label, "num_issues", "num_closed_issue") + return updateLabelCols(ctx, label, "num_issues", "num_closed_issue") } // NewIssueLabel creates a new issue-label relation. func NewIssueLabel(issue *Issue, label *Label, doer *user_model.User) (err error) { - if HasIssueLabel(issue.ID, label.ID) { + if HasIssueLabel(db.DefaultContext, issue.ID, label.ID) { return nil } @@ -637,13 +595,12 @@ func NewIssueLabel(issue *Issue, label *Label, doer *user_model.User) (err error // newIssueLabels add labels to an issue. It will check if the labels are valid for the issue func newIssueLabels(ctx context.Context, issue *Issue, labels []*Label, doer *user_model.User) (err error) { - e := db.GetEngine(ctx) if err = issue.LoadRepo(ctx); err != nil { return err } for _, label := range labels { // Don't add already present labels and invalid labels - if hasIssueLabel(e, issue.ID, label.ID) || + if HasIssueLabel(ctx, issue.ID, label.ID) || (label.RepoID != issue.RepoID && label.OrgID != issue.Repo.OwnerID) { continue } @@ -677,8 +634,7 @@ func NewIssueLabels(issue *Issue, labels []*Label, doer *user_model.User) (err e } func deleteIssueLabel(ctx context.Context, issue *Issue, label *Label, doer *user_model.User) (err error) { - e := db.GetEngine(ctx) - if count, err := e.Delete(&IssueLabel{ + if count, err := db.DeleteByBean(ctx, &IssueLabel{ IssueID: issue.ID, LabelID: label.ID, }); err != nil { @@ -702,7 +658,7 @@ func deleteIssueLabel(ctx context.Context, issue *Issue, label *Label, doer *use return err } - return updateLabelCols(e, label, "num_issues", "num_closed_issue") + return updateLabelCols(ctx, label, "num_issues", "num_closed_issue") } // DeleteIssueLabel deletes issue-label relation. @@ -715,14 +671,14 @@ func DeleteIssueLabel(ctx context.Context, issue *Issue, label *Label, doer *use return issue.LoadLabels(ctx) } -func deleteLabelsByRepoID(sess db.Engine, repoID int64) error { +func deleteLabelsByRepoID(ctx context.Context, repoID int64) error { deleteCond := builder.Select("id").From("label").Where(builder.Eq{"label.repo_id": repoID}) - if _, err := sess.In("label_id", deleteCond). + if _, err := db.GetEngine(ctx).In("label_id", deleteCond). Delete(&IssueLabel{}); err != nil { return err } - _, err := sess.Delete(&Label{RepoID: repoID}) + _, err := db.DeleteByBean(ctx, &Label{RepoID: repoID}) return err } diff --git a/models/issue_label_test.go b/models/issue_label_test.go index 2dd0cf98e..67a09151d 100644 --- a/models/issue_label_test.go +++ b/models/issue_label_test.go @@ -59,25 +59,25 @@ func TestNewLabels(t *testing.T) { func TestGetLabelByID(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - label, err := GetLabelByID(1) + label, err := GetLabelByID(db.DefaultContext, 1) assert.NoError(t, err) assert.EqualValues(t, 1, label.ID) - _, err = GetLabelByID(unittest.NonexistentID) + _, err = GetLabelByID(db.DefaultContext, unittest.NonexistentID) assert.True(t, IsErrLabelNotExist(err)) } func TestGetLabelInRepoByName(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - label, err := GetLabelInRepoByName(1, "label1") + label, err := GetLabelInRepoByName(db.DefaultContext, 1, "label1") assert.NoError(t, err) assert.EqualValues(t, 1, label.ID) assert.Equal(t, "label1", label.Name) - _, err = GetLabelInRepoByName(1, "") + _, err = GetLabelInRepoByName(db.DefaultContext, 1, "") assert.True(t, IsErrRepoLabelNotExist(err)) - _, err = GetLabelInRepoByName(unittest.NonexistentID, "nonexistent") + _, err = GetLabelInRepoByName(db.DefaultContext, unittest.NonexistentID, "nonexistent") assert.True(t, IsErrRepoLabelNotExist(err)) } @@ -107,14 +107,14 @@ func TestGetLabelInRepoByNamesDiscardsNonExistentLabels(t *testing.T) { func TestGetLabelInRepoByID(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - label, err := GetLabelInRepoByID(1, 1) + label, err := GetLabelInRepoByID(db.DefaultContext, 1, 1) assert.NoError(t, err) assert.EqualValues(t, 1, label.ID) - _, err = GetLabelInRepoByID(1, -1) + _, err = GetLabelInRepoByID(db.DefaultContext, 1, -1) assert.True(t, IsErrRepoLabelNotExist(err)) - _, err = GetLabelInRepoByID(unittest.NonexistentID, unittest.NonexistentID) + _, err = GetLabelInRepoByID(db.DefaultContext, unittest.NonexistentID, unittest.NonexistentID) assert.True(t, IsErrRepoLabelNotExist(err)) } @@ -131,7 +131,7 @@ func TestGetLabelsInRepoByIDs(t *testing.T) { func TestGetLabelsByRepoID(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) testSuccess := func(repoID int64, sortType string, expectedIssueIDs []int64) { - labels, err := GetLabelsByRepoID(repoID, sortType, db.ListOptions{}) + labels, err := GetLabelsByRepoID(db.DefaultContext, repoID, sortType, db.ListOptions{}) assert.NoError(t, err) assert.Len(t, labels, len(expectedIssueIDs)) for i, label := range labels { @@ -148,21 +148,21 @@ func TestGetLabelsByRepoID(t *testing.T) { func TestGetLabelInOrgByName(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - label, err := GetLabelInOrgByName(3, "orglabel3") + label, err := GetLabelInOrgByName(db.DefaultContext, 3, "orglabel3") assert.NoError(t, err) assert.EqualValues(t, 3, label.ID) assert.Equal(t, "orglabel3", label.Name) - _, err = GetLabelInOrgByName(3, "") + _, err = GetLabelInOrgByName(db.DefaultContext, 3, "") assert.True(t, IsErrOrgLabelNotExist(err)) - _, err = GetLabelInOrgByName(0, "orglabel3") + _, err = GetLabelInOrgByName(db.DefaultContext, 0, "orglabel3") assert.True(t, IsErrOrgLabelNotExist(err)) - _, err = GetLabelInOrgByName(-1, "orglabel3") + _, err = GetLabelInOrgByName(db.DefaultContext, -1, "orglabel3") assert.True(t, IsErrOrgLabelNotExist(err)) - _, err = GetLabelInOrgByName(unittest.NonexistentID, "nonexistent") + _, err = GetLabelInOrgByName(db.DefaultContext, unittest.NonexistentID, "nonexistent") assert.True(t, IsErrOrgLabelNotExist(err)) } @@ -192,20 +192,20 @@ func TestGetLabelInOrgByNamesDiscardsNonExistentLabels(t *testing.T) { func TestGetLabelInOrgByID(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - label, err := GetLabelInOrgByID(3, 3) + label, err := GetLabelInOrgByID(db.DefaultContext, 3, 3) assert.NoError(t, err) assert.EqualValues(t, 3, label.ID) - _, err = GetLabelInOrgByID(3, -1) + _, err = GetLabelInOrgByID(db.DefaultContext, 3, -1) assert.True(t, IsErrOrgLabelNotExist(err)) - _, err = GetLabelInOrgByID(0, 3) + _, err = GetLabelInOrgByID(db.DefaultContext, 0, 3) assert.True(t, IsErrOrgLabelNotExist(err)) - _, err = GetLabelInOrgByID(-1, 3) + _, err = GetLabelInOrgByID(db.DefaultContext, -1, 3) assert.True(t, IsErrOrgLabelNotExist(err)) - _, err = GetLabelInOrgByID(unittest.NonexistentID, unittest.NonexistentID) + _, err = GetLabelInOrgByID(db.DefaultContext, unittest.NonexistentID, unittest.NonexistentID) assert.True(t, IsErrOrgLabelNotExist(err)) } @@ -222,7 +222,7 @@ func TestGetLabelsInOrgByIDs(t *testing.T) { func TestGetLabelsByOrgID(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) testSuccess := func(orgID int64, sortType string, expectedIssueIDs []int64) { - labels, err := GetLabelsByOrgID(orgID, sortType, db.ListOptions{}) + labels, err := GetLabelsByOrgID(db.DefaultContext, orgID, sortType, db.ListOptions{}) assert.NoError(t, err) assert.Len(t, labels, len(expectedIssueIDs)) for i, label := range labels { @@ -235,10 +235,10 @@ func TestGetLabelsByOrgID(t *testing.T) { testSuccess(3, "default", []int64{3, 4}) var err error - _, err = GetLabelsByOrgID(0, "leastissues", db.ListOptions{}) + _, err = GetLabelsByOrgID(db.DefaultContext, 0, "leastissues", db.ListOptions{}) assert.True(t, IsErrOrgLabelNotExist(err)) - _, err = GetLabelsByOrgID(-1, "leastissues", db.ListOptions{}) + _, err = GetLabelsByOrgID(db.DefaultContext, -1, "leastissues", db.ListOptions{}) assert.True(t, IsErrOrgLabelNotExist(err)) } @@ -246,13 +246,13 @@ func TestGetLabelsByOrgID(t *testing.T) { func TestGetLabelsByIssueID(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - labels, err := GetLabelsByIssueID(1) + labels, err := GetLabelsByIssueID(db.DefaultContext, 1) assert.NoError(t, err) if assert.Len(t, labels, 1) { assert.EqualValues(t, 1, labels[0].ID) } - labels, err = GetLabelsByIssueID(unittest.NonexistentID) + labels, err = GetLabelsByIssueID(db.DefaultContext, unittest.NonexistentID) assert.NoError(t, err) assert.Len(t, labels, 0) } @@ -293,9 +293,9 @@ func TestDeleteLabel(t *testing.T) { func TestHasIssueLabel(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - assert.True(t, HasIssueLabel(1, 1)) - assert.False(t, HasIssueLabel(1, 2)) - assert.False(t, HasIssueLabel(unittest.NonexistentID, unittest.NonexistentID)) + assert.True(t, HasIssueLabel(db.DefaultContext, 1, 1)) + assert.False(t, HasIssueLabel(db.DefaultContext, 1, 2)) + assert.False(t, HasIssueLabel(db.DefaultContext, unittest.NonexistentID, unittest.NonexistentID)) } func TestNewIssueLabel(t *testing.T) { diff --git a/models/issue_list.go b/models/issue_list.go index 3116b49d8..4a8f72a48 100644 --- a/models/issue_list.go +++ b/models/issue_list.go @@ -5,6 +5,7 @@ package models import ( + "context" "fmt" "code.gitea.io/gitea/models/db" @@ -37,7 +38,7 @@ func (issues IssueList) getRepoIDs() []int64 { return container.KeysInt64(repoIDs) } -func (issues IssueList) loadRepositories(e db.Engine) ([]*repo_model.Repository, error) { +func (issues IssueList) loadRepositories(ctx context.Context) ([]*repo_model.Repository, error) { if len(issues) == 0 { return nil, nil } @@ -50,7 +51,7 @@ func (issues IssueList) loadRepositories(e db.Engine) ([]*repo_model.Repository, if left < limit { limit = left } - err := e. + err := db.GetEngine(ctx). In("id", repoIDs[:limit]). Find(&repoMaps) if err != nil { @@ -75,7 +76,7 @@ func (issues IssueList) loadRepositories(e db.Engine) ([]*repo_model.Repository, // LoadRepositories loads issues' all repositories func (issues IssueList) LoadRepositories() ([]*repo_model.Repository, error) { - return issues.loadRepositories(db.GetEngine(db.DefaultContext)) + return issues.loadRepositories(db.DefaultContext) } func (issues IssueList) getPosterIDs() []int64 { @@ -88,7 +89,7 @@ func (issues IssueList) getPosterIDs() []int64 { return container.KeysInt64(posterIDs) } -func (issues IssueList) loadPosters(e db.Engine) error { +func (issues IssueList) loadPosters(ctx context.Context) error { if len(issues) == 0 { return nil } @@ -101,7 +102,7 @@ func (issues IssueList) loadPosters(e db.Engine) error { if left < limit { limit = left } - err := e. + err := db.GetEngine(ctx). In("id", posterIDs[:limit]). Find(&posterMaps) if err != nil { @@ -131,7 +132,7 @@ func (issues IssueList) getIssueIDs() []int64 { return ids } -func (issues IssueList) loadLabels(e db.Engine) error { +func (issues IssueList) loadLabels(ctx context.Context) error { if len(issues) == 0 { return nil } @@ -149,7 +150,7 @@ func (issues IssueList) loadLabels(e db.Engine) error { if left < limit { limit = left } - rows, err := e.Table("label"). + rows, err := db.GetEngine(ctx).Table("label"). Join("LEFT", "issue_label", "issue_label.label_id = label.id"). In("issue_label.issue_id", issueIDs[:limit]). Asc("label.name"). @@ -194,7 +195,7 @@ func (issues IssueList) getMilestoneIDs() []int64 { return container.KeysInt64(ids) } -func (issues IssueList) loadMilestones(e db.Engine) error { +func (issues IssueList) loadMilestones(ctx context.Context) error { milestoneIDs := issues.getMilestoneIDs() if len(milestoneIDs) == 0 { return nil @@ -207,7 +208,7 @@ func (issues IssueList) loadMilestones(e db.Engine) error { if left < limit { limit = left } - err := e. + err := db.GetEngine(ctx). In("id", milestoneIDs[:limit]). Find(&milestoneMaps) if err != nil { @@ -223,7 +224,7 @@ func (issues IssueList) loadMilestones(e db.Engine) error { return nil } -func (issues IssueList) loadAssignees(e db.Engine) error { +func (issues IssueList) loadAssignees(ctx context.Context) error { if len(issues) == 0 { return nil } @@ -241,7 +242,7 @@ func (issues IssueList) loadAssignees(e db.Engine) error { if left < limit { limit = left } - rows, err := e.Table("issue_assignees"). + rows, err := db.GetEngine(ctx).Table("issue_assignees"). Join("INNER", "`user`", "`user`.id = `issue_assignees`.assignee_id"). In("`issue_assignees`.issue_id", issueIDs[:limit]). Rows(new(AssigneeIssue)) @@ -284,7 +285,7 @@ func (issues IssueList) getPullIssueIDs() []int64 { return ids } -func (issues IssueList) loadPullRequests(e db.Engine) error { +func (issues IssueList) loadPullRequests(ctx context.Context) error { issuesIDs := issues.getPullIssueIDs() if len(issuesIDs) == 0 { return nil @@ -297,7 +298,7 @@ func (issues IssueList) loadPullRequests(e db.Engine) error { if left < limit { limit = left } - rows, err := e. + rows, err := db.GetEngine(ctx). In("issue_id", issuesIDs[:limit]). Rows(new(PullRequest)) if err != nil { @@ -328,7 +329,7 @@ func (issues IssueList) loadPullRequests(e db.Engine) error { return nil } -func (issues IssueList) loadAttachments(e db.Engine) (err error) { +func (issues IssueList) loadAttachments(ctx context.Context) (err error) { if len(issues) == 0 { return nil } @@ -341,7 +342,7 @@ func (issues IssueList) loadAttachments(e db.Engine) (err error) { if left < limit { limit = left } - rows, err := e.Table("attachment"). + rows, err := db.GetEngine(ctx).Table("attachment"). Join("INNER", "issue", "issue.id = attachment.issue_id"). In("issue.id", issuesIDs[:limit]). Rows(new(repo_model.Attachment)) @@ -373,7 +374,7 @@ func (issues IssueList) loadAttachments(e db.Engine) (err error) { return nil } -func (issues IssueList) loadComments(e db.Engine, cond builder.Cond) (err error) { +func (issues IssueList) loadComments(ctx context.Context, cond builder.Cond) (err error) { if len(issues) == 0 { return nil } @@ -386,7 +387,7 @@ func (issues IssueList) loadComments(e db.Engine, cond builder.Cond) (err error) if left < limit { limit = left } - rows, err := e.Table("comment"). + rows, err := db.GetEngine(ctx).Table("comment"). Join("INNER", "issue", "issue.id = comment.issue_id"). In("issue.id", issuesIDs[:limit]). Where(cond). @@ -419,7 +420,7 @@ func (issues IssueList) loadComments(e db.Engine, cond builder.Cond) (err error) return nil } -func (issues IssueList) loadTotalTrackedTimes(e db.Engine) (err error) { +func (issues IssueList) loadTotalTrackedTimes(ctx context.Context) (err error) { type totalTimesByIssue struct { IssueID int64 Time int64 @@ -444,7 +445,7 @@ func (issues IssueList) loadTotalTrackedTimes(e db.Engine) (err error) { } // select issue_id, sum(time) from tracked_time where issue_id in () group by issue_id - rows, err := e.Table("tracked_time"). + rows, err := db.GetEngine(ctx).Table("tracked_time"). Where("deleted = ?", false). Select("issue_id, sum(time) as time"). In("issue_id", ids[:limit]). @@ -479,32 +480,32 @@ func (issues IssueList) loadTotalTrackedTimes(e db.Engine) (err error) { } // loadAttributes loads all attributes, expect for attachments and comments -func (issues IssueList) loadAttributes(e db.Engine) error { - if _, err := issues.loadRepositories(e); err != nil { +func (issues IssueList) loadAttributes(ctx context.Context) error { + if _, err := issues.loadRepositories(ctx); err != nil { return fmt.Errorf("issue.loadAttributes: loadRepositories: %v", err) } - if err := issues.loadPosters(e); err != nil { + if err := issues.loadPosters(ctx); err != nil { return fmt.Errorf("issue.loadAttributes: loadPosters: %v", err) } - if err := issues.loadLabels(e); err != nil { + if err := issues.loadLabels(ctx); err != nil { return fmt.Errorf("issue.loadAttributes: loadLabels: %v", err) } - if err := issues.loadMilestones(e); err != nil { + if err := issues.loadMilestones(ctx); err != nil { return fmt.Errorf("issue.loadAttributes: loadMilestones: %v", err) } - if err := issues.loadAssignees(e); err != nil { + if err := issues.loadAssignees(ctx); err != nil { return fmt.Errorf("issue.loadAttributes: loadAssignees: %v", err) } - if err := issues.loadPullRequests(e); err != nil { + if err := issues.loadPullRequests(ctx); err != nil { return fmt.Errorf("issue.loadAttributes: loadPullRequests: %v", err) } - if err := issues.loadTotalTrackedTimes(e); err != nil { + if err := issues.loadTotalTrackedTimes(ctx); err != nil { return fmt.Errorf("issue.loadAttributes: loadTotalTrackedTimes: %v", err) } @@ -514,42 +515,38 @@ func (issues IssueList) loadAttributes(e db.Engine) error { // LoadAttributes loads attributes of the issues, except for attachments and // comments func (issues IssueList) LoadAttributes() error { - return issues.loadAttributes(db.GetEngine(db.DefaultContext)) + return issues.loadAttributes(db.DefaultContext) } // LoadAttachments loads attachments func (issues IssueList) LoadAttachments() error { - return issues.loadAttachments(db.GetEngine(db.DefaultContext)) + return issues.loadAttachments(db.DefaultContext) } // LoadComments loads comments func (issues IssueList) LoadComments() error { - return issues.loadComments(db.GetEngine(db.DefaultContext), builder.NewCond()) + return issues.loadComments(db.DefaultContext, builder.NewCond()) } // LoadDiscussComments loads discuss comments func (issues IssueList) LoadDiscussComments() error { - return issues.loadComments(db.GetEngine(db.DefaultContext), builder.Eq{"comment.type": CommentTypeComment}) + return issues.loadComments(db.DefaultContext, builder.Eq{"comment.type": CommentTypeComment}) } // LoadPullRequests loads pull requests func (issues IssueList) LoadPullRequests() error { - return issues.loadPullRequests(db.GetEngine(db.DefaultContext)) + return issues.loadPullRequests(db.DefaultContext) } // GetApprovalCounts returns a map of issue ID to slice of approval counts // FIXME: only returns official counts due to double counting of non-official approvals -func (issues IssueList) GetApprovalCounts() (map[int64][]*ReviewCount, error) { - return issues.getApprovalCounts(db.GetEngine(db.DefaultContext)) -} - -func (issues IssueList) getApprovalCounts(e db.Engine) (map[int64][]*ReviewCount, error) { +func (issues IssueList) GetApprovalCounts(ctx context.Context) (map[int64][]*ReviewCount, error) { rCounts := make([]*ReviewCount, 0, 2*len(issues)) ids := make([]int64, len(issues)) for i, issue := range issues { ids[i] = issue.ID } - sess := e.In("issue_id", ids) + sess := db.GetEngine(ctx).In("issue_id", ids) err := sess.Select("issue_id, type, count(id) as `count`"). Where("official = ? AND dismissed = ?", true, false). GroupBy("issue_id, type"). diff --git a/models/issue_project.go b/models/issue_project.go index 0e993b39c..0f8c61977 100644 --- a/models/issue_project.go +++ b/models/issue_project.go @@ -15,13 +15,13 @@ import ( // LoadProject load the project the issue was assigned to func (i *Issue) LoadProject() (err error) { - return i.loadProject(db.GetEngine(db.DefaultContext)) + return i.loadProject(db.DefaultContext) } -func (i *Issue) loadProject(e db.Engine) (err error) { +func (i *Issue) loadProject(ctx context.Context) (err error) { if i.Project == nil { var p project_model.Project - if _, err = e.Table("project"). + if _, err = db.GetEngine(ctx).Table("project"). Join("INNER", "project_issue", "project.id=project_issue.project_id"). Where("project_issue.issue_id = ?", i.ID). Get(&p); err != nil { @@ -34,12 +34,12 @@ func (i *Issue) loadProject(e db.Engine) (err error) { // ProjectID return project id if issue was assigned to one func (i *Issue) ProjectID() int64 { - return i.projectID(db.GetEngine(db.DefaultContext)) + return i.projectID(db.DefaultContext) } -func (i *Issue) projectID(e db.Engine) int64 { +func (i *Issue) projectID(ctx context.Context) int64 { var ip project_model.ProjectIssue - has, err := e.Where("issue_id=?", i.ID).Get(&ip) + has, err := db.GetEngine(ctx).Where("issue_id=?", i.ID).Get(&ip) if err != nil || !has { return 0 } @@ -48,12 +48,12 @@ func (i *Issue) projectID(e db.Engine) int64 { // ProjectBoardID return project board id if issue was assigned to one func (i *Issue) ProjectBoardID() int64 { - return i.projectBoardID(db.GetEngine(db.DefaultContext)) + return i.projectBoardID(db.DefaultContext) } -func (i *Issue) projectBoardID(e db.Engine) int64 { +func (i *Issue) projectBoardID(ctx context.Context) int64 { var ip project_model.ProjectIssue - has, err := e.Where("issue_id=?", i.ID).Get(&ip) + has, err := db.GetEngine(ctx).Where("issue_id=?", i.ID).Get(&ip) if err != nil || !has { return 0 } @@ -122,10 +122,9 @@ func ChangeProjectAssign(issue *Issue, doer *user_model.User, newProjectID int64 } func addUpdateIssueProject(ctx context.Context, issue *Issue, doer *user_model.User, newProjectID int64) error { - e := db.GetEngine(ctx) - oldProjectID := issue.projectID(e) + oldProjectID := issue.projectID(ctx) - if _, err := e.Where("project_issue.issue_id=?", issue.ID).Delete(&project_model.ProjectIssue{}); err != nil { + if _, err := db.GetEngine(ctx).Where("project_issue.issue_id=?", issue.ID).Delete(&project_model.ProjectIssue{}); err != nil { return err } @@ -146,11 +145,10 @@ func addUpdateIssueProject(ctx context.Context, issue *Issue, doer *user_model.U } } - _, err := e.Insert(&project_model.ProjectIssue{ + return db.Insert(ctx, &project_model.ProjectIssue{ IssueID: issue.ID, ProjectID: newProjectID, }) - return err } // MoveIssueAcrossProjectBoards move a card from one board to another diff --git a/models/issue_stopwatch.go b/models/issue_stopwatch.go index 81459ba44..5b9f5c402 100644 --- a/models/issue_stopwatch.go +++ b/models/issue_stopwatch.go @@ -125,13 +125,9 @@ func StopwatchExists(userID, issueID int64) bool { } // HasUserStopwatch returns true if the user has a stopwatch -func HasUserStopwatch(userID int64) (exists bool, sw *Stopwatch, err error) { - return hasUserStopwatch(db.GetEngine(db.DefaultContext), userID) -} - -func hasUserStopwatch(e db.Engine, userID int64) (exists bool, sw *Stopwatch, err error) { +func HasUserStopwatch(ctx context.Context, userID int64) (exists bool, sw *Stopwatch, err error) { sw = new(Stopwatch) - exists, err = e. + exists, err = db.GetEngine(ctx). Where("user_id = ?", userID). Get(sw) return @@ -203,24 +199,23 @@ func FinishIssueStopwatch(ctx context.Context, user *user_model.User, issue *Iss }); err != nil { return err } - _, err = db.GetEngine(ctx).Delete(sw) + _, err = db.DeleteByBean(ctx, sw) return err } // CreateIssueStopwatch creates a stopwatch if not exist, otherwise return an error func CreateIssueStopwatch(ctx context.Context, user *user_model.User, issue *Issue) error { - e := db.GetEngine(ctx) if err := issue.LoadRepo(ctx); err != nil { return err } // if another stopwatch is running: stop it - exists, sw, err := hasUserStopwatch(e, user.ID) + exists, sw, err := HasUserStopwatch(ctx, user.ID) if err != nil { return err } if exists { - issue, err := getIssueByID(e, sw.IssueID) + issue, err := getIssueByID(ctx, sw.IssueID) if err != nil { return err } diff --git a/models/issue_stopwatch_test.go b/models/issue_stopwatch_test.go index 0f578f101..15d5f234f 100644 --- a/models/issue_stopwatch_test.go +++ b/models/issue_stopwatch_test.go @@ -7,6 +7,7 @@ package models import ( "testing" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/models/unittest" user_model "code.gitea.io/gitea/models/user" "code.gitea.io/gitea/modules/timeutil" @@ -44,12 +45,12 @@ func TestStopwatchExists(t *testing.T) { func TestHasUserStopwatch(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - exists, sw, err := HasUserStopwatch(1) + exists, sw, err := HasUserStopwatch(db.DefaultContext, 1) assert.NoError(t, err) assert.True(t, exists) assert.Equal(t, int64(1), sw.ID) - exists, _, err = HasUserStopwatch(3) + exists, _, err = HasUserStopwatch(db.DefaultContext, 3) assert.NoError(t, err) assert.False(t, exists) } diff --git a/models/issue_test.go b/models/issue_test.go index 9a8b7bd53..5b2f461a8 100644 --- a/models/issue_test.go +++ b/models/issue_test.go @@ -52,7 +52,7 @@ func TestIssue_ReplaceLabels(t *testing.T) { func Test_GetIssueIDsByRepoID(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - ids, err := GetIssueIDsByRepoID(1) + ids, err := GetIssueIDsByRepoID(db.DefaultContext, 1) assert.NoError(t, err) assert.Len(t, ids, 5) } @@ -69,7 +69,7 @@ func TestIssueAPIURL(t *testing.T) { func TestGetIssuesByIDs(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) testSuccess := func(expectedIssueIDs, nonExistentIssueIDs []int64) { - issues, err := GetIssuesByIDs(append(expectedIssueIDs, nonExistentIssueIDs...)) + issues, err := GetIssuesByIDs(db.DefaultContext, append(expectedIssueIDs, nonExistentIssueIDs...)) assert.NoError(t, err) actualIssueIDs := make([]int64, len(issues)) for i, issue := range issues { @@ -87,7 +87,7 @@ func TestGetParticipantIDsByIssue(t *testing.T) { checkParticipants := func(issueID int64, userIDs []int) { issue, err := GetIssueByID(issueID) assert.NoError(t, err) - participants, err := issue.getParticipantIDsByIssue(db.GetEngine(db.DefaultContext)) + participants, err := issue.getParticipantIDsByIssue(db.DefaultContext) if assert.NoError(t, err) { participantsIDs := make([]int, len(participants)) for i, uid := range participants { @@ -317,7 +317,7 @@ func TestIssue_loadTotalTimes(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) ms, err := GetIssueByID(2) assert.NoError(t, err) - assert.NoError(t, ms.loadTotalTimes(db.GetEngine(db.DefaultContext))) + assert.NoError(t, ms.loadTotalTimes(db.DefaultContext)) assert.Equal(t, int64(3682), ms.TotalTrackedTime) } @@ -419,7 +419,7 @@ func TestIssue_InsertIssue(t *testing.T) { func TestIssue_DeleteIssue(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - issueIDs, err := GetIssueIDsByRepoID(1) + issueIDs, err := GetIssueIDsByRepoID(db.DefaultContext, 1) assert.NoError(t, err) assert.EqualValues(t, 5, len(issueIDs)) @@ -430,12 +430,12 @@ func TestIssue_DeleteIssue(t *testing.T) { err = DeleteIssue(issue) assert.NoError(t, err) - issueIDs, err = GetIssueIDsByRepoID(1) + issueIDs, err = GetIssueIDsByRepoID(db.DefaultContext, 1) assert.NoError(t, err) assert.EqualValues(t, 4, len(issueIDs)) // check attachment removal - attachments, err := repo_model.GetAttachmentsByIssueID(4) + attachments, err := repo_model.GetAttachmentsByIssueID(db.DefaultContext, 4) assert.NoError(t, err) issue, err = GetIssueByID(4) assert.NoError(t, err) @@ -443,7 +443,7 @@ func TestIssue_DeleteIssue(t *testing.T) { assert.NoError(t, err) assert.EqualValues(t, 2, len(attachments)) for i := range attachments { - attachment, err := repo_model.GetAttachmentByUUID(attachments[i].UUID) + attachment, err := repo_model.GetAttachmentByUUID(db.DefaultContext, attachments[i].UUID) assert.Error(t, err) assert.True(t, repo_model.IsErrAttachmentNotExist(err)) assert.Nil(t, attachment) diff --git a/models/issue_tracked_time.go b/models/issue_tracked_time.go index 76ff874c5..30b3905bb 100644 --- a/models/issue_tracked_time.go +++ b/models/issue_tracked_time.go @@ -47,9 +47,8 @@ func (t *TrackedTime) LoadAttributes() (err error) { } func (t *TrackedTime) loadAttributes(ctx context.Context) (err error) { - e := db.GetEngine(ctx) if t.Issue == nil { - t.Issue, err = getIssueByID(e, t.IssueID) + t.Issue, err = getIssueByID(ctx, t.IssueID) if err != nil { return } @@ -59,7 +58,7 @@ func (t *TrackedTime) loadAttributes(ctx context.Context) (err error) { } } if t.User == nil { - t.User, err = user_model.GetUserByIDEngine(e, t.UserID) + t.User, err = user_model.GetUserByIDCtx(ctx, t.UserID) if err != nil { return } @@ -128,14 +127,10 @@ func (opts *FindTrackedTimesOptions) toSession(e db.Engine) db.Engine { return sess } -func getTrackedTimes(e db.Engine, options *FindTrackedTimesOptions) (trackedTimes TrackedTimeList, err error) { - err = options.toSession(e).Find(&trackedTimes) - return -} - // GetTrackedTimes returns all tracked times that fit to the given options. -func GetTrackedTimes(opts *FindTrackedTimesOptions) (TrackedTimeList, error) { - return getTrackedTimes(db.GetEngine(db.DefaultContext), opts) +func GetTrackedTimes(ctx context.Context, options *FindTrackedTimesOptions) (trackedTimes TrackedTimeList, err error) { + err = options.toSession(db.GetEngine(ctx)).Find(&trackedTimes) + return } // CountTrackedTimes returns count of tracked times that fit to the given options. @@ -147,13 +142,9 @@ func CountTrackedTimes(opts *FindTrackedTimesOptions) (int64, error) { return sess.Count(&TrackedTime{}) } -func getTrackedSeconds(e db.Engine, opts FindTrackedTimesOptions) (trackedSeconds int64, err error) { - return opts.toSession(e).SumInt(&TrackedTime{}, "time") -} - // GetTrackedSeconds return sum of seconds -func GetTrackedSeconds(opts FindTrackedTimesOptions) (int64, error) { - return getTrackedSeconds(db.GetEngine(db.DefaultContext), opts) +func GetTrackedSeconds(ctx context.Context, opts FindTrackedTimesOptions) (trackedSeconds int64, err error) { + return opts.toSession(db.GetEngine(ctx)).SumInt(&TrackedTime{}, "time") } // AddTime will add the given time (in seconds) to the issue @@ -163,9 +154,8 @@ func AddTime(user *user_model.User, issue *Issue, amount int64, created time.Tim return nil, err } defer committer.Close() - sess := db.GetEngine(ctx) - t, err := addTime(sess, user, issue, amount, created) + t, err := addTime(ctx, user, issue, amount, created) if err != nil { return nil, err } @@ -188,7 +178,7 @@ func AddTime(user *user_model.User, issue *Issue, amount int64, created time.Tim return t, committer.Commit() } -func addTime(e db.Engine, user *user_model.User, issue *Issue, amount int64, created time.Time) (*TrackedTime, error) { +func addTime(ctx context.Context, user *user_model.User, issue *Issue, amount int64, created time.Time) (*TrackedTime, error) { if created.IsZero() { created = time.Now() } @@ -198,16 +188,12 @@ func addTime(e db.Engine, user *user_model.User, issue *Issue, amount int64, cre Time: amount, Created: created, } - if _, err := e.Insert(tt); err != nil { - return nil, err - } - - return tt, nil + return tt, db.Insert(ctx, tt) } // TotalTimes returns the spent time for each user by an issue func TotalTimes(options *FindTrackedTimesOptions) (map[*user_model.User]string, error) { - trackedTimes, err := GetTrackedTimes(options) + trackedTimes, err := GetTrackedTimes(db.DefaultContext, options) if err != nil { return nil, err } @@ -239,14 +225,13 @@ func DeleteIssueUserTimes(issue *Issue, user *user_model.User) error { return err } defer committer.Close() - sess := db.GetEngine(ctx) opts := FindTrackedTimesOptions{ IssueID: issue.ID, UserID: user.ID, } - removedTime, err := deleteTimes(sess, opts) + removedTime, err := deleteTimes(ctx, opts) if err != nil { return err } @@ -282,7 +267,7 @@ func DeleteTime(t *TrackedTime) error { return err } - if err := deleteTime(db.GetEngine(ctx), t); err != nil { + if err := deleteTime(ctx, t); err != nil { return err } @@ -299,22 +284,22 @@ func DeleteTime(t *TrackedTime) error { return committer.Commit() } -func deleteTimes(e db.Engine, opts FindTrackedTimesOptions) (removedTime int64, err error) { - removedTime, err = getTrackedSeconds(e, opts) +func deleteTimes(ctx context.Context, opts FindTrackedTimesOptions) (removedTime int64, err error) { + removedTime, err = GetTrackedSeconds(ctx, opts) if err != nil || removedTime == 0 { return } - _, err = opts.toSession(e).Table("tracked_time").Cols("deleted").Update(&TrackedTime{Deleted: true}) + _, err = opts.toSession(db.GetEngine(ctx)).Table("tracked_time").Cols("deleted").Update(&TrackedTime{Deleted: true}) return } -func deleteTime(e db.Engine, t *TrackedTime) error { +func deleteTime(ctx context.Context, t *TrackedTime) error { if t.Deleted { return db.ErrNotExist{ID: t.ID} } t.Deleted = true - _, err := e.ID(t.ID).Cols("deleted").Update(t) + _, err := db.GetEngine(ctx).ID(t.ID).Cols("deleted").Update(t) return err } diff --git a/models/issue_tracked_time_test.go b/models/issue_tracked_time_test.go index 68e78c71c..a62832971 100644 --- a/models/issue_tracked_time_test.go +++ b/models/issue_tracked_time_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/models/unittest" user_model "code.gitea.io/gitea/models/user" @@ -41,27 +42,27 @@ func TestGetTrackedTimes(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) // by Issue - times, err := GetTrackedTimes(&FindTrackedTimesOptions{IssueID: 1}) + times, err := GetTrackedTimes(db.DefaultContext, &FindTrackedTimesOptions{IssueID: 1}) assert.NoError(t, err) assert.Len(t, times, 1) assert.Equal(t, int64(400), times[0].Time) - times, err = GetTrackedTimes(&FindTrackedTimesOptions{IssueID: -1}) + times, err = GetTrackedTimes(db.DefaultContext, &FindTrackedTimesOptions{IssueID: -1}) assert.NoError(t, err) assert.Len(t, times, 0) // by User - times, err = GetTrackedTimes(&FindTrackedTimesOptions{UserID: 1}) + times, err = GetTrackedTimes(db.DefaultContext, &FindTrackedTimesOptions{UserID: 1}) assert.NoError(t, err) assert.Len(t, times, 3) assert.Equal(t, int64(400), times[0].Time) - times, err = GetTrackedTimes(&FindTrackedTimesOptions{UserID: 3}) + times, err = GetTrackedTimes(db.DefaultContext, &FindTrackedTimesOptions{UserID: 3}) assert.NoError(t, err) assert.Len(t, times, 0) // by Repo - times, err = GetTrackedTimes(&FindTrackedTimesOptions{RepositoryID: 2}) + times, err = GetTrackedTimes(db.DefaultContext, &FindTrackedTimesOptions{RepositoryID: 2}) assert.NoError(t, err) assert.Len(t, times, 3) assert.Equal(t, int64(1), times[0].Time) @@ -69,11 +70,11 @@ func TestGetTrackedTimes(t *testing.T) { assert.NoError(t, err) assert.Equal(t, issue.RepoID, int64(2)) - times, err = GetTrackedTimes(&FindTrackedTimesOptions{RepositoryID: 1}) + times, err = GetTrackedTimes(db.DefaultContext, &FindTrackedTimesOptions{RepositoryID: 1}) assert.NoError(t, err) assert.Len(t, times, 5) - times, err = GetTrackedTimes(&FindTrackedTimesOptions{RepositoryID: 10}) + times, err = GetTrackedTimes(db.DefaultContext, &FindTrackedTimesOptions{RepositoryID: 10}) assert.NoError(t, err) assert.Len(t, times, 0) } diff --git a/models/issue_watch.go b/models/issue_watch.go index 92dc84741..9f41d36e1 100644 --- a/models/issue_watch.go +++ b/models/issue_watch.go @@ -5,6 +5,8 @@ package models import ( + "context" + "code.gitea.io/gitea/models/db" repo_model "code.gitea.io/gitea/models/repo" user_model "code.gitea.io/gitea/models/user" @@ -30,7 +32,7 @@ type IssueWatchList []*IssueWatch // CreateOrUpdateIssueWatch set watching for a user and issue func CreateOrUpdateIssueWatch(userID, issueID int64, isWatching bool) error { - iw, exists, err := getIssueWatch(db.GetEngine(db.DefaultContext), userID, issueID) + iw, exists, err := GetIssueWatch(db.DefaultContext, userID, issueID) if err != nil { return err } @@ -57,14 +59,9 @@ func CreateOrUpdateIssueWatch(userID, issueID int64, isWatching bool) error { // GetIssueWatch returns all IssueWatch objects from db by user and issue // the current Web-UI need iw object for watchers AND explicit non-watchers -func GetIssueWatch(userID, issueID int64) (iw *IssueWatch, exists bool, err error) { - return getIssueWatch(db.GetEngine(db.DefaultContext), userID, issueID) -} - -// Return watcher AND explicit non-watcher if entry in db exist -func getIssueWatch(e db.Engine, userID, issueID int64) (iw *IssueWatch, exists bool, err error) { +func GetIssueWatch(ctx context.Context, userID, issueID int64) (iw *IssueWatch, exists bool, err error) { iw = new(IssueWatch) - exists, err = e. + exists, err = db.GetEngine(ctx). Where("user_id = ?", userID). And("issue_id = ?", issueID). Get(iw) @@ -74,7 +71,7 @@ func getIssueWatch(e db.Engine, userID, issueID int64) (iw *IssueWatch, exists b // CheckIssueWatch check if an user is watching an issue // it takes participants and repo watch into account func CheckIssueWatch(user *user_model.User, issue *Issue) (bool, error) { - iw, exist, err := getIssueWatch(db.GetEngine(db.DefaultContext), user.ID, issue.ID) + iw, exist, err := GetIssueWatch(db.DefaultContext, user.ID, issue.ID) if err != nil { return false, err } @@ -91,13 +88,9 @@ func CheckIssueWatch(user *user_model.User, issue *Issue) (bool, error) { // GetIssueWatchersIDs returns IDs of subscribers or explicit unsubscribers to a given issue id // but avoids joining with `user` for performance reasons // User permissions must be verified elsewhere if required -func GetIssueWatchersIDs(issueID int64, watching bool) ([]int64, error) { - return getIssueWatchersIDs(db.GetEngine(db.DefaultContext), issueID, watching) -} - -func getIssueWatchersIDs(e db.Engine, issueID int64, watching bool) ([]int64, error) { +func GetIssueWatchersIDs(ctx context.Context, issueID int64, watching bool) ([]int64, error) { ids := make([]int64, 0, 64) - return ids, e.Table("issue_watch"). + return ids, db.GetEngine(ctx).Table("issue_watch"). Where("issue_id=?", issueID). And("is_watching = ?", watching). Select("user_id"). @@ -105,12 +98,8 @@ func getIssueWatchersIDs(e db.Engine, issueID int64, watching bool) ([]int64, er } // GetIssueWatchers returns watchers/unwatchers of a given issue -func GetIssueWatchers(issueID int64, listOptions db.ListOptions) (IssueWatchList, error) { - return getIssueWatchers(db.GetEngine(db.DefaultContext), issueID, listOptions) -} - -func getIssueWatchers(e db.Engine, issueID int64, listOptions db.ListOptions) (IssueWatchList, error) { - sess := e. +func GetIssueWatchers(ctx context.Context, issueID int64, listOptions db.ListOptions) (IssueWatchList, error) { + sess := db.GetEngine(ctx). Where("`issue_watch`.issue_id = ?", issueID). And("`issue_watch`.is_watching = ?", true). And("`user`.is_active = ?", true). @@ -127,12 +116,8 @@ func getIssueWatchers(e db.Engine, issueID int64, listOptions db.ListOptions) (I } // CountIssueWatchers count watchers/unwatchers of a given issue -func CountIssueWatchers(issueID int64) (int64, error) { - return countIssueWatchers(db.GetEngine(db.DefaultContext), issueID) -} - -func countIssueWatchers(e db.Engine, issueID int64) (int64, error) { - return e. +func CountIssueWatchers(ctx context.Context, issueID int64) (int64, error) { + return db.GetEngine(ctx). Where("`issue_watch`.issue_id = ?", issueID). And("`issue_watch`.is_watching = ?", true). And("`user`.is_active = ?", true). @@ -140,8 +125,8 @@ func countIssueWatchers(e db.Engine, issueID int64) (int64, error) { Join("INNER", "`user`", "`user`.id = `issue_watch`.user_id").Count(new(IssueWatch)) } -func removeIssueWatchersByRepoID(e db.Engine, userID, repoID int64) error { - _, err := e. +func removeIssueWatchersByRepoID(ctx context.Context, userID, repoID int64) error { + _, err := db.GetEngine(ctx). Join("INNER", "issue", "`issue`.id = `issue_watch`.issue_id AND `issue`.repo_id = ?", repoID). Where("`issue_watch`.user_id = ?", userID). Delete(new(IssueWatch)) diff --git a/models/issue_watch_test.go b/models/issue_watch_test.go index f75677d68..b686196ae 100644 --- a/models/issue_watch_test.go +++ b/models/issue_watch_test.go @@ -28,16 +28,16 @@ func TestCreateOrUpdateIssueWatch(t *testing.T) { func TestGetIssueWatch(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - _, exists, err := GetIssueWatch(9, 1) + _, exists, err := GetIssueWatch(db.DefaultContext, 9, 1) assert.True(t, exists) assert.NoError(t, err) - iw, exists, err := GetIssueWatch(2, 2) + iw, exists, err := GetIssueWatch(db.DefaultContext, 2, 2) assert.True(t, exists) assert.NoError(t, err) assert.False(t, iw.IsWatching) - _, exists, err = GetIssueWatch(3, 1) + _, exists, err = GetIssueWatch(db.DefaultContext, 3, 1) assert.False(t, exists) assert.NoError(t, err) } @@ -45,22 +45,22 @@ func TestGetIssueWatch(t *testing.T) { func TestGetIssueWatchers(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - iws, err := GetIssueWatchers(1, db.ListOptions{}) + iws, err := GetIssueWatchers(db.DefaultContext, 1, db.ListOptions{}) assert.NoError(t, err) // Watcher is inactive, thus 0 assert.Len(t, iws, 0) - iws, err = GetIssueWatchers(2, db.ListOptions{}) + iws, err = GetIssueWatchers(db.DefaultContext, 2, db.ListOptions{}) assert.NoError(t, err) // Watcher is explicit not watching assert.Len(t, iws, 0) - iws, err = GetIssueWatchers(5, db.ListOptions{}) + iws, err = GetIssueWatchers(db.DefaultContext, 5, db.ListOptions{}) assert.NoError(t, err) // Issue has no Watchers assert.Len(t, iws, 0) - iws, err = GetIssueWatchers(7, db.ListOptions{}) + iws, err = GetIssueWatchers(db.DefaultContext, 7, db.ListOptions{}) assert.NoError(t, err) // Issue has one watcher assert.Len(t, iws, 1) diff --git a/models/issue_xref.go b/models/issue_xref.go index 3c9e67c3f..0c1623b5a 100644 --- a/models/issue_xref.go +++ b/models/issue_xref.go @@ -30,16 +30,16 @@ type crossReferencesContext struct { RemoveOld bool } -func findOldCrossReferences(e db.Engine, issueID, commentID int64) ([]*Comment, error) { +func findOldCrossReferences(ctx context.Context, issueID, commentID int64) ([]*Comment, error) { active := make([]*Comment, 0, 10) - return active, e.Where("`ref_action` IN (?, ?, ?)", references.XRefActionNone, references.XRefActionCloses, references.XRefActionReopens). + return active, db.GetEngine(ctx).Where("`ref_action` IN (?, ?, ?)", references.XRefActionNone, references.XRefActionCloses, references.XRefActionReopens). And("`ref_issue_id` = ?", issueID). And("`ref_comment_id` = ?", commentID). Find(&active) } -func neuterCrossReferences(e db.Engine, issueID, commentID int64) error { - active, err := findOldCrossReferences(e, issueID, commentID) +func neuterCrossReferences(ctx context.Context, issueID, commentID int64) error { + active, err := findOldCrossReferences(ctx, issueID, commentID) if err != nil { return err } @@ -47,11 +47,11 @@ func neuterCrossReferences(e db.Engine, issueID, commentID int64) error { for i, c := range active { ids[i] = c.ID } - return neuterCrossReferencesIds(e, ids) + return neuterCrossReferencesIds(ctx, ids) } -func neuterCrossReferencesIds(e db.Engine, ids []int64) error { - _, err := e.In("id", ids).Cols("`ref_action`").Update(&Comment{RefAction: references.XRefActionNeutered}) +func neuterCrossReferencesIds(ctx context.Context, ids []int64) error { + _, err := db.GetEngine(ctx).In("id", ids).Cols("`ref_action`").Update(&Comment{RefAction: references.XRefActionNeutered}) return err } @@ -80,7 +80,6 @@ func (issue *Issue) addCrossReferences(stdCtx context.Context, doer *user_model. } func (issue *Issue) createCrossReferences(stdCtx context.Context, ctx *crossReferencesContext, plaincontent, mdcontent string) error { - e := db.GetEngine(stdCtx) xreflist, err := ctx.OrigIssue.getCrossReferences(stdCtx, ctx, plaincontent, mdcontent) if err != nil { return err @@ -90,7 +89,7 @@ func (issue *Issue) createCrossReferences(stdCtx context.Context, ctx *crossRefe if ctx.OrigComment != nil { commentID = ctx.OrigComment.ID } - active, err := findOldCrossReferences(e, ctx.OrigIssue.ID, commentID) + active, err := findOldCrossReferences(stdCtx, ctx.OrigIssue.ID, commentID) if err != nil { return err } @@ -109,7 +108,7 @@ func (issue *Issue) createCrossReferences(stdCtx context.Context, ctx *crossRefe } } if len(ids) > 0 { - if err = neuterCrossReferencesIds(e, ids); err != nil { + if err = neuterCrossReferencesIds(stdCtx, ids); err != nil { return err } } @@ -263,8 +262,8 @@ func (comment *Comment) addCrossReferences(stdCtx context.Context, doer *user_mo return comment.Issue.createCrossReferences(stdCtx, ctx, "", comment.Content) } -func (comment *Comment) neuterCrossReferences(e db.Engine) error { - return neuterCrossReferences(e, comment.IssueID, comment.ID) +func (comment *Comment) neuterCrossReferences(ctx context.Context) error { + return neuterCrossReferences(ctx, comment.IssueID, comment.ID) } // LoadRefComment loads comment that created this reference from database @@ -272,7 +271,7 @@ func (comment *Comment) LoadRefComment() (err error) { if comment.RefComment != nil { return nil } - comment.RefComment, err = GetCommentByID(comment.RefCommentID) + comment.RefComment, err = GetCommentByID(db.DefaultContext, comment.RefCommentID) return } diff --git a/models/issue_xref_test.go b/models/issue_xref_test.go index 677caa48b..b4ad5b270 100644 --- a/models/issue_xref_test.go +++ b/models/issue_xref_test.go @@ -150,7 +150,7 @@ func testCreateIssue(t *testing.T, repo, doer int64, title, content string, ispu Issue: i, }) assert.NoError(t, err) - i, err = getIssueByID(db.GetEngine(ctx), i.ID) + i, err = getIssueByID(ctx, i.ID) assert.NoError(t, err) assert.NoError(t, i.addCrossReferences(ctx, d, false)) assert.NoError(t, committer.Commit()) diff --git a/models/issues/content_history.go b/models/issues/content_history.go index 13aadcb1e..4c5af13db 100644 --- a/models/issues/content_history.go +++ b/models/issues/content_history.go @@ -38,7 +38,7 @@ func init() { } // SaveIssueContentHistory save history -func SaveIssueContentHistory(e db.Engine, posterID, issueID, commentID int64, editTime timeutil.TimeStamp, contentText string, isFirstCreated bool) error { +func SaveIssueContentHistory(ctx context.Context, posterID, issueID, commentID int64, editTime timeutil.TimeStamp, contentText string, isFirstCreated bool) error { ch := &ContentHistory{ PosterID: posterID, IssueID: issueID, @@ -47,27 +47,26 @@ func SaveIssueContentHistory(e db.Engine, posterID, issueID, commentID int64, ed EditedUnix: editTime, IsFirstCreated: isFirstCreated, } - _, err := e.Insert(ch) - if err != nil { + if err := db.Insert(ctx, ch); err != nil { log.Error("can not save issue content history. err=%v", err) return err } // We only keep at most 20 history revisions now. It is enough in most cases. // If there is a special requirement to keep more, we can consider introducing a new setting option then, but not now. - keepLimitedContentHistory(e, issueID, commentID, 20) + keepLimitedContentHistory(ctx, issueID, commentID, 20) return nil } // keepLimitedContentHistory keeps at most `limit` history revisions, it will hard delete out-dated revisions, sorting by revision interval // we can ignore all errors in this function, so we just log them -func keepLimitedContentHistory(e db.Engine, issueID, commentID int64, limit int) { +func keepLimitedContentHistory(ctx context.Context, issueID, commentID int64, limit int) { type IDEditTime struct { ID int64 EditedUnix timeutil.TimeStamp } var res []*IDEditTime - err := e.Select("id, edited_unix").Table("issue_content_history"). + err := db.GetEngine(ctx).Select("id, edited_unix").Table("issue_content_history"). Where(builder.Eq{"issue_id": issueID, "comment_id": commentID}). OrderBy("edited_unix ASC"). Find(&res) @@ -96,7 +95,7 @@ func keepLimitedContentHistory(e db.Engine, issueID, commentID int64, limit int) } // hard delete the found one - _, err = e.Delete(&ContentHistory{ID: res[indexToDelete].ID}) + _, err = db.GetEngine(ctx).Delete(&ContentHistory{ID: res[indexToDelete].ID}) if err != nil { log.Error("can not delete out-dated content history, err=%v", err) break diff --git a/models/issues/content_history_test.go b/models/issues/content_history_test.go index 71ccc6e6a..3cbc0ad5e 100644 --- a/models/issues/content_history_test.go +++ b/models/issues/content_history_test.go @@ -18,18 +18,17 @@ func TestContentHistory(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) dbCtx := db.DefaultContext - dbEngine := db.GetEngine(dbCtx) timeStampNow := timeutil.TimeStampNow() - _ = SaveIssueContentHistory(dbEngine, 1, 10, 0, timeStampNow, "i-a", true) - _ = SaveIssueContentHistory(dbEngine, 1, 10, 0, timeStampNow.Add(2), "i-b", false) - _ = SaveIssueContentHistory(dbEngine, 1, 10, 0, timeStampNow.Add(7), "i-c", false) + _ = SaveIssueContentHistory(dbCtx, 1, 10, 0, timeStampNow, "i-a", true) + _ = SaveIssueContentHistory(dbCtx, 1, 10, 0, timeStampNow.Add(2), "i-b", false) + _ = SaveIssueContentHistory(dbCtx, 1, 10, 0, timeStampNow.Add(7), "i-c", false) - _ = SaveIssueContentHistory(dbEngine, 1, 10, 100, timeStampNow, "c-a", true) - _ = SaveIssueContentHistory(dbEngine, 1, 10, 100, timeStampNow.Add(5), "c-b", false) - _ = SaveIssueContentHistory(dbEngine, 1, 10, 100, timeStampNow.Add(20), "c-c", false) - _ = SaveIssueContentHistory(dbEngine, 1, 10, 100, timeStampNow.Add(50), "c-d", false) - _ = SaveIssueContentHistory(dbEngine, 1, 10, 100, timeStampNow.Add(51), "c-e", false) + _ = SaveIssueContentHistory(dbCtx, 1, 10, 100, timeStampNow, "c-a", true) + _ = SaveIssueContentHistory(dbCtx, 1, 10, 100, timeStampNow.Add(5), "c-b", false) + _ = SaveIssueContentHistory(dbCtx, 1, 10, 100, timeStampNow.Add(20), "c-c", false) + _ = SaveIssueContentHistory(dbCtx, 1, 10, 100, timeStampNow.Add(50), "c-d", false) + _ = SaveIssueContentHistory(dbCtx, 1, 10, 100, timeStampNow.Add(51), "c-e", false) h1, _ := GetIssueContentHistoryByID(dbCtx, 1) assert.EqualValues(t, 1, h1.ID) @@ -47,7 +46,7 @@ func TestContentHistory(t *testing.T) { Name string FullName string } - _ = dbEngine.Sync2(&User{}) + _ = db.GetEngine(dbCtx).Sync2(&User{}) list1, _ := FetchIssueContentHistoryList(dbCtx, 10, 0) assert.Len(t, list1, 3) @@ -70,7 +69,7 @@ func TestContentHistory(t *testing.T) { assert.EqualValues(t, 4, h6Prev.ID) // only keep 3 history revisions for comment_id=100, the first and the last should never be deleted - keepLimitedContentHistory(dbEngine, 10, 100, 3) + keepLimitedContentHistory(dbCtx, 10, 100, 3) list1, _ = FetchIssueContentHistoryList(dbCtx, 10, 0) assert.Len(t, list1, 3) list2, _ = FetchIssueContentHistoryList(dbCtx, 10, 100) diff --git a/models/issues/milestone.go b/models/issues/milestone.go index 07c38754d..f7172f644 100644 --- a/models/issues/milestone.go +++ b/models/issues/milestone.go @@ -292,11 +292,11 @@ func DeleteMilestoneByRepoID(repoID, id int64) error { return err } - numMilestones, err := countRepoMilestones(sess, repo.ID) + numMilestones, err := countRepoMilestones(ctx, repo.ID) if err != nil { return err } - numClosedMilestones, err := countRepoClosedMilestones(sess, repo.ID) + numClosedMilestones, err := countRepoClosedMilestones(ctx, repo.ID) if err != nil { return err } @@ -503,21 +503,21 @@ func GetMilestonesStatsByRepoCondAndKw(repoCond builder.Cond, keyword string) (* return stats, nil } -func countRepoMilestones(e db.Engine, repoID int64) (int64, error) { - return e. +func countRepoMilestones(ctx context.Context, repoID int64) (int64, error) { + return db.GetEngine(ctx). Where("repo_id=?", repoID). Count(new(Milestone)) } -func countRepoClosedMilestones(e db.Engine, repoID int64) (int64, error) { - return e. +func countRepoClosedMilestones(ctx context.Context, repoID int64) (int64, error) { + return db.GetEngine(ctx). Where("repo_id=? AND is_closed=?", repoID, true). Count(new(Milestone)) } // CountRepoClosedMilestones returns number of closed milestones in given repository. func CountRepoClosedMilestones(repoID int64) (int64, error) { - return countRepoClosedMilestones(db.GetEngine(db.DefaultContext), repoID) + return countRepoClosedMilestones(db.DefaultContext, repoID) } // CountMilestonesByRepoCond map from repo conditions to number of milestones matching the options` @@ -590,7 +590,7 @@ func updateRepoMilestoneNum(ctx context.Context, repoID int64) error { // |_||_| \__,_|\___|_|\_\___|\__,_| |_| |_|_| |_| |_|\___||___/ // -func (milestones MilestoneList) loadTotalTrackedTimes(e db.Engine) error { +func (milestones MilestoneList) loadTotalTrackedTimes(ctx context.Context) error { type totalTimesByMilestone struct { MilestoneID int64 Time int64 @@ -601,7 +601,7 @@ func (milestones MilestoneList) loadTotalTrackedTimes(e db.Engine) error { trackedTimes := make(map[int64]int64, len(milestones)) // Get total tracked time by milestone_id - rows, err := e.Table("issue"). + rows, err := db.GetEngine(ctx).Table("issue"). Join("INNER", "milestone", "issue.milestone_id = milestone.id"). Join("LEFT", "tracked_time", "tracked_time.issue_id = issue.id"). Where("tracked_time.deleted = ?", false). @@ -630,13 +630,13 @@ func (milestones MilestoneList) loadTotalTrackedTimes(e db.Engine) error { return nil } -func (m *Milestone) loadTotalTrackedTime(e db.Engine) error { +func (m *Milestone) loadTotalTrackedTime(ctx context.Context) error { type totalTimesByMilestone struct { MilestoneID int64 Time int64 } totalTime := &totalTimesByMilestone{MilestoneID: m.ID} - has, err := e.Table("issue"). + has, err := db.GetEngine(ctx).Table("issue"). Join("INNER", "milestone", "issue.milestone_id = milestone.id"). Join("LEFT", "tracked_time", "tracked_time.issue_id = issue.id"). Where("tracked_time.deleted = ?", false). @@ -655,10 +655,10 @@ func (m *Milestone) loadTotalTrackedTime(e db.Engine) error { // LoadTotalTrackedTimes loads for every milestone in the list the TotalTrackedTime by a batch request func (milestones MilestoneList) LoadTotalTrackedTimes() error { - return milestones.loadTotalTrackedTimes(db.GetEngine(db.DefaultContext)) + return milestones.loadTotalTrackedTimes(db.DefaultContext) } // LoadTotalTrackedTime loads the tracked time for the milestone func (m *Milestone) LoadTotalTrackedTime() error { - return m.loadTotalTrackedTime(db.GetEngine(db.DefaultContext)) + return m.loadTotalTrackedTime(db.DefaultContext) } diff --git a/models/issues/milestone_test.go b/models/issues/milestone_test.go index 09f51de45..e08731832 100644 --- a/models/issues/milestone_test.go +++ b/models/issues/milestone_test.go @@ -149,7 +149,7 @@ func TestCountRepoMilestones(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) test := func(repoID int64) { repo := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: repoID}).(*repo_model.Repository) - count, err := countRepoMilestones(db.GetEngine(db.DefaultContext), repoID) + count, err := countRepoMilestones(db.DefaultContext, repoID) assert.NoError(t, err) assert.EqualValues(t, repo.NumMilestones, count) } @@ -157,7 +157,7 @@ func TestCountRepoMilestones(t *testing.T) { test(2) test(3) - count, err := countRepoMilestones(db.GetEngine(db.DefaultContext), unittest.NonexistentID) + count, err := countRepoMilestones(db.DefaultContext, unittest.NonexistentID) assert.NoError(t, err) assert.EqualValues(t, 0, count) } diff --git a/models/notification.go b/models/notification.go index d0b7852cd..548362d19 100644 --- a/models/notification.go +++ b/models/notification.go @@ -119,22 +119,18 @@ func (opts *FindNotificationOptions) ToCond() builder.Cond { } // ToSession will convert the given options to a xorm Session by using the conditions from ToCond and joining with issue table if required -func (opts *FindNotificationOptions) ToSession(e db.Engine) *xorm.Session { - sess := e.Where(opts.ToCond()) +func (opts *FindNotificationOptions) ToSession(ctx context.Context) *xorm.Session { + sess := db.GetEngine(ctx).Where(opts.ToCond()) if opts.Page != 0 { sess = db.SetSessionPagination(sess, opts) } return sess } -func getNotifications(e db.Engine, options *FindNotificationOptions) (nl NotificationList, err error) { - err = options.ToSession(e).OrderBy("notification.updated_unix DESC").Find(&nl) - return -} - // GetNotifications returns all notifications that fit to the given options. -func GetNotifications(opts *FindNotificationOptions) (NotificationList, error) { - return getNotifications(db.GetEngine(db.DefaultContext), opts) +func GetNotifications(ctx context.Context, options *FindNotificationOptions) (nl NotificationList, err error) { + err = options.ToSession(ctx).OrderBy("notification.updated_unix DESC").Find(&nl) + return } // CountNotifications count all notifications that fit to the given options and ignore pagination. @@ -201,15 +197,14 @@ func CreateOrUpdateIssueNotifications(issueID, commentID, notificationAuthorID, } func createOrUpdateIssueNotifications(ctx context.Context, issueID, commentID, notificationAuthorID, receiverID int64) error { - e := db.GetEngine(ctx) // init var toNotify map[int64]struct{} - notifications, err := getNotificationsByIssueID(e, issueID) + notifications, err := getNotificationsByIssueID(ctx, issueID) if err != nil { return err } - issue, err := getIssueByID(e, issueID) + issue, err := getIssueByID(ctx, issueID) if err != nil { return err } @@ -219,7 +214,7 @@ func createOrUpdateIssueNotifications(ctx context.Context, issueID, commentID, n toNotify[receiverID] = struct{}{} } else { toNotify = make(map[int64]struct{}, 32) - issueWatches, err := getIssueWatchersIDs(e, issueID, true) + issueWatches, err := GetIssueWatchersIDs(ctx, issueID, true) if err != nil { return err } @@ -235,7 +230,7 @@ func createOrUpdateIssueNotifications(ctx context.Context, issueID, commentID, n toNotify[id] = struct{}{} } } - issueParticipants, err := issue.getParticipantIDsByIssue(e) + issueParticipants, err := issue.getParticipantIDsByIssue(ctx) if err != nil { return err } @@ -246,7 +241,7 @@ func createOrUpdateIssueNotifications(ctx context.Context, issueID, commentID, n // dont notify user who cause notification delete(toNotify, notificationAuthorID) // explicit unwatch on issue - issueUnWatches, err := getIssueWatchersIDs(e, issueID, false) + issueUnWatches, err := GetIssueWatchersIDs(ctx, issueID, false) if err != nil { return err } @@ -263,7 +258,7 @@ func createOrUpdateIssueNotifications(ctx context.Context, issueID, commentID, n // notify for userID := range toNotify { issue.Repo.Units = nil - user, err := user_model.GetUserByIDEngine(e, userID) + user, err := user_model.GetUserByIDCtx(ctx, userID) if err != nil { if user_model.IsErrUserNotExist(err) { continue @@ -279,20 +274,20 @@ func createOrUpdateIssueNotifications(ctx context.Context, issueID, commentID, n } if notificationExists(notifications, issue.ID, userID) { - if err = updateIssueNotification(e, userID, issue.ID, commentID, notificationAuthorID); err != nil { + if err = updateIssueNotification(ctx, userID, issue.ID, commentID, notificationAuthorID); err != nil { return err } continue } - if err = createIssueNotification(e, userID, issue, commentID, notificationAuthorID); err != nil { + if err = createIssueNotification(ctx, userID, issue, commentID, notificationAuthorID); err != nil { return err } } return nil } -func getNotificationsByIssueID(e db.Engine, issueID int64) (notifications []*Notification, err error) { - err = e. +func getNotificationsByIssueID(ctx context.Context, issueID int64) (notifications []*Notification, err error) { + err = db.GetEngine(ctx). Where("issue_id = ?", issueID). Find(¬ifications) return @@ -308,7 +303,7 @@ func notificationExists(notifications []*Notification, issueID, userID int64) bo return false } -func createIssueNotification(e db.Engine, userID int64, issue *Issue, commentID, updatedByID int64) error { +func createIssueNotification(ctx context.Context, userID int64, issue *Issue, commentID, updatedByID int64) error { notification := &Notification{ UserID: userID, RepoID: issue.RepoID, @@ -324,12 +319,11 @@ func createIssueNotification(e db.Engine, userID int64, issue *Issue, commentID, notification.Source = NotificationSourceIssue } - _, err := e.Insert(notification) - return err + return db.Insert(ctx, notification) } -func updateIssueNotification(e db.Engine, userID, issueID, commentID, updatedByID int64) error { - notification, err := getIssueNotification(e, userID, issueID) +func updateIssueNotification(ctx context.Context, userID, issueID, commentID, updatedByID int64) error { + notification, err := getIssueNotification(ctx, userID, issueID) if err != nil { return err } @@ -346,13 +340,13 @@ func updateIssueNotification(e db.Engine, userID, issueID, commentID, updatedByI cols = []string{"update_by"} } - _, err = e.ID(notification.ID).Cols(cols...).Update(notification) + _, err = db.GetEngine(ctx).ID(notification.ID).Cols(cols...).Update(notification) return err } -func getIssueNotification(e db.Engine, userID, issueID int64) (*Notification, error) { +func getIssueNotification(ctx context.Context, userID, issueID int64) (*Notification, error) { notification := new(Notification) - _, err := e. + _, err := db.GetEngine(ctx). Where("user_id = ?", userID). And("issue_id = ?", issueID). Get(notification) @@ -360,16 +354,12 @@ func getIssueNotification(e db.Engine, userID, issueID int64) (*Notification, er } // NotificationsForUser returns notifications for a given user and status -func NotificationsForUser(user *user_model.User, statuses []NotificationStatus, page, perPage int) (NotificationList, error) { - return notificationsForUser(db.GetEngine(db.DefaultContext), user, statuses, page, perPage) -} - -func notificationsForUser(e db.Engine, user *user_model.User, statuses []NotificationStatus, page, perPage int) (notifications []*Notification, err error) { +func NotificationsForUser(ctx context.Context, user *user_model.User, statuses []NotificationStatus, page, perPage int) (notifications NotificationList, err error) { if len(statuses) == 0 { return } - sess := e. + sess := db.GetEngine(ctx). Where("user_id = ?", user.ID). In("status", statuses). OrderBy("updated_unix DESC") @@ -383,12 +373,8 @@ func notificationsForUser(e db.Engine, user *user_model.User, statuses []Notific } // CountUnread count unread notifications for a user -func CountUnread(user *user_model.User) int64 { - return countUnread(db.GetEngine(db.DefaultContext), user.ID) -} - -func countUnread(e db.Engine, userID int64) int64 { - exist, err := e.Where("user_id = ?", userID).And("status = ?", NotificationStatusUnread).Count(new(Notification)) +func CountUnread(ctx context.Context, userID int64) int64 { + exist, err := db.GetEngine(ctx).Where("user_id = ?", userID).And("status = ?", NotificationStatusUnread).Count(new(Notification)) if err != nil { log.Error("countUnread", err) return 0 @@ -402,17 +388,16 @@ func (n *Notification) LoadAttributes() (err error) { } func (n *Notification) loadAttributes(ctx context.Context) (err error) { - e := db.GetEngine(ctx) if err = n.loadRepo(ctx); err != nil { return } if err = n.loadIssue(ctx); err != nil { return } - if err = n.loadUser(e); err != nil { + if err = n.loadUser(ctx); err != nil { return } - if err = n.loadComment(e); err != nil { + if err = n.loadComment(ctx); err != nil { return } return @@ -430,7 +415,7 @@ func (n *Notification) loadRepo(ctx context.Context) (err error) { func (n *Notification) loadIssue(ctx context.Context) (err error) { if n.Issue == nil && n.IssueID != 0 { - n.Issue, err = getIssueByID(db.GetEngine(ctx), n.IssueID) + n.Issue, err = getIssueByID(ctx, n.IssueID) if err != nil { return fmt.Errorf("getIssueByID [%d]: %v", n.IssueID, err) } @@ -439,9 +424,9 @@ func (n *Notification) loadIssue(ctx context.Context) (err error) { return nil } -func (n *Notification) loadComment(e db.Engine) (err error) { +func (n *Notification) loadComment(ctx context.Context) (err error) { if n.Comment == nil && n.CommentID != 0 { - n.Comment, err = getCommentByID(e, n.CommentID) + n.Comment, err = GetCommentByID(ctx, n.CommentID) if err != nil { if IsErrCommentNotExist(err) { return ErrCommentNotExist{ @@ -455,9 +440,9 @@ func (n *Notification) loadComment(e db.Engine) (err error) { return nil } -func (n *Notification) loadUser(e db.Engine) (err error) { +func (n *Notification) loadUser(ctx context.Context) (err error) { if n.User == nil { - n.User, err = user_model.GetUserByIDEngine(e, n.UserID) + n.User, err = user_model.GetUserByIDCtx(ctx, n.UserID) if err != nil { return fmt.Errorf("getUserByID [%d]: %v", n.UserID, err) } @@ -739,12 +724,8 @@ func (nl NotificationList) LoadComments() ([]int, error) { } // GetNotificationCount returns the notification count for user -func GetNotificationCount(user *user_model.User, status NotificationStatus) (int64, error) { - return getNotificationCount(db.GetEngine(db.DefaultContext), user, status) -} - -func getNotificationCount(e db.Engine, user *user_model.User, status NotificationStatus) (count int64, err error) { - count, err = e. +func GetNotificationCount(ctx context.Context, user *user_model.User, status NotificationStatus) (count int64, err error) { + count, err = db.GetEngine(ctx). Where("user_id = ?", user.ID). And("status = ?", status). Count(&Notification{}) @@ -766,8 +747,8 @@ func GetUIDsAndNotificationCounts(since, until timeutil.TimeStamp) ([]UserIDCoun return res, db.GetEngine(db.DefaultContext).SQL(sql, since, until, NotificationStatusUnread).Find(&res) } -func setIssueNotificationStatusReadIfUnread(e db.Engine, userID, issueID int64) error { - notification, err := getIssueNotification(e, userID, issueID) +func setIssueNotificationStatusReadIfUnread(ctx context.Context, userID, issueID int64) error { + notification, err := getIssueNotification(ctx, userID, issueID) // ignore if not exists if err != nil { return nil @@ -779,12 +760,13 @@ func setIssueNotificationStatusReadIfUnread(e db.Engine, userID, issueID int64) notification.Status = NotificationStatusRead - _, err = e.ID(notification.ID).Update(notification) + _, err = db.GetEngine(ctx).ID(notification.ID).Update(notification) return err } -func setRepoNotificationStatusReadIfUnread(e db.Engine, userID, repoID int64) error { - _, err := e.Where(builder.Eq{ +// SetRepoReadBy sets repo to be visited by given user. +func SetRepoReadBy(ctx context.Context, userID, repoID int64) error { + _, err := db.GetEngine(ctx).Where(builder.Eq{ "user_id": userID, "status": NotificationStatusUnread, "source": NotificationSourceRepository, @@ -795,7 +777,7 @@ func setRepoNotificationStatusReadIfUnread(e db.Engine, userID, repoID int64) er // SetNotificationStatus change the notification status func SetNotificationStatus(notificationID int64, user *user_model.User, status NotificationStatus) (*Notification, error) { - notification, err := getNotificationByID(db.GetEngine(db.DefaultContext), notificationID) + notification, err := getNotificationByID(db.DefaultContext, notificationID) if err != nil { return notification, err } @@ -812,12 +794,12 @@ func SetNotificationStatus(notificationID int64, user *user_model.User, status N // GetNotificationByID return notification by ID func GetNotificationByID(notificationID int64) (*Notification, error) { - return getNotificationByID(db.GetEngine(db.DefaultContext), notificationID) + return getNotificationByID(db.DefaultContext, notificationID) } -func getNotificationByID(e db.Engine, notificationID int64) (*Notification, error) { +func getNotificationByID(ctx context.Context, notificationID int64) (*Notification, error) { notification := new(Notification) - ok, err := e. + ok, err := db.GetEngine(ctx). Where("id = ?", notificationID). Get(notification) if err != nil { diff --git a/models/notification_test.go b/models/notification_test.go index 3b05f34c5..15c29389c 100644 --- a/models/notification_test.go +++ b/models/notification_test.go @@ -7,6 +7,7 @@ package models import ( "testing" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/models/unittest" user_model "code.gitea.io/gitea/models/user" @@ -32,7 +33,7 @@ func TestNotificationsForUser(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 2}).(*user_model.User) statuses := []NotificationStatus{NotificationStatusRead, NotificationStatusUnread} - notfs, err := NotificationsForUser(user, statuses, 1, 10) + notfs, err := NotificationsForUser(db.DefaultContext, user, statuses, 1, 10) assert.NoError(t, err) if assert.Len(t, notfs, 3) { assert.EqualValues(t, 5, notfs[0].ID) @@ -65,11 +66,11 @@ func TestNotification_GetIssue(t *testing.T) { func TestGetNotificationCount(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 1}).(*user_model.User) - cnt, err := GetNotificationCount(user, NotificationStatusRead) + cnt, err := GetNotificationCount(db.DefaultContext, user, NotificationStatusRead) assert.NoError(t, err) assert.EqualValues(t, 0, cnt) - cnt, err = GetNotificationCount(user, NotificationStatusUnread) + cnt, err = GetNotificationCount(db.DefaultContext, user, NotificationStatusUnread) assert.NoError(t, err) assert.EqualValues(t, 1, cnt) } diff --git a/models/org.go b/models/org.go index 1efa0504e..681b367f4 100644 --- a/models/org.go +++ b/models/org.go @@ -95,7 +95,7 @@ func removeOrgUser(ctx context.Context, orgID, userID int64) error { return nil } - org, err := organization.GetOrgByIDCtx(ctx, orgID) + org, err := organization.GetOrgByID(ctx, orgID) if err != nil { return fmt.Errorf("GetUserByID [%d]: %v", orgID, err) } @@ -120,7 +120,7 @@ func removeOrgUser(ctx context.Context, orgID, userID int64) error { if _, err := sess.ID(ou.ID).Delete(ou); err != nil { return err - } else if _, err = sess.Exec("UPDATE `user` SET num_members=num_members-1 WHERE id=?", orgID); err != nil { + } else if _, err = db.Exec(ctx, "UPDATE `user` SET num_members=num_members-1 WHERE id=?", orgID); err != nil { return err } @@ -134,7 +134,7 @@ func removeOrgUser(ctx context.Context, orgID, userID int64) error { return fmt.Errorf("GetUserRepositories [%d]: %v", userID, err) } for _, repoID := range repoIDs { - if err = repo_model.WatchRepoCtx(ctx, userID, repoID, false); err != nil { + if err = repo_model.WatchRepo(ctx, userID, repoID, false); err != nil { return err } } diff --git a/models/org_team.go b/models/org_team.go index 0aba0cbb2..f1d35ee18 100644 --- a/models/org_team.go +++ b/models/org_team.go @@ -44,7 +44,7 @@ func addRepository(ctx context.Context, t *organization.Team, repo *repo_model.R return fmt.Errorf("getMembers: %v", err) } for _, u := range t.Members { - if err = repo_model.WatchRepoCtx(ctx, u.ID, repo.ID, true); err != nil { + if err = repo_model.WatchRepo(ctx, u.ID, repo.ID, true); err != nil { return fmt.Errorf("watchRepo: %v", err) } } @@ -147,12 +147,12 @@ func removeAllRepositories(ctx context.Context, t *organization.Team) (err error continue } - if err = repo_model.WatchRepoCtx(ctx, user.ID, repo.ID, false); err != nil { + if err = repo_model.WatchRepo(ctx, user.ID, repo.ID, false); err != nil { return err } // Remove all IssueWatches a user has subscribed to in the repositories - if err = removeIssueWatchersByRepoID(e, user.ID, repo.ID); err != nil { + if err = removeIssueWatchersByRepoID(ctx, user.ID, repo.ID); err != nil { return err } } @@ -210,12 +210,12 @@ func removeRepository(ctx context.Context, t *organization.Team, repo *repo_mode continue } - if err = repo_model.WatchRepoCtx(ctx, teamUser.UID, repo.ID, false); err != nil { + if err = repo_model.WatchRepo(ctx, teamUser.UID, repo.ID, false); err != nil { return err } // Remove all IssueWatches a user has subscribed to in the repositories - if err := removeIssueWatchersByRepoID(e, teamUser.UID, repo.ID); err != nil { + if err := removeIssueWatchersByRepoID(ctx, teamUser.UID, repo.ID); err != nil { return err } } @@ -555,7 +555,7 @@ func AddTeamMember(team *organization.Team, userID int64) error { } go func(repos []*repo_model.Repository) { for _, repo := range repos { - if err = repo_model.WatchRepoCtx(db.DefaultContext, userID, repo.ID, true); err != nil { + if err = repo_model.WatchRepo(db.DefaultContext, userID, repo.ID, true); err != nil { log.Error("watch repo failed: %v", err) } } diff --git a/models/organization/org.go b/models/organization/org.go index 43d96793b..0d4a5e337 100644 --- a/models/organization/org.go +++ b/models/organization/org.go @@ -102,7 +102,7 @@ func (org *Organization) CanCreateOrgRepo(uid int64) (bool, error) { } func (org *Organization) getTeam(ctx context.Context, name string) (*Team, error) { - return getTeam(ctx, org.ID, name) + return GetTeam(ctx, org.ID, name) } // GetTeam returns named team of organization. @@ -203,7 +203,7 @@ func CountOrgMembers(opts *FindOrgMembersOpts) (int64, error) { // FindOrgMembers loads organization members according conditions func FindOrgMembers(opts *FindOrgMembersOpts) (user_model.UserList, map[int64]bool, error) { - ous, err := GetOrgUsersByOrgID(opts) + ous, err := GetOrgUsersByOrgID(db.DefaultContext, opts) if err != nil { return nil, nil, err } @@ -248,7 +248,7 @@ func CreateOrganization(org *Organization, owner *user_model.User) (err error) { return err } - isExist, err := user_model.IsUserExist(0, org.Name) + isExist, err := user_model.IsUserExist(db.DefaultContext, 0, org.Name) if err != nil { return err } else if isExist { @@ -281,7 +281,7 @@ func CreateOrganization(org *Organization, owner *user_model.User) (err error) { if err = db.Insert(ctx, org); err != nil { return fmt.Errorf("insert organization: %v", err) } - if err = user_model.GenerateRandomAvatarCtx(ctx, org.AsUser()); err != nil { + if err = user_model.GenerateRandomAvatar(ctx, org.AsUser()); err != nil { return fmt.Errorf("generate random avatar: %v", err) } @@ -350,14 +350,6 @@ func GetOrgByName(name string) (*Organization, error) { return u, nil } -// CountOrganizations returns number of organizations. -func CountOrganizations() int64 { - count, _ := db.GetEngine(db.DefaultContext). - Where("type=1"). - Count(new(Organization)) - return count -} - // DeleteOrganization deletes models associated to an organization. func DeleteOrganization(ctx context.Context, org *Organization) error { if org.Type != user_model.UserTypeOrganization { @@ -425,7 +417,7 @@ func queryUserOrgIDs(userID int64, includePrivate bool) *builder.Builder { } func (opts FindOrgOptions) toConds() builder.Cond { - cond := builder.NewCond() + var cond builder.Cond = builder.Eq{"`user`.`type`": user_model.UserTypeOrganization} if opts.UserID > 0 { cond = cond.And(builder.In("`user`.`id`", queryUserOrgIDs(opts.UserID, opts.IncludePrivate))) } @@ -451,18 +443,7 @@ func FindOrgs(opts FindOrgOptions) ([]*Organization, error) { func CountOrgs(opts FindOrgOptions) (int64, error) { return db.GetEngine(db.DefaultContext). Where(opts.toConds()). - Count(new(user_model.User)) -} - -func getOwnedOrgsByUserID(sess db.Engine, userID int64) ([]*Organization, error) { - orgs := make([]*Organization, 0, 10) - return orgs, sess. - Join("INNER", "`team_user`", "`team_user`.org_id=`user`.id"). - Join("INNER", "`team`", "`team`.id=`team_user`.team_id"). - Where("`team_user`.uid=?", userID). - And("`team`.authorize=?", perm.AccessModeOwner). - Asc("`user`.name"). - Find(&orgs) + Count(new(Organization)) } // HasOrgOrUserVisible tells if the given user can see the given org or user @@ -496,17 +477,6 @@ func HasOrgsVisible(orgs []*Organization, user *user_model.User) bool { return false } -// GetOwnedOrgsByUserID returns a list of organizations are owned by given user ID. -func GetOwnedOrgsByUserID(userID int64) ([]*Organization, error) { - return getOwnedOrgsByUserID(db.GetEngine(db.DefaultContext), userID) -} - -// GetOwnedOrgsByUserIDDesc returns a list of organizations are owned by -// given user ID, ordered descending by the given condition. -func GetOwnedOrgsByUserIDDesc(userID int64, desc string) ([]*Organization, error) { - return getOwnedOrgsByUserID(db.GetEngine(db.DefaultContext).Desc(desc), userID) -} - // GetOrgsCanCreateRepoByUserID returns a list of organizations where given user ID // are allowed to create repos. func GetOrgsCanCreateRepoByUserID(userID int64) ([]*Organization, error) { @@ -543,12 +513,8 @@ func GetOrgUsersByUserID(uid int64, opts *SearchOrganizationsOptions) ([]*OrgUse } // GetOrgUsersByOrgID returns all organization-user relations by organization ID. -func GetOrgUsersByOrgID(opts *FindOrgMembersOpts) ([]*OrgUser, error) { - return getOrgUsersByOrgID(db.GetEngine(db.DefaultContext), opts) -} - -func getOrgUsersByOrgID(e db.Engine, opts *FindOrgMembersOpts) ([]*OrgUser, error) { - sess := e.Where("org_id=?", opts.OrgID) +func GetOrgUsersByOrgID(ctx context.Context, opts *FindOrgMembersOpts) ([]*OrgUser, error) { + sess := db.GetEngine(ctx).Where("org_id=?", opts.OrgID) if opts.PublicOnly { sess.And("is_public = ?", true) } @@ -615,8 +581,8 @@ func AddOrgUser(orgID, uid int64) error { return committer.Commit() } -// GetOrgByIDCtx returns the user object by given ID if exists. -func GetOrgByIDCtx(ctx context.Context, id int64) (*Organization, error) { +// GetOrgByID returns the user object by given ID if exists. +func GetOrgByID(ctx context.Context, id int64) (*Organization, error) { u := new(Organization) has, err := db.GetEngine(ctx).ID(id).Get(u) if err != nil { @@ -631,11 +597,6 @@ func GetOrgByIDCtx(ctx context.Context, id int64) (*Organization, error) { return u, nil } -// GetOrgByID returns the user object by given ID if exists. -func GetOrgByID(id int64) (*Organization, error) { - return GetOrgByIDCtx(db.DefaultContext, id) -} - // RemoveOrgRepo removes all team-repository relations of organization. func RemoveOrgRepo(ctx context.Context, orgID, repoID int64) error { teamRepos := make([]*TeamRepo, 0, 10) @@ -664,9 +625,9 @@ func RemoveOrgRepo(ctx context.Context, orgID, repoID int64) error { return err } -func (org *Organization) getUserTeams(e db.Engine, userID int64, cols ...string) ([]*Team, error) { +func (org *Organization) getUserTeams(ctx context.Context, userID int64, cols ...string) ([]*Team, error) { teams := make([]*Team, 0, org.NumTeams) - return teams, e. + return teams, db.GetEngine(ctx). Where("`team_user`.org_id = ?", org.ID). Join("INNER", "team_user", "`team_user`.team_id = team.id"). Join("INNER", "`user`", "`user`.id=team_user.uid"). @@ -700,7 +661,7 @@ func (org *Organization) GetUserTeamIDs(userID int64) ([]int64, error) { // GetUserTeams returns all teams that belong to user, // and that the user has joined. func (org *Organization) GetUserTeams(userID int64) ([]*Team, error) { - return org.getUserTeams(db.GetEngine(db.DefaultContext), userID) + return org.getUserTeams(db.DefaultContext, userID) } // AccessibleReposEnvironment operations involving the repositories that are diff --git a/models/organization/org_test.go b/models/organization/org_test.go index 71cdbd869..b408a2f36 100644 --- a/models/organization/org_test.go +++ b/models/organization/org_test.go @@ -128,9 +128,11 @@ func TestGetOrgByName(t *testing.T) { func TestCountOrganizations(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - expected, err := db.GetEngine(db.DefaultContext).Where("type=?", user_model.UserTypeOrganization).Count(&user_model.User{}) + expected, err := db.GetEngine(db.DefaultContext).Where("type=?", user_model.UserTypeOrganization).Count(&Organization{}) assert.NoError(t, err) - assert.Equal(t, expected, CountOrganizations()) + cnt, err := CountOrgs(FindOrgOptions{IncludePrivate: true}) + assert.NoError(t, err) + assert.Equal(t, expected, cnt) } func TestIsOrganizationOwner(t *testing.T) { @@ -204,35 +206,6 @@ func TestFindOrgs(t *testing.T) { assert.EqualValues(t, 1, total) } -func TestGetOwnedOrgsByUserID(t *testing.T) { - assert.NoError(t, unittest.PrepareTestDatabase()) - - orgs, err := GetOwnedOrgsByUserID(2) - assert.NoError(t, err) - if assert.Len(t, orgs, 1) { - assert.EqualValues(t, 3, orgs[0].ID) - } - - orgs, err = GetOwnedOrgsByUserID(4) - assert.NoError(t, err) - assert.Len(t, orgs, 0) -} - -func TestGetOwnedOrgsByUserIDDesc(t *testing.T) { - assert.NoError(t, unittest.PrepareTestDatabase()) - - orgs, err := GetOwnedOrgsByUserIDDesc(5, "id") - assert.NoError(t, err) - if assert.Len(t, orgs, 2) { - assert.EqualValues(t, 7, orgs[0].ID) - assert.EqualValues(t, 6, orgs[1].ID) - } - - orgs, err = GetOwnedOrgsByUserIDDesc(4, "id") - assert.NoError(t, err) - assert.Len(t, orgs, 0) -} - func TestGetOrgUsersByUserID(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) @@ -266,7 +239,7 @@ func TestGetOrgUsersByUserID(t *testing.T) { func TestGetOrgUsersByOrgID(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - orgUsers, err := GetOrgUsersByOrgID(&FindOrgMembersOpts{ + orgUsers, err := GetOrgUsersByOrgID(db.DefaultContext, &FindOrgMembersOpts{ ListOptions: db.ListOptions{}, OrgID: 3, PublicOnly: false, @@ -287,7 +260,7 @@ func TestGetOrgUsersByOrgID(t *testing.T) { }, *orgUsers[1]) } - orgUsers, err = GetOrgUsersByOrgID(&FindOrgMembersOpts{ + orgUsers, err = GetOrgUsersByOrgID(db.DefaultContext, &FindOrgMembersOpts{ ListOptions: db.ListOptions{}, OrgID: unittest.NonexistentID, PublicOnly: false, diff --git a/models/organization/org_user_test.go b/models/organization/org_user_test.go index 5fee253c4..1e85f4ed2 100644 --- a/models/organization/org_user_test.go +++ b/models/organization/org_user_test.go @@ -91,7 +91,7 @@ func TestUserListIsPublicMember(t *testing.T) { } func testUserListIsPublicMember(t *testing.T, orgID int64, expected map[int64]bool) { - org, err := GetOrgByID(orgID) + org, err := GetOrgByID(db.DefaultContext, orgID) assert.NoError(t, err) _, membersIsPublic, err := org.GetMembers() assert.NoError(t, err) @@ -118,7 +118,7 @@ func TestUserListIsUserOrgOwner(t *testing.T) { } func testUserListIsUserOrgOwner(t *testing.T, orgID int64, expected map[int64]bool) { - org, err := GetOrgByID(orgID) + org, err := GetOrgByID(db.DefaultContext, orgID) assert.NoError(t, err) members, _, err := org.GetMembers() assert.NoError(t, err) diff --git a/models/organization/team.go b/models/organization/team.go index 077fba6a6..b32ffa6ca 100644 --- a/models/organization/team.go +++ b/models/organization/team.go @@ -272,7 +272,8 @@ func IsUsableTeamName(name string) error { } } -func getTeam(ctx context.Context, orgID int64, name string) (*Team, error) { +// GetTeam returns team by given team name and organization. +func GetTeam(ctx context.Context, orgID int64, name string) (*Team, error) { t := &Team{ OrgID: orgID, LowerName: strings.ToLower(name), @@ -286,16 +287,11 @@ func getTeam(ctx context.Context, orgID int64, name string) (*Team, error) { return t, nil } -// GetTeam returns team by given team name and organization. -func GetTeam(orgID int64, name string) (*Team, error) { - return getTeam(db.DefaultContext, orgID, name) -} - // GetTeamIDsByNames returns a slice of team ids corresponds to names. func GetTeamIDsByNames(orgID int64, names []string, ignoreNonExistent bool) ([]int64, error) { ids := make([]int64, 0, len(names)) for _, name := range names { - u, err := GetTeam(orgID, name) + u, err := GetTeam(db.DefaultContext, orgID, name) if err != nil { if ignoreNonExistent { continue @@ -310,11 +306,11 @@ func GetTeamIDsByNames(orgID int64, names []string, ignoreNonExistent bool) ([]i // GetOwnerTeam returns team by given team name and organization. func GetOwnerTeam(ctx context.Context, orgID int64) (*Team, error) { - return getTeam(ctx, orgID, OwnerTeamName) + return GetTeam(ctx, orgID, OwnerTeamName) } -// GetTeamByIDCtx returns team by given ID. -func GetTeamByIDCtx(ctx context.Context, teamID int64) (*Team, error) { +// GetTeamByID returns team by given ID. +func GetTeamByID(ctx context.Context, teamID int64) (*Team, error) { t := new(Team) has, err := db.GetEngine(ctx).ID(teamID).Get(t) if err != nil { @@ -325,11 +321,6 @@ func GetTeamByIDCtx(ctx context.Context, teamID int64) (*Team, error) { return t, nil } -// GetTeamByID returns team by given ID. -func GetTeamByID(teamID int64) (*Team, error) { - return GetTeamByIDCtx(db.DefaultContext, teamID) -} - // GetTeamNamesByID returns team's lower name from a list of team ids. func GetTeamNamesByID(teamIDs []int64) ([]string, error) { if len(teamIDs) == 0 { @@ -346,16 +337,12 @@ func GetTeamNamesByID(teamIDs []int64) ([]string, error) { return teamNames, err } -func getRepoTeams(e db.Engine, repo *repo_model.Repository) (teams []*Team, err error) { - return teams, e. +// GetRepoTeams gets the list of teams that has access to the repository +func GetRepoTeams(ctx context.Context, repo *repo_model.Repository) (teams []*Team, err error) { + return teams, db.GetEngine(ctx). Join("INNER", "team_repo", "team_repo.team_id = team.id"). Where("team.org_id = ?", repo.OwnerID). And("team_repo.repo_id=?", repo.ID). OrderBy("CASE WHEN name LIKE '" + OwnerTeamName + "' THEN '' ELSE name END"). Find(&teams) } - -// GetRepoTeams gets the list of teams that has access to the repository -func GetRepoTeams(repo *repo_model.Repository) ([]*Team, error) { - return getRepoTeams(db.GetEngine(db.DefaultContext), repo) -} diff --git a/models/organization/team_test.go b/models/organization/team_test.go index bbf9f789f..860a7107e 100644 --- a/models/organization/team_test.go +++ b/models/organization/team_test.go @@ -71,7 +71,7 @@ func TestGetTeam(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) testSuccess := func(orgID int64, name string) { - team, err := GetTeam(orgID, name) + team, err := GetTeam(db.DefaultContext, orgID, name) assert.NoError(t, err) assert.EqualValues(t, orgID, team.OrgID) assert.Equal(t, name, team.Name) @@ -79,9 +79,9 @@ func TestGetTeam(t *testing.T) { testSuccess(3, "Owners") testSuccess(3, "team1") - _, err := GetTeam(3, "nonexistent") + _, err := GetTeam(db.DefaultContext, 3, "nonexistent") assert.Error(t, err) - _, err = GetTeam(unittest.NonexistentID, "Owners") + _, err = GetTeam(db.DefaultContext, unittest.NonexistentID, "Owners") assert.Error(t, err) } @@ -89,7 +89,7 @@ func TestGetTeamByID(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) testSuccess := func(teamID int64) { - team, err := GetTeamByID(teamID) + team, err := GetTeamByID(db.DefaultContext, teamID) assert.NoError(t, err) assert.EqualValues(t, teamID, team.ID) } @@ -98,7 +98,7 @@ func TestGetTeamByID(t *testing.T) { testSuccess(3) testSuccess(4) - _, err := GetTeamByID(unittest.NonexistentID) + _, err := GetTeamByID(db.DefaultContext, unittest.NonexistentID) assert.Error(t, err) } diff --git a/models/perm/access/access.go b/models/perm/access/access.go index 75a3a93a3..764751902 100644 --- a/models/perm/access/access.go +++ b/models/perm/access/access.go @@ -30,7 +30,7 @@ func init() { db.RegisterModel(new(Access)) } -func accessLevel(e db.Engine, user *user_model.User, repo *repo_model.Repository) (perm.AccessMode, error) { +func accessLevel(ctx context.Context, user *user_model.User, repo *repo_model.Repository) (perm.AccessMode, error) { mode := perm.AccessModeNone var userID int64 restricted := false @@ -53,7 +53,7 @@ func accessLevel(e db.Engine, user *user_model.User, repo *repo_model.Repository } a := &Access{UserID: userID, RepoID: repo.ID} - if has, err := e.Get(a); !has || err != nil { + if has, err := db.GetByBean(ctx, a); !has || err != nil { return mode, err } return a.Mode, nil @@ -84,7 +84,7 @@ func updateUserAccess(accessMap map[int64]*userAccess, user *user_model.User, mo } // FIXME: do cross-comparison so reduce deletions and additions to the minimum? -func refreshAccesses(e db.Engine, repo *repo_model.Repository, accessMap map[int64]*userAccess) (err error) { +func refreshAccesses(ctx context.Context, repo *repo_model.Repository, accessMap map[int64]*userAccess) (err error) { minMode := perm.AccessModeRead if !repo.IsPrivate { minMode = perm.AccessModeWrite @@ -104,14 +104,14 @@ func refreshAccesses(e db.Engine, repo *repo_model.Repository, accessMap map[int } // Delete old accesses and insert new ones for repository. - if _, err = e.Delete(&Access{RepoID: repo.ID}); err != nil { + if _, err = db.DeleteByBean(ctx, &Access{RepoID: repo.ID}); err != nil { return fmt.Errorf("delete old accesses: %v", err) } if len(newAccesses) == 0 { return nil } - if _, err = e.Insert(newAccesses); err != nil { + if err = db.Insert(ctx, newAccesses); err != nil { return fmt.Errorf("insert new accesses: %v", err) } return nil @@ -144,8 +144,6 @@ func RecalculateTeamAccesses(ctx context.Context, repo *repo_model.Repository, i return fmt.Errorf("owner is not an organization: %d", repo.OwnerID) } - e := db.GetEngine(ctx) - if err = refreshCollaboratorAccesses(ctx, repo.ID, accessMap); err != nil { return fmt.Errorf("refreshCollaboratorAccesses: %v", err) } @@ -176,7 +174,7 @@ func RecalculateTeamAccesses(ctx context.Context, repo *repo_model.Repository, i } } - return refreshAccesses(e, repo, accessMap) + return refreshAccesses(ctx, repo, accessMap) } // RecalculateUserAccess recalculates new access for a single user @@ -235,10 +233,9 @@ func RecalculateAccesses(ctx context.Context, repo *repo_model.Repository) error return RecalculateTeamAccesses(ctx, repo, 0) } - e := db.GetEngine(ctx) accessMap := make(map[int64]*userAccess, 20) if err := refreshCollaboratorAccesses(ctx, repo.ID, accessMap); err != nil { return fmt.Errorf("refreshCollaboratorAccesses: %v", err) } - return refreshAccesses(e, repo, accessMap) + return refreshAccesses(ctx, repo, accessMap) } diff --git a/models/perm/access/repo_permission.go b/models/perm/access/repo_permission.go index 090c78ff2..6bc1c8270 100644 --- a/models/perm/access/repo_permission.go +++ b/models/perm/access/repo_permission.go @@ -168,8 +168,6 @@ func GetUserRepoPermission(ctx context.Context, repo *repo_model.Repository, use return } - e := db.GetEngine(ctx) - var is bool if user != nil { is, err = repo_model.IsCollaborator(ctx, repo.ID, user.ID) @@ -208,7 +206,7 @@ func GetUserRepoPermission(ctx context.Context, repo *repo_model.Repository, use } // plain user - perm.AccessMode, err = accessLevel(e, user, repo) + perm.AccessMode, err = accessLevel(ctx, user, repo) if err != nil { return } @@ -288,7 +286,7 @@ func IsUserRealRepoAdmin(repo *repo_model.Repository, user *user_model.User) (bo return false, err } - accessMode, err := accessLevel(db.GetEngine(db.DefaultContext), user, repo) + accessMode, err := accessLevel(db.DefaultContext, user, repo) if err != nil { return false, err } @@ -297,12 +295,7 @@ func IsUserRealRepoAdmin(repo *repo_model.Repository, user *user_model.User) (bo } // IsUserRepoAdmin return true if user has admin right of a repo -func IsUserRepoAdmin(repo *repo_model.Repository, user *user_model.User) (bool, error) { - return IsUserRepoAdminCtx(db.DefaultContext, repo, user) -} - -// IsUserRepoAdminCtx return true if user has admin right of a repo -func IsUserRepoAdminCtx(ctx context.Context, repo *repo_model.Repository, user *user_model.User) (bool, error) { +func IsUserRepoAdmin(ctx context.Context, repo *repo_model.Repository, user *user_model.User) (bool, error) { if user == nil || repo == nil { return false, nil } @@ -310,8 +303,7 @@ func IsUserRepoAdminCtx(ctx context.Context, repo *repo_model.Repository, user * return true, nil } - e := db.GetEngine(ctx) - mode, err := accessLevel(e, user, repo) + mode, err := accessLevel(ctx, user, repo) if err != nil { return false, err } @@ -377,7 +369,7 @@ func HasAccess(ctx context.Context, userID int64, repo *repo_model.Repository) ( var user *user_model.User var err error if userID > 0 { - user, err = user_model.GetUserByIDEngine(db.GetEngine(ctx), userID) + user, err = user_model.GetUserByIDCtx(ctx, userID) if err != nil { return false, err } diff --git a/models/project/board.go b/models/project/board.go index f770a18f5..be7119ee4 100644 --- a/models/project/board.go +++ b/models/project/board.go @@ -147,8 +147,7 @@ func DeleteBoardByID(boardID int64) error { } func deleteBoardByID(ctx context.Context, boardID int64) error { - e := db.GetEngine(ctx) - board, err := getBoard(e, boardID) + board, err := GetBoard(ctx, boardID) if err != nil { if IsErrProjectBoardNotExist(err) { return nil @@ -157,30 +156,26 @@ func deleteBoardByID(ctx context.Context, boardID int64) error { return err } - if err = board.removeIssues(e); err != nil { + if err = board.removeIssues(ctx); err != nil { return err } - if _, err := e.ID(board.ID).Delete(board); err != nil { + if _, err := db.GetEngine(ctx).ID(board.ID).NoAutoCondition().Delete(board); err != nil { return err } return nil } -func deleteBoardByProjectID(e db.Engine, projectID int64) error { - _, err := e.Where("project_id=?", projectID).Delete(&Board{}) +func deleteBoardByProjectID(ctx context.Context, projectID int64) error { + _, err := db.GetEngine(ctx).Where("project_id=?", projectID).Delete(&Board{}) return err } // GetBoard fetches the current board of a project -func GetBoard(boardID int64) (*Board, error) { - return getBoard(db.GetEngine(db.DefaultContext), boardID) -} - -func getBoard(e db.Engine, boardID int64) (*Board, error) { +func GetBoard(ctx context.Context, boardID int64) (*Board, error) { board := new(Board) - has, err := e.ID(boardID).Get(board) + has, err := db.GetEngine(ctx).ID(boardID).Get(board) if err != nil { return nil, err } else if !has { @@ -191,11 +186,7 @@ func getBoard(e db.Engine, boardID int64) (*Board, error) { } // UpdateBoard updates a project board -func UpdateBoard(board *Board) error { - return updateBoard(db.GetEngine(db.DefaultContext), board) -} - -func updateBoard(e db.Engine, board *Board) error { +func UpdateBoard(ctx context.Context, board *Board) error { var fieldToUpdate []string if board.Sorting != 0 { @@ -211,25 +202,21 @@ func updateBoard(e db.Engine, board *Board) error { } fieldToUpdate = append(fieldToUpdate, "color") - _, err := e.ID(board.ID).Cols(fieldToUpdate...).Update(board) + _, err := db.GetEngine(ctx).ID(board.ID).Cols(fieldToUpdate...).Update(board) return err } // GetBoards fetches all boards related to a project // if no default board set, first board is a temporary "Uncategorized" board -func GetBoards(projectID int64) (BoardList, error) { - return getBoards(db.GetEngine(db.DefaultContext), projectID) -} - -func getBoards(e db.Engine, projectID int64) ([]*Board, error) { +func GetBoards(ctx context.Context, projectID int64) (BoardList, error) { boards := make([]*Board, 0, 5) - if err := e.Where("project_id=? AND `default`=?", projectID, false).OrderBy("Sorting").Find(&boards); err != nil { + if err := db.GetEngine(ctx).Where("project_id=? AND `default`=?", projectID, false).OrderBy("Sorting").Find(&boards); err != nil { return nil, err } - defaultB, err := getDefaultBoard(e, projectID) + defaultB, err := getDefaultBoard(ctx, projectID) if err != nil { return nil, err } @@ -238,9 +225,9 @@ func getBoards(e db.Engine, projectID int64) ([]*Board, error) { } // getDefaultBoard return default board and create a dummy if none exist -func getDefaultBoard(e db.Engine, projectID int64) (*Board, error) { +func getDefaultBoard(ctx context.Context, projectID int64) (*Board, error) { var board Board - exist, err := e.Where("project_id=? AND `default`=?", projectID, true).Get(&board) + exist, err := db.GetEngine(ctx).Where("project_id=? AND `default`=?", projectID, true).Get(&board) if err != nil { return nil, err } diff --git a/models/project/issue.go b/models/project/issue.go index 6bde91668..04efc0e74 100644 --- a/models/project/issue.go +++ b/models/project/issue.go @@ -28,8 +28,8 @@ func init() { db.RegisterModel(new(ProjectIssue)) } -func deleteProjectIssuesByProjectID(e db.Engine, projectID int64) error { - _, err := e.Where("project_id=?", projectID).Delete(&ProjectIssue{}) +func deleteProjectIssuesByProjectID(ctx context.Context, projectID int64) error { + _, err := db.GetEngine(ctx).Where("project_id=?", projectID).Delete(&ProjectIssue{}) return err } @@ -97,7 +97,7 @@ func MoveIssuesOnProjectBoard(board *Board, sortedIssueIDs map[int64]int64) erro }) } -func (pb *Board) removeIssues(e db.Engine) error { - _, err := e.Exec("UPDATE `project_issue` SET project_board_id = 0 WHERE project_board_id = ? ", pb.ID) +func (pb *Board) removeIssues(ctx context.Context) error { + _, err := db.GetEngine(ctx).Exec("UPDATE `project_issue` SET project_board_id = 0 WHERE project_board_id = ? ", pb.ID) return err } diff --git a/models/project/project.go b/models/project/project.go index a639879e7..0aa37cc5c 100644 --- a/models/project/project.go +++ b/models/project/project.go @@ -121,12 +121,7 @@ type SearchOptions struct { } // GetProjects returns a list of all projects that have been created in the repository -func GetProjects(opts SearchOptions) ([]*Project, int64, error) { - return GetProjectsCtx(db.DefaultContext, opts) -} - -// GetProjectsCtx returns a list of all projects that have been created in the repository -func GetProjectsCtx(ctx context.Context, opts SearchOptions) ([]*Project, int64, error) { +func GetProjects(ctx context.Context, opts SearchOptions) ([]*Project, int64, error) { e := db.GetEngine(ctx) projects := make([]*Project, 0, setting.UI.IssuePagingNum) @@ -199,14 +194,10 @@ func NewProject(p *Project) error { } // GetProjectByID returns the projects in a repository -func GetProjectByID(id int64) (*Project, error) { - return getProjectByID(db.GetEngine(db.DefaultContext), id) -} - -func getProjectByID(e db.Engine, id int64) (*Project, error) { +func GetProjectByID(ctx context.Context, id int64) (*Project, error) { p := new(Project) - has, err := e.ID(id).Get(p) + has, err := db.GetEngine(ctx).ID(id).Get(p) if err != nil { return nil, err } else if !has { @@ -217,20 +208,16 @@ func getProjectByID(e db.Engine, id int64) (*Project, error) { } // UpdateProject updates project properties -func UpdateProject(p *Project) error { - return updateProject(db.GetEngine(db.DefaultContext), p) -} - -func updateProject(e db.Engine, p *Project) error { - _, err := e.ID(p.ID).Cols( +func UpdateProject(ctx context.Context, p *Project) error { + _, err := db.GetEngine(ctx).ID(p.ID).Cols( "title", "description", ).Update(p) return err } -func updateRepositoryProjectCount(e db.Engine, repoID int64) error { - if _, err := e.Exec(builder.Update( +func updateRepositoryProjectCount(ctx context.Context, repoID int64) error { + if _, err := db.GetEngine(ctx).Exec(builder.Update( builder.Eq{ "`num_projects`": builder.Select("count(*)").From("`project`"). Where(builder.Eq{"`project`.`repo_id`": repoID}. @@ -239,7 +226,7 @@ func updateRepositoryProjectCount(e db.Engine, repoID int64) error { return err } - if _, err := e.Exec(builder.Update( + if _, err := db.GetEngine(ctx).Exec(builder.Update( builder.Eq{ "`num_closed_projects`": builder.Select("count(*)").From("`project`"). Where(builder.Eq{"`project`.`repo_id`": repoID}. @@ -293,8 +280,7 @@ func ChangeProjectStatus(p *Project, isClosed bool) error { func changeProjectStatus(ctx context.Context, p *Project, isClosed bool) error { p.IsClosed = isClosed p.ClosedDateUnix = timeutil.TimeStampNow() - e := db.GetEngine(ctx) - count, err := e.ID(p.ID).Where("repo_id = ? AND is_closed = ?", p.RepoID, !isClosed).Cols("is_closed", "closed_date_unix").Update(p) + count, err := db.GetEngine(ctx).ID(p.ID).Where("repo_id = ? AND is_closed = ?", p.RepoID, !isClosed).Cols("is_closed", "closed_date_unix").Update(p) if err != nil { return err } @@ -302,7 +288,7 @@ func changeProjectStatus(ctx context.Context, p *Project, isClosed bool) error { return nil } - return updateRepositoryProjectCount(e, p.RepoID) + return updateRepositoryProjectCount(ctx, p.RepoID) } // DeleteProjectByID deletes a project from a repository. @@ -322,8 +308,7 @@ func DeleteProjectByID(id int64) error { // DeleteProjectByIDCtx deletes a project from a repository. func DeleteProjectByIDCtx(ctx context.Context, id int64) error { - e := db.GetEngine(ctx) - p, err := getProjectByID(e, id) + p, err := GetProjectByID(ctx, id) if err != nil { if IsErrProjectNotExist(err) { return nil @@ -331,17 +316,17 @@ func DeleteProjectByIDCtx(ctx context.Context, id int64) error { return err } - if err := deleteProjectIssuesByProjectID(e, id); err != nil { + if err := deleteProjectIssuesByProjectID(ctx, id); err != nil { return err } - if err := deleteBoardByProjectID(e, id); err != nil { + if err := deleteBoardByProjectID(ctx, id); err != nil { return err } - if _, err = e.ID(p.ID).Delete(new(Project)); err != nil { + if _, err = db.GetEngine(ctx).ID(p.ID).Delete(new(Project)); err != nil { return err } - return updateRepositoryProjectCount(e, p.RepoID) + return updateRepositoryProjectCount(ctx, p.RepoID) } diff --git a/models/project/project_test.go b/models/project/project_test.go index 211a89087..f33fb3351 100644 --- a/models/project/project_test.go +++ b/models/project/project_test.go @@ -7,6 +7,7 @@ package project import ( "testing" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/models/unittest" "code.gitea.io/gitea/modules/timeutil" @@ -34,13 +35,13 @@ func TestIsProjectTypeValid(t *testing.T) { func TestGetProjects(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - projects, _, err := GetProjects(SearchOptions{RepoID: 1}) + projects, _, err := GetProjects(db.DefaultContext, SearchOptions{RepoID: 1}) assert.NoError(t, err) // 1 value for this repo exists in the fixtures assert.Len(t, projects, 1) - projects, _, err = GetProjects(SearchOptions{RepoID: 3}) + projects, _, err = GetProjects(db.DefaultContext, SearchOptions{RepoID: 3}) assert.NoError(t, err) // 1 value for this repo exists in the fixtures @@ -61,14 +62,14 @@ func TestProject(t *testing.T) { assert.NoError(t, NewProject(project)) - _, err := GetProjectByID(project.ID) + _, err := GetProjectByID(db.DefaultContext, project.ID) assert.NoError(t, err) // Update project project.Title = "Updated title" - assert.NoError(t, UpdateProject(project)) + assert.NoError(t, UpdateProject(db.DefaultContext, project)) - projectFromDB, err := GetProjectByID(project.ID) + projectFromDB, err := GetProjectByID(db.DefaultContext, project.ID) assert.NoError(t, err) assert.Equal(t, project.Title, projectFromDB.Title) @@ -76,7 +77,7 @@ func TestProject(t *testing.T) { assert.NoError(t, ChangeProjectStatus(project, true)) // Retrieve from DB afresh to check if it is truly closed - projectFromDB, err = GetProjectByID(project.ID) + projectFromDB, err = GetProjectByID(db.DefaultContext, project.ID) assert.NoError(t, err) assert.True(t, projectFromDB.IsClosed) diff --git a/models/pull.go b/models/pull.go index 8eab7569c..bb5bb1181 100644 --- a/models/pull.go +++ b/models/pull.go @@ -97,22 +97,22 @@ func init() { db.RegisterModel(new(PullRequest)) } -func deletePullsByBaseRepoID(sess db.Engine, repoID int64) error { +func deletePullsByBaseRepoID(ctx context.Context, repoID int64) error { deleteCond := builder.Select("id").From("pull_request").Where(builder.Eq{"pull_request.base_repo_id": repoID}) // Delete scheduled auto merges - if _, err := sess.In("pull_id", deleteCond). + if _, err := db.GetEngine(ctx).In("pull_id", deleteCond). Delete(&pull_model.AutoMerge{}); err != nil { return err } // Delete review states - if _, err := sess.In("pull_id", deleteCond). + if _, err := db.GetEngine(ctx).In("pull_id", deleteCond). Delete(&pull_model.ReviewState{}); err != nil { return err } - _, err := sess.Delete(&PullRequest{BaseRepoID: repoID}) + _, err := db.DeleteByBean(ctx, &PullRequest{BaseRepoID: repoID}) return err } @@ -133,9 +133,9 @@ func (pr *PullRequest) MustHeadUserName() string { } // Note: don't try to get Issue because will end up recursive querying. -func (pr *PullRequest) loadAttributes(e db.Engine) (err error) { +func (pr *PullRequest) loadAttributes(ctx context.Context) (err error) { if pr.HasMerged && pr.Merger == nil { - pr.Merger, err = user_model.GetUserByIDEngine(e, pr.MergerID) + pr.Merger, err = user_model.GetUserByIDCtx(ctx, pr.MergerID) if user_model.IsErrUserNotExist(err) { pr.MergerID = -1 pr.Merger = user_model.NewGhostUser() @@ -149,7 +149,7 @@ func (pr *PullRequest) loadAttributes(e db.Engine) (err error) { // LoadAttributes loads pull request attributes from database func (pr *PullRequest) LoadAttributes() error { - return pr.loadAttributes(db.GetEngine(db.DefaultContext)) + return pr.loadAttributes(db.DefaultContext) } // LoadHeadRepoCtx loads the head repository @@ -218,7 +218,7 @@ func (pr *PullRequest) LoadIssueCtx(ctx context.Context) (err error) { return nil } - pr.Issue, err = getIssueByID(db.GetEngine(ctx), pr.IssueID) + pr.Issue, err = getIssueByID(ctx, pr.IssueID) if err == nil { pr.Issue.PullRequest = pr } @@ -242,7 +242,7 @@ func (pr *PullRequest) LoadProtectedBranchCtx(ctx context.Context) (err error) { return } } - pr.ProtectedBranch, err = getProtectedBranchBy(db.GetEngine(ctx), pr.BaseRepo.ID, pr.BaseBranch) + pr.ProtectedBranch, err = GetProtectedBranchBy(ctx, pr.BaseRepo.ID, pr.BaseBranch) } return } @@ -256,13 +256,9 @@ type ReviewCount struct { // GetApprovalCounts returns the approval counts by type // FIXME: Only returns official counts due to double counting of non-official counts -func (pr *PullRequest) GetApprovalCounts() ([]*ReviewCount, error) { - return pr.getApprovalCounts(db.GetEngine(db.DefaultContext)) -} - -func (pr *PullRequest) getApprovalCounts(e db.Engine) ([]*ReviewCount, error) { +func (pr *PullRequest) GetApprovalCounts(ctx context.Context) ([]*ReviewCount, error) { rCounts := make([]*ReviewCount, 0, 6) - sess := e.Where("issue_id = ?", pr.IssueID) + sess := db.GetEngine(ctx).Where("issue_id = ?", pr.IssueID) return rCounts, sess.Select("issue_id, type, count(id) as `count`").Where("official = ? AND dismissed = ?", true, false).GroupBy("issue_id, type").Table("review").Find(&rCounts) } @@ -289,10 +285,9 @@ func (pr *PullRequest) getReviewedByLines(writer io.Writer) error { return err } defer committer.Close() - sess := db.GetEngine(ctx) // Note: This doesn't page as we only expect a very limited number of reviews - reviews, err := findReviews(sess, FindReviewOptions{ + reviews, err := FindReviews(ctx, FindReviewOptions{ Type: ReviewTypeApprove, IssueID: pr.IssueID, OfficialOnly: setting.Repository.PullRequest.DefaultMergeMessageOfficialApproversOnly, @@ -309,7 +304,7 @@ func (pr *PullRequest) getReviewedByLines(writer io.Writer) error { break } - if err := review.loadReviewer(sess); err != nil && !user_model.IsErrUserNotExist(err) { + if err := review.loadReviewer(ctx); err != nil && !user_model.IsErrUserNotExist(err) { log.Error("Unable to LoadReviewer[%d] for PR ID %d : %v", review.ReviewerID, pr.ID, err) return err } else if review.Reviewer == nil { @@ -374,7 +369,7 @@ func (pr *PullRequest) SetMerged(ctx context.Context) (bool, error) { return false, err } - if tmpPr, err := getPullRequestByID(sess, pr.ID); err != nil { + if tmpPr, err := GetPullRequestByID(ctx, pr.ID); err != nil { return false, err } else if tmpPr.HasMerged { if pr.Issue.IsClosed { @@ -484,12 +479,7 @@ func GetLatestPullRequestByHeadInfo(repoID int64, branch string) (*PullRequest, } // GetPullRequestByIndex returns a pull request by the given index -func GetPullRequestByIndex(repoID, index int64) (*PullRequest, error) { - return GetPullRequestByIndexCtx(db.DefaultContext, repoID, index) -} - -// GetPullRequestByIndexCtx returns a pull request by the given index -func GetPullRequestByIndexCtx(ctx context.Context, repoID, index int64) (*PullRequest, error) { +func GetPullRequestByIndex(ctx context.Context, repoID, index int64) (*PullRequest, error) { if index < 1 { return nil, ErrPullRequestNotExist{} } @@ -505,7 +495,7 @@ func GetPullRequestByIndexCtx(ctx context.Context, repoID, index int64) (*PullRe return nil, ErrPullRequestNotExist{0, 0, 0, repoID, "", ""} } - if err = pr.loadAttributes(db.GetEngine(ctx)); err != nil { + if err = pr.loadAttributes(ctx); err != nil { return nil, err } if err = pr.LoadIssueCtx(ctx); err != nil { @@ -515,20 +505,16 @@ func GetPullRequestByIndexCtx(ctx context.Context, repoID, index int64) (*PullRe return pr, nil } -func getPullRequestByID(e db.Engine, id int64) (*PullRequest, error) { +// GetPullRequestByID returns a pull request by given ID. +func GetPullRequestByID(ctx context.Context, id int64) (*PullRequest, error) { pr := new(PullRequest) - has, err := e.ID(id).Get(pr) + has, err := db.GetEngine(ctx).ID(id).Get(pr) if err != nil { return nil, err } else if !has { return nil, ErrPullRequestNotExist{id, 0, 0, 0, "", ""} } - return pr, pr.loadAttributes(e) -} - -// GetPullRequestByID returns a pull request by given ID. -func GetPullRequestByID(ctx context.Context, id int64) (*PullRequest, error) { - return getPullRequestByID(db.GetEngine(ctx), id) + return pr, pr.loadAttributes(ctx) } // GetPullRequestByIssueIDWithNoAttributes returns pull request with no attributes loaded by given issue ID. @@ -544,17 +530,18 @@ func GetPullRequestByIssueIDWithNoAttributes(issueID int64) (*PullRequest, error return &pr, nil } -func getPullRequestByIssueID(e db.Engine, issueID int64) (*PullRequest, error) { +// GetPullRequestByIssueID returns pull request by given issue ID. +func GetPullRequestByIssueID(ctx context.Context, issueID int64) (*PullRequest, error) { pr := &PullRequest{ IssueID: issueID, } - has, err := e.Get(pr) + has, err := db.GetByBean(ctx, pr) if err != nil { return nil, err } else if !has { return nil, ErrPullRequestNotExist{0, issueID, 0, 0, "", ""} } - return pr, pr.loadAttributes(e) + return pr, pr.loadAttributes(ctx) } // GetAllUnmergedAgitPullRequestByPoster get all unmerged agit flow pull request @@ -571,11 +558,6 @@ func GetAllUnmergedAgitPullRequestByPoster(uid int64) ([]*PullRequest, error) { return pulls, err } -// GetPullRequestByIssueID returns pull request by given issue ID. -func GetPullRequestByIssueID(issueID int64) (*PullRequest, error) { - return getPullRequestByIssueID(db.GetEngine(db.DefaultContext), issueID) -} - // Update updates all fields of pull request. func (pr *PullRequest) Update() error { _, err := db.GetEngine(db.DefaultContext).ID(pr.ID).AllCols().Update(pr) @@ -635,17 +617,13 @@ func (pr *PullRequest) GetWorkInProgressPrefix() string { } // UpdateCommitDivergence update Divergence of a pull request -func (pr *PullRequest) UpdateCommitDivergence(ahead, behind int) error { - return pr.updateCommitDivergence(db.GetEngine(db.DefaultContext), ahead, behind) -} - -func (pr *PullRequest) updateCommitDivergence(e db.Engine, ahead, behind int) error { +func (pr *PullRequest) UpdateCommitDivergence(ctx context.Context, ahead, behind int) error { if pr.ID == 0 { return fmt.Errorf("pull ID is 0") } pr.CommitsAhead = ahead pr.CommitsBehind = behind - _, err := e.ID(pr.ID).Cols("commits_ahead", "commits_behind").Update(pr) + _, err := db.GetEngine(ctx).ID(pr.ID).Cols("commits_ahead", "commits_behind").Update(pr) return err } diff --git a/models/pull_list.go b/models/pull_list.go index 60b829977..fb14d3bea 100644 --- a/models/pull_list.go +++ b/models/pull_list.go @@ -156,7 +156,7 @@ func PullRequests(baseRepoID int64, opts *PullRequestsOptions) ([]*PullRequest, // PullRequestList defines a list of pull requests type PullRequestList []*PullRequest -func (prs PullRequestList) loadAttributes(e db.Engine) error { +func (prs PullRequestList) loadAttributes(ctx context.Context) error { if len(prs) == 0 { return nil } @@ -164,7 +164,7 @@ func (prs PullRequestList) loadAttributes(e db.Engine) error { // Load issues. issueIDs := prs.getIssueIDs() issues := make([]*Issue, 0, len(issueIDs)) - if err := e. + if err := db.GetEngine(ctx). Where("id > 0"). In("id", issueIDs). Find(&issues); err != nil { @@ -191,7 +191,7 @@ func (prs PullRequestList) getIssueIDs() []int64 { // LoadAttributes load all the prs attributes func (prs PullRequestList) LoadAttributes() error { - return prs.loadAttributes(db.GetEngine(db.DefaultContext)) + return prs.loadAttributes(db.DefaultContext) } // InvalidateCodeComments will lookup the prs for code comments which got invalidated by change diff --git a/models/pull_test.go b/models/pull_test.go index 6119bca69..00bbfc798 100644 --- a/models/pull_test.go +++ b/models/pull_test.go @@ -140,16 +140,16 @@ func TestGetUnmergedPullRequestsByBaseInfo(t *testing.T) { func TestGetPullRequestByIndex(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - pr, err := GetPullRequestByIndex(1, 2) + pr, err := GetPullRequestByIndex(db.DefaultContext, 1, 2) assert.NoError(t, err) assert.Equal(t, int64(1), pr.BaseRepoID) assert.Equal(t, int64(2), pr.Index) - _, err = GetPullRequestByIndex(9223372036854775807, 9223372036854775807) + _, err = GetPullRequestByIndex(db.DefaultContext, 9223372036854775807, 9223372036854775807) assert.Error(t, err) assert.True(t, IsErrPullRequestNotExist(err)) - _, err = GetPullRequestByIndex(1, 0) + _, err = GetPullRequestByIndex(db.DefaultContext, 1, 0) assert.Error(t, err) assert.True(t, IsErrPullRequestNotExist(err)) } @@ -168,11 +168,11 @@ func TestGetPullRequestByID(t *testing.T) { func TestGetPullRequestByIssueID(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - pr, err := GetPullRequestByIssueID(2) + pr, err := GetPullRequestByIssueID(db.DefaultContext, 2) assert.NoError(t, err) assert.Equal(t, int64(2), pr.IssueID) - _, err = GetPullRequestByIssueID(9223372036854775807) + _, err = GetPullRequestByIssueID(db.DefaultContext, 9223372036854775807) assert.Error(t, err) assert.True(t, IsErrPullRequestNotExist(err)) } diff --git a/models/release.go b/models/release.go index 0285f6bd5..c7e8bff83 100644 --- a/models/release.go +++ b/models/release.go @@ -52,7 +52,7 @@ func init() { db.RegisterModel(new(Release)) } -func (r *Release) loadAttributes(e db.Engine) error { +func (r *Release) loadAttributes(ctx context.Context) error { var err error if r.Repo == nil { r.Repo, err = repo_model.GetRepositoryByID(r.RepoID) @@ -61,7 +61,7 @@ func (r *Release) loadAttributes(e db.Engine) error { } } if r.Publisher == nil { - r.Publisher, err = user_model.GetUserByIDEngine(e, r.PublisherID) + r.Publisher, err = user_model.GetUserByIDCtx(ctx, r.PublisherID) if err != nil { if user_model.IsErrUserNotExist(err) { r.Publisher = user_model.NewGhostUser() @@ -70,12 +70,12 @@ func (r *Release) loadAttributes(e db.Engine) error { } } } - return getReleaseAttachments(e, r) + return GetReleaseAttachments(ctx, r) } // LoadAttributes load repo and publisher attributes for a release func (r *Release) LoadAttributes() error { - return r.loadAttributes(db.GetEngine(db.DefaultContext)) + return r.loadAttributes(db.DefaultContext) } // APIURL the api url for a release. release must have attributes loaded @@ -282,11 +282,7 @@ func (s releaseMetaSearch) Less(i, j int) bool { } // GetReleaseAttachments retrieves the attachments for releases -func GetReleaseAttachments(rels ...*Release) (err error) { - return getReleaseAttachments(db.GetEngine(db.DefaultContext), rels...) -} - -func getReleaseAttachments(e db.Engine, rels ...*Release) (err error) { +func GetReleaseAttachments(ctx context.Context, rels ...*Release) (err error) { if len(rels) == 0 { return } @@ -306,7 +302,7 @@ func getReleaseAttachments(e db.Engine, rels ...*Release) (err error) { sort.Sort(sortedRels) // Select attachments - err = e. + err = db.GetEngine(ctx). Asc("release_id", "name"). In("release_id", sortedRels.ID). Find(&attachments, repo_model.Attachment{}) @@ -373,10 +369,6 @@ func UpdateReleasesMigrationsByType(gitServiceType structs.GitServiceType, origi // PushUpdateDeleteTagsContext updates a number of delete tags with context func PushUpdateDeleteTagsContext(ctx context.Context, repo *repo_model.Repository, tags []string) error { - return pushUpdateDeleteTags(db.GetEngine(ctx), repo, tags) -} - -func pushUpdateDeleteTags(e db.Engine, repo *repo_model.Repository, tags []string) error { if len(tags) == 0 { return nil } @@ -385,14 +377,14 @@ func pushUpdateDeleteTags(e db.Engine, repo *repo_model.Repository, tags []strin lowerTags = append(lowerTags, strings.ToLower(tag)) } - if _, err := e. + if _, err := db.GetEngine(ctx). Where("repo_id = ? AND is_tag = ?", repo.ID, true). In("lower_tag_name", lowerTags). Delete(new(Release)); err != nil { return fmt.Errorf("Delete: %v", err) } - if _, err := e. + if _, err := db.GetEngine(ctx). Where("repo_id = ? AND is_tag = ?", repo.ID, false). In("lower_tag_name", lowerTags). Cols("is_draft", "num_commits", "sha1"). diff --git a/models/repo.go b/models/repo.go index 598eec7c9..d2ad56094 100644 --- a/models/repo.go +++ b/models/repo.go @@ -204,27 +204,23 @@ func GetReviewerTeams(repo *repo_model.Repository) ([]*organization.Team, error) return teams, err } -func updateRepoSize(e db.Engine, repo *repo_model.Repository) error { +// UpdateRepoSize updates the repository size, calculating it using util.GetDirectorySize +func UpdateRepoSize(ctx context.Context, repo *repo_model.Repository) error { size, err := util.GetDirectorySize(repo.RepoPath()) if err != nil { return fmt.Errorf("updateSize: %v", err) } - lfsSize, err := e.Where("repository_id = ?", repo.ID).SumInt(new(LFSMetaObject), "size") + lfsSize, err := db.GetEngine(ctx).Where("repository_id = ?", repo.ID).SumInt(new(LFSMetaObject), "size") if err != nil { return fmt.Errorf("updateSize: GetLFSMetaObjects: %v", err) } repo.Size = size + lfsSize - _, err = e.ID(repo.ID).Cols("size").NoAutoTime().Update(repo) + _, err = db.GetEngine(ctx).ID(repo.ID).Cols("size").NoAutoTime().Update(repo) return err } -// UpdateRepoSize updates the repository size, calculating it using util.GetDirectorySize -func UpdateRepoSize(ctx context.Context, repo *repo_model.Repository) error { - return updateRepoSize(db.GetEngine(ctx), repo) -} - // CanUserForkRepo returns true if specified user can fork repository. func CanUserForkRepo(user *user_model.User, repo *repo_model.Repository) (bool, error) { if user == nil { @@ -303,11 +299,6 @@ func CanUserDelete(repo *repo_model.Repository, user *user_model.User) (bool, er return false, nil } -// SetRepoReadBy sets repo to be visited by given user. -func SetRepoReadBy(repoID, userID int64) error { - return setRepoNotificationStatusReadIfUnread(db.GetEngine(db.DefaultContext), userID, repoID) -} - // CreateRepoOptions contains the create repository options type CreateRepoOptions struct { Name string @@ -334,7 +325,7 @@ func CreateRepository(ctx context.Context, doer, u *user_model.User, repo *repo_ return err } - has, err := repo_model.IsRepositoryExistCtx(ctx, u, repo.Name) + has, err := repo_model.IsRepositoryExist(ctx, u, repo.Name) if err != nil { return fmt.Errorf("IsRepositoryExist: %v", err) } else if has { @@ -398,7 +389,7 @@ func CreateRepository(ctx context.Context, doer, u *user_model.User, repo *repo_ // Remember visibility preference. u.LastRepoVisibility = repo.IsPrivate - if err = user_model.UpdateUserColsEngine(db.GetEngine(ctx), u, "last_repo_visibility"); err != nil { + if err = user_model.UpdateUserCols(ctx, u, "last_repo_visibility"); err != nil { return fmt.Errorf("updateUser: %v", err) } @@ -421,7 +412,7 @@ func CreateRepository(ctx context.Context, doer, u *user_model.User, repo *repo_ } } - if isAdmin, err := access_model.IsUserRepoAdminCtx(ctx, repo, doer); err != nil { + if isAdmin, err := access_model.IsUserRepoAdmin(ctx, repo, doer); err != nil { return fmt.Errorf("IsUserRepoAdminCtx: %v", err) } else if !isAdmin { // Make creator repo admin if it wasn't assigned automatically @@ -438,7 +429,7 @@ func CreateRepository(ctx context.Context, doer, u *user_model.User, repo *repo_ } if setting.Service.AutoWatchNewRepos { - if err = repo_model.WatchRepoCtx(ctx, doer.ID, repo.ID, true); err != nil { + if err = repo_model.WatchRepo(ctx, doer.ID, repo.ID, true); err != nil { return fmt.Errorf("watchRepo: %v", err) } } @@ -510,7 +501,7 @@ func UpdateRepositoryCtx(ctx context.Context, repo *repo_model.Repository, visib return fmt.Errorf("update: %v", err) } - if err = updateRepoSize(e, repo); err != nil { + if err = UpdateRepoSize(ctx, repo); err != nil { log.Error("Failed to update size for repository: %v", err) } @@ -536,13 +527,13 @@ func UpdateRepositoryCtx(ctx context.Context, repo *repo_model.Repository, visib } // Create/Remove git-daemon-export-ok for git-daemon... - if err := CheckDaemonExportOK(db.WithEngine(ctx, e), repo); err != nil { + if err := CheckDaemonExportOK(ctx, repo); err != nil { return err } forkRepos, err := repo_model.GetRepositoriesByForkID(ctx, repo.ID) if err != nil { - return fmt.Errorf("getRepositoriesByForkID: %v", err) + return fmt.Errorf("GetRepositoriesByForkID: %v", err) } for i := range forkRepos { forkRepos[i].IsPrivate = repo.IsPrivate || repo.Owner.Visibility == api.VisibleTypePrivate @@ -581,7 +572,7 @@ func DeleteRepository(doer *user_model.User, uid, repoID int64) error { sess := db.GetEngine(ctx) // In case is a organization. - org, err := user_model.GetUserByIDEngine(sess, uid) + org, err := user_model.GetUserByIDCtx(ctx, uid) if err != nil { return err } @@ -647,7 +638,7 @@ func DeleteRepository(doer *user_model.User, uid, repoID int64) error { releaseAttachments = append(releaseAttachments, attachments[i].RelativePath()) } - if _, err := sess.Exec("UPDATE `user` SET num_stars=num_stars-1 WHERE id IN (SELECT `uid` FROM `star` WHERE repo_id = ?)", repo.ID); err != nil { + if _, err := db.Exec(ctx, "UPDATE `user` SET num_stars=num_stars-1 WHERE id IN (SELECT `uid` FROM `star` WHERE repo_id = ?)", repo.ID); err != nil { return err } @@ -680,33 +671,33 @@ func DeleteRepository(doer *user_model.User, uid, repoID int64) error { } // Delete Labels and related objects - if err := deleteLabelsByRepoID(sess, repoID); err != nil { + if err := deleteLabelsByRepoID(ctx, repoID); err != nil { return err } // Delete Pulls and related objects - if err := deletePullsByBaseRepoID(sess, repoID); err != nil { + if err := deletePullsByBaseRepoID(ctx, repoID); err != nil { return err } // Delete Issues and related objects var attachmentPaths []string - if attachmentPaths, err = deleteIssuesByRepoID(sess, repoID); err != nil { + if attachmentPaths, err = deleteIssuesByRepoID(ctx, repoID); err != nil { return err } // Delete issue index - if err := db.DeleteResouceIndex(sess, "issue_index", repoID); err != nil { + if err := db.DeleteResouceIndex(ctx, "issue_index", repoID); err != nil { return err } if repo.IsFork { - if _, err := sess.Exec("UPDATE `repository` SET num_forks=num_forks-1 WHERE id=?", repo.ForkID); err != nil { + if _, err := db.Exec(ctx, "UPDATE `repository` SET num_forks=num_forks-1 WHERE id=?", repo.ForkID); err != nil { return fmt.Errorf("decrease fork count: %v", err) } } - if _, err := sess.Exec("UPDATE `user` SET num_repos=num_repos-1 WHERE id=?", uid); err != nil { + if _, err := db.Exec(ctx, "UPDATE `user` SET num_repos=num_repos-1 WHERE id=?", uid); err != nil { return err } @@ -716,7 +707,7 @@ func DeleteRepository(doer *user_model.User, uid, repoID int64) error { } } - projects, _, err := project_model.GetProjectsCtx(ctx, project_model.SearchOptions{ + projects, _, err := project_model.GetProjects(ctx, project_model.SearchOptions{ RepoID: repoID, }) if err != nil { @@ -736,7 +727,7 @@ func DeleteRepository(doer *user_model.User, uid, repoID int64) error { lfsPaths := make([]string, 0, len(lfsObjects)) for _, v := range lfsObjects { - count, err := sess.Count(&LFSMetaObject{Pointer: lfs.Pointer{Oid: v.Oid}}) + count, err := db.CountByBean(ctx, &LFSMetaObject{Pointer: lfs.Pointer{Oid: v.Oid}}) if err != nil { return err } @@ -747,7 +738,7 @@ func DeleteRepository(doer *user_model.User, uid, repoID int64) error { lfsPaths = append(lfsPaths, v.RelativePath()) } - if _, err := sess.Delete(&LFSMetaObject{RepositoryID: repoID}); err != nil { + if _, err := db.DeleteByBean(ctx, &LFSMetaObject{RepositoryID: repoID}); err != nil { return err } @@ -763,7 +754,7 @@ func DeleteRepository(doer *user_model.User, uid, repoID int64) error { archivePaths = append(archivePaths, p) } - if _, err := sess.Delete(&repo_model.RepoArchiver{RepoID: repoID}); err != nil { + if _, err := db.DeleteByBean(ctx, &repo_model.RepoArchiver{RepoID: repoID}); err != nil { return err } @@ -1181,7 +1172,7 @@ func DeleteDeployKey(ctx context.Context, doer *user_model.User, id int64) error if err != nil { return fmt.Errorf("GetRepositoryByID: %v", err) } - has, err := access_model.IsUserRepoAdminCtx(ctx, repo, doer) + has, err := access_model.IsUserRepoAdmin(ctx, repo, doer) if err != nil { return fmt.Errorf("GetUserRepoPermission: %v", err) } else if !has { diff --git a/models/repo/attachment.go b/models/repo/attachment.go index f5351578f..ddddac2c3 100644 --- a/models/repo/attachment.go +++ b/models/repo/attachment.go @@ -60,11 +60,6 @@ func (a *Attachment) DownloadURL() string { return setting.AppURL + "attachments/" + url.PathEscape(a.UUID) } -// GetAttachmentByID returns attachment by given id -func GetAttachmentByID(id int64) (*Attachment, error) { - return getAttachmentByID(db.GetEngine(db.DefaultContext), id) -} - // _____ __ __ .__ __ // / _ \_/ |__/ |______ ____ | |__ _____ ____ _____/ |_ // / /_\ \ __\ __\__ \ _/ ___\| | \ / \_/ __ \ / \ __\ @@ -88,9 +83,10 @@ func (err ErrAttachmentNotExist) Error() string { return fmt.Sprintf("attachment does not exist [id: %d, uuid: %s]", err.ID, err.UUID) } -func getAttachmentByID(e db.Engine, id int64) (*Attachment, error) { +// GetAttachmentByID returns attachment by given id +func GetAttachmentByID(ctx context.Context, id int64) (*Attachment, error) { attach := &Attachment{} - if has, err := e.ID(id).Get(attach); err != nil { + if has, err := db.GetEngine(ctx).ID(id).Get(attach); err != nil { return nil, err } else if !has { return nil, ErrAttachmentNotExist{ID: id, UUID: ""} @@ -98,9 +94,10 @@ func getAttachmentByID(e db.Engine, id int64) (*Attachment, error) { return attach, nil } -func getAttachmentByUUID(e db.Engine, uuid string) (*Attachment, error) { +// GetAttachmentByUUID returns attachment by given UUID. +func GetAttachmentByUUID(ctx context.Context, uuid string) (*Attachment, error) { attach := &Attachment{} - has, err := e.Where("uuid=?", uuid).Get(attach) + has, err := db.GetEngine(ctx).Where("uuid=?", uuid).Get(attach) if err != nil { return nil, err } else if !has { @@ -111,22 +108,13 @@ func getAttachmentByUUID(e db.Engine, uuid string) (*Attachment, error) { // GetAttachmentsByUUIDs returns attachment by given UUID list. func GetAttachmentsByUUIDs(ctx context.Context, uuids []string) ([]*Attachment, error) { - return getAttachmentsByUUIDs(db.GetEngine(ctx), uuids) -} - -func getAttachmentsByUUIDs(e db.Engine, uuids []string) ([]*Attachment, error) { if len(uuids) == 0 { return []*Attachment{}, nil } // Silently drop invalid uuids. attachments := make([]*Attachment, 0, len(uuids)) - return attachments, e.In("uuid", uuids).Find(&attachments) -} - -// GetAttachmentByUUID returns attachment by given UUID. -func GetAttachmentByUUID(uuid string) (*Attachment, error) { - return getAttachmentByUUID(db.GetEngine(db.DefaultContext), uuid) + return attachments, db.GetEngine(ctx).In("uuid", uuids).Find(&attachments) } // ExistAttachmentsByUUID returns true if attachment is exist by given UUID @@ -134,37 +122,22 @@ func ExistAttachmentsByUUID(uuid string) (bool, error) { return db.GetEngine(db.DefaultContext).Where("`uuid`=?", uuid).Exist(new(Attachment)) } -// GetAttachmentByReleaseIDFileName returns attachment by given releaseId and fileName. -func GetAttachmentByReleaseIDFileName(releaseID int64, fileName string) (*Attachment, error) { - return getAttachmentByReleaseIDFileName(db.GetEngine(db.DefaultContext), releaseID, fileName) -} - -// GetAttachmentsByIssueIDCtx returns all attachments of an issue. -func GetAttachmentsByIssueIDCtx(ctx context.Context, issueID int64) ([]*Attachment, error) { +// GetAttachmentsByIssueID returns all attachments of an issue. +func GetAttachmentsByIssueID(ctx context.Context, issueID int64) ([]*Attachment, error) { attachments := make([]*Attachment, 0, 10) return attachments, db.GetEngine(ctx).Where("issue_id = ? AND comment_id = 0", issueID).Find(&attachments) } -// GetAttachmentsByIssueID returns all attachments of an issue. -func GetAttachmentsByIssueID(issueID int64) ([]*Attachment, error) { - return GetAttachmentsByIssueIDCtx(db.DefaultContext, issueID) -} - // GetAttachmentsByCommentID returns all attachments if comment by given ID. -func GetAttachmentsByCommentID(commentID int64) ([]*Attachment, error) { - return GetAttachmentsByCommentIDCtx(db.DefaultContext, commentID) -} - -// GetAttachmentsByCommentIDCtx returns all attachments if comment by given ID. -func GetAttachmentsByCommentIDCtx(ctx context.Context, commentID int64) ([]*Attachment, error) { +func GetAttachmentsByCommentID(ctx context.Context, commentID int64) ([]*Attachment, error) { attachments := make([]*Attachment, 0, 10) return attachments, db.GetEngine(ctx).Where("comment_id=?", commentID).Find(&attachments) } -// getAttachmentByReleaseIDFileName return a file based on the the following infos: -func getAttachmentByReleaseIDFileName(e db.Engine, releaseID int64, fileName string) (*Attachment, error) { +// GetAttachmentByReleaseIDFileName returns attachment by given releaseId and fileName. +func GetAttachmentByReleaseIDFileName(ctx context.Context, releaseID int64, fileName string) (*Attachment, error) { attach := &Attachment{ReleaseID: releaseID, Name: fileName} - has, err := e.Get(attach) + has, err := db.GetEngine(ctx).Get(attach) if err != nil { return nil, err } else if !has { @@ -207,7 +180,7 @@ func DeleteAttachments(ctx context.Context, attachments []*Attachment, remove bo // DeleteAttachmentsByIssue deletes all attachments associated with the given issue. func DeleteAttachmentsByIssue(issueID int64, remove bool) (int, error) { - attachments, err := GetAttachmentsByIssueID(issueID) + attachments, err := GetAttachmentsByIssueID(db.DefaultContext, issueID) if err != nil { return 0, err } @@ -217,7 +190,7 @@ func DeleteAttachmentsByIssue(issueID int64, remove bool) (int, error) { // DeleteAttachmentsByComment deletes all attachments associated with the given comment. func DeleteAttachmentsByComment(commentID int64, remove bool) (int, error) { - attachments, err := GetAttachmentsByCommentID(commentID) + attachments, err := GetAttachmentsByCommentID(db.DefaultContext, commentID) if err != nil { return 0, err } @@ -225,11 +198,6 @@ func DeleteAttachmentsByComment(commentID int64, remove bool) (int, error) { return DeleteAttachments(db.DefaultContext, attachments, remove) } -// UpdateAttachment updates the given attachment in database -func UpdateAttachment(atta *Attachment) error { - return UpdateAttachmentCtx(db.DefaultContext, atta) -} - // UpdateAttachmentByUUID Updates attachment via uuid func UpdateAttachmentByUUID(ctx context.Context, attach *Attachment, cols ...string) error { if attach.UUID == "" { @@ -239,8 +207,8 @@ func UpdateAttachmentByUUID(ctx context.Context, attach *Attachment, cols ...str return err } -// UpdateAttachmentCtx updates the given attachment in database -func UpdateAttachmentCtx(ctx context.Context, atta *Attachment) error { +// UpdateAttachment updates the given attachment in database +func UpdateAttachment(ctx context.Context, atta *Attachment) error { sess := db.GetEngine(ctx).Cols("name", "issue_id", "release_id", "comment_id", "download_count") if atta.ID != 0 && atta.UUID == "" { sess = sess.ID(atta.ID) diff --git a/models/repo/attachment_test.go b/models/repo/attachment_test.go index 53c28d532..da486fdb2 100644 --- a/models/repo/attachment_test.go +++ b/models/repo/attachment_test.go @@ -16,7 +16,7 @@ import ( func TestIncreaseDownloadCount(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - attachment, err := GetAttachmentByUUID("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11") + attachment, err := GetAttachmentByUUID(db.DefaultContext, "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11") assert.NoError(t, err) assert.Equal(t, int64(0), attachment.DownloadCount) @@ -24,7 +24,7 @@ func TestIncreaseDownloadCount(t *testing.T) { err = attachment.IncreaseDownloadCount() assert.NoError(t, err) - attachment, err = GetAttachmentByUUID("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11") + attachment, err = GetAttachmentByUUID(db.DefaultContext, "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11") assert.NoError(t, err) assert.Equal(t, int64(1), attachment.DownloadCount) } @@ -33,11 +33,11 @@ func TestGetByCommentOrIssueID(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) // count of attachments from issue ID - attachments, err := GetAttachmentsByIssueID(1) + attachments, err := GetAttachmentsByIssueID(db.DefaultContext, 1) assert.NoError(t, err) assert.Len(t, attachments, 1) - attachments, err = GetAttachmentsByCommentID(1) + attachments, err = GetAttachmentsByCommentID(db.DefaultContext, 1) assert.NoError(t, err) assert.Len(t, attachments, 2) } @@ -56,7 +56,7 @@ func TestDeleteAttachments(t *testing.T) { err = DeleteAttachment(&Attachment{ID: 8}, false) assert.NoError(t, err) - attachment, err := GetAttachmentByUUID("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a18") + attachment, err := GetAttachmentByUUID(db.DefaultContext, "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a18") assert.Error(t, err) assert.True(t, IsErrAttachmentNotExist(err)) assert.Nil(t, attachment) @@ -65,7 +65,7 @@ func TestDeleteAttachments(t *testing.T) { func TestGetAttachmentByID(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - attach, err := GetAttachmentByID(1) + attach, err := GetAttachmentByID(db.DefaultContext, 1) assert.NoError(t, err) assert.Equal(t, "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11", attach.UUID) } @@ -81,12 +81,12 @@ func TestAttachment_DownloadURL(t *testing.T) { func TestUpdateAttachment(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - attach, err := GetAttachmentByID(1) + attach, err := GetAttachmentByID(db.DefaultContext, 1) assert.NoError(t, err) assert.Equal(t, "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11", attach.UUID) attach.Name = "new_name" - assert.NoError(t, UpdateAttachment(attach)) + assert.NoError(t, UpdateAttachment(db.DefaultContext, attach)) unittest.AssertExistsAndLoadBean(t, &Attachment{Name: "new_name"}) } diff --git a/models/repo/avatar.go b/models/repo/avatar.go index f11f868d6..cdf85bf1a 100644 --- a/models/repo/avatar.go +++ b/models/repo/avatar.go @@ -5,6 +5,7 @@ package repo import ( + "context" "fmt" "image/png" "io" @@ -25,11 +26,11 @@ func (repo *Repository) CustomAvatarRelativePath() string { // RelAvatarLink returns a relative link to the repository's avatar. func (repo *Repository) RelAvatarLink() string { - return repo.relAvatarLink(db.GetEngine(db.DefaultContext)) + return repo.relAvatarLink(db.DefaultContext) } // generateRandomAvatar generates a random avatar for repository. -func generateRandomAvatar(e db.Engine, repo *Repository) error { +func generateRandomAvatar(ctx context.Context, repo *Repository) error { idToString := fmt.Sprintf("%d", repo.ID) seed := idToString @@ -51,14 +52,14 @@ func generateRandomAvatar(e db.Engine, repo *Repository) error { log.Info("New random avatar created for repository: %d", repo.ID) - if _, err := e.ID(repo.ID).Cols("avatar").NoAutoTime().Update(repo); err != nil { + if _, err := db.GetEngine(ctx).ID(repo.ID).Cols("avatar").NoAutoTime().Update(repo); err != nil { return err } return nil } -func (repo *Repository) relAvatarLink(e db.Engine) string { +func (repo *Repository) relAvatarLink(ctx context.Context) string { // If no avatar - path is empty avatarPath := repo.CustomAvatarRelativePath() if len(avatarPath) == 0 { @@ -66,7 +67,7 @@ func (repo *Repository) relAvatarLink(e db.Engine) string { case "image": return setting.RepoAvatar.FallbackImage case "random": - if err := generateRandomAvatar(e, repo); err != nil { + if err := generateRandomAvatar(ctx, repo); err != nil { log.Error("generateRandomAvatar: %v", err) } default: @@ -79,12 +80,12 @@ func (repo *Repository) relAvatarLink(e db.Engine) string { // AvatarLink returns a link to the repository's avatar. func (repo *Repository) AvatarLink() string { - return repo.avatarLink(db.GetEngine(db.DefaultContext)) + return repo.avatarLink(db.DefaultContext) } // avatarLink returns user avatar absolute link. -func (repo *Repository) avatarLink(e db.Engine) string { - link := repo.relAvatarLink(e) +func (repo *Repository) avatarLink(ctx context.Context) string { + link := repo.relAvatarLink(ctx) // we only prepend our AppURL to our known (relative, internal) avatar link to get an absolute URL if strings.HasPrefix(link, "/") && !strings.HasPrefix(link, "//") { return setting.AppURL + strings.TrimPrefix(link, setting.AppSubURL)[1:] diff --git a/models/repo/collaboration.go b/models/repo/collaboration.go index 3ebb68814..09397dd17 100644 --- a/models/repo/collaboration.go +++ b/models/repo/collaboration.go @@ -37,15 +37,14 @@ type Collaborator struct { // GetCollaborators returns the collaborators for a repository func GetCollaborators(ctx context.Context, repoID int64, listOptions db.ListOptions) ([]*Collaborator, error) { - e := db.GetEngine(ctx) - collaborations, err := getCollaborations(e, repoID, listOptions) + collaborations, err := getCollaborations(ctx, repoID, listOptions) if err != nil { return nil, fmt.Errorf("getCollaborations: %v", err) } collaborators := make([]*Collaborator, 0, len(collaborations)) for _, c := range collaborations { - user, err := user_model.GetUserByIDEngine(e, c.UserID) + user, err := user_model.GetUserByIDCtx(ctx, c.UserID) if err != nil { if user_model.IsErrUserNotExist(err) { log.Warn("Inconsistent DB: User: %d is listed as collaborator of %-v but does not exist", c.UserID, repoID) @@ -85,12 +84,14 @@ func IsCollaborator(ctx context.Context, repoID, userID int64) (bool, error) { return db.GetEngine(ctx).Get(&Collaboration{RepoID: repoID, UserID: userID}) } -func getCollaborations(e db.Engine, repoID int64, listOptions db.ListOptions) ([]*Collaboration, error) { +func getCollaborations(ctx context.Context, repoID int64, listOptions db.ListOptions) ([]*Collaboration, error) { if listOptions.Page == 0 { collaborations := make([]*Collaboration, 0, 8) - return collaborations, e.Find(&collaborations, &Collaboration{RepoID: repoID}) + return collaborations, db.GetEngine(ctx).Find(&collaborations, &Collaboration{RepoID: repoID}) } + e := db.GetEngine(ctx) + e = db.SetEnginePagination(e, &listOptions) collaborations := make([]*Collaboration, 0, listOptions.PageSize) diff --git a/models/repo/fork.go b/models/repo/fork.go index ae7882a02..b48126253 100644 --- a/models/repo/fork.go +++ b/models/repo/fork.go @@ -10,18 +10,14 @@ import ( "code.gitea.io/gitea/models/db" ) -func getRepositoriesByForkID(e db.Engine, forkID int64) ([]*Repository, error) { +// GetRepositoriesByForkID returns all repositories with given fork ID. +func GetRepositoriesByForkID(ctx context.Context, forkID int64) ([]*Repository, error) { repos := make([]*Repository, 0, 10) - return repos, e. + return repos, db.GetEngine(ctx). Where("fork_id=?", forkID). Find(&repos) } -// GetRepositoriesByForkID returns all repositories with given fork ID. -func GetRepositoriesByForkID(ctx context.Context, forkID int64) ([]*Repository, error) { - return getRepositoriesByForkID(db.GetEngine(ctx), forkID) -} - // GetForkedRepo checks if given user has already forked a repository with given ID. func GetForkedRepo(ownerID, repoID int64) *Repository { repo := new(Repository) diff --git a/models/repo/language_stats.go b/models/repo/language_stats.go index 3b0888b6b..b047046ae 100644 --- a/models/repo/language_stats.go +++ b/models/repo/language_stats.go @@ -5,6 +5,7 @@ package repo import ( + "context" "math" "strings" @@ -66,22 +67,18 @@ func (stats LanguageStatList) getLanguagePercentages() map[string]float32 { return langPerc } -func getLanguageStats(e db.Engine, repo *Repository) (LanguageStatList, error) { +// GetLanguageStats returns the language statistics for a repository +func GetLanguageStats(ctx context.Context, repo *Repository) (LanguageStatList, error) { stats := make(LanguageStatList, 0, 6) - if err := e.Where("`repo_id` = ?", repo.ID).Desc("`size`").Find(&stats); err != nil { + if err := db.GetEngine(ctx).Where("`repo_id` = ?", repo.ID).Desc("`size`").Find(&stats); err != nil { return nil, err } return stats, nil } -// GetLanguageStats returns the language statistics for a repository -func GetLanguageStats(repo *Repository) (LanguageStatList, error) { - return getLanguageStats(db.GetEngine(db.DefaultContext), repo) -} - // GetTopLanguageStats returns the top language statistics for a repository func GetTopLanguageStats(repo *Repository, limit int) (LanguageStatList, error) { - stats, err := getLanguageStats(db.GetEngine(db.DefaultContext), repo) + stats, err := GetLanguageStats(db.DefaultContext, repo) if err != nil { return nil, err } @@ -120,7 +117,7 @@ func UpdateLanguageStats(repo *Repository, commitID string, stats map[string]int defer committer.Close() sess := db.GetEngine(ctx) - oldstats, err := getLanguageStats(sess, repo) + oldstats, err := GetLanguageStats(ctx, repo) if err != nil { return err } @@ -151,7 +148,7 @@ func UpdateLanguageStats(repo *Repository, commitID string, stats map[string]int } // Insert new language if !upd { - if _, err := sess.Insert(&LanguageStat{ + if err := db.Insert(ctx, &LanguageStat{ RepoID: repo.ID, CommitID: commitID, IsPrimary: llang == topLang, @@ -176,7 +173,7 @@ func UpdateLanguageStats(repo *Repository, commitID string, stats map[string]int } // Update indexer status - if err = updateIndexerStatus(sess, repo, RepoIndexerTypeStats, commitID); err != nil { + if err = UpdateIndexerStatus(ctx, repo, RepoIndexerTypeStats, commitID); err != nil { return err } @@ -190,10 +187,9 @@ func CopyLanguageStat(originalRepo, destRepo *Repository) error { return err } defer committer.Close() - sess := db.GetEngine(ctx) RepoLang := make(LanguageStatList, 0, 6) - if err := sess.Where("`repo_id` = ?", originalRepo.ID).Desc("`size`").Find(&RepoLang); err != nil { + if err := db.GetEngine(ctx).Where("`repo_id` = ?", originalRepo.ID).Desc("`size`").Find(&RepoLang); err != nil { return err } if len(RepoLang) > 0 { @@ -204,10 +200,10 @@ func CopyLanguageStat(originalRepo, destRepo *Repository) error { } // update destRepo's indexer status tmpCommitID := RepoLang[0].CommitID - if err := updateIndexerStatus(sess, destRepo, RepoIndexerTypeStats, tmpCommitID); err != nil { + if err := UpdateIndexerStatus(ctx, destRepo, RepoIndexerTypeStats, tmpCommitID); err != nil { return err } - if _, err := sess.Insert(&RepoLang); err != nil { + if err := db.Insert(ctx, &RepoLang); err != nil { return err } } diff --git a/models/repo/mirror.go b/models/repo/mirror.go index df4e32075..5d20b7f83 100644 --- a/models/repo/mirror.go +++ b/models/repo/mirror.go @@ -14,8 +14,6 @@ import ( "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/timeutil" - - "xorm.io/xorm" ) // ErrMirrorNotExist mirror does not exist error @@ -56,21 +54,16 @@ func (m *Mirror) BeforeInsert() { } } -// AfterLoad is invoked from XORM after setting the values of all fields of this object. -func (m *Mirror) AfterLoad(session *xorm.Session) { - if m == nil { - return +// GetRepository returns the repository. +func (m *Mirror) GetRepository() *Repository { + if m.Repo != nil { + return m.Repo } - var err error - m.Repo, err = getRepositoryByID(session, m.RepoID) + m.Repo, err = GetRepositoryByIDCtx(db.DefaultContext, m.RepoID) if err != nil { log.Error("getRepositoryByID[%d]: %v", m.ID, err) } -} - -// GetRepository returns the repository. -func (m *Mirror) GetRepository() *Repository { return m.Repo } @@ -88,9 +81,10 @@ func (m *Mirror) ScheduleNextUpdate() { } } -func getMirrorByRepoID(e db.Engine, repoID int64) (*Mirror, error) { +// GetMirrorByRepoID returns mirror information of a repository. +func GetMirrorByRepoID(ctx context.Context, repoID int64) (*Mirror, error) { m := &Mirror{RepoID: repoID} - has, err := e.Get(m) + has, err := db.GetEngine(ctx).Get(m) if err != nil { return nil, err } else if !has { @@ -99,19 +93,10 @@ func getMirrorByRepoID(e db.Engine, repoID int64) (*Mirror, error) { return m, nil } -// GetMirrorByRepoID returns mirror information of a repository. -func GetMirrorByRepoID(repoID int64) (*Mirror, error) { - return getMirrorByRepoID(db.GetEngine(db.DefaultContext), repoID) -} - -func updateMirror(e db.Engine, m *Mirror) error { - _, err := e.ID(m.ID).AllCols().Update(m) - return err -} - // UpdateMirror updates the mirror -func UpdateMirror(m *Mirror) error { - return updateMirror(db.GetEngine(db.DefaultContext), m) +func UpdateMirror(ctx context.Context, m *Mirror) error { + _, err := db.GetEngine(ctx).ID(m.ID).AllCols().Update(m) + return err } // TouchMirror updates the mirror updatedUnix @@ -146,7 +131,7 @@ func InsertMirror(mirror *Mirror) error { // MirrorRepositoryList contains the mirror repositories type MirrorRepositoryList []*Repository -func (repos MirrorRepositoryList) loadAttributes(e db.Engine) error { +func (repos MirrorRepositoryList) loadAttributes(ctx context.Context) error { if len(repos) == 0 { return nil } @@ -161,7 +146,7 @@ func (repos MirrorRepositoryList) loadAttributes(e db.Engine) error { repoIDs = append(repoIDs, repos[i].ID) } mirrors := make([]*Mirror, 0, len(repoIDs)) - if err := e. + if err := db.GetEngine(ctx). Where("id > 0"). In("repo_id", repoIDs). Find(&mirrors); err != nil { @@ -174,11 +159,12 @@ func (repos MirrorRepositoryList) loadAttributes(e db.Engine) error { } for i := range repos { repos[i].Mirror = set[repos[i].ID] + repos[i].Mirror.Repo = repos[i] } return nil } // LoadAttributes loads the attributes for the given MirrorRepositoryList func (repos MirrorRepositoryList) LoadAttributes() error { - return repos.loadAttributes(db.GetEngine(db.DefaultContext)) + return repos.loadAttributes(db.DefaultContext) } diff --git a/models/repo/pushmirror.go b/models/repo/pushmirror.go index b5c6411bd..048c0c348 100644 --- a/models/repo/pushmirror.go +++ b/models/repo/pushmirror.go @@ -11,8 +11,6 @@ import ( "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/timeutil" - - "xorm.io/xorm" ) // ErrPushMirrorNotExist mirror does not exist error @@ -35,21 +33,16 @@ func init() { db.RegisterModel(new(PushMirror)) } -// AfterLoad is invoked from XORM after setting the values of all fields of this object. -func (m *PushMirror) AfterLoad(session *xorm.Session) { - if m == nil { - return +// GetRepository returns the path of the repository. +func (m *PushMirror) GetRepository() *Repository { + if m.Repo != nil { + return m.Repo } - var err error - m.Repo, err = getRepositoryByID(session, m.RepoID) + m.Repo, err = GetRepositoryByIDCtx(db.DefaultContext, m.RepoID) if err != nil { log.Error("getRepositoryByID[%d]: %v", m.ID, err) } -} - -// GetRepository returns the path of the repository. -func (m *PushMirror) GetRepository() *Repository { return m.Repo } diff --git a/models/repo/repo.go b/models/repo/repo.go index 8af6357bf..3fd6b94eb 100644 --- a/models/repo/repo.go +++ b/models/repo/repo.go @@ -289,7 +289,7 @@ func (repo *Repository) LoadUnits(ctx context.Context) (err error) { return nil } - repo.Units, err = getUnitsByRepoID(db.GetEngine(ctx), repo.ID) + repo.Units, err = getUnitsByRepoID(ctx, repo.ID) if log.IsTrace() { unitTypeStrings := make([]string, len(repo.Units)) for i, unit := range repo.Units { @@ -383,7 +383,7 @@ func (repo *Repository) GetOwner(ctx context.Context) (err error) { return nil } - repo.Owner, err = user_model.GetUserByIDEngine(db.GetEngine(ctx), repo.OwnerID) + repo.Owner, err = user_model.GetUserByIDCtx(ctx, repo.OwnerID) return err } @@ -454,15 +454,15 @@ func (repo *Repository) ComposeDocumentMetas() map[string]string { // returns an error on failure (NOTE: no error is returned for // non-fork repositories, and BaseRepo will be left untouched) func (repo *Repository) GetBaseRepo() (err error) { - return repo.getBaseRepo(db.GetEngine(db.DefaultContext)) + return repo.getBaseRepo(db.DefaultContext) } -func (repo *Repository) getBaseRepo(e db.Engine) (err error) { +func (repo *Repository) getBaseRepo(ctx context.Context) (err error) { if !repo.IsFork { return nil } - repo.BaseRepo, err = getRepositoryByID(e, repo.ForkID) + repo.BaseRepo, err = GetRepositoryByIDCtx(ctx, repo.ForkID) return err } @@ -481,16 +481,6 @@ func (repo *Repository) RepoPath() string { return RepoPath(repo.OwnerName, repo.Name) } -// GitConfigPath returns the path to a repository's git config/ directory -func GitConfigPath(repoPath string) string { - return filepath.Join(repoPath, "config") -} - -// GitConfigPath returns the repository git config path -func (repo *Repository) GitConfigPath() string { - return GitConfigPath(repo.RepoPath()) -} - // Link returns the repository link func (repo *Repository) Link() string { return setting.AppSubURL + "/" + url.PathEscape(repo.OwnerName) + "/" + url.PathEscape(repo.Name) @@ -669,9 +659,10 @@ func GetRepositoryByName(ownerID int64, name string) (*Repository, error) { return repo, err } -func getRepositoryByID(e db.Engine, id int64) (*Repository, error) { +// GetRepositoryByIDCtx returns the repository by given id if exists. +func GetRepositoryByIDCtx(ctx context.Context, id int64) (*Repository, error) { repo := new(Repository) - has, err := e.ID(id).Get(repo) + has, err := db.GetEngine(ctx).ID(id).Get(repo) if err != nil { return nil, err } else if !has { @@ -682,12 +673,7 @@ func getRepositoryByID(e db.Engine, id int64) (*Repository, error) { // GetRepositoryByID returns the repository by given id if exists. func GetRepositoryByID(id int64) (*Repository, error) { - return getRepositoryByID(db.GetEngine(db.DefaultContext), id) -} - -// GetRepositoryByIDCtx returns the repository by given id if exists. -func GetRepositoryByIDCtx(ctx context.Context, id int64) (*Repository, error) { - return getRepositoryByID(db.GetEngine(ctx), id) + return GetRepositoryByIDCtx(db.DefaultContext, id) } // GetRepositoriesMapByIDs returns the repositories by given id slice. @@ -696,8 +682,8 @@ func GetRepositoriesMapByIDs(ids []int64) (map[int64]*Repository, error) { return repos, db.GetEngine(db.DefaultContext).In("id", ids).Find(&repos) } -// IsRepositoryExistCtx returns true if the repository with given name under user has already existed. -func IsRepositoryExistCtx(ctx context.Context, u *user_model.User, repoName string) (bool, error) { +// IsRepositoryExist returns true if the repository with given name under user has already existed. +func IsRepositoryExist(ctx context.Context, u *user_model.User, repoName string) (bool, error) { has, err := db.GetEngine(ctx).Get(&Repository{ OwnerID: u.ID, LowerName: strings.ToLower(repoName), @@ -709,29 +695,20 @@ func IsRepositoryExistCtx(ctx context.Context, u *user_model.User, repoName stri return has && isDir, err } -// IsRepositoryExist returns true if the repository with given name under user has already existed. -func IsRepositoryExist(u *user_model.User, repoName string) (bool, error) { - return IsRepositoryExistCtx(db.DefaultContext, u, repoName) -} - // GetTemplateRepo populates repo.TemplateRepo for a generated repository and // returns an error on failure (NOTE: no error is returned for // non-generated repositories, and TemplateRepo will be left untouched) -func GetTemplateRepo(repo *Repository) (*Repository, error) { - return getTemplateRepo(db.GetEngine(db.DefaultContext), repo) -} - -func getTemplateRepo(e db.Engine, repo *Repository) (*Repository, error) { +func GetTemplateRepo(ctx context.Context, repo *Repository) (*Repository, error) { if !repo.IsGenerated() { return nil, nil } - return getRepositoryByID(e, repo.TemplateID) + return GetRepositoryByIDCtx(ctx, repo.TemplateID) } // TemplateRepo returns the repository, which is template of this repository func (repo *Repository) TemplateRepo() *Repository { - repo, err := GetTemplateRepo(repo) + repo, err := GetTemplateRepo(db.DefaultContext, repo) if err != nil { log.Error("TemplateRepo: %v", err) return nil @@ -739,26 +716,27 @@ func (repo *Repository) TemplateRepo() *Repository { return repo } -func countRepositories(userID int64, private bool) int64 { - sess := db.GetEngine(db.DefaultContext).Where("id > 0") +type CountRepositoryOptions struct { + OwnerID int64 + Private util.OptionalBool +} - if userID > 0 { - sess.And("owner_id = ?", userID) +// CountRepositories returns number of repositories. +// Argument private only takes effect when it is false, +// set it true to count all repositories. +func CountRepositories(ctx context.Context, opts CountRepositoryOptions) (int64, error) { + sess := db.GetEngine(ctx).Where("id > 0") + + if opts.OwnerID > 0 { + sess.And("owner_id = ?", opts.OwnerID) } - if !private { - sess.And("is_private=?", false) + if !opts.Private.IsNone() { + sess.And("is_private=?", opts.Private.IsTrue()) } count, err := sess.Count(new(Repository)) if err != nil { - log.Error("countRepositories: %v", err) + return 0, fmt.Errorf("countRepositories: %v", err) } - return count -} - -// CountRepositories returns number of repositories. -// Argument private only takes effect when it is false, -// set it true to count all repositories. -func CountRepositories(private bool) int64 { - return countRepositories(-1, private) + return count, nil } diff --git a/models/repo/repo_indexer.go b/models/repo/repo_indexer.go index f442cad4d..cba70a14e 100644 --- a/models/repo/repo_indexer.go +++ b/models/repo/repo_indexer.go @@ -5,6 +5,7 @@ package repo import ( + "context" "fmt" "code.gitea.io/gitea/models/db" @@ -62,8 +63,8 @@ func GetUnindexedRepos(indexerType RepoIndexerType, maxRepoID int64, page, pageS return ids, err } -// getIndexerStatus loads repo codes indxer status -func getIndexerStatus(e db.Engine, repo *Repository, indexerType RepoIndexerType) (*RepoIndexerStatus, error) { +// GetIndexerStatus loads repo codes indxer status +func GetIndexerStatus(ctx context.Context, repo *Repository, indexerType RepoIndexerType) (*RepoIndexerStatus, error) { switch indexerType { case RepoIndexerTypeCode: if repo.CodeIndexerStatus != nil { @@ -75,7 +76,7 @@ func getIndexerStatus(e db.Engine, repo *Repository, indexerType RepoIndexerType } } status := &RepoIndexerStatus{RepoID: repo.ID} - if has, err := e.Where("`indexer_type` = ?", indexerType).Get(status); err != nil { + if has, err := db.GetEngine(ctx).Where("`indexer_type` = ?", indexerType).Get(status); err != nil { return nil, err } else if !has { status.IndexerType = indexerType @@ -90,36 +91,25 @@ func getIndexerStatus(e db.Engine, repo *Repository, indexerType RepoIndexerType return status, nil } -// GetIndexerStatus loads repo codes indxer status -func GetIndexerStatus(repo *Repository, indexerType RepoIndexerType) (*RepoIndexerStatus, error) { - return getIndexerStatus(db.GetEngine(db.DefaultContext), repo, indexerType) -} - -// updateIndexerStatus updates indexer status -func updateIndexerStatus(e db.Engine, repo *Repository, indexerType RepoIndexerType, sha string) error { - status, err := getIndexerStatus(e, repo, indexerType) +// UpdateIndexerStatus updates indexer status +func UpdateIndexerStatus(ctx context.Context, repo *Repository, indexerType RepoIndexerType, sha string) error { + status, err := GetIndexerStatus(ctx, repo, indexerType) if err != nil { return fmt.Errorf("UpdateIndexerStatus: Unable to getIndexerStatus for repo: %s Error: %v", repo.FullName(), err) } if len(status.CommitSha) == 0 { status.CommitSha = sha - _, err := e.Insert(status) - if err != nil { + if err := db.Insert(ctx, status); err != nil { return fmt.Errorf("UpdateIndexerStatus: Unable to insert repoIndexerStatus for repo: %s Sha: %s Error: %v", repo.FullName(), sha, err) } return nil } status.CommitSha = sha - _, err = e.ID(status.ID).Cols("commit_sha"). + _, err = db.GetEngine(ctx).ID(status.ID).Cols("commit_sha"). Update(status) if err != nil { return fmt.Errorf("UpdateIndexerStatus: Unable to update repoIndexerStatus for repo: %s Sha: %s Error: %v", repo.FullName(), sha, err) } return nil } - -// UpdateIndexerStatus updates indexer status -func UpdateIndexerStatus(repo *Repository, indexerType RepoIndexerType, sha string) error { - return updateIndexerStatus(db.GetEngine(db.DefaultContext), repo, indexerType, sha) -} diff --git a/models/repo/repo_list.go b/models/repo/repo_list.go index 571604a2c..23cdd6cad 100644 --- a/models/repo/repo_list.go +++ b/models/repo/repo_list.go @@ -22,9 +22,10 @@ func GetUserMirrorRepositories(userID int64) ([]*Repository, error) { func IterateRepository(f func(repo *Repository) error) error { var start int batchSize := setting.Database.IterateBufferSize + sess := db.GetEngine(db.DefaultContext) for { repos := make([]*Repository, 0, batchSize) - if err := db.GetEngine(db.DefaultContext).Limit(batchSize, start).Find(&repos); err != nil { + if err := sess.Limit(batchSize, start).Find(&repos); err != nil { return err } if len(repos) == 0 { diff --git a/models/repo/repo_test.go b/models/repo/repo_test.go index 92b95f1d4..cf6ee8b67 100644 --- a/models/repo/repo_test.go +++ b/models/repo/repo_test.go @@ -9,17 +9,24 @@ import ( "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/models/unittest" - user_model "code.gitea.io/gitea/models/user" + "code.gitea.io/gitea/modules/util" "github.com/stretchr/testify/assert" ) +var ( + countRepospts = CountRepositoryOptions{OwnerID: 10} + countReposptsPublic = CountRepositoryOptions{OwnerID: 10, Private: util.OptionalBoolFalse} + countReposptsPrivate = CountRepositoryOptions{OwnerID: 10, Private: util.OptionalBoolTrue} +) + func TestGetRepositoryCount(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - count, err1 := GetRepositoryCount(db.DefaultContext, 10) - privateCount, err2 := GetPrivateRepositoryCount(&user_model.User{ID: int64(10)}) - publicCount, err3 := GetPublicRepositoryCount(&user_model.User{ID: int64(10)}) + ctx := db.DefaultContext + count, err1 := CountRepositories(ctx, countRepospts) + privateCount, err2 := CountRepositories(ctx, countReposptsPrivate) + publicCount, err3 := CountRepositories(ctx, countReposptsPublic) assert.NoError(t, err1) assert.NoError(t, err2) assert.NoError(t, err3) @@ -30,7 +37,7 @@ func TestGetRepositoryCount(t *testing.T) { func TestGetPublicRepositoryCount(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - count, err := GetPublicRepositoryCount(&user_model.User{ID: int64(10)}) + count, err := CountRepositories(db.DefaultContext, countReposptsPublic) assert.NoError(t, err) assert.Equal(t, int64(1), count) } @@ -38,7 +45,7 @@ func TestGetPublicRepositoryCount(t *testing.T) { func TestGetPrivateRepositoryCount(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - count, err := GetPrivateRepositoryCount(&user_model.User{ID: int64(10)}) + count, err := CountRepositories(db.DefaultContext, countReposptsPrivate) assert.NoError(t, err) assert.Equal(t, int64(2), count) } diff --git a/models/repo/repo_unit.go b/models/repo/repo_unit.go index de79eb1c9..a73678c6e 100644 --- a/models/repo/repo_unit.go +++ b/models/repo/repo_unit.go @@ -5,6 +5,7 @@ package repo import ( + "context" "fmt" "code.gitea.io/gitea/models/db" @@ -206,9 +207,9 @@ func (r *RepoUnit) ExternalTrackerConfig() *ExternalTrackerConfig { return r.Config.(*ExternalTrackerConfig) } -func getUnitsByRepoID(e db.Engine, repoID int64) (units []*RepoUnit, err error) { +func getUnitsByRepoID(ctx context.Context, repoID int64) (units []*RepoUnit, err error) { var tmpUnits []*RepoUnit - if err := e.Where("repo_id = ?", repoID).Find(&tmpUnits); err != nil { + if err := db.GetEngine(ctx).Where("repo_id = ?", repoID).Find(&tmpUnits); err != nil { return nil, err } diff --git a/models/repo/star.go b/models/repo/star.go index 8db297e3b..113b56f59 100644 --- a/models/repo/star.go +++ b/models/repo/star.go @@ -5,6 +5,8 @@ package repo import ( + "context" + "code.gitea.io/gitea/models/db" user_model "code.gitea.io/gitea/models/user" "code.gitea.io/gitea/modules/timeutil" @@ -29,7 +31,7 @@ func StarRepo(userID, repoID int64, star bool) error { return err } defer committer.Close() - staring := isStaring(db.GetEngine(ctx), userID, repoID) + staring := IsStaring(ctx, userID, repoID) if star { if staring { @@ -65,12 +67,8 @@ func StarRepo(userID, repoID int64, star bool) error { } // IsStaring checks if user has starred given repository. -func IsStaring(userID, repoID int64) bool { - return isStaring(db.GetEngine(db.DefaultContext), userID, repoID) -} - -func isStaring(e db.Engine, userID, repoID int64) bool { - has, _ := e.Get(&Star{UID: userID, RepoID: repoID}) +func IsStaring(ctx context.Context, userID, repoID int64) bool { + has, _ := db.GetEngine(ctx).Get(&Star{UID: userID, RepoID: repoID}) return has } diff --git a/models/repo/star_test.go b/models/repo/star_test.go index 20c4b6bef..2dde09c74 100644 --- a/models/repo/star_test.go +++ b/models/repo/star_test.go @@ -28,8 +28,8 @@ func TestStarRepo(t *testing.T) { func TestIsStaring(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - assert.True(t, IsStaring(2, 4)) - assert.False(t, IsStaring(3, 4)) + assert.True(t, IsStaring(db.DefaultContext, 2, 4)) + assert.False(t, IsStaring(db.DefaultContext, 3, 4)) } func TestRepository_GetStargazers(t *testing.T) { diff --git a/models/repo/topic.go b/models/repo/topic.go index 121863519..2a1646721 100644 --- a/models/repo/topic.go +++ b/models/repo/topic.go @@ -99,8 +99,9 @@ func GetTopicByName(name string) (*Topic, error) { // addTopicByNameToRepo adds a topic name to a repo and increments the topic count. // Returns topic after the addition -func addTopicByNameToRepo(e db.Engine, repoID int64, topicName string) (*Topic, error) { +func addTopicByNameToRepo(ctx context.Context, repoID int64, topicName string) (*Topic, error) { var topic Topic + e := db.GetEngine(ctx) has, err := e.Where("name = ?", topicName).Get(&topic) if err != nil { return nil, err @@ -108,7 +109,7 @@ func addTopicByNameToRepo(e db.Engine, repoID int64, topicName string) (*Topic, if !has { topic.Name = topicName topic.RepoCount = 1 - if _, err := e.Insert(&topic); err != nil { + if err := db.Insert(ctx, &topic); err != nil { return nil, err } } else { @@ -118,7 +119,7 @@ func addTopicByNameToRepo(e db.Engine, repoID int64, topicName string) (*Topic, } } - if _, err := e.Insert(&RepoTopic{ + if err := db.Insert(ctx, &RepoTopic{ RepoID: repoID, TopicID: topic.ID, }); err != nil { @@ -129,8 +130,9 @@ func addTopicByNameToRepo(e db.Engine, repoID int64, topicName string) (*Topic, } // removeTopicFromRepo remove a topic from a repo and decrements the topic repo count -func removeTopicFromRepo(e db.Engine, repoID int64, topic *Topic) error { +func removeTopicFromRepo(ctx context.Context, repoID int64, topic *Topic) error { topic.RepoCount-- + e := db.GetEngine(ctx) if _, err := e.ID(topic.ID).Cols("repo_count").Update(topic); err != nil { return err } @@ -208,17 +210,13 @@ func CountTopics(opts *FindTopicOptions) (int64, error) { } // GetRepoTopicByName retrieves topic from name for a repo if it exist -func GetRepoTopicByName(repoID int64, topicName string) (*Topic, error) { - return getRepoTopicByName(db.GetEngine(db.DefaultContext), repoID, topicName) -} - -func getRepoTopicByName(e db.Engine, repoID int64, topicName string) (*Topic, error) { +func GetRepoTopicByName(ctx context.Context, repoID int64, topicName string) (*Topic, error) { cond := builder.NewCond() var topic Topic cond = cond.And(builder.Eq{"repo_topic.repo_id": repoID}).And(builder.Eq{"topic.name": topicName}) - sess := e.Table("topic").Where(cond) + sess := db.GetEngine(ctx).Table("topic").Where(cond) sess.Join("INNER", "repo_topic", "repo_topic.topic_id = topic.id") - has, err := sess.Get(&topic) + has, err := sess.Select("topic.*").Get(&topic) if has { return &topic, err } @@ -234,7 +232,7 @@ func AddTopic(repoID int64, topicName string) (*Topic, error) { defer committer.Close() sess := db.GetEngine(ctx) - topic, err := getRepoTopicByName(sess, repoID, topicName) + topic, err := GetRepoTopicByName(ctx, repoID, topicName) if err != nil { return nil, err } @@ -243,7 +241,7 @@ func AddTopic(repoID int64, topicName string) (*Topic, error) { return topic, nil } - topic, err = addTopicByNameToRepo(sess, repoID, topicName) + topic, err = addTopicByNameToRepo(ctx, repoID, topicName) if err != nil { return nil, err } @@ -266,7 +264,7 @@ func AddTopic(repoID int64, topicName string) (*Topic, error) { // DeleteTopic removes a topic name from a repository (if it has it) func DeleteTopic(repoID int64, topicName string) (*Topic, error) { - topic, err := GetRepoTopicByName(repoID, topicName) + topic, err := GetRepoTopicByName(db.DefaultContext, repoID, topicName) if err != nil { return nil, err } @@ -275,7 +273,7 @@ func DeleteTopic(repoID int64, topicName string) (*Topic, error) { return nil, nil } - err = removeTopicFromRepo(db.GetEngine(db.DefaultContext), repoID, topic) + err = removeTopicFromRepo(db.DefaultContext, repoID, topic) return topic, err } @@ -329,14 +327,14 @@ func SaveTopics(repoID int64, topicNames ...string) error { } for _, topicName := range addedTopicNames { - _, err := addTopicByNameToRepo(sess, repoID, topicName) + _, err := addTopicByNameToRepo(ctx, repoID, topicName) if err != nil { return err } } for _, topic := range removeTopics { - err := removeTopicFromRepo(sess, repoID, topic) + err := removeTopicFromRepo(ctx, repoID, topic) if err != nil { return err } @@ -361,7 +359,7 @@ func SaveTopics(repoID int64, topicNames ...string) error { // GenerateTopics generates topics from a template repository func GenerateTopics(ctx context.Context, templateRepo, generateRepo *Repository) error { for _, topic := range templateRepo.Topics { - if _, err := addTopicByNameToRepo(db.GetEngine(ctx), generateRepo.ID, topic); err != nil { + if _, err := addTopicByNameToRepo(ctx, generateRepo.ID, topic); err != nil { return err } } diff --git a/models/repo/update.go b/models/repo/update.go index efc562a40..7fb51c959 100644 --- a/models/repo/update.go +++ b/models/repo/update.go @@ -42,17 +42,12 @@ func UpdateRepositoryUpdatedTime(repoID int64, updateTime time.Time) error { return err } -// UpdateRepositoryColsCtx updates repository's columns -func UpdateRepositoryColsCtx(ctx context.Context, repo *Repository, cols ...string) error { +// UpdateRepositoryCols updates repository's columns +func UpdateRepositoryCols(ctx context.Context, repo *Repository, cols ...string) error { _, err := db.GetEngine(ctx).ID(repo.ID).Cols(cols...).Update(repo) return err } -// UpdateRepositoryCols updates repository's columns -func UpdateRepositoryCols(repo *Repository, cols ...string) error { - return UpdateRepositoryColsCtx(db.DefaultContext, repo, cols...) -} - // ErrReachLimitOfRepo represents a "ReachLimitOfRepo" kind of error. type ErrReachLimitOfRepo struct { Limit int @@ -110,7 +105,7 @@ func CheckCreateRepository(doer, u *user_model.User, name string, overwriteOrAdo return err } - has, err := IsRepositoryExist(u, name) + has, err := IsRepositoryExist(db.DefaultContext, u, name) if err != nil { return fmt.Errorf("IsRepositoryExist: %v", err) } else if has { @@ -141,7 +136,7 @@ func ChangeRepositoryName(doer *user_model.User, repo *Repository, newRepoName s return err } - has, err := IsRepositoryExist(repo.Owner, newRepoName) + has, err := IsRepositoryExist(db.DefaultContext, repo.Owner, newRepoName) if err != nil { return fmt.Errorf("IsRepositoryExist: %v", err) } else if has { diff --git a/models/repo/user_repo.go b/models/repo/user_repo.go index 18a04f726..fe9677179 100644 --- a/models/repo/user_repo.go +++ b/models/repo/user_repo.go @@ -5,10 +5,7 @@ package repo import ( - "context" - "code.gitea.io/gitea/models/db" - user_model "code.gitea.io/gitea/models/user" ) // GetStarredRepos returns the repos starred by a particular user @@ -51,37 +48,3 @@ func GetWatchedRepos(userID int64, private bool, listOptions db.ListOptions) ([] total, err := sess.FindAndCount(&repos) return repos, total, err } - -// CountUserRepositories returns number of repositories user owns. -// Argument private only takes effect when it is false, -// set it true to count all repositories. -func CountUserRepositories(userID int64, private bool) int64 { - return countRepositories(userID, private) -} - -func getRepositoryCount(e db.Engine, ownerID int64) (int64, error) { - return e.Count(&Repository{OwnerID: ownerID}) -} - -func getPublicRepositoryCount(e db.Engine, u *user_model.User) (int64, error) { - return e.Where("is_private = ?", false).Count(&Repository{OwnerID: u.ID}) -} - -func getPrivateRepositoryCount(e db.Engine, u *user_model.User) (int64, error) { - return e.Where("is_private = ?", true).Count(&Repository{OwnerID: u.ID}) -} - -// GetRepositoryCount returns the total number of repositories of user. -func GetRepositoryCount(ctx context.Context, ownerID int64) (int64, error) { - return getRepositoryCount(db.GetEngine(ctx), ownerID) -} - -// GetPublicRepositoryCount returns the total number of public repositories of user. -func GetPublicRepositoryCount(u *user_model.User) (int64, error) { - return getPublicRepositoryCount(db.GetEngine(db.DefaultContext), u) -} - -// GetPrivateRepositoryCount returns the total number of private repositories of user. -func GetPrivateRepositoryCount(u *user_model.User) (int64, error) { - return getPrivateRepositoryCount(db.GetEngine(db.DefaultContext), u) -} diff --git a/models/repo/watch.go b/models/repo/watch.go index 8e54f0970..ecc25ee32 100644 --- a/models/repo/watch.go +++ b/models/repo/watch.go @@ -116,8 +116,8 @@ func WatchRepoMode(userID, repoID int64, mode WatchMode) (err error) { return watchRepoMode(db.DefaultContext, watch, mode) } -// WatchRepoCtx watch or unwatch repository. -func WatchRepoCtx(ctx context.Context, userID, repoID int64, doWatch bool) (err error) { +// WatchRepo watch or unwatch repository. +func WatchRepo(ctx context.Context, userID, repoID int64, doWatch bool) (err error) { var watch Watch if watch, err = GetWatch(ctx, userID, repoID); err != nil { return err @@ -132,11 +132,6 @@ func WatchRepoCtx(ctx context.Context, userID, repoID int64, doWatch bool) (err return err } -// WatchRepo watch or unwatch repository. -func WatchRepo(userID, repoID int64, watch bool) (err error) { - return WatchRepoCtx(db.DefaultContext, userID, repoID, watch) -} - // GetWatchers returns all watchers of given repository. func GetWatchers(ctx context.Context, repoID int64) ([]*Watch, error) { watches := make([]*Watch, 0, 10) @@ -176,7 +171,8 @@ func GetRepoWatchers(repoID int64, opts db.ListOptions) ([]*user_model.User, err return users, sess.Find(&users) } -func watchIfAuto(ctx context.Context, userID, repoID int64, isWrite bool) error { +// WatchIfAuto subscribes to repo if AutoWatchOnChanges is set +func WatchIfAuto(ctx context.Context, userID, repoID int64, isWrite bool) error { if !isWrite || !setting.Service.AutoWatchOnChanges { return nil } @@ -189,8 +185,3 @@ func watchIfAuto(ctx context.Context, userID, repoID int64, isWrite bool) error } return watchRepoMode(ctx, watch, WatchModeAuto) } - -// WatchIfAuto subscribes to repo if AutoWatchOnChanges is set -func WatchIfAuto(userID, repoID int64, isWrite bool) error { - return watchIfAuto(db.DefaultContext, userID, repoID, isWrite) -} diff --git a/models/repo/watch_test.go b/models/repo/watch_test.go index 2ff3ced2d..2f4e04ab1 100644 --- a/models/repo/watch_test.go +++ b/models/repo/watch_test.go @@ -73,13 +73,13 @@ func TestWatchIfAuto(t *testing.T) { prevCount := repo.NumWatches // Must not add watch - assert.NoError(t, WatchIfAuto(8, 1, true)) + assert.NoError(t, WatchIfAuto(db.DefaultContext, 8, 1, true)) watchers, err = GetRepoWatchers(repo.ID, db.ListOptions{Page: 1}) assert.NoError(t, err) assert.Len(t, watchers, prevCount) // Should not add watch - assert.NoError(t, WatchIfAuto(10, 1, true)) + assert.NoError(t, WatchIfAuto(db.DefaultContext, 10, 1, true)) watchers, err = GetRepoWatchers(repo.ID, db.ListOptions{Page: 1}) assert.NoError(t, err) assert.Len(t, watchers, prevCount) @@ -87,31 +87,31 @@ func TestWatchIfAuto(t *testing.T) { setting.Service.AutoWatchOnChanges = true // Must not add watch - assert.NoError(t, WatchIfAuto(8, 1, true)) + assert.NoError(t, WatchIfAuto(db.DefaultContext, 8, 1, true)) watchers, err = GetRepoWatchers(repo.ID, db.ListOptions{Page: 1}) assert.NoError(t, err) assert.Len(t, watchers, prevCount) // Should not add watch - assert.NoError(t, WatchIfAuto(12, 1, false)) + assert.NoError(t, WatchIfAuto(db.DefaultContext, 12, 1, false)) watchers, err = GetRepoWatchers(repo.ID, db.ListOptions{Page: 1}) assert.NoError(t, err) assert.Len(t, watchers, prevCount) // Should add watch - assert.NoError(t, WatchIfAuto(12, 1, true)) + assert.NoError(t, WatchIfAuto(db.DefaultContext, 12, 1, true)) watchers, err = GetRepoWatchers(repo.ID, db.ListOptions{Page: 1}) assert.NoError(t, err) assert.Len(t, watchers, prevCount+1) // Should remove watch, inhibit from adding auto - assert.NoError(t, WatchRepo(12, 1, false)) + assert.NoError(t, WatchRepo(db.DefaultContext, 12, 1, false)) watchers, err = GetRepoWatchers(repo.ID, db.ListOptions{Page: 1}) assert.NoError(t, err) assert.Len(t, watchers, prevCount) // Must not add watch - assert.NoError(t, WatchIfAuto(12, 1, true)) + assert.NoError(t, WatchIfAuto(db.DefaultContext, 12, 1, true)) watchers, err = GetRepoWatchers(repo.ID, db.ListOptions{Page: 1}) assert.NoError(t, err) assert.Len(t, watchers, prevCount) diff --git a/models/repo_collaboration.go b/models/repo_collaboration.go index 2069ce8cc..e20e96815 100644 --- a/models/repo_collaboration.go +++ b/models/repo_collaboration.go @@ -76,7 +76,7 @@ func DeleteCollaboration(repo *repo_model.Repository, uid int64) (err error) { return err } - if err = repo_model.WatchRepoCtx(ctx, uid, repo.ID, false); err != nil { + if err = repo_model.WatchRepo(ctx, uid, repo.ID, false); err != nil { return err } @@ -93,7 +93,7 @@ func DeleteCollaboration(repo *repo_model.Repository, uid int64) (err error) { } func reconsiderRepoIssuesAssignee(ctx context.Context, repo *repo_model.Repository, uid int64) error { - user, err := user_model.GetUserByIDEngine(db.GetEngine(ctx), uid) + user, err := user_model.GetUserByIDCtx(ctx, uid) if err != nil { return err } @@ -114,12 +114,12 @@ func reconsiderWatches(ctx context.Context, repo *repo_model.Repository, uid int if has, err := access_model.HasAccess(ctx, uid, repo); err != nil || has { return err } - if err := repo_model.WatchRepoCtx(ctx, uid, repo.ID, false); err != nil { + if err := repo_model.WatchRepo(ctx, uid, repo.ID, false); err != nil { return err } // Remove all IssueWatches a user has subscribed to in the repository - return removeIssueWatchersByRepoID(db.GetEngine(ctx), uid, repo.ID) + return removeIssueWatchersByRepoID(ctx, uid, repo.ID) } // IsOwnerMemberCollaborator checks if a provided user is the owner, a collaborator or a member of a team in a repository diff --git a/models/repo_generate.go b/models/repo_generate.go index 7d6d262aa..6b720b496 100644 --- a/models/repo_generate.go +++ b/models/repo_generate.go @@ -70,7 +70,7 @@ func (gt GiteaTemplate) Globs() []glob.Glob { // GenerateWebhooks generates webhooks from a template repository func GenerateWebhooks(ctx context.Context, templateRepo, generateRepo *repo_model.Repository) error { - templateWebhooks, err := webhook.ListWebhooksByOpts(&webhook.ListWebhookOptions{RepoID: templateRepo.ID}) + templateWebhooks, err := webhook.ListWebhooksByOpts(ctx, &webhook.ListWebhookOptions{RepoID: templateRepo.ID}) if err != nil { return err } @@ -98,7 +98,7 @@ func GenerateWebhooks(ctx context.Context, templateRepo, generateRepo *repo_mode // GenerateIssueLabels generates issue labels from a template repository func GenerateIssueLabels(ctx context.Context, templateRepo, generateRepo *repo_model.Repository) error { - templateLabels, err := getLabelsByRepoID(db.GetEngine(ctx), templateRepo.ID, "", db.ListOptions{}) + templateLabels, err := GetLabelsByRepoID(ctx, templateRepo.ID, "", db.ListOptions{}) if err != nil { return err } diff --git a/models/repo_list.go b/models/repo_list.go index 4b76cbc08..d1974d77e 100644 --- a/models/repo_list.go +++ b/models/repo_list.go @@ -5,6 +5,7 @@ package models import ( + "context" "fmt" "strings" @@ -57,7 +58,7 @@ func RepositoryListOfMap(repoMap map[int64]*repo_model.Repository) RepositoryLis return RepositoryList(valuesRepository(repoMap)) } -func (repos RepositoryList) loadAttributes(e db.Engine) error { +func (repos RepositoryList) loadAttributes(ctx context.Context) error { if len(repos) == 0 { return nil } @@ -71,7 +72,7 @@ func (repos RepositoryList) loadAttributes(e db.Engine) error { // Load owners. users := make(map[int64]*user_model.User, len(set)) - if err := e. + if err := db.GetEngine(ctx). Where("id > 0"). In("id", container.KeysInt64(set)). Find(&users); err != nil { @@ -83,7 +84,7 @@ func (repos RepositoryList) loadAttributes(e db.Engine) error { // Load primary language. stats := make(repo_model.LanguageStatList, 0, len(repos)) - if err := e. + if err := db.GetEngine(ctx). Where("`is_primary` = ? AND `language` != ?", true, "other"). In("`repo_id`", repoIDs). Find(&stats); err != nil { @@ -104,7 +105,7 @@ func (repos RepositoryList) loadAttributes(e db.Engine) error { // LoadAttributes loads the attributes for the given RepositoryList func (repos RepositoryList) LoadAttributes() error { - return repos.loadAttributes(db.GetEngine(db.DefaultContext)) + return repos.loadAttributes(db.DefaultContext) } // SearchRepoOptions holds the search options @@ -509,7 +510,8 @@ func SearchRepository(opts *SearchRepoOptions) (RepositoryList, int64, error) { // SearchRepositoryByCondition search repositories by condition func SearchRepositoryByCondition(opts *SearchRepoOptions, cond builder.Cond, loadAttributes bool) (RepositoryList, int64, error) { - sess, count, err := searchRepositoryByCondition(opts, cond) + ctx := db.DefaultContext + sess, count, err := searchRepositoryByCondition(ctx, opts, cond) if err != nil { return nil, 0, err } @@ -528,7 +530,7 @@ func SearchRepositoryByCondition(opts *SearchRepoOptions, cond builder.Cond, loa } if loadAttributes { - if err := repos.loadAttributes(sess); err != nil { + if err := repos.loadAttributes(ctx); err != nil { return nil, 0, fmt.Errorf("LoadAttributes: %v", err) } } @@ -536,7 +538,7 @@ func SearchRepositoryByCondition(opts *SearchRepoOptions, cond builder.Cond, loa return repos, count, nil } -func searchRepositoryByCondition(opts *SearchRepoOptions, cond builder.Cond) (db.Engine, int64, error) { +func searchRepositoryByCondition(ctx context.Context, opts *SearchRepoOptions, cond builder.Cond) (db.Engine, int64, error) { if opts.Page <= 0 { opts.Page = 1 } @@ -549,7 +551,7 @@ func searchRepositoryByCondition(opts *SearchRepoOptions, cond builder.Cond) (db opts.OrderBy = db.SearchOrderBy(fmt.Sprintf("CASE WHEN owner_id = %d THEN 0 ELSE owner_id END, %s", opts.PriorityOwnerID, opts.OrderBy)) } - sess := db.GetEngine(db.DefaultContext) + sess := db.GetEngine(ctx) var count int64 if opts.PageSize > 0 { @@ -619,7 +621,7 @@ func SearchRepositoryIDs(opts *SearchRepoOptions) ([]int64, int64, error) { cond := SearchRepositoryCondition(opts) - sess, count, err := searchRepositoryByCondition(opts, cond) + sess, count, err := searchRepositoryByCondition(db.DefaultContext, opts, cond) if err != nil { return nil, 0, err } diff --git a/models/repo_test.go b/models/repo_test.go index a93d84b81..dd1673f6b 100644 --- a/models/repo_test.go +++ b/models/repo_test.go @@ -27,11 +27,11 @@ func TestWatchRepo(t *testing.T) { const repoID = 3 const userID = 2 - assert.NoError(t, repo_model.WatchRepo(userID, repoID, true)) + assert.NoError(t, repo_model.WatchRepo(db.DefaultContext, userID, repoID, true)) unittest.AssertExistsAndLoadBean(t, &repo_model.Watch{RepoID: repoID, UserID: userID}) unittest.CheckConsistencyFor(t, &repo_model.Repository{ID: repoID}) - assert.NoError(t, repo_model.WatchRepo(userID, repoID, false)) + assert.NoError(t, repo_model.WatchRepo(db.DefaultContext, userID, repoID, false)) unittest.AssertNotExistsBean(t, &repo_model.Watch{RepoID: repoID, UserID: userID}) unittest.CheckConsistencyFor(t, &repo_model.Repository{ID: repoID}) } @@ -179,7 +179,7 @@ func TestLinkedRepository(t *testing.T) { } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - attach, err := repo_model.GetAttachmentByID(tc.attachID) + attach, err := repo_model.GetAttachmentByID(db.DefaultContext, tc.attachID) assert.NoError(t, err) repo, unitType, err := LinkedRepository(attach) assert.NoError(t, err) diff --git a/models/repo_transfer.go b/models/repo_transfer.go index b283bc8c7..79cfc699c 100644 --- a/models/repo_transfer.go +++ b/models/repo_transfer.go @@ -51,7 +51,7 @@ func (r *RepoTransfer) LoadAttributes() error { if r.Recipient.IsOrganization() && len(r.TeamIDs) != len(r.Teams) { for _, v := range r.TeamIDs { - team, err := organization.GetTeamByID(v) + team, err := organization.GetTeamByID(db.DefaultContext, v) if err != nil { return err } @@ -130,7 +130,7 @@ func CancelRepositoryTransfer(repo *repo_model.Repository) error { defer committer.Close() repo.Status = repo_model.RepositoryReady - if err := repo_model.UpdateRepositoryColsCtx(ctx, repo, "status"); err != nil { + if err := repo_model.UpdateRepositoryCols(ctx, repo, "status"); err != nil { return err } @@ -172,12 +172,12 @@ func CreatePendingRepositoryTransfer(doer, newOwner *user_model.User, repoID int } repo.Status = repo_model.RepositoryPendingTransfer - if err := repo_model.UpdateRepositoryColsCtx(ctx, repo, "status"); err != nil { + if err := repo_model.UpdateRepositoryCols(ctx, repo, "status"); err != nil { return err } // Check if new owner has repository with same name. - if has, err := repo_model.IsRepositoryExistCtx(ctx, newOwner, repo.Name); err != nil { + if has, err := repo_model.IsRepositoryExist(ctx, newOwner, repo.Name); err != nil { return fmt.Errorf("IsRepositoryExist: %v", err) } else if has { return repo_model.ErrRepoAlreadyExist{ @@ -250,14 +250,14 @@ func TransferOwnership(doer *user_model.User, newOwnerName string, repo *repo_mo sess := db.GetEngine(ctx) - newOwner, err := user_model.GetUserByNameCtx(ctx, newOwnerName) + newOwner, err := user_model.GetUserByName(ctx, newOwnerName) if err != nil { return fmt.Errorf("get new owner '%s': %v", newOwnerName, err) } newOwnerName = newOwner.Name // ensure capitalisation matches // Check if new owner has repository with same name. - if has, err := repo_model.IsRepositoryExistCtx(ctx, newOwner, repo.Name); err != nil { + if has, err := repo_model.IsRepositoryExist(ctx, newOwner, repo.Name); err != nil { return fmt.Errorf("IsRepositoryExist: %v", err) } else if has { return repo_model.ErrRepoAlreadyExist{ @@ -343,13 +343,13 @@ func TransferOwnership(doer *user_model.User, newOwnerName string, repo *repo_mo return fmt.Errorf("decrease old owner repository count: %v", err) } - if err := repo_model.WatchRepoCtx(ctx, doer.ID, repo.ID, true); err != nil { + if err := repo_model.WatchRepo(ctx, doer.ID, repo.ID, true); err != nil { return fmt.Errorf("watchRepo: %v", err) } // Remove watch for organization. if oldOwner.IsOrganization() { - if err := repo_model.WatchRepoCtx(ctx, oldOwner.ID, repo.ID, false); err != nil { + if err := repo_model.WatchRepo(ctx, oldOwner.ID, repo.ID, false); err != nil { return fmt.Errorf("watchRepo [false]: %v", err) } } @@ -410,7 +410,7 @@ func TransferOwnership(doer *user_model.User, newOwnerName string, repo *repo_mo return fmt.Errorf("deleteRepositoryTransfer: %v", err) } repo.Status = repo_model.RepositoryReady - if err := repo_model.UpdateRepositoryColsCtx(ctx, repo, "status"); err != nil { + if err := repo_model.UpdateRepositoryCols(ctx, repo, "status"); err != nil { return err } diff --git a/models/review.go b/models/review.go index 8917ea714..296c9ce04 100644 --- a/models/review.go +++ b/models/review.go @@ -93,26 +93,26 @@ func (r *Review) LoadCodeComments(ctx context.Context) (err error) { if r.CodeComments != nil { return } - if err = r.loadIssue(db.GetEngine(ctx)); err != nil { + if err = r.loadIssue(ctx); err != nil { return } r.CodeComments, err = fetchCodeCommentsByReview(ctx, r.Issue, nil, r) return } -func (r *Review) loadIssue(e db.Engine) (err error) { +func (r *Review) loadIssue(ctx context.Context) (err error) { if r.Issue != nil { return } - r.Issue, err = getIssueByID(e, r.IssueID) + r.Issue, err = getIssueByID(ctx, r.IssueID) return } -func (r *Review) loadReviewer(e db.Engine) (err error) { +func (r *Review) loadReviewer(ctx context.Context) (err error) { if r.ReviewerID == 0 || r.Reviewer != nil { return } - r.Reviewer, err = user_model.GetUserByIDEngine(e, r.ReviewerID) + r.Reviewer, err = user_model.GetUserByIDCtx(ctx, r.ReviewerID) return } @@ -121,13 +121,13 @@ func (r *Review) loadReviewerTeam(ctx context.Context) (err error) { return } - r.ReviewerTeam, err = organization.GetTeamByIDCtx(ctx, r.ReviewerTeamID) + r.ReviewerTeam, err = organization.GetTeamByID(ctx, r.ReviewerTeamID) return } // LoadReviewer loads reviewer func (r *Review) LoadReviewer() error { - return r.loadReviewer(db.GetEngine(db.DefaultContext)) + return r.loadReviewer(db.DefaultContext) } // LoadReviewerTeam loads reviewer team @@ -137,14 +137,13 @@ func (r *Review) LoadReviewerTeam() error { // LoadAttributes loads all attributes except CodeComments func (r *Review) LoadAttributes(ctx context.Context) (err error) { - e := db.GetEngine(ctx) - if err = r.loadIssue(e); err != nil { + if err = r.loadIssue(ctx); err != nil { return } if err = r.LoadCodeComments(ctx); err != nil { return } - if err = r.loadReviewer(e); err != nil { + if err = r.loadReviewer(ctx); err != nil { return } if err = r.loadReviewerTeam(ctx); err != nil { @@ -153,9 +152,10 @@ func (r *Review) LoadAttributes(ctx context.Context) (err error) { return } -func getReviewByID(e db.Engine, id int64) (*Review, error) { +// GetReviewByID returns the review by the given ID +func GetReviewByID(ctx context.Context, id int64) (*Review, error) { review := new(Review) - if has, err := e.ID(id).Get(review); err != nil { + if has, err := db.GetEngine(ctx).ID(id).Get(review); err != nil { return nil, err } else if !has { return nil, ErrReviewNotExist{ID: id} @@ -164,11 +164,6 @@ func getReviewByID(e db.Engine, id int64) (*Review, error) { } } -// GetReviewByID returns the review by the given ID -func GetReviewByID(id int64) (*Review, error) { - return getReviewByID(db.GetEngine(db.DefaultContext), id) -} - // FindReviewOptions represent possible filters to find reviews type FindReviewOptions struct { db.ListOptions @@ -195,9 +190,10 @@ func (opts *FindReviewOptions) toCond() builder.Cond { return cond } -func findReviews(e db.Engine, opts FindReviewOptions) ([]*Review, error) { +// FindReviews returns reviews passing FindReviewOptions +func FindReviews(ctx context.Context, opts FindReviewOptions) ([]*Review, error) { reviews := make([]*Review, 0, 10) - sess := e.Where(opts.toCond()) + sess := db.GetEngine(ctx).Where(opts.toCond()) if opts.Page > 0 { sess = db.SetSessionPagination(sess, &opts) } @@ -207,11 +203,6 @@ func findReviews(e db.Engine, opts FindReviewOptions) ([]*Review, error) { Find(&reviews) } -// FindReviews returns reviews passing FindReviewOptions -func FindReviews(opts FindReviewOptions) ([]*Review, error) { - return findReviews(db.GetEngine(db.DefaultContext), opts) -} - // CountReviews returns count of reviews passing FindReviewOptions func CountReviews(opts FindReviewOptions) (int64, error) { return db.GetEngine(db.DefaultContext).Where(opts.toCond()).Count(&Review{}) @@ -230,12 +221,8 @@ type CreateReviewOptions struct { } // IsOfficialReviewer check if at least one of the provided reviewers can make official reviews in issue (counts towards required approvals) -func IsOfficialReviewer(issue *Issue, reviewers ...*user_model.User) (bool, error) { - return isOfficialReviewer(db.DefaultContext, issue, reviewers...) -} - -func isOfficialReviewer(ctx context.Context, issue *Issue, reviewers ...*user_model.User) (bool, error) { - pr, err := getPullRequestByIssueID(db.GetEngine(ctx), issue.ID) +func IsOfficialReviewer(ctx context.Context, issue *Issue, reviewers ...*user_model.User) (bool, error) { + pr, err := GetPullRequestByIssueID(ctx, issue.ID) if err != nil { return false, err } @@ -257,12 +244,8 @@ func isOfficialReviewer(ctx context.Context, issue *Issue, reviewers ...*user_mo } // IsOfficialReviewerTeam check if reviewer in this team can make official reviews in issue (counts towards required approvals) -func IsOfficialReviewerTeam(issue *Issue, team *organization.Team) (bool, error) { - return isOfficialReviewerTeam(db.DefaultContext, issue, team) -} - -func isOfficialReviewerTeam(ctx context.Context, issue *Issue, team *organization.Team) (bool, error) { - pr, err := getPullRequestByIssueID(db.GetEngine(ctx), issue.ID) +func IsOfficialReviewerTeam(ctx context.Context, issue *Issue, team *organization.Team) (bool, error) { + pr, err := GetPullRequestByIssueID(ctx, issue.ID) if err != nil { return false, err } @@ -280,7 +263,8 @@ func isOfficialReviewerTeam(ctx context.Context, issue *Issue, team *organizatio return base.Int64sContains(pr.ProtectedBranch.ApprovalsWhitelistTeamIDs, team.ID), nil } -func createReview(e db.Engine, opts CreateReviewOptions) (*Review, error) { +// CreateReview creates a new review based on opts +func CreateReview(ctx context.Context, opts CreateReviewOptions) (*Review, error) { review := &Review{ Type: opts.Type, Issue: opts.Issue, @@ -300,23 +284,15 @@ func createReview(e db.Engine, opts CreateReviewOptions) (*Review, error) { } review.ReviewerTeamID = opts.ReviewerTeam.ID } - if _, err := e.Insert(review); err != nil { - return nil, err - } - - return review, nil -} - -// CreateReview creates a new review based on opts -func CreateReview(opts CreateReviewOptions) (*Review, error) { - return createReview(db.GetEngine(db.DefaultContext), opts) + return review, db.Insert(ctx, review) } -func getCurrentReview(e db.Engine, reviewer *user_model.User, issue *Issue) (*Review, error) { +// GetCurrentReview returns the current pending review of reviewer for given issue +func GetCurrentReview(ctx context.Context, reviewer *user_model.User, issue *Issue) (*Review, error) { if reviewer == nil { return nil, nil } - reviews, err := findReviews(e, FindReviewOptions{ + reviews, err := FindReviews(ctx, FindReviewOptions{ Type: ReviewTypePending, IssueID: issue.ID, ReviewerID: reviewer.ID, @@ -337,11 +313,6 @@ func ReviewExists(issue *Issue, treePath string, line int64) (bool, error) { return db.GetEngine(db.DefaultContext).Cols("id").Exist(&Comment{IssueID: issue.ID, TreePath: treePath, Line: line, Type: CommentTypeCode}) } -// GetCurrentReview returns the current pending review of reviewer for given issue -func GetCurrentReview(reviewer *user_model.User, issue *Issue) (*Review, error) { - return getCurrentReview(db.GetEngine(db.DefaultContext), reviewer, issue) -} - // ContentEmptyErr represents an content empty error type ContentEmptyErr struct{} @@ -366,7 +337,7 @@ func SubmitReview(doer *user_model.User, issue *Issue, reviewType ReviewType, co official := false - review, err := getCurrentReview(sess, doer, issue) + review, err := GetCurrentReview(ctx, doer, issue) if err != nil { if !IsErrReviewNotExist(err) { return nil, nil, err @@ -378,16 +349,16 @@ func SubmitReview(doer *user_model.User, issue *Issue, reviewType ReviewType, co if reviewType == ReviewTypeApprove || reviewType == ReviewTypeReject { // Only reviewers latest review of type approve and reject shall count as "official", so existing reviews needs to be cleared - if _, err := sess.Exec("UPDATE `review` SET official=? WHERE issue_id=? AND reviewer_id=?", false, issue.ID, doer.ID); err != nil { + if _, err := db.Exec(ctx, "UPDATE `review` SET official=? WHERE issue_id=? AND reviewer_id=?", false, issue.ID, doer.ID); err != nil { return nil, nil, err } - if official, err = isOfficialReviewer(ctx, issue, doer); err != nil { + if official, err = IsOfficialReviewer(ctx, issue, doer); err != nil { return nil, nil, err } } // No current review. Create a new one! - if review, err = createReview(sess, CreateReviewOptions{ + if review, err = CreateReview(ctx, CreateReviewOptions{ Type: reviewType, Issue: issue, Reviewer: doer, @@ -408,10 +379,10 @@ func SubmitReview(doer *user_model.User, issue *Issue, reviewType ReviewType, co if reviewType == ReviewTypeApprove || reviewType == ReviewTypeReject { // Only reviewers latest review of type approve and reject shall count as "official", so existing reviews needs to be cleared - if _, err := sess.Exec("UPDATE `review` SET official=? WHERE issue_id=? AND reviewer_id=?", false, issue.ID, doer.ID); err != nil { + if _, err := db.Exec(ctx, "UPDATE `review` SET official=? WHERE issue_id=? AND reviewer_id=?", false, issue.ID, doer.ID); err != nil { return nil, nil, err } - if official, err = isOfficialReviewer(ctx, issue, doer); err != nil { + if official, err = IsOfficialReviewer(ctx, issue, doer); err != nil { return nil, nil, err } } @@ -456,7 +427,7 @@ func SubmitReview(doer *user_model.User, issue *Issue, reviewType ReviewType, co continue } - if _, err := sess.Delete(teamReviewRequest); err != nil { + if _, err := sess.ID(teamReviewRequest.ID).NoAutoCondition().Delete(teamReviewRequest); err != nil { return nil, nil, err } } @@ -508,14 +479,10 @@ func GetReviewersFromOriginalAuthorsByIssueID(issueID int64) ([]*Review, error) } // GetReviewByIssueIDAndUserID get the latest review of reviewer for a pull request -func GetReviewByIssueIDAndUserID(issueID, userID int64) (*Review, error) { - return getReviewByIssueIDAndUserID(db.GetEngine(db.DefaultContext), issueID, userID) -} - -func getReviewByIssueIDAndUserID(e db.Engine, issueID, userID int64) (*Review, error) { +func GetReviewByIssueIDAndUserID(ctx context.Context, issueID, userID int64) (*Review, error) { review := new(Review) - has, err := e.SQL("SELECT * FROM review WHERE id IN (SELECT max(id) as id FROM review WHERE issue_id = ? AND reviewer_id = ? AND original_author_id = 0 AND type in (?, ?, ?))", + has, err := db.GetEngine(ctx).SQL("SELECT * FROM review WHERE id IN (SELECT max(id) as id FROM review WHERE issue_id = ? AND reviewer_id = ? AND original_author_id = 0 AND type in (?, ?, ?))", issueID, userID, ReviewTypeApprove, ReviewTypeReject, ReviewTypeRequest). Get(review) if err != nil { @@ -530,15 +497,11 @@ func getReviewByIssueIDAndUserID(e db.Engine, issueID, userID int64) (*Review, e } // GetTeamReviewerByIssueIDAndTeamID get the latest review request of reviewer team for a pull request -func GetTeamReviewerByIssueIDAndTeamID(issueID, teamID int64) (review *Review, err error) { - return getTeamReviewerByIssueIDAndTeamID(db.GetEngine(db.DefaultContext), issueID, teamID) -} - -func getTeamReviewerByIssueIDAndTeamID(e db.Engine, issueID, teamID int64) (review *Review, err error) { +func GetTeamReviewerByIssueIDAndTeamID(ctx context.Context, issueID, teamID int64) (review *Review, err error) { review = new(Review) has := false - if has, err = e.SQL("SELECT * FROM review WHERE id IN (SELECT max(id) as id FROM review WHERE issue_id = ? AND reviewer_team_id = ?)", + if has, err = db.GetEngine(ctx).SQL("SELECT * FROM review WHERE id IN (SELECT max(id) as id FROM review WHERE issue_id = ? AND reviewer_team_id = ?)", issueID, teamID). Get(review); err != nil { return nil, err @@ -633,7 +596,7 @@ func AddReviewRequest(issue *Issue, reviewer, doer *user_model.User) (*Comment, defer committer.Close() sess := db.GetEngine(ctx) - review, err := getReviewByIssueIDAndUserID(sess, issue.ID, reviewer.ID) + review, err := GetReviewByIssueIDAndUserID(ctx, issue.ID, reviewer.ID) if err != nil && !IsErrReviewNotExist(err) { return nil, err } @@ -643,7 +606,7 @@ func AddReviewRequest(issue *Issue, reviewer, doer *user_model.User) (*Comment, return nil, nil } - official, err := isOfficialReviewer(ctx, issue, reviewer, doer) + official, err := IsOfficialReviewer(ctx, issue, reviewer, doer) if err != nil { return nil, err } else if official { @@ -652,7 +615,7 @@ func AddReviewRequest(issue *Issue, reviewer, doer *user_model.User) (*Comment, } } - review, err = createReview(sess, CreateReviewOptions{ + review, err = CreateReview(ctx, CreateReviewOptions{ Type: ReviewTypeRequest, Issue: issue, Reviewer: reviewer, @@ -686,9 +649,8 @@ func RemoveReviewRequest(issue *Issue, reviewer, doer *user_model.User) (*Commen return nil, err } defer committer.Close() - sess := db.GetEngine(ctx) - review, err := getReviewByIssueIDAndUserID(sess, issue.ID, reviewer.ID) + review, err := GetReviewByIssueIDAndUserID(ctx, issue.ID, reviewer.ID) if err != nil && !IsErrReviewNotExist(err) { return nil, err } @@ -697,22 +659,22 @@ func RemoveReviewRequest(issue *Issue, reviewer, doer *user_model.User) (*Commen return nil, nil } - if _, err = sess.Delete(review); err != nil { + if _, err = db.DeleteByBean(ctx, review); err != nil { return nil, err } - official, err := isOfficialReviewer(ctx, issue, reviewer) + official, err := IsOfficialReviewer(ctx, issue, reviewer) if err != nil { return nil, err } else if official { // recalculate the latest official review for reviewer - review, err := getReviewByIssueIDAndUserID(sess, issue.ID, reviewer.ID) + review, err := GetReviewByIssueIDAndUserID(ctx, issue.ID, reviewer.ID) if err != nil && !IsErrReviewNotExist(err) { return nil, err } if review != nil { - if _, err := sess.Exec("UPDATE `review` SET official=? WHERE id=?", true, review.ID); err != nil { + if _, err := db.Exec(ctx, "UPDATE `review` SET official=? WHERE id=?", true, review.ID); err != nil { return nil, err } } @@ -740,9 +702,8 @@ func AddTeamReviewRequest(issue *Issue, reviewer *organization.Team, doer *user_ return nil, err } defer committer.Close() - sess := db.GetEngine(ctx) - review, err := getTeamReviewerByIssueIDAndTeamID(sess, issue.ID, reviewer.ID) + review, err := GetTeamReviewerByIssueIDAndTeamID(ctx, issue.ID, reviewer.ID) if err != nil && !IsErrReviewNotExist(err) { return nil, err } @@ -752,16 +713,16 @@ func AddTeamReviewRequest(issue *Issue, reviewer *organization.Team, doer *user_ return nil, nil } - official, err := isOfficialReviewerTeam(ctx, issue, reviewer) + official, err := IsOfficialReviewerTeam(ctx, issue, reviewer) if err != nil { return nil, fmt.Errorf("isOfficialReviewerTeam(): %v", err) } else if !official { - if official, err = isOfficialReviewer(ctx, issue, doer); err != nil { + if official, err = IsOfficialReviewer(ctx, issue, doer); err != nil { return nil, fmt.Errorf("isOfficialReviewer(): %v", err) } } - if review, err = createReview(sess, CreateReviewOptions{ + if review, err = CreateReview(ctx, CreateReviewOptions{ Type: ReviewTypeRequest, Issue: issue, ReviewerTeam: reviewer, @@ -772,7 +733,7 @@ func AddTeamReviewRequest(issue *Issue, reviewer *organization.Team, doer *user_ } if official { - if _, err := sess.Exec("UPDATE `review` SET official=? WHERE issue_id=? AND reviewer_team_id=?", false, issue.ID, reviewer.ID); err != nil { + if _, err := db.Exec(ctx, "UPDATE `review` SET official=? WHERE issue_id=? AND reviewer_team_id=?", false, issue.ID, reviewer.ID); err != nil { return nil, err } } @@ -800,9 +761,8 @@ func RemoveTeamReviewRequest(issue *Issue, reviewer *organization.Team, doer *us return nil, err } defer committer.Close() - sess := db.GetEngine(ctx) - review, err := getTeamReviewerByIssueIDAndTeamID(sess, issue.ID, reviewer.ID) + review, err := GetTeamReviewerByIssueIDAndTeamID(ctx, issue.ID, reviewer.ID) if err != nil && !IsErrReviewNotExist(err) { return nil, err } @@ -811,24 +771,24 @@ func RemoveTeamReviewRequest(issue *Issue, reviewer *organization.Team, doer *us return nil, nil } - if _, err = sess.Delete(review); err != nil { + if _, err = db.DeleteByBean(ctx, review); err != nil { return nil, err } - official, err := isOfficialReviewerTeam(ctx, issue, reviewer) + official, err := IsOfficialReviewerTeam(ctx, issue, reviewer) if err != nil { return nil, fmt.Errorf("isOfficialReviewerTeam(): %v", err) } if official { // recalculate which is the latest official review from that team - review, err := getReviewByIssueIDAndUserID(sess, issue.ID, -reviewer.ID) + review, err := GetReviewByIssueIDAndUserID(ctx, issue.ID, -reviewer.ID) if err != nil && !IsErrReviewNotExist(err) { return nil, err } if review != nil { - if _, err := sess.Exec("UPDATE `review` SET official=? WHERE id=?", true, review.ID); err != nil { + if _, err := db.Exec(ctx, "UPDATE `review` SET official=? WHERE id=?", true, review.ID); err != nil { return nil, err } } @@ -899,7 +859,7 @@ func CanMarkConversation(issue *Issue, doer *user_model.User) (permResult bool, permResult = p.CanAccess(perm.AccessModeWrite, unit.TypePullRequests) if !permResult { - if permResult, err = IsOfficialReviewer(issue, doer); err != nil { + if permResult, err = IsOfficialReviewer(db.DefaultContext, issue, doer); err != nil { return false, err } } diff --git a/models/review_test.go b/models/review_test.go index a4a71cc70..93291f9f5 100644 --- a/models/review_test.go +++ b/models/review_test.go @@ -16,12 +16,12 @@ import ( func TestGetReviewByID(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - review, err := GetReviewByID(1) + review, err := GetReviewByID(db.DefaultContext, 1) assert.NoError(t, err) assert.Equal(t, "Demo Review", review.Content) assert.Equal(t, ReviewTypeApprove, review.Type) - _, err = GetReviewByID(23892) + _, err = GetReviewByID(db.DefaultContext, 23892) assert.Error(t, err) assert.True(t, IsErrReviewNotExist(err), "IsErrReviewNotExist") } @@ -61,7 +61,7 @@ func TestReviewType_Icon(t *testing.T) { func TestFindReviews(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - reviews, err := FindReviews(FindReviewOptions{ + reviews, err := FindReviews(db.DefaultContext, FindReviewOptions{ Type: ReviewTypeApprove, IssueID: 2, ReviewerID: 1, @@ -76,14 +76,14 @@ func TestGetCurrentReview(t *testing.T) { issue := unittest.AssertExistsAndLoadBean(t, &Issue{ID: 2}).(*Issue) user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 1}).(*user_model.User) - review, err := GetCurrentReview(user, issue) + review, err := GetCurrentReview(db.DefaultContext, user, issue) assert.NoError(t, err) assert.NotNil(t, review) assert.Equal(t, ReviewTypePending, review.Type) assert.Equal(t, "Pending Review", review.Content) user2 := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 7}).(*user_model.User) - review2, err := GetCurrentReview(user2, issue) + review2, err := GetCurrentReview(db.DefaultContext, user2, issue) assert.Error(t, err) assert.True(t, IsErrReviewNotExist(err)) assert.Nil(t, review2) @@ -95,7 +95,7 @@ func TestCreateReview(t *testing.T) { issue := unittest.AssertExistsAndLoadBean(t, &Issue{ID: 2}).(*Issue) user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 1}).(*user_model.User) - review, err := CreateReview(CreateReviewOptions{ + review, err := CreateReview(db.DefaultContext, CreateReviewOptions{ Content: "New Review", Type: ReviewTypePending, Issue: issue, diff --git a/models/statistic.go b/models/statistic.go index 0f6359fb3..2b6b3a182 100644 --- a/models/statistic.go +++ b/models/statistic.go @@ -51,9 +51,9 @@ type IssueByRepositoryCount struct { func GetStatistic() (stats Statistic) { e := db.GetEngine(db.DefaultContext) stats.Counter.User = user_model.CountUsers(nil) - stats.Counter.Org = organization.CountOrganizations() + stats.Counter.Org, _ = organization.CountOrgs(organization.FindOrgOptions{IncludePrivate: true}) stats.Counter.PublicKey, _ = e.Count(new(asymkey_model.PublicKey)) - stats.Counter.Repo = repo_model.CountRepositories(true) + stats.Counter.Repo, _ = repo_model.CountRepositories(db.DefaultContext, repo_model.CountRepositoryOptions{}) stats.Counter.Watch, _ = e.Count(new(repo_model.Watch)) stats.Counter.Star, _ = e.Count(new(repo_model.Star)) stats.Counter.Action, _ = e.Count(new(Action)) diff --git a/models/task.go b/models/task.go index 5528573ca..cabb96c60 100644 --- a/models/task.go +++ b/models/task.go @@ -5,6 +5,7 @@ package models import ( + "context" "fmt" "code.gitea.io/gitea/models/db" @@ -51,15 +52,15 @@ type TranslatableMessage struct { // LoadRepo loads repository of the task func (task *Task) LoadRepo() error { - return task.loadRepo(db.GetEngine(db.DefaultContext)) + return task.loadRepo(db.DefaultContext) } -func (task *Task) loadRepo(e db.Engine) error { +func (task *Task) loadRepo(ctx context.Context) error { if task.Repo != nil { return nil } var repo repo_model.Repository - has, err := e.ID(task.RepoID).Get(&repo) + has, err := db.GetEngine(ctx).ID(task.RepoID).Get(&repo) if err != nil { return err } else if !has { @@ -233,12 +234,7 @@ func FindTasks(opts FindTaskOptions) ([]*Task, error) { // CreateTask creates a task on database func CreateTask(task *Task) error { - return createTask(db.GetEngine(db.DefaultContext), task) -} - -func createTask(e db.Engine, task *Task) error { - _, err := e.Insert(task) - return err + return db.Insert(db.DefaultContext, task) } // FinishMigrateTask updates database when migrate task finished diff --git a/models/user.go b/models/user.go index 6816527e4..e8a412ca2 100644 --- a/models/user.go +++ b/models/user.go @@ -166,7 +166,7 @@ func DeleteUser(ctx context.Context, u *user_model.User) (err error) { // ***** END: Branch Protections ***** // ***** START: PublicKey ***** - if _, err = e.Delete(&asymkey_model.PublicKey{OwnerID: u.ID}); err != nil { + if _, err = db.DeleteByBean(ctx, &asymkey_model.PublicKey{OwnerID: u.ID}); err != nil { return fmt.Errorf("deletePublicKeys: %v", err) } // ***** END: PublicKey ***** @@ -178,17 +178,17 @@ func DeleteUser(ctx context.Context, u *user_model.User) (err error) { } // Delete GPGKeyImport(s). for _, key := range keys { - if _, err = e.Delete(&asymkey_model.GPGKeyImport{KeyID: key.KeyID}); err != nil { + if _, err = db.DeleteByBean(ctx, &asymkey_model.GPGKeyImport{KeyID: key.KeyID}); err != nil { return fmt.Errorf("deleteGPGKeyImports: %v", err) } } - if _, err = e.Delete(&asymkey_model.GPGKey{OwnerID: u.ID}); err != nil { + if _, err = db.DeleteByBean(ctx, &asymkey_model.GPGKey{OwnerID: u.ID}); err != nil { return fmt.Errorf("deleteGPGKeys: %v", err) } // ***** END: GPGPublicKey ***** // Clear assignee. - if err = clearAssigneeByUserID(e, u.ID); err != nil { + if _, err = db.DeleteByBean(ctx, &IssueAssignees{AssigneeID: u.ID}); err != nil { return fmt.Errorf("clear assignee: %v", err) } diff --git a/models/user/avatar.go b/models/user/avatar.go index c881642b5..6a44a3bcb 100644 --- a/models/user/avatar.go +++ b/models/user/avatar.go @@ -26,12 +26,7 @@ func (u *User) CustomAvatarRelativePath() string { } // GenerateRandomAvatar generates a random avatar for user. -func GenerateRandomAvatar(u *User) error { - return GenerateRandomAvatarCtx(db.DefaultContext, u) -} - -// GenerateRandomAvatarCtx generates a random avatar for user. -func GenerateRandomAvatarCtx(ctx context.Context, u *User) error { +func GenerateRandomAvatar(ctx context.Context, u *User) error { seed := u.Email if len(seed) == 0 { seed = u.Name @@ -82,7 +77,7 @@ func (u *User) AvatarLinkWithSize(size int) string { if useLocalAvatar { if u.Avatar == "" && autoGenerateAvatar { - if err := GenerateRandomAvatar(u); err != nil { + if err := GenerateRandomAvatar(db.DefaultContext, u); err != nil { log.Error("GenerateRandomAvatar: %v", err) } } diff --git a/models/user/email_address.go b/models/user/email_address.go index 564d018da..c931db9c1 100644 --- a/models/user/email_address.go +++ b/models/user/email_address.go @@ -207,7 +207,8 @@ func IsEmailUsed(ctx context.Context, email string) (bool, error) { return db.GetEngine(ctx).Where("lower_email=?", strings.ToLower(email)).Get(&EmailAddress{}) } -func addEmailAddress(ctx context.Context, email *EmailAddress) error { +// AddEmailAddress adds an email address to given user. +func AddEmailAddress(ctx context.Context, email *EmailAddress) error { email.Email = strings.TrimSpace(email.Email) used, err := IsEmailUsed(ctx, email.Email) if err != nil { @@ -223,11 +224,6 @@ func addEmailAddress(ctx context.Context, email *EmailAddress) error { return db.Insert(ctx, email) } -// AddEmailAddress adds an email address to given user. -func AddEmailAddress(email *EmailAddress) error { - return addEmailAddress(db.DefaultContext, email) -} - // AddEmailAddresses adds an email address to given user. func AddEmailAddresses(emails []*EmailAddress) error { if len(emails) == 0 { @@ -311,14 +307,14 @@ func ActivateEmail(email *EmailAddress) error { return err } defer committer.Close() - if err := updateActivation(db.GetEngine(ctx), email, true); err != nil { + if err := updateActivation(ctx, email, true); err != nil { return err } return committer.Commit() } -func updateActivation(e db.Engine, email *EmailAddress, activate bool) error { - user, err := GetUserByIDEngine(e, email.UID) +func updateActivation(ctx context.Context, email *EmailAddress, activate bool) error { + user, err := GetUserByIDCtx(ctx, email.UID) if err != nil { return err } @@ -326,10 +322,10 @@ func updateActivation(e db.Engine, email *EmailAddress, activate bool) error { return err } email.IsActivated = activate - if _, err := e.ID(email.ID).Cols("is_activated").Update(email); err != nil { + if _, err := db.GetEngine(ctx).ID(email.ID).Cols("is_activated").Update(email); err != nil { return err } - return UpdateUserColsEngine(e, user, "rands") + return UpdateUserCols(ctx, user, "rands") } // MakeEmailPrimary sets primary email address of given user. @@ -500,12 +496,11 @@ func ActivateUserEmail(userID int64, email string, activate bool) (err error) { return err } defer committer.Close() - sess := db.GetEngine(ctx) // Activate/deactivate a user's secondary email address // First check if there's another user active with the same address addr := EmailAddress{UID: userID, LowerEmail: strings.ToLower(email)} - if has, err := sess.Get(&addr); err != nil { + if has, err := db.GetByBean(ctx, &addr); err != nil { return err } else if !has { return fmt.Errorf("no such email: %d (%s)", userID, email) @@ -521,14 +516,14 @@ func ActivateUserEmail(userID int64, email string, activate bool) (err error) { return ErrEmailAlreadyUsed{Email: email} } } - if err = updateActivation(sess, &addr, activate); err != nil { + if err = updateActivation(ctx, &addr, activate); err != nil { return fmt.Errorf("unable to updateActivation() for %d:%s: %w", addr.ID, addr.Email, err) } // Activate/deactivate a user's primary email address and account if addr.IsPrimary { user := User{ID: userID, Email: email} - if has, err := sess.Get(&user); err != nil { + if has, err := db.GetByBean(ctx, &user); err != nil { return err } else if !has { return fmt.Errorf("no user with ID: %d and Email: %s", userID, email) @@ -539,7 +534,7 @@ func ActivateUserEmail(userID int64, email string, activate bool) (err error) { if user.Rands, err = GetUserSalt(); err != nil { return fmt.Errorf("unable to generate salt: %v", err) } - if err = UpdateUserColsEngine(sess, &user, "is_active", "rands"); err != nil { + if err = UpdateUserCols(ctx, &user, "is_active", "rands"); err != nil { return fmt.Errorf("unable to updateUserCols() for user ID: %d: %v", userID, err) } } diff --git a/models/user/email_address_test.go b/models/user/email_address_test.go index 7eeb469b2..79de4c0b4 100644 --- a/models/user/email_address_test.go +++ b/models/user/email_address_test.go @@ -45,7 +45,7 @@ func TestIsEmailUsed(t *testing.T) { func TestAddEmailAddress(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - assert.NoError(t, AddEmailAddress(&EmailAddress{ + assert.NoError(t, AddEmailAddress(db.DefaultContext, &EmailAddress{ Email: "user1234567890@example.com", LowerEmail: "user1234567890@example.com", IsPrimary: true, @@ -53,7 +53,7 @@ func TestAddEmailAddress(t *testing.T) { })) // ErrEmailAlreadyUsed - err := AddEmailAddress(&EmailAddress{ + err := AddEmailAddress(db.DefaultContext, &EmailAddress{ Email: "user1234567890@example.com", LowerEmail: "user1234567890@example.com", }) diff --git a/models/user/list.go b/models/user/list.go index 5cdc92ba4..68e62ca15 100644 --- a/models/user/list.go +++ b/models/user/list.go @@ -5,6 +5,7 @@ package user import ( + "context" "fmt" "code.gitea.io/gitea/models/auth" @@ -31,13 +32,13 @@ func (users UserList) GetTwoFaStatus() map[int64]bool { results[user.ID] = false // Set default to false } - if tokenMaps, err := users.loadTwoFactorStatus(db.GetEngine(db.DefaultContext)); err == nil { + if tokenMaps, err := users.loadTwoFactorStatus(db.DefaultContext); err == nil { for _, token := range tokenMaps { results[token.UID] = true } } - if ids, err := users.userIDsWithWebAuthn(db.GetEngine(db.DefaultContext)); err == nil { + if ids, err := users.userIDsWithWebAuthn(db.DefaultContext); err == nil { for _, id := range ids { results[id] = true } @@ -46,25 +47,25 @@ func (users UserList) GetTwoFaStatus() map[int64]bool { return results } -func (users UserList) loadTwoFactorStatus(e db.Engine) (map[int64]*auth.TwoFactor, error) { +func (users UserList) loadTwoFactorStatus(ctx context.Context) (map[int64]*auth.TwoFactor, error) { if len(users) == 0 { return nil, nil } userIDs := users.GetUserIDs() tokenMaps := make(map[int64]*auth.TwoFactor, len(userIDs)) - if err := e.In("uid", userIDs).Find(&tokenMaps); err != nil { + if err := db.GetEngine(ctx).In("uid", userIDs).Find(&tokenMaps); err != nil { return nil, fmt.Errorf("find two factor: %v", err) } return tokenMaps, nil } -func (users UserList) userIDsWithWebAuthn(e db.Engine) ([]int64, error) { +func (users UserList) userIDsWithWebAuthn(ctx context.Context) ([]int64, error) { if len(users) == 0 { return nil, nil } ids := make([]int64, 0, len(users)) - if err := e.Table(new(auth.WebAuthnCredential)).In("user_id", users.GetUserIDs()).Select("user_id").Distinct("user_id").Find(&ids); err != nil { + if err := db.GetEngine(ctx).Table(new(auth.WebAuthnCredential)).In("user_id", users.GetUserIDs()).Select("user_id").Distinct("user_id").Find(&ids); err != nil { return nil, fmt.Errorf("find two factor: %v", err) } return ids, nil diff --git a/models/user/openid.go b/models/user/openid.go index 8ca3c7f2c..8ef0ce5ed 100644 --- a/models/user/openid.go +++ b/models/user/openid.go @@ -5,6 +5,7 @@ package user import ( + "context" "errors" "fmt" @@ -41,12 +42,12 @@ func GetUserOpenIDs(uid int64) ([]*UserOpenID, error) { } // isOpenIDUsed returns true if the openid has been used. -func isOpenIDUsed(e db.Engine, uri string) (bool, error) { +func isOpenIDUsed(ctx context.Context, uri string) (bool, error) { if len(uri) == 0 { return true, nil } - return e.Get(&UserOpenID{URI: uri}) + return db.GetEngine(ctx).Get(&UserOpenID{URI: uri}) } // ErrOpenIDAlreadyUsed represents a "OpenIDAlreadyUsed" kind of error. @@ -64,22 +65,17 @@ func (err ErrOpenIDAlreadyUsed) Error() string { return fmt.Sprintf("OpenID already in use [oid: %s]", err.OpenID) } +// AddUserOpenID adds an pre-verified/normalized OpenID URI to given user. // NOTE: make sure openid.URI is normalized already -func addUserOpenID(e db.Engine, openid *UserOpenID) error { - used, err := isOpenIDUsed(e, openid.URI) +func AddUserOpenID(ctx context.Context, openid *UserOpenID) error { + used, err := isOpenIDUsed(ctx, openid.URI) if err != nil { return err } else if used { return ErrOpenIDAlreadyUsed{openid.URI} } - _, err = e.Insert(openid) - return err -} - -// AddUserOpenID adds an pre-verified/normalized OpenID URI to given user. -func AddUserOpenID(openid *UserOpenID) error { - return addUserOpenID(db.GetEngine(db.DefaultContext), openid) + return db.Insert(ctx, openid) } // DeleteUserOpenID deletes an openid address of given user. diff --git a/models/user/user.go b/models/user/user.go index 6aa63a0a5..f7d457b91 100644 --- a/models/user/user.go +++ b/models/user/user.go @@ -509,23 +509,19 @@ func SetEmailNotifications(u *User, set string) error { return nil } -func isUserExist(e db.Engine, uid int64, name string) (bool, error) { +// IsUserExist checks if given user name exist, +// the user name should be noncased unique. +// If uid is presented, then check will rule out that one, +// it is used when update a user name in settings page. +func IsUserExist(ctx context.Context, uid int64, name string) (bool, error) { if len(name) == 0 { return false, nil } - return e. + return db.GetEngine(ctx). Where("id!=?", uid). Get(&User{LowerName: strings.ToLower(name)}) } -// IsUserExist checks if given user name exist, -// the user name should be noncased unique. -// If uid is presented, then check will rule out that one, -// it is used when update a user name in settings page. -func IsUserExist(uid int64, name string) (bool, error) { - return isUserExist(db.GetEngine(db.DefaultContext), uid, name) -} - // Note: As of the beginning of 2022, it is recommended to use at least // 64 bits of salt, but NIST is already recommending to use to 128 bits. // (16 bytes = 16 * 8 = 128 bits) @@ -691,9 +687,7 @@ func CreateUser(u *User, overwriteDefault ...*CreateUserOverwriteOptions) (err e } defer committer.Close() - sess := db.GetEngine(ctx) - - isExist, err := isUserExist(sess, 0, u.Name) + isExist, err := IsUserExist(ctx, 0, u.Name) if err != nil { return err } else if isExist { @@ -774,7 +768,7 @@ func GetVerifyUser(code string) (user *User) { // use tail hex username query user hexStr := code[base.TimeLimitCodeLength:] if b, err := hex.DecodeString(hexStr); err == nil { - if user, err = GetUserByName(string(b)); user != nil { + if user, err = GetUserByName(db.DefaultContext, string(b)); user != nil { return user } log.Error("user.getVerifyUser: %v", err) @@ -811,16 +805,15 @@ func ChangeUserName(u *User, newUserName string) (err error) { return err } defer committer.Close() - sess := db.GetEngine(ctx) - isExist, err := isUserExist(sess, 0, newUserName) + isExist, err := IsUserExist(ctx, 0, newUserName) if err != nil { return err } else if isExist { return ErrUserAlreadyExist{newUserName} } - if _, err = sess.Exec("UPDATE `repository` SET owner_name=? WHERE owner_name=?", newUserName, oldUserName); err != nil { + if _, err = db.GetEngine(ctx).Exec("UPDATE `repository` SET owner_name=? WHERE owner_name=?", newUserName, oldUserName); err != nil { return fmt.Errorf("Change repo owner name: %v", err) } @@ -845,9 +838,9 @@ func ChangeUserName(u *User, newUserName string) (err error) { } // checkDupEmail checks whether there are the same email with the user -func checkDupEmail(e db.Engine, u *User) error { +func checkDupEmail(ctx context.Context, u *User) error { u.Email = strings.ToLower(u.Email) - has, err := e. + has, err := db.GetEngine(ctx). Where("id!=?", u.ID). And("type=?", u.Type). And("email=?", u.Email). @@ -872,7 +865,8 @@ func validateUser(u *User) error { return ValidateEmail(u.Email) } -func updateUser(ctx context.Context, u *User, changePrimaryEmail bool, cols ...string) error { +// UpdateUser updates user's information. +func UpdateUser(ctx context.Context, u *User, changePrimaryEmail bool, cols ...string) error { err := validateUser(u) if err != nil { return err @@ -932,27 +926,13 @@ func updateUser(ctx context.Context, u *User, changePrimaryEmail bool, cols ...s return err } -// UpdateUser updates user's information. -func UpdateUser(u *User, emailChanged bool, cols ...string) error { - return updateUser(db.DefaultContext, u, emailChanged, cols...) -} - // UpdateUserCols update user according special columns func UpdateUserCols(ctx context.Context, u *User, cols ...string) error { - return updateUserCols(db.GetEngine(ctx), u, cols...) -} - -// UpdateUserColsEngine update user according special columns -func UpdateUserColsEngine(e db.Engine, u *User, cols ...string) error { - return updateUserCols(e, u, cols...) -} - -func updateUserCols(e db.Engine, u *User, cols ...string) error { if err := validateUser(u); err != nil { return err } - _, err := e.ID(u.ID).Cols(cols...).Update(u) + _, err := db.GetEngine(ctx).ID(u.ID).Cols(cols...).Update(u) return err } @@ -965,11 +945,11 @@ func UpdateUserSetting(u *User) (err error) { defer committer.Close() if !u.IsOrganization() { - if err = checkDupEmail(db.GetEngine(ctx), u); err != nil { + if err = checkDupEmail(ctx, u); err != nil { return err } } - if err = updateUser(ctx, u, false); err != nil { + if err = UpdateUser(ctx, u, false); err != nil { return err } return committer.Commit() @@ -994,18 +974,6 @@ func UserPath(userName string) string { //revive:disable-line:exported return filepath.Join(setting.RepoRootPath, strings.ToLower(userName)) } -// GetUserByIDEngine returns the user object by given ID if exists. -func GetUserByIDEngine(e db.Engine, id int64) (*User, error) { - u := new(User) - has, err := e.ID(id).Get(u) - if err != nil { - return nil, err - } else if !has { - return nil, ErrUserNotExist{id, "", 0} - } - return u, nil -} - // GetUserByID returns the user object by given ID if exists. func GetUserByID(id int64) (*User, error) { return GetUserByIDCtx(db.DefaultContext, id) @@ -1013,16 +981,18 @@ func GetUserByID(id int64) (*User, error) { // GetUserByIDCtx returns the user object by given ID if exists. func GetUserByIDCtx(ctx context.Context, id int64) (*User, error) { - return GetUserByIDEngine(db.GetEngine(ctx), id) -} - -// GetUserByName returns user by given name. -func GetUserByName(name string) (*User, error) { - return GetUserByNameCtx(db.DefaultContext, name) + u := new(User) + has, err := db.GetEngine(ctx).ID(id).Get(u) + if err != nil { + return nil, err + } else if !has { + return nil, ErrUserNotExist{id, "", 0} + } + return u, nil } // GetUserByNameCtx returns user by given name. -func GetUserByNameCtx(ctx context.Context, name string) (*User, error) { +func GetUserByName(ctx context.Context, name string) (*User, error) { if len(name) == 0 { return nil, ErrUserNotExist{0, name, 0} } @@ -1038,14 +1008,10 @@ func GetUserByNameCtx(ctx context.Context, name string) (*User, error) { // GetUserEmailsByNames returns a list of e-mails corresponds to names of users // that have their email notifications set to enabled or onmention. -func GetUserEmailsByNames(names []string) []string { - return getUserEmailsByNames(db.DefaultContext, names) -} - -func getUserEmailsByNames(ctx context.Context, names []string) []string { +func GetUserEmailsByNames(ctx context.Context, names []string) []string { mails := make([]string, 0, len(names)) for _, name := range names { - u, err := GetUserByNameCtx(ctx, name) + u, err := GetUserByName(ctx, name) if err != nil { continue } @@ -1108,7 +1074,7 @@ func GetUserNameByID(ctx context.Context, id int64) (string, error) { func GetUserIDsByNames(names []string, ignoreNonExistent bool) ([]int64, error) { ids := make([]int64, 0, len(names)) for _, name := range names { - u, err := GetUserByName(name) + u, err := GetUserByName(db.DefaultContext, name) if err != nil { if ignoreNonExistent { continue @@ -1254,11 +1220,7 @@ func GetAdminUser() (*User, error) { } // IsUserVisibleToViewer check if viewer is able to see user profile -func IsUserVisibleToViewer(u, viewer *User) bool { - return isUserVisibleToViewer(db.GetEngine(db.DefaultContext), u, viewer) -} - -func isUserVisibleToViewer(e db.Engine, u, viewer *User) bool { +func IsUserVisibleToViewer(ctx context.Context, u, viewer *User) bool { if viewer != nil && viewer.IsAdmin { return true } @@ -1283,7 +1245,7 @@ func isUserVisibleToViewer(e db.Engine, u, viewer *User) bool { } // Now we need to check if they in some organization together - count, err := e.Table("team_user"). + count, err := db.GetEngine(ctx).Table("team_user"). Where( builder.And( builder.Eq{"uid": viewer.ID}, diff --git a/models/user/user_test.go b/models/user/user_test.go index 335537aa1..0dbf2fc20 100644 --- a/models/user/user_test.go +++ b/models/user/user_test.go @@ -31,10 +31,10 @@ func TestGetUserEmailsByNames(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) // ignore none active user email - assert.Equal(t, []string{"user8@example.com"}, GetUserEmailsByNames([]string{"user8", "user9"})) - assert.Equal(t, []string{"user8@example.com", "user5@example.com"}, GetUserEmailsByNames([]string{"user8", "user5"})) + assert.Equal(t, []string{"user8@example.com"}, GetUserEmailsByNames(db.DefaultContext, []string{"user8", "user9"})) + assert.Equal(t, []string{"user8@example.com", "user5@example.com"}, GetUserEmailsByNames(db.DefaultContext, []string{"user8", "user5"})) - assert.Equal(t, []string{"user8@example.com"}, GetUserEmailsByNames([]string{"user8", "user7"})) + assert.Equal(t, []string{"user8@example.com"}, GetUserEmailsByNames(db.DefaultContext, []string{"user8", "user7"})) } func TestCanCreateOrganization(t *testing.T) { @@ -287,19 +287,19 @@ func TestUpdateUser(t *testing.T) { user := unittest.AssertExistsAndLoadBean(t, &User{ID: 2}).(*User) user.KeepActivityPrivate = true - assert.NoError(t, UpdateUser(user, false)) + assert.NoError(t, UpdateUser(db.DefaultContext, user, false)) user = unittest.AssertExistsAndLoadBean(t, &User{ID: 2}).(*User) assert.True(t, user.KeepActivityPrivate) setting.Service.AllowedUserVisibilityModesSlice = []bool{true, false, false} user.KeepActivityPrivate = false user.Visibility = structs.VisibleTypePrivate - assert.Error(t, UpdateUser(user, false)) + assert.Error(t, UpdateUser(db.DefaultContext, user, false)) user = unittest.AssertExistsAndLoadBean(t, &User{ID: 2}).(*User) assert.True(t, user.KeepActivityPrivate) user.Email = "no mail@mail.org" - assert.Error(t, UpdateUser(user, true)) + assert.Error(t, UpdateUser(db.DefaultContext, user, true)) } func TestNewUserRedirect(t *testing.T) { diff --git a/models/webhook/webhook.go b/models/webhook/webhook.go index 941a3f15c..5eea97772 100644 --- a/models/webhook/webhook.go +++ b/models/webhook/webhook.go @@ -454,8 +454,9 @@ func (opts *ListWebhookOptions) toCond() builder.Cond { return cond } -func listWebhooksByOpts(e db.Engine, opts *ListWebhookOptions) ([]*Webhook, error) { - sess := e.Where(opts.toCond()) +// ListWebhooksByOpts return webhooks based on options +func ListWebhooksByOpts(ctx context.Context, opts *ListWebhookOptions) ([]*Webhook, error) { + sess := db.GetEngine(ctx).Where(opts.toCond()) if opts.Page != 0 { sess = db.SetSessionPagination(sess, opts) @@ -469,22 +470,13 @@ func listWebhooksByOpts(e db.Engine, opts *ListWebhookOptions) ([]*Webhook, erro return webhooks, err } -// ListWebhooksByOpts return webhooks based on options -func ListWebhooksByOpts(opts *ListWebhookOptions) ([]*Webhook, error) { - return listWebhooksByOpts(db.GetEngine(db.DefaultContext), opts) -} - // CountWebhooksByOpts count webhooks based on options and ignore pagination func CountWebhooksByOpts(opts *ListWebhookOptions) (int64, error) { return db.GetEngine(db.DefaultContext).Where(opts.toCond()).Count(&Webhook{}) } // GetDefaultWebhooks returns all admin-default webhooks. -func GetDefaultWebhooks() ([]*Webhook, error) { - return getDefaultWebhooks(db.DefaultContext) -} - -func getDefaultWebhooks(ctx context.Context) ([]*Webhook, error) { +func GetDefaultWebhooks(ctx context.Context) ([]*Webhook, error) { webhooks := make([]*Webhook, 0, 5) return webhooks, db.GetEngine(ctx). Where("repo_id=? AND org_id=? AND is_system_webhook=?", 0, 0, false). @@ -506,18 +498,14 @@ func GetSystemOrDefaultWebhook(id int64) (*Webhook, error) { } // GetSystemWebhooks returns all admin system webhooks. -func GetSystemWebhooks(isActive util.OptionalBool) ([]*Webhook, error) { - return getSystemWebhooks(db.GetEngine(db.DefaultContext), isActive) -} - -func getSystemWebhooks(e db.Engine, isActive util.OptionalBool) ([]*Webhook, error) { +func GetSystemWebhooks(ctx context.Context, isActive util.OptionalBool) ([]*Webhook, error) { webhooks := make([]*Webhook, 0, 5) if isActive.IsNone() { - return webhooks, e. + return webhooks, db.GetEngine(ctx). Where("repo_id=? AND org_id=? AND is_system_webhook=?", 0, 0, true). Find(&webhooks) } - return webhooks, e. + return webhooks, db.GetEngine(ctx). Where("repo_id=? AND org_id=? AND is_system_webhook=? AND is_active = ?", 0, 0, true, isActive.IsTrue()). Find(&webhooks) } @@ -596,7 +584,7 @@ func DeleteDefaultSystemWebhook(id int64) error { // CopyDefaultWebhooksToRepo creates copies of the default webhooks in a new repo func CopyDefaultWebhooksToRepo(ctx context.Context, repoID int64) error { - ws, err := getDefaultWebhooks(ctx) + ws, err := GetDefaultWebhooks(ctx) if err != nil { return fmt.Errorf("GetDefaultWebhooks: %v", err) } diff --git a/models/webhook/webhook_test.go b/models/webhook/webhook_test.go index 5ce564b77..4bc811586 100644 --- a/models/webhook/webhook_test.go +++ b/models/webhook/webhook_test.go @@ -122,7 +122,7 @@ func TestGetWebhookByOrgID(t *testing.T) { func TestGetActiveWebhooksByRepoID(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - hooks, err := ListWebhooksByOpts(&ListWebhookOptions{RepoID: 1, IsActive: util.OptionalBoolTrue}) + hooks, err := ListWebhooksByOpts(db.DefaultContext, &ListWebhookOptions{RepoID: 1, IsActive: util.OptionalBoolTrue}) assert.NoError(t, err) if assert.Len(t, hooks, 1) { assert.Equal(t, int64(1), hooks[0].ID) @@ -132,7 +132,7 @@ func TestGetActiveWebhooksByRepoID(t *testing.T) { func TestGetWebhooksByRepoID(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - hooks, err := ListWebhooksByOpts(&ListWebhookOptions{RepoID: 1}) + hooks, err := ListWebhooksByOpts(db.DefaultContext, &ListWebhookOptions{RepoID: 1}) assert.NoError(t, err) if assert.Len(t, hooks, 2) { assert.Equal(t, int64(1), hooks[0].ID) @@ -142,7 +142,7 @@ func TestGetWebhooksByRepoID(t *testing.T) { func TestGetActiveWebhooksByOrgID(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - hooks, err := ListWebhooksByOpts(&ListWebhookOptions{OrgID: 3, IsActive: util.OptionalBoolTrue}) + hooks, err := ListWebhooksByOpts(db.DefaultContext, &ListWebhookOptions{OrgID: 3, IsActive: util.OptionalBoolTrue}) assert.NoError(t, err) if assert.Len(t, hooks, 1) { assert.Equal(t, int64(3), hooks[0].ID) @@ -152,7 +152,7 @@ func TestGetActiveWebhooksByOrgID(t *testing.T) { func TestGetWebhooksByOrgID(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - hooks, err := ListWebhooksByOpts(&ListWebhookOptions{OrgID: 3}) + hooks, err := ListWebhooksByOpts(db.DefaultContext, &ListWebhookOptions{OrgID: 3}) assert.NoError(t, err) if assert.Len(t, hooks, 1) { assert.Equal(t, int64(3), hooks[0].ID) diff --git a/modules/context/repo.go b/modules/context/repo.go index eb773dfb2..539b111f1 100644 --- a/modules/context/repo.go +++ b/modules/context/repo.go @@ -16,6 +16,7 @@ import ( "strings" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/db" access_model "code.gitea.io/gitea/models/perm/access" repo_model "code.gitea.io/gitea/models/repo" unit_model "code.gitea.io/gitea/models/unit" @@ -116,7 +117,7 @@ type CanCommitToBranchResults struct { // CanCommitToBranch returns true if repository is editable and user has proper access level // and branch is not protected for push func (r *Repository) CanCommitToBranch(ctx context.Context, doer *user_model.User) (CanCommitToBranchResults, error) { - protectedBranch, err := models.GetProtectedBranchBy(r.Repository.ID, r.BranchName) + protectedBranch, err := models.GetProtectedBranchBy(ctx, r.Repository.ID, r.BranchName) if err != nil { return CanCommitToBranchResults{}, err } @@ -159,7 +160,7 @@ func (r *Repository) CanUseTimetracker(issue *models.Issue, user *user_model.Use // Checking for following: // 1. Is timetracker enabled // 2. Is the user a contributor, admin, poster or assignee and do the repository policies require this? - isAssigned, _ := models.IsUserAssignedToIssue(issue, user) + isAssigned, _ := models.IsUserAssignedToIssue(db.DefaultContext, issue, user) return r.Repository.IsTimetrackerEnabled() && (!r.Repository.AllowOnlyContributorsToTrackTime() || r.Permission.CanWriteIssuesOrPulls(issue.IsPull) || issue.IsPoster(user.ID) || isAssigned) } @@ -278,7 +279,7 @@ func RetrieveBaseRepo(ctx *Context, repo *repo_model.Repository) { // RetrieveTemplateRepo retrieves template repository used to generate this repository func RetrieveTemplateRepo(ctx *Context, repo *repo_model.Repository) { // Non-generated repository will not return error in this method. - templateRepo, err := repo_model.GetTemplateRepo(repo) + templateRepo, err := repo_model.GetTemplateRepo(ctx, repo) if err != nil { if repo_model.IsErrRepoNotExist(err) { repo.TemplateID = 0 @@ -385,11 +386,12 @@ func repoAssignment(ctx *Context, repo *repo_model.Repository) { return } if finishedMigrating { - ctx.Repo.Mirror, err = repo_model.GetMirrorByRepoID(repo.ID) + ctx.Repo.Mirror, err = repo_model.GetMirrorByRepoID(ctx, repo.ID) if err != nil { ctx.ServerError("GetMirrorByRepoID", err) return } + ctx.Repo.Mirror.Repo = ctx.Repo.Repository ctx.Data["MirrorEnablePrune"] = ctx.Repo.Mirror.EnablePrune ctx.Data["MirrorInterval"] = ctx.Repo.Mirror.Interval ctx.Data["Mirror"] = ctx.Repo.Mirror @@ -451,7 +453,7 @@ func RepoAssignment(ctx *Context) (cancel context.CancelFunc) { if ctx.IsSigned && ctx.Doer.LowerName == strings.ToLower(userName) { owner = ctx.Doer } else { - owner, err = user_model.GetUserByName(userName) + owner, err = user_model.GetUserByName(ctx, userName) if err != nil { if user_model.IsErrUserNotExist(err) { if ctx.FormString("go-get") == "1" { @@ -587,7 +589,7 @@ func RepoAssignment(ctx *Context) (cancel context.CancelFunc) { if ctx.IsSigned { ctx.Data["IsWatchingRepo"] = repo_model.IsWatching(ctx.Doer.ID, repo.ID) - ctx.Data["IsStaringRepo"] = repo_model.IsStaring(ctx.Doer.ID, repo.ID) + ctx.Data["IsStaringRepo"] = repo_model.IsStaring(ctx, ctx.Doer.ID, repo.ID) } if repo.IsFork { diff --git a/modules/convert/convert.go b/modules/convert/convert.go index 53357e750..67b3902cd 100644 --- a/modules/convert/convert.go +++ b/modules/convert/convert.go @@ -340,7 +340,7 @@ func ToTeams(teams []*organization.Team, loadOrgs bool) ([]*api.Team, error) { if loadOrgs { apiOrg, ok := cache[teams[i].OrgID] if !ok { - org, err := organization.GetOrgByID(teams[i].OrgID) + org, err := organization.GetOrgByID(db.DefaultContext, teams[i].OrgID) if err != nil { return nil, err } diff --git a/modules/convert/issue.go b/modules/convert/issue.go index bf116e228..a4512e424 100644 --- a/modules/convert/issue.go +++ b/modules/convert/issue.go @@ -72,7 +72,7 @@ func ToAPIIssue(issue *models.Issue) *api.Issue { apiIssue.Milestone = ToAPIMilestone(issue.Milestone) } - if err := issue.LoadAssignees(); err != nil { + if err := issue.LoadAssignees(db.DefaultContext); err != nil { return &api.Issue{} } if len(issue.Assignees) > 0 { diff --git a/modules/convert/issue_comment.go b/modules/convert/issue_comment.go index 6d72849bc..eaa7f64ea 100644 --- a/modules/convert/issue_comment.go +++ b/modules/convert/issue_comment.go @@ -6,6 +6,7 @@ package convert import ( "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/db" repo_model "code.gitea.io/gitea/models/repo" user_model "code.gitea.io/gitea/models/user" "code.gitea.io/gitea/modules/log" @@ -113,7 +114,7 @@ func ToTimelineComment(c *models.Comment, doer *user_model.User) *api.TimelineCo } if c.RefCommentID != 0 { - com, err := models.GetCommentByID(c.RefCommentID) + com, err := models.GetCommentByID(db.DefaultContext, c.RefCommentID) if err != nil { log.Error("GetCommentByID(%d): %v", c.RefCommentID, err) return nil diff --git a/modules/convert/repository.go b/modules/convert/repository.go index b813d6969..eb6bb3770 100644 --- a/modules/convert/repository.go +++ b/modules/convert/repository.go @@ -104,7 +104,7 @@ func innerToRepo(repo *repo_model.Repository, mode perm.AccessMode, isParent boo var mirrorUpdated time.Time if repo.IsMirror { var err error - repo.Mirror, err = repo_model.GetMirrorByRepoID(repo.ID) + repo.Mirror, err = repo_model.GetMirrorByRepoID(db.DefaultContext, repo.ID) if err == nil { mirrorInterval = repo.Mirror.Interval.String() mirrorUpdated = repo.Mirror.UpdatedUnix.AsTime() diff --git a/modules/doctor/authorizedkeys.go b/modules/doctor/authorizedkeys.go index 18e7a3cbf..34dfe939d 100644 --- a/modules/doctor/authorizedkeys.go +++ b/modules/doctor/authorizedkeys.go @@ -54,7 +54,7 @@ func checkAuthorizedKeys(ctx context.Context, logger log.Logger, autofix bool) e // now we regenerate and check if there are any lines missing regenerated := &bytes.Buffer{} - if err := asymkey_model.RegeneratePublicKeys(regenerated); err != nil { + if err := asymkey_model.RegeneratePublicKeys(ctx, regenerated); err != nil { logger.Critical("Unable to regenerate authorized_keys file. ERROR: %v", err) return fmt.Errorf("Unable to regenerate authorized_keys file. ERROR: %v", err) } diff --git a/modules/gitgraph/graph_models.go b/modules/gitgraph/graph_models.go index 653384252..551e56f63 100644 --- a/modules/gitgraph/graph_models.go +++ b/modules/gitgraph/graph_models.go @@ -120,7 +120,7 @@ func (graph *Graph) LoadAndProcessCommits(repository *repo_model.Repository, git return models.IsOwnerMemberCollaborator(repository, user.ID) }, &keyMap) - statuses, _, err := models.GetLatestCommitStatus(repository.ID, c.Commit.ID.String(), db.ListOptions{}) + statuses, _, err := models.GetLatestCommitStatus(db.DefaultContext, repository.ID, c.Commit.ID.String(), db.ListOptions{}) if err != nil { log.Error("GetLatestCommitStatus: %v", err) } else { diff --git a/modules/indexer/code/git.go b/modules/indexer/code/git.go index 60018af20..66d76377a 100644 --- a/modules/indexer/code/git.go +++ b/modules/indexer/code/git.go @@ -38,7 +38,7 @@ func getDefaultBranchSha(ctx context.Context, repo *repo_model.Repository) (stri // getRepoChanges returns changes to repo since last indexer update func getRepoChanges(ctx context.Context, repo *repo_model.Repository, revision string) (*repoChanges, error) { - status, err := repo_model.GetIndexerStatus(repo, repo_model.RepoIndexerTypeCode) + status, err := repo_model.GetIndexerStatus(ctx, repo, repo_model.RepoIndexerTypeCode) if err != nil { return nil, err } diff --git a/modules/indexer/code/indexer.go b/modules/indexer/code/indexer.go index f15b8d865..9845ade3d 100644 --- a/modules/indexer/code/indexer.go +++ b/modules/indexer/code/indexer.go @@ -108,7 +108,7 @@ func index(ctx context.Context, indexer Indexer, repoID int64) error { return err } - return repo_model.UpdateIndexerStatus(repo, repo_model.RepoIndexerTypeCode, sha) + return repo_model.UpdateIndexerStatus(ctx, repo, repo_model.RepoIndexerTypeCode, sha) } // Init initialize the repo indexer diff --git a/modules/indexer/issues/indexer.go b/modules/indexer/issues/indexer.go index d4df4f8a4..7adc938dc 100644 --- a/modules/indexer/issues/indexer.go +++ b/modules/indexer/issues/indexer.go @@ -362,7 +362,7 @@ func UpdateIssueIndexer(issue *models.Issue) { // DeleteRepoIssueIndexer deletes repo's all issues indexes func DeleteRepoIssueIndexer(repo *repo_model.Repository) { var ids []int64 - ids, err := models.GetIssueIDsByRepoID(repo.ID) + ids, err := models.GetIssueIDsByRepoID(db.DefaultContext, repo.ID) if err != nil { log.Error("getIssueIDsByRepoID failed: %v", err) return diff --git a/modules/indexer/stats/db.go b/modules/indexer/stats/db.go index bb3385ab6..d39b1dcf2 100644 --- a/modules/indexer/stats/db.go +++ b/modules/indexer/stats/db.go @@ -30,7 +30,7 @@ func (db *DBIndexer) Index(id int64) error { return nil } - status, err := repo_model.GetIndexerStatus(repo, repo_model.RepoIndexerTypeStats) + status, err := repo_model.GetIndexerStatus(ctx, repo, repo_model.RepoIndexerTypeStats) if err != nil { return err } diff --git a/modules/indexer/stats/indexer_test.go b/modules/indexer/stats/indexer_test.go index c8bd8d178..9d9de5413 100644 --- a/modules/indexer/stats/indexer_test.go +++ b/modules/indexer/stats/indexer_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "code.gitea.io/gitea/models/db" repo_model "code.gitea.io/gitea/models/repo" "code.gitea.io/gitea/models/unittest" "code.gitea.io/gitea/modules/git" @@ -49,7 +50,7 @@ func TestRepoStatsIndex(t *testing.T) { queue.GetManager().FlushAll(context.Background(), 5*time.Second) - status, err := repo_model.GetIndexerStatus(repo, repo_model.RepoIndexerTypeStats) + status, err := repo_model.GetIndexerStatus(db.DefaultContext, repo, repo_model.RepoIndexerTypeStats) assert.NoError(t, err) assert.Equal(t, "65f1bf27bc3bf70f64657658635e66094edbcb4d", status.CommitSha) langs, err := repo_model.GetTopLanguageStats(repo, 5) diff --git a/modules/ssh/ssh.go b/modules/ssh/ssh.go index 44ed431c9..fe3561cef 100644 --- a/modules/ssh/ssh.go +++ b/modules/ssh/ssh.go @@ -174,7 +174,7 @@ func publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool { // look for the exact principal principalLoop: for _, principal := range cert.ValidPrincipals { - pkey, err := asymkey_model.SearchPublicKeyByContentExact(principal) + pkey, err := asymkey_model.SearchPublicKeyByContentExact(ctx, principal) if err != nil { if asymkey_model.IsErrKeyNotExist(err) { log.Debug("Principal Rejected: %s Unknown Principal: %s", ctx.RemoteAddr(), principal) @@ -234,7 +234,7 @@ func publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool { log.Debug("Handle Public Key: %s Fingerprint: %s is not a certificate", ctx.RemoteAddr(), gossh.FingerprintSHA256(key)) } - pkey, err := asymkey_model.SearchPublicKeyByContent(strings.TrimSpace(string(gossh.MarshalAuthorizedKey(key)))) + pkey, err := asymkey_model.SearchPublicKeyByContent(ctx, strings.TrimSpace(string(gossh.MarshalAuthorizedKey(key)))) if err != nil { if asymkey_model.IsErrKeyNotExist(err) { if log.IsWarn() { diff --git a/routers/api/v1/admin/adopt.go b/routers/api/v1/admin/adopt.go index 3c39d7c2b..8f11ab67f 100644 --- a/routers/api/v1/admin/adopt.go +++ b/routers/api/v1/admin/adopt.go @@ -85,7 +85,7 @@ func AdoptRepository(ctx *context.APIContext) { ownerName := ctx.Params(":username") repoName := ctx.Params(":reponame") - ctxUser, err := user_model.GetUserByName(ownerName) + ctxUser, err := user_model.GetUserByName(ctx, ownerName) if err != nil { if user_model.IsErrUserNotExist(err) { ctx.NotFound() @@ -96,7 +96,7 @@ func AdoptRepository(ctx *context.APIContext) { } // check not a repo - has, err := repo_model.IsRepositoryExist(ctxUser, repoName) + has, err := repo_model.IsRepositoryExist(ctx, ctxUser, repoName) if err != nil { ctx.InternalServerError(err) return @@ -147,7 +147,7 @@ func DeleteUnadoptedRepository(ctx *context.APIContext) { ownerName := ctx.Params(":username") repoName := ctx.Params(":reponame") - ctxUser, err := user_model.GetUserByName(ownerName) + ctxUser, err := user_model.GetUserByName(ctx, ownerName) if err != nil { if user_model.IsErrUserNotExist(err) { ctx.NotFound() @@ -158,7 +158,7 @@ func DeleteUnadoptedRepository(ctx *context.APIContext) { } // check not a repo - has, err := repo_model.IsRepositoryExist(ctxUser, repoName) + has, err := repo_model.IsRepositoryExist(ctx, ctxUser, repoName) if err != nil { ctx.InternalServerError(err) return diff --git a/routers/api/v1/admin/user.go b/routers/api/v1/admin/user.go index 6263a6704..71932136b 100644 --- a/routers/api/v1/admin/user.go +++ b/routers/api/v1/admin/user.go @@ -269,7 +269,7 @@ func EditUser(ctx *context.APIContext) { ctx.ContextUser.IsRestricted = *form.Restricted } - if err := user_model.UpdateUser(ctx.ContextUser, emailChanged); err != nil { + if err := user_model.UpdateUser(ctx, ctx.ContextUser, emailChanged); err != nil { if user_model.IsErrEmailAlreadyUsed(err) || user_model.IsErrEmailCharIsNotSupported(err) || user_model.IsErrEmailInvalid(err) { diff --git a/routers/api/v1/api.go b/routers/api/v1/api.go index 9db1d80f7..62c4a8934 100644 --- a/routers/api/v1/api.go +++ b/routers/api/v1/api.go @@ -108,7 +108,7 @@ func sudo() func(ctx *context.APIContext) { if len(sudo) > 0 { if ctx.IsSigned && ctx.Doer.IsAdmin { - user, err := user_model.GetUserByName(sudo) + user, err := user_model.GetUserByName(ctx, sudo) if err != nil { if user_model.IsErrUserNotExist(err) { ctx.NotFound() @@ -143,7 +143,7 @@ func repoAssignment() func(ctx *context.APIContext) { if ctx.IsSigned && ctx.Doer.LowerName == strings.ToLower(userName) { owner = ctx.Doer } else { - owner, err = user_model.GetUserByName(userName) + owner, err = user_model.GetUserByName(ctx, userName) if err != nil { if user_model.IsErrUserNotExist(err) { if redirectUserID, err := user_model.LookupUserRedirect(userName); err == nil { @@ -467,7 +467,7 @@ func orgAssignment(args ...bool) func(ctx *context.APIContext) { } if assignTeam { - ctx.Org.Team, err = organization.GetTeamByID(ctx.ParamsInt64(":teamid")) + ctx.Org.Team, err = organization.GetTeamByID(ctx, ctx.ParamsInt64(":teamid")) if err != nil { if organization.IsErrTeamNotExist(err) { ctx.NotFound() diff --git a/routers/api/v1/notify/notifications.go b/routers/api/v1/notify/notifications.go index c707cf452..0a3684fbe 100644 --- a/routers/api/v1/notify/notifications.go +++ b/routers/api/v1/notify/notifications.go @@ -22,7 +22,7 @@ func NewAvailable(ctx *context.APIContext) { // responses: // "200": // "$ref": "#/responses/NotificationCount" - ctx.JSON(http.StatusOK, api.NotificationCount{New: models.CountUnread(ctx.Doer)}) + ctx.JSON(http.StatusOK, api.NotificationCount{New: models.CountUnread(ctx, ctx.Doer.ID)}) } func getFindNotificationOptions(ctx *context.APIContext) *models.FindNotificationOptions { diff --git a/routers/api/v1/notify/repo.go b/routers/api/v1/notify/repo.go index 0f6b90b05..4e9dd806d 100644 --- a/routers/api/v1/notify/repo.go +++ b/routers/api/v1/notify/repo.go @@ -115,7 +115,7 @@ func ListRepoNotifications(ctx *context.APIContext) { return } - nl, err := models.GetNotifications(opts) + nl, err := models.GetNotifications(ctx, opts) if err != nil { ctx.InternalServerError(err) return @@ -203,7 +203,7 @@ func ReadRepoNotifications(ctx *context.APIContext) { opts.Status = statusStringsToNotificationStatuses(statuses, []string{"unread"}) log.Error("%v", opts.Status) } - nl, err := models.GetNotifications(opts) + nl, err := models.GetNotifications(ctx, opts) if err != nil { ctx.InternalServerError(err) return diff --git a/routers/api/v1/notify/user.go b/routers/api/v1/notify/user.go index ac3d0591d..b92330778 100644 --- a/routers/api/v1/notify/user.go +++ b/routers/api/v1/notify/user.go @@ -75,7 +75,7 @@ func ListNotifications(ctx *context.APIContext) { return } - nl, err := models.GetNotifications(opts) + nl, err := models.GetNotifications(ctx, opts) if err != nil { ctx.InternalServerError(err) return @@ -148,7 +148,7 @@ func ReadNotifications(ctx *context.APIContext) { statuses := ctx.FormStrings("status-types") opts.Status = statusStringsToNotificationStatuses(statuses, []string{"unread"}) } - nl, err := models.GetNotifications(opts) + nl, err := models.GetNotifications(ctx, opts) if err != nil { ctx.InternalServerError(err) return diff --git a/routers/api/v1/org/hook.go b/routers/api/v1/org/hook.go index 67957430d..ddf0ddefe 100644 --- a/routers/api/v1/org/hook.go +++ b/routers/api/v1/org/hook.go @@ -51,7 +51,7 @@ func ListHooks(ctx *context.APIContext) { return } - orgHooks, err := webhook.ListWebhooksByOpts(opts) + orgHooks, err := webhook.ListWebhooksByOpts(ctx, opts) if err != nil { ctx.InternalServerError(err) return diff --git a/routers/api/v1/org/label.go b/routers/api/v1/org/label.go index d36b1d9a9..9844ea21d 100644 --- a/routers/api/v1/org/label.go +++ b/routers/api/v1/org/label.go @@ -43,7 +43,7 @@ func ListLabels(ctx *context.APIContext) { // "200": // "$ref": "#/responses/LabelList" - labels, err := models.GetLabelsByOrgID(ctx.Org.Organization.ID, ctx.FormString("sort"), utils.GetListOptions(ctx)) + labels, err := models.GetLabelsByOrgID(ctx, ctx.Org.Organization.ID, ctx.FormString("sort"), utils.GetListOptions(ctx)) if err != nil { ctx.Error(http.StatusInternalServerError, "GetLabelsByOrgID", err) return @@ -136,9 +136,9 @@ func GetLabel(ctx *context.APIContext) { ) strID := ctx.Params(":id") if intID, err2 := strconv.ParseInt(strID, 10, 64); err2 != nil { - label, err = models.GetLabelInOrgByName(ctx.Org.Organization.ID, strID) + label, err = models.GetLabelInOrgByName(ctx, ctx.Org.Organization.ID, strID) } else { - label, err = models.GetLabelInOrgByID(ctx.Org.Organization.ID, intID) + label, err = models.GetLabelInOrgByID(ctx, ctx.Org.Organization.ID, intID) } if err != nil { if models.IsErrOrgLabelNotExist(err) { @@ -183,7 +183,7 @@ func EditLabel(ctx *context.APIContext) { // "422": // "$ref": "#/responses/validationError" form := web.GetForm(ctx).(*api.EditLabelOption) - label, err := models.GetLabelInOrgByID(ctx.Org.Organization.ID, ctx.ParamsInt64(":id")) + label, err := models.GetLabelInOrgByID(ctx, ctx.Org.Organization.ID, ctx.ParamsInt64(":id")) if err != nil { if models.IsErrOrgLabelNotExist(err) { ctx.NotFound() diff --git a/routers/api/v1/repo/branch.go b/routers/api/v1/repo/branch.go index c030a896a..09e6ccf23 100644 --- a/routers/api/v1/repo/branch.go +++ b/routers/api/v1/repo/branch.go @@ -70,7 +70,7 @@ func GetBranch(ctx *context.APIContext) { return } - branchProtection, err := models.GetProtectedBranchBy(ctx.Repo.Repository.ID, branchName) + branchProtection, err := models.GetProtectedBranchBy(ctx, ctx.Repo.Repository.ID, branchName) if err != nil { ctx.Error(http.StatusInternalServerError, "GetBranchProtection", err) return @@ -206,7 +206,7 @@ func CreateBranch(ctx *context.APIContext) { return } - branchProtection, err := models.GetProtectedBranchBy(ctx.Repo.Repository.ID, branch.Name) + branchProtection, err := models.GetProtectedBranchBy(ctx, ctx.Repo.Repository.ID, branch.Name) if err != nil { ctx.Error(http.StatusInternalServerError, "GetBranchProtection", err) return @@ -271,7 +271,7 @@ func ListBranches(ctx *context.APIContext) { ctx.Error(http.StatusInternalServerError, "GetCommit", err) return } - branchProtection, err := models.GetProtectedBranchBy(ctx.Repo.Repository.ID, branches[i].Name) + branchProtection, err := models.GetProtectedBranchBy(ctx, ctx.Repo.Repository.ID, branches[i].Name) if err != nil { ctx.Error(http.StatusInternalServerError, "GetBranchProtection", err) return @@ -320,7 +320,7 @@ func GetBranchProtection(ctx *context.APIContext) { repo := ctx.Repo.Repository bpName := ctx.Params(":name") - bp, err := models.GetProtectedBranchBy(repo.ID, bpName) + bp, err := models.GetProtectedBranchBy(ctx, repo.ID, bpName) if err != nil { ctx.Error(http.StatusInternalServerError, "GetProtectedBranchByID", err) return @@ -412,7 +412,7 @@ func CreateBranchProtection(ctx *context.APIContext) { return } - protectBranch, err := models.GetProtectedBranchBy(repo.ID, form.BranchName) + protectBranch, err := models.GetProtectedBranchBy(ctx, repo.ID, form.BranchName) if err != nil { ctx.Error(http.StatusInternalServerError, "GetProtectBranchOfRepoByName", err) return @@ -523,7 +523,7 @@ func CreateBranchProtection(ctx *context.APIContext) { } // Reload from db to get all whitelists - bp, err := models.GetProtectedBranchBy(ctx.Repo.Repository.ID, form.BranchName) + bp, err := models.GetProtectedBranchBy(ctx, ctx.Repo.Repository.ID, form.BranchName) if err != nil { ctx.Error(http.StatusInternalServerError, "GetProtectedBranchByID", err) return @@ -575,7 +575,7 @@ func EditBranchProtection(ctx *context.APIContext) { form := web.GetForm(ctx).(*api.EditBranchProtectionOption) repo := ctx.Repo.Repository bpName := ctx.Params(":name") - protectBranch, err := models.GetProtectedBranchBy(repo.ID, bpName) + protectBranch, err := models.GetProtectedBranchBy(ctx, repo.ID, bpName) if err != nil { ctx.Error(http.StatusInternalServerError, "GetProtectedBranchByID", err) return @@ -758,7 +758,7 @@ func EditBranchProtection(ctx *context.APIContext) { } // Reload from db to ensure get all whitelists - bp, err := models.GetProtectedBranchBy(repo.ID, bpName) + bp, err := models.GetProtectedBranchBy(ctx, repo.ID, bpName) if err != nil { ctx.Error(http.StatusInternalServerError, "GetProtectedBranchBy", err) return @@ -802,7 +802,7 @@ func DeleteBranchProtection(ctx *context.APIContext) { repo := ctx.Repo.Repository bpName := ctx.Params(":name") - bp, err := models.GetProtectedBranchBy(repo.ID, bpName) + bp, err := models.GetProtectedBranchBy(ctx, repo.ID, bpName) if err != nil { ctx.Error(http.StatusInternalServerError, "GetProtectedBranchByID", err) return diff --git a/routers/api/v1/repo/collaborators.go b/routers/api/v1/repo/collaborators.go index fed2d9f05..248497a56 100644 --- a/routers/api/v1/repo/collaborators.go +++ b/routers/api/v1/repo/collaborators.go @@ -103,7 +103,7 @@ func IsCollaborator(ctx *context.APIContext) { // "422": // "$ref": "#/responses/validationError" - user, err := user_model.GetUserByName(ctx.Params(":collaborator")) + user, err := user_model.GetUserByName(ctx, ctx.Params(":collaborator")) if err != nil { if user_model.IsErrUserNotExist(err) { ctx.Error(http.StatusUnprocessableEntity, "", err) @@ -159,7 +159,7 @@ func AddCollaborator(ctx *context.APIContext) { form := web.GetForm(ctx).(*api.AddCollaboratorOption) - collaborator, err := user_model.GetUserByName(ctx.Params(":collaborator")) + collaborator, err := user_model.GetUserByName(ctx, ctx.Params(":collaborator")) if err != nil { if user_model.IsErrUserNotExist(err) { ctx.Error(http.StatusUnprocessableEntity, "", err) @@ -218,7 +218,7 @@ func DeleteCollaborator(ctx *context.APIContext) { // "422": // "$ref": "#/responses/validationError" - collaborator, err := user_model.GetUserByName(ctx.Params(":collaborator")) + collaborator, err := user_model.GetUserByName(ctx, ctx.Params(":collaborator")) if err != nil { if user_model.IsErrUserNotExist(err) { ctx.Error(http.StatusUnprocessableEntity, "", err) @@ -271,7 +271,7 @@ func GetRepoPermissions(ctx *context.APIContext) { return } - collaborator, err := user_model.GetUserByName(ctx.Params(":collaborator")) + collaborator, err := user_model.GetUserByName(ctx, ctx.Params(":collaborator")) if err != nil { if user_model.IsErrUserNotExist(err) { ctx.Error(http.StatusNotFound, "GetUserByName", err) diff --git a/routers/api/v1/repo/hook.go b/routers/api/v1/repo/hook.go index 7ec6cd88a..8a546e581 100644 --- a/routers/api/v1/repo/hook.go +++ b/routers/api/v1/repo/hook.go @@ -60,7 +60,7 @@ func ListHooks(ctx *context.APIContext) { return } - hooks, err := webhook.ListWebhooksByOpts(opts) + hooks, err := webhook.ListWebhooksByOpts(ctx, opts) if err != nil { ctx.InternalServerError(err) return diff --git a/routers/api/v1/repo/issue.go b/routers/api/v1/repo/issue.go index b45069ad4..62959c3a7 100644 --- a/routers/api/v1/repo/issue.go +++ b/routers/api/v1/repo/issue.go @@ -145,7 +145,7 @@ func SearchIssues(ctx *context.APIContext) { opts.AllLimited = true } if ctx.FormString("owner") != "" { - owner, err := user_model.GetUserByName(ctx.FormString("owner")) + owner, err := user_model.GetUserByName(ctx, ctx.FormString("owner")) if err != nil { if user_model.IsErrUserNotExist(err) { ctx.Error(http.StatusBadRequest, "Owner not found", err) @@ -164,7 +164,7 @@ func SearchIssues(ctx *context.APIContext) { ctx.Error(http.StatusBadRequest, "", "Owner organisation is required for filtering on team") return } - team, err := organization.GetTeam(opts.OwnerID, ctx.FormString("team")) + team, err := organization.GetTeam(ctx, opts.OwnerID, ctx.FormString("team")) if err != nil { if organization.IsErrTeamNotExist(err) { ctx.Error(http.StatusBadRequest, "Team not found", err) @@ -502,7 +502,7 @@ func getUserIDForFilter(ctx *context.APIContext, queryName string) int64 { return 0 } - user, err := user_model.GetUserByName(userName) + user, err := user_model.GetUserByName(ctx, userName) if user_model.IsErrUserNotExist(err) { ctx.NotFound(err) return 0 diff --git a/routers/api/v1/repo/issue_comment.go b/routers/api/v1/repo/issue_comment.go index 6065adc27..22533c381 100644 --- a/routers/api/v1/repo/issue_comment.go +++ b/routers/api/v1/repo/issue_comment.go @@ -79,7 +79,7 @@ func ListIssueComments(ctx *context.APIContext) { Type: models.CommentTypeComment, } - comments, err := models.FindComments(opts) + comments, err := models.FindComments(ctx, opts) if err != nil { ctx.Error(http.StatusInternalServerError, "FindComments", err) return @@ -172,7 +172,7 @@ func ListIssueCommentsAndTimeline(ctx *context.APIContext) { Type: models.CommentTypeUnknown, } - comments, err := models.FindComments(opts) + comments, err := models.FindComments(ctx, opts) if err != nil { ctx.Error(http.StatusInternalServerError, "FindComments", err) return @@ -269,7 +269,7 @@ func ListRepoIssueComments(ctx *context.APIContext) { Before: before, } - comments, err := models.FindComments(opts) + comments, err := models.FindComments(ctx, opts) if err != nil { ctx.Error(http.StatusInternalServerError, "FindComments", err) return @@ -399,7 +399,7 @@ func GetIssueComment(ctx *context.APIContext) { // "404": // "$ref": "#/responses/notFound" - comment, err := models.GetCommentByID(ctx.ParamsInt64(":id")) + comment, err := models.GetCommentByID(ctx, ctx.ParamsInt64(":id")) if err != nil { if models.IsErrCommentNotExist(err) { ctx.NotFound(err) @@ -526,7 +526,7 @@ func EditIssueCommentDeprecated(ctx *context.APIContext) { } func editIssueComment(ctx *context.APIContext, form api.EditIssueCommentOption) { - comment, err := models.GetCommentByID(ctx.ParamsInt64(":id")) + comment, err := models.GetCommentByID(ctx, ctx.ParamsInt64(":id")) if err != nil { if models.IsErrCommentNotExist(err) { ctx.NotFound(err) @@ -629,7 +629,7 @@ func DeleteIssueCommentDeprecated(ctx *context.APIContext) { } func deleteIssueComment(ctx *context.APIContext) { - comment, err := models.GetCommentByID(ctx.ParamsInt64(":id")) + comment, err := models.GetCommentByID(ctx, ctx.ParamsInt64(":id")) if err != nil { if models.IsErrCommentNotExist(err) { ctx.NotFound(err) diff --git a/routers/api/v1/repo/issue_label.go b/routers/api/v1/repo/issue_label.go index e314e756d..0193eb423 100644 --- a/routers/api/v1/repo/issue_label.go +++ b/routers/api/v1/repo/issue_label.go @@ -111,7 +111,7 @@ func AddIssueLabels(ctx *context.APIContext) { return } - labels, err = models.GetLabelsByIssueID(issue.ID) + labels, err = models.GetLabelsByIssueID(ctx, issue.ID) if err != nil { ctx.Error(http.StatusInternalServerError, "GetLabelsByIssueID", err) return @@ -173,7 +173,7 @@ func DeleteIssueLabel(ctx *context.APIContext) { return } - label, err := models.GetLabelByID(ctx.ParamsInt64(":id")) + label, err := models.GetLabelByID(ctx, ctx.ParamsInt64(":id")) if err != nil { if models.IsErrLabelNotExist(err) { ctx.Error(http.StatusUnprocessableEntity, "", err) @@ -237,7 +237,7 @@ func ReplaceIssueLabels(ctx *context.APIContext) { return } - labels, err = models.GetLabelsByIssueID(issue.ID) + labels, err = models.GetLabelsByIssueID(ctx, issue.ID) if err != nil { ctx.Error(http.StatusInternalServerError, "GetLabelsByIssueID", err) return diff --git a/routers/api/v1/repo/issue_reaction.go b/routers/api/v1/repo/issue_reaction.go index 5aa736679..45be7a92d 100644 --- a/routers/api/v1/repo/issue_reaction.go +++ b/routers/api/v1/repo/issue_reaction.go @@ -49,7 +49,7 @@ func GetIssueCommentReactions(ctx *context.APIContext) { // "403": // "$ref": "#/responses/forbidden" - comment, err := models.GetCommentByID(ctx.ParamsInt64(":id")) + comment, err := models.GetCommentByID(ctx, ctx.ParamsInt64(":id")) if err != nil { if models.IsErrCommentNotExist(err) { ctx.NotFound(err) @@ -176,7 +176,7 @@ func DeleteIssueCommentReaction(ctx *context.APIContext) { } func changeIssueCommentReaction(ctx *context.APIContext, form api.EditReactionOption, isCreateType bool) { - comment, err := models.GetCommentByID(ctx.ParamsInt64(":id")) + comment, err := models.GetCommentByID(ctx, ctx.ParamsInt64(":id")) if err != nil { if models.IsErrCommentNotExist(err) { ctx.NotFound(err) diff --git a/routers/api/v1/repo/issue_subscription.go b/routers/api/v1/repo/issue_subscription.go index f00c85b12..a608ba227 100644 --- a/routers/api/v1/repo/issue_subscription.go +++ b/routers/api/v1/repo/issue_subscription.go @@ -116,7 +116,7 @@ func setIssueSubscription(ctx *context.APIContext, watch bool) { return } - user, err := user_model.GetUserByName(ctx.Params(":user")) + user, err := user_model.GetUserByName(ctx, ctx.Params(":user")) if err != nil { if user_model.IsErrUserNotExist(err) { ctx.NotFound() @@ -263,7 +263,7 @@ func GetIssueSubscribers(ctx *context.APIContext) { return } - iwl, err := models.GetIssueWatchers(issue.ID, utils.GetListOptions(ctx)) + iwl, err := models.GetIssueWatchers(ctx, issue.ID, utils.GetListOptions(ctx)) if err != nil { ctx.Error(http.StatusInternalServerError, "GetIssueWatchers", err) return @@ -284,7 +284,7 @@ func GetIssueSubscribers(ctx *context.APIContext) { apiUsers = append(apiUsers, convert.ToUser(v, ctx.Doer)) } - count, err := models.CountIssueWatchers(issue.ID) + count, err := models.CountIssueWatchers(ctx, issue.ID) if err != nil { ctx.Error(http.StatusInternalServerError, "CountIssueWatchers", err) return diff --git a/routers/api/v1/repo/issue_tracked_time.go b/routers/api/v1/repo/issue_tracked_time.go index 8ccad8783..19e1a8259 100644 --- a/routers/api/v1/repo/issue_tracked_time.go +++ b/routers/api/v1/repo/issue_tracked_time.go @@ -94,7 +94,7 @@ func ListTrackedTimes(ctx *context.APIContext) { qUser := ctx.FormTrim("user") if qUser != "" { - user, err := user_model.GetUserByName(qUser) + user, err := user_model.GetUserByName(ctx, qUser) if user_model.IsErrUserNotExist(err) { ctx.Error(http.StatusNotFound, "User does not exist", err) } else if err != nil { @@ -128,7 +128,7 @@ func ListTrackedTimes(ctx *context.APIContext) { return } - trackedTimes, err := models.GetTrackedTimes(opts) + trackedTimes, err := models.GetTrackedTimes(ctx, opts) if err != nil { ctx.Error(http.StatusInternalServerError, "GetTrackedTimes", err) return @@ -203,7 +203,7 @@ func AddTime(ctx *context.APIContext) { if form.User != "" { if (ctx.IsUserRepoAdmin() && ctx.Doer.Name != form.User) || ctx.Doer.IsAdmin { // allow only RepoAdmin, Admin and User to add time - user, err = user_model.GetUserByName(form.User) + user, err = user_model.GetUserByName(ctx, form.User) if err != nil { ctx.Error(http.StatusInternalServerError, "GetUserByName", err) } @@ -415,7 +415,7 @@ func ListTrackedTimesByUser(ctx *context.APIContext) { ctx.Error(http.StatusBadRequest, "", "time tracking disabled") return } - user, err := user_model.GetUserByName(ctx.Params(":timetrackingusername")) + user, err := user_model.GetUserByName(ctx, ctx.Params(":timetrackingusername")) if err != nil { if user_model.IsErrUserNotExist(err) { ctx.NotFound(err) @@ -439,7 +439,7 @@ func ListTrackedTimesByUser(ctx *context.APIContext) { RepositoryID: ctx.Repo.Repository.ID, } - trackedTimes, err := models.GetTrackedTimes(opts) + trackedTimes, err := models.GetTrackedTimes(ctx, opts) if err != nil { ctx.Error(http.StatusInternalServerError, "GetTrackedTimes", err) return @@ -512,7 +512,7 @@ func ListTrackedTimesByRepository(ctx *context.APIContext) { // Filters qUser := ctx.FormTrim("user") if qUser != "" { - user, err := user_model.GetUserByName(qUser) + user, err := user_model.GetUserByName(ctx, qUser) if user_model.IsErrUserNotExist(err) { ctx.Error(http.StatusNotFound, "User does not exist", err) } else if err != nil { @@ -547,7 +547,7 @@ func ListTrackedTimesByRepository(ctx *context.APIContext) { return } - trackedTimes, err := models.GetTrackedTimes(opts) + trackedTimes, err := models.GetTrackedTimes(ctx, opts) if err != nil { ctx.Error(http.StatusInternalServerError, "GetTrackedTimes", err) return @@ -609,7 +609,7 @@ func ListMyTrackedTimes(ctx *context.APIContext) { return } - trackedTimes, err := models.GetTrackedTimes(opts) + trackedTimes, err := models.GetTrackedTimes(ctx, opts) if err != nil { ctx.Error(http.StatusInternalServerError, "GetTrackedTimesByUser", err) return diff --git a/routers/api/v1/repo/label.go b/routers/api/v1/repo/label.go index ab559a2ee..4332b8e62 100644 --- a/routers/api/v1/repo/label.go +++ b/routers/api/v1/repo/label.go @@ -49,7 +49,7 @@ func ListLabels(ctx *context.APIContext) { // "200": // "$ref": "#/responses/LabelList" - labels, err := models.GetLabelsByRepoID(ctx.Repo.Repository.ID, ctx.FormString("sort"), utils.GetListOptions(ctx)) + labels, err := models.GetLabelsByRepoID(ctx, ctx.Repo.Repository.ID, ctx.FormString("sort"), utils.GetListOptions(ctx)) if err != nil { ctx.Error(http.StatusInternalServerError, "GetLabelsByRepoID", err) return @@ -99,9 +99,9 @@ func GetLabel(ctx *context.APIContext) { ) strID := ctx.Params(":id") if intID, err2 := strconv.ParseInt(strID, 10, 64); err2 != nil { - label, err = models.GetLabelInRepoByName(ctx.Repo.Repository.ID, strID) + label, err = models.GetLabelInRepoByName(ctx, ctx.Repo.Repository.ID, strID) } else { - label, err = models.GetLabelInRepoByID(ctx.Repo.Repository.ID, intID) + label, err = models.GetLabelInRepoByID(ctx, ctx.Repo.Repository.ID, intID) } if err != nil { if models.IsErrRepoLabelNotExist(err) { @@ -206,7 +206,7 @@ func EditLabel(ctx *context.APIContext) { // "$ref": "#/responses/validationError" form := web.GetForm(ctx).(*api.EditLabelOption) - label, err := models.GetLabelInRepoByID(ctx.Repo.Repository.ID, ctx.ParamsInt64(":id")) + label, err := models.GetLabelInRepoByID(ctx, ctx.Repo.Repository.ID, ctx.ParamsInt64(":id")) if err != nil { if models.IsErrRepoLabelNotExist(err) { ctx.NotFound() diff --git a/routers/api/v1/repo/language.go b/routers/api/v1/repo/language.go index f47b0a0e7..ca803cb68 100644 --- a/routers/api/v1/repo/language.go +++ b/routers/api/v1/repo/language.go @@ -68,7 +68,7 @@ func GetLanguages(ctx *context.APIContext) { // "200": // "$ref": "#/responses/LanguageStatistics" - langs, err := repo_model.GetLanguageStats(ctx.Repo.Repository) + langs, err := repo_model.GetLanguageStats(ctx, ctx.Repo.Repository) if err != nil { log.Error("GetLanguageStats failed: %v", err) ctx.InternalServerError(err) diff --git a/routers/api/v1/repo/migrate.go b/routers/api/v1/repo/migrate.go index f5851bfca..f868c5395 100644 --- a/routers/api/v1/repo/migrate.go +++ b/routers/api/v1/repo/migrate.go @@ -65,7 +65,7 @@ func Migrate(ctx *context.APIContext) { err error ) if len(form.RepoOwner) != 0 { - repoOwner, err = user_model.GetUserByName(form.RepoOwner) + repoOwner, err = user_model.GetUserByName(ctx, form.RepoOwner) } else if form.RepoOwnerID != 0 { repoOwner, err = user_model.GetUserByID(form.RepoOwnerID) } else { diff --git a/routers/api/v1/repo/mirror.go b/routers/api/v1/repo/mirror.go index d7facd24d..1af63c55b 100644 --- a/routers/api/v1/repo/mirror.go +++ b/routers/api/v1/repo/mirror.go @@ -50,7 +50,7 @@ func MirrorSync(ctx *context.APIContext) { return } - if _, err := repo_model.GetMirrorByRepoID(repo.ID); err != nil { + if _, err := repo_model.GetMirrorByRepoID(ctx, repo.ID); err != nil { if errors.Is(err, repo_model.ErrMirrorNotExist) { ctx.Error(http.StatusBadRequest, "MirrorSync", "Repository is not a mirror") return diff --git a/routers/api/v1/repo/pull.go b/routers/api/v1/repo/pull.go index a01efda1b..393f2d157 100644 --- a/routers/api/v1/repo/pull.go +++ b/routers/api/v1/repo/pull.go @@ -160,7 +160,7 @@ func GetPullRequest(ctx *context.APIContext) { // "404": // "$ref": "#/responses/notFound" - pr, err := models.GetPullRequestByIndex(ctx.Repo.Repository.ID, ctx.ParamsInt64(":index")) + pr, err := models.GetPullRequestByIndex(ctx, ctx.Repo.Repository.ID, ctx.ParamsInt64(":index")) if err != nil { if models.IsErrPullRequestNotExist(err) { ctx.NotFound() @@ -220,7 +220,7 @@ func DownloadPullDiffOrPatch(ctx *context.APIContext) { // "$ref": "#/responses/string" // "404": // "$ref": "#/responses/notFound" - pr, err := models.GetPullRequestByIndex(ctx.Repo.Repository.ID, ctx.ParamsInt64(":index")) + pr, err := models.GetPullRequestByIndex(ctx, ctx.Repo.Repository.ID, ctx.ParamsInt64(":index")) if err != nil { if models.IsErrPullRequestNotExist(err) { ctx.NotFound() @@ -470,7 +470,7 @@ func EditPullRequest(ctx *context.APIContext) { // "$ref": "#/responses/validationError" form := web.GetForm(ctx).(*api.EditPullRequestOption) - pr, err := models.GetPullRequestByIndex(ctx.Repo.Repository.ID, ctx.ParamsInt64(":index")) + pr, err := models.GetPullRequestByIndex(ctx, ctx.Repo.Repository.ID, ctx.ParamsInt64(":index")) if err != nil { if models.IsErrPullRequestNotExist(err) { ctx.NotFound() @@ -632,7 +632,7 @@ func EditPullRequest(ctx *context.APIContext) { } // Refetch from database - pr, err = models.GetPullRequestByIndex(ctx.Repo.Repository.ID, pr.Index) + pr, err = models.GetPullRequestByIndex(ctx, ctx.Repo.Repository.ID, pr.Index) if err != nil { if models.IsErrPullRequestNotExist(err) { ctx.NotFound() @@ -676,7 +676,7 @@ func IsPullRequestMerged(ctx *context.APIContext) { // "404": // description: pull request has not been merged - pr, err := models.GetPullRequestByIndex(ctx.Repo.Repository.ID, ctx.ParamsInt64(":index")) + pr, err := models.GetPullRequestByIndex(ctx, ctx.Repo.Repository.ID, ctx.ParamsInt64(":index")) if err != nil { if models.IsErrPullRequestNotExist(err) { ctx.NotFound() @@ -730,7 +730,7 @@ func MergePullRequest(ctx *context.APIContext) { form := web.GetForm(ctx).(*forms.MergePullRequestForm) - pr, err := models.GetPullRequestByIndexCtx(ctx, ctx.Repo.Repository.ID, ctx.ParamsInt64(":index")) + pr, err := models.GetPullRequestByIndex(ctx, ctx.Repo.Repository.ID, ctx.ParamsInt64(":index")) if err != nil { if models.IsErrPullRequestNotExist(err) { ctx.NotFound("GetPullRequestByIndex", err) @@ -938,7 +938,7 @@ func parseCompareInfo(ctx *context.APIContext, form api.CreatePullRequestOption) headBranch = headInfos[0] } else if len(headInfos) == 2 { - headUser, err = user_model.GetUserByName(headInfos[0]) + headUser, err = user_model.GetUserByName(ctx, headInfos[0]) if err != nil { if user_model.IsErrUserNotExist(err) { ctx.NotFound("GetUserByName") @@ -1079,7 +1079,7 @@ func UpdatePullRequest(ctx *context.APIContext) { // "422": // "$ref": "#/responses/validationError" - pr, err := models.GetPullRequestByIndex(ctx.Repo.Repository.ID, ctx.ParamsInt64(":index")) + pr, err := models.GetPullRequestByIndex(ctx, ctx.Repo.Repository.ID, ctx.ParamsInt64(":index")) if err != nil { if models.IsErrPullRequestNotExist(err) { ctx.NotFound() @@ -1177,7 +1177,7 @@ func CancelScheduledAutoMerge(ctx *context.APIContext) { // "$ref": "#/responses/notFound" pullIndex := ctx.ParamsInt64(":index") - pull, err := models.GetPullRequestByIndex(ctx.Repo.Repository.ID, pullIndex) + pull, err := models.GetPullRequestByIndex(ctx, ctx.Repo.Repository.ID, pullIndex) if err != nil { if models.IsErrPullRequestNotExist(err) { ctx.NotFound() @@ -1198,7 +1198,7 @@ func CancelScheduledAutoMerge(ctx *context.APIContext) { } if ctx.Doer.ID != autoMerge.DoerID { - allowed, err := access_model.IsUserRepoAdminCtx(ctx, ctx.Repo.Repository, ctx.Doer) + allowed, err := access_model.IsUserRepoAdmin(ctx, ctx.Repo.Repository, ctx.Doer) if err != nil { ctx.InternalServerError(err) return @@ -1254,7 +1254,7 @@ func GetPullRequestCommits(ctx *context.APIContext) { // "404": // "$ref": "#/responses/notFound" - pr, err := models.GetPullRequestByIndex(ctx.Repo.Repository.ID, ctx.ParamsInt64(":index")) + pr, err := models.GetPullRequestByIndex(ctx, ctx.Repo.Repository.ID, ctx.ParamsInt64(":index")) if err != nil { if models.IsErrPullRequestNotExist(err) { ctx.NotFound() diff --git a/routers/api/v1/repo/pull_review.go b/routers/api/v1/repo/pull_review.go index 0cf540ce7..5175fa921 100644 --- a/routers/api/v1/repo/pull_review.go +++ b/routers/api/v1/repo/pull_review.go @@ -61,7 +61,7 @@ func ListPullReviews(ctx *context.APIContext) { // "404": // "$ref": "#/responses/notFound" - pr, err := models.GetPullRequestByIndex(ctx.Repo.Repository.ID, ctx.ParamsInt64(":index")) + pr, err := models.GetPullRequestByIndex(ctx, ctx.Repo.Repository.ID, ctx.ParamsInt64(":index")) if err != nil { if models.IsErrPullRequestNotExist(err) { ctx.NotFound("GetPullRequestByIndex", err) @@ -87,7 +87,7 @@ func ListPullReviews(ctx *context.APIContext) { IssueID: pr.IssueID, } - allReviews, err := models.FindReviews(opts) + allReviews, err := models.FindReviews(ctx, opts) if err != nil { ctx.InternalServerError(err) return @@ -307,7 +307,7 @@ func CreatePullReview(ctx *context.APIContext) { // "$ref": "#/responses/validationError" opts := web.GetForm(ctx).(*api.CreatePullReviewOptions) - pr, err := models.GetPullRequestByIndex(ctx.Repo.Repository.ID, ctx.ParamsInt64(":index")) + pr, err := models.GetPullRequestByIndex(ctx, ctx.Repo.Repository.ID, ctx.ParamsInt64(":index")) if err != nil { if models.IsErrPullRequestNotExist(err) { ctx.NotFound("GetPullRequestByIndex", err) @@ -526,7 +526,7 @@ func preparePullReviewType(ctx *context.APIContext, pr *models.PullRequest, even // prepareSingleReview return review, related pull and false or nil, nil and true if an error happen func prepareSingleReview(ctx *context.APIContext) (*models.Review, *models.PullRequest, bool) { - pr, err := models.GetPullRequestByIndex(ctx.Repo.Repository.ID, ctx.ParamsInt64(":index")) + pr, err := models.GetPullRequestByIndex(ctx, ctx.Repo.Repository.ID, ctx.ParamsInt64(":index")) if err != nil { if models.IsErrPullRequestNotExist(err) { ctx.NotFound("GetPullRequestByIndex", err) @@ -536,7 +536,7 @@ func prepareSingleReview(ctx *context.APIContext) (*models.Review, *models.PullR return nil, nil, true } - review, err := models.GetReviewByID(ctx.ParamsInt64(":id")) + review, err := models.GetReviewByID(ctx, ctx.ParamsInt64(":id")) if err != nil { if models.IsErrReviewNotExist(err) { ctx.NotFound("GetReviewByID", err) @@ -648,7 +648,7 @@ func DeleteReviewRequests(ctx *context.APIContext) { } func apiReviewRequest(ctx *context.APIContext, opts api.PullReviewRequestOptions, isAdd bool) { - pr, err := models.GetPullRequestByIndex(ctx.Repo.Repository.ID, ctx.ParamsInt64(":index")) + pr, err := models.GetPullRequestByIndex(ctx, ctx.Repo.Repository.ID, ctx.ParamsInt64(":index")) if err != nil { if models.IsErrPullRequestNotExist(err) { ctx.NotFound("GetPullRequestByIndex", err) @@ -676,7 +676,7 @@ func apiReviewRequest(ctx *context.APIContext, opts api.PullReviewRequestOptions if strings.Contains(r, "@") { reviewer, err = user_model.GetUserByEmail(r) } else { - reviewer, err = user_model.GetUserByName(r) + reviewer, err = user_model.GetUserByName(ctx, r) } if err != nil { @@ -727,7 +727,7 @@ func apiReviewRequest(ctx *context.APIContext, opts api.PullReviewRequestOptions teamReviewers := make([]*organization.Team, 0, len(opts.TeamReviewers)) for _, t := range opts.TeamReviewers { var teamReviewer *organization.Team - teamReviewer, err = organization.GetTeam(ctx.Repo.Owner.ID, t) + teamReviewer, err = organization.GetTeam(ctx, ctx.Repo.Owner.ID, t) if err != nil { if organization.IsErrTeamNotExist(err) { ctx.NotFound("TeamNotExist", fmt.Sprintf("Team '%s' not exist", t)) @@ -892,7 +892,7 @@ func dismissReview(ctx *context.APIContext, msg string, isDismiss bool) { return } - if review, err = models.GetReviewByID(review.ID); err != nil { + if review, err = models.GetReviewByID(ctx, review.ID); err != nil { ctx.Error(http.StatusInternalServerError, "GetReviewByID", err) return } diff --git a/routers/api/v1/repo/release_attachment.go b/routers/api/v1/repo/release_attachment.go index c172b6612..b7807e5e8 100644 --- a/routers/api/v1/repo/release_attachment.go +++ b/routers/api/v1/repo/release_attachment.go @@ -55,7 +55,7 @@ func GetReleaseAttachment(ctx *context.APIContext) { releaseID := ctx.ParamsInt64(":id") attachID := ctx.ParamsInt64(":asset") - attach, err := repo_model.GetAttachmentByID(attachID) + attach, err := repo_model.GetAttachmentByID(ctx, attachID) if err != nil { ctx.Error(http.StatusInternalServerError, "GetAttachmentByID", err) return @@ -242,7 +242,7 @@ func EditReleaseAttachment(ctx *context.APIContext) { // Check if release exists an load release releaseID := ctx.ParamsInt64(":id") attachID := ctx.ParamsInt64(":asset") - attach, err := repo_model.GetAttachmentByID(attachID) + attach, err := repo_model.GetAttachmentByID(ctx, attachID) if err != nil { ctx.Error(http.StatusInternalServerError, "GetAttachmentByID", err) return @@ -257,7 +257,7 @@ func EditReleaseAttachment(ctx *context.APIContext) { attach.Name = form.Name } - if err := repo_model.UpdateAttachment(attach); err != nil { + if err := repo_model.UpdateAttachment(ctx, attach); err != nil { ctx.Error(http.StatusInternalServerError, "UpdateAttachment", attach) } ctx.JSON(http.StatusCreated, convert.ToReleaseAttachment(attach)) @@ -300,7 +300,7 @@ func DeleteReleaseAttachment(ctx *context.APIContext) { // Check if release exists an load release releaseID := ctx.ParamsInt64(":id") attachID := ctx.ParamsInt64(":asset") - attach, err := repo_model.GetAttachmentByID(attachID) + attach, err := repo_model.GetAttachmentByID(ctx, attachID) if err != nil { ctx.Error(http.StatusInternalServerError, "GetAttachmentByID", err) return diff --git a/routers/api/v1/repo/repo.go b/routers/api/v1/repo/repo.go index a11f579ee..8485ffbac 100644 --- a/routers/api/v1/repo/repo.go +++ b/routers/api/v1/repo/repo.go @@ -365,7 +365,7 @@ func Generate(ctx *context.APIContext) { ctxUser := ctx.Doer var err error if form.Owner != ctxUser.Name { - ctxUser, err = user_model.GetUserByName(form.Owner) + ctxUser, err = user_model.GetUserByName(ctx, form.Owner) if err != nil { if user_model.IsErrUserNotExist(err) { ctx.JSON(http.StatusNotFound, map[string]interface{}{ @@ -962,7 +962,7 @@ func updateMirror(ctx *context.APIContext, opts api.EditRepoOption) error { } // get the mirror from the repo - mirror, err := repo_model.GetMirrorByRepoID(repo.ID) + mirror, err := repo_model.GetMirrorByRepoID(ctx, repo.ID) if err != nil { log.Error("Failed to get mirror: %s", err) ctx.Error(http.StatusInternalServerError, "MirrorInterval", err) @@ -1000,7 +1000,7 @@ func updateMirror(ctx *context.APIContext, opts api.EditRepoOption) error { } // finally update the mirror in the DB - if err := repo_model.UpdateMirror(mirror); err != nil { + if err := repo_model.UpdateMirror(ctx, mirror); err != nil { log.Error("Failed to Set Mirror Interval: %s", err) ctx.Error(http.StatusUnprocessableEntity, "MirrorInterval", err) return err diff --git a/routers/api/v1/repo/status.go b/routers/api/v1/repo/status.go index f4c0ebd38..09597dc4e 100644 --- a/routers/api/v1/repo/status.go +++ b/routers/api/v1/repo/status.go @@ -253,7 +253,7 @@ func GetCombinedCommitStatusByRef(ctx *context.APIContext) { repo := ctx.Repo.Repository - statuses, count, err := models.GetLatestCommitStatus(repo.ID, sha, utils.GetListOptions(ctx)) + statuses, count, err := models.GetLatestCommitStatus(ctx, repo.ID, sha, utils.GetListOptions(ctx)) if err != nil { ctx.Error(http.StatusInternalServerError, "GetLatestCommitStatus", fmt.Errorf("GetLatestCommitStatus[%s, %s]: %v", repo.FullName(), sha, err)) return diff --git a/routers/api/v1/repo/teams.go b/routers/api/v1/repo/teams.go index e414d8b60..47c69d722 100644 --- a/routers/api/v1/repo/teams.go +++ b/routers/api/v1/repo/teams.go @@ -41,7 +41,7 @@ func ListTeams(ctx *context.APIContext) { return } - teams, err := organization.GetRepoTeams(ctx.Repo.Repository) + teams, err := organization.GetRepoTeams(ctx, ctx.Repo.Repository) if err != nil { ctx.InternalServerError(err) return @@ -216,7 +216,7 @@ func changeRepoTeam(ctx *context.APIContext, add bool) { } func getTeamByParam(ctx *context.APIContext) *organization.Team { - team, err := organization.GetTeam(ctx.Repo.Owner.ID, ctx.Params(":team")) + team, err := organization.GetTeam(ctx, ctx.Repo.Owner.ID, ctx.Params(":team")) if err != nil { if organization.IsErrTeamNotExist(err) { ctx.Error(http.StatusNotFound, "TeamNotExit", err) diff --git a/routers/api/v1/repo/transfer.go b/routers/api/v1/repo/transfer.go index 241c578e6..067a4ebe1 100644 --- a/routers/api/v1/repo/transfer.go +++ b/routers/api/v1/repo/transfer.go @@ -57,7 +57,7 @@ func Transfer(ctx *context.APIContext) { opts := web.GetForm(ctx).(*api.TransferRepoOption) - newOwner, err := user_model.GetUserByName(opts.NewOwner) + newOwner, err := user_model.GetUserByName(ctx, opts.NewOwner) if err != nil { if user_model.IsErrUserNotExist(err) { ctx.Error(http.StatusNotFound, "", "The new owner does not exist or cannot be found") @@ -84,7 +84,7 @@ func Transfer(ctx *context.APIContext) { org := convert.ToOrganization(organization.OrgFromUser(newOwner)) for _, tID := range *opts.TeamIDs { - team, err := organization.GetTeamByID(tID) + team, err := organization.GetTeamByID(ctx, tID) if err != nil { ctx.Error(http.StatusUnprocessableEntity, "team", fmt.Errorf("team %d not found", tID)) return diff --git a/routers/api/v1/user/app.go b/routers/api/v1/user/app.go index 165b8f005..0d2e8401c 100644 --- a/routers/api/v1/user/app.go +++ b/routers/api/v1/user/app.go @@ -213,7 +213,7 @@ func CreateOauth2Application(ctx *context.APIContext) { data := web.GetForm(ctx).(*api.CreateOAuth2ApplicationOptions) - app, err := auth.CreateOAuth2Application(auth.CreateOAuth2ApplicationOptions{ + app, err := auth.CreateOAuth2Application(ctx, auth.CreateOAuth2ApplicationOptions{ Name: data.Name, UserID: ctx.Doer.ID, RedirectURIs: data.RedirectURIs, @@ -320,7 +320,7 @@ func GetOauth2Application(ctx *context.APIContext) { // "404": // "$ref": "#/responses/notFound" appID := ctx.ParamsInt64(":id") - app, err := auth.GetOAuth2ApplicationByID(appID) + app, err := auth.GetOAuth2ApplicationByID(ctx, appID) if err != nil { if auth.IsErrOauthClientIDInvalid(err) || auth.IsErrOAuthApplicationNotFound(err) { ctx.NotFound() diff --git a/routers/api/v1/user/helper.go b/routers/api/v1/user/helper.go index fab3ce2ae..ae7fa5248 100644 --- a/routers/api/v1/user/helper.go +++ b/routers/api/v1/user/helper.go @@ -14,7 +14,7 @@ import ( // GetUserByParamsName get user by name func GetUserByParamsName(ctx *context.APIContext, name string) *user_model.User { username := ctx.Params(name) - user, err := user_model.GetUserByName(username) + user, err := user_model.GetUserByName(ctx, username) if err != nil { if user_model.IsErrUserNotExist(err) { if redirectUserID, err2 := user_model.LookupUserRedirect(username); err2 == nil { diff --git a/routers/api/v1/user/settings.go b/routers/api/v1/user/settings.go index dc7e7f116..f00bf8c26 100644 --- a/routers/api/v1/user/settings.go +++ b/routers/api/v1/user/settings.go @@ -74,7 +74,7 @@ func UpdateUserSettings(ctx *context.APIContext) { ctx.Doer.KeepActivityPrivate = *form.HideActivity } - if err := user_model.UpdateUser(ctx.Doer, false); err != nil { + if err := user_model.UpdateUser(ctx, ctx.Doer, false); err != nil { ctx.InternalServerError(err) return } diff --git a/routers/api/v1/user/star.go b/routers/api/v1/user/star.go index d2e4f36cd..9cb9ec79b 100644 --- a/routers/api/v1/user/star.go +++ b/routers/api/v1/user/star.go @@ -124,7 +124,7 @@ func IsStarring(ctx *context.APIContext) { // "404": // "$ref": "#/responses/notFound" - if repo_model.IsStaring(ctx.Doer.ID, ctx.Repo.Repository.ID) { + if repo_model.IsStaring(ctx, ctx.Doer.ID, ctx.Repo.Repository.ID) { ctx.Status(http.StatusNoContent) } else { ctx.NotFound() diff --git a/routers/api/v1/user/user.go b/routers/api/v1/user/user.go index 018f75762..2a3cb15c0 100644 --- a/routers/api/v1/user/user.go +++ b/routers/api/v1/user/user.go @@ -98,7 +98,7 @@ func GetInfo(ctx *context.APIContext) { // "404": // "$ref": "#/responses/notFound" - if !user_model.IsUserVisibleToViewer(ctx.ContextUser, ctx.Doer) { + if !user_model.IsUserVisibleToViewer(ctx, ctx.ContextUser, ctx.Doer) { // fake ErrUserNotExist error message to not leak information about existence ctx.NotFound("GetUserByName", user_model.ErrUserNotExist{Name: ctx.Params(":username")}) return diff --git a/routers/api/v1/user/watch.go b/routers/api/v1/user/watch.go index 652ef3f8a..83f23db15 100644 --- a/routers/api/v1/user/watch.go +++ b/routers/api/v1/user/watch.go @@ -156,7 +156,7 @@ func Watch(ctx *context.APIContext) { // "200": // "$ref": "#/responses/WatchInfo" - err := repo_model.WatchRepo(ctx.Doer.ID, ctx.Repo.Repository.ID, true) + err := repo_model.WatchRepo(ctx, ctx.Doer.ID, ctx.Repo.Repository.ID, true) if err != nil { ctx.Error(http.StatusInternalServerError, "WatchRepo", err) return @@ -191,7 +191,7 @@ func Unwatch(ctx *context.APIContext) { // "204": // "$ref": "#/responses/empty" - err := repo_model.WatchRepo(ctx.Doer.ID, ctx.Repo.Repository.ID, false) + err := repo_model.WatchRepo(ctx, ctx.Doer.ID, ctx.Repo.Repository.ID, false) if err != nil { ctx.Error(http.StatusInternalServerError, "UnwatchRepo", err) return diff --git a/routers/install/install.go b/routers/install/install.go index 41b11aef3..3fc5f0536 100644 --- a/routers/install/install.go +++ b/routers/install/install.go @@ -521,7 +521,7 @@ func SubmitInstall(ctx *context.Context) { return } log.Info("Admin account already exist") - u, _ = user_model.GetUserByName(u.Name) + u, _ = user_model.GetUserByName(ctx, u.Name) } days := 86400 * setting.LogInRememberDays diff --git a/routers/private/hook_post_receive.go b/routers/private/hook_post_receive.go index 5e315ede4..eb2bbc1e5 100644 --- a/routers/private/hook_post_receive.go +++ b/routers/private/hook_post_receive.go @@ -106,7 +106,7 @@ func HookPostReceive(ctx *gitea_context.PrivateContext) { repo.IsPrivate = opts.GitPushOptions.Bool(private.GitPushOptionRepoPrivate, repo.IsPrivate) repo.IsTemplate = opts.GitPushOptions.Bool(private.GitPushOptionRepoTemplate, repo.IsTemplate) - if err := repo_model.UpdateRepositoryCols(repo, "is_private", "is_template"); err != nil { + if err := repo_model.UpdateRepositoryCols(ctx, repo, "is_private", "is_template"); err != nil { log.Error("Failed to Update: %s/%s Error: %v", ownerName, repoName, err) ctx.JSON(http.StatusInternalServerError, private.HookPostReceiveResult{ Err: fmt.Sprintf("Failed to Update: %s/%s Error: %v", ownerName, repoName, err), @@ -141,7 +141,7 @@ func HookPostReceive(ctx *gitea_context.PrivateContext) { continue } - pr, err := models.GetPullRequestByIndex(repo.ID, pullIndex) + pr, err := models.GetPullRequestByIndex(ctx, repo.ID, pullIndex) if err != nil && !models.IsErrPullRequestNotExist(err) { log.Error("Failed to get PR by index %v Error: %v", pullIndex, err) ctx.JSON(http.StatusInternalServerError, private.Response{ diff --git a/routers/private/hook_pre_receive.go b/routers/private/hook_pre_receive.go index 9d3d66526..1f005d35b 100644 --- a/routers/private/hook_pre_receive.go +++ b/routers/private/hook_pre_receive.go @@ -155,7 +155,7 @@ func preReceiveBranch(ctx *preReceiveContext, oldCommitID, newCommitID, refFullN return } - protectBranch, err := models.GetProtectedBranchBy(repo.ID, branchName) + protectBranch, err := models.GetProtectedBranchBy(ctx, repo.ID, branchName) if err != nil { log.Error("Unable to get protected branch: %s in %-v Error: %v", branchName, repo, err) ctx.JSON(http.StatusInternalServerError, private.Response{ diff --git a/routers/private/key.go b/routers/private/key.go index 3366b764e..9977492c6 100644 --- a/routers/private/key.go +++ b/routers/private/key.go @@ -25,7 +25,7 @@ func UpdatePublicKeyInRepo(ctx *context.PrivateContext) { return } - deployKey, err := asymkey_model.GetDeployKeyByRepo(keyID, repoID) + deployKey, err := asymkey_model.GetDeployKeyByRepo(ctx, keyID, repoID) if err != nil { if asymkey_model.IsErrDeployKeyNotExist(err) { ctx.PlainText(http.StatusOK, "success") @@ -52,7 +52,7 @@ func UpdatePublicKeyInRepo(ctx *context.PrivateContext) { func AuthorizedPublicKeyByContent(ctx *context.PrivateContext) { content := ctx.FormString("content") - publicKey, err := asymkey_model.SearchPublicKeyByContent(content) + publicKey, err := asymkey_model.SearchPublicKeyByContent(ctx, content) if err != nil { ctx.JSON(http.StatusInternalServerError, private.Response{ Err: err.Error(), diff --git a/routers/private/mail.go b/routers/private/mail.go index 853b58b09..966a83816 100644 --- a/routers/private/mail.go +++ b/routers/private/mail.go @@ -44,7 +44,7 @@ func SendEmail(ctx *context.PrivateContext) { var emails []string if len(mail.To) > 0 { for _, uname := range mail.To { - user, err := user_model.GetUserByName(uname) + user, err := user_model.GetUserByName(ctx, uname) if err != nil { err := fmt.Sprintf("Failed to get user information: %v", err) log.Error(err) diff --git a/routers/private/serv.go b/routers/private/serv.go index 77877e1ad..803d51e9d 100644 --- a/routers/private/serv.go +++ b/routers/private/serv.go @@ -109,7 +109,7 @@ func ServCommand(ctx *context.PrivateContext) { results.RepoName = repoName[:len(repoName)-5] } - owner, err := user_model.GetUserByName(results.OwnerName) + owner, err := user_model.GetUserByName(ctx, results.OwnerName) if err != nil { if user_model.IsErrUserNotExist(err) { // User is fetching/cloning a non-existent repository @@ -230,7 +230,7 @@ func ServCommand(ctx *context.PrivateContext) { var user *user_model.User if key.Type == asymkey_model.KeyTypeDeploy { var err error - deployKey, err = asymkey_model.GetDeployKeyByRepo(key.ID, repo.ID) + deployKey, err = asymkey_model.GetDeployKeyByRepo(ctx, key.ID, repo.ID) if err != nil { if asymkey_model.IsErrDeployKeyNotExist(err) { ctx.JSON(http.StatusNotFound, private.ErrServCommand{ @@ -345,7 +345,7 @@ func ServCommand(ctx *context.PrivateContext) { // We already know we aren't using a deploy key if !repoExist { - owner, err := user_model.GetUserByName(ownerName) + owner, err := user_model.GetUserByName(ctx, ownerName) if err != nil { ctx.JSON(http.StatusInternalServerError, private.ErrServCommand{ Results: results, diff --git a/routers/web/admin/hooks.go b/routers/web/admin/hooks.go index 1483d0959..bf71cb559 100644 --- a/routers/web/admin/hooks.go +++ b/routers/web/admin/hooks.go @@ -35,7 +35,7 @@ func DefaultOrSystemWebhooks(ctx *context.Context) { sys["Title"] = ctx.Tr("admin.systemhooks") sys["Description"] = ctx.Tr("admin.systemhooks.desc") - sys["Webhooks"], err = webhook.GetSystemWebhooks(util.OptionalBoolNone) + sys["Webhooks"], err = webhook.GetSystemWebhooks(ctx, util.OptionalBoolNone) sys["BaseLink"] = setting.AppSubURL + "/admin/hooks" sys["BaseLinkNew"] = setting.AppSubURL + "/admin/system-hooks" if err != nil { @@ -45,7 +45,7 @@ func DefaultOrSystemWebhooks(ctx *context.Context) { def["Title"] = ctx.Tr("admin.defaulthooks") def["Description"] = ctx.Tr("admin.defaulthooks.desc") - def["Webhooks"], err = webhook.GetDefaultWebhooks() + def["Webhooks"], err = webhook.GetDefaultWebhooks(ctx) def["BaseLink"] = setting.AppSubURL + "/admin/hooks" def["BaseLinkNew"] = setting.AppSubURL + "/admin/default-hooks" if err != nil { diff --git a/routers/web/admin/repos.go b/routers/web/admin/repos.go index fb7be12c3..809d1de74 100644 --- a/routers/web/admin/repos.go +++ b/routers/web/admin/repos.go @@ -121,7 +121,7 @@ func AdoptOrDeleteRepository(ctx *context.Context) { return } - ctxUser, err := user_model.GetUserByName(dirSplit[0]) + ctxUser, err := user_model.GetUserByName(ctx, dirSplit[0]) if err != nil { if user_model.IsErrUserNotExist(err) { log.Debug("User does not exist: %s", dirSplit[0]) @@ -135,7 +135,7 @@ func AdoptOrDeleteRepository(ctx *context.Context) { repoName := dirSplit[1] // check not a repo - has, err := repo_model.IsRepositoryExist(ctxUser, repoName) + has, err := repo_model.IsRepositoryExist(ctx, ctxUser, repoName) if err != nil { ctx.ServerError("IsRepositoryExist", err) return diff --git a/routers/web/admin/users.go b/routers/web/admin/users.go index 7841ac569..c37ecfd71 100644 --- a/routers/web/admin/users.go +++ b/routers/web/admin/users.go @@ -389,7 +389,7 @@ func EditUserPost(ctx *context.Context) { u.ProhibitLogin = form.ProhibitLogin } - if err := user_model.UpdateUser(u, emailChanged); err != nil { + if err := user_model.UpdateUser(ctx, u, emailChanged); err != nil { if user_model.IsErrEmailAlreadyUsed(err) { ctx.Data["Err_Email"] = true ctx.RenderWithErr(ctx.Tr("form.email_been_used"), tplUserEdit, &form) diff --git a/routers/web/admin/users_test.go b/routers/web/admin/users_test.go index 9de548685..e63367ccf 100644 --- a/routers/web/admin/users_test.go +++ b/routers/web/admin/users_test.go @@ -47,7 +47,7 @@ func TestNewUserPost_MustChangePassword(t *testing.T) { assert.NotEmpty(t, ctx.Flash.SuccessMsg) - u, err := user_model.GetUserByName(username) + u, err := user_model.GetUserByName(ctx, username) assert.NoError(t, err) assert.Equal(t, username, u.Name) @@ -84,7 +84,7 @@ func TestNewUserPost_MustChangePasswordFalse(t *testing.T) { assert.NotEmpty(t, ctx.Flash.SuccessMsg) - u, err := user_model.GetUserByName(username) + u, err := user_model.GetUserByName(ctx, username) assert.NoError(t, err) assert.Equal(t, username, u.Name) @@ -151,7 +151,7 @@ func TestNewUserPost_VisibilityDefaultPublic(t *testing.T) { assert.NotEmpty(t, ctx.Flash.SuccessMsg) - u, err := user_model.GetUserByName(username) + u, err := user_model.GetUserByName(ctx, username) assert.NoError(t, err) assert.Equal(t, username, u.Name) @@ -190,7 +190,7 @@ func TestNewUserPost_VisibilityPrivate(t *testing.T) { assert.NotEmpty(t, ctx.Flash.SuccessMsg) - u, err := user_model.GetUserByName(username) + u, err := user_model.GetUserByName(ctx, username) assert.NoError(t, err) assert.Equal(t, username, u.Name) diff --git a/routers/web/auth/auth.go b/routers/web/auth/auth.go index be936d223..4d5a2c933 100644 --- a/routers/web/auth/auth.go +++ b/routers/web/auth/auth.go @@ -64,7 +64,7 @@ func AutoSignIn(ctx *context.Context) (bool, error) { } }() - u, err := user_model.GetUserByName(uname) + u, err := user_model.GetUserByName(ctx, uname) if err != nil { if !user_model.IsErrUserNotExist(err) { return false, fmt.Errorf("GetUserByName: %v", err) diff --git a/routers/web/auth/linkaccount.go b/routers/web/auth/linkaccount.go index c3e96f077..a2d76e9c5 100644 --- a/routers/web/auth/linkaccount.go +++ b/routers/web/auth/linkaccount.go @@ -70,7 +70,7 @@ func LinkAccount(ctx *context.Context) { ctx.Data["user_exists"] = true } } else if len(uname) != 0 { - u, err := user_model.GetUserByName(uname) + u, err := user_model.GetUserByName(ctx, uname) if err != nil && !user_model.IsErrUserNotExist(err) { ctx.ServerError("UserSignIn", err) return diff --git a/routers/web/auth/oauth.go b/routers/web/auth/oauth.go index 4c3e3c3ac..9aa31c1c0 100644 --- a/routers/web/auth/oauth.go +++ b/routers/web/auth/oauth.go @@ -5,6 +5,7 @@ package auth import ( + stdContext "context" "encoding/base64" "errors" "fmt" @@ -135,9 +136,9 @@ type AccessTokenResponse struct { IDToken string `json:"id_token,omitempty"` } -func newAccessTokenResponse(grant *auth.OAuth2Grant, serverKey, clientKey oauth2.JWTSigningKey) (*AccessTokenResponse, *AccessTokenError) { +func newAccessTokenResponse(ctx stdContext.Context, grant *auth.OAuth2Grant, serverKey, clientKey oauth2.JWTSigningKey) (*AccessTokenResponse, *AccessTokenError) { if setting.OAuth2.InvalidateRefreshTokens { - if err := grant.IncreaseCounter(); err != nil { + if err := grant.IncreaseCounter(ctx); err != nil { return nil, &AccessTokenError{ ErrorCode: AccessTokenErrorCodeInvalidGrant, ErrorDescription: "cannot increase the grant counter", @@ -182,7 +183,7 @@ func newAccessTokenResponse(grant *auth.OAuth2Grant, serverKey, clientKey oauth2 // generate OpenID Connect id_token signedIDToken := "" if grant.ScopeContains("openid") { - app, err := auth.GetOAuth2ApplicationByID(grant.ApplicationID) + app, err := auth.GetOAuth2ApplicationByID(ctx, grant.ApplicationID) if err != nil { return nil, &AccessTokenError{ ErrorCode: AccessTokenErrorCodeInvalidRequest, @@ -333,9 +334,9 @@ func IntrospectOAuth(ctx *context.Context) { token, err := oauth2.ParseToken(form.Token, oauth2.DefaultSigningKey) if err == nil { if token.Valid() == nil { - grant, err := auth.GetOAuth2GrantByID(token.GrantID) + grant, err := auth.GetOAuth2GrantByID(ctx, token.GrantID) if err == nil && grant != nil { - app, err := auth.GetOAuth2ApplicationByID(grant.ApplicationID) + app, err := auth.GetOAuth2ApplicationByID(ctx, grant.ApplicationID) if err == nil && app != nil { response.Active = true response.Scope = grant.Scope @@ -364,7 +365,7 @@ func AuthorizeOAuth(ctx *context.Context) { return } - app, err := auth.GetOAuth2ApplicationByClientID(form.ClientID) + app, err := auth.GetOAuth2ApplicationByClientID(ctx, form.ClientID) if err != nil { if auth.IsErrOauthClientIDInvalid(err) { handleAuthorizeError(ctx, AuthorizeError{ @@ -438,7 +439,7 @@ func AuthorizeOAuth(ctx *context.Context) { return } - grant, err := app.GetGrantByUserID(ctx.Doer.ID) + grant, err := app.GetGrantByUserID(ctx, ctx.Doer.ID) if err != nil { handleServerError(ctx, form.State, form.RedirectURI) return @@ -446,7 +447,7 @@ func AuthorizeOAuth(ctx *context.Context) { // Redirect if user already granted access if grant != nil { - code, err := grant.GenerateNewAuthorizationCode(form.RedirectURI, form.CodeChallenge, form.CodeChallengeMethod) + code, err := grant.GenerateNewAuthorizationCode(ctx, form.RedirectURI, form.CodeChallenge, form.CodeChallengeMethod) if err != nil { handleServerError(ctx, form.State, form.RedirectURI) return @@ -458,7 +459,7 @@ func AuthorizeOAuth(ctx *context.Context) { } // Update nonce to reflect the new session if len(form.Nonce) > 0 { - err := grant.SetNonce(form.Nonce) + err := grant.SetNonce(ctx, form.Nonce) if err != nil { log.Error("Unable to update nonce: %v", err) } @@ -510,12 +511,12 @@ func GrantApplicationOAuth(ctx *context.Context) { ctx.Error(http.StatusBadRequest) return } - app, err := auth.GetOAuth2ApplicationByClientID(form.ClientID) + app, err := auth.GetOAuth2ApplicationByClientID(ctx, form.ClientID) if err != nil { ctx.ServerError("GetOAuth2ApplicationByClientID", err) return } - grant, err := app.CreateGrant(ctx.Doer.ID, form.Scope) + grant, err := app.CreateGrant(ctx, ctx.Doer.ID, form.Scope) if err != nil { handleAuthorizeError(ctx, AuthorizeError{ State: form.State, @@ -525,7 +526,7 @@ func GrantApplicationOAuth(ctx *context.Context) { return } if len(form.Nonce) > 0 { - err := grant.SetNonce(form.Nonce) + err := grant.SetNonce(ctx, form.Nonce) if err != nil { log.Error("Unable to update nonce: %v", err) } @@ -535,7 +536,7 @@ func GrantApplicationOAuth(ctx *context.Context) { codeChallenge, _ = ctx.Session.Get("CodeChallenge").(string) codeChallengeMethod, _ = ctx.Session.Get("CodeChallengeMethod").(string) - code, err := grant.GenerateNewAuthorizationCode(form.RedirectURI, codeChallenge, codeChallengeMethod) + code, err := grant.GenerateNewAuthorizationCode(ctx, form.RedirectURI, codeChallenge, codeChallengeMethod) if err != nil { handleServerError(ctx, form.State, form.RedirectURI) return @@ -648,7 +649,7 @@ func handleRefreshToken(ctx *context.Context, form forms.AccessTokenForm, server return } // get grant before increasing counter - grant, err := auth.GetOAuth2GrantByID(token.GrantID) + grant, err := auth.GetOAuth2GrantByID(ctx, token.GrantID) if err != nil || grant == nil { handleAccessTokenError(ctx, AccessTokenError{ ErrorCode: AccessTokenErrorCodeInvalidGrant, @@ -666,7 +667,7 @@ func handleRefreshToken(ctx *context.Context, form forms.AccessTokenForm, server log.Warn("A client tried to use a refresh token for grant_id = %d was used twice!", grant.ID) return } - accessToken, tokenErr := newAccessTokenResponse(grant, serverKey, clientKey) + accessToken, tokenErr := newAccessTokenResponse(ctx, grant, serverKey, clientKey) if tokenErr != nil { handleAccessTokenError(ctx, *tokenErr) return @@ -675,7 +676,7 @@ func handleRefreshToken(ctx *context.Context, form forms.AccessTokenForm, server } func handleAuthorizationCode(ctx *context.Context, form forms.AccessTokenForm, serverKey, clientKey oauth2.JWTSigningKey) { - app, err := auth.GetOAuth2ApplicationByClientID(form.ClientID) + app, err := auth.GetOAuth2ApplicationByClientID(ctx, form.ClientID) if err != nil { handleAccessTokenError(ctx, AccessTokenError{ ErrorCode: AccessTokenErrorCodeInvalidClient, @@ -697,7 +698,7 @@ func handleAuthorizationCode(ctx *context.Context, form forms.AccessTokenForm, s }) return } - authorizationCode, err := auth.GetOAuth2AuthorizationByCode(form.Code) + authorizationCode, err := auth.GetOAuth2AuthorizationByCode(ctx, form.Code) if err != nil || authorizationCode == nil { handleAccessTokenError(ctx, AccessTokenError{ ErrorCode: AccessTokenErrorCodeUnauthorizedClient, @@ -722,13 +723,13 @@ func handleAuthorizationCode(ctx *context.Context, form forms.AccessTokenForm, s return } // remove token from database to deny duplicate usage - if err := authorizationCode.Invalidate(); err != nil { + if err := authorizationCode.Invalidate(ctx); err != nil { handleAccessTokenError(ctx, AccessTokenError{ ErrorCode: AccessTokenErrorCodeInvalidRequest, ErrorDescription: "cannot proceed your request", }) } - resp, tokenErr := newAccessTokenResponse(authorizationCode.Grant, serverKey, clientKey) + resp, tokenErr := newAccessTokenResponse(ctx, authorizationCode.Grant, serverKey, clientKey) if tokenErr != nil { handleAccessTokenError(ctx, *tokenErr) return diff --git a/routers/web/auth/oauth_test.go b/routers/web/auth/oauth_test.go index 669d7431f..5a09a9510 100644 --- a/routers/web/auth/oauth_test.go +++ b/routers/web/auth/oauth_test.go @@ -8,6 +8,7 @@ import ( "testing" "code.gitea.io/gitea/models/auth" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/models/unittest" user_model "code.gitea.io/gitea/models/user" "code.gitea.io/gitea/services/auth/source/oauth2" @@ -21,7 +22,7 @@ func createAndParseToken(t *testing.T, grant *auth.OAuth2Grant) *oauth2.OIDCToke assert.NoError(t, err) assert.NotNil(t, signingKey) - response, terr := newAccessTokenResponse(grant, signingKey, signingKey) + response, terr := newAccessTokenResponse(db.DefaultContext, grant, signingKey, signingKey) assert.Nil(t, terr) assert.NotNil(t, response) @@ -43,7 +44,7 @@ func createAndParseToken(t *testing.T, grant *auth.OAuth2Grant) *oauth2.OIDCToke func TestNewAccessTokenResponse_OIDCToken(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) - grants, err := auth.GetOAuth2GrantsByUserID(3) + grants, err := auth.GetOAuth2GrantsByUserID(db.DefaultContext, 3) assert.NoError(t, err) assert.Len(t, grants, 1) @@ -59,7 +60,7 @@ func TestNewAccessTokenResponse_OIDCToken(t *testing.T) { assert.False(t, oidcToken.EmailVerified) user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 5}).(*user_model.User) - grants, err = auth.GetOAuth2GrantsByUserID(user.ID) + grants, err = auth.GetOAuth2GrantsByUserID(db.DefaultContext, user.ID) assert.NoError(t, err) assert.Len(t, grants, 1) diff --git a/routers/web/auth/openid.go b/routers/web/auth/openid.go index 3012d8c5a..32ae91da4 100644 --- a/routers/web/auth/openid.go +++ b/routers/web/auth/openid.go @@ -217,7 +217,7 @@ func signInOpenIDVerify(ctx *context.Context) { } if u == nil && nickname != "" { - u, _ = user_model.GetUserByName(nickname) + u, _ = user_model.GetUserByName(ctx, nickname) if err != nil { if !user_model.IsErrUserNotExist(err) { ctx.RenderWithErr(err.Error(), tplSignInOpenID, &forms.SignInOpenIDForm{ @@ -307,7 +307,7 @@ func ConnectOpenIDPost(ctx *context.Context) { // add OpenID for the user userOID := &user_model.UserOpenID{UID: u.ID, URI: oid} - if err = user_model.AddUserOpenID(userOID); err != nil { + if err = user_model.AddUserOpenID(ctx, userOID); err != nil { if user_model.IsErrOpenIDAlreadyUsed(err) { ctx.RenderWithErr(ctx.Tr("form.openid_been_used", oid), tplConnectOID, &form) return @@ -434,7 +434,7 @@ func RegisterOpenIDPost(ctx *context.Context) { // add OpenID for the user userOID := &user_model.UserOpenID{UID: u.ID, URI: oid} - if err = user_model.AddUserOpenID(userOID); err != nil { + if err = user_model.AddUserOpenID(ctx, userOID); err != nil { if user_model.IsErrOpenIDAlreadyUsed(err) { ctx.RenderWithErr(ctx.Tr("form.openid_been_used", oid), tplSignUpOID, &form) return diff --git a/routers/web/org/org_labels.go b/routers/web/org/org_labels.go index d79ffc597..bfa9f162c 100644 --- a/routers/web/org/org_labels.go +++ b/routers/web/org/org_labels.go @@ -17,7 +17,7 @@ import ( // RetrieveLabels find all the labels of an organization func RetrieveLabels(ctx *context.Context) { - labels, err := models.GetLabelsByOrgID(ctx.Org.Organization.ID, ctx.FormString("sort"), db.ListOptions{}) + labels, err := models.GetLabelsByOrgID(ctx, ctx.Org.Organization.ID, ctx.FormString("sort"), db.ListOptions{}) if err != nil { ctx.ServerError("RetrieveLabels.GetLabels", err) return @@ -59,7 +59,7 @@ func NewLabel(ctx *context.Context) { // UpdateLabel update a label's name and color func UpdateLabel(ctx *context.Context) { form := web.GetForm(ctx).(*forms.CreateLabelForm) - l, err := models.GetLabelInOrgByID(ctx.Org.Organization.ID, form.ID) + l, err := models.GetLabelInOrgByID(ctx, ctx.Org.Organization.ID, form.ID) if err != nil { switch { case models.IsErrOrgLabelNotExist(err): diff --git a/routers/web/org/setting.go b/routers/web/org/setting.go index 5cd245ef0..758ca47af 100644 --- a/routers/web/org/setting.go +++ b/routers/web/org/setting.go @@ -66,7 +66,7 @@ func SettingsPost(ctx *context.Context) { // Check if organization name has been changed. if org.LowerName != strings.ToLower(form.Name) { - isExist, err := user_model.IsUserExist(org.ID, form.Name) + isExist, err := user_model.IsUserExist(ctx, org.ID, form.Name) if err != nil { ctx.ServerError("IsUserExist", err) return @@ -110,7 +110,7 @@ func SettingsPost(ctx *context.Context) { visibilityChanged := form.Visibility != org.Visibility org.Visibility = form.Visibility - if err := user_model.UpdateUser(org.AsUser(), false); err != nil { + if err := user_model.UpdateUser(ctx, org.AsUser(), false); err != nil { ctx.ServerError("UpdateUser", err) return } @@ -207,7 +207,7 @@ func Webhooks(ctx *context.Context) { ctx.Data["BaseLinkNew"] = ctx.Org.OrgLink + "/settings/hooks" ctx.Data["Description"] = ctx.Tr("org.settings.hooks_desc") - ws, err := webhook.ListWebhooksByOpts(&webhook.ListWebhookOptions{OrgID: ctx.Org.Organization.ID}) + ws, err := webhook.ListWebhooksByOpts(ctx, &webhook.ListWebhookOptions{OrgID: ctx.Org.Organization.ID}) if err != nil { ctx.ServerError("GetWebhooksByOrgId", err) return diff --git a/routers/web/org/teams.go b/routers/web/org/teams.go index 3689ffe93..284fb096f 100644 --- a/routers/web/org/teams.go +++ b/routers/web/org/teams.go @@ -122,7 +122,7 @@ func TeamsAction(ctx *context.Context) { } uname := utils.RemoveUsernameParameterSuffix(strings.ToLower(ctx.FormString("uname"))) var u *user_model.User - u, err = user_model.GetUserByName(uname) + u, err = user_model.GetUserByName(ctx, uname) if err != nil { if user_model.IsErrUserNotExist(err) { ctx.Flash.Error(ctx.Tr("form.user_not_exist")) diff --git a/routers/web/repo/attachment.go b/routers/web/repo/attachment.go index b05d6448d..701236f83 100644 --- a/routers/web/repo/attachment.go +++ b/routers/web/repo/attachment.go @@ -64,7 +64,7 @@ func uploadAttachment(ctx *context.Context, repoID int64, allowedTypes string) { // DeleteAttachment response for deleting issue's attachment func DeleteAttachment(ctx *context.Context) { file := ctx.FormString("file") - attach, err := repo_model.GetAttachmentByUUID(file) + attach, err := repo_model.GetAttachmentByUUID(ctx, file) if err != nil { ctx.Error(http.StatusBadRequest, err.Error()) return @@ -85,7 +85,7 @@ func DeleteAttachment(ctx *context.Context) { // GetAttachment serve attachements func GetAttachment(ctx *context.Context) { - attach, err := repo_model.GetAttachmentByUUID(ctx.Params(":uuid")) + attach, err := repo_model.GetAttachmentByUUID(ctx, ctx.Params(":uuid")) if err != nil { if repo_model.IsErrAttachmentNotExist(err) { ctx.Error(http.StatusNotFound) diff --git a/routers/web/repo/commit.go b/routers/web/repo/commit.go index 7f68fd3dd..1636a6c29 100644 --- a/routers/web/repo/commit.go +++ b/routers/web/repo/commit.go @@ -335,7 +335,7 @@ func Diff(ctx *context.Context) { ctx.Data["Commit"] = commit ctx.Data["Diff"] = diff - statuses, _, err := models.GetLatestCommitStatus(ctx.Repo.Repository.ID, commitID, db.ListOptions{}) + statuses, _, err := models.GetLatestCommitStatus(ctx, ctx.Repo.Repository.ID, commitID, db.ListOptions{}) if err != nil { log.Error("GetLatestCommitStatus: %v", err) } diff --git a/routers/web/repo/compare.go b/routers/web/repo/compare.go index 3ea888454..9f9448923 100644 --- a/routers/web/repo/compare.go +++ b/routers/web/repo/compare.go @@ -254,7 +254,7 @@ func ParseCompareInfo(ctx *context.Context) *CompareInfo { } else if len(headInfos) == 2 { headInfosSplit := strings.Split(headInfos[0], "/") if len(headInfosSplit) == 1 { - ci.HeadUser, err = user_model.GetUserByName(headInfos[0]) + ci.HeadUser, err = user_model.GetUserByName(ctx, headInfos[0]) if err != nil { if user_model.IsErrUserNotExist(err) { ctx.NotFound("GetUserByName", nil) diff --git a/routers/web/repo/issue.go b/routers/web/repo/issue.go index f50b30da1..079ccbf6c 100644 --- a/routers/web/repo/issue.go +++ b/routers/web/repo/issue.go @@ -255,7 +255,7 @@ func issues(ctx *context.Context, milestoneID, projectID int64, isPullOption uti } issueList := models.IssueList(issues) - approvalCounts, err := issueList.GetApprovalCounts() + approvalCounts, err := issueList.GetApprovalCounts(ctx) if err != nil { ctx.ServerError("ApprovalCounts", err) return @@ -294,14 +294,14 @@ func issues(ctx *context.Context, milestoneID, projectID int64, isPullOption uti return } - labels, err := models.GetLabelsByRepoID(repo.ID, "", db.ListOptions{}) + labels, err := models.GetLabelsByRepoID(ctx, repo.ID, "", db.ListOptions{}) if err != nil { ctx.ServerError("GetLabelsByRepoID", err) return } if repo.Owner.IsOrganization() { - orgLabels, err := models.GetLabelsByOrgID(repo.Owner.ID, ctx.FormString("sort"), db.ListOptions{}) + orgLabels, err := models.GetLabelsByOrgID(ctx, repo.Owner.ID, ctx.FormString("sort"), db.ListOptions{}) if err != nil { ctx.ServerError("GetLabelsByOrgID", err) return @@ -343,7 +343,7 @@ func issues(ctx *context.Context, milestoneID, projectID int64, isPullOption uti } if ctx.Repo.CanWriteIssuesOrPulls(ctx.Params(":type") == "pulls") { - projects, _, err := project_model.GetProjects(project_model.SearchOptions{ + projects, _, err := project_model.GetProjects(ctx, project_model.SearchOptions{ RepoID: repo.ID, Type: project_model.TypeRepository, IsClosed: util.OptionalBoolOf(isShowClosed), @@ -453,7 +453,7 @@ func RetrieveRepoMilestonesAndAssignees(ctx *context.Context, repo *repo_model.R func retrieveProjects(ctx *context.Context, repo *repo_model.Repository) { var err error - ctx.Data["OpenProjects"], _, err = project_model.GetProjects(project_model.SearchOptions{ + ctx.Data["OpenProjects"], _, err = project_model.GetProjects(ctx, project_model.SearchOptions{ RepoID: repo.ID, Page: -1, IsClosed: util.OptionalBoolFalse, @@ -464,7 +464,7 @@ func retrieveProjects(ctx *context.Context, repo *repo_model.Repository) { return } - ctx.Data["ClosedProjects"], _, err = project_model.GetProjects(project_model.SearchOptions{ + ctx.Data["ClosedProjects"], _, err = project_model.GetProjects(ctx, project_model.SearchOptions{ RepoID: repo.ID, Page: -1, IsClosed: util.OptionalBoolTrue, @@ -673,14 +673,14 @@ func RetrieveRepoMetas(ctx *context.Context, repo *repo_model.Repository, isPull return nil } - labels, err := models.GetLabelsByRepoID(repo.ID, "", db.ListOptions{}) + labels, err := models.GetLabelsByRepoID(ctx, repo.ID, "", db.ListOptions{}) if err != nil { ctx.ServerError("GetLabelsByRepoID", err) return nil } ctx.Data["Labels"] = labels if repo.Owner.IsOrganization() { - orgLabels, err := models.GetLabelsByOrgID(repo.Owner.ID, ctx.FormString("sort"), db.ListOptions{}) + orgLabels, err := models.GetLabelsByOrgID(ctx, repo.Owner.ID, ctx.FormString("sort"), db.ListOptions{}) if err != nil { return nil } @@ -761,10 +761,10 @@ func setTemplateIfExists(ctx *context.Context, ctxDataKey string, possibleDirs, ctx.Data[issueTemplateTitleKey] = meta.Title ctx.Data[ctxDataKey] = templateBody labelIDs := make([]string, 0, len(meta.Labels)) - if repoLabels, err := models.GetLabelsByRepoID(ctx.Repo.Repository.ID, "", db.ListOptions{}); err == nil { + if repoLabels, err := models.GetLabelsByRepoID(ctx, ctx.Repo.Repository.ID, "", db.ListOptions{}); err == nil { ctx.Data["Labels"] = repoLabels if ctx.Repo.Owner.IsOrganization() { - if orgLabels, err := models.GetLabelsByOrgID(ctx.Repo.Owner.ID, ctx.FormString("sort"), db.ListOptions{}); err == nil { + if orgLabels, err := models.GetLabelsByOrgID(ctx, ctx.Repo.Owner.ID, ctx.FormString("sort"), db.ListOptions{}); err == nil { ctx.Data["OrgLabels"] = orgLabels repoLabels = append(repoLabels, orgLabels...) } @@ -818,7 +818,7 @@ func NewIssue(ctx *context.Context) { projectID := ctx.FormInt64("project") if projectID > 0 { - project, err := project_model.GetProjectByID(projectID) + project, err := project_model.GetProjectByID(ctx, projectID) if err != nil { log.Error("GetProjectByID: %d: %v", projectID, err) } else if project.RepoID != ctx.Repo.Repository.ID { @@ -930,7 +930,7 @@ func ValidateRepoMetas(ctx *context.Context, form forms.CreateIssueForm, isPull } if form.ProjectID > 0 { - p, err := project_model.GetProjectByID(form.ProjectID) + p, err := project_model.GetProjectByID(ctx, form.ProjectID) if err != nil { ctx.ServerError("GetProjectByID", err) return nil, nil, 0, 0 @@ -1237,7 +1237,7 @@ func ViewIssue(ctx *context.Context) { for i := range issue.Labels { labelIDMark[issue.Labels[i].ID] = true } - labels, err := models.GetLabelsByRepoID(repo.ID, "", db.ListOptions{}) + labels, err := models.GetLabelsByRepoID(ctx, repo.ID, "", db.ListOptions{}) if err != nil { ctx.ServerError("GetLabelsByRepoID", err) return @@ -1245,7 +1245,7 @@ func ViewIssue(ctx *context.Context) { ctx.Data["Labels"] = labels if repo.Owner.IsOrganization() { - orgLabels, err := models.GetLabelsByOrgID(repo.Owner.ID, ctx.FormString("sort"), db.ListOptions{}) + orgLabels, err := models.GetLabelsByOrgID(ctx, repo.Owner.ID, ctx.FormString("sort"), db.ListOptions{}) if err != nil { ctx.ServerError("GetLabelsByOrgID", err) return @@ -1277,7 +1277,7 @@ func ViewIssue(ctx *context.Context) { if issue.IsPull { canChooseReviewer := ctx.Repo.CanWrite(unit.TypePullRequests) if !canChooseReviewer && ctx.Doer != nil && ctx.IsSigned { - canChooseReviewer, err = models.IsOfficialReviewer(issue, ctx.Doer) + canChooseReviewer, err = models.IsOfficialReviewer(ctx, issue, ctx.Doer) if err != nil { ctx.ServerError("IsOfficialReviewer", err) return @@ -1312,7 +1312,7 @@ func ViewIssue(ctx *context.Context) { if !ctx.Data["IsStopwatchRunning"].(bool) { var exists bool var sw *models.Stopwatch - if exists, sw, err = models.HasUserStopwatch(ctx.Doer.ID); err != nil { + if exists, sw, err = models.HasUserStopwatch(ctx, ctx.Doer.ID); err != nil { ctx.ServerError("HasUserStopwatch", err) return } @@ -1688,12 +1688,12 @@ func ViewIssue(ctx *context.Context) { } // Get Dependencies - ctx.Data["BlockedByDependencies"], err = issue.BlockedByDependencies() + ctx.Data["BlockedByDependencies"], err = issue.BlockedByDependencies(ctx) if err != nil { ctx.ServerError("BlockedByDependencies", err) return } - ctx.Data["BlockingDependencies"], err = issue.BlockingDependencies() + ctx.Data["BlockingDependencies"], err = issue.BlockingDependencies(ctx) if err != nil { ctx.ServerError("BlockingDependencies", err) return @@ -1767,7 +1767,7 @@ func getActionIssues(ctx *context.Context) []*models.Issue { } issueIDs = append(issueIDs, issueID) } - issues, err := models.GetIssuesByIDs(issueIDs) + issues, err := models.GetIssuesByIDs(ctx, issueIDs) if err != nil { ctx.ServerError("GetIssuesByIDs", err) return nil @@ -1873,7 +1873,7 @@ func UpdateIssueContent(ctx *context.Context) { // when update the request doesn't intend to update attachments (eg: change checkbox state), ignore attachment updates if !ctx.FormBool("ignore_attachments") { - if err := updateAttachments(issue, ctx.FormStrings("files[]")); err != nil { + if err := updateAttachments(ctx, issue, ctx.FormStrings("files[]")); err != nil { ctx.ServerError("UpdateAttachments", err) return } @@ -2047,7 +2047,7 @@ func UpdatePullReviewRequest(ctx *context.Context) { return } - team, err := organization.GetTeamByID(-reviewID) + team, err := organization.GetTeamByID(ctx, -reviewID) if err != nil { ctx.ServerError("GetTeamByID", err) return @@ -2160,7 +2160,7 @@ func SearchIssues(ctx *context.Context) { opts.AllLimited = true } if ctx.FormString("owner") != "" { - owner, err := user_model.GetUserByName(ctx.FormString("owner")) + owner, err := user_model.GetUserByName(ctx, ctx.FormString("owner")) if err != nil { if user_model.IsErrUserNotExist(err) { ctx.Error(http.StatusBadRequest, "Owner not found", err.Error()) @@ -2179,7 +2179,7 @@ func SearchIssues(ctx *context.Context) { ctx.Error(http.StatusBadRequest, "", "Owner organisation is required for filtering on team") return } - team, err := organization.GetTeam(opts.OwnerID, ctx.FormString("team")) + team, err := organization.GetTeam(ctx, opts.OwnerID, ctx.FormString("team")) if err != nil { if organization.IsErrTeamNotExist(err) { ctx.Error(http.StatusBadRequest, "Team not found", err.Error()) @@ -2307,7 +2307,7 @@ func getUserIDForFilter(ctx *context.Context, queryName string) int64 { return 0 } - user, err := user_model.GetUserByName(userName) + user, err := user_model.GetUserByName(ctx, userName) if user_model.IsErrUserNotExist(err) { ctx.NotFound("", err) return 0 @@ -2631,7 +2631,7 @@ func NewComment(ctx *context.Context) { // UpdateCommentContent change comment of issue's content func UpdateCommentContent(ctx *context.Context) { - comment, err := models.GetCommentByID(ctx.ParamsInt64(":id")) + comment, err := models.GetCommentByID(ctx, ctx.ParamsInt64(":id")) if err != nil { ctx.NotFoundOrServerError("GetCommentByID", models.IsErrCommentNotExist, err) return @@ -2672,7 +2672,7 @@ func UpdateCommentContent(ctx *context.Context) { // when the update request doesn't intend to update attachments (eg: change checkbox state), ignore attachment updates if !ctx.FormBool("ignore_attachments") { - if err := updateAttachments(comment, ctx.FormStrings("files[]")); err != nil { + if err := updateAttachments(ctx, comment, ctx.FormStrings("files[]")); err != nil { ctx.ServerError("UpdateAttachments", err) return } @@ -2697,7 +2697,7 @@ func UpdateCommentContent(ctx *context.Context) { // DeleteComment delete comment of issue func DeleteComment(ctx *context.Context) { - comment, err := models.GetCommentByID(ctx.ParamsInt64(":id")) + comment, err := models.GetCommentByID(ctx, ctx.ParamsInt64(":id")) if err != nil { ctx.NotFoundOrServerError("GetCommentByID", models.IsErrCommentNotExist, err) return @@ -2823,7 +2823,7 @@ func ChangeIssueReaction(ctx *context.Context) { // ChangeCommentReaction create a reaction for comment func ChangeCommentReaction(ctx *context.Context) { form := web.GetForm(ctx).(*forms.ReactionForm) - comment, err := models.GetCommentByID(ctx.ParamsInt64(":id")) + comment, err := models.GetCommentByID(ctx, ctx.ParamsInt64(":id")) if err != nil { ctx.NotFoundOrServerError("GetCommentByID", models.IsErrCommentNotExist, err) return @@ -2968,7 +2968,7 @@ func GetIssueAttachments(ctx *context.Context) { // GetCommentAttachments returns attachments for the comment func GetCommentAttachments(ctx *context.Context) { - comment, err := models.GetCommentByID(ctx.ParamsInt64(":id")) + comment, err := models.GetCommentByID(ctx, ctx.ParamsInt64(":id")) if err != nil { ctx.NotFoundOrServerError("GetCommentByID", models.IsErrCommentNotExist, err) return @@ -2986,7 +2986,7 @@ func GetCommentAttachments(ctx *context.Context) { ctx.JSON(http.StatusOK, attachments) } -func updateAttachments(item interface{}, files []string) error { +func updateAttachments(ctx *context.Context, item interface{}, files []string) error { var attachments []*repo_model.Attachment switch content := item.(type) { case *models.Issue: @@ -3020,9 +3020,9 @@ func updateAttachments(item interface{}, files []string) error { } switch content := item.(type) { case *models.Issue: - content.Attachments, err = repo_model.GetAttachmentsByIssueID(content.ID) + content.Attachments, err = repo_model.GetAttachmentsByIssueID(ctx, content.ID) case *models.Comment: - content.Attachments, err = repo_model.GetAttachmentsByCommentID(content.ID) + content.Attachments, err = repo_model.GetAttachmentsByCommentID(ctx, content.ID) default: return fmt.Errorf("unknown Type: %T", content) } diff --git a/routers/web/repo/issue_content_history.go b/routers/web/repo/issue_content_history.go index 11cc8a2a6..407832dff 100644 --- a/routers/web/repo/issue_content_history.go +++ b/routers/web/repo/issue_content_history.go @@ -130,7 +130,7 @@ func GetContentHistoryDetail(ctx *context.Context) { var comment *models.Comment if history.CommentID != 0 { var err error - if comment, err = models.GetCommentByID(history.CommentID); err != nil { + if comment, err = models.GetCommentByID(ctx, history.CommentID); err != nil { log.Error("can not get comment for issue content history %v. err=%v", historyID, err) return } @@ -190,7 +190,7 @@ func SoftDeleteContentHistory(ctx *context.Context) { var history *issuesModel.ContentHistory var err error if commentID != 0 { - if comment, err = models.GetCommentByID(commentID); err != nil { + if comment, err = models.GetCommentByID(ctx, commentID); err != nil { log.Error("can not get comment for issue content history %v. err=%v", historyID, err) return } diff --git a/routers/web/repo/issue_label.go b/routers/web/repo/issue_label.go index 887bbc115..2e72d659b 100644 --- a/routers/web/repo/issue_label.go +++ b/routers/web/repo/issue_label.go @@ -56,7 +56,7 @@ func InitializeLabels(ctx *context.Context) { // RetrieveLabels find all the labels of a repository and organization func RetrieveLabels(ctx *context.Context) { - labels, err := models.GetLabelsByRepoID(ctx.Repo.Repository.ID, ctx.FormString("sort"), db.ListOptions{}) + labels, err := models.GetLabelsByRepoID(ctx, ctx.Repo.Repository.ID, ctx.FormString("sort"), db.ListOptions{}) if err != nil { ctx.ServerError("RetrieveLabels.GetLabels", err) return @@ -69,7 +69,7 @@ func RetrieveLabels(ctx *context.Context) { ctx.Data["Labels"] = labels if ctx.Repo.Owner.IsOrganization() { - orgLabels, err := models.GetLabelsByOrgID(ctx.Repo.Owner.ID, ctx.FormString("sort"), db.ListOptions{}) + orgLabels, err := models.GetLabelsByOrgID(ctx, ctx.Repo.Owner.ID, ctx.FormString("sort"), db.ListOptions{}) if err != nil { ctx.ServerError("GetLabelsByOrgID", err) return @@ -127,7 +127,7 @@ func NewLabel(ctx *context.Context) { // UpdateLabel update a label's name and color func UpdateLabel(ctx *context.Context) { form := web.GetForm(ctx).(*forms.CreateLabelForm) - l, err := models.GetLabelInRepoByID(ctx.Repo.Repository.ID, form.ID) + l, err := models.GetLabelInRepoByID(ctx, ctx.Repo.Repository.ID, form.ID) if err != nil { switch { case models.IsErrRepoLabelNotExist(err): @@ -177,7 +177,7 @@ func UpdateIssueLabel(ctx *context.Context) { } } case "attach", "detach", "toggle": - label, err := models.GetLabelByID(ctx.FormInt64("id")) + label, err := models.GetLabelByID(ctx, ctx.FormInt64("id")) if err != nil { if models.IsErrRepoLabelNotExist(err) { ctx.Error(http.StatusNotFound, "GetLabelByID") @@ -191,7 +191,7 @@ func UpdateIssueLabel(ctx *context.Context) { // detach if any issues already have label, otherwise attach action = "attach" for _, issue := range issues { - if models.HasIssueLabel(issue.ID, label.ID) { + if models.HasIssueLabel(ctx, issue.ID, label.ID) { action = "detach" break } diff --git a/routers/web/repo/issue_stopwatch.go b/routers/web/repo/issue_stopwatch.go index 83e4ecedb..4e1f6af03 100644 --- a/routers/web/repo/issue_stopwatch.go +++ b/routers/web/repo/issue_stopwatch.go @@ -87,7 +87,7 @@ func GetActiveStopwatch(ctx *context.Context) { return } - _, sw, err := models.HasUserStopwatch(ctx.Doer.ID) + _, sw, err := models.HasUserStopwatch(ctx, ctx.Doer.ID) if err != nil { ctx.ServerError("HasUserStopwatch", err) return diff --git a/routers/web/repo/projects.go b/routers/web/repo/projects.go index a6f843d84..c1805944d 100644 --- a/routers/web/repo/projects.go +++ b/routers/web/repo/projects.go @@ -70,7 +70,7 @@ func Projects(ctx *context.Context) { total = repo.NumClosedProjects } - projects, count, err := project_model.GetProjects(project_model.SearchOptions{ + projects, count, err := project_model.GetProjects(ctx, project_model.SearchOptions{ RepoID: repo.ID, Page: page, IsClosed: util.OptionalBoolOf(isShowClosed), @@ -182,7 +182,7 @@ func ChangeProjectStatus(ctx *context.Context) { // DeleteProject delete a project func DeleteProject(ctx *context.Context) { - p, err := project_model.GetProjectByID(ctx.ParamsInt64(":id")) + p, err := project_model.GetProjectByID(ctx, ctx.ParamsInt64(":id")) if err != nil { if project_model.IsErrProjectNotExist(err) { ctx.NotFound("", nil) @@ -213,7 +213,7 @@ func EditProject(ctx *context.Context) { ctx.Data["PageIsEditProjects"] = true ctx.Data["CanWriteProjects"] = ctx.Repo.Permission.CanWrite(unit.TypeProjects) - p, err := project_model.GetProjectByID(ctx.ParamsInt64(":id")) + p, err := project_model.GetProjectByID(ctx, ctx.ParamsInt64(":id")) if err != nil { if project_model.IsErrProjectNotExist(err) { ctx.NotFound("", nil) @@ -245,7 +245,7 @@ func EditProjectPost(ctx *context.Context) { return } - p, err := project_model.GetProjectByID(ctx.ParamsInt64(":id")) + p, err := project_model.GetProjectByID(ctx, ctx.ParamsInt64(":id")) if err != nil { if project_model.IsErrProjectNotExist(err) { ctx.NotFound("", nil) @@ -261,7 +261,7 @@ func EditProjectPost(ctx *context.Context) { p.Title = form.Title p.Description = form.Content - if err = project_model.UpdateProject(p); err != nil { + if err = project_model.UpdateProject(ctx, p); err != nil { ctx.ServerError("UpdateProjects", err) return } @@ -272,7 +272,7 @@ func EditProjectPost(ctx *context.Context) { // ViewProject renders the project board for a project func ViewProject(ctx *context.Context) { - project, err := project_model.GetProjectByID(ctx.ParamsInt64(":id")) + project, err := project_model.GetProjectByID(ctx, ctx.ParamsInt64(":id")) if err != nil { if project_model.IsErrProjectNotExist(err) { ctx.NotFound("", nil) @@ -286,7 +286,7 @@ func ViewProject(ctx *context.Context) { return } - boards, err := project_model.GetBoards(project.ID) + boards, err := project_model.GetBoards(ctx, project.ID) if err != nil { ctx.ServerError("GetProjectBoards", err) return @@ -385,7 +385,7 @@ func DeleteProjectBoard(ctx *context.Context) { return } - project, err := project_model.GetProjectByID(ctx.ParamsInt64(":id")) + project, err := project_model.GetProjectByID(ctx, ctx.ParamsInt64(":id")) if err != nil { if project_model.IsErrProjectNotExist(err) { ctx.NotFound("", nil) @@ -395,7 +395,7 @@ func DeleteProjectBoard(ctx *context.Context) { return } - pb, err := project_model.GetBoard(ctx.ParamsInt64(":boardID")) + pb, err := project_model.GetBoard(ctx, ctx.ParamsInt64(":boardID")) if err != nil { ctx.ServerError("GetProjectBoard", err) return @@ -434,7 +434,7 @@ func AddBoardToProjectPost(ctx *context.Context) { return } - project, err := project_model.GetProjectByID(ctx.ParamsInt64(":id")) + project, err := project_model.GetProjectByID(ctx, ctx.ParamsInt64(":id")) if err != nil { if project_model.IsErrProjectNotExist(err) { ctx.NotFound("", nil) @@ -474,7 +474,7 @@ func checkProjectBoardChangePermissions(ctx *context.Context) (*project_model.Pr return nil, nil } - project, err := project_model.GetProjectByID(ctx.ParamsInt64(":id")) + project, err := project_model.GetProjectByID(ctx, ctx.ParamsInt64(":id")) if err != nil { if project_model.IsErrProjectNotExist(err) { ctx.NotFound("", nil) @@ -484,7 +484,7 @@ func checkProjectBoardChangePermissions(ctx *context.Context) (*project_model.Pr return nil, nil } - board, err := project_model.GetBoard(ctx.ParamsInt64(":boardID")) + board, err := project_model.GetBoard(ctx, ctx.ParamsInt64(":boardID")) if err != nil { ctx.ServerError("GetProjectBoard", err) return nil, nil @@ -523,7 +523,7 @@ func EditProjectBoard(ctx *context.Context) { board.Sorting = form.Sorting } - if err := project_model.UpdateBoard(board); err != nil { + if err := project_model.UpdateBoard(ctx, board); err != nil { ctx.ServerError("UpdateProjectBoard", err) return } @@ -566,7 +566,7 @@ func MoveIssues(ctx *context.Context) { return } - project, err := project_model.GetProjectByID(ctx.ParamsInt64(":id")) + project, err := project_model.GetProjectByID(ctx, ctx.ParamsInt64(":id")) if err != nil { if project_model.IsErrProjectNotExist(err) { ctx.NotFound("ProjectNotExist", nil) @@ -589,7 +589,7 @@ func MoveIssues(ctx *context.Context) { Title: ctx.Tr("repo.projects.type.uncategorized"), } } else { - board, err = project_model.GetBoard(ctx.ParamsInt64(":boardID")) + board, err = project_model.GetBoard(ctx, ctx.ParamsInt64(":boardID")) if err != nil { if project_model.IsErrProjectBoardNotExist(err) { ctx.NotFound("ProjectBoardNotExist", nil) @@ -622,7 +622,7 @@ func MoveIssues(ctx *context.Context) { issueIDs = append(issueIDs, issue.IssueID) sortedIssueIDs[issue.Sorting] = issue.IssueID } - movedIssues, err := models.GetIssuesByIDs(issueIDs) + movedIssues, err := models.GetIssuesByIDs(ctx, issueIDs) if err != nil { if models.IsErrIssueNotExist(err) { ctx.NotFound("IssueNotExisting", nil) diff --git a/routers/web/repo/pull.go b/routers/web/repo/pull.go index fd224a22e..3f24be33d 100644 --- a/routers/web/repo/pull.go +++ b/routers/web/repo/pull.go @@ -377,7 +377,7 @@ func PrepareMergedViewPullInfo(ctx *context.Context, issue *models.Issue) *git.C if len(compareInfo.Commits) != 0 { sha := compareInfo.Commits[0].ID.String() - commitStatuses, _, err := models.GetLatestCommitStatus(ctx.Repo.Repository.ID, sha, db.ListOptions{}) + commitStatuses, _, err := models.GetLatestCommitStatus(ctx, ctx.Repo.Repository.ID, sha, db.ListOptions{}) if err != nil { ctx.ServerError("GetLatestCommitStatus", err) return nil @@ -438,7 +438,7 @@ func PrepareViewPullInfo(ctx *context.Context, issue *models.Issue) *git.Compare ctx.ServerError(fmt.Sprintf("GetRefCommitID(%s)", pull.GetGitRefName()), err) return nil } - commitStatuses, _, err := models.GetLatestCommitStatus(repo.ID, sha, db.ListOptions{}) + commitStatuses, _, err := models.GetLatestCommitStatus(ctx, repo.ID, sha, db.ListOptions{}) if err != nil { ctx.ServerError("GetLatestCommitStatus", err) return nil @@ -528,7 +528,7 @@ func PrepareViewPullInfo(ctx *context.Context, issue *models.Issue) *git.Compare return nil } - commitStatuses, _, err := models.GetLatestCommitStatus(repo.ID, sha, db.ListOptions{}) + commitStatuses, _, err := models.GetLatestCommitStatus(ctx, repo.ID, sha, db.ListOptions{}) if err != nil { ctx.ServerError("GetLatestCommitStatus", err) return nil @@ -767,7 +767,7 @@ func ViewPullFiles(ctx *context.Context) { return } - currentReview, err := models.GetCurrentReview(ctx.Doer, issue) + currentReview, err := models.GetCurrentReview(ctx, ctx.Doer, issue) if err != nil && !models.IsErrReviewNotExist(err) { ctx.ServerError("GetCurrentReview", err) return @@ -1354,7 +1354,7 @@ func DownloadPullPatch(ctx *context.Context) { // DownloadPullDiffOrPatch render a pull's raw diff or patch func DownloadPullDiffOrPatch(ctx *context.Context, patch bool) { - pr, err := models.GetPullRequestByIndex(ctx.Repo.Repository.ID, ctx.ParamsInt64(":index")) + pr, err := models.GetPullRequestByIndex(ctx, ctx.Repo.Repository.ID, ctx.ParamsInt64(":index")) if err != nil { if models.IsErrPullRequestNotExist(err) { ctx.NotFound("GetPullRequestByIndex", err) @@ -1447,7 +1447,7 @@ func UpdatePullRequestTarget(ctx *context.Context) { func SetAllowEdits(ctx *context.Context) { form := web.GetForm(ctx).(*forms.UpdateAllowEditsForm) - pr, err := models.GetPullRequestByIndex(ctx.Repo.Repository.ID, ctx.ParamsInt64(":index")) + pr, err := models.GetPullRequestByIndex(ctx, ctx.Repo.Repository.ID, ctx.ParamsInt64(":index")) if err != nil { if models.IsErrPullRequestNotExist(err) { ctx.NotFound("GetPullRequestByIndex", err) diff --git a/routers/web/repo/pull_review.go b/routers/web/repo/pull_review.go index 98272ed48..e05129020 100644 --- a/routers/web/repo/pull_review.go +++ b/routers/web/repo/pull_review.go @@ -31,7 +31,7 @@ func RenderNewCodeCommentForm(ctx *context.Context) { if !issue.IsPull { return } - currentReview, err := models.GetCurrentReview(ctx.Doer, issue) + currentReview, err := models.GetCurrentReview(ctx, ctx.Doer, issue) if err != nil && !models.IsErrReviewNotExist(err) { ctx.ServerError("GetCurrentReview", err) return @@ -107,7 +107,7 @@ func UpdateResolveConversation(ctx *context.Context) { action := ctx.FormString("action") commentID := ctx.FormInt64("comment_id") - comment, err := models.GetCommentByID(commentID) + comment, err := models.GetCommentByID(ctx, commentID) if err != nil { ctx.ServerError("GetIssueByID", err) return diff --git a/routers/web/repo/release.go b/routers/web/repo/release.go index ebc650080..fba3ef7a0 100644 --- a/routers/web/repo/release.go +++ b/routers/web/repo/release.go @@ -126,7 +126,7 @@ func releasesOrTags(ctx *context.Context, isTagList bool) { return } - if err = models.GetReleaseAttachments(releases...); err != nil { + if err = models.GetReleaseAttachments(ctx, releases...); err != nil { ctx.ServerError("GetReleaseAttachments", err) return } @@ -202,7 +202,7 @@ func SingleRelease(ctx *context.Context) { return } - err = models.GetReleaseAttachments(release) + err = models.GetReleaseAttachments(ctx, release) if err != nil { ctx.ServerError("GetReleaseAttachments", err) return diff --git a/routers/web/repo/repo.go b/routers/web/repo/repo.go index 199651b2f..30cb888dc 100644 --- a/routers/web/repo/repo.go +++ b/routers/web/repo/repo.go @@ -285,9 +285,9 @@ func Action(ctx *context.Context) { var err error switch ctx.Params(":action") { case "watch": - err = repo_model.WatchRepo(ctx.Doer.ID, ctx.Repo.Repository.ID, true) + err = repo_model.WatchRepo(ctx, ctx.Doer.ID, ctx.Repo.Repository.ID, true) case "unwatch": - err = repo_model.WatchRepo(ctx.Doer.ID, ctx.Repo.Repository.ID, false) + err = repo_model.WatchRepo(ctx, ctx.Doer.ID, ctx.Repo.Repository.ID, false) case "star": err = repo_model.StarRepo(ctx.Doer.ID, ctx.Repo.Repository.ID, true) case "unstar": @@ -369,7 +369,7 @@ func RedirectDownload(ctx *context.Context) { } if len(releases) == 1 { release := releases[0] - att, err := repo_model.GetAttachmentByReleaseIDFileName(release.ID, fileName) + att, err := repo_model.GetAttachmentByReleaseIDFileName(ctx, release.ID, fileName) if err != nil { ctx.Error(http.StatusNotFound) return diff --git a/routers/web/repo/setting.go b/routers/web/repo/setting.go index dd1cb412c..a60bf5262 100644 --- a/routers/web/repo/setting.go +++ b/routers/web/repo/setting.go @@ -72,14 +72,14 @@ func Settings(ctx *context.Context) { ctx.Data["CodeIndexerEnabled"] = setting.Indexer.RepoIndexerEnabled if ctx.Doer.IsAdmin { if setting.Indexer.RepoIndexerEnabled { - status, err := repo_model.GetIndexerStatus(ctx.Repo.Repository, repo_model.RepoIndexerTypeCode) + status, err := repo_model.GetIndexerStatus(ctx, ctx.Repo.Repository, repo_model.RepoIndexerTypeCode) if err != nil { ctx.ServerError("repo.indexer_status", err) return } ctx.Data["CodeIndexerStatus"] = status } - status, err := repo_model.GetIndexerStatus(ctx.Repo.Repository, repo_model.RepoIndexerTypeStats) + status, err := repo_model.GetIndexerStatus(ctx, ctx.Repo.Repository, repo_model.RepoIndexerTypeStats) if err != nil { ctx.ServerError("repo.indexer_status", err) return @@ -195,7 +195,7 @@ func SettingsPost(ctx *context.Context) { ctx.Repo.Mirror.EnablePrune = form.EnablePrune ctx.Repo.Mirror.Interval = interval ctx.Repo.Mirror.ScheduleNextUpdate() - if err := repo_model.UpdateMirror(ctx.Repo.Mirror); err != nil { + if err := repo_model.UpdateMirror(ctx, ctx.Repo.Mirror); err != nil { ctx.Data["Err_Interval"] = true ctx.RenderWithErr(ctx.Tr("repo.mirror_interval_invalid"), tplSettingsOptions, &form) return @@ -241,7 +241,7 @@ func SettingsPost(ctx *context.Context) { ctx.Repo.Mirror.LFS = form.LFS ctx.Repo.Mirror.LFSEndpoint = form.LFSEndpoint - if err := repo_model.UpdateMirror(ctx.Repo.Mirror); err != nil { + if err := repo_model.UpdateMirror(ctx, ctx.Repo.Mirror); err != nil { ctx.ServerError("UpdateMirror", err) return } @@ -642,7 +642,7 @@ func SettingsPost(ctx *context.Context) { return } - newOwner, err := user_model.GetUserByName(ctx.FormString("new_owner_name")) + newOwner, err := user_model.GetUserByName(ctx, ctx.FormString("new_owner_name")) if err != nil { if user_model.IsErrUserNotExist(err) { ctx.RenderWithErr(ctx.Tr("form.enterred_invalid_owner_name"), tplSettingsOptions, nil) @@ -840,7 +840,7 @@ func Collaboration(ctx *context.Context) { } ctx.Data["Collaborators"] = users - teams, err := organization.GetRepoTeams(ctx.Repo.Repository) + teams, err := organization.GetRepoTeams(ctx, ctx.Repo.Repository) if err != nil { ctx.ServerError("GetRepoTeams", err) return @@ -863,7 +863,7 @@ func CollaborationPost(ctx *context.Context) { return } - u, err := user_model.GetUserByName(name) + u, err := user_model.GetUserByName(ctx, name) if err != nil { if user_model.IsErrUserNotExist(err) { ctx.Flash.Error(ctx.Tr("form.user_not_exist")) @@ -983,7 +983,7 @@ func DeleteTeam(ctx *context.Context) { return } - team, err := organization.GetTeamByID(ctx.FormInt64("id")) + team, err := organization.GetTeamByID(ctx, ctx.FormInt64("id")) if err != nil { ctx.ServerError("GetTeamByID", err) return @@ -1215,6 +1215,7 @@ func selectPushMirrorByForm(form *forms.RepoSettingForm, repo *repo_model.Reposi for _, m := range pushMirrors { if m.ID == id { + m.Repo = repo return m, nil } } diff --git a/routers/web/repo/setting_protected_branch.go b/routers/web/repo/setting_protected_branch.go index 35f35163c..6c2f3ce01 100644 --- a/routers/web/repo/setting_protected_branch.go +++ b/routers/web/repo/setting_protected_branch.go @@ -111,7 +111,7 @@ func SettingsProtectedBranch(c *context.Context) { c.Data["Title"] = c.Tr("repo.settings.protected_branch") + " - " + branch c.Data["PageIsSettingsBranches"] = true - protectBranch, err := models.GetProtectedBranchBy(c.Repo.Repository.ID, branch) + protectBranch, err := models.GetProtectedBranchBy(c, c.Repo.Repository.ID, branch) if err != nil { if !git.IsErrBranchNotExist(err) { c.ServerError("GetProtectBranchOfRepoByName", err) @@ -184,7 +184,7 @@ func SettingsProtectedBranchPost(ctx *context.Context) { return } - protectBranch, err := models.GetProtectedBranchBy(ctx.Repo.Repository.ID, branch) + protectBranch, err := models.GetProtectedBranchBy(ctx, ctx.Repo.Repository.ID, branch) if err != nil { if !git.IsErrBranchNotExist(err) { ctx.ServerError("GetProtectBranchOfRepoByName", err) diff --git a/routers/web/repo/view.go b/routers/web/repo/view.go index 86fc36fad..95ca81c44 100644 --- a/routers/web/repo/view.go +++ b/routers/web/repo/view.go @@ -684,7 +684,7 @@ func checkHomeCodeViewable(ctx *context.Context) { if ctx.IsSigned { // Set repo notification-status read if unread - if err := models.SetRepoReadBy(ctx.Repo.Repository.ID, ctx.Doer.ID); err != nil { + if err := models.SetRepoReadBy(ctx, ctx.Repo.Repository.ID, ctx.Doer.ID); err != nil { ctx.ServerError("ReadBy", err) return } @@ -839,7 +839,7 @@ func renderDirectoryFiles(ctx *context.Context, timeout time.Duration) git.Entri ctx.Data["LatestCommitUser"] = user_model.ValidateCommitWithEmail(latestCommit) } - statuses, _, err := models.GetLatestCommitStatus(ctx.Repo.Repository.ID, ctx.Repo.Commit.ID.String(), db.ListOptions{}) + statuses, _, err := models.GetLatestCommitStatus(ctx, ctx.Repo.Repository.ID, ctx.Repo.Commit.ID.String(), db.ListOptions{}) if err != nil { log.Error("GetLatestCommitStatus: %v", err) } @@ -901,7 +901,7 @@ func renderCode(ctx *context.Context) { // it's possible for a repository to be non-empty by that flag but still 500 // because there are no branches - only tags -or the default branch is non-extant as it has been 0-pushed. ctx.Repo.Repository.IsEmpty = false - if err = repo_model.UpdateRepositoryCols(ctx.Repo.Repository, "is_empty"); err != nil { + if err = repo_model.UpdateRepositoryCols(ctx, ctx.Repo.Repository, "is_empty"); err != nil { ctx.ServerError("UpdateRepositoryCols", err) return } diff --git a/routers/web/repo/webhook.go b/routers/web/repo/webhook.go index d2e246118..a9b14ee21 100644 --- a/routers/web/repo/webhook.go +++ b/routers/web/repo/webhook.go @@ -44,7 +44,7 @@ func Webhooks(ctx *context.Context) { ctx.Data["BaseLinkNew"] = ctx.Repo.RepoLink + "/settings/hooks" ctx.Data["Description"] = ctx.Tr("repo.settings.hooks_desc", "https://docs.gitea.io/en-us/webhooks/") - ws, err := webhook.ListWebhooksByOpts(&webhook.ListWebhookOptions{RepoID: ctx.Repo.Repository.ID}) + ws, err := webhook.ListWebhooksByOpts(ctx, &webhook.ListWebhookOptions{RepoID: ctx.Repo.Repository.ID}) if err != nil { ctx.ServerError("GetWebhooksByRepoID", err) return diff --git a/routers/web/user/avatar.go b/routers/web/user/avatar.go index c8bca9dc2..53a603fab 100644 --- a/routers/web/user/avatar.go +++ b/routers/web/user/avatar.go @@ -30,7 +30,7 @@ func AvatarByUserName(ctx *context.Context) { var user *user_model.User if strings.ToLower(userName) != "ghost" { var err error - if user, err = user_model.GetUserByName(userName); err != nil { + if user, err = user_model.GetUserByName(ctx, userName); err != nil { ctx.ServerError("Invalid user: "+userName, err) return } diff --git a/routers/web/user/home.go b/routers/web/user/home.go index 37f6b8835..2a802053f 100644 --- a/routers/web/user/home.go +++ b/routers/web/user/home.go @@ -615,7 +615,7 @@ func buildIssueOverview(ctx *context.Context, unitType unit.Type) { ctx.Data["Issues"] = issues - approvalCounts, err := models.IssueList(issues).GetApprovalCounts() + approvalCounts, err := models.IssueList(issues).GetApprovalCounts(ctx) if err != nil { ctx.ServerError("ApprovalCounts", err) return diff --git a/routers/web/user/notification.go b/routers/web/user/notification.go index 05421cf55..0b1789dcf 100644 --- a/routers/web/user/notification.go +++ b/routers/web/user/notification.go @@ -34,7 +34,7 @@ func GetNotificationCount(c *context.Context) { } c.Data["NotificationUnreadCount"] = func() int64 { - count, err := models.GetNotificationCount(c.Doer, models.NotificationStatusUnread) + count, err := models.GetNotificationCount(c, c.Doer, models.NotificationStatusUnread) if err != nil { c.ServerError("GetNotificationCount", err) return -1 @@ -79,7 +79,7 @@ func getNotifications(c *context.Context) { status = models.NotificationStatusUnread } - total, err := models.GetNotificationCount(c.Doer, status) + total, err := models.GetNotificationCount(c, c.Doer, status) if err != nil { c.ServerError("ErrGetNotificationCount", err) return @@ -93,7 +93,7 @@ func getNotifications(c *context.Context) { } statuses := []models.NotificationStatus{status, models.NotificationStatusPinned} - notifications, err := models.NotificationsForUser(c.Doer, statuses, page, perPage) + notifications, err := models.NotificationsForUser(c, c.Doer, statuses, page, perPage) if err != nil { c.ServerError("ErrNotificationsForUser", err) return @@ -195,5 +195,5 @@ func NotificationPurgePost(c *context.Context) { // NewAvailable returns the notification counts func NewAvailable(ctx *context.Context) { - ctx.JSON(http.StatusOK, structs.NotificationCount{New: models.CountUnread(ctx.Doer)}) + ctx.JSON(http.StatusOK, structs.NotificationCount{New: models.CountUnread(ctx, ctx.Doer.ID)}) } diff --git a/routers/web/user/profile.go b/routers/web/user/profile.go index 85870eddf..8bce5460c 100644 --- a/routers/web/user/profile.go +++ b/routers/web/user/profile.go @@ -42,7 +42,7 @@ func Profile(ctx *context.Context) { } // check view permissions - if !user_model.IsUserVisibleToViewer(ctx.ContextUser, ctx.Doer) { + if !user_model.IsUserVisibleToViewer(ctx, ctx.ContextUser, ctx.Doer) { ctx.NotFound("user", fmt.Errorf(ctx.ContextUser.Name)) return } @@ -217,7 +217,7 @@ func Profile(ctx *context.Context) { total = int(count) case "projects": - ctx.Data["OpenProjects"], _, err = project_model.GetProjects(project_model.SearchOptions{ + ctx.Data["OpenProjects"], _, err = project_model.GetProjects(ctx, project_model.SearchOptions{ Page: -1, IsClosed: util.OptionalBoolFalse, Type: project_model.TypeIndividual, diff --git a/routers/web/user/setting/account.go b/routers/web/user/setting/account.go index b2476dff9..92f6c9a18 100644 --- a/routers/web/user/setting/account.go +++ b/routers/web/user/setting/account.go @@ -181,7 +181,7 @@ func EmailPost(ctx *context.Context) { Email: form.Email, IsActivated: !setting.Service.RegisterEmailConfirm, } - if err := user_model.AddEmailAddress(email); err != nil { + if err := user_model.AddEmailAddress(ctx, email); err != nil { if user_model.IsErrEmailAlreadyUsed(err) { loadAccountData(ctx) diff --git a/routers/web/user/setting/adopt.go b/routers/web/user/setting/adopt.go index ce2377a99..c7139f8bb 100644 --- a/routers/web/user/setting/adopt.go +++ b/routers/web/user/setting/adopt.go @@ -32,7 +32,7 @@ func AdoptOrDeleteRepository(ctx *context.Context) { root := user_model.UserPath(ctxUser.LowerName) // check not a repo - has, err := repo_model.IsRepositoryExist(ctxUser, dir) + has, err := repo_model.IsRepositoryExist(ctx, ctxUser, dir) if err != nil { ctx.ServerError("IsRepositoryExist", err) return diff --git a/routers/web/user/setting/applications.go b/routers/web/user/setting/applications.go index b0f599fc4..4ffec4780 100644 --- a/routers/web/user/setting/applications.go +++ b/routers/web/user/setting/applications.go @@ -93,12 +93,12 @@ func loadApplicationsData(ctx *context.Context) { ctx.Data["Tokens"] = tokens ctx.Data["EnableOAuth2"] = setting.OAuth2.Enable if setting.OAuth2.Enable { - ctx.Data["Applications"], err = auth.GetOAuth2ApplicationsByUserID(ctx.Doer.ID) + ctx.Data["Applications"], err = auth.GetOAuth2ApplicationsByUserID(ctx, ctx.Doer.ID) if err != nil { ctx.ServerError("GetOAuth2ApplicationsByUserID", err) return } - ctx.Data["Grants"], err = auth.GetOAuth2GrantsByUserID(ctx.Doer.ID) + ctx.Data["Grants"], err = auth.GetOAuth2GrantsByUserID(ctx, ctx.Doer.ID) if err != nil { ctx.ServerError("GetOAuth2GrantsByUserID", err) return diff --git a/routers/web/user/setting/oauth2.go b/routers/web/user/setting/oauth2.go index 76c50852a..db76a12f1 100644 --- a/routers/web/user/setting/oauth2.go +++ b/routers/web/user/setting/oauth2.go @@ -34,7 +34,7 @@ func OAuthApplicationsPost(ctx *context.Context) { return } // TODO validate redirect URI - app, err := auth.CreateOAuth2Application(auth.CreateOAuth2ApplicationOptions{ + app, err := auth.CreateOAuth2Application(ctx, auth.CreateOAuth2ApplicationOptions{ Name: form.Name, RedirectURIs: []string{form.RedirectURI}, UserID: ctx.Doer.ID, @@ -85,7 +85,7 @@ func OAuthApplicationsRegenerateSecret(ctx *context.Context) { ctx.Data["Title"] = ctx.Tr("settings") ctx.Data["PageIsSettingsApplications"] = true - app, err := auth.GetOAuth2ApplicationByID(ctx.ParamsInt64("id")) + app, err := auth.GetOAuth2ApplicationByID(ctx, ctx.ParamsInt64("id")) if err != nil { if auth.IsErrOAuthApplicationNotFound(err) { ctx.NotFound("Application not found", err) @@ -110,7 +110,7 @@ func OAuthApplicationsRegenerateSecret(ctx *context.Context) { // OAuth2ApplicationShow displays the given application func OAuth2ApplicationShow(ctx *context.Context) { - app, err := auth.GetOAuth2ApplicationByID(ctx.ParamsInt64("id")) + app, err := auth.GetOAuth2ApplicationByID(ctx, ctx.ParamsInt64("id")) if err != nil { if auth.IsErrOAuthApplicationNotFound(err) { ctx.NotFound("Application not found", err) @@ -147,7 +147,7 @@ func RevokeOAuth2Grant(ctx *context.Context) { ctx.ServerError("RevokeOAuth2Grant", fmt.Errorf("user id or grant id is zero")) return } - if err := auth.RevokeOAuth2Grant(ctx.FormInt64("id"), ctx.Doer.ID); err != nil { + if err := auth.RevokeOAuth2Grant(ctx, ctx.FormInt64("id"), ctx.Doer.ID); err != nil { ctx.ServerError("RevokeOAuth2Grant", err) return } diff --git a/routers/web/user/setting/profile.go b/routers/web/user/setting/profile.go index 0123b9b52..c2a406b18 100644 --- a/routers/web/user/setting/profile.go +++ b/routers/web/user/setting/profile.go @@ -180,7 +180,7 @@ func UpdateAvatarSetting(ctx *context.Context, form *forms.AvatarForm, ctxUser * } else if ctxUser.UseCustomAvatar && ctxUser.Avatar == "" { // No avatar is uploaded but setting has been changed to enable, // generate a random one when needed. - if err := user_model.GenerateRandomAvatar(ctxUser); err != nil { + if err := user_model.GenerateRandomAvatar(ctx, ctxUser); err != nil { log.Error("GenerateRandomAvatar[%d]: %v", ctxUser.ID, err) } } diff --git a/routers/web/user/setting/security/openid.go b/routers/web/user/setting/security/openid.go index 2ecc9b053..a378c8bf6 100644 --- a/routers/web/user/setting/security/openid.go +++ b/routers/web/user/setting/security/openid.go @@ -90,7 +90,7 @@ func settingsOpenIDVerify(ctx *context.Context) { log.Trace("Verified ID: " + id) oid := &user_model.UserOpenID{UID: ctx.Doer.ID, URI: id} - if err = user_model.AddUserOpenID(oid); err != nil { + if err = user_model.AddUserOpenID(ctx, oid); err != nil { if user_model.IsErrOpenIDAlreadyUsed(err) { ctx.RenderWithErr(ctx.Tr("form.openid_been_used", id), tplSettingsSecurity, &forms.AddOpenIDForm{Openid: id}) return diff --git a/routers/web/webfinger.go b/routers/web/webfinger.go index 27d0351b8..840296786 100644 --- a/routers/web/webfinger.go +++ b/routers/web/webfinger.go @@ -59,7 +59,7 @@ func WebfingerQuery(ctx *context.Context) { return } - u, err = user_model.GetUserByNameCtx(ctx, parts[0]) + u, err = user_model.GetUserByName(ctx, parts[0]) case "mailto": u, err = user_model.GetUserByEmailContext(ctx, resource.Opaque) if u != nil && u.KeepEmailPrivate { @@ -79,7 +79,7 @@ func WebfingerQuery(ctx *context.Context) { return } - if !user_model.IsUserVisibleToViewer(u, ctx.Doer) { + if !user_model.IsUserVisibleToViewer(ctx, u, ctx.Doer) { ctx.Error(http.StatusNotFound) return } diff --git a/services/asymkey/sign.go b/services/asymkey/sign.go index 6b17c017f..2431146f9 100644 --- a/services/asymkey/sign.go +++ b/services/asymkey/sign.go @@ -310,7 +310,7 @@ Loop: return false, "", nil, &ErrWontSign{twofa} } case approved: - protectedBranch, err := models.GetProtectedBranchBy(repo.ID, pr.BaseBranch) + protectedBranch, err := models.GetProtectedBranchBy(ctx, repo.ID, pr.BaseBranch) if err != nil { return false, "", nil, err } diff --git a/services/asymkey/ssh_key.go b/services/asymkey/ssh_key.go index 1f6b93eb2..143c807a1 100644 --- a/services/asymkey/ssh_key.go +++ b/services/asymkey/ssh_key.go @@ -42,7 +42,7 @@ func DeletePublicKey(doer *user_model.User, id int64) (err error) { committer.Close() if key.Type == asymkey_model.KeyTypePrincipal { - return asymkey_model.RewriteAllPrincipalKeys() + return asymkey_model.RewriteAllPrincipalKeys(db.DefaultContext) } return asymkey_model.RewriteAllPublicKeys() diff --git a/services/attachment/attachment_test.go b/services/attachment/attachment_test.go index ffce5943e..889151d8f 100644 --- a/services/attachment/attachment_test.go +++ b/services/attachment/attachment_test.go @@ -9,6 +9,7 @@ import ( "path/filepath" "testing" + "code.gitea.io/gitea/models/db" repo_model "code.gitea.io/gitea/models/repo" "code.gitea.io/gitea/models/unittest" user_model "code.gitea.io/gitea/models/user" @@ -39,7 +40,7 @@ func TestUploadAttachment(t *testing.T) { }, f) assert.NoError(t, err) - attachment, err := repo_model.GetAttachmentByUUID(attach.UUID) + attachment, err := repo_model.GetAttachmentByUUID(db.DefaultContext, attach.UUID) assert.NoError(t, err) assert.EqualValues(t, user.ID, attachment.UploaderID) assert.Equal(t, int64(0), attachment.DownloadCount) diff --git a/services/auth/oauth2.go b/services/auth/oauth2.go index 42c91fac3..68638a080 100644 --- a/services/auth/oauth2.go +++ b/services/auth/oauth2.go @@ -38,7 +38,7 @@ func CheckOAuthAccessToken(accessToken string) int64 { return 0 } var grant *auth.OAuth2Grant - if grant, err = auth.GetOAuth2GrantByID(token.GrantID); err != nil || grant == nil { + if grant, err = auth.GetOAuth2GrantByID(db.DefaultContext, token.GrantID); err != nil || grant == nil { return 0 } if token.Type != oauth2.TypeAccessToken { diff --git a/services/auth/reverseproxy.go b/services/auth/reverseproxy.go index 299d7abd3..05d6af78f 100644 --- a/services/auth/reverseproxy.go +++ b/services/auth/reverseproxy.go @@ -63,7 +63,7 @@ func (r *ReverseProxy) Verify(req *http.Request, w http.ResponseWriter, store Da } log.Trace("ReverseProxy Authorization: Found username: %s", username) - user, err := user_model.GetUserByName(username) + user, err := user_model.GetUserByName(req.Context(), username) if err != nil { if !user_model.IsErrUserNotExist(err) || !r.isAutoRegisterAllowed() { log.Error("GetUserByName: %v", err) diff --git a/services/auth/source/ldap/source_authenticate.go b/services/auth/source/ldap/source_authenticate.go index d8d11f18e..785cb8ed3 100644 --- a/services/auth/source/ldap/source_authenticate.go +++ b/services/auth/source/ldap/source_authenticate.go @@ -34,11 +34,11 @@ func (source *Source) Authenticate(user *user_model.User, userName, password str isAttributeSSHPublicKeySet := len(strings.TrimSpace(source.AttributeSSHPublicKey)) > 0 // Update User admin flag if exist - if isExist, err := user_model.IsUserExist(0, sr.Username); err != nil { + if isExist, err := user_model.IsUserExist(db.DefaultContext, 0, sr.Username); err != nil { return nil, err } else if isExist { if user == nil { - user, err = user_model.GetUserByName(sr.Username) + user, err = user_model.GetUserByName(db.DefaultContext, sr.Username) if err != nil { return nil, err } diff --git a/services/auth/source/ldap/source_sync.go b/services/auth/source/ldap/source_sync.go index a245f4c6f..eb5ee8463 100644 --- a/services/auth/source/ldap/source_sync.go +++ b/services/auth/source/ldap/source_sync.go @@ -118,7 +118,6 @@ func (source *Source) Sync(ctx context.Context, updateExisting bool) error { } err = user_model.CreateUser(usr, overwriteDefault) - if err != nil { log.Error("SyncExternalUsers[%s]: Error creating user %s: %v", source.authSource.Name, su.Username, err) } @@ -161,7 +160,7 @@ func (source *Source) Sync(ctx context.Context, updateExisting bool) error { } usr.IsActive = true - err = user_model.UpdateUser(usr, emailChanged, "full_name", "email", "is_admin", "is_restricted", "is_active") + err = user_model.UpdateUser(ctx, usr, emailChanged, "full_name", "email", "is_admin", "is_restricted", "is_active") if err != nil { log.Error("SyncExternalUsers[%s]: Error updating user %s: %v", source.authSource.Name, usr.Name, err) } diff --git a/services/auth/sspi_windows.go b/services/auth/sspi_windows.go index 9bc4041a7..7c9529a76 100644 --- a/services/auth/sspi_windows.go +++ b/services/auth/sspi_windows.go @@ -127,7 +127,7 @@ func (s *SSPI) Verify(req *http.Request, w http.ResponseWriter, store DataStore, } log.Info("Authenticated as %s\n", username) - user, err := user_model.GetUserByName(username) + user, err := user_model.GetUserByName(req.Context(), username) if err != nil { if !user_model.IsErrUserNotExist(err) { log.Error("GetUserByName: %v", err) diff --git a/services/automerge/automerge.go b/services/automerge/automerge.go index 3ce4883aa..3c7346ab5 100644 --- a/services/automerge/automerge.go +++ b/services/automerge/automerge.go @@ -145,7 +145,7 @@ func getPullRequestsByHeadSHA(ctx context.Context, sha string, repo *repo_model. continue } - p, err := models.GetPullRequestByIndexCtx(ctx, repo.ID, prIndex) + p, err := models.GetPullRequestByIndex(ctx, repo.ID, prIndex) if err != nil { // If there is no pull request for this branch, we don't try to merge it. if models.IsErrPullRequestNotExist(err) { diff --git a/services/comments/comments.go b/services/comments/comments.go index c1b3ab73c..b80fddf93 100644 --- a/services/comments/comments.go +++ b/services/comments/comments.go @@ -48,7 +48,7 @@ func UpdateComment(c *models.Comment, doer *user_model.User, oldContent string) return err } if !hasContentHistory { - if err = issues.SaveIssueContentHistory(db.GetEngine(db.DefaultContext), c.PosterID, c.IssueID, c.ID, + if err = issues.SaveIssueContentHistory(db.DefaultContext, c.PosterID, c.IssueID, c.ID, c.CreatedUnix, oldContent, true); err != nil { return err } @@ -60,7 +60,7 @@ func UpdateComment(c *models.Comment, doer *user_model.User, oldContent string) } if needsContentHistory { - err := issues.SaveIssueContentHistory(db.GetEngine(db.DefaultContext), doer.ID, c.IssueID, c.ID, timeutil.TimeStampNow(), c.Content, false) + err := issues.SaveIssueContentHistory(db.DefaultContext, doer.ID, c.IssueID, c.ID, timeutil.TimeStampNow(), c.Content, false) if err != nil { return err } diff --git a/services/context/user.go b/services/context/user.go index c5efd4378..1c92d24d4 100644 --- a/services/context/user.go +++ b/services/context/user.go @@ -44,7 +44,7 @@ func userAssignment(ctx *context.Context, errCb func(int, string, interface{})) ctx.ContextUser = ctx.Doer } else { var err error - ctx.ContextUser, err = user_model.GetUserByName(username) + ctx.ContextUser, err = user_model.GetUserByName(ctx, username) if err != nil { if user_model.IsErrUserNotExist(err) { if redirectUserID, err := user_model.LookupUserRedirect(username); err == nil { diff --git a/services/cron/tasks_extended.go b/services/cron/tasks_extended.go index ec7ced99e..41bd5c442 100644 --- a/services/cron/tasks_extended.go +++ b/services/cron/tasks_extended.go @@ -11,6 +11,7 @@ import ( "code.gitea.io/gitea/models" "code.gitea.io/gitea/models/admin" asymkey_model "code.gitea.io/gitea/models/asymkey" + "code.gitea.io/gitea/models/db" user_model "code.gitea.io/gitea/models/user" "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/modules/updatechecker" @@ -79,7 +80,7 @@ func registerRewriteAllPrincipalKeys() { RunAtStart: false, Schedule: "@every 72h", }, func(_ context.Context, _ *user_model.User, _ Config) error { - return asymkey_model.RewriteAllPrincipalKeys() + return asymkey_model.RewriteAllPrincipalKeys(db.DefaultContext) }) } diff --git a/services/issue/assignee.go b/services/issue/assignee.go index 0b6d0045f..8cad03351 100644 --- a/services/issue/assignee.go +++ b/services/issue/assignee.go @@ -21,9 +21,10 @@ import ( // DeleteNotPassedAssignee deletes all assignees who aren't passed via the "assignees" array func DeleteNotPassedAssignee(issue *models.Issue, doer *user_model.User, assignees []*user_model.User) (err error) { var found bool + oriAssignes := make([]*user_model.User, len(issue.Assignees)) + _ = copy(oriAssignes, issue.Assignees) - for _, assignee := range issue.Assignees { - + for _, assignee := range oriAssignes { found = false for _, alreadyAssignee := range assignees { if assignee.ID == alreadyAssignee.ID { @@ -110,7 +111,7 @@ func IsValidReviewRequest(ctx context.Context, reviewer, doer *user_model.User, } } - lastreview, err := models.GetReviewByIssueIDAndUserID(issue.ID, reviewer.ID) + lastreview, err := models.GetReviewByIssueIDAndUserID(ctx, issue.ID, reviewer.ID) if err != nil && !models.IsErrReviewNotExist(err) { return err } @@ -132,7 +133,7 @@ func IsValidReviewRequest(ctx context.Context, reviewer, doer *user_model.User, pemResult = permDoer.CanAccessAny(perm.AccessModeWrite, unit.TypePullRequests) if !pemResult { - pemResult, err = models.IsOfficialReviewer(issue, doer) + pemResult, err = models.IsOfficialReviewer(ctx, issue, doer) if err != nil { return err } @@ -201,7 +202,7 @@ func IsValidTeamReviewRequest(ctx context.Context, reviewer *organization.Team, doerCanWrite := permission.CanAccessAny(perm.AccessModeWrite, unit.TypePullRequests) if !doerCanWrite { - official, err := models.IsOfficialReviewer(issue, doer) + official, err := models.IsOfficialReviewer(ctx, issue, doer) if err != nil { log.Error("Unable to Check if IsOfficialReviewer for %-v in %-v#%d", doer, issue.Repo, issue.Index) return err diff --git a/services/issue/assignee_test.go b/services/issue/assignee_test.go index d3d7ad74f..ff4d7029e 100644 --- a/services/issue/assignee_test.go +++ b/services/issue/assignee_test.go @@ -8,6 +8,7 @@ import ( "testing" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/db" "code.gitea.io/gitea/models/unittest" user_model "code.gitea.io/gitea/models/user" @@ -20,21 +21,23 @@ func TestDeleteNotPassedAssignee(t *testing.T) { // Fake issue with assignees issue, err := models.GetIssueWithAttrsByID(1) assert.NoError(t, err) + assert.EqualValues(t, 1, len(issue.Assignees)) user1, err := user_model.GetUserByID(1) // This user is already assigned (see the definition in fixtures), so running UpdateAssignee should unassign him assert.NoError(t, err) // Check if he got removed - isAssigned, err := models.IsUserAssignedToIssue(issue, user1) + isAssigned, err := models.IsUserAssignedToIssue(db.DefaultContext, issue, user1) assert.NoError(t, err) assert.True(t, isAssigned) // Clean everyone err = DeleteNotPassedAssignee(issue, user1, []*user_model.User{}) assert.NoError(t, err) + assert.EqualValues(t, 0, len(issue.Assignees)) // Check they're gone - assignees, err := models.GetAssigneesByIssue(issue) - assert.NoError(t, err) - assert.Empty(t, assignees) + assert.NoError(t, issue.LoadAssignees(db.DefaultContext)) + assert.EqualValues(t, 0, len(issue.Assignees)) + assert.Empty(t, issue.Assignee) } diff --git a/services/issue/issue.go b/services/issue/issue.go index db304a46b..78a486727 100644 --- a/services/issue/issue.go +++ b/services/issue/issue.go @@ -100,7 +100,7 @@ func UpdateAssignees(issue *models.Issue, oneAssignee string, multipleAssignees // Loop through all assignees to add them for _, assigneeName := range multipleAssignees { - assignee, err := user_model.GetUserByName(assigneeName) + assignee, err := user_model.GetUserByName(db.DefaultContext, assigneeName) if err != nil { return err } @@ -164,7 +164,7 @@ func AddAssigneeIfNotAssigned(issue *models.Issue, doer *user_model.User, assign } // Check if the user is already assigned - isAssigned, err := models.IsUserAssignedToIssue(issue, assignee) + isAssigned, err := models.IsUserAssignedToIssue(db.DefaultContext, issue, assignee) if err != nil { return err } diff --git a/services/issue/label.go b/services/issue/label.go index 94e52482f..289466f60 100644 --- a/services/issue/label.go +++ b/services/issue/label.go @@ -80,7 +80,7 @@ func RemoveLabel(issue *models.Issue, doer *user_model.User, label *models.Label // ReplaceLabels removes all current labels and add new labels to the issue. func ReplaceLabels(issue *models.Issue, doer *user_model.User, labels []*models.Label) error { - old, err := models.GetLabelsByIssueID(issue.ID) + old, err := models.GetLabelsByIssueID(db.DefaultContext, issue.ID) if err != nil { return err } diff --git a/services/mailer/mail_issue.go b/services/mailer/mail_issue.go index c24edf50c..4abf7eefd 100644 --- a/services/mailer/mail_issue.go +++ b/services/mailer/mail_issue.go @@ -71,7 +71,7 @@ func mailIssueCommentToParticipants(ctx *mailCommentContext, mentions []*user_mo unfiltered = append(unfiltered, ids...) // =========== Issue watchers =========== - ids, err = models.GetIssueWatchersIDs(ctx.Issue.ID, true) + ids, err = models.GetIssueWatchersIDs(ctx, ctx.Issue.ID, true) if err != nil { return fmt.Errorf("GetIssueWatchersIDs(%d): %v", ctx.Issue.ID, err) } @@ -98,7 +98,7 @@ func mailIssueCommentToParticipants(ctx *mailCommentContext, mentions []*user_mo } // Avoid mailing explicit unwatched - ids, err = models.GetIssueWatchersIDs(ctx.Issue.ID, false) + ids, err = models.GetIssueWatchersIDs(ctx, ctx.Issue.ID, false) if err != nil { return fmt.Errorf("GetIssueWatchersIDs(%d): %v", ctx.Issue.ID, err) } diff --git a/services/migrations/gitea_uploader.go b/services/migrations/gitea_uploader.go index 34dd59d7f..fec1fc8c7 100644 --- a/services/migrations/gitea_uploader.go +++ b/services/migrations/gitea_uploader.go @@ -93,7 +93,7 @@ func (g *GiteaLocalUploader) MaxBatchInsertSize(tp string) int { // CreateRepo creates a repository func (g *GiteaLocalUploader) CreateRepo(repo *base.Repository, opts base.MigrateOptions) error { - owner, err := user_model.GetUserByName(g.repoOwner) + owner, err := user_model.GetUserByName(g.ctx, g.repoOwner) if err != nil { return err } @@ -826,7 +826,7 @@ func (g *GiteaLocalUploader) Finish() error { } g.repo.Status = repo_model.RepositoryReady - return repo_model.UpdateRepositoryCols(g.repo, "status") + return repo_model.UpdateRepositoryCols(g.ctx, g.repo, "status") } func (g *GiteaLocalUploader) remapUser(source user_model.ExternalUserMigrated, target user_model.ExternalUserRemappable) error { diff --git a/services/migrations/gitea_uploader_test.go b/services/migrations/gitea_uploader_test.go index f57c8e233..bd7c6e065 100644 --- a/services/migrations/gitea_uploader_test.go +++ b/services/migrations/gitea_uploader_test.go @@ -40,7 +40,8 @@ func TestGiteaUploadRepo(t *testing.T) { user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 1}).(*user_model.User) var ( - downloader = NewGithubDownloaderV3(context.Background(), "https://github.com", "", "", "", "go-xorm", "builder") + ctx = context.Background() + downloader = NewGithubDownloaderV3(ctx, "https://github.com", "", "", "", "go-xorm", "builder") repoName = "builder-" + time.Now().Format("2006-01-02-15-04-05") uploader = NewGiteaLocalUploader(graceful.GetManager().HammerContext(), user, user.Name, repoName) ) @@ -80,7 +81,7 @@ func TestGiteaUploadRepo(t *testing.T) { assert.NoError(t, err) assert.Empty(t, milestones) - labels, err := models.GetLabelsByRepoID(repo.ID, "", db.ListOptions{}) + labels, err := models.GetLabelsByRepoID(ctx, repo.ID, "", db.ListOptions{}) assert.NoError(t, err) assert.Len(t, labels, 12) diff --git a/services/mirror/mirror_pull.go b/services/mirror/mirror_pull.go index ecd031b38..c51483821 100644 --- a/services/mirror/mirror_pull.go +++ b/services/mirror/mirror_pull.go @@ -12,6 +12,7 @@ import ( "code.gitea.io/gitea/models" admin_model "code.gitea.io/gitea/models/admin" + "code.gitea.io/gitea/models/db" repo_model "code.gitea.io/gitea/models/repo" "code.gitea.io/gitea/modules/cache" "code.gitea.io/gitea/modules/git" @@ -71,7 +72,7 @@ func UpdateAddress(ctx context.Context, m *repo_model.Mirror, addr string) error } m.Repo.OriginalURL = addr - return repo_model.UpdateRepositoryCols(m.Repo, "original_url") + return repo_model.UpdateRepositoryCols(ctx, m.Repo, "original_url") } // mirrorSyncResult contains information of a updated reference. @@ -395,11 +396,12 @@ func SyncPullMirror(ctx context.Context, repoID int64) bool { log.Error("PANIC whilst SyncMirrors[repo_id: %d] Panic: %v\nStacktrace: %s", repoID, err, log.Stack(2)) }() - m, err := repo_model.GetMirrorByRepoID(repoID) + m, err := repo_model.GetMirrorByRepoID(ctx, repoID) if err != nil { log.Error("SyncMirrors [repo_id: %v]: unable to GetMirrorByRepoID: %v", repoID, err) return false } + _ = m.GetRepository() // force load repository of mirror ctx, _, finished := process.GetManager().AddContext(ctx, fmt.Sprintf("Syncing Mirror %s/%s", m.Repo.OwnerName, m.Repo.Name)) defer finished() @@ -415,7 +417,7 @@ func SyncPullMirror(ctx context.Context, repoID int64) bool { log.Trace("SyncMirrors [repo: %-v]: Scheduling next update", m.Repo) m.ScheduleNextUpdate() - if err = repo_model.UpdateMirror(m); err != nil { + if err = repo_model.UpdateMirror(ctx, m); err != nil { log.Error("SyncMirrors [repo: %-v]: failed to UpdateMirror with next update date: %v", m.Repo, err) return false } @@ -574,7 +576,7 @@ func checkAndUpdateEmptyRepository(m *repo_model.Mirror, gitRepo *git.Repository } m.Repo.IsEmpty = false // Update the is empty and default_branch columns - if err := repo_model.UpdateRepositoryCols(m.Repo, "default_branch", "is_empty"); err != nil { + if err := repo_model.UpdateRepositoryCols(db.DefaultContext, m.Repo, "default_branch", "is_empty"); err != nil { log.Error("Failed to update default branch of repository %-v. Error: %v", m.Repo, err) desc := fmt.Sprintf("Failed to uupdate default branch of repository '%s': %v", m.Repo.RepoPath(), err) if err = admin_model.CreateRepositoryNotice(desc); err != nil { diff --git a/services/mirror/mirror_push.go b/services/mirror/mirror_push.go index 5c0c14c62..138ebb737 100644 --- a/services/mirror/mirror_push.go +++ b/services/mirror/mirror_push.go @@ -66,6 +66,7 @@ func AddPushMirrorRemote(ctx context.Context, m *repo_model.PushMirror, addr str // RemovePushMirrorRemote removes the push mirror remote. func RemovePushMirrorRemote(ctx context.Context, m *repo_model.PushMirror) error { cmd := git.NewCommand(ctx, "remote", "rm", m.RemoteName) + _ = m.GetRepository() if _, _, err := cmd.RunStdString(&git.RunOpts{Dir: m.Repo.RepoPath()}); err != nil { return err @@ -99,6 +100,8 @@ func SyncPushMirror(ctx context.Context, mirrorID int64) bool { return false } + _ = m.GetRepository() + m.LastError = "" ctx, _, finished := process.GetManager().AddContext(ctx, fmt.Sprintf("Syncing PushMirror %s/%s to %s", m.Repo.OwnerName, m.Repo.Name, m.RemoteName)) diff --git a/services/org/org.go b/services/org/org.go index d7b3019e7..b24b7e34c 100644 --- a/services/org/org.go +++ b/services/org/org.go @@ -26,7 +26,7 @@ func DeleteOrganization(org *organization.Organization) error { defer commiter.Close() // Check ownership of repository. - count, err := repo_model.GetRepositoryCount(ctx, org.ID) + count, err := repo_model.CountRepositories(ctx, repo_model.CountRepositoryOptions{OwnerID: org.ID}) if err != nil { return fmt.Errorf("GetRepositoryCount: %v", err) } else if count > 0 { diff --git a/services/pull/check.go b/services/pull/check.go index d88dd3a55..94e7ca716 100644 --- a/services/pull/check.go +++ b/services/pull/check.go @@ -99,7 +99,7 @@ func CheckPullMergable(stdCtx context.Context, doer *user_model.User, perm *acce if err := CheckPullBranchProtections(ctx, pr, false); err != nil { if models.IsErrDisallowedToMerge(err) { if force { - if isRepoAdmin, err2 := access_model.IsUserRepoAdminCtx(ctx, pr.BaseRepo, doer); err2 != nil { + if isRepoAdmin, err2 := access_model.IsUserRepoAdmin(ctx, pr.BaseRepo, doer); err2 != nil { return err2 } else if !isRepoAdmin { return err diff --git a/services/pull/commit_status.go b/services/pull/commit_status.go index ec4cc2aa0..539b3c852 100644 --- a/services/pull/commit_status.go +++ b/services/pull/commit_status.go @@ -132,7 +132,7 @@ func GetPullRequestCommitStatusState(ctx context.Context, pr *models.PullRequest return "", errors.Wrap(err, "LoadBaseRepo") } - commitStatuses, _, err := models.GetLatestCommitStatusCtx(ctx, pr.BaseRepo.ID, sha, db.ListOptions{}) + commitStatuses, _, err := models.GetLatestCommitStatus(ctx, pr.BaseRepo.ID, sha, db.ListOptions{}) if err != nil { return "", errors.Wrap(err, "GetLatestCommitStatus") } diff --git a/services/pull/pull.go b/services/pull/pull.go index b94b6769a..efac3f019 100644 --- a/services/pull/pull.go +++ b/services/pull/pull.go @@ -290,7 +290,7 @@ func AddTestPullRequestTask(doer *user_model.User, repoID int64, branch string, if err != nil { log.Error("GetDiverging: %v", err) } else { - err = pr.UpdateCommitDivergence(divergence.Ahead, divergence.Behind) + err = pr.UpdateCommitDivergence(ctx, divergence.Ahead, divergence.Behind) if err != nil { log.Error("UpdateCommitDivergence: %v", err) } @@ -336,7 +336,7 @@ func AddTestPullRequestTask(doer *user_model.User, repoID int64, branch string, log.Error("GetDiverging: %v", err) } } else { - err = pr.UpdateCommitDivergence(divergence.Ahead, divergence.Behind) + err = pr.UpdateCommitDivergence(ctx, divergence.Ahead, divergence.Behind) if err != nil { log.Error("UpdateCommitDivergence: %v", err) } @@ -793,7 +793,7 @@ func getAllCommitStatus(gitRepo *git.Repository, pr *models.PullRequest) (status return nil, nil, shaErr } - statuses, _, err = models.GetLatestCommitStatus(pr.BaseRepo.ID, sha, db.ListOptions{}) + statuses, _, err = models.GetLatestCommitStatus(db.DefaultContext, pr.BaseRepo.ID, sha, db.ListOptions{}) lastStatus = models.CalcCommitStatus(statuses) return statuses, lastStatus, err } diff --git a/services/pull/review.go b/services/pull/review.go index 940fe4470..eac7279f9 100644 --- a/services/pull/review.go +++ b/services/pull/review.go @@ -71,13 +71,13 @@ func CreateCodeComment(ctx context.Context, doer *user_model.User, gitRepo *git. return comment, nil } - review, err := models.GetCurrentReview(doer, issue) + review, err := models.GetCurrentReview(ctx, doer, issue) if err != nil { if !models.IsErrReviewNotExist(err) { return nil, err } - if review, err = models.CreateReview(models.CreateReviewOptions{ + if review, err = models.CreateReview(ctx, models.CreateReviewOptions{ Type: models.ReviewTypePending, Reviewer: doer, Issue: issue, @@ -135,7 +135,7 @@ func createCodeComment(ctx context.Context, doer *user_model.User, repo *repo_mo head := pr.GetGitRefName() if line > 0 { if reviewID != 0 { - first, err := models.FindComments(&models.FindCommentsOptions{ + first, err := models.FindComments(ctx, &models.FindCommentsOptions{ ReviewID: reviewID, Line: line, TreePath: treePath, @@ -152,7 +152,7 @@ func createCodeComment(ctx context.Context, doer *user_model.User, repo *repo_mo } else if err != nil && !models.IsErrCommentNotExist(err) { return nil, fmt.Errorf("Find first comment for %d line %d path %s. Error: %v", reviewID, line, treePath, err) } else { - review, err := models.GetReviewByID(reviewID) + review, err := models.GetReviewByID(ctx, reviewID) if err == nil && len(review.CommitID) > 0 { head = review.CommitID } else if err != nil && !models.IsErrReviewNotExist(err) { @@ -272,7 +272,7 @@ func SubmitReview(ctx context.Context, doer *user_model.User, gitRepo *git.Repos // DismissReview dismissing stale review by repo admin func DismissReview(ctx context.Context, reviewID int64, message string, doer *user_model.User, isDismiss bool) (comment *models.Comment, err error) { - review, err := models.GetReviewByID(reviewID) + review, err := models.GetReviewByID(ctx, reviewID) if err != nil { return } diff --git a/services/release/release_test.go b/services/release/release_test.go index 19d985491..823560a09 100644 --- a/services/release/release_test.go +++ b/services/release/release_test.go @@ -11,6 +11,7 @@ import ( "time" "code.gitea.io/gitea/models" + "code.gitea.io/gitea/models/db" repo_model "code.gitea.io/gitea/models/repo" "code.gitea.io/gitea/models/unittest" user_model "code.gitea.io/gitea/models/user" @@ -248,7 +249,7 @@ func TestRelease_Update(t *testing.T) { assert.NoError(t, err) assert.NoError(t, UpdateRelease(user, gitRepo, release, []string{attach.UUID}, nil, nil)) - assert.NoError(t, models.GetReleaseAttachments(release)) + assert.NoError(t, models.GetReleaseAttachments(db.DefaultContext, release)) assert.Len(t, release.Attachments, 1) assert.EqualValues(t, attach.UUID, release.Attachments[0].UUID) assert.EqualValues(t, release.ID, release.Attachments[0].ReleaseID) @@ -259,7 +260,7 @@ func TestRelease_Update(t *testing.T) { attach.UUID: "test2.txt", })) release.Attachments = nil - assert.NoError(t, models.GetReleaseAttachments(release)) + assert.NoError(t, models.GetReleaseAttachments(db.DefaultContext, release)) assert.Len(t, release.Attachments, 1) assert.EqualValues(t, attach.UUID, release.Attachments[0].UUID) assert.EqualValues(t, release.ID, release.Attachments[0].ReleaseID) @@ -268,7 +269,7 @@ func TestRelease_Update(t *testing.T) { // delete the attachment assert.NoError(t, UpdateRelease(user, gitRepo, release, nil, []string{attach.UUID}, nil)) release.Attachments = nil - assert.NoError(t, models.GetReleaseAttachments(release)) + assert.NoError(t, models.GetReleaseAttachments(db.DefaultContext, release)) assert.Empty(t, release.Attachments) } diff --git a/services/repository/adopt.go b/services/repository/adopt.go index b287d94f9..1e8c22a47 100644 --- a/services/repository/adopt.go +++ b/services/repository/adopt.go @@ -208,7 +208,7 @@ func DeleteUnadoptedRepository(doer, u *user_model.User, repoName string) error } } - if exist, err := repo_model.IsRepositoryExist(u, repoName); err != nil { + if exist, err := repo_model.IsRepositoryExist(db.DefaultContext, u, repoName); err != nil { return err } else if exist { return repo_model.ErrRepoAlreadyExist{ @@ -238,7 +238,7 @@ func checkUnadoptedRepositories(userName string, repoNamesToCheck []string, unad if len(repoNamesToCheck) == 0 { return nil } - ctxUser, err := user_model.GetUserByName(userName) + ctxUser, err := user_model.GetUserByName(db.DefaultContext, userName) if err != nil { if user_model.IsErrUserNotExist(err) { log.Debug("Missing user: %s", userName) diff --git a/services/repository/avatar.go b/services/repository/avatar.go index f51a312e1..dcf04c7e5 100644 --- a/services/repository/avatar.go +++ b/services/repository/avatar.go @@ -44,7 +44,7 @@ func UploadAvatar(repo *repo_model.Repository, data []byte) error { // Users can upload the same image to other repo - prefix it with ID // Then repo will be removed - only it avatar file will be removed repo.Avatar = newAvatar - if err := repo_model.UpdateRepositoryColsCtx(ctx, repo, "avatar"); err != nil { + if err := repo_model.UpdateRepositoryCols(ctx, repo, "avatar"); err != nil { return fmt.Errorf("UploadAvatar: Update repository avatar: %v", err) } @@ -83,7 +83,7 @@ func DeleteAvatar(repo *repo_model.Repository) error { defer committer.Close() repo.Avatar = "" - if err := repo_model.UpdateRepositoryColsCtx(ctx, repo, "avatar"); err != nil { + if err := repo_model.UpdateRepositoryCols(ctx, repo, "avatar"); err != nil { return fmt.Errorf("DeleteAvatar: Update repository avatar: %v", err) } @@ -117,5 +117,5 @@ func generateAvatar(ctx context.Context, templateRepo, generateRepo *repo_model. return err } - return repo_model.UpdateRepositoryColsCtx(ctx, generateRepo, "avatar") + return repo_model.UpdateRepositoryCols(ctx, generateRepo, "avatar") } diff --git a/services/repository/files/patch.go b/services/repository/files/patch.go index 240cb4fe2..73464f31f 100644 --- a/services/repository/files/patch.go +++ b/services/repository/files/patch.go @@ -66,7 +66,7 @@ func (opts *ApplyDiffPatchOptions) Validate(ctx context.Context, repo *repo_mode return err } } else { - protectedBranch, err := models.GetProtectedBranchBy(repo.ID, opts.OldBranch) + protectedBranch, err := models.GetProtectedBranchBy(ctx, repo.ID, opts.OldBranch) if err != nil { return err } diff --git a/services/repository/files/update.go b/services/repository/files/update.go index 2cb40aac4..a093ee5da 100644 --- a/services/repository/files/update.go +++ b/services/repository/files/update.go @@ -462,7 +462,7 @@ func CreateOrUpdateRepoFile(ctx context.Context, repo *repo_model.Repository, do // VerifyBranchProtection verify the branch protection for modifying the given treePath on the given branch func VerifyBranchProtection(ctx context.Context, repo *repo_model.Repository, doer *user_model.User, branchName, treePath string) error { - protectedBranch, err := models.GetProtectedBranchBy(repo.ID, branchName) + protectedBranch, err := models.GetProtectedBranchBy(ctx, repo.ID, branchName) if err != nil { return err } diff --git a/services/repository/push.go b/services/repository/push.go index 4eb52c18c..5ca8c7398 100644 --- a/services/repository/push.go +++ b/services/repository/push.go @@ -181,7 +181,7 @@ func pushUpdates(optsList []*repo_module.PushUpdateOptions) error { } } // Update the is empty and default_branch columns - if err := repo_model.UpdateRepositoryCols(repo, "default_branch", "is_empty"); err != nil { + if err := repo_model.UpdateRepositoryCols(db.DefaultContext, repo, "default_branch", "is_empty"); err != nil { return fmt.Errorf("UpdateRepositoryCols: %v", err) } } @@ -269,7 +269,7 @@ func pushUpdates(optsList []*repo_module.PushUpdateOptions) error { } // Even if user delete a branch on a repository which he didn't watch, he will be watch that. - if err = repo_model.WatchIfAuto(opts.PusherID, repo.ID, true); err != nil { + if err = repo_model.WatchIfAuto(db.DefaultContext, opts.PusherID, repo.ID, true); err != nil { log.Warn("Fail to perform auto watch on user %v for repo %v: %v", opts.PusherID, repo.ID, err) } } else { diff --git a/services/user/user.go b/services/user/user.go index d41fc4249..4db4d7ca1 100644 --- a/services/user/user.go +++ b/services/user/user.go @@ -44,7 +44,7 @@ func DeleteUser(u *user_model.User) error { // cannot perform delete operation. // Check ownership of repository. - count, err := repo_model.GetRepositoryCount(ctx, u.ID) + count, err := repo_model.CountRepositories(ctx, repo_model.CountRepositoryOptions{OwnerID: u.ID}) if err != nil { return fmt.Errorf("GetRepositoryCount: %v", err) } else if count > 0 { @@ -78,7 +78,7 @@ func DeleteUser(u *user_model.User) error { if err = asymkey_model.RewriteAllPublicKeys(); err != nil { return err } - if err = asymkey_model.RewriteAllPrincipalKeys(); err != nil { + if err = asymkey_model.RewriteAllPrincipalKeys(db.DefaultContext); err != nil { return err } diff --git a/services/webhook/webhook.go b/services/webhook/webhook.go index b15b8173f..68cfe147a 100644 --- a/services/webhook/webhook.go +++ b/services/webhook/webhook.go @@ -5,10 +5,12 @@ package webhook import ( + "context" "fmt" "strconv" "strings" + "code.gitea.io/gitea/models/db" repo_model "code.gitea.io/gitea/models/repo" webhook_model "code.gitea.io/gitea/models/webhook" "code.gitea.io/gitea/modules/git" @@ -218,15 +220,15 @@ func prepareWebhook(w *webhook_model.Webhook, repo *repo_model.Repository, event // PrepareWebhooks adds new webhooks to task queue for given payload. func PrepareWebhooks(repo *repo_model.Repository, event webhook_model.HookEventType, p api.Payloader) error { - if err := prepareWebhooks(repo, event, p); err != nil { + if err := prepareWebhooks(db.DefaultContext, repo, event, p); err != nil { return err } return addToTask(repo.ID) } -func prepareWebhooks(repo *repo_model.Repository, event webhook_model.HookEventType, p api.Payloader) error { - ws, err := webhook_model.ListWebhooksByOpts(&webhook_model.ListWebhookOptions{ +func prepareWebhooks(ctx context.Context, repo *repo_model.Repository, event webhook_model.HookEventType, p api.Payloader) error { + ws, err := webhook_model.ListWebhooksByOpts(ctx, &webhook_model.ListWebhookOptions{ RepoID: repo.ID, IsActive: util.OptionalBoolTrue, }) @@ -237,7 +239,7 @@ func prepareWebhooks(repo *repo_model.Repository, event webhook_model.HookEventT // check if repo belongs to org and append additional webhooks if repo.MustOwner().IsOrganization() { // get hooks for org - orgHooks, err := webhook_model.ListWebhooksByOpts(&webhook_model.ListWebhookOptions{ + orgHooks, err := webhook_model.ListWebhooksByOpts(ctx, &webhook_model.ListWebhookOptions{ OrgID: repo.OwnerID, IsActive: util.OptionalBoolTrue, }) @@ -248,7 +250,7 @@ func prepareWebhooks(repo *repo_model.Repository, event webhook_model.HookEventT } // Add any admin-defined system webhooks - systemHooks, err := webhook_model.GetSystemWebhooks(util.OptionalBoolTrue) + systemHooks, err := webhook_model.GetSystemWebhooks(ctx, util.OptionalBoolTrue) if err != nil { return fmt.Errorf("GetSystemWebhooks: %v", err) }