#!/usr/bin/env python3 """Tests for DPO Preference Pair Extractor.""" import pytest from cortex.dpo_extractor import ( detect_correction, is_false_positive, extract_user_text, extract_assistant_text, get_session_id, clean_user_text, build_dpo_pair, extract_dpo_pairs, to_dpo_training_format, to_detailed_format, MIN_PROMPT_LEN, MIN_RESPONSE_LEN, MIN_CORRECTION_LEN, ) # --- Correction Detection --- class TestDetectCorrection: """Test correction pattern detection.""" def test_direct_negation_with_instruction(self): result = detect_correction("nein, du sollst ollama anhalten und das model trainieren") assert result is not None label, conf = result assert label == 'direct_correction' assert conf >= 0.85 def test_redirect(self): result = detect_correction("nicht so, sondern mit apply_chat_template") assert result is not None assert result[0] == 'redirect' def test_factual_correction(self): result = detect_correction("das stimmt nicht, die GPU hat 20GB VRAM") assert result is not None assert result[0] == 'factual_correction' def test_frustration_correction(self): result = detect_correction("bist du bescheuert? Wir haben doch alles da") assert result is not None assert result[0] == 'frustration_correction' assert result[1] >= 0.9 def test_forgotten(self): result = detect_correction("du hast vergessen warum wir Darkplex gebaut haben") assert result is not None assert result[0] == 'forgotten' def test_repeated_instruction(self): result = detect_correction("ich habe doch gesagt wir sollen Gemma nehmen") assert result is not None assert result[0] == 'repeated_instruction' def test_mild_redirect(self): result = detect_correction("naja, eher so dass wir die Pipeline automatisieren") assert result is not None assert result[0] == 'mild_redirect' def test_negative_feedback(self): result = detect_correction("finds blöd dass ich dich immer so testen muss") assert result is not None assert result[0] == 'negative_feedback' def test_should_redirect(self): result = detect_correction("du solltest eher den bestehenden Code checken statt neu zu bauen") assert result is not None assert result[0] == 'should_redirect' # --- Non-corrections --- def test_normal_message_no_correction(self): result = detect_correction("kannst du mir den Status von Ollama zeigen?") assert result is None def test_positive_feedback_no_correction(self): result = detect_correction("super, genau so wollte ich das haben!") assert result is None def test_short_message_no_correction(self): result = detect_correction("ja") assert result is None def test_question_no_correction(self): result = detect_correction("wie viel VRAM hat die GPU auf Desktop01?") assert result is None def test_timestamp_prefix_stripped(self): result = detect_correction("[Thu 2026-02-12 09:17 GMT+1] nein, du sollst ollama anhalten") assert result is not None assert result[0] == 'direct_correction' def test_matrix_prefix_stripped(self): result = detect_correction( "System: [2026-02-12 09:17:00 GMT+1] Matrix message from albert: " "nein, du sollst das anders machen, ich meinte die andere Config" ) assert result is not None class TestFalsePositives: """Ensure we don't flag system messages as corrections.""" def test_system_message(self): assert is_false_positive("System: heartbeat check ok") is True def test_heartbeat(self): assert is_false_positive("Read HEARTBEAT.md if it exists") is True def test_media_attachment(self): assert is_false_positive("[media attached: /home/keller/file.pdf]") is True def test_cron_message(self): assert is_false_positive("[cron:abc-123] Run learning context") is True def test_exec_result(self): assert is_false_positive("exec completed with code 0") is True def test_precompaction(self): assert is_false_positive("Pre-compaction memory flush. Store durable memories now") is True def test_normal_message_not_false_positive(self): assert is_false_positive("lass uns ein anderes model wählen") is False # --- Event Parsing --- class TestEventParsing: """Test extraction from various event formats.""" def test_extract_user_text_preview(self): event = { 'type': 'conversation.message.in', 'payload': { 'text_preview': [{'type': 'text', 'text': 'hallo welt'}], } } assert extract_user_text(event) == 'hallo welt' def test_extract_user_content(self): event = { 'type': 'conversation.message.in', 'payload': {'content': 'hallo welt'}, } assert extract_user_text(event) == 'hallo welt' def test_extract_user_wrong_type(self): event = { 'type': 'conversation.message.out', 'payload': {'content': 'response'}, } assert extract_user_text(event) is None def test_extract_assistant_data_text(self): event = { 'type': 'conversation.message.out', 'payload': {'data': {'text': 'here is my response'}}, } assert extract_assistant_text(event) == 'here is my response' def test_extract_assistant_content(self): event = { 'type': 'conversation.message.out', 'payload': {'content': 'response text'}, } assert extract_assistant_text(event) == 'response text' def test_session_id_from_payload(self): event = { 'type': 'conversation.message.in', 'payload': {'sessionId': 'abc-123'}, } assert get_session_id(event) == 'abc-123' def test_session_id_from_run_id(self): event = { 'type': 'conversation.message.out', 'payload': {'runId': 'run-456'}, } assert get_session_id(event) == 'run-456' class TestCleanUserText: def test_strip_timestamp(self): assert clean_user_text("[Thu 2026-02-12 09:17 GMT+1] hallo") == "hallo" def test_strip_matrix_prefix(self): text = "System: [2026-02-12 09:17:00 GMT+1] Matrix message from albert: hallo" assert clean_user_text(text) == "hallo" def test_strip_message_id(self): text = "some text\n[message_id: abc-123]" assert clean_user_text(text) == "some text" def test_plain_text_unchanged(self): assert clean_user_text("just a normal message") == "just a normal message" # --- DPO Pair Building --- def _make_event(type_: str, text: str, seq: int = 1, session: str = 'test') -> dict: """Helper to create test events.""" if type_ == 'conversation.message.in': return { 'type': type_, 'seq': seq, 'timestamp': 1000000 + seq, 'payload': { 'text_preview': [{'type': 'text', 'text': text}], 'sessionId': session, }, } else: return { 'type': type_, 'seq': seq, 'timestamp': 1000000 + seq, 'payload': { 'data': {'text': text}, 'runId': session, }, } class TestBuildDPOPair: def test_valid_pair(self): prompt = _make_event('conversation.message.in', 'Kannst du mir ein Training-Script für Gemma erstellen?', seq=1) rejected = _make_event('conversation.message.out', 'Hier ist ein Script mit hardcoded chat template tags... ' * 5, seq=2) correction = _make_event('conversation.message.in', 'nein, du sollst tokenizer.apply_chat_template benutzen statt hardcoded tags', seq=3) pair = build_dpo_pair(prompt, rejected, correction, 'direct_correction', 0.9) assert pair is not None assert 'Gemma' in pair['prompt'] assert 'hardcoded' in pair['rejected'] assert pair['chosen'] is None # Caller fills this assert pair['metadata']['correction_type'] == 'direct_correction' def test_too_short_prompt(self): prompt = _make_event('conversation.message.in', 'ja?', seq=1) rejected = _make_event('conversation.message.out', 'x' * 100, seq=2) correction = _make_event('conversation.message.in', 'nein, ich meinte etwas ganz anderes', seq=3) pair = build_dpo_pair(prompt, rejected, correction, 'direct_correction', 0.9) assert pair is None def test_too_short_response(self): prompt = _make_event('conversation.message.in', 'Erstelle mir einen DPO Extraktor', seq=1) rejected = _make_event('conversation.message.out', 'Ok.', seq=2) correction = _make_event('conversation.message.in', 'das ist falsch, mach das bitte richtig', seq=3) pair = build_dpo_pair(prompt, rejected, correction, 'factual_correction', 0.85) assert pair is None # --- Full Pipeline --- class TestExtractDPOPairs: def test_basic_correction_flow(self): """Test: user → assistant → correction → better response.""" events = [ _make_event('conversation.message.in', 'Erstelle mir ein Training Script für Gemma 2 auf der 7800 XT', seq=1), _make_event('conversation.message.out', 'Hier ist das Script mit user tags... ' * 5, seq=2), _make_event('conversation.message.in', 'nein, du sollst tokenizer.apply_chat_template() benutzen, nicht hardcoded tags', seq=3), _make_event('conversation.message.out', 'Du hast recht, hier ist die korrigierte Version mit apply_chat_template()... ' * 5, seq=4), ] pairs, stats = extract_dpo_pairs(events) assert len(pairs) >= 1 assert pairs[0].get('chosen') is not None assert stats['corrections_found'] >= 1 def test_no_corrections(self): """Normal conversation without corrections → no pairs.""" events = [ _make_event('conversation.message.in', 'Was ist der Status von Ollama?', seq=1), _make_event('conversation.message.out', 'Ollama läuft und hat das claudia-memory Modell geladen. ' * 5, seq=2), _make_event('conversation.message.in', 'Super, danke für die Info!', seq=3), ] pairs, stats = extract_dpo_pairs(events) assert len(pairs) == 0 def test_multiple_sessions(self): """Corrections from different sessions are handled independently.""" events = [ _make_event('conversation.message.in', 'Mach mir ein neues Feature für Darkplex', seq=1, session='s1'), _make_event('conversation.message.out', 'Ich baue jetzt ein komplett neues System dafür... ' * 5, seq=2, session='s1'), _make_event('conversation.message.in', 'bist du bescheuert? Wir haben doch alles da, schau erst was existiert', seq=3, session='s1'), _make_event('conversation.message.in', 'Zeig mir den Wetterbericht', seq=4, session='s2'), _make_event('conversation.message.out', 'Das Wetter in Berlin ist sonnig bei 5 Grad... ' * 5, seq=5, session='s2'), ] pairs, stats = extract_dpo_pairs(events) # Only session s1 should produce a pair assert len(pairs) == 1 assert pairs[0]['metadata']['session'] == 's1' class TestOutputFormats: def test_training_format_only_complete(self): """Training format only includes pairs with both chosen and rejected.""" pairs = [ {'prompt': 'q1', 'chosen': 'good answer', 'rejected': 'bad answer', 'correction': 'fix it', 'metadata': {}}, {'prompt': 'q2', 'chosen': None, 'rejected': 'wrong', 'correction': 'nope', 'metadata': {}}, ] training = to_dpo_training_format(pairs) assert len(training) == 1 assert training[0]['prompt'] == 'q1' def test_detailed_format_all_pairs(self): """Detailed format includes all pairs with metadata.""" pairs = [ {'prompt': 'q1', 'chosen': 'good', 'rejected': 'bad', 'correction': 'fix', 'metadata': {'correction_type': 'test', 'confidence': 0.9, 'prompt_seq': 1, 'rejected_seq': 2, 'correction_seq': 3, 'session': 's1', 'timestamp': 1000}}, ] detailed = to_detailed_format(pairs) assert len(detailed) == 1 assert detailed[0]['correction_type'] == 'test' assert detailed[0]['has_chosen'] is True