Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion exasol/saas/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
)
from typing import Final

from exasol.saas.client.openapi.models.status import Status
from exasol.saas.client.openapi_facade_types import Status

SAAS_HOST = "https://cloud.exasol.com"

Expand Down
146 changes: 45 additions & 101 deletions exasol/saas/client/api_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,31 +28,17 @@
wait_fixed,
)

from exasol.saas.client import (
Limits,
openapi,
)
from exasol.saas.client.openapi.api.clusters import (
get_cluster_connection,
list_clusters,
)
from exasol.saas.client.openapi.api.databases import (
create_database,
delete_database,
get_database,
list_databases,
)
from exasol.saas.client.openapi.api.security import (
add_allowed_ip,
delete_allowed_ip,
list_allowed_i_ps,
)
from exasol.saas.client.openapi.models import (
from exasol.saas.client import Limits
from exasol.saas.client.openapi_facade import (
AllowedIP,
ApiError,
AuthenticatedClient,
Cluster,
ClusterConnection,
ExasolDatabase,
Status,
)
from exasol.saas.client.openapi.types import UNSET
from exasol.saas.client.openapi_facade import OpenApiFacade as OpenApiClient

LOG = logging.getLogger(__name__)
LOG.setLevel(logging.INFO)
Expand Down Expand Up @@ -138,34 +124,23 @@ def create_saas_client(
host: str,
pat: str,
raise_on_unexpected_status: bool = True,
) -> openapi.AuthenticatedClient:
return openapi.AuthenticatedClient(
base_url=host,
token=pat,
) -> AuthenticatedClient:
return OpenApiClient(
host=host,
pat=pat,
raise_on_unexpected_status=raise_on_unexpected_status,
)


def _get_database_id(
account_id: str,
client: openapi.AuthenticatedClient,
client: AuthenticatedClient,
database_name: str,
) -> str:
"""
Finds the database id, given the database name.
"""
dbs = list_databases.sync(account_id, client=client)
dbs = list(
filter(
lambda db: (db.name == database_name) # type: ignore
and (db.deleted_at is UNSET) # type: ignore
and (db.deleted_by is UNSET),
dbs, # type: ignore
)
) # type: ignore
if not dbs:
raise RuntimeError(f"SaaS database {database_name} was not found.")
return dbs[0].id
return client.find_database_id(account_id, database_name)


def get_database_id(
Expand Down Expand Up @@ -223,15 +198,18 @@ def get_connection_params(
database_id = _get_database_id(
account_id, client, database_name=database_name
)
clusters = list_clusters.sync(account_id, database_id, client=client)
clusters = ensure_type(
list,
client.list_clusters(account_id, database_id),
"Failed to list clusters of "
f"host {host}, account {account_id}, database {database_id}",
)
cluster_id = next(
filter(lambda cl: cl.main_cluster, clusters) # type: ignore
).id
resp = get_cluster_connection.sync(
account_id, database_id, cluster_id, client=client
)
resp = client.get_cluster_connection(account_id, database_id, cluster_id)
connection = ensure_type(
openapi.models.ClusterConnection,
ClusterConnection,
resp,
"Failed to get the connection data to"
f" host {host}, account {account_id},"
Expand All @@ -251,7 +229,7 @@ class OpenApiAccess:
planned to only use fixture ``saas_database_id()``.
"""

def __init__(self, client: openapi.AuthenticatedClient, account_id: str):
def __init__(self, client: AuthenticatedClient, account_id: str):
self._client = client
self._account_id = account_id

Expand All @@ -262,29 +240,14 @@ def create_database(
region: str = "eu-central-1",
idle_time: timedelta | None = None,
) -> ExasolDatabase | None:
def minutes(x: timedelta) -> int:
return x.seconds // 60

idle_time = idle_time or Limits.AUTOSTOP_MIN_IDLE_TIME
cluster_spec = openapi.models.CreateDatabaseInitialCluster(
name="my-cluster",
size=cluster_size,
auto_stop=openapi.models.AutoStop(
enabled=True,
idle_time=minutes(idle_time),
),
)
LOG.info("Creating database %s", name)
resp = create_database.sync(
resp = self._client.create_database(
self._account_id,
client=self._client,
body=openapi.models.CreateDatabase(
name=name,
initial_cluster=cluster_spec,
provider="aws",
region=region,
stream_type="innovation-release",
),
name=name,
cluster_size=cluster_size,
region=region,
idle_time=idle_time,
)
database = ensure_type(
ExasolDatabase, resp, f"Failed to create database {name}"
Expand Down Expand Up @@ -341,10 +304,9 @@ def is_retry(resp: ApiError) -> bool:
)
def delete_with_retry() -> None:
LOG.info("- Trying to delete ...")
resp = delete_database.sync(
resp = self._client.delete_database(
self._account_id,
database_id,
client=self._client,
)
if not isinstance(resp, ApiError):
# success
Expand All @@ -366,7 +328,7 @@ def delete_with_retry() -> None:
raise DatabaseDeleteError(msg) from ex

def list_database_ids(self) -> Iterable[str]:
resp = list_databases.sync(self._account_id, client=self._client) or []
resp = self._client.list_databases(self._account_id) or []
# actually list[ExasolDatabase]
dbs = ensure_type(list, resp, "Failed to list databases")
return (db.id for db in dbs)
Expand Down Expand Up @@ -397,11 +359,7 @@ def get_database(
self,
database_id: str,
) -> ExasolDatabase | None:
resp = get_database.sync(
self._account_id,
database_id,
client=self._client,
)
resp = self._client.get_database(self._account_id, database_id)
return ensure_type(
ExasolDatabase, resp, f"Failed to get database {database_id}"
)
Expand Down Expand Up @@ -430,16 +388,9 @@ def poll_status() -> Status:
def clusters(
self,
database_id: str,
) -> list[openapi.models.Cluster] | None:
resp = (
list_clusters.sync(
self._account_id,
database_id,
client=self._client,
)
or []
)
# actually list[openapi.models.Cluster]
) -> list[Cluster] | None:
resp = self._client.list_clusters(self._account_id, database_id) or []
# actually list[Cluster]
return ensure_type(
list, resp, f"Failed to list clusters of database {database_id}"
)
Expand All @@ -448,54 +399,47 @@ def get_connection(
self,
database_id: str,
cluster_id: str,
) -> openapi.models.ClusterConnection | None:
resp = get_cluster_connection.sync(
self._account_id,
database_id,
cluster_id,
client=self._client,
) -> ClusterConnection | None:
resp = self._client.get_cluster_connection(
self._account_id, database_id, cluster_id
)
return ensure_type(
openapi.models.ClusterConnection,
ClusterConnection,
resp,
"Failed to retrieve a connection to "
f"database {database_id} cluster {cluster_id}",
)

def list_allowed_ip_ids(self) -> Iterable[str]:
resp = list_allowed_i_ps.sync(self._account_id, client=self._client) or []
# actually list[openapi.models.AllowedIP]
resp = self._client.list_allowed_ips(self._account_id) or []
# actually list[AllowedIP]
ips = ensure_type(list, resp, "Failed to retrieve the list of allowed ips")
return (x.id for x in ips)

def add_allowed_ip(
self,
cidr_ip: str = "0.0.0.0/0",
) -> openapi.models.AllowedIP | None:
) -> AllowedIP | None:
"""
Suggested values for cidr_ip:
* 185.17.207.78/32
* 0.0.0.0/0 = all ipv4
* ::/0 = all ipv6
"""
rule = openapi.models.CreateAllowedIP(
name=timestamp_name(),
cidr_ip=cidr_ip,
)
resp = add_allowed_ip.sync(
resp = self._client.add_allowed_ip(
self._account_id,
client=self._client,
body=rule,
cidr_ip=cidr_ip,
name=timestamp_name(),
)
return ensure_type(
openapi.models.AllowedIP,
AllowedIP,
resp,
f"Failed to add allowed IP address {cidr_ip}",
)

def delete_allowed_ip(self, id: str, ignore_failures=False) -> Any | None:
with self._ignore_failures(ignore_failures) as client:
return delete_allowed_ip.sync(self._account_id, id, client=client)
return client.delete_allowed_ip(self._account_id, id)

@contextmanager
def allowed_ip(
Expand Down
Loading
Loading