Initial commit, source code of Gitea 1.23.1
Some checks failed
release-nightly / nightly-binary (push) Has been cancelled
release-nightly / nightly-docker-rootful (push) Has been cancelled
release-nightly / nightly-docker-rootless (push) Has been cancelled

This commit is contained in:
Taldariner
2025-01-28 08:57:29 +02:00
commit 65b18d67e8
6122 changed files with 702653 additions and 0 deletions

183
models/actions/artifact.go Normal file
View File

@@ -0,0 +1,183 @@
// Copyright 2023 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
// This artifact server is inspired by https://github.com/nektos/act/blob/master/pkg/artifacts/server.go.
// It updates url setting and uses ObjectStore to handle artifacts persistence.
package actions
import (
"context"
"errors"
"time"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/timeutil"
"code.gitea.io/gitea/modules/util"
"xorm.io/builder"
)
// ArtifactStatus is the status of an artifact, uploading, expired or need-delete
type ArtifactStatus int64
const (
ArtifactStatusUploadPending ArtifactStatus = iota + 1 // 1 ArtifactStatusUploadPending is the status of an artifact upload that is pending
ArtifactStatusUploadConfirmed // 2 ArtifactStatusUploadConfirmed is the status of an artifact upload that is confirmed
ArtifactStatusUploadError // 3 ArtifactStatusUploadError is the status of an artifact upload that is errored
ArtifactStatusExpired // 4, ArtifactStatusExpired is the status of an artifact that is expired
ArtifactStatusPendingDeletion // 5, ArtifactStatusPendingDeletion is the status of an artifact that is pending deletion
ArtifactStatusDeleted // 6, ArtifactStatusDeleted is the status of an artifact that is deleted
)
func init() {
db.RegisterModel(new(ActionArtifact))
}
// ActionArtifact is a file that is stored in the artifact storage.
type ActionArtifact struct {
ID int64 `xorm:"pk autoincr"`
RunID int64 `xorm:"index unique(runid_name_path)"` // The run id of the artifact
RunnerID int64
RepoID int64 `xorm:"index"`
OwnerID int64
CommitSHA string
StoragePath string // The path to the artifact in the storage
FileSize int64 // The size of the artifact in bytes
FileCompressedSize int64 // The size of the artifact in bytes after gzip compression
ContentEncoding string // The content encoding of the artifact
ArtifactPath string `xorm:"index unique(runid_name_path)"` // The path to the artifact when runner uploads it
ArtifactName string `xorm:"index unique(runid_name_path)"` // The name of the artifact when runner uploads it
Status int64 `xorm:"index"` // The status of the artifact, uploading, expired or need-delete
CreatedUnix timeutil.TimeStamp `xorm:"created"`
UpdatedUnix timeutil.TimeStamp `xorm:"updated index"`
ExpiredUnix timeutil.TimeStamp `xorm:"index"` // The time when the artifact will be expired
}
func CreateArtifact(ctx context.Context, t *ActionTask, artifactName, artifactPath string, expiredDays int64) (*ActionArtifact, error) {
if err := t.LoadJob(ctx); err != nil {
return nil, err
}
artifact, err := getArtifactByNameAndPath(ctx, t.Job.RunID, artifactName, artifactPath)
if errors.Is(err, util.ErrNotExist) {
artifact := &ActionArtifact{
ArtifactName: artifactName,
ArtifactPath: artifactPath,
RunID: t.Job.RunID,
RunnerID: t.RunnerID,
RepoID: t.RepoID,
OwnerID: t.OwnerID,
CommitSHA: t.CommitSHA,
Status: int64(ArtifactStatusUploadPending),
ExpiredUnix: timeutil.TimeStamp(time.Now().Unix() + timeutil.Day*expiredDays),
}
if _, err := db.GetEngine(ctx).Insert(artifact); err != nil {
return nil, err
}
return artifact, nil
} else if err != nil {
return nil, err
}
if _, err := db.GetEngine(ctx).ID(artifact.ID).Cols("expired_unix").Update(&ActionArtifact{
ExpiredUnix: timeutil.TimeStamp(time.Now().Unix() + timeutil.Day*expiredDays),
}); err != nil {
return nil, err
}
return artifact, nil
}
func getArtifactByNameAndPath(ctx context.Context, runID int64, name, fpath string) (*ActionArtifact, error) {
var art ActionArtifact
has, err := db.GetEngine(ctx).Where("run_id = ? AND artifact_name = ? AND artifact_path = ?", runID, name, fpath).Get(&art)
if err != nil {
return nil, err
} else if !has {
return nil, util.ErrNotExist
}
return &art, nil
}
// UpdateArtifactByID updates an artifact by id
func UpdateArtifactByID(ctx context.Context, id int64, art *ActionArtifact) error {
art.ID = id
_, err := db.GetEngine(ctx).ID(id).AllCols().Update(art)
return err
}
type FindArtifactsOptions struct {
db.ListOptions
RepoID int64
RunID int64
ArtifactName string
Status int
}
func (opts FindArtifactsOptions) ToConds() builder.Cond {
cond := builder.NewCond()
if opts.RepoID > 0 {
cond = cond.And(builder.Eq{"repo_id": opts.RepoID})
}
if opts.RunID > 0 {
cond = cond.And(builder.Eq{"run_id": opts.RunID})
}
if opts.ArtifactName != "" {
cond = cond.And(builder.Eq{"artifact_name": opts.ArtifactName})
}
if opts.Status > 0 {
cond = cond.And(builder.Eq{"status": opts.Status})
}
return cond
}
// ActionArtifactMeta is the meta data of an artifact
type ActionArtifactMeta struct {
ArtifactName string
FileSize int64
Status ArtifactStatus
}
// ListUploadedArtifactsMeta returns all uploaded artifacts meta of a run
func ListUploadedArtifactsMeta(ctx context.Context, runID int64) ([]*ActionArtifactMeta, error) {
arts := make([]*ActionArtifactMeta, 0, 10)
return arts, db.GetEngine(ctx).Table("action_artifact").
Where("run_id=? AND (status=? OR status=?)", runID, ArtifactStatusUploadConfirmed, ArtifactStatusExpired).
GroupBy("artifact_name").
Select("artifact_name, sum(file_size) as file_size, max(status) as status").
Find(&arts)
}
// ListNeedExpiredArtifacts returns all need expired artifacts but not deleted
func ListNeedExpiredArtifacts(ctx context.Context) ([]*ActionArtifact, error) {
arts := make([]*ActionArtifact, 0, 10)
return arts, db.GetEngine(ctx).
Where("expired_unix < ? AND status = ?", timeutil.TimeStamp(time.Now().Unix()), ArtifactStatusUploadConfirmed).Find(&arts)
}
// ListPendingDeleteArtifacts returns all artifacts in pending-delete status.
// limit is the max number of artifacts to return.
func ListPendingDeleteArtifacts(ctx context.Context, limit int) ([]*ActionArtifact, error) {
arts := make([]*ActionArtifact, 0, limit)
return arts, db.GetEngine(ctx).
Where("status = ?", ArtifactStatusPendingDeletion).Limit(limit).Find(&arts)
}
// SetArtifactExpired sets an artifact to expired
func SetArtifactExpired(ctx context.Context, artifactID int64) error {
_, err := db.GetEngine(ctx).Where("id=? AND status = ?", artifactID, ArtifactStatusUploadConfirmed).Cols("status").Update(&ActionArtifact{Status: int64(ArtifactStatusExpired)})
return err
}
// SetArtifactNeedDelete sets an artifact to need-delete, cron job will delete it
func SetArtifactNeedDelete(ctx context.Context, runID int64, name string) error {
_, err := db.GetEngine(ctx).Where("run_id=? AND artifact_name=? AND status = ?", runID, name, ArtifactStatusUploadConfirmed).Cols("status").Update(&ActionArtifact{Status: int64(ArtifactStatusPendingDeletion)})
return err
}
// SetArtifactDeleted sets an artifact to deleted
func SetArtifactDeleted(ctx context.Context, artifactID int64) error {
_, err := db.GetEngine(ctx).ID(artifactID).Cols("status").Update(&ActionArtifact{Status: int64(ArtifactStatusDeleted)})
return err
}

View File

@@ -0,0 +1,18 @@
// Copyright 2023 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package actions
import (
"testing"
"code.gitea.io/gitea/models/unittest"
)
func TestMain(m *testing.M) {
unittest.MainTest(m, &unittest.TestOptions{
FixtureFiles: []string{
"action_runner_token.yml",
},
})
}

437
models/actions/run.go Normal file
View File

@@ -0,0 +1,437 @@
// Copyright 2022 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package actions
import (
"context"
"fmt"
"slices"
"strings"
"time"
"code.gitea.io/gitea/models/db"
repo_model "code.gitea.io/gitea/models/repo"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/git"
"code.gitea.io/gitea/modules/json"
api "code.gitea.io/gitea/modules/structs"
"code.gitea.io/gitea/modules/timeutil"
"code.gitea.io/gitea/modules/util"
webhook_module "code.gitea.io/gitea/modules/webhook"
"github.com/nektos/act/pkg/jobparser"
"xorm.io/builder"
)
// ActionRun represents a run of a workflow file
type ActionRun struct {
ID int64
Title string
RepoID int64 `xorm:"index unique(repo_index)"`
Repo *repo_model.Repository `xorm:"-"`
OwnerID int64 `xorm:"index"`
WorkflowID string `xorm:"index"` // the name of workflow file
Index int64 `xorm:"index unique(repo_index)"` // a unique number for each run of a repository
TriggerUserID int64 `xorm:"index"`
TriggerUser *user_model.User `xorm:"-"`
ScheduleID int64
Ref string `xorm:"index"` // the commit/tag/… that caused the run
IsRefDeleted bool `xorm:"-"`
CommitSHA string
IsForkPullRequest bool // If this is triggered by a PR from a forked repository or an untrusted user, we need to check if it is approved and limit permissions when running the workflow.
NeedApproval bool // may need approval if it's a fork pull request
ApprovedBy int64 `xorm:"index"` // who approved
Event webhook_module.HookEventType // the webhook event that causes the workflow to run
EventPayload string `xorm:"LONGTEXT"`
TriggerEvent string // the trigger event defined in the `on` configuration of the triggered workflow
Status Status `xorm:"index"`
Version int `xorm:"version default 0"` // Status could be updated concomitantly, so an optimistic lock is needed
// Started and Stopped is used for recording last run time, if rerun happened, they will be reset to 0
Started timeutil.TimeStamp
Stopped timeutil.TimeStamp
// PreviousDuration is used for recording previous duration
PreviousDuration time.Duration
Created timeutil.TimeStamp `xorm:"created"`
Updated timeutil.TimeStamp `xorm:"updated"`
}
func init() {
db.RegisterModel(new(ActionRun))
db.RegisterModel(new(ActionRunIndex))
}
func (run *ActionRun) HTMLURL() string {
if run.Repo == nil {
return ""
}
return fmt.Sprintf("%s/actions/runs/%d", run.Repo.HTMLURL(), run.Index)
}
func (run *ActionRun) Link() string {
if run.Repo == nil {
return ""
}
return fmt.Sprintf("%s/actions/runs/%d", run.Repo.Link(), run.Index)
}
func (run *ActionRun) WorkflowLink() string {
if run.Repo == nil {
return ""
}
return fmt.Sprintf("%s/actions/?workflow=%s", run.Repo.Link(), run.WorkflowID)
}
// RefLink return the url of run's ref
func (run *ActionRun) RefLink() string {
refName := git.RefName(run.Ref)
if refName.IsPull() {
return run.Repo.Link() + "/pulls/" + refName.ShortName()
}
return git.RefURL(run.Repo.Link(), run.Ref)
}
// PrettyRef return #id for pull ref or ShortName for others
func (run *ActionRun) PrettyRef() string {
refName := git.RefName(run.Ref)
if refName.IsPull() {
return "#" + strings.TrimSuffix(strings.TrimPrefix(run.Ref, git.PullPrefix), "/head")
}
return refName.ShortName()
}
// LoadAttributes load Repo TriggerUser if not loaded
func (run *ActionRun) LoadAttributes(ctx context.Context) error {
if run == nil {
return nil
}
if err := run.LoadRepo(ctx); err != nil {
return err
}
if err := run.Repo.LoadAttributes(ctx); err != nil {
return err
}
if run.TriggerUser == nil {
u, err := user_model.GetPossibleUserByID(ctx, run.TriggerUserID)
if err != nil {
return err
}
run.TriggerUser = u
}
return nil
}
func (run *ActionRun) LoadRepo(ctx context.Context) error {
if run == nil || run.Repo != nil {
return nil
}
repo, err := repo_model.GetRepositoryByID(ctx, run.RepoID)
if err != nil {
return err
}
run.Repo = repo
return nil
}
func (run *ActionRun) Duration() time.Duration {
return calculateDuration(run.Started, run.Stopped, run.Status) + run.PreviousDuration
}
func (run *ActionRun) GetPushEventPayload() (*api.PushPayload, error) {
if run.Event == webhook_module.HookEventPush {
var payload api.PushPayload
if err := json.Unmarshal([]byte(run.EventPayload), &payload); err != nil {
return nil, err
}
return &payload, nil
}
return nil, fmt.Errorf("event %s is not a push event", run.Event)
}
func (run *ActionRun) GetPullRequestEventPayload() (*api.PullRequestPayload, error) {
if run.Event == webhook_module.HookEventPullRequest || run.Event == webhook_module.HookEventPullRequestSync {
var payload api.PullRequestPayload
if err := json.Unmarshal([]byte(run.EventPayload), &payload); err != nil {
return nil, err
}
return &payload, nil
}
return nil, fmt.Errorf("event %s is not a pull request event", run.Event)
}
func (run *ActionRun) IsSchedule() bool {
return run.ScheduleID > 0
}
func updateRepoRunsNumbers(ctx context.Context, repo *repo_model.Repository) error {
_, err := db.GetEngine(ctx).ID(repo.ID).
SetExpr("num_action_runs",
builder.Select("count(*)").From("action_run").
Where(builder.Eq{"repo_id": repo.ID}),
).
SetExpr("num_closed_action_runs",
builder.Select("count(*)").From("action_run").
Where(builder.Eq{
"repo_id": repo.ID,
}.And(
builder.In("status",
StatusSuccess,
StatusFailure,
StatusCancelled,
StatusSkipped,
),
),
),
).
Update(repo)
return err
}
// CancelPreviousJobs cancels all previous jobs of the same repository, reference, workflow, and event.
// It's useful when a new run is triggered, and all previous runs needn't be continued anymore.
func CancelPreviousJobs(ctx context.Context, repoID int64, ref, workflowID string, event webhook_module.HookEventType) error {
// Find all runs in the specified repository, reference, and workflow with non-final status
runs, total, err := db.FindAndCount[ActionRun](ctx, FindRunOptions{
RepoID: repoID,
Ref: ref,
WorkflowID: workflowID,
TriggerEvent: event,
Status: []Status{StatusRunning, StatusWaiting, StatusBlocked},
})
if err != nil {
return err
}
// If there are no runs found, there's no need to proceed with cancellation, so return nil.
if total == 0 {
return nil
}
// Iterate over each found run and cancel its associated jobs.
for _, run := range runs {
// Find all jobs associated with the current run.
jobs, err := db.Find[ActionRunJob](ctx, FindRunJobOptions{
RunID: run.ID,
})
if err != nil {
return err
}
// Iterate over each job and attempt to cancel it.
for _, job := range jobs {
// Skip jobs that are already in a terminal state (completed, cancelled, etc.).
status := job.Status
if status.IsDone() {
continue
}
// If the job has no associated task (probably an error), set its status to 'Cancelled' and stop it.
if job.TaskID == 0 {
job.Status = StatusCancelled
job.Stopped = timeutil.TimeStampNow()
// Update the job's status and stopped time in the database.
n, err := UpdateRunJob(ctx, job, builder.Eq{"task_id": 0}, "status", "stopped")
if err != nil {
return err
}
// If the update affected 0 rows, it means the job has changed in the meantime, so we need to try again.
if n == 0 {
return fmt.Errorf("job has changed, try again")
}
// Continue with the next job.
continue
}
// If the job has an associated task, try to stop the task, effectively cancelling the job.
if err := StopTask(ctx, job.TaskID, StatusCancelled); err != nil {
return err
}
}
}
// Return nil to indicate successful cancellation of all running and waiting jobs.
return nil
}
// InsertRun inserts a run
// The title will be cut off at 255 characters if it's longer than 255 characters.
func InsertRun(ctx context.Context, run *ActionRun, jobs []*jobparser.SingleWorkflow) error {
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return err
}
defer committer.Close()
index, err := db.GetNextResourceIndex(ctx, "action_run_index", run.RepoID)
if err != nil {
return err
}
run.Index = index
run.Title, _ = util.SplitStringAtByteN(run.Title, 255)
if err := db.Insert(ctx, run); err != nil {
return err
}
if run.Repo == nil {
repo, err := repo_model.GetRepositoryByID(ctx, run.RepoID)
if err != nil {
return err
}
run.Repo = repo
}
if err := updateRepoRunsNumbers(ctx, run.Repo); err != nil {
return err
}
runJobs := make([]*ActionRunJob, 0, len(jobs))
var hasWaiting bool
for _, v := range jobs {
id, job := v.Job()
needs := job.Needs()
if err := v.SetJob(id, job.EraseNeeds()); err != nil {
return err
}
payload, _ := v.Marshal()
status := StatusWaiting
if len(needs) > 0 || run.NeedApproval {
status = StatusBlocked
} else {
hasWaiting = true
}
job.Name, _ = util.SplitStringAtByteN(job.Name, 255)
runJobs = append(runJobs, &ActionRunJob{
RunID: run.ID,
RepoID: run.RepoID,
OwnerID: run.OwnerID,
CommitSHA: run.CommitSHA,
IsForkPullRequest: run.IsForkPullRequest,
Name: job.Name,
WorkflowPayload: payload,
JobID: id,
Needs: needs,
RunsOn: job.RunsOn(),
Status: status,
})
}
if err := db.Insert(ctx, runJobs); err != nil {
return err
}
// if there is a job in the waiting status, increase tasks version.
if hasWaiting {
if err := IncreaseTaskVersion(ctx, run.OwnerID, run.RepoID); err != nil {
return err
}
}
return committer.Commit()
}
func GetRunByID(ctx context.Context, id int64) (*ActionRun, error) {
var run ActionRun
has, err := db.GetEngine(ctx).Where("id=?", id).Get(&run)
if err != nil {
return nil, err
} else if !has {
return nil, fmt.Errorf("run with id %d: %w", id, util.ErrNotExist)
}
return &run, nil
}
func GetRunByIndex(ctx context.Context, repoID, index int64) (*ActionRun, error) {
run := &ActionRun{
RepoID: repoID,
Index: index,
}
has, err := db.GetEngine(ctx).Get(run)
if err != nil {
return nil, err
} else if !has {
return nil, fmt.Errorf("run with index %d %d: %w", repoID, index, util.ErrNotExist)
}
return run, nil
}
func GetLatestRun(ctx context.Context, repoID int64) (*ActionRun, error) {
run := &ActionRun{
RepoID: repoID,
}
has, err := db.GetEngine(ctx).Where("repo_id=?", repoID).Desc("index").Get(run)
if err != nil {
return nil, err
} else if !has {
return nil, fmt.Errorf("latest run with repo_id %d: %w", repoID, util.ErrNotExist)
}
return run, nil
}
func GetWorkflowLatestRun(ctx context.Context, repoID int64, workflowFile, branch, event string) (*ActionRun, error) {
var run ActionRun
q := db.GetEngine(ctx).Where("repo_id=?", repoID).
And("ref = ?", branch).
And("workflow_id = ?", workflowFile)
if event != "" {
q.And("event = ?", event)
}
has, err := q.Desc("id").Get(&run)
if err != nil {
return nil, err
} else if !has {
return nil, util.NewNotExistErrorf("run with repo_id %d, ref %s, workflow_id %s", repoID, branch, workflowFile)
}
return &run, nil
}
// UpdateRun updates a run.
// It requires the inputted run has Version set.
// It will return error if the version is not matched (it means the run has been changed after loaded).
func UpdateRun(ctx context.Context, run *ActionRun, cols ...string) error {
sess := db.GetEngine(ctx).ID(run.ID)
if len(cols) > 0 {
sess.Cols(cols...)
}
run.Title, _ = util.SplitStringAtByteN(run.Title, 255)
affected, err := sess.Update(run)
if err != nil {
return err
}
if affected == 0 {
return fmt.Errorf("run has changed")
// It's impossible that the run is not found, since Gitea never deletes runs.
}
if run.Status != 0 || slices.Contains(cols, "status") {
if run.RepoID == 0 {
run, err = GetRunByID(ctx, run.ID)
if err != nil {
return err
}
}
if run.Repo == nil {
repo, err := repo_model.GetRepositoryByID(ctx, run.RepoID)
if err != nil {
return err
}
run.Repo = repo
}
if err := updateRepoRunsNumbers(ctx, run.Repo); err != nil {
return err
}
}
return nil
}
type ActionRunIndex db.ResourceIndex

186
models/actions/run_job.go Normal file
View File

@@ -0,0 +1,186 @@
// Copyright 2022 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package actions
import (
"context"
"fmt"
"slices"
"time"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/timeutil"
"code.gitea.io/gitea/modules/util"
"xorm.io/builder"
)
// ActionRunJob represents a job of a run
type ActionRunJob struct {
ID int64
RunID int64 `xorm:"index"`
Run *ActionRun `xorm:"-"`
RepoID int64 `xorm:"index"`
OwnerID int64 `xorm:"index"`
CommitSHA string `xorm:"index"`
IsForkPullRequest bool
Name string `xorm:"VARCHAR(255)"`
Attempt int64
WorkflowPayload []byte
JobID string `xorm:"VARCHAR(255)"` // job id in workflow, not job's id
Needs []string `xorm:"JSON TEXT"`
RunsOn []string `xorm:"JSON TEXT"`
TaskID int64 // the latest task of the job
Status Status `xorm:"index"`
Started timeutil.TimeStamp
Stopped timeutil.TimeStamp
Created timeutil.TimeStamp `xorm:"created"`
Updated timeutil.TimeStamp `xorm:"updated index"`
}
func init() {
db.RegisterModel(new(ActionRunJob))
}
func (job *ActionRunJob) Duration() time.Duration {
return calculateDuration(job.Started, job.Stopped, job.Status)
}
func (job *ActionRunJob) LoadRun(ctx context.Context) error {
if job.Run == nil {
run, err := GetRunByID(ctx, job.RunID)
if err != nil {
return err
}
job.Run = run
}
return nil
}
// LoadAttributes load Run if not loaded
func (job *ActionRunJob) LoadAttributes(ctx context.Context) error {
if job == nil {
return nil
}
if err := job.LoadRun(ctx); err != nil {
return err
}
return job.Run.LoadAttributes(ctx)
}
func GetRunJobByID(ctx context.Context, id int64) (*ActionRunJob, error) {
var job ActionRunJob
has, err := db.GetEngine(ctx).Where("id=?", id).Get(&job)
if err != nil {
return nil, err
} else if !has {
return nil, fmt.Errorf("run job with id %d: %w", id, util.ErrNotExist)
}
return &job, nil
}
func GetRunJobsByRunID(ctx context.Context, runID int64) ([]*ActionRunJob, error) {
var jobs []*ActionRunJob
if err := db.GetEngine(ctx).Where("run_id=?", runID).OrderBy("id").Find(&jobs); err != nil {
return nil, err
}
return jobs, nil
}
func UpdateRunJob(ctx context.Context, job *ActionRunJob, cond builder.Cond, cols ...string) (int64, error) {
e := db.GetEngine(ctx)
sess := e.ID(job.ID)
if len(cols) > 0 {
sess.Cols(cols...)
}
if cond != nil {
sess.Where(cond)
}
affected, err := sess.Update(job)
if err != nil {
return 0, err
}
if affected == 0 || (!slices.Contains(cols, "status") && job.Status == 0) {
return affected, nil
}
if affected != 0 && slices.Contains(cols, "status") && job.Status.IsWaiting() {
// if the status of job changes to waiting again, increase tasks version.
if err := IncreaseTaskVersion(ctx, job.OwnerID, job.RepoID); err != nil {
return 0, err
}
}
if job.RunID == 0 {
var err error
if job, err = GetRunJobByID(ctx, job.ID); err != nil {
return 0, err
}
}
{
// Other goroutines may aggregate the status of the run and update it too.
// So we need load the run and its jobs before updating the run.
run, err := GetRunByID(ctx, job.RunID)
if err != nil {
return 0, err
}
jobs, err := GetRunJobsByRunID(ctx, job.RunID)
if err != nil {
return 0, err
}
run.Status = AggregateJobStatus(jobs)
if run.Started.IsZero() && run.Status.IsRunning() {
run.Started = timeutil.TimeStampNow()
}
if run.Stopped.IsZero() && run.Status.IsDone() {
run.Stopped = timeutil.TimeStampNow()
}
if err := UpdateRun(ctx, run, "status", "started", "stopped"); err != nil {
return 0, fmt.Errorf("update run %d: %w", run.ID, err)
}
}
return affected, nil
}
func AggregateJobStatus(jobs []*ActionRunJob) Status {
allSuccessOrSkipped := len(jobs) != 0
allSkipped := len(jobs) != 0
var hasFailure, hasCancelled, hasWaiting, hasRunning, hasBlocked bool
for _, job := range jobs {
allSuccessOrSkipped = allSuccessOrSkipped && (job.Status == StatusSuccess || job.Status == StatusSkipped)
allSkipped = allSkipped && job.Status == StatusSkipped
hasFailure = hasFailure || job.Status == StatusFailure
hasCancelled = hasCancelled || job.Status == StatusCancelled
hasWaiting = hasWaiting || job.Status == StatusWaiting
hasRunning = hasRunning || job.Status == StatusRunning
hasBlocked = hasBlocked || job.Status == StatusBlocked
}
switch {
case allSkipped:
return StatusSkipped
case allSuccessOrSkipped:
return StatusSuccess
case hasCancelled:
return StatusCancelled
case hasFailure:
return StatusFailure
case hasRunning:
return StatusRunning
case hasWaiting:
return StatusWaiting
case hasBlocked:
return StatusBlocked
default:
return StatusUnknown // it shouldn't happen
}
}

View File

@@ -0,0 +1,80 @@
// Copyright 2022 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package actions
import (
"context"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/container"
"code.gitea.io/gitea/modules/timeutil"
"xorm.io/builder"
)
type ActionJobList []*ActionRunJob
func (jobs ActionJobList) GetRunIDs() []int64 {
return container.FilterSlice(jobs, func(j *ActionRunJob) (int64, bool) {
return j.RunID, j.RunID != 0
})
}
func (jobs ActionJobList) LoadRuns(ctx context.Context, withRepo bool) error {
runIDs := jobs.GetRunIDs()
runs := make(map[int64]*ActionRun, len(runIDs))
if err := db.GetEngine(ctx).In("id", runIDs).Find(&runs); err != nil {
return err
}
for _, j := range jobs {
if j.RunID > 0 && j.Run == nil {
j.Run = runs[j.RunID]
}
}
if withRepo {
var runsList RunList = make([]*ActionRun, 0, len(runs))
for _, r := range runs {
runsList = append(runsList, r)
}
return runsList.LoadRepos(ctx)
}
return nil
}
func (jobs ActionJobList) LoadAttributes(ctx context.Context, withRepo bool) error {
return jobs.LoadRuns(ctx, withRepo)
}
type FindRunJobOptions struct {
db.ListOptions
RunID int64
RepoID int64
OwnerID int64
CommitSHA string
Statuses []Status
UpdatedBefore timeutil.TimeStamp
}
func (opts FindRunJobOptions) ToConds() builder.Cond {
cond := builder.NewCond()
if opts.RunID > 0 {
cond = cond.And(builder.Eq{"run_id": opts.RunID})
}
if opts.RepoID > 0 {
cond = cond.And(builder.Eq{"repo_id": opts.RepoID})
}
if opts.OwnerID > 0 {
cond = cond.And(builder.Eq{"owner_id": opts.OwnerID})
}
if opts.CommitSHA != "" {
cond = cond.And(builder.Eq{"commit_sha": opts.CommitSHA})
}
if len(opts.Statuses) > 0 {
cond = cond.And(builder.In("status", opts.Statuses))
}
if opts.UpdatedBefore > 0 {
cond = cond.And(builder.Lt{"updated": opts.UpdatedBefore})
}
return cond
}

View File

@@ -0,0 +1,85 @@
// Copyright 2024 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package actions
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestAggregateJobStatus(t *testing.T) {
testStatuses := func(expected Status, statuses []Status) {
t.Helper()
var jobs []*ActionRunJob
for _, v := range statuses {
jobs = append(jobs, &ActionRunJob{Status: v})
}
actual := AggregateJobStatus(jobs)
if !assert.Equal(t, expected, actual) {
var statusStrings []string
for _, s := range statuses {
statusStrings = append(statusStrings, s.String())
}
t.Errorf("AggregateJobStatus(%v) = %v, want %v", statusStrings, statusNames[actual], statusNames[expected])
}
}
cases := []struct {
statuses []Status
expected Status
}{
// unknown cases, maybe it shouldn't happen in real world
{[]Status{}, StatusUnknown},
{[]Status{StatusUnknown, StatusSuccess}, StatusUnknown},
{[]Status{StatusUnknown, StatusSkipped}, StatusUnknown},
{[]Status{StatusUnknown, StatusFailure}, StatusFailure},
{[]Status{StatusUnknown, StatusCancelled}, StatusCancelled},
{[]Status{StatusUnknown, StatusWaiting}, StatusWaiting},
{[]Status{StatusUnknown, StatusRunning}, StatusRunning},
{[]Status{StatusUnknown, StatusBlocked}, StatusBlocked},
// success with other status
{[]Status{StatusSuccess}, StatusSuccess},
{[]Status{StatusSuccess, StatusSkipped}, StatusSuccess}, // skipped doesn't affect success
{[]Status{StatusSuccess, StatusFailure}, StatusFailure},
{[]Status{StatusSuccess, StatusCancelled}, StatusCancelled},
{[]Status{StatusSuccess, StatusWaiting}, StatusWaiting},
{[]Status{StatusSuccess, StatusRunning}, StatusRunning},
{[]Status{StatusSuccess, StatusBlocked}, StatusBlocked},
// any cancelled, then cancelled
{[]Status{StatusCancelled}, StatusCancelled},
{[]Status{StatusCancelled, StatusSuccess}, StatusCancelled},
{[]Status{StatusCancelled, StatusSkipped}, StatusCancelled},
{[]Status{StatusCancelled, StatusFailure}, StatusCancelled},
{[]Status{StatusCancelled, StatusWaiting}, StatusCancelled},
{[]Status{StatusCancelled, StatusRunning}, StatusCancelled},
{[]Status{StatusCancelled, StatusBlocked}, StatusCancelled},
// failure with other status, fail fast
// Should "running" win? Maybe no: old code does make "running" win, but GitHub does fail fast.
{[]Status{StatusFailure}, StatusFailure},
{[]Status{StatusFailure, StatusSuccess}, StatusFailure},
{[]Status{StatusFailure, StatusSkipped}, StatusFailure},
{[]Status{StatusFailure, StatusCancelled}, StatusCancelled},
{[]Status{StatusFailure, StatusWaiting}, StatusFailure},
{[]Status{StatusFailure, StatusRunning}, StatusFailure},
{[]Status{StatusFailure, StatusBlocked}, StatusFailure},
// skipped with other status
// "all skipped" is also considered as "mergeable" by "services/actions.toCommitStatus", the same as GitHub
{[]Status{StatusSkipped}, StatusSkipped},
{[]Status{StatusSkipped, StatusSuccess}, StatusSuccess},
{[]Status{StatusSkipped, StatusFailure}, StatusFailure},
{[]Status{StatusSkipped, StatusCancelled}, StatusCancelled},
{[]Status{StatusSkipped, StatusWaiting}, StatusWaiting},
{[]Status{StatusSkipped, StatusRunning}, StatusRunning},
{[]Status{StatusSkipped, StatusBlocked}, StatusBlocked},
}
for _, c := range cases {
testStatuses(c.expected, c.statuses)
}
}

138
models/actions/run_list.go Normal file
View File

@@ -0,0 +1,138 @@
// Copyright 2022 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package actions
import (
"context"
"code.gitea.io/gitea/models/db"
repo_model "code.gitea.io/gitea/models/repo"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/container"
webhook_module "code.gitea.io/gitea/modules/webhook"
"xorm.io/builder"
)
type RunList []*ActionRun
// GetUserIDs returns a slice of user's id
func (runs RunList) GetUserIDs() []int64 {
return container.FilterSlice(runs, func(run *ActionRun) (int64, bool) {
return run.TriggerUserID, true
})
}
func (runs RunList) GetRepoIDs() []int64 {
return container.FilterSlice(runs, func(run *ActionRun) (int64, bool) {
return run.RepoID, true
})
}
func (runs RunList) LoadTriggerUser(ctx context.Context) error {
userIDs := runs.GetUserIDs()
users := make(map[int64]*user_model.User, len(userIDs))
if err := db.GetEngine(ctx).In("id", userIDs).Find(&users); err != nil {
return err
}
for _, run := range runs {
if run.TriggerUserID == user_model.ActionsUserID {
run.TriggerUser = user_model.NewActionsUser()
} else {
run.TriggerUser = users[run.TriggerUserID]
if run.TriggerUser == nil {
run.TriggerUser = user_model.NewGhostUser()
}
}
}
return nil
}
func (runs RunList) LoadRepos(ctx context.Context) error {
repoIDs := runs.GetRepoIDs()
repos, err := repo_model.GetRepositoriesMapByIDs(ctx, repoIDs)
if err != nil {
return err
}
for _, run := range runs {
run.Repo = repos[run.RepoID]
}
return nil
}
type FindRunOptions struct {
db.ListOptions
RepoID int64
OwnerID int64
WorkflowID string
Ref string // the commit/tag/… that caused this workflow
TriggerUserID int64
TriggerEvent webhook_module.HookEventType
Approved bool // not util.OptionalBool, it works only when it's true
Status []Status
}
func (opts FindRunOptions) ToConds() builder.Cond {
cond := builder.NewCond()
if opts.RepoID > 0 {
cond = cond.And(builder.Eq{"repo_id": opts.RepoID})
}
if opts.OwnerID > 0 {
cond = cond.And(builder.Eq{"owner_id": opts.OwnerID})
}
if opts.WorkflowID != "" {
cond = cond.And(builder.Eq{"workflow_id": opts.WorkflowID})
}
if opts.TriggerUserID > 0 {
cond = cond.And(builder.Eq{"trigger_user_id": opts.TriggerUserID})
}
if opts.Approved {
cond = cond.And(builder.Gt{"approved_by": 0})
}
if len(opts.Status) > 0 {
cond = cond.And(builder.In("status", opts.Status))
}
if opts.Ref != "" {
cond = cond.And(builder.Eq{"ref": opts.Ref})
}
if opts.TriggerEvent != "" {
cond = cond.And(builder.Eq{"trigger_event": opts.TriggerEvent})
}
return cond
}
func (opts FindRunOptions) ToOrders() string {
return "`id` DESC"
}
type StatusInfo struct {
Status int
DisplayedStatus string
}
// GetStatusInfoList returns a slice of StatusInfo
func GetStatusInfoList(ctx context.Context) []StatusInfo {
// same as those in aggregateJobStatus
allStatus := []Status{StatusSuccess, StatusFailure, StatusWaiting, StatusRunning}
statusInfoList := make([]StatusInfo, 0, 4)
for _, s := range allStatus {
statusInfoList = append(statusInfoList, StatusInfo{
Status: int(s),
DisplayedStatus: s.String(),
})
}
return statusInfoList
}
// GetActors returns a slice of Actors
func GetActors(ctx context.Context, repoID int64) ([]*user_model.User, error) {
actors := make([]*user_model.User, 0, 10)
return actors, db.GetEngine(ctx).Where(builder.In("id", builder.Select("`action_run`.trigger_user_id").From("`action_run`").
GroupBy("`action_run`.trigger_user_id").
Where(builder.Eq{"`action_run`.repo_id": repoID}))).
Cols("id", "name", "full_name", "avatar", "avatar_email", "use_custom_avatar").
OrderBy(user_model.GetOrderByName()).
Find(&actors)
}

330
models/actions/runner.go Normal file
View File

@@ -0,0 +1,330 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package actions
import (
"context"
"fmt"
"strings"
"time"
"code.gitea.io/gitea/models/db"
repo_model "code.gitea.io/gitea/models/repo"
"code.gitea.io/gitea/models/shared/types"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/optional"
"code.gitea.io/gitea/modules/timeutil"
"code.gitea.io/gitea/modules/translation"
"code.gitea.io/gitea/modules/util"
runnerv1 "code.gitea.io/actions-proto-go/runner/v1"
"xorm.io/builder"
)
// ActionRunner represents runner machines
//
// It can be:
// 1. global runner, OwnerID is 0 and RepoID is 0
// 2. org/user level runner, OwnerID is org/user ID and RepoID is 0
// 3. repo level runner, OwnerID is 0 and RepoID is repo ID
//
// Please note that it's not acceptable to have both OwnerID and RepoID to be non-zero,
// or it will be complicated to find runners belonging to a specific owner.
// For example, conditions like `OwnerID = 1` will also return runner {OwnerID: 1, RepoID: 1},
// but it's a repo level runner, not an org/user level runner.
// To avoid this, make it clear with {OwnerID: 0, RepoID: 1} for repo level runners.
type ActionRunner struct {
ID int64
UUID string `xorm:"CHAR(36) UNIQUE"`
Name string `xorm:"VARCHAR(255)"`
Version string `xorm:"VARCHAR(64)"`
OwnerID int64 `xorm:"index"`
Owner *user_model.User `xorm:"-"`
RepoID int64 `xorm:"index"`
Repo *repo_model.Repository `xorm:"-"`
Description string `xorm:"TEXT"`
Base int // 0 native 1 docker 2 virtual machine
RepoRange string // glob match which repositories could use this runner
Token string `xorm:"-"`
TokenHash string `xorm:"UNIQUE"` // sha256 of token
TokenSalt string
// TokenLastEight string `xorm:"token_last_eight"` // it's unnecessary because we don't find runners by token
LastOnline timeutil.TimeStamp `xorm:"index"`
LastActive timeutil.TimeStamp `xorm:"index"`
// Store labels defined in state file (default: .runner file) of `act_runner`
AgentLabels []string `xorm:"TEXT"`
Created timeutil.TimeStamp `xorm:"created"`
Updated timeutil.TimeStamp `xorm:"updated"`
Deleted timeutil.TimeStamp `xorm:"deleted"`
}
const (
RunnerOfflineTime = time.Minute
RunnerIdleTime = 10 * time.Second
)
// BelongsToOwnerName before calling, should guarantee that all attributes are loaded
func (r *ActionRunner) BelongsToOwnerName() string {
if r.RepoID != 0 {
return r.Repo.FullName()
}
if r.OwnerID != 0 {
return r.Owner.Name
}
return ""
}
func (r *ActionRunner) BelongsToOwnerType() types.OwnerType {
if r.RepoID != 0 {
return types.OwnerTypeRepository
}
if r.OwnerID != 0 {
if r.Owner.Type == user_model.UserTypeOrganization {
return types.OwnerTypeOrganization
} else if r.Owner.Type == user_model.UserTypeIndividual {
return types.OwnerTypeIndividual
}
}
return types.OwnerTypeSystemGlobal
}
// if the logic here changed, you should also modify FindRunnerOptions.ToCond
func (r *ActionRunner) Status() runnerv1.RunnerStatus {
if time.Since(r.LastOnline.AsTime()) > RunnerOfflineTime {
return runnerv1.RunnerStatus_RUNNER_STATUS_OFFLINE
}
if time.Since(r.LastActive.AsTime()) > RunnerIdleTime {
return runnerv1.RunnerStatus_RUNNER_STATUS_IDLE
}
return runnerv1.RunnerStatus_RUNNER_STATUS_ACTIVE
}
func (r *ActionRunner) StatusName() string {
return strings.ToLower(strings.TrimPrefix(r.Status().String(), "RUNNER_STATUS_"))
}
func (r *ActionRunner) StatusLocaleName(lang translation.Locale) string {
return lang.TrString("actions.runners.status." + r.StatusName())
}
func (r *ActionRunner) IsOnline() bool {
status := r.Status()
if status == runnerv1.RunnerStatus_RUNNER_STATUS_IDLE || status == runnerv1.RunnerStatus_RUNNER_STATUS_ACTIVE {
return true
}
return false
}
// Editable checks if the runner is editable by the user
func (r *ActionRunner) Editable(ownerID, repoID int64) bool {
if ownerID == 0 && repoID == 0 {
return true
}
if ownerID > 0 && r.OwnerID == ownerID {
return true
}
return repoID > 0 && r.RepoID == repoID
}
// LoadAttributes loads the attributes of the runner
func (r *ActionRunner) LoadAttributes(ctx context.Context) error {
if r.OwnerID > 0 {
var user user_model.User
has, err := db.GetEngine(ctx).ID(r.OwnerID).Get(&user)
if err != nil {
return err
}
if has {
r.Owner = &user
}
}
if r.RepoID > 0 {
var repo repo_model.Repository
has, err := db.GetEngine(ctx).ID(r.RepoID).Get(&repo)
if err != nil {
return err
}
if has {
r.Repo = &repo
}
}
return nil
}
func (r *ActionRunner) GenerateToken() (err error) {
r.Token, r.TokenSalt, r.TokenHash, _, err = generateSaltedToken()
return err
}
func init() {
db.RegisterModel(&ActionRunner{})
}
type FindRunnerOptions struct {
db.ListOptions
RepoID int64
OwnerID int64 // it will be ignored if RepoID is set
Sort string
Filter string
IsOnline optional.Option[bool]
WithAvailable bool // not only runners belong to, but also runners can be used
}
func (opts FindRunnerOptions) ToConds() builder.Cond {
cond := builder.NewCond()
if opts.RepoID > 0 {
c := builder.NewCond().And(builder.Eq{"repo_id": opts.RepoID})
if opts.WithAvailable {
c = c.Or(builder.Eq{"owner_id": builder.Select("owner_id").From("repository").Where(builder.Eq{"id": opts.RepoID})})
c = c.Or(builder.Eq{"repo_id": 0, "owner_id": 0})
}
cond = cond.And(c)
} else if opts.OwnerID > 0 { // OwnerID is ignored if RepoID is set
c := builder.NewCond().And(builder.Eq{"owner_id": opts.OwnerID})
if opts.WithAvailable {
c = c.Or(builder.Eq{"repo_id": 0, "owner_id": 0})
}
cond = cond.And(c)
}
if opts.Filter != "" {
cond = cond.And(builder.Like{"name", opts.Filter})
}
if opts.IsOnline.Has() {
if opts.IsOnline.Value() {
cond = cond.And(builder.Gt{"last_online": time.Now().Add(-RunnerOfflineTime).Unix()})
} else {
cond = cond.And(builder.Lte{"last_online": time.Now().Add(-RunnerOfflineTime).Unix()})
}
}
return cond
}
func (opts FindRunnerOptions) ToOrders() string {
switch opts.Sort {
case "online":
return "last_online DESC"
case "offline":
return "last_online ASC"
case "alphabetically":
return "name ASC"
case "reversealphabetically":
return "name DESC"
case "newest":
return "id DESC"
case "oldest":
return "id ASC"
}
return "last_online DESC"
}
// GetRunnerByUUID returns a runner via uuid
func GetRunnerByUUID(ctx context.Context, uuid string) (*ActionRunner, error) {
var runner ActionRunner
has, err := db.GetEngine(ctx).Where("uuid=?", uuid).Get(&runner)
if err != nil {
return nil, err
} else if !has {
return nil, fmt.Errorf("runner with uuid %s: %w", uuid, util.ErrNotExist)
}
return &runner, nil
}
// GetRunnerByID returns a runner via id
func GetRunnerByID(ctx context.Context, id int64) (*ActionRunner, error) {
var runner ActionRunner
has, err := db.GetEngine(ctx).Where("id=?", id).Get(&runner)
if err != nil {
return nil, err
} else if !has {
return nil, fmt.Errorf("runner with id %d: %w", id, util.ErrNotExist)
}
return &runner, nil
}
// UpdateRunner updates runner's information.
func UpdateRunner(ctx context.Context, r *ActionRunner, cols ...string) error {
e := db.GetEngine(ctx)
r.Name, _ = util.SplitStringAtByteN(r.Name, 255)
var err error
if len(cols) == 0 {
_, err = e.ID(r.ID).AllCols().Update(r)
} else {
_, err = e.ID(r.ID).Cols(cols...).Update(r)
}
return err
}
// DeleteRunner deletes a runner by given ID.
func DeleteRunner(ctx context.Context, id int64) error {
if _, err := GetRunnerByID(ctx, id); err != nil {
return err
}
_, err := db.DeleteByID[ActionRunner](ctx, id)
return err
}
// CreateRunner creates new runner.
func CreateRunner(ctx context.Context, t *ActionRunner) error {
if t.OwnerID != 0 && t.RepoID != 0 {
// It's trying to create a runner that belongs to a repository, but OwnerID has been set accidentally.
// Remove OwnerID to avoid confusion; it's not worth returning an error here.
t.OwnerID = 0
}
t.Name, _ = util.SplitStringAtByteN(t.Name, 255)
return db.Insert(ctx, t)
}
func CountRunnersWithoutBelongingOwner(ctx context.Context) (int64, error) {
// Only affect action runners were a owner ID is set, as actions runners
// could also be created on a repository.
return db.GetEngine(ctx).Table("action_runner").
Join("LEFT", "`user`", "`action_runner`.owner_id = `user`.id").
Where("`action_runner`.owner_id != ?", 0).
And(builder.IsNull{"`user`.id"}).
Count(new(ActionRunner))
}
func FixRunnersWithoutBelongingOwner(ctx context.Context) (int64, error) {
subQuery := builder.Select("`action_runner`.id").
From("`action_runner`").
Join("LEFT", "`user`", "`action_runner`.owner_id = `user`.id").
Where(builder.Neq{"`action_runner`.owner_id": 0}).
And(builder.IsNull{"`user`.id"})
b := builder.Delete(builder.In("id", subQuery)).From("`action_runner`")
res, err := db.GetEngine(ctx).Exec(b)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
func CountRunnersWithoutBelongingRepo(ctx context.Context) (int64, error) {
return db.GetEngine(ctx).Table("action_runner").
Join("LEFT", "`repository`", "`action_runner`.repo_id = `repository`.id").
Where("`action_runner`.repo_id != ?", 0).
And(builder.IsNull{"`repository`.id"}).
Count(new(ActionRunner))
}
func FixRunnersWithoutBelongingRepo(ctx context.Context) (int64, error) {
subQuery := builder.Select("`action_runner`.id").
From("`action_runner`").
Join("LEFT", "`repository`", "`action_runner`.repo_id = `repository`.id").
Where(builder.Neq{"`action_runner`.repo_id": 0}).
And(builder.IsNull{"`repository`.id"})
b := builder.Delete(builder.In("id", subQuery)).From("`action_runner`")
res, err := db.GetEngine(ctx).Exec(b)
if err != nil {
return 0, err
}
return res.RowsAffected()
}

View File

@@ -0,0 +1,65 @@
// Copyright 2022 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package actions
import (
"context"
"code.gitea.io/gitea/models/db"
repo_model "code.gitea.io/gitea/models/repo"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/container"
)
type RunnerList []*ActionRunner
// GetUserIDs returns a slice of user's id
func (runners RunnerList) GetUserIDs() []int64 {
return container.FilterSlice(runners, func(runner *ActionRunner) (int64, bool) {
return runner.OwnerID, runner.OwnerID != 0
})
}
func (runners RunnerList) LoadOwners(ctx context.Context) error {
userIDs := runners.GetUserIDs()
users := make(map[int64]*user_model.User, len(userIDs))
if err := db.GetEngine(ctx).In("id", userIDs).Find(&users); err != nil {
return err
}
for _, runner := range runners {
if runner.OwnerID > 0 && runner.Owner == nil {
runner.Owner = users[runner.OwnerID]
}
}
return nil
}
func (runners RunnerList) getRepoIDs() []int64 {
return container.FilterSlice(runners, func(runner *ActionRunner) (int64, bool) {
return runner.RepoID, runner.RepoID > 0
})
}
func (runners RunnerList) LoadRepos(ctx context.Context) error {
repoIDs := runners.getRepoIDs()
repos := make(map[int64]*repo_model.Repository, len(repoIDs))
if err := db.GetEngine(ctx).In("id", repoIDs).Find(&repos); err != nil {
return err
}
for _, runner := range runners {
if runner.RepoID > 0 && runner.Repo == nil {
runner.Repo = repos[runner.RepoID]
}
}
return nil
}
func (runners RunnerList) LoadAttributes(ctx context.Context) error {
if err := runners.LoadOwners(ctx); err != nil {
return err
}
return runners.LoadRepos(ctx)
}

View File

@@ -0,0 +1,125 @@
// Copyright 2022 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package actions
import (
"context"
"fmt"
"code.gitea.io/gitea/models/db"
repo_model "code.gitea.io/gitea/models/repo"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/base"
"code.gitea.io/gitea/modules/timeutil"
"code.gitea.io/gitea/modules/util"
)
// ActionRunnerToken represents runner tokens
//
// It can be:
// 1. global token, OwnerID is 0 and RepoID is 0
// 2. org/user level token, OwnerID is org/user ID and RepoID is 0
// 3. repo level token, OwnerID is 0 and RepoID is repo ID
//
// Please note that it's not acceptable to have both OwnerID and RepoID to be non-zero,
// or it will be complicated to find tokens belonging to a specific owner.
// For example, conditions like `OwnerID = 1` will also return token {OwnerID: 1, RepoID: 1},
// but it's a repo level token, not an org/user level token.
// To avoid this, make it clear with {OwnerID: 0, RepoID: 1} for repo level tokens.
type ActionRunnerToken struct {
ID int64
Token string `xorm:"UNIQUE"`
OwnerID int64 `xorm:"index"`
Owner *user_model.User `xorm:"-"`
RepoID int64 `xorm:"index"`
Repo *repo_model.Repository `xorm:"-"`
IsActive bool // true means it can be used
Created timeutil.TimeStamp `xorm:"created"`
Updated timeutil.TimeStamp `xorm:"updated"`
Deleted timeutil.TimeStamp `xorm:"deleted"`
}
func init() {
db.RegisterModel(new(ActionRunnerToken))
}
// GetRunnerToken returns a action runner via token
func GetRunnerToken(ctx context.Context, token string) (*ActionRunnerToken, error) {
var runnerToken ActionRunnerToken
has, err := db.GetEngine(ctx).Where("token=?", token).Get(&runnerToken)
if err != nil {
return nil, err
} else if !has {
return nil, fmt.Errorf(`runner token "%s...": %w`, base.TruncateString(token, 3), util.ErrNotExist)
}
return &runnerToken, nil
}
// UpdateRunnerToken updates runner token information.
func UpdateRunnerToken(ctx context.Context, r *ActionRunnerToken, cols ...string) (err error) {
e := db.GetEngine(ctx)
if len(cols) == 0 {
_, err = e.ID(r.ID).AllCols().Update(r)
} else {
_, err = e.ID(r.ID).Cols(cols...).Update(r)
}
return err
}
// NewRunnerTokenWithValue creates a new active runner token and invalidate all old tokens
// ownerID will be ignored and treated as 0 if repoID is non-zero.
func NewRunnerTokenWithValue(ctx context.Context, ownerID, repoID int64, token string) (*ActionRunnerToken, error) {
if ownerID != 0 && repoID != 0 {
// It's trying to create a runner token that belongs to a repository, but OwnerID has been set accidentally.
// Remove OwnerID to avoid confusion; it's not worth returning an error here.
ownerID = 0
}
runnerToken := &ActionRunnerToken{
OwnerID: ownerID,
RepoID: repoID,
IsActive: true,
Token: token,
}
return runnerToken, db.WithTx(ctx, func(ctx context.Context) error {
if _, err := db.GetEngine(ctx).Where("owner_id =? AND repo_id = ?", ownerID, repoID).Cols("is_active").Update(&ActionRunnerToken{
IsActive: false,
}); err != nil {
return err
}
_, err := db.GetEngine(ctx).Insert(runnerToken)
return err
})
}
func NewRunnerToken(ctx context.Context, ownerID, repoID int64) (*ActionRunnerToken, error) {
token, err := util.CryptoRandomString(40)
if err != nil {
return nil, err
}
return NewRunnerTokenWithValue(ctx, ownerID, repoID, token)
}
// GetLatestRunnerToken returns the latest runner token
func GetLatestRunnerToken(ctx context.Context, ownerID, repoID int64) (*ActionRunnerToken, error) {
if ownerID != 0 && repoID != 0 {
// It's trying to get a runner token that belongs to a repository, but OwnerID has been set accidentally.
// Remove OwnerID to avoid confusion; it's not worth returning an error here.
ownerID = 0
}
var runnerToken ActionRunnerToken
has, err := db.GetEngine(ctx).Where("owner_id=? AND repo_id=?", ownerID, repoID).
OrderBy("id DESC").Get(&runnerToken)
if err != nil {
return nil, err
} else if !has {
return nil, fmt.Errorf("runner token: %w", util.ErrNotExist)
}
return &runnerToken, nil
}

View File

@@ -0,0 +1,40 @@
// Copyright 2023 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package actions
import (
"testing"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/unittest"
"github.com/stretchr/testify/assert"
)
func TestGetLatestRunnerToken(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
token := unittest.AssertExistsAndLoadBean(t, &ActionRunnerToken{ID: 3})
expectedToken, err := GetLatestRunnerToken(db.DefaultContext, 1, 0)
assert.NoError(t, err)
assert.EqualValues(t, expectedToken, token)
}
func TestNewRunnerToken(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
token, err := NewRunnerToken(db.DefaultContext, 1, 0)
assert.NoError(t, err)
expectedToken, err := GetLatestRunnerToken(db.DefaultContext, 1, 0)
assert.NoError(t, err)
assert.EqualValues(t, expectedToken, token)
}
func TestUpdateRunnerToken(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
token := unittest.AssertExistsAndLoadBean(t, &ActionRunnerToken{ID: 3})
token.IsActive = true
assert.NoError(t, UpdateRunnerToken(db.DefaultContext, token))
expectedToken, err := GetLatestRunnerToken(db.DefaultContext, 1, 0)
assert.NoError(t, err)
assert.EqualValues(t, expectedToken, token)
}

140
models/actions/schedule.go Normal file
View File

@@ -0,0 +1,140 @@
// Copyright 2023 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package actions
import (
"context"
"fmt"
"time"
"code.gitea.io/gitea/models/db"
repo_model "code.gitea.io/gitea/models/repo"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/timeutil"
"code.gitea.io/gitea/modules/util"
webhook_module "code.gitea.io/gitea/modules/webhook"
)
// ActionSchedule represents a schedule of a workflow file
type ActionSchedule struct {
ID int64
Title string
Specs []string
RepoID int64 `xorm:"index"`
Repo *repo_model.Repository `xorm:"-"`
OwnerID int64 `xorm:"index"`
WorkflowID string
TriggerUserID int64
TriggerUser *user_model.User `xorm:"-"`
Ref string
CommitSHA string
Event webhook_module.HookEventType
EventPayload string `xorm:"LONGTEXT"`
Content []byte
Created timeutil.TimeStamp `xorm:"created"`
Updated timeutil.TimeStamp `xorm:"updated"`
}
func init() {
db.RegisterModel(new(ActionSchedule))
}
// GetSchedulesMapByIDs returns the schedules by given id slice.
func GetSchedulesMapByIDs(ctx context.Context, ids []int64) (map[int64]*ActionSchedule, error) {
schedules := make(map[int64]*ActionSchedule, len(ids))
return schedules, db.GetEngine(ctx).In("id", ids).Find(&schedules)
}
// GetReposMapByIDs returns the repos by given id slice.
func GetReposMapByIDs(ctx context.Context, ids []int64) (map[int64]*repo_model.Repository, error) {
repos := make(map[int64]*repo_model.Repository, len(ids))
return repos, db.GetEngine(ctx).In("id", ids).Find(&repos)
}
// CreateScheduleTask creates new schedule task.
func CreateScheduleTask(ctx context.Context, rows []*ActionSchedule) error {
// Return early if there are no rows to insert
if len(rows) == 0 {
return nil
}
// Begin transaction
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return err
}
defer committer.Close()
// Loop through each schedule row
for _, row := range rows {
row.Title, _ = util.SplitStringAtByteN(row.Title, 255)
// Create new schedule row
if err = db.Insert(ctx, row); err != nil {
return err
}
// Loop through each schedule spec and create a new spec row
now := time.Now()
for _, spec := range row.Specs {
specRow := &ActionScheduleSpec{
RepoID: row.RepoID,
ScheduleID: row.ID,
Spec: spec,
}
// Parse the spec and check for errors
schedule, err := specRow.Parse()
if err != nil {
continue // skip to the next spec if there's an error
}
specRow.Next = timeutil.TimeStamp(schedule.Next(now).Unix())
// Insert the new schedule spec row
if err = db.Insert(ctx, specRow); err != nil {
return err
}
}
}
// Commit transaction
return committer.Commit()
}
func DeleteScheduleTaskByRepo(ctx context.Context, id int64) error {
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return err
}
defer committer.Close()
if _, err := db.GetEngine(ctx).Delete(&ActionSchedule{RepoID: id}); err != nil {
return err
}
if _, err := db.GetEngine(ctx).Delete(&ActionScheduleSpec{RepoID: id}); err != nil {
return err
}
return committer.Commit()
}
func CleanRepoScheduleTasks(ctx context.Context, repo *repo_model.Repository) error {
// If actions disabled when there is schedule task, this will remove the outdated schedule tasks
// There is no other place we can do this because the app.ini will be changed manually
if err := DeleteScheduleTaskByRepo(ctx, repo.ID); err != nil {
return fmt.Errorf("DeleteCronTaskByRepo: %v", err)
}
// cancel running cron jobs of this repository and delete old schedules
if err := CancelPreviousJobs(
ctx,
repo.ID,
repo.DefaultBranch,
"",
webhook_module.HookEventSchedule,
); err != nil {
return fmt.Errorf("CancelPreviousJobs: %v", err)
}
return nil
}

View File

@@ -0,0 +1,83 @@
// Copyright 2023 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package actions
import (
"context"
"code.gitea.io/gitea/models/db"
repo_model "code.gitea.io/gitea/models/repo"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/container"
"xorm.io/builder"
)
type ScheduleList []*ActionSchedule
// GetUserIDs returns a slice of user's id
func (schedules ScheduleList) GetUserIDs() []int64 {
return container.FilterSlice(schedules, func(schedule *ActionSchedule) (int64, bool) {
return schedule.TriggerUserID, true
})
}
func (schedules ScheduleList) GetRepoIDs() []int64 {
return container.FilterSlice(schedules, func(schedule *ActionSchedule) (int64, bool) {
return schedule.RepoID, true
})
}
func (schedules ScheduleList) LoadTriggerUser(ctx context.Context) error {
userIDs := schedules.GetUserIDs()
users := make(map[int64]*user_model.User, len(userIDs))
if err := db.GetEngine(ctx).In("id", userIDs).Find(&users); err != nil {
return err
}
for _, schedule := range schedules {
if schedule.TriggerUserID == user_model.ActionsUserID {
schedule.TriggerUser = user_model.NewActionsUser()
} else {
schedule.TriggerUser = users[schedule.TriggerUserID]
if schedule.TriggerUser == nil {
schedule.TriggerUser = user_model.NewGhostUser()
}
}
}
return nil
}
func (schedules ScheduleList) LoadRepos(ctx context.Context) error {
repoIDs := schedules.GetRepoIDs()
repos, err := repo_model.GetRepositoriesMapByIDs(ctx, repoIDs)
if err != nil {
return err
}
for _, schedule := range schedules {
schedule.Repo = repos[schedule.RepoID]
}
return nil
}
type FindScheduleOptions struct {
db.ListOptions
RepoID int64
OwnerID int64
}
func (opts FindScheduleOptions) ToConds() builder.Cond {
cond := builder.NewCond()
if opts.RepoID > 0 {
cond = cond.And(builder.Eq{"repo_id": opts.RepoID})
}
if opts.OwnerID > 0 {
cond = cond.And(builder.Eq{"owner_id": opts.OwnerID})
}
return cond
}
func (opts FindScheduleOptions) ToOrders() string {
return "`id` DESC"
}

View File

@@ -0,0 +1,73 @@
// Copyright 2023 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package actions
import (
"context"
"strings"
"time"
"code.gitea.io/gitea/models/db"
repo_model "code.gitea.io/gitea/models/repo"
"code.gitea.io/gitea/modules/timeutil"
"github.com/robfig/cron/v3"
)
// ActionScheduleSpec represents a schedule spec of a workflow file
type ActionScheduleSpec struct {
ID int64
RepoID int64 `xorm:"index"`
Repo *repo_model.Repository `xorm:"-"`
ScheduleID int64 `xorm:"index"`
Schedule *ActionSchedule `xorm:"-"`
// Next time the job will run, or the zero time if Cron has not been
// started or this entry's schedule is unsatisfiable
Next timeutil.TimeStamp `xorm:"index"`
// Prev is the last time this job was run, or the zero time if never.
Prev timeutil.TimeStamp
Spec string
Created timeutil.TimeStamp `xorm:"created"`
Updated timeutil.TimeStamp `xorm:"updated"`
}
// Parse parses the spec and returns a cron.Schedule
// Unlike the default cron parser, Parse uses UTC timezone as the default if none is specified.
func (s *ActionScheduleSpec) Parse() (cron.Schedule, error) {
parser := cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor)
schedule, err := parser.Parse(s.Spec)
if err != nil {
return nil, err
}
// If the spec has specified a timezone, use it
if strings.HasPrefix(s.Spec, "TZ=") || strings.HasPrefix(s.Spec, "CRON_TZ=") {
return schedule, nil
}
specSchedule, ok := schedule.(*cron.SpecSchedule)
// If it's not a spec schedule, like "@every 5m", timezone is not relevant
if !ok {
return schedule, nil
}
// Set the timezone to UTC
specSchedule.Location = time.UTC
return specSchedule, nil
}
func init() {
db.RegisterModel(new(ActionScheduleSpec))
}
func UpdateScheduleSpec(ctx context.Context, spec *ActionScheduleSpec, cols ...string) error {
sess := db.GetEngine(ctx).ID(spec.ID)
if len(cols) > 0 {
sess.Cols(cols...)
}
_, err := sess.Update(spec)
return err
}

View File

@@ -0,0 +1,97 @@
// Copyright 2023 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package actions
import (
"context"
"code.gitea.io/gitea/models/db"
repo_model "code.gitea.io/gitea/models/repo"
"code.gitea.io/gitea/modules/container"
"xorm.io/builder"
)
type SpecList []*ActionScheduleSpec
func (specs SpecList) GetScheduleIDs() []int64 {
return container.FilterSlice(specs, func(spec *ActionScheduleSpec) (int64, bool) {
return spec.ScheduleID, true
})
}
func (specs SpecList) LoadSchedules(ctx context.Context) error {
scheduleIDs := specs.GetScheduleIDs()
schedules, err := GetSchedulesMapByIDs(ctx, scheduleIDs)
if err != nil {
return err
}
for _, spec := range specs {
spec.Schedule = schedules[spec.ScheduleID]
}
repoIDs := specs.GetRepoIDs()
repos, err := GetReposMapByIDs(ctx, repoIDs)
if err != nil {
return err
}
for _, spec := range specs {
spec.Repo = repos[spec.RepoID]
}
return nil
}
func (specs SpecList) GetRepoIDs() []int64 {
return container.FilterSlice(specs, func(spec *ActionScheduleSpec) (int64, bool) {
return spec.RepoID, true
})
}
func (specs SpecList) LoadRepos(ctx context.Context) error {
repoIDs := specs.GetRepoIDs()
repos, err := repo_model.GetRepositoriesMapByIDs(ctx, repoIDs)
if err != nil {
return err
}
for _, spec := range specs {
spec.Repo = repos[spec.RepoID]
}
return nil
}
type FindSpecOptions struct {
db.ListOptions
RepoID int64
Next int64
}
func (opts FindSpecOptions) ToConds() builder.Cond {
cond := builder.NewCond()
if opts.RepoID > 0 {
cond = cond.And(builder.Eq{"repo_id": opts.RepoID})
}
if opts.Next > 0 {
cond = cond.And(builder.Lte{"next": opts.Next})
}
return cond
}
func (opts FindSpecOptions) ToOrders() string {
return "`id` DESC"
}
func FindSpecs(ctx context.Context, opts FindSpecOptions) (SpecList, int64, error) {
specs, total, err := db.FindAndCount[ActionScheduleSpec](ctx, opts)
if err != nil {
return nil, 0, err
}
if err := SpecList(specs).LoadSchedules(ctx); err != nil {
return nil, 0, err
}
return specs, total, nil
}

View File

@@ -0,0 +1,69 @@
// Copyright 2024 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package actions
import (
"testing"
"time"
"code.gitea.io/gitea/modules/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestActionScheduleSpec_Parse(t *testing.T) {
// Mock the local timezone is not UTC
tz, err := time.LoadLocation("Asia/Shanghai")
require.NoError(t, err)
defer test.MockVariableValue(&time.Local, tz)()
now, err := time.Parse(time.RFC3339, "2024-07-31T15:47:55+08:00")
require.NoError(t, err)
tests := []struct {
name string
spec string
want string
wantErr assert.ErrorAssertionFunc
}{
{
name: "regular",
spec: "0 10 * * *",
want: "2024-07-31T10:00:00Z",
wantErr: assert.NoError,
},
{
name: "invalid",
spec: "0 10 * *",
want: "",
wantErr: assert.Error,
},
{
name: "with timezone",
spec: "TZ=America/New_York 0 10 * * *",
want: "2024-07-31T14:00:00Z",
wantErr: assert.NoError,
},
{
name: "timezone irrelevant",
spec: "@every 5m",
want: "2024-07-31T07:52:55Z",
wantErr: assert.NoError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &ActionScheduleSpec{
Spec: tt.spec,
}
got, err := s.Parse()
tt.wantErr(t, err)
if err == nil {
assert.Equal(t, tt.want, got.Next(now).UTC().Format(time.RFC3339))
}
})
}
}

104
models/actions/status.go Normal file
View File

@@ -0,0 +1,104 @@
// Copyright 2022 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package actions
import (
"code.gitea.io/gitea/modules/translation"
runnerv1 "code.gitea.io/actions-proto-go/runner/v1"
)
// Status represents the status of ActionRun, ActionRunJob, ActionTask, or ActionTaskStep
type Status int
const (
StatusUnknown Status = iota // 0, consistent with runnerv1.Result_RESULT_UNSPECIFIED
StatusSuccess // 1, consistent with runnerv1.Result_RESULT_SUCCESS
StatusFailure // 2, consistent with runnerv1.Result_RESULT_FAILURE
StatusCancelled // 3, consistent with runnerv1.Result_RESULT_CANCELLED
StatusSkipped // 4, consistent with runnerv1.Result_RESULT_SKIPPED
StatusWaiting // 5, isn't a runnerv1.Result
StatusRunning // 6, isn't a runnerv1.Result
StatusBlocked // 7, isn't a runnerv1.Result
)
var statusNames = map[Status]string{
StatusUnknown: "unknown",
StatusWaiting: "waiting",
StatusRunning: "running",
StatusSuccess: "success",
StatusFailure: "failure",
StatusCancelled: "cancelled",
StatusSkipped: "skipped",
StatusBlocked: "blocked",
}
// String returns the string name of the Status
func (s Status) String() string {
return statusNames[s]
}
// LocaleString returns the locale string name of the Status
func (s Status) LocaleString(lang translation.Locale) string {
return lang.TrString("actions.status." + s.String())
}
// IsDone returns whether the Status is final
func (s Status) IsDone() bool {
return s.In(StatusSuccess, StatusFailure, StatusCancelled, StatusSkipped)
}
// HasRun returns whether the Status is a result of running
func (s Status) HasRun() bool {
return s.In(StatusSuccess, StatusFailure)
}
func (s Status) IsUnknown() bool {
return s == StatusUnknown
}
func (s Status) IsSuccess() bool {
return s == StatusSuccess
}
func (s Status) IsFailure() bool {
return s == StatusFailure
}
func (s Status) IsCancelled() bool {
return s == StatusCancelled
}
func (s Status) IsSkipped() bool {
return s == StatusSkipped
}
func (s Status) IsWaiting() bool {
return s == StatusWaiting
}
func (s Status) IsRunning() bool {
return s == StatusRunning
}
func (s Status) IsBlocked() bool {
return s == StatusBlocked
}
// In returns whether s is one of the given statuses
func (s Status) In(statuses ...Status) bool {
for _, v := range statuses {
if s == v {
return true
}
}
return false
}
func (s Status) AsResult() runnerv1.Result {
if s.IsDone() {
return runnerv1.Result(s)
}
return runnerv1.Result_RESULT_UNSPECIFIED
}

529
models/actions/task.go Normal file
View File

@@ -0,0 +1,529 @@
// Copyright 2022 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package actions
import (
"context"
"crypto/subtle"
"fmt"
"time"
auth_model "code.gitea.io/gitea/models/auth"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/unit"
"code.gitea.io/gitea/modules/container"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/timeutil"
"code.gitea.io/gitea/modules/util"
runnerv1 "code.gitea.io/actions-proto-go/runner/v1"
lru "github.com/hashicorp/golang-lru/v2"
"github.com/nektos/act/pkg/jobparser"
"google.golang.org/protobuf/types/known/timestamppb"
"xorm.io/builder"
)
// ActionTask represents a distribution of job
type ActionTask struct {
ID int64
JobID int64
Job *ActionRunJob `xorm:"-"`
Steps []*ActionTaskStep `xorm:"-"`
Attempt int64
RunnerID int64 `xorm:"index"`
Status Status `xorm:"index"`
Started timeutil.TimeStamp `xorm:"index"`
Stopped timeutil.TimeStamp `xorm:"index(stopped_log_expired)"`
RepoID int64 `xorm:"index"`
OwnerID int64 `xorm:"index"`
CommitSHA string `xorm:"index"`
IsForkPullRequest bool
Token string `xorm:"-"`
TokenHash string `xorm:"UNIQUE"` // sha256 of token
TokenSalt string
TokenLastEight string `xorm:"index token_last_eight"`
LogFilename string // file name of log
LogInStorage bool // read log from database or from storage
LogLength int64 // lines count
LogSize int64 // blob size
LogIndexes LogIndexes `xorm:"LONGBLOB"` // line number to offset
LogExpired bool `xorm:"index(stopped_log_expired)"` // files that are too old will be deleted
Created timeutil.TimeStamp `xorm:"created"`
Updated timeutil.TimeStamp `xorm:"updated index"`
}
var successfulTokenTaskCache *lru.Cache[string, any]
func init() {
db.RegisterModel(new(ActionTask), func() error {
if setting.SuccessfulTokensCacheSize > 0 {
var err error
successfulTokenTaskCache, err = lru.New[string, any](setting.SuccessfulTokensCacheSize)
if err != nil {
return fmt.Errorf("unable to allocate Task cache: %v", err)
}
} else {
successfulTokenTaskCache = nil
}
return nil
})
}
func (task *ActionTask) Duration() time.Duration {
return calculateDuration(task.Started, task.Stopped, task.Status)
}
func (task *ActionTask) IsStopped() bool {
return task.Stopped > 0
}
func (task *ActionTask) GetRunLink() string {
if task.Job == nil || task.Job.Run == nil {
return ""
}
return task.Job.Run.Link()
}
func (task *ActionTask) GetCommitLink() string {
if task.Job == nil || task.Job.Run == nil || task.Job.Run.Repo == nil {
return ""
}
return task.Job.Run.Repo.CommitLink(task.CommitSHA)
}
func (task *ActionTask) GetRepoName() string {
if task.Job == nil || task.Job.Run == nil || task.Job.Run.Repo == nil {
return ""
}
return task.Job.Run.Repo.FullName()
}
func (task *ActionTask) GetRepoLink() string {
if task.Job == nil || task.Job.Run == nil || task.Job.Run.Repo == nil {
return ""
}
return task.Job.Run.Repo.Link()
}
func (task *ActionTask) LoadJob(ctx context.Context) error {
if task.Job == nil {
job, err := GetRunJobByID(ctx, task.JobID)
if err != nil {
return err
}
task.Job = job
}
return nil
}
// LoadAttributes load Job Steps if not loaded
func (task *ActionTask) LoadAttributes(ctx context.Context) error {
if task == nil {
return nil
}
if err := task.LoadJob(ctx); err != nil {
return err
}
if err := task.Job.LoadAttributes(ctx); err != nil {
return err
}
if task.Steps == nil { // be careful, an empty slice (not nil) also means loaded
steps, err := GetTaskStepsByTaskID(ctx, task.ID)
if err != nil {
return err
}
task.Steps = steps
}
return nil
}
func (task *ActionTask) GenerateToken() (err error) {
task.Token, task.TokenSalt, task.TokenHash, task.TokenLastEight, err = generateSaltedToken()
return err
}
func GetTaskByID(ctx context.Context, id int64) (*ActionTask, error) {
var task ActionTask
has, err := db.GetEngine(ctx).Where("id=?", id).Get(&task)
if err != nil {
return nil, err
} else if !has {
return nil, fmt.Errorf("task with id %d: %w", id, util.ErrNotExist)
}
return &task, nil
}
func GetRunningTaskByToken(ctx context.Context, token string) (*ActionTask, error) {
errNotExist := fmt.Errorf("task with token %q: %w", token, util.ErrNotExist)
if token == "" {
return nil, errNotExist
}
// A token is defined as being SHA1 sum these are 40 hexadecimal bytes long
if len(token) != 40 {
return nil, errNotExist
}
for _, x := range []byte(token) {
if x < '0' || (x > '9' && x < 'a') || x > 'f' {
return nil, errNotExist
}
}
lastEight := token[len(token)-8:]
if id := getTaskIDFromCache(token); id > 0 {
task := &ActionTask{
TokenLastEight: lastEight,
}
// Re-get the task from the db in case it has been deleted in the intervening period
has, err := db.GetEngine(ctx).ID(id).Get(task)
if err != nil {
return nil, err
}
if has {
return task, nil
}
successfulTokenTaskCache.Remove(token)
}
var tasks []*ActionTask
err := db.GetEngine(ctx).Where("token_last_eight = ? AND status = ?", lastEight, StatusRunning).Find(&tasks)
if err != nil {
return nil, err
} else if len(tasks) == 0 {
return nil, errNotExist
}
for _, t := range tasks {
tempHash := auth_model.HashToken(token, t.TokenSalt)
if subtle.ConstantTimeCompare([]byte(t.TokenHash), []byte(tempHash)) == 1 {
if successfulTokenTaskCache != nil {
successfulTokenTaskCache.Add(token, t.ID)
}
return t, nil
}
}
return nil, errNotExist
}
func CreateTaskForRunner(ctx context.Context, runner *ActionRunner) (*ActionTask, bool, error) {
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return nil, false, err
}
defer committer.Close()
e := db.GetEngine(ctx)
jobCond := builder.NewCond()
if runner.RepoID != 0 {
jobCond = builder.Eq{"repo_id": runner.RepoID}
} else if runner.OwnerID != 0 {
jobCond = builder.In("repo_id", builder.Select("`repository`.id").From("repository").
Join("INNER", "repo_unit", "`repository`.id = `repo_unit`.repo_id").
Where(builder.Eq{"`repository`.owner_id": runner.OwnerID, "`repo_unit`.type": unit.TypeActions}))
}
if jobCond.IsValid() {
jobCond = builder.In("run_id", builder.Select("id").From("action_run").Where(jobCond))
}
var jobs []*ActionRunJob
if err := e.Where("task_id=? AND status=?", 0, StatusWaiting).And(jobCond).Asc("updated", "id").Find(&jobs); err != nil {
return nil, false, err
}
// TODO: a more efficient way to filter labels
var job *ActionRunJob
log.Trace("runner labels: %v", runner.AgentLabels)
for _, v := range jobs {
if isSubset(runner.AgentLabels, v.RunsOn) {
job = v
break
}
}
if job == nil {
return nil, false, nil
}
if err := job.LoadAttributes(ctx); err != nil {
return nil, false, err
}
now := timeutil.TimeStampNow()
job.Attempt++
job.Started = now
job.Status = StatusRunning
task := &ActionTask{
JobID: job.ID,
Attempt: job.Attempt,
RunnerID: runner.ID,
Started: now,
Status: StatusRunning,
RepoID: job.RepoID,
OwnerID: job.OwnerID,
CommitSHA: job.CommitSHA,
IsForkPullRequest: job.IsForkPullRequest,
}
if err := task.GenerateToken(); err != nil {
return nil, false, err
}
var workflowJob *jobparser.Job
if gots, err := jobparser.Parse(job.WorkflowPayload); err != nil {
return nil, false, fmt.Errorf("parse workflow of job %d: %w", job.ID, err)
} else if len(gots) != 1 {
return nil, false, fmt.Errorf("workflow of job %d: not single workflow", job.ID)
} else { //nolint:revive
_, workflowJob = gots[0].Job()
}
if _, err := e.Insert(task); err != nil {
return nil, false, err
}
task.LogFilename = logFileName(job.Run.Repo.FullName(), task.ID)
if err := UpdateTask(ctx, task, "log_filename"); err != nil {
return nil, false, err
}
if len(workflowJob.Steps) > 0 {
steps := make([]*ActionTaskStep, len(workflowJob.Steps))
for i, v := range workflowJob.Steps {
name, _ := util.SplitStringAtByteN(v.String(), 255)
steps[i] = &ActionTaskStep{
Name: name,
TaskID: task.ID,
Index: int64(i),
RepoID: task.RepoID,
Status: StatusWaiting,
}
}
if _, err := e.Insert(steps); err != nil {
return nil, false, err
}
task.Steps = steps
}
job.TaskID = task.ID
if n, err := UpdateRunJob(ctx, job, builder.Eq{"task_id": 0}); err != nil {
return nil, false, err
} else if n != 1 {
return nil, false, nil
}
task.Job = job
if err := committer.Commit(); err != nil {
return nil, false, err
}
return task, true, nil
}
func UpdateTask(ctx context.Context, task *ActionTask, cols ...string) error {
sess := db.GetEngine(ctx).ID(task.ID)
if len(cols) > 0 {
sess.Cols(cols...)
}
_, err := sess.Update(task)
return err
}
// UpdateTaskByState updates the task by the state.
// It will always update the task if the state is not final, even there is no change.
// So it will update ActionTask.Updated to avoid the task being judged as a zombie task.
func UpdateTaskByState(ctx context.Context, runnerID int64, state *runnerv1.TaskState) (*ActionTask, error) {
stepStates := map[int64]*runnerv1.StepState{}
for _, v := range state.Steps {
stepStates[v.Id] = v
}
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return nil, err
}
defer committer.Close()
e := db.GetEngine(ctx)
task := &ActionTask{}
if has, err := e.ID(state.Id).Get(task); err != nil {
return nil, err
} else if !has {
return nil, util.ErrNotExist
} else if runnerID != task.RunnerID {
return nil, fmt.Errorf("invalid runner for task")
}
if task.Status.IsDone() {
// the state is final, do nothing
return task, nil
}
// state.Result is not unspecified means the task is finished
if state.Result != runnerv1.Result_RESULT_UNSPECIFIED {
task.Status = Status(state.Result)
task.Stopped = timeutil.TimeStamp(state.StoppedAt.AsTime().Unix())
if err := UpdateTask(ctx, task, "status", "stopped"); err != nil {
return nil, err
}
if _, err := UpdateRunJob(ctx, &ActionRunJob{
ID: task.JobID,
Status: task.Status,
Stopped: task.Stopped,
}, nil); err != nil {
return nil, err
}
} else {
// Force update ActionTask.Updated to avoid the task being judged as a zombie task
task.Updated = timeutil.TimeStampNow()
if err := UpdateTask(ctx, task, "updated"); err != nil {
return nil, err
}
}
if err := task.LoadAttributes(ctx); err != nil {
return nil, err
}
for _, step := range task.Steps {
var result runnerv1.Result
if v, ok := stepStates[step.Index]; ok {
result = v.Result
step.LogIndex = v.LogIndex
step.LogLength = v.LogLength
step.Started = convertTimestamp(v.StartedAt)
step.Stopped = convertTimestamp(v.StoppedAt)
}
if result != runnerv1.Result_RESULT_UNSPECIFIED {
step.Status = Status(result)
} else if step.Started != 0 {
step.Status = StatusRunning
}
if _, err := e.ID(step.ID).Update(step); err != nil {
return nil, err
}
}
if err := committer.Commit(); err != nil {
return nil, err
}
return task, nil
}
func StopTask(ctx context.Context, taskID int64, status Status) error {
if !status.IsDone() {
return fmt.Errorf("cannot stop task with status %v", status)
}
e := db.GetEngine(ctx)
task := &ActionTask{}
if has, err := e.ID(taskID).Get(task); err != nil {
return err
} else if !has {
return util.ErrNotExist
}
if task.Status.IsDone() {
return nil
}
now := timeutil.TimeStampNow()
task.Status = status
task.Stopped = now
if _, err := UpdateRunJob(ctx, &ActionRunJob{
ID: task.JobID,
Status: task.Status,
Stopped: task.Stopped,
}, nil); err != nil {
return err
}
if err := UpdateTask(ctx, task, "status", "stopped"); err != nil {
return err
}
if err := task.LoadAttributes(ctx); err != nil {
return err
}
for _, step := range task.Steps {
if !step.Status.IsDone() {
step.Status = status
if step.Started == 0 {
step.Started = now
}
step.Stopped = now
}
if _, err := e.ID(step.ID).Update(step); err != nil {
return err
}
}
return nil
}
func FindOldTasksToExpire(ctx context.Context, olderThan timeutil.TimeStamp, limit int) ([]*ActionTask, error) {
e := db.GetEngine(ctx)
tasks := make([]*ActionTask, 0, limit)
// Check "stopped > 0" to avoid deleting tasks that are still running
return tasks, e.Where("stopped > 0 AND stopped < ? AND log_expired = ?", olderThan, false).
Limit(limit).
Find(&tasks)
}
func isSubset(set, subset []string) bool {
m := make(container.Set[string], len(set))
for _, v := range set {
m.Add(v)
}
for _, v := range subset {
if !m.Contains(v) {
return false
}
}
return true
}
func convertTimestamp(timestamp *timestamppb.Timestamp) timeutil.TimeStamp {
if timestamp.GetSeconds() == 0 && timestamp.GetNanos() == 0 {
return timeutil.TimeStamp(0)
}
return timeutil.TimeStamp(timestamp.AsTime().Unix())
}
func logFileName(repoFullName string, taskID int64) string {
ret := fmt.Sprintf("%s/%02x/%d.log", repoFullName, taskID%256, taskID)
if setting.Actions.LogCompression.IsZstd() {
ret += ".zst"
}
return ret
}
func getTaskIDFromCache(token string) int64 {
if successfulTokenTaskCache == nil {
return 0
}
tInterface, ok := successfulTokenTaskCache.Get(token)
if !ok {
return 0
}
t, ok := tInterface.(int64)
if !ok {
return 0
}
return t
}

View File

@@ -0,0 +1,87 @@
// Copyright 2022 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package actions
import (
"context"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/container"
"code.gitea.io/gitea/modules/timeutil"
"xorm.io/builder"
)
type TaskList []*ActionTask
func (tasks TaskList) GetJobIDs() []int64 {
return container.FilterSlice(tasks, func(t *ActionTask) (int64, bool) {
return t.JobID, t.JobID != 0
})
}
func (tasks TaskList) LoadJobs(ctx context.Context) error {
jobIDs := tasks.GetJobIDs()
jobs := make(map[int64]*ActionRunJob, len(jobIDs))
if err := db.GetEngine(ctx).In("id", jobIDs).Find(&jobs); err != nil {
return err
}
for _, t := range tasks {
if t.JobID > 0 && t.Job == nil {
t.Job = jobs[t.JobID]
}
}
// TODO: Replace with "ActionJobList(maps.Values(jobs))" once available
var jobsList ActionJobList = make([]*ActionRunJob, 0, len(jobs))
for _, j := range jobs {
jobsList = append(jobsList, j)
}
return jobsList.LoadAttributes(ctx, true)
}
func (tasks TaskList) LoadAttributes(ctx context.Context) error {
return tasks.LoadJobs(ctx)
}
type FindTaskOptions struct {
db.ListOptions
RepoID int64
OwnerID int64
CommitSHA string
Status Status
UpdatedBefore timeutil.TimeStamp
StartedBefore timeutil.TimeStamp
RunnerID int64
}
func (opts FindTaskOptions) ToConds() builder.Cond {
cond := builder.NewCond()
if opts.RepoID > 0 {
cond = cond.And(builder.Eq{"repo_id": opts.RepoID})
}
if opts.OwnerID > 0 {
cond = cond.And(builder.Eq{"owner_id": opts.OwnerID})
}
if opts.CommitSHA != "" {
cond = cond.And(builder.Eq{"commit_sha": opts.CommitSHA})
}
if opts.Status > StatusUnknown {
cond = cond.And(builder.Eq{"status": opts.Status})
}
if opts.UpdatedBefore > 0 {
cond = cond.And(builder.Lt{"updated": opts.UpdatedBefore})
}
if opts.StartedBefore > 0 {
cond = cond.And(builder.Lt{"started": opts.StartedBefore})
}
if opts.RunnerID > 0 {
cond = cond.And(builder.Eq{"runner_id": opts.RunnerID})
}
return cond
}
func (opts FindTaskOptions) ToOrders() string {
return "`id` DESC"
}

View File

@@ -0,0 +1,55 @@
// Copyright 2023 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package actions
import (
"context"
"code.gitea.io/gitea/models/db"
)
// ActionTaskOutput represents an output of ActionTask.
// So the outputs are bound to a task, that means when a completed job has been rerun,
// the outputs of the job will be reset because the task is new.
// It's by design, to avoid the outputs of the old task to be mixed with the new task.
type ActionTaskOutput struct {
ID int64
TaskID int64 `xorm:"INDEX UNIQUE(task_id_output_key)"`
OutputKey string `xorm:"VARCHAR(255) UNIQUE(task_id_output_key)"`
OutputValue string `xorm:"MEDIUMTEXT"`
}
func init() {
db.RegisterModel(new(ActionTaskOutput))
}
// FindTaskOutputByTaskID returns the outputs of the task.
func FindTaskOutputByTaskID(ctx context.Context, taskID int64) ([]*ActionTaskOutput, error) {
var outputs []*ActionTaskOutput
return outputs, db.GetEngine(ctx).Where("task_id=?", taskID).Find(&outputs)
}
// FindTaskOutputKeyByTaskID returns the keys of the outputs of the task.
func FindTaskOutputKeyByTaskID(ctx context.Context, taskID int64) ([]string, error) {
var keys []string
return keys, db.GetEngine(ctx).Table(ActionTaskOutput{}).Where("task_id=?", taskID).Cols("output_key").Find(&keys)
}
// InsertTaskOutputIfNotExist inserts a new task output if it does not exist.
func InsertTaskOutputIfNotExist(ctx context.Context, taskID int64, key, value string) error {
return db.WithTx(ctx, func(ctx context.Context) error {
sess := db.GetEngine(ctx)
if exist, err := sess.Exist(&ActionTaskOutput{TaskID: taskID, OutputKey: key}); err != nil {
return err
} else if exist {
return nil
}
_, err := sess.Insert(&ActionTaskOutput{
TaskID: taskID,
OutputKey: key,
OutputValue: value,
})
return err
})
}

View File

@@ -0,0 +1,41 @@
// Copyright 2022 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package actions
import (
"context"
"time"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/timeutil"
)
// ActionTaskStep represents a step of ActionTask
type ActionTaskStep struct {
ID int64
Name string `xorm:"VARCHAR(255)"`
TaskID int64 `xorm:"index unique(task_index)"`
Index int64 `xorm:"index unique(task_index)"`
RepoID int64 `xorm:"index"`
Status Status `xorm:"index"`
LogIndex int64
LogLength int64
Started timeutil.TimeStamp
Stopped timeutil.TimeStamp
Created timeutil.TimeStamp `xorm:"created"`
Updated timeutil.TimeStamp `xorm:"updated"`
}
func (step *ActionTaskStep) Duration() time.Duration {
return calculateDuration(step.Started, step.Stopped, step.Status)
}
func init() {
db.RegisterModel(new(ActionTaskStep))
}
func GetTaskStepsByTaskID(ctx context.Context, taskID int64) ([]*ActionTaskStep, error) {
var steps []*ActionTaskStep
return steps, db.GetEngine(ctx).Where("task_id=?", taskID).OrderBy("`index` ASC").Find(&steps)
}

View File

@@ -0,0 +1,105 @@
// Copyright 2023 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package actions
import (
"context"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/timeutil"
)
// ActionTasksVersion
// If both ownerID and repoID is zero, its scope is global.
// If ownerID is not zero and repoID is zero, its scope is org (there is no user-level runner currently).
// If ownerID is zero and repoID is not zero, its scope is repo.
type ActionTasksVersion struct {
ID int64 `xorm:"pk autoincr"`
OwnerID int64 `xorm:"UNIQUE(owner_repo)"`
RepoID int64 `xorm:"INDEX UNIQUE(owner_repo)"`
Version int64
CreatedUnix timeutil.TimeStamp `xorm:"created"`
UpdatedUnix timeutil.TimeStamp `xorm:"updated"`
}
func init() {
db.RegisterModel(new(ActionTasksVersion))
}
func GetTasksVersionByScope(ctx context.Context, ownerID, repoID int64) (int64, error) {
var tasksVersion ActionTasksVersion
has, err := db.GetEngine(ctx).Where("owner_id = ? AND repo_id = ?", ownerID, repoID).Get(&tasksVersion)
if err != nil {
return 0, err
} else if !has {
return 0, nil
}
return tasksVersion.Version, err
}
func insertTasksVersion(ctx context.Context, ownerID, repoID int64) (*ActionTasksVersion, error) {
tasksVersion := &ActionTasksVersion{
OwnerID: ownerID,
RepoID: repoID,
Version: 1,
}
if _, err := db.GetEngine(ctx).Insert(tasksVersion); err != nil {
return nil, err
}
return tasksVersion, nil
}
func increaseTasksVersionByScope(ctx context.Context, ownerID, repoID int64) error {
result, err := db.GetEngine(ctx).Exec("UPDATE action_tasks_version SET version = version + 1 WHERE owner_id = ? AND repo_id = ?", ownerID, repoID)
if err != nil {
return err
}
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
// if update sql does not affect any rows, the database may be broken,
// so re-insert the row of version data here.
if _, err := insertTasksVersion(ctx, ownerID, repoID); err != nil {
return err
}
}
return nil
}
func IncreaseTaskVersion(ctx context.Context, ownerID, repoID int64) error {
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return err
}
defer committer.Close()
// 1. increase global
if err := increaseTasksVersionByScope(ctx, 0, 0); err != nil {
log.Error("IncreaseTasksVersionByScope(Global): %v", err)
return err
}
// 2. increase owner
if ownerID > 0 {
if err := increaseTasksVersionByScope(ctx, ownerID, 0); err != nil {
log.Error("IncreaseTasksVersionByScope(Owner): %v", err)
return err
}
}
// 3. increase repo
if repoID > 0 {
if err := increaseTasksVersionByScope(ctx, 0, repoID); err != nil {
log.Error("IncreaseTasksVersionByScope(Repo): %v", err)
return err
}
}
return committer.Commit()
}

84
models/actions/utils.go Normal file
View File

@@ -0,0 +1,84 @@
// Copyright 2022 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package actions
import (
"bytes"
"encoding/binary"
"encoding/hex"
"errors"
"fmt"
"io"
"time"
auth_model "code.gitea.io/gitea/models/auth"
"code.gitea.io/gitea/modules/timeutil"
"code.gitea.io/gitea/modules/util"
)
func generateSaltedToken() (string, string, string, string, error) {
salt, err := util.CryptoRandomString(10)
if err != nil {
return "", "", "", "", err
}
buf, err := util.CryptoRandomBytes(20)
if err != nil {
return "", "", "", "", err
}
token := hex.EncodeToString(buf)
hash := auth_model.HashToken(token, salt)
return token, salt, hash, token[len(token)-8:], nil
}
/*
LogIndexes is the index for mapping log line number to buffer offset.
Because it uses varint encoding, it is impossible to predict its size.
But we can make a simple estimate with an assumption that each log line has 200 byte, then:
| lines | file size | index size |
|-----------|---------------------|--------------------|
| 100 | 20 KiB(20000) | 258 B(258) |
| 1000 | 195 KiB(200000) | 2.9 KiB(2958) |
| 10000 | 1.9 MiB(2000000) | 34 KiB(34715) |
| 100000 | 19 MiB(20000000) | 386 KiB(394715) |
| 1000000 | 191 MiB(200000000) | 4.1 MiB(4323626) |
| 10000000 | 1.9 GiB(2000000000) | 47 MiB(49323626) |
| 100000000 | 19 GiB(20000000000) | 490 MiB(513424280) |
*/
type LogIndexes []int64
func (indexes *LogIndexes) FromDB(b []byte) error {
reader := bytes.NewReader(b)
for {
v, err := binary.ReadVarint(reader)
if err != nil {
if errors.Is(err, io.EOF) {
return nil
}
return fmt.Errorf("binary ReadVarint: %w", err)
}
*indexes = append(*indexes, v)
}
}
func (indexes *LogIndexes) ToDB() ([]byte, error) {
buf, i := make([]byte, binary.MaxVarintLen64*len(*indexes)), 0
for _, v := range *indexes {
n := binary.PutVarint(buf[i:], v)
i += n
}
return buf[:i], nil
}
var timeSince = time.Since
func calculateDuration(started, stopped timeutil.TimeStamp, status Status) time.Duration {
if started == 0 {
return 0
}
s := started.AsTime()
if status.IsDone() {
return stopped.AsTime().Sub(s)
}
return timeSince(s).Truncate(time.Second)
}

View File

@@ -0,0 +1,90 @@
// Copyright 2022 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package actions
import (
"math"
"testing"
"time"
"code.gitea.io/gitea/modules/timeutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLogIndexes_ToDB(t *testing.T) {
tests := []struct {
indexes LogIndexes
}{
{
indexes: []int64{1, 2, 0, -1, -2, math.MaxInt64, math.MinInt64},
},
}
for _, tt := range tests {
t.Run("", func(t *testing.T) {
got, err := tt.indexes.ToDB()
require.NoError(t, err)
indexes := LogIndexes{}
require.NoError(t, indexes.FromDB(got))
assert.Equal(t, tt.indexes, indexes)
})
}
}
func Test_calculateDuration(t *testing.T) {
oldTimeSince := timeSince
defer func() {
timeSince = oldTimeSince
}()
timeSince = func(t time.Time) time.Duration {
return timeutil.TimeStamp(1000).AsTime().Sub(t)
}
type args struct {
started timeutil.TimeStamp
stopped timeutil.TimeStamp
status Status
}
tests := []struct {
name string
args args
want time.Duration
}{
{
name: "unknown",
args: args{
started: 0,
stopped: 0,
status: StatusUnknown,
},
want: 0,
},
{
name: "running",
args: args{
started: 500,
stopped: 0,
status: StatusRunning,
},
want: 500 * time.Second,
},
{
name: "done",
args: args{
started: 500,
stopped: 600,
status: StatusSuccess,
},
want: 100 * time.Second,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equalf(t, tt.want, calculateDuration(tt.args.started, tt.args.stopped, tt.args.status), "calculateDuration(%v, %v, %v)", tt.args.started, tt.args.stopped, tt.args.status)
})
}
}

139
models/actions/variable.go Normal file
View File

@@ -0,0 +1,139 @@
// Copyright 2023 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package actions
import (
"context"
"strings"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/timeutil"
"xorm.io/builder"
)
// ActionVariable represents a variable that can be used in actions
//
// It can be:
// 1. global variable, OwnerID is 0 and RepoID is 0
// 2. org/user level variable, OwnerID is org/user ID and RepoID is 0
// 3. repo level variable, OwnerID is 0 and RepoID is repo ID
//
// Please note that it's not acceptable to have both OwnerID and RepoID to be non-zero,
// or it will be complicated to find variables belonging to a specific owner.
// For example, conditions like `OwnerID = 1` will also return variable {OwnerID: 1, RepoID: 1},
// but it's a repo level variable, not an org/user level variable.
// To avoid this, make it clear with {OwnerID: 0, RepoID: 1} for repo level variables.
type ActionVariable struct {
ID int64 `xorm:"pk autoincr"`
OwnerID int64 `xorm:"UNIQUE(owner_repo_name)"`
RepoID int64 `xorm:"INDEX UNIQUE(owner_repo_name)"`
Name string `xorm:"UNIQUE(owner_repo_name) NOT NULL"`
Data string `xorm:"LONGTEXT NOT NULL"`
CreatedUnix timeutil.TimeStamp `xorm:"created NOT NULL"`
UpdatedUnix timeutil.TimeStamp `xorm:"updated"`
}
func init() {
db.RegisterModel(new(ActionVariable))
}
func InsertVariable(ctx context.Context, ownerID, repoID int64, name, data string) (*ActionVariable, error) {
if ownerID != 0 && repoID != 0 {
// It's trying to create a variable that belongs to a repository, but OwnerID has been set accidentally.
// Remove OwnerID to avoid confusion; it's not worth returning an error here.
ownerID = 0
}
variable := &ActionVariable{
OwnerID: ownerID,
RepoID: repoID,
Name: strings.ToUpper(name),
Data: data,
}
return variable, db.Insert(ctx, variable)
}
type FindVariablesOpts struct {
db.ListOptions
RepoID int64
OwnerID int64 // it will be ignored if RepoID is set
Name string
}
func (opts FindVariablesOpts) ToConds() builder.Cond {
cond := builder.NewCond()
// Since we now support instance-level variables,
// there is no need to check for null values for `owner_id` and `repo_id`
cond = cond.And(builder.Eq{"repo_id": opts.RepoID})
if opts.RepoID != 0 { // if RepoID is set
// ignore OwnerID and treat it as 0
cond = cond.And(builder.Eq{"owner_id": 0})
} else {
cond = cond.And(builder.Eq{"owner_id": opts.OwnerID})
}
if opts.Name != "" {
cond = cond.And(builder.Eq{"name": strings.ToUpper(opts.Name)})
}
return cond
}
func FindVariables(ctx context.Context, opts FindVariablesOpts) ([]*ActionVariable, error) {
return db.Find[ActionVariable](ctx, opts)
}
func UpdateVariable(ctx context.Context, variable *ActionVariable) (bool, error) {
count, err := db.GetEngine(ctx).ID(variable.ID).Cols("name", "data").
Update(&ActionVariable{
Name: variable.Name,
Data: variable.Data,
})
return count != 0, err
}
func DeleteVariable(ctx context.Context, id int64) error {
if _, err := db.DeleteByID[ActionVariable](ctx, id); err != nil {
return err
}
return nil
}
func GetVariablesOfRun(ctx context.Context, run *ActionRun) (map[string]string, error) {
variables := map[string]string{}
if err := run.LoadRepo(ctx); err != nil {
log.Error("LoadRepo: %v", err)
return nil, err
}
// Global
globalVariables, err := db.Find[ActionVariable](ctx, FindVariablesOpts{})
if err != nil {
log.Error("find global variables: %v", err)
return nil, err
}
// Org / User level
ownerVariables, err := db.Find[ActionVariable](ctx, FindVariablesOpts{OwnerID: run.Repo.OwnerID})
if err != nil {
log.Error("find variables of org: %d, error: %v", run.Repo.OwnerID, err)
return nil, err
}
// Repo level
repoVariables, err := db.Find[ActionVariable](ctx, FindVariablesOpts{RepoID: run.RepoID})
if err != nil {
log.Error("find variables of repo: %d, error: %v", run.RepoID, err)
return nil, err
}
// Level precedence: Repo > Org / User > Global
for _, v := range append(globalVariables, append(ownerVariables, repoVariables...)...) {
variables[v.Name] = v.Data
}
return variables, nil
}

736
models/activities/action.go Normal file
View File

@@ -0,0 +1,736 @@
// Copyright 2014 The Gogs Authors. All rights reserved.
// Copyright 2019 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package activities
import (
"context"
"fmt"
"net/url"
"path"
"strconv"
"strings"
"time"
"code.gitea.io/gitea/models/db"
issues_model "code.gitea.io/gitea/models/issues"
"code.gitea.io/gitea/models/organization"
access_model "code.gitea.io/gitea/models/perm/access"
repo_model "code.gitea.io/gitea/models/repo"
"code.gitea.io/gitea/models/unit"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/base"
"code.gitea.io/gitea/modules/git"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/structs"
"code.gitea.io/gitea/modules/timeutil"
"xorm.io/builder"
"xorm.io/xorm/schemas"
)
// ActionType represents the type of an action.
type ActionType int
// Possible action types.
const (
ActionCreateRepo ActionType = iota + 1 // 1
ActionRenameRepo // 2
ActionStarRepo // 3
ActionWatchRepo // 4
ActionCommitRepo // 5
ActionCreateIssue // 6
ActionCreatePullRequest // 7
ActionTransferRepo // 8
ActionPushTag // 9
ActionCommentIssue // 10
ActionMergePullRequest // 11
ActionCloseIssue // 12
ActionReopenIssue // 13
ActionClosePullRequest // 14
ActionReopenPullRequest // 15
ActionDeleteTag // 16
ActionDeleteBranch // 17
ActionMirrorSyncPush // 18
ActionMirrorSyncCreate // 19
ActionMirrorSyncDelete // 20
ActionApprovePullRequest // 21
ActionRejectPullRequest // 22
ActionCommentPull // 23
ActionPublishRelease // 24
ActionPullReviewDismissed // 25
ActionPullRequestReadyForReview // 26
ActionAutoMergePullRequest // 27
)
func (at ActionType) String() string {
switch at {
case ActionCreateRepo:
return "create_repo"
case ActionRenameRepo:
return "rename_repo"
case ActionStarRepo:
return "star_repo"
case ActionWatchRepo:
return "watch_repo"
case ActionCommitRepo:
return "commit_repo"
case ActionCreateIssue:
return "create_issue"
case ActionCreatePullRequest:
return "create_pull_request"
case ActionTransferRepo:
return "transfer_repo"
case ActionPushTag:
return "push_tag"
case ActionCommentIssue:
return "comment_issue"
case ActionMergePullRequest:
return "merge_pull_request"
case ActionCloseIssue:
return "close_issue"
case ActionReopenIssue:
return "reopen_issue"
case ActionClosePullRequest:
return "close_pull_request"
case ActionReopenPullRequest:
return "reopen_pull_request"
case ActionDeleteTag:
return "delete_tag"
case ActionDeleteBranch:
return "delete_branch"
case ActionMirrorSyncPush:
return "mirror_sync_push"
case ActionMirrorSyncCreate:
return "mirror_sync_create"
case ActionMirrorSyncDelete:
return "mirror_sync_delete"
case ActionApprovePullRequest:
return "approve_pull_request"
case ActionRejectPullRequest:
return "reject_pull_request"
case ActionCommentPull:
return "comment_pull"
case ActionPublishRelease:
return "publish_release"
case ActionPullReviewDismissed:
return "pull_review_dismissed"
case ActionPullRequestReadyForReview:
return "pull_request_ready_for_review"
case ActionAutoMergePullRequest:
return "auto_merge_pull_request"
default:
return "action-" + strconv.Itoa(int(at))
}
}
func (at ActionType) InActions(actions ...string) bool {
for _, action := range actions {
if action == at.String() {
return true
}
}
return false
}
// Action represents user operation type and other information to
// repository. It implemented interface base.Actioner so that can be
// used in template render.
type Action struct {
ID int64 `xorm:"pk autoincr"`
UserID int64 `xorm:"INDEX"` // Receiver user id.
OpType ActionType
ActUserID int64 // Action user id.
ActUser *user_model.User `xorm:"-"`
RepoID int64
Repo *repo_model.Repository `xorm:"-"`
CommentID int64 `xorm:"INDEX"`
Comment *issues_model.Comment `xorm:"-"`
Issue *issues_model.Issue `xorm:"-"` // get the issue id from content
IsDeleted bool `xorm:"NOT NULL DEFAULT false"`
RefName string
IsPrivate bool `xorm:"NOT NULL DEFAULT false"`
Content string `xorm:"TEXT"`
CreatedUnix timeutil.TimeStamp `xorm:"created"`
}
func init() {
db.RegisterModel(new(Action))
}
// TableIndices implements xorm's TableIndices interface
func (a *Action) TableIndices() []*schemas.Index {
repoIndex := schemas.NewIndex("r_u_d", schemas.IndexType)
repoIndex.AddColumn("repo_id", "user_id", "is_deleted")
actUserIndex := schemas.NewIndex("au_r_c_u_d", schemas.IndexType)
actUserIndex.AddColumn("act_user_id", "repo_id", "created_unix", "user_id", "is_deleted")
cudIndex := schemas.NewIndex("c_u_d", schemas.IndexType)
cudIndex.AddColumn("created_unix", "user_id", "is_deleted")
cuIndex := schemas.NewIndex("c_u", schemas.IndexType)
cuIndex.AddColumn("user_id", "is_deleted")
indices := []*schemas.Index{actUserIndex, repoIndex, cudIndex, cuIndex}
return indices
}
// GetOpType gets the ActionType of this action.
func (a *Action) GetOpType() ActionType {
return a.OpType
}
// LoadActUser loads a.ActUser
func (a *Action) LoadActUser(ctx context.Context) {
if a.ActUser != nil {
return
}
var err error
a.ActUser, err = user_model.GetUserByID(ctx, a.ActUserID)
if err == nil {
return
} else if user_model.IsErrUserNotExist(err) {
a.ActUser = user_model.NewGhostUser()
} else {
log.Error("GetUserByID(%d): %v", a.ActUserID, err)
}
}
func (a *Action) LoadRepo(ctx context.Context) {
if a.Repo != nil {
return
}
var err error
a.Repo, err = repo_model.GetRepositoryByID(ctx, a.RepoID)
if err != nil {
log.Error("repo_model.GetRepositoryByID(%d): %v", a.RepoID, err)
}
}
// GetActFullName gets the action's user full name.
func (a *Action) GetActFullName(ctx context.Context) string {
a.LoadActUser(ctx)
return a.ActUser.FullName
}
// GetActUserName gets the action's user name.
func (a *Action) GetActUserName(ctx context.Context) string {
a.LoadActUser(ctx)
return a.ActUser.Name
}
// ShortActUserName gets the action's user name trimmed to max 20
// chars.
func (a *Action) ShortActUserName(ctx context.Context) string {
return base.EllipsisString(a.GetActUserName(ctx), 20)
}
// GetActDisplayName gets the action's display name based on DEFAULT_SHOW_FULL_NAME, or falls back to the username if it is blank.
func (a *Action) GetActDisplayName(ctx context.Context) string {
if setting.UI.DefaultShowFullName {
trimmedFullName := strings.TrimSpace(a.GetActFullName(ctx))
if len(trimmedFullName) > 0 {
return trimmedFullName
}
}
return a.ShortActUserName(ctx)
}
// GetActDisplayNameTitle gets the action's display name used for the title (tooltip) based on DEFAULT_SHOW_FULL_NAME
func (a *Action) GetActDisplayNameTitle(ctx context.Context) string {
if setting.UI.DefaultShowFullName {
return a.ShortActUserName(ctx)
}
return a.GetActFullName(ctx)
}
// GetRepoUserName returns the name of the action repository owner.
func (a *Action) GetRepoUserName(ctx context.Context) string {
a.LoadRepo(ctx)
if a.Repo == nil {
return "(non-existing-repo)"
}
return a.Repo.OwnerName
}
// ShortRepoUserName returns the name of the action repository owner
// trimmed to max 20 chars.
func (a *Action) ShortRepoUserName(ctx context.Context) string {
return base.EllipsisString(a.GetRepoUserName(ctx), 20)
}
// GetRepoName returns the name of the action repository.
func (a *Action) GetRepoName(ctx context.Context) string {
a.LoadRepo(ctx)
if a.Repo == nil {
return "(non-existing-repo)"
}
return a.Repo.Name
}
// ShortRepoName returns the name of the action repository
// trimmed to max 33 chars.
func (a *Action) ShortRepoName(ctx context.Context) string {
return base.EllipsisString(a.GetRepoName(ctx), 33)
}
// GetRepoPath returns the virtual path to the action repository.
func (a *Action) GetRepoPath(ctx context.Context) string {
return path.Join(a.GetRepoUserName(ctx), a.GetRepoName(ctx))
}
// ShortRepoPath returns the virtual path to the action repository
// trimmed to max 20 + 1 + 33 chars.
func (a *Action) ShortRepoPath(ctx context.Context) string {
return path.Join(a.ShortRepoUserName(ctx), a.ShortRepoName(ctx))
}
// GetRepoLink returns relative link to action repository.
func (a *Action) GetRepoLink(ctx context.Context) string {
// path.Join will skip empty strings
return path.Join(setting.AppSubURL, "/", url.PathEscape(a.GetRepoUserName(ctx)), url.PathEscape(a.GetRepoName(ctx)))
}
// GetRepoAbsoluteLink returns the absolute link to action repository.
func (a *Action) GetRepoAbsoluteLink(ctx context.Context) string {
return setting.AppURL + url.PathEscape(a.GetRepoUserName(ctx)) + "/" + url.PathEscape(a.GetRepoName(ctx))
}
func (a *Action) loadComment(ctx context.Context) (err error) {
if a.CommentID == 0 || a.Comment != nil {
return nil
}
a.Comment, err = issues_model.GetCommentByID(ctx, a.CommentID)
return err
}
// GetCommentHTMLURL returns link to action comment.
func (a *Action) GetCommentHTMLURL(ctx context.Context) string {
if a == nil {
return "#"
}
_ = a.loadComment(ctx)
if a.Comment != nil {
return a.Comment.HTMLURL(ctx)
}
if err := a.LoadIssue(ctx); err != nil || a.Issue == nil {
return "#"
}
if err := a.Issue.LoadRepo(ctx); err != nil {
return "#"
}
return a.Issue.HTMLURL()
}
// GetCommentLink returns link to action comment.
func (a *Action) GetCommentLink(ctx context.Context) string {
if a == nil {
return "#"
}
_ = a.loadComment(ctx)
if a.Comment != nil {
return a.Comment.Link(ctx)
}
if err := a.LoadIssue(ctx); err != nil || a.Issue == nil {
return "#"
}
if err := a.Issue.LoadRepo(ctx); err != nil {
return "#"
}
return a.Issue.Link()
}
// GetBranch returns the action's repository branch.
func (a *Action) GetBranch() string {
return strings.TrimPrefix(a.RefName, git.BranchPrefix)
}
// GetRefLink returns the action's ref link.
func (a *Action) GetRefLink(ctx context.Context) string {
return git.RefURL(a.GetRepoLink(ctx), a.RefName)
}
// GetTag returns the action's repository tag.
func (a *Action) GetTag() string {
return strings.TrimPrefix(a.RefName, git.TagPrefix)
}
// GetContent returns the action's content.
func (a *Action) GetContent() string {
return a.Content
}
// GetCreate returns the action creation time.
func (a *Action) GetCreate() time.Time {
return a.CreatedUnix.AsTime()
}
func (a *Action) IsIssueEvent() bool {
return a.OpType.InActions("comment_issue", "approve_pull_request", "reject_pull_request", "comment_pull", "merge_pull_request")
}
// GetIssueInfos returns a list of associated information with the action.
func (a *Action) GetIssueInfos() []string {
// make sure it always returns 3 elements, because there are some access to the a[1] and a[2] without checking the length
ret := strings.SplitN(a.Content, "|", 3)
for len(ret) < 3 {
ret = append(ret, "")
}
return ret
}
func (a *Action) getIssueIndex() int64 {
infos := a.GetIssueInfos()
if len(infos) == 0 {
return 0
}
index, _ := strconv.ParseInt(infos[0], 10, 64)
return index
}
func (a *Action) LoadIssue(ctx context.Context) error {
if a.Issue != nil {
return nil
}
if index := a.getIssueIndex(); index > 0 {
issue, err := issues_model.GetIssueByIndex(ctx, a.RepoID, index)
if err != nil {
return err
}
a.Issue = issue
a.Issue.Repo = a.Repo
}
return nil
}
// GetIssueTitle returns the title of first issue associated with the action.
func (a *Action) GetIssueTitle(ctx context.Context) string {
if err := a.LoadIssue(ctx); err != nil {
log.Error("LoadIssue: %v", err)
return "<500 when get issue>"
}
if a.Issue == nil {
return "<Issue not found>"
}
return a.Issue.Title
}
// GetIssueContent returns the content of first issue associated with this action.
func (a *Action) GetIssueContent(ctx context.Context) string {
if err := a.LoadIssue(ctx); err != nil {
log.Error("LoadIssue: %v", err)
return "<500 when get issue>"
}
if a.Issue == nil {
return "<Content not found>"
}
return a.Issue.Content
}
// GetFeedsOptions options for retrieving feeds
type GetFeedsOptions struct {
db.ListOptions
RequestedUser *user_model.User // the user we want activity for
RequestedTeam *organization.Team // the team we want activity for
RequestedRepo *repo_model.Repository // the repo we want activity for
Actor *user_model.User // the user viewing the activity
IncludePrivate bool // include private actions
OnlyPerformedBy bool // only actions performed by requested user
IncludeDeleted bool // include deleted actions
Date string // the day we want activity for: YYYY-MM-DD
}
// ActivityReadable return whether doer can read activities of user
func ActivityReadable(user, doer *user_model.User) bool {
return !user.KeepActivityPrivate ||
doer != nil && (doer.IsAdmin || user.ID == doer.ID)
}
func ActivityQueryCondition(ctx context.Context, opts GetFeedsOptions) (builder.Cond, error) {
cond := builder.NewCond()
if opts.RequestedTeam != nil && opts.RequestedUser == nil {
org, err := user_model.GetUserByID(ctx, opts.RequestedTeam.OrgID)
if err != nil {
return nil, err
}
opts.RequestedUser = org
}
// check activity visibility for actor ( similar to activityReadable() )
if opts.Actor == nil {
cond = cond.And(builder.In("act_user_id",
builder.Select("`user`.id").Where(
builder.Eq{"keep_activity_private": false, "visibility": structs.VisibleTypePublic},
).From("`user`"),
))
} else if !opts.Actor.IsAdmin {
uidCond := builder.Select("`user`.id").From("`user`").Where(
builder.Eq{"keep_activity_private": false}.
And(builder.In("visibility", structs.VisibleTypePublic, structs.VisibleTypeLimited))).
Or(builder.Eq{"id": opts.Actor.ID})
if opts.RequestedUser != nil {
if opts.RequestedUser.IsOrganization() {
// An organization can always see the activities whose `act_user_id` is the same as its id.
uidCond = uidCond.Or(builder.Eq{"id": opts.RequestedUser.ID})
} else {
// A user can always see the activities of the organizations to which the user belongs.
uidCond = uidCond.Or(
builder.Eq{"type": user_model.UserTypeOrganization}.
And(builder.In("`user`.id", builder.Select("org_id").
Where(builder.Eq{"uid": opts.RequestedUser.ID}).
From("team_user"))),
)
}
}
cond = cond.And(builder.In("act_user_id", uidCond))
}
// check readable repositories by doer/actor
if opts.Actor == nil || !opts.Actor.IsAdmin {
cond = cond.And(builder.In("repo_id", repo_model.AccessibleRepoIDsQuery(opts.Actor)))
}
if opts.RequestedRepo != nil {
// repo's actions could have duplicate items, see the comment of NotifyWatchers
// so here we only filter the "original items", aka: user_id == act_user_id
cond = cond.And(
builder.Eq{"`action`.repo_id": opts.RequestedRepo.ID},
builder.Expr("`action`.user_id = `action`.act_user_id"),
)
}
if opts.RequestedTeam != nil {
env := organization.OrgFromUser(opts.RequestedUser).AccessibleTeamReposEnv(ctx, opts.RequestedTeam)
teamRepoIDs, err := env.RepoIDs(1, opts.RequestedUser.NumRepos)
if err != nil {
return nil, fmt.Errorf("GetTeamRepositories: %w", err)
}
cond = cond.And(builder.In("repo_id", teamRepoIDs))
}
if opts.RequestedUser != nil {
cond = cond.And(builder.Eq{"user_id": opts.RequestedUser.ID})
if opts.OnlyPerformedBy {
cond = cond.And(builder.Eq{"act_user_id": opts.RequestedUser.ID})
}
}
if !opts.IncludePrivate {
cond = cond.And(builder.Eq{"`action`.is_private": false})
}
if !opts.IncludeDeleted {
cond = cond.And(builder.Eq{"is_deleted": false})
}
if opts.Date != "" {
dateLow, err := time.ParseInLocation("2006-01-02", opts.Date, setting.DefaultUILocation)
if err != nil {
log.Warn("Unable to parse %s, filter not applied: %v", opts.Date, err)
} else {
dateHigh := dateLow.Add(86399000000000) // 23h59m59s
cond = cond.And(builder.Gte{"`action`.created_unix": dateLow.Unix()})
cond = cond.And(builder.Lte{"`action`.created_unix": dateHigh.Unix()})
}
}
return cond, nil
}
// DeleteOldActions deletes all old actions from database.
func DeleteOldActions(ctx context.Context, olderThan time.Duration) (err error) {
if olderThan <= 0 {
return nil
}
_, err = db.GetEngine(ctx).Where("created_unix < ?", time.Now().Add(-olderThan).Unix()).Delete(&Action{})
return err
}
// NotifyWatchers creates batch of actions for every watcher.
// It could insert duplicate actions for a repository action, like this:
// * Original action: UserID=1 (the real actor), ActUserID=1
// * Organization action: UserID=100 (the repo's org), ActUserID=1
// * Watcher action: UserID=20 (a user who is watching a repo), ActUserID=1
func NotifyWatchers(ctx context.Context, actions ...*Action) error {
var watchers []*repo_model.Watch
var repo *repo_model.Repository
var err error
var permCode []bool
var permIssue []bool
var permPR []bool
e := db.GetEngine(ctx)
for _, act := range actions {
repoChanged := repo == nil || repo.ID != act.RepoID
if repoChanged {
// Add feeds for user self and all watchers.
watchers, err = repo_model.GetWatchers(ctx, act.RepoID)
if err != nil {
return fmt.Errorf("get watchers: %w", err)
}
}
// Add feed for actioner.
act.UserID = act.ActUserID
if _, err = e.Insert(act); err != nil {
return fmt.Errorf("insert new actioner: %w", err)
}
if repoChanged {
act.LoadRepo(ctx)
repo = act.Repo
// check repo owner exist.
if err := act.Repo.LoadOwner(ctx); err != nil {
return fmt.Errorf("can't get repo owner: %w", err)
}
} else if act.Repo == nil {
act.Repo = repo
}
// Add feed for organization
if act.Repo.Owner.IsOrganization() && act.ActUserID != act.Repo.Owner.ID {
act.ID = 0
act.UserID = act.Repo.Owner.ID
if err = db.Insert(ctx, act); err != nil {
return fmt.Errorf("insert new actioner: %w", err)
}
}
if repoChanged {
permCode = make([]bool, len(watchers))
permIssue = make([]bool, len(watchers))
permPR = make([]bool, len(watchers))
for i, watcher := range watchers {
user, err := user_model.GetUserByID(ctx, watcher.UserID)
if err != nil {
permCode[i] = false
permIssue[i] = false
permPR[i] = false
continue
}
perm, err := access_model.GetUserRepoPermission(ctx, repo, user)
if err != nil {
permCode[i] = false
permIssue[i] = false
permPR[i] = false
continue
}
permCode[i] = perm.CanRead(unit.TypeCode)
permIssue[i] = perm.CanRead(unit.TypeIssues)
permPR[i] = perm.CanRead(unit.TypePullRequests)
}
}
for i, watcher := range watchers {
if act.ActUserID == watcher.UserID {
continue
}
act.ID = 0
act.UserID = watcher.UserID
act.Repo.Units = nil
switch act.OpType {
case ActionCommitRepo, ActionPushTag, ActionDeleteTag, ActionPublishRelease, ActionDeleteBranch:
if !permCode[i] {
continue
}
case ActionCreateIssue, ActionCommentIssue, ActionCloseIssue, ActionReopenIssue:
if !permIssue[i] {
continue
}
case ActionCreatePullRequest, ActionCommentPull, ActionMergePullRequest, ActionClosePullRequest, ActionReopenPullRequest, ActionAutoMergePullRequest:
if !permPR[i] {
continue
}
}
if err = db.Insert(ctx, act); err != nil {
return fmt.Errorf("insert new action: %w", err)
}
}
}
return nil
}
// NotifyWatchersActions creates batch of actions for every watcher.
func NotifyWatchersActions(ctx context.Context, acts []*Action) error {
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return err
}
defer committer.Close()
for _, act := range acts {
if err := NotifyWatchers(ctx, act); err != nil {
return err
}
}
return committer.Commit()
}
// DeleteIssueActions delete all actions related with issueID
func DeleteIssueActions(ctx context.Context, repoID, issueID, issueIndex int64) error {
// delete actions assigned to this issue
e := db.GetEngine(ctx)
// MariaDB has a performance bug: https://jira.mariadb.org/browse/MDEV-16289
// so here it uses "DELETE ... WHERE IN" with pre-queried IDs.
var lastCommentID int64
commentIDs := make([]int64, 0, db.DefaultMaxInSize)
for {
commentIDs = commentIDs[:0]
err := e.Select("`id`").Table(&issues_model.Comment{}).
Where(builder.Eq{"issue_id": issueID}).And("`id` > ?", lastCommentID).
OrderBy("`id`").Limit(db.DefaultMaxInSize).
Find(&commentIDs)
if err != nil {
return err
} else if len(commentIDs) == 0 {
break
} else if _, err = db.GetEngine(ctx).In("comment_id", commentIDs).Delete(&Action{}); err != nil {
return err
}
lastCommentID = commentIDs[len(commentIDs)-1]
}
_, err := e.Where("repo_id = ?", repoID).
In("op_type", ActionCreateIssue, ActionCreatePullRequest).
Where("content LIKE ?", strconv.FormatInt(issueIndex, 10)+"|%"). // "IssueIndex|content..."
Delete(&Action{})
return err
}
// CountActionCreatedUnixString count actions where created_unix is an empty string
func CountActionCreatedUnixString(ctx context.Context) (int64, error) {
if setting.Database.Type.IsSQLite3() {
return db.GetEngine(ctx).Where(`created_unix = ''`).Count(new(Action))
}
return 0, nil
}
// FixActionCreatedUnixString set created_unix to zero if it is an empty string
func FixActionCreatedUnixString(ctx context.Context) (int64, error) {
if setting.Database.Type.IsSQLite3() {
res, err := db.GetEngine(ctx).Exec(`UPDATE action SET created_unix = 0 WHERE created_unix = ''`)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
return 0, nil
}

View File

@@ -0,0 +1,255 @@
// Copyright 2018 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package activities
import (
"context"
"fmt"
"strconv"
"code.gitea.io/gitea/models/db"
issues_model "code.gitea.io/gitea/models/issues"
repo_model "code.gitea.io/gitea/models/repo"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/container"
"code.gitea.io/gitea/modules/util"
"xorm.io/builder"
)
// ActionList defines a list of actions
type ActionList []*Action
func (actions ActionList) getUserIDs() []int64 {
return container.FilterSlice(actions, func(action *Action) (int64, bool) {
return action.ActUserID, true
})
}
func (actions ActionList) LoadActUsers(ctx context.Context) (map[int64]*user_model.User, error) {
if len(actions) == 0 {
return nil, nil
}
userIDs := actions.getUserIDs()
userMaps := make(map[int64]*user_model.User, len(userIDs))
err := db.GetEngine(ctx).
In("id", userIDs).
Find(&userMaps)
if err != nil {
return nil, fmt.Errorf("find user: %w", err)
}
for _, action := range actions {
action.ActUser = userMaps[action.ActUserID]
}
return userMaps, nil
}
func (actions ActionList) getRepoIDs() []int64 {
return container.FilterSlice(actions, func(action *Action) (int64, bool) {
return action.RepoID, true
})
}
func (actions ActionList) LoadRepositories(ctx context.Context) error {
if len(actions) == 0 {
return nil
}
repoIDs := actions.getRepoIDs()
repoMaps := make(map[int64]*repo_model.Repository, len(repoIDs))
err := db.GetEngine(ctx).In("id", repoIDs).Find(&repoMaps)
if err != nil {
return fmt.Errorf("find repository: %w", err)
}
for _, action := range actions {
action.Repo = repoMaps[action.RepoID]
}
repos := repo_model.RepositoryList(util.ValuesOfMap(repoMaps))
return repos.LoadUnits(ctx)
}
func (actions ActionList) loadRepoOwner(ctx context.Context, userMap map[int64]*user_model.User) (err error) {
if userMap == nil {
userMap = make(map[int64]*user_model.User)
}
missingUserIDs := container.FilterSlice(actions, func(action *Action) (int64, bool) {
if action.Repo == nil {
return 0, false
}
_, alreadyLoaded := userMap[action.Repo.OwnerID]
return action.Repo.OwnerID, !alreadyLoaded
})
if len(missingUserIDs) == 0 {
return nil
}
if err := db.GetEngine(ctx).
In("id", missingUserIDs).
Find(&userMap); err != nil {
return fmt.Errorf("find user: %w", err)
}
for _, action := range actions {
if action.Repo != nil {
action.Repo.Owner = userMap[action.Repo.OwnerID]
}
}
return nil
}
// LoadAttributes loads all attributes
func (actions ActionList) LoadAttributes(ctx context.Context) error {
// the load sequence cannot be changed because of the dependencies
userMap, err := actions.LoadActUsers(ctx)
if err != nil {
return err
}
if err := actions.LoadRepositories(ctx); err != nil {
return err
}
if err := actions.loadRepoOwner(ctx, userMap); err != nil {
return err
}
if err := actions.LoadIssues(ctx); err != nil {
return err
}
return actions.LoadComments(ctx)
}
func (actions ActionList) LoadComments(ctx context.Context) error {
if len(actions) == 0 {
return nil
}
commentIDs := make([]int64, 0, len(actions))
for _, action := range actions {
if action.CommentID > 0 {
commentIDs = append(commentIDs, action.CommentID)
}
}
if len(commentIDs) == 0 {
return nil
}
commentsMap := make(map[int64]*issues_model.Comment, len(commentIDs))
if err := db.GetEngine(ctx).In("id", commentIDs).Find(&commentsMap); err != nil {
return fmt.Errorf("find comment: %w", err)
}
for _, action := range actions {
if action.CommentID > 0 {
action.Comment = commentsMap[action.CommentID]
if action.Comment != nil {
action.Comment.Issue = action.Issue
}
}
}
return nil
}
func (actions ActionList) LoadIssues(ctx context.Context) error {
if len(actions) == 0 {
return nil
}
conditions := builder.NewCond()
issueNum := 0
for _, action := range actions {
if action.IsIssueEvent() {
infos := action.GetIssueInfos()
if len(infos) == 0 {
continue
}
index, _ := strconv.ParseInt(infos[0], 10, 64)
if index > 0 {
conditions = conditions.Or(builder.Eq{
"repo_id": action.RepoID,
"`index`": index,
})
issueNum++
}
}
}
if !conditions.IsValid() {
return nil
}
issuesMap := make(map[string]*issues_model.Issue, issueNum)
issues := make([]*issues_model.Issue, 0, issueNum)
if err := db.GetEngine(ctx).Where(conditions).Find(&issues); err != nil {
return fmt.Errorf("find issue: %w", err)
}
for _, issue := range issues {
issuesMap[fmt.Sprintf("%d-%d", issue.RepoID, issue.Index)] = issue
}
for _, action := range actions {
if !action.IsIssueEvent() {
continue
}
if index := action.getIssueIndex(); index > 0 {
if issue, ok := issuesMap[fmt.Sprintf("%d-%d", action.RepoID, index)]; ok {
action.Issue = issue
action.Issue.Repo = action.Repo
}
}
}
return nil
}
// GetFeeds returns actions according to the provided options
func GetFeeds(ctx context.Context, opts GetFeedsOptions) (ActionList, int64, error) {
if opts.RequestedUser == nil && opts.RequestedTeam == nil && opts.RequestedRepo == nil {
return nil, 0, fmt.Errorf("need at least one of these filters: RequestedUser, RequestedTeam, RequestedRepo")
}
cond, err := ActivityQueryCondition(ctx, opts)
if err != nil {
return nil, 0, err
}
actions := make([]*Action, 0, opts.PageSize)
var count int64
opts.SetDefaultValues()
if opts.Page < 10 { // TODO: why it's 10 but other values? It's an experience value.
sess := db.GetEngine(ctx).Where(cond)
sess = db.SetSessionPagination(sess, &opts)
count, err = sess.Desc("`action`.created_unix").FindAndCount(&actions)
if err != nil {
return nil, 0, fmt.Errorf("FindAndCount: %w", err)
}
} else {
// First, only query which IDs are necessary, and only then query all actions to speed up the overall query
sess := db.GetEngine(ctx).Where(cond).Select("`action`.id")
sess = db.SetSessionPagination(sess, &opts)
actionIDs := make([]int64, 0, opts.PageSize)
if err := sess.Table("action").Desc("`action`.created_unix").Find(&actionIDs); err != nil {
return nil, 0, fmt.Errorf("Find(actionsIDs): %w", err)
}
count, err = db.GetEngine(ctx).Where(cond).
Table("action").
Cols("`action`.id").Count()
if err != nil {
return nil, 0, fmt.Errorf("Count: %w", err)
}
if err := db.GetEngine(ctx).In("`action`.id", actionIDs).Desc("`action`.created_unix").Find(&actions); err != nil {
return nil, 0, fmt.Errorf("Find: %w", err)
}
}
if err := ActionList(actions).LoadAttributes(ctx); err != nil {
return nil, 0, fmt.Errorf("LoadAttributes: %w", err)
}
return actions, count, nil
}

View File

@@ -0,0 +1,196 @@
// Copyright 2020 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package activities_test
import (
"fmt"
"path"
"testing"
activities_model "code.gitea.io/gitea/models/activities"
"code.gitea.io/gitea/models/db"
issue_model "code.gitea.io/gitea/models/issues"
repo_model "code.gitea.io/gitea/models/repo"
"code.gitea.io/gitea/models/unittest"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/test"
"github.com/stretchr/testify/assert"
)
func TestAction_GetRepoPath(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
repo := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 1})
owner := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: repo.OwnerID})
action := &activities_model.Action{RepoID: repo.ID}
assert.Equal(t, path.Join(owner.Name, repo.Name), action.GetRepoPath(db.DefaultContext))
}
func TestAction_GetRepoLink(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
repo := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: 1})
owner := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: repo.OwnerID})
comment := unittest.AssertExistsAndLoadBean(t, &issue_model.Comment{ID: 2})
action := &activities_model.Action{RepoID: repo.ID, CommentID: comment.ID}
defer test.MockVariableValue(&setting.AppURL, "https://try.gitea.io/suburl/")()
defer test.MockVariableValue(&setting.AppSubURL, "/suburl")()
expected := path.Join(setting.AppSubURL, owner.Name, repo.Name)
assert.Equal(t, expected, action.GetRepoLink(db.DefaultContext))
assert.Equal(t, repo.HTMLURL(), action.GetRepoAbsoluteLink(db.DefaultContext))
assert.Equal(t, comment.HTMLURL(db.DefaultContext), action.GetCommentHTMLURL(db.DefaultContext))
}
func TestActivityReadable(t *testing.T) {
tt := []struct {
desc string
user *user_model.User
doer *user_model.User
result bool
}{{
desc: "user should see own activity",
user: &user_model.User{ID: 1},
doer: &user_model.User{ID: 1},
result: true,
}, {
desc: "anon should see activity if public",
user: &user_model.User{ID: 1},
result: true,
}, {
desc: "anon should NOT see activity",
user: &user_model.User{ID: 1, KeepActivityPrivate: true},
result: false,
}, {
desc: "user should see own activity if private too",
user: &user_model.User{ID: 1, KeepActivityPrivate: true},
doer: &user_model.User{ID: 1},
result: true,
}, {
desc: "other user should NOT see activity",
user: &user_model.User{ID: 1, KeepActivityPrivate: true},
doer: &user_model.User{ID: 2},
result: false,
}, {
desc: "admin should see activity",
user: &user_model.User{ID: 1, KeepActivityPrivate: true},
doer: &user_model.User{ID: 2, IsAdmin: true},
result: true,
}}
for _, test := range tt {
assert.Equal(t, test.result, activities_model.ActivityReadable(test.user, test.doer), test.desc)
}
}
func TestNotifyWatchers(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
action := &activities_model.Action{
ActUserID: 8,
RepoID: 1,
OpType: activities_model.ActionStarRepo,
}
assert.NoError(t, activities_model.NotifyWatchers(db.DefaultContext, action))
// One watchers are inactive, thus action is only created for user 8, 1, 4, 11
unittest.AssertExistsAndLoadBean(t, &activities_model.Action{
ActUserID: action.ActUserID,
UserID: 8,
RepoID: action.RepoID,
OpType: action.OpType,
})
unittest.AssertExistsAndLoadBean(t, &activities_model.Action{
ActUserID: action.ActUserID,
UserID: 1,
RepoID: action.RepoID,
OpType: action.OpType,
})
unittest.AssertExistsAndLoadBean(t, &activities_model.Action{
ActUserID: action.ActUserID,
UserID: 4,
RepoID: action.RepoID,
OpType: action.OpType,
})
unittest.AssertExistsAndLoadBean(t, &activities_model.Action{
ActUserID: action.ActUserID,
UserID: 11,
RepoID: action.RepoID,
OpType: action.OpType,
})
}
func TestConsistencyUpdateAction(t *testing.T) {
if !setting.Database.Type.IsSQLite3() {
t.Skip("Test is only for SQLite database.")
}
assert.NoError(t, unittest.PrepareTestDatabase())
id := 8
unittest.AssertExistsAndLoadBean(t, &activities_model.Action{
ID: int64(id),
})
_, err := db.GetEngine(db.DefaultContext).Exec(`UPDATE action SET created_unix = '' WHERE id = ?`, id)
assert.NoError(t, err)
actions := make([]*activities_model.Action, 0, 1)
//
// XORM returns an error when created_unix is a string
//
err = db.GetEngine(db.DefaultContext).Where("id = ?", id).Find(&actions)
if assert.Error(t, err) {
assert.Contains(t, err.Error(), "type string to a int64: invalid syntax")
}
//
// Get rid of incorrectly set created_unix
//
count, err := activities_model.CountActionCreatedUnixString(db.DefaultContext)
assert.NoError(t, err)
assert.EqualValues(t, 1, count)
count, err = activities_model.FixActionCreatedUnixString(db.DefaultContext)
assert.NoError(t, err)
assert.EqualValues(t, 1, count)
count, err = activities_model.CountActionCreatedUnixString(db.DefaultContext)
assert.NoError(t, err)
assert.EqualValues(t, 0, count)
count, err = activities_model.FixActionCreatedUnixString(db.DefaultContext)
assert.NoError(t, err)
assert.EqualValues(t, 0, count)
//
// XORM must be happy now
//
assert.NoError(t, db.GetEngine(db.DefaultContext).Where("id = ?", id).Find(&actions))
unittest.CheckConsistencyFor(t, &activities_model.Action{})
}
func TestDeleteIssueActions(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
// load an issue
issue := unittest.AssertExistsAndLoadBean(t, &issue_model.Issue{ID: 4})
assert.NotEqualValues(t, issue.ID, issue.Index) // it needs to use different ID/Index to test the DeleteIssueActions to delete some actions by IssueIndex
// insert a comment
err := db.Insert(db.DefaultContext, &issue_model.Comment{Type: issue_model.CommentTypeComment, IssueID: issue.ID})
assert.NoError(t, err)
comment := unittest.AssertExistsAndLoadBean(t, &issue_model.Comment{Type: issue_model.CommentTypeComment, IssueID: issue.ID})
// truncate action table and insert some actions
err = db.TruncateBeans(db.DefaultContext, &activities_model.Action{})
assert.NoError(t, err)
err = db.Insert(db.DefaultContext, &activities_model.Action{
OpType: activities_model.ActionCommentIssue,
CommentID: comment.ID,
})
assert.NoError(t, err)
err = db.Insert(db.DefaultContext, &activities_model.Action{
OpType: activities_model.ActionCreateIssue,
RepoID: issue.RepoID,
Content: fmt.Sprintf("%d|content...", issue.Index),
})
assert.NoError(t, err)
// assert that the actions exist, then delete them
unittest.AssertCount(t, &activities_model.Action{}, 2)
assert.NoError(t, activities_model.DeleteIssueActions(db.DefaultContext, issue.RepoID, issue.ID, issue.Index))
unittest.AssertCount(t, &activities_model.Action{}, 0)
}

View File

@@ -0,0 +1,17 @@
// Copyright 2020 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package activities_test
import (
"testing"
"code.gitea.io/gitea/models/unittest"
_ "code.gitea.io/gitea/models"
_ "code.gitea.io/gitea/models/actions"
)
func TestMain(m *testing.M) {
unittest.MainTest(m)
}

View File

@@ -0,0 +1,418 @@
// Copyright 2016 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package activities
import (
"context"
"fmt"
"net/url"
"strconv"
"code.gitea.io/gitea/models/db"
issues_model "code.gitea.io/gitea/models/issues"
"code.gitea.io/gitea/models/organization"
repo_model "code.gitea.io/gitea/models/repo"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/timeutil"
"xorm.io/builder"
"xorm.io/xorm/schemas"
)
type (
// NotificationStatus is the status of the notification (read or unread)
NotificationStatus uint8
// NotificationSource is the source of the notification (issue, PR, commit, etc)
NotificationSource uint8
)
const (
// NotificationStatusUnread represents an unread notification
NotificationStatusUnread NotificationStatus = iota + 1
// NotificationStatusRead represents a read notification
NotificationStatusRead
// NotificationStatusPinned represents a pinned notification
NotificationStatusPinned
)
const (
// NotificationSourceIssue is a notification of an issue
NotificationSourceIssue NotificationSource = iota + 1
// NotificationSourcePullRequest is a notification of a pull request
NotificationSourcePullRequest
// NotificationSourceCommit is a notification of a commit
NotificationSourceCommit
// NotificationSourceRepository is a notification for a repository
NotificationSourceRepository
)
// Notification represents a notification
type Notification struct {
ID int64 `xorm:"pk autoincr"`
UserID int64 `xorm:"NOT NULL"`
RepoID int64 `xorm:"NOT NULL"`
Status NotificationStatus `xorm:"SMALLINT NOT NULL"`
Source NotificationSource `xorm:"SMALLINT NOT NULL"`
IssueID int64 `xorm:"NOT NULL"`
CommitID string
CommentID int64
UpdatedBy int64 `xorm:"NOT NULL"`
Issue *issues_model.Issue `xorm:"-"`
Repository *repo_model.Repository `xorm:"-"`
Comment *issues_model.Comment `xorm:"-"`
User *user_model.User `xorm:"-"`
CreatedUnix timeutil.TimeStamp `xorm:"created NOT NULL"`
UpdatedUnix timeutil.TimeStamp `xorm:"updated NOT NULL"`
}
// TableIndices implements xorm's TableIndices interface
func (n *Notification) TableIndices() []*schemas.Index {
indices := make([]*schemas.Index, 0, 8)
usuuIndex := schemas.NewIndex("u_s_uu", schemas.IndexType)
usuuIndex.AddColumn("user_id", "status", "updated_unix")
indices = append(indices, usuuIndex)
// Add the individual indices that were previously defined in struct tags
userIDIndex := schemas.NewIndex("idx_notification_user_id", schemas.IndexType)
userIDIndex.AddColumn("user_id")
indices = append(indices, userIDIndex)
repoIDIndex := schemas.NewIndex("idx_notification_repo_id", schemas.IndexType)
repoIDIndex.AddColumn("repo_id")
indices = append(indices, repoIDIndex)
statusIndex := schemas.NewIndex("idx_notification_status", schemas.IndexType)
statusIndex.AddColumn("status")
indices = append(indices, statusIndex)
sourceIndex := schemas.NewIndex("idx_notification_source", schemas.IndexType)
sourceIndex.AddColumn("source")
indices = append(indices, sourceIndex)
issueIDIndex := schemas.NewIndex("idx_notification_issue_id", schemas.IndexType)
issueIDIndex.AddColumn("issue_id")
indices = append(indices, issueIDIndex)
commitIDIndex := schemas.NewIndex("idx_notification_commit_id", schemas.IndexType)
commitIDIndex.AddColumn("commit_id")
indices = append(indices, commitIDIndex)
updatedByIndex := schemas.NewIndex("idx_notification_updated_by", schemas.IndexType)
updatedByIndex.AddColumn("updated_by")
indices = append(indices, updatedByIndex)
return indices
}
func init() {
db.RegisterModel(new(Notification))
}
// CreateRepoTransferNotification creates notification for the user a repository was transferred to
func CreateRepoTransferNotification(ctx context.Context, doer, newOwner *user_model.User, repo *repo_model.Repository) error {
return db.WithTx(ctx, func(ctx context.Context) error {
var notify []*Notification
if newOwner.IsOrganization() {
users, err := organization.GetUsersWhoCanCreateOrgRepo(ctx, newOwner.ID)
if err != nil || len(users) == 0 {
return err
}
for i := range users {
notify = append(notify, &Notification{
UserID: i,
RepoID: repo.ID,
Status: NotificationStatusUnread,
UpdatedBy: doer.ID,
Source: NotificationSourceRepository,
})
}
} else {
notify = []*Notification{{
UserID: newOwner.ID,
RepoID: repo.ID,
Status: NotificationStatusUnread,
UpdatedBy: doer.ID,
Source: NotificationSourceRepository,
}}
}
return db.Insert(ctx, notify)
})
}
func createIssueNotification(ctx context.Context, userID int64, issue *issues_model.Issue, commentID, updatedByID int64) error {
notification := &Notification{
UserID: userID,
RepoID: issue.RepoID,
Status: NotificationStatusUnread,
IssueID: issue.ID,
CommentID: commentID,
UpdatedBy: updatedByID,
}
if issue.IsPull {
notification.Source = NotificationSourcePullRequest
} else {
notification.Source = NotificationSourceIssue
}
return db.Insert(ctx, notification)
}
func updateIssueNotification(ctx context.Context, userID, issueID, commentID, updatedByID int64) error {
notification, err := GetIssueNotification(ctx, userID, issueID)
if err != nil {
return err
}
// NOTICE: Only update comment id when the before notification on this issue is read, otherwise you may miss some old comments.
// But we need update update_by so that the notification will be reorder
var cols []string
if notification.Status == NotificationStatusRead {
notification.Status = NotificationStatusUnread
notification.CommentID = commentID
cols = []string{"status", "update_by", "comment_id"}
} else {
notification.UpdatedBy = updatedByID
cols = []string{"update_by"}
}
_, err = db.GetEngine(ctx).ID(notification.ID).Cols(cols...).Update(notification)
return err
}
// GetIssueNotification return the notification about an issue
func GetIssueNotification(ctx context.Context, userID, issueID int64) (*Notification, error) {
notification := new(Notification)
_, err := db.GetEngine(ctx).
Where("user_id = ?", userID).
And("issue_id = ?", issueID).
Get(notification)
return notification, err
}
// LoadAttributes load Repo Issue User and Comment if not loaded
func (n *Notification) LoadAttributes(ctx context.Context) (err error) {
if err = n.loadRepo(ctx); err != nil {
return err
}
if err = n.loadIssue(ctx); err != nil {
return err
}
if err = n.loadUser(ctx); err != nil {
return err
}
if err = n.loadComment(ctx); err != nil {
return err
}
return err
}
func (n *Notification) loadRepo(ctx context.Context) (err error) {
if n.Repository == nil {
n.Repository, err = repo_model.GetRepositoryByID(ctx, n.RepoID)
if err != nil {
return fmt.Errorf("getRepositoryByID [%d]: %w", n.RepoID, err)
}
}
return nil
}
func (n *Notification) loadIssue(ctx context.Context) (err error) {
if n.Issue == nil && n.IssueID != 0 {
n.Issue, err = issues_model.GetIssueByID(ctx, n.IssueID)
if err != nil {
return fmt.Errorf("getIssueByID [%d]: %w", n.IssueID, err)
}
return n.Issue.LoadAttributes(ctx)
}
return nil
}
func (n *Notification) loadComment(ctx context.Context) (err error) {
if n.Comment == nil && n.CommentID != 0 {
n.Comment, err = issues_model.GetCommentByID(ctx, n.CommentID)
if err != nil {
if issues_model.IsErrCommentNotExist(err) {
return issues_model.ErrCommentNotExist{
ID: n.CommentID,
IssueID: n.IssueID,
}
}
return err
}
}
return nil
}
func (n *Notification) loadUser(ctx context.Context) (err error) {
if n.User == nil {
n.User, err = user_model.GetUserByID(ctx, n.UserID)
if err != nil {
return fmt.Errorf("getUserByID [%d]: %w", n.UserID, err)
}
}
return nil
}
// GetRepo returns the repo of the notification
func (n *Notification) GetRepo(ctx context.Context) (*repo_model.Repository, error) {
return n.Repository, n.loadRepo(ctx)
}
// GetIssue returns the issue of the notification
func (n *Notification) GetIssue(ctx context.Context) (*issues_model.Issue, error) {
return n.Issue, n.loadIssue(ctx)
}
// HTMLURL formats a URL-string to the notification
func (n *Notification) HTMLURL(ctx context.Context) string {
switch n.Source {
case NotificationSourceIssue, NotificationSourcePullRequest:
if n.Comment != nil {
return n.Comment.HTMLURL(ctx)
}
return n.Issue.HTMLURL()
case NotificationSourceCommit:
return n.Repository.HTMLURL() + "/commit/" + url.PathEscape(n.CommitID)
case NotificationSourceRepository:
return n.Repository.HTMLURL()
}
return ""
}
// Link formats a relative URL-string to the notification
func (n *Notification) Link(ctx context.Context) string {
switch n.Source {
case NotificationSourceIssue, NotificationSourcePullRequest:
if n.Comment != nil {
return n.Comment.Link(ctx)
}
return n.Issue.Link()
case NotificationSourceCommit:
return n.Repository.Link() + "/commit/" + url.PathEscape(n.CommitID)
case NotificationSourceRepository:
return n.Repository.Link()
}
return ""
}
// APIURL formats a URL-string to the notification
func (n *Notification) APIURL() string {
return setting.AppURL + "api/v1/notifications/threads/" + strconv.FormatInt(n.ID, 10)
}
func notificationExists(notifications []*Notification, issueID, userID int64) bool {
for _, notification := range notifications {
if notification.IssueID == issueID && notification.UserID == userID {
return true
}
}
return false
}
// UserIDCount is a simple coalition of UserID and Count
type UserIDCount struct {
UserID int64
Count int64
}
// GetUIDsAndNotificationCounts returns the unread counts for every user between the two provided times.
// It must return all user IDs which appear during the period, including count=0 for users who have read all.
func GetUIDsAndNotificationCounts(ctx context.Context, since, until timeutil.TimeStamp) ([]UserIDCount, error) {
sql := `SELECT user_id, sum(case when status= ? then 1 else 0 end) AS count FROM notification ` +
`WHERE user_id IN (SELECT user_id FROM notification WHERE updated_unix >= ? AND ` +
`updated_unix < ?) GROUP BY user_id`
var res []UserIDCount
return res, db.GetEngine(ctx).SQL(sql, NotificationStatusUnread, since, until).Find(&res)
}
// SetIssueReadBy sets issue to be read by given user.
func SetIssueReadBy(ctx context.Context, issueID, userID int64) error {
if err := issues_model.UpdateIssueUserByRead(ctx, userID, issueID); err != nil {
return err
}
return setIssueNotificationStatusReadIfUnread(ctx, userID, issueID)
}
func setIssueNotificationStatusReadIfUnread(ctx context.Context, userID, issueID int64) error {
notification, err := GetIssueNotification(ctx, userID, issueID)
// ignore if not exists
if err != nil {
return nil
}
if notification.Status != NotificationStatusUnread {
return nil
}
notification.Status = NotificationStatusRead
_, err = db.GetEngine(ctx).ID(notification.ID).Cols("status").Update(notification)
return err
}
// SetRepoReadBy sets repo to be visited by given user.
func SetRepoReadBy(ctx context.Context, userID, repoID int64) error {
_, err := db.GetEngine(ctx).Where(builder.Eq{
"user_id": userID,
"status": NotificationStatusUnread,
"source": NotificationSourceRepository,
"repo_id": repoID,
}).Cols("status").Update(&Notification{Status: NotificationStatusRead})
return err
}
// SetNotificationStatus change the notification status
func SetNotificationStatus(ctx context.Context, notificationID int64, user *user_model.User, status NotificationStatus) (*Notification, error) {
notification, err := GetNotificationByID(ctx, notificationID)
if err != nil {
return notification, err
}
if notification.UserID != user.ID {
return nil, fmt.Errorf("Can't change notification of another user: %d, %d", notification.UserID, user.ID)
}
notification.Status = status
_, err = db.GetEngine(ctx).ID(notificationID).Update(notification)
return notification, err
}
// GetNotificationByID return notification by ID
func GetNotificationByID(ctx context.Context, notificationID int64) (*Notification, error) {
notification := new(Notification)
ok, err := db.GetEngine(ctx).
Where("id = ?", notificationID).
Get(notification)
if err != nil {
return nil, err
}
if !ok {
return nil, db.ErrNotExist{Resource: "notification", ID: notificationID}
}
return notification, nil
}
// UpdateNotificationStatuses updates the statuses of all of a user's notifications that are of the currentStatus type to the desiredStatus
func UpdateNotificationStatuses(ctx context.Context, user *user_model.User, currentStatus, desiredStatus NotificationStatus) error {
n := &Notification{Status: desiredStatus, UpdatedBy: user.ID}
_, err := db.GetEngine(ctx).
Where("user_id = ? AND status = ?", user.ID, currentStatus).
Cols("status", "updated_by", "updated_unix").
Update(n)
return err
}

View File

@@ -0,0 +1,499 @@
// Copyright 2024 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package activities
import (
"context"
"code.gitea.io/gitea/models/db"
issues_model "code.gitea.io/gitea/models/issues"
access_model "code.gitea.io/gitea/models/perm/access"
repo_model "code.gitea.io/gitea/models/repo"
"code.gitea.io/gitea/models/unit"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/container"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/util"
"xorm.io/builder"
)
// FindNotificationOptions represent the filters for notifications. If an ID is 0 it will be ignored.
type FindNotificationOptions struct {
db.ListOptions
UserID int64
RepoID int64
IssueID int64
Status []NotificationStatus
Source []NotificationSource
UpdatedAfterUnix int64
UpdatedBeforeUnix int64
}
// ToCond will convert each condition into a xorm-Cond
func (opts FindNotificationOptions) ToConds() builder.Cond {
cond := builder.NewCond()
if opts.UserID != 0 {
cond = cond.And(builder.Eq{"notification.user_id": opts.UserID})
}
if opts.RepoID != 0 {
cond = cond.And(builder.Eq{"notification.repo_id": opts.RepoID})
}
if opts.IssueID != 0 {
cond = cond.And(builder.Eq{"notification.issue_id": opts.IssueID})
}
if len(opts.Status) > 0 {
if len(opts.Status) == 1 {
cond = cond.And(builder.Eq{"notification.status": opts.Status[0]})
} else {
cond = cond.And(builder.In("notification.status", opts.Status))
}
}
if len(opts.Source) > 0 {
cond = cond.And(builder.In("notification.source", opts.Source))
}
if opts.UpdatedAfterUnix != 0 {
cond = cond.And(builder.Gte{"notification.updated_unix": opts.UpdatedAfterUnix})
}
if opts.UpdatedBeforeUnix != 0 {
cond = cond.And(builder.Lte{"notification.updated_unix": opts.UpdatedBeforeUnix})
}
return cond
}
func (opts FindNotificationOptions) ToOrders() string {
return "notification.updated_unix DESC"
}
// CreateOrUpdateIssueNotifications creates an issue notification
// for each watcher, or updates it if already exists
// receiverID > 0 just send to receiver, else send to all watcher
func CreateOrUpdateIssueNotifications(ctx context.Context, issueID, commentID, notificationAuthorID, receiverID int64) error {
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return err
}
defer committer.Close()
if err := createOrUpdateIssueNotifications(ctx, issueID, commentID, notificationAuthorID, receiverID); err != nil {
return err
}
return committer.Commit()
}
func createOrUpdateIssueNotifications(ctx context.Context, issueID, commentID, notificationAuthorID, receiverID int64) error {
// init
var toNotify container.Set[int64]
notifications, err := db.Find[Notification](ctx, FindNotificationOptions{
IssueID: issueID,
})
if err != nil {
return err
}
issue, err := issues_model.GetIssueByID(ctx, issueID)
if err != nil {
return err
}
if receiverID > 0 {
toNotify = make(container.Set[int64], 1)
toNotify.Add(receiverID)
} else {
toNotify = make(container.Set[int64], 32)
issueWatches, err := issues_model.GetIssueWatchersIDs(ctx, issueID, true)
if err != nil {
return err
}
toNotify.AddMultiple(issueWatches...)
if !(issue.IsPull && issues_model.HasWorkInProgressPrefix(issue.Title)) {
repoWatches, err := repo_model.GetRepoWatchersIDs(ctx, issue.RepoID)
if err != nil {
return err
}
toNotify.AddMultiple(repoWatches...)
}
issueParticipants, err := issue.GetParticipantIDsByIssue(ctx)
if err != nil {
return err
}
toNotify.AddMultiple(issueParticipants...)
// dont notify user who cause notification
delete(toNotify, notificationAuthorID)
// explicit unwatch on issue
issueUnWatches, err := issues_model.GetIssueWatchersIDs(ctx, issueID, false)
if err != nil {
return err
}
for _, id := range issueUnWatches {
toNotify.Remove(id)
}
}
err = issue.LoadRepo(ctx)
if err != nil {
return err
}
// notify
for userID := range toNotify {
issue.Repo.Units = nil
user, err := user_model.GetUserByID(ctx, userID)
if err != nil {
if user_model.IsErrUserNotExist(err) {
continue
}
return err
}
if issue.IsPull && !access_model.CheckRepoUnitUser(ctx, issue.Repo, user, unit.TypePullRequests) {
continue
}
if !issue.IsPull && !access_model.CheckRepoUnitUser(ctx, issue.Repo, user, unit.TypeIssues) {
continue
}
if notificationExists(notifications, issue.ID, userID) {
if err = updateIssueNotification(ctx, userID, issue.ID, commentID, notificationAuthorID); err != nil {
return err
}
continue
}
if err = createIssueNotification(ctx, userID, issue, commentID, notificationAuthorID); err != nil {
return err
}
}
return nil
}
// NotificationList contains a list of notifications
type NotificationList []*Notification
// LoadAttributes load Repo Issue User and Comment if not loaded
func (nl NotificationList) LoadAttributes(ctx context.Context) error {
if _, _, err := nl.LoadRepos(ctx); err != nil {
return err
}
if _, err := nl.LoadIssues(ctx); err != nil {
return err
}
if _, err := nl.LoadUsers(ctx); err != nil {
return err
}
if _, err := nl.LoadComments(ctx); err != nil {
return err
}
return nil
}
func (nl NotificationList) getPendingRepoIDs() []int64 {
return container.FilterSlice(nl, func(n *Notification) (int64, bool) {
if n.Repository != nil {
return 0, false
}
return n.RepoID, true
})
}
// LoadRepos loads repositories from database
func (nl NotificationList) LoadRepos(ctx context.Context) (repo_model.RepositoryList, []int, error) {
if len(nl) == 0 {
return repo_model.RepositoryList{}, []int{}, nil
}
repoIDs := nl.getPendingRepoIDs()
repos := make(map[int64]*repo_model.Repository, len(repoIDs))
left := len(repoIDs)
for left > 0 {
limit := db.DefaultMaxInSize
if left < limit {
limit = left
}
rows, err := db.GetEngine(ctx).
In("id", repoIDs[:limit]).
Rows(new(repo_model.Repository))
if err != nil {
return nil, nil, err
}
for rows.Next() {
var repo repo_model.Repository
err = rows.Scan(&repo)
if err != nil {
rows.Close()
return nil, nil, err
}
repos[repo.ID] = &repo
}
_ = rows.Close()
left -= limit
repoIDs = repoIDs[limit:]
}
failed := []int{}
reposList := make(repo_model.RepositoryList, 0, len(repoIDs))
for i, notification := range nl {
if notification.Repository == nil {
notification.Repository = repos[notification.RepoID]
}
if notification.Repository == nil {
log.Error("Notification[%d]: RepoID: %d not found", notification.ID, notification.RepoID)
failed = append(failed, i)
continue
}
var found bool
for _, r := range reposList {
if r.ID == notification.RepoID {
found = true
break
}
}
if !found {
reposList = append(reposList, notification.Repository)
}
}
return reposList, failed, nil
}
func (nl NotificationList) getPendingIssueIDs() []int64 {
ids := make(container.Set[int64], len(nl))
for _, notification := range nl {
if notification.Issue != nil {
continue
}
ids.Add(notification.IssueID)
}
return ids.Values()
}
// LoadIssues loads issues from database
func (nl NotificationList) LoadIssues(ctx context.Context) ([]int, error) {
if len(nl) == 0 {
return []int{}, nil
}
issueIDs := nl.getPendingIssueIDs()
issues := make(map[int64]*issues_model.Issue, len(issueIDs))
left := len(issueIDs)
for left > 0 {
limit := db.DefaultMaxInSize
if left < limit {
limit = left
}
rows, err := db.GetEngine(ctx).
In("id", issueIDs[:limit]).
Rows(new(issues_model.Issue))
if err != nil {
return nil, err
}
for rows.Next() {
var issue issues_model.Issue
err = rows.Scan(&issue)
if err != nil {
rows.Close()
return nil, err
}
issues[issue.ID] = &issue
}
_ = rows.Close()
left -= limit
issueIDs = issueIDs[limit:]
}
failures := []int{}
for i, notification := range nl {
if notification.Issue == nil {
notification.Issue = issues[notification.IssueID]
if notification.Issue == nil {
if notification.IssueID != 0 {
log.Error("Notification[%d]: IssueID: %d Not Found", notification.ID, notification.IssueID)
failures = append(failures, i)
}
continue
}
notification.Issue.Repo = notification.Repository
}
}
return failures, nil
}
// Without returns the notification list without the failures
func (nl NotificationList) Without(failures []int) NotificationList {
if len(failures) == 0 {
return nl
}
remaining := make([]*Notification, 0, len(nl))
last := -1
var i int
for _, i = range failures {
remaining = append(remaining, nl[last+1:i]...)
last = i
}
if len(nl) > i {
remaining = append(remaining, nl[i+1:]...)
}
return remaining
}
func (nl NotificationList) getPendingCommentIDs() []int64 {
ids := make(container.Set[int64], len(nl))
for _, notification := range nl {
if notification.CommentID == 0 || notification.Comment != nil {
continue
}
ids.Add(notification.CommentID)
}
return ids.Values()
}
func (nl NotificationList) getUserIDs() []int64 {
ids := make(container.Set[int64], len(nl))
for _, notification := range nl {
if notification.UserID == 0 || notification.User != nil {
continue
}
ids.Add(notification.UserID)
}
return ids.Values()
}
// LoadUsers loads users from database
func (nl NotificationList) LoadUsers(ctx context.Context) ([]int, error) {
if len(nl) == 0 {
return []int{}, nil
}
userIDs := nl.getUserIDs()
users := make(map[int64]*user_model.User, len(userIDs))
left := len(userIDs)
for left > 0 {
limit := db.DefaultMaxInSize
if left < limit {
limit = left
}
rows, err := db.GetEngine(ctx).
In("id", userIDs[:limit]).
Rows(new(user_model.User))
if err != nil {
return nil, err
}
for rows.Next() {
var user user_model.User
err = rows.Scan(&user)
if err != nil {
rows.Close()
return nil, err
}
users[user.ID] = &user
}
_ = rows.Close()
left -= limit
userIDs = userIDs[limit:]
}
failures := []int{}
for i, notification := range nl {
if notification.UserID > 0 && notification.User == nil && users[notification.UserID] != nil {
notification.User = users[notification.UserID]
if notification.User == nil {
log.Error("Notification[%d]: UserID[%d] failed to load", notification.ID, notification.UserID)
failures = append(failures, i)
continue
}
}
}
return failures, nil
}
// LoadComments loads comments from database
func (nl NotificationList) LoadComments(ctx context.Context) ([]int, error) {
if len(nl) == 0 {
return []int{}, nil
}
commentIDs := nl.getPendingCommentIDs()
comments := make(map[int64]*issues_model.Comment, len(commentIDs))
left := len(commentIDs)
for left > 0 {
limit := db.DefaultMaxInSize
if left < limit {
limit = left
}
rows, err := db.GetEngine(ctx).
In("id", commentIDs[:limit]).
Rows(new(issues_model.Comment))
if err != nil {
return nil, err
}
for rows.Next() {
var comment issues_model.Comment
err = rows.Scan(&comment)
if err != nil {
rows.Close()
return nil, err
}
comments[comment.ID] = &comment
}
_ = rows.Close()
left -= limit
commentIDs = commentIDs[limit:]
}
failures := []int{}
for i, notification := range nl {
if notification.CommentID > 0 && notification.Comment == nil && comments[notification.CommentID] != nil {
notification.Comment = comments[notification.CommentID]
if notification.Comment == nil {
log.Error("Notification[%d]: CommentID[%d] failed to load", notification.ID, notification.CommentID)
failures = append(failures, i)
continue
}
notification.Comment.Issue = notification.Issue
}
}
return failures, nil
}
// LoadIssuePullRequests loads all issues' pull requests if possible
func (nl NotificationList) LoadIssuePullRequests(ctx context.Context) error {
issues := make(map[int64]*issues_model.Issue, len(nl))
for _, notification := range nl {
if notification.Issue != nil && notification.Issue.IsPull && notification.Issue.PullRequest == nil {
issues[notification.Issue.ID] = notification.Issue
}
}
if len(issues) == 0 {
return nil
}
pulls, err := issues_model.GetPullRequestByIssueIDs(ctx, util.KeysOfMap(issues))
if err != nil {
return err
}
for _, pull := range pulls {
if issue := issues[pull.IssueID]; issue != nil {
issue.PullRequest = pull
issue.PullRequest.Issue = issue
}
}
return nil
}

View File

@@ -0,0 +1,140 @@
// Copyright 2017 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package activities_test
import (
"context"
"testing"
activities_model "code.gitea.io/gitea/models/activities"
"code.gitea.io/gitea/models/db"
issues_model "code.gitea.io/gitea/models/issues"
"code.gitea.io/gitea/models/unittest"
user_model "code.gitea.io/gitea/models/user"
"github.com/stretchr/testify/assert"
)
func TestCreateOrUpdateIssueNotifications(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
issue := unittest.AssertExistsAndLoadBean(t, &issues_model.Issue{ID: 1})
assert.NoError(t, activities_model.CreateOrUpdateIssueNotifications(db.DefaultContext, issue.ID, 0, 2, 0))
// User 9 is inactive, thus notifications for user 1 and 4 are created
notf := unittest.AssertExistsAndLoadBean(t, &activities_model.Notification{UserID: 1, IssueID: issue.ID})
assert.Equal(t, activities_model.NotificationStatusUnread, notf.Status)
unittest.CheckConsistencyFor(t, &issues_model.Issue{ID: issue.ID})
notf = unittest.AssertExistsAndLoadBean(t, &activities_model.Notification{UserID: 4, IssueID: issue.ID})
assert.Equal(t, activities_model.NotificationStatusUnread, notf.Status)
}
func TestNotificationsForUser(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 2})
notfs, err := db.Find[activities_model.Notification](db.DefaultContext, activities_model.FindNotificationOptions{
UserID: user.ID,
Status: []activities_model.NotificationStatus{
activities_model.NotificationStatusRead,
activities_model.NotificationStatusUnread,
},
})
assert.NoError(t, err)
if assert.Len(t, notfs, 3) {
assert.EqualValues(t, 5, notfs[0].ID)
assert.EqualValues(t, user.ID, notfs[0].UserID)
assert.EqualValues(t, 4, notfs[1].ID)
assert.EqualValues(t, user.ID, notfs[1].UserID)
assert.EqualValues(t, 2, notfs[2].ID)
assert.EqualValues(t, user.ID, notfs[2].UserID)
}
}
func TestNotification_GetRepo(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
notf := unittest.AssertExistsAndLoadBean(t, &activities_model.Notification{RepoID: 1})
repo, err := notf.GetRepo(db.DefaultContext)
assert.NoError(t, err)
assert.Equal(t, repo, notf.Repository)
assert.EqualValues(t, notf.RepoID, repo.ID)
}
func TestNotification_GetIssue(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
notf := unittest.AssertExistsAndLoadBean(t, &activities_model.Notification{RepoID: 1})
issue, err := notf.GetIssue(db.DefaultContext)
assert.NoError(t, err)
assert.Equal(t, issue, notf.Issue)
assert.EqualValues(t, notf.IssueID, issue.ID)
}
func TestGetNotificationCount(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 1})
cnt, err := db.Count[activities_model.Notification](db.DefaultContext, activities_model.FindNotificationOptions{
UserID: user.ID,
Status: []activities_model.NotificationStatus{
activities_model.NotificationStatusRead,
},
})
assert.NoError(t, err)
assert.EqualValues(t, 0, cnt)
cnt, err = db.Count[activities_model.Notification](db.DefaultContext, activities_model.FindNotificationOptions{
UserID: user.ID,
Status: []activities_model.NotificationStatus{
activities_model.NotificationStatusUnread,
},
})
assert.NoError(t, err)
assert.EqualValues(t, 1, cnt)
}
func TestSetNotificationStatus(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 2})
notf := unittest.AssertExistsAndLoadBean(t,
&activities_model.Notification{UserID: user.ID, Status: activities_model.NotificationStatusRead})
_, err := activities_model.SetNotificationStatus(db.DefaultContext, notf.ID, user, activities_model.NotificationStatusPinned)
assert.NoError(t, err)
unittest.AssertExistsAndLoadBean(t,
&activities_model.Notification{ID: notf.ID, Status: activities_model.NotificationStatusPinned})
_, err = activities_model.SetNotificationStatus(db.DefaultContext, 1, user, activities_model.NotificationStatusRead)
assert.Error(t, err)
_, err = activities_model.SetNotificationStatus(db.DefaultContext, unittest.NonexistentID, user, activities_model.NotificationStatusRead)
assert.Error(t, err)
}
func TestUpdateNotificationStatuses(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 2})
notfUnread := unittest.AssertExistsAndLoadBean(t,
&activities_model.Notification{UserID: user.ID, Status: activities_model.NotificationStatusUnread})
notfRead := unittest.AssertExistsAndLoadBean(t,
&activities_model.Notification{UserID: user.ID, Status: activities_model.NotificationStatusRead})
notfPinned := unittest.AssertExistsAndLoadBean(t,
&activities_model.Notification{UserID: user.ID, Status: activities_model.NotificationStatusPinned})
assert.NoError(t, activities_model.UpdateNotificationStatuses(db.DefaultContext, user, activities_model.NotificationStatusUnread, activities_model.NotificationStatusRead))
unittest.AssertExistsAndLoadBean(t,
&activities_model.Notification{ID: notfUnread.ID, Status: activities_model.NotificationStatusRead})
unittest.AssertExistsAndLoadBean(t,
&activities_model.Notification{ID: notfRead.ID, Status: activities_model.NotificationStatusRead})
unittest.AssertExistsAndLoadBean(t,
&activities_model.Notification{ID: notfPinned.ID, Status: activities_model.NotificationStatusPinned})
}
func TestSetIssueReadBy(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 1})
issue := unittest.AssertExistsAndLoadBean(t, &issues_model.Issue{ID: 1})
assert.NoError(t, db.WithTx(db.DefaultContext, func(ctx context.Context) error {
return activities_model.SetIssueReadBy(ctx, issue.ID, user.ID)
}))
nt, err := activities_model.GetIssueNotification(db.DefaultContext, user.ID, issue.ID)
assert.NoError(t, err)
assert.EqualValues(t, activities_model.NotificationStatusRead, nt.Status)
}

View File

@@ -0,0 +1,395 @@
// Copyright 2017 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package activities
import (
"context"
"fmt"
"sort"
"time"
"code.gitea.io/gitea/models/db"
issues_model "code.gitea.io/gitea/models/issues"
repo_model "code.gitea.io/gitea/models/repo"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/git"
"code.gitea.io/gitea/modules/gitrepo"
"xorm.io/builder"
"xorm.io/xorm"
)
// ActivityAuthorData represents statistical git commit count data
type ActivityAuthorData struct {
Name string `json:"name"`
Login string `json:"login"`
AvatarLink string `json:"avatar_link"`
HomeLink string `json:"home_link"`
Commits int64 `json:"commits"`
}
// ActivityStats represents issue and pull request information.
type ActivityStats struct {
OpenedPRs issues_model.PullRequestList
OpenedPRAuthorCount int64
MergedPRs issues_model.PullRequestList
MergedPRAuthorCount int64
ActiveIssues issues_model.IssueList
OpenedIssues issues_model.IssueList
OpenedIssueAuthorCount int64
ClosedIssues issues_model.IssueList
ClosedIssueAuthorCount int64
UnresolvedIssues issues_model.IssueList
PublishedReleases []*repo_model.Release
PublishedReleaseAuthorCount int64
Code *git.CodeActivityStats
}
// GetActivityStats return stats for repository at given time range
func GetActivityStats(ctx context.Context, repo *repo_model.Repository, timeFrom time.Time, releases, issues, prs, code bool) (*ActivityStats, error) {
stats := &ActivityStats{Code: &git.CodeActivityStats{}}
if releases {
if err := stats.FillReleases(ctx, repo.ID, timeFrom); err != nil {
return nil, fmt.Errorf("FillReleases: %w", err)
}
}
if prs {
if err := stats.FillPullRequests(ctx, repo.ID, timeFrom); err != nil {
return nil, fmt.Errorf("FillPullRequests: %w", err)
}
}
if issues {
if err := stats.FillIssues(ctx, repo.ID, timeFrom); err != nil {
return nil, fmt.Errorf("FillIssues: %w", err)
}
}
if err := stats.FillUnresolvedIssues(ctx, repo.ID, timeFrom, issues, prs); err != nil {
return nil, fmt.Errorf("FillUnresolvedIssues: %w", err)
}
if code {
gitRepo, closer, err := gitrepo.RepositoryFromContextOrOpen(ctx, repo)
if err != nil {
return nil, fmt.Errorf("OpenRepository: %w", err)
}
defer closer.Close()
code, err := gitRepo.GetCodeActivityStats(timeFrom, repo.DefaultBranch)
if err != nil {
return nil, fmt.Errorf("FillFromGit: %w", err)
}
stats.Code = code
}
return stats, nil
}
// GetActivityStatsTopAuthors returns top author stats for git commits for all branches
func GetActivityStatsTopAuthors(ctx context.Context, repo *repo_model.Repository, timeFrom time.Time, count int) ([]*ActivityAuthorData, error) {
gitRepo, closer, err := gitrepo.RepositoryFromContextOrOpen(ctx, repo)
if err != nil {
return nil, fmt.Errorf("OpenRepository: %w", err)
}
defer closer.Close()
code, err := gitRepo.GetCodeActivityStats(timeFrom, "")
if err != nil {
return nil, fmt.Errorf("FillFromGit: %w", err)
}
if code.Authors == nil {
return nil, nil
}
users := make(map[int64]*ActivityAuthorData)
var unknownUserID int64
unknownUserAvatarLink := user_model.NewGhostUser().AvatarLink(ctx)
for _, v := range code.Authors {
if len(v.Email) == 0 {
continue
}
u, err := user_model.GetUserByEmail(ctx, v.Email)
if u == nil || user_model.IsErrUserNotExist(err) {
unknownUserID--
users[unknownUserID] = &ActivityAuthorData{
Name: v.Name,
AvatarLink: unknownUserAvatarLink,
Commits: v.Commits,
}
continue
}
if err != nil {
return nil, err
}
if user, ok := users[u.ID]; !ok {
users[u.ID] = &ActivityAuthorData{
Name: u.DisplayName(),
Login: u.LowerName,
AvatarLink: u.AvatarLink(ctx),
HomeLink: u.HomeLink(),
Commits: v.Commits,
}
} else {
user.Commits += v.Commits
}
}
v := make([]*ActivityAuthorData, 0, len(users))
for _, u := range users {
v = append(v, u)
}
sort.Slice(v, func(i, j int) bool {
return v[i].Commits > v[j].Commits
})
cnt := count
if cnt > len(v) {
cnt = len(v)
}
return v[:cnt], nil
}
// ActivePRCount returns total active pull request count
func (stats *ActivityStats) ActivePRCount() int {
return stats.OpenedPRCount() + stats.MergedPRCount()
}
// OpenedPRCount returns opened pull request count
func (stats *ActivityStats) OpenedPRCount() int {
return len(stats.OpenedPRs)
}
// OpenedPRPerc returns opened pull request percents from total active
func (stats *ActivityStats) OpenedPRPerc() int {
return int(float32(stats.OpenedPRCount()) / float32(stats.ActivePRCount()) * 100.0)
}
// MergedPRCount returns merged pull request count
func (stats *ActivityStats) MergedPRCount() int {
return len(stats.MergedPRs)
}
// MergedPRPerc returns merged pull request percent from total active
func (stats *ActivityStats) MergedPRPerc() int {
return int(float32(stats.MergedPRCount()) / float32(stats.ActivePRCount()) * 100.0)
}
// ActiveIssueCount returns total active issue count
func (stats *ActivityStats) ActiveIssueCount() int {
return len(stats.ActiveIssues)
}
// OpenedIssueCount returns open issue count
func (stats *ActivityStats) OpenedIssueCount() int {
return len(stats.OpenedIssues)
}
// OpenedIssuePerc returns open issue count percent from total active
func (stats *ActivityStats) OpenedIssuePerc() int {
return int(float32(stats.OpenedIssueCount()) / float32(stats.ActiveIssueCount()) * 100.0)
}
// ClosedIssueCount returns closed issue count
func (stats *ActivityStats) ClosedIssueCount() int {
return len(stats.ClosedIssues)
}
// ClosedIssuePerc returns closed issue count percent from total active
func (stats *ActivityStats) ClosedIssuePerc() int {
return int(float32(stats.ClosedIssueCount()) / float32(stats.ActiveIssueCount()) * 100.0)
}
// UnresolvedIssueCount returns unresolved issue and pull request count
func (stats *ActivityStats) UnresolvedIssueCount() int {
return len(stats.UnresolvedIssues)
}
// PublishedReleaseCount returns published release count
func (stats *ActivityStats) PublishedReleaseCount() int {
return len(stats.PublishedReleases)
}
// FillPullRequests returns pull request information for activity page
func (stats *ActivityStats) FillPullRequests(ctx context.Context, repoID int64, fromTime time.Time) error {
var err error
var count int64
// Merged pull requests
sess := pullRequestsForActivityStatement(ctx, repoID, fromTime, true)
sess.OrderBy("pull_request.merged_unix DESC")
stats.MergedPRs = make(issues_model.PullRequestList, 0)
if err = sess.Find(&stats.MergedPRs); err != nil {
return err
}
if err = stats.MergedPRs.LoadAttributes(ctx); err != nil {
return err
}
// Merged pull request authors
sess = pullRequestsForActivityStatement(ctx, repoID, fromTime, true)
if _, err = sess.Select("count(distinct issue.poster_id) as `count`").Table("pull_request").Get(&count); err != nil {
return err
}
stats.MergedPRAuthorCount = count
// Opened pull requests
sess = pullRequestsForActivityStatement(ctx, repoID, fromTime, false)
sess.OrderBy("issue.created_unix ASC")
stats.OpenedPRs = make(issues_model.PullRequestList, 0)
if err = sess.Find(&stats.OpenedPRs); err != nil {
return err
}
if err = stats.OpenedPRs.LoadAttributes(ctx); err != nil {
return err
}
// Opened pull request authors
sess = pullRequestsForActivityStatement(ctx, repoID, fromTime, false)
if _, err = sess.Select("count(distinct issue.poster_id) as `count`").Table("pull_request").Get(&count); err != nil {
return err
}
stats.OpenedPRAuthorCount = count
return nil
}
func pullRequestsForActivityStatement(ctx context.Context, repoID int64, fromTime time.Time, merged bool) *xorm.Session {
sess := db.GetEngine(ctx).Where("pull_request.base_repo_id=?", repoID).
Join("INNER", "issue", "pull_request.issue_id = issue.id")
if merged {
sess.And("pull_request.has_merged = ?", true)
sess.And("pull_request.merged_unix >= ?", fromTime.Unix())
} else {
sess.And("issue.is_closed = ?", false)
sess.And("issue.created_unix >= ?", fromTime.Unix())
}
return sess
}
// FillIssues returns issue information for activity page
func (stats *ActivityStats) FillIssues(ctx context.Context, repoID int64, fromTime time.Time) error {
var err error
var count int64
// Closed issues
sess := issuesForActivityStatement(ctx, repoID, fromTime, true, false)
sess.OrderBy("issue.closed_unix DESC")
stats.ClosedIssues = make(issues_model.IssueList, 0)
if err = sess.Find(&stats.ClosedIssues); err != nil {
return err
}
// Closed issue authors
sess = issuesForActivityStatement(ctx, repoID, fromTime, true, false)
if _, err = sess.Select("count(distinct issue.poster_id) as `count`").Table("issue").Get(&count); err != nil {
return err
}
stats.ClosedIssueAuthorCount = count
// New issues
sess = newlyCreatedIssues(ctx, repoID, fromTime)
sess.OrderBy("issue.created_unix ASC")
stats.OpenedIssues = make(issues_model.IssueList, 0)
if err = sess.Find(&stats.OpenedIssues); err != nil {
return err
}
// Active issues
sess = activeIssues(ctx, repoID, fromTime)
sess.OrderBy("issue.created_unix ASC")
stats.ActiveIssues = make(issues_model.IssueList, 0)
if err = sess.Find(&stats.ActiveIssues); err != nil {
return err
}
// Opened issue authors
sess = issuesForActivityStatement(ctx, repoID, fromTime, false, false)
if _, err = sess.Select("count(distinct issue.poster_id) as `count`").Table("issue").Get(&count); err != nil {
return err
}
stats.OpenedIssueAuthorCount = count
return nil
}
// FillUnresolvedIssues returns unresolved issue and pull request information for activity page
func (stats *ActivityStats) FillUnresolvedIssues(ctx context.Context, repoID int64, fromTime time.Time, issues, prs bool) error {
// Check if we need to select anything
if !issues && !prs {
return nil
}
sess := issuesForActivityStatement(ctx, repoID, fromTime, false, true)
if !issues || !prs {
sess.And("issue.is_pull = ?", prs)
}
sess.OrderBy("issue.updated_unix DESC")
stats.UnresolvedIssues = make(issues_model.IssueList, 0)
return sess.Find(&stats.UnresolvedIssues)
}
func newlyCreatedIssues(ctx context.Context, repoID int64, fromTime time.Time) *xorm.Session {
sess := db.GetEngine(ctx).Where("issue.repo_id = ?", repoID).
And("issue.is_pull = ?", false). // Retain the is_pull check to exclude pull requests
And("issue.created_unix >= ?", fromTime.Unix()) // Include all issues created after fromTime
return sess
}
func activeIssues(ctx context.Context, repoID int64, fromTime time.Time) *xorm.Session {
sess := db.GetEngine(ctx).Where("issue.repo_id = ?", repoID).
And("issue.is_pull = ?", false).
And(builder.Or(
builder.Gte{"issue.created_unix": fromTime.Unix()},
builder.Gte{"issue.closed_unix": fromTime.Unix()},
))
return sess
}
func issuesForActivityStatement(ctx context.Context, repoID int64, fromTime time.Time, closed, unresolved bool) *xorm.Session {
sess := db.GetEngine(ctx).Where("issue.repo_id = ?", repoID).
And("issue.is_closed = ?", closed)
if !unresolved {
sess.And("issue.is_pull = ?", false)
if closed {
sess.And("issue.closed_unix >= ?", fromTime.Unix())
} else {
sess.And("issue.created_unix >= ?", fromTime.Unix())
}
} else {
sess.And("issue.created_unix < ?", fromTime.Unix())
sess.And("issue.updated_unix >= ?", fromTime.Unix())
}
return sess
}
// FillReleases returns release information for activity page
func (stats *ActivityStats) FillReleases(ctx context.Context, repoID int64, fromTime time.Time) error {
var err error
var count int64
// Published releases list
sess := releasesForActivityStatement(ctx, repoID, fromTime)
sess.OrderBy("`release`.created_unix DESC")
stats.PublishedReleases = make([]*repo_model.Release, 0)
if err = sess.Find(&stats.PublishedReleases); err != nil {
return err
}
// Published releases authors
sess = releasesForActivityStatement(ctx, repoID, fromTime)
if _, err = sess.Select("count(distinct `release`.publisher_id) as `count`").Table("release").Get(&count); err != nil {
return err
}
stats.PublishedReleaseAuthorCount = count
return nil
}
func releasesForActivityStatement(ctx context.Context, repoID int64, fromTime time.Time) *xorm.Session {
return db.GetEngine(ctx).Where("`release`.repo_id = ?", repoID).
And("`release`.is_draft = ?", false).
And("`release`.created_unix >= ?", fromTime.Unix())
}

View File

@@ -0,0 +1,120 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package activities
import (
"context"
asymkey_model "code.gitea.io/gitea/models/asymkey"
"code.gitea.io/gitea/models/auth"
"code.gitea.io/gitea/models/db"
git_model "code.gitea.io/gitea/models/git"
issues_model "code.gitea.io/gitea/models/issues"
"code.gitea.io/gitea/models/organization"
access_model "code.gitea.io/gitea/models/perm/access"
project_model "code.gitea.io/gitea/models/project"
repo_model "code.gitea.io/gitea/models/repo"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/models/webhook"
"code.gitea.io/gitea/modules/setting"
)
// Statistic contains the database statistics
type Statistic struct {
Counter struct {
User, Org, PublicKey,
Repo, Watch, Star, Access,
Issue, IssueClosed, IssueOpen,
Comment, Oauth, Follow,
Mirror, Release, AuthSource, Webhook,
Milestone, Label, HookTask,
Team, UpdateTask, Project,
ProjectColumn, Attachment,
Branches, Tags, CommitStatus int64
IssueByLabel []IssueByLabelCount
IssueByRepository []IssueByRepositoryCount
}
}
// IssueByLabelCount contains the number of issue group by label
type IssueByLabelCount struct {
Count int64
Label string
}
// IssueByRepositoryCount contains the number of issue group by repository
type IssueByRepositoryCount struct {
Count int64
OwnerName string
Repository string
}
// GetStatistic returns the database statistics
func GetStatistic(ctx context.Context) (stats Statistic) {
e := db.GetEngine(ctx)
stats.Counter.User = user_model.CountUsers(ctx, nil)
stats.Counter.Org, _ = db.Count[organization.Organization](ctx, organization.FindOrgOptions{IncludePrivate: true})
stats.Counter.PublicKey, _ = e.Count(new(asymkey_model.PublicKey))
stats.Counter.Repo, _ = repo_model.CountRepositories(ctx, repo_model.CountRepositoryOptions{})
stats.Counter.Watch, _ = e.Count(new(repo_model.Watch))
stats.Counter.Star, _ = e.Count(new(repo_model.Star))
stats.Counter.Access, _ = e.Count(new(access_model.Access))
stats.Counter.Branches, _ = e.Count(new(git_model.Branch))
stats.Counter.Tags, _ = e.Where("is_draft=?", false).Count(new(repo_model.Release))
stats.Counter.CommitStatus, _ = e.Count(new(git_model.CommitStatus))
type IssueCount struct {
Count int64
IsClosed bool
}
if setting.Metrics.EnabledIssueByLabel {
stats.Counter.IssueByLabel = []IssueByLabelCount{}
_ = e.Select("COUNT(*) AS count, l.name AS label").
Join("LEFT", "label l", "l.id=il.label_id").
Table("issue_label il").
GroupBy("l.name").
Find(&stats.Counter.IssueByLabel)
}
if setting.Metrics.EnabledIssueByRepository {
stats.Counter.IssueByRepository = []IssueByRepositoryCount{}
_ = e.Select("COUNT(*) AS count, r.owner_name, r.name AS repository").
Join("LEFT", "repository r", "r.id=i.repo_id").
Table("issue i").
GroupBy("r.owner_name, r.name").
Find(&stats.Counter.IssueByRepository)
}
var issueCounts []IssueCount
_ = e.Select("COUNT(*) AS count, is_closed").Table("issue").GroupBy("is_closed").Find(&issueCounts)
for _, c := range issueCounts {
if c.IsClosed {
stats.Counter.IssueClosed = c.Count
} else {
stats.Counter.IssueOpen = c.Count
}
}
stats.Counter.Issue = stats.Counter.IssueClosed + stats.Counter.IssueOpen
stats.Counter.Comment, _ = e.Count(new(issues_model.Comment))
stats.Counter.Oauth = 0
stats.Counter.Follow, _ = e.Count(new(user_model.Follow))
stats.Counter.Mirror, _ = e.Count(new(repo_model.Mirror))
stats.Counter.Release, _ = e.Count(new(repo_model.Release))
stats.Counter.AuthSource, _ = db.Count[auth.Source](ctx, auth.FindSourcesOptions{})
stats.Counter.Webhook, _ = e.Count(new(webhook.Webhook))
stats.Counter.Milestone, _ = e.Count(new(issues_model.Milestone))
stats.Counter.Label, _ = e.Count(new(issues_model.Label))
stats.Counter.HookTask, _ = e.Count(new(webhook.HookTask))
stats.Counter.Team, _ = e.Count(new(organization.Team))
stats.Counter.Attachment, _ = e.Count(new(repo_model.Attachment))
stats.Counter.Project, _ = e.Count(new(project_model.Project))
stats.Counter.ProjectColumn, _ = e.Count(new(project_model.Column))
return stats
}

View File

@@ -0,0 +1,82 @@
// Copyright 2018 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package activities
import (
"context"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/organization"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/timeutil"
)
// UserHeatmapData represents the data needed to create a heatmap
type UserHeatmapData struct {
Timestamp timeutil.TimeStamp `json:"timestamp"`
Contributions int64 `json:"contributions"`
}
// GetUserHeatmapDataByUser returns an array of UserHeatmapData
func GetUserHeatmapDataByUser(ctx context.Context, user, doer *user_model.User) ([]*UserHeatmapData, error) {
return getUserHeatmapData(ctx, user, nil, doer)
}
// GetUserHeatmapDataByUserTeam returns an array of UserHeatmapData
func GetUserHeatmapDataByUserTeam(ctx context.Context, user *user_model.User, team *organization.Team, doer *user_model.User) ([]*UserHeatmapData, error) {
return getUserHeatmapData(ctx, user, team, doer)
}
func getUserHeatmapData(ctx context.Context, user *user_model.User, team *organization.Team, doer *user_model.User) ([]*UserHeatmapData, error) {
hdata := make([]*UserHeatmapData, 0)
if !ActivityReadable(user, doer) {
return hdata, nil
}
// Group by 15 minute intervals which will allow the client to accurately shift the timestamp to their timezone.
// The interval is based on the fact that there are timezones such as UTC +5:30 and UTC +12:45.
groupBy := "created_unix / 900 * 900"
groupByName := "timestamp" // We need this extra case because mssql doesn't allow grouping by alias
switch {
case setting.Database.Type.IsMySQL():
groupBy = "created_unix DIV 900 * 900"
case setting.Database.Type.IsMSSQL():
groupByName = groupBy
}
cond, err := ActivityQueryCondition(ctx, GetFeedsOptions{
RequestedUser: user,
RequestedTeam: team,
Actor: doer,
IncludePrivate: true, // don't filter by private, as we already filter by repo access
IncludeDeleted: true,
// * Heatmaps for individual users only include actions that the user themself did.
// * For organizations actions by all users that were made in owned
// repositories are counted.
OnlyPerformedBy: !user.IsOrganization(),
})
if err != nil {
return nil, err
}
return hdata, db.GetEngine(ctx).
Select(groupBy+" AS timestamp, count(user_id) as contributions").
Table("action").
Where(cond).
And("created_unix > ?", timeutil.TimeStampNow()-31536000).
GroupBy(groupByName).
OrderBy("timestamp").
Find(&hdata)
}
// GetTotalContributionsInHeatmap returns the total number of contributions in a heatmap
func GetTotalContributionsInHeatmap(hdata []*UserHeatmapData) int64 {
var total int64
for _, v := range hdata {
total += v.Contributions
}
return total
}

View File

@@ -0,0 +1,100 @@
// Copyright 2018 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package activities_test
import (
"testing"
"time"
activities_model "code.gitea.io/gitea/models/activities"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/unittest"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/json"
"code.gitea.io/gitea/modules/timeutil"
"github.com/stretchr/testify/assert"
)
func TestGetUserHeatmapDataByUser(t *testing.T) {
testCases := []struct {
desc string
userID int64
doerID int64
CountResult int
JSONResult string
}{
{
"self looks at action in private repo",
2, 2, 1, `[{"timestamp":1603227600,"contributions":1}]`,
},
{
"admin looks at action in private repo",
2, 1, 1, `[{"timestamp":1603227600,"contributions":1}]`,
},
{
"other user looks at action in private repo",
2, 3, 0, `[]`,
},
{
"nobody looks at action in private repo",
2, 0, 0, `[]`,
},
{
"collaborator looks at action in private repo",
16, 15, 1, `[{"timestamp":1603267200,"contributions":1}]`,
},
{
"no action action not performed by target user",
3, 3, 0, `[]`,
},
{
"multiple actions performed with two grouped together",
10, 10, 3, `[{"timestamp":1603009800,"contributions":1},{"timestamp":1603010700,"contributions":2}]`,
},
}
// Prepare
assert.NoError(t, unittest.PrepareTestDatabase())
// Mock time
timeutil.MockSet(time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC))
defer timeutil.MockUnset()
for _, tc := range testCases {
user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: tc.userID})
doer := &user_model.User{ID: tc.doerID}
_, err := unittest.LoadBeanIfExists(doer)
assert.NoError(t, err)
if tc.doerID == 0 {
doer = nil
}
// get the action for comparison
actions, count, err := activities_model.GetFeeds(db.DefaultContext, activities_model.GetFeedsOptions{
RequestedUser: user,
Actor: doer,
IncludePrivate: true,
OnlyPerformedBy: true,
IncludeDeleted: true,
})
assert.NoError(t, err)
// Get the heatmap and compare
heatmap, err := activities_model.GetUserHeatmapDataByUser(db.DefaultContext, user, doer)
var contributions int
for _, hm := range heatmap {
contributions += int(hm.Contributions)
}
assert.NoError(t, err)
assert.Len(t, actions, contributions, "invalid action count: did the test data became too old?")
assert.Equal(t, count, int64(contributions))
assert.Equal(t, tc.CountResult, contributions, "testcase '%s'", tc.desc)
// Test JSON rendering
jsonData, err := json.Marshal(heatmap)
assert.NoError(t, err)
assert.JSONEq(t, tc.JSONResult, string(jsonData))
}
}

211
models/admin/task.go Normal file
View File

@@ -0,0 +1,211 @@
// Copyright 2019 Gitea. All rights reserved.
// SPDX-License-Identifier: MIT
package admin
import (
"context"
"fmt"
"code.gitea.io/gitea/models/db"
repo_model "code.gitea.io/gitea/models/repo"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/json"
"code.gitea.io/gitea/modules/migration"
"code.gitea.io/gitea/modules/secret"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/structs"
"code.gitea.io/gitea/modules/timeutil"
"code.gitea.io/gitea/modules/util"
)
// Task represents a task
type Task struct {
ID int64
DoerID int64 `xorm:"index"` // operator
Doer *user_model.User `xorm:"-"`
OwnerID int64 `xorm:"index"` // repo owner id, when creating, the repoID maybe zero
Owner *user_model.User `xorm:"-"`
RepoID int64 `xorm:"index"`
Repo *repo_model.Repository `xorm:"-"`
Type structs.TaskType
Status structs.TaskStatus `xorm:"index"`
StartTime timeutil.TimeStamp
EndTime timeutil.TimeStamp
PayloadContent string `xorm:"TEXT"`
Message string `xorm:"TEXT"` // if task failed, saved the error reason, it could be a JSON string of TranslatableMessage or a plain message
Created timeutil.TimeStamp `xorm:"created"`
}
func init() {
db.RegisterModel(new(Task))
}
// TranslatableMessage represents JSON struct that can be translated with a Locale
type TranslatableMessage struct {
Format string
Args []any `json:"omitempty"`
}
// LoadRepo loads repository of the task
func (task *Task) LoadRepo(ctx context.Context) error {
if task.Repo != nil {
return nil
}
var repo repo_model.Repository
has, err := db.GetEngine(ctx).ID(task.RepoID).Get(&repo)
if err != nil {
return err
} else if !has {
return repo_model.ErrRepoNotExist{
ID: task.RepoID,
}
}
task.Repo = &repo
return nil
}
// LoadDoer loads do user
func (task *Task) LoadDoer(ctx context.Context) error {
if task.Doer != nil {
return nil
}
var doer user_model.User
has, err := db.GetEngine(ctx).ID(task.DoerID).Get(&doer)
if err != nil {
return err
} else if !has {
return user_model.ErrUserNotExist{
UID: task.DoerID,
}
}
task.Doer = &doer
return nil
}
// LoadOwner loads owner user
func (task *Task) LoadOwner(ctx context.Context) error {
if task.Owner != nil {
return nil
}
var owner user_model.User
has, err := db.GetEngine(ctx).ID(task.OwnerID).Get(&owner)
if err != nil {
return err
} else if !has {
return user_model.ErrUserNotExist{
UID: task.OwnerID,
}
}
task.Owner = &owner
return nil
}
// UpdateCols updates some columns
func (task *Task) UpdateCols(ctx context.Context, cols ...string) error {
_, err := db.GetEngine(ctx).ID(task.ID).Cols(cols...).Update(task)
return err
}
// MigrateConfig returns task config when migrate repository
func (task *Task) MigrateConfig() (*migration.MigrateOptions, error) {
if task.Type == structs.TaskTypeMigrateRepo {
var opts migration.MigrateOptions
err := json.Unmarshal([]byte(task.PayloadContent), &opts)
if err != nil {
return nil, err
}
// decrypt credentials
if opts.CloneAddrEncrypted != "" {
if opts.CloneAddr, err = secret.DecryptSecret(setting.SecretKey, opts.CloneAddrEncrypted); err != nil {
return nil, err
}
}
if opts.AuthPasswordEncrypted != "" {
if opts.AuthPassword, err = secret.DecryptSecret(setting.SecretKey, opts.AuthPasswordEncrypted); err != nil {
return nil, err
}
}
if opts.AuthTokenEncrypted != "" {
if opts.AuthToken, err = secret.DecryptSecret(setting.SecretKey, opts.AuthTokenEncrypted); err != nil {
return nil, err
}
}
return &opts, nil
}
return nil, fmt.Errorf("Task type is %s, not Migrate Repo", task.Type.Name())
}
// ErrTaskDoesNotExist represents a "TaskDoesNotExist" kind of error.
type ErrTaskDoesNotExist struct {
ID int64
RepoID int64
Type structs.TaskType
}
// IsErrTaskDoesNotExist checks if an error is a ErrTaskDoesNotExist.
func IsErrTaskDoesNotExist(err error) bool {
_, ok := err.(ErrTaskDoesNotExist)
return ok
}
func (err ErrTaskDoesNotExist) Error() string {
return fmt.Sprintf("task does not exist [id: %d, repo_id: %d, type: %d]",
err.ID, err.RepoID, err.Type)
}
func (err ErrTaskDoesNotExist) Unwrap() error {
return util.ErrNotExist
}
// GetMigratingTask returns the migrating task by repo's id
func GetMigratingTask(ctx context.Context, repoID int64) (*Task, error) {
task := Task{
RepoID: repoID,
Type: structs.TaskTypeMigrateRepo,
}
has, err := db.GetEngine(ctx).Get(&task)
if err != nil {
return nil, err
} else if !has {
return nil, ErrTaskDoesNotExist{0, repoID, task.Type}
}
return &task, nil
}
// CreateTask creates a task on database
func CreateTask(ctx context.Context, task *Task) error {
return db.Insert(ctx, task)
}
// FinishMigrateTask updates database when migrate task finished
func FinishMigrateTask(ctx context.Context, task *Task) error {
task.Status = structs.TaskStatusFinished
task.EndTime = timeutil.TimeStampNow()
// delete credentials when we're done, they're a liability.
conf, err := task.MigrateConfig()
if err != nil {
return err
}
conf.AuthPassword = ""
conf.AuthToken = ""
conf.CloneAddr = util.SanitizeCredentialURLs(conf.CloneAddr)
conf.AuthPasswordEncrypted = ""
conf.AuthTokenEncrypted = ""
conf.CloneAddrEncrypted = ""
confBytes, err := json.Marshal(conf)
if err != nil {
return err
}
task.PayloadContent = string(confBytes)
_, err = db.GetEngine(ctx).ID(task.ID).Cols("status", "end_time", "payload_content").Update(task)
return err
}

318
models/asymkey/error.go Normal file
View File

@@ -0,0 +1,318 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package asymkey
import (
"fmt"
"code.gitea.io/gitea/modules/util"
)
// ErrKeyUnableVerify represents a "KeyUnableVerify" kind of error.
type ErrKeyUnableVerify struct {
Result string
}
// IsErrKeyUnableVerify checks if an error is a ErrKeyUnableVerify.
func IsErrKeyUnableVerify(err error) bool {
_, ok := err.(ErrKeyUnableVerify)
return ok
}
func (err ErrKeyUnableVerify) Error() string {
return fmt.Sprintf("Unable to verify key content [result: %s]", err.Result)
}
// ErrKeyIsPrivate is returned when the provided key is a private key not a public key
var ErrKeyIsPrivate = util.NewSilentWrapErrorf(util.ErrInvalidArgument, "the provided key is a private key")
// ErrKeyNotExist represents a "KeyNotExist" kind of error.
type ErrKeyNotExist struct {
ID int64
}
// IsErrKeyNotExist checks if an error is a ErrKeyNotExist.
func IsErrKeyNotExist(err error) bool {
_, ok := err.(ErrKeyNotExist)
return ok
}
func (err ErrKeyNotExist) Error() string {
return fmt.Sprintf("public key does not exist [id: %d]", err.ID)
}
func (err ErrKeyNotExist) Unwrap() error {
return util.ErrNotExist
}
// ErrKeyAlreadyExist represents a "KeyAlreadyExist" kind of error.
type ErrKeyAlreadyExist struct {
OwnerID int64
Fingerprint string
Content string
}
// IsErrKeyAlreadyExist checks if an error is a ErrKeyAlreadyExist.
func IsErrKeyAlreadyExist(err error) bool {
_, ok := err.(ErrKeyAlreadyExist)
return ok
}
func (err ErrKeyAlreadyExist) Error() string {
return fmt.Sprintf("public key already exists [owner_id: %d, finger_print: %s, content: %s]",
err.OwnerID, err.Fingerprint, err.Content)
}
func (err ErrKeyAlreadyExist) Unwrap() error {
return util.ErrAlreadyExist
}
// ErrKeyNameAlreadyUsed represents a "KeyNameAlreadyUsed" kind of error.
type ErrKeyNameAlreadyUsed struct {
OwnerID int64
Name string
}
// IsErrKeyNameAlreadyUsed checks if an error is a ErrKeyNameAlreadyUsed.
func IsErrKeyNameAlreadyUsed(err error) bool {
_, ok := err.(ErrKeyNameAlreadyUsed)
return ok
}
func (err ErrKeyNameAlreadyUsed) Error() string {
return fmt.Sprintf("public key already exists [owner_id: %d, name: %s]", err.OwnerID, err.Name)
}
func (err ErrKeyNameAlreadyUsed) Unwrap() error {
return util.ErrAlreadyExist
}
// ErrGPGNoEmailFound represents a "ErrGPGNoEmailFound" kind of error.
type ErrGPGNoEmailFound struct {
FailedEmails []string
ID string
}
// IsErrGPGNoEmailFound checks if an error is a ErrGPGNoEmailFound.
func IsErrGPGNoEmailFound(err error) bool {
_, ok := err.(ErrGPGNoEmailFound)
return ok
}
func (err ErrGPGNoEmailFound) Error() string {
return fmt.Sprintf("none of the emails attached to the GPG key could be found: %v", err.FailedEmails)
}
// ErrGPGInvalidTokenSignature represents a "ErrGPGInvalidTokenSignature" kind of error.
type ErrGPGInvalidTokenSignature struct {
Wrapped error
ID string
}
// IsErrGPGInvalidTokenSignature checks if an error is a ErrGPGInvalidTokenSignature.
func IsErrGPGInvalidTokenSignature(err error) bool {
_, ok := err.(ErrGPGInvalidTokenSignature)
return ok
}
func (err ErrGPGInvalidTokenSignature) Error() string {
return "the provided signature does not sign the token with the provided key"
}
// ErrGPGKeyParsing represents a "ErrGPGKeyParsing" kind of error.
type ErrGPGKeyParsing struct {
ParseError error
}
// IsErrGPGKeyParsing checks if an error is a ErrGPGKeyParsing.
func IsErrGPGKeyParsing(err error) bool {
_, ok := err.(ErrGPGKeyParsing)
return ok
}
func (err ErrGPGKeyParsing) Error() string {
return fmt.Sprintf("failed to parse gpg key %s", err.ParseError.Error())
}
// ErrGPGKeyNotExist represents a "GPGKeyNotExist" kind of error.
type ErrGPGKeyNotExist struct {
ID int64
}
// IsErrGPGKeyNotExist checks if an error is a ErrGPGKeyNotExist.
func IsErrGPGKeyNotExist(err error) bool {
_, ok := err.(ErrGPGKeyNotExist)
return ok
}
func (err ErrGPGKeyNotExist) Error() string {
return fmt.Sprintf("public gpg key does not exist [id: %d]", err.ID)
}
func (err ErrGPGKeyNotExist) Unwrap() error {
return util.ErrNotExist
}
// ErrGPGKeyImportNotExist represents a "GPGKeyImportNotExist" kind of error.
type ErrGPGKeyImportNotExist struct {
ID string
}
// IsErrGPGKeyImportNotExist checks if an error is a ErrGPGKeyImportNotExist.
func IsErrGPGKeyImportNotExist(err error) bool {
_, ok := err.(ErrGPGKeyImportNotExist)
return ok
}
func (err ErrGPGKeyImportNotExist) Error() string {
return fmt.Sprintf("public gpg key import does not exist [id: %s]", err.ID)
}
func (err ErrGPGKeyImportNotExist) Unwrap() error {
return util.ErrNotExist
}
// ErrGPGKeyIDAlreadyUsed represents a "GPGKeyIDAlreadyUsed" kind of error.
type ErrGPGKeyIDAlreadyUsed struct {
KeyID string
}
// IsErrGPGKeyIDAlreadyUsed checks if an error is a ErrKeyNameAlreadyUsed.
func IsErrGPGKeyIDAlreadyUsed(err error) bool {
_, ok := err.(ErrGPGKeyIDAlreadyUsed)
return ok
}
func (err ErrGPGKeyIDAlreadyUsed) Error() string {
return fmt.Sprintf("public key already exists [key_id: %s]", err.KeyID)
}
func (err ErrGPGKeyIDAlreadyUsed) Unwrap() error {
return util.ErrAlreadyExist
}
// ErrGPGKeyAccessDenied represents a "GPGKeyAccessDenied" kind of Error.
type ErrGPGKeyAccessDenied struct {
UserID int64
KeyID int64
}
// IsErrGPGKeyAccessDenied checks if an error is a ErrGPGKeyAccessDenied.
func IsErrGPGKeyAccessDenied(err error) bool {
_, ok := err.(ErrGPGKeyAccessDenied)
return ok
}
// Error pretty-prints an error of type ErrGPGKeyAccessDenied.
func (err ErrGPGKeyAccessDenied) Error() string {
return fmt.Sprintf("user does not have access to the key [user_id: %d, key_id: %d]",
err.UserID, err.KeyID)
}
func (err ErrGPGKeyAccessDenied) Unwrap() error {
return util.ErrPermissionDenied
}
// ErrKeyAccessDenied represents a "KeyAccessDenied" kind of error.
type ErrKeyAccessDenied struct {
UserID int64
KeyID int64
Note string
}
// IsErrKeyAccessDenied checks if an error is a ErrKeyAccessDenied.
func IsErrKeyAccessDenied(err error) bool {
_, ok := err.(ErrKeyAccessDenied)
return ok
}
func (err ErrKeyAccessDenied) Error() string {
return fmt.Sprintf("user does not have access to the key [user_id: %d, key_id: %d, note: %s]",
err.UserID, err.KeyID, err.Note)
}
func (err ErrKeyAccessDenied) Unwrap() error {
return util.ErrPermissionDenied
}
// ErrDeployKeyNotExist represents a "DeployKeyNotExist" kind of error.
type ErrDeployKeyNotExist struct {
ID int64
KeyID int64
RepoID int64
}
// IsErrDeployKeyNotExist checks if an error is a ErrDeployKeyNotExist.
func IsErrDeployKeyNotExist(err error) bool {
_, ok := err.(ErrDeployKeyNotExist)
return ok
}
func (err ErrDeployKeyNotExist) Error() string {
return fmt.Sprintf("Deploy key does not exist [id: %d, key_id: %d, repo_id: %d]", err.ID, err.KeyID, err.RepoID)
}
func (err ErrDeployKeyNotExist) Unwrap() error {
return util.ErrNotExist
}
// ErrDeployKeyAlreadyExist represents a "DeployKeyAlreadyExist" kind of error.
type ErrDeployKeyAlreadyExist struct {
KeyID int64
RepoID int64
}
// IsErrDeployKeyAlreadyExist checks if an error is a ErrDeployKeyAlreadyExist.
func IsErrDeployKeyAlreadyExist(err error) bool {
_, ok := err.(ErrDeployKeyAlreadyExist)
return ok
}
func (err ErrDeployKeyAlreadyExist) Error() string {
return fmt.Sprintf("public key already exists [key_id: %d, repo_id: %d]", err.KeyID, err.RepoID)
}
func (err ErrDeployKeyAlreadyExist) Unwrap() error {
return util.ErrAlreadyExist
}
// ErrDeployKeyNameAlreadyUsed represents a "DeployKeyNameAlreadyUsed" kind of error.
type ErrDeployKeyNameAlreadyUsed struct {
RepoID int64
Name string
}
// IsErrDeployKeyNameAlreadyUsed checks if an error is a ErrDeployKeyNameAlreadyUsed.
func IsErrDeployKeyNameAlreadyUsed(err error) bool {
_, ok := err.(ErrDeployKeyNameAlreadyUsed)
return ok
}
func (err ErrDeployKeyNameAlreadyUsed) Error() string {
return fmt.Sprintf("public key with name already exists [repo_id: %d, name: %s]", err.RepoID, err.Name)
}
func (err ErrDeployKeyNameAlreadyUsed) Unwrap() error {
return util.ErrNotExist
}
// ErrSSHInvalidTokenSignature represents a "ErrSSHInvalidTokenSignature" kind of error.
type ErrSSHInvalidTokenSignature struct {
Wrapped error
Fingerprint string
}
// IsErrSSHInvalidTokenSignature checks if an error is a ErrSSHInvalidTokenSignature.
func IsErrSSHInvalidTokenSignature(err error) bool {
_, ok := err.(ErrSSHInvalidTokenSignature)
return ok
}
func (err ErrSSHInvalidTokenSignature) Error() string {
return "the provided signature does not sign the token with the provided key"
}
func (err ErrSSHInvalidTokenSignature) Unwrap() error {
return util.ErrInvalidArgument
}

267
models/asymkey/gpg_key.go Normal file
View File

@@ -0,0 +1,267 @@
// Copyright 2017 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package asymkey
import (
"context"
"fmt"
"strings"
"time"
"code.gitea.io/gitea/models/db"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/timeutil"
"github.com/keybase/go-crypto/openpgp"
"github.com/keybase/go-crypto/openpgp/packet"
"xorm.io/builder"
)
// GPGKey represents a GPG key.
type GPGKey struct {
ID int64 `xorm:"pk autoincr"`
OwnerID int64 `xorm:"INDEX NOT NULL"`
KeyID string `xorm:"INDEX CHAR(16) NOT NULL"`
PrimaryKeyID string `xorm:"CHAR(16)"`
Content string `xorm:"MEDIUMTEXT NOT NULL"`
CreatedUnix timeutil.TimeStamp `xorm:"created"`
ExpiredUnix timeutil.TimeStamp
AddedUnix timeutil.TimeStamp
SubsKey []*GPGKey `xorm:"-"`
Emails []*user_model.EmailAddress
Verified bool `xorm:"NOT NULL DEFAULT false"`
CanSign bool
CanEncryptComms bool
CanEncryptStorage bool
CanCertify bool
}
func init() {
db.RegisterModel(new(GPGKey))
}
// BeforeInsert will be invoked by XORM before inserting a record
func (key *GPGKey) BeforeInsert() {
key.AddedUnix = timeutil.TimeStampNow()
}
func (key *GPGKey) LoadSubKeys(ctx context.Context) error {
if err := db.GetEngine(ctx).Where("primary_key_id=?", key.KeyID).Find(&key.SubsKey); err != nil {
return fmt.Errorf("find Sub GPGkeys[%s]: %v", key.KeyID, err)
}
return nil
}
// PaddedKeyID show KeyID padded to 16 characters
func (key *GPGKey) PaddedKeyID() string {
return PaddedKeyID(key.KeyID)
}
// PaddedKeyID show KeyID padded to 16 characters
func PaddedKeyID(keyID string) string {
if len(keyID) > 15 {
return keyID
}
zeros := "0000000000000000"
return zeros[0:16-len(keyID)] + keyID
}
type FindGPGKeyOptions struct {
db.ListOptions
OwnerID int64
KeyID string
IncludeSubKeys bool
}
func (opts FindGPGKeyOptions) ToConds() builder.Cond {
cond := builder.NewCond()
if !opts.IncludeSubKeys {
cond = cond.And(builder.Eq{"primary_key_id": ""})
}
if opts.OwnerID > 0 {
cond = cond.And(builder.Eq{"owner_id": opts.OwnerID})
}
if opts.KeyID != "" {
cond = cond.And(builder.Eq{"key_id": opts.KeyID})
}
return cond
}
func GetGPGKeyForUserByID(ctx context.Context, ownerID, keyID int64) (*GPGKey, error) {
key := new(GPGKey)
has, err := db.GetEngine(ctx).Where("id=? AND owner_id=?", keyID, ownerID).Get(key)
if err != nil {
return nil, err
} else if !has {
return nil, ErrGPGKeyNotExist{keyID}
}
return key, nil
}
// GPGKeyToEntity retrieve the imported key and the traducted entity
func GPGKeyToEntity(ctx context.Context, k *GPGKey) (*openpgp.Entity, error) {
impKey, err := GetGPGImportByKeyID(ctx, k.KeyID)
if err != nil {
return nil, err
}
keys, err := checkArmoredGPGKeyString(impKey.Content)
if err != nil {
return nil, err
}
return keys[0], err
}
// parseSubGPGKey parse a sub Key
func parseSubGPGKey(ownerID int64, primaryID string, pubkey *packet.PublicKey, expiry time.Time) (*GPGKey, error) {
content, err := base64EncPubKey(pubkey)
if err != nil {
return nil, err
}
return &GPGKey{
OwnerID: ownerID,
KeyID: pubkey.KeyIdString(),
PrimaryKeyID: primaryID,
Content: content,
CreatedUnix: timeutil.TimeStamp(pubkey.CreationTime.Unix()),
ExpiredUnix: timeutil.TimeStamp(expiry.Unix()),
CanSign: pubkey.CanSign(),
CanEncryptComms: pubkey.PubKeyAlgo.CanEncrypt(),
CanEncryptStorage: pubkey.PubKeyAlgo.CanEncrypt(),
CanCertify: pubkey.PubKeyAlgo.CanSign(),
}, nil
}
// parseGPGKey parse a PrimaryKey entity (primary key + subs keys + self-signature)
func parseGPGKey(ctx context.Context, ownerID int64, e *openpgp.Entity, verified bool) (*GPGKey, error) {
pubkey := e.PrimaryKey
expiry := getExpiryTime(e)
// Parse Subkeys
subkeys := make([]*GPGKey, len(e.Subkeys))
for i, k := range e.Subkeys {
subs, err := parseSubGPGKey(ownerID, pubkey.KeyIdString(), k.PublicKey, expiry)
if err != nil {
return nil, ErrGPGKeyParsing{ParseError: err}
}
subkeys[i] = subs
}
// Check emails
userEmails, err := user_model.GetEmailAddresses(ctx, ownerID)
if err != nil {
return nil, err
}
emails := make([]*user_model.EmailAddress, 0, len(e.Identities))
for _, ident := range e.Identities {
if ident.Revocation != nil {
continue
}
email := strings.ToLower(strings.TrimSpace(ident.UserId.Email))
for _, e := range userEmails {
if e.IsActivated && e.LowerEmail == email {
emails = append(emails, e)
break
}
}
}
if !verified {
// In the case no email as been found
if len(emails) == 0 {
failedEmails := make([]string, 0, len(e.Identities))
for _, ident := range e.Identities {
failedEmails = append(failedEmails, ident.UserId.Email)
}
return nil, ErrGPGNoEmailFound{failedEmails, e.PrimaryKey.KeyIdString()}
}
}
content, err := base64EncPubKey(pubkey)
if err != nil {
return nil, err
}
return &GPGKey{
OwnerID: ownerID,
KeyID: pubkey.KeyIdString(),
PrimaryKeyID: "",
Content: content,
CreatedUnix: timeutil.TimeStamp(pubkey.CreationTime.Unix()),
ExpiredUnix: timeutil.TimeStamp(expiry.Unix()),
Emails: emails,
SubsKey: subkeys,
Verified: verified,
CanSign: pubkey.CanSign(),
CanEncryptComms: pubkey.PubKeyAlgo.CanEncrypt(),
CanEncryptStorage: pubkey.PubKeyAlgo.CanEncrypt(),
CanCertify: pubkey.PubKeyAlgo.CanSign(),
}, nil
}
// deleteGPGKey does the actual key deletion
func deleteGPGKey(ctx context.Context, keyID string) (int64, error) {
if keyID == "" {
return 0, fmt.Errorf("empty KeyId forbidden") // Should never happen but just to be sure
}
// Delete imported key
n, err := db.GetEngine(ctx).Where("key_id=?", keyID).Delete(new(GPGKeyImport))
if err != nil {
return n, err
}
return db.GetEngine(ctx).Where("key_id=?", keyID).Or("primary_key_id=?", keyID).Delete(new(GPGKey))
}
// DeleteGPGKey deletes GPG key information in database.
func DeleteGPGKey(ctx context.Context, doer *user_model.User, id int64) (err error) {
key, err := GetGPGKeyForUserByID(ctx, doer.ID, id)
if err != nil {
if IsErrGPGKeyNotExist(err) {
return nil
}
return fmt.Errorf("GetPublicKeyByID: %w", err)
}
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return err
}
defer committer.Close()
if _, err = deleteGPGKey(ctx, key.KeyID); err != nil {
return err
}
return committer.Commit()
}
func checkKeyEmails(ctx context.Context, email string, keys ...*GPGKey) (bool, string) {
uid := int64(0)
var userEmails []*user_model.EmailAddress
var user *user_model.User
for _, key := range keys {
for _, e := range key.Emails {
if e.IsActivated && (email == "" || strings.EqualFold(e.Email, email)) {
return true, e.Email
}
}
if key.Verified && key.OwnerID != 0 {
if uid != key.OwnerID {
userEmails, _ = user_model.GetEmailAddresses(ctx, key.OwnerID)
uid = key.OwnerID
user = &user_model.User{ID: uid}
_, _ = user_model.GetUser(ctx, user)
}
for _, e := range userEmails {
if e.IsActivated && (email == "" || strings.EqualFold(e.Email, email)) {
return true, e.Email
}
}
if user.KeepEmailPrivate && strings.EqualFold(email, user.GetEmail()) {
return true, user.GetEmail()
}
}
}
return false, email
}

View File

@@ -0,0 +1,167 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package asymkey
import (
"context"
"strings"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/log"
"github.com/keybase/go-crypto/openpgp"
)
// __________________ ________ ____ __.
// / _____/\______ \/ _____/ | |/ _|____ ___.__.
// / \ ___ | ___/ \ ___ | <_/ __ < | |
// \ \_\ \| | \ \_\ \ | | \ ___/\___ |
// \______ /|____| \______ / |____|__ \___ > ____|
// \/ \/ \/ \/\/
// _____ .___ .___
// / _ \ __| _/__| _/
// / /_\ \ / __ |/ __ |
// / | \/ /_/ / /_/ |
// \____|__ /\____ \____ |
// \/ \/ \/
// This file contains functions relating to adding GPG Keys
// addGPGKey add key, import and subkeys to database
func addGPGKey(ctx context.Context, key *GPGKey, content string) (err error) {
// Add GPGKeyImport
if err = db.Insert(ctx, &GPGKeyImport{
KeyID: key.KeyID,
Content: content,
}); err != nil {
return err
}
// Save GPG primary key.
if err = db.Insert(ctx, key); err != nil {
return err
}
// Save GPG subs key.
for _, subkey := range key.SubsKey {
if err := addGPGSubKey(ctx, subkey); err != nil {
return err
}
}
return nil
}
// addGPGSubKey add subkeys to database
func addGPGSubKey(ctx context.Context, key *GPGKey) (err error) {
// Save GPG primary key.
if err = db.Insert(ctx, key); err != nil {
return err
}
// Save GPG subs key.
for _, subkey := range key.SubsKey {
if err := addGPGSubKey(ctx, subkey); err != nil {
return err
}
}
return nil
}
// AddGPGKey adds new public key to database.
func AddGPGKey(ctx context.Context, ownerID int64, content, token, signature string) ([]*GPGKey, error) {
ekeys, err := checkArmoredGPGKeyString(content)
if err != nil {
return nil, err
}
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return nil, err
}
defer committer.Close()
keys := make([]*GPGKey, 0, len(ekeys))
verified := false
// Handle provided signature
if signature != "" {
signer, err := openpgp.CheckArmoredDetachedSignature(ekeys, strings.NewReader(token), strings.NewReader(signature))
if err != nil {
signer, err = openpgp.CheckArmoredDetachedSignature(ekeys, strings.NewReader(token+"\n"), strings.NewReader(signature))
}
if err != nil {
signer, err = openpgp.CheckArmoredDetachedSignature(ekeys, strings.NewReader(token+"\r\n"), strings.NewReader(signature))
}
if err != nil {
log.Error("Unable to validate token signature. Error: %v", err)
return nil, ErrGPGInvalidTokenSignature{
ID: ekeys[0].PrimaryKey.KeyIdString(),
Wrapped: err,
}
}
ekeys = []*openpgp.Entity{signer}
verified = true
}
if len(ekeys) > 1 {
id2key := map[string]*openpgp.Entity{}
newEKeys := make([]*openpgp.Entity, 0, len(ekeys))
for _, ekey := range ekeys {
id := ekey.PrimaryKey.KeyIdString()
if original, has := id2key[id]; has {
// Coalesce this with the other one
for _, subkey := range ekey.Subkeys {
if subkey.PublicKey == nil {
continue
}
found := false
for _, originalSubkey := range original.Subkeys {
if originalSubkey.PublicKey == nil {
continue
}
if originalSubkey.PublicKey.KeyId == subkey.PublicKey.KeyId {
found = true
break
}
}
if !found {
original.Subkeys = append(original.Subkeys, subkey)
}
}
for name, identity := range ekey.Identities {
if _, has := original.Identities[name]; has {
continue
}
original.Identities[name] = identity
}
continue
}
id2key[id] = ekey
newEKeys = append(newEKeys, ekey)
}
ekeys = newEKeys
}
for _, ekey := range ekeys {
// Key ID cannot be duplicated.
has, err := db.GetEngine(ctx).Where("key_id=?", ekey.PrimaryKey.KeyIdString()).
Get(new(GPGKey))
if err != nil {
return nil, err
} else if has {
return nil, ErrGPGKeyIDAlreadyUsed{ekey.PrimaryKey.KeyIdString()}
}
// Get DB session
key, err := parseGPGKey(ctx, ownerID, ekey, verified)
if err != nil {
return nil, err
}
if err = addGPGKey(ctx, key, content); err != nil {
return nil, err
}
keys = append(keys, key)
}
return keys, committer.Commit()
}

View File

@@ -0,0 +1,536 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package asymkey
import (
"context"
"fmt"
"hash"
"strings"
"code.gitea.io/gitea/models/db"
repo_model "code.gitea.io/gitea/models/repo"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/git"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/setting"
"github.com/keybase/go-crypto/openpgp/packet"
)
// __________________ ________ ____ __.
// / _____/\______ \/ _____/ | |/ _|____ ___.__.
// / \ ___ | ___/ \ ___ | <_/ __ < | |
// \ \_\ \| | \ \_\ \ | | \ ___/\___ |
// \______ /|____| \______ / |____|__ \___ > ____|
// \/ \/ \/ \/\/
// _________ .__ __
// \_ ___ \ ____ _____ _____ |__|/ |_
// / \ \/ / _ \ / \ / \| \ __\
// \ \___( <_> ) Y Y \ Y Y \ || |
// \______ /\____/|__|_| /__|_| /__||__|
// \/ \/ \/
// ____ ____ .__ _____.__ __ .__
// \ \ / /___________|__|/ ____\__| ____ _____ _/ |_|__| ____ ____
// \ Y // __ \_ __ \ \ __\| |/ ___\\__ \\ __\ |/ _ \ / \
// \ /\ ___/| | \/ || | | \ \___ / __ \| | | ( <_> ) | \
// \___/ \___ >__| |__||__| |__|\___ >____ /__| |__|\____/|___| /
// \/ \/ \/ \/
// This file provides functions relating commit verification
// CommitVerification represents a commit validation of signature
type CommitVerification struct {
Verified bool
Warning bool
Reason string
SigningUser *user_model.User
CommittingUser *user_model.User
SigningEmail string
SigningKey *GPGKey
SigningSSHKey *PublicKey
TrustStatus string
}
// SignCommit represents a commit with validation of signature.
type SignCommit struct {
Verification *CommitVerification
*user_model.UserCommit
}
const (
// BadSignature is used as the reason when the signature has a KeyID that is in the db
// but no key that has that ID verifies the signature. This is a suspicious failure.
BadSignature = "gpg.error.probable_bad_signature"
// BadDefaultSignature is used as the reason when the signature has a KeyID that matches the
// default Key but is not verified by the default key. This is a suspicious failure.
BadDefaultSignature = "gpg.error.probable_bad_default_signature"
// NoKeyFound is used as the reason when no key can be found to verify the signature.
NoKeyFound = "gpg.error.no_gpg_keys_found"
)
// ParseCommitsWithSignature checks if signaute of commits are corresponding to users gpg keys.
func ParseCommitsWithSignature(ctx context.Context, oldCommits []*user_model.UserCommit, repoTrustModel repo_model.TrustModelType, isOwnerMemberCollaborator func(*user_model.User) (bool, error)) []*SignCommit {
newCommits := make([]*SignCommit, 0, len(oldCommits))
keyMap := map[string]bool{}
for _, c := range oldCommits {
signCommit := &SignCommit{
UserCommit: c,
Verification: ParseCommitWithSignature(ctx, c.Commit),
}
_ = CalculateTrustStatus(signCommit.Verification, repoTrustModel, isOwnerMemberCollaborator, &keyMap)
newCommits = append(newCommits, signCommit)
}
return newCommits
}
// ParseCommitWithSignature check if signature is good against keystore.
func ParseCommitWithSignature(ctx context.Context, c *git.Commit) *CommitVerification {
var committer *user_model.User
if c.Committer != nil {
var err error
// Find Committer account
committer, err = user_model.GetUserByEmail(ctx, c.Committer.Email) // This finds the user by primary email or activated email so commit will not be valid if email is not
if err != nil { // Skipping not user for committer
committer = &user_model.User{
Name: c.Committer.Name,
Email: c.Committer.Email,
}
// We can expect this to often be an ErrUserNotExist. in the case
// it is not, however, it is important to log it.
if !user_model.IsErrUserNotExist(err) {
log.Error("GetUserByEmail: %v", err)
return &CommitVerification{
CommittingUser: committer,
Verified: false,
Reason: "gpg.error.no_committer_account",
}
}
}
}
// If no signature just report the committer
if c.Signature == nil {
return &CommitVerification{
CommittingUser: committer,
Verified: false, // Default value
Reason: "gpg.error.not_signed_commit", // Default value
}
}
// If this a SSH signature handle it differently
if strings.HasPrefix(c.Signature.Signature, "-----BEGIN SSH SIGNATURE-----") {
return ParseCommitWithSSHSignature(ctx, c, committer)
}
// Parsing signature
sig, err := extractSignature(c.Signature.Signature)
if err != nil { // Skipping failed to extract sign
log.Error("SignatureRead err: %v", err)
return &CommitVerification{
CommittingUser: committer,
Verified: false,
Reason: "gpg.error.extract_sign",
}
}
keyID := tryGetKeyIDFromSignature(sig)
defaultReason := NoKeyFound
// First check if the sig has a keyID and if so just look at that
if commitVerification := hashAndVerifyForKeyID(
ctx,
sig,
c.Signature.Payload,
committer,
keyID,
setting.AppName,
""); commitVerification != nil {
if commitVerification.Reason == BadSignature {
defaultReason = BadSignature
} else {
return commitVerification
}
}
// Now try to associate the signature with the committer, if present
if committer.ID != 0 {
keys, err := db.Find[GPGKey](ctx, FindGPGKeyOptions{
OwnerID: committer.ID,
})
if err != nil { // Skipping failed to get gpg keys of user
log.Error("ListGPGKeys: %v", err)
return &CommitVerification{
CommittingUser: committer,
Verified: false,
Reason: "gpg.error.failed_retrieval_gpg_keys",
}
}
if err := GPGKeyList(keys).LoadSubKeys(ctx); err != nil {
log.Error("LoadSubKeys: %v", err)
return &CommitVerification{
CommittingUser: committer,
Verified: false,
Reason: "gpg.error.failed_retrieval_gpg_keys",
}
}
committerEmailAddresses, _ := user_model.GetEmailAddresses(ctx, committer.ID)
activated := false
for _, e := range committerEmailAddresses {
if e.IsActivated && strings.EqualFold(e.Email, c.Committer.Email) {
activated = true
break
}
}
for _, k := range keys {
// Pre-check (& optimization) that emails attached to key can be attached to the committer email and can validate
canValidate := false
email := ""
if k.Verified && activated {
canValidate = true
email = c.Committer.Email
}
if !canValidate {
for _, e := range k.Emails {
if e.IsActivated && strings.EqualFold(e.Email, c.Committer.Email) {
canValidate = true
email = e.Email
break
}
}
}
if !canValidate {
continue // Skip this key
}
commitVerification := hashAndVerifyWithSubKeysCommitVerification(sig, c.Signature.Payload, k, committer, committer, email)
if commitVerification != nil {
return commitVerification
}
}
}
if setting.Repository.Signing.SigningKey != "" && setting.Repository.Signing.SigningKey != "default" && setting.Repository.Signing.SigningKey != "none" {
// OK we should try the default key
gpgSettings := git.GPGSettings{
Sign: true,
KeyID: setting.Repository.Signing.SigningKey,
Name: setting.Repository.Signing.SigningName,
Email: setting.Repository.Signing.SigningEmail,
}
if err := gpgSettings.LoadPublicKeyContent(); err != nil {
log.Error("Error getting default signing key: %s %v", gpgSettings.KeyID, err)
} else if commitVerification := verifyWithGPGSettings(ctx, &gpgSettings, sig, c.Signature.Payload, committer, keyID); commitVerification != nil {
if commitVerification.Reason == BadSignature {
defaultReason = BadSignature
} else {
return commitVerification
}
}
}
defaultGPGSettings, err := c.GetRepositoryDefaultPublicGPGKey(false)
if err != nil {
log.Error("Error getting default public gpg key: %v", err)
} else if defaultGPGSettings == nil {
log.Warn("Unable to get defaultGPGSettings for unattached commit: %s", c.ID.String())
} else if defaultGPGSettings.Sign {
if commitVerification := verifyWithGPGSettings(ctx, defaultGPGSettings, sig, c.Signature.Payload, committer, keyID); commitVerification != nil {
if commitVerification.Reason == BadSignature {
defaultReason = BadSignature
} else {
return commitVerification
}
}
}
return &CommitVerification{ // Default at this stage
CommittingUser: committer,
Verified: false,
Warning: defaultReason != NoKeyFound,
Reason: defaultReason,
SigningKey: &GPGKey{
KeyID: keyID,
},
}
}
func verifyWithGPGSettings(ctx context.Context, gpgSettings *git.GPGSettings, sig *packet.Signature, payload string, committer *user_model.User, keyID string) *CommitVerification {
// First try to find the key in the db
if commitVerification := hashAndVerifyForKeyID(ctx, sig, payload, committer, gpgSettings.KeyID, gpgSettings.Name, gpgSettings.Email); commitVerification != nil {
return commitVerification
}
// Otherwise we have to parse the key
ekeys, err := checkArmoredGPGKeyString(gpgSettings.PublicKeyContent)
if err != nil {
log.Error("Unable to get default signing key: %v", err)
return &CommitVerification{
CommittingUser: committer,
Verified: false,
Reason: "gpg.error.generate_hash",
}
}
for _, ekey := range ekeys {
pubkey := ekey.PrimaryKey
content, err := base64EncPubKey(pubkey)
if err != nil {
return &CommitVerification{
CommittingUser: committer,
Verified: false,
Reason: "gpg.error.generate_hash",
}
}
k := &GPGKey{
Content: content,
CanSign: pubkey.CanSign(),
KeyID: pubkey.KeyIdString(),
}
for _, subKey := range ekey.Subkeys {
content, err := base64EncPubKey(subKey.PublicKey)
if err != nil {
return &CommitVerification{
CommittingUser: committer,
Verified: false,
Reason: "gpg.error.generate_hash",
}
}
k.SubsKey = append(k.SubsKey, &GPGKey{
Content: content,
CanSign: subKey.PublicKey.CanSign(),
KeyID: subKey.PublicKey.KeyIdString(),
})
}
if commitVerification := hashAndVerifyWithSubKeysCommitVerification(sig, payload, k, committer, &user_model.User{
Name: gpgSettings.Name,
Email: gpgSettings.Email,
}, gpgSettings.Email); commitVerification != nil {
return commitVerification
}
if keyID == k.KeyID {
// This is a bad situation ... We have a key id that matches our default key but the signature doesn't match.
return &CommitVerification{
CommittingUser: committer,
Verified: false,
Warning: true,
Reason: BadSignature,
}
}
}
return nil
}
func verifySign(s *packet.Signature, h hash.Hash, k *GPGKey) error {
// Check if key can sign
if !k.CanSign {
return fmt.Errorf("key can not sign")
}
// Decode key
pkey, err := base64DecPubKey(k.Content)
if err != nil {
return err
}
return pkey.VerifySignature(h, s)
}
func hashAndVerify(sig *packet.Signature, payload string, k *GPGKey) (*GPGKey, error) {
// Generating hash of commit
hash, err := populateHash(sig.Hash, []byte(payload))
if err != nil { // Skipping as failed to generate hash
log.Error("PopulateHash: %v", err)
return nil, err
}
// We will ignore errors in verification as they don't need to be propagated up
err = verifySign(sig, hash, k)
if err != nil {
return nil, nil
}
return k, nil
}
func hashAndVerifyWithSubKeys(sig *packet.Signature, payload string, k *GPGKey) (*GPGKey, error) {
verified, err := hashAndVerify(sig, payload, k)
if err != nil || verified != nil {
return verified, err
}
for _, sk := range k.SubsKey {
verified, err := hashAndVerify(sig, payload, sk)
if err != nil || verified != nil {
return verified, err
}
}
return nil, nil
}
func hashAndVerifyWithSubKeysCommitVerification(sig *packet.Signature, payload string, k *GPGKey, committer, signer *user_model.User, email string) *CommitVerification {
key, err := hashAndVerifyWithSubKeys(sig, payload, k)
if err != nil { // Skipping failed to generate hash
return &CommitVerification{
CommittingUser: committer,
Verified: false,
Reason: "gpg.error.generate_hash",
}
}
if key != nil {
return &CommitVerification{ // Everything is ok
CommittingUser: committer,
Verified: true,
Reason: fmt.Sprintf("%s / %s", signer.Name, key.KeyID),
SigningUser: signer,
SigningKey: key,
SigningEmail: email,
}
}
return nil
}
func hashAndVerifyForKeyID(ctx context.Context, sig *packet.Signature, payload string, committer *user_model.User, keyID, name, email string) *CommitVerification {
if keyID == "" {
return nil
}
keys, err := db.Find[GPGKey](ctx, FindGPGKeyOptions{
KeyID: keyID,
IncludeSubKeys: true,
})
if err != nil {
log.Error("GetGPGKeysByKeyID: %v", err)
return &CommitVerification{
CommittingUser: committer,
Verified: false,
Reason: "gpg.error.failed_retrieval_gpg_keys",
}
}
if len(keys) == 0 {
return nil
}
for _, key := range keys {
var primaryKeys []*GPGKey
if key.PrimaryKeyID != "" {
primaryKeys, err = db.Find[GPGKey](ctx, FindGPGKeyOptions{
KeyID: key.PrimaryKeyID,
IncludeSubKeys: true,
})
if err != nil {
log.Error("GetGPGKeysByKeyID: %v", err)
return &CommitVerification{
CommittingUser: committer,
Verified: false,
Reason: "gpg.error.failed_retrieval_gpg_keys",
}
}
}
activated, email := checkKeyEmails(ctx, email, append([]*GPGKey{key}, primaryKeys...)...)
if !activated {
continue
}
signer := &user_model.User{
Name: name,
Email: email,
}
if key.OwnerID != 0 {
owner, err := user_model.GetUserByID(ctx, key.OwnerID)
if err == nil {
signer = owner
} else if !user_model.IsErrUserNotExist(err) {
log.Error("Failed to user_model.GetUserByID: %d for key ID: %d (%s) %v", key.OwnerID, key.ID, key.KeyID, err)
return &CommitVerification{
CommittingUser: committer,
Verified: false,
Reason: "gpg.error.no_committer_account",
}
}
}
commitVerification := hashAndVerifyWithSubKeysCommitVerification(sig, payload, key, committer, signer, email)
if commitVerification != nil {
return commitVerification
}
}
// This is a bad situation ... We have a key id that is in our database but the signature doesn't match.
return &CommitVerification{
CommittingUser: committer,
Verified: false,
Warning: true,
Reason: BadSignature,
}
}
// CalculateTrustStatus will calculate the TrustStatus for a commit verification within a repository
// There are several trust models in Gitea
func CalculateTrustStatus(verification *CommitVerification, repoTrustModel repo_model.TrustModelType, isOwnerMemberCollaborator func(*user_model.User) (bool, error), keyMap *map[string]bool) error {
if !verification.Verified {
return nil
}
// In the Committer trust model a signature is trusted if it matches the committer
// - it doesn't matter if they're a collaborator, the owner, Gitea or Github
// NB: This model is commit verification only
if repoTrustModel == repo_model.CommitterTrustModel {
// default to "unmatched"
verification.TrustStatus = "unmatched"
// We can only verify against users in our database but the default key will match
// against by email if it is not in the db.
if (verification.SigningUser.ID != 0 &&
verification.CommittingUser.ID == verification.SigningUser.ID) ||
(verification.SigningUser.ID == 0 && verification.CommittingUser.ID == 0 &&
verification.SigningUser.Email == verification.CommittingUser.Email) {
verification.TrustStatus = "trusted"
}
return nil
}
// Now we drop to the more nuanced trust models...
verification.TrustStatus = "trusted"
if verification.SigningUser.ID == 0 {
// This commit is signed by the default key - but this key is not assigned to a user in the DB.
// However in the repo_model.CollaboratorCommitterTrustModel we cannot mark this as trusted
// unless the default key matches the email of a non-user.
if repoTrustModel == repo_model.CollaboratorCommitterTrustModel && (verification.CommittingUser.ID != 0 ||
verification.SigningUser.Email != verification.CommittingUser.Email) {
verification.TrustStatus = "untrusted"
}
return nil
}
// Check we actually have a GPG SigningKey
var err error
if verification.SigningKey != nil {
var isMember bool
if keyMap != nil {
var has bool
isMember, has = (*keyMap)[verification.SigningKey.KeyID]
if !has {
isMember, err = isOwnerMemberCollaborator(verification.SigningUser)
(*keyMap)[verification.SigningKey.KeyID] = isMember
}
} else {
isMember, err = isOwnerMemberCollaborator(verification.SigningUser)
}
if !isMember {
verification.TrustStatus = "untrusted"
if verification.CommittingUser.ID != verification.SigningUser.ID {
// The committing user and the signing user are not the same
// This should be marked as questionable unless the signing user is a collaborator/team member etc.
verification.TrustStatus = "unmatched"
}
} else if repoTrustModel == repo_model.CollaboratorCommitterTrustModel && verification.CommittingUser.ID != verification.SigningUser.ID {
// The committing user and the signing user are not the same and our trustmodel states that they must match
verification.TrustStatus = "unmatched"
}
}
return err
}

View File

@@ -0,0 +1,146 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package asymkey
import (
"bytes"
"crypto"
"encoding/base64"
"fmt"
"hash"
"io"
"strings"
"time"
"github.com/keybase/go-crypto/openpgp"
"github.com/keybase/go-crypto/openpgp/armor"
"github.com/keybase/go-crypto/openpgp/packet"
)
// __________________ ________ ____ __.
// / _____/\______ \/ _____/ | |/ _|____ ___.__.
// / \ ___ | ___/ \ ___ | <_/ __ < | |
// \ \_\ \| | \ \_\ \ | | \ ___/\___ |
// \______ /|____| \______ / |____|__ \___ > ____|
// \/ \/ \/ \/\/
// _________
// \_ ___ \ ____ _____ _____ ____ ____
// / \ \/ / _ \ / \ / \ / _ \ / \
// \ \___( <_> ) Y Y \ Y Y ( <_> ) | \
// \______ /\____/|__|_| /__|_| /\____/|___| /
// \/ \/ \/ \/
// This file provides common functions relating to GPG Keys
// checkArmoredGPGKeyString checks if the given key string is a valid GPG armored key.
// The function returns the actual public key on success
func checkArmoredGPGKeyString(content string) (openpgp.EntityList, error) {
list, err := openpgp.ReadArmoredKeyRing(strings.NewReader(content))
if err != nil {
return nil, ErrGPGKeyParsing{err}
}
return list, nil
}
// base64EncPubKey encode public key content to base 64
func base64EncPubKey(pubkey *packet.PublicKey) (string, error) {
var w bytes.Buffer
err := pubkey.Serialize(&w)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(w.Bytes()), nil
}
func readerFromBase64(s string) (io.Reader, error) {
bs, err := base64.StdEncoding.DecodeString(s)
if err != nil {
return nil, err
}
return bytes.NewBuffer(bs), nil
}
// base64DecPubKey decode public key content from base 64
func base64DecPubKey(content string) (*packet.PublicKey, error) {
b, err := readerFromBase64(content)
if err != nil {
return nil, err
}
// Read key
p, err := packet.Read(b)
if err != nil {
return nil, err
}
// Check type
pkey, ok := p.(*packet.PublicKey)
if !ok {
return nil, fmt.Errorf("key is not a public key")
}
return pkey, nil
}
// getExpiryTime extract the expire time of primary key based on sig
func getExpiryTime(e *openpgp.Entity) time.Time {
expiry := time.Time{}
// Extract self-sign for expire date based on : https://github.com/golang/crypto/blob/master/openpgp/keys.go#L165
var selfSig *packet.Signature
for _, ident := range e.Identities {
if selfSig == nil {
selfSig = ident.SelfSignature
} else if ident.SelfSignature.IsPrimaryId != nil && *ident.SelfSignature.IsPrimaryId {
selfSig = ident.SelfSignature
break
}
}
if selfSig.KeyLifetimeSecs != nil {
expiry = e.PrimaryKey.CreationTime.Add(time.Duration(*selfSig.KeyLifetimeSecs) * time.Second)
}
return expiry
}
func populateHash(hashFunc crypto.Hash, msg []byte) (hash.Hash, error) {
h := hashFunc.New()
if _, err := h.Write(msg); err != nil {
return nil, err
}
return h, nil
}
// readArmoredSign read an armored signature block with the given type. https://sourcegraph.com/github.com/golang/crypto/-/blob/openpgp/read.go#L24:6-24:17
func readArmoredSign(r io.Reader) (body io.Reader, err error) {
block, err := armor.Decode(r)
if err != nil {
return nil, err
}
if block.Type != openpgp.SignatureType {
return nil, fmt.Errorf("expected '%s', got: %s", openpgp.SignatureType, block.Type)
}
return block.Body, nil
}
func extractSignature(s string) (*packet.Signature, error) {
r, err := readArmoredSign(strings.NewReader(s))
if err != nil {
return nil, fmt.Errorf("Failed to read signature armor")
}
p, err := packet.Read(r)
if err != nil {
return nil, fmt.Errorf("Failed to read signature packet")
}
sig, ok := p.(*packet.Signature)
if !ok {
return nil, fmt.Errorf("Packet is not a signature")
}
return sig, nil
}
func tryGetKeyIDFromSignature(sig *packet.Signature) string {
if sig.IssuerKeyId != nil && (*sig.IssuerKeyId) != 0 {
return fmt.Sprintf("%016X", *sig.IssuerKeyId)
}
if len(sig.IssuerFingerprint) > 0 {
return fmt.Sprintf("%016X", sig.IssuerFingerprint[12:20])
}
return ""
}

View File

@@ -0,0 +1,47 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package asymkey
import (
"context"
"code.gitea.io/gitea/models/db"
)
// __________________ ________ ____ __.
// / _____/\______ \/ _____/ | |/ _|____ ___.__.
// / \ ___ | ___/ \ ___ | <_/ __ < | |
// \ \_\ \| | \ \_\ \ | | \ ___/\___ |
// \______ /|____| \______ / |____|__ \___ > ____|
// \/ \/ \/ \/\/
// .___ __
// | | _____ ______ ____________/ |_
// | |/ \\____ \ / _ \_ __ \ __\
// | | Y Y \ |_> > <_> ) | \/| |
// |___|__|_| / __/ \____/|__| |__|
// \/|__|
// This file contains functions related to the original import of a key
// GPGKeyImport the original import of key
type GPGKeyImport struct {
KeyID string `xorm:"pk CHAR(16) NOT NULL"`
Content string `xorm:"MEDIUMTEXT NOT NULL"`
}
func init() {
db.RegisterModel(new(GPGKeyImport))
}
// GetGPGImportByKeyID returns the import public armored key by given KeyID.
func GetGPGImportByKeyID(ctx context.Context, keyID string) (*GPGKeyImport, error) {
key := new(GPGKeyImport)
has, err := db.GetEngine(ctx).ID(keyID).Get(key)
if err != nil {
return nil, err
} else if !has {
return nil, ErrGPGKeyImportNotExist{keyID}
}
return key, nil
}

View File

@@ -0,0 +1,38 @@
// Copyright 2023 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package asymkey
import (
"context"
"code.gitea.io/gitea/models/db"
)
type GPGKeyList []*GPGKey
func (keys GPGKeyList) keyIDs() []string {
ids := make([]string, len(keys))
for i, key := range keys {
ids[i] = key.KeyID
}
return ids
}
func (keys GPGKeyList) LoadSubKeys(ctx context.Context) error {
subKeys := make([]*GPGKey, 0, len(keys))
if err := db.GetEngine(ctx).In("primary_key_id", keys.keyIDs()).Find(&subKeys); err != nil {
return err
}
subKeysMap := make(map[string][]*GPGKey, len(subKeys))
for _, key := range subKeys {
subKeysMap[key.PrimaryKeyID] = append(subKeysMap[key.PrimaryKeyID], key)
}
for _, key := range keys {
if subKeys, ok := subKeysMap[key.KeyID]; ok {
key.SubsKey = subKeys
}
}
return nil
}

View File

@@ -0,0 +1,405 @@
// Copyright 2017 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package asymkey
import (
"testing"
"time"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/unittest"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/timeutil"
"code.gitea.io/gitea/modules/util"
"github.com/keybase/go-crypto/openpgp/packet"
"github.com/stretchr/testify/assert"
)
func TestCheckArmoredGPGKeyString(t *testing.T) {
testGPGArmor := `-----BEGIN PGP PUBLIC KEY BLOCK-----
mQENBFh91QoBCADciaDd7aqegYkn4ZIG7J0p1CRwpqMGjxFroJEMg6M1ZiuEVTRv
z49P4kcr1+98NvFmcNc+x5uJgvPCwr/N8ZW5nqBUs2yrklbFF4MeQomyZJJegP8m
/dsRT3BwIT8YMUtJuCj0iqD9vuKYfjrztcMgC1sYwcE9E9OlA0pWBvUdU2i0TIB1
vOq6slWGvHHa5l5gPfm09idlVxfH5+I+L1uIMx5ovbiVVU5x2f1AR1T18f0t2TVN
0agFTyuoYE1ATmvJHmMcsfgM1Gpd9hIlr9vlupT2kKTPoNzVzsJsOU6Ku/Lf/bac
mF+TfSbRCtmG7dkYZ4metLj7zG/WkW8IvJARABEBAAG0HUFudG9pbmUgR0lSQVJE
IDxzYXBrQHNhcGsuZnI+iQFUBBMBCAA+FiEEEIOwJg/1vpF1itJ4roJVuKDYKOQF
Alh91QoCGwMFCQPCZwAFCwkIBwIGFQgJCgsCBBYCAwECHgECF4AACgkQroJVuKDY
KORreggAlIkC2QjHP5tb7b0+LksB2JMXdY+UzZBcJxtNmvA7gNQaGvWRrhrbePpa
MKDP+3A4BPDBsWFbbB7N56vQ5tROpmWbNKuFOVER4S1bj0JZV0E+xkDLqt9QwQtQ
ojd7oIZJwDUwdud1PvCza2mjgBqqiFE+twbc3i9xjciCGspMniUul1eQYLxRJ0w+
sbvSOUnujnq5ByMSz9ij00O6aiPfNQS5oB5AALfpjYZDvWAAljLVrtmlQJWZ6dZo
T/YNwsW2dECPuti8+Nmu5FxPGDTXxdbnRaeJTQ3T6q1oUVAv7yTXBx5NXfXkMa5i
iEayQIH8Joq5Ev5ja/lRGQQhArMQ2bkBDQRYfdUKAQgAv7B3coLSrOQbuTZSlgWE
QeT+7DWbmqE1LAQA1pQPcUPXLBUVd60amZJxF9nzUYcY83ylDi0gUNJS+DJGOXpT
pzX2IOuOMGbtUSeKwg5s9O4SUO7f2yCc3RGaegER5zgESxelmOXG+b/hoNt7JbdU
JtxcnLr91Jw2PBO/Xf0ZKJ01CQG2Yzdrrj6jnrHyx94seHy0i6xH1o0OuvfVMLfN
/Vbb/ZHh6ym2wHNqRX62b0VAbchcJXX/MEehXGknKTkO6dDUd+mhRgWMf9ZGRFWx
ag4qALimkf1FXtAyD0vxFYeyoWUQzrOvUsm2BxIN/986R08fhkBQnp5nz07mrU02
cQARAQABiQE8BBgBCAAmFiEEEIOwJg/1vpF1itJ4roJVuKDYKOQFAlh91QoCGwwF
CQPCZwAACgkQroJVuKDYKOT32wf/UZqMdPn5OhyhffFzjQx7wolrf92WkF2JkxtH
6c3Htjlt/p5RhtKEeErSrNAxB4pqB7dznHaJXiOdWEZtRVXXjlNHjrokGTesqtKk
lHWtK62/MuyLdr+FdCl68F3ewuT2iu/MDv+D4HPqA47zma9xVgZ9ZNwJOpv3fCOo
RfY66UjGEnfgYifgtI5S84/mp2jaSc9UNvlZB6RSf8cfbJUL74kS2lq+xzSlf0yP
Av844q/BfRuVsJsK1NDNG09LC30B0l3LKBqlrRmRTUMHtgchdX2dY+p7GPOoSzlR
MkM/fdpyc2hY7Dl/+qFmN5MG5yGmMpQcX+RNNR222ibNC1D3wg==
=i9b7
-----END PGP PUBLIC KEY BLOCK-----`
key, err := checkArmoredGPGKeyString(testGPGArmor)
assert.NoError(t, err, "Could not parse a valid GPG public armored rsa key", key)
// TODO verify value of key
}
func TestCheckArmoredbrainpoolP256r1GPGKeyString(t *testing.T) {
testGPGArmor := `-----BEGIN PGP PUBLIC KEY BLOCK-----
Version: GnuPG v2
mFMEV6HwkhMJKyQDAwIIAQEHAgMEUsvJO/j5dFMRRj67qeZC9fSKBsGZdOHRj2+6
8wssmbUuLTfT/ZjIbExETyY8hFnURRGpD2Ifyz0cKjXcbXfJtrQTRm9vYmFyIDxm
b29AYmFyLmRlPoh/BBMTCAAnBQJZOsDIAhsDBQkJZgGABQsJCAcCBhUICQoLAgQW
AgMBAh4BAheAAAoJEGuJTd/DBMzmNVQA/2beUrv1yU4gyvCiPDEm3pK42cSfaL5D
muCtPCUg9hlWAP4yq6M78NW8STfsXgn6oeziMYiHSTmV14nOamLuwwDWM7hXBFeh
8JISCSskAwMCCAEBBwIDBG3A+XfINAZp1CTse2mRNgeUE5DbUtEpO8ALXKA1UQsQ
DLKq27b7zTgawgXIGUGP6mWsJ5oH7MNAJ/uKTsYmX40DAQgHiGcEGBMIAA8FAleh
8JICGwwFCQlmAYAACgkQa4lN38MEzOZwKAD/QKyerAgcvzzLaqvtap3XvpYcw9tc
OyjLLnFQiVmq7kEA/0z0CQe3ZQiQIq5zrs7Nh1XRkFAo8GlU/SGC9XFFi722
=ZiSe
-----END PGP PUBLIC KEY BLOCK-----`
key, err := checkArmoredGPGKeyString(testGPGArmor)
assert.NoError(t, err, "Could not parse a valid GPG public armored brainpoolP256r1 key", key)
// TODO verify value of key
}
func TestExtractSignature(t *testing.T) {
testGPGArmor := `-----BEGIN PGP PUBLIC KEY BLOCK-----
mQENBFh91QoBCADciaDd7aqegYkn4ZIG7J0p1CRwpqMGjxFroJEMg6M1ZiuEVTRv
z49P4kcr1+98NvFmcNc+x5uJgvPCwr/N8ZW5nqBUs2yrklbFF4MeQomyZJJegP8m
/dsRT3BwIT8YMUtJuCj0iqD9vuKYfjrztcMgC1sYwcE9E9OlA0pWBvUdU2i0TIB1
vOq6slWGvHHa5l5gPfm09idlVxfH5+I+L1uIMx5ovbiVVU5x2f1AR1T18f0t2TVN
0agFTyuoYE1ATmvJHmMcsfgM1Gpd9hIlr9vlupT2kKTPoNzVzsJsOU6Ku/Lf/bac
mF+TfSbRCtmG7dkYZ4metLj7zG/WkW8IvJARABEBAAG0HUFudG9pbmUgR0lSQVJE
IDxzYXBrQHNhcGsuZnI+iQFUBBMBCAA+FiEEEIOwJg/1vpF1itJ4roJVuKDYKOQF
Alh91QoCGwMFCQPCZwAFCwkIBwIGFQgJCgsCBBYCAwECHgECF4AACgkQroJVuKDY
KORreggAlIkC2QjHP5tb7b0+LksB2JMXdY+UzZBcJxtNmvA7gNQaGvWRrhrbePpa
MKDP+3A4BPDBsWFbbB7N56vQ5tROpmWbNKuFOVER4S1bj0JZV0E+xkDLqt9QwQtQ
ojd7oIZJwDUwdud1PvCza2mjgBqqiFE+twbc3i9xjciCGspMniUul1eQYLxRJ0w+
sbvSOUnujnq5ByMSz9ij00O6aiPfNQS5oB5AALfpjYZDvWAAljLVrtmlQJWZ6dZo
T/YNwsW2dECPuti8+Nmu5FxPGDTXxdbnRaeJTQ3T6q1oUVAv7yTXBx5NXfXkMa5i
iEayQIH8Joq5Ev5ja/lRGQQhArMQ2bkBDQRYfdUKAQgAv7B3coLSrOQbuTZSlgWE
QeT+7DWbmqE1LAQA1pQPcUPXLBUVd60amZJxF9nzUYcY83ylDi0gUNJS+DJGOXpT
pzX2IOuOMGbtUSeKwg5s9O4SUO7f2yCc3RGaegER5zgESxelmOXG+b/hoNt7JbdU
JtxcnLr91Jw2PBO/Xf0ZKJ01CQG2Yzdrrj6jnrHyx94seHy0i6xH1o0OuvfVMLfN
/Vbb/ZHh6ym2wHNqRX62b0VAbchcJXX/MEehXGknKTkO6dDUd+mhRgWMf9ZGRFWx
ag4qALimkf1FXtAyD0vxFYeyoWUQzrOvUsm2BxIN/986R08fhkBQnp5nz07mrU02
cQARAQABiQE8BBgBCAAmFiEEEIOwJg/1vpF1itJ4roJVuKDYKOQFAlh91QoCGwwF
CQPCZwAACgkQroJVuKDYKOT32wf/UZqMdPn5OhyhffFzjQx7wolrf92WkF2JkxtH
6c3Htjlt/p5RhtKEeErSrNAxB4pqB7dznHaJXiOdWEZtRVXXjlNHjrokGTesqtKk
lHWtK62/MuyLdr+FdCl68F3ewuT2iu/MDv+D4HPqA47zma9xVgZ9ZNwJOpv3fCOo
RfY66UjGEnfgYifgtI5S84/mp2jaSc9UNvlZB6RSf8cfbJUL74kS2lq+xzSlf0yP
Av844q/BfRuVsJsK1NDNG09LC30B0l3LKBqlrRmRTUMHtgchdX2dY+p7GPOoSzlR
MkM/fdpyc2hY7Dl/+qFmN5MG5yGmMpQcX+RNNR222ibNC1D3wg==
=i9b7
-----END PGP PUBLIC KEY BLOCK-----`
keys, err := checkArmoredGPGKeyString(testGPGArmor)
if !assert.NotEmpty(t, keys) {
return
}
ekey := keys[0]
assert.NoError(t, err, "Could not parse a valid GPG armored key", ekey)
pubkey := ekey.PrimaryKey
content, err := base64EncPubKey(pubkey)
assert.NoError(t, err, "Could not base64 encode a valid PublicKey content", ekey)
key := &GPGKey{
KeyID: pubkey.KeyIdString(),
Content: content,
CreatedUnix: timeutil.TimeStamp(pubkey.CreationTime.Unix()),
CanSign: pubkey.CanSign(),
CanEncryptComms: pubkey.PubKeyAlgo.CanEncrypt(),
CanEncryptStorage: pubkey.PubKeyAlgo.CanEncrypt(),
CanCertify: pubkey.PubKeyAlgo.CanSign(),
}
cannotsignkey := &GPGKey{
KeyID: pubkey.KeyIdString(),
Content: content,
CreatedUnix: timeutil.TimeStamp(pubkey.CreationTime.Unix()),
CanSign: false,
CanEncryptComms: false,
CanEncryptStorage: false,
CanCertify: false,
}
testGoodSigArmor := `-----BEGIN PGP SIGNATURE-----
iQEzBAABCAAdFiEEEIOwJg/1vpF1itJ4roJVuKDYKOQFAljAiQIACgkQroJVuKDY
KORvCgf6A/Ehh0r7QbO2tFEghT+/Ab+bN7jRN3zP9ed6/q/ophYmkrU0NibtbJH9
AwFVdHxCmj78SdiRjaTKyevklXw34nvMftmvnOI4lBNUdw6KWl25/n/7wN0l2oZW
rW3UawYpZgodXiLTYarfEimkDQmT67ArScjRA6lLbkEYKO0VdwDu+Z6yBUH3GWtm
45RkXpnsF6AXUfuD7YxnfyyDE1A7g7zj4vVYUAfWukJjqow/LsCUgETETJOqj9q3
52/oQDs04fVkIEtCDulcY+K/fKlukBPJf9WceNDEqiENUzN/Z1y0E+tJ07cSy4bk
yIJb+d0OAaG8bxloO7nJq4Res1Qa8Q==
=puvG
-----END PGP SIGNATURE-----`
testGoodPayload := `tree 56ae8d2799882b20381fc11659db06c16c68c61a
parent c7870c39e4e6b247235ca005797703ec4254613f
author Antoine GIRARD <sapk@sapk.fr> 1489012989 +0100
committer Antoine GIRARD <sapk@sapk.fr> 1489012989 +0100
Goog GPG
`
testBadSigArmor := `-----BEGIN PGP SIGNATURE-----
iQEzBAABCAAdFiEE5yr4rn9ulbdMxJFiPYI/ySNrtNkFAljAiYkACgkQPYI/ySNr
tNmDdQf+NXhVRiOGt0GucpjJCGrOnK/qqVUmQyRUfrqzVUdb/1/Ws84V5/wE547I
6z3oxeBKFsJa1CtIlxYaUyVhYnDzQtphJzub+Aw3UG0E2ywiE+N7RCa1Ufl7pPxJ
U0SD6gvNaeTDQV/Wctu8v8DkCtEd3N8cMCDWhvy/FQEDztVtzm8hMe0Vdm0ozEH6
P0W93sDNkLC5/qpWDN44sFlYDstW5VhMrnF0r/ohfaK2kpYHhkPk7WtOoHSUwQSg
c4gfhjvXIQrWFnII1Kr5jFGlmgNSR02qpb31VGkMzSnBhWVf2OaHS/kI49QHJakq
AhVDEnoYLCgoDGg9c3p1Ll2452/c6Q==
=uoGV
-----END PGP SIGNATURE-----`
testBadPayload := `tree 3074ff04951956a974e8b02d57733b0766f7cf6c
parent fd3577542f7ad1554c7c7c0eb86bb57a1324ad91
author Antoine GIRARD <sapk@sapk.fr> 1489013107 +0100
committer Antoine GIRARD <sapk@sapk.fr> 1489013107 +0100
Unknown GPG key with good email
`
// Reading Sign
goodSig, err := extractSignature(testGoodSigArmor)
assert.NoError(t, err, "Could not parse a valid GPG armored signature", testGoodSigArmor)
badSig, err := extractSignature(testBadSigArmor)
assert.NoError(t, err, "Could not parse a valid GPG armored signature", testBadSigArmor)
// Generating hash of commit
goodHash, err := populateHash(goodSig.Hash, []byte(testGoodPayload))
assert.NoError(t, err, "Could not generate a valid hash of payload", testGoodPayload)
badHash, err := populateHash(badSig.Hash, []byte(testBadPayload))
assert.NoError(t, err, "Could not generate a valid hash of payload", testBadPayload)
// Verify
err = verifySign(goodSig, goodHash, key)
assert.NoError(t, err, "Could not validate a good signature")
err = verifySign(badSig, badHash, key)
assert.Error(t, err, "Validate a bad signature")
err = verifySign(goodSig, goodHash, cannotsignkey)
assert.Error(t, err, "Validate a bad signature with a kay that can not sign")
}
func TestCheckGPGUserEmail(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
_ = unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 1})
testEmailWithUpperCaseLetters := `-----BEGIN PGP PUBLIC KEY BLOCK-----
Version: GnuPG v1
mQENBFlEBvMBCADe+EQcfv/aKbMFy7YB8e/DE+hY39sfjvdvSgeXtNhfmYvIOUjT
ORMCvce2Oxzb3HTI0rjYsJpzo9jEQ53dB3vdr0ne5Juby6N7QPjof3NR+ko50Ki2
0ilOjYuA0v6VHLIn70UBa9NEf+XDuE7P+Lbtl2L9B9OMXtcTAZoA3cJySgtNFNIG
AVefPi8LeOcekL39wxJEA8OzdCyO5oENEwAG1tzjy9DDNJf74/dBBh2NiXeSeMxZ
RYeYzqEa2UTDP1fkUl7d2/hV36cKZWZr+l4SQ5bM7HeLj2SsfabLfqKoVWgkfAzQ
VwtkbRpzMiDLMte2ZAyTJUc+77YbFoyAmOcjABEBAAG0HFVzZXIgT25lIDxVc2Vy
MUBFeGFtcGxlLmNvbT6JATgEEwECACIFAllEBvMCGwMGCwkIBwMCBhUIAgkKCwQW
AgMBAh4BAheAAAoJEFMOzOY274DFw5EIAKc4jiYaMb1HDKrSv0tphgNxPFEY83/J
9CZggO7BINxlb7z/lH1i0U2h2Ha9E3VJTJQF80zBCaIvtU2UNrgVmSKoc0BdE/2S
rS9MAl29sXxf1BfvXHu12Suvo8O/ZFP45Vm/3kkHuasHyOV1GwUWnynt1qo0zUEn
WMIcB8USlmMT1TnSb10YKBd/BpGF3crFDJLfAHRumZUk4knDDWUOWy5RCOG8cedc
VTAhfdoKRRO3PchOfz6Rls/hew12mRNayqxuLQl2+BX+BWu+25dR3qyiS+twLbk6
Rjpb0S+RQTkYIUoI0SEZpxcTZso11xF5KNpKZ9aAoiLJqkNF5h4oPSe5AQ0EWUQG
8wEIALiMMqh3NF3ON/z7hQfeU24bCl/WdfJwCR9CWU/jx4X4gZq2C2aGtytGN5g/
qoYQ3poTOPzh/4Dvs+r6CtHqi0CvPiEOfSxzmaK+F+vA0GMn2i3Sx5gq/VB0mr+j
RIYMCjf68Tifo2RAT0VDzn6t304l5+VPr4OgbobMRH+wDe7Hhd2pZXl7ty8DooBn
vqaqoKgdiccUXGBKe4Oihl/oZ4qrYH6K4ACP1Sco1rs4mNeKDAW8k/Y7zLjg6d59
g0YQ1YI+CX/bKB7/cpMHLupyMLqvCcqIpjBXRJNMdjuMHgKckjr89DwnqXqgXz7W
u0B39MZQn9nn6vq8BdkoDFgrTQ8AEQEAAYkBHwQYAQIACQUCWUQG8wIbDAAKCRBT
DszmNu+Axf4IB/0S9NTc6kpwW+ZPZQNTWR5oKDEaXVCRLccOlkt33txMvk/z2jNM
trEke99ss5L1bRyWB5fRA+XVsPmW9kIk8pmGFmxqp2nSxr9m9rlL5oTYH8u6dfSm
zwGhqkfITjPI7hyNN52PLANwoS0o4dLzIE65ewigx6cnRlrT2IENObxG/tlxaYg1
NHahJX0uFlVk0W0bLBrs3fTDw1lS/N8HpyQb+5ryQmiIb2a48aygCS/h2qeRlX1d
Q0KHb+QcycSgbDx0ZAvdIacuKvBBcbxrsmFUI4LR+oIup0G9gUc0roPvr014jYQL
7f8r/8fpcN8t+I/41QHCs6L/BEIdTHW3rTQ6
=zHo9
-----END PGP PUBLIC KEY BLOCK-----`
keys, err := AddGPGKey(db.DefaultContext, 1, testEmailWithUpperCaseLetters, "", "")
assert.NoError(t, err)
if assert.NotEmpty(t, keys) {
key := keys[0]
if assert.Len(t, key.Emails, 1) {
assert.Equal(t, "user1@example.com", key.Emails[0].Email)
}
}
}
func TestCheckGParseGPGExpire(t *testing.T) {
testIssue6599 := `-----BEGIN PGP PUBLIC KEY BLOCK-----
mQINBFlFJRsBEAClNcRT5El+EaTtQEYs/eNAhr/bqiyt6fPMtabDq2x6a8wFWMX0
yhRh4vZuLzhi95DU/pmhZARt0W15eiN0AhWdOKxry1KtZNiZBzMm1f0qZJMuBG8g
YJ7aRkCqdWRxy1Q+U/yhr6z7ucD8/yn7u5wke/jsPdF/L8I/HKNHoawI1FcMC9v+
QoG3pIX8NVGdzaUYygFG1Gxofc3pb3i4pcpOUxpOP12t6PfwTCoAWZtRLgxTdwWn
DGvY6SCIIIxn4AC6u3+tHz9HDXx+4eiB7VxMsiIsEuHW9DVBzen9jFNNjRnNaFkL
pTAFOyGsSzGRGhuJpb7j7hByoWkaItqaw+clnzVrDqhfbxS1B8dmgMANh9pzNsv7
J/OnNdGsbgDX5RytSKMaXclK2ZGH6Txatgezo167z6EdthNR1daj1QfqWADiqKbR
UXp7Xz9b+/CBedUNEXPbIExva9mPsFJo2IEntRGtdhhjuO4a6HLG7k1i0o0dHxqb
a9HrOW7fO902L7JHIgnjpDWDGLGGnVGcGWdEEZggfpnvjxADeTgyMb2XkALTQ0GG
yRywByxG8/zjXeEkqUng/mxNbBCcHcuIRVsqYwGQLiLubYxnRudqtNst8Tdu+0+q
AL0bb8ueQC1M3WHsMUxvTjknFJdJzRicNyLf6AdfRv6yy6Ra+t4SFoSbsQARAQAB
tB90YXN0eXRlYSA8dGFzdHl0ZWFAdGFzdHl0ZWEuZGU+iQJXBBMBCABBAhsDBQsJ
CAcCBhUICQoLAgQWAgMBAh4BAheAAhkBFiEE1bTEO0ioefY1KTbmWTRuDqNcZ+UF
Alyo2K0FCQVE5xIACgkQWTRuDqNcZ+UTFA/+IygU02oz19tRVNgVmKyXv1GhnkaY
O/oGxp7cRGJ0gf0bjhbJpFf4+6OHaS0ei47Qp8XTuStfWry6V6rXLSV/ZOOhFaCq
VpFvoG2JcPZbSTB+CR/lL5fWwx3w5PAOUwipGRFs7mYLgy8U/E3U7u+ioP4ZqCXS
heclyXAGNlrjUwvwOWRLxvcEQr4ztQR0Lk2tv1QYYDzbaXUSdnsM1YK9YpYP7BE2
luKtwwXaubdwcXPs96FEmGLGfsWC/dWnAxkYXPo9q7O6c5GKbGiP3xFhBaBCzzm0
PAqAJ+NyIWL63yI1aNNz4xC1marU7UPLzBnv5fG1WdscYqAbj8XbZ96mPPM80y0A
j5/7YecRXce4yedxRHhi3bD8MEzDMHWfkQPpWCZj/KwjDFiZwSMgpQUqeAllDKQx
Ld0CLkLuUe20b+/5h6dGtGpoACkoOPxMl6zi9uihztvR5iYdkwnmcxKmnEtz+WV4
1efhS3QRZro3QAHjhCqU1Xjl0hnwSCgP5nUhTq6dJqgeZ7c5D4Uhg55MXwQ68Oe4
NrQfhdO8IOSVPDPDEeQ2kuP7/HEZsjKZBMKhKoUcdXM6y9T2tYw3wv5JDuDxT2Q1
3IuFVr1uFm/spVyFCpPpPSQM1wfdtoPLRjiJ/KVh777AWUlywP2b7cWyKShYJb4P
QzTQ/udx94916cSJAlQEEwEIAD4WIQTVtMQ7SKh59jUpNuZZNG4Oo1xn5QUCWUUl
GwIbAwUJA8ORBQULCQgHAgYVCAkKCwIEFgIDAQIeAQIXgAAKCRBZNG4Oo1xn5Uoa
D/9tdmXECDZS1th0xmdNIsecxhI9dBGJyaJwfhH7UVkL+e86EsmTSzyJhBAepDDe
4wTEaW/NnjVX+ulO7rKFN4/qvSCOaeIdP0MEn7zfZVVKG8gMW4mb/piLvUnsZvsM
eWfv9AL/b3H1MRkl9S6XsE0ove72pmbBSZEhh2rNHqf+tIGr/RTtn80efTv3w+75
0UJtaFPsAKoAzNRy+ouhf9IHy9pEMJRA/hZ0Ho04QCDAC65mWz7iwI7v9VRDVfng
UjJPJahoM4vTpB30vJiFYT2oFTgdxGckfEUezsk8Rx/o6x4u6igKypPbeqM/7SMw
H61sCWR7nHJhCK55WeEIbzHEhwCZTf1pgvHj5oGUOjzksp2DmFV3ma3WCh8JyqyA
zw2OvOXBlayIaGIoyD5tSHS40rTi9JmOUfhg6WPN3MIrvsSVEV7JNdiZs/Tb07eQ
l71O7wv/LXZZCYP5NLV0PJbN2pHMf8cysWulfHN/mNgpEiLJpPBYVVyVbzWLg54X
FcNQMrT70kRF4M2GBRahXchkWi6+1pd3jPtvCFfcNiYBnHcrKu2R/UdSpFYdclDi
y6u7xMxXt0AVeLLtlXq7+ChOANMH5aPdUjCXeQDNJawLx41KL9fETsjScodmmpKi
SNhkC03FNfbkPJzZthoTxCfUBQeHYWgDpN3Gjb/OdSWC34kCVwQTAQgAQQIbAwUJ
A8ORBQULCQgHAgYVCAkKCwIEFgIDAQIeAQIXgBYhBNW0xDtIqHn2NSk25lk0bg6j
XGflBQJcqNWQAhkBAAoJEFk0bg6jXGfldcEP/iz4UbJPd/kr8D008ky7vI7hnYs8
VQIxL6ljQJ75XmVx/Lz1MVo4Vdsu6+qEta5gvqbGwjuEugaHcFVbHCZEBKI0QHSQ
UNHfXT8eZP/BwwFWawUokLTbF//Dg5xd5ejo/TeltNleyq1r0AoxcoMv1srrY4yK
GvWE5V8SVSi/E71y4VarS58ZH3NZ6sW5slnYvgAHTVgOjkVvMYk5JmrWsFsycYf8
Rs5BvCuXQpUV9N8UFfW8pAxYhLvUTqhf34m24syyFn9j1udEO1c+IeX7h7hX2CFL
+P6wS9Ok2Z++IKvhIXLy/OoBULxKXjM04aLxDDlRW3qEyeLKvbFiEHGSnlaDz27L
LBAGGRxzLLr0g1evV33AHUU2N8pATnzXHJaRiMjExjRi5IkHjbiEaxiqIwr8CSnS
4RlZ+owxhJ/4MjnsqBL3ELhkSnN+HGkPBQkbFDhCm0ICm78EK2x4+bWo/YUUfoky
Hq92XB6RNbO0RcdGyltFsJ02Ev20Hc4MClF7jT7xm7VJfbeYNmxZ6GNXZ7kEsl87
7qzFtr2BcEfw/ieyyoOrwAC9FBJc/9CALex3p3TGWpM43C+IdqZIsr9QHAzvJfY7
/n5/wJyCPhIZSSE3b8PZRIAdh6NA2IF877OCzIl2UFUNJE1zaEcTvjxZzCZ1SHGU
YzQeSbODHUuPDbhytBJnZW50b29AdGFzdHl0ZWEuZGWJAlQEEwEIAD4CGwMFCwkI
BwIGFQoJCAsCBBYCAwECHgECF4AWIQTVtMQ7SKh59jUpNuZZNG4Oo1xn5QUCXKjY
rQUJBUTnEgAKCRBZNG4Oo1xn5VhkD/42pGYstRMvrO37wJDnnLDm+ZPb0RGy80Ru
Nt3S6OmU3TFuU9mj/FBc8VNs6xr0CCMVVM/CXX1gXCHhADss1YDaOcRsl5wVJ6EF
tbpEXT/USMw3dV4Y8OYUSNxyEitzKt25CnOdWGPYaJG3YOtAR0qwopMiAgLrgLy9
mugXqnrykF7yN27i6iRi2Jk9K7tSb4owpw1kuToJrNGThAkz+3nvXG5oRiYFTlH3
pATx34r+QOg1o3giomP49cP4ohxvQFP90w2/cURhLqEKdR6N1X0bTXRQvy8G+4Wl
QMl8WYPzQUrKGMgj/f7Uhb3pFFLCcnCaYFdUj+fvshg5NMLGVztENz9x7Vr5n51o
Hj9WuM3s65orKrGhMUk4NJCsQWJUHnSNsEXsuir9ocwCv4unIJuoOukNJigL4d5o
i0fKPKuLpdIah1dmcrWLIoid0wPeA8unKQg3h6VL5KXpUudo8CiPw/kk1KTLtYQR
7lezb1oldqfWgGHmqnOK+u6sOhxGj2fcrTi4139ULMph+LCIB3JEtgaaw4lTTt0t
S8h6db6LalzsQyL2sIHgl/rmLmZ5sqZhmi/DsAjZWfpz+inUP6rgap+OgAmtLCit
BwsDAy7ux44mUNtW1KExuY2W/bmSLlV28H+fHJ3fhpHDQMNAFYc5n4NgTe6eT/KY
WA4KGfp7KYkCVAQTAQgAPhYhBNW0xDtIqHn2NSk25lk0bg6jXGflBQJcqNTKAhsD
BQkDw5EFBQsJCAcCBhUKCQgLAgQWAgMBAh4BAheAAAoJEFk0bg6jXGflazAP/iae
7/PIaWhIyDw14NvyJG4D8FMSV9bC1cJ+ICo0qkx0dxcZMsxTp7fD8ODaSWzJEI4X
mGDvJp5fJ7ZALFhp7IBIsj9CHRWyVBCzwhnAXgSmGF+qzBFE7WjQORdn5ytTiWAN
PqyJV0sAw46jLJNvYv/LaFb2bzR/z6U1wQ2qvqXZj8vh2eLvY2XfQa1HnKaPi8h9
OqtLM80/6uai2scdYAI6usB8wxTJY2b2B8flDB7c8DruCDRL1QmrK5o70yIIai2c
4fXHHglulT9GnwD01a5DA2dgn5nxb81xgofgofXQjIOYARUKvcuZsF/tsR5S+C5k
CJnq8V9xdABbWz/FvwXz7ejf2jPtAnD6gcvuPnLX/dsxFHio2n4HHzXboUrVMKid
zcvuIrmlNtvKHYGxC9Dk3vNM+9rTlaY2BRt0zkgakDpMhqFu6A/TCEDZK0ukQLtc
h0g806AWding6gr4vQDeX6dSCuJMFKTu/2q85R1w2vGuyWYSm6QR6sM+KumOX3vJ
c/zvOodhRWXQBWYHTuSw6QGDCI115lWO8DAK4T6u7SVXfthHKm+38dpDH1tSfcHo
KaG7XJKExEPgdcNLvJIN/xCx5lX6fy0ohj7oF1dEpeBpIgqTC0l5I8bLAjcLKZl9
4YwJSSS8aTedptCmBTAHWd6y3W/hgFJrdKsqbHVGuQINBFlFJRsBEAC1EFjL9rvn
O9UIJ2dfaPdfm2GjH/sKfOInfWp4KKEDWtS59Pssld4gnjcmDNgunYYhHYcok61K
9J4x33KvkNAhEbw9y5AGW0tb7p2I6NxiOaWZjmZbg7AJMBFenipdUXBEjbu4LzEd
yyIm3/lQiV4bW6GR14cKdQLZm/inVmbEaGSpq2g19WA+X7SwBxzZR9O80Iohm3RL
X8Z8lXzUj/fUWCCstfXZwdy4vbZv8ms7kmq+3TUOwOiVavgWYhbal+nO0kLdVFbb
i7YRvZh6afxfgMyJ3v1goXvsW1W8jno2ikUmkwZiiPY/cKOPmOwEzj3hl73i6qrx
vm9SjEwEzI/gFXlJD8cOKMc6/g8kUeCepDfdKjgo1SYynLUk4NW9QeucJo6BSPEP
llamHsTaUGzT4tj9qZqAQ0dwSnWYvyi19EMCGssLoy7bAoNueHOYZtHN5TskKShQ
XzEG9IRZvXGmaWAT17sFesqXK0g47jQswmwobDsXyvXJfree36jQRj7SAVVK44Im
bqBe6BT9QYIBkfThAWjwTibg0P1CPGk5TPpssAQgM3jxXVEyD6iKCS4LKWrtm+Sk
MlGaPNyO8OcwHp6p5QaYAE6vlSfT8fsZ0iGd06ua5miZRbkM2i94/jVKvZLRvWv4
S8SMZemAYnVMc0YFWEJCbaKdZp35rb5e4QARAQABiQI8BBgBCAAmAhsMFiEE1bTE
O0ioefY1KTbmWTRuDqNcZ+UFAlyo2PAFCQVE51UACgkQWTRuDqNcZ+V+Hg/9HhVI
No0ID4o8y0jlhyNg8n/Fy08uDALQ6JlbN6buLw+IYU75GTDIysGjx+9bgt+Mjvtp
bbWkeT6okKkyB3H/x7w7v9GTYWlnzMA/KwHF7L7Wqy0afcVjg+fchWXPJQ3H5Jxh
bcX3FKkIN9kpfdHN87C8//s4LzDOWeYCxFwkxkbx4tc1K4HhezpvYDKiLmFMVbaU
qB0pzP8IM3hU1GJeAC2skfjstuaKJPuF895aFSF6++DYodXBFu3UlSJbJGfDEBYC
9PgSrxX1qlNUFw+6Hr2uSdPnmcKgCDFGhxB1d/Z2Xa/QFhvuj7U38eyqla3dzXxu
4+/9BOoJwdyRlUxd1Jcy3q7l8V4Hk1vMwKICdXBadAcAgSi0ImXt7UttpTYB7WNV
nlFmFFi8eVnmMll08LWV6LygG8GBSzW5NUZnUhxHbFVFcEuHo6W1lIEgJooOnGwd
H2rqKXpkcv86q7ODxdt9nb0txUPzgukusHes6Q0cnTMWcd0YT75frKjjK6TK8KZA
XMH0zobogpnr/n2ji87cn9sSlL3/2NtxfAwqyDWomECKOtKYfx10OPjrPrScDFG0
aF6w50Xg5DH/I38zzBVanEgwzWHosIVKNQHgoSYijErnShbRefA8+zCsyn0q/9Rg
cToAM7X3ro+tQQHWDIhiayHvJMeGN/R/u1U4Kv25BK4EW0zMChEMALGnffpA/rz6
oRXV++syFI6AaByfiatYgKh+d2LkhyeAAnp93VBV8c2YArsSp7XookhxlRA7XAGw
x71VKouHjdcMpZM76OcEJgC2fKCbsLrMhkjKOjux6Lru1mY4bFmXBxex0pssvIoc
zefV00qVvQ0e2JkvUmuKKIplyH0GAapDRnF3R8/doNNUXfVufHButKHlmK7yaFkK
UBXLFUc3c8mCm/UQcMrFYrlyRNd6Axir2LpD8ya8gIwOM49nH+DDSla4d23zP+4M
kTaWZ5QlX4FGN8kfPE4rzVxhCP0jtC5m2oqFp8dIKtxzX836YkHG7wlAPsaoPmhl
kJMylGSwvjRvjxNLHWodMJfrQgajnW0UEd1XrfO48i/OD3f1Z22/sHRY2VejD4KJ
49QBienKCUlNbZRfpaGOQn2HqbOX6/wUfS/83rhBVNrsU2kNb/+6OKsJV2YtokPK
saS88q8225YEcsDLPS/3V5VrFW0CQwXJM4AbVweHhE7486VtSfkQswEAjTMJSbTO
4IgjWYDaQ57m77bc4N9z0oCWaChlaAjdzSsL/0JQx5GJXUcxW1GvEGhP/Fx1IFd3
oCR8OmY6oZHYmB1fNvFLSmJN0dJcQjm3hebrSQiWg/JvVAlF2S7f+j0pjeki09kM
0RqAHOkDpLeY6ifU8+QW5DP5yh8d9ZDc4wjPdz53ycwJzaMqESOIr9eHYtOWN6Hi
0rItsMN8FB5A70te1IcKG5UWh3cCRg7fEbKVofIYTSU2V98RLkp+iEHLKfa6wObx
Mt60OVU/xbrO28w93cLpWUIH1Csow3k3wSbNmw3d9mWc7cVESct+IM5W4ZSYMcjG
cvcMELWCwuT1mPSkR0hv2oz5xFOBlUV1KUViIcxpKzTrjj69JAaBbJ3f5OEfEbj/
G+aa30EoddPBhwF7XnQUeC/DLRJQh2MH1ohMnkpBttDipHOuFS1CZh8xoxr/8moW
nj5FRG+FAZeCmcqj5PE+du7KF2XRPBlxhc1Nu+kPejlr6qa5qdwo4MzfuzmxWmvc
WQuNMtaPqQvYL1A09MH0uMH65MtJNsqbSvHa5AwAlletPw6Wr0qrBLBCmOpNf+Q7
7nBQBrK5VPMcto9IkGB4/bwhx7gQ0O2dD4dD4DPpGY9p52KpOG2ECoCWMtbsPD2P
bs+WNHN8V+3ZCxZukEj25wDhc5941P01BhKVFevGLHyYNWk34mQk7RdHj9OiEL8n
GpQ9l/R58+mvVwarzs898/y5onQieWi0Zu3WfMvjTOG3D3NIKMuthzRytfV5C/tJ
+W5ZX/jLVR3bzvzx8Pnpvf602xCST9/7LbgFhljfXQq0bq0d9si9hvyaMOh1PQFU
2+PzmWtHcsiVoyXfQp6ztJYFkoYaaD+Mc2jWG2Qy9kAyUGTXj/WfkPn7hr5hvuwk
0kNDSan8NY2f1mtG253qr6fMOmCgrUfaumpafd9xIJ65x1G2BGAr8bzjLJufEUaG
D2wBYWE6tlRqT4j7u6u9vRjShKH+A1UpLV2pEtaIQ3wfbt6GIwFJHWU506m3RCCn
pL46fAOVKS1GSuf79koXsZeECJRSbipXz3TJs0TqiQKzBBgBCAAmAhsCFiEE1bTE
O0ioefY1KTbmWTRuDqNcZ+UFAlyo2PAFCQM9QGYAgXYgBBkRCAAdFiEENVUmaGTK
bX/0Wqbnz8OUl/GybgcFAltMzAoACgkQz8OUl/Gybgf0OwD/c4hwqsfZ79t7pM9d
PPWYQ1jyq2g3ELMKyPp79GmL0qsA/2t2qkaOEX3y7egmhL/iKyqASb4y/JTABGMU
hy5GjBhxCRBZNG4Oo1xn5WBvEACbCAQRC00FYoktuRzQQy2LCJe13AUS1/lCWv8B
Qu7hTmM8TC/iNmYk71qeYInQMp/12b0HSWcv8IBmOlMy2GTjgnTgiwpqY5nhtb9O
uB5H2g6fpu7FFG9ARhtH9PiTMwOUzfZFUz0tDdEEG5sayzWUcY3zjmJFmHSg5A9B
/Q/yctqZ1eINtyEECINo/OVEfD7bmyZwK/vrxAg285iF6lB11wVl+5E7sNy9Hvu8
4kCKPksqyjFWUd0XoEu9AH6+XVeEPF7CQKHpRfhc4uweT9O5nTb7aaPcqq0B4lUL
unG6KSCm88zaZczp2SUCFwENegmBT/YKN5ZoHsPh1nwLxh194EP/qRjW9IvFKTlJ
EsB4uCpfDeC233oH5nDkvvphcPYdUuOsVH1uPQ7PyWNTf1ufd9bDSDtK8epIcDPe
abOuphxQbrMVP4JJsBXnVW5raZO7s5lmSA8Ovce//+xJSAq9u0GTsGu1hWDe60ro
uOZwqjo/cU5G4y7WHRaC3oshH+DO8ajdXDogoDVs8DzYkTfWND2DDNEVhVrn7lGf
a4739sFIDagtBq6RzJGL0X82eJZzXPFiYvmy0OVbNDUgH+Drva/wRv/tN8RvBiS6
bsn8+GBGaU5RASu67UbqxHiytFnN4OnADA5ZHcwQbMgRHHiiMMIf+tJWH/pFMp00
epiDVQ==
=VSKJ
-----END PGP PUBLIC KEY BLOCK-----
`
keys, err := checkArmoredGPGKeyString(testIssue6599)
assert.NoError(t, err)
if assert.NotEmpty(t, keys) {
ekey := keys[0]
expire := getExpiryTime(ekey)
assert.Equal(t, time.Unix(1586105389, 0), expire)
}
}
func TestTryGetKeyIDFromSignature(t *testing.T) {
assert.Empty(t, tryGetKeyIDFromSignature(&packet.Signature{}))
assert.Equal(t, "038D1A3EADDBEA9C", tryGetKeyIDFromSignature(&packet.Signature{
IssuerKeyId: util.ToPointer(uint64(0x38D1A3EADDBEA9C)),
}))
assert.Equal(t, "038D1A3EADDBEA9C", tryGetKeyIDFromSignature(&packet.Signature{
IssuerFingerprint: []uint8{0xb, 0x23, 0x24, 0xc7, 0xe6, 0xfe, 0x4f, 0x3a, 0x6, 0x26, 0xc1, 0x21, 0x3, 0x8d, 0x1a, 0x3e, 0xad, 0xdb, 0xea, 0x9c},
}))
}

View File

@@ -0,0 +1,119 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package asymkey
import (
"context"
"strconv"
"time"
"code.gitea.io/gitea/models/db"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/base"
"code.gitea.io/gitea/modules/log"
)
// __________________ ________ ____ __.
// / _____/\______ \/ _____/ | |/ _|____ ___.__.
// / \ ___ | ___/ \ ___ | <_/ __ < | |
// \ \_\ \| | \ \_\ \ | | \ ___/\___ |
// \______ /|____| \______ / |____|__ \___ > ____|
// \/ \/ \/ \/\/
// ____ ____ .__ _____
// \ \ / /___________|__|/ ____\__.__.
// \ Y // __ \_ __ \ \ __< | |
// \ /\ ___/| | \/ || | \___ |
// \___/ \___ >__| |__||__| / ____|
// \/ \/
// This file provides functions relating verifying gpg keys
// VerifyGPGKey marks a GPG key as verified
func VerifyGPGKey(ctx context.Context, ownerID int64, keyID, token, signature string) (string, error) {
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return "", err
}
defer committer.Close()
key := new(GPGKey)
has, err := db.GetEngine(ctx).Where("owner_id = ? AND key_id = ?", ownerID, keyID).Get(key)
if err != nil {
return "", err
} else if !has {
return "", ErrGPGKeyNotExist{}
}
if err := key.LoadSubKeys(ctx); err != nil {
return "", err
}
sig, err := extractSignature(signature)
if err != nil {
return "", ErrGPGInvalidTokenSignature{
ID: key.KeyID,
Wrapped: err,
}
}
signer, err := hashAndVerifyWithSubKeys(sig, token, key)
if err != nil {
return "", ErrGPGInvalidTokenSignature{
ID: key.KeyID,
Wrapped: err,
}
}
if signer == nil {
signer, err = hashAndVerifyWithSubKeys(sig, token+"\n", key)
if err != nil {
return "", ErrGPGInvalidTokenSignature{
ID: key.KeyID,
Wrapped: err,
}
}
}
if signer == nil {
signer, err = hashAndVerifyWithSubKeys(sig, token+"\n\n", key)
if err != nil {
return "", ErrGPGInvalidTokenSignature{
ID: key.KeyID,
Wrapped: err,
}
}
}
if signer == nil {
log.Error("Unable to validate token signature. Error: %v", err)
return "", ErrGPGInvalidTokenSignature{
ID: key.KeyID,
}
}
if signer.PrimaryKeyID != key.KeyID && signer.KeyID != key.KeyID {
return "", ErrGPGKeyNotExist{}
}
key.Verified = true
if _, err := db.GetEngine(ctx).ID(key.ID).SetExpr("verified", true).Update(new(GPGKey)); err != nil {
return "", err
}
if err := committer.Commit(); err != nil {
return "", err
}
return key.KeyID, nil
}
// VerificationToken returns token for the user that will be valid in minutes (time)
func VerificationToken(user *user_model.User, minutes int) string {
return base.EncodeSha256(
time.Now().Truncate(1*time.Minute).Add(time.Duration(minutes)*time.Minute).Format(
time.RFC1123Z) + ":" +
user.CreatedUnix.Format(time.RFC1123Z) + ":" +
user.Name + ":" +
user.Email + ":" +
strconv.FormatInt(user.ID, 10))
}

View File

@@ -0,0 +1,23 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package asymkey
import (
"testing"
"code.gitea.io/gitea/models/unittest"
)
func TestMain(m *testing.M) {
unittest.MainTest(m, &unittest.TestOptions{
FixtureFiles: []string{
"gpg_key.yml",
"public_key.yml",
"deploy_key.yml",
"gpg_key_import.yml",
"user.yml",
"email_address.yml",
},
})
}

427
models/asymkey/ssh_key.go Normal file
View File

@@ -0,0 +1,427 @@
// Copyright 2014 The Gogs Authors. All rights reserved.
// Copyright 2019 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package asymkey
import (
"context"
"fmt"
"strings"
"time"
"code.gitea.io/gitea/models/auth"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/perm"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/timeutil"
"code.gitea.io/gitea/modules/util"
"golang.org/x/crypto/ssh"
"xorm.io/builder"
)
// KeyType specifies the key type
type KeyType int
const (
// KeyTypeUser specifies the user key
KeyTypeUser = iota + 1
// KeyTypeDeploy specifies the deploy key
KeyTypeDeploy
// KeyTypePrincipal specifies the authorized principal key
KeyTypePrincipal
)
// PublicKey represents a user or deploy SSH public key.
type PublicKey struct {
ID int64 `xorm:"pk autoincr"`
OwnerID int64 `xorm:"INDEX NOT NULL"`
Name string `xorm:"NOT NULL"`
Fingerprint string `xorm:"INDEX NOT NULL"`
Content string `xorm:"MEDIUMTEXT NOT NULL"`
Mode perm.AccessMode `xorm:"NOT NULL DEFAULT 2"`
Type KeyType `xorm:"NOT NULL DEFAULT 1"`
LoginSourceID int64 `xorm:"NOT NULL DEFAULT 0"`
CreatedUnix timeutil.TimeStamp `xorm:"created"`
UpdatedUnix timeutil.TimeStamp `xorm:"updated"`
HasRecentActivity bool `xorm:"-"`
HasUsed bool `xorm:"-"`
Verified bool `xorm:"NOT NULL DEFAULT false"`
}
func init() {
db.RegisterModel(new(PublicKey))
}
// AfterLoad is invoked from XORM after setting the values of all fields of this object.
func (key *PublicKey) AfterLoad() {
key.HasUsed = key.UpdatedUnix > key.CreatedUnix
key.HasRecentActivity = key.UpdatedUnix.AddDuration(7*24*time.Hour) > timeutil.TimeStampNow()
}
// OmitEmail returns content of public key without email address.
func (key *PublicKey) OmitEmail() string {
return strings.Join(strings.Split(key.Content, " ")[:2], " ")
}
// AuthorizedString returns formatted public key string for authorized_keys file.
//
// TODO: Consider dropping this function
func (key *PublicKey) AuthorizedString() string {
return AuthorizedStringForKey(key)
}
func addKey(ctx context.Context, key *PublicKey) (err error) {
if len(key.Fingerprint) == 0 {
key.Fingerprint, err = CalcFingerprint(key.Content)
if err != nil {
return err
}
}
// Save SSH key.
if err = db.Insert(ctx, key); err != nil {
return err
}
return appendAuthorizedKeysToFile(key)
}
// AddPublicKey adds new public key to database and authorized_keys file.
func AddPublicKey(ctx context.Context, ownerID int64, name, content string, authSourceID int64) (*PublicKey, error) {
log.Trace(content)
fingerprint, err := CalcFingerprint(content)
if err != nil {
return nil, err
}
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return nil, err
}
defer committer.Close()
if err := checkKeyFingerprint(ctx, fingerprint); err != nil {
return nil, err
}
// Key name of same user cannot be duplicated.
has, err := db.GetEngine(ctx).
Where("owner_id = ? AND name = ?", ownerID, name).
Get(new(PublicKey))
if err != nil {
return nil, err
} else if has {
return nil, ErrKeyNameAlreadyUsed{ownerID, name}
}
key := &PublicKey{
OwnerID: ownerID,
Name: name,
Fingerprint: fingerprint,
Content: content,
Mode: perm.AccessModeWrite,
Type: KeyTypeUser,
LoginSourceID: authSourceID,
}
if err = addKey(ctx, key); err != nil {
return nil, fmt.Errorf("addKey: %w", err)
}
return key, committer.Commit()
}
// GetPublicKeyByID returns public key by given ID.
func GetPublicKeyByID(ctx context.Context, keyID int64) (*PublicKey, error) {
key := new(PublicKey)
has, err := db.GetEngine(ctx).
ID(keyID).
Get(key)
if err != nil {
return nil, err
} else if !has {
return nil, ErrKeyNotExist{keyID}
}
return key, nil
}
// SearchPublicKeyByContent searches content as prefix (leak e-mail part)
// and returns public key found.
func SearchPublicKeyByContent(ctx context.Context, content string) (*PublicKey, error) {
key := new(PublicKey)
has, err := db.GetEngine(ctx).
Where("content like ?", content+"%").
Get(key)
if err != nil {
return nil, err
} else if !has {
return nil, ErrKeyNotExist{}
}
return key, nil
}
// SearchPublicKeyByContentExact searches content
// and returns public key found.
func SearchPublicKeyByContentExact(ctx context.Context, content string) (*PublicKey, error) {
key := new(PublicKey)
has, err := db.GetEngine(ctx).
Where("content = ?", content).
Get(key)
if err != nil {
return nil, err
} else if !has {
return nil, ErrKeyNotExist{}
}
return key, nil
}
type FindPublicKeyOptions struct {
db.ListOptions
OwnerID int64
Fingerprint string
KeyTypes []KeyType
NotKeytype KeyType
LoginSourceID int64
}
func (opts FindPublicKeyOptions) ToConds() builder.Cond {
cond := builder.NewCond()
if opts.OwnerID > 0 {
cond = cond.And(builder.Eq{"owner_id": opts.OwnerID})
}
if opts.Fingerprint != "" {
cond = cond.And(builder.Eq{"fingerprint": opts.Fingerprint})
}
if len(opts.KeyTypes) > 0 {
cond = cond.And(builder.In("`type`", opts.KeyTypes))
}
if opts.NotKeytype > 0 {
cond = cond.And(builder.Neq{"`type`": opts.NotKeytype})
}
if opts.LoginSourceID > 0 {
cond = cond.And(builder.Eq{"login_source_id": opts.LoginSourceID})
}
return cond
}
// UpdatePublicKeyUpdated updates public key use time.
func UpdatePublicKeyUpdated(ctx context.Context, id int64) error {
// Check if key exists before update as affected rows count is unreliable
// and will return 0 affected rows if two updates are made at the same time
if cnt, err := db.GetEngine(ctx).ID(id).Count(&PublicKey{}); err != nil {
return err
} else if cnt != 1 {
return ErrKeyNotExist{id}
}
_, err := db.GetEngine(ctx).ID(id).Cols("updated_unix").Update(&PublicKey{
UpdatedUnix: timeutil.TimeStampNow(),
})
if err != nil {
return err
}
return nil
}
// PublicKeysAreExternallyManaged returns whether the provided KeyID represents an externally managed Key
func PublicKeysAreExternallyManaged(ctx context.Context, keys []*PublicKey) ([]bool, error) {
sourceCache := make(map[int64]*auth.Source, len(keys))
externals := make([]bool, len(keys))
for i, key := range keys {
if key.LoginSourceID == 0 {
externals[i] = false
continue
}
source, ok := sourceCache[key.LoginSourceID]
if !ok {
var err error
source, err = auth.GetSourceByID(ctx, key.LoginSourceID)
if err != nil {
if auth.IsErrSourceNotExist(err) {
externals[i] = false
sourceCache[key.LoginSourceID] = &auth.Source{
ID: key.LoginSourceID,
}
continue
}
return nil, err
}
}
if sshKeyProvider, ok := source.Cfg.(auth.SSHKeyProvider); ok && sshKeyProvider.ProvidesSSHKeys() {
// Disable setting SSH keys for this user
externals[i] = true
}
}
return externals, nil
}
// PublicKeyIsExternallyManaged returns whether the provided KeyID represents an externally managed Key
func PublicKeyIsExternallyManaged(ctx context.Context, id int64) (bool, error) {
key, err := GetPublicKeyByID(ctx, id)
if err != nil {
return false, err
}
if key.LoginSourceID == 0 {
return false, nil
}
source, err := auth.GetSourceByID(ctx, key.LoginSourceID)
if err != nil {
if auth.IsErrSourceNotExist(err) {
return false, nil
}
return false, err
}
if sshKeyProvider, ok := source.Cfg.(auth.SSHKeyProvider); ok && sshKeyProvider.ProvidesSSHKeys() {
// Disable setting SSH keys for this user
return true, nil
}
return false, nil
}
// deleteKeysMarkedForDeletion returns true if ssh keys needs update
func deleteKeysMarkedForDeletion(ctx context.Context, keys []string) (bool, error) {
// Start session
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return false, err
}
defer committer.Close()
// Delete keys marked for deletion
var sshKeysNeedUpdate bool
for _, KeyToDelete := range keys {
key, err := SearchPublicKeyByContent(ctx, KeyToDelete)
if err != nil {
log.Error("SearchPublicKeyByContent: %v", err)
continue
}
if _, err = db.DeleteByID[PublicKey](ctx, key.ID); err != nil {
log.Error("DeleteByID[PublicKey]: %v", err)
continue
}
sshKeysNeedUpdate = true
}
if err := committer.Commit(); err != nil {
return false, err
}
return sshKeysNeedUpdate, nil
}
// AddPublicKeysBySource add a users public keys. Returns true if there are changes.
func AddPublicKeysBySource(ctx context.Context, usr *user_model.User, s *auth.Source, sshPublicKeys []string) bool {
var sshKeysNeedUpdate bool
for _, sshKey := range sshPublicKeys {
var err error
found := false
keys := []byte(sshKey)
loop:
for len(keys) > 0 && err == nil {
var out ssh.PublicKey
// We ignore options as they are not relevant to Gitea
out, _, _, keys, err = ssh.ParseAuthorizedKey(keys)
if err != nil {
break loop
}
found = true
marshalled := string(ssh.MarshalAuthorizedKey(out))
marshalled = marshalled[:len(marshalled)-1]
sshKeyName := fmt.Sprintf("%s-%s", s.Name, ssh.FingerprintSHA256(out))
if _, err := AddPublicKey(ctx, usr.ID, sshKeyName, marshalled, s.ID); err != nil {
if IsErrKeyAlreadyExist(err) {
log.Trace("AddPublicKeysBySource[%s]: Public SSH Key %s already exists for user", sshKeyName, usr.Name)
} else {
log.Error("AddPublicKeysBySource[%s]: Error adding Public SSH Key for user %s: %v", sshKeyName, usr.Name, err)
}
} else {
log.Trace("AddPublicKeysBySource[%s]: Added Public SSH Key for user %s", sshKeyName, usr.Name)
sshKeysNeedUpdate = true
}
}
if !found && err != nil {
log.Warn("AddPublicKeysBySource[%s]: Skipping invalid Public SSH Key for user %s: %v", s.Name, usr.Name, sshKey)
}
}
return sshKeysNeedUpdate
}
// SynchronizePublicKeys updates a users public keys. Returns true if there are changes.
func SynchronizePublicKeys(ctx context.Context, usr *user_model.User, s *auth.Source, sshPublicKeys []string) bool {
var sshKeysNeedUpdate bool
log.Trace("synchronizePublicKeys[%s]: Handling Public SSH Key synchronization for user %s", s.Name, usr.Name)
// Get Public Keys from DB with current LDAP source
var giteaKeys []string
keys, err := db.Find[PublicKey](ctx, FindPublicKeyOptions{
OwnerID: usr.ID,
LoginSourceID: s.ID,
})
if err != nil {
log.Error("synchronizePublicKeys[%s]: Error listing Public SSH Keys for user %s: %v", s.Name, usr.Name, err)
}
for _, v := range keys {
giteaKeys = append(giteaKeys, v.OmitEmail())
}
// Process the provided keys to remove duplicates and name part
var providedKeys []string
for _, v := range sshPublicKeys {
sshKeySplit := strings.Split(v, " ")
if len(sshKeySplit) > 1 {
key := strings.Join(sshKeySplit[:2], " ")
if !util.SliceContainsString(providedKeys, key) {
providedKeys = append(providedKeys, key)
}
}
}
// Check if Public Key sync is needed
if util.SliceSortedEqual(giteaKeys, providedKeys) {
log.Trace("synchronizePublicKeys[%s]: Public Keys are already in sync for %s (Source:%v/DB:%v)", s.Name, usr.Name, len(providedKeys), len(giteaKeys))
return false
}
log.Trace("synchronizePublicKeys[%s]: Public Key needs update for user %s (Source:%v/DB:%v)", s.Name, usr.Name, len(providedKeys), len(giteaKeys))
// Add new Public SSH Keys that doesn't already exist in DB
var newKeys []string
for _, key := range providedKeys {
if !util.SliceContainsString(giteaKeys, key) {
newKeys = append(newKeys, key)
}
}
if AddPublicKeysBySource(ctx, usr, s, newKeys) {
sshKeysNeedUpdate = true
}
// Mark keys from DB that no longer exist in the source for deletion
var giteaKeysToDelete []string
for _, giteaKey := range giteaKeys {
if !util.SliceContainsString(providedKeys, giteaKey) {
log.Trace("synchronizePublicKeys[%s]: Marking Public SSH Key for deletion for user %s: %v", s.Name, usr.Name, giteaKey)
giteaKeysToDelete = append(giteaKeysToDelete, giteaKey)
}
}
// Delete keys from DB that no longer exist in the source
needUpd, err := deleteKeysMarkedForDeletion(ctx, giteaKeysToDelete)
if err != nil {
log.Error("synchronizePublicKeys[%s]: Error deleting Public Keys marked for deletion for user %s: %v", s.Name, usr.Name, err)
}
if needUpd {
sshKeysNeedUpdate = true
}
return sshKeysNeedUpdate
}

View File

@@ -0,0 +1,161 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package asymkey
import (
"bufio"
"context"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"sync"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/util"
)
// _____ __ .__ .__ .___
// / _ \ __ ___/ |_| |__ ___________|__|_______ ____ __| _/
// / /_\ \| | \ __\ | \ / _ \_ __ \ \___ // __ \ / __ |
// / | \ | /| | | Y ( <_> ) | \/ |/ /\ ___// /_/ |
// \____|__ /____/ |__| |___| /\____/|__| |__/_____ \\___ >____ |
// \/ \/ \/ \/ \/
// ____ __.
// | |/ _|____ ___.__. ______
// | <_/ __ < | |/ ___/
// | | \ ___/\___ |\___ \
// |____|__ \___ > ____/____ >
// \/ \/\/ \/
//
// This file contains functions for creating authorized_keys files
//
// There is a dependence on the database within RegeneratePublicKeys however most of these functions probably belong in a module
const (
tplCommentPrefix = `# gitea public key`
tplPublicKey = tplCommentPrefix + "\n" + `command=%s,no-port-forwarding,no-X11-forwarding,no-agent-forwarding,no-pty,no-user-rc,restrict %s` + "\n"
)
var sshOpLocker sync.Mutex
func WithSSHOpLocker(f func() error) error {
sshOpLocker.Lock()
defer sshOpLocker.Unlock()
return f()
}
// AuthorizedStringForKey creates the authorized keys string appropriate for the provided key
func AuthorizedStringForKey(key *PublicKey) string {
sb := &strings.Builder{}
_ = setting.SSH.AuthorizedKeysCommandTemplateTemplate.Execute(sb, map[string]any{
"AppPath": util.ShellEscape(setting.AppPath),
"AppWorkPath": util.ShellEscape(setting.AppWorkPath),
"CustomConf": util.ShellEscape(setting.CustomConf),
"CustomPath": util.ShellEscape(setting.CustomPath),
"Key": key,
})
return fmt.Sprintf(tplPublicKey, util.ShellEscape(sb.String()), key.Content)
}
// appendAuthorizedKeysToFile appends new SSH keys' content to authorized_keys file.
func appendAuthorizedKeysToFile(keys ...*PublicKey) error {
// Don't need to rewrite this file if builtin SSH server is enabled.
if setting.SSH.StartBuiltinServer || !setting.SSH.CreateAuthorizedKeysFile {
return nil
}
sshOpLocker.Lock()
defer sshOpLocker.Unlock()
if setting.SSH.RootPath != "" {
// First of ensure that the RootPath is present, and if not make it with 0700 permissions
// This of course doesn't guarantee that this is the right directory for authorized_keys
// but at least if it's supposed to be this directory and it doesn't exist and we're the
// right user it will at least be created properly.
err := os.MkdirAll(setting.SSH.RootPath, 0o700)
if err != nil {
log.Error("Unable to MkdirAll(%s): %v", setting.SSH.RootPath, err)
return err
}
}
fPath := filepath.Join(setting.SSH.RootPath, "authorized_keys")
f, err := os.OpenFile(fPath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0o600)
if err != nil {
return err
}
defer f.Close()
// Note: chmod command does not support in Windows.
if !setting.IsWindows {
fi, err := f.Stat()
if err != nil {
return err
}
// .ssh directory should have mode 700, and authorized_keys file should have mode 600.
if fi.Mode().Perm() > 0o600 {
log.Error("authorized_keys file has unusual permission flags: %s - setting to -rw-------", fi.Mode().Perm().String())
if err = f.Chmod(0o600); err != nil {
return err
}
}
}
for _, key := range keys {
if key.Type == KeyTypePrincipal {
continue
}
if _, err = f.WriteString(key.AuthorizedString()); err != nil {
return err
}
}
return nil
}
// RegeneratePublicKeys regenerates the authorized_keys file
func RegeneratePublicKeys(ctx context.Context, t io.StringWriter) error {
if err := db.GetEngine(ctx).Where("type != ?", KeyTypePrincipal).Iterate(new(PublicKey), func(idx int, bean any) (err error) {
_, err = t.WriteString((bean.(*PublicKey)).AuthorizedString())
return err
}); err != nil {
return err
}
fPath := filepath.Join(setting.SSH.RootPath, "authorized_keys")
isExist, err := util.IsExist(fPath)
if err != nil {
log.Error("Unable to check if %s exists. Error: %v", fPath, err)
return err
}
if isExist {
f, err := os.Open(fPath)
if err != nil {
return err
}
defer f.Close()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, tplCommentPrefix) {
scanner.Scan()
continue
}
_, err = t.WriteString(line + "\n")
if err != nil {
return err
}
}
if err = scanner.Err(); err != nil {
return fmt.Errorf("RegeneratePublicKeys scan: %w", err)
}
}
return nil
}

View File

@@ -0,0 +1,80 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package asymkey
import (
"bytes"
"context"
"fmt"
"strings"
"code.gitea.io/gitea/models/db"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/git"
"code.gitea.io/gitea/modules/log"
"github.com/42wim/sshsig"
)
// ParseCommitWithSSHSignature check if signature is good against keystore.
func ParseCommitWithSSHSignature(ctx context.Context, c *git.Commit, committer *user_model.User) *CommitVerification {
// Now try to associate the signature with the committer, if present
if committer.ID != 0 {
keys, err := db.Find[PublicKey](ctx, FindPublicKeyOptions{
OwnerID: committer.ID,
NotKeytype: KeyTypePrincipal,
})
if err != nil { // Skipping failed to get ssh keys of user
log.Error("ListPublicKeys: %v", err)
return &CommitVerification{
CommittingUser: committer,
Verified: false,
Reason: "gpg.error.failed_retrieval_gpg_keys",
}
}
committerEmailAddresses, err := user_model.GetEmailAddresses(ctx, committer.ID)
if err != nil {
log.Error("GetEmailAddresses: %v", err)
}
activated := false
for _, e := range committerEmailAddresses {
if e.IsActivated && strings.EqualFold(e.Email, c.Committer.Email) {
activated = true
break
}
}
for _, k := range keys {
if k.Verified && activated {
commitVerification := verifySSHCommitVerification(c.Signature.Signature, c.Signature.Payload, k, committer, committer, c.Committer.Email)
if commitVerification != nil {
return commitVerification
}
}
}
}
return &CommitVerification{
CommittingUser: committer,
Verified: false,
Reason: NoKeyFound,
}
}
func verifySSHCommitVerification(sig, payload string, k *PublicKey, committer, signer *user_model.User, email string) *CommitVerification {
if err := sshsig.Verify(bytes.NewBuffer([]byte(payload)), []byte(sig), []byte(k.Content), "git"); err != nil {
return nil
}
return &CommitVerification{ // Everything is ok
CommittingUser: committer,
Verified: true,
Reason: fmt.Sprintf("%s / %s", signer.Name, k.Fingerprint),
SigningUser: signer,
SigningSSHKey: k,
SigningEmail: email,
}
}

View File

@@ -0,0 +1,218 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package asymkey
import (
"context"
"fmt"
"time"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/perm"
"code.gitea.io/gitea/modules/timeutil"
"xorm.io/builder"
)
// ________ .__ ____ __.
// \______ \ ____ ______ | | ____ ___.__.| |/ _|____ ___.__.
// | | \_/ __ \\____ \| | / _ < | || <_/ __ < | |
// | ` \ ___/| |_> > |_( <_> )___ || | \ ___/\___ |
// /_______ /\___ > __/|____/\____// ____||____|__ \___ > ____|
// \/ \/|__| \/ \/ \/\/
//
// This file contains functions specific to DeployKeys
// DeployKey represents deploy key information and its relation with repository.
type DeployKey struct {
ID int64 `xorm:"pk autoincr"`
KeyID int64 `xorm:"UNIQUE(s) INDEX"`
RepoID int64 `xorm:"UNIQUE(s) INDEX"`
Name string
Fingerprint string
Content string `xorm:"-"`
Mode perm.AccessMode `xorm:"NOT NULL DEFAULT 1"`
CreatedUnix timeutil.TimeStamp `xorm:"created"`
UpdatedUnix timeutil.TimeStamp `xorm:"updated"`
HasRecentActivity bool `xorm:"-"`
HasUsed bool `xorm:"-"`
}
// AfterLoad is invoked from XORM after setting the values of all fields of this object.
func (key *DeployKey) AfterLoad() {
key.HasUsed = key.UpdatedUnix > key.CreatedUnix
key.HasRecentActivity = key.UpdatedUnix.AddDuration(7*24*time.Hour) > timeutil.TimeStampNow()
}
// GetContent gets associated public key content.
func (key *DeployKey) GetContent(ctx context.Context) error {
pkey, err := GetPublicKeyByID(ctx, key.KeyID)
if err != nil {
return err
}
key.Content = pkey.Content
return nil
}
// IsReadOnly checks if the key can only be used for read operations, used by template
func (key *DeployKey) IsReadOnly() bool {
return key.Mode == perm.AccessModeRead
}
func init() {
db.RegisterModel(new(DeployKey))
}
func checkDeployKey(ctx context.Context, keyID, repoID int64, name string) error {
// Note: We want error detail, not just true or false here.
has, err := db.GetEngine(ctx).
Where("key_id = ? AND repo_id = ?", keyID, repoID).
Get(new(DeployKey))
if err != nil {
return err
} else if has {
return ErrDeployKeyAlreadyExist{keyID, repoID}
}
has, err = db.GetEngine(ctx).
Where("repo_id = ? AND name = ?", repoID, name).
Get(new(DeployKey))
if err != nil {
return err
} else if has {
return ErrDeployKeyNameAlreadyUsed{repoID, name}
}
return nil
}
// addDeployKey adds new key-repo relation.
func addDeployKey(ctx context.Context, keyID, repoID int64, name, fingerprint string, mode perm.AccessMode) (*DeployKey, error) {
if err := checkDeployKey(ctx, keyID, repoID, name); err != nil {
return nil, err
}
key := &DeployKey{
KeyID: keyID,
RepoID: repoID,
Name: name,
Fingerprint: fingerprint,
Mode: mode,
}
return key, db.Insert(ctx, key)
}
// HasDeployKey returns true if public key is a deploy key of given repository.
func HasDeployKey(ctx context.Context, keyID, repoID int64) bool {
has, _ := db.GetEngine(ctx).
Where("key_id = ? AND repo_id = ?", keyID, repoID).
Get(new(DeployKey))
return has
}
// AddDeployKey add new deploy key to database and authorized_keys file.
func AddDeployKey(ctx context.Context, repoID int64, name, content string, readOnly bool) (*DeployKey, error) {
fingerprint, err := CalcFingerprint(content)
if err != nil {
return nil, err
}
accessMode := perm.AccessModeRead
if !readOnly {
accessMode = perm.AccessModeWrite
}
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return nil, err
}
defer committer.Close()
pkey, exist, err := db.Get[PublicKey](ctx, builder.Eq{"fingerprint": fingerprint})
if err != nil {
return nil, err
} else if exist {
if pkey.Type != KeyTypeDeploy {
return nil, ErrKeyAlreadyExist{0, fingerprint, ""}
}
} else {
// First time use this deploy key.
pkey = &PublicKey{
Fingerprint: fingerprint,
Mode: accessMode,
Type: KeyTypeDeploy,
Content: content,
Name: name,
}
if err = addKey(ctx, pkey); err != nil {
return nil, fmt.Errorf("addKey: %w", err)
}
}
key, err := addDeployKey(ctx, pkey.ID, repoID, name, pkey.Fingerprint, accessMode)
if err != nil {
return nil, err
}
return key, committer.Commit()
}
// GetDeployKeyByID returns deploy key by given ID.
func GetDeployKeyByID(ctx context.Context, id int64) (*DeployKey, error) {
key, exist, err := db.GetByID[DeployKey](ctx, id)
if err != nil {
return nil, err
} else if !exist {
return nil, ErrDeployKeyNotExist{id, 0, 0}
}
return key, nil
}
// GetDeployKeyByRepo returns deploy key by given public key ID and repository ID.
func GetDeployKeyByRepo(ctx context.Context, keyID, repoID int64) (*DeployKey, error) {
key, exist, err := db.Get[DeployKey](ctx, builder.Eq{"key_id": keyID, "repo_id": repoID})
if err != nil {
return nil, err
} else if !exist {
return nil, ErrDeployKeyNotExist{0, keyID, repoID}
}
return key, nil
}
// IsDeployKeyExistByKeyID return true if there is at least one deploykey with the key id
func IsDeployKeyExistByKeyID(ctx context.Context, keyID int64) (bool, error) {
return db.GetEngine(ctx).
Where("key_id = ?", keyID).
Get(new(DeployKey))
}
// UpdateDeployKeyCols updates deploy key information in the specified columns.
func UpdateDeployKeyCols(ctx context.Context, key *DeployKey, cols ...string) error {
_, err := db.GetEngine(ctx).ID(key.ID).Cols(cols...).Update(key)
return err
}
// ListDeployKeysOptions are options for ListDeployKeys
type ListDeployKeysOptions struct {
db.ListOptions
RepoID int64
KeyID int64
Fingerprint string
}
func (opt ListDeployKeysOptions) ToConds() builder.Cond {
cond := builder.NewCond()
if opt.RepoID != 0 {
cond = cond.And(builder.Eq{"repo_id": opt.RepoID})
}
if opt.KeyID != 0 {
cond = cond.And(builder.Eq{"key_id": opt.KeyID})
}
if opt.Fingerprint != "" {
cond = cond.And(builder.Eq{"fingerprint": opt.Fingerprint})
}
return cond
}

View File

@@ -0,0 +1,89 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package asymkey
import (
"context"
"fmt"
"strings"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/process"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/util"
"golang.org/x/crypto/ssh"
"xorm.io/builder"
)
// ___________.__ .__ __
// \_ _____/|__| ____ ____ ________________________|__| _____/ |_
// | __) | |/ \ / ___\_/ __ \_ __ \____ \_ __ \ |/ \ __\
// | \ | | | \/ /_/ > ___/| | \/ |_> > | \/ | | \ |
// \___ / |__|___| /\___ / \___ >__| | __/|__| |__|___| /__|
// \/ \//_____/ \/ |__| \/
//
// This file contains functions for fingerprinting SSH keys
//
// The database is used in checkKeyFingerprint however most of these functions probably belong in a module
// checkKeyFingerprint only checks if key fingerprint has been used as public key,
// it is OK to use same key as deploy key for multiple repositories/users.
func checkKeyFingerprint(ctx context.Context, fingerprint string) error {
has, err := db.Exist[PublicKey](ctx, builder.Eq{"fingerprint": fingerprint})
if err != nil {
return err
} else if has {
return ErrKeyAlreadyExist{0, fingerprint, ""}
}
return nil
}
func calcFingerprintSSHKeygen(publicKeyContent string) (string, error) {
// Calculate fingerprint.
tmpPath, err := writeTmpKeyFile(publicKeyContent)
if err != nil {
return "", err
}
defer func() {
if err := util.Remove(tmpPath); err != nil {
log.Warn("Unable to remove temporary key file: %s: Error: %v", tmpPath, err)
}
}()
stdout, stderr, err := process.GetManager().Exec("AddPublicKey", "ssh-keygen", "-lf", tmpPath)
if err != nil {
if strings.Contains(stderr, "is not a public key file") {
return "", ErrKeyUnableVerify{stderr}
}
return "", util.NewInvalidArgumentErrorf("'ssh-keygen -lf %s' failed with error '%s': %s", tmpPath, err, stderr)
} else if len(stdout) < 2 {
return "", util.NewInvalidArgumentErrorf("not enough output for calculating fingerprint: %s", stdout)
}
return strings.Split(stdout, " ")[1], nil
}
func calcFingerprintNative(publicKeyContent string) (string, error) {
// Calculate fingerprint.
pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(publicKeyContent))
if err != nil {
return "", err
}
return ssh.FingerprintSHA256(pk), nil
}
// CalcFingerprint calculate public key's fingerprint
func CalcFingerprint(publicKeyContent string) (string, error) {
// Call the method based on configuration
useNative := setting.SSH.KeygenPath == ""
calcFn := util.Iif(useNative, calcFingerprintNative, calcFingerprintSSHKeygen)
fp, err := calcFn(publicKeyContent)
if err != nil {
if IsErrKeyUnableVerify(err) {
return "", err
}
return "", fmt.Errorf("CalcFingerprint(%s): %w", util.Iif(useNative, "native", "ssh-keygen"), err)
}
return fp, nil
}

View File

@@ -0,0 +1,312 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package asymkey
import (
"crypto/rsa"
"crypto/x509"
"encoding/asn1"
"encoding/base64"
"encoding/binary"
"encoding/pem"
"fmt"
"math/big"
"os"
"strconv"
"strings"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/process"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/util"
"golang.org/x/crypto/ssh"
)
// ____ __. __________
// | |/ _|____ ___.__. \______ \_____ _______ ______ ___________
// | <_/ __ < | | | ___/\__ \\_ __ \/ ___// __ \_ __ \
// | | \ ___/\___ | | | / __ \| | \/\___ \\ ___/| | \/
// |____|__ \___ > ____| |____| (____ /__| /____ >\___ >__|
// \/ \/\/ \/ \/ \/
//
// This file contains functions for parsing ssh-keys
//
// TODO: Consider if these functions belong in models - no other models function call them or are called by them
// They may belong in a service or a module
const ssh2keyStart = "---- BEGIN SSH2 PUBLIC KEY ----"
func extractTypeFromBase64Key(key string) (string, error) {
b, err := base64.StdEncoding.DecodeString(key)
if err != nil || len(b) < 4 {
return "", fmt.Errorf("invalid key format: %w", err)
}
keyLength := int(binary.BigEndian.Uint32(b))
if len(b) < 4+keyLength {
return "", fmt.Errorf("invalid key format: not enough length %d", keyLength)
}
return string(b[4 : 4+keyLength]), nil
}
// parseKeyString parses any key string in OpenSSH or SSH2 format to clean OpenSSH string (RFC4253).
func parseKeyString(content string) (string, error) {
// remove whitespace at start and end
content = strings.TrimSpace(content)
var keyType, keyContent, keyComment string
if strings.HasPrefix(content, ssh2keyStart) {
// Parse SSH2 file format.
// Transform all legal line endings to a single "\n".
content = strings.NewReplacer("\r\n", "\n", "\r", "\n").Replace(content)
lines := strings.Split(content, "\n")
continuationLine := false
for _, line := range lines {
// Skip lines that:
// 1) are a continuation of the previous line,
// 2) contain ":" as that are comment lines
// 3) contain "-" as that are begin and end tags
if continuationLine || strings.ContainsAny(line, ":-") {
continuationLine = strings.HasSuffix(line, "\\")
} else {
keyContent += line
}
}
t, err := extractTypeFromBase64Key(keyContent)
if err != nil {
return "", fmt.Errorf("extractTypeFromBase64Key: %w", err)
}
keyType = t
} else {
if strings.Contains(content, "-----BEGIN") {
// Convert PEM Keys to OpenSSH format
// Transform all legal line endings to a single "\n".
content = strings.NewReplacer("\r\n", "\n", "\r", "\n").Replace(content)
block, _ := pem.Decode([]byte(content))
if block == nil {
return "", fmt.Errorf("failed to parse PEM block containing the public key")
}
if strings.Contains(block.Type, "PRIVATE") {
return "", ErrKeyIsPrivate
}
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
var pk rsa.PublicKey
_, err2 := asn1.Unmarshal(block.Bytes, &pk)
if err2 != nil {
return "", fmt.Errorf("failed to parse DER encoded public key as either PKIX or PEM RSA Key: %v %w", err, err2)
}
pub = &pk
}
sshKey, err := ssh.NewPublicKey(pub)
if err != nil {
return "", fmt.Errorf("unable to convert to ssh public key: %w", err)
}
content = string(ssh.MarshalAuthorizedKey(sshKey))
}
// Parse OpenSSH format.
// Remove all newlines
content = strings.NewReplacer("\r\n", "", "\n", "").Replace(content)
parts := strings.SplitN(content, " ", 3)
switch len(parts) {
case 0:
return "", util.NewInvalidArgumentErrorf("empty key")
case 1:
keyContent = parts[0]
case 2:
keyType = parts[0]
keyContent = parts[1]
default:
keyType = parts[0]
keyContent = parts[1]
keyComment = parts[2]
}
// If keyType is not given, extract it from content. If given, validate it.
t, err := extractTypeFromBase64Key(keyContent)
if err != nil {
return "", fmt.Errorf("extractTypeFromBase64Key: %w", err)
}
if len(keyType) == 0 {
keyType = t
} else if keyType != t {
return "", fmt.Errorf("key type and content does not match: %s - %s", keyType, t)
}
}
// Finally we need to check whether we can actually read the proposed key:
_, _, _, _, err := ssh.ParseAuthorizedKey([]byte(keyType + " " + keyContent + " " + keyComment))
if err != nil {
return "", fmt.Errorf("invalid ssh public key: %w", err)
}
return keyType + " " + keyContent + " " + keyComment, nil
}
// CheckPublicKeyString checks if the given public key string is recognized by SSH.
// It returns the actual public key line on success.
func CheckPublicKeyString(content string) (_ string, err error) {
content, err = parseKeyString(content)
if err != nil {
return "", err
}
content = strings.TrimRight(content, "\n\r")
if strings.ContainsAny(content, "\n\r") {
return "", util.NewInvalidArgumentErrorf("only a single line with a single key please")
}
// remove any unnecessary whitespace now
content = strings.TrimSpace(content)
if !setting.SSH.MinimumKeySizeCheck {
return content, nil
}
var (
fnName string
keyType string
length int
)
if len(setting.SSH.KeygenPath) == 0 {
fnName = "SSHNativeParsePublicKey"
keyType, length, err = SSHNativeParsePublicKey(content)
} else {
fnName = "SSHKeyGenParsePublicKey"
keyType, length, err = SSHKeyGenParsePublicKey(content)
}
if err != nil {
return "", fmt.Errorf("%s: %w", fnName, err)
}
log.Trace("Key info [native: %v]: %s-%d", setting.SSH.StartBuiltinServer, keyType, length)
if minLen, found := setting.SSH.MinimumKeySizes[keyType]; found && length >= minLen {
return content, nil
} else if found && length < minLen {
return "", fmt.Errorf("key length is not enough: got %d, needs %d", length, minLen)
}
return "", fmt.Errorf("key type is not allowed: %s", keyType)
}
// SSHNativeParsePublicKey extracts the key type and length using the golang SSH library.
func SSHNativeParsePublicKey(keyLine string) (string, int, error) {
fields := strings.Fields(keyLine)
if len(fields) < 2 {
return "", 0, fmt.Errorf("not enough fields in public key line: %s", keyLine)
}
raw, err := base64.StdEncoding.DecodeString(fields[1])
if err != nil {
return "", 0, err
}
pkey, err := ssh.ParsePublicKey(raw)
if err != nil {
if strings.Contains(err.Error(), "ssh: unknown key algorithm") {
return "", 0, ErrKeyUnableVerify{err.Error()}
}
return "", 0, fmt.Errorf("ParsePublicKey: %w", err)
}
// The ssh library can parse the key, so next we find out what key exactly we have.
switch pkey.Type() {
case ssh.KeyAlgoDSA:
rawPub := struct {
Name string
P, Q, G, Y *big.Int
}{}
if err := ssh.Unmarshal(pkey.Marshal(), &rawPub); err != nil {
return "", 0, err
}
// as per https://bugzilla.mindrot.org/show_bug.cgi?id=1647 we should never
// see dsa keys != 1024 bit, but as it seems to work, we will not check here
return "dsa", rawPub.P.BitLen(), nil // use P as per crypto/dsa/dsa.go (is L)
case ssh.KeyAlgoRSA:
rawPub := struct {
Name string
E *big.Int
N *big.Int
}{}
if err := ssh.Unmarshal(pkey.Marshal(), &rawPub); err != nil {
return "", 0, err
}
return "rsa", rawPub.N.BitLen(), nil // use N as per crypto/rsa/rsa.go (is bits)
case ssh.KeyAlgoECDSA256:
return "ecdsa", 256, nil
case ssh.KeyAlgoECDSA384:
return "ecdsa", 384, nil
case ssh.KeyAlgoECDSA521:
return "ecdsa", 521, nil
case ssh.KeyAlgoED25519:
return "ed25519", 256, nil
case ssh.KeyAlgoSKECDSA256:
return "ecdsa-sk", 256, nil
case ssh.KeyAlgoSKED25519:
return "ed25519-sk", 256, nil
}
return "", 0, fmt.Errorf("unsupported key length detection for type: %s", pkey.Type())
}
// writeTmpKeyFile writes key content to a temporary file
// and returns the name of that file, along with any possible errors.
func writeTmpKeyFile(content string) (string, error) {
tmpFile, err := os.CreateTemp(setting.SSH.KeyTestPath, "gitea_keytest")
if err != nil {
return "", fmt.Errorf("TempFile: %w", err)
}
defer tmpFile.Close()
if _, err = tmpFile.WriteString(content); err != nil {
return "", fmt.Errorf("WriteString: %w", err)
}
return tmpFile.Name(), nil
}
// SSHKeyGenParsePublicKey extracts key type and length using ssh-keygen.
func SSHKeyGenParsePublicKey(key string) (string, int, error) {
tmpName, err := writeTmpKeyFile(key)
if err != nil {
return "", 0, fmt.Errorf("writeTmpKeyFile: %w", err)
}
defer func() {
if err := util.Remove(tmpName); err != nil {
log.Warn("Unable to remove temporary key file: %s: Error: %v", tmpName, err)
}
}()
keygenPath := setting.SSH.KeygenPath
if len(keygenPath) == 0 {
keygenPath = "ssh-keygen"
}
stdout, stderr, err := process.GetManager().Exec("SSHKeyGenParsePublicKey", keygenPath, "-lf", tmpName)
if err != nil {
return "", 0, fmt.Errorf("fail to parse public key: %s - %s", err, stderr)
}
if strings.Contains(stdout, "is not a public key file") {
return "", 0, ErrKeyUnableVerify{stdout}
}
fields := strings.Split(stdout, " ")
if len(fields) < 4 {
return "", 0, fmt.Errorf("invalid public key line: %s", stdout)
}
keyType := strings.Trim(fields[len(fields)-1], "()\r\n")
length, err := strconv.ParseInt(fields[0], 10, 32)
if err != nil {
return "", 0, err
}
return strings.ToLower(keyType), int(length), nil
}

View File

@@ -0,0 +1,56 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package asymkey
import (
"context"
"fmt"
"strings"
"code.gitea.io/gitea/models/db"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/util"
)
// CheckPrincipalKeyString strips spaces and returns an error if the given principal contains newlines
func CheckPrincipalKeyString(ctx context.Context, user *user_model.User, content string) (_ string, err error) {
if setting.SSH.Disabled {
return "", db.ErrSSHDisabled{}
}
content = strings.TrimSpace(content)
if strings.ContainsAny(content, "\r\n") {
return "", util.NewInvalidArgumentErrorf("only a single line with a single principal please")
}
// check all the allowed principals, email, username or anything
// if any matches, return ok
for _, v := range setting.SSH.AuthorizedPrincipalsAllow {
switch v {
case "anything":
return content, nil
case "email":
emails, err := user_model.GetEmailAddresses(ctx, user.ID)
if err != nil {
return "", err
}
for _, email := range emails {
if !email.IsActivated {
continue
}
if content == email.Email {
return content, nil
}
}
case "username":
if content == user.Name {
return content, nil
}
}
}
return "", fmt.Errorf("didn't match allowed principals: %s", setting.SSH.AuthorizedPrincipalsAllow)
}

View File

@@ -0,0 +1,513 @@
// Copyright 2016 The Gogs Authors. All rights reserved.
// Copyright 2019 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package asymkey
import (
"bytes"
"os"
"os/exec"
"path/filepath"
"strings"
"testing"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/unittest"
"code.gitea.io/gitea/modules/setting"
"github.com/42wim/sshsig"
"github.com/stretchr/testify/assert"
)
func Test_SSHParsePublicKey(t *testing.T) {
testCases := []struct {
name string
skipSSHKeygen bool
keyType string
length int
content string
}{
{"rsa-1024", false, "rsa", 1024, "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAAAgQDAu7tvIvX6ZHrRXuZNfkR3XLHSsuCK9Zn3X58lxBcQzuo5xZgB6vRwwm/QtJuF+zZPtY5hsQILBLmF+BZ5WpKZp1jBeSjH2G7lxet9kbcH+kIVj0tPFEoyKI9wvWqIwC4prx/WVk2wLTJjzBAhyNxfEq7C9CeiX9pQEbEqJfkKCQ== nocomment\n"},
{"rsa-2048", false, "rsa", 2048, "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDMZXh+1OBUwSH9D45wTaxErQIN9IoC9xl7MKJkqvTvv6O5RR9YW/IK9FbfjXgXsppYGhsCZo1hFOOsXHMnfOORqu/xMDx4yPuyvKpw4LePEcg4TDipaDFuxbWOqc/BUZRZcXu41QAWfDLrInwsltWZHSeG7hjhpacl4FrVv9V1pS6Oc5Q1NxxEzTzuNLS/8diZrTm/YAQQ/+B+mzWI3zEtF4miZjjAljWd1LTBPvU23d29DcBmmFahcZ441XZsTeAwGxG/Q6j8NgNXj9WxMeWwxXV2jeAX/EBSpZrCVlCQ1yJswT6xCp8TuBnTiGWYMBNTbOZvPC4e0WI2/yZW/s5F nocomment"},
{"ecdsa-256", false, "ecdsa", 256, "ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBFQacN3PrOll7PXmN5B/ZNVahiUIqI05nbBlZk1KXsO3d06ktAWqbNflv2vEmA38bTFTfJ2sbn2B5ksT52cDDbA= nocomment"},
{"ecdsa-384", false, "ecdsa", 384, "ecdsa-sha2-nistp384 AAAAE2VjZHNhLXNoYTItbmlzdHAzODQAAAAIbmlzdHAzODQAAABhBINmioV+XRX1Fm9Qk2ehHXJ2tfVxW30ypUWZw670Zyq5GQfBAH6xjygRsJ5wWsHXBsGYgFUXIHvMKVAG1tpw7s6ax9oA+dJOJ7tj+vhn8joFqT+sg3LYHgZkHrfqryRasQ== nocomment"},
{"ecdsa-sk", true, "ecdsa-sk", 256, "sk-ecdsa-sha2-nistp256@openssh.com AAAAInNrLWVjZHNhLXNoYTItbmlzdHAyNTZAb3BlbnNzaC5jb20AAAAIbmlzdHAyNTYAAABBBGXEEzWmm1dxb+57RoK5KVCL0w2eNv9cqJX2AGGVlkFsVDhOXHzsadS3LTK4VlEbbrDMJdoti9yM8vclA8IeRacAAAAEc3NoOg== nocomment"},
{"ed25519-sk", true, "ed25519-sk", 256, "sk-ssh-ed25519@openssh.com AAAAGnNrLXNzaC1lZDI1NTE5QG9wZW5zc2guY29tAAAAIE7kM1R02+4ertDKGKEDcKG0s+2vyDDcIvceJ0Gqv5f1AAAABHNzaDo= nocomment"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Run("Native", func(t *testing.T) {
keyTypeN, lengthN, err := SSHNativeParsePublicKey(tc.content)
assert.NoError(t, err)
assert.Equal(t, tc.keyType, keyTypeN)
assert.EqualValues(t, tc.length, lengthN)
})
if tc.skipSSHKeygen {
return
}
t.Run("SSHKeygen", func(t *testing.T) {
keyTypeK, lengthK, err := SSHKeyGenParsePublicKey(tc.content)
if err != nil {
// Some servers do not support ecdsa format.
if !strings.Contains(err.Error(), "line 1 too long:") {
assert.FailNow(t, "%v", err)
}
}
assert.Equal(t, tc.keyType, keyTypeK)
assert.EqualValues(t, tc.length, lengthK)
})
t.Run("SSHParseKeyNative", func(t *testing.T) {
keyTypeK, lengthK, err := SSHNativeParsePublicKey(tc.content)
if err != nil {
assert.FailNow(t, "%v", err)
}
assert.Equal(t, tc.keyType, keyTypeK)
assert.EqualValues(t, tc.length, lengthK)
})
})
}
}
func Test_CheckPublicKeyString(t *testing.T) {
oldValue := setting.SSH.MinimumKeySizeCheck
setting.SSH.MinimumKeySizeCheck = false
for _, test := range []struct {
content string
}{
{"ssh-dss AAAAB3NzaC1kc3MAAACBAOChCC7lf6Uo9n7BmZ6M8St19PZf4Tn59NriyboW2x/DZuYAz3ibZ2OkQ3S0SqDIa0HXSEJ1zaExQdmbO+Ux/wsytWZmCczWOVsaszBZSl90q8UnWlSH6P+/YA+RWJm5SFtuV9PtGIhyZgoNuz5kBQ7K139wuQsecdKktISwTakzAAAAFQCzKsO2JhNKlL+wwwLGOcLffoAmkwAAAIBpK7/3xvduajLBD/9vASqBQIHrgK2J+wiQnIb/Wzy0UsVmvfn8A+udRbBo+csM8xrSnlnlJnjkJS3qiM5g+eTwsLIV1IdKPEwmwB+VcP53Cw6lSyWyJcvhFb0N6s08NZysLzvj0N+ZC/FnhKTLzIyMtkHf/IrPCwlM+pV/M/96YgAAAIEAqQcGn9CKgzgPaguIZooTAOQdvBLMI5y0bQjOW6734XOpqQGf/Kra90wpoasLKZjSYKNPjE+FRUOrStLrxcNs4BeVKhy2PYTRnybfYVk1/dmKgH6P1YSRONsGKvTsH6c5IyCRG0ncCgYeF8tXppyd642982daopE7zQ/NPAnJfag= nocomment"},
{"ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAAAgQDAu7tvIvX6ZHrRXuZNfkR3XLHSsuCK9Zn3X58lxBcQzuo5xZgB6vRwwm/QtJuF+zZPtY5hsQILBLmF+BZ5WpKZp1jBeSjH2G7lxet9kbcH+kIVj0tPFEoyKI9wvWqIwC4prx/WVk2wLTJjzBAhyNxfEq7C9CeiX9pQEbEqJfkKCQ== nocomment\n"},
{"ssh-rsa AAAAB3NzaC1yc2EA\r\nAAADAQABAAAAgQDAu7tvIvX6ZHrRXuZNfkR3XLHSsuCK9Zn3X58lxBcQzuo5xZgB6vRwwm/QtJuF+zZPtY5hsQILBLmF+\r\nBZ5WpKZp1jBeSjH2G7lxet9kbcH+kIVj0tPFEoyKI9wvWqIwC4prx/WVk2wLTJjzBAhyNx\r\nfEq7C9CeiX9pQEbEqJfkKCQ== nocomment\r\n\r\n"},
{"ssh-rsa AAAAB3NzaC1yc2EA\r\nAAADAQABAAAAgQDAu7tvI\nvX6ZHrRXuZNfkR3XLHSsuCK9Zn3X58lxBcQzuo5xZgB6vRwwm/QtJuF+zZPtY5hsQILBLmF+\r\nBZ5WpKZp1jBeSjH2G7lxet9kbcH+kIVj0tPFEoyKI9wvW\nqIwC4prx/WVk2wLTJjzBAhyNx\r\nfEq7C9CeiX9pQEbEqJfkKCQ== nocomment\r\n\r\n"},
{"ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAICV0MGX/W9IvLA4FXpIuUcdDcbj5KX4syHgsTy7soVgf"},
{"\r\nssh-ed25519 \r\nAAAAC3NzaC1lZDI1NTE5AAAAICV0MGX/W9IvLA4FXpIuUcdDcbj5KX4syHgsTy7soVgf\r\n\r\n"},
{"sk-ecdsa-sha2-nistp256@openssh.com AAAAInNrLWVjZHNhLXNoYTItbmlzdHAyNTZAb3BlbnNzaC5jb20AAAAIbmlzdHAyNTYAAABBBGXEEzWmm1dxb+57RoK5KVCL0w2eNv9cqJX2AGGVlkFsVDhOXHzsadS3LTK4VlEbbrDMJdoti9yM8vclA8IeRacAAAAEc3NoOg== nocomment"},
{"sk-ssh-ed25519@openssh.com AAAAGnNrLXNzaC1lZDI1NTE5QG9wZW5zc2guY29tAAAAIE7kM1R02+4ertDKGKEDcKG0s+2vyDDcIvceJ0Gqv5f1AAAABHNzaDo= nocomment"},
{`---- BEGIN SSH2 PUBLIC KEY ----
Comment: "1024-bit DSA, converted by andrew@phaedra from OpenSSH"
AAAAB3NzaC1kc3MAAACBAOChCC7lf6Uo9n7BmZ6M8St19PZf4Tn59NriyboW2x/DZuYAz3
ibZ2OkQ3S0SqDIa0HXSEJ1zaExQdmbO+Ux/wsytWZmCczWOVsaszBZSl90q8UnWlSH6P+/
YA+RWJm5SFtuV9PtGIhyZgoNuz5kBQ7K139wuQsecdKktISwTakzAAAAFQCzKsO2JhNKlL
+wwwLGOcLffoAmkwAAAIBpK7/3xvduajLBD/9vASqBQIHrgK2J+wiQnIb/Wzy0UsVmvfn8
A+udRbBo+csM8xrSnlnlJnjkJS3qiM5g+eTwsLIV1IdKPEwmwB+VcP53Cw6lSyWyJcvhFb
0N6s08NZysLzvj0N+ZC/FnhKTLzIyMtkHf/IrPCwlM+pV/M/96YgAAAIEAqQcGn9CKgzgP
aguIZooTAOQdvBLMI5y0bQjOW6734XOpqQGf/Kra90wpoasLKZjSYKNPjE+FRUOrStLrxc
Ns4BeVKhy2PYTRnybfYVk1/dmKgH6P1YSRONsGKvTsH6c5IyCRG0ncCgYeF8tXppyd6429
82daopE7zQ/NPAnJfag=
---- END SSH2 PUBLIC KEY ----
`},
{`---- BEGIN SSH2 PUBLIC KEY ----
Comment: "1024-bit RSA, converted by andrew@phaedra from OpenSSH"
AAAAB3NzaC1yc2EAAAADAQABAAAAgQDAu7tvIvX6ZHrRXuZNfkR3XLHSsuCK9Zn3X58lxB
cQzuo5xZgB6vRwwm/QtJuF+zZPtY5hsQILBLmF+BZ5WpKZp1jBeSjH2G7lxet9kbcH+kIV
j0tPFEoyKI9wvWqIwC4prx/WVk2wLTJjzBAhyNxfEq7C9CeiX9pQEbEqJfkKCQ==
---- END SSH2 PUBLIC KEY ----
`},
{`-----BEGIN RSA PUBLIC KEY-----
MIGJAoGBAMC7u28i9fpketFe5k1+RHdcsdKy4Ir1mfdfnyXEFxDO6jnFmAHq9HDC
b9C0m4X7Nk+1jmGxAgsEuYX4FnlakpmnWMF5KMfYbuXF632Rtwf6QhWPS08USjIo
j3C9aojALimvH9ZWTbAtMmPMECHI3F8SrsL0J6Jf2lARsSol+QoJAgMBAAE=
-----END RSA PUBLIC KEY-----
`},
{`-----BEGIN PUBLIC KEY-----
MIIBtzCCASsGByqGSM44BAEwggEeAoGBAOChCC7lf6Uo9n7BmZ6M8St19PZf4Tn5
9NriyboW2x/DZuYAz3ibZ2OkQ3S0SqDIa0HXSEJ1zaExQdmbO+Ux/wsytWZmCczW
OVsaszBZSl90q8UnWlSH6P+/YA+RWJm5SFtuV9PtGIhyZgoNuz5kBQ7K139wuQse
cdKktISwTakzAhUAsyrDtiYTSpS/sMMCxjnC336AJpMCgYBpK7/3xvduajLBD/9v
ASqBQIHrgK2J+wiQnIb/Wzy0UsVmvfn8A+udRbBo+csM8xrSnlnlJnjkJS3qiM5g
+eTwsLIV1IdKPEwmwB+VcP53Cw6lSyWyJcvhFb0N6s08NZysLzvj0N+ZC/FnhKTL
zIyMtkHf/IrPCwlM+pV/M/96YgOBhQACgYEAqQcGn9CKgzgPaguIZooTAOQdvBLM
I5y0bQjOW6734XOpqQGf/Kra90wpoasLKZjSYKNPjE+FRUOrStLrxcNs4BeVKhy2
PYTRnybfYVk1/dmKgH6P1YSRONsGKvTsH6c5IyCRG0ncCgYeF8tXppyd642982da
opE7zQ/NPAnJfag=
-----END PUBLIC KEY-----
`},
{`-----BEGIN PUBLIC KEY-----
MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDAu7tvIvX6ZHrRXuZNfkR3XLHS
suCK9Zn3X58lxBcQzuo5xZgB6vRwwm/QtJuF+zZPtY5hsQILBLmF+BZ5WpKZp1jB
eSjH2G7lxet9kbcH+kIVj0tPFEoyKI9wvWqIwC4prx/WVk2wLTJjzBAhyNxfEq7C
9CeiX9pQEbEqJfkKCQIDAQAB
-----END PUBLIC KEY-----
`},
{`-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAzGV4ftTgVMEh/Q+OcE2s
RK0CDfSKAvcZezCiZKr077+juUUfWFvyCvRW3414F7KaWBobAmaNYRTjrFxzJ3zj
karv8TA8eMj7sryqcOC3jxHIOEw4qWgxbsW1jqnPwVGUWXF7uNUAFnwy6yJ8LJbV
mR0nhu4Y4aWnJeBa1b/VdaUujnOUNTccRM087jS0v/HYma05v2AEEP/gfps1iN8x
LReJomY4wJY1ndS0wT71Nt3dvQ3AZphWoXGeONV2bE3gMBsRv0Oo/DYDV4/VsTHl
sMV1do3gF/xAUqWawlZQkNcibME+sQqfE7gZ04hlmDATU2zmbzwuHtFiNv8mVv7O
RQIDAQAB
-----END PUBLIC KEY-----
`},
{`---- BEGIN SSH2 PUBLIC KEY ----
Comment: "256-bit ED25519, converted by andrew@phaedra from OpenSSH"
AAAAC3NzaC1lZDI1NTE5AAAAICV0MGX/W9IvLA4FXpIuUcdDcbj5KX4syHgsTy7soVgf
---- END SSH2 PUBLIC KEY ----
`},
} {
_, err := CheckPublicKeyString(test.content)
assert.NoError(t, err)
}
setting.SSH.MinimumKeySizeCheck = oldValue
for _, invalidKeys := range []struct {
content string
}{
{"test"},
{"---- NOT A REAL KEY ----"},
{"bad\nkey"},
{"\t\t:)\t\r\n"},
{"\r\ntest \r\ngitea\r\n\r\n"},
} {
_, err := CheckPublicKeyString(invalidKeys.content)
assert.Error(t, err)
}
}
func Test_calcFingerprint(t *testing.T) {
testCases := []struct {
name string
skipSSHKeygen bool
fp string
content string
}{
{"rsa-1024", false, "SHA256:vSnDkvRh/xM6kMxPidLgrUhq3mCN7CDaronCEm2joyQ", "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAAAgQDAu7tvIvX6ZHrRXuZNfkR3XLHSsuCK9Zn3X58lxBcQzuo5xZgB6vRwwm/QtJuF+zZPtY5hsQILBLmF+BZ5WpKZp1jBeSjH2G7lxet9kbcH+kIVj0tPFEoyKI9wvWqIwC4prx/WVk2wLTJjzBAhyNxfEq7C9CeiX9pQEbEqJfkKCQ== nocomment\n"},
{"rsa-2048", false, "SHA256:ZHD//a1b9VuTq9XSunAeYjKeU1xDa2tBFZYrFr2Okkg", "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDMZXh+1OBUwSH9D45wTaxErQIN9IoC9xl7MKJkqvTvv6O5RR9YW/IK9FbfjXgXsppYGhsCZo1hFOOsXHMnfOORqu/xMDx4yPuyvKpw4LePEcg4TDipaDFuxbWOqc/BUZRZcXu41QAWfDLrInwsltWZHSeG7hjhpacl4FrVv9V1pS6Oc5Q1NxxEzTzuNLS/8diZrTm/YAQQ/+B+mzWI3zEtF4miZjjAljWd1LTBPvU23d29DcBmmFahcZ441XZsTeAwGxG/Q6j8NgNXj9WxMeWwxXV2jeAX/EBSpZrCVlCQ1yJswT6xCp8TuBnTiGWYMBNTbOZvPC4e0WI2/yZW/s5F nocomment"},
{"ecdsa-256", false, "SHA256:Bqx/xgWqRKLtkZ0Lr4iZpgb+5lYsFpSwXwVZbPwuTRw", "ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBFQacN3PrOll7PXmN5B/ZNVahiUIqI05nbBlZk1KXsO3d06ktAWqbNflv2vEmA38bTFTfJ2sbn2B5ksT52cDDbA= nocomment"},
{"ecdsa-384", false, "SHA256:4qfJOgJDtUd8BrEjyVNdI8IgjiZKouztVde43aDhe1E", "ecdsa-sha2-nistp384 AAAAE2VjZHNhLXNoYTItbmlzdHAzODQAAAAIbmlzdHAzODQAAABhBINmioV+XRX1Fm9Qk2ehHXJ2tfVxW30ypUWZw670Zyq5GQfBAH6xjygRsJ5wWsHXBsGYgFUXIHvMKVAG1tpw7s6ax9oA+dJOJ7tj+vhn8joFqT+sg3LYHgZkHrfqryRasQ== nocomment"},
{"ecdsa-sk", true, "SHA256:4wcIu4z+53gHc+db85OPfy8IydyNzPLCr6kHIs625LQ", "sk-ecdsa-sha2-nistp256@openssh.com AAAAInNrLWVjZHNhLXNoYTItbmlzdHAyNTZAb3BlbnNzaC5jb20AAAAIbmlzdHAyNTYAAABBBGXEEzWmm1dxb+57RoK5KVCL0w2eNv9cqJX2AGGVlkFsVDhOXHzsadS3LTK4VlEbbrDMJdoti9yM8vclA8IeRacAAAAEc3NoOg== nocomment"},
{"ed25519-sk", true, "SHA256:RB4ku1OeWKN7fLMrjxz38DK0mp1BnOPBx4BItjTvJ0g", "sk-ssh-ed25519@openssh.com AAAAGnNrLXNzaC1lZDI1NTE5QG9wZW5zc2guY29tAAAAIE7kM1R02+4ertDKGKEDcKG0s+2vyDDcIvceJ0Gqv5f1AAAABHNzaDo= nocomment"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Run("Native", func(t *testing.T) {
fpN, err := calcFingerprintNative(tc.content)
assert.NoError(t, err)
assert.Equal(t, tc.fp, fpN)
})
if tc.skipSSHKeygen {
return
}
t.Run("SSHKeygen", func(t *testing.T) {
fpK, err := calcFingerprintSSHKeygen(tc.content)
assert.NoError(t, err)
assert.Equal(t, tc.fp, fpK)
})
})
}
}
var (
// Generated with "ssh-keygen -C test@rekor.dev -f id_rsa"
sshPrivateKey = `-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAABlwAAAAdzc2gtcn
NhAAAAAwEAAQAAAYEA16H5ImoRO7mr41r8Z8JFBdu6jIM+6XU8M0r9F81RuhLYqzr9zw1n
LeGCqFxPXNBKm8ZyH2BCsBHsbXbwe85IMHM3SUh8X/9fI0Lpi5/xbqAproFUpNR+UJYv6s
8AaWk5zpN1rmpBrqGFJfGQKJCioDiiwNGmSdVkUNmQmYIANxJMDWYmNe8vUOh6nYEHB+lz
fGgDAAzVSXTACW994UkSY47AD05swU4rIT/JWA6BkUrEhO//F0QQhFeROCPJiPRhJXGcFf
9SicffJqR/ELzM1zNYnRXMD0bbdTUwDrIcIFFNBbtcfJVOUUCGumSlt+qjUC7y8cvwbHAu
wf5nS6baA7P6LfTYplF2XIAkdWtkN6O1ouoyIHICXMlddDW2vNaJeEXTeKjx51WSM7qPnQ
ZKsBtwjLQeEY/OPkIvu88lNNYSD63qMUA12msohjwVFCIgJVvYLIrkViczZ7t3L7lgy1X0
CJI4e1roOfM/r9jTieyDHchEYpZYcw3L1R2qtePlAAAFiHdJQKl3SUCpAAAAB3NzaC1yc2
EAAAGBANeh+SJqETu5q+Na/GfCRQXbuoyDPul1PDNK/RfNUboS2Ks6/c8NZy3hgqhcT1zQ
SpvGch9gQrAR7G128HvOSDBzN0lIfF//XyNC6Yuf8W6gKa6BVKTUflCWL+rPAGlpOc6Tda
5qQa6hhSXxkCiQoqA4osDRpknVZFDZkJmCADcSTA1mJjXvL1Doep2BBwfpc3xoAwAM1Ul0
wAlvfeFJEmOOwA9ObMFOKyE/yVgOgZFKxITv/xdEEIRXkTgjyYj0YSVxnBX/UonH3yakfx
C8zNczWJ0VzA9G23U1MA6yHCBRTQW7XHyVTlFAhrpkpbfqo1Au8vHL8GxwLsH+Z0um2gOz
+i302KZRdlyAJHVrZDejtaLqMiByAlzJXXQ1trzWiXhF03io8edVkjO6j50GSrAbcIy0Hh
GPzj5CL7vPJTTWEg+t6jFANdprKIY8FRQiICVb2CyK5FYnM2e7dy+5YMtV9AiSOHta6Dnz
P6/Y04nsgx3IRGKWWHMNy9UdqrXj5QAAAAMBAAEAAAGAJyaOcFQnuttUPRxY9ZHNLGofrc
Fqm8KgYoO7/iVWMF2Zn0U/rec2E5t9OIpCEozy7uOR9uZoVUV70sgkk6X5b2qL4C9b/aYF
JQbSFnq8wCQuTTPIJYE7SfBq1Mwuu/TR/RLC7B74u/cxkJkSXnscO9Dso+ussH0hEJjf6y
8yUM1up4Qjbel2gs8i7BPwLdySDkVoPgsWcpbTAyOODGhTAWZ6soy/rD1AEXJeYTGJDtMv
aR+WBihig1TO1g2RWt9bqqiG7PIlljd3ZsjSSU5y3t6ZN/8j5keKD032EtxbZB0WFD3Ar4
FbFwlW+urb2MQ0JyNKOio3nhdjolXYkJa+C6LXdaaml/8BhMR1eLoMe8nS45w76o8mdJWX
wsirB8tvjCLY0QBXgGv/1DTsKu/wEFCW2/Y0e50gF7pHAlYFNmKDcgI9OyORRYhFbV4D82
fI8JLQ42ZJkS/0t6xQma8WC88pbHGEuVSB6CE/p25fyYRX+UPTQ79tWFvLV4kNQAaBAAAA
wEvyd6H8ePyBXImg8JzGxthufB0eXSfZBrabjf6e6bR2ivpJsHmB64gbMkV6MFV7EWYX1B
wYPQxf4gA2Ez7aJvDtfE7uV6pa0WJS3hW1+be8DHEftmLSbTy/TEvDujNb2gqoi7uWQXWJ
yYWZlYO65r1a6HucryQ8+78fTuTRbZALO43vNGz0oXH1hPSddkcbNAhZTsD0rQKNwqVTe5
wl+6Cduy/CQwjHLYrY73MyWy1Vh1LXhAdGMPnWZwGIu/dnkgAAAMEA9KuaoGnfnLQkrjeR
tO4RCRS2quNRvm4L6i4vHgTDsYtoSlR1ujge7SGOOmIPS4XVjZN5zzCOA7+EDVnuz3WWmx
hmkjpG1YxzmJGaWoYdeo3a6UgJtisfMp8eUKqjJT1mhsCliCWtaOQNRoQieDQmgwZzSX/v
ZiGsOIKa6cR37eKvOJSjVrHsAUzdtYrmi8P2gvAUFWyzXobAtpzHcWrwWkOEIm04G0OGXb
J46hfIX3f45E5EKXvFzexGgVOD2I7hAAAAwQDhniYAizfW9YfG7UJWekkl42xMP7Cb8b0W
SindSIuE8bFTukV1yxbmNZp/f0pKvn/DWc2n0I0bwSGZpy8BCY46RKKB2DYQavY/tGcC1N
AynKuvbtWs11A0mTXmq3WwHVXQDozMwJ2nnHpm0UHspPuHqkYpurlP+xoFsocaQ9QwITyp
lL4qHtXBEzaT8okkcGZBHdSx3gk4TzCsEDOP7ZZPLq42lpKMK10zFPTMd0maXtJDYKU/b4
gAATvvPoylyYUAAAAOdGVzdEByZWtvci5kZXYBAgMEBQ==
-----END OPENSSH PRIVATE KEY-----
`
sshPublicKey = `ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQDXofkiahE7uavjWvxnwkUF27qMgz7pdTwzSv0XzVG6EtirOv3PDWct4YKoXE9c0EqbxnIfYEKwEextdvB7zkgwczdJSHxf/18jQumLn/FuoCmugVSk1H5Qli/qzwBpaTnOk3WuakGuoYUl8ZAokKKgOKLA0aZJ1WRQ2ZCZggA3EkwNZiY17y9Q6HqdgQcH6XN8aAMADNVJdMAJb33hSRJjjsAPTmzBTishP8lYDoGRSsSE7/8XRBCEV5E4I8mI9GElcZwV/1KJx98mpH8QvMzXM1idFcwPRtt1NTAOshwgUU0Fu1x8lU5RQIa6ZKW36qNQLvLxy/BscC7B/mdLptoDs/ot9NimUXZcgCR1a2Q3o7Wi6jIgcgJcyV10Nba81ol4RdN4qPHnVZIzuo+dBkqwG3CMtB4Rj84+Qi+7zyU01hIPreoxQDXaayiGPBUUIiAlW9gsiuRWJzNnu3cvuWDLVfQIkjh7Wug58z+v2NOJ7IMdyERillhzDcvVHaq14+U= test@rekor.dev
`
// Generated with "ssh-keygen -C other-test@rekor.dev -f id_rsa"
otherSSHPrivateKey = `-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAABlwAAAAdzc2gtcn
NhAAAAAwEAAQAAAYEAw/WCSWC9TEvCQOwO+T68EvNa3OSIv1Y0+sT8uSvyjPyEO0+p0t8C
g/zy67vOxiQpU5jN6MItjXAjMmeCm8GKMt6gk+cDoaAev/ZfjuzSL7RayExpmhBleh2X3G
KLkkXF9ABFNchlTqSLOZiEjDoNpbFv16KT1sE6CqW8DjxXQkQk9JK65hLH+BxeWMNCEJVa
Cma4X04aJmC7zJAi5yGeeT0SKVqMohavF90O6XiYFCQHuwXPPyHfocqgudmXnozz+6D6ax
JKZMwQsNp3WKumOjlzWnxBCCB1l2jN6Rag8aJ2277iMFXRwjTL/8jaEsW4KkysDf0GjV2/
iqbr0q5b0arDYbv7CrGBR+uH0wGz/Zog1x5iZANObhZULpDrLVJidEMc27HXBb7PMsNDy7
BGYRB1yc0d0y83p8mUqvOlWSArxn1WnAZO04pAgTrclrhEh4ZXOkn2Sn82eu3DpQ8inkol
Y4IfnhIfbOIeemoUNq1tOUquhow9GLRM6INieHLBAAAFkPPnA1jz5wNYAAAAB3NzaC1yc2
EAAAGBAMP1gklgvUxLwkDsDvk+vBLzWtzkiL9WNPrE/Lkr8oz8hDtPqdLfAoP88uu7zsYk
KVOYzejCLY1wIzJngpvBijLeoJPnA6GgHr/2X47s0i+0WshMaZoQZXodl9xii5JFxfQART
XIZU6kizmYhIw6DaWxb9eik9bBOgqlvA48V0JEJPSSuuYSx/gcXljDQhCVWgpmuF9OGiZg
u8yQIuchnnk9EilajKIWrxfdDul4mBQkB7sFzz8h36HKoLnZl56M8/ug+msSSmTMELDad1
irpjo5c1p8QQggdZdozekWoPGidtu+4jBV0cI0y//I2hLFuCpMrA39Bo1dv4qm69KuW9Gq
w2G7+wqxgUfrh9MBs/2aINceYmQDTm4WVC6Q6y1SYnRDHNux1wW+zzLDQ8uwRmEQdcnNHd
MvN6fJlKrzpVkgK8Z9VpwGTtOKQIE63Ja4RIeGVzpJ9kp/Nnrtw6UPIp5KJWOCH54SH2zi
HnpqFDatbTlKroaMPRi0TOiDYnhywQAAAAMBAAEAAAGAYycx4oEhp55Zz1HijblxnsEmQ8
kbbH1pV04fdm7HTxFis0Qu8PVIp5JxNFiWWunnQ1Z5MgI23G9WT+XST4+RpwXBCLWGv9xu
UsGOPpqUC/FdUiZf9MXBIxYgRjJS3xORA1KzsnAQ2sclb2I+B1pEl4d9yQWJesvQ25xa2H
Utzej/LgWkrk/ogSGRl6ZNImj/421wc0DouGyP+gUgtATt0/jT3LrlmAqUVCXVqssLYH2O
r9JTuGUibBJEW2W/c0lsM0jaHa5bGAdL3nhDuF1Q6KFB87mZoNw8c2znYoTzQ3FyWtIEZI
V/9oWrkS7V6242SKSR9tJoEzK0jtrKC/FZwBiI4hPcwoqY6fZbT1701i/n50xWEfEUOLVm
d6VqNKyAbIaZIPN0qfZuD+xdrHuM3V6k/rgFxGl4XTrp/N4AsruiQs0nRQKNTw3fHE0zPq
UTxSeMvjywRCepxhBFCNh8NHydapclHtEPEGdTVHohL3krJehstPO/IuRyKLfSVtL1AAAA
wQCmGA8k+uW6mway9J3jp8mlMhhp3DCX6DAcvalbA/S5OcqMyiTM3c/HD5OJ6OYFDldcqu
MPEgLRL2HfxL29LsbQSzjyOIrfp5PLJlo70P5lXS8u2QPbo4/KQJmQmsIX18LDyU2zRtNA
C2WfBiHSZV+guLhmHms9S5gQYKt2T5OnY/W0tmnInx9lmFCMC+XKS1iSQ2o433IrtCPQJp
IXZd59OQpO9QjJABgJIDtXxFIXt45qpXduDPJuggrhg81stOwAAADBAPX73u/CY+QUPts+
LV185Z4mZ2y+qu2ZMCAU3BnpHktGZZ1vFN1Xq9o8KdnuPZ+QJRdO8eKMWpySqrIdIbTYLm
9nXmVH0uNECIEAvdU+wgKeR+BSHxCRVuTF4YSygmNadgH/z+oRWLgOblGo2ywFBoXsIAKQ
paNu1MFGRUmhz67+dcpkkBUDRU9loAgBKexMo8D9vkR0YiHLOUjCrtmEZRNm0YRZt0gQhD
ZSD1fOH0fZDcCVNpGP2zqAKos4EGLnkwAAAMEAy/AuLtPKA2u9oCA8e18ZnuQRAi27FBVU
rU2D7bMg1eS0IakG8v0gE9K6WdYzyArY1RoKB3ZklK5VmJ1cOcWc2x3Ejc5jcJgc8cC6lZ
wwjpE8HfWL1kIIYgPdcexqFc+l6MdgH6QMKU3nLg1LsM4v5FEldtk/2dmnw620xnFfstpF
VxSZNdKrYfM/v9o6sRaDRqSfH1dG8BvkUxPznTAF+JDxBENcKXYECcq9f6dcl1w5IEnNTD
Wry/EKQvgvOUjbAAAAFG90aGVyLXRlc3RAcmVrb3IuZGV2AQIDBAUG
-----END OPENSSH PRIVATE KEY-----
`
otherSSHPublicKey = `ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQDD9YJJYL1MS8JA7A75PrwS81rc5Ii/VjT6xPy5K/KM/IQ7T6nS3wKD/PLru87GJClTmM3owi2NcCMyZ4KbwYoy3qCT5wOhoB6/9l+O7NIvtFrITGmaEGV6HZfcYouSRcX0AEU1yGVOpIs5mISMOg2lsW/XopPWwToKpbwOPFdCRCT0krrmEsf4HF5Yw0IQlVoKZrhfThomYLvMkCLnIZ55PRIpWoyiFq8X3Q7peJgUJAe7Bc8/Id+hyqC52ZeejPP7oPprEkpkzBCw2ndYq6Y6OXNafEEIIHWXaM3pFqDxonbbvuIwVdHCNMv/yNoSxbgqTKwN/QaNXb+KpuvSrlvRqsNhu/sKsYFH64fTAbP9miDXHmJkA05uFlQukOstUmJ0QxzbsdcFvs8yw0PLsEZhEHXJzR3TLzenyZSq86VZICvGfVacBk7TikCBOtyWuESHhlc6SfZKfzZ67cOlDyKeSiVjgh+eEh9s4h56ahQ2rW05Sq6GjD0YtEzog2J4csE= other-test@rekor.dev
`
// Generated with ssh-keygen -C test@rekor.dev -t ed25519 -f id_ed25519
ed25519PrivateKey = `-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW
QyNTUxOQAAACBB45zRHxPPFtabwS3Vd6Lb9vMe+tIHZj2qN5VQ+bgLfQAAAJgyRa3cMkWt
3AAAAAtzc2gtZWQyNTUxOQAAACBB45zRHxPPFtabwS3Vd6Lb9vMe+tIHZj2qN5VQ+bgLfQ
AAAED7y4N/DsVnRQiBZNxEWdsJ9RmbranvtQ3X9jnb6gFed0HjnNEfE88W1pvBLdV3otv2
8x760gdmPao3lVD5uAt9AAAADnRlc3RAcmVrb3IuZGV2AQIDBAUGBw==
-----END OPENSSH PRIVATE KEY-----
`
ed25519PublicKey = `ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIEHjnNEfE88W1pvBLdV3otv28x760gdmPao3lVD5uAt9 test@rekor.dev
`
)
func TestFromOpenSSH(t *testing.T) {
for _, tt := range []struct {
name string
pub string
priv string
}{
{
name: "rsa",
pub: sshPublicKey,
priv: sshPrivateKey,
},
{
name: "ed25519",
pub: ed25519PublicKey,
priv: ed25519PrivateKey,
},
} {
if _, err := exec.LookPath("ssh-keygen"); err != nil {
t.Skip("skip TestFromOpenSSH: missing ssh-keygen in PATH")
}
t.Run(tt.name, func(t *testing.T) {
tt := tt
// Test that a signature from the cli can validate here.
td := t.TempDir()
data := []byte("hello, ssh world")
dataPath := write(t, data, td, "data")
privPath := write(t, []byte(tt.priv), td, "id")
write(t, []byte(tt.pub), td, "id.pub")
sigPath := dataPath + ".sig"
run(t, nil, "ssh-keygen", "-Y", "sign", "-n", "file", "-f", privPath, dataPath)
sigBytes, err := os.ReadFile(sigPath)
if err != nil {
t.Fatal(err)
}
if err := sshsig.Verify(bytes.NewReader(data), sigBytes, []byte(tt.pub), "file"); err != nil {
t.Error(err)
}
// It should not verify if we check against another public key
if err := sshsig.Verify(bytes.NewReader(data), sigBytes, []byte(otherSSHPublicKey), "file"); err == nil {
t.Error("expected error with incorrect key")
}
// It should not verify if the data is tampered
if err := sshsig.Verify(strings.NewReader("bad data"), sigBytes, []byte(sshPublicKey), "file"); err == nil {
t.Error("expected error with incorrect data")
}
})
}
}
func TestToOpenSSH(t *testing.T) {
for _, tt := range []struct {
name string
pub string
priv string
}{
{
name: "rsa",
pub: sshPublicKey,
priv: sshPrivateKey,
},
{
name: "ed25519",
pub: ed25519PublicKey,
priv: ed25519PrivateKey,
},
} {
if _, err := exec.LookPath("ssh-keygen"); err != nil {
t.Skip("skip TestToOpenSSH: missing ssh-keygen in PATH")
}
t.Run(tt.name, func(t *testing.T) {
tt := tt
// Test that a signature from here can validate in the CLI.
td := t.TempDir()
data := []byte("hello, ssh world")
write(t, data, td, "data")
armored, err := sshsig.Sign([]byte(tt.priv), bytes.NewReader(data), "file")
if err != nil {
t.Fatal(err)
}
sigPath := write(t, armored, td, "oursig")
// Create an allowed_signers file with two keys to check against.
allowedSigner := "test@rekor.dev " + tt.pub + "\n"
allowedSigner += "othertest@rekor.dev " + otherSSHPublicKey + "\n"
allowedSigners := write(t, []byte(allowedSigner), td, "allowed_signer")
// We use the correct principal here so it should work.
run(t, data, "ssh-keygen", "-Y", "verify", "-f", allowedSigners,
"-I", "test@rekor.dev", "-n", "file", "-s", sigPath)
// Just to be sure, check against the other public key as well.
runErr(t, data, "ssh-keygen", "-Y", "verify", "-f", allowedSigners,
"-I", "othertest@rekor.dev", "-n", "file", "-s", sigPath)
// It should error if we run it against other data
data = []byte("other data!")
runErr(t, data, "ssh-keygen", "-Y", "check-novalidate", "-n", "file", "-s", sigPath)
})
}
}
func TestRoundTrip(t *testing.T) {
data := []byte("my good data to be signed!")
// Create one extra signature for all the tests.
otherSig, err := sshsig.Sign([]byte(otherSSHPrivateKey), bytes.NewReader(data), "file")
if err != nil {
t.Fatal(err)
}
for _, tt := range []struct {
name string
pub string
priv string
}{
{
name: "rsa",
pub: sshPublicKey,
priv: sshPrivateKey,
},
{
name: "ed25519",
pub: ed25519PublicKey,
priv: ed25519PrivateKey,
},
} {
t.Run(tt.name, func(t *testing.T) {
tt := tt
sig, err := sshsig.Sign([]byte(tt.priv), bytes.NewReader(data), "file")
if err != nil {
t.Fatal(err)
}
// Check the signature against that data and public key
if err := sshsig.Verify(bytes.NewReader(data), sig, []byte(tt.pub), "file"); err != nil {
t.Error(err)
}
// Now check it against invalid data.
if err := sshsig.Verify(strings.NewReader("invalid data!"), sig, []byte(tt.pub), "file"); err == nil {
t.Error("expected error!")
}
// Now check it against the wrong key.
if err := sshsig.Verify(bytes.NewReader(data), sig, []byte(otherSSHPublicKey), "file"); err == nil {
t.Error("expected error!")
}
// Now check it against an invalid signature data.
if err := sshsig.Verify(bytes.NewReader(data), []byte("invalid signature!"), []byte(tt.pub), "file"); err == nil {
t.Error("expected error!")
}
// Once more, use the wrong signature and check it against the original (wrong public key)
if err := sshsig.Verify(bytes.NewReader(data), otherSig, []byte(tt.pub), "file"); err == nil {
t.Error("expected error!")
}
// It should work against the correct public key.
if err := sshsig.Verify(bytes.NewReader(data), otherSig, []byte(otherSSHPublicKey), "file"); err != nil {
t.Error(err)
}
})
}
}
func write(t *testing.T, d []byte, fp ...string) string {
p := filepath.Join(fp...)
if err := os.WriteFile(p, d, 0o600); err != nil {
t.Fatal(err)
}
return p
}
func run(t *testing.T, stdin []byte, args ...string) {
t.Helper()
/* #nosec */
cmd := exec.Command(args[0], args[1:]...)
cmd.Stdin = bytes.NewReader(stdin)
out, err := cmd.CombinedOutput()
t.Logf("cmd %v: %s", cmd, string(out))
if err != nil {
t.Fatal(err)
}
}
func runErr(t *testing.T, stdin []byte, args ...string) {
t.Helper()
/* #nosec */
cmd := exec.Command(args[0], args[1:]...)
cmd.Stdin = bytes.NewReader(stdin)
out, err := cmd.CombinedOutput()
t.Logf("cmd %v: %s", cmd, string(out))
if err == nil {
t.Fatal("expected error")
}
}
func Test_PublicKeysAreExternallyManaged(t *testing.T) {
key1 := unittest.AssertExistsAndLoadBean(t, &PublicKey{ID: 1})
externals, err := PublicKeysAreExternallyManaged(db.DefaultContext, []*PublicKey{key1})
assert.NoError(t, err)
assert.Len(t, externals, 1)
assert.False(t, externals[0])
}

View File

@@ -0,0 +1,55 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package asymkey
import (
"bytes"
"context"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/log"
"github.com/42wim/sshsig"
)
// VerifySSHKey marks a SSH key as verified
func VerifySSHKey(ctx context.Context, ownerID int64, fingerprint, token, signature string) (string, error) {
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return "", err
}
defer committer.Close()
key := new(PublicKey)
has, err := db.GetEngine(ctx).Where("owner_id = ? AND fingerprint = ?", ownerID, fingerprint).Get(key)
if err != nil {
return "", err
} else if !has {
return "", ErrKeyNotExist{}
}
err = sshsig.Verify(bytes.NewBuffer([]byte(token)), []byte(signature), []byte(key.Content), "gitea")
if err != nil {
// edge case for Windows based shells that will add CR LF if piped to ssh-keygen command
// see https://github.com/PowerShell/PowerShell/issues/5974
if sshsig.Verify(bytes.NewBuffer([]byte(token+"\r\n")), []byte(signature), []byte(key.Content), "gitea") != nil {
log.Error("Unable to validate token signature. Error: %v", err)
return "", ErrSSHInvalidTokenSignature{
Fingerprint: key.Fingerprint,
}
}
}
key.Verified = true
if _, err := db.GetEngine(ctx).ID(key.ID).Cols("verified").Update(key); err != nil {
return "", err
}
if err := committer.Commit(); err != nil {
return "", err
}
return key.Fingerprint, nil
}

236
models/auth/access_token.go Normal file
View File

@@ -0,0 +1,236 @@
// Copyright 2014 The Gogs Authors. All rights reserved.
// Copyright 2019 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package auth
import (
"context"
"crypto/subtle"
"encoding/hex"
"fmt"
"time"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/timeutil"
"code.gitea.io/gitea/modules/util"
lru "github.com/hashicorp/golang-lru/v2"
"xorm.io/builder"
)
// ErrAccessTokenNotExist represents a "AccessTokenNotExist" kind of error.
type ErrAccessTokenNotExist struct {
Token string
}
// IsErrAccessTokenNotExist checks if an error is a ErrAccessTokenNotExist.
func IsErrAccessTokenNotExist(err error) bool {
_, ok := err.(ErrAccessTokenNotExist)
return ok
}
func (err ErrAccessTokenNotExist) Error() string {
return fmt.Sprintf("access token does not exist [sha: %s]", err.Token)
}
func (err ErrAccessTokenNotExist) Unwrap() error {
return util.ErrNotExist
}
// ErrAccessTokenEmpty represents a "AccessTokenEmpty" kind of error.
type ErrAccessTokenEmpty struct{}
// IsErrAccessTokenEmpty checks if an error is a ErrAccessTokenEmpty.
func IsErrAccessTokenEmpty(err error) bool {
_, ok := err.(ErrAccessTokenEmpty)
return ok
}
func (err ErrAccessTokenEmpty) Error() string {
return "access token is empty"
}
func (err ErrAccessTokenEmpty) Unwrap() error {
return util.ErrInvalidArgument
}
var successfulAccessTokenCache *lru.Cache[string, any]
// AccessToken represents a personal access token.
type AccessToken struct {
ID int64 `xorm:"pk autoincr"`
UID int64 `xorm:"INDEX"`
Name string
Token string `xorm:"-"`
TokenHash string `xorm:"UNIQUE"` // sha256 of token
TokenSalt string
TokenLastEight string `xorm:"INDEX token_last_eight"`
Scope AccessTokenScope
CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"`
UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
HasRecentActivity bool `xorm:"-"`
HasUsed bool `xorm:"-"`
}
// AfterLoad is invoked from XORM after setting the values of all fields of this object.
func (t *AccessToken) AfterLoad() {
t.HasUsed = t.UpdatedUnix > t.CreatedUnix
t.HasRecentActivity = t.UpdatedUnix.AddDuration(7*24*time.Hour) > timeutil.TimeStampNow()
}
func init() {
db.RegisterModel(new(AccessToken), func() error {
if setting.SuccessfulTokensCacheSize > 0 {
var err error
successfulAccessTokenCache, err = lru.New[string, any](setting.SuccessfulTokensCacheSize)
if err != nil {
return fmt.Errorf("unable to allocate AccessToken cache: %w", err)
}
} else {
successfulAccessTokenCache = nil
}
return nil
})
}
// NewAccessToken creates new access token.
func NewAccessToken(ctx context.Context, t *AccessToken) error {
salt, err := util.CryptoRandomString(10)
if err != nil {
return err
}
token, err := util.CryptoRandomBytes(20)
if err != nil {
return err
}
t.TokenSalt = salt
t.Token = hex.EncodeToString(token)
t.TokenHash = HashToken(t.Token, t.TokenSalt)
t.TokenLastEight = t.Token[len(t.Token)-8:]
_, err = db.GetEngine(ctx).Insert(t)
return err
}
// DisplayPublicOnly whether to display this as a public-only token.
func (t *AccessToken) DisplayPublicOnly() bool {
publicOnly, err := t.Scope.PublicOnly()
if err != nil {
return false
}
return publicOnly
}
func getAccessTokenIDFromCache(token string) int64 {
if successfulAccessTokenCache == nil {
return 0
}
tInterface, ok := successfulAccessTokenCache.Get(token)
if !ok {
return 0
}
t, ok := tInterface.(int64)
if !ok {
return 0
}
return t
}
// GetAccessTokenBySHA returns access token by given token value
func GetAccessTokenBySHA(ctx context.Context, token string) (*AccessToken, error) {
if token == "" {
return nil, ErrAccessTokenEmpty{}
}
// A token is defined as being SHA1 sum these are 40 hexadecimal bytes long
if len(token) != 40 {
return nil, ErrAccessTokenNotExist{token}
}
for _, x := range []byte(token) {
if x < '0' || (x > '9' && x < 'a') || x > 'f' {
return nil, ErrAccessTokenNotExist{token}
}
}
lastEight := token[len(token)-8:]
if id := getAccessTokenIDFromCache(token); id > 0 {
accessToken := &AccessToken{
TokenLastEight: lastEight,
}
// Re-get the token from the db in case it has been deleted in the intervening period
has, err := db.GetEngine(ctx).ID(id).Get(accessToken)
if err != nil {
return nil, err
}
if has {
return accessToken, nil
}
successfulAccessTokenCache.Remove(token)
}
var tokens []AccessToken
err := db.GetEngine(ctx).Table(&AccessToken{}).Where("token_last_eight = ?", lastEight).Find(&tokens)
if err != nil {
return nil, err
} else if len(tokens) == 0 {
return nil, ErrAccessTokenNotExist{token}
}
for _, t := range tokens {
tempHash := HashToken(token, t.TokenSalt)
if subtle.ConstantTimeCompare([]byte(t.TokenHash), []byte(tempHash)) == 1 {
if successfulAccessTokenCache != nil {
successfulAccessTokenCache.Add(token, t.ID)
}
return &t, nil
}
}
return nil, ErrAccessTokenNotExist{token}
}
// AccessTokenByNameExists checks if a token name has been used already by a user.
func AccessTokenByNameExists(ctx context.Context, token *AccessToken) (bool, error) {
return db.GetEngine(ctx).Table("access_token").Where("name = ?", token.Name).And("uid = ?", token.UID).Exist()
}
// ListAccessTokensOptions contain filter options
type ListAccessTokensOptions struct {
db.ListOptions
Name string
UserID int64
}
func (opts ListAccessTokensOptions) ToConds() builder.Cond {
cond := builder.NewCond()
// user id is required, otherwise it will return all result which maybe a possible bug
cond = cond.And(builder.Eq{"uid": opts.UserID})
if len(opts.Name) > 0 {
cond = cond.And(builder.Eq{"name": opts.Name})
}
return cond
}
func (opts ListAccessTokensOptions) ToOrders() string {
return "created_unix DESC"
}
// UpdateAccessToken updates information of access token.
func UpdateAccessToken(ctx context.Context, t *AccessToken) error {
_, err := db.GetEngine(ctx).ID(t.ID).AllCols().Update(t)
return err
}
// DeleteAccessTokenByID deletes access token by given ID.
func DeleteAccessTokenByID(ctx context.Context, id, userID int64) error {
cnt, err := db.GetEngine(ctx).ID(id).Delete(&AccessToken{
UID: userID,
})
if err != nil {
return err
} else if cnt != 1 {
return ErrAccessTokenNotExist{}
}
return nil
}

View File

@@ -0,0 +1,366 @@
// Copyright 2022 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package auth
import (
"fmt"
"strings"
"code.gitea.io/gitea/models/perm"
)
// AccessTokenScopeCategory represents the scope category for an access token
type AccessTokenScopeCategory int
const (
AccessTokenScopeCategoryActivityPub = iota
AccessTokenScopeCategoryAdmin
AccessTokenScopeCategoryMisc // WARN: this is now just a placeholder, don't remove it which will change the following values
AccessTokenScopeCategoryNotification
AccessTokenScopeCategoryOrganization
AccessTokenScopeCategoryPackage
AccessTokenScopeCategoryIssue
AccessTokenScopeCategoryRepository
AccessTokenScopeCategoryUser
)
// AllAccessTokenScopeCategories contains all access token scope categories
var AllAccessTokenScopeCategories = []AccessTokenScopeCategory{
AccessTokenScopeCategoryActivityPub,
AccessTokenScopeCategoryAdmin,
AccessTokenScopeCategoryMisc,
AccessTokenScopeCategoryNotification,
AccessTokenScopeCategoryOrganization,
AccessTokenScopeCategoryPackage,
AccessTokenScopeCategoryIssue,
AccessTokenScopeCategoryRepository,
AccessTokenScopeCategoryUser,
}
// AccessTokenScopeLevel represents the access levels without a given scope category
type AccessTokenScopeLevel int
const (
NoAccess AccessTokenScopeLevel = iota
Read
Write
)
// AccessTokenScope represents the scope for an access token.
type AccessTokenScope string
// for all categories, write implies read
const (
AccessTokenScopeAll AccessTokenScope = "all"
AccessTokenScopePublicOnly AccessTokenScope = "public-only" // limited to public orgs/repos
AccessTokenScopeReadActivityPub AccessTokenScope = "read:activitypub"
AccessTokenScopeWriteActivityPub AccessTokenScope = "write:activitypub"
AccessTokenScopeReadAdmin AccessTokenScope = "read:admin"
AccessTokenScopeWriteAdmin AccessTokenScope = "write:admin"
AccessTokenScopeReadMisc AccessTokenScope = "read:misc"
AccessTokenScopeWriteMisc AccessTokenScope = "write:misc"
AccessTokenScopeReadNotification AccessTokenScope = "read:notification"
AccessTokenScopeWriteNotification AccessTokenScope = "write:notification"
AccessTokenScopeReadOrganization AccessTokenScope = "read:organization"
AccessTokenScopeWriteOrganization AccessTokenScope = "write:organization"
AccessTokenScopeReadPackage AccessTokenScope = "read:package"
AccessTokenScopeWritePackage AccessTokenScope = "write:package"
AccessTokenScopeReadIssue AccessTokenScope = "read:issue"
AccessTokenScopeWriteIssue AccessTokenScope = "write:issue"
AccessTokenScopeReadRepository AccessTokenScope = "read:repository"
AccessTokenScopeWriteRepository AccessTokenScope = "write:repository"
AccessTokenScopeReadUser AccessTokenScope = "read:user"
AccessTokenScopeWriteUser AccessTokenScope = "write:user"
)
// accessTokenScopeBitmap represents a bitmap of access token scopes.
type accessTokenScopeBitmap uint64
// Bitmap of each scope, including the child scopes.
const (
// AccessTokenScopeAllBits is the bitmap of all access token scopes
accessTokenScopeAllBits accessTokenScopeBitmap = accessTokenScopeWriteActivityPubBits |
accessTokenScopeWriteAdminBits | accessTokenScopeWriteMiscBits | accessTokenScopeWriteNotificationBits |
accessTokenScopeWriteOrganizationBits | accessTokenScopeWritePackageBits | accessTokenScopeWriteIssueBits |
accessTokenScopeWriteRepositoryBits | accessTokenScopeWriteUserBits
accessTokenScopePublicOnlyBits accessTokenScopeBitmap = 1 << iota
accessTokenScopeReadActivityPubBits accessTokenScopeBitmap = 1 << iota
accessTokenScopeWriteActivityPubBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadActivityPubBits
accessTokenScopeReadAdminBits accessTokenScopeBitmap = 1 << iota
accessTokenScopeWriteAdminBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadAdminBits
accessTokenScopeReadMiscBits accessTokenScopeBitmap = 1 << iota
accessTokenScopeWriteMiscBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadMiscBits
accessTokenScopeReadNotificationBits accessTokenScopeBitmap = 1 << iota
accessTokenScopeWriteNotificationBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadNotificationBits
accessTokenScopeReadOrganizationBits accessTokenScopeBitmap = 1 << iota
accessTokenScopeWriteOrganizationBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadOrganizationBits
accessTokenScopeReadPackageBits accessTokenScopeBitmap = 1 << iota
accessTokenScopeWritePackageBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadPackageBits
accessTokenScopeReadIssueBits accessTokenScopeBitmap = 1 << iota
accessTokenScopeWriteIssueBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadIssueBits
accessTokenScopeReadRepositoryBits accessTokenScopeBitmap = 1 << iota
accessTokenScopeWriteRepositoryBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadRepositoryBits
accessTokenScopeReadUserBits accessTokenScopeBitmap = 1 << iota
accessTokenScopeWriteUserBits accessTokenScopeBitmap = 1<<iota | accessTokenScopeReadUserBits
// The current implementation only supports up to 64 token scopes.
// If we need to support > 64 scopes,
// refactoring the whole implementation in this file (and only this file) is needed.
)
// allAccessTokenScopes contains all access token scopes.
// The order is important: parent scope must precede child scopes.
var allAccessTokenScopes = []AccessTokenScope{
AccessTokenScopePublicOnly,
AccessTokenScopeWriteActivityPub, AccessTokenScopeReadActivityPub,
AccessTokenScopeWriteAdmin, AccessTokenScopeReadAdmin,
AccessTokenScopeWriteMisc, AccessTokenScopeReadMisc,
AccessTokenScopeWriteNotification, AccessTokenScopeReadNotification,
AccessTokenScopeWriteOrganization, AccessTokenScopeReadOrganization,
AccessTokenScopeWritePackage, AccessTokenScopeReadPackage,
AccessTokenScopeWriteIssue, AccessTokenScopeReadIssue,
AccessTokenScopeWriteRepository, AccessTokenScopeReadRepository,
AccessTokenScopeWriteUser, AccessTokenScopeReadUser,
}
// allAccessTokenScopeBits contains all access token scopes.
var allAccessTokenScopeBits = map[AccessTokenScope]accessTokenScopeBitmap{
AccessTokenScopeAll: accessTokenScopeAllBits,
AccessTokenScopePublicOnly: accessTokenScopePublicOnlyBits,
AccessTokenScopeReadActivityPub: accessTokenScopeReadActivityPubBits,
AccessTokenScopeWriteActivityPub: accessTokenScopeWriteActivityPubBits,
AccessTokenScopeReadAdmin: accessTokenScopeReadAdminBits,
AccessTokenScopeWriteAdmin: accessTokenScopeWriteAdminBits,
AccessTokenScopeReadMisc: accessTokenScopeReadMiscBits,
AccessTokenScopeWriteMisc: accessTokenScopeWriteMiscBits,
AccessTokenScopeReadNotification: accessTokenScopeReadNotificationBits,
AccessTokenScopeWriteNotification: accessTokenScopeWriteNotificationBits,
AccessTokenScopeReadOrganization: accessTokenScopeReadOrganizationBits,
AccessTokenScopeWriteOrganization: accessTokenScopeWriteOrganizationBits,
AccessTokenScopeReadPackage: accessTokenScopeReadPackageBits,
AccessTokenScopeWritePackage: accessTokenScopeWritePackageBits,
AccessTokenScopeReadIssue: accessTokenScopeReadIssueBits,
AccessTokenScopeWriteIssue: accessTokenScopeWriteIssueBits,
AccessTokenScopeReadRepository: accessTokenScopeReadRepositoryBits,
AccessTokenScopeWriteRepository: accessTokenScopeWriteRepositoryBits,
AccessTokenScopeReadUser: accessTokenScopeReadUserBits,
AccessTokenScopeWriteUser: accessTokenScopeWriteUserBits,
}
// readAccessTokenScopes maps a scope category to the read permission scope
var accessTokenScopes = map[AccessTokenScopeLevel]map[AccessTokenScopeCategory]AccessTokenScope{
Read: {
AccessTokenScopeCategoryActivityPub: AccessTokenScopeReadActivityPub,
AccessTokenScopeCategoryAdmin: AccessTokenScopeReadAdmin,
AccessTokenScopeCategoryMisc: AccessTokenScopeReadMisc,
AccessTokenScopeCategoryNotification: AccessTokenScopeReadNotification,
AccessTokenScopeCategoryOrganization: AccessTokenScopeReadOrganization,
AccessTokenScopeCategoryPackage: AccessTokenScopeReadPackage,
AccessTokenScopeCategoryIssue: AccessTokenScopeReadIssue,
AccessTokenScopeCategoryRepository: AccessTokenScopeReadRepository,
AccessTokenScopeCategoryUser: AccessTokenScopeReadUser,
},
Write: {
AccessTokenScopeCategoryActivityPub: AccessTokenScopeWriteActivityPub,
AccessTokenScopeCategoryAdmin: AccessTokenScopeWriteAdmin,
AccessTokenScopeCategoryMisc: AccessTokenScopeWriteMisc,
AccessTokenScopeCategoryNotification: AccessTokenScopeWriteNotification,
AccessTokenScopeCategoryOrganization: AccessTokenScopeWriteOrganization,
AccessTokenScopeCategoryPackage: AccessTokenScopeWritePackage,
AccessTokenScopeCategoryIssue: AccessTokenScopeWriteIssue,
AccessTokenScopeCategoryRepository: AccessTokenScopeWriteRepository,
AccessTokenScopeCategoryUser: AccessTokenScopeWriteUser,
},
}
// GetRequiredScopes gets the specific scopes for a given level and categories
func GetRequiredScopes(level AccessTokenScopeLevel, scopeCategories ...AccessTokenScopeCategory) []AccessTokenScope {
scopes := make([]AccessTokenScope, 0, len(scopeCategories))
for _, cat := range scopeCategories {
scopes = append(scopes, accessTokenScopes[level][cat])
}
return scopes
}
// ContainsCategory checks if a list of categories contains a specific category
func ContainsCategory(categories []AccessTokenScopeCategory, category AccessTokenScopeCategory) bool {
for _, c := range categories {
if c == category {
return true
}
}
return false
}
// GetScopeLevelFromAccessMode converts permission access mode to scope level
func GetScopeLevelFromAccessMode(mode perm.AccessMode) AccessTokenScopeLevel {
switch mode {
case perm.AccessModeNone:
return NoAccess
case perm.AccessModeRead:
return Read
case perm.AccessModeWrite:
return Write
case perm.AccessModeAdmin:
return Write
case perm.AccessModeOwner:
return Write
default:
return NoAccess
}
}
// parse the scope string into a bitmap, thus removing possible duplicates.
func (s AccessTokenScope) parse() (accessTokenScopeBitmap, error) {
var bitmap accessTokenScopeBitmap
// The following is the more performant equivalent of 'for _, v := range strings.Split(remainingScope, ",")' as this is hot code
remainingScopes := string(s)
for len(remainingScopes) > 0 {
i := strings.IndexByte(remainingScopes, ',')
var v string
if i < 0 {
v = remainingScopes
remainingScopes = ""
} else if i+1 >= len(remainingScopes) {
v = remainingScopes[:i]
remainingScopes = ""
} else {
v = remainingScopes[:i]
remainingScopes = remainingScopes[i+1:]
}
singleScope := AccessTokenScope(v)
if singleScope == "" {
continue
}
if singleScope == AccessTokenScopeAll {
bitmap |= accessTokenScopeAllBits
continue
}
bits, ok := allAccessTokenScopeBits[singleScope]
if !ok {
return 0, fmt.Errorf("invalid access token scope: %s", singleScope)
}
bitmap |= bits
}
return bitmap, nil
}
// StringSlice returns the AccessTokenScope as a []string
func (s AccessTokenScope) StringSlice() []string {
return strings.Split(string(s), ",")
}
// Normalize returns a normalized scope string without any duplicates.
func (s AccessTokenScope) Normalize() (AccessTokenScope, error) {
bitmap, err := s.parse()
if err != nil {
return "", err
}
return bitmap.toScope(), nil
}
// PublicOnly checks if this token scope is limited to public resources
func (s AccessTokenScope) PublicOnly() (bool, error) {
bitmap, err := s.parse()
if err != nil {
return false, err
}
return bitmap.hasScope(AccessTokenScopePublicOnly)
}
// HasScope returns true if the string has the given scope
func (s AccessTokenScope) HasScope(scopes ...AccessTokenScope) (bool, error) {
bitmap, err := s.parse()
if err != nil {
return false, err
}
for _, s := range scopes {
if has, err := bitmap.hasScope(s); !has || err != nil {
return has, err
}
}
return true, nil
}
// HasAnyScope returns true if any of the scopes is contained in the string
func (s AccessTokenScope) HasAnyScope(scopes ...AccessTokenScope) (bool, error) {
bitmap, err := s.parse()
if err != nil {
return false, err
}
for _, s := range scopes {
if has, err := bitmap.hasScope(s); has || err != nil {
return has, err
}
}
return false, nil
}
// hasScope returns true if the string has the given scope
func (bitmap accessTokenScopeBitmap) hasScope(scope AccessTokenScope) (bool, error) {
expectedBits, ok := allAccessTokenScopeBits[scope]
if !ok {
return false, fmt.Errorf("invalid access token scope: %s", scope)
}
return bitmap&expectedBits == expectedBits, nil
}
// toScope returns a normalized scope string without any duplicates.
func (bitmap accessTokenScopeBitmap) toScope() AccessTokenScope {
var scopes []string
// iterate over all scopes, and reconstruct the bitmap
// if the reconstructed bitmap doesn't change, then the scope is already included
var reconstruct accessTokenScopeBitmap
for _, singleScope := range allAccessTokenScopes {
// no need for error checking here, since we know the scope is valid
if ok, _ := bitmap.hasScope(singleScope); ok {
current := reconstruct | allAccessTokenScopeBits[singleScope]
if current == reconstruct {
continue
}
reconstruct = current
scopes = append(scopes, string(singleScope))
}
}
scope := AccessTokenScope(strings.Join(scopes, ","))
scope = AccessTokenScope(strings.ReplaceAll(
string(scope),
"write:activitypub,write:admin,write:misc,write:notification,write:organization,write:package,write:issue,write:repository,write:user",
"all",
))
return scope
}

View File

@@ -0,0 +1,90 @@
// Copyright 2022 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package auth
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
type scopeTestNormalize struct {
in AccessTokenScope
out AccessTokenScope
err error
}
func TestAccessTokenScope_Normalize(t *testing.T) {
tests := []scopeTestNormalize{
{"", "", nil},
{"write:misc,write:notification,read:package,write:notification,public-only", "public-only,write:misc,write:notification,read:package", nil},
{"all", "all", nil},
{"write:activitypub,write:admin,write:misc,write:notification,write:organization,write:package,write:issue,write:repository,write:user", "all", nil},
{"write:activitypub,write:admin,write:misc,write:notification,write:organization,write:package,write:issue,write:repository,write:user,public-only", "public-only,all", nil},
}
for _, scope := range []string{"activitypub", "admin", "misc", "notification", "organization", "package", "issue", "repository", "user"} {
tests = append(tests,
scopeTestNormalize{AccessTokenScope(fmt.Sprintf("read:%s", scope)), AccessTokenScope(fmt.Sprintf("read:%s", scope)), nil},
scopeTestNormalize{AccessTokenScope(fmt.Sprintf("write:%s", scope)), AccessTokenScope(fmt.Sprintf("write:%s", scope)), nil},
scopeTestNormalize{AccessTokenScope(fmt.Sprintf("write:%[1]s,read:%[1]s", scope)), AccessTokenScope(fmt.Sprintf("write:%s", scope)), nil},
scopeTestNormalize{AccessTokenScope(fmt.Sprintf("read:%[1]s,write:%[1]s", scope)), AccessTokenScope(fmt.Sprintf("write:%s", scope)), nil},
scopeTestNormalize{AccessTokenScope(fmt.Sprintf("read:%[1]s,write:%[1]s,write:%[1]s", scope)), AccessTokenScope(fmt.Sprintf("write:%s", scope)), nil},
)
}
for _, test := range tests {
t.Run(string(test.in), func(t *testing.T) {
scope, err := test.in.Normalize()
assert.Equal(t, test.out, scope)
assert.Equal(t, test.err, err)
})
}
}
type scopeTestHasScope struct {
in AccessTokenScope
scope AccessTokenScope
out bool
err error
}
func TestAccessTokenScope_HasScope(t *testing.T) {
tests := []scopeTestHasScope{
{"read:admin", "write:package", false, nil},
{"all", "write:package", true, nil},
{"write:package", "all", false, nil},
{"public-only", "read:issue", false, nil},
}
for _, scope := range []string{"activitypub", "admin", "misc", "notification", "organization", "package", "issue", "repository", "user"} {
tests = append(tests,
scopeTestHasScope{
AccessTokenScope(fmt.Sprintf("read:%s", scope)),
AccessTokenScope(fmt.Sprintf("read:%s", scope)), true, nil,
},
scopeTestHasScope{
AccessTokenScope(fmt.Sprintf("write:%s", scope)),
AccessTokenScope(fmt.Sprintf("write:%s", scope)), true, nil,
},
scopeTestHasScope{
AccessTokenScope(fmt.Sprintf("write:%s", scope)),
AccessTokenScope(fmt.Sprintf("read:%s", scope)), true, nil,
},
scopeTestHasScope{
AccessTokenScope(fmt.Sprintf("read:%s", scope)),
AccessTokenScope(fmt.Sprintf("write:%s", scope)), false, nil,
},
)
}
for _, test := range tests {
t.Run(string(test.in), func(t *testing.T) {
hasScope, err := test.in.HasScope(test.scope)
assert.Equal(t, test.out, hasScope)
assert.Equal(t, test.err, err)
})
}
}

View File

@@ -0,0 +1,132 @@
// Copyright 2016 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package auth_test
import (
"testing"
auth_model "code.gitea.io/gitea/models/auth"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/unittest"
"github.com/stretchr/testify/assert"
)
func TestNewAccessToken(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
token := &auth_model.AccessToken{
UID: 3,
Name: "Token C",
}
assert.NoError(t, auth_model.NewAccessToken(db.DefaultContext, token))
unittest.AssertExistsAndLoadBean(t, token)
invalidToken := &auth_model.AccessToken{
ID: token.ID, // duplicate
UID: 2,
Name: "Token F",
}
assert.Error(t, auth_model.NewAccessToken(db.DefaultContext, invalidToken))
}
func TestAccessTokenByNameExists(t *testing.T) {
name := "Token Gitea"
assert.NoError(t, unittest.PrepareTestDatabase())
token := &auth_model.AccessToken{
UID: 3,
Name: name,
}
// Check to make sure it doesn't exists already
exist, err := auth_model.AccessTokenByNameExists(db.DefaultContext, token)
assert.NoError(t, err)
assert.False(t, exist)
// Save it to the database
assert.NoError(t, auth_model.NewAccessToken(db.DefaultContext, token))
unittest.AssertExistsAndLoadBean(t, token)
// This token must be found by name in the DB now
exist, err = auth_model.AccessTokenByNameExists(db.DefaultContext, token)
assert.NoError(t, err)
assert.True(t, exist)
user4Token := &auth_model.AccessToken{
UID: 4,
Name: name,
}
// Name matches but different user ID, this shouldn't exists in the
// database
exist, err = auth_model.AccessTokenByNameExists(db.DefaultContext, user4Token)
assert.NoError(t, err)
assert.False(t, exist)
}
func TestGetAccessTokenBySHA(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
token, err := auth_model.GetAccessTokenBySHA(db.DefaultContext, "d2c6c1ba3890b309189a8e618c72a162e4efbf36")
assert.NoError(t, err)
assert.Equal(t, int64(1), token.UID)
assert.Equal(t, "Token A", token.Name)
assert.Equal(t, "2b3668e11cb82d3af8c6e4524fc7841297668f5008d1626f0ad3417e9fa39af84c268248b78c481daa7e5dc437784003494f", token.TokenHash)
assert.Equal(t, "e4efbf36", token.TokenLastEight)
_, err = auth_model.GetAccessTokenBySHA(db.DefaultContext, "notahash")
assert.Error(t, err)
assert.True(t, auth_model.IsErrAccessTokenNotExist(err))
_, err = auth_model.GetAccessTokenBySHA(db.DefaultContext, "")
assert.Error(t, err)
assert.True(t, auth_model.IsErrAccessTokenEmpty(err))
}
func TestListAccessTokens(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
tokens, err := db.Find[auth_model.AccessToken](db.DefaultContext, auth_model.ListAccessTokensOptions{UserID: 1})
assert.NoError(t, err)
if assert.Len(t, tokens, 2) {
assert.Equal(t, int64(1), tokens[0].UID)
assert.Equal(t, int64(1), tokens[1].UID)
assert.Contains(t, []string{tokens[0].Name, tokens[1].Name}, "Token A")
assert.Contains(t, []string{tokens[0].Name, tokens[1].Name}, "Token B")
}
tokens, err = db.Find[auth_model.AccessToken](db.DefaultContext, auth_model.ListAccessTokensOptions{UserID: 2})
assert.NoError(t, err)
if assert.Len(t, tokens, 1) {
assert.Equal(t, int64(2), tokens[0].UID)
assert.Equal(t, "Token A", tokens[0].Name)
}
tokens, err = db.Find[auth_model.AccessToken](db.DefaultContext, auth_model.ListAccessTokensOptions{UserID: 100})
assert.NoError(t, err)
assert.Empty(t, tokens)
}
func TestUpdateAccessToken(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
token, err := auth_model.GetAccessTokenBySHA(db.DefaultContext, "4c6f36e6cf498e2a448662f915d932c09c5a146c")
assert.NoError(t, err)
token.Name = "Token Z"
assert.NoError(t, auth_model.UpdateAccessToken(db.DefaultContext, token))
unittest.AssertExistsAndLoadBean(t, token)
}
func TestDeleteAccessTokenByID(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
token, err := auth_model.GetAccessTokenBySHA(db.DefaultContext, "4c6f36e6cf498e2a448662f915d932c09c5a146c")
assert.NoError(t, err)
assert.Equal(t, int64(1), token.UID)
assert.NoError(t, auth_model.DeleteAccessTokenByID(db.DefaultContext, token.ID, 1))
unittest.AssertNotExistsBean(t, token)
err = auth_model.DeleteAccessTokenByID(db.DefaultContext, 100, 100)
assert.Error(t, err)
assert.True(t, auth_model.IsErrAccessTokenNotExist(err))
}

65
models/auth/auth_token.go Normal file
View File

@@ -0,0 +1,65 @@
// Copyright 2023 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package auth
import (
"context"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/timeutil"
"code.gitea.io/gitea/modules/util"
"xorm.io/builder"
)
var ErrAuthTokenNotExist = util.NewNotExistErrorf("auth token does not exist")
type AuthToken struct { //nolint:revive
ID string `xorm:"pk"`
TokenHash string
UserID int64 `xorm:"INDEX"`
ExpiresUnix timeutil.TimeStamp `xorm:"INDEX"`
}
func init() {
db.RegisterModel(new(AuthToken))
}
func InsertAuthToken(ctx context.Context, t *AuthToken) error {
_, err := db.GetEngine(ctx).Insert(t)
return err
}
func GetAuthTokenByID(ctx context.Context, id string) (*AuthToken, error) {
at := &AuthToken{}
has, err := db.GetEngine(ctx).ID(id).Get(at)
if err != nil {
return nil, err
}
if !has {
return nil, ErrAuthTokenNotExist
}
return at, nil
}
func UpdateAuthTokenByID(ctx context.Context, t *AuthToken) error {
_, err := db.GetEngine(ctx).ID(t.ID).Cols("token_hash", "expires_unix").Update(t)
return err
}
func DeleteAuthTokenByID(ctx context.Context, id string) error {
_, err := db.GetEngine(ctx).ID(id).Delete(&AuthToken{})
return err
}
func DeleteAuthTokensByUserID(ctx context.Context, uid int64) error {
_, err := db.GetEngine(ctx).Where(builder.Eq{"user_id": uid}).Delete(&AuthToken{})
return err
}
func DeleteExpiredAuthTokens(ctx context.Context) error {
_, err := db.GetEngine(ctx).Where(builder.Lt{"expires_unix": timeutil.TimeStampNow()}).Delete(&AuthToken{})
return err
}

20
models/auth/main_test.go Normal file
View File

@@ -0,0 +1,20 @@
// Copyright 2020 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package auth_test
import (
"testing"
"code.gitea.io/gitea/models/unittest"
_ "code.gitea.io/gitea/models"
_ "code.gitea.io/gitea/models/actions"
_ "code.gitea.io/gitea/models/activities"
_ "code.gitea.io/gitea/models/auth"
_ "code.gitea.io/gitea/models/perm/access"
)
func TestMain(m *testing.M) {
unittest.MainTest(m)
}

650
models/auth/oauth2.go Normal file
View File

@@ -0,0 +1,650 @@
// Copyright 2019 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package auth
import (
"context"
"crypto/sha256"
"encoding/base32"
"encoding/base64"
"errors"
"fmt"
"net"
"net/url"
"strings"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/container"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/timeutil"
"code.gitea.io/gitea/modules/util"
uuid "github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
"xorm.io/builder"
"xorm.io/xorm"
)
// OAuth2Application represents an OAuth2 client (RFC 6749)
type OAuth2Application struct {
ID int64 `xorm:"pk autoincr"`
UID int64 `xorm:"INDEX"`
Name string
ClientID string `xorm:"unique"`
ClientSecret string
// OAuth defines both Confidential and Public client types
// https://datatracker.ietf.org/doc/html/rfc6749#section-2.1
// "Authorization servers MUST record the client type in the client registration details"
// https://datatracker.ietf.org/doc/html/rfc8252#section-8.4
ConfidentialClient bool `xorm:"NOT NULL DEFAULT TRUE"`
SkipSecondaryAuthorization bool `xorm:"NOT NULL DEFAULT FALSE"`
RedirectURIs []string `xorm:"redirect_uris JSON TEXT"`
CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"`
UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
}
func init() {
db.RegisterModel(new(OAuth2Application))
db.RegisterModel(new(OAuth2AuthorizationCode))
db.RegisterModel(new(OAuth2Grant))
}
type BuiltinOAuth2Application struct {
ConfigName string
DisplayName string
RedirectURIs []string
}
func BuiltinApplications() map[string]*BuiltinOAuth2Application {
m := make(map[string]*BuiltinOAuth2Application)
m["a4792ccc-144e-407e-86c9-5e7d8d9c3269"] = &BuiltinOAuth2Application{
ConfigName: "git-credential-oauth",
DisplayName: "git-credential-oauth",
RedirectURIs: []string{"http://127.0.0.1", "https://127.0.0.1"},
}
m["e90ee53c-94e2-48ac-9358-a874fb9e0662"] = &BuiltinOAuth2Application{
ConfigName: "git-credential-manager",
DisplayName: "Git Credential Manager",
RedirectURIs: []string{"http://127.0.0.1", "https://127.0.0.1"},
}
m["d57cb8c4-630c-4168-8324-ec79935e18d4"] = &BuiltinOAuth2Application{
ConfigName: "tea",
DisplayName: "tea",
RedirectURIs: []string{"http://127.0.0.1", "https://127.0.0.1"},
}
return m
}
func Init(ctx context.Context) error {
builtinApps := BuiltinApplications()
var builtinAllClientIDs []string
for clientID := range builtinApps {
builtinAllClientIDs = append(builtinAllClientIDs, clientID)
}
var registeredApps []*OAuth2Application
if err := db.GetEngine(ctx).In("client_id", builtinAllClientIDs).Find(&registeredApps); err != nil {
return err
}
clientIDsToAdd := container.Set[string]{}
for _, configName := range setting.OAuth2.DefaultApplications {
found := false
for clientID, builtinApp := range builtinApps {
if builtinApp.ConfigName == configName {
clientIDsToAdd.Add(clientID) // add all user-configured apps to the "add" list
found = true
}
}
if !found {
return fmt.Errorf("unknown oauth2 application: %q", configName)
}
}
clientIDsToDelete := container.Set[string]{}
for _, app := range registeredApps {
if !clientIDsToAdd.Contains(app.ClientID) {
clientIDsToDelete.Add(app.ClientID) // if a registered app is not in the "add" list, it should be deleted
}
}
for _, app := range registeredApps {
clientIDsToAdd.Remove(app.ClientID) // no need to re-add existing (registered) apps, so remove them from the set
}
for _, app := range registeredApps {
if clientIDsToDelete.Contains(app.ClientID) {
if err := deleteOAuth2Application(ctx, app.ID, 0); err != nil {
return err
}
}
}
for clientID := range clientIDsToAdd {
builtinApp := builtinApps[clientID]
if err := db.Insert(ctx, &OAuth2Application{
Name: builtinApp.DisplayName,
ClientID: clientID,
RedirectURIs: builtinApp.RedirectURIs,
}); err != nil {
return err
}
}
return nil
}
// TableName sets the table name to `oauth2_application`
func (app *OAuth2Application) TableName() string {
return "oauth2_application"
}
// ContainsRedirectURI checks if redirectURI is allowed for app
func (app *OAuth2Application) ContainsRedirectURI(redirectURI string) bool {
// OAuth2 requires the redirect URI to be an exact match, no dynamic parts are allowed.
// https://stackoverflow.com/questions/55524480/should-dynamic-query-parameters-be-present-in-the-redirection-uri-for-an-oauth2
// https://www.rfc-editor.org/rfc/rfc6819#section-5.2.3.3
// https://openid.net/specs/openid-connect-core-1_0.html#AuthRequest
// https://datatracker.ietf.org/doc/html/draft-ietf-oauth-security-topics-12#section-3.1
contains := func(s string) bool {
s = strings.TrimSuffix(strings.ToLower(s), "/")
for _, u := range app.RedirectURIs {
if strings.TrimSuffix(strings.ToLower(u), "/") == s {
return true
}
}
return false
}
if !app.ConfidentialClient {
uri, err := url.Parse(redirectURI)
// ignore port for http loopback uris following https://datatracker.ietf.org/doc/html/rfc8252#section-7.3
if err == nil && uri.Scheme == "http" && uri.Port() != "" {
ip := net.ParseIP(uri.Hostname())
if ip != nil && ip.IsLoopback() {
// strip port
uri.Host = uri.Hostname()
if contains(uri.String()) {
return true
}
}
}
}
return contains(redirectURI)
}
// Base32 characters, but lowercased.
const lowerBase32Chars = "abcdefghijklmnopqrstuvwxyz234567"
// base32 encoder that uses lowered characters without padding.
var base32Lower = base32.NewEncoding(lowerBase32Chars).WithPadding(base32.NoPadding)
// GenerateClientSecret will generate the client secret and returns the plaintext and saves the hash at the database
func (app *OAuth2Application) GenerateClientSecret(ctx context.Context) (string, error) {
rBytes, err := util.CryptoRandomBytes(32)
if err != nil {
return "", err
}
// Add a prefix to the base32, this is in order to make it easier
// for code scanners to grab sensitive tokens.
clientSecret := "gto_" + base32Lower.EncodeToString(rBytes)
hashedSecret, err := bcrypt.GenerateFromPassword([]byte(clientSecret), bcrypt.DefaultCost)
if err != nil {
return "", err
}
app.ClientSecret = string(hashedSecret)
if _, err := db.GetEngine(ctx).ID(app.ID).Cols("client_secret").Update(app); err != nil {
return "", err
}
return clientSecret, nil
}
// ValidateClientSecret validates the given secret by the hash saved in database
func (app *OAuth2Application) ValidateClientSecret(secret []byte) bool {
return bcrypt.CompareHashAndPassword([]byte(app.ClientSecret), secret) == nil
}
// GetGrantByUserID returns a OAuth2Grant by its user and application ID
func (app *OAuth2Application) GetGrantByUserID(ctx context.Context, userID int64) (grant *OAuth2Grant, err error) {
grant = new(OAuth2Grant)
if has, err := db.GetEngine(ctx).Where("user_id = ? AND application_id = ?", userID, app.ID).Get(grant); err != nil {
return nil, err
} else if !has {
return nil, nil
}
return grant, nil
}
// CreateGrant generates a grant for an user
func (app *OAuth2Application) CreateGrant(ctx context.Context, userID int64, scope string) (*OAuth2Grant, error) {
grant := &OAuth2Grant{
ApplicationID: app.ID,
UserID: userID,
Scope: scope,
}
err := db.Insert(ctx, grant)
if err != nil {
return nil, err
}
return grant, nil
}
// GetOAuth2ApplicationByClientID returns the oauth2 application with the given client_id. Returns an error if not found.
func GetOAuth2ApplicationByClientID(ctx context.Context, clientID string) (app *OAuth2Application, err error) {
app = new(OAuth2Application)
has, err := db.GetEngine(ctx).Where("client_id = ?", clientID).Get(app)
if !has {
return nil, ErrOAuthClientIDInvalid{ClientID: clientID}
}
return app, err
}
// GetOAuth2ApplicationByID returns the oauth2 application with the given id. Returns an error if not found.
func GetOAuth2ApplicationByID(ctx context.Context, id int64) (app *OAuth2Application, err error) {
app = new(OAuth2Application)
has, err := db.GetEngine(ctx).ID(id).Get(app)
if err != nil {
return nil, err
}
if !has {
return nil, ErrOAuthApplicationNotFound{ID: id}
}
return app, nil
}
// CreateOAuth2ApplicationOptions holds options to create an oauth2 application
type CreateOAuth2ApplicationOptions struct {
Name string
UserID int64
ConfidentialClient bool
SkipSecondaryAuthorization bool
RedirectURIs []string
}
// CreateOAuth2Application inserts a new oauth2 application
func CreateOAuth2Application(ctx context.Context, opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) {
clientID := uuid.New().String()
app := &OAuth2Application{
UID: opts.UserID,
Name: opts.Name,
ClientID: clientID,
RedirectURIs: opts.RedirectURIs,
ConfidentialClient: opts.ConfidentialClient,
SkipSecondaryAuthorization: opts.SkipSecondaryAuthorization,
}
if err := db.Insert(ctx, app); err != nil {
return nil, err
}
return app, nil
}
// UpdateOAuth2ApplicationOptions holds options to update an oauth2 application
type UpdateOAuth2ApplicationOptions struct {
ID int64
Name string
UserID int64
ConfidentialClient bool
SkipSecondaryAuthorization bool
RedirectURIs []string
}
// UpdateOAuth2Application updates an oauth2 application
func UpdateOAuth2Application(ctx context.Context, opts UpdateOAuth2ApplicationOptions) (*OAuth2Application, error) {
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return nil, err
}
defer committer.Close()
app, err := GetOAuth2ApplicationByID(ctx, opts.ID)
if err != nil {
return nil, err
}
if app.UID != opts.UserID {
return nil, errors.New("UID mismatch")
}
builtinApps := BuiltinApplications()
if _, builtin := builtinApps[app.ClientID]; builtin {
return nil, fmt.Errorf("failed to edit OAuth2 application: application is locked: %s", app.ClientID)
}
app.Name = opts.Name
app.RedirectURIs = opts.RedirectURIs
app.ConfidentialClient = opts.ConfidentialClient
app.SkipSecondaryAuthorization = opts.SkipSecondaryAuthorization
if err = updateOAuth2Application(ctx, app); err != nil {
return nil, err
}
app.ClientSecret = ""
return app, committer.Commit()
}
func updateOAuth2Application(ctx context.Context, app *OAuth2Application) error {
if _, err := db.GetEngine(ctx).ID(app.ID).UseBool("confidential_client", "skip_secondary_authorization").Update(app); err != nil {
return err
}
return nil
}
func deleteOAuth2Application(ctx context.Context, id, userid int64) error {
sess := db.GetEngine(ctx)
// the userid could be 0 if the app is instance-wide
if deleted, err := sess.Where(builder.Eq{"id": id, "uid": userid}).Delete(&OAuth2Application{}); err != nil {
return err
} else if deleted == 0 {
return ErrOAuthApplicationNotFound{ID: id}
}
codes := make([]*OAuth2AuthorizationCode, 0)
// delete correlating auth codes
if err := sess.Join("INNER", "oauth2_grant",
"oauth2_authorization_code.grant_id = oauth2_grant.id AND oauth2_grant.application_id = ?", id).Find(&codes); err != nil {
return err
}
codeIDs := make([]int64, 0, len(codes))
for _, grant := range codes {
codeIDs = append(codeIDs, grant.ID)
}
if _, err := sess.In("id", codeIDs).Delete(new(OAuth2AuthorizationCode)); err != nil {
return err
}
if _, err := sess.Where("application_id = ?", id).Delete(new(OAuth2Grant)); err != nil {
return err
}
return nil
}
// DeleteOAuth2Application deletes the application with the given id and the grants and auth codes related to it. It checks if the userid was the creator of the app.
func DeleteOAuth2Application(ctx context.Context, id, userid int64) error {
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return err
}
defer committer.Close()
app, err := GetOAuth2ApplicationByID(ctx, id)
if err != nil {
return err
}
builtinApps := BuiltinApplications()
if _, builtin := builtinApps[app.ClientID]; builtin {
return fmt.Errorf("failed to delete OAuth2 application: application is locked: %s", app.ClientID)
}
if err := deleteOAuth2Application(ctx, id, userid); err != nil {
return err
}
return committer.Commit()
}
//////////////////////////////////////////////////////
// OAuth2AuthorizationCode is a code to obtain an access token in combination with the client secret once. It has a limited lifetime.
type OAuth2AuthorizationCode struct {
ID int64 `xorm:"pk autoincr"`
Grant *OAuth2Grant `xorm:"-"`
GrantID int64
Code string `xorm:"INDEX unique"`
CodeChallenge string
CodeChallengeMethod string
RedirectURI string
ValidUntil timeutil.TimeStamp `xorm:"index"`
}
// TableName sets the table name to `oauth2_authorization_code`
func (code *OAuth2AuthorizationCode) TableName() string {
return "oauth2_authorization_code"
}
// GenerateRedirectURI generates a redirect URI for a successful authorization request. State will be used if not empty.
func (code *OAuth2AuthorizationCode) GenerateRedirectURI(state string) (*url.URL, error) {
redirect, err := url.Parse(code.RedirectURI)
if err != nil {
return nil, err
}
q := redirect.Query()
if state != "" {
q.Set("state", state)
}
q.Set("code", code.Code)
redirect.RawQuery = q.Encode()
return redirect, err
}
// Invalidate deletes the auth code from the database to invalidate this code
func (code *OAuth2AuthorizationCode) Invalidate(ctx context.Context) error {
_, err := db.GetEngine(ctx).ID(code.ID).NoAutoCondition().Delete(code)
return err
}
// ValidateCodeChallenge validates the given verifier against the saved code challenge. This is part of the PKCE implementation.
func (code *OAuth2AuthorizationCode) ValidateCodeChallenge(verifier string) bool {
switch code.CodeChallengeMethod {
case "S256":
// base64url(SHA256(verifier)) see https://tools.ietf.org/html/rfc7636#section-4.6
h := sha256.Sum256([]byte(verifier))
hashedVerifier := base64.RawURLEncoding.EncodeToString(h[:])
return hashedVerifier == code.CodeChallenge
case "plain":
return verifier == code.CodeChallenge
case "":
return true
default:
// unsupported method -> return false
return false
}
}
// GetOAuth2AuthorizationByCode returns an authorization by its code
func GetOAuth2AuthorizationByCode(ctx context.Context, code string) (auth *OAuth2AuthorizationCode, err error) {
auth = new(OAuth2AuthorizationCode)
if has, err := db.GetEngine(ctx).Where("code = ?", code).Get(auth); err != nil {
return nil, err
} else if !has {
return nil, nil
}
auth.Grant = new(OAuth2Grant)
if has, err := db.GetEngine(ctx).ID(auth.GrantID).Get(auth.Grant); err != nil {
return nil, err
} else if !has {
return nil, nil
}
return auth, nil
}
//////////////////////////////////////////////////////
// OAuth2Grant represents the permission of an user for a specific application to access resources
type OAuth2Grant struct {
ID int64 `xorm:"pk autoincr"`
UserID int64 `xorm:"INDEX unique(user_application)"`
Application *OAuth2Application `xorm:"-"`
ApplicationID int64 `xorm:"INDEX unique(user_application)"`
Counter int64 `xorm:"NOT NULL DEFAULT 1"`
Scope string `xorm:"TEXT"`
Nonce string `xorm:"TEXT"`
CreatedUnix timeutil.TimeStamp `xorm:"created"`
UpdatedUnix timeutil.TimeStamp `xorm:"updated"`
}
// TableName sets the table name to `oauth2_grant`
func (grant *OAuth2Grant) TableName() string {
return "oauth2_grant"
}
// GenerateNewAuthorizationCode generates a new authorization code for a grant and saves it to the database
func (grant *OAuth2Grant) GenerateNewAuthorizationCode(ctx context.Context, redirectURI, codeChallenge, codeChallengeMethod string) (code *OAuth2AuthorizationCode, err error) {
rBytes, err := util.CryptoRandomBytes(32)
if err != nil {
return &OAuth2AuthorizationCode{}, err
}
// Add a prefix to the base32, this is in order to make it easier
// for code scanners to grab sensitive tokens.
codeSecret := "gta_" + base32Lower.EncodeToString(rBytes)
code = &OAuth2AuthorizationCode{
Grant: grant,
GrantID: grant.ID,
RedirectURI: redirectURI,
Code: codeSecret,
CodeChallenge: codeChallenge,
CodeChallengeMethod: codeChallengeMethod,
}
if err := db.Insert(ctx, code); err != nil {
return nil, err
}
return code, nil
}
// IncreaseCounter increases the counter and updates the grant
func (grant *OAuth2Grant) IncreaseCounter(ctx context.Context) error {
_, err := db.GetEngine(ctx).ID(grant.ID).Incr("counter").Update(new(OAuth2Grant))
if err != nil {
return err
}
updatedGrant, err := GetOAuth2GrantByID(ctx, grant.ID)
if err != nil {
return err
}
grant.Counter = updatedGrant.Counter
return nil
}
// ScopeContains returns true if the grant scope contains the specified scope
func (grant *OAuth2Grant) ScopeContains(scope string) bool {
for _, currentScope := range strings.Split(grant.Scope, " ") {
if scope == currentScope {
return true
}
}
return false
}
// SetNonce updates the current nonce value of a grant
func (grant *OAuth2Grant) SetNonce(ctx context.Context, nonce string) error {
grant.Nonce = nonce
_, err := db.GetEngine(ctx).ID(grant.ID).Cols("nonce").Update(grant)
if err != nil {
return err
}
return nil
}
// GetOAuth2GrantByID returns the grant with the given ID
func GetOAuth2GrantByID(ctx context.Context, id int64) (grant *OAuth2Grant, err error) {
grant = new(OAuth2Grant)
if has, err := db.GetEngine(ctx).ID(id).Get(grant); err != nil {
return nil, err
} else if !has {
return nil, nil
}
return grant, err
}
// GetOAuth2GrantsByUserID lists all grants of a certain user
func GetOAuth2GrantsByUserID(ctx context.Context, uid int64) ([]*OAuth2Grant, error) {
type joinedOAuth2Grant struct {
Grant *OAuth2Grant `xorm:"extends"`
Application *OAuth2Application `xorm:"extends"`
}
var results *xorm.Rows
var err error
if results, err = db.GetEngine(ctx).
Table("oauth2_grant").
Where("user_id = ?", uid).
Join("INNER", "oauth2_application", "application_id = oauth2_application.id").
Rows(new(joinedOAuth2Grant)); err != nil {
return nil, err
}
defer results.Close()
grants := make([]*OAuth2Grant, 0)
for results.Next() {
joinedGrant := new(joinedOAuth2Grant)
if err := results.Scan(joinedGrant); err != nil {
return nil, err
}
joinedGrant.Grant.Application = joinedGrant.Application
grants = append(grants, joinedGrant.Grant)
}
return grants, nil
}
// RevokeOAuth2Grant deletes the grant with grantID and userID
func RevokeOAuth2Grant(ctx context.Context, grantID, userID int64) error {
_, err := db.GetEngine(ctx).Where(builder.Eq{"id": grantID, "user_id": userID}).Delete(&OAuth2Grant{})
return err
}
// ErrOAuthClientIDInvalid will be thrown if client id cannot be found
type ErrOAuthClientIDInvalid struct {
ClientID string
}
// IsErrOauthClientIDInvalid checks if an error is a ErrOAuthClientIDInvalid.
func IsErrOauthClientIDInvalid(err error) bool {
_, ok := err.(ErrOAuthClientIDInvalid)
return ok
}
// Error returns the error message
func (err ErrOAuthClientIDInvalid) Error() string {
return fmt.Sprintf("Client ID invalid [Client ID: %s]", err.ClientID)
}
// Unwrap unwraps this as a ErrNotExist err
func (err ErrOAuthClientIDInvalid) Unwrap() error {
return util.ErrNotExist
}
// ErrOAuthApplicationNotFound will be thrown if id cannot be found
type ErrOAuthApplicationNotFound struct {
ID int64
}
// IsErrOAuthApplicationNotFound checks if an error is a ErrReviewNotExist.
func IsErrOAuthApplicationNotFound(err error) bool {
_, ok := err.(ErrOAuthApplicationNotFound)
return ok
}
// Error returns the error message
func (err ErrOAuthApplicationNotFound) Error() string {
return fmt.Sprintf("OAuth application not found [ID: %d]", err.ID)
}
// Unwrap unwraps this as a ErrNotExist err
func (err ErrOAuthApplicationNotFound) Unwrap() error {
return util.ErrNotExist
}
// GetActiveOAuth2SourceByName returns a OAuth2 AuthSource based on the given name
func GetActiveOAuth2SourceByName(ctx context.Context, name string) (*Source, error) {
authSource := new(Source)
has, err := db.GetEngine(ctx).Where("name = ? and type = ? and is_active = ?", name, OAuth2, true).Get(authSource)
if err != nil {
return nil, err
}
if !has {
return nil, fmt.Errorf("oauth2 source not found, name: %q", name)
}
return authSource, nil
}
func DeleteOAuth2RelictsByUserID(ctx context.Context, userID int64) error {
deleteCond := builder.Select("id").From("oauth2_grant").Where(builder.Eq{"oauth2_grant.user_id": userID})
if _, err := db.GetEngine(ctx).In("grant_id", deleteCond).
Delete(&OAuth2AuthorizationCode{}); err != nil {
return err
}
if err := db.DeleteBeans(ctx,
&OAuth2Application{UID: userID},
&OAuth2Grant{UserID: userID},
); err != nil {
return fmt.Errorf("DeleteBeans: %w", err)
}
return nil
}

View File

@@ -0,0 +1,32 @@
// Copyright 2023 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package auth
import (
"code.gitea.io/gitea/models/db"
"xorm.io/builder"
)
type FindOAuth2ApplicationsOptions struct {
db.ListOptions
// OwnerID is the user id or org id of the owner of the application
OwnerID int64
// find global applications, if true, then OwnerID will be igonred
IsGlobal bool
}
func (opts FindOAuth2ApplicationsOptions) ToConds() builder.Cond {
conds := builder.NewCond()
if opts.IsGlobal {
conds = conds.And(builder.Eq{"uid": 0})
} else if opts.OwnerID != 0 {
conds = conds.And(builder.Eq{"uid": opts.OwnerID})
}
return conds
}
func (opts FindOAuth2ApplicationsOptions) ToOrders() string {
return "id DESC"
}

265
models/auth/oauth2_test.go Normal file
View File

@@ -0,0 +1,265 @@
// Copyright 2019 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package auth_test
import (
"testing"
auth_model "code.gitea.io/gitea/models/auth"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/unittest"
"github.com/stretchr/testify/assert"
)
func TestOAuth2Application_GenerateClientSecret(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
app := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{ID: 1})
secret, err := app.GenerateClientSecret(db.DefaultContext)
assert.NoError(t, err)
assert.NotEmpty(t, secret)
unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{ID: 1, ClientSecret: app.ClientSecret})
}
func BenchmarkOAuth2Application_GenerateClientSecret(b *testing.B) {
assert.NoError(b, unittest.PrepareTestDatabase())
app := unittest.AssertExistsAndLoadBean(b, &auth_model.OAuth2Application{ID: 1})
for i := 0; i < b.N; i++ {
_, _ = app.GenerateClientSecret(db.DefaultContext)
}
}
func TestOAuth2Application_ContainsRedirectURI(t *testing.T) {
app := &auth_model.OAuth2Application{
RedirectURIs: []string{"a", "b", "c"},
}
assert.True(t, app.ContainsRedirectURI("a"))
assert.True(t, app.ContainsRedirectURI("b"))
assert.True(t, app.ContainsRedirectURI("c"))
assert.False(t, app.ContainsRedirectURI("d"))
}
func TestOAuth2Application_ContainsRedirectURI_WithPort(t *testing.T) {
app := &auth_model.OAuth2Application{
RedirectURIs: []string{"http://127.0.0.1/", "http://::1/", "http://192.168.0.1/", "http://intranet/", "https://127.0.0.1/"},
ConfidentialClient: false,
}
// http loopback uris should ignore port
// https://datatracker.ietf.org/doc/html/rfc8252#section-7.3
assert.True(t, app.ContainsRedirectURI("http://127.0.0.1:3456/"))
assert.True(t, app.ContainsRedirectURI("http://127.0.0.1/"))
assert.True(t, app.ContainsRedirectURI("http://[::1]:3456/"))
// not http
assert.False(t, app.ContainsRedirectURI("https://127.0.0.1:3456/"))
// not loopback
assert.False(t, app.ContainsRedirectURI("http://192.168.0.1:9954/"))
assert.False(t, app.ContainsRedirectURI("http://intranet:3456/"))
// unparseable
assert.False(t, app.ContainsRedirectURI(":"))
}
func TestOAuth2Application_ContainsRedirect_Slash(t *testing.T) {
app := &auth_model.OAuth2Application{RedirectURIs: []string{"http://127.0.0.1"}}
assert.True(t, app.ContainsRedirectURI("http://127.0.0.1"))
assert.True(t, app.ContainsRedirectURI("http://127.0.0.1/"))
assert.False(t, app.ContainsRedirectURI("http://127.0.0.1/other"))
app = &auth_model.OAuth2Application{RedirectURIs: []string{"http://127.0.0.1/"}}
assert.True(t, app.ContainsRedirectURI("http://127.0.0.1"))
assert.True(t, app.ContainsRedirectURI("http://127.0.0.1/"))
assert.False(t, app.ContainsRedirectURI("http://127.0.0.1/other"))
}
func TestOAuth2Application_ValidateClientSecret(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
app := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{ID: 1})
secret, err := app.GenerateClientSecret(db.DefaultContext)
assert.NoError(t, err)
assert.True(t, app.ValidateClientSecret([]byte(secret)))
assert.False(t, app.ValidateClientSecret([]byte("fewijfowejgfiowjeoifew")))
}
func TestGetOAuth2ApplicationByClientID(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
app, err := auth_model.GetOAuth2ApplicationByClientID(db.DefaultContext, "da7da3ba-9a13-4167-856f-3899de0b0138")
assert.NoError(t, err)
assert.Equal(t, "da7da3ba-9a13-4167-856f-3899de0b0138", app.ClientID)
app, err = auth_model.GetOAuth2ApplicationByClientID(db.DefaultContext, "invalid client id")
assert.Error(t, err)
assert.Nil(t, app)
}
func TestCreateOAuth2Application(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
app, err := auth_model.CreateOAuth2Application(db.DefaultContext, auth_model.CreateOAuth2ApplicationOptions{Name: "newapp", UserID: 1})
assert.NoError(t, err)
assert.Equal(t, "newapp", app.Name)
assert.Len(t, app.ClientID, 36)
unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{Name: "newapp"})
}
func TestOAuth2Application_TableName(t *testing.T) {
assert.Equal(t, "oauth2_application", new(auth_model.OAuth2Application).TableName())
}
func TestOAuth2Application_GetGrantByUserID(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
app := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{ID: 1})
grant, err := app.GetGrantByUserID(db.DefaultContext, 1)
assert.NoError(t, err)
assert.Equal(t, int64(1), grant.UserID)
grant, err = app.GetGrantByUserID(db.DefaultContext, 34923458)
assert.NoError(t, err)
assert.Nil(t, grant)
}
func TestOAuth2Application_CreateGrant(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
app := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Application{ID: 1})
grant, err := app.CreateGrant(db.DefaultContext, 2, "")
assert.NoError(t, err)
assert.NotNil(t, grant)
assert.Equal(t, int64(2), grant.UserID)
assert.Equal(t, int64(1), grant.ApplicationID)
assert.Equal(t, "", grant.Scope)
}
//////////////////// Grant
func TestGetOAuth2GrantByID(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
grant, err := auth_model.GetOAuth2GrantByID(db.DefaultContext, 1)
assert.NoError(t, err)
assert.Equal(t, int64(1), grant.ID)
grant, err = auth_model.GetOAuth2GrantByID(db.DefaultContext, 34923458)
assert.NoError(t, err)
assert.Nil(t, grant)
}
func TestOAuth2Grant_IncreaseCounter(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
grant := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Grant{ID: 1, Counter: 1})
assert.NoError(t, grant.IncreaseCounter(db.DefaultContext))
assert.Equal(t, int64(2), grant.Counter)
unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Grant{ID: 1, Counter: 2})
}
func TestOAuth2Grant_ScopeContains(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
grant := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Grant{ID: 1, Scope: "openid profile"})
assert.True(t, grant.ScopeContains("openid"))
assert.True(t, grant.ScopeContains("profile"))
assert.False(t, grant.ScopeContains("profil"))
assert.False(t, grant.ScopeContains("profile2"))
}
func TestOAuth2Grant_GenerateNewAuthorizationCode(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
grant := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2Grant{ID: 1})
code, err := grant.GenerateNewAuthorizationCode(db.DefaultContext, "https://example2.com/callback", "CjvyTLSdR47G5zYenDA-eDWW4lRrO8yvjcWwbD_deOg", "S256")
assert.NoError(t, err)
assert.NotNil(t, code)
assert.Greater(t, len(code.Code), 32) // secret length > 32
}
func TestOAuth2Grant_TableName(t *testing.T) {
assert.Equal(t, "oauth2_grant", new(auth_model.OAuth2Grant).TableName())
}
func TestGetOAuth2GrantsByUserID(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
result, err := auth_model.GetOAuth2GrantsByUserID(db.DefaultContext, 1)
assert.NoError(t, err)
assert.Len(t, result, 1)
assert.Equal(t, int64(1), result[0].ID)
assert.Equal(t, result[0].ApplicationID, result[0].Application.ID)
result, err = auth_model.GetOAuth2GrantsByUserID(db.DefaultContext, 34134)
assert.NoError(t, err)
assert.Empty(t, result)
}
func TestRevokeOAuth2Grant(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
assert.NoError(t, auth_model.RevokeOAuth2Grant(db.DefaultContext, 1, 1))
unittest.AssertNotExistsBean(t, &auth_model.OAuth2Grant{ID: 1, UserID: 1})
}
//////////////////// Authorization Code
func TestGetOAuth2AuthorizationByCode(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
code, err := auth_model.GetOAuth2AuthorizationByCode(db.DefaultContext, "authcode")
assert.NoError(t, err)
assert.NotNil(t, code)
assert.Equal(t, "authcode", code.Code)
assert.Equal(t, int64(1), code.ID)
code, err = auth_model.GetOAuth2AuthorizationByCode(db.DefaultContext, "does not exist")
assert.NoError(t, err)
assert.Nil(t, code)
}
func TestOAuth2AuthorizationCode_ValidateCodeChallenge(t *testing.T) {
// test plain
code := &auth_model.OAuth2AuthorizationCode{
CodeChallengeMethod: "plain",
CodeChallenge: "test123",
}
assert.True(t, code.ValidateCodeChallenge("test123"))
assert.False(t, code.ValidateCodeChallenge("ierwgjoergjio"))
// test S256
code = &auth_model.OAuth2AuthorizationCode{
CodeChallengeMethod: "S256",
CodeChallenge: "CjvyTLSdR47G5zYenDA-eDWW4lRrO8yvjcWwbD_deOg",
}
assert.True(t, code.ValidateCodeChallenge("N1Zo9-8Rfwhkt68r1r29ty8YwIraXR8eh_1Qwxg7yQXsonBt"))
assert.False(t, code.ValidateCodeChallenge("wiogjerogorewngoenrgoiuenorg"))
// test unknown
code = &auth_model.OAuth2AuthorizationCode{
CodeChallengeMethod: "monkey",
CodeChallenge: "foiwgjioriogeiogjerger",
}
assert.False(t, code.ValidateCodeChallenge("foiwgjioriogeiogjerger"))
// test no code challenge
code = &auth_model.OAuth2AuthorizationCode{
CodeChallengeMethod: "",
CodeChallenge: "foierjiogerogerg",
}
assert.True(t, code.ValidateCodeChallenge(""))
}
func TestOAuth2AuthorizationCode_GenerateRedirectURI(t *testing.T) {
code := &auth_model.OAuth2AuthorizationCode{
RedirectURI: "https://example.com/callback",
Code: "thecode",
}
redirect, err := code.GenerateRedirectURI("thestate")
assert.NoError(t, err)
assert.Equal(t, "https://example.com/callback?code=thecode&state=thestate", redirect.String())
redirect, err = code.GenerateRedirectURI("")
assert.NoError(t, err)
assert.Equal(t, "https://example.com/callback?code=thecode", redirect.String())
}
func TestOAuth2AuthorizationCode_Invalidate(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
code := unittest.AssertExistsAndLoadBean(t, &auth_model.OAuth2AuthorizationCode{Code: "authcode"})
assert.NoError(t, code.Invalidate(db.DefaultContext))
unittest.AssertNotExistsBean(t, &auth_model.OAuth2AuthorizationCode{Code: "authcode"})
}
func TestOAuth2AuthorizationCode_TableName(t *testing.T) {
assert.Equal(t, "oauth2_authorization_code", new(auth_model.OAuth2AuthorizationCode).TableName())
}

120
models/auth/session.go Normal file
View File

@@ -0,0 +1,120 @@
// Copyright 2020 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package auth
import (
"context"
"fmt"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/timeutil"
"xorm.io/builder"
)
// Session represents a session compatible for go-chi session
type Session struct {
Key string `xorm:"pk CHAR(16)"` // has to be Key to match with go-chi/session
Data []byte `xorm:"BLOB"` // on MySQL this has a maximum size of 64Kb - this may need to be increased
Expiry timeutil.TimeStamp // has to be Expiry to match with go-chi/session
}
func init() {
db.RegisterModel(new(Session))
}
// UpdateSession updates the session with provided id
func UpdateSession(ctx context.Context, key string, data []byte) error {
_, err := db.GetEngine(ctx).ID(key).Update(&Session{
Data: data,
Expiry: timeutil.TimeStampNow(),
})
return err
}
// ReadSession reads the data for the provided session
func ReadSession(ctx context.Context, key string) (*Session, error) {
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return nil, err
}
defer committer.Close()
session, exist, err := db.Get[Session](ctx, builder.Eq{"`key`": key})
if err != nil {
return nil, err
} else if !exist {
session = &Session{
Key: key,
Expiry: timeutil.TimeStampNow(),
}
if err := db.Insert(ctx, session); err != nil {
return nil, err
}
}
return session, committer.Commit()
}
// ExistSession checks if a session exists
func ExistSession(ctx context.Context, key string) (bool, error) {
return db.Exist[Session](ctx, builder.Eq{"`key`": key})
}
// DestroySession destroys a session
func DestroySession(ctx context.Context, key string) error {
_, err := db.GetEngine(ctx).Delete(&Session{
Key: key,
})
return err
}
// RegenerateSession regenerates a session from the old id
func RegenerateSession(ctx context.Context, oldKey, newKey string) (*Session, error) {
ctx, committer, err := db.TxContext(ctx)
if err != nil {
return nil, err
}
defer committer.Close()
if has, err := db.Exist[Session](ctx, builder.Eq{"`key`": newKey}); err != nil {
return nil, err
} else if has {
return nil, fmt.Errorf("session Key: %s already exists", newKey)
}
if has, err := db.Exist[Session](ctx, builder.Eq{"`key`": oldKey}); err != nil {
return nil, err
} else if !has {
if err := db.Insert(ctx, &Session{
Key: oldKey,
Expiry: timeutil.TimeStampNow(),
}); err != nil {
return nil, err
}
}
if _, err := db.Exec(ctx, "UPDATE "+db.TableName(&Session{})+" SET `key` = ? WHERE `key`=?", newKey, oldKey); err != nil {
return nil, err
}
s, _, err := db.Get[Session](ctx, builder.Eq{"`key`": newKey})
if err != nil {
// is not exist, it should be impossible
return nil, err
}
return s, committer.Commit()
}
// CountSessions returns the number of sessions
func CountSessions(ctx context.Context) (int64, error) {
return db.GetEngine(ctx).Count(&Session{})
}
// CleanupSessions cleans up expired sessions
func CleanupSessions(ctx context.Context, maxLifetime int64) error {
_, err := db.GetEngine(ctx).Where("expiry <= ?", timeutil.TimeStampNow().Add(-maxLifetime)).Delete(&Session{})
return err
}

395
models/auth/source.go Normal file
View File

@@ -0,0 +1,395 @@
// Copyright 2014 The Gogs Authors. All rights reserved.
// Copyright 2019 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package auth
import (
"context"
"fmt"
"reflect"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/optional"
"code.gitea.io/gitea/modules/timeutil"
"code.gitea.io/gitea/modules/util"
"xorm.io/builder"
"xorm.io/xorm"
"xorm.io/xorm/convert"
)
// Type represents an login type.
type Type int
// Note: new type must append to the end of list to maintain compatibility.
const (
NoType Type = iota
Plain // 1
LDAP // 2
SMTP // 3
PAM // 4
DLDAP // 5
OAuth2 // 6
SSPI // 7
)
// String returns the string name of the LoginType
func (typ Type) String() string {
return Names[typ]
}
// Int returns the int value of the LoginType
func (typ Type) Int() int {
return int(typ)
}
// Names contains the name of LoginType values.
var Names = map[Type]string{
LDAP: "LDAP (via BindDN)",
DLDAP: "LDAP (simple auth)", // Via direct bind
SMTP: "SMTP",
PAM: "PAM",
OAuth2: "OAuth2",
SSPI: "SPNEGO with SSPI",
}
// Config represents login config as far as the db is concerned
type Config interface {
convert.Conversion
}
// SkipVerifiable configurations provide a IsSkipVerify to check if SkipVerify is set
type SkipVerifiable interface {
IsSkipVerify() bool
}
// HasTLSer configurations provide a HasTLS to check if TLS can be enabled
type HasTLSer interface {
HasTLS() bool
}
// UseTLSer configurations provide a HasTLS to check if TLS is enabled
type UseTLSer interface {
UseTLS() bool
}
// SSHKeyProvider configurations provide ProvidesSSHKeys to check if they provide SSHKeys
type SSHKeyProvider interface {
ProvidesSSHKeys() bool
}
// RegisterableSource configurations provide RegisterSource which needs to be run on creation
type RegisterableSource interface {
RegisterSource() error
UnregisterSource() error
}
var registeredConfigs = map[Type]func() Config{}
// RegisterTypeConfig register a config for a provided type
func RegisterTypeConfig(typ Type, exemplar Config) {
if reflect.TypeOf(exemplar).Kind() == reflect.Ptr {
// Pointer:
registeredConfigs[typ] = func() Config {
return reflect.New(reflect.ValueOf(exemplar).Elem().Type()).Interface().(Config)
}
return
}
// Not a Pointer
registeredConfigs[typ] = func() Config {
return reflect.New(reflect.TypeOf(exemplar)).Elem().Interface().(Config)
}
}
// SourceSettable configurations can have their authSource set on them
type SourceSettable interface {
SetAuthSource(*Source)
}
// Source represents an external way for authorizing users.
type Source struct {
ID int64 `xorm:"pk autoincr"`
Type Type
Name string `xorm:"UNIQUE"`
IsActive bool `xorm:"INDEX NOT NULL DEFAULT false"`
IsSyncEnabled bool `xorm:"INDEX NOT NULL DEFAULT false"`
Cfg convert.Conversion `xorm:"TEXT"`
CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"`
UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
}
// TableName xorm will read the table name from this method
func (Source) TableName() string {
return "login_source"
}
func init() {
db.RegisterModel(new(Source))
}
// BeforeSet is invoked from XORM before setting the value of a field of this object.
func (source *Source) BeforeSet(colName string, val xorm.Cell) {
if colName == "type" {
typ := Type(db.Cell2Int64(val))
constructor, ok := registeredConfigs[typ]
if !ok {
return
}
source.Cfg = constructor()
if settable, ok := source.Cfg.(SourceSettable); ok {
settable.SetAuthSource(source)
}
}
}
// TypeName return name of this login source type.
func (source *Source) TypeName() string {
return Names[source.Type]
}
// IsLDAP returns true of this source is of the LDAP type.
func (source *Source) IsLDAP() bool {
return source.Type == LDAP
}
// IsDLDAP returns true of this source is of the DLDAP type.
func (source *Source) IsDLDAP() bool {
return source.Type == DLDAP
}
// IsSMTP returns true of this source is of the SMTP type.
func (source *Source) IsSMTP() bool {
return source.Type == SMTP
}
// IsPAM returns true of this source is of the PAM type.
func (source *Source) IsPAM() bool {
return source.Type == PAM
}
// IsOAuth2 returns true of this source is of the OAuth2 type.
func (source *Source) IsOAuth2() bool {
return source.Type == OAuth2
}
// IsSSPI returns true of this source is of the SSPI type.
func (source *Source) IsSSPI() bool {
return source.Type == SSPI
}
// HasTLS returns true of this source supports TLS.
func (source *Source) HasTLS() bool {
hasTLSer, ok := source.Cfg.(HasTLSer)
return ok && hasTLSer.HasTLS()
}
// UseTLS returns true of this source is configured to use TLS.
func (source *Source) UseTLS() bool {
useTLSer, ok := source.Cfg.(UseTLSer)
return ok && useTLSer.UseTLS()
}
// SkipVerify returns true if this source is configured to skip SSL
// verification.
func (source *Source) SkipVerify() bool {
skipVerifiable, ok := source.Cfg.(SkipVerifiable)
return ok && skipVerifiable.IsSkipVerify()
}
// CreateSource inserts a AuthSource in the DB if not already
// existing with the given name.
func CreateSource(ctx context.Context, source *Source) error {
has, err := db.GetEngine(ctx).Where("name=?", source.Name).Exist(new(Source))
if err != nil {
return err
} else if has {
return ErrSourceAlreadyExist{source.Name}
}
// Synchronization is only available with LDAP for now
if !source.IsLDAP() && !source.IsOAuth2() {
source.IsSyncEnabled = false
}
_, err = db.GetEngine(ctx).Insert(source)
if err != nil {
return err
}
if !source.IsActive {
return nil
}
if settable, ok := source.Cfg.(SourceSettable); ok {
settable.SetAuthSource(source)
}
registerableSource, ok := source.Cfg.(RegisterableSource)
if !ok {
return nil
}
err = registerableSource.RegisterSource()
if err != nil {
// remove the AuthSource in case of errors while registering configuration
if _, err := db.GetEngine(ctx).ID(source.ID).Delete(new(Source)); err != nil {
log.Error("CreateSource: Error while wrapOpenIDConnectInitializeError: %v", err)
}
}
return err
}
type FindSourcesOptions struct {
db.ListOptions
IsActive optional.Option[bool]
LoginType Type
}
func (opts FindSourcesOptions) ToConds() builder.Cond {
conds := builder.NewCond()
if opts.IsActive.Has() {
conds = conds.And(builder.Eq{"is_active": opts.IsActive.Value()})
}
if opts.LoginType != NoType {
conds = conds.And(builder.Eq{"`type`": opts.LoginType})
}
return conds
}
// IsSSPIEnabled returns true if there is at least one activated login
// source of type LoginSSPI
func IsSSPIEnabled(ctx context.Context) bool {
exist, err := db.Exist[Source](ctx, FindSourcesOptions{
IsActive: optional.Some(true),
LoginType: SSPI,
}.ToConds())
if err != nil {
log.Error("IsSSPIEnabled: failed to query active SSPI sources: %v", err)
return false
}
return exist
}
// GetSourceByID returns login source by given ID.
func GetSourceByID(ctx context.Context, id int64) (*Source, error) {
source := new(Source)
if id == 0 {
source.Cfg = registeredConfigs[NoType]()
// Set this source to active
// FIXME: allow disabling of db based password authentication in future
source.IsActive = true
return source, nil
}
has, err := db.GetEngine(ctx).ID(id).Get(source)
if err != nil {
return nil, err
} else if !has {
return nil, ErrSourceNotExist{id}
}
return source, nil
}
// UpdateSource updates a Source record in DB.
func UpdateSource(ctx context.Context, source *Source) error {
var originalSource *Source
if source.IsOAuth2() {
// keep track of the original values so we can restore in case of errors while registering OAuth2 providers
var err error
if originalSource, err = GetSourceByID(ctx, source.ID); err != nil {
return err
}
}
has, err := db.GetEngine(ctx).Where("name=? AND id!=?", source.Name, source.ID).Exist(new(Source))
if err != nil {
return err
} else if has {
return ErrSourceAlreadyExist{source.Name}
}
_, err = db.GetEngine(ctx).ID(source.ID).AllCols().Update(source)
if err != nil {
return err
}
if !source.IsActive {
return nil
}
if settable, ok := source.Cfg.(SourceSettable); ok {
settable.SetAuthSource(source)
}
registerableSource, ok := source.Cfg.(RegisterableSource)
if !ok {
return nil
}
err = registerableSource.RegisterSource()
if err != nil {
// restore original values since we cannot update the provider it self
if _, err := db.GetEngine(ctx).ID(source.ID).AllCols().Update(originalSource); err != nil {
log.Error("UpdateSource: Error while wrapOpenIDConnectInitializeError: %v", err)
}
}
return err
}
// ErrSourceNotExist represents a "SourceNotExist" kind of error.
type ErrSourceNotExist struct {
ID int64
}
// IsErrSourceNotExist checks if an error is a ErrSourceNotExist.
func IsErrSourceNotExist(err error) bool {
_, ok := err.(ErrSourceNotExist)
return ok
}
func (err ErrSourceNotExist) Error() string {
return fmt.Sprintf("login source does not exist [id: %d]", err.ID)
}
// Unwrap unwraps this as a ErrNotExist err
func (err ErrSourceNotExist) Unwrap() error {
return util.ErrNotExist
}
// ErrSourceAlreadyExist represents a "SourceAlreadyExist" kind of error.
type ErrSourceAlreadyExist struct {
Name string
}
// IsErrSourceAlreadyExist checks if an error is a ErrSourceAlreadyExist.
func IsErrSourceAlreadyExist(err error) bool {
_, ok := err.(ErrSourceAlreadyExist)
return ok
}
func (err ErrSourceAlreadyExist) Error() string {
return fmt.Sprintf("login source already exists [name: %s]", err.Name)
}
// Unwrap unwraps this as a ErrExist err
func (err ErrSourceAlreadyExist) Unwrap() error {
return util.ErrAlreadyExist
}
// ErrSourceInUse represents a "SourceInUse" kind of error.
type ErrSourceInUse struct {
ID int64
}
// IsErrSourceInUse checks if an error is a ErrSourceInUse.
func IsErrSourceInUse(err error) bool {
_, ok := err.(ErrSourceInUse)
return ok
}
func (err ErrSourceInUse) Error() string {
return fmt.Sprintf("login source is still used by some users [id: %d]", err.ID)
}

View File

@@ -0,0 +1,60 @@
// Copyright 2019 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package auth_test
import (
"strings"
"testing"
auth_model "code.gitea.io/gitea/models/auth"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/unittest"
"code.gitea.io/gitea/modules/json"
"github.com/stretchr/testify/assert"
"xorm.io/xorm/schemas"
)
type TestSource struct {
Provider string
ClientID string
ClientSecret string
OpenIDConnectAutoDiscoveryURL string
IconURL string
}
// FromDB fills up a LDAPConfig from serialized format.
func (source *TestSource) FromDB(bs []byte) error {
return json.Unmarshal(bs, &source)
}
// ToDB exports a LDAPConfig to a serialized format.
func (source *TestSource) ToDB() ([]byte, error) {
return json.Marshal(source)
}
func TestDumpAuthSource(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
authSourceSchema, err := db.TableInfo(new(auth_model.Source))
assert.NoError(t, err)
auth_model.RegisterTypeConfig(auth_model.OAuth2, new(TestSource))
auth_model.CreateSource(db.DefaultContext, &auth_model.Source{
Type: auth_model.OAuth2,
Name: "TestSource",
IsActive: false,
Cfg: &TestSource{
Provider: "ConvertibleSourceName",
ClientID: "42",
},
})
sb := new(strings.Builder)
db.DumpTables([]*schemas.Table{authSourceSchema}, sb)
assert.Contains(t, sb.String(), `"Provider":"ConvertibleSourceName"`)
}

166
models/auth/twofactor.go Normal file
View File

@@ -0,0 +1,166 @@
// Copyright 2017 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package auth
import (
"context"
"crypto/md5"
"crypto/sha256"
"crypto/subtle"
"encoding/base32"
"encoding/base64"
"encoding/hex"
"fmt"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/secret"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/timeutil"
"code.gitea.io/gitea/modules/util"
"github.com/pquerna/otp/totp"
"golang.org/x/crypto/pbkdf2"
)
//
// Two-factor authentication
//
// ErrTwoFactorNotEnrolled indicates that a user is not enrolled in two-factor authentication.
type ErrTwoFactorNotEnrolled struct {
UID int64
}
// IsErrTwoFactorNotEnrolled checks if an error is a ErrTwoFactorNotEnrolled.
func IsErrTwoFactorNotEnrolled(err error) bool {
_, ok := err.(ErrTwoFactorNotEnrolled)
return ok
}
func (err ErrTwoFactorNotEnrolled) Error() string {
return fmt.Sprintf("user not enrolled in 2FA [uid: %d]", err.UID)
}
// Unwrap unwraps this as a ErrNotExist err
func (err ErrTwoFactorNotEnrolled) Unwrap() error {
return util.ErrNotExist
}
// TwoFactor represents a two-factor authentication token.
type TwoFactor struct {
ID int64 `xorm:"pk autoincr"`
UID int64 `xorm:"UNIQUE"`
Secret string
ScratchSalt string
ScratchHash string
LastUsedPasscode string `xorm:"VARCHAR(10)"`
CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"`
UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
}
func init() {
db.RegisterModel(new(TwoFactor))
}
// GenerateScratchToken recreates the scratch token the user is using.
func (t *TwoFactor) GenerateScratchToken() (string, error) {
tokenBytes, err := util.CryptoRandomBytes(6)
if err != nil {
return "", err
}
// these chars are specially chosen, avoid ambiguous chars like `0`, `O`, `1`, `I`.
const base32Chars = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789"
token := base32.NewEncoding(base32Chars).WithPadding(base32.NoPadding).EncodeToString(tokenBytes)
t.ScratchSalt, _ = util.CryptoRandomString(10)
t.ScratchHash = HashToken(token, t.ScratchSalt)
return token, nil
}
// HashToken return the hashable salt
func HashToken(token, salt string) string {
tempHash := pbkdf2.Key([]byte(token), []byte(salt), 10000, 50, sha256.New)
return hex.EncodeToString(tempHash)
}
// VerifyScratchToken verifies if the specified scratch token is valid.
func (t *TwoFactor) VerifyScratchToken(token string) bool {
if len(token) == 0 {
return false
}
tempHash := HashToken(token, t.ScratchSalt)
return subtle.ConstantTimeCompare([]byte(t.ScratchHash), []byte(tempHash)) == 1
}
func (t *TwoFactor) getEncryptionKey() []byte {
k := md5.Sum([]byte(setting.SecretKey))
return k[:]
}
// SetSecret sets the 2FA secret.
func (t *TwoFactor) SetSecret(secretString string) error {
secretBytes, err := secret.AesEncrypt(t.getEncryptionKey(), []byte(secretString))
if err != nil {
return err
}
t.Secret = base64.StdEncoding.EncodeToString(secretBytes)
return nil
}
// ValidateTOTP validates the provided passcode.
func (t *TwoFactor) ValidateTOTP(passcode string) (bool, error) {
decodedStoredSecret, err := base64.StdEncoding.DecodeString(t.Secret)
if err != nil {
return false, err
}
secretBytes, err := secret.AesDecrypt(t.getEncryptionKey(), decodedStoredSecret)
if err != nil {
return false, err
}
secretStr := string(secretBytes)
return totp.Validate(passcode, secretStr), nil
}
// NewTwoFactor creates a new two-factor authentication token.
func NewTwoFactor(ctx context.Context, t *TwoFactor) error {
_, err := db.GetEngine(ctx).Insert(t)
return err
}
// UpdateTwoFactor updates a two-factor authentication token.
func UpdateTwoFactor(ctx context.Context, t *TwoFactor) error {
_, err := db.GetEngine(ctx).ID(t.ID).AllCols().Update(t)
return err
}
// GetTwoFactorByUID returns the two-factor authentication token associated with
// the user, if any.
func GetTwoFactorByUID(ctx context.Context, uid int64) (*TwoFactor, error) {
twofa := &TwoFactor{}
has, err := db.GetEngine(ctx).Where("uid=?", uid).Get(twofa)
if err != nil {
return nil, err
} else if !has {
return nil, ErrTwoFactorNotEnrolled{uid}
}
return twofa, nil
}
// HasTwoFactorByUID returns the two-factor authentication token associated with
// the user, if any.
func HasTwoFactorByUID(ctx context.Context, uid int64) (bool, error) {
return db.GetEngine(ctx).Where("uid=?", uid).Exist(&TwoFactor{})
}
// DeleteTwoFactorByID deletes two-factor authentication token by given ID.
func DeleteTwoFactorByID(ctx context.Context, id, userID int64) error {
cnt, err := db.GetEngine(ctx).ID(id).Delete(&TwoFactor{
UID: userID,
})
if err != nil {
return err
} else if cnt != 1 {
return ErrTwoFactorNotEnrolled{userID}
}
return nil
}

212
models/auth/webauthn.go Normal file
View File

@@ -0,0 +1,212 @@
// Copyright 2020 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package auth
import (
"context"
"fmt"
"strings"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/timeutil"
"code.gitea.io/gitea/modules/util"
"github.com/go-webauthn/webauthn/protocol"
"github.com/go-webauthn/webauthn/webauthn"
)
// ErrWebAuthnCredentialNotExist represents a "ErrWebAuthnCRedentialNotExist" kind of error.
type ErrWebAuthnCredentialNotExist struct {
ID int64
CredentialID []byte
}
func (err ErrWebAuthnCredentialNotExist) Error() string {
if len(err.CredentialID) == 0 {
return fmt.Sprintf("WebAuthn credential does not exist [id: %d]", err.ID)
}
return fmt.Sprintf("WebAuthn credential does not exist [credential_id: %x]", err.CredentialID)
}
// Unwrap unwraps this as a ErrNotExist err
func (err ErrWebAuthnCredentialNotExist) Unwrap() error {
return util.ErrNotExist
}
// IsErrWebAuthnCredentialNotExist checks if an error is a ErrWebAuthnCredentialNotExist.
func IsErrWebAuthnCredentialNotExist(err error) bool {
_, ok := err.(ErrWebAuthnCredentialNotExist)
return ok
}
// WebAuthnCredential represents the WebAuthn credential data for a public-key
// credential conformant to WebAuthn Level 1
type WebAuthnCredential struct {
ID int64 `xorm:"pk autoincr"`
Name string
LowerName string `xorm:"unique(s)"`
UserID int64 `xorm:"INDEX unique(s)"`
CredentialID []byte `xorm:"INDEX VARBINARY(1024)"`
PublicKey []byte
AttestationType string
AAGUID []byte
SignCount uint32 `xorm:"BIGINT"`
CloneWarning bool
CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"`
UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
}
func init() {
db.RegisterModel(new(WebAuthnCredential))
}
// TableName returns a better table name for WebAuthnCredential
func (cred WebAuthnCredential) TableName() string {
return "webauthn_credential"
}
// UpdateSignCount will update the database value of SignCount
func (cred *WebAuthnCredential) UpdateSignCount(ctx context.Context) error {
_, err := db.GetEngine(ctx).ID(cred.ID).Cols("sign_count").Update(cred)
return err
}
// BeforeInsert will be invoked by XORM before updating a record
func (cred *WebAuthnCredential) BeforeInsert() {
cred.LowerName = strings.ToLower(cred.Name)
}
// BeforeUpdate will be invoked by XORM before updating a record
func (cred *WebAuthnCredential) BeforeUpdate() {
cred.LowerName = strings.ToLower(cred.Name)
}
// AfterLoad is invoked from XORM after setting the values of all fields of this object.
func (cred *WebAuthnCredential) AfterLoad() {
cred.LowerName = strings.ToLower(cred.Name)
}
// WebAuthnCredentialList is a list of *WebAuthnCredential
type WebAuthnCredentialList []*WebAuthnCredential
// newCredentialFlagsFromAuthenticatorFlags is copied from https://github.com/go-webauthn/webauthn/pull/337
// to convert protocol.AuthenticatorFlags to webauthn.CredentialFlags
func newCredentialFlagsFromAuthenticatorFlags(flags protocol.AuthenticatorFlags) webauthn.CredentialFlags {
return webauthn.CredentialFlags{
UserPresent: flags.HasUserPresent(),
UserVerified: flags.HasUserVerified(),
BackupEligible: flags.HasBackupEligible(),
BackupState: flags.HasBackupState(),
}
}
// ToCredentials will convert all WebAuthnCredentials to webauthn.Credentials
func (list WebAuthnCredentialList) ToCredentials(defaultAuthFlags ...protocol.AuthenticatorFlags) []webauthn.Credential {
// TODO: at the moment, Gitea doesn't store or check the flags
// so we need to use the default flags from the authenticator to make the login validation pass
// In the future, we should:
// 1. store the flags when registering the credential
// 2. provide the stored flags when converting the credentials (for login)
// 3. for old users, still use this fallback to the default flags
defAuthFlags := util.OptionalArg(defaultAuthFlags)
creds := make([]webauthn.Credential, 0, len(list))
for _, cred := range list {
creds = append(creds, webauthn.Credential{
ID: cred.CredentialID,
PublicKey: cred.PublicKey,
AttestationType: cred.AttestationType,
Flags: newCredentialFlagsFromAuthenticatorFlags(defAuthFlags),
Authenticator: webauthn.Authenticator{
AAGUID: cred.AAGUID,
SignCount: cred.SignCount,
CloneWarning: cred.CloneWarning,
},
})
}
return creds
}
// GetWebAuthnCredentialsByUID returns all WebAuthn credentials of the given user
func GetWebAuthnCredentialsByUID(ctx context.Context, uid int64) (WebAuthnCredentialList, error) {
creds := make(WebAuthnCredentialList, 0)
return creds, db.GetEngine(ctx).Where("user_id = ?", uid).Find(&creds)
}
// ExistsWebAuthnCredentialsForUID returns if the given user has credentials
func ExistsWebAuthnCredentialsForUID(ctx context.Context, uid int64) (bool, error) {
return db.GetEngine(ctx).Where("user_id = ?", uid).Exist(&WebAuthnCredential{})
}
// GetWebAuthnCredentialByName returns WebAuthn credential by id
func GetWebAuthnCredentialByName(ctx context.Context, uid int64, name string) (*WebAuthnCredential, error) {
cred := new(WebAuthnCredential)
if found, err := db.GetEngine(ctx).Where("user_id = ? AND lower_name = ?", uid, strings.ToLower(name)).Get(cred); err != nil {
return nil, err
} else if !found {
return nil, ErrWebAuthnCredentialNotExist{}
}
return cred, nil
}
// GetWebAuthnCredentialByID returns WebAuthn credential by id
func GetWebAuthnCredentialByID(ctx context.Context, id int64) (*WebAuthnCredential, error) {
cred := new(WebAuthnCredential)
if found, err := db.GetEngine(ctx).ID(id).Get(cred); err != nil {
return nil, err
} else if !found {
return nil, ErrWebAuthnCredentialNotExist{ID: id}
}
return cred, nil
}
// HasWebAuthnRegistrationsByUID returns whether a given user has WebAuthn registrations
func HasWebAuthnRegistrationsByUID(ctx context.Context, uid int64) (bool, error) {
return db.GetEngine(ctx).Where("user_id = ?", uid).Exist(&WebAuthnCredential{})
}
// GetWebAuthnCredentialByCredID returns WebAuthn credential by credential ID
func GetWebAuthnCredentialByCredID(ctx context.Context, userID int64, credID []byte) (*WebAuthnCredential, error) {
cred := new(WebAuthnCredential)
if found, err := db.GetEngine(ctx).Where("user_id = ? AND credential_id = ?", userID, credID).Get(cred); err != nil {
return nil, err
} else if !found {
return nil, ErrWebAuthnCredentialNotExist{CredentialID: credID}
}
return cred, nil
}
// CreateCredential will create a new WebAuthnCredential from the given Credential
func CreateCredential(ctx context.Context, userID int64, name string, cred *webauthn.Credential) (*WebAuthnCredential, error) {
c := &WebAuthnCredential{
UserID: userID,
Name: name,
CredentialID: cred.ID,
PublicKey: cred.PublicKey,
AttestationType: cred.AttestationType,
AAGUID: cred.Authenticator.AAGUID,
SignCount: cred.Authenticator.SignCount,
CloneWarning: false,
}
if err := db.Insert(ctx, c); err != nil {
return nil, err
}
return c, nil
}
// DeleteCredential will delete WebAuthnCredential
func DeleteCredential(ctx context.Context, id, userID int64) (bool, error) {
had, err := db.GetEngine(ctx).ID(id).Where("user_id = ?", userID).Delete(&WebAuthnCredential{})
return had > 0, err
}
// WebAuthnCredentials implements the webauthn.User interface
func WebAuthnCredentials(ctx context.Context, userID int64) ([]webauthn.Credential, error) {
dbCreds, err := GetWebAuthnCredentialsByUID(ctx, userID)
if err != nil {
return nil, err
}
return dbCreds.ToCredentials(), nil
}

View File

@@ -0,0 +1,67 @@
// Copyright 2020 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package auth_test
import (
"testing"
auth_model "code.gitea.io/gitea/models/auth"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/unittest"
"github.com/go-webauthn/webauthn/webauthn"
"github.com/stretchr/testify/assert"
)
func TestGetWebAuthnCredentialByID(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
res, err := auth_model.GetWebAuthnCredentialByID(db.DefaultContext, 1)
assert.NoError(t, err)
assert.Equal(t, "WebAuthn credential", res.Name)
_, err = auth_model.GetWebAuthnCredentialByID(db.DefaultContext, 342432)
assert.Error(t, err)
assert.True(t, auth_model.IsErrWebAuthnCredentialNotExist(err))
}
func TestGetWebAuthnCredentialsByUID(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
res, err := auth_model.GetWebAuthnCredentialsByUID(db.DefaultContext, 32)
assert.NoError(t, err)
assert.Len(t, res, 1)
assert.Equal(t, "WebAuthn credential", res[0].Name)
}
func TestWebAuthnCredential_TableName(t *testing.T) {
assert.Equal(t, "webauthn_credential", auth_model.WebAuthnCredential{}.TableName())
}
func TestWebAuthnCredential_UpdateSignCount(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
cred := unittest.AssertExistsAndLoadBean(t, &auth_model.WebAuthnCredential{ID: 1})
cred.SignCount = 1
assert.NoError(t, cred.UpdateSignCount(db.DefaultContext))
unittest.AssertExistsIf(t, true, &auth_model.WebAuthnCredential{ID: 1, SignCount: 1})
}
func TestWebAuthnCredential_UpdateLargeCounter(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
cred := unittest.AssertExistsAndLoadBean(t, &auth_model.WebAuthnCredential{ID: 1})
cred.SignCount = 0xffffffff
assert.NoError(t, cred.UpdateSignCount(db.DefaultContext))
unittest.AssertExistsIf(t, true, &auth_model.WebAuthnCredential{ID: 1, SignCount: 0xffffffff})
}
func TestCreateCredential(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
res, err := auth_model.CreateCredential(db.DefaultContext, 1, "WebAuthn Created Credential", &webauthn.Credential{ID: []byte("Test")})
assert.NoError(t, err)
assert.Equal(t, "WebAuthn Created Credential", res.Name)
assert.Equal(t, []byte("Test"), res.CredentialID)
unittest.AssertExistsIf(t, true, &auth_model.WebAuthnCredential{Name: "WebAuthn Created Credential", UserID: 1})
}

238
models/avatars/avatar.go Normal file
View File

@@ -0,0 +1,238 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package avatars
import (
"context"
"crypto/md5"
"encoding/hex"
"fmt"
"net/url"
"path"
"strconv"
"strings"
"sync/atomic"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/cache"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/setting"
"strk.kbt.io/projects/go/libravatar"
)
const (
// DefaultAvatarClass is the default class of a rendered avatar
DefaultAvatarClass = "ui avatar tw-align-middle"
// DefaultAvatarPixelSize is the default size in pixels of a rendered avatar
DefaultAvatarPixelSize = 28
)
// EmailHash represents a pre-generated hash map (mainly used by LibravatarURL, it queries email server's DNS records)
type EmailHash struct {
Hash string `xorm:"pk varchar(32)"`
Email string `xorm:"UNIQUE NOT NULL"`
}
func init() {
db.RegisterModel(new(EmailHash))
}
type avatarSettingStruct struct {
defaultAvatarLink string
gravatarSource string
gravatarSourceURL *url.URL
libravatar *libravatar.Libravatar
}
var avatarSettingAtomic atomic.Pointer[avatarSettingStruct]
func loadAvatarSetting() (*avatarSettingStruct, error) {
s := avatarSettingAtomic.Load()
if s == nil || s.gravatarSource != setting.GravatarSource {
s = &avatarSettingStruct{}
u, err := url.Parse(setting.AppSubURL)
if err != nil {
return nil, fmt.Errorf("unable to parse AppSubURL: %w", err)
}
u.Path = path.Join(u.Path, "/assets/img/avatar_default.png")
s.defaultAvatarLink = u.String()
s.gravatarSourceURL, err = url.Parse(setting.GravatarSource)
if err != nil {
return nil, fmt.Errorf("unable to parse GravatarSource %q: %w", setting.GravatarSource, err)
}
s.libravatar = libravatar.New()
if s.gravatarSourceURL.Scheme == "https" {
s.libravatar.SetUseHTTPS(true)
s.libravatar.SetSecureFallbackHost(s.gravatarSourceURL.Host)
} else {
s.libravatar.SetUseHTTPS(false)
s.libravatar.SetFallbackHost(s.gravatarSourceURL.Host)
}
avatarSettingAtomic.Store(s)
}
return s, nil
}
// DefaultAvatarLink the default avatar link
func DefaultAvatarLink() string {
a, err := loadAvatarSetting()
if err != nil {
log.Error("Failed to loadAvatarSetting: %v", err)
return ""
}
return a.defaultAvatarLink
}
// HashEmail hashes email address to MD5 string. https://en.gravatar.com/site/implement/hash/
func HashEmail(email string) string {
m := md5.New()
_, _ = m.Write([]byte(strings.ToLower(strings.TrimSpace(email))))
return hex.EncodeToString(m.Sum(nil))
}
// GetEmailForHash converts a provided md5sum to the email
func GetEmailForHash(ctx context.Context, md5Sum string) (string, error) {
return cache.GetString("Avatar:"+md5Sum, func() (string, error) {
emailHash := EmailHash{
Hash: strings.ToLower(strings.TrimSpace(md5Sum)),
}
_, err := db.GetEngine(ctx).Get(&emailHash)
return emailHash.Email, err
})
}
// LibravatarURL returns the URL for the given email. Slow due to the DNS lookup.
// This function should only be called if a federated avatar service is enabled.
func LibravatarURL(email string) (*url.URL, error) {
a, err := loadAvatarSetting()
if err != nil {
return nil, err
}
urlStr, err := a.libravatar.FromEmail(email)
if err != nil {
log.Error("LibravatarService.FromEmail(email=%s): error %v", email, err)
return nil, err
}
u, err := url.Parse(urlStr)
if err != nil {
log.Error("Failed to parse libravatar url(%s): error %v", urlStr, err)
return nil, err
}
return u, nil
}
// saveEmailHash returns an avatar link for a provided email,
// the email and hash are saved into database, which will be used by GetEmailForHash later
func saveEmailHash(ctx context.Context, email string) string {
lowerEmail := strings.ToLower(strings.TrimSpace(email))
emailHash := HashEmail(lowerEmail)
_, _ = cache.GetString("Avatar:"+emailHash, func() (string, error) {
emailHash := &EmailHash{
Email: lowerEmail,
Hash: emailHash,
}
// OK we're going to open a session just because I think that that might hide away any problems with postgres reporting errors
if err := db.WithTx(ctx, func(ctx context.Context) error {
has, err := db.GetEngine(ctx).Where("email = ? AND hash = ?", emailHash.Email, emailHash.Hash).Get(new(EmailHash))
if has || err != nil {
// Seriously we don't care about any DB problems just return the lowerEmail - we expect the transaction to fail most of the time
return nil
}
_, _ = db.GetEngine(ctx).Insert(emailHash)
return nil
}); err != nil {
// Seriously we don't care about any DB problems just return the lowerEmail - we expect the transaction to fail most of the time
return lowerEmail, nil
}
return lowerEmail, nil
})
return emailHash
}
// GenerateUserAvatarFastLink returns a fast link (302) to the user's avatar: "/user/avatar/${User.Name}/${size}"
func GenerateUserAvatarFastLink(userName string, size int) string {
if size < 0 {
size = 0
}
return setting.AppSubURL + "/user/avatar/" + url.PathEscape(userName) + "/" + strconv.Itoa(size)
}
// GenerateUserAvatarImageLink returns a link for `User.Avatar` image file: "/avatars/${User.Avatar}"
func GenerateUserAvatarImageLink(userAvatar string, size int) string {
if size > 0 {
return setting.AppSubURL + "/avatars/" + url.PathEscape(userAvatar) + "?size=" + strconv.Itoa(size)
}
return setting.AppSubURL + "/avatars/" + url.PathEscape(userAvatar)
}
// generateRecognizedAvatarURL generate a recognized avatar (Gravatar/Libravatar) URL, it modifies the URL so the parameter is passed by a copy
func generateRecognizedAvatarURL(u url.URL, size int) string {
urlQuery := u.Query()
urlQuery.Set("d", "identicon")
if size > 0 {
urlQuery.Set("s", strconv.Itoa(size))
}
u.RawQuery = urlQuery.Encode()
return u.String()
}
// generateEmailAvatarLink returns a email avatar link.
// if final is true, it may use a slow path (eg: query DNS).
// if final is false, it always uses a fast path.
func generateEmailAvatarLink(ctx context.Context, email string, size int, final bool) string {
email = strings.TrimSpace(email)
if email == "" {
return DefaultAvatarLink()
}
avatarSetting, err := loadAvatarSetting()
if err != nil {
return DefaultAvatarLink()
}
enableFederatedAvatar := setting.Config().Picture.EnableFederatedAvatar.Value(ctx)
if enableFederatedAvatar {
emailHash := saveEmailHash(ctx, email)
if final {
// for final link, we can spend more time on slow external query
var avatarURL *url.URL
if avatarURL, err = LibravatarURL(email); err != nil {
return DefaultAvatarLink()
}
return generateRecognizedAvatarURL(*avatarURL, size)
}
// for non-final link, we should return fast (use a 302 redirection link)
urlStr := setting.AppSubURL + "/avatar/" + url.PathEscape(emailHash)
if size > 0 {
urlStr += "?size=" + strconv.Itoa(size)
}
return urlStr
}
disableGravatar := setting.Config().Picture.DisableGravatar.Value(ctx)
if !disableGravatar {
// copy GravatarSourceURL, because we will modify its Path.
avatarURLCopy := *avatarSetting.gravatarSourceURL
avatarURLCopy.Path = path.Join(avatarURLCopy.Path, HashEmail(email))
return generateRecognizedAvatarURL(avatarURLCopy, size)
}
return DefaultAvatarLink()
}
// GenerateEmailAvatarFastLink returns a avatar link (fast, the link may be a delegated one: "/avatar/${hash}")
func GenerateEmailAvatarFastLink(ctx context.Context, email string, size int) string {
return generateEmailAvatarLink(ctx, email, size, false)
}
// GenerateEmailAvatarFinalLink returns a avatar final link (maybe slow)
func GenerateEmailAvatarFinalLink(ctx context.Context, email string, size int) string {
return generateEmailAvatarLink(ctx, email, size, true)
}

View File

@@ -0,0 +1,58 @@
// Copyright 2020 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package avatars_test
import (
"testing"
avatars_model "code.gitea.io/gitea/models/avatars"
"code.gitea.io/gitea/models/db"
system_model "code.gitea.io/gitea/models/system"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/setting/config"
"github.com/stretchr/testify/assert"
)
const gravatarSource = "https://secure.gravatar.com/avatar/"
func disableGravatar(t *testing.T) {
err := system_model.SetSettings(db.DefaultContext, map[string]string{setting.Config().Picture.EnableFederatedAvatar.DynKey(): "false"})
assert.NoError(t, err)
err = system_model.SetSettings(db.DefaultContext, map[string]string{setting.Config().Picture.DisableGravatar.DynKey(): "true"})
assert.NoError(t, err)
}
func enableGravatar(t *testing.T) {
err := system_model.SetSettings(db.DefaultContext, map[string]string{setting.Config().Picture.DisableGravatar.DynKey(): "false"})
assert.NoError(t, err)
setting.GravatarSource = gravatarSource
}
func TestHashEmail(t *testing.T) {
assert.Equal(t,
"d41d8cd98f00b204e9800998ecf8427e",
avatars_model.HashEmail(""),
)
assert.Equal(t,
"353cbad9b58e69c96154ad99f92bedc7",
avatars_model.HashEmail("gitea@example.com"),
)
}
func TestSizedAvatarLink(t *testing.T) {
setting.AppSubURL = "/testsuburl"
disableGravatar(t)
config.GetDynGetter().InvalidateCache()
assert.Equal(t, "/testsuburl/assets/img/avatar_default.png",
avatars_model.GenerateEmailAvatarFastLink(db.DefaultContext, "gitea@example.com", 100))
enableGravatar(t)
config.GetDynGetter().InvalidateCache()
assert.Equal(t,
"https://secure.gravatar.com/avatar/353cbad9b58e69c96154ad99f92bedc7?d=identicon&s=100",
avatars_model.GenerateEmailAvatarFastLink(db.DefaultContext, "gitea@example.com", 100),
)
}

View File

@@ -0,0 +1,18 @@
// Copyright 2020 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package avatars_test
import (
"testing"
"code.gitea.io/gitea/models/unittest"
_ "code.gitea.io/gitea/models"
_ "code.gitea.io/gitea/models/activities"
_ "code.gitea.io/gitea/models/perm/access"
)
func TestMain(m *testing.M) {
unittest.MainTest(m)
}

191
models/db/collation.go Normal file
View File

@@ -0,0 +1,191 @@
// Copyright 2023 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db
import (
"errors"
"fmt"
"strings"
"code.gitea.io/gitea/modules/container"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/setting"
"xorm.io/xorm"
"xorm.io/xorm/schemas"
)
type CheckCollationsResult struct {
ExpectedCollation string
AvailableCollation container.Set[string]
DatabaseCollation string
IsCollationCaseSensitive func(s string) bool
CollationEquals func(a, b string) bool
ExistingTableNumber int
InconsistentCollationColumns []string
}
func findAvailableCollationsMySQL(x *xorm.Engine) (ret container.Set[string], err error) {
var res []struct {
Collation string
}
if err = x.SQL("SHOW COLLATION WHERE (Collation = 'utf8mb4_bin') OR (Collation LIKE '%\\_as\\_cs%')").Find(&res); err != nil {
return nil, err
}
ret = make(container.Set[string], len(res))
for _, r := range res {
ret.Add(r.Collation)
}
return ret, nil
}
func findAvailableCollationsMSSQL(x *xorm.Engine) (ret container.Set[string], err error) {
var res []struct {
Name string
}
if err = x.SQL("SELECT * FROM sys.fn_helpcollations() WHERE name LIKE '%[_]CS[_]AS%'").Find(&res); err != nil {
return nil, err
}
ret = make(container.Set[string], len(res))
for _, r := range res {
ret.Add(r.Name)
}
return ret, nil
}
func CheckCollations(x *xorm.Engine) (*CheckCollationsResult, error) {
dbTables, err := x.DBMetas()
if err != nil {
return nil, err
}
res := &CheckCollationsResult{
ExistingTableNumber: len(dbTables),
CollationEquals: func(a, b string) bool { return a == b },
}
var candidateCollations []string
if x.Dialect().URI().DBType == schemas.MYSQL {
_, err = x.SQL("SELECT DEFAULT_COLLATION_NAME FROM INFORMATION_SCHEMA.SCHEMATA WHERE SCHEMA_NAME = ?", setting.Database.Name).Get(&res.DatabaseCollation)
if err != nil {
return nil, err
}
res.IsCollationCaseSensitive = func(s string) bool {
return s == "utf8mb4_bin" || strings.HasSuffix(s, "_as_cs")
}
candidateCollations = []string{"utf8mb4_0900_as_cs", "uca1400_as_cs", "utf8mb4_bin"}
res.AvailableCollation, err = findAvailableCollationsMySQL(x)
if err != nil {
return nil, err
}
res.CollationEquals = func(a, b string) bool {
// MariaDB adds the "utf8mb4_" prefix, eg: "utf8mb4_uca1400_as_cs", but not the name "uca1400_as_cs" in "SHOW COLLATION"
// At the moment, it's safe to ignore the database difference, just trim the prefix and compare. It could be fixed easily if there is any problem in the future.
return a == b || strings.TrimPrefix(a, "utf8mb4_") == strings.TrimPrefix(b, "utf8mb4_")
}
} else if x.Dialect().URI().DBType == schemas.MSSQL {
if _, err = x.SQL("SELECT DATABASEPROPERTYEX(DB_NAME(), 'Collation')").Get(&res.DatabaseCollation); err != nil {
return nil, err
}
res.IsCollationCaseSensitive = func(s string) bool {
return strings.HasSuffix(s, "_CS_AS")
}
candidateCollations = []string{"Latin1_General_CS_AS"}
res.AvailableCollation, err = findAvailableCollationsMSSQL(x)
if err != nil {
return nil, err
}
} else {
return nil, nil
}
if res.DatabaseCollation == "" {
return nil, errors.New("unable to get collation for current database")
}
res.ExpectedCollation = setting.Database.CharsetCollation
if res.ExpectedCollation == "" {
for _, collation := range candidateCollations {
if res.AvailableCollation.Contains(collation) {
res.ExpectedCollation = collation
break
}
}
}
if res.ExpectedCollation == "" {
return nil, errors.New("unable to find a suitable collation for current database")
}
allColumnsMatchExpected := true
allColumnsMatchDatabase := true
for _, table := range dbTables {
for _, col := range table.Columns() {
if col.Collation != "" {
allColumnsMatchExpected = allColumnsMatchExpected && res.CollationEquals(col.Collation, res.ExpectedCollation)
allColumnsMatchDatabase = allColumnsMatchDatabase && res.CollationEquals(col.Collation, res.DatabaseCollation)
if !res.IsCollationCaseSensitive(col.Collation) || !res.CollationEquals(col.Collation, res.DatabaseCollation) {
res.InconsistentCollationColumns = append(res.InconsistentCollationColumns, fmt.Sprintf("%s.%s", table.Name, col.Name))
}
}
}
}
// if all columns match expected collation or all match database collation, then it could also be considered as "consistent"
if allColumnsMatchExpected || allColumnsMatchDatabase {
res.InconsistentCollationColumns = nil
}
return res, nil
}
func CheckCollationsDefaultEngine() (*CheckCollationsResult, error) {
return CheckCollations(x)
}
func alterDatabaseCollation(x *xorm.Engine, collation string) error {
if x.Dialect().URI().DBType == schemas.MYSQL {
_, err := x.Exec("ALTER DATABASE CHARACTER SET utf8mb4 COLLATE " + collation)
return err
} else if x.Dialect().URI().DBType == schemas.MSSQL {
// TODO: MSSQL has many limitations on changing database collation, it could fail in many cases.
_, err := x.Exec("ALTER DATABASE CURRENT COLLATE " + collation)
return err
}
return errors.New("unsupported database type")
}
// preprocessDatabaseCollation checks database & table column collation, and alter the database collation if needed
func preprocessDatabaseCollation(x *xorm.Engine) {
r, err := CheckCollations(x)
if err != nil {
log.Error("Failed to check database collation: %v", err)
}
if r == nil {
return // no check result means the database doesn't need to do such check/process (at the moment ....)
}
// try to alter database collation to expected if the database is empty, it might fail in some cases (and it isn't necessary to succeed)
// at the moment, there is no "altering" solution for MSSQL, site admin should manually change the database collation
if !r.CollationEquals(r.DatabaseCollation, r.ExpectedCollation) && r.ExistingTableNumber == 0 {
if err = alterDatabaseCollation(x, r.ExpectedCollation); err != nil {
log.Error("Failed to change database collation to %q: %v", r.ExpectedCollation, err)
} else {
_, _ = x.Exec("SELECT 1") // after "altering", MSSQL's session becomes invalid, so make a simple query to "refresh" the session
if r, err = CheckCollations(x); err != nil {
log.Error("Failed to check database collation again after altering: %v", err) // impossible case
return
}
log.Warn("Current database has been altered to use collation %q", r.DatabaseCollation)
}
}
// check column collation, and show warning/error to end users -- no need to fatal, do not block the startup
if !r.IsCollationCaseSensitive(r.DatabaseCollation) {
log.Warn("Current database is using a case-insensitive collation %q, although Gitea could work with it, there might be some rare cases which don't work as expected.", r.DatabaseCollation)
}
if len(r.InconsistentCollationColumns) > 0 {
log.Error("There are %d table columns using inconsistent collation, they should use %q. Please go to admin panel Self Check page", len(r.InconsistentCollationColumns), r.DatabaseCollation)
}
}

55
models/db/common.go Normal file
View File

@@ -0,0 +1,55 @@
// Copyright 2022 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db
import (
"strings"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/util"
"xorm.io/builder"
)
// BuildCaseInsensitiveLike returns a condition to check if the given value is like the given key case-insensitively.
// Handles especially SQLite correctly as UPPER there only transforms ASCII letters.
func BuildCaseInsensitiveLike(key, value string) builder.Cond {
if setting.Database.Type.IsSQLite3() {
return builder.Like{"UPPER(" + key + ")", util.ToUpperASCII(value)}
}
return builder.Like{"UPPER(" + key + ")", strings.ToUpper(value)}
}
// BuildCaseInsensitiveIn returns a condition to check if the given value is in the given values case-insensitively.
// Handles especially SQLite correctly as UPPER there only transforms ASCII letters.
func BuildCaseInsensitiveIn(key string, values []string) builder.Cond {
uppers := make([]string, 0, len(values))
if setting.Database.Type.IsSQLite3() {
for _, value := range values {
uppers = append(uppers, util.ToUpperASCII(value))
}
} else {
for _, value := range values {
uppers = append(uppers, strings.ToUpper(value))
}
}
return builder.In("UPPER("+key+")", uppers)
}
// BuilderDialect returns the xorm.Builder dialect of the engine
func BuilderDialect() string {
switch {
case setting.Database.Type.IsMySQL():
return builder.MYSQL
case setting.Database.Type.IsSQLite3():
return builder.SQLITE
case setting.Database.Type.IsPostgreSQL():
return builder.POSTGRES
case setting.Database.Type.IsMSSQL():
return builder.MSSQL
default:
return ""
}
}

31
models/db/consistency.go Normal file
View File

@@ -0,0 +1,31 @@
// Copyright 2022 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db
import (
"context"
"xorm.io/builder"
)
// CountOrphanedObjects count subjects with have no existing refobject anymore
func CountOrphanedObjects(ctx context.Context, subject, refObject, joinCond string) (int64, error) {
return GetEngine(ctx).
Table("`"+subject+"`").
Join("LEFT", "`"+refObject+"`", joinCond).
Where(builder.IsNull{"`" + refObject + "`.id"}).
Select("COUNT(`" + subject + "`.`id`)").
Count()
}
// DeleteOrphanedObjects delete subjects with have no existing refobject anymore
func DeleteOrphanedObjects(ctx context.Context, subject, refObject, joinCond string) error {
subQuery := builder.Select("`"+subject+"`.id").
From("`"+subject+"`").
Join("LEFT", "`"+refObject+"`", joinCond).
Where(builder.IsNull{"`" + refObject + "`.id"})
b := builder.Delete(builder.In("id", subQuery)).From("`" + subject + "`")
_, err := GetEngine(ctx).Exec(b)
return err
}

351
models/db/context.go Normal file
View File

@@ -0,0 +1,351 @@
// Copyright 2019 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db
import (
"context"
"database/sql"
"errors"
"runtime"
"slices"
"sync"
"code.gitea.io/gitea/modules/setting"
"xorm.io/builder"
"xorm.io/xorm"
)
// DefaultContext is the default context to run xorm queries in
// will be overwritten by Init with HammerContext
var DefaultContext context.Context
type engineContextKeyType struct{}
var engineContextKey = engineContextKeyType{}
// Context represents a db context
type Context struct {
context.Context
engine Engine
}
func newContext(ctx context.Context, e Engine) *Context {
return &Context{Context: ctx, engine: e}
}
// Value shadows Value for context.Context but allows us to get ourselves and an Engined object
func (ctx *Context) Value(key any) any {
if key == engineContextKey {
return ctx
}
return ctx.Context.Value(key)
}
// WithContext returns this engine tied to this context
func (ctx *Context) WithContext(other context.Context) *Context {
return newContext(ctx, ctx.engine.Context(other))
}
var (
contextSafetyOnce sync.Once
contextSafetyDeniedFuncPCs []uintptr
)
func contextSafetyCheck(e Engine) {
if setting.IsProd && !setting.IsInTesting {
return
}
if e == nil {
return
}
// Only do this check for non-end-users. If the problem could be fixed in the future, this code could be removed.
contextSafetyOnce.Do(func() {
// try to figure out the bad functions to deny
type m struct{}
_ = e.SQL("SELECT 1").Iterate(&m{}, func(int, any) error {
callers := make([]uintptr, 32)
callerNum := runtime.Callers(1, callers)
for i := 0; i < callerNum; i++ {
if funcName := runtime.FuncForPC(callers[i]).Name(); funcName == "xorm.io/xorm.(*Session).Iterate" {
contextSafetyDeniedFuncPCs = append(contextSafetyDeniedFuncPCs, callers[i])
}
}
return nil
})
if len(contextSafetyDeniedFuncPCs) != 1 {
panic(errors.New("unable to determine the functions to deny"))
}
})
// it should be very fast: xxxx ns/op
callers := make([]uintptr, 32)
callerNum := runtime.Callers(3, callers) // skip 3: runtime.Callers, contextSafetyCheck, GetEngine
for i := 0; i < callerNum; i++ {
if slices.Contains(contextSafetyDeniedFuncPCs, callers[i]) {
panic(errors.New("using database context in an iterator would cause corrupted results"))
}
}
}
// GetEngine gets an existing db Engine/Statement or creates a new Session
func GetEngine(ctx context.Context) Engine {
if e := getExistingEngine(ctx); e != nil {
return e
}
return x.Context(ctx)
}
// getExistingEngine gets an existing db Engine/Statement from this context or returns nil
func getExistingEngine(ctx context.Context) (e Engine) {
defer func() { contextSafetyCheck(e) }()
if engined, ok := ctx.(*Context); ok {
return engined.engine
}
if engined, ok := ctx.Value(engineContextKey).(*Context); ok {
return engined.engine
}
return nil
}
// Committer represents an interface to Commit or Close the Context
type Committer interface {
Commit() error
Close() error
}
// halfCommitter is a wrapper of Committer.
// It can be closed early, but can't be committed early, it is useful for reusing a transaction.
type halfCommitter struct {
committer Committer
committed bool
}
func (c *halfCommitter) Commit() error {
c.committed = true
// should do nothing, and the parent committer will commit later
return nil
}
func (c *halfCommitter) Close() error {
if c.committed {
// it's "commit and close", should do nothing, and the parent committer will commit later
return nil
}
// it's "rollback and close", let the parent committer rollback right now
return c.committer.Close()
}
// TxContext represents a transaction Context,
// it will reuse the existing transaction in the parent context or create a new one.
// Some tips to use:
//
// 1 It's always recommended to use `WithTx` in new code instead of `TxContext`, since `WithTx` will handle the transaction automatically.
// 2. To maintain the old code which uses `TxContext`:
// a. Always call `Close()` before returning regardless of whether `Commit()` has been called.
// b. Always call `Commit()` before returning if there are no errors, even if the code did not change any data.
// c. Remember the `Committer` will be a halfCommitter when a transaction is being reused.
// So calling `Commit()` will do nothing, but calling `Close()` without calling `Commit()` will rollback the transaction.
// And all operations submitted by the caller stack will be rollbacked as well, not only the operations in the current function.
// d. It doesn't mean rollback is forbidden, but always do it only when there is an error, and you do want to rollback.
func TxContext(parentCtx context.Context) (*Context, Committer, error) {
if sess, ok := inTransaction(parentCtx); ok {
return newContext(parentCtx, sess), &halfCommitter{committer: sess}, nil
}
sess := x.NewSession()
if err := sess.Begin(); err != nil {
_ = sess.Close()
return nil, nil, err
}
return newContext(DefaultContext, sess), sess, nil
}
// WithTx represents executing database operations on a transaction, if the transaction exist,
// this function will reuse it otherwise will create a new one and close it when finished.
func WithTx(parentCtx context.Context, f func(ctx context.Context) error) error {
if sess, ok := inTransaction(parentCtx); ok {
err := f(newContext(parentCtx, sess))
if err != nil {
// rollback immediately, in case the caller ignores returned error and tries to commit the transaction.
_ = sess.Close()
}
return err
}
return txWithNoCheck(parentCtx, f)
}
func txWithNoCheck(parentCtx context.Context, f func(ctx context.Context) error) error {
sess := x.NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
}
if err := f(newContext(parentCtx, sess)); err != nil {
return err
}
return sess.Commit()
}
// Insert inserts records into database
func Insert(ctx context.Context, beans ...any) error {
_, err := GetEngine(ctx).Insert(beans...)
return err
}
// Exec executes a sql with args
func Exec(ctx context.Context, sqlAndArgs ...any) (sql.Result, error) {
return GetEngine(ctx).Exec(sqlAndArgs...)
}
func Get[T any](ctx context.Context, cond builder.Cond) (object *T, exist bool, err error) {
if !cond.IsValid() {
panic("cond is invalid in db.Get(ctx, cond). This should not be possible.")
}
var bean T
has, err := GetEngine(ctx).Where(cond).NoAutoCondition().Get(&bean)
if err != nil {
return nil, false, err
} else if !has {
return nil, false, nil
}
return &bean, true, nil
}
func GetByID[T any](ctx context.Context, id int64) (object *T, exist bool, err error) {
var bean T
has, err := GetEngine(ctx).ID(id).NoAutoCondition().Get(&bean)
if err != nil {
return nil, false, err
} else if !has {
return nil, false, nil
}
return &bean, true, nil
}
func Exist[T any](ctx context.Context, cond builder.Cond) (bool, error) {
if !cond.IsValid() {
panic("cond is invalid in db.Exist(ctx, cond). This should not be possible.")
}
var bean T
return GetEngine(ctx).Where(cond).NoAutoCondition().Exist(&bean)
}
func ExistByID[T any](ctx context.Context, id int64) (bool, error) {
var bean T
return GetEngine(ctx).ID(id).NoAutoCondition().Exist(&bean)
}
// DeleteByID deletes the given bean with the given ID
func DeleteByID[T any](ctx context.Context, id int64) (int64, error) {
var bean T
return GetEngine(ctx).ID(id).NoAutoCondition().NoAutoTime().Delete(&bean)
}
func DeleteByIDs[T any](ctx context.Context, ids ...int64) error {
if len(ids) == 0 {
return nil
}
var bean T
_, err := GetEngine(ctx).In("id", ids).NoAutoCondition().NoAutoTime().Delete(&bean)
return err
}
func Delete[T any](ctx context.Context, opts FindOptions) (int64, error) {
if opts == nil || !opts.ToConds().IsValid() {
panic("opts are empty or invalid in db.Delete(ctx, opts). This should not be possible.")
}
var bean T
return GetEngine(ctx).Where(opts.ToConds()).NoAutoCondition().NoAutoTime().Delete(&bean)
}
// DeleteByBean deletes all records according non-empty fields of the bean as conditions.
func DeleteByBean(ctx context.Context, bean any) (int64, error) {
return GetEngine(ctx).Delete(bean)
}
// FindIDs finds the IDs for the given table name satisfying the given condition
// By passing a different value than "id" for "idCol", you can query for foreign IDs, i.e. the repo IDs which satisfy the condition
func FindIDs(ctx context.Context, tableName, idCol string, cond builder.Cond) ([]int64, error) {
ids := make([]int64, 0, 10)
if err := GetEngine(ctx).Table(tableName).
Cols(idCol).
Where(cond).
Find(&ids); err != nil {
return nil, err
}
return ids, nil
}
// DecrByIDs decreases the given column for entities of the "bean" type with one of the given ids by one
// Timestamps of the entities won't be updated
func DecrByIDs(ctx context.Context, ids []int64, decrCol string, bean any) error {
_, err := GetEngine(ctx).Decr(decrCol).In("id", ids).NoAutoCondition().NoAutoTime().Update(bean)
return err
}
// DeleteBeans deletes all given beans, beans must contain delete conditions.
func DeleteBeans(ctx context.Context, beans ...any) (err error) {
e := GetEngine(ctx)
for i := range beans {
if _, err = e.Delete(beans[i]); err != nil {
return err
}
}
return nil
}
// TruncateBeans deletes all given beans, beans may contain delete conditions.
func TruncateBeans(ctx context.Context, beans ...any) (err error) {
e := GetEngine(ctx)
for i := range beans {
if _, err = e.Truncate(beans[i]); err != nil {
return err
}
}
return nil
}
// CountByBean counts the number of database records according non-empty fields of the bean as conditions.
func CountByBean(ctx context.Context, bean any) (int64, error) {
return GetEngine(ctx).Count(bean)
}
// TableName returns the table name according a bean object
func TableName(bean any) string {
return x.TableName(bean)
}
// InTransaction returns true if the engine is in a transaction otherwise return false
func InTransaction(ctx context.Context) bool {
_, ok := inTransaction(ctx)
return ok
}
func inTransaction(ctx context.Context) (*xorm.Session, bool) {
e := getExistingEngine(ctx)
if e == nil {
return nil, false
}
switch t := e.(type) {
case *xorm.Engine:
return nil, false
case *xorm.Session:
if t.IsInTx() {
return t, true
}
return nil, false
default:
return nil, false
}
}

View File

@@ -0,0 +1,102 @@
// Copyright 2022 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db // it's not db_test, because this file is for testing the private type halfCommitter
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
)
type MockCommitter struct {
wants []string
gots []string
}
func NewMockCommitter(wants ...string) *MockCommitter {
return &MockCommitter{
wants: wants,
}
}
func (c *MockCommitter) Commit() error {
c.gots = append(c.gots, "commit")
return nil
}
func (c *MockCommitter) Close() error {
c.gots = append(c.gots, "close")
return nil
}
func (c *MockCommitter) Assert(t *testing.T) {
assert.Equal(t, c.wants, c.gots, "want operations %v, but got %v", c.wants, c.gots)
}
func Test_halfCommitter(t *testing.T) {
/*
Do something like:
ctx, committer, err := db.TxContext(db.DefaultContext)
if err != nil {
return nil
}
defer committer.Close()
// ...
if err != nil {
return nil
}
// ...
return committer.Commit()
*/
testWithCommitter := func(committer Committer, f func(committer Committer) error) {
if err := f(&halfCommitter{committer: committer}); err == nil {
committer.Commit()
}
committer.Close()
}
t.Run("commit and close", func(t *testing.T) {
mockCommitter := NewMockCommitter("commit", "close")
testWithCommitter(mockCommitter, func(committer Committer) error {
defer committer.Close()
return committer.Commit()
})
mockCommitter.Assert(t)
})
t.Run("rollback and close", func(t *testing.T) {
mockCommitter := NewMockCommitter("close", "close")
testWithCommitter(mockCommitter, func(committer Committer) error {
defer committer.Close()
if true {
return errors.New("error")
}
return committer.Commit()
})
mockCommitter.Assert(t)
})
t.Run("close and commit", func(t *testing.T) {
mockCommitter := NewMockCommitter("close", "close")
testWithCommitter(mockCommitter, func(committer Committer) error {
committer.Close()
committer.Commit()
return errors.New("error")
})
mockCommitter.Assert(t)
})
}

130
models/db/context_test.go Normal file
View File

@@ -0,0 +1,130 @@
// Copyright 2022 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db_test
import (
"context"
"testing"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/unittest"
"github.com/stretchr/testify/assert"
)
func TestInTransaction(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
assert.False(t, db.InTransaction(db.DefaultContext))
assert.NoError(t, db.WithTx(db.DefaultContext, func(ctx context.Context) error {
assert.True(t, db.InTransaction(ctx))
return nil
}))
ctx, committer, err := db.TxContext(db.DefaultContext)
assert.NoError(t, err)
defer committer.Close()
assert.True(t, db.InTransaction(ctx))
assert.NoError(t, db.WithTx(ctx, func(ctx context.Context) error {
assert.True(t, db.InTransaction(ctx))
return nil
}))
}
func TestTxContext(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
{ // create new transaction
ctx, committer, err := db.TxContext(db.DefaultContext)
assert.NoError(t, err)
assert.True(t, db.InTransaction(ctx))
assert.NoError(t, committer.Commit())
}
{ // reuse the transaction created by TxContext and commit it
ctx, committer, err := db.TxContext(db.DefaultContext)
engine := db.GetEngine(ctx)
assert.NoError(t, err)
assert.True(t, db.InTransaction(ctx))
{
ctx, committer, err := db.TxContext(ctx)
assert.NoError(t, err)
assert.True(t, db.InTransaction(ctx))
assert.Equal(t, engine, db.GetEngine(ctx))
assert.NoError(t, committer.Commit())
}
assert.NoError(t, committer.Commit())
}
{ // reuse the transaction created by TxContext and close it
ctx, committer, err := db.TxContext(db.DefaultContext)
engine := db.GetEngine(ctx)
assert.NoError(t, err)
assert.True(t, db.InTransaction(ctx))
{
ctx, committer, err := db.TxContext(ctx)
assert.NoError(t, err)
assert.True(t, db.InTransaction(ctx))
assert.Equal(t, engine, db.GetEngine(ctx))
assert.NoError(t, committer.Close())
}
assert.NoError(t, committer.Close())
}
{ // reuse the transaction created by WithTx
assert.NoError(t, db.WithTx(db.DefaultContext, func(ctx context.Context) error {
assert.True(t, db.InTransaction(ctx))
{
ctx, committer, err := db.TxContext(ctx)
assert.NoError(t, err)
assert.True(t, db.InTransaction(ctx))
assert.NoError(t, committer.Commit())
}
return nil
}))
}
}
func TestContextSafety(t *testing.T) {
type TestModel1 struct {
ID int64
}
type TestModel2 struct {
ID int64
}
assert.NoError(t, unittest.GetXORMEngine().Sync(&TestModel1{}, &TestModel2{}))
assert.NoError(t, db.TruncateBeans(db.DefaultContext, &TestModel1{}, &TestModel2{}))
testCount := 10
for i := 1; i <= testCount; i++ {
assert.NoError(t, db.Insert(db.DefaultContext, &TestModel1{ID: int64(i)}))
assert.NoError(t, db.Insert(db.DefaultContext, &TestModel2{ID: int64(-i)}))
}
actualCount := 0
// here: db.GetEngine(db.DefaultContext) is a new *Session created from *Engine
_ = db.WithTx(db.DefaultContext, func(ctx context.Context) error {
_ = db.GetEngine(ctx).Iterate(&TestModel1{}, func(i int, bean any) error {
// here: db.GetEngine(ctx) is always the unclosed "Iterate" *Session with autoResetStatement=false,
// and the internal states (including "cond" and others) are always there and not be reset in this callback.
m1 := bean.(*TestModel1)
assert.EqualValues(t, i+1, m1.ID)
// here: XORM bug, it fails because the SQL becomes "WHERE id=-1", "WHERE id=-1 AND id=-2", "WHERE id=-1 AND id=-2 AND id=-3" ...
// and it conflicts with the "Iterate"'s internal states.
// has, err := db.GetEngine(ctx).Get(&TestModel2{ID: -m1.ID})
actualCount++
return nil
})
return nil
})
assert.EqualValues(t, testCount, actualCount)
// deny the bad usages
assert.PanicsWithError(t, "using database context in an iterator would cause corrupted results", func() {
_ = unittest.GetXORMEngine().Iterate(&TestModel1{}, func(i int, bean any) error {
_ = db.GetEngine(db.DefaultContext)
return nil
})
})
}

88
models/db/convert.go Normal file
View File

@@ -0,0 +1,88 @@
// Copyright 2019 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db
import (
"fmt"
"strconv"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/setting"
"xorm.io/xorm"
"xorm.io/xorm/schemas"
)
// ConvertDatabaseTable converts database and tables from utf8 to utf8mb4 if it's mysql and set ROW_FORMAT=dynamic
func ConvertDatabaseTable() error {
if x.Dialect().URI().DBType != schemas.MYSQL {
return nil
}
r, err := CheckCollations(x)
if err != nil {
return err
}
_, err = x.Exec(fmt.Sprintf("ALTER DATABASE `%s` CHARACTER SET utf8mb4 COLLATE %s", setting.Database.Name, r.ExpectedCollation))
if err != nil {
return err
}
tables, err := x.DBMetas()
if err != nil {
return err
}
for _, table := range tables {
if _, err := x.Exec(fmt.Sprintf("ALTER TABLE `%s` ROW_FORMAT=dynamic", table.Name)); err != nil {
return err
}
if _, err := x.Exec(fmt.Sprintf("ALTER TABLE `%s` CONVERT TO CHARACTER SET utf8mb4 COLLATE %s", table.Name, r.ExpectedCollation)); err != nil {
return err
}
}
return nil
}
// ConvertVarcharToNVarchar converts database and tables from varchar to nvarchar if it's mssql
func ConvertVarcharToNVarchar() error {
if x.Dialect().URI().DBType != schemas.MSSQL {
return nil
}
sess := x.NewSession()
defer sess.Close()
res, err := sess.QuerySliceString(`SELECT 'ALTER TABLE ' + OBJECT_NAME(SC.object_id) + ' MODIFY SC.name NVARCHAR(' + CONVERT(VARCHAR(5),SC.max_length) + ')'
FROM SYS.columns SC
JOIN SYS.types ST
ON SC.system_type_id = ST.system_type_id
AND SC.user_type_id = ST.user_type_id
WHERE ST.name ='varchar'`)
if err != nil {
return err
}
for _, row := range res {
if len(row) == 1 {
if _, err = sess.Exec(row[0]); err != nil {
return err
}
}
}
return err
}
// Cell2Int64 converts a xorm.Cell type to int64,
// and handles possible irregular cases.
func Cell2Int64(val xorm.Cell) int64 {
switch (*val).(type) {
case []uint8:
log.Trace("Cell2Int64 ([]uint8): %v", *val)
v, _ := strconv.ParseInt(string((*val).([]uint8)), 10, 64)
return v
}
return (*val).(int64)
}

331
models/db/engine.go Normal file
View File

@@ -0,0 +1,331 @@
// Copyright 2014 The Gogs Authors. All rights reserved.
// Copyright 2018 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db
import (
"context"
"database/sql"
"fmt"
"io"
"reflect"
"strings"
"time"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/setting"
"xorm.io/xorm"
"xorm.io/xorm/contexts"
"xorm.io/xorm/names"
"xorm.io/xorm/schemas"
_ "github.com/go-sql-driver/mysql" // Needed for the MySQL driver
_ "github.com/lib/pq" // Needed for the Postgresql driver
_ "github.com/microsoft/go-mssqldb" // Needed for the MSSQL driver
)
var (
x *xorm.Engine
tables []any
initFuncs []func() error
)
// Engine represents a xorm engine or session.
type Engine interface {
Table(tableNameOrBean any) *xorm.Session
Count(...any) (int64, error)
Decr(column string, arg ...any) *xorm.Session
Delete(...any) (int64, error)
Truncate(...any) (int64, error)
Exec(...any) (sql.Result, error)
Find(any, ...any) error
Get(beans ...any) (bool, error)
ID(any) *xorm.Session
In(string, ...any) *xorm.Session
Incr(column string, arg ...any) *xorm.Session
Insert(...any) (int64, error)
Iterate(any, xorm.IterFunc) error
Join(joinOperator string, tablename, condition any, args ...any) *xorm.Session
SQL(any, ...any) *xorm.Session
Where(any, ...any) *xorm.Session
Asc(colNames ...string) *xorm.Session
Desc(colNames ...string) *xorm.Session
Limit(limit int, start ...int) *xorm.Session
NoAutoTime() *xorm.Session
SumInt(bean any, columnName string) (res int64, err error)
Sync(...any) error
Select(string) *xorm.Session
SetExpr(string, any) *xorm.Session
NotIn(string, ...any) *xorm.Session
OrderBy(any, ...any) *xorm.Session
Exist(...any) (bool, error)
Distinct(...string) *xorm.Session
Query(...any) ([]map[string][]byte, error)
Cols(...string) *xorm.Session
Context(ctx context.Context) *xorm.Session
Ping() error
}
// TableInfo returns table's information via an object
func TableInfo(v any) (*schemas.Table, error) {
return x.TableInfo(v)
}
// DumpTables dump tables information
func DumpTables(tables []*schemas.Table, w io.Writer, tp ...schemas.DBType) error {
return x.DumpTables(tables, w, tp...)
}
// RegisterModel registers model, if initfunc provided, it will be invoked after data model sync
func RegisterModel(bean any, initFunc ...func() error) {
tables = append(tables, bean)
if len(initFuncs) > 0 && initFunc[0] != nil {
initFuncs = append(initFuncs, initFunc[0])
}
}
func init() {
gonicNames := []string{"SSL", "UID"}
for _, name := range gonicNames {
names.LintGonicMapper[name] = true
}
}
// newXORMEngine returns a new XORM engine from the configuration
func newXORMEngine() (*xorm.Engine, error) {
connStr, err := setting.DBConnStr()
if err != nil {
return nil, err
}
var engine *xorm.Engine
if setting.Database.Type.IsPostgreSQL() && len(setting.Database.Schema) > 0 {
// OK whilst we sort out our schema issues - create a schema aware postgres
registerPostgresSchemaDriver()
engine, err = xorm.NewEngine("postgresschema", connStr)
} else {
engine, err = xorm.NewEngine(setting.Database.Type.String(), connStr)
}
if err != nil {
return nil, err
}
if setting.Database.Type == "mysql" {
engine.Dialect().SetParams(map[string]string{"rowFormat": "DYNAMIC"})
} else if setting.Database.Type == "mssql" {
engine.Dialect().SetParams(map[string]string{"DEFAULT_VARCHAR": "nvarchar"})
}
engine.SetSchema(setting.Database.Schema)
return engine, nil
}
// SyncAllTables sync the schemas of all tables, is required by unit test code
func SyncAllTables() error {
_, err := x.StoreEngine("InnoDB").SyncWithOptions(xorm.SyncOptions{
WarnIfDatabaseColumnMissed: true,
}, tables...)
return err
}
// InitEngine initializes the xorm.Engine and sets it as db.DefaultContext
func InitEngine(ctx context.Context) error {
xormEngine, err := newXORMEngine()
if err != nil {
if strings.Contains(err.Error(), "SQLite3 support") {
return fmt.Errorf(`sqlite3 requires: -tags sqlite,sqlite_unlock_notify%s%w`, "\n", err)
}
return fmt.Errorf("failed to connect to database: %w", err)
}
xormEngine.SetMapper(names.GonicMapper{})
// WARNING: for serv command, MUST remove the output to os.stdout,
// so use log file to instead print to stdout.
xormEngine.SetLogger(NewXORMLogger(setting.Database.LogSQL))
xormEngine.ShowSQL(setting.Database.LogSQL)
xormEngine.SetMaxOpenConns(setting.Database.MaxOpenConns)
xormEngine.SetMaxIdleConns(setting.Database.MaxIdleConns)
xormEngine.SetConnMaxLifetime(setting.Database.ConnMaxLifetime)
xormEngine.SetDefaultContext(ctx)
if setting.Database.SlowQueryThreshold > 0 {
xormEngine.AddHook(&SlowQueryHook{
Threshold: setting.Database.SlowQueryThreshold,
Logger: log.GetLogger("xorm"),
})
}
SetDefaultEngine(ctx, xormEngine)
return nil
}
// SetDefaultEngine sets the default engine for db
func SetDefaultEngine(ctx context.Context, eng *xorm.Engine) {
x = eng
DefaultContext = &Context{Context: ctx, engine: x}
}
// UnsetDefaultEngine closes and unsets the default engine
// We hope the SetDefaultEngine and UnsetDefaultEngine can be paired, but it's impossible now,
// there are many calls to InitEngine -> SetDefaultEngine directly to overwrite the `x` and DefaultContext without close
// Global database engine related functions are all racy and there is no graceful close right now.
func UnsetDefaultEngine() {
if x != nil {
_ = x.Close()
x = nil
}
DefaultContext = nil
}
// InitEngineWithMigration initializes a new xorm.Engine and sets it as the db.DefaultContext
// This function must never call .Sync() if the provided migration function fails.
// When called from the "doctor" command, the migration function is a version check
// that prevents the doctor from fixing anything in the database if the migration level
// is different from the expected value.
func InitEngineWithMigration(ctx context.Context, migrateFunc func(*xorm.Engine) error) (err error) {
if err = InitEngine(ctx); err != nil {
return err
}
if err = x.Ping(); err != nil {
return err
}
preprocessDatabaseCollation(x)
// We have to run migrateFunc here in case the user is re-running installation on a previously created DB.
// If we do not then table schemas will be changed and there will be conflicts when the migrations run properly.
//
// Installation should only be being re-run if users want to recover an old database.
// However, we should think carefully about should we support re-install on an installed instance,
// as there may be other problems due to secret reinitialization.
if err = migrateFunc(x); err != nil {
return fmt.Errorf("migrate: %w", err)
}
if err = SyncAllTables(); err != nil {
return fmt.Errorf("sync database struct error: %w", err)
}
for _, initFunc := range initFuncs {
if err := initFunc(); err != nil {
return fmt.Errorf("initFunc failed: %w", err)
}
}
return nil
}
// NamesToBean return a list of beans or an error
func NamesToBean(names ...string) ([]any, error) {
beans := []any{}
if len(names) == 0 {
beans = append(beans, tables...)
return beans, nil
}
// Need to map provided names to beans...
beanMap := make(map[string]any)
for _, bean := range tables {
beanMap[strings.ToLower(reflect.Indirect(reflect.ValueOf(bean)).Type().Name())] = bean
beanMap[strings.ToLower(x.TableName(bean))] = bean
beanMap[strings.ToLower(x.TableName(bean, true))] = bean
}
gotBean := make(map[any]bool)
for _, name := range names {
bean, ok := beanMap[strings.ToLower(strings.TrimSpace(name))]
if !ok {
return nil, fmt.Errorf("no table found that matches: %s", name)
}
if !gotBean[bean] {
beans = append(beans, bean)
gotBean[bean] = true
}
}
return beans, nil
}
// DumpDatabase dumps all data from database according the special database SQL syntax to file system.
func DumpDatabase(filePath, dbType string) error {
var tbs []*schemas.Table
for _, t := range tables {
t, err := x.TableInfo(t)
if err != nil {
return err
}
tbs = append(tbs, t)
}
type Version struct {
ID int64 `xorm:"pk autoincr"`
Version int64
}
t, err := x.TableInfo(&Version{})
if err != nil {
return err
}
tbs = append(tbs, t)
if len(dbType) > 0 {
return x.DumpTablesToFile(tbs, filePath, schemas.DBType(dbType))
}
return x.DumpTablesToFile(tbs, filePath)
}
// MaxBatchInsertSize returns the table's max batch insert size
func MaxBatchInsertSize(bean any) int {
t, err := x.TableInfo(bean)
if err != nil {
return 50
}
return 999 / len(t.ColumnsSeq())
}
// IsTableNotEmpty returns true if table has at least one record
func IsTableNotEmpty(beanOrTableName any) (bool, error) {
return x.Table(beanOrTableName).Exist()
}
// DeleteAllRecords will delete all the records of this table
func DeleteAllRecords(tableName string) error {
_, err := x.Exec(fmt.Sprintf("DELETE FROM %s", tableName))
return err
}
// GetMaxID will return max id of the table
func GetMaxID(beanOrTableName any) (maxID int64, err error) {
_, err = x.Select("MAX(id)").Table(beanOrTableName).Get(&maxID)
return maxID, err
}
func SetLogSQL(ctx context.Context, on bool) {
e := GetEngine(ctx)
if x, ok := e.(*xorm.Engine); ok {
x.ShowSQL(on)
} else if sess, ok := e.(*xorm.Session); ok {
sess.Engine().ShowSQL(on)
}
}
type SlowQueryHook struct {
Threshold time.Duration
Logger log.Logger
}
var _ contexts.Hook = &SlowQueryHook{}
func (SlowQueryHook) BeforeProcess(c *contexts.ContextHook) (context.Context, error) {
return c.Ctx, nil
}
func (h *SlowQueryHook) AfterProcess(c *contexts.ContextHook) error {
if c.ExecuteTime >= h.Threshold {
// 8 is the amount of skips passed to runtime.Caller, so that in the log the correct function
// is being displayed (the function that ultimately wants to execute the query in the code)
// instead of the function of the slow query hook being called.
h.Logger.Log(8, log.WARN, "[Slow SQL Query] %s %v - %v", c.SQL, c.Args, c.ExecuteTime)
}
return nil
}

86
models/db/engine_test.go Normal file
View File

@@ -0,0 +1,86 @@
// Copyright 2019 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db_test
import (
"path/filepath"
"testing"
"code.gitea.io/gitea/models/db"
issues_model "code.gitea.io/gitea/models/issues"
"code.gitea.io/gitea/models/unittest"
"code.gitea.io/gitea/modules/setting"
_ "code.gitea.io/gitea/cmd" // for TestPrimaryKeys
"github.com/stretchr/testify/assert"
)
func TestDumpDatabase(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
dir := t.TempDir()
type Version struct {
ID int64 `xorm:"pk autoincr"`
Version int64
}
assert.NoError(t, db.GetEngine(db.DefaultContext).Sync(new(Version)))
for _, dbType := range setting.SupportedDatabaseTypes {
assert.NoError(t, db.DumpDatabase(filepath.Join(dir, dbType+".sql"), dbType))
}
}
func TestDeleteOrphanedObjects(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
countBefore, err := db.GetEngine(db.DefaultContext).Count(&issues_model.PullRequest{})
assert.NoError(t, err)
_, err = db.GetEngine(db.DefaultContext).Insert(&issues_model.PullRequest{IssueID: 1000}, &issues_model.PullRequest{IssueID: 1001}, &issues_model.PullRequest{IssueID: 1003})
assert.NoError(t, err)
orphaned, err := db.CountOrphanedObjects(db.DefaultContext, "pull_request", "issue", "pull_request.issue_id=issue.id")
assert.NoError(t, err)
assert.EqualValues(t, 3, orphaned)
err = db.DeleteOrphanedObjects(db.DefaultContext, "pull_request", "issue", "pull_request.issue_id=issue.id")
assert.NoError(t, err)
countAfter, err := db.GetEngine(db.DefaultContext).Count(&issues_model.PullRequest{})
assert.NoError(t, err)
assert.EqualValues(t, countBefore, countAfter)
}
func TestPrimaryKeys(t *testing.T) {
// Some dbs require that all tables have primary keys, see
// https://github.com/go-gitea/gitea/issues/21086
// https://github.com/go-gitea/gitea/issues/16802
// To avoid creating tables without primary key again, this test will check them.
// Import "code.gitea.io/gitea/cmd" to make sure each db.RegisterModel in init functions has been called.
beans, err := db.NamesToBean()
if err != nil {
t.Fatal(err)
}
whitelist := map[string]string{
"the_table_name_to_skip_checking": "Write a note here to explain why",
}
for _, bean := range beans {
table, err := db.TableInfo(bean)
if err != nil {
t.Fatal(err)
}
if why, ok := whitelist[table.Name]; ok {
t.Logf("ignore %q because %q", table.Name, why)
continue
}
if len(table.PrimaryKeys) == 0 {
t.Errorf("table %q has no primary key", table.Name)
}
}
}

74
models/db/error.go Normal file
View File

@@ -0,0 +1,74 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db
import (
"fmt"
"code.gitea.io/gitea/modules/util"
)
// ErrCancelled represents an error due to context cancellation
type ErrCancelled struct {
Message string
}
// IsErrCancelled checks if an error is a ErrCancelled.
func IsErrCancelled(err error) bool {
_, ok := err.(ErrCancelled)
return ok
}
func (err ErrCancelled) Error() string {
return "Cancelled: " + err.Message
}
// ErrCancelledf returns an ErrCancelled for the provided format and args
func ErrCancelledf(format string, args ...any) error {
return ErrCancelled{
fmt.Sprintf(format, args...),
}
}
// ErrSSHDisabled represents an "SSH disabled" error.
type ErrSSHDisabled struct{}
// IsErrSSHDisabled checks if an error is a ErrSSHDisabled.
func IsErrSSHDisabled(err error) bool {
_, ok := err.(ErrSSHDisabled)
return ok
}
func (err ErrSSHDisabled) Error() string {
return "SSH is disabled"
}
// ErrNotExist represents a non-exist error.
type ErrNotExist struct {
Resource string
ID int64
}
// IsErrNotExist checks if an error is an ErrNotExist
func IsErrNotExist(err error) bool {
_, ok := err.(ErrNotExist)
return ok
}
func (err ErrNotExist) Error() string {
name := "record"
if err.Resource != "" {
name = err.Resource
}
if err.ID != 0 {
return fmt.Sprintf("%s does not exist [id: %d]", name, err.ID)
}
return fmt.Sprintf("%s does not exist", name)
}
// Unwrap unwraps this as a ErrNotExist err
func (err ErrNotExist) Unwrap() error {
return util.ErrNotExist
}

177
models/db/index.go Normal file
View File

@@ -0,0 +1,177 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db
import (
"context"
"errors"
"fmt"
"strconv"
"code.gitea.io/gitea/modules/setting"
)
// ResourceIndex represents a resource index which could be used as issue/release and others
// We can create different tables i.e. issue_index, release_index, etc.
type ResourceIndex struct {
GroupID int64 `xorm:"pk"`
MaxIndex int64 `xorm:"index"`
}
var (
// ErrResouceOutdated represents an error when request resource outdated
ErrResouceOutdated = errors.New("resource outdated")
// ErrGetResourceIndexFailed represents an error when resource index retries 3 times
ErrGetResourceIndexFailed = errors.New("get resource index failed")
)
// SyncMaxResourceIndex sync the max index with the resource
func SyncMaxResourceIndex(ctx context.Context, tableName string, groupID, maxIndex int64) (err error) {
e := GetEngine(ctx)
// try to update the max_index and acquire the write-lock for the record
res, err := e.Exec(fmt.Sprintf("UPDATE %s SET max_index=? WHERE group_id=? AND max_index<?", tableName), maxIndex, groupID, maxIndex)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
// if nothing is updated, the record might not exist or might be larger, it's safe to try to insert it again and then check whether the record exists
_, errIns := e.Exec(fmt.Sprintf("INSERT INTO %s (group_id, max_index) VALUES (?, ?)", tableName), groupID, maxIndex)
var savedIdx int64
has, err := e.SQL(fmt.Sprintf("SELECT max_index FROM %s WHERE group_id=?", tableName), groupID).Get(&savedIdx)
if err != nil {
return err
}
// if the record still doesn't exist, there must be some errors (insert error)
if !has {
if errIns == nil {
return errors.New("impossible error when SyncMaxResourceIndex, insert succeeded but no record is saved")
}
return errIns
}
}
return nil
}
func postgresGetNextResourceIndex(ctx context.Context, tableName string, groupID int64) (int64, error) {
res, err := GetEngine(ctx).Query(fmt.Sprintf("INSERT INTO %s (group_id, max_index) "+
"VALUES (?,1) ON CONFLICT (group_id) DO UPDATE SET max_index = %s.max_index+1 RETURNING max_index",
tableName, tableName), groupID)
if err != nil {
return 0, err
}
if len(res) == 0 {
return 0, ErrGetResourceIndexFailed
}
return strconv.ParseInt(string(res[0]["max_index"]), 10, 64)
}
func mysqlGetNextResourceIndex(ctx context.Context, tableName string, groupID int64) (int64, error) {
if _, err := GetEngine(ctx).Exec(fmt.Sprintf("INSERT INTO %s (group_id, max_index) "+
"VALUES (?,1) ON DUPLICATE KEY UPDATE max_index = max_index+1",
tableName), groupID); err != nil {
return 0, err
}
var idx int64
_, err := GetEngine(ctx).SQL(fmt.Sprintf("SELECT max_index FROM %s WHERE group_id = ?", tableName), groupID).Get(&idx)
if err != nil {
return 0, err
}
if idx == 0 {
return 0, errors.New("cannot get the correct index")
}
return idx, nil
}
func mssqlGetNextResourceIndex(ctx context.Context, tableName string, groupID int64) (int64, error) {
if _, err := GetEngine(ctx).Exec(fmt.Sprintf(`
MERGE INTO %s WITH (HOLDLOCK) AS target
USING (SELECT %d AS group_id) AS source
(group_id)
ON target.group_id = source.group_id
WHEN MATCHED
THEN UPDATE
SET max_index = max_index + 1
WHEN NOT MATCHED
THEN INSERT (group_id, max_index)
VALUES (%d, 1);
`, tableName, groupID, groupID)); err != nil {
return 0, err
}
var idx int64
_, err := GetEngine(ctx).SQL(fmt.Sprintf("SELECT max_index FROM %s WHERE group_id = ?", tableName), groupID).Get(&idx)
if err != nil {
return 0, err
}
if idx == 0 {
return 0, errors.New("cannot get the correct index")
}
return idx, nil
}
// GetNextResourceIndex generates a resource index, it must run in the same transaction where the resource is created
func GetNextResourceIndex(ctx context.Context, tableName string, groupID int64) (int64, error) {
switch {
case setting.Database.Type.IsPostgreSQL():
return postgresGetNextResourceIndex(ctx, tableName, groupID)
case setting.Database.Type.IsMySQL():
return mysqlGetNextResourceIndex(ctx, tableName, groupID)
case setting.Database.Type.IsMSSQL():
return mssqlGetNextResourceIndex(ctx, tableName, groupID)
}
e := GetEngine(ctx)
// try to update the max_index to next value, and acquire the write-lock for the record
res, err := e.Exec(fmt.Sprintf("UPDATE %s SET max_index=max_index+1 WHERE group_id=?", tableName), groupID)
if err != nil {
return 0, err
}
affected, err := res.RowsAffected()
if err != nil {
return 0, err
}
if affected == 0 {
// this slow path is only for the first time of creating a resource index
_, errIns := e.Exec(fmt.Sprintf("INSERT INTO %s (group_id, max_index) VALUES (?, 0)", tableName), groupID)
res, err = e.Exec(fmt.Sprintf("UPDATE %s SET max_index=max_index+1 WHERE group_id=?", tableName), groupID)
if err != nil {
return 0, err
}
affected, err = res.RowsAffected()
if err != nil {
return 0, err
}
// if the update still can not update any records, the record must not exist and there must be some errors (insert error)
if affected == 0 {
if errIns == nil {
return 0, errors.New("impossible error when GetNextResourceIndex, insert and update both succeeded but no record is updated")
}
return 0, errIns
}
}
// now, the new index is in database (protected by the transaction and write-lock)
var newIdx int64
has, err := e.SQL(fmt.Sprintf("SELECT max_index FROM %s WHERE group_id=?", tableName), groupID).Get(&newIdx)
if err != nil {
return 0, err
}
if !has {
return 0, errors.New("impossible error when GetNextResourceIndex, upsert succeeded but no record can be selected")
}
return newIdx, nil
}
// DeleteResourceIndex delete resource index
func DeleteResourceIndex(ctx context.Context, tableName string, groupID int64) error {
_, err := Exec(ctx, fmt.Sprintf("DELETE FROM %s WHERE group_id=?", tableName), groupID)
return err
}

126
models/db/index_test.go Normal file
View File

@@ -0,0 +1,126 @@
// Copyright 2022 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db_test
import (
"context"
"errors"
"fmt"
"testing"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/models/unittest"
"github.com/stretchr/testify/assert"
)
type TestIndex db.ResourceIndex
func getCurrentResourceIndex(ctx context.Context, tableName string, groupID int64) (int64, error) {
e := db.GetEngine(ctx)
var idx int64
has, err := e.SQL(fmt.Sprintf("SELECT max_index FROM %s WHERE group_id=?", tableName), groupID).Get(&idx)
if err != nil {
return 0, err
}
if !has {
return 0, errors.New("no record")
}
return idx, nil
}
func TestSyncMaxResourceIndex(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
xe := unittest.GetXORMEngine()
assert.NoError(t, xe.Sync(&TestIndex{}))
err := db.SyncMaxResourceIndex(db.DefaultContext, "test_index", 10, 51)
assert.NoError(t, err)
// sync new max index
maxIndex, err := getCurrentResourceIndex(db.DefaultContext, "test_index", 10)
assert.NoError(t, err)
assert.EqualValues(t, 51, maxIndex)
// smaller index doesn't change
err = db.SyncMaxResourceIndex(db.DefaultContext, "test_index", 10, 30)
assert.NoError(t, err)
maxIndex, err = getCurrentResourceIndex(db.DefaultContext, "test_index", 10)
assert.NoError(t, err)
assert.EqualValues(t, 51, maxIndex)
// larger index changes
err = db.SyncMaxResourceIndex(db.DefaultContext, "test_index", 10, 62)
assert.NoError(t, err)
maxIndex, err = getCurrentResourceIndex(db.DefaultContext, "test_index", 10)
assert.NoError(t, err)
assert.EqualValues(t, 62, maxIndex)
// commit transaction
err = db.WithTx(db.DefaultContext, func(ctx context.Context) error {
err = db.SyncMaxResourceIndex(ctx, "test_index", 10, 73)
assert.NoError(t, err)
maxIndex, err = getCurrentResourceIndex(ctx, "test_index", 10)
assert.NoError(t, err)
assert.EqualValues(t, 73, maxIndex)
return nil
})
assert.NoError(t, err)
maxIndex, err = getCurrentResourceIndex(db.DefaultContext, "test_index", 10)
assert.NoError(t, err)
assert.EqualValues(t, 73, maxIndex)
// rollback transaction
err = db.WithTx(db.DefaultContext, func(ctx context.Context) error {
err = db.SyncMaxResourceIndex(ctx, "test_index", 10, 84)
maxIndex, err = getCurrentResourceIndex(ctx, "test_index", 10)
assert.NoError(t, err)
assert.EqualValues(t, 84, maxIndex)
return errors.New("test rollback")
})
assert.Error(t, err)
maxIndex, err = getCurrentResourceIndex(db.DefaultContext, "test_index", 10)
assert.NoError(t, err)
assert.EqualValues(t, 73, maxIndex) // the max index doesn't change because the transaction was rolled back
}
func TestGetNextResourceIndex(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
xe := unittest.GetXORMEngine()
assert.NoError(t, xe.Sync(&TestIndex{}))
// create a new record
maxIndex, err := db.GetNextResourceIndex(db.DefaultContext, "test_index", 20)
assert.NoError(t, err)
assert.EqualValues(t, 1, maxIndex)
// increase the existing record
maxIndex, err = db.GetNextResourceIndex(db.DefaultContext, "test_index", 20)
assert.NoError(t, err)
assert.EqualValues(t, 2, maxIndex)
// commit transaction
err = db.WithTx(db.DefaultContext, func(ctx context.Context) error {
maxIndex, err = db.GetNextResourceIndex(ctx, "test_index", 20)
assert.NoError(t, err)
assert.EqualValues(t, 3, maxIndex)
return nil
})
assert.NoError(t, err)
maxIndex, err = getCurrentResourceIndex(db.DefaultContext, "test_index", 20)
assert.NoError(t, err)
assert.EqualValues(t, 3, maxIndex)
// rollback transaction
err = db.WithTx(db.DefaultContext, func(ctx context.Context) error {
maxIndex, err = db.GetNextResourceIndex(ctx, "test_index", 20)
assert.NoError(t, err)
assert.EqualValues(t, 4, maxIndex)
return errors.New("test rollback")
})
assert.Error(t, err)
maxIndex, err = getCurrentResourceIndex(db.DefaultContext, "test_index", 20)
assert.NoError(t, err)
assert.EqualValues(t, 3, maxIndex) // the max index doesn't change because the transaction was rolled back
}

64
models/db/install/db.go Normal file
View File

@@ -0,0 +1,64 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package install
import (
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/setting"
"xorm.io/xorm"
)
func getXORMEngine() *xorm.Engine {
return db.GetEngine(db.DefaultContext).(*xorm.Engine)
}
// CheckDatabaseConnection checks the database connection
func CheckDatabaseConnection() error {
e := db.GetEngine(db.DefaultContext)
_, err := e.Exec("SELECT 1")
return err
}
// GetMigrationVersion gets the database migration version
func GetMigrationVersion() (int64, error) {
var installedDbVersion int64
x := getXORMEngine()
exist, err := x.IsTableExist("version")
if err != nil {
return 0, err
}
if !exist {
return 0, nil
}
_, err = x.Table("version").Cols("version").Get(&installedDbVersion)
if err != nil {
return 0, err
}
return installedDbVersion, nil
}
// HasPostInstallationUsers checks whether there are users after installation
func HasPostInstallationUsers() (bool, error) {
x := getXORMEngine()
exist, err := x.IsTableExist("user")
if err != nil {
return false, err
}
if !exist {
return false, nil
}
// if there are 2 or more users in database, we consider there are users created after installation
threshold := 2
if !setting.IsProd {
// to debug easily, with non-prod RUN_MODE, we only check the count to 1
threshold = 1
}
res, err := x.Table("user").Cols("id").Limit(threshold).Query()
if err != nil {
return false, err
}
return len(res) >= threshold, nil
}

43
models/db/iterate.go Normal file
View File

@@ -0,0 +1,43 @@
// Copyright 2022 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db
import (
"context"
"code.gitea.io/gitea/modules/setting"
"xorm.io/builder"
)
// Iterate iterates all the Bean object
func Iterate[Bean any](ctx context.Context, cond builder.Cond, f func(ctx context.Context, bean *Bean) error) error {
var start int
batchSize := setting.Database.IterateBufferSize
sess := GetEngine(ctx)
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
beans := make([]*Bean, 0, batchSize)
if cond != nil {
sess = sess.Where(cond)
}
if err := sess.Limit(batchSize, start).Find(&beans); err != nil {
return err
}
if len(beans) == 0 {
return nil
}
start += len(beans)
for _, bean := range beans {
if err := f(ctx, bean); err != nil {
return err
}
}
}
}
}

44
models/db/iterate_test.go Normal file
View File

@@ -0,0 +1,44 @@
// Copyright 2022 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db_test
import (
"context"
"testing"
"code.gitea.io/gitea/models/db"
repo_model "code.gitea.io/gitea/models/repo"
"code.gitea.io/gitea/models/unittest"
"github.com/stretchr/testify/assert"
)
func TestIterate(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
xe := unittest.GetXORMEngine()
assert.NoError(t, xe.Sync(&repo_model.RepoUnit{}))
cnt, err := db.GetEngine(db.DefaultContext).Count(&repo_model.RepoUnit{})
assert.NoError(t, err)
var repoUnitCnt int
err = db.Iterate(db.DefaultContext, nil, func(ctx context.Context, repo *repo_model.RepoUnit) error {
repoUnitCnt++
return nil
})
assert.NoError(t, err)
assert.EqualValues(t, cnt, repoUnitCnt)
err = db.Iterate(db.DefaultContext, nil, func(ctx context.Context, repoUnit *repo_model.RepoUnit) error {
has, err := db.ExistByID[repo_model.RepoUnit](ctx, repoUnit.ID)
if err != nil {
return err
}
if !has {
return db.ErrNotExist{Resource: "repo_unit", ID: repoUnit.ID}
}
return nil
})
assert.NoError(t, err)
}

215
models/db/list.go Normal file
View File

@@ -0,0 +1,215 @@
// Copyright 2020 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db
import (
"context"
"code.gitea.io/gitea/modules/setting"
"xorm.io/builder"
"xorm.io/xorm"
)
const (
// DefaultMaxInSize represents default variables number on IN () in SQL
DefaultMaxInSize = 50
defaultFindSliceSize = 10
)
// Paginator is the base for different ListOptions types
type Paginator interface {
GetSkipTake() (skip, take int)
IsListAll() bool
}
// SetSessionPagination sets pagination for a database session
func SetSessionPagination(sess Engine, p Paginator) *xorm.Session {
skip, take := p.GetSkipTake()
return sess.Limit(take, skip)
}
// ListOptions options to paginate results
type ListOptions struct {
PageSize int
Page int // start from 1
ListAll bool // if true, then PageSize and Page will not be taken
}
var ListOptionsAll = ListOptions{ListAll: true}
var (
_ Paginator = &ListOptions{}
_ FindOptions = ListOptions{}
)
// GetSkipTake returns the skip and take values
func (opts *ListOptions) GetSkipTake() (skip, take int) {
opts.SetDefaultValues()
return (opts.Page - 1) * opts.PageSize, opts.PageSize
}
func (opts ListOptions) GetPage() int {
return opts.Page
}
func (opts ListOptions) GetPageSize() int {
return opts.PageSize
}
// IsListAll indicates PageSize and Page will be ignored
func (opts ListOptions) IsListAll() bool {
return opts.ListAll
}
// SetDefaultValues sets default values
func (opts *ListOptions) SetDefaultValues() {
if opts.PageSize <= 0 {
opts.PageSize = setting.API.DefaultPagingNum
}
if opts.PageSize > setting.API.MaxResponseItems {
opts.PageSize = setting.API.MaxResponseItems
}
if opts.Page <= 0 {
opts.Page = 1
}
}
func (opts ListOptions) ToConds() builder.Cond {
return builder.NewCond()
}
// AbsoluteListOptions absolute options to paginate results
type AbsoluteListOptions struct {
skip int
take int
}
var _ Paginator = &AbsoluteListOptions{}
// NewAbsoluteListOptions creates a list option with applied limits
func NewAbsoluteListOptions(skip, take int) *AbsoluteListOptions {
if skip < 0 {
skip = 0
}
if take <= 0 {
take = setting.API.DefaultPagingNum
}
if take > setting.API.MaxResponseItems {
take = setting.API.MaxResponseItems
}
return &AbsoluteListOptions{skip, take}
}
// IsListAll will always return false
func (opts *AbsoluteListOptions) IsListAll() bool {
return false
}
// GetSkipTake returns the skip and take values
func (opts *AbsoluteListOptions) GetSkipTake() (skip, take int) {
return opts.skip, opts.take
}
// FindOptions represents a find options
type FindOptions interface {
GetPage() int
GetPageSize() int
IsListAll() bool
ToConds() builder.Cond
}
type JoinFunc func(sess Engine) error
type FindOptionsJoin interface {
ToJoins() []JoinFunc
}
type FindOptionsOrder interface {
ToOrders() string
}
// Find represents a common find function which accept an options interface
func Find[T any](ctx context.Context, opts FindOptions) ([]*T, error) {
sess := GetEngine(ctx).Where(opts.ToConds())
if joinOpt, ok := opts.(FindOptionsJoin); ok {
for _, joinFunc := range joinOpt.ToJoins() {
if err := joinFunc(sess); err != nil {
return nil, err
}
}
}
if orderOpt, ok := opts.(FindOptionsOrder); ok {
if order := orderOpt.ToOrders(); order != "" {
sess.OrderBy(order)
}
}
page, pageSize := opts.GetPage(), opts.GetPageSize()
if !opts.IsListAll() && pageSize > 0 {
if page == 0 {
page = 1
}
sess.Limit(pageSize, (page-1)*pageSize)
}
findPageSize := defaultFindSliceSize
if pageSize > 0 {
findPageSize = pageSize
}
objects := make([]*T, 0, findPageSize)
if err := sess.Find(&objects); err != nil {
return nil, err
}
return objects, nil
}
// Count represents a common count function which accept an options interface
func Count[T any](ctx context.Context, opts FindOptions) (int64, error) {
sess := GetEngine(ctx).Where(opts.ToConds())
if joinOpt, ok := opts.(FindOptionsJoin); ok {
for _, joinFunc := range joinOpt.ToJoins() {
if err := joinFunc(sess); err != nil {
return 0, err
}
}
}
var object T
return sess.Count(&object)
}
// FindAndCount represents a common findandcount function which accept an options interface
func FindAndCount[T any](ctx context.Context, opts FindOptions) ([]*T, int64, error) {
sess := GetEngine(ctx).Where(opts.ToConds())
page, pageSize := opts.GetPage(), opts.GetPageSize()
if !opts.IsListAll() && pageSize > 0 && page >= 1 {
sess.Limit(pageSize, (page-1)*pageSize)
}
if joinOpt, ok := opts.(FindOptionsJoin); ok {
for _, joinFunc := range joinOpt.ToJoins() {
if err := joinFunc(sess); err != nil {
return nil, 0, err
}
}
}
if orderOpt, ok := opts.(FindOptionsOrder); ok {
if order := orderOpt.ToOrders(); order != "" {
sess.OrderBy(order)
}
}
findPageSize := defaultFindSliceSize
if pageSize > 0 {
findPageSize = pageSize
}
objects := make([]*T, 0, findPageSize)
cnt, err := sess.FindAndCount(&objects)
if err != nil {
return nil, 0, err
}
return objects, cnt, nil
}

52
models/db/list_test.go Normal file
View File

@@ -0,0 +1,52 @@
// Copyright 2023 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db_test
import (
"testing"
"code.gitea.io/gitea/models/db"
repo_model "code.gitea.io/gitea/models/repo"
"code.gitea.io/gitea/models/unittest"
"github.com/stretchr/testify/assert"
"xorm.io/builder"
)
type mockListOptions struct {
db.ListOptions
}
func (opts mockListOptions) IsListAll() bool {
return true
}
func (opts mockListOptions) ToConds() builder.Cond {
return builder.NewCond()
}
func TestFind(t *testing.T) {
assert.NoError(t, unittest.PrepareTestDatabase())
xe := unittest.GetXORMEngine()
assert.NoError(t, xe.Sync(&repo_model.RepoUnit{}))
var repoUnitCount int
_, err := db.GetEngine(db.DefaultContext).SQL("SELECT COUNT(*) FROM repo_unit").Get(&repoUnitCount)
assert.NoError(t, err)
assert.NotEmpty(t, repoUnitCount)
opts := mockListOptions{}
repoUnits, err := db.Find[repo_model.RepoUnit](db.DefaultContext, opts)
assert.NoError(t, err)
assert.Len(t, repoUnits, repoUnitCount)
cnt, err := db.Count[repo_model.RepoUnit](db.DefaultContext, opts)
assert.NoError(t, err)
assert.EqualValues(t, repoUnitCount, cnt)
repoUnits, newCnt, err := db.FindAndCount[repo_model.RepoUnit](db.DefaultContext, opts)
assert.NoError(t, err)
assert.EqualValues(t, cnt, newCnt)
assert.Len(t, repoUnits, repoUnitCount)
}

107
models/db/log.go Normal file
View File

@@ -0,0 +1,107 @@
// Copyright 2017 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db
import (
"fmt"
"sync/atomic"
"code.gitea.io/gitea/modules/log"
xormlog "xorm.io/xorm/log"
)
// XORMLogBridge a logger bridge from Logger to xorm
type XORMLogBridge struct {
showSQL atomic.Bool
logger log.Logger
}
// NewXORMLogger inits a log bridge for xorm
func NewXORMLogger(showSQL bool) xormlog.Logger {
l := &XORMLogBridge{logger: log.GetLogger("xorm")}
l.showSQL.Store(showSQL)
return l
}
const stackLevel = 8
// Log a message with defined skip and at logging level
func (l *XORMLogBridge) Log(skip int, level log.Level, format string, v ...any) {
l.logger.Log(skip+1, level, format, v...)
}
// Debug show debug log
func (l *XORMLogBridge) Debug(v ...any) {
l.Log(stackLevel, log.DEBUG, "%s", fmt.Sprint(v...))
}
// Debugf show debug log
func (l *XORMLogBridge) Debugf(format string, v ...any) {
l.Log(stackLevel, log.DEBUG, format, v...)
}
// Error show error log
func (l *XORMLogBridge) Error(v ...any) {
l.Log(stackLevel, log.ERROR, "%s", fmt.Sprint(v...))
}
// Errorf show error log
func (l *XORMLogBridge) Errorf(format string, v ...any) {
l.Log(stackLevel, log.ERROR, format, v...)
}
// Info show information level log
func (l *XORMLogBridge) Info(v ...any) {
l.Log(stackLevel, log.INFO, "%s", fmt.Sprint(v...))
}
// Infof show information level log
func (l *XORMLogBridge) Infof(format string, v ...any) {
l.Log(stackLevel, log.INFO, format, v...)
}
// Warn show warning log
func (l *XORMLogBridge) Warn(v ...any) {
l.Log(stackLevel, log.WARN, "%s", fmt.Sprint(v...))
}
// Warnf show warnning log
func (l *XORMLogBridge) Warnf(format string, v ...any) {
l.Log(stackLevel, log.WARN, format, v...)
}
// Level get logger level
func (l *XORMLogBridge) Level() xormlog.LogLevel {
switch l.logger.GetLevel() {
case log.TRACE, log.DEBUG:
return xormlog.LOG_DEBUG
case log.INFO:
return xormlog.LOG_INFO
case log.WARN:
return xormlog.LOG_WARNING
case log.ERROR:
return xormlog.LOG_ERR
case log.NONE:
return xormlog.LOG_OFF
}
return xormlog.LOG_UNKNOWN
}
// SetLevel set the logger level
func (l *XORMLogBridge) SetLevel(lvl xormlog.LogLevel) {
}
// ShowSQL set if record SQL
func (l *XORMLogBridge) ShowSQL(show ...bool) {
if len(show) == 0 {
show = []bool{true}
}
l.showSQL.Store(show[0])
}
// IsShowSQL if record SQL
func (l *XORMLogBridge) IsShowSQL() bool {
return l.showSQL.Load()
}

17
models/db/main_test.go Normal file
View File

@@ -0,0 +1,17 @@
// Copyright 2020 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db_test
import (
"testing"
"code.gitea.io/gitea/models/unittest"
_ "code.gitea.io/gitea/models"
_ "code.gitea.io/gitea/models/repo"
)
func TestMain(m *testing.M) {
unittest.MainTest(m)
}

106
models/db/name.go Normal file
View File

@@ -0,0 +1,106 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db
import (
"fmt"
"regexp"
"strings"
"unicode/utf8"
"code.gitea.io/gitea/modules/util"
)
var (
// ErrNameEmpty name is empty error
ErrNameEmpty = util.SilentWrap{Message: "name is empty", Err: util.ErrInvalidArgument}
// AlphaDashDotPattern characters prohibited in a user name (anything except A-Za-z0-9_.-)
AlphaDashDotPattern = regexp.MustCompile(`[^\w-\.]`)
)
// ErrNameReserved represents a "reserved name" error.
type ErrNameReserved struct {
Name string
}
// IsErrNameReserved checks if an error is a ErrNameReserved.
func IsErrNameReserved(err error) bool {
_, ok := err.(ErrNameReserved)
return ok
}
func (err ErrNameReserved) Error() string {
return fmt.Sprintf("name is reserved [name: %s]", err.Name)
}
// Unwrap unwraps this as a ErrInvalid err
func (err ErrNameReserved) Unwrap() error {
return util.ErrInvalidArgument
}
// ErrNamePatternNotAllowed represents a "pattern not allowed" error.
type ErrNamePatternNotAllowed struct {
Pattern string
}
// IsErrNamePatternNotAllowed checks if an error is an ErrNamePatternNotAllowed.
func IsErrNamePatternNotAllowed(err error) bool {
_, ok := err.(ErrNamePatternNotAllowed)
return ok
}
func (err ErrNamePatternNotAllowed) Error() string {
return fmt.Sprintf("name pattern is not allowed [pattern: %s]", err.Pattern)
}
// Unwrap unwraps this as a ErrInvalid err
func (err ErrNamePatternNotAllowed) Unwrap() error {
return util.ErrInvalidArgument
}
// ErrNameCharsNotAllowed represents a "character not allowed in name" error.
type ErrNameCharsNotAllowed struct {
Name string
}
// IsErrNameCharsNotAllowed checks if an error is an ErrNameCharsNotAllowed.
func IsErrNameCharsNotAllowed(err error) bool {
_, ok := err.(ErrNameCharsNotAllowed)
return ok
}
func (err ErrNameCharsNotAllowed) Error() string {
return fmt.Sprintf("name is invalid [%s]: must be valid alpha or numeric or dash(-_) or dot characters", err.Name)
}
// Unwrap unwraps this as a ErrInvalid err
func (err ErrNameCharsNotAllowed) Unwrap() error {
return util.ErrInvalidArgument
}
// IsUsableName checks if name is reserved or pattern of name is not allowed
// based on given reserved names and patterns.
// Names are exact match, patterns can be prefix or suffix match with placeholder '*'.
func IsUsableName(names, patterns []string, name string) error {
name = strings.TrimSpace(strings.ToLower(name))
if utf8.RuneCountInString(name) == 0 {
return ErrNameEmpty
}
for i := range names {
if name == names[i] {
return ErrNameReserved{name}
}
}
for _, pat := range patterns {
if pat[0] == '*' && strings.HasSuffix(name, pat[1:]) ||
(pat[len(pat)-1] == '*' && strings.HasPrefix(name, pat[:len(pat)-1])) {
return ErrNamePatternNotAllowed{pat}
}
}
return nil
}

View File

@@ -0,0 +1,14 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package paginator
import (
"testing"
"code.gitea.io/gitea/models/unittest"
)
func TestMain(m *testing.M) {
unittest.MainTest(m)
}

View File

@@ -0,0 +1,7 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package paginator
// dummy only. in the future, the models/db/list_options.go should be moved here to decouple from db package
// otherwise the unit test will cause cycle import

View File

@@ -0,0 +1,59 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package paginator
import (
"testing"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/setting"
"github.com/stretchr/testify/assert"
)
func TestPaginator(t *testing.T) {
cases := []struct {
db.Paginator
Skip int
Take int
Start int
End int
}{
{
Paginator: &db.ListOptions{Page: -1, PageSize: -1},
Skip: 0,
Take: setting.API.DefaultPagingNum,
Start: 0,
End: setting.API.DefaultPagingNum,
},
{
Paginator: &db.ListOptions{Page: 2, PageSize: 10},
Skip: 10,
Take: 10,
Start: 10,
End: 20,
},
{
Paginator: db.NewAbsoluteListOptions(-1, -1),
Skip: 0,
Take: setting.API.DefaultPagingNum,
Start: 0,
End: setting.API.DefaultPagingNum,
},
{
Paginator: db.NewAbsoluteListOptions(2, 10),
Skip: 2,
Take: 10,
Start: 2,
End: 12,
},
}
for _, c := range cases {
skip, take := c.Paginator.GetSkipTake()
assert.Equal(t, c.Skip, skip)
assert.Equal(t, c.Take, take)
}
}

35
models/db/search.go Normal file
View File

@@ -0,0 +1,35 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db
// SearchOrderBy is used to sort the result
type SearchOrderBy string
func (s SearchOrderBy) String() string {
return string(s)
}
// Strings for sorting result
const (
SearchOrderByAlphabetically SearchOrderBy = "name ASC"
SearchOrderByAlphabeticallyReverse SearchOrderBy = "name DESC"
SearchOrderByLeastUpdated SearchOrderBy = "updated_unix ASC"
SearchOrderByRecentUpdated SearchOrderBy = "updated_unix DESC"
SearchOrderByOldest SearchOrderBy = "created_unix ASC"
SearchOrderByNewest SearchOrderBy = "created_unix DESC"
SearchOrderByID SearchOrderBy = "id ASC"
SearchOrderByIDReverse SearchOrderBy = "id DESC"
SearchOrderByStars SearchOrderBy = "num_stars ASC"
SearchOrderByStarsReverse SearchOrderBy = "num_stars DESC"
SearchOrderByForks SearchOrderBy = "num_forks ASC"
SearchOrderByForksReverse SearchOrderBy = "num_forks DESC"
)
// NoConditionID means a condition to filter the records which don't match any id.
// eg: "milestone_id=-1" means "find the items without any milestone.
const NoConditionID int64 = -1
// NonExistingID means a condition to match no result (eg: a non-existing user)
// It doesn't use -1 or -2 because they are used as builtin users.
const NonExistingID int64 = -1000000

70
models/db/sequence.go Normal file
View File

@@ -0,0 +1,70 @@
// Copyright 2018 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db
import (
"context"
"fmt"
"regexp"
"code.gitea.io/gitea/modules/setting"
)
// CountBadSequences looks for broken sequences from recreate-table mistakes
func CountBadSequences(_ context.Context) (int64, error) {
if !setting.Database.Type.IsPostgreSQL() {
return 0, nil
}
sess := x.NewSession()
defer sess.Close()
var sequences []string
schema := x.Dialect().URI().Schema
sess.Engine().SetSchema("")
if err := sess.Table("information_schema.sequences").Cols("sequence_name").Where("sequence_name LIKE 'tmp_recreate__%_id_seq%' AND sequence_catalog = ?", setting.Database.Name).Find(&sequences); err != nil {
return 0, err
}
sess.Engine().SetSchema(schema)
return int64(len(sequences)), nil
}
// FixBadSequences fixes for broken sequences from recreate-table mistakes
func FixBadSequences(_ context.Context) error {
if !setting.Database.Type.IsPostgreSQL() {
return nil
}
sess := x.NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
}
var sequences []string
schema := sess.Engine().Dialect().URI().Schema
sess.Engine().SetSchema("")
if err := sess.Table("information_schema.sequences").Cols("sequence_name").Where("sequence_name LIKE 'tmp_recreate__%_id_seq%' AND sequence_catalog = ?", setting.Database.Name).Find(&sequences); err != nil {
return err
}
sess.Engine().SetSchema(schema)
sequenceRegexp := regexp.MustCompile(`tmp_recreate__(\w+)_id_seq.*`)
for _, sequence := range sequences {
tableName := sequenceRegexp.FindStringSubmatch(sequence)[1]
newSequenceName := tableName + "_id_seq"
if _, err := sess.Exec(fmt.Sprintf("ALTER SEQUENCE `%s` RENAME TO `%s`", sequence, newSequenceName)); err != nil {
return err
}
if _, err := sess.Exec(fmt.Sprintf("SELECT setval('%s', COALESCE((SELECT MAX(id)+1 FROM `%s`), 1), false)", newSequenceName, tableName)); err != nil {
return err
}
}
return sess.Commit()
}

View File

@@ -0,0 +1,74 @@
// Copyright 2020 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package db
import (
"database/sql"
"database/sql/driver"
"sync"
"code.gitea.io/gitea/modules/setting"
"github.com/lib/pq"
"xorm.io/xorm/dialects"
)
var registerOnce sync.Once
func registerPostgresSchemaDriver() {
registerOnce.Do(func() {
sql.Register("postgresschema", &postgresSchemaDriver{})
dialects.RegisterDriver("postgresschema", dialects.QueryDriver("postgres"))
})
}
type postgresSchemaDriver struct {
pq.Driver
}
// Open opens a new connection to the database. name is a connection string.
// This function opens the postgres connection in the default manner but immediately
// runs set_config to set the search_path appropriately
func (d *postgresSchemaDriver) Open(name string) (driver.Conn, error) {
conn, err := d.Driver.Open(name)
if err != nil {
return conn, err
}
schemaValue, _ := driver.String.ConvertValue(setting.Database.Schema)
// golangci lint is incorrect here - there is no benefit to using driver.ExecerContext here
// and in any case pq does not implement it
if execer, ok := conn.(driver.Execer); ok { //nolint:staticcheck
_, err := execer.Exec(`SELECT set_config(
'search_path',
$1 || ',' || current_setting('search_path'),
false)`, []driver.Value{schemaValue})
if err != nil {
_ = conn.Close()
return nil, err
}
return conn, nil
}
stmt, err := conn.Prepare(`SELECT set_config(
'search_path',
$1 || ',' || current_setting('search_path'),
false)`)
if err != nil {
_ = conn.Close()
return nil, err
}
defer stmt.Close()
// driver.String.ConvertValue will never return err for string
// golangci lint is incorrect here - there is no benefit to using stmt.ExecWithContext here
_, err = stmt.Exec([]driver.Value{schemaValue}) //nolint:staticcheck
if err != nil {
_ = conn.Close()
return nil, err
}
return conn, nil
}

Some files were not shown because too many files have changed in this diff Show More