diff --git a/lib/crewai/src/crewai/llm.py b/lib/crewai/src/crewai/llm.py index 153bbd2d73..7f7ecccfd8 100644 --- a/lib/crewai/src/crewai/llm.py +++ b/lib/crewai/src/crewai/llm.py @@ -45,6 +45,7 @@ GEMINI_MODELS, OPENAI_MODELS, ) +from crewai.llms.cache import CACHE_BREAKPOINT_KEY from crewai.utilities import InternalInstructor from crewai.utilities.exceptions.context_window_exceeding_exception import ( LLMContextLengthExceededError, @@ -342,6 +343,7 @@ def _ensure_litellm() -> bool: "cerebras", "dashscope", "snowflake", + "groq", ] @@ -431,6 +433,7 @@ def __new__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> LLM: "cerebras": "cerebras", "dashscope": "dashscope", "snowflake": "snowflake", + "groq": "groq", } canonical_provider = provider_mapping.get(prefix.lower()) @@ -553,6 +556,12 @@ def _matches_provider_pattern(cls, model: str, provider: str) -> bool: if provider == "snowflake": return True + if provider == "groq": + return any( + model_lower.startswith(prefix) + for prefix in ["llama", "gemma", "whisper"] + ) + return False @classmethod @@ -664,6 +673,7 @@ def _get_native_provider(cls, provider: str) -> type | None: "hosted_vllm", "cerebras", "dashscope", + "groq", } if provider in openai_compatible_providers: from crewai.llms.providers.openai_compatible.completion import ( @@ -2286,6 +2296,16 @@ def _format_messages_for_provider( "Invalid message format. Each message must be a dict with 'role' and 'content' keys" ) + # Strip cache_breakpoint key from messages if not using Anthropic + cleaned_messages = [] + for msg in messages: + if not self.is_anthropic: + msg_copy = {k: v for k, v in msg.items() if k != CACHE_BREAKPOINT_KEY} + else: + msg_copy = dict(msg) + cleaned_messages.append(msg_copy) + messages = cleaned_messages + if "o1" in self.model.lower(): formatted_messages = [] for msg in messages: diff --git a/lib/crewai/src/crewai/llms/providers/openai_compatible/completion.py b/lib/crewai/src/crewai/llms/providers/openai_compatible/completion.py index da4cfd03db..152eac2abe 100644 --- a/lib/crewai/src/crewai/llms/providers/openai_compatible/completion.py +++ b/lib/crewai/src/crewai/llms/providers/openai_compatible/completion.py @@ -43,6 +43,12 @@ class ProviderConfig: OPENAI_COMPATIBLE_PROVIDERS: dict[str, ProviderConfig] = { + "groq": ProviderConfig( + base_url="https://api.groq.com/openai/v1", + api_key_env="GROQ_API_KEY", + base_url_env="GROQ_BASE_URL", + api_key_required=True, + ), "openrouter": ProviderConfig( base_url="https://openrouter.ai/api/v1", api_key_env="OPENROUTER_API_KEY", diff --git a/lib/crewai/tests/llms/openai_compatible/test_openai_compatible.py b/lib/crewai/tests/llms/openai_compatible/test_openai_compatible.py index ce856a5334..6096937b3c 100644 --- a/lib/crewai/tests/llms/openai_compatible/test_openai_compatible.py +++ b/lib/crewai/tests/llms/openai_compatible/test_openai_compatible.py @@ -95,6 +95,14 @@ def test_dashscope_config(self): assert config.api_key_env == "DASHSCOPE_API_KEY" assert config.api_key_required is True + def test_groq_config(self): + """Test Groq provider configuration.""" + config = OPENAI_COMPATIBLE_PROVIDERS["groq"] + assert config.base_url == "https://api.groq.com/openai/v1" + assert config.api_key_env == "GROQ_API_KEY" + assert config.base_url_env == "GROQ_BASE_URL" + assert config.api_key_required is True + class TestNormalizeOllamaBaseUrl: """Tests for _normalize_ollama_base_url helper.""" @@ -271,6 +279,13 @@ def test_llm_creates_openai_compatible_for_dashscope(self): assert isinstance(llm, OpenAICompatibleCompletion) assert llm.provider == "dashscope" + def test_llm_creates_openai_compatible_for_groq(self): + """Test LLM factory creates OpenAICompatibleCompletion for Groq.""" + with patch.dict(os.environ, {"GROQ_API_KEY": "test-key"}): + llm = LLM(model="groq/llama-3.3-70b-versatile") + assert isinstance(llm, OpenAICompatibleCompletion) + assert llm.provider == "groq" + def test_llm_with_explicit_provider(self): """Test LLM with explicit provider parameter.""" with patch.dict(os.environ, {"DEEPSEEK_API_KEY": "test-key"}): diff --git a/lib/crewai/tests/test_llm.py b/lib/crewai/tests/test_llm.py index 1c98d751ef..f13fdf3440 100644 --- a/lib/crewai/tests/test_llm.py +++ b/lib/crewai/tests/test_llm.py @@ -806,6 +806,31 @@ def test_ollama_does_not_modify_when_last_is_user(ollama_llm): assert formatted == original_messages +def test_format_messages_strips_cache_breakpoint_for_non_anthropic(): + """cache_breakpoint keys must be removed for non-Anthropic providers.""" + from crewai.llms.cache import CACHE_BREAKPOINT_KEY + + llm = LLM(model="gpt-4o-mini", is_litellm=True) + # Simulate a message that carries the cache breakpoint marker + messages = [ + {"role": "user", "content": "Hello", CACHE_BREAKPOINT_KEY: True}, + ] + formatted = llm._format_messages_for_provider(messages) + assert all(CACHE_BREAKPOINT_KEY not in msg for msg in formatted) + + +def test_format_messages_preserves_cache_breakpoint_for_anthropic(): + """cache_breakpoint keys must be preserved for Anthropic providers.""" + from crewai.llms.cache import CACHE_BREAKPOINT_KEY + + llm = LLM(model="claude-3-5-sonnet-20241022", is_litellm=True) + messages = [ + {"role": "user", "content": "Hello", CACHE_BREAKPOINT_KEY: True}, + ] + formatted = llm._format_messages_for_provider(messages) + assert any(CACHE_BREAKPOINT_KEY in msg for msg in formatted) + + def test_native_provider_raises_error_when_supported_but_fails(): """Test that when a native provider is in SUPPORTED_NATIVE_PROVIDERS but fails to instantiate, we raise the error.""" with patch("crewai.llm.SUPPORTED_NATIVE_PROVIDERS", ["openai"]): @@ -853,6 +878,12 @@ def test_prefixed_models_with_valid_constants_use_native_sdk(): assert llm3.is_litellm is False assert llm3.provider == "gemini" + # Test groq/ prefix with Groq model in constants → Native SDK + with patch.dict(os.environ, {"GROQ_API_KEY": "test-key"}): + llm4 = LLM(model="groq/llama-3.1-70b-versatile", is_litellm=False) + assert llm4.is_litellm is False + assert llm4.provider == "groq" + def test_prefixed_models_with_invalid_constants_use_litellm(): """Test that models with native provider prefixes use LiteLLM when model is NOT in constants and does NOT match patterns.""" @@ -889,10 +920,10 @@ def test_prefixed_models_with_valid_patterns_use_native_sdk(): def test_prefixed_models_with_non_native_providers_use_litellm(): """Test that models with non-native provider prefixes always use LiteLLM.""" - # Test groq/ prefix (not a native provider) → LiteLLM - llm = LLM(model="groq/llama-3.3-70b", is_litellm=False) + # Test replicate/ prefix (not a native provider) → LiteLLM + llm = LLM(model="replicate/llama-3-70b", is_litellm=False) assert llm.is_litellm is True - assert llm.model == "groq/llama-3.3-70b" + assert llm.model == "replicate/llama-3-70b" # Test together/ prefix (not a native provider) → LiteLLM llm2 = LLM(model="together/qwen-2.5-72b", is_litellm=False)