Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 38 additions & 19 deletions providers/onedrive/onedrive.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,19 @@
package onedrive

import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"

"github.com/markbates/goth"
"golang.org/x/oauth2"
)

const (
authURL string = "https://login.live.com/oauth20_authorize.srf"
tokenURL string = "https://login.live.com/oauth20_token.srf"
endpointProfile string = "https://apis.live.net/v5.0/me"
var (
authURL = "https://login.live.com/oauth20_authorize.srf"
tokenURL = "https://login.live.com/oauth20_token.srf"
endpointProfile = "https://graph.microsoft.com/v1.0/me"
)

// Provider is the implementation of `goth.Provider` for accessing Onedrive.
Expand All @@ -30,6 +28,18 @@ type Provider struct {
providerName string
}

func (p *Provider) SetAuthURL(url string) {
authURL = url
}

func (p *Provider) SetTokenURL(url string) {
tokenURL = url
}

func (p *Provider) SetEndpointProfile(url string) {
endpointProfile = url
}

// New creates a new Onedrive provider and sets up important connection details.
// You should always call `onedrive.New` to get a new provider. Never try to
// create one manually.
Expand Down Expand Up @@ -68,7 +78,7 @@ func (p *Provider) BeginAuth(state string) (goth.Session, error) {
}, nil
}

// FetchUser will go to Onedrive and access basic information about the user.
// FetchUser will go to Microsoft Graph API and access basic information about the user.
func (p *Provider) FetchUser(session goth.Session) (goth.User, error) {
sess := session.(*Session)
user := goth.User{
Expand All @@ -79,31 +89,40 @@ func (p *Provider) FetchUser(session goth.Session) (goth.User, error) {
}

if user.AccessToken == "" {
// data is not yet retrieved since accessToken is still empty
return user, fmt.Errorf("%s cannot get user information without accessToken", p.providerName)
}

response, err := p.Client().Get(endpointProfile + "?access_token=" + url.QueryEscape(sess.AccessToken))
req, err := http.NewRequest("GET", endpointProfile, nil)
if err != nil {
return user, err
}
defer response.Body.Close()

if response.StatusCode != http.StatusOK {
return user, fmt.Errorf("%s responded with a %d trying to fetch user information", p.providerName, response.StatusCode)
}
req.Header.Set("Authorization", "Bearer "+sess.AccessToken)

bits, err := io.ReadAll(response.Body)
resp, err := p.Client().Do(req)
if err != nil {
return user, err
}
defer resp.Body.Close()

err = json.NewDecoder(bytes.NewReader(bits)).Decode(&user.RawData)
if err != nil {
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return user, fmt.Errorf("%s responded with %d: %s", p.providerName, resp.StatusCode, string(body))
}

var raw map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&raw); err != nil {
return user, err
}
err = userFromReader(bytes.NewReader(bits), &user)
return user, err

user.RawData = raw
user.UserID = raw["id"].(string)
user.Email = raw["mail"].(string)
if user.Email == "" {
user.Email = raw["userPrincipalName"].(string) // Fallback if mail is missing
}
user.Name = raw["displayName"].(string)

return user, nil
}

func newConfig(provider *Provider, scopes []string) *oauth2.Config {
Expand Down