Spaces:
Runtime error
Runtime error
File size: 6,777 Bytes
7b127f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import glob
import ray
import numpy as np
import argparse
from OCC.Core.BRepAdaptor import BRepAdaptor_Surface
from OCC.Core.BRepGProp import brepgprop
from OCC.Core.BRepLProp import BRepLProp_SLProps
from OCC.Core.GProp import GProp_GProps
from lightning_fabric import seed_everything
from eval.eval_condition import *
import networkx as nx
from OCC.Core.STEPControl import STEPControl_Reader
from OCC.Core.TopExp import TopExp_Explorer
from OCC.Core.TopAbs import TopAbs_VERTEX, TopAbs_EDGE
from OCC.Core.BRep import BRep_Tool
from OCC.Core.gp import gp_Pnt
def remove_outliers_zscore(data, threshold=3, max_value=50):
if len(data) == 0 or sum(data) == 0:
return data
mean = np.mean(data)
std_dev = np.std(data)
return [x for x in data if abs((x - mean) / (std_dev + 1e-8)) <= threshold and x < max_value]
def extract_edges_and_vertices(shape):
explorer_edges = TopExp_Explorer(shape, TopAbs_EDGE)
explorer_vertices = TopExp_Explorer(shape, TopAbs_VERTEX)
vertex_map = {}
edges = []
while explorer_edges.More():
edge = explorer_edges.Current()
vertices_on_edge = []
vertex_explorer = TopExp_Explorer(edge, TopAbs_VERTEX)
while vertex_explorer.More():
vertex = vertex_explorer.Current()
point = BRep_Tool.Pnt(vertex)
coord = (round(point.X(), 6), round(point.Y(), 6), round(point.Z(), 6))
if coord not in vertex_map:
vertex_map[coord] = len(vertex_map)
vertices_on_edge.append(vertex_map[coord])
vertex_explorer.Next()
if len(vertices_on_edge) == 2:
edges.append(tuple(vertices_on_edge))
explorer_edges.Next()
return vertex_map, edges
def create_nx_graph(vertex_map, edges):
graph = nx.Graph()
for coord, node_id in vertex_map.items():
graph.add_node(node_id, coord=coord)
for edge in edges:
graph.add_edge(edge[0], edge[1])
return graph
def calculate_cyclomatic_complexity(graph):
num_nodes = graph.number_of_nodes() # N
num_edges = graph.number_of_edges() # E
if graph.is_directed():
num_components = nx.number_strongly_connected_components(graph)
else:
num_components = nx.number_connected_components(graph)
# M = E - N + 2P
cyclomatic_complexity = num_edges - num_nodes + 2 * num_components
return cyclomatic_complexity
def eval_complexity_one(step_file_path):
isvalid, shape = check_step_valid_soild(step_file_path, return_shape=True)
if not isvalid:
return None
vertex_map, edges = extract_edges_and_vertices(shape)
graph = create_nx_graph(vertex_map, edges)
cyclomatic_complexity = calculate_cyclomatic_complexity(graph)
face_list = get_primitives(shape, TopAbs_FACE)
num_face = len(face_list)
num_edge = len(vertex_map.keys())
num_vertex = len(edges)
sample_point_curvature = []
num_samples = 256
for face in face_list:
surf_adaptor = BRepAdaptor_Surface(face)
u_min, u_max, v_min, v_max = (surf_adaptor.FirstUParameter(), surf_adaptor.LastUParameter(), surf_adaptor.FirstVParameter(),
surf_adaptor.LastVParameter())
u_samples = np.linspace(u_min, u_max, int(np.sqrt(num_samples)))
v_samples = np.linspace(v_min, v_max, int(np.sqrt(num_samples)))
face_sample_point_curvature = []
for u in u_samples:
for v in v_samples:
props = BRepLProp_SLProps(surf_adaptor, u, v, 2, 1e-8)
if props.IsCurvatureDefined():
mean_curvature = props.MeanCurvature()
face_sample_point_curvature.append(abs(mean_curvature))
face_sample_point_curvature = remove_outliers_zscore(face_sample_point_curvature)
if len(face_sample_point_curvature) == 0:
continue
sample_point_curvature.append(np.median(face_sample_point_curvature))
mean_curvature = np.mean(sample_point_curvature) if len(sample_point_curvature) > 0 else np.nan
if num_face == 0 or mean_curvature == np.nan:
return None
return {
'num_face' : int(num_face),
'num_edge' : int(num_edge),
'num_vertex' : int(num_vertex),
'cyclomatic_complexity': cyclomatic_complexity,
'mean_curvature' : mean_curvature,
}
eval_complexity_one_remote = ray.remote(eval_complexity_one)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Evaluate Brep Complexity')
parser.add_argument('--eval_root', type=str)
parser.add_argument('--only_valid', action='store_true')
args = parser.parse_args()
# 设置随机种子
seed_everything(0)
ray.init(ignore_reinit_error=True, local_mode=False)
all_folders = os.listdir(args.eval_root)
is_valid_list = []
futures = []
for folder in tqdm(all_folders):
step_path_list = glob.glob(os.path.join(args.eval_root, folder, '*.step'))
if len(step_path_list) == 0:
is_valid_list.append(False)
futures.append(None)
continue
is_valid_list.append(check_step_valid_soild(step_path_list[0], return_shape=False))
futures.append(eval_complexity_one_remote.remote(step_path_list[0]))
assert len(is_valid_list) == len(futures) == len(all_folders)
all_result = {}
for i, future in enumerate(tqdm(futures)):
if future is None:
continue
result = ray.get(future)
if args.only_valid and not is_valid_list[i]:
continue
all_result[all_folders[i]] = result
num_face_list = []
num_edge_list = []
num_vertex_list = []
cyclomatic_complexity_list = []
mean_curvature_list = []
exception_folder = []
for folder, result in tqdm(all_result.items()):
if result is None:
continue
result = dict(result)
num_face_list.append(result['num_face'])
num_edge_list.append(result['num_edge'])
num_vertex_list.append(result['num_vertex'])
cyclomatic_complexity_list.append(result['cyclomatic_complexity'])
mean_curvature_list.append(result['mean_curvature'])
exception_folder.append(folder)
print(f'Num Face: {np.mean(num_face_list)}')
print(f'Num Edge: {np.mean(num_edge_list)}')
print(f'Num Vertex: {np.mean(num_vertex_list)}')
print(f'Cyclomatic Complexity: {np.mean(cyclomatic_complexity_list)}')
print(f'Mean Curvature: {np.mean(mean_curvature_list)}')
print(f"{np.mean(num_face_list)} {np.mean(num_edge_list)} {np.mean(num_vertex_list)} "
f"{np.mean(cyclomatic_complexity_list)} {np.mean(mean_curvature_list)}")
ray.shutdown()
|