Source code for chipiron.players.move_selector.treevalue.nodes.tree_traversal

"""
This module provides functions for traversing a tree of nodes.

The functions in this module allow you to retrieve descendants of a given node in a tree structure.
"""

from typing import Any

from .algorithm_node import AlgorithmNode
from .itree_node import ITreeNode


[docs]def get_descendants(from_tree_node: ITreeNode[Any]) -> dict[ITreeNode[Any], None]: """ Get all descendants of a given tree node. Args: from_tree_node (ITreeNode): The starting tree node. Returns: dict[ITreeNode, None]: A dictionary containing all descendants of the starting tree node. """ des: dict[ITreeNode[Any], None] = {from_tree_node: None} # include itself generation: set[ITreeNode[Any]] = set( [node for node in from_tree_node.moves_children.values() if node is not None] ) while generation: next_depth_generation: set[ITreeNode[Any]] = set() for node in generation: assert node is not None des[node] = None for _, next_generation_child in node.moves_children.items(): if next_generation_child is not None: next_depth_generation.add(next_generation_child) generation = next_depth_generation return des
[docs]def get_descendants_candidate_to_open( from_tree_node: AlgorithmNode, max_depth: int | None = None ) -> list[AlgorithmNode]: """ Get descendants of a given tree node that are not over. Args: from_tree_node (AlgorithmNode): The starting tree node. max_depth (int | None, optional): The maximum depth to traverse. Defaults to None. Returns: list[AlgorithmNode]: A list of descendants that are not over. """ if not from_tree_node.all_legal_moves_generated and not from_tree_node.is_over(): # should use are_all_moves_and_children_opened() but its messy! # also using is_over is messy as over_events are defined in a child class!!! des = {from_tree_node: None} # include itself maybe else: des = {} generation: set[ITreeNode[Any]] = set( [node for node in from_tree_node.moves_children.values() if node is not None] ) depth: int = 1 assert max_depth is not None while generation and depth <= max_depth: next_depth_generation: set[ITreeNode[Any]] = set() for node in generation: assert isinstance(node, AlgorithmNode) if not node.all_legal_moves_generated and not node.is_over(): des[node] = None for _, next_generation_child in node.moves_children.items(): if next_generation_child is not None: next_depth_generation.add(next_generation_child) generation = next_depth_generation return list(des.keys())
[docs]def get_descendants_candidate_not_over( from_tree_node: ITreeNode[Any], max_depth: int | None = None ) -> list[ITreeNode[Any]]: """ Get descendants of a given tree node that are not over. Args: from_tree_node (ITreeNode): The starting tree node. max_depth (int | None, optional): The maximum depth to traverse. Defaults to None. Returns: list[ITreeNode]: A list of descendants that are not over. """ assert not from_tree_node.is_over() if not from_tree_node.moves_children: return [from_tree_node] des: dict[ITreeNode[Any], None] = {} generation: set[ITreeNode[Any]] = set( [node for node in from_tree_node.moves_children.values() if node is not None] ) depth: int = 1 assert max_depth is not None while generation and depth <= max_depth: next_depth_generation: set[ITreeNode[Any]] = set() for node in generation: assert isinstance(node, AlgorithmNode) if not node.is_over(): des[node] = None for _, next_generation_child in node.moves_children.items(): if next_generation_child is not None: next_depth_generation.add(next_generation_child) generation = next_depth_generation return list(des.keys())