diff --git a/veans/internal/auth/oauth_test.go b/veans/internal/auth/oauth_test.go index 906addba1..5f6a68ca3 100644 --- a/veans/internal/auth/oauth_test.go +++ b/veans/internal/auth/oauth_test.go @@ -17,8 +17,12 @@ package auth import ( + "context" "crypto/sha256" "encoding/base64" + "net" + "net/http" + "net/http/httptest" "strings" "testing" ) @@ -91,3 +95,222 @@ func TestBuildAuthorizeURL(t *testing.T) { t.Errorf("trailing slash leaked into URL: %s", u2) } } + +func TestGenerateState_Shape(t *testing.T) { + s, err := generateState() + if err != nil { + t.Fatal(err) + } + // 24 random bytes base64url-no-pad → 32 chars. + if len(s) < 30 { + t.Fatalf("state length %d shorter than expected", len(s)) + } + for _, r := range s { + switch { + case r >= 'A' && r <= 'Z', + r >= 'a' && r <= 'z', + r >= '0' && r <= '9', + r == '-', r == '_': + default: + t.Fatalf("state contains illegal rune %q", r) + } + } + // Decodes cleanly as base64url-no-pad. + if _, err := base64.RawURLEncoding.DecodeString(s); err != nil { + t.Fatalf("state isn't base64url-no-pad: %v", err) + } + // Two consecutive states should differ — sanity for entropy. + s2, _ := generateState() + if s == s2 { + t.Fatal("two consecutive states are identical — entropy is broken") + } +} + +// newCallbackHandler returns just the http.Handler portion of +// newCallbackServer so tests can drive it directly with httptest.NewRecorder +// without binding a real loopback socket. +func newCallbackHandler(t *testing.T) (http.Handler, <-chan callbackResult) { + t.Helper() + var lc net.ListenConfig + listener, err := lc.Listen(context.Background(), "tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + t.Cleanup(func() { _ = listener.Close() }) + server, ch := newCallbackServer(listener) + return server.Handler, ch +} + +func TestNewCallbackServer_HappyPath(t *testing.T) { + handler, ch := newCallbackHandler(t) + req := httptest.NewRequest(http.MethodGet, "/callback?code=abc&state=xyz", nil) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + select { + case res := <-ch: + if res.code != "abc" { + t.Errorf("code = %q, want abc", res.code) + } + if res.state != "xyz" { + t.Errorf("state = %q, want xyz", res.state) + } + if res.err != nil { + t.Errorf("err = %v, want nil", res.err) + } + default: + t.Fatal("no result pushed to channel") + } + + if ct := rec.Header().Get("Content-Type"); !strings.HasPrefix(ct, "text/html") { + t.Errorf("Content-Type = %q, want text/html…", ct) + } + if rec.Code != http.StatusOK { + t.Errorf("status = %d, want 200", rec.Code) + } +} + +func TestNewCallbackServer_AuthzServerError(t *testing.T) { + handler, ch := newCallbackHandler(t) + req := httptest.NewRequest(http.MethodGet, + "/callback?error=access_denied&error_description=user+declined", nil) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + select { + case res := <-ch: + if res.err == nil { + t.Fatal("err = nil, want non-nil") + } + // renderCallbackPage uses error_description when present; the + // handler also stuffs it into res.err. "user declined" comes + // straight from error_description. + if !strings.Contains(res.err.Error(), "user declined") { + t.Errorf("err = %q, want it to mention error_description", res.err.Error()) + } + default: + t.Fatal("no result pushed to channel") + } + + if rec.Code != http.StatusBadRequest { + t.Errorf("status = %d, want 400", rec.Code) + } +} + +func TestNewCallbackServer_AuthzServerErrorOnlyCode(t *testing.T) { + // When error_description is empty, the handler falls back to the + // `error` code in the user-visible message. + handler, ch := newCallbackHandler(t) + req := httptest.NewRequest(http.MethodGet, "/callback?error=access_denied", nil) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + select { + case res := <-ch: + if res.err == nil { + t.Fatal("err = nil, want non-nil") + } + if !strings.Contains(res.err.Error(), "access_denied") { + t.Errorf("err = %q, want it to mention error code", res.err.Error()) + } + default: + t.Fatal("no result pushed to channel") + } +} + +func TestNewCallbackServer_EmptyCode(t *testing.T) { + // Empty `code` without an `error` parameter is the handler's job + // only to forward verbatim — waitForCallback is the one that + // upgrades that to an error. + handler, ch := newCallbackHandler(t) + req := httptest.NewRequest(http.MethodGet, "/callback?state=xyz", nil) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + select { + case res := <-ch: + if res.code != "" { + t.Errorf("code = %q, want empty", res.code) + } + if res.state != "xyz" { + t.Errorf("state = %q, want xyz", res.state) + } + if res.err != nil { + t.Errorf("err = %v, want nil", res.err) + } + default: + t.Fatal("no result pushed to channel") + } +} + +func TestNewCallbackServer_MethodNotAllowed(t *testing.T) { + handler, ch := newCallbackHandler(t) + req := httptest.NewRequest(http.MethodPost, "/callback?code=abc&state=xyz", nil) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusMethodNotAllowed { + t.Errorf("status = %d, want 405", rec.Code) + } + if a := rec.Header().Get("Allow"); a != "GET" { + t.Errorf("Allow header = %q, want GET", a) + } + select { + case res := <-ch: + t.Fatalf("nothing should be pushed for a rejected method, got %+v", res) + default: + } +} + +func TestNewCallbackServer_WrongPath(t *testing.T) { + handler, ch := newCallbackHandler(t) + req := httptest.NewRequest(http.MethodGet, "/something-else", nil) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusNotFound { + t.Errorf("status = %d, want 404", rec.Code) + } + select { + case res := <-ch: + t.Fatalf("nothing should be pushed for a 404, got %+v", res) + default: + } +} + +func TestRenderCallbackPage_HTMLEscapesError(t *testing.T) { + rec := httptest.NewRecorder() + renderCallbackPage(rec, &fakeError{msg: ``}) + + body := rec.Body.String() + if strings.Contains(body, "