diff --git a/src/aap_eda/api/metadata.py b/src/aap_eda/api/metadata.py index 29165e678..249c039c8 100644 --- a/src/aap_eda/api/metadata.py +++ b/src/aap_eda/api/metadata.py @@ -83,7 +83,7 @@ def _customize_field_attributes(method: str, action: dict): action.pop(field) # For PUT/PATCH/POST methods, remove read-only fields. - if method in ("PUT", "PATCH", "POST"): + elif method in ("PUT", "PATCH", "POST"): # file-based read-only settings can't be updated meta.pop("defined_in_file", False) diff --git a/src/aap_eda/api/views/root.py b/src/aap_eda/api/views/root.py index b62750b84..3a63e9d9c 100644 --- a/src/aap_eda/api/views/root.py +++ b/src/aap_eda/api/views/root.py @@ -52,32 +52,34 @@ def get(self, request, *args, **kwargs): return Response(urls) +def _list_urls(url_patterns, request=None): + """Collect named URL patterns into a dict of name to URL.""" + url_list = {} + for url in url_patterns: + if isinstance(url, URLResolver): + url_list.update(_list_urls(url.url_patterns, request)) + elif isinstance(url, URLPattern): + name = url.name + if not name: + LOGGER.warning( + "URL %s has no name, DRF browsable API will omit it", + url.pattern, + ) + continue + if url.pattern.regex.groups: + continue + url_list[name] = reverse(name, request=request) + return url_list + + def get_api_v1_urls(request=None): from aap_eda.api import urls - def list_urls(urls): - url_list = {} - for url in urls: - if isinstance(url, URLResolver): - url_list.update(list_urls(url.url_patterns)) - elif isinstance(url, URLPattern): - name = url.name - if not name: - LOGGER.warning( - "URL %s has no name, DRF browsable API will omit it", - url.pattern, - ) - continue - if url.pattern.regex.groups: - continue - url_list[name] = reverse(name, request=request) - return url_list - if settings.ALLOW_LOCAL_RESOURCE_MANAGEMENT: - return list_urls(urls.v1_urls) + return _list_urls(urls.v1_urls, request) - url_list = list_urls(urls.eda_v1_urls) - all_urls = list_urls(urls.dab_urls) + url_list = _list_urls(urls.eda_v1_urls, request) + all_urls = _list_urls(urls.dab_urls, request) for name, url in all_urls.items(): if name in ALWAYS_VISIBLE_ENDPOINTS: url_list[name] = url diff --git a/src/aap_eda/core/management/commands/create_initial_data.py b/src/aap_eda/core/management/commands/create_initial_data.py index f65b952b6..809ca8cf3 100644 --- a/src/aap_eda/core/management/commands/create_initial_data.py +++ b/src/aap_eda/core/management/commands/create_initial_data.py @@ -2655,150 +2655,160 @@ def _create_obj_roles(self): for cls in permission_registry.all_registered_models: ct = self.content_type_model.objects.get_for_model(cls) parent_model = permission_registry.get_parent_model(cls) - # ignore if the model is organization, covered by org roles - # or child model, inherits permissions from parent model + # ignore if the model is organization, covered by org + # roles or child model, inherits permissions from + # parent model if cls._meta.model_name == "organization" or ( parent_model and parent_model._meta.model_name != "organization" ): continue - permissions = self._create_permissions_for_content_type(ct) - desc = f"Has all permissions to a single {cls._meta.verbose_name}" - # parent model should add permissions related to its child models - child_models = permission_registry.get_child_models(cls) - child_names = [] - for _, child_model in child_models: - child_ct = self.content_type_model.objects.get_for_model( - child_model - ) - permissions.extend( - self._create_permissions_for_content_type(child_ct) - ) - child_names.append(child_model._meta.verbose_name) - if child_names: - desc += f" and its child resources - {', '.join(child_names)}" # noqa: E501 - - # create resource admin role - admin_role_name = f"{cls._meta.verbose_name.title()} Admin" - if cls._meta.model_name == "project": - admin_role_name = f"EDA {admin_role_name}" - elif cls._meta.model_name == "edacredential": - admin_role_name = admin_role_name.replace("Eda ", "EDA ") - - role, created = RoleDefinition.objects.update_or_create( - name=admin_role_name, - defaults={ - "description": desc, - "content_type": ct, - "managed": True, - }, - ) - role.permissions.set(permissions) - if created: - self.stdout.write( - f"Added role {role.name} with {len(permissions)} " - "permissions to itself" - ) - # create resource use role - # ignore team model as it makes no sense to have Use role for it - # and should be managed by Admin users only - if cls._meta.model_name != "team": - use_role_name = f"{cls._meta.verbose_name.title()} Use" - if cls._meta.model_name == "project": - use_role_name = f"EDA {use_role_name}" - elif cls._meta.model_name == "edacredential": - use_role_name = use_role_name.replace("Eda ", "EDA ") - - ( - use_role, - use_role_created, - ) = RoleDefinition.objects.update_or_create( - name=use_role_name, - defaults={ - "description": f"Has use permissions to a single {cls._meta.verbose_name}", # noqa: E501 - "content_type": ct, - "managed": True, - }, - ) - use_permissions = [ - perm - for perm in permissions - if perm.codename.startswith("view_") - ] - use_role.permissions.set(use_permissions) - if use_role_created: - self.stdout.write( - f"Added role {use_role.name} with " - f"{len(use_permissions)} permissions to itself" - ) + permissions = self._collect_permissions(cls, ct) + desc = self._build_admin_description(cls) + self._create_admin_role(cls, ct, permissions, desc) - # create org-level admin roles for each resource type - org_role_name = ( - f"Organization {cls._meta.verbose_name.title()} Admin" - ) - if cls._meta.model_name == "project": - org_role_name = f"EDA {org_role_name}" - elif cls._meta.model_name == "edacredential": - org_role_name = org_role_name.replace("Eda ", "EDA ") - - ( - org_role, - org_role_created, - ) = RoleDefinition.objects.update_or_create( - name=org_role_name, - defaults={ - "description": f"Has all permissions to {cls._meta.verbose_name}s within an organization", # noqa: E501 - "content_type": self.content_type_model.objects.get( - model="organization" - ), - "managed": True, - }, - ) - permissions.extend( - DABPermission.objects.filter( - content_type=ct, codename__startswith="add_" - ) - ) - # Add organization view permission for organization-level admin roles, # noqa: E501 - org_ct = self.content_type_model.objects.get( - model="organization" - ) - org_view_permission = DABPermission.objects.filter( - content_type=org_ct, codename="view_organization" - ).first() - if ( - org_view_permission - and org_view_permission not in permissions - ): - permissions.append(org_view_permission) - - org_role.permissions.set(permissions) - if org_role_created: - self.stdout.write( - f"Added role {org_role.name} with {len(permissions)} " - "permissions to itself" - ) + # ignore team model as it makes no sense to have Use + # role for it and should be managed by Admin users only + if cls._meta.model_name != "team": + self._create_use_role(cls, ct, permissions) + self._create_org_admin_role(cls, ct, permissions) # Special case to create team member role if cls._meta.model_name == "team": - member_permissions = [ - p - for p in permissions - if p.codename in ("view_team", "member_team") - ] - desc = "Inherits permissions assigned to this team" - role, created = RoleDefinition.objects.update_or_create( - name="Team Member", - defaults={ - "description": desc, - "content_type": ct, - "managed": True, - }, - ) - role.permissions.set(member_permissions) - if created: - self.stdout.write( - f"Added role {role.name} with " - f"{len(member_permissions)} permissions to itself" - ) + self._create_team_member_role(ct, permissions) + + def _collect_permissions(self, cls, ct): + """Gather permissions for a model including child models.""" + permissions = self._create_permissions_for_content_type(ct) + # parent model should add permissions related to its + # child models + for _, child_model in permission_registry.get_child_models(cls): + child_ct = self.content_type_model.objects.get_for_model( + child_model + ) + permissions.extend( + self._create_permissions_for_content_type(child_ct) + ) + return permissions + + def _build_admin_description(self, cls): + """Build the admin role description with child resources.""" + desc = "Has all permissions to a single " f"{cls._meta.verbose_name}" + child_names = [ + cm._meta.verbose_name + for _, cm in permission_registry.get_child_models(cls) + ] + if child_names: + desc += " and its child resources" f" - {', '.join(child_names)}" + return desc + + def _format_role_name(self, cls, template): + """Format a role name with EDA prefix for special models.""" + name = template.format(verbose=cls._meta.verbose_name.title()) + if cls._meta.model_name == "project": + name = f"EDA {name}" + elif cls._meta.model_name == "edacredential": + name = name.replace("Eda ", "EDA ") + return name + + def _create_admin_role(self, cls, ct, permissions, desc): + """Create resource admin role.""" + name = self._format_role_name(cls, "{verbose} Admin") + role, created = RoleDefinition.objects.update_or_create( + name=name, + defaults={ + "description": desc, + "content_type": ct, + "managed": True, + }, + ) + role.permissions.set(permissions) + if created: + self.stdout.write( + f"Added role {role.name} with " + f"{len(permissions)} permissions to itself" + ) + + def _create_use_role(self, cls, ct, permissions): + """Create resource use role (view-only permissions).""" + name = self._format_role_name(cls, "{verbose} Use") + use_role, created = RoleDefinition.objects.update_or_create( + name=name, + defaults={ + "description": ( + "Has use permissions to a single " + f"{cls._meta.verbose_name}" + ), + "content_type": ct, + "managed": True, + }, + ) + use_permissions = [ + p for p in permissions if p.codename.startswith("view_") + ] + use_role.permissions.set(use_permissions) + if created: + self.stdout.write( + f"Added role {use_role.name} with " + f"{len(use_permissions)} permissions to itself" + ) + + def _create_org_admin_role(self, cls, ct, permissions): + """Create org-level admin role for a resource type.""" + name = self._format_role_name(cls, "Organization {verbose} Admin") + org_ct = self.content_type_model.objects.get(model="organization") + org_role, created = RoleDefinition.objects.update_or_create( + name=name, + defaults={ + "description": ( + "Has all permissions to " + f"{cls._meta.verbose_name}s " + "within an organization" + ), + "content_type": org_ct, + "managed": True, + }, + ) + permissions.extend( + DABPermission.objects.filter( + content_type=ct, codename__startswith="add_" + ) + ) + # Add organization view permission for + # organization-level admin roles + org_view_permission = DABPermission.objects.filter( + content_type=org_ct, codename="view_organization" + ).first() + if org_view_permission and org_view_permission not in permissions: + permissions.append(org_view_permission) + + org_role.permissions.set(permissions) + if created: + self.stdout.write( + f"Added role {org_role.name} with " + f"{len(permissions)} permissions to itself" + ) + + def _create_team_member_role(self, ct, permissions): + """Create team member role.""" + member_permissions = [ + p + for p in permissions + if p.codename in ("view_team", "member_team") + ] + role, created = RoleDefinition.objects.update_or_create( + name="Team Member", + defaults={ + "description": ("Inherits permissions assigned to this team"), + "content_type": ct, + "managed": True, + }, + ) + role.permissions.set(member_permissions) + if created: + self.stdout.write( + f"Added role {role.name} with " + f"{len(member_permissions)} " + "permissions to itself" + ) diff --git a/src/aap_eda/core/models/mixins.py b/src/aap_eda/core/models/mixins.py index addc7083d..237abb364 100644 --- a/src/aap_eda/core/models/mixins.py +++ b/src/aap_eda/core/models/mixins.py @@ -48,38 +48,39 @@ def update_status( ) def save(self, *args, **kwargs): - # when creating + """Save the activation status.""" if self._state.adding: if self.status_message is None: self._set_status_message() else: - if not bool(kwargs) or "update_fields" not in kwargs: - raise UpdateFieldsRequiredError( - "update_fields is required to use when saving " - "due to race conditions" - ) - else: - if "status" in kwargs["update_fields"]: - self._is_valid_status() - - if ( - "status_message" in kwargs["update_fields"] - and "status" not in kwargs["update_fields"] - ): - raise StatusRequiredError( - "status_message cannot be set by itself, " - "it requires status and status_message together" - ) - # when updating without status_message - elif ( - "status" in kwargs["update_fields"] - and "status_message" not in kwargs["update_fields"] - ): - self._set_status_message() - kwargs["update_fields"].append("status_message") + self._validate_update_fields(kwargs) super().save(*args, **kwargs) + def _validate_update_fields(self, kwargs): + """Validate kwargs for update operations.""" + if not bool(kwargs) or "update_fields" not in kwargs: + raise UpdateFieldsRequiredError( + "update_fields is required to use when saving " + "due to race conditions" + ) + + update_fields = kwargs["update_fields"] + if "status" in update_fields: + self._is_valid_status() + + has_status_message = "status_message" in update_fields + has_status = "status" in update_fields + + if has_status_message and not has_status: + raise StatusRequiredError( + "status_message cannot be set by itself, " + "it requires status and status_message together" + ) + if has_status and not has_status_message: + self._set_status_message() + update_fields.append("status_message") + def _set_status_message(self): self.status_message = self._get_default_status_message() diff --git a/src/aap_eda/core/models/rulebook_process.py b/src/aap_eda/core/models/rulebook_process.py index 350701c08..71d830997 100644 --- a/src/aap_eda/core/models/rulebook_process.py +++ b/src/aap_eda/core/models/rulebook_process.py @@ -80,9 +80,8 @@ def __str__(self) -> str: return f"Rulebook Process id {self.id}" def save(self, *args, **kwargs): - # when creating + """Save the rulebook process.""" if self._state.adding: - # ensure type is set self._set_parent_type() parent = self.get_parent() parent.latest_instance = self @@ -90,37 +89,37 @@ def save(self, *args, **kwargs): if self.status_message is None: self.status_message = self._get_default_status_message() else: - if not bool(kwargs) or "update_fields" not in kwargs: - raise UpdateFieldsRequiredError( - "update_fields is required to use when saving " - "due to race conditions" - ) - else: - if "status" in kwargs["update_fields"]: - self._is_valid_status() - - if ( - "status_message" in kwargs["update_fields"] - and "status" not in kwargs["update_fields"] - ): - raise StatusRequiredError( - "status_message cannot be set by itself, " - "it requires status and status_message together" - ) - # when updating without status_message - elif ( - "status" in kwargs["update_fields"] - and "status_message" not in kwargs["update_fields"] - ): - self.status_message = self._get_default_status_message() - kwargs["update_fields"].append("status_message") + self._validate_update_fields(kwargs) super().save(*args, **kwargs) - # update parent's latest_instance parent = self.get_parent() parent.save(update_fields=["latest_instance"]) + def _validate_update_fields(self, kwargs): + """Validate kwargs for update operations.""" + if not bool(kwargs) or "update_fields" not in kwargs: + raise UpdateFieldsRequiredError( + "update_fields is required to use when saving " + "due to race conditions" + ) + + update_fields = kwargs["update_fields"] + if "status" in update_fields: + self._is_valid_status() + + has_status_message = "status_message" in update_fields + has_status = "status" in update_fields + + if has_status_message and not has_status: + raise StatusRequiredError( + "status_message cannot be set by itself, " + "it requires status and status_message together" + ) + if has_status and not has_status_message: + self.status_message = self._get_default_status_message() + update_fields.append("status_message") + def _check_parent(self): """Clean method for RulebookProcess model.""" # Check that activation is set diff --git a/src/aap_eda/core/utils/credentials.py b/src/aap_eda/core/utils/credentials.py index a60353fba..c626c1c05 100644 --- a/src/aap_eda/core/utils/credentials.py +++ b/src/aap_eda/core/utils/credentials.py @@ -191,106 +191,168 @@ def validate_inputs( aes_salt = secrets.token_hex(32) for data in section_fields: - field = data["id"] - required = field in required_fields - default = data.get("default") - user_input = inputs.get(field) - display_field = f"inputs.{field}" - - if user_input is None: - if default: - inputs[field] = default - if required and not default: - errors[display_field] = ["Cannot be blank"] - continue - else: - if not isinstance(user_input, str) and not isinstance( - user_input, bool - ): - msg = ( - f"Input fields must have a boolean or string value. " - f"The value provided in the '{field}' field is of " - f"type {type(user_input).__name__}." - ) + _validate_field( + data, + inputs, + required_fields, + errors, + credential_type, + schema, + old_inputs, + aes_salt, + ) - errors[display_field] = [msg] - continue - else: - if required and len(user_input.strip()) == 0: - errors[display_field] = ["Cannot be blank"] - continue - if data.get("format") and user_input: - result = _validate_format( - schema=schema, - data_type=data.get("format"), - data=user_input, - inputs=inputs, - ) - if bool(result): - if PROTECTED_PASSPHRASE_ERROR in result: - errors["inputs.ssh_key_unlock"] = result - else: - errors[display_field] = result - - # We apply particular requirements on "host" when it is - # associated with a container registry. - if ( - (credential_type.name == enums.DefaultCredentialType.REGISTRY) - and (field == "host") - and user_input - ): - result = _validate_registry_host_name(user_input) - if bool(result): - errors[display_field] = result - - if field == "gpg_public_key": - result = _validate_gpg_public_key(user_input) - if bool(result): - errors[display_field] = result - - # Special validation for mTLS certificate and subject - if credential_type.name == enums.EventStreamCredentialType.MTLS: - # Also validate subject format if provided - if field == "subject" and user_input: - subject_errors = _validate_certificate_subject_format( - user_input - ) - if bool(subject_errors): - errors["inputs.subject"] = subject_errors + return errors - if data.get("type") == "boolean": - if user_input and not isinstance(user_input, bool): - errors[display_field] = ["Must be a boolean"] - continue - choices = data.get("choices") - if choices and user_input and user_input not in choices: - errors[display_field] = [f"Must be one of the choices: {choices}"] - continue +def _validate_field( + data, + inputs, + required_fields, + errors, + credential_type, + schema, + old_inputs, + aes_salt, +): + """Validate a single field against the schema.""" + field = data["id"] + required = field in required_fields + default = data.get("default") + user_input = inputs.get(field) + display_field = f"inputs.{field}" + + if user_input is None: + if default: + inputs[field] = default + if required and not default: + errors[display_field] = ["Cannot be blank"] + return + elif not isinstance(user_input, (str, bool)): + msg = ( + "Input fields must have a boolean or string " + f"value. The value provided in the '{field}' " + f"field is of type " + f"{type(user_input).__name__}." + ) + errors[display_field] = [msg] + return + elif required and len(user_input.strip()) == 0: + errors[display_field] = ["Cannot be blank"] + return + + _validate_field_format( + data, + field, + user_input, + inputs, + errors, + display_field, + credential_type, + schema, + ) - if data.get("format") == "aes_salt" and aes_salt is not None: - inputs[field] = aes_salt + if data.get("type") == "boolean": + if user_input and not isinstance(user_input, bool): + errors[display_field] = ["Must be a boolean"] + return + + choices = data.get("choices") + if choices and user_input and user_input not in choices: + errors[display_field] = [f"Must be one of the choices: {choices}"] + return + + _apply_aes_fields( + data, + field, + inputs, + errors, + display_field, + old_inputs, + aes_salt, + ) - if data.get("format") == "aes_key" and inputs.get(field): - if aes_salt is None: - errors[display_field] = [AES_SALT_ERROR] - else: - # If this is update preserve the old values - if old_inputs.get(field) != inputs.get(field): - inputs[field] = _get_aes_key(inputs[field], aes_salt) - return errors +def _apply_aes_fields( + data, field, inputs, errors, display_field, old_inputs, aes_salt +): + """Apply AES salt and key field processing.""" + if data.get("format") == "aes_salt" and aes_salt is not None: + inputs[field] = aes_salt + + if data.get("format") == "aes_key" and inputs.get(field): + if aes_salt is None: + errors[display_field] = [AES_SALT_ERROR] + # If this is update preserve the old values + elif old_inputs.get(field) != inputs.get(field): + inputs[field] = _get_aes_key(inputs[field], aes_salt) + + +def _validate_field_format( + data, + field, + user_input, + inputs, + errors, + display_field, + credential_type, + schema, +): + """Run format and type-specific validations.""" + _check_format_errors( + data, user_input, inputs, errors, display_field, schema + ) + + # We apply particular requirements on "host" when it + # is associated with a container registry. + if ( + credential_type.name == enums.DefaultCredentialType.REGISTRY + and field == "host" + and user_input + ): + result = _validate_registry_host_name(user_input) + if bool(result): + errors[display_field] = result + + if field == "gpg_public_key": + result = _validate_gpg_public_key(user_input) + if bool(result): + errors[display_field] = result + + # Special validation for mTLS certificate and subject + if ( + credential_type.name == enums.EventStreamCredentialType.MTLS + and field == "subject" + and user_input + ): + subject_errors = _validate_certificate_subject_format(user_input) + if bool(subject_errors): + errors["inputs.subject"] = subject_errors + + +def _check_format_errors( + data, user_input, inputs, errors, display_field, schema +): + """Check format-specific validation errors.""" + if not (data.get("format") and user_input): + return + result = _validate_format( + schema=schema, + data_type=data.get("format"), + data=user_input, + inputs=inputs, + ) + if not bool(result): + return + if PROTECTED_PASSPHRASE_ERROR in result: + errors["inputs.ssh_key_unlock"] = result + else: + errors[display_field] = result def validate_schema(schema: dict) -> list[str]: """Validate a credential schema. - Sample output: - [ - "label must exist and be a string", - "type must be either string or boolean" - ] - Return an empty list if no errors. """ errors = [] @@ -304,98 +366,123 @@ def validate_schema(schema: dict) -> list[str]: if not isinstance(fields, list): errors.append("'fields' must be a list") - else: - id_fields = _get_id_fields(schema) - metadata_fields = _get_id_fields(schema, "metadata") - duplicates = [] - uniqs = [] - for id in id_fields: - if id in uniqs: - duplicates.append(id) - else: - uniqs.append(id) - - if len(duplicates) > 0: - errors.append(f"Duplicate fields: {set(duplicates)} found") - - for id in id_fields: - if id.upper().startswith(EDA_PREFIX): - errors.append(f"{id} should not start with {EDA_PREFIX}") - - if not bool(re.match(r"^\w+$", id)): - errors.append( - f"{id} can only contain alphanumeric and " - "underscore characters" - ) - - formats_found = [] - for field in fields: - for option in ["id", "label"]: - value = field.get(option) - if not value or not isinstance(value, str): - errors.append(f"{option} must exist and be a string") - - field_type = field.get("type") - if field_type and field_type not in ["string", "boolean"]: - errors.append("type must be either string or boolean") - - choices = field.get("choices") - if choices: - if not isinstance(choices, list) or any( - not isinstance(choice, str) for choice in choices - ): - errors.append("choices must be a list of strings") - - for option in ["secret", "multiline"]: - value = field.get(option) - if value is not None and not isinstance(value, bool): - errors.append(f"{option} must be a boolean") - - for option in ["help_text", "format"]: - value = field.get(option) - if value is not None and not isinstance(value, str): - errors.append(f"{option} must be a string") - - default_value = field.get("default") - if default_value is not None: - if not isinstance(default_value, (str, bool)): - errors.append( - f"default for field '{field.get('id')}' " - "must be a string or boolean" - ) + return errors - field_format = field.get("format") - if field_format: - # Only validate format if it's a string (non-string types - # are caught by the validation above) - if isinstance(field_format, str): - formats_found.append(field_format) - if field_format not in VALID_FIELD_FORMATS: - errors.append( - f"invalid format: {field_format} for field " - f"{field.get('id')} " - "must be one of " - f"{' '.join(sorted(VALID_FIELD_FORMATS))}" - ) - - if "aes_key" in formats_found and "aes_salt" not in formats_found: - errors.append(AES_SALT_ERROR) + id_fields = _get_id_fields(schema) + metadata_fields = _get_id_fields(schema, "metadata") + _validate_field_ids(id_fields, errors) + formats_found = _validate_schema_fields(fields, errors) - required_fields = schema.get("required") - if required_fields: - if not isinstance(required_fields, list): - errors.append("required must be a list of strings") - else: - for field_id in required_fields: - if ( - field_id not in id_fields - and field_id not in metadata_fields - ): - errors.append(f"required field {field_id} does not exist") + if "aes_key" in formats_found and "aes_salt" not in formats_found: + errors.append(AES_SALT_ERROR) + _validate_required_fields(schema, id_fields, metadata_fields, errors) return errors +def _validate_field_ids(id_fields, errors): + """Check for duplicate and invalid field IDs.""" + seen = set() + duplicates = set() + for field_id in id_fields: + if field_id in seen: + duplicates.add(field_id) + seen.add(field_id) + + if duplicates: + errors.append(f"Duplicate fields: {duplicates} found") + + for field_id in id_fields: + if field_id.upper().startswith(EDA_PREFIX): + errors.append(f"{field_id} should not start with {EDA_PREFIX}") + if not bool(re.match(r"^\w+$", field_id)): + errors.append( + f"{field_id} can only contain alphanumeric " + "and underscore characters" + ) + + +def _validate_schema_fields(fields, errors): + """Validate individual field definitions in a schema.""" + formats_found = [] + for field in fields: + _validate_single_schema_field(field, errors, formats_found) + return formats_found + + +def _validate_single_schema_field(field, errors, formats_found): + """Validate a single field definition.""" + for option in ["id", "label"]: + value = field.get(option) + if not value or not isinstance(value, str): + errors.append(f"{option} must exist and be a string") + + field_type = field.get("type") + if field_type and field_type not in ["string", "boolean"]: + errors.append("type must be either string or boolean") + + choices = field.get("choices") + if choices and ( + not isinstance(choices, list) + or any(not isinstance(choice, str) for choice in choices) + ): + errors.append("choices must be a list of strings") + + for option, expected_type in [ + ("secret", bool), + ("multiline", bool), + ("help_text", str), + ("format", str), + ]: + value = field.get(option) + type_name = "a boolean" if expected_type is bool else "a string" + if value is not None and not isinstance(value, expected_type): + errors.append(f"{option} must be {type_name}") + + _validate_field_default(field, errors) + _validate_field_format_definition(field, errors, formats_found) + + +def _validate_field_default(field, errors): + default_value = field.get("default") + if default_value is not None and not isinstance( + default_value, (str, bool) + ): + errors.append( + f"default for field '{field.get('id')}' " + "must be a string or boolean" + ) + + +def _validate_field_format_definition(field, errors, formats_found): + field_format = field.get("format") + # Only validate format if it's a string + # (non-string types are caught by the validation above) + if not field_format or not isinstance(field_format, str): + return + formats_found.append(field_format) + if field_format not in VALID_FIELD_FORMATS: + errors.append( + f"invalid format: {field_format} " + f"for field {field.get('id')} " + "must be one of " + f"{' '.join(sorted(VALID_FIELD_FORMATS))}" + ) + + +def _validate_required_fields(schema, id_fields, metadata_fields, errors): + """Validate the required fields list.""" + required_fields = schema.get("required") + if not required_fields: + return + if not isinstance(required_fields, list): + errors.append("required must be a list of strings") + return + for field_id in required_fields: + if field_id not in id_fields and field_id not in metadata_fields: + errors.append(f"required field {field_id} does not exist") + + def validate_injectors(schema: dict, injectors: dict) -> dict: errors = [] @@ -412,42 +499,49 @@ def validate_injectors(schema: dict, injectors: dict) -> dict: context = _default_context(schema, injectors) key_names = [] for field in SUPPORTED_KEYS_IN_INJECTORS: - input_data = injectors.get(field) - if not input_data: - continue + _validate_injector_field(field, injectors, context, key_names, errors) - if not isinstance(input_data, dict): - errors.append(f"{field} must be a dict type") - continue + return {"injectors": errors} if bool(errors) else {} - try: - if field in ["extra_vars", "env"]: - check_reserved_keys_in_extra_vars(input_data) - except ValidationError as e: - errors.append(e.message) - continue - for k, v in input_data.items(): - try: - if k in key_names: - raise InjectorDuplicateKey( - f"Injector {field} key: {k} already exists" - ) +def _validate_injector_field(field, injectors, context, key_names, errors): + """Validate a single injector field.""" + input_data = injectors.get(field) + if not input_data: + return - if field == "file": - _validate_file_template_key(k, key_names) - if isinstance(v, str): - _check_jinja_string(v, context) - key_names.append(k) - except InjectorMissingKeyException as e: - errors.append( - f"Injector key: {k} has a value which refers to an" - f" undefined key error: {e}" - ) - except (InjectorInvalidTemplateKey, InjectorDuplicateKey) as e: - errors.append(f"{e}") + if not isinstance(input_data, dict): + errors.append(f"{field} must be a dict type") + return - return {"injectors": errors} if bool(errors) else {} + try: + if field in ["extra_vars", "env"]: + check_reserved_keys_in_extra_vars(input_data) + except ValidationError as e: + errors.append(e.message) + return + + for k, v in input_data.items(): + try: + if k in key_names: + raise InjectorDuplicateKey( + f"Injector {field} key: {k} already exists" + ) + if field == "file": + _validate_file_template_key(k, key_names) + if isinstance(v, str): + _check_jinja_string(v, context) + key_names.append(k) + except InjectorMissingKeyException as e: + errors.append( + f"Injector key: {k} has a value which " + f"refers to an undefined key error: {e}" + ) + except ( + InjectorInvalidTemplateKey, + InjectorDuplicateKey, + ) as e: + errors.append(f"{e}") def validate_registry_host_name(host: str) -> None: diff --git a/src/aap_eda/core/validators.py b/src/aap_eda/core/validators.py index e57c82825..879f75f51 100644 --- a/src/aap_eda/core/validators.py +++ b/src/aap_eda/core/validators.py @@ -70,16 +70,27 @@ def check_if_de_valid( image_url: str, eda_credential_id: tp.Optional[int] = None, ): - # The OCI standard format for the image url is a combination of a host - # (with optional port) separated from the image path (with optional tag) by - # a slash: [:port]/[:tag]. - # - # https://github.com/opencontainers/distribution-spec/blob/8376368dd8aadc33bf6c88a8b765df90287bb5c8/spec.md?plain=1#L155 # noqa: E501 - # - # We split the image url on the first slash into the host and path. The - # path is further split into a name and tag on the rightmost colon. - # - # The path and tag are validated using the OCI regexes for each. + host, name, tag, digest = _parse_image_url(image_url) + _validate_oci_name(image_url, name) + _validate_oci_tag(image_url, tag, digest) + if eda_credential_id: + _validate_credential_host(image_url, host, eda_credential_id) + + +def _parse_image_url(image_url): + """Parse an OCI image URL into host, name, tag, and digest flag. + + The OCI standard format for the image url is a combination of a + host (with optional port) separated from the image path (with + optional tag) by a slash: [:port]/[:tag]. + + https://github.com/opencontainers/distribution-spec/blob/8376368dd8aadc33bf6c88a8b765df90287bb5c8/spec.md?plain=1#L155 + + We split the image url on the first slash into the host and path. + The path is further split into a name and tag on the rightmost + colon. The path and tag are validated using the OCI regexes for + each. + """ split = image_url.split("/", 1) host = split[0] path = split[1] if len(split) > 1 else None @@ -93,15 +104,15 @@ def check_if_de_valid( try: validate_registry_host_name(host) except serializers.ValidationError: - # We raise our own instance of this exception in order to assert - # control over the format of the message. + # We raise our own instance of this exception in order to + # assert control over the format of the message. message = _( "Image url %(image_url)s is malformed; " "invalid host name: '%(host)s'" ) % {"image_url": image_url, "host": host} raise serializers.ValidationError({"image_url": message}) - if (path is None) or (path == ""): + if not path: message = _( "Image url %(image_url)s is malformed; no image path found" ) % {"image_url": image_url} @@ -110,18 +121,23 @@ def check_if_de_valid( digest = False if "@sha256" in path or "@sha512" in path: split = path.split("@", 1) - name = split[0] digest = True else: split = path.split(":", 1) - name = split[0] + name = split[0] # Get the tag sans any additional content. Any additional content # is passed without validation. - tag = split[1] if (len(split) > 1) else None - tag = tag if tag is None else tag.split("@", 1)[0] + tag = split[1] if len(split) > 1 else None + if tag is not None: + tag = tag.split("@", 1)[0] + + return host, name, tag, digest + +def _validate_oci_name(image_url, name): if not re.fullmatch( - r"[[a-z0-9]+((\.|_|__|-+)[a-z0-9]+)*(\/[a-z0-9]+((\.|_|__|-+)[a-z0-9]+)*)*", # noqa: E501 + r"[[a-z0-9]+((\.|_|__|-+)[a-z0-9]+)" + r"*(\/[a-z0-9]+((\.|_|__|-+)[a-z0-9]+)*)*", name, ): message = _( @@ -130,48 +146,50 @@ def check_if_de_valid( ) % {"image_url": image_url, "name": name} raise serializers.ValidationError({"image_url": message}) - if (not digest and tag is not None) and ( - not re.fullmatch(r"[a-zA-Z0-9_][a-zA-Z0-9._-]{0,127}", tag) - ): + +def _validate_oci_tag(image_url, tag, digest): + if digest or tag is None: + return + if not re.fullmatch(r"\w[\w.-]{0,127}", tag): message = _( "Image url %(image_url)s is malformed; " "'%(tag)s' does not match OCI tag standard" ) % {"image_url": image_url, "tag": tag} raise serializers.ValidationError({"image_url": message}) - if eda_credential_id: - credential = get_credential_if_exists(eda_credential_id) - inputs = yaml.safe_load(credential.inputs.get_secret_value()) - credential_host = inputs.get("host") - - if not credential_host: - message = _( - "Credential %(name)s needs to have host information" - ) % {"name": credential.name} - raise serializers.ValidationError({"image_url": message}) - - # Check that the host matches the credential host. - # For backward compatibility when creating a new DE with - # an old credential we need to separate any - # scheme from the host before doing the compare. - parsed_credential_host = urllib.parse.urlparse(credential_host) - # If there's a netloc that's the host to use; if not, it's the path if - # there is no scheme else it's the scheme and path joined by a colon. - if parsed_credential_host.netloc: - parsed_host = parsed_credential_host.netloc - else: - parsed_host = parsed_credential_host.path - if parsed_credential_host.scheme: - parsed_host = ":".join( - [parsed_credential_host.scheme, parsed_host] - ) - - if host != parsed_host: - message = _( - "DecisionEnvironment image url: %(image_url)s does " - "not match with the credential host: %(host)s" - ) % {"image_url": image_url, "host": credential_host} - raise serializers.ValidationError({"image_url": message}) + +def _validate_credential_host(image_url, host, eda_credential_id): + """Validate that the image host matches the credential host.""" + credential = get_credential_if_exists(eda_credential_id) + inputs = yaml.safe_load(credential.inputs.get_secret_value()) + credential_host = inputs.get("host") + + if not credential_host: + message = _("Credential %(name)s needs to have host information") % { + "name": credential.name + } + raise serializers.ValidationError({"image_url": message}) + + # For backward compatibility when creating a new DE with + # an old credential we need to separate any + # scheme from the host before doing the compare. + parsed = urllib.parse.urlparse(credential_host) + # If there's a netloc that's the host to use; if not, it's the + # path if there is no scheme else it's the scheme and path + # joined by a colon. + if parsed.netloc: + parsed_host = parsed.netloc + elif parsed.scheme: + parsed_host = ":".join([parsed.scheme, parsed.path]) + else: + parsed_host = parsed.path + + if host != parsed_host: + message = _( + "DecisionEnvironment image url: %(image_url)s does " + "not match with the credential host: %(host)s" + ) % {"image_url": image_url, "host": credential_host} + raise serializers.ValidationError({"image_url": message}) def get_credential_if_exists(eda_credential_id: int) -> models.EdaCredential: diff --git a/src/aap_eda/services/activation/activation_manager.py b/src/aap_eda/services/activation/activation_manager.py index d3867d809..63b4e2da9 100644 --- a/src/aap_eda/services/activation/activation_manager.py +++ b/src/aap_eda/services/activation/activation_manager.py @@ -786,7 +786,7 @@ def delete(self): """User requested delete.""" LOGGER.info( f"Delete operation requested for activation id: " - f"{self.db_instance.id},", + f"{self.db_instance.id}", ) try: self._cleanup() @@ -896,25 +896,8 @@ def monitor(self): self._detect_running_status() - # get the status of the container - container_status = None - try: - container_status = self.container_engine.get_status( - container_id=self.latest_instance.activation_pod_id, - ) - except engine_exceptions.ContainerNotFoundError: - pass - except engine_exceptions.ContainerEngineError as exc: - msg = ( - f"Monitor operation: activation id: {self.db_instance.id} " - f"Failed to get status of the container. Reason: {exc}" - ) - LOGGER.warning(msg) - self.set_status(ActivationStatus.WORKERS_OFFLINE, msg) - self.set_latest_instance_status( - ActivationStatus.WORKERS_OFFLINE, - msg, - ) + container_status, engine_error = self._get_container_status() + if engine_error: return # Activations in running status must have a container @@ -998,6 +981,38 @@ def monitor(self): "is in an stopped state.", ) + def _get_container_status(self): + """Get the container status for the latest instance. + + Returns a (status, engine_error) tuple. When the container + engine reports a transient error the activation is moved to + WORKERS_OFFLINE and engine_error is True so the caller can + exit early. + """ + try: + return ( + self.container_engine.get_status( + container_id=self.latest_instance.activation_pod_id, + ), + False, + ) + except engine_exceptions.ContainerNotFoundError: + return None, False + except engine_exceptions.ContainerEngineError as exc: + msg = ( + f"Monitor operation: activation id: " + f"{self.db_instance.id} " + f"Failed to get status of the container. " + f"Reason: {exc}" + ) + LOGGER.warning(msg) + self.set_status(ActivationStatus.WORKERS_OFFLINE, msg) + self.set_latest_instance_status( + ActivationStatus.WORKERS_OFFLINE, + msg, + ) + return None, True + def _detect_running_status(self): if ( self.latest_instance.status == ActivationStatus.STARTING diff --git a/src/aap_eda/services/activation/engine/kubernetes.py b/src/aap_eda/services/activation/engine/kubernetes.py index 3495ee327..8f6ced516 100644 --- a/src/aap_eda/services/activation/engine/kubernetes.py +++ b/src/aap_eda/services/activation/engine/kubernetes.py @@ -661,6 +661,7 @@ def _process_pod_start_event(self, event) -> bool: return False def _wait_for_pod_to_start(self, log_handler: LogHandler) -> None: + """Wait for the pod to reach a running state.""" LOGGER.info("Waiting for pod to start") desc = f"watch pod start {self.job_name}" last_exc = None diff --git a/src/aap_eda/services/activation/engine/ports.py b/src/aap_eda/services/activation/engine/ports.py index bcad9e1a4..9e4ed2353 100644 --- a/src/aap_eda/services/activation/engine/ports.py +++ b/src/aap_eda/services/activation/engine/ports.py @@ -19,54 +19,44 @@ def render_string(value: str, context: dict) -> str: def find_ports(rulebook_text: str, context: dict = None) -> list[tuple]: - """ - Return (host, port) pairs for all sources in a rulebook. - - Walk the rulebook and find ports in source parameters - Assume the rulebook is valid if it imported - """ + """Return (host, port) pairs for all sources in a rulebook.""" rulebook = yaml.safe_load(rulebook_text) - - # Make a list of host, port pairs found in all sources in - # rulesets in a rulebook found_ports = [] - - # Walk all rulesets in a rulebook for ruleset in rulebook: - # Walk through all sources in a ruleset for source in ruleset.get("sources", []): - # Remove name from source - if "name" in source: - del source["name"] - # The first remaining key is the type and the arguments - source_plugin = list(source.keys())[0] - - if source_plugin not in settings.SAFE_PLUGINS_FOR_PORT_FORWARD: - continue - - source_args = source[source_plugin] - if source_args is None: - continue - # Get host if it exists - # Maybe check for "0.0.0.0" in the future - host = source_args.get("host") - # Get port if it exists - maybe_port = source_args.get("port") - # port may be a string or an integer - if maybe_port is None: - continue - - try: - maybe_port = render_string(str(maybe_port), context or {}) + result = _extract_port(source, context or {}) + if result is not None: + found_ports.append(result) + return found_ports - with contextlib.suppress(ValueError): - found_ports.append((host, int(maybe_port))) - except ValueError as e: - LOGGER.error(f"find_ports error: {e}") - raise exceptions.ActivationStartError(str(e)) - except UndefinedError as e: - raise exceptions.ActivationStartError(str(e)) - except SecurityError as e: - raise exceptions.ActivationStartError(str(e)) - return found_ports +def _extract_port(source, context): + """Extract a (host, port) pair from a single source.""" + if "name" in source: + del source["name"] + # The first remaining key is the type and the arguments + source_plugin = next(iter(source)) + + if source_plugin not in settings.SAFE_PLUGINS_FOR_PORT_FORWARD: + return None + + source_args = source[source_plugin] + if source_args is None: + return None + + host = source_args.get("host") + # port may be a string or an integer + maybe_port = source_args.get("port") + if maybe_port is None: + return None + + try: + maybe_port = render_string(str(maybe_port), context) + with contextlib.suppress(ValueError): + return (host, int(maybe_port)) + return None + except ValueError as e: + LOGGER.exception(f"find_ports error: {e}") + raise exceptions.ActivationStartError(str(e)) + except (UndefinedError, SecurityError) as e: + raise exceptions.ActivationStartError(str(e)) diff --git a/src/aap_eda/services/project/imports.py b/src/aap_eda/services/project/imports.py index 0507d03b1..eff01b70a 100644 --- a/src/aap_eda/services/project/imports.py +++ b/src/aap_eda/services/project/imports.py @@ -227,38 +227,44 @@ def _sync_rulebook( ) def _find_rulebooks(self, repo: StrPath) -> Iterator[RulebookInfo]: - rulebooks_dir = None - for name in ["extensions/eda/rulebooks", "rulebooks"]: - if os.path.exists(os.path.join(repo, name)): - rulebooks_dir = os.path.join(repo, name) - break - - if not rulebooks_dir: - raise ProjectImportWarning( - "The 'extensions/eda/rulebooks' or 'rulebooks' directory" - " doesn't exist within the project root." - ) + rulebooks_dir = self._locate_rulebooks_dir(repo) for root, _dirs, files in os.walk(rulebooks_dir): for filename in files: - path = os.path.join(root, filename) _base, ext = os.path.splitext(filename) if ext not in YAML_EXTENSIONS: continue - try: - info = self._try_load_rulebook(rulebooks_dir, path) - except Exception: - logger.error( - "Unexpected exception when scanning file %s." - " Skipping.", - path, - exc_info=settings.DEBUG, - ) - continue - if not info: - logger.warning("Not a rulebook file: %s", path) - continue - yield info + path = os.path.join(root, filename) + info = self._safe_load_rulebook(rulebooks_dir, path) + if info: + yield info + + def _locate_rulebooks_dir(self, repo: StrPath) -> str: + for name in ["extensions/eda/rulebooks", "rulebooks"]: + candidate = os.path.join(repo, name) + if os.path.exists(candidate): + return candidate + raise ProjectImportWarning( + "The 'extensions/eda/rulebooks' or 'rulebooks' directory" + " doesn't exist within the project root." + ) + + def _safe_load_rulebook( + self, rulebooks_dir: StrPath, path: str + ) -> Optional[RulebookInfo]: + try: + info = self._try_load_rulebook(rulebooks_dir, path) + except Exception: + logger.error( + "Unexpected exception when scanning file %s. Skipping.", + path, + exc_info=settings.DEBUG, + ) + return None + if not info: + logger.warning("Not a rulebook file: %s", path) + return None + return info def _try_load_rulebook( self, rulebooks_dir: StrPath, rulebook_path: StrPath diff --git a/src/aap_eda/services/project/scm.py b/src/aap_eda/services/project/scm.py index 014bb1181..f107ba525 100644 --- a/src/aap_eda/services/project/scm.py +++ b/src/aap_eda/services/project/scm.py @@ -179,43 +179,13 @@ def clone( os.makedirs(path) extra_vars = {"project_path": path} env_vars = {} - final_url = url - secret = "" - key_file = None - key_password = None - gpg_key_file = None - gpg_home_dir = None with set_proxy_environ(proxy): - if credential: - inputs = credentials.get_resolved_secrets(credential) - secret = inputs.get("password", "") - key_data = inputs.get("ssh_key_data", "") - - final_url = cls.build_url( - url, - inputs.get("username", ""), - secret, - key_data, - ) - - if key_data: # ssh - key_file = tempfile.NamedTemporaryFile("w+t") - key_file.write(key_data) - key_file.write("\n") - key_file.flush() - extra_vars["key_file"] = key_file.name - key_password = inputs.get("ssh_key_unlock") - - if gpg_credential: - gpg_inputs = credentials.get_resolved_secrets(gpg_credential) - gpg_key = gpg_inputs.get("gpg_public_key") - gpg_key_file = tempfile.NamedTemporaryFile("w+t") - gpg_key_file.write(gpg_key) - gpg_key_file.write("\n") - gpg_key_file.flush() - extra_vars["verify_commit"] = "true" - gpg_home_dir = tempfile.TemporaryDirectory() - env_vars["GNUPGHOME"] = gpg_home_dir.name + final_url, secret, key_file, key_password = cls._setup_credential( + url, credential, extra_vars + ) + gpg_key_file, gpg_home_dir = cls._setup_gpg_credential( + gpg_credential, extra_vars, env_vars + ) if not verify_ssl: extra_vars["ssl_no_verify"] = "true" @@ -242,18 +212,7 @@ def clone( with contextlib.chdir(path): git_hash = _executor(extra_vars=extra_vars, env_vars=env_vars) except ScmError as e: - msg = str(e) - # Replace credential-embedded URL with clean URL instead - # of redacting the password substring (which would create - # an oracle attack vector — see AAP-72813). - if final_url != url: - msg = msg.replace(final_url, url) - # Handle URL-decoded form for passwords with special chars - # (e.g., "p@ss" is encoded as "p%40ss" in final_url, but - # git may print the decoded form in error messages). - if secret and secret != quote(secret, safe=""): - raw_url = final_url.replace(quote(secret, safe=""), secret) - msg = msg.replace(raw_url, url) + msg = cls._sanitize_clone_error(str(e), url, final_url, secret) logger.warning("SCM clone failed: %s", msg) raise e.__class__(msg) from None finally: @@ -266,6 +225,66 @@ def clone( instance.git_hash = git_hash return instance + @classmethod + def _setup_credential(cls, url, credential, extra_vars): + """Set up SCM credential, returning URL and key info.""" + final_url = url + secret = "" + key_file = None + key_password = None + if credential: + inputs = credentials.get_resolved_secrets(credential) + secret = inputs.get("password", "") + key_data = inputs.get("ssh_key_data", "") + + final_url = cls.build_url( + url, + inputs.get("username", ""), + secret, + key_data, + ) + + if key_data: # ssh + key_file = tempfile.NamedTemporaryFile("w+t") + key_file.write(key_data) + key_file.write("\n") + key_file.flush() + extra_vars["key_file"] = key_file.name + key_password = inputs.get("ssh_key_unlock") + return final_url, secret, key_file, key_password + + @classmethod + def _setup_gpg_credential(cls, gpg_credential, extra_vars, env_vars): + """Set up GPG credential for commit verification.""" + gpg_key_file = None + gpg_home_dir = None + if gpg_credential: + gpg_inputs = credentials.get_resolved_secrets(gpg_credential) + gpg_key = gpg_inputs.get("gpg_public_key") + gpg_key_file = tempfile.NamedTemporaryFile("w+t") + gpg_key_file.write(gpg_key) + gpg_key_file.write("\n") + gpg_key_file.flush() + extra_vars["verify_commit"] = "true" + gpg_home_dir = tempfile.TemporaryDirectory() + env_vars["GNUPGHOME"] = gpg_home_dir.name + return gpg_key_file, gpg_home_dir + + @classmethod + def _sanitize_clone_error(cls, msg, url, final_url, secret): + """Redact credentials from clone error messages. + + Replaces credential-embedded URL with clean URL instead + of redacting the password substring (which would create + an oracle attack vector — see AAP-72813). + """ + if final_url != url: + msg = msg.replace(final_url, url) + if secret and secret != quote(secret, safe=""): + raw_url = final_url.replace(quote(secret, safe=""), secret) + msg = msg.replace(raw_url, url) + return msg + @classmethod def build_url( cls, url: str, user: str, password: str, ssh_key: str @@ -284,7 +303,9 @@ def build_url( if user and password: encoded_user = quote(user, safe="") encoded_password = quote(password, safe="") - domain = f"{encoded_user}:{encoded_password}@{domain}" + domain = ( + f"{encoded_user}:{encoded_password}@{domain}" # noqa: E231 + ) elif password: encoded_token = quote(password, safe="") domain = f"{encoded_token}@{domain}" diff --git a/src/aap_eda/tasks/activation_request_queue.py b/src/aap_eda/tasks/activation_request_queue.py index 495eaba0d..f57770867 100644 --- a/src/aap_eda/tasks/activation_request_queue.py +++ b/src/aap_eda/tasks/activation_request_queue.py @@ -86,41 +86,45 @@ def _arbitrate( ref_request = request continue - # nothing can be done after delete - # or dedup - # or skip auto_start - if ( - ref_request.request == ActivationRequest.DELETE - or request.request == ref_request.request - or request.request == ActivationRequest.AUTO_START - ): - request.delete() - continue - - if ref_request.request == ActivationRequest.AUTO_START: - ref_request.delete() - ref_request = request - continue - - if ( - request.request == ActivationRequest.STOP - or request.request == ActivationRequest.DELETE - ): - while qualified_requests: - qualified = qualified_requests.pop() - qualified.delete() - ref_request.delete() - ref_request = request - continue - - if request.request in starts and ref_request.request in starts: - request.delete() - continue - - qualified_requests.append(ref_request) - ref_request = request + ref_request = _resolve_request_pair( + ref_request, request, qualified_requests, starts + ) if ref_request: qualified_requests.append(ref_request) return qualified_requests + + +def _resolve_request_pair(ref_request, request, qualified_requests, starts): + """Resolve a pair of activation requests, returning the new reference.""" + # nothing can be done after delete + # or dedup + # or skip auto_start + if ( + ref_request.request == ActivationRequest.DELETE + or request.request == ref_request.request + or request.request == ActivationRequest.AUTO_START + ): + request.delete() + return ref_request + + if ref_request.request == ActivationRequest.AUTO_START: + ref_request.delete() + return request + + if request.request in ( + ActivationRequest.STOP, + ActivationRequest.DELETE, + ): + while qualified_requests: + qualified_requests.pop().delete() + ref_request.delete() + return request + + if request.request in starts and ref_request.request in starts: + request.delete() + return ref_request + + qualified_requests.append(ref_request) + return request diff --git a/src/aap_eda/tasks/orchestrator.py b/src/aap_eda/tasks/orchestrator.py index 14f90a9a0..a39f308c2 100644 --- a/src/aap_eda/tasks/orchestrator.py +++ b/src/aap_eda/tasks/orchestrator.py @@ -266,105 +266,15 @@ def queue_dispatch( return else: - queue_name = get_queue_name_by_parent_id( + queue_name = _resolve_existing_queue( process_parent_type, process_parent_id, + request_type, + process_parent, + status_manager, ) - - # If there is not an associated queue or the queue does not exist - # within the configured queues (i.e., it is from a previous deployment - # with different queues) we get a queue to use. - if (not queue_name) or ( - queue_name not in settings.RULEBOOK_WORKER_QUEUES - ): - if not queue_name: - LOGGER.info( - "Scheduling request " - f"{request_type} for {process_parent_type} " - f"{process_parent_id} to the least busy queue; " - "it is not currently associated with a queue.", - ) - else: - LOGGER.info( - "Scheduling request" - f"{request_type} for {process_parent_type} " - f"{process_parent_id} to the least busy queue; " - f"its associated queue '{queue_name}' is from " - "previous configuation settings.", - ) - try: - queue_name = get_least_busy_queue_name() - except HealthyQueueNotFoundError: - msg = ( - f"There are no healthy queues to process operation " - f"{request_type} for {process_parent_type} " - f"{process_parent_id}. Waiting for a worker. " - "There may be an issue with the system; please " - "contact the administrator." - ) - LOGGER.error(msg) - status_manager.set_status( - ActivationStatus.PENDING, - msg, - ) - return - elif not check_rulebook_queue_health(queue_name): - # The queue is unhealthy. If we're not restarting it there's - # nothing we can do except update its status to WORKERS_OFFLINE. - if request_type != ActivationRequest.RESTART: - # A process in PENDING status don't need to update its status. - # A monitor can be scheduled for an activation in PENDING - # status if its latest process is in workers-offline status - # and it is scheduled for restart. - if process_parent.status == ActivationStatus.PENDING: - return - - # If the process is in WORKERS_OFFLINE status, it is already - # in a bad state. We don't need to update its status. - if process_parent.status == ActivationStatus.WORKERS_OFFLINE: - return - - msg = ( - f"{process_parent_type} {process_parent_id} is in an " - "unknown state. The workers of its associated queue " - f"'{queue_name}' are failing liveness checks. " - "There may be an issue with the worker node; " - "please contact the administrator." - ) - status_manager.set_status( - ActivationStatus.WORKERS_OFFLINE, - msg, - ) - status_manager.set_latest_instance_status( - ActivationStatus.WORKERS_OFFLINE, - msg, - ) - LOGGER.error(msg) - return - - # The queue is unhealthy, but this is a restart. - # The priority is to adhere to the restart policy and - # execute the task. - LOGGER.warning( - f"Forcing user restart of {process_parent_type} " - f"{process_parent_id} on the least busy queue; " - "after failing liveness checks of current associated queue" - ) - try: - queue_name = get_least_busy_queue_name() - except HealthyQueueNotFoundError: - msg = ( - f"There are no healthy queues to process the " - f"restart request for {process_parent_type} " - f"{process_parent_id}. There may be an issue " - "with the system; please contact the administrator." - ) - LOGGER.error(msg) - status_manager.set_status( - ActivationStatus.PENDING, - msg, - ) - return + if queue_name is None: + return LOGGER.info( f"Trying to enqueue {process_parent_type} {process_parent_id} " f"request {request_type} to queue {queue_name}" @@ -385,6 +295,151 @@ def queue_dispatch( ) +def _resolve_existing_queue( + process_parent_type, + process_parent_id, + request_type, + process_parent, + status_manager, +): + """Resolve the queue for an existing process. + + Returns the queue name, or None if the request cannot + be dispatched. + """ + queue_name = get_queue_name_by_parent_id( + process_parent_type, + process_parent_id, + ) + + # If there is not an associated queue or the queue does not + # exist within the configured queues (i.e., it is from a + # previous deployment with different queues) we get a queue + # to use. + if not queue_name or (queue_name not in settings.RULEBOOK_WORKER_QUEUES): + if not queue_name: + LOGGER.info( + "Scheduling request " + f"{request_type} for {process_parent_type} " + f"{process_parent_id} to the least busy queue; " + "it is not currently associated with a queue.", + ) + else: + LOGGER.info( + "Scheduling request" + f"{request_type} for {process_parent_type} " + f"{process_parent_id} to the least busy queue; " + f"its associated queue '{queue_name}' is from " + "previous configuation settings.", + ) + try: + return get_least_busy_queue_name() + except HealthyQueueNotFoundError: + msg = ( + f"There are no healthy queues to process " + f"operation {request_type} for " + f"{process_parent_type} " + f"{process_parent_id}. Waiting for a worker. " + "There may be an issue with the system; " + "please contact the administrator." + ) + LOGGER.error(msg) + status_manager.set_status( + ActivationStatus.PENDING, + msg, + ) + return None + + if not check_rulebook_queue_health(queue_name): + return _handle_unhealthy_queue( + queue_name, + process_parent_type, + process_parent_id, + request_type, + process_parent, + status_manager, + ) + + return queue_name + + +def _handle_unhealthy_queue( + queue_name, + process_parent_type, + process_parent_id, + request_type, + process_parent, + status_manager, +): + """Handle dispatch when the associated queue is unhealthy. + + Returns a fallback queue name, or None if the request + cannot be dispatched. + """ + # The queue is unhealthy. If we're not restarting it there's + # nothing we can do except update its status to + # WORKERS_OFFLINE. + if request_type != ActivationRequest.RESTART: + # A process in PENDING status don't need to update its + # status. A monitor can be scheduled for an activation + # in PENDING status if its latest process is in + # workers-offline status and it is scheduled for restart. + # + # If the process is in WORKERS_OFFLINE status, it is + # already in a bad state. We don't need to update its + # status. + if process_parent.status in ( + ActivationStatus.PENDING, + ActivationStatus.WORKERS_OFFLINE, + ): + return None + + msg = ( + f"{process_parent_type} {process_parent_id} is " + "in an unknown state. The workers of its " + f"associated queue '{queue_name}' are failing " + "liveness checks. There may be an issue with " + "the worker node; please contact the " + "administrator." + ) + status_manager.set_status( + ActivationStatus.WORKERS_OFFLINE, + msg, + ) + status_manager.set_latest_instance_status( + ActivationStatus.WORKERS_OFFLINE, + msg, + ) + LOGGER.error(msg) + return None + + # The queue is unhealthy, but this is a restart. + # The priority is to adhere to the restart policy and + # execute the task. + LOGGER.warning( + f"Forcing user restart of {process_parent_type} " + f"{process_parent_id} on the least busy queue; " + "after failing liveness checks of current " + "associated queue" + ) + try: + return get_least_busy_queue_name() + except HealthyQueueNotFoundError: + msg = ( + f"There are no healthy queues to process the " + f"restart request for {process_parent_type} " + f"{process_parent_id}. There may be an issue " + "with the system; please contact the " + "administrator." + ) + LOGGER.error(msg) + status_manager.set_status( + ActivationStatus.PENDING, + msg, + ) + return None + + def get_least_busy_queue_name() -> str: """Return the queue name with the least running processes.""" queue_counter = Counter() diff --git a/src/aap_eda/utils/openapi.py b/src/aap_eda/utils/openapi.py index 41e6cf7ad..6a67e9bd3 100644 --- a/src/aap_eda/utils/openapi.py +++ b/src/aap_eda/utils/openapi.py @@ -17,41 +17,47 @@ from drf_spectacular.utils import OpenApiParameter from rest_framework.serializers import Serializer +_FIELD_TYPE_MAP = { + models.CharField: OpenApiTypes.STR, + models.IntegerField: OpenApiTypes.NUMBER, + models.DateField: OpenApiTypes.DATETIME, + models.BooleanField: OpenApiTypes.BOOL, +} -def generate_query_params(serializer: Serializer) -> list[OpenApiParameter]: - """Generate OpenAPI query parameters dynamically based on the view's serializer fields and model.""" # noqa: E501 + +def _get_openapi_type(field): + for field_cls, api_type in _FIELD_TYPE_MAP.items(): + if isinstance(field, field_cls): + return api_type + return OpenApiTypes.STR + + +def _resolve_param_name(name, field_names): + if name in field_names: + return name + id_name = f"{name}_id" + if id_name in field_names: + return id_name + return None + + +def generate_query_params( + serializer: Serializer, +) -> list[OpenApiParameter]: + """Generate OpenAPI query parameters dynamically.""" query_params = [] model = serializer.Meta.model - fields = serializer.get_fields() - field_names = fields.keys() + field_names = set(serializer.get_fields().keys()) for field in model._meta.get_fields(): - # check if model field name is defined in the serializer - if ( - field.name in field_names - or "_".join([field.name, "id"]) in field_names - ): - param_name = ( - field.name - if field.name in field_names - else "_".join([field.name, "id"]) - ) - query_params.append( - OpenApiParameter( - name=param_name, - description=f"Filter by {param_name}", - required=False, - type=( - OpenApiTypes.STR - if isinstance(field, models.CharField) - else OpenApiTypes.NUMBER - if isinstance(field, models.IntegerField) - else OpenApiTypes.DATETIME - if isinstance(field, models.DateField) - else OpenApiTypes.BOOL - if isinstance(field, models.BooleanField) - else OpenApiTypes.STR - ), - ) + param_name = _resolve_param_name(field.name, field_names) + if param_name is None: + continue + query_params.append( + OpenApiParameter( + name=param_name, + description=f"Filter by {param_name}", + required=False, + type=_get_openapi_type(field), ) - + ) return query_params diff --git a/src/aap_eda/wsapi/consumers.py b/src/aap_eda/wsapi/consumers.py index edc19c3cc..1d785267f 100644 --- a/src/aap_eda/wsapi/consumers.py +++ b/src/aap_eda/wsapi/consumers.py @@ -273,16 +273,7 @@ def insert_event_related_data(self, message: AnsibleEventMessage) -> None: @database_sync_to_async def insert_audit_rule_data(self, message: ActionMessage) -> None: - job_instance_id = None - if message.job_id: - job_instance = models.JobInstance.objects.filter( - uuid=message.job_id - ).first() - job_instance_id = job_instance.id if job_instance else None - - audit_rule = models.AuditRule.objects.filter( - rule_uuid=message.rule_uuid, fired_at=message.rule_run_at - ).first() + job_instance_id = self._resolve_job_instance_id(message) try: activation_instance = models.RulebookProcess.objects.get( @@ -292,6 +283,30 @@ def insert_audit_rule_data(self, message: ActionMessage) -> None: logger.error(f"RulebookProcess {message.activation_id} not found") raise + audit_rule = self._get_or_create_audit_rule( + message, activation_instance, job_instance_id + ) + audit_action = self._get_or_create_audit_action( + message, activation_instance, audit_rule + ) + self._process_matching_events(message, audit_action) + + def _resolve_job_instance_id(self, message): + if not message.job_id: + return None + job_instance = models.JobInstance.objects.filter( + uuid=message.job_id + ).first() + return job_instance.id if job_instance else None + + def _get_or_create_audit_rule( + self, message, activation_instance, job_instance_id + ): + audit_rule = models.AuditRule.objects.filter( + rule_uuid=message.rule_uuid, + fired_at=message.rule_run_at, + ).first() + if audit_rule is None: activation_org = models.Organization.objects.filter( id=activation_instance.organization.id @@ -307,71 +322,76 @@ def insert_audit_rule_data(self, message: ActionMessage) -> None: status=message.status, organization=activation_org, ) - logger.info(f"Audit rule [{audit_rule.name}] is created.") - else: - # if rule has multiple actions and one of its action's status is - # 'failed', keep rule's status as 'failed' - if ( - audit_rule.status != message.status - and audit_rule.status != "failed" - ): - audit_rule.status = message.status - audit_rule.save() + # if rule has multiple actions and one of its action's + # status is 'failed', keep rule's status as 'failed' + elif ( + audit_rule.status != message.status + and audit_rule.status != "failed" + ): + audit_rule.status = message.status + audit_rule.save() + + return audit_rule + def _get_or_create_audit_action( + self, message, activation_instance, audit_rule + ): audit_action = models.AuditAction.objects.filter( id=message.action_uuid ).first() - if audit_action is None: - inputs = {} - aap_credential_type = models.CredentialType.objects.filter( - name=DefaultCredentialType.AAP - ) - if aap_credential_type: - credentials = ( - activation_instance.get_parent().eda_credentials.filter( - credential_type_id=aap_credential_type[0].id - ) - ) - if credentials: - inputs = get_resolved_secrets(credentials[0]) - - url = self._get_url(message, inputs) - audit_action = models.AuditAction.objects.create( - id=message.action_uuid, - fired_at=message.run_at, - name=message.action, - url=url, - status=message.status, - rule_fired_at=message.rule_run_at, - audit_rule_id=audit_rule.id, - status_message=message.message, - ) + if audit_action is not None: + return audit_action + + inputs = self._resolve_aap_inputs(activation_instance) + url = self._get_url(message, inputs) + audit_action = models.AuditAction.objects.create( + id=message.action_uuid, + fired_at=message.run_at, + name=message.action, + url=url, + status=message.status, + rule_fired_at=message.rule_run_at, + audit_rule_id=audit_rule.id, + status_message=message.message, + ) + logger.info(f"Audit action [{audit_action.name}] is created.") + return audit_action - logger.info(f"Audit action [{audit_action.name}] is created.") + def _resolve_aap_inputs(self, activation_instance): + aap_credential_type = models.CredentialType.objects.filter( + name=DefaultCredentialType.AAP + ) + if not aap_credential_type: + return {} + credentials = activation_instance.get_parent().eda_credentials.filter( + credential_type_id=aap_credential_type[0].id + ) + if credentials: + return get_resolved_secrets(credentials[0]) + return {} - matching_events = message.matching_events - for event_meta in matching_events.values(): + def _process_matching_events(self, message, audit_action): + for event_meta in message.matching_events.values(): meta = event_meta.pop("meta") - if meta: - audit_event = models.AuditEvent.objects.filter( - id=meta.get("uuid") - ).first() - - if audit_event is None: - audit_event = models.AuditEvent.objects.create( - id=meta.get("uuid"), - source_name=meta.get("source", {}).get("name"), - source_type=meta.get("source", {}).get("type"), - payload=event_meta, - received_at=meta.get("received_at"), - rule_fired_at=message.rule_run_at, - ) - logger.info(f"Audit event [{audit_event.id}] is created.") - - audit_event.audit_actions.add(audit_action) - audit_event.save() + if not meta: + continue + audit_event = models.AuditEvent.objects.filter( + id=meta.get("uuid") + ).first() + if audit_event is None: + audit_event = models.AuditEvent.objects.create( + id=meta.get("uuid"), + source_name=meta.get("source", {}).get("name"), + source_type=meta.get("source", {}).get("type"), + payload=event_meta, + received_at=meta.get("received_at"), + rule_fired_at=message.rule_run_at, + ) + logger.info(f"Audit event [{audit_event.id}] is created.") + audit_event.audit_actions.add(audit_action) + audit_event.save() @database_sync_to_async def insert_job_related_data( diff --git a/tests/unit/test_orchestrator.py b/tests/unit/test_orchestrator.py index 34fa03cd3..d255d0ab5 100644 --- a/tests/unit/test_orchestrator.py +++ b/tests/unit/test_orchestrator.py @@ -20,12 +20,19 @@ import pytest from django.conf import settings -from aap_eda.core.enums import ActivationStatus, ProcessParentType +from aap_eda.core.enums import ( + ActivationRequest, + ActivationStatus, + ProcessParentType, +) from aap_eda.core.models import Activation, RulebookProcess from aap_eda.tasks import orchestrator from aap_eda.tasks.exceptions import UnknownProcessParentType from aap_eda.tasks.orchestrator import ( + HealthyQueueNotFoundError, + _handle_unhealthy_queue, _manage, + _resolve_existing_queue, get_least_busy_queue_name, get_process_parent, ) @@ -737,3 +744,247 @@ def test_manage_monitor_runs_when_queue_returns_none( ) manager_mock.monitor.assert_called_once() + + +################################################################# +# Tests for _resolve_existing_queue and _handle_unhealthy_queue +################################################################# + + +@pytest.mark.django_db +def test_resolve_existing_queue_no_queue_assigned(): + """When no queue is assigned, fall back to least busy.""" + status_manager = mock.Mock() + with mock.patch( + "aap_eda.tasks.orchestrator.get_queue_name_by_parent_id", + return_value=None, + ), mock.patch( + "aap_eda.tasks.orchestrator.get_least_busy_queue_name", + return_value="healthy-queue", + ): + result = _resolve_existing_queue( + ProcessParentType.ACTIVATION, + 1, + "Monitor", + mock.Mock(), + status_manager, + ) + + assert result == "healthy-queue" + + +@pytest.mark.django_db +def test_resolve_existing_queue_stale_queue(): + """When queue is not in configured queues, fall back to least busy.""" + status_manager = mock.Mock() + with mock.patch( + "aap_eda.tasks.orchestrator.get_queue_name_by_parent_id", + return_value="old-queue", + ), mock.patch( + "aap_eda.tasks.orchestrator.settings" + ) as mock_settings, mock.patch( + "aap_eda.tasks.orchestrator.get_least_busy_queue_name", + return_value="healthy-queue", + ): + mock_settings.RULEBOOK_WORKER_QUEUES = ["activation"] + result = _resolve_existing_queue( + ProcessParentType.ACTIVATION, + 1, + "Monitor", + mock.Mock(), + status_manager, + ) + + assert result == "healthy-queue" + + +@pytest.mark.django_db +def test_resolve_existing_queue_no_healthy_queues(): + """When no healthy queues exist, set PENDING and return None.""" + status_manager = mock.Mock() + with mock.patch( + "aap_eda.tasks.orchestrator.get_queue_name_by_parent_id", + return_value=None, + ), mock.patch( + "aap_eda.tasks.orchestrator.get_least_busy_queue_name", + side_effect=HealthyQueueNotFoundError, + ): + result = _resolve_existing_queue( + ProcessParentType.ACTIVATION, + 1, + "Monitor", + mock.Mock(), + status_manager, + ) + + assert result is None + status_manager.set_status.assert_called_once_with( + ActivationStatus.PENDING, mock.ANY + ) + + +@pytest.mark.django_db +def test_resolve_existing_queue_healthy(): + """When assigned queue is healthy, return it.""" + status_manager = mock.Mock() + with mock.patch( + "aap_eda.tasks.orchestrator.get_queue_name_by_parent_id", + return_value="activation", + ), mock.patch( + "aap_eda.tasks.orchestrator.settings" + ) as mock_settings, mock.patch( + "aap_eda.tasks.orchestrator.check_rulebook_queue_health", + return_value=True, + ): + mock_settings.RULEBOOK_WORKER_QUEUES = ["activation"] + result = _resolve_existing_queue( + ProcessParentType.ACTIVATION, + 1, + "Monitor", + mock.Mock(), + status_manager, + ) + + assert result == "activation" + + +@pytest.mark.django_db +def test_resolve_existing_queue_unhealthy_delegates(): + """When queue is unhealthy, delegate to _handle_unhealthy_queue.""" + status_manager = mock.Mock() + process_parent = mock.Mock() + with mock.patch( + "aap_eda.tasks.orchestrator.get_queue_name_by_parent_id", + return_value="activation", + ), mock.patch( + "aap_eda.tasks.orchestrator.settings" + ) as mock_settings, mock.patch( + "aap_eda.tasks.orchestrator.check_rulebook_queue_health", + return_value=False, + ), mock.patch( + "aap_eda.tasks.orchestrator._handle_unhealthy_queue", + return_value="fallback-queue", + ) as mock_handle: + mock_settings.RULEBOOK_WORKER_QUEUES = ["activation"] + result = _resolve_existing_queue( + ProcessParentType.ACTIVATION, + 1, + "Monitor", + process_parent, + status_manager, + ) + + assert result == "fallback-queue" + mock_handle.assert_called_once() + + +@pytest.mark.django_db +def test_handle_unhealthy_queue_non_restart_pending(): + """Non-restart with PENDING parent returns None without update.""" + process_parent = mock.Mock() + process_parent.status = ActivationStatus.PENDING + status_manager = mock.Mock() + + result = _handle_unhealthy_queue( + "activation", + ProcessParentType.ACTIVATION, + 1, + "Monitor", + process_parent, + status_manager, + ) + + assert result is None + status_manager.set_status.assert_not_called() + + +@pytest.mark.django_db +def test_handle_unhealthy_queue_non_restart_workers_offline(): + """Non-restart with WORKERS_OFFLINE parent returns None.""" + process_parent = mock.Mock() + process_parent.status = ActivationStatus.WORKERS_OFFLINE + status_manager = mock.Mock() + + result = _handle_unhealthy_queue( + "activation", + ProcessParentType.ACTIVATION, + 1, + "Monitor", + process_parent, + status_manager, + ) + + assert result is None + status_manager.set_status.assert_not_called() + + +@pytest.mark.django_db +def test_handle_unhealthy_queue_non_restart_running(): + """Non-restart with RUNNING parent sets WORKERS_OFFLINE.""" + process_parent = mock.Mock() + process_parent.status = ActivationStatus.RUNNING + status_manager = mock.Mock() + + result = _handle_unhealthy_queue( + "activation", + ProcessParentType.ACTIVATION, + 1, + "Monitor", + process_parent, + status_manager, + ) + + assert result is None + status_manager.set_status.assert_called_once_with( + ActivationStatus.WORKERS_OFFLINE, mock.ANY + ) + status_manager.set_latest_instance_status.assert_called_once_with( + ActivationStatus.WORKERS_OFFLINE, mock.ANY + ) + + +@pytest.mark.django_db +def test_handle_unhealthy_queue_restart_fallback(): + """Restart request falls back to least busy queue.""" + process_parent = mock.Mock() + status_manager = mock.Mock() + + with mock.patch( + "aap_eda.tasks.orchestrator.get_least_busy_queue_name", + return_value="healthy-queue", + ): + result = _handle_unhealthy_queue( + "activation", + ProcessParentType.ACTIVATION, + 1, + ActivationRequest.RESTART, + process_parent, + status_manager, + ) + + assert result == "healthy-queue" + + +@pytest.mark.django_db +def test_handle_unhealthy_queue_restart_no_healthy(): + """Restart with no healthy queues sets PENDING.""" + process_parent = mock.Mock() + status_manager = mock.Mock() + + with mock.patch( + "aap_eda.tasks.orchestrator.get_least_busy_queue_name", + side_effect=HealthyQueueNotFoundError, + ): + result = _handle_unhealthy_queue( + "activation", + ProcessParentType.ACTIVATION, + 1, + ActivationRequest.RESTART, + process_parent, + status_manager, + ) + + assert result is None + status_manager.set_status.assert_called_once_with( + ActivationStatus.PENDING, mock.ANY + )