diff --git a/cortex/dpo_extractor.py b/cortex/dpo_extractor.py new file mode 100644 index 0000000..c8bf89f --- /dev/null +++ b/cortex/dpo_extractor.py @@ -0,0 +1,626 @@ +#!/usr/bin/env python3 +"""DPO Preference Pair Extractor for Darkplex. + +Extracts Direct Preference Optimization (DPO) training pairs from NATS event store. +DPO pairs consist of (prompt, chosen, rejected) where: +- prompt: the user's original request +- rejected: Claudia's response that was corrected +- chosen: the corrected/better response (derived from the correction) + +Correction detection strategy: +1. Find user messages that explicitly correct the previous assistant response +2. Pair the corrected response (rejected) with the correction context (chosen) +3. Quality-filter to avoid false positives + +Usage: + python -m cortex.dpo_extractor --since 7d + python -m cortex.dpo_extractor --since 30d --output ~/clawd/training-data/dpo-pairs.json + python -m cortex.dpo_extractor --dry-run --since 7d +""" + +import argparse +import asyncio +import json +import os +import re +import sys +from datetime import datetime, timedelta +from pathlib import Path +from typing import Optional + +# --- Correction Detection --- + +# Explicit correction: user tells Claudia she's wrong or should do something differently +CORRECTION_PATTERNS = [ + # Direct negation + instruction + (r'(?:^|\s)nein[,.]?\s+(?:du\s+sollst|ich\s+(?:wollte|meinte|meine))', 'direct_correction', 0.9), + (r'nicht\s+so[,.]?\s+(?:sondern|ich\s+(?:meinte|wollte))', 'redirect', 0.9), + (r'das\s+(?:stimmt|ist)\s+(?:nicht|falsch)', 'factual_correction', 0.85), + (r'(?:^|\s)falsch[.!]', 'wrong', 0.85), + + # Implicit correction: frustration + redirect + (r'(?:bist\s+du\s+)?bescheuert\??', 'frustration_correction', 0.95), + (r'(?:du\s+hast\s+(?:es\s+)?(?:vergessen|übersehen))', 'forgotten', 0.8), + (r'ich\s+(?:hab|habe)\s+(?:doch\s+)?gesagt', 'repeated_instruction', 0.85), + (r'(?:das\s+)?(?:war|ist)\s+(?:nicht\s+(?:das|was)\s+ich|falsch|quatsch)', 'rejection', 0.85), + (r'(?:nochmal|noch\s*mal)[,:]?\s+(?:ich|du|wir|das)', 'retry_request', 0.7), + + # Mild corrections + (r'(?:naja|nee|hmm)[,.]?\s+(?:eher|lieber|besser|anders)', 'mild_redirect', 0.7), + (r'(?:finds?|finde)\s+(?:ich\s+)?(?:blöd|schlecht|nicht\s+gut)', 'negative_feedback', 0.75), + (r'du\s+(?:solltest|musst|kannst)\s+(?:eher|lieber|besser)', 'should_redirect', 0.7), +] + +# Anti-patterns: things that look like corrections but aren't +FALSE_POSITIVE_PATTERNS = [ + r'^system:', # System messages + r'heartbeat', # Heartbeat + r'^\[media\s+attached', # Media attachments + r'^\[cron:', # Cron messages + r'^pre-compaction', # Memory flush + r'exec\s+(?:completed|failed)', # Exec results + r'^read\s+heartbeat', # Heartbeat instructions +] + +# Compile patterns +CORRECTION_RE = [(re.compile(p, re.IGNORECASE | re.MULTILINE), label, conf) + for p, label, conf in CORRECTION_PATTERNS] +FALSE_POSITIVE_RE = [re.compile(p, re.IGNORECASE) for p in FALSE_POSITIVE_PATTERNS] + +# Minimum lengths +MIN_PROMPT_LEN = 15 # User prompt must be meaningful +MIN_RESPONSE_LEN = 80 # Assistant response must be substantive +MIN_CORRECTION_LEN = 20 # Correction must explain what's wrong + + +def is_false_positive(text: str) -> bool: + """Check if text matches false positive patterns.""" + return any(p.search(text) for p in FALSE_POSITIVE_RE) + + +def detect_correction(text: str) -> Optional[tuple[str, float]]: + """Detect if user message is a correction. Returns (label, confidence) or None.""" + # Clean first, then check false positives on cleaned text + clean = clean_user_text(text) + + if is_false_positive(clean): + return None + + if len(clean) < MIN_CORRECTION_LEN: + return None + + best_match = None + best_conf = 0.0 + + for pattern, label, conf in CORRECTION_RE: + if pattern.search(clean): + if conf > best_conf: + best_match = (label, conf) + best_conf = conf + + return best_match + + +# --- Event Parsing --- + +def extract_user_text(event: dict) -> Optional[str]: + """Extract user text from a conversation.message.in event.""" + if event.get('type') != 'conversation.message.in': + return None + + payload = event.get('payload', {}) + + # text_preview format (most common) + if isinstance(payload.get('text_preview'), list) and payload['text_preview']: + return payload['text_preview'][0].get('text', '') + + # Direct content + if 'content' in payload: + return payload['content'] + + return None + + +def extract_assistant_text(event: dict) -> Optional[str]: + """Extract assistant text from a conversation.message.out event.""" + if event.get('type') != 'conversation.message.out': + return None + + payload = event.get('payload', {}) + + if isinstance(payload.get('data'), dict) and 'text' in payload['data']: + return payload['data']['text'] + + if 'content' in payload: + return payload['content'] + + return None + + +def get_session_id(event: dict) -> str: + """Extract session identifier from event.""" + payload = event.get('payload', {}) + if event.get('type') == 'conversation.message.out' and payload.get('runId'): + return payload['runId'] + if payload.get('sessionId'): + return payload['sessionId'] + return event.get('session', 'unknown') + + +def clean_user_text(text: str) -> str: + """Strip metadata from user message, return the actual content.""" + # Remove System: [timestamp] Matrix message from X: prefix + text = re.sub(r'^System:\s*\[.*?\]\s*(?:Matrix\s+message\s+from\s+\w+:\s*)?', '', text).strip() + # Remove [Day YYYY-MM-DD HH:MM TZ] or [YYYY-MM-DD ...] timestamp prefix + text = re.sub(r'^\[(?:\w+\s+)?\d{4}-\d{2}-\d{2}[^\]]*\]\s*', '', text).strip() + # Remove [Matrix user ...] prefix + text = re.sub(r'^\[Matrix\s+\w+[^\]]*\]\s*', '', text).strip() + # Remove message_id lines + text = re.sub(r'\[message_id:.*?\]', '', text).strip() + return text + + +# --- DPO Pair Construction --- + +def build_dpo_pair( + prompt_event: dict, + rejected_event: dict, + correction_event: dict, + correction_label: str, + correction_confidence: float, +) -> Optional[dict]: + """Build a DPO training pair from a correction sequence. + + Returns dict with: prompt, chosen, rejected, metadata + """ + prompt_text = clean_user_text(extract_user_text(prompt_event) or '') + rejected_text = extract_assistant_text(rejected_event) or '' + correction_text = clean_user_text(extract_user_text(correction_event) or '') + + # Validate lengths + if len(prompt_text) < MIN_PROMPT_LEN: + return None + if len(rejected_text) < MIN_RESPONSE_LEN: + return None + if len(correction_text) < MIN_CORRECTION_LEN: + return None + + # The "chosen" response is constructed from the correction context. + # We use the correction as a signal — the chosen text is what the user + # wanted instead. For DPO training, we need an actual better response. + # Strategy: use the correction itself as context for what "chosen" should be. + # In practice, after the correction, the assistant usually gives a better response. + # We'll look for that in the caller. + + return { + 'prompt': prompt_text, + 'rejected': rejected_text, + 'chosen': None, # To be filled by caller with post-correction response + 'correction': correction_text, + 'metadata': { + 'correction_type': correction_label, + 'confidence': correction_confidence, + 'prompt_seq': prompt_event.get('seq'), + 'rejected_seq': rejected_event.get('seq'), + 'correction_seq': correction_event.get('seq'), + 'session': get_session_id(prompt_event), + 'timestamp': correction_event.get('timestamp'), + } + } + + +# --- NATS Fetching --- + +# --- Session Transcript Parsing --- + +def load_session_transcripts(sessions_dir: str, since_hours: int = 168) -> list[dict]: + """Load conversation messages from OpenClaw session JSONL files. + + Returns a flat list of events with 'role', 'text', 'session', 'timestamp', 'seq'. + """ + from pathlib import Path + import time + + sessions_path = Path(sessions_dir) + if not sessions_path.exists(): + print(f" ⚠️ Sessions dir not found: {sessions_dir}", file=sys.stderr) + return [] + + cutoff_time = time.time() - (since_hours * 3600) + events = [] + files_processed = 0 + + for jsonl_file in sorted(sessions_path.glob('*.jsonl')): + # Skip old files + if jsonl_file.stat().st_mtime < cutoff_time: + continue + + session_id = jsonl_file.stem + seq = 0 + + try: + with open(jsonl_file) as fh: + for line in fh: + try: + entry = json.loads(line) + except json.JSONDecodeError: + continue + + if entry.get('type') != 'message': + continue + + msg = entry.get('message', {}) + role = msg.get('role', '') + if role not in ('user', 'assistant'): + continue + + # Extract text content + content = msg.get('content', '') + if isinstance(content, list): + text_parts = [] + for part in content: + if isinstance(part, dict) and part.get('type') == 'text': + text_parts.append(part.get('text', '')) + content = '\n'.join(text_parts) + + if not content: + continue + + events.append({ + 'type': f'conversation.message.{"in" if role == "user" else "out"}', + 'role': role, + 'text': content, + 'session': session_id, + 'timestamp': entry.get('timestamp', 0), + 'seq': seq, + 'payload': { + 'text_preview': [{'type': 'text', 'text': content}] if role == 'user' else {}, + 'data': {'text': content} if role == 'assistant' else {}, + 'sessionId': session_id, + }, + }) + seq += 1 + + files_processed += 1 + except Exception as e: + print(f" ⚠️ Error reading {jsonl_file.name}: {e}", file=sys.stderr) + + print(f" Loaded {len(events)} messages from {files_processed} session files", file=sys.stderr) + return events + + +def fetch_events_by_sequence(start_seq: int, end_seq: int, batch_size: int = 500) -> list[dict]: + """Fetch events from NATS by sequence range using nats-py (fast) or CLI (fallback).""" + # Try fast async fetch first + try: + import asyncio + events = asyncio.run(_fetch_events_async(start_seq, end_seq)) + if events: + return events + except Exception as e: + print(f" Async fetch failed ({e}), falling back to CLI", file=sys.stderr) + + return _fetch_events_cli(start_seq, end_seq) + + +async def _fetch_events_async(start_seq: int, end_seq: int) -> list[dict]: + """Fast bulk fetch using nats-py consumer.""" + import nats as nats_lib + from nats.js.api import ConsumerConfig, DeliverPolicy + + user = os.environ.get('NATS_USER', 'claudia') + password = os.environ.get('NATS_PASSWORD', '') + + nc = await nats_lib.connect( + 'nats://localhost:4222', + user=user, password=password, + ) + js = nc.jetstream() + + events = [] + # Create ephemeral ordered consumer starting at our sequence + sub = await js.subscribe( + 'openclaw.events.>', + ordered_consumer=True, + config=ConsumerConfig( + deliver_policy=DeliverPolicy.BY_START_SEQUENCE, + opt_start_seq=start_seq, + ), + ) + + try: + count = 0 + while True: + try: + msg = await asyncio.wait_for(sub.next_msg(), timeout=2.0) + try: + event = json.loads(msg.data.decode()) + except (json.JSONDecodeError, UnicodeDecodeError): + count += 1 + continue + event['seq'] = count + start_seq + events.append(event) + count += 1 + + if count % 1000 == 0: + print(f" Fetched {count} events...", file=sys.stderr, end='\r') + + # Stop at end_seq (approximate via count) + if count >= (end_seq - start_seq + 1): + break + except asyncio.TimeoutError: + break + finally: + await sub.unsubscribe() + await nc.drain() + + print(f" Fetched {len(events)} events (async) ", file=sys.stderr) + return events + + +def _fetch_events_cli(start_seq: int, end_seq: int) -> list[dict]: + """Fallback: fetch events one by one via nats CLI.""" + import subprocess + + events = [] + errors = 0 + + for seq in range(start_seq, end_seq + 1): + try: + result = subprocess.run( + ['nats', 'stream', 'get', 'openclaw-events', str(seq)], + capture_output=True, text=True, timeout=5, + ) + for line in result.stdout.split('\n'): + if line.startswith('{'): + event = json.loads(line) + event['seq'] = seq + events.append(event) + break + except Exception: + errors += 1 + if errors > 50: + print(f" ⚠️ Too many errors ({errors}), stopping", file=sys.stderr) + break + + if (seq - start_seq) % 200 == 0 and seq > start_seq: + print(f" Fetched {seq - start_seq}/{end_seq - start_seq} events...", + file=sys.stderr, end='\r') + + print(f" Fetched {len(events)} events ({errors} errors) ", file=sys.stderr) + return events + + +def get_stream_info() -> dict: + """Get NATS stream info.""" + import subprocess + result = subprocess.run( + ['nats', 'stream', 'info', 'openclaw-events', '--json'], + capture_output=True, text=True, timeout=10, + ) + return json.loads(result.stdout) + + +# --- Main Extraction Pipeline --- + +def extract_dpo_pairs(events: list[dict], verbose: bool = False) -> list[dict]: + """Extract DPO pairs from a list of events. + + Scanning strategy: + 1. Group events by session + 2. Within each session, find the pattern: + user_msg → assistant_response → correction → (optional) better_response + 3. Build DPO pair: prompt=user_msg, rejected=assistant_response, chosen=better_response + """ + # Group by session + sessions: dict[str, list[dict]] = {} + for event in events: + sid = get_session_id(event) + sessions.setdefault(sid, []).append(event) + + pairs = [] + stats = {'sessions': 0, 'corrections_found': 0, 'pairs_built': 0, 'pairs_with_chosen': 0} + + for sid, session_events in sessions.items(): + session_events.sort(key=lambda e: e.get('seq', 0)) + stats['sessions'] += 1 + + # Build conversation sequence: list of (role, text, event) + conversation = [] + for event in session_events: + user_text = extract_user_text(event) + if user_text: + conversation.append(('user', user_text, event)) + continue + asst_text = extract_assistant_text(event) + if asst_text: + # Keep the longest assistant response in a streak + if conversation and conversation[-1][0] == 'assistant': + if len(asst_text) > len(conversation[-1][1]): + conversation[-1] = ('assistant', asst_text, event) + else: + conversation.append(('assistant', asst_text, event)) + + # Scan for correction pattern: user → assistant → user(correction) → assistant(better) + for i in range(len(conversation) - 2): + if (conversation[i][0] == 'user' and + conversation[i + 1][0] == 'assistant' and + conversation[i + 2][0] == 'user'): + + correction_result = detect_correction(conversation[i + 2][1]) + if not correction_result: + continue + + label, confidence = correction_result + stats['corrections_found'] += 1 + + pair = build_dpo_pair( + prompt_event=conversation[i][2], + rejected_event=conversation[i + 1][2], + correction_event=conversation[i + 2][2], + correction_label=label, + correction_confidence=confidence, + ) + + if not pair: + continue + + # Look for the better response after the correction + if (i + 3 < len(conversation) and + conversation[i + 3][0] == 'assistant'): + chosen_text = conversation[i + 3][1] + if len(chosen_text) >= MIN_RESPONSE_LEN: + pair['chosen'] = chosen_text + stats['pairs_with_chosen'] += 1 + + stats['pairs_built'] += 1 + pairs.append(pair) + + if verbose: + print(f"\n 📌 [{label}] conf={confidence:.0%}", file=sys.stderr) + print(f" prompt: {pair['prompt'][:80]}...", file=sys.stderr) + print(f" correction: {pair['correction'][:80]}...", file=sys.stderr) + + return pairs, stats + + +def to_dpo_training_format(pairs: list[dict]) -> list[dict]: + """Convert to standard DPO training format for trl.DPOTrainer. + + Only includes pairs that have both chosen and rejected responses. + """ + training_pairs = [] + for pair in pairs: + if not pair.get('chosen'): + continue + + training_pairs.append({ + 'prompt': pair['prompt'], + 'chosen': pair['chosen'], + 'rejected': pair['rejected'], + }) + + return training_pairs + + +def to_detailed_format(pairs: list[dict]) -> list[dict]: + """Full format with metadata for inspection and debugging.""" + return [{ + 'prompt': p['prompt'], + 'chosen': p.get('chosen', ''), + 'rejected': p['rejected'], + 'correction': p['correction'], + 'has_chosen': bool(p.get('chosen')), + **p['metadata'], + } for p in pairs] + + +# --- CLI --- + +def parse_duration(s: str) -> timedelta: + """Parse '7d', '24h', '30m' to timedelta.""" + m = re.match(r'^(\d+)([dhm])$', s.lower()) + if not m: + raise ValueError(f"Invalid duration: {s}") + v, u = int(m.group(1)), m.group(2) + return {'d': timedelta(days=v), 'h': timedelta(hours=v), 'm': timedelta(minutes=v)}[u] + + +def main(): + parser = argparse.ArgumentParser( + description='Extract DPO preference pairs from NATS event store', + ) + parser.add_argument('--since', default='7d', help='Time window (e.g. 7d, 24h)') + parser.add_argument('--output', '-o', help='Output file (default: auto)') + parser.add_argument('--format', choices=['training', 'detailed', 'both'], default='both', + help='Output format') + parser.add_argument('--min-confidence', type=float, default=0.7, + help='Minimum correction confidence (0-1)') + parser.add_argument('--dry-run', action='store_true', help='Show stats only, no output') + parser.add_argument('--verbose', '-v', action='store_true', help='Show each found pair') + parser.add_argument('--sessions-dir', default=None, + help='Path to OpenClaw session JSONL dir (default: ~/.openclaw/agents/main/sessions)') + parser.add_argument('--source', choices=['sessions', 'nats', 'auto'], default='auto', + help='Data source: session transcripts (preferred) or NATS events') + + args = parser.parse_args() + + print("🔍 DPO Preference Pair Extractor", file=sys.stderr) + + duration = parse_duration(args.since) + hours = duration.total_seconds() / 3600 + + # Determine data source + sessions_dir = args.sessions_dir or os.path.expanduser('~/.openclaw/agents/main/sessions') + use_sessions = args.source == 'sessions' or ( + args.source == 'auto' and os.path.isdir(sessions_dir) + ) + + if use_sessions: + print(f" Source: Session transcripts ({sessions_dir})", file=sys.stderr) + conv_events = load_session_transcripts(sessions_dir, since_hours=int(hours)) + else: + print(f" Source: NATS event store", file=sys.stderr) + info = get_stream_info() + last_seq = info['state']['last_seq'] + first_seq = info['state']['first_seq'] + estimated_events = int(hours * 50) + start_seq = max(first_seq, last_seq - estimated_events) + print(f" Scanning sequences {start_seq}-{last_seq}", file=sys.stderr) + events = fetch_events_by_sequence(start_seq, last_seq) + conv_events = [e for e in events if e.get('type', '').startswith('conversation.message')] + print(f" {len(conv_events)} conversation events out of {len(events)} total", file=sys.stderr) + + # Extract pairs + pairs, stats = extract_dpo_pairs(conv_events, verbose=args.verbose) + + # Filter by confidence + pairs = [p for p in pairs if p['metadata']['confidence'] >= args.min_confidence] + + # Stats + print(f"\n📊 Results:", file=sys.stderr) + print(f" Sessions scanned: {stats['sessions']}", file=sys.stderr) + print(f" Corrections detected: {stats['corrections_found']}", file=sys.stderr) + print(f" DPO pairs built: {stats['pairs_built']}", file=sys.stderr) + print(f" Pairs with chosen response: {stats['pairs_with_chosen']}", file=sys.stderr) + print(f" After confidence filter (≥{args.min_confidence}): {len(pairs)}", file=sys.stderr) + + complete = [p for p in pairs if p.get('chosen')] + incomplete = [p for p in pairs if not p.get('chosen')] + print(f" Complete (prompt+chosen+rejected): {len(complete)}", file=sys.stderr) + print(f" Incomplete (no chosen response): {len(incomplete)}", file=sys.stderr) + + if args.dry_run: + # Show sample pairs + if pairs: + print(f"\n📝 Sample pairs:", file=sys.stderr) + for p in pairs[:5]: + print(f"\n [{p['metadata']['correction_type']}] " + f"conf={p['metadata']['confidence']:.0%}", file=sys.stderr) + print(f" prompt: {p['prompt'][:100]}", file=sys.stderr) + print(f" rejected: {p['rejected'][:100]}", file=sys.stderr) + print(f" correction: {p['correction'][:100]}", file=sys.stderr) + if p.get('chosen'): + print(f" chosen: {p['chosen'][:100]}", file=sys.stderr) + return + + # Output + output_dir = Path(os.environ.get('CLAWD_DIR', Path.home() / 'clawd')) / 'training-data' + output_dir.mkdir(parents=True, exist_ok=True) + timestamp = datetime.now().strftime('%Y-%m-%d') + + if args.format in ('training', 'both'): + training_data = to_dpo_training_format(pairs) + path = Path(args.output) if args.output else output_dir / f'dpo-training-{timestamp}.json' + path.write_text(json.dumps(training_data, indent=2, ensure_ascii=False)) + print(f"\n✅ Training format: {path} ({len(training_data)} pairs)", file=sys.stderr) + + if args.format in ('detailed', 'both'): + detailed_data = to_detailed_format(pairs) + path = output_dir / f'dpo-detailed-{timestamp}.json' + path.write_text(json.dumps(detailed_data, indent=2, ensure_ascii=False)) + print(f"✅ Detailed format: {path} ({len(detailed_data)} pairs)", file=sys.stderr) + + +if __name__ == '__main__': + main() diff --git a/tests/test_dpo_extractor.py b/tests/test_dpo_extractor.py new file mode 100644 index 0000000..ecf2cfe --- /dev/null +++ b/tests/test_dpo_extractor.py @@ -0,0 +1,341 @@ +#!/usr/bin/env python3 +"""Tests for DPO Preference Pair Extractor.""" + +import pytest +from cortex.dpo_extractor import ( + detect_correction, + is_false_positive, + extract_user_text, + extract_assistant_text, + get_session_id, + clean_user_text, + build_dpo_pair, + extract_dpo_pairs, + to_dpo_training_format, + to_detailed_format, + MIN_PROMPT_LEN, + MIN_RESPONSE_LEN, + MIN_CORRECTION_LEN, +) + + +# --- Correction Detection --- + +class TestDetectCorrection: + """Test correction pattern detection.""" + + def test_direct_negation_with_instruction(self): + result = detect_correction("nein, du sollst ollama anhalten und das model trainieren") + assert result is not None + label, conf = result + assert label == 'direct_correction' + assert conf >= 0.85 + + def test_redirect(self): + result = detect_correction("nicht so, sondern mit apply_chat_template") + assert result is not None + assert result[0] == 'redirect' + + def test_factual_correction(self): + result = detect_correction("das stimmt nicht, die GPU hat 20GB VRAM") + assert result is not None + assert result[0] == 'factual_correction' + + def test_frustration_correction(self): + result = detect_correction("bist du bescheuert? Wir haben doch alles da") + assert result is not None + assert result[0] == 'frustration_correction' + assert result[1] >= 0.9 + + def test_forgotten(self): + result = detect_correction("du hast vergessen warum wir Darkplex gebaut haben") + assert result is not None + assert result[0] == 'forgotten' + + def test_repeated_instruction(self): + result = detect_correction("ich habe doch gesagt wir sollen Gemma nehmen") + assert result is not None + assert result[0] == 'repeated_instruction' + + def test_mild_redirect(self): + result = detect_correction("naja, eher so dass wir die Pipeline automatisieren") + assert result is not None + assert result[0] == 'mild_redirect' + + def test_negative_feedback(self): + result = detect_correction("finds blöd dass ich dich immer so testen muss") + assert result is not None + assert result[0] == 'negative_feedback' + + def test_should_redirect(self): + result = detect_correction("du solltest eher den bestehenden Code checken statt neu zu bauen") + assert result is not None + assert result[0] == 'should_redirect' + + # --- Non-corrections --- + + def test_normal_message_no_correction(self): + result = detect_correction("kannst du mir den Status von Ollama zeigen?") + assert result is None + + def test_positive_feedback_no_correction(self): + result = detect_correction("super, genau so wollte ich das haben!") + assert result is None + + def test_short_message_no_correction(self): + result = detect_correction("ja") + assert result is None + + def test_question_no_correction(self): + result = detect_correction("wie viel VRAM hat die GPU auf Desktop01?") + assert result is None + + def test_timestamp_prefix_stripped(self): + result = detect_correction("[Thu 2026-02-12 09:17 GMT+1] nein, du sollst ollama anhalten") + assert result is not None + assert result[0] == 'direct_correction' + + def test_matrix_prefix_stripped(self): + result = detect_correction( + "System: [2026-02-12 09:17:00 GMT+1] Matrix message from albert: " + "nein, du sollst das anders machen, ich meinte die andere Config" + ) + assert result is not None + + +class TestFalsePositives: + """Ensure we don't flag system messages as corrections.""" + + def test_system_message(self): + assert is_false_positive("System: heartbeat check ok") is True + + def test_heartbeat(self): + assert is_false_positive("Read HEARTBEAT.md if it exists") is True + + def test_media_attachment(self): + assert is_false_positive("[media attached: /home/keller/file.pdf]") is True + + def test_cron_message(self): + assert is_false_positive("[cron:abc-123] Run learning context") is True + + def test_exec_result(self): + assert is_false_positive("exec completed with code 0") is True + + def test_precompaction(self): + assert is_false_positive("Pre-compaction memory flush. Store durable memories now") is True + + def test_normal_message_not_false_positive(self): + assert is_false_positive("lass uns ein anderes model wählen") is False + + +# --- Event Parsing --- + +class TestEventParsing: + """Test extraction from various event formats.""" + + def test_extract_user_text_preview(self): + event = { + 'type': 'conversation.message.in', + 'payload': { + 'text_preview': [{'type': 'text', 'text': 'hallo welt'}], + } + } + assert extract_user_text(event) == 'hallo welt' + + def test_extract_user_content(self): + event = { + 'type': 'conversation.message.in', + 'payload': {'content': 'hallo welt'}, + } + assert extract_user_text(event) == 'hallo welt' + + def test_extract_user_wrong_type(self): + event = { + 'type': 'conversation.message.out', + 'payload': {'content': 'response'}, + } + assert extract_user_text(event) is None + + def test_extract_assistant_data_text(self): + event = { + 'type': 'conversation.message.out', + 'payload': {'data': {'text': 'here is my response'}}, + } + assert extract_assistant_text(event) == 'here is my response' + + def test_extract_assistant_content(self): + event = { + 'type': 'conversation.message.out', + 'payload': {'content': 'response text'}, + } + assert extract_assistant_text(event) == 'response text' + + def test_session_id_from_payload(self): + event = { + 'type': 'conversation.message.in', + 'payload': {'sessionId': 'abc-123'}, + } + assert get_session_id(event) == 'abc-123' + + def test_session_id_from_run_id(self): + event = { + 'type': 'conversation.message.out', + 'payload': {'runId': 'run-456'}, + } + assert get_session_id(event) == 'run-456' + + +class TestCleanUserText: + def test_strip_timestamp(self): + assert clean_user_text("[Thu 2026-02-12 09:17 GMT+1] hallo") == "hallo" + + def test_strip_matrix_prefix(self): + text = "System: [2026-02-12 09:17:00 GMT+1] Matrix message from albert: hallo" + assert clean_user_text(text) == "hallo" + + def test_strip_message_id(self): + text = "some text\n[message_id: abc-123]" + assert clean_user_text(text) == "some text" + + def test_plain_text_unchanged(self): + assert clean_user_text("just a normal message") == "just a normal message" + + +# --- DPO Pair Building --- + +def _make_event(type_: str, text: str, seq: int = 1, session: str = 'test') -> dict: + """Helper to create test events.""" + if type_ == 'conversation.message.in': + return { + 'type': type_, + 'seq': seq, + 'timestamp': 1000000 + seq, + 'payload': { + 'text_preview': [{'type': 'text', 'text': text}], + 'sessionId': session, + }, + } + else: + return { + 'type': type_, + 'seq': seq, + 'timestamp': 1000000 + seq, + 'payload': { + 'data': {'text': text}, + 'runId': session, + }, + } + + +class TestBuildDPOPair: + def test_valid_pair(self): + prompt = _make_event('conversation.message.in', + 'Kannst du mir ein Training-Script für Gemma erstellen?', seq=1) + rejected = _make_event('conversation.message.out', + 'Hier ist ein Script mit hardcoded chat template tags... ' * 5, seq=2) + correction = _make_event('conversation.message.in', + 'nein, du sollst tokenizer.apply_chat_template benutzen statt hardcoded tags', + seq=3) + + pair = build_dpo_pair(prompt, rejected, correction, 'direct_correction', 0.9) + assert pair is not None + assert 'Gemma' in pair['prompt'] + assert 'hardcoded' in pair['rejected'] + assert pair['chosen'] is None # Caller fills this + assert pair['metadata']['correction_type'] == 'direct_correction' + + def test_too_short_prompt(self): + prompt = _make_event('conversation.message.in', 'ja?', seq=1) + rejected = _make_event('conversation.message.out', 'x' * 100, seq=2) + correction = _make_event('conversation.message.in', 'nein, ich meinte etwas ganz anderes', seq=3) + + pair = build_dpo_pair(prompt, rejected, correction, 'direct_correction', 0.9) + assert pair is None + + def test_too_short_response(self): + prompt = _make_event('conversation.message.in', 'Erstelle mir einen DPO Extraktor', seq=1) + rejected = _make_event('conversation.message.out', 'Ok.', seq=2) + correction = _make_event('conversation.message.in', 'das ist falsch, mach das bitte richtig', seq=3) + + pair = build_dpo_pair(prompt, rejected, correction, 'factual_correction', 0.85) + assert pair is None + + +# --- Full Pipeline --- + +class TestExtractDPOPairs: + def test_basic_correction_flow(self): + """Test: user → assistant → correction → better response.""" + events = [ + _make_event('conversation.message.in', + 'Erstelle mir ein Training Script für Gemma 2 auf der 7800 XT', seq=1), + _make_event('conversation.message.out', + 'Hier ist das Script mit user tags... ' * 5, seq=2), + _make_event('conversation.message.in', + 'nein, du sollst tokenizer.apply_chat_template() benutzen, nicht hardcoded tags', + seq=3), + _make_event('conversation.message.out', + 'Du hast recht, hier ist die korrigierte Version mit apply_chat_template()... ' * 5, + seq=4), + ] + + pairs, stats = extract_dpo_pairs(events) + assert len(pairs) >= 1 + assert pairs[0].get('chosen') is not None + assert stats['corrections_found'] >= 1 + + def test_no_corrections(self): + """Normal conversation without corrections → no pairs.""" + events = [ + _make_event('conversation.message.in', 'Was ist der Status von Ollama?', seq=1), + _make_event('conversation.message.out', 'Ollama läuft und hat das claudia-memory Modell geladen. ' * 5, seq=2), + _make_event('conversation.message.in', 'Super, danke für die Info!', seq=3), + ] + + pairs, stats = extract_dpo_pairs(events) + assert len(pairs) == 0 + + def test_multiple_sessions(self): + """Corrections from different sessions are handled independently.""" + events = [ + _make_event('conversation.message.in', 'Mach mir ein neues Feature für Darkplex', seq=1, session='s1'), + _make_event('conversation.message.out', 'Ich baue jetzt ein komplett neues System dafür... ' * 5, seq=2, session='s1'), + _make_event('conversation.message.in', + 'bist du bescheuert? Wir haben doch alles da, schau erst was existiert', + seq=3, session='s1'), + _make_event('conversation.message.in', 'Zeig mir den Wetterbericht', seq=4, session='s2'), + _make_event('conversation.message.out', 'Das Wetter in Berlin ist sonnig bei 5 Grad... ' * 5, seq=5, session='s2'), + ] + + pairs, stats = extract_dpo_pairs(events) + # Only session s1 should produce a pair + assert len(pairs) == 1 + assert pairs[0]['metadata']['session'] == 's1' + + +class TestOutputFormats: + def test_training_format_only_complete(self): + """Training format only includes pairs with both chosen and rejected.""" + pairs = [ + {'prompt': 'q1', 'chosen': 'good answer', 'rejected': 'bad answer', + 'correction': 'fix it', 'metadata': {}}, + {'prompt': 'q2', 'chosen': None, 'rejected': 'wrong', + 'correction': 'nope', 'metadata': {}}, + ] + training = to_dpo_training_format(pairs) + assert len(training) == 1 + assert training[0]['prompt'] == 'q1' + + def test_detailed_format_all_pairs(self): + """Detailed format includes all pairs with metadata.""" + pairs = [ + {'prompt': 'q1', 'chosen': 'good', 'rejected': 'bad', + 'correction': 'fix', 'metadata': {'correction_type': 'test', 'confidence': 0.9, + 'prompt_seq': 1, 'rejected_seq': 2, + 'correction_seq': 3, 'session': 's1', + 'timestamp': 1000}}, + ] + detailed = to_detailed_format(pairs) + assert len(detailed) == 1 + assert detailed[0]['correction_type'] == 'test' + assert detailed[0]['has_chosen'] is True