convert_gguf.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. #!/usr/bin/env python3
  2. # /// script
  3. # requires-python = ">=3.10"
  4. # dependencies = [
  5. # "transformers>=4.36.0",
  6. # "peft>=0.7.0",
  7. # "torch>=2.0.0",
  8. # "accelerate>=0.24.0",
  9. # "huggingface_hub>=0.20.0",
  10. # "sentencepiece>=0.1.99",
  11. # "protobuf>=3.20.0",
  12. # "numpy",
  13. # "gguf",
  14. # ]
  15. # ///
  16. """
  17. Convert QMD query expansion model to GGUF format.
  18. Loads the base model, merges SFT and GRPO adapters, then converts to
  19. GGUF with multiple quantizations for use with Ollama/llama.cpp/LM Studio.
  20. Usage:
  21. uv run convert_gguf.py --size 1.7B
  22. uv run convert_gguf.py --size 4B --skip-quantize
  23. uv run convert_gguf.py --base Qwen/Qwen3-1.7B \
  24. --sft tobil/qmd-query-expansion-1.7B-sft \
  25. --grpo tobil/qmd-query-expansion-1.7B-grpo \
  26. --output tobil/qmd-query-expansion-1.7B-gguf
  27. """
  28. import argparse
  29. import os
  30. import subprocess
  31. import sys
  32. import torch
  33. from huggingface_hub import HfApi, login
  34. from peft import PeftModel
  35. from transformers import AutoModelForCausalLM, AutoTokenizer
  36. # Preset configurations for each model size
  37. PRESETS = {
  38. "1.7B": {
  39. "base": "Qwen/Qwen3-1.7B",
  40. "sft": "tobil/qmd-query-expansion-1.7B-sft",
  41. "grpo": "tobil/qmd-query-expansion-1.7B-grpo",
  42. "output": "tobil/qmd-query-expansion-1.7B-gguf",
  43. "ollama_name": "qmd-expand",
  44. },
  45. "4B": {
  46. "base": "Qwen/Qwen3-4B",
  47. "sft": "tobil/qmd-query-expansion-4B-sft",
  48. "grpo": "tobil/qmd-query-expansion-4B-grpo",
  49. "output": "tobil/qmd-query-expansion-4B-gguf",
  50. "ollama_name": "qmd-expand-4b",
  51. },
  52. }
  53. def run_cmd(cmd, description):
  54. """Run a shell command with error handling."""
  55. print(f" {description}...")
  56. try:
  57. subprocess.run(cmd, check=True, capture_output=True, text=True)
  58. return True
  59. except subprocess.CalledProcessError as e:
  60. print(f" FAILED: {' '.join(cmd)}")
  61. if e.stderr:
  62. print(f" {e.stderr[:500]}")
  63. return False
  64. except FileNotFoundError:
  65. print(f" Command not found: {cmd[0]}")
  66. return False
  67. def main():
  68. parser = argparse.ArgumentParser(description="Convert QMD model to GGUF")
  69. parser.add_argument("--size", choices=PRESETS.keys(), help="Use preset config for model size")
  70. parser.add_argument("--base", help="Base model (overrides preset)")
  71. parser.add_argument("--sft", help="SFT adapter (overrides preset)")
  72. parser.add_argument("--grpo", help="GRPO adapter (overrides preset)")
  73. parser.add_argument("--output", help="Output HF repo (overrides preset)")
  74. parser.add_argument("--skip-quantize", action="store_true", help="Only produce FP16 GGUF")
  75. parser.add_argument("--no-upload", action="store_true", help="Don't upload to HF Hub")
  76. args = parser.parse_args()
  77. # Resolve config
  78. if args.size:
  79. preset = PRESETS[args.size]
  80. base_model = args.base or preset["base"]
  81. sft_model = args.sft or preset["sft"]
  82. grpo_model = args.grpo or preset["grpo"]
  83. output_repo = args.output or preset["output"]
  84. elif args.base and args.sft and args.grpo and args.output:
  85. base_model = args.base
  86. sft_model = args.sft
  87. grpo_model = args.grpo
  88. output_repo = args.output
  89. else:
  90. parser.error("Either --size or all of --base/--sft/--grpo/--output are required")
  91. model_name = output_repo.split("/")[-1].replace("-gguf", "")
  92. print(f"QMD GGUF Conversion: {model_name}")
  93. print("=" * 60)
  94. # Install build tools (for Colab/cloud environments)
  95. print("\nInstalling build dependencies...")
  96. subprocess.run(["apt-get", "update", "-qq"], capture_output=True)
  97. subprocess.run(["apt-get", "install", "-y", "-qq", "build-essential", "cmake", "git"], capture_output=True)
  98. # Login
  99. hf_token = os.environ.get("HF_TOKEN")
  100. if hf_token:
  101. print("Logging in to HuggingFace...")
  102. login(token=hf_token)
  103. # Step 1: Load and merge
  104. print(f"\nStep 1: Loading base model {base_model}...")
  105. model = AutoModelForCausalLM.from_pretrained(
  106. base_model, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True,
  107. )
  108. print(f"Step 2: Merging SFT adapter {sft_model}...")
  109. model = PeftModel.from_pretrained(model, sft_model)
  110. model = model.merge_and_unload()
  111. print(f"Step 3: Merging GRPO adapter {grpo_model}...")
  112. model = PeftModel.from_pretrained(model, grpo_model)
  113. model = model.merge_and_unload()
  114. tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
  115. # Step 2: Save merged model
  116. merged_dir = "/tmp/merged_model"
  117. print(f"\nStep 4: Saving merged model to {merged_dir}...")
  118. model.save_pretrained(merged_dir, safe_serialization=True)
  119. tokenizer.save_pretrained(merged_dir)
  120. # Step 3: Setup llama.cpp
  121. print("\nStep 5: Setting up llama.cpp...")
  122. if not os.path.exists("/tmp/llama.cpp"):
  123. run_cmd(["git", "clone", "--depth", "1", "https://github.com/ggerganov/llama.cpp.git", "/tmp/llama.cpp"],
  124. "Cloning llama.cpp")
  125. subprocess.run([sys.executable, "-m", "pip", "install", "-q", "-r", "/tmp/llama.cpp/requirements.txt"],
  126. capture_output=True)
  127. # Step 4: Convert to FP16 GGUF
  128. gguf_dir = "/tmp/gguf_output"
  129. os.makedirs(gguf_dir, exist_ok=True)
  130. gguf_file = f"{gguf_dir}/{model_name}-f16.gguf"
  131. print(f"\nStep 6: Converting to FP16 GGUF...")
  132. if not run_cmd([sys.executable, "/tmp/llama.cpp/convert_hf_to_gguf.py",
  133. merged_dir, "--outfile", gguf_file, "--outtype", "f16"],
  134. "Converting"):
  135. sys.exit(1)
  136. size_mb = os.path.getsize(gguf_file) / (1024 * 1024)
  137. print(f" FP16: {size_mb:.1f} MB")
  138. # Step 5: Quantize
  139. quantized_files = []
  140. if not args.skip_quantize:
  141. print("\nStep 7: Building quantize tool...")
  142. os.makedirs("/tmp/llama.cpp/build", exist_ok=True)
  143. run_cmd(["cmake", "-B", "/tmp/llama.cpp/build", "-S", "/tmp/llama.cpp", "-DGGML_CUDA=OFF"],
  144. "CMake configure")
  145. run_cmd(["cmake", "--build", "/tmp/llama.cpp/build", "--target", "llama-quantize", "-j", "4"],
  146. "Building llama-quantize")
  147. quantize_bin = "/tmp/llama.cpp/build/bin/llama-quantize"
  148. print("\nStep 8: Quantizing...")
  149. for quant_type, desc in [("Q4_K_M", "4-bit"), ("Q5_K_M", "5-bit"), ("Q8_0", "8-bit")]:
  150. qfile = f"{gguf_dir}/{model_name}-{quant_type.lower()}.gguf"
  151. if run_cmd([quantize_bin, gguf_file, qfile, quant_type], f"{quant_type} ({desc})"):
  152. qsize = os.path.getsize(qfile) / (1024 * 1024)
  153. print(f" {quant_type}: {qsize:.1f} MB")
  154. quantized_files.append((qfile, quant_type))
  155. # Step 6: Upload
  156. if not args.no_upload:
  157. print(f"\nStep 9: Uploading to {output_repo}...")
  158. api = HfApi()
  159. api.create_repo(repo_id=output_repo, repo_type="model", exist_ok=True)
  160. api.upload_file(path_or_fileobj=gguf_file,
  161. path_in_repo=f"{model_name}-f16.gguf", repo_id=output_repo)
  162. for qfile, qtype in quantized_files:
  163. api.upload_file(path_or_fileobj=qfile,
  164. path_in_repo=f"{model_name}-{qtype.lower()}.gguf", repo_id=output_repo)
  165. # Upload README
  166. readme = f"""---
  167. base_model: {base_model}
  168. tags: [gguf, llama.cpp, quantized, query-expansion, qmd]
  169. ---
  170. # {model_name} (GGUF)
  171. GGUF conversion of the QMD Query Expansion model.
  172. ## Details
  173. - **Base:** {base_model}
  174. - **SFT:** {sft_model}
  175. - **GRPO:** {grpo_model}
  176. - **Task:** Query expansion (lex/vec/hyde format)
  177. ## Prompt Format
  178. ```
  179. <|im_start|>user
  180. /no_think Expand this search query: your query here<|im_end|>
  181. <|im_start|>assistant
  182. ```
  183. """
  184. api.upload_file(path_or_fileobj=readme.encode(),
  185. path_in_repo="README.md", repo_id=output_repo)
  186. print(f"\nDone! Repository: https://huggingface.co/{output_repo}")
  187. if __name__ == "__main__":
  188. main()