| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | #include "cutlass_preprocessors.h" |
| | #include "cuda_utils.h" |
| | #include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" |
| |
|
| | #include <vector> |
| |
|
| | namespace fastertransformer { |
| |
|
| | int get_bits_in_quant_type(QuantType quant_type) { |
| | switch (quant_type) { |
| | case QuantType::INT8_WEIGHT_ONLY: |
| | return 8; |
| | case QuantType::PACKED_INT4_WEIGHT_ONLY: |
| | return 4; |
| | default: |
| | return -1; |
| | } |
| | } |
| |
|
| | struct LayoutDetails { |
| | enum class Layout { |
| | UNKNOWN, |
| | ROW_MAJOR, |
| | COLUMN_MAJOR |
| | }; |
| |
|
| | Layout layoutB = Layout::UNKNOWN; |
| | int rows_per_column_tile = 1; |
| | int columns_interleaved = 1; |
| |
|
| | bool uses_imma_ldsm = false; |
| | }; |
| |
|
| | template<typename Layout> |
| | struct getLayoutDetails { |
| | }; |
| |
|
| | template<> |
| | struct getLayoutDetails<cutlass::layout::RowMajor> { |
| | LayoutDetails operator()() |
| | { |
| | LayoutDetails layout_details; |
| | layout_details.layoutB = LayoutDetails::Layout::ROW_MAJOR; |
| | return layout_details; |
| | } |
| | }; |
| |
|
| | template<> |
| | struct getLayoutDetails<cutlass::layout::ColumnMajor> { |
| | LayoutDetails operator()() |
| | { |
| | LayoutDetails layout_details; |
| | layout_details.layoutB = LayoutDetails::Layout::COLUMN_MAJOR; |
| | return layout_details; |
| | } |
| | }; |
| |
|
| | template<int RowsPerTile, int ColumnsInterleaved> |
| | struct getLayoutDetails<cutlass::layout::ColumnMajorTileInterleave<RowsPerTile, ColumnsInterleaved>> { |
| | LayoutDetails operator()() |
| | { |
| | LayoutDetails layout_details; |
| | layout_details.layoutB = LayoutDetails::Layout::COLUMN_MAJOR; |
| | layout_details.rows_per_column_tile = RowsPerTile; |
| | layout_details.columns_interleaved = ColumnsInterleaved; |
| | return layout_details; |
| | } |
| | }; |
| |
|
| | template<typename cutlassArch, typename TypeB> |
| | LayoutDetails getLayoutDetailsForArchAndQuantType() |
| | { |
| |
|
| | using CompileTraits = cutlass::gemm::kernel::LayoutDetailsB<TypeB, cutlassArch>; |
| | using LayoutB = typename CompileTraits::Layout; |
| | using MmaOperator = typename CompileTraits::Operator; |
| | LayoutDetails details = getLayoutDetails<LayoutB>()(); |
| | details.uses_imma_ldsm = std::is_same<MmaOperator, cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA>::value; |
| | return details; |
| | } |
| |
|
| | template<typename cutlassArch> |
| | LayoutDetails getLayoutDetailsForArch(QuantType quant_type) |
| | { |
| | LayoutDetails details; |
| | if (quant_type == QuantType::INT8_WEIGHT_ONLY) { |
| | details = getLayoutDetailsForArchAndQuantType<cutlassArch, uint8_t>(); |
| | } |
| | else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) { |
| | details = getLayoutDetailsForArchAndQuantType<cutlassArch, cutlass::uint4b_t>(); |
| | } |
| | else { |
| | FT_CHECK_WITH_INFO(false, "Unsupported quantization type"); |
| | } |
| | return details; |
| | } |
| |
|
| | LayoutDetails getLayoutDetailsForTransform(QuantType quant_type, int arch) |
| | { |
| | if (arch >= 70 && arch < 75) { |
| | return getLayoutDetailsForArch<cutlass::arch::Sm70>(quant_type); |
| | } |
| | else if (arch >= 75 && arch < 80) { |
| | return getLayoutDetailsForArch<cutlass::arch::Sm75>(quant_type); |
| | } |
| | else if (arch >= 80 && arch < 90) { |
| | return getLayoutDetailsForArch<cutlass::arch::Sm80>(quant_type); |
| | } |
| | else { |
| | FT_CHECK_WITH_INFO(false, "Unsupported Arch"); |
| | return LayoutDetails(); |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | void permute_B_rows_for_mixed_gemm(int8_t *permuted_quantized_tensor, |
| | const int8_t *quantized_tensor, |
| | const std::vector<size_t> &shape, |
| | QuantType quant_type, |
| | const int64_t arch_version) { |
| | const size_t num_rows = shape[0]; |
| | const size_t num_cols = shape[1]; |
| |
|
| | const int BITS_PER_ELT = get_bits_in_quant_type(quant_type); |
| | const int K = 16 / BITS_PER_ELT; |
| | const int ELTS_PER_REG = 32 / BITS_PER_ELT; |
| |
|
| | const uint32_t *input_byte_ptr = |
| | reinterpret_cast<const uint32_t *>(quantized_tensor); |
| | uint32_t *output_byte_ptr = |
| | reinterpret_cast<uint32_t *>(permuted_quantized_tensor); |
| |
|
| | int MMA_SHAPE_N = 8; |
| | int B_ROWS_PER_MMA = 8 * K; |
| | const int elts_in_int32 = 32 / BITS_PER_ELT; |
| |
|
| | const int num_vec_cols = num_cols / elts_in_int32; |
| |
|
| | FT_CHECK_WITH_INFO(arch_version >= 75, |
| | "Unsupported Arch. Pre-volta not supported. Column " |
| | "interleave not needed on Volta."); |
| |
|
| | FT_CHECK_WITH_INFO(num_rows % B_ROWS_PER_MMA == 0, |
| | fmtstr("Invalid shape for quantized tensor. Number of " |
| | "rows of quantized matrix must be a multiple of %d", |
| | B_ROWS_PER_MMA)); |
| |
|
| | FT_CHECK_WITH_INFO( |
| | num_cols % MMA_SHAPE_N == 0, |
| | fmtstr("Invalid shape for quantized tensor. On turing/Ampere, the number " |
| | "of cols must be a multiple of %d.", |
| | MMA_SHAPE_N)); |
| |
|
| | |
| | |
| | for (size_t base_row = 0; base_row < num_rows; base_row += B_ROWS_PER_MMA) { |
| | for (int tile_row = 0; tile_row < B_ROWS_PER_MMA; ++tile_row) { |
| |
|
| | for (int write_col = 0; write_col < num_vec_cols; ++write_col) { |
| | const int write_row = base_row + tile_row; |
| | const int tile_read_row = 8 * (((tile_row % ELTS_PER_REG) / 2)) + |
| | tile_row % 2 + 2 * (tile_row / ELTS_PER_REG); |
| | const int read_row = base_row + tile_read_row; |
| | const int read_col = write_col; |
| |
|
| | const int64_t read_offset = int64_t(read_row) * num_vec_cols + read_col; |
| | const int64_t write_offset = |
| | int64_t(write_row) * num_vec_cols + write_col; |
| |
|
| | output_byte_ptr[write_offset] = input_byte_ptr[read_offset]; |
| | } |
| | } |
| | } |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | template <QuantType quant_type> |
| | void subbyte_transpose_impl(int8_t *transposed_quantized_tensor, |
| | const int8_t *quantized_tensor, |
| | const std::vector<size_t> &shape) { |
| | const int bits_per_elt = get_bits_in_quant_type(quant_type); |
| | const size_t num_rows = shape[0]; |
| | const size_t num_cols = shape[1]; |
| |
|
| | const size_t col_bytes = num_cols * bits_per_elt / 8; |
| | const size_t col_bytes_trans = num_rows * bits_per_elt / 8; |
| |
|
| | const uint8_t *input_byte_ptr = |
| | reinterpret_cast<const uint8_t *>(quantized_tensor); |
| | uint8_t *output_byte_ptr = |
| | reinterpret_cast<uint8_t *>(transposed_quantized_tensor); |
| |
|
| | static constexpr int ELTS_PER_BYTE = |
| | quant_type == QuantType::INT8_WEIGHT_ONLY ? 1 : 2; |
| |
|
| | static constexpr int M_TILE_L1 = 64; |
| | static constexpr int N_TILE_L1 = M_TILE_L1 / ELTS_PER_BYTE; |
| | uint8_t cache_buf[M_TILE_L1][N_TILE_L1]; |
| |
|
| | static constexpr int VECTOR_WIDTH = std::min(32, N_TILE_L1); |
| |
|
| | |
| | |
| | |
| | |
| | FT_CHECK_WITH_INFO( |
| | !(col_bytes_trans % VECTOR_WIDTH) && !(col_bytes % VECTOR_WIDTH), |
| | fmtstr("Number of bytes for rows and cols must be a multiple of %d. " |
| | "However, num_rows_bytes = %ld and num_col_bytes = %d.", |
| | VECTOR_WIDTH, col_bytes_trans, col_bytes)); |
| |
|
| | for (size_t row_tile_start = 0; row_tile_start < num_rows; |
| | row_tile_start += M_TILE_L1) { |
| | for (size_t col_tile_start_byte = 0; col_tile_start_byte < col_bytes; |
| | col_tile_start_byte += N_TILE_L1) { |
| |
|
| | const int row_limit = std::min(row_tile_start + M_TILE_L1, num_rows); |
| | const int col_limit = |
| | std::min(col_tile_start_byte + N_TILE_L1, col_bytes); |
| |
|
| | for (int ii = 0; ii < M_TILE_L1; ++ii) { |
| | const int row = row_tile_start + ii; |
| |
|
| | for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) { |
| | const int col = col_tile_start_byte + jj; |
| |
|
| | const size_t logical_src_offset = row * col_bytes + col; |
| |
|
| | if (row < row_limit && col < col_limit) { |
| | for (int v = 0; v < VECTOR_WIDTH; ++v) { |
| | cache_buf[ii][jj + v] = input_byte_ptr[logical_src_offset + v]; |
| | } |
| | } |
| | } |
| | } |
| |
|
| | if (quant_type == QuantType::INT8_WEIGHT_ONLY) { |
| | for (int ii = 0; ii < M_TILE_L1; ++ii) { |
| | for (int jj = ii + 1; jj < N_TILE_L1; ++jj) { |
| | std::swap(cache_buf[ii][jj], cache_buf[jj][ii]); |
| | } |
| | } |
| | } else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) { |
| |
|
| | for (int ii = 0; ii < M_TILE_L1; ++ii) { |
| | |
| | |
| | |
| | for (int jj = ii + 1; jj < M_TILE_L1; ++jj) { |
| | const int ii_byte = ii / ELTS_PER_BYTE; |
| | const int ii_bit_offset = ii % ELTS_PER_BYTE; |
| |
|
| | const int jj_byte = jj / ELTS_PER_BYTE; |
| | const int jj_bit_offset = jj % ELTS_PER_BYTE; |
| |
|
| | uint8_t src_elt = |
| | 0xF & (cache_buf[ii][jj_byte] >> (4 * jj_bit_offset)); |
| | uint8_t tgt_elt = |
| | 0xF & (cache_buf[jj][ii_byte] >> (4 * ii_bit_offset)); |
| |
|
| | cache_buf[ii][jj_byte] &= (0xF0 >> (4 * jj_bit_offset)); |
| | cache_buf[jj][ii_byte] &= (0xF0 >> (4 * ii_bit_offset)); |
| |
|
| | cache_buf[ii][jj_byte] |= (tgt_elt << (4 * jj_bit_offset)); |
| | cache_buf[jj][ii_byte] |= (src_elt << (4 * ii_bit_offset)); |
| | } |
| | } |
| | } else { |
| | FT_CHECK_WITH_INFO(false, "Unsupported quantization type."); |
| | } |
| |
|
| | const size_t row_tile_start_trans = col_tile_start_byte * ELTS_PER_BYTE; |
| | const size_t col_tile_start_byte_trans = row_tile_start / ELTS_PER_BYTE; |
| |
|
| | const int row_limit_trans = |
| | std::min(row_tile_start_trans + M_TILE_L1, num_cols); |
| | const int col_limit_trans = |
| | std::min(col_tile_start_byte_trans + N_TILE_L1, col_bytes_trans); |
| |
|
| | for (int ii = 0; ii < M_TILE_L1; ++ii) { |
| | const int row = row_tile_start_trans + ii; |
| | for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) { |
| | const int col = col_tile_start_byte_trans + jj; |
| |
|
| | const size_t logical_tgt_offset = row * col_bytes_trans + col; |
| |
|
| | if (row < row_limit_trans && col < col_limit_trans) { |
| | for (int v = 0; v < VECTOR_WIDTH; ++v) { |
| | output_byte_ptr[logical_tgt_offset + v] = cache_buf[ii][jj + v]; |
| | } |
| | } |
| | } |
| | } |
| | } |
| | } |
| | } |
| |
|
| | void subbyte_transpose(int8_t *transposed_quantized_tensor, |
| | const int8_t *quantized_tensor, |
| | const std::vector<size_t> &shape, QuantType quant_type) { |
| |
|
| | if (quant_type == QuantType::INT8_WEIGHT_ONLY) { |
| | subbyte_transpose_impl<QuantType::INT8_WEIGHT_ONLY>( |
| | transposed_quantized_tensor, quantized_tensor, shape); |
| | } else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) { |
| | subbyte_transpose_impl<QuantType::PACKED_INT4_WEIGHT_ONLY>( |
| | transposed_quantized_tensor, quantized_tensor, shape); |
| | } else { |
| | FT_CHECK_WITH_INFO(false, "Invalid quant_tye"); |
| | } |
| | } |
| |
|
| | void add_bias_and_interleave_int8s_inplace(int8_t *int8_tensor, |
| | const size_t num_elts) { |
| | for (size_t ii = 0; ii < num_elts; ++ii) { |
| | int8_tensor[ii] = int8_t(int(int8_tensor[ii]) + 128); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | FT_CHECK_WITH_INFO(num_elts % 4 == 0, "Dimensions of int8 tensor must be a " |
| | "multiple of 4 for register relayout"); |
| | for (size_t base = 0; base < num_elts; base += 4) { |
| | std::swap(int8_tensor[base + 1], int8_tensor[base + 2]); |
| | } |
| | } |
| |
|
| | void add_bias_and_interleave_int4s_inplace(int8_t *packed_int4_tensor, |
| | const size_t num_elts) { |
| | const size_t num_bytes = num_elts / 2; |
| |
|
| | |
| | |
| | for (size_t ii = 0; ii < num_bytes; ++ii) { |
| | int8_t transformed_packed_int4s = 0; |
| | int8_t transformed_first_elt = |
| | (int8_t(packed_int4_tensor[ii] << 4) >> 4) + |
| | 8; |
| | int8_t transformed_second_elt = (packed_int4_tensor[ii] >> 4) + 8; |
| |
|
| | FT_CHECK_WITH_INFO(transformed_first_elt >= 0 && |
| | transformed_first_elt <= 15, |
| | "Illegal result for int4 transform (first elt)"); |
| | FT_CHECK_WITH_INFO(transformed_second_elt >= 0 && |
| | transformed_second_elt <= 15, |
| | "Illegal result for int4 transform (second elt)"); |
| |
|
| | |
| | |
| | transformed_packed_int4s |= transformed_first_elt; |
| | transformed_packed_int4s |= (transformed_second_elt << 4); |
| | packed_int4_tensor[ii] = transformed_packed_int4s; |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | FT_CHECK_WITH_INFO(num_bytes % 4 == 0, "Dimensions of int4 tensor must be a " |
| | "multiple of 8 for register relayout"); |
| | const size_t num_registers = num_bytes / 4; |
| |
|
| | uint32_t *register_ptr = reinterpret_cast<uint32_t *>(packed_int4_tensor); |
| | for (size_t ii = 0; ii < num_registers; ++ii) { |
| | const uint32_t current_register = register_ptr[ii]; |
| | uint32_t transformed_register = 0; |
| |
|
| | for (int dest_idx = 0; dest_idx < 8; ++dest_idx) { |
| | const int src_idx = dest_idx < 4 ? 2 * dest_idx : 2 * (dest_idx - 4) + 1; |
| | const int src_shift = 4 * src_idx; |
| | const int dest_shift = 4 * dest_idx; |
| |
|
| | const uint32_t src_bits = (current_register >> src_shift) & 0xF; |
| | transformed_register |= (src_bits << dest_shift); |
| | } |
| | register_ptr[ii] = transformed_register; |
| | } |
| | } |
| |
|
| | void add_bias_and_interleave_quantized_tensor_inplace(int8_t *tensor, |
| | const size_t num_elts, |
| | QuantType quant_type) { |
| | if (quant_type == QuantType::INT8_WEIGHT_ONLY) { |
| | add_bias_and_interleave_int8s_inplace(tensor, num_elts); |
| | } else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) { |
| | add_bias_and_interleave_int4s_inplace(tensor, num_elts); |
| | } else { |
| | FT_CHECK_WITH_INFO(false, "Invalid quantization type for interleaving."); |
| | } |
| | } |
| |
|
| | void interleave_column_major_tensor(int8_t *interleaved_quantized_tensor, |
| | const int8_t *quantized_tensor, |
| | const std::vector<size_t> &shape, |
| | QuantType quant_type, |
| | LayoutDetails details) { |
| | |
| | FT_CHECK(quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY || |
| | quant_type == QuantType::INT8_WEIGHT_ONLY); |
| | FT_CHECK_WITH_INFO(shape.size() == 2, "Shape must be 2-D"); |
| |
|
| | const size_t num_rows = shape[0]; |
| | const size_t num_cols = shape[1]; |
| |
|
| | const int BITS_PER_ELT = get_bits_in_quant_type(quant_type); |
| | const int elts_in_int32 = 32 / BITS_PER_ELT; |
| |
|
| | const int rows_per_tile = details.rows_per_column_tile; |
| |
|
| | FT_CHECK_WITH_INFO(!(num_rows % elts_in_int32), |
| | fmtstr("The number of rows must be a multiple of %d but " |
| | "the number of rows is %d.", |
| | elts_in_int32, num_rows)); |
| |
|
| | FT_CHECK_WITH_INFO(!(num_cols % rows_per_tile), |
| | fmtstr("The number of columns must be a multiple of %d " |
| | "but the number of columns is %ld", |
| | rows_per_tile, num_cols)); |
| |
|
| | const uint32_t *input_byte_ptr = |
| | reinterpret_cast<const uint32_t *>(quantized_tensor); |
| | uint32_t *output_byte_ptr = |
| | reinterpret_cast<uint32_t *>(interleaved_quantized_tensor); |
| |
|
| | FT_CHECK_WITH_INFO(!(num_cols % rows_per_tile), |
| | fmtstr("The number of columns must be a multiple of %d " |
| | "but the number of columns is %d.", |
| | rows_per_tile, num_cols)); |
| |
|
| | const int num_vec_rows = num_rows / elts_in_int32; |
| | const int vec_rows_per_tile = rows_per_tile / elts_in_int32; |
| | const int interleave = details.columns_interleaved; |
| |
|
| | for (size_t read_col = 0; read_col < num_cols; ++read_col) { |
| | const auto write_col = read_col / interleave; |
| | for (int base_vec_row = 0; base_vec_row < num_vec_rows; |
| | base_vec_row += vec_rows_per_tile) { |
| | for (int vec_read_row = base_vec_row; |
| | vec_read_row < |
| | std::min(num_vec_rows, base_vec_row + vec_rows_per_tile); |
| | ++vec_read_row) { |
| | const int64_t vec_write_row = |
| | interleave * base_vec_row + |
| | vec_rows_per_tile * (read_col % interleave) + |
| | vec_read_row % vec_rows_per_tile; |
| |
|
| | const int64_t read_offset = |
| | int64_t(read_col) * num_vec_rows + vec_read_row; |
| | const int64_t write_offset = |
| | int64_t(write_col) * num_vec_rows * interleave + vec_write_row; |
| | output_byte_ptr[write_offset] = input_byte_ptr[read_offset]; |
| | } |
| | } |
| | } |
| | } |
| |
|
| | void preprocess_weights_for_mixed_gemm(int8_t *preprocessed_quantized_weight, |
| | const int8_t *row_major_quantized_weight, |
| | const std::vector<size_t> &shape, |
| | QuantType quant_type, int arch) { |
| | LayoutDetails details = getLayoutDetailsForTransform(quant_type, arch); |
| |
|
| | FT_CHECK_WITH_INFO(shape.size() == 2, "Shape must be 2-D"); |
| |
|
| | size_t num_elts = 1; |
| | for (const auto &dim : shape) { |
| | num_elts *= dim; |
| | } |
| |
|
| | const size_t num_bytes = num_elts * get_bits_in_quant_type(quant_type) / 8; |
| |
|
| | std::vector<int8_t> src_buf(num_bytes); |
| | std::vector<int8_t> dst_buf(num_bytes); |
| | std::copy(row_major_quantized_weight, row_major_quantized_weight + num_bytes, src_buf.begin()); |
| |
|
| | |
| | if (details.uses_imma_ldsm) { |
| | permute_B_rows_for_mixed_gemm(dst_buf.data(), src_buf.data(), shape, quant_type, arch); |
| | src_buf.swap(dst_buf); |
| | } |
| |
|
| | if (details.layoutB == LayoutDetails::Layout::COLUMN_MAJOR) { |
| | subbyte_transpose(dst_buf.data(), src_buf.data(), shape, quant_type); |
| | src_buf.swap(dst_buf); |
| | } |
| |
|
| | if (details.columns_interleaved > 1) { |
| | interleave_column_major_tensor(dst_buf.data(), src_buf.data(), shape, quant_type, details); |
| | src_buf.swap(dst_buf); |
| | } |
| |
|
| | add_bias_and_interleave_quantized_tensor_inplace(src_buf.data(), num_elts, quant_type); |
| | std::copy(src_buf.begin(), src_buf.end(), preprocessed_quantized_weight); |
| | } |
| |
|
| | void preprocess_weights(int8_t *preprocessed_quantized_weight, |
| | const int8_t *row_major_quantized_weight, size_t rows, |
| | size_t cols, bool is_int4, int arch) { |
| | QuantType qtype = is_int4 ? QuantType::PACKED_INT4_WEIGHT_ONLY |
| | : QuantType::INT8_WEIGHT_ONLY; |
| | preprocess_weights_for_mixed_gemm(preprocessed_quantized_weight, |
| | row_major_quantized_weight, {rows, cols}, |
| | qtype, arch); |
| | } |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | template<typename ComputeType, typename WeightType> |
| | void symmetric_quantize(int8_t* processed_quantized_weight, |
| | int8_t* unprocessed_quantized_weight, |
| | ComputeType* scale_ptr, |
| | const WeightType* input_weight_ptr, |
| | const std::vector<size_t>& shape, |
| | QuantType quant_type) |
| | { |
| |
|
| | FT_CHECK_WITH_INFO(processed_quantized_weight, "Processed quantized tensor is NULL"); |
| | FT_CHECK_WITH_INFO(scale_ptr, "Scale output pointer is NULL"); |
| | FT_CHECK_WITH_INFO(input_weight_ptr, "Input weight pointer is NULL"); |
| |
|
| | FT_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); |
| | const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; |
| | const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; |
| | const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; |
| |
|
| | const int bits_in_type = get_bits_in_quant_type(quant_type); |
| | const int bytes_per_out_col = num_cols * bits_in_type / 8; |
| |
|
| | std::vector<int8_t> weight_buf; |
| | if (unprocessed_quantized_weight == nullptr) { |
| | weight_buf.resize(num_experts * num_rows * num_cols); |
| | unprocessed_quantized_weight = weight_buf.data(); |
| | } |
| |
|
| | const int input_mat_size = num_rows * num_cols; |
| | const int quantized_mat_size = num_rows * bytes_per_out_col; |
| | const float quant_range_scale = 1.f / float(1 << (bits_in_type - 1)); |
| |
|
| | std::vector<float> per_col_max(num_cols); |
| |
|
| | for (int expert = 0; expert < num_experts; ++expert) { |
| | const WeightType* current_weight = input_weight_ptr + expert * input_mat_size; |
| | int8_t* current_quantized_weight = unprocessed_quantized_weight + expert * quantized_mat_size; |
| |
|
| | |
| | for (int jj = 0; jj < num_cols; ++jj) { |
| | per_col_max[jj] = 0.f; |
| | } |
| |
|
| | for (int ii = 0; ii < num_rows; ++ii) { |
| | const WeightType* current_weight_row = current_weight + ii * num_cols; |
| | for (int jj = 0; jj < num_cols; ++jj) { |
| | per_col_max[jj] = std::max(per_col_max[jj], std::abs(float(current_weight_row[jj]))); |
| | } |
| | } |
| |
|
| | |
| | ComputeType* current_scales = scale_ptr + expert * num_cols; |
| | for (int jj = 0; jj < num_cols; ++jj) { |
| | per_col_max[jj] *= quant_range_scale; |
| | current_scales[jj] = ComputeType(per_col_max[jj]); |
| | } |
| |
|
| | |
| | for (int ii = 0; ii < num_rows; ++ii) { |
| | int8_t* current_quantized_weight_row = current_quantized_weight + ii * bytes_per_out_col; |
| | const WeightType* current_weight_row = current_weight + ii * num_cols; |
| | for (int jj = 0; jj < bytes_per_out_col; ++jj) { |
| |
|
| | if (quant_type == QuantType::INT8_WEIGHT_ONLY) { |
| | const float col_scale = per_col_max[jj]; |
| | const float weight_elt = float(current_weight_row[jj]); |
| | const float scaled_weight = round(weight_elt / col_scale); |
| | const int8_t clipped_weight = int8_t(std::max(-128.f, std::min(127.f, scaled_weight))); |
| | current_quantized_weight_row[jj] = clipped_weight; |
| | } |
| | else if (quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY) { |
| |
|
| | |
| | int8_t packed_int4s = 0; |
| | for (int packed_idx = 0; packed_idx < 2; ++packed_idx) { |
| | const int input_idx = 2 * jj + packed_idx; |
| | if (input_idx < num_cols) { |
| | const float col_scale = per_col_max[input_idx]; |
| | const float weight_elt = float(current_weight_row[input_idx]); |
| | const float scaled_weight = round(weight_elt / col_scale); |
| | int int_weight = int(scaled_weight); |
| | const int8_t clipped_weight = std::max(-8, std::min(7, int_weight)); |
| |
|
| | |
| | |
| | packed_int4s |= ((clipped_weight & 0x0F) << (4 * packed_idx)); |
| | } |
| | } |
| | current_quantized_weight_row[jj] = packed_int4s; |
| | } |
| | else { |
| | FT_CHECK_WITH_INFO(false, "Unsupported quantization type"); |
| | } |
| | } |
| | } |
| | } |
| | const int arch = fastertransformer::getSMVersion(); |
| | preprocess_weights_for_mixed_gemm(processed_quantized_weight, unprocessed_quantized_weight, shape, quant_type, arch); |
| | } |
| |
|
| | template void |
| | symmetric_quantize<half, float>(int8_t*, int8_t*, half*, const float*, const std::vector<size_t>&, QuantType); |
| |
|
| | template void |
| | symmetric_quantize<half, half>(int8_t*, int8_t*, half*, const half*, const std::vector<size_t>&, QuantType); |
| |
|
| |
|
| | template<typename ComputeType, typename WeightType> |
| | void symmetric_quantize(int8_t* processed_quantized_weight, |
| | ComputeType* scale_ptr, |
| | const WeightType* input_weight_ptr, |
| | const std::vector<size_t>& shape, |
| | QuantType quant_type) |
| | { |
| | symmetric_quantize(processed_quantized_weight, nullptr, scale_ptr, input_weight_ptr, shape, quant_type); |
| | } |
| |
|
| | template void symmetric_quantize<float, float>(int8_t*, float*, const float*, const std::vector<size_t>&, QuantType); |
| |
|
| | template void symmetric_quantize<half, float>(int8_t*, half*, const float*, const std::vector<size_t>&, QuantType); |
| |
|
| | template void symmetric_quantize<half, half>(int8_t*, half*, const half*, const std::vector<size_t>&, QuantType); |
| |
|
| | } |
| |
|