Source code for chipiron.scripts.generate_datasets.generate_boards

import random
import shutil
import subprocess
from datetime import datetime
from pathlib import Path
from typing import Generator

import chess
import chess.pgn
import pandas as pd
import zstandard
from pandas import DataFrame

from chipiron.environments.chess_env.board.utils import fen
from chipiron.utils.path_variables import (  # removed LICHESS_PGN_FILE usage
    EXTERNAL_DATA_DIR,
)

# Sampling configuration variables
DEFAULT_SAMPLING_FREQUENCY = 50
DEFAULT_OFFSET_MIN = 5
DEFAULT_RANDOM_SEED: int | None = 0  # Set to an int for reproducibility

# Lichess monthly database parameters
LICHESS_STANDARD_BASE_URL = "https://database.lichess.org/standard"
MONTHLY_FILE_TEMPLATE = "lichess_db_standard_rated_{month}.pgn.zst"  # month = YYYY-MM


[docs]def save_dataset_progress( the_dic: list[dict[str, fen]], output_file_path: str, count_game: int, total_count_move: int, max_boards: int, total_games_in_file: int | None, total_moves_in_file: int | None, input_pgn_file_path: str, sampling_frequency: int, offset_min: int, seed: int | None, is_final: bool = False, months_used: list[str] | None = None, ) -> int: """ Save dataset progress (intermediate or final) and display statistics. Args: the_dic: Current list of board positions output_file_path: Path where to save the pickle file count_game: Number of games processed so far total_count_move: Number of moves processed so far max_boards: Maximum target board positions total_games_in_file: Total games in source file (optional) total_moves_in_file: Total moves in source file (optional) input_pgn_file_path: Source PGN file path sampling_frequency: Sampling frequency for moves offset_min: Minimum offset for sampling seed: Random seed used is_final: Whether this is the final save (adds additional metadata) months_used: List of months processed (for dynamic mode) Returns: Number of board positions recorded so far """ new_data_frame_states: DataFrame = pd.DataFrame.from_dict(the_dic) recorded_board = len(new_data_frame_states.index) save_type = "Final" if is_final else "Progress" print( f"{save_type}: {recorded_board:,} / {max_boards:,} board positions collected ({recorded_board / max_boards * 100:.1f}%)" ) # Enhanced progress with file totals games_progress = f"{count_game:,}" moves_progress = f"{total_count_move:,}" if total_games_in_file: games_progress += f" / {total_games_in_file:,} ({count_game / total_games_in_file * 100:.1f}%)" if total_moves_in_file: moves_progress += f" / {total_moves_in_file:,} ({total_count_move / total_moves_in_file * 100:.1f}%)" print(f" Games processed: {games_progress}") print(f" Moves processed: {moves_progress}") # Add metadata new_data_frame_states.attrs["source_pgn_file"] = input_pgn_file_path new_data_frame_states.attrs["sampling_frequency"] = sampling_frequency new_data_frame_states.attrs["offset_min"] = offset_min new_data_frame_states.attrs["offset_max_strategy"] = "per-game total move count" new_data_frame_states.attrs["random_seed"] = seed new_data_frame_states.attrs["creation_date"] = pd.Timestamp.now().isoformat() new_data_frame_states.attrs["filter_criteria"] = ( "positions sampled at offset + k*frequency (no engine eval filter)" ) new_data_frame_states.attrs["games_processed"] = count_game new_data_frame_states.attrs["moves_processed"] = total_count_move # Additional metadata for final saves if is_final: if months_used: new_data_frame_states.attrs["months_used"] = months_used new_data_frame_states.attrs["source_pgn_mode"] = "dynamic_monthly" new_data_frame_states.attrs["total_games_processed"] = count_game new_data_frame_states.attrs["total_moves_processed"] = total_count_move new_data_frame_states.attrs["final_dataset_size"] = len(new_data_frame_states) new_data_frame_states.to_pickle(output_file_path) if is_final and months_used: print( f"Final dataset saved ({len(new_data_frame_states)} positions) using months {months_used} -> {output_file_path}" ) return recorded_board
[docs]def process_game( game: chess.pgn.GameNode, total_count_move: int, the_dic: list[dict[str, fen]], sampling_frequency: int, offset_min: int = DEFAULT_OFFSET_MIN, ) -> int: """Process a single game and extract positions (no eval requirement).""" chess_board: chess.Board = game.board() current_node: chess.pgn.GameNode = game # Normalize offset_min if offset_min < 0: offset_min = 0 # Materialize moves to know game length moves_list = list(game.mainline_moves()) game_total_moves = len(moves_list) effective_min = min(offset_min, game_total_moves) # Draw random offset within game length if game_total_moves >= effective_min: random_offset = random.randint(effective_min, game_total_moves) else: random_offset = game_total_moves # degenerate case # Track moves within this game game_move_count = 0 for move in moves_list: total_count_move += 1 game_move_count += 1 chess_board.push(move) next_node: chess.pgn.GameNode | None = current_node.next() if next_node is None: break # Safety check for end of game current_node = next_node # Check if we should sample this position based on offset + frequency # Sample at: offset, offset + frequency, offset + 2*frequency, etc. if ( game_move_count >= random_offset and (game_move_count - random_offset) % sampling_frequency == 0 ): # Only store FEN strings since that's all we need the_dic.append( { "fen": chess_board.fen(), } ) return total_count_move
# --- Dynamic monthly download helpers ---
[docs]def iterate_months(start_month: str) -> Generator[str, None, None]: """Yield month strings (YYYY-MM) starting at start_month incrementing by one month indefinitely.""" dt = datetime.strptime(start_month, "%Y-%m") while True: yield dt.strftime("%Y-%m") # increment month year = dt.year + (dt.month // 12) month = dt.month % 12 + 1 dt = dt.replace(year=year, month=month)
[docs]def download_month_zst(month: str, dest_dir: Path) -> Path: """Download the compressed monthly PGN (.zst) file for a given month into dest_dir and return its path.""" dest_dir.mkdir(parents=True, exist_ok=True) file_name = MONTHLY_FILE_TEMPLATE.format(month=month) url = f"{LICHESS_STANDARD_BASE_URL}/{file_name}" local_path = dest_dir / file_name # Check if file exists and verify its size if local_path.exists(): try: # Get remote file size to compare import urllib.request with urllib.request.urlopen(url) as response: remote_size = int(response.headers.get("Content-Length", 0)) local_size = local_path.stat().st_size if remote_size > 0 and local_size == remote_size: print(f"Compressed file already exists for {month}: {local_path}") return local_path else: print( f"Incomplete file detected for {month} (local: {local_size}, remote: {remote_size}). Re-downloading..." ) except Exception as e: print(f"Could not verify file size for {month}: {e}. Re-downloading...") print(f"Downloading {url} -> {local_path}") def progress_hook(block_num: int, block_size: int, total_size: int) -> None: if total_size > 0: downloaded = block_num * block_size percent = min(100.0, downloaded * 100.0 / total_size) downloaded_mb = downloaded / (1024 * 1024) total_mb = total_size / (1024 * 1024) print( f"\rProgress: {percent:.1f}% ({downloaded_mb:.1f}/{total_mb:.1f} MB)", end="", flush=True, ) urllib.request.urlretrieve(url, local_path, reporthook=progress_hook) print() # New line after download completes return local_path
[docs]def decompress_zst(zst_path: Path, output_pgn_path: Path) -> None: """Decompress .zst to .pgn using zstd command line or Python fallback.""" if output_pgn_path.exists(): print(f"Decompressed PGN already present: {output_pgn_path}") return print(f"Decompressing {zst_path.name}...") # Try zstd command first if shutil.which("zstd"): try: cmd = ["zstd", "-d", "-f", "-o", str(output_pgn_path), str(zst_path)] print(f"Running: {' '.join(cmd)}") subprocess.run(cmd, check=True, capture_output=True, text=True) print("Decompression completed with zstd command") return except subprocess.CalledProcessError as e: print(f"zstd command failed (exit code {e.returncode}): {e.stderr}") print("Falling back to Python decompression...") # Fall back to Python zstandard try: print("Using Python zstandard library for decompression...") dctx = zstandard.ZstdDecompressor() with zst_path.open("rb") as src, output_pgn_path.open("wb") as dst: dctx.copy_stream(src, dst) print("Decompression completed with Python zstandard") except Exception as exc: raise RuntimeError(f"Both zstd CLI and Python zstandard failed: {exc}") from exc
[docs]def ensure_month_pgn( month: str, dest_dir: Path, keep_decompressed: bool = False ) -> Path: """Ensure decompressed monthly PGN exists; download & decompress if needed; return .pgn path.""" file_name = MONTHLY_FILE_TEMPLATE.format(month=month) pgn_path = dest_dir / file_name.replace(".pgn.zst", ".pgn") # First check if decompressed file already exists if pgn_path.exists(): print(f"Decompressed PGN already exists for {month}: {pgn_path}") return pgn_path # If not, download compressed file and decompress compressed = download_month_zst(month, dest_dir) decompress_zst(compressed, pgn_path) # Always delete compressed file after decompression if compressed.exists(): compressed.unlink(missing_ok=True) print(f"Deleted compressed file: {compressed}") return pgn_path
[docs]def generate_board_dataset_multi_months( output_file_path: str, max_boards: int = 10_000_000, sampling_frequency: int = DEFAULT_SAMPLING_FREQUENCY, offset_min: int = DEFAULT_OFFSET_MIN, seed: int | None = DEFAULT_RANDOM_SEED, start_month: str = "2015-03", max_months: int | None = None, delete_pgn_after_use: bool = True, intermediate_every_games: int = 10_000, dest_dir: Path | None = None, ) -> None: """Generate dataset streaming through monthly Lichess dumps downloaded on-the-fly. Stops when max_boards collected or month limit reached. Each month PGN is deleted when done (optional). """ if seed is not None: random.seed(seed) if dest_dir is None: dest_dir = EXTERNAL_DATA_DIR / "lichess_pgn" dest_dir.mkdir(parents=True, exist_ok=True) the_dic: list[dict[str, fen]] = [] months_used: list[str] = [] count_game = 0 total_count_move = 0 recorded_board = 0 month_iter = iterate_months(start_month) months_processed = 0 while recorded_board < max_boards: month = next(month_iter) if max_months is not None and months_processed >= max_months: print("Reached max_months limit.") break print(f"\n=== Processing month {month} ===") pgn_path = ensure_month_pgn( month, dest_dir, keep_decompressed=not delete_pgn_after_use ) months_used.append(month) months_processed += 1 with open(pgn_path, "r", encoding="utf-8") as pgn_file: while recorded_board < max_boards: game = chess.pgn.read_game(pgn_file) if game is None: break count_game += 1 if count_game % intermediate_every_games == 0: recorded_board = save_dataset_progress( the_dic, output_file_path, count_game, total_count_move, max_boards, None, None, f"months:{','.join(months_used)}", sampling_frequency, offset_min, seed, is_final=False, ) total_count_move = process_game( game, total_count_move, the_dic, sampling_frequency, offset_min ) recorded_board = len(the_dic) if delete_pgn_after_use: print(f"Deleting processed PGN for month {month}: {pgn_path}") try: Path(pgn_path).unlink(missing_ok=True) except OSError as exc: print(f"Warning: could not delete {pgn_path}: {exc}") # continue to next month if need more boards # Final save using unified function if the_dic: save_dataset_progress( the_dic, output_file_path, count_game, total_count_move, max_boards, None, None, f"months:{','.join(months_used)}", sampling_frequency, offset_min, seed, is_final=True, months_used=months_used, )
# --- CLI integration (dynamic only) --- if __name__ == "__main__": import argparse parser = argparse.ArgumentParser( description="Generate chess board dataset via on-the-fly monthly Lichess downloads (dynamic only)" ) parser.add_argument( "--start-month", default="2015-03", help="Start month YYYY-MM for dynamic mode" ) parser.add_argument( "--max-months", type=int, default=None, help="Maximum number of months to fetch" ) parser.add_argument("--max-boards", type=int, default=10_000_000) parser.add_argument( "--sampling-frequency", type=int, default=DEFAULT_SAMPLING_FREQUENCY ) parser.add_argument("--offset-min", type=int, default=DEFAULT_OFFSET_MIN) parser.add_argument("--seed", type=int, default=0) parser.add_argument( "--output", type=str, default=str(EXTERNAL_DATA_DIR / "datasets" / "only_boards.pkl"), ) parser.add_argument( "--keep-pgn", action="store_true", help="Keep monthly PGN files after processing", ) parser.add_argument( "--intermediate-games", type=int, default=10_000, help="Games interval for intermediate saves", ) args = parser.parse_args() print("Running dynamic monthly download mode (legacy single-file mode disabled)") generate_board_dataset_multi_months( output_file_path=args.output, max_boards=args.max_boards, sampling_frequency=args.sampling_frequency, offset_min=args.offset_min, seed=args.seed, start_month=args.start_month, max_months=args.max_months, delete_pgn_after_use=not args.keep_pgn, intermediate_every_games=args.intermediate_games, )