diff --git a/pkg/modules/dump/restore.go b/pkg/modules/dump/restore.go index 14742e649..5c9062cfc 100644 --- a/pkg/modules/dump/restore.go +++ b/pkg/modules/dump/restore.go @@ -45,7 +45,7 @@ import ( "src.techknowlogick.com/xormigrate" ) -const maxConfigSize = 5 * 1024 * 1024 // 5 MB, should be largely enough +const maxConfigSize = 5 * 1024 * 1024 // 5 MB, should be largely enough const maxDumpEntrySize = 500 * 1024 * 1024 // 500 MB // parseDbFileName validates and extracts the table name from a database dump filename. @@ -180,26 +180,8 @@ func Restore(filename string, overrideConfig bool) error { lastMigration := ms[len(ms)-2] - // Pre-validate all table data JSON before wiping to avoid leaving the database - // in a destroyed state when the archive contains corrupted table data. - for table, d := range dbfiles { - if table == "migration" { - continue - } - rc, err := d.Open() - if err != nil { - return fmt.Errorf("could not open table data for %s: %w", table, err) - } - var bufValidate bytes.Buffer - if _, err := bufValidate.ReadFrom(io.LimitReader(rc, maxDumpEntrySize)); err != nil { - rc.Close() - return fmt.Errorf("could not read table data for %s: %w", table, err) - } - rc.Close() - var test []map[string]interface{} - if err := json.Unmarshal(bufValidate.Bytes(), &test); err != nil { - return fmt.Errorf("invalid JSON in table data for %s: %w", table, err) - } + if err := preValidateTableData(dbfiles); err != nil { + return err } // Start by wiping everything - only after we've validated the archive @@ -404,6 +386,32 @@ func restoreTableData(tables map[string]*zip.File) error { return nil } +// preValidateTableData checks that all table data JSON files in the archive +// are parseable before wiping the database, to avoid leaving the database +// in a destroyed state when the archive contains corrupted data. +func preValidateTableData(dbfiles map[string]*zip.File) error { + for table, d := range dbfiles { + if table == "migration" { + continue + } + rc, err := d.Open() + if err != nil { + return fmt.Errorf("could not open table data for %s: %w", table, err) + } + var buf bytes.Buffer + if _, err := buf.ReadFrom(io.LimitReader(rc, maxDumpEntrySize)); err != nil { + rc.Close() + return fmt.Errorf("could not read table data for %s: %w", table, err) + } + rc.Close() + var test []map[string]interface{} + if err := json.Unmarshal(buf.Bytes(), &test); err != nil { + return fmt.Errorf("invalid JSON in table data for %s: %w", table, err) + } + } + return nil +} + func unmarshalFileToJSON(file *zip.File) (contents []map[string]interface{}, err error) { rc, err := file.Open() if err != nil {