Project structure of: waterhorse1/LLM_Tree_Search
README.mdAccelerate language model inference with AlphaZero-like techniques.requirement.txtRequired packages: ctranslate2, transformers, torch, numpy, flash-attn, dm-treesetup.pySetup Python package for AlphaZero-like learning framework using LLMsds_config.jsonTrains ML model with params: micro-batch size, grad accum steps, mixed precision.mcts_game24_llama_deepspeed.yamlDeepSpeed MCTS Game24 Training Configtest_policy_and_value.shTest offline RL model for game24 with specified CT2 and critic models.train_game24_critic.pyTrain MCTS algorithm for game24 task using Flash attention.train_game24_sft.pyTrains Game24 RL model with Llama-2-7b base, config params.
ds_config.jsonTraining parameters JSON for ML model.it1_gsm8k.ipynbDeduplicates, merges, subsamples Q&A data for JSONL files.mcts_gsm8k_llama_deepspeed.yamlTraining MCTS on game24 with DeepSpeed, 8 processes.README.mdGSM8k rollout and training process using rollouts, SFT, and critic.test_policy_and_value.shTrains ML model with env settings, tests policy and value functions.train_gsm8k_critic.pyTrains MCTS-based language model using PeFTtrain_gsm8k_sft.pyTrains Llama-2-7b on GSM8K with SFT data and MCTS algorithm.
ds_config.jsonTrains ML model with config parameters.mcts_prontoqa_llama_deepspeed.yamlDeepSpeed MCTS game24 training config with 8 processes.test_policy_and_value.shRun ProntoQA test with multi-GPU offline RL model.train_prontoqa_critic.pyTrains Llama with Lora & Flash Attentiontrain_prontoqa_sft.pyTrain ProntoQA with Flash Attention
accelerate_config.yamlLocal DeepSpeed, 8 processes, RDZV backendds_config_no_offload.jsonOptimize ML training: micro-batch, acc steps, BF16/FP16, zero optimization.filter_top_data_policy_training.pyFilter top 5 MCTS data for policy training.mix_value_data.pyMixes and loads data for RL critic training.README.mdTrains RL model with preprocessed data using RLHFtest_policy_and_value.shTest policy and value using CT2 model cache.train_rlhf_critic.pyTrains RL model using MCTStrain_rlhf_policy.pyTrains RL model for language modeling with MCTS
argparse_utils.pyUtility functions for boolean and list of integers parsingutils.pyDistributed computation utilities for tensor collection
__init__.pyImports envs, defines tasks with datasets.base_env.pyBase class for Monte Carlo Tree Search in text-based games.__init__.pyGame24 environment setup.data.pyTrains and tests datasets for game puzzles.env.pyGame24Env: Solving 24 math puzzle environmentprompt.pySolve 24 with given numbers and ops.
__init__.pyEnvironment initialization and data handling.data.pyGSM8K dataset loader with truncationenv.pyGsm8kEnv: CoTEnv class, answer/ground truth extraction, init/reward methods.prompt.pyPython math prompts, step-by-step with multiple-choice
__init__.pyImports for PrOntoQA environment.data.pyData generation scripts for question-answer datasetsenv.pyDefines QA environment class for answer extraction and ground truth.prompt.pyLogical reasoning with arthropods examples.
__init__.pyImports for RLHF environmentdata.pyGenerates datasets, customizes for models, encodes answers.env.pyReinforcement learning environment for language modelsprompt.pyTask format constants and templates.
test_game24.pyTest Game24Env environment, data loading, and lengths display.test_gsm8k.pyTest GSM8k dataset few-shot learning model.test_prontoqa.pySetup environment for PrOntoQA, tokenize with Llama-2-7b-hf.test_rlhf.pyContextual understanding task environment setup and model fine-tuning.
utils.pyConstructs data component, tokenizes, generates queries and responses. Assigns rewards.
vote_utils.pyAggregate votes based on rules using Counter and defaultdict.
lm_self_value.pyCalculate mean values and generate prompts.trajectory_collector.pyMonte Carlo Tree Search, trajectory data collection.value.pyValue function from critic model on text input.
ct2_utils.pyLoad ctranslate2 model and converter.text_generation.pyText generation with ChatGLM, handling missing substrings.
tree.pyTree-based search algorithms for games and language tasks using MCTS, Beam Search, and JSON data.utils.pyRetrieves root node from tree by traversing upwards.
merge_jsonl.pyMerge multiple JSONL files, excluding "merged.jsonl", into one.__init__.pyInitializes critic model based on value_model_type_name and model path.llama_flash_attn_monkey_patch.pyLlama model's attention modification with GPU restriction.modeling_actor_critic.pyActor-Critic Transformer Language Modelmodeling_base.pyWrapper for PreTrainedModel with Peft support and efficient downloads.modeling_prm.pyImproved sequential decoding for causal language models.utils.pySimplify attribute access with chainable functions
dedup.pyDeduplicates JSONL text entries and tracks counts.gen_3.shGenerate 3-episode game24 data using CUDA devices.process.shPreprocesses Game24 data for offline RL training.
generate_data.pyGenerate Llama responses with ThreadPoolExecutor.gen_3.shGenerates data for GSM8K environment thrice.process.shProcesses GSM8K data for offline RL training.
merge.pyMerges, filters, and deduplicates JSONL pairs.gen_3.shOffline RL data generation scriptprocess.shProcess script for offline RL split.
gen_3.shTrains RLHF model using 8 processes, specified settings.process.pyLoads, merges JSONL data and produces a new JSONL file.process.shProcess input files, set env vars, create output dir if needed.
sample.pyRandomly samples JSONL data for training or testingsplit_two_test.pyTrains and splits JSONL data for offline RL.test_sft_and_v.pyTest offline RL agent with various methods and evaluate accuracy.test_sft_and_v_rlhf.pyOffline RL for LLMs with search parameters and CotSC comparisons.utils.pyUtils for JSONL file operations and random seeding
config.pyConfigurable RL settings classbuffer.pyBuffer class with padding validation and DataLoader supportnode_types_new.pyTimeStep, tokenizing, reinforcement learning, returns calculationsft_buffer.pySFTBuffer: Batch, Pad, Collate SftInstance data.traj_buffer.pyTrajBuffer management and MultiTrajBuffer for trajectories.
base_trainer.pyBaseTrainer: MCTS-based model training abstract class with get_arch, setup_optimizer methods.mcts_trainer_traj_ct2_sft.pyTrains MCTS models for Monte Carlo search algorithms.mcts_trainer_traj_ct2_value.pyTrains/tests RL and language models with PyTorch DataLoader.opt_utils.pyOptimizer class, Adam, AdamW, SGD, learning rate schedulersutils.pyFlattens dict, retrieves git commit details.