Skip to content

Commit fafa55d

Browse files
tests: improve tests for client and config
1 parent 3d661e1 commit fafa55d

4 files changed

Lines changed: 149 additions & 73 deletions

File tree

tests/units/test_client.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,103 @@
1+
import pytest
2+
from unittest.mock import Mock, patch
13

4+
from tfe import client, config
5+
6+
7+
@pytest.fixture
8+
def test_config():
9+
return config.Config(
10+
address="https://app.terraform.io",
11+
token="test-token"
12+
)
13+
14+
15+
@pytest.fixture
16+
def mock_response():
17+
response = Mock()
18+
response.headers = {
19+
"TFP-API-Version": "2.5.0",
20+
"X-TFE-Version": "v202308-1",
21+
"TFP-AppName": "HCP Terraform"
22+
}
23+
response.raise_for_status.return_value = None
24+
return response
25+
26+
27+
class TestClient:
28+
@patch('requests.Session.get')
29+
def test_client_initialization(self, mock_get, test_config, mock_response):
30+
"""Test basic client setup works."""
31+
mock_get.return_value = mock_response
32+
33+
client_instance = client.Client(config=test_config)
34+
35+
assert client_instance.config.address == "https://app.terraform.io"
36+
assert client_instance.config.token == "test-token"
37+
assert client_instance.base_url == "https://app.terraform.io/api/v2/"
38+
assert client_instance.registry_base_url == "https://app.terraform.io/api/registry/"
39+
40+
@patch('requests.Session.get')
41+
def test_url_normalization(self, mock_get, mock_response):
42+
"""Test that paths get normalized with trailing slashes."""
43+
mock_get.return_value = mock_response
44+
45+
cfg = config.Config(
46+
address="https://example.com",
47+
token="test",
48+
base_path="/custom/api", # no trailing slash
49+
registry_base_path="/registry" # no trailing slash
50+
)
51+
52+
client_instance = client.Client(config=cfg)
53+
54+
assert client_instance.base_url == "https://example.com/custom/api/"
55+
assert client_instance.registry_base_url == "https://example.com/registry/"
56+
57+
@patch('requests.Session.get')
58+
def test_api_metadata_extraction(self, mock_get, test_config, mock_response):
59+
"""Test that API metadata gets extracted from response headers."""
60+
mock_get.return_value = mock_response
61+
62+
client_instance = client.Client(config=test_config)
63+
64+
assert client_instance.remote_api_version == "2.5.0"
65+
assert client_instance.remote_tfe_version == "v202308-1"
66+
assert client_instance.app_name == "HCP Terraform"
67+
68+
@patch('requests.Session.get')
69+
def test_cloud_vs_enterprise_detection(self, mock_get, test_config):
70+
"""Test detection between cloud and enterprise instances."""
71+
# Test HCP Terraform (cloud)
72+
cloud_response = Mock()
73+
cloud_response.headers = {"TFP-AppName": "HCP Terraform"}
74+
cloud_response.raise_for_status.return_value = None
75+
mock_get.return_value = cloud_response
76+
77+
cloud_client = client.Client(config=test_config)
78+
assert cloud_client.is_cloud() is True
79+
assert cloud_client.is_enterprise() is False
80+
81+
# Test Terraform Enterprise
82+
enterprise_response = Mock()
83+
enterprise_response.headers = {"TFP-AppName": "Terraform Enterprise"}
84+
enterprise_response.raise_for_status.return_value = None
85+
mock_get.return_value = enterprise_response
86+
87+
enterprise_client = client.Client(config=test_config)
88+
assert enterprise_client.is_cloud() is False
89+
assert enterprise_client.is_enterprise() is True
90+
91+
@patch('requests.Session.get')
92+
def test_fake_api_version_for_testing(self, mock_get, test_config, mock_response):
93+
"""Test the fake API version setter for testing scenarios."""
94+
mock_get.return_value = mock_response
95+
96+
client_instance = client.Client(config=test_config)
97+
98+
# Original version from mock
99+
assert client_instance.remote_api_version == "2.5.0"
100+
101+
# Set fake version
102+
client_instance.set_fake_remote_api_version("3.0.0")
103+
assert client_instance.remote_api_version == "3.0.0"

tests/units/test_config.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ def reset_environment(monkeypatch):
1010
monkeypatch.delenv("TFE_ADDRESS", raising=False)
1111
monkeypatch.delenv("TFE_TOKEN", raising=False)
1212
monkeypatch.delenv("TFE_HOST", raising=False)
13+
monkeypatch.setenv("TFE_TOKEN", "abc123")
1314
yield
1415

1516

@@ -34,17 +35,14 @@ def test_default_config(self, cfg):
3435
assert cfg.address == config.DEFAULT_ADDRESS
3536
assert cfg.base_path == config.DEFAULT_BASE_PATH
3637
assert cfg.registry_base_path == config.DEFAULT_REGISTRY_PATH
37-
assert cfg.token == ""
3838
assert isinstance(cfg.http_client, requests.Session)
3939
assert "User-Agent" in cfg.http_client.headers
40-
assert "Authorization" not in cfg.http_client.headers
4140
assert cfg.retry_log_hook is None
4241
assert cfg.retry_server_errors is False
4342

4443
def test_env_address_and_token(self, monkeypatch):
4544
"""Test that environment variables TFE_ADDRESS and TFE_TOKEN are read correctly."""
4645
monkeypatch.setenv("TFE_ADDRESS", "https://custom.tfe")
47-
monkeypatch.setenv("TFE_TOKEN", "abc123")
4846
cfg = config.Config()
4947
assert cfg.address == "https://custom.tfe"
5048
assert cfg.token == "abc123"
@@ -86,3 +84,14 @@ def test_custom_session(self, test_session):
8684
assert "User-Agent" in cfg.http_client.headers
8785
assert cfg.http_client.headers["User-Agent"] == "test"
8886
assert cfg.http_client.headers["Authorization"] == "Bearer test"
87+
88+
def test_validate_config(self, monkeypatch):
89+
"""Test that configuration validation works as expected."""
90+
with pytest.raises(ValueError, match="API token is required") as _:
91+
monkeypatch.setenv("TFE_TOKEN", "")
92+
cfg = config.Config(token="")
93+
94+
with pytest.raises(ValueError, match="Address must include protocol") as _:
95+
monkeypatch.setenv("TFE_TOKEN", "test-token")
96+
monkeypatch.setenv("TFE_ADDRESS", "test.foo.bar")
97+
cfg = config.Config()

tfe/client.py

Lines changed: 11 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""
44

55
import logging
6-
from urllib.parse import urljoin, urlparse
6+
from urllib.parse import urljoin
77

88
from tfe.config import Config
99

@@ -23,82 +23,33 @@ class Client:
2323
"""
2424

2525
def __init__(self, config: Config | None = None) -> None:
26-
self.config = Config()
27-
28-
if config is not None:
29-
self._merge_config(config)
30-
31-
self._validate_config()
26+
self.config = config or Config()
3227
self._setup_urls()
3328

3429
self._api_version = ""
3530
self._tfe_version = ""
3631
self._app_name = ""
3732
self._fetch_api_metadata()
3833

39-
def _merge_config(self, config: Config) -> None:
40-
"""Merge provided config with defaults, preserving non-empty values."""
41-
if config.address:
42-
self.config.address = config.address
43-
if config.base_path:
44-
self.config.base_path = config.base_path
45-
if config.registry_base_path:
46-
self.config.registry_base_path = config.registry_base_path
47-
if config.token:
48-
self.config.token = config.token
49-
if config.headers:
50-
if self.config.headers is None:
51-
self.config.headers = {}
52-
self.config.headers.update(config.headers)
53-
if config.http_client:
54-
self.config.http_client = config.http_client
55-
if config.retry_log_hook:
56-
self.config.retry_log_hook = config.retry_log_hook
57-
self.config.retry_server_errors = config.retry_server_errors
58-
59-
def _validate_config(self) -> None:
60-
"""Validate required configuration."""
61-
if not self.config.token:
62-
raise TFEClientError("API token is required")
63-
64-
if not self.config.address:
65-
raise TFEClientError("API address is required")
66-
67-
# Type narrowing: after validation, we know these are not None
68-
assert self.config.address is not None
69-
assert self.config.token is not None
70-
7134
def _setup_urls(self) -> None:
7235
"""Parse and setup base URLs."""
73-
try:
74-
# After validation, we know address is not None
75-
parsed_url = urlparse(self.config.address)
76-
if not parsed_url.scheme:
77-
raise ValueError("Address must include protocol (http/https)")
36+
# Ensure base path ends with /
37+
base_path = self.config.base_path
38+
if not base_path.endswith("/"):
39+
base_path += "/"
7840

79-
# Ensure base path ends with /
80-
base_path = self.config.base_path
81-
if not base_path.endswith("/"):
82-
base_path += "/"
41+
registry_path = self.config.registry_base_path
42+
if not registry_path.endswith("/"):
43+
registry_path += "/"
8344

84-
registry_path = self.config.registry_base_path
85-
if not registry_path.endswith("/"):
86-
registry_path += "/"
87-
88-
self.base_url = urljoin(self.config.address, base_path)
89-
self.registry_base_url = urljoin(self.config.address, registry_path)
90-
91-
except Exception as e:
92-
raise TFEClientError(f"Invalid address '{self.config.address}': {e}") from e
45+
self.base_url = urljoin(self.config.address, base_path)
46+
self.registry_base_url = urljoin(self.config.address, registry_path)
9347

9448
def _fetch_api_metadata(self) -> None:
9549
"""Fetch API metadata from the server."""
9650
ping_url = urljoin(self.base_url, "ping")
97-
98-
# After validation, we know token is not None
9951
headers = {
10052
"Accept": "application/vnd.api+json",
101-
"Authorization": f"Bearer {self.config.token}",
10253
}
10354
if self.config.headers:
10455
headers.update(self.config.headers)

tfe/config.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
from collections.abc import Callable
44
from dataclasses import dataclass, field
5+
from urllib.parse import urlparse
56

67
import requests
78

@@ -28,7 +29,7 @@ class Config:
2829

2930
# Headers to include in API requests
3031
# TODO: Do we need headers ? we can pass them directly to http_client, but this will differ from the go-tfe module
31-
headers: dict[str, str] | None = None
32+
headers: dict[str, str] = field(default_factory=dict)
3233

3334
# Custom request session which needs to be used
3435
http_client: requests.Session = field(default_factory=requests.Session)
@@ -39,7 +40,7 @@ class Config:
3940
# Enable/Disable retry logic
4041
retry_server_errors: bool = False
4142

42-
def __post_init__(self) -> None:
43+
def _set_address(self) -> None:
4344
tfe_address = os.getenv("TFE_ADDRESS", "")
4445
if tfe_address:
4546
self.address = tfe_address
@@ -50,23 +51,36 @@ def __post_init__(self) -> None:
5051
else:
5152
self.address = DEFAULT_ADDRESS
5253

54+
def _set_token(self) -> None:
5355
if not self.token:
5456
self.token = os.getenv("TFE_TOKEN", "")
5557

56-
if self.headers is None:
57-
self.headers = {}
58+
if (
59+
self.token
60+
and "Authorization" not in self.http_client.headers
61+
and "Authorization" not in self.headers
62+
):
63+
self.headers["Authorization"] = f"Bearer {self.token}"
5864

65+
def _set_user_agent(self) -> None:
5966
if (
6067
"User-Agent" not in self.http_client.headers
6168
and "User-Agent" not in self.headers
6269
):
6370
self.headers["User-Agent"] = "python-tfe"
6471

65-
if (
66-
self.token
67-
and "Authorization" not in self.http_client.headers
68-
and "Authorization" not in self.headers
69-
):
70-
self.headers["Authorization"] = f"Bearer {self.token}"
72+
def _validate_config(self) -> None:
73+
if not self.http_client.headers.get("Authorization"):
74+
raise ValueError(
75+
"API token is required, please set the TFE_TOKEN environment variable or the token field in the configuration."
76+
)
77+
parsed_url = urlparse(self.address)
78+
if not parsed_url.scheme:
79+
raise ValueError("Address must include protocol (http/https)")
7180

81+
def __post_init__(self) -> None:
82+
self._set_address()
83+
self._set_token()
84+
self._set_user_agent()
7285
self.http_client.headers.update(self.headers)
86+
self._validate_config()

0 commit comments

Comments
 (0)