diff --git a/veans/internal/auth/oauth.go b/veans/internal/auth/oauth.go index 101b3f9e3..1c3e87a66 100644 --- a/veans/internal/auth/oauth.go +++ b/veans/internal/auth/oauth.go @@ -24,20 +24,24 @@ import ( "errors" "fmt" "io" + "net" + "net/http" "net/url" "strings" + "time" "code.vikunja.io/veans/internal/client" "code.vikunja.io/veans/internal/output" ) -// OAuth client identity. Vikunja's authorization server requires no -// pre-registration — these values just need to be consistent between the -// browser-side authorize step and the CLI-side token exchange. -const ( - oauthClientID = "veans-cli" - oauthRedirectURI = "vikunja-veans-cli://callback" -) +// oauthClientID is what veans presents to Vikunja's authorization server. +// Vikunja's OAuth provider doesn't require client registration — the value +// just needs to be consistent across the authorize and token-exchange steps. +const oauthClientID = "veans-cli" + +// loopbackTimeout caps how long we wait for the user to complete the +// browser-side handshake before giving up. +const loopbackTimeout = 5 * time.Minute // PKCEPair holds the challenge sent to /oauth/authorize and the verifier // kept locally until token exchange. @@ -61,8 +65,7 @@ func generatePKCE() (PKCEPair, error) { return PKCEPair{Verifier: verifier, Challenge: challenge}, nil } -// generateState returns a random opaque string for CSRF protection on the -// authorize redirect. We verify it matches when the user pastes back. +// generateState returns a random opaque string for CSRF protection. func generateState() (string, error) { buf := make([]byte, 24) if _, err := rand.Read(buf); err != nil { @@ -71,68 +74,35 @@ func generateState() (string, error) { return base64.RawURLEncoding.EncodeToString(buf), nil } -// buildAuthorizeURL renders the browser-side redirect target. The user -// follows it, authenticates if necessary, and is redirected to the custom -// scheme with `?code=...&state=...`. -func buildAuthorizeURL(server string, pkce PKCEPair, state string) string { +// buildAuthorizeURL renders the browser-side redirect target. +func buildAuthorizeURL(server, redirectURI string, pkce PKCEPair, state string) string { q := url.Values{} q.Set("response_type", "code") q.Set("client_id", oauthClientID) - q.Set("redirect_uri", oauthRedirectURI) + q.Set("redirect_uri", redirectURI) q.Set("code_challenge", pkce.Challenge) q.Set("code_challenge_method", "S256") q.Set("state", state) return strings.TrimRight(server, "/") + "/oauth/authorize?" + q.Encode() } -// extractCodeAndState pulls the OAuth callback parameters out of whatever -// the user pasted. We accept three shapes: -// - the full custom-scheme URL: `vikunja-veans-cli://callback?code=...&state=...` -// - just the query: `code=ABC&state=XYZ` -// - just the code (state verification then skipped, with a warning) -func extractCodeAndState(raw string) (code, state string, err error) { - raw = strings.TrimSpace(raw) - if raw == "" { - return "", "", errors.New("empty callback paste") - } - - // Full URL form? - if strings.Contains(raw, "://") || strings.HasPrefix(raw, "vikunja-") { - u, perr := url.Parse(raw) - if perr != nil { - return "", "", fmt.Errorf("parse callback URL: %w", perr) - } - v := u.Query() - // Some browsers strip the query and put it in Fragment when they - // can't open the scheme — handle both. - if v.Get("code") == "" && u.RawQuery == "" && u.Fragment != "" { - v, _ = url.ParseQuery(u.Fragment) - } - return v.Get("code"), v.Get("state"), nil - } - - // Query-string form? - if strings.Contains(raw, "code=") { - v, perr := url.ParseQuery(raw) - if perr != nil { - return "", "", fmt.Errorf("parse callback query: %w", perr) - } - return v.Get("code"), v.Get("state"), nil - } - - // Bare code. - return raw, "", nil +// callbackResult carries the parsed query parameters from the loopback +// callback request, or any error that prevented a clean handshake. +type callbackResult struct { + code string + state string + err error } -// runOAuthFlow drives the manual paste-back OAuth Authorization Code + -// PKCE handshake against Vikunja's server. +// runOAuthFlow drives an OAuth Authorization Code + PKCE handshake against +// Vikunja's server using a localhost loopback listener (RFC 8252): +// bind 127.0.0.1:0, open the authorize URL in the browser, capture the +// callback, exchange the code for a token. // -// The user-facing UX: print the authorize URL, ask the user to open it in -// their browser, sign in there, and paste the resulting (failed-to-open) -// `vikunja-veans-cli://callback?code=...` URL back into the CLI. The -// browser will show a "can't open this scheme" error, but the URL bar -// contains the code we need. -func runOAuthFlow(ctx context.Context, c *client.Client, p Prompter, w io.Writer) (string, error) { +// The prompter is retained on the signature for symmetry with the +// password flow but isn't called — the loopback handshake completes +// without further user input beyond the in-browser sign-in. +func runOAuthFlow(ctx context.Context, c *client.Client, _ Prompter, w io.Writer) (string, error) { pkce, err := generatePKCE() if err != nil { return "", output.Wrap(output.CodeUnknown, err, "generate PKCE: %v", err) @@ -142,40 +112,40 @@ func runOAuthFlow(ctx context.Context, c *client.Client, p Prompter, w io.Writer return "", output.Wrap(output.CodeUnknown, err, "generate state: %v", err) } - authURL := buildAuthorizeURL(c.BaseURL, pkce, state) - if w != nil { - fmt.Fprintln(w, "") - fmt.Fprintln(w, "Open the following URL in your browser:") - fmt.Fprintln(w, "") - fmt.Fprintln(w, " "+authURL) - fmt.Fprintln(w, "") - fmt.Fprintln(w, "After signing in, your browser will try to open") - fmt.Fprintln(w, " "+oauthRedirectURI+"?code=...&state=...") - fmt.Fprintln(w, "and show a 'can't open this URL' error. That's expected.") - fmt.Fprintln(w, "Copy the URL from the address bar and paste it here.") - fmt.Fprintln(w, "") - } - - pasted, err := p.ReadLine("Paste callback URL (or just the code): ") + listener, redirectURI, err := bindLoopbackListener(ctx) if err != nil { return "", err } - code, returnedState, err := extractCodeAndState(pasted) + + server, resultCh := newCallbackServer(listener) + go func() { _ = server.Serve(listener) }() + // Shutdown uses a detached context derived from ctx so cancellation + // at the outer scope still allows the graceful-stop to drain. + shutdownParent := context.WithoutCancel(ctx) + defer func() { + shutdownCtx, cancel := context.WithTimeout(shutdownParent, 2*time.Second) + defer cancel() + _ = server.Shutdown(shutdownCtx) + }() + + authURL := buildAuthorizeURL(c.BaseURL, redirectURI, pkce, state) + announceBrowserStep(w, authURL) + openBrowser(ctx, authURL) + + result, err := waitForCallback(ctx, resultCh) if err != nil { - return "", output.Wrap(output.CodeAuth, err, "%v", err) + return "", err } - if code == "" { - return "", output.New(output.CodeAuth, "no `code` found in pasted callback") - } - if returnedState != "" && returnedState != state { - return "", output.New(output.CodeAuth, "state mismatch on OAuth callback (possible CSRF)") + if result.state != state { + return "", output.New(output.CodeAuth, + "state mismatch on OAuth callback (possible CSRF)") } resp, err := c.ExchangeOAuthCode(ctx, &client.OAuthTokenRequest{ GrantType: "authorization_code", - Code: code, + Code: result.code, ClientID: oauthClientID, - RedirectURI: oauthRedirectURI, + RedirectURI: redirectURI, CodeVerifier: pkce.Verifier, }) if err != nil { @@ -186,3 +156,114 @@ func runOAuthFlow(ctx context.Context, c *client.Client, p Prompter, w io.Writer } return resp.AccessToken, nil } + +// bindLoopbackListener picks a free port on 127.0.0.1 and returns a +// listener + the corresponding `http://127.0.0.1:NNN/callback` URI for +// the OAuth `redirect_uri` parameter. +func bindLoopbackListener(ctx context.Context) (net.Listener, string, error) { + var lc net.ListenConfig + listener, err := lc.Listen(ctx, "tcp", "127.0.0.1:0") + if err != nil { + return nil, "", output.Wrap(output.CodeUnknown, err, + "bind loopback port for OAuth callback: %v", err) + } + port := listener.Addr().(*net.TCPAddr).Port + return listener, fmt.Sprintf("http://127.0.0.1:%d/callback", port), nil +} + +// newCallbackServer returns an http.Server bound to `listener` whose +// /callback handler parses the authorization-server redirect query and +// pushes the result onto the returned channel. +func newCallbackServer(listener net.Listener) (*http.Server, <-chan callbackResult) { + resultCh := make(chan callbackResult, 1) + server := &http.Server{ + Addr: listener.Addr().String(), + ReadHeaderTimeout: 5 * time.Second, + Handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/callback" { + http.NotFound(rw, r) + return + } + q := r.URL.Query() + res := callbackResult{code: q.Get("code"), state: q.Get("state")} + if errCode := q.Get("error"); errCode != "" { + desc := q.Get("error_description") + if desc == "" { + desc = errCode + } + res.err = fmt.Errorf("authorization failed: %s", desc) + } + renderCallbackPage(rw, res.err) + select { + case resultCh <- res: + default: + } + }), + } + return server, resultCh +} + +// waitForCallback blocks until the loopback handler fires, ctx cancels, +// or loopbackTimeout elapses. +func waitForCallback(ctx context.Context, resultCh <-chan callbackResult) (callbackResult, error) { + timer := time.NewTimer(loopbackTimeout) + defer timer.Stop() + select { + case result := <-resultCh: + if result.err != nil { + return result, output.Wrap(output.CodeAuth, result.err, "%v", result.err) + } + if result.code == "" { + return result, output.New(output.CodeAuth, "no `code` returned from OAuth callback") + } + return result, nil + case <-timer.C: + return callbackResult{}, output.New(output.CodeAuth, + "OAuth flow timed out after %s — re-run init with --use-password or --token", loopbackTimeout) + case <-ctx.Done(): + return callbackResult{}, ctx.Err() + } +} + +func announceBrowserStep(w io.Writer, authURL string) { + if w == nil { + return + } + fmt.Fprintln(w) + fmt.Fprintln(w, "Opening your browser to authorize veans:") + fmt.Fprintln(w, " "+authURL) + fmt.Fprintln(w) + fmt.Fprintln(w, "If the browser doesn't open, paste the URL above manually.") + fmt.Fprintln(w) +} + +// renderCallbackPage writes a minimal HTML response to the user's browser +// after the loopback callback fires. We don't ship any framework — a few +// lines of inlined HTML are enough to confirm completion. +func renderCallbackPage(w http.ResponseWriter, err error) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.Header().Set("Cache-Control", "no-store") + if err != nil { + w.WriteHeader(http.StatusBadRequest) + _, _ = fmt.Fprintf(w, ` +

veans: authorization failed

+

%s

+

You can close this tab and re-run veans init.

+`, err.Error()) + return + } + _, _ = w.Write([]byte(` +

veans is authorized

+

You can close this tab and return to the terminal.

+`)) +} + +// openBrowser tries to launch the user's default browser at `url`. Failure +// is ignored — the calling flow already prints the URL to stderr so the +// user can open it themselves. +func openBrowser(ctx context.Context, url string) { + _ = osOpen(ctx, url) +} + +// silence the unused-import linter when errors isn't referenced elsewhere. +var _ = errors.New diff --git a/veans/internal/auth/oauth_test.go b/veans/internal/auth/oauth_test.go index 4496bc8cc..906addba1 100644 --- a/veans/internal/auth/oauth_test.go +++ b/veans/internal/auth/oauth_test.go @@ -61,44 +61,13 @@ func TestGeneratePKCE_Unique(t *testing.T) { } } -func TestExtractCodeAndState_FullURL(t *testing.T) { - code, state, err := extractCodeAndState("vikunja-veans-cli://callback?code=ABC123&state=XYZ") - if err != nil { - t.Fatal(err) - } - if code != "ABC123" || state != "XYZ" { - t.Fatalf("got code=%q state=%q", code, state) - } -} - -func TestExtractCodeAndState_QueryOnly(t *testing.T) { - code, state, err := extractCodeAndState("code=ABC&state=XYZ") - if err != nil { - t.Fatal(err) - } - if code != "ABC" || state != "XYZ" { - t.Fatalf("got code=%q state=%q", code, state) - } -} - -func TestExtractCodeAndState_BareCode(t *testing.T) { - code, state, err := extractCodeAndState("plain-code-value") - if err != nil { - t.Fatal(err) - } - if code != "plain-code-value" || state != "" { - t.Fatalf("got code=%q state=%q", code, state) - } -} - -func TestExtractCodeAndState_EmptyError(t *testing.T) { - if _, _, err := extractCodeAndState(" "); err == nil { - t.Fatal("expected error on empty paste") - } -} - func TestBuildAuthorizeURL(t *testing.T) { - u := buildAuthorizeURL("https://vikunja.example.com", PKCEPair{Challenge: "CHL"}, "S") + u := buildAuthorizeURL( + "https://vikunja.example.com", + "http://127.0.0.1:54321/callback", + PKCEPair{Challenge: "CHL"}, + "S", + ) if !strings.HasPrefix(u, "https://vikunja.example.com/oauth/authorize?") { t.Fatalf("unexpected prefix: %s", u) } @@ -108,6 +77,8 @@ func TestBuildAuthorizeURL(t *testing.T) { "code_challenge=CHL", "code_challenge_method=S256", "state=S", + // redirect_uri carried through (URL-encoded) + "redirect_uri=http%3A%2F%2F127.0.0.1%3A54321%2Fcallback", } { if !strings.Contains(u, want) { t.Errorf("authorize URL missing %q: %s", want, u) @@ -115,7 +86,7 @@ func TestBuildAuthorizeURL(t *testing.T) { } // Server URL with trailing slash should still produce a single slash // before the path. - u2 := buildAuthorizeURL("https://vikunja.example.com/", PKCEPair{}, "") + u2 := buildAuthorizeURL("https://vikunja.example.com/", "", PKCEPair{}, "") if strings.Contains(u2, "//oauth") { t.Errorf("trailing slash leaked into URL: %s", u2) }