feat: expand signal extraction — positive reinforcement, teaching, soft redirects
Some checks failed
Tests / test (push) Failing after 5s
Some checks failed
Tests / test (push) Failing after 5s
- Renamed to 'Preference & Learning Pair Extractor' - NEW: Positive reinforcement detection (praise, affirmation, emoji, accept+continue) - NEW: Teaching moment detection (rules, explanations, reminders, preferences) - NEW: Soft redirect detection (let's rather, alternative plan, switch to) - Outputs: DPO pairs, SFT pairs (alpaca format), teaching pairs - Improved false positive filters (subagent output, apt, system messages) - 4x more training signal: 3 → 11 pairs from same 30-day window - 446 tests passing
This commit is contained in:
parent
a3764c627d
commit
d60d337da3
1 changed files with 376 additions and 92 deletions
|
|
@ -1,16 +1,19 @@
|
|||
#!/usr/bin/env python3
|
||||
"""DPO Preference Pair Extractor for Darkplex.
|
||||
"""Preference & Learning Pair Extractor for Darkplex.
|
||||
|
||||
Extracts Direct Preference Optimization (DPO) training pairs from NATS event store.
|
||||
DPO pairs consist of (prompt, chosen, rejected) where:
|
||||
- prompt: the user's original request
|
||||
- rejected: Claudia's response that was corrected
|
||||
- chosen: the corrected/better response (derived from the correction)
|
||||
Extracts training signal from session transcripts in multiple categories:
|
||||
|
||||
Correction detection strategy:
|
||||
1. Find user messages that explicitly correct the previous assistant response
|
||||
2. Pair the corrected response (rejected) with the correction context (chosen)
|
||||
3. Quality-filter to avoid false positives
|
||||
1. **DPO Pairs** (correction → preference): prompt + chosen + rejected
|
||||
- Hard corrections: "nein, falsch, das stimmt nicht"
|
||||
- Soft redirects: "lass uns lieber...", "eher so..."
|
||||
|
||||
2. **SFT Pairs** (positive reinforcement → good examples): prompt + response
|
||||
- Positive signals: "super", "genau so", "perfekt", 👍
|
||||
- These are responses worth reinforcing
|
||||
|
||||
3. **Teaching Pairs** (knowledge transfer → learning): context + lesson
|
||||
- Albert explains something new
|
||||
- "wir haben mal gesagt...", "das ist weil...", "denk dran..."
|
||||
|
||||
Usage:
|
||||
python -m cortex.dpo_extractor --since 7d
|
||||
|
|
@ -60,6 +63,9 @@ FALSE_POSITIVE_PATTERNS = [
|
|||
r'^pre-compaction', # Memory flush
|
||||
r'exec\s+(?:completed|failed)', # Exec results
|
||||
r'^read\s+heartbeat', # Heartbeat instructions
|
||||
r'^a\s+subagent\s+task', # Sub-agent completion
|
||||
r'^creating\s+(?:config|symlink)', # System output
|
||||
r'apt\.conf', # Package manager output
|
||||
]
|
||||
|
||||
# Compile patterns
|
||||
|
|
@ -67,10 +73,115 @@ CORRECTION_RE = [(re.compile(p, re.IGNORECASE | re.MULTILINE), label, conf)
|
|||
for p, label, conf in CORRECTION_PATTERNS]
|
||||
FALSE_POSITIVE_RE = [re.compile(p, re.IGNORECASE) for p in FALSE_POSITIVE_PATTERNS]
|
||||
|
||||
# --- Positive Reinforcement Detection ---
|
||||
|
||||
POSITIVE_PATTERNS = [
|
||||
# Explicit praise
|
||||
(r'(?:^|\s)(?:super|perfekt|genau\s*so|klasse|toll|prima|excellent|great|nice)[\s!\.]*$', 'praise', 0.85),
|
||||
(r'(?:^|\s)(?:genau|exactly|perfect|richtig)[\s!\.]*$', 'affirmation', 0.8),
|
||||
(r'(?:das|so)\s+(?:ist|war)\s+(?:super|gut|perfekt|genau\s+richtig)', 'quality_praise', 0.85),
|
||||
(r'(?:gefällt|mag)\s+(?:mir|ich)', 'preference_positive', 0.75),
|
||||
(r'(?:gut|super|toll)\s+gemacht', 'task_praise', 0.9),
|
||||
(r'👍|👏|🙌|❤️|🔥|💯|✅', 'emoji_positive', 0.7),
|
||||
# Implicit positive: user builds on the response
|
||||
(r'(?:ja|ok|alles\s*klar)[,.]?\s+(?:und\s+jetzt|dann|mach\s+(?:mal|jetzt|weiter))', 'accept_continue', 0.7),
|
||||
]
|
||||
|
||||
POSITIVE_RE = [(re.compile(p, re.IGNORECASE | re.MULTILINE), label, conf)
|
||||
for p, label, conf in POSITIVE_PATTERNS]
|
||||
|
||||
# --- Teaching/Knowledge Transfer Detection ---
|
||||
|
||||
TEACHING_PATTERNS = [
|
||||
# Albert explains why/how
|
||||
(r'(?:wir\s+haben\s+(?:mal\s+)?gesagt|wir\s+machen\s+das\s+(?:so|weil))', 'established_rule', 0.85),
|
||||
(r'(?:das\s+ist\s+weil|der\s+grund\s+(?:ist|dafür))', 'explanation', 0.8),
|
||||
(r'(?:denk\s+dran|vergiss\s+nicht|wichtig(?:\s+ist)?:)', 'reminder_teaching', 0.85),
|
||||
(r'(?:du\s+(?:hast|hattest)\s+(?:das\s+)?(?:schon\s+)?(?:mal\s+)?(?:gemacht|installiert|gebaut))', 'prior_knowledge', 0.8),
|
||||
(r'(?:das\s+(?:problem|thema)\s+ist\s+(?:ja\s+)?(?:dass|weil|-))', 'problem_framing', 0.75),
|
||||
(r'(?:ich\s+(?:will|möchte|hätte\s+gerne)\s+(?:dass|lieber|eher))', 'preference_statement', 0.8),
|
||||
(r'(?:die\s+(?:regel|strategie|idee)\s+ist)', 'strategy_teaching', 0.85),
|
||||
# Albert shares context Claudia should remember
|
||||
(r'(?:(?:zur\s+)?info:|fyi:?|heads\s*up:?)', 'info_sharing', 0.7),
|
||||
]
|
||||
|
||||
TEACHING_RE = [(re.compile(p, re.IGNORECASE | re.MULTILINE), label, conf)
|
||||
for p, label, conf in TEACHING_PATTERNS]
|
||||
|
||||
|
||||
# --- Soft Redirect Detection (milder than corrections) ---
|
||||
|
||||
SOFT_REDIRECT_PATTERNS = [
|
||||
(r'(?:lass\s+uns\s+(?:lieber|eher|besser|mal))', 'lets_rather', 0.75),
|
||||
(r'(?:ich\s+würde?\s+(?:eher|lieber|besser))', 'i_would_rather', 0.7),
|
||||
(r'(?:können?\s+wir\s+(?:nicht\s+)?(?:lieber|eher|stattdessen))', 'can_we_instead', 0.7),
|
||||
(r'(?:anderer?\s+(?:plan|idee|ansatz|weg|vorschlag))', 'alternative_plan', 0.75),
|
||||
(r'(?:wechsel(?:n)?\s+(?:wir|mal)\s+(?:zu|auf|den))', 'switch_to', 0.7),
|
||||
]
|
||||
|
||||
SOFT_REDIRECT_RE = [(re.compile(p, re.IGNORECASE | re.MULTILINE), label, conf)
|
||||
for p, label, conf in SOFT_REDIRECT_PATTERNS]
|
||||
|
||||
|
||||
# Minimum lengths
|
||||
MIN_PROMPT_LEN = 15 # User prompt must be meaningful
|
||||
MIN_RESPONSE_LEN = 80 # Assistant response must be substantive
|
||||
MIN_CORRECTION_LEN = 20 # Correction must explain what's wrong
|
||||
MIN_TEACHING_LEN = 30 # Teaching must have substance
|
||||
|
||||
|
||||
def detect_positive(text: str) -> Optional[tuple[str, float]]:
|
||||
"""Detect if user message is positive reinforcement."""
|
||||
clean = clean_user_text(text)
|
||||
if is_false_positive(clean):
|
||||
return None
|
||||
if len(clean) > 200: # Long messages are rarely just praise
|
||||
return None
|
||||
|
||||
best_match = None
|
||||
best_conf = 0.0
|
||||
for pattern, label, conf in POSITIVE_RE:
|
||||
if pattern.search(clean):
|
||||
if conf > best_conf:
|
||||
best_match = (label, conf)
|
||||
best_conf = conf
|
||||
return best_match
|
||||
|
||||
|
||||
def detect_teaching(text: str) -> Optional[tuple[str, float]]:
|
||||
"""Detect if user message is a teaching/knowledge transfer moment."""
|
||||
clean = clean_user_text(text)
|
||||
if is_false_positive(clean):
|
||||
return None
|
||||
if len(clean) < MIN_TEACHING_LEN:
|
||||
return None
|
||||
|
||||
best_match = None
|
||||
best_conf = 0.0
|
||||
for pattern, label, conf in TEACHING_RE:
|
||||
if pattern.search(clean):
|
||||
if conf > best_conf:
|
||||
best_match = (label, conf)
|
||||
best_conf = conf
|
||||
return best_match
|
||||
|
||||
|
||||
def detect_soft_redirect(text: str) -> Optional[tuple[str, float]]:
|
||||
"""Detect if user message is a soft redirect (mild preference signal)."""
|
||||
clean = clean_user_text(text)
|
||||
if is_false_positive(clean):
|
||||
return None
|
||||
if len(clean) < MIN_CORRECTION_LEN:
|
||||
return None
|
||||
|
||||
best_match = None
|
||||
best_conf = 0.0
|
||||
for pattern, label, conf in SOFT_REDIRECT_RE:
|
||||
if pattern.search(clean):
|
||||
if conf > best_conf:
|
||||
best_match = (label, conf)
|
||||
best_conf = conf
|
||||
return best_match
|
||||
|
||||
|
||||
def is_false_positive(text: str) -> bool:
|
||||
|
|
@ -403,14 +514,34 @@ def get_stream_info() -> dict:
|
|||
|
||||
# --- Main Extraction Pipeline ---
|
||||
|
||||
def extract_dpo_pairs(events: list[dict], verbose: bool = False) -> list[dict]:
|
||||
"""Extract DPO pairs from a list of events.
|
||||
def _build_conversation(events: list[dict]) -> list[tuple[str, str, dict]]:
|
||||
"""Build conversation sequence from events: list of (role, text, event)."""
|
||||
conversation = []
|
||||
for event in events:
|
||||
user_text = extract_user_text(event)
|
||||
if user_text:
|
||||
conversation.append(('user', user_text, event))
|
||||
continue
|
||||
asst_text = extract_assistant_text(event)
|
||||
if asst_text:
|
||||
# Keep the longest assistant response in a streak
|
||||
if conversation and conversation[-1][0] == 'assistant':
|
||||
if len(asst_text) > len(conversation[-1][1]):
|
||||
conversation[-1] = ('assistant', asst_text, event)
|
||||
else:
|
||||
conversation.append(('assistant', asst_text, event))
|
||||
return conversation
|
||||
|
||||
Scanning strategy:
|
||||
1. Group events by session
|
||||
2. Within each session, find the pattern:
|
||||
user_msg → assistant_response → correction → (optional) better_response
|
||||
3. Build DPO pair: prompt=user_msg, rejected=assistant_response, chosen=better_response
|
||||
|
||||
def extract_all_signals(events: list[dict], verbose: bool = False) -> dict:
|
||||
"""Extract ALL learning signals from events.
|
||||
|
||||
Returns dict with:
|
||||
- dpo_pairs: hard correction DPO pairs (prompt + chosen + rejected)
|
||||
- soft_redirect_pairs: soft redirect DPO pairs
|
||||
- sft_pairs: positively reinforced exchanges (prompt + good response)
|
||||
- teaching_pairs: knowledge transfer moments (context + lesson)
|
||||
- stats: extraction statistics
|
||||
"""
|
||||
# Group by session
|
||||
sessions: dict[str, list[dict]] = {}
|
||||
|
|
@ -418,70 +549,180 @@ def extract_dpo_pairs(events: list[dict], verbose: bool = False) -> list[dict]:
|
|||
sid = get_session_id(event)
|
||||
sessions.setdefault(sid, []).append(event)
|
||||
|
||||
pairs = []
|
||||
stats = {'sessions': 0, 'corrections_found': 0, 'pairs_built': 0, 'pairs_with_chosen': 0}
|
||||
dpo_pairs = []
|
||||
soft_redirect_pairs = []
|
||||
sft_pairs = []
|
||||
teaching_pairs = []
|
||||
stats = {
|
||||
'sessions': 0,
|
||||
'corrections_found': 0, 'dpo_pairs_built': 0, 'dpo_with_chosen': 0,
|
||||
'soft_redirects_found': 0, 'soft_redirect_pairs_built': 0,
|
||||
'positives_found': 0, 'sft_pairs_built': 0,
|
||||
'teachings_found': 0, 'teaching_pairs_built': 0,
|
||||
}
|
||||
|
||||
for sid, session_events in sessions.items():
|
||||
session_events.sort(key=lambda e: e.get('seq', 0))
|
||||
stats['sessions'] += 1
|
||||
|
||||
# Build conversation sequence: list of (role, text, event)
|
||||
conversation = []
|
||||
for event in session_events:
|
||||
user_text = extract_user_text(event)
|
||||
if user_text:
|
||||
conversation.append(('user', user_text, event))
|
||||
continue
|
||||
asst_text = extract_assistant_text(event)
|
||||
if asst_text:
|
||||
# Keep the longest assistant response in a streak
|
||||
if conversation and conversation[-1][0] == 'assistant':
|
||||
if len(asst_text) > len(conversation[-1][1]):
|
||||
conversation[-1] = ('assistant', asst_text, event)
|
||||
else:
|
||||
conversation.append(('assistant', asst_text, event))
|
||||
conversation = _build_conversation(session_events)
|
||||
|
||||
# Scan for correction pattern: user → assistant → user(correction) → assistant(better)
|
||||
for i in range(len(conversation) - 2):
|
||||
if (conversation[i][0] == 'user' and
|
||||
for i in range(len(conversation)):
|
||||
# === Pattern 1: user → assistant → user(correction) → assistant(better) ===
|
||||
if (i + 2 < len(conversation) and
|
||||
conversation[i][0] == 'user' and
|
||||
conversation[i + 1][0] == 'assistant' and
|
||||
conversation[i + 2][0] == 'user'):
|
||||
|
||||
correction_result = detect_correction(conversation[i + 2][1])
|
||||
if not correction_result:
|
||||
continue
|
||||
user_text = conversation[i + 2][1]
|
||||
|
||||
label, confidence = correction_result
|
||||
stats['corrections_found'] += 1
|
||||
# Hard correction
|
||||
correction_result = detect_correction(user_text)
|
||||
if correction_result:
|
||||
label, confidence = correction_result
|
||||
stats['corrections_found'] += 1
|
||||
|
||||
pair = build_dpo_pair(
|
||||
prompt_event=conversation[i][2],
|
||||
rejected_event=conversation[i + 1][2],
|
||||
correction_event=conversation[i + 2][2],
|
||||
correction_label=label,
|
||||
correction_confidence=confidence,
|
||||
)
|
||||
pair = build_dpo_pair(
|
||||
prompt_event=conversation[i][2],
|
||||
rejected_event=conversation[i + 1][2],
|
||||
correction_event=conversation[i + 2][2],
|
||||
correction_label=label,
|
||||
correction_confidence=confidence,
|
||||
)
|
||||
|
||||
if not pair:
|
||||
continue
|
||||
if pair:
|
||||
if (i + 3 < len(conversation) and
|
||||
conversation[i + 3][0] == 'assistant'):
|
||||
chosen_text = conversation[i + 3][1]
|
||||
if len(chosen_text) >= MIN_RESPONSE_LEN:
|
||||
pair['chosen'] = chosen_text
|
||||
stats['dpo_with_chosen'] += 1
|
||||
|
||||
# Look for the better response after the correction
|
||||
if (i + 3 < len(conversation) and
|
||||
conversation[i + 3][0] == 'assistant'):
|
||||
chosen_text = conversation[i + 3][1]
|
||||
if len(chosen_text) >= MIN_RESPONSE_LEN:
|
||||
pair['chosen'] = chosen_text
|
||||
stats['pairs_with_chosen'] += 1
|
||||
stats['dpo_pairs_built'] += 1
|
||||
dpo_pairs.append(pair)
|
||||
|
||||
stats['pairs_built'] += 1
|
||||
pairs.append(pair)
|
||||
if verbose:
|
||||
print(f"\n 🔴 CORRECTION [{label}] conf={confidence:.0%}",
|
||||
file=sys.stderr)
|
||||
print(f" prompt: {pair['prompt'][:80]}...", file=sys.stderr)
|
||||
continue
|
||||
|
||||
if verbose:
|
||||
print(f"\n 📌 [{label}] conf={confidence:.0%}", file=sys.stderr)
|
||||
print(f" prompt: {pair['prompt'][:80]}...", file=sys.stderr)
|
||||
print(f" correction: {pair['correction'][:80]}...", file=sys.stderr)
|
||||
# Soft redirect
|
||||
redirect_result = detect_soft_redirect(user_text)
|
||||
if redirect_result:
|
||||
label, confidence = redirect_result
|
||||
stats['soft_redirects_found'] += 1
|
||||
|
||||
return pairs, stats
|
||||
pair = build_dpo_pair(
|
||||
prompt_event=conversation[i][2],
|
||||
rejected_event=conversation[i + 1][2],
|
||||
correction_event=conversation[i + 2][2],
|
||||
correction_label=f'soft_{label}',
|
||||
correction_confidence=confidence,
|
||||
)
|
||||
|
||||
if pair:
|
||||
if (i + 3 < len(conversation) and
|
||||
conversation[i + 3][0] == 'assistant'):
|
||||
chosen_text = conversation[i + 3][1]
|
||||
if len(chosen_text) >= MIN_RESPONSE_LEN:
|
||||
pair['chosen'] = chosen_text
|
||||
|
||||
stats['soft_redirect_pairs_built'] += 1
|
||||
soft_redirect_pairs.append(pair)
|
||||
|
||||
if verbose:
|
||||
print(f"\n 🟡 REDIRECT [{label}] conf={confidence:.0%}",
|
||||
file=sys.stderr)
|
||||
print(f" redirect: {clean_user_text(user_text)[:80]}...",
|
||||
file=sys.stderr)
|
||||
continue
|
||||
|
||||
# Positive reinforcement: user praises → previous exchange is good
|
||||
positive_result = detect_positive(user_text)
|
||||
if positive_result:
|
||||
label, confidence = positive_result
|
||||
stats['positives_found'] += 1
|
||||
|
||||
prompt_text = clean_user_text(conversation[i][1])
|
||||
response_text = conversation[i + 1][1]
|
||||
|
||||
if len(prompt_text) >= MIN_PROMPT_LEN and len(response_text) >= MIN_RESPONSE_LEN:
|
||||
sft_pairs.append({
|
||||
'prompt': prompt_text,
|
||||
'response': response_text,
|
||||
'signal_type': 'positive_reinforcement',
|
||||
'signal_label': label,
|
||||
'confidence': confidence,
|
||||
'metadata': {
|
||||
'session': sid,
|
||||
'prompt_seq': conversation[i][2].get('seq'),
|
||||
'response_seq': conversation[i + 1][2].get('seq'),
|
||||
'timestamp': conversation[i + 2][2].get('timestamp'),
|
||||
'praise_text': clean_user_text(user_text)[:200],
|
||||
}
|
||||
})
|
||||
stats['sft_pairs_built'] += 1
|
||||
|
||||
if verbose:
|
||||
print(f"\n 🟢 POSITIVE [{label}] conf={confidence:.0%}",
|
||||
file=sys.stderr)
|
||||
print(f" prompt: {prompt_text[:80]}...", file=sys.stderr)
|
||||
continue
|
||||
|
||||
# Teaching moment: user teaches something
|
||||
teaching_result = detect_teaching(user_text)
|
||||
if teaching_result:
|
||||
label, confidence = teaching_result
|
||||
stats['teachings_found'] += 1
|
||||
|
||||
cleaned_teaching = clean_user_text(user_text)
|
||||
# Context is what came before (the assistant response that triggered teaching)
|
||||
context = conversation[i + 1][1] if conversation[i + 1][0] == 'assistant' else ''
|
||||
|
||||
if len(cleaned_teaching) >= MIN_TEACHING_LEN:
|
||||
teaching_pairs.append({
|
||||
'lesson': cleaned_teaching,
|
||||
'context': context[:500] if context else '',
|
||||
'signal_type': 'teaching',
|
||||
'signal_label': label,
|
||||
'confidence': confidence,
|
||||
'metadata': {
|
||||
'session': sid,
|
||||
'seq': conversation[i + 2][2].get('seq'),
|
||||
'timestamp': conversation[i + 2][2].get('timestamp'),
|
||||
}
|
||||
})
|
||||
stats['teaching_pairs_built'] += 1
|
||||
|
||||
if verbose:
|
||||
print(f"\n 📚 TEACHING [{label}] conf={confidence:.0%}",
|
||||
file=sys.stderr)
|
||||
print(f" lesson: {cleaned_teaching[:80]}...", file=sys.stderr)
|
||||
|
||||
return {
|
||||
'dpo_pairs': dpo_pairs,
|
||||
'soft_redirect_pairs': soft_redirect_pairs,
|
||||
'sft_pairs': sft_pairs,
|
||||
'teaching_pairs': teaching_pairs,
|
||||
'stats': stats,
|
||||
}
|
||||
|
||||
|
||||
# Keep backward-compatible function
|
||||
def extract_dpo_pairs(events: list[dict], verbose: bool = False) -> tuple[list[dict], dict]:
|
||||
"""Extract DPO pairs (backward compatible wrapper)."""
|
||||
result = extract_all_signals(events, verbose=verbose)
|
||||
all_dpo = result['dpo_pairs'] + result['soft_redirect_pairs']
|
||||
stats = result['stats']
|
||||
# Map to old stats format
|
||||
old_stats = {
|
||||
'sessions': stats['sessions'],
|
||||
'corrections_found': stats['corrections_found'] + stats['soft_redirects_found'],
|
||||
'pairs_built': stats['dpo_pairs_built'] + stats['soft_redirect_pairs_built'],
|
||||
'pairs_with_chosen': stats['dpo_with_chosen'],
|
||||
}
|
||||
return all_dpo, old_stats
|
||||
|
||||
|
||||
def to_dpo_training_format(pairs: list[dict]) -> list[dict]:
|
||||
|
|
@ -571,37 +812,57 @@ def main():
|
|||
conv_events = [e for e in events if e.get('type', '').startswith('conversation.message')]
|
||||
print(f" {len(conv_events)} conversation events out of {len(events)} total", file=sys.stderr)
|
||||
|
||||
# Extract pairs
|
||||
pairs, stats = extract_dpo_pairs(conv_events, verbose=args.verbose)
|
||||
# Extract ALL signals
|
||||
result = extract_all_signals(conv_events, verbose=args.verbose)
|
||||
stats = result['stats']
|
||||
|
||||
# Filter by confidence
|
||||
pairs = [p for p in pairs if p['metadata']['confidence'] >= args.min_confidence]
|
||||
dpo_pairs = [p for p in result['dpo_pairs']
|
||||
if p['metadata']['confidence'] >= args.min_confidence]
|
||||
soft_pairs = [p for p in result['soft_redirect_pairs']
|
||||
if p['metadata']['confidence'] >= args.min_confidence]
|
||||
sft_pairs = [p for p in result['sft_pairs']
|
||||
if p['confidence'] >= args.min_confidence]
|
||||
teaching_pairs = [p for p in result['teaching_pairs']
|
||||
if p['confidence'] >= args.min_confidence]
|
||||
|
||||
# Stats
|
||||
print(f"\n📊 Results:", file=sys.stderr)
|
||||
print(f" Sessions scanned: {stats['sessions']}", file=sys.stderr)
|
||||
print(f" Corrections detected: {stats['corrections_found']}", file=sys.stderr)
|
||||
print(f" DPO pairs built: {stats['pairs_built']}", file=sys.stderr)
|
||||
print(f" Pairs with chosen response: {stats['pairs_with_chosen']}", file=sys.stderr)
|
||||
print(f" After confidence filter (≥{args.min_confidence}): {len(pairs)}", file=sys.stderr)
|
||||
print(f"", file=sys.stderr)
|
||||
print(f" 🔴 Hard corrections: {stats['corrections_found']:3d} detected → "
|
||||
f"{len(dpo_pairs)} pairs ({stats['dpo_with_chosen']} with chosen)", file=sys.stderr)
|
||||
print(f" 🟡 Soft redirects: {stats['soft_redirects_found']:3d} detected → "
|
||||
f"{len(soft_pairs)} pairs", file=sys.stderr)
|
||||
print(f" 🟢 Positive signals: {stats['positives_found']:3d} detected → "
|
||||
f"{len(sft_pairs)} SFT pairs", file=sys.stderr)
|
||||
print(f" 📚 Teaching moments: {stats['teachings_found']:3d} detected → "
|
||||
f"{len(teaching_pairs)} pairs", file=sys.stderr)
|
||||
|
||||
complete = [p for p in pairs if p.get('chosen')]
|
||||
incomplete = [p for p in pairs if not p.get('chosen')]
|
||||
print(f" Complete (prompt+chosen+rejected): {len(complete)}", file=sys.stderr)
|
||||
print(f" Incomplete (no chosen response): {len(incomplete)}", file=sys.stderr)
|
||||
total = len(dpo_pairs) + len(soft_pairs) + len(sft_pairs) + len(teaching_pairs)
|
||||
print(f"\n 📦 Total training signal: {total} pairs", file=sys.stderr)
|
||||
|
||||
if args.dry_run:
|
||||
# Show sample pairs
|
||||
if pairs:
|
||||
print(f"\n📝 Sample pairs:", file=sys.stderr)
|
||||
for p in pairs[:5]:
|
||||
print(f"\n [{p['metadata']['correction_type']}] "
|
||||
f"conf={p['metadata']['confidence']:.0%}", file=sys.stderr)
|
||||
print(f" prompt: {p['prompt'][:100]}", file=sys.stderr)
|
||||
print(f" rejected: {p['rejected'][:100]}", file=sys.stderr)
|
||||
print(f" correction: {p['correction'][:100]}", file=sys.stderr)
|
||||
if p.get('chosen'):
|
||||
print(f" chosen: {p['chosen'][:100]}", file=sys.stderr)
|
||||
for label, pairs_list, emoji in [
|
||||
('DPO (hard)', dpo_pairs, '🔴'),
|
||||
('DPO (soft)', soft_pairs, '🟡'),
|
||||
('SFT (positive)', sft_pairs, '🟢'),
|
||||
('Teaching', teaching_pairs, '📚'),
|
||||
]:
|
||||
if pairs_list:
|
||||
print(f"\n{emoji} {label} — sample:", file=sys.stderr)
|
||||
for p in pairs_list[:3]:
|
||||
if 'prompt' in p:
|
||||
print(f" prompt: {p['prompt'][:100]}", file=sys.stderr)
|
||||
if 'correction' in p:
|
||||
print(f" correction: {p['correction'][:100]}", file=sys.stderr)
|
||||
if 'response' in p:
|
||||
print(f" response: {p['response'][:100]}", file=sys.stderr)
|
||||
if 'lesson' in p:
|
||||
print(f" lesson: {p['lesson'][:100]}", file=sys.stderr)
|
||||
if p.get('metadata', {}).get('praise_text'):
|
||||
print(f" praise: {p['metadata']['praise_text'][:80]}", file=sys.stderr)
|
||||
print(f" ---", file=sys.stderr)
|
||||
return
|
||||
|
||||
# Output
|
||||
|
|
@ -609,17 +870,40 @@ def main():
|
|||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
timestamp = datetime.now().strftime('%Y-%m-%d')
|
||||
|
||||
all_dpo = dpo_pairs + soft_pairs
|
||||
|
||||
if args.format in ('training', 'both'):
|
||||
training_data = to_dpo_training_format(pairs)
|
||||
# DPO training format
|
||||
training_data = to_dpo_training_format(all_dpo)
|
||||
path = Path(args.output) if args.output else output_dir / f'dpo-training-{timestamp}.json'
|
||||
path.write_text(json.dumps(training_data, indent=2, ensure_ascii=False))
|
||||
print(f"\n✅ Training format: {path} ({len(training_data)} pairs)", file=sys.stderr)
|
||||
print(f"\n✅ DPO training: {path} ({len(training_data)} pairs)", file=sys.stderr)
|
||||
|
||||
# SFT training format (positive reinforcement)
|
||||
sft_data = [{'instruction': p['prompt'], 'input': '', 'output': p['response']}
|
||||
for p in sft_pairs]
|
||||
sft_path = output_dir / f'sft-positive-{timestamp}.json'
|
||||
sft_path.write_text(json.dumps(sft_data, indent=2, ensure_ascii=False))
|
||||
print(f"✅ SFT positive: {sft_path} ({len(sft_data)} pairs)", file=sys.stderr)
|
||||
|
||||
# Teaching pairs
|
||||
teach_data = [{'lesson': p['lesson'], 'context': p.get('context', ''),
|
||||
'label': p['signal_label']} for p in teaching_pairs]
|
||||
teach_path = output_dir / f'teaching-{timestamp}.json'
|
||||
teach_path.write_text(json.dumps(teach_data, indent=2, ensure_ascii=False))
|
||||
print(f"✅ Teaching: {teach_path} ({len(teach_data)} pairs)", file=sys.stderr)
|
||||
|
||||
if args.format in ('detailed', 'both'):
|
||||
detailed_data = to_detailed_format(pairs)
|
||||
path = output_dir / f'dpo-detailed-{timestamp}.json'
|
||||
detailed_data = {
|
||||
'dpo_pairs': to_detailed_format(all_dpo),
|
||||
'sft_pairs': sft_pairs,
|
||||
'teaching_pairs': teaching_pairs,
|
||||
'stats': stats,
|
||||
'extracted_at': datetime.now().isoformat(),
|
||||
}
|
||||
path = output_dir / f'signals-detailed-{timestamp}.json'
|
||||
path.write_text(json.dumps(detailed_data, indent=2, ensure_ascii=False))
|
||||
print(f"✅ Detailed format: {path} ({len(detailed_data)} pairs)", file=sys.stderr)
|
||||
print(f"✅ Detailed: {path}", file=sys.stderr)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
Loading…
Reference in a new issue