|
use crate::ast::{BinaryOp, Literal, Node, UnaryOp, Unit}; |
|
use crate::context::EvalContext; |
|
use crate::value::{Complex, Number, Value}; |
|
use lazy_static::lazy_static; |
|
use num_complex::ComplexFloat; |
|
use pest::Parser; |
|
use pest::iterators::{Pair, Pairs}; |
|
use pest::pratt_parser::{Assoc, Op, PrattParser}; |
|
use pest_derive::Parser; |
|
use std::num::{ParseFloatError, ParseIntError}; |
|
use thiserror::Error; |
|
|
|
#[derive(Parser)] |
|
#[grammar = "./grammer.pest"] |
|
struct ExprParser; |
|
|
|
lazy_static! { |
|
static ref PRATT_PARSER: PrattParser<Rule> = { |
|
PrattParser::new() |
|
.op(Op::infix(Rule::add, Assoc::Left) | Op::infix(Rule::sub, Assoc::Left)) |
|
.op(Op::infix(Rule::mul, Assoc::Left) | Op::infix(Rule::div, Assoc::Left) | Op::infix(Rule::paren, Assoc::Left)) |
|
.op(Op::infix(Rule::pow, Assoc::Right)) |
|
.op(Op::postfix(Rule::fac) | Op::postfix(Rule::EOI)) |
|
.op(Op::prefix(Rule::sqrt)) |
|
.op(Op::prefix(Rule::neg)) |
|
}; |
|
} |
|
|
|
#[derive(Error, Debug)] |
|
pub enum TypeError { |
|
#[error("Invalid BinOp: {0:?} {1:?} {2:?}")] |
|
InvalidBinaryOp(Unit, BinaryOp, Unit), |
|
|
|
#[error("Invalid UnaryOp: {0:?}")] |
|
InvalidUnaryOp(Unit, UnaryOp), |
|
} |
|
|
|
#[derive(Error, Debug)] |
|
pub enum ParseError { |
|
#[error("ParseIntError: {0}")] |
|
ParseInt(#[from] ParseIntError), |
|
#[error("ParseFloatError: {0}")] |
|
ParseFloat(#[from] ParseFloatError), |
|
|
|
#[error("TypeError: {0}")] |
|
Type(#[from] TypeError), |
|
|
|
#[error("PestError: {0}")] |
|
Pest(#[from] Box<pest::error::Error<Rule>>), |
|
} |
|
|
|
impl Node { |
|
pub fn try_parse_from_str(s: &str) -> Result<(Node, Unit), ParseError> { |
|
let pairs = ExprParser::parse(Rule::program, s).map_err(Box::new)?; |
|
let (node, metadata) = parse_expr(pairs)?; |
|
Ok((node, metadata.unit)) |
|
} |
|
} |
|
|
|
struct NodeMetadata { |
|
pub unit: Unit, |
|
} |
|
|
|
impl NodeMetadata { |
|
pub fn new(unit: Unit) -> Self { |
|
Self { unit } |
|
} |
|
} |
|
|
|
fn parse_unit(pairs: Pairs<Rule>) -> Result<(Unit, f64), ParseError> { |
|
let mut scale = 1.0; |
|
let mut length = 0; |
|
let mut mass = 0; |
|
let mut time = 0; |
|
|
|
for pair in pairs { |
|
println!("found rule: {:?}", pair.as_rule()); |
|
match pair.as_rule() { |
|
Rule::nano => scale *= 1e-9, |
|
Rule::micro => scale *= 1e-6, |
|
Rule::milli => scale *= 1e-3, |
|
Rule::centi => scale *= 1e-2, |
|
Rule::deci => scale *= 1e-1, |
|
Rule::deca => scale *= 1e1, |
|
Rule::hecto => scale *= 1e2, |
|
Rule::kilo => scale *= 1e3, |
|
Rule::mega => scale *= 1e6, |
|
Rule::giga => scale *= 1e9, |
|
Rule::tera => scale *= 1e12, |
|
|
|
Rule::meter => length = 1, |
|
Rule::gram => mass = 1, |
|
Rule::second => time = 1, |
|
|
|
_ => unreachable!(), |
|
} |
|
} |
|
|
|
Ok((Unit { length, mass, time }, scale)) |
|
} |
|
|
|
fn parse_const(pair: Pair<Rule>) -> Literal { |
|
match pair.as_rule() { |
|
Rule::infinity => Literal::Float(f64::INFINITY), |
|
Rule::imaginary_unit => Literal::Complex(Complex::new(0.0, 1.0)), |
|
Rule::pi => Literal::Float(std::f64::consts::PI), |
|
Rule::tau => Literal::Float(2.0 * std::f64::consts::PI), |
|
Rule::euler_number => Literal::Float(std::f64::consts::E), |
|
Rule::golden_ratio => Literal::Float(1.61803398875), |
|
_ => unreachable!("Unexpected constant: {:?}", pair), |
|
} |
|
} |
|
|
|
fn parse_lit(mut pairs: Pairs<Rule>) -> Result<(Literal, Unit), ParseError> { |
|
let literal = match pairs.next() { |
|
Some(lit) => match lit.as_rule() { |
|
Rule::int => { |
|
let value = lit.as_str().parse::<i32>()? as f64; |
|
Literal::Float(value) |
|
} |
|
Rule::float => { |
|
let value = lit.as_str().parse::<f64>()?; |
|
Literal::Float(value) |
|
} |
|
Rule::unit => { |
|
let (unit, scale) = parse_unit(lit.into_inner())?; |
|
return Ok((Literal::Float(scale), unit)); |
|
} |
|
rule => unreachable!("unexpected rule: {:?}", rule), |
|
}, |
|
None => unreachable!("expected rule"), |
|
}; |
|
|
|
if let Some(unit_pair) = pairs.next() { |
|
let unit_pairs = unit_pair.into_inner(); |
|
let (unit, scale) = parse_unit(unit_pairs)?; |
|
|
|
println!("found unit: {:?}", unit); |
|
|
|
Ok(( |
|
match literal { |
|
Literal::Float(num) => Literal::Float(num * scale), |
|
Literal::Complex(num) => Literal::Complex(num * scale), |
|
}, |
|
unit, |
|
)) |
|
} else { |
|
Ok((literal, Unit::BASE_UNIT)) |
|
} |
|
} |
|
|
|
fn parse_expr(pairs: Pairs<Rule>) -> Result<(Node, NodeMetadata), ParseError> { |
|
PRATT_PARSER |
|
.map_primary(|primary| { |
|
Ok(match primary.as_rule() { |
|
Rule::lit => { |
|
let (lit, unit) = parse_lit(primary.into_inner())?; |
|
|
|
(Node::Lit(lit), NodeMetadata { unit }) |
|
} |
|
Rule::fn_call => { |
|
let mut pairs = primary.into_inner(); |
|
let name = pairs.next().expect("fn_call always has 2 children").as_str().to_string(); |
|
|
|
( |
|
Node::FnCall { |
|
name, |
|
expr: pairs.map(|p| parse_expr(p.into_inner()).map(|expr| expr.0)).collect::<Result<Vec<Node>, ParseError>>()?, |
|
}, |
|
NodeMetadata::new(Unit::BASE_UNIT), |
|
) |
|
} |
|
Rule::constant => { |
|
let lit = parse_const(primary.into_inner().next().expect("constant should have atleast 1 child")); |
|
|
|
(Node::Lit(lit), NodeMetadata::new(Unit::BASE_UNIT)) |
|
} |
|
Rule::ident => { |
|
let name = primary.as_str().to_string(); |
|
|
|
(Node::Var(name), NodeMetadata::new(Unit::BASE_UNIT)) |
|
} |
|
Rule::expr => parse_expr(primary.into_inner())?, |
|
Rule::float => { |
|
let value = primary.as_str().parse::<f64>()?; |
|
(Node::Lit(Literal::Float(value)), NodeMetadata::new(Unit::BASE_UNIT)) |
|
} |
|
rule => unreachable!("unexpected rule: {:?}", rule), |
|
}) |
|
}) |
|
.map_prefix(|op, rhs| { |
|
let (rhs, rhs_metadata) = rhs?; |
|
let op = match op.as_rule() { |
|
Rule::neg => UnaryOp::Neg, |
|
Rule::sqrt => UnaryOp::Sqrt, |
|
|
|
rule => unreachable!("unexpected rule: {:?}", rule), |
|
}; |
|
|
|
let node = Node::UnaryOp { expr: Box::new(rhs), op }; |
|
let unit = rhs_metadata.unit; |
|
|
|
let unit = if !unit.is_base() { |
|
match op { |
|
UnaryOp::Sqrt if unit.length % 2 == 0 && unit.mass % 2 == 0 && unit.time % 2 == 0 => Unit { |
|
length: unit.length / 2, |
|
mass: unit.mass / 2, |
|
time: unit.time / 2, |
|
}, |
|
UnaryOp::Neg => unit, |
|
op => return Err(ParseError::Type(TypeError::InvalidUnaryOp(unit, op))), |
|
} |
|
} else { |
|
Unit::BASE_UNIT |
|
}; |
|
|
|
Ok((node, NodeMetadata::new(unit))) |
|
}) |
|
.map_postfix(|lhs, op| { |
|
let (lhs_node, lhs_metadata) = lhs?; |
|
|
|
let op = match op.as_rule() { |
|
Rule::EOI => return Ok((lhs_node, lhs_metadata)), |
|
Rule::fac => UnaryOp::Fac, |
|
rule => unreachable!("unexpected rule: {:?}", rule), |
|
}; |
|
|
|
if !lhs_metadata.unit.is_base() { |
|
return Err(ParseError::Type(TypeError::InvalidUnaryOp(lhs_metadata.unit, op))); |
|
} |
|
|
|
Ok((Node::UnaryOp { expr: Box::new(lhs_node), op }, lhs_metadata)) |
|
}) |
|
.map_infix(|lhs, op, rhs| { |
|
let (lhs, lhs_metadata) = lhs?; |
|
let (rhs, rhs_metadata) = rhs?; |
|
|
|
let op = match op.as_rule() { |
|
Rule::add => BinaryOp::Add, |
|
Rule::sub => BinaryOp::Sub, |
|
Rule::mul => BinaryOp::Mul, |
|
Rule::div => BinaryOp::Div, |
|
Rule::pow => BinaryOp::Pow, |
|
Rule::paren => BinaryOp::Mul, |
|
rule => unreachable!("unexpected rule: {:?}", rule), |
|
}; |
|
|
|
let (lhs_unit, rhs_unit) = (lhs_metadata.unit, rhs_metadata.unit); |
|
|
|
let unit = match (!lhs_unit.is_base(), !rhs_unit.is_base()) { |
|
(true, true) => match op { |
|
BinaryOp::Mul => Unit { |
|
length: lhs_unit.length + rhs_unit.length, |
|
mass: lhs_unit.mass + rhs_unit.mass, |
|
time: lhs_unit.time + rhs_unit.time, |
|
}, |
|
BinaryOp::Div => Unit { |
|
length: lhs_unit.length - rhs_unit.length, |
|
mass: lhs_unit.mass - rhs_unit.mass, |
|
time: lhs_unit.time - rhs_unit.time, |
|
}, |
|
BinaryOp::Add | BinaryOp::Sub => { |
|
if lhs_unit == rhs_unit { |
|
lhs_unit |
|
} else { |
|
return Err(ParseError::Type(TypeError::InvalidBinaryOp(lhs_unit, op, rhs_unit))); |
|
} |
|
} |
|
BinaryOp::Pow => { |
|
return Err(ParseError::Type(TypeError::InvalidBinaryOp(lhs_unit, op, rhs_unit))); |
|
} |
|
}, |
|
|
|
(true, false) => match op { |
|
BinaryOp::Add | BinaryOp::Sub => return Err(ParseError::Type(TypeError::InvalidBinaryOp(lhs_unit, op, Unit::BASE_UNIT))), |
|
BinaryOp::Pow => { |
|
|
|
|
|
if let Ok(Value::Number(Number::Real(val))) = rhs.eval(&EvalContext::default()) { |
|
if (val - val as i32 as f64).abs() <= f64::EPSILON { |
|
Unit { |
|
length: lhs_unit.length * val as i32, |
|
mass: lhs_unit.mass * val as i32, |
|
time: lhs_unit.time * val as i32, |
|
} |
|
} else { |
|
return Err(ParseError::Type(TypeError::InvalidBinaryOp(lhs_unit, op, Unit::BASE_UNIT))); |
|
} |
|
} else { |
|
return Err(ParseError::Type(TypeError::InvalidBinaryOp(lhs_unit, op, Unit::BASE_UNIT))); |
|
} |
|
} |
|
_ => lhs_unit, |
|
}, |
|
(false, true) => match op { |
|
BinaryOp::Add | BinaryOp::Sub | BinaryOp::Pow => return Err(ParseError::Type(TypeError::InvalidBinaryOp(Unit::BASE_UNIT, op, rhs_unit))), |
|
_ => rhs_unit, |
|
}, |
|
(false, false) => Unit::BASE_UNIT, |
|
}; |
|
|
|
let node = Node::BinOp { |
|
lhs: Box::new(lhs), |
|
op, |
|
rhs: Box::new(rhs), |
|
}; |
|
|
|
Ok((node, NodeMetadata::new(unit))) |
|
}) |
|
.parse(pairs) |
|
} |
|
|
|
|
|
#[cfg(test)] |
|
mod tests { |
|
use super::*; |
|
macro_rules! test_parser { |
|
($($name:ident: $input:expr_2021 => $expected:expr_2021),* $(,)?) => { |
|
$( |
|
#[test] |
|
fn $name() { |
|
let result = Node::try_parse_from_str($input).unwrap(); |
|
assert_eq!(result.0, $expected); |
|
} |
|
)* |
|
}; |
|
} |
|
|
|
test_parser! { |
|
test_parse_int_literal: "42" => Node::Lit(Literal::Float(42.0)), |
|
test_parse_float_literal: "3.14" => Node::Lit(Literal::Float(#[allow(clippy::approx_constant)] 3.14)), |
|
test_parse_ident: "x" => Node::Var("x".to_string()), |
|
test_parse_unary_neg: "-42" => Node::UnaryOp { |
|
expr: Box::new(Node::Lit(Literal::Float(42.0))), |
|
op: UnaryOp::Neg, |
|
}, |
|
test_parse_binary_add: "1 + 2" => Node::BinOp { |
|
lhs: Box::new(Node::Lit(Literal::Float(1.0))), |
|
op: BinaryOp::Add, |
|
rhs: Box::new(Node::Lit(Literal::Float(2.0))), |
|
}, |
|
test_parse_binary_mul: "3 * 4" => Node::BinOp { |
|
lhs: Box::new(Node::Lit(Literal::Float(3.0))), |
|
op: BinaryOp::Mul, |
|
rhs: Box::new(Node::Lit(Literal::Float(4.0))), |
|
}, |
|
test_parse_binary_pow: "2 ^ 3" => Node::BinOp { |
|
lhs: Box::new(Node::Lit(Literal::Float(2.0))), |
|
op: BinaryOp::Pow, |
|
rhs: Box::new(Node::Lit(Literal::Float(3.0))), |
|
}, |
|
test_parse_unary_sqrt: "sqrt(16)" => Node::UnaryOp { |
|
expr: Box::new(Node::Lit(Literal::Float(16.0))), |
|
op: UnaryOp::Sqrt, |
|
}, |
|
test_parse_sqr_ident: "sqr(16)" => Node::FnCall { |
|
name:"sqr".to_string(), |
|
expr: vec![Node::Lit(Literal::Float(16.0))] |
|
}, |
|
|
|
test_parse_complex_expr: "(1 + 2) 3 - 4 ^ 2" => Node::BinOp { |
|
lhs: Box::new(Node::BinOp { |
|
lhs: Box::new(Node::BinOp { |
|
lhs: Box::new(Node::Lit(Literal::Float(1.0))), |
|
op: BinaryOp::Add, |
|
rhs: Box::new(Node::Lit(Literal::Float(2.0))), |
|
}), |
|
op: BinaryOp::Mul, |
|
rhs: Box::new(Node::Lit(Literal::Float(3.0))), |
|
}), |
|
op: BinaryOp::Sub, |
|
rhs: Box::new(Node::BinOp { |
|
lhs: Box::new(Node::Lit(Literal::Float(4.0))), |
|
op: BinaryOp::Pow, |
|
rhs: Box::new(Node::Lit(Literal::Float(2.0))), |
|
}), |
|
} |
|
} |
|
} |
|
|