Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improve async sharding #977

Merged
merged 6 commits into from
Jun 11, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
use semaphore to control shard concurrency
  • Loading branch information
daniel-sanche committed Jun 4, 2024
commit 376b38438aaaf894c957740c0b9464e66f4184da
58 changes: 29 additions & 29 deletions google/cloud/bigtable/data/_async/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,43 +740,43 @@ async def read_rows_sharded(
"""
if not sharded_query:
raise ValueError("empty sharded_query")
# reduce operation_timeout between batches
operation_timeout, attempt_timeout = _get_timeouts(
operation_timeout, attempt_timeout, self
)
timeout_generator = _attempt_timeout_generator(
# make sure each rpc stays within overall operation timeout
rpc_timeout_generator = _attempt_timeout_generator(
operation_timeout, operation_timeout
)
# submit shards in batches if the number of shards goes over _CONCURRENCY_LIMIT
batched_queries = [
sharded_query[i : i + _CONCURRENCY_LIMIT]
for i in range(0, len(sharded_query), _CONCURRENCY_LIMIT)
]
# run batches and collect results
results_list = []
error_dict = {}
shard_idx = 0
for batch in batched_queries:
batch_operation_timeout = next(timeout_generator)
routine_list = [
self.read_rows(

# limit the number of concurrent requests using a semaphore
concurrency_sem = asyncio.Semaphore(_CONCURRENCY_LIMIT)
async def read_rows_with_semaphore(query):
async with concurrency_sem:
# calculate new timeout based on time left in overall operation
shard_timeout = next(rpc_timeout_generator)
return await self.read_rows(
query,
operation_timeout=batch_operation_timeout,
attempt_timeout=min(attempt_timeout, batch_operation_timeout),
operation_timeout=shard_timeout,
attempt_timeout=min(attempt_timeout, shard_timeout),
retryable_errors=retryable_errors,
)
for query in batch
]
batch_result = await asyncio.gather(*routine_list, return_exceptions=True)
for result in batch_result:
if isinstance(result, Exception):
error_dict[shard_idx] = result
elif isinstance(result, BaseException):
# BaseException not expected; raise immediately
raise result
else:
results_list.extend(result)
shard_idx += 1

routine_list = [read_rows_with_semaphore(query) for query in sharded_query]
batch_result = await asyncio.gather(*routine_list, return_exceptions=True)

# collect results and errors
error_dict = {}
shard_idx = 0
results_list = []
for result in batch_result:
if isinstance(result, Exception):
error_dict[shard_idx] = result
elif isinstance(result, BaseException):
# BaseException not expected; raise immediately
raise result
else:
results_list.extend(result)
shard_idx += 1
if error_dict:
# if any sub-request failed, raise an exception instead of returning results
raise ShardedReadRowsExceptionGroup(
Expand Down
Loading