Skip to content

Commit

Permalink
fix: Catch rst stream error for all transactions (#934)
Browse files Browse the repository at this point in the history
* fix: rst retry for txn

* rst changes and tests

* fix

* rst stream comment changes

* lint

* lint
  • Loading branch information
asthamohta committed May 24, 2023
1 parent c53f273 commit d317d2e
Show file tree
Hide file tree
Showing 8 changed files with 268 additions and 14 deletions.
54 changes: 54 additions & 0 deletions google/cloud/spanner_v1/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import datetime
import decimal
import math
import time

from google.protobuf.struct_pb2 import ListValue
from google.protobuf.struct_pb2 import Value
Expand Down Expand Up @@ -294,6 +295,59 @@ def _metadata_with_prefix(prefix, **kw):
return [("google-cloud-resource-prefix", prefix)]


def _retry(
func,
retry_count=5,
delay=2,
allowed_exceptions=None,
):
"""
Retry a function with a specified number of retries, delay between retries, and list of allowed exceptions.
Args:
func: The function to be retried.
retry_count: The maximum number of times to retry the function.
delay: The delay in seconds between retries.
allowed_exceptions: A tuple of exceptions that are allowed to occur without triggering a retry.
Passing allowed_exceptions as None will lead to retrying for all exceptions.
Returns:
The result of the function if it is successful, or raises the last exception if all retries fail.
"""
retries = 0
while retries <= retry_count:
try:
return func()
except Exception as exc:
if (
allowed_exceptions is None or exc.__class__ in allowed_exceptions
) and retries < retry_count:
if (
allowed_exceptions is not None
and allowed_exceptions[exc.__class__] is not None
):
allowed_exceptions[exc.__class__](exc)
time.sleep(delay)
delay = delay * 2
retries = retries + 1
else:
raise exc


def _check_rst_stream_error(exc):
resumable_error = (
any(
resumable_message in exc.message
for resumable_message in (
"RST_STREAM",
"Received unexpected EOS on DATA frame from server",
)
),
)
if not resumable_error:
raise


def _metadata_with_leader_aware_routing(value, **kw):
"""Create RPC metadata containing a leader aware routing header
Expand Down
11 changes: 10 additions & 1 deletion google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""Context manager for Cloud Spanner batched writes."""
import functools

from google.cloud.spanner_v1 import CommitRequest
from google.cloud.spanner_v1 import Mutation
Expand All @@ -26,6 +27,9 @@
)
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
from google.cloud.spanner_v1 import RequestOptions
from google.cloud.spanner_v1._helpers import _retry
from google.cloud.spanner_v1._helpers import _check_rst_stream_error
from google.api_core.exceptions import InternalServerError


class _BatchBase(_SessionWrapper):
Expand Down Expand Up @@ -186,10 +190,15 @@ def commit(self, return_commit_stats=False, request_options=None):
request_options=request_options,
)
with trace_call("CloudSpanner.Commit", self._session, trace_attributes):
response = api.commit(
method = functools.partial(
api.commit,
request=request,
metadata=metadata,
)
response = _retry(
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
)
self.committed = response.commit_timestamp
self.commit_stats = response.commit_stats
return self.committed
Expand Down
29 changes: 23 additions & 6 deletions google/cloud/spanner_v1/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@
from google.api_core.exceptions import ServiceUnavailable
from google.api_core.exceptions import InvalidArgument
from google.api_core import gapic_v1
from google.cloud.spanner_v1._helpers import _make_value_pb
from google.cloud.spanner_v1._helpers import _merge_query_options
from google.cloud.spanner_v1._helpers import (
_make_value_pb,
_merge_query_options,
_metadata_with_prefix,
_metadata_with_leader_aware_routing,
_retry,
_check_rst_stream_error,
_SessionWrapper,
)
from google.cloud.spanner_v1._helpers import _SessionWrapper
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
from google.cloud.spanner_v1.streamed import StreamedResultSet
from google.cloud.spanner_v1 import RequestOptions
Expand Down Expand Up @@ -560,12 +562,17 @@ def partition_read(
with trace_call(
"CloudSpanner.PartitionReadOnlyTransaction", self._session, trace_attributes
):
response = api.partition_read(
method = functools.partial(
api.partition_read,
request=request,
metadata=metadata,
retry=retry,
timeout=timeout,
)
response = _retry(
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
)

return [partition.partition_token for partition in response.partitions]

Expand Down Expand Up @@ -659,12 +666,17 @@ def partition_query(
self._session,
trace_attributes,
):
response = api.partition_query(
method = functools.partial(
api.partition_query,
request=request,
metadata=metadata,
retry=retry,
timeout=timeout,
)
response = _retry(
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
)

return [partition.partition_token for partition in response.partitions]

Expand Down Expand Up @@ -791,10 +803,15 @@ def begin(self):
)
txn_selector = self._make_txn_selector()
with trace_call("CloudSpanner.BeginTransaction", self._session):
response = api.begin_transaction(
method = functools.partial(
api.begin_transaction,
session=self._session.name,
options=txn_selector.begin,
metadata=metadata,
)
response = _retry(
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
)
self._transaction_id = response.id
return self._transaction_id
34 changes: 29 additions & 5 deletions google/cloud/spanner_v1/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
_merge_query_options,
_metadata_with_prefix,
_metadata_with_leader_aware_routing,
_retry,
_check_rst_stream_error,
)
from google.cloud.spanner_v1 import CommitRequest
from google.cloud.spanner_v1 import ExecuteBatchDmlRequest
Expand All @@ -33,6 +35,7 @@
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
from google.cloud.spanner_v1 import RequestOptions
from google.api_core import gapic_v1
from google.api_core.exceptions import InternalServerError


class Transaction(_SnapshotBase, _BatchBase):
Expand Down Expand Up @@ -102,7 +105,11 @@ def _execute_request(
transaction = self._make_txn_selector()
request.transaction = transaction
with trace_call(trace_name, session, attributes):
response = method(request=request)
method = functools.partial(method, request=request)
response = _retry(
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
)

return response

Expand Down Expand Up @@ -132,8 +139,15 @@ def begin(self):
)
txn_options = TransactionOptions(read_write=TransactionOptions.ReadWrite())
with trace_call("CloudSpanner.BeginTransaction", self._session):
response = api.begin_transaction(
session=self._session.name, options=txn_options, metadata=metadata
method = functools.partial(
api.begin_transaction,
session=self._session.name,
options=txn_options,
metadata=metadata,
)
response = _retry(
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
)
self._transaction_id = response.id
return self._transaction_id
Expand All @@ -153,11 +167,16 @@ def rollback(self):
)
)
with trace_call("CloudSpanner.Rollback", self._session):
api.rollback(
method = functools.partial(
api.rollback,
session=self._session.name,
transaction_id=self._transaction_id,
metadata=metadata,
)
_retry(
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
)
self.rolled_back = True
del self._session._transaction

Expand Down Expand Up @@ -212,10 +231,15 @@ def commit(self, return_commit_stats=False, request_options=None):
request_options=request_options,
)
with trace_call("CloudSpanner.Commit", self._session, trace_attributes):
response = api.commit(
method = functools.partial(
api.commit,
request=request,
metadata=metadata,
)
response = _retry(
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
)
self.committed = response.commit_timestamp
if return_commit_stats:
self.commit_stats = response.commit_stats
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/spanner_dbapi/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def test__session_checkout(self, mock_database):
connection._session_checkout()
self.assertEqual(connection._session, "db_session")

def test__session_checkout_database_error(self):
def test_session_checkout_database_error(self):
from google.cloud.spanner_dbapi import Connection

connection = Connection(INSTANCE)
Expand All @@ -191,7 +191,7 @@ def test__release_session(self, mock_database):
pool.put.assert_called_once_with("session")
self.assertIsNone(connection._session)

def test__release_session_database_error(self):
def test_release_session_database_error(self):
from google.cloud.spanner_dbapi import Connection

connection = Connection(INSTANCE)
Expand Down
78 changes: 78 additions & 0 deletions tests/unit/test__helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


import unittest
import mock


class Test_merge_query_options(unittest.TestCase):
Expand Down Expand Up @@ -671,6 +672,83 @@ def test(self):
self.assertEqual(metadata, [("google-cloud-resource-prefix", prefix)])


class Test_retry(unittest.TestCase):
class test_class:
def test_fxn(self):
return True

def test_retry_on_error(self):
from google.api_core.exceptions import InternalServerError, NotFound
from google.cloud.spanner_v1._helpers import _retry
import functools

test_api = mock.create_autospec(self.test_class)
test_api.test_fxn.side_effect = [
InternalServerError("testing"),
NotFound("testing"),
True,
]

_retry(functools.partial(test_api.test_fxn))

self.assertEqual(test_api.test_fxn.call_count, 3)

def test_retry_allowed_exceptions(self):
from google.api_core.exceptions import InternalServerError, NotFound
from google.cloud.spanner_v1._helpers import _retry
import functools

test_api = mock.create_autospec(self.test_class)
test_api.test_fxn.side_effect = [
NotFound("testing"),
InternalServerError("testing"),
True,
]

with self.assertRaises(InternalServerError):
_retry(
functools.partial(test_api.test_fxn),
allowed_exceptions={NotFound: None},
)

self.assertEqual(test_api.test_fxn.call_count, 2)

def test_retry_count(self):
from google.api_core.exceptions import InternalServerError
from google.cloud.spanner_v1._helpers import _retry
import functools

test_api = mock.create_autospec(self.test_class)
test_api.test_fxn.side_effect = [
InternalServerError("testing"),
InternalServerError("testing"),
]

with self.assertRaises(InternalServerError):
_retry(functools.partial(test_api.test_fxn), retry_count=1)

self.assertEqual(test_api.test_fxn.call_count, 2)

def test_check_rst_stream_error(self):
from google.api_core.exceptions import InternalServerError
from google.cloud.spanner_v1._helpers import _retry, _check_rst_stream_error
import functools

test_api = mock.create_autospec(self.test_class)
test_api.test_fxn.side_effect = [
InternalServerError("Received unexpected EOS on DATA frame from server"),
InternalServerError("RST_STREAM"),
True,
]

_retry(
functools.partial(test_api.test_fxn),
allowed_exceptions={InternalServerError: _check_rst_stream_error},
)

self.assertEqual(test_api.test_fxn.call_count, 3)


class Test_metadata_with_leader_aware_routing(unittest.TestCase):
def _call_fut(self, *args, **kw):
from google.cloud.spanner_v1._helpers import _metadata_with_leader_aware_routing
Expand Down
Loading

0 comments on commit d317d2e

Please sign in to comment.