test(veans): cover OAuth callback handler error paths
The e2e suite bypasses the OAuth flow via --token, so the callback handler's error branches had zero coverage. Eight tests appended to oauth_test.go drive the handler directly: - happy path: code+state arrive on the channel; response is HTML - authz-server error path: ?error=access_denied&error_description=… bubbles up as a non-nil err containing the description (not the code) - only-code fallback: when error_description is missing, the error message falls back to the error code - empty code: handler captures it; waitForCallback's job to reject - non-GET method: 405 with Allow: GET, nothing pushed to channel (defense against forged POST from a same-origin page) - wrong path: 404, nothing pushed - HTML-escaping: an error containing <script>…</script> renders as <script> — XSS regression guard - nil-err success page: 200 with 'veans is authorized' Plus generateState shape coverage (length, charset, uniqueness) to match the existing TestGeneratePKCE_*. Sanity-checked the XSS test by deleting the html.EscapeString call — it fails with raw <script> in the body. Restored.
This commit is contained in:
parent
4cda019336
commit
964fdb71d1
|
|
@ -17,8 +17,12 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
|
@ -91,3 +95,222 @@ func TestBuildAuthorizeURL(t *testing.T) {
|
|||
t.Errorf("trailing slash leaked into URL: %s", u2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateState_Shape(t *testing.T) {
|
||||
s, err := generateState()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// 24 random bytes base64url-no-pad → 32 chars.
|
||||
if len(s) < 30 {
|
||||
t.Fatalf("state length %d shorter than expected", len(s))
|
||||
}
|
||||
for _, r := range s {
|
||||
switch {
|
||||
case r >= 'A' && r <= 'Z',
|
||||
r >= 'a' && r <= 'z',
|
||||
r >= '0' && r <= '9',
|
||||
r == '-', r == '_':
|
||||
default:
|
||||
t.Fatalf("state contains illegal rune %q", r)
|
||||
}
|
||||
}
|
||||
// Decodes cleanly as base64url-no-pad.
|
||||
if _, err := base64.RawURLEncoding.DecodeString(s); err != nil {
|
||||
t.Fatalf("state isn't base64url-no-pad: %v", err)
|
||||
}
|
||||
// Two consecutive states should differ — sanity for entropy.
|
||||
s2, _ := generateState()
|
||||
if s == s2 {
|
||||
t.Fatal("two consecutive states are identical — entropy is broken")
|
||||
}
|
||||
}
|
||||
|
||||
// newCallbackHandler returns just the http.Handler portion of
|
||||
// newCallbackServer so tests can drive it directly with httptest.NewRecorder
|
||||
// without binding a real loopback socket.
|
||||
func newCallbackHandler(t *testing.T) (http.Handler, <-chan callbackResult) {
|
||||
t.Helper()
|
||||
var lc net.ListenConfig
|
||||
listener, err := lc.Listen(context.Background(), "tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = listener.Close() })
|
||||
server, ch := newCallbackServer(listener)
|
||||
return server.Handler, ch
|
||||
}
|
||||
|
||||
func TestNewCallbackServer_HappyPath(t *testing.T) {
|
||||
handler, ch := newCallbackHandler(t)
|
||||
req := httptest.NewRequest(http.MethodGet, "/callback?code=abc&state=xyz", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
select {
|
||||
case res := <-ch:
|
||||
if res.code != "abc" {
|
||||
t.Errorf("code = %q, want abc", res.code)
|
||||
}
|
||||
if res.state != "xyz" {
|
||||
t.Errorf("state = %q, want xyz", res.state)
|
||||
}
|
||||
if res.err != nil {
|
||||
t.Errorf("err = %v, want nil", res.err)
|
||||
}
|
||||
default:
|
||||
t.Fatal("no result pushed to channel")
|
||||
}
|
||||
|
||||
if ct := rec.Header().Get("Content-Type"); !strings.HasPrefix(ct, "text/html") {
|
||||
t.Errorf("Content-Type = %q, want text/html…", ct)
|
||||
}
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCallbackServer_AuthzServerError(t *testing.T) {
|
||||
handler, ch := newCallbackHandler(t)
|
||||
req := httptest.NewRequest(http.MethodGet,
|
||||
"/callback?error=access_denied&error_description=user+declined", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
select {
|
||||
case res := <-ch:
|
||||
if res.err == nil {
|
||||
t.Fatal("err = nil, want non-nil")
|
||||
}
|
||||
// renderCallbackPage uses error_description when present; the
|
||||
// handler also stuffs it into res.err. "user declined" comes
|
||||
// straight from error_description.
|
||||
if !strings.Contains(res.err.Error(), "user declined") {
|
||||
t.Errorf("err = %q, want it to mention error_description", res.err.Error())
|
||||
}
|
||||
default:
|
||||
t.Fatal("no result pushed to channel")
|
||||
}
|
||||
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("status = %d, want 400", rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCallbackServer_AuthzServerErrorOnlyCode(t *testing.T) {
|
||||
// When error_description is empty, the handler falls back to the
|
||||
// `error` code in the user-visible message.
|
||||
handler, ch := newCallbackHandler(t)
|
||||
req := httptest.NewRequest(http.MethodGet, "/callback?error=access_denied", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
select {
|
||||
case res := <-ch:
|
||||
if res.err == nil {
|
||||
t.Fatal("err = nil, want non-nil")
|
||||
}
|
||||
if !strings.Contains(res.err.Error(), "access_denied") {
|
||||
t.Errorf("err = %q, want it to mention error code", res.err.Error())
|
||||
}
|
||||
default:
|
||||
t.Fatal("no result pushed to channel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCallbackServer_EmptyCode(t *testing.T) {
|
||||
// Empty `code` without an `error` parameter is the handler's job
|
||||
// only to forward verbatim — waitForCallback is the one that
|
||||
// upgrades that to an error.
|
||||
handler, ch := newCallbackHandler(t)
|
||||
req := httptest.NewRequest(http.MethodGet, "/callback?state=xyz", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
select {
|
||||
case res := <-ch:
|
||||
if res.code != "" {
|
||||
t.Errorf("code = %q, want empty", res.code)
|
||||
}
|
||||
if res.state != "xyz" {
|
||||
t.Errorf("state = %q, want xyz", res.state)
|
||||
}
|
||||
if res.err != nil {
|
||||
t.Errorf("err = %v, want nil", res.err)
|
||||
}
|
||||
default:
|
||||
t.Fatal("no result pushed to channel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCallbackServer_MethodNotAllowed(t *testing.T) {
|
||||
handler, ch := newCallbackHandler(t)
|
||||
req := httptest.NewRequest(http.MethodPost, "/callback?code=abc&state=xyz", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusMethodNotAllowed {
|
||||
t.Errorf("status = %d, want 405", rec.Code)
|
||||
}
|
||||
if a := rec.Header().Get("Allow"); a != "GET" {
|
||||
t.Errorf("Allow header = %q, want GET", a)
|
||||
}
|
||||
select {
|
||||
case res := <-ch:
|
||||
t.Fatalf("nothing should be pushed for a rejected method, got %+v", res)
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewCallbackServer_WrongPath(t *testing.T) {
|
||||
handler, ch := newCallbackHandler(t)
|
||||
req := httptest.NewRequest(http.MethodGet, "/something-else", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Errorf("status = %d, want 404", rec.Code)
|
||||
}
|
||||
select {
|
||||
case res := <-ch:
|
||||
t.Fatalf("nothing should be pushed for a 404, got %+v", res)
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderCallbackPage_HTMLEscapesError(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
renderCallbackPage(rec, &fakeError{msg: `<script>alert(1)</script>`})
|
||||
|
||||
body := rec.Body.String()
|
||||
if strings.Contains(body, "<script>") {
|
||||
t.Errorf("body contains raw <script>: %s", body)
|
||||
}
|
||||
if !strings.Contains(body, "<script>") {
|
||||
t.Errorf("body missing escaped script tag: %s", body)
|
||||
}
|
||||
if rec.Code != http.StatusBadRequest {
|
||||
t.Errorf("status = %d, want 400 on error", rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenderCallbackPage_SuccessNoErrorPath(t *testing.T) {
|
||||
rec := httptest.NewRecorder()
|
||||
renderCallbackPage(rec, nil)
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want 200", rec.Code)
|
||||
}
|
||||
if !strings.Contains(rec.Body.String(), "veans is authorized") {
|
||||
t.Errorf("missing success message: %s", rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
type fakeError struct{ msg string }
|
||||
|
||||
func (f *fakeError) Error() string { return f.msg }
|
||||
|
|
|
|||
Loading…
Reference in New Issue