feat(veans): switch OAuth handshake to a loopback HTTP server
This commit is contained in:
parent
b18762171d
commit
c1d5272afe
|
|
@ -24,20 +24,24 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"code.vikunja.io/veans/internal/client"
|
||||
"code.vikunja.io/veans/internal/output"
|
||||
)
|
||||
|
||||
// OAuth client identity. Vikunja's authorization server requires no
|
||||
// pre-registration — these values just need to be consistent between the
|
||||
// browser-side authorize step and the CLI-side token exchange.
|
||||
const (
|
||||
oauthClientID = "veans-cli"
|
||||
oauthRedirectURI = "vikunja-veans-cli://callback"
|
||||
)
|
||||
// oauthClientID is what veans presents to Vikunja's authorization server.
|
||||
// Vikunja's OAuth provider doesn't require client registration — the value
|
||||
// just needs to be consistent across the authorize and token-exchange steps.
|
||||
const oauthClientID = "veans-cli"
|
||||
|
||||
// loopbackTimeout caps how long we wait for the user to complete the
|
||||
// browser-side handshake before giving up.
|
||||
const loopbackTimeout = 5 * time.Minute
|
||||
|
||||
// PKCEPair holds the challenge sent to /oauth/authorize and the verifier
|
||||
// kept locally until token exchange.
|
||||
|
|
@ -61,8 +65,7 @@ func generatePKCE() (PKCEPair, error) {
|
|||
return PKCEPair{Verifier: verifier, Challenge: challenge}, nil
|
||||
}
|
||||
|
||||
// generateState returns a random opaque string for CSRF protection on the
|
||||
// authorize redirect. We verify it matches when the user pastes back.
|
||||
// generateState returns a random opaque string for CSRF protection.
|
||||
func generateState() (string, error) {
|
||||
buf := make([]byte, 24)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
|
|
@ -71,68 +74,35 @@ func generateState() (string, error) {
|
|||
return base64.RawURLEncoding.EncodeToString(buf), nil
|
||||
}
|
||||
|
||||
// buildAuthorizeURL renders the browser-side redirect target. The user
|
||||
// follows it, authenticates if necessary, and is redirected to the custom
|
||||
// scheme with `?code=...&state=...`.
|
||||
func buildAuthorizeURL(server string, pkce PKCEPair, state string) string {
|
||||
// buildAuthorizeURL renders the browser-side redirect target.
|
||||
func buildAuthorizeURL(server, redirectURI string, pkce PKCEPair, state string) string {
|
||||
q := url.Values{}
|
||||
q.Set("response_type", "code")
|
||||
q.Set("client_id", oauthClientID)
|
||||
q.Set("redirect_uri", oauthRedirectURI)
|
||||
q.Set("redirect_uri", redirectURI)
|
||||
q.Set("code_challenge", pkce.Challenge)
|
||||
q.Set("code_challenge_method", "S256")
|
||||
q.Set("state", state)
|
||||
return strings.TrimRight(server, "/") + "/oauth/authorize?" + q.Encode()
|
||||
}
|
||||
|
||||
// extractCodeAndState pulls the OAuth callback parameters out of whatever
|
||||
// the user pasted. We accept three shapes:
|
||||
// - the full custom-scheme URL: `vikunja-veans-cli://callback?code=...&state=...`
|
||||
// - just the query: `code=ABC&state=XYZ`
|
||||
// - just the code (state verification then skipped, with a warning)
|
||||
func extractCodeAndState(raw string) (code, state string, err error) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return "", "", errors.New("empty callback paste")
|
||||
// callbackResult carries the parsed query parameters from the loopback
|
||||
// callback request, or any error that prevented a clean handshake.
|
||||
type callbackResult struct {
|
||||
code string
|
||||
state string
|
||||
err error
|
||||
}
|
||||
|
||||
// Full URL form?
|
||||
if strings.Contains(raw, "://") || strings.HasPrefix(raw, "vikunja-") {
|
||||
u, perr := url.Parse(raw)
|
||||
if perr != nil {
|
||||
return "", "", fmt.Errorf("parse callback URL: %w", perr)
|
||||
}
|
||||
v := u.Query()
|
||||
// Some browsers strip the query and put it in Fragment when they
|
||||
// can't open the scheme — handle both.
|
||||
if v.Get("code") == "" && u.RawQuery == "" && u.Fragment != "" {
|
||||
v, _ = url.ParseQuery(u.Fragment)
|
||||
}
|
||||
return v.Get("code"), v.Get("state"), nil
|
||||
}
|
||||
|
||||
// Query-string form?
|
||||
if strings.Contains(raw, "code=") {
|
||||
v, perr := url.ParseQuery(raw)
|
||||
if perr != nil {
|
||||
return "", "", fmt.Errorf("parse callback query: %w", perr)
|
||||
}
|
||||
return v.Get("code"), v.Get("state"), nil
|
||||
}
|
||||
|
||||
// Bare code.
|
||||
return raw, "", nil
|
||||
}
|
||||
|
||||
// runOAuthFlow drives the manual paste-back OAuth Authorization Code +
|
||||
// PKCE handshake against Vikunja's server.
|
||||
// runOAuthFlow drives an OAuth Authorization Code + PKCE handshake against
|
||||
// Vikunja's server using a localhost loopback listener (RFC 8252):
|
||||
// bind 127.0.0.1:0, open the authorize URL in the browser, capture the
|
||||
// callback, exchange the code for a token.
|
||||
//
|
||||
// The user-facing UX: print the authorize URL, ask the user to open it in
|
||||
// their browser, sign in there, and paste the resulting (failed-to-open)
|
||||
// `vikunja-veans-cli://callback?code=...` URL back into the CLI. The
|
||||
// browser will show a "can't open this scheme" error, but the URL bar
|
||||
// contains the code we need.
|
||||
func runOAuthFlow(ctx context.Context, c *client.Client, p Prompter, w io.Writer) (string, error) {
|
||||
// The prompter is retained on the signature for symmetry with the
|
||||
// password flow but isn't called — the loopback handshake completes
|
||||
// without further user input beyond the in-browser sign-in.
|
||||
func runOAuthFlow(ctx context.Context, c *client.Client, _ Prompter, w io.Writer) (string, error) {
|
||||
pkce, err := generatePKCE()
|
||||
if err != nil {
|
||||
return "", output.Wrap(output.CodeUnknown, err, "generate PKCE: %v", err)
|
||||
|
|
@ -142,40 +112,40 @@ func runOAuthFlow(ctx context.Context, c *client.Client, p Prompter, w io.Writer
|
|||
return "", output.Wrap(output.CodeUnknown, err, "generate state: %v", err)
|
||||
}
|
||||
|
||||
authURL := buildAuthorizeURL(c.BaseURL, pkce, state)
|
||||
if w != nil {
|
||||
fmt.Fprintln(w, "")
|
||||
fmt.Fprintln(w, "Open the following URL in your browser:")
|
||||
fmt.Fprintln(w, "")
|
||||
fmt.Fprintln(w, " "+authURL)
|
||||
fmt.Fprintln(w, "")
|
||||
fmt.Fprintln(w, "After signing in, your browser will try to open")
|
||||
fmt.Fprintln(w, " "+oauthRedirectURI+"?code=...&state=...")
|
||||
fmt.Fprintln(w, "and show a 'can't open this URL' error. That's expected.")
|
||||
fmt.Fprintln(w, "Copy the URL from the address bar and paste it here.")
|
||||
fmt.Fprintln(w, "")
|
||||
}
|
||||
|
||||
pasted, err := p.ReadLine("Paste callback URL (or just the code): ")
|
||||
listener, redirectURI, err := bindLoopbackListener(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
code, returnedState, err := extractCodeAndState(pasted)
|
||||
|
||||
server, resultCh := newCallbackServer(listener)
|
||||
go func() { _ = server.Serve(listener) }()
|
||||
// Shutdown uses a detached context derived from ctx so cancellation
|
||||
// at the outer scope still allows the graceful-stop to drain.
|
||||
shutdownParent := context.WithoutCancel(ctx)
|
||||
defer func() {
|
||||
shutdownCtx, cancel := context.WithTimeout(shutdownParent, 2*time.Second)
|
||||
defer cancel()
|
||||
_ = server.Shutdown(shutdownCtx)
|
||||
}()
|
||||
|
||||
authURL := buildAuthorizeURL(c.BaseURL, redirectURI, pkce, state)
|
||||
announceBrowserStep(w, authURL)
|
||||
openBrowser(ctx, authURL)
|
||||
|
||||
result, err := waitForCallback(ctx, resultCh)
|
||||
if err != nil {
|
||||
return "", output.Wrap(output.CodeAuth, err, "%v", err)
|
||||
return "", err
|
||||
}
|
||||
if code == "" {
|
||||
return "", output.New(output.CodeAuth, "no `code` found in pasted callback")
|
||||
}
|
||||
if returnedState != "" && returnedState != state {
|
||||
return "", output.New(output.CodeAuth, "state mismatch on OAuth callback (possible CSRF)")
|
||||
if result.state != state {
|
||||
return "", output.New(output.CodeAuth,
|
||||
"state mismatch on OAuth callback (possible CSRF)")
|
||||
}
|
||||
|
||||
resp, err := c.ExchangeOAuthCode(ctx, &client.OAuthTokenRequest{
|
||||
GrantType: "authorization_code",
|
||||
Code: code,
|
||||
Code: result.code,
|
||||
ClientID: oauthClientID,
|
||||
RedirectURI: oauthRedirectURI,
|
||||
RedirectURI: redirectURI,
|
||||
CodeVerifier: pkce.Verifier,
|
||||
})
|
||||
if err != nil {
|
||||
|
|
@ -186,3 +156,114 @@ func runOAuthFlow(ctx context.Context, c *client.Client, p Prompter, w io.Writer
|
|||
}
|
||||
return resp.AccessToken, nil
|
||||
}
|
||||
|
||||
// bindLoopbackListener picks a free port on 127.0.0.1 and returns a
|
||||
// listener + the corresponding `http://127.0.0.1:NNN/callback` URI for
|
||||
// the OAuth `redirect_uri` parameter.
|
||||
func bindLoopbackListener(ctx context.Context) (net.Listener, string, error) {
|
||||
var lc net.ListenConfig
|
||||
listener, err := lc.Listen(ctx, "tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return nil, "", output.Wrap(output.CodeUnknown, err,
|
||||
"bind loopback port for OAuth callback: %v", err)
|
||||
}
|
||||
port := listener.Addr().(*net.TCPAddr).Port
|
||||
return listener, fmt.Sprintf("http://127.0.0.1:%d/callback", port), nil
|
||||
}
|
||||
|
||||
// newCallbackServer returns an http.Server bound to `listener` whose
|
||||
// /callback handler parses the authorization-server redirect query and
|
||||
// pushes the result onto the returned channel.
|
||||
func newCallbackServer(listener net.Listener) (*http.Server, <-chan callbackResult) {
|
||||
resultCh := make(chan callbackResult, 1)
|
||||
server := &http.Server{
|
||||
Addr: listener.Addr().String(),
|
||||
ReadHeaderTimeout: 5 * time.Second,
|
||||
Handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/callback" {
|
||||
http.NotFound(rw, r)
|
||||
return
|
||||
}
|
||||
q := r.URL.Query()
|
||||
res := callbackResult{code: q.Get("code"), state: q.Get("state")}
|
||||
if errCode := q.Get("error"); errCode != "" {
|
||||
desc := q.Get("error_description")
|
||||
if desc == "" {
|
||||
desc = errCode
|
||||
}
|
||||
res.err = fmt.Errorf("authorization failed: %s", desc)
|
||||
}
|
||||
renderCallbackPage(rw, res.err)
|
||||
select {
|
||||
case resultCh <- res:
|
||||
default:
|
||||
}
|
||||
}),
|
||||
}
|
||||
return server, resultCh
|
||||
}
|
||||
|
||||
// waitForCallback blocks until the loopback handler fires, ctx cancels,
|
||||
// or loopbackTimeout elapses.
|
||||
func waitForCallback(ctx context.Context, resultCh <-chan callbackResult) (callbackResult, error) {
|
||||
timer := time.NewTimer(loopbackTimeout)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case result := <-resultCh:
|
||||
if result.err != nil {
|
||||
return result, output.Wrap(output.CodeAuth, result.err, "%v", result.err)
|
||||
}
|
||||
if result.code == "" {
|
||||
return result, output.New(output.CodeAuth, "no `code` returned from OAuth callback")
|
||||
}
|
||||
return result, nil
|
||||
case <-timer.C:
|
||||
return callbackResult{}, output.New(output.CodeAuth,
|
||||
"OAuth flow timed out after %s — re-run init with --use-password or --token", loopbackTimeout)
|
||||
case <-ctx.Done():
|
||||
return callbackResult{}, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func announceBrowserStep(w io.Writer, authURL string) {
|
||||
if w == nil {
|
||||
return
|
||||
}
|
||||
fmt.Fprintln(w)
|
||||
fmt.Fprintln(w, "Opening your browser to authorize veans:")
|
||||
fmt.Fprintln(w, " "+authURL)
|
||||
fmt.Fprintln(w)
|
||||
fmt.Fprintln(w, "If the browser doesn't open, paste the URL above manually.")
|
||||
fmt.Fprintln(w)
|
||||
}
|
||||
|
||||
// renderCallbackPage writes a minimal HTML response to the user's browser
|
||||
// after the loopback callback fires. We don't ship any framework — a few
|
||||
// lines of inlined HTML are enough to confirm completion.
|
||||
func renderCallbackPage(w http.ResponseWriter, err error) {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
if err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = fmt.Fprintf(w, `<!doctype html><html><body style="font-family:system-ui,sans-serif;max-width:32rem;margin:4rem auto;padding:0 1rem">
|
||||
<h1>veans: authorization failed</h1>
|
||||
<p>%s</p>
|
||||
<p>You can close this tab and re-run <code>veans init</code>.</p>
|
||||
</body></html>`, err.Error())
|
||||
return
|
||||
}
|
||||
_, _ = w.Write([]byte(`<!doctype html><html><body style="font-family:system-ui,sans-serif;max-width:32rem;margin:4rem auto;padding:0 1rem">
|
||||
<h1>veans is authorized</h1>
|
||||
<p>You can close this tab and return to the terminal.</p>
|
||||
</body></html>`))
|
||||
}
|
||||
|
||||
// openBrowser tries to launch the user's default browser at `url`. Failure
|
||||
// is ignored — the calling flow already prints the URL to stderr so the
|
||||
// user can open it themselves.
|
||||
func openBrowser(ctx context.Context, url string) {
|
||||
_ = osOpen(ctx, url)
|
||||
}
|
||||
|
||||
// silence the unused-import linter when errors isn't referenced elsewhere.
|
||||
var _ = errors.New
|
||||
|
|
|
|||
|
|
@ -61,44 +61,13 @@ func TestGeneratePKCE_Unique(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestExtractCodeAndState_FullURL(t *testing.T) {
|
||||
code, state, err := extractCodeAndState("vikunja-veans-cli://callback?code=ABC123&state=XYZ")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if code != "ABC123" || state != "XYZ" {
|
||||
t.Fatalf("got code=%q state=%q", code, state)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractCodeAndState_QueryOnly(t *testing.T) {
|
||||
code, state, err := extractCodeAndState("code=ABC&state=XYZ")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if code != "ABC" || state != "XYZ" {
|
||||
t.Fatalf("got code=%q state=%q", code, state)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractCodeAndState_BareCode(t *testing.T) {
|
||||
code, state, err := extractCodeAndState("plain-code-value")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if code != "plain-code-value" || state != "" {
|
||||
t.Fatalf("got code=%q state=%q", code, state)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractCodeAndState_EmptyError(t *testing.T) {
|
||||
if _, _, err := extractCodeAndState(" "); err == nil {
|
||||
t.Fatal("expected error on empty paste")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthorizeURL(t *testing.T) {
|
||||
u := buildAuthorizeURL("https://vikunja.example.com", PKCEPair{Challenge: "CHL"}, "S")
|
||||
u := buildAuthorizeURL(
|
||||
"https://vikunja.example.com",
|
||||
"http://127.0.0.1:54321/callback",
|
||||
PKCEPair{Challenge: "CHL"},
|
||||
"S",
|
||||
)
|
||||
if !strings.HasPrefix(u, "https://vikunja.example.com/oauth/authorize?") {
|
||||
t.Fatalf("unexpected prefix: %s", u)
|
||||
}
|
||||
|
|
@ -108,6 +77,8 @@ func TestBuildAuthorizeURL(t *testing.T) {
|
|||
"code_challenge=CHL",
|
||||
"code_challenge_method=S256",
|
||||
"state=S",
|
||||
// redirect_uri carried through (URL-encoded)
|
||||
"redirect_uri=http%3A%2F%2F127.0.0.1%3A54321%2Fcallback",
|
||||
} {
|
||||
if !strings.Contains(u, want) {
|
||||
t.Errorf("authorize URL missing %q: %s", want, u)
|
||||
|
|
@ -115,7 +86,7 @@ func TestBuildAuthorizeURL(t *testing.T) {
|
|||
}
|
||||
// Server URL with trailing slash should still produce a single slash
|
||||
// before the path.
|
||||
u2 := buildAuthorizeURL("https://vikunja.example.com/", PKCEPair{}, "")
|
||||
u2 := buildAuthorizeURL("https://vikunja.example.com/", "", PKCEPair{}, "")
|
||||
if strings.Contains(u2, "//oauth") {
|
||||
t.Errorf("trailing slash leaked into URL: %s", u2)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue