File size: 2,589 Bytes
3dfe8fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#include <map>
#include <string>
#include <iostream>
#include <memory>
#include <mutex>

class Node {
public:
    // Constructor, initializes a Node with a parent pointer and a prior probability
    Node(Node* parent = nullptr, float prior_p = 1.0)
        : parent(parent), prior_p(prior_p), visit_count(0), value_sum(0.0) {}

    // Destructor, deletes all child nodes when a node is deleted to prevent memory leaks
    ~Node() {
        for (auto& pair : children) {
            delete pair.second;
        }
    }

    // Returns the average value of the node
    float get_value() {
        return visit_count == 0 ? 0.0 : value_sum / visit_count;
    }

    // Updates the visit count and value sum of the node
    void update(float value) {
        visit_count++;
        value_sum += value;
    }

    // Recursively updates the value and visit count of the node and its parent nodes
    void update_recursive(float leaf_value, std::string battle_mode_in_simulation_env) {
        // If the mode is "self_play_mode", the leaf_value is subtracted from the parent's value
        if (battle_mode_in_simulation_env == "self_play_mode") {
            update(leaf_value);
            if (!is_root()) {
                parent->update_recursive(-leaf_value, battle_mode_in_simulation_env);
            }
        }
        // If the mode is "play_with_bot_mode", the leaf_value is added to the parent's value
        else if (battle_mode_in_simulation_env == "play_with_bot_mode") {
            update(leaf_value);
            if (!is_root()) {
                parent->update_recursive(leaf_value, battle_mode_in_simulation_env);
            }
        }
    }

    // Returns true if the node has no children
    bool is_leaf() {
        return children.empty();
    }

    // Returns true if the node has no parent
    bool is_root() {
        return parent == nullptr;
    }

    // Returns a pointer to the node's parent
    Node* get_parent() {
        return parent;
    }

    // Returns a map of the node's children
    std::map<int, Node*> get_children() {
        return children;
    }

    // Returns the node's visit count
    int get_visit_count() {
        return visit_count;
    }

    // Adds a child to the node
    void add_child(int action, Node* node) {
        children[action] = node;
    }

public:
    Node* parent;  // Pointer to the parent node
    float prior_p;  // Prior probability of the node
    int visit_count;  // Count of visits to the node
    float value_sum;  // Sum of values of the node
    std::map<int, Node*> children;  // Map of child nodes
};