diff --git a/care_scribe/migrations/0012_scribe_transcript_only.py b/care_scribe/migrations/0012_scribe_transcript_only.py new file mode 100644 index 0000000..7147115 --- /dev/null +++ b/care_scribe/migrations/0012_scribe_transcript_only.py @@ -0,0 +1,98 @@ +from django.db import migrations, models + + +def rename_processing_meta_keys_forward(apps, schema_editor): + """Rename old processing meta keys to the new keys. + + Old -> New: + provider -> chat_provider + audio_model -> transcribe_model + + Also adds `transcribe_provider` mirroring `chat_provider` so historical + entries match the new shape. + """ + Scribe = apps.get_model("care_scribe", "Scribe") + to_update = [] + for scribe in Scribe.objects.exclude(meta={}).iterator(): + meta = scribe.meta or {} + processings = meta.get("processings") + if not processings: + continue + changed = False + for processing in processings: + if not isinstance(processing, dict): + continue + if "provider" in processing and "chat_provider" not in processing: + processing["chat_provider"] = processing.pop("provider") + changed = True + if "audio_model" in processing and "transcribe_model" not in processing: + processing["transcribe_model"] = processing.pop("audio_model") + changed = True + if ( + "chat_provider" in processing + and "transcribe_provider" not in processing + ): + processing["transcribe_provider"] = processing["chat_provider"] + changed = True + if changed: + scribe.meta = meta + to_update.append(scribe) + if len(to_update) >= 500: + Scribe.objects.bulk_update(to_update, ["meta"]) + to_update = [] + if to_update: + Scribe.objects.bulk_update(to_update, ["meta"]) + + +def rename_processing_meta_keys_reverse(apps, schema_editor): + """Revert the rename: new keys -> old keys.""" + Scribe = apps.get_model("care_scribe", "Scribe") + to_update = [] + for scribe in Scribe.objects.exclude(meta={}).iterator(): + meta = scribe.meta or {} + processings = meta.get("processings") + if not processings: + continue + changed = False + for processing in processings: + if not isinstance(processing, dict): + continue + if "chat_provider" in processing and "provider" not in processing: + processing["provider"] = processing.pop("chat_provider") + changed = True + if "transcribe_model" in processing and "audio_model" not in processing: + processing["audio_model"] = processing.pop("transcribe_model") + changed = True + if "transcribe_provider" in processing: + processing.pop("transcribe_provider") + changed = True + if changed: + scribe.meta = meta + to_update.append(scribe) + if len(to_update) >= 500: + Scribe.objects.bulk_update(to_update, ["meta"]) + to_update = [] + if to_update: + Scribe.objects.bulk_update(to_update, ["meta"]) + + +class Migration(migrations.Migration): + + dependencies = [ + ('care_scribe', '0011_scribefile_mime_type'), + ] + + operations = [ + migrations.AddField( + model_name='scribe', + name='transcript_only', + field=models.BooleanField( + default=False, + help_text='If True, only transcribe the audio without running any AI form-fill processing.', + ), + ), + migrations.RunPython( + rename_processing_meta_keys_forward, + rename_processing_meta_keys_reverse, + ), + ] diff --git a/care_scribe/models/scribe.py b/care_scribe/models/scribe.py index e7e070f..0d9bcc3 100644 --- a/care_scribe/models/scribe.py +++ b/care_scribe/models/scribe.py @@ -124,6 +124,10 @@ class Status(models.TextChoices): chat_model = models.CharField(max_length=100, null=True, blank=True) audio_model = models.CharField(max_length=100, null=True, blank=True) chat_model_temperature = models.FloatField(null=True, blank=True) + transcript_only = models.BooleanField( + default=False, + help_text="If True, only transcribe the audio without running any AI form-fill processing.", + ) is_feedback_positive = models.BooleanField(null=True, blank=True, help_text="Whether the user has given positive feedback on the AI response") feedback_comments = models.TextField(null=True, blank=True, help_text="Details of the feedback provided by the user") diff --git a/care_scribe/serializers/scribe.py b/care_scribe/serializers/scribe.py index d093b00..e36e53b 100644 --- a/care_scribe/serializers/scribe.py +++ b/care_scribe/serializers/scribe.py @@ -77,6 +77,7 @@ class Meta: "chat_model", "audio_model", "chat_model_temperature", + "transcript_only", "is_feedback_positive", "feedback_comments", ] diff --git a/care_scribe/settings.py b/care_scribe/settings.py index f31199a..dbf3852 100644 --- a/care_scribe/settings.py +++ b/care_scribe/settings.py @@ -86,21 +86,32 @@ def validate(self) -> None: f'Please set the "{setting}" in the environment or the {PLUGIN_NAME} plugin config.' ) - if getattr(self, "SCRIBE_API_PROVIDER") not in ("openai", "azure", "google"): + valid_providers = ("openai", "azure", "google") + providers_in_use = set() + + for setting_name in ("SCRIBE_CHAT_MODEL_NAME", "SCRIBE_TRANSCRIBE_MODEL_NAME"): + value = getattr(self, setting_name) + if "/" not in value: + raise ImproperlyConfigured( + f'Invalid value for "{setting_name}". ' + f'Expected format "provider/model-name" ' + f'(provider must be one of {valid_providers}).' + ) + provider = value.split("/", 1)[0] + if provider not in valid_providers: + raise ImproperlyConfigured( + f'Invalid provider "{provider}" in "{setting_name}". ' + f'Provider must be one of {valid_providers}.' + ) + providers_in_use.add(provider) + + if "openai" in providers_in_use and not getattr(self, "SCRIBE_OPENAI_API_KEY"): raise ImproperlyConfigured( - 'Invalid value for "SCRIBE_API_PROVIDER". ' - 'Please set the "SCRIBE_API_PROVIDER" to "openai", "google" or "azure".' + 'The "SCRIBE_OPENAI_API_KEY" setting is required when using OpenAI API. ' + f'Please set it in the environment or the {PLUGIN_NAME} plugin config.' ) - if getattr(self, "SCRIBE_API_PROVIDER") == "openai": - for setting in ("SCRIBE_OPENAI_API_KEY",): - if not getattr(self, setting): - raise ImproperlyConfigured( - f'The "{setting}" setting is required when using OpenAI API. ' - f'Please set the "{setting}" in the environment or the {PLUGIN_NAME} plugin config.' - ) - - if getattr(self, "SCRIBE_API_PROVIDER") == "azure": + if "azure" in providers_in_use: for setting in ("SCRIBE_AZURE_API_VERSION", "SCRIBE_AZURE_ENDPOINT", "SCRIBE_AZURE_API_KEY"): if not getattr(self, setting): raise ImproperlyConfigured( @@ -108,7 +119,7 @@ def validate(self) -> None: f'Please set the "{setting}" in the environment or the {PLUGIN_NAME} plugin config.' ) - if getattr(self, "SCRIBE_API_PROVIDER") == "google": + if "google" in providers_in_use: for setting in ("SCRIBE_GOOGLE_PROJECT_ID", "SCRIBE_GOOGLE_LOCATION"): if not getattr(self, setting): raise ImproperlyConfigured( @@ -129,19 +140,19 @@ def reload(self) -> None: REQUIRED_SETTINGS = { "SCRIBE_CHAT_MODEL_NAME", - "SCRIBE_API_PROVIDER", + "SCRIBE_TRANSCRIBE_MODEL_NAME", } DEFAULTS = { "SCRIBE_OPENAI_API_KEY": "", "SCRIBE_AZURE_API_KEY": "", - "SCRIBE_AUDIO_MODEL_NAME": "whisper-1", - "SCRIBE_CHAT_MODEL_NAME": "gpt-4o", - "SCRIBE_API_PROVIDER": "openai", + "SCRIBE_TRANSCRIBE_MODEL_NAME": "openai/whisper-1", + "SCRIBE_CHAT_MODEL_NAME": "openai/gpt-4o", "SCRIBE_AZURE_API_VERSION": "", "SCRIBE_AZURE_ENDPOINT": "", "SCRIBE_GOOGLE_PROJECT_ID" : "", "SCRIBE_GOOGLE_LOCATION" : "", + "SCRIBE_TRANSCRIBE_LANGUAGE": "", # only works for google. OpenAI can return source language or only translate to English. "SCRIBE_TNC": "", } diff --git a/care_scribe/tasks/scribe.py b/care_scribe/tasks/scribe.py index 4f2368b..a0c6202 100644 --- a/care_scribe/tasks/scribe.py +++ b/care_scribe/tasks/scribe.py @@ -21,7 +21,118 @@ logger = logging.getLogger(__name__) -def ai_client(provider=plugin_settings.SCRIBE_API_PROVIDER): +def _google_credentials(): + b64_credentials = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_B64") + if not b64_credentials: + return None + info = json.loads(base64.b64decode(b64_credentials).decode("utf-8")) + return service_account.Credentials.from_service_account_info( + info, scopes=["https://www.googleapis.com/auth/cloud-platform"] + ) + +def _google_llm_transcribe(audio_file_object, model_name): + """Transcribe a single audio file using a Google Gemini model. + + The audio is sent to the configured Gemini model with a prompt instructing + it to return ONLY the transcribed text. If ``SCRIBE_TRANSCRIBE_LANGUAGE`` + is set, the model is asked to translate into that language; otherwise the + transcript is returned in the original spoken language. + """ + target_language = (plugin_settings.SCRIBE_TRANSCRIBE_LANGUAGE or "").strip() + + _, audio_data = audio_file_object.files_manager.file_contents(audio_file_object) + fmt = audio_file_object.internal_name.split(".")[-1] + + client = ai_client("google") + if target_language: + prompt = ( + "You are an audio transcription engine. Transcribe the provided " + f"audio and translate the transcript into the language with BCP-47 " + f"code '{target_language}'.\n" + "Strict output rules:\n" + f"- Output ONLY the final transcript in '{target_language}'.\n" + "- Do NOT include the original-language transcription.\n" + "- Do NOT include both languages or any side-by-side text.\n" + "- Do NOT add explanations, labels, preambles, quotes, or markdown.\n" + "- If the audio is empty or unintelligible, or contains no speech, output an empty string." + ) + else: + prompt = ( + "You are an audio transcription engine. Transcribe the provided " + "audio in the original spoken language. Do not translate.\n" + "Strict output rules:\n" + "- Output ONLY the transcript text.\n" + "- Do NOT add explanations, labels, preambles, quotes, or markdown.\n" + "- If the audio is empty or unintelligible, or contains no speech, output an empty string." + ) + response = client.models.generate_content( + model=model_name, + contents=[ + types.Content( + role="user", + parts=[ + types.Part.from_text(text=prompt), + types.Part.from_bytes( + data=audio_data, + mime_type=f"audio/{fmt}", + ), + ], + ) + ], + config=types.GenerateContentConfig( + temperature=0, + thinking_config=( + types.ThinkingConfig(thinking_budget=0) + if "2.5" in model_name and "pro" not in model_name + else None + ), + ), + ) + return (response.text or "").strip() + + +def transcribe_audio_file(audio_file_object, provider, audio_model): + """Transcribe a single audio file using the configured provider.""" + if provider == "google": + return _google_llm_transcribe(audio_file_object, audio_model) + + client = ai_client(provider) + _, audio_file_data = audio_file_object.files_manager.file_contents( + audio_file_object + ) + fmt = audio_file_object.internal_name.split(".")[-1] + buffer = io.BytesIO(audio_file_data) + buffer.name = "file." + fmt + # Only whisper-1 supports the /audio/translations endpoint. + # Newer models (gpt-4o-transcribe, gpt-4o-mini-transcribe, etc.) are + # transcription-only and must use /audio/transcriptions. + if audio_model == "whisper-1": + transcription = client.audio.translations.create( + model=audio_model, file=buffer + ) + else: + transcription = client.audio.transcriptions.create( + model=audio_model, file=buffer + ) + return transcription.text + + +def _parse_provider_model(value: str): + """Split a 'provider/model-name' string into (provider, model). + + The model portion may itself contain '/' characters (kept intact). + """ + if not value or "/" not in value: + raise ValueError( + f"Expected 'provider/model-name' format, got: {value!r}" + ) + provider, model = value.split("/", 1) + if provider == "openai" and plugin_settings.SCRIBE_AZURE_API_KEY: + provider = "azure" + return provider, model + + +def ai_client(provider): if provider == "azure": AiClient = AzureOpenAI( api_key=plugin_settings.SCRIBE_AZURE_API_KEY, @@ -34,25 +145,18 @@ def ai_client(provider=plugin_settings.SCRIBE_API_PROVIDER): ) elif provider == "google": - credentials = None - b64_credentials = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_B64") - - if b64_credentials: - info = json.loads(base64.b64decode(b64_credentials).decode("utf-8")) - credentials = service_account.Credentials.from_service_account_info(info, scopes=["https://www.googleapis.com/auth/cloud-platform"]) - AiClient = genai.Client( vertexai=True, project=plugin_settings.SCRIBE_GOOGLE_PROJECT_ID, location=plugin_settings.SCRIBE_GOOGLE_LOCATION, - credentials=credentials, + credentials=_google_credentials(), ) else: raise Exception("Invalid api provider") return AiClient -def chat_message(provider=plugin_settings.SCRIBE_API_PROVIDER, role="user", text=None, file_object=None, file_type="audio"): +def chat_message(provider, role="user", text=None, file_object=None, file_type="audio"): """ Generates a chat message compatible with the given AI provider client.""" if file_object: _, file_data = file_object.files_manager.file_contents(file_object) @@ -184,35 +288,88 @@ def process_ai_form_fill(external_id): form.save() return - api_provider = plugin_settings.SCRIBE_API_PROVIDER - chat_model = plugin_settings.SCRIBE_CHAT_MODEL_NAME - audio_model = plugin_settings.SCRIBE_AUDIO_MODEL_NAME + chat_provider, chat_model = _parse_provider_model( + plugin_settings.SCRIBE_CHAT_MODEL_NAME + ) + transcribe_provider, transcribe_model = _parse_provider_model( + plugin_settings.SCRIBE_TRANSCRIBE_MODEL_NAME + ) temperature = 0 if form.chat_model: - api_provider = form.chat_model.split("/")[0] - if api_provider == "openai" and plugin_settings.SCRIBE_AZURE_API_KEY is not "": - api_provider = "azure" - chat_model = form.chat_model.split("/")[1] + chat_provider, chat_model = _parse_provider_model(form.chat_model) if form.audio_model: - audio_model = form.audio_model + # Form override may be either "provider/model" or just a model name + if "/" in form.audio_model: + transcribe_provider, transcribe_model = _parse_provider_model( + form.audio_model + ) + else: + transcribe_model = form.audio_model if form.chat_model_temperature is not None: temperature = form.chat_model_temperature - processing["provider"] = api_provider + processing["chat_provider"] = chat_provider processing["chat_model"] = chat_model - processing["audio_model"] = audio_model if api_provider != "google" else None + processing["transcribe_provider"] = transcribe_provider + processing["transcribe_model"] = ( + transcribe_model if chat_provider != "google" else None + ) processing["form_data"] = form.form_data - # Instantiate the AI client once to avoid premature closure and resource management issues, - # especially with the Google GenAI provider. Reuse this client instance throughout the function. - client = ai_client(api_provider) - audio_files = ScribeFile.objects.filter(external_id__in=form.audio_file_ids) total_audio_duration = sum(file.meta.get("length", 0) for file in audio_files) + if form.transcript_only: + logger.info(f"=== Processing transcript-only Scribe {form.external_id} ===") + processing["transcript_only"] = True + processing["transcribe_model"] = transcribe_model + try: + form.status = Scribe.Status.GENERATING_TRANSCRIPT + form.save() + transcript = form.transcript or "" + if not transcript: + transcription_start = perf_counter() + for audio_file_object in audio_files: + transcript += ( + transcribe_audio_file( + audio_file_object=audio_file_object, + provider=transcribe_provider, + audio_model=transcribe_model, + ) + or "" + ) + processing["transcription_time"] = perf_counter() - transcription_start + form.transcript = transcript + form.meta["processings"] = [ + *form.meta.get("processings", []), + processing, + ] + form.status = Scribe.Status.COMPLETED + form.save() + if not is_benchmark: + user_quota.calculate_used() + facility_quota.calculate_used() + except Exception as e: + logger.error( + f"Transcript-only processing failed at line " + f"{e.__traceback__.tb_lineno}: {e}" + ) + processing["error"] = str(e) + form.meta["processings"] = [ + *form.meta.get("processings", []), + processing, + ] + form.status = Scribe.Status.FAILED + form.save() + return + + # Instantiate the AI client once to avoid premature closure and resource management issues, + # especially with the Google GenAI provider. Reuse this client instance throughout the function. + client = ai_client(chat_provider) + processed_fields = {} def process_fields(fields: list, indent: int = 0): @@ -247,7 +404,7 @@ def process_fields(fields: list, indent: int = 0): # Asking for the full transcription on longer audio would eat up too many tokens. output_schema["properties"]["__scribe__transcription"]["description"] = f"A short summarized transcription of the {'image' if len(form.document_file_ids) > 0 else 'audio'} content, focusing on key points and insights in English." - if api_provider != "google" and len(form.document_file_ids) == 0: + if chat_provider != "google" and len(form.document_file_ids) == 0: # As we are transcribing using whisper, we do not need the transcription field in the output schema del output_schema["properties"]["__scribe__transcription"] output_schema["required"].remove("__scribe__transcription") @@ -261,7 +418,7 @@ def process_fields(fields: list, indent: int = 0): messages.append( chat_message( - provider=api_provider, + provider=chat_provider, role="system", text=base_prompt, ) @@ -270,7 +427,7 @@ def process_fields(fields: list, indent: int = 0): if form.text: messages.append( chat_message( - provider=api_provider, + provider=chat_provider, role="user", text=form.text, ) @@ -286,10 +443,10 @@ def process_fields(fields: list, indent: int = 0): for audio_file_object in audio_files: - if api_provider == "google": + if chat_provider == "google": messages.append( chat_message( - provider=api_provider, + provider=chat_provider, role="user", file_object=audio_file_object, file_type="audio", @@ -297,13 +454,13 @@ def process_fields(fields: list, indent: int = 0): ) else: - _, audio_file_data = audio_file_object.files_manager.file_contents(audio_file_object) - format = audio_file_object.internal_name.split(".")[-1] - buffer = io.BytesIO(audio_file_data) - buffer.name = "file." + format logger.info(f"=== Generating transcript for AI form fill {form.external_id} ===") try: - transcription = client.audio.translations.create(model=audio_model, file=buffer) + transcription_text = transcribe_audio_file( + audio_file_object=audio_file_object, + provider=transcribe_provider, + audio_model=transcribe_model, + ) except Exception as e: logger.error(f"Error generating transcript: {e}") processing["error"] = f"Error generating transcript: {e}" @@ -315,7 +472,7 @@ def process_fields(fields: list, indent: int = 0): form.save() return - transcript += transcription.text + transcript += transcription_text or "" logger.info(f"Transcript: {transcript}") transcription_time = perf_counter() - initiation_time @@ -333,7 +490,7 @@ def process_fields(fields: list, indent: int = 0): for document_file_object in document_file_objects: messages.append( chat_message( - provider=api_provider, + provider=chat_provider, role="user", file_object=document_file_object, file_type="image", @@ -343,7 +500,7 @@ def process_fields(fields: list, indent: int = 0): if transcript != "": messages.append( chat_message( - provider=api_provider, + provider=chat_provider, role="user", text=transcript, ) @@ -355,7 +512,7 @@ def process_fields(fields: list, indent: int = 0): completion_start_time = perf_counter() - if api_provider == "google": + if chat_provider == "google": output_schema_hash = hash_string(json.dumps(output_schema, sort_keys=True)) try: