bump deps

Signed-off-by: Alex Suraci <asuraci@pivotal.io>
This commit is contained in:
Chris Brown
2015-02-17 18:08:25 -08:00
committed by Alex Suraci
parent efbfb91683
commit 8b3dc6edbe
57 changed files with 1827 additions and 255 deletions

View File

@@ -67,7 +67,7 @@ func main() {
// Get an authorization code from the data provider.
// ("Please ask the user if I can access this resource.")
url := config.AuthCodeURL("")
fmt.Println("Visit this URL to get a code, then run again with -code=YOUR_CODE\n")
fmt.Print("Visit this URL to get a code, then run again with -code=YOUR_CODE\n\n")
fmt.Println(url)
return
}

View File

@@ -2,8 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// The oauth package provides support for making
// OAuth2-authenticated HTTP requests.
// Package oauth supports making OAuth2-authenticated HTTP requests.
//
// Example usage:
//
@@ -39,14 +38,23 @@ package oauth
import (
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"mime"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"sync"
"time"
)
// OAuthError is the error type returned by many operations.
//
// In retrospect it should not exist. Don't depend on it.
type OAuthError struct {
prefix string
msg string
@@ -122,7 +130,16 @@ type Config struct {
// TokenCache allows tokens to be cached for subsequent requests.
TokenCache Cache
AccessType string // Optional, "online" (default) or "offline", no refresh token if "online"
// AccessType is an OAuth extension that gets sent as the
// "access_type" field in the URL from AuthCodeURL.
// See https://developers.google.com/accounts/docs/OAuth2WebServer.
// It may be "online" (the default) or "offline".
// If your application needs to refresh access tokens when the
// user is not present at the browser, then use offline. This
// will result in your application obtaining a refresh token
// the first time your application exchanges an authorization
// code for a user.
AccessType string
// ApprovalPrompt indicates whether the user should be
// re-prompted for consent. If set to "auto" (default) the
@@ -139,11 +156,20 @@ type Config struct {
type Token struct {
AccessToken string
RefreshToken string
Expiry time.Time // If zero the token has no (known) expiry time.
Extra map[string]string // May be nil.
Expiry time.Time // If zero the token has no (known) expiry time.
// Extra optionally contains extra metadata from the server
// when updating a token. The only current key that may be
// populated is "id_token". It may be nil and will be
// initialized as needed.
Extra map[string]string
}
// Expired reports whether the token has expired or is invalid.
func (t *Token) Expired() bool {
if t.AccessToken == "" {
return true
}
if t.Expiry.IsZero() {
return false
}
@@ -164,6 +190,9 @@ type Transport struct {
*Config
*Token
// mu guards modifying the token.
mu sync.Mutex
// Transport is the HTTP transport to use when making requests.
// It will default to http.DefaultTransport if nil.
// (It should never be an oauth.Transport.)
@@ -192,11 +221,11 @@ func (c *Config) AuthCodeURL(state string) string {
q := url.Values{
"response_type": {"code"},
"client_id": {c.ClientId},
"redirect_uri": {c.RedirectURL},
"scope": {c.Scope},
"state": {state},
"access_type": {c.AccessType},
"approval_prompt": {c.ApprovalPrompt},
"state": condVal(state),
"scope": condVal(c.Scope),
"redirect_uri": condVal(c.RedirectURL),
"access_type": condVal(c.AccessType),
"approval_prompt": condVal(c.ApprovalPrompt),
}.Encode()
if url_.RawQuery == "" {
url_.RawQuery = q
@@ -206,6 +235,13 @@ func (c *Config) AuthCodeURL(state string) string {
return url_.String()
}
func condVal(v string) []string {
if v == "" {
return nil
}
return []string{v}
}
// Exchange takes a code and gets access Token from the remote server.
func (t *Transport) Exchange(code string) (*Token, error) {
if t.Config == nil {
@@ -246,35 +282,48 @@ func (t *Transport) Exchange(code string) (*Token, error) {
// If the Token is invalid callers should expect HTTP-level errors,
// as indicated by the Response's StatusCode.
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
accessToken, err := t.getAccessToken()
if err != nil {
return nil, err
}
// To set the Authorization header, we must make a copy of the Request
// so that we don't modify the Request we were given.
// This is required by the specification of http.RoundTripper.
req = cloneRequest(req)
req.Header.Set("Authorization", "Bearer "+accessToken)
// Make the HTTP request.
return t.transport().RoundTrip(req)
}
func (t *Transport) getAccessToken() (string, error) {
t.mu.Lock()
defer t.mu.Unlock()
if t.Token == nil {
if t.Config == nil {
return nil, OAuthError{"RoundTrip", "no Config supplied"}
return "", OAuthError{"RoundTrip", "no Config supplied"}
}
if t.TokenCache == nil {
return nil, OAuthError{"RoundTrip", "no Token supplied"}
return "", OAuthError{"RoundTrip", "no Token supplied"}
}
var err error
t.Token, err = t.TokenCache.Token()
if err != nil {
return nil, err
return "", err
}
}
// Refresh the Token if it has expired.
if t.Expired() {
if err := t.Refresh(); err != nil {
return nil, err
return "", err
}
}
// To set the Authorization header, we must make a copy of the Request
// so that we don't modify the Request we were given.
// This is required by the specification of http.RoundTripper.
req = cloneRequest(req)
req.Header.Set("Authorization", "Bearer "+t.AccessToken)
// Make the HTTP request.
return t.transport().RoundTrip(req)
if t.AccessToken == "" {
return "", errors.New("no access token obtained from refresh")
}
return t.AccessToken, nil
}
// cloneRequest returns a clone of the provided *http.Request.
@@ -316,31 +365,80 @@ func (t *Transport) Refresh() error {
return nil
}
// AuthenticateClient gets an access Token using the client_credentials grant
// type.
func (t *Transport) AuthenticateClient() error {
if t.Config == nil {
return OAuthError{"Exchange", "no Config supplied"}
}
if t.Token == nil {
t.Token = &Token{}
}
return t.updateToken(t.Token, url.Values{"grant_type": {"client_credentials"}})
}
// providerAuthHeaderWorks reports whether the OAuth2 server identified by the tokenURL
// implements the OAuth2 spec correctly
// See https://code.google.com/p/goauth2/issues/detail?id=31 for background.
// In summary:
// - Reddit only accepts client secret in the Authorization header
// - Dropbox accepts either it in URL param or Auth header, but not both.
// - Google only accepts URL param (not spec compliant?), not Auth header
func providerAuthHeaderWorks(tokenURL string) bool {
if strings.HasPrefix(tokenURL, "https://accounts.google.com/") ||
strings.HasPrefix(tokenURL, "https://github.com/") ||
strings.HasPrefix(tokenURL, "https://api.instagram.com/") ||
strings.HasPrefix(tokenURL, "https://www.douban.com/") {
// Some sites fail to implement the OAuth2 spec fully.
return false
}
// Assume the provider implements the spec properly
// otherwise. We can add more exceptions as they're
// discovered. We will _not_ be adding configurable hooks
// to this package to let users select server bugs.
return true
}
// updateToken mutates both tok and v.
func (t *Transport) updateToken(tok *Token, v url.Values) error {
v.Set("client_id", t.ClientId)
v.Set("client_secret", t.ClientSecret)
r, err := (&http.Client{Transport: t.transport()}).PostForm(t.TokenURL, v)
bustedAuth := !providerAuthHeaderWorks(t.TokenURL)
if bustedAuth {
v.Set("client_secret", t.ClientSecret)
}
client := &http.Client{Transport: t.transport()}
req, err := http.NewRequest("POST", t.TokenURL, strings.NewReader(v.Encode()))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
if !bustedAuth {
req.SetBasicAuth(t.ClientId, t.ClientSecret)
}
r, err := client.Do(req)
if err != nil {
return err
}
defer r.Body.Close()
if r.StatusCode != 200 {
return OAuthError{"updateToken", r.Status}
return OAuthError{"updateToken", "Unexpected HTTP status " + r.Status}
}
var b struct {
Access string `json:"access_token"`
Refresh string `json:"refresh_token"`
ExpiresIn time.Duration `json:"expires_in"`
Id string `json:"id_token"`
Access string `json:"access_token"`
Refresh string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"` // seconds
Id string `json:"id_token"`
}
body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20))
if err != nil {
return err
}
content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type"))
switch content {
case "application/x-www-form-urlencoded", "text/plain":
body, err := ioutil.ReadAll(r.Body)
if err != nil {
return err
}
vals, err := url.ParseQuery(string(body))
if err != nil {
return err
@@ -348,25 +446,25 @@ func (t *Transport) updateToken(tok *Token, v url.Values) error {
b.Access = vals.Get("access_token")
b.Refresh = vals.Get("refresh_token")
b.ExpiresIn, _ = time.ParseDuration(vals.Get("expires_in") + "s")
b.ExpiresIn, _ = strconv.ParseInt(vals.Get("expires_in"), 10, 64)
b.Id = vals.Get("id_token")
default:
if err = json.NewDecoder(r.Body).Decode(&b); err != nil {
return err
if err = json.Unmarshal(body, &b); err != nil {
return fmt.Errorf("got bad response from server: %q", body)
}
// The JSON parser treats the unitless ExpiresIn like 'ns' instead of 's' as above,
// so compensate here.
b.ExpiresIn *= time.Second
}
if b.Access == "" {
return errors.New("received empty access token from authorization server")
}
tok.AccessToken = b.Access
// Don't overwrite `RefreshToken` with an empty value
if len(b.Refresh) > 0 {
if b.Refresh != "" {
tok.RefreshToken = b.Refresh
}
if b.ExpiresIn == 0 {
tok.Expiry = time.Time{}
} else {
tok.Expiry = time.Now().Add(b.ExpiresIn)
tok.Expiry = time.Now().Add(time.Duration(b.ExpiresIn) * time.Second)
}
if b.Id != "" {
if tok.Extra == nil {

View File

@@ -25,6 +25,7 @@ var requests = []struct {
path: "/token",
query: "grant_type=authorization_code&code=c0d3&client_id=cl13nt1d",
contenttype: "application/json",
auth: "Basic Y2wxM250MWQ6czNjcjN0",
body: `
{
"access_token":"token1",
@@ -39,6 +40,7 @@ var requests = []struct {
path: "/token",
query: "grant_type=refresh_token&refresh_token=refreshtoken1&client_id=cl13nt1d",
contenttype: "application/json",
auth: "Basic Y2wxM250MWQ6czNjcjN0",
body: `
{
"access_token":"token2",
@@ -54,8 +56,22 @@ var requests = []struct {
query: "grant_type=refresh_token&refresh_token=refreshtoken2&client_id=cl13nt1d",
contenttype: "application/x-www-form-urlencoded",
body: "access_token=token3&refresh_token=refreshtoken3&id_token=idtoken3&expires_in=3600",
auth: "Basic Y2wxM250MWQ6czNjcjN0",
},
{path: "/secure", auth: "Bearer token3", body: "third payload"},
{
path: "/token",
query: "grant_type=client_credentials&client_id=cl13nt1d",
contenttype: "application/json",
auth: "Basic Y2wxM250MWQ6czNjcjN0",
body: `
{
"access_token":"token4",
"expires_in":3600
}
`,
},
{path: "/secure", auth: "Bearer token4", body: "fourth payload"},
}
func TestOAuth(t *testing.T) {
@@ -131,6 +147,18 @@ func TestOAuth(t *testing.T) {
}
checkBody(t, resp, "third payload")
checkToken(t, transport.Token, "token3", "refreshtoken3", "idtoken3")
transport.Token = &Token{}
err = transport.AuthenticateClient()
if err != nil {
t.Fatalf("AuthenticateClient: %v", err)
}
checkToken(t, transport.Token, "token4", "", "")
resp, err = c.Get(server.URL + "/secure")
if err != nil {
t.Fatalf("Get: %v", err)
}
checkBody(t, resp, "fourth payload")
}
func checkToken(t *testing.T, tok *Token, access, refresh, id string) {
@@ -143,16 +171,21 @@ func checkToken(t *testing.T, tok *Token, access, refresh, id string) {
if g, w := tok.Extra["id_token"], id; g != w {
t.Errorf("Extra['id_token'] = %q, want %q", g, w)
}
exp := tok.Expiry.Sub(time.Now())
if (time.Hour-time.Second) > exp || exp > time.Hour {
t.Errorf("Expiry = %v, want ~1 hour", exp)
if tok.Expiry.IsZero() {
t.Errorf("Expiry is zero; want ~1 hour")
} else {
exp := tok.Expiry.Sub(time.Now())
const slop = 3 * time.Second // time moving during test
if (time.Hour-slop) > exp || exp > time.Hour {
t.Errorf("Expiry = %v, want ~1 hour", exp)
}
}
}
func checkBody(t *testing.T, r *http.Response, body string) {
b, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Error("reading reponse body: %v, want %q", err, body)
t.Errorf("reading reponse body: %v, want %q", err, body)
}
if g, w := string(b), body; g != w {
t.Errorf("request body mismatch: got %q, want %q", g, w)
@@ -184,3 +217,20 @@ func TestCachePermissions(t *testing.T) {
t.Errorf("Created cache file has mode %#o, want non-accessible to group+other", fi.Mode())
}
}
func TestTokenExpired(t *testing.T) {
tests := []struct {
token Token
expired bool
}{
{Token{AccessToken: "foo"}, false},
{Token{AccessToken: ""}, true},
{Token{AccessToken: "foo", Expiry: time.Now().Add(-1 * time.Hour)}, true},
{Token{AccessToken: "foo", Expiry: time.Now().Add(1 * time.Hour)}, false},
}
for _, tt := range tests {
if got := tt.token.Expired(); got != tt.expired {
t.Errorf("token %+v Expired = %v; want %v", tt.token, got, !got)
}
}
}