Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions lib/crewai/src/crewai/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
GEMINI_MODELS,
OPENAI_MODELS,
)
from crewai.llms.cache import CACHE_BREAKPOINT_KEY
from crewai.utilities import InternalInstructor
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededError,
Expand Down Expand Up @@ -342,6 +343,7 @@ def _ensure_litellm() -> bool:
"cerebras",
"dashscope",
"snowflake",
"groq",
]


Expand Down Expand Up @@ -431,6 +433,7 @@ def __new__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> LLM:
"cerebras": "cerebras",
"dashscope": "dashscope",
"snowflake": "snowflake",
"groq": "groq",
}

canonical_provider = provider_mapping.get(prefix.lower())
Expand Down Expand Up @@ -553,6 +556,12 @@ def _matches_provider_pattern(cls, model: str, provider: str) -> bool:
if provider == "snowflake":
return True

if provider == "groq":
return any(
model_lower.startswith(prefix)
for prefix in ["llama", "gemma", "whisper"]
)

return False

@classmethod
Expand Down Expand Up @@ -664,6 +673,7 @@ def _get_native_provider(cls, provider: str) -> type | None:
"hosted_vllm",
"cerebras",
"dashscope",
"groq",
}
if provider in openai_compatible_providers:
from crewai.llms.providers.openai_compatible.completion import (
Expand Down Expand Up @@ -2286,6 +2296,16 @@ def _format_messages_for_provider(
"Invalid message format. Each message must be a dict with 'role' and 'content' keys"
)

# Strip cache_breakpoint key from messages if not using Anthropic
cleaned_messages = []
for msg in messages:
if not self.is_anthropic:
msg_copy = {k: v for k, v in msg.items() if k != CACHE_BREAKPOINT_KEY}
else:
msg_copy = dict(msg)
cleaned_messages.append(msg_copy)
messages = cleaned_messages

if "o1" in self.model.lower():
formatted_messages = []
for msg in messages:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ class ProviderConfig:


OPENAI_COMPATIBLE_PROVIDERS: dict[str, ProviderConfig] = {
"groq": ProviderConfig(
base_url="https://api.groq.com/openai/v1",
api_key_env="GROQ_API_KEY",
base_url_env="GROQ_BASE_URL",
api_key_required=True,
),
Comment thread
coderabbitai[bot] marked this conversation as resolved.
"openrouter": ProviderConfig(
base_url="https://openrouter.ai/api/v1",
api_key_env="OPENROUTER_API_KEY",
Expand Down
15 changes: 15 additions & 0 deletions lib/crewai/tests/llms/openai_compatible/test_openai_compatible.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,14 @@ def test_dashscope_config(self):
assert config.api_key_env == "DASHSCOPE_API_KEY"
assert config.api_key_required is True

def test_groq_config(self):
"""Test Groq provider configuration."""
config = OPENAI_COMPATIBLE_PROVIDERS["groq"]
assert config.base_url == "https://api.groq.com/openai/v1"
assert config.api_key_env == "GROQ_API_KEY"
assert config.base_url_env == "GROQ_BASE_URL"
assert config.api_key_required is True


class TestNormalizeOllamaBaseUrl:
"""Tests for _normalize_ollama_base_url helper."""
Expand Down Expand Up @@ -271,6 +279,13 @@ def test_llm_creates_openai_compatible_for_dashscope(self):
assert isinstance(llm, OpenAICompatibleCompletion)
assert llm.provider == "dashscope"

def test_llm_creates_openai_compatible_for_groq(self):
"""Test LLM factory creates OpenAICompatibleCompletion for Groq."""
with patch.dict(os.environ, {"GROQ_API_KEY": "test-key"}):
llm = LLM(model="groq/llama-3.3-70b-versatile")
assert isinstance(llm, OpenAICompatibleCompletion)
assert llm.provider == "groq"

def test_llm_with_explicit_provider(self):
"""Test LLM with explicit provider parameter."""
with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "test-key"}):
Expand Down
37 changes: 34 additions & 3 deletions lib/crewai/tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,31 @@ def test_ollama_does_not_modify_when_last_is_user(ollama_llm):
assert formatted == original_messages


def test_format_messages_strips_cache_breakpoint_for_non_anthropic():
"""cache_breakpoint keys must be removed for non-Anthropic providers."""
from crewai.llms.cache import CACHE_BREAKPOINT_KEY

llm = LLM(model="gpt-4o-mini", is_litellm=True)
# Simulate a message that carries the cache breakpoint marker
messages = [
{"role": "user", "content": "Hello", CACHE_BREAKPOINT_KEY: True},
]
formatted = llm._format_messages_for_provider(messages)
assert all(CACHE_BREAKPOINT_KEY not in msg for msg in formatted)


def test_format_messages_preserves_cache_breakpoint_for_anthropic():
"""cache_breakpoint keys must be preserved for Anthropic providers."""
from crewai.llms.cache import CACHE_BREAKPOINT_KEY

llm = LLM(model="claude-3-5-sonnet-20241022", is_litellm=True)
messages = [
{"role": "user", "content": "Hello", CACHE_BREAKPOINT_KEY: True},
]
formatted = llm._format_messages_for_provider(messages)
assert any(CACHE_BREAKPOINT_KEY in msg for msg in formatted)


def test_native_provider_raises_error_when_supported_but_fails():
"""Test that when a native provider is in SUPPORTED_NATIVE_PROVIDERS but fails to instantiate, we raise the error."""
with patch("crewai.llm.SUPPORTED_NATIVE_PROVIDERS", ["openai"]):
Expand Down Expand Up @@ -853,6 +878,12 @@ def test_prefixed_models_with_valid_constants_use_native_sdk():
assert llm3.is_litellm is False
assert llm3.provider == "gemini"

# Test groq/ prefix with Groq model in constants → Native SDK
with patch.dict(os.environ, {"GROQ_API_KEY": "test-key"}):
llm4 = LLM(model="groq/llama-3.1-70b-versatile", is_litellm=False)
assert llm4.is_litellm is False
assert llm4.provider == "groq"


def test_prefixed_models_with_invalid_constants_use_litellm():
"""Test that models with native provider prefixes use LiteLLM when model is NOT in constants and does NOT match patterns."""
Expand Down Expand Up @@ -889,10 +920,10 @@ def test_prefixed_models_with_valid_patterns_use_native_sdk():

def test_prefixed_models_with_non_native_providers_use_litellm():
"""Test that models with non-native provider prefixes always use LiteLLM."""
# Test groq/ prefix (not a native provider) → LiteLLM
llm = LLM(model="groq/llama-3.3-70b", is_litellm=False)
# Test replicate/ prefix (not a native provider) → LiteLLM
llm = LLM(model="replicate/llama-3-70b", is_litellm=False)
assert llm.is_litellm is True
assert llm.model == "groq/llama-3.3-70b"
assert llm.model == "replicate/llama-3-70b"

# Test together/ prefix (not a native provider) → LiteLLM
llm2 = LLM(model="together/qwen-2.5-72b", is_litellm=False)
Expand Down