Skip to main content

xlog_prob/kc/
ddnnf.rs

1//! Decision-DNNF parser and CPU reference evaluator.
2
3use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
4
5use xlog_core::{Result, XlogError};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum DdnnfNodeKind {
9    Or,
10    And,
11    True,
12    False,
13}
14
15#[derive(Debug, Clone)]
16pub struct DdnnfNode {
17    pub kind: DdnnfNodeKind,
18}
19
20#[derive(Debug, Clone)]
21pub struct DdnnfEdge {
22    pub from: u32,
23    pub to: u32,
24    pub lits: Vec<i32>,
25}
26
27#[derive(Debug, Clone)]
28pub struct DecisionDnnf {
29    root: u32,
30    nodes: BTreeMap<u32, DdnnfNode>,
31    edges: Vec<DdnnfEdge>,
32    outgoing: BTreeMap<u32, Vec<usize>>,
33    max_var: u32,
34}
35
36impl DecisionDnnf {
37    pub fn root(&self) -> u32 {
38        self.root
39    }
40
41    pub fn max_var(&self) -> u32 {
42        self.max_var
43    }
44
45    pub fn node_kind(&self, node_id: u32) -> Option<DdnnfNodeKind> {
46        self.nodes.get(&node_id).map(|n| n.kind)
47    }
48
49    pub fn outgoing_edge_indices(&self, node_id: u32) -> Option<&[usize]> {
50        self.outgoing.get(&node_id).map(|v| v.as_slice())
51    }
52
53    pub fn edge(&self, edge_idx: usize) -> Option<&DdnnfEdge> {
54        self.edges.get(edge_idx)
55    }
56
57    pub fn parse_str(input: &str) -> Result<Self> {
58        let mut nodes: BTreeMap<u32, DdnnfNode> = BTreeMap::new();
59        let mut edges: Vec<DdnnfEdge> = Vec::new();
60        let mut targets: HashSet<u32> = HashSet::new();
61        let mut max_var: u32 = 0;
62
63        for (line_no, raw_line) in input.lines().enumerate() {
64            let line = raw_line.trim();
65            if line.is_empty() {
66                continue;
67            }
68
69            let mut tokens: Vec<&str> = line.split_whitespace().collect();
70            if tokens.is_empty() {
71                continue;
72            }
73
74            if tokens.last() != Some(&"0") {
75                return Err(XlogError::Compilation(format!(
76                    "Decision-DNNF parse error at line {}: missing 0 terminator",
77                    line_no + 1
78                )));
79            }
80            tokens.pop();
81            if tokens.is_empty() {
82                return Err(XlogError::Compilation(format!(
83                    "Decision-DNNF parse error at line {}: empty record before terminator",
84                    line_no + 1
85                )));
86            }
87
88            match tokens[0] {
89                "o" | "a" | "t" | "f" => {
90                    if tokens.len() < 2 {
91                        return Err(XlogError::Compilation(format!(
92                            "Decision-DNNF parse error at line {}: node record missing id",
93                            line_no + 1
94                        )));
95                    }
96                    let id: u32 = tokens[1].parse().map_err(|_| {
97                        XlogError::Compilation(format!(
98                            "Decision-DNNF parse error at line {}: invalid node id '{}'",
99                            line_no + 1,
100                            tokens[1]
101                        ))
102                    })?;
103
104                    let kind = match tokens[0] {
105                        "o" => DdnnfNodeKind::Or,
106                        "a" => DdnnfNodeKind::And,
107                        "t" => DdnnfNodeKind::True,
108                        "f" => DdnnfNodeKind::False,
109                        _ => unreachable!(),
110                    };
111
112                    if nodes.insert(id, DdnnfNode { kind }).is_some() {
113                        return Err(XlogError::Compilation(format!(
114                            "Decision-DNNF parse error at line {}: duplicate node id {}",
115                            line_no + 1,
116                            id
117                        )));
118                    }
119                }
120                _ => {
121                    if tokens.len() < 2 {
122                        return Err(XlogError::Compilation(format!(
123                            "Decision-DNNF parse error at line {}: edge record missing dst",
124                            line_no + 1
125                        )));
126                    }
127                    let from: u32 = tokens[0].parse().map_err(|_| {
128                        XlogError::Compilation(format!(
129                            "Decision-DNNF parse error at line {}: invalid edge src '{}'",
130                            line_no + 1,
131                            tokens[0]
132                        ))
133                    })?;
134                    let to: u32 = tokens[1].parse().map_err(|_| {
135                        XlogError::Compilation(format!(
136                            "Decision-DNNF parse error at line {}: invalid edge dst '{}'",
137                            line_no + 1,
138                            tokens[1]
139                        ))
140                    })?;
141
142                    let mut lits: Vec<i32> = Vec::new();
143                    for &tok in &tokens[2..] {
144                        let lit: i32 = tok.parse().map_err(|_| {
145                            XlogError::Compilation(format!(
146                                "Decision-DNNF parse error at line {}: invalid literal '{}'",
147                                line_no + 1,
148                                tok
149                            ))
150                        })?;
151                        if lit == 0 {
152                            return Err(XlogError::Compilation(format!(
153                                "Decision-DNNF parse error at line {}: literal cannot be 0",
154                                line_no + 1
155                            )));
156                        }
157                        max_var = max_var.max(lit.unsigned_abs());
158                        lits.push(lit);
159                    }
160
161                    let edge_id = edges.len();
162                    edges.push(DdnnfEdge { from, to, lits });
163                    targets.insert(to);
164
165                    // outgoing filled later after validation.
166                    let _ = edge_id;
167                }
168            }
169        }
170
171        if nodes.is_empty() {
172            return Err(XlogError::Compilation(
173                "Decision-DNNF parse error: no nodes found".to_string(),
174            ));
175        }
176
177        for edge in &edges {
178            let from_kind = nodes.get(&edge.from).ok_or_else(|| {
179                XlogError::Compilation(format!(
180                    "Decision-DNNF parse error: edge references unknown src node {}",
181                    edge.from
182                ))
183            })?;
184            let _to_kind = nodes.get(&edge.to).ok_or_else(|| {
185                XlogError::Compilation(format!(
186                    "Decision-DNNF parse error: edge references unknown dst node {}",
187                    edge.to
188                ))
189            })?;
190
191            match from_kind.kind {
192                DdnnfNodeKind::Or | DdnnfNodeKind::And => {}
193                DdnnfNodeKind::True | DdnnfNodeKind::False => {
194                    return Err(XlogError::Compilation(format!(
195                        "Decision-DNNF parse error: leaf node {} cannot have outgoing edges",
196                        edge.from
197                    )));
198                }
199            }
200        }
201
202        let declared: BTreeSet<u32> = nodes.keys().copied().collect();
203        let target_set: BTreeSet<u32> = targets.into_iter().collect();
204        let roots: Vec<u32> = declared.difference(&target_set).copied().collect();
205        let root = match roots.as_slice() {
206            [only] => *only,
207            [] => {
208                return Err(XlogError::Compilation(
209                    "Decision-DNNF parse error: could not infer root (no root candidates)"
210                        .to_string(),
211                ))
212            }
213            many => {
214                return Err(XlogError::Compilation(format!(
215                    "Decision-DNNF parse error: could not infer unique root (candidates: {:?})",
216                    many
217                )))
218            }
219        };
220
221        let mut outgoing: BTreeMap<u32, Vec<usize>> = BTreeMap::new();
222        for (idx, edge) in edges.iter().enumerate() {
223            outgoing.entry(edge.from).or_default().push(idx);
224        }
225
226        // Optional cycle check (defensive).
227        Self::check_acyclic(root, &nodes, &edges, &outgoing)?;
228
229        Ok(Self {
230            root,
231            nodes,
232            edges,
233            outgoing,
234            max_var,
235        })
236    }
237
238    fn check_acyclic(
239        root: u32,
240        nodes: &BTreeMap<u32, DdnnfNode>,
241        edges: &[DdnnfEdge],
242        outgoing: &BTreeMap<u32, Vec<usize>>,
243    ) -> Result<()> {
244        let mut visiting: HashSet<u32> = HashSet::new();
245        let mut visited: HashSet<u32> = HashSet::new();
246
247        fn dfs(
248            node_id: u32,
249            nodes: &BTreeMap<u32, DdnnfNode>,
250            edges: &[DdnnfEdge],
251            outgoing: &BTreeMap<u32, Vec<usize>>,
252            visiting: &mut HashSet<u32>,
253            visited: &mut HashSet<u32>,
254        ) -> Result<()> {
255            if visited.contains(&node_id) {
256                return Ok(());
257            }
258            if !visiting.insert(node_id) {
259                return Err(XlogError::Compilation(format!(
260                    "Decision-DNNF parse error: cycle detected at node {}",
261                    node_id
262                )));
263            }
264
265            let node = nodes.get(&node_id).ok_or_else(|| {
266                XlogError::Compilation(format!(
267                    "Decision-DNNF parse error: unknown node {} during cycle check",
268                    node_id
269                ))
270            })?;
271
272            match node.kind {
273                DdnnfNodeKind::True | DdnnfNodeKind::False => {}
274                DdnnfNodeKind::Or | DdnnfNodeKind::And => {
275                    if let Some(out) = outgoing.get(&node_id) {
276                        for &edge_idx in out {
277                            let edge = &edges[edge_idx];
278                            dfs(edge.to, nodes, edges, outgoing, visiting, visited)?;
279                        }
280                    }
281                }
282            }
283
284            visiting.remove(&node_id);
285            visited.insert(node_id);
286            Ok(())
287        }
288
289        dfs(root, nodes, edges, outgoing, &mut visiting, &mut visited)
290    }
291
292    pub fn eval_log_wmc<F>(&self, var_log_weights: F) -> Result<f64>
293    where
294        F: Fn(u32) -> (f64, f64),
295    {
296        let mut memo: HashMap<u32, f64> = HashMap::new();
297
298        fn logsumexp(values: &[f64]) -> f64 {
299            let mut max = f64::NEG_INFINITY;
300            for &v in values {
301                if v > max {
302                    max = v;
303                }
304            }
305            if max.is_infinite() {
306                return max;
307            }
308            let mut sum = 0.0;
309            for &v in values {
310                sum += (v - max).exp();
311            }
312            max + sum.ln()
313        }
314
315        fn eval_node<F>(
316            node_id: u32,
317            ddnnf: &DecisionDnnf,
318            memo: &mut HashMap<u32, f64>,
319            var_log_weights: &F,
320        ) -> Result<f64>
321        where
322            F: Fn(u32) -> (f64, f64),
323        {
324            if let Some(&v) = memo.get(&node_id) {
325                return Ok(v);
326            }
327
328            let node = ddnnf.nodes.get(&node_id).ok_or_else(|| {
329                XlogError::Compilation(format!(
330                    "Decision-DNNF eval error: unknown node {}",
331                    node_id
332                ))
333            })?;
334
335            let value = match node.kind {
336                DdnnfNodeKind::True => 0.0,
337                DdnnfNodeKind::False => f64::NEG_INFINITY,
338                DdnnfNodeKind::And => {
339                    let out = ddnnf.outgoing.get(&node_id).ok_or_else(|| {
340                        XlogError::Compilation(format!(
341                            "Decision-DNNF eval error: AND node {} has no children",
342                            node_id
343                        ))
344                    })?;
345
346                    let mut acc = 0.0;
347                    for &edge_idx in out {
348                        let edge = &ddnnf.edges[edge_idx];
349                        let child = eval_node(edge.to, ddnnf, memo, var_log_weights)?;
350                        let mut lit_sum = 0.0;
351                        for &lit in &edge.lits {
352                            let var = lit.unsigned_abs();
353                            let (t, f) = var_log_weights(var);
354                            lit_sum += if lit > 0 { t } else { f };
355                        }
356                        acc += lit_sum + child;
357                    }
358                    acc
359                }
360                DdnnfNodeKind::Or => {
361                    let out = ddnnf.outgoing.get(&node_id).ok_or_else(|| {
362                        XlogError::Compilation(format!(
363                            "Decision-DNNF eval error: OR node {} has no children",
364                            node_id
365                        ))
366                    })?;
367
368                    let mut branch_vals: Vec<f64> = Vec::with_capacity(out.len());
369                    for &edge_idx in out {
370                        let edge = &ddnnf.edges[edge_idx];
371                        let child = eval_node(edge.to, ddnnf, memo, var_log_weights)?;
372                        let mut lit_sum = 0.0;
373                        for &lit in &edge.lits {
374                            let var = lit.unsigned_abs();
375                            let (t, f) = var_log_weights(var);
376                            lit_sum += if lit > 0 { t } else { f };
377                        }
378                        branch_vals.push(lit_sum + child);
379                    }
380                    logsumexp(&branch_vals)
381                }
382            };
383
384            memo.insert(node_id, value);
385            Ok(value)
386        }
387
388        eval_node(self.root, self, &mut memo, &var_log_weights)
389    }
390}
391
392#[cfg(test)]
393mod tests {
394    use super::*;
395
396    #[test]
397    fn test_parse_and_eval_identity_variable() {
398        // Represents the formula: x1
399        let nnf = r#"
400o 1 0
401t 2 0
402f 3 0
4031 2 1 0
4041 3 -1 0
405"#;
406
407        let ddnnf = DecisionDnnf::parse_str(nnf).unwrap();
408        assert_eq!(ddnnf.root(), 1);
409        assert_eq!(ddnnf.max_var(), 1);
410
411        let p = 0.3_f64;
412        let log_wmc = ddnnf
413            .eval_log_wmc(|var| match var {
414                1 => (p.ln(), (1.0 - p).ln()),
415                _ => panic!("unexpected var {}", var),
416            })
417            .unwrap();
418
419        assert!((log_wmc - p.ln()).abs() < 1e-9, "log_wmc={}", log_wmc);
420    }
421
422    #[test]
423    fn test_parse_detects_missing_terminator() {
424        let nnf = "t 1";
425        let err = DecisionDnnf::parse_str(nnf).unwrap_err();
426        let msg = err.to_string();
427        assert!(msg.contains("terminator"), "msg={}", msg);
428    }
429}