- Extracts Direct Preference Optimization training pairs from session transcripts - Detects corrections via regex patterns (direct, redirect, frustration, forgotten, etc.) - Supports session JSONL files (primary) and NATS events (fallback) - Async NATS fetching via nats-py ordered consumer for bulk reads - Outputs training format (prompt/chosen/rejected) and detailed format with metadata - 41 tests covering correction detection, false positives, event parsing, pair building - CLI: python -m cortex.dpo_extractor --since 30d --source sessions --dry-run
This commit is contained in:
parent
c5e5ce9dc0
commit
a3764c627d
2 changed files with 967 additions and 0 deletions
626
cortex/dpo_extractor.py
Normal file
626
cortex/dpo_extractor.py
Normal file
|
|
@ -0,0 +1,626 @@
|
|||
#!/usr/bin/env python3
|
||||
"""DPO Preference Pair Extractor for Darkplex.
|
||||
|
||||
Extracts Direct Preference Optimization (DPO) training pairs from NATS event store.
|
||||
DPO pairs consist of (prompt, chosen, rejected) where:
|
||||
- prompt: the user's original request
|
||||
- rejected: Claudia's response that was corrected
|
||||
- chosen: the corrected/better response (derived from the correction)
|
||||
|
||||
Correction detection strategy:
|
||||
1. Find user messages that explicitly correct the previous assistant response
|
||||
2. Pair the corrected response (rejected) with the correction context (chosen)
|
||||
3. Quality-filter to avoid false positives
|
||||
|
||||
Usage:
|
||||
python -m cortex.dpo_extractor --since 7d
|
||||
python -m cortex.dpo_extractor --since 30d --output ~/clawd/training-data/dpo-pairs.json
|
||||
python -m cortex.dpo_extractor --dry-run --since 7d
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
# --- Correction Detection ---
|
||||
|
||||
# Explicit correction: user tells Claudia she's wrong or should do something differently
|
||||
CORRECTION_PATTERNS = [
|
||||
# Direct negation + instruction
|
||||
(r'(?:^|\s)nein[,.]?\s+(?:du\s+sollst|ich\s+(?:wollte|meinte|meine))', 'direct_correction', 0.9),
|
||||
(r'nicht\s+so[,.]?\s+(?:sondern|ich\s+(?:meinte|wollte))', 'redirect', 0.9),
|
||||
(r'das\s+(?:stimmt|ist)\s+(?:nicht|falsch)', 'factual_correction', 0.85),
|
||||
(r'(?:^|\s)falsch[.!]', 'wrong', 0.85),
|
||||
|
||||
# Implicit correction: frustration + redirect
|
||||
(r'(?:bist\s+du\s+)?bescheuert\??', 'frustration_correction', 0.95),
|
||||
(r'(?:du\s+hast\s+(?:es\s+)?(?:vergessen|übersehen))', 'forgotten', 0.8),
|
||||
(r'ich\s+(?:hab|habe)\s+(?:doch\s+)?gesagt', 'repeated_instruction', 0.85),
|
||||
(r'(?:das\s+)?(?:war|ist)\s+(?:nicht\s+(?:das|was)\s+ich|falsch|quatsch)', 'rejection', 0.85),
|
||||
(r'(?:nochmal|noch\s*mal)[,:]?\s+(?:ich|du|wir|das)', 'retry_request', 0.7),
|
||||
|
||||
# Mild corrections
|
||||
(r'(?:naja|nee|hmm)[,.]?\s+(?:eher|lieber|besser|anders)', 'mild_redirect', 0.7),
|
||||
(r'(?:finds?|finde)\s+(?:ich\s+)?(?:blöd|schlecht|nicht\s+gut)', 'negative_feedback', 0.75),
|
||||
(r'du\s+(?:solltest|musst|kannst)\s+(?:eher|lieber|besser)', 'should_redirect', 0.7),
|
||||
]
|
||||
|
||||
# Anti-patterns: things that look like corrections but aren't
|
||||
FALSE_POSITIVE_PATTERNS = [
|
||||
r'^system:', # System messages
|
||||
r'heartbeat', # Heartbeat
|
||||
r'^\[media\s+attached', # Media attachments
|
||||
r'^\[cron:', # Cron messages
|
||||
r'^pre-compaction', # Memory flush
|
||||
r'exec\s+(?:completed|failed)', # Exec results
|
||||
r'^read\s+heartbeat', # Heartbeat instructions
|
||||
]
|
||||
|
||||
# Compile patterns
|
||||
CORRECTION_RE = [(re.compile(p, re.IGNORECASE | re.MULTILINE), label, conf)
|
||||
for p, label, conf in CORRECTION_PATTERNS]
|
||||
FALSE_POSITIVE_RE = [re.compile(p, re.IGNORECASE) for p in FALSE_POSITIVE_PATTERNS]
|
||||
|
||||
# Minimum lengths
|
||||
MIN_PROMPT_LEN = 15 # User prompt must be meaningful
|
||||
MIN_RESPONSE_LEN = 80 # Assistant response must be substantive
|
||||
MIN_CORRECTION_LEN = 20 # Correction must explain what's wrong
|
||||
|
||||
|
||||
def is_false_positive(text: str) -> bool:
|
||||
"""Check if text matches false positive patterns."""
|
||||
return any(p.search(text) for p in FALSE_POSITIVE_RE)
|
||||
|
||||
|
||||
def detect_correction(text: str) -> Optional[tuple[str, float]]:
|
||||
"""Detect if user message is a correction. Returns (label, confidence) or None."""
|
||||
# Clean first, then check false positives on cleaned text
|
||||
clean = clean_user_text(text)
|
||||
|
||||
if is_false_positive(clean):
|
||||
return None
|
||||
|
||||
if len(clean) < MIN_CORRECTION_LEN:
|
||||
return None
|
||||
|
||||
best_match = None
|
||||
best_conf = 0.0
|
||||
|
||||
for pattern, label, conf in CORRECTION_RE:
|
||||
if pattern.search(clean):
|
||||
if conf > best_conf:
|
||||
best_match = (label, conf)
|
||||
best_conf = conf
|
||||
|
||||
return best_match
|
||||
|
||||
|
||||
# --- Event Parsing ---
|
||||
|
||||
def extract_user_text(event: dict) -> Optional[str]:
|
||||
"""Extract user text from a conversation.message.in event."""
|
||||
if event.get('type') != 'conversation.message.in':
|
||||
return None
|
||||
|
||||
payload = event.get('payload', {})
|
||||
|
||||
# text_preview format (most common)
|
||||
if isinstance(payload.get('text_preview'), list) and payload['text_preview']:
|
||||
return payload['text_preview'][0].get('text', '')
|
||||
|
||||
# Direct content
|
||||
if 'content' in payload:
|
||||
return payload['content']
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def extract_assistant_text(event: dict) -> Optional[str]:
|
||||
"""Extract assistant text from a conversation.message.out event."""
|
||||
if event.get('type') != 'conversation.message.out':
|
||||
return None
|
||||
|
||||
payload = event.get('payload', {})
|
||||
|
||||
if isinstance(payload.get('data'), dict) and 'text' in payload['data']:
|
||||
return payload['data']['text']
|
||||
|
||||
if 'content' in payload:
|
||||
return payload['content']
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_session_id(event: dict) -> str:
|
||||
"""Extract session identifier from event."""
|
||||
payload = event.get('payload', {})
|
||||
if event.get('type') == 'conversation.message.out' and payload.get('runId'):
|
||||
return payload['runId']
|
||||
if payload.get('sessionId'):
|
||||
return payload['sessionId']
|
||||
return event.get('session', 'unknown')
|
||||
|
||||
|
||||
def clean_user_text(text: str) -> str:
|
||||
"""Strip metadata from user message, return the actual content."""
|
||||
# Remove System: [timestamp] Matrix message from X: prefix
|
||||
text = re.sub(r'^System:\s*\[.*?\]\s*(?:Matrix\s+message\s+from\s+\w+:\s*)?', '', text).strip()
|
||||
# Remove [Day YYYY-MM-DD HH:MM TZ] or [YYYY-MM-DD ...] timestamp prefix
|
||||
text = re.sub(r'^\[(?:\w+\s+)?\d{4}-\d{2}-\d{2}[^\]]*\]\s*', '', text).strip()
|
||||
# Remove [Matrix user ...] prefix
|
||||
text = re.sub(r'^\[Matrix\s+\w+[^\]]*\]\s*', '', text).strip()
|
||||
# Remove message_id lines
|
||||
text = re.sub(r'\[message_id:.*?\]', '', text).strip()
|
||||
return text
|
||||
|
||||
|
||||
# --- DPO Pair Construction ---
|
||||
|
||||
def build_dpo_pair(
|
||||
prompt_event: dict,
|
||||
rejected_event: dict,
|
||||
correction_event: dict,
|
||||
correction_label: str,
|
||||
correction_confidence: float,
|
||||
) -> Optional[dict]:
|
||||
"""Build a DPO training pair from a correction sequence.
|
||||
|
||||
Returns dict with: prompt, chosen, rejected, metadata
|
||||
"""
|
||||
prompt_text = clean_user_text(extract_user_text(prompt_event) or '')
|
||||
rejected_text = extract_assistant_text(rejected_event) or ''
|
||||
correction_text = clean_user_text(extract_user_text(correction_event) or '')
|
||||
|
||||
# Validate lengths
|
||||
if len(prompt_text) < MIN_PROMPT_LEN:
|
||||
return None
|
||||
if len(rejected_text) < MIN_RESPONSE_LEN:
|
||||
return None
|
||||
if len(correction_text) < MIN_CORRECTION_LEN:
|
||||
return None
|
||||
|
||||
# The "chosen" response is constructed from the correction context.
|
||||
# We use the correction as a signal — the chosen text is what the user
|
||||
# wanted instead. For DPO training, we need an actual better response.
|
||||
# Strategy: use the correction itself as context for what "chosen" should be.
|
||||
# In practice, after the correction, the assistant usually gives a better response.
|
||||
# We'll look for that in the caller.
|
||||
|
||||
return {
|
||||
'prompt': prompt_text,
|
||||
'rejected': rejected_text,
|
||||
'chosen': None, # To be filled by caller with post-correction response
|
||||
'correction': correction_text,
|
||||
'metadata': {
|
||||
'correction_type': correction_label,
|
||||
'confidence': correction_confidence,
|
||||
'prompt_seq': prompt_event.get('seq'),
|
||||
'rejected_seq': rejected_event.get('seq'),
|
||||
'correction_seq': correction_event.get('seq'),
|
||||
'session': get_session_id(prompt_event),
|
||||
'timestamp': correction_event.get('timestamp'),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# --- NATS Fetching ---
|
||||
|
||||
# --- Session Transcript Parsing ---
|
||||
|
||||
def load_session_transcripts(sessions_dir: str, since_hours: int = 168) -> list[dict]:
|
||||
"""Load conversation messages from OpenClaw session JSONL files.
|
||||
|
||||
Returns a flat list of events with 'role', 'text', 'session', 'timestamp', 'seq'.
|
||||
"""
|
||||
from pathlib import Path
|
||||
import time
|
||||
|
||||
sessions_path = Path(sessions_dir)
|
||||
if not sessions_path.exists():
|
||||
print(f" ⚠️ Sessions dir not found: {sessions_dir}", file=sys.stderr)
|
||||
return []
|
||||
|
||||
cutoff_time = time.time() - (since_hours * 3600)
|
||||
events = []
|
||||
files_processed = 0
|
||||
|
||||
for jsonl_file in sorted(sessions_path.glob('*.jsonl')):
|
||||
# Skip old files
|
||||
if jsonl_file.stat().st_mtime < cutoff_time:
|
||||
continue
|
||||
|
||||
session_id = jsonl_file.stem
|
||||
seq = 0
|
||||
|
||||
try:
|
||||
with open(jsonl_file) as fh:
|
||||
for line in fh:
|
||||
try:
|
||||
entry = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
if entry.get('type') != 'message':
|
||||
continue
|
||||
|
||||
msg = entry.get('message', {})
|
||||
role = msg.get('role', '')
|
||||
if role not in ('user', 'assistant'):
|
||||
continue
|
||||
|
||||
# Extract text content
|
||||
content = msg.get('content', '')
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get('type') == 'text':
|
||||
text_parts.append(part.get('text', ''))
|
||||
content = '\n'.join(text_parts)
|
||||
|
||||
if not content:
|
||||
continue
|
||||
|
||||
events.append({
|
||||
'type': f'conversation.message.{"in" if role == "user" else "out"}',
|
||||
'role': role,
|
||||
'text': content,
|
||||
'session': session_id,
|
||||
'timestamp': entry.get('timestamp', 0),
|
||||
'seq': seq,
|
||||
'payload': {
|
||||
'text_preview': [{'type': 'text', 'text': content}] if role == 'user' else {},
|
||||
'data': {'text': content} if role == 'assistant' else {},
|
||||
'sessionId': session_id,
|
||||
},
|
||||
})
|
||||
seq += 1
|
||||
|
||||
files_processed += 1
|
||||
except Exception as e:
|
||||
print(f" ⚠️ Error reading {jsonl_file.name}: {e}", file=sys.stderr)
|
||||
|
||||
print(f" Loaded {len(events)} messages from {files_processed} session files", file=sys.stderr)
|
||||
return events
|
||||
|
||||
|
||||
def fetch_events_by_sequence(start_seq: int, end_seq: int, batch_size: int = 500) -> list[dict]:
|
||||
"""Fetch events from NATS by sequence range using nats-py (fast) or CLI (fallback)."""
|
||||
# Try fast async fetch first
|
||||
try:
|
||||
import asyncio
|
||||
events = asyncio.run(_fetch_events_async(start_seq, end_seq))
|
||||
if events:
|
||||
return events
|
||||
except Exception as e:
|
||||
print(f" Async fetch failed ({e}), falling back to CLI", file=sys.stderr)
|
||||
|
||||
return _fetch_events_cli(start_seq, end_seq)
|
||||
|
||||
|
||||
async def _fetch_events_async(start_seq: int, end_seq: int) -> list[dict]:
|
||||
"""Fast bulk fetch using nats-py consumer."""
|
||||
import nats as nats_lib
|
||||
from nats.js.api import ConsumerConfig, DeliverPolicy
|
||||
|
||||
user = os.environ.get('NATS_USER', 'claudia')
|
||||
password = os.environ.get('NATS_PASSWORD', '')
|
||||
|
||||
nc = await nats_lib.connect(
|
||||
'nats://localhost:4222',
|
||||
user=user, password=password,
|
||||
)
|
||||
js = nc.jetstream()
|
||||
|
||||
events = []
|
||||
# Create ephemeral ordered consumer starting at our sequence
|
||||
sub = await js.subscribe(
|
||||
'openclaw.events.>',
|
||||
ordered_consumer=True,
|
||||
config=ConsumerConfig(
|
||||
deliver_policy=DeliverPolicy.BY_START_SEQUENCE,
|
||||
opt_start_seq=start_seq,
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
count = 0
|
||||
while True:
|
||||
try:
|
||||
msg = await asyncio.wait_for(sub.next_msg(), timeout=2.0)
|
||||
try:
|
||||
event = json.loads(msg.data.decode())
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
count += 1
|
||||
continue
|
||||
event['seq'] = count + start_seq
|
||||
events.append(event)
|
||||
count += 1
|
||||
|
||||
if count % 1000 == 0:
|
||||
print(f" Fetched {count} events...", file=sys.stderr, end='\r')
|
||||
|
||||
# Stop at end_seq (approximate via count)
|
||||
if count >= (end_seq - start_seq + 1):
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
finally:
|
||||
await sub.unsubscribe()
|
||||
await nc.drain()
|
||||
|
||||
print(f" Fetched {len(events)} events (async) ", file=sys.stderr)
|
||||
return events
|
||||
|
||||
|
||||
def _fetch_events_cli(start_seq: int, end_seq: int) -> list[dict]:
|
||||
"""Fallback: fetch events one by one via nats CLI."""
|
||||
import subprocess
|
||||
|
||||
events = []
|
||||
errors = 0
|
||||
|
||||
for seq in range(start_seq, end_seq + 1):
|
||||
try:
|
||||
result = subprocess.run(
|
||||
['nats', 'stream', 'get', 'openclaw-events', str(seq)],
|
||||
capture_output=True, text=True, timeout=5,
|
||||
)
|
||||
for line in result.stdout.split('\n'):
|
||||
if line.startswith('{'):
|
||||
event = json.loads(line)
|
||||
event['seq'] = seq
|
||||
events.append(event)
|
||||
break
|
||||
except Exception:
|
||||
errors += 1
|
||||
if errors > 50:
|
||||
print(f" ⚠️ Too many errors ({errors}), stopping", file=sys.stderr)
|
||||
break
|
||||
|
||||
if (seq - start_seq) % 200 == 0 and seq > start_seq:
|
||||
print(f" Fetched {seq - start_seq}/{end_seq - start_seq} events...",
|
||||
file=sys.stderr, end='\r')
|
||||
|
||||
print(f" Fetched {len(events)} events ({errors} errors) ", file=sys.stderr)
|
||||
return events
|
||||
|
||||
|
||||
def get_stream_info() -> dict:
|
||||
"""Get NATS stream info."""
|
||||
import subprocess
|
||||
result = subprocess.run(
|
||||
['nats', 'stream', 'info', 'openclaw-events', '--json'],
|
||||
capture_output=True, text=True, timeout=10,
|
||||
)
|
||||
return json.loads(result.stdout)
|
||||
|
||||
|
||||
# --- Main Extraction Pipeline ---
|
||||
|
||||
def extract_dpo_pairs(events: list[dict], verbose: bool = False) -> list[dict]:
|
||||
"""Extract DPO pairs from a list of events.
|
||||
|
||||
Scanning strategy:
|
||||
1. Group events by session
|
||||
2. Within each session, find the pattern:
|
||||
user_msg → assistant_response → correction → (optional) better_response
|
||||
3. Build DPO pair: prompt=user_msg, rejected=assistant_response, chosen=better_response
|
||||
"""
|
||||
# Group by session
|
||||
sessions: dict[str, list[dict]] = {}
|
||||
for event in events:
|
||||
sid = get_session_id(event)
|
||||
sessions.setdefault(sid, []).append(event)
|
||||
|
||||
pairs = []
|
||||
stats = {'sessions': 0, 'corrections_found': 0, 'pairs_built': 0, 'pairs_with_chosen': 0}
|
||||
|
||||
for sid, session_events in sessions.items():
|
||||
session_events.sort(key=lambda e: e.get('seq', 0))
|
||||
stats['sessions'] += 1
|
||||
|
||||
# Build conversation sequence: list of (role, text, event)
|
||||
conversation = []
|
||||
for event in session_events:
|
||||
user_text = extract_user_text(event)
|
||||
if user_text:
|
||||
conversation.append(('user', user_text, event))
|
||||
continue
|
||||
asst_text = extract_assistant_text(event)
|
||||
if asst_text:
|
||||
# Keep the longest assistant response in a streak
|
||||
if conversation and conversation[-1][0] == 'assistant':
|
||||
if len(asst_text) > len(conversation[-1][1]):
|
||||
conversation[-1] = ('assistant', asst_text, event)
|
||||
else:
|
||||
conversation.append(('assistant', asst_text, event))
|
||||
|
||||
# Scan for correction pattern: user → assistant → user(correction) → assistant(better)
|
||||
for i in range(len(conversation) - 2):
|
||||
if (conversation[i][0] == 'user' and
|
||||
conversation[i + 1][0] == 'assistant' and
|
||||
conversation[i + 2][0] == 'user'):
|
||||
|
||||
correction_result = detect_correction(conversation[i + 2][1])
|
||||
if not correction_result:
|
||||
continue
|
||||
|
||||
label, confidence = correction_result
|
||||
stats['corrections_found'] += 1
|
||||
|
||||
pair = build_dpo_pair(
|
||||
prompt_event=conversation[i][2],
|
||||
rejected_event=conversation[i + 1][2],
|
||||
correction_event=conversation[i + 2][2],
|
||||
correction_label=label,
|
||||
correction_confidence=confidence,
|
||||
)
|
||||
|
||||
if not pair:
|
||||
continue
|
||||
|
||||
# Look for the better response after the correction
|
||||
if (i + 3 < len(conversation) and
|
||||
conversation[i + 3][0] == 'assistant'):
|
||||
chosen_text = conversation[i + 3][1]
|
||||
if len(chosen_text) >= MIN_RESPONSE_LEN:
|
||||
pair['chosen'] = chosen_text
|
||||
stats['pairs_with_chosen'] += 1
|
||||
|
||||
stats['pairs_built'] += 1
|
||||
pairs.append(pair)
|
||||
|
||||
if verbose:
|
||||
print(f"\n 📌 [{label}] conf={confidence:.0%}", file=sys.stderr)
|
||||
print(f" prompt: {pair['prompt'][:80]}...", file=sys.stderr)
|
||||
print(f" correction: {pair['correction'][:80]}...", file=sys.stderr)
|
||||
|
||||
return pairs, stats
|
||||
|
||||
|
||||
def to_dpo_training_format(pairs: list[dict]) -> list[dict]:
|
||||
"""Convert to standard DPO training format for trl.DPOTrainer.
|
||||
|
||||
Only includes pairs that have both chosen and rejected responses.
|
||||
"""
|
||||
training_pairs = []
|
||||
for pair in pairs:
|
||||
if not pair.get('chosen'):
|
||||
continue
|
||||
|
||||
training_pairs.append({
|
||||
'prompt': pair['prompt'],
|
||||
'chosen': pair['chosen'],
|
||||
'rejected': pair['rejected'],
|
||||
})
|
||||
|
||||
return training_pairs
|
||||
|
||||
|
||||
def to_detailed_format(pairs: list[dict]) -> list[dict]:
|
||||
"""Full format with metadata for inspection and debugging."""
|
||||
return [{
|
||||
'prompt': p['prompt'],
|
||||
'chosen': p.get('chosen', ''),
|
||||
'rejected': p['rejected'],
|
||||
'correction': p['correction'],
|
||||
'has_chosen': bool(p.get('chosen')),
|
||||
**p['metadata'],
|
||||
} for p in pairs]
|
||||
|
||||
|
||||
# --- CLI ---
|
||||
|
||||
def parse_duration(s: str) -> timedelta:
|
||||
"""Parse '7d', '24h', '30m' to timedelta."""
|
||||
m = re.match(r'^(\d+)([dhm])$', s.lower())
|
||||
if not m:
|
||||
raise ValueError(f"Invalid duration: {s}")
|
||||
v, u = int(m.group(1)), m.group(2)
|
||||
return {'d': timedelta(days=v), 'h': timedelta(hours=v), 'm': timedelta(minutes=v)}[u]
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Extract DPO preference pairs from NATS event store',
|
||||
)
|
||||
parser.add_argument('--since', default='7d', help='Time window (e.g. 7d, 24h)')
|
||||
parser.add_argument('--output', '-o', help='Output file (default: auto)')
|
||||
parser.add_argument('--format', choices=['training', 'detailed', 'both'], default='both',
|
||||
help='Output format')
|
||||
parser.add_argument('--min-confidence', type=float, default=0.7,
|
||||
help='Minimum correction confidence (0-1)')
|
||||
parser.add_argument('--dry-run', action='store_true', help='Show stats only, no output')
|
||||
parser.add_argument('--verbose', '-v', action='store_true', help='Show each found pair')
|
||||
parser.add_argument('--sessions-dir', default=None,
|
||||
help='Path to OpenClaw session JSONL dir (default: ~/.openclaw/agents/main/sessions)')
|
||||
parser.add_argument('--source', choices=['sessions', 'nats', 'auto'], default='auto',
|
||||
help='Data source: session transcripts (preferred) or NATS events')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("🔍 DPO Preference Pair Extractor", file=sys.stderr)
|
||||
|
||||
duration = parse_duration(args.since)
|
||||
hours = duration.total_seconds() / 3600
|
||||
|
||||
# Determine data source
|
||||
sessions_dir = args.sessions_dir or os.path.expanduser('~/.openclaw/agents/main/sessions')
|
||||
use_sessions = args.source == 'sessions' or (
|
||||
args.source == 'auto' and os.path.isdir(sessions_dir)
|
||||
)
|
||||
|
||||
if use_sessions:
|
||||
print(f" Source: Session transcripts ({sessions_dir})", file=sys.stderr)
|
||||
conv_events = load_session_transcripts(sessions_dir, since_hours=int(hours))
|
||||
else:
|
||||
print(f" Source: NATS event store", file=sys.stderr)
|
||||
info = get_stream_info()
|
||||
last_seq = info['state']['last_seq']
|
||||
first_seq = info['state']['first_seq']
|
||||
estimated_events = int(hours * 50)
|
||||
start_seq = max(first_seq, last_seq - estimated_events)
|
||||
print(f" Scanning sequences {start_seq}-{last_seq}", file=sys.stderr)
|
||||
events = fetch_events_by_sequence(start_seq, last_seq)
|
||||
conv_events = [e for e in events if e.get('type', '').startswith('conversation.message')]
|
||||
print(f" {len(conv_events)} conversation events out of {len(events)} total", file=sys.stderr)
|
||||
|
||||
# Extract pairs
|
||||
pairs, stats = extract_dpo_pairs(conv_events, verbose=args.verbose)
|
||||
|
||||
# Filter by confidence
|
||||
pairs = [p for p in pairs if p['metadata']['confidence'] >= args.min_confidence]
|
||||
|
||||
# Stats
|
||||
print(f"\n📊 Results:", file=sys.stderr)
|
||||
print(f" Sessions scanned: {stats['sessions']}", file=sys.stderr)
|
||||
print(f" Corrections detected: {stats['corrections_found']}", file=sys.stderr)
|
||||
print(f" DPO pairs built: {stats['pairs_built']}", file=sys.stderr)
|
||||
print(f" Pairs with chosen response: {stats['pairs_with_chosen']}", file=sys.stderr)
|
||||
print(f" After confidence filter (≥{args.min_confidence}): {len(pairs)}", file=sys.stderr)
|
||||
|
||||
complete = [p for p in pairs if p.get('chosen')]
|
||||
incomplete = [p for p in pairs if not p.get('chosen')]
|
||||
print(f" Complete (prompt+chosen+rejected): {len(complete)}", file=sys.stderr)
|
||||
print(f" Incomplete (no chosen response): {len(incomplete)}", file=sys.stderr)
|
||||
|
||||
if args.dry_run:
|
||||
# Show sample pairs
|
||||
if pairs:
|
||||
print(f"\n📝 Sample pairs:", file=sys.stderr)
|
||||
for p in pairs[:5]:
|
||||
print(f"\n [{p['metadata']['correction_type']}] "
|
||||
f"conf={p['metadata']['confidence']:.0%}", file=sys.stderr)
|
||||
print(f" prompt: {p['prompt'][:100]}", file=sys.stderr)
|
||||
print(f" rejected: {p['rejected'][:100]}", file=sys.stderr)
|
||||
print(f" correction: {p['correction'][:100]}", file=sys.stderr)
|
||||
if p.get('chosen'):
|
||||
print(f" chosen: {p['chosen'][:100]}", file=sys.stderr)
|
||||
return
|
||||
|
||||
# Output
|
||||
output_dir = Path(os.environ.get('CLAWD_DIR', Path.home() / 'clawd')) / 'training-data'
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
timestamp = datetime.now().strftime('%Y-%m-%d')
|
||||
|
||||
if args.format in ('training', 'both'):
|
||||
training_data = to_dpo_training_format(pairs)
|
||||
path = Path(args.output) if args.output else output_dir / f'dpo-training-{timestamp}.json'
|
||||
path.write_text(json.dumps(training_data, indent=2, ensure_ascii=False))
|
||||
print(f"\n✅ Training format: {path} ({len(training_data)} pairs)", file=sys.stderr)
|
||||
|
||||
if args.format in ('detailed', 'both'):
|
||||
detailed_data = to_detailed_format(pairs)
|
||||
path = output_dir / f'dpo-detailed-{timestamp}.json'
|
||||
path.write_text(json.dumps(detailed_data, indent=2, ensure_ascii=False))
|
||||
print(f"✅ Detailed format: {path} ({len(detailed_data)} pairs)", file=sys.stderr)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
341
tests/test_dpo_extractor.py
Normal file
341
tests/test_dpo_extractor.py
Normal file
|
|
@ -0,0 +1,341 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Tests for DPO Preference Pair Extractor."""
|
||||
|
||||
import pytest
|
||||
from cortex.dpo_extractor import (
|
||||
detect_correction,
|
||||
is_false_positive,
|
||||
extract_user_text,
|
||||
extract_assistant_text,
|
||||
get_session_id,
|
||||
clean_user_text,
|
||||
build_dpo_pair,
|
||||
extract_dpo_pairs,
|
||||
to_dpo_training_format,
|
||||
to_detailed_format,
|
||||
MIN_PROMPT_LEN,
|
||||
MIN_RESPONSE_LEN,
|
||||
MIN_CORRECTION_LEN,
|
||||
)
|
||||
|
||||
|
||||
# --- Correction Detection ---
|
||||
|
||||
class TestDetectCorrection:
|
||||
"""Test correction pattern detection."""
|
||||
|
||||
def test_direct_negation_with_instruction(self):
|
||||
result = detect_correction("nein, du sollst ollama anhalten und das model trainieren")
|
||||
assert result is not None
|
||||
label, conf = result
|
||||
assert label == 'direct_correction'
|
||||
assert conf >= 0.85
|
||||
|
||||
def test_redirect(self):
|
||||
result = detect_correction("nicht so, sondern mit apply_chat_template")
|
||||
assert result is not None
|
||||
assert result[0] == 'redirect'
|
||||
|
||||
def test_factual_correction(self):
|
||||
result = detect_correction("das stimmt nicht, die GPU hat 20GB VRAM")
|
||||
assert result is not None
|
||||
assert result[0] == 'factual_correction'
|
||||
|
||||
def test_frustration_correction(self):
|
||||
result = detect_correction("bist du bescheuert? Wir haben doch alles da")
|
||||
assert result is not None
|
||||
assert result[0] == 'frustration_correction'
|
||||
assert result[1] >= 0.9
|
||||
|
||||
def test_forgotten(self):
|
||||
result = detect_correction("du hast vergessen warum wir Darkplex gebaut haben")
|
||||
assert result is not None
|
||||
assert result[0] == 'forgotten'
|
||||
|
||||
def test_repeated_instruction(self):
|
||||
result = detect_correction("ich habe doch gesagt wir sollen Gemma nehmen")
|
||||
assert result is not None
|
||||
assert result[0] == 'repeated_instruction'
|
||||
|
||||
def test_mild_redirect(self):
|
||||
result = detect_correction("naja, eher so dass wir die Pipeline automatisieren")
|
||||
assert result is not None
|
||||
assert result[0] == 'mild_redirect'
|
||||
|
||||
def test_negative_feedback(self):
|
||||
result = detect_correction("finds blöd dass ich dich immer so testen muss")
|
||||
assert result is not None
|
||||
assert result[0] == 'negative_feedback'
|
||||
|
||||
def test_should_redirect(self):
|
||||
result = detect_correction("du solltest eher den bestehenden Code checken statt neu zu bauen")
|
||||
assert result is not None
|
||||
assert result[0] == 'should_redirect'
|
||||
|
||||
# --- Non-corrections ---
|
||||
|
||||
def test_normal_message_no_correction(self):
|
||||
result = detect_correction("kannst du mir den Status von Ollama zeigen?")
|
||||
assert result is None
|
||||
|
||||
def test_positive_feedback_no_correction(self):
|
||||
result = detect_correction("super, genau so wollte ich das haben!")
|
||||
assert result is None
|
||||
|
||||
def test_short_message_no_correction(self):
|
||||
result = detect_correction("ja")
|
||||
assert result is None
|
||||
|
||||
def test_question_no_correction(self):
|
||||
result = detect_correction("wie viel VRAM hat die GPU auf Desktop01?")
|
||||
assert result is None
|
||||
|
||||
def test_timestamp_prefix_stripped(self):
|
||||
result = detect_correction("[Thu 2026-02-12 09:17 GMT+1] nein, du sollst ollama anhalten")
|
||||
assert result is not None
|
||||
assert result[0] == 'direct_correction'
|
||||
|
||||
def test_matrix_prefix_stripped(self):
|
||||
result = detect_correction(
|
||||
"System: [2026-02-12 09:17:00 GMT+1] Matrix message from albert: "
|
||||
"nein, du sollst das anders machen, ich meinte die andere Config"
|
||||
)
|
||||
assert result is not None
|
||||
|
||||
|
||||
class TestFalsePositives:
|
||||
"""Ensure we don't flag system messages as corrections."""
|
||||
|
||||
def test_system_message(self):
|
||||
assert is_false_positive("System: heartbeat check ok") is True
|
||||
|
||||
def test_heartbeat(self):
|
||||
assert is_false_positive("Read HEARTBEAT.md if it exists") is True
|
||||
|
||||
def test_media_attachment(self):
|
||||
assert is_false_positive("[media attached: /home/keller/file.pdf]") is True
|
||||
|
||||
def test_cron_message(self):
|
||||
assert is_false_positive("[cron:abc-123] Run learning context") is True
|
||||
|
||||
def test_exec_result(self):
|
||||
assert is_false_positive("exec completed with code 0") is True
|
||||
|
||||
def test_precompaction(self):
|
||||
assert is_false_positive("Pre-compaction memory flush. Store durable memories now") is True
|
||||
|
||||
def test_normal_message_not_false_positive(self):
|
||||
assert is_false_positive("lass uns ein anderes model wählen") is False
|
||||
|
||||
|
||||
# --- Event Parsing ---
|
||||
|
||||
class TestEventParsing:
|
||||
"""Test extraction from various event formats."""
|
||||
|
||||
def test_extract_user_text_preview(self):
|
||||
event = {
|
||||
'type': 'conversation.message.in',
|
||||
'payload': {
|
||||
'text_preview': [{'type': 'text', 'text': 'hallo welt'}],
|
||||
}
|
||||
}
|
||||
assert extract_user_text(event) == 'hallo welt'
|
||||
|
||||
def test_extract_user_content(self):
|
||||
event = {
|
||||
'type': 'conversation.message.in',
|
||||
'payload': {'content': 'hallo welt'},
|
||||
}
|
||||
assert extract_user_text(event) == 'hallo welt'
|
||||
|
||||
def test_extract_user_wrong_type(self):
|
||||
event = {
|
||||
'type': 'conversation.message.out',
|
||||
'payload': {'content': 'response'},
|
||||
}
|
||||
assert extract_user_text(event) is None
|
||||
|
||||
def test_extract_assistant_data_text(self):
|
||||
event = {
|
||||
'type': 'conversation.message.out',
|
||||
'payload': {'data': {'text': 'here is my response'}},
|
||||
}
|
||||
assert extract_assistant_text(event) == 'here is my response'
|
||||
|
||||
def test_extract_assistant_content(self):
|
||||
event = {
|
||||
'type': 'conversation.message.out',
|
||||
'payload': {'content': 'response text'},
|
||||
}
|
||||
assert extract_assistant_text(event) == 'response text'
|
||||
|
||||
def test_session_id_from_payload(self):
|
||||
event = {
|
||||
'type': 'conversation.message.in',
|
||||
'payload': {'sessionId': 'abc-123'},
|
||||
}
|
||||
assert get_session_id(event) == 'abc-123'
|
||||
|
||||
def test_session_id_from_run_id(self):
|
||||
event = {
|
||||
'type': 'conversation.message.out',
|
||||
'payload': {'runId': 'run-456'},
|
||||
}
|
||||
assert get_session_id(event) == 'run-456'
|
||||
|
||||
|
||||
class TestCleanUserText:
|
||||
def test_strip_timestamp(self):
|
||||
assert clean_user_text("[Thu 2026-02-12 09:17 GMT+1] hallo") == "hallo"
|
||||
|
||||
def test_strip_matrix_prefix(self):
|
||||
text = "System: [2026-02-12 09:17:00 GMT+1] Matrix message from albert: hallo"
|
||||
assert clean_user_text(text) == "hallo"
|
||||
|
||||
def test_strip_message_id(self):
|
||||
text = "some text\n[message_id: abc-123]"
|
||||
assert clean_user_text(text) == "some text"
|
||||
|
||||
def test_plain_text_unchanged(self):
|
||||
assert clean_user_text("just a normal message") == "just a normal message"
|
||||
|
||||
|
||||
# --- DPO Pair Building ---
|
||||
|
||||
def _make_event(type_: str, text: str, seq: int = 1, session: str = 'test') -> dict:
|
||||
"""Helper to create test events."""
|
||||
if type_ == 'conversation.message.in':
|
||||
return {
|
||||
'type': type_,
|
||||
'seq': seq,
|
||||
'timestamp': 1000000 + seq,
|
||||
'payload': {
|
||||
'text_preview': [{'type': 'text', 'text': text}],
|
||||
'sessionId': session,
|
||||
},
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'type': type_,
|
||||
'seq': seq,
|
||||
'timestamp': 1000000 + seq,
|
||||
'payload': {
|
||||
'data': {'text': text},
|
||||
'runId': session,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class TestBuildDPOPair:
|
||||
def test_valid_pair(self):
|
||||
prompt = _make_event('conversation.message.in',
|
||||
'Kannst du mir ein Training-Script für Gemma erstellen?', seq=1)
|
||||
rejected = _make_event('conversation.message.out',
|
||||
'Hier ist ein Script mit hardcoded chat template tags... ' * 5, seq=2)
|
||||
correction = _make_event('conversation.message.in',
|
||||
'nein, du sollst tokenizer.apply_chat_template benutzen statt hardcoded tags',
|
||||
seq=3)
|
||||
|
||||
pair = build_dpo_pair(prompt, rejected, correction, 'direct_correction', 0.9)
|
||||
assert pair is not None
|
||||
assert 'Gemma' in pair['prompt']
|
||||
assert 'hardcoded' in pair['rejected']
|
||||
assert pair['chosen'] is None # Caller fills this
|
||||
assert pair['metadata']['correction_type'] == 'direct_correction'
|
||||
|
||||
def test_too_short_prompt(self):
|
||||
prompt = _make_event('conversation.message.in', 'ja?', seq=1)
|
||||
rejected = _make_event('conversation.message.out', 'x' * 100, seq=2)
|
||||
correction = _make_event('conversation.message.in', 'nein, ich meinte etwas ganz anderes', seq=3)
|
||||
|
||||
pair = build_dpo_pair(prompt, rejected, correction, 'direct_correction', 0.9)
|
||||
assert pair is None
|
||||
|
||||
def test_too_short_response(self):
|
||||
prompt = _make_event('conversation.message.in', 'Erstelle mir einen DPO Extraktor', seq=1)
|
||||
rejected = _make_event('conversation.message.out', 'Ok.', seq=2)
|
||||
correction = _make_event('conversation.message.in', 'das ist falsch, mach das bitte richtig', seq=3)
|
||||
|
||||
pair = build_dpo_pair(prompt, rejected, correction, 'factual_correction', 0.85)
|
||||
assert pair is None
|
||||
|
||||
|
||||
# --- Full Pipeline ---
|
||||
|
||||
class TestExtractDPOPairs:
|
||||
def test_basic_correction_flow(self):
|
||||
"""Test: user → assistant → correction → better response."""
|
||||
events = [
|
||||
_make_event('conversation.message.in',
|
||||
'Erstelle mir ein Training Script für Gemma 2 auf der 7800 XT', seq=1),
|
||||
_make_event('conversation.message.out',
|
||||
'Hier ist das Script mit <start_of_turn>user tags... ' * 5, seq=2),
|
||||
_make_event('conversation.message.in',
|
||||
'nein, du sollst tokenizer.apply_chat_template() benutzen, nicht hardcoded tags',
|
||||
seq=3),
|
||||
_make_event('conversation.message.out',
|
||||
'Du hast recht, hier ist die korrigierte Version mit apply_chat_template()... ' * 5,
|
||||
seq=4),
|
||||
]
|
||||
|
||||
pairs, stats = extract_dpo_pairs(events)
|
||||
assert len(pairs) >= 1
|
||||
assert pairs[0].get('chosen') is not None
|
||||
assert stats['corrections_found'] >= 1
|
||||
|
||||
def test_no_corrections(self):
|
||||
"""Normal conversation without corrections → no pairs."""
|
||||
events = [
|
||||
_make_event('conversation.message.in', 'Was ist der Status von Ollama?', seq=1),
|
||||
_make_event('conversation.message.out', 'Ollama läuft und hat das claudia-memory Modell geladen. ' * 5, seq=2),
|
||||
_make_event('conversation.message.in', 'Super, danke für die Info!', seq=3),
|
||||
]
|
||||
|
||||
pairs, stats = extract_dpo_pairs(events)
|
||||
assert len(pairs) == 0
|
||||
|
||||
def test_multiple_sessions(self):
|
||||
"""Corrections from different sessions are handled independently."""
|
||||
events = [
|
||||
_make_event('conversation.message.in', 'Mach mir ein neues Feature für Darkplex', seq=1, session='s1'),
|
||||
_make_event('conversation.message.out', 'Ich baue jetzt ein komplett neues System dafür... ' * 5, seq=2, session='s1'),
|
||||
_make_event('conversation.message.in',
|
||||
'bist du bescheuert? Wir haben doch alles da, schau erst was existiert',
|
||||
seq=3, session='s1'),
|
||||
_make_event('conversation.message.in', 'Zeig mir den Wetterbericht', seq=4, session='s2'),
|
||||
_make_event('conversation.message.out', 'Das Wetter in Berlin ist sonnig bei 5 Grad... ' * 5, seq=5, session='s2'),
|
||||
]
|
||||
|
||||
pairs, stats = extract_dpo_pairs(events)
|
||||
# Only session s1 should produce a pair
|
||||
assert len(pairs) == 1
|
||||
assert pairs[0]['metadata']['session'] == 's1'
|
||||
|
||||
|
||||
class TestOutputFormats:
|
||||
def test_training_format_only_complete(self):
|
||||
"""Training format only includes pairs with both chosen and rejected."""
|
||||
pairs = [
|
||||
{'prompt': 'q1', 'chosen': 'good answer', 'rejected': 'bad answer',
|
||||
'correction': 'fix it', 'metadata': {}},
|
||||
{'prompt': 'q2', 'chosen': None, 'rejected': 'wrong',
|
||||
'correction': 'nope', 'metadata': {}},
|
||||
]
|
||||
training = to_dpo_training_format(pairs)
|
||||
assert len(training) == 1
|
||||
assert training[0]['prompt'] == 'q1'
|
||||
|
||||
def test_detailed_format_all_pairs(self):
|
||||
"""Detailed format includes all pairs with metadata."""
|
||||
pairs = [
|
||||
{'prompt': 'q1', 'chosen': 'good', 'rejected': 'bad',
|
||||
'correction': 'fix', 'metadata': {'correction_type': 'test', 'confidence': 0.9,
|
||||
'prompt_seq': 1, 'rejected_seq': 2,
|
||||
'correction_seq': 3, 'session': 's1',
|
||||
'timestamp': 1000}},
|
||||
]
|
||||
detailed = to_detailed_format(pairs)
|
||||
assert len(detailed) == 1
|
||||
assert detailed[0]['correction_type'] == 'test'
|
||||
assert detailed[0]['has_chosen'] is True
|
||||
Loading…
Reference in a new issue