|
use crate::document::{InlineRust, value}; |
|
use crate::document::{NodeId, OriginalLocation}; |
|
pub use graphene_core::registry::*; |
|
use graphene_core::*; |
|
use rustc_hash::FxHashMap; |
|
use std::borrow::Cow; |
|
use std::collections::{HashMap, HashSet}; |
|
use std::fmt::Debug; |
|
use std::hash::Hash; |
|
|
|
#[derive(Debug, Default, PartialEq, Clone, Hash, Eq, serde::Serialize, serde::Deserialize)] |
|
|
|
pub struct ProtoNetwork { |
|
|
|
|
|
pub inputs: Vec<NodeId>, |
|
|
|
pub output: NodeId, |
|
|
|
pub nodes: Vec<(NodeId, ProtoNode)>, |
|
} |
|
|
|
impl core::fmt::Display for ProtoNetwork { |
|
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { |
|
f.write_str("Proto Network with nodes: ")?; |
|
fn write_node(f: &mut core::fmt::Formatter<'_>, network: &ProtoNetwork, id: NodeId, indent: usize) -> core::fmt::Result { |
|
f.write_str(&"\t".repeat(indent))?; |
|
let Some((_, node)) = network.nodes.iter().find(|(node_id, _)| *node_id == id) else { |
|
return f.write_str("{{Unknown Node}}"); |
|
}; |
|
f.write_str("Node: ")?; |
|
f.write_str(&node.identifier.name)?; |
|
|
|
f.write_str("\n")?; |
|
f.write_str(&"\t".repeat(indent))?; |
|
f.write_str("{\n")?; |
|
|
|
f.write_str(&"\t".repeat(indent + 1))?; |
|
f.write_str("Input: ")?; |
|
match &node.input { |
|
ProtoNodeInput::None => f.write_str("None")?, |
|
ProtoNodeInput::ManualComposition(ty) => f.write_fmt(format_args!("Manual Composition (type = {ty:?})"))?, |
|
ProtoNodeInput::Node(_) => f.write_str("Node")?, |
|
ProtoNodeInput::NodeLambda(_) => f.write_str("Lambda Node")?, |
|
} |
|
f.write_str("\n")?; |
|
|
|
match &node.construction_args { |
|
ConstructionArgs::Value(value) => { |
|
f.write_str(&"\t".repeat(indent + 1))?; |
|
f.write_fmt(format_args!("Value construction argument: {value:?}"))? |
|
} |
|
ConstructionArgs::Nodes(nodes) => { |
|
for id in nodes { |
|
write_node(f, network, id.0, indent + 1)?; |
|
} |
|
} |
|
ConstructionArgs::Inline(inline) => { |
|
f.write_str(&"\t".repeat(indent + 1))?; |
|
f.write_fmt(format_args!("Inline construction argument: {inline:?}"))? |
|
} |
|
} |
|
f.write_str(&"\t".repeat(indent))?; |
|
f.write_str("}\n")?; |
|
Ok(()) |
|
} |
|
|
|
let id = self.output; |
|
write_node(f, self, id, 0) |
|
} |
|
} |
|
|
|
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] |
|
|
|
pub enum ConstructionArgs { |
|
|
|
Value(MemoHash<value::TaggedValue>), |
|
|
|
|
|
|
|
Nodes(Vec<(NodeId, bool)>), |
|
|
|
Inline(InlineRust), |
|
} |
|
|
|
impl Eq for ConstructionArgs {} |
|
|
|
impl PartialEq for ConstructionArgs { |
|
fn eq(&self, other: &Self) -> bool { |
|
match (&self, &other) { |
|
(Self::Nodes(n1), Self::Nodes(n2)) => n1 == n2, |
|
(Self::Value(v1), Self::Value(v2)) => v1 == v2, |
|
_ => { |
|
use std::hash::Hasher; |
|
let hash = |input: &Self| { |
|
let mut hasher = rustc_hash::FxHasher::default(); |
|
input.hash(&mut hasher); |
|
hasher.finish() |
|
}; |
|
hash(self) == hash(other) |
|
} |
|
} |
|
} |
|
} |
|
|
|
impl Hash for ConstructionArgs { |
|
fn hash<H: std::hash::Hasher>(&self, state: &mut H) { |
|
core::mem::discriminant(self).hash(state); |
|
match self { |
|
Self::Nodes(nodes) => { |
|
for node in nodes { |
|
node.hash(state); |
|
} |
|
} |
|
Self::Value(value) => value.hash(state), |
|
Self::Inline(inline) => inline.hash(state), |
|
} |
|
} |
|
} |
|
|
|
impl ConstructionArgs { |
|
|
|
pub fn new_function_args(&self) -> Vec<String> { |
|
match self { |
|
ConstructionArgs::Nodes(nodes) => nodes.iter().map(|(n, _)| format!("n{:0x}", n.0)).collect(), |
|
ConstructionArgs::Value(value) => vec![value.to_primitive_string()], |
|
ConstructionArgs::Inline(inline) => vec![inline.expr.clone()], |
|
} |
|
} |
|
} |
|
|
|
#[derive(Debug, Clone, PartialEq, Hash, Eq, serde::Serialize, serde::Deserialize)] |
|
|
|
|
|
pub struct ProtoNode { |
|
pub construction_args: ConstructionArgs, |
|
pub input: ProtoNodeInput, |
|
pub identifier: ProtoNodeIdentifier, |
|
pub original_location: OriginalLocation, |
|
pub skip_deduplication: bool, |
|
} |
|
|
|
impl Default for ProtoNode { |
|
fn default() -> Self { |
|
Self { |
|
identifier: ProtoNodeIdentifier::new("graphene_core::ops::IdentityNode"), |
|
construction_args: ConstructionArgs::Value(value::TaggedValue::U32(0).into()), |
|
input: ProtoNodeInput::None, |
|
original_location: OriginalLocation::default(), |
|
skip_deduplication: false, |
|
} |
|
} |
|
} |
|
|
|
|
|
#[derive(Debug, PartialEq, Eq, Clone, Hash, serde::Serialize, serde::Deserialize)] |
|
pub enum ProtoNodeInput { |
|
|
|
None, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ManualComposition(Type), |
|
|
|
|
|
|
|
|
|
|
|
|
|
Node(NodeId), |
|
|
|
|
|
|
|
|
|
|
|
|
|
NodeLambda(NodeId), |
|
} |
|
|
|
impl ProtoNode { |
|
|
|
|
|
pub fn stable_node_id(&self) -> Option<NodeId> { |
|
use std::hash::Hasher; |
|
let mut hasher = rustc_hash::FxHasher::default(); |
|
|
|
self.identifier.name.hash(&mut hasher); |
|
self.construction_args.hash(&mut hasher); |
|
if self.skip_deduplication { |
|
self.original_location.path.hash(&mut hasher); |
|
} |
|
|
|
std::mem::discriminant(&self.input).hash(&mut hasher); |
|
match self.input { |
|
ProtoNodeInput::None => (), |
|
ProtoNodeInput::ManualComposition(ref ty) => { |
|
ty.hash(&mut hasher); |
|
} |
|
ProtoNodeInput::Node(id) => (id, false).hash(&mut hasher), |
|
ProtoNodeInput::NodeLambda(id) => (id, true).hash(&mut hasher), |
|
}; |
|
|
|
Some(NodeId(hasher.finish())) |
|
} |
|
|
|
|
|
pub fn value(value: ConstructionArgs, path: Vec<NodeId>) -> Self { |
|
let inputs_exposed = match &value { |
|
ConstructionArgs::Nodes(nodes) => nodes.len() + 1, |
|
_ => 2, |
|
}; |
|
Self { |
|
identifier: ProtoNodeIdentifier::new("graphene_core::value::ClonedNode"), |
|
construction_args: value, |
|
input: ProtoNodeInput::ManualComposition(concrete!(Context)), |
|
original_location: OriginalLocation { |
|
path: Some(path), |
|
inputs_exposed: vec![false; inputs_exposed], |
|
..Default::default() |
|
}, |
|
skip_deduplication: false, |
|
} |
|
} |
|
|
|
|
|
|
|
pub fn map_ids(&mut self, f: impl Fn(NodeId) -> NodeId, skip_lambdas: bool) { |
|
match self.input { |
|
ProtoNodeInput::Node(id) => self.input = ProtoNodeInput::Node(f(id)), |
|
ProtoNodeInput::NodeLambda(id) => { |
|
if !skip_lambdas { |
|
self.input = ProtoNodeInput::NodeLambda(f(id)) |
|
} |
|
} |
|
_ => (), |
|
} |
|
|
|
if let ConstructionArgs::Nodes(ids) = &mut self.construction_args { |
|
ids.iter_mut().filter(|(_, lambda)| !(skip_lambdas && *lambda)).for_each(|(id, _)| *id = f(*id)); |
|
} |
|
} |
|
|
|
pub fn unwrap_construction_nodes(&self) -> Vec<(NodeId, bool)> { |
|
match &self.construction_args { |
|
ConstructionArgs::Nodes(nodes) => nodes.clone(), |
|
_ => panic!("tried to unwrap nodes from non node construction args \n node: {self:#?}"), |
|
} |
|
} |
|
} |
|
|
|
#[derive(Clone, Copy, PartialEq)] |
|
enum NodeState { |
|
Unvisited, |
|
Visiting, |
|
Visited, |
|
} |
|
|
|
impl ProtoNetwork { |
|
fn check_ref(&self, ref_id: &NodeId, id: &NodeId) { |
|
debug_assert!( |
|
self.nodes.iter().any(|(check_id, _)| check_id == ref_id), |
|
"Node id:{id} has a reference which uses node id:{ref_id} which doesn't exist in network {self:#?}" |
|
); |
|
} |
|
|
|
#[cfg(debug_assertions)] |
|
pub fn example() -> (Self, NodeId, ProtoNode) { |
|
let node_id = NodeId(1); |
|
let proto_node = ProtoNode::default(); |
|
let proto_network = ProtoNetwork { |
|
inputs: vec![node_id], |
|
output: node_id, |
|
nodes: vec![(node_id, proto_node.clone())], |
|
}; |
|
(proto_network, node_id, proto_node) |
|
} |
|
|
|
|
|
pub fn collect_outwards_edges(&self) -> HashMap<NodeId, Vec<NodeId>> { |
|
let mut edges: HashMap<NodeId, Vec<NodeId>> = HashMap::new(); |
|
for (id, node) in &self.nodes { |
|
match &node.input { |
|
ProtoNodeInput::Node(ref_id) | ProtoNodeInput::NodeLambda(ref_id) => { |
|
self.check_ref(ref_id, id); |
|
edges.entry(*ref_id).or_default().push(*id) |
|
} |
|
_ => (), |
|
} |
|
|
|
if let ConstructionArgs::Nodes(ref_nodes) = &node.construction_args { |
|
for (ref_id, _) in ref_nodes { |
|
self.check_ref(ref_id, id); |
|
edges.entry(*ref_id).or_default().push(*id) |
|
} |
|
} |
|
} |
|
edges |
|
} |
|
|
|
|
|
|
|
pub fn generate_stable_node_ids(&mut self) { |
|
debug_assert!(self.is_topologically_sorted()); |
|
let outwards_edges = self.collect_outwards_edges(); |
|
|
|
for index in 0..self.nodes.len() { |
|
let Some(sni) = self.nodes[index].1.stable_node_id() else { |
|
panic!("failed to generate stable node id for node {:#?}", self.nodes[index].1); |
|
}; |
|
self.replace_node_id(&outwards_edges, NodeId(index as u64), sni, false); |
|
self.nodes[index].0 = sni; |
|
} |
|
} |
|
|
|
|
|
|
|
pub fn collect_inwards_edges(&self) -> HashMap<NodeId, Vec<NodeId>> { |
|
let mut edges: HashMap<NodeId, Vec<NodeId>> = HashMap::new(); |
|
for (id, node) in &self.nodes { |
|
match &node.input { |
|
ProtoNodeInput::Node(ref_id) | ProtoNodeInput::NodeLambda(ref_id) => { |
|
self.check_ref(ref_id, id); |
|
edges.entry(*id).or_default().push(*ref_id) |
|
} |
|
_ => (), |
|
} |
|
|
|
if let ConstructionArgs::Nodes(ref_nodes) = &node.construction_args { |
|
for (ref_id, _) in ref_nodes { |
|
self.check_ref(ref_id, id); |
|
edges.entry(*id).or_default().push(*ref_id) |
|
} |
|
} |
|
} |
|
edges |
|
} |
|
|
|
fn collect_inwards_edges_with_mapping(&self) -> (Vec<Vec<usize>>, FxHashMap<NodeId, usize>) { |
|
let id_map: FxHashMap<_, _> = self.nodes.iter().enumerate().map(|(idx, (id, _))| (*id, idx)).collect(); |
|
|
|
|
|
let mut inwards_edges = vec![Vec::new(); self.nodes.len()]; |
|
for (node_id, node) in &self.nodes { |
|
let node_index = id_map[node_id]; |
|
match &node.input { |
|
ProtoNodeInput::Node(ref_id) | ProtoNodeInput::NodeLambda(ref_id) => { |
|
self.check_ref(ref_id, &NodeId(node_index as u64)); |
|
inwards_edges[node_index].push(id_map[ref_id]); |
|
} |
|
_ => {} |
|
} |
|
|
|
if let ConstructionArgs::Nodes(ref_nodes) = &node.construction_args { |
|
for (ref_id, _) in ref_nodes { |
|
self.check_ref(ref_id, &NodeId(node_index as u64)); |
|
inwards_edges[node_index].push(id_map[ref_id]); |
|
} |
|
} |
|
} |
|
|
|
(inwards_edges, id_map) |
|
} |
|
|
|
|
|
pub fn resolve_inputs(&mut self) -> Result<(), String> { |
|
|
|
self.reorder_ids()?; |
|
|
|
let max_id = self.nodes.len() as u64 - 1; |
|
|
|
|
|
let outwards_edges = self.collect_outwards_edges(); |
|
|
|
|
|
for node_id in 0..=max_id { |
|
let node_id = NodeId(node_id); |
|
|
|
let (_, node) = &mut self.nodes[node_id.0 as usize]; |
|
|
|
if let ProtoNodeInput::Node(input_node_id) = node.input { |
|
|
|
let compose_node_id = NodeId(self.nodes.len() as u64); |
|
|
|
let (_, input_node_id_proto) = &self.nodes[input_node_id.0 as usize]; |
|
|
|
let input = input_node_id_proto.input.clone(); |
|
|
|
let mut path = input_node_id_proto.original_location.path.clone(); |
|
if let Some(path) = &mut path { |
|
path.push(node_id); |
|
} |
|
|
|
self.nodes.push(( |
|
compose_node_id, |
|
ProtoNode { |
|
identifier: ProtoNodeIdentifier::new("graphene_core::structural::ComposeNode"), |
|
construction_args: ConstructionArgs::Nodes(vec![(input_node_id, false), (node_id, true)]), |
|
input, |
|
original_location: OriginalLocation { path, ..Default::default() }, |
|
skip_deduplication: false, |
|
}, |
|
)); |
|
|
|
self.replace_node_id(&outwards_edges, node_id, compose_node_id, true); |
|
} |
|
} |
|
self.reorder_ids()?; |
|
Ok(()) |
|
} |
|
|
|
|
|
fn replace_node_id(&mut self, outwards_edges: &HashMap<NodeId, Vec<NodeId>>, node_id: NodeId, compose_node_id: NodeId, skip_lambdas: bool) { |
|
|
|
if let Some(referring_nodes) = outwards_edges.get(&node_id) { |
|
for &referring_node_id in referring_nodes { |
|
let (_, referring_node) = &mut self.nodes[referring_node_id.0 as usize]; |
|
referring_node.map_ids(|id| if id == node_id { compose_node_id } else { id }, skip_lambdas) |
|
} |
|
} |
|
|
|
if self.output == node_id { |
|
self.output = compose_node_id; |
|
} |
|
|
|
self.inputs.iter_mut().for_each(|id| { |
|
if *id == node_id { |
|
*id = compose_node_id; |
|
} |
|
}); |
|
} |
|
|
|
|
|
|
|
pub fn topological_sort(&self) -> Result<(Vec<NodeId>, FxHashMap<NodeId, usize>), String> { |
|
let (inwards_edges, id_map) = self.collect_inwards_edges_with_mapping(); |
|
let mut sorted = Vec::with_capacity(self.nodes.len()); |
|
let mut stack = vec![id_map[&self.output]]; |
|
let mut state = vec![NodeState::Unvisited; self.nodes.len()]; |
|
|
|
while let Some(&node_index) = stack.last() { |
|
match state[node_index] { |
|
NodeState::Unvisited => { |
|
state[node_index] = NodeState::Visiting; |
|
for &dep_index in inwards_edges[node_index].iter().rev() { |
|
match state[dep_index] { |
|
NodeState::Visiting => { |
|
return Err(format!("Cycle detected involving node {}", self.nodes[dep_index].0)); |
|
} |
|
NodeState::Unvisited => { |
|
stack.push(dep_index); |
|
} |
|
NodeState::Visited => {} |
|
} |
|
} |
|
} |
|
NodeState::Visiting => { |
|
stack.pop(); |
|
state[node_index] = NodeState::Visited; |
|
sorted.push(NodeId(node_index as u64)); |
|
} |
|
NodeState::Visited => { |
|
stack.pop(); |
|
} |
|
} |
|
} |
|
|
|
Ok((sorted, id_map)) |
|
} |
|
|
|
fn is_topologically_sorted(&self) -> bool { |
|
let mut visited = HashSet::new(); |
|
|
|
let inwards_edges = self.collect_inwards_edges(); |
|
for (id, _) in &self.nodes { |
|
for &dependency in inwards_edges.get(id).unwrap_or(&Vec::new()) { |
|
if !visited.contains(&dependency) { |
|
dbg!(id, dependency); |
|
dbg!(&visited); |
|
dbg!(&self.nodes); |
|
return false; |
|
} |
|
} |
|
visited.insert(*id); |
|
} |
|
true |
|
} |
|
|
|
|
|
fn reorder_ids(&mut self) -> Result<(), String> { |
|
let (order, _id_map) = self.topological_sort()?; |
|
|
|
|
|
|
|
|
|
|
|
let new_positions: FxHashMap<_, _> = order.iter().enumerate().map(|(pos, id)| (self.nodes[id.0 as usize].0, pos)).collect(); |
|
|
|
|
|
|
|
|
|
let mut new_nodes = Vec::with_capacity(order.len()); |
|
for (index, &id) in order.iter().enumerate() { |
|
let mut node = std::mem::take(&mut self.nodes[id.0 as usize].1); |
|
|
|
node.map_ids(|id| NodeId(*new_positions.get(&id).expect("node not found in lookup table") as u64), false); |
|
new_nodes.push((NodeId(index as u64), node)); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.nodes = new_nodes; |
|
self.inputs = self.inputs.iter().filter_map(|id| new_positions.get(id).map(|x| NodeId(*x as u64))).collect(); |
|
self.output = NodeId(*new_positions.get(&self.output).unwrap() as u64); |
|
|
|
assert_eq!(order.len(), self.nodes.len()); |
|
Ok(()) |
|
} |
|
} |
|
#[derive(Clone, PartialEq, serde::Serialize, serde::Deserialize)] |
|
pub enum GraphErrorType { |
|
NodeNotFound(NodeId), |
|
InputNodeNotFound(NodeId), |
|
UnexpectedGenerics { index: usize, inputs: Vec<Type> }, |
|
NoImplementations, |
|
NoConstructor, |
|
InvalidImplementations { inputs: String, error_inputs: Vec<Vec<(usize, (Type, Type))>> }, |
|
MultipleImplementations { inputs: String, valid: Vec<NodeIOTypes> }, |
|
} |
|
impl Debug for GraphErrorType { |
|
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
|
match self { |
|
GraphErrorType::NodeNotFound(id) => write!(f, "Input node {id} is not present in the typing context"), |
|
GraphErrorType::InputNodeNotFound(id) => write!(f, "Input node {id} is not present in the typing context"), |
|
GraphErrorType::UnexpectedGenerics { index, inputs } => write!(f, "Generic inputs should not exist but found at {index}: {inputs:?}"), |
|
GraphErrorType::NoImplementations => write!(f, "No implementations found"), |
|
GraphErrorType::NoConstructor => write!(f, "No construct found for node"), |
|
GraphErrorType::InvalidImplementations { inputs, error_inputs } => { |
|
let format_error = |(index, (found, expected)): &(usize, (Type, Type))| { |
|
let index = index + 1; |
|
format!( |
|
"\ |
|
• Input {index}:\n\ |
|
…found: {found}\n\ |
|
…expected: {expected}\ |
|
" |
|
) |
|
}; |
|
let format_error_list = |errors: &Vec<(usize, (Type, Type))>| errors.iter().map(format_error).collect::<Vec<_>>().join("\n"); |
|
let mut errors = error_inputs.iter().map(format_error_list).collect::<Vec<_>>(); |
|
errors.sort(); |
|
let errors = errors.join("\n"); |
|
let incompatibility = if errors.chars().filter(|&c| c == '•').count() == 1 { |
|
"This input type is incompatible:" |
|
} else { |
|
"These input types are incompatible:" |
|
}; |
|
|
|
write!( |
|
f, |
|
"\ |
|
{incompatibility}\n\ |
|
{errors}\n\ |
|
\n\ |
|
The node is currently receiving all of the following input types:\n\ |
|
{inputs}\n\ |
|
This is not a supported arrangement of types for the node.\ |
|
" |
|
) |
|
} |
|
GraphErrorType::MultipleImplementations { inputs, valid } => write!(f, "Multiple implementations found ({inputs}):\n{valid:#?}"), |
|
} |
|
} |
|
} |
|
#[derive(Clone, PartialEq, serde::Serialize, serde::Deserialize)] |
|
pub struct GraphError { |
|
pub node_path: Vec<NodeId>, |
|
pub identifier: Cow<'static, str>, |
|
pub error: GraphErrorType, |
|
} |
|
impl GraphError { |
|
pub fn new(node: &ProtoNode, text: impl Into<GraphErrorType>) -> Self { |
|
Self { |
|
node_path: node.original_location.path.clone().unwrap_or_default(), |
|
identifier: node.identifier.name.clone(), |
|
error: text.into(), |
|
} |
|
} |
|
} |
|
impl Debug for GraphError { |
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
|
f.debug_struct("NodeGraphError") |
|
.field("path", &self.node_path.iter().map(|id| id.0).collect::<Vec<_>>()) |
|
.field("identifier", &self.identifier.to_string()) |
|
.field("error", &self.error) |
|
.finish() |
|
} |
|
} |
|
pub type GraphErrors = Vec<GraphError>; |
|
|
|
|
|
#[derive(Default, Clone, dyn_any::DynAny)] |
|
pub struct TypingContext { |
|
lookup: Cow<'static, HashMap<ProtoNodeIdentifier, HashMap<NodeIOTypes, NodeConstructor>>>, |
|
inferred: HashMap<NodeId, NodeIOTypes>, |
|
constructor: HashMap<NodeId, NodeConstructor>, |
|
} |
|
|
|
impl TypingContext { |
|
|
|
pub fn new(lookup: &'static HashMap<ProtoNodeIdentifier, HashMap<NodeIOTypes, NodeConstructor>>) -> Self { |
|
Self { |
|
lookup: Cow::Borrowed(lookup), |
|
..Default::default() |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
pub fn update(&mut self, network: &ProtoNetwork) -> Result<(), GraphErrors> { |
|
for (id, node) in network.nodes.iter() { |
|
self.infer(*id, node)?; |
|
} |
|
|
|
Ok(()) |
|
} |
|
|
|
pub fn remove_inference(&mut self, node_id: NodeId) -> Option<NodeIOTypes> { |
|
self.constructor.remove(&node_id); |
|
self.inferred.remove(&node_id) |
|
} |
|
|
|
|
|
pub fn constructor(&self, node_id: NodeId) -> Option<NodeConstructor> { |
|
self.constructor.get(&node_id).copied() |
|
} |
|
|
|
|
|
pub fn type_of(&self, node_id: NodeId) -> Option<&NodeIOTypes> { |
|
self.inferred.get(&node_id) |
|
} |
|
|
|
|
|
pub fn infer(&mut self, node_id: NodeId, node: &ProtoNode) -> Result<NodeIOTypes, GraphErrors> { |
|
|
|
if let Some(inferred) = self.inferred.get(&node_id) { |
|
return Ok(inferred.clone()); |
|
} |
|
|
|
let inputs = match node.construction_args { |
|
|
|
ConstructionArgs::Value(ref v) => { |
|
assert!(matches!(node.input, ProtoNodeInput::None) || matches!(node.input, ProtoNodeInput::ManualComposition(ref x) if x == &concrete!(Context))); |
|
|
|
let types = NodeIOTypes::new(concrete!(Context), Type::Future(Box::new(v.ty())), vec![]); |
|
self.inferred.insert(node_id, types.clone()); |
|
return Ok(types); |
|
} |
|
|
|
ConstructionArgs::Nodes(ref nodes) => nodes |
|
.iter() |
|
.map(|(id, _)| { |
|
self.inferred |
|
.get(id) |
|
.ok_or_else(|| vec![GraphError::new(node, GraphErrorType::NodeNotFound(*id))]) |
|
.map(|node| node.ty()) |
|
}) |
|
.collect::<Result<Vec<Type>, GraphErrors>>()?, |
|
ConstructionArgs::Inline(ref inline) => vec![inline.ty.clone()], |
|
}; |
|
|
|
|
|
|
|
let primary_input_or_call_argument = match node.input { |
|
ProtoNodeInput::None => concrete!(()), |
|
ProtoNodeInput::ManualComposition(ref ty) => ty.clone(), |
|
ProtoNodeInput::Node(id) | ProtoNodeInput::NodeLambda(id) => { |
|
let input = self.inferred.get(&id).ok_or_else(|| vec![GraphError::new(node, GraphErrorType::InputNodeNotFound(id))])?; |
|
input.return_value.clone() |
|
} |
|
}; |
|
let using_manual_composition = matches!(node.input, ProtoNodeInput::ManualComposition(_) | ProtoNodeInput::None); |
|
let impls = self.lookup.get(&node.identifier).ok_or_else(|| vec![GraphError::new(node, GraphErrorType::NoImplementations)])?; |
|
|
|
if let Some(index) = inputs.iter().position(|p| { |
|
matches!(p, |
|
Type::Fn(_, b) if matches!(b.as_ref(), Type::Generic(_))) |
|
}) { |
|
return Err(vec![GraphError::new(node, GraphErrorType::UnexpectedGenerics { index, inputs })]); |
|
} |
|
|
|
|
|
|
|
fn valid_type(from: &Type, to: &Type) -> bool { |
|
match (from, to) { |
|
|
|
(Type::Concrete(type1), Type::Concrete(type2)) => type1 == type2, |
|
|
|
(Type::Future(type1), Type::Future(type2)) => valid_type(type1, type2), |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
(Type::Fn(in1, out1), Type::Fn(in2, out2)) => valid_type(out2, out1) && valid_type(in1, in2), |
|
|
|
|
|
(Type::Generic(_), _) | (_, Type::Generic(_)) => true, |
|
|
|
_ => false, |
|
} |
|
} |
|
|
|
|
|
let valid_output_types = impls |
|
.keys() |
|
.filter(|node_io| valid_type(&node_io.call_argument, &primary_input_or_call_argument) && inputs.iter().zip(node_io.inputs.iter()).all(|(p1, p2)| valid_type(p1, p2))) |
|
.collect::<Vec<_>>(); |
|
|
|
|
|
let substitution_results = valid_output_types |
|
.iter() |
|
.map(|node_io| { |
|
let generics_lookup: Result<HashMap<_, _>, _> = collect_generics(node_io) |
|
.iter() |
|
.map(|generic| check_generic(node_io, &primary_input_or_call_argument, &inputs, generic).map(|x| (generic.to_string(), x))) |
|
.collect(); |
|
|
|
generics_lookup.map(|generics_lookup| { |
|
let orig_node_io = (*node_io).clone(); |
|
let mut new_node_io = orig_node_io.clone(); |
|
replace_generics(&mut new_node_io, &generics_lookup); |
|
(new_node_io, orig_node_io) |
|
}) |
|
}) |
|
.collect::<Vec<_>>(); |
|
|
|
|
|
let valid_impls = substitution_results.iter().filter_map(|result| result.as_ref().ok()).collect::<Vec<_>>(); |
|
|
|
match valid_impls.as_slice() { |
|
[] => { |
|
let mut best_errors = usize::MAX; |
|
let mut error_inputs = Vec::new(); |
|
for node_io in impls.keys() { |
|
let current_errors = [&primary_input_or_call_argument] |
|
.into_iter() |
|
.chain(&inputs) |
|
.cloned() |
|
.zip([&node_io.call_argument].into_iter().chain(&node_io.inputs).cloned()) |
|
.enumerate() |
|
.filter(|(_, (p1, p2))| !valid_type(p1, p2)) |
|
.map(|(index, ty)| { |
|
let i = node.original_location.inputs(index).min_by_key(|s| s.node.len()).map(|s| s.index).unwrap_or(index); |
|
let i = if using_manual_composition { i } else { i + 1 }; |
|
(i, ty) |
|
}) |
|
.collect::<Vec<_>>(); |
|
if current_errors.len() < best_errors { |
|
best_errors = current_errors.len(); |
|
error_inputs.clear(); |
|
} |
|
if current_errors.len() <= best_errors { |
|
error_inputs.push(current_errors); |
|
} |
|
} |
|
let inputs = [&primary_input_or_call_argument] |
|
.into_iter() |
|
.chain(&inputs) |
|
.enumerate() |
|
|
|
.filter_map(|(i, t)| { |
|
let i = if using_manual_composition { i } else { i + 1 }; |
|
if i == 0 { None } else { Some(format!("• Input {i}: {t}")) } |
|
}) |
|
.collect::<Vec<_>>() |
|
.join("\n"); |
|
Err(vec![GraphError::new(node, GraphErrorType::InvalidImplementations { inputs, error_inputs })]) |
|
} |
|
[(node_io, org_nio)] => { |
|
let node_io = node_io.clone(); |
|
|
|
|
|
self.inferred.insert(node_id, node_io.clone()); |
|
self.constructor.insert(node_id, impls[org_nio]); |
|
Ok(node_io) |
|
} |
|
|
|
[first, second] => { |
|
if first.0.call_argument != second.0.call_argument { |
|
for (node_io, orig_nio) in [first, second] { |
|
if node_io.call_argument != concrete!(()) { |
|
continue; |
|
} |
|
|
|
|
|
self.inferred.insert(node_id, node_io.clone()); |
|
self.constructor.insert(node_id, impls[orig_nio]); |
|
return Ok(node_io.clone()); |
|
} |
|
} |
|
let inputs = [&primary_input_or_call_argument].into_iter().chain(&inputs).map(|t| t.to_string()).collect::<Vec<_>>().join(", "); |
|
let valid = valid_output_types.into_iter().cloned().collect(); |
|
Err(vec![GraphError::new(node, GraphErrorType::MultipleImplementations { inputs, valid })]) |
|
} |
|
|
|
_ => { |
|
let inputs = [&primary_input_or_call_argument].into_iter().chain(&inputs).map(|t| t.to_string()).collect::<Vec<_>>().join(", "); |
|
let valid = valid_output_types.into_iter().cloned().collect(); |
|
Err(vec![GraphError::new(node, GraphErrorType::MultipleImplementations { inputs, valid })]) |
|
} |
|
} |
|
} |
|
} |
|
|
|
|
|
fn collect_generics(types: &NodeIOTypes) -> Vec<Cow<'static, str>> { |
|
let inputs = [&types.call_argument].into_iter().chain(types.inputs.iter().map(|x| x.nested_type())); |
|
let mut generics = inputs |
|
.filter_map(|t| match t { |
|
Type::Generic(out) => Some(out.clone()), |
|
_ => None, |
|
}) |
|
.collect::<Vec<_>>(); |
|
if let Type::Generic(out) = &types.return_value { |
|
generics.push(out.clone()); |
|
} |
|
generics.dedup(); |
|
generics |
|
} |
|
|
|
|
|
fn check_generic(types: &NodeIOTypes, input: &Type, parameters: &[Type], generic: &str) -> Result<Type, String> { |
|
let inputs = [(Some(&types.call_argument), Some(input))] |
|
.into_iter() |
|
.chain(types.inputs.iter().map(|x| x.fn_input()).zip(parameters.iter().map(|x| x.fn_input()))) |
|
.chain(types.inputs.iter().map(|x| x.fn_output()).zip(parameters.iter().map(|x| x.fn_output()))); |
|
let concrete_inputs = inputs.filter(|(ni, _)| matches!(ni, Some(Type::Generic(input)) if generic == input)); |
|
let mut outputs = concrete_inputs.flat_map(|(_, out)| out); |
|
let out_ty = outputs |
|
.next() |
|
.ok_or_else(|| format!("Generic output type {generic} is not dependent on input {input:?} or parameters {parameters:?}",))?; |
|
if outputs.any(|ty| ty != out_ty) { |
|
return Err(format!("Generic output type {generic} is dependent on multiple inputs or parameters",)); |
|
} |
|
Ok(out_ty.clone()) |
|
} |
|
|
|
|
|
fn replace_generics(types: &mut NodeIOTypes, lookup: &HashMap<String, Type>) { |
|
let replace = |ty: &Type| { |
|
let Type::Generic(ident) = ty else { |
|
return None; |
|
}; |
|
lookup.get(ident.as_ref()).cloned() |
|
}; |
|
types.call_argument.replace_nested(replace); |
|
types.return_value.replace_nested(replace); |
|
for input in &mut types.inputs { |
|
input.replace_nested(replace); |
|
} |
|
} |
|
|
|
#[cfg(test)] |
|
mod test { |
|
use super::*; |
|
use crate::proto::{ConstructionArgs, ProtoNetwork, ProtoNode, ProtoNodeInput}; |
|
|
|
#[test] |
|
fn topological_sort() { |
|
let construction_network = test_network(); |
|
let (sorted, _) = construction_network.topological_sort().expect("Error when calling 'topological_sort' on 'construction_network."); |
|
let sorted: Vec<_> = sorted.iter().map(|x| construction_network.nodes[x.0 as usize].0).collect(); |
|
println!("{sorted:#?}"); |
|
assert_eq!(sorted, vec![NodeId(14), NodeId(10), NodeId(11), NodeId(1)]); |
|
} |
|
|
|
#[test] |
|
fn topological_sort_with_cycles() { |
|
let construction_network = test_network_with_cycles(); |
|
let sorted = construction_network.topological_sort(); |
|
|
|
assert!(sorted.is_err()) |
|
} |
|
|
|
#[test] |
|
fn id_reordering() { |
|
let mut construction_network = test_network(); |
|
construction_network.reorder_ids().expect("Error when calling 'reorder_ids' on 'construction_network."); |
|
let (sorted, _) = construction_network.topological_sort().expect("Error when calling 'topological_sort' on 'construction_network."); |
|
let sorted: Vec<_> = sorted.iter().map(|x| construction_network.nodes[x.0 as usize].0).collect(); |
|
println!("nodes: {:#?}", construction_network.nodes); |
|
assert_eq!(sorted, vec![NodeId(0), NodeId(1), NodeId(2), NodeId(3)]); |
|
let ids: Vec<_> = construction_network.nodes.iter().map(|(id, _)| *id).collect(); |
|
println!("{ids:#?}"); |
|
println!("nodes: {:#?}", construction_network.nodes); |
|
assert_eq!(construction_network.nodes[0].1.identifier.name.as_ref(), "value"); |
|
assert_eq!(ids, vec![NodeId(0), NodeId(1), NodeId(2), NodeId(3)]); |
|
} |
|
|
|
#[test] |
|
fn id_reordering_idempotent() { |
|
let mut construction_network = test_network(); |
|
construction_network.reorder_ids().expect("Error when calling 'reorder_ids' on 'construction_network."); |
|
construction_network.reorder_ids().expect("Error when calling 'reorder_ids' on 'construction_network."); |
|
let (sorted, _) = construction_network.topological_sort().expect("Error when calling 'topological_sort' on 'construction_network."); |
|
assert_eq!(sorted, vec![NodeId(0), NodeId(1), NodeId(2), NodeId(3)]); |
|
let ids: Vec<_> = construction_network.nodes.iter().map(|(id, _)| *id).collect(); |
|
println!("{ids:#?}"); |
|
assert_eq!(construction_network.nodes[0].1.identifier.name.as_ref(), "value"); |
|
assert_eq!(ids, vec![NodeId(0), NodeId(1), NodeId(2), NodeId(3)]); |
|
} |
|
|
|
#[test] |
|
fn input_resolution() { |
|
let mut construction_network = test_network(); |
|
construction_network.resolve_inputs().expect("Error when calling 'resolve_inputs' on 'construction_network."); |
|
println!("{construction_network:#?}"); |
|
assert_eq!(construction_network.nodes[0].1.identifier.name.as_ref(), "value"); |
|
assert_eq!(construction_network.nodes.len(), 6); |
|
assert_eq!(construction_network.nodes[5].1.construction_args, ConstructionArgs::Nodes(vec![(NodeId(3), false), (NodeId(4), true)])); |
|
} |
|
|
|
#[test] |
|
fn stable_node_id_generation() { |
|
let mut construction_network = test_network(); |
|
construction_network.resolve_inputs().expect("Error when calling 'resolve_inputs' on 'construction_network."); |
|
construction_network.generate_stable_node_ids(); |
|
assert_eq!(construction_network.nodes[0].1.identifier.name.as_ref(), "value"); |
|
let ids: Vec<_> = construction_network.nodes.iter().map(|(id, _)| *id).collect(); |
|
assert_eq!( |
|
ids, |
|
vec![ |
|
NodeId(16997244687192517417), |
|
NodeId(12226224850522777131), |
|
NodeId(9162113827627229771), |
|
NodeId(12793582657066318419), |
|
NodeId(16945623684036608820), |
|
NodeId(2640415155091892458) |
|
] |
|
); |
|
} |
|
|
|
fn test_network() -> ProtoNetwork { |
|
ProtoNetwork { |
|
inputs: vec![NodeId(10)], |
|
output: NodeId(1), |
|
nodes: [ |
|
( |
|
NodeId(7), |
|
ProtoNode { |
|
identifier: "id".into(), |
|
input: ProtoNodeInput::Node(NodeId(11)), |
|
construction_args: ConstructionArgs::Nodes(vec![]), |
|
..Default::default() |
|
}, |
|
), |
|
( |
|
NodeId(1), |
|
ProtoNode { |
|
identifier: "id".into(), |
|
input: ProtoNodeInput::Node(NodeId(11)), |
|
construction_args: ConstructionArgs::Nodes(vec![]), |
|
..Default::default() |
|
}, |
|
), |
|
( |
|
NodeId(10), |
|
ProtoNode { |
|
identifier: "cons".into(), |
|
input: ProtoNodeInput::ManualComposition(concrete!(u32)), |
|
construction_args: ConstructionArgs::Nodes(vec![(NodeId(14), false)]), |
|
..Default::default() |
|
}, |
|
), |
|
( |
|
NodeId(11), |
|
ProtoNode { |
|
identifier: "add".into(), |
|
input: ProtoNodeInput::Node(NodeId(10)), |
|
construction_args: ConstructionArgs::Nodes(vec![]), |
|
..Default::default() |
|
}, |
|
), |
|
( |
|
NodeId(14), |
|
ProtoNode { |
|
identifier: "value".into(), |
|
input: ProtoNodeInput::None, |
|
construction_args: ConstructionArgs::Value(value::TaggedValue::U32(2).into()), |
|
..Default::default() |
|
}, |
|
), |
|
] |
|
.into_iter() |
|
.collect(), |
|
} |
|
} |
|
|
|
fn test_network_with_cycles() -> ProtoNetwork { |
|
ProtoNetwork { |
|
inputs: vec![NodeId(1)], |
|
output: NodeId(1), |
|
nodes: [ |
|
( |
|
NodeId(1), |
|
ProtoNode { |
|
identifier: "id".into(), |
|
input: ProtoNodeInput::Node(NodeId(2)), |
|
construction_args: ConstructionArgs::Nodes(vec![]), |
|
..Default::default() |
|
}, |
|
), |
|
( |
|
NodeId(2), |
|
ProtoNode { |
|
identifier: "id".into(), |
|
input: ProtoNodeInput::Node(NodeId(1)), |
|
construction_args: ConstructionArgs::Nodes(vec![]), |
|
..Default::default() |
|
}, |
|
), |
|
] |
|
.into_iter() |
|
.collect(), |
|
} |
|
} |
|
} |
|
|