diff --git a/pkg/models/webhooks.go b/pkg/models/webhooks.go index 9622c9e7f..98975dd7b 100644 --- a/pkg/models/webhooks.go +++ b/pkg/models/webhooks.go @@ -25,9 +25,7 @@ import ( "encoding/hex" "encoding/json" "io" - "net" "net/http" - "net/url" "sort" "strings" "sync" @@ -37,10 +35,10 @@ import ( "code.vikunja.io/api/pkg/events" "code.vikunja.io/api/pkg/log" "code.vikunja.io/api/pkg/user" + "code.vikunja.io/api/pkg/utils" "code.vikunja.io/api/pkg/version" "code.vikunja.io/api/pkg/web" - "code.dny.dev/ssrf" "xorm.io/xorm" ) @@ -289,40 +287,14 @@ func (w *Webhook) Delete(s *xorm.Session, _ web.Auth) (err error) { } func getWebhookHTTPClient() (client *http.Client) { - if webhookClient != nil { return webhookClient } - client = &http.Client{} + client = utils.NewSSRFSafeHTTPClient() client.Timeout = time.Duration(config.WebhooksTimeoutSeconds.GetInt()) * time.Second - transport := &http.Transport{} - - // SSRF protection: block connections to non-globally-routable IPs unless - // explicitly allowed. Uses daenney/ssrf which validates resolved IPs - // against IANA Special Purpose Registries after DNS resolution, - // preventing DNS rebinding attacks. - if !config.WebhooksAllowNonRoutableIPs.GetBool() { - guardian := ssrf.New(ssrf.WithAnyPort()) - transport.DialContext = (&net.Dialer{ - Control: guardian.Safe, - }).DialContext - } - - if config.WebhooksProxyURL.GetString() != "" && config.WebhooksProxyPassword.GetString() != "" { - proxyURL, _ := url.Parse(config.WebhooksProxyURL.GetString()) - transport.Proxy = http.ProxyURL(proxyURL) - transport.ProxyConnectHeader = http.Header{ - "Proxy-Authorization": []string{"Basic " + base64.StdEncoding.EncodeToString([]byte("vikunja:"+config.WebhooksProxyPassword.GetString()))}, - "User-Agent": []string{"Vikunja/" + version.Version}, - } - } - - client.Transport = transport - webhookClient = client - return } diff --git a/pkg/models/webhooks_ssrf_test.go b/pkg/models/webhooks_ssrf_test.go index fda2dca75..2d6ef0865 100644 --- a/pkg/models/webhooks_ssrf_test.go +++ b/pkg/models/webhooks_ssrf_test.go @@ -35,7 +35,7 @@ func TestWebhookSSRFProtection(t *testing.T) { t.Run("blocks requests to loopback addresses", func(t *testing.T) { resetWebhookClient() - config.WebhooksAllowNonRoutableIPs.Set(false) + config.OutgoingRequestsAllowNonRoutableIPs.Set("false") config.WebhooksProxyURL.Set("") config.WebhooksProxyPassword.Set("") @@ -53,7 +53,7 @@ func TestWebhookSSRFProtection(t *testing.T) { t.Run("allows requests to public addresses", func(t *testing.T) { resetWebhookClient() - config.WebhooksAllowNonRoutableIPs.Set(false) + config.OutgoingRequestsAllowNonRoutableIPs.Set("false") config.WebhooksProxyURL.Set("") config.WebhooksProxyPassword.Set("") @@ -80,7 +80,7 @@ func TestWebhookSSRFProtection(t *testing.T) { t.Run("allows loopback when allownonroutableips is true", func(t *testing.T) { resetWebhookClient() - config.WebhooksAllowNonRoutableIPs.Set(true) + config.OutgoingRequestsAllowNonRoutableIPs.Set("true") config.WebhooksProxyURL.Set("") config.WebhooksProxyPassword.Set("") @@ -102,7 +102,7 @@ func TestWebhookSSRFProtection(t *testing.T) { t.Run("blocks requests to private RFC1918 addresses", func(t *testing.T) { resetWebhookClient() - config.WebhooksAllowNonRoutableIPs.Set(false) + config.OutgoingRequestsAllowNonRoutableIPs.Set("false") config.WebhooksProxyURL.Set("") config.WebhooksProxyPassword.Set("") @@ -128,7 +128,7 @@ func TestWebhookSSRFProtection(t *testing.T) { t.Run("blocks requests to metadata endpoint", func(t *testing.T) { resetWebhookClient() - config.WebhooksAllowNonRoutableIPs.Set(false) + config.OutgoingRequestsAllowNonRoutableIPs.Set("false") config.WebhooksProxyURL.Set("") config.WebhooksProxyPassword.Set("") diff --git a/pkg/utils/httpclient_test.go b/pkg/utils/httpclient_test.go index 82280b676..c1392f0f5 100644 --- a/pkg/utils/httpclient_test.go +++ b/pkg/utils/httpclient_test.go @@ -17,6 +17,7 @@ package utils import ( + "context" "net/http" "net/http/httptest" "testing" @@ -37,13 +38,15 @@ func TestNewSSRFSafeHTTPClient(t *testing.T) { config.OutgoingRequestsAllowNonRoutableIPs.Set("true") defer config.OutgoingRequestsAllowNonRoutableIPs.Set("false") - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) })) defer server.Close() client := NewSSRFSafeHTTPClient() - resp, err := client.Get(server.URL) + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, nil) + require.NoError(t, err) + resp, err := client.Do(req) require.NoError(t, err) defer resp.Body.Close() assert.Equal(t, http.StatusOK, resp.StatusCode) @@ -54,7 +57,9 @@ func TestNewSSRFSafeHTTPClient(t *testing.T) { client := NewSSRFSafeHTTPClient() // Attempt to connect to localhost (non-routable) - _, err := client.Get("http://127.0.0.1:1/test") + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://127.0.0.1:1/test", nil) + require.NoError(t, err) + _, err = client.Do(req) //nolint:bodyclose require.Error(t, err) }) @@ -62,13 +67,15 @@ func TestNewSSRFSafeHTTPClient(t *testing.T) { config.OutgoingRequestsAllowNonRoutableIPs.Set("true") defer config.OutgoingRequestsAllowNonRoutableIPs.Set("false") - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) })) defer server.Close() client := NewSSRFSafeHTTPClient() - resp, err := client.Get(server.URL) + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, nil) + require.NoError(t, err) + resp, err := client.Do(req) require.NoError(t, err) defer resp.Body.Close() assert.Equal(t, http.StatusOK, resp.StatusCode)