refactor: use shared SSRF-safe HTTP client in webhook code
This commit is contained in:
parent
a94109e1be
commit
e5a1c05771
|
|
@ -25,9 +25,7 @@ import (
|
|||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
|
@ -37,10 +35,10 @@ import (
|
|||
"code.vikunja.io/api/pkg/events"
|
||||
"code.vikunja.io/api/pkg/log"
|
||||
"code.vikunja.io/api/pkg/user"
|
||||
"code.vikunja.io/api/pkg/utils"
|
||||
"code.vikunja.io/api/pkg/version"
|
||||
"code.vikunja.io/api/pkg/web"
|
||||
|
||||
"code.dny.dev/ssrf"
|
||||
"xorm.io/xorm"
|
||||
)
|
||||
|
||||
|
|
@ -289,40 +287,14 @@ func (w *Webhook) Delete(s *xorm.Session, _ web.Auth) (err error) {
|
|||
}
|
||||
|
||||
func getWebhookHTTPClient() (client *http.Client) {
|
||||
|
||||
if webhookClient != nil {
|
||||
return webhookClient
|
||||
}
|
||||
|
||||
client = &http.Client{}
|
||||
client = utils.NewSSRFSafeHTTPClient()
|
||||
client.Timeout = time.Duration(config.WebhooksTimeoutSeconds.GetInt()) * time.Second
|
||||
|
||||
transport := &http.Transport{}
|
||||
|
||||
// SSRF protection: block connections to non-globally-routable IPs unless
|
||||
// explicitly allowed. Uses daenney/ssrf which validates resolved IPs
|
||||
// against IANA Special Purpose Registries after DNS resolution,
|
||||
// preventing DNS rebinding attacks.
|
||||
if !config.WebhooksAllowNonRoutableIPs.GetBool() {
|
||||
guardian := ssrf.New(ssrf.WithAnyPort())
|
||||
transport.DialContext = (&net.Dialer{
|
||||
Control: guardian.Safe,
|
||||
}).DialContext
|
||||
}
|
||||
|
||||
if config.WebhooksProxyURL.GetString() != "" && config.WebhooksProxyPassword.GetString() != "" {
|
||||
proxyURL, _ := url.Parse(config.WebhooksProxyURL.GetString())
|
||||
transport.Proxy = http.ProxyURL(proxyURL)
|
||||
transport.ProxyConnectHeader = http.Header{
|
||||
"Proxy-Authorization": []string{"Basic " + base64.StdEncoding.EncodeToString([]byte("vikunja:"+config.WebhooksProxyPassword.GetString()))},
|
||||
"User-Agent": []string{"Vikunja/" + version.Version},
|
||||
}
|
||||
}
|
||||
|
||||
client.Transport = transport
|
||||
|
||||
webhookClient = client
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ func TestWebhookSSRFProtection(t *testing.T) {
|
|||
|
||||
t.Run("blocks requests to loopback addresses", func(t *testing.T) {
|
||||
resetWebhookClient()
|
||||
config.WebhooksAllowNonRoutableIPs.Set(false)
|
||||
config.OutgoingRequestsAllowNonRoutableIPs.Set("false")
|
||||
config.WebhooksProxyURL.Set("")
|
||||
config.WebhooksProxyPassword.Set("")
|
||||
|
||||
|
|
@ -53,7 +53,7 @@ func TestWebhookSSRFProtection(t *testing.T) {
|
|||
|
||||
t.Run("allows requests to public addresses", func(t *testing.T) {
|
||||
resetWebhookClient()
|
||||
config.WebhooksAllowNonRoutableIPs.Set(false)
|
||||
config.OutgoingRequestsAllowNonRoutableIPs.Set("false")
|
||||
config.WebhooksProxyURL.Set("")
|
||||
config.WebhooksProxyPassword.Set("")
|
||||
|
||||
|
|
@ -80,7 +80,7 @@ func TestWebhookSSRFProtection(t *testing.T) {
|
|||
|
||||
t.Run("allows loopback when allownonroutableips is true", func(t *testing.T) {
|
||||
resetWebhookClient()
|
||||
config.WebhooksAllowNonRoutableIPs.Set(true)
|
||||
config.OutgoingRequestsAllowNonRoutableIPs.Set("true")
|
||||
config.WebhooksProxyURL.Set("")
|
||||
config.WebhooksProxyPassword.Set("")
|
||||
|
||||
|
|
@ -102,7 +102,7 @@ func TestWebhookSSRFProtection(t *testing.T) {
|
|||
|
||||
t.Run("blocks requests to private RFC1918 addresses", func(t *testing.T) {
|
||||
resetWebhookClient()
|
||||
config.WebhooksAllowNonRoutableIPs.Set(false)
|
||||
config.OutgoingRequestsAllowNonRoutableIPs.Set("false")
|
||||
config.WebhooksProxyURL.Set("")
|
||||
config.WebhooksProxyPassword.Set("")
|
||||
|
||||
|
|
@ -128,7 +128,7 @@ func TestWebhookSSRFProtection(t *testing.T) {
|
|||
|
||||
t.Run("blocks requests to metadata endpoint", func(t *testing.T) {
|
||||
resetWebhookClient()
|
||||
config.WebhooksAllowNonRoutableIPs.Set(false)
|
||||
config.OutgoingRequestsAllowNonRoutableIPs.Set("false")
|
||||
config.WebhooksProxyURL.Set("")
|
||||
config.WebhooksProxyPassword.Set("")
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
|
@ -37,13 +38,15 @@ func TestNewSSRFSafeHTTPClient(t *testing.T) {
|
|||
config.OutgoingRequestsAllowNonRoutableIPs.Set("true")
|
||||
defer config.OutgoingRequestsAllowNonRoutableIPs.Set("false")
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewSSRFSafeHTTPClient()
|
||||
resp, err := client.Get(server.URL)
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, nil)
|
||||
require.NoError(t, err)
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
|
@ -54,7 +57,9 @@ func TestNewSSRFSafeHTTPClient(t *testing.T) {
|
|||
client := NewSSRFSafeHTTPClient()
|
||||
|
||||
// Attempt to connect to localhost (non-routable)
|
||||
_, err := client.Get("http://127.0.0.1:1/test")
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://127.0.0.1:1/test", nil)
|
||||
require.NoError(t, err)
|
||||
_, err = client.Do(req) //nolint:bodyclose
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
|
|
@ -62,13 +67,15 @@ func TestNewSSRFSafeHTTPClient(t *testing.T) {
|
|||
config.OutgoingRequestsAllowNonRoutableIPs.Set("true")
|
||||
defer config.OutgoingRequestsAllowNonRoutableIPs.Set("false")
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
client := NewSSRFSafeHTTPClient()
|
||||
resp, err := client.Get(server.URL)
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, nil)
|
||||
require.NoError(t, err)
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
defer resp.Body.Close()
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
|
|
|||
Loading…
Reference in New Issue