diff --git a/pkg/config/config.go b/pkg/config/config.go index af529b757..27b913c1b 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -223,6 +223,7 @@ const ( OutgoingRequestsAllowNonRoutableIPs Key = `outgoingrequests.allownonroutableips` OutgoingRequestsProxyURL Key = `outgoingrequests.proxyurl` OutgoingRequestsProxyPassword Key = `outgoingrequests.proxypassword` + OutgoingRequestsTimeoutSeconds Key = `outgoingrequests.timeoutseconds` AutoTLSEnabled Key = `autotls.enabled` AutoTLSEmail Key = `autotls.email` @@ -480,6 +481,7 @@ func InitDefaultConfig() { WebhooksAllowNonRoutableIPs.setDefault(false) // Outgoing Requests OutgoingRequestsAllowNonRoutableIPs.setDefault(false) + OutgoingRequestsTimeoutSeconds.setDefault(30) // AutoTLS AutoTLSRenewBefore.setDefault("720h") // 30days in hours // Plugins diff --git a/pkg/modules/avatar/gravatar/gravatar.go b/pkg/modules/avatar/gravatar/gravatar.go index e44fd0334..572cdecc3 100644 --- a/pkg/modules/avatar/gravatar/gravatar.go +++ b/pkg/modules/avatar/gravatar/gravatar.go @@ -90,7 +90,7 @@ func (g *Provider) GetAvatar(user *user.User, size int64) ([]byte, string, error if err != nil { return nil, err } - resp, err := (&http.Client{}).Do(req) // #nosec G704 -- URL is from config (AvatarGravatarBaseURL) + resp, err := (&http.Client{Timeout: 5 * time.Second}).Do(req) // #nosec G704 -- URL is from config (AvatarGravatarBaseURL) if err != nil { return nil, err } diff --git a/pkg/modules/background/unsplash/unsplash.go b/pkg/modules/background/unsplash/unsplash.go index afeed3496..07a3b4ad8 100644 --- a/pkg/modules/background/unsplash/unsplash.go +++ b/pkg/modules/background/unsplash/unsplash.go @@ -108,7 +108,7 @@ func doGet(url string, result ...interface{}) (err error) { } req.Header.Add("Authorization", "Client-ID "+config.BackgroundsUnsplashAccessToken.GetString()) - hc := http.Client{} + hc := http.Client{Timeout: 10 * time.Second} resp, err := hc.Do(req) // #nosec G704 -- URL is constructed from hardcoded Unsplash API base if err != nil { return diff --git a/pkg/utils/httpclient.go b/pkg/utils/httpclient.go index 57b9b0b58..ed6a902bc 100644 --- a/pkg/utils/httpclient.go +++ b/pkg/utils/httpclient.go @@ -21,6 +21,7 @@ import ( "net" "net/http" "net/url" + "time" "code.vikunja.io/api/pkg/config" "code.vikunja.io/api/pkg/version" @@ -37,7 +38,9 @@ import ( // config init time (see config.InitDefaultConfig), so this function only // reads the new keys. func NewSSRFSafeHTTPClient() *http.Client { - client := &http.Client{} + client := &http.Client{ + Timeout: time.Duration(config.OutgoingRequestsTimeoutSeconds.GetInt()) * time.Second, + } transport := &http.Transport{} if !config.OutgoingRequestsAllowNonRoutableIPs.GetBool() { diff --git a/pkg/utils/httpclient_test.go b/pkg/utils/httpclient_test.go index aa1a79724..03ccdb117 100644 --- a/pkg/utils/httpclient_test.go +++ b/pkg/utils/httpclient_test.go @@ -21,6 +21,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "code.vikunja.io/api/pkg/config" @@ -63,6 +64,40 @@ func TestNewSSRFSafeHTTPClient(t *testing.T) { require.Error(t, err) }) + t.Run("has default timeout from config", func(t *testing.T) { + config.OutgoingRequestsTimeoutSeconds.Set("30") + client := NewSSRFSafeHTTPClient() + assert.Equal(t, 30*time.Second, client.Timeout) + }) + + t.Run("respects custom timeout config", func(t *testing.T) { + config.OutgoingRequestsTimeoutSeconds.Set("15") + defer config.OutgoingRequestsTimeoutSeconds.Set("30") + + client := NewSSRFSafeHTTPClient() + assert.Equal(t, 15*time.Second, client.Timeout) + }) + + t.Run("timeout fires on slow server", func(t *testing.T) { + config.OutgoingRequestsAllowNonRoutableIPs.Set("true") + config.OutgoingRequestsTimeoutSeconds.Set("1") + defer config.OutgoingRequestsAllowNonRoutableIPs.Set("false") + defer config.OutgoingRequestsTimeoutSeconds.Set("30") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + time.Sleep(3 * time.Second) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + client := NewSSRFSafeHTTPClient() + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, nil) + require.NoError(t, err) + _, err = client.Do(req) //nolint:bodyclose,gosec + require.Error(t, err) + assert.Contains(t, err.Error(), "Client.Timeout") + }) + t.Run("allows non-routable IPs when config is true", func(t *testing.T) { config.OutgoingRequestsAllowNonRoutableIPs.Set("true") defer config.OutgoingRequestsAllowNonRoutableIPs.Set("false")