diff --git a/veans/internal/auth/oauth_test.go b/veans/internal/auth/oauth_test.go
index 906addba1..5f6a68ca3 100644
--- a/veans/internal/auth/oauth_test.go
+++ b/veans/internal/auth/oauth_test.go
@@ -17,8 +17,12 @@
package auth
import (
+ "context"
"crypto/sha256"
"encoding/base64"
+ "net"
+ "net/http"
+ "net/http/httptest"
"strings"
"testing"
)
@@ -91,3 +95,222 @@ func TestBuildAuthorizeURL(t *testing.T) {
t.Errorf("trailing slash leaked into URL: %s", u2)
}
}
+
+func TestGenerateState_Shape(t *testing.T) {
+ s, err := generateState()
+ if err != nil {
+ t.Fatal(err)
+ }
+ // 24 random bytes base64url-no-pad → 32 chars.
+ if len(s) < 30 {
+ t.Fatalf("state length %d shorter than expected", len(s))
+ }
+ for _, r := range s {
+ switch {
+ case r >= 'A' && r <= 'Z',
+ r >= 'a' && r <= 'z',
+ r >= '0' && r <= '9',
+ r == '-', r == '_':
+ default:
+ t.Fatalf("state contains illegal rune %q", r)
+ }
+ }
+ // Decodes cleanly as base64url-no-pad.
+ if _, err := base64.RawURLEncoding.DecodeString(s); err != nil {
+ t.Fatalf("state isn't base64url-no-pad: %v", err)
+ }
+ // Two consecutive states should differ — sanity for entropy.
+ s2, _ := generateState()
+ if s == s2 {
+ t.Fatal("two consecutive states are identical — entropy is broken")
+ }
+}
+
+// newCallbackHandler returns just the http.Handler portion of
+// newCallbackServer so tests can drive it directly with httptest.NewRecorder
+// without binding a real loopback socket.
+func newCallbackHandler(t *testing.T) (http.Handler, <-chan callbackResult) {
+ t.Helper()
+ var lc net.ListenConfig
+ listener, err := lc.Listen(context.Background(), "tcp", "127.0.0.1:0")
+ if err != nil {
+ t.Fatalf("listen: %v", err)
+ }
+ t.Cleanup(func() { _ = listener.Close() })
+ server, ch := newCallbackServer(listener)
+ return server.Handler, ch
+}
+
+func TestNewCallbackServer_HappyPath(t *testing.T) {
+ handler, ch := newCallbackHandler(t)
+ req := httptest.NewRequest(http.MethodGet, "/callback?code=abc&state=xyz", nil)
+ rec := httptest.NewRecorder()
+
+ handler.ServeHTTP(rec, req)
+
+ select {
+ case res := <-ch:
+ if res.code != "abc" {
+ t.Errorf("code = %q, want abc", res.code)
+ }
+ if res.state != "xyz" {
+ t.Errorf("state = %q, want xyz", res.state)
+ }
+ if res.err != nil {
+ t.Errorf("err = %v, want nil", res.err)
+ }
+ default:
+ t.Fatal("no result pushed to channel")
+ }
+
+ if ct := rec.Header().Get("Content-Type"); !strings.HasPrefix(ct, "text/html") {
+ t.Errorf("Content-Type = %q, want text/html…", ct)
+ }
+ if rec.Code != http.StatusOK {
+ t.Errorf("status = %d, want 200", rec.Code)
+ }
+}
+
+func TestNewCallbackServer_AuthzServerError(t *testing.T) {
+ handler, ch := newCallbackHandler(t)
+ req := httptest.NewRequest(http.MethodGet,
+ "/callback?error=access_denied&error_description=user+declined", nil)
+ rec := httptest.NewRecorder()
+
+ handler.ServeHTTP(rec, req)
+
+ select {
+ case res := <-ch:
+ if res.err == nil {
+ t.Fatal("err = nil, want non-nil")
+ }
+ // renderCallbackPage uses error_description when present; the
+ // handler also stuffs it into res.err. "user declined" comes
+ // straight from error_description.
+ if !strings.Contains(res.err.Error(), "user declined") {
+ t.Errorf("err = %q, want it to mention error_description", res.err.Error())
+ }
+ default:
+ t.Fatal("no result pushed to channel")
+ }
+
+ if rec.Code != http.StatusBadRequest {
+ t.Errorf("status = %d, want 400", rec.Code)
+ }
+}
+
+func TestNewCallbackServer_AuthzServerErrorOnlyCode(t *testing.T) {
+ // When error_description is empty, the handler falls back to the
+ // `error` code in the user-visible message.
+ handler, ch := newCallbackHandler(t)
+ req := httptest.NewRequest(http.MethodGet, "/callback?error=access_denied", nil)
+ rec := httptest.NewRecorder()
+
+ handler.ServeHTTP(rec, req)
+
+ select {
+ case res := <-ch:
+ if res.err == nil {
+ t.Fatal("err = nil, want non-nil")
+ }
+ if !strings.Contains(res.err.Error(), "access_denied") {
+ t.Errorf("err = %q, want it to mention error code", res.err.Error())
+ }
+ default:
+ t.Fatal("no result pushed to channel")
+ }
+}
+
+func TestNewCallbackServer_EmptyCode(t *testing.T) {
+ // Empty `code` without an `error` parameter is the handler's job
+ // only to forward verbatim — waitForCallback is the one that
+ // upgrades that to an error.
+ handler, ch := newCallbackHandler(t)
+ req := httptest.NewRequest(http.MethodGet, "/callback?state=xyz", nil)
+ rec := httptest.NewRecorder()
+
+ handler.ServeHTTP(rec, req)
+
+ select {
+ case res := <-ch:
+ if res.code != "" {
+ t.Errorf("code = %q, want empty", res.code)
+ }
+ if res.state != "xyz" {
+ t.Errorf("state = %q, want xyz", res.state)
+ }
+ if res.err != nil {
+ t.Errorf("err = %v, want nil", res.err)
+ }
+ default:
+ t.Fatal("no result pushed to channel")
+ }
+}
+
+func TestNewCallbackServer_MethodNotAllowed(t *testing.T) {
+ handler, ch := newCallbackHandler(t)
+ req := httptest.NewRequest(http.MethodPost, "/callback?code=abc&state=xyz", nil)
+ rec := httptest.NewRecorder()
+
+ handler.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusMethodNotAllowed {
+ t.Errorf("status = %d, want 405", rec.Code)
+ }
+ if a := rec.Header().Get("Allow"); a != "GET" {
+ t.Errorf("Allow header = %q, want GET", a)
+ }
+ select {
+ case res := <-ch:
+ t.Fatalf("nothing should be pushed for a rejected method, got %+v", res)
+ default:
+ }
+}
+
+func TestNewCallbackServer_WrongPath(t *testing.T) {
+ handler, ch := newCallbackHandler(t)
+ req := httptest.NewRequest(http.MethodGet, "/something-else", nil)
+ rec := httptest.NewRecorder()
+
+ handler.ServeHTTP(rec, req)
+
+ if rec.Code != http.StatusNotFound {
+ t.Errorf("status = %d, want 404", rec.Code)
+ }
+ select {
+ case res := <-ch:
+ t.Fatalf("nothing should be pushed for a 404, got %+v", res)
+ default:
+ }
+}
+
+func TestRenderCallbackPage_HTMLEscapesError(t *testing.T) {
+ rec := httptest.NewRecorder()
+ renderCallbackPage(rec, &fakeError{msg: ``})
+
+ body := rec.Body.String()
+ if strings.Contains(body, "