diff --git a/client.go b/client.go index c34fa1d..4384f3b 100644 --- a/client.go +++ b/client.go @@ -92,12 +92,19 @@ func (c *Client) ValidateToken(ctx context.Context) (userID *int64, err error) { return } var r struct { + Result *bool `json:"result"` UserID *int64 `json:"user_id"` } _, err = c.Do(req, &r) // nolint:bodyclose if err != nil { return nil, err } + if r.Result != nil && !*r.Result { + return nil, ErrInvalidToken + } + if r.UserID == nil { + return nil, ErrInvalidToken + } return r.UserID, nil } diff --git a/client_test.go b/client_test.go index 619990b..02e397a 100644 --- a/client_test.go +++ b/client_test.go @@ -2,6 +2,7 @@ package putio import ( "context" + "errors" "fmt" "net/http" "net/http/httptest" @@ -72,3 +73,59 @@ func TestNewRequest_customUserAgent(t *testing.T) { t.Errorf("got: %v, want: %v", got, userAgent) } } + +func TestValidateToken(t *testing.T) { + t.Run("valid token", func(t *testing.T) { + setup() + defer teardown() + + mux.HandleFunc("/v2/oauth2/validate", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, http.MethodGet) + _, _ = w.Write([]byte(`{"result":true,"user_id":123}`)) + }) + + got, err := client.ValidateToken(context.Background()) + if err != nil { + t.Fatalf("ValidateToken returned error: %v", err) + } + if got == nil || *got != 123 { + t.Fatalf("ValidateToken userID = %v, want 123", got) + } + }) + + t.Run("invalid token result", func(t *testing.T) { + setup() + defer teardown() + + mux.HandleFunc("/v2/oauth2/validate", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, http.MethodGet) + _, _ = w.Write([]byte(`{"result":false,"token_id":null,"token_scope":null,"user_id":null}`)) + }) + + got, err := client.ValidateToken(context.Background()) + if got != nil { + t.Fatalf("ValidateToken userID = %v, want nil", got) + } + if !errors.Is(err, ErrInvalidToken) { + t.Fatalf("ValidateToken error = %v, want %v", err, ErrInvalidToken) + } + }) + + t.Run("missing user id", func(t *testing.T) { + setup() + defer teardown() + + mux.HandleFunc("/v2/oauth2/validate", func(w http.ResponseWriter, r *http.Request) { + testMethod(t, r, http.MethodGet) + _, _ = w.Write([]byte(`{}`)) + }) + + got, err := client.ValidateToken(context.Background()) + if got != nil { + t.Fatalf("ValidateToken userID = %v, want nil", got) + } + if !errors.Is(err, ErrInvalidToken) { + t.Fatalf("ValidateToken error = %v, want %v", err, ErrInvalidToken) + } + }) +} diff --git a/error.go b/error.go index 7aa2ca1..45f5b0e 100644 --- a/error.go +++ b/error.go @@ -19,6 +19,7 @@ var ( ErrNoFileIsGiven = errors.New("no files given") ErrEmptyUserName = errors.New("empty username") ErrEmptyURL = errors.New("empty URL") + ErrInvalidToken = errors.New("invalid token") ErrUnexpected = errors.New("unexpected error") )