diff --git a/internal/driverutil/operation.go b/internal/driverutil/operation.go index e37cba5903..74142a56e8 100644 --- a/internal/driverutil/operation.go +++ b/internal/driverutil/operation.go @@ -6,6 +6,12 @@ package driverutil +import ( + "context" + "math" + "time" +) + // Operation Names should be sourced from the command reference documentation: // https://www.mongodb.com/docs/manual/reference/command/ const ( @@ -30,3 +36,34 @@ const ( UpdateOp = "update" // UpdateOp is the name for updating BulkWriteOp = "bulkWrite" // BulkWriteOp is the name for client-level bulk write ) + +// CalculateMaxTimeMS calculates the maxTimeMS value to send to the server +// based on the context deadline and the minimum round trip time. If the +// calculated maxTimeMS is likely to cause a socket timeout, then this function +// will return 0 and false. +func CalculateMaxTimeMS(ctx context.Context, rttMin time.Duration) (int64, bool) { + deadline, ok := ctx.Deadline() + if !ok { + return 0, true + } + + remainingTimeout := time.Until(deadline) + + // Always round up to the next millisecond value so we never truncate the calculated + // maxTimeMS value (e.g. 400 microseconds evaluates to 1ms, not 0ms). + maxTimeMS := int64((remainingTimeout - rttMin + time.Millisecond - 1) / time.Millisecond) + if maxTimeMS <= 0 { + return 0, false + } + + // The server will return a "BadValue" error if maxTimeMS is greater + // than the maximum positive int32 value (about 24.9 days). If the + // user specified a timeout value greater than that, omit maxTimeMS + // and let the client-side timeout handle cancelling the op if the + // timeout is ever reached. + if maxTimeMS > math.MaxInt32 { + return 0, true + } + + return maxTimeMS, true +} diff --git a/internal/driverutil/operation_test.go b/internal/driverutil/operation_test.go new file mode 100644 index 0000000000..474c3e1aa1 --- /dev/null +++ b/internal/driverutil/operation_test.go @@ -0,0 +1,113 @@ +// Copyright (C) MongoDB, Inc. 2025-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package driverutil + +import ( + "context" + "math" + "testing" + "time" + + "go.mongodb.org/mongo-driver/v2/internal/assert" +) + +func TestCalculateMaxTimeMS(t *testing.T) { + tests := []struct { + name string + ctx context.Context + rttMin time.Duration + wantZero bool + wantOk bool + wantPositive bool + wantExact int64 + }{ + { + name: "no deadline", + ctx: context.Background(), + rttMin: 10 * time.Millisecond, + wantZero: true, + wantOk: true, + wantPositive: false, + }, + { + name: "deadline expired", + ctx: func() context.Context { + ctx, _ := context.WithDeadline(context.Background(), time.Now().Add(-1*time.Second)) //nolint:govet + return ctx + }(), + wantZero: true, + wantOk: false, + wantPositive: false, + }, + { + name: "remaining timeout < rttMin", + ctx: func() context.Context { + ctx, _ := context.WithDeadline(context.Background(), time.Now().Add(1*time.Millisecond)) //nolint:govet + return ctx + }(), + rttMin: 10 * time.Millisecond, + wantZero: true, + wantOk: false, + wantPositive: false, + }, + { + name: "normal positive result", + ctx: func() context.Context { + ctx, _ := context.WithDeadline(context.Background(), time.Now().Add(100*time.Millisecond)) //nolint:govet + return ctx + }(), + wantZero: false, + wantOk: true, + wantPositive: true, + }, + { + name: "beyond maxInt32", + ctx: func() context.Context { + dur := time.Now().Add(time.Duration(math.MaxInt32+1000) * time.Millisecond) + ctx, _ := context.WithDeadline(context.Background(), dur) //nolint:govet + return ctx + }(), + wantZero: true, + wantOk: true, + wantPositive: false, + }, + { + name: "round up to 1ms", + ctx: func() context.Context { + ctx, _ := context.WithDeadline(context.Background(), time.Now().Add(999*time.Microsecond)) //nolint:govet + return ctx + }(), + wantOk: true, + wantExact: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, got1 := CalculateMaxTimeMS(tt.ctx, tt.rttMin) + + assert.Equal(t, tt.wantOk, got1) + + if tt.wantExact > 0 && got != tt.wantExact { + t.Errorf("CalculateMaxTimeMS() got = %v, want %v", got, tt.wantExact) + } + + if tt.wantZero && got != 0 { + t.Errorf("CalculateMaxTimeMS() got = %v, want 0", got) + } + + if !tt.wantZero && got == 0 { + t.Errorf("CalculateMaxTimeMS() got = %v, want > 0", got) + } + + if !tt.wantZero && tt.wantPositive && got <= 0 { + t.Errorf("CalculateMaxTimeMS() got = %v, want > 0", got) + } + }) + } + +} diff --git a/internal/integration/cursor_test.go b/internal/integration/cursor_test.go index 5ee9986ec2..6376e78e74 100644 --- a/internal/integration/cursor_test.go +++ b/internal/integration/cursor_test.go @@ -17,6 +17,7 @@ import ( "go.mongodb.org/mongo-driver/v2/internal/assert" "go.mongodb.org/mongo-driver/v2/internal/failpoint" "go.mongodb.org/mongo-driver/v2/internal/integration/mtest" + "go.mongodb.org/mongo-driver/v2/internal/require" "go.mongodb.org/mongo-driver/v2/mongo" "go.mongodb.org/mongo-driver/v2/mongo/options" ) @@ -303,6 +304,75 @@ func TestCursor(t *testing.T) { batchSize = sizeVal.Int32() assert.Equal(mt, int32(4), batchSize, "expected batchSize 4, got %v", batchSize) }) + + tailableAwaitDataCursorOpts := mtest.NewOptions().MinServerVersion("4.4"). + Topologies(mtest.ReplicaSet, mtest.Sharded, mtest.LoadBalanced, mtest.Single) + + mt.RunOpts("tailable awaitData cursor", tailableAwaitDataCursorOpts, func(mt *mtest.T) { + mt.Run("apply remaining timeoutMS if less than maxAwaitTimeMS", func(mt *mtest.T) { + initCollection(mt, mt.Coll) + mt.ClearEvents() + + // Create a find cursor + opts := options.Find().SetBatchSize(1).SetMaxAwaitTime(100 * time.Millisecond) + + cursor, err := mt.Coll.Find(context.Background(), bson.D{}, opts) + require.NoError(mt, err) + + _ = mt.GetStartedEvent() // Empty find from started list. + + defer cursor.Close(context.Background()) + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + // Iterate twice to force a getMore + cursor.Next(ctx) + cursor.Next(ctx) + + cmd := mt.GetStartedEvent().Command + + maxTimeMSRaw, err := cmd.LookupErr("maxTimeMS") + require.NoError(mt, err) + + got, ok := maxTimeMSRaw.AsInt64OK() + require.True(mt, ok) + + assert.LessOrEqual(mt, got, int64(50)) + }) + + mt.RunOpts("apply maxAwaitTimeMS if less than remaining timeout", tailableAwaitDataCursorOpts, func(mt *mtest.T) { + initCollection(mt, mt.Coll) + mt.ClearEvents() + + // Create a find cursor + opts := options.Find().SetBatchSize(1).SetMaxAwaitTime(50 * time.Millisecond) + + cursor, err := mt.Coll.Find(context.Background(), bson.D{}, opts) + require.NoError(mt, err) + + _ = mt.GetStartedEvent() // Empty find from started list. + + defer cursor.Close(context.Background()) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + // Iterate twice to force a getMore + cursor.Next(ctx) + cursor.Next(ctx) + + cmd := mt.GetStartedEvent().Command + + maxTimeMSRaw, err := cmd.LookupErr("maxTimeMS") + require.NoError(mt, err) + + got, ok := maxTimeMSRaw.AsInt64OK() + require.True(mt, ok) + + assert.LessOrEqual(mt, got, int64(50)) + }) + }) } type tryNextCursor interface { diff --git a/internal/integration/unified/collection_operation_execution.go b/internal/integration/unified/collection_operation_execution.go index 6c8b38145a..c3e7040256 100644 --- a/internal/integration/unified/collection_operation_execution.go +++ b/internal/integration/unified/collection_operation_execution.go @@ -10,6 +10,7 @@ import ( "context" "errors" "fmt" + "strings" "time" "go.mongodb.org/mongo-driver/v2/bson" @@ -1485,6 +1486,20 @@ func createFindCursor(ctx context.Context, operation *operation) (*cursorResult, opts.SetSkip(int64(val.Int32())) case "sort": opts.SetSort(val.Document()) + case "timeoutMode": + return nil, newSkipTestError("timeoutMode is not supported") + case "cursorType": + switch strings.ToLower(val.StringValue()) { + case "tailable": + opts.SetCursorType(options.Tailable) + case "tailableawait": + opts.SetCursorType(options.TailableAwait) + case "nontailable": + opts.SetCursorType(options.NonTailable) + } + case "maxAwaitTimeMS": + maxAwaitTimeMS := time.Duration(val.Int32()) * time.Millisecond + opts.SetMaxAwaitTime(maxAwaitTimeMS) default: return nil, fmt.Errorf("unrecognized find option %q", key) } diff --git a/internal/spectest/skip.go b/internal/spectest/skip.go index 396590841d..c3c5d04fe3 100644 --- a/internal/spectest/skip.go +++ b/internal/spectest/skip.go @@ -346,6 +346,10 @@ var skipTests = map[string][]string{ "TestUnifiedSpec/client-side-operations-timeout/tests/retryability-timeoutMS.json/operation_is_retried_multiple_times_for_non-zero_timeoutMS_-_aggregate_on_collection", "TestUnifiedSpec/client-side-operations-timeout/tests/retryability-timeoutMS.json/operation_is_retried_multiple_times_for_non-zero_timeoutMS_-_aggregate_on_database", "TestUnifiedSpec/client-side-operations-timeout/tests/gridfs-find.json/timeoutMS_applied_to_find_command", + "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_applied_to_find", + "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_is_refreshed_for_getMore_if_maxAwaitTimeMS_is_not_set", + "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_is_refreshed_for_getMore_if_maxAwaitTimeMS_is_set", + "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_is_refreshed_for_getMore_-_failure", }, // TODO(GODRIVER-3411): Tests require "getMore" with "maxTimeMS" settings. Not @@ -448,7 +452,6 @@ var skipTests = map[string][]string{ "TestUnifiedSpec/client-side-operations-timeout/tests/change-streams.json/change_stream_can_be_iterated_again_if_previous_iteration_times_out", "TestUnifiedSpec/client-side-operations-timeout/tests/change-streams.json/timeoutMS_is_refreshed_for_getMore_-_failure", "TestUnifiedSpec/client-side-operations-timeout/tests/change-streams.json/error_if_maxAwaitTimeMS_is_greater_than_timeoutMS", - "TestUnifiedSpec/client-side-operations-timeout/tests/change-streams.json/error_if_maxAwaitTimeMS_is_equal_to_timeoutMS", }, // Unknown CSOT: @@ -584,12 +587,10 @@ var skipTests = map[string][]string{ "TestUnifiedSpec/client-side-operations-timeout/tests/sessions-override-timeoutMS.json", "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/error_if_timeoutMode_is_cursor_lifetime", "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/error_if_maxAwaitTimeMS_is_greater_than_timeoutMS", - "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/error_if_maxAwaitTimeMS_is_equal_to_timeoutMS", "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_applied_to_find", "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_is_refreshed_for_getMore_if_maxAwaitTimeMS_is_not_set", "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_is_refreshed_for_getMore_if_maxAwaitTimeMS_is_set", "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/timeoutMS_is_refreshed_for_getMore_-_failure", - "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/apply_remaining_timeoutMS_if_less_than_maxAwaitTimeMS", "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/apply_maxAwaitTimeMS_if_less_than_remaining_timeout", "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-non-awaitData.json/error_if_timeoutMode_is_cursor_lifetime", "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-non-awaitData.json/timeoutMS_applied_to_find", @@ -819,6 +820,21 @@ var skipTests = map[string][]string{ "TestUnifiedSpec/transactions-convenient-api/tests/unified/transaction-options.json/withTransaction_explicit_transaction_options_override_client_options", "TestUnifiedSpec/transactions-convenient-api/tests/unified/commit.json/withTransaction_commits_after_callback_returns", }, + + // GODRIVER-3473: the implementation of DRIVERS-2868 makes it clear that the + // Go Driver does not correctly implement the following validation for + // tailable awaitData cursors: + // + // Drivers MUST error if this option is set, timeoutMS is set to a + // non-zero value, and maxAwaitTimeMS is greater than or equal to + // timeoutMS. + // + // Once GODRIVER-3473 is completed, we can continue running these tests. + "When constructing tailable awaitData cusors must validate, timeoutMS is set to a non-zero value, and maxAwaitTimeMS is greater than or equal to timeoutMS (GODRIVER-3473)": { + "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/apply_remaining_timeoutMS_if_less_than_maxAwaitTimeMS", + "TestUnifiedSpec/client-side-operations-timeout/tests/tailable-awaitData.json/error_if_maxAwaitTimeMS_is_equal_to_timeoutMS", + "TestUnifiedSpec/client-side-operations-timeout/tests/change-streams.json/error_if_maxAwaitTimeMS_is_equal_to_timeoutMS", + }, } // CheckSkip checks if the fully-qualified test name matches a list of skipped test names for a given reason. diff --git a/x/mongo/driver/batch_cursor.go b/x/mongo/driver/batch_cursor.go index f444739661..6d6cd211a5 100644 --- a/x/mongo/driver/batch_cursor.go +++ b/x/mongo/driver/batch_cursor.go @@ -381,14 +381,40 @@ func (bc *BatchCursor) getMore(ctx context.Context) { bc.err = Operation{ CommandFn: func(dst []byte, _ description.SelectedServer) ([]byte, error) { + // If maxAwaitTime > remaining timeoutMS - minRoundTripTime, then use + // send remaining TimeoutMS - minRoundTripTime allowing the server an + // opportunity to respond with an empty batch. + var maxTimeMS int64 + if bc.maxAwaitTime != nil { + _, ctxDeadlineSet := ctx.Deadline() + + if ctxDeadlineSet { + rttMonitor := bc.Server().RTTMonitor() + + var ok bool + maxTimeMS, ok = driverutil.CalculateMaxTimeMS(ctx, rttMonitor.Min()) + if !ok && maxTimeMS <= 0 { + return nil, fmt.Errorf( + "calculated server-side timeout (%v ms) is less than or equal to 0 (%v): %w", + maxTimeMS, + rttMonitor.Stats(), + ErrDeadlineWouldBeExceeded) + } + } + + if !ctxDeadlineSet || bc.maxAwaitTime.Milliseconds() < maxTimeMS { + maxTimeMS = bc.maxAwaitTime.Milliseconds() + } + } + dst = bsoncore.AppendInt64Element(dst, "getMore", bc.id) dst = bsoncore.AppendStringElement(dst, "collection", bc.collection) if numToReturn > 0 { dst = bsoncore.AppendInt32Element(dst, "batchSize", numToReturn) } - if bc.maxAwaitTime != nil && *bc.maxAwaitTime > 0 { - dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", int64(*bc.maxAwaitTime)/int64(time.Millisecond)) + if maxTimeMS > 0 { + dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", maxTimeMS) } comment, err := codecutil.MarshalValue(bc.comment, bc.encoderFn) diff --git a/x/mongo/driver/operation.go b/x/mongo/driver/operation.go index 2597a5de66..50136456e4 100644 --- a/x/mongo/driver/operation.go +++ b/x/mongo/driver/operation.go @@ -1724,34 +1724,16 @@ func (op Operation) calculateMaxTimeMS(ctx context.Context, rttMin time.Duration return 0, nil } - deadline, ok := ctx.Deadline() - if !ok { - return 0, nil - } - - remainingTimeout := time.Until(deadline) - - // Always round up to the next millisecond value so we never truncate the calculated - // maxTimeMS value (e.g. 400 microseconds evaluates to 1ms, not 0ms). - maxTimeMS := int64((remainingTimeout - rttMin + time.Millisecond - 1) / time.Millisecond) - if maxTimeMS <= 0 { + // Calculate maxTimeMS value to potentially be appended to the wire message. + maxTimeMS, ok := driverutil.CalculateMaxTimeMS(ctx, rttMin) + if !ok && maxTimeMS <= 0 { return 0, fmt.Errorf( - "remaining time %v until context deadline is less than or equal to min network round-trip time %v (%v): %w", - remainingTimeout, - rttMin, + "calculated server-side timeout (%v ms) is less than or equal to 0 (%v): %w", + maxTimeMS, rttStats, ErrDeadlineWouldBeExceeded) } - // The server will return a "BadValue" error if maxTimeMS is greater - // than the maximum positive int32 value (about 24.9 days). If the - // user specified a timeout value greater than that, omit maxTimeMS - // and let the client-side timeout handle cancelling the op if the - // timeout is ever reached. - if maxTimeMS > math.MaxInt32 { - return 0, nil - } - return maxTimeMS, nil }