From 5458b3a822feb727d1fb9ea84b6d401fae33685e Mon Sep 17 00:00:00 2001 From: Albin Chaboissier Date: Thu, 12 Feb 2026 20:59:45 +0100 Subject: [PATCH] Domains test --- src/ast.rs | 5 + src/domain.rs | 512 +++++++++++++++++++++++++++++++ src/lib.rs | 1 + src/main.rs | 81 +++-- src/parsing.rs | 67 +++- src/prover.rs | 1 + src/prover/constraints.rs | 88 ++++-- src/prover/internal_operators.rs | 50 +++ src/prover/unification.rs | 34 +- 9 files changed, 783 insertions(+), 56 deletions(-) create mode 100644 src/domain.rs create mode 100644 src/prover/internal_operators.rs diff --git a/src/ast.rs b/src/ast.rs index 23a819e..4811ba8 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -2,6 +2,8 @@ use owo_colors::{OwoColorize, colors::css::Gray}; use std::fmt::Display; +use crate::domain::Domain; + #[derive(Clone, PartialEq, Eq, Debug, Hash)] pub struct Variable(pub String, pub Option); @@ -31,6 +33,7 @@ pub enum Predicate { Variable(Variable), // Upercase variable like X Fixed(Functor, Vec), + Domain(Domain), } #[derive(Clone, Debug, PartialEq, Eq, Hash)] @@ -135,6 +138,7 @@ impl Display for Predicate Ok(()) } } + Predicate::Domain(domain) => write!(f, "{domain}"), } } } @@ -174,6 +178,7 @@ impl Display for Functor } } } + impl Display for Operator { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result diff --git a/src/domain.rs b/src/domain.rs new file mode 100644 index 0000000..2c04a09 --- /dev/null +++ b/src/domain.rs @@ -0,0 +1,512 @@ +use std::{cmp, fmt::Display}; + +use crate::domain; + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct Domain +{ + pub union: Vec, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub struct DomainRange +{ + pub start: DomainBound, + pub end: DomainBound, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum DomainBound +{ + Unbounded, + Bounded(i64), +} + +impl Domain +{ + pub fn all() -> Domain + { + Domain { + union: vec![DomainRange { + start: DomainBound::Unbounded, + end: DomainBound::Unbounded, + }], + } + } + + pub fn included_in(&self, domain: &Domain) -> bool + { + self.union(domain) == *domain + } + + pub fn empty(&self) -> bool + { + self.union.len() == 0 + } + + pub fn disjoint(&self, domain: &Domain) -> bool + { + self.intersection(domain).union.len() == 0 + } + + pub fn intersection(&self, domain: &Domain) -> Domain + { + let mut i = 0; + let mut j = 0; + let mut stack = vec![]; + + while i < self.union.len() || j < domain.union.len() + { + if i < self.union.len() && j < domain.union.len() + { + let a = self.union[i]; + let b = domain.union[j]; + + if a.disjoint(&b) + { + if a < b + { + i += 1; + } + else + { + j += 1; + } + } + else + { + let inter = a.intersection(&b); + if let Some(range) = inter + { + stack.push(range); + }; + if a.ends_furthest(&b) + { + i += 1; + } + else + { + j += 1; + } + } + } + else + { + break; + } + } + + Domain { union: stack } + } + + pub fn union(&self, domain: &Domain) -> Domain + { + let mut i = 0; + let mut j = 0; + let mut stack = vec![]; + + while i < self.union.len() || j < domain.union.len() + { + if i < self.union.len() && j < domain.union.len() + { + let a = self.union[i]; + let b = domain.union[j]; + + let (min, max) = (a.min(b), a.max(b)); + + if let DomainBound::Bounded(end) = min.end + && let DomainBound::Bounded(start) = max.start + && end == start + { + let union = a.union(&b); + stack.extend_from_slice(union.union.as_slice()); + i += 1; + j += 1; + } + else if a.disjoint(&b) + { + if a < b + { + stack.push(a); + i += 1; + } + else + { + stack.push(b); + j += 1; + } + } + else + { + let union = a.union(&b); + stack.extend_from_slice(union.union.as_slice()); + i += 1; + j += 1; + } + } + else if i < self.union.len() + { + stack.extend_from_slice(&self.union[i..]); + i = self.union.len(); + } + else + { + stack.extend_from_slice(&domain.union[j..]); + j = self.union.len(); + } + } + + Domain { union: stack } + } +} + +impl From for Domain +{ + fn from(value: DomainRange) -> Self + { + Domain { union: vec![value] } + } +} + +impl DomainBound +{ + pub fn min_start(&self, bound: DomainBound) -> DomainBound + { + match (self, bound) + { + (DomainBound::Unbounded, DomainBound::Unbounded) => DomainBound::Unbounded, + (DomainBound::Bounded(_), DomainBound::Unbounded) + | (DomainBound::Unbounded, DomainBound::Bounded(_)) => DomainBound::Unbounded, + (DomainBound::Bounded(a), DomainBound::Bounded(b)) => DomainBound::Bounded(*a.min(&b)), + } + } + + pub fn max_end(&self, bound: DomainBound) -> DomainBound + { + match (self, bound) + { + (DomainBound::Unbounded, DomainBound::Unbounded) => DomainBound::Unbounded, + (DomainBound::Bounded(_), DomainBound::Unbounded) + | (DomainBound::Unbounded, DomainBound::Bounded(_)) => DomainBound::Unbounded, + (DomainBound::Bounded(a), DomainBound::Bounded(b)) => DomainBound::Bounded(*a.max(&b)), + } + } +} + +impl PartialOrd for DomainRange +{ + fn partial_cmp(&self, other: &Self) -> Option + { + Some(self.cmp(other)) + } +} + +impl Ord for DomainRange +{ + fn cmp(&self, other: &Self) -> std::cmp::Ordering + { + match (self.start, other.start) + { + (DomainBound::Unbounded, DomainBound::Unbounded) => cmp::Ordering::Equal, + (DomainBound::Unbounded, DomainBound::Bounded(_)) => cmp::Ordering::Less, + (DomainBound::Bounded(_), DomainBound::Unbounded) => cmp::Ordering::Greater, + (DomainBound::Bounded(a), DomainBound::Bounded(b)) => a.cmp(&b), + } + } +} + +impl DomainRange +{ + pub fn intersection(&self, range: &DomainRange) -> Option + { + let (min, max) = (self.min(range), self.max(range)); + match (min.end, max.start) + { + (DomainBound::Unbounded, DomainBound::Unbounded) => Some(DomainRange { + start: DomainBound::Unbounded, + end: DomainBound::Unbounded, + }), + (DomainBound::Unbounded, DomainBound::Bounded(a)) => Some(DomainRange { + start: DomainBound::Bounded(a), + end: max.end, + }), + (DomainBound::Bounded(a), DomainBound::Unbounded) => Some(DomainRange { + start: max.start, + end: DomainBound::Bounded(a), + }), + (DomainBound::Bounded(a), DomainBound::Bounded(b)) if b < a => Some(DomainRange { + start: DomainBound::Bounded(b), + end: DomainBound::Bounded(a), + }), + _ => None, + } + } + + pub fn union(&self, range: &DomainRange) -> Domain + { + if self.disjoint(range) + { + let (a, b) = (self.min(range), self.max(range)); + if let DomainBound::Bounded(end) = a.end + && let DomainBound::Bounded(start) = b.start + && end == start + { + Domain { + union: vec![DomainRange { + start: a.start, + end: b.end, + }], + } + } + else + { + Domain { + union: vec![*self.min(range), *self.max(range)], + } + } + } + else + { + let (a, b) = (self.min(range), self.max(range)); + let range = DomainRange { + start: match (a.start, b.start) + { + (DomainBound::Bounded(a), DomainBound::Bounded(b)) => + { + DomainBound::Bounded(a.min(b)) + } + (DomainBound::Unbounded, DomainBound::Bounded(_)) + | (DomainBound::Bounded(_), DomainBound::Unbounded) + | (DomainBound::Unbounded, DomainBound::Unbounded) => DomainBound::Unbounded, + }, + end: match (a.end, b.end) + { + (DomainBound::Bounded(a), DomainBound::Bounded(b)) => + { + DomainBound::Bounded(a.max(b)) + } + (DomainBound::Unbounded, DomainBound::Bounded(_)) + | (DomainBound::Bounded(_), DomainBound::Unbounded) + | (DomainBound::Unbounded, DomainBound::Unbounded) => DomainBound::Unbounded, + }, + }; + + Domain { union: vec![range] } + } + } + + pub fn ends_furthest(&self, end: &DomainRange) -> bool + { + match (self.end, end.end) + { + (DomainBound::Unbounded, DomainBound::Unbounded) => true, + (DomainBound::Unbounded, DomainBound::Bounded(_)) => true, + (DomainBound::Bounded(_), DomainBound::Unbounded) => false, + (DomainBound::Bounded(a), DomainBound::Bounded(b)) => a >= b, + } + } + + pub fn disjoint(&self, range: &DomainRange) -> bool + { + let (a, b) = (self.min(range), self.max(range)); + matches!((a.end, b.start), (DomainBound::Bounded(a), DomainBound::Bounded(b)) if a <= b) + } + + pub fn singleton(value: i64) -> DomainRange + { + DomainRange { + start: DomainBound::Bounded(value), + end: DomainBound::Bounded(value + 1), + } + } + + pub fn from(value: i64) -> DomainRange + { + DomainRange { + start: DomainBound::Bounded(value), + end: DomainBound::Unbounded, + } + } + + pub fn to(value: i64) -> DomainRange + { + DomainRange { + start: DomainBound::Unbounded, + end: DomainBound::Bounded(value), + } + } + + pub fn closed(start: i64, end: i64) -> DomainRange + { + DomainRange { + start: DomainBound::Bounded(start), + end: DomainBound::Bounded(end), + } + } +} + +impl Display for Domain +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result + { + let len = self.union.len(); + for (i, range) in self.union.iter().enumerate() + { + write!(f, "{}", range)?; + if i != len - 1 + { + write!(f, " ∪ ")?; + } + } + Ok(()) + } +} + +impl Display for DomainRange +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result + { + match (self.start, self.end) + { + (DomainBound::Bounded(a), DomainBound::Bounded(b)) if b == a + 1 => write!(f, "{}", a), + (a, b) => write!(f, "{}..{}", a, b), + } + } +} + +impl Display for DomainBound +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result + { + match self + { + DomainBound::Unbounded => Ok(()), + DomainBound::Bounded(n) => write!(f, "{}", n), + } + } +} + +mod test +{ + #[allow(unused_imports)] + use crate::domain::{Domain, DomainBound, DomainRange}; + + #[test] + pub fn domain_range_parse() + { + let a: DomainRange = "..".into(); + assert_eq!( + a, + DomainRange { + start: DomainBound::Unbounded, + end: DomainBound::Unbounded + } + ); + + let a: DomainRange = "42..".into(); + assert_eq!( + a, + DomainRange { + start: DomainBound::Bounded(42), + end: DomainBound::Unbounded + } + ); + + let a: DomainRange = "..42".into(); + assert_eq!( + a, + DomainRange { + start: DomainBound::Unbounded, + end: DomainBound::Bounded(42), + } + ); + + let a: DomainRange = "0..42".into(); + assert_eq!( + a, + DomainRange { + start: DomainBound::Bounded(0), + end: DomainBound::Bounded(42), + } + ); + } + + #[test] + pub fn domain_range_union() + { + let inf: DomainRange = "..".into(); + let inf_b: DomainRange = "..42".into(); + let a_inf: DomainRange = "0..".into(); + let a_b: DomainRange = "10..20".into(); + let a_b2: DomainRange = "30..50".into(); + let a_b3: DomainRange = "20..30".into(); + + assert_eq!(inf.union(&inf), "..".into()); + assert_eq!(inf.union(&inf_b), "..".into()); + assert_eq!(inf.union(&a_inf), "..".into()); + assert_eq!(inf.union(&a_b), "..".into()); + + assert_eq!(inf.union(&inf), "..".into()); + assert_eq!(inf_b.union(&inf), "..".into()); + assert_eq!(a_inf.union(&inf), "..".into()); + assert_eq!(a_b.union(&inf), "..".into()); + + assert_eq!(inf.union(&inf), "..".into()); + assert_eq!(inf_b.union(&inf_b), "..42".into()); + assert_eq!(a_inf.union(&a_inf), "0..".into()); + assert_eq!(a_b.union(&a_b), "10..20".into()); + + assert_eq!( + a_b.union(&a_b2), + Domain { + union: vec!["10..20".into(), "30..50".into()] + } + ); + + assert_eq!(a_b.union(&a_b3), "10..30".into()); + } + + #[test] + pub fn domain_union() + { + let k: Domain = "2..5".into(); + let k2: Domain = "5..10".into(); + assert_eq!( + k.union(&k2), + Domain { + union: vec![DomainRange { + start: DomainBound::Bounded(2), + end: DomainBound::Bounded(10) + }] + } + ); + } + + #[test] + pub fn domain_intersection() + { + let k: Domain = "0..5".into(); + let k2: Domain = "10..15".into(); + + let k3: Domain = "3..7".into(); + let k4: Domain = "13..17".into(); + + let a = k.union(&k2); + let b = k3.union(&k4); + assert_eq!( + a.intersection(&b), + Domain { + union: vec![ + DomainRange { + start: DomainBound::Bounded(3), + end: DomainBound::Bounded(5) + }, + DomainRange { + start: DomainBound::Bounded(13), + end: DomainBound::Bounded(15) + } + ] + } + ); + } +} diff --git a/src/lib.rs b/src/lib.rs index 8b27f29..cc907a2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ pub mod ast; +pub mod domain; pub mod parsing; pub mod prover; diff --git a/src/main.rs b/src/main.rs index 0cb8466..e52da74 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,7 @@ use picolog::ast::Body; use picolog::ast::Module; +use picolog::ast::Predicate; +use picolog::domain::Domain; fn main() { @@ -8,48 +10,67 @@ fn main() .format_timestamp(None) .init(); //println!("{:#?}", Module::parse_from_file("1.pl")); + // let module: Module = " + // integer(zero). + // integer(s(X)) :- integer(X). + // + // add(X, zero, X). + // add(X, s(Y), Z) :- add(s(X), Y, Z). + // + // mult(zero, X, zero). + // mult(s(Y), X, Z) :- mult(Y, X, W), add(W, X, Z). + // + // op(10, yfx, +). + // op(8, yfx, ^). + // op(6, xfy, ::). + // op(2, fx, [). + // op(2, xf, ]). + // op(3, yfx, |). + // + // A + B :- test. + // + // A ^ B + C :- test. + // A::B::C :- A. + // [Hd|Tl] :- Hd::Tl. + // + // l(X) :- in(X, 3..7). + // " + // .into(); + + let d1: Domain = "5..10".into(); + let d2: Domain = "7..18".into(); + let module: Module = " - integer(zero). - integer(s(X)) :- integer(X). - - add(X, zero, X). - add(X, s(Y), Z) :- add(s(X), Y, Z). - - mult(zero, X, zero). - mult(s(Y), X, Z) :- mult(Y, X, W), add(W, X, Z). - - op(10, yfx, +). - op(8, yfx, ^). - op(6, xfy, ::). - op(2, fx, [). - op(2, xf, ]). - op(3, yfx, |). - - A + B :- test. - - A ^ B + C :- test. - A::B::C :- A. - [Hd|Tl] :- Hd::Tl. + hello(5..10). + world(7..18). " .into(); - // println!("{}", module); - // let prop: Body = "mult(X, s(s(s(zero))), s(s(s(s(s(s(s(s(s(zero))))))))))".into(); - //let prop: Body = "integer(s(X))".into(); - let prop: Body = "mult(X, s(s(zero)), s(s(s(s(zero)))))".into(); - //let prop: Body = "mult(X, Y, Z)".into(); - //let prop: Body = "mult(s(s(zero)), s(s(zero)), X)".into(); - for c in module.prove(&prop) + //let pred: Predicate = "hello(0..10).".into(); + let show: Body = "hello(X), world(X)".into(); + + for c in module.prove(&show) { println!("true:"); println!("{}", c.simplified()); let _ = std::io::stdin().read_line(&mut String::new()); } - + //println!("{}", module.prove(&pred).unwrap()); + // let prop: Body = "mult(X, s(s(s(zero))), s(s(s(s(s(s(s(s(s(zero))))))))))".into(); + //let prop: Body = "integer(s(X))".into(); + //let prop: Body = "mult(X, s(s(zero)), s(s(s(s(zero)))))".into(); + //let prop: Body = "mult(X, Y, Z)".into(); + //let prop: Body = "mult(s(s(zero)), s(s(zero)), X)".into(); + // for c in module.prove(&prop) + // { + // println!("true:"); + // println!("{}", c.simplified()); + // let _ = std::io::stdin().read_line(&mut String::new()); + // } + // // let p: Predicate = "add(s(zero), zero, Y)".into(); // let p1: Predicate = "add(X, zero, X)".into(); // // let p: Predicate = "integer(s(zero))".into(); // // let p1: Predicate = "integer(s(X))".into(); // println!("{}", p.matches(&p1).unwrap()); - // } diff --git a/src/parsing.rs b/src/parsing.rs index 34fcba8..6c7cf58 100644 --- a/src/parsing.rs +++ b/src/parsing.rs @@ -1,13 +1,14 @@ use std::{collections::HashMap, path::Path}; use winnow::{ - ascii::{self, alphanumeric1, multispace0}, + Parser, Result, Stateful, + ascii::{self, alphanumeric1, dec_int, multispace0}, combinator::{ - alt, delimited, expression, opt, preceded, repeat, separated, seq, terminated, Infix, - Postfix, Prefix, + Infix, Postfix, Prefix, alt, delimited, expression, opt, preceded, repeat, separated, seq, + terminated, }, error::ContextError, - Parser, Result, Stateful, + token::literal, }; use crate::ast::Clause; @@ -15,8 +16,9 @@ use crate::ast::Functor; use crate::ast::Module; use crate::ast::Operator; use crate::ast::Predicate; -use crate::ast::{Body, OperatorType}; use crate::ast::Variable; +use crate::ast::{Body, OperatorType}; +use crate::domain::{Domain, DomainBound, DomainRange}; impl Operator { @@ -106,6 +108,28 @@ impl State type Stream<'is> = Stateful<&'is str, State>; +pub fn domain_range_parse(input: &mut Stream) -> Result +{ + let start = opt(dec_int.map(DomainBound::Bounded)) + .map(|v| v.unwrap_or(DomainBound::Unbounded)) + .parse_next(input)?; + literal("..").parse_next(input)?; + let end = opt(dec_int.map(DomainBound::Bounded)) + .map(|v| v.unwrap_or(DomainBound::Unbounded)) + .parse_next(input)?; + Ok(DomainRange { start, end }) +} + +pub fn domain_parse(input: &mut Stream) -> Result +{ + // Domain is either a..b, .. + alt((domain_range_parse, dec_int.map(DomainRange::singleton))) + .map(|domain| Domain { + union: vec![domain], + }) + .parse_next(input) +} + pub fn operator_parse(input: &mut Stream) -> Result { delimited(multispace0, repeat(1.., alt(OPERATORS)), multispace0) @@ -226,6 +250,7 @@ pub fn predicate_parse_recursive(input: &mut Stream) -> Result { alt(( delimited("(", predicate_parse, ")"), + domain_parse.map(Predicate::Domain), predicate_parse_variable_or_functor, )) .parse_next(input) @@ -342,3 +367,35 @@ where .unwrap() } } + +impl From for Domain +where + T: AsRef, +{ + fn from(value: T) -> Self + { + let str: &str = value.as_ref(); + domain_parse + .parse_next(&mut Stream { + input: str, + state: State::new(), + }) + .unwrap() + } +} + +impl From for DomainRange +where + T: AsRef, +{ + fn from(value: T) -> Self + { + let str: &str = value.as_ref(); + domain_range_parse + .parse_next(&mut Stream { + input: str, + state: State::new(), + }) + .unwrap() + } +} diff --git a/src/prover.rs b/src/prover.rs index 00177c3..28e60dc 100644 --- a/src/prover.rs +++ b/src/prover.rs @@ -135,6 +135,7 @@ impl Predicate .map(|x| x.make_unique(counter.clone(), unique_map)) .collect(), ), + Predicate::Domain(domain) => Predicate::Domain(domain.clone()), } } } diff --git a/src/prover/constraints.rs b/src/prover/constraints.rs index bf17625..5e592ff 100644 --- a/src/prover/constraints.rs +++ b/src/prover/constraints.rs @@ -4,11 +4,15 @@ use std::fmt::Display; use crate::ast::Body; use crate::ast::Predicate; use crate::ast::Variable; +use crate::domain; +use crate::domain::Domain; +use crate::prover::predicate; #[derive(Clone, Debug)] pub struct Constraints { - pub(crate) set: HashMap, + pub(crate) predicates: HashMap, + pub(crate) domains: HashMap, } impl Constraints @@ -16,20 +20,33 @@ impl Constraints pub fn none() -> Self { Constraints { - set: HashMap::new(), + predicates: HashMap::new(), + domains: HashMap::new(), } } - pub fn with(variable: Variable, predicate: Predicate) -> Self + pub fn with(variable: Variable, predicate: Option, domain: Option) -> Self { let mut c = Constraints::none(); - c.set.insert(variable, predicate); + if let Some(predicate) = predicate + { + c.predicates.insert(variable.clone(), predicate); + } + if let Some(domain) = domain + { + c.domains.insert(variable, domain); + } c } - pub fn try_append(&mut self, variable: &Variable, predicate: &Predicate) -> bool + pub fn try_append( + &mut self, + variable: &Variable, + predicate: &Predicate, + domain: &Domain, + ) -> bool { - if let Some(other_predicate) = self.set.get(variable) + let predicates = if let Some(other_predicate) = self.predicates.get(variable) { if predicate == other_predicate { @@ -46,7 +63,7 @@ impl Constraints // We can try adding the unification contraints which is implicitely the same if self.try_merge(&unification_contraints) { - self.set.insert(variable.clone(), predicate.clone()); + self.predicates.insert(variable.clone(), predicate.clone()); true } else @@ -63,18 +80,40 @@ impl Constraints else { // No constraint - self.set.insert(variable.clone(), predicate.clone()); + self.predicates + .insert(variable.clone(), (predicate.clone(), domain.clone())); true + }; + + // Check if domains are compatible + let domains = if let Some((_, other_domain)) = self.predicates.get_mut(variable) + { + let intersection = domain.intersection(other_domain); + if intersection.empty() + { + false + } + else + { + *other_domain = intersection; + true + } } + else + { + true + }; + + domains && predicates } pub fn try_merge(&mut self, constraints: &Constraints) -> bool { // Trying to merge, is just trying to add all of the constraints into self let mut ok = self.clone(); - for (var, pred) in constraints.set.iter() + for (var, (pred, domain)) in constraints.predicates.iter() { - if !ok.try_append(var, pred) + if !ok.try_append(var, pred, domain) { return false; } @@ -93,24 +132,26 @@ impl Constraints pub fn simplified(&self) -> Constraints { let mut max_sub = Constraints::none(); - for (var, pred) in self.set.iter() + for (var, (pred, domain)) in self.predicates.iter() { - max_sub.set.insert(var.clone(), pred.substitute(self)); + max_sub + .predicates + .insert(var.clone(), (pred.substitute(self), domain.clone())); } let mut stripped = max_sub.clone(); - 'outer: for (var, _) in max_sub.set.iter() + 'outer: for (var, _) in max_sub.predicates.iter() { if var.0.chars().next().is_some_and(|x| x == '_') || var.1.is_some() { - for (_, other_pred) in max_sub.set.iter() + for (_, (other_pred, _)) in max_sub.predicates.iter() { if other_pred.contains_variable(var) { continue 'outer; } } - stripped.set.remove(var); + stripped.predicates.remove(var); } } @@ -126,7 +167,7 @@ impl Predicate { Predicate::Variable(name) => { - if let Some(pred) = constraints.set.get(name) + if let Some((pred, _)) = constraints.predicates.get(name) { pred.substitute(constraints) } @@ -142,6 +183,7 @@ impl Predicate .map(|x| x.substitute(constraints)) .collect(), ), + Predicate::Domain(domain) => Predicate::Domain(domain.clone()), } } @@ -151,6 +193,7 @@ impl Predicate { Predicate::Variable(var_name) => name == var_name, Predicate::Fixed(_, predicates) => predicates.iter().any(|x| x.contains_variable(name)), + Predicate::Domain(_) => false, } } } @@ -183,8 +226,8 @@ impl Display for Constraints { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let len = self.set.len(); - for (i, (var, pred)) in self.set.iter().enumerate() + let len = self.predicates.len(); + for (i, (var, (pred, _))) in self.predicates.iter().enumerate() { write!(f, "{} = {}", var, pred)?; if i != len - 1 @@ -192,6 +235,15 @@ impl Display for Constraints write!(f, ", ")?; } } + write!(f, " ;; ")?; + for (i, (var, (_, domain))) in self.predicates.iter().enumerate() + { + write!(f, "{} in {}", var, domain)?; + if i != len - 1 + { + write!(f, ", ")?; + } + } Ok(()) } } diff --git a/src/prover/internal_operators.rs b/src/prover/internal_operators.rs new file mode 100644 index 0000000..367e31b --- /dev/null +++ b/src/prover/internal_operators.rs @@ -0,0 +1,50 @@ +use crate::{ast::Predicate, prover::constraints::Constraints}; + +impl Constraints +{ + pub fn collapsed_operators(&self) -> Constraints + { + let mut new_constraints = Constraints::none(); + for (var, pred) in self.set.iter() + { + new_constraints + .set + .insert(var.clone(), pred.substitute(self).collapsed_operators()); + } + new_constraints + } +} + +impl Predicate +{ + pub fn collapsed_operators(&self) -> Predicate + { + match self + { + Predicate::Variable(variable) => Predicate::Variable(variable.clone()), + Predicate::Fixed(crate::ast::Functor::Operator(op), predicates) + if op.op == "+" && predicates.len() == 2 => + { + match (predicates[0].clone(), predicates[1].clone()) + { + (Predicate::Number(a), Predicate::Number(b)) => Predicate::Number(a + b), + _ => self.clone(), + } + } + Predicate::Fixed(crate::ast::Functor::Operator(op), predicates) + if op.op == "-" && predicates.len() == 2 => + { + match (predicates[0].clone(), predicates[1].clone()) + { + (Predicate::Number(a), Predicate::Number(b)) => Predicate::Number(a - b), + _ => self.clone(), + } + } + Predicate::Fixed(functor, predicates) => Predicate::Fixed( + functor.clone(), + predicates.iter().map(|x| x.collapsed_operators()).collect(), + ), + Predicate::Number(n) => Predicate::Number(*n), + } + } +} diff --git a/src/prover/unification.rs b/src/prover/unification.rs index d6d5042..9917a7c 100644 --- a/src/prover/unification.rs +++ b/src/prover/unification.rs @@ -1,4 +1,7 @@ +use std::clone; + use crate::ast::Predicate; +use crate::domain::{self, Domain}; use crate::prover::constraints::Constraints; impl Predicate @@ -16,9 +19,20 @@ impl Predicate //debug!("Unifying var {} against {}", self, other); // We are trying to see if X (any) matches the other Predicate. // This is always true if X = other_predicate - Some(Constraints::with(variable.clone(), other.clone())) + match other + { + Predicate::Domain(domain) => Some(Constraints::with( + variable.clone(), + other.clone(), + domain.clone(), + )), + _ => Some(Constraints::with( + variable.clone(), + other.clone(), + Domain::all(), + )), + } } - Predicate::Fixed(name, arguments) => { match other @@ -29,7 +43,7 @@ impl Predicate // (any) // This is always true //debug!("Unifying pred {} against var {}", self, other); - Some(Constraints::with(var.clone(), self.clone())) + Some(Constraints::with(var.clone(), self.clone(), Domain::all())) } // Match pred(X, Y, Z, ...) with pred(_X, _Y, _Z, ...) Predicate::Fixed(other_name, other_arguments) @@ -58,6 +72,20 @@ impl Predicate _ => None, } } + Predicate::Domain(domain) => match other + { + Predicate::Variable(variable) => Some(Constraints::with( + variable.clone(), + other.clone(), + domain.clone(), + )), + Predicate::Fixed(_, _) => None, + Predicate::Domain(other_domain) if !domain.intersection(other_domain).empty() => + { + Some(Constraints::none()) + } + Predicate::Domain(_) => None, + }, } } }