Dynavera/apps/onboarding/tests/test_consumers.py
2026-03-18 01:04:16 +00:00

192 lines
7.6 KiB
Python

from asgiref.sync import async_to_sync
from django.contrib.auth import get_user_model
from django.test import TestCase
from unittest.mock import AsyncMock, patch
from apps.accounts.models import Organization, Role
from apps.onboarding.consumers import OnboardingGenerateConsumer
from apps.onboarding.models import AgentConfig
from apps.onboarding.utils.content_moderator import ContentModerator
User = get_user_model()
_PROFANE = "fuck" # Common profanity...
class ContentModeratorTests(TestCase):
def setUp(self):
self.moderator = ContentModerator()
def test_clean_text_passes(self):
self.assertTrue(self.moderator.is_clean("What is the onboarding process?"))
def test_profane_text_blocked(self):
self.assertFalse(self.moderator.is_clean(f"this is {_PROFANE} content"))
def test_non_string_input_passes(self):
self.assertTrue(self.moderator.is_clean(None))
self.assertTrue(self.moderator.is_clean(123))
def test_censor_replaces_profanity(self):
result = self.moderator.censor(f"this is {_PROFANE} content")
self.assertNotIn(_PROFANE, result)
def test_censor_passes_clean_text_unchanged(self):
text = "Please review the onboarding materials."
self.assertEqual(self.moderator.censor(text), text)
def test_censor_non_string_returned_as_is(self):
self.assertIsNone(self.moderator.censor(None))
self.assertEqual(self.moderator.censor(42), 42)
class ConsumerModerationTests(TestCase):
def setUp(self):
self.user = User.objects.create_user(
email_address='moderation-test@example.com',
password='pass1234',
first_name='Mod',
last_name='Tester',
date_of_birth='1995-05-05',
is_manager=True,
)
self.org = Organization.objects.create(name='Moderation Test Org', owner=self.user)
self.org.members.add(self.user)
self.role = Role.objects.create(name='Mod Role', organization=self.org)
self.consumer = OnboardingGenerateConsumer()
self.consumer.user = self.user
def _run_receive(self, payload: str):
return async_to_sync(self.consumer.receive)(payload)
def test_clean_query_is_dispatched(self):
import json
self.consumer.send_error = AsyncMock()
with patch.object(self.consumer, 'action_start_full_onboarding', new=AsyncMock()) as mock_action:
self._run_receive(json.dumps({"action": "start_full_onboarding", "query": "Tell me about onboarding"}))
mock_action.assert_called_once()
self.consumer.send_error.assert_not_called()
def test_profane_query_is_blocked(self):
import json
self.consumer.send_error = AsyncMock()
self._run_receive(json.dumps({"action": "start_full_onboarding", "query": f"this is {_PROFANE} content"}))
self.consumer.send_error.assert_called_once()
args = self.consumer.send_error.call_args[0]
self.assertIn("moderation", args[0].lower())
def test_profane_message_field_is_blocked(self):
import json
self.consumer.send_error = AsyncMock()
self._run_receive(json.dumps({"action": "start_full_onboarding", "message": f"this is {_PROFANE} content"}))
self.consumer.send_error.assert_called_once()
args = self.consumer.send_error.call_args[0]
self.assertIn("moderation", args[0].lower())
def test_clean_message_field_is_dispatched(self):
import json
self.consumer.send_error = AsyncMock()
with patch.object(self.consumer, 'action_start_full_onboarding', new=AsyncMock()) as mock_action:
self._run_receive(json.dumps({"action": "start_full_onboarding", "message": "begin onboarding"}))
mock_action.assert_called_once()
self.consumer.send_error.assert_not_called()
class OnboardingConsumerConfigSelectionTests(TestCase):
def setUp(self):
self.user = User.objects.create_user(
email_address='consumer-test@example.com',
password='pass1234',
first_name='Consumer',
last_name='Tester',
date_of_birth='1992-02-02',
is_manager=True,
)
self.org = Organization.objects.create(name='Consumer Test Org', owner=self.user)
self.org.members.add(self.user)
self.quant_role = Role.objects.create(name='Quant Role Consumer', organization=self.org)
self.ux_role = Role.objects.create(name='UX Role Consumer', organization=self.org)
self.consumer = OnboardingGenerateConsumer()
def test_get_config_by_type_prefers_exact_role(self):
quant_cfg = AgentConfig.objects.create(
organization=self.org,
role=self.quant_role,
name='Quant Curriculum Override',
agent_type='curriculum',
system_prompt='Quant-specific prompt',
)
AgentConfig.objects.create(
organization=self.org,
role=self.ux_role,
name='UX Curriculum Override',
agent_type='curriculum',
system_prompt='UX-specific prompt',
)
selected = async_to_sync(self.consumer.get_config_by_type)(str(self.quant_role.uuid), 'curriculum')
self.assertIsNotNone(selected)
self.assertEqual(selected.uuid, quant_cfg.uuid)
self.assertEqual(selected.role_id, self.quant_role.id)
def test_get_config_by_type_falls_back_to_org_default(self):
AgentConfig.objects.filter(role=self.quant_role, agent_type='monitor').delete()
org_default = AgentConfig.objects.create(
organization=self.org,
role=None,
name='Org Monitor Default',
agent_type='monitor',
system_prompt='Organization-level monitor prompt',
)
selected = async_to_sync(self.consumer.get_config_by_type)(str(self.quant_role.uuid), 'monitor')
self.assertIsNotNone(selected)
self.assertEqual(selected.uuid, org_default.uuid)
self.assertIsNone(selected.role)
def test_extract_json_list_supports_wrapped_questions_payload(self):
payload = (
"Here is your quiz output:\n"
"```json\n"
'{"questions": [{"key": "q1", "label": "Question?", "field_type": "select", "options": ["A", "B"], "required": true, "validation": {"correct_option": "A", "explanation": "A"}}]}\n'
"```"
)
extracted = self.consumer._extract_json_list(payload)
self.assertIsInstance(extracted, list)
self.assertEqual(len(extracted), 1)
self.assertEqual(extracted[0]['key'], 'q1')
def test_build_fallback_quiz_fields_generates_eight_valid_questions(self):
fallback = self.consumer._build_fallback_quiz_fields(['Topic A', 'Topic B'])
self.assertEqual(len(fallback), 8)
select_items = [item for item in fallback if item.get('field_type') == 'select']
short_answer_items = [
item
for item in fallback
if item.get('field_type') in ('text', 'textarea')
]
self.assertGreaterEqual(len(select_items), 2)
self.assertGreaterEqual(len(short_answer_items), 2)
self.assertTrue(all(len(item.get('options', [])) >= 4 for item in select_items))
self.assertTrue(
all(
item.get('validation', {}).get('correct_option') in item.get('options', [])
for item in select_items
)
)
self.assertTrue(
all(
isinstance(item.get('validation', {}).get('accepted_answers', []), list)
and len(item.get('validation', {}).get('accepted_answers', [])) > 0
for item in short_answer_items
)
)