From 376b38438aaaf894c957740c0b9464e66f4184da Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Tue, 4 Jun 2024 16:36:18 -0700 Subject: [PATCH 1/6] use semaphore to control shard concurrency --- google/cloud/bigtable/data/_async/client.py | 58 ++++++++++----------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 7d75fab00..46b8089af 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -740,43 +740,43 @@ async def read_rows_sharded( """ if not sharded_query: raise ValueError("empty sharded_query") - # reduce operation_timeout between batches operation_timeout, attempt_timeout = _get_timeouts( operation_timeout, attempt_timeout, self ) - timeout_generator = _attempt_timeout_generator( + # make sure each rpc stays within overall operation timeout + rpc_timeout_generator = _attempt_timeout_generator( operation_timeout, operation_timeout ) - # submit shards in batches if the number of shards goes over _CONCURRENCY_LIMIT - batched_queries = [ - sharded_query[i : i + _CONCURRENCY_LIMIT] - for i in range(0, len(sharded_query), _CONCURRENCY_LIMIT) - ] - # run batches and collect results - results_list = [] - error_dict = {} - shard_idx = 0 - for batch in batched_queries: - batch_operation_timeout = next(timeout_generator) - routine_list = [ - self.read_rows( + + # limit the number of concurrent requests using a semaphore + concurrency_sem = asyncio.Semaphore(_CONCURRENCY_LIMIT) + async def read_rows_with_semaphore(query): + async with concurrency_sem: + # calculate new timeout based on time left in overall operation + shard_timeout = next(rpc_timeout_generator) + return await self.read_rows( query, - operation_timeout=batch_operation_timeout, - attempt_timeout=min(attempt_timeout, batch_operation_timeout), + operation_timeout=shard_timeout, + attempt_timeout=min(attempt_timeout, shard_timeout), retryable_errors=retryable_errors, ) - for query in batch - ] - batch_result = await asyncio.gather(*routine_list, return_exceptions=True) - for result in batch_result: - if isinstance(result, Exception): - error_dict[shard_idx] = result - elif isinstance(result, BaseException): - # BaseException not expected; raise immediately - raise result - else: - results_list.extend(result) - shard_idx += 1 + + routine_list = [read_rows_with_semaphore(query) for query in sharded_query] + batch_result = await asyncio.gather(*routine_list, return_exceptions=True) + + # collect results and errors + error_dict = {} + shard_idx = 0 + results_list = [] + for result in batch_result: + if isinstance(result, Exception): + error_dict[shard_idx] = result + elif isinstance(result, BaseException): + # BaseException not expected; raise immediately + raise result + else: + results_list.extend(result) + shard_idx += 1 if error_dict: # if any sub-request failed, raise an exception instead of returning results raise ShardedReadRowsExceptionGroup( From 0dbecea45a70b5ff6c95913fb36a864023c27e50 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 5 Jun 2024 11:53:04 -0700 Subject: [PATCH 2/6] fixed lint --- google/cloud/bigtable/data/_async/client.py | 1 + 1 file changed, 1 insertion(+) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 46b8089af..4442e5854 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -750,6 +750,7 @@ async def read_rows_sharded( # limit the number of concurrent requests using a semaphore concurrency_sem = asyncio.Semaphore(_CONCURRENCY_LIMIT) + async def read_rows_with_semaphore(query): async with concurrency_sem: # calculate new timeout based on time left in overall operation From 48f8bc232f6d5e21ebe5034844afafd9f5e6a958 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 5 Jun 2024 12:43:18 -0700 Subject: [PATCH 3/6] fixed test --- tests/unit/data/_async/test_client.py | 89 ++++++++++++--------------- 1 file changed, 40 insertions(+), 49 deletions(-) diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 7593572d8..7e4bd67da 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -1927,62 +1927,53 @@ async def mock_call(*args, **kwargs): assert call_time < 0.2 @pytest.mark.asyncio - async def test_read_rows_sharded_batching(self): + async def test_read_rows_sharded_concurrency_limit(self): """ - Large queries should be processed in batches to limit concurrency - operation timeout should change between batches + Only 10 queries should be processed concurrently. Others should be queued + + Should start a new query as soon as previous finishes """ - from google.cloud.bigtable.data._async.client import TableAsync from google.cloud.bigtable.data._async.client import _CONCURRENCY_LIMIT assert _CONCURRENCY_LIMIT == 10 # change this test if this changes + num_queries = 15 - n_queries = 90 - expected_num_batches = n_queries // _CONCURRENCY_LIMIT - query_list = [ReadRowsQuery() for _ in range(n_queries)] - - table_mock = AsyncMock() - start_operation_timeout = 10 - start_attempt_timeout = 3 - table_mock.default_read_rows_operation_timeout = start_operation_timeout - table_mock.default_read_rows_attempt_timeout = start_attempt_timeout - # clock ticks one second on each check - with mock.patch("time.monotonic", side_effect=range(0, 100000)): - with mock.patch("asyncio.gather", AsyncMock()) as gather_mock: - await TableAsync.read_rows_sharded(table_mock, query_list) - # should have individual calls for each query - assert table_mock.read_rows.call_count == n_queries - # should have single gather call for each batch - assert gather_mock.call_count == expected_num_batches - # ensure that timeouts decrease over time - kwargs = [ - table_mock.read_rows.call_args_list[idx][1] - for idx in range(n_queries) - ] - for batch_idx in range(expected_num_batches): - batch_kwargs = kwargs[ - batch_idx - * _CONCURRENCY_LIMIT : (batch_idx + 1) - * _CONCURRENCY_LIMIT + # each of the first 10 queries take longer than the last + # later rpcs will have to wait on first 10 + increment_time = 0.05 + max_time = increment_time * (_CONCURRENCY_LIMIT - 1) + rpc_times = [min(i * increment_time, max_time) for i in range(num_queries)] + + async def mock_call(*args, **kwargs): + next_sleep = rpc_times.pop(0) + await asyncio.sleep(next_sleep) + return [mock.Mock()] + + starting_timeout = 10 + + async with _make_client() as client: + async with client.get_table("instance", "table") as table: + with mock.patch.object(table, "read_rows") as read_rows: + read_rows.side_effect = mock_call + queries = [ReadRowsQuery() for _ in range(num_queries)] + await table.read_rows_sharded( + queries, operation_timeout=starting_timeout + ) + assert read_rows.call_count == num_queries + # check operation timeouts to see how far into the operation each rpc started + rpc_start_list = [ + starting_timeout - kwargs["operation_timeout"] + for _, kwargs in read_rows.call_args_list ] - for req_kwargs in batch_kwargs: - # each batch should have the same operation_timeout, and it should decrease in each batch - expected_operation_timeout = start_operation_timeout - ( - batch_idx + 1 - ) - assert ( - req_kwargs["operation_timeout"] - == expected_operation_timeout - ) - # each attempt_timeout should start with default value, but decrease when operation_timeout reaches it - expected_attempt_timeout = min( - start_attempt_timeout, expected_operation_timeout - ) - assert req_kwargs["attempt_timeout"] == expected_attempt_timeout - # await all created coroutines to avoid warnings - for i in range(len(gather_mock.call_args_list)): - for j in range(len(gather_mock.call_args_list[i][0])): - await gather_mock.call_args_list[i][0][j] + eps = 0.01 + # first 10 should start immediately + assert all( + rpc_start_list[i] < eps for i in range(_CONCURRENCY_LIMIT) + ) + # next rpcs should start as first ones finish + for i in range(num_queries - _CONCURRENCY_LIMIT): + idx = i + _CONCURRENCY_LIMIT + assert rpc_start_list[idx] - (i * increment_time) < eps class TestSampleRowKeys: From f6b187dd99fc7585b2a4bf7736cec0ec293563d0 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Fri, 7 Jun 2024 13:55:36 -0700 Subject: [PATCH 4/6] added test --- tests/unit/data/_async/test_client.py | 40 +++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 7e4bd67da..6767b2505 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -1975,6 +1975,46 @@ async def mock_call(*args, **kwargs): idx = i + _CONCURRENCY_LIMIT assert rpc_start_list[idx] - (i * increment_time) < eps + @pytest.mark.asyncio + async def test_read_rows_sharded_expirary(self): + """ + If the operation times out before all shards complete, should raise + a ShardedReadRowsExceptionGroup + """ + from google.cloud.bigtable.data._async.client import _CONCURRENCY_LIMIT + from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup + from google.api_core.exceptions import DeadlineExceeded + + operation_timeout = 0.1 + + # let the first batch complete, but the next batch times out + num_queries = 15 + sleeps = [0] * _CONCURRENCY_LIMIT + [DeadlineExceeded("times up")] * ( + num_queries - _CONCURRENCY_LIMIT + ) + + async def mock_call(*args, **kwargs): + next_item = sleeps.pop(0) + if isinstance(next_item, Exception): + raise next_item + else: + await asyncio.sleep(next_item) + return [mock.Mock()] + + async with _make_client() as client: + async with client.get_table("instance", "table") as table: + with mock.patch.object(table, "read_rows") as read_rows: + read_rows.side_effect = mock_call + queries = [ReadRowsQuery() for _ in range(num_queries)] + with pytest.raises(ShardedReadRowsExceptionGroup) as exc: + await table.read_rows_sharded( + queries, operation_timeout=operation_timeout + ) + assert isinstance(exc.value, ShardedReadRowsExceptionGroup) + assert len(exc.value.exceptions) == num_queries - _CONCURRENCY_LIMIT + # should keep successful queries + assert len(exc.value.successful_rows) == _CONCURRENCY_LIMIT + class TestSampleRowKeys: async def _make_gapic_stream(self, sample_list: list[tuple[bytes, int]]): From 73f5852b489006003b2a86569489319315881032 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 10 Jun 2024 12:30:50 -0700 Subject: [PATCH 5/6] added test for timeout before running shard --- google/cloud/bigtable/data/_async/client.py | 4 +++ tests/unit/data/_async/test_client.py | 29 +++++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 4442e5854..74f1e9837 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -755,6 +755,10 @@ async def read_rows_with_semaphore(query): async with concurrency_sem: # calculate new timeout based on time left in overall operation shard_timeout = next(rpc_timeout_generator) + if shard_timeout <= 0: + raise DeadlineExceeded( + "Operation timeout exceeded before starting query" + ) return await self.read_rows( query, operation_timeout=shard_timeout, diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 6767b2505..68c590cb7 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -2015,6 +2015,35 @@ async def mock_call(*args, **kwargs): # should keep successful queries assert len(exc.value.successful_rows) == _CONCURRENCY_LIMIT + @pytest.mark.asyncio + async def test_read_rows_sharded_negative_batch_timeout(self): + """ + try to run with batch that starts after operation timeout + + They should raise DeadlineExceeded errors + """ + import time + from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup + from google.api_core.exceptions import DeadlineExceeded + + async def mock_call(*args, **kwargs): + await asyncio.sleep(0.05) + return [mock.Mock()] + + async with _make_client() as client: + async with client.get_table("instance", "table") as table: + with mock.patch.object(table, "read_rows") as read_rows: + read_rows.side_effect = mock_call + queries = [ReadRowsQuery() for _ in range(15)] + with pytest.raises(ShardedReadRowsExceptionGroup) as exc: + await table.read_rows_sharded(queries, operation_timeout=0.01) + assert isinstance(exc.value, ShardedReadRowsExceptionGroup) + assert len(exc.value.exceptions) == 5 + assert all( + isinstance(e.__cause__, DeadlineExceeded) + for e in exc.value.exceptions + ) + class TestSampleRowKeys: async def _make_gapic_stream(self, sample_list: list[tuple[bytes, int]]): From e65e49f0e301af316a3af46a8126255d772cda54 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 10 Jun 2024 12:47:10 -0700 Subject: [PATCH 6/6] removed import --- tests/unit/data/_async/test_client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index 68c590cb7..9ebc403ce 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -2022,7 +2022,6 @@ async def test_read_rows_sharded_negative_batch_timeout(self): They should raise DeadlineExceeded errors """ - import time from google.cloud.bigtable.data.exceptions import ShardedReadRowsExceptionGroup from google.api_core.exceptions import DeadlineExceeded