@ -26,7 +26,10 @@ import (
"code.google.com/p/goauth2/oauth"
"github.com/go-martini/martini"
"github.com/martini-contrib/sessions"
"github.com/gogits/session"
"github.com/gogits/gogs/modules/middleware"
)
const (
@ -142,23 +145,23 @@ func NewOAuth2Provider(opts *Options) martini.Handler {
Transport : http . DefaultTransport ,
}
return func ( s sessions . Session , c martini . Context , w http . ResponseWriter , r * http . Reques t) {
if r . Method == "GET" {
switch r . URL . Path {
return func ( c martini . Context , ctx * middleware . Contex t) {
if ctx . Req . Method == "GET" {
switch ctx . Req . URL . Path {
case PathLogin :
login ( transport , s , w , r )
login ( transport , ctx )
case PathLogout :
logout ( transport , s , w , r )
logout ( transport , ctx )
case PathCallback :
handleOAuth2Callback ( transport , s , w , r )
handleOAuth2Callback ( transport , ctx )
}
}
tk := unmarshallToken ( s )
tk := unmarshallToken ( ctx . Session )
if tk != nil {
// check if the access token is expired
if tk . IsExpired ( ) && tk . Refresh ( ) == "" {
s . Delete ( keyToken )
ctx . Session . Delete ( keyToken )
tk = nil
}
}
@ -172,49 +175,49 @@ func NewOAuth2Provider(opts *Options) martini.Handler {
// Sample usage:
// m.Get("/login-required", oauth2.LoginRequired, func() ... {})
var LoginRequired martini . Handler = func ( ) martini . Handler {
return func ( s sessions . Session , c martini . Context , w http . ResponseWriter , r * http . Reques t) {
token := unmarshallToken ( s )
return func ( c martini . Context , ctx * middleware . Contex t) {
token := unmarshallToken ( ctx . Session )
if token == nil || token . IsExpired ( ) {
next := url . QueryEscape ( r . URL . RequestURI ( ) )
http . Redirect ( w , r , PathLogin + "?next=" + next , codeRedirect )
next := url . QueryEscape ( ctx . Req . URL . RequestURI ( ) )
ctx . Redirect ( PathLogin + "?next=" + next , codeRedirect )
}
}
} ( )
func login ( t * oauth . Transport , s sessions . Session , w http . ResponseWriter , r * http . Reques t) {
next := extractPath ( r . URL . Query ( ) . Get ( keyNextPage ) )
if s . Get ( keyToken ) == nil {
func login ( t * oauth . Transport , ctx * middleware . Contex t) {
next := extractPath ( ctx . Req . URL . Query ( ) . Get ( keyNextPage ) )
if ctx . Session . Get ( keyToken ) == nil {
// User is not logged in.
http . Redirect ( w , r , t . Config . AuthCodeURL ( next ) , codeRedirect )
ctx . Redirect ( t . Config . AuthCodeURL ( next ) , codeRedirect )
return
}
// No need to login, redirect to the next page.
http . Redirect ( w , r , next , codeRedirect )
ctx . Redirect ( next , codeRedirect )
}
func logout ( t * oauth . Transport , s sessions . Session , w http . ResponseWriter , r * http . Reques t) {
next := extractPath ( r . URL . Query ( ) . Get ( keyNextPage ) )
s . Delete ( keyToken )
http . Redirect ( w , r , next , codeRedirect )
func logout ( t * oauth . Transport , ctx * middleware . Contex t) {
next := extractPath ( ctx . Req . URL . Query ( ) . Get ( keyNextPage ) )
ctx . Session . Delete ( keyToken )
ctx . Redirect ( next , codeRedirect )
}
func handleOAuth2Callback ( t * oauth . Transport , s sessions . Session , w http . ResponseWriter , r * http . Reques t) {
next := extractPath ( r . URL . Query ( ) . Get ( "state" ) )
code := r . URL . Query ( ) . Get ( "code" )
func handleOAuth2Callback ( t * oauth . Transport , ctx * middleware . Contex t) {
next := extractPath ( ctx . Req . URL . Query ( ) . Get ( "state" ) )
code := ctx . Req . URL . Query ( ) . Get ( "code" )
tk , err := t . Exchange ( code )
if err != nil {
// Pass the error message, or allow dev to provide its own
// error handler.
http . Redirect ( w , r , PathError , codeRedirect )
ctx . Redirect ( PathError , codeRedirect )
return
}
// Store the credentials in the session.
val , _ := json . Marshal ( tk )
s . Set ( keyToken , val )
http . Redirect ( w , r , next , codeRedirect )
ctx . Session . Set ( keyToken , val )
ctx . Redirect ( next , codeRedirect )
}
func unmarshallToken ( s sessions . Session ) ( t * token ) {
func unmarshallToken ( s session . SessionStore ) ( t * token ) {
if s . Get ( keyToken ) == nil {
return
}