python3Packages.openai-triton: 2.0.0 -> 2.1.0

This commit is contained in:
Thiago Franco de Moraes 2023-12-05 18:02:45 -03:00
parent 9f9f0356bc
commit 2b25625c1f
No known key found for this signature in database
GPG Key ID: 1B96996EE6559B7A
6 changed files with 39 additions and 76 deletions

View File

@ -1,39 +1,15 @@
diff --git a/python/setup.py b/python/setup.py diff --git a/python/setup.py b/python/setup.py
index 2ac3accd2..f26161c72 100644 index 18764ec13..b3bb5b60a 100644
--- a/python/setup.py --- a/python/setup.py
+++ b/python/setup.py +++ b/python/setup.py
@@ -101,25 +101,6 @@ def get_thirdparty_packages(triton_cache_path): @@ -269,10 +269,6 @@ class CMakeBuild(build_ext):
# ---- package data --- subprocess.check_call(["cmake", self.base_dir] + cmake_args, cwd=cmake_dir, env=env)
subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=cmake_dir)
-def download_and_copy_ptxas():
- base_dir = os.path.dirname(__file__)
- src_path = "bin/ptxas"
- url = "https://conda.anaconda.org/nvidia/label/cuda-12.0.0/linux-64/cuda-nvcc-12.0.76-0.tar.bz2"
- dst_prefix = os.path.join(base_dir, "triton")
- dst_suffix = os.path.join("third_party", "cuda", src_path)
- dst_path = os.path.join(dst_prefix, dst_suffix)
- if not os.path.exists(dst_path):
- print(f'downloading and extracting {url} ...')
- ftpstream = urllib.request.urlopen(url)
- file = tarfile.open(fileobj=ftpstream, mode="r|*")
- with tempfile.TemporaryDirectory() as temp_dir:
- file.extractall(path=temp_dir)
- src_path = os.path.join(temp_dir, src_path)
- os.makedirs(os.path.split(dst_path)[0], exist_ok=True)
- shutil.copy(src_path, dst_path)
- return dst_suffix
- -
-
# ---- cmake extension ----
@@ -200,8 +181,6 @@ class CMakeBuild(build_ext):
subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=self.build_temp)
-download_and_copy_ptxas() -download_and_copy_ptxas()
-
- -
setup( setup(
name="triton", name="triton",
version="2.0.0", version="2.1.0",

View File

@ -18,7 +18,7 @@
buildPythonPackage rec { buildPythonPackage rec {
pname = "triton"; pname = "triton";
version = "2.0.0"; version = "2.1.0";
format = "wheel"; format = "wheel";
src = src =
@ -62,7 +62,7 @@ buildPythonPackage rec {
newStr = lib.concatMapStringsSep ", " quote new; newStr = lib.concatMapStringsSep ", " quote new;
in in
'' ''
substituteInPlace $out/${python.sitePackages}/triton/compiler.py \ substituteInPlace $out/${python.sitePackages}/triton/common/build.py \
--replace '${oldStr}' '${newStr}' --replace '${oldStr}' '${newStr}'
''); '');

View File

@ -6,26 +6,26 @@
# To add a new version, run "prefetch.sh 'new-version'" to paste the generated file as follows. # To add a new version, run "prefetch.sh 'new-version'" to paste the generated file as follows.
version : builtins.getAttr version { version : builtins.getAttr version {
"2.0.0" = { "2.1.0" = {
x86_64-linux-38 = { x86_64-linux-38 = {
name = "triton-2.0.0-1-cp38-cp38-linux_x86_64.whl"; name = "triton-2.1.0-cp38-cp38-linux_x86_64.whl";
url = "https://download.pytorch.org/whl/triton-2.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl"; url = "https://download.pytorch.org/whl/triton-2.1.0-0-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl";
hash = "sha256-nUl4KYt0/PWadf5x5TXAkrAjCIkzsvHfkz7DJhXkvu8="; hash = "sha256-Ofb7a9zLPpjzFS4/vqck8a6ufXSUErux+pxEHUdOuiY=";
}; };
x86_64-linux-39 = { x86_64-linux-39 = {
name = "triton-2.0.0-1-cp39-cp39-linux_x86_64.whl"; name = "triton-2.1.0-cp39-cp39-linux_x86_64.whl";
url = "https://download.pytorch.org/whl/triton-2.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl"; url = "https://download.pytorch.org/whl/triton-2.1.0-0-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl";
hash = "sha256-dPEYwStDf7LKJeGgR1kXO1F1gvz0x74RkTMWx2QhNlY="; hash = "sha256-IVROUiwCAFpibIrWPTm9/y8x1BBpWSkZ7ygelk7SZEY=";
}; };
x86_64-linux-310 = { x86_64-linux-310 = {
name = "triton-2.0.0-1-cp310-cp310-linux_x86_64.whl"; name = "triton-2.1.0-cp310-cp310-linux_x86_64.whl";
url = "https://download.pytorch.org/whl/triton-2.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl"; url = "https://download.pytorch.org/whl/triton-2.1.0-0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl";
hash = "sha256-OIBu6WY/Sw981keQ6WxXk3QInlj0mqxKZggSGqVeJQU="; hash = "sha256-ZkOZI6MNXUg5mwip6uEDcPbCYaXshkpkmDuuYxUtOdc=";
}; };
x86_64-linux-311 = { x86_64-linux-311 = {
name = "triton-2.0.0-1-cp311-cp311-linux_x86_64.whl"; name = "triton-2.1.0-cp311-cp311-linux_x86_64.whl";
url = "https://download.pytorch.org/whl/triton-2.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl"; url = "https://download.pytorch.org/whl/triton-2.1.0-0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl";
hash = "sha256-ImlBx7hZUhnd71mh/bgh6MdEKJoTJBXd1YT6zt60dbE="; hash = "sha256-kZsGRT8AM+pSwT6veDPeDlfbMXjSPU4E+fxxxPLDK/g=";
}; };
}; };
} }

View File

@ -2,10 +2,11 @@
, config , config
, buildPythonPackage , buildPythonPackage
, fetchFromGitHub , fetchFromGitHub
, fetchpatch
, addOpenGLRunpath , addOpenGLRunpath
, setuptools
, pytestCheckHook , pytestCheckHook
, pythonRelaxDepsHook , pythonRelaxDepsHook
, pkgsTargetTarget
, cmake , cmake
, ninja , ninja
, pybind11 , pybind11
@ -23,46 +24,32 @@
}: }:
let let
# A time may come we'll want to be cross-friendly ptxas = "${cudaPackages.cuda_nvcc}/bin/ptxas"; # Make sure cudaPackages is the right version each update (See python/setup.py)
#
# Short explanation: we need pkgsTargetTarget, because we use string
# interpolation instead of buildInputs.
#
# Long explanation: OpenAI/triton downloads and vendors a copy of NVidia's
# ptxas compiler. We're not running this ptxas on the build machine, but on
# the user's machine, i.e. our Target platform. The second "Target" in
# pkgsTargetTarget maybe doesn't matter, because ptxas compiles programs to
# be executed on the GPU.
# Cf. https://nixos.org/manual/nixpkgs/unstable/#sec-cross-infra
ptxas = "${pkgsTargetTarget.cudaPackages.cuda_nvcc}/bin/ptxas"; # Make sure cudaPackages is the right version each update (See python/setup.py)
in in
buildPythonPackage rec { buildPythonPackage rec {
pname = "triton"; pname = "triton";
version = "2.0.0"; version = "2.1.0";
format = "setuptools"; pyproject = true;
src = fetchFromGitHub { src = fetchFromGitHub {
owner = "openai"; owner = "openai";
repo = pname; repo = pname;
rev = "v${version}"; rev = "v${version}";
hash = "sha256-9GZzugab+Pdt74Dj6zjlEzjj4BcJ69rzMJmqcVMxsKU="; hash = "sha256-8UTUwLH+SriiJnpejdrzz9qIquP2zBp1/uwLdHmv0XQ=";
}; };
patches = [ patches = [
# TODO: there have been commits upstream aimed at removing the "torch" # fix overflow error
# circular dependency, but the patches fail to apply on the release (fetchpatch {
# revision. Keeping the link for future reference url = "https://github.com/openai/triton/commit/52c146f66b79b6079bcd28c55312fc6ea1852519.patch";
# Also cf. https://github.com/openai/triton/issues/1374 hash = "sha256-098/TCQrzvrBAbQiaVGCMaF3o5Yc3yWDxzwSkzIuAtY=";
})
# (fetchpatch {
# url = "https://github.com/openai/triton/commit/fc7c0b0e437a191e421faa61494b2ff4870850f1.patch";
# hash = "sha256-f0shIqHJkVvuil2Yku7vuqWFn7VCRKFSFjYRlwx25ig=";
# })
] ++ lib.optionals (!cudaSupport) [ ] ++ lib.optionals (!cudaSupport) [
./0000-dont-download-ptxas.patch ./0000-dont-download-ptxas.patch
]; ];
nativeBuildInputs = [ nativeBuildInputs = [
setuptools
pythonRelaxDepsHook pythonRelaxDepsHook
# pytestCheckHook # Requires torch (circular dependency) and probably needs GPUs: # pytestCheckHook # Requires torch (circular dependency) and probably needs GPUs:
cmake cmake
@ -111,7 +98,7 @@ buildPythonPackage rec {
--replace "include(GoogleTest)" "find_package(GTest REQUIRED)" --replace "include(GoogleTest)" "find_package(GTest REQUIRED)"
'' + lib.optionalString cudaSupport '' '' + lib.optionalString cudaSupport ''
# Use our linker flags # Use our linker flags
substituteInPlace python/triton/compiler.py \ substituteInPlace python/triton/common/build.py \
--replace '${oldStr}' '${newStr}' --replace '${oldStr}' '${newStr}'
''; '';

View File

@ -8,10 +8,10 @@ version=$1
linux_bucket="https://download.pytorch.org/whl" linux_bucket="https://download.pytorch.org/whl"
url_and_key_list=( url_and_key_list=(
"x86_64-linux-38 $linux_bucket/triton-${version}-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl triton-${version}-cp38-cp38-linux_x86_64.whl" "x86_64-linux-38 $linux_bucket/triton-${version}-0-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl triton-${version}-cp38-cp38-linux_x86_64.whl"
"x86_64-linux-39 $linux_bucket/triton-${version}-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl triton-${version}-cp39-cp39-linux_x86_64.whl" "x86_64-linux-39 $linux_bucket/triton-${version}-0-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl triton-${version}-cp39-cp39-linux_x86_64.whl"
"x86_64-linux-310 $linux_bucket/triton-${version}-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl triton-${version}-cp310-cp310-linux_x86_64.whl" "x86_64-linux-310 $linux_bucket/triton-${version}-0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl triton-${version}-cp310-cp310-linux_x86_64.whl"
"x86_64-linux-311 $linux_bucket/triton-${version}-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl triton-${version}-cp311-cp311-linux_x86_64.whl" "x86_64-linux-311 $linux_bucket/triton-${version}-0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl triton-${version}-cp311-cp311-linux_x86_64.whl"
) )
hashfile=binary-hashes-"$version".nix hashfile=binary-hashes-"$version".nix

View File

@ -8566,7 +8566,7 @@ self: super: with self; {
openai-triton = callPackage ../development/python-modules/openai-triton { openai-triton = callPackage ../development/python-modules/openai-triton {
llvm = pkgs.openai-triton-llvm; llvm = pkgs.openai-triton-llvm;
cudaPackages = pkgs.cudaPackages_12_0; cudaPackages = pkgs.cudaPackages_12_1;
}; };
openai-triton-cuda = self.openai-triton.override { openai-triton-cuda = self.openai-triton.override {