darkplex-core/tests/test_dpo_extractor.py
Claudia a3764c627d
Some checks failed
Tests / test (push) Failing after 5s
feat: DPO preference pair extractor
- Extracts Direct Preference Optimization training pairs from session transcripts
- Detects corrections via regex patterns (direct, redirect, frustration, forgotten, etc.)
- Supports session JSONL files (primary) and NATS events (fallback)
- Async NATS fetching via nats-py ordered consumer for bulk reads
- Outputs training format (prompt/chosen/rejected) and detailed format with metadata
- 41 tests covering correction detection, false positives, event parsing, pair building
- CLI: python -m cortex.dpo_extractor --since 30d --source sessions --dry-run
2026-02-12 10:01:32 +01:00

341 lines
13 KiB
Python

#!/usr/bin/env python3
"""Tests for DPO Preference Pair Extractor."""
import pytest
from cortex.dpo_extractor import (
detect_correction,
is_false_positive,
extract_user_text,
extract_assistant_text,
get_session_id,
clean_user_text,
build_dpo_pair,
extract_dpo_pairs,
to_dpo_training_format,
to_detailed_format,
MIN_PROMPT_LEN,
MIN_RESPONSE_LEN,
MIN_CORRECTION_LEN,
)
# --- Correction Detection ---
class TestDetectCorrection:
"""Test correction pattern detection."""
def test_direct_negation_with_instruction(self):
result = detect_correction("nein, du sollst ollama anhalten und das model trainieren")
assert result is not None
label, conf = result
assert label == 'direct_correction'
assert conf >= 0.85
def test_redirect(self):
result = detect_correction("nicht so, sondern mit apply_chat_template")
assert result is not None
assert result[0] == 'redirect'
def test_factual_correction(self):
result = detect_correction("das stimmt nicht, die GPU hat 20GB VRAM")
assert result is not None
assert result[0] == 'factual_correction'
def test_frustration_correction(self):
result = detect_correction("bist du bescheuert? Wir haben doch alles da")
assert result is not None
assert result[0] == 'frustration_correction'
assert result[1] >= 0.9
def test_forgotten(self):
result = detect_correction("du hast vergessen warum wir Darkplex gebaut haben")
assert result is not None
assert result[0] == 'forgotten'
def test_repeated_instruction(self):
result = detect_correction("ich habe doch gesagt wir sollen Gemma nehmen")
assert result is not None
assert result[0] == 'repeated_instruction'
def test_mild_redirect(self):
result = detect_correction("naja, eher so dass wir die Pipeline automatisieren")
assert result is not None
assert result[0] == 'mild_redirect'
def test_negative_feedback(self):
result = detect_correction("finds blöd dass ich dich immer so testen muss")
assert result is not None
assert result[0] == 'negative_feedback'
def test_should_redirect(self):
result = detect_correction("du solltest eher den bestehenden Code checken statt neu zu bauen")
assert result is not None
assert result[0] == 'should_redirect'
# --- Non-corrections ---
def test_normal_message_no_correction(self):
result = detect_correction("kannst du mir den Status von Ollama zeigen?")
assert result is None
def test_positive_feedback_no_correction(self):
result = detect_correction("super, genau so wollte ich das haben!")
assert result is None
def test_short_message_no_correction(self):
result = detect_correction("ja")
assert result is None
def test_question_no_correction(self):
result = detect_correction("wie viel VRAM hat die GPU auf Desktop01?")
assert result is None
def test_timestamp_prefix_stripped(self):
result = detect_correction("[Thu 2026-02-12 09:17 GMT+1] nein, du sollst ollama anhalten")
assert result is not None
assert result[0] == 'direct_correction'
def test_matrix_prefix_stripped(self):
result = detect_correction(
"System: [2026-02-12 09:17:00 GMT+1] Matrix message from albert: "
"nein, du sollst das anders machen, ich meinte die andere Config"
)
assert result is not None
class TestFalsePositives:
"""Ensure we don't flag system messages as corrections."""
def test_system_message(self):
assert is_false_positive("System: heartbeat check ok") is True
def test_heartbeat(self):
assert is_false_positive("Read HEARTBEAT.md if it exists") is True
def test_media_attachment(self):
assert is_false_positive("[media attached: /home/keller/file.pdf]") is True
def test_cron_message(self):
assert is_false_positive("[cron:abc-123] Run learning context") is True
def test_exec_result(self):
assert is_false_positive("exec completed with code 0") is True
def test_precompaction(self):
assert is_false_positive("Pre-compaction memory flush. Store durable memories now") is True
def test_normal_message_not_false_positive(self):
assert is_false_positive("lass uns ein anderes model wählen") is False
# --- Event Parsing ---
class TestEventParsing:
"""Test extraction from various event formats."""
def test_extract_user_text_preview(self):
event = {
'type': 'conversation.message.in',
'payload': {
'text_preview': [{'type': 'text', 'text': 'hallo welt'}],
}
}
assert extract_user_text(event) == 'hallo welt'
def test_extract_user_content(self):
event = {
'type': 'conversation.message.in',
'payload': {'content': 'hallo welt'},
}
assert extract_user_text(event) == 'hallo welt'
def test_extract_user_wrong_type(self):
event = {
'type': 'conversation.message.out',
'payload': {'content': 'response'},
}
assert extract_user_text(event) is None
def test_extract_assistant_data_text(self):
event = {
'type': 'conversation.message.out',
'payload': {'data': {'text': 'here is my response'}},
}
assert extract_assistant_text(event) == 'here is my response'
def test_extract_assistant_content(self):
event = {
'type': 'conversation.message.out',
'payload': {'content': 'response text'},
}
assert extract_assistant_text(event) == 'response text'
def test_session_id_from_payload(self):
event = {
'type': 'conversation.message.in',
'payload': {'sessionId': 'abc-123'},
}
assert get_session_id(event) == 'abc-123'
def test_session_id_from_run_id(self):
event = {
'type': 'conversation.message.out',
'payload': {'runId': 'run-456'},
}
assert get_session_id(event) == 'run-456'
class TestCleanUserText:
def test_strip_timestamp(self):
assert clean_user_text("[Thu 2026-02-12 09:17 GMT+1] hallo") == "hallo"
def test_strip_matrix_prefix(self):
text = "System: [2026-02-12 09:17:00 GMT+1] Matrix message from albert: hallo"
assert clean_user_text(text) == "hallo"
def test_strip_message_id(self):
text = "some text\n[message_id: abc-123]"
assert clean_user_text(text) == "some text"
def test_plain_text_unchanged(self):
assert clean_user_text("just a normal message") == "just a normal message"
# --- DPO Pair Building ---
def _make_event(type_: str, text: str, seq: int = 1, session: str = 'test') -> dict:
"""Helper to create test events."""
if type_ == 'conversation.message.in':
return {
'type': type_,
'seq': seq,
'timestamp': 1000000 + seq,
'payload': {
'text_preview': [{'type': 'text', 'text': text}],
'sessionId': session,
},
}
else:
return {
'type': type_,
'seq': seq,
'timestamp': 1000000 + seq,
'payload': {
'data': {'text': text},
'runId': session,
},
}
class TestBuildDPOPair:
def test_valid_pair(self):
prompt = _make_event('conversation.message.in',
'Kannst du mir ein Training-Script für Gemma erstellen?', seq=1)
rejected = _make_event('conversation.message.out',
'Hier ist ein Script mit hardcoded chat template tags... ' * 5, seq=2)
correction = _make_event('conversation.message.in',
'nein, du sollst tokenizer.apply_chat_template benutzen statt hardcoded tags',
seq=3)
pair = build_dpo_pair(prompt, rejected, correction, 'direct_correction', 0.9)
assert pair is not None
assert 'Gemma' in pair['prompt']
assert 'hardcoded' in pair['rejected']
assert pair['chosen'] is None # Caller fills this
assert pair['metadata']['correction_type'] == 'direct_correction'
def test_too_short_prompt(self):
prompt = _make_event('conversation.message.in', 'ja?', seq=1)
rejected = _make_event('conversation.message.out', 'x' * 100, seq=2)
correction = _make_event('conversation.message.in', 'nein, ich meinte etwas ganz anderes', seq=3)
pair = build_dpo_pair(prompt, rejected, correction, 'direct_correction', 0.9)
assert pair is None
def test_too_short_response(self):
prompt = _make_event('conversation.message.in', 'Erstelle mir einen DPO Extraktor', seq=1)
rejected = _make_event('conversation.message.out', 'Ok.', seq=2)
correction = _make_event('conversation.message.in', 'das ist falsch, mach das bitte richtig', seq=3)
pair = build_dpo_pair(prompt, rejected, correction, 'factual_correction', 0.85)
assert pair is None
# --- Full Pipeline ---
class TestExtractDPOPairs:
def test_basic_correction_flow(self):
"""Test: user → assistant → correction → better response."""
events = [
_make_event('conversation.message.in',
'Erstelle mir ein Training Script für Gemma 2 auf der 7800 XT', seq=1),
_make_event('conversation.message.out',
'Hier ist das Script mit <start_of_turn>user tags... ' * 5, seq=2),
_make_event('conversation.message.in',
'nein, du sollst tokenizer.apply_chat_template() benutzen, nicht hardcoded tags',
seq=3),
_make_event('conversation.message.out',
'Du hast recht, hier ist die korrigierte Version mit apply_chat_template()... ' * 5,
seq=4),
]
pairs, stats = extract_dpo_pairs(events)
assert len(pairs) >= 1
assert pairs[0].get('chosen') is not None
assert stats['corrections_found'] >= 1
def test_no_corrections(self):
"""Normal conversation without corrections → no pairs."""
events = [
_make_event('conversation.message.in', 'Was ist der Status von Ollama?', seq=1),
_make_event('conversation.message.out', 'Ollama läuft und hat das claudia-memory Modell geladen. ' * 5, seq=2),
_make_event('conversation.message.in', 'Super, danke für die Info!', seq=3),
]
pairs, stats = extract_dpo_pairs(events)
assert len(pairs) == 0
def test_multiple_sessions(self):
"""Corrections from different sessions are handled independently."""
events = [
_make_event('conversation.message.in', 'Mach mir ein neues Feature für Darkplex', seq=1, session='s1'),
_make_event('conversation.message.out', 'Ich baue jetzt ein komplett neues System dafür... ' * 5, seq=2, session='s1'),
_make_event('conversation.message.in',
'bist du bescheuert? Wir haben doch alles da, schau erst was existiert',
seq=3, session='s1'),
_make_event('conversation.message.in', 'Zeig mir den Wetterbericht', seq=4, session='s2'),
_make_event('conversation.message.out', 'Das Wetter in Berlin ist sonnig bei 5 Grad... ' * 5, seq=5, session='s2'),
]
pairs, stats = extract_dpo_pairs(events)
# Only session s1 should produce a pair
assert len(pairs) == 1
assert pairs[0]['metadata']['session'] == 's1'
class TestOutputFormats:
def test_training_format_only_complete(self):
"""Training format only includes pairs with both chosen and rejected."""
pairs = [
{'prompt': 'q1', 'chosen': 'good answer', 'rejected': 'bad answer',
'correction': 'fix it', 'metadata': {}},
{'prompt': 'q2', 'chosen': None, 'rejected': 'wrong',
'correction': 'nope', 'metadata': {}},
]
training = to_dpo_training_format(pairs)
assert len(training) == 1
assert training[0]['prompt'] == 'q1'
def test_detailed_format_all_pairs(self):
"""Detailed format includes all pairs with metadata."""
pairs = [
{'prompt': 'q1', 'chosen': 'good', 'rejected': 'bad',
'correction': 'fix', 'metadata': {'correction_type': 'test', 'confidence': 0.9,
'prompt_seq': 1, 'rejected_seq': 2,
'correction_seq': 3, 'session': 's1',
'timestamp': 1000}},
]
detailed = to_detailed_format(pairs)
assert len(detailed) == 1
assert detailed[0]['correction_type'] == 'test'
assert detailed[0]['has_chosen'] is True