Skip to content

Commit 932c7ea

Browse files
committed
fix: keep live music API keys out of websocket urls
1 parent 7c331c6 commit 932c7ea

2 files changed

Lines changed: 30 additions & 3 deletions

File tree

google/genai/live_music.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,10 @@ async def connect(self, *, model: str) -> AsyncIterator[AsyncMusicSession]:
175175
transformed_model = t.t_model(self._api_client, model)
176176

177177
if self._api_client.api_key:
178-
api_key = self._api_client.api_key
179178
version = self._api_client._http_options.api_version
180-
uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.BidiGenerateMusic?key={api_key}'
181-
headers = self._api_client._http_options.headers
179+
uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.BidiGenerateMusic'
180+
original_headers = self._api_client._http_options.headers
181+
headers = original_headers.copy() if original_headers is not None else {}
182182

183183
# Only mldev supported
184184
request_dict = _common.convert_to_dict(

google/genai/tests/live/test_live_music.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,33 @@ async def _test_connect():
133133
return await _test_connect()
134134

135135

136+
@pytest.mark.asyncio
137+
async def test_connect_uses_header_auth_without_query_key(mock_websocket):
138+
api_client = mock_api_client(vertexai=False)
139+
api_client._websocket_base_url = lambda: 'wss://generativelanguage.googleapis.com'
140+
api_client._http_options.api_version = 'v1beta'
141+
api_client._http_options.headers['x-goog-api-key'] = 'TEST_API_KEY'
142+
captured = {}
143+
144+
@contextlib.asynccontextmanager
145+
async def mock_connect(uri, additional_headers=None):
146+
captured['uri'] = uri
147+
captured['headers'] = additional_headers
148+
yield mock_websocket
149+
150+
@patch.object(live_music, 'connect', new=mock_connect)
151+
async def _test_connect():
152+
live_module = live.AsyncLive(api_client)
153+
async with live_module.music.connect(model='test_model'):
154+
pass
155+
156+
await _test_connect()
157+
158+
assert 'TEST_API_KEY' not in captured['uri']
159+
assert '?key=' not in captured['uri']
160+
assert captured['headers']['x-goog-api-key'] == 'TEST_API_KEY'
161+
162+
136163
def test_mldev_from_env(monkeypatch):
137164
api_key = 'google_api_key'
138165
monkeypatch.setenv('GOOGLE_API_KEY', api_key)

0 commit comments

Comments
 (0)