Project structure of: waterhorse1/LLM_Tree_Search
README.md
Accelerate language model inference with AlphaZero-like techniques.requirement.txt
Required packages: ctranslate2, transformers, torch, numpy, flash-attn, dm-treesetup.py
Setup Python package for AlphaZero-like learning framework using LLMsds_config.json
Trains ML model with params: micro-batch size, grad accum steps, mixed precision.mcts_game24_llama_deepspeed.yaml
DeepSpeed MCTS Game24 Training Configtest_policy_and_value.sh
Test offline RL model for game24 with specified CT2 and critic models.train_game24_critic.py
Train MCTS algorithm for game24 task using Flash attention.train_game24_sft.py
Trains Game24 RL model with Llama-2-7b base, config params.
ds_config.json
Training parameters JSON for ML model.it1_gsm8k.ipynb
Deduplicates, merges, subsamples Q&A data for JSONL files.mcts_gsm8k_llama_deepspeed.yaml
Training MCTS on game24 with DeepSpeed, 8 processes.README.md
GSM8k rollout and training process using rollouts, SFT, and critic.test_policy_and_value.sh
Trains ML model with env settings, tests policy and value functions.train_gsm8k_critic.py
Trains MCTS-based language model using PeFTtrain_gsm8k_sft.py
Trains Llama-2-7b on GSM8K with SFT data and MCTS algorithm.
ds_config.json
Trains ML model with config parameters.mcts_prontoqa_llama_deepspeed.yaml
DeepSpeed MCTS game24 training config with 8 processes.test_policy_and_value.sh
Run ProntoQA test with multi-GPU offline RL model.train_prontoqa_critic.py
Trains Llama with Lora & Flash Attentiontrain_prontoqa_sft.py
Train ProntoQA with Flash Attention
accelerate_config.yaml
Local DeepSpeed, 8 processes, RDZV backendds_config_no_offload.json
Optimize ML training: micro-batch, acc steps, BF16/FP16, zero optimization.filter_top_data_policy_training.py
Filter top 5 MCTS data for policy training.mix_value_data.py
Mixes and loads data for RL critic training.README.md
Trains RL model with preprocessed data using RLHFtest_policy_and_value.sh
Test policy and value using CT2 model cache.train_rlhf_critic.py
Trains RL model using MCTStrain_rlhf_policy.py
Trains RL model for language modeling with MCTS
argparse_utils.py
Utility functions for boolean and list of integers parsingutils.py
Distributed computation utilities for tensor collection
__init__.py
Imports envs, defines tasks with datasets.base_env.py
Base class for Monte Carlo Tree Search in text-based games.__init__.py
Game24 environment setup.data.py
Trains and tests datasets for game puzzles.env.py
Game24Env: Solving 24 math puzzle environmentprompt.py
Solve 24 with given numbers and ops.
__init__.py
Environment initialization and data handling.data.py
GSM8K dataset loader with truncationenv.py
Gsm8kEnv: CoTEnv class, answer/ground truth extraction, init/reward methods.prompt.py
Python math prompts, step-by-step with multiple-choice
__init__.py
Imports for PrOntoQA environment.data.py
Data generation scripts for question-answer datasetsenv.py
Defines QA environment class for answer extraction and ground truth.prompt.py
Logical reasoning with arthropods examples.
__init__.py
Imports for RLHF environmentdata.py
Generates datasets, customizes for models, encodes answers.env.py
Reinforcement learning environment for language modelsprompt.py
Task format constants and templates.
test_game24.py
Test Game24Env environment, data loading, and lengths display.test_gsm8k.py
Test GSM8k dataset few-shot learning model.test_prontoqa.py
Setup environment for PrOntoQA, tokenize with Llama-2-7b-hf.test_rlhf.py
Contextual understanding task environment setup and model fine-tuning.
utils.py
Constructs data component, tokenizes, generates queries and responses. Assigns rewards.
vote_utils.py
Aggregate votes based on rules using Counter and defaultdict.
lm_self_value.py
Calculate mean values and generate prompts.trajectory_collector.py
Monte Carlo Tree Search, trajectory data collection.value.py
Value function from critic model on text input.
ct2_utils.py
Load ctranslate2 model and converter.text_generation.py
Text generation with ChatGLM, handling missing substrings.
tree.py
Tree-based search algorithms for games and language tasks using MCTS, Beam Search, and JSON data.utils.py
Retrieves root node from tree by traversing upwards.
merge_jsonl.py
Merge multiple JSONL files, excluding "merged.jsonl", into one.__init__.py
Initializes critic model based on value_model_type_name and model path.llama_flash_attn_monkey_patch.py
Llama model's attention modification with GPU restriction.modeling_actor_critic.py
Actor-Critic Transformer Language Modelmodeling_base.py
Wrapper for PreTrainedModel with Peft support and efficient downloads.modeling_prm.py
Improved sequential decoding for causal language models.utils.py
Simplify attribute access with chainable functions
dedup.py
Deduplicates JSONL text entries and tracks counts.gen_3.sh
Generate 3-episode game24 data using CUDA devices.process.sh
Preprocesses Game24 data for offline RL training.
generate_data.py
Generate Llama responses with ThreadPoolExecutor.gen_3.sh
Generates data for GSM8K environment thrice.process.sh
Processes GSM8K data for offline RL training.
merge.py
Merges, filters, and deduplicates JSONL pairs.gen_3.sh
Offline RL data generation scriptprocess.sh
Process script for offline RL split.
gen_3.sh
Trains RLHF model using 8 processes, specified settings.process.py
Loads, merges JSONL data and produces a new JSONL file.process.sh
Process input files, set env vars, create output dir if needed.
sample.py
Randomly samples JSONL data for training or testingsplit_two_test.py
Trains and splits JSONL data for offline RL.test_sft_and_v.py
Test offline RL agent with various methods and evaluate accuracy.test_sft_and_v_rlhf.py
Offline RL for LLMs with search parameters and CotSC comparisons.utils.py
Utils for JSONL file operations and random seeding
config.py
Configurable RL settings classbuffer.py
Buffer class with padding validation and DataLoader supportnode_types_new.py
TimeStep, tokenizing, reinforcement learning, returns calculationsft_buffer.py
SFTBuffer: Batch, Pad, Collate SftInstance data.traj_buffer.py
TrajBuffer management and MultiTrajBuffer for trajectories.
base_trainer.py
BaseTrainer: MCTS-based model training abstract class with get_arch, setup_optimizer methods.mcts_trainer_traj_ct2_sft.py
Trains MCTS models for Monte Carlo search algorithms.mcts_trainer_traj_ct2_value.py
Trains/tests RL and language models with PyTorch DataLoader.opt_utils.py
Optimizer class, Adam, AdamW, SGD, learning rate schedulersutils.py
Flattens dict, retrieves git commit details.