Some checks failed
Tests / test (push) Failing after 5s
- Extracts Direct Preference Optimization training pairs from session transcripts - Detects corrections via regex patterns (direct, redirect, frustration, forgotten, etc.) - Supports session JSONL files (primary) and NATS events (fallback) - Async NATS fetching via nats-py ordered consumer for bulk reads - Outputs training format (prompt/chosen/rejected) and detailed format with metadata - 41 tests covering correction detection, false positives, event parsing, pair building - CLI: python -m cortex.dpo_extractor --since 30d --source sessions --dry-run
341 lines
13 KiB
Python
341 lines
13 KiB
Python
#!/usr/bin/env python3
|
|
"""Tests for DPO Preference Pair Extractor."""
|
|
|
|
import pytest
|
|
from cortex.dpo_extractor import (
|
|
detect_correction,
|
|
is_false_positive,
|
|
extract_user_text,
|
|
extract_assistant_text,
|
|
get_session_id,
|
|
clean_user_text,
|
|
build_dpo_pair,
|
|
extract_dpo_pairs,
|
|
to_dpo_training_format,
|
|
to_detailed_format,
|
|
MIN_PROMPT_LEN,
|
|
MIN_RESPONSE_LEN,
|
|
MIN_CORRECTION_LEN,
|
|
)
|
|
|
|
|
|
# --- Correction Detection ---
|
|
|
|
class TestDetectCorrection:
|
|
"""Test correction pattern detection."""
|
|
|
|
def test_direct_negation_with_instruction(self):
|
|
result = detect_correction("nein, du sollst ollama anhalten und das model trainieren")
|
|
assert result is not None
|
|
label, conf = result
|
|
assert label == 'direct_correction'
|
|
assert conf >= 0.85
|
|
|
|
def test_redirect(self):
|
|
result = detect_correction("nicht so, sondern mit apply_chat_template")
|
|
assert result is not None
|
|
assert result[0] == 'redirect'
|
|
|
|
def test_factual_correction(self):
|
|
result = detect_correction("das stimmt nicht, die GPU hat 20GB VRAM")
|
|
assert result is not None
|
|
assert result[0] == 'factual_correction'
|
|
|
|
def test_frustration_correction(self):
|
|
result = detect_correction("bist du bescheuert? Wir haben doch alles da")
|
|
assert result is not None
|
|
assert result[0] == 'frustration_correction'
|
|
assert result[1] >= 0.9
|
|
|
|
def test_forgotten(self):
|
|
result = detect_correction("du hast vergessen warum wir Darkplex gebaut haben")
|
|
assert result is not None
|
|
assert result[0] == 'forgotten'
|
|
|
|
def test_repeated_instruction(self):
|
|
result = detect_correction("ich habe doch gesagt wir sollen Gemma nehmen")
|
|
assert result is not None
|
|
assert result[0] == 'repeated_instruction'
|
|
|
|
def test_mild_redirect(self):
|
|
result = detect_correction("naja, eher so dass wir die Pipeline automatisieren")
|
|
assert result is not None
|
|
assert result[0] == 'mild_redirect'
|
|
|
|
def test_negative_feedback(self):
|
|
result = detect_correction("finds blöd dass ich dich immer so testen muss")
|
|
assert result is not None
|
|
assert result[0] == 'negative_feedback'
|
|
|
|
def test_should_redirect(self):
|
|
result = detect_correction("du solltest eher den bestehenden Code checken statt neu zu bauen")
|
|
assert result is not None
|
|
assert result[0] == 'should_redirect'
|
|
|
|
# --- Non-corrections ---
|
|
|
|
def test_normal_message_no_correction(self):
|
|
result = detect_correction("kannst du mir den Status von Ollama zeigen?")
|
|
assert result is None
|
|
|
|
def test_positive_feedback_no_correction(self):
|
|
result = detect_correction("super, genau so wollte ich das haben!")
|
|
assert result is None
|
|
|
|
def test_short_message_no_correction(self):
|
|
result = detect_correction("ja")
|
|
assert result is None
|
|
|
|
def test_question_no_correction(self):
|
|
result = detect_correction("wie viel VRAM hat die GPU auf Desktop01?")
|
|
assert result is None
|
|
|
|
def test_timestamp_prefix_stripped(self):
|
|
result = detect_correction("[Thu 2026-02-12 09:17 GMT+1] nein, du sollst ollama anhalten")
|
|
assert result is not None
|
|
assert result[0] == 'direct_correction'
|
|
|
|
def test_matrix_prefix_stripped(self):
|
|
result = detect_correction(
|
|
"System: [2026-02-12 09:17:00 GMT+1] Matrix message from albert: "
|
|
"nein, du sollst das anders machen, ich meinte die andere Config"
|
|
)
|
|
assert result is not None
|
|
|
|
|
|
class TestFalsePositives:
|
|
"""Ensure we don't flag system messages as corrections."""
|
|
|
|
def test_system_message(self):
|
|
assert is_false_positive("System: heartbeat check ok") is True
|
|
|
|
def test_heartbeat(self):
|
|
assert is_false_positive("Read HEARTBEAT.md if it exists") is True
|
|
|
|
def test_media_attachment(self):
|
|
assert is_false_positive("[media attached: /home/keller/file.pdf]") is True
|
|
|
|
def test_cron_message(self):
|
|
assert is_false_positive("[cron:abc-123] Run learning context") is True
|
|
|
|
def test_exec_result(self):
|
|
assert is_false_positive("exec completed with code 0") is True
|
|
|
|
def test_precompaction(self):
|
|
assert is_false_positive("Pre-compaction memory flush. Store durable memories now") is True
|
|
|
|
def test_normal_message_not_false_positive(self):
|
|
assert is_false_positive("lass uns ein anderes model wählen") is False
|
|
|
|
|
|
# --- Event Parsing ---
|
|
|
|
class TestEventParsing:
|
|
"""Test extraction from various event formats."""
|
|
|
|
def test_extract_user_text_preview(self):
|
|
event = {
|
|
'type': 'conversation.message.in',
|
|
'payload': {
|
|
'text_preview': [{'type': 'text', 'text': 'hallo welt'}],
|
|
}
|
|
}
|
|
assert extract_user_text(event) == 'hallo welt'
|
|
|
|
def test_extract_user_content(self):
|
|
event = {
|
|
'type': 'conversation.message.in',
|
|
'payload': {'content': 'hallo welt'},
|
|
}
|
|
assert extract_user_text(event) == 'hallo welt'
|
|
|
|
def test_extract_user_wrong_type(self):
|
|
event = {
|
|
'type': 'conversation.message.out',
|
|
'payload': {'content': 'response'},
|
|
}
|
|
assert extract_user_text(event) is None
|
|
|
|
def test_extract_assistant_data_text(self):
|
|
event = {
|
|
'type': 'conversation.message.out',
|
|
'payload': {'data': {'text': 'here is my response'}},
|
|
}
|
|
assert extract_assistant_text(event) == 'here is my response'
|
|
|
|
def test_extract_assistant_content(self):
|
|
event = {
|
|
'type': 'conversation.message.out',
|
|
'payload': {'content': 'response text'},
|
|
}
|
|
assert extract_assistant_text(event) == 'response text'
|
|
|
|
def test_session_id_from_payload(self):
|
|
event = {
|
|
'type': 'conversation.message.in',
|
|
'payload': {'sessionId': 'abc-123'},
|
|
}
|
|
assert get_session_id(event) == 'abc-123'
|
|
|
|
def test_session_id_from_run_id(self):
|
|
event = {
|
|
'type': 'conversation.message.out',
|
|
'payload': {'runId': 'run-456'},
|
|
}
|
|
assert get_session_id(event) == 'run-456'
|
|
|
|
|
|
class TestCleanUserText:
|
|
def test_strip_timestamp(self):
|
|
assert clean_user_text("[Thu 2026-02-12 09:17 GMT+1] hallo") == "hallo"
|
|
|
|
def test_strip_matrix_prefix(self):
|
|
text = "System: [2026-02-12 09:17:00 GMT+1] Matrix message from albert: hallo"
|
|
assert clean_user_text(text) == "hallo"
|
|
|
|
def test_strip_message_id(self):
|
|
text = "some text\n[message_id: abc-123]"
|
|
assert clean_user_text(text) == "some text"
|
|
|
|
def test_plain_text_unchanged(self):
|
|
assert clean_user_text("just a normal message") == "just a normal message"
|
|
|
|
|
|
# --- DPO Pair Building ---
|
|
|
|
def _make_event(type_: str, text: str, seq: int = 1, session: str = 'test') -> dict:
|
|
"""Helper to create test events."""
|
|
if type_ == 'conversation.message.in':
|
|
return {
|
|
'type': type_,
|
|
'seq': seq,
|
|
'timestamp': 1000000 + seq,
|
|
'payload': {
|
|
'text_preview': [{'type': 'text', 'text': text}],
|
|
'sessionId': session,
|
|
},
|
|
}
|
|
else:
|
|
return {
|
|
'type': type_,
|
|
'seq': seq,
|
|
'timestamp': 1000000 + seq,
|
|
'payload': {
|
|
'data': {'text': text},
|
|
'runId': session,
|
|
},
|
|
}
|
|
|
|
|
|
class TestBuildDPOPair:
|
|
def test_valid_pair(self):
|
|
prompt = _make_event('conversation.message.in',
|
|
'Kannst du mir ein Training-Script für Gemma erstellen?', seq=1)
|
|
rejected = _make_event('conversation.message.out',
|
|
'Hier ist ein Script mit hardcoded chat template tags... ' * 5, seq=2)
|
|
correction = _make_event('conversation.message.in',
|
|
'nein, du sollst tokenizer.apply_chat_template benutzen statt hardcoded tags',
|
|
seq=3)
|
|
|
|
pair = build_dpo_pair(prompt, rejected, correction, 'direct_correction', 0.9)
|
|
assert pair is not None
|
|
assert 'Gemma' in pair['prompt']
|
|
assert 'hardcoded' in pair['rejected']
|
|
assert pair['chosen'] is None # Caller fills this
|
|
assert pair['metadata']['correction_type'] == 'direct_correction'
|
|
|
|
def test_too_short_prompt(self):
|
|
prompt = _make_event('conversation.message.in', 'ja?', seq=1)
|
|
rejected = _make_event('conversation.message.out', 'x' * 100, seq=2)
|
|
correction = _make_event('conversation.message.in', 'nein, ich meinte etwas ganz anderes', seq=3)
|
|
|
|
pair = build_dpo_pair(prompt, rejected, correction, 'direct_correction', 0.9)
|
|
assert pair is None
|
|
|
|
def test_too_short_response(self):
|
|
prompt = _make_event('conversation.message.in', 'Erstelle mir einen DPO Extraktor', seq=1)
|
|
rejected = _make_event('conversation.message.out', 'Ok.', seq=2)
|
|
correction = _make_event('conversation.message.in', 'das ist falsch, mach das bitte richtig', seq=3)
|
|
|
|
pair = build_dpo_pair(prompt, rejected, correction, 'factual_correction', 0.85)
|
|
assert pair is None
|
|
|
|
|
|
# --- Full Pipeline ---
|
|
|
|
class TestExtractDPOPairs:
|
|
def test_basic_correction_flow(self):
|
|
"""Test: user → assistant → correction → better response."""
|
|
events = [
|
|
_make_event('conversation.message.in',
|
|
'Erstelle mir ein Training Script für Gemma 2 auf der 7800 XT', seq=1),
|
|
_make_event('conversation.message.out',
|
|
'Hier ist das Script mit <start_of_turn>user tags... ' * 5, seq=2),
|
|
_make_event('conversation.message.in',
|
|
'nein, du sollst tokenizer.apply_chat_template() benutzen, nicht hardcoded tags',
|
|
seq=3),
|
|
_make_event('conversation.message.out',
|
|
'Du hast recht, hier ist die korrigierte Version mit apply_chat_template()... ' * 5,
|
|
seq=4),
|
|
]
|
|
|
|
pairs, stats = extract_dpo_pairs(events)
|
|
assert len(pairs) >= 1
|
|
assert pairs[0].get('chosen') is not None
|
|
assert stats['corrections_found'] >= 1
|
|
|
|
def test_no_corrections(self):
|
|
"""Normal conversation without corrections → no pairs."""
|
|
events = [
|
|
_make_event('conversation.message.in', 'Was ist der Status von Ollama?', seq=1),
|
|
_make_event('conversation.message.out', 'Ollama läuft und hat das claudia-memory Modell geladen. ' * 5, seq=2),
|
|
_make_event('conversation.message.in', 'Super, danke für die Info!', seq=3),
|
|
]
|
|
|
|
pairs, stats = extract_dpo_pairs(events)
|
|
assert len(pairs) == 0
|
|
|
|
def test_multiple_sessions(self):
|
|
"""Corrections from different sessions are handled independently."""
|
|
events = [
|
|
_make_event('conversation.message.in', 'Mach mir ein neues Feature für Darkplex', seq=1, session='s1'),
|
|
_make_event('conversation.message.out', 'Ich baue jetzt ein komplett neues System dafür... ' * 5, seq=2, session='s1'),
|
|
_make_event('conversation.message.in',
|
|
'bist du bescheuert? Wir haben doch alles da, schau erst was existiert',
|
|
seq=3, session='s1'),
|
|
_make_event('conversation.message.in', 'Zeig mir den Wetterbericht', seq=4, session='s2'),
|
|
_make_event('conversation.message.out', 'Das Wetter in Berlin ist sonnig bei 5 Grad... ' * 5, seq=5, session='s2'),
|
|
]
|
|
|
|
pairs, stats = extract_dpo_pairs(events)
|
|
# Only session s1 should produce a pair
|
|
assert len(pairs) == 1
|
|
assert pairs[0]['metadata']['session'] == 's1'
|
|
|
|
|
|
class TestOutputFormats:
|
|
def test_training_format_only_complete(self):
|
|
"""Training format only includes pairs with both chosen and rejected."""
|
|
pairs = [
|
|
{'prompt': 'q1', 'chosen': 'good answer', 'rejected': 'bad answer',
|
|
'correction': 'fix it', 'metadata': {}},
|
|
{'prompt': 'q2', 'chosen': None, 'rejected': 'wrong',
|
|
'correction': 'nope', 'metadata': {}},
|
|
]
|
|
training = to_dpo_training_format(pairs)
|
|
assert len(training) == 1
|
|
assert training[0]['prompt'] == 'q1'
|
|
|
|
def test_detailed_format_all_pairs(self):
|
|
"""Detailed format includes all pairs with metadata."""
|
|
pairs = [
|
|
{'prompt': 'q1', 'chosen': 'good', 'rejected': 'bad',
|
|
'correction': 'fix', 'metadata': {'correction_type': 'test', 'confidence': 0.9,
|
|
'prompt_seq': 1, 'rejected_seq': 2,
|
|
'correction_seq': 3, 'session': 's1',
|
|
'timestamp': 1000}},
|
|
]
|
|
detailed = to_detailed_format(pairs)
|
|
assert len(detailed) == 1
|
|
assert detailed[0]['correction_type'] == 'test'
|
|
assert detailed[0]['has_chosen'] is True
|