diff --git a/pkg/modules/auth/ldap/ldap.go b/pkg/modules/auth/ldap/ldap.go index 39766df90..dbc90b06e 100644 --- a/pkg/modules/auth/ldap/ldap.go +++ b/pkg/modules/auth/ldap/ldap.go @@ -100,16 +100,64 @@ func ConnectAndBindToLDAPDirectory() (l *ldap.Conn, err error) { return } +// escapeLDAPFilterValue escapes special characters in LDAP filter values according to RFC 4515. +// This prevents LDAP injection attacks by properly escaping all special characters. +func escapeLDAPFilterValue(value string) string { + var buf strings.Builder + buf.Grow(len(value) * 2) // Pre-allocate to avoid reallocations + + for _, r := range value { + switch r { + case 0x00: // NULL + buf.WriteString(`\00`) + case '(': + buf.WriteString(`\28`) + case ')': + buf.WriteString(`\29`) + case '*': + buf.WriteString(`\2a`) + case '\\': + buf.WriteString(`\5c`) + case '&': + buf.WriteString(`\26`) + case '|': + buf.WriteString(`\7c`) + case '=': + buf.WriteString(`\3d`) + case '<': + buf.WriteString(`\3c`) + case '>': + buf.WriteString(`\3e`) + case '~': + buf.WriteString(`\7e`) + default: + buf.WriteRune(r) + } + } + + return buf.String() +} + // Adjusted from https://github.com/go-gitea/gitea/blob/6ca91f555ab9778310ac46cbbe33849c59286793/services/auth/source/ldap/source_search.go#L34 func sanitizedUserQuery(username string) (string, bool) { - // See http://tools.ietf.org/search/rfc4515 - badCharacters := "\x00()*\\" - if strings.ContainsAny(username, badCharacters) { - log.Debugf("'%s' contains invalid query characters. Aborting.", username) + // Validate username is not empty and doesn't contain control characters + if username == "" { + log.Debugf("Empty username provided. Aborting.") return "", false } - return fmt.Sprintf(config.AuthLdapUserFilter.GetString(), username), true + // Check for control characters that shouldn't be in usernames + for _, r := range username { + if r < 32 && r != 9 && r != 10 && r != 13 { // Allow tab, LF, CR but block other control chars + log.Debugf("Username contains control character 0x%02x. Aborting.", r) + return "", false + } + } + + // Escape the username according to RFC 4515 to prevent LDAP injection + escapedUsername := escapeLDAPFilterValue(username) + + return fmt.Sprintf(config.AuthLdapUserFilter.GetString(), escapedUsername), true } func AuthenticateUserInLDAP(s *xorm.Session, username, password string, syncGroups bool, avatarSyncAttribute string) (u *user.User, err error) { diff --git a/pkg/modules/auth/ldap/ldap_test.go b/pkg/modules/auth/ldap/ldap_test.go index f3386a0c0..c48c8db08 100644 --- a/pkg/modules/auth/ldap/ldap_test.go +++ b/pkg/modules/auth/ldap/ldap_test.go @@ -17,7 +17,9 @@ package ldap import ( + "fmt" "os" + "strings" "testing" "code.vikunja.io/api/pkg/config" @@ -116,3 +118,227 @@ func TestLdapLogin(t *testing.T) { }, false) }) } + +func TestEscapeLDAPFilterValue(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "normal username", + input: "testuser", + expected: "testuser", + }, + { + name: "username with parentheses", + input: "test(user)", + expected: `test\28user\29`, + }, + { + name: "username with asterisk", + input: "test*user", + expected: `test\2auser`, + }, + { + name: "username with backslash", + input: `test\user`, + expected: `test\5cuser`, + }, + { + name: "username with ampersand", + input: "test&user", + expected: `test\26user`, + }, + { + name: "username with pipe", + input: "test|user", + expected: `test\7cuser`, + }, + { + name: "username with equals", + input: "test=user", + expected: `test\3duser`, + }, + { + name: "username with less than", + input: "testuser", + expected: `test\3euser`, + }, + { + name: "username with tilde", + input: "test~user", + expected: `test\7euser`, + }, + { + name: "username with null byte", + input: "test\x00user", + expected: `test\00user`, + }, + { + name: "complex injection attempt", + input: "admin)(|(objectClass=*", + expected: `admin\29\28\7c\28objectClass\3d\2a`, + }, + { + name: "LDAP injection with OR operator", + input: "testuser)|(&(objectClass=user", + expected: `testuser\29\7c\28\26\28objectClass\3duser`, + }, + { + name: "multiple special characters", + input: "test()&|=<>~*\\user", + expected: `test\28\29\26\7c\3d\3c\3e\7e\2a\5cuser`, + }, + { + name: "empty string", + input: "", + expected: "", + }, + { + name: "unicode characters", + input: "testuser_unicode", + expected: "testuser_unicode", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := escapeLDAPFilterValue(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestSanitizedUserQuery(t *testing.T) { + // Set up a test filter for this test + originalFilter := config.AuthLdapUserFilter.GetString() + config.AuthLdapUserFilter.Set("(&(objectClass=user)(sAMAccountName=%[1]s))") + defer func() { + if originalFilter != "" { + config.AuthLdapUserFilter.Set(originalFilter) + } + }() + + tests := []struct { + name string + input string + expectedResult bool + expectedFilter string + }{ + { + name: "normal username", + input: "testuser", + expectedResult: true, + expectedFilter: "(&(objectClass=user)(sAMAccountName=testuser))", + }, + { + name: "username with injection attempt", + input: "admin)(|(objectClass=*", + expectedResult: true, + expectedFilter: `(&(objectClass=user)(sAMAccountName=admin\29\28\7c\28objectClass\3d\2a))`, + }, + { + name: "username with OR operator", + input: "test|admin", + expectedResult: true, + expectedFilter: `(&(objectClass=user)(sAMAccountName=test\7cadmin))`, + }, + { + name: "empty username", + input: "", + expectedResult: false, + expectedFilter: "", + }, + { + name: "username with null byte", + input: "test\x00user", + expectedResult: false, + expectedFilter: "", + }, + { + name: "username with other control characters", + input: "test\x01user", + expectedResult: false, + expectedFilter: "", + }, + { + name: "username with allowed whitespace", + input: "test user", + expectedResult: true, + expectedFilter: "(&(objectClass=user)(sAMAccountName=test user))", + }, + { + name: "username with tab (allowed)", + input: "test\tuser", + expectedResult: true, + expectedFilter: "(&(objectClass=user)(sAMAccountName=test\tuser))", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, ok := sanitizedUserQuery(tt.input) + assert.Equal(t, tt.expectedResult, ok) + if ok { + assert.Equal(t, tt.expectedFilter, result) + } else { + assert.Empty(t, result) + } + }) + } +} + +func TestSanitizedUserQueryPreventsInjection(t *testing.T) { + // Set up a test filter + config.AuthLdapUserFilter.Set("(&(objectClass=user)(uid=%[1]s))") + defer config.AuthLdapUserFilter.Set("") + + // Test various injection attempts + injectionAttempts := []string{ + "admin)(uid=*", // Try to match any uid + "*)(|(uid=admin", // OR injection + "admin))(&(objectClass=*", // Try to match any object class + "admin))(|(|(uid=admin)(uid=root", // Complex OR injection + "admin&admin", // AND injection + "admin=admin", // Equals injection + "adminadmin", // Greater than injection + "admin~admin", // Approximate match injection + } + + for i, attempt := range injectionAttempts { + t.Run(fmt.Sprintf("injection_attempt_%d", i+1), func(t *testing.T) { + result, ok := sanitizedUserQuery(attempt) + assert.True(t, ok, "Query should be sanitized, not rejected") + + // Verify that all special characters are properly escaped + assert.NotContains(t, result, ")(uid=*", "Should not contain unescaped injection") + assert.NotContains(t, result, "|(", "Should not contain unescaped OR operator") + assert.NotContains(t, result, "))(", "Should not contain unescaped parentheses") + assert.NotContains(t, result, "=*", "Should not contain unescaped equals with wildcard") + + // Verify escaping is present where expected + if strings.Contains(attempt, "(") { + assert.Contains(t, result, `\28`, "Should contain escaped opening parenthesis") + } + if strings.Contains(attempt, ")") { + assert.Contains(t, result, `\29`, "Should contain escaped closing parenthesis") + } + if strings.Contains(attempt, "|") { + assert.Contains(t, result, `\7c`, "Should contain escaped pipe") + } + if strings.Contains(attempt, "&") { + assert.Contains(t, result, `\26`, "Should contain escaped ampersand") + } + if strings.Contains(attempt, "=") { + assert.Contains(t, result, `\3d`, "Should contain escaped equals") + } + }) + } +}