refactor: use shared SSRF-safe HTTP client in webhook code

This commit is contained in:
kolaente 2026-03-23 16:17:02 +01:00 committed by kolaente
parent a94109e1be
commit e5a1c05771
3 changed files with 19 additions and 40 deletions

View File

@ -25,9 +25,7 @@ import (
"encoding/hex"
"encoding/json"
"io"
"net"
"net/http"
"net/url"
"sort"
"strings"
"sync"
@ -37,10 +35,10 @@ import (
"code.vikunja.io/api/pkg/events"
"code.vikunja.io/api/pkg/log"
"code.vikunja.io/api/pkg/user"
"code.vikunja.io/api/pkg/utils"
"code.vikunja.io/api/pkg/version"
"code.vikunja.io/api/pkg/web"
"code.dny.dev/ssrf"
"xorm.io/xorm"
)
@ -289,40 +287,14 @@ func (w *Webhook) Delete(s *xorm.Session, _ web.Auth) (err error) {
}
func getWebhookHTTPClient() (client *http.Client) {
if webhookClient != nil {
return webhookClient
}
client = &http.Client{}
client = utils.NewSSRFSafeHTTPClient()
client.Timeout = time.Duration(config.WebhooksTimeoutSeconds.GetInt()) * time.Second
transport := &http.Transport{}
// SSRF protection: block connections to non-globally-routable IPs unless
// explicitly allowed. Uses daenney/ssrf which validates resolved IPs
// against IANA Special Purpose Registries after DNS resolution,
// preventing DNS rebinding attacks.
if !config.WebhooksAllowNonRoutableIPs.GetBool() {
guardian := ssrf.New(ssrf.WithAnyPort())
transport.DialContext = (&net.Dialer{
Control: guardian.Safe,
}).DialContext
}
if config.WebhooksProxyURL.GetString() != "" && config.WebhooksProxyPassword.GetString() != "" {
proxyURL, _ := url.Parse(config.WebhooksProxyURL.GetString())
transport.Proxy = http.ProxyURL(proxyURL)
transport.ProxyConnectHeader = http.Header{
"Proxy-Authorization": []string{"Basic " + base64.StdEncoding.EncodeToString([]byte("vikunja:"+config.WebhooksProxyPassword.GetString()))},
"User-Agent": []string{"Vikunja/" + version.Version},
}
}
client.Transport = transport
webhookClient = client
return
}

View File

@ -35,7 +35,7 @@ func TestWebhookSSRFProtection(t *testing.T) {
t.Run("blocks requests to loopback addresses", func(t *testing.T) {
resetWebhookClient()
config.WebhooksAllowNonRoutableIPs.Set(false)
config.OutgoingRequestsAllowNonRoutableIPs.Set("false")
config.WebhooksProxyURL.Set("")
config.WebhooksProxyPassword.Set("")
@ -53,7 +53,7 @@ func TestWebhookSSRFProtection(t *testing.T) {
t.Run("allows requests to public addresses", func(t *testing.T) {
resetWebhookClient()
config.WebhooksAllowNonRoutableIPs.Set(false)
config.OutgoingRequestsAllowNonRoutableIPs.Set("false")
config.WebhooksProxyURL.Set("")
config.WebhooksProxyPassword.Set("")
@ -80,7 +80,7 @@ func TestWebhookSSRFProtection(t *testing.T) {
t.Run("allows loopback when allownonroutableips is true", func(t *testing.T) {
resetWebhookClient()
config.WebhooksAllowNonRoutableIPs.Set(true)
config.OutgoingRequestsAllowNonRoutableIPs.Set("true")
config.WebhooksProxyURL.Set("")
config.WebhooksProxyPassword.Set("")
@ -102,7 +102,7 @@ func TestWebhookSSRFProtection(t *testing.T) {
t.Run("blocks requests to private RFC1918 addresses", func(t *testing.T) {
resetWebhookClient()
config.WebhooksAllowNonRoutableIPs.Set(false)
config.OutgoingRequestsAllowNonRoutableIPs.Set("false")
config.WebhooksProxyURL.Set("")
config.WebhooksProxyPassword.Set("")
@ -128,7 +128,7 @@ func TestWebhookSSRFProtection(t *testing.T) {
t.Run("blocks requests to metadata endpoint", func(t *testing.T) {
resetWebhookClient()
config.WebhooksAllowNonRoutableIPs.Set(false)
config.OutgoingRequestsAllowNonRoutableIPs.Set("false")
config.WebhooksProxyURL.Set("")
config.WebhooksProxyPassword.Set("")

View File

@ -17,6 +17,7 @@
package utils
import (
"context"
"net/http"
"net/http/httptest"
"testing"
@ -37,13 +38,15 @@ func TestNewSSRFSafeHTTPClient(t *testing.T) {
config.OutgoingRequestsAllowNonRoutableIPs.Set("true")
defer config.OutgoingRequestsAllowNonRoutableIPs.Set("false")
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
client := NewSSRFSafeHTTPClient()
resp, err := client.Get(server.URL)
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, nil)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)
@ -54,7 +57,9 @@ func TestNewSSRFSafeHTTPClient(t *testing.T) {
client := NewSSRFSafeHTTPClient()
// Attempt to connect to localhost (non-routable)
_, err := client.Get("http://127.0.0.1:1/test")
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://127.0.0.1:1/test", nil)
require.NoError(t, err)
_, err = client.Do(req) //nolint:bodyclose
require.Error(t, err)
})
@ -62,13 +67,15 @@ func TestNewSSRFSafeHTTPClient(t *testing.T) {
config.OutgoingRequestsAllowNonRoutableIPs.Set("true")
defer config.OutgoingRequestsAllowNonRoutableIPs.Set("false")
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
client := NewSSRFSafeHTTPClient()
resp, err := client.Get(server.URL)
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, nil)
require.NoError(t, err)
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
assert.Equal(t, http.StatusOK, resp.StatusCode)