|
import hashlib |
|
import copy |
|
import json |
|
from citekit.prompt.prompt import Prompt, ALCEDocPrompt,DocPrompt,NewALCEVanillaPrompt |
|
|
|
try: |
|
example_results = json.load(open('res.json')) |
|
except: |
|
example_results = [] |
|
|
|
|
|
def weight_check(module, target): |
|
if module.model_type == 'verifier': |
|
test_module = copy.deepcopy(module) |
|
test_module.last_message = True |
|
true_target = test_module.send() |
|
if str(target) == str(true_target) or ('output' in str(target).lower() and 'output' in str(true_target).lower()): |
|
return 1 |
|
test_module.last_message = False |
|
true_target = test_module.send() |
|
if str(target) == str(true_target) or ('output' in str(target).lower() and 'output' in str(true_target).lower()): |
|
return -1 |
|
|
|
return 0 |
|
|
|
def get_params(module): |
|
import copy |
|
componets = copy.deepcopy(module.prompt_maker.components if hasattr(module, 'prompt_maker') and hasattr(module.prompt_maker, 'components') else None) |
|
if componets: |
|
if isinstance(componets, tuple): |
|
componets = componets[0] |
|
componets.update(module.self_prompt) |
|
destination = module.get_destinations() |
|
if len(destination) == 1: |
|
destination = str(destination[0]) |
|
elif not destination: |
|
destination = 'N/A' |
|
else: |
|
destination = 'Not Available for multiple destinations' |
|
|
|
if hasattr(module, 'iterative'): |
|
mode = 'iterative' if module.iterative else ('parallel' if module.parallel else None) |
|
else: |
|
mode = None |
|
|
|
if hasattr(module, 'if_add_output_to_head') and hasattr(module, 'head_key'): |
|
global_prompt = 'N/A' if module.if_add_output_to_head == False else module.head_key |
|
else: |
|
global_prompt = 'N/A' |
|
|
|
|
|
params = { |
|
'type': module.model_type, |
|
'Model': getattr(module, 'model_name', None), |
|
'Mode': mode, |
|
'Max Turn': getattr(module, 'max_turn', None), |
|
'Destination': destination, |
|
'Prompt': module.prompt_maker.make_prompt(componets) if hasattr(module,'prompt_maker') and module.prompt_maker else None, |
|
'Global Prompt': global_prompt, |
|
} |
|
print(params) |
|
non_empty_params = {} |
|
for k, v in params.items(): |
|
if v: |
|
non_empty_params[k] = v |
|
return non_empty_params |
|
|
|
class PipelineGraph: |
|
|
|
def __init__(self, pipeline = None): |
|
self.pipeline = pipeline |
|
self.nodes = {'input': {}, 'output': {}} |
|
self.edges = {} |
|
self.node_count = 0 |
|
if pipeline: |
|
self.load_pipeline(pipeline) |
|
print('FINDING ATTR:', [str(module) for module in pipeline.module]) |
|
|
|
def get_auto_node_name(self): |
|
"""生成自动节点名。""" |
|
self.node_count += 1 |
|
return f"Node{self.node_count}" |
|
|
|
def update(self): |
|
self.__init__(pipeline=self.pipeline) |
|
|
|
def load_pipeline(self, pipeline): |
|
"""从现有Pipeline对象加载图。""" |
|
initial_module = pipeline.get_initial_module() |
|
processed_nodes = [] |
|
if not initial_module: |
|
initial_module = pipeline.llm |
|
self.add_node(name = str(initial_module), |
|
params = get_params(module=initial_module)) |
|
self.add_edge(from_node='input', to_node=str(initial_module), weight=0) |
|
processed_nodes.append(initial_module) |
|
subnodes = initial_module.get_destinations() |
|
if not subnodes: |
|
self.add_edge(from_node=str(initial_module), to_node='output', weight=0) |
|
return |
|
|
|
for subnode in subnodes: |
|
weight = weight_check(initial_module, subnode) |
|
self.load_node(subnode, processed_nodes=processed_nodes) |
|
self.add_edge(from_node=str(initial_module), to_node=str(subnode), weight=0) |
|
if hasattr(initial_module,'output_cond') and initial_module.output_cond: |
|
weight = weight_check(initial_module, 'output') |
|
self.add_edge(from_node=str(initial_module), to_node='output', weight=weight) |
|
|
|
|
|
def load_node(self, module, processed_nodes = None): |
|
if module in processed_nodes: |
|
return |
|
self.add_node(name = str(module), |
|
params = get_params(module=module)) |
|
processed_nodes.append(module) |
|
subnodes = module.get_destinations() |
|
if not subnodes: |
|
self.add_edge(from_node=str(module), to_node='output', weight=0) |
|
return |
|
for subnode in subnodes: |
|
self.load_node(subnode, processed_nodes) |
|
weight = weight_check(module, subnode) |
|
self.add_edge(from_node=str(module), to_node=str(subnode), weight=weight) |
|
if hasattr(module,'output_cond') and module.output_cond: |
|
weight = weight_check(module, 'output') |
|
self.add_edge(from_node=str(module), to_node='output', weight=0) |
|
|
|
|
|
def export(self): |
|
"""将图结构导出为字典格式,包含节点和边的信息。""" |
|
nodes_export = {name: params.copy() for name, params in self.nodes.items()} |
|
edges_export = [] |
|
for from_node, connections in self.edges.items(): |
|
for to_node, weight in connections.items(): |
|
edges_export.append({ |
|
'from': from_node, |
|
'to': to_node, |
|
'weight': weight |
|
}) |
|
return { |
|
'nodes': nodes_export, |
|
'edges': edges_export |
|
} |
|
|
|
@classmethod |
|
def import_from_dict(cls, data): |
|
"""从字典数据创建新的PipelineGraph实例。""" |
|
graph = cls() |
|
|
|
graph.nodes = {} |
|
|
|
for name, params in data['nodes'].items(): |
|
graph.add_node(name, params) |
|
|
|
for edge in data['edges']: |
|
graph.add_edge(edge['from'], edge['to'], edge['weight']) |
|
return graph |
|
|
|
def add_node(self, name, params=None): |
|
"""添加或更新节点,参数可选。""" |
|
if params is None: |
|
params = {} |
|
self.nodes[name] = params.copy() |
|
|
|
def remove_node(self, name): |
|
"""删除节点,自动移除相关边。不能删除input/output节点。""" |
|
if name in ['input', 'output']: |
|
raise ValueError("Cannot remove 'input' or 'output' nodes.") |
|
if name not in self.nodes: |
|
return |
|
del self.nodes[name] |
|
|
|
if name in self.edges: |
|
del self.edges[name] |
|
|
|
for from_node in list(self.edges.keys()): |
|
if name in self.edges[from_node]: |
|
del self.edges[from_node][name] |
|
|
|
if not self.edges[from_node]: |
|
del self.edges[from_node] |
|
|
|
def add_edge(self, from_node, to_node, weight): |
|
"""添加或更新边,权重必须是0、-1或+1。""" |
|
if weight not in {0, -1, 1}: |
|
raise ValueError("Weight must be 0, -1, or 1.") |
|
if from_node not in self.nodes: |
|
raise KeyError(f"Node '{from_node}' does not exist.") |
|
if to_node not in self.nodes: |
|
raise KeyError(f"Node '{to_node}' does not exist.") |
|
if from_node not in self.edges: |
|
self.edges[from_node] = {} |
|
self.edges[from_node][to_node] = weight |
|
|
|
def remove_edge(self, from_node, to_node): |
|
"""删除指定边。""" |
|
if from_node in self.edges and to_node in self.edges[from_node]: |
|
del self.edges[from_node][to_node] |
|
|
|
if not self.edges[from_node]: |
|
del self.edges[from_node] |
|
|
|
def get_nodes(self): |
|
"""返回所有节点名。""" |
|
return list(self.nodes.keys()) |
|
|
|
def get_edges(self): |
|
"""返回所有边的列表,形式为(from, to, weight)。""" |
|
edges = [] |
|
for from_node in self.edges: |
|
for to_node, weight in self.edges[from_node].items(): |
|
edges.append((from_node, to_node, weight)) |
|
return edges |
|
|
|
def update_node_params(self, name, params): |
|
"""更新节点的参数表。""" |
|
if name not in self.nodes: |
|
raise KeyError(f"Node '{name}' does not exist.") |
|
self.nodes[name].update(params.copy()) |
|
|
|
def visualize(self, use_graphviz=True): |
|
""" |
|
可视化图结构: |
|
- 默认尝试用graphviz生成矢量图(需安装graphviz) |
|
- 若未安装graphviz,则用文本模式输出 |
|
""" |
|
|
|
if use_graphviz: |
|
try: |
|
import graphviz |
|
except ImportError: |
|
print("Graphviz未安装,切换到文本模式。安装命令:pip install graphviz") |
|
return self._visualize_text() |
|
|
|
|
|
dot = graphviz.Digraph(comment='Pipeline Graph', graph_attr={'rankdir': 'LR'}) |
|
|
|
|
|
for node_name in self.nodes: |
|
params = self.nodes[node_name] |
|
|
|
params_str = ", ".join([f"{k}={v}" for k, v in params.items()]) |
|
|
|
|
|
if node_name in ('input', 'output'): |
|
dot.node( |
|
node_name, |
|
label=f"{node_name}\n({params_str})" if params_str else node_name, |
|
shape='doublecircle', |
|
style='filled', |
|
fillcolor='#e6f3ff' |
|
) |
|
else: |
|
dot.node( |
|
node_name, |
|
label=f"{node_name}\n({params_str})" if params_str else node_name, |
|
shape='box', |
|
style='rounded,filled', |
|
fillcolor='#f0f0f0' |
|
) |
|
|
|
|
|
for from_node, connections in self.edges.items(): |
|
for to_node, weight in connections.items(): |
|
dot.edge( |
|
from_node, |
|
to_node, |
|
label=str(weight), |
|
color='#666666', |
|
fontcolor='#ff5555' |
|
) |
|
|
|
|
|
dot.render('pipeline_graph', view=True, cleanup=True) |
|
return |
|
|
|
|
|
self._visualize_text() |
|
|
|
def _visualize_text(self): |
|
"""纯文本模式的可视化(备用方案)""" |
|
print("="*40 + "\nPipeline Graph 文本视图\n" + "="*40) |
|
|
|
|
|
print("\n[节点列表]") |
|
for node, params in self.nodes.items(): |
|
param_desc = " ".join([f"{k}={v}" for k, v in params.items()]) |
|
print(f"· {node:10} {param_desc}") |
|
|
|
|
|
print("\n[边列表]") |
|
if not self.edges: |
|
print("(暂无边连接)") |
|
else: |
|
for from_node, connections in self.edges.items(): |
|
for to_node, weight in connections.items(): |
|
arrow = { |
|
1: "──(+)──>", |
|
-1: "──(-)──>", |
|
0: "───────>" |
|
}[weight] |
|
print(f"{from_node:10} {arrow} {to_node}") |
|
print("="*40 + "\n") |
|
|
|
|
|
|
|
def generate_html(self, results = example_results): |
|
assert isinstance(results, str) |
|
|
|
import json |
|
result_path = results |
|
results = json.load(open(results)) |
|
"""生成交互式可视化HTML,使用D3.js力导向图布局""" |
|
nodes_data = [] |
|
for name in self.nodes: |
|
params = self.nodes[name] |
|
nodes_data.append({ |
|
"id": name, |
|
"type": params.get("type", name), |
|
"params": params |
|
}) |
|
|
|
edges_data = [] |
|
for from_node, to_nodes in self.edges.items(): |
|
for to_node, weight in to_nodes.items(): |
|
edges_data.append({ |
|
"source": from_node, |
|
"target": to_node, |
|
"weight": weight |
|
}) |
|
|
|
nodes_js = json.dumps(nodes_data) |
|
edges_js = json.dumps(edges_data) |
|
|
|
with open('htmls/html_templates/pipeline_template.txt', 'r', encoding='utf-8') as f: |
|
template = f.read() |
|
|
|
template = template.replace('<NODE_JS>', nodes_js).replace('<EDGE_JS>', edges_js).replace('<FILE_PATH>', result_path).replace('<OPTIONS>', "".join(f'<option value="{i}">Result {i+1}</option>' for i in range(len(results)))) |
|
return template |
|
|
|
|
|
def get_nodes(self): |
|
nodes_data = [] |
|
for name in self.nodes: |
|
params = self.nodes[name] |
|
nodes_data.append({ |
|
"id": name, |
|
"type": params.get("type", name), |
|
"params": params |
|
}) |
|
return json.dumps(nodes_data) |
|
|
|
def get_edges(self): |
|
edges_data = [] |
|
for from_node, to_nodes in self.edges.items(): |
|
for to_node, weight in to_nodes.items(): |
|
edges_data.append({ |
|
"source": from_node, |
|
"target": to_node, |
|
"weight": weight |
|
}) |
|
return json.dumps(edges_data) |
|
|
|
def get_json(self): |
|
return { |
|
'nodes': self.get_nodes(), |
|
'edges': self.get_edges() |
|
} |
|
|
|
def generate_html_embed(self, results = example_results): |
|
assert isinstance(results, str) or isinstance(results, list) |
|
|
|
import json |
|
result_path = results |
|
if isinstance(results, str): |
|
results = json.load(open(results)) |
|
"""生成交互式可视化HTML,使用D3.js力导向图布局""" |
|
nodes_data = [] |
|
for name in self.nodes: |
|
params = self.nodes[name] |
|
nodes_data.append({ |
|
"id": name, |
|
"type": params.get("type", name), |
|
"params": params |
|
}) |
|
|
|
edges_data = [] |
|
for from_node, to_nodes in self.edges.items(): |
|
for to_node, weight in to_nodes.items(): |
|
edges_data.append({ |
|
"source": from_node, |
|
"target": to_node, |
|
"weight": weight |
|
}) |
|
|
|
nodes_js = json.dumps(nodes_data) |
|
edges_js = json.dumps(edges_data) |
|
|
|
with open('htmls/html_templates/pipeline_template_emb.txt', 'r', encoding='utf-8') as f: |
|
template = f.read() |
|
|
|
template = template.replace('<NODE_JS>', nodes_js).replace('<EDGE_JS>', edges_js).replace('<RESULTS>', json.dumps(results)).replace('<OPTIONS>', "".join(f'<option value="{i}">Result {i+1}</option>' for i in range(len(results)))) |
|
return template |
|
|
|
if __name__ == '__main__': |
|
|
|
graph = PipelineGraph() |
|
|
|
|
|
graph.add_node('processor', {'type': 'transform', 'rate': 0.8}) |
|
graph.add_node('validator', {'threshold': 0.5}) |
|
|
|
|
|
graph.add_edge('input', 'processor', 1) |
|
graph.add_edge('processor', 'validator', -1) |
|
graph.add_edge('validator', 'output', 0) |
|
|
|
|
|
graph.update_node_params('processor', {'rate': 1.0}) |
|
|
|
graph.visualize() |