diff --git a/providers/onedrive/onedrive.go b/providers/onedrive/onedrive.go index 877894f8..4d61e2cd 100644 --- a/providers/onedrive/onedrive.go +++ b/providers/onedrive/onedrive.go @@ -3,21 +3,19 @@ package onedrive import ( - "bytes" "encoding/json" "fmt" "io" "net/http" - "net/url" "github.com/markbates/goth" "golang.org/x/oauth2" ) -const ( - authURL string = "https://login.live.com/oauth20_authorize.srf" - tokenURL string = "https://login.live.com/oauth20_token.srf" - endpointProfile string = "https://apis.live.net/v5.0/me" +var ( + authURL = "https://login.live.com/oauth20_authorize.srf" + tokenURL = "https://login.live.com/oauth20_token.srf" + endpointProfile = "https://graph.microsoft.com/v1.0/me" ) // Provider is the implementation of `goth.Provider` for accessing Onedrive. @@ -30,6 +28,18 @@ type Provider struct { providerName string } +func (p *Provider) SetAuthURL(url string) { + authURL = url +} + +func (p *Provider) SetTokenURL(url string) { + tokenURL = url +} + +func (p *Provider) SetEndpointProfile(url string) { + endpointProfile = url +} + // New creates a new Onedrive provider and sets up important connection details. // You should always call `onedrive.New` to get a new provider. Never try to // create one manually. @@ -68,7 +78,7 @@ func (p *Provider) BeginAuth(state string) (goth.Session, error) { }, nil } -// FetchUser will go to Onedrive and access basic information about the user. +// FetchUser will go to Microsoft Graph API and access basic information about the user. func (p *Provider) FetchUser(session goth.Session) (goth.User, error) { sess := session.(*Session) user := goth.User{ @@ -79,31 +89,40 @@ func (p *Provider) FetchUser(session goth.Session) (goth.User, error) { } if user.AccessToken == "" { - // data is not yet retrieved since accessToken is still empty return user, fmt.Errorf("%s cannot get user information without accessToken", p.providerName) } - response, err := p.Client().Get(endpointProfile + "?access_token=" + url.QueryEscape(sess.AccessToken)) + req, err := http.NewRequest("GET", endpointProfile, nil) if err != nil { return user, err } - defer response.Body.Close() - - if response.StatusCode != http.StatusOK { - return user, fmt.Errorf("%s responded with a %d trying to fetch user information", p.providerName, response.StatusCode) - } + req.Header.Set("Authorization", "Bearer "+sess.AccessToken) - bits, err := io.ReadAll(response.Body) + resp, err := p.Client().Do(req) if err != nil { return user, err } + defer resp.Body.Close() - err = json.NewDecoder(bytes.NewReader(bits)).Decode(&user.RawData) - if err != nil { + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return user, fmt.Errorf("%s responded with %d: %s", p.providerName, resp.StatusCode, string(body)) + } + + var raw map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&raw); err != nil { return user, err } - err = userFromReader(bytes.NewReader(bits), &user) - return user, err + + user.RawData = raw + user.UserID = raw["id"].(string) + user.Email = raw["mail"].(string) + if user.Email == "" { + user.Email = raw["userPrincipalName"].(string) // Fallback if mail is missing + } + user.Name = raw["displayName"].(string) + + return user, nil } func newConfig(provider *Provider, scopes []string) *oauth2.Config {