from asgiref.sync import async_to_sync from django.contrib.auth import get_user_model from django.test import TestCase from unittest.mock import AsyncMock, patch from apps.accounts.models import Organization, Role from apps.onboarding.consumers import OnboardingGenerateConsumer from apps.onboarding.models import AgentConfig from apps.onboarding.utils.content_moderator import ContentModerator User = get_user_model() _PROFANE = "fuck" # Common profanity... class ContentModeratorTests(TestCase): def setUp(self): self.moderator = ContentModerator() def test_clean_text_passes(self): self.assertTrue(self.moderator.is_clean("What is the onboarding process?")) def test_profane_text_blocked(self): self.assertFalse(self.moderator.is_clean(f"this is {_PROFANE} content")) def test_non_string_input_passes(self): self.assertTrue(self.moderator.is_clean(None)) self.assertTrue(self.moderator.is_clean(123)) def test_censor_replaces_profanity(self): result = self.moderator.censor(f"this is {_PROFANE} content") self.assertNotIn(_PROFANE, result) def test_censor_passes_clean_text_unchanged(self): text = "Please review the onboarding materials." self.assertEqual(self.moderator.censor(text), text) def test_censor_non_string_returned_as_is(self): self.assertIsNone(self.moderator.censor(None)) self.assertEqual(self.moderator.censor(42), 42) class ConsumerModerationTests(TestCase): def setUp(self): self.user = User.objects.create_user( email_address='moderation-test@example.com', password='pass1234', first_name='Mod', last_name='Tester', date_of_birth='1995-05-05', is_manager=True, ) self.org = Organization.objects.create(name='Moderation Test Org', owner=self.user) self.org.members.add(self.user) self.role = Role.objects.create(name='Mod Role', organization=self.org) self.consumer = OnboardingGenerateConsumer() self.consumer.user = self.user def _run_receive(self, payload: str): return async_to_sync(self.consumer.receive)(payload) def test_clean_query_is_dispatched(self): import json self.consumer.send_error = AsyncMock() with patch.object(self.consumer, 'action_start_full_onboarding', new=AsyncMock()) as mock_action: self._run_receive(json.dumps({"action": "start_full_onboarding", "query": "Tell me about onboarding"})) mock_action.assert_called_once() self.consumer.send_error.assert_not_called() def test_profane_query_is_blocked(self): import json self.consumer.send_error = AsyncMock() self._run_receive(json.dumps({"action": "start_full_onboarding", "query": f"this is {_PROFANE} content"})) self.consumer.send_error.assert_called_once() args = self.consumer.send_error.call_args[0] self.assertIn("moderation", args[0].lower()) def test_profane_message_field_is_blocked(self): import json self.consumer.send_error = AsyncMock() self._run_receive(json.dumps({"action": "start_full_onboarding", "message": f"this is {_PROFANE} content"})) self.consumer.send_error.assert_called_once() args = self.consumer.send_error.call_args[0] self.assertIn("moderation", args[0].lower()) def test_clean_message_field_is_dispatched(self): import json self.consumer.send_error = AsyncMock() with patch.object(self.consumer, 'action_start_full_onboarding', new=AsyncMock()) as mock_action: self._run_receive(json.dumps({"action": "start_full_onboarding", "message": "begin onboarding"})) mock_action.assert_called_once() self.consumer.send_error.assert_not_called() class OnboardingConsumerConfigSelectionTests(TestCase): def setUp(self): self.user = User.objects.create_user( email_address='consumer-test@example.com', password='pass1234', first_name='Consumer', last_name='Tester', date_of_birth='1992-02-02', is_manager=True, ) self.org = Organization.objects.create(name='Consumer Test Org', owner=self.user) self.org.members.add(self.user) self.quant_role = Role.objects.create(name='Quant Role Consumer', organization=self.org) self.ux_role = Role.objects.create(name='UX Role Consumer', organization=self.org) self.consumer = OnboardingGenerateConsumer() def test_get_config_by_type_prefers_exact_role(self): quant_cfg = AgentConfig.objects.create( organization=self.org, role=self.quant_role, name='Quant Curriculum Override', agent_type='curriculum', system_prompt='Quant-specific prompt', ) AgentConfig.objects.create( organization=self.org, role=self.ux_role, name='UX Curriculum Override', agent_type='curriculum', system_prompt='UX-specific prompt', ) selected = async_to_sync(self.consumer.get_config_by_type)(str(self.quant_role.uuid), 'curriculum') self.assertIsNotNone(selected) self.assertEqual(selected.uuid, quant_cfg.uuid) self.assertEqual(selected.role_id, self.quant_role.id) def test_get_config_by_type_falls_back_to_org_default(self): AgentConfig.objects.filter(role=self.quant_role, agent_type='monitor').delete() org_default = AgentConfig.objects.create( organization=self.org, role=None, name='Org Monitor Default', agent_type='monitor', system_prompt='Organization-level monitor prompt', ) selected = async_to_sync(self.consumer.get_config_by_type)(str(self.quant_role.uuid), 'monitor') self.assertIsNotNone(selected) self.assertEqual(selected.uuid, org_default.uuid) self.assertIsNone(selected.role) def test_extract_json_list_supports_wrapped_questions_payload(self): payload = ( "Here is your quiz output:\n" "```json\n" '{"questions": [{"key": "q1", "label": "Question?", "field_type": "select", "options": ["A", "B"], "required": true, "validation": {"correct_option": "A", "explanation": "A"}}]}\n' "```" ) extracted = self.consumer._extract_json_list(payload) self.assertIsInstance(extracted, list) self.assertEqual(len(extracted), 1) self.assertEqual(extracted[0]['key'], 'q1') def test_build_fallback_quiz_fields_generates_eight_valid_questions(self): fallback = self.consumer._build_fallback_quiz_fields(['Topic A', 'Topic B']) self.assertEqual(len(fallback), 8) select_items = [item for item in fallback if item.get('field_type') == 'select'] short_answer_items = [ item for item in fallback if item.get('field_type') in ('text', 'textarea') ] self.assertGreaterEqual(len(select_items), 2) self.assertGreaterEqual(len(short_answer_items), 2) self.assertTrue(all(len(item.get('options', [])) >= 4 for item in select_items)) self.assertTrue( all( item.get('validation', {}).get('correct_option') in item.get('options', []) for item in select_items ) ) self.assertTrue( all( isinstance(item.get('validation', {}).get('accepted_answers', []), list) and len(item.get('validation', {}).get('accepted_answers', [])) > 0 for item in short_answer_items ) )