Spaces:
Runtime error
Runtime error
class VQAReTokenRelation(object): | |
def __init__(self, **kwargs): | |
pass | |
def __call__(self, data): | |
""" | |
build relations | |
""" | |
entities = data["entities"] | |
relations = data["relations"] | |
id2label = data.pop("id2label") | |
empty_entity = data.pop("empty_entity") | |
entity_id_to_index_map = data.pop("entity_id_to_index_map") | |
relations = list(set(relations)) | |
relations = [ | |
rel | |
for rel in relations | |
if rel[0] not in empty_entity and rel[1] not in empty_entity | |
] | |
kv_relations = [] | |
for rel in relations: | |
pair = [id2label[rel[0]], id2label[rel[1]]] | |
if pair == ["question", "answer"]: | |
kv_relations.append( | |
{ | |
"head": entity_id_to_index_map[rel[0]], | |
"tail": entity_id_to_index_map[rel[1]], | |
} | |
) | |
elif pair == ["answer", "question"]: | |
kv_relations.append( | |
{ | |
"head": entity_id_to_index_map[rel[1]], | |
"tail": entity_id_to_index_map[rel[0]], | |
} | |
) | |
else: | |
continue | |
relations = sorted( | |
[ | |
{ | |
"head": rel["head"], | |
"tail": rel["tail"], | |
"start_index": self.get_relation_span(rel, entities)[0], | |
"end_index": self.get_relation_span(rel, entities)[1], | |
} | |
for rel in kv_relations | |
], | |
key=lambda x: x["head"], | |
) | |
data["relations"] = relations | |
return data | |
def get_relation_span(self, rel, entities): | |
bound = [] | |
for entity_index in [rel["head"], rel["tail"]]: | |
bound.append(entities[entity_index]["start"]) | |
bound.append(entities[entity_index]["end"]) | |
return min(bound), max(bound) | |