python3Packages.openai-triton: 2.0.0 -> 2.1.0

This commit is contained in:
Thiago Franco de Moraes 2023-12-05 18:02:45 -03:00
parent 9f9f0356bc
commit 2b25625c1f
No known key found for this signature in database
GPG Key ID: 1B96996EE6559B7A
6 changed files with 39 additions and 76 deletions

View File

@ -1,39 +1,15 @@
diff --git a/python/setup.py b/python/setup.py
index 2ac3accd2..f26161c72 100644
index 18764ec13..b3bb5b60a 100644
--- a/python/setup.py
+++ b/python/setup.py
@@ -101,25 +101,6 @@ def get_thirdparty_packages(triton_cache_path):
# ---- package data ---
@@ -269,10 +269,6 @@ class CMakeBuild(build_ext):
subprocess.check_call(["cmake", self.base_dir] + cmake_args, cwd=cmake_dir, env=env)
subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=cmake_dir)
-def download_and_copy_ptxas():
- base_dir = os.path.dirname(__file__)
- src_path = "bin/ptxas"
- url = "https://conda.anaconda.org/nvidia/label/cuda-12.0.0/linux-64/cuda-nvcc-12.0.76-0.tar.bz2"
- dst_prefix = os.path.join(base_dir, "triton")
- dst_suffix = os.path.join("third_party", "cuda", src_path)
- dst_path = os.path.join(dst_prefix, dst_suffix)
- if not os.path.exists(dst_path):
- print(f'downloading and extracting {url} ...')
- ftpstream = urllib.request.urlopen(url)
- file = tarfile.open(fileobj=ftpstream, mode="r|*")
- with tempfile.TemporaryDirectory() as temp_dir:
- file.extractall(path=temp_dir)
- src_path = os.path.join(temp_dir, src_path)
- os.makedirs(os.path.split(dst_path)[0], exist_ok=True)
- shutil.copy(src_path, dst_path)
- return dst_suffix
-
-
# ---- cmake extension ----
@@ -200,8 +181,6 @@ class CMakeBuild(build_ext):
subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=self.build_temp)
-download_and_copy_ptxas()
-
-
setup(
name="triton",
version="2.0.0",
version="2.1.0",

View File

@ -18,7 +18,7 @@
buildPythonPackage rec {
pname = "triton";
version = "2.0.0";
version = "2.1.0";
format = "wheel";
src =
@ -62,7 +62,7 @@ buildPythonPackage rec {
newStr = lib.concatMapStringsSep ", " quote new;
in
''
substituteInPlace $out/${python.sitePackages}/triton/compiler.py \
substituteInPlace $out/${python.sitePackages}/triton/common/build.py \
--replace '${oldStr}' '${newStr}'
'');

View File

@ -6,26 +6,26 @@
# To add a new version, run "prefetch.sh 'new-version'" to paste the generated file as follows.
version : builtins.getAttr version {
"2.0.0" = {
"2.1.0" = {
x86_64-linux-38 = {
name = "triton-2.0.0-1-cp38-cp38-linux_x86_64.whl";
url = "https://download.pytorch.org/whl/triton-2.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl";
hash = "sha256-nUl4KYt0/PWadf5x5TXAkrAjCIkzsvHfkz7DJhXkvu8=";
name = "triton-2.1.0-cp38-cp38-linux_x86_64.whl";
url = "https://download.pytorch.org/whl/triton-2.1.0-0-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl";
hash = "sha256-Ofb7a9zLPpjzFS4/vqck8a6ufXSUErux+pxEHUdOuiY=";
};
x86_64-linux-39 = {
name = "triton-2.0.0-1-cp39-cp39-linux_x86_64.whl";
url = "https://download.pytorch.org/whl/triton-2.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl";
hash = "sha256-dPEYwStDf7LKJeGgR1kXO1F1gvz0x74RkTMWx2QhNlY=";
name = "triton-2.1.0-cp39-cp39-linux_x86_64.whl";
url = "https://download.pytorch.org/whl/triton-2.1.0-0-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl";
hash = "sha256-IVROUiwCAFpibIrWPTm9/y8x1BBpWSkZ7ygelk7SZEY=";
};
x86_64-linux-310 = {
name = "triton-2.0.0-1-cp310-cp310-linux_x86_64.whl";
url = "https://download.pytorch.org/whl/triton-2.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl";
hash = "sha256-OIBu6WY/Sw981keQ6WxXk3QInlj0mqxKZggSGqVeJQU=";
name = "triton-2.1.0-cp310-cp310-linux_x86_64.whl";
url = "https://download.pytorch.org/whl/triton-2.1.0-0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl";
hash = "sha256-ZkOZI6MNXUg5mwip6uEDcPbCYaXshkpkmDuuYxUtOdc=";
};
x86_64-linux-311 = {
name = "triton-2.0.0-1-cp311-cp311-linux_x86_64.whl";
url = "https://download.pytorch.org/whl/triton-2.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl";
hash = "sha256-ImlBx7hZUhnd71mh/bgh6MdEKJoTJBXd1YT6zt60dbE=";
name = "triton-2.1.0-cp311-cp311-linux_x86_64.whl";
url = "https://download.pytorch.org/whl/triton-2.1.0-0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl";
hash = "sha256-kZsGRT8AM+pSwT6veDPeDlfbMXjSPU4E+fxxxPLDK/g=";
};
};
}

View File

@ -2,10 +2,11 @@
, config
, buildPythonPackage
, fetchFromGitHub
, fetchpatch
, addOpenGLRunpath
, setuptools
, pytestCheckHook
, pythonRelaxDepsHook
, pkgsTargetTarget
, cmake
, ninja
, pybind11
@ -23,46 +24,32 @@
}:
let
# A time may come we'll want to be cross-friendly
#
# Short explanation: we need pkgsTargetTarget, because we use string
# interpolation instead of buildInputs.
#
# Long explanation: OpenAI/triton downloads and vendors a copy of NVidia's
# ptxas compiler. We're not running this ptxas on the build machine, but on
# the user's machine, i.e. our Target platform. The second "Target" in
# pkgsTargetTarget maybe doesn't matter, because ptxas compiles programs to
# be executed on the GPU.
# Cf. https://nixos.org/manual/nixpkgs/unstable/#sec-cross-infra
ptxas = "${pkgsTargetTarget.cudaPackages.cuda_nvcc}/bin/ptxas"; # Make sure cudaPackages is the right version each update (See python/setup.py)
ptxas = "${cudaPackages.cuda_nvcc}/bin/ptxas"; # Make sure cudaPackages is the right version each update (See python/setup.py)
in
buildPythonPackage rec {
pname = "triton";
version = "2.0.0";
format = "setuptools";
version = "2.1.0";
pyproject = true;
src = fetchFromGitHub {
owner = "openai";
repo = pname;
rev = "v${version}";
hash = "sha256-9GZzugab+Pdt74Dj6zjlEzjj4BcJ69rzMJmqcVMxsKU=";
hash = "sha256-8UTUwLH+SriiJnpejdrzz9qIquP2zBp1/uwLdHmv0XQ=";
};
patches = [
# TODO: there have been commits upstream aimed at removing the "torch"
# circular dependency, but the patches fail to apply on the release
# revision. Keeping the link for future reference
# Also cf. https://github.com/openai/triton/issues/1374
# (fetchpatch {
# url = "https://github.com/openai/triton/commit/fc7c0b0e437a191e421faa61494b2ff4870850f1.patch";
# hash = "sha256-f0shIqHJkVvuil2Yku7vuqWFn7VCRKFSFjYRlwx25ig=";
# })
# fix overflow error
(fetchpatch {
url = "https://github.com/openai/triton/commit/52c146f66b79b6079bcd28c55312fc6ea1852519.patch";
hash = "sha256-098/TCQrzvrBAbQiaVGCMaF3o5Yc3yWDxzwSkzIuAtY=";
})
] ++ lib.optionals (!cudaSupport) [
./0000-dont-download-ptxas.patch
];
nativeBuildInputs = [
setuptools
pythonRelaxDepsHook
# pytestCheckHook # Requires torch (circular dependency) and probably needs GPUs:
cmake
@ -111,7 +98,7 @@ buildPythonPackage rec {
--replace "include(GoogleTest)" "find_package(GTest REQUIRED)"
'' + lib.optionalString cudaSupport ''
# Use our linker flags
substituteInPlace python/triton/compiler.py \
substituteInPlace python/triton/common/build.py \
--replace '${oldStr}' '${newStr}'
'';

View File

@ -8,10 +8,10 @@ version=$1
linux_bucket="https://download.pytorch.org/whl"
url_and_key_list=(
"x86_64-linux-38 $linux_bucket/triton-${version}-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl triton-${version}-cp38-cp38-linux_x86_64.whl"
"x86_64-linux-39 $linux_bucket/triton-${version}-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl triton-${version}-cp39-cp39-linux_x86_64.whl"
"x86_64-linux-310 $linux_bucket/triton-${version}-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl triton-${version}-cp310-cp310-linux_x86_64.whl"
"x86_64-linux-311 $linux_bucket/triton-${version}-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl triton-${version}-cp311-cp311-linux_x86_64.whl"
"x86_64-linux-38 $linux_bucket/triton-${version}-0-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl triton-${version}-cp38-cp38-linux_x86_64.whl"
"x86_64-linux-39 $linux_bucket/triton-${version}-0-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl triton-${version}-cp39-cp39-linux_x86_64.whl"
"x86_64-linux-310 $linux_bucket/triton-${version}-0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl triton-${version}-cp310-cp310-linux_x86_64.whl"
"x86_64-linux-311 $linux_bucket/triton-${version}-0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl triton-${version}-cp311-cp311-linux_x86_64.whl"
)
hashfile=binary-hashes-"$version".nix

View File

@ -8566,7 +8566,7 @@ self: super: with self; {
openai-triton = callPackage ../development/python-modules/openai-triton {
llvm = pkgs.openai-triton-llvm;
cudaPackages = pkgs.cudaPackages_12_0;
cudaPackages = pkgs.cudaPackages_12_1;
};
openai-triton-cuda = self.openai-triton.override {