python3Packages.triton: fix cuda (ptxas, cudart paths)

This commit is contained in:
SomeoneSerge 2024-10-14 17:27:10 +00:00
parent e262792bf1
commit ae560061d8
5 changed files with 258 additions and 49 deletions

View File

@ -0,0 +1,35 @@
From 2751c5de5c61c90b56e3e392a41847f4c47258fd Mon Sep 17 00:00:00 2001
From: SomeoneSerge <else+aalto@someonex.net>
Date: Sun, 13 Oct 2024 14:16:48 +0000
Subject: [PATCH 1/3] _build: allow extra cc flags
---
python/triton/runtime/build.py | 10 +++++++++-
1 file changed, 9 insertions(+), 1 deletion(-)
diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py
index d7baeb286..d334dce77 100644
--- a/python/triton/runtime/build.py
+++ b/python/triton/runtime/build.py
@@ -42,9 +42,17 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
include_dirs = include_dirs + [srcdir, py_include_dir]
cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-o", so]
+
+ # Nixpkgs support branch
+ # Allows passing e.g. extra -Wl,-rpath
+ cc_cmd_extra_flags = "@ccCmdExtraFlags@"
+ if cc_cmd_extra_flags != ("@" + "ccCmdExtraFlags@"): # substituteAll hack
+ import shlex
+ cc_cmd.extend(shlex.split(cc_cmd_extra_flags))
+
cc_cmd += [f'-l{lib}' for lib in libraries]
cc_cmd += [f"-L{dir}" for dir in library_dirs]
- cc_cmd += [f"-I{dir}" for dir in include_dirs]
+ cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None]
ret = subprocess.check_call(cc_cmd)
if ret == 0:
return so
--
2.46.0

View File

@ -0,0 +1,70 @@
From 7407cb03eec82768e333909d87b7668b633bfe86 Mon Sep 17 00:00:00 2001
From: SomeoneSerge <else+aalto@someonex.net>
Date: Sun, 13 Oct 2024 14:28:48 +0000
Subject: [PATCH 2/3] {nvidia,amd}/driver: short-circuit before ldconfig
---
python/triton/runtime/build.py | 6 +++---
third_party/amd/backend/driver.py | 7 +++++++
third_party/nvidia/backend/driver.py | 3 +++
3 files changed, 13 insertions(+), 3 deletions(-)
diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py
index d334dce77..a64e98da0 100644
--- a/python/triton/runtime/build.py
+++ b/python/triton/runtime/build.py
@@ -42,6 +42,9 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
include_dirs = include_dirs + [srcdir, py_include_dir]
cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-o", so]
+ cc_cmd += [f'-l{lib}' for lib in libraries]
+ cc_cmd += [f"-L{dir}" for dir in library_dirs]
+ cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None]
# Nixpkgs support branch
# Allows passing e.g. extra -Wl,-rpath
@@ -50,9 +53,6 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
import shlex
cc_cmd.extend(shlex.split(cc_cmd_extra_flags))
- cc_cmd += [f'-l{lib}' for lib in libraries]
- cc_cmd += [f"-L{dir}" for dir in library_dirs]
- cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None]
ret = subprocess.check_call(cc_cmd)
if ret == 0:
return so
diff --git a/third_party/amd/backend/driver.py b/third_party/amd/backend/driver.py
index 0a8cd7bed..aab8805f6 100644
--- a/third_party/amd/backend/driver.py
+++ b/third_party/amd/backend/driver.py
@@ -24,6 +24,13 @@ def _get_path_to_hip_runtime_dylib():
return env_libhip_path
raise RuntimeError(f"TRITON_LIBHIP_PATH '{env_libhip_path}' does not point to a valid {lib_name}")
+ # ...on release/3.1.x:
+ # return mmapped_path
+ # raise RuntimeError(f"memory mapped '{mmapped_path}' in process does not point to a valid {lib_name}")
+
+ if os.path.isdir("@libhipDir@"):
+ return ["@libhipDir@"]
+
paths = []
import site
diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py
index 90f71138b..30fbadb2a 100644
--- a/third_party/nvidia/backend/driver.py
+++ b/third_party/nvidia/backend/driver.py
@@ -21,6 +21,9 @@ def libcuda_dirs():
if env_libcuda_path:
return [env_libcuda_path]
+ if os.path.exists("@libcudaStubsDir@"):
+ return ["@libcudaStubsDir@"]
+
libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode()
# each line looks like the following:
# libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
--
2.46.0

View File

@ -0,0 +1,46 @@
From 6f92d54e5a544bc34bb07f2808d554a71cc0e4c3 Mon Sep 17 00:00:00 2001
From: SomeoneSerge <else+aalto@someonex.net>
Date: Sun, 13 Oct 2024 14:30:19 +0000
Subject: [PATCH 3/3] nvidia: cudart a systempath
---
third_party/nvidia/backend/driver.c | 2 +-
third_party/nvidia/backend/driver.py | 5 +++--
2 files changed, 4 insertions(+), 3 deletions(-)
diff --git a/third_party/nvidia/backend/driver.c b/third_party/nvidia/backend/driver.c
index 44524da27..fbdf0d156 100644
--- a/third_party/nvidia/backend/driver.c
+++ b/third_party/nvidia/backend/driver.c
@@ -1,4 +1,4 @@
-#include "cuda.h"
+#include <cuda.h>
#include <dlfcn.h>
#include <stdbool.h>
#define PY_SSIZE_T_CLEAN
diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py
index 30fbadb2a..65c0562ed 100644
--- a/third_party/nvidia/backend/driver.py
+++ b/third_party/nvidia/backend/driver.py
@@ -10,7 +10,8 @@ from triton.backends.compiler import GPUTarget
from triton.backends.driver import GPUDriver
dirname = os.path.dirname(os.path.realpath(__file__))
-include_dir = [os.path.join(dirname, "include")]
+import shlex
+include_dir = [*shlex.split("@cudaToolkitIncludeDirs@"), os.path.join(dirname, "include")]
libdevice_dir = os.path.join(dirname, "lib")
libraries = ['cuda']
@@ -149,7 +150,7 @@ def make_launcher(constants, signature, ids):
# generate glue code
params = [i for i in signature.keys() if i not in constants]
src = f"""
-#include \"cuda.h\"
+#include <cuda.h>
#include <stdbool.h>
#include <Python.h>
#include <dlfcn.h>
--
2.46.0

View File

@ -0,0 +1,26 @@
From e503e572b6d444cd27f1cdf124aaf553aa3a8665 Mon Sep 17 00:00:00 2001
From: SomeoneSerge <else+aalto@someonex.net>
Date: Mon, 14 Oct 2024 00:12:05 +0000
Subject: [PATCH 4/4] nvidia: allow static ptxas path
---
third_party/nvidia/backend/compiler.py | 3 +++
1 file changed, 3 insertions(+)
diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py
index 6d7994923..6720e8f97 100644
--- a/third_party/nvidia/backend/compiler.py
+++ b/third_party/nvidia/backend/compiler.py
@@ -20,6 +20,9 @@ def _path_to_binary(binary: str):
os.path.join(os.path.dirname(__file__), "bin", binary),
]
+ import shlex
+ paths.extend(shlex.split("@nixpkgsExtraBinaryPaths@"))
+
for bin in paths:
if os.path.exists(bin) and os.path.isfile(bin):
result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT)
--
2.46.0

View File

@ -1,5 +1,6 @@
{
lib,
addDriverRunpath,
buildPythonPackage,
cmake,
config,
@ -15,10 +16,13 @@
pybind11,
python,
runCommand,
substituteAll,
setuptools,
torchWithRocm,
zlib,
cudaSupport ? config.cudaSupport,
rocmSupport ? config.rocmSupport,
rocmPackages,
}:
buildPythonPackage {
@ -34,29 +38,53 @@ buildPythonPackage {
hash = "sha256-L5KqiR+TgSyKjEBlkE0yOU1pemMHFk2PhEmxLdbbxUU=";
};
patches = [
./0001-setup.py-introduce-TRITON_OFFLINE_BUILD.patch
];
patches =
[
./0001-setup.py-introduce-TRITON_OFFLINE_BUILD.patch
(substituteAll {
src = ./0001-_build-allow-extra-cc-flags.patch;
ccCmdExtraFlags = "-Wl,-rpath,${addDriverRunpath.driverLink}/lib";
})
(substituteAll (
{
src = ./0002-nvidia-amd-driver-short-circuit-before-ldconfig.patch;
}
// lib.optionalAttrs rocmSupport { libhipDir = "${lib.getLib rocmPackages.clr}/lib"; }
// lib.optionalAttrs cudaSupport {
libcudaStubsDir = "${lib.getLib cudaPackages.cuda_cudart}/lib/stubs";
ccCmdExtraFlags = "-Wl,-rpath,${addDriverRunpath.driverLink}/lib";
}
))
]
++ lib.optionals cudaSupport [
(substituteAll {
src = ./0003-nvidia-cudart-a-systempath.patch;
cudaToolkitIncludeDirs = "${lib.getInclude cudaPackages.cuda_cudart}/include";
})
(substituteAll {
src = ./0004-nvidia-allow-static-ptxas-path.patch;
nixpkgsExtraBinaryPaths = lib.escapeShellArgs [ (lib.getExe' cudaPackages.cuda_nvcc "ptxas") ];
})
];
postPatch =
''
# Use our `cmakeFlags` instead and avoid downloading dependencies
# remove any downloads
substituteInPlace python/setup.py \
--replace-fail "get_json_package_info(), get_pybind11_package_info()" ""\
--replace-fail "get_pybind11_package_info(), get_llvm_package_info()" ""\
--replace-fail 'packages += ["triton/profiler"]' ""\
--replace-fail "curr_version != version" "False"
postPatch = ''
# Use our `cmakeFlags` instead and avoid downloading dependencies
# remove any downloads
substituteInPlace python/setup.py \
--replace-fail "get_json_package_info(), get_pybind11_package_info()" ""\
--replace-fail "get_pybind11_package_info(), get_llvm_package_info()" ""\
--replace-fail 'packages += ["triton/profiler"]' ""\
--replace-fail "curr_version != version" "False"
# Don't fetch googletest
substituteInPlace unittest/CMakeLists.txt \
--replace-fail "include (\''${CMAKE_CURRENT_SOURCE_DIR}/googletest.cmake)" ""\
--replace-fail "include(GoogleTest)" "find_package(GTest REQUIRED)"
'';
# Don't fetch googletest
substituteInPlace unittest/CMakeLists.txt \
--replace-fail "include (\''${CMAKE_CURRENT_SOURCE_DIR}/googletest.cmake)" ""\
--replace-fail "include(GoogleTest)" "find_package(GTest REQUIRED)"
'';
build-system = [ setuptools ];
nativeBuildInputs = [
setuptools
# pytestCheckHook # Requires torch (circular dependency) and probably needs GPUs:
cmake
ninja
@ -76,7 +104,7 @@ buildPythonPackage {
zlib
];
propagatedBuildInputs = [
dependencies = [
filelock
# triton uses setuptools at runtime:
# https://github.com/NixOS/nixpkgs/pull/286763/#discussion_r1480392652
@ -106,26 +134,40 @@ buildPythonPackage {
cd python
'';
env = {
TRITON_BUILD_PROTON = "OFF";
TRITON_OFFLINE_BUILD = true;
} // lib.optionalAttrs cudaSupport {
CC = "${cudaPackages.backendStdenv.cc}/bin/cc";
CXX = "${cudaPackages.backendStdenv.cc}/bin/c++";
env =
{
TRITON_BUILD_PROTON = "OFF";
TRITON_OFFLINE_BUILD = true;
}
// lib.optionalAttrs cudaSupport {
CC = lib.getExe' cudaPackages.backendStdenv.cc "cc";
CXX = lib.getExe' cudaPackages.backendStdenv.cc "c++";
TRITON_PTXAS_PATH = lib.getExe' cudaPackages.cuda_nvcc "ptxas"; # Make sure cudaPackages is the right version each update (See python/setup.py)
TRITON_CUOBJDUMP_PATH = cudaPackages.cuda_cuobjdump;
TRITON_NVDISASM_PATH = cudaPackages.cuda_nvdisasm;
TRITON_CUDACRT_PATH = cudaPackages.cuda_nvcc;
TRITON_CUDART_PATH = cudaPackages.cuda_cudart;
TRITON_CUPTI_PATH = cudaPackages.cuda_cupti;
};
# TODO: Unused because of how TRITON_OFFLINE_BUILD currently works (subject to change)
TRITON_PTXAS_PATH = lib.getExe' cudaPackages.cuda_nvcc "ptxas"; # Make sure cudaPackages is the right version each update (See python/setup.py)
TRITON_CUOBJDUMP_PATH = lib.getExe' cudaPackages.cuda_cuobjdump "cuobjdump";
TRITON_NVDISASM_PATH = lib.getExe' cudaPackages.cuda_nvdisasm "nvdisasm";
TRITON_CUDACRT_PATH = lib.getInclude cudaPackages.cuda_nvcc;
TRITON_CUDART_PATH = lib.getInclude cudaPackages.cuda_cudart;
TRITON_CUPTI_PATH = cudaPackages.cuda_cupti;
};
pythonRemoveDeps = [
# Circular dependency, cf. https://github.com/triton-lang/triton/issues/1374
"torch"
# CLI tools without dist-info
"cmake"
"lit"
];
# CMake is run by setup.py instead
dontUseCmakeConfigure = true;
checkInputs = [ cmake ]; # ctest
dontUseSetuptoolsCheck = true;
nativeCheckInputs = [
cmake
# Requires torch (circular dependency) and GPU access: pytestCheckHook
];
preCheck = ''
# build/temp* refers to build_ext.build_temp (looked up in the build logs)
(cd ./build/temp* ; ctest)
@ -134,11 +176,10 @@ buildPythonPackage {
cd test/unit
'';
# Circular dependency on torch
# pythonImportsCheck = [
# "triton"
# "triton.language"
# ];
pythonImportsCheck = [
"triton"
"triton.language"
];
# Ultimately, torch is our test suite:
passthru.tests = {
@ -157,15 +198,6 @@ buildPythonPackage {
'';
};
pythonRemoveDeps = [
# Circular dependency, cf. https://github.com/triton-lang/triton/issues/1374
"torch"
# CLI tools without dist-info
"cmake"
"lit"
];
meta = with lib; {
description = "Language and compiler for writing highly efficient custom Deep-Learning primitives";
homepage = "https://github.com/triton-lang/triton";