python3Packages.triton: fix cuda (ptxas, cudart paths)
This commit is contained in:
parent
e262792bf1
commit
ae560061d8
@ -0,0 +1,35 @@
|
||||
From 2751c5de5c61c90b56e3e392a41847f4c47258fd Mon Sep 17 00:00:00 2001
|
||||
From: SomeoneSerge <else+aalto@someonex.net>
|
||||
Date: Sun, 13 Oct 2024 14:16:48 +0000
|
||||
Subject: [PATCH 1/3] _build: allow extra cc flags
|
||||
|
||||
---
|
||||
python/triton/runtime/build.py | 10 +++++++++-
|
||||
1 file changed, 9 insertions(+), 1 deletion(-)
|
||||
|
||||
diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py
|
||||
index d7baeb286..d334dce77 100644
|
||||
--- a/python/triton/runtime/build.py
|
||||
+++ b/python/triton/runtime/build.py
|
||||
@@ -42,9 +42,17 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
|
||||
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
|
||||
include_dirs = include_dirs + [srcdir, py_include_dir]
|
||||
cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-o", so]
|
||||
+
|
||||
+ # Nixpkgs support branch
|
||||
+ # Allows passing e.g. extra -Wl,-rpath
|
||||
+ cc_cmd_extra_flags = "@ccCmdExtraFlags@"
|
||||
+ if cc_cmd_extra_flags != ("@" + "ccCmdExtraFlags@"): # substituteAll hack
|
||||
+ import shlex
|
||||
+ cc_cmd.extend(shlex.split(cc_cmd_extra_flags))
|
||||
+
|
||||
cc_cmd += [f'-l{lib}' for lib in libraries]
|
||||
cc_cmd += [f"-L{dir}" for dir in library_dirs]
|
||||
- cc_cmd += [f"-I{dir}" for dir in include_dirs]
|
||||
+ cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None]
|
||||
ret = subprocess.check_call(cc_cmd)
|
||||
if ret == 0:
|
||||
return so
|
||||
--
|
||||
2.46.0
|
||||
|
@ -0,0 +1,70 @@
|
||||
From 7407cb03eec82768e333909d87b7668b633bfe86 Mon Sep 17 00:00:00 2001
|
||||
From: SomeoneSerge <else+aalto@someonex.net>
|
||||
Date: Sun, 13 Oct 2024 14:28:48 +0000
|
||||
Subject: [PATCH 2/3] {nvidia,amd}/driver: short-circuit before ldconfig
|
||||
|
||||
---
|
||||
python/triton/runtime/build.py | 6 +++---
|
||||
third_party/amd/backend/driver.py | 7 +++++++
|
||||
third_party/nvidia/backend/driver.py | 3 +++
|
||||
3 files changed, 13 insertions(+), 3 deletions(-)
|
||||
|
||||
diff --git a/python/triton/runtime/build.py b/python/triton/runtime/build.py
|
||||
index d334dce77..a64e98da0 100644
|
||||
--- a/python/triton/runtime/build.py
|
||||
+++ b/python/triton/runtime/build.py
|
||||
@@ -42,6 +42,9 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
|
||||
py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
|
||||
include_dirs = include_dirs + [srcdir, py_include_dir]
|
||||
cc_cmd = [cc, src, "-O3", "-shared", "-fPIC", "-o", so]
|
||||
+ cc_cmd += [f'-l{lib}' for lib in libraries]
|
||||
+ cc_cmd += [f"-L{dir}" for dir in library_dirs]
|
||||
+ cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None]
|
||||
|
||||
# Nixpkgs support branch
|
||||
# Allows passing e.g. extra -Wl,-rpath
|
||||
@@ -50,9 +53,6 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
|
||||
import shlex
|
||||
cc_cmd.extend(shlex.split(cc_cmd_extra_flags))
|
||||
|
||||
- cc_cmd += [f'-l{lib}' for lib in libraries]
|
||||
- cc_cmd += [f"-L{dir}" for dir in library_dirs]
|
||||
- cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None]
|
||||
ret = subprocess.check_call(cc_cmd)
|
||||
if ret == 0:
|
||||
return so
|
||||
diff --git a/third_party/amd/backend/driver.py b/third_party/amd/backend/driver.py
|
||||
index 0a8cd7bed..aab8805f6 100644
|
||||
--- a/third_party/amd/backend/driver.py
|
||||
+++ b/third_party/amd/backend/driver.py
|
||||
@@ -24,6 +24,13 @@ def _get_path_to_hip_runtime_dylib():
|
||||
return env_libhip_path
|
||||
raise RuntimeError(f"TRITON_LIBHIP_PATH '{env_libhip_path}' does not point to a valid {lib_name}")
|
||||
|
||||
+ # ...on release/3.1.x:
|
||||
+ # return mmapped_path
|
||||
+ # raise RuntimeError(f"memory mapped '{mmapped_path}' in process does not point to a valid {lib_name}")
|
||||
+
|
||||
+ if os.path.isdir("@libhipDir@"):
|
||||
+ return ["@libhipDir@"]
|
||||
+
|
||||
paths = []
|
||||
|
||||
import site
|
||||
diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py
|
||||
index 90f71138b..30fbadb2a 100644
|
||||
--- a/third_party/nvidia/backend/driver.py
|
||||
+++ b/third_party/nvidia/backend/driver.py
|
||||
@@ -21,6 +21,9 @@ def libcuda_dirs():
|
||||
if env_libcuda_path:
|
||||
return [env_libcuda_path]
|
||||
|
||||
+ if os.path.exists("@libcudaStubsDir@"):
|
||||
+ return ["@libcudaStubsDir@"]
|
||||
+
|
||||
libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode()
|
||||
# each line looks like the following:
|
||||
# libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1
|
||||
--
|
||||
2.46.0
|
||||
|
@ -0,0 +1,46 @@
|
||||
From 6f92d54e5a544bc34bb07f2808d554a71cc0e4c3 Mon Sep 17 00:00:00 2001
|
||||
From: SomeoneSerge <else+aalto@someonex.net>
|
||||
Date: Sun, 13 Oct 2024 14:30:19 +0000
|
||||
Subject: [PATCH 3/3] nvidia: cudart a systempath
|
||||
|
||||
---
|
||||
third_party/nvidia/backend/driver.c | 2 +-
|
||||
third_party/nvidia/backend/driver.py | 5 +++--
|
||||
2 files changed, 4 insertions(+), 3 deletions(-)
|
||||
|
||||
diff --git a/third_party/nvidia/backend/driver.c b/third_party/nvidia/backend/driver.c
|
||||
index 44524da27..fbdf0d156 100644
|
||||
--- a/third_party/nvidia/backend/driver.c
|
||||
+++ b/third_party/nvidia/backend/driver.c
|
||||
@@ -1,4 +1,4 @@
|
||||
-#include "cuda.h"
|
||||
+#include <cuda.h>
|
||||
#include <dlfcn.h>
|
||||
#include <stdbool.h>
|
||||
#define PY_SSIZE_T_CLEAN
|
||||
diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py
|
||||
index 30fbadb2a..65c0562ed 100644
|
||||
--- a/third_party/nvidia/backend/driver.py
|
||||
+++ b/third_party/nvidia/backend/driver.py
|
||||
@@ -10,7 +10,8 @@ from triton.backends.compiler import GPUTarget
|
||||
from triton.backends.driver import GPUDriver
|
||||
|
||||
dirname = os.path.dirname(os.path.realpath(__file__))
|
||||
-include_dir = [os.path.join(dirname, "include")]
|
||||
+import shlex
|
||||
+include_dir = [*shlex.split("@cudaToolkitIncludeDirs@"), os.path.join(dirname, "include")]
|
||||
libdevice_dir = os.path.join(dirname, "lib")
|
||||
libraries = ['cuda']
|
||||
|
||||
@@ -149,7 +150,7 @@ def make_launcher(constants, signature, ids):
|
||||
# generate glue code
|
||||
params = [i for i in signature.keys() if i not in constants]
|
||||
src = f"""
|
||||
-#include \"cuda.h\"
|
||||
+#include <cuda.h>
|
||||
#include <stdbool.h>
|
||||
#include <Python.h>
|
||||
#include <dlfcn.h>
|
||||
--
|
||||
2.46.0
|
||||
|
@ -0,0 +1,26 @@
|
||||
From e503e572b6d444cd27f1cdf124aaf553aa3a8665 Mon Sep 17 00:00:00 2001
|
||||
From: SomeoneSerge <else+aalto@someonex.net>
|
||||
Date: Mon, 14 Oct 2024 00:12:05 +0000
|
||||
Subject: [PATCH 4/4] nvidia: allow static ptxas path
|
||||
|
||||
---
|
||||
third_party/nvidia/backend/compiler.py | 3 +++
|
||||
1 file changed, 3 insertions(+)
|
||||
|
||||
diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py
|
||||
index 6d7994923..6720e8f97 100644
|
||||
--- a/third_party/nvidia/backend/compiler.py
|
||||
+++ b/third_party/nvidia/backend/compiler.py
|
||||
@@ -20,6 +20,9 @@ def _path_to_binary(binary: str):
|
||||
os.path.join(os.path.dirname(__file__), "bin", binary),
|
||||
]
|
||||
|
||||
+ import shlex
|
||||
+ paths.extend(shlex.split("@nixpkgsExtraBinaryPaths@"))
|
||||
+
|
||||
for bin in paths:
|
||||
if os.path.exists(bin) and os.path.isfile(bin):
|
||||
result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT)
|
||||
--
|
||||
2.46.0
|
||||
|
@ -1,5 +1,6 @@
|
||||
{
|
||||
lib,
|
||||
addDriverRunpath,
|
||||
buildPythonPackage,
|
||||
cmake,
|
||||
config,
|
||||
@ -15,10 +16,13 @@
|
||||
pybind11,
|
||||
python,
|
||||
runCommand,
|
||||
substituteAll,
|
||||
setuptools,
|
||||
torchWithRocm,
|
||||
zlib,
|
||||
cudaSupport ? config.cudaSupport,
|
||||
rocmSupport ? config.rocmSupport,
|
||||
rocmPackages,
|
||||
}:
|
||||
|
||||
buildPythonPackage {
|
||||
@ -34,29 +38,53 @@ buildPythonPackage {
|
||||
hash = "sha256-L5KqiR+TgSyKjEBlkE0yOU1pemMHFk2PhEmxLdbbxUU=";
|
||||
};
|
||||
|
||||
patches = [
|
||||
./0001-setup.py-introduce-TRITON_OFFLINE_BUILD.patch
|
||||
];
|
||||
patches =
|
||||
[
|
||||
./0001-setup.py-introduce-TRITON_OFFLINE_BUILD.patch
|
||||
(substituteAll {
|
||||
src = ./0001-_build-allow-extra-cc-flags.patch;
|
||||
ccCmdExtraFlags = "-Wl,-rpath,${addDriverRunpath.driverLink}/lib";
|
||||
})
|
||||
(substituteAll (
|
||||
{
|
||||
src = ./0002-nvidia-amd-driver-short-circuit-before-ldconfig.patch;
|
||||
}
|
||||
// lib.optionalAttrs rocmSupport { libhipDir = "${lib.getLib rocmPackages.clr}/lib"; }
|
||||
// lib.optionalAttrs cudaSupport {
|
||||
libcudaStubsDir = "${lib.getLib cudaPackages.cuda_cudart}/lib/stubs";
|
||||
ccCmdExtraFlags = "-Wl,-rpath,${addDriverRunpath.driverLink}/lib";
|
||||
}
|
||||
))
|
||||
]
|
||||
++ lib.optionals cudaSupport [
|
||||
(substituteAll {
|
||||
src = ./0003-nvidia-cudart-a-systempath.patch;
|
||||
cudaToolkitIncludeDirs = "${lib.getInclude cudaPackages.cuda_cudart}/include";
|
||||
})
|
||||
(substituteAll {
|
||||
src = ./0004-nvidia-allow-static-ptxas-path.patch;
|
||||
nixpkgsExtraBinaryPaths = lib.escapeShellArgs [ (lib.getExe' cudaPackages.cuda_nvcc "ptxas") ];
|
||||
})
|
||||
];
|
||||
|
||||
postPatch =
|
||||
''
|
||||
# Use our `cmakeFlags` instead and avoid downloading dependencies
|
||||
# remove any downloads
|
||||
substituteInPlace python/setup.py \
|
||||
--replace-fail "get_json_package_info(), get_pybind11_package_info()" ""\
|
||||
--replace-fail "get_pybind11_package_info(), get_llvm_package_info()" ""\
|
||||
--replace-fail 'packages += ["triton/profiler"]' ""\
|
||||
--replace-fail "curr_version != version" "False"
|
||||
postPatch = ''
|
||||
# Use our `cmakeFlags` instead and avoid downloading dependencies
|
||||
# remove any downloads
|
||||
substituteInPlace python/setup.py \
|
||||
--replace-fail "get_json_package_info(), get_pybind11_package_info()" ""\
|
||||
--replace-fail "get_pybind11_package_info(), get_llvm_package_info()" ""\
|
||||
--replace-fail 'packages += ["triton/profiler"]' ""\
|
||||
--replace-fail "curr_version != version" "False"
|
||||
|
||||
# Don't fetch googletest
|
||||
substituteInPlace unittest/CMakeLists.txt \
|
||||
--replace-fail "include (\''${CMAKE_CURRENT_SOURCE_DIR}/googletest.cmake)" ""\
|
||||
--replace-fail "include(GoogleTest)" "find_package(GTest REQUIRED)"
|
||||
'';
|
||||
# Don't fetch googletest
|
||||
substituteInPlace unittest/CMakeLists.txt \
|
||||
--replace-fail "include (\''${CMAKE_CURRENT_SOURCE_DIR}/googletest.cmake)" ""\
|
||||
--replace-fail "include(GoogleTest)" "find_package(GTest REQUIRED)"
|
||||
'';
|
||||
|
||||
build-system = [ setuptools ];
|
||||
|
||||
nativeBuildInputs = [
|
||||
setuptools
|
||||
# pytestCheckHook # Requires torch (circular dependency) and probably needs GPUs:
|
||||
cmake
|
||||
ninja
|
||||
|
||||
@ -76,7 +104,7 @@ buildPythonPackage {
|
||||
zlib
|
||||
];
|
||||
|
||||
propagatedBuildInputs = [
|
||||
dependencies = [
|
||||
filelock
|
||||
# triton uses setuptools at runtime:
|
||||
# https://github.com/NixOS/nixpkgs/pull/286763/#discussion_r1480392652
|
||||
@ -106,26 +134,40 @@ buildPythonPackage {
|
||||
cd python
|
||||
'';
|
||||
|
||||
env = {
|
||||
TRITON_BUILD_PROTON = "OFF";
|
||||
TRITON_OFFLINE_BUILD = true;
|
||||
} // lib.optionalAttrs cudaSupport {
|
||||
CC = "${cudaPackages.backendStdenv.cc}/bin/cc";
|
||||
CXX = "${cudaPackages.backendStdenv.cc}/bin/c++";
|
||||
env =
|
||||
{
|
||||
TRITON_BUILD_PROTON = "OFF";
|
||||
TRITON_OFFLINE_BUILD = true;
|
||||
}
|
||||
// lib.optionalAttrs cudaSupport {
|
||||
CC = lib.getExe' cudaPackages.backendStdenv.cc "cc";
|
||||
CXX = lib.getExe' cudaPackages.backendStdenv.cc "c++";
|
||||
|
||||
TRITON_PTXAS_PATH = lib.getExe' cudaPackages.cuda_nvcc "ptxas"; # Make sure cudaPackages is the right version each update (See python/setup.py)
|
||||
TRITON_CUOBJDUMP_PATH = cudaPackages.cuda_cuobjdump;
|
||||
TRITON_NVDISASM_PATH = cudaPackages.cuda_nvdisasm;
|
||||
TRITON_CUDACRT_PATH = cudaPackages.cuda_nvcc;
|
||||
TRITON_CUDART_PATH = cudaPackages.cuda_cudart;
|
||||
TRITON_CUPTI_PATH = cudaPackages.cuda_cupti;
|
||||
};
|
||||
# TODO: Unused because of how TRITON_OFFLINE_BUILD currently works (subject to change)
|
||||
TRITON_PTXAS_PATH = lib.getExe' cudaPackages.cuda_nvcc "ptxas"; # Make sure cudaPackages is the right version each update (See python/setup.py)
|
||||
TRITON_CUOBJDUMP_PATH = lib.getExe' cudaPackages.cuda_cuobjdump "cuobjdump";
|
||||
TRITON_NVDISASM_PATH = lib.getExe' cudaPackages.cuda_nvdisasm "nvdisasm";
|
||||
TRITON_CUDACRT_PATH = lib.getInclude cudaPackages.cuda_nvcc;
|
||||
TRITON_CUDART_PATH = lib.getInclude cudaPackages.cuda_cudart;
|
||||
TRITON_CUPTI_PATH = cudaPackages.cuda_cupti;
|
||||
};
|
||||
|
||||
pythonRemoveDeps = [
|
||||
# Circular dependency, cf. https://github.com/triton-lang/triton/issues/1374
|
||||
"torch"
|
||||
|
||||
# CLI tools without dist-info
|
||||
"cmake"
|
||||
"lit"
|
||||
];
|
||||
|
||||
# CMake is run by setup.py instead
|
||||
dontUseCmakeConfigure = true;
|
||||
checkInputs = [ cmake ]; # ctest
|
||||
dontUseSetuptoolsCheck = true;
|
||||
|
||||
nativeCheckInputs = [
|
||||
cmake
|
||||
# Requires torch (circular dependency) and GPU access: pytestCheckHook
|
||||
];
|
||||
preCheck = ''
|
||||
# build/temp* refers to build_ext.build_temp (looked up in the build logs)
|
||||
(cd ./build/temp* ; ctest)
|
||||
@ -134,11 +176,10 @@ buildPythonPackage {
|
||||
cd test/unit
|
||||
'';
|
||||
|
||||
# Circular dependency on torch
|
||||
# pythonImportsCheck = [
|
||||
# "triton"
|
||||
# "triton.language"
|
||||
# ];
|
||||
pythonImportsCheck = [
|
||||
"triton"
|
||||
"triton.language"
|
||||
];
|
||||
|
||||
# Ultimately, torch is our test suite:
|
||||
passthru.tests = {
|
||||
@ -157,15 +198,6 @@ buildPythonPackage {
|
||||
'';
|
||||
};
|
||||
|
||||
pythonRemoveDeps = [
|
||||
# Circular dependency, cf. https://github.com/triton-lang/triton/issues/1374
|
||||
"torch"
|
||||
|
||||
# CLI tools without dist-info
|
||||
"cmake"
|
||||
"lit"
|
||||
];
|
||||
|
||||
meta = with lib; {
|
||||
description = "Language and compiler for writing highly efficient custom Deep-Learning primitives";
|
||||
homepage = "https://github.com/triton-lang/triton";
|
||||
|
Loading…
Reference in New Issue
Block a user