Move macaron to chi (#14293)
Use [chi](https://github.com/go-chi/chi) instead of the forked [macaron](https://gitea.com/macaron/macaron). Since macaron and chi have conflicts with session share, this big PR becomes a have-to thing. According my previous idea, we can replace macaron step by step but I'm wrong. :( Below is a list of big changes on this PR. - [x] Define `context.ResponseWriter` interface with an implementation `context.Response`. - [x] Use chi instead of macaron, and also a customize `Route` to wrap chi so that the router usage is similar as before. - [x] Create different routers for `web`, `api`, `internal` and `install` so that the codes will be more clear and no magic . - [x] Use https://github.com/unrolled/render instead of macaron's internal render - [x] Use https://github.com/NYTimes/gziphandler instead of https://gitea.com/macaron/gzip - [x] Use https://gitea.com/go-chi/session which is a modified version of https://gitea.com/macaron/session and removed `nodb` support since it will not be maintained. **BREAK** - [x] Use https://gitea.com/go-chi/captcha which is a modified version of https://gitea.com/macaron/captcha - [x] Use https://gitea.com/go-chi/cache which is a modified version of https://gitea.com/macaron/cache - [x] Use https://gitea.com/go-chi/binding which is a modified version of https://gitea.com/macaron/binding - [x] Use https://github.com/go-chi/cors instead of https://gitea.com/macaron/cors - [x] Dropped https://gitea.com/macaron/i18n and make a new one in `code.gitea.io/gitea/modules/translation` - [x] Move validation form structs from `code.gitea.io/gitea/modules/auth` to `code.gitea.io/gitea/modules/forms` to avoid dependency cycle. - [x] Removed macaron log service because it's not need any more. **BREAK** - [x] All form structs have to be get by `web.GetForm(ctx)` in the route function but not as a function parameter on routes definition. - [x] Move Git HTTP protocol implementation to use routers directly. - [x] Fix the problem that chi routes don't support trailing slash but macaron did. - [x] `/api/v1/swagger` now will be redirect to `/api/swagger` but not render directly so that `APIContext` will not create a html render. Notices: - Chi router don't support request with trailing slash - Integration test `TestUserHeatmap` maybe mysql version related. It's failed on my macOS(mysql 5.7.29 installed via brew) but succeed on CI. Co-authored-by: 6543 <6543@obermui.de>tokarchuk/v1.17
parent
3adbbb4255
commit
6433ba0ec3
@ -0,0 +1,26 @@ |
||||
// Copyright 2020 The Gitea Authors. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package context |
||||
|
||||
import ( |
||||
"sync" |
||||
|
||||
"code.gitea.io/gitea/modules/setting" |
||||
|
||||
"gitea.com/go-chi/captcha" |
||||
) |
||||
|
||||
var imageCaptchaOnce sync.Once |
||||
var cpt *captcha.Captcha |
||||
|
||||
// GetImageCaptcha returns global image captcha
|
||||
func GetImageCaptcha() *captcha.Captcha { |
||||
imageCaptchaOnce.Do(func() { |
||||
cpt = captcha.NewCaptcha(captcha.Options{ |
||||
SubURL: setting.AppSubURL, |
||||
}) |
||||
}) |
||||
return cpt |
||||
} |
@ -0,0 +1,227 @@ |
||||
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package context |
||||
|
||||
import ( |
||||
"errors" |
||||
"net/http" |
||||
"net/url" |
||||
"strconv" |
||||
"strings" |
||||
"text/template" |
||||
|
||||
"code.gitea.io/gitea/modules/log" |
||||
) |
||||
|
||||
// Forms a new enhancement of http.Request
|
||||
type Forms http.Request |
||||
|
||||
// Values returns http.Request values
|
||||
func (f *Forms) Values() url.Values { |
||||
return (*http.Request)(f).Form |
||||
} |
||||
|
||||
// String returns request form as string
|
||||
func (f *Forms) String(key string) (string, error) { |
||||
return (*http.Request)(f).FormValue(key), nil |
||||
} |
||||
|
||||
// Trimmed returns request form as string with trimed spaces left and right
|
||||
func (f *Forms) Trimmed(key string) (string, error) { |
||||
return strings.TrimSpace((*http.Request)(f).FormValue(key)), nil |
||||
} |
||||
|
||||
// Strings returns request form as strings
|
||||
func (f *Forms) Strings(key string) ([]string, error) { |
||||
if (*http.Request)(f).Form == nil { |
||||
if err := (*http.Request)(f).ParseMultipartForm(32 << 20); err != nil { |
||||
return nil, err |
||||
} |
||||
} |
||||
if v, ok := (*http.Request)(f).Form[key]; ok { |
||||
return v, nil |
||||
} |
||||
return nil, errors.New("not exist") |
||||
} |
||||
|
||||
// Escape returns request form as escaped string
|
||||
func (f *Forms) Escape(key string) (string, error) { |
||||
return template.HTMLEscapeString((*http.Request)(f).FormValue(key)), nil |
||||
} |
||||
|
||||
// Int returns request form as int
|
||||
func (f *Forms) Int(key string) (int, error) { |
||||
return strconv.Atoi((*http.Request)(f).FormValue(key)) |
||||
} |
||||
|
||||
// Int32 returns request form as int32
|
||||
func (f *Forms) Int32(key string) (int32, error) { |
||||
v, err := strconv.ParseInt((*http.Request)(f).FormValue(key), 10, 32) |
||||
return int32(v), err |
||||
} |
||||
|
||||
// Int64 returns request form as int64
|
||||
func (f *Forms) Int64(key string) (int64, error) { |
||||
return strconv.ParseInt((*http.Request)(f).FormValue(key), 10, 64) |
||||
} |
||||
|
||||
// Uint returns request form as uint
|
||||
func (f *Forms) Uint(key string) (uint, error) { |
||||
v, err := strconv.ParseUint((*http.Request)(f).FormValue(key), 10, 64) |
||||
return uint(v), err |
||||
} |
||||
|
||||
// Uint32 returns request form as uint32
|
||||
func (f *Forms) Uint32(key string) (uint32, error) { |
||||
v, err := strconv.ParseUint((*http.Request)(f).FormValue(key), 10, 32) |
||||
return uint32(v), err |
||||
} |
||||
|
||||
// Uint64 returns request form as uint64
|
||||
func (f *Forms) Uint64(key string) (uint64, error) { |
||||
return strconv.ParseUint((*http.Request)(f).FormValue(key), 10, 64) |
||||
} |
||||
|
||||
// Bool returns request form as bool
|
||||
func (f *Forms) Bool(key string) (bool, error) { |
||||
return strconv.ParseBool((*http.Request)(f).FormValue(key)) |
||||
} |
||||
|
||||
// Float32 returns request form as float32
|
||||
func (f *Forms) Float32(key string) (float32, error) { |
||||
v, err := strconv.ParseFloat((*http.Request)(f).FormValue(key), 64) |
||||
return float32(v), err |
||||
} |
||||
|
||||
// Float64 returns request form as float64
|
||||
func (f *Forms) Float64(key string) (float64, error) { |
||||
return strconv.ParseFloat((*http.Request)(f).FormValue(key), 64) |
||||
} |
||||
|
||||
// MustString returns request form as string with default
|
||||
func (f *Forms) MustString(key string, defaults ...string) string { |
||||
if v := (*http.Request)(f).FormValue(key); len(v) > 0 { |
||||
return v |
||||
} |
||||
if len(defaults) > 0 { |
||||
return defaults[0] |
||||
} |
||||
return "" |
||||
} |
||||
|
||||
// MustTrimmed returns request form as string with default
|
||||
func (f *Forms) MustTrimmed(key string, defaults ...string) string { |
||||
return strings.TrimSpace(f.MustString(key, defaults...)) |
||||
} |
||||
|
||||
// MustStrings returns request form as strings with default
|
||||
func (f *Forms) MustStrings(key string, defaults ...[]string) []string { |
||||
if (*http.Request)(f).Form == nil { |
||||
if err := (*http.Request)(f).ParseMultipartForm(32 << 20); err != nil { |
||||
log.Error("ParseMultipartForm: %v", err) |
||||
return []string{} |
||||
} |
||||
} |
||||
|
||||
if v, ok := (*http.Request)(f).Form[key]; ok { |
||||
return v |
||||
} |
||||
if len(defaults) > 0 { |
||||
return defaults[0] |
||||
} |
||||
return []string{} |
||||
} |
||||
|
||||
// MustEscape returns request form as escaped string with default
|
||||
func (f *Forms) MustEscape(key string, defaults ...string) string { |
||||
if v := (*http.Request)(f).FormValue(key); len(v) > 0 { |
||||
return template.HTMLEscapeString(v) |
||||
} |
||||
if len(defaults) > 0 { |
||||
return defaults[0] |
||||
} |
||||
return "" |
||||
} |
||||
|
||||
// MustInt returns request form as int with default
|
||||
func (f *Forms) MustInt(key string, defaults ...int) int { |
||||
v, err := strconv.Atoi((*http.Request)(f).FormValue(key)) |
||||
if len(defaults) > 0 && err != nil { |
||||
return defaults[0] |
||||
} |
||||
return v |
||||
} |
||||
|
||||
// MustInt32 returns request form as int32 with default
|
||||
func (f *Forms) MustInt32(key string, defaults ...int32) int32 { |
||||
v, err := strconv.ParseInt((*http.Request)(f).FormValue(key), 10, 32) |
||||
if len(defaults) > 0 && err != nil { |
||||
return defaults[0] |
||||
} |
||||
return int32(v) |
||||
} |
||||
|
||||
// MustInt64 returns request form as int64 with default
|
||||
func (f *Forms) MustInt64(key string, defaults ...int64) int64 { |
||||
v, err := strconv.ParseInt((*http.Request)(f).FormValue(key), 10, 64) |
||||
if len(defaults) > 0 && err != nil { |
||||
return defaults[0] |
||||
} |
||||
return v |
||||
} |
||||
|
||||
// MustUint returns request form as uint with default
|
||||
func (f *Forms) MustUint(key string, defaults ...uint) uint { |
||||
v, err := strconv.ParseUint((*http.Request)(f).FormValue(key), 10, 64) |
||||
if len(defaults) > 0 && err != nil { |
||||
return defaults[0] |
||||
} |
||||
return uint(v) |
||||
} |
||||
|
||||
// MustUint32 returns request form as uint32 with default
|
||||
func (f *Forms) MustUint32(key string, defaults ...uint32) uint32 { |
||||
v, err := strconv.ParseUint((*http.Request)(f).FormValue(key), 10, 32) |
||||
if len(defaults) > 0 && err != nil { |
||||
return defaults[0] |
||||
} |
||||
return uint32(v) |
||||
} |
||||
|
||||
// MustUint64 returns request form as uint64 with default
|
||||
func (f *Forms) MustUint64(key string, defaults ...uint64) uint64 { |
||||
v, err := strconv.ParseUint((*http.Request)(f).FormValue(key), 10, 64) |
||||
if len(defaults) > 0 && err != nil { |
||||
return defaults[0] |
||||
} |
||||
return v |
||||
} |
||||
|
||||
// MustFloat32 returns request form as float32 with default
|
||||
func (f *Forms) MustFloat32(key string, defaults ...float32) float32 { |
||||
v, err := strconv.ParseFloat((*http.Request)(f).FormValue(key), 32) |
||||
if len(defaults) > 0 && err != nil { |
||||
return defaults[0] |
||||
} |
||||
return float32(v) |
||||
} |
||||
|
||||
// MustFloat64 returns request form as float64 with default
|
||||
func (f *Forms) MustFloat64(key string, defaults ...float64) float64 { |
||||
v, err := strconv.ParseFloat((*http.Request)(f).FormValue(key), 64) |
||||
if len(defaults) > 0 && err != nil { |
||||
return defaults[0] |
||||
} |
||||
return v |
||||
} |
||||
|
||||
// MustBool returns request form as bool with default
|
||||
func (f *Forms) MustBool(key string, defaults ...bool) bool { |
||||
v, err := strconv.ParseBool((*http.Request)(f).FormValue(key)) |
||||
if len(defaults) > 0 && err != nil { |
||||
return defaults[0] |
||||
} |
||||
return v |
||||
} |
@ -0,0 +1,45 @@ |
||||
// Copyright 2020 The Gitea Authors. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package context |
||||
|
||||
import ( |
||||
"context" |
||||
"net/http" |
||||
) |
||||
|
||||
// PrivateContext represents a context for private routes
|
||||
type PrivateContext struct { |
||||
*Context |
||||
} |
||||
|
||||
var ( |
||||
privateContextKey interface{} = "default_private_context" |
||||
) |
||||
|
||||
// WithPrivateContext set up private context in request
|
||||
func WithPrivateContext(req *http.Request, ctx *PrivateContext) *http.Request { |
||||
return req.WithContext(context.WithValue(req.Context(), privateContextKey, ctx)) |
||||
} |
||||
|
||||
// GetPrivateContext returns a context for Private routes
|
||||
func GetPrivateContext(req *http.Request) *PrivateContext { |
||||
return req.Context().Value(privateContextKey).(*PrivateContext) |
||||
} |
||||
|
||||
// PrivateContexter returns apicontext as macaron middleware
|
||||
func PrivateContexter() func(http.Handler) http.Handler { |
||||
return func(next http.Handler) http.Handler { |
||||
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { |
||||
ctx := &PrivateContext{ |
||||
Context: &Context{ |
||||
Resp: NewResponse(w), |
||||
Data: map[string]interface{}{}, |
||||
}, |
||||
} |
||||
ctx.Req = WithPrivateContext(req, ctx) |
||||
next.ServeHTTP(ctx.Resp, ctx.Req) |
||||
}) |
||||
} |
||||
} |
@ -0,0 +1,100 @@ |
||||
// Copyright 2019 The Gitea Authors. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package context |
||||
|
||||
import ( |
||||
"crypto/aes" |
||||
"crypto/cipher" |
||||
"crypto/rand" |
||||
"crypto/sha256" |
||||
"encoding/base64" |
||||
"errors" |
||||
"io" |
||||
) |
||||
|
||||
// NewSecret creates a new secret
|
||||
func NewSecret() (string, error) { |
||||
return NewSecretWithLength(32) |
||||
} |
||||
|
||||
// NewSecretWithLength creates a new secret for a given length
|
||||
func NewSecretWithLength(length int64) (string, error) { |
||||
return randomString(length) |
||||
} |
||||
|
||||
func randomBytes(len int64) ([]byte, error) { |
||||
b := make([]byte, len) |
||||
if _, err := rand.Read(b); err != nil { |
||||
return nil, err |
||||
} |
||||
return b, nil |
||||
} |
||||
|
||||
func randomString(len int64) (string, error) { |
||||
b, err := randomBytes(len) |
||||
return base64.URLEncoding.EncodeToString(b), err |
||||
} |
||||
|
||||
// AesEncrypt encrypts text and given key with AES.
|
||||
func AesEncrypt(key, text []byte) ([]byte, error) { |
||||
block, err := aes.NewCipher(key) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
b := base64.StdEncoding.EncodeToString(text) |
||||
ciphertext := make([]byte, aes.BlockSize+len(b)) |
||||
iv := ciphertext[:aes.BlockSize] |
||||
if _, err := io.ReadFull(rand.Reader, iv); err != nil { |
||||
return nil, err |
||||
} |
||||
cfb := cipher.NewCFBEncrypter(block, iv) |
||||
cfb.XORKeyStream(ciphertext[aes.BlockSize:], []byte(b)) |
||||
return ciphertext, nil |
||||
} |
||||
|
||||
// AesDecrypt decrypts text and given key with AES.
|
||||
func AesDecrypt(key, text []byte) ([]byte, error) { |
||||
block, err := aes.NewCipher(key) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
if len(text) < aes.BlockSize { |
||||
return nil, errors.New("ciphertext too short") |
||||
} |
||||
iv := text[:aes.BlockSize] |
||||
text = text[aes.BlockSize:] |
||||
cfb := cipher.NewCFBDecrypter(block, iv) |
||||
cfb.XORKeyStream(text, text) |
||||
data, err := base64.StdEncoding.DecodeString(string(text)) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return data, nil |
||||
} |
||||
|
||||
// EncryptSecret encrypts a string with given key into a hex string
|
||||
func EncryptSecret(key string, str string) (string, error) { |
||||
keyHash := sha256.Sum256([]byte(key)) |
||||
plaintext := []byte(str) |
||||
ciphertext, err := AesEncrypt(keyHash[:], plaintext) |
||||
if err != nil { |
||||
return "", err |
||||
} |
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil |
||||
} |
||||
|
||||
// DecryptSecret decrypts a previously encrypted hex string
|
||||
func DecryptSecret(key string, cipherhex string) (string, error) { |
||||
keyHash := sha256.Sum256([]byte(key)) |
||||
ciphertext, err := base64.StdEncoding.DecodeString(cipherhex) |
||||
if err != nil { |
||||
return "", err |
||||
} |
||||
plaintext, err := AesDecrypt(keyHash[:], ciphertext) |
||||
if err != nil { |
||||
return "", err |
||||
} |
||||
return string(plaintext), nil |
||||
} |
@ -0,0 +1,90 @@ |
||||
// Copyright 2012 Google Inc. All Rights Reserved.
|
||||
// Copyright 2014 The Macaron Authors
|
||||
// Copyright 2020 The Gitea Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package context |
||||
|
||||
import ( |
||||
"encoding/base64" |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/stretchr/testify/assert" |
||||
) |
||||
|
||||
const ( |
||||
key = "quay" |
||||
userID = "12345678" |
||||
actionID = "POST /form" |
||||
) |
||||
|
||||
var ( |
||||
now = time.Now() |
||||
oneMinuteFromNow = now.Add(1 * time.Minute) |
||||
) |
||||
|
||||
func Test_ValidToken(t *testing.T) { |
||||
t.Run("Validate token", func(t *testing.T) { |
||||
tok := generateTokenAtTime(key, userID, actionID, now) |
||||
assert.True(t, validTokenAtTime(tok, key, userID, actionID, oneMinuteFromNow)) |
||||
assert.True(t, validTokenAtTime(tok, key, userID, actionID, now.Add(Timeout-1*time.Nanosecond))) |
||||
assert.True(t, validTokenAtTime(tok, key, userID, actionID, now.Add(-1*time.Minute))) |
||||
}) |
||||
} |
||||
|
||||
// Test_SeparatorReplacement tests that separators are being correctly substituted
|
||||
func Test_SeparatorReplacement(t *testing.T) { |
||||
t.Run("Test two separator replacements", func(t *testing.T) { |
||||
assert.NotEqual(t, generateTokenAtTime("foo:bar", "baz", "wah", now), |
||||
generateTokenAtTime("foo", "bar:baz", "wah", now)) |
||||
}) |
||||
} |
||||
|
||||
func Test_InvalidToken(t *testing.T) { |
||||
t.Run("Test invalid tokens", func(t *testing.T) { |
||||
invalidTokenTests := []struct { |
||||
name, key, userID, actionID string |
||||
t time.Time |
||||
}{ |
||||
{"Bad key", "foobar", userID, actionID, oneMinuteFromNow}, |
||||
{"Bad userID", key, "foobar", actionID, oneMinuteFromNow}, |
||||
{"Bad actionID", key, userID, "foobar", oneMinuteFromNow}, |
||||
{"Expired", key, userID, actionID, now.Add(Timeout)}, |
||||
{"More than 1 minute from the future", key, userID, actionID, now.Add(-1*time.Nanosecond - 1*time.Minute)}, |
||||
} |
||||
|
||||
tok := generateTokenAtTime(key, userID, actionID, now) |
||||
for _, itt := range invalidTokenTests { |
||||
assert.False(t, validTokenAtTime(tok, itt.key, itt.userID, itt.actionID, itt.t)) |
||||
} |
||||
}) |
||||
} |
||||
|
||||
// Test_ValidateBadData primarily tests that no unexpected panics are triggered during parsing
|
||||
func Test_ValidateBadData(t *testing.T) { |
||||
t.Run("Validate bad data", func(t *testing.T) { |
||||
badDataTests := []struct { |
||||
name, tok string |
||||
}{ |
||||
{"Invalid Base64", "ASDab24(@)$*=="}, |
||||
{"No delimiter", base64.URLEncoding.EncodeToString([]byte("foobar12345678"))}, |
||||
{"Invalid time", base64.URLEncoding.EncodeToString([]byte("foobar:foobar"))}, |
||||
} |
||||
|
||||
for _, bdt := range badDataTests { |
||||
assert.False(t, validTokenAtTime(bdt.tok, key, userID, actionID, oneMinuteFromNow)) |
||||
} |
||||
}) |
||||
} |
@ -0,0 +1,10 @@ |
||||
// Copyright 2020 The Gitea Authors. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package middlewares |
||||
|
||||
// DataStore represents a data store
|
||||
type DataStore interface { |
||||
GetData() map[string]interface{} |
||||
} |
@ -0,0 +1,65 @@ |
||||
// Copyright 2020 The Gitea Authors. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package middlewares |
||||
|
||||
import "net/url" |
||||
|
||||
// flashes enumerates all the flash types
|
||||
const ( |
||||
SuccessFlash = "SuccessMsg" |
||||
ErrorFlash = "ErrorMsg" |
||||
WarnFlash = "WarningMsg" |
||||
InfoFlash = "InfoMsg" |
||||
) |
||||
|
||||
var ( |
||||
// FlashNow FIXME:
|
||||
FlashNow bool |
||||
) |
||||
|
||||
// Flash represents a one time data transfer between two requests.
|
||||
type Flash struct { |
||||
DataStore |
||||
url.Values |
||||
ErrorMsg, WarningMsg, InfoMsg, SuccessMsg string |
||||
} |
||||
|
||||
func (f *Flash) set(name, msg string, current ...bool) { |
||||
isShow := false |
||||
if (len(current) == 0 && FlashNow) || |
||||
(len(current) > 0 && current[0]) { |
||||
isShow = true |
||||
} |
||||
|
||||
if isShow { |
||||
f.GetData()["Flash"] = f |
||||
} else { |
||||
f.Set(name, msg) |
||||
} |
||||
} |
||||
|
||||
// Error sets error message
|
||||
func (f *Flash) Error(msg string, current ...bool) { |
||||
f.ErrorMsg = msg |
||||
f.set("error", msg, current...) |
||||
} |
||||
|
||||
// Warning sets warning message
|
||||
func (f *Flash) Warning(msg string, current ...bool) { |
||||
f.WarningMsg = msg |
||||
f.set("warning", msg, current...) |
||||
} |
||||
|
||||
// Info sets info message
|
||||
func (f *Flash) Info(msg string, current ...bool) { |
||||
f.InfoMsg = msg |
||||
f.set("info", msg, current...) |
||||
} |
||||
|
||||
// Success sets success message
|
||||
func (f *Flash) Success(msg string, current ...bool) { |
||||
f.SuccessMsg = msg |
||||
f.set("success", msg, current...) |
||||
} |
@ -1,217 +0,0 @@ |
||||
// Copyright 2013 Beego Authors
|
||||
// Copyright 2014 The Macaron Authors
|
||||
// Copyright 2020 The Gitea Authors. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License"): you may
|
||||
// not use this file except in compliance with the License. You may obtain
|
||||
// a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
// License for the specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package middlewares |
||||
|
||||
import ( |
||||
"fmt" |
||||
"sync" |
||||
"time" |
||||
|
||||
"code.gitea.io/gitea/modules/nosql" |
||||
|
||||
"gitea.com/go-chi/session" |
||||
"github.com/go-redis/redis/v7" |
||||
) |
||||
|
||||
// RedisStore represents a redis session store implementation.
|
||||
// TODO: copied from modules/session/redis.go and should remove that one until macaron removed.
|
||||
type RedisStore struct { |
||||
c redis.UniversalClient |
||||
prefix, sid string |
||||
duration time.Duration |
||||
lock sync.RWMutex |
||||
data map[interface{}]interface{} |
||||
} |
||||
|
||||
// NewRedisStore creates and returns a redis session store.
|
||||
func NewRedisStore(c redis.UniversalClient, prefix, sid string, dur time.Duration, kv map[interface{}]interface{}) *RedisStore { |
||||
return &RedisStore{ |
||||
c: c, |
||||
prefix: prefix, |
||||
sid: sid, |
||||
duration: dur, |
||||
data: kv, |
||||
} |
||||
} |
||||
|
||||
// Set sets value to given key in session.
|
||||
func (s *RedisStore) Set(key, val interface{}) error { |
||||
s.lock.Lock() |
||||
defer s.lock.Unlock() |
||||
|
||||
s.data[key] = val |
||||
return nil |
||||
} |
||||
|
||||
// Get gets value by given key in session.
|
||||
func (s *RedisStore) Get(key interface{}) interface{} { |
||||
s.lock.RLock() |
||||
defer s.lock.RUnlock() |
||||
|
||||
return s.data[key] |
||||
} |
||||
|
||||
// Delete delete a key from session.
|
||||
func (s *RedisStore) Delete(key interface{}) error { |
||||
s.lock.Lock() |
||||
defer s.lock.Unlock() |
||||
|
||||
delete(s.data, key) |
||||
return nil |
||||
} |
||||
|
||||
// ID returns current session ID.
|
||||
func (s *RedisStore) ID() string { |
||||
return s.sid |
||||
} |
||||
|
||||
// Release releases resource and save data to provider.
|
||||
func (s *RedisStore) Release() error { |
||||
// Skip encoding if the data is empty
|
||||
if len(s.data) == 0 { |
||||
return nil |
||||
} |
||||
|
||||
data, err := session.EncodeGob(s.data) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
return s.c.Set(s.prefix+s.sid, string(data), s.duration).Err() |
||||
} |
||||
|
||||
// Flush deletes all session data.
|
||||
func (s *RedisStore) Flush() error { |
||||
s.lock.Lock() |
||||
defer s.lock.Unlock() |
||||
|
||||
s.data = make(map[interface{}]interface{}) |
||||
return nil |
||||
} |
||||
|
||||
// RedisProvider represents a redis session provider implementation.
|
||||
type RedisProvider struct { |
||||
c redis.UniversalClient |
||||
duration time.Duration |
||||
prefix string |
||||
} |
||||
|
||||
// Init initializes redis session provider.
|
||||
// configs: network=tcp,addr=:6379,password=macaron,db=0,pool_size=100,idle_timeout=180,prefix=session;
|
||||
func (p *RedisProvider) Init(maxlifetime int64, configs string) (err error) { |
||||
p.duration, err = time.ParseDuration(fmt.Sprintf("%ds", maxlifetime)) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
uri := nosql.ToRedisURI(configs) |
||||
|
||||
for k, v := range uri.Query() { |
||||
switch k { |
||||
case "prefix": |
||||
p.prefix = v[0] |
||||
} |
||||
} |
||||
|
||||
p.c = nosql.GetManager().GetRedisClient(uri.String()) |
||||
return p.c.Ping().Err() |
||||
} |
||||
|
||||
// Read returns raw session store by session ID.
|
||||
func (p *RedisProvider) Read(sid string) (session.RawStore, error) { |
||||
psid := p.prefix + sid |
||||
if !p.Exist(sid) { |
||||
if err := p.c.Set(psid, "", p.duration).Err(); err != nil { |
||||
return nil, err |
||||
} |
||||
} |
||||
|
||||
var kv map[interface{}]interface{} |
||||
kvs, err := p.c.Get(psid).Result() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
if len(kvs) == 0 { |
||||
kv = make(map[interface{}]interface{}) |
||||
} else { |
||||
kv, err = session.DecodeGob([]byte(kvs)) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
} |
||||
|
||||
return NewRedisStore(p.c, p.prefix, sid, p.duration, kv), nil |
||||
} |
||||
|
||||
// Exist returns true if session with given ID exists.
|
||||
func (p *RedisProvider) Exist(sid string) bool { |
||||
v, err := p.c.Exists(p.prefix + sid).Result() |
||||
return err == nil && v == 1 |
||||
} |
||||
|
||||
// Destroy deletes a session by session ID.
|
||||
func (p *RedisProvider) Destroy(sid string) error { |
||||
return p.c.Del(p.prefix + sid).Err() |
||||
} |
||||
|
||||
// Regenerate regenerates a session store from old session ID to new one.
|
||||
func (p *RedisProvider) Regenerate(oldsid, sid string) (_ session.RawStore, err error) { |
||||
poldsid := p.prefix + oldsid |
||||
psid := p.prefix + sid |
||||
|
||||
if p.Exist(sid) { |
||||
return nil, fmt.Errorf("new sid '%s' already exists", sid) |
||||
} else if !p.Exist(oldsid) { |
||||
// Make a fake old session.
|
||||
if err = p.c.Set(poldsid, "", p.duration).Err(); err != nil { |
||||
return nil, err |
||||
} |
||||
} |
||||
|
||||
if err = p.c.Rename(poldsid, psid).Err(); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
var kv map[interface{}]interface{} |
||||
kvs, err := p.c.Get(psid).Result() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
if len(kvs) == 0 { |
||||
kv = make(map[interface{}]interface{}) |
||||
} else { |
||||
kv, err = session.DecodeGob([]byte(kvs)) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
} |
||||
|
||||
return NewRedisStore(p.c, p.prefix, sid, p.duration, kv), nil |
||||
} |
||||
|
||||
// Count counts and returns number of sessions.
|
||||
func (p *RedisProvider) Count() int { |
||||
return int(p.c.DBSize().Val()) |
||||
} |
||||
|
||||
// GC calls GC to clean expired sessions.
|
||||
func (*RedisProvider) GC() {} |
||||
|
||||
func init() { |
||||
session.Register("redis", &RedisProvider{}) |
||||
} |
@ -1,196 +0,0 @@ |
||||
// Copyright 2019 The Gitea Authors. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package middlewares |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"fmt" |
||||
"sync" |
||||
|
||||
"gitea.com/go-chi/session" |
||||
couchbase "gitea.com/go-chi/session/couchbase" |
||||
memcache "gitea.com/go-chi/session/memcache" |
||||
mysql "gitea.com/go-chi/session/mysql" |
||||
postgres "gitea.com/go-chi/session/postgres" |
||||
) |
||||
|
||||
// VirtualSessionProvider represents a shadowed session provider implementation.
|
||||
// TODO: copied from modules/session/redis.go and should remove that one until macaron removed.
|
||||
type VirtualSessionProvider struct { |
||||
lock sync.RWMutex |
||||
provider session.Provider |
||||
} |
||||
|
||||
// Init initializes the cookie session provider with given root path.
|
||||
func (o *VirtualSessionProvider) Init(gclifetime int64, config string) error { |
||||
var opts session.Options |
||||
if err := json.Unmarshal([]byte(config), &opts); err != nil { |
||||
return err |
||||
} |
||||
// Note that these options are unprepared so we can't just use NewManager here.
|
||||
// Nor can we access the provider map in session.
|
||||
// So we will just have to do this by hand.
|
||||
// This is only slightly more wrong than modules/setting/session.go:23
|
||||
switch opts.Provider { |
||||
case "memory": |
||||
o.provider = &session.MemProvider{} |
||||
case "file": |
||||
o.provider = &session.FileProvider{} |
||||
case "redis": |
||||
o.provider = &RedisProvider{} |
||||
case "mysql": |
||||
o.provider = &mysql.MysqlProvider{} |
||||
case "postgres": |
||||
o.provider = &postgres.PostgresProvider{} |
||||
case "couchbase": |
||||
o.provider = &couchbase.CouchbaseProvider{} |
||||
case "memcache": |
||||
o.provider = &memcache.MemcacheProvider{} |
||||
default: |
||||
return fmt.Errorf("VirtualSessionProvider: Unknown Provider: %s", opts.Provider) |
||||
} |
||||
return o.provider.Init(gclifetime, opts.ProviderConfig) |
||||
} |
||||
|
||||
// Read returns raw session store by session ID.
|
||||
func (o *VirtualSessionProvider) Read(sid string) (session.RawStore, error) { |
||||
o.lock.RLock() |
||||
defer o.lock.RUnlock() |
||||
if o.provider.Exist(sid) { |
||||
return o.provider.Read(sid) |
||||
} |
||||
kv := make(map[interface{}]interface{}) |
||||
kv["_old_uid"] = "0" |
||||
return NewVirtualStore(o, sid, kv), nil |
||||
} |
||||
|
||||
// Exist returns true if session with given ID exists.
|
||||
func (o *VirtualSessionProvider) Exist(sid string) bool { |
||||
return true |
||||
} |
||||
|
||||
// Destroy deletes a session by session ID.
|
||||
func (o *VirtualSessionProvider) Destroy(sid string) error { |
||||
o.lock.Lock() |
||||
defer o.lock.Unlock() |
||||
return o.provider.Destroy(sid) |
||||
} |
||||
|
||||
// Regenerate regenerates a session store from old session ID to new one.
|
||||
func (o *VirtualSessionProvider) Regenerate(oldsid, sid string) (session.RawStore, error) { |
||||
o.lock.Lock() |
||||
defer o.lock.Unlock() |
||||
return o.provider.Regenerate(oldsid, sid) |
||||
} |
||||
|
||||
// Count counts and returns number of sessions.
|
||||
func (o *VirtualSessionProvider) Count() int { |
||||
o.lock.RLock() |
||||
defer o.lock.RUnlock() |
||||
return o.provider.Count() |
||||
} |
||||
|
||||
// GC calls GC to clean expired sessions.
|
||||
func (o *VirtualSessionProvider) GC() { |
||||
o.provider.GC() |
||||
} |
||||
|
||||
func init() { |
||||
session.Register("VirtualSession", &VirtualSessionProvider{}) |
||||
} |
||||
|
||||
// VirtualStore represents a virtual session store implementation.
|
||||
type VirtualStore struct { |
||||
p *VirtualSessionProvider |
||||
sid string |
||||
lock sync.RWMutex |
||||
data map[interface{}]interface{} |
||||
released bool |
||||
} |
||||
|
||||
// NewVirtualStore creates and returns a virtual session store.
|
||||
func NewVirtualStore(p *VirtualSessionProvider, sid string, kv map[interface{}]interface{}) *VirtualStore { |
||||
return &VirtualStore{ |
||||
p: p, |
||||
sid: sid, |
||||
data: kv, |
||||
} |
||||
} |
||||
|
||||
// Set sets value to given key in session.
|
||||
func (s *VirtualStore) Set(key, val interface{}) error { |
||||
s.lock.Lock() |
||||
defer s.lock.Unlock() |
||||
|
||||
s.data[key] = val |
||||
return nil |
||||
} |
||||
|
||||
// Get gets value by given key in session.
|
||||
func (s *VirtualStore) Get(key interface{}) interface{} { |
||||
s.lock.RLock() |
||||
defer s.lock.RUnlock() |
||||
|
||||
return s.data[key] |
||||
} |
||||
|
||||
// Delete delete a key from session.
|
||||
func (s *VirtualStore) Delete(key interface{}) error { |
||||
s.lock.Lock() |
||||
defer s.lock.Unlock() |
||||
|
||||
delete(s.data, key) |
||||
return nil |
||||
} |
||||
|
||||
// ID returns current session ID.
|
||||
func (s *VirtualStore) ID() string { |
||||
return s.sid |
||||
} |
||||
|
||||
// Release releases resource and save data to provider.
|
||||
func (s *VirtualStore) Release() error { |
||||
s.lock.Lock() |
||||
defer s.lock.Unlock() |
||||
// Now need to lock the provider
|
||||
s.p.lock.Lock() |
||||
defer s.p.lock.Unlock() |
||||
if oldUID, ok := s.data["_old_uid"]; (ok && (oldUID != "0" || len(s.data) > 1)) || (!ok && len(s.data) > 0) { |
||||
// Now ensure that we don't exist!
|
||||
realProvider := s.p.provider |
||||
|
||||
if !s.released && realProvider.Exist(s.sid) { |
||||
// This is an error!
|
||||
return fmt.Errorf("new sid '%s' already exists", s.sid) |
||||
} |
||||
realStore, err := realProvider.Read(s.sid) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if err := realStore.Flush(); err != nil { |
||||
return err |
||||
} |
||||
for key, value := range s.data { |
||||
if err := realStore.Set(key, value); err != nil { |
||||
return err |
||||
} |
||||
} |
||||
err = realStore.Release() |
||||
if err == nil { |
||||
s.released = true |
||||
} |
||||
return err |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// Flush deletes all session data.
|
||||
func (s *VirtualStore) Flush() error { |
||||
s.lock.Lock() |
||||
defer s.lock.Unlock() |
||||
|
||||
s.data = make(map[interface{}]interface{}) |
||||
return nil |
||||
} |
@ -0,0 +1,12 @@ |
||||
// Copyright 2020 The Gitea Authors. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package session |
||||
|
||||
// Store represents a session store
|
||||
type Store interface { |
||||
Get(interface{}) interface{} |
||||
Set(interface{}, interface{}) error |
||||
Delete(interface{}) error |
||||
} |
@ -0,0 +1,322 @@ |
||||
// Copyright 2020 The Gitea Authors. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package web |
||||
|
||||
import ( |
||||
"fmt" |
||||
"net/http" |
||||
"reflect" |
||||
"strings" |
||||
|
||||
"code.gitea.io/gitea/modules/context" |
||||
"code.gitea.io/gitea/modules/middlewares" |
||||
|
||||
"gitea.com/go-chi/binding" |
||||
"github.com/go-chi/chi" |
||||
) |
||||
|
||||
// Wrap converts all kinds of routes to standard library one
|
||||
func Wrap(handlers ...interface{}) http.HandlerFunc { |
||||
if len(handlers) == 0 { |
||||
panic("No handlers found") |
||||
} |
||||
|
||||
for _, handler := range handlers { |
||||
switch t := handler.(type) { |
||||
case http.HandlerFunc, func(http.ResponseWriter, *http.Request), |
||||
func(ctx *context.Context), |
||||
func(*context.APIContext), |
||||
func(*context.PrivateContext), |
||||
func(http.Handler) http.Handler: |
||||
default: |
||||
panic(fmt.Sprintf("Unsupported handler type: %#v", t)) |
||||
} |
||||
} |
||||
return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { |
||||
for i := 0; i < len(handlers); i++ { |
||||
handler := handlers[i] |
||||
switch t := handler.(type) { |
||||
case http.HandlerFunc: |
||||
t(resp, req) |
||||
if r, ok := resp.(context.ResponseWriter); ok && r.Status() > 0 { |
||||
return |
||||
} |
||||
case func(http.ResponseWriter, *http.Request): |
||||
t(resp, req) |
||||
if r, ok := resp.(context.ResponseWriter); ok && r.Status() > 0 { |
||||
return |
||||
} |
||||
case func(ctx *context.Context): |
||||
ctx := context.GetContext(req) |
||||
t(ctx) |
||||
if ctx.Written() { |
||||
return |
||||
} |
||||
case func(*context.APIContext): |
||||
ctx := context.GetAPIContext(req) |
||||
t(ctx) |
||||
if ctx.Written() { |
||||
return |
||||
} |
||||
case func(*context.PrivateContext): |
||||
ctx := context.GetPrivateContext(req) |
||||
t(ctx) |
||||
if ctx.Written() { |
||||
return |
||||
} |
||||
case func(http.Handler) http.Handler: |
||||
var next = http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}) |
||||
t(next).ServeHTTP(resp, req) |
||||
if r, ok := resp.(context.ResponseWriter); ok && r.Status() > 0 { |
||||
return |
||||
} |
||||
default: |
||||
panic(fmt.Sprintf("Unsupported handler type: %#v", t)) |
||||
} |
||||
} |
||||
}) |
||||
} |
||||
|
||||
// Middle wrap a context function as a chi middleware
|
||||
func Middle(f func(ctx *context.Context)) func(netx http.Handler) http.Handler { |
||||
return func(next http.Handler) http.Handler { |
||||
return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { |
||||
ctx := context.GetContext(req) |
||||
f(ctx) |
||||
if ctx.Written() { |
||||
return |
||||
} |
||||
next.ServeHTTP(ctx.Resp, ctx.Req) |
||||
}) |
||||
} |
||||
} |
||||
|
||||
// MiddleAPI wrap a context function as a chi middleware
|
||||
func MiddleAPI(f func(ctx *context.APIContext)) func(netx http.Handler) http.Handler { |
||||
return func(next http.Handler) http.Handler { |
||||
return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { |
||||
ctx := context.GetAPIContext(req) |
||||
f(ctx) |
||||
if ctx.Written() { |
||||
return |
||||
} |
||||
next.ServeHTTP(ctx.Resp, ctx.Req) |
||||
}) |
||||
} |
||||
} |
||||
|
||||
// Bind binding an obj to a handler
|
||||
func Bind(obj interface{}) http.HandlerFunc { |
||||
var tp = reflect.TypeOf(obj) |
||||
if tp.Kind() == reflect.Ptr { |
||||
tp = tp.Elem() |
||||
} |
||||
if tp.Kind() != reflect.Struct { |
||||
panic("Only structs are allowed to bind") |
||||
} |
||||
return Wrap(func(ctx *context.Context) { |
||||
var theObj = reflect.New(tp).Interface() // create a new form obj for every request but not use obj directly
|
||||
binding.Bind(ctx.Req, theObj) |
||||
SetForm(ctx, theObj) |
||||
middlewares.AssignForm(theObj, ctx.Data) |
||||
}) |
||||
} |
||||
|
||||
// SetForm set the form object
|
||||
func SetForm(data middlewares.DataStore, obj interface{}) { |
||||
data.GetData()["__form"] = obj |
||||
} |
||||
|
||||
// GetForm returns the validate form information
|
||||
func GetForm(data middlewares.DataStore) interface{} { |
||||
return data.GetData()["__form"] |
||||
} |
||||
|
||||
// Route defines a route based on chi's router
|
||||
type Route struct { |
||||
R chi.Router |
||||
curGroupPrefix string |
||||
curMiddlewares []interface{} |
||||
} |
||||
|
||||
// NewRoute creates a new route
|
||||
func NewRoute() *Route { |
||||
r := chi.NewRouter() |
||||
return &Route{ |
||||
R: r, |
||||
curGroupPrefix: "", |
||||
curMiddlewares: []interface{}{}, |
||||
} |
||||
} |
||||
|
||||
// Use supports two middlewares
|
||||
func (r *Route) Use(middlewares ...interface{}) { |
||||
if r.curGroupPrefix != "" { |
||||
r.curMiddlewares = append(r.curMiddlewares, middlewares...) |
||||
} else { |
||||
for _, middle := range middlewares { |
||||
switch t := middle.(type) { |
||||
case func(http.Handler) http.Handler: |
||||
r.R.Use(t) |
||||
case func(*context.Context): |
||||
r.R.Use(Middle(t)) |
||||
case func(*context.APIContext): |
||||
r.R.Use(MiddleAPI(t)) |
||||
default: |
||||
panic(fmt.Sprintf("Unsupported middleware type: %#v", t)) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
// Group mounts a sub-Router along a `pattern` string.
|
||||
func (r *Route) Group(pattern string, fn func(), middlewares ...interface{}) { |
||||
var previousGroupPrefix = r.curGroupPrefix |
||||
var previousMiddlewares = r.curMiddlewares |
||||
r.curGroupPrefix += pattern |
||||
r.curMiddlewares = append(r.curMiddlewares, middlewares...) |
||||
|
||||
fn() |
||||
|
||||
r.curGroupPrefix = previousGroupPrefix |
||||
r.curMiddlewares = previousMiddlewares |
||||
} |
||||
|
||||
func (r *Route) getPattern(pattern string) string { |
||||
newPattern := r.curGroupPrefix + pattern |
||||
if !strings.HasPrefix(newPattern, "/") { |
||||
newPattern = "/" + newPattern |
||||
} |
||||
if newPattern == "/" { |
||||
return newPattern |
||||
} |
||||
return strings.TrimSuffix(newPattern, "/") |
||||
} |
||||
|
||||
// Mount attaches another Route along ./pattern/*
|
||||
func (r *Route) Mount(pattern string, subR *Route) { |
||||
var middlewares = make([]interface{}, len(r.curMiddlewares)) |
||||
copy(middlewares, r.curMiddlewares) |
||||
subR.Use(middlewares...) |
||||
r.R.Mount(r.getPattern(pattern), subR.R) |
||||
} |
||||
|
||||
// Any delegate requests for all methods
|
||||
func (r *Route) Any(pattern string, h ...interface{}) { |
||||
var middlewares = r.getMiddlewares(h) |
||||
r.R.HandleFunc(r.getPattern(pattern), Wrap(middlewares...)) |
||||
} |
||||
|
||||
// Route delegate special methods
|
||||
func (r *Route) Route(pattern string, methods string, h ...interface{}) { |
||||
p := r.getPattern(pattern) |
||||
ms := strings.Split(methods, ",") |
||||
var middlewares = r.getMiddlewares(h) |
||||
for _, method := range ms { |
||||
r.R.MethodFunc(strings.TrimSpace(method), p, Wrap(middlewares...)) |
||||
} |
||||
} |
||||
|
||||
// Delete delegate delete method
|
||||
func (r *Route) Delete(pattern string, h ...interface{}) { |
||||
var middlewares = r.getMiddlewares(h) |
||||
r.R.Delete(r.getPattern(pattern), Wrap(middlewares...)) |
||||
} |
||||
|
||||
func (r *Route) getMiddlewares(h []interface{}) []interface{} { |
||||
var middlewares = make([]interface{}, len(r.curMiddlewares), len(r.curMiddlewares)+len(h)) |
||||
copy(middlewares, r.curMiddlewares) |
||||
middlewares = append(middlewares, h...) |
||||
return middlewares |
||||
} |
||||
|
||||
// Get delegate get method
|
||||
func (r *Route) Get(pattern string, h ...interface{}) { |
||||
var middlewares = r.getMiddlewares(h) |
||||
r.R.Get(r.getPattern(pattern), Wrap(middlewares...)) |
||||
} |
||||
|
||||
// Head delegate head method
|
||||
func (r *Route) Head(pattern string, h ...interface{}) { |
||||
var middlewares = r.getMiddlewares(h) |
||||
r.R.Head(r.getPattern(pattern), Wrap(middlewares...)) |
||||
} |
||||
|
||||
// Post delegate post method
|
||||
func (r *Route) Post(pattern string, h ...interface{}) { |
||||
var middlewares = r.getMiddlewares(h) |
||||
r.R.Post(r.getPattern(pattern), Wrap(middlewares...)) |
||||
} |
||||
|
||||
// Put delegate put method
|
||||
func (r *Route) Put(pattern string, h ...interface{}) { |
||||
var middlewares = r.getMiddlewares(h) |
||||
r.R.Put(r.getPattern(pattern), Wrap(middlewares...)) |
||||
} |
||||
|
||||
// Patch delegate patch method
|
||||
func (r *Route) Patch(pattern string, h ...interface{}) { |
||||
var middlewares = r.getMiddlewares(h) |
||||
r.R.Patch(r.getPattern(pattern), Wrap(middlewares...)) |
||||
} |
||||
|
||||
// ServeHTTP implements http.Handler
|
||||
func (r *Route) ServeHTTP(w http.ResponseWriter, req *http.Request) { |
||||
r.R.ServeHTTP(w, req) |
||||
} |
||||
|
||||
// NotFound defines a handler to respond whenever a route could
|
||||
// not be found.
|
||||
func (r *Route) NotFound(h http.HandlerFunc) { |
||||
r.R.NotFound(h) |
||||
} |
||||
|
||||
// MethodNotAllowed defines a handler to respond whenever a method is
|
||||
// not allowed.
|
||||
func (r *Route) MethodNotAllowed(h http.HandlerFunc) { |
||||
r.R.MethodNotAllowed(h) |
||||
} |
||||
|
||||
// Combo deletegate requests to Combo
|
||||
func (r *Route) Combo(pattern string, h ...interface{}) *Combo { |
||||
return &Combo{r, pattern, h} |
||||
} |
||||
|
||||
// Combo represents a tiny group routes with same pattern
|
||||
type Combo struct { |
||||
r *Route |
||||
pattern string |
||||
h []interface{} |
||||
} |
||||
|
||||
// Get deletegate Get method
|
||||
func (c *Combo) Get(h ...interface{}) *Combo { |
||||
c.r.Get(c.pattern, append(c.h, h...)...) |
||||
return c |
||||
} |
||||
|
||||
// Post deletegate Post method
|
||||
func (c *Combo) Post(h ...interface{}) *Combo { |
||||
c.r.Post(c.pattern, append(c.h, h...)...) |
||||
return c |
||||
} |
||||
|
||||
// Delete deletegate Delete method
|
||||
func (c *Combo) Delete(h ...interface{}) *Combo { |
||||
c.r.Delete(c.pattern, append(c.h, h...)...) |
||||
return c |
||||
} |
||||
|
||||
// Put deletegate Put method
|
||||
func (c *Combo) Put(h ...interface{}) *Combo { |
||||
c.r.Put(c.pattern, append(c.h, h...)...) |
||||
return c |
||||
} |
||||
|
||||
// Patch deletegate Patch method
|
||||
func (c *Combo) Patch(h ...interface{}) *Combo { |
||||
c.r.Patch(c.pattern, append(c.h, h...)...) |
||||
return c |
||||
} |
@ -0,0 +1,169 @@ |
||||
// Copyright 2021 The Gitea Authors. All rights reserved.
|
||||
// Use of this source code is governed by a MIT-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package web |
||||
|
||||
import ( |
||||
"bytes" |
||||
"net/http" |
||||
"net/http/httptest" |
||||
"testing" |
||||
|
||||
"github.com/go-chi/chi" |
||||
"github.com/stretchr/testify/assert" |
||||
) |
||||
|
||||
func TestRoute1(t *testing.T) { |
||||
buff := bytes.NewBufferString("") |
||||
recorder := httptest.NewRecorder() |
||||
recorder.Body = buff |
||||
|
||||
r := NewRoute() |
||||
r.Get("/{username}/{reponame}/{type:issues|pulls}", func(resp http.ResponseWriter, req *http.Request) { |
||||
username := chi.URLParam(req, "username") |
||||
assert.EqualValues(t, "gitea", username) |
||||
reponame := chi.URLParam(req, "reponame") |
||||
assert.EqualValues(t, "gitea", reponame) |
||||
tp := chi.URLParam(req, "type") |
||||
assert.EqualValues(t, "issues", tp) |
||||
}) |
||||
|
||||
req, err := http.NewRequest("GET", "http://localhost:8000/gitea/gitea/issues", nil) |
||||
assert.NoError(t, err) |
||||
r.ServeHTTP(recorder, req) |
||||
assert.EqualValues(t, http.StatusOK, recorder.Code) |
||||
} |
||||
|
||||
func TestRoute2(t *testing.T) { |
||||
buff := bytes.NewBufferString("") |
||||
recorder := httptest.NewRecorder() |
||||
recorder.Body = buff |
||||
|
||||
var route int |
||||
|
||||
r := NewRoute() |
||||
r.Group("/{username}/{reponame}", func() { |
||||
r.Group("", func() { |
||||
r.Get("/{type:issues|pulls}", func(resp http.ResponseWriter, req *http.Request) { |
||||
username := chi.URLParam(req, "username") |
||||
assert.EqualValues(t, "gitea", username) |
||||
reponame := chi.URLParam(req, "reponame") |
||||
assert.EqualValues(t, "gitea", reponame) |
||||
tp := chi.URLParam(req, "type") |
||||
assert.EqualValues(t, "issues", tp) |
||||
route = 0 |
||||
}) |
||||
|
||||
r.Get("/{type:issues|pulls}/{index}", func(resp http.ResponseWriter, req *http.Request) { |
||||
username := chi.URLParam(req, "username") |
||||
assert.EqualValues(t, "gitea", username) |
||||
reponame := chi.URLParam(req, "reponame") |
||||
assert.EqualValues(t, "gitea", reponame) |
||||
tp := chi.URLParam(req, "type") |
||||
assert.EqualValues(t, "issues", tp) |
||||
index := chi.URLParam(req, "index") |
||||
assert.EqualValues(t, "1", index) |
||||
route = 1 |
||||
}) |
||||
}, func(resp http.ResponseWriter, req *http.Request) { |
||||
resp.WriteHeader(200) |
||||
}) |
||||
|
||||
r.Group("/issues/{index}", func() { |
||||
r.Get("/view", func(resp http.ResponseWriter, req *http.Request) { |
||||
username := chi.URLParam(req, "username") |
||||
assert.EqualValues(t, "gitea", username) |
||||
reponame := chi.URLParam(req, "reponame") |
||||
assert.EqualValues(t, "gitea", reponame) |
||||
index := chi.URLParam(req, "index") |
||||
assert.EqualValues(t, "1", index) |
||||
route = 2 |
||||
}) |
||||
}) |
||||
}) |
||||
|
||||
req, err := http.NewRequest("GET", "http://localhost:8000/gitea/gitea/issues", nil) |
||||
assert.NoError(t, err) |
||||
r.ServeHTTP(recorder, req) |
||||
assert.EqualValues(t, http.StatusOK, recorder.Code) |
||||
assert.EqualValues(t, 0, route) |
||||
|
||||
req, err = http.NewRequest("GET", "http://localhost:8000/gitea/gitea/issues/1", nil) |
||||
assert.NoError(t, err) |
||||
r.ServeHTTP(recorder, req) |
||||
assert.EqualValues(t, http.StatusOK, recorder.Code) |
||||
assert.EqualValues(t, 1, route) |
||||
|
||||
req, err = http.NewRequest("GET", "http://localhost:8000/gitea/gitea/issues/1/view", nil) |
||||
assert.NoError(t, err) |
||||
r.ServeHTTP(recorder, req) |
||||
assert.EqualValues(t, http.StatusOK, recorder.Code) |
||||
assert.EqualValues(t, 2, route) |
||||
} |
||||
|
||||
func TestRoute3(t *testing.T) { |
||||
buff := bytes.NewBufferString("") |
||||
recorder := httptest.NewRecorder() |
||||
recorder.Body = buff |
||||
|
||||
var route int |
||||
|
||||
m := NewRoute() |
||||
r := NewRoute() |
||||
r.Mount("/api/v1", m) |
||||
|
||||
m.Group("/repos", func() { |
||||
m.Group("/{username}/{reponame}", func() { |
||||
m.Group("/branch_protections", func() { |
||||
m.Get("", func(resp http.ResponseWriter, req *http.Request) { |
||||
route = 0 |
||||
}) |
||||
m.Post("", func(resp http.ResponseWriter, req *http.Request) { |
||||
route = 1 |
||||
}) |
||||
m.Group("/{name}", func() { |
||||
m.Get("", func(resp http.ResponseWriter, req *http.Request) { |
||||
route = 2 |
||||
}) |
||||
m.Patch("", func(resp http.ResponseWriter, req *http.Request) { |
||||
route = 3 |
||||
}) |
||||
m.Delete("", func(resp http.ResponseWriter, req *http.Request) { |
||||
route = 4 |
||||
}) |
||||
}) |
||||
}) |
||||
}) |
||||
}) |
||||
|
||||
req, err := http.NewRequest("GET", "http://localhost:8000/api/v1/repos/gitea/gitea/branch_protections", nil) |
||||
assert.NoError(t, err) |
||||
r.ServeHTTP(recorder, req) |
||||
assert.EqualValues(t, http.StatusOK, recorder.Code) |
||||
assert.EqualValues(t, 0, route) |
||||
|
||||
req, err = http.NewRequest("POST", "http://localhost:8000/api/v1/repos/gitea/gitea/branch_protections", nil) |
||||
assert.NoError(t, err) |
||||
r.ServeHTTP(recorder, req) |
||||
assert.EqualValues(t, http.StatusOK, recorder.Code, http.StatusOK) |
||||
assert.EqualValues(t, 1, route) |
||||
|
||||
req, err = http.NewRequest("GET", "http://localhost:8000/api/v1/repos/gitea/gitea/branch_protections/master", nil) |
||||
assert.NoError(t, err) |
||||
r.ServeHTTP(recorder, req) |
||||
assert.EqualValues(t, http.StatusOK, recorder.Code) |
||||
assert.EqualValues(t, 2, route) |
||||
|
||||
req, err = http.NewRequest("PATCH", "http://localhost:8000/api/v1/repos/gitea/gitea/branch_protections/master", nil) |
||||
assert.NoError(t, err) |
||||
r.ServeHTTP(recorder, req) |
||||
assert.EqualValues(t, http.StatusOK, recorder.Code) |
||||
assert.EqualValues(t, 3, route) |
||||
|
||||
req, err = http.NewRequest("DELETE", "http://localhost:8000/api/v1/repos/gitea/gitea/branch_protections/master", nil) |
||||
assert.NoError(t, err) |
||||
r.ServeHTTP(recorder, req) |
||||
assert.EqualValues(t, http.StatusOK, recorder.Code) |
||||
assert.EqualValues(t, 4, route) |
||||
} |
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue