setup_ai_env.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. # encoding: utf-8
  2. from __future__ import unicode_literals
  3. from .service import BaseService, NodesConfig
  4. from .parser import inject_ssh_hosts_options, help_d
  5. class AIEnvService(BaseService):
  6. def __init__(self, subparsers, action):
  7. super(AIEnvService, self).__init__(subparsers,
  8. action, "%s services of hosts" % action)
  9. def inject_options(self, parser):
  10. inject_ssh_hosts_options(parser)
  11. # NVIDIA driver and CUDA installer full paths
  12. parser.add_argument("--nvidia-driver-installer-path",
  13. dest="nvidia_driver_installer_path",
  14. required=True,
  15. help="Full path to NVIDIA driver installer (e.g., /root/nvidia/NVIDIA-Linux-x86_64-570.133.07.run)")
  16. parser.add_argument("--cuda-installer-path",
  17. dest="cuda_installer_path",
  18. required=True,
  19. help="Full path to CUDA installer (e.g., /root/nvidia/cuda_12.8.1_570.124.06_linux.run)")
  20. parser.add_argument("--gpu-device-virtual-number",
  21. dest="gpu_device_virtual_number",
  22. type=int,
  23. default=2,
  24. help=help_d("Virtual number for NVIDIA GPU share device (default: 2)"))
  25. def do_action(self, args):
  26. config = NodesConfig(args.target_node_hosts,
  27. args.ssh_user,
  28. args.ssh_private_file,
  29. args.ssh_port)
  30. # Prepare ansible variables from command line arguments (using full paths)
  31. vars = {
  32. 'nvidia_driver_installer_path': args.nvidia_driver_installer_path,
  33. 'cuda_installer_path': args.cuda_installer_path,
  34. 'gpu_device_virtual_number': args.gpu_device_virtual_number,
  35. # Add SSH configuration for rsync commands
  36. 'ansible_ssh_private_key_file': args.ssh_private_file,
  37. }
  38. return config.run(self.action, vars=vars)
  39. def add_command(subparsers):
  40. AIEnvService(subparsers, 'setup-ai-env')