diff --git a/boto3_helpers/s3.py b/boto3_helpers/s3.py index 6e01872..3aa4982 100644 --- a/boto3_helpers/s3.py +++ b/boto3_helpers/s3.py @@ -2,7 +2,6 @@ from boto3 import client as boto3_client - SELECT_FORMATS = { 'json': {'JSON': {'Type': 'DOCUMENT'}}, 'json.gz': {'JSON': {'Type': 'DOCUMENT'}, 'CompressionType': 'GZIP'}, diff --git a/boto3_helpers/signed_requests.py b/boto3_helpers/signed_requests.py index 6f0038e..bf4aff4 100644 --- a/boto3_helpers/signed_requests.py +++ b/boto3_helpers/signed_requests.py @@ -15,7 +15,14 @@ def __init__(self, status_code, content): def sigv4_request( - service, method, endpoint, client=None, base_url=None, operation_name=None, **kwargs + service, + method, + endpoint, + client=None, + base_url=None, + operation_name=None, + decode_json=True, + **kwargs, ): """Make a signed request to the AWS API and return the JSON payload. @@ -28,6 +35,8 @@ def sigv4_request( * *base_url* is the URL for the target AWS API. If not given, a guess will be made based on the service name and client region. * *operation_name* is the name of the API operation to use when signing the request + * *decode_json* controls whether responses are automatically deserialized from + JSON (default: ``True``). * **kwargs** are passed on to an ``AWSRequest`` object. If the API response indicates an error, @@ -83,4 +92,7 @@ def sigv4_request( if not (200 <= resp.status_code <= 299): raise SigV4RequestException(resp.status_code, resp.content) - return loads(resp.content) + if decode_json: + return loads(resp.content) + else: + return resp.content diff --git a/tests/test_signed_requests.py b/tests/test_signed_requests.py index bde3c4f..0a8dcef 100644 --- a/tests/test_signed_requests.py +++ b/tests/test_signed_requests.py @@ -39,6 +39,40 @@ def test_call_succeeds(self): _client._endpoint.http_session.send.assert_called_once_with(sign_call[0][1]) + def test_no_json_succeeds(self): + _client = MagicMock() + _client.meta.region_name = 'test-region-1' + _client._endpoint.http_session.send.return_value = MagicMock( + status_code=200, content=b'Hello world' + ) + + service = 'scheduler' + method = 'POST' + endpoint = '/schedules?MaxResults=1' + operation_name = 'ListSchedules' + actual = sigv4_request( + service, + method, + endpoint, + client=_client, + operation_name=operation_name, + decode_json=False, + data='{"test": "payload"}', + ) + self.assertEqual(actual, b'Hello world') + + sign_call = _client._request_signer.sign.call_args + self.assertEqual(sign_call[0][0], operation_name) + self.assertEqual(sign_call[0][1].method, method) + self.assertEqual(sign_call[0][1].data, '{"test": "payload"}') + self.assertEqual( + sign_call[0][1].url, + 'https://scheduler.test-region-1.amazonaws.com/schedules?MaxResults=1', + ) + self.assertEqual(sign_call[1], {'signing_name': service}) + + _client._endpoint.http_session.send.assert_called_once_with(sign_call[0][1]) + def test_call_fails(self): _client = MagicMock() _client.meta.region_name = 'test-region-1'