diff --git a/CHANGELOG.md b/CHANGELOG.md index f3b311fd..cf2864d0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- Added root River CLI flag `--statement-timeout` so Postgres session statement timeout can be set explicitly for commands like migrations. Explicit flag values take priority over database URL query params, and query params still take priority over built-in defaults. [PR #1142](https://github.com/riverqueue/river/pull/1142). + ### Fixed - `JobCountByQueueAndState` now returns consistent results across drivers, including requested queues with zero jobs, and deduplicates repeated queue names in input. This resolves an issue with the sqlite driver in River UI reported in [riverqueue/riverui#496](https://github.com/riverqueue/riverui#496). [PR #1140](https://github.com/riverqueue/river/pull/1140). diff --git a/cmd/river/rivercli/command.go b/cmd/river/rivercli/command.go index 4f96f51c..b6f8e1dc 100644 --- a/cmd/river/rivercli/command.go +++ b/cmd/river/rivercli/command.go @@ -1,6 +1,7 @@ package rivercli import ( + "cmp" "context" "database/sql" "fmt" @@ -46,11 +47,12 @@ type CommandOpts interface { // RunCommandBundle is a bundle of utilities for RunCommand. type RunCommandBundle struct { - DatabaseURL *string - DriverProcurer DriverProcurer - Logger *slog.Logger - OutStd io.Writer - Schema string + DatabaseURL *string + DriverProcurer DriverProcurer + Logger *slog.Logger + OutStd io.Writer + Schema string + StatementTimeout *time.Duration } // RunCommand bootstraps and runs a River CLI subcommand. @@ -81,7 +83,7 @@ func RunCommand[TOpts CommandOpts](ctx context.Context, bundle *RunCommandBundle if databaseURL != nil { switch protocol { case "postgres", "postgresql": - dbPool, err := openPgxV5DBPool(ctx, *databaseURL) + dbPool, err := openPgxV5DBPool(ctx, *databaseURL, bundle.StatementTimeout) if err != nil { return false, err } @@ -128,7 +130,7 @@ func RunCommand[TOpts CommandOpts](ctx context.Context, bundle *RunCommandBundle return nil } -func openPgxV5DBPool(ctx context.Context, databaseURL string) (*pgxpool.Pool, error) { +func openPgxV5DBPool(ctx context.Context, databaseURL string, statementTimeout *time.Duration) (*pgxpool.Pool, error) { const ( defaultIdleInTransactionSessionTimeout = 11 * time.Second // should be greater than statement timeout because statements count towards idle-in-transaction defaultStatementTimeout = 10 * time.Second @@ -149,9 +151,24 @@ func openPgxV5DBPool(ctx context.Context, databaseURL string) (*pgxpool.Pool, er runtimeParams[name] = val } - setParamIfUnset(pgxConfig.ConnConfig.RuntimeParams, "application_name", "river CLI") - setParamIfUnset(pgxConfig.ConnConfig.RuntimeParams, "idle_in_transaction_session_timeout", strconv.Itoa(int(defaultIdleInTransactionSessionTimeout.Milliseconds()))) - setParamIfUnset(pgxConfig.ConnConfig.RuntimeParams, "statement_timeout", strconv.Itoa(int(defaultStatementTimeout.Milliseconds()))) + runtimeParams := pgxConfig.ConnConfig.RuntimeParams + if runtimeParams == nil { + runtimeParams = make(map[string]string) + pgxConfig.ConnConfig.RuntimeParams = runtimeParams + } + + var statementTimeoutMilliseconds string + if statementTimeout != nil { + statementTimeoutMilliseconds = strconv.Itoa(int(statementTimeout.Milliseconds())) + } + + setParamIfUnset(runtimeParams, "application_name", "river CLI") + setParamIfUnset(runtimeParams, "idle_in_transaction_session_timeout", strconv.Itoa(int(defaultIdleInTransactionSessionTimeout.Milliseconds()))) + runtimeParams["statement_timeout"] = cmp.Or( + statementTimeoutMilliseconds, + runtimeParams["statement_timeout"], + strconv.Itoa(int(defaultStatementTimeout.Milliseconds())), + ) dbPool, err := pgxpool.NewWithConfig(ctx, pgxConfig) if err != nil { diff --git a/cmd/river/rivercli/river_cli.go b/cmd/river/rivercli/river_cli.go index 88e8e599..0b0dedc4 100644 --- a/cmd/river/rivercli/river_cli.go +++ b/cmd/river/rivercli/river_cli.go @@ -60,9 +60,11 @@ func (c *CLI) BaseCommandSet() *cobra.Command { ctx := context.Background() var globalOpts struct { - Debug bool - Verbose bool + Debug bool + StatementTimeout time.Duration + Verbose bool } + var rootCmd *cobra.Command makeLogger := func() *slog.Logger { switch { @@ -75,18 +77,35 @@ func (c *CLI) BaseCommandSet() *cobra.Command { } } + statementTimeoutFlagSet := func() bool { + return rootCmd.PersistentFlags().Changed("statement-timeout") + } + + validateGlobalOpts := func() error { + if statementTimeoutFlagSet() && globalOpts.StatementTimeout <= time.Millisecond { + return errors.New("`--statement-timeout` must be greater than 1ms when set") + } + + return nil + } + // Make a bundle for RunCommand. Takes a database URL pointer because not every command is required to take a database URL. makeCommandBundle := func(databaseURL *string, schema string) *RunCommandBundle { + var statementTimeout *time.Duration + if statementTimeoutFlagSet() { + statementTimeout = &globalOpts.StatementTimeout + } + return &RunCommandBundle{ - DatabaseURL: databaseURL, - DriverProcurer: c.driverProcurer, - Logger: makeLogger(), - OutStd: c.out, - Schema: schema, + DatabaseURL: databaseURL, + DriverProcurer: c.driverProcurer, + Logger: makeLogger(), + OutStd: c.out, + Schema: schema, + StatementTimeout: statementTimeout, } } - var rootCmd *cobra.Command { var rootOpts struct { Version bool @@ -103,7 +122,15 @@ also accept Postgres configuration through the standard set of libpq environment variables like PGHOST, PGPORT, PGDATABASE, PGUSER, PGPASSWORD, and PGSSLMODE, with a minimum of PGDATABASE required. --database-url will take precedence of PG* vars if it's been specified. + +Use --statement-timeout to explicitly set Postgres statement_timeout for +Postgres-backed commands. Precedence is: --statement-timeout, then a +statement_timeout query parameter in --database-url, then the built-in 10s +default. `), + PersistentPreRunE: func(cmd *cobra.Command, args []string) error { + return validateGlobalOpts() + }, RunE: func(cmd *cobra.Command, args []string) error { if rootOpts.Version { return RunCommand(ctx, makeCommandBundle(nil, ""), &version{}, &versionOpts{Name: c.name}) @@ -116,6 +143,7 @@ PG* vars if it's been specified. rootCmd.SetOut(c.out) rootCmd.PersistentFlags().BoolVar(&globalOpts.Debug, "debug", false, "output maximum logging verbosity (debug level)") + rootCmd.PersistentFlags().DurationVar(&globalOpts.StatementTimeout, "statement-timeout", 0, "override Postgres statement_timeout for Postgres commands (Go duration >1ms, e.g. 10s, 1m); precedence: flag > --database-url statement_timeout > default 10s") rootCmd.PersistentFlags().BoolVarP(&globalOpts.Verbose, "verbose", "v", false, "output additional logging verbosity (info level)") rootCmd.MarkFlagsMutuallyExclusive("debug", "verbose") diff --git a/cmd/river/rivercli/river_cli_test.go b/cmd/river/rivercli/river_cli_test.go index ca6cf600..65e1134d 100644 --- a/cmd/river/rivercli/river_cli_test.go +++ b/cmd/river/rivercli/river_cli_test.go @@ -5,6 +5,7 @@ import ( "cmp" "context" "fmt" + "maps" "net/url" "runtime/debug" "strings" @@ -169,6 +170,34 @@ func TestBaseCommandSetIntegration(t *testing.T) { require.EqualError(t, cmd.Execute(), "either PG* env vars or --database-url must be set") }) + t.Run("StatementTimeoutValidation", func(t *testing.T) { + t.Parallel() + + t.Run("AllowsGreaterThanOneMillisecond", func(t *testing.T) { + t.Parallel() + + cmd, _ := setup(t) + cmd.SetArgs([]string{"--statement-timeout", "2ms", "--version"}) + require.NoError(t, cmd.Execute()) + }) + + t.Run("RejectsOneMillisecond", func(t *testing.T) { + t.Parallel() + + cmd, _ := setup(t) + cmd.SetArgs([]string{"--statement-timeout", "1ms", "--version"}) + require.EqualError(t, cmd.Execute(), "`--statement-timeout` must be greater than 1ms when set") + }) + + t.Run("RejectsZero", func(t *testing.T) { + t.Parallel() + + cmd, _ := setup(t) + cmd.SetArgs([]string{"--statement-timeout", "0", "--version"}) + require.EqualError(t, cmd.Execute(), "`--statement-timeout` must be greater than 1ms when set") + }) + }) + t.Run("VersionFlag", func(t *testing.T) { t.Parallel() @@ -263,6 +292,116 @@ func TestBaseCommandSetNonParallel(t *testing.T) { }) } +func TestBaseCommandSetPostgresTimeoutPrecedence(t *testing.T) { + t.Parallel() + + type testCase struct { + databaseURLStatementTimeout string + expectedStatementTimeoutMS string + name string + statementTimeoutFlag string + } + + makeCommandAndParams := func(t *testing.T) (*cobra.Command, func() map[string]string) { + t.Helper() + + var capturedRuntimeParams map[string]string + + migratorStub := &MigratorStub{} + migratorStub.allVersionsStub = func() []rivermigrate.Migration { return []rivermigrate.Migration{testMigration01} } + migratorStub.getVersionStub = func(version int) (rivermigrate.Migration, error) { + if version == 1 { + return testMigration01, nil + } + + return rivermigrate.Migration{}, fmt.Errorf("unknown version: %d", version) + } + migratorStub.existingVersionsStub = func(ctx context.Context) ([]rivermigrate.Migration, error) { return nil, nil } + + cli := NewCLI(&Config{ + DriverProcurer: &DriverProcurerStub{ + getMigratorStub: func(config *rivermigrate.Config) (MigratorInterface, error) { + return migratorStub, nil + }, + initPgxV5Stub: func(pool *pgxpool.Pool) { + capturedRuntimeParams = maps.Clone(pool.Config().ConnConfig.RuntimeParams) + }, + }, + Name: "River", + }) + + var out bytes.Buffer + cli.SetOut(&out) + + return cli.BaseCommandSet(), func() map[string]string { + return capturedRuntimeParams + } + } + + makeBaseDatabaseURL := func(t *testing.T) *url.URL { + t.Helper() + + testDatabaseURL := riversharedtest.TestDatabaseURL() + parsedDatabaseURL, err := url.Parse(testDatabaseURL) + require.NoError(t, err) + + return parsedDatabaseURL + } + + testCases := []testCase{ + { + name: "DefaultsAppliedWhenNothingSpecified", + expectedStatementTimeoutMS: "10000", + }, + { + databaseURLStatementTimeout: "11234", + name: "DatabaseURLQueryParamsOverrideDefaults", + expectedStatementTimeoutMS: "11234", + }, + { + databaseURLStatementTimeout: "12345", + name: "ExplicitFlagsOverrideDatabaseURLQueryParams", + statementTimeoutFlag: "1m3.123s", + expectedStatementTimeoutMS: "63123", + }, + { + databaseURLStatementTimeout: "12345", + name: "ExplicitFlagsUseMillisecondValue", + statementTimeoutFlag: "2ms", + expectedStatementTimeoutMS: "2", + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + cmd, getRuntimeParams := makeCommandAndParams(t) + + databaseURL := makeBaseDatabaseURL(t) + if testCase.databaseURLStatementTimeout != "" { + queryValues := databaseURL.Query() + queryValues.Set("statement_timeout", testCase.databaseURLStatementTimeout) + databaseURL.RawQuery = queryValues.Encode() + } + + args := []string{ + "migrate-get", "--up", "--version", "1", "--database-url", databaseURL.String(), + } + if testCase.statementTimeoutFlag != "" { + args = append(args, "--statement-timeout", testCase.statementTimeoutFlag) + } + cmd.SetArgs(args) + require.NoError(t, cmd.Execute()) + + runtimeParams := getRuntimeParams() + require.NotNil(t, runtimeParams) + + require.Equal(t, testCase.expectedStatementTimeoutMS, runtimeParams["statement_timeout"]) + }) + } +} + func TestBaseCommandSetDriverProcurerPgxV5(t *testing.T) { t.Parallel() diff --git a/docs/development.md b/docs/development.md index 11d16de8..837c2997 100644 --- a/docs/development.md +++ b/docs/development.md @@ -30,6 +30,11 @@ To run programs locally outside of tests, create and raise a development databas createdb river_dev go run ./cmd/river migrate-up --database-url postgres:///river_dev --line main +If needed, override Postgres timeouts for long-running migrations with root CLI +flags: + + go run ./cmd/river --statement-timeout 2m migrate-up --database-url postgres:///river_dev --line main + ## Releasing a new version 1. Fetch changes to the repo and any new tags. Export `VERSION` by incrementing the last tag. Execute `update-mod-version` to add it the project's `go.mod` files: